#include "MCTargetDesc/SPIRVInstPrinter.h"
#include "SPIRV.h"
#include "SPIRVInstrInfo.h"
#include "SPIRVMCInstLower.h"
#include "SPIRVModuleAnalysis.h"
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "TargetInfo/SPIRVTargetInfo.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/AsmPrinter.h"
#include "llvm/CodeGen/MachineConstantPool.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/MC/MCInst.h"
#include "llvm/MC/MCStreamer.h"
#include "llvm/MC/MCSymbol.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm;
#define DEBUG_TYPE "asm-printer"
namespace {
class SPIRVAsmPrinter : public AsmPrinter {
public:
explicit SPIRVAsmPrinter(TargetMachine &TM,
std::unique_ptr<MCStreamer> Streamer)
: AsmPrinter(TM, std::move(Streamer)), ST(nullptr), TII(nullptr) {}
bool ModuleSectionsEmitted;
const SPIRVSubtarget *ST;
const SPIRVInstrInfo *TII;
StringRef getPassName() const override { return "SPIRV Assembly Printer"; }
void printOperand(const MachineInstr *MI, int OpNum, raw_ostream &O);
bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
const char *ExtraCode, raw_ostream &O) override;
void outputMCInst(MCInst &Inst);
void outputInstruction(const MachineInstr *MI);
void outputModuleSection(SPIRV::ModuleSectionType MSType);
void outputEntryPoints();
void outputDebugSourceAndStrings(const Module &M);
void outputOpExtInstImports(const Module &M);
void outputOpMemoryModel();
void outputOpFunctionEnd();
void outputExtFuncDecls();
void outputExecutionModeFromMDNode(Register Reg, MDNode *Node,
SPIRV::ExecutionMode EM);
void outputExecutionMode(const Module &M);
void outputAnnotations(const Module &M);
void outputModuleSections();
void emitInstruction(const MachineInstr *MI) override;
void emitFunctionEntryLabel() override {}
void emitFunctionHeader() override;
void emitFunctionBodyStart() override {}
void emitFunctionBodyEnd() override;
void emitBasicBlockStart(const MachineBasicBlock &MBB) override;
void emitBasicBlockEnd(const MachineBasicBlock &MBB) override {}
void emitGlobalVariable(const GlobalVariable *GV) override {}
void emitOpLabel(const MachineBasicBlock &MBB);
void emitEndOfAsmFile(Module &M) override;
bool doInitialization(Module &M) override;
void getAnalysisUsage(AnalysisUsage &AU) const override;
SPIRV::ModuleAnalysisInfo *MAI;
};
}
void SPIRVAsmPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<SPIRVModuleAnalysis>();
AU.addPreserved<SPIRVModuleAnalysis>();
AsmPrinter::getAnalysisUsage(AU);
}
void SPIRVAsmPrinter::emitEndOfAsmFile(Module &M) {
if (ModuleSectionsEmitted == false) {
outputModuleSections();
ModuleSectionsEmitted = true;
}
}
void SPIRVAsmPrinter::emitFunctionHeader() {
if (ModuleSectionsEmitted == false) {
outputModuleSections();
ModuleSectionsEmitted = true;
}
ST = &MF->getSubtarget<SPIRVSubtarget>();
TII = ST->getInstrInfo();
const Function &F = MF->getFunction();
if (isVerbose()) {
OutStreamer->getCommentOS()
<< "-- Begin function "
<< GlobalValue::dropLLVMManglingEscape(F.getName()) << '\n';
}
auto Section = getObjFileLowering().SectionForGlobal(&F, TM);
MF->setSection(Section);
}
void SPIRVAsmPrinter::outputOpFunctionEnd() {
MCInst FunctionEndInst;
FunctionEndInst.setOpcode(SPIRV::OpFunctionEnd);
outputMCInst(FunctionEndInst);
}
void SPIRVAsmPrinter::emitFunctionBodyEnd() {
outputOpFunctionEnd();
MAI->BBNumToRegMap.clear();
}
void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) {
if (MAI->MBBsToSkip.contains(&MBB))
return;
MCInst LabelInst;
LabelInst.setOpcode(SPIRV::OpLabel);
LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB)));
outputMCInst(LabelInst);
}
void SPIRVAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
if (MBB.getNumber() == MF->front().getNumber()) {
for (const MachineInstr &MI : MBB)
if (MI.getOpcode() == SPIRV::OpFunction)
return;
report_fatal_error("OpFunction is expected in the front MBB of MF");
}
emitOpLabel(MBB);
}
void SPIRVAsmPrinter::printOperand(const MachineInstr *MI, int OpNum,
raw_ostream &O) {
const MachineOperand &MO = MI->getOperand(OpNum);
switch (MO.getType()) {
case MachineOperand::MO_Register:
O << SPIRVInstPrinter::getRegisterName(MO.getReg());
break;
case MachineOperand::MO_Immediate:
O << MO.getImm();
break;
case MachineOperand::MO_FPImmediate:
O << MO.getFPImm();
break;
case MachineOperand::MO_MachineBasicBlock:
O << *MO.getMBB()->getSymbol();
break;
case MachineOperand::MO_GlobalAddress:
O << *getSymbol(MO.getGlobal());
break;
case MachineOperand::MO_BlockAddress: {
MCSymbol *BA = GetBlockAddressSymbol(MO.getBlockAddress());
O << BA->getName();
break;
}
case MachineOperand::MO_ExternalSymbol:
O << *GetExternalSymbolSymbol(MO.getSymbolName());
break;
case MachineOperand::MO_JumpTableIndex:
case MachineOperand::MO_ConstantPoolIndex:
default:
llvm_unreachable("<unknown operand type>");
}
}
bool SPIRVAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
const char *ExtraCode, raw_ostream &O) {
if (ExtraCode && ExtraCode[0])
return true;
printOperand(MI, OpNo, O);
return false;
}
static bool isFuncOrHeaderInstr(const MachineInstr *MI,
const SPIRVInstrInfo *TII) {
return TII->isHeaderInstr(*MI) || MI->getOpcode() == SPIRV::OpFunction ||
MI->getOpcode() == SPIRV::OpFunctionParameter;
}
void SPIRVAsmPrinter::outputMCInst(MCInst &Inst) {
OutStreamer->emitInstruction(Inst, *OutContext.getSubtargetInfo());
}
void SPIRVAsmPrinter::outputInstruction(const MachineInstr *MI) {
SPIRVMCInstLower MCInstLowering;
MCInst TmpInst;
MCInstLowering.lower(MI, TmpInst, MAI);
outputMCInst(TmpInst);
}
void SPIRVAsmPrinter::emitInstruction(const MachineInstr *MI) {
SPIRV_MC::verifyInstructionPredicates(MI->getOpcode(),
getSubtargetInfo().getFeatureBits());
if (!MAI->getSkipEmission(MI))
outputInstruction(MI);
const MachineInstr *NextMI = MI->getNextNode();
if (!MAI->hasMBBRegister(*MI->getParent()) && isFuncOrHeaderInstr(MI, TII) &&
(!NextMI || !isFuncOrHeaderInstr(NextMI, TII))) {
assert(MI->getParent()->getNumber() == MF->front().getNumber() &&
"OpFunction is not in the front MBB of MF");
emitOpLabel(*MI->getParent());
}
}
void SPIRVAsmPrinter::outputModuleSection(SPIRV::ModuleSectionType MSType) {
for (MachineInstr *MI : MAI->getMSInstrs(MSType))
outputInstruction(MI);
}
void SPIRVAsmPrinter::outputDebugSourceAndStrings(const Module &M) {
for (auto &Str : MAI->SrcExt) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpSourceExtension);
addStringImm(Str.first(), Inst);
outputMCInst(Inst);
}
MCInst Inst;
Inst.setOpcode(SPIRV::OpSource);
Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(MAI->SrcLang)));
Inst.addOperand(
MCOperand::createImm(static_cast<unsigned>(MAI->SrcLangVersion)));
outputMCInst(Inst);
}
void SPIRVAsmPrinter::outputOpExtInstImports(const Module &M) {
for (auto &CU : MAI->ExtInstSetMap) {
unsigned Set = CU.first;
Register Reg = CU.second;
MCInst Inst;
Inst.setOpcode(SPIRV::OpExtInstImport);
Inst.addOperand(MCOperand::createReg(Reg));
addStringImm(getExtInstSetName(static_cast<SPIRV::InstructionSet>(Set)),
Inst);
outputMCInst(Inst);
}
}
void SPIRVAsmPrinter::outputOpMemoryModel() {
MCInst Inst;
Inst.setOpcode(SPIRV::OpMemoryModel);
Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(MAI->Addr)));
Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(MAI->Mem)));
outputMCInst(Inst);
}
void SPIRVAsmPrinter::outputEntryPoints() {
DenseSet<Register> InterfaceIDs;
for (MachineInstr *MI : MAI->GlobalVarList) {
assert(MI->getOpcode() == SPIRV::OpVariable);
auto SC = static_cast<SPIRV::StorageClass>(MI->getOperand(2).getImm());
if (ST->getSPIRVVersion() >= 14 || SC == SPIRV::StorageClass::Input ||
SC == SPIRV::StorageClass::Output) {
MachineFunction *MF = MI->getMF();
Register Reg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
InterfaceIDs.insert(Reg);
}
}
for (MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_EntryPoints)) {
SPIRVMCInstLower MCInstLowering;
MCInst TmpInst;
MCInstLowering.lower(MI, TmpInst, MAI);
for (Register Reg : InterfaceIDs) {
assert(Reg.isValid());
TmpInst.addOperand(MCOperand::createReg(Reg));
}
outputMCInst(TmpInst);
}
}
void SPIRVAsmPrinter::outputExtFuncDecls() {
SmallVectorImpl<MachineInstr *>::iterator
I = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).begin(),
E = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).end();
for (; I != E; ++I) {
outputInstruction(*I);
if ((I + 1) == E || (*(I + 1))->getOpcode() == SPIRV::OpFunction)
outputOpFunctionEnd();
}
}
static unsigned encodeVecTypeHint(Type *Ty) {
if (Ty->isHalfTy())
return 4;
if (Ty->isFloatTy())
return 5;
if (Ty->isDoubleTy())
return 6;
if (IntegerType *IntTy = dyn_cast<IntegerType>(Ty)) {
switch (IntTy->getIntegerBitWidth()) {
case 8:
return 0;
case 16:
return 1;
case 32:
return 2;
case 64:
return 3;
default:
llvm_unreachable("invalid integer type");
}
}
if (FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty)) {
Type *EleTy = VecTy->getElementType();
unsigned Size = VecTy->getNumElements();
return Size << 16 | encodeVecTypeHint(EleTy);
}
llvm_unreachable("invalid type");
}
static void addOpsFromMDNode(MDNode *MDN, MCInst &Inst,
SPIRV::ModuleAnalysisInfo *MAI) {
for (const MDOperand &MDOp : MDN->operands()) {
if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
Constant *C = CMeta->getValue();
if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
Inst.addOperand(MCOperand::createImm(Const->getZExtValue()));
} else if (auto *CE = dyn_cast<Function>(C)) {
Register FuncReg = MAI->getFuncReg(CE->getName().str());
assert(FuncReg.isValid());
Inst.addOperand(MCOperand::createReg(FuncReg));
}
}
}
}
void SPIRVAsmPrinter::outputExecutionModeFromMDNode(Register Reg, MDNode *Node,
SPIRV::ExecutionMode EM) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
Inst.addOperand(MCOperand::createReg(Reg));
Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(EM)));
addOpsFromMDNode(Node, Inst, MAI);
outputMCInst(Inst);
}
void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode");
if (Node) {
for (unsigned i = 0; i < Node->getNumOperands(); i++) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI);
outputMCInst(Inst);
}
}
for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
const Function &F = *FI;
if (F.isDeclaration())
continue;
Register FReg = MAI->getFuncReg(F.getGlobalIdentifier());
assert(FReg.isValid());
if (MDNode *Node = F.getMetadata("reqd_work_group_size"))
outputExecutionModeFromMDNode(FReg, Node,
SPIRV::ExecutionMode::LocalSize);
if (MDNode *Node = F.getMetadata("work_group_size_hint"))
outputExecutionModeFromMDNode(FReg, Node,
SPIRV::ExecutionMode::LocalSizeHint);
if (MDNode *Node = F.getMetadata("intel_reqd_sub_group_size"))
outputExecutionModeFromMDNode(FReg, Node,
SPIRV::ExecutionMode::SubgroupSize);
if (MDNode *Node = F.getMetadata("vec_type_hint")) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
Inst.addOperand(MCOperand::createReg(FReg));
unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::VecTypeHint);
Inst.addOperand(MCOperand::createImm(EM));
unsigned TypeCode = encodeVecTypeHint(getMDOperandAsType(Node, 0));
Inst.addOperand(MCOperand::createImm(TypeCode));
outputMCInst(Inst);
}
}
}
void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
outputModuleSection(SPIRV::MB_Annotations);
for (auto F = M.global_begin(), E = M.global_end(); F != E; ++F) {
if ((*F).getName() != "llvm.global.annotations")
continue;
const GlobalVariable *V = &(*F);
const ConstantArray *CA = cast<ConstantArray>(V->getOperand(0));
for (Value *Op : CA->operands()) {
ConstantStruct *CS = cast<ConstantStruct>(Op);
Value *AnnotatedVar = CS->getOperand(0)->stripPointerCasts();
if (!isa<Function>(AnnotatedVar))
llvm_unreachable("Unsupported value in llvm.global.annotations");
Function *Func = cast<Function>(AnnotatedVar);
Register Reg = MAI->getFuncReg(Func->getGlobalIdentifier());
GlobalVariable *GV =
cast<GlobalVariable>(CS->getOperand(1)->stripPointerCasts());
StringRef AnnotationString;
getConstantStringInfo(GV, AnnotationString);
MCInst Inst;
Inst.setOpcode(SPIRV::OpDecorate);
Inst.addOperand(MCOperand::createReg(Reg));
unsigned Dec = static_cast<unsigned>(SPIRV::Decoration::UserSemantic);
Inst.addOperand(MCOperand::createImm(Dec));
addStringImm(AnnotationString, Inst);
outputMCInst(Inst);
}
}
}
void SPIRVAsmPrinter::outputModuleSections() {
const Module *M = MMI->getModule();
ST = static_cast<const SPIRVTargetMachine &>(TM).getSubtargetImpl();
TII = ST->getInstrInfo();
MAI = &SPIRVModuleAnalysis::MAI;
assert(ST && TII && MAI && M && "Module analysis is required");
outputOpExtInstImports(*M);
outputOpMemoryModel();
outputEntryPoints();
outputExecutionMode(*M);
outputDebugSourceAndStrings(*M);
outputModuleSection(SPIRV::MB_DebugNames);
outputModuleSection(SPIRV::MB_DebugModuleProcessed);
outputAnnotations(*M);
outputModuleSection(SPIRV::MB_TypeConstVars);
outputExtFuncDecls();
}
bool SPIRVAsmPrinter::doInitialization(Module &M) {
ModuleSectionsEmitted = false;
return AsmPrinter::doInitialization(M);
}
extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVAsmPrinter() {
RegisterAsmPrinter<SPIRVAsmPrinter> X(getTheSPIRV32Target());
RegisterAsmPrinter<SPIRVAsmPrinter> Y(getTheSPIRV64Target());
}