Skip to content

Commit 2f3bb13

Browse files
add tp support for fp8 moe training
1 parent 6ca070d commit 2f3bb13

File tree

8 files changed

+330
-88
lines changed

8 files changed

+330
-88
lines changed

test/prototype/moe_training/test_fsdp.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616

1717
from torchao.float8.float8_utils import compute_error
1818
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
19-
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
2019
from torchao.quantization.quant_api import quantize_
2120

21+
from .testing_utils import _validate_model_conversion
22+
2223
# this test requires torchtitan
2324
try:
2425
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
@@ -119,36 +120,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
119120
dist.destroy_process_group()
120121

121122

122-
def _validate_model_conversion(
123-
root_module: nn.Module,
124-
target_fqns: list[str],
125-
):
126-
def _recursive_validate(
127-
module: nn.Module,
128-
cur_fqn: str,
129-
):
130-
is_allowed_module = cur_fqn in target_fqns
131-
132-
# check current module params
133-
for param_name, param in module.named_parameters(recurse=False):
134-
is_converted_type = isinstance(param, ScaledGroupedMMTensor)
135-
if is_converted_type:
136-
assert is_allowed_module, (
137-
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
138-
)
139-
if not is_allowed_module:
140-
assert not is_converted_type, (
141-
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
142-
)
143-
144-
# recursively check child modules
145-
for child_name, child_module in module.named_children():
146-
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name
147-
_recursive_validate(child_module, child_fqn)
148-
149-
_recursive_validate(root_module, "")
150-
151-
152123
def setup_distributed():
153124
rank = int(os.environ["RANK"])
154125
world_size = int(os.environ["WORLD_SIZE"])
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp.py
1+
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp.py -s
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
######################################################################
7+
#
8+
# To run these unit tests, use the following command:
9+
#
10+
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_tp.py
11+
#
12+
#######################################################################
13+
14+
import copy
15+
import os
16+
17+
import pytest
18+
import torch
19+
from torch import distributed as dist
20+
from torch import nn
21+
from torch.distributed._tensor import DTensor
22+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
23+
from torch.distributed.tensor import Partial, Replicate, Shard
24+
from torch.distributed.tensor.parallel import (
25+
PrepareModuleInputOutput,
26+
parallelize_module,
27+
)
28+
from torch.nn import functional as F
29+
30+
# this feature requires CUDA and SM89+
31+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
32+
pytest.skip(
33+
"CUDA not available or compute capability < 8.9", allow_module_level=True
34+
)
35+
36+
from torchao.float8.float8_utils import compute_error
37+
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
38+
from torchao.quantization.quant_api import quantize_
39+
40+
from .testing_utils import _validate_model_conversion
41+
42+
# this test requires torchtitan
43+
try:
44+
from torchtitan.experiments.llama4.infra.expert_parallel import (
45+
ExpertParallel,
46+
ExpertTensorParallel,
47+
NoParallel,
48+
TensorParallel,
49+
)
50+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
51+
from torchtitan.experiments.llama4.model.moe import MoE
52+
except ImportError:
53+
import warnings
54+
55+
warnings.warn("torchtitan not installed, skipping MoE tests.")
56+
pytest.skip(allow_module_level=True)
57+
58+
59+
@pytest.mark.parametrize(
60+
"target_fqns",
61+
[
62+
["experts"],
63+
# TODO: investigate hang when shared_expert is converted
64+
# ["experts,shared_expert"],
65+
],
66+
)
67+
def test_moe_float8_training_tp(target_fqns: list[str]):
68+
assert torch.cuda.is_available()
69+
70+
# setup distributed for tp
71+
mesh = setup_distributed()
72+
73+
# define model args
74+
model_args = TransformerModelArgs(
75+
moe_enabled=True,
76+
num_experts=8,
77+
dim=256,
78+
vocab_size=1024,
79+
)
80+
init_std = 0.02
81+
device = torch.device("cuda")
82+
83+
# reference bf16 MoE
84+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
85+
torch.manual_seed(1)
86+
ref_model.init_weights(init_std, device)
87+
88+
# target MoE for testing conversion
89+
model = copy.deepcopy(ref_model)
90+
91+
# assert starting params are identical for both models
92+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
93+
assert torch.equal(param1, param2)
94+
95+
# convert MoE to float8 training
96+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
97+
for target_fqn in target_fqns:
98+
if target_fqn in cur_fqn:
99+
return True
100+
return False
101+
102+
# quantize test model
103+
config = MoETrainingConfig()
104+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
105+
106+
# validate that only the experts were converted
107+
_validate_model_conversion(
108+
model,
109+
target_fqns=target_fqns,
110+
)
111+
112+
# apply TP
113+
apply_moe_ep_tp(model, tp_mesh=mesh, ep_mesh=None, ep_tp_mesh=None)
114+
apply_moe_ep_tp(ref_model, tp_mesh=mesh, ep_mesh=None, ep_tp_mesh=None)
115+
116+
# Rough validation that parallelization was applied properly.
117+
assert isinstance(model.experts.w1.data, DTensor), (
118+
"test model experts.w1 is not a DTensor"
119+
)
120+
assert isinstance(model.experts.w2.data, DTensor), (
121+
"test model experts.w2 is not a DTensor"
122+
)
123+
assert isinstance(model.experts.w3.data, DTensor), (
124+
"test model experts.w3 is not a DTensor"
125+
)
126+
assert isinstance(ref_model.experts.w1.data, DTensor), (
127+
"ref model experts.w1 is not a DTensor"
128+
)
129+
assert isinstance(ref_model.experts.w2.data, DTensor), (
130+
"ref model experts.w2 is not a DTensor"
131+
)
132+
assert isinstance(ref_model.experts.w3.data, DTensor), (
133+
"ref model experts.w3 is not a DTensor"
134+
)
135+
136+
# inputs
137+
batch, seq, dim = 8, 2048, 256
138+
ref_x = torch.randn(
139+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
140+
)
141+
x = ref_x.detach().clone().requires_grad_(True)
142+
143+
# forward pass
144+
ref_out = ref_model(ref_x)
145+
out = model(x)
146+
147+
# validate output
148+
out_sqnr = compute_error(out, ref_out)
149+
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
150+
151+
# compute loss
152+
labels = torch.ones_like(ref_out)
153+
ref_loss = F.mse_loss(ref_out, labels)
154+
out_loss = F.mse_loss(out, labels)
155+
156+
# backward pass
157+
ref_loss.backward()
158+
out_loss.backward()
159+
160+
# validate input gradient
161+
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
162+
assert input_grad_sqnr.item() >= 28.0, (
163+
f"SQNR must be >= 28.0, got {input_grad_sqnr.item()}."
164+
)
165+
166+
# validate param gradients
167+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
168+
param_grad_sqnr = compute_error(param1.grad, param2.grad)
169+
assert param_grad_sqnr.item() >= 25.0, (
170+
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
171+
)
172+
173+
dist.destroy_process_group()
174+
175+
176+
def setup_distributed():
177+
rank = int(os.environ["RANK"])
178+
world_size = int(os.environ["WORLD_SIZE"])
179+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
180+
device_mesh = init_device_mesh("cuda", (world_size,))
181+
# seed must be the same in all processes
182+
torch.manual_seed(1)
183+
torch.cuda.set_device(rank)
184+
return device_mesh
185+
186+
187+
def apply_moe_ep_tp(
188+
model: nn.Module,
189+
tp_mesh: DeviceMesh | None,
190+
ep_mesh: DeviceMesh | None,
191+
ep_tp_mesh: DeviceMesh | None,
192+
):
193+
# Modified version of moe parallelization from https://github.com/pytorch/torchtitan/pull/1324/
194+
# that supports single MoE layer independent of a transformer.
195+
if tp_mesh is not None:
196+
moe_layer_plan = {
197+
# input / output sharding on the seqlen dim
198+
# all-gather for input, reduce-scatter for output
199+
"moe": PrepareModuleInputOutput(
200+
input_layouts=(Shard(1),),
201+
desired_input_layouts=(Replicate(),),
202+
use_local_input=True,
203+
output_layouts=(Partial(),),
204+
desired_output_layouts=(Shard(1),),
205+
),
206+
# replicate computation for the router
207+
"moe.router.gate": NoParallel(),
208+
# input Replicate, output Partial
209+
"moe.shared_expert": TensorParallel(),
210+
}
211+
parallelize_module(
212+
module=model,
213+
device_mesh=tp_mesh,
214+
parallelize_plan=moe_layer_plan,
215+
)
216+
217+
# if ep_mesh is not None:
218+
experts_mesh, experts_plan = None, None
219+
if ep_mesh is None:
220+
experts_mesh = tp_mesh
221+
# input Replicate, output Partial
222+
experts_plan = TensorParallel()
223+
elif tp_mesh is None:
224+
experts_mesh = ep_mesh
225+
# input / output sharding on the batch / tokens dim
226+
experts_plan = ExpertParallel()
227+
else:
228+
experts_mesh = ep_tp_mesh
229+
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
230+
231+
parallelize_module(
232+
module=model.experts,
233+
device_mesh=experts_mesh,
234+
parallelize_plan=experts_plan,
235+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_tp.py -s

test/prototype/moe_training/test_training.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313

1414
from torchao.float8.float8_utils import compute_error
1515
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
16-
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
1716
from torchao.quantization.quant_api import quantize_
1817

18+
from .testing_utils import _validate_model_conversion
19+
1920
# this test requires torchtitan
2021
try:
2122
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
@@ -108,33 +109,3 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
108109
assert param_grad_sqnr.item() >= 25.0, (
109110
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
110111
)
111-
112-
113-
def _validate_model_conversion(
114-
root_module: nn.Module,
115-
target_fqns: list[str],
116-
):
117-
def _recursive_validate(
118-
module: nn.Module,
119-
cur_fqn: str,
120-
):
121-
is_allowed_module = cur_fqn in target_fqns
122-
123-
# check current module params
124-
for param_name, param in module.named_parameters(recurse=False):
125-
is_converted_type = isinstance(param, ScaledGroupedMMTensor)
126-
if is_converted_type:
127-
assert is_allowed_module, (
128-
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
129-
)
130-
if not is_allowed_module:
131-
assert not is_converted_type, (
132-
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
133-
)
134-
135-
# recursively check child modules
136-
for child_name, child_module in module.named_children():
137-
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name
138-
_recursive_validate(child_module, child_fqn)
139-
140-
_recursive_validate(root_module, "")
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from torch import nn
2+
3+
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
4+
5+
6+
def _validate_model_conversion(
7+
root_module: nn.Module,
8+
target_fqns: list[str],
9+
):
10+
def _recursive_validate(
11+
module: nn.Module,
12+
cur_fqn: str,
13+
):
14+
is_allowed_module = any([target_fqn in cur_fqn for target_fqn in target_fqns])
15+
16+
# check current module params
17+
for param_name, param in module.named_parameters(recurse=False):
18+
is_converted_type = isinstance(param, ScaledGroupedMMTensor)
19+
if is_converted_type:
20+
assert is_allowed_module, (
21+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
22+
)
23+
if not is_allowed_module:
24+
assert not is_converted_type, (
25+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
26+
)
27+
28+
# recursively check child modules
29+
for child_name, child_module in module.named_children():
30+
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name
31+
_recursive_validate(child_module, child_fqn)
32+
33+
_recursive_validate(root_module, "")

0 commit comments

Comments
 (0)