Skip to content

bnb.optim.AdamW performance differs from torch.optim.AdamW despite being called with the same hyperparameters #1756

@inkitori

Description

@inkitori

Hi, I noticed that when I swap out torch.optim.AdamW with bnb.optim.AdamW (both in 32 bit precision) and I fine tune a model loaded in bf16, the overall performance of the model trained using bnb's AdamW is higher than with Torch's. Furthermore, if I fine tune a model loaded in fp16, torch's AdamW leads to NaN values after the very first optim.step(), while bnb's AdamW trains perfectly fine without any NaNs. Is there a reason for this performance discrepancy, especially in the latter case? Does bnb internally use a different dtype than torch, or something else? Thank you

Metadata

Metadata

Assignees

No one assigned

    Labels

    OptimizersIssues or feature requests relating to optimizers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions