16
16
DTYPE_FP6_E3M2 ,
17
17
SUPPORTED_ELEM_DTYPES ,
18
18
)
19
- from torchao .prototype .mx_formats .custom_cast import pack_uint4
19
+ from torchao .prototype .mx_formats .custom_cast import pack_uint4 , pack_uint6
20
20
from torchao .prototype .mx_formats .mx_tensor import (
21
21
E8M0_EXPONENT_NAN_VAL ,
22
22
MXTensor ,
@@ -75,7 +75,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
75
75
@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
76
76
def test_hello_world (elem_dtype ):
77
77
data = torch .randn (4 , 4 , device = "cuda" , dtype = torch .bfloat16 )
78
- block_size = 2
78
+ block_size = 4
79
79
_test_mx (data , elem_dtype , block_size )
80
80
81
81
@@ -92,7 +92,7 @@ def test_realistic_numerics(elem_dtype, scale_calculation_mode):
92
92
@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
93
93
def test_all_zeros (elem_dtype ):
94
94
data = torch .zeros (4 , 4 , device = "cuda" , dtype = torch .bfloat16 )
95
- block_size = 2
95
+ block_size = 4
96
96
_test_mx (data , elem_dtype , block_size )
97
97
98
98
@@ -102,7 +102,7 @@ def test_some_zeros(elem_dtype):
102
102
data = torch .randn (4 , 4 , device = "cuda" , dtype = torch .bfloat16 )
103
103
data [0 , :] = 0.0
104
104
data [:, 2 ] = 0.0
105
- block_size = 2
105
+ block_size = 4
106
106
_test_mx (data , elem_dtype , block_size )
107
107
108
108
@@ -114,33 +114,46 @@ def test_exponent_nan_in(elem_dtype):
114
114
value is set to is NaN
115
115
"""
116
116
tensor_hp = torch .tensor (
117
- [float ("nan" ), 1 , 2 , 3 , 4 , 5 ], device = "cuda" , dtype = torch .bfloat16
117
+ [float ("nan" ), 1 , 2 , 3 , 4 , 5 , 6 , 7 ], device = "cuda" , dtype = torch .bfloat16
118
118
)
119
- block_size = 2
119
+ block_size = 4
120
120
tensor_mx = MXTensor .to_mx (tensor_hp , elem_dtype , block_size )
121
121
assert torch .all (tensor_mx ._scale_e8m0 [0 ] == E8M0_EXPONENT_NAN_VAL )
122
122
assert not torch .any (tensor_mx ._scale_e8m0 [1 :] == E8M0_EXPONENT_NAN_VAL )
123
123
124
124
125
125
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
126
126
@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
127
- def test_exponent_nan_out (elem_dtype ):
127
+ @pytest .mark .parametrize ("pack_fp6" , [False , True ])
128
+ def test_exponent_nan_out (elem_dtype , pack_fp6 ):
128
129
"""
129
130
If block exponent value is NaN, the MX tensor block value is NaN
130
131
"""
131
132
scale_e8m0_bits = torch .tensor (
132
- [E8M0_EXPONENT_NAN_VAL , 23 , 42 ], dtype = torch .uint8 , device = "cuda"
133
+ [E8M0_EXPONENT_NAN_VAL , 23 ], dtype = torch .uint8 , device = "cuda"
133
134
)
135
+
136
+ block_size = 4
137
+
134
138
if elem_dtype in (torch .float8_e4m3fn , torch .float8_e5m2 ):
135
- data_bits = torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 ], dtype = elem_dtype , device = "cuda" ) # noqa: E501
139
+ data_bits = torch .tensor (
140
+ [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], dtype = elem_dtype , device = "cuda"
141
+ ) # noqa: E501
136
142
elif elem_dtype in (DTYPE_FP6_E2M3 , DTYPE_FP6_E3M2 ):
137
- data_bits = torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 ], dtype = torch .uint8 , device = "cuda" ) # noqa: E501
143
+ data_bits = torch .tensor (
144
+ [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], dtype = torch .uint8 , device = "cuda"
145
+ ) # noqa: E501
146
+ if pack_fp6 :
147
+ data_bits = data_bits .reshape (- 1 , block_size )
148
+ data_bits = pack_uint6 (data_bits )
138
149
elif elem_dtype == DTYPE_FP4 :
139
- data_bits = torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 ], dtype = torch .uint8 , device = "cuda" ) # noqa: E501
150
+ data_bits = torch .tensor (
151
+ [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], dtype = torch .uint8 , device = "cuda"
152
+ ) # noqa: E501
140
153
data_bits = pack_uint4 (data_bits )
141
154
else :
142
155
raise AssertionError ("unsupported" )
143
- block_size = 2
156
+ block_size = 4
144
157
use_fp4_custom_triton_dequant_kernel = False
145
158
tensor_mx = MXTensor (
146
159
scale_e8m0_bits ,
@@ -150,10 +163,11 @@ def test_exponent_nan_out(elem_dtype):
150
163
torch .float ,
151
164
use_fp4_custom_triton_dequant_kernel ,
152
165
MXGemmKernelChoice .EMULATED ,
166
+ pack_fp6 ,
153
167
)
154
168
tensor_hp = tensor_mx .to_dtype (torch .float )
155
- assert torch .all (torch .isnan (tensor_hp [0 :1 ]))
156
- assert not torch .any (torch .isnan (tensor_hp [ 2 :]))
169
+ assert torch .all (torch .isnan (tensor_hp . flatten () [0 :4 ]))
170
+ assert not torch .any (torch .isnan (tensor_hp . flatten ()[ 4 :]))
157
171
158
172
159
173
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
@@ -162,24 +176,26 @@ def test_ranks(elem_dtype):
162
176
"""
163
177
The reshaping logic works for various ranks
164
178
"""
165
- B = 2
166
- shapes = ((B * 4 ,), (B * 4 , 2 ), (B * 4 , 2 , 2 ), (B * 4 , 2 , 2 , 2 ))
179
+ B = 4
180
+ shapes = ((B * 4 ,), (B * 4 , 4 ), (B * 4 , 4 , 4 ), (B * 4 , 4 , 4 , 4 ))
167
181
for s in shapes :
168
182
tensor_hp = torch .randn (* s , device = "cuda" , dtype = torch .bfloat16 )
169
183
_test_mx (tensor_hp , elem_dtype , B )
170
184
171
185
172
186
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
173
187
@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
174
- def test_block_sizes (elem_dtype ):
188
+ @pytest .mark .parametrize ("B" , [1 , 4 , 32 ])
189
+ def test_block_sizes (elem_dtype , B ):
175
190
"""
176
191
Smoke test for various block sizes
177
192
"""
178
- for B in (1 , 2 , 32 ):
179
- if B == 1 and elem_dtype == DTYPE_FP4 :
180
- pytest .skip ("unsupported configuration" )
181
- tensor_hp = torch .randn (B , device = "cuda" , dtype = torch .bfloat16 )
182
- _test_mx (tensor_hp , elem_dtype , B )
193
+ if B == 1 and elem_dtype == DTYPE_FP4 :
194
+ pytest .skip ("unsupported configuration" )
195
+ elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3 , DTYPE_FP6_E3M2 ]:
196
+ pytest .skip ("unsupported configuration" )
197
+ tensor_hp = torch .randn (B , device = "cuda" , dtype = torch .bfloat16 )
198
+ _test_mx (tensor_hp , elem_dtype , B )
183
199
184
200
185
201
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
@@ -224,14 +240,30 @@ def test_cast_autograd(elem_dtype):
224
240
torch .testing .assert_close (grad , x .grad , atol = 0 , rtol = 0 )
225
241
226
242
243
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
227
244
@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
228
245
def test_view (elem_dtype ):
229
- x = torch .randn (1 , 2 , 4 )
230
- block_size = 2
246
+ x = torch .randn (1 , 2 , 4 , device = "cuda" )
247
+ block_size = 4
231
248
x_mx = MXTensor .to_mx (x , elem_dtype , block_size )
232
249
x_mx_2 = x_mx .view (2 , 4 ) # noqa: F841
233
250
234
251
252
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
253
+ @pytest .mark .parametrize ("elem_dtype" , [DTYPE_FP6_E2M3 , DTYPE_FP6_E3M2 ])
254
+ @pytest .mark .parametrize ("pack_fp6" , [False , True ])
255
+ def test_fp6_packing (elem_dtype , pack_fp6 ):
256
+ x = torch .randn (1 , 2 , 4 , device = "cuda" )
257
+ block_size = 4
258
+ x_mx = MXTensor .to_mx (x , elem_dtype , block_size , pack_fp6 = pack_fp6 )
259
+ if pack_fp6 :
260
+ expected_packed_shape = torch .Size ([* x .shape [:- 1 ], 3 * x .shape [- 1 ] // 4 ])
261
+ else :
262
+ expected_packed_shape = x .shape
263
+
264
+ assert x_mx ._data .shape == expected_packed_shape
265
+
266
+
235
267
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
236
268
@pytest .mark .skipif (
237
269
is_sm_at_least_100 (), reason = "triton does not work yet on CUDA capability 10.0"
@@ -253,7 +285,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
253
285
x = torch .randn (* shape , dtype = hp_dtype , device = "cuda" )
254
286
else :
255
287
x = torch .zeros (* shape , dtype = hp_dtype , device = "cuda" )
256
- block_size = 2
288
+ block_size = 4
257
289
to_mx_c = torch .compile (MXTensor .to_mx , fullgraph = True )
258
290
259
291
x_mx = MXTensor .to_mx (x , elem_dtype , block_size )
@@ -269,13 +301,15 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
269
301
to_dtype_c = torch .compile (to_dtype , fullgraph = True )
270
302
271
303
use_fp4_custom_triton_dequant_kernel = False
304
+ pack_fp6 = False
272
305
x_mx_dq = to_dtype (
273
306
x_mx ._data ,
274
307
x_mx ._scale_e8m0 ,
275
308
x_mx ._elem_dtype ,
276
309
x_mx ._block_size ,
277
310
hp_dtype , # noqa: E501
278
311
use_fp4_custom_triton_dequant_kernel ,
312
+ pack_fp6 ,
279
313
)
280
314
x_mx_c_dq = to_dtype_c (
281
315
x_mx_c ._data ,
@@ -284,6 +318,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
284
318
x_mx_c ._block_size ,
285
319
hp_dtype ,
286
320
use_fp4_custom_triton_dequant_kernel ,
321
+ pack_fp6 ,
287
322
)
288
323
torch .testing .assert_close (x_mx_dq , x_mx_c_dq , atol = 0 , rtol = 0 )
289
324
0 commit comments