Skip to content

[PTX] Enable migration of mma PTX ASM instruction for 11 shapes #2746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: SYCLomatic
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
395 changes: 390 additions & 5 deletions clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Large diffs are not rendered by default.

20 changes: 12 additions & 8 deletions clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ class InlineAsmBuiltinType : public InlineAsmType {
return ((K == Kind) || ...);
}
template <class... Ks> bool isNot(Ks... K) { return ((K != Kind) && ...); }
bool isBit() const { return isOneOf(b8, b16, b32, b64); }
bool isSigned() const { return isOneOf(s8, s16, s32, s64); }
bool isUnsigned() const { return isOneOf(u8, u16, u32, u64); }
bool isBit() const { return isOneOf(b1, b8, b16, b32, b64); }
bool isSigned() const { return isOneOf(s4, s8, s16, s32, s64); }
bool isUnsigned() const { return isOneOf(u4, u8, u16, u32, u64); }
bool isInt() const { return isSigned() || isUnsigned(); }
bool isFloat() const { return isOneOf(f16, f32, f64); }
bool isScalar() const { return isInt() || isFloat(); }
Expand All @@ -116,7 +116,7 @@ class InlineAsmBuiltinType : public InlineAsmType {
// This class is used for device asm vector types.
class InlineAsmVectorType : public InlineAsmType {
public:
enum VecKind { v2, v4, v8 };
enum VecKind { v1, v2, v4, v8 };

private:
VecKind Kind;
Expand Down Expand Up @@ -322,7 +322,7 @@ class InlineAsmInstruction : public InlineAsmStmt {

/// This represents arrtibutes like: comparsion operator, rounding modifiers,
/// ... e.g. instruction setp.eq.s32 has a comparsion operator 'eq'.
SmallSet<InstAttr, 4> Attributes;
SmallVector<InstAttr, 4> Attributes;

/// This represents types in instruction, e.g. mov.u32.
SmallVector<InlineAsmType *, 4> Types;
Expand Down Expand Up @@ -350,11 +350,11 @@ class InlineAsmInstruction : public InlineAsmStmt {
OutputOp(Out), PredOutputOp(Pred), InputOps(InOps) {
StateSpaces.insert(StateSpaces.begin(), AsmStateSpaces.begin(),
AsmStateSpaces.end());
Attributes.insert(Attrs.begin(), Attrs.end());
Attributes.insert(Attributes.begin(), Attrs.begin(), Attrs.end());
}

using attr_range =
llvm::iterator_range<SmallSet<InstAttr, 4>::const_iterator>;
llvm::iterator_range<SmallVector<InstAttr, 4>::const_iterator>;
using type_range =
llvm::iterator_range<SmallVector<InlineAsmType *, 4>::const_iterator>;
using op_range =
Expand All @@ -369,12 +369,16 @@ class InlineAsmInstruction : public InlineAsmStmt {
}

template <typename... Ts> bool hasAttr(Ts... Attrs) const {
return (Attributes.contains(Attrs) || ...);
return (llvm::is_contained(Attributes, Attrs) || ...);
}
const InlineAsmIdentifierInfo *getOpcodeID() const { return Opcode; }
asmtok::TokenKind getOpcode() const { return Opcode->getTokenID(); }
ArrayRef<InlineAsmType *> getTypes() const { return Types; }
const InlineAsmType *getType(unsigned I) const { return Types[I]; }
InstAttr getAttr(unsigned I) const {
assert(I < Attributes.size() && "Attributes index out of range");
return Attributes[I];
}
unsigned getNumTypes() const { return Types.size(); }
const InlineAsmExpr *getOutputOperand() const { return OutputOp; }
const InlineAsmExpr *getPredOutputOperand() const { return PredOutputOp; }
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,9 @@ InlineAsmParser::ActOnVectorExpr(ArrayRef<InlineAsmExpr *> Vec) {
// Vector size must be 2, 4, or 8.
InlineAsmVectorType::VecKind Kind;
switch (Vec.size()) {
case 1:
Kind = InlineAsmVectorType::v1;
break;
case 2:
Kind = InlineAsmVectorType::v2;
break;
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/DPCT/RulesAsm/Parser/AsmParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ class InlineAsmParser {
/// .reg .sreg .const .local .param .shared .tex
///
/// vector-specifier: one of
/// .v2 .v4 .v8
/// .v1 .v2 .v4 .v8
///
/// type-specifier: one of
/// .b8 .b16 .b32 .b64 .s8 .s16 .s32 .s64
Expand Down
23 changes: 23 additions & 0 deletions clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,17 @@ SPECIAL_REG(warpid, "%warpid", s64)
SPECIAL_REG(WARP_SZ, "WARP_SZ", s64)

// Built-in type names
BUILTIN_TYPE(b1, ".b1")
BUILTIN_TYPE(b8, ".b8")
BUILTIN_TYPE(b16, ".b16")
BUILTIN_TYPE(b32, ".b32")
BUILTIN_TYPE(b64, ".b64")
BUILTIN_TYPE(u4, ".u4")
BUILTIN_TYPE(u8, ".u8")
BUILTIN_TYPE(u16, ".u16")
BUILTIN_TYPE(u32, ".u32")
BUILTIN_TYPE(u64, ".u64")
BUILTIN_TYPE(s4, ".s4")
BUILTIN_TYPE(s8, ".s8")
BUILTIN_TYPE(s16, ".s16")
BUILTIN_TYPE(s32, ".s32")
Expand All @@ -270,10 +273,28 @@ BUILTIN_TYPE(s16x2, ".s16x2")
BUILTIN_TYPE(u16x2, ".u16x2")

// Vector modifiers
MODIFIER(v1, ".v1")
MODIFIER(v2, ".v2")
MODIFIER(v4, ".v4")
MODIFIER(v8, ".v8")

// Matrix modifiers
MODIFIER(row, ".row")
MODIFIER(col, ".col")

// Matrix shape
MODIFIER(m8n8k4, ".m8n8k4")
MODIFIER(m8n8k16, ".m8n8k16")
MODIFIER(m8n8k32, ".m8n8k32")
MODIFIER(m8n8k128, ".m8n8k128")
MODIFIER(m16n8k4, ".m16n8k4")
MODIFIER(m16n8k8, ".m16n8k8")
MODIFIER(m16n8k16, ".m16n8k16")
MODIFIER(m16n8k32, ".m16n8k32")
MODIFIER(m16n8k64, ".m16n8k64")
MODIFIER(m16n8k128, ".m16n8k128")
MODIFIER(m16n8k256, ".m16n8k256")

STATE_SPACE(reg, ".reg")
STATE_SPACE(sreg, ".sreg")
STATE_SPACE(const, ".const")
Expand Down Expand Up @@ -368,6 +389,7 @@ MODIFIER(max, ".max")
MODIFIER(op_or, ".or")
MODIFIER(op_xor, ".xor")
MODIFIER(op_and, ".and")
MODIFIER(op_popc, ".popc")
MODIFIER(cas, ".cas")
MODIFIER(exch, ".exch")
MODIFIER(inc, ".inc")
Expand Down Expand Up @@ -420,6 +442,7 @@ MODIFIER(ecr, ".ecr")
MODIFIER(rc16, ".rc16")
MODIFIER(cs, ".cs")
MODIFIER(to, ".to")
MODIFIER(aligned, ".aligned")

#undef LINKAGE
#undef TARGET
Expand Down
Loading