Skip to content

unable to run FSDP2 with low bit optimizers like adam 8 bit #1403

@nighting0le01

Description

@nighting0le01

Feature request

Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/ubuntu/projects/scripts/train_ranker.py", line 404, in <module>
    train(accelerator, args)
  File "/home/ubuntu/projects/scripts/train_ranker.py", line 357, in train
    save_training_artifacts(
  File "/home/ubuntu/projects/scripts/train_ranker.py", line 74, in save_training_artifacts
    accelerator.save_state(save_dir)
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/accelerate/accelerator.py", line 2958, in save_state
    save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 168, in save_fsdp_optimizer
    optim_state = FSDP.optim_state_dict(model, optimizer)
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1840, in optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_impl(
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1263, in _optim_state_dict_impl
    return _optim_state_dict(
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1971, in _optim_state_dict
    fsdp_osd_state = convert_fn(
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1794, in _convert_state_with_orig_params
    _gather_all_orig_param_state(
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1688, in _gather_all_orig_param_state
    output_states = _allgather_orig_param_states(
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1518, in _allgather_orig_param_states
    dtype, state_buffers = _convert_all_state_info(
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1377, in _convert_all_state_info
    assert dtype == info.dtype
AssertionError

Motivation

to use FSDP2 and FSDP1 with low bit optimizers together.

Your contribution

please let me know

Metadata

Metadata

Assignees

No one assigned

    Labels

    DuplicateThis issue or pull request already existsFSDPOptimizersIssues or feature requests relating to optimizers

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions