Skip to content

TorchAO compile + offloading tests #11697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
return module

def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
breakpoint()
# If there wasn't an onload_leader assigned, we assume that the submodule that first called its forward
# method is the onload_leader of the group.
if self.group.onload_leader is None:
Expand Down Expand Up @@ -285,6 +286,7 @@ def callback():
return module

def post_forward(self, module, output):
breakpoint()
# At this point, for the current modules' submodules, we know the execution order of the layers. We can now
# remove the layer execution tracker hooks and apply prefetching by setting the next_group attribute for each
# group offloading hook.
Expand Down Expand Up @@ -624,7 +626,9 @@ def _apply_group_offloading_leaf_level(
modules_with_group_offloading = set()
for name, submodule in module.named_modules():
if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
print("unsupported module", name, type(submodule))
continue
print("applying group offloading to", name, type(submodule))
group = ModuleGroup(
modules=[submodule],
offload_device=offload_device,
Expand Down
2 changes: 1 addition & 1 deletion tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,4 +881,4 @@ def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)

def test_torch_compile_with_group_offload(self):
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)
super()._test_torch_compile_with_group_offload_leaf_stream(quantization_config=self.quantization_config)
2 changes: 1 addition & 1 deletion tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,6 @@ def test_torch_compile_with_cpu_offload(self):

@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
def test_torch_compile_with_group_offload(self):
super()._test_torch_compile_with_group_offload(
super()._test_torch_compile_with_group_offload_leaf_stream(
quantization_config=self.quantization_config, torch_dtype=torch.float16
)
25 changes: 23 additions & 2 deletions tests/quantization/test_torch_compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,29 @@ def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)

def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16):
def _test_torch_compile_with_group_offload_leaf(self, quantization_config, torch_dtype=torch.bfloat16):
torch._dynamo.config.cache_size_limit = 10000

pipe = self._init_pipeline(quantization_config, torch_dtype)
group_offload_kwargs = {
"onload_device": torch.device("cuda"),
"offload_device": torch.device("cpu"),
"offload_type": "leaf_level",
"num_blocks_per_group": 1,
"use_stream": False,
}
pipe.transformer.enable_group_offload(**group_offload_kwargs)
# pipe.transformer.compile()
for name, component in pipe.components.items():
if name != "transformer" and isinstance(component, torch.nn.Module):
if torch.device(component.device).type == "cpu":
component.to("cuda")

for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)

def _test_torch_compile_with_group_offload_leaf_stream(self, quantization_config, torch_dtype=torch.bfloat16):
torch._dynamo.config.cache_size_limit = 10000

pipe = self._init_pipeline(quantization_config, torch_dtype)
Expand All @@ -73,7 +95,6 @@ def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtyp
"offload_device": torch.device("cpu"),
"offload_type": "leaf_level",
"use_stream": True,
"non_blocking": True,
}
pipe.transformer.enable_group_offload(**group_offload_kwargs)
pipe.transformer.compile()
Expand Down
31 changes: 31 additions & 0 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
TorchAoConfig,
)
from diffusers.models.attention_processor import Attention
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_synchronize,
Expand All @@ -44,6 +45,8 @@
torch_device,
)

from ..test_torch_compile_utils import QuantCompileTests


enable_full_determinism()

Expand Down Expand Up @@ -625,6 +628,34 @@ def test_int_a16w8_cpu(self):
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)


@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoCompileTest(QuantCompileTests):
quantization_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
},
)

def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)

def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)

def test_torch_compile_with_group_offload_leaf(self):
from diffusers.utils.logging import set_verbosity_debug

set_verbosity_debug()
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)

@unittest.skip(
"Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO."
)
def test_torch_compile_with_group_offload_leaf_stream(self):
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
super()._test_torch_compile_with_group_offload_leaf_stream(quantization_config=self.quantization_config)


# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
Expand Down
Loading