For study purpose, we refined original DeepSeekV3 HuggingFace modeling (https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py) to make it capable of training with FSDP+EP
at scale, and might implement the missing part of the original modeling:
- Multi Token Prediction;
- Auxiliary Free Load Balancing;
- Grouped GEMM for Experts;
- Expert Parallelism;
based on the details of DeepSeek-V3 Technical Report (https://arxiv.org/abs/2412.19437) and other open-sourced projects.
python convert_ckpt_hf2dcp.py --input input_hf_ckpt_path --output output_dcp_ckpt_path
Since some assertions in the FSDP source code might be too strict, we need to comment out two assertions:
#if _get_module_fsdp_state(module):
# # TODO: We may relax this by taking the FSDP instance's wrapped
# # module to provide more flexibility to the user.
# raise ValueError("`ignored_modules` should not include FSDP modules")
## TODO: We may relax this no-nested-wrapping constraint to support manual
## wrapping followed by auto wrapping.
#_check_nested_wrapping(root_module)
We have discussed the details with FSDP developer, and the accuracy is guaranteed.
Assume expert_data_process_group
is the process group where you want to shard Expert modules, and data_process_group
is the process group where you want to shard non-Expert modules.
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy, BackwardPrefetch
ignored_mod = []
for layer_id, layer in enumerate(model.layers):
if layer_id >= config.first_k_dense_replace:
layer.mlp.experts = FSDP(
layer.mlp.experts,
process_group=expert_data_process_group,
sharding_strategy=ShardingStrategy.FULL_SHARD,
forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
use_orig_params=True,
)
ignored_mod.append(layer.mlp.experts)
model = FSDP(
module=model,
process_group=data_process_group,
sharding_strategy=ShardingStrategy.FULL_SHARD,
auto_wrap_policy=ModuleWrapPolicy(wrap_cls),
forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
use_orig_params=True,
ignored_modules=ignored_mod,
)
After FSDP wrap finished, you might use below code snippet to load converted DCP checkpoint.
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_model_state_dict, set_model_state_dict
state_dict = get_model_state_dict(model=model)
state_dict = {key: state_dict[key].clone().detach() for key in state_dict}
dcp.load(state_dict=state_dict, checkpoint_id=output_dcp_ckpt_path)
set_model_state_dict(model=model, model_state_dict=state_dict)
del state_dict
torch.cuda.empty_cache()
For more details, please refer to pytorch/pytorch#149396