Skip to content

[NVFP4][WIP] Add NVFp4 Support #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
# flake8: noqa

from .base import *
from .modelopt_quantized import *
from .naive_quantized import *
from .pack_quantized import *
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def compress(
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
global_scale = model_state.get(
merge_names(prefix, "weight_global_scale"), None
)
if scale is not None:
# weight is quantized, compress it
if isinstance(names_to_scheme[prefix], tuple):
Expand All @@ -125,6 +128,7 @@ def compress(
scale=scale,
zero_point=zp,
g_idx=g_idx,
global_scale=global_scale,
quantization_args=quant_args,
device="cpu",
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Dict, Optional, Tuple

import numpy
import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.quantized_compressors.base import (
BaseQuantizationCompressor,
)
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import QuantizationArgs
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
from torch import Tensor


__all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8"]

FLOAT_TO_E2M1 = [
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
]


@BaseCompressor.register(name=CompressionFormat.modelopt_quantized.value)
class ModelOptCompressor(BaseQuantizationCompressor):
"""
Implements naive compression for quantized models. Weight of each
quantized layer is converted from its original float type to the closest Pytorch
type to the type specified by the layer's QuantizationArgs.
"""

@property
def compression_param_names(self) -> Tuple[str]:
"""
Returns a tuple of compression parameter names introduced by
the compressor during compression
"""
return (
"weight_packed",
"weight_scale",
"weight_zero_point",
"weight_global_scale",
)

def compress_weight(
self,
weight: Tensor,
scale: Tensor,
global_scale: Tensor,
quantization_args: QuantizationArgs,
device: Optional[torch.device] = None,
zero_point: Optional[torch.Tensor] = None,
g_idx: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:

quantized_weight = quantize(
x=weight,
scale=scale,
global_scale=global_scale,
zero_point=zero_point,
args=quantization_args,
)
compressed_dict = {}
weight_packed = pack_fp4_to_uint8(quantized_weight)
if device is not None:
weight_packed = weight_packed.to(device)
compressed_dict["weight_packed"] = weight_packed
return compressed_dict

def decompress_weight(
self,
compressed_data: Dict[str, Tensor],
quantization_args: Optional[QuantizationArgs] = None,
) -> torch.Tensor:

weight = compressed_data["weight_packed"]
scale = compressed_data["weight_scale"]
global_scale = compressed_data["weight_global_scale"]
m, n = weight.shape
# TODO: we may not always use the global_scale dtype as the detype to dequant
# We need to pass in the pretrained model dtype to the compressors
unpacked = unpack_fp4_from_uint8(weight, m, n * 2)
decompressed_weight = dequantize(
x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype
)

return decompressed_weight


def pack_fp4_to_uint8(x: torch.Tensor):
m, n = x.shape
device = x.device

# Create lookup table for FP4 values to indices
# Map the absolute values to 0-7 indices
kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device, dtype=x.dtype)

# Find closest valid FP4 value index for each element
abs_x = torch.abs(x)
abs_indices = torch.zeros_like(abs_x, dtype=torch.long)
for i, val in enumerate(kE2M1):
abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices)

# Apply sign bit (bit 3) to get final 4-bit representation
indices = abs_indices + (torch.signbit(x) * 8).to(torch.long)

# Reshape to prepare for packing pairs of values
indices = indices.reshape(-1)

# Handle odd length by padding if necessary
if indices.numel() % 2 != 0:
indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)])

# Reshape to pair consecutive elements
indices = indices.reshape(-1, 2)

# Pack pairs of 4-bit values into 8-bit values
packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8)

return packed.reshape(m, n // 2)


kE2M1ToFloat = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
)

# reference: : https://github.com/vllm-project/vllm/pull/16362
def unpack_fp4_from_uint8(a: torch.Tensor, m: int, n: int, dtype=torch.bfloat16):
assert a.dtype == torch.uint8

# Vectorized nibble processing
a_flat = a.flatten()
high = (a_flat & 0xF0) >> 4 # Upper nibbles
low = a_flat & 0x0F # Lower nibbles

# Combine nibbles for batch processing
combined = torch.stack((low, high), dim=1).flatten()

# Vectorized sign and magnitude extraction
signs = (combined & 0x08).to(torch.bool) # Sign bits
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices

# Device-aware lookup and sign application
kE2M1 = kE2M1ToFloat.to(device=a.device)
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)

# Reshape to final form
return values.reshape(m, n).to(dtype=dtype)
1 change: 1 addition & 0 deletions src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class CompressionFormat(Enum):
naive_quantized = "naive-quantized"
pack_quantized = "pack-quantized"
marlin_24 = "marlin-24"
modelopt_quantized = "modelopt-quantized"


@unique
Expand Down
59 changes: 58 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
from compressed_tensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_args import (
FP4_E2M1_DATA,
FP8_E4M3_DATA,
QuantizationArgs,
)
from compressed_tensors.quantization.quant_config import (
QuantizationConfig,
QuantizationStatus,
Expand Down Expand Up @@ -238,6 +242,55 @@ def process_kv_cache_config(
return config


def is_attention_module(module: Module):
return "attention" in module.__class__.__name__.lower() and (
hasattr(module, "k_proj")
or hasattr(module, "v_proj")
or hasattr(module, "qkv_proj")
)


def is_mlp_module(module: Module):
return "mlp" in module.__class__.__name__.lower() and (
hasattr(module, "gate_proj") or hasattr(module, "up_porj")
)


def update_fp4_global_scales(model):
for name, submodule in iter_named_quantizable_modules(
model,
include_attn=True,
include_mlp=True,
):
if is_attention_module(submodule):
q_weight = submodule.q_proj.weight.data
v_weight = submodule.v_proj.weight.data
k_weight = submodule.k_proj.weight.data
all_data = torch.cat((q_weight, v_weight, k_weight), dim=0)

scale_dtype = FP8_E4M3_DATA.dtype
tensor_amax = torch.abs(all_data.data).max().to(torch.float32)
value = FP8_E4M3_DATA.max * FP4_E2M1_DATA.max / tensor_amax
value = value.to(torch.float32)

update_parameter_data(submodule.q_proj, value, "weight_global_scale")
update_parameter_data(submodule.k_proj, value, "weight_global_scale")
update_parameter_data(submodule.v_proj, value, "weight_global_scale")

if is_mlp_module(submodule):
gate_data = submodule.gate_proj.weight.data
up_data = submodule.up_proj.weight.data
all_data = torch.cat((gate_data, up_data), dim=0)

scale_dtype = FP8_E4M3_DATA.dtype
tensor_amax = torch.abs(all_data.data).max().to(torch.float32)
value = FP8_E4M3_DATA.max * FP4_E2M1_DATA.max / tensor_amax
value = value.to(torch.float32)

update_parameter_data(submodule.gate_proj, value, "weight_global_scale")
update_parameter_data(submodule.up_proj, value, "weight_global_scale")


def apply_quantization_status(model: Module, status: QuantizationStatus):
"""
Applies in place the quantization lifecycle up to the given status
Expand Down Expand Up @@ -266,6 +319,10 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
)
)

# hacks
if status == QuantizationStatus.INITIALIZED:
update_fp4_global_scales(model)

if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
model.apply(compress_quantized_weights)

Expand Down
Loading