Skip to content

Commit 09b7ed4

Browse files
mgoinshanjiaz
andauthored
Support DeepSeekV3-style block FP8 quantization (#372)
* Support DeepSeekV3-style block FP8 quantization Signed-off-by: mgoin <michael@neuralmagic.com> * Remove math Signed-off-by: mgoin <michael@neuralmagic.com> * Enforce divisible shapes Signed-off-by: mgoin <michael@neuralmagic.com> * Format Signed-off-by: mgoin <michael@neuralmagic.com> * Remove validation Signed-off-by: mgoin <michael@neuralmagic.com> * Fix string Signed-off-by: mgoin <michael@neuralmagic.com> * fix failed tests Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * address review comments Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * new line Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> --------- Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> Co-authored-by: shanjiaz <zsjwpianpian@gmail.com>
1 parent 180226b commit 09b7ed4

File tree

8 files changed

+234
-18
lines changed

8 files changed

+234
-18
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,18 @@ def dequantize(
111111
elif scale.ndim == 2:
112112
if scale.shape[1] == 1:
113113
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
114-
else:
114+
# Scale height matches input or is 1 -> group quantization across columns
115+
#
116+
# Example 1: scale.shape[0] == 1
117+
# x_q: (4, 8), scale: (1, 4) -> 2 columns per group
118+
#
119+
# Example 2: scale.shape[0] == x_q.shape[0]
120+
# x_q: (4, 8), scale: (4, 4) -> 2 elements per group (per row)
121+
elif (scale.shape[0] == 1) or (scale.shape[0] == x_q.shape[0]):
115122
group_size = int(x_q.shape[1] / scale.shape[1])
116-
args = QuantizationArgs(
117-
strategy=QuantizationStrategy.GROUP, group_size=group_size
118-
)
123+
args = QuantizationArgs(strategy=QuantizationStrategy.GROUP, group_size=group_size)
124+
else:
125+
args = QuantizationArgs(strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape)
119126
else:
120127
raise ValueError(
121128
f"Could not infer a quantization strategy from scale with {scale.ndim} "
@@ -189,7 +196,63 @@ def _process_quantization(
189196
q_min, q_max = calculate_range(args, x.device)
190197
group_size = args.group_size
191198

192-
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
199+
# blockwise FP8: quantize per 2D block, supports block_structure for static block quant
200+
if args.strategy == QuantizationStrategy.BLOCK:
201+
original_shape = x.shape
202+
rows, cols = x.shape[-2], x.shape[-1]
203+
block_height, block_width = args.block_structure
204+
205+
# Ensure exact division (tensor dimensions must be divisible by block size)
206+
if rows % block_height != 0:
207+
raise ValueError(
208+
f"Tensor height {rows} is not divisible by block_height {block_height}. "
209+
f"Block quantization requires exact division."
210+
)
211+
if cols % block_width != 0:
212+
raise ValueError(
213+
f"Tensor width {cols} is not divisible by block_width {block_width}. "
214+
f"Block quantization requires exact division."
215+
)
216+
217+
# reshape into blocks and transpose to make each block contiguous
218+
num_rows_blocks = rows // block_height
219+
num_cols_blocks = cols // block_width
220+
x_blocks = x.reshape(
221+
num_rows_blocks,
222+
block_height,
223+
num_cols_blocks,
224+
block_width,
225+
).transpose(1, 2)
226+
227+
# expand scale/zero_point for blocks
228+
sb = scale.unsqueeze(-1).unsqueeze(-1)
229+
zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None
230+
if do_quantize:
231+
# quantize blocks
232+
x_blocks = _quantize(
233+
x=x_blocks,
234+
scale=sb,
235+
zero_point=zb,
236+
q_min=q_min,
237+
q_max=q_max,
238+
args=args,
239+
dtype=dtype,
240+
global_scale=global_scale,
241+
)
242+
if do_dequantize:
243+
# dequantize blocks
244+
x_blocks = _dequantize(
245+
x_q=x_blocks,
246+
scale=sb,
247+
zero_point=zb,
248+
global_scale=global_scale,
249+
)
250+
# restore original shape
251+
output = x_blocks.transpose(1, 2).reshape(original_shape)
252+
elif args.strategy in (
253+
QuantizationStrategy.GROUP,
254+
QuantizationStrategy.TENSOR_GROUP,
255+
):
193256
n_dims = x.shape
194257
if len(n_dims) > 2:
195258
x = x.squeeze(0)

src/compressed_tensors/quantization/quant_args.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import warnings
1616
from enum import Enum
17-
from typing import Any, Dict, Optional, Union
17+
from typing import Any, Dict, List, Optional, Union
1818

1919
import torch
2020
from compressed_tensors.utils import Aliasable
@@ -153,8 +153,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
153153
:param symmetric: whether or not quantization scale is symmetric about zero-point
154154
:param strategy: string id determining the scope of scale/zero-point to apply
155155
:param group_size: group length to use for the group strategy
156-
:param block_structure: 2d block structure to use for the block strategy, must be
157-
of the format "2x4", "8x16", etc.
156+
:param block_structure: 2d block structure to use for the block strategy; must be
157+
a list of two ints [rows, cols] like [128, 128].
158158
:param dynamic: set True to perform dynamic quantization - values will not be
159159
calibrated during calibration phase, instead during inference new quantization
160160
ranges will be observed with every sample. Defaults to False for static
@@ -169,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
169169
symmetric: bool = True
170170
group_size: Optional[int] = None
171171
strategy: Optional[QuantizationStrategy] = None
172-
block_structure: Optional[str] = None
172+
block_structure: Optional[List[int]] = None
173173
dynamic: Union[DynamicType, bool] = False
174174
actorder: Union[ActivationOrdering, bool, None] = None
175175
observer: Optional[str] = Field(
@@ -207,6 +207,28 @@ def validate_group(cls, value) -> Union[int, None]:
207207

208208
return value
209209

210+
@field_validator("block_structure", mode="before")
211+
def validate_block_structure(cls, value) -> Optional[List[int]]:
212+
if value is None:
213+
return value
214+
# For backward compatibility, allow string format "2x4", "8x16", etc.
215+
if isinstance(value, str):
216+
try:
217+
return [int(x) for x in value.split("x")]
218+
except Exception:
219+
raise ValueError(
220+
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
221+
)
222+
if isinstance(value, (list, tuple)):
223+
if len(value) != 2 or not all(isinstance(v, int) for v in value):
224+
raise ValueError(
225+
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
226+
)
227+
return list(value)
228+
raise ValueError(
229+
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
230+
)
231+
210232
@field_validator("strategy", mode="before")
211233
def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]:
212234
if isinstance(value, str):
@@ -277,14 +299,15 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
277299

278300
# infer observer w.r.t. dynamic
279301
if dynamic:
280-
if strategy not in (
302+
supported_strategies = (
281303
QuantizationStrategy.TOKEN,
282304
QuantizationStrategy.TENSOR,
283305
QuantizationStrategy.TENSOR_GROUP,
284-
):
306+
QuantizationStrategy.GROUP,
307+
)
308+
if strategy not in supported_strategies:
285309
raise ValueError(
286-
f"One of {(QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP)} "
287-
"must be used for dynamic quantization",
310+
f"One of {supported_strategies} must be used for dynamic quantization"
288311
)
289312

290313
if (

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
1516
from copy import deepcopy
1617
from typing import Any, Dict, List, Optional
1718

@@ -52,6 +53,7 @@ class QuantizationScheme(BaseModel):
5253
def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
5354
inputs = model.input_activations
5455
outputs = model.output_activations
56+
weights = model.weights
5557

5658
if inputs is not None:
5759
if inputs.actorder is not None:
@@ -61,6 +63,21 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
6163
if outputs.actorder is not None:
6264
raise ValueError("Cannot apply actorder to output activations")
6365

66+
if (
67+
inputs and weights
68+
and weights.strategy == QuantizationStrategy.GROUP
69+
and inputs.strategy == QuantizationStrategy.GROUP
70+
and weights.group_size != inputs.group_size
71+
):
72+
warnings.warn(
73+
"Using GROUP strategy for both weights and input_activations "
74+
f"with different group sizes ({weights.group_size} vs {inputs.group_size}) "
75+
"may complicate fused kernel implementations. Consider using "
76+
"TENSOR_GROUP strategy for both or matching group sizes.",
77+
UserWarning,
78+
stacklevel=2
79+
)
80+
6481
return model
6582

6683

@@ -243,6 +260,29 @@ def is_preset_scheme(name: str) -> bool:
243260
),
244261
)
245262

263+
# Block‐wise FP8 (deepseekv3-style quantization):
264+
# static 128x128 per‐block weights and
265+
# dynamic per‐token‐group activations
266+
FP8_BLOCK = dict(
267+
weights=QuantizationArgs(
268+
num_bits=8,
269+
type=QuantizationType.FLOAT,
270+
strategy=QuantizationStrategy.BLOCK,
271+
symmetric=True,
272+
dynamic=False,
273+
block_structure=[128, 128],
274+
),
275+
input_activations=QuantizationArgs(
276+
num_bits=8,
277+
type=QuantizationType.FLOAT,
278+
strategy=QuantizationStrategy.GROUP,
279+
symmetric=True,
280+
dynamic=True,
281+
observer=None,
282+
group_size=128,
283+
),
284+
)
285+
246286
PRESET_SCHEMES = {
247287
# Unquantized (no-op)
248288
"UNQUANTIZED": UNQUANTIZED,
@@ -257,6 +297,7 @@ def is_preset_scheme(name: str) -> bool:
257297
# Float weight and activation schemes
258298
"FP8": FP8,
259299
"FP8_DYNAMIC": FP8_DYNAMIC,
300+
"FP8_BLOCK": FP8_BLOCK,
260301
"NVFP4A16": NVFP4A16,
261302
"NVFP4": NVFP4,
262303
}

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,10 @@ def compute_dynamic_scales_and_zp(
171171
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
172172
elif args.strategy == QuantizationStrategy.TENSOR:
173173
reduce_dims = None
174-
elif args.strategy == QuantizationStrategy.TENSOR_GROUP:
174+
elif args.strategy in (
175+
QuantizationStrategy.TENSOR_GROUP,
176+
QuantizationStrategy.GROUP,
177+
):
175178
if len(value.shape) > 2:
176179
value = value.squeeze(0)
177180

@@ -187,9 +190,15 @@ def compute_dynamic_scales_and_zp(
187190
),
188191
)
189192
else:
193+
supported_strategies = (
194+
QuantizationStrategy.TOKEN,
195+
QuantizationStrategy.TENSOR,
196+
QuantizationStrategy.TENSOR_GROUP,
197+
QuantizationStrategy.GROUP,
198+
)
190199
raise ValueError(
191200
"Dynamic quantization is only supported for ",
192-
f"{QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP}",
201+
f"{supported_strategies}",
193202
)
194203

195204
if not reduce_dims:

tests/test_examples/test_bitmask_compression_ipynb.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import nbformat
1615
import pytest
16+
17+
18+
nbformat = pytest.importorskip("nbformat")
1719
from nbconvert.preprocessors import ExecutePreprocessor
1820

1921

tests/test_quantization/lifecycle/test_forward.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
# limitations under the License.
1414

1515

16+
import math
17+
1618
import pytest
1719
import torch
1820
from compressed_tensors.quantization.lifecycle.forward import (
21+
_process_quantization,
1922
dequantize,
2023
forward_quantize,
2124
quantize,
@@ -29,6 +32,7 @@
2932
QuantizationStrategy,
3033
)
3134
from compressed_tensors.quantization.quant_config import QuantizationStatus
35+
from compressed_tensors.quantization.utils.helpers import calculate_range
3236
from torch.nn import Linear
3337

3438

@@ -203,3 +207,49 @@ def test_dequantize(num_bits, type, strategy, group_size, scale, zero_point, g_i
203207
dtype=None,
204208
g_idx=g_idx,
205209
)
210+
211+
212+
def test_process_quantization_block_static():
213+
"""
214+
Static block quantization (QuantizationStrategy.BLOCK) should split a 2D tensor
215+
into blocks, quantize each block, and reassemble without changing shape.
216+
"""
217+
rows, cols = 8, 8
218+
bh, bw = 2, 4
219+
x = torch.randn(rows, cols)
220+
args = QuantizationArgs(
221+
num_bits=8,
222+
type="float",
223+
strategy=QuantizationStrategy.BLOCK,
224+
symmetric=True,
225+
dynamic=False,
226+
block_structure=[bh, bw],
227+
)
228+
num_rb = math.ceil(rows / bh)
229+
num_cb = math.ceil(cols / bw)
230+
scale = torch.rand(num_rb, num_cb) + 0.1
231+
zp = torch.zeros_like(scale)
232+
q_min, q_max = calculate_range(args, x.device)
233+
out = _process_quantization(
234+
x=x,
235+
scale=scale,
236+
zero_point=zp,
237+
args=args,
238+
do_quantize=True,
239+
do_dequantize=False,
240+
dtype=None,
241+
global_scale=None,
242+
)
243+
assert out.shape == x.shape
244+
# full fake-quantize roundtrip
245+
out2 = _process_quantization(
246+
x=x,
247+
scale=scale,
248+
zero_point=zp,
249+
args=args,
250+
do_quantize=True,
251+
do_dequantize=True,
252+
dtype=None,
253+
global_scale=None,
254+
)
255+
assert out2.shape == x.shape

tests/test_quantization/test_quant_args.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def test_block():
5959

6060
block = QuantizationArgs(**kwargs)
6161
assert block.strategy == QuantizationStrategy.BLOCK
62-
assert block.block_structure == kwargs["block_structure"]
62+
assert block.block_structure == [2, 4]
63+
assert block.block_structure != kwargs["block_structure"] # "2x4" != [2, 4]
6364

6465

6566
def test_infer_strategy():

0 commit comments

Comments
 (0)