Skip to content

Commit 4506330

Browse files
Improve test coverage and refactor BaseModel (#26)
* Update `available_feature_keys` and `available_weights` * Fix serialization * Update version * Update `README`
1 parent c608f1c commit 4506330

20 files changed

+1083
-667
lines changed

README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ pip install keras kimm
2929
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14WxYgVjlwCIO9MwqPYW-dskbTL2UHsVN?usp=sharing)
3030

3131
```python
32-
import cv2
3332
import keras
3433
from keras import ops
34+
from keras import utils
3535
from keras.applications.imagenet_utils import decode_predictions
3636

3737
import kimm
@@ -43,15 +43,15 @@ print(kimm.list_models())
4343
print(kimm.list_models("efficientnet", weights="imagenet")) # fuzzy search
4444

4545
# Initialize the model with pretrained weights
46-
model = kimm.models.EfficientNetV2B0()
47-
image_size = model._default_size
46+
model = kimm.models.VisionTransformerTiny16()
47+
image_size = (model._default_size, model._default_size)
4848

4949
# Load an image as the model input
5050
image_path = keras.utils.get_file(
5151
"african_elephant.jpg", "https://i.imgur.com/Bvro0YD.png"
5252
)
53-
image = cv2.imread(image_path)
54-
image = cv2.resize(image, (image_size, image_size))
53+
image = utils.load_img(image_path, target_size=image_size)
54+
image = utils.img_to_array(image)
5555
x = ops.convert_to_tensor(image)
5656
x = ops.expand_dims(x, axis=0)
5757

@@ -62,9 +62,9 @@ print("Predicted:", decode_predictions(preds, top=3)[0])
6262

6363
```bash
6464
['ConvMixer1024D20', 'ConvMixer1536D20', 'ConvMixer736D32', 'ConvNeXtAtto', ...]
65-
['EfficientNetB0', 'EfficientNetB1', 'EfficientNetB2', 'EfficientNetB3', ...]
66-
1/1 ━━━━━━━━━━━━━━━━━━━━ 11s 11s/step
67-
Predicted: [('n02504458', 'African_elephant', 0.90578836), ('n01871265', 'tusker', 0.024864597), ('n02504013', 'Indian_elephant', 0.01161992)]
65+
['VisionTransformerBase16', 'VisionTransformerBase32', 'VisionTransformerSmall16', ...]
66+
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step
67+
Predicted: [('n02504458', 'African_elephant', 0.6895825), ('n01871265', 'tusker', 0.17934209), ('n02504013', 'Indian_elephant', 0.12927249)]
6868
```
6969

7070
### An end-to-end example: fine-tuning an image classification model on a cats vs. dogs dataset

kimm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from kimm import models # force to add models to the registry
22
from kimm.utils.model_registry import list_models
33

4-
__version__ = "0.1.2"
4+
__version__ = "0.1.3"

kimm/models/base_model.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import abc
21
import pathlib
32
import typing
43
import urllib.parse
@@ -12,6 +11,12 @@
1211

1312

1413
class BaseModel(models.Model):
14+
default_origin = (
15+
"https://github.com/james77777778/kimm/releases/download/0.1.0/"
16+
)
17+
available_feature_keys = []
18+
available_weights = []
19+
1520
def __init__(
1621
self,
1722
inputs,
@@ -183,12 +188,6 @@ def load_pretrained_weights(self, weights_url: typing.Optional[str] = None):
183188
)
184189
self.load_weights(weights_path)
185190

186-
@staticmethod
187-
@abc.abstractmethod
188-
def available_feature_keys():
189-
# TODO: add docstring
190-
raise NotImplementedError
191-
192191
def get_config(self):
193192
# Don't chain to super here. The default `get_config()` for functional
194193
# models is nested and cannot be passed to BaseModel.
@@ -215,6 +214,19 @@ def get_config(self):
215214
def fix_config(self, config: typing.Dict):
216215
return config
217216

218-
@property
219-
def default_origin(self):
220-
return "https://github.com/james77777778/kimm/releases/download/0.1.0/"
217+
def get_weights_url(self, weights):
218+
if weights is None:
219+
return None
220+
221+
for _weights, _origin, _file_name in self.available_weights:
222+
if weights == _weights:
223+
return f"{_origin}/{_file_name}"
224+
225+
# Failed to find the weights
226+
_available_weights_name = [
227+
_weights for _weights, _ in self.available_weights
228+
]
229+
raise ValueError(
230+
f"Available weights are {_available_weights_name}. "
231+
f"Received weights={weights}"
232+
)

kimm/models/convmixer.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def __init__(
5252
activation: str = "relu",
5353
**kwargs,
5454
):
55+
kwargs["weights_url"] = self.get_weights_url(kwargs["weights"])
56+
5557
input_tensor = kwargs.pop("input_tensor", None)
5658
self.set_properties(kwargs)
5759
inputs = self.determine_input_tensor(
@@ -100,10 +102,6 @@ def __init__(
100102
self.kernel_size = kernel_size
101103
self.activation = activation
102104

103-
@staticmethod
104-
def available_feature_keys():
105-
raise NotImplementedError
106-
107105
def get_config(self):
108106
config = super().get_config()
109107
config.update(
@@ -136,6 +134,15 @@ def fix_config(self, config):
136134

137135

138136
class ConvMixer736D32(ConvMixer):
137+
available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(32)]]
138+
available_weights = [
139+
(
140+
"imagenet",
141+
ConvMixer.default_origin,
142+
"convmixer736d32_convmixer_768_32.in1k.keras",
143+
)
144+
]
145+
139146
def __init__(
140147
self,
141148
input_tensor: keras.KerasTensor = None,
@@ -151,9 +158,6 @@ def __init__(
151158
**kwargs,
152159
):
153160
kwargs = self.fix_config(kwargs)
154-
if weights == "imagenet":
155-
file_name = "convmixer736d32_convmixer_768_32.in1k.keras"
156-
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
157161
super().__init__(
158162
32,
159163
768,
@@ -173,14 +177,17 @@ def __init__(
173177
**kwargs,
174178
)
175179

176-
@staticmethod
177-
def available_feature_keys():
178-
feature_keys = ["STEM"]
179-
feature_keys.extend([f"BLOCK{i}" for i in range(32)])
180-
return feature_keys
181-
182180

183181
class ConvMixer1024D20(ConvMixer):
182+
available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(20)]]
183+
available_weights = [
184+
(
185+
"imagenet",
186+
ConvMixer.default_origin,
187+
"convmixer1024d20_convmixer_1024_20_ks9_p14.in1k.keras",
188+
)
189+
]
190+
184191
def __init__(
185192
self,
186193
input_tensor: keras.KerasTensor = None,
@@ -196,9 +203,6 @@ def __init__(
196203
**kwargs,
197204
):
198205
kwargs = self.fix_config(kwargs)
199-
if weights == "imagenet":
200-
file_name = "convmixer1024d20_convmixer_1024_20_ks9_p14.in1k.keras"
201-
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
202206
super().__init__(
203207
20,
204208
1024,
@@ -218,14 +222,17 @@ def __init__(
218222
**kwargs,
219223
)
220224

221-
@staticmethod
222-
def available_feature_keys():
223-
feature_keys = ["STEM"]
224-
feature_keys.extend([f"BLOCK{i}" for i in range(20)])
225-
return feature_keys
226-
227225

228226
class ConvMixer1536D20(ConvMixer):
227+
available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(20)]]
228+
available_weights = [
229+
(
230+
"imagenet",
231+
ConvMixer.default_origin,
232+
"convmixer1536d20_convmixer_1536_20.in1k.keras",
233+
)
234+
]
235+
229236
def __init__(
230237
self,
231238
input_tensor: keras.KerasTensor = None,
@@ -241,9 +248,6 @@ def __init__(
241248
**kwargs,
242249
):
243250
kwargs = self.fix_config(kwargs)
244-
if weights == "imagenet":
245-
file_name = "convmixer1536d20_convmixer_1536_20.in1k.keras"
246-
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
247251
super().__init__(
248252
20,
249253
1536,
@@ -263,12 +267,6 @@ def __init__(
263267
**kwargs,
264268
)
265269

266-
@staticmethod
267-
def available_feature_keys():
268-
feature_keys = ["STEM"]
269-
feature_keys.extend([f"BLOCK{i}" for i in range(20)])
270-
return feature_keys
271-
272270

273271
add_model_to_registry(ConvMixer736D32, "imagenet")
274272
add_model_to_registry(ConvMixer1024D20, "imagenet")

0 commit comments

Comments
 (0)