Skip to content

Lack of Quanto support for transforming a WAN 2.1 model #11022

@ukaprch

Description

@ukaprch

Describe the bug

The WAN transformer doesn't work. It spits out garbage (basically all the dict items in the model) resulting in an empty transformer which can't be quantized.

transformer = WanTransformer3DModel.from_pretrained( #<<<< FAILS base_model, torch_dtype=dtype, use_safetensors=True, )
is not currently supported. After loading the (6) safetensor files, the transformer dumps the data to output resulting in an empty transformer that can't be quantized.

Fetching 6 files: 100%|██████████| 6/6 [16:17<00:00, 162.85s/it]
The config attributes {'dim': 5120, 'in_dim': 16, 'model_type': 't2v', 'num_heads': 40, 'out_dim': 16, 'text_len': 512} were passed to WanTransformer3DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00, 54.16it/s]
Some weights of the model checkpoint at Wan-AI/Wan2.1-T2V-14B were not used when initializing WanTransformer3DModel:
['blocks.35.self_attn.q.bias, blocks.9.self_attn.o.bias, blocks.21.self_attn.q.weight, time_embedding.2.weight, blocks.19.ffn.0.bias, blocks.11.self_attn.norm_q.weight, blocks.24.self_attn.k.weight, blocks.6.self_attn.q.weight, blocks.12.cross_attn.v.weight, blocks.20.self_attn.k.weight, blocks.28.cross_attn.o.weight, blocks.9.self_attn.norm_q.weight, blocks.22.cross_attn.v.weight, blocks.21.cross_attn.v.weight, blocks.31.cross_attn.q.bias, blocks.3.cross_attn.q.weight, blocks.18.self_attn.k.weight, blocks.26.self_attn.v.weight, blocks.33.self_attn.v.weight, blocks.26.cross_attn.norm_k.weight, blocks.33.cross_attn.norm_k.weight, blocks.20.ffn.2.weight, blocks.13.ffn.0.weight, blocks.0.ffn.0.bias, blocks.11.cross_attn.v.bias, blocks.26.cross_attn.o.weight, blocks.7.cross_attn.k.weight, blocks.30.cross_attn.k.weight, time_embedding.2.bias, blocks.31.ffn.2.bias, blocks.31.self_attn.k.weight, blocks.25.self_attn.k.weight, blocks.30.cross_attn.norm_q.weight, blocks.35.cross_attn.k.bias, blocks.29.cross_attn.norm_q.weight, blocks.26.norm3.bias, blocks.29.self_attn.k.bias, blocks.3.cross_attn.o.bias, blocks.20.cross_attn.v.weight, blocks.39.cross_attn.q.bias, blocks.32.ffn.0.bias, blocks.27.self_attn.k.weight, blocks.20.self_attn.v.bias, blocks.25.cross_attn.k.bias, blocks.29.cross_attn.v.bias, blocks.33.ffn.0.bias, blocks.6.cross_attn.k.bias, blocks.5.ffn.2.weight, blocks.7.self_attn.k.weight, blocks.8.self_attn.o.weight, blocks.23.self_attn.o.bias, blocks.3.self_attn.k.weight, blocks.4.self_attn.v.weight, blocks.8.self_attn.norm_q.weight, blocks.12.ffn.0.bias, blocks.13.cross_attn.q.weight, blocks.1.self_attn.v.weight, blocks.18.cross_attn.q.bias, blocks.8.self_attn.k.bias, blocks.19.cross_attn.v.bias, blocks.31.self_attn.o.bias,
....
and it goes on and on.

Reproduction

`def quantize_wan_model():
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers import AutoencoderKLWan, WanTransformer3DModel
from optimum.quanto import freeze, qint8, quantize, quantization_map
from pathlib import Path
import json

################## QUANTIZE WAN MODEL #########################
base_model = 'Wan-AI/Wan2.1-T2V-14B'
dtype = torch.bfloat16
 
print('Quantize Transformer')
transformer = WanTransformer3DModel.from_pretrained(              #<<<< FAILS
    base_model,
    torch_dtype=dtype,
    use_safetensors=True,
)
quantize(transformer, weights=qint8)
freeze(transformer)

save_directory = "./wan quantro 14B/basemodel/wantransformer3dmodel_qint8"
transformer.save_pretrained(save_directory)
qmap_name = Path(save_directory, "quanto_qmap.json" )
qmap = quantization_map(transformer)
with  open (qmap_name, "w" , encoding= "utf8" ) as f:
    json.dump(qmap, f, indent= 4 )
print('Transformer done')
return`

Logs

LOG is indicated above with issue.

System Info

Windows 10
Python 3.12.5
diffusers-0.33.0.dev0.dist-info
transformers-4.49.0.dist-info
optimum-1.24.0.dist-info
optimum_quanto-0.2.6.dist-info

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions