Skip to content

Commit 3fa38aa

Browse files
authored
[Float8] Add static constructor that will be used in Observer workflow (#869)
1 parent 16b40fd commit 3fa38aa

File tree

6 files changed

+386
-68
lines changed

6 files changed

+386
-68
lines changed

ruff.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,6 @@ include = [
1111
"torchao/quantization/linear_activation_weight_observer.py",
1212
"test/quantization/test_observer.py",
1313
"test/dtypes/test_affine_quantized_float.py",
14+
"torchao/quantization/weight_tensor_linear_activation_quantization.py"
15+
1416
]

test/dtypes/test_affine_quantized_float.py

Lines changed: 86 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
float8_weight_only,
1515
float8_dynamic_activation_float8_weight,
1616
)
17+
from torchao.quantization.quant_api import (
18+
float8_static_activation_float8_weight,
19+
)
20+
from torchao.quantization.quant_primitives import choose_qparams_affine, MappingType
1721
from torchao.quantization.observer import PerTensor, PerRow
1822
from torchao.float8.float8_utils import compute_error
1923
import torch
@@ -50,7 +54,7 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
5054
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
5155
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
5256
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
53-
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
57+
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
5458
@common_utils.parametrize("compile", [True, False])
5559
@common_utils.parametrize(
5660
"granularity", [PerTensor(), PerRow()] if is_H100 else [PerTensor()]
@@ -60,45 +64,57 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
6064
"sizes",
6165
[
6266
((128,), 256, 128),
63-
((256,), 512, 256),
64-
((64,), 128, 64),
6567
((32, 128), 64, 256),
66-
((64, 256), 512, 128),
6768
],
6869
)
6970
def test_fp8_linear_variants(
7071
self, dtype: torch.dtype, mode: str, compile: bool, sizes: Tuple, granularity
7172
):
72-
raises = (
73-
isinstance(granularity, PerRow)
74-
and mode == "dynamic"
75-
and dtype != torch.bfloat16
76-
)
77-
context = (
78-
nullcontext()
79-
if not raises
80-
else pytest.raises(
81-
AssertionError,
82-
match="PerRow quantization only works for bfloat16 precision",
83-
)
73+
error_message = None
74+
if isinstance(granularity, PerRow):
75+
if mode == "dynamic" and dtype != torch.bfloat16:
76+
error_message = "PerRow quantization only works for bfloat16 precision"
77+
elif mode == "static":
78+
error_message = (
79+
"Static quantization only supports PerTensor granularity"
80+
)
81+
82+
error_context = (
83+
pytest.raises(AssertionError, match=error_message)
84+
if error_message
85+
else nullcontext()
8486
)
85-
with context:
87+
88+
with error_context:
8689
M, N, K = sizes
8790
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
88-
91+
# Get a "reasonable" scale for the input tensor even though
92+
# we use the same scale for multiple activations
93+
scale, _ = choose_qparams_affine(
94+
input_tensor,
95+
MappingType.SYMMETRIC,
96+
input_tensor.shape,
97+
torch.float8_e4m3fn,
98+
scale_dtype=torch.float32,
99+
)
89100
mode_map = {
90101
"dynamic": partial(
91102
float8_dynamic_activation_float8_weight, granularity=granularity
92103
),
93104
"weight-only": float8_weight_only,
105+
"static": partial(
106+
float8_static_activation_float8_weight,
107+
scale=scale,
108+
granularity=granularity,
109+
),
94110
}
95111

96112
# Create a linear layer with bfloat16 dtype
97113
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
98114

99115
quantized_model = copy.deepcopy(model)
100116
factory = mode_map[mode]()
101-
quantize_(model, factory)
117+
quantize_(quantized_model, factory)
102118

103119
if compile:
104120
quantized_model = torch.compile(quantized_model, fullgraph=True)
@@ -145,14 +161,23 @@ def test_per_row_with_float32(self):
145161

146162
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
147163
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
148-
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
164+
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
149165
def test_serialization(self, mode: str):
150166
# Create and quantize the model
151167
model = ToyLinearModel(16, 32).to(device="cuda")
152-
if mode == "dynamic":
153-
factory = float8_dynamic_activation_float8_weight()
154-
else:
155-
factory = float8_weight_only()
168+
169+
mode_map = {
170+
"dynamic": partial(
171+
float8_dynamic_activation_float8_weight, granularity=PerTensor()
172+
),
173+
"weight-only": float8_weight_only,
174+
"static": partial(
175+
float8_static_activation_float8_weight,
176+
scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"),
177+
granularity=PerTensor(),
178+
),
179+
}
180+
factory = mode_map[mode]()
156181
quantize_(model, factory)
157182

158183
# Save the state dict to an in-memory buffer
@@ -163,46 +188,50 @@ def test_serialization(self, mode: str):
163188
buffer.seek(0)
164189

165190
# Load the state dict from the buffer
166-
loaded_state_dict = torch.load(buffer)
191+
weights_only_load = True
192+
if mode == "dynamic":
193+
# TODO will fix in followup
194+
weights_only_load = False
195+
196+
loaded_state_dict = torch.load(buffer, weights_only=weights_only_load)
167197

168198
# Create a new model and load the state dict
169199
with torch.device("meta"):
170200
new_model = ToyLinearModel(16, 32)
201+
if mode == "static":
202+
quantize_(new_model, factory)
171203
new_model.load_state_dict(loaded_state_dict, assign=True)
172204

173205
# Compare the original and loaded models
174-
if mode == "weight-only":
175-
model_weight_1 = model.linear1.weight.layout_tensor.float8_data.to(
176-
torch.float32
177-
)
178-
new_model_weight_1 = new_model.linear1.weight.layout_tensor.float8_data.to(
179-
torch.float32
180-
)
181-
182-
model_weight_2 = model.linear2.weight.layout_tensor.float8_data.to(
183-
torch.float32
184-
)
185-
new_model_weight_2 = new_model.linear2.weight.layout_tensor.float8_data.to(
186-
torch.float32
187-
)
188-
189-
else:
190-
model_weight_1 = model.linear1.weight.original_weight_tensor.layout_tensor.float8_data.to(
191-
torch.float32
192-
)
193-
new_model_weight_1 = new_model.linear1.weight.original_weight_tensor.layout_tensor.float8_data.to(
194-
torch.float32
195-
)
196-
197-
model_weight_2 = model.linear2.weight.original_weight_tensor.layout_tensor.float8_data.to(
198-
torch.float32
199-
)
200-
new_model_weight_2 = new_model.linear2.weight.original_weight_tensor.layout_tensor.float8_data.to(
201-
torch.float32
202-
)
203-
204-
assert torch.allclose(model_weight_1, new_model_weight_1)
205-
assert torch.allclose(model_weight_2, new_model_weight_2)
206+
for layer_name in ["linear1", "linear2"]:
207+
original_layer = getattr(model, layer_name)
208+
new_layer = getattr(new_model, layer_name)
209+
210+
# Compare weights
211+
if mode == "weight-only":
212+
original_weight = original_layer.weight.layout_tensor.float8_data.to(
213+
torch.float32
214+
)
215+
new_weight = new_layer.weight.layout_tensor.float8_data.to(
216+
torch.float32
217+
)
218+
else:
219+
original_weight = original_layer.weight.original_weight_tensor.layout_tensor.float8_data.to(
220+
torch.float32
221+
)
222+
new_weight = new_layer.weight.original_weight_tensor.layout_tensor.float8_data.to(
223+
torch.float32
224+
)
225+
226+
assert torch.allclose(
227+
original_weight, new_weight
228+
), f"Weights do not match for {layer_name}"
229+
230+
# Compare scales
231+
if hasattr(original_layer.weight, "scale"):
232+
assert torch.allclose(
233+
original_layer.weight.scale, new_layer.weight.scale
234+
), f"Scales do not match for {layer_name}"
206235

207236

208237
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

torchao/float8/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
LinearMMConfig,
2424
ScaledMMConfig,
2525
)
26+
from torchao.float8.inference import Float8MMConfig
2627
from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
2728

2829
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
@@ -31,7 +32,7 @@
3132
if TORCH_VERSION_AT_LEAST_2_5:
3233
# Needed to load Float8Tensor with weights_only = True
3334
from torch.serialization import add_safe_globals
34-
add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig])
35+
add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig, Float8MMConfig])
3536

3637
__all__ = [
3738
# configuration

torchao/quantization/observer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
MappingType,
66
ZeroPointDomain,
77
)
8+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
89

910
from abc import ABCMeta, abstractmethod
1011
from dataclasses import dataclass
@@ -222,3 +223,7 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
222223
self.preserve_zero,
223224
self.zero_point_domain,
224225
)
226+
227+
if TORCH_VERSION_AT_LEAST_2_5:
228+
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
229+
torch.serialization.add_safe_globals([PerRow, PerTensor])

torchao/quantization/quant_api.py

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torchao.dtypes import (
2828
to_affine_quantized_intx,
2929
to_affine_quantized_floatx,
30+
to_affine_quantized_floatx_static,
3031
TensorCoreTiledLayoutType,
3132
PlainLayoutType,
3233
AffineQuantizedTensor,
@@ -47,6 +48,9 @@
4748
LinearActivationQuantizedTensor,
4849
to_linear_activation_quantized,
4950
)
51+
from torchao.quantization.weight_tensor_linear_activation_quantization import (
52+
to_weight_tensor_with_linear_activation_quantization_metadata,
53+
)
5054

5155
from .quant_primitives import (
5256
MappingType,
@@ -678,24 +682,40 @@ def _normalize_granularity(
678682
raise ValueError(f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported.")
679683

680684

681-
def _input_quant_func_dyanmic_fp8(
685+
def _input_activation_quant_func_fp8(
682686
x: torch.Tensor,
683687
activation_granularity: _fp8_granularities,
684688
activation_dtype: torch.dtype,
689+
scale: Optional[torch.Tensor] = None,
690+
zero_point: Optional[torch.Tensor] = None,
685691
):
692+
"""This function is used to quantize the input activation tensor for an aqt_float variant. If scale
693+
is not provided it will be dynamically calculate the scales otherwise it will use the provided scale.
694+
"""
695+
assert zero_point is None, "Zero point is not supported for dynamic FP8 quantization"
686696
if isinstance(activation_granularity, PerRow):
687697
assert (
688698
x.dtype == torch.bfloat16
689699
), "PerRow quantization only works for bfloat16 precision input activation"
690700

691701
block_size = get_block_size(x.shape, activation_granularity)
692-
activation = to_affine_quantized_floatx(
693-
input_float=x,
694-
block_size=block_size,
695-
target_dtype=activation_dtype,
696-
scale_dtype=torch.float32,
697-
layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight
698-
)
702+
if scale is None:
703+
activation = to_affine_quantized_floatx(
704+
input_float=x,
705+
block_size=block_size,
706+
target_dtype=activation_dtype,
707+
scale_dtype=torch.float32,
708+
layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight
709+
)
710+
else:
711+
assert isinstance(activation_granularity, PerTensor), "Static quantization only supports PerTensor granularity"
712+
activation = to_affine_quantized_floatx_static(
713+
input_float=x,
714+
block_size=block_size,
715+
scale=scale,
716+
target_dtype=activation_dtype,
717+
layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight
718+
)
699719
return activation
700720

701721

@@ -742,7 +762,7 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
742762
)
743763

744764
input_quant_func = partial(
745-
_input_quant_func_dyanmic_fp8,
765+
_input_activation_quant_func_fp8,
746766
activation_granularity=activation_granularity,
747767
activation_dtype=activation_dtype,
748768
)
@@ -755,6 +775,60 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
755775
return _get_linear_subclass_inserter(apply_float8_dynamic_activation_quant)
756776

757777

778+
def float8_static_activation_float8_weight(
779+
scale: torch.Tensor,
780+
activation_dtype: torch.dtype = torch.float8_e4m3fn,
781+
weight_dtype: torch.dtype = torch.float8_e4m3fn,
782+
granularity: Optional[
783+
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
784+
] = None,
785+
mm_config: Optional[Float8MMConfig] = None,
786+
):
787+
"""
788+
Applies float8 static symmetric quantization to
789+
790+
Args:
791+
scale (torch.Tensor): The scale tensor for activation quantization.
792+
activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m
793+
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
794+
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
795+
"""
796+
if mm_config is None:
797+
mm_config = Float8MMConfig(use_fast_accum=True)
798+
799+
activation_granularity, weight_granularity = _normalize_granularity(granularity)
800+
assert isinstance(
801+
activation_granularity, PerTensor
802+
), "Static quantization only supports PerTensor granularity"
803+
804+
def apply_float8_static_activation_quant(weight: torch.Tensor):
805+
block_size = get_block_size(weight.shape, weight_granularity)
806+
quantized_weight = to_affine_quantized_floatx(
807+
input_float=weight,
808+
block_size=block_size,
809+
target_dtype=weight_dtype,
810+
scale_dtype=torch.float32,
811+
layout_type=Float8LayoutType(mm_config=mm_config),
812+
)
813+
814+
input_quant_func = _input_activation_quant_func_fp8
815+
input_quant_kwargs = {
816+
"activation_granularity": activation_granularity,
817+
"activation_dtype": activation_dtype,
818+
}
819+
820+
quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata(
821+
quantized_weight,
822+
input_quant_func,
823+
scale=scale,
824+
zero_point=None,
825+
quant_kwargs=input_quant_kwargs
826+
)
827+
return quantized_weight
828+
829+
return _get_linear_subclass_inserter(apply_float8_static_activation_quant)
830+
831+
758832
def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
759833
"""
760834
Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
@@ -836,4 +910,10 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor:
836910

837911

838912
if TORCH_VERSION_AT_LEAST_2_5:
839-
torch.serialization.add_safe_globals([_int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant])
913+
torch.serialization.add_safe_globals(
914+
[
915+
_int8_asymm_per_token_quant,
916+
_int8_symm_per_token_reduced_range_quant,
917+
_input_activation_quant_func_fp8,
918+
]
919+
)

0 commit comments

Comments
 (0)