Skip to content

NVFP4 Emulation #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions run_fp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

import numpy
import torch

from vllm import LLM, SamplingParams

prompts = [
"The Swiss Alps are", "The president of the USA is",
"The Boston Bruins are"
]

# Create a sampling params object for greedy sampling
sampling_params = SamplingParams(temperature=0.80, top_p=0.95, max_tokens=40, min_tokens=10)
#llm = LLM('nm-testing/Llama-3.1-8B-Instruct-FP4-Weight')
#llm = LLM("/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/TinyLlama-1.1B-Chat-v1.0-FP4")

Check failure on line 15 in run_fp4.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

run_fp4.py:15:81: E501 Line too long (102 > 80)
#llm = LLM("/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/Llama-3.1-8B-Instruct-NVFP4A16")

Check failure on line 16 in run_fp4.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

run_fp4.py:16:81: E501 Line too long (104 > 80)
#llm = LLM("/home/dsikka/llm-compressor/examples/quantization_w4a16_fp4/Llama-3.1-8B-Instruct-NVFP4A16-MSE")

Check failure on line 17 in run_fp4.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

run_fp4.py:17:81: E501 Line too long (108 > 80)
#llm = LLM("nm-testing/Llama-3.3-70B-Instruct-NVFP4A16", max_model_len=4096)
# Print the outputs.
llm = LLM("nvidia/Llama-3.3-70B-Instruct-FP4", max_model_len=4096, quantization="nvfp4", enforce_eager=True)
output = llm.generate(prompts, sampling_params)
for o in output:
print(o.outputs[0].text)
print("\n")
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16, CompressedTensorsW4A4Fp4)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
Expand Down Expand Up @@ -299,6 +299,13 @@
# All conditions satisfied.
return True

def _is_fp4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
is_group_quant = weight_quant.strategy == QuantizationStrategy.GROUP.value
is_group_size_16 = weight_quant.group_size == 16

Check failure on line 304 in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py:304:81: E501 Line too long (82 > 80)
is_float_type = weight_quant.type == QuantizationType.FLOAT

Check failure on line 305 in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py:305:9: F841 Local variable `is_group_size_16` is assigned to but never used
is_4_bits = weight_quant.num_bits == 4
return is_group_quant and is_float_type and is_4_bits

def _is_wNa16_group_channel(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
input_quant_none = input_quant is None
Expand All @@ -313,6 +320,8 @@
self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":

if self._is_fp4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4Fp4()
# Detect If Mixed Precision
if self._is_wNa16_group_channel(weight_quant, input_quant):
if (self.quant_format == CompressionFormat.marlin_24.value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
CompressedTensorsWNA16)
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4

from .compressed_tensors_24 import CompressedTensors24 # isort: skip

Expand All @@ -16,5 +17,5 @@
"CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24",
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
"CompressedTensors24"
"CompressedTensors24", "CompressedTensorsW4A4Fp4"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Callable, List, Optional

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (

Check failure on line 11 in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py:11:81: E501 Line too long (81 > 80)
dequantize_to_dtype)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)

__all__ = ["CompressedTensorsW4A4Fp4"]


class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):

def __init__(self):
self.group_size = 16

@classmethod
def get_min_capability(cls) -> int:
# dont restrict as emulations
return 80

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

# Weight
self.output_partition_sizes = output_partition_sizes
self.params_dtype = params_dtype
weight = ModelWeightParameter(
data=torch.empty(
# 2 fp4 items are packed in the input dimension
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=torch.uint8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)

# Global Weight Scale
weight_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("weight_global_scale", weight_global_scale)

# Per Group Weight Scale
weight_scale = GroupQuantScaleParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // self.group_size,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

layer.register_parameter("weight_scale", weight_scale)

def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))

def process_weights_after_loading(self, layer) -> None:
print(layer.weight_global_scale)
layer.weight_global_scale = Parameter(
layer.weight_global_scale.max().to(torch.float32),
requires_grad=False)
# Note: a post weight loading step but not required for the emulation
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

w_fp4 = layer.weight_packed.data
w_global_scale = layer.weight_global_scale
w_blockscale = layer.weight_scale_swizzled.data
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
x.dtype, x.device, self.group_size)
out = F.linear(x, w_dq)
del w_dq
return out
132 changes: 11 additions & 121 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, List, Optional, Union

import torch
import torch._dynamo
from torch.nn import Module
from torch.nn.parameter import Parameter

Expand All @@ -13,116 +14,23 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (

Check failure on line 17 in vllm/model_executor/layers/quantization/modelopt.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/quantization/modelopt.py:17:81: E501 Line too long (81 > 80)
dequantize_to_dtype, ref_nvfp4_quant)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types

torch._dynamo.config.suppress_errors = True

