-
Notifications
You must be signed in to change notification settings - Fork 15
Support DeepSeekV3-style block FP8 quantization #372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -111,11 +111,15 @@ def dequantize( | |
elif scale.ndim == 2: | ||
if scale.shape[1] == 1: | ||
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL) | ||
else: | ||
elif scale.shape[0] == 1: | ||
group_size = int(x_q.shape[1] / scale.shape[1]) | ||
args = QuantizationArgs( | ||
strategy=QuantizationStrategy.GROUP, group_size=group_size | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a docstring explaining why this falls into the else condition? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dsikka Do you have a sense of why this logic exists at all/ what cases its used? I'm not sure if inferring quant strat is really safe practice |
||
else: | ||
args = QuantizationArgs( | ||
strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape | ||
) | ||
else: | ||
raise ValueError( | ||
f"Could not infer a quantization strategy from scale with {scale.ndim} " | ||
|
@@ -189,7 +193,63 @@ def _process_quantization( | |
q_min, q_max = calculate_range(args, x.device) | ||
group_size = args.group_size | ||
|
||
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): | ||
# blockwise FP8: quantize per 2D block, supports block_structure for static block quant | ||
if args.strategy == QuantizationStrategy.BLOCK: | ||
original_shape = x.shape | ||
rows, cols = x.shape[-2], x.shape[-1] | ||
block_height, block_width = args.block_structure | ||
|
||
# Ensure exact division (tensor dimensions must be divisible by block size) | ||
if rows % block_height != 0: | ||
raise ValueError( | ||
f"Tensor height {rows} is not divisible by block_height {block_height}. " | ||
f"Block quantization requires exact division." | ||
) | ||
if cols % block_width != 0: | ||
raise ValueError( | ||
f"Tensor width {cols} is not divisible by block_width {block_width}. " | ||
f"Block quantization requires exact division." | ||
) | ||
|
||
# reshape into blocks | ||
num_rows_blocks = rows // block_height | ||
num_cols_blocks = cols // block_width | ||
x_blocks = x.reshape( | ||
num_rows_blocks, | ||
block_height, | ||
num_cols_blocks, | ||
block_width, | ||
).transpose(1, 2) | ||
|
||
# expand scale/zero_point for blocks | ||
sb = scale.unsqueeze(-1).unsqueeze(-1) | ||
zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None | ||
if do_quantize: | ||
# quantize blocks | ||
x_blocks = _quantize( | ||
x=x_blocks, | ||
scale=sb, | ||
zero_point=zb, | ||
q_min=q_min, | ||
q_max=q_max, | ||
args=args, | ||
dtype=dtype, | ||
global_scale=global_scale, | ||
) | ||
if do_dequantize: | ||
# dequantize blocks | ||
x_blocks = _dequantize( | ||
x_q=x_blocks, | ||
scale=sb, | ||
zero_point=zb, | ||
global_scale=global_scale, | ||
) | ||
# restore original shape | ||
output = x_blocks.transpose(1, 2).reshape(original_shape) | ||
elif args.strategy in ( | ||
QuantizationStrategy.GROUP, | ||
QuantizationStrategy.TENSOR_GROUP, | ||
): | ||
n_dims = x.shape | ||
if len(n_dims) > 2: | ||
x = x.squeeze(0) | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -14,7 +14,7 @@ | |||||
|
||||||
import warnings | ||||||
from enum import Enum | ||||||
from typing import Any, Dict, Optional, Union | ||||||
from typing import Any, Dict, List, Optional, Union | ||||||
|
||||||
import torch | ||||||
from compressed_tensors.utils import Aliasable | ||||||
|
@@ -153,8 +153,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True): | |||||
:param symmetric: whether or not quantization scale is symmetric about zero-point | ||||||
:param strategy: string id determining the scope of scale/zero-point to apply | ||||||
:param group_size: group length to use for the group strategy | ||||||
:param block_structure: 2d block structure to use for the block strategy, must be | ||||||
of the format "2x4", "8x16", etc. | ||||||
:param block_structure: 2d block structure to use for the block strategy; must be | ||||||
a list of two ints [rows, cols] like [128, 128]. | ||||||
:param dynamic: set True to perform dynamic quantization - values will not be | ||||||
calibrated during calibration phase, instead during inference new quantization | ||||||
ranges will be observed with every sample. Defaults to False for static | ||||||
|
@@ -169,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True): | |||||
symmetric: bool = True | ||||||
group_size: Optional[int] = None | ||||||
strategy: Optional[QuantizationStrategy] = None | ||||||
block_structure: Optional[str] = None | ||||||
block_structure: Optional[List[int]] = None | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel towards keeping it a list since I don't think we can distinguish between list and tuple in the final json config |
||||||
dynamic: Union[DynamicType, bool] = False | ||||||
actorder: Union[ActivationOrdering, bool, None] = None | ||||||
observer: Optional[str] = Field( | ||||||
|
@@ -207,6 +207,28 @@ def validate_group(cls, value) -> Union[int, None]: | |||||
|
||||||
return value | ||||||
|
||||||
@field_validator("block_structure", mode="before") | ||||||
def validate_block_structure(cls, value) -> Optional[List[int]]: | ||||||
if value is None: | ||||||
return value | ||||||
# For backward compatibility, allow string format "2x4", "8x16", etc. | ||||||
if isinstance(value, str): | ||||||
try: | ||||||
return [int(x) for x in value.split("x")] | ||||||
except Exception: | ||||||
raise ValueError( | ||||||
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]." | ||||||
) | ||||||
if isinstance(value, (list, tuple)): | ||||||
if len(value) != 2 or not all(isinstance(v, int) for v in value): | ||||||
raise ValueError( | ||||||
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]." | ||||||
) | ||||||
return list(value) | ||||||
raise ValueError( | ||||||
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]." | ||||||
) | ||||||
|
||||||
@field_validator("strategy", mode="before") | ||||||
def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]: | ||||||
if isinstance(value, str): | ||||||
|
@@ -277,14 +299,15 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": | |||||
|
||||||
# infer observer w.r.t. dynamic | ||||||
if dynamic: | ||||||
if strategy not in ( | ||||||
supported_strategies = ( | ||||||
QuantizationStrategy.TOKEN, | ||||||
QuantizationStrategy.TENSOR, | ||||||
QuantizationStrategy.TENSOR_GROUP, | ||||||
): | ||||||
QuantizationStrategy.GROUP, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is mostly an aesthetic choice, but it might have aesthetic consequences if vllm wants to support fused input-weight quantization. Ex There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might want to add some validation on quant_scheme related to this as well |
||||||
) | ||||||
if strategy not in supported_strategies: | ||||||
raise ValueError( | ||||||
f"One of {(QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP)} " | ||||||
"must be used for dynamic quantization", | ||||||
f"One of {supported_strategies} must be used for dynamic quantization" | ||||||
) | ||||||
|
||||||
if ( | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove