Skip to content

[wip][poc] make group offloading work with disk/nvme transfers #11682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Jun 9, 2025

What does this PR do?

Group offloading is a crucial feature to provide a good speed-memory trade-off for large models on consumer hardware. However, since group offloading relies quite a bit on RAM usage, it can be bottlenecked by its availability. As such, for machines where GPU VRAM > available RAM or machines have limited RAM, group offloading can be far from ideal.

This PR takes a stab at supporting disk/NMVe serialization/deserialization inside group offloading so that users can use the secondary memory to onload/offload model params while also benefiting from the overlapping between compute and data transfer.

Below are some numbers I have gathered with this PR:

Mode Time (s) RAM (GB) GPU (GB)
base 6.594 1.838 33.85
model CPU offload 21.682 35.504 22.64
sequential CPU offload 68.406 33.196 2.41
group offload 47.693 36.421 11.68
group offload with disk / NVMe support 55.467 2.814 11.68
same + compile 55.296 3.036 11.68
Code
from diffusers import DiffusionPipeline
import torch.utils.benchmark as benchmark
import torch
import psutil
import os
import json
import argparse

def benchmark_fn(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)",
        globals={"args": args, "kwargs": kwargs, "f": f},
        num_threads=torch.get_num_threads(),
    )
    return f"{(t0.blocked_autorange().mean):.3f}"

def run_inference(pipe, pipe_kwargs):
    _ = pipe(**pipe_kwargs)

def initialize_pipeline():
    pipe = DiffusionPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
    )
    pipe.set_progress_bar_config(disable=True)
    return pipe

def maybe_apply_offloading(pipe, args):
    if not args.model_cpu_offload and not args.seq_cpu_offload and not args.group_offload:
        pipe = pipe.to("cuda")
    else:
        if args.model_cpu_offload:
            pipe.enable_model_cpu_offload()
        elif args.seq_cpu_offload:
            pipe.enable_sequential_cpu_offload()
        elif args.group_offload:
            pipe.transformer.enable_group_offload(
                onload_device=torch.device("cuda"), 
                offload_device=torch.device("cpu"), 
                offload_type="block_level",
                num_blocks_per_group=1,
                use_stream=True,
                non_blocking=False,
                offload_to_disk=True if args.offload_to_disk else False,
                offload_path="." if args.offload_to_disk else None,
                record_stream=True
            )
            
            # For the rest of the components, just place on CUDA.
            for name, component in pipe.components.items():
                if name != "transformer" and isinstance(component, torch.nn.Module):
                    component.cuda()

    return pipe


def main(args):
    process = psutil.Process(os.getpid())
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()

    pipe = initialize_pipeline()
    pipe = maybe_apply_offloading(pipe, args)
    pipe_kwargs = {
        "prompt": "A cat holding a sign that says hello world",
        "height": 1024,
        "width": 1024,
        "guidance_scale": 3.5,
        "num_inference_steps": 28,
        "max_sequence_length": 512,
        "generator": torch.manual_seed(0),
    }
    time = benchmark_fn(run_inference, pipe, pipe_kwargs)
    inference_memory = torch.cuda.max_memory_allocated() / (1024 ** 3)
    inference_memory = float(f"{inference_memory:.2f}")
    ram_bytes = process.memory_info().rss
    ram_gb = ram_bytes / (1024 ** 3)

    # report
    print(f"Peak GPU memory: {inference_memory} GB")
    print(f"Resident CPU memory (RSS): {ram_gb:.2f} GB")

    prefix = "base"
    for key, value in vars(args).items():
        prefix += f"_{key}@{value}"
    
    image = pipe(**pipe_kwargs).images[0]
    image.save(f"{prefix}.png")
    
    artifact_dict = {"time": time, "memory": inference_memory, "ram": ram_gb}
    artifact_dict.update(vars(args))
    with open(f"{prefix}.json", "w") as f:
        json.dump(artifact_dict, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_cpu_offload", action="store_true")
    parser.add_argument("--seq_cpu_offload", action="store_true")
    parser.add_argument("--group_offload", action="store_true")
    parser.add_argument("--offload_to_disk", action="store_true")
    args = parser.parse_args()

    main(args)

Quality comparison:
preview

The stark background color difference in regular group offloading exists in the main branch as well. So, I am not sure what is happening there.

Group offloading with disk serialization/deserialization works with torch.compile(), too.

This PR is a PoC and hence, it has some things that can be made better. I'd be fine if the PR is completely dropped or if someone else wants to take it over and see it to completion. Otherwise, I am completely fine working on it.

@asomoza I think you will be quite interested in this one.

@sayakpaul sayakpaul requested review from DN6 and a-r-r-o-w June 9, 2025 06:59
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really cool work at getting this started!

  • Can we see some results with more compute heavy model like Wan?
  • We probably need to look at some profiles to see if there is overlapping happening here when streams are used with disk-offload (reason: I think there's a blocking operation which prevents this, but not 100% sure)
  • re: stark background color difference; Weird, I'll take a look
  • can we also benchmark the disk memory usage?

Edit: For the benchmark, I think a fair comparison for all methods would require us to use group offloading on all components instead of just transformer. Maybe the benchmark could be updated to show the memory usages with (1) just transformer, (2) all components

all_tensors.extend(list(module.buffers()))
all_tensors.extend(self.parameters)
all_tensors.extend(self.buffers)
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): will there be duplicates? I cannot think of a quick example, so maybe we can remove

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There shouldn't really be. But I kept it to prevent edge-cases while reading something similar.

self._is_offloaded_to_disk = False

if self.offload_to_disk:
if self.offload_path is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): probably could package this into a separate helper function like what's done with _init_cpu_param_dict

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's a one-liner now, I think it's okay here.

