Skip to content

Commit 03d84b4

Browse files
Add ConvNeXt and refactor BaseModel (#16)
* Add `ConvNeXt` * Update `requirements.txt` * Refactor `BaseModel` to reduce redundant code
1 parent face1a0 commit 03d84b4

26 files changed

+1046
-450
lines changed

kimm/blocks/transformer_block.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,29 @@ def apply_mlp_block(
1111
normalization=None,
1212
use_bias=True,
1313
dropout_rate=0.0,
14+
use_conv_mlp=False,
1415
name="mlp_block",
1516
):
1617
input_dim = inputs.shape[-1]
1718
output_dim = output_dim or input_dim
1819

1920
x = inputs
20-
x = layers.Dense(hidden_dim, use_bias=use_bias, name=f"{name}_fc1")(x)
21+
if use_conv_mlp:
22+
x = layers.Conv2D(
23+
hidden_dim, 1, use_bias=use_bias, name=f"{name}_fc1_conv2d"
24+
)(x)
25+
else:
26+
x = layers.Dense(hidden_dim, use_bias=use_bias, name=f"{name}_fc1")(x)
2127
x = layers.Activation(activation, name=f"{name}_act")(x)
2228
x = layers.Dropout(dropout_rate, name=f"{name}_drop1")(x)
2329
if normalization is not None:
2430
x = normalization(name=f"{name}_norm")(x)
25-
x = layers.Dense(output_dim, use_bias=use_bias, name=f"{name}_fc2")(x)
31+
if use_conv_mlp:
32+
x = layers.Conv2D(
33+
output_dim, 1, use_bias=use_bias, name=f"{name}_fc2_conv2d"
34+
)(x)
35+
else:
36+
x = layers.Dense(output_dim, use_bias=use_bias, name=f"{name}_fc2")(x)
2637
x = layers.Dropout(dropout_rate, name=f"{name}_drop2")(x)
2738
return x
2839

kimm/layers/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import keras
12
from keras import layers
23
from keras import ops
34

45

6+
@keras.saving.register_keras_serializable(package="kimm")
57
class Attention(layers.Layer):
68
def __init__(
79
self,

kimm/layers/layer_scale.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import keras
12
from keras import initializers
23
from keras import layers
34
from keras import ops
45

56

7+
@keras.saving.register_keras_serializable(package="kimm")
68
class LayerScale(layers.Layer):
79
def __init__(
810
self,

kimm/layers/position_embedding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import keras
12
from keras import layers
23
from keras import ops
34

45

6+
@keras.saving.register_keras_serializable(package="kimm")
57
class PositionEmbedding(layers.Layer):
68
def __init__(self, **kwargs):
79
super().__init__(**kwargs)

kimm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from kimm.models.base_model import BaseModel
22
from kimm.models.convmixer import * # noqa:F403
3+
from kimm.models.convnext import * # noqa:F403
34
from kimm.models.densenet import * # noqa:F403
45
from kimm.models.efficientnet import * # noqa:F403
56
from kimm.models.ghostnet import * # noqa:F403

kimm/models/base_model.py

Lines changed: 111 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import abc
2+
import pathlib
23
import typing
4+
import urllib.parse
35

46
from keras import KerasTensor
57
from keras import backend
68
from keras import layers
79
from keras import models
10+
from keras import utils
811
from keras.src.applications import imagenet_utils
912

1013

@@ -14,53 +17,79 @@ def __init__(
1417
inputs,
1518
outputs,
1619
features: typing.Optional[typing.Dict[str, KerasTensor]] = None,
17-
feature_keys: typing.Optional[typing.List[str]] = None,
1820
**kwargs,
1921
):
20-
self.feature_extractor = kwargs.pop("feature_extractor", False)
21-
self.feature_keys = feature_keys
22-
if self.feature_extractor:
23-
if features is None:
24-
raise ValueError(
25-
"`features` must be set when "
26-
f"`feature_extractor=True`. Received features={features}"
27-
)
28-
if self.feature_keys is None:
29-
self.feature_keys = list(features.keys())
30-
filtered_features = {}
31-
for k in self.feature_keys:
32-
if k not in features:
33-
raise KeyError(
34-
f"'{k}' is not a key of `features`. Available keys "
35-
f"are: {list(features.keys())}"
36-
)
37-
filtered_features[k] = features[k]
38-
# add outputs
39-
if backend.is_keras_tensor(outputs):
40-
filtered_features["TOP"] = outputs
41-
super().__init__(inputs=inputs, outputs=filtered_features, **kwargs)
42-
else:
22+
if not hasattr(self, "_feature_extractor"):
4323
del features
4424
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
25+
else:
26+
if not hasattr(self, "_feature_keys"):
27+
raise AttributeError(
28+
"`self._feature_keys` must be set when initializing "
29+
"BaseModel"
30+
)
31+
if self._feature_extractor:
32+
if features is None:
33+
raise ValueError(
34+
"`features` must be set when `feature_extractor=True`. "
35+
f"Received features={features}"
36+
)
37+
if self._feature_keys is None:
38+
self._feature_keys = list(features.keys())
39+
filtered_features = {}
40+
for k in self._feature_keys:
41+
if k not in features:
42+
raise KeyError(
43+
f"'{k}' is not a key of `features`. Available keys "
44+
f"are: {list(features.keys())}"
45+
)
46+
filtered_features[k] = features[k]
47+
# Add outputs
48+
if backend.is_keras_tensor(outputs):
49+
filtered_features["TOP"] = outputs
50+
super().__init__(
51+
inputs=inputs, outputs=filtered_features, **kwargs
52+
)
53+
else:
54+
del features
55+
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
56+
57+
if hasattr(self, "_weights_url"):
58+
self.load_pretrained_weights(self._weights_url)
4559

46-
def parse_kwargs(
60+
def set_properties(
4761
self, kwargs: typing.Dict[str, typing.Any], default_size: int = 224
4862
):
49-
result = {
50-
"input_tensor": kwargs.pop("input_tensor", None),
51-
"input_shape": kwargs.pop("input_shape", None),
52-
"include_preprocessing": kwargs.pop("include_preprocessing", True),
53-
"include_top": kwargs.pop("include_top", True),
54-
"pooling": kwargs.pop("pooling", None),
55-
"dropout_rate": kwargs.pop("dropout_rate", 0.0),
56-
"classes": kwargs.pop("classes", 1000),
57-
"classifier_activation": kwargs.pop(
58-
"classifier_activation", "softmax"
59-
),
60-
"weights": kwargs.pop("weights", "imagenet"),
61-
"default_size": kwargs.pop("default_size", default_size),
62-
}
63-
return result
63+
"""Must be called in the initilization of the class.
64+
65+
This method will add following common properties to the model object:
66+
- input_shape
67+
- include_preprocessing
68+
- include_top
69+
- pooling
70+
- dropout_rate
71+
- classes
72+
- classifier_activation
73+
- _weights
74+
- weights_url
75+
- default_size
76+
"""
77+
self._input_shape = kwargs.pop("input_shape", None)
78+
self._include_preprocessing = kwargs.pop("include_preprocessing", True)
79+
self._include_top = kwargs.pop("include_top", True)
80+
self._pooling = kwargs.pop("pooling", None)
81+
self._dropout_rate = kwargs.pop("dropout_rate", 0.0)
82+
self._classes = kwargs.pop("classes", 1000)
83+
self._classifier_activation = kwargs.pop(
84+
"classifier_activation", "softmax"
85+
)
86+
self._weights = kwargs.pop("weights", None)
87+
self._weights_url = kwargs.pop("weights_url", None)
88+
self._default_size = kwargs.pop("default_size", default_size)
89+
# feature extractor
90+
self._feature_extractor = kwargs.pop("feature_extractor", False)
91+
self._feature_keys = kwargs.pop("feature_keys", None)
92+
print("self._feature_keys", self._feature_keys)
6493

6594
def determine_input_tensor(
6695
self,
@@ -87,10 +116,12 @@ def determine_input_tensor(
87116
if not backend.is_keras_tensor(input_tensor):
88117
x = layers.Input(tensor=input_tensor, shape=input_shape)
89118
else:
90-
x = input_tensor
119+
x = utils.get_source_inputs(input_tensor)
91120
return x
92121

93122
def build_preprocessing(self, inputs, mode="imagenet"):
123+
if self._include_preprocessing is False:
124+
return inputs
94125
if mode == "imagenet":
95126
# [0, 255] to [0, 1] and apply ImageNet mean and variance
96127
x = layers.Rescaling(scale=1.0 / 255.0)(inputs)
@@ -118,15 +149,30 @@ def build_top(self, inputs, classes, classifier_activation, dropout_rate):
118149
)(x)
119150
return x
120151

121-
def add_references(self, parsed_kwargs: typing.Dict[str, typing.Any]):
122-
self.include_preprocessing = parsed_kwargs["include_preprocessing"]
123-
self.include_top = parsed_kwargs["include_top"]
124-
self.pooling = parsed_kwargs["pooling"]
125-
self.dropout_rate = parsed_kwargs["dropout_rate"]
126-
self.classes = parsed_kwargs["classes"]
127-
self.classifier_activation = parsed_kwargs["classifier_activation"]
128-
# `self.weights` is been used internally
129-
self._weights = parsed_kwargs["weights"]
152+
def build_head(self, inputs):
153+
x = inputs
154+
if self._include_top:
155+
x = self.build_top(
156+
x,
157+
self._classes,
158+
self._classifier_activation,
159+
self._dropout_rate,
160+
)
161+
else:
162+
if self._pooling == "avg":
163+
x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
164+
elif self._pooling == "max":
165+
x = layers.GlobalMaxPooling2D(name="max_pool")(x)
166+
return x
167+
168+
def load_pretrained_weights(self, weights_url: typing.Optional[str] = None):
169+
if weights_url is not None:
170+
result = urllib.parse.urlparse(weights_url)
171+
file_name = pathlib.Path(result.path).name
172+
weights_path = utils.get_file(
173+
file_name, weights_url, cache_subdir="kimm_models"
174+
)
175+
self.load_weights(weights_path)
130176

131177
@staticmethod
132178
@abc.abstractmethod
@@ -141,20 +187,25 @@ def get_config(self):
141187
# models.Model
142188
"name": self.name,
143189
"trainable": self.trainable,
144-
# feature extractor
145-
"feature_extractor": self.feature_extractor,
146-
"feature_keys": self.feature_keys,
147-
# common
148190
"input_shape": self.input_shape[1:],
149-
"include_preprocessing": self.include_preprocessing,
150-
"include_top": self.include_top,
151-
"pooling": self.pooling,
152-
"dropout_rate": self.dropout_rate,
153-
"classes": self.classes,
154-
"classifier_activation": self.classifier_activation,
191+
# common
192+
"include_preprocessing": self._include_preprocessing,
193+
"include_top": self._include_top,
194+
"pooling": self._pooling,
195+
"dropout_rate": self._dropout_rate,
196+
"classes": self._classes,
197+
"classifier_activation": self._classifier_activation,
155198
"weights": self._weights,
199+
"weights_url": self._weights_url,
200+
# feature extractor
201+
"feature_extractor": self._feature_extractor,
202+
"feature_keys": self._feature_keys,
156203
}
157204
return config
158205

159206
def fix_config(self, config: typing.Dict):
160207
return config
208+
209+
@property
210+
def default_origin(self):
211+
return "https://github.com/james77777778/keras-aug/releases/download/v0.5.0"

kimm/models/base_model_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
class SampleModel(BaseModel):
1010
def __init__(self, **kwargs):
11+
self.set_properties(kwargs)
1112
inputs = layers.Input(shape=[224, 224, 3])
1213

1314
features = {}

kimm/models/convmixer.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import keras
44
from keras import layers
5-
from keras import utils
65

76
from kimm.models.base_model import BaseModel
87
from kimm.utils import add_model_to_registry
@@ -42,6 +41,7 @@ def apply_convmixer_block(
4241
return x
4342

4443

44+
@keras.saving.register_keras_serializable(package="kimm")
4545
class ConvMixer(BaseModel):
4646
def __init__(
4747
self,
@@ -52,16 +52,16 @@ def __init__(
5252
activation: str = "relu",
5353
**kwargs,
5454
):
55-
parsed_kwargs = self.parse_kwargs(kwargs)
56-
img_input = self.determine_input_tensor(
57-
parsed_kwargs["input_tensor"],
58-
parsed_kwargs["input_shape"],
59-
parsed_kwargs["default_size"],
55+
input_tensor = kwargs.pop("input_tensor", None)
56+
self.set_properties(kwargs)
57+
inputs = self.determine_input_tensor(
58+
input_tensor,
59+
self._input_shape,
60+
self._default_size,
6061
)
61-
x = img_input
62+
x = inputs
6263

63-
if parsed_kwargs["include_preprocessing"]:
64-
x = self.build_preprocessing(x, "imagenet")
64+
x = self.build_preprocessing(x, "imagenet")
6565

6666
# Prepare feature extraction
6767
features = {}
@@ -89,30 +89,11 @@ def __init__(
8989
features[f"BLOCK{i}"] = x
9090

9191
# Head
92-
if parsed_kwargs["include_top"]:
93-
x = self.build_top(
94-
x,
95-
parsed_kwargs["classes"],
96-
parsed_kwargs["classifier_activation"],
97-
parsed_kwargs["dropout_rate"],
98-
)
99-
else:
100-
if parsed_kwargs["pooling"] == "avg":
101-
x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
102-
elif parsed_kwargs["pooling"] == "max":
103-
x = layers.GlobalMaxPooling2D(name="max_pool")(x)
104-
105-
# Ensure that the model takes into account
106-
# any potential predecessors of `input_tensor`.
107-
if parsed_kwargs["input_tensor"] is not None:
108-
inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"])
109-
else:
110-
inputs = img_input
92+
x = self.build_head(x)
11193

11294
super().__init__(inputs=inputs, outputs=x, features=features, **kwargs)
11395

11496
# All references to `self` below this line
115-
self.add_references(parsed_kwargs)
11697
self.depth = depth
11798
self.hidden_channels = hidden_channels
11899
self.patch_size = patch_size

0 commit comments

Comments
 (0)