-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
When using Layer Wise Upcasting with WAN it works as intended but if it's used with a lora, it OOMs, only tested this with a 24GB GPU.
Reproduction
Tested with this lora:
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, WanTransformer3DModel
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel, UMT5EncoderModel
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
text_encoder = UMT5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
)
vae = AutoencoderKLWan.from_pretrained(
model_id, subfolder="vae", torch_dtype=torch.float32
)
transformer = WanTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
state_dict = WanImageToVideoPipeline.lora_state_dict("loras/super_saiyan_35_epochs.safetensors")
WanImageToVideoPipeline.load_lora_into_transformer(state_dict, transformer)
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/wan_i2v/bird.png")
max_area = 480 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = ("a bird is standing on a branch, its hair brightens to glowing yellow, spiking up as gold energy around its body. The background pules with yellow"
"light, and sparks crackle in the air during its 5up3r super saiyan transformation")
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 65
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v_lora.mp4", fps=16)
Logs
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 9.08it/s]
0%| | 0/50 [00:10<?, ?it/s]
Traceback (most recent call last):
File "/home/ec2-user/diffusers/test_wan_i2v_layerwise.py", line 55, in <module>
output = pipe(
File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/ec2-user/diffusers/src/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 627, in __call__
noise_pred = self.transformer(
File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/accelerate/hooks.py", line 170, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/ec2-user/diffusers/src/diffusers/models/transformers/transformer_wan.py", line 440, in forward
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ec2-user/diffusers/src/diffusers/models/transformers/transformer_wan.py", line 285, in forward
ff_output = self.ffn(norm_hidden_states)
File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ec2-user/diffusers/src/diffusers/models/attention.py", line 1250, in forward
hidden_states = module(hidden_states)
File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ec2-user/diffusers/src/diffusers/models/activations.py", line 88, in forward
hidden_states = self.proj(hidden_states)
File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/peft/tuners/lora/layer.py", line 621, in forward
result = result + lora_B(lora_A(dropout(x))) * scaling
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 700.00 MiB. GPU 0 has a total capacity of 22.09 GiB of which 581.44 MiB is free. Including non-PyTorch memory, this process has 21.51 GiB memory in use. Of the allocated memory 19.27 GiB is allocated by PyTorch, and 1.94 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
System Info
24GB GPU AWS Instance
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working