@@ -287,7 +287,7 @@ def closeTab(self, index):
287
287
288
288
# delete the digest model to free up used memory
289
289
if unique_id in self .digest_models :
290
- del self .digest_models [ unique_id ]
290
+ self .digest_models . pop ( unique_id )
291
291
292
292
self .ui .tabWidget .removeTab (index )
293
293
if self .ui .tabWidget .count () == 0 :
@@ -486,20 +486,41 @@ def load_onnx(self, file_path: str):
486
486
# Every time an onnx is loaded we should emulate a model summary button click
487
487
self .summary_clicked ()
488
488
489
- # Before opening the file, check to see if it is already opened.
489
+ model_proto = None
490
+
491
+ # Before opening the ONNX file, check to see if it is already opened.
490
492
for index in range (self .ui .tabWidget .count ()):
491
493
widget = self .ui .tabWidget .widget (index )
492
- if isinstance (widget , modelSummary ) and file_path == widget .file :
493
- self .ui .tabWidget .setCurrentIndex (index )
494
- return
494
+ if (
495
+ isinstance (widget , modelSummary )
496
+ and isinstance (widget .digest_model , DigestOnnxModel )
497
+ and file_path == widget .file
498
+ ):
499
+ # Check if the model proto is different
500
+ if widget .digest_model .model_proto :
501
+ model_proto = onnx_utils .load_onnx (
502
+ file_path , load_external_data = False
503
+ )
504
+ # If they are equivalent, set the GUI to show the existing
505
+ # report and return
506
+ if model_proto == widget .digest_model .model_proto :
507
+ self .ui .tabWidget .setCurrentIndex (index )
508
+ return
509
+ # If they aren't equivalent, then the proto has been modified. In this case,
510
+ # we close the tab associated with the stale model, remove from the model list,
511
+ # then go through the standard process of adding it to the tabWidget. In the
512
+ # future, it may be slightly better to have an update tab function.
513
+ else :
514
+ self .closeTab (index )
495
515
496
516
try :
497
517
498
518
progress = ProgressDialog ("Loading & Optimizing ONNX Model..." , 8 , self )
499
519
QApplication .processEvents () # Process pending events
500
520
501
- model = onnx_utils .load_onnx (file_path , load_external_data = False )
502
- opt_model , opt_passed = onnx_utils .optimize_onnx_model (model )
521
+ if not model_proto :
522
+ model_proto = onnx_utils .load_onnx (file_path , load_external_data = False )
523
+ opt_model , opt_passed = onnx_utils .optimize_onnx_model (model_proto )
503
524
progress .step ()
504
525
505
526
basename = os .path .splitext (os .path .basename (file_path ))
@@ -918,6 +939,9 @@ def load_pytorch(self, file_path: str):
918
939
basename = os .path .splitext (os .path .basename (file_path ))
919
940
model_name = basename [0 ]
920
941
942
+ # The current support for PyTorch includes exporting it to ONNX. In this case,
943
+ # an ingest window will pop up giving the user options to export. This window
944
+ # will block the main GUI until the ingest window is closed
921
945
self .pytorch_ingest = PyTorchIngest (file_path , model_name )
922
946
self .pytorch_ingest_window = PopupDialog (
923
947
self .pytorch_ingest , "PyTorch Ingest" , self
@@ -1027,35 +1051,39 @@ def save_reports(self):
1027
1051
os .path .join (save_directory , f"{ model_name } _histogram.png" ), "PNG"
1028
1052
)
1029
1053
1030
- # Save csv of node type counts
1031
- node_type_file_path = os .path .join (
1032
- save_directory , f"{ model_name } _node_type_counts.csv"
1033
- )
1034
- digest_model .save_node_type_counts_csv_report (node_type_file_path )
1035
-
1036
- # Save (copy) the similarity image
1037
- png_file_path = self .model_similarity_thread [
1038
- digest_model .unique_id
1039
- ].png_file_path
1040
- png_save_path = os .path .join (save_directory , f"{ model_name } _heatmap.png" )
1041
- if png_file_path and os .path .exists (png_file_path ):
1042
- shutil .copy (png_file_path , png_save_path )
1043
-
1044
- # Save the text report
1045
- txt_report_file_path = os .path .join (save_directory , f"{ model_name } _report.txt" )
1046
- digest_model .save_text_report (txt_report_file_path )
1047
-
1048
- # Save the yaml report
1049
- yaml_report_file_path = os .path .join (
1050
- save_directory , f"{ model_name } _report.yaml"
1051
- )
1052
- digest_model .save_yaml_report (yaml_report_file_path )
1054
+ # Save csv of node type counts
1055
+ node_type_file_path = os .path .join (
1056
+ save_directory , f"{ model_name } _node_type_counts.csv"
1057
+ )
1058
+ digest_model .save_node_type_counts_csv_report (node_type_file_path )
1059
+
1060
+ # Save (copy) the similarity image
1061
+ png_file_path = self .model_similarity_thread [
1062
+ digest_model .unique_id
1063
+ ].png_file_path
1064
+ png_save_path = os .path .join (save_directory , f"{ model_name } _heatmap.png" )
1065
+ if png_file_path and os .path .exists (png_file_path ):
1066
+ shutil .copy (png_file_path , png_save_path )
1067
+
1068
+ # Save the text report
1069
+ txt_report_file_path = os .path .join (
1070
+ save_directory , f"{ model_name } _report.txt"
1071
+ )
1072
+ digest_model .save_text_report (txt_report_file_path )
1073
+
1074
+ # Save the yaml report
1075
+ yaml_report_file_path = os .path .join (
1076
+ save_directory , f"{ model_name } _report.yaml"
1077
+ )
1078
+ digest_model .save_yaml_report (yaml_report_file_path )
1053
1079
1054
- # Save the node list
1055
- nodes_report_file_path = os .path .join (save_directory , f"{ model_name } _nodes.csv" )
1056
- self .save_nodes_csv (nodes_report_file_path , False )
1080
+ # Save the node list
1081
+ nodes_report_file_path = os .path .join (
1082
+ save_directory , f"{ model_name } _nodes.csv"
1083
+ )
1084
+ self .save_nodes_csv (nodes_report_file_path , False )
1057
1085
1058
- self .save_nodes_csv (nodes_report_filepath , False )
1086
+ self .save_nodes_csv (nodes_report_file_path , False )
1059
1087
except Exception as exception : # pylint: disable=broad-exception-caught
1060
1088
self .status_dialog = StatusDialog (f"{ exception } " )
1061
1089
self .status_dialog .show ()
0 commit comments