Open
Description
Bug description
When I apply FSDP+TP to the Llama4 debug model using plain eager bf16 training, the MoE routed experts weights are DTensors. The local tensor dtype is bf16, but the Dtensor spec tensor meta dtype (self.w1._spec.tensor_meta.dtype
) is fp32. This mismatch seems to cause the meta registration error below.
Repro command
NGPU=4 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.steps=100 --parallelism.tensor_parallel_degree=2
Meta registration error
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_meta_registrations.py", line 7527, in _meta_grouped_mm_common
torch._check(
~~~~~~~~~~~~^
mat_a.dtype == torch.bfloat16 and mat_b.dtype == torch.bfloat16,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.",
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/__init__.py", line 1702, in _check
_check_with(RuntimeError, cond, message)
~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/__init__.py", line 1684, in _check_with
raise error_type(message_evaluated)
RuntimeError: Expected inputs of BF16 type but got mat_a.dtype=torch.bfloat16 and mat_b.dtype=torch.float32.
PDB log
The following pdb commands/log show inspection of self.w1
in the MoE layer, confirming the DTensor's local tensor dtype is bf16, yet the DTensorSpec has tensor meta dtype of fp32. This seems to be what is causing the meta registration error mismatch.
[rank0]: 86 -> torch.distributed.breakpoint()
[rank0]: 87 h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets))
[rank0]: 88 h = h * torch._grouped_mm(x, self.w3, offs=offsets)
[rank0]: 89 out = torch._grouped_mm(h, self.w2, offs=offsets)
[rank0]: 90
[rank0]: 91 return out
self.w1
[rank0]:(Pdb) [rank0]:DTensor(local_tensor=tensor([[[-0.0050, -0.0244, 0.0243, ..., 0.0317, 0.0069, -0.0222],
[rank0]: [-0.0125, 0.0201, -0.0250, ..., 0.0376, 0.0055, -0.0094],
[rank0]: [-0.0045, -0.0300, -0.0115, ..., -0.0493, -0.0259, 0.0117],
[rank0]: ...,
[rank0]: [-0.0112, -0.0012, -0.0051, ..., -0.0104, 0.0087, -0.0325],
[rank0]: [ 0.0209, 0.0086, 0.0109, ..., -0.0430, -0.0036, 0.0359],
[rank0]: [ 0.0110, -0.0234, -0.0066, ..., -0.0238, 0.0148, -0.0304]],
[rank0]:
[rank0]: [[-0.0168, -0.0038, 0.0179, ..., 0.0076, -0.0461, -0.0182],
[rank0]: [-0.0109, -0.0120, 0.0427, ..., -0.0027, -0.0048, -0.0131],
[rank0]: [-0.0156, 0.0018, -0.0083, ..., 0.0189, 0.0309, 0.0066],
[rank0]: ...,
[rank0]: [-0.0021, -0.0231, 0.0132, ..., -0.0095, -0.0050, -0.0168],
[rank0]: [-0.0422, 0.0035, 0.0017, ..., 0.0339, 0.0195, 0.0003],
[rank0]: [ 0.0183, 0.0415, 0.0552, ..., 0.0084, 0.0159, 0.0229]],
[rank0]:
[rank0]: [[ 0.0036, -0.0337, 0.0398, ..., 0.0027, -0.0219, 0.0043],
[rank0]: [-0.0107, -0.0270, 0.0166, ..., 0.0044, -0.0030, 0.0432],
[rank0]: [ 0.0233, 0.0203, 0.0106, ..., -0.0018, -0.0118, -0.0060],
[rank0]: ...,
[rank0]: [-0.0247, -0.0038, -0.0322, ..., 0.0172, 0.0156, -0.0047],
[rank0]: [-0.0225, 0.0289, 0.0299, ..., 0.0025, -0.0221, 0.0134],
[rank0]: [ 0.0093, 0.0255, -0.0039, ..., 0.0045, -0.0226, -0.0170]],
[rank0]:
[rank0]: ...,
[rank0]:
[rank0]: [[-0.0120, -0.0054, -0.0262, ..., 0.0086, -0.0012, -0.0043],
[rank0]: [-0.0192, -0.0245, 0.0143, ..., -0.0083, 0.0111, 0.0067],
[rank0]: [ 0.0220, -0.0182, 0.0442, ..., 0.0008, 0.0240, 0.0167],
[rank0]: ...,
[rank0]: [ 0.0165, -0.0152, 0.0175, ..., 0.0027, 0.0120, 0.0100],
[rank0]: [ 0.0050, -0.0135, 0.0160, ..., 0.0311, 0.0106, 0.0571],
[rank0]: [ 0.0199, -0.0073, 0.0215, ..., 0.0131, 0.0327, 0.0097]],
[rank0]:
[rank0]: [[ 0.0113, 0.0044, -0.0234, ..., 0.0009, 0.0026, -0.0031],
[rank0]: [ 0.0059, -0.0195, -0.0089, ..., 0.0269, -0.0195, 0.0033],
[rank0]: [ 0.0366, 0.0199, 0.0055, ..., -0.0400, -0.0101, -0.0386],
[rank0]: ...,
[rank0]: [-0.0040, -0.0228, -0.0114, ..., -0.0342, -0.0032, -0.0157],
[rank0]: [ 0.0277, -0.0120, -0.0300, ..., 0.0079, 0.0038, 0.0342],
[rank0]: [-0.0057, 0.0148, -0.0048, ..., -0.0192, -0.0291, 0.0187]],
[rank0]:
[rank0]: [[-0.0291, -0.0271, 0.0058, ..., 0.0035, 0.0095, 0.0045],
[rank0]: [ 0.0508, 0.0175, -0.0264, ..., 0.0070, -0.0014, -0.0064],
[rank0]: [ 0.0208, -0.0004, -0.0386, ..., -0.0505, -0.0194, 0.0293],
[rank0]: ...,
[rank0]: [-0.0161, 0.0170, 0.0060, ..., 0.0023, 0.0280, -0.0095],
[rank0]: [-0.0298, -0.0276, -0.0089, ..., 0.0184, -0.0007, -0.0137],
[rank0]: [ 0.0201, -0.0154, 0.0211, ..., 0.0220, -0.0047, -0.0265]]],
[rank0]: device='cuda:0', dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=2),))
type(self.w1)
[rank0]:(Pdb) [rank0]:<class 'torch.distributed.tensor.DTensor'>
self.w1.dtype
[rank0]:(Pdb) [rank0]:torch.bfloat16
self.w1._local_tensor.dtype
[rank0]:(Pdb) [rank0]:torch.bfloat16
self.w1._spec
[rank0]:(Pdb) [rank0]:DTensorSpec(mesh=DeviceMesh('cuda', [0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=2),), tensor_meta=TensorMeta(shape=torch.Size([8, 256, 512]), stride=(131072, 512, 1), dtype=torch.float32))
n
[rank0]:(Pdb) [rank0]:> /home/danvm/torchtitan/torchtitan/experiments/llama4/model/moe.py(87)forward()
[rank0]:-> h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets))
[rank0]:/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/distributed_c10d.py:4805: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
[rank0]: warnings.warn( # warn only once
n
[rank0]:(Pdb) [rank0]:RuntimeError: Expected inputs of BF16 type but got mat_a.dtype=torch.bfloat16 and mat_b.dtype=torch.float32.
Versions
torchtitan latest main branch