Skip to content

Commit dffb3a0

Browse files
[mxfp8 moe training] bench and profile mxfp8 a2a fwd and bwd separately (#3203)
1 parent d089c6a commit dffb3a0

File tree

1 file changed

+78
-47
lines changed

1 file changed

+78
-47
lines changed

benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py

Lines changed: 78 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ class ExperimentConfig:
4242

4343
@dataclass(frozen=True)
4444
class ExperimentResult:
45-
bf16_ms: float
46-
mxfp8_ms: float
45+
fwd_bf16_ms: float
46+
fwd_mxfp8_ms: float
47+
bwd_bf16_ms: float
48+
bwd_mxfp8_ms: float
4749

4850

4951
@dataclass(frozen=True)
@@ -55,6 +57,10 @@ class Experiment:
5557
def get_configs() -> List[ExperimentConfig]:
5658
# (batch_size, seq_len, dim)
5759
input_shapes = [
60+
(1, 8192, 5120),
61+
(2, 8192, 5120),
62+
(4, 8192, 5120),
63+
(8, 8192, 5120),
5864
(16, 8192, 5120),
5965
]
6066
configs = []
@@ -67,9 +73,8 @@ def get_configs() -> List[ExperimentConfig]:
6773
return configs
6874

6975

70-
def default_a2a_fwd_bwd(
76+
def default_a2a_fwd(
7177
routed_input: torch.Tensor,
72-
labels: torch.Tensor,
7378
output_splits_list: list[int],
7479
input_splits_list: list[int],
7580
device_mesh: DeviceMesh,
@@ -81,17 +86,12 @@ def default_a2a_fwd_bwd(
8186
device_mesh.get_group(),
8287
)
8388
routed_input = torch.ops._c10d_functional.wait_tensor(routed_input)
84-
85-
loss = F.mse_loss(routed_input, labels)
86-
loss.backward()
87-
8889
torch.cuda.synchronize()
8990
return routed_input
9091

9192

92-
def mxfp8_a2a_fwd_bwd(
93+
def mxfp8_a2a_fwd(
9394
routed_input: torch.Tensor,
94-
labels: torch.Tensor,
9595
output_splits_list: list[int],
9696
input_splits_list: list[int],
9797
device_mesh: DeviceMesh,
@@ -102,16 +102,22 @@ def mxfp8_a2a_fwd_bwd(
102102
input_splits_list,
103103
device_mesh.get_group(),
104104
)
105+
torch.cuda.synchronize()
106+
return routed_input
107+
105108

109+
def mse_loss_and_bwd(
110+
routed_input: torch.Tensor,
111+
labels: torch.Tensor,
112+
):
106113
loss = F.mse_loss(routed_input, labels)
107114
loss.backward()
108115
torch.cuda.synchronize()
109116
return routed_input
110117

111118

112119
# Compile target funcs
113-
default_a2a_sync_compiled = torch.compile(default_a2a_fwd_bwd)
114-
mxfp8_a2a_sync_compiled = torch.compile(mxfp8_a2a_fwd_bwd)
120+
mse_loss_and_bwd_compiled = torch.compile(mse_loss_and_bwd)
115121

116122

117123
def run_experiment(
@@ -129,82 +135,105 @@ def run_experiment(
129135
# Set up device mesh
130136
mesh = init_device_mesh("cuda", (dist.get_world_size(),))
131137

132-
# Max output tokens per rank is worst case where one rank receives all tokens
133-
input_tokens_per_rank = batch_size * seq_len
134-
135138
def warmup(func_no_args):
136139
for _ in range(2):
137140
func_no_args()
138141

142+
input_tokens_per_rank = batch_size * seq_len
139143
num_experts_per_rank = 2
140-
num_splits = dist.get_world_size() * num_experts_per_rank
141-
input_splits = generate_split_sizes(
142-
num_splits, input_tokens_per_rank, device=device
143-
)
144+
num_experts = dist.get_world_size() * num_experts_per_rank
145+
input_tokens_per_expert = input_tokens_per_rank // num_experts
146+
input_splits = torch.tensor(
147+
input_tokens_per_expert, dtype=torch.int32, device=device
148+
).repeat(num_experts)
144149
input_splits_list, output_splits_list = get_split_lists(input_splits, mesh)
145150

146151
# Generate labels
147152
labels_shape = (sum(output_splits_list), dim)
148153
labels = x.new_ones(*labels_shape)
149154

150-
# Bench default a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
151-
warmup(
152-
lambda: default_a2a_sync_compiled(
153-
ref_x, labels, output_splits_list, input_splits_list, mesh
154-
)
155-
)
155+
# Bench default a2a fwd (exclude d2h sync from preparing input splits_list and output_splits_list)
156+
warmup(lambda: default_a2a_fwd(ref_x, output_splits_list, input_splits_list, mesh))
156157
start_sec = time.perf_counter()
157-
default_a2a_sync_compiled(
158-
ref_x, labels, output_splits_list, input_splits_list, mesh
158+
bf16_routed_input = default_a2a_fwd(
159+
ref_x, output_splits_list, input_splits_list, mesh
159160
)
160161
end_sec = time.perf_counter()
161-
bf16_ms = (end_sec - start_sec) * 1e3
162+
fwd_bf16_ms = (end_sec - start_sec) * 1e3
162163
if args.profile:
163164
profile_fn(
164-
default_a2a_sync_compiled,
165+
default_a2a_fwd,
165166
ref_x,
166-
labels,
167167
output_splits_list,
168168
input_splits_list,
169169
mesh,
170170
distributed=True,
171-
profile_name="all_to_all_single_autograd",
171+
profile_name="default_a2a_fwd",
172172
)
173173

174-
# Bench mxfp8 sync a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
175-
warmup(
176-
lambda: mxfp8_a2a_sync_compiled(
177-
x, labels, output_splits_list, input_splits_list, mesh
174+
# Bench default a2a backward
175+
warmup(lambda: mse_loss_and_bwd_compiled(bf16_routed_input, labels))
176+
start_sec = time.perf_counter()
177+
mse_loss_and_bwd_compiled(bf16_routed_input, labels)
178+
end_sec = time.perf_counter()
179+
bwd_bf16_ms = (end_sec - start_sec) * 1e3
180+
if args.profile:
181+
profile_fn(
182+
mse_loss_and_bwd_compiled,
183+
bf16_routed_input,
184+
labels,
185+
distributed=True,
186+
profile_name="bf16_a2a_bwd",
178187
)
179-
)
188+
189+
# Bench mxfp8 sync a2a fwd (exclude d2h sync from preparing input splits_list and output_splits_list)
190+
warmup(lambda: mxfp8_a2a_fwd(x, output_splits_list, input_splits_list, mesh))
180191
start_sec = time.perf_counter()
181-
mxfp8_a2a_sync_compiled(x, labels, output_splits_list, input_splits_list, mesh)
192+
mxfp8_routed_input = mxfp8_a2a_fwd(x, output_splits_list, input_splits_list, mesh)
182193
end_sec = time.perf_counter()
183-
mxfp8_ms = (end_sec - start_sec) * 1e3
194+
fwd_mxfp8_ms = (end_sec - start_sec) * 1e3
184195
if args.profile:
185196
profile_fn(
186-
mxfp8_a2a_sync_compiled,
197+
mxfp8_a2a_fwd,
187198
x,
188-
labels,
189199
output_splits_list,
190200
input_splits_list,
191201
mesh,
192202
distributed=True,
193-
profile_name="to_mxfp8_a2a_dequant",
203+
profile_name="mxfp8_a2a_fwd",
204+
)
205+
206+
# Bench mxfp8 sync a2a backward
207+
warmup(lambda: mse_loss_and_bwd_compiled(mxfp8_routed_input, labels))
208+
start_sec = time.perf_counter()
209+
mse_loss_and_bwd_compiled(mxfp8_routed_input, labels)
210+
end_sec = time.perf_counter()
211+
bwd_mxfp8_ms = (end_sec - start_sec) * 1e3
212+
if args.profile:
213+
profile_fn(
214+
mse_loss_and_bwd_compiled,
215+
mxfp8_routed_input,
216+
labels,
217+
distributed=True,
218+
profile_name="mxfp8_a2a_bwd",
194219
)
195220

196221
return ExperimentResult(
197-
bf16_ms=bf16_ms,
198-
mxfp8_ms=mxfp8_ms,
222+
fwd_bf16_ms=fwd_bf16_ms,
223+
fwd_mxfp8_ms=fwd_mxfp8_ms,
224+
bwd_bf16_ms=bwd_bf16_ms,
225+
bwd_mxfp8_ms=bwd_mxfp8_ms,
199226
)
200227

201228

202229
def print_results(experiments: List[Experiment]):
203230
headers = [
204231
"input_shape",
205232
"num_splits",
206-
"bf16_ms",
207-
"mxfp8_ms",
233+
"fwd_bf16_ms",
234+
"fwd_mxfp8_ms",
235+
"bwd_bf16_ms",
236+
"bwd_mxfp8_ms",
208237
]
209238
rows = []
210239
num_splits = dist.get_world_size()
@@ -213,8 +242,10 @@ def print_results(experiments: List[Experiment]):
213242
[
214243
str(experiment.config.input_shape),
215244
num_splits,
216-
experiment.result.bf16_ms,
217-
experiment.result.mxfp8_ms,
245+
experiment.result.fwd_bf16_ms,
246+
experiment.result.fwd_mxfp8_ms,
247+
experiment.result.bwd_bf16_ms,
248+
experiment.result.bwd_mxfp8_ms,
218249
]
219250
)
220251
print(tabulate(rows, headers=headers))

0 commit comments

Comments
 (0)