-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Description
Describe the bug
The WAN transformer doesn't work. It spits out garbage (basically all the dict items in the model) resulting in an empty transformer which can't be quantized.
transformer = WanTransformer3DModel.from_pretrained( #<<<< FAILS base_model, torch_dtype=dtype, use_safetensors=True, )
is not currently supported. After loading the (6) safetensor files, the transformer dumps the data to output resulting in an empty transformer that can't be quantized.
Fetching 6 files: 100%|██████████| 6/6 [16:17<00:00, 162.85s/it]
The config attributes {'dim': 5120, 'in_dim': 16, 'model_type': 't2v', 'num_heads': 40, 'out_dim': 16, 'text_len': 512} were passed to WanTransformer3DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00, 54.16it/s]
Some weights of the model checkpoint at Wan-AI/Wan2.1-T2V-14B were not used when initializing WanTransformer3DModel:
['blocks.35.self_attn.q.bias, blocks.9.self_attn.o.bias, blocks.21.self_attn.q.weight, time_embedding.2.weight, blocks.19.ffn.0.bias, blocks.11.self_attn.norm_q.weight, blocks.24.self_attn.k.weight, blocks.6.self_attn.q.weight, blocks.12.cross_attn.v.weight, blocks.20.self_attn.k.weight, blocks.28.cross_attn.o.weight, blocks.9.self_attn.norm_q.weight, blocks.22.cross_attn.v.weight, blocks.21.cross_attn.v.weight, blocks.31.cross_attn.q.bias, blocks.3.cross_attn.q.weight, blocks.18.self_attn.k.weight, blocks.26.self_attn.v.weight, blocks.33.self_attn.v.weight, blocks.26.cross_attn.norm_k.weight, blocks.33.cross_attn.norm_k.weight, blocks.20.ffn.2.weight, blocks.13.ffn.0.weight, blocks.0.ffn.0.bias, blocks.11.cross_attn.v.bias, blocks.26.cross_attn.o.weight, blocks.7.cross_attn.k.weight, blocks.30.cross_attn.k.weight, time_embedding.2.bias, blocks.31.ffn.2.bias, blocks.31.self_attn.k.weight, blocks.25.self_attn.k.weight, blocks.30.cross_attn.norm_q.weight, blocks.35.cross_attn.k.bias, blocks.29.cross_attn.norm_q.weight, blocks.26.norm3.bias, blocks.29.self_attn.k.bias, blocks.3.cross_attn.o.bias, blocks.20.cross_attn.v.weight, blocks.39.cross_attn.q.bias, blocks.32.ffn.0.bias, blocks.27.self_attn.k.weight, blocks.20.self_attn.v.bias, blocks.25.cross_attn.k.bias, blocks.29.cross_attn.v.bias, blocks.33.ffn.0.bias, blocks.6.cross_attn.k.bias, blocks.5.ffn.2.weight, blocks.7.self_attn.k.weight, blocks.8.self_attn.o.weight, blocks.23.self_attn.o.bias, blocks.3.self_attn.k.weight, blocks.4.self_attn.v.weight, blocks.8.self_attn.norm_q.weight, blocks.12.ffn.0.bias, blocks.13.cross_attn.q.weight, blocks.1.self_attn.v.weight, blocks.18.cross_attn.q.bias, blocks.8.self_attn.k.bias, blocks.19.cross_attn.v.bias, blocks.31.self_attn.o.bias,
....
and it goes on and on.
Reproduction
`def quantize_wan_model():
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers import AutoencoderKLWan, WanTransformer3DModel
from optimum.quanto import freeze, qint8, quantize, quantization_map
from pathlib import Path
import json
################## QUANTIZE WAN MODEL #########################
base_model = 'Wan-AI/Wan2.1-T2V-14B'
dtype = torch.bfloat16
print('Quantize Transformer')
transformer = WanTransformer3DModel.from_pretrained( #<<<< FAILS
base_model,
torch_dtype=dtype,
use_safetensors=True,
)
quantize(transformer, weights=qint8)
freeze(transformer)
save_directory = "./wan quantro 14B/basemodel/wantransformer3dmodel_qint8"
transformer.save_pretrained(save_directory)
qmap_name = Path(save_directory, "quanto_qmap.json" )
qmap = quantization_map(transformer)
with open (qmap_name, "w" , encoding= "utf8" ) as f:
json.dump(qmap, f, indent= 4 )
print('Transformer done')
return`
Logs
LOG is indicated above with issue.
System Info
Windows 10
Python 3.12.5
diffusers-0.33.0.dev0.dist-info
transformers-4.49.0.dist-info
optimum-1.24.0.dist-info
optimum_quanto-0.2.6.dist-info
Who can help?
No response