Skip to content

[BE] Convert quantization internal methods private #2568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@
LoggingTensorMode,
_apply_logging_hook,
_fqn_to_op_to_shape_to_count,
_quant_int8_dynamic_per_token_linear,
_quantize_activation_per_token_absmax,
compute_error,
dequantize_per_channel,
dynamically_quantize_per_channel,
quant_int8_dynamic_per_token_linear,
quantize_activation_per_token_absmax,
)
from torchao.quantization.utils import (
compute_error as SQNR,
Expand Down Expand Up @@ -557,7 +557,7 @@ def test_dynamic_quant_per_channel_numerics_cuda(self):

def _test_quantize_per_token_impl(self, device, dtype):
x = torch.randn(3, 3, 3, device=device, dtype=dtype)
xq, scales = quantize_activation_per_token_absmax(x)
xq, scales = _quantize_activation_per_token_absmax(x)
block_size = (1, 1, 3)
x_dq = dequantize_affine(
xq, block_size, scales, None, torch.int8, output_dtype=x.dtype
Expand All @@ -581,7 +581,7 @@ def _test_per_token_linear_impl(self, device, dtype):
# Note: need to make the weight contiguous because we are
# testing in eager mode and cuBlas will not give correct results
# for a transposed weight
y = quant_int8_dynamic_per_token_linear(
y = _quant_int8_dynamic_per_token_linear(
x, wq.t().contiguous(), w_scales, None, dtype
)
y_ref = torch.matmul(x, w.t())
Expand Down Expand Up @@ -1679,9 +1679,9 @@ def forward(self, x):
assert not isinstance(mod.mha.out_proj.weight, AutoQuantizableLinearWeight)
assert isinstance(mod.lin.weight, AutoQuantizableLinearWeight)
mod(*input)
from torchao.quantization.autoquant import AUTOQUANT_CACHE
from torchao.quantization.autoquant import _AUTOQUANT_CACHE

assert len(AUTOQUANT_CACHE) > 0
assert len(_AUTOQUANT_CACHE) > 0

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.")
Expand Down
12 changes: 6 additions & 6 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

# TODO: remove test for utils?
from torchao.quantization.utils import (
_quantize_activation_per_token_absmax,
get_group_qparams_symmetric,
groupwise_affine_dequantize_tensor_from_qparams,
groupwise_affine_quantize_tensor_from_qparams,
quantize_activation_per_token_absmax,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
Expand Down Expand Up @@ -352,7 +352,7 @@ def test_choose_qparams_tensor_sym(self):
)
def test_quantize_activation_per_token_abs_max(self):
input = torch.randn(10, 10)
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)
quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input)

mapping_type = MappingType.SYMMETRIC
block_size = list(input.shape)
Expand Down Expand Up @@ -386,22 +386,22 @@ def test_quantize_activation_per_token_abs_max(self):
def test_quantize_activation_per_token_abs_max_zero_input(self):
input = torch.zeros(10, 10)
# make sure it still works
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)
quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input)

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_quantize_activation_per_token_abs_max_dtype(self):
input = torch.zeros(10, 10, dtype=torch.bfloat16)
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)
quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input)
self.assertTrue(scale_ref.dtype, torch.bfloat16)

input = torch.zeros(10, 10, dtype=torch.float32)
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)
quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input)
self.assertTrue(scale_ref.dtype, torch.float32)

input = torch.zeros(10, 10, dtype=torch.float16)
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)
quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input)
self.assertTrue(scale_ref.dtype, torch.float32)

@unittest.skipIf(
Expand Down
6 changes: 3 additions & 3 deletions torchao/_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def update(self, input_pos, k_val, v_val):
return k_out, v_out


from torchao.quantization.utils import quantize_activation_per_token_absmax
from torchao.quantization.utils import _quantize_activation_per_token_absmax


class AffineQuantizedKVCache(nn.Module):
Expand All @@ -218,13 +218,13 @@ def __init__(

def update(self, input_pos, k_val, v_val):
# quantize current k_val and store it in the cache
q_k_val, k_scale = quantize_activation_per_token_absmax(k_val)
q_k_val, k_scale = _quantize_activation_per_token_absmax(k_val)
self.k_cache[:, :, input_pos] = q_k_val
self.k_cache_scale[:, :, input_pos] = k_scale.unsqueeze(-1)
k_out = self.k_cache * self.k_cache_scale
k_out[:, :, input_pos] = k_val

q_v_val, v_scale = quantize_activation_per_token_absmax(v_val)
q_v_val, v_scale = _quantize_activation_per_token_absmax(v_val)
self.v_cache[:, :, input_pos] = q_v_val
self.v_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1)
v_out = self.v_cache * self.v_cache_scale
Expand Down
12 changes: 6 additions & 6 deletions torchao/prototype/quantization/autoquant_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
Int8WeightOnlyQuantizedLinearWeight,
QuantizedLinearWeightBase,
)
from torchao.quantization.utils import quantize_activation_per_token_absmax
from torchao.quantization.utils import _quantize_activation_per_token_absmax
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_5,
Expand Down Expand Up @@ -110,7 +110,7 @@ def _graph_equals(g1, g2):