logger = init_logger(__name__)

QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]

FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max()

kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
dtype=torch.float32)


def break_fp4_bytes(a, dtype):
assert a.dtype == torch.uint8
m, n = a.shape
# Vectorized nibble processing
a_flat = a.flatten()
high = (a_flat & 0xF0) >> 4 # Upper nibbles
low = a_flat & 0x0F # Lower nibbles
# Combine nibbles for batch processing
combined = torch.stack((low, high), dim=1).flatten()
# Vectorized sign and magnitude extraction
signs = (combined & 0x08).to(torch.bool) # Sign bits
abs_vals = (combined & 0x07).to(torch.long)
# Device-aware lookup and sign application
kE2M1 = kE2M1ToFloat.to(device=a.device)
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
# Reshape to final form
return values.reshape(m, n * 2).to(dtype=dtype)


def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
m_tiles = (m + 128 - 1) // 128
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
return out[0:m, 0:k]


def dequantize_to_dtype(tensor_fp4,
tensor_sf,
global_scale,
dtype,
device,
block_size=16):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
m, packed_k = tensor_fp4.shape
k = packed_k * 2
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale

# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out.to(dtype)


def cast_to_fp4(x):
sign = torch.sign(x)
x = torch.abs(x)
x[(x >= 0.0) & (x <= 0.25)] = 0.0
x[(x > 0.25) & (x < 0.75)] = 0.5
x[(x >= 0.75) & (x <= 1.25)] = 1.0
x[(x > 1.25) & (x < 1.75)] = 1.5
x[(x >= 1.75) & (x <= 2.5)] = 2.0
x[(x > 2.5) & (x < 3.5)] = 3.0
x[(x >= 3.5) & (x <= 5.0)] = 4.0
x[x > 5.0] = 6.0
return x * sign


def get_reciprocal(x):
if isinstance(x, torch.Tensor):
return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x)
elif isinstance(x, (float, int)):
return 0.0 if x == 0 else 1.0 / x
else:
raise TypeError("Input must be a float, int, or a torch.Tensor.")


def ref_nvfp4_quant(x, global_scale, block_size):
assert global_scale.dtype == torch.float32
assert x.ndim == 2
m, n = x.shape
x = torch.reshape(x, (m, n // block_size, block_size))
vec_max = torch.max(torch.abs(x), dim=-1,
keepdim=True)[0].to(torch.float32)
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))

scaled_x = x.to(torch.float32) * output_scale
clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)
# both outputs are float32
return cast_to_fp4(clipped_x), scale.squeeze(-1)


class ModelOptFp8Config(QuantizationConfig):
"""Config class for ModelOpt FP8."""
Expand Down Expand Up @@ -289,7 +197,7 @@

@classmethod
def get_min_capability(cls) -> int:
return 89
return 80

@classmethod
def get_config_filenames(cls) -> List[str]:
Expand Down Expand Up @@ -483,20 +391,11 @@

# for input only the contracting dimension has a constraint.
x_m, x_k = x.shape
w_n, w_k = layer.weight.shape
# print(f"{x.shape=}")
# print(f"{layer.weight.shape=}")
output_shape = [x_m, w_n]
block_size = 16
block_size = group_size = 16

Check failure on line 394 in vllm/model_executor/layers/quantization/modelopt.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/model_executor/layers/quantization/modelopt.py:394:22: F841 Local variable `group_size` is assigned to but never used

# quantize input to (FP4 and interleaved block scale)
# x_global_scale = layer.input_scale
x_global_scale = 1 / layer.input_scale
# x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant)
x_fp4, x_blockscale = ref_nvfp4_quant(x, x_global_scale, block_size)
# x_blockscale = self.swizzle_blockscale(x_blockscale)
# print(f"{x_fp4.shape=}")
# print(f"{x_blockscale.shape=}")

# dequantize input
x_fp4 = x_fp4.reshape(x_m, x_k // block_size, block_size)
Expand All @@ -507,20 +406,11 @@
# dequantize weight
w_fp4 = layer.weight.data.view(torch.uint8)
w_blockscale = layer.weight_scale_swizzled.data
w_global_scale = layer.weight_scale_2
# print(f"{w_fp4.shape=}")
# print(f"{w_blockscale.shape=}")
# print(f"{w_global_scale.shape=}")
w_global_scale = 1 / layer.weight_scale_2
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
output_dtype, x.device,
block_size).to(output_dtype)
# print(f"{w_dq.shape=}")
output_dtype, x.device, block_size)

# matmul
out = torch.matmul(x_dq, w_dq.t())
del x_dq, w_dq
# print(f"{out.shape=}")

if bias is not None:
out = out + bias
return out.view(*output_shape)
del w_dq, x_dq
return out
Loading
Loading