@@ -134,8 +134,6 @@ def calib_fn(model):
134
134
logger .warning ("out shape is %s" , out .shape )
135
135
assert out is not None
136
136
137
- # TODO: AssertionError: Node 'call_function' <OpOverload(op='quantized_decomposed.quantize_per_tensor', overload='default')> should occur 2 times, but 1
138
- @pytest .mark .skipif (True , reason = "TODO: fix AssertionError" )
139
137
@pytest .mark .skipif (not GT_OR_EQUAL_TORCH_VERSION_2_5 , reason = "Requires torch>=2.5" )
140
138
def test_quantize_simple_model_with_set_local (self , force_not_import_ipex ):
141
139
model , example_inputs = self .build_simple_torch_model_and_example_inputs ()
@@ -149,12 +147,14 @@ def calib_fn(model):
149
147
quant_config = get_default_static_config ()
150
148
quant_config .set_local ("fc1" , StaticQuantConfig (w_dtype = "fp32" , act_dtype = "fp32" ))
151
149
q_model = quantize (model = model , quant_config = quant_config , run_fn = calib_fn )
152
-
153
- # check the half node
154
150
expected_node_occurrence = {
155
151
# Only quantize the `fc2`
156
- torch .ops .quantized_decomposed .quantize_per_tensor .default : 2 ,
157
- torch .ops .quantized_decomposed .quantize_per_tensor .default : 2 ,
152
+ # Quantize/Dequantize input
153
+ torch .ops .quantized_decomposed .quantize_per_tensor .default : 1 ,
154
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default : 1 ,
155
+ # Quantize/Dequantize weight
156
+ torch .ops .quantized_decomposed .quantize_per_channel .default : 1 ,
157
+ torch .ops .quantized_decomposed .quantize_per_channel .default : 1 ,
158
158
}
159
159
expected_node_occurrence = {
160
160
torch_test_quant_common .NodeSpec .call_function (k ): v for k , v in expected_node_occurrence .items ()
0 commit comments