3
3
on blackwell with/without warpspec.
4
4
"""
5
5
6
+ import functools
7
+ import logging
6
8
from typing import Optional
7
9
8
10
import torch
13
15
# TODO: Add proton support
14
16
15
17
18
+ def torch_dtype_to_triton_dtype (dtype ):
19
+ if dtype == torch .float16 :
20
+ return tl .float16
21
+ elif dtype == torch .float32 :
22
+ return tl .float32
23
+ elif dtype == torch .float8_e4m3fn :
24
+ return tl .float8e4nv
25
+ elif dtype == torch .bfloat16 :
26
+ return tl .bfloat16
27
+ else :
28
+ raise ValueError (f"Unsupported dtype: { dtype } " )
29
+
30
+
31
+ def check_tma_alignment (strides , elem_bytes ):
32
+ for stride in strides [:- 1 ]:
33
+ if (stride * elem_bytes ) % 16 != 0 :
34
+ raise RuntimeError ("strides must be 16-byte aligned" )
35
+ if strides [- 1 ] != 1 :
36
+ raise RuntimeError ("Last dimension must be contiguous" )
37
+
38
+
16
39
def _matmul_launch_metadata (grid , kernel , args ):
17
40
ret = {}
18
41
M , N , K , WS = args ["M" ], args ["N" ], args ["K" ], args .get ("WARP_SPECIALIZE" , False )
@@ -21,7 +44,8 @@ def _matmul_launch_metadata(grid, kernel, args):
21
44
if "c_ptr" in args :
22
45
bytes_per_elem = args ["c_ptr" ].element_size ()
23
46
else :
24
- bytes_per_elem = 1 if args ["FP8_OUTPUT" ] else 2
47
+ # ceil division to capture the correct number of bytes
48
+ bytes_per_elem = (args ["DTYPE" ].int_bitwidth + 7 ) // 8
25
49
ret [f"flops{ bytes_per_elem * 8 } " ] = 2.0 * M * N * K
26
50
ret ["bytes" ] = bytes_per_elem * (M * K + N * K + M * N )
27
51
return ret
@@ -77,10 +101,10 @@ def matmul_kernel_tma(
77
101
BLOCK_SIZE_N : tl .constexpr , #
78
102
BLOCK_SIZE_K : tl .constexpr , #
79
103
GROUP_SIZE_M : tl .constexpr , #
80
- FP8_OUTPUT : tl .constexpr , #
81
104
WARP_SPECIALIZE : tl .constexpr , #
105
+ DTYPE : tl .constexpr ,
82
106
):
83
- dtype = tl . float8e4nv if FP8_OUTPUT else tl . float16
107
+ dtype = DTYPE
84
108
85
109
pid = tl .program_id (axis = 0 )
86
110
num_pid_m = tl .cdiv (M , BLOCK_SIZE_M )
@@ -112,8 +136,24 @@ def matmul_kernel_tma(
112
136
c_desc .store ([offs_cm , offs_cn ], c )
113
137
114
138
139
+ @functools .lru_cache
140
+ def warn_once (msg : str ):
141
+ """
142
+ Wrapper around logging.warning to try minimize the number of warnings when
143
+ a function is repeatedly called.
144
+ """
145
+ logging .warning (
146
+ "Incompatible dimensions, B is transposed. We are transposing B which may impact results"
147
+ )
148
+
149
+
115
150
def blackwell_matmul_tma (a , b , warp_specialize : bool ):
116
151
# Check constraints.
152
+ if a .shape [1 ] != b .shape [1 ]:
153
+ warn_once (
154
+ "Incompatible dimensions, B is transposed. We are transposing B which may impact results"
155
+ )
156
+ b = b .T .contiguous ()
117
157
assert a .shape [1 ] == b .shape [1 ], "Incompatible dimensions" # b is transposed
118
158
assert a .dtype == b .dtype , "Incompatible dtypes"
119
159
@@ -141,8 +181,8 @@ def grid(META):
141
181
M ,
142
182
N ,
143
183
K , #
144
- FP8_OUTPUT = dtype == torch .float8_e4m3fn , #
145
184
WARP_SPECIALIZE = warp_specialize , #
185
+ DTYPE = torch_dtype_to_triton_dtype (dtype ), #
146
186
)
147
187
return c
148
188
@@ -196,12 +236,12 @@ def matmul_kernel_tma_persistent(
196
236
BLOCK_SIZE_N : tl .constexpr , #
197
237
BLOCK_SIZE_K : tl .constexpr , #
198
238
GROUP_SIZE_M : tl .constexpr , #
199
- FP8_OUTPUT : tl .constexpr , #
200
239
EPILOGUE_SUBTILE : tl .constexpr , #
201
240
NUM_SMS : tl .constexpr , #
202
241
WARP_SPECIALIZE : tl .constexpr , #
242
+ DTYPE : tl .constexpr ,
203
243
):
204
- dtype = tl . float8e4nv if FP8_OUTPUT else tl . float16
244
+ dtype = DTYPE
205
245
start_pid = tl .program_id (axis = 0 )
206
246
num_pid_m = tl .cdiv (M , BLOCK_SIZE_M )
207
247
num_pid_n = tl .cdiv (N , BLOCK_SIZE_N )
@@ -256,9 +296,17 @@ def matmul_kernel_tma_persistent(
256
296
257
297
def blackwell_matmul_tma_persistent (a , b , warp_specialize : bool ):
258
298
# Check constraints.
299
+ if a .shape [1 ] != b .shape [1 ]:
300
+ warn_once (
301
+ "Incompatible dimensions, B is transposed. We are transposing B which may impact results"
302
+ )
303
+ b = b .T .contiguous ()
259
304
assert a .shape [1 ] == b .shape [1 ], "Incompatible dimensions" # b is transposed
260
305
assert a .dtype == b .dtype , "Incompatible dtypes"
261
306
307
+ check_tma_alignment (a .stride (), (torch .finfo (a .dtype ).bits + 7 ) // 8 )
308
+ check_tma_alignment (b .stride (), (torch .finfo (b .dtype ).bits + 7 ) // 8 )
309
+
262
310
M , K = a .shape
263
311
N , K = b .shape
264
312
dtype = a .dtype
@@ -291,9 +339,9 @@ def grid(META):
291
339
M ,
292
340
N ,
293
341
K , #
294
- FP8_OUTPUT = dtype == torch .float8_e4m3fn , #
295
342
NUM_SMS = NUM_SMS , #
296
343
WARP_SPECIALIZE = warp_specialize , #
344
+ DTYPE = torch_dtype_to_triton_dtype (dtype ), #
297
345
)
298
346
return c
299
347
@@ -403,6 +451,11 @@ def matmul_kernel_descriptor_persistent(
403
451
404
452
def blackwell_matmul_descriptor_persistent (a , b , warp_specialize : bool ):
405
453
# Check constraints.
454
+ if a .shape [1 ] != b .shape [1 ]:
455
+ warn_once (
456
+ "Incompatible dimensions, B is transposed. We are transposing B which may impact results"
457
+ )
458
+ b = b .T .contiguous ()
406
459
assert a .shape [1 ] == b .shape [1 ], "Incompatible dimensions" # b is transposed
407
460
assert a .dtype == b .dtype , "Incompatible dtypes"
408
461
0 commit comments