Skip to content

Commit a383dbd

Browse files
committed
Add integer WMMA support
Add integer WMMA support Fix Int8 test and cleanup Further cleanup Merge Integer and Float generation and testing Extended valid WMMA matrix shapes Update documentation and remove bit/half-byte code
1 parent 9d69cab commit a383dbd

File tree

2 files changed

+260
-121
lines changed

2 files changed

+260
-121
lines changed

src/device/intrinsics/wmma.jl

Lines changed: 154 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,72 @@ using Core: LLVMPtr
1010

1111
# Maps PTX types to Julia array types
1212
const map_ptx_to_jl_array = Dict(
13+
"u8" => UInt8,
14+
"s8" => Int8,
15+
"s32" => Int32,
1316
"f16" => Float16,
1417
"f32" => Float32
1518
)
1619

1720
# Maps PTX types to Julia fragment types
1821
const map_ptx_to_jl_frag = Dict(
22+
"u8" => UInt32,
23+
"s8" => UInt32,
24+
"s32" => Int32,
1925
"f16" => NTuple{2, VecElement{Float16}},
2026
"f32" => Float32
2127
)
2228

2329
# Maps matrix & PTX types to fragment sizes
2430
const map_frag_sizes = Dict(
25-
"a.f16" => 8,
26-
"b.f16" => 8,
27-
"c.f16" => 4,
28-
"c.f32" => 8,
29-
"d.f16" => 4,
30-
"d.f32" => 8
31+
# A
32+
"a.u8.m16n16k16" => 2,
33+
"a.u8.m8n32k16" => 1,
34+
"a.u8.m32n8k16" => 4,
35+
36+
"a.s8.m16n16k16" => 2,
37+
"a.s8.m8n32k16" => 1,
38+
"a.s8.m32n8k16" => 4,
39+
40+
"a.f16.m16n16k16" => 8,
41+
"a.f16.m8n32k16" => 8,
42+
"a.f16.m32n8k16" => 8,
43+
# B
44+
"b.u8.m16n16k16" => 2,
45+
"b.u8.m8n32k16" => 4,
46+
"b.u8.m32n8k16" => 1,
47+
48+
"b.s8.m16n16k16" => 2,
49+
"b.s8.m8n32k16" => 4,
50+
"b.s8.m32n8k16" => 1,
51+
52+
"b.f16.m16n16k16" => 8,
53+
"b.f16.m8n32k16" => 8,
54+
"b.f16.m32n8k16" => 8,
55+
# C
56+
"c.s32.m16n16k16" => 8,
57+
"c.s32.m8n32k16" => 8,
58+
"c.s32.m32n8k16" => 8,
59+
60+
"c.f16.m16n16k16" => 4,
61+
"c.f16.m8n32k16" => 4,
62+
"c.f16.m32n8k16" => 4,
63+
64+
"c.f32.m16n16k16" => 8,
65+
"c.f32.m8n32k16" => 8,
66+
"c.f32.m32n8k16" => 8,
67+
# D
68+
"d.s32.m16n16k16" => 8,
69+
"d.s32.m8n32k16" => 8,
70+
"d.s32.m32n8k16" => 8,
71+
72+
"d.f16.m16n16k16" => 4,
73+
"d.f16.m8n32k16" => 4,
74+
"d.f16.m32n8k16" => 4,
75+
76+
"d.f32.m16n16k16" => 8,
77+
"d.f32.m8n32k16" => 8,
78+
"d.f32.m32n8k16" => 8,
3179
)
3280

3381
# Maps PTX AS to CUDA.AS
@@ -37,15 +85,41 @@ const map_ptx_as_to_as_ty = Dict(
3785
"global" => AS.Global
3886
)
3987

88+
# Valid WMMA Operation configurations: Shape (M,N,K), Matrix, Element Type
89+
90+
# Half-Precision Floating Point
91+
const ldst_half_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["f16"]
92+
const ldst_half_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["f16", "f32"]
93+
const wmma_half_ops = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f32"], ["f16", "f32"]
94+
# Integer
95+
const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"]
96+
const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"]
97+
const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"]
98+
99+
const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops,
100+
ldst_int_ab_ops, ldst_int_cd_ops)
101+
const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops)
102+
103+
# Valid WMMA operation shapes
104+
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)]
105+
40106
################################################################################
41107
# HELPER FUNCTIONS
42108
################################################################################
43109

