Skip to content

Problems occur when loading 8-bit quantized models using FSDP #1709

@Piggy-ch

Description

@Piggy-ch

I quantized my model to 8 bits, then I tried to load it with FSDP and encountered a problem.

Must flatten tensors with uniform dtype but got torch.float16 and torch.int8
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 771, in _validate_tensors_to_flatten
    raise ValueError(
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 633, in _init_flat_param_and_metadata
    ) = self._validate_tensors_to_flatten(params)
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 583, in __init__
    self._init_flat_param_and_metadata(
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 615, in _init_param_handle_from_params
    handle = FlatParamHandle(
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 603, in _init_param_handle_from_module
    _init_param_handle_from_params(state, managed_params, fully_sharded_module)
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 509, in __init__
    _init_param_handle_from_module(
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 491, in _wrap
    return wrapper_cls(module, **kwargs)
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 562, in _recursive_wrap
    return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 102, in _auto_wrap
    _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 483, in __init__
    _auto_wrap(
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/hg-vgen-base-model/wan/distributed/fsdp.py", line 22, in shard_model
    model = FSDP(
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/hg-vgen-base-model/wan/multitalk.py", line 247, in __init__
    self.model = shard_fn(self.model)
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/hg-vgen-base-model/generate_multitalk.py", line 571, in generate
    wan_i2v = wan.MultiTalkPipeline(
  File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/hg-vgen-base-model/generate_multitalk.py", line 625, in <module>
    generate(args)
ValueError: Must flatten tensors with uniform dtype but got torch.float16 and torch.int8

My quantitative operations are:

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
model = WanModel.from_pretrained(checkpoint_dir,quantization_config=quant_config,torch_dtype=torch.bfloat16)
model.eval().requires_grad_(False)

model.save_pretrained("weights/Wan2.1-I2V-14B-480P-int8-bit")

My FSDP operation is:

def shard_model(
    model,
    device_id,
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.float32,
    process_group=None,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    sync_module_states=True,
):
    model = FSDP(
        module=model,
        process_group=process_group,
        sharding_strategy=sharding_strategy,
        auto_wrap_policy=partial(
            lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
        # mixed_precision=MixedPrecision(
        #     param_dtype=param_dtype,
        #     reduce_dtype=reduce_dtype,
        #     buffer_dtype=buffer_dtype),
        device_id=device_id,
        sync_module_states=sync_module_states)
    return model

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions