Skip to content

Commit d3d5d30

Browse files
authored
Fix the PT2E UT (#2071)
Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent eed8cb9 commit d3d5d30

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

test/3x/torch/quantization/test_pt2e_quant.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,6 @@ def calib_fn(model):
134134
logger.warning("out shape is %s", out.shape)
135135
assert out is not None
136136

137-
# TODO: AssertionError: Node 'call_function' <OpOverload(op='quantized_decomposed.quantize_per_tensor', overload='default')> should occur 2 times, but 1
138-
@pytest.mark.skipif(True, reason="TODO: fix AssertionError")
139137
@pytest.mark.skipif(not GT_OR_EQUAL_TORCH_VERSION_2_5, reason="Requires torch>=2.5")
140138
def test_quantize_simple_model_with_set_local(self, force_not_import_ipex):
141139
model, example_inputs = self.build_simple_torch_model_and_example_inputs()
@@ -149,12 +147,14 @@ def calib_fn(model):
149147
quant_config = get_default_static_config()
150148
quant_config.set_local("fc1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
151149
q_model = quantize(model=model, quant_config=quant_config, run_fn=calib_fn)
152-
153-
# check the half node
154150
expected_node_occurrence = {
155151
# Only quantize the `fc2`
156-
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
157-
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
152+
# Quantize/Dequantize input
153+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
154+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
155+
# Quantize/Dequantize weight
156+
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
157+
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
158158
}
159159
expected_node_occurrence = {
160160
torch_test_quant_common.NodeSpec.call_function(k): v for k, v in expected_node_occurrence.items()

0 commit comments

Comments
 (0)