aten = torch.ops.aten

AUTOQUANT_CACHE = {}
_AUTOQUANT_CACHE = {}

# This is a flag to control whether we do some rewrite for graph
# to account for different batch sizes, it's a temporary solution for llama model
Expand All @@ -119,15 +119,15 @@ def _graph_equals(g1, g2):


def check_cache(gm, cls, shapes_and_dtype):
for gm_, cls_, shapes_and_dtype_ in AUTOQUANT_CACHE.keys():
for gm_, cls_, shapes_and_dtype_ in _AUTOQUANT_CACHE.keys():
graph_equals = _graph_equals(gm_.graph, gm.graph)
if graph_equals and cls_ is cls and shapes_and_dtype_ == shapes_and_dtype:
return AUTOQUANT_CACHE[(gm_, cls_, shapes_and_dtype_)]
return _AUTOQUANT_CACHE[(gm_, cls_, shapes_and_dtype_)]
return None


def update_cache(gm, cls, shapes_and_dtype, res):
AUTOQUANT_CACHE[(gm, cls, shapes_and_dtype)] = res
_AUTOQUANT_CACHE[(gm, cls, shapes_and_dtype)] = res


# adjust each input's bsz to target_bsz
Expand Down Expand Up @@ -638,7 +638,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
# SAM best is between .8 and 1, SDXL also performs best in this range
INTERPOLATION_CONSTANT = mode[1]
w_qtensor = cls.from_float(weight)
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
x_vals_int8, x_scales = _quantize_activation_per_token_absmax(
act_mat.reshape(-1, act_mat.shape[-1])
)
quantized_matmul = (
Expand Down
10 changes: 5 additions & 5 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,21 @@ When used as in the example above, when the `autoquant` api is called alongside

When `model(input)` is called, (under the hood) the tool does a preliminary run with the input where each linear layer keeps track of the different shapes and types of activations that it sees. Once the preliminary run is complete, the next step is to check each linear layer and benchmark the tracked shapes for different types of quantization techniques in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, the next step is to apply the necessary quantization technique to each layer, before finally allowing the normal `torch.compile` process to occur on the now quantized model. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `Int4WeightOnlyConfig()` since for certain (compute bound) regimes, int4 weight only quantization can be very slow.

Sometimes it is desirable to reuse a quantization plan that `autoquant` came up with. `torchao.quantization.AUTOQUANT_CACHE` is a dictionary holding autoquant's benchmark results. We can save it and restore it later, which will cause `autoquant` to choose the same quantization methods.
Sometimes it is desirable to reuse a quantization plan that `autoquant` came up with. `torchao.quantization._AUTOQUANT_CACHE` is a dictionary holding autoquant's benchmark results. We can save it and restore it later, which will cause `autoquant` to choose the same quantization methods.

```python
import pickle
import torchao.quantization

# After the first forward pass (when quantization was done)
from torchao.quantization.autoquant import AUTOQUANT_CACHE
from torchao.quantization.autoquant import _AUTOQUANT_CACHE
with open("quantization-cache.pkl", "wb") as f:
pickle.dump(AUTOQUANT_CACHE, f)
pickle.dump(_AUTOQUANT_CACHE, f)

# On load
from torchao.quantization.autoquant import AUTOQUANT_CACHE
from torchao.quantization.autoquant import _AUTOQUANT_CACHE
with open("quantization-cache.pkl", "rb") as f:
AUTOQUANT_CACHE.update(pickle.load(f))
_AUTOQUANT_CACHE.update(pickle.load(f))
```

## Quantization Techniques
Expand Down
28 changes: 14 additions & 14 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
ZeroPointDomain,
)
from torchao.quantization.utils import (
_quantize_activation_per_token_absmax,
compute_error,
quantize_activation_per_token_absmax,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
Expand Down Expand Up @@ -63,15 +63,15 @@

aten = torch.ops.aten

AUTOQUANT_CACHE = {}
_AUTOQUANT_CACHE = {}


def check_cache(cls, shapes_and_dtype):
return AUTOQUANT_CACHE.get((cls,) + shapes_and_dtype, None)
def _check_cache(cls, shapes_and_dtype):
return _AUTOQUANT_CACHE.get((cls,) + shapes_and_dtype, None)


def update_cache(cls, shapes_and_dtype, res):
AUTOQUANT_CACHE[(cls,) + shapes_and_dtype] = res
def _update_cache(cls, shapes_and_dtype, res):
_AUTOQUANT_CACHE[(cls,) + shapes_and_dtype] = res


# TODO: Document the methods
Expand Down Expand Up @@ -145,12 +145,12 @@ def log_shape(act_mat, w_autoquant, bias):
shapes_and_dtype, 0
)
for q_cls in w_autoquant.qtensor_class_list:
if check_cache(q_cls, shapes_and_dtype) is None:
update_cache(q_cls, shapes_and_dtype, None)
if _check_cache(q_cls, shapes_and_dtype) is None:
_update_cache(q_cls, shapes_and_dtype, None)

