@@ -40,30 +40,17 @@ def apply_compile(model: nn.Module):
40
40
torch ._dynamo .config .fail_on_recompile_limit_hit = True
41
41
for layer_id , transformer_block in model .layers .named_children ():
42
42
if transformer_block .moe_enabled :
43
- # compile the experts directly which can be wrapped by fsdp
44
43
moe = transformer_block .moe
45
-
46
- # transformer_block.moe.experts = torch.compile(transformer_block.moe.experts, fullgraph=True))
44
+ # Individually compile modules to keep fullgraph=True on FSDP wrapped experts
47
45
moe .experts = torch .compile (moe .experts , fullgraph = True )
48
- moe .router = torch .compile (moe .router , fullgraph = True )
49
46
moe .shared_expert = torch .compile (moe .shared_expert , fullgraph = True )
47
+
48
+ # Separately compile the code around the FSDP wrapped experts
49
+ moe .router = torch .compile (moe .router , fullgraph = True )
50
50
else :
51
51
transformer_block = torch .compile (transformer_block , fullgraph = True )
52
52
model .layers .register_module (layer_id , transformer_block )
53
53
54
- # def _compile_child(parent:nn.Module, child_name: str, child: nn.Module):
55
- # parent.register_module(child_name, torch.compile(child, fullgraph=True))
56
-
57
- # torch._dynamo.config.fail_on_recompile_limit_hit = True
58
- # for layer_id, transformer_block in model.layers.named_children():
59
- # if transformer_block.moe_enabled:
60
- # # compile the experts directly which can be wrapped by fsdp
61
- # moe = transformer_block.moe
62
- # # for submod_id, submod in moe.named_children():
63
- # # _compile_child(moe, submod_id, submod)
64
- # else:
65
- # _compile_child(transformer_block, layer_id, transformer_block)
66
-
67
54
logger .info ("Compiling each TransformerBlock with torch.compile" )
68
55
69
56
0 commit comments