Skip to content

Commit 5269aa6

Browse files
committed
[Test] factor out ModelWrapper init get/set tests, add init rm test
1 parent e5bb71c commit 5269aa6

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

tests/core/test_modelwrapper.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,6 @@ def test_modelwrapper():
5454
assert first_conv_iname != "" and (first_conv_iname is not None)
5555
assert first_conv_wname != "" and (first_conv_wname is not None)
5656
assert first_conv_oname != "" and (first_conv_oname is not None)
57-
first_conv_weights = model.get_initializer(first_conv_wname)
58-
assert first_conv_weights.shape == (8, 1, 5, 5)
59-
first_conv_weights_rand = np.random.randn(8, 1, 5, 5)
60-
model.set_initializer(first_conv_wname, first_conv_weights_rand)
61-
assert (model.get_initializer(first_conv_wname) == first_conv_weights_rand).all()
6257
inp_cons = model.find_consumer(first_conv_iname)
6358
assert inp_cons == first_conv
6459
out_prod = model.find_producer(first_conv_oname)
@@ -75,6 +70,21 @@ def test_modelwrapper():
7570
assert model.get_tensor_sparsity(first_conv_iname) == inp_sparsity
7671

7772

73+
def test_modelwrapper_set_get_rm_initializer():
74+
raw_m = get_data("qonnx.data", "onnx/mnist-conv/model.onnx")
75+
model = ModelWrapper(raw_m)
76+
conv_nodes = model.get_nodes_by_op_type("Conv")
77+
first_conv = conv_nodes[0]
78+
first_conv_wname = first_conv.input[1]
79+
first_conv_weights = model.get_initializer(first_conv_wname)
80+
assert first_conv_weights.shape == (8, 1, 5, 5)
81+
first_conv_weights_rand = np.random.randn(8, 1, 5, 5)
82+
model.set_initializer(first_conv_wname, first_conv_weights_rand)
83+
assert (model.get_initializer(first_conv_wname) == first_conv_weights_rand).all()
84+
model.del_initializer(first_conv_wname)
85+
assert model.get_initializer(first_conv_wname) is None
86+
87+
7888
def test_modelwrapper_graph_order():
7989
# create small network with properties to be tested
8090
Neg_node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["neg1"])

0 commit comments

Comments
 (0)