Skip to content

Commit 6821971

Browse files
[moe training] Add TP support for routed experts (#2473)
add tp support for fp8 moe training
1 parent 01f7352 commit 6821971

File tree

8 files changed

+340
-90
lines changed

8 files changed

+340
-90
lines changed

test/prototype/moe_training/test_fsdp.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616

1717
from torchao.float8.float8_utils import compute_error
1818
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
19-
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
2019
from torchao.quantization.quant_api import quantize_
2120

21+
from .testing_utils import _validate_model_conversion
22+
2223
# this test requires torchtitan
2324
try:
2425
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
@@ -119,36 +120,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
119120
dist.destroy_process_group()
120121

121122

122-
def _validate_model_conversion(
123-
root_module: nn.Module,
124-
target_fqns: list[str],
125-
):
126-
def _recursive_validate(
127-
module: nn.Module,
128-
cur_fqn: str,
129-
):
130-
is_allowed_module = cur_fqn in target_fqns
131-
132-
# check current module params
133-
for param_name, param in module.named_parameters(recurse=False):
134-
is_converted_type = isinstance(param, ScaledGroupedMMTensor)
135-
if is_converted_type:
136-
assert is_allowed_module, (
137-
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
138-
)
139-
if not is_allowed_module:
140-
assert not is_converted_type, (
141-
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
142-
)
143-
144-
# recursively check child modules
145-
for child_name, child_module in module.named_children():
146-
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name
147-
_recursive_validate(child_module, child_fqn)
148-
149-
_recursive_validate(root_module, "")
150-
151-
152123
def setup_distributed():
153124
rank = int(os.environ["RANK"])
154125
world_size = int(os.environ["WORLD_SIZE"])
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp.py
1+
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp.py -s
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
######################################################################
7+
#
8+
# To run these unit tests, use the following command:
9+
#
10+
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_tp.py
11+
#
12+
#######################################################################
13+
14+
import copy
15+
import os
16+
17+
import pytest
18+
import torch
19+
from torch import distributed as dist
20+
from torch import nn
21+
from torch.distributed._tensor import DTensor
22+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
23+
from torch.distributed.tensor import Partial, Replicate, Shard
24+
from torch.nn import functional as F
25+
26+
try:
27+
from torch.distributed.tensor.parallel import (
28+
PrepareModuleInputOutput,
29+
parallelize_module,
30+
)
31+
except ImportError:
32+
import warnings
33+
34+
warnings.warn(
35+
"torch version is too old, these tests require nightly build. Skipping MoE training tests."
36+
)
37+
pytest.skip(allow_module_level=True)
38+
39+
40+
# this feature requires CUDA and SM89+
41+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
42+
pytest.skip(
43+
"CUDA not available or compute capability < 8.9", allow_module_level=True
44+
)
45+
46+
from torchao.float8.float8_utils import compute_error
47+
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
48+
from torchao.quantization.quant_api import quantize_
49+
50+
from .testing_utils import _validate_model_conversion
51+
52+
# this test requires torchtitan
53+
try:
54+
from torchtitan.experiments.llama4.infra.expert_parallel import (
55+
ExpertParallel,
56+
ExpertTensorParallel,
57+
NoParallel,
58+
TensorParallel,
59+
)
60+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
61+
from torchtitan.experiments.llama4.model.moe import MoE
62+
except ImportError:
63+
import warnings
64+
65+
warnings.warn("torchtitan not installed, skipping MoE tests.")
66+
pytest.skip(allow_module_level=True)
67+
68+
69+
@pytest.mark.parametrize(
70+
"target_fqns",
71+
[
72+
["experts"],
73+
# TODO: investigate hang when shared_expert is converted
74+
# ["experts,shared_expert"],
75+
],
76+
)
77+
def test_moe_float8_training_tp(target_fqns: list[str]):
78+
assert torch.cuda.is_available()
79+
80+
# setup distributed for tp
81+
mesh = setup_distributed()
82+
83+
# define model args
84+
model_args = TransformerModelArgs(
85+
moe_enabled=True,
86+
num_experts=8,
87+
dim=256,
88+
vocab_size=1024,
89+
)
90+
init_std = 0.02
91+
device = torch.device("cuda")
92+
93+
# reference bf16 MoE
94+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
95+
torch.manual_seed(1)
96+
ref_model.init_weights(init_std, device)
97+
98+
# target MoE for testing conversion
99+
model = copy.deepcopy(ref_model)
100+
101+
# assert starting params are identical for both models
102+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
103+
assert torch.equal(param1, param2)
104+
105+
# convert MoE to float8 training
106+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
107+
for target_fqn in target_fqns:
108+
if target_fqn in cur_fqn:
109+
return True
110+
return False
111+
112+
# quantize test model
113+
config = MoETrainingConfig()
114+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
115+
116+
# validate that only the experts were converted
117+
_validate_model_conversion(
118+
model,
119+
target_fqns=target_fqns,
120+
)
121+
122+
# apply TP
123+
apply_moe_ep_tp(model, tp_mesh=mesh, ep_mesh=None, ep_tp_mesh=None)
124+
apply_moe_ep_tp(ref_model, tp_mesh=mesh, ep_mesh=None, ep_tp_mesh=None)
125+
126+
# Rough validation that parallelization was applied properly.
127+
assert isinstance(model.experts.w1.data, DTensor), (
128+
"test model experts.w1 is not a DTensor"
129+
)
130+
assert isinstance(model.experts.w2.data, DTensor), (
131+
"test model experts.w2 is not a DTensor"
132+
)
133+
assert isinstance(model.experts.w3.data, DTensor), (
134+
"test model experts.w3 is not a DTensor"
135+
)
136+
assert isinstance(ref_model.experts.w1.data, DTensor), (
137+
"ref model experts.w1 is not a DTensor"
138+
)
139+
assert isinstance(ref_model.experts.w2.data, DTensor), (
140+
"ref model experts.w2 is not a DTensor"
141+
)
142+
assert isinstance(ref_model.experts.w3.data, DTensor), (
143+
"ref model experts.w3 is not a DTensor"
144+
)
145+
146+
# inputs
147+
batch, seq, dim = 8, 2048, 256
148+
ref_x = torch.randn(
149+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
150+
)
151+
x = ref_x.detach().clone().requires_grad_(True)
152+
153+
# forward pass
154+
ref_out = ref_model(ref_x)
155+
out = model(x)
156+
157+
# validate output
158+
out_sqnr = compute_error(out, ref_out)
159+
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
160+
161+
# compute loss
162+
labels = torch.ones_like(ref_out)
163+
ref_loss = F.mse_loss(ref_out, labels)
164+
out_loss = F.mse_loss(out, labels)
165+
166+
# backward pass
167+
ref_loss.backward()
168+
out_loss.backward()
169+
170+
# validate input gradient
171+
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
172+
assert input_grad_sqnr.item() >= 28.0, (
173+
f"SQNR must be >= 28.0, got {input_grad_sqnr.item()}."
174+
)
175+
176+
# validate param gradients
177+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
178+
param_grad_sqnr = compute_error(param1.grad, param2.grad)
179+
assert param_grad_sqnr.item() >= 25.0, (
180+
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
181+
)
182+
183+
dist.destroy_process_group()
184+
185+
186+
def setup_distributed():
187+
rank = int(os.environ["RANK"])
188+
world_size = int(os.environ["WORLD_SIZE"])
189+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
190+
device_mesh = init_device_mesh("cuda", (world_size,))
191+
# seed must be the same in all processes
192+
torch.manual_seed(1)
193+
torch.cuda.set_device(rank)
194+
return device_mesh
195+
196+
197+
def apply_moe_ep_tp(
198+
model: nn.Module,
199+
tp_mesh: DeviceMesh | None,
200+
ep_mesh: DeviceMesh | None,
201+
ep_tp_mesh: DeviceMesh | None,
202+
):
203+
# Modified version of moe parallelization from https://github.com/pytorch/torchtitan/pull/1324/
204+
# that supports single MoE layer independent of a transformer.
205+
if tp_mesh is not None:
206+
moe_layer_plan = {
207+
# input / output sharding on the seqlen dim
208+
# all-gather for input, reduce-scatter for output
209+
"moe": PrepareModuleInputOutput(
210+
input_layouts=(Shard(1),),
211+
desired_input_layouts=(Replicate(),),
212+
use_local_input=True,
213+
output_layouts=(Partial(),),
214+
desired_output_layouts=(Shard(1),),
215+
),
216+
# replicate computation for the router
217+
"moe.router.gate": NoParallel(),
218+
# input Replicate, output Partial
219+
"moe.shared_expert": TensorParallel(),
220+
}
221+
parallelize_module(
222+
module=model,
223+
device_mesh=tp_mesh,
224+
parallelize_plan=moe_layer_plan,
225+
)
226+
227+
# if ep_mesh is not None:
228+
experts_mesh, experts_plan = None, None
229+
if ep_mesh is None:
230+
experts_mesh = tp_mesh
231+
# input Replicate, output Partial
232+
experts_plan = TensorParallel()
233+
elif tp_mesh is None:
234+
experts_mesh = ep_mesh
235+
# input / output sharding on the batch / tokens dim
236+
experts_plan = ExpertParallel()
237+
else:
238+
experts_mesh = ep_tp_mesh
239+
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
240+
241+
parallelize_module(
242+
module=model.experts,
243+
device_mesh=experts_mesh,
244+
parallelize_plan=experts_plan,
245+
)

test/prototype/moe_training/test_tp.sh

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
torchrun --nproc_per_node=2 -m pytest test/prototype/moe_training/test_tp.py
1+
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_tp.py -s

test/prototype/moe_training/test_training.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313

1414
from torchao.float8.float8_utils import compute_error
1515
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
16-
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
1716
from torchao.quantization.quant_api import quantize_
1817

18+
from .testing_utils import _validate_model_conversion
19+
1920
# this test requires torchtitan
2021
try:
2122
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
@@ -108,33 +109,3 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
108109
assert param_grad_sqnr.item() >= 25.0, (
109110
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
110111
)
111-
112-
113-
def _validate_model_conversion(
114-
root_module: nn.Module,
115-
target_fqns: list[str],
116-
):
117-
def _recursive_validate(
118-
module: nn.Module,
119-
cur_fqn: str,
120-
):
121-
is_allowed_module = cur_fqn in target_fqns
122-
123-
# check current module params
124-
for param_name, param in module.named_parameters(recurse=False):
125-
is_converted_type = isinstance(param, ScaledGroupedMMTensor)
126-
if is_converted_type:
127-
assert is_allowed_module, (
128-
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
129-
)
130-
if not is_allowed_module:
131-
assert not is_converted_type, (
132-
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
133-
)
134-
135-
# recursively check child modules
136-
for child_name, child_module in module.named_children():
137-
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name
138-
_recursive_validate(child_module, child_fqn)
139-
140-
_recursive_validate(root_module, "")
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from torch import nn
2+
3+
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
4+
5+
6+
def _validate_model_conversion(
7+
root_module: nn.Module,
8+
target_fqns: list[str],
9+
):
10+
def _recursive_validate(
11+
module: nn.Module,
12+
cur_fqn: str,
13+
):
14+
is_allowed_module = any([target_fqn in cur_fqn for target_fqn in target_fqns])
15+
16+
# check current module params
17+
for param_name, param in module.named_parameters(recurse=False):
18+
is_converted_type = isinstance(param, ScaledGroupedMMTensor)
19+
if is_converted_type:
20+
assert is_allowed_module, (
21+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
22+
)
23+
if not is_allowed_module:
24+
assert not is_converted_type, (
25+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
26+
)
27+
28+
# recursively check child modules
29+
for child_name, child_module in module.named_children():
30+
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name
31+
_recursive_validate(child_module, child_fqn)
32+
33+
_recursive_validate(root_module, "")

0 commit comments

Comments
 (0)