Skip to content

Assertion Error during Optimizer State Checkpointing in FSDP2 (bf16 Model + 8-bit Optimizer) #1732

@EvelynKimm

Description

@EvelynKimm

System Info

GPU type: NVIDIA RTX A6000

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.11.9

numpy==2.2.6
torch==2.7.0
transformers==4.55.3
accelerate==1.10.0
datasets==3.4.1
trl==0.16.0
peft==0.15.2
bitsandbytes==0.47.0
safetensors==0.5.3
tokenizers==0.21.1
triton==3.3.0

Reproduction

I am attempting to save the optimizer state to enable resume_from_checkpoint. However, when training without save_only_model (so that the optimizer/scheduler are also checkpointed), the run consistently fails with an assertion error.

Training Code

def main():
    tokenizer = AutoTokenizer.from_pretrained(args.model_uri, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(
        args.model_uri,
        low_cpu_mem_usage=True,
        attn_implementation="kernels-community/vllm-flash-attn3",
        torch_dtype=torch.bfloat16,
    )

    model.config.use_cache = False

    training_args = SFTConfig(
        output_dir=args.checkpoint_dir,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.train_batch_size,
        per_device_eval_batch_size=args.valid_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        eval_accumulation_steps=1,
        eval_strategy="steps",
        eval_steps=args.eval_interval,
        logging_strategy="steps",
        logging_steps=args.logging_interval,
        save_strategy="steps",
        save_steps=args.model_save_interval,
        save_total_limit=2,
        max_steps=args.max_steps,  
        report_to="wandb", 
        dataloader_persistent_workers=True,
        dataloader_num_workers=2,
        seed=args.seed,
        average_tokens_across_devices=False,  
        optim="paged_adamw_8bit",
        run_name=args.run_name,
        learning_rate=args.max_lr,
        lr_scheduler_type="warmup_stable_decay",
        lr_scheduler_kwargs={"num_decay_steps": 500},
        warmup_steps=200,
        weight_decay=args.weight_decay,
        prediction_loss_only=True,
        # save_only_model=True,
        label_names=["labels"],
        disable_tqdm=False,
        remove_unused_columns=False,
        fsdp="full_shard auto_wrap",
        fsdp_config={
            "fsdp_version": 2,  
            "reshard_after_forward": True,  
            "auto_wrap_policy": "transformer_based_wrap",
            "activation_checkpointing": True,
            "mixed_precision_policy": {
                "param_dtype": "bfloat16",
                "reduce_dtype": "bfloat16",
            },
            "cpu_offload": False,
            "state_dict_type": "full_state_dict",
            "fsdp_cpu_ram_efficient_loading": True,
        },
    )


    trainer_kwargs = {
        "model": model,
        "args": training_args,
        "train_dataset": train_dataset,
        "eval_dataset": valid_dataset,
        "tokenizer": tokenizer,
        "data_collator": data_collator,
        "callbacks": [PerplexityCallback()],
    }

    trainer = MySFTTrainer(**trainer_kwargs)

    if not trainer.is_world_process_zero():
        logger.disable("")

    logger.info(f"Size of train dataset: {len(train_dataset)}")
    logger.info(f"Size of valid dataset: {len(valid_dataset)}")

    torch.cuda.empty_cache()

    train_result = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
    trainer.save_model()
    trainer.save_state()
    trainer.log_metrics("train", train_result.metrics)

    logger.info("*** Evaluate ***")
    metrics = trainer.evaluate()
    trainer.log_metrics("eval", metrics)

    logger.info("Done!")

    cleanup()

if __name__ == "__main__":
    main()

Command line

torchrun --nproc_per_node=4 \
  --rdzv-backend=c10d \
  --rdzv-endpoint=localhost:0 \
  --module scripts.sfttrainer \
  --model-uri Qwen/Qwen3-8B \
  --max-length 8192 \
  --max-lr 5e-6 \
  --log-dir logs \
  --run-name $run_name \
  --train-batch-size 1 \
  --valid-batch-size 2 \
  --gradient-accumulation-steps 1 \
  --eval-interval 100 \
  --logging-interval 10 \
  --model-save-interval 3000 \
  --max-steps 5

I got the error below:

[rank0]: Traceback (most recent call last):
[rank0]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank0]:   File "<frozen runpy>", line 88, in _run_code
[rank0]:   File "/home/evekimmy/main_project/joseon-analysis-v2/scripts/sfttrainer_resize_final.py", line 414, in <module>
[rank0]:     main()
[rank0]:   File "/home/evekimmy/main_project/joseon-analysis-v2/scripts/sfttrainer_resize_final.py", line 394, in main
[rank0]:     train_result = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/transformers/trainer.py", line 2238, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/transformers/trainer.py", line 3144, in _maybe_log_save_evaluate 
[rank0]:     self._save_checkpoint(model, trial)
[rank0]:   File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/transformers/trainer.py", line 3252, in _save_checkpoint
[rank0]:     self._save_optimizer_and_scheduler(output_dir)
[rank0]:   File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/transformers/trainer.py", line 3374, in _save_optimizer_and_scheduler
[rank0]:     save_fsdp_optimizer(
[rank0]:   File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/accelerate/utils/fsdp_utils.py", line 256, in save_fsdp_optimizer
[rank0]:     optim_state = FSDP.optim_state_dict(model, optimizer)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1882, in optim_
state_dict
[rank0]:     return FullyShardedDataParallel._optim_state_dict_impl(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1293, in _optim
_state_dict_impl
[rank0]:     return _optim_state_dict(
[rank0]:            ^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1958, in _optim_state_dict
[rank0]:     fsdp_osd_state = convert_fn(
[rank0]:                      ^^^^^^^^^^^
[rank0]:   File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1781, in _convert_state_with_o
rig_params
[rank0]:     _gather_all_orig_param_state(
[rank0]:   File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1675, in _gather_all_orig_para
m_state
[rank0]:     output_states = _allgather_orig_param_states(
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1505, in _allgather_orig_param
_states
[rank0]:     dtype, state_buffers = _convert_all_state_info(
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1364, in _convert_all_state_in
fo
[rank0]:     assert dtype == info.dtype
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError

How can I checkpoint the 8-bit optimizer state without encountering the assertion error?

Expected behavior

I was hoping the checkpoint could save the 8-bit optimizer state to support resume_from_checkpoint without errors.

Metadata

Metadata

Assignees

No one assigned

    Labels

    FSDPOptimizersIssues or feature requests relating to optimizers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions