Skip to content

Commit 2160b3e

Browse files
Refactor APIs using namex (#43)
* Refactor `kimm.blocks` * Update * Export model apis * Update apis * Fix version * Update CI * Update CI * Update `release.yml` * Update tests * Speed up tests * Update keras version * Update keras version
1 parent 6309e06 commit 2160b3e

File tree

109 files changed

+1325
-521
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+1325
-521
lines changed

.github/workflows/actions.yml

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,57 @@ jobs:
2121
uses: actions/setup-python@v5
2222
with:
2323
python-version: '3.9'
24-
- uses: pre-commit/action@v3.0.1
24+
- name: Lint
25+
uses: pre-commit/action@v3.0.1
26+
- name: Get pip cache dir
27+
id: pip-cache
28+
run: |
29+
python -m pip install --upgrade pip setuptools
30+
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
31+
- name: Cache pip
32+
uses: actions/cache@v4
33+
with:
34+
path: ${{ steps.pip-cache.outputs.dir }}
35+
key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}
36+
- name: Install dependencies
37+
run: |
38+
pip install -r requirements.txt --progress-bar off --upgrade
39+
pip install -e ".[tests]" --progress-bar off --upgrade
40+
- name: Check for API changes
41+
run: |
42+
bash shell/api_gen.sh
43+
git status
44+
clean=$(git status | grep "nothing to commit")
45+
if [ -z "$clean" ]; then
46+
echo "Please run shell/api_gen.sh to generate API."
47+
exit 1
48+
fi
2549
2650
build:
2751
strategy:
2852
fail-fast: false
2953
matrix:
30-
python-version: [3.9]
3154
backend: [tensorflow, jax, torch, numpy]
3255
name: Run tests
3356
runs-on: ubuntu-latest
3457
env:
35-
PYTHON: ${{ matrix.python-version }}
3658
KERAS_BACKEND: ${{ matrix.backend }}
3759
steps:
3860
- uses: actions/checkout@v4
39-
- name: Set up Python
61+
- name: Set up Python 3.9
4062
uses: actions/setup-python@v5
4163
with:
42-
python-version: ${{ matrix.python-version }}
64+
python-version: '3.9'
4365
- name: Get pip cache dir
4466
id: pip-cache
4567
run: |
4668
python -m pip install --upgrade pip setuptools
4769
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
48-
- name: Pip cache
70+
- name: Cache pip
4971
uses: actions/cache@v4
5072
with:
5173
path: ${{ steps.pip-cache.outputs.dir }}
52-
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}-${{ hashFiles('requirements.txt') }}
74+
key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}
5375
- name: Install dependencies
5476
run: |
5577
pip install -r requirements.txt --progress-bar off --upgrade

.github/workflows/release.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ jobs:
2323
run: |
2424
python -m pip install --upgrade pip setuptools
2525
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
26-
- name: Pip cache
26+
- name: Cache pip
2727
uses: actions/cache@v4
2828
with:
2929
path: ${{ steps.pip-cache.outputs.dir }}
30-
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}-${{ hashFiles('requirements.txt') }}
30+
key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}
3131
- name: Install dependencies
3232
run: |
3333
pip install -r requirements.txt --progress-bar off --upgrade

.pre-commit-config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ repos:
1717
rev: 5.13.2
1818
hooks:
1919
- id: isort
20-
name: isort (python)
2120

2221
- repo: https://github.com/psf/black-pre-commit-mirror
2322
rev: 24.4.2

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
<div align="center">
55
<img width="50%" src="https://github.com/james77777778/kimm/assets/20734616/b21db8f2-307b-4791-b93d-e913e45fb238" alt="KIMM">
66

7-
[![Keras](https://img.shields.io/badge/keras-v3.0.4+-success.svg)](https://github.com/keras-team/keras)
7+
[![Keras](https://img.shields.io/badge/keras-v3.3.0+-success.svg)](https://github.com/keras-team/keras)
88
[![PyPI](https://img.shields.io/pypi/v/kimm)](https://pypi.org/project/kimm/)
99
[![Contributions Welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/james77777778/kimm/issues)
1010
[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/james77777778/keras-image-models/actions.yml?label=tests)](https://github.com/james77777778/keras-image-models/actions/workflows/actions.yml?query=branch%3Amain++)

api_gen.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import namex
2+
3+
from kimm._src.version import __version__
4+
5+
namex.generate_api_files(package="kimm", code_directory="_src")
6+
7+
# Add version string
8+
9+
with open("kimm/__init__.py", "r") as f:
10+
contents = f.read()
11+
with open("kimm/__init__.py", "w") as f:
12+
contents += f'__version__ = "{__version__}"\n'
13+
f.write(contents)

kimm/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
1+
"""DO NOT EDIT.
2+
3+
This file was autogenerated. Do not edit it by hand,
4+
since your modifications would be overwritten.
5+
"""
6+
7+
from kimm import blocks
18
from kimm import export
9+
from kimm import layers
210
from kimm import models
11+
from kimm import timm_utils
312
from kimm import utils
4-
from kimm.utils.model_registry import list_models
13+
from kimm._src.utils.model_registry import list_models
14+
from kimm._src.version import version
515

616
__version__ = "0.2.0"

kimm/_src/blocks/__init__.py

Whitespace-only changes.

kimm/blocks/base_block.py renamed to kimm/_src/blocks/conv2d.py

Lines changed: 4 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,10 @@
33
from keras import backend
44
from keras import layers
55

6-
from kimm.utils import make_divisible
7-
8-
9-
def apply_activation(
10-
inputs, activation: typing.Optional[str] = None, name: str = "activation"
11-
):
12-
x = inputs
13-
if activation is not None:
14-
if isinstance(activation, str):
15-
x = layers.Activation(activation, name=name)(x)
16-
elif isinstance(activation, layers.Layer):
17-
x = activation(x)
18-
else:
19-
NotImplementedError(
20-
f"Unsupported activation type: {type(activation)}"
21-
)
22-
return x
6+
from kimm._src.kimm_export import kimm_export
237

248

9+
@kimm_export(parent_path=["kimm.blocks"])
2510
def apply_conv2d_block(
2611
inputs,
2712
filters: typing.Optional[int] = None,
@@ -83,45 +68,8 @@ def apply_conv2d_block(
8368
momentum=bn_momentum,
8469
epsilon=bn_epsilon,
8570
)(x)
86-
x = apply_activation(x, activation, name=name)
71+
if activation is not None:
72+
x = layers.Activation(activation, name=name)(x)
8773
if has_skip:
8874
x = layers.Add()([x, inputs])
8975
return x
90-
91-
92-
def apply_se_block(
93-
inputs,
94-
se_ratio: float = 0.25,
95-
activation: typing.Optional[str] = "relu",
96-
gate_activation: typing.Optional[str] = "sigmoid",
97-
make_divisible_number: typing.Optional[int] = None,
98-
se_input_channels: typing.Optional[int] = None,
99-
name: str = "se_block",
100-
):
101-
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
102-
input_channels = inputs.shape[channels_axis]
103-
if se_input_channels is None:
104-
se_input_channels = input_channels
105-
if make_divisible_number is None:
106-
se_channels = round(se_input_channels * se_ratio)
107-
else:
108-
se_channels = make_divisible(
109-
se_input_channels * se_ratio, make_divisible_number
110-
)
111-
112-
x = inputs
113-
x = layers.GlobalAveragePooling2D(
114-
data_format=backend.image_data_format(),
115-
keepdims=True,
116-
name=f"{name}_mean",
117-
)(x)
118-
x = layers.Conv2D(
119-
se_channels, 1, use_bias=True, name=f"{name}_conv_reduce"
120-
)(x)
121-
x = apply_activation(x, activation, name=f"{name}_act1")
122-
x = layers.Conv2D(
123-
input_channels, 1, use_bias=True, name=f"{name}_conv_expand"
124-
)(x)
125-
x = apply_activation(x, gate_activation, name=f"{name}_gate")
126-
x = layers.Multiply(name=name)([inputs, x])
127-
return x

kimm/blocks/depthwise_separation_block.py renamed to kimm/_src/blocks/depthwise_separation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from keras import backend
44
from keras import layers
55

6-
from kimm.blocks.base_block import apply_conv2d_block
7-
from kimm.blocks.base_block import apply_se_block
6+
from kimm._src.blocks.conv2d import apply_conv2d_block
7+
from kimm._src.blocks.squeeze_and_excitation import apply_se_block
8+
from kimm._src.kimm_export import kimm_export
89

910

11+
@kimm_export(parent_path=["kimm.blocks"])
1012
def apply_depthwise_separation_block(
1113
inputs,
1214
output_channels: int,

kimm/blocks/inverted_residual_block.py renamed to kimm/_src/blocks/inverted_residual.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from keras import backend
44
from keras import layers
55

6-
from kimm.blocks.base_block import apply_conv2d_block
7-
from kimm.blocks.base_block import apply_se_block
8-
from kimm.utils import make_divisible
6+
from kimm._src.blocks.conv2d import apply_conv2d_block
7+
from kimm._src.blocks.squeeze_and_excitation import apply_se_block
8+
from kimm._src.kimm_export import kimm_export
9+
from kimm._src.utils.make_divisble import make_divisible
910

1011

12+
@kimm_export(parent_path=["kimm.blocks"])
1113
def apply_inverted_residual_block(
1214
inputs,
1315
output_channels: int,
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import typing
2+
3+
from keras import backend
4+
from keras import layers
5+
6+
from kimm._src.kimm_export import kimm_export
7+
from kimm._src.utils.make_divisble import make_divisible
8+
9+
10+
@kimm_export(parent_path=["kimm.blocks"])
11+
def apply_se_block(
12+
inputs,
13+
se_ratio: float = 0.25,
14+
activation: typing.Optional[str] = "relu",
15+
gate_activation: typing.Optional[str] = "sigmoid",
16+
make_divisible_number: typing.Optional[int] = None,
17+
se_input_channels: typing.Optional[int] = None,
18+
name: str = "se_block",
19+
):
20+
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
21+
input_channels = inputs.shape[channels_axis]
22+
if se_input_channels is None:
23+
se_input_channels = input_channels
24+
if make_divisible_number is None:
25+
se_channels = round(se_input_channels * se_ratio)
26+
else:
27+
se_channels = make_divisible(
28+
se_input_channels * se_ratio, make_divisible_number
29+
)
30+
31+
x = inputs
32+
x = layers.GlobalAveragePooling2D(
33+
data_format=backend.image_data_format(),
34+
keepdims=True,
35+
name=f"{name}_mean",
36+
)(x)
37+
x = layers.Conv2D(
38+
se_channels, 1, use_bias=True, name=f"{name}_conv_reduce"
39+
)(x)
40+
if activation is not None:
41+
x = layers.Activation(activation, name=f"{name}_act1")(x)
42+
x = layers.Conv2D(
43+
input_channels, 1, use_bias=True, name=f"{name}_conv_expand"
44+
)(x)
45+
if activation is not None:
46+
x = layers.Activation(gate_activation, name=f"{name}_gate")(x)
47+
x = layers.Multiply(name=name)([inputs, x])
48+
return x

kimm/blocks/transformer_block.py renamed to kimm/_src/blocks/transformer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
from keras import backend
44
from keras import layers
55

6-
from kimm import layers as kimm_layers
6+
from kimm._src.kimm_export import kimm_export
7+
from kimm._src.layers.attention import Attention
78

89

10+
@kimm_export(parent_path=["kimm.blocks"])
911
def apply_mlp_block(
1012
inputs,
1113
hidden_dim: int,
@@ -42,6 +44,7 @@ def apply_mlp_block(
4244
return x
4345

4446

47+
@kimm_export(parent_path=["kimm.blocks"])
4548
def apply_transformer_block(
4649
inputs,
4750
dim: int,
@@ -58,7 +61,7 @@ def apply_transformer_block(
5861
residual_1 = x
5962

6063
x = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_norm1")(x)
61-
x = kimm_layers.Attention(
64+
x = Attention(
6265
dim,
6366
num_heads,
6467
use_qkv_bias,

kimm/_src/export/__init__.py

Whitespace-only changes.

kimm/export/export_onnx.py renamed to kimm/_src/export/export_onnx.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
from keras import models
77
from keras import ops
88

9-
from kimm.models import BaseModel
10-
from kimm.utils.module_utils import torch
9+
from kimm._src.kimm_export import kimm_export
10+
from kimm._src.models.base_model import BaseModel
11+
from kimm._src.utils.module_utils import torch
1112

1213

14+
@kimm_export(parent_path=["kimm.export"])
1315
def export_onnx(
1416
model: BaseModel,
1517
input_shape: typing.Union[int, typing.Sequence[int]],

kimm/export/export_onnx_test.py renamed to kimm/_src/export/export_onnx_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
from keras import backend
44
from keras.src import testing
55

6-
from kimm import export
7-
from kimm import models
6+
from kimm._src import models
7+
from kimm._src.export import export_onnx
88

99

1010
class ExportOnnxTest(testing.TestCase, parameterized.TestCase):
1111
def get_model(self):
1212
input_shape = [3, 224, 224] # channels_first
13-
model = models.MobileNetV3W050Small(include_preprocessing=False)
13+
model = models.mobilenet_v3.MobileNetV3W050Small(
14+
include_preprocessing=False, weights=None
15+
)
1416
return input_shape, model
1517

1618
@classmethod
@@ -33,4 +35,4 @@ def DISABLE_test_export_onnx_use(self):
3335

3436
temp_dir = self.get_temp_dir()
3537

36-
export.export_onnx(model, input_shape, f"{temp_dir}/model.onnx")
38+
export_onnx.export_onnx(model, input_shape, f"{temp_dir}/model.onnx")

kimm/export/export_tflite.py renamed to kimm/_src/export/export_tflite.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from keras import models
88
from keras.src.utils.module_utils import tensorflow as tf
99

10-
from kimm.models import BaseModel
10+
from kimm._src.kimm_export import kimm_export
11+
from kimm._src.models.base_model import BaseModel
1112

1213

14+
@kimm_export(parent_path=["kimm.export"])
1315
def export_tflite(
1416
model: BaseModel,
1517
input_shape: typing.Union[int, typing.Sequence[int]],
@@ -20,9 +22,10 @@ def export_tflite(
2022
):
2123
"""Export the model to tflite format.
2224
23-
Only tensorflow backend with 'channels_last' is supported. The tflite model
24-
will be generated using `tf.lite.TFLiteConverter.from_saved_model` and
25-
optimized through tflite built-in functions.
25+
Only TensorFlow backend with 'channels_last' is supported. The
26+
tflite model will be generated using
27+
`tf.lite.TFLiteConverter.from_saved_model` and optimized through tflite
28+
built-in functions.
2629
2730
Note that when exporting an `int8` tflite model, `representative_dataset`
2831
must be passed.
@@ -37,8 +40,8 @@ def export_tflite(
3740
batch_size: int, specifying the batch size of the input,
3841
defaults to `1`.
3942
"""
40-
if backend.backend() != "tensorflow":
41-
raise ValueError("`export_tflite` only supports tensorflow backend")
43+
if backend.backend() not in ("tensorflow",):
44+
raise ValueError("`export_tflite` only supports TensorFlow backend")
4245
if backend.image_data_format() != "channels_last":
4346
raise ValueError(
4447
"`export_tflite` only supports 'channels_last' data format."

0 commit comments

Comments
 (0)