Skip to content

Commit c61add8

Browse files
add tp integration test
1 parent accbb27 commit c61add8

File tree

2 files changed

+231
-0
lines changed

2 files changed

+231
-0
lines changed

test/prototype/moe_training/test_fsdp.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
######################################################################
7+
#
8+
# To run these unit tests, use the following command:
9+
#
10+
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_fsdp.py
11+
#
12+
#######################################################################
13+
114
import copy
215
import os
316

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
######################################################################
7+
#
8+
# To run these unit tests, use the following command:
9+
#
10+
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_tp.py
11+
#
12+
#######################################################################
13+
14+
import copy
15+
import os
16+
17+
import pytest
18+
import torch
19+
from torch import distributed as dist
20+
from torch import nn
21+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
22+
from torch.nn import functional as F
23+
24+
# this feature requires CUDA and SM89+
25+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
26+
pytest.skip(
27+
"CUDA not available or compute capability < 8.9", allow_module_level=True
28+
)
29+
30+
from torchao.float8.float8_utils import compute_error
31+
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
32+
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
33+
from torchao.quantization.quant_api import quantize_
34+
35+
# this test requires torchtitan
36+
try:
37+
from torchtitan.experiments.llama4.infra.parallelize import apply_moe_tp
38+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
39+
from torchtitan.experiments.llama4.model.moe import MoE
40+
except ImportError:
41+
import warnings
42+
43+
warnings.warn("torchtitan not installed, skipping MoE tests.")
44+
pytest.skip(allow_module_level=True)
45+
46+
47+
@pytest.mark.parametrize(
48+
"target_fqns",
49+
[
50+
["experts"],
51+
["experts,shared_expert"],
52+
],
53+
)
54+
def test_moe_float8_training_tp_sp(target_fqns: list[str]):
55+
assert torch.cuda.is_available()
56+
57+
# setup distributed for fsdp
58+
mesh = setup_distributed()
59+
60+
# define model args
61+
model_args = TransformerModelArgs(
62+
moe_enabled=True,
63+
num_experts=8,
64+
dim=256,
65+
vocab_size=1024,
66+
)
67+
init_std = 0.02
68+
device = torch.device("cuda")
69+
70+
# reference bf16 MoE
71+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
72+
torch.manual_seed(1)
73+
ref_model.init_weights(init_std, device)
74+
75+
# target MoE for testing conversion
76+
model = copy.deepcopy(ref_model)
77+
78+
# assert starting params are identical for both models
79+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
80+
assert torch.equal(param1, param2)
81+
82+
# convert MoE to float8 training
83+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
84+
for target_fqn in target_fqns:
85+
if target_fqn in cur_fqn:
86+
return True
87+
return False
88+
89+
# quantize test model
90+
config = MoETrainingConfig()
91+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
92+
93+
# validate that only the experts were converted
94+
_validate_model_conversion(
95+
model,
96+
target_fqns=target_fqns,
97+
)
98+
99+
# apply TP
100+
apply_moe_tp(model, mesh)
101+
apply_moe_tp(ref_model, mesh)
102+
103+
# inputs
104+
batch, seq, dim = 8, 2048, 256
105+
ref_x = torch.randn(
106+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
107+
)
108+
x = ref_x.detach().clone().requires_grad_(True)
109+
110+
# forward pass
111+
ref_out = ref_model(ref_x)
112+
out = model(x)
113+
114+
# validate output
115+
out_sqnr = compute_error(out, ref_out)
116+
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
117+
118+
# compute loss
119+
labels = torch.ones_like(ref_out)
120+
ref_loss = F.mse_loss(ref_out, labels)
121+
out_loss = F.mse_loss(out, labels)
122+
123+
# backward pass
124+
ref_loss.backward()
125+
out_loss.backward()
126+
127+
# validate input gradient
128+
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
129+
assert input_grad_sqnr.item() >= 30.0, (
130+
f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}."
131+
)
132+
133+
# validate param gradients
134+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
135+
param_grad_sqnr = compute_error(param1.grad, param2.grad)
136+
assert param_grad_sqnr.item() >= 25.0, (
137+
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
138+
)
139+
140+
dist.destroy_process_group()
141+
142+
143+
def _validate_model_conversion(
144+
root_module: nn.Module,
145+
target_fqns: list[str],
146+
):
147+
def _recursive_validate(
148+
module: nn.Module,
149+
cur_fqn: str,
150+
):
151+
is_allowed_module = any([target_fqn in cur_fqn for target_fqn in target_fqns])
152+
153+
# check current module params
154+
for param_name, param in module.named_parameters(recurse=False):
155+
is_converted_type = isinstance(param, ScaledGroupedMMTensor)
156+
if is_converted_type:
157+
assert is_allowed_module, (
158+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
159+
)
160+
if not is_allowed_module:
161+
assert not is_converted_type, (
162+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
163+
)
164+
165+
# recursively check child modules
166+
for child_name, child_module in module.named_children():
167+
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name
168+
_recursive_validate(child_module, child_fqn)
169+
170+
_recursive_validate(root_module, "")
171+
172+
173+
def setup_distributed():
174+
rank = int(os.environ["RANK"])
175+
world_size = int(os.environ["WORLD_SIZE"])
176+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
177+
device_mesh = init_device_mesh("cuda", (world_size,))
178+
# seed must be the same in all processes
179+
torch.manual_seed(1)
180+
torch.cuda.set_device(rank)
181+
return device_mesh
182+
183+
184+
def apply_moe_tp(
185+
model: nn.Module,
186+
tp_mesh: DeviceMesh,
187+
):
188+
from torch.distributed.tensor import Partial, Replicate, Shard
189+
from torch.distributed.tensor.parallel import (
190+
PrepareModuleInputOutput,
191+
parallelize_module,
192+
)
193+
from torchtitan.experiments.llama4.infra.expert_parallel import (
194+
NoParallel,
195+
TensorParallel,
196+
)
197+
198+
moe_layer_plan = {
199+
# input / output sharding on the seqlen dim
200+
# all-gather for input, reduce-scatter for output
201+
"moe": PrepareModuleInputOutput(
202+
input_layouts=(Shard(1),),
203+
desired_input_layouts=(Replicate(),),
204+
use_local_input=True,
205+
output_layouts=(Partial(),),
206+
desired_output_layouts=(Shard(1),),
207+
),
208+
# replicate computation for the router
209+
"moe.router.gate": NoParallel(),
210+
# input Replicate, output Partial
211+
"moe.experts": TensorParallel(output_layout=Partial()),
212+
"moe.shared_expert": TensorParallel(output_layout=Partial()),
213+
}
214+
parallelize_module(
215+
module=model,
216+
device_mesh=tp_mesh,
217+
parallelize_plan=moe_layer_plan,
218+
)

0 commit comments

Comments
 (0)