Skip to content

Commit d9501e0

Browse files
add 2d parallel fsdp+tp support for moe training
1 parent 34dffa5 commit d9501e0

File tree

4 files changed

+265
-0
lines changed

4 files changed

+265
-0
lines changed

test/float8/test_everything.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ then
2020
./test/float8/test_fsdp_compile.sh
2121
./test/float8/test_dtensor.sh
2222
python test/float8/test_fsdp2/test_fsdp2.py
23+
./test/prototype/moe_training/test_fsdp.sh
24+
./test/prototype/moe_training/test_tp.sh
25+
./test/prototype/moe_training/test_fsdp_tp.sh
2326
fi
2427

2528
echo "all tests successful"

test/prototype/moe_training/test_fsdp.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
######################################################################
7+
#
8+
# To run these unit tests, use the following command:
9+
#
10+
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_fsdp.py
11+
#
12+
#######################################################################
13+
114
import copy
215
import os
316

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
######################################################################
7+
#
8+
# To run these unit tests, use the following command:
9+
#
10+
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_fsdp_tp.py
11+
#
12+
#######################################################################
13+
14+
import copy
15+
import os
16+
17+
import pytest
18+
import torch
19+
from torch import distributed as dist
20+
from torch import nn
21+
from torch.distributed._composable.fsdp import fully_shard
22+
from torch.distributed._tensor import DTensor
23+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
24+
from torch.distributed.tensor import Partial, Replicate, Shard
25+
from torch.distributed.tensor.parallel import (
26+
PrepareModuleInputOutput,
27+
parallelize_module,
28+
)
29+
from torch.nn import functional as F
30+
31+
# this feature requires CUDA and SM89+
32+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
33+
pytest.skip(
34+
"CUDA not available or compute capability < 8.9", allow_module_level=True
35+
)
36+
37+
from torchao.float8.float8_utils import compute_error
38+
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
39+
from torchao.quantization.quant_api import quantize_
40+
41+
from .testing_utils import _validate_model_conversion
42+
43+
# this test requires torchtitan
44+
try:
45+
from torchtitan.experiments.llama4.infra.expert_parallel import (
46+
ExpertParallel,
47+
ExpertTensorParallel,
48+
NoParallel,
49+
TensorParallel,
50+
)
51+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
52+
from torchtitan.experiments.llama4.model.moe import MoE
53+
except ImportError:
54+
import warnings
55+
56+
warnings.warn("torchtitan not installed, skipping MoE tests.")
57+
pytest.skip(allow_module_level=True)
58+
59+
60+
@pytest.mark.parametrize(
61+
"target_fqns",
62+
[
63+
["experts"],
64+
# TODO: investigate hang when shared_expert is converted
65+
# ["experts,shared_expert"],
66+
],
67+
)
68+
def test_moe_float8_training_fsdp_tp(target_fqns: list[str]):
69+
assert torch.cuda.is_available()
70+
71+
# setup distributed for tp
72+
mesh = setup_distributed()
73+
74+
# define model args
75+
model_args = TransformerModelArgs(
76+
moe_enabled=True,
77+
num_experts=8,
78+
dim=256,
79+
vocab_size=1024,
80+
)
81+
init_std = 0.02
82+
device = torch.device("cuda")
83+
84+
# reference bf16 MoE
85+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
86+
torch.manual_seed(1)
87+
ref_model.init_weights(init_std, device)
88+
89+
# target MoE for testing conversion
90+
model = copy.deepcopy(ref_model)
91+
92+
# assert starting params are identical for both models
93+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
94+
assert torch.equal(param1, param2)
95+
96+
# convert MoE to float8 training
97+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
98+
for target_fqn in target_fqns:
99+
if target_fqn in cur_fqn:
100+
return True
101+
return False
102+
103+
# quantize test model
104+
config = MoETrainingConfig()
105+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
106+
107+
# validate that only the experts were converted
108+
_validate_model_conversion(
109+
model,
110+
target_fqns=target_fqns,
111+
)
112+
113+
# apply TP
114+
apply_moe_ep_tp(model, tp_mesh=mesh["tp"], ep_mesh=None, ep_tp_mesh=None)
115+
apply_moe_ep_tp(ref_model, tp_mesh=mesh["tp"], ep_mesh=None, ep_tp_mesh=None)
116+
117+
# apply FSDP2
118+
fsdp_config = {"mesh": mesh["dp"]}
119+
fully_shard(model, **fsdp_config)
120+
fully_shard(ref_model, **fsdp_config)
121+
122+
# Rough validation that parallelization was applied properly.
123+
assert isinstance(model.experts.w1.data, DTensor), (
124+
"test model experts.w1 is not a DTensor"
125+
)
126+
assert isinstance(model.experts.w2.data, DTensor), (
127+
"test model experts.w2 is not a DTensor"
128+
)
129+
assert isinstance(model.experts.w3.data, DTensor), (
130+
"test model experts.w3 is not a DTensor"
131+
)
132+
assert isinstance(ref_model.experts.w1.data, DTensor), (
133+
"ref model experts.w1 is not a DTensor"
134+
)
135+
assert isinstance(ref_model.experts.w2.data, DTensor), (
136+
"ref model experts.w2 is not a DTensor"
137+
)
138+
assert isinstance(ref_model.experts.w3.data, DTensor), (
139+
"ref model experts.w3 is not a DTensor"
140+
)
141+
142+
# inputs
143+
batch, seq, dim = 8, 2048, 256
144+
ref_x = torch.randn(
145+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
146+
)
147+
x = ref_x.detach().clone().requires_grad_(True)
148+
149+
# forward pass
150+
ref_out = ref_model(ref_x)
151+
out = model(x)
152+
153+
# validate output
154+
out_sqnr = compute_error(out, ref_out)
155+
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
156+
157+
# compute loss
158+
labels = torch.ones_like(ref_out)
159+
ref_loss = F.mse_loss(ref_out, labels)
160+
out_loss = F.mse_loss(out, labels)
161+
162+
# backward pass
163+
ref_loss.backward()
164+
out_loss.backward()
165+
166+
# validate input gradient
167+
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
168+
assert input_grad_sqnr.item() >= 28.0, (
169+
f"SQNR must be >= 28.0, got {input_grad_sqnr.item()}."
170+
)
171+
172+
# validate param gradients
173+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
174+
param_grad_sqnr = compute_error(param1.grad, param2.grad)
175+
assert param_grad_sqnr.item() >= 25.0, (
176+
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
177+
)
178+
179+
dist.destroy_process_group()
180+
181+
182+
def setup_distributed():
183+
rank = int(os.environ["RANK"])
184+
world_size = int(os.environ["WORLD_SIZE"])
185+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
186+
187+
# https://pytorch.org/tutorials/recipes/distributed_device_mesh.html
188+
device_mesh = init_device_mesh(
189+
"cuda",
190+
(world_size // 2, 2),
191+
mesh_dim_names=("dp", "tp"),
192+
)
193+
194+
# seed must be the same in all processes
195+
torch.manual_seed(1)
196+
torch.cuda.set_device(rank)
197+
return device_mesh
198+
199+
200+
def apply_moe_ep_tp(
201+
model: nn.Module,
202+
tp_mesh: DeviceMesh | None,
203+
ep_mesh: DeviceMesh | None,
204+
ep_tp_mesh: DeviceMesh | None,
205+
):
206+
# Modified version of moe parallelization from https://github.com/pytorch/torchtitan/pull/1324/
207+
# that supports single MoE layer independent of a transformer.
208+
if tp_mesh is not None:
209+
moe_layer_plan = {
210+
# input / output sharding on the seqlen dim
211+
# all-gather for input, reduce-scatter for output
212+
"moe": PrepareModuleInputOutput(
213+
input_layouts=(Shard(1),),
214+
desired_input_layouts=(Replicate(),),
215+
use_local_input=True,
216+
output_layouts=(Partial(),),
217+
desired_output_layouts=(Shard(1),),
218+
),
219+
# replicate computation for the router
220+
"moe.router.gate": NoParallel(),
221+
# input Replicate, output Partial
222+
"moe.shared_expert": TensorParallel(),
223+
}
224+
parallelize_module(
225+
module=model,
226+
device_mesh=tp_mesh,
227+
parallelize_plan=moe_layer_plan,
228+
)
229+
230+
# if ep_mesh is not None:
231+
experts_mesh, experts_plan = None, None
232+
if ep_mesh is None:
233+
experts_mesh = tp_mesh
234+
# input Replicate, output Partial
235+
experts_plan = TensorParallel()
236+
elif tp_mesh is None:
237+
experts_mesh = ep_mesh
238+
# input / output sharding on the batch / tokens dim
239+
experts_plan = ExpertParallel()
240+
else:
241+
experts_mesh = ep_tp_mesh
242+
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
243+
244+
parallelize_module(
245+
module=model.experts,
246+
device_mesh=experts_mesh,
247+
parallelize_plan=experts_plan,
248+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torchrun --nproc_per_node=4 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp_tp.py -s

0 commit comments

Comments
 (0)