diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index c27b7774..f5895c56 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -367,6 +367,13 @@ def get_initializer(self, tensor_name, return_dtype=False): else: return None + def del_initializer(self, initializer_name): + """Deletes an initializer from the model.""" + graph = self._model_proto.graph + init = util.get_by_name(graph.initializer, initializer_name) + if not (init is None): + graph.initializer.remove(init) + def find_producer(self, tensor_name): """Finds and returns the node that produces the tensor with given name.""" for x in self._model_proto.graph.node: diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py index f0cb203e..06d08798 100644 --- a/tests/core/test_modelwrapper.py +++ b/tests/core/test_modelwrapper.py @@ -54,11 +54,6 @@ def test_modelwrapper(): assert first_conv_iname != "" and (first_conv_iname is not None) assert first_conv_wname != "" and (first_conv_wname is not None) assert first_conv_oname != "" and (first_conv_oname is not None) - first_conv_weights = model.get_initializer(first_conv_wname) - assert first_conv_weights.shape == (8, 1, 5, 5) - first_conv_weights_rand = np.random.randn(8, 1, 5, 5) - model.set_initializer(first_conv_wname, first_conv_weights_rand) - assert (model.get_initializer(first_conv_wname) == first_conv_weights_rand).all() inp_cons = model.find_consumer(first_conv_iname) assert inp_cons == first_conv out_prod = model.find_producer(first_conv_oname) @@ -75,6 +70,21 @@ def test_modelwrapper(): assert model.get_tensor_sparsity(first_conv_iname) == inp_sparsity +def test_modelwrapper_set_get_rm_initializer(): + raw_m = get_data("qonnx.data", "onnx/mnist-conv/model.onnx") + model = ModelWrapper(raw_m) + conv_nodes = model.get_nodes_by_op_type("Conv") + first_conv = conv_nodes[0] + first_conv_wname = first_conv.input[1] + first_conv_weights = model.get_initializer(first_conv_wname) + assert first_conv_weights.shape == (8, 1, 5, 5) + first_conv_weights_rand = np.random.randn(8, 1, 5, 5) + model.set_initializer(first_conv_wname, first_conv_weights_rand) + assert (model.get_initializer(first_conv_wname) == first_conv_weights_rand).all() + model.del_initializer(first_conv_wname) + assert model.get_initializer(first_conv_wname) is None + + def test_modelwrapper_graph_order(): # create small network with properties to be tested Neg_node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["neg1"])