@@ -60,8 +60,8 @@ def forward(
60
60
offs : Optional [torch .Tensor ] = None ,
61
61
out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
62
62
) -> torch .Tensor :
63
- # torchao _scaled_grouped_mm only supports A=2D, B=3D.
64
- assert A .ndim == 2 , "A must be 2D"
63
+ # torchao _scaled_grouped_mm only supports A=2D|3D + B=3D.
64
+ assert A .ndim == 2 or A . ndim == 3 , "A must be 2D or 3D "
65
65
assert B_t .ndim == 3 , "B must be 3D"
66
66
67
67
assert A .size (- 1 ) % 16 == 0 , (
@@ -150,12 +150,25 @@ def forward(
150
150
assert _is_column_major (B_t_fp8_col_major ), (
151
151
"B must be column-major for output = A @ B"
152
152
)
153
+
154
+ # TODO: remove excessive logging once prototype is more mature.
155
+ logger .debug (
156
+ (
157
+ f"forward scaled_grouped_mm: A_fp8_row_major.shape={ A_fp8_row_major .shape } , "
158
+ f"A_scale.shape={ A_scales .squeeze (- 1 ).shape } , "
159
+ f"B_t_fp8_col_major.shape={ B_t_fp8_col_major .shape } , "
160
+ f"B_t_scale.shape={ B_t_scales .squeeze (1 ).shape } , "
161
+ f"offs={ offs if offs is not None else None } "
162
+ )
163
+ )
153
164
return torch ._scaled_grouped_mm (
154
165
A_fp8_row_major ,
155
166
B_t_fp8_col_major ,
156
- A_scales .squeeze ().reciprocal (),
157
- B_t_scales .squeeze ().reciprocal (),
158
- offs ,
167
+ # Squeeze A scales to: (B, S, 1) => (B, M), or (B*S, 1) => (B*S)
168
+ A_scales .squeeze (- 1 ).reciprocal (),
169
+ # Squeeze B scales to: (B, 1, N) => (B, N)
170
+ B_t_scales .squeeze (1 ).reciprocal (),
171
+ offs = offs ,
159
172
out_dtype = out_dtype ,
160
173
use_fast_accum = True ,
161
174
)
@@ -192,12 +205,20 @@ def backward(ctx, grad_output: torch.Tensor):
192
205
assert _is_column_major (B_fp8_col_major ), (
193
206
"B must be column-major for grad_A = grad_output @ B"
194
207
)
208
+ logger .debug (
209
+ (
210
+ f"backward grad_A: grad_output_fp8_row_major.shape={ grad_output_fp8_row_major .shape } , "
211
+ f"grad_output_scale.shape={ grad_output_scales .shape } , "
212
+ f"B_fp8_col_major.shape={ B_fp8_col_major .shape } , "
213
+ f"B_scale.shape={ B_scales .shape } , "
214
+ )
215
+ )
195
216
grad_A = torch ._scaled_grouped_mm (
196
217
grad_output_fp8_row_major ,
197
218
B_fp8_col_major ,
198
- grad_output_scales .squeeze ().reciprocal (),
199
- B_scales .squeeze ().reciprocal (),
200
- offs ,
219
+ grad_output_scales .squeeze (- 1 ).reciprocal (),
220
+ B_scales .squeeze (1 ).reciprocal (),
221
+ offs = offs ,
201
222
out_dtype = out_dtype ,
202
223
use_fast_accum = True ,
203
224
)
@@ -237,12 +258,21 @@ def backward(ctx, grad_output: torch.Tensor):
237
258
assert _is_column_major (A_fp8_col_major ), (
238
259
"A must be column-major for grad_B = grad_output_t @ A"
239
260
)
261
+
262
+ logger .debug (
263
+ (
264
+ f"backward grad_B: grad_output_t_fp8_row_major.shape={ grad_output_t_fp8_row_major .shape } , "
265
+ f"grad_output_t_scale.shape={ grad_output_t_scales .shape } , "
266
+ f"A_fp8_col_major.shape={ A_fp8_col_major .shape } , "
267
+ f"A_scale.shape={ A_scales .shape } , "
268
+ )
269
+ )
240
270
grad_B = torch ._scaled_grouped_mm (
241
271
grad_output_t_fp8_row_major ,
242
272
A_fp8_col_major ,
243
273
grad_output_t_scales .reciprocal (),
244
274
A_scales .reciprocal (),
245
- offs ,
275
+ offs = offs ,
246
276
out_dtype = out_dtype ,
247
277
use_fast_accum = True ,
248
278
)
0 commit comments