@@ -42,8 +42,10 @@ class ExperimentConfig:
4242
4343@dataclass (frozen = True )
4444class ExperimentResult :
45- bf16_ms : float
46- mxfp8_ms : float
45+ fwd_bf16_ms : float
46+ fwd_mxfp8_ms : float
47+ bwd_bf16_ms : float
48+ bwd_mxfp8_ms : float
4749
4850
4951@dataclass (frozen = True )
@@ -55,6 +57,10 @@ class Experiment:
5557def get_configs () -> List [ExperimentConfig ]:
5658 # (batch_size, seq_len, dim)
5759 input_shapes = [
60+ (1 , 8192 , 5120 ),
61+ (2 , 8192 , 5120 ),
62+ (4 , 8192 , 5120 ),
63+ (8 , 8192 , 5120 ),
5864 (16 , 8192 , 5120 ),
5965 ]
6066 configs = []
@@ -67,9 +73,8 @@ def get_configs() -> List[ExperimentConfig]:
6773 return configs
6874
6975
70- def default_a2a_fwd_bwd (
76+ def default_a2a_fwd (
7177 routed_input : torch .Tensor ,
72- labels : torch .Tensor ,
7378 output_splits_list : list [int ],
7479 input_splits_list : list [int ],
7580 device_mesh : DeviceMesh ,
@@ -81,17 +86,12 @@ def default_a2a_fwd_bwd(
8186 device_mesh .get_group (),
8287 )
8388 routed_input = torch .ops ._c10d_functional .wait_tensor (routed_input )
84-
85- loss = F .mse_loss (routed_input , labels )
86- loss .backward ()
87-
8889 torch .cuda .synchronize ()
8990 return routed_input
9091
9192
92- def mxfp8_a2a_fwd_bwd (
93+ def mxfp8_a2a_fwd (
9394 routed_input : torch .Tensor ,
94- labels : torch .Tensor ,
9595 output_splits_list : list [int ],
9696 input_splits_list : list [int ],
9797 device_mesh : DeviceMesh ,
@@ -102,16 +102,22 @@ def mxfp8_a2a_fwd_bwd(
102102 input_splits_list ,
103103 device_mesh .get_group (),
104104 )
105+ torch .cuda .synchronize ()
106+ return routed_input
107+
105108
109+ def mse_loss_and_bwd (
110+ routed_input : torch .Tensor ,
111+ labels : torch .Tensor ,
112+ ):
106113 loss = F .mse_loss (routed_input , labels )
107114 loss .backward ()
108115 torch .cuda .synchronize ()
109116 return routed_input
110117
111118
112119# Compile target funcs
113- default_a2a_sync_compiled = torch .compile (default_a2a_fwd_bwd )
114- mxfp8_a2a_sync_compiled = torch .compile (mxfp8_a2a_fwd_bwd )
120+ mse_loss_and_bwd_compiled = torch .compile (mse_loss_and_bwd )
115121
116122
117123def run_experiment (
@@ -129,82 +135,105 @@ def run_experiment(
129135 # Set up device mesh
130136 mesh = init_device_mesh ("cuda" , (dist .get_world_size (),))
131137
132- # Max output tokens per rank is worst case where one rank receives all tokens
133- input_tokens_per_rank = batch_size * seq_len
134-
135138 def warmup (func_no_args ):
136139 for _ in range (2 ):
137140 func_no_args ()
138141
142+ input_tokens_per_rank = batch_size * seq_len
139143 num_experts_per_rank = 2
140- num_splits = dist .get_world_size () * num_experts_per_rank
141- input_splits = generate_split_sizes (
142- num_splits , input_tokens_per_rank , device = device
143- )
144+ num_experts = dist .get_world_size () * num_experts_per_rank
145+ input_tokens_per_expert = input_tokens_per_rank // num_experts
146+ input_splits = torch .tensor (
147+ input_tokens_per_expert , dtype = torch .int32 , device = device
148+ ).repeat (num_experts )
144149 input_splits_list , output_splits_list = get_split_lists (input_splits , mesh )
145150
146151 # Generate labels
147152 labels_shape = (sum (output_splits_list ), dim )
148153 labels = x .new_ones (* labels_shape )
149154
150- # Bench default a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
151- warmup (
152- lambda : default_a2a_sync_compiled (
153- ref_x , labels , output_splits_list , input_splits_list , mesh
154- )
155- )
155+ # Bench default a2a fwd (exclude d2h sync from preparing input splits_list and output_splits_list)
156+ warmup (lambda : default_a2a_fwd (ref_x , output_splits_list , input_splits_list , mesh ))
156157 start_sec = time .perf_counter ()
157- default_a2a_sync_compiled (
158- ref_x , labels , output_splits_list , input_splits_list , mesh
158+ bf16_routed_input = default_a2a_fwd (
159+ ref_x , output_splits_list , input_splits_list , mesh
159160 )
160161 end_sec = time .perf_counter ()
161- bf16_ms = (end_sec - start_sec ) * 1e3
162+ fwd_bf16_ms = (end_sec - start_sec ) * 1e3
162163 if args .profile :
163164 profile_fn (
164- default_a2a_sync_compiled ,
165+ default_a2a_fwd ,
165166 ref_x ,
166- labels ,
167167 output_splits_list ,
168168 input_splits_list ,
169169 mesh ,
170170 distributed = True ,
171- profile_name = "all_to_all_single_autograd " ,
171+ profile_name = "default_a2a_fwd " ,
172172 )
173173
174- # Bench mxfp8 sync a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
175- warmup (
176- lambda : mxfp8_a2a_sync_compiled (
177- x , labels , output_splits_list , input_splits_list , mesh
174+ # Bench default a2a backward
175+ warmup (lambda : mse_loss_and_bwd_compiled (bf16_routed_input , labels ))
176+ start_sec = time .perf_counter ()
177+ mse_loss_and_bwd_compiled (bf16_routed_input , labels )
178+ end_sec = time .perf_counter ()
179+ bwd_bf16_ms = (end_sec - start_sec ) * 1e3
180+ if args .profile :
181+ profile_fn (
182+ mse_loss_and_bwd_compiled ,
183+ bf16_routed_input ,
184+ labels ,
185+ distributed = True ,
186+ profile_name = "bf16_a2a_bwd" ,
178187 )
179- )
188+
189+ # Bench mxfp8 sync a2a fwd (exclude d2h sync from preparing input splits_list and output_splits_list)
190+ warmup (lambda : mxfp8_a2a_fwd (x , output_splits_list , input_splits_list , mesh ))
180191 start_sec = time .perf_counter ()
181- mxfp8_a2a_sync_compiled ( x , labels , output_splits_list , input_splits_list , mesh )
192+ mxfp8_routed_input = mxfp8_a2a_fwd ( x , output_splits_list , input_splits_list , mesh )
182193 end_sec = time .perf_counter ()
183- mxfp8_ms = (end_sec - start_sec ) * 1e3
194+ fwd_mxfp8_ms = (end_sec - start_sec ) * 1e3
184195 if args .profile :
185196 profile_fn (
186- mxfp8_a2a_sync_compiled ,
197+ mxfp8_a2a_fwd ,
187198 x ,
188- labels ,
189199 output_splits_list ,
190200 input_splits_list ,
191201 mesh ,
192202 distributed = True ,
193- profile_name = "to_mxfp8_a2a_dequant" ,
203+ profile_name = "mxfp8_a2a_fwd" ,
204+ )
205+
206+ # Bench mxfp8 sync a2a backward
207+ warmup (lambda : mse_loss_and_bwd_compiled (mxfp8_routed_input , labels ))
208+ start_sec = time .perf_counter ()
209+ mse_loss_and_bwd_compiled (mxfp8_routed_input , labels )
210+ end_sec = time .perf_counter ()
211+ bwd_mxfp8_ms = (end_sec - start_sec ) * 1e3
212+ if args .profile :
213+ profile_fn (
214+ mse_loss_and_bwd_compiled ,
215+ mxfp8_routed_input ,
216+ labels ,
217+ distributed = True ,
218+ profile_name = "mxfp8_a2a_bwd" ,
194219 )
195220
196221 return ExperimentResult (
197- bf16_ms = bf16_ms ,
198- mxfp8_ms = mxfp8_ms ,
222+ fwd_bf16_ms = fwd_bf16_ms ,
223+ fwd_mxfp8_ms = fwd_mxfp8_ms ,
224+ bwd_bf16_ms = bwd_bf16_ms ,
225+ bwd_mxfp8_ms = bwd_mxfp8_ms ,
199226 )
200227
201228
202229def print_results (experiments : List [Experiment ]):
203230 headers = [
204231 "input_shape" ,
205232 "num_splits" ,
206- "bf16_ms" ,
207- "mxfp8_ms" ,
233+ "fwd_bf16_ms" ,
234+ "fwd_mxfp8_ms" ,
235+ "bwd_bf16_ms" ,
236+ "bwd_mxfp8_ms" ,
208237 ]
209238 rows = []
210239 num_splits = dist .get_world_size ()
@@ -213,8 +242,10 @@ def print_results(experiments: List[Experiment]):
213242 [
214243 str (experiment .config .input_shape ),
215244 num_splits ,
216- experiment .result .bf16_ms ,
217- experiment .result .mxfp8_ms ,
245+ experiment .result .fwd_bf16_ms ,
246+ experiment .result .fwd_mxfp8_ms ,
247+ experiment .result .bwd_bf16_ms ,
248+ experiment .result .bwd_mxfp8_ms ,
218249 ]
219250 )
220251 print (tabulate (rows , headers = headers ))
0 commit comments