Skip to content

Llama4 TP bug: DTensor local tensor dtype does not match DTensorSpec tensor meta dtype, causing meta registration error #1355

Open
@danielvegamyhre

Description

@danielvegamyhre

Bug description

When I apply FSDP+TP to the Llama4 debug model using plain eager bf16 training, the MoE routed experts weights are DTensors. The local tensor dtype is bf16, but the Dtensor spec tensor meta dtype (self.w1._spec.tensor_meta.dtype) is fp32. This mismatch seems to cause the meta registration error below.

Repro command

NGPU=4 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.steps=100 --parallelism.tensor_parallel_degree=2 

Meta registration error

   File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_meta_registrations.py", line 7527, in _meta_grouped_mm_common
      torch._check(
      ~~~~~~~~~~~~^
          mat_a.dtype == torch.bfloat16 and mat_b.dtype == torch.bfloat16,
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
          lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.",
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      )
      ^
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/__init__.py", line 1702, in _check
      _check_with(RuntimeError, cond, message)
      ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/__init__.py", line 1684, in _check_with
      raise error_type(message_evaluated)
  RuntimeError: Expected inputs of BF16 type but got mat_a.dtype=torch.bfloat16 and mat_b.dtype=torch.float32.

PDB log

The following pdb commands/log show inspection of self.w1 in the MoE layer, confirming the DTensor's local tensor dtype is bf16, yet the DTensorSpec has tensor meta dtype of fp32. This seems to be what is causing the meta registration error mismatch.

[rank0]: 86  ->         torch.distributed.breakpoint()
[rank0]: 87             h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets))
[rank0]: 88             h = h * torch._grouped_mm(x, self.w3, offs=offsets)
[rank0]: 89             out = torch._grouped_mm(h, self.w2, offs=offsets)
[rank0]: 90  
[rank0]: 91             return out

self.w1
[rank0]:(Pdb) [rank0]:DTensor(local_tensor=tensor([[[-0.0050, -0.0244,  0.0243,  ...,  0.0317,  0.0069, -0.0222],
[rank0]:         [-0.0125,  0.0201, -0.0250,  ...,  0.0376,  0.0055, -0.0094],
[rank0]:         [-0.0045, -0.0300, -0.0115,  ..., -0.0493, -0.0259,  0.0117],
[rank0]:         ...,
[rank0]:         [-0.0112, -0.0012, -0.0051,  ..., -0.0104,  0.0087, -0.0325],
[rank0]:         [ 0.0209,  0.0086,  0.0109,  ..., -0.0430, -0.0036,  0.0359],
[rank0]:         [ 0.0110, -0.0234, -0.0066,  ..., -0.0238,  0.0148, -0.0304]],
[rank0]:
[rank0]:        [[-0.0168, -0.0038,  0.0179,  ...,  0.0076, -0.0461, -0.0182],
[rank0]:         [-0.0109, -0.0120,  0.0427,  ..., -0.0027, -0.0048, -0.0131],
[rank0]:         [-0.0156,  0.0018, -0.0083,  ...,  0.0189,  0.0309,  0.0066],
[rank0]:         ...,
[rank0]:         [-0.0021, -0.0231,  0.0132,  ..., -0.0095, -0.0050, -0.0168],
[rank0]:         [-0.0422,  0.0035,  0.0017,  ...,  0.0339,  0.0195,  0.0003],
[rank0]:         [ 0.0183,  0.0415,  0.0552,  ...,  0.0084,  0.0159,  0.0229]],
[rank0]:
[rank0]:        [[ 0.0036, -0.0337,  0.0398,  ...,  0.0027, -0.0219,  0.0043],
[rank0]:         [-0.0107, -0.0270,  0.0166,  ...,  0.0044, -0.0030,  0.0432],
[rank0]:         [ 0.0233,  0.0203,  0.0106,  ..., -0.0018, -0.0118, -0.0060],
[rank0]:         ...,
[rank0]:         [-0.0247, -0.0038, -0.0322,  ...,  0.0172,  0.0156, -0.0047],
[rank0]:         [-0.0225,  0.0289,  0.0299,  ...,  0.0025, -0.0221,  0.0134],
[rank0]:         [ 0.0093,  0.0255, -0.0039,  ...,  0.0045, -0.0226, -0.0170]],
[rank0]:
[rank0]:        ...,
[rank0]:
[rank0]:        [[-0.0120, -0.0054, -0.0262,  ...,  0.0086, -0.0012, -0.0043],
[rank0]:         [-0.0192, -0.0245,  0.0143,  ..., -0.0083,  0.0111,  0.0067],
[rank0]:         [ 0.0220, -0.0182,  0.0442,  ...,  0.0008,  0.0240,  0.0167],
[rank0]:         ...,
[rank0]:         [ 0.0165, -0.0152,  0.0175,  ...,  0.0027,  0.0120,  0.0100],
[rank0]:         [ 0.0050, -0.0135,  0.0160,  ...,  0.0311,  0.0106,  0.0571],
[rank0]:         [ 0.0199, -0.0073,  0.0215,  ...,  0.0131,  0.0327,  0.0097]],
[rank0]:
[rank0]:        [[ 0.0113,  0.0044, -0.0234,  ...,  0.0009,  0.0026, -0.0031],
[rank0]:         [ 0.0059, -0.0195, -0.0089,  ...,  0.0269, -0.0195,  0.0033],
[rank0]:         [ 0.0366,  0.0199,  0.0055,  ..., -0.0400, -0.0101, -0.0386],
[rank0]:         ...,
[rank0]:         [-0.0040, -0.0228, -0.0114,  ..., -0.0342, -0.0032, -0.0157],
[rank0]:         [ 0.0277, -0.0120, -0.0300,  ...,  0.0079,  0.0038,  0.0342],
[rank0]:         [-0.0057,  0.0148, -0.0048,  ..., -0.0192, -0.0291,  0.0187]],
[rank0]:
[rank0]:        [[-0.0291, -0.0271,  0.0058,  ...,  0.0035,  0.0095,  0.0045],
[rank0]:         [ 0.0508,  0.0175, -0.0264,  ...,  0.0070, -0.0014, -0.0064],
[rank0]:         [ 0.0208, -0.0004, -0.0386,  ..., -0.0505, -0.0194,  0.0293],
[rank0]:         ...,
[rank0]:         [-0.0161,  0.0170,  0.0060,  ...,  0.0023,  0.0280, -0.0095],
[rank0]:         [-0.0298, -0.0276, -0.0089,  ...,  0.0184, -0.0007, -0.0137],
[rank0]:         [ 0.0201, -0.0154,  0.0211,  ...,  0.0220, -0.0047, -0.0265]]],
[rank0]:       device='cuda:0', dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=2),))

type(self.w1)
[rank0]:(Pdb) [rank0]:<class 'torch.distributed.tensor.DTensor'>

self.w1.dtype
[rank0]:(Pdb) [rank0]:torch.bfloat16

self.w1._local_tensor.dtype
[rank0]:(Pdb) [rank0]:torch.bfloat16

self.w1._spec 
[rank0]:(Pdb) [rank0]:DTensorSpec(mesh=DeviceMesh('cuda', [0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=2),), tensor_meta=TensorMeta(shape=torch.Size([8, 256, 512]), stride=(131072, 512, 1), dtype=torch.float32))

n
[rank0]:(Pdb) [rank0]:> /home/danvm/torchtitan/torchtitan/experiments/llama4/model/moe.py(87)forward()
[rank0]:-> h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets))
[rank0]:/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/distributed_c10d.py:4805: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. 
[rank0]:  warnings.warn(  # warn only once

n
[rank0]:(Pdb) [rank0]:RuntimeError: Expected inputs of BF16 type but got mat_a.dtype=torch.bfloat16 and mat_b.dtype=torch.float32.

Versions

torchtitan latest main branch

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions