Skip to content

Commit 50a555b

Browse files
authored
Add quantize_ nn.Parameter support (#3083)
This PR adds in support for quantizing `nn.Parameter` to `quantize_`. ### bc-breaking changes The top level `quantize_` API has the following bc-breaking changes: 1) Passing in both`filter_fn` and `ModuleFqnToConfig` is no longer supported and will now throw a value error if both are specified. Previously, we would quantize all modules that were both matched by `filter_fn` and specified in `ModuleFqnToConfig`. Users should now manually specify `filter_fn=None` when using `ModuleFqnToConfig`/`FqnToConfig`. 2) The semantics of `filter_fn=None` have changed. Previously passing in `None` would default to `_is_linear` when running `quantize_`. Now when `filter_fn=None` is specified we ignore `filter_fn` completely and only rely on `FqnToConfig` to quantize the model. Note that this is equivalent to passing in `filter_fn=lambda mod, fqn: True` in the previous API. 3) The default `filter_fn` has changed from `None` to `_is_linear` and `_default` in `ModuleFqnToConfig` now only applies to linear layers. Previously `_default` would apply to all modules that passed `filter_fn`. We plan to deprecate `_default` in the future, please see #3229 for more details. Before: ```python model = torch.nn.Sequential( torch.nn.Linear(128, 128), torch.nn.Linear(128, 128), torch.nn.Conv2d(128, 128, 3, 1, 1), ).cuda().to(torch.bfloat16) config = ModuleFqnToConfig({ "0": Float8DynamicActivationFloat8WeightConfig(), }) # these are equivalent quantize_(model, config, filter_fn=_is_linear) quantize_(model, config, filter_fn=None) quantize_(model, config) ``` ``` > Sequential( (0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>)) (1): Linear(in_features=128, out_features=128, bias=True) (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ``` After: ```python # user must specify None quantize_(model, config, filter_fn=None) ``` ``` > Sequential( (0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>)) (1): Linear(in_features=128, out_features=128, bias=True) (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ``` After: ```python # these now error quantize_(model, config, filter_fn=_is_linear) quantize_(model, config) ``` ``` > ValueError: Custom filter_fn and FqnToConfig were both specified. Only filter_fn=None is supported when FqnToConfig is specified. ``` #### Example for _default changes: Before: ```python class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.register_parameter("weight", torch.nn.Parameter(torch.randn(128, 128))) model = torch.nn.Sequential( torch.nn.Linear(128, 128), torch.nn.Linear(128, 128), MyModule(), ).cuda().to(torch.bfloat16) config = ModuleFqnToConfig({ "_default": Float8DynamicActivationFloat8WeightConfig(), }) quantize_(model, config, filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear) or isinstance(mod, MyModule)) ``` ``` > Sequential( (0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>)) (1): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>)) (2): MyModule(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>)) ) ``` After: ```python # only linear is applied for default quantize_(model, config, filter_fn=None) ``` ``` > Sequential( (0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>)) (1): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>)) (2): MyModule() ) ``` ### Summary `ModuleFqnToConfig` has been renamed to `FqnToConfig`, which now accepts both module fqn and parameter fqns. `ModuleFqnToConfig` has been aliased to maintain BC. The keys to `FqnToConfig` can be one of the following (in order of precedence): 1) exact parameter FQN ```python quant_config = FqnToConfig({ "linear1.weight": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), }) ``` 2) exact module FQN ```python quant_config = FqnToConfig({ "linear1": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), }) ``` 3) regex that matches parameter FQN (prepended by `re:`) ```python quant_config = FqnToConfig({ "re:linear*.weight": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), }) ``` 4) regex that matches module FQN (prepended by `re:`) ```python quant_config = FqnToConfig({ "re:linear*": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), }) ``` 5) _default, only applies to `nn.Linear` layers ```python quant_config = FqnToConfig({ "_default": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), }) ``` To enable support for parameter fqn for a paticular config, we need to add the `parameter_name` kwarg into the config signature, and update `CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS`. See the changes [here](https://github.com/pytorch/ao/pull/3083/files#diff-bf4d50867e3d649de2d89146592bf47d2f258c4c19126c8acf0e120ee904b726R1874) for more details. `Float8DynamicActivationFloat8WeightConfig` has been enabled by this PR, but other configs will throw an `NotImplementedError`. ### Test Plan 1) unit tests for new config: ``` pytest test/quantization/test_quant_api.py::TestFqnToConfig ``` 2) regression test for ModuleFqnToConfig ``` pytest test/quantization/test_quant_api.py -k test_module_fqn_to_config ``` 3) Make sure that we can load old HF checkpoints to maintain BC, run [this](https://huggingface.co/torchao-testing/opt-125m-ModuleFqnToConfig-v1-regex-0.14.0.dev#test-loading) 4) Make sure that this doesn't break BC with transformers ``` pytest tests/quantization/torchao_integration/test_torchao.py -k test_module_fqn_to_config ``` 5) make sure that this doesn't break BC in VLLM: ``` pytest tests/quantization/test_torchao.py ``` ___ ## How do our configs translate for MoEs? Currently, we define a bunch of configs that are for dense nn.Linear modules, how do these configs translate in the case of MoE inference? ### Some background on MoE inference There are two ways that forwards is implemented for MoE - For loop of `nn.Linear` - In this case, we break down the 3d weight x activation matmul into a for loop of 2d weight x activation matmuls. This can be seen [here](https://github.com/huggingface/transformers/blob/6cade29278c4aee3f174f8950f97a3873bdb212f/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L117). **In this case, I argue that the semantics of the configs do not change at all from the normal `nn.Linear` case, as we are just doing a bunch of normal 2d linear matmuls.** - bmm/grouped mm on the 3d weights / activations directly. **For this case, we'd need to add additional op support (bmm) for forwards. Depending on whether the subclass is an AQT subclass or non AQT subclass this will be added differently.** I plan to only support parameter quantization for non-AQT subclasses, my reasoning being that those are the most popular / important configs anyway (Float8Dynamic, Int4WeightOnly). Below is a breakdown of what Configs map to AQT / non-AQT subclasses: | not using AQT | AffineQuantizedTensor | |-----------|---------------| | Float8DynamicActivationFloat8WeightConfig | FPXWeightOnlyConfig | | Float8DynamicActivationInt4WeightConfig | Float8WeightOnlyConfig | | Float8StaticActivationFloat8WeightConfig | Float8DynamicActivationFloat8SemiSparseWeightConfig | | Int4WeightOnlyConfig (v2) | GemliteUIntXWeightOnlyConfig | | | Int4DynamicActivationInt4WeightConfig | | | Int8DynamicActivationInt4WeightConfig | | | Int8DynamicActivationInt8WeightConfig | | | Int8WeightOnlyConfig | | | IntxWeightOnlyConfig | | | UIntXWeightOnlyConfig | For these the majority of the semantics remain the same, the only semantics that really changes is `PerRow` granularity. and there's a very natural extension of `PerRow` to the 3d case (apply on the last dimension). I took a look at the keys of the non-AQT configs below and what they would mean for MoEs. #### Float8DynamicActivationFloat8WeightConfig ``` [('activation_dtype', <class 'torch.dtype'>), ('weight_dtype', <class 'torch.dtype'>), ('granularity', typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow'), typing.List[typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')]], NoneType]), ('mm_config', typing.Optional[torchao.float8.inference.Float8MMConfig]), ('activation_value_lb', typing.Optional[float]), ('activation_value_ub', typing.Optional[float]), ('kernel_preference', <enum 'KernelPreference'>), ('set_inductor_config', <class 'bool'>), ('version', <class 'int'>)] ``` `activation_dtype`, `weight_dtype`, `activation_value_lb`, `activation_value_ub` all do not change meaning semantically. `granularity=PerTensor()` does not change semantic meaning - we still use a single tensor to scale the entire weight tensor. `granularity=PerRow()` does change meaning - we now calculate a scale for each row for the last dimension [-1] i.e for a weight of (E, N, K) we would expect PerRow to create scales of block size (1, 1, K). `mm_config` `kernel_preference` and `set_inductor_config` stay the same as well. #### Float8StaticActivationFloat8WeightConfig ``` [('scale', <class 'torch.Tensor'>), ('activation_dtype', <class 'torch.dtype'>), ('weight_dtype', <class 'torch.dtype'>), ('granularity', typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow'), typing.Tuple[typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')], typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')]], NoneType]), ('mm_config', typing.Optional[torchao.float8.inference.Float8MMConfig]), ('set_inductor_config', <class 'bool'>)] ``` `scale` should be passed in as a 3d tensor instead of a 2d tensor in the case of `PerRow` granularity #### Float8DynamicActivationInt4WeightConfig ``` [('int4_packing_format', <enum 'Int4PackingFormat'>)] ``` int4_packing_format - Only "preshuffled" is supported and Int4PreshuffledTensor [supports](https://github.com/pytorch/ao/blob/895573980e085b02a2c6abbc82239bae7f1318d6/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py#L154) 3d weights. #### Int4WeightOnlyConfig ``` [('group_size', <class 'int'>), ('layout', typing.Optional[torchao.dtypes.uintx.tensor_core_tiled_layout.TensorCoreTiledLayout]), ('use_hqq', <class 'bool'>), ('zero_point_domain', typing.Optional[torchao.quantization.quant_primitives.ZeroPointDomain]), ('set_inductor_config', <class 'bool'>), ('preserve_zero', typing.Optional[bool]), ('int4_packing_format', <enum 'Int4PackingFormat'>), ('int4_choose_qparams_algorithm', <enum 'Int4ChooseQParamsAlgorithm'>), ('version', <class 'int'>)] ``` `group_size`, `int4_packing_format`, `int4_choose_qparams_algorithm`, `set_inductor_config` are the only things that are set for v2 config, I don't think these semantics of these change, although there are some packing formats that do not support 3d weights. It looks like (`Int4PackingFormat.PLAIN_INT32`, `Int4PackingFormat.MARLIN_SPARSE`).
1 parent dffb3a0 commit 50a555b

File tree

6 files changed

+566
-133
lines changed

6 files changed

+566
-133
lines changed

docs/source/serving.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ Quantizing the model for mobile deployment using TorchAO's ``Int8DynamicActivati
175175
from torchao.quantization.quant_api import (
176176
IntxWeightOnlyConfig,
177177
Int8DynamicActivationIntxWeightConfig,
178-
ModuleFqnToConfig,
178+
FqnToConfig,
179179
quantize_,
180180
)
181181
from torchao.quantization.granularity import PerGroup, PerAxis
@@ -198,7 +198,7 @@ Quantizing the model for mobile deployment using TorchAO's ``Int8DynamicActivati
198198
weight_granularity=PerGroup(32),
199199
weight_scale_dtype=torch.bfloat16,
200200
)
201-
quant_config = ModuleFqnToConfig({"_default": linear_config, "model.embed_tokens": embedding_config})
201+
quant_config = FqnToConfig({"_default": linear_config, "model.embed_tokens": embedding_config})
202202
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True, modules_to_not_convert=[])
203203
204204
# either use `untied_model_id` or `untied_model_local_path`

docs/source/torchao_vllm_integration.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,22 +54,21 @@ assert isinstance(config, AOBaseConfig)
5454
All quantization configurations inherit from {class}`torchao.core.config.AOBaseConfig`, which provides serialization and validation capabilities.
5555
```
5656

57-
(module-level-configuration)=
58-
### 3. Module-Level Configuration
57+
(fqn-configuration)=
58+
### 3. FQN Configuration
5959

60-
For granular control, use `ModuleFqnToConfig`:
60+
For granular control, use `FqnToConfig`:
6161

6262
```python
63-
from torchao.quantization import ModuleFqnToConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig
63+
from torchao.quantization import FqnToConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig
6464

65-
config = ModuleFqnToConfig({
65+
config = FqnToConfig({
6666
"model.layers.0.self_attn.q_proj": Int4WeightOnlyConfig(group_size=64),
6767
"model.layers.0.self_attn.k_proj": Int4WeightOnlyConfig(group_size=64),
6868
"model.layers.0.mlp.gate_proj": Int8WeightOnlyConfig(),
6969
"_default": Int4WeightOnlyConfig(group_size=128, version=1) # Default for other modules
7070
})
7171
```
72-
7372
(usage-examples)=
7473
## Usage Examples
7574

test/prototype/test_parq.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -588,11 +588,9 @@ def test_int8_dynamic_activation_intx_e2e(
588588
n for n, m in model.named_modules() if isinstance(m, nn.Embedding)
589589
}
590590
reg_param_names.add("_default")
591-
module_fqn_to_config = (
592-
model.config.quantization_config.quant_type.module_fqn_to_config
593-
)
594-
self.assertEqual(set(module_fqn_to_config.keys()), reg_param_names)
595-
for torchao_config in module_fqn_to_config.values():
591+
fqn_to_config = model.config.quantization_config.quant_type.fqn_to_config
592+
self.assertEqual(set(fqn_to_config.keys()), reg_param_names)
593+
for torchao_config in fqn_to_config.values():
596594
self.assertTrue(isinstance(torchao_config, config.__class__))
597595

598596

0 commit comments

Comments
 (0)