Skip to content

Commit 565366a

Browse files
authored
Merge pull request #193 from fastmachinelearning/feature/analysis-subgraph-traversal
Add apply_to_subgraph parameter to ModelWrapper.analysis()
2 parents d3348e3 + 01f558d commit 565366a

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

src/qonnx/core/modelwrapper.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

3030
import copy
31+
import inspect
3132
import onnx
3233
import onnx.helper as oh
3334
import onnx.numpy_helper as np_helper
@@ -125,9 +126,13 @@ def save(self, filename):
125126
"""Saves the wrapper ONNX ModelProto into a file with given name."""
126127
onnx.save(self._model_proto, filename)
127128

128-
def analysis(self, analysis_fxn):
129+
def analysis(self, analysis_fxn, apply_to_subgraphs=False):
129130
"""Runs given anaylsis_fxn on this model and return resulting dict."""
130-
return analysis_fxn(self)
131+
if apply_to_subgraphs == True:
132+
assert "apply_to_subgraphs" in inspect.signature(analysis_fxn), "analysis_fxn must have 'apply_to_subgraphs' argument when apply_to_subgraphs == True"
133+
return analysis_fxn(self, apply_to_subgraphs)
134+
else:
135+
return analysis_fxn(self)
131136

132137
def transform_subgraphs(self, transformation, make_deepcopy=True, cleanup=True, apply_to_subgraphs=False, use_preorder_traversal=True):
133138
"""Applies given Transformation to all subgraphs of this ModelWrapper instance.

tests/core/test_subgraph_traversal.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,17 @@ def test_traversal_nested(tree, cleanup, make_deepcopy):
241241

242242
check_all_visted_once(tree, transform.dummy_transform)
243243
check_all_subgraphs_transformed(t_model.model.graph)
244+
245+
def dummy_analysis_fxn(model_wrapper):
246+
"""
247+
A dummy analysis function that simply returns the model wrapper.
248+
This is used to test that analysis functions are called correctly.
249+
"""
250+
d = {}
251+
return d
252+
253+
@pytest.mark.xfail(reason="Analysis functions require apply_to_subgraphs when traversing subgraphs")
254+
def test_analysis_fxn_without_apply_to_subgraphs_fails():
255+
# Check that an analysis function fails when apply_to_subgraphs is False
256+
model = make_subgraph_model(("top", [("sub1", []), ("sub2", [])]))
257+
model.analysis(dummy_analysis_fxn, apply_to_subgraphs=True)

0 commit comments

Comments
 (0)