From 2ac41fb90fe4e860022e85cee1f203ff770c3475 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 6 May 2025 04:57:11 -0700 Subject: [PATCH 1/3] [PT2E] Fix per-tensor observer issue with varying shape & rank --- torchao/quantization/pt2e/observer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/pt2e/observer.py b/torchao/quantization/pt2e/observer.py index f6534308d8..e98d85e7cf 100644 --- a/torchao/quantization/pt2e/observer.py +++ b/torchao/quantization/pt2e/observer.py @@ -1793,7 +1793,7 @@ def get_block_size( "Please provide an instance of Granularity, not subclass of it" ) if isinstance(granularity, PerTensor): - return input_shape + return (-1,) * len(input_shape) elif isinstance(granularity, PerAxis): block_size = list(input_shape) block_size[granularity.axis] = 1 @@ -1891,6 +1891,10 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): assert self.original_dtype is not None, ( "Expecting original_dtype to be populated" ) + # Since input shape & rank may change (e.g. Resnet18), here we need to update block_size for each input + self.block_size = get_block_size( + observer_node.args[0].meta["tensor_meta"].shape, self.granularity + ) if hasattr(self, "is_dynamic") and self.is_dynamic: choose_qparams_affine = model.graph.call_function( torch.ops.torchao.choose_qparams_affine, From 330e2c0e51a0694728ead2a46464db28443946a9 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 8 May 2025 00:13:43 -0700 Subject: [PATCH 2/3] block_size = [-1] for per-tensor quantization --- torchao/quantization/pt2e/_affine_quantization.py | 3 ++- torchao/quantization/pt2e/observer.py | 6 +----- torchao/quantization/quant_primitives.py | 9 +++++++++ 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/torchao/quantization/pt2e/_affine_quantization.py b/torchao/quantization/pt2e/_affine_quantization.py index e02bee03ce..906629bd8a 100644 --- a/torchao/quantization/pt2e/_affine_quantization.py +++ b/torchao/quantization/pt2e/_affine_quantization.py @@ -113,7 +113,8 @@ def _get_reduction_params(block_size, input_size): shape_for_reduction: (3, 3, 5, 2, 10) reduction_dim: [0, 1, 3, 4] """ - assert len(block_size) == len(input_size) + assert block_size == [-1] or len(block_size) == len(input_size) + block_size = [-1] * len(input_size) if block_size == [-1] else block_size shape_for_reduction = [] reduction_dims = [] cur_dim = 0 diff --git a/torchao/quantization/pt2e/observer.py b/torchao/quantization/pt2e/observer.py index e98d85e7cf..3d09255cd1 100644 --- a/torchao/quantization/pt2e/observer.py +++ b/torchao/quantization/pt2e/observer.py @@ -1793,7 +1793,7 @@ def get_block_size( "Please provide an instance of Granularity, not subclass of it" ) if isinstance(granularity, PerTensor): - return (-1,) * len(input_shape) + return [-1] elif isinstance(granularity, PerAxis): block_size = list(input_shape) block_size[granularity.axis] = 1 @@ -1891,10 +1891,6 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): assert self.original_dtype is not None, ( "Expecting original_dtype to be populated" ) - # Since input shape & rank may change (e.g. Resnet18), here we need to update block_size for each input - self.block_size = get_block_size( - observer_node.args[0].meta["tensor_meta"].shape, self.granularity - ) if hasattr(self, "is_dynamic") and self.is_dynamic: choose_qparams_affine = model.graph.call_function( torch.ops.torchao.choose_qparams_affine, diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index d13ac330a0..ca740ef862 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -365,6 +365,9 @@ def _quantize_affine( # torch.uintx dtypes yet if output_dtype in _SUB_BYTE_UINT_BOUNDS: output_dtype = torch.uint8 + if block_size == [-1]: + # per-tensor quantization + block_size = [-1] * input.dim() return _quantize_affine_no_dtype_cast( input, block_size, @@ -520,6 +523,9 @@ def _dequantize_affine( torch.float16, torch.bfloat16, ], f"Unsupported output dtype: {output_dtype}" + if block_size == [-1]: + # per-tensor quantization + block_size = [-1] * input.dim() quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) return _dequantize_affine_no_dtype_check( input, @@ -878,6 +884,9 @@ def _choose_qparams_affine( scale_dtype = input.dtype if eps is None: eps = torch.finfo(input.dtype).eps + if block_size == [-1]: + # per-tensor quantization + block_size = [-1] * input.dim() assert len(block_size) == input.dim(), ( f"Got input dim:{input.dim()}, block_size: {block_size}" From b03b1e6a1db4f759e691fbe9aec0b7e741ee277a Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Fri, 16 May 2025 03:19:55 -0700 Subject: [PATCH 3/3] Refine code --- torchao/quantization/observer.py | 2 +- torchao/quantization/pt2e/_affine_quantization.py | 2 +- torchao/quantization/quant_primitives.py | 9 +++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index e103f0a59e..cbf7208edd 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -73,7 +73,7 @@ def get_block_size( granularity: The granularity type of the quantization """ if isinstance(granularity, PerTensor): - return input_shape + return [-1] elif isinstance(granularity, PerAxis): block_size = list(input_shape) block_size[granularity.axis] = 1 diff --git a/torchao/quantization/pt2e/_affine_quantization.py b/torchao/quantization/pt2e/_affine_quantization.py index 906629bd8a..2bd50893fe 100644 --- a/torchao/quantization/pt2e/_affine_quantization.py +++ b/torchao/quantization/pt2e/_affine_quantization.py @@ -114,7 +114,7 @@ def _get_reduction_params(block_size, input_size): reduction_dim: [0, 1, 3, 4] """ assert block_size == [-1] or len(block_size) == len(input_size) - block_size = [-1] * len(input_size) if block_size == [-1] else block_size + block_size = input_size if block_size == [-1] else block_size shape_for_reduction = [] reduction_dims = [] cur_dim = 0 diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index ca740ef862..1069194eda 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -255,7 +255,8 @@ def _get_reduction_params(block_size, input_size): shape_for_reduction: (3, 3, 5, 2, 10) reduction_dim: [0, 1, 3, 4] """ - assert len(block_size) == len(input_size) + assert block_size == [-1] or len(block_size) == len(input_size) + block_size = input_size if block_size == [-1] else block_size shape_for_reduction = [] reduction_dims = [] cur_dim = 0 @@ -367,7 +368,7 @@ def _quantize_affine( output_dtype = torch.uint8 if block_size == [-1]: # per-tensor quantization - block_size = [-1] * input.dim() + block_size = input.shape return _quantize_affine_no_dtype_cast( input, block_size, @@ -525,7 +526,7 @@ def _dequantize_affine( ], f"Unsupported output dtype: {output_dtype}" if block_size == [-1]: # per-tensor quantization - block_size = [-1] * input.dim() + block_size = input.shape quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) return _dequantize_affine_no_dtype_check( input, @@ -886,7 +887,7 @@ def _choose_qparams_affine( eps = torch.finfo(input.dtype).eps if block_size == [-1]: # per-tensor quantization - block_size = [-1] * input.dim() + block_size = input.shape assert len(block_size) == input.dim(), ( f"Got input dim:{input.dim()}, block_size: {block_size}"