@@ -122,6 +122,10 @@ def extraction(
122
122
# domain, ir_version
123
123
domain : str = onnx_graph .domain
124
124
ir_version : int = onnx_graph .ir_version
125
+ meta_data = {'domain' : domain , 'ir_version' : ir_version }
126
+ metadata_props = None
127
+ if hasattr (onnx_graph , 'metadata_props' ):
128
+ metadata_props = onnx_graph .metadata_props
125
129
126
130
graph = gs .import_onnx (onnx_graph )
127
131
graph .cleanup ().toposort ()
@@ -176,9 +180,14 @@ def extraction(
176
180
# Shape Estimation
177
181
extracted_graph = None
178
182
try :
179
- extracted_graph = onnx .shape_inference .infer_shapes (gs .export_onnx (graph , do_type_check = False , ** {'domain' : domain , 'ir_version' : ir_version }))
183
+ exported_onnx_graph = gs .export_onnx (graph , do_type_check = False , ** meta_data )
184
+ if metadata_props is not None :
185
+ exported_onnx_graph .metadata_props .extend (metadata_props )
186
+ extracted_graph = onnx .shape_inference .infer_shapes (exported_onnx_graph )
180
187
except Exception as e :
181
- extracted_graph = gs .export_onnx (graph , do_type_check = False , ** {'domain' : domain , 'ir_version' : ir_version })
188
+ extracted_graph = gs .export_onnx (graph , do_type_check = False , ** meta_data )
189
+ if metadata_props is not None :
190
+ exported_onnx_graph .metadata_props .extend (metadata_props )
182
191
if not non_verbose :
183
192
print (
184
193
f'{ Color .YELLOW } WARNING:{ Color .RESET } ' +
0 commit comments