Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 5a8a333

Browse files
authored
Audit the onnx pathways to make them robust against >2Gb models (#1540)
* initial commit * fix the bad rebase
1 parent 2f86f08 commit 5a8a333

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+154
-173
lines changed

src/sparseml/exporters/onnx_to_deepsparse.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from sparseml.exporters import transforms as sparseml_transforms
2323
from sparseml.exporters.base_exporter import BaseExporter
24+
from sparsezoo import save_onnx
2425

2526

2627
class ONNXToDeepsparse(BaseExporter):
@@ -109,7 +110,7 @@ def post_validate(self, model: onnx.ModelProto) -> onnx.ModelProto:
109110

110111
def export(self, pre_transforms_model: onnx.ModelProto, file_path: str):
111112
if self.export_input_model or os.getenv("SAVE_PREQAT_ONNX", False):
112-
onnx.save(pre_transforms_model, file_path.replace(".onnx", ".preqat.onnx"))
113+
save_onnx(pre_transforms_model, file_path.replace(".onnx", ".preqat.onnx"))
113114

114115
post_transforms_model: onnx.ModelProto = self.apply(pre_transforms_model)
115-
onnx.save(post_transforms_model, file_path)
116+
save_onnx(post_transforms_model, file_path)

src/sparseml/exporters/transforms/onnx_transform.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020

2121
from sparseml.exporters.transforms import BaseTransform
2222
from sparseml.exporters.transforms.utils import MatchResult
23-
from sparseml.onnx.utils import ONNXGraph, check_load_model, validate_onnx_file
23+
from sparseml.onnx.utils import ONNXGraph
24+
from sparsezoo.utils import load_model, validate_onnx
2425

2526

2627
__all__ = ["OnnxTransform"]
@@ -80,8 +81,8 @@ def pre_validate(self, model: Union[ModelProto, str]) -> ModelProto:
8081
f"Invalid model type: {type(model)}. "
8182
"Must be a string (path to the .onnx file) or ONNX ModelProto"
8283
)
83-
model = check_load_model(model)
84-
validate_onnx_file(model)
84+
model = load_model(model)
85+
validate_onnx(model)
8586
self._nodes_to_delete.clear()
8687
self._nodes_to_add.clear()
8788
self._num_matches = 0
@@ -102,5 +103,5 @@ def post_validate(self, model: ModelProto) -> ModelProto:
102103
graph = ONNXGraph(model)
103104
graph.delete_unused_initializers()
104105
graph.sort_nodes_topologically()
105-
validate_onnx_file(model)
106+
validate_onnx(model)
106107
return model

src/sparseml/onnx/optim/analyzer_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from sparseml.onnx.utils import (
2828
NodeShape,
2929
calculate_flops,
30-
check_load_model,
3130
extract_node_id,
3231
extract_node_shapes,
3332
get_kernel_shape,
@@ -38,6 +37,7 @@
3837
is_prunable_node,
3938
)
4039
from sparseml.utils import clean_path, create_parent_dirs
40+
from sparsezoo.utils import load_model
4141

4242

