Skip to content

Using Layer Wise Upcasting with WAN gives OOM with loras #11223

@asomoza

Description

@asomoza

Describe the bug

When using Layer Wise Upcasting with WAN it works as intended but if it's used with a lora, it OOMs, only tested this with a 24GB GPU.

Reproduction

Tested with this lora:

import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, WanTransformer3DModel
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel, UMT5EncoderModel

model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"

image_encoder = CLIPVisionModel.from_pretrained(
    model_id, subfolder="image_encoder", torch_dtype=torch.float32
)

text_encoder = UMT5EncoderModel.from_pretrained(
    model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
)
vae = AutoencoderKLWan.from_pretrained(
    model_id, subfolder="vae", torch_dtype=torch.float32
)
transformer = WanTransformer3DModel.from_pretrained(
    model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)

state_dict = WanImageToVideoPipeline.lora_state_dict("loras/super_saiyan_35_epochs.safetensors")
WanImageToVideoPipeline.load_lora_into_transformer(state_dict, transformer)

transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)

pipe = WanImageToVideoPipeline.from_pretrained(
    model_id,
    vae=vae,
    transformer=transformer,
    text_encoder=text_encoder,
    image_encoder=image_encoder,
    torch_dtype=torch.bfloat16,
)

pipe.enable_model_cpu_offload()

image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/wan_i2v/bird.png")

max_area = 480 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))

prompt = ("a bird is standing on a branch, its hair brightens to glowing yellow, spiking up as gold energy around its body. The background pules with yellow" 
          "light, and sparks crackle in the air during its 5up3r super saiyan transformation")
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

num_frames = 65

output = pipe(
    image=image,
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=height,
    width=width,
    num_frames=num_frames,
    guidance_scale=5.0,
).frames[0]

export_to_video(output, "wan-i2v_lora.mp4", fps=16)

Logs

Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.08it/s]
  0%|                                                                                                                 | 0/50 [00:10<?, ?it/s]
Traceback (most recent call last):
  File "/home/ec2-user/diffusers/test_wan_i2v_layerwise.py", line 55, in <module>
    output = pipe(
  File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/ec2-user/diffusers/src/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 627, in __call__
    noise_pred = self.transformer(
  File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/ec2-user/diffusers/src/diffusers/models/transformers/transformer_wan.py", line 440, in forward
    hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
  File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ec2-user/diffusers/src/diffusers/models/transformers/transformer_wan.py", line 285, in forward
    ff_output = self.ffn(norm_hidden_states)
  File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ec2-user/diffusers/src/diffusers/models/attention.py", line 1250, in forward
    hidden_states = module(hidden_states)
  File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ec2-user/diffusers/src/diffusers/models/activations.py", line 88, in forward
    hidden_states = self.proj(hidden_states)
  File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ec2-user/diffusers/.venv/lib64/python3.9/site-packages/peft/tuners/lora/layer.py", line 621, in forward
    result = result + lora_B(lora_A(dropout(x))) * scaling
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 700.00 MiB. GPU 0 has a total capacity of 22.09 GiB of which 581.44 MiB is free. Including non-PyTorch memory, this process has 21.51 GiB memory in use. Of the allocated memory 19.27 GiB is allocated by PyTorch, and 1.94 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

System Info

24GB GPU AWS Instance

Who can help?

@sayakpaul @a-r-r-o-w @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions