@@ -171,30 +171,36 @@ def test_activation_checkpointing():
171
171
@pytest .mark .skipif (
172
172
is_sm_at_least_100 (), reason = "triton does not work yet on CUDA capability 10.0"
173
173
)
174
- @pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
174
+ @pytest .mark .parametrize (
175
+ "recipe_name" ,
176
+ ["mxfp8_emulated" , "mxfp4_emulated" , "mxfp8_cutlass" , "mxfp4_cutlass" ],
177
+ )
175
178
@pytest .mark .parametrize ("bias" , [False , True ])
176
179
# TODO(future PR): figure out why torch.compile does not match eager when
177
180
# autocast is on
178
- @pytest .mark .parametrize (
179
- "use_autocast" ,
180
- [
181
- False ,
182
- ],
183
- )
184
- def test_linear_compile (elem_dtype , bias , use_autocast ):
181
+ def test_linear_compile (recipe_name , bias ):
185
182
"""
186
183
Verify that compile does not change numerics of MX linear fw + bw
187
184
"""
188
- if elem_dtype in ( torch . float8_e4m3fn , torch . float8_e5m2 ) :
185
+ if recipe_name in [ "mxfp8_emulated" , "mxfp8_cutlass" ] :
189
186
if not is_sm_at_least_89 ():
190
187
pytest .skip ("CUDA capability >= 8.9 required for float8 in triton" )
191
- M , K , N = 4 , 8 , 6
188
+
189
+ if recipe_name in ["mxfp8_cutlass" , "mxfp4_cutlass" ]:
190
+ if not is_sm_at_least_100 ():
191
+ pytest .skip ("CUDA capability >= 10.0 required for MX gemms" )
192
+
193
+ if bias and recipe_name in ["mxfp8_cutlass" , "mxfp4_cutlass" ]:
194
+ # TODO(future PR): fix this, things are clearly broken with bias=True
195
+ pytest .skip ("this test is broken for cutlass recipes with bias=True" )
196
+
197
+ M , K , N = 128 , 256 , 512
192
198
input_shape = (M , K )
193
199
grad_shape = (M , N )
194
200
m_mx = nn .Sequential (
195
201
nn .Linear (K , N , bias = bias , device = "cuda" ),
196
202
)
197
- config = MXLinearConfig ( block_size = 2 , elem_dtype = elem_dtype )
203
+ config = MXLinearConfig . from_recipe_name ( recipe_name )
198
204
swap_linear_with_mx_linear (m_mx , config = config )
199
205
m_mx_c = copy .deepcopy (m_mx )
200
206
m_mx_c = torch .compile (m_mx_c , fullgraph = True , backend = "inductor" )
@@ -203,13 +209,8 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
203
209
x = copy .deepcopy (x_ref )
204
210
g = torch .randn (* grad_shape , device = "cuda" )
205
211
206
- if use_autocast :
207
- with torch .autocast ("cuda" , dtype = torch .bfloat16 ):
208
- y_ref = m_mx (x_ref )
209
- y = m_mx_c (x )
210
- else :
211
- y_ref = m_mx (x_ref )
212
- y = m_mx_c (x )
212
+ y_ref = m_mx (x_ref )
213
+ y = m_mx_c (x )
213
214
torch .testing .assert_close (y_ref , y , atol = 0 , rtol = 0 )
214
215
215
216
y_ref .backward (g )
0 commit comments