From 4896387f16887a53440883b3d7e532d41f9cd298 Mon Sep 17 00:00:00 2001 From: Teja Alaghari Date: Tue, 25 Mar 2025 19:53:02 +0800 Subject: [PATCH 1/6] Added support for mma migration --- clang/lib/DPCT/RulesAsm/AsmMigration.cpp | 77 +++++++++++++++++++ clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h | 12 ++- .../DPCT/RulesAsm/Parser/AsmTokenKinds.def | 8 ++ clang/runtime/dpct-rt/include/dpct/math.hpp | 73 ++++++++++++++++++ clang/test/dpct/asm/mma.cu | 45 +++++++++++ .../ASM_API_migration_status.csv | 2 +- 6 files changed, 212 insertions(+), 5 deletions(-) create mode 100644 clang/test/dpct/asm/mma.cu diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index fc9cb1be635e..7f1810f023ce 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -1305,6 +1305,83 @@ class SYCLGen : public SYCLGenBase { return SYCLGenSuccess(); } + bool handle_mma(const InlineAsmInstruction *Inst) override { + if (Inst->getNumInputOperands() != 3) + return SYCLGenError(); + + if (!Inst->hasAttr(InstAttr::m16n8k16)) + return SYCLGenError(); + + // Only row Layout is supported for of A matrix and + // only col Layout is supported for of B matrix + if (Inst->getAttr(3) != InstAttr::row || + Inst->getAttr(4) != InstAttr::col) { + return SYCLGenError(); + } + + // Only f16 type is supported for A and B matrix data + const auto *AType = dyn_cast(Inst->getType(1)); + const auto *BType = dyn_cast(Inst->getType(2)); + + std::string TypeStr; + if (!AType || !BType || + (AType->getKind() != InlineAsmBuiltinType::f16 || + BType->getKind() != InlineAsmBuiltinType::f16)) { + return SYCLGenError(); + } else { + if (tryEmitType(TypeStr, AType)) + return SYCLGenError(); + } + + const InlineAsmVectorExpr *VE = + dyn_cast(Inst->getOutputOperand()); + if (VE && VE->getNumElements() != 4) { + return SYCLGenError(); + } + + OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma"; + OS() << "<" << TypeStr << ">("; + + // Add D matrix address values to store the MAD result + for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) { + if (isa(VE->getElement(Inst))) + continue; + OS() << "&"; + if (emitStmt(VE->getElement(Inst))) + return SYCLGenError(); + OS() << ", "; + } + + // Add A, B & C matrix values to compute MAD + for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands(); + InputOp++) { + if (VE = dyn_cast(Inst->getInputOperand(InputOp))) { + for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) { + if (isa(VE->getElement(Inst))) + continue; + if (emitStmt(VE->getElement(Inst))) + return SYCLGenError(); + OS() << ", "; + } + } else { + return SYCLGenError(); + } + } + + OS() << DpctGlobalInfo::getItem(GAS); + OS() << ");"; + + const auto *KernelDecl = getImmediateOuterFuncDecl(GAS); + if (KernelDecl) { + auto FuncInfo = DeviceFunctionDecl::LinkRedecls(KernelDecl); + if (FuncInfo) + FuncInfo->addSubGroupSizeRequest(32, GAS->getBeginLoc(), + DpctGlobalInfo::getSubGroup(GAS)); + } + + return SYCLGenSuccess(); + } + bool handle_prefetch(const InlineAsmInstruction *Inst) override { if (!DpctGlobalInfo::useExtPrefetch() || Inst->getNumInputOperands() != 1) return SYCLGenError(); diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h b/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h index 1922185e50df..6f8ab9a7b5b3 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h @@ -322,7 +322,7 @@ class InlineAsmInstruction : public InlineAsmStmt { /// This represents arrtibutes like: comparsion operator, rounding modifiers, /// ... e.g. instruction setp.eq.s32 has a comparsion operator 'eq'. - SmallSet Attributes; + SmallVector Attributes; /// This represents types in instruction, e.g. mov.u32. SmallVector Types; @@ -350,11 +350,11 @@ class InlineAsmInstruction : public InlineAsmStmt { OutputOp(Out), PredOutputOp(Pred), InputOps(InOps) { StateSpaces.insert(StateSpaces.begin(), AsmStateSpaces.begin(), AsmStateSpaces.end()); - Attributes.insert(Attrs.begin(), Attrs.end()); + Attributes.insert(Attributes.begin(), Attrs.begin(), Attrs.end()); } using attr_range = - llvm::iterator_range::const_iterator>; + llvm::iterator_range::const_iterator>; using type_range = llvm::iterator_range::const_iterator>; using op_range = @@ -369,12 +369,16 @@ class InlineAsmInstruction : public InlineAsmStmt { } template bool hasAttr(Ts... Attrs) const { - return (Attributes.contains(Attrs) || ...); + return (llvm::is_contained(Attributes, Attrs) || ...); } const InlineAsmIdentifierInfo *getOpcodeID() const { return Opcode; } asmtok::TokenKind getOpcode() const { return Opcode->getTokenID(); } ArrayRef getTypes() const { return Types; } const InlineAsmType *getType(unsigned I) const { return Types[I]; } + InstAttr getAttr(unsigned I) const { + assert(I < Attributes.size() && "Attributes index out of range"); + return Attributes[I]; + } unsigned getNumTypes() const { return Types.size(); } const InlineAsmExpr *getOutputOperand() const { return OutputOp; } const InlineAsmExpr *getPredOutputOperand() const { return PredOutputOp; } diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def b/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def index 563d5595ec65..75d16636edb3 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def @@ -274,6 +274,13 @@ MODIFIER(v2, ".v2") MODIFIER(v4, ".v4") MODIFIER(v8, ".v8") +// Matrix modifiers +MODIFIER(row, ".row") +MODIFIER(col, ".col") + +// Matrix shape +MODIFIER(m16n8k16, ".m16n8k16") + STATE_SPACE(reg, ".reg") STATE_SPACE(sreg, ".sreg") STATE_SPACE(const, ".const") @@ -420,6 +427,7 @@ MODIFIER(ecr, ".ecr") MODIFIER(rc16, ".rc16") MODIFIER(cs, ".cs") MODIFIER(to, ".to") +MODIFIER(aligned, ".aligned") #undef LINKAGE #undef TARGET diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index f23ee2d8e83a..be5fb29d5cce 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2055,6 +2055,79 @@ class joint_matrix { matrix_accessor x; const size_t num_elements; }; + +/// Multiplies 2 16x16 & 16x8 matrices and accumulates the result to a 16x8 b32 +/// matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// \tparam [in] ItemT The type of the sycl::nd_item index space class +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix +/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +/// \param [in] item The sycl::nd_item index space class +template +__attribute__((optnone)) void +mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, CDType c2, + CDType c3, const ItemT &item) { + int lane = item.get_sub_group().get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane / 4); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + ABType recv_a[4 * 4], recv_b[4 * 4]; + for (int i = 0; i < 4; i++) { + recv_a[0 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), a0, + ROW_LOAD_OFFSET + i); + recv_a[1 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), a2, + ROW_LOAD_OFFSET + i); + recv_a[2 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), a1, + ROW_LOAD_OFFSET + i); + recv_a[3 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), a3, + ROW_LOAD_OFFSET + i); + + recv_b[0 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), b0, + COL_LOAD_OFFSET + i); + recv_b[1 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), b1, + COL_LOAD_OFFSET + i); + recv_b[2 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), b0, + COL_LOAD_OFFSET + 4 + i); + recv_b[3 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), b1, + COL_LOAD_OFFSET + 4 + i); + } + + auto *a = reinterpret_cast(recv_a); + auto *b = reinterpret_cast(recv_b); + for (int i = 0; i < 16; i++) { + auto a0 = static_cast(a[i]); + auto a1 = static_cast(a[i + 16]); + auto b0 = static_cast(b[i]); + auto b1 = static_cast(b[i + 16]); + + c0 += a0 * b0; + c1 += a0 * b1; + c2 += a1 * b0; + c3 += a1 * b1; + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + } // namespace matrix } // namespace experimental diff --git a/clang/test/dpct/asm/mma.cu b/clang/test/dpct/asm/mma.cu new file mode 100644 index 000000000000..75d2d8ab5aba --- /dev/null +++ b/clang/test/dpct/asm/mma.cu @@ -0,0 +1,45 @@ +// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2 +// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2 +// RUN: dpct --format-range=none -out-root %T/mma %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only +// RUN: FileCheck %s --match-full-lines --input-file %T/mma/mma.dp.cpp +// RUN: %if build_lit %{icpx -c -DNO_BUILD_TEST -fsycl %T/mma/mma.dp.cpp -o %T/mma/mma.dp.o %} + +// clang-format off +#include +#include + +/* +mma.sync.aligned.m16n8k16.alayout.blayout.dtype.f16.f16.ctype d, a, b, c; + +Below are the currenly supported configurations: + +.alayout = {.row}; +.blayout = {.col}; +.ctype = {.f32}; +.dtype = {.f32}; +*/ + +__global__ void mma_kernel() { + int a[4]; + int b[2]; + float c[4]; + + // CHECK: dpct::experimental::matrix::mma(&c[0], &c[1], &c[2], &c[3], a[0], a[1], a[2], a[3], b[0], b[1], c[0], c[1], c[2], c[3], item_ct1); + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %0, %1, %2, %3 };" + : "+f"(c[0]), "+f"(c[1]), "+f"(c[2]), "+f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1])); +} + + +int main () { + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel<<<1, 32>>>(); + + return 0; +} +// clang-format on diff --git a/docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv b/docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv index 0cd876f76810..24f2adac899d 100644 --- a/docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv +++ b/docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv @@ -54,7 +54,7 @@ max,YES, mbarrier,NO, membar,YES, Partial min,YES, -mma,NO, +mma,YES, Partial mov,YES, movmatrix,NO, mul,YES, Partial From 44f54344f571f886b3c7bd426fe970d9a3257188 Mon Sep 17 00:00:00 2001 From: Teja Alaghari Date: Tue, 15 Apr 2025 12:06:26 +0800 Subject: [PATCH 2/6] Added comments and removed optnone attr --- clang/runtime/dpct-rt/include/dpct/math.hpp | 76 +++++++++++---------- 1 file changed, 40 insertions(+), 36 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index be5fb29d5cce..a247289ae21e 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2058,6 +2058,7 @@ class joint_matrix { /// Multiplies 2 16x16 & 16x8 matrices and accumulates the result to a 16x8 b32 /// matrix +/// Requires the sub-group size of kernel calling this function to be 32 /// \tparam [in] MulType The type of the multiplication result /// \tparam [in] ABType The type of the input matrices /// \tparam [in] CDType The type of the output matrix @@ -2078,48 +2079,51 @@ class joint_matrix { /// \param [in] c3 The 4th element from C matrix to be added with d3 /// \param [in] item The sycl::nd_item index space class template -__attribute__((optnone)) void -mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, - ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, CDType c2, - CDType c3, const ItemT &item) { +void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, + CDType c2, CDType c3, const ItemT &item) { int lane = item.get_sub_group().get_local_linear_id(); short ROW_LOAD_OFFSET = 4 * (lane / 4); short COL_LOAD_OFFSET = 8 * (lane % 4); - ABType recv_a[4 * 4], recv_b[4 * 4]; for (int i = 0; i < 4; i++) { - recv_a[0 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), a0, - ROW_LOAD_OFFSET + i); - recv_a[1 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), a2, - ROW_LOAD_OFFSET + i); - recv_a[2 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), a1, - ROW_LOAD_OFFSET + i); - recv_a[3 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), a3, - ROW_LOAD_OFFSET + i); - - recv_b[0 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), b0, - COL_LOAD_OFFSET + i); - recv_b[1 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), b1, - COL_LOAD_OFFSET + i); - recv_b[2 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), b0, - COL_LOAD_OFFSET + 4 + i); - recv_b[3 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), b1, - COL_LOAD_OFFSET + 4 + i); - } - - auto *a = reinterpret_cast(recv_a); - auto *b = reinterpret_cast(recv_b); - for (int i = 0; i < 16; i++) { - auto a0 = static_cast(a[i]); - auto a1 = static_cast(a[i + 16]); - auto b0 = static_cast(b[i]); - auto b1 = static_cast(b[i + 16]); - - c0 += a0 * b0; - c1 += a0 * b1; - c2 += a1 * b0; - c3 += a1 * b1; + ABType recv_a[4], recv_b[4]; + + recv_a[0] = dpct::select_from_sub_group(item.get_sub_group(), a0, + ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(item.get_sub_group(), a2, + ROW_LOAD_OFFSET + i); + recv_a[2] = dpct::select_from_sub_group(item.get_sub_group(), a1, + ROW_LOAD_OFFSET + i); + recv_a[3] = dpct::select_from_sub_group(item.get_sub_group(), a3, + ROW_LOAD_OFFSET + i); + + recv_b[0] = dpct::select_from_sub_group(item.get_sub_group(), b0, + COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(item.get_sub_group(), b1, + COL_LOAD_OFFSET + i); + recv_b[2] = dpct::select_from_sub_group(item.get_sub_group(), b0, + COL_LOAD_OFFSET + 4 + i); + recv_b[3] = dpct::select_from_sub_group(item.get_sub_group(), b1, + COL_LOAD_OFFSET + 4 + i); + + auto *ra0 = reinterpret_cast(recv_a); + auto *ra1 = reinterpret_cast(recv_a + 2); + auto *rb0 = reinterpret_cast(recv_b); + auto *rb1 = reinterpret_cast(recv_b + 2); + + for (int j = 0; j < 2 * 2; j++) { + auto a0 = static_cast(ra0[j]); + auto a1 = static_cast(ra1[j]); + auto b0 = static_cast(rb0[j]); + auto b1 = static_cast(rb1[j]); + + c0 += a0 * b0; + c1 += a0 * b1; + c2 += a1 * b0; + c3 += a1 * b1; + } } *d0 = c0; From 891e459a883fd396625890e31f7fbe40d32b45be Mon Sep 17 00:00:00 2001 From: Teja Alaghari Date: Tue, 22 Apr 2025 17:30:44 +0800 Subject: [PATCH 3/6] Added support for m8n8k4 shape --- clang/lib/DPCT/RulesAsm/AsmMigration.cpp | 117 +++++++--- clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h | 2 +- clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp | 3 + clang/lib/DPCT/RulesAsm/Parser/AsmParser.h | 2 +- .../DPCT/RulesAsm/Parser/AsmTokenKinds.def | 2 + clang/runtime/dpct-rt/include/dpct/math.hpp | 206 +++++++++++++++++- 6 files changed, 300 insertions(+), 32 deletions(-) diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index 7f1810f023ce..a8a3e594997f 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -556,6 +556,9 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) { return SYCLGenError(); OS() << ", "; switch (T->getKind()) { + case InlineAsmVectorType::v1: + OS() << 1; + break; case InlineAsmVectorType::v2: OS() << 2; break; @@ -1309,45 +1312,109 @@ class SYCLGen : public SYCLGenBase { if (Inst->getNumInputOperands() != 3) return SYCLGenError(); - if (!Inst->hasAttr(InstAttr::m16n8k16)) + const InlineAsmVectorExpr *DMatVE = + dyn_cast(Inst->getOutputOperand()); + if (!DMatVE) return SYCLGenError(); // Only row Layout is supported for of A matrix and // only col Layout is supported for of B matrix - if (Inst->getAttr(3) != InstAttr::row || - Inst->getAttr(4) != InstAttr::col) { + if (Inst->getAttr(3) != InstAttr::row || Inst->getAttr(4) != InstAttr::col) return SYCLGenError(); - } // Only f16 type is supported for A and B matrix data + const auto *DType = dyn_cast(Inst->getType(0)); const auto *AType = dyn_cast(Inst->getType(1)); const auto *BType = dyn_cast(Inst->getType(2)); + const auto *CType = dyn_cast(Inst->getType(3)); - std::string TypeStr; - if (!AType || !BType || - (AType->getKind() != InlineAsmBuiltinType::f16 || - BType->getKind() != InlineAsmBuiltinType::f16)) { + if (!(AType && BType && CType && DType)) return SYCLGenError(); - } else { - if (tryEmitType(TypeStr, AType)) - return SYCLGenError(); - } - const InlineAsmVectorExpr *VE = - dyn_cast(Inst->getOutputOperand()); - if (VE && VE->getNumElements() != 4) { + // Data types of matrix elements for A&B and C&D matrices should be same + if ((AType->getKind() != BType->getKind()) || + (CType->getKind() != DType->getKind())) return SYCLGenError(); + + // Check the validity of AB & CD types + std::string ABType, CDType; + if (tryEmitType(ABType, AType)) + return SYCLGenError(); + + if (tryEmitType(CDType, CType)) + return SYCLGenError(); + + // Register sizes for vector elements of A, B, C & D matrices + unsigned NumVecElements[4] = {0}; + + // Data type used to multiply A & B matrices + std::string MulType; + if (Inst->hasAttr(InstAttr::m16n8k16)) { + // Only f16 type is supported for A and B matrix data for m16n8k16 + if (AType->getKind() == InlineAsmBuiltinType::f16) { + // If A matrix type is f16, then C&D matrix types can only be f16 + if (CType->getKind() == AType->getKind()) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 4; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m8n8k4)) { + // f16 & f64 types are supported for A and B matrix data for m8n8k4 + if (AType->getKind() == InlineAsmBuiltinType::f16) { + // If A matrix type is f16, then C&D matrix types can only be f16/f32 + if (CType->getKind() == AType->getKind()) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else if (CType->getKind() == InlineAsmBuiltinType::f32) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 8; // C + NumVecElements[3] = 8; // D + } else + return SYCLGenError(); + } else if (AType->getKind() == InlineAsmBuiltinType::f64) { + // If A matrix type is f64, then C&D matrix types can only be f64 + if (CType->getKind() == AType->getKind()) { + NumVecElements[0] = 1; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 2; // C + NumVecElements[3] = 2; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else + return SYCLGenError(); + + // Check the register sizes for vector elements of A, B, C & D matrices + for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands(); + InputOp++) { + if (auto VE = + dyn_cast(Inst->getInputOperand(InputOp))) { + if (VE->getNumElements() != NumVecElements[InputOp]) + return SYCLGenError(); + } else + return SYCLGenError(); } + if (DMatVE->getNumElements() != NumVecElements[3]) + return SYCLGenError(); + MulType = ABType; OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma"; - OS() << "<" << TypeStr << ">("; + OS() << "<" << MulType << ">("; // Add D matrix address values to store the MAD result - for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) { - if (isa(VE->getElement(Inst))) + for (unsigned Inst = 0; Inst != DMatVE->getNumElements(); ++Inst) { + if (isa(DMatVE->getElement(Inst))) continue; OS() << "&"; - if (emitStmt(VE->getElement(Inst))) + if (emitStmt(DMatVE->getElement(Inst))) return SYCLGenError(); OS() << ", "; } @@ -1355,7 +1422,8 @@ class SYCLGen : public SYCLGenBase { // Add A, B & C matrix values to compute MAD for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands(); InputOp++) { - if (VE = dyn_cast(Inst->getInputOperand(InputOp))) { + if (auto VE = + dyn_cast(Inst->getInputOperand(InputOp))) { for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) { if (isa(VE->getElement(Inst))) continue; @@ -2607,11 +2675,10 @@ class SYCLGen : public SYCLGenBase { Op = std::move(NewOp); } - bool HasHalfOrBfloat16 = - SrcType->getKind() == InlineAsmBuiltinType::f16 || - DesType->getKind() == InlineAsmBuiltinType::f16 || - SrcType->getKind() == InlineAsmBuiltinType::bf16 || - DesType->getKind() == InlineAsmBuiltinType::bf16; + bool HasHalfOrBfloat16 = SrcType->getKind() == InlineAsmBuiltinType::f16 || + DesType->getKind() == InlineAsmBuiltinType::f16 || + SrcType->getKind() == InlineAsmBuiltinType::bf16 || + DesType->getKind() == InlineAsmBuiltinType::bf16; if (DpctGlobalInfo::useIntelDeviceMath() && HasHalfOrBfloat16) { insertHeader(HeaderType::HT_SYCL_Math); if (SrcNeedBitCast) diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h b/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h index 6f8ab9a7b5b3..7fbf6f6fcae0 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h @@ -116,7 +116,7 @@ class InlineAsmBuiltinType : public InlineAsmType { // This class is used for device asm vector types. class InlineAsmVectorType : public InlineAsmType { public: - enum VecKind { v2, v4, v8 }; + enum VecKind { v1, v2, v4, v8 }; private: VecKind Kind; diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp index 8c8b7e9ff022..4b2e763bba47 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp @@ -739,6 +739,9 @@ InlineAsmParser::ActOnVectorExpr(ArrayRef Vec) { // Vector size must be 2, 4, or 8. InlineAsmVectorType::VecKind Kind; switch (Vec.size()) { + case 1: + Kind = InlineAsmVectorType::v1; + break; case 2: Kind = InlineAsmVectorType::v2; break; diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.h b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.h index ca3196110015..f94d6c4a3df3 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.h +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.h @@ -496,7 +496,7 @@ class InlineAsmParser { /// .reg .sreg .const .local .param .shared .tex /// /// vector-specifier: one of - /// .v2 .v4 .v8 + /// .v1 .v2 .v4 .v8 /// /// type-specifier: one of /// .b8 .b16 .b32 .b64 .s8 .s16 .s32 .s64 diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def b/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def index 75d16636edb3..7d4134678f42 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def @@ -270,6 +270,7 @@ BUILTIN_TYPE(s16x2, ".s16x2") BUILTIN_TYPE(u16x2, ".u16x2") // Vector modifiers +MODIFIER(v1, ".v1") MODIFIER(v2, ".v2") MODIFIER(v4, ".v4") MODIFIER(v8, ".v8") @@ -280,6 +281,7 @@ MODIFIER(col, ".col") // Matrix shape MODIFIER(m16n8k16, ".m16n8k16") +MODIFIER(m8n8k4, ".m8n8k4") STATE_SPACE(reg, ".reg") STATE_SPACE(sreg, ".sreg") diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index a247289ae21e..15201ac35f55 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -9,8 +9,8 @@ #ifndef __DPCT_MATH_HPP__ #define __DPCT_MATH_HPP__ -#include #include +#include #include #include @@ -1636,7 +1636,8 @@ inline constexpr unsigned extend_vcompare2_add(AT a, BT b, unsigned c, /// \returns The extend vectorized average of the two values template inline constexpr RetT extend_vavrg2(AT a, BT b, RetT c) { - return detail::extend_vbinary2(a, b, c, detail::average()); + return detail::extend_vbinary2(a, b, c, + detail::average()); } /// Compute vectorized average of \p a and \p b, with each value treated as a 2 @@ -1933,7 +1934,8 @@ inline constexpr unsigned extend_vcompare4_add(AT a, BT b, unsigned c, /// \returns The extend vectorized average of the two values template inline constexpr RetT extend_vavrg4(AT a, BT b, RetT c) { - return detail::extend_vbinary4(a, b, c, detail::average()); + return detail::extend_vbinary4(a, b, c, + detail::average()); } /// Compute vectorized average of \p a and \p b, with each value treated as a 4 @@ -2056,6 +2058,199 @@ class joint_matrix { const size_t num_elements; }; +/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b16 +/// matrix +/// Requires the sub-group size of kernel calling this function to be 32 +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// \tparam [in] ItemT The type of the sycl::nd_item index space class +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +/// \param [in] item The sycl::nd_item index space class +template +void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType b0, ABType b1, CDType c0, CDType c1, CDType c2, CDType c3, + const ItemT &item) { + int lane = item.get_sub_group().get_local_linear_id(); + + short COL_LOAD_OFFSET = 4 * ((lane % 16) / 4); + + ABType recv_a[2]; + recv_a[0] = a0; + recv_a[1] = a1; + + MulType *ra = reinterpret_cast(recv_a); + MulType *c_h[4]; + c_h[0] = reinterpret_cast(&c0); + c_h[1] = reinterpret_cast(&c1); + c_h[2] = reinterpret_cast(&c2); + c_h[3] = reinterpret_cast(&c3); + for (int i = 0; i < 4; i++) { + ABType recv_b[4]; + + recv_b[0] = dpct::select_from_sub_group(item.get_sub_group(), b0, + COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(item.get_sub_group(), b1, + COL_LOAD_OFFSET + i); + recv_b[2] = dpct::select_from_sub_group(item.get_sub_group(), b0, + COL_LOAD_OFFSET + 16 + i); + recv_b[3] = dpct::select_from_sub_group(item.get_sub_group(), b1, + COL_LOAD_OFFSET + 16 + i); + + MulType *rb = reinterpret_cast(recv_b); + // Iterate for k times + for (int j = 0; j < 4; j++) { + c_h[(i >> 1)][i % 2] += + static_cast(ra[j]) * static_cast(rb[j]); + c_h[2 + (i >> 1)][i % 2] += + static_cast(ra[j]) * static_cast(rb[4 + j]); + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b32 +/// matrix +/// Requires the sub-group size of kernel calling this function to be 32 +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// \tparam [in] ItemT The type of the sycl::nd_item index space class +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] d4 The 5th element to be written to the output D matrix +/// \param [in] d5 The 6th element to be written to the output D matrix +/// \param [in] d6 The 7th element to be written to the output D matrix +/// \param [in] d7 The 8th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +/// \param [in] c4 The 5th element from C matrix to be added with d4 +/// \param [in] c5 The 6th element from C matrix to be added with d5 +/// \param [in] c6 The 7th element from C matrix to be added with d6 +/// \param [in] c7 The 8th element from C matrix to be added with d7 +/// \param [in] item The sycl::nd_item index space class +template +void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, CDType *d4, CDType *d5, + CDType *d6, CDType *d7, ABType a0, ABType a1, ABType b0, ABType b1, + CDType c0, CDType c1, CDType c2, CDType c3, CDType c4, CDType c5, + CDType c6, CDType c7, const ItemT &item) { + int lane = item.get_sub_group().get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane / 4) + (lane % 2); + short COL_LOAD_OFFSET = 4 * ((lane % 16) / 4) + 2 * ((lane / 2) % 2); + + ABType recv_a[2 * 2], recv_b[4 * 2]; + + recv_a[0] = + dpct::select_from_sub_group(item.get_sub_group(), a0, ROW_LOAD_OFFSET); + recv_a[1] = + dpct::select_from_sub_group(item.get_sub_group(), a1, ROW_LOAD_OFFSET); + recv_a[2] = dpct::select_from_sub_group(item.get_sub_group(), a0, + ROW_LOAD_OFFSET + 2); + recv_a[3] = dpct::select_from_sub_group(item.get_sub_group(), a1, + ROW_LOAD_OFFSET + 2); + + recv_b[0] = + dpct::select_from_sub_group(item.get_sub_group(), b0, COL_LOAD_OFFSET); + recv_b[1] = + dpct::select_from_sub_group(item.get_sub_group(), b1, COL_LOAD_OFFSET); + recv_b[2] = dpct::select_from_sub_group(item.get_sub_group(), b0, + COL_LOAD_OFFSET + 1); + recv_b[3] = dpct::select_from_sub_group(item.get_sub_group(), b1, + COL_LOAD_OFFSET + 1); + recv_b[4] = dpct::select_from_sub_group(item.get_sub_group(), b0, + COL_LOAD_OFFSET + 16); + recv_b[5] = dpct::select_from_sub_group(item.get_sub_group(), b1, + COL_LOAD_OFFSET + 16); + recv_b[6] = dpct::select_from_sub_group(item.get_sub_group(), b0, + COL_LOAD_OFFSET + 17); + recv_b[7] = dpct::select_from_sub_group(item.get_sub_group(), b1, + COL_LOAD_OFFSET + 17); + + MulType *ra = reinterpret_cast(recv_a); + MulType *rb = reinterpret_cast(recv_b); + for (int i = 0; i < 4 /*k*/; i++) { + c0 += static_cast(ra[i]) * static_cast(rb[i]); + c1 += static_cast(ra[i]) * static_cast(rb[i + 4]); + c2 += static_cast(ra[i + 4]) * static_cast(rb[i]); + c3 += static_cast(ra[i + 4]) * static_cast(rb[i + 4]); + c4 += static_cast(ra[i]) * static_cast(rb[i + 8]); + c5 += static_cast(ra[i]) * static_cast(rb[i + 12]); + c6 += static_cast(ra[i + 4]) * static_cast(rb[i + 8]); + c7 += static_cast(ra[i + 4]) * static_cast(rb[i + 12]); + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; + *d4 = c4; + *d5 = c5; + *d6 = c6; + *d7 = c7; +} + +/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b32 +/// matrix +/// Requires the sub-group size of kernel calling this function to be 32 +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// \tparam [in] ItemT The type of the sycl::nd_item index space class +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] item The sycl::nd_item index space class +template +void mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1, + const ItemT &item) { + int lane = item.get_sub_group().get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane / 4); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + for (int i = 0; i < 4; i++) { + ABType recv_a = dpct::select_from_sub_group(item.get_sub_group(), a0, + ROW_LOAD_OFFSET + i); + ABType recv_b = dpct::select_from_sub_group(item.get_sub_group(), b0, + COL_LOAD_OFFSET + i); + c0 += recv_a * recv_b; + + recv_b = dpct::select_from_sub_group(item.get_sub_group(), b0, + COL_LOAD_OFFSET + i + 4); + c1 += recv_a * recv_b; + } + + *d0 = c0; + *d1 = c1; +} + /// Multiplies 2 16x16 & 16x8 matrices and accumulates the result to a 16x8 b32 /// matrix /// Requires the sub-group size of kernel calling this function to be 32 @@ -2084,7 +2279,7 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, CDType c2, CDType c3, const ItemT &item) { int lane = item.get_sub_group().get_local_linear_id(); - short ROW_LOAD_OFFSET = 4 * (lane / 4); + short ROW_LOAD_OFFSET = 4 * (lane >> 2); short COL_LOAD_OFFSET = 8 * (lane % 4); for (int i = 0; i < 4; i++) { @@ -2113,7 +2308,8 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, auto *rb0 = reinterpret_cast(recv_b); auto *rb1 = reinterpret_cast(recv_b + 2); - for (int j = 0; j < 2 * 2; j++) { + // Iterate for k (i * j) times + for (int j = 0; j < 4; j++) { auto a0 = static_cast(ra0[j]); auto a1 = static_cast(ra1[j]); auto b0 = static_cast(rb0[j]); From c94ff565a12c0fa038bec0e1d5c685dbe9cafe90 Mon Sep 17 00:00:00 2001 From: Teja Alaghari Date: Tue, 29 Apr 2025 12:55:32 +0800 Subject: [PATCH 4/6] Added support for 10 shapes --- clang/lib/DPCT/RulesAsm/AsmMigration.cpp | 274 +++- clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h | 6 +- .../DPCT/RulesAsm/Parser/AsmTokenKinds.def | 15 +- clang/runtime/dpct-rt/include/dpct/math.hpp | 1396 +++++++++++++++-- clang/test/dpct/asm/mma.cu | 383 ++++- 5 files changed, 1898 insertions(+), 176 deletions(-) diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index a8a3e594997f..f0d7dba3f198 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -514,14 +514,17 @@ bool SYCLGenBase::emitType(const InlineAsmType *T) { bool SYCLGenBase::emitBuiltinType(const InlineAsmBuiltinType *T) { switch (T->getKind()) { // clang-format off + case InlineAsmBuiltinType::b1: OS() << "uint8_t"; break; case InlineAsmBuiltinType::b8: OS() << "uint8_t"; break; case InlineAsmBuiltinType::b16: OS() << "uint16_t"; break; case InlineAsmBuiltinType::b32: OS() << "uint32_t"; break; case InlineAsmBuiltinType::b64: OS() << "uint64_t"; break; + case InlineAsmBuiltinType::u4: OS() << "uint8_t"; break; case InlineAsmBuiltinType::u8: OS() << "uint8_t"; break; case InlineAsmBuiltinType::u16: OS() << "uint16_t"; break; case InlineAsmBuiltinType::u32: OS() << "uint32_t"; break; case InlineAsmBuiltinType::u64: OS() << "uint64_t"; break; + case InlineAsmBuiltinType::s4: OS() << "int8_t"; break; case InlineAsmBuiltinType::s8: OS() << "int8_t"; break; case InlineAsmBuiltinType::s16: OS() << "int16_t"; break; case InlineAsmBuiltinType::s32: OS() << "int32_t"; break; @@ -1347,44 +1350,276 @@ class SYCLGen : public SYCLGenBase { // Register sizes for vector elements of A, B, C & D matrices unsigned NumVecElements[4] = {0}; + // Sizes of A & B matrices + std::string M, N, K; + + // Operator for m8n8k128/m16n8k128/m16n8k256 + std::string MatrixOp; + // Data type used to multiply A & B matrices std::string MulType; - if (Inst->hasAttr(InstAttr::m16n8k16)) { - // Only f16 type is supported for A and B matrix data for m16n8k16 + if (Inst->hasAttr(InstAttr::m8n8k4)) { + M = "8"; + N = "8"; + K = "4"; + // f16 & f64 types are supported for A and B matrices of m8n8k4 if (AType->getKind() == InlineAsmBuiltinType::f16) { - // If A matrix type is f16, then C&D matrix types can only be f16 + // If A matrix type is f16, then C&D matrix types can only be f16/f32 if (CType->getKind() == AType->getKind()) { NumVecElements[0] = 2; // A - NumVecElements[1] = 4; // B + NumVecElements[1] = 2; // B NumVecElements[2] = 4; // C NumVecElements[3] = 4; // D + } else if (CType->getKind() == InlineAsmBuiltinType::f32) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 8; // C + NumVecElements[3] = 8; // D + } else + return SYCLGenError(); + } else if (AType->getKind() == InlineAsmBuiltinType::f64) { + // If A matrix type is f64, then C&D matrix types can only be f64 + if (CType->getKind() == AType->getKind()) { + NumVecElements[0] = 1; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 2; // C + NumVecElements[3] = 2; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m8n8k16)) { + M = "8"; + N = "8"; + K = "16"; + // Only s8/u8 types are supported for A and B matrices of m8n8k16 + if (AType->getKind() == InlineAsmBuiltinType::s8 || + AType->getKind() == InlineAsmBuiltinType::u8) { + // If A matrix type is s8/u8, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 1; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 2; // C + NumVecElements[3] = 2; // D } else return SYCLGenError(); } else return SYCLGenError(); - } else if (Inst->hasAttr(InstAttr::m8n8k4)) { - // f16 & f64 types are supported for A and B matrix data for m8n8k4 + } else if (Inst->hasAttr(InstAttr::m8n8k32)) { + M = "8"; + N = "8"; + K = "32"; + // Only s4/u4 types are supported for A and B matrices of m16n8k32 + if (AType->getKind() == InlineAsmBuiltinType::s4 || + AType->getKind() == InlineAsmBuiltinType::u4) { + // If A matrix type is s4/u4, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 1; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 2; // C + NumVecElements[3] = 2; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m8n8k128)) { + M = "8"; + N = "8"; + K = "128"; + // Only b1 type is supported for A and B matrices of m16n8k128 + if (AType->getKind() == InlineAsmBuiltinType::b1) { + // If A matrix type is b1, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 1; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 2; // C + NumVecElements[3] = 2; // D + + // Only and/xor bitwise operations are supported for m8n8k128 + if (Inst->hasAttr(InstAttr::op_and)) + MatrixOp = "and"; + else if (Inst->hasAttr(InstAttr::op_xor)) + MatrixOp = "xor"; + else + return SYCLGenError(); + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m16n8k4)) { + M = "16"; + N = "8"; + K = "4"; + // Only f64 type is supported for A and B matrices of m16n8k4 + if (AType->getKind() == InlineAsmBuiltinType::f64) { + // If A matrix type is f64, then C&D matrix types can only be f64 + if (CType->getKind() == InlineAsmBuiltinType::f64) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m16n8k8)) { + M = "16"; + N = "8"; + K = "8"; + // Only f16/f64 types are supported for A and B matrices of m16n8k8 if (AType->getKind() == InlineAsmBuiltinType::f16) { // If A matrix type is f16, then C&D matrix types can only be f16/f32 - if (CType->getKind() == AType->getKind()) { + if (CType->getKind() == InlineAsmBuiltinType::f16) { NumVecElements[0] = 2; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 2; // C + NumVecElements[3] = 2; // D + } else if (CType->getKind() == InlineAsmBuiltinType::f32) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else if (AType->getKind() == InlineAsmBuiltinType::f64) { + // If A matrix type is f64, then C&D matrix types can only be f64 + if (CType->getKind() == InlineAsmBuiltinType::f64) { + NumVecElements[0] = 4; // A NumVecElements[1] = 2; // B NumVecElements[2] = 4; // C NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m16n8k16)) { + M = "16"; + N = "8"; + K = "16"; + // Only f16/f64/s8/u8 type is supported for A and B matrices of m16n8k16 + if (AType->getKind() == InlineAsmBuiltinType::f16) { + // If A matrix type is f16, then C&D matrix types can only be f16/f32 + if (CType->getKind() == AType->getKind()) { + NumVecElements[0] = 4; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 2; // C + NumVecElements[3] = 2; // D } else if (CType->getKind() == InlineAsmBuiltinType::f32) { - NumVecElements[0] = 2; // A + NumVecElements[0] = 4; // A NumVecElements[1] = 2; // B - NumVecElements[2] = 8; // C - NumVecElements[3] = 8; // D + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D } else return SYCLGenError(); } else if (AType->getKind() == InlineAsmBuiltinType::f64) { // If A matrix type is f64, then C&D matrix types can only be f64 if (CType->getKind() == AType->getKind()) { - NumVecElements[0] = 1; // A + NumVecElements[0] = 8; // A + NumVecElements[1] = 4; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else if (AType->getKind() == InlineAsmBuiltinType::s8 || + AType->getKind() == InlineAsmBuiltinType::u8) { + // If A matrix type is s8/u8, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 2; // A NumVecElements[1] = 1; // B - NumVecElements[2] = 2; // C - NumVecElements[3] = 2; // D + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m16n8k32)) { + M = "16"; + N = "8"; + K = "32"; + // Only s4/s8/u4/u8 types are supported for A and B matrices of m16n8k32 + if (AType->getKind() == InlineAsmBuiltinType::s4 || + AType->getKind() == InlineAsmBuiltinType::u4) { + // If A matrix type is s4/u4, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else if (AType->getKind() == InlineAsmBuiltinType::s8 || + AType->getKind() == InlineAsmBuiltinType::u8) { + // If A matrix type is s8/u8, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 4; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m16n8k64)) { + M = "16"; + N = "8"; + K = "64"; + // Only s4/u4 types are supported for A and B matrices of m16n8k64 + if (AType->getKind() == InlineAsmBuiltinType::s4 || + AType->getKind() == InlineAsmBuiltinType::u4) { + // If A matrix type is s4/u4, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 4; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m16n8k128)) { + M = "16"; + N = "8"; + K = "128"; + // Only b1 type is supported for A and B matrices of m16n8k128 + if (AType->getKind() == InlineAsmBuiltinType::b1) { + // If A matrix type is b1, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + + // Only and/xor bitwise operations are supported for m16n8k128 + if (Inst->hasAttr(InstAttr::op_and)) + MatrixOp = "and"; + else if (Inst->hasAttr(InstAttr::op_xor)) + MatrixOp = "xor"; + else + return SYCLGenError(); + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else if (Inst->hasAttr(InstAttr::m16n8k256)) { + M = "16"; + N = "8"; + K = "256"; + // Only b1 type is supported for A and B matrices of m16n8k256 + if (AType->getKind() == InlineAsmBuiltinType::b1) { + // If A matrix type is b1, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 4; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + + // Only and/xor bitwise operations are supported for m16n8k256 + if (Inst->hasAttr(InstAttr::op_and)) + MatrixOp = "and"; + else if (Inst->hasAttr(InstAttr::op_xor)) + MatrixOp = "xor"; + else + return SYCLGenError(); } else return SYCLGenError(); } else @@ -1407,7 +1642,12 @@ class SYCLGen : public SYCLGenBase { MulType = ABType; OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma"; - OS() << "<" << MulType << ">("; + if (!MatrixOp.empty()) { + OS() << "_" << MatrixOp; + } + OS() << "<"; + OS() << M << ", " << N << ", " << K << ", "; + OS() << MulType << ">("; // Add D matrix address values to store the MAD result for (unsigned Inst = 0; Inst != DMatVE->getNumElements(); ++Inst) { @@ -1416,7 +1656,8 @@ class SYCLGen : public SYCLGenBase { OS() << "&"; if (emitStmt(DMatVE->getElement(Inst))) return SYCLGenError(); - OS() << ", "; + if ((Inst + 1) != DMatVE->getNumElements()) + OS() << ", "; } // Add A, B & C matrix values to compute MAD @@ -1427,16 +1668,15 @@ class SYCLGen : public SYCLGenBase { for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) { if (isa(VE->getElement(Inst))) continue; + OS() << ", "; if (emitStmt(VE->getElement(Inst))) return SYCLGenError(); - OS() << ", "; } } else { return SYCLGenError(); } } - OS() << DpctGlobalInfo::getItem(GAS); OS() << ");"; const auto *KernelDecl = getImmediateOuterFuncDecl(GAS); diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h b/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h index 7fbf6f6fcae0..04f7485e3476 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h @@ -99,9 +99,9 @@ class InlineAsmBuiltinType : public InlineAsmType { return ((K == Kind) || ...); } template bool isNot(Ks... K) { return ((K != Kind) && ...); } - bool isBit() const { return isOneOf(b8, b16, b32, b64); } - bool isSigned() const { return isOneOf(s8, s16, s32, s64); } - bool isUnsigned() const { return isOneOf(u8, u16, u32, u64); } + bool isBit() const { return isOneOf(b1, b8, b16, b32, b64); } + bool isSigned() const { return isOneOf(s4, s8, s16, s32, s64); } + bool isUnsigned() const { return isOneOf(u4, u8, u16, u32, u64); } bool isInt() const { return isSigned() || isUnsigned(); } bool isFloat() const { return isOneOf(f16, f32, f64); } bool isScalar() const { return isInt() || isFloat(); } diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def b/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def index 7d4134678f42..4155daea4aca 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def @@ -240,14 +240,17 @@ SPECIAL_REG(warpid, "%warpid", s64) SPECIAL_REG(WARP_SZ, "WARP_SZ", s64) // Built-in type names +BUILTIN_TYPE(b1, ".b1") BUILTIN_TYPE(b8, ".b8") BUILTIN_TYPE(b16, ".b16") BUILTIN_TYPE(b32, ".b32") BUILTIN_TYPE(b64, ".b64") +BUILTIN_TYPE(u4, ".u4") BUILTIN_TYPE(u8, ".u8") BUILTIN_TYPE(u16, ".u16") BUILTIN_TYPE(u32, ".u32") BUILTIN_TYPE(u64, ".u64") +BUILTIN_TYPE(s4, ".s4") BUILTIN_TYPE(s8, ".s8") BUILTIN_TYPE(s16, ".s16") BUILTIN_TYPE(s32, ".s32") @@ -280,8 +283,17 @@ MODIFIER(row, ".row") MODIFIER(col, ".col") // Matrix shape -MODIFIER(m16n8k16, ".m16n8k16") MODIFIER(m8n8k4, ".m8n8k4") +MODIFIER(m8n8k16, ".m8n8k16") +MODIFIER(m8n8k32, ".m8n8k32") +MODIFIER(m8n8k128, ".m8n8k128") +MODIFIER(m16n8k4, ".m16n8k4") +MODIFIER(m16n8k8, ".m16n8k8") +MODIFIER(m16n8k16, ".m16n8k16") +MODIFIER(m16n8k32, ".m16n8k32") +MODIFIER(m16n8k64, ".m16n8k64") +MODIFIER(m16n8k128, ".m16n8k128") +MODIFIER(m16n8k256, ".m16n8k256") STATE_SPACE(reg, ".reg") STATE_SPACE(sreg, ".sreg") @@ -377,6 +389,7 @@ MODIFIER(max, ".max") MODIFIER(op_or, ".or") MODIFIER(op_xor, ".xor") MODIFIER(op_and, ".and") +MODIFIER(op_popc, ".popc") MODIFIER(cas, ".cas") MODIFIER(exch, ".exch") MODIFIER(inc, ".inc") diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index 15201ac35f55..e00763a0291d 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2059,12 +2059,15 @@ class joint_matrix { }; /// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b16 -/// matrix +/// matrix (m8n8k4.row.col.f16.f16.f16.f16) /// Requires the sub-group size of kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix /// \tparam [in] MulType The type of the multiplication result /// \tparam [in] ABType The type of the input matrices /// \tparam [in] CDType The type of the output matrix -/// \tparam [in] ItemT The type of the sycl::nd_item index space class +/// In: 4, 2, 2, 4 /// \param [in] d0 The 1st element to be written to the output D matrix /// \param [in] d1 The 2nd element to be written to the output D matrix /// \param [in] d2 The 3rd element to be written to the output D matrix @@ -2077,45 +2080,61 @@ class joint_matrix { /// \param [in] c1 The 2nd element from C matrix to be added with d1 /// \param [in] c2 The 3rd element from C matrix to be added with d2 /// \param [in] c3 The 4th element from C matrix to be added with d3 -/// \param [in] item The sycl::nd_item index space class -template -void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, - ABType b0, ABType b1, CDType c0, CDType c1, CDType c2, CDType c3, - const ItemT &item) { - int lane = item.get_sub_group().get_local_linear_id(); +template +void mma(volatile CDType *d0, volatile CDType *d1, volatile CDType *d2, + volatile CDType *d3, ABType a0, ABType a1, ABType b0, ABType b1, + CDType c0, CDType c1, CDType c2, CDType c3) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); short COL_LOAD_OFFSET = 4 * ((lane % 16) / 4); - ABType recv_a[2]; - recv_a[0] = a0; - recv_a[1] = a1; - - MulType *ra = reinterpret_cast(recv_a); - MulType *c_h[4]; - c_h[0] = reinterpret_cast(&c0); - c_h[1] = reinterpret_cast(&c1); - c_h[2] = reinterpret_cast(&c2); - c_h[3] = reinterpret_cast(&c3); - for (int i = 0; i < 4; i++) { - ABType recv_b[4]; - - recv_b[0] = dpct::select_from_sub_group(item.get_sub_group(), b0, - COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(item.get_sub_group(), b1, - COL_LOAD_OFFSET + i); - recv_b[2] = dpct::select_from_sub_group(item.get_sub_group(), b0, - COL_LOAD_OFFSET + 16 + i); - recv_b[3] = dpct::select_from_sub_group(item.get_sub_group(), b1, - COL_LOAD_OFFSET + 16 + i); + if (M == 8 && N == 8 && K == 4) { + ABType recv_a[2], recv_b[4]; + recv_a[0] = a0; + recv_a[1] = a1; + MulType *ra = reinterpret_cast(recv_a); MulType *rb = reinterpret_cast(recv_b); - // Iterate for k times - for (int j = 0; j < 4; j++) { - c_h[(i >> 1)][i % 2] += - static_cast(ra[j]) * static_cast(rb[j]); - c_h[2 + (i >> 1)][i % 2] += - static_cast(ra[j]) * static_cast(rb[4 + j]); + + float c_f[8] = {0.0f}; + + for (int i = 0; i < 4; i++) { + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[2] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 16 + i); + recv_b[3] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 16 + i); + + for (int j = 0; j < 4; j++) { + c_f[i] += static_cast(ra[j]) * static_cast(rb[j]); + c_f[i + 4] += static_cast(ra[j]) * static_cast(rb[j + 4]); + } } + + auto c_h = reinterpret_cast(&c0); + c_f[0] += static_cast(c_h[0]); + c_f[1] += static_cast(c_h[1]); + c_h[0] = c_f[0]; + c_h[1] = c_f[1]; + + c_h = reinterpret_cast(&c1); + c_f[2] += static_cast(c_h[0]); + c_f[3] += static_cast(c_h[1]); + c_h[0] = c_f[2]; + c_h[1] = c_f[3]; + + c_h = reinterpret_cast(&c2); + c_f[4] += static_cast(c_h[0]); + c_f[5] += static_cast(c_h[1]); + c_h[0] = c_f[4]; + c_h[1] = c_f[5]; + + c_h = reinterpret_cast(&c3); + c_f[6] += static_cast(c_h[0]); + c_f[7] += static_cast(c_h[1]); + c_h[0] = c_f[6]; + c_h[1] = c_f[7]; } *d0 = c0; @@ -2125,11 +2144,15 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, } /// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b32 -/// matrix +/// matrix (m8n8k4.row.col.f32.f32.f32.f32) /// Requires the sub-group size of kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix /// \tparam [in] MulType The type of the multiplication result /// \tparam [in] ABType The type of the input matrices /// \tparam [in] CDType The type of the output matrix +/// In: 8, 2, 2, 8 /// \tparam [in] ItemT The type of the sycl::nd_item index space class /// \param [in] d0 The 1st element to be written to the output D matrix /// \param [in] d1 The 2nd element to be written to the output D matrix @@ -2151,56 +2174,49 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, /// \param [in] c5 The 6th element from C matrix to be added with d5 /// \param [in] c6 The 7th element from C matrix to be added with d6 /// \param [in] c7 The 8th element from C matrix to be added with d7 -/// \param [in] item The sycl::nd_item index space class -template +template void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, CDType *d4, CDType *d5, CDType *d6, CDType *d7, ABType a0, ABType a1, ABType b0, ABType b1, CDType c0, CDType c1, CDType c2, CDType c3, CDType c4, CDType c5, - CDType c6, CDType c7, const ItemT &item) { - int lane = item.get_sub_group().get_local_linear_id(); + CDType c6, CDType c7) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); - short ROW_LOAD_OFFSET = 4 * (lane / 4) + (lane % 2); + short ROW_LOAD_OFFSET = 4 * (lane >> 2) + (lane % 2); short COL_LOAD_OFFSET = 4 * ((lane % 16) / 4) + 2 * ((lane / 2) % 2); - ABType recv_a[2 * 2], recv_b[4 * 2]; - - recv_a[0] = - dpct::select_from_sub_group(item.get_sub_group(), a0, ROW_LOAD_OFFSET); - recv_a[1] = - dpct::select_from_sub_group(item.get_sub_group(), a1, ROW_LOAD_OFFSET); - recv_a[2] = dpct::select_from_sub_group(item.get_sub_group(), a0, - ROW_LOAD_OFFSET + 2); - recv_a[3] = dpct::select_from_sub_group(item.get_sub_group(), a1, - ROW_LOAD_OFFSET + 2); - - recv_b[0] = - dpct::select_from_sub_group(item.get_sub_group(), b0, COL_LOAD_OFFSET); - recv_b[1] = - dpct::select_from_sub_group(item.get_sub_group(), b1, COL_LOAD_OFFSET); - recv_b[2] = dpct::select_from_sub_group(item.get_sub_group(), b0, - COL_LOAD_OFFSET + 1); - recv_b[3] = dpct::select_from_sub_group(item.get_sub_group(), b1, - COL_LOAD_OFFSET + 1); - recv_b[4] = dpct::select_from_sub_group(item.get_sub_group(), b0, - COL_LOAD_OFFSET + 16); - recv_b[5] = dpct::select_from_sub_group(item.get_sub_group(), b1, - COL_LOAD_OFFSET + 16); - recv_b[6] = dpct::select_from_sub_group(item.get_sub_group(), b0, - COL_LOAD_OFFSET + 17); - recv_b[7] = dpct::select_from_sub_group(item.get_sub_group(), b1, - COL_LOAD_OFFSET + 17); - - MulType *ra = reinterpret_cast(recv_a); - MulType *rb = reinterpret_cast(recv_b); - for (int i = 0; i < 4 /*k*/; i++) { - c0 += static_cast(ra[i]) * static_cast(rb[i]); - c1 += static_cast(ra[i]) * static_cast(rb[i + 4]); - c2 += static_cast(ra[i + 4]) * static_cast(rb[i]); - c3 += static_cast(ra[i + 4]) * static_cast(rb[i + 4]); - c4 += static_cast(ra[i]) * static_cast(rb[i + 8]); - c5 += static_cast(ra[i]) * static_cast(rb[i + 12]); - c6 += static_cast(ra[i + 4]) * static_cast(rb[i + 8]); - c7 += static_cast(ra[i + 4]) * static_cast(rb[i + 12]); + if (M == 8 && N == 8 && K == 4) { + ABType recv_a[2 * 2], recv_b[4 * 2]; + + for (int i = 0; i < 2; i++) { + recv_a[2 * i] = + dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + 2 * i); + recv_a[2 * i + 1] = + dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + 2 * i); + + recv_b[4 * i] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 16 * i); + recv_b[4 * i + 1] = + dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 16 * i); + recv_b[4 * i + 2] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 16 * i + 1); + recv_b[4 * i + 3] = + dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 16 * i + 1); + } + + MulType *ra = reinterpret_cast(recv_a); + MulType *rb = reinterpret_cast(recv_b); + for (int i = 0; i < 4; i++) { + c0 += static_cast(ra[i]) * static_cast(rb[i]); + c1 += static_cast(ra[i]) * static_cast(rb[i + 4]); + c2 += static_cast(ra[i + 4]) * static_cast(rb[i]); + c3 += static_cast(ra[i + 4]) * static_cast(rb[i + 4]); + c4 += static_cast(ra[i]) * static_cast(rb[i + 8]); + c5 += static_cast(ra[i]) * static_cast(rb[i + 12]); + c6 += static_cast(ra[i + 4]) * static_cast(rb[i + 8]); + c7 += static_cast(ra[i + 4]) * static_cast(rb[i + 12]); + } } *d0 = c0; @@ -2213,38 +2229,477 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, CDType *d4, CDType *d5, *d7 = c7; } -/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b32 -/// matrix +/// Multiplies 2 8x4 & 4x8 f64 matrices and accumulates the result to a 8x8 b64 +/// matrix (m8n8k4.row.col.f64.f64.f64.f64) /// Requires the sub-group size of kernel calling this function to be 32 +/// In: 2, 1, 1, 2 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +template +typename std::enable_if_t, void> +mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 8 && N == 8 && K == 4) { + for (int i = 0; i < 4; i++) { + ABType recv_a = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + ABType recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + c0 += recv_a * recv_b; + + recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + c1 += recv_a * recv_b; + } + } + + *d0 = c0; + *d1 = c1; +} + +/// Multiplies 2 8x16 & 16x8 u8/s8 matrices and accumulates the result to a 8x8 +/// s32 matrix (m8n8k16.row.col.s32.u8.u8.s32 / m8n8k16.row.col.s32.s8.s8.s32) +/// Multiplies 2 8x32 & 32x8 u4/s4 matrices and accumulates the result to a 8x8 +/// s32 matrix (m8n8k32.row.col.s32.u4.u4.s32 / m8n8k32.row.col.s32.s4.s4.s32) +/// Requires the sub-group size of kernel calling this function to be 32 +/// In: 2, 1, 1, 2 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix /// \tparam [in] MulType The type of the multiplication result /// \tparam [in] ABType The type of the input matrices /// \tparam [in] CDType The type of the output matrix -/// \tparam [in] ItemT The type of the sycl::nd_item index space class /// \param [in] d0 The 1st element to be written to the output D matrix /// \param [in] d1 The 2nd element to be written to the output D matrix /// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix /// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix /// \param [in] c0 The 1st element from C matrix to be added with d0 /// \param [in] c1 The 2nd element from C matrix to be added with d1 -/// \param [in] item The sycl::nd_item index space class -template -void mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1, - const ItemT &item) { - int lane = item.get_sub_group().get_local_linear_id(); +template +std::enable_if_t, void> +mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); - short ROW_LOAD_OFFSET = 4 * (lane / 4); + short ROW_LOAD_OFFSET = 4 * (lane >> 2); short COL_LOAD_OFFSET = 8 * (lane % 4); - for (int i = 0; i < 4; i++) { - ABType recv_a = dpct::select_from_sub_group(item.get_sub_group(), a0, - ROW_LOAD_OFFSET + i); - ABType recv_b = dpct::select_from_sub_group(item.get_sub_group(), b0, - COL_LOAD_OFFSET + i); - c0 += recv_a * recv_b; + if (M == 8 && N == 8 && K == 16) { + for (int i = 0; i < 4; i++) { + ABType recv_a = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + ABType recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b = dpct::select_from_sub_group(item.get_sub_group(), b0, - COL_LOAD_OFFSET + i + 4); - c1 += recv_a * recv_b; + MulType *a = reinterpret_cast(&recv_a); + MulType *b = reinterpret_cast(&recv_b); + + for (int k = 0; k < 4; k++) { + c0 += a[k] * b[k]; + } + + recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + for (int k = 0; k < 4; k++) { + c1 += a[k] * b[k]; + } + } + } else if (M == 8 && N == 8 && K == 32) { + for (int i = 0; i < 4; i++) { + ABType recv_a = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + ABType recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + + MulType *a = reinterpret_cast(&recv_a); + MulType *b = reinterpret_cast(&recv_b); + + for (int k = 0; k < 4; k++) { + MulType a0 = a[k] >> 4; + MulType a1 = a[k] & 0x0F; + MulType b0 = b[k] >> 4; + MulType b1 = b[k] & 0x0F; + + c0 += a0 * b0; + c0 += a1 * b1; + } + + recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + for (int k = 0; k < 4; k++) { + MulType a0 = a[k] >> 4; + MulType a1 = a[k] & 0x0F; + MulType b0 = b[k] >> 4; + MulType b1 = b[k] & 0x0F; + + c1 += a0 * b0; + c1 += a1 * b1; + } + } + } + + *d0 = c0; + *d1 = c1; +} + +/// Multiplies 2 8x128 & 128x8 b1 matrices and accumulates the result to a 8x8 +/// s32 matrix (mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc) +/// Requires the sub-group size of kernel calling this function to be 32 +/// In: 2, 1, 1, 2 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +template +void mma_and(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, + CDType c1) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 8 && N == 8 && K == 128) { + for (int i = 0; i < 4; i++) { + ABType recv_a = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + ABType recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + + c0 += sycl::popcount(recv_a & recv_b); + + recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c1 += sycl::popcount(recv_a & recv_b); + } + } + + *d0 = c0; + *d1 = c1; +} + +/// Multiplies 2 8x128 & 128x8 b1 matrices and accumulates the result to a 8x8 +/// s32 matrix (mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc) +/// Requires the sub-group size of kernel calling this function to be 32 +/// In: 2, 1, 1, 2 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +template +void mma_xor(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, + CDType c1) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 8 && N == 8 && K == 128) { + for (int i = 0; i < 4; i++) { + ABType recv_a = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + ABType recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + + c0 += sycl::popcount(recv_a ^ recv_b); + + recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c1 += sycl::popcount(recv_a ^ recv_b); + } + } + + *d0 = c0; + *d1 = c1; +} + +/// Multiplies 2 16x8 & 8x8 f16 matrices and accumulates the result to a +/// 16x8 f16 matrix (m16n8k8.row.col.f16.f16.f16.f16) +/// Requires the sub-group size of kernel +/// calling this function to be 32 +/// In: 2, 2, 1, 2 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +template +void mma(CDType *d0, CDType *d1, ABType a0, ABType a1, ABType b0, CDType c0, + CDType c1) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 8) { + auto c0_h = reinterpret_cast(&c0); + auto c1_h = reinterpret_cast(&c1); + + float c_f[4] = {c0_h[0], c0_h[1], c1_h[0], c1_h[1]}; + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + for (int j = 0; j < 2; j++) { + c_f[0] += static_cast(ra[j]) * static_cast(rb[j]); + c_f[1] += static_cast(ra[j]) * static_cast(rb[j + 2]); + c_f[2] += static_cast(ra[j + 2]) * static_cast(rb[j]); + c_f[3] += static_cast(ra[j + 2]) * static_cast(rb[j + 2]); + } + } + + c0_h[0] = c_f[0]; + c0_h[1] = c_f[1]; + c1_h[0] = c_f[2]; + c1_h[1] = c_f[3]; + } + + *d0 = c0; + *d1 = c1; +} + +/// Multiplies 2 16x4 & 4x8 f32/f64 matrices and accumulates the result to a +/// 16x8 f32/f64 matrix (m16n8k4.row.col.f32.f32.f32.f32 / +/// m16n8k4.row.col.f64.f64.f64.f64) +/// Requires the sub-group size of kernel +/// calling this function to be 32 +/// In: 4, 2, 1, 4 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template +std::enable_if_t, void> +mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType b0, CDType c0, CDType c1, CDType c2, CDType c3) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 4) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c0 += recv_a[0] * recv_b[0]; + c1 += recv_a[0] * recv_b[1]; + c2 += recv_a[1] * recv_b[0]; + c3 += recv_a[1] * recv_b[1]; + } + } else if (M == 16 && N == 8 && K == 8) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + for (int j = 0; j < 2; j++) { + c0 += static_cast(ra[j]) * static_cast(rb[j]); + c1 += static_cast(ra[j]) * static_cast(rb[j + 2]); + c2 += static_cast(ra[j + 2]) * static_cast(rb[j]); + c3 += static_cast(ra[j + 2]) * static_cast(rb[j + 2]); + } + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 16x8 & 8x8 f16 matrices and accumulates the result to a +/// 16x8 f32 matrix (m16n8k8.row.col.f32.f16.f16.f32) +/// Requires the sub-group size of kernel +/// calling this function to be 32 +/// In: 4, 2, 1, 4 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template +std::enable_if_t, void> +mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType b0, CDType c0, CDType c1, CDType c2, CDType c3) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 8) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + for (int j = 0; j < 2; j++) { + c0 += static_cast(ra[j]) * static_cast(rb[j]); + c1 += static_cast(ra[j]) * static_cast(rb[j + 2]); + c2 += static_cast(ra[j + 2]) * static_cast(rb[j]); + c3 += static_cast(ra[j + 2]) * static_cast(rb[j + 2]); + } + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 16x16 & 16x8 f16 matrices and accumulates the result to a 16x8 +/// f16 matrix (m16n8k16.row.col.f16.f16.f16.f16) Requires the sub-group size of +/// kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 2, 4, 2, 2 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix +/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +template +void mma(volatile CDType *d0, volatile CDType *d1, ABType a0, ABType a1, + ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 16) { + auto c0_h = reinterpret_cast(&c0); + auto c1_h = reinterpret_cast(&c1); + + float c_f[4] = {c0_h[0], c0_h[1], c1_h[0], c1_h[1]}; + + for (int i = 0; i < 4; i++) { + ABType recv_a[4], recv_b[4]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[2] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_a[3] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[2] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + recv_b[3] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + for (int j = 0; j < 4; j++) { + c_f[0] += static_cast(ra[j]) * static_cast(rb[j]); + c_f[1] += static_cast(ra[j]) * static_cast(rb[j + 4]); + c_f[2] += static_cast(ra[j + 4]) * static_cast(rb[j]); + c_f[3] += static_cast(ra[j + 4]) * static_cast(rb[j + 4]); + } + } + + c0_h[0] = c_f[0]; + c0_h[1] = c_f[1]; + c1_h[0] = c_f[2]; + c1_h[1] = c_f[3]; } *d0 = c0; @@ -2252,12 +2707,15 @@ void mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1, } /// Multiplies 2 16x16 & 16x8 matrices and accumulates the result to a 16x8 b32 -/// matrix +/// matrix (m16n8k16.row.col.f32.f16.f16.f32) /// Requires the sub-group size of kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix /// \tparam [in] MulType The type of the multiplication result /// \tparam [in] ABType The type of the input matrices /// \tparam [in] CDType The type of the output matrix -/// \tparam [in] ItemT The type of the sycl::nd_item index space class +/// In: 4, 4, 2, 4 /// \param [in] d0 The 1st element to be written to the output D matrix /// \param [in] d1 The 2nd element to be written to the output D matrix /// \param [in] d2 The 3rd element to be written to the output D matrix @@ -2272,53 +2730,707 @@ void mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1, /// \param [in] c1 The 2nd element from C matrix to be added with d1 /// \param [in] c2 The 3rd element from C matrix to be added with d2 /// \param [in] c3 The 4th element from C matrix to be added with d3 -/// \param [in] item The sycl::nd_item index space class -template +template +std::enable_if_t, void> +mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, CDType c2, + CDType c3) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 16) { + for (int i = 0; i < 4; i++) { + ABType recv_a[4], recv_b[4]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[2] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_a[3] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[2] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 4 + i); + recv_b[3] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 4 + i); + + auto *ra0 = reinterpret_cast(recv_a); + auto *ra1 = reinterpret_cast(recv_a + 2); + auto *rb0 = reinterpret_cast(recv_b); + auto *rb1 = reinterpret_cast(recv_b + 2); + + // Iterate for k (i * j) times + for (int j = 0; j < 4; j++) { + auto a0 = static_cast(ra0[j]); + auto a1 = static_cast(ra1[j]); + auto b0 = static_cast(rb0[j]); + auto b1 = static_cast(rb1[j]); + + c0 += a0 * b0; + c1 += a0 * b1; + c2 += a1 * b0; + c3 += a1 * b1; + } + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 16x8 & 8x8 u4/s4 matrices and accumulates the result to a 16x8 +/// f64 matrix (m16n8k8.row.col.f64.f64.f64.f64) Requires the sub-group size of +/// kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 4, 4, 2, 4 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix +/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template +std::enable_if_t, void> +mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, CDType c2, + CDType c3) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 8) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c0 += recv_a[0] * recv_b[0]; + c1 += recv_a[0] * recv_b[1]; + c2 += recv_a[1] * recv_b[0]; + c3 += recv_a[1] * recv_b[1]; + } + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + c0 += recv_a[0] * recv_b[0]; + c1 += recv_a[0] * recv_b[1]; + c2 += recv_a[1] * recv_b[0]; + c3 += recv_a[1] * recv_b[1]; + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 16x32 & 32x8 u8/s8 matrices and accumulates the result to a +/// 16x8 b32 matrix (m16n8k32.row.col.s32.u8.u8.s32 / +/// m16n8k32.row.col.s32.s8.s8.s32) Multiplies 2 16x64 & 64x8 u4/s4 matrices and +/// accumulates the result to a 16x8 b32 matrix (m16n8k64.row.col.s32.u4.u4.s32 +/// / m16n8k64.row.col.s32.s4.s4.s32) Requires the sub-group size of kernel +/// calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 4, 4, 2, 4 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix +/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template +std::enable_if_t, void> +mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, CDType c2, + CDType c3) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 32) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + c0 += a[k] * b[k]; + c1 += a[k] * b[k + 4]; + c2 += a[k + 4] * b[k]; + c3 += a[k + 4] * b[k + 4]; + } + } + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + c0 += a[k] * b[k]; + c1 += a[k] * b[k + 4]; + c2 += a[k + 4] * b[k]; + c3 += a[k + 4] * b[k + 4]; + } + } + } else if (M == 16 && N == 8 && K == 64) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + MulType a00 = a[k] >> 4; + MulType a01 = a[k] & 0x0F; + MulType a10 = a[k + 4] >> 4; + MulType a11 = a[k + 4] & 0x0F; + MulType b00 = b[k] >> 4; + MulType b01 = b[k] & 0x0F; + MulType b10 = b[k + 4] >> 4; + MulType b11 = b[k + 4] & 0x0F; + + c0 += a00 * b00; + c0 += a01 * b01; + + c1 += a00 * b10; + c1 += a01 * b11; + + c2 += a10 * b00; + c2 += a11 * b01; + + c3 += a10 * b10; + c3 += a11 * b11; + } + } + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + MulType a00 = a[k] >> 4; + MulType a01 = a[k] & 0x0F; + MulType a10 = a[k + 4] >> 4; + MulType a11 = a[k + 4] & 0x0F; + MulType b00 = b[k] >> 4; + MulType b01 = b[k] & 0x0F; + MulType b10 = b[k + 4] >> 4; + MulType b11 = b[k + 4] & 0x0F; + + c0 += a00 * b00; + c0 += a01 * b01; + + c1 += a00 * b10; + c1 += a01 * b11; + + c2 += a10 * b00; + c2 += a11 * b01; + + c3 += a10 * b10; + c3 += a11 * b11; + } + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 16x16 & 16x8 f64 matrices and accumulates the result to a 16x8 +/// f64 matrix (m16n8k16.row.col.f64.f64.f64.f64) Requires the sub-group size of +/// kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 4, 8, 4, 4 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix +/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix +/// \param [in] a4 The 5th element from A matrix to be multiplied with B matrix +/// \param [in] a5 The 6th element from A matrix to be multiplied with B matrix +/// \param [in] a6 The 7th element from A matrix to be multiplied with B matrix +/// \param [in] a7 The 8th element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] b2 The 3rd element from B matrix to be multiplied with A matrix +/// \param [in] b3 The 4th element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, - ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, - CDType c2, CDType c3, const ItemT &item) { - int lane = item.get_sub_group().get_local_linear_id(); + ABType a2, ABType a3, ABType a4, ABType a5, ABType a6, ABType a7, + ABType b0, ABType b1, ABType b2, ABType b3, CDType c0, CDType c1, + CDType c2, CDType c3) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 16) { + ABType recv_a[16 * 2], recv_b[16 * 2]; + + for (int i = 0; i < 4; i++) { + recv_a[i] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[i + 4] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[i + 8] = dpct::select_from_sub_group(sg, a4, ROW_LOAD_OFFSET + i); + recv_a[i + 12] = dpct::select_from_sub_group(sg, a6, ROW_LOAD_OFFSET + i); + recv_a[i + 16] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_a[i + 20] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_a[i + 24] = dpct::select_from_sub_group(sg, a5, ROW_LOAD_OFFSET + i); + recv_a[i + 28] = dpct::select_from_sub_group(sg, a7, ROW_LOAD_OFFSET + i); + + recv_b[i] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[i + 4] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[i + 8] = dpct::select_from_sub_group(sg, b2, COL_LOAD_OFFSET + i); + recv_b[i + 12] = dpct::select_from_sub_group(sg, b3, COL_LOAD_OFFSET + i); + recv_b[i + 16] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + recv_b[i + 20] = + dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + recv_b[i + 24] = + dpct::select_from_sub_group(sg, b2, COL_LOAD_OFFSET + i + 4); + recv_b[i + 28] = + dpct::select_from_sub_group(sg, b3, COL_LOAD_OFFSET + i + 4); + } + + for (int i = 0; i < 16; i++) { + c0 += recv_a[i] * recv_b[i]; + c1 += recv_a[i] * recv_b[i + 16]; + c2 += recv_a[i + 16] * recv_b[i]; + c3 += recv_a[i + 16] * recv_b[i + 16]; + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 16x32 & 32x8 u4/s4 matrices and accumulates the result to a +/// 16x8 s32 matrix (m16n8k32.row.col.s32.u4.u4.s32 / +/// m16n8k32.row.col.s32.s4.s4.s32) +/// Multiplies 2 16x16 & 16x8 u8/s8 matrices and accumulates the result to a +/// 16x8 s32 matrix (m16n8k16.row.col.s32.u8.u8.s32 / +/// m16n8k16.row.col.s32.s8.s8.s32) +/// Requires the sub-group size of kernel +/// calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 4, 2, 1, 4 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template +std::enable_if_t, void> +mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType b0, CDType c0, CDType c1, CDType c2, CDType c3) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 16) { + ABType recv_a[4 * 2], recv_b[4 * 2]; + + for (int i = 0; i < 4; i++) { + recv_a[i] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[i + 4] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + + recv_b[i] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[i + 4] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + } + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + for (int i = 0; i < 16; i++) { + c0 += a[i] * b[i]; + c1 += a[i] * b[i + 16]; + c2 += a[i + 16] * b[i]; + c3 += a[i + 16] * b[i + 16]; + } + } else if (M == 16 && N == 8 && K == 32) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + MulType a00 = a[k] >> 4; + MulType a01 = a[k] & 0x0F; + MulType a10 = a[k + 4] >> 4; + MulType a11 = a[k + 4] & 0x0F; + MulType b00 = b[k] >> 4; + MulType b01 = b[k] & 0x0F; + MulType b10 = b[k + 4] >> 4; + MulType b11 = b[k + 4] & 0x0F; + + c0 += a00 * b00; + c0 += a01 * b01; + + c1 += a00 * b10; + c1 += a01 * b11; + + c2 += a10 * b00; + c2 += a11 * b01; + + c3 += a10 * b10; + c3 += a11 * b11; + } + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 16x128 & 128x8 b1 matrices and accumulates the result to a 16x8 +/// s32 matrix (mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc) +/// Requires the sub-group size of kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 4, 2, 1, 4 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template +void mma_and(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, + ABType a1, ABType b0, CDType c0, CDType c1, CDType c2, CDType c3) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 128) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c0 += sycl::popcount(recv_a[0] & recv_b[0]); + c1 += sycl::popcount(recv_a[0] & recv_b[1]); + c2 += sycl::popcount(recv_a[1] & recv_b[0]); + c3 += sycl::popcount(recv_a[1] & recv_b[1]); + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 16x128 & 128x8 b1 matrices and accumulates the result to a 16x8 +/// s32 matrix (mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc) +/// Requires the sub-group size of kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 4, 2, 1, 4 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template +void mma_xor(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, + ABType a1, ABType b0, CDType c0, CDType c1, CDType c2, CDType c3) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 128) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c0 += sycl::popcount(recv_a[0] ^ recv_b[0]); + c1 += sycl::popcount(recv_a[0] ^ recv_b[1]); + c2 += sycl::popcount(recv_a[1] ^ recv_b[0]); + c3 += sycl::popcount(recv_a[1] ^ recv_b[1]); + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 16x256 & 256x8 b1 matrices and accumulates the result to a 16x8 +/// s32 matrix (mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc) +/// Requires the sub-group size of kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 4, 4, 2, 4 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix +/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template +void mma_and(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, + ABType a1, ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, + CDType c1, CDType c2, CDType c3) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); short ROW_LOAD_OFFSET = 4 * (lane >> 2); short COL_LOAD_OFFSET = 8 * (lane % 4); - for (int i = 0; i < 4; i++) { - ABType recv_a[4], recv_b[4]; - - recv_a[0] = dpct::select_from_sub_group(item.get_sub_group(), a0, - ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(item.get_sub_group(), a2, - ROW_LOAD_OFFSET + i); - recv_a[2] = dpct::select_from_sub_group(item.get_sub_group(), a1, - ROW_LOAD_OFFSET + i); - recv_a[3] = dpct::select_from_sub_group(item.get_sub_group(), a3, - ROW_LOAD_OFFSET + i); - - recv_b[0] = dpct::select_from_sub_group(item.get_sub_group(), b0, - COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(item.get_sub_group(), b1, - COL_LOAD_OFFSET + i); - recv_b[2] = dpct::select_from_sub_group(item.get_sub_group(), b0, - COL_LOAD_OFFSET + 4 + i); - recv_b[3] = dpct::select_from_sub_group(item.get_sub_group(), b1, - COL_LOAD_OFFSET + 4 + i); - - auto *ra0 = reinterpret_cast(recv_a); - auto *ra1 = reinterpret_cast(recv_a + 2); - auto *rb0 = reinterpret_cast(recv_b); - auto *rb1 = reinterpret_cast(recv_b + 2); - - // Iterate for k (i * j) times - for (int j = 0; j < 4; j++) { - auto a0 = static_cast(ra0[j]); - auto a1 = static_cast(ra1[j]); - auto b0 = static_cast(rb0[j]); - auto b1 = static_cast(rb1[j]); - - c0 += a0 * b0; - c1 += a0 * b1; - c2 += a1 * b0; - c3 += a1 * b1; + if (M == 16 && N == 8 && K == 256) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c0 += sycl::popcount(recv_a[0] & recv_b[0]); + c1 += sycl::popcount(recv_a[0] & recv_b[1]); + c2 += sycl::popcount(recv_a[1] & recv_b[0]); + c3 += sycl::popcount(recv_a[1] & recv_b[1]); + } + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + c0 += sycl::popcount(recv_a[0] & recv_b[0]); + c1 += sycl::popcount(recv_a[0] & recv_b[1]); + c2 += sycl::popcount(recv_a[1] & recv_b[0]); + c3 += sycl::popcount(recv_a[1] & recv_b[1]); + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 16x256 & 256x8 b1 matrices and accumulates the result to a 16x8 +/// s32 matrix (mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc) +/// Requires the sub-group size of kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 4, 4, 2, 4 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix +/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template +void mma_xor(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, + ABType a1, ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, + CDType c1, CDType c2, CDType c3) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 256) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c0 += sycl::popcount(recv_a[0] ^ recv_b[0]); + c1 += sycl::popcount(recv_a[0] ^ recv_b[1]); + c2 += sycl::popcount(recv_a[1] ^ recv_b[0]); + c3 += sycl::popcount(recv_a[1] ^ recv_b[1]); + } + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + c0 += sycl::popcount(recv_a[0] ^ recv_b[0]); + c1 += sycl::popcount(recv_a[0] ^ recv_b[1]); + c2 += sycl::popcount(recv_a[1] ^ recv_b[0]); + c3 += sycl::popcount(recv_a[1] ^ recv_b[1]); } } diff --git a/clang/test/dpct/asm/mma.cu b/clang/test/dpct/asm/mma.cu index 75d2d8ab5aba..a959e383b2b1 100644 --- a/clang/test/dpct/asm/mma.cu +++ b/clang/test/dpct/asm/mma.cu @@ -9,36 +9,393 @@ #include /* -mma.sync.aligned.m16n8k16.alayout.blayout.dtype.f16.f16.ctype d, a, b, c; +As per PTX ASM 8.1, below is the status of supported configurations -Below are the currenly supported configurations: +--------- --------- ---------- ----------- ------------- +| Shape | | A | | B | | C / D | | Supported | +--------- --------- ---------- ----------- ------------- +m8n8k4 .f16 .f16 .f16/.f32 Yes + .f64 .f64 .f64 Yes +m8n8k16 .s8/.u8 .s8/.u8 .s32 Yes +m8n8k32 .s4/.u4 .s4/.u4 .s32 Yes +m8n8k128 .b1 .b1 .s32 Yes -.alayout = {.row}; -.blayout = {.col}; -.ctype = {.f32}; -.dtype = {.f32}; +m16n8k4 .tf32 .tf32 .tf32 No + .f64 .f64 .f64 Yes +m16n8k8 .f16/.bf16 .f16/.bf16 .f16/.f32 Partial (.f16.f16.f16.f16 / .f32.f16.f16.f32) + .tf32 .tf32 .tf32 No + .f64 .f64 .f64 Yes +m16n8k16 .f16/.bf16 .f16/.bf16 .f16/.f32 Partial (.f16.f16.f16.f16 / .f32.f16.f16.f32) + .f64 .f64 .f64 Yes + .s8/.u8 .s8/.u8 .s32 Yes +m16n8k32 .s4/.u4 .s4/.u4 .s32 Yes + .s8/.u8 .s8/.u8 .s32 Yes +m16n8k64 .s4/.u4 .s4/.u4 .s32 Yes +m16n8k128 .b1 .b1 .s32 Yes +m16n8k256 .b1 .b1 .s32 Yes + +A Layout: row +B Layout: col */ -__global__ void mma_kernel() { - int a[4]; - int b[2]; - float c[4]; +__global__ void mma_kernel_m8n8k4(int *a, int *b, float *c, double *d) { + // CHECK: dpct::experimental::matrix::mma<8, 8, 4, sycl::half>(&c[0], &c[1], &c[2], &c[3], a[0], a[1], b[0], b[1], c[0], c[1], c[2], c[3]); + asm("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6, %7 }, " + " { %8, %9, %10, %11 };" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + + // CHECK: dpct::experimental::matrix::mma<8, 8, 4, sycl::half>(&c[0], &c[1], &c[2], &c[3], &c[4], &c[5], &c[6], &c[7], a[0], a[1], b[0], b[1], c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]); + asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " + " { %0, %1, %2, %3, %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %10, %11 }, " + " { %0, %1, %2, %3, %4, %5, %6, %7 };" + : "+f"(c[0]), "+f"(c[1]), "+f"(c[2]), "+f"(c[3]), "+f"(c[4]), "+f"(c[5]), "+f"(c[6]), "+f"(c[7]) + : "r"(a[0]), "r"(a[1]), + "r"(b[0]), "r"(b[1])); + + // CHECK: dpct::experimental::matrix::mma<8, 8, 4, double>(&d[0], &d[1], a[0], b[0], d[0], d[1]); + asm("mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 " + " { %0, %1 }, " + " { %2 }, " + " { %3 }, " + " { %4, %5 };" + : "=d"(d[0]), "=d"(d[1]) + : "d"(a[0]), + "d"(b[0]), + "d"(d[0]), "d"(d[1])); +} + +__global__ void mma_kernel_m8n8k16(int *a, int *b, int *c, int *d) { + // CHECK: dpct::experimental::matrix::mma<8, 8, 16, int8_t>(&d[0], &d[1], a[0], b[0], c[0], c[1]); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 " + " { %0, %1 }, " + " { %2 }, " + " { %3 }, " + " { %4, %5 };" + : "=r"(d[0]), "=r"(d[1]) + : "r"(a[0]), + "r"(b[0]), + "r"(c[0]), "r"(c[1])); + + // CHECK: dpct::experimental::matrix::mma<8, 8, 16, uint8_t>(&d[0], &d[1], a[0], b[0], c[0], c[1]); + asm("mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 " + " { %0, %1 }, " + " { %2 }, " + " { %3 }, " + " { %4, %5 };" + : "=r"(d[0]), "=r"(d[1]) + : "r"(a[0]), + "r"(b[0]), + "r"(c[0]), "r"(c[1])); +} + +__global__ void mma_kernel_m8n8k32(int *a, int *b, int *c, int *d) { + // CHECK: dpct::experimental::matrix::mma<8, 8, 32, int8_t>(&d[0], &d[1], a[0], b[0], c[0], c[1]); + asm("mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 " + " { %0, %1 }, " + " { %2 }, " + " { %3 }, " + " { %4, %5 };" + : "=r"(d[0]), "=r"(d[1]) + : "r"(a[0]), + "r"(b[0]), + "r"(c[0]), "r"(c[1])); + + // CHECK: dpct::experimental::matrix::mma<8, 8, 32, uint8_t>(&d[0], &d[1], a[0], b[0], c[0], c[1]); + asm("mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 " + " { %0, %1 }, " + " { %2 }, " + " { %3 }, " + " { %4, %5 };" + : "=r"(d[0]), "=r"(d[1]) + : "r"(a[0]), + "r"(b[0]), + "r"(c[0]), "r"(c[1])); +} + +__global__ void mma_kernel_m8n8k128(int *a, int *b, int *c, int *d) { + // CHECK: dpct::experimental::matrix::mma<8, 8, 128, uint8_t, sycl::bit_and<>>(&d[0], &d[1], a[0], b[0], c[0], c[1]); + asm ("mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc " + " { %0, %1 }, " + " { %2 }, " + " { %3 }, " + " { %4, %5 };" + : "=r"(d[0]), "=r"(d[1]) + : "r"(a[0]), + "r"(b[0]), + "r"(c[0]), "r"(c[1])); + + // CHECK: dpct::experimental::matrix::mma<8, 8, 128, uint8_t, sycl::bit_xor<>>(&d[0], &d[1], a[0], b[0], c[0], c[1]); + asm ("mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc " + " { %0, %1 }, " + " { %2 }, " + " { %3 }, " + " { %4, %5 };" + : "=r"(d[0]), "=r"(d[1]) + : "r"(a[0]), + "r"(b[0]), + "r"(c[0]), "r"(c[1])); +} + +__global__ void mma_kernel_m16n8k4(float *fa, float *fb, float *fc, float *fd, double *da, double *db, double *dc, double *dd) { + // CHECK: dpct::experimental::matrix::mma<16, 8, 4, double>(&dd[0], &dd[1], &dd[2], &dd[3], da[0], da[1], db[0], dc[0], dc[1], dc[2], dc[3]); + asm("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64 " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=d"(dd[0]), "=d"(dd[1]), "=d"(dd[2]), "=d"(dd[3]) + : "d"(da[0]), "d"(da[1]), + "d"(db[0]), + "d"(dc[0]), "d"(dc[1]), "d"(dc[2]), "d"(dc[3])); +} + +__global__ void mma_kernel_m16n8k8(int *a, int *b, uint *c, uint *d, float *fa, float *fb, float *fc, float *fd, double *da, double *db, double *dc, double *dd) { + // CHECK: dpct::experimental::matrix::mma<16, 8, 8, sycl::half>(&d[0], &d[1], (*(reinterpret_cast(&a[0]))), (*(reinterpret_cast(&a[1]))), (*(reinterpret_cast(&b[0]))), c[0], c[1]); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + " { %0, %1 }, " + " { %2, %3 }, " + " { %4 }, " + " { %5, %6 };" + : "=r"(d[0]), "=r"(d[1]) + : "r"(*(reinterpret_cast(&a[0]))), + "r"(*(reinterpret_cast(&a[1]))), + "r"(*(reinterpret_cast(&b[0]))), + "r"(c[0]), "r"(c[1])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 8, sycl::half>(&fd[0], &fd[1], &fd[2], &fd[3], *(reinterpret_cast(&a[0])), *(reinterpret_cast(&a[1])), *(reinterpret_cast(&b[0])), fc[0], fc[1], fc[2], fc[3]); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=f"(fd[0]), "=f"(fd[1]), "=f"(fd[2]), "=f"(fd[3]) + : "r"(*(reinterpret_cast(&a[0]))), + "r"(*(reinterpret_cast(&a[1]))), + "r"(*(reinterpret_cast(&b[0]))), + "f"(fc[0]), "f"(fc[1]), "f"(fc[2]), "f"(fc[3])); - // CHECK: dpct::experimental::matrix::mma(&c[0], &c[1], &c[2], &c[3], a[0], a[1], a[2], a[3], b[0], b[1], c[0], c[1], c[2], c[3], item_ct1); + // CHECK: dpct::experimental::matrix::mma<16, 8, 8, double>(&dd[0], &dd[1], &dd[2], &dd[3], da[0], da[1], da[2], da[3], db[0], db[1], dc[0], dc[1], dc[2], dc[3]); + asm("mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %10, %11, %12, %13 };" + : "=d"(dd[0]), "=d"(dd[1]), "=d"(dd[2]), "=d"(dd[3]) + : "d"(da[0]), "d"(da[1]), "d"(da[2]), "d"(da[3]), + "d"(db[0]), "d"(db[1]), + "d"(dc[0]), "d"(dc[1]), "d"(dc[2]), "d"(dc[3])); +} + +__global__ void mma_kernel_m16n8k16(uint *ua, uint *ub, uint *uc, uint *ud, int *a, int *b, int *c, float *fc, int *d, double *da, double *db, double *dc, double *dd) { + // CHECK: dpct::experimental::matrix::mma<16, 8, 16, sycl::half>(&ud[0], &ud[1], a[0], a[1], a[2], a[3], b[0], b[1], uc[0], uc[1]); + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + " { %0, %1 }, " + " { %2, %3, %4, %5 }, " + " { %6, %7 }, " + " { %8, %9 };" + : "=r"(ud[0]), "=r"(ud[1]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(uc[0]), "r"(uc[1])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 16, sycl::half>(&fc[0], &fc[1], &fc[2], &fc[3], a[0], a[1], a[2], a[3], b[0], b[1], fc[0], fc[1], fc[2], fc[3]); asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " " { %0, %1, %2, %3 }, " " { %4, %5, %6, %7 }, " " { %8, %9 }, " " { %0, %1, %2, %3 };" - : "+f"(c[0]), "+f"(c[1]), "+f"(c[2]), "+f"(c[3]) + : "+f"(fc[0]), "+f"(fc[1]), "+f"(fc[2]), "+f"(fc[3]) : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 16, double>(&dd[0], &dd[1], &dd[2], &dd[3], da[0], da[1], da[2], da[3], da[4], da[5], da[6], da[7], db[0], db[1], db[2], db[3], dc[0], dc[1], dc[2], dc[3]); + asm("mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7, %8, %9, %10, %11 }, " + " { %12, %13, %14, %15 }, " + " { %16, %17, %18, %19 };" + : "=d"(dd[0]), "=d"(dd[1]), "=d"(dd[2]), "=d"(dd[3]) + : "d"(da[0]), "d"(da[1]), "d"(da[2]), "d"(da[3]), "d"(da[4]), "d"(da[5]), "d"(da[6]), "d"(da[7]), + "d"(db[0]), "d"(db[1]), "d"(db[2]), "d"(db[3]), + "d"(dc[0]), "d"(dc[1]), "d"(dc[2]), "d"(dc[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 16, uint8_t>(&ud[0], &ud[1], &ud[2], &ud[3], ua[0], ua[1], ub[0], uc[0], uc[1], uc[2], uc[3]); + asm("mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=r"(ud[0]), "=r"(ud[1]), "=r"(ud[2]), "=r"(ud[3]) + : "r"(ua[0]), "r"(ua[1]), + "r"(ub[0]), + "r"(uc[0]), "r"(uc[1]), "r"(uc[2]), "r"(uc[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 16, int8_t>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], b[0], c[0], c[1], c[2], c[3]); + asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), + "r"(b[0]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +} + +__global__ void mma_kernel_m16n8k32(uint *ua, uint *ub, uint *uc, uint *ud, int *a, int *b, int *c, int *d) { + // CHECK: dpct::experimental::matrix::mma<16, 8, 32, uint8_t>(&ud[0], &ud[1], &ud[2], &ud[3], ua[0], ua[1], ub[0], uc[0], uc[1], uc[2], uc[3]); + asm("mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=r"(ud[0]), "=r"(ud[1]), "=r"(ud[2]), "=r"(ud[3]) + : "r"(ua[0]), "r"(ua[1]), + "r"(ub[0]), + "r"(uc[0]), "r"(uc[1]), "r"(uc[2]), "r"(uc[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 32, int8_t>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], b[0], c[0], c[1], c[2], c[3]); + asm("mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), + "r"(b[0]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 32, uint8_t>(&ud[0], &ud[1], &ud[2], &ud[3], ua[0], ua[1], ua[2], ua[3], ub[0], ub[1], uc[0], uc[1], uc[2], uc[3]); + asm("mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %10, %11, %12, %13 };" + : "=r"(ud[0]), "=r"(ud[1]), "=r"(ud[2]), "=r"(ud[3]) + : "r"(ua[0]), "r"(ua[1]), "r"(ua[2]), "r"(ua[3]), + "r"(ub[0]), "r"(ub[1]), + "r"(uc[0]), "r"(uc[1]), "r"(uc[2]), "r"(uc[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 32, int8_t>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], a[2], a[3], b[0], b[1], c[0], c[1], c[2], c[3]); + asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %10, %11, %12, %13 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +} + +__global__ void mma_kernel_m16n8k64(uint *ua, uint *ub, uint *uc, uint *ud, int *a, int *b, int *c, int *d) { + // CHECK: dpct::experimental::matrix::mma<16, 8, 64, uint8_t>(&ud[0], &ud[1], &ud[2], &ud[3], ua[0], ua[1], ua[2], ua[3], ub[0], ub[1], uc[0], uc[1], uc[2], uc[3]); + asm("mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %10, %11, %12, %13 };" + : "=r"(ud[0]), "=r"(ud[1]), "=r"(ud[2]), "=r"(ud[3]) + : "r"(ua[0]), "r"(ua[1]), "r"(ua[2]), "r"(ua[3]), + "r"(ub[0]), "r"(ub[1]), + "r"(uc[0]), "r"(uc[1]), "r"(uc[2]), "r"(uc[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 64, int8_t>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], a[2], a[3], b[0], b[1], c[0], c[1], c[2], c[3]); + asm("mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %10, %11, %12, %13 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +} + +__global__ void mma_kernel_m16n8k128(int *a, int *b, int *c, int *d) { + // CHECK: dpct::experimental::matrix::mma<16, 8, 128, uint8_t, sycl::bit_and<>>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], b[0], c[0], c[1], c[2], c[3]); + asm ("mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), + "r"(b[0]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 128, uint8_t, sycl::bit_xor<>>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], b[0], c[0], c[1], c[2], c[3]); + asm ("mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), + "r"(b[0]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +} + +__global__ void mma_kernel_m16n8k256(int *a, int *b, int *c, int *d) { + // CHECK: dpct::experimental::matrix::mma<16, 8, 256, uint8_t, sycl::bit_and<>>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], a[2], a[3], b[0], b[1], c[0], c[1], c[2], c[3]); + asm ("mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %10, %11, %12, %13 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 256, uint8_t, sycl::bit_xor<>>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], a[2], a[3], b[0], b[1], c[0], c[1], c[2], c[3]); + asm ("mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %10, %11, %12, %13 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); } int main () { + uint *uint_a, *uint_b, *uint_c, *uint_d; + int *int_a, *int_b, *int_c, *int_d; + float *float_a, *float_b, *float_c, *float_d; + double *double_a, *double_b, *double_c, *double_d; + + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m8n8k4<<<1, 32>>>(int_a, int_b, float_c, double_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m8n8k16<<<1, 32>>>(int_a, int_b, int_c, int_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m8n8k32<<<1, 32>>>(int_a, int_b, int_c, int_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m8n8k128<<<1, 32>>>(int_a, int_b, int_c, int_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m16n8k4<<<1, 32>>>(float_a, float_b, float_c, float_d, double_a, double_b, double_c, double_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m16n8k8<<<1, 32>>>(int_a, int_b, uint_c, uint_d, float_a, float_b, float_c, float_d, double_a, double_b, double_c, double_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m16n8k16<<<1, 32>>>(uint_a, uint_b, uint_c, uint_d, int_a, int_b, int_c, float_c, int_d, double_a, double_b, double_c, double_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m16n8k32<<<1, 32>>>(uint_a, uint_b, uint_c, uint_d, int_a, int_b, int_c, int_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m16n8k64<<<1, 32>>>(uint_a, uint_b, uint_c, uint_d, int_a, int_b, int_c, int_d); + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m16n8k128<<<1, 32>>>(int_a, int_b, int_c, int_d); // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { - mma_kernel<<<1, 32>>>(); + mma_kernel_m16n8k256<<<1, 32>>>(int_a, int_b, int_c, int_d); return 0; } From 0769b4ad9505d5c59fa22796b8d4e415854423c2 Mon Sep 17 00:00:00 2001 From: Teja Alaghari Date: Mon, 5 May 2025 21:32:50 +0800 Subject: [PATCH 5/6] Merged type based overloads using constexpr --- clang/runtime/dpct-rt/include/dpct/math.hpp | 654 +++++++------------- 1 file changed, 228 insertions(+), 426 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index e00763a0291d..8e97ca6bd1ca 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2231,6 +2231,10 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, CDType *d4, CDType *d5, /// Multiplies 2 8x4 & 4x8 f64 matrices and accumulates the result to a 8x8 b64 /// matrix (m8n8k4.row.col.f64.f64.f64.f64) +/// Multiplies 2 8x16 & 16x8 u8/s8 matrices and accumulates the result to a 8x8 +/// s32 matrix (m8n8k16.row.col.s32.u8.u8.s32 / m8n8k16.row.col.s32.s8.s8.s32) +/// Multiplies 2 8x32 & 32x8 u4/s4 matrices and accumulates the result to a 8x8 +/// s32 matrix (m8n8k32.row.col.s32.u4.u4.s32 / m8n8k32.row.col.s32.s4.s4.s32) /// Requires the sub-group size of kernel calling this function to be 32 /// In: 2, 1, 1, 2 /// \tparam [in] M The rows of A/C/D matrix @@ -2247,8 +2251,7 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, CDType *d4, CDType *d5, /// \param [in] c1 The 2nd element from C matrix to be added with d1 template -typename std::enable_if_t, void> -mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1) { +void mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1) { auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); int lane = sg.get_local_linear_id(); @@ -2264,41 +2267,7 @@ mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1) { recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); c1 += recv_a * recv_b; } - } - - *d0 = c0; - *d1 = c1; -} - -/// Multiplies 2 8x16 & 16x8 u8/s8 matrices and accumulates the result to a 8x8 -/// s32 matrix (m8n8k16.row.col.s32.u8.u8.s32 / m8n8k16.row.col.s32.s8.s8.s32) -/// Multiplies 2 8x32 & 32x8 u4/s4 matrices and accumulates the result to a 8x8 -/// s32 matrix (m8n8k32.row.col.s32.u4.u4.s32 / m8n8k32.row.col.s32.s4.s4.s32) -/// Requires the sub-group size of kernel calling this function to be 32 -/// In: 2, 1, 1, 2 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -template -std::enable_if_t, void> -mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); - - if (M == 8 && N == 8 && K == 16) { + } else if (M == 8 && N == 8 && K == 16) { for (int i = 0; i < 4; i++) { ABType recv_a = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); ABType recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); @@ -2317,33 +2286,37 @@ mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1) { } } } else if (M == 8 && N == 8 && K == 32) { - for (int i = 0; i < 4; i++) { - ABType recv_a = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - ABType recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - - MulType *a = reinterpret_cast(&recv_a); - MulType *b = reinterpret_cast(&recv_b); - - for (int k = 0; k < 4; k++) { - MulType a0 = a[k] >> 4; - MulType a1 = a[k] & 0x0F; - MulType b0 = b[k] >> 4; - MulType b1 = b[k] & 0x0F; - - c0 += a0 * b0; - c0 += a1 * b1; - } - - recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - for (int k = 0; k < 4; k++) { - MulType a0 = a[k] >> 4; - MulType a1 = a[k] & 0x0F; - MulType b0 = b[k] >> 4; - MulType b1 = b[k] & 0x0F; - - c1 += a0 * b0; - c1 += a1 * b1; + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a = + dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + ABType recv_b = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + + MulType *a = reinterpret_cast(&recv_a); + MulType *b = reinterpret_cast(&recv_b); + + for (int k = 0; k < 4; k++) { + MulType a0 = a[k] >> 4; + MulType a1 = a[k] & 0x0F; + MulType b0 = b[k] >> 4; + MulType b1 = b[k] & 0x0F; + + c0 += a0 * b0; + c0 += a1 * b1; + } + + recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + for (int k = 0; k < 4; k++) { + MulType a0 = a[k] >> 4; + MulType a1 = a[k] & 0x0F; + MulType b0 = b[k] >> 4; + MulType b1 = b[k] & 0x0F; + + c1 += a0 * b0; + c1 += a1 * b1; + } } } } @@ -2501,144 +2474,9 @@ void mma(CDType *d0, CDType *d1, ABType a0, ABType a1, ABType b0, CDType c0, *d1 = c1; } -/// Multiplies 2 16x4 & 4x8 f32/f64 matrices and accumulates the result to a -/// 16x8 f32/f64 matrix (m16n8k4.row.col.f32.f32.f32.f32 / -/// m16n8k4.row.col.f64.f64.f64.f64) -/// Requires the sub-group size of kernel -/// calling this function to be 32 -/// In: 4, 2, 1, 4 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] d2 The 3rd element to be written to the output D matrix -/// \param [in] d3 The 4th element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -/// \param [in] c2 The 3rd element from C matrix to be added with d2 -/// \param [in] c3 The 4th element from C matrix to be added with d3 -template -std::enable_if_t, void> -mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, - ABType b0, CDType c0, CDType c1, CDType c2, CDType c3) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); - - if (M == 16 && N == 8 && K == 4) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - c0 += recv_a[0] * recv_b[0]; - c1 += recv_a[0] * recv_b[1]; - c2 += recv_a[1] * recv_b[0]; - c3 += recv_a[1] * recv_b[1]; - } - } else if (M == 16 && N == 8 && K == 8) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - auto ra = reinterpret_cast(recv_a); - auto rb = reinterpret_cast(recv_b); - - for (int j = 0; j < 2; j++) { - c0 += static_cast(ra[j]) * static_cast(rb[j]); - c1 += static_cast(ra[j]) * static_cast(rb[j + 2]); - c2 += static_cast(ra[j + 2]) * static_cast(rb[j]); - c3 += static_cast(ra[j + 2]) * static_cast(rb[j + 2]); - } - } - } - - *d0 = c0; - *d1 = c1; - *d2 = c2; - *d3 = c3; -} - -/// Multiplies 2 16x8 & 8x8 f16 matrices and accumulates the result to a -/// 16x8 f32 matrix (m16n8k8.row.col.f32.f16.f16.f32) -/// Requires the sub-group size of kernel -/// calling this function to be 32 -/// In: 4, 2, 1, 4 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] d2 The 3rd element to be written to the output D matrix -/// \param [in] d3 The 4th element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -/// \param [in] c2 The 3rd element from C matrix to be added with d2 -/// \param [in] c3 The 4th element from C matrix to be added with d3 -template -std::enable_if_t, void> -mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, - ABType b0, CDType c0, CDType c1, CDType c2, CDType c3) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); - - if (M == 16 && N == 8 && K == 8) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - auto ra = reinterpret_cast(recv_a); - auto rb = reinterpret_cast(recv_b); - - for (int j = 0; j < 2; j++) { - c0 += static_cast(ra[j]) * static_cast(rb[j]); - c1 += static_cast(ra[j]) * static_cast(rb[j + 2]); - c2 += static_cast(ra[j + 2]) * static_cast(rb[j]); - c3 += static_cast(ra[j + 2]) * static_cast(rb[j + 2]); - } - } - } - - *d0 = c0; - *d1 = c1; - *d2 = c2; - *d3 = c3; -} - /// Multiplies 2 16x16 & 16x8 f16 matrices and accumulates the result to a 16x8 -/// f16 matrix (m16n8k16.row.col.f16.f16.f16.f16) Requires the sub-group size of -/// kernel calling this function to be 32 +/// f16 matrix (m16n8k16.row.col.f16.f16.f16.f16). +/// Requires the sub-group size of kernel calling this function to be 32 /// \tparam [in] M The rows of A/C/D matrix /// \tparam [in] N The columns of B/C/D matrix /// \tparam [in] K The columns/rows of A/B matrix @@ -2706,85 +2544,17 @@ void mma(volatile CDType *d0, volatile CDType *d1, ABType a0, ABType a1, *d1 = c1; } -/// Multiplies 2 16x16 & 16x8 matrices and accumulates the result to a 16x8 b32 -/// matrix (m16n8k16.row.col.f32.f16.f16.f32) -/// Requires the sub-group size of kernel calling this function to be 32 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// In: 4, 4, 2, 4 -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] d2 The 3rd element to be written to the output D matrix -/// \param [in] d3 The 4th element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix -/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix -/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -/// \param [in] c2 The 3rd element from C matrix to be added with d2 -/// \param [in] c3 The 4th element from C matrix to be added with d3 -template -std::enable_if_t, void> -mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, - ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, CDType c2, - CDType c3) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); - - if (M == 16 && N == 8 && K == 16) { - for (int i = 0; i < 4; i++) { - ABType recv_a[4], recv_b[4]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); - recv_a[2] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_a[3] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); - - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); - recv_b[2] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 4 + i); - recv_b[3] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 4 + i); - - auto *ra0 = reinterpret_cast(recv_a); - auto *ra1 = reinterpret_cast(recv_a + 2); - auto *rb0 = reinterpret_cast(recv_b); - auto *rb1 = reinterpret_cast(recv_b + 2); - - // Iterate for k (i * j) times - for (int j = 0; j < 4; j++) { - auto a0 = static_cast(ra0[j]); - auto a1 = static_cast(ra1[j]); - auto b0 = static_cast(rb0[j]); - auto b1 = static_cast(rb1[j]); - - c0 += a0 * b0; - c1 += a0 * b1; - c2 += a1 * b0; - c3 += a1 * b1; - } - } - } - - *d0 = c0; - *d1 = c1; - *d2 = c2; - *d3 = c3; -} - /// Multiplies 2 16x8 & 8x8 u4/s4 matrices and accumulates the result to a 16x8 -/// f64 matrix (m16n8k8.row.col.f64.f64.f64.f64) Requires the sub-group size of -/// kernel calling this function to be 32 +/// f64 matrix (m16n8k8.row.col.f64.f64.f64.f64). +/// Multiplies 2 16x16 & 16x8 matrices and accumulates the result to a 16x8 b32 +/// matrix (m16n8k16.row.col.f32.f16.f16.f32). +/// Multiplies 2 16x32 & 32x8 u8/s8 matrices and accumulates the result to a +/// 16x8 b32 matrix (m16n8k32.row.col.s32.u8.u8.s32 / +/// m16n8k32.row.col.s32.s8.s8.s32). +/// Multiplies 2 16x64 & 64x8 u4/s4 matrices and +/// accumulates the result to a 16x8 b32 matrix (m16n8k64.row.col.s32.u4.u4.s32 +/// / m16n8k64.row.col.s32.s4.s4.s32). +/// Requires the sub-group size of kernel calling this function to be 32. /// \tparam [in] M The rows of A/C/D matrix /// \tparam [in] N The columns of B/C/D matrix /// \tparam [in] K The columns/rows of A/B matrix @@ -2808,10 +2578,9 @@ mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, /// \param [in] c3 The 4th element from C matrix to be added with d3 template -std::enable_if_t, void> -mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, - ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, CDType c2, - CDType c3) { +void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, + CDType c2, CDType c3) { auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); int lane = sg.get_local_linear_id(); @@ -2846,54 +2615,39 @@ mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, c2 += recv_a[1] * recv_b[0]; c3 += recv_a[1] * recv_b[1]; } - } + } else if (M == 16 && N == 8 && K == 16) { + for (int i = 0; i < 4; i++) { + ABType recv_a[4], recv_b[4]; - *d0 = c0; - *d1 = c1; - *d2 = c2; - *d3 = c3; -} + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[2] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_a[3] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); -/// Multiplies 2 16x32 & 32x8 u8/s8 matrices and accumulates the result to a -/// 16x8 b32 matrix (m16n8k32.row.col.s32.u8.u8.s32 / -/// m16n8k32.row.col.s32.s8.s8.s32) Multiplies 2 16x64 & 64x8 u4/s4 matrices and -/// accumulates the result to a 16x8 b32 matrix (m16n8k64.row.col.s32.u4.u4.s32 -/// / m16n8k64.row.col.s32.s4.s4.s32) Requires the sub-group size of kernel -/// calling this function to be 32 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// In: 4, 4, 2, 4 -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] d2 The 3rd element to be written to the output D matrix -/// \param [in] d3 The 4th element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix -/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix -/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -/// \param [in] c2 The 3rd element from C matrix to be added with d2 -/// \param [in] c3 The 4th element from C matrix to be added with d3 -template -std::enable_if_t, void> -mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, - ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, CDType c2, - CDType c3) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[2] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 4 + i); + recv_b[3] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 4 + i); - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); + auto *ra0 = reinterpret_cast(recv_a); + auto *ra1 = reinterpret_cast(recv_a + 2); + auto *rb0 = reinterpret_cast(recv_b); + auto *rb1 = reinterpret_cast(recv_b + 2); - if (M == 16 && N == 8 && K == 32) { + // Iterate for k (i * j) times + for (int j = 0; j < 4; j++) { + auto a0 = static_cast(ra0[j]); + auto a1 = static_cast(ra1[j]); + auto b0 = static_cast(rb0[j]); + auto b1 = static_cast(rb1[j]); + + c0 += a0 * b0; + c1 += a0 * b1; + c2 += a1 * b0; + c3 += a1 * b1; + } + } + } else if (M == 16 && N == 8 && K == 32) { for (int i = 0; i < 4; i++) { ABType recv_a[2], recv_b[2]; @@ -2932,73 +2686,77 @@ mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, } } } else if (M == 16 && N == 8 && K == 64) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + MulType a00 = a[k] >> 4; + MulType a01 = a[k] & 0x0F; + MulType a10 = a[k + 4] >> 4; + MulType a11 = a[k + 4] & 0x0F; + MulType b00 = b[k] >> 4; + MulType b01 = b[k] & 0x0F; + MulType b10 = b[k + 4] >> 4; + MulType b11 = b[k + 4] & 0x0F; + + c0 += a00 * b00; + c0 += a01 * b01; + + c1 += a00 * b10; + c1 += a01 * b11; + + c2 += a10 * b00; + c2 += a11 * b01; + + c3 += a10 * b10; + c3 += a11 * b11; + } + } - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; - MulType *a = reinterpret_cast(recv_a); - MulType *b = reinterpret_cast(recv_b); + recv_a[0] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); - for (int k = 0; k < 4; k++) { - MulType a00 = a[k] >> 4; - MulType a01 = a[k] & 0x0F; - MulType a10 = a[k + 4] >> 4; - MulType a11 = a[k + 4] & 0x0F; - MulType b00 = b[k] >> 4; - MulType b01 = b[k] & 0x0F; - MulType b10 = b[k + 4] >> 4; - MulType b11 = b[k + 4] & 0x0F; - - c0 += a00 * b00; - c0 += a01 * b01; - - c1 += a00 * b10; - c1 += a01 * b11; - - c2 += a10 * b00; - c2 += a11 * b01; - - c3 += a10 * b10; - c3 += a11 * b11; - } - } + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; + for (int k = 0; k < 4; k++) { + MulType a00 = a[k] >> 4; + MulType a01 = a[k] & 0x0F; + MulType a10 = a[k + 4] >> 4; + MulType a11 = a[k + 4] & 0x0F; + MulType b00 = b[k] >> 4; + MulType b01 = b[k] & 0x0F; + MulType b10 = b[k + 4] >> 4; + MulType b11 = b[k + 4] & 0x0F; - recv_a[0] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + c0 += a00 * b00; + c0 += a01 * b01; - MulType *a = reinterpret_cast(recv_a); - MulType *b = reinterpret_cast(recv_b); + c1 += a00 * b10; + c1 += a01 * b11; - for (int k = 0; k < 4; k++) { - MulType a00 = a[k] >> 4; - MulType a01 = a[k] & 0x0F; - MulType a10 = a[k + 4] >> 4; - MulType a11 = a[k + 4] & 0x0F; - MulType b00 = b[k] >> 4; - MulType b01 = b[k] & 0x0F; - MulType b10 = b[k + 4] >> 4; - MulType b11 = b[k + 4] & 0x0F; - - c0 += a00 * b00; - c0 += a01 * b01; - - c1 += a00 * b10; - c1 += a01 * b11; - - c2 += a10 * b00; - c2 += a11 * b01; - - c3 += a10 * b10; - c3 += a11 * b11; + c2 += a10 * b00; + c2 += a11 * b01; + + c3 += a10 * b10; + c3 += a11 * b11; + } } } } @@ -3092,12 +2850,21 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, *d3 = c3; } -/// Multiplies 2 16x32 & 32x8 u4/s4 matrices and accumulates the result to a -/// 16x8 s32 matrix (m16n8k32.row.col.s32.u4.u4.s32 / -/// m16n8k32.row.col.s32.s4.s4.s32) +/// Multiplies 2 16x4 & 4x8 f16 matrices and accumulates the result to a +/// 16x8 f32 matrix (m16n8k4.row.col.f16.f16.f16.f16 / +/// m16n8k4.row.col.f32.f16.f16.f32) +/// Multiplies 2 16x4 & 4x8 f64 matrices and accumulates the result to a +/// 16x8 f64 matrix (m16n8k4.row.col.f64.f64.f64.f64) +/// Multiplies 2 16x8 & 8x8 f16 matrices and accumulates the result to a +/// 16x8 f32 matrix (m16n8k8.row.col.f32.f16.f16.f32) +/// Multiplies 2 16x8 & 8x8 f64 matrices and accumulates the result to a +/// 16x8 f64 matrix (m16n8k8.row.col.f64.f64.f64.f64) /// Multiplies 2 16x16 & 16x8 u8/s8 matrices and accumulates the result to a /// 16x8 s32 matrix (m16n8k16.row.col.s32.u8.u8.s32 / /// m16n8k16.row.col.s32.s8.s8.s32) +/// Multiplies 2 16x32 & 32x8 u4/s4 matrices and accumulates the result to a +/// 16x8 s32 matrix (m16n8k32.row.col.s32.u4.u4.s32 / +/// m16n8k32.row.col.s32.s4.s4.s32) /// Requires the sub-group size of kernel /// calling this function to be 32 /// \tparam [in] M The rows of A/C/D matrix @@ -3120,16 +2887,48 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, /// \param [in] c3 The 4th element from C matrix to be added with d3 template -std::enable_if_t, void> -mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, - ABType b0, CDType c0, CDType c1, CDType c2, CDType c3) { +void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType b0, CDType c0, CDType c1, CDType c2, CDType c3) { auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); int lane = sg.get_local_linear_id(); short ROW_LOAD_OFFSET = 4 * (lane >> 2); short COL_LOAD_OFFSET = 8 * (lane % 4); - if (M == 16 && N == 8 && K == 16) { + if (M == 16 && N == 8 && K == 4) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c0 += recv_a[0] * recv_b[0]; + c1 += recv_a[0] * recv_b[1]; + c2 += recv_a[1] * recv_b[0]; + c3 += recv_a[1] * recv_b[1]; + } + } else if (M == 16 && N == 8 && K == 8) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + for (int j = 0; j < 2; j++) { + c0 += static_cast(ra[j]) * static_cast(rb[j]); + c1 += static_cast(ra[j]) * static_cast(rb[j + 2]); + c2 += static_cast(ra[j + 2]) * static_cast(rb[j]); + c3 += static_cast(ra[j + 2]) * static_cast(rb[j + 2]); + } + } + } else if (M == 16 && N == 8 && K == 16) { ABType recv_a[4 * 2], recv_b[4 * 2]; for (int i = 0; i < 4; i++) { @@ -3150,38 +2949,41 @@ mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, c3 += a[i + 16] * b[i + 16]; } } else if (M == 16 && N == 8 && K == 32) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - MulType *a = reinterpret_cast(recv_a); - MulType *b = reinterpret_cast(recv_b); - - for (int k = 0; k < 4; k++) { - MulType a00 = a[k] >> 4; - MulType a01 = a[k] & 0x0F; - MulType a10 = a[k + 4] >> 4; - MulType a11 = a[k + 4] & 0x0F; - MulType b00 = b[k] >> 4; - MulType b01 = b[k] & 0x0F; - MulType b10 = b[k + 4] >> 4; - MulType b11 = b[k + 4] & 0x0F; - - c0 += a00 * b00; - c0 += a01 * b01; - - c1 += a00 * b10; - c1 += a01 * b11; - - c2 += a10 * b00; - c2 += a11 * b01; - - c3 += a10 * b10; - c3 += a11 * b11; + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + MulType a00 = a[k] >> 4; + MulType a01 = a[k] & 0x0F; + MulType a10 = a[k + 4] >> 4; + MulType a11 = a[k + 4] & 0x0F; + MulType b00 = b[k] >> 4; + MulType b01 = b[k] & 0x0F; + MulType b10 = b[k + 4] >> 4; + MulType b11 = b[k + 4] & 0x0F; + + c0 += a00 * b00; + c0 += a01 * b01; + + c1 += a00 * b10; + c1 += a01 * b11; + + c2 += a10 * b00; + c2 += a11 * b01; + + c3 += a10 * b10; + c3 += a11 * b11; + } } } } From 7ad002a402be896ed51c393b0e75164754df3324 Mon Sep 17 00:00:00 2001 From: Teja Alaghari Date: Mon, 5 May 2025 22:33:47 +0800 Subject: [PATCH 6/6] Added Op template arg to pass bit operator --- clang/lib/DPCT/RulesAsm/AsmMigration.cpp | 9 +- clang/runtime/dpct-rt/include/dpct/math.hpp | 437 ++++---------------- 2 files changed, 95 insertions(+), 351 deletions(-) diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index f0d7dba3f198..a53c5d5d0d94 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -1642,12 +1642,13 @@ class SYCLGen : public SYCLGenBase { MulType = ABType; OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma"; - if (!MatrixOp.empty()) { - OS() << "_" << MatrixOp; - } OS() << "<"; OS() << M << ", " << N << ", " << K << ", "; - OS() << MulType << ">("; + OS() << MulType; + if (!MatrixOp.empty()) { + OS() << ", sycl::bit_" << MatrixOp << "<>"; + } + OS() << ">("; // Add D matrix address values to store the MAD result for (unsigned Inst = 0; Inst != DMatVE->getNumElements(); ++Inst) { diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index 8e97ca6bd1ca..a95f8b934e70 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2230,12 +2230,16 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, CDType *d4, CDType *d5, } /// Multiplies 2 8x4 & 4x8 f64 matrices and accumulates the result to a 8x8 b64 -/// matrix (m8n8k4.row.col.f64.f64.f64.f64) +/// matrix (m8n8k4.row.col.f64.f64.f64.f64). /// Multiplies 2 8x16 & 16x8 u8/s8 matrices and accumulates the result to a 8x8 -/// s32 matrix (m8n8k16.row.col.s32.u8.u8.s32 / m8n8k16.row.col.s32.s8.s8.s32) +/// s32 matrix (m8n8k16.row.col.s32.u8.u8.s32 / m8n8k16.row.col.s32.s8.s8.s32). /// Multiplies 2 8x32 & 32x8 u4/s4 matrices and accumulates the result to a 8x8 -/// s32 matrix (m8n8k32.row.col.s32.u4.u4.s32 / m8n8k32.row.col.s32.s4.s4.s32) -/// Requires the sub-group size of kernel calling this function to be 32 +/// s32 matrix (m8n8k32.row.col.s32.u4.u4.s32 / m8n8k32.row.col.s32.s4.s4.s32). +/// Multiplies 2 8x128 & 128x8 b1 matrices and accumulates the result to a 8x8 +/// s32 matrix (mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc). +/// Multiplies 2 8x128 & 128x8 b1 matrices and accumulates the result to a 8x8 +/// s32 matrix (mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc). +/// Requires the sub-group size of kernel calling this function to be 32. /// In: 2, 1, 1, 2 /// \tparam [in] M The rows of A/C/D matrix /// \tparam [in] N The columns of B/C/D matrix @@ -2249,9 +2253,10 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, CDType *d4, CDType *d5, /// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix /// \param [in] c0 The 1st element from C matrix to be added with d0 /// \param [in] c1 The 2nd element from C matrix to be added with d1 -template -void mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1) { +template , + typename ABType, typename CDType> +void mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1, + Op op = Op{}) { auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); int lane = sg.get_local_linear_id(); @@ -2319,91 +2324,20 @@ void mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1) { } } } - } - - *d0 = c0; - *d1 = c1; -} - -/// Multiplies 2 8x128 & 128x8 b1 matrices and accumulates the result to a 8x8 -/// s32 matrix (mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc) -/// Requires the sub-group size of kernel calling this function to be 32 -/// In: 2, 1, 1, 2 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -template -void mma_and(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, - CDType c1) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); - - if (M == 8 && N == 8 && K == 128) { - for (int i = 0; i < 4; i++) { - ABType recv_a = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - ABType recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - - c0 += sycl::popcount(recv_a & recv_b); - - recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - c1 += sycl::popcount(recv_a & recv_b); - } - } - - *d0 = c0; - *d1 = c1; -} - -/// Multiplies 2 8x128 & 128x8 b1 matrices and accumulates the result to a 8x8 -/// s32 matrix (mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc) -/// Requires the sub-group size of kernel calling this function to be 32 -/// In: 2, 1, 1, 2 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -template -void mma_xor(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, - CDType c1) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); - - if (M == 8 && N == 8 && K == 128) { - for (int i = 0; i < 4; i++) { - ABType recv_a = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - ABType recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + } else if (M == 8 && N == 8 && K == 128) { + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a = + dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + ABType recv_b = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - c0 += sycl::popcount(recv_a ^ recv_b); + c0 += sycl::popcount(op(recv_a, recv_b)); - recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - c1 += sycl::popcount(recv_a ^ recv_b); + c1 += sycl::popcount(op(recv_a, recv_b)); + } } } @@ -2554,6 +2488,10 @@ void mma(volatile CDType *d0, volatile CDType *d1, ABType a0, ABType a1, /// Multiplies 2 16x64 & 64x8 u4/s4 matrices and /// accumulates the result to a 16x8 b32 matrix (m16n8k64.row.col.s32.u4.u4.s32 /// / m16n8k64.row.col.s32.s4.s4.s32). +/// Multiplies 2 16x256 & 256x8 b1 matrices and accumulates the result to a 16x8 +/// s32 matrix (mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc). +/// Multiplies 2 16x256 & 256x8 b1 matrices and accumulates the result to a 16x8 +/// s32 matrix (mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc). /// Requires the sub-group size of kernel calling this function to be 32. /// \tparam [in] M The rows of A/C/D matrix /// \tparam [in] N The columns of B/C/D matrix @@ -2576,11 +2514,11 @@ void mma(volatile CDType *d0, volatile CDType *d1, ABType a0, ABType a1, /// \param [in] c1 The 2nd element from C matrix to be added with d1 /// \param [in] c2 The 3rd element from C matrix to be added with d2 /// \param [in] c3 The 4th element from C matrix to be added with d3 -template +template , + typename ABType, typename CDType> void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, - CDType c2, CDType c3) { + CDType c2, CDType c3, Op op = Op{}) { auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); int lane = sg.get_local_linear_id(); @@ -2759,6 +2697,38 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, } } } + } else if (M == 16 && N == 8 && K == 256) { + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c0 += sycl::popcount(op(recv_a[0], recv_b[0])); + c1 += sycl::popcount(op(recv_a[0], recv_b[1])); + c2 += sycl::popcount(op(recv_a[1], recv_b[0])); + c3 += sycl::popcount(op(recv_a[1], recv_b[1])); + } + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + c0 += sycl::popcount(op(recv_a[0], recv_b[0])); + c1 += sycl::popcount(op(recv_a[0], recv_b[1])); + c2 += sycl::popcount(op(recv_a[1], recv_b[0])); + c3 += sycl::popcount(op(recv_a[1], recv_b[1])); + } + } } *d0 = c0; @@ -2852,20 +2822,24 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, /// Multiplies 2 16x4 & 4x8 f16 matrices and accumulates the result to a /// 16x8 f32 matrix (m16n8k4.row.col.f16.f16.f16.f16 / -/// m16n8k4.row.col.f32.f16.f16.f32) +/// m16n8k4.row.col.f32.f16.f16.f32). /// Multiplies 2 16x4 & 4x8 f64 matrices and accumulates the result to a -/// 16x8 f64 matrix (m16n8k4.row.col.f64.f64.f64.f64) +/// 16x8 f64 matrix (m16n8k4.row.col.f64.f64.f64.f64). /// Multiplies 2 16x8 & 8x8 f16 matrices and accumulates the result to a -/// 16x8 f32 matrix (m16n8k8.row.col.f32.f16.f16.f32) +/// 16x8 f32 matrix (m16n8k8.row.col.f32.f16.f16.f32). /// Multiplies 2 16x8 & 8x8 f64 matrices and accumulates the result to a -/// 16x8 f64 matrix (m16n8k8.row.col.f64.f64.f64.f64) +/// 16x8 f64 matrix (m16n8k8.row.col.f64.f64.f64.f64). /// Multiplies 2 16x16 & 16x8 u8/s8 matrices and accumulates the result to a /// 16x8 s32 matrix (m16n8k16.row.col.s32.u8.u8.s32 / -/// m16n8k16.row.col.s32.s8.s8.s32) +/// m16n8k16.row.col.s32.s8.s8.s32). /// Multiplies 2 16x32 & 32x8 u4/s4 matrices and accumulates the result to a /// 16x8 s32 matrix (m16n8k32.row.col.s32.u4.u4.s32 / -/// m16n8k32.row.col.s32.s4.s4.s32) -/// Requires the sub-group size of kernel +/// m16n8k32.row.col.s32.s4.s4.s32). +/// Multiplies 2 16x128 & 128x8 b1 matrices and accumulates the result to a 16x8 +/// s32 matrix (mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc). +/// Multiplies 2 16x128 & 128x8 b1 matrices and accumulates the result to a 16x8 +/// s32 matrix (mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc). +/// Requires the sub-group size of kernel. /// calling this function to be 32 /// \tparam [in] M The rows of A/C/D matrix /// \tparam [in] N The columns of B/C/D matrix @@ -2885,10 +2859,10 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, /// \param [in] c1 The 2nd element from C matrix to be added with d1 /// \param [in] c2 The 3rd element from C matrix to be added with d2 /// \param [in] c3 The 4th element from C matrix to be added with d3 -template +template , + typename ABType, typename CDType> void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, - ABType b0, CDType c0, CDType c1, CDType c2, CDType c3) { + ABType b0, CDType c0, CDType c1, CDType c2, CDType c3, Op op = Op{}) { auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); int lane = sg.get_local_linear_id(); @@ -2986,253 +2960,22 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, } } } - } - - *d0 = c0; - *d1 = c1; - *d2 = c2; - *d3 = c3; -} - -/// Multiplies 2 16x128 & 128x8 b1 matrices and accumulates the result to a 16x8 -/// s32 matrix (mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc) -/// Requires the sub-group size of kernel calling this function to be 32 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// In: 4, 2, 1, 4 -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] d2 The 3rd element to be written to the output D matrix -/// \param [in] d3 The 4th element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -/// \param [in] c2 The 3rd element from C matrix to be added with d2 -/// \param [in] c3 The 4th element from C matrix to be added with d3 -template -void mma_and(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, - ABType a1, ABType b0, CDType c0, CDType c1, CDType c2, CDType c3) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); - - if (M == 16 && N == 8 && K == 128) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - c0 += sycl::popcount(recv_a[0] & recv_b[0]); - c1 += sycl::popcount(recv_a[0] & recv_b[1]); - c2 += sycl::popcount(recv_a[1] & recv_b[0]); - c3 += sycl::popcount(recv_a[1] & recv_b[1]); - } - } - - *d0 = c0; - *d1 = c1; - *d2 = c2; - *d3 = c3; -} - -/// Multiplies 2 16x128 & 128x8 b1 matrices and accumulates the result to a 16x8 -/// s32 matrix (mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc) -/// Requires the sub-group size of kernel calling this function to be 32 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// In: 4, 2, 1, 4 -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] d2 The 3rd element to be written to the output D matrix -/// \param [in] d3 The 4th element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -/// \param [in] c2 The 3rd element from C matrix to be added with d2 -/// \param [in] c3 The 4th element from C matrix to be added with d3 -template -void mma_xor(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, - ABType a1, ABType b0, CDType c0, CDType c1, CDType c2, CDType c3) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); - - if (M == 16 && N == 8 && K == 128) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - c0 += sycl::popcount(recv_a[0] ^ recv_b[0]); - c1 += sycl::popcount(recv_a[0] ^ recv_b[1]); - c2 += sycl::popcount(recv_a[1] ^ recv_b[0]); - c3 += sycl::popcount(recv_a[1] ^ recv_b[1]); - } - } - - *d0 = c0; - *d1 = c1; - *d2 = c2; - *d3 = c3; -} - -/// Multiplies 2 16x256 & 256x8 b1 matrices and accumulates the result to a 16x8 -/// s32 matrix (mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc) -/// Requires the sub-group size of kernel calling this function to be 32 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// In: 4, 4, 2, 4 -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] d2 The 3rd element to be written to the output D matrix -/// \param [in] d3 The 4th element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix -/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix -/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -/// \param [in] c2 The 3rd element from C matrix to be added with d2 -/// \param [in] c3 The 4th element from C matrix to be added with d3 -template -void mma_and(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, - ABType a1, ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, - CDType c1, CDType c2, CDType c3) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); - - if (M == 16 && N == 8 && K == 256) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - c0 += sycl::popcount(recv_a[0] & recv_b[0]); - c1 += sycl::popcount(recv_a[0] & recv_b[1]); - c2 += sycl::popcount(recv_a[1] & recv_b[0]); - c3 += sycl::popcount(recv_a[1] & recv_b[1]); - } - - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); - - c0 += sycl::popcount(recv_a[0] & recv_b[0]); - c1 += sycl::popcount(recv_a[0] & recv_b[1]); - c2 += sycl::popcount(recv_a[1] & recv_b[0]); - c3 += sycl::popcount(recv_a[1] & recv_b[1]); - } - } - - *d0 = c0; - *d1 = c1; - *d2 = c2; - *d3 = c3; -} - -/// Multiplies 2 16x256 & 256x8 b1 matrices and accumulates the result to a 16x8 -/// s32 matrix (mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc) -/// Requires the sub-group size of kernel calling this function to be 32 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// In: 4, 4, 2, 4 -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] d2 The 3rd element to be written to the output D matrix -/// \param [in] d3 The 4th element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix -/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix -/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -/// \param [in] c2 The 3rd element from C matrix to be added with d2 -/// \param [in] c3 The 4th element from C matrix to be added with d3 -template -void mma_xor(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, - ABType a1, ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, - CDType c1, CDType c2, CDType c3) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); - - if (M == 16 && N == 8 && K == 256) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - c0 += sycl::popcount(recv_a[0] ^ recv_b[0]); - c1 += sycl::popcount(recv_a[0] ^ recv_b[1]); - c2 += sycl::popcount(recv_a[1] ^ recv_b[0]); - c3 += sycl::popcount(recv_a[1] ^ recv_b[1]); - } - - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; + } else if (M == 16 && N == 8 && K == 128) { + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; - recv_a[0] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - c0 += sycl::popcount(recv_a[0] ^ recv_b[0]); - c1 += sycl::popcount(recv_a[0] ^ recv_b[1]); - c2 += sycl::popcount(recv_a[1] ^ recv_b[0]); - c3 += sycl::popcount(recv_a[1] ^ recv_b[1]); + c0 += sycl::popcount(op(recv_a[0], recv_b[0])); + c1 += sycl::popcount(op(recv_a[0], recv_b[1])); + c2 += sycl::popcount(op(recv_a[1], recv_b[0])); + c3 += sycl::popcount(op(recv_a[1], recv_b[1])); + } } }