diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index fc9cb1be635e..a53c5d5d0d94 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -514,14 +514,17 @@ bool SYCLGenBase::emitType(const InlineAsmType *T) { bool SYCLGenBase::emitBuiltinType(const InlineAsmBuiltinType *T) { switch (T->getKind()) { // clang-format off + case InlineAsmBuiltinType::b1: OS() << "uint8_t"; break; case InlineAsmBuiltinType::b8: OS() << "uint8_t"; break; case InlineAsmBuiltinType::b16: OS() << "uint16_t"; break; case InlineAsmBuiltinType::b32: OS() << "uint32_t"; break; case InlineAsmBuiltinType::b64: OS() << "uint64_t"; break; + case InlineAsmBuiltinType::u4: OS() << "uint8_t"; break; case InlineAsmBuiltinType::u8: OS() << "uint8_t"; break; case InlineAsmBuiltinType::u16: OS() << "uint16_t"; break; case InlineAsmBuiltinType::u32: OS() << "uint32_t"; break; case InlineAsmBuiltinType::u64: OS() << "uint64_t"; break; + case InlineAsmBuiltinType::s4: OS() << "int8_t"; break; case InlineAsmBuiltinType::s8: OS() << "int8_t"; break; case InlineAsmBuiltinType::s16: OS() << "int16_t"; break; case InlineAsmBuiltinType::s32: OS() << "int32_t"; break; @@ -556,6 +559,9 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) { return SYCLGenError(); OS() << ", "; switch (T->getKind()) { + case InlineAsmVectorType::v1: + OS() << 1; + break; case InlineAsmVectorType::v2: OS() << 2; break; @@ -1305,6 +1311,386 @@ class SYCLGen : public SYCLGenBase { return SYCLGenSuccess(); } + bool handle_mma(const InlineAsmInstruction *Inst) override { + if (Inst->getNumInputOperands() != 3) + return SYCLGenError(); + + const InlineAsmVectorExpr *DMatVE = + dyn_cast(Inst->getOutputOperand()); + if (!DMatVE) + return SYCLGenError(); + + // Only row Layout is supported for of A matrix and + // only col Layout is supported for of B matrix + if (Inst->getAttr(3) != InstAttr::row || Inst->getAttr(4) != InstAttr::col) + return SYCLGenError(); + + // Only f16 type is supported for A and B matrix data + const auto *DType = dyn_cast(Inst->getType(0)); + const auto *AType = dyn_cast(Inst->getType(1)); + const auto *BType = dyn_cast(Inst->getType(2)); + const auto *CType = dyn_cast(Inst->getType(3)); + + if (!(AType && BType && CType && DType)) + return SYCLGenError(); + + // Data types of matrix elements for A&B and C&D matrices should be same + if ((AType->getKind() != BType->getKind()) || + (CType->getKind() != DType->getKind())) + return SYCLGenError(); + + // Check the validity of AB & CD types + std::string ABType, CDType; + if (tryEmitType(ABType, AType)) + return SYCLGenError(); + + if (tryEmitType(CDType, CType)) + return SYCLGenError(); + + // Register sizes for vector elements of A, B, C & D matrices + unsigned NumVecElements[4] = {0}; + + // Sizes of A & B matrices + std::string M, N, K; + + // Operator for m8n8k128/m16n8k128/m16n8k256 + std::string MatrixOp; + + // Data type used to multiply A & B matrices + std::string MulType; + if (Inst->hasAttr(InstAttr::m8n8k4)) { + M = "8"; + N = "8"; + K = "4"; + // f16 & f64 types are supported for A and B matrices of m8n8k4 + if (AType->getKind() == InlineAsmBuiltinType::f16) { + // If A matrix type is f16, then C&D matrix types can only be f16/f32 + if (CType->getKind() == AType->getKind()) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else if (CType->getKind() == InlineAsmBuiltinType::f32) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 8; // C + NumVecElements[3] = 8; // D + } else + return SYCLGenError(); + } else if (AType->getKind() == InlineAsmBuiltinType::f64) { + // If A matrix type is f64, then C&D matrix types can only be f64 + if (CType->getKind() == AType->getKind()) { + NumVecElements[0] = 1; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 2; // C + NumVecElements[3] = 2; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m8n8k16)) { + M = "8"; + N = "8"; + K = "16"; + // Only s8/u8 types are supported for A and B matrices of m8n8k16 + if (AType->getKind() == InlineAsmBuiltinType::s8 || + AType->getKind() == InlineAsmBuiltinType::u8) { + // If A matrix type is s8/u8, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 1; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 2; // C + NumVecElements[3] = 2; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m8n8k32)) { + M = "8"; + N = "8"; + K = "32"; + // Only s4/u4 types are supported for A and B matrices of m16n8k32 + if (AType->getKind() == InlineAsmBuiltinType::s4 || + AType->getKind() == InlineAsmBuiltinType::u4) { + // If A matrix type is s4/u4, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 1; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 2; // C + NumVecElements[3] = 2; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m8n8k128)) { + M = "8"; + N = "8"; + K = "128"; + // Only b1 type is supported for A and B matrices of m16n8k128 + if (AType->getKind() == InlineAsmBuiltinType::b1) { + // If A matrix type is b1, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 1; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 2; // C + NumVecElements[3] = 2; // D + + // Only and/xor bitwise operations are supported for m8n8k128 + if (Inst->hasAttr(InstAttr::op_and)) + MatrixOp = "and"; + else if (Inst->hasAttr(InstAttr::op_xor)) + MatrixOp = "xor"; + else + return SYCLGenError(); + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m16n8k4)) { + M = "16"; + N = "8"; + K = "4"; + // Only f64 type is supported for A and B matrices of m16n8k4 + if (AType->getKind() == InlineAsmBuiltinType::f64) { + // If A matrix type is f64, then C&D matrix types can only be f64 + if (CType->getKind() == InlineAsmBuiltinType::f64) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m16n8k8)) { + M = "16"; + N = "8"; + K = "8"; + // Only f16/f64 types are supported for A and B matrices of m16n8k8 + if (AType->getKind() == InlineAsmBuiltinType::f16) { + // If A matrix type is f16, then C&D matrix types can only be f16/f32 + if (CType->getKind() == InlineAsmBuiltinType::f16) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 2; // C + NumVecElements[3] = 2; // D + } else if (CType->getKind() == InlineAsmBuiltinType::f32) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else if (AType->getKind() == InlineAsmBuiltinType::f64) { + // If A matrix type is f64, then C&D matrix types can only be f64 + if (CType->getKind() == InlineAsmBuiltinType::f64) { + NumVecElements[0] = 4; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m16n8k16)) { + M = "16"; + N = "8"; + K = "16"; + // Only f16/f64/s8/u8 type is supported for A and B matrices of m16n8k16 + if (AType->getKind() == InlineAsmBuiltinType::f16) { + // If A matrix type is f16, then C&D matrix types can only be f16/f32 + if (CType->getKind() == AType->getKind()) { + NumVecElements[0] = 4; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 2; // C + NumVecElements[3] = 2; // D + } else if (CType->getKind() == InlineAsmBuiltinType::f32) { + NumVecElements[0] = 4; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else if (AType->getKind() == InlineAsmBuiltinType::f64) { + // If A matrix type is f64, then C&D matrix types can only be f64 + if (CType->getKind() == AType->getKind()) { + NumVecElements[0] = 8; // A + NumVecElements[1] = 4; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else if (AType->getKind() == InlineAsmBuiltinType::s8 || + AType->getKind() == InlineAsmBuiltinType::u8) { + // If A matrix type is s8/u8, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m16n8k32)) { + M = "16"; + N = "8"; + K = "32"; + // Only s4/s8/u4/u8 types are supported for A and B matrices of m16n8k32 + if (AType->getKind() == InlineAsmBuiltinType::s4 || + AType->getKind() == InlineAsmBuiltinType::u4) { + // If A matrix type is s4/u4, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else if (AType->getKind() == InlineAsmBuiltinType::s8 || + AType->getKind() == InlineAsmBuiltinType::u8) { + // If A matrix type is s8/u8, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 4; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m16n8k64)) { + M = "16"; + N = "8"; + K = "64"; + // Only s4/u4 types are supported for A and B matrices of m16n8k64 + if (AType->getKind() == InlineAsmBuiltinType::s4 || + AType->getKind() == InlineAsmBuiltinType::u4) { + // If A matrix type is s4/u4, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 4; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m16n8k128)) { + M = "16"; + N = "8"; + K = "128"; + // Only b1 type is supported for A and B matrices of m16n8k128 + if (AType->getKind() == InlineAsmBuiltinType::b1) { + // If A matrix type is b1, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + + // Only and/xor bitwise operations are supported for m16n8k128 + if (Inst->hasAttr(InstAttr::op_and)) + MatrixOp = "and"; + else if (Inst->hasAttr(InstAttr::op_xor)) + MatrixOp = "xor"; + else + return SYCLGenError(); + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m16n8k256)) { + M = "16"; + N = "8"; + K = "256"; + // Only b1 type is supported for A and B matrices of m16n8k256 + if (AType->getKind() == InlineAsmBuiltinType::b1) { + // If A matrix type is b1, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 4; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + + // Only and/xor bitwise operations are supported for m16n8k256 + if (Inst->hasAttr(InstAttr::op_and)) + MatrixOp = "and"; + else if (Inst->hasAttr(InstAttr::op_xor)) + MatrixOp = "xor"; + else + return SYCLGenError(); + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else + return SYCLGenError(); + + // Check the register sizes for vector elements of A, B, C & D matrices + for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands(); + InputOp++) { + if (auto VE = + dyn_cast(Inst->getInputOperand(InputOp))) { + if (VE->getNumElements() != NumVecElements[InputOp]) + return SYCLGenError(); + } else + return SYCLGenError(); + } + if (DMatVE->getNumElements() != NumVecElements[3]) + return SYCLGenError(); + + MulType = ABType; + OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma"; + OS() << "<"; + OS() << M << ", " << N << ", " << K << ", "; + OS() << MulType; + if (!MatrixOp.empty()) { + OS() << ", sycl::bit_" << MatrixOp << "<>"; + } + OS() << ">("; + + // Add D matrix address values to store the MAD result + for (unsigned Inst = 0; Inst != DMatVE->getNumElements(); ++Inst) { + if (isa(DMatVE->getElement(Inst))) + continue; + OS() << "&"; + if (emitStmt(DMatVE->getElement(Inst))) + return SYCLGenError(); + if ((Inst + 1) != DMatVE->getNumElements()) + OS() << ", "; + } + + // Add A, B & C matrix values to compute MAD + for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands(); + InputOp++) { + if (auto VE = + dyn_cast(Inst->getInputOperand(InputOp))) { + for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) { + if (isa(VE->getElement(Inst))) + continue; + OS() << ", "; + if (emitStmt(VE->getElement(Inst))) + return SYCLGenError(); + } + } else { + return SYCLGenError(); + } + } + + OS() << ");"; + + const auto *KernelDecl = getImmediateOuterFuncDecl(GAS); + if (KernelDecl) { + auto FuncInfo = DeviceFunctionDecl::LinkRedecls(KernelDecl); + if (FuncInfo) + FuncInfo->addSubGroupSizeRequest(32, GAS->getBeginLoc(), + DpctGlobalInfo::getSubGroup(GAS)); + } + + return SYCLGenSuccess(); + } + bool handle_prefetch(const InlineAsmInstruction *Inst) override { if (!DpctGlobalInfo::useExtPrefetch() || Inst->getNumInputOperands() != 1) return SYCLGenError(); @@ -2530,11 +2916,10 @@ class SYCLGen : public SYCLGenBase { Op = std::move(NewOp); } - bool HasHalfOrBfloat16 = - SrcType->getKind() == InlineAsmBuiltinType::f16 || - DesType->getKind() == InlineAsmBuiltinType::f16 || - SrcType->getKind() == InlineAsmBuiltinType::bf16 || - DesType->getKind() == InlineAsmBuiltinType::bf16; + bool HasHalfOrBfloat16 = SrcType->getKind() == InlineAsmBuiltinType::f16 || + DesType->getKind() == InlineAsmBuiltinType::f16 || + SrcType->getKind() == InlineAsmBuiltinType::bf16 || + DesType->getKind() == InlineAsmBuiltinType::bf16; if (DpctGlobalInfo::useIntelDeviceMath() && HasHalfOrBfloat16) { insertHeader(HeaderType::HT_SYCL_Math); if (SrcNeedBitCast) diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h b/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h index 1922185e50df..04f7485e3476 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h @@ -99,9 +99,9 @@ class InlineAsmBuiltinType : public InlineAsmType { return ((K == Kind) || ...); } template bool isNot(Ks... K) { return ((K != Kind) && ...); } - bool isBit() const { return isOneOf(b8, b16, b32, b64); } - bool isSigned() const { return isOneOf(s8, s16, s32, s64); } - bool isUnsigned() const { return isOneOf(u8, u16, u32, u64); } + bool isBit() const { return isOneOf(b1, b8, b16, b32, b64); } + bool isSigned() const { return isOneOf(s4, s8, s16, s32, s64); } + bool isUnsigned() const { return isOneOf(u4, u8, u16, u32, u64); } bool isInt() const { return isSigned() || isUnsigned(); } bool isFloat() const { return isOneOf(f16, f32, f64); } bool isScalar() const { return isInt() || isFloat(); } @@ -116,7 +116,7 @@ class InlineAsmBuiltinType : public InlineAsmType { // This class is used for device asm vector types. class InlineAsmVectorType : public InlineAsmType { public: - enum VecKind { v2, v4, v8 }; + enum VecKind { v1, v2, v4, v8 }; private: VecKind Kind; @@ -322,7 +322,7 @@ class InlineAsmInstruction : public InlineAsmStmt { /// This represents arrtibutes like: comparsion operator, rounding modifiers, /// ... e.g. instruction setp.eq.s32 has a comparsion operator 'eq'. - SmallSet Attributes; + SmallVector Attributes; /// This represents types in instruction, e.g. mov.u32. SmallVector Types; @@ -350,11 +350,11 @@ class InlineAsmInstruction : public InlineAsmStmt { OutputOp(Out), PredOutputOp(Pred), InputOps(InOps) { StateSpaces.insert(StateSpaces.begin(), AsmStateSpaces.begin(), AsmStateSpaces.end()); - Attributes.insert(Attrs.begin(), Attrs.end()); + Attributes.insert(Attributes.begin(), Attrs.begin(), Attrs.end()); } using attr_range = - llvm::iterator_range::const_iterator>; + llvm::iterator_range::const_iterator>; using type_range = llvm::iterator_range::const_iterator>; using op_range = @@ -369,12 +369,16 @@ class InlineAsmInstruction : public InlineAsmStmt { } template bool hasAttr(Ts... Attrs) const { - return (Attributes.contains(Attrs) || ...); + return (llvm::is_contained(Attributes, Attrs) || ...); } const InlineAsmIdentifierInfo *getOpcodeID() const { return Opcode; } asmtok::TokenKind getOpcode() const { return Opcode->getTokenID(); } ArrayRef getTypes() const { return Types; } const InlineAsmType *getType(unsigned I) const { return Types[I]; } + InstAttr getAttr(unsigned I) const { + assert(I < Attributes.size() && "Attributes index out of range"); + return Attributes[I]; + } unsigned getNumTypes() const { return Types.size(); } const InlineAsmExpr *getOutputOperand() const { return OutputOp; } const InlineAsmExpr *getPredOutputOperand() const { return PredOutputOp; } diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp index 8c8b7e9ff022..4b2e763bba47 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp @@ -739,6 +739,9 @@ InlineAsmParser::ActOnVectorExpr(ArrayRef Vec) { // Vector size must be 2, 4, or 8. InlineAsmVectorType::VecKind Kind; switch (Vec.size()) { + case 1: + Kind = InlineAsmVectorType::v1; + break; case 2: Kind = InlineAsmVectorType::v2; break; diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.h b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.h index ca3196110015..f94d6c4a3df3 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.h +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.h @@ -496,7 +496,7 @@ class InlineAsmParser { /// .reg .sreg .const .local .param .shared .tex /// /// vector-specifier: one of - /// .v2 .v4 .v8 + /// .v1 .v2 .v4 .v8 /// /// type-specifier: one of /// .b8 .b16 .b32 .b64 .s8 .s16 .s32 .s64 diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def b/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def index 563d5595ec65..4155daea4aca 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def @@ -240,14 +240,17 @@ SPECIAL_REG(warpid, "%warpid", s64) SPECIAL_REG(WARP_SZ, "WARP_SZ", s64) // Built-in type names +BUILTIN_TYPE(b1, ".b1") BUILTIN_TYPE(b8, ".b8") BUILTIN_TYPE(b16, ".b16") BUILTIN_TYPE(b32, ".b32") BUILTIN_TYPE(b64, ".b64") +BUILTIN_TYPE(u4, ".u4") BUILTIN_TYPE(u8, ".u8") BUILTIN_TYPE(u16, ".u16") BUILTIN_TYPE(u32, ".u32") BUILTIN_TYPE(u64, ".u64") +BUILTIN_TYPE(s4, ".s4") BUILTIN_TYPE(s8, ".s8") BUILTIN_TYPE(s16, ".s16") BUILTIN_TYPE(s32, ".s32") @@ -270,10 +273,28 @@ BUILTIN_TYPE(s16x2, ".s16x2") BUILTIN_TYPE(u16x2, ".u16x2") // Vector modifiers +MODIFIER(v1, ".v1") MODIFIER(v2, ".v2") MODIFIER(v4, ".v4") MODIFIER(v8, ".v8") +// Matrix modifiers +MODIFIER(row, ".row") +MODIFIER(col, ".col") + +// Matrix shape +MODIFIER(m8n8k4, ".m8n8k4") +MODIFIER(m8n8k16, ".m8n8k16") +MODIFIER(m8n8k32, ".m8n8k32") +MODIFIER(m8n8k128, ".m8n8k128") +MODIFIER(m16n8k4, ".m16n8k4") +MODIFIER(m16n8k8, ".m16n8k8") +MODIFIER(m16n8k16, ".m16n8k16") +MODIFIER(m16n8k32, ".m16n8k32") +MODIFIER(m16n8k64, ".m16n8k64") +MODIFIER(m16n8k128, ".m16n8k128") +MODIFIER(m16n8k256, ".m16n8k256") + STATE_SPACE(reg, ".reg") STATE_SPACE(sreg, ".sreg") STATE_SPACE(const, ".const") @@ -368,6 +389,7 @@ MODIFIER(max, ".max") MODIFIER(op_or, ".or") MODIFIER(op_xor, ".xor") MODIFIER(op_and, ".and") +MODIFIER(op_popc, ".popc") MODIFIER(cas, ".cas") MODIFIER(exch, ".exch") MODIFIER(inc, ".inc") @@ -420,6 +442,7 @@ MODIFIER(ecr, ".ecr") MODIFIER(rc16, ".rc16") MODIFIER(cs, ".cs") MODIFIER(to, ".to") +MODIFIER(aligned, ".aligned") #undef LINKAGE #undef TARGET diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index f23ee2d8e83a..a95f8b934e70 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -9,8 +9,8 @@ #ifndef __DPCT_MATH_HPP__ #define __DPCT_MATH_HPP__ -#include #include +#include #include #include @@ -1636,7 +1636,8 @@ inline constexpr unsigned extend_vcompare2_add(AT a, BT b, unsigned c, /// \returns The extend vectorized average of the two values template inline constexpr RetT extend_vavrg2(AT a, BT b, RetT c) { - return detail::extend_vbinary2(a, b, c, detail::average()); + return detail::extend_vbinary2(a, b, c, + detail::average()); } /// Compute vectorized average of \p a and \p b, with each value treated as a 2 @@ -1933,7 +1934,8 @@ inline constexpr unsigned extend_vcompare4_add(AT a, BT b, unsigned c, /// \returns The extend vectorized average of the two values template inline constexpr RetT extend_vavrg4(AT a, BT b, RetT c) { - return detail::extend_vbinary4(a, b, c, detail::average()); + return detail::extend_vbinary4(a, b, c, + detail::average()); } /// Compute vectorized average of \p a and \p b, with each value treated as a 4 @@ -2055,6 +2057,934 @@ class joint_matrix { matrix_accessor x; const size_t num_elements; }; + +/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b16 +/// matrix (m8n8k4.row.col.f16.f16.f16.f16) +/// Requires the sub-group size of kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 4, 2, 2, 4 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template +void mma(volatile CDType *d0, volatile CDType *d1, volatile CDType *d2, + volatile CDType *d3, ABType a0, ABType a1, ABType b0, ABType b1, + CDType c0, CDType c1, CDType c2, CDType c3) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short COL_LOAD_OFFSET = 4 * ((lane % 16) / 4); + + if (M == 8 && N == 8 && K == 4) { + ABType recv_a[2], recv_b[4]; + recv_a[0] = a0; + recv_a[1] = a1; + + MulType *ra = reinterpret_cast(recv_a); + MulType *rb = reinterpret_cast(recv_b); + + float c_f[8] = {0.0f}; + + for (int i = 0; i < 4; i++) { + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[2] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 16 + i); + recv_b[3] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 16 + i); + + for (int j = 0; j < 4; j++) { + c_f[i] += static_cast(ra[j]) * static_cast(rb[j]); + c_f[i + 4] += static_cast(ra[j]) * static_cast(rb[j + 4]); + } + } + + auto c_h = reinterpret_cast(&c0); + c_f[0] += static_cast(c_h[0]); + c_f[1] += static_cast(c_h[1]); + c_h[0] = c_f[0]; + c_h[1] = c_f[1]; + + c_h = reinterpret_cast(&c1); + c_f[2] += static_cast(c_h[0]); + c_f[3] += static_cast(c_h[1]); + c_h[0] = c_f[2]; + c_h[1] = c_f[3]; + + c_h = reinterpret_cast(&c2); + c_f[4] += static_cast(c_h[0]); + c_f[5] += static_cast(c_h[1]); + c_h[0] = c_f[4]; + c_h[1] = c_f[5]; + + c_h = reinterpret_cast(&c3); + c_f[6] += static_cast(c_h[0]); + c_f[7] += static_cast(c_h[1]); + c_h[0] = c_f[6]; + c_h[1] = c_f[7]; + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b32 +/// matrix (m8n8k4.row.col.f32.f32.f32.f32) +/// Requires the sub-group size of kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 8, 2, 2, 8 +/// \tparam [in] ItemT The type of the sycl::nd_item index space class +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] d4 The 5th element to be written to the output D matrix +/// \param [in] d5 The 6th element to be written to the output D matrix +/// \param [in] d6 The 7th element to be written to the output D matrix +/// \param [in] d7 The 8th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +/// \param [in] c4 The 5th element from C matrix to be added with d4 +/// \param [in] c5 The 6th element from C matrix to be added with d5 +/// \param [in] c6 The 7th element from C matrix to be added with d6 +/// \param [in] c7 The 8th element from C matrix to be added with d7 +template +void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, CDType *d4, CDType *d5, + CDType *d6, CDType *d7, ABType a0, ABType a1, ABType b0, ABType b1, + CDType c0, CDType c1, CDType c2, CDType c3, CDType c4, CDType c5, + CDType c6, CDType c7) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2) + (lane % 2); + short COL_LOAD_OFFSET = 4 * ((lane % 16) / 4) + 2 * ((lane / 2) % 2); + + if (M == 8 && N == 8 && K == 4) { + ABType recv_a[2 * 2], recv_b[4 * 2]; + + for (int i = 0; i < 2; i++) { + recv_a[2 * i] = + dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + 2 * i); + recv_a[2 * i + 1] = + dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + 2 * i); + + recv_b[4 * i] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 16 * i); + recv_b[4 * i + 1] = + dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 16 * i); + recv_b[4 * i + 2] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 16 * i + 1); + recv_b[4 * i + 3] = + dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 16 * i + 1); + } + + MulType *ra = reinterpret_cast(recv_a); + MulType *rb = reinterpret_cast(recv_b); + for (int i = 0; i < 4; i++) { + c0 += static_cast(ra[i]) * static_cast(rb[i]); + c1 += static_cast(ra[i]) * static_cast(rb[i + 4]); + c2 += static_cast(ra[i + 4]) * static_cast(rb[i]); + c3 += static_cast(ra[i + 4]) * static_cast(rb[i + 4]); + c4 += static_cast(ra[i]) * static_cast(rb[i + 8]); + c5 += static_cast(ra[i]) * static_cast(rb[i + 12]); + c6 += static_cast(ra[i + 4]) * static_cast(rb[i + 8]); + c7 += static_cast(ra[i + 4]) * static_cast(rb[i + 12]); + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; + *d4 = c4; + *d5 = c5; + *d6 = c6; + *d7 = c7; +} + +/// Multiplies 2 8x4 & 4x8 f64 matrices and accumulates the result to a 8x8 b64 +/// matrix (m8n8k4.row.col.f64.f64.f64.f64). +/// Multiplies 2 8x16 & 16x8 u8/s8 matrices and accumulates the result to a 8x8 +/// s32 matrix (m8n8k16.row.col.s32.u8.u8.s32 / m8n8k16.row.col.s32.s8.s8.s32). +/// Multiplies 2 8x32 & 32x8 u4/s4 matrices and accumulates the result to a 8x8 +/// s32 matrix (m8n8k32.row.col.s32.u4.u4.s32 / m8n8k32.row.col.s32.s4.s4.s32). +/// Multiplies 2 8x128 & 128x8 b1 matrices and accumulates the result to a 8x8 +/// s32 matrix (mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc). +/// Multiplies 2 8x128 & 128x8 b1 matrices and accumulates the result to a 8x8 +/// s32 matrix (mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc). +/// Requires the sub-group size of kernel calling this function to be 32. +/// In: 2, 1, 1, 2 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +template , + typename ABType, typename CDType> +void mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1, + Op op = Op{}) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 8 && N == 8 && K == 4) { + for (int i = 0; i < 4; i++) { + ABType recv_a = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + ABType recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + c0 += recv_a * recv_b; + + recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + c1 += recv_a * recv_b; + } + } else if (M == 8 && N == 8 && K == 16) { + for (int i = 0; i < 4; i++) { + ABType recv_a = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + ABType recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + + MulType *a = reinterpret_cast(&recv_a); + MulType *b = reinterpret_cast(&recv_b); + + for (int k = 0; k < 4; k++) { + c0 += a[k] * b[k]; + } + + recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + for (int k = 0; k < 4; k++) { + c1 += a[k] * b[k]; + } + } + } else if (M == 8 && N == 8 && K == 32) { + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a = + dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + ABType recv_b = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + + MulType *a = reinterpret_cast(&recv_a); + MulType *b = reinterpret_cast(&recv_b); + + for (int k = 0; k < 4; k++) { + MulType a0 = a[k] >> 4; + MulType a1 = a[k] & 0x0F; + MulType b0 = b[k] >> 4; + MulType b1 = b[k] & 0x0F; + + c0 += a0 * b0; + c0 += a1 * b1; + } + + recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + for (int k = 0; k < 4; k++) { + MulType a0 = a[k] >> 4; + MulType a1 = a[k] & 0x0F; + MulType b0 = b[k] >> 4; + MulType b1 = b[k] & 0x0F; + + c1 += a0 * b0; + c1 += a1 * b1; + } + } + } + } else if (M == 8 && N == 8 && K == 128) { + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a = + dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + ABType recv_b = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + + c0 += sycl::popcount(op(recv_a, recv_b)); + + recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c1 += sycl::popcount(op(recv_a, recv_b)); + } + } + } + + *d0 = c0; + *d1 = c1; +} + +/// Multiplies 2 16x8 & 8x8 f16 matrices and accumulates the result to a +/// 16x8 f16 matrix (m16n8k8.row.col.f16.f16.f16.f16) +/// Requires the sub-group size of kernel +/// calling this function to be 32 +/// In: 2, 2, 1, 2 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +template +void mma(CDType *d0, CDType *d1, ABType a0, ABType a1, ABType b0, CDType c0, + CDType c1) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 8) { + auto c0_h = reinterpret_cast(&c0); + auto c1_h = reinterpret_cast(&c1); + + float c_f[4] = {c0_h[0], c0_h[1], c1_h[0], c1_h[1]}; + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + for (int j = 0; j < 2; j++) { + c_f[0] += static_cast(ra[j]) * static_cast(rb[j]); + c_f[1] += static_cast(ra[j]) * static_cast(rb[j + 2]); + c_f[2] += static_cast(ra[j + 2]) * static_cast(rb[j]); + c_f[3] += static_cast(ra[j + 2]) * static_cast(rb[j + 2]); + } + } + + c0_h[0] = c_f[0]; + c0_h[1] = c_f[1]; + c1_h[0] = c_f[2]; + c1_h[1] = c_f[3]; + } + + *d0 = c0; + *d1 = c1; +} + +/// Multiplies 2 16x16 & 16x8 f16 matrices and accumulates the result to a 16x8 +/// f16 matrix (m16n8k16.row.col.f16.f16.f16.f16). +/// Requires the sub-group size of kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 2, 4, 2, 2 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix +/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +template +void mma(volatile CDType *d0, volatile CDType *d1, ABType a0, ABType a1, + ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 16) { + auto c0_h = reinterpret_cast(&c0); + auto c1_h = reinterpret_cast(&c1); + + float c_f[4] = {c0_h[0], c0_h[1], c1_h[0], c1_h[1]}; + + for (int i = 0; i < 4; i++) { + ABType recv_a[4], recv_b[4]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[2] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_a[3] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[2] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + recv_b[3] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + for (int j = 0; j < 4; j++) { + c_f[0] += static_cast(ra[j]) * static_cast(rb[j]); + c_f[1] += static_cast(ra[j]) * static_cast(rb[j + 4]); + c_f[2] += static_cast(ra[j + 4]) * static_cast(rb[j]); + c_f[3] += static_cast(ra[j + 4]) * static_cast(rb[j + 4]); + } + } + + c0_h[0] = c_f[0]; + c0_h[1] = c_f[1]; + c1_h[0] = c_f[2]; + c1_h[1] = c_f[3]; + } + + *d0 = c0; + *d1 = c1; +} + +/// Multiplies 2 16x8 & 8x8 u4/s4 matrices and accumulates the result to a 16x8 +/// f64 matrix (m16n8k8.row.col.f64.f64.f64.f64). +/// Multiplies 2 16x16 & 16x8 matrices and accumulates the result to a 16x8 b32 +/// matrix (m16n8k16.row.col.f32.f16.f16.f32). +/// Multiplies 2 16x32 & 32x8 u8/s8 matrices and accumulates the result to a +/// 16x8 b32 matrix (m16n8k32.row.col.s32.u8.u8.s32 / +/// m16n8k32.row.col.s32.s8.s8.s32). +/// Multiplies 2 16x64 & 64x8 u4/s4 matrices and +/// accumulates the result to a 16x8 b32 matrix (m16n8k64.row.col.s32.u4.u4.s32 +/// / m16n8k64.row.col.s32.s4.s4.s32). +/// Multiplies 2 16x256 & 256x8 b1 matrices and accumulates the result to a 16x8 +/// s32 matrix (mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc). +/// Multiplies 2 16x256 & 256x8 b1 matrices and accumulates the result to a 16x8 +/// s32 matrix (mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc). +/// Requires the sub-group size of kernel calling this function to be 32. +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 4, 4, 2, 4 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix +/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template , + typename ABType, typename CDType> +void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, + CDType c2, CDType c3, Op op = Op{}) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 8) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c0 += recv_a[0] * recv_b[0]; + c1 += recv_a[0] * recv_b[1]; + c2 += recv_a[1] * recv_b[0]; + c3 += recv_a[1] * recv_b[1]; + } + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + c0 += recv_a[0] * recv_b[0]; + c1 += recv_a[0] * recv_b[1]; + c2 += recv_a[1] * recv_b[0]; + c3 += recv_a[1] * recv_b[1]; + } + } else if (M == 16 && N == 8 && K == 16) { + for (int i = 0; i < 4; i++) { + ABType recv_a[4], recv_b[4]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[2] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_a[3] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[2] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 4 + i); + recv_b[3] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 4 + i); + + auto *ra0 = reinterpret_cast(recv_a); + auto *ra1 = reinterpret_cast(recv_a + 2); + auto *rb0 = reinterpret_cast(recv_b); + auto *rb1 = reinterpret_cast(recv_b + 2); + + // Iterate for k (i * j) times + for (int j = 0; j < 4; j++) { + auto a0 = static_cast(ra0[j]); + auto a1 = static_cast(ra1[j]); + auto b0 = static_cast(rb0[j]); + auto b1 = static_cast(rb1[j]); + + c0 += a0 * b0; + c1 += a0 * b1; + c2 += a1 * b0; + c3 += a1 * b1; + } + } + } else if (M == 16 && N == 8 && K == 32) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + c0 += a[k] * b[k]; + c1 += a[k] * b[k + 4]; + c2 += a[k + 4] * b[k]; + c3 += a[k + 4] * b[k + 4]; + } + } + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + c0 += a[k] * b[k]; + c1 += a[k] * b[k + 4]; + c2 += a[k + 4] * b[k]; + c3 += a[k + 4] * b[k + 4]; + } + } + } else if (M == 16 && N == 8 && K == 64) { + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + MulType a00 = a[k] >> 4; + MulType a01 = a[k] & 0x0F; + MulType a10 = a[k + 4] >> 4; + MulType a11 = a[k + 4] & 0x0F; + MulType b00 = b[k] >> 4; + MulType b01 = b[k] & 0x0F; + MulType b10 = b[k + 4] >> 4; + MulType b11 = b[k + 4] & 0x0F; + + c0 += a00 * b00; + c0 += a01 * b01; + + c1 += a00 * b10; + c1 += a01 * b11; + + c2 += a10 * b00; + c2 += a11 * b01; + + c3 += a10 * b10; + c3 += a11 * b11; + } + } + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + MulType a00 = a[k] >> 4; + MulType a01 = a[k] & 0x0F; + MulType a10 = a[k + 4] >> 4; + MulType a11 = a[k + 4] & 0x0F; + MulType b00 = b[k] >> 4; + MulType b01 = b[k] & 0x0F; + MulType b10 = b[k + 4] >> 4; + MulType b11 = b[k + 4] & 0x0F; + + c0 += a00 * b00; + c0 += a01 * b01; + + c1 += a00 * b10; + c1 += a01 * b11; + + c2 += a10 * b00; + c2 += a11 * b01; + + c3 += a10 * b10; + c3 += a11 * b11; + } + } + } + } else if (M == 16 && N == 8 && K == 256) { + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c0 += sycl::popcount(op(recv_a[0], recv_b[0])); + c1 += sycl::popcount(op(recv_a[0], recv_b[1])); + c2 += sycl::popcount(op(recv_a[1], recv_b[0])); + c3 += sycl::popcount(op(recv_a[1], recv_b[1])); + } + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + c0 += sycl::popcount(op(recv_a[0], recv_b[0])); + c1 += sycl::popcount(op(recv_a[0], recv_b[1])); + c2 += sycl::popcount(op(recv_a[1], recv_b[0])); + c3 += sycl::popcount(op(recv_a[1], recv_b[1])); + } + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 16x16 & 16x8 f64 matrices and accumulates the result to a 16x8 +/// f64 matrix (m16n8k16.row.col.f64.f64.f64.f64) Requires the sub-group size of +/// kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 4, 8, 4, 4 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix +/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix +/// \param [in] a4 The 5th element from A matrix to be multiplied with B matrix +/// \param [in] a5 The 6th element from A matrix to be multiplied with B matrix +/// \param [in] a6 The 7th element from A matrix to be multiplied with B matrix +/// \param [in] a7 The 8th element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] b2 The 3rd element from B matrix to be multiplied with A matrix +/// \param [in] b3 The 4th element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template +void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType a2, ABType a3, ABType a4, ABType a5, ABType a6, ABType a7, + ABType b0, ABType b1, ABType b2, ABType b3, CDType c0, CDType c1, + CDType c2, CDType c3) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 16) { + ABType recv_a[16 * 2], recv_b[16 * 2]; + + for (int i = 0; i < 4; i++) { + recv_a[i] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[i + 4] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[i + 8] = dpct::select_from_sub_group(sg, a4, ROW_LOAD_OFFSET + i); + recv_a[i + 12] = dpct::select_from_sub_group(sg, a6, ROW_LOAD_OFFSET + i); + recv_a[i + 16] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_a[i + 20] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_a[i + 24] = dpct::select_from_sub_group(sg, a5, ROW_LOAD_OFFSET + i); + recv_a[i + 28] = dpct::select_from_sub_group(sg, a7, ROW_LOAD_OFFSET + i); + + recv_b[i] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[i + 4] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[i + 8] = dpct::select_from_sub_group(sg, b2, COL_LOAD_OFFSET + i); + recv_b[i + 12] = dpct::select_from_sub_group(sg, b3, COL_LOAD_OFFSET + i); + recv_b[i + 16] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + recv_b[i + 20] = + dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + recv_b[i + 24] = + dpct::select_from_sub_group(sg, b2, COL_LOAD_OFFSET + i + 4); + recv_b[i + 28] = + dpct::select_from_sub_group(sg, b3, COL_LOAD_OFFSET + i + 4); + } + + for (int i = 0; i < 16; i++) { + c0 += recv_a[i] * recv_b[i]; + c1 += recv_a[i] * recv_b[i + 16]; + c2 += recv_a[i + 16] * recv_b[i]; + c3 += recv_a[i + 16] * recv_b[i + 16]; + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 16x4 & 4x8 f16 matrices and accumulates the result to a +/// 16x8 f32 matrix (m16n8k4.row.col.f16.f16.f16.f16 / +/// m16n8k4.row.col.f32.f16.f16.f32). +/// Multiplies 2 16x4 & 4x8 f64 matrices and accumulates the result to a +/// 16x8 f64 matrix (m16n8k4.row.col.f64.f64.f64.f64). +/// Multiplies 2 16x8 & 8x8 f16 matrices and accumulates the result to a +/// 16x8 f32 matrix (m16n8k8.row.col.f32.f16.f16.f32). +/// Multiplies 2 16x8 & 8x8 f64 matrices and accumulates the result to a +/// 16x8 f64 matrix (m16n8k8.row.col.f64.f64.f64.f64). +/// Multiplies 2 16x16 & 16x8 u8/s8 matrices and accumulates the result to a +/// 16x8 s32 matrix (m16n8k16.row.col.s32.u8.u8.s32 / +/// m16n8k16.row.col.s32.s8.s8.s32). +/// Multiplies 2 16x32 & 32x8 u4/s4 matrices and accumulates the result to a +/// 16x8 s32 matrix (m16n8k32.row.col.s32.u4.u4.s32 / +/// m16n8k32.row.col.s32.s4.s4.s32). +/// Multiplies 2 16x128 & 128x8 b1 matrices and accumulates the result to a 16x8 +/// s32 matrix (mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc). +/// Multiplies 2 16x128 & 128x8 b1 matrices and accumulates the result to a 16x8 +/// s32 matrix (mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc). +/// Requires the sub-group size of kernel. +/// calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 4, 2, 1, 4 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template , + typename ABType, typename CDType> +void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType b0, CDType c0, CDType c1, CDType c2, CDType c3, Op op = Op{}) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 4) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c0 += recv_a[0] * recv_b[0]; + c1 += recv_a[0] * recv_b[1]; + c2 += recv_a[1] * recv_b[0]; + c3 += recv_a[1] * recv_b[1]; + } + } else if (M == 16 && N == 8 && K == 8) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + for (int j = 0; j < 2; j++) { + c0 += static_cast(ra[j]) * static_cast(rb[j]); + c1 += static_cast(ra[j]) * static_cast(rb[j + 2]); + c2 += static_cast(ra[j + 2]) * static_cast(rb[j]); + c3 += static_cast(ra[j + 2]) * static_cast(rb[j + 2]); + } + } + } else if (M == 16 && N == 8 && K == 16) { + ABType recv_a[4 * 2], recv_b[4 * 2]; + + for (int i = 0; i < 4; i++) { + recv_a[i] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[i + 4] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + + recv_b[i] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[i + 4] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + } + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + for (int i = 0; i < 16; i++) { + c0 += a[i] * b[i]; + c1 += a[i] * b[i + 16]; + c2 += a[i + 16] * b[i]; + c3 += a[i + 16] * b[i + 16]; + } + } else if (M == 16 && N == 8 && K == 32) { + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + MulType a00 = a[k] >> 4; + MulType a01 = a[k] & 0x0F; + MulType a10 = a[k + 4] >> 4; + MulType a11 = a[k + 4] & 0x0F; + MulType b00 = b[k] >> 4; + MulType b01 = b[k] & 0x0F; + MulType b10 = b[k + 4] >> 4; + MulType b11 = b[k + 4] & 0x0F; + + c0 += a00 * b00; + c0 += a01 * b01; + + c1 += a00 * b10; + c1 += a01 * b11; + + c2 += a10 * b00; + c2 += a11 * b01; + + c3 += a10 * b10; + c3 += a11 * b11; + } + } + } + } else if (M == 16 && N == 8 && K == 128) { + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c0 += sycl::popcount(op(recv_a[0], recv_b[0])); + c1 += sycl::popcount(op(recv_a[0], recv_b[1])); + c2 += sycl::popcount(op(recv_a[1], recv_b[0])); + c3 += sycl::popcount(op(recv_a[1], recv_b[1])); + } + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + } // namespace matrix } // namespace experimental diff --git a/clang/test/dpct/asm/mma.cu b/clang/test/dpct/asm/mma.cu new file mode 100644 index 000000000000..a959e383b2b1 --- /dev/null +++ b/clang/test/dpct/asm/mma.cu @@ -0,0 +1,402 @@ +// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2 +// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2 +// RUN: dpct --format-range=none -out-root %T/mma %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only +// RUN: FileCheck %s --match-full-lines --input-file %T/mma/mma.dp.cpp +// RUN: %if build_lit %{icpx -c -DNO_BUILD_TEST -fsycl %T/mma/mma.dp.cpp -o %T/mma/mma.dp.o %} + +// clang-format off +#include +#include + +/* +As per PTX ASM 8.1, below is the status of supported configurations + +--------- --------- ---------- ----------- ------------- +| Shape | | A | | B | | C / D | | Supported | +--------- --------- ---------- ----------- ------------- +m8n8k4 .f16 .f16 .f16/.f32 Yes + .f64 .f64 .f64 Yes +m8n8k16 .s8/.u8 .s8/.u8 .s32 Yes +m8n8k32 .s4/.u4 .s4/.u4 .s32 Yes +m8n8k128 .b1 .b1 .s32 Yes + +m16n8k4 .tf32 .tf32 .tf32 No + .f64 .f64 .f64 Yes +m16n8k8 .f16/.bf16 .f16/.bf16 .f16/.f32 Partial (.f16.f16.f16.f16 / .f32.f16.f16.f32) + .tf32 .tf32 .tf32 No + .f64 .f64 .f64 Yes +m16n8k16 .f16/.bf16 .f16/.bf16 .f16/.f32 Partial (.f16.f16.f16.f16 / .f32.f16.f16.f32) + .f64 .f64 .f64 Yes + .s8/.u8 .s8/.u8 .s32 Yes +m16n8k32 .s4/.u4 .s4/.u4 .s32 Yes + .s8/.u8 .s8/.u8 .s32 Yes +m16n8k64 .s4/.u4 .s4/.u4 .s32 Yes +m16n8k128 .b1 .b1 .s32 Yes +m16n8k256 .b1 .b1 .s32 Yes + +A Layout: row +B Layout: col +*/ + +__global__ void mma_kernel_m8n8k4(int *a, int *b, float *c, double *d) { + // CHECK: dpct::experimental::matrix::mma<8, 8, 4, sycl::half>(&c[0], &c[1], &c[2], &c[3], a[0], a[1], b[0], b[1], c[0], c[1], c[2], c[3]); + asm("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6, %7 }, " + " { %8, %9, %10, %11 };" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + + // CHECK: dpct::experimental::matrix::mma<8, 8, 4, sycl::half>(&c[0], &c[1], &c[2], &c[3], &c[4], &c[5], &c[6], &c[7], a[0], a[1], b[0], b[1], c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]); + asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " + " { %0, %1, %2, %3, %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %10, %11 }, " + " { %0, %1, %2, %3, %4, %5, %6, %7 };" + : "+f"(c[0]), "+f"(c[1]), "+f"(c[2]), "+f"(c[3]), "+f"(c[4]), "+f"(c[5]), "+f"(c[6]), "+f"(c[7]) + : "r"(a[0]), "r"(a[1]), + "r"(b[0]), "r"(b[1])); + + // CHECK: dpct::experimental::matrix::mma<8, 8, 4, double>(&d[0], &d[1], a[0], b[0], d[0], d[1]); + asm("mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 " + " { %0, %1 }, " + " { %2 }, " + " { %3 }, " + " { %4, %5 };" + : "=d"(d[0]), "=d"(d[1]) + : "d"(a[0]), + "d"(b[0]), + "d"(d[0]), "d"(d[1])); +} + +__global__ void mma_kernel_m8n8k16(int *a, int *b, int *c, int *d) { + // CHECK: dpct::experimental::matrix::mma<8, 8, 16, int8_t>(&d[0], &d[1], a[0], b[0], c[0], c[1]); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 " + " { %0, %1 }, " + " { %2 }, " + " { %3 }, " + " { %4, %5 };" + : "=r"(d[0]), "=r"(d[1]) + : "r"(a[0]), + "r"(b[0]), + "r"(c[0]), "r"(c[1])); + + // CHECK: dpct::experimental::matrix::mma<8, 8, 16, uint8_t>(&d[0], &d[1], a[0], b[0], c[0], c[1]); + asm("mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 " + " { %0, %1 }, " + " { %2 }, " + " { %3 }, " + " { %4, %5 };" + : "=r"(d[0]), "=r"(d[1]) + : "r"(a[0]), + "r"(b[0]), + "r"(c[0]), "r"(c[1])); +} + +__global__ void mma_kernel_m8n8k32(int *a, int *b, int *c, int *d) { + // CHECK: dpct::experimental::matrix::mma<8, 8, 32, int8_t>(&d[0], &d[1], a[0], b[0], c[0], c[1]); + asm("mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 " + " { %0, %1 }, " + " { %2 }, " + " { %3 }, " + " { %4, %5 };" + : "=r"(d[0]), "=r"(d[1]) + : "r"(a[0]), + "r"(b[0]), + "r"(c[0]), "r"(c[1])); + + // CHECK: dpct::experimental::matrix::mma<8, 8, 32, uint8_t>(&d[0], &d[1], a[0], b[0], c[0], c[1]); + asm("mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 " + " { %0, %1 }, " + " { %2 }, " + " { %3 }, " + " { %4, %5 };" + : "=r"(d[0]), "=r"(d[1]) + : "r"(a[0]), + "r"(b[0]), + "r"(c[0]), "r"(c[1])); +} + +__global__ void mma_kernel_m8n8k128(int *a, int *b, int *c, int *d) { + // CHECK: dpct::experimental::matrix::mma<8, 8, 128, uint8_t, sycl::bit_and<>>(&d[0], &d[1], a[0], b[0], c[0], c[1]); + asm ("mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc " + " { %0, %1 }, " + " { %2 }, " + " { %3 }, " + " { %4, %5 };" + : "=r"(d[0]), "=r"(d[1]) + : "r"(a[0]), + "r"(b[0]), + "r"(c[0]), "r"(c[1])); + + // CHECK: dpct::experimental::matrix::mma<8, 8, 128, uint8_t, sycl::bit_xor<>>(&d[0], &d[1], a[0], b[0], c[0], c[1]); + asm ("mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc " + " { %0, %1 }, " + " { %2 }, " + " { %3 }, " + " { %4, %5 };" + : "=r"(d[0]), "=r"(d[1]) + : "r"(a[0]), + "r"(b[0]), + "r"(c[0]), "r"(c[1])); +} + +__global__ void mma_kernel_m16n8k4(float *fa, float *fb, float *fc, float *fd, double *da, double *db, double *dc, double *dd) { + // CHECK: dpct::experimental::matrix::mma<16, 8, 4, double>(&dd[0], &dd[1], &dd[2], &dd[3], da[0], da[1], db[0], dc[0], dc[1], dc[2], dc[3]); + asm("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64 " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=d"(dd[0]), "=d"(dd[1]), "=d"(dd[2]), "=d"(dd[3]) + : "d"(da[0]), "d"(da[1]), + "d"(db[0]), + "d"(dc[0]), "d"(dc[1]), "d"(dc[2]), "d"(dc[3])); +} + +__global__ void mma_kernel_m16n8k8(int *a, int *b, uint *c, uint *d, float *fa, float *fb, float *fc, float *fd, double *da, double *db, double *dc, double *dd) { + // CHECK: dpct::experimental::matrix::mma<16, 8, 8, sycl::half>(&d[0], &d[1], (*(reinterpret_cast(&a[0]))), (*(reinterpret_cast(&a[1]))), (*(reinterpret_cast(&b[0]))), c[0], c[1]); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + " { %0, %1 }, " + " { %2, %3 }, " + " { %4 }, " + " { %5, %6 };" + : "=r"(d[0]), "=r"(d[1]) + : "r"(*(reinterpret_cast(&a[0]))), + "r"(*(reinterpret_cast(&a[1]))), + "r"(*(reinterpret_cast(&b[0]))), + "r"(c[0]), "r"(c[1])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 8, sycl::half>(&fd[0], &fd[1], &fd[2], &fd[3], *(reinterpret_cast(&a[0])), *(reinterpret_cast(&a[1])), *(reinterpret_cast(&b[0])), fc[0], fc[1], fc[2], fc[3]); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=f"(fd[0]), "=f"(fd[1]), "=f"(fd[2]), "=f"(fd[3]) + : "r"(*(reinterpret_cast(&a[0]))), + "r"(*(reinterpret_cast(&a[1]))), + "r"(*(reinterpret_cast(&b[0]))), + "f"(fc[0]), "f"(fc[1]), "f"(fc[2]), "f"(fc[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 8, double>(&dd[0], &dd[1], &dd[2], &dd[3], da[0], da[1], da[2], da[3], db[0], db[1], dc[0], dc[1], dc[2], dc[3]); + asm("mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %10, %11, %12, %13 };" + : "=d"(dd[0]), "=d"(dd[1]), "=d"(dd[2]), "=d"(dd[3]) + : "d"(da[0]), "d"(da[1]), "d"(da[2]), "d"(da[3]), + "d"(db[0]), "d"(db[1]), + "d"(dc[0]), "d"(dc[1]), "d"(dc[2]), "d"(dc[3])); +} + +__global__ void mma_kernel_m16n8k16(uint *ua, uint *ub, uint *uc, uint *ud, int *a, int *b, int *c, float *fc, int *d, double *da, double *db, double *dc, double *dd) { + // CHECK: dpct::experimental::matrix::mma<16, 8, 16, sycl::half>(&ud[0], &ud[1], a[0], a[1], a[2], a[3], b[0], b[1], uc[0], uc[1]); + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + " { %0, %1 }, " + " { %2, %3, %4, %5 }, " + " { %6, %7 }, " + " { %8, %9 };" + : "=r"(ud[0]), "=r"(ud[1]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(uc[0]), "r"(uc[1])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 16, sycl::half>(&fc[0], &fc[1], &fc[2], &fc[3], a[0], a[1], a[2], a[3], b[0], b[1], fc[0], fc[1], fc[2], fc[3]); + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %0, %1, %2, %3 };" + : "+f"(fc[0]), "+f"(fc[1]), "+f"(fc[2]), "+f"(fc[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 16, double>(&dd[0], &dd[1], &dd[2], &dd[3], da[0], da[1], da[2], da[3], da[4], da[5], da[6], da[7], db[0], db[1], db[2], db[3], dc[0], dc[1], dc[2], dc[3]); + asm("mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7, %8, %9, %10, %11 }, " + " { %12, %13, %14, %15 }, " + " { %16, %17, %18, %19 };" + : "=d"(dd[0]), "=d"(dd[1]), "=d"(dd[2]), "=d"(dd[3]) + : "d"(da[0]), "d"(da[1]), "d"(da[2]), "d"(da[3]), "d"(da[4]), "d"(da[5]), "d"(da[6]), "d"(da[7]), + "d"(db[0]), "d"(db[1]), "d"(db[2]), "d"(db[3]), + "d"(dc[0]), "d"(dc[1]), "d"(dc[2]), "d"(dc[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 16, uint8_t>(&ud[0], &ud[1], &ud[2], &ud[3], ua[0], ua[1], ub[0], uc[0], uc[1], uc[2], uc[3]); + asm("mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=r"(ud[0]), "=r"(ud[1]), "=r"(ud[2]), "=r"(ud[3]) + : "r"(ua[0]), "r"(ua[1]), + "r"(ub[0]), + "r"(uc[0]), "r"(uc[1]), "r"(uc[2]), "r"(uc[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 16, int8_t>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], b[0], c[0], c[1], c[2], c[3]); + asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), + "r"(b[0]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +} + +__global__ void mma_kernel_m16n8k32(uint *ua, uint *ub, uint *uc, uint *ud, int *a, int *b, int *c, int *d) { + // CHECK: dpct::experimental::matrix::mma<16, 8, 32, uint8_t>(&ud[0], &ud[1], &ud[2], &ud[3], ua[0], ua[1], ub[0], uc[0], uc[1], uc[2], uc[3]); + asm("mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=r"(ud[0]), "=r"(ud[1]), "=r"(ud[2]), "=r"(ud[3]) + : "r"(ua[0]), "r"(ua[1]), + "r"(ub[0]), + "r"(uc[0]), "r"(uc[1]), "r"(uc[2]), "r"(uc[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 32, int8_t>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], b[0], c[0], c[1], c[2], c[3]); + asm("mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), + "r"(b[0]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 32, uint8_t>(&ud[0], &ud[1], &ud[2], &ud[3], ua[0], ua[1], ua[2], ua[3], ub[0], ub[1], uc[0], uc[1], uc[2], uc[3]); + asm("mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %10, %11, %12, %13 };" + : "=r"(ud[0]), "=r"(ud[1]), "=r"(ud[2]), "=r"(ud[3]) + : "r"(ua[0]), "r"(ua[1]), "r"(ua[2]), "r"(ua[3]), + "r"(ub[0]), "r"(ub[1]), + "r"(uc[0]), "r"(uc[1]), "r"(uc[2]), "r"(uc[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 32, int8_t>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], a[2], a[3], b[0], b[1], c[0], c[1], c[2], c[3]); + asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %10, %11, %12, %13 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +} + +__global__ void mma_kernel_m16n8k64(uint *ua, uint *ub, uint *uc, uint *ud, int *a, int *b, int *c, int *d) { + // CHECK: dpct::experimental::matrix::mma<16, 8, 64, uint8_t>(&ud[0], &ud[1], &ud[2], &ud[3], ua[0], ua[1], ua[2], ua[3], ub[0], ub[1], uc[0], uc[1], uc[2], uc[3]); + asm("mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %10, %11, %12, %13 };" + : "=r"(ud[0]), "=r"(ud[1]), "=r"(ud[2]), "=r"(ud[3]) + : "r"(ua[0]), "r"(ua[1]), "r"(ua[2]), "r"(ua[3]), + "r"(ub[0]), "r"(ub[1]), + "r"(uc[0]), "r"(uc[1]), "r"(uc[2]), "r"(uc[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 64, int8_t>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], a[2], a[3], b[0], b[1], c[0], c[1], c[2], c[3]); + asm("mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %10, %11, %12, %13 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +} + +__global__ void mma_kernel_m16n8k128(int *a, int *b, int *c, int *d) { + // CHECK: dpct::experimental::matrix::mma<16, 8, 128, uint8_t, sycl::bit_and<>>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], b[0], c[0], c[1], c[2], c[3]); + asm ("mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), + "r"(b[0]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 128, uint8_t, sycl::bit_xor<>>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], b[0], c[0], c[1], c[2], c[3]); + asm ("mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), + "r"(b[0]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +} + +__global__ void mma_kernel_m16n8k256(int *a, int *b, int *c, int *d) { + // CHECK: dpct::experimental::matrix::mma<16, 8, 256, uint8_t, sycl::bit_and<>>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], a[2], a[3], b[0], b[1], c[0], c[1], c[2], c[3]); + asm ("mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %10, %11, %12, %13 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 256, uint8_t, sycl::bit_xor<>>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], a[2], a[3], b[0], b[1], c[0], c[1], c[2], c[3]); + asm ("mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %10, %11, %12, %13 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +} + + +int main () { + uint *uint_a, *uint_b, *uint_c, *uint_d; + int *int_a, *int_b, *int_c, *int_d; + float *float_a, *float_b, *float_c, *float_d; + double *double_a, *double_b, *double_c, *double_d; + + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m8n8k4<<<1, 32>>>(int_a, int_b, float_c, double_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m8n8k16<<<1, 32>>>(int_a, int_b, int_c, int_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m8n8k32<<<1, 32>>>(int_a, int_b, int_c, int_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m8n8k128<<<1, 32>>>(int_a, int_b, int_c, int_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m16n8k4<<<1, 32>>>(float_a, float_b, float_c, float_d, double_a, double_b, double_c, double_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m16n8k8<<<1, 32>>>(int_a, int_b, uint_c, uint_d, float_a, float_b, float_c, float_d, double_a, double_b, double_c, double_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m16n8k16<<<1, 32>>>(uint_a, uint_b, uint_c, uint_d, int_a, int_b, int_c, float_c, int_d, double_a, double_b, double_c, double_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m16n8k32<<<1, 32>>>(uint_a, uint_b, uint_c, uint_d, int_a, int_b, int_c, int_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m16n8k64<<<1, 32>>>(uint_a, uint_b, uint_c, uint_d, int_a, int_b, int_c, int_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m16n8k128<<<1, 32>>>(int_a, int_b, int_c, int_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m16n8k256<<<1, 32>>>(int_a, int_b, int_c, int_d); + + return 0; +} +// clang-format on diff --git a/docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv b/docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv index 0cd876f76810..24f2adac899d 100644 --- a/docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv +++ b/docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv @@ -54,7 +54,7 @@ max,YES, mbarrier,NO, membar,YES, Partial min,YES, -mma,NO, +mma,YES, Partial mov,YES, movmatrix,NO, mul,YES, Partial