-
-
Notifications
You must be signed in to change notification settings - Fork 791
Open
Labels
Description
I quantized my model to 8 bits, then I tried to load it with FSDP and encountered a problem.
Must flatten tensors with uniform dtype but got torch.float16 and torch.int8
File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 771, in _validate_tensors_to_flatten
raise ValueError(
File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 633, in _init_flat_param_and_metadata
) = self._validate_tensors_to_flatten(params)
File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 583, in __init__
self._init_flat_param_and_metadata(
File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 615, in _init_param_handle_from_params
handle = FlatParamHandle(
File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 603, in _init_param_handle_from_module
_init_param_handle_from_params(state, managed_params, fully_sharded_module)
File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 509, in __init__
_init_param_handle_from_module(
File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 491, in _wrap
return wrapper_cls(module, **kwargs)
File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 562, in _recursive_wrap
return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 102, in _auto_wrap
_recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/conda/envs/multi-talk-quant/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 483, in __init__
_auto_wrap(
File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/hg-vgen-base-model/wan/distributed/fsdp.py", line 22, in shard_model
model = FSDP(
File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/hg-vgen-base-model/wan/multitalk.py", line 247, in __init__
self.model = shard_fn(self.model)
File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/hg-vgen-base-model/generate_multitalk.py", line 571, in generate
wan_i2v = wan.MultiTalkPipeline(
File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-camera3d/liuxiangyu18/hg-vgen-base-model/generate_multitalk.py", line 625, in <module>
generate(args)
ValueError: Must flatten tensors with uniform dtype but got torch.float16 and torch.int8
My quantitative operations are:
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
model = WanModel.from_pretrained(checkpoint_dir,quantization_config=quant_config,torch_dtype=torch.bfloat16)
model.eval().requires_grad_(False)
model.save_pretrained("weights/Wan2.1-I2V-14B-480P-int8-bit")
My FSDP operation is:
def shard_model(
model,
device_id,
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
process_group=None,
sharding_strategy=ShardingStrategy.FULL_SHARD,
sync_module_states=True,
):
model = FSDP(
module=model,
process_group=process_group,
sharding_strategy=sharding_strategy,
auto_wrap_policy=partial(
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
# mixed_precision=MixedPrecision(
# param_dtype=param_dtype,
# reduce_dtype=reduce_dtype,
# buffer_dtype=buffer_dtype),
device_id=device_id,
sync_module_states=sync_module_states)
return model