Skip to content

Commit 7936d0d

Browse files
authored
Fix static AQT flow (#2046)
**Summary:** Fixes this error with the static AQT flow: ``` File "/home/andrewor/local/ao/test/dtypes/test_affine_quantized.py", line 269, in test_to_affine_quantized_intx_static to_affine_quantized_intx_static( File "/home/andrewor/local/ao/torchao/dtypes/affine_quantized_tensor.py", line 325, in from_hp_to_intx_static input_float, scale, zero_point = _layout.pre_process_static( TypeError: Layout.pre_process_static() missing 1 required positional argument: 'block_size' ``` **Test Plan:** python test/dtypes/test_affine_quantized.py -k test_to_affine_quantized_intx_static python tutorials/calibration_flow/static_quant.py
1 parent a81322e commit 7936d0d

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
CutlassInt4PackedLayout,
2020
Int4CPULayout,
2121
Int4XPULayout,
22+
PlainLayout,
2223
SemiSparseLayout,
24+
to_affine_quantized_intx_static,
2325
)
2426
from torchao.quantization import (
2527
Int4WeightOnlyConfig,
@@ -280,6 +282,16 @@ def test_copy__mismatch_metadata(self, apply_quant):
280282
):
281283
ql2.weight.copy_(ql.weight)
282284

285+
def test_to_affine_quantized_intx_static(self):
286+
to_affine_quantized_intx_static(
287+
torch.randn(2, 3),
288+
scale=torch.randn(1),
289+
zero_point=torch.zeros(1),
290+
block_size=(2, 3),
291+
target_dtype=torch.int8,
292+
_layout=PlainLayout(),
293+
)
294+
283295

284296
class TestAffineQuantizedBasic(TestCase):
285297
COMMON_DEVICES = (

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,12 @@ def from_hp_to_intx_static(
337337
zero_point_domain,
338338
)
339339

340-
int_data, scale, zero_point = _layout.post_process(int_data, scale, zero_point)
340+
int_data, scale, zero_point = _layout.post_process(
341+
int_data,
342+
scale,
343+
zero_point,
344+
block_size,
345+
)
341346

342347
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
343348
tensor_impl = tensor_impl_ctr(int_data, scale, zero_point, _layout)

0 commit comments

Comments
 (0)