Skip to content

Commit cade740

Browse files
author
Philip Colangelo
committed
further support for ingesting pytorch
1 parent 1573a71 commit cade740

File tree

9 files changed

+153
-108
lines changed

9 files changed

+153
-108
lines changed

examples/analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def main(onnx_files: str, output_dir: str):
7373
print(f"dim: {dynamic_shape}")
7474

7575
digest_model = DigestOnnxModel(
76-
model_proto, onnx_filepath=onnx_file, model_name=model_name
76+
model_proto, onnx_file_path=onnx_file, model_name=model_name
7777
)
7878

7979
# Update the global model dictionary

src/digest/main.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def closeTab(self, index):
287287

288288
# delete the digest model to free up used memory
289289
if unique_id in self.digest_models:
290-
del self.digest_models[unique_id]
290+
self.digest_models.pop(unique_id)
291291

292292
self.ui.tabWidget.removeTab(index)
293293
if self.ui.tabWidget.count() == 0:
@@ -486,20 +486,41 @@ def load_onnx(self, file_path: str):
486486
# Every time an onnx is loaded we should emulate a model summary button click
487487
self.summary_clicked()
488488

489-
# Before opening the file, check to see if it is already opened.
489+
model_proto = None
490+
491+
# Before opening the ONNX file, check to see if it is already opened.
490492
for index in range(self.ui.tabWidget.count()):
491493
widget = self.ui.tabWidget.widget(index)
492-
if isinstance(widget, modelSummary) and file_path == widget.file:
493-
self.ui.tabWidget.setCurrentIndex(index)
494-
return
494+
if (
495+
isinstance(widget, modelSummary)
496+
and isinstance(widget.digest_model, DigestOnnxModel)
497+
and file_path == widget.file
498+
):
499+
# Check if the model proto is different
500+
if widget.digest_model.model_proto:
501+
model_proto = onnx_utils.load_onnx(
502+
file_path, load_external_data=False
503+
)
504+
# If they are equivalent, set the GUI to show the existing
505+
# report and return
506+
if model_proto == widget.digest_model.model_proto:
507+
self.ui.tabWidget.setCurrentIndex(index)
508+
return
509+
# If they aren't equivalent, then the proto has been modified. In this case,
510+
# we close the tab associated with the stale model, remove from the model list,
511+
# then go through the standard process of adding it to the tabWidget. In the
512+
# future, it may be slightly better to have an update tab function.
513+
else:
514+
self.closeTab(index)
495515

496516
try:
497517

498518
progress = ProgressDialog("Loading & Optimizing ONNX Model...", 8, self)
499519
QApplication.processEvents() # Process pending events
500520

501-
model = onnx_utils.load_onnx(file_path, load_external_data=False)
502-
opt_model, opt_passed = onnx_utils.optimize_onnx_model(model)
521+
if not model_proto:
522+
model_proto = onnx_utils.load_onnx(file_path, load_external_data=False)
523+
opt_model, opt_passed = onnx_utils.optimize_onnx_model(model_proto)
503524
progress.step()
504525

505526
basename = os.path.splitext(os.path.basename(file_path))
@@ -918,6 +939,9 @@ def load_pytorch(self, file_path: str):
918939
basename = os.path.splitext(os.path.basename(file_path))
919940
model_name = basename[0]
920941

942+
# The current support for PyTorch includes exporting it to ONNX. In this case,
943+
# an ingest window will pop up giving the user options to export. This window
944+
# will block the main GUI until the ingest window is closed
921945
self.pytorch_ingest = PyTorchIngest(file_path, model_name)
922946
self.pytorch_ingest_window = PopupDialog(
923947
self.pytorch_ingest, "PyTorch Ingest", self

src/digest/model_class/digest_pytorch_model.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import os
44
from collections import OrderedDict
5-
from typing import List, Tuple, Optional, Any, Union
5+
from typing import List, Tuple, Optional, Union
66
import inspect
77
import onnx
88
import torch
@@ -37,7 +37,9 @@ def __init__(
3737

3838
# Input dictionary to contain the names and shapes
3939
# required for exporting the ONNX model
40-
self.input_tensor_info: OrderedDict[str, List[Any]] = OrderedDict()
40+
self.input_tensor_info: OrderedDict[
41+
str, Tuple[torch.dtype, List[Union[str, int]]]
42+
] = OrderedDict()
4143

4244
self.pytorch_model = torch.load(pytorch_file_path)
4345

@@ -58,21 +60,24 @@ def save_yaml_report(self, file_path: str) -> None:
5860
def save_text_report(self, file_path: str) -> None:
5961
"""This will be done in the DigestOnnxModel"""
6062

61-
def generate_random_tensor(self, shape: List[Union[str, int]]):
63+
def generate_random_tensor(self, dtype: torch.dtype, shape: List[Union[str, int]]):
6264
static_shape = [dim if isinstance(dim, int) else 1 for dim in shape]
63-
return torch.rand(static_shape)
65+
if dtype in (torch.float16, torch.float32, torch.float64):
66+
return torch.rand(static_shape, dtype=dtype)
67+
else:
68+
return torch.randint(0, 100, static_shape, dtype=dtype)
6469

6570
def export_to_onnx(self, output_onnx_path: str) -> Union[onnx.ModelProto, None]:
6671

6772
dummy_input_names: List[str] = list(self.input_tensor_info.keys())
6873
dummy_inputs: List[torch.Tensor] = []
6974

70-
for shape in self.input_tensor_info.values():
71-
dummy_inputs.append(self.generate_random_tensor(shape))
75+
for dtype, shape in self.input_tensor_info.values():
76+
dummy_inputs.append(self.generate_random_tensor(dtype, shape))
7277

7378
dynamic_axes = {
7479
name: {i: dim for i, dim in enumerate(shape) if isinstance(dim, str)}
75-
for name, shape in self.input_tensor_info.items()
80+
for name, (_, shape) in self.input_tensor_info.items()
7681
}
7782

7883
try:
@@ -92,7 +97,7 @@ def export_to_onnx(self, output_onnx_path: str) -> Union[onnx.ModelProto, None]:
9297

9398
return onnx.load(output_onnx_path)
9499

95-
except (TypeError, RuntimeError) as err:
100+
except (ValueError, TypeError, RuntimeError) as err:
96101
print(f"Failed to export ONNX: {err}")
97102
raise
98103

src/digest/multi_model_selection_page.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def run(self):
5656
model_proto = onnx_utils.load_onnx(file, False)
5757
self.model_dict[file] = DigestOnnxModel(
5858
model_proto,
59-
onnx_filepath=file,
59+
onnx_file_path=file,
6060
model_name=model_name,
6161
save_proto=False,
6262
)

0 commit comments

Comments
 (0)