def tune_autoquant(self, q_cls, shapes_and_dtype, best_time):
act_shape, w_shape, bias_shape, act_dtype = shapes_and_dtype
if check_cache(q_cls, shapes_and_dtype) is None:
if _check_cache(q_cls, shapes_and_dtype) is None:
with torch.no_grad():
act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device)
bias = (
Expand Down Expand Up @@ -183,7 +183,7 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time):
f"warning: failed to autoquant {q_cls.__name__} for shape: {shapes_and_dtype} due to {e}"
)
res = torch.inf
update_cache(q_cls, shapes_and_dtype, res)
_update_cache(q_cls, shapes_and_dtype, res)

@torch.no_grad()
def to_quantized(self, error_on_unseen, **kwargs):
Expand Down Expand Up @@ -223,13 +223,13 @@ def count_shapes(self, do_print=True):
total_seen = 0
shape_count = count_shapes(self, do_print=False)
for shapes_and_dtype, times_seen in self.logged_data.items():
if check_cache(q_cls, shapes_and_dtype) is None:
if _check_cache(q_cls, shapes_and_dtype) is None:
# only print shapes once
if print_shape_once:
print_shape_once = False
count_shapes(self, do_print=True)

time_for_best_shape = check_cache(best_cls, shapes_and_dtype)
time_for_best_shape = _check_cache(best_cls, shapes_and_dtype)
time_for_best_shape = (
torch.inf
if time_for_best_shape is None
Expand All @@ -238,7 +238,7 @@ def count_shapes(self, do_print=True):
self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape)
ran_new_benchmarks = True
torch._dynamo.reset()
cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen
cur_time += _check_cache(q_cls, shapes_and_dtype) * times_seen
total_seen += times_seen
cur_time = cur_time / total_seen
# print aggregated time if there were multiple shapes to aggregate and some new benchmarking was done
Expand Down Expand Up @@ -498,7 +498,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
# SAM best is between .8 and 1, SDXL also performs best in this range
INTERPOLATION_CONSTANT = mode[1]
w_qtensor = cls.from_float(weight)
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
x_vals_int8, x_scales = _quantize_activation_per_token_absmax(
act_mat.reshape(-1, act_mat.shape[-1])
)
quantized_matmul = (
Expand Down
4 changes: 2 additions & 2 deletions torchao/quantization/dynamic_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import torch.nn as nn

from .utils import (
_quant_int8_dynamic_per_token_linear,
dynamically_quantize_per_channel,
quant_int8_dynamic_per_token_linear,
)

__all__ = ["DynamicallyPerAxisQuantizedLinear"]
Expand Down Expand Up @@ -44,7 +44,7 @@ def forward(self, X: torch.Tensor, *args, **kwargs) -> torch.Tensor:

"""

Y = quant_int8_dynamic_per_token_linear(
Y = _quant_int8_dynamic_per_token_linear(
X, self.W_int_repr_t, self.W_scales, self.bias, X.dtype
)
return Y
Expand Down
4 changes: 2 additions & 2 deletions torchao/quantization/smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import torch.nn.functional as F

from .utils import (
_quant_int8_dynamic_per_token_linear,
dynamically_quantize_per_channel,
quant_int8_dynamic_per_token_linear,
)

__all__ = [
Expand Down Expand Up @@ -152,7 +152,7 @@ def forward(self, X, *args, **kwargs):
W_int_repr_t = (
self.W_int_repr if self.store_w_int_repr_t else self.W_int_repr.t()
)
Y = quant_int8_dynamic_per_token_linear(
Y = _quant_int8_dynamic_per_token_linear(
X, W_int_repr_t, self.W_scales, self.bias, X.dtype
)
return Y
Expand Down
4 changes: 2 additions & 2 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.quantization.utils import (
_quant_int8_dynamic_per_token_linear,
dequantize_per_channel,
dynamically_quantize_per_channel,
groupwise_affine_quantize_tensor,
quant_int8_dynamic_per_token_linear,
unpack_tinygemm_scales_and_zeros,
)
from torchao.utils import (
Expand Down Expand Up @@ -244,7 +244,7 @@ def __init__(self, int_data, q_scales, transposed, shape, dtype=None, **kwargs):

@staticmethod
def _quantized_op(act_mat, w_qtensor, bias):
return quant_int8_dynamic_per_token_linear(
return _quant_int8_dynamic_per_token_linear(
act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype
)

Expand Down
Loading
Loading