Skip to content

Commit b3b8219

Browse files
Defaults to weights=imagenet if possible (#19)
* Nit * Add pretrained weights * Fix InceptionV3 * Fix `MobileNet100V3Large` * Update version
1 parent 22136f9 commit b3b8219

20 files changed

+483
-136
lines changed

kimm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from kimm import models # force to add models to the registry
22
from kimm.utils.model_registry import list_models
33

4-
__version__ = "0.1.1"
4+
__version__ = "0.1.2"

kimm/models/base_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def set_properties(
8989
# feature extractor
9090
self._feature_extractor = kwargs.pop("feature_extractor", False)
9191
self._feature_keys = kwargs.pop("feature_keys", None)
92-
print("self._feature_keys", self._feature_keys)
9392

9493
def determine_input_tensor(
9594
self,
@@ -208,4 +207,4 @@ def fix_config(self, config: typing.Dict):
208207

209208
@property
210209
def default_origin(self):
211-
return "https://github.com/james77777778/keras-aug/releases/download/v0.5.0"
210+
return "https://github.com/james77777778/kimm/releases/download/0.1.0/"

kimm/models/convmixer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,14 @@ def __init__(
146146
dropout_rate: float = 0.0,
147147
classes: int = 1000,
148148
classifier_activation: str = "softmax",
149-
weights: typing.Optional[str] = None,
149+
weights: typing.Optional[str] = "imagenet",
150150
name: str = "ConvMixer736D32",
151151
**kwargs,
152152
):
153153
kwargs = self.fix_config(kwargs)
154+
if weights == "imagenet":
155+
file_name = "convmixer736d32_convmixer_768_32.in1k.keras"
156+
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
154157
super().__init__(
155158
32,
156159
768,
@@ -188,11 +191,14 @@ def __init__(
188191
dropout_rate: float = 0.0,
189192
classes: int = 1000,
190193
classifier_activation: str = "softmax",
191-
weights: typing.Optional[str] = None,
194+
weights: typing.Optional[str] = "imagenet",
192195
name: str = "ConvMixer1024D20",
193196
**kwargs,
194197
):
195198
kwargs = self.fix_config(kwargs)
199+
if weights == "imagenet":
200+
file_name = "convmixer1024d20_convmixer_1024_20_ks9_p14.in1k.keras"
201+
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
196202
super().__init__(
197203
20,
198204
1024,
@@ -230,11 +236,14 @@ def __init__(
230236
dropout_rate: float = 0.0,
231237
classes: int = 1000,
232238
classifier_activation: str = "softmax",
233-
weights: typing.Optional[str] = None,
239+
weights: typing.Optional[str] = "imagenet",
234240
name: str = "ConvMixer1536D20",
235241
**kwargs,
236242
):
237243
kwargs = self.fix_config(kwargs)
244+
if weights == "imagenet":
245+
file_name = "convmixer1536d20_convmixer_1536_20.in1k.keras"
246+
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
238247
super().__init__(
239248
20,
240249
1536,

kimm/models/convnext.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -247,11 +247,14 @@ def __init__(
247247
dropout_rate: float = 0.0,
248248
classes: int = 1000,
249249
classifier_activation: str = "softmax",
250-
weights: typing.Optional[str] = None,
250+
weights: typing.Optional[str] = "imagenet",
251251
name: str = "ConvNeXtAtto",
252252
**kwargs,
253253
):
254254
kwargs = self.fix_config(kwargs)
255+
if weights == "imagenet":
256+
file_name = "convnextatto_convnext_atto.d2_in1k.keras"
257+
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
255258
super().__init__(
256259
(2, 2, 6, 2),
257260
(40, 80, 160, 320),
@@ -284,11 +287,14 @@ def __init__(
284287
dropout_rate: float = 0.0,
285288
classes: int = 1000,
286289
classifier_activation: str = "softmax",
287-
weights: typing.Optional[str] = None,
290+
weights: typing.Optional[str] = "imagenet",
288291
name: str = "ConvNeXtFemto",
289292
**kwargs,
290293
):
291294
kwargs = self.fix_config(kwargs)
295+
if weights == "imagenet":
296+
file_name = "convnextfemto_convnext_femto.d1_in1k.keras"
297+
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
292298
super().__init__(
293299
(2, 2, 6, 2),
294300
(48, 96, 192, 384),
@@ -321,11 +327,14 @@ def __init__(
321327
dropout_rate: float = 0.0,
322328
classes: int = 1000,
323329
classifier_activation: str = "softmax",
324-
weights: typing.Optional[str] = None,
330+
weights: typing.Optional[str] = "imagenet",
325331
name: str = "ConvNeXtPico",
326332
**kwargs,
327333
):
328334
kwargs = self.fix_config(kwargs)
335+
if weights == "imagenet":
336+
file_name = "convnextpico_convnext_pico.d1_in1k.keras"
337+
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
329338
super().__init__(
330339
(2, 2, 6, 2),
331340
(64, 128, 256, 512),
@@ -358,11 +367,14 @@ def __init__(
358367
dropout_rate: float = 0.0,
359368
classes: int = 1000,
360369
classifier_activation: str = "softmax",
361-
weights: typing.Optional[str] = None,
370+
weights: typing.Optional[str] = "imagenet",
362371
name: str = "ConvNeXtNano",
363372
**kwargs,
364373
):
365374
kwargs = self.fix_config(kwargs)
375+
if weights == "imagenet":
376+
file_name = "convnextnano_convnext_nano.in12k_ft_in1k.keras"
377+
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
366378
super().__init__(
367379
(2, 2, 8, 2),
368380
(80, 160, 320, 640),
@@ -395,11 +407,14 @@ def __init__(
395407
dropout_rate: float = 0.0,
396408
classes: int = 1000,
397409
classifier_activation: str = "softmax",
398-
weights: typing.Optional[str] = None,
410+
weights: typing.Optional[str] = "imagenet",
399411
name: str = "ConvNeXtTiny",
400412
**kwargs,
401413
):
402414
kwargs = self.fix_config(kwargs)
415+
if weights == "imagenet":
416+
file_name = "convnexttiny_convnext_tiny.in12k_ft_in1k.keras"
417+
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
403418
super().__init__(
404419
(3, 3, 9, 3),
405420
(96, 192, 384, 768),
@@ -432,11 +447,14 @@ def __init__(
432447
dropout_rate: float = 0.0,
433448
classes: int = 1000,
434449
classifier_activation: str = "softmax",
435-
weights: typing.Optional[str] = None,
450+
weights: typing.Optional[str] = "imagenet",
436451
name: str = "ConvNeXtSmall",
437452
**kwargs,
438453
):
439454
kwargs = self.fix_config(kwargs)
455+
if weights == "imagenet":
456+
file_name = "convnextsmall_convnext_small.in12k_ft_in1k.keras"
457+
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
440458
super().__init__(
441459
(3, 3, 27, 3),
442460
(96, 192, 384, 768),
@@ -469,11 +487,14 @@ def __init__(
469487
dropout_rate: float = 0.0,
470488
classes: int = 1000,
471489
classifier_activation: str = "softmax",
472-
weights: typing.Optional[str] = None,
490+
weights: typing.Optional[str] = "imagenet",
473491
name: str = "ConvNeXtBase",
474492
**kwargs,
475493
):
476494
kwargs = self.fix_config(kwargs)
495+
if weights == "imagenet":
496+
file_name = "convnextbase_convnext_base.fb_in22k_ft_in1k.keras"
497+
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
477498
super().__init__(
478499
(3, 3, 27, 3),
479500
(128, 256, 512, 1024),
@@ -506,11 +527,14 @@ def __init__(
506527
dropout_rate: float = 0.0,
507528
classes: int = 1000,
508529
classifier_activation: str = "softmax",
509-
weights: typing.Optional[str] = None,
530+
weights: typing.Optional[str] = "imagenet",
510531
name: str = "ConvNeXtLarge",
511532
**kwargs,
512533
):
513534
kwargs = self.fix_config(kwargs)
535+
if weights == "imagenet":
536+
file_name = "convnextlarge_convnext_large.fb_in22k_ft_in1k.keras"
537+
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
514538
super().__init__(
515539
(3, 3, 27, 3),
516540
(192, 384, 768, 1536),
@@ -577,4 +601,4 @@ def __init__(
577601
add_model_to_registry(ConvNeXtSmall, "imagenet")
578602
add_model_to_registry(ConvNeXtBase, "imagenet")
579603
add_model_to_registry(ConvNeXtLarge, "imagenet")
580-
add_model_to_registry(ConvNeXtXLarge, "imagenet")
604+
add_model_to_registry(ConvNeXtXLarge)

kimm/models/densenet.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,14 @@ def __init__(
171171
dropout_rate: float = 0.0,
172172
classes: int = 1000,
173173
classifier_activation: str = "softmax",
174-
weights: typing.Optional[str] = None, # TODO: imagenet
174+
weights: typing.Optional[str] = "imagenet",
175175
name: str = "DenseNet121",
176176
**kwargs,
177177
):
178178
kwargs = self.fix_config(kwargs)
179+
if weights == "imagenet":
180+
file_name = "densenet121_densenet121.ra_in1k.keras"
181+
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
179182
super().__init__(
180183
32,
181184
[6, 12, 24, 16],
@@ -205,11 +208,14 @@ def __init__(
205208
dropout_rate: float = 0.0,
206209
classes: int = 1000,
207210
classifier_activation: str = "softmax",
208-
weights: typing.Optional[str] = None, # TODO: imagenet
211+
weights: typing.Optional[str] = "imagenet",
209212
name: str = "DenseNet161",
210213
**kwargs,
211214
):
212215
kwargs = self.fix_config(kwargs)
216+
if weights == "imagenet":
217+
file_name = "densenet161_densenet161.tv_in1k.keras"
218+
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
213219
super().__init__(
214220
48,
215221
[6, 12, 36, 24],
@@ -239,11 +245,14 @@ def __init__(
239245
dropout_rate: float = 0.0,
240246
classes: int = 1000,
241247
classifier_activation: str = "softmax",
242-
weights: typing.Optional[str] = None, # TODO: imagenet
248+
weights: typing.Optional[str] = "imagenet",
243249
name: str = "DenseNet169",
244250
**kwargs,
245251
):
246252
kwargs = self.fix_config(kwargs)
253+
if weights == "imagenet":
254+
file_name = "densenet169_densenet169.tv_in1k.keras"
255+
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
247256
super().__init__(
248257
32,
249258
[6, 12, 32, 32],
@@ -273,11 +282,14 @@ def __init__(
273282
dropout_rate: float = 0.0,
274283
classes: int = 1000,
275284
classifier_activation: str = "softmax",
276-
weights: typing.Optional[str] = None, # TODO: imagenet
285+
weights: typing.Optional[str] = "imagenet",
277286
name: str = "DenseNet201",
278287
**kwargs,
279288
):
280289
kwargs = self.fix_config(kwargs)
290+
if weights == "imagenet":
291+
file_name = "densenet201_densenet201.tv_in1k.keras"
292+
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
281293
super().__init__(
282294
32,
283295
[6, 12, 48, 32],

0 commit comments

Comments
 (0)