Skip to content

Commit 4ccfee4

Browse files
Add export_tflite and export_onnx and support channels_first (#28)
1 parent 01007b3 commit 4ccfee4

32 files changed

+623
-79
lines changed

.github/workflows/actions.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ jobs:
2222
runs-on: ubuntu-latest
2323
env:
2424
PYTHON: ${{ matrix.python-version }}
25+
KERAS_BACKEND: ${{ matrix.backend }}
2526
steps:
2627
- uses: actions/checkout@v4
2728
- name: Set up Python

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,8 @@ cython_debug/
161161

162162
# Keras
163163
*.keras
164-
exported
164+
exported
165+
166+
# Exported model
167+
*.tflite
168+
*.onnx

conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22

33
import pytest
4+
from keras import backend
45

56

67
def pytest_addoption(parser):
@@ -27,6 +28,10 @@ def pytest_configure(config):
2728
config.addinivalue_line(
2829
"markers", "serialization: mark test as a serialization test"
2930
)
31+
config.addinivalue_line(
32+
"markers",
33+
"requires_trainable_backend: mark test for trainable backend only",
34+
)
3035

3136

3237
def pytest_collection_modifyitems(config, items):
@@ -35,6 +40,11 @@ def pytest_collection_modifyitems(config, items):
3540
not run_serialization_tests,
3641
reason="need --run_serialization option to run",
3742
)
43+
requires_trainable_backend = pytest.mark.skipif(
44+
backend.backend() == "numpy", reason="require trainable backend"
45+
)
3846
for item in items:
47+
if "requires_trainable_backend" in item.keywords:
48+
item.add_marker(requires_trainable_backend)
3949
if "serialization" in item.name:
4050
item.add_marker(skip_serialization)

kimm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from kimm import export
12
from kimm import models # force to add models to the registry
23
from kimm.utils.model_registry import list_models
34

kimm/blocks/base_block.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typing
22

3+
from keras import backend
34
from keras import layers
45

56
from kimm.utils import make_divisible
@@ -43,7 +44,9 @@ def apply_conv2d_block(
4344
)
4445
if isinstance(kernel_size, int):
4546
kernel_size = [kernel_size, kernel_size]
46-
input_channels = inputs.shape[-1]
47+
48+
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
49+
input_channels = inputs.shape[channels_axis]
4750
has_skip = add_skip and strides == 1 and input_channels == filters
4851
x = inputs
4952

@@ -74,7 +77,10 @@ def apply_conv2d_block(
7477
name=f"{name}_dwconv2d",
7578
)(x)
7679
x = layers.BatchNormalization(
77-
name=f"{name}_bn", momentum=bn_momentum, epsilon=bn_epsilon
80+
axis=channels_axis,
81+
name=f"{name}_bn",
82+
momentum=bn_momentum,
83+
epsilon=bn_epsilon,
7884
)(x)
7985
x = apply_activation(x, activation, name=name)
8086
if has_skip:
@@ -91,7 +97,8 @@ def apply_se_block(
9197
se_input_channels: typing.Optional[int] = None,
9298
name: str = "se_block",
9399
):
94-
input_channels = inputs.shape[-1]
100+
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
101+
input_channels = inputs.shape[channels_axis]
95102
if se_input_channels is None:
96103
se_input_channels = input_channels
97104
if make_divisible_number is None:
@@ -102,7 +109,11 @@ def apply_se_block(
102109
)
103110

104111
x = inputs
105-
x = layers.GlobalAveragePooling2D(keepdims=True, name=f"{name}_mean")(x)
112+
x = layers.GlobalAveragePooling2D(
113+
data_format=backend.image_data_format(),
114+
keepdims=True,
115+
name=f"{name}_mean",
116+
)(x)
106117
x = layers.Conv2D(
107118
se_channels, 1, use_bias=True, name=f"{name}_conv_reduce"
108119
)(x)

kimm/blocks/depthwise_separation_block.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typing
22

3+
from keras import backend
34
from keras import layers
45

56
from kimm.blocks.base_block import apply_conv2d_block
@@ -23,7 +24,8 @@ def apply_depthwise_separation_block(
2324
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
2425
name: str = "depthwise_separation_block",
2526
):
26-
input_channels = inputs.shape[-1]
27+
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
28+
input_channels = inputs.shape[channels_axis]
2729
has_skip = skip and (strides == 1 and input_channels == output_channels)
2830

2931
x = inputs

kimm/blocks/inverted_residual_block.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typing
22

3+
from keras import backend
34
from keras import layers
45

56
from kimm.blocks.base_block import apply_conv2d_block
@@ -25,7 +26,8 @@ def apply_inverted_residual_block(
2526
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
2627
name: str = "inverted_residual_block",
2728
):
28-
input_channels = inputs.shape[-1]
29+
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
30+
input_channels = inputs.shape[channels_axis]
2931
hidden_channels = make_divisible(input_channels * expansion_ratio)
3032
has_skip = strides == 1 and input_channels == output_channels
3133

kimm/blocks/transformer_block.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typing
22

3+
from keras import backend
34
from keras import layers
45

56
from kimm import layers as kimm_layers
@@ -13,9 +14,13 @@ def apply_mlp_block(
1314
use_bias: bool = True,
1415
dropout_rate: float = 0.0,
1516
use_conv_mlp: bool = False,
17+
data_format: typing.Optional[str] = None,
1618
name: str = "mlp_block",
1719
):
18-
input_dim = inputs.shape[-1]
20+
if data_format is None:
21+
data_format = backend.image_data_format()
22+
dim_axis = -1 if data_format == "channels_last" else 1
23+
input_dim = inputs.shape[dim_axis]
1924
output_dim = output_dim or input_dim
2025

2126
x = inputs
@@ -71,6 +76,7 @@ def apply_transformer_block(
7176
int(dim * mlp_ratio),
7277
activation=activation,
7378
dropout_rate=projection_dropout_rate,
79+
data_format="channels_last", # TODO: let backend decides
7480
name=f"{name}_mlp",
7581
)
7682
x = layers.Add()([residual_2, x])

kimm/export/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from kimm.export.export_onnx import export_onnx
2+
from kimm.export.export_tflite import export_tflite

kimm/export/export_onnx.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import pathlib
2+
import tempfile
3+
import typing
4+
5+
from keras import backend
6+
from keras import layers
7+
from keras import models
8+
from keras import ops
9+
10+
from kimm.models import BaseModel
11+
12+
13+
def _export_onnx_tf(
14+
model: BaseModel,
15+
inputs_as_nchw,
16+
export_path: typing.Union[str, pathlib.Path],
17+
):
18+
try:
19+
import tf2onnx
20+
import tf2onnx.tf_loader
21+
except ModuleNotFoundError:
22+
raise ModuleNotFoundError(
23+
"Failed to import 'tf2onnx'. Please install it by the following "
24+
"instruction:\n'pip install tf2onnx'"
25+
)
26+
27+
with tempfile.TemporaryDirectory() as temp_dir:
28+
temp_path = pathlib.Path(temp_dir, "temp_saved_model")
29+
model.export(temp_path)
30+
31+
(
32+
graph_def,
33+
inputs,
34+
outputs,
35+
tensors_to_rename,
36+
) = tf2onnx.tf_loader.from_saved_model(
37+
temp_path,
38+
None,
39+
None,
40+
return_tensors_to_rename=True,
41+
)
42+
43+
tf2onnx.convert.from_graph_def(
44+
graph_def,
45+
input_names=inputs,
46+
output_names=outputs,
47+
output_path=export_path,
48+
inputs_as_nchw=inputs_as_nchw,
49+
tensors_to_rename=tensors_to_rename,
50+
)
51+
52+
53+
def _export_onnx_torch(
54+
model: BaseModel,
55+
input_shape: typing.Union[int, int, int],
56+
export_path: typing.Union[str, pathlib.Path],
57+
):
58+
try:
59+
import torch
60+
except ModuleNotFoundError:
61+
raise ModuleNotFoundError(
62+
"Failed to import 'torch'. Please install it before calling"
63+
"`export_onnx` using torch backend"
64+
)
65+
full_input_shape = [1] + list(input_shape)
66+
dummy_inputs = ops.ones(full_input_shape)
67+
scripted_model = torch.jit.trace(model, dummy_inputs).eval()
68+
torch.onnx.export(scripted_model, dummy_inputs, export_path)
69+
70+
71+
def export_onnx(
72+
model: BaseModel,
73+
input_shape: typing.Union[int, typing.Sequence[int]],
74+
export_path: typing.Union[str, pathlib.Path],
75+
batch_size: int = 1,
76+
use_nchw: bool = True,
77+
):
78+
if backend.backend() not in ("tensorflow", "torch"):
79+
raise ValueError(
80+
"Currently, `export_onnx` only supports tensorflow and torch "
81+
"backend"
82+
)
83+
try:
84+
import onnx
85+
import onnxoptimizer
86+
import onnxsim
87+
except ModuleNotFoundError:
88+
raise ModuleNotFoundError(
89+
"Failed to import 'onnx', 'onnxsim' or 'onnxoptimizer'. Please "
90+
"install them by the following instruction:\n"
91+
"'pip install onnx onnxsim onnxoptimizer'"
92+
)
93+
94+
if isinstance(input_shape, int):
95+
input_shape = [input_shape, input_shape, 3]
96+
elif len(input_shape) == 2:
97+
input_shape = [input_shape[0], input_shape[1], 3]
98+
elif len(input_shape) == 3:
99+
input_shape = input_shape
100+
if use_nchw:
101+
if backend.backend() == "torch":
102+
raise ValueError(
103+
"Currently, torch backend doesn't support `use_nchw=True`. "
104+
"You can use tensorflow backend to overcome this issue or "
105+
"set `use_nchw=False`. "
106+
"Note that there might be a significant performance "
107+
"degradation when using torch backend to export onnx due to "
108+
"the pre- and post-transpose of the Conv2D."
109+
)
110+
elif backend.backend() == "tensorflow":
111+
inputs_as_nchw = ["inputs"]
112+
else:
113+
inputs_as_nchw = None
114+
else:
115+
inputs_as_nchw = None
116+
117+
# Fix input shape
118+
inputs = layers.Input(
119+
shape=input_shape, batch_size=batch_size, name="inputs"
120+
)
121+
outputs = model(inputs, training=False)
122+
model = models.Model(inputs, outputs)
123+
124+
if backend.backend() == "tensorflow":
125+
_export_onnx_tf(model, inputs_as_nchw, export_path)
126+
elif backend.backend() == "torch":
127+
_export_onnx_torch(model, input_shape, export_path)
128+
129+
# Further optimization
130+
model = onnx.load(export_path)
131+
model_simp, _ = onnxsim.simplify(model)
132+
model_simp = onnxoptimizer.optimize(model_simp)
133+
onnx.save(model_simp, export_path)

kimm/export/export_onnx_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
from absl.testing import parameterized
3+
from keras import backend
4+
from keras.src import testing
5+
6+
from kimm import export
7+
from kimm import models
8+
9+
10+
class ExportOnnxTest(testing.TestCase, parameterized.TestCase):
11+
def get_model(self):
12+
input_shape = [224, 224, 3]
13+
model = models.MobileNet050V3Small(include_preprocessing=False)
14+
return input_shape, model
15+
16+
@pytest.mark.skipif(
17+
backend.backend() != "tensorflow", # TODO: test torch
18+
reason="Requires tensorflow or torch backend.",
19+
)
20+
def test_export_onnx_use(self):
21+
input_shape, model = self.get_model()
22+
23+
temp_dir = self.get_temp_dir()
24+
25+
if backend.backend() == "tensorflow":
26+
export.export_onnx(model, input_shape, f"{temp_dir}/model.onnx")
27+
elif backend.backend() == "torch":
28+
export.export_onnx(
29+
model, input_shape, f"{temp_dir}/model.onnx", use_nchw=False
30+
)

0 commit comments

Comments
 (0)