Skip to content

Commit 971e72b

Browse files
[SYCLomatic][PTX] Enable migration of ldmatrix (#2692)
1 parent c97f93a commit 971e72b

File tree

9 files changed

+397
-21
lines changed

9 files changed

+397
-21
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,10 +556,15 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
556556
return SYCLGenError();
557557
OS() << ", ";
558558
switch (T->getKind()) {
559+
case InlineAsmVectorType::x1:
560+
OS() << 1;
561+
break;
559562
case InlineAsmVectorType::v2:
563+
case InlineAsmVectorType::x2:
560564
OS() << 2;
561565
break;
562566
case InlineAsmVectorType::v4:
567+
case InlineAsmVectorType::x4:
563568
OS() << 4;
564569
break;
565570
case InlineAsmVectorType::v8:
@@ -589,9 +594,9 @@ bool SYCLGenBase::emitVariableDeclaration(const InlineAsmVarDecl *D) {
589594

590595
bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
591596
// Address expression only support ld/st/red & atom instructions.
592-
if (!CurrInst ||
593-
!CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
594-
asmtok::op_prefetch, asmtok::op_red, asmtok::op_cp)) {
597+
if (!CurrInst || !CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
598+
asmtok::op_prefetch, asmtok::op_red,
599+
asmtok::op_cp, asmtok::op_ldmatrix)) {
595600
return SYCLGenError();
596601
}
597602
std::string Type;
@@ -624,6 +629,8 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
624629
if (CurrInst->is(asmtok::op_prefetch, asmtok::op_red) ||
625630
CanSuppressCast(Dst->getSymbol()))
626631
OS() << llvm::formatv("{0}", Reg);
632+
else if (CurrInst->is(asmtok::op_ldmatrix))
633+
OS() << llvm::formatv("(uintptr_t){0}", Reg);
627634
else
628635
OS() << llvm::formatv("(({0} *)(uintptr_t){1})", Type, Reg);
629636
break;
@@ -1305,6 +1312,64 @@ class SYCLGen : public SYCLGenBase {
13051312
return SYCLGenSuccess();
13061313
}
13071314

1315+
bool handle_ldmatrix(const InlineAsmInstruction *Inst) override {
1316+
if (Inst->getNumInputOperands() != 1)
1317+
return SYCLGenError();
1318+
1319+
const auto *Type = dyn_cast<InlineAsmBuiltinType>(Inst->getType(0));
1320+
1321+
if (!Type || Type->getKind() != InlineAsmBuiltinType::b16)
1322+
return SYCLGenError();
1323+
1324+
const InlineAsmVectorExpr *VE;
1325+
if (VE = dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand())) {
1326+
auto numOutputOperands = VE->getNumElements();
1327+
if (Inst->hasAttr(InstAttr::x1)) {
1328+
if (numOutputOperands != 1)
1329+
return SYCLGenError();
1330+
} else if (Inst->hasAttr(InstAttr::x2)) {
1331+
if (numOutputOperands != 2)
1332+
return SYCLGenError();
1333+
} else if (Inst->hasAttr(InstAttr::x4)) {
1334+
if (numOutputOperands != 4)
1335+
return SYCLGenError();
1336+
}
1337+
} else {
1338+
return SYCLGenError();
1339+
}
1340+
1341+
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
1342+
CurrInst = Inst;
1343+
const auto *Src =
1344+
dyn_cast_or_null<InlineAsmAddressExpr>(Inst->getInputOperand(0));
1345+
if (!Src)
1346+
return false;
1347+
1348+
OS() << MapNames::getDpctNamespace() << "experimental::matrix::ldmatrix(";
1349+
if (emitStmt(Src)) {
1350+
return SYCLGenError();
1351+
}
1352+
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
1353+
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
1354+
continue;
1355+
OS() << ", &";
1356+
if (emitStmt(VE->getElement(Inst)))
1357+
return SYCLGenError();
1358+
}
1359+
if (Inst->hasAttr(InstAttr::trans))
1360+
OS() << ", true";
1361+
OS() << ");";
1362+
const auto *KernelDecl = getImmediateOuterFuncDecl(GAS);
1363+
if (KernelDecl) {
1364+
auto FuncInfo = DeviceFunctionDecl::LinkRedecls(KernelDecl);
1365+
if (FuncInfo)
1366+
FuncInfo->addSubGroupSizeRequest(32, GAS->getBeginLoc(),
1367+
DpctGlobalInfo::getSubGroup(GAS));
1368+
}
1369+
1370+
return SYCLGenSuccess();
1371+
}
1372+
13081373
bool handle_prefetch(const InlineAsmInstruction *Inst) override {
13091374
if (!DpctGlobalInfo::useExtPrefetch() || Inst->getNumInputOperands() != 1)
13101375
return SYCLGenError();
@@ -2881,6 +2946,7 @@ class SYCLGen : public SYCLGenBase {
28812946
bool handle_ld(const InlineAsmInstruction *Inst) override {
28822947
if (Inst->getNumInputOperands() != 1)
28832948
return SYCLGenError();
2949+
28842950
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
28852951
CurrInst = Inst;
28862952
const auto *Src =

clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class InlineAsmBuiltinType : public InlineAsmType {
116116
// This class is used for device asm vector types.
117117
class InlineAsmVectorType : public InlineAsmType {
118118
public:
119-
enum VecKind { v2, v4, v8 };
119+
enum VecKind { v2, v4, v8, x1, x2, x4 };
120120

121121
private:
122122
VecKind Kind;

clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ InlineAsmStmtResult InlineAsmParser::ParseInstruction() {
327327
if (!Tok.getIdentifier() || !Tok.getIdentifier()->isInstruction())
328328
return AsmStmtError();
329329

330-
InlineAsmIdentifierInfo *Opcode = Tok.getIdentifier();
330+
Opcode = Tok.getIdentifier();
331331
ConsumeToken();
332332

333333
SmallVector<InstAttr, 4> Attrs;
@@ -736,20 +736,38 @@ InlineAsmExprResult InlineAsmParser::ActOnParenExpr(InlineAsmExpr *SubExpr) {
736736
InlineAsmExprResult
737737
InlineAsmParser::ActOnVectorExpr(ArrayRef<InlineAsmExpr *> Vec) {
738738

739-
// Vector size must be 2, 4, or 8.
739+
// Vector size for ldmatrix are 1, 2, 4
740+
// size(x) = 2 * sizeof(v).
740741
InlineAsmVectorType::VecKind Kind;
741-
switch (Vec.size()) {
742-
case 2:
743-
Kind = InlineAsmVectorType::v2;
744-
break;
745-
case 4:
746-
Kind = InlineAsmVectorType::v4;
747-
break;
748-
case 8:
749-
Kind = InlineAsmVectorType::v8;
750-
break;
751-
default:
752-
return AsmExprError();
742+
if (Opcode->getTokenID() == asmtok::op_ldmatrix) {
743+
switch (Vec.size()) {
744+
case 1:
745+
Kind = InlineAsmVectorType::x1;
746+
break;
747+
case 2:
748+
Kind = InlineAsmVectorType::x2;
749+
break;
750+
case 4:
751+
Kind = InlineAsmVectorType::x4;
752+
break;
753+
default:
754+
return AsmExprError();
755+
}
756+
} else {
757+
// Vector size must be 2, 4, or 8.
758+
switch (Vec.size()) {
759+
case 2:
760+
Kind = InlineAsmVectorType::v2;
761+
break;
762+
case 4:
763+
Kind = InlineAsmVectorType::v4;
764+
break;
765+
case 8:
766+
Kind = InlineAsmVectorType::v8;
767+
break;
768+
default:
769+
return AsmExprError();
770+
}
753771
}
754772

755773
InlineAsmBuiltinType *ElementType = nullptr;

clang/lib/DPCT/RulesAsm/Parser/AsmParser.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,8 @@ class InlineAsmParser {
247247
};
248248

249249
public:
250+
InlineAsmIdentifierInfo *Opcode;
251+
250252
InlineAsmParser(InlineAsmContext &Ctx, SourceMgr &Mgr)
251253
: Lexer(*Mgr.getMemoryBuffer(Mgr.getMainFileID())), Context(Ctx),
252254
SrcMgr(Mgr), CurScope(nullptr) {

clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,14 @@ MODIFIER(v2, ".v2")
274274
MODIFIER(v4, ".v4")
275275
MODIFIER(v8, ".v8")
276276

277+
// Matrix modifiers
278+
MODIFIER(x1, ".x1")
279+
MODIFIER(x2, ".x2")
280+
MODIFIER(x4, ".x4")
281+
282+
// Matrix shape
283+
MODIFIER(m8n8, ".m8n8")
284+
277285
STATE_SPACE(reg, ".reg")
278286
STATE_SPACE(sreg, ".sreg")
279287
STATE_SPACE(const, ".const")
@@ -420,6 +428,8 @@ MODIFIER(ecr, ".ecr")
420428
MODIFIER(rc16, ".rc16")
421429
MODIFIER(cs, ".cs")
422430
MODIFIER(to, ".to")
431+
MODIFIER(aligned, ".aligned")
432+
MODIFIER(trans, ".trans")
423433

424434
#undef LINKAGE
425435
#undef TARGET

clang/lib/DPCT/SrcAPI/APINames_ASM.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ ENTRY("griddepcontrol", "griddepcontrol", false, NO_FLAG, P1, "Comment")
7575
ENTRY("isspacep", "isspacep", false, NO_FLAG, P1, "Comment")
7676
ENTRY("istypep", "istypep", false, NO_FLAG, P1, "Comment")
7777
ENTRY("ld", "ld", true, NO_FLAG, P1, "Partial")
78-
ENTRY("ldmatrix", "ldmatrix", false, NO_FLAG, P1, "Comment")
78+
ENTRY("ldmatrix", "ldmatrix", true, NO_FLAG, P1, "Successful")
7979
ENTRY("ldu", "ldu", false, NO_FLAG, P1, "Comment")
8080
ENTRY("lg2", "lg2", true, NO_FLAG, P1, "Successful")
8181
ENTRY("lop3", "lop3", true, NO_FLAG, P1, "Successful")

0 commit comments

Comments
 (0)