Skip to content

Commit 8ebd9f2

Browse files
tp on routed experts working
1 parent ab79bce commit 8ebd9f2

File tree

2 files changed

+43
-13
lines changed

2 files changed

+43
-13
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def forward(
6060
offs: Optional[torch.Tensor] = None,
6161
out_dtype: Optional[torch.dtype] = torch.bfloat16,
6262
) -> torch.Tensor:
63-
# torchao _scaled_grouped_mm only supports A=2D, B=3D.
64-
assert A.ndim == 2, "A must be 2D"
63+
# torchao _scaled_grouped_mm only supports A=2D|3D + B=3D.
64+
assert A.ndim == 2 or A.ndim == 3, "A must be 2D or 3D"
6565
assert B_t.ndim == 3, "B must be 3D"
6666

6767
assert A.size(-1) % 16 == 0, (
@@ -150,12 +150,25 @@ def forward(
150150
assert _is_column_major(B_t_fp8_col_major), (
151151
"B must be column-major for output = A @ B"
152152
)
153+
154+
# TODO: remove excessive logging once prototype is more mature.
155+
logger.debug(
156+
(
157+
f"forward scaled_grouped_mm: A_fp8_row_major.shape={A_fp8_row_major.shape}, "
158+
f"A_scale.shape={A_scales.squeeze(-1).shape}, "
159+
f"B_t_fp8_col_major.shape={B_t_fp8_col_major.shape}, "
160+
f"B_t_scale.shape={B_t_scales.squeeze(1).shape}, "
161+
f"offs={offs if offs is not None else None}"
162+
)
163+
)
153164
return torch._scaled_grouped_mm(
154165
A_fp8_row_major,
155166
B_t_fp8_col_major,
156-
A_scales.squeeze().reciprocal(),
157-
B_t_scales.squeeze().reciprocal(),
158-
offs,
167+
# Squeeze A scales to: (B, S, 1) => (B, M), or (B*S, 1) => (B*S)
168+
A_scales.squeeze(-1).reciprocal(),
169+
# Squeeze B scales to: (B, 1, N) => (B, N)
170+
B_t_scales.squeeze(1).reciprocal(),
171+
offs=offs,
159172
out_dtype=out_dtype,
160173
use_fast_accum=True,
161174
)
@@ -192,12 +205,20 @@ def backward(ctx, grad_output: torch.Tensor):
192205
assert _is_column_major(B_fp8_col_major), (
193206
"B must be column-major for grad_A = grad_output @ B"
194207
)
208+
logger.debug(
209+
(
210+
f"backward grad_A: grad_output_fp8_row_major.shape={grad_output_fp8_row_major.shape}, "
211+
f"grad_output_scale.shape={grad_output_scales.shape}, "
212+
f"B_fp8_col_major.shape={B_fp8_col_major.shape}, "
213+
f"B_scale.shape={B_scales.shape}, "
214+
)
215+
)
195216
grad_A = torch._scaled_grouped_mm(
196217
grad_output_fp8_row_major,
197218
B_fp8_col_major,
198-
grad_output_scales.squeeze().reciprocal(),
199-
B_scales.squeeze().reciprocal(),
200-
offs,
219+
grad_output_scales.squeeze(-1).reciprocal(),
220+
B_scales.squeeze(1).reciprocal(),
221+
offs=offs,
201222
out_dtype=out_dtype,
202223
use_fast_accum=True,
203224
)
@@ -237,12 +258,21 @@ def backward(ctx, grad_output: torch.Tensor):
237258
assert _is_column_major(A_fp8_col_major), (
238259
"A must be column-major for grad_B = grad_output_t @ A"
239260
)
261+
262+
logger.debug(
263+
(
264+
f"backward grad_B: grad_output_t_fp8_row_major.shape={grad_output_t_fp8_row_major.shape}, "
265+
f"grad_output_t_scale.shape={grad_output_t_scales.shape}, "
266+
f"A_fp8_col_major.shape={A_fp8_col_major.shape}, "
267+
f"A_scale.shape={A_scales.shape}, "
268+
)
269+
)
240270
grad_B = torch._scaled_grouped_mm(
241271
grad_output_t_fp8_row_major,
242272
A_fp8_col_major,
243273
grad_output_t_scales.reciprocal(),
244274
A_scales.reciprocal(),
245-
offs,
275+
offs=offs,
246276
out_dtype=out_dtype,
247277
use_fast_accum=True,
248278
)

torchao/prototype/moe_training/tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,12 @@ def __torch_function__(cls, func, types, args, kwargs={}):
8383
# used for shared experts. This is basically the grouped_mm
8484
# kernel handling a bmm.
8585
A, B = args[0], args[1]
86-
A_is_2d = A.dim() == 2
86+
A_is_2d_or_3d = A.dim() in (2, 3)
8787
B_is_3d = B.dim() == 3
8888
has_offs = kwargs.get(cls.offs_arg_name) is not None
89-
logger.info(f"A.shape={A.shape}, B.shape={B.shape}, has_offs={has_offs}")
90-
91-
if A_is_2d and B_is_3d:
89+
logger.debug(f"A.shape={A.shape}, B.shape={B.shape}, has_offs={has_offs}")
90+
91+
if A_is_2d_or_3d and B_is_3d:
9292
return _scaled_grouped_mm(
9393
*args,
9494
**kwargs,

0 commit comments

Comments
 (0)