-
Notifications
You must be signed in to change notification settings - Fork 426
[DSV3] Add PP support for DSV3 #1345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Returns: | ||
torch.Tensor: Logits tensor of shape (batch_size, vocab_size). | ||
""" | ||
h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does these "if xx is not None else xxx" also applies to the forward function if there's no PP, should we add it to forward directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good question, it only applies to PP. Since in PP we are splitting the model and then setting some modules to None
. For llama, we directly modified the model in titan to make this work:
torchtitan/torchtitan/models/llama3/model/model.py
Lines 419 to 432 in aefe15a
if self.model_args.use_flex_attn: | |
init_attention_mask( | |
input_batch if input_batch is not None else tokens, eos_id=self.eos_id | |
) | |
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages | |
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens | |
for layer in self.layers.values(): | |
h = layer(h, self.freqs_cis) | |
h = self.norm(h) if self.norm else h | |
output = self.output(h) if self.output else h | |
return output |
compared to regular llama:
But I thought it would be cleaner to leave the model forward as-is and only modify if using PP. Maybe that will come back to bite me lol
module_names_per_stage = [] | ||
current_layer = 0 | ||
|
||
for stage_idx in range(num_stages): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the LoC is quite longer than https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/pipeline.py#L106-L121
as you are not generating the split points but generating actual module names per stage. (I'm assuming function-wise they are the same but I didn't check.)
But later on the complexity in pipeline_deepseekv3_module_split
doesn't seem to be saved.
I'd like to learn more about the reasonings behind the change of UI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Module names are more flexible since they can be applied for all models. generate_module_names_per_stage
is specific to deepseek-v3; however, the new helper method i refactored pipeline_module_split
is model agnostic. So I am thinking of upstreaming this as a utility in pytorch core. With that, we can reduce the LoC needed in pipeline.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, but can user still manually specify the split via https://github.com/pytorch/torchtitan/blob/main/torchtitan/config_manager.py#L313
or is it not encouraged any more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also maybe we should try to upstream more functions in https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/pipeline.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, but can user still manually specify the split via https://github.com/pytorch/torchtitan/blob/main/torchtitan/config_manager.py#L313 or is it not encouraged any more?
I think we should move away from this
also maybe we should try to upstream more functions in https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/pipeline.py
Yeah I agree, i think this can be cleaned up, let me think of a way
50e9230
to
d9dbb5b
Compare
4d2ad19
to
3f0bfc8
Compare
|
||
if num_stages == 1: | ||
# Single stage gets everything | ||
layer_names = [f"layers.{i}" for i in range(num_layers)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you take into consideration that the layers in dsv3 are not evenly distributed -- several dense layers, followed by MoE layers
https://github.com/pytorch/torchtitan/pull/1373/files#diff-ed005d894ae945a545c92c33136fba3bde35e70f1b7052f78242a1f69e862ab8R273
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't take this into consideration yet 🫤.
module_names_per_stage = [] | ||
current_layer = 0 | ||
|
||
for stage_idx in range(num_stages): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, but can user still manually specify the split via https://github.com/pytorch/torchtitan/blob/main/torchtitan/config_manager.py#L313
or is it not encouraged any more?
module_names_per_stage = [] | ||
current_layer = 0 | ||
|
||
for stage_idx in range(num_stages): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also maybe we should try to upstream more functions in https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/pipeline.py
df8288b
to
d5988d0
Compare
## Supported Features - FSDP, HSDP - Activation checkpointing - Tensor Parallel (TP) from @tianyu-l - Expert Parallel (EP) ## To be added - Modeling - Merge DeepSeek-V3 and Llama4 MoE common components - Parallelism - Context Parallel support for DeepSeek-V3 - PP support for DeepSeek-V3 @H-Huang is working on #1345 - torch.compile - Quantization - Testing - perfomance and loss converging tests - CI integration - @wwwjn will work on this after PyTorch side diffs (mentioned in #1324) get into PyTorch nightly ## Test 1. With FSDP=8, EP=2 (['dp_shard_mod_ep', 'dp_shard_in_ep'], [4, 2]) ``` [rank0]:[titan] 2025-07-08 15:15:43,068 - root - INFO - step: 1 loss: 12.2616 grad_norm: 0.3918 memory: 65.53GiB(68.98%) tps: 1,482 tflops: 0.61 mfu: 0.06% [rank0]:[titan] 2025-07-08 15:15:43,068 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-08 15:15:43,543 - root - INFO - step: 2 loss: 12.0093 grad_norm: 0.5745 memory: 65.54GiB(68.99%) tps: 69,111 tflops: 28.68 mfu: 2.90% [rank0]:[titan] 2025-07-08 15:15:43,981 - root - INFO - step: 3 loss: 11.1697 grad_norm: 1.2095 memory: 65.54GiB(68.99%) tps: 74,931 tflops: 31.09 mfu: 3.14% [rank0]:[titan] 2025-07-08 15:15:44,015 - root - WARNING - Dataset c4_test is being re-looped [rank0]:[titan] 2025-07-08 15:15:44,409 - root - INFO - step: 4 loss: 10.7248 grad_norm: 1.2230 memory: 65.54GiB(68.99%) tps: 76,668 tflops: 31.81 mfu: 3.22% [rank0]:[titan] 2025-07-08 15:15:44,838 - root - INFO - step: 5 loss: 10.5484 grad_norm: 1.1633 memory: 65.54GiB(68.99%) tps: 76,416 tflops: 31.71 mfu: 3.21% [rank0]:[titan] 2025-07-08 15:15:45,339 - root - INFO - step: 6 loss: 10.3509 grad_norm: 1.1611 memory: 65.54GiB(68.99%) tps: 65,490 tflops: 27.18 mfu: 2.75% [rank0]:[titan] 2025-07-08 15:15:45,401 - root - WARNING - Dataset c4_test is being re-looped [rank0]:[titan] 2025-07-08 15:15:46,121 - root - INFO - step: 7 loss: 10.2153 grad_norm: 1.1410 memory: 65.54GiB(68.99%) tps: 41,934 tflops: 17.40 mfu: 1.76% [rank0]:[titan] 2025-07-08 15:15:46,733 - root - INFO - step: 8 loss: 10.0801 grad_norm: 1.1487 memory: 65.54GiB(68.99%) tps: 53,599 tflops: 22.24 mfu: 2.25% [rank0]:[titan] 2025-07-08 15:15:47,137 - root - INFO - step: 9 loss: 9.9781 grad_norm: 1.1257 memory: 65.54GiB(68.99%) tps: 81,051 tflops: 33.63 mfu: 3.40% [rank0]:[titan] 2025-07-08 15:15:47,554 - root - INFO - step: 10 loss: 9.9183 grad_norm: 1.1012 memory: 65.54GiB(68.99%) tps: 78,712 tflops: 32.66 mfu: 3.30% ``` 2. With FSDP=4, TP=2 ``` [rank0]:[titan] 2025-07-08 15:16:25,927 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-08 15:16:34,993 - root - INFO - step: 1 loss: 12.2768 grad_norm: 0.3836 memory: 41.14GiB(43.31%) tps: 1,750 tflops: 0.73 mfu: 0.07% [rank0]:[titan] 2025-07-08 15:16:34,993 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-08 15:16:35,310 - root - INFO - step: 2 loss: 12.0284 grad_norm: 0.5423 memory: 41.29GiB(43.46%) tps: 51,796 tflops: 21.49 mfu: 2.17% [rank0]:[titan] 2025-07-08 15:16:35,605 - root - INFO - step: 3 loss: 11.2398 grad_norm: 1.2037 memory: 41.29GiB(43.46%) tps: 55,575 tflops: 23.06 mfu: 2.33% [rank0]:[titan] 2025-07-08 15:16:35,912 - root - INFO - step: 4 loss: 10.8246 grad_norm: 1.2360 memory: 41.29GiB(43.46%) tps: 53,553 tflops: 22.22 mfu: 2.25% [rank0]:[titan] 2025-07-08 15:16:36,206 - root - INFO - step: 5 loss: 10.6295 grad_norm: 1.1951 memory: 41.29GiB(43.46%) tps: 55,732 tflops: 23.13 mfu: 2.34% [rank0]:[titan] 2025-07-08 15:16:36,502 - root - INFO - step: 6 loss: 10.5240 grad_norm: 1.1296 memory: 41.29GiB(43.46%) tps: 55,564 tflops: 23.06 mfu: 2.33% [rank0]:[titan] 2025-07-08 15:16:36,793 - root - INFO - step: 7 loss: 10.3426 grad_norm: 1.1630 memory: 41.29GiB(43.46%) tps: 56,295 tflops: 23.36 mfu: 2.36% [rank0]:[titan] 2025-07-08 15:16:36,824 - root - WARNING - Dataset c4_test is being re-looped [rank0]:[titan] 2025-07-08 15:16:37,081 - root - INFO - step: 8 loss: 10.2127 grad_norm: 1.1499 memory: 41.29GiB(43.46%) tps: 57,052 tflops: 23.67 mfu: 2.39% [rank0]:[titan] 2025-07-08 15:16:37,374 - root - INFO - step: 9 loss: 10.0537 grad_norm: 1.1814 memory: 41.29GiB(43.46%) tps: 56,019 tflops: 23.25 mfu: 2.35% [rank0]:[titan] 2025-07-08 15:16:37,664 - root - INFO - step: 10 loss: 10.0311 grad_norm: 1.1082 memory: 41.29GiB(43.46%) tps: 56,504 tflops: 23.45 mfu: 2.37% ``` --------- Co-authored-by: Tianyu Liu <lty@fb.com> Co-authored-by: Howard Huang <howardhuang96@gmail.com>
@tianyu-l no worries haha, i will rebase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it's fine to land this for now and the complete tests/benchmarks can come later. Could you show a screenshot of running FSDP+TP+PP together on 8 GPUs?
Besides, I had a question on DCP save / load / reshard. It'd be great if you can demonstrate an example of resharding during loading.
# Handle simple module attributes (e.g., "linear", "norm") | ||
elif module_name not in modules_to_keep: | ||
# Replace with identity module instead of None | ||
setattr(model, module_name, nn.Identity()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the advantage of doing this instead of setting to None
?
I'm worried about DCP loading / resharding, as multiple PP ranks will have the same fqns.
enable_async_tensor_parallel = false | ||
expert_parallel_degree = 1 | ||
pipeline_parallel_degree = 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe still set this to 1 in debug_model
@@ -51,6 +51,8 @@ fsdp_reshard_after_forward = "default" # default / never / always | |||
tensor_parallel_degree = 1 | |||
enable_async_tensor_parallel = false | |||
expert_parallel_degree = 1 | |||
pipeline_parallel_degree = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can add this in the new 671B toml too
This takes a different approach from what we do in LLaMa3/4:
forward()
method only when using PP (therefore nomodel.py
changes required)Can run the 16B DSV3 on 8 GPUs with PP:
TODO:
pipeline_module_split
totorch.distributed.pipelining
?