Skip to content

Llama4 training does not automatically use bfloat16 when FSDP2 is enabled #1332

@danielvegamyhre

Description

@danielvegamyhre

Bug description

For Llama3, when using FSDP2 the model weights are automatically converted to bfloat16 for mixed precision training. However, I notice with Llama4, when using FSDP2 I have to manually cast the weights to bfloat16 before this line otherwise they stay in fp32 and I can't use float8 GEMMs.

Versions

torchtitan latest main branch

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions