From d561f6552d262ca9a1cfa235c1e014199b358b7c Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 3 Jun 2025 08:01:42 +0000 Subject: [PATCH 01/10] frst commit --- .github/workflows/nightly_tests.yml | 2 + docs/source/en/_toctree.yml | 2 + docs/source/en/api/quantization.md | 5 + .../source/en/quantization/finegrained_fp8.md | 15 + docs/source/en/quantization/overview.md | 2 +- src/diffusers/__init__.py | 4 + src/diffusers/models/modeling_utils.py | 2 + src/diffusers/quantizers/auto.py | 4 + .../quantizers/finegrained_fp8/__init__.py | 1 + .../finegrained_fp8_quantizer.py | 206 +++++++++ .../quantizers/finegrained_fp8/utils.py | 422 ++++++++++++++++++ .../quantizers/quantization_config.py | 44 +- 12 files changed, 707 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/quantization/finegrained_fp8.md create mode 100644 src/diffusers/quantizers/finegrained_fp8/__init__.py create mode 100644 src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py create mode 100644 src/diffusers/quantizers/finegrained_fp8/utils.py diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 4f92717df8b7..a89199c96919 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -473,6 +473,8 @@ jobs: additional_deps: [] - backend: "optimum_quanto" test_location: "quanto" + - backend: "finegrained_fp8" + test_location: "finegrained_fp8" additional_deps: [] runs-on: group: aws-g6e-xlarge-plus diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a6f265fd3fb5..91cc4b8e6324 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -174,6 +174,8 @@ title: torchao - local: quantization/quanto title: quanto + - local: quantization/finegrained_fp8 + title: finegrained_fp8 title: Quantization Methods - sections: - local: optimization/fp16 diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md index e2ca990190e6..c6baed8abe11 100644 --- a/docs/source/en/api/quantization.md +++ b/docs/source/en/api/quantization.md @@ -41,6 +41,11 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui [[autodoc]] TorchAoConfig +## FinegrainedFP8Config + +[[autodoc]] FinegrainedFP8Config + ## DiffusersQuantizer [[autodoc]] quantizers.base.DiffusersQuantizer + diff --git a/docs/source/en/quantization/finegrained_fp8.md b/docs/source/en/quantization/finegrained_fp8.md new file mode 100644 index 000000000000..4173238e2df7 --- /dev/null +++ b/docs/source/en/quantization/finegrained_fp8.md @@ -0,0 +1,15 @@ + + +# FinegrainedFP8 + +## Overview + +## Usage + diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 68b99f524ec0..2c422d79125b 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -37,7 +37,7 @@ Diffusers currently supports the following quantization methods. - [TorchAO](./torchao) - [GGUF](./gguf) - [Quanto](./quanto.md) - +- [FinegrainedFP8](./finegrained_fp8.md) [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. ## Pipeline-level quantization diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9ab973351c86..2019d06d6402 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -96,6 +96,8 @@ else: _import_structure["quantizers.quantization_config"].append("TorchAoConfig") +_import_structure["quantizers.quantization_config"].append("FinegrainedFP8Config") + try: if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available(): raise OptionalDependencyNotAvailable() @@ -724,6 +726,8 @@ else: from .quantizers.quantization_config import QuantoConfig + from .quantizers.quantization_config import FinegrainedFP8Config + try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 55ce0cf79fb9..638c5fbfb009 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1238,6 +1238,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P } # Dispatch model with hooks on all devices if necessary + print(model.transformer_blocks[0].attn.to_q.weight) + print(model.transformer_blocks[0].attn.to_q.weight_scale_inv) if device_map is not None: device_map_kwargs = { "device_map": device_map, diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index ce214ae7bc17..d4c5ce720124 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -28,9 +28,11 @@ QuantizationMethod, QuantoConfig, TorchAoConfig, + FinegrainedFP8Config, ) from .quanto import QuantoQuantizer from .torchao import TorchAoHfQuantizer +from .finegrained_fp8 import FinegrainedFP8Quantizer AUTO_QUANTIZER_MAPPING = { @@ -39,6 +41,7 @@ "gguf": GGUFQuantizer, "quanto": QuantoQuantizer, "torchao": TorchAoHfQuantizer, + "finegrained_fp8": FinegrainedFP8Quantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -47,6 +50,7 @@ "gguf": GGUFQuantizationConfig, "quanto": QuantoConfig, "torchao": TorchAoConfig, + "finegrained_fp8": FinegrainedFP8Config, } diff --git a/src/diffusers/quantizers/finegrained_fp8/__init__.py b/src/diffusers/quantizers/finegrained_fp8/__init__.py new file mode 100644 index 000000000000..1763207f09ce --- /dev/null +++ b/src/diffusers/quantizers/finegrained_fp8/__init__.py @@ -0,0 +1 @@ +from .finegrained_fp8_quantizer import FinegrainedFP8Quantizer \ No newline at end of file diff --git a/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py b/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py new file mode 100644 index 000000000000..5dec8b0b8699 --- /dev/null +++ b/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py @@ -0,0 +1,206 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from ...utils import is_accelerate_available, is_torch_available, logging +from ..base import DiffusersQuantizer +from ...utils import get_module_from_name + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + +class FinegrainedFP8Quantizer(DiffusersQuantizer): + """ + FP8 quantization implementation supporting both standard and MoE models. + Supports both e4m3fn formats based on platform. + """ + + requires_parameters_quantization = True + requires_calibration = False + required_packages = ["accelerate"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + self.quantization_config = quantization_config + + def validate_environment(self, *args, **kwargs): + if not is_torch_available(): + raise ImportError( + "Using fp8 quantization requires torch >= 2.1.0" + "Please install the latest version of torch ( pip install --upgrade torch )" + ) + + if not is_accelerate_available(): + raise ImportError("Loading an FP8 quantized model requires accelerate (`pip install accelerate`)") + + if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): + raise ValueError( + "Converting into FP8 weights from tf/flax weights is currently not supported, " + "please make sure the weights are in PyTorch format." + ) + + if torch.cuda.is_available(): + compute_capability = torch.cuda.get_device_capability() + major, minor = compute_capability + if (major < 8) or (major == 8 and minor < 9): + raise ValueError( + "FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)" + f", actual = `{major}.{minor}`" + ) + + device_map = kwargs.get("device_map", None) + if device_map is None: + logger.warning_once( + "You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set " + "your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. " + ) + elif device_map is not None: + if ( + not self.pre_quantized + and isinstance(device_map, dict) + and ("cpu" in device_map.values() or "disk" in device_map.values()) + ): + raise ValueError( + "You are attempting to load an FP8 model with a device_map that contains a cpu/disk device." + "This is not supported when the model is quantized on the fly. " + "Please use a quantized checkpoint or remove the cpu/disk device from the device_map." + ) + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + logger.info("Setting torch_dtype to torch.float32 as no torch_dtype was specified in from_pretrained") + torch_dtype = torch.float32 + return torch_dtype + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: Optional[List[str]] = None, + **kwargs, + ): + """ + Quantizes weights to FP8 format using Block-wise quantization + """ + # print("############ create quantized param ########") + from accelerate.utils import set_module_tensor_to_device + + set_module_tensor_to_device(model, param_name, target_device, param_value) + + module, tensor_name = get_module_from_name(model, param_name) + + # Get FP8 min/max values + fp8_min = torch.finfo(torch.float8_e4m3fn).min + fp8_max = torch.finfo(torch.float8_e4m3fn).max + + block_size_m, block_size_n = self.quantization_config.weight_block_size + + rows, cols = param_value.shape[-2:] + + if rows % block_size_m != 0 or cols % block_size_n != 0: + raise ValueError( + f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})" + ) + param_value_orig_shape = param_value.shape + + param_value = param_value.reshape( + -1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n + ).permute(0, 1, 3, 2, 4) + + # Calculate scaling factor for each block + max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2)) + scale = fp8_max / max_abs + scale_orig_shape = scale.shape + scale = scale.unsqueeze(-1).unsqueeze(-1) + + # Quantize the weights + quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + quantized_param = quantized_param.permute(0, 1, 3, 2, 4) + # Reshape back to matrix shape + quantized_param = quantized_param.reshape(param_value_orig_shape) + + # Reshape scale to match the number of blocks + scale = scale.reshape(scale_orig_shape).squeeze().reciprocal() + + # Load into the model + module._buffers[tensor_name] = quantized_param.to(target_device) + module._buffers["weight_scale_inv"] = scale.to(target_device) + # print("_buffers[0]", module._buffers["weight_scale_inv"]) + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ): + from .utils import FP8Linear + + module, tensor_name = get_module_from_name(model, param_name) + if isinstance(module, FP8Linear): + if self.pre_quantized or tensor_name == "bias": + if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn: + raise ValueError("Expect quantized weights but got an unquantized weight") + return False + else: + if tensor_name == "weight_scale_inv": + raise ValueError("Expect unquantized weights but got a quantized weight_scale") + return True + return False + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + keep_in_fp32_modules: Optional[List[str]] = None, + **kwargs, + ): + from .utils import replace_with_fp8_linear + + if self.quantization_config.modules_to_not_convert is not None: + self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) + + model = replace_with_fp8_linear( + model, + modules_to_not_convert=self.modules_to_not_convert, + quantization_config=self.quantization_config, + ) + + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): + return model + + def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: + from .utils import FP8Linear + + not_missing_keys = [] + for name, module in model.named_modules(): + if isinstance(module, FP8Linear): + for missing in missing_keys: + if ( + (name in missing or name in f"{prefix}.{missing}") + and not missing.endswith(".weight") + and not missing.endswith(".bias") + ): + not_missing_keys.append(missing) + return [k for k in missing_keys if k not in not_missing_keys] + + def is_serializable(self, safe_serialization=None): + return True + + @property + def is_trainable(self) -> bool: + return False + + def get_cuda_warm_up_factor(self): + # Pre-processing is done cleanly, so we can allocate everything here + return 2 diff --git a/src/diffusers/quantizers/finegrained_fp8/utils.py b/src/diffusers/quantizers/finegrained_fp8/utils.py new file mode 100644 index 000000000000..5bc5cb984647 --- /dev/null +++ b/src/diffusers/quantizers/finegrained_fp8/utils.py @@ -0,0 +1,422 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +from ...utils import is_accelerate_available, is_torch_available, logging + + +if is_torch_available(): + import torch + import torch.nn as nn + import triton + import triton.language as tl + from torch.nn import functional as F + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +logger = logging.get_logger(__name__) + + +# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py +@triton.jit +def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + s = tl.max(tl.abs(x)) / 448.0 + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + pid, s) + + +def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.is_contiguous() + assert x.shape[-1] % block_size == 0 + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + + def grid(meta): + return (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),) + + act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + return y, s + + +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py +@triton.jit +def _w8a8_block_fp8_matmul( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and + store the result in output tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def w8a8_block_fp8_matmul_triton( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """This function performs matrix multiplication with block-wise + quantization. + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. It should + be 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N,) + C = A.new_empty(C_shape, dtype=output_dtype) + + BLOCK_SIZE_M = 128 + if M < BLOCK_SIZE_M: + BLOCK_SIZE_M = triton.next_power_of_2(M) + BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16) + BLOCK_SIZE_K = block_k + assert block_k % BLOCK_SIZE_K == 0 + BLOCK_SIZE_N = block_n + + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),) + + _w8a8_block_fp8_matmul[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + ) + + return C + + +# Python version of the above triton function, it's much slower than the triton version, for testing +@torch.compile +def w8a8_block_fp8_matmul_compile( + input_q: torch.Tensor, # [batch, seq_len, hidden_dim] + weight_q: torch.Tensor, # [out_features, hidden_dim] + input_scale: torch.Tensor, # [batch * seq_len, num_input_groups] + weight_scale: torch.Tensor, # [num_weight_blocks_m, num_weight_blocks_n] + block_size: Optional[Tuple[int, int]] = None, # (M=128, N=128) for weights for example + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Performs blocked matrix multiplication with FP8 quantized matrices. + + Args: + input_q: Quantized input tensor with 1x128 block quantization + weight_q: Quantized weight tensor with 128x128 block quantization + input_scale: Scaling factors for input blocks + weight_scale: Scaling factors for weight blocks + block_size: Tuple of (M, N) for weight block dimensions + output_dtype: Desired output dtype + """ + batch_size, seq_len, hidden_dim = input_q.shape if input_q.ndim == 3 else (1, input_q.shape[0], input_q.shape[1]) + out_features = weight_q.shape[0] + + # Reshape input for batched matmul + input_reshaped = input_q.view(-1, hidden_dim) # [batch*seq_len, hidden_dim] + input_scale_reshaped = input_scale.view(input_scale.shape[0], -1) # [batch*seq_len, 1] + # Calculate number of blocks + num_weight_blocks_m = out_features // block_size[0] + num_weight_blocks_n = hidden_dim // block_size[1] + + output = torch.zeros((batch_size * seq_len, out_features), dtype=torch.float32, device=input_q.device) + + for i in range(num_weight_blocks_m): + m_start = i * block_size[0] + m_end = m_start + block_size[0] + + for j in range(num_weight_blocks_n): + n_start = j * block_size[1] + n_end = n_start + block_size[1] + + # Extract current blocks + input_block = input_reshaped[:, n_start:n_end] + weight_block = weight_q[m_start:m_end, n_start:n_end] + + # Get corresponding scales + curr_input_scale = input_scale_reshaped[:, j : j + 1] # [batch*seq_len, 1] + curr_weight_scale = weight_scale[i, j] # scalar + + block_result = ( + torch._scaled_mm( + input_block, + weight_block.t(), + scale_a=torch.tensor(1, dtype=torch.float32, device=input_q.device), + scale_b=curr_weight_scale, + out_dtype=output_dtype, + ) + * curr_input_scale + ) + + output[:, m_start:m_end] += block_result + + output = output.view(batch_size, seq_len, out_features) + + return output.to(output_dtype) + + +class FP8Linear(nn.Linear): + dtype = torch.float8_e4m3fn + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + dtype=None, + block_size: Optional[Tuple[int, int]] = None, + device=None, + activation_scheme="dynamic", + ): + super().__init__(in_features, out_features) + self.in_features = in_features + self.out_features = out_features + + self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device)) + + if self.weight.element_size() == 1: + scale_out_features = (out_features + block_size[0] - 1) // block_size[0] + scale_in_features = (in_features + block_size[1] - 1) // block_size[1] + self.weight_scale_inv = nn.Parameter( + torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device) + ) + else: + self.register_parameter("weight_scale_inv", None) + + self.block_size = block_size + + self.activation_scheme = activation_scheme + + if bias: + self.bias = nn.Parameter(torch.empty(self.out_features)) + else: + self.register_parameter("bias", None) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.weight.element_size() > 1: + return F.linear(input, self.weight, self.bias) + else: + # Context manager used to switch among the available accelerators + torch_accelerator_module = getattr(torch, "cuda", torch.cuda) + with torch_accelerator_module.device(input.device): + qinput, scale = act_quant(input, self.block_size[1]) + output = w8a8_block_fp8_matmul_triton( + qinput, + self.weight, + scale, + self.weight_scale_inv, + self.block_size, + output_dtype=input.dtype, + ) + # Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the + # preceding operations are ready before proceeding + torch_accelerator_module.synchronize() + if self.bias is not None: + output = output + self.bias + return output.to(dtype=input.dtype) + + +def _replace_with_fp8_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, +): + """Replace Linear layers with FP8Linear.""" + if current_key_name is None: + current_key_name = [] + + for name, module in model.named_children(): + current_key_name.append(name) + + if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []): + current_key_name_str = ".".join(current_key_name) + if not any(key in current_key_name_str for key in (modules_to_not_convert or [])): + with init_empty_weights(): + model._modules[name] = FP8Linear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + device=module.weight.device, + dtype=module.weight.dtype, + activation_scheme=quantization_config.activation_scheme, + block_size=quantization_config.weight_block_size, + ) + has_been_replaced = True + # when changing a layer the TP PLAN for that layer should be updated. TODO + + if len(list(module.children())) > 0: + _, has_been_replaced = _replace_with_fp8_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + + current_key_name.pop(-1) + + return model, has_been_replaced + + +def replace_with_fp8_linear( + model, + modules_to_not_convert=None, + quantization_config=None, +): + """Helper function to replace model layers with FP8 versions.""" + modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert + + if quantization_config.modules_to_not_convert is not None: + modules_to_not_convert.extend(quantization_config.modules_to_not_convert) + modules_to_not_convert = list(set(modules_to_not_convert)) + model, has_been_replaced = _replace_with_fp8_linear( + model, + modules_to_not_convert=modules_to_not_convert, + quantization_config=quantization_config, + ) + + if not has_been_replaced: + logger.warning( + "You are loading your model using fp8 but no linear modules were found in your model." + " Please double check your model architecture." + ) + + return model diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 609c9ad15a31..dc9b6c7b9ef8 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -28,7 +28,7 @@ from dataclasses import dataclass from enum import Enum from functools import partial -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Tuple from packaging import version @@ -46,6 +46,7 @@ class QuantizationMethod(str, Enum): GGUF = "gguf" TORCHAO = "torchao" QUANTO = "quanto" + FINEGRAINED_FP8 = "finegrained_fp8" if is_torchao_available(): @@ -722,3 +723,44 @@ def post_init(self): accepted_weights = ["float8", "int8", "int4", "int2"] if self.weights_dtype not in accepted_weights: raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}") + + + +@dataclass +class FinegrainedFP8Config(QuantizationConfigMixin): + """ + FinegrainedFP8Config is a configuration class for fine-grained FP8 quantization used mainly for deepseek models. + + Args: + activation_scheme (`str`, *optional*, defaults to `"dynamic"`): + The scheme used for activation, the defaults and only support scheme for now is "dynamic". + weight_block_size (`typing.Tuple[int, int]`, *optional*, defaults to `(128, 128)`): + The size of the weight blocks for quantization, default is (128, 128). + modules_to_not_convert (`list`, *optional*): + A list of module names that should not be converted during quantization. + """ + + def __init__( + self, + activation_scheme: str = "dynamic", + weight_block_size: Tuple[int, int] = (128, 128), + modules_to_not_convert: Optional[List] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.FINEGRAINED_FP8 + self.modules_to_not_convert = modules_to_not_convert + self.activation_scheme = activation_scheme + self.weight_block_size = weight_block_size + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct + """ + self.activation_scheme = self.activation_scheme.lower() + if self.activation_scheme not in ["dynamic"]: + raise ValueError(f"Activation scheme {self.activation_scheme} not supported") + if len(self.weight_block_size) != 2: + raise ValueError("weight_block_size must be a tuple of two integers") + if self.weight_block_size[0] <= 0 or self.weight_block_size[1] <= 0: + raise ValueError("weight_block_size must be a tuple of two positive integers") From 83163bcd5adf540380a626c5547ba005c89b3e4e Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Wed, 4 Jun 2025 12:00:13 +0000 Subject: [PATCH 02/10] buffer to parameter --- .../quantizers/finegrained_fp8/finegrained_fp8_quantizer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py b/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py index 5dec8b0b8699..8e91060d909b 100644 --- a/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py +++ b/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py @@ -131,9 +131,8 @@ def create_quantized_param( scale = scale.reshape(scale_orig_shape).squeeze().reciprocal() # Load into the model - module._buffers[tensor_name] = quantized_param.to(target_device) - module._buffers["weight_scale_inv"] = scale.to(target_device) - # print("_buffers[0]", module._buffers["weight_scale_inv"]) + module._parameters[tensor_name] = quantized_param.to(target_device) + module._parameters["weight_scale_inv"] = scale.to(target_device) def check_if_quantized_param( self, From e97e2ef1157e6957cfdf5048075a07189903989a Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 12 Jun 2025 12:07:16 +0000 Subject: [PATCH 03/10] resolve nan issue --- src/diffusers/models/modeling_utils.py | 2 -- src/diffusers/quantizers/finegrained_fp8/utils.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 638c5fbfb009..55ce0cf79fb9 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1238,8 +1238,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P } # Dispatch model with hooks on all devices if necessary - print(model.transformer_blocks[0].attn.to_q.weight) - print(model.transformer_blocks[0].attn.to_q.weight_scale_inv) if device_map is not None: device_map_kwargs = { "device_map": device_map, diff --git a/src/diffusers/quantizers/finegrained_fp8/utils.py b/src/diffusers/quantizers/finegrained_fp8/utils.py index 5bc5cb984647..1d3c297d70bb 100644 --- a/src/diffusers/quantizers/finegrained_fp8/utils.py +++ b/src/diffusers/quantizers/finegrained_fp8/utils.py @@ -38,7 +38,7 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offs).to(tl.float32) - s = tl.max(tl.abs(x)) / 448.0 + s = tl.max(tl.abs(x)) / 448.0 + 1e-6 y = x / s y = y.to(y_ptr.dtype.element_ty) tl.store(y_ptr + offs, y) From 33b2be7af7a5f859749b3b0c131882c0eba33cc4 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 12 Jun 2025 12:45:08 +0000 Subject: [PATCH 04/10] fix style --- src/diffusers/quantizers/auto.py | 4 ++-- src/diffusers/quantizers/finegrained_fp8/__init__.py | 2 +- .../quantizers/finegrained_fp8/finegrained_fp8_quantizer.py | 4 ++-- src/diffusers/quantizers/quantization_config.py | 3 +-- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index d4c5ce720124..6f9b1e6a26a0 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -20,19 +20,19 @@ from typing import Dict, Optional, Union from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer +from .finegrained_fp8 import FinegrainedFP8Quantizer from .gguf import GGUFQuantizer from .quantization_config import ( BitsAndBytesConfig, + FinegrainedFP8Config, GGUFQuantizationConfig, QuantizationConfigMixin, QuantizationMethod, QuantoConfig, TorchAoConfig, - FinegrainedFP8Config, ) from .quanto import QuantoQuantizer from .torchao import TorchAoHfQuantizer -from .finegrained_fp8 import FinegrainedFP8Quantizer AUTO_QUANTIZER_MAPPING = { diff --git a/src/diffusers/quantizers/finegrained_fp8/__init__.py b/src/diffusers/quantizers/finegrained_fp8/__init__.py index 1763207f09ce..107f58184ee4 100644 --- a/src/diffusers/quantizers/finegrained_fp8/__init__.py +++ b/src/diffusers/quantizers/finegrained_fp8/__init__.py @@ -1 +1 @@ -from .finegrained_fp8_quantizer import FinegrainedFP8Quantizer \ No newline at end of file + src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.pyfrom .finegrained_fp8_quantizer import FinegrainedFP8Quantizer diff --git a/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py b/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py index 8e91060d909b..97c9df7979e3 100644 --- a/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py +++ b/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py @@ -1,8 +1,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional -from ...utils import is_accelerate_available, is_torch_available, logging +from ...utils import get_module_from_name, is_accelerate_available, is_torch_available, logging from ..base import DiffusersQuantizer -from ...utils import get_module_from_name if is_torch_available(): @@ -13,6 +12,7 @@ if TYPE_CHECKING: from ...models.modeling_utils import ModelMixin + class FinegrainedFP8Quantizer(DiffusersQuantizer): """ FP8 quantization implementation supporting both standard and MoE models. diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index dc9b6c7b9ef8..d38cad104057 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -28,7 +28,7 @@ from dataclasses import dataclass from enum import Enum from functools import partial -from typing import Any, Dict, List, Optional, Union, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from packaging import version @@ -725,7 +725,6 @@ def post_init(self): raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}") - @dataclass class FinegrainedFP8Config(QuantizationConfigMixin): """ From 0eb39894a10bdb7af1e91688689d4c856dcc5528 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 12 Jun 2025 12:49:37 +0000 Subject: [PATCH 05/10] update --- src/diffusers/quantizers/finegrained_fp8/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/quantizers/finegrained_fp8/__init__.py b/src/diffusers/quantizers/finegrained_fp8/__init__.py index 107f58184ee4..83cb327c08a4 100644 --- a/src/diffusers/quantizers/finegrained_fp8/__init__.py +++ b/src/diffusers/quantizers/finegrained_fp8/__init__.py @@ -1 +1 @@ - src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.pyfrom .finegrained_fp8_quantizer import FinegrainedFP8Quantizer +from .finegrained_fp8_quantizer import FinegrainedFP8Quantizer From e6e0c4aa9d9d7814a0a3d1c6baa1172a62da6beb Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 12 Jun 2025 14:02:16 +0000 Subject: [PATCH 06/10] fix shapes --- .../finegrained_fp8/finegrained_fp8_quantizer.py | 10 +++++----- src/diffusers/quantizers/finegrained_fp8/utils.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py b/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py index 97c9df7979e3..4f53b5b712f1 100644 --- a/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py +++ b/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py @@ -106,13 +106,13 @@ def create_quantized_param( if rows % block_size_m != 0 or cols % block_size_n != 0: raise ValueError( - f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})" + f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n} for {param_name})" ) param_value_orig_shape = param_value.shape param_value = param_value.reshape( - -1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n - ).permute(0, 1, 3, 2, 4) + rows // block_size_m, block_size_m, cols // block_size_n, block_size_n + ).permute(0, 2, 1, 3) # Calculate scaling factor for each block max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2)) @@ -123,12 +123,12 @@ def create_quantized_param( # Quantize the weights quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - quantized_param = quantized_param.permute(0, 1, 3, 2, 4) + quantized_param = quantized_param.permute(0, 2, 1, 3) # Reshape back to matrix shape quantized_param = quantized_param.reshape(param_value_orig_shape) # Reshape scale to match the number of blocks - scale = scale.reshape(scale_orig_shape).squeeze().reciprocal() + scale = scale.reshape(scale_orig_shape).reciprocal() # Load into the model module._parameters[tensor_name] = quantized_param.to(target_device) diff --git a/src/diffusers/quantizers/finegrained_fp8/utils.py b/src/diffusers/quantizers/finegrained_fp8/utils.py index 1d3c297d70bb..30c510d795d1 100644 --- a/src/diffusers/quantizers/finegrained_fp8/utils.py +++ b/src/diffusers/quantizers/finegrained_fp8/utils.py @@ -175,7 +175,7 @@ def w8a8_block_fp8_matmul_triton( assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] M = A.numel() // A.shape[-1] - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2, f"B {B.shape} and Bs {Bs.shape}" N, K = B.shape assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] From 243173f7ae9e7461268f04502d366d91ec4dcea8 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 12 Jun 2025 14:06:47 +0000 Subject: [PATCH 07/10] small fix --- .../quantizers/finegrained_fp8/finegrained_fp8_quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py b/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py index 4f53b5b712f1..1ce7cda317b1 100644 --- a/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py +++ b/src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py @@ -106,7 +106,7 @@ def create_quantized_param( if rows % block_size_m != 0 or cols % block_size_n != 0: raise ValueError( - f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n} for {param_name})" + f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n}) for {param_name}" ) param_value_orig_shape = param_value.shape From c3cf41fe62f3c2c48fe6999725bed22c08c01aa9 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Fri, 13 Jun 2025 12:04:26 +0000 Subject: [PATCH 08/10] adding some tests --- .../quantizers/quantization_config.py | 15 + .../test_finegrained_fp8.py | 574 ++++++++++++++++++ 2 files changed, 589 insertions(+) create mode 100644 tests/quantization/finegrained_fp8.py/test_finegrained_fp8.py diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index d38cad104057..b39cff54f824 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -737,6 +737,19 @@ class FinegrainedFP8Config(QuantizationConfigMixin): The size of the weight blocks for quantization, default is (128, 128). modules_to_not_convert (`list`, *optional*): A list of module names that should not be converted during quantization. + + Example: + ```python + from diffusers import FluxTransformer2DModel, FinegrainedFP8Config + + quantization_config = FinegrainedFP8Config() + transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/Flux.1-Dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + ``` """ def __init__( @@ -759,7 +772,9 @@ def post_init(self): self.activation_scheme = self.activation_scheme.lower() if self.activation_scheme not in ["dynamic"]: raise ValueError(f"Activation scheme {self.activation_scheme} not supported") + if len(self.weight_block_size) != 2: raise ValueError("weight_block_size must be a tuple of two integers") + if self.weight_block_size[0] <= 0 or self.weight_block_size[1] <= 0: raise ValueError("weight_block_size must be a tuple of two positive integers") diff --git a/tests/quantization/finegrained_fp8.py/test_finegrained_fp8.py b/tests/quantization/finegrained_fp8.py/test_finegrained_fp8.py new file mode 100644 index 000000000000..420677775d76 --- /dev/null +++ b/tests/quantization/finegrained_fp8.py/test_finegrained_fp8.py @@ -0,0 +1,574 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest +from typing import List + +import numpy as np +from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxPipeline, + FluxTransformer2DModel, + TorchAoConfig, +) +from diffusers.models.attention_processor import Attention +from diffusers.utils.testing_utils import ( + enable_full_determinism, + is_torch_available, + nightly, + numpy_cosine_similarity_distance, + require_torch, + require_torch_gpu, + require_torchao_version_greater_or_equal, + slow, + torch_device, +) + +from diffusers.quantizers.quantization_config import FinegrainedFP8Config +from diffusers.quantizers.finegrained_fp8.utils import FP8Linear + + +enable_full_determinism() + + +if is_torch_available(): + import torch + import torch.nn as nn + + from ..utils import LoRALayer, get_memory_consumption_stat + + +@require_torch +@require_torch_gpu +class FinegrainedFP8ConfigTest(unittest.TestCase): + def test_to_dict(self): + """ + Makes sure the config format is properly set + """ + quantization_config = FinegrainedFP8Config() + finegrained_fp8_orig_config = quantization_config.to_dict() + + for key in finegrained_fp8_orig_config: + self.assertEqual(getattr(quantization_config, key), finegrained_fp8_orig_config[key]) + + def test_post_init_check(self): + """ + Test kwargs validations in FinegrainedFP8Config + """ + _ = FinegrainedFP8Config() + with self.assertRaisesRegex(ValueError, "weight_block_size must be a tuple of two integers"): + _ = FinegrainedFP8Config(weight_block_size=(1, 32, 32)) + + with self.assertRaisesRegex(ValueError, "weight_block_size must be a tuple of two positive integers"): + _ = FinegrainedFP8Config(weight_block_size=(0, 32)) + + +@require_torch +@require_torch_gpu +class FinegrainedFP8Test(unittest.TestCase): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_components( + self, quantization_config: FinegrainedFP8Config, model_id: str = "hf-internal-testing/tiny-flux-pipe" + ): + transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) + text_encoder_2 = T5EncoderModel.from_pretrained( + model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16 + ) + tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") + tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16) + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + def get_model_size_in_bytes(self, model, ignore_embeddings=False): + """ + Returns the model size in bytes. The option to ignore embeddings + is useful for models with disproportionately large embeddings compared + to other model parameters that get quantized/sparsified. + """ + model_size = 0 + for name, child in model.named_children(): + if not (isinstance(child, torch.nn.Embedding) and ignore_embeddings): + for p in child.parameters(recurse=False): + model_size += p.numel() * p.element_size() + for b in child.buffers(recurse=False): + model_size += b.numel() * b.element_size() + model_size += self.get_model_size_in_bytes(child, ignore_embeddings) + return model_size + + def get_dummy_inputs(self, device: torch.device, seed: int = 0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator().manual_seed(seed) + + inputs = { + "prompt": "an astronaut riding a horse in space", + "height": 32, + "width": 32, + "num_inference_steps": 2, + "output_type": "np", + "generator": generator, + } + + return inputs + + def get_dummy_tensor_inputs(self, device=None, seed: int = 0): + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + height = width = 4 + sequence_length = 48 + embedding_dim = 32 + + torch.manual_seed(seed) + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( + device, dtype=torch.bfloat16 + ) + + torch.manual_seed(seed) + pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) + text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) + image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) + + timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_prompt_embeds, + "txt_ids": text_ids, + "img_ids": image_ids, + "timestep": timestep, + } + + def _test_quantization_output(self, quantization_config: FinegrainedFP8Config, expected_slice: List[float], model_id: str): + components = self.get_dummy_components(quantization_config, model_id) + pipe = FluxPipeline(**components) + pipe.to(device=torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + output_slice = output[-1, -1, -3:, -3:].flatten() + + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + def test_quantization(self): + expected_slice = [np.array([0.34179688, -0.03613281, 0.01428223, -0.22949219, -0.49609375, 0.4375, -0.1640625, -0.66015625, 0.43164062]), np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])] + for index, model_id in enumerate(["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]): + quantization_config = FinegrainedFP8Config( + modules_to_not_convert=["x_embedder", "proj_out"], + weight_block_size=(32, 32) + ) + self._test_quantization_output(quantization_config, model_id, expected_slice[index]) + + def test_dtype(self): + """ + Tests whether the dtype of the weight and weight_scale_inv are correct + """ + quantization_config = FinegrainedFP8Config( + modules_to_not_convert=["x_embedder", "proj_out"], + weight_block_size=(32, 32) + ) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + + weight = quantized_model.transformer_blocks[0].ff.net[2].weight + weight_scale_inv = quantized_model.transformer_blocks[0].ff.net[2].weight_scale_inv + self.assertTrue(isinstance(weight, FP8Linear)) + self.assertEqual(weight_scale_inv.dtype, torch.bfloat16) + self.assertEqual(weight.weight.dtype, torch.float8_e4m3fn) + + def test_device_map_auto(self): + """ + Test if the quantized model is working properly with "auto" + """ + + inputs = self.get_dummy_tensor_inputs(torch_device) + # requires with different expected slices since models are different due to offload (we don't quantize modules offloaded to cpu/disk) + expected_slice_auto = np.array( + [ + 0.34179688, + -0.03613281, + 0.01428223, + -0.22949219, + -0.49609375, + 0.4375, + -0.1640625, + -0.66015625, + 0.43164062, + ] + ) + + quantization_config = FinegrainedFP8Config( + modules_to_not_convert=["x_embedder", "proj_out"], + weight_block_size=(32, 32) + ) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + device_map="auto", + torch_dtype=torch.bfloat16, + ) + + output = quantized_model(**inputs)[0] + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice_auto) < 1e-3) + + + def test_modules_to_not_convert(self): + quantization_config = FinegrainedFP8Config(weight_block_size=(32, 32), modules_to_not_convert=["transformer_blocks.0", "proj_out", "x_embedder"]) + quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + + unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2] + self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear)) + self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16) + + quantization_config = FinegrainedFP8Config( + modules_to_not_convert=["x_embedder", "proj_out"], + weight_block_size=(32, 32) + ) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + + size_quantized_with_not_convert = self.get_model_size_in_bytes(quantized_model_with_not_convert) + size_quantized = self.get_model_size_in_bytes(quantized_model) + + self.assertTrue(size_quantized < size_quantized_with_not_convert) + + def test_training(self): + quantization_config = FinegrainedFP8Config(weight_block_size=(32, 32), modules_to_not_convert=["x_embedder", "proj_out"]) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ).to(torch_device) + + for param in quantized_model.parameters(): + # freeze the model as only adapter layers will be trained + param.requires_grad = False + if param.ndim == 1: + param.data = param.data.to(torch.float32) + + for _, module in quantized_model.named_modules(): + if isinstance(module, Attention): + module.to_q = LoRALayer(module.to_q, rank=4) + module.to_k = LoRALayer(module.to_k, rank=4) + module.to_v = LoRALayer(module.to_v, rank=4) + + with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16): + inputs = self.get_dummy_tensor_inputs(torch_device) + output = quantized_model(**inputs)[0] + output.norm().backward() + + for module in quantized_model.modules(): + if isinstance(module, LoRALayer): + self.assertTrue(module.adapter[1].weight.grad is not None) + self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) + + @nightly + def test_torch_compile(self): + r"""Test that verifies if torch.compile works with torchao quantization.""" + for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: + quantization_config = FinegrainedFP8Config(weight_block_size=(32, 32), modules_to_not_convert=["x_embedder", "proj_out"]) + components = self.get_dummy_components(quantization_config, model_id=model_id) + pipe = FluxPipeline(**components) + pipe.to(device=torch_device) + + inputs = self.get_dummy_inputs(torch_device) + normal_output = pipe(**inputs)[0].flatten()[-32:] + + pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False) + inputs = self.get_dummy_inputs(torch_device) + compile_output = pipe(**inputs)[0].flatten()[-32:] + + # Note: Seems to require higher tolerance + self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3)) + + def test_memory_footprint(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: + transformer_quantized = self.get_dummy_components(FinegrainedFP8Config(weight_block_size=(32, 32), modules_to_not_convert=["x_embedder", "proj_out"]), model_id=model_id)["transformer"] + transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"] + + for name, module in transformer_quantized.named_modules(): + if isinstance(module, nn.Linear) and name not in ["x_embedder", "proj_out"]: + self.assertTrue(isinstance(module.weight, FP8Linear)) + + + total_quantized = self.get_model_size_in_bytes(transformer_quantized) + total_bf16 = self.get_model_size_in_bytes(transformer_bf16) + + self.assertTrue(total_quantized < total_bf16) + + def test_model_memory_usage(self): + model_id = "hf-internal-testing/tiny-flux-pipe" + expected_memory_saving_ratio = 2.0 + + inputs = self.get_dummy_tensor_inputs(device=torch_device) + + transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"] + transformer_bf16.to(torch_device) + unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs) + del transformer_bf16 + + transformer_quantized = self.get_dummy_components(FinegrainedFP8Config(weight_block_size=(32, 32), modules_to_not_convert=["x_embedder", "proj_out"]), model_id=model_id)["transformer"] + transformer_quantized.to(torch_device) + quantized_model_memory = get_memory_consumption_stat(transformer_quantized, inputs) + self.assertTrue(unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio) + + def test_exception_of_cpu_in_device_map(self): + r""" + A test that checks if inference runs as expected when sequential cpu offloading is enabled. + """ + quantization_config = FinegrainedFP8Config( + modules_to_not_convert=["x_embedder", "proj_out"], + weight_block_size=(32, 32) + ) + device_map = {"transformer_blocks.0": "cpu"} + + with self.assertRaisesRegex(ValueError, "You are attempting to load an FP8 model with a device_map that contains a cpu/disk device."): + _ = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + device_map=device_map, + torch_dtype=torch.bfloat16, + ) + + + +@require_torch +@require_torch_gpu +class TorchAoSerializationTest(unittest.TestCase): + model_name = "hf-internal-testing/tiny-flux-pipe" + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_model(self, device=None): + quantization_config = FinegrainedFP8Config(weight_block_size=(32, 32), modules_to_not_convert=["x_embedder", "proj_out"]) + quantized_model = FluxTransformer2DModel.from_pretrained( + self.model_name, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + return quantized_model.to(device) + + def get_dummy_tensor_inputs(self, device=None, seed: int = 0): + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + height = width = 4 + sequence_length = 48 + embedding_dim = 32 + + torch.manual_seed(seed) + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( + device, dtype=torch.bfloat16 + ) + pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) + text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) + image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) + timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_prompt_embeds, + "txt_ids": text_ids, + "img_ids": image_ids, + "timestep": timestep, + } + + def _test_original_model_expected_slice(self, expected_slice): + quantized_model = self.get_dummy_model(torch_device) + inputs = self.get_dummy_tensor_inputs(torch_device) + output = quantized_model(**inputs)[0] + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) + + def _check_serialization_expected_slice(self, expected_slice, device): + quantized_model = self.get_dummy_model(device) + + with tempfile.TemporaryDirectory() as tmp_dir: + quantized_model.save_pretrained(tmp_dir, safe_serialization=False) + loaded_quantized_model = FluxTransformer2DModel.from_pretrained( + tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False + ).to(device=torch_device) + + inputs = self.get_dummy_tensor_inputs(torch_device) + output = loaded_quantized_model(**inputs)[0] + + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) + + def test_slice_output(self): + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + device = "cuda" + self._test_original_model_expected_slice(expected_slice) + self._check_serialization_expected_slice(expected_slice, device) + +@require_torch +@require_torch_gpu +@slow +@nightly +class SlowTorchAoTests(unittest.TestCase): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_components(self, quantization_config: FinegrainedFP8Config): + # This is just for convenience, so that we can modify it at one place for custom environments and locally testing + cache_dir = None + model_id = "black-forest-labs/FLUX.1-dev" + transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + cache_dir=cache_dir, + ) + text_encoder = CLIPTextModel.from_pretrained( + model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, cache_dir=cache_dir + ) + text_encoder_2 = T5EncoderModel.from_pretrained( + model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, cache_dir=cache_dir + ) + tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir) + tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir) + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=cache_dir) + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device: torch.device, seed: int = 0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator().manual_seed(seed) + + inputs = { + "prompt": "an astronaut riding a horse in space", + "height": 512, + "width": 512, + "num_inference_steps": 20, + "output_type": "np", + "generator": generator, + } + + return inputs + + def _test_quant_output(self, quantization_config, expected_slice): + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0].flatten() + output_slice = np.concatenate((output[:16], output[-16:])) + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + def test_quantization(self): + # fmt: off + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + # fmt: on + + quantization_config = FinegrainedFP8Config(weight_block_size=(32, 32), modules_to_not_convert=["x_embedder", "proj_out"]) + self._test_quant_output(quantization_config, expected_slice) + + def test_serialization(self): + quantization_config = FinegrainedFP8Config(weight_block_size=(32, 32), modules_to_not_convert=["x_embedder", "proj_out"]) + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0].flatten()[:128] + + with tempfile.TemporaryDirectory() as tmp_dir: + pipe.transformer.save_pretrained(tmp_dir, safe_serialization=False) + pipe.remove_all_hooks() + del pipe.transformer + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + transformer = FluxTransformer2DModel.from_pretrained( + tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False + ) + pipe.transformer = transformer + + loaded_output = pipe(**inputs)[0].flatten()[:128] + # Seems to require higher tolerance depending on which machine it is being run. + # A difference of 0.06 in normalized pixel space (-1 to 1), corresponds to a difference of + # 0.06 / 2 * 255 = 7.65 in pixel space (0 to 255). On our CI runners, the difference is about 0.04, + # on DGX it is 0.06, and on audace it is 0.037. So, we are using a tolerance of 0.06 here. + self.assertTrue(np.allclose(output, loaded_output, atol=0.06)) \ No newline at end of file From c7072fea08d515b3b249ffccbd00f8212bf0f8fa Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Fri, 13 Jun 2025 12:30:08 +0000 Subject: [PATCH 09/10] rename --- .../test_finegrained_fp8.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/quantization/{finegrained_fp8.py => finegrained_fp8}/test_finegrained_fp8.py (100%) diff --git a/tests/quantization/finegrained_fp8.py/test_finegrained_fp8.py b/tests/quantization/finegrained_fp8/test_finegrained_fp8.py similarity index 100% rename from tests/quantization/finegrained_fp8.py/test_finegrained_fp8.py rename to tests/quantization/finegrained_fp8/test_finegrained_fp8.py From f54dcc0f0e4b9f91df5a8724259b27fb2eb89593 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Mon, 16 Jun 2025 07:55:51 +0000 Subject: [PATCH 10/10] update tests --- .../finegrained_fp8/test_finegrained_fp8.py | 123 +++++++++++++----- 1 file changed, 88 insertions(+), 35 deletions(-) diff --git a/tests/quantization/finegrained_fp8/test_finegrained_fp8.py b/tests/quantization/finegrained_fp8/test_finegrained_fp8.py index 420677775d76..3db5cd60337a 100644 --- a/tests/quantization/finegrained_fp8/test_finegrained_fp8.py +++ b/tests/quantization/finegrained_fp8/test_finegrained_fp8.py @@ -26,8 +26,8 @@ FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, - TorchAoConfig, ) + from diffusers.models.attention_processor import Attention from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -36,7 +36,6 @@ numpy_cosine_similarity_distance, require_torch, require_torch_gpu, - require_torchao_version_greater_or_equal, slow, torch_device, ) @@ -52,7 +51,7 @@ import torch import torch.nn as nn - from ..utils import LoRALayer, get_memory_consumption_stat + from tests.quantization.utils import LoRALayer, get_memory_consumption_stat @require_torch @@ -95,14 +94,23 @@ def get_dummy_components( subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch.bfloat16, + device_map=torch_device, + ) + text_encoder = CLIPTextModel.from_pretrained( + model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16 ) - text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) text_encoder_2 = T5EncoderModel.from_pretrained( model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16 ) - tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") - tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") - vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16) + tokenizer = CLIPTokenizer.from_pretrained( + model_id, subfolder="tokenizer" + ) + tokenizer_2 = AutoTokenizer.from_pretrained( + model_id, subfolder="tokenizer_2" + ) + vae = AutoencoderKL.from_pretrained( + model_id, subfolder="vae", torch_dtype=torch.bfloat16 + ) scheduler = FlowMatchEulerDiscreteScheduler() return { @@ -195,13 +203,18 @@ def _test_quantization_output(self, quantization_config: FinegrainedFP8Config, e self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_quantization(self): - expected_slice = [np.array([0.34179688, -0.03613281, 0.01428223, -0.22949219, -0.49609375, 0.4375, -0.1640625, -0.66015625, 0.43164062]), np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])] + expected_slice = [ + np.array([0.46679688, 0.51953125, 0.5546875, 0.421875, 0.44140625, 0.64453125, 0.43359375, 0.453125, 0.5625]), + np.array([0.46679688, 0.51953125, 0.5546875, 0.421875, 0.44140625, 0.64453125, 0.43359375, 0.453125, 0.5625]) + ] + for index, model_id in enumerate(["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]): quantization_config = FinegrainedFP8Config( modules_to_not_convert=["x_embedder", "proj_out"], weight_block_size=(32, 32) ) - self._test_quantization_output(quantization_config, model_id, expected_slice[index]) + + self._test_quantization_output(quantization_config, expected_slice[index], model_id) def test_dtype(self): """ @@ -216,13 +229,18 @@ def test_dtype(self): subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch.bfloat16, + device_map=torch_device, ) + layer = quantized_model.transformer_blocks[0].ff.net[2] weight = quantized_model.transformer_blocks[0].ff.net[2].weight weight_scale_inv = quantized_model.transformer_blocks[0].ff.net[2].weight_scale_inv - self.assertTrue(isinstance(weight, FP8Linear)) + + self.assertTrue(isinstance(layer, FP8Linear)) + self.assertEqual(weight_scale_inv.dtype, torch.bfloat16) - self.assertEqual(weight.weight.dtype, torch.float8_e4m3fn) + + self.assertEqual(weight.dtype, torch.float8_e4m3fn) def test_device_map_auto(self): """ @@ -232,16 +250,16 @@ def test_device_map_auto(self): inputs = self.get_dummy_tensor_inputs(torch_device) # requires with different expected slices since models are different due to offload (we don't quantize modules offloaded to cpu/disk) expected_slice_auto = np.array( - [ - 0.34179688, - -0.03613281, - 0.01428223, - -0.22949219, - -0.49609375, + [ + 0.34375, + -0.0402832, + 0.01226807, + -0.22851562, + -0.49414062, 0.4375, - -0.1640625, + -0.16992188, -0.66015625, - 0.43164062, + 0.43164062 ] ) @@ -259,6 +277,7 @@ def test_device_map_auto(self): output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice_auto) < 1e-3) @@ -269,6 +288,7 @@ def test_modules_to_not_convert(self): subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch.bfloat16, + device_map=torch_device, ) unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2] @@ -284,6 +304,7 @@ def test_modules_to_not_convert(self): subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch.bfloat16, + device_map=torch_device, ) size_quantized_with_not_convert = self.get_model_size_in_bytes(quantized_model_with_not_convert) @@ -298,7 +319,8 @@ def test_training(self): subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch.bfloat16, - ).to(torch_device) + device_map=torch_device, + ) for param in quantized_model.parameters(): # freeze the model as only adapter layers will be trained @@ -321,10 +343,10 @@ def test_training(self): if isinstance(module, LoRALayer): self.assertTrue(module.adapter[1].weight.grad is not None) self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) - + @nightly def test_torch_compile(self): - r"""Test that verifies if torch.compile works with torchao quantization.""" + r"""Test that verifies if torch.compile works with fp8 quantization.""" for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: quantization_config = FinegrainedFP8Config(weight_block_size=(32, 32), modules_to_not_convert=["x_embedder", "proj_out"]) components = self.get_dummy_components(quantization_config, model_id=model_id) @@ -350,9 +372,14 @@ def test_memory_footprint(self): transformer_quantized = self.get_dummy_components(FinegrainedFP8Config(weight_block_size=(32, 32), modules_to_not_convert=["x_embedder", "proj_out"]), model_id=model_id)["transformer"] transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"] - for name, module in transformer_quantized.named_modules(): - if isinstance(module, nn.Linear) and name not in ["x_embedder", "proj_out"]: - self.assertTrue(isinstance(module.weight, FP8Linear)) + for (name, module_quantized), (name_bf16, module_bf16) in zip(transformer_quantized.named_modules(), transformer_bf16.named_modules()): + if isinstance(module_bf16, nn.Linear) and name_bf16.split(".")[-1] not in ["x_embedder", "proj_out"]: + + self.assertTrue(isinstance(module_quantized, FP8Linear)) + + self.assertEqual(module_quantized.weight.shape, module_bf16.weight.shape) + + self.assertEqual(name, name_bf16) total_quantized = self.get_model_size_in_bytes(transformer_quantized) @@ -373,7 +400,9 @@ def test_model_memory_usage(self): transformer_quantized = self.get_dummy_components(FinegrainedFP8Config(weight_block_size=(32, 32), modules_to_not_convert=["x_embedder", "proj_out"]), model_id=model_id)["transformer"] transformer_quantized.to(torch_device) + quantized_model_memory = get_memory_consumption_stat(transformer_quantized, inputs) + self.assertTrue(unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio) def test_exception_of_cpu_in_device_map(self): @@ -413,6 +442,7 @@ def get_dummy_model(self, device=None): subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch.bfloat16, + device_map=torch_device, ) return quantized_model.to(device) @@ -448,6 +478,7 @@ def _test_original_model_expected_slice(self, expected_slice): inputs = self.get_dummy_tensor_inputs(torch_device) output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) def _check_serialization_expected_slice(self, expected_slice, device): @@ -463,18 +494,34 @@ def _check_serialization_expected_slice(self, expected_slice, device): output = loaded_quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) def test_slice_output(self): - expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) - device = "cuda" + expected_slice = np.array( + [ + 0.34960938, + -0.12109375, + -0.02648926, + -0.25195312, + -0.45898438, + 0.49609375, + -0.14453125, + -0.69921875, + 0.44921875 + ] + ) + + device = torch_device + self._test_original_model_expected_slice(expected_slice) + self._check_serialization_expected_slice(expected_slice, device) @require_torch @require_torch_gpu @slow -@nightly + class SlowTorchAoTests(unittest.TestCase): def tearDown(self): gc.collect() @@ -490,6 +537,7 @@ def get_dummy_components(self, quantization_config: FinegrainedFP8Config): quantization_config=quantization_config, torch_dtype=torch.bfloat16, cache_dir=cache_dir, + device_map=torch_device, ) text_encoder = CLIPTextModel.from_pretrained( model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, cache_dir=cache_dir @@ -497,9 +545,15 @@ def get_dummy_components(self, quantization_config: FinegrainedFP8Config): text_encoder_2 = T5EncoderModel.from_pretrained( model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, cache_dir=cache_dir ) - tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir) - tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir) - vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=cache_dir) + tokenizer = CLIPTokenizer.from_pretrained( + model_id, subfolder="tokenizer", cache_dir=cache_dir + ) + tokenizer_2 = AutoTokenizer.from_pretrained( + model_id, subfolder="tokenizer_2", cache_dir=cache_dir + ) + vae = AutoencoderKL.from_pretrained( + model_id, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=cache_dir + ) scheduler = FlowMatchEulerDiscreteScheduler() return { @@ -536,12 +590,11 @@ def _test_quant_output(self, quantization_config, expected_slice): inputs = self.get_dummy_inputs(torch_device) output = pipe(**inputs)[0].flatten() output_slice = np.concatenate((output[:16], output[-16:])) + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_quantization(self): - # fmt: off - expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) - # fmt: on + expected_slice = np.array([0., -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) quantization_config = FinegrainedFP8Config(weight_block_size=(32, 32), modules_to_not_convert=["x_embedder", "proj_out"]) self._test_quant_output(quantization_config, expected_slice) @@ -562,7 +615,7 @@ def test_serialization(self): torch.cuda.empty_cache() torch.cuda.synchronize() transformer = FluxTransformer2DModel.from_pretrained( - tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False + tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False, device_map=torch_device ) pipe.transformer = transformer