You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: torchao/prototype/moe_training/README.md
+22-25Lines changed: 22 additions & 25 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -2,13 +2,10 @@
2
2
3
3
This prototype provides:
4
4
5
-
1. Quantized building block for low precision MoE training: [_quantize_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/53b5efdac921a38fd15e8d3ac8191c3927140287/torchao/prototype/moe_training/scaled_grouped_mm.py#L42). It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [example](#torchao_scaled_grouped_mm-example-forward--backward-pass) of a forward and backward pass below.
5
+
1. Quantized building block for low precision MoE training: [_to_mxfp8_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/53b5efdac921a38fd15e8d3ac8191c3927140287/torchao/prototype/moe_training/scaled_grouped_mm.py#L677). It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [example](#torchao_scaled_grouped_mm-example-forward--backward-pass) of a forward and backward pass below.
6
6
- Using MXFP8 on a B200 GPU, this provides:
7
7
-**~1.4x - 1.8x speedups** over bfloat16 `torch._grouped_mm` for Llama4 Scout shapes
8
-
-**~1.15 - 1.3x speedups** over bfloat16 `torch._grouped_mm` for DeepSeekV3 671b shapes
9
-
- We also provide the following convenience functions for specific recipes:
0 commit comments