Skip to content

Commit 9e0a49a

Browse files
committed
[Test] add more testcases for ExtractQuantScaleZeroPt
1 parent 95279e4 commit 9e0a49a

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

tests/transformation/test_extract_quant_scale_zeropt.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828

29+
import pytest
30+
2931
import numpy as np
3032
import onnx.parser as oprs
3133

@@ -34,19 +36,22 @@
3436
from qonnx.transformation.extract_quant_scale_zeropt import ExtractQuantScaleZeroPt
3537

3638

37-
def make_test_model():
38-
ishp = (1, 10)
39+
def make_test_model(ishp, channelwise, bitwidth, need_extraction_scale, need_extraction_zeropt):
3940
ishp_str = str(list(ishp))
40-
channelwise = True
41-
bitwidth = np.asarray(4.0, dtype=np.float32)
4241
if channelwise:
4342
q_attr_shp = ishp
4443
else:
45-
q_attr_shp = 1
44+
q_attr_shp = (1,)
4645
attrshp_str = str(list(q_attr_shp))
4746
np.random.seed(0)
48-
scale = np.random.rand(*q_attr_shp).astype(np.float32)
49-
zeropt = np.random.rand(*q_attr_shp).astype(np.float32)
47+
if need_extraction_scale:
48+
scale = np.random.rand(*q_attr_shp).astype(np.float32)
49+
else:
50+
scale = np.ones(q_attr_shp, dtype=np.float32)
51+
if need_extraction_zeropt:
52+
zeropt = np.random.rand(*q_attr_shp).astype(np.float32)
53+
else:
54+
zeropt = np.zeros(q_attr_shp, dtype=np.float32)
5055
signed = 1
5156
narrow = 1
5257
rounding_mode = "ROUND"
@@ -78,8 +83,13 @@ def make_test_model():
7883
return model
7984

8085

81-
def test_extract_quant_scale_zeropt():
82-
model = make_test_model()
86+
@pytest.mark.parametrize("need_extraction_scale", [True, False])
87+
@pytest.mark.parametrize("need_extraction_zeropt", [True, False])
88+
@pytest.mark.parametrize("channelwise", [True, False])
89+
def test_extract_quant_scale_zeropt(channelwise, need_extraction_scale, need_extraction_zeropt):
90+
ishp = (1, 10)
91+
bitwidth = np.asarray(4.0, dtype=np.float32)
92+
model = make_test_model(ishp, channelwise, bitwidth, need_extraction_scale, need_extraction_zeropt)
8393
ishp = model.get_tensor_shape("in0")
8494
inp = np.random.rand(*ishp).astype(np.float32)
8595
y_golden = execute_onnx(model, {"in0": inp})["out0"]
@@ -88,6 +98,12 @@ def test_extract_quant_scale_zeropt():
8898
assert np.allclose(y_golden, y_ret)
8999
qnt_node = model_new.get_nodes_by_op_type("Quant")[0]
90100
new_scale = model_new.get_initializer(qnt_node.input[1])
91-
assert new_scale == 1
101+
assert (new_scale == 1).all()
92102
new_zeropt = model_new.get_initializer(qnt_node.input[2])
93-
assert new_zeropt == 0
103+
assert (new_zeropt == 0).all()
104+
if need_extraction_scale:
105+
assert len(model_new.get_nodes_by_op_type("Mul")) == 1
106+
assert len(model_new.get_nodes_by_op_type("Div")) == 1
107+
if need_extraction_zeropt:
108+
assert len(model_new.get_nodes_by_op_type("Add")) == 1
109+
assert len(model_new.get_nodes_by_op_type("Sub")) == 1

0 commit comments

Comments
 (0)