Skip to content

[WIP] NVFP4 Emulation #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 140 additions & 24 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from torch.nn import Module
from torch.nn.parameter import Parameter

from vllm._custom_ops import (cutlass_scaled_fp4_mm,
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
Expand All @@ -21,12 +20,109 @@
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types

logger = init_logger(__name__)

QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]

FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max()

kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
dtype=torch.float32)


def break_fp4_bytes(a, dtype):
assert a.dtype == torch.uint8
m, n = a.shape
# Vectorized nibble processing
a_flat = a.flatten()
high = (a_flat & 0xF0) >> 4 # Upper nibbles
low = a_flat & 0x0F # Lower nibbles
# Combine nibbles for batch processing
combined = torch.stack((low, high), dim=1).flatten()
# Vectorized sign and magnitude extraction
signs = (combined & 0x08).to(torch.bool) # Sign bits
abs_vals = (combined & 0x07).to(torch.long)
# Device-aware lookup and sign application
kE2M1 = kE2M1ToFloat.to(device=a.device)
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
# Reshape to final form
return values.reshape(m, n * 2).to(dtype=dtype)


def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
m_tiles = (m + 128 - 1) // 128
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
return out[0:m, 0:k]


def dequantize_to_dtype(tensor_fp4,
tensor_sf,
global_scale,
dtype,
device,
block_size=16):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
m, packed_k = tensor_fp4.shape
k = packed_k * 2
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale

# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out.to(dtype)


def cast_to_fp4(x):
sign = torch.sign(x)
x = torch.abs(x)
x[(x >= 0.0) & (x <= 0.25)] = 0.0
x[(x > 0.25) & (x < 0.75)] = 0.5
x[(x >= 0.75) & (x <= 1.25)] = 1.0
x[(x > 1.25) & (x < 1.75)] = 1.5
x[(x >= 1.75) & (x <= 2.5)] = 2.0
x[(x > 2.5) & (x < 3.5)] = 3.0
x[(x >= 3.5) & (x <= 5.0)] = 4.0
x[x > 5.0] = 6.0
return x * sign


def get_reciprocal(x):
if isinstance(x, torch.Tensor):
return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x)
elif isinstance(x, (float, int)):
return 0.0 if x == 0 else 1.0 / x
else:
raise TypeError("Input must be a float, int, or a torch.Tensor.")


def ref_nvfp4_quant(x, global_scale, block_size):
assert global_scale.dtype == torch.float32
assert x.ndim == 2
m, n = x.shape
x = torch.reshape(x, (m, n // block_size, block_size))
vec_max = torch.max(torch.abs(x), dim=-1,
keepdim=True)[0].to(torch.float32)
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))

scaled_x = x.to(torch.float32) * output_scale
clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)
# both outputs are float32
return cast_to_fp4(clipped_x), scale.squeeze(-1)


class ModelOptFp8Config(QuantizationConfig):
"""Config class for ModelOpt FP8."""
Expand Down Expand Up @@ -193,7 +289,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]:

@classmethod
def get_min_capability(cls) -> int:
return 100
return 89

@classmethod
def get_config_filenames(cls) -> List[str]:
Expand Down Expand Up @@ -262,10 +358,10 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):

def __init__(self, quant_config: ModelOptNvFp4Config):
self.quant_config = quant_config
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
if not self.cutlass_nvfp4_supported:
raise ValueError("Current platform does not support NVFP4"
" quantization. Please use Blackwell and above.")
# self.cutlass_nvfp4_supported = cutlass_fp4_supported()
# if not self.cutlass_nvfp4_supported:
# raise ValueError("Current platform does not support NVFP4"
# " quantization. Please use Blackwell and above.")

def create_weights(
self,
Expand Down Expand Up @@ -386,25 +482,45 @@ def apply(
output_dtype = x.dtype

# for input only the contracting dimension has a constraint.
x_m, _ = x.shape
w_n, _ = layer.weight.shape
x_m, x_k = x.shape
w_n, w_k = layer.weight.shape
# print(f"{x.shape=}")
# print(f"{layer.weight.shape=}")
output_shape = [x_m, w_n]
block_size = 16

# quantize input to (FP4 and interleaved block scale)
# x_global_scale = layer.input_scale
x_global_scale = 1 / layer.input_scale
# x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant)
x_fp4, x_blockscale = ref_nvfp4_quant(x, x_global_scale, block_size)
# x_blockscale = self.swizzle_blockscale(x_blockscale)
# print(f"{x_fp4.shape=}")
# print(f"{x_blockscale.shape=}")

# dequantize input
x_fp4 = x_fp4.reshape(x_m, x_k // block_size, block_size)
x_blockscale = x_blockscale.unsqueeze(-1) / x_global_scale
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
del x_fp4, x_blockscale

# dequantize weight
w_fp4 = layer.weight.data.view(torch.uint8)
w_blockscale = layer.weight_scale_swizzled.data
w_global_scale = layer.weight_scale_2
# print(f"{w_fp4.shape=}")
# print(f"{w_blockscale.shape=}")
# print(f"{w_global_scale.shape=}")
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
output_dtype, x.device,
block_size).to(output_dtype)
# print(f"{w_dq.shape=}")

# matmul
out = torch.matmul(x_dq, w_dq.t())
del x_dq, w_dq
# print(f"{out.shape=}")

# quantize BF16 or FP16 to (FP4 and interleaved block scale)
s_quant = 1 / layer.input_scale
x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant)

# validate dtypes of quantized input, input block scale,
# weight and weight_blockscale
assert (x_fp4.dtype == torch.uint8)
assert (layer.weight.dtype == torch.uint8)
assert (x_blockscale.dtype == torch.float8_e4m3fn)
assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
assert (layer.alpha.dtype == torch.float32)

out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
layer.weight_scale_swizzled, layer.alpha,
output_dtype)
if bias is not None:
out = out + bias
return out.view(*output_shape)