Skip to content

Commit cbb1567

Browse files
toppercjph-13
authored andcommitted
[SPIR-V] Stop using Register to represent target specific virtual registers. (llvm#129362)
These were using the virtual register encoding in Register which required including Register.h in MC layer code which is a layering violation. This also required converting Register with bit 31 set to MCRegister which should be an error. Register with bit 31 set should only be used for codegen virtual register. I'd like to add assertions to enforce this. Migrate to MCRegister and manually create an encoding with bit 31 set. WebAssembly also does this. We could consider adding interfaces to MCRegister for target specific virtual registers.
1 parent 1686a01 commit cbb1567

File tree

7 files changed

+65
-53
lines changed

7 files changed

+65
-53
lines changed

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include "SPIRV.h"
1515
#include "SPIRVBaseInfo.h"
1616
#include "llvm/ADT/APFloat.h"
17-
#include "llvm/CodeGen/Register.h"
1817
#include "llvm/MC/MCAsmInfo.h"
1918
#include "llvm/MC/MCExpr.h"
2019
#include "llvm/MC/MCInst.h"
@@ -97,7 +96,7 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
9796
}
9897

9998
void SPIRVInstPrinter::recordOpExtInstImport(const MCInst *MI) {
100-
Register Reg = MI->getOperand(0).getReg();
99+
MCRegister Reg = MI->getOperand(0).getReg();
101100
auto Name = getSPIRVStringOperand(*MI, 1);
102101
auto Set = getExtInstSetFromString(Name);
103102
ExtInstSetIDs.insert({Reg, Set});
@@ -335,7 +334,7 @@ void SPIRVInstPrinter::printOperand(const MCInst *MI, unsigned OpNo,
335334
if (OpNo < MI->getNumOperands()) {
336335
const MCOperand &Op = MI->getOperand(OpNo);
337336
if (Op.isReg())
338-
O << '%' << (Register(Op.getReg()).virtRegIndex() + 1);
337+
O << '%' << (getIDFromRegister(Op.getReg().id()) + 1);
339338
else if (Op.isImm())
340339
O << formatImm((int64_t)Op.getImm());
341340
else if (Op.isDFPImm())

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "MCTargetDesc/SPIRVMCTargetDesc.h"
14-
#include "llvm/CodeGen/Register.h"
1514
#include "llvm/MC/MCCodeEmitter.h"
1615
#include "llvm/MC/MCFixup.h"
1716
#include "llvm/MC/MCInst.h"
@@ -77,7 +76,8 @@ static void emitOperand(const MCOperand &Op, SmallVectorImpl<char> &CB) {
7776
if (Op.isReg()) {
7877
// Emit the id index starting at 1 (0 is an invalid index).
7978
support::endian::write<uint32_t>(
80-
CB, Register(Op.getReg()).virtRegIndex() + 1, llvm::endianness::little);
79+
CB, SPIRV::getIDFromRegister(Op.getReg().id()) + 1,
80+
llvm::endianness::little);
8181
} else if (Op.isImm()) {
8282
support::endian::write(CB, static_cast<uint32_t>(Op.getImm()),
8383
llvm::endianness::little);

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCTargetDesc.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVMCTARGETDESC_H
1515

1616
#include "llvm/Support/DataTypes.h"
17+
#include <cassert>
1718
#include <memory>
1819

1920
namespace llvm {
@@ -50,4 +51,11 @@ std::unique_ptr<MCObjectTargetWriter> createSPIRVObjectTargetWriter();
5051
#define GET_SUBTARGETINFO_ENUM
5152
#include "SPIRVGenSubtargetInfo.inc"
5253

54+
namespace llvm::SPIRV {
55+
inline unsigned getIDFromRegister(unsigned Reg) {
56+
assert(Reg & (1U << 31));
57+
return Reg & ~(1U << 31);
58+
}
59+
} // namespace llvm::SPIRV
60+
5361
#endif // LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVMCTARGETDESC_H

llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ class SPIRVAsmPrinter : public AsmPrinter {
7070
void outputOpMemoryModel();
7171
void outputOpFunctionEnd();
7272
void outputExtFuncDecls();
73-
void outputExecutionModeFromMDNode(Register Reg, MDNode *Node,
73+
void outputExecutionModeFromMDNode(MCRegister Reg, MDNode *Node,
7474
SPIRV::ExecutionMode::ExecutionMode EM,
7575
unsigned ExpectMDOps, int64_t DefVal);
7676
void outputExecutionModeFromNumthreadsAttribute(
77-
const Register &Reg, const Attribute &Attr,
77+
const MCRegister &Reg, const Attribute &Attr,
7878
SPIRV::ExecutionMode::ExecutionMode EM);
7979
void outputExecutionMode(const Module &M);
8080
void outputAnnotations(const Module &M);
@@ -316,7 +316,7 @@ void SPIRVAsmPrinter::outputDebugSourceAndStrings(const Module &M) {
316316
void SPIRVAsmPrinter::outputOpExtInstImports(const Module &M) {
317317
for (auto &CU : MAI->ExtInstSetMap) {
318318
unsigned Set = CU.first;
319-
Register Reg = CU.second;
319+
MCRegister Reg = CU.second;
320320
MCInst Inst;
321321
Inst.setOpcode(SPIRV::OpExtInstImport);
322322
Inst.addOperand(MCOperand::createReg(Reg));
@@ -341,7 +341,7 @@ void SPIRVAsmPrinter::outputOpMemoryModel() {
341341
// the interface of this entry point.
342342
void SPIRVAsmPrinter::outputEntryPoints() {
343343
// Find all OpVariable IDs with required StorageClass.
344-
DenseSet<Register> InterfaceIDs;
344+
DenseSet<MCRegister> InterfaceIDs;
345345
for (const MachineInstr *MI : MAI->GlobalVarList) {
346346
assert(MI->getOpcode() == SPIRV::OpVariable);
347347
auto SC = static_cast<SPIRV::StorageClass::StorageClass>(
@@ -353,7 +353,7 @@ void SPIRVAsmPrinter::outputEntryPoints() {
353353
if (ST->isAtLeastSPIRVVer(VersionTuple(1, 4)) ||
354354
SC == SPIRV::StorageClass::Input || SC == SPIRV::StorageClass::Output) {
355355
const MachineFunction *MF = MI->getMF();
356-
Register Reg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
356+
MCRegister Reg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
357357
InterfaceIDs.insert(Reg);
358358
}
359359
}
@@ -363,7 +363,7 @@ void SPIRVAsmPrinter::outputEntryPoints() {
363363
SPIRVMCInstLower MCInstLowering;
364364
MCInst TmpInst;
365365
MCInstLowering.lower(MI, TmpInst, MAI);
366-
for (Register Reg : InterfaceIDs) {
366+
for (MCRegister Reg : InterfaceIDs) {
367367
assert(Reg.isValid());
368368
TmpInst.addOperand(MCOperand::createReg(Reg));
369369
}
@@ -444,7 +444,7 @@ static void addOpsFromMDNode(MDNode *MDN, MCInst &Inst,
444444
if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
445445
Inst.addOperand(MCOperand::createImm(Const->getZExtValue()));
446446
} else if (auto *CE = dyn_cast<Function>(C)) {
447-
Register FuncReg = MAI->getFuncReg(CE);
447+
MCRegister FuncReg = MAI->getFuncReg(CE);
448448
assert(FuncReg.isValid());
449449
Inst.addOperand(MCOperand::createReg(FuncReg));
450450
}
@@ -453,7 +453,7 @@ static void addOpsFromMDNode(MDNode *MDN, MCInst &Inst,
453453
}
454454

455455
void SPIRVAsmPrinter::outputExecutionModeFromMDNode(
456-
Register Reg, MDNode *Node, SPIRV::ExecutionMode::ExecutionMode EM,
456+
MCRegister Reg, MDNode *Node, SPIRV::ExecutionMode::ExecutionMode EM,
457457
unsigned ExpectMDOps, int64_t DefVal) {
458458
MCInst Inst;
459459
Inst.setOpcode(SPIRV::OpExecutionMode);
@@ -470,7 +470,7 @@ void SPIRVAsmPrinter::outputExecutionModeFromMDNode(
470470
}
471471

472472
void SPIRVAsmPrinter::outputExecutionModeFromNumthreadsAttribute(
473-
const Register &Reg, const Attribute &Attr,
473+
const MCRegister &Reg, const Attribute &Attr,
474474
SPIRV::ExecutionMode::ExecutionMode EM) {
475475
assert(Attr.isValid() && "Function called with an invalid attribute.");
476476

@@ -508,7 +508,7 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
508508
// <Entry Point> operands of OpExecutionMode
509509
if (F.isDeclaration() || !isEntryPoint(F))
510510
continue;
511-
Register FReg = MAI->getFuncReg(&F);
511+
MCRegister FReg = MAI->getFuncReg(&F);
512512
assert(FReg.isValid());
513513
if (MDNode *Node = F.getMetadata("reqd_work_group_size"))
514514
outputExecutionModeFromMDNode(FReg, Node, SPIRV::ExecutionMode::LocalSize,
@@ -560,7 +560,7 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
560560
if (!isa<Function>(AnnotatedVar))
561561
report_fatal_error("Unsupported value in llvm.global.annotations");
562562
Function *Func = cast<Function>(AnnotatedVar);
563-
Register Reg = MAI->getFuncReg(Func);
563+
MCRegister Reg = MAI->getFuncReg(Func);
564564
if (!Reg.isValid()) {
565565
std::string DiagMsg;
566566
raw_string_ostream OS(DiagMsg);

llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ void SPIRVMCInstLower::lower(const MachineInstr *MI, MCInst &OutMI,
3434
default:
3535
llvm_unreachable("unknown operand type");
3636
case MachineOperand::MO_GlobalAddress: {
37-
Register FuncReg = MAI->getFuncReg(dyn_cast<Function>(MO.getGlobal()));
37+
MCRegister FuncReg = MAI->getFuncReg(dyn_cast<Function>(MO.getGlobal()));
3838
if (!FuncReg.isValid()) {
3939
std::string DiagMsg;
4040
raw_string_ostream OS(DiagMsg);
@@ -49,13 +49,14 @@ void SPIRVMCInstLower::lower(const MachineInstr *MI, MCInst &OutMI,
4949
MCOp = MCOperand::createReg(MAI->getOrCreateMBBRegister(*MO.getMBB()));
5050
break;
5151
case MachineOperand::MO_Register: {
52-
Register NewReg = MAI->getRegisterAlias(MF, MO.getReg());
53-
MCOp = MCOperand::createReg(NewReg.isValid() ? NewReg : MO.getReg());
52+
MCRegister NewReg = MAI->getRegisterAlias(MF, MO.getReg());
53+
MCOp = MCOperand::createReg(NewReg.isValid() ? NewReg
54+
: MO.getReg().asMCReg());
5455
break;
5556
}
5657
case MachineOperand::MO_Immediate:
5758
if (MI->getOpcode() == SPIRV::OpExtInst && i == 2) {
58-
Register Reg = MAI->getExtInstSetReg(MO.getImm());
59+
MCRegister Reg = MAI->getExtInstSetReg(MO.getImm());
5960
MCOp = MCOperand::createReg(Reg);
6061
} else {
6162
MCOp = MCOperand::createImm(MO.getImm());

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,7 @@ void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
213213
if (ST->isOpenCLEnv()) {
214214
// TODO: check if it's required by default.
215215
MAI.ExtInstSetMap[static_cast<unsigned>(
216-
SPIRV::InstructionSet::OpenCL_std)] =
217-
Register::index2VirtReg(MAI.getNextID());
216+
SPIRV::InstructionSet::OpenCL_std)] = MAI.getNextIDRegister();
218217
}
219218
}
220219

@@ -306,7 +305,8 @@ void SPIRVModuleAnalysis::visitFunPtrUse(
306305
} while (OpDefMI && (OpDefMI->getOpcode() == SPIRV::OpFunction ||
307306
OpDefMI->getOpcode() == SPIRV::OpFunctionParameter));
308307
// associate the function pointer with the newly assigned global number
309-
Register GlobalFunDefReg = MAI.getRegisterAlias(FunDefMF, OpFunDef->getReg());
308+
MCRegister GlobalFunDefReg =
309+
MAI.getRegisterAlias(FunDefMF, OpFunDef->getReg());
310310
assert(GlobalFunDefReg.isValid() &&
311311
"Function definition must refer to a global register");
312312
MAI.setRegisterAlias(MF, OpReg, GlobalFunDefReg);
@@ -353,10 +353,10 @@ void SPIRVModuleAnalysis::visitDecl(
353353
"No unique definition is found for the virtual register");
354354
}
355355

356-
Register GReg;
356+
MCRegister GReg;
357357
bool IsFunDef = false;
358358
if (TII->isSpecConstantInstr(MI)) {
359-
GReg = Register::index2VirtReg(MAI.getNextID());
359+
GReg = MAI.getNextIDRegister();
360360
MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
361361
} else if (Opcode == SPIRV::OpFunction ||
362362
Opcode == SPIRV::OpFunctionParameter) {
@@ -366,7 +366,7 @@ void SPIRVModuleAnalysis::visitDecl(
366366
const MachineInstr *NextInstr = MI.getNextNode();
367367
while (NextInstr &&
368368
NextInstr->getOpcode() == SPIRV::OpTypeStructContinuedINTEL) {
369-
Register Tmp = handleTypeDeclOrConstant(*NextInstr, SignatureToGReg);
369+
MCRegister Tmp = handleTypeDeclOrConstant(*NextInstr, SignatureToGReg);
370370
MAI.setRegisterAlias(MF, NextInstr->getOperand(0).getReg(), Tmp);
371371
MAI.setSkipEmission(NextInstr);
372372
NextInstr = NextInstr->getNextNode();
@@ -389,7 +389,7 @@ void SPIRVModuleAnalysis::visitDecl(
389389
MAI.setSkipEmission(&MI);
390390
}
391391

392-
Register SPIRVModuleAnalysis::handleFunctionOrParameter(
392+
MCRegister SPIRVModuleAnalysis::handleFunctionOrParameter(
393393
const MachineFunction *MF, const MachineInstr &MI,
394394
std::map<const Value *, unsigned> &GlobalToGReg, bool &IsFunDef) {
395395
const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg());
@@ -402,27 +402,27 @@ Register SPIRVModuleAnalysis::handleFunctionOrParameter(
402402
auto It = GlobalToGReg.find(GObj);
403403
if (It != GlobalToGReg.end())
404404
return It->second;
405-
Register GReg = Register::index2VirtReg(MAI.getNextID());
405+
MCRegister GReg = MAI.getNextIDRegister();
406406
GlobalToGReg[GObj] = GReg;
407407
if (!IsFunDef)
408408
MAI.MS[SPIRV::MB_ExtFuncDecls].push_back(&MI);
409409
return GReg;
410410
}
411411

412-
Register
412+
MCRegister
413413
SPIRVModuleAnalysis::handleTypeDeclOrConstant(const MachineInstr &MI,
414414
InstrGRegsMap &SignatureToGReg) {
415415
InstrSignature MISign = instrToSignature(MI, MAI, false);
416416
auto It = SignatureToGReg.find(MISign);
417417
if (It != SignatureToGReg.end())
418418
return It->second;
419-
Register GReg = Register::index2VirtReg(MAI.getNextID());
419+
MCRegister GReg = MAI.getNextIDRegister();
420420
SignatureToGReg[MISign] = GReg;
421421
MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
422422
return GReg;
423423
}
424424

425-
Register SPIRVModuleAnalysis::handleVariable(
425+
MCRegister SPIRVModuleAnalysis::handleVariable(
426426
const MachineFunction *MF, const MachineInstr &MI,
427427
std::map<const Value *, unsigned> &GlobalToGReg) {
428428
MAI.GlobalVarList.push_back(&MI);
@@ -431,7 +431,7 @@ Register SPIRVModuleAnalysis::handleVariable(
431431
auto It = GlobalToGReg.find(GObj);
432432
if (It != GlobalToGReg.end())
433433
return It->second;
434-
Register GReg = Register::index2VirtReg(MAI.getNextID());
434+
MCRegister GReg = MAI.getNextIDRegister();
435435
GlobalToGReg[GObj] = GReg;
436436
MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
437437
return GReg;
@@ -507,7 +507,7 @@ void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
507507
} else if (MI.getOpcode() == SPIRV::OpFunction) {
508508
// Record all internal OpFunction declarations.
509509
Register Reg = MI.defs().begin()->getReg();
510-
Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
510+
MCRegister GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
511511
assert(GlobalReg.isValid());
512512
MAI.FuncMap[F] = GlobalReg;
513513
}
@@ -599,14 +599,14 @@ void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
599599
Register Reg = Op.getReg();
600600
if (MAI.hasRegisterAlias(MF, Reg))
601601
continue;
602-
Register NewReg = Register::index2VirtReg(MAI.getNextID());
602+
MCRegister NewReg = MAI.getNextIDRegister();
603603
MAI.setRegisterAlias(MF, Reg, NewReg);
604604
}
605605
if (MI.getOpcode() != SPIRV::OpExtInst)
606606
continue;
607607
auto Set = MI.getOperand(2).getImm();
608608
if (!MAI.ExtInstSetMap.contains(Set))
609-
MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID());
609+
MAI.ExtInstSetMap[Set] = MAI.getNextIDRegister();
610610
}
611611
}
612612
}
@@ -1938,7 +1938,7 @@ static void addMBBNames(const Module &M, const SPIRVInstrInfo &TII,
19381938
Register Reg = MRI.createGenericVirtualRegister(LLT::scalar(64));
19391939
MRI.setRegClass(Reg, &SPIRV::IDRegClass);
19401940
buildOpName(Reg, MBB.getName(), *std::prev(MBB.end()), TII);
1941-
Register GlobalReg = MAI.getOrCreateMBBRegister(MBB);
1941+
MCRegister GlobalReg = MAI.getOrCreateMBBRegister(MBB);
19421942
MAI.setRegisterAlias(MF, Reg, GlobalReg);
19431943
}
19441944
}
@@ -1992,6 +1992,7 @@ bool SPIRVModuleAnalysis::runOnModule(Module &M) {
19921992

19931993
// Process type/const/global var/func decl instructions, number their
19941994
// destination registers from 0 to N, collect Extensions and Capabilities.
1995+
collectReqs(M, MAI, MMI, *ST);
19951996
collectDeclarations(M);
19961997

19971998
// Number rest of registers from N+1 onwards.

0 commit comments

Comments
 (0)