Skip to content

Commit e60ce6d

Browse files
Add VGG and Xception (#14)
1 parent cdfd212 commit e60ce6d

File tree

10 files changed

+865
-2
lines changed

10 files changed

+865
-2
lines changed

kimm/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@
88
from kimm.models.mobilevit import * # noqa:F403
99
from kimm.models.regnet import * # noqa:F403
1010
from kimm.models.resnet import * # noqa:F403
11+
from kimm.models.vgg import * # noqa:F403
1112
from kimm.models.vision_transformer import * # noqa:F403
13+
from kimm.models.xception import * # noqa:F403

kimm/models/vgg.py

Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
import typing
2+
3+
import keras
4+
from keras import layers
5+
from keras import utils
6+
7+
from kimm.models import BaseModel
8+
from kimm.utils import add_model_to_registry
9+
10+
DEFAULT_VGG11_CONFIG = [
11+
64,
12+
"M",
13+
128,
14+
"M",
15+
256,
16+
256,
17+
"M",
18+
512,
19+
512,
20+
"M",
21+
512,
22+
512,
23+
"M",
24+
]
25+
DEFAULT_VGG13_CONFIG = [
26+
64,
27+
64,
28+
"M",
29+
128,
30+
128,
31+
"M",
32+
256,
33+
256,
34+
"M",
35+
512,
36+
512,
37+
"M",
38+
512,
39+
512,
40+
"M",
41+
]
42+
DEFAULT_VGG16_CONFIG = [
43+
64,
44+
64,
45+
"M",
46+
128,
47+
128,
48+
"M",
49+
256,
50+
256,
51+
256,
52+
"M",
53+
512,
54+
512,
55+
512,
56+
"M",
57+
512,
58+
512,
59+
512,
60+
"M",
61+
]
62+
DEFAULT_VGG19_CONFIG = [
63+
64,
64+
64,
65+
"M",
66+
128,
67+
128,
68+
"M",
69+
256,
70+
256,
71+
256,
72+
256,
73+
"M",
74+
512,
75+
512,
76+
512,
77+
512,
78+
"M",
79+
512,
80+
512,
81+
512,
82+
512,
83+
"M",
84+
]
85+
86+
87+
def apply_conv_mlp_layer(
88+
inputs,
89+
output_channels,
90+
kernel_size,
91+
mlp_ratio=1.0,
92+
dropout_rate=0.2,
93+
name="conv_mlp_layer",
94+
):
95+
mid_channels = int(output_channels * mlp_ratio)
96+
97+
x = inputs
98+
x = layers.Conv2D(
99+
mid_channels, kernel_size, 1, use_bias=True, name=f"{name}_fc1conv2d"
100+
)(x)
101+
x = layers.ReLU()(x)
102+
x = layers.Dropout(dropout_rate, name=f"{name}_drop")(x)
103+
x = layers.Conv2D(
104+
output_channels, 1, 1, use_bias=True, name=f"{name}_fc2conv2d"
105+
)(x)
106+
x = layers.ReLU()(x)
107+
return x
108+
109+
110+
class VGG(BaseModel):
111+
def __init__(self, config: typing.Union[str, typing.List], **kwargs):
112+
_available_configs = ["vgg11", "vgg13", "vgg16", "vgg19"]
113+
if config == "vgg11":
114+
_config = DEFAULT_VGG11_CONFIG
115+
elif config == "vgg13":
116+
_config = DEFAULT_VGG13_CONFIG
117+
elif config == "vgg16":
118+
_config = DEFAULT_VGG16_CONFIG
119+
elif config == "vgg19":
120+
_config = DEFAULT_VGG19_CONFIG
121+
else:
122+
raise ValueError(
123+
f"config must be one of {_available_configs} using string. "
124+
f"Received: config={config}"
125+
)
126+
127+
parsed_kwargs = self.parse_kwargs(kwargs)
128+
img_input = self.determine_input_tensor(
129+
parsed_kwargs["input_tensor"],
130+
parsed_kwargs["input_shape"],
131+
parsed_kwargs["default_size"],
132+
)
133+
x = img_input
134+
135+
if parsed_kwargs["include_preprocessing"]:
136+
x = self.build_preprocessing(x, "imagenet")
137+
138+
# Prepare feature extraction
139+
features = {}
140+
141+
# Blocks
142+
current_stage_idx = 0
143+
current_block_idx = 0
144+
current_stride = 1
145+
for c in _config:
146+
name = f"features_{current_block_idx}"
147+
if c == "M":
148+
features[f"BLOCK{current_stage_idx}_S{current_stride}"] = x
149+
x = layers.MaxPooling2D(2, 2, name=name)(x)
150+
current_stride *= 2
151+
current_stage_idx += 1
152+
current_block_idx += 1
153+
else:
154+
x = layers.Conv2D(
155+
c,
156+
3,
157+
1,
158+
padding="same",
159+
use_bias=True,
160+
name=f"features_{current_block_idx}conv2d",
161+
)(x)
162+
x = layers.BatchNormalization(
163+
momentum=0.9,
164+
epsilon=1e-5,
165+
name=f"features_{current_block_idx + 1}",
166+
)(x)
167+
x = layers.ReLU(name=f"features_{current_block_idx + 2}")(x)
168+
current_block_idx += 3
169+
170+
features[f"BLOCK{current_stage_idx}_S{current_stride}"] = x
171+
x = apply_conv_mlp_layer(x, 4096, 7, 1.0, 0.0, name="pre_logits")
172+
173+
# Head
174+
if parsed_kwargs["include_top"]:
175+
x = self.build_top(
176+
x,
177+
parsed_kwargs["classes"],
178+
parsed_kwargs["classifier_activation"],
179+
parsed_kwargs["dropout_rate"],
180+
)
181+
else:
182+
if parsed_kwargs["pooling"] == "avg":
183+
x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
184+
elif parsed_kwargs["pooling"] == "max":
185+
x = layers.GlobalMaxPooling2D(name="max_pool")(x)
186+
187+
# Ensure that the model takes into account
188+
# any potential predecessors of `input_tensor`.
189+
if parsed_kwargs["input_tensor"] is not None:
190+
inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"])
191+
else:
192+
inputs = img_input
193+
194+
super().__init__(inputs=inputs, outputs=x, features=features, **kwargs)
195+
196+
# All references to `self` below this line
197+
self.add_references(parsed_kwargs)
198+
self.config = config
199+
200+
@staticmethod
201+
def available_feature_keys():
202+
return [
203+
f"BLOCK{i}_S{j}" for i, j in zip(range(6), [1, 2, 4, 8, 16, 32])
204+
]
205+
206+
def get_config(self):
207+
config = super().get_config()
208+
config.update({"config": self.config})
209+
return config
210+
211+
def fix_config(self, config: typing.Dict):
212+
unused_kwargs = ["config"]
213+
for k in unused_kwargs:
214+
config.pop(k, None)
215+
return config
216+
217+
218+
"""
219+
Model Definition
220+
"""
221+
222+
223+
class VGG11(VGG):
224+
def __init__(
225+
self,
226+
input_tensor: keras.KerasTensor = None,
227+
input_shape: typing.Optional[typing.Sequence[int]] = None,
228+
include_preprocessing: bool = True,
229+
include_top: bool = True,
230+
pooling: typing.Optional[str] = None,
231+
dropout_rate: float = 0.0,
232+
classes: int = 1000,
233+
classifier_activation: str = "softmax",
234+
weights: typing.Optional[str] = None, # TODO: imagenet
235+
name: str = "VGG11",
236+
**kwargs,
237+
):
238+
kwargs = self.fix_config(kwargs)
239+
super().__init__(
240+
"vgg11",
241+
input_tensor=input_tensor,
242+
input_shape=input_shape,
243+
include_preprocessing=include_preprocessing,
244+
include_top=include_top,
245+
pooling=pooling,
246+
dropout_rate=dropout_rate,
247+
classes=classes,
248+
classifier_activation=classifier_activation,
249+
weights=weights,
250+
name=name,
251+
default_size=224,
252+
**kwargs,
253+
)
254+
255+
256+
class VGG13(VGG):
257+
def __init__(
258+
self,
259+
input_tensor: keras.KerasTensor = None,
260+
input_shape: typing.Optional[typing.Sequence[int]] = None,
261+
include_preprocessing: bool = True,
262+
include_top: bool = True,
263+
pooling: typing.Optional[str] = None,
264+
dropout_rate: float = 0.0,
265+
classes: int = 1000,
266+
classifier_activation: str = "softmax",
267+
weights: typing.Optional[str] = None, # TODO: imagenet
268+
name: str = "VGG13",
269+
**kwargs,
270+
):
271+
kwargs = self.fix_config(kwargs)
272+
super().__init__(
273+
"vgg13",
274+
input_tensor=input_tensor,
275+
input_shape=input_shape,
276+
include_preprocessing=include_preprocessing,
277+
include_top=include_top,
278+
pooling=pooling,
279+
dropout_rate=dropout_rate,
280+
classes=classes,
281+
classifier_activation=classifier_activation,
282+
weights=weights,
283+
name=name,
284+
default_size=224,
285+
**kwargs,
286+
)
287+
288+
289+
class VGG16(VGG):
290+
def __init__(
291+
self,
292+
input_tensor: keras.KerasTensor = None,
293+
input_shape: typing.Optional[typing.Sequence[int]] = None,
294+
include_preprocessing: bool = True,
295+
include_top: bool = True,
296+
pooling: typing.Optional[str] = None,
297+
dropout_rate: float = 0.0,
298+
classes: int = 1000,
299+
classifier_activation: str = "softmax",
300+
weights: typing.Optional[str] = None, # TODO: imagenet
301+
name: str = "VGG16",
302+
**kwargs,
303+
):
304+
kwargs = self.fix_config(kwargs)
305+
super().__init__(
306+
"vgg16",
307+
input_tensor=input_tensor,
308+
input_shape=input_shape,
309+
include_preprocessing=include_preprocessing,
310+
include_top=include_top,
311+
pooling=pooling,
312+
dropout_rate=dropout_rate,
313+
classes=classes,
314+
classifier_activation=classifier_activation,
315+
weights=weights,
316+
name=name,
317+
default_size=224,
318+
**kwargs,
319+
)
320+
321+
322+
class VGG19(VGG):
323+
def __init__(
324+
self,
325+
input_tensor: keras.KerasTensor = None,
326+
input_shape: typing.Optional[typing.Sequence[int]] = None,
327+
include_preprocessing: bool = True,
328+
include_top: bool = True,
329+
pooling: typing.Optional[str] = None,
330+
dropout_rate: float = 0.0,
331+
classes: int = 1000,
332+
classifier_activation: str = "softmax",
333+
weights: typing.Optional[str] = None, # TODO: imagenet
334+
name: str = "VGG19",
335+
**kwargs,
336+
):
337+
kwargs = self.fix_config(kwargs)
338+
super().__init__(
339+
"vgg19",
340+
input_tensor=input_tensor,
341+
input_shape=input_shape,
342+
include_preprocessing=include_preprocessing,
343+
include_top=include_top,
344+
pooling=pooling,
345+
dropout_rate=dropout_rate,
346+
classes=classes,
347+
classifier_activation=classifier_activation,
348+
weights=weights,
349+
name=name,
350+
default_size=224,
351+
**kwargs,
352+
)
353+
354+
355+
add_model_to_registry(VGG11, "imagenet")
356+
add_model_to_registry(VGG13, "imagenet")
357+
add_model_to_registry(VGG16, "imagenet")
358+
add_model_to_registry(VGG19, "imagenet")

0 commit comments

Comments
 (0)