Skip to content

Commit 19237e3

Browse files
authored
Update QAT README and API docstrings (#2465)
Previously they pointed to the 0.7.0 code. Now they point to the corresponding API page on our docs. Also move the docstring from the private function to the public `IntXQuantizationAwareTrainingConfig` so it shows up on the doc page.
1 parent 44d1dd3 commit 19237e3

File tree

2 files changed

+23
-25
lines changed

2 files changed

+23
-25
lines changed

torchao/quantization/qat/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ def train_loop(m: torch.nn.Module):
7171

7272
The recommended way to run QAT in torchao is through the `quantize_` API:
7373
1. **Prepare:** specify how weights and/or activations are to be quantized through
74-
[`FakeQuantizeConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L29) and passing these to [`IntXQuantizationAwareTrainingConfig`](https://github.com/pytorch/ao/blob/cedadc741954f47a9e9efac2aa584701f125bc73/torchao/quantization/qat/api.py#L242)
74+
[`FakeQuantizeConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.FakeQuantizeConfig.html#torchao.quantization.qat.FakeQuantizeConfig) and passing these to [`IntXQuantizationAwareTrainingConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.IntXQuantizationAwareTrainingConfig.html#torchao.quantization.qat.IntXQuantizationAwareTrainingConfig)
7575
2. **Convert:** quantize the model using the standard post-training quantization (PTQ)
76-
functions such as [`Int8DynamicActivationInt4WeightConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/quant_api.py#L606)
76+
functions such as [`Int8DynamicActivationInt4WeightConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int8DynamicActivationInt4WeightConfig.html#torchao.quantization.Int8DynamicActivationInt4WeightConfig)
7777

7878
For example:
7979

@@ -137,9 +137,9 @@ quantize_(
137137

138138
Alternatively, torchao provides a few hardcoded quantization settings through
139139
the following Quantizers:
140-
- [Int8DynActInt4QATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/linear.py#L126) (linear), targeting int8 per-token dynamic asymmetric activation + int4 per-group symmetric weight
141-
- [Int4WeightOnlyQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/linear.py#L308) (linear), targeting int4 per-group asymmetric weight using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training)
142-
- [Int4WeightOnlyEmbeddingQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/embedding.py#L94) (embedding), targeting int4 per-group symmetric weight
140+
- [Int8DynActInt4QATQuantizer](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer.html#torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer) (linear), targeting int8 per-token dynamic asymmetric activation + int4 per-group symmetric weight
141+
- [Int4WeightOnlyQATQuantizer](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.Int4WeightOnlyQATQuantizer.html#torchao.quantization.qat.Int4WeightOnlyQATQuantizer) (linear), targeting int4 per-group asymmetric weight using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training)
142+
- [Int4WeightOnlyEmbeddingQATQuantizer](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.Int4WeightOnlyEmbeddingQATQuantizer.html#torchao.quantization.qat.Int4WeightOnlyEmbeddingQATQuantizer) (embedding), targeting int4 per-group symmetric weight
143143

144144
For example:
145145
```python
@@ -162,7 +162,7 @@ model = qat_quantizer.convert(model)
162162
```
163163

164164
To use multiple Quantizers in the same model for different layer types,
165-
users can also leverage the [ComposableQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L242)
165+
users can also leverage the [ComposableQATQuantizer](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.ComposableQATQuantizer.html#torchao.quantization.qat.ComposableQATQuantizer)
166166
as follows:
167167

168168
```python

torchao/quantization/qat/api.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -255,24 +255,8 @@ def __setattr__(self, name: str, value: Any):
255255

256256
@dataclass
257257
class IntXQuantizationAwareTrainingConfig(AOBaseConfig):
258-
activation_config: Optional[FakeQuantizeConfig] = None
259-
weight_config: Optional[FakeQuantizeConfig] = None
260-
261-
262-
# for BC
263-
intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig
264-
265-
266-
@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig)
267-
def _intx_quantization_aware_training_transform(
268-
module: torch.nn.Module,
269-
config: IntXQuantizationAwareTrainingConfig,
270-
) -> torch.nn.Module:
271258
"""
272-
THIS IS NOT A PUBLIC API - any usage of this outside of torchao
273-
can break at any time.
274-
275-
Apply fake quantization to a `torch.nn.Module`.
259+
Config for applying fake quantization to a `torch.nn.Module`.
276260
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
277261
278262
Example usage::
@@ -290,11 +274,25 @@ def _intx_quantization_aware_training_transform(
290274
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
291275
)
292276
293-
Note: If the returned function is applied on a module that is not
277+
Note: If the config is applied on a module that is not
294278
`torch.nn.Linear` or `torch.nn.Embedding`, or it is applied on
295279
`torch.nn.Embedding` with an activation config, then we will raise
296280
ValueError as these are not supported.
297281
"""
282+
283+
activation_config: Optional[FakeQuantizeConfig] = None
284+
weight_config: Optional[FakeQuantizeConfig] = None
285+
286+
287+
# for BC
288+
intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig
289+
290+
291+
@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig)
292+
def _intx_quantization_aware_training_transform(
293+
module: torch.nn.Module,
294+
config: IntXQuantizationAwareTrainingConfig,
295+
) -> torch.nn.Module:
298296
from .embedding import FakeQuantizedEmbedding
299297
from .linear import FakeQuantizedLinear
300298

@@ -320,7 +318,7 @@ def _intx_quantization_aware_training_transform(
320318

321319
class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig):
322320
"""
323-
Object that knows how to convert a model with fake quantized modules,
321+
Config for converting a model with fake quantized modules,
324322
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
325323
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
326324
back to model with the original, corresponding modules without

0 commit comments

Comments
 (0)