4343
__all__ = ["NodeAnalyzer", "ModelAnalyzer"]
@@ -358,7 +358,7 @@ def __init__(
358358
raise ValueError("model or nodes must be None, both cannot be passed")
359359

360360
if model is not None:
361-
model = check_load_model(model)
361+
model = load_model(model)
362362
node_shapes = extract_node_shapes(model)
363363
self._nodes = [
364364
NodeAnalyzer(

src/sparseml/onnx/optim/quantization/calibration.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import onnx
2626

2727
from sparseml.onnx.utils import ORTModelRunner, fold_conv_bns, get_node_output_nodes
28+
from sparsezoo.utils import save_onnx, validate_onnx
2829

2930

3031
__all__ = ["CalibrationSession"]
@@ -68,7 +69,7 @@ def __init__(
6869
suffix=".onnx", delete=True
6970
)
7071
self._augmented_model_path = self._augmented_model_tmp_file.name
71-
onnx.save(self._model_augmented, self._augmented_model_path)
72+
save_onnx(self._model_augmented, self._augmented_model_path)
7273

7374
self._sessions = {} # batch_size -> session
7475
self._quantization_thresholds = {} # Dict[node.name, Tuple(min_val, max_val)]
@@ -101,13 +102,11 @@ def _optimize_model(self) -> Union[str, None]:
101102
if model_optimized is None:
102103
# no optimization performed, skip the rest of this block
103104
raise Exception()
104-
onnx.checker.check_model(
105-
model_optimized
106-
) # should raise exception if broken
105+
validate_onnx(model_optimized) # should raise exception if broken
107106
optimized_model_path = tempfile.NamedTemporaryFile(
108107
suffix=".onnx", delete=False
109108
)
110-
onnx.save(model_optimized, optimized_model_path.name)
109+
save_onnx(model_optimized, optimized_model_path.name)
111110
self._model = model_optimized
112111
print("Optimization successful")
113112
return optimized_model_path.name

src/sparseml/onnx/optim/quantization/quantize_model_post_training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from sparseml.onnx.optim.quantization.calibration import CalibrationSession
2525
from sparseml.onnx.optim.quantization.quantize import QuantizationMode, quantize
2626
from sparseml.onnx.utils import DataLoader, quantize_resnet_identity_add_inputs
27+
from sparsezoo.utils import save_onnx
2728

2829

2930
__all__ = ["quantize_model_post_training"]
@@ -105,4 +106,4 @@ def quantize_model_post_training(
105106
if output_model_path is None:
106107
return calibrated_quantized_model
107108
else:
108-
onnx.save(calibrated_quantized_model, output_model_path)
109+
save_onnx(calibrated_quantized_model, output_model_path)

src/sparseml/onnx/optim/sensitivity_pruning.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
DeepSparseAnalyzeModelRunner,
3131
DeepSparseModelRunner,
3232
ORTModelRunner,
33-
check_load_model,
3433
extract_node_id,
3534
get_node_params,
3635
get_prunable_nodes,
@@ -46,6 +45,7 @@
4645
default_pruning_sparsities_perf,
4746
)
4847
from sparseml.utils import flatten_iterable
48+
from sparsezoo.utils import load_model
4949

5050

5151
_LOGGER = logging.getLogger(__name__)
@@ -142,7 +142,7 @@ def pruning_loss_sens_magnitude_iter(
142142
:return: the analysis results for the model with an additional layer at each
143143
iteration along with a float representing the iteration progress
144144
"""
145-
model = check_load_model(model)
145+
model = load_model(model)
146146
prunable = get_prunable_nodes(model)
147147
analysis = PruningLossSensitivityAnalysis()
148148
num_layers = len(prunable)
@@ -251,7 +251,7 @@ def pruning_loss_sens_one_shot_iter(
251251
:return: the sensitivity results for every node that is prunable,
252252
yields update at each layer along with iteration progress
253253
"""
254-
model = check_load_model(model)
254+
model = load_model(model)
255255
prunable_nodes = get_prunable_nodes(model)
256256
analysis = PruningLossSensitivityAnalysis()
257257
num_updates = len(prunable_nodes) * len(sparsity_levels) + 1

src/sparseml/onnx/utils/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
from onnx import ModelProto
2626

2727
from sparseml.onnx.utils.helpers import (
28-
check_load_model,
2928
extract_shape,
3029
get_numpy_dtype,
3130
model_inputs,
3231
model_outputs,
3332
)
3433
from sparseml.utils import NumpyArrayBatcher, load_labeled_data
34+
from sparsezoo.utils import load_model
3535

3636

3737
__all__ = ["DataLoader"]
@@ -171,7 +171,7 @@ def from_model_random(
171171
and outputs, typically the batch dimension
172172
:return: the created DataLoader instance with the random data
173173
"""
174-
model = check_load_model(model)
174+
model = load_model(model)
175175
inputs = model_inputs(model)
176176
outputs = model_outputs(model)
177177
data_shapes = OrderedDict(

src/sparseml/onnx/utils/helpers.py

Lines changed: 6 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,12 @@
2828
from onnx.helper import get_attribute_value, make_empty_tensor_value_info
2929

3030
from sparseml.onnx.base import require_onnxruntime
31-
from sparseml.utils import clean_path
31+
from sparsezoo.utils import load_model, save_onnx
3232

3333

3434
_LOGGER = logging.getLogger(__name__)
3535

3636
__all__ = [
37-
"validate_onnx_file",
38-
"check_load_model",
3937
"extract_node_id",
4038
"get_node_by_id",
4139
"get_nodes_by_input_id",
@@ -78,46 +76,6 @@
7876
]
7977

8078

81-
def validate_onnx_file(path: str):
82-
"""
83-
Validate that a file at a given path is a valid ONNX model
84-
85-
:param path: the path of the file to validate
86-
:raise ValueError: if not a valid ONNX model
87-
"""
88-
try:
89-
onnx_model = check_load_model(path)
90-
91-
if onnx_model.ByteSize() < onnx.checker.MAXIMUM_PROTOBUF:
92-
onnx.checker.check_model(onnx_model)
93-
else:
94-
_LOGGER.warning(
95-
"onnx check_model skipped as model exceeds maximum protobuf size of 2GB"
96-
)
97-
98-
if not onnx_model.opset_import:
99-
raise ValueError("could not parse opset_import")
100-
except Exception as err:
101-
raise ValueError(f"Invalid onnx model: {err}")
102-
103-
104-
def check_load_model(model: Union[str, ModelProto]) -> ModelProto:
105-
"""
106-
Load an ONNX model from a given file path if supplied.
107-
If already a model proto, then returns.
108-
109-
:param model: the model proto or path to the model ONNX file to check for loading
110-
:return: the loaded ONNX ModelProto
111-
"""
112-
if isinstance(model, ModelProto):
113-
return model
114-
115-
if isinstance(model, str):
116-
return onnx.load(clean_path(model))
117-
118-
raise ValueError(f"unknown type given for model: {type(model)}")
119-
120-
12179
def extract_node_id(node: NodeProto) -> str:
12280
"""
12381
Get the node id for a given node from an ONNX model.
@@ -915,7 +873,7 @@ def get_prunable_nodes(model: Union[str, ModelProto]) -> List[Any]:
915873
:param model: the model proto loaded from the ONNX file
916874
:return: a list of nodes from the model proto
917875
"""
918-
model = check_load_model(model)
876+
model = load_model(model)
919877
prunable_nodes = []
920878

921879
for node in model.graph.node:
@@ -951,7 +909,7 @@ def onnx_nodes_sparsities(
951909
:return: a tuple containing the overall sparsity measurement for the model,
952910
each conv or gemm node found in the model
953911
"""
954-
model = check_load_model(model)
912+
model = load_model(model)
955913
node_inp_sparsities = OrderedDict() # type: Dict[str, SparsityMeasurement]
956914
params_count = 0
957915
params_zero_count = 0
@@ -991,7 +949,7 @@ def model_inputs(model: Union[str, ModelProto]) -> List:
991949
to get the model inputs for
992950
:return: the input to the model
993951
"""
994-
model = check_load_model(model)
952+
model = load_model(model)
995953
inputs_all = [node.name for node in model.graph.input]
996954
inputs_init = [node.name for node in model.graph.initializer]
997955
input_names = list(set(inputs_all) - set(inputs_init))
@@ -1009,7 +967,7 @@ def model_outputs(model: Union[str, ModelProto]) -> List:
1009967
to get the model outputs for
1010968
:return: the output from the model
1011969
"""
1012-
model = check_load_model(model)
970+
model = load_model(model)
1013971
outputs = [node for node in model.graph.output]
1014972

1015973
return outputs
@@ -1272,4 +1230,4 @@ def override_model_input_shape(model: Union[str, onnx.ModelProto], shape: List[i
12721230
set_tensor_dim_shape(model.graph.input[0], dim, dim_size)
12731231

12741232
if model_path:
1275-
onnx.save(model, model_path)
1233+
save_onnx(model, model_path)

src/sparseml/onnx/utils/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@
3434
from sparseml.onnx.utils.data import DataLoader
3535
from sparseml.onnx.utils.graph_editor import override_model_batch_size
3636
from sparseml.onnx.utils.helpers import (
37-
check_load_model,
3837
extract_node_id,
3938
get_node_by_id,
4039
get_prunable_node_from_foldable,
4140
is_foldable_node,
4241
)
4342
from sparsezoo import File, Model
43+
from sparsezoo.utils import load_model
4444

4545

4646
try:
@@ -464,7 +464,7 @@ def __init__(
464464
import onnxruntime # import protected by @require_onnxruntime()
465465

466466
super().__init__(loss)
467-
self._model = check_load_model(model)
467+
self._model = load_model(model)
468468

469469
if batch_size is not None:
470470
override_model_batch_size(self._model, batch_size)
@@ -712,7 +712,7 @@ def correct_nm_analyze_model_node_ids(nm_result: Dict, model: Union[str, ModelPr
712712
:param model: the onnx model proto or path to the onnx file that the
713713
nm_result was for
714714
"""
715-
model = check_load_model(model)
715+
model = load_model(model)
716716

717717
for layer in nm_result["layer_info"]:
718718
node_id = (

src/sparseml/openpifpaf/export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
import os
2020
from typing import Optional
2121

22-
import onnx
2322
import torch
2423

2524
import openpifpaf
2625
from sparseml.pytorch.optim.manager import ScheduledModifierManager
2726
from sparseml.pytorch.utils import ModuleExporter
27+
from sparsezoo.utils import validate_onnx
2828

2929

3030
LOG = logging.getLogger(__name__)
@@ -115,7 +115,7 @@ def export(
115115
input_names=["input_batch"],
116116
output_names=[meta.name for meta in datamodule.head_metas],
117117
)
118-
onnx.checker.check_model(os.path.join(save_dir, name))
118+
validate_onnx(os.path.join(save_dir, name))
119119
exporter.create_deployment_folder()
120120

121121

0 commit comments

Comments
 (0)