Skip to content

Commit 379cb75

Browse files
authored
Remove zero_point_domain from quant configs (#2058)
* init * up * up * up
1 parent 7513042 commit 379cb75

File tree

8 files changed

+179
-223
lines changed

8 files changed

+179
-223
lines changed

benchmarks/microbenchmarks/utils.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -255,24 +255,30 @@ def string_to_config(
255255
group_size = int(_quant_args[2])
256256
return UIntXWeightOnlyConfig(dtype, group_size, use_hqq=use_hqq)
257257
elif "int8_dynamic_activation_intx_weight" in quantization:
258-
from torchao.experimental.quant_api import (
259-
Int8DynamicActivationIntxWeightConfig,
260-
)
261-
from torchao.quantization.granularity import PerGroup
262-
263258
assert (
264259
high_precision_dtype == torch.float32
265260
), "int8_dynamic_activation_intx_weight requires using high_precision_dtype=torch.float32"
266261

262+
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
263+
from torchao.quantization.granularity import PerAxis, PerGroup
264+
from torchao.quantization.quant_api import (
265+
Int8DynamicActivationIntxWeightConfig,
266+
)
267+
267268
# Quantize model
268269
_quant_args = quantization.split("-")
269270
weight_dtype = getattr(torch, f"int{_quant_args[1]}")
270-
granularity = PerGroup(int(_quant_args[2]))
271-
has_weight_zeros = bool(_quant_args[3])
271+
group_size = int(_quant_args[2])
272+
granularity = PerGroup(group_size) if group_size > 0 else PerAxis(0)
273+
is_asymmetric = bool(_quant_args[3])
272274
return Int8DynamicActivationIntxWeightConfig(
273275
weight_dtype=weight_dtype,
274-
granularity=granularity,
275-
has_weight_zeros=has_weight_zeros,
276+
weight_granularity=granularity,
277+
weight_mapping_type=MappingType.ASYMMETRIC
278+
if is_asymmetric
279+
else MappingType.SYMMETRIC,
280+
weight_scale_dtype=torch.bfloat16,
281+
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
276282
)
277283
elif "float8wo" in quantization:
278284
return Float8WeightOnlyConfig()

torchao/_models/llama/generate.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -568,24 +568,22 @@ def ffn_or_attn_only(mod, fqn):
568568
from torchao.quantization.granularity import PerAxis, PerGroup
569569
from torchao.quantization.quant_api import (
570570
Int8DynamicActivationIntxWeightConfig,
571-
ZeroPointDomain,
572571
)
573572

574573
# Quantize model
575574
_quant_args = quantization.split("-")
576575
weight_dtype = getattr(torch, f"int{_quant_args[1]}")
577576
group_size = int(_quant_args[2])
578577
granularity = PerGroup(group_size) if group_size > 0 else PerAxis(0)
579-
has_weight_zeros = bool(_quant_args[3])
578+
is_asymmetric = bool(_quant_args[3])
580579
quantize_(
581580
model,
582581
Int8DynamicActivationIntxWeightConfig(
583582
weight_dtype=weight_dtype,
584583
weight_granularity=granularity,
585-
weight_zero_point_domain=ZeroPointDomain.INT
586-
if has_weight_zeros
587-
else ZeroPointDomain.NONE,
588-
weight_mapping_type=MappingType.ASYMMETRIC,
584+
weight_mapping_type=MappingType.ASYMMETRIC
585+
if is_asymmetric
586+
else MappingType.SYMMETRIC,
589587
weight_scale_dtype=torch.bfloat16,
590588
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
591589
),

torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py

Lines changed: 96 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -156,48 +156,109 @@ def from_plain(
156156
zero_point: Optional[torch.Tensor],
157157
layout: Layout,
158158
bias: Optional[torch.Tensor] = None,
159+
*,
160+
validate_inputs: bool = True,
159161
):
160162
assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
161-
assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
162163
assert layout.target in [
163164
t for t, _ in _TARGET_AND_STR
164165
], f"Unexpected target: {layout.target}"
166+
assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
167+
168+
if layout.target != Target.ATEN:
169+
_check_torchao_ops_loaded()
170+
else:
171+
assert (
172+
TORCH_VERSION_AT_LEAST_2_6
173+
), "aten target is requires torch version > 2.6.0"
174+
assert (
175+
torch.backends.kleidiai.is_available()
176+
), "ATEN target requires torch.backends.kleidiai.is_available()"
177+
layout.bit_width == 4, "ATEN target only supports torch.int4"
178+
assert not layout.has_weight_zeros, "ATEN target does not support zeros"
179+
180+
data_dtype = getattr(torch, f"int{layout.bit_width}")
181+
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[data_dtype]
165182

166183
int_types = [torch.int8, torch.int16, torch.int32, torch.int64]
167184

185+
# Check int_data
186+
assert int_data.device == torch.device("cpu")
187+
assert int_data.dtype in int_types
168188
n, k = int_data.shape
169-
assert int_data.dtype in int_types, f"int_data.dtype must be {int_types}"
170189
assert k % layout.group_size == 0, "k must be divisible by group_size"
190+
if validate_inputs:
191+
assert int_data.min().item() >= qmin
192+
assert int_data.max().item() <= qmax
171193
int_data = int_data.to(torch.int8)
172194

173-
assert scale.dtype == torch.float32, "scale must be float32"
195+
# Check scale
196+
assert scale.device == torch.device("cpu")
197+
if scale.dtype != torch.float32:
198+
logging.info(f"scale has dtype {scale.dtype}, converting to torch.float32")
199+
scale = scale.to(torch.float32)
200+
n_, _ = scale.shape
201+
assert n_ == n
174202
assert (
175203
scale.numel() * layout.group_size == int_data.numel()
176204
), "must have 1 scale per group"
177-
178-
assert (zero_point is not None) == (
179-
layout.has_weight_zeros
180-
), "zero_point being None must be consistent with layout.has_weight_zeros"
181-
if zero_point is not None:
205+
if validate_inputs:
206+
assert scale.min().item() > 0
207+
# Some targets round scales to bfloat16, give warning if scales are at higher precision
208+
scale_is_rounded_to_bf16 = torch.allclose(
209+
scale, scale.to(torch.bfloat16).to(torch.float32)
210+
)
211+
if not scale_is_rounded_to_bf16:
212+
if layout.target == Target.ATEN and (layout.group_size < k):
213+
logging.warning(
214+
"When using Target.ATEN with group_size < k, scales will be rounded to bfloat16"
215+
)
216+
if layout.target in [Target.AUTO, Target.KLEIDIAI]:
217+
logging.warning(
218+
"When using [Target.AUTO, Target.KLEIDIAI], scales will be rounded to bfloat16"
219+
)
220+
221+
# Check zero_point
222+
if zero_point is None:
182223
assert (
183-
zero_point.dtype in int_types
184-
), f"zero_point.dtype must be {int_types}"
224+
not layout.has_weight_zeros
225+
), "zero_point must be provided if has_weight_zeros=True"
226+
else:
227+
assert zero_point.device == torch.device("cpu")
228+
assert zero_point.shape == scale.shape
229+
assert zero_point.dtype in int_types
185230
assert (
186231
zero_point.numel() * layout.group_size == int_data.numel()
187232
), "must have 1 zero_point per group"
233+
if validate_inputs:
234+
zero_point_min = zero_point.min().item()
235+
zero_point_max = zero_point.max().item()
236+
assert zero_point.min().item() >= qmin
237+
assert zero_point.max().item() <= qmax
238+
has_weight_zeros = True
239+
if zero_point_min == 0 and zero_point_max == 0:
240+
has_weight_zeros = False
241+
assert (
242+
has_weight_zeros == layout.has_weight_zeros
243+
), "zero_point being all zeros must be consistent with layout.has_weight_zeros"
188244
zero_point = zero_point.to(torch.int8)
189245

190-
assert (bias is not None) == (
191-
layout.has_bias
246+
# Check bias
247+
has_bias = bias is not None
248+
assert (
249+
has_bias == layout.has_bias
192250
), "bias being None must be consistent with layout.has_bias"
193-
if bias is not None:
194-
assert bias.dtype == torch.float32, "bias.dtype must be float32"
195-
assert bias.shape == (n,), "bias must have shape n"
251+
if has_bias:
252+
assert bias.device == torch.device("cpu")
253+
if bias.dtype != torch.float32:
254+
logging.info(
255+
f"bias has dtype {bias.dtype}, converting to torch.float32"
256+
)
257+
bias = bias.to(torch.float32)
258+
assert bias.shape == (n,)
196259

260+
# Construct packed_weight
197261
if layout.target == Target.ATEN:
198-
assert (
199-
TORCH_VERSION_AT_LEAST_2_6
200-
), "aten target is requires torch version > 2.6.0"
201262
int_data = int_data.add(8)
202263
int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
203264

@@ -213,12 +274,11 @@ def from_plain(
213274
args = [
214275
int_data,
215276
scale.reshape(-1),
216-
zero_point.reshape(-1) if zero_point is not None else None,
277+
zero_point.reshape(-1) if layout.has_weight_zeros else None,
217278
layout.group_size,
218279
bias,
219280
target_to_str(layout.target) if layout.target != Target.AUTO else None,
220281
]
221-
222282
packed_weight = getattr(
223283
torch.ops.torchao,
224284
f"_pack_8bit_act_{layout.bit_width}bit_weight",
@@ -358,79 +418,35 @@ def make_packed_linear_int8_dynamic_activation_intx_weight_tensor(
358418
assert TORCH_VERSION_AT_LEAST_2_6, "Using PackedLinearInt8DynamicActivationIntxWeightLayout requires torch version > 2.6.0"
359419

360420
layout = PackedLinearInt8DynamicActivationIntxWeightLayout(target=target)
361-
if layout.target != Target.ATEN:
362-
_check_torchao_ops_loaded()
363-
else:
364-
assert (
365-
torch.backends.kleidiai.is_available()
366-
), "ATEN target requires torch.backends.kleidiai.is_available()"
367-
assert data_dtype == torch.int4, "ATEN target only supports torch.int4"
368-
assert zero_point is None, "ATEN target does not support zeros"
369421

370-
assert data_dtype in [getattr(torch, f"int{x}") for x in range(1, 9)]
371-
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[data_dtype]
372422
bit_width = _DTYPE_TO_BIT_WIDTH[data_dtype]
423+
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[data_dtype]
373424

374-
int_types = [torch.int8, torch.int16, torch.int32, torch.int64]
375-
376-
# Check int_data
377-
assert int_data.device == torch.device("cpu")
378-
assert int_data.dtype in int_types
379425
n, k = int_data.shape
380-
if validate_inputs:
381-
assert int_data.min().item() >= qmin
382-
assert int_data.max().item() <= qmax
383-
384-
# Check scale
385-
assert scale.device == torch.device("cpu")
386-
if scale.dtype != torch.float32:
387-
logging.info(f"scale has dtype {scale.dtype}, converting to torch.float32")
388-
scale = scale.to(torch.float32)
389426
n_, groups_per_k = scale.shape
390-
assert n_ == n
391427
assert k % groups_per_k == 0
392428
group_size = k // groups_per_k
393-
if validate_inputs:
394-
assert scale.min().item() > 0
395429

396-
if validate_inputs:
397-
# Some targets round scales to bfloat16, give warning if scales are at higher precision
398-
scale_is_rounded_to_bf16 = torch.allclose(
399-
scale, scale.to(torch.bfloat16).to(torch.float32)
400-
)
401-
if not scale_is_rounded_to_bf16:
402-
if layout.target == Target.ATEN and (group_size < k):
403-
logging.warning(
404-
"When using Target.ATEN with group_size < k, scales will be rounded to bfloat16"
405-
)
406-
if layout.target in [Target.AUTO, Target.KLEIDIAI]:
407-
logging.warning(
408-
"When using [Target.AUTO, Target.KLEIDIAI], scales will be rounded to bfloat16"
409-
)
410-
411-
# Check zero_point
412-
has_weight_zeros = zero_point is not None
413-
if has_weight_zeros:
414-
assert zero_point.device == torch.device("cpu")
415-
assert zero_point.shape == scale.shape
416-
assert zero_point.dtype in int_types
417-
if validate_inputs:
418-
assert zero_point.min().item() >= qmin
419-
assert zero_point.max().item() <= qmax
430+
has_weight_zeros = True
431+
if zero_point is None:
432+
has_weight_zeros = False
433+
else:
434+
zero_point_min = zero_point.min().item()
435+
zero_point_max = zero_point.max().item()
436+
if zero_point_min == 0 and zero_point_max == 0:
437+
has_weight_zeros = False
420438

421-
# Check bias
422439
has_bias = bias is not None
423-
if has_bias:
424-
assert bias.device == torch.device("cpu")
425-
if bias.dtype != torch.float32:
426-
logging.info(f"bias has dtype {bias.dtype}, converting to torch.float32")
427-
bias = bias.to(torch.float32)
428-
assert bias.shape == (n,)
429440

430441
layout.set_params(bit_width, group_size, has_weight_zeros, has_bias)
431442
assert layout.has_params_set()
432443
tensor_impl = PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl.from_plain(
433-
int_data, scale, zero_point, layout, bias
444+
int_data,
445+
scale,
446+
zero_point,
447+
layout,
448+
bias,
449+
validate_inputs=validate_inputs,
434450
)
435451

436452
return AffineQuantizedTensor(
@@ -439,7 +455,5 @@ def make_packed_linear_int8_dynamic_activation_intx_weight_tensor(
439455
shape=int_data.shape,
440456
quant_min=qmin,
441457
quant_max=qmax,
442-
zero_point_domain=ZeroPointDomain.INT
443-
if has_weight_zeros
444-
else ZeroPointDomain.NONE,
458+
zero_point_domain=ZeroPointDomain.INT,
445459
)

0 commit comments

Comments
 (0)