Skip to content

Commit 1161f7f

Browse files
committed
Update
[ghstack-poisoned]
2 parents 1aff468 + 57b8876 commit 1161f7f

File tree

6 files changed

+33
-38
lines changed

6 files changed

+33
-38
lines changed

test/prototype/moe_training/test_fsdp.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@
4646

4747
# this test requires torchtitan
4848
try:
49-
from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m
5049
from torchtitan.models.moe import MoE, MoEArgs
50+
from torchtitan.models.moe.utils import set_token_group_alignment_size_m
5151
except ImportError:
5252
pytest.skip(
5353
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
@@ -62,9 +62,6 @@ def device_mesh_1d() -> DeviceMesh:
6262
"""
6363
rank = int(os.environ["RANK"])
6464
world_size = int(os.environ["WORLD_SIZE"])
65-
if not dist.is_initialized():
66-
dist.init_process_group("nccl", rank=rank, world_size=world_size)
67-
6865
device_mesh = init_device_mesh("cuda", (world_size,))
6966
torch.manual_seed(1)
7067
torch.cuda.set_device(rank)

test/prototype/moe_training/test_fsdp_tp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@
6565
ExpertTensorParallel,
6666
NoParallel,
6767
TensorParallel,
68-
set_token_group_alignment_size_m,
6968
)
7069
from torchtitan.models.moe import MoE, MoEArgs
70+
from torchtitan.models.moe.utils import set_token_group_alignment_size_m
7171
except ImportError:
7272
pytest.skip(
7373
"torchtitan not installed, skipping MoE tests.", allow_module_level=True

test/prototype/moe_training/test_tp.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,14 @@
5858

5959
# this test requires torchtitan
6060
try:
61+
from torchtitan.distributed import NoParallel
6162
from torchtitan.distributed.expert_parallel import (
6263
ExpertParallel,
6364
ExpertTensorParallel,
64-
NoParallel,
6565
TensorParallel,
66-
set_token_group_alignment_size_m,
6766
)
6867
from torchtitan.models.moe import MoE, MoEArgs
68+
from torchtitan.models.moe.utils import set_token_group_alignment_size_m
6969
except ImportError:
7070
pytest.skip(
7171
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
@@ -80,9 +80,6 @@ def device_mesh_1d() -> DeviceMesh:
8080
"""
8181
rank = int(os.environ["RANK"])
8282
world_size = int(os.environ["WORLD_SIZE"])
83-
if not dist.is_initialized():
84-
dist.init_process_group("nccl", rank=rank, world_size=world_size)
85-
8683
device_mesh = init_device_mesh("cuda", (world_size,))
8784
torch.manual_seed(1)
8885
torch.cuda.set_device(rank)

test/prototype/moe_training/test_training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222

2323
# this test requires torchtitan
2424
try:
25-
from torchtitan.distributed.expert_parallel import (
25+
from torchtitan.models.moe import MoE, MoEArgs
26+
from torchtitan.models.moe.utils import (
2627
set_token_group_alignment_size_m,
2728
)
28-
from torchtitan.models.moe import MoE, MoEArgs
2929
except ImportError:
3030
pytest.skip(
3131
"torchtitan not installed, skipping MoE tests.", allow_module_level=True

torchao/prototype/moe_training/README.md

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22

33
This prototype provides:
44

5-
1. Quantized building block for low precision MoE training: [_quantize_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/53b5efdac921a38fd15e8d3ac8191c3927140287/torchao/prototype/moe_training/scaled_grouped_mm.py#L42). It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [example](#torchao_scaled_grouped_mm-example-forward--backward-pass) of a forward and backward pass below.
5+
1. Quantized building block for low precision MoE training: [_to_mxfp8_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/53b5efdac921a38fd15e8d3ac8191c3927140287/torchao/prototype/moe_training/scaled_grouped_mm.py#L677). It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [example](#torchao_scaled_grouped_mm-example-forward--backward-pass) of a forward and backward pass below.
66
- Using MXFP8 on a B200 GPU, this provides:
77
- **~1.4x - 1.8x speedups** over bfloat16 `torch._grouped_mm` for Llama4 Scout shapes
8-
- **~1.15 - 1.3x speedups** over bfloat16 `torch._grouped_mm` for DeepSeekV3 671b shapes
9-
- We also provide the following convenience functions for specific recipes:
10-
- [_to_mxfp8_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/53b5efdac921a38fd15e8d3ac8191c3927140287/torchao/prototype/moe_training/scaled_grouped_mm.py#L677)
11-
- [_to_fp8_rowwise_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/53b5efdac921a38fd15e8d3ac8191c3927140287/torchao/prototype/moe_training/scaled_grouped_mm.py#L678)
8+
- **~1.19 - 1.6x speedups** over bfloat16 `torch._grouped_mm` for DeepSeekV3 671b shapes
129

1310

1411

@@ -28,12 +25,12 @@ This prototype provides:
2825
- [Limitations](#limitations)
2926

3027
## Examples
31-
#### _quantize_then_scaled_grouped_mm usage
28+
#### _to_mxfp8_and_scaled_grouped_mm usage
3229
```python
3330
import torch
3431
from torch.nn import functional as F
3532
from torchao.prototype.moe_training import (
36-
_quantize_then_scaled_grouped_mm
33+
_to_mxfp8_then_scaled_grouped_mm,
3734
)
3835
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
3936
from torchao.prototype.moe_training.utils import generate_jagged_offs
@@ -48,11 +45,10 @@ B = torch.randn(num_groups, N, K, dtype=torch.bfloat16, device="cuda", requires_
4845
offs = generate_jagged_offs(num_groups, total_M, device="cuda")
4946

5047
# Forward and backward example
51-
out = _quantize_then_scaled_grouped_mm(
48+
out = _to_mxfp8_then_scaled_grouped_mm(
5249
A,
5350
B.transpose(-2, -1),
54-
offs=offs,
55-
scaling_type=MoEScalingType.MXFP8,
51+
offs,
5652
)
5753

5854
# (Fake labels for demonstration purposes)
@@ -63,20 +59,20 @@ loss.backward()
6359

6460
#### Model conversion API example: end-to-end training
6561
```python
62+
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
6663
import torch
6764
from torch import nn
6865
from torch.nn import functional as F
6966

70-
# this feature requires CUDA 12.8+ and SM100+
67+
# This feature requires CUDA 12.8+ and SM100+
7168
assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0)
7269

7370
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
7471
from torchao.quantization.quant_api import quantize_
7572

76-
# this example uses torchtitan llama4 MoE, see
77-
# this benchmark requires torchtitan
73+
# This example uses torchtitan Llama4 MoE.
7874
try:
79-
from torchtitan.distributed.expert_parallel import (
75+
from torchtitan.models.moe.utils import (
8076
set_token_group_alignment_size_m,
8177
)
8278
from torchtitan.models.moe import MoE, MoEArgs
@@ -86,7 +82,7 @@ except ImportError:
8682
)
8783

8884

89-
# initialize model
85+
# Initialize model
9086
device = torch.device("cuda")
9187
moe_args = MoEArgs(
9288
num_experts=8,
@@ -96,7 +92,7 @@ model = MoE(moe_args, dim, hidden_dim).to(torch.bfloat16).to(device)
9692
init_std = 0.02
9793
model.init_weights(init_std, device)
9894

99-
# module filter function to define which modules to quantize
95+
# Module filter function to define which modules to quantize
10096
target_fqns = ["experts"]
10197

10298

@@ -106,31 +102,32 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
106102
return True
107103
return False
108104

109-
# Token group alignment size must be 32 for MXFP8 training
110-
alignment_size = 32 if recipe == MoEScalingType.MXFP8 else 16
105+
# Token group sizes must be padded to multiple of MXFP8 scaling block size (1x32)
106+
alignment_size = 32
111107
set_token_group_alignment_size_m(alignment_size)
112108

113-
# quantize the model
114-
config = MoETrainingConfig()
109+
# Convert model to use MXFP8 scaled grouped GEMMs
110+
config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8)
115111
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
116112

117-
# training loop
113+
# Training loop
118114
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
119115
batch_size, seq_len = 2, 2048
120116
for step in range(10):
117+
# Simulate random batch of input data
121118
x = torch.randn(
122-
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
119+
batch_size, seq_len, dim, dtype=torch.bfloat16, requires_grad=True, device=device
123120
)
124121

125-
# forward pass
122+
# Forward pass
126123
out = model(x)
127124

128-
# compute loss
125+
# Compute loss with fake labels for demonstration purposes
129126
labels = torch.ones_like(out)
130127
out_loss = F.mse_loss(out, labels)
131128
print(f"step {step} loss: {out_loss.item()}")
132129

133-
# backward pass
130+
# Backward pass
134131
out_loss.backward()
135132
optimizer.step()
136133
optimizer.zero_grad()
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from torchao.prototype.moe_training.scaled_grouped_mm import (
22
_quantize_then_scaled_grouped_mm,
3+
_to_mxfp8_then_scaled_grouped_mm,
34
)
45

5-
__all__ = ["_quantize_then_scaled_grouped_mm"]
6+
__all__ = [
7+
"_quantize_then_scaled_grouped_mm",
8+
"_to_mxfp8_then_scaled_grouped_mm",
9+
]

0 commit comments

Comments
 (0)