Comment on lines +160 to +166
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think cleaner approach would be to provide a callable to map_location (assuming we were using torch.load instead of safetensors), which for each tensor can pin and move to device. Do we know if there is a equivalent to passing a callable with safetensors? If not, this is okay too

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know if there would be other alternatives to this code path? If not, I think it's better as is. From skimming through the documentation of safetensors, I couldn't find any equivalent of map_location.

@@ -169,6 +219,18 @@ def onload_(self):
@torch.compiler.disable()
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.offload_to_disk:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): we probably need to refactor this a bit and break into smaller methods so we don't have to branch and do early-returns every time a new feature is added (we can do refactor once we have everything working, so not urgent)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. Can I do it in an immediate follow-up PR so that it's easier to review?

self._is_offloaded_to_disk = True

for tensor_obj in self.tensor_to_key.keys():
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for this to be different from the non-disk-offload counterpart? That is, is there a reason we're not doing buffer.data.to(self.offload_device, non_blocking=self.non_blocking)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we first free up the memory of the accelerator with:

key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()

However, since we're also optimizing for RAM usage (can be made clearer through documentation I believe), we need to free up the RAM that is holding the tensor data. After the data has been safely written from RAM to the disk, this step replaces the large data tensor in RAM with a memory-less placeholder. This allows the memory to be released.

@sayakpaul
Copy link
Member Author

@a-r-r-o-w thanks for your comments. I will work on them.

We probably need to look at some profiles to see if there is overlapping happening here when streams are used with disk-offload

I can gather this. Should we gather the CPU and GPU activities through the profiler and export a trace? If you have any references for me to consider, feel free to send over.

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Jun 9, 2025

I can gather this. Should we gather the CPU and GPU activities through the profiler and export a trace? If you have any references for me to consider, feel free to send over.

I think CPU/GPU activities will measure all the operations in the model, so if you could collect the filtered stream-related operations and related onloading/offloading times, it'll be helpful!

You probably already know but for the readers, this will help with gathering the traces for visualization:

@SunMarc SunMarc self-requested a review June 9, 2025 12:51
@sayakpaul
Copy link
Member Author

@a-r-r-o-w for Wan:

time memory ram model_cpu_offload seq_cpu_offload group_offload offload_to_disk
0 237.827 15.15 2.80788 False False True True
1 151.684 41.89 1.96027 False False False False
2 226.929 15.15 42.4634 False False True False
3 172.826 28.7 39.4348 True False False False

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Jun 9, 2025

Awesome, thanks for sharing! The numbers look good.

re: weird color results with group offloading; i looked into it and seems to only happen with block level (I don't know why yet). I think it could be because of some incorrect/missing synchronization somewhere, so will try to fix. If you use leaf_level, it should produce the same result.

@sayakpaul sayakpaul added the performance Anything related to performance improvements, profiling and benchmarking label Jun 10, 2025
@sayakpaul
Copy link
Member Author

@a-r-r-o-w could I ask for another review at this point? I have also added a simple test to make sure it's working but I can also add a heavier integration test that checks for VRAM and RAM usage. Apart from that, I think only doc is missing.

@sayakpaul sayakpaul requested a review from a-r-r-o-w June 12, 2025 05:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Anything related to performance improvements, profiling and benchmarking
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants