Skip to content

Commit fe99758

Browse files
authored
Fix tutorials (#2516)
stack-info: PR: #2516, branch: drisspg/stack/83
1 parent ddd4021 commit fe99758

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

tutorials/calibration_flow/awq_like.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,12 @@ def weight_quant_func(weight):
121121
weight, weight_scale, weight_zero_point, block_size, target_dtype
122122
)
123123
elif target_dtype == torch.float8_e4m3fn:
124+
scale_2d = (
125+
weight_scale.view(1, -1) if weight_scale.dim() == 1 else weight_scale
126+
)
124127
return to_affine_quantized_floatx_static(
125128
weight,
126-
weight_scale,
129+
scale_2d,
127130
block_size,
128131
target_dtype,
129132
Float8Layout(mm_config=None),

tutorials/calibration_flow/gptq_like.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@
4848
LinearActivationQuantizedTensor,
4949
MappingType,
5050
PerTensor,
51-
_fake_quantize_affine,
5251
quantize_,
5352
to_linear_activation_quantized,
5453
)
5554
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
55+
from torchao.quantization.quant_primitives import _fake_quantize_affine
5656
from torchao.quantization.transform_module import (
5757
register_quantize_module_handler,
5858
)

0 commit comments

Comments
 (0)