Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions clang/lib/DPCT/RulesAsm/AsmMigration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1488,6 +1488,8 @@ class SYCLGen : public SYCLGenBase {
// Data types of A, B & C matrices respectively in the PTX arguments
std::string InMatrixType[3];

InMatrixType[2] = CDType;

if (Inst->hasAttr(InstAttr::m8n8k4)) {
M = "8";
N = "8";
Expand Down Expand Up @@ -1560,13 +1562,22 @@ class SYCLGen : public SYCLGenBase {
InMatrixType[0] = "uint32_t"; // A type is .f16/.bf16x2
InMatrixType[1] = "uint32_t"; // B type is .f16/.bf16x2

// If A matrix type is f16, then C&D matrix types can only be f32
// If A matrix type is f16, then C&D matrix types can be f32
if (CType->getKind() == InlineAsmBuiltinType::f32) {
NumVecElements[0] = 4; // A
NumVecElements[1] = 2; // B
NumVecElements[2] = 4; // C
NumVecElements[3] = 4; // D
} else
}
// C &D matrix types can be f16.
else if (CType->getKind() == InlineAsmBuiltinType::f16) {
NumVecElements[0] = 4; // A
NumVecElements[1] = 2; // B
NumVecElements[2] = 2; // C
NumVecElements[3] = 2; // D
InMatrixType[2] = "uint32_t"; // C type is f16*2
}
else
return SYCLGenError();
} else if (AType->getKind() == InlineAsmBuiltinType::s8) {
InMatrixType[0] = "uint32_t"; // A type is .s8x4
Expand Down Expand Up @@ -1605,8 +1616,6 @@ class SYCLGen : public SYCLGenBase {
} else
return SYCLGenError();

InMatrixType[2] = CDType;

// Check the register sizes for vector elements of A, B, C & D matrices
for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands();
InputOp++) {
Expand Down
62 changes: 61 additions & 1 deletion clang/runtime/dpct-rt/include/dpct/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2394,7 +2394,7 @@ template <typename T> struct MMAType {
/// - m8n8k4 (f32.f16.f16.f32)
/// - m8n8k16 (s32.s8.s8.s32)
/// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32)
/// - m16n8k16 (f32.f16.f16.f32 & f32.bf16.bf16.f32 & s32.s8.s8.s32)
/// - m16n8k16 (f32.f16.f16.f32 & f16.f16.f16.f16 & f32.bf16.bf16.f32 & s32.s8.s8.s32)
/// - m16n8k32 (s32.s8.s8.s32)
/// Here, m, n & k define the shapes of A, B & C matrices respectively
/// (A = [m x k], B = [k x n], C = [m x n]).
Expand Down Expand Up @@ -2671,6 +2671,66 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag,
static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j + 4]);
}
}
} else if constexpr (std::is_same_v<CDType, sycl::half>) {
// Init D matrix fragment with C matrix fragment
sycl::half *d0 = const_cast<sycl::half *>(d[0]);
sycl::half *d1 = d0 + 1;
sycl::half *d2 = const_cast<sycl::half *>(d[1]);
sycl::half *d3 = d2 + 1;
*d0 = c[0];
*d1 = c[1];
*d2 = c[2];
*d3 = c[3];

// Each sub-group is responsible for computing a fragment size of 16*8
// elements of matrix D.
// Each work item computes 4 elements of matrix D by gathering
// their corresponding row & col matrix fragments of length k (8)
// from A & B matrices respectively using below mapping logic:
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
// col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
// As each row & col fragment of A & B matrices is distributed across
// 4 work items, each iteration of below loop loads a partial fragment
// of matrix A (row) and matrix B (col) using the row & col offsets.
for (int i = 0; i < 4; i++) {
typename MMAType<ABType>::PackType recv_a[4], recv_b[4];

// Load partial fragment from row0 of matrix A ({a0, a1})
recv_a[0] = dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
// Load partial fragment from row0 of matrix A ({a2, a3})
recv_a[1] = dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
// Load partial fragment from row1 of matrix A ({a0, a1})
recv_a[2] = dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
// Load partial fragment from row1 of matrix A ({a2, a3})
recv_a[3] = dpct::select_from_sub_group(sg, a[3], row_load_offset + i);

// Load partial fragment from col0 of matrix B ({b0, b1})
recv_b[0] = dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
// Load partial fragment from col0 of matrix B ({b2, b3})
recv_b[1] = dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
// Load partial fragment from col1 of matrix B ({b0, b1})
recv_b[2] =
dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i);
// Load partial fragment from col1 of matrix B ({b2, b3})
recv_b[3] =
dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i);

auto ra = reinterpret_cast<ABType *>(recv_a);
auto rb = reinterpret_cast<ABType *>(recv_b);

// Each work item calculates a partial product of A & B matrix fragments
// and adds it to the corresponding D matrix fragment
// d0 += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 }
// d1 += row0{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 }
// d2 += row1{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 }
// d3 += row1{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 }
for (int j = 0; j < 4; j++) {
*d0 += ra[j] * rb[j];
*d1 += ra[j] * rb[j + 4];
*d2 += ra[j + 4] * rb[j];
*d3 += ra[j + 4] * rb[j + 4];
}
}
} else if constexpr (std::is_integral_v<ABType>) {
// Init D matrix with fragments of C matrix
*d[0] = c[0];
Expand Down
22 changes: 21 additions & 1 deletion clang/test/dpct/asm/mma.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ m8n8k16 .s8 .s8 .s32
m16n8k8 .f16/.bf16 .f16/.bf16 .f32
m16n8k16 .f16 .f16 .f32
.bf16 .bf16 .f32
.s8 .s8 .s32
.s8 .s8 .s32
.f16 .f16 .f16
m16n8k32 .s8 .s8 .s32

Except for m8n8k4, all other shapes are supported for row/col layout of A/B matrices respectively.
Expand Down Expand Up @@ -100,6 +101,25 @@ __global__ void mma_kernel_m16n8k8(int *a, int *b, float *fc, float *fd) {
"f"(fc[0]), "f"(fc[1]), "f"(fc[2]), "f"(fc[3]));
}

__global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, int *d) {
// CHECK: {
// CHECK-NEXT: volatile void *d_mat_frag_ct1[2] = { &d[0], &d[1] };
// CHECK-NEXT: sycl::vec<uint32_t, 4> a_mat_frag_ct1(a[0], a[1], a[2], a[3]);
// CHECK-NEXT: sycl::vec<uint32_t, 2> b_mat_frag_ct1(b[0], b[1]);
// CHECK-NEXT: sycl::vec<uint32_t, 2> c_mat_frag_ct1(c[0], c[1]);
// CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half, sycl::half>(reinterpret_cast<volatile void **>(d_mat_frag_ct1), &a_mat_frag_ct1, &b_mat_frag_ct1, &c_mat_frag_ct1);
// CHECK-NEXT: }
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
" { %0, %1 }, "
" { %2, %3, %4, %5 }, "
" { %6, %7 }, "
" { %8, %9 };"
: "+r"(d[0]), "+r"(d[1])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
"r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]));
}

__global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) {
// CHECK: {
// CHECK-NEXT: volatile void *d_mat_frag_ct1[4] = { &fc[0], &fc[1], &fc[2], &fc[3] };
Expand Down
Loading