-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Refactor keras/src/export/export_lib
and add export_onnx
#20710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
from keras.src.export.export_lib import ExportArchive | ||
from keras.src.export.onnx import export_onnx | ||
from keras.src.export.saved_model import ExportArchive | ||
from keras.src.export.saved_model import export_saved_model | ||
from keras.src.export.tfsm_layer import TFSMLayer |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from keras.src import backend | ||
from keras.src import layers | ||
from keras.src import models | ||
from keras.src import ops | ||
from keras.src import tree | ||
from keras.src.utils.module_utils import tensorflow as tf | ||
|
||
|
||
def get_input_signature(model): | ||
if not isinstance(model, models.Model): | ||
raise TypeError( | ||
"The model must be a `keras.Model`. " | ||
f"Received: model={model} of the type {type(model)}" | ||
) | ||
if not model.built: | ||
raise ValueError( | ||
"The model provided has not yet been built. It must be built " | ||
"before export." | ||
) | ||
if isinstance(model, (models.Functional, models.Sequential)): | ||
input_signature = tree.map_structure(make_input_spec, model.inputs) | ||
if isinstance(input_signature, list) and len(input_signature) > 1: | ||
input_signature = [input_signature] | ||
else: | ||
input_signature = _infer_input_signature_from_model(model) | ||
if not input_signature or not model._called: | ||
raise ValueError( | ||
"The model provided has never called. " | ||
"It must be called at least once before export." | ||
) | ||
return input_signature | ||
|
||
|
||
def _infer_input_signature_from_model(model): | ||
shapes_dict = getattr(model, "_build_shapes_dict", None) | ||
if not shapes_dict: | ||
return None | ||
|
||
def _make_input_spec(structure): | ||
# We need to turn wrapper structures like TrackingDict or _DictWrapper | ||
# into plain Python structures because they don't work with jax2tf/JAX. | ||
if isinstance(structure, dict): | ||
return {k: _make_input_spec(v) for k, v in structure.items()} | ||
elif isinstance(structure, tuple): | ||
if all(isinstance(d, (int, type(None))) for d in structure): | ||
return layers.InputSpec( | ||
shape=(None,) + structure[1:], dtype=model.input_dtype | ||
) | ||
return tuple(_make_input_spec(v) for v in structure) | ||
elif isinstance(structure, list): | ||
if all(isinstance(d, (int, type(None))) for d in structure): | ||
return layers.InputSpec( | ||
shape=[None] + structure[1:], dtype=model.input_dtype | ||
) | ||
return [_make_input_spec(v) for v in structure] | ||
else: | ||
raise ValueError( | ||
f"Unsupported type {type(structure)} for {structure}" | ||
) | ||
|
||
return [_make_input_spec(value) for value in shapes_dict.values()] | ||
|
||
|
||
def make_input_spec(x): | ||
if isinstance(x, layers.InputSpec): | ||
if x.shape is None or x.dtype is None: | ||
raise ValueError( | ||
"The `shape` and `dtype` must be provided. " f"Received: x={x}" | ||
) | ||
input_spec = x | ||
elif isinstance(x, backend.KerasTensor): | ||
shape = (None,) + backend.standardize_shape(x.shape)[1:] | ||
dtype = backend.standardize_dtype(x.dtype) | ||
input_spec = layers.InputSpec(dtype=dtype, shape=shape, name=x.name) | ||
elif backend.is_tensor(x): | ||
shape = (None,) + backend.standardize_shape(x.shape)[1:] | ||
dtype = backend.standardize_dtype(x.dtype) | ||
input_spec = layers.InputSpec(dtype=dtype, shape=shape, name=None) | ||
else: | ||
raise TypeError( | ||
f"Unsupported x={x} of the type ({type(x)}). Supported types are: " | ||
"`keras.InputSpec`, `keras.KerasTensor` and backend tensor." | ||
) | ||
return input_spec | ||
|
||
|
||
def make_tf_tensor_spec(x): | ||
if isinstance(x, tf.TensorSpec): | ||
tensor_spec = x | ||
else: | ||
input_spec = make_input_spec(x) | ||
tensor_spec = tf.TensorSpec( | ||
input_spec.shape, dtype=input_spec.dtype, name=input_spec.name | ||
) | ||
return tensor_spec | ||
|
||
|
||
def convert_spec_to_tensor(spec, replace_none_number=None): | ||
shape = backend.standardize_shape(spec.shape) | ||
if replace_none_number is not None: | ||
replace_none_number = int(replace_none_number) | ||
shape = tuple( | ||
s if s is not None else replace_none_number for s in shape | ||
) | ||
return ops.ones(shape, spec.dtype) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
import pathlib | ||
import tempfile | ||
|
||
from keras.src import backend | ||
from keras.src import tree | ||
from keras.src.export.export_utils import convert_spec_to_tensor | ||
from keras.src.export.export_utils import get_input_signature | ||
from keras.src.export.saved_model import export_saved_model | ||
from keras.src.utils.module_utils import tensorflow as tf | ||
|
||
|
||
def export_onnx(model, filepath, verbose=True, input_signature=None, **kwargs): | ||
"""Export the model as a ONNX artifact for inference. | ||
|
||
This method lets you export a model to a lightweight ONNX artifact | ||
that contains the model's forward pass only (its `call()` method) | ||
and can be served via e.g. ONNX Runtime. | ||
|
||
The original code of the model (including any custom layers you may | ||
have used) is *no longer* necessary to reload the artifact -- it is | ||
entirely standalone. | ||
|
||
Args: | ||
filepath: `str` or `pathlib.Path` object. The path to save the artifact. | ||
verbose: `bool`. Whether to print a message during export. Defaults to | ||
True`. | ||
input_signature: Optional. Specifies the shape and dtype of the model | ||
inputs. Can be a structure of `keras.InputSpec`, `tf.TensorSpec`, | ||
`backend.KerasTensor`, or backend tensor. If not provided, it will | ||
be automatically computed. Defaults to `None`. | ||
**kwargs: Additional keyword arguments. | ||
|
||
**Note:** This feature is currently supported only with TensorFlow, JAX and | ||
Torch backends. | ||
|
||
**Note:** The dtype policy must be "float32" for the model. You can further | ||
optimize the ONNX artifact using the ONNX toolkit. Learn more here: | ||
https://onnxruntime.ai/docs/performance/. | ||
|
||
**Note:** The dynamic shape feature is not yet supported with Torch | ||
backend. As a result, you must fully define the shapes of the inputs using | ||
`input_signature`. If `input_signature` is not provided, all instances of | ||
`None` (such as the batch size) will be replaced with `1`. | ||
|
||
Example: | ||
|
||
```python | ||
# Export the model as a ONNX artifact | ||
model.export("path/to/location", format="onnx") | ||
|
||
# Load the artifact in a different process/environment | ||
ort_session = onnxruntime.InferenceSession("path/to/location") | ||
ort_inputs = { | ||
k.name: v for k, v in zip(ort_session.get_inputs(), input_data) | ||
} | ||
predictions = ort_session.run(None, ort_inputs) | ||
``` | ||
""" | ||
if input_signature is None: | ||
input_signature = get_input_signature(model) | ||
if not input_signature or not model._called: | ||
raise ValueError( | ||
"The model provided has never called. " | ||
"It must be called at least once before export." | ||
) | ||
|
||
if backend.backend() in ("tensorflow", "jax"): | ||
working_dir = pathlib.Path(filepath).parent | ||
with tempfile.TemporaryDirectory(dir=working_dir) as temp_dir: | ||
if backend.backend() == "jax": | ||
kwargs = _check_jax_kwargs(kwargs) | ||
export_saved_model( | ||
model, | ||
temp_dir, | ||
verbose, | ||
input_signature, | ||
**kwargs, | ||
) | ||
saved_model_to_onnx(temp_dir, filepath, model.name) | ||
|
||
elif backend.backend() == "torch": | ||
import torch | ||
|
||
sample_inputs = tree.map_structure( | ||
lambda x: convert_spec_to_tensor(x, replace_none_number=1), | ||
input_signature, | ||
) | ||
sample_inputs = tuple(sample_inputs) | ||
# TODO: Make dict model exportable. | ||
if any(isinstance(x, dict) for x in sample_inputs): | ||
raise ValueError( | ||
"Currently, `export_onnx` in the torch backend doesn't support " | ||
"dictionaries as inputs." | ||
) | ||
|
||
# Convert to ONNX using TorchScript-based ONNX Exporter. | ||
# TODO: Use TorchDynamo-based ONNX Exporter once | ||
# `torch.onnx.dynamo_export()` supports Keras models. | ||
torch.onnx.export(model, sample_inputs, filepath, verbose=verbose) | ||
else: | ||
raise NotImplementedError( | ||
"`export_onnx` is only compatible with TensorFlow, JAX and " | ||
"Torch backends." | ||
) | ||
|
||
|
||
def _check_jax_kwargs(kwargs): | ||
kwargs = kwargs.copy() | ||
if "is_static" not in kwargs: | ||
kwargs["is_static"] = True | ||
if "jax2tf_kwargs" not in kwargs: | ||
# TODO: These options will be deprecated in JAX. We need to | ||
# find another way to export ONNX. | ||
kwargs["jax2tf_kwargs"] = { | ||
"enable_xla": False, | ||
"native_serialization": False, | ||
} | ||
if kwargs["is_static"] is not True: | ||
raise ValueError( | ||
"`is_static` must be `True` in `kwargs` when using the jax " | ||
"backend." | ||
) | ||
if kwargs["jax2tf_kwargs"]["enable_xla"] is not False: | ||
raise ValueError( | ||
"`enable_xla` must be `False` in `kwargs['jax2tf_kwargs']` " | ||
"when using the jax backend." | ||
) | ||
if kwargs["jax2tf_kwargs"]["native_serialization"] is not False: | ||
raise ValueError( | ||
"`native_serialization` must be `False` in " | ||
"`kwargs['jax2tf_kwargs']` when using the jax backend." | ||
) | ||
return kwargs | ||
|
||
|
||
def saved_model_to_onnx(saved_model_dir, filepath, name): | ||
from keras.src.utils.module_utils import tf2onnx | ||
|
||
# Convert to ONNX using `tf2onnx` library. | ||
(graph_def, inputs, outputs, initialized_tables, tensors_to_rename) = ( | ||
tf2onnx.tf_loader.from_saved_model( | ||
saved_model_dir, | ||
None, | ||
None, | ||
return_initialized_tables=True, | ||
return_tensors_to_rename=True, | ||
) | ||
) | ||
|
||
with tf.device("/cpu:0"): | ||
_ = tf2onnx.convert._convert_common( | ||
graph_def, | ||
name=name, | ||
target=[], | ||
custom_op_handlers={}, | ||
extra_opset=[], | ||
input_names=inputs, | ||
output_names=outputs, | ||
tensors_to_rename=tensors_to_rename, | ||
initialized_tables=initialized_tables, | ||
output_path=filepath, | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.