4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from typing import Callable , Tuple
7
+ from typing import Tuple
8
8
9
9
import fire
10
10
import torch
11
11
import triton
12
- from torch . _inductor . utils import do_bench_using_profiling
12
+ from triton . testing import do_bench
13
13
14
14
from torchao .prototype .mx_formats .kernels import (
15
15
triton_to_mxfp8_dim1 ,
@@ -64,29 +64,35 @@ def to_mx_dim1_reference(x_hp, block_size):
64
64
return data_d1 .t (), scale_d1
65
65
66
66
67
- def benchmark_cuda_function_in_microseconds (func : Callable , * args , ** kwargs ) -> float :
68
- """Thin wrapper around do_bench_using_profiling"""
69
- no_args = lambda : func (* args , ** kwargs )
70
- time = do_bench_using_profiling (no_args )
71
- return time * 1e3
67
+ def benchmark_cuda_function_in_microseconds (f , * args ):
68
+ return do_bench (lambda : f (* args ), return_mode = "median" ) * 1e3
72
69
73
70
74
71
def run (
75
72
M : int = 16384 ,
76
73
K : int = 16384 ,
77
74
BLOCK_SIZE : int = 32 ,
78
- mode : str = "dim0 " ,
75
+ mode : str = "dim0_floor " ,
79
76
):
80
77
print (f"M { M } K { K } BLOCK_SIZE { BLOCK_SIZE } " )
81
78
print (f"GPU: { torch .cuda .get_device_name (0 )} " )
82
79
print (f"torch version: { torch .__version__ } " )
83
80
print (f"triton version: { triton .__version__ } " )
84
81
print (f"mode: { mode } " )
85
- assert mode in ("dim0" , "dim1" , "dim0_dim1" , "dim0_mx" , "dim1_mx" , "dim1_mx_triton" )
82
+ assert mode in (
83
+ "dim0_floor" ,
84
+ "dim1_floor" ,
85
+ "dim0_dim1_floor" ,
86
+ "dim0_mx_floor" ,
87
+ "dim1_mx_floor" ,
88
+ "dim1_mx_triton_floor" ,
89
+ "dim1_mx_cuda_floor" ,
90
+ "dim1_mx_cuda_rceil" ,
91
+ )
86
92
87
93
x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" ) * 1000
88
94
89
- if mode == "dim0 " :
95
+ if mode == "dim0_floor " :
90
96
scale_dim0_reference_c = torch .compile (scale_dim0_reference )
91
97
y_d0 , s_d0 = scale_dim0_reference_c (x , BLOCK_SIZE )
92
98
@@ -103,7 +109,7 @@ def run(
103
109
bytes_rw = sum (t .numel () for t in [x , y_d0 , s_d0 ]) * bytes_per_el_bf16
104
110
bps = bytes_rw / (time_us / 1e6 )
105
111
106
- elif mode == "dim1 " :
112
+ elif mode == "dim1_floor " :
107
113
scale_dim1_reference_c = torch .compile (scale_dim1_reference )
108
114
y_d1 , s_d1 = scale_dim1_reference_c (x , BLOCK_SIZE )
109
115
@@ -120,7 +126,7 @@ def run(
120
126
bytes_rw = sum (t .numel () for t in [x , y_d1 , s_d1 ]) * bytes_per_el_bf16
121
127
bps = bytes_rw / (time_us / 1e6 )
122
128
123
- elif mode == "dim0_dim1 " :
129
+ elif mode == "dim0_dim1_floor " :
124
130
scale_dim0_dim1_reference_c = torch .compile (scale_dim0_dim1_reference )
125
131
y_d0 , y_d1 , s_d0 , s_d1 = scale_dim0_dim1_reference_c (x , BLOCK_SIZE )
126
132
@@ -141,7 +147,7 @@ def run(
141
147
)
142
148
bps = bytes_rw / (time_us / 1e6 )
143
149
144
- elif mode == "dim0_mx " :
150
+ elif mode == "dim0_mx_floor " :
145
151
to_mx_dim0_reference_c = torch .compile (to_mx_dim0_reference )
146
152
y_d0 , s_d0 = to_mx_dim0_reference_c (x , BLOCK_SIZE )
147
153
@@ -159,7 +165,7 @@ def run(
159
165
bytes_w = (y_d0 .numel () + s_d0 .numel ()) * bytes_per_el_fp8
160
166
bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
161
167
162
- elif mode == "dim1_mx " :
168
+ elif mode == "dim1_mx_floor " :
163
169
to_mx_dim1_reference_c = torch .compile (to_mx_dim1_reference )
164
170
y_d1 , s_d1 = to_mx_dim1_reference_c (x , BLOCK_SIZE )
165
171
@@ -177,7 +183,7 @@ def run(
177
183
bytes_w = (y_d1 .numel () + s_d1 .numel ()) * bytes_per_el_fp8
178
184
bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
179
185
180
- elif mode == "dim1_mx_triton " :
186
+ elif mode == "dim1_mx_triton_floor " :
181
187
y_d1 , s_d1 = triton_to_mxfp8_dim1 (x , inner_block_size = BLOCK_SIZE )
182
188
183
189
for _ in range (2 ):
@@ -194,6 +200,58 @@ def run(
194
200
bytes_w = (y_d1 .numel () + s_d1 .numel ()) * bytes_per_el_fp8
195
201
bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
196
202
203
+ elif mode == "dim1_mx_cuda_floor" :
204
+ from torchao .prototype import mxfp8_cuda
205
+
206
+ _ , y_d1 , _ , s_d1 = mxfp8_cuda .quantize (
207
+ x , rowwise = False , colwise = True , scaling_mode = "floor"
208
+ )
209
+
210
+ for _ in range (2 ):
211
+ __ = mxfp8_cuda .quantize (
212
+ x , rowwise = False , colwise = True , scaling_mode = "floor"
213
+ )
214
+
215
+ time_us = benchmark_cuda_function_in_microseconds (
216
+ lambda x : mxfp8_cuda .quantize (
217
+ x , rowwise = False , colwise = True , scaling_mode = "floor"
218
+ ),
219
+ x ,
220
+ )
221
+
222
+ assert y_d1 .dtype == torch .float8_e4m3fn
223
+ assert s_d1 .dtype == torch .float8_e8m0fnu
224
+
225
+ bytes_r = x .numel () * bytes_per_el_bf16
226
+ bytes_w = (y_d1 .numel () + s_d1 .numel ()) * bytes_per_el_fp8
227
+ bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
228
+
229
+ elif mode == "dim1_mx_cuda_rceil" :
230
+ from torchao .prototype import mxfp8_cuda
231
+
232
+ _ , y_d1 , _ , s_d1 = mxfp8_cuda .quantize (
233
+ x , rowwise = False , colwise = True , scaling_mode = "rceil"
234
+ )
235
+
236
+ for _ in range (2 ):
237
+ __ = mxfp8_cuda .quantize (
238
+ x , rowwise = False , colwise = True , scaling_mode = "rceil"
239
+ )
240
+
241
+ time_us = benchmark_cuda_function_in_microseconds (
242
+ lambda x : mxfp8_cuda .quantize (
243
+ x , rowwise = False , colwise = True , scaling_mode = "rceil"
244
+ ),
245
+ x ,
246
+ )
247
+
248
+ assert y_d1 .dtype == torch .float8_e4m3fn
249
+ assert s_d1 .dtype == torch .float8_e8m0fnu
250
+
251
+ bytes_r = x .numel () * bytes_per_el_bf16
252
+ bytes_w = (y_d1 .numel () + s_d1 .numel ()) * bytes_per_el_fp8
253
+ bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
254
+
197
255
else :
198
256
raise AssertionError (f"unknown mode { mode } " )
199
257
0 commit comments