110+
# Returns shape information as a string
111+
function get_hl_shape(M, N, K)
112+
if (M, N, K) in valid_shapes
113+
return "m$(M)n$(N)k$(K)"
114+
end
115+
error("Invalid shape for WMMA: (M, N, K) = ($M, $N, $K)")
116+
end
117+
44118
# Returns (Julia array type, Julia fragment type, fragment size)
45-
get_frag_info(matrix, ptx_el_type) = (
119+
get_frag_info(matrix, ptx_el_type, shape) = (
46120
map_ptx_to_jl_array[ptx_el_type],
47121
map_ptx_to_jl_frag[ptx_el_type],
48-
map_frag_sizes["$matrix.$ptx_el_type"]
122+
map_frag_sizes["$matrix.$ptx_el_type.$shape"]
49123
)
50124

51125
get_addrspace_info(addr_space) = convert(Int, map_ptx_as_to_as_ty[addr_space])
@@ -86,27 +160,26 @@ Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.load.{matrix}.sync.{layout}.{
86160
# Placeholders
87161
- `{matrix}`: The matrix to load. Can be `a`, `b` or `c`.
88162
- `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively.
89-
- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`.
163+
- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k16`, `m32n8k16`, and `m8n32k16`.
90164
- `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`.
91-
- `{elem_type}`: The type of each element in the matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). Note that `f32` is only valid for the matrix ``C``.
165+
- `{elem_type}`: The type of each element in the matrix. For `a` and `b` matrices, valid values are `u8` (byte unsigned integer),
166+
`s8` (byte signed integer), and `f16` (half precision floating point). For `c` and `d` matrices, valid values are
167+
`s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point).
92168
"""
93169
llvm_wmma_load() = error("Cannot call llvm_wmma_load without values for placeholders!")
94170
export llvm_wmma_load
95171

96-
for mat in ["a", "b", "c"],
172+
for ops in all_ldst_ops,
173+
mnk in ops[1],
174+
mat in ops[2],
175+
elem_type in ops[3],
97176
layout in ["col", "row"],
98-
shape in ["m16n16k16"],
99177
addr_space in ["", "shared", "global"],
100-
stride in ["stride"],
101-
elem_type in ["f16", "f32"]
178+
stride in ["stride"]
102179

180+
shape = get_hl_shape(mnk[1], mnk[2], mnk[3])
103181
# TODO: Non-stride versions?
104182

105-
# Float32 is only supported for C
106-
if (elem_type == "f32") && (mat != "c")
107-
continue
108-
end
109-
110183
addr_space_int = get_addrspace_info(addr_space)
111184

112185
# Name of the Julia wrapper function
@@ -116,7 +189,7 @@ for mat in ["a", "b", "c"],
116189
llvm_intr = "llvm.nvvm.wmma.$shape.load.$mat.$layout.stride.$elem_type.p$(addr_space_int)i8"
117190

118191
# Determine types + size for this (matrix, elem_type) combination
119-
arr_ty, frag_ty, sz = get_frag_info(mat, elem_type)
192+
arr_ty, frag_ty, sz = get_frag_info(mat, elem_type, shape)
120193

121194
ccall_name = "extern $llvm_intr"
122195

@@ -144,19 +217,28 @@ Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.store.d.sync.{layout}.{shape}
144217
145218
# Placeholders
146219
- `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively.
147-
- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`.
220+
- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k16`, `m32n8k16`, and `m8n32k16`.
148221
- `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`.
149-
- `{elem_type}`: The type of each element in the matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point).
222+
- `{elem_type}`: The type of each element in the matrix. For `a` and `b` matrices, valid values are `u8` (byte unsigned integer),
223+
`s8` (byte signed integer), and `f16` (half precision floating point). For `c` and `d` matrices, valid values are
224+
`s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point).
150225
"""
151226
llvm_wmma_store() = error("Cannot call llvm_wmma_store without values for placeholders!")
152227
export llvm_wmma_store
153228

154-
for mat in ["d"],
155-
layout in ["col", "row"],
156-
shape in ["m16n16k16"],
157-
addr_space in ["", "shared", "global"],
158-
stride in ["stride"],
159-
elem_type in ["f16", "f32"]
229+
for ops in all_ldst_ops,
230+
mnk in ops[1],
231+
mat in ops[2],
232+
elem_type in ops[3],
233+
layout in ["col", "row"],
234+
addr_space in ["", "shared", "global"],
235+
stride in ["stride"]
236+
237+
if mat != "d"
238+
continue
239+
end
240+
241+
shape = get_hl_shape(mnk[1], mnk[2], mnk[3])
160242

161243
# TODO: Non-stride versions?
162244

@@ -169,7 +251,7 @@ for mat in ["d"],
169251
llvm_intr = "llvm.nvvm.wmma.$shape.store.$mat.$layout.stride.$elem_type.p$(addr_space_int)i8"
170252

171253
# Determine types + size for this (matrix, elem_type) combination
172-
arr_ty, frag_ty, sz = get_frag_info(mat, elem_type)
254+
arr_ty, frag_ty, sz = get_frag_info(mat, elem_type, shape)
173255

174256
ccall_name = "extern $llvm_intr"
175257
frag_types = ntuple(i -> frag_ty, sz)
@@ -187,9 +269,11 @@ end
187269
# --------------------------
188270

189271
@doc """
190-
WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{d_elem_type}_{c_elem_type}(a, b, c)
272+
WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{d_elem_type}_{c_elem_type}(a, b, c) or
273+
WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{a_elem_type}(a, b, c)
191274
192-
Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{d_elem_type}.{c_elem_type}`.
275+
For floating point operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{d_elem_type}.{c_elem_type}`
276+
For all other operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{a_elem_type}`
193277
194278
# Arguments
195279
- `a`: The WMMA fragment corresponding to the matrix ``A``.
@@ -199,9 +283,10 @@ Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout
199283
# Placeholders
200284
- `{a_layout}`: The storage layout for matrix ``A``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation.
201285
- `{b_layout}`: The storage layout for matrix ``B``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation.
202-
- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`.
203-
- `{d_elem_type}`: The type of each element in the resultant ``D`` matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point).
204-
- `{c_elem_type}`: The type of each element in the ``C`` matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point).
286+
- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k16`, `m32n8k16`, and `m8n32k16`.
287+
- `{a_elem_type}`: The type of each element in the ``A`` matrix. Valid values are `u8` (byte unsigned integer), `s8` (byte signed integer), and `f16` (half precision floating point).
288+
- `{d_elem_type}`: The type of each element in the resultant ``D`` matrix. Valid values are `s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point).
289+
- `{c_elem_type}`: The type of each element in the ``C`` matrix. Valid values are `s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point).
205290
206291
!!! warning
207292
@@ -211,25 +296,34 @@ Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout
211296
llvm_wmma_mma() = error("Cannot call llvm_wmma_mma without values for placeholders!")
212297
export llvm_wmma_mma
213298

214-
for a_layout in ["col", "row"],
299+
for ops in all_wmma_ops,
300+
a_layout in ["col", "row"],
215301
b_layout in ["col", "row"],
216-
shape in ["m16n16k16"],
217-
d_elem_type in ["f16", "f32"],
218-
c_elem_type in ["f16", "f32"],
219-
b_elem_type in ["f16"],
220-
a_elem_type in ["f16"]
302+
mnk in ops[1],
303+
d_elem_type in ops[4],
304+
c_elem_type in ops[3],
305+
b_elem_type in ops[2]
221306

222-
# Name of the Julia wrapper function
223-
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type]), "_"))
307+
a_elem_type = b_elem_type
308+
shape = get_hl_shape(mnk[1], mnk[2], mnk[3])
224309

225310
# Name of the LLVM intrinsic
226-
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$d_elem_type.$c_elem_type"
311+
# If integer/sub-byte/bit A/B types, name is determined by A/B types
312+
if d_elem_type == "s32"
313+
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$a_elem_type"
314+
# Name of the Julia wrapper function
315+
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type]), "_"))
316+
else # Name defined by D/C types
317+
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$d_elem_type.$c_elem_type"
318+
# Name of the Julia wrapper function
319+
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type]), "_"))
320+
end
227321

228322
# Determine types + size for the (matrix, elem_type) combinations for matrix A, B, C and D
229-
a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type)
230-
b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", b_elem_type)
231-
c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", c_elem_type)
232-
d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", d_elem_type)
323+
a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type, shape)
324+
b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", b_elem_type, shape)
325+
c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", c_elem_type, shape)
326+
d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", d_elem_type, shape)
233327

234328
ccall_name = "extern $llvm_intr"
235329

@@ -439,17 +533,9 @@ function get_hl_layout(L)
439533
end
440534
end
441535

442-
function get_hl_shape(M, N, K)
443-
if (M, N, K) != (16, 16, 16)
444-
error("Invalid shape for WMMA: (M, N, K) = ($M, $N, $K)")
445-
end
446-
447-
return "m$(M)n$(N)k$(K)"
448-
end
449-
450536
get_hl_mat_use(mat) = map_matrix_to_use[mat]
451537

452-
function get_hl_frag_info(matrix, T)
538+
function get_hl_frag_info(matrix, T, shape)
453539
ptx_ty = nothing
454540

455541
try
@@ -460,7 +546,7 @@ function get_hl_frag_info(matrix, T)
460546

461547
try
462548
return (map_num_elems[(matrix, T)],
463-
map_frag_sizes["$matrix.$ptx_ty"],
549+
map_frag_sizes["$matrix.$ptx_ty.$shape"],
464550
map_ptx_to_jl_frag[ptx_ty],
465551
ptx_ty)
466552
catch
@@ -507,7 +593,7 @@ for mat in ["a", "b", "c"]
507593
as_str = get_hl_as_info(AS)
508594
layout = get_hl_layout(L)
509595
shape = get_hl_shape(M, N, K)
510-
num_els, _, _, arr_str = get_hl_frag_info($mat, T)
596+
num_els, _, _, arr_str = get_hl_frag_info($mat, T, shape)
511597
U = get_hl_mat_use($mat)
512598
L_ret = ($mat == "c") ? Unspecified : L
513599

@@ -552,15 +638,17 @@ mma
552638
c::Fragment{M, N, K, C_SZ, C_T, Unspecified, Accumulator},
553639
config::Type{Config{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T}
554640

555-
_, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T)
556-
_, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T)
557-
_, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T)
558-
d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T)
559-
560641
a_layout = get_hl_layout(A_L)
561642
b_layout = get_hl_layout(B_L)
562643
shape = get_hl_shape(M, N, K)
563644

645+
_, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T, shape)
646+
_, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T, shape)
647+
_, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T, shape)
648+
d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T, shape)
649+
650+
651+
564652
# Name of the Julia wrapper
565653
wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_arr_str, c_arr_str]), "_"))
566654

@@ -611,7 +699,7 @@ store_d
611699
as_str = get_hl_as_info(AS)
612700
layout = get_hl_layout(L)
613701
shape = get_hl_shape(M, N, K)
614-
num_els, frag_sz, frag_ty, arr_str = get_hl_frag_info("d", T)
702+
num_els, frag_sz, frag_ty, arr_str = get_hl_frag_info("d", T, shape)
615703

616704
# Name of the Julia wrapper
617705
wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "store", "d", layout, shape, as_str, "stride", arr_str]), "_"))
@@ -648,7 +736,8 @@ fill_c
648736

649737
# We can't use closures in @generated functions, so we'll have to do it this way instead of
650738
# ntuple(i -> val, $num_els)
651-
num_els, _, _ = get_hl_frag_info("c", T)
739+
shape = get_hl_shape(M, N, K)
740+
num_els, _, _ = get_hl_frag_info("c", T, shape)
652741

653742
args = [:value for i=1:num_els]
654743
expr = :(tuple($(args...)))

0 commit comments

Comments
 (0)