-
-
Notifications
You must be signed in to change notification settings - Fork 791
Description
System Info
GPU type: NVIDIA RTX A6000
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.11.9numpy==2.2.6
torch==2.7.0
transformers==4.55.3
accelerate==1.10.0
datasets==3.4.1
trl==0.16.0
peft==0.15.2
bitsandbytes==0.47.0
safetensors==0.5.3
tokenizers==0.21.1
triton==3.3.0
Reproduction
I am attempting to save the optimizer state to enable resume_from_checkpoint. However, when training without save_only_model (so that the optimizer/scheduler are also checkpointed), the run consistently fails with an assertion error.
Training Code
def main():
tokenizer = AutoTokenizer.from_pretrained(args.model_uri, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
args.model_uri,
low_cpu_mem_usage=True,
attn_implementation="kernels-community/vllm-flash-attn3",
torch_dtype=torch.bfloat16,
)
model.config.use_cache = False
training_args = SFTConfig(
output_dir=args.checkpoint_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.valid_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
eval_accumulation_steps=1,
eval_strategy="steps",
eval_steps=args.eval_interval,
logging_strategy="steps",
logging_steps=args.logging_interval,
save_strategy="steps",
save_steps=args.model_save_interval,
save_total_limit=2,
max_steps=args.max_steps,
report_to="wandb",
dataloader_persistent_workers=True,
dataloader_num_workers=2,
seed=args.seed,
average_tokens_across_devices=False,
optim="paged_adamw_8bit",
run_name=args.run_name,
learning_rate=args.max_lr,
lr_scheduler_type="warmup_stable_decay",
lr_scheduler_kwargs={"num_decay_steps": 500},
warmup_steps=200,
weight_decay=args.weight_decay,
prediction_loss_only=True,
# save_only_model=True,
label_names=["labels"],
disable_tqdm=False,
remove_unused_columns=False,
fsdp="full_shard auto_wrap",
fsdp_config={
"fsdp_version": 2,
"reshard_after_forward": True,
"auto_wrap_policy": "transformer_based_wrap",
"activation_checkpointing": True,
"mixed_precision_policy": {
"param_dtype": "bfloat16",
"reduce_dtype": "bfloat16",
},
"cpu_offload": False,
"state_dict_type": "full_state_dict",
"fsdp_cpu_ram_efficient_loading": True,
},
)
trainer_kwargs = {
"model": model,
"args": training_args,
"train_dataset": train_dataset,
"eval_dataset": valid_dataset,
"tokenizer": tokenizer,
"data_collator": data_collator,
"callbacks": [PerplexityCallback()],
}
trainer = MySFTTrainer(**trainer_kwargs)
if not trainer.is_world_process_zero():
logger.disable("")
logger.info(f"Size of train dataset: {len(train_dataset)}")
logger.info(f"Size of valid dataset: {len(valid_dataset)}")
torch.cuda.empty_cache()
train_result = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
trainer.save_model()
trainer.save_state()
trainer.log_metrics("train", train_result.metrics)
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
logger.info("Done!")
cleanup()
if __name__ == "__main__":
main()
Command line
torchrun --nproc_per_node=4 \
--rdzv-backend=c10d \
--rdzv-endpoint=localhost:0 \
--module scripts.sfttrainer \
--model-uri Qwen/Qwen3-8B \
--max-length 8192 \
--max-lr 5e-6 \
--log-dir logs \
--run-name $run_name \
--train-batch-size 1 \
--valid-batch-size 2 \
--gradient-accumulation-steps 1 \
--eval-interval 100 \
--logging-interval 10 \
--model-save-interval 3000 \
--max-steps 5
I got the error below:
[rank0]: Traceback (most recent call last):
[rank0]: File "<frozen runpy>", line 198, in _run_module_as_main
[rank0]: File "<frozen runpy>", line 88, in _run_code
[rank0]: File "/home/evekimmy/main_project/joseon-analysis-v2/scripts/sfttrainer_resize_final.py", line 414, in <module>
[rank0]: main()
[rank0]: File "/home/evekimmy/main_project/joseon-analysis-v2/scripts/sfttrainer_resize_final.py", line 394, in main
[rank0]: train_result = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/transformers/trainer.py", line 2238, in train
[rank0]: return inner_training_loop(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/transformers/trainer.py", line 3144, in _maybe_log_save_evaluate
[rank0]: self._save_checkpoint(model, trial)
[rank0]: File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/transformers/trainer.py", line 3252, in _save_checkpoint
[rank0]: self._save_optimizer_and_scheduler(output_dir)
[rank0]: File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/transformers/trainer.py", line 3374, in _save_optimizer_and_scheduler
[rank0]: save_fsdp_optimizer(
[rank0]: File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/accelerate/utils/fsdp_utils.py", line 256, in save_fsdp_optimizer
[rank0]: optim_state = FSDP.optim_state_dict(model, optimizer)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1882, in optim_
state_dict
[rank0]: return FullyShardedDataParallel._optim_state_dict_impl(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1293, in _optim
_state_dict_impl
[rank0]: return _optim_state_dict(
[rank0]: ^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1958, in _optim_state_dict
[rank0]: fsdp_osd_state = convert_fn(
[rank0]: ^^^^^^^^^^^
[rank0]: File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1781, in _convert_state_with_o
rig_params
[rank0]: _gather_all_orig_param_state(
[rank0]: File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1675, in _gather_all_orig_para
m_state
[rank0]: output_states = _allgather_orig_param_states(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1505, in _allgather_orig_param
_states
[rank0]: dtype, state_buffers = _convert_all_state_info(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/evekimmy/miniconda3/envs/py11/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1364, in _convert_all_state_in
fo
[rank0]: assert dtype == info.dtype
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError
How can I checkpoint the 8-bit optimizer state without encountering the assertion error?
Expected behavior
I was hoping the checkpoint could save the 8-bit optimizer state to support resume_from_checkpoint without errors.