Skip to content

Commit face1a0

Browse files
Add ConvMixer (#15)
* Add `ConvMixer` * Merge model tests into one file * Fix format
1 parent e60ce6d commit face1a0

18 files changed

+774
-933
lines changed

kimm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from kimm.models.base_model import BaseModel
2+
from kimm.models.convmixer import * # noqa:F403
23
from kimm.models.densenet import * # noqa:F403
34
from kimm.models.efficientnet import * # noqa:F403
45
from kimm.models.ghostnet import * # noqa:F403

kimm/models/convmixer.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
import typing
2+
3+
import keras
4+
from keras import layers
5+
from keras import utils
6+
7+
from kimm.models.base_model import BaseModel
8+
from kimm.utils import add_model_to_registry
9+
10+
11+
def apply_convmixer_block(
12+
inputs, output_channels, kernel_size, activation, name="convmixer_block"
13+
):
14+
x = inputs
15+
16+
# Depthwise
17+
x = layers.DepthwiseConv2D(
18+
kernel_size,
19+
1,
20+
padding="same",
21+
activation=activation,
22+
use_bias=True,
23+
name=f"{name}_0_fn_0_dwconv2d",
24+
)(x)
25+
x = layers.BatchNormalization(
26+
momentum=0.9, epsilon=1e-5, name=f"{name}_0_fn_2"
27+
)(x)
28+
x = layers.Add()([x, inputs])
29+
30+
# Pointwise
31+
x = layers.Conv2D(
32+
output_channels,
33+
1,
34+
1,
35+
activation=activation,
36+
use_bias=True,
37+
name=f"{name}_1_conv2d",
38+
)(x)
39+
x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name=f"{name}_3")(
40+
x
41+
)
42+
return x
43+
44+
45+
class ConvMixer(BaseModel):
46+
def __init__(
47+
self,
48+
depth: int = 32,
49+
hidden_channels: int = 768,
50+
patch_size: int = 7,
51+
kernel_size: int = 7,
52+
activation: str = "relu",
53+
**kwargs,
54+
):
55+
parsed_kwargs = self.parse_kwargs(kwargs)
56+
img_input = self.determine_input_tensor(
57+
parsed_kwargs["input_tensor"],
58+
parsed_kwargs["input_shape"],
59+
parsed_kwargs["default_size"],
60+
)
61+
x = img_input
62+
63+
if parsed_kwargs["include_preprocessing"]:
64+
x = self.build_preprocessing(x, "imagenet")
65+
66+
# Prepare feature extraction
67+
features = {}
68+
69+
# Stem
70+
x = layers.Conv2D(
71+
hidden_channels,
72+
patch_size,
73+
patch_size,
74+
activation=activation,
75+
use_bias=True,
76+
name="stem_conv2d",
77+
)(x)
78+
x = layers.BatchNormalization(
79+
momentum=0.9, epsilon=1e-5, name="stem_bn"
80+
)(x)
81+
features["STEM"] = x
82+
83+
# Blocks
84+
for i in range(depth):
85+
x = apply_convmixer_block(
86+
x, hidden_channels, kernel_size, activation, name=f"blocks_{i}"
87+
)
88+
# Add feature
89+
features[f"BLOCK{i}"] = x
90+
91+
# Head
92+
if parsed_kwargs["include_top"]:
93+
x = self.build_top(
94+
x,
95+
parsed_kwargs["classes"],
96+
parsed_kwargs["classifier_activation"],
97+
parsed_kwargs["dropout_rate"],
98+
)
99+
else:
100+
if parsed_kwargs["pooling"] == "avg":
101+
x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
102+
elif parsed_kwargs["pooling"] == "max":
103+
x = layers.GlobalMaxPooling2D(name="max_pool")(x)
104+
105+
# Ensure that the model takes into account
106+
# any potential predecessors of `input_tensor`.
107+
if parsed_kwargs["input_tensor"] is not None:
108+
inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"])
109+
else:
110+
inputs = img_input
111+
112+
super().__init__(inputs=inputs, outputs=x, features=features, **kwargs)
113+
114+
# All references to `self` below this line
115+
self.add_references(parsed_kwargs)
116+
self.depth = depth
117+
self.hidden_channels = hidden_channels
118+
self.patch_size = patch_size
119+
self.kernel_size = kernel_size
120+
self.activation = activation
121+
122+
@staticmethod
123+
def available_feature_keys():
124+
raise NotImplementedError
125+
126+
def get_config(self):
127+
config = super().get_config()
128+
config.update(
129+
{
130+
"depth": self.depth,
131+
"hidden_channels": self.hidden_channels,
132+
"patch_size": self.patch_size,
133+
"kernel_size": self.kernel_size,
134+
"activation": self.activation,
135+
}
136+
)
137+
return config
138+
139+
def fix_config(self, config):
140+
unused_kwargs = [
141+
"depth",
142+
"hidden_channels",
143+
"patch_size",
144+
"kernel_size",
145+
"activation",
146+
]
147+
for k in unused_kwargs:
148+
config.pop(k, None)
149+
return config
150+
151+
152+
"""
153+
Model Definition
154+
"""
155+
156+
157+
class ConvMixer736D32(ConvMixer):
158+
def __init__(
159+
self,
160+
input_tensor: keras.KerasTensor = None,
161+
input_shape: typing.Optional[typing.Sequence[int]] = None,
162+
include_preprocessing: bool = True,
163+
include_top: bool = True,
164+
pooling: typing.Optional[str] = None,
165+
dropout_rate: float = 0.0,
166+
classes: int = 1000,
167+
classifier_activation: str = "softmax",
168+
weights: typing.Optional[str] = None,
169+
name: str = "ConvMixer736D32",
170+
**kwargs,
171+
):
172+
kwargs = self.fix_config(kwargs)
173+
super().__init__(
174+
32,
175+
768,
176+
7,
177+
7,
178+
"relu",
179+
input_tensor=input_tensor,
180+
input_shape=input_shape,
181+
include_preprocessing=include_preprocessing,
182+
include_top=include_top,
183+
pooling=pooling,
184+
dropout_rate=dropout_rate,
185+
classes=classes,
186+
classifier_activation=classifier_activation,
187+
weights=weights,
188+
name=name,
189+
**kwargs,
190+
)
191+
192+
@staticmethod
193+
def available_feature_keys():
194+
feature_keys = ["STEM"]
195+
feature_keys.extend([f"BLOCK{i}" for i in range(32)])
196+
return feature_keys
197+
198+
199+
class ConvMixer1024D20(ConvMixer):
200+
def __init__(
201+
self,
202+
input_tensor: keras.KerasTensor = None,
203+
input_shape: typing.Optional[typing.Sequence[int]] = None,
204+
include_preprocessing: bool = True,
205+
include_top: bool = True,
206+
pooling: typing.Optional[str] = None,
207+
dropout_rate: float = 0.0,
208+
classes: int = 1000,
209+
classifier_activation: str = "softmax",
210+
weights: typing.Optional[str] = None,
211+
name: str = "ConvMixer1024D20",
212+
**kwargs,
213+
):
214+
kwargs = self.fix_config(kwargs)
215+
super().__init__(
216+
20,
217+
1024,
218+
14,
219+
9,
220+
"gelu",
221+
input_tensor=input_tensor,
222+
input_shape=input_shape,
223+
include_preprocessing=include_preprocessing,
224+
include_top=include_top,
225+
pooling=pooling,
226+
dropout_rate=dropout_rate,
227+
classes=classes,
228+
classifier_activation=classifier_activation,
229+
weights=weights,
230+
name=name,
231+
**kwargs,
232+
)
233+
234+
@staticmethod
235+
def available_feature_keys():
236+
feature_keys = ["STEM"]
237+
feature_keys.extend([f"BLOCK{i}" for i in range(20)])
238+
return feature_keys
239+
240+
241+
class ConvMixer1536D20(ConvMixer):
242+
def __init__(
243+
self,
244+
input_tensor: keras.KerasTensor = None,
245+
input_shape: typing.Optional[typing.Sequence[int]] = None,
246+
include_preprocessing: bool = True,
247+
include_top: bool = True,
248+
pooling: typing.Optional[str] = None,
249+
dropout_rate: float = 0.0,
250+
classes: int = 1000,
251+
classifier_activation: str = "softmax",
252+
weights: typing.Optional[str] = None,
253+
name: str = "ConvMixer1536D20",
254+
**kwargs,
255+
):
256+
kwargs = self.fix_config(kwargs)
257+
super().__init__(
258+
20,
259+
1536,
260+
7,
261+
9,
262+
"gelu",
263+
input_tensor=input_tensor,
264+
input_shape=input_shape,
265+
include_preprocessing=include_preprocessing,
266+
include_top=include_top,
267+
pooling=pooling,
268+
dropout_rate=dropout_rate,
269+
classes=classes,
270+
classifier_activation=classifier_activation,
271+
weights=weights,
272+
name=name,
273+
**kwargs,
274+
)
275+
276+
@staticmethod
277+
def available_feature_keys():
278+
feature_keys = ["STEM"]
279+
feature_keys.extend([f"BLOCK{i}" for i in range(20)])
280+
return feature_keys
281+
282+
283+
add_model_to_registry(ConvMixer736D32, "imagenet")
284+
add_model_to_registry(ConvMixer1024D20, "imagenet")
285+
add_model_to_registry(ConvMixer1536D20, "imagenet")

kimm/models/densenet_test.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

0 commit comments

Comments
 (0)