26
26
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
27
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
28
29
+ import pytest
30
+
29
31
import numpy as np
30
32
import onnx .parser as oprs
31
33
34
36
from qonnx .transformation .extract_quant_scale_zeropt import ExtractQuantScaleZeroPt
35
37
36
38
37
- def make_test_model ():
38
- ishp = (1 , 10 )
39
+ def make_test_model (ishp , channelwise , bitwidth , need_extraction_scale , need_extraction_zeropt ):
39
40
ishp_str = str (list (ishp ))
40
- channelwise = True
41
- bitwidth = np .asarray (4.0 , dtype = np .float32 )
42
41
if channelwise :
43
42
q_attr_shp = ishp
44
43
else :
45
- q_attr_shp = 1
44
+ q_attr_shp = ( 1 ,)
46
45
attrshp_str = str (list (q_attr_shp ))
47
46
np .random .seed (0 )
48
- scale = np .random .rand (* q_attr_shp ).astype (np .float32 )
49
- zeropt = np .random .rand (* q_attr_shp ).astype (np .float32 )
47
+ if need_extraction_scale :
48
+ scale = np .random .rand (* q_attr_shp ).astype (np .float32 )
49
+ else :
50
+ scale = np .ones (q_attr_shp , dtype = np .float32 )
51
+ if need_extraction_zeropt :
52
+ zeropt = np .random .rand (* q_attr_shp ).astype (np .float32 )
53
+ else :
54
+ zeropt = np .zeros (q_attr_shp , dtype = np .float32 )
50
55
signed = 1
51
56
narrow = 1
52
57
rounding_mode = "ROUND"
@@ -78,8 +83,13 @@ def make_test_model():
78
83
return model
79
84
80
85
81
- def test_extract_quant_scale_zeropt ():
82
- model = make_test_model ()
86
+ @pytest .mark .parametrize ("need_extraction_scale" , [True , False ])
87
+ @pytest .mark .parametrize ("need_extraction_zeropt" , [True , False ])
88
+ @pytest .mark .parametrize ("channelwise" , [True , False ])
89
+ def test_extract_quant_scale_zeropt (channelwise , need_extraction_scale , need_extraction_zeropt ):
90
+ ishp = (1 , 10 )
91
+ bitwidth = np .asarray (4.0 , dtype = np .float32 )
92
+ model = make_test_model (ishp , channelwise , bitwidth , need_extraction_scale , need_extraction_zeropt )
83
93
ishp = model .get_tensor_shape ("in0" )
84
94
inp = np .random .rand (* ishp ).astype (np .float32 )
85
95
y_golden = execute_onnx (model , {"in0" : inp })["out0" ]
@@ -88,6 +98,12 @@ def test_extract_quant_scale_zeropt():
88
98
assert np .allclose (y_golden , y_ret )
89
99
qnt_node = model_new .get_nodes_by_op_type ("Quant" )[0 ]
90
100
new_scale = model_new .get_initializer (qnt_node .input [1 ])
91
- assert new_scale == 1
101
+ assert ( new_scale == 1 ). all ()
92
102
new_zeropt = model_new .get_initializer (qnt_node .input [2 ])
93
- assert new_zeropt == 0
103
+ assert (new_zeropt == 0 ).all ()
104
+ if need_extraction_scale :
105
+ assert len (model_new .get_nodes_by_op_type ("Mul" )) == 1
106
+ assert len (model_new .get_nodes_by_op_type ("Div" )) == 1
107
+ if need_extraction_zeropt :
108
+ assert len (model_new .get_nodes_by_op_type ("Add" )) == 1
109
+ assert len (model_new .get_nodes_by_op_type ("Sub" )) == 1
0 commit comments