Skip to content

Commit d7804ac

Browse files
Add DenseNet, InceptionV3 and refactor BaseModel (#11)
* Fix export name * Add `DenseNet` * Cleanup * Add `InceptionV3` * Refactor `BaseModel` * Refactor `BaseModel` * Simplify `build_preprocessing` and `build_top` * Simplify code * Format * Mark serialization and skip them by default
1 parent ce979af commit d7804ac

40 files changed

+2142
-1245
lines changed

conftest.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
import os
22

3+
import pytest
34

4-
def pytest_configure():
5+
6+
def pytest_addoption(parser):
7+
parser.addoption(
8+
"--run_serialization",
9+
action="store_true",
10+
default=False,
11+
help="run serialization tests",
12+
)
13+
14+
15+
def pytest_configure(config):
516
import tensorflow as tf
617

718
# disable tensorflow gpu memory preallocation
@@ -12,3 +23,18 @@ def pytest_configure():
1223
# disable jax gpu memory preallocation
1324
# https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
1425
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
26+
27+
config.addinivalue_line(
28+
"markers", "serialization: mark test as a serialization test"
29+
)
30+
31+
32+
def pytest_collection_modifyitems(config, items):
33+
run_serialization_tests = config.getoption("--run_serialization")
34+
skip_serialization = pytest.mark.skipif(
35+
not run_serialization_tests,
36+
reason="need --run_serialization option to run",
37+
)
38+
for item in items:
39+
if "serialization" in item.name:
40+
item.add_marker(skip_serialization)

kimm/blocks/base_block.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def apply_conv2d_block(
3434
raise ValueError(
3535
f"kernel_size must be passed. Received: kernel_size={kernel_size}"
3636
)
37+
if isinstance(kernel_size, int):
38+
kernel_size = [kernel_size, kernel_size]
3739
input_channels = inputs.shape[-1]
3840
has_skip = add_skip and strides == 1 and input_channels == filters
3941
x = inputs
@@ -42,7 +44,9 @@ def apply_conv2d_block(
4244
padding = "same"
4345
if strides > 1:
4446
padding = "valid"
45-
x = layers.ZeroPadding2D(kernel_size // 2, name=f"{name}_pad")(x)
47+
x = layers.ZeroPadding2D(
48+
(kernel_size[0] // 2, kernel_size[1] // 2), name=f"{name}_pad"
49+
)(x)
4650

4751
if not use_depthwise:
4852
x = layers.Conv2D(

kimm/blocks/inverted_residual_block.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def apply_inverted_residual_block(
1515
expansion_ratio=1.0,
1616
se_ratio=0.0,
1717
activation="swish",
18-
se_input_channels=None,
18+
se_channels=None,
1919
se_activation=None,
2020
se_gate_activation="sigmoid",
2121
se_make_divisible_number=None,
@@ -57,7 +57,7 @@ def apply_inverted_residual_block(
5757
se_ratio,
5858
activation=se_activation or activation,
5959
gate_activation=se_gate_activation,
60-
se_input_channels=se_input_channels,
60+
se_input_channels=se_channels,
6161
make_divisible_number=se_make_divisible_number,
6262
name=f"{name}_se",
6363
)

kimm/layers/attention.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -118,18 +118,3 @@ def get_config(self):
118118
}
119119
)
120120
return config
121-
122-
123-
if __name__ == "__main__":
124-
from keras import models
125-
from keras import random
126-
127-
inputs = layers.Input(shape=[197, 768])
128-
outputs = Attention(768)(inputs)
129-
130-
model = models.Model(inputs, outputs)
131-
model.summary()
132-
133-
inputs = random.uniform([1, 197, 768])
134-
outputs = model(inputs)
135-
print(outputs.shape)

kimm/layers/layer_scale.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,3 @@ def get_config(self):
3535
}
3636
)
3737
return config
38-
39-
40-
if __name__ == "__main__":
41-
from keras import models
42-
from keras import random
43-
44-
inputs = layers.Input(shape=[197, 768])
45-
outputs = LayerScale(768)(inputs)
46-
47-
model = models.Model(inputs, outputs)
48-
model.summary()
49-
50-
inputs = random.uniform([1, 197, 768])
51-
outputs = model(inputs)
52-
print(outputs.shape)

kimm/layers/position_embedding.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,25 +38,3 @@ def compute_output_shape(self, input_shape):
3838

3939
def get_config(self):
4040
return super().get_config()
41-
42-
43-
if __name__ == "__main__":
44-
from keras import models
45-
from keras import random
46-
47-
inputs = layers.Input([224, 224, 3])
48-
x = layers.Conv2D(
49-
768,
50-
16,
51-
16,
52-
use_bias=True,
53-
)(inputs)
54-
x = layers.Reshape((-1, 768))(x)
55-
outputs = PositionEmbedding()(x)
56-
57-
model = models.Model(inputs, outputs)
58-
model.summary()
59-
60-
inputs = random.uniform([1, 224, 224, 3])
61-
outputs = model(inputs)
62-
print(outputs.shape)

kimm/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
from kimm.models.base_model import BaseModel
12
from kimm.models.efficientnet import * # noqa:F403
2-
from kimm.models.feature_extractor import FeatureExtractor
33
from kimm.models.ghostnet import * # noqa:F403
44
from kimm.models.mobilenet_v2 import * # noqa:F403
55
from kimm.models.mobilenet_v3 import * # noqa:F403

kimm/models/base_model.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import abc
2+
import typing
3+
4+
from keras import KerasTensor
5+
from keras import backend
6+
from keras import layers
7+
from keras import models
8+
from keras.src.applications import imagenet_utils
9+
10+
11+
class BaseModel(models.Model):
12+
def __init__(
13+
self,
14+
inputs,
15+
outputs,
16+
features: typing.Optional[typing.Dict[str, KerasTensor]] = None,
17+
feature_keys: typing.Optional[typing.List[str]] = None,
18+
**kwargs,
19+
):
20+
self.as_feature_extractor = kwargs.pop("as_feature_extractor", False)
21+
self.feature_keys = feature_keys
22+
if self.as_feature_extractor:
23+
if features is None:
24+
raise ValueError(
25+
"`features` must be set when "
26+
f"`as_feature_extractor=True`. Got features={features}"
27+
)
28+
if self.feature_keys is None:
29+
self.feature_keys = list(features.keys())
30+
filtered_features = {}
31+
for k in self.feature_keys:
32+
if k not in features:
33+
raise KeyError(
34+
f"'{k}' is not a key of `features`. Available keys "
35+
f"are: {list(features.keys())}"
36+
)
37+
filtered_features[k] = features[k]
38+
super().__init__(inputs=inputs, outputs=filtered_features, **kwargs)
39+
else:
40+
del features
41+
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
42+
43+
def parse_kwargs(
44+
self, kwargs: typing.Dict[str, typing.Any], default_size: int = 224
45+
):
46+
result = {
47+
"input_tensor": kwargs.pop("input_tensor", None),
48+
"input_shape": kwargs.pop("input_shape", None),
49+
"include_preprocessing": kwargs.pop("include_preprocessing", True),
50+
"include_top": kwargs.pop("include_top", True),
51+
"pooling": kwargs.pop("pooling", None),
52+
"dropout_rate": kwargs.pop("dropout_rate", 0.0),
53+
"classes": kwargs.pop("classes", 1000),
54+
"classifier_activation": kwargs.pop(
55+
"classifier_activation", "softmax"
56+
),
57+
"weights": kwargs.pop("weights", "imagenet"),
58+
"default_size": kwargs.pop("default_size", default_size),
59+
}
60+
return result
61+
62+
def determine_input_tensor(
63+
self,
64+
input_tensor=None,
65+
input_shape=None,
66+
default_size=224,
67+
min_size=32,
68+
require_flatten=False,
69+
static_shape=False,
70+
):
71+
"""Determine the input tensor by the arguments."""
72+
input_shape = imagenet_utils.obtain_input_shape(
73+
input_shape,
74+
default_size=default_size,
75+
min_size=min_size,
76+
data_format="channels_last", # always channels_last
77+
require_flatten=require_flatten or static_shape,
78+
weights=None,
79+
)
80+
81+
if input_tensor is None:
82+
x = layers.Input(shape=input_shape)
83+
else:
84+
if not backend.is_keras_tensor(input_tensor):
85+
x = layers.Input(tensor=input_tensor, shape=input_shape)
86+
else:
87+
x = input_tensor
88+
return x
89+
90+
def build_preprocessing(self, inputs, mode="imagenet"):
91+
if mode == "imagenet":
92+
# [0, 255] to [0, 1] and apply ImageNet mean and variance
93+
x = layers.Rescaling(scale=1.0 / 255.0)(inputs)
94+
x = layers.Normalization(
95+
mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225]
96+
)(x)
97+
elif mode == "0_1":
98+
# [0, 255] to [-1, 1]
99+
x = layers.Rescaling(scale=1.0 / 255.0)(inputs)
100+
elif mode == "-1_1":
101+
# [0, 255] to [-1, 1]
102+
x = layers.Rescaling(scale=1.0 / 127.5, offset=-1.0)(inputs)
103+
else:
104+
raise ValueError(
105+
"`mode` must be one of ('imagenet', '0_1', '-1_1'). "
106+
f"Received: mode={mode}"
107+
)
108+
return x
109+
110+
def build_top(self, inputs, classes, classifier_activation, dropout_rate):
111+
x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs)
112+
x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x)
113+
x = layers.Dense(
114+
classes, activation=classifier_activation, name="classifier"
115+
)(x)
116+
return x
117+
118+
def add_references(self, parsed_kwargs: typing.Dict[str, typing.Any]):
119+
self.include_preprocessing = parsed_kwargs["include_preprocessing"]
120+
self.include_top = parsed_kwargs["include_top"]
121+
self.pooling = parsed_kwargs["pooling"]
122+
self.dropout_rate = parsed_kwargs["dropout_rate"]
123+
self.classes = parsed_kwargs["classes"]
124+
self.classifier_activation = parsed_kwargs["classifier_activation"]
125+
# `self.weights` is been used internally
126+
self._weights = parsed_kwargs["weights"]
127+
128+
@staticmethod
129+
@abc.abstractmethod
130+
def available_feature_keys():
131+
# TODO: add docstring
132+
raise NotImplementedError
133+
134+
def get_config(self):
135+
# Don't chain to super here. The default `get_config()` for functional
136+
# models is nested and cannot be passed to BaseModel.
137+
config = {
138+
# models.Model
139+
"name": self.name,
140+
"trainable": self.trainable,
141+
# feature extractor
142+
"as_feature_extractor": self.as_feature_extractor,
143+
"feature_keys": self.feature_keys,
144+
# common
145+
"input_shape": self.input_shape[1:],
146+
"include_preprocessing": self.include_preprocessing,
147+
"include_top": self.include_top,
148+
"pooling": self.pooling,
149+
"dropout_rate": self.dropout_rate,
150+
"classes": self.classes,
151+
"classifier_activation": self.classifier_activation,
152+
"weights": self._weights,
153+
}
154+
return config
155+
156+
def fix_config(self, config: typing.Dict):
157+
return config

kimm/models/feature_extractor_test.py renamed to kimm/models/base_model_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from keras import random
44
from keras.src import testing
55

6-
from kimm.models.feature_extractor import FeatureExtractor
6+
from kimm.models.base_model import BaseModel
77

88

9-
class SampleModel(FeatureExtractor):
9+
class SampleModel(BaseModel):
1010
def __init__(self, **kwargs):
1111
inputs = layers.Input(shape=[224, 224, 3])
1212

@@ -34,7 +34,7 @@ def get_config(self):
3434
return super().get_config()
3535

3636

37-
class GhostNetTest(testing.TestCase, parameterized.TestCase):
37+
class BaseModelTest(testing.TestCase, parameterized.TestCase):
3838
def test_feature_extractor(self):
3939
x = random.uniform([1, 224, 224, 3])
4040

0 commit comments

Comments
 (0)