Hi, I noticed that when I swap out torch.optim.AdamW with bnb.optim.AdamW (both in 32 bit precision) and I fine tune a model loaded in bf16, the overall performance of the model trained using bnb's AdamW is higher than with Torch's. Furthermore, if I fine tune a model loaded in fp16, torch's AdamW leads to NaN values after the very first optim.step(), while bnb's AdamW trains perfectly fine without any NaNs. Is there a reason for this performance discrepancy, especially in the latter case? Does bnb internally use a different dtype than torch, or something else? Thank you