Skip to content

Commit 15ce7b6

Browse files
add tests for ep
1 parent 80c88fc commit 15ce7b6

File tree

4 files changed

+268
-0
lines changed

4 files changed

+268
-0
lines changed

test/prototype/moe_training/test_everything.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ then
1515
./test/prototype/moe_training/test_fsdp.sh
1616
./test/prototype/moe_training/test_tp.sh
1717
./test/prototype/moe_training/test_fsdp_tp.sh
18+
./test/prototype/moe_training/test_fsdp_tp_ep.sh
1819
fi
1920

2021
echo "all tests successful"

test/prototype/moe_training/test_fsdp_tp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
191191
def setup_distributed():
192192
rank = int(os.environ["RANK"])
193193
world_size = int(os.environ["WORLD_SIZE"])
194+
assert world_size >= 4, "world size must be >= 4 for 2D parallel"
195+
194196
dist.init_process_group("nccl", rank=rank, world_size=world_size)
195197

196198
# https://pytorch.org/tutorials/recipes/distributed_device_mesh.html
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
######################################################################
7+
#
8+
# To run these unit tests, use the following command:
9+
#
10+
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_fsdp_tp.py
11+
#
12+
#######################################################################
13+
14+
import copy
15+
import os
16+
17+
import pytest
18+
import torch
19+
from torch import distributed as dist
20+
from torch import nn
21+
from torch.distributed._composable.fsdp import fully_shard
22+
from torch.distributed._tensor import DTensor
23+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
24+
from torch.distributed.tensor import Partial, Replicate, Shard
25+
from torch.nn import functional as F
26+
27+
try:
28+
from torch.distributed.tensor.parallel import (
29+
PrepareModuleInputOutput,
30+
parallelize_module,
31+
)
32+
except ImportError:
33+
import warnings
34+
35+
warnings.warn(
36+
"torch version is too old, these tests require nightly build. Skipping MoE training tests."
37+
)
38+
pytest.skip(allow_module_level=True)
39+
40+
# this feature requires CUDA and SM89+
41+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
42+
pytest.skip(
43+
"CUDA not available or compute capability < 8.9", allow_module_level=True
44+
)
45+
46+
from torchao.float8.float8_utils import compute_error
47+
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
48+
from torchao.quantization.quant_api import quantize_
49+
50+
from .testing_utils import _validate_model_conversion
51+
52+
# this test requires torchtitan
53+
try:
54+
from torchtitan.experiments.llama4.infra.expert_parallel import (
55+
ExpertParallel,
56+
ExpertTensorParallel,
57+
NoParallel,
58+
TensorParallel,
59+
)
60+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
61+
from torchtitan.experiments.llama4.model.moe import MoE
62+
except ImportError:
63+
import warnings
64+
65+
warnings.warn("torchtitan not installed, skipping MoE tests.")
66+
pytest.skip(allow_module_level=True)
67+
68+
69+
@pytest.mark.parametrize(
70+
"target_fqns",
71+
[
72+
["experts"],
73+
# TODO: investigate hang when shared_expert is converted
74+
# ["experts,shared_expert"],
75+
],
76+
)
77+
def test_moe_float8_training_fsdp_tp(target_fqns: list[str]):
78+
assert torch.cuda.is_available()
79+
80+
# setup distributed for tp
81+
mesh = setup_distributed()
82+
83+
# define model args
84+
model_args = TransformerModelArgs(
85+
moe_enabled=True,
86+
num_experts=8,
87+
dim=256,
88+
vocab_size=1024,
89+
)
90+
init_std = 0.02
91+
device = torch.device("cuda")
92+
93+
# reference bf16 MoE
94+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
95+
torch.manual_seed(1)
96+
ref_model.init_weights(init_std, device)
97+
98+
# target MoE for testing conversion
99+
model = copy.deepcopy(ref_model)
100+
101+
# assert starting params are identical for both models
102+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
103+
assert torch.equal(param1, param2)
104+
105+
# convert MoE to float8 training
106+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
107+
for target_fqn in target_fqns:
108+
if target_fqn in cur_fqn:
109+
return True
110+
return False
111+
112+
# quantize test model
113+
config = MoETrainingConfig()
114+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
115+
116+
# validate that only the experts were converted
117+
_validate_model_conversion(
118+
model,
119+
target_fqns=target_fqns,
120+
)
121+
122+
# apply TP2EP
123+
apply_moe_ep_tp(
124+
model, tp_mesh=mesh["tp"], ep_mesh=mesh["ep"], ep_tp_mesh=mesh["ep", "tp"]
125+
)
126+
apply_moe_ep_tp(
127+
ref_model, tp_mesh=mesh["tp"], ep_mesh=mesh["ep"], ep_tp_mesh=mesh["ep", "tp"]
128+
)
129+
130+
# apply FSDP2
131+
fsdp_config = {"mesh": mesh["dp"]}
132+
fully_shard(model, **fsdp_config)
133+
fully_shard(ref_model, **fsdp_config)
134+
135+
# Rough validation that parallelization was applied properly.
136+
assert isinstance(model.experts.w1.data, DTensor), (
137+
"test model experts.w1 is not a DTensor"
138+
)
139+
assert isinstance(model.experts.w2.data, DTensor), (
140+
"test model experts.w2 is not a DTensor"
141+
)
142+
assert isinstance(model.experts.w3.data, DTensor), (
143+
"test model experts.w3 is not a DTensor"
144+
)
145+
assert isinstance(ref_model.experts.w1.data, DTensor), (
146+
"ref model experts.w1 is not a DTensor"
147+
)
148+
assert isinstance(ref_model.experts.w2.data, DTensor), (
149+
"ref model experts.w2 is not a DTensor"
150+
)
151+
assert isinstance(ref_model.experts.w3.data, DTensor), (
152+
"ref model experts.w3 is not a DTensor"
153+
)
154+
155+
# inputs
156+
batch, seq, dim = 8, 2048, 256
157+
ref_x = torch.randn(
158+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
159+
)
160+
x = ref_x.detach().clone().requires_grad_(True)
161+
162+
# forward pass
163+
ref_out = ref_model(ref_x)
164+
out = model(x)
165+
166+
# validate output
167+
out_sqnr = compute_error(out, ref_out)
168+
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
169+
170+
# compute loss
171+
labels = torch.ones_like(ref_out)
172+
ref_loss = F.mse_loss(ref_out, labels)
173+
out_loss = F.mse_loss(out, labels)
174+
175+
# backward pass
176+
ref_loss.backward()
177+
out_loss.backward()
178+
179+
# validate input gradient
180+
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
181+
assert input_grad_sqnr.item() >= 28.0, (
182+
f"SQNR must be >= 28.0, got {input_grad_sqnr.item()}."
183+
)
184+
185+
# validate param gradients
186+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
187+
param_grad_sqnr = compute_error(param1.grad, param2.grad)
188+
assert param_grad_sqnr.item() >= 25.0, (
189+
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
190+
)
191+
192+
dist.destroy_process_group()
193+
194+
195+
def setup_distributed():
196+
rank = int(os.environ["RANK"])
197+
world_size = int(os.environ["WORLD_SIZE"])
198+
assert world_size == 8, "world size must be == 8 for 3D parallel test"
199+
200+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
201+
202+
# https://pytorch.org/tutorials/recipes/distributed_device_mesh.html
203+
device_mesh = init_device_mesh(
204+
"cuda",
205+
(2, 2, 2),
206+
mesh_dim_names=("dp", "ep", "tp"),
207+
)
208+
209+
# seed must be the same in all processes
210+
torch.manual_seed(1)
211+
torch.cuda.set_device(rank)
212+
return device_mesh
213+
214+
215+
def apply_moe_ep_tp(
216+
model: nn.Module,
217+
tp_mesh: DeviceMesh | None,
218+
ep_mesh: DeviceMesh | None,
219+
ep_tp_mesh: DeviceMesh | None,
220+
):
221+
# Modified version of moe parallelization from https://github.com/pytorch/torchtitan/pull/1324/
222+
# that supports single MoE layer independent of a transformer.
223+
if tp_mesh is not None:
224+
moe_layer_plan = {
225+
# input / output sharding on the seqlen dim
226+
# all-gather for input, reduce-scatter for output
227+
"moe": PrepareModuleInputOutput(
228+
input_layouts=(Shard(1),),
229+
desired_input_layouts=(Replicate(),),
230+
use_local_input=True,
231+
output_layouts=(Partial(),),
232+
desired_output_layouts=(Shard(1),),
233+
),
234+
# replicate computation for the router
235+
"moe.router.gate": NoParallel(),
236+
# input Replicate, output Partial
237+
"moe.shared_expert": TensorParallel(),
238+
}
239+
parallelize_module(
240+
module=model,
241+
device_mesh=tp_mesh,
242+
parallelize_plan=moe_layer_plan,
243+
)
244+
245+
# if ep_mesh is not None:
246+
experts_mesh, experts_plan = None, None
247+
if ep_mesh is None:
248+
experts_mesh = tp_mesh
249+
# input Replicate, output Partial
250+
experts_plan = TensorParallel()
251+
elif tp_mesh is None:
252+
experts_mesh = ep_mesh
253+
# input / output sharding on the batch / tokens dim
254+
experts_plan = ExpertParallel()
255+
else:
256+
experts_mesh = ep_tp_mesh
257+
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
258+
259+
torch.distributed.breakpoint()
260+
parallelize_module(
261+
module=model.experts,
262+
device_mesh=experts_mesh,
263+
parallelize_plan=experts_plan,
264+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torchrun --nproc_per_node=8 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp_tp_ep.py -s

0 commit comments

Comments
 (0)