Skip to content

Commit 6e37c90

Browse files
Fix import bug and refactor model registry (#12)
* Fix import error * Rename `list_models` and add predicitons to features
1 parent d7804ac commit 6e37c90

24 files changed

+187
-164
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import kimm
2424
print(kimm.list_models())
2525

2626
# Specify the name and other arguments to filter the result
27-
print(kimm.list_models("efficientnet", has_pretrained=True)) # fuzzy search
27+
print(kimm.list_models("efficientnet", weights="imagenet")) # fuzzy search
2828

2929
# Initialize the model with pretrained weights
3030
model = kimm.models.EfficientNetV2B0(weights="imagenet")
@@ -36,7 +36,7 @@ print(y.shape)
3636

3737
# Initialize the model as a feature extractor with pretrained weights
3838
model = kimm.models.EfficientNetV2B0(
39-
as_feature_extractor=True, weights="imagenet"
39+
feature_extractor=True, weights="imagenet"
4040
)
4141

4242
# Extract features for downstream tasks

kimm/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from kimm.models.base_model import BaseModel
2+
from kimm.models.densenet import * # noqa:F403
23
from kimm.models.efficientnet import * # noqa:F403
34
from kimm.models.ghostnet import * # noqa:F403
5+
from kimm.models.inception_v3 import * # noqa:F403
46
from kimm.models.mobilenet_v2 import * # noqa:F403
57
from kimm.models.mobilenet_v3 import * # noqa:F403
8+
from kimm.models.mobilevit import * # noqa:F403
69
from kimm.models.resnet import * # noqa:F403
710
from kimm.models.vision_transformer import * # noqa:F403

kimm/models/base_model.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ def __init__(
1717
feature_keys: typing.Optional[typing.List[str]] = None,
1818
**kwargs,
1919
):
20-
self.as_feature_extractor = kwargs.pop("as_feature_extractor", False)
20+
self.feature_extractor = kwargs.pop("feature_extractor", False)
2121
self.feature_keys = feature_keys
22-
if self.as_feature_extractor:
22+
if self.feature_extractor:
2323
if features is None:
2424
raise ValueError(
2525
"`features` must be set when "
26-
f"`as_feature_extractor=True`. Got features={features}"
26+
f"`feature_extractor=True`. Received features={features}"
2727
)
2828
if self.feature_keys is None:
2929
self.feature_keys = list(features.keys())
@@ -35,6 +35,9 @@ def __init__(
3535
f"are: {list(features.keys())}"
3636
)
3737
filtered_features[k] = features[k]
38+
# add outputs
39+
if backend.is_keras_tensor(outputs):
40+
filtered_features["TOP"] = outputs
3841
super().__init__(inputs=inputs, outputs=filtered_features, **kwargs)
3942
else:
4043
del features
@@ -139,7 +142,7 @@ def get_config(self):
139142
"name": self.name,
140143
"trainable": self.trainable,
141144
# feature extractor
142-
"as_feature_extractor": self.as_feature_extractor,
145+
"feature_extractor": self.feature_extractor,
143146
"feature_keys": self.feature_keys,
144147
# common
145148
"input_shape": self.input_shape[1:],

kimm/models/base_model_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,31 +39,31 @@ def test_feature_extractor(self):
3939
x = random.uniform([1, 224, 224, 3])
4040

4141
# availiable_feature_keys
42-
self.assertEqual(
43-
SampleModel.available_feature_keys(),
42+
self.assertContainsSubset(
4443
["S2", "S4", "S8", "S16", "S32"],
44+
SampleModel.available_feature_keys(),
4545
)
4646

47-
# as_feature_extractor=False
47+
# feature_extractor=False
4848
model = SampleModel()
4949

5050
y = model(x, training=False)
5151

5252
self.assertNotIsInstance(y, dict)
5353
self.assertEqual(list(y.shape), [1, 7, 7, 3])
5454

55-
# as_feature_extractor=True
56-
model = SampleModel(as_feature_extractor=True)
55+
# feature_extractor=True
56+
model = SampleModel(feature_extractor=True)
5757

5858
y = model(x, training=False)
5959

6060
self.assertIsInstance(y, dict)
6161
self.assertEqual(list(y["S2"].shape), [1, 112, 112, 3])
6262
self.assertEqual(list(y["S32"].shape), [1, 7, 7, 3])
6363

64-
# as_feature_extractor=True with feature_keys
64+
# feature_extractor=True with feature_keys
6565
model = SampleModel(
66-
as_feature_extractor=True, feature_keys=["S2", "S16", "S32"]
66+
feature_extractor=True, feature_keys=["S2", "S16", "S32"]
6767
)
6868

6969
y = model(x, training=False)

kimm/models/densenet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def __init__(
315315
)
316316

317317

318-
add_model_to_registry(DenseNet121, True)
319-
add_model_to_registry(DenseNet161, True)
320-
add_model_to_registry(DenseNet169, True)
321-
add_model_to_registry(DenseNet201, True)
318+
add_model_to_registry(DenseNet121, "imagenet")
319+
add_model_to_registry(DenseNet161, "imagenet")
320+
add_model_to_registry(DenseNet169, "imagenet")
321+
add_model_to_registry(DenseNet201, "imagenet")

kimm/models/densenet_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@ def test_densenet_base(self, model_class):
2121
@parameterized.named_parameters([(DenseNet121.__name__, DenseNet121)])
2222
def test_densenet_feature_extractor(self, model_class):
2323
x = random.uniform([1, 224, 224, 3]) * 255.0
24-
model = model_class(
25-
input_shape=[224, 224, 3], as_feature_extractor=True
26-
)
24+
model = model_class(input_shape=[224, 224, 3], feature_extractor=True)
2725

2826
y = model(x, training=False)
2927

3028
self.assertIsInstance(y, dict)
31-
self.assertAllEqual(
32-
list(y.keys()), model_class.available_feature_keys()
29+
self.assertContainsSubset(
30+
model_class.available_feature_keys(),
31+
list(y.keys()),
3332
)
3433
self.assertEqual(list(y["STEM_S4"].shape), [1, 56, 56, 64])
3534
self.assertEqual(list(y["BLOCK0_S8"].shape), [1, 28, 28, 128])

kimm/models/efficientnet.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,29 +1514,29 @@ def __init__(
15141514
)
15151515

15161516

1517-
add_model_to_registry(EfficientNetB0, True)
1518-
add_model_to_registry(EfficientNetB1, True)
1519-
add_model_to_registry(EfficientNetB2, True)
1520-
add_model_to_registry(EfficientNetB3, True)
1521-
add_model_to_registry(EfficientNetB4, True)
1522-
add_model_to_registry(EfficientNetB5, True)
1523-
add_model_to_registry(EfficientNetB6, True)
1524-
add_model_to_registry(EfficientNetB7, True)
1525-
add_model_to_registry(EfficientNetLiteB0, True)
1526-
add_model_to_registry(EfficientNetLiteB1, True)
1527-
add_model_to_registry(EfficientNetLiteB2, True)
1528-
add_model_to_registry(EfficientNetLiteB3, True)
1529-
add_model_to_registry(EfficientNetLiteB4, True)
1530-
add_model_to_registry(EfficientNetV2S, True)
1531-
add_model_to_registry(EfficientNetV2M, True)
1532-
add_model_to_registry(EfficientNetV2L, True)
1533-
add_model_to_registry(EfficientNetV2XL, True)
1534-
add_model_to_registry(EfficientNetV2B0, True)
1535-
add_model_to_registry(EfficientNetV2B1, True)
1536-
add_model_to_registry(EfficientNetV2B2, True)
1537-
add_model_to_registry(EfficientNetV2B3, True)
1538-
add_model_to_registry(TinyNetA, True)
1539-
add_model_to_registry(TinyNetB, True)
1540-
add_model_to_registry(TinyNetC, True)
1541-
add_model_to_registry(TinyNetD, True)
1542-
add_model_to_registry(TinyNetE, True)
1517+
add_model_to_registry(EfficientNetB0, "imagenet")
1518+
add_model_to_registry(EfficientNetB1, "imagenet")
1519+
add_model_to_registry(EfficientNetB2, "imagenet")
1520+
add_model_to_registry(EfficientNetB3, "imagenet")
1521+
add_model_to_registry(EfficientNetB4, "imagenet")
1522+
add_model_to_registry(EfficientNetB5, "imagenet")
1523+
add_model_to_registry(EfficientNetB6, "imagenet")
1524+
add_model_to_registry(EfficientNetB7, "imagenet")
1525+
add_model_to_registry(EfficientNetLiteB0, "imagenet")
1526+
add_model_to_registry(EfficientNetLiteB1, "imagenet")
1527+
add_model_to_registry(EfficientNetLiteB2, "imagenet")
1528+
add_model_to_registry(EfficientNetLiteB3, "imagenet")
1529+
add_model_to_registry(EfficientNetLiteB4, "imagenet")
1530+
add_model_to_registry(EfficientNetV2S, "imagenet")
1531+
add_model_to_registry(EfficientNetV2M, "imagenet")
1532+
add_model_to_registry(EfficientNetV2L, "imagenet")
1533+
add_model_to_registry(EfficientNetV2XL, "imagenet")
1534+
add_model_to_registry(EfficientNetV2B0, "imagenet")
1535+
add_model_to_registry(EfficientNetV2B1, "imagenet")
1536+
add_model_to_registry(EfficientNetV2B2, "imagenet")
1537+
add_model_to_registry(EfficientNetV2B3, "imagenet")
1538+
add_model_to_registry(TinyNetA, "imagenet")
1539+
add_model_to_registry(TinyNetB, "imagenet")
1540+
add_model_to_registry(TinyNetC, "imagenet")
1541+
add_model_to_registry(TinyNetD, "imagenet")
1542+
add_model_to_registry(TinyNetE, "imagenet")

kimm/models/efficientnet_test.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,14 @@ def test_efficentnet_feature_extractor(
4848
self, model_class, width, fix_stem_channels
4949
):
5050
x = random.uniform([1, 224, 224, 3]) * 255.0
51-
model = model_class(
52-
input_shape=[224, 224, 3], as_feature_extractor=True
53-
)
51+
model = model_class(input_shape=[224, 224, 3], feature_extractor=True)
5452

5553
y = model(x, training=False)
5654

5755
self.assertIsInstance(y, dict)
58-
self.assertAllEqual(
59-
list(y.keys()), model_class.available_feature_keys()
56+
self.assertContainsSubset(
57+
model_class.available_feature_keys(),
58+
list(y.keys()),
6059
)
6160
if fix_stem_channels:
6261
self.assertEqual(list(y["STEM_S2"].shape), [1, 112, 112, 32])
@@ -86,15 +85,14 @@ def test_efficentnet_feature_extractor(
8685
)
8786
def test_efficentnet_v2_feature_extractor(self, model_class, width):
8887
x = random.uniform([1, 224, 224, 3]) * 255.0
89-
model = model_class(
90-
input_shape=[224, 224, 3], as_feature_extractor=True
91-
)
88+
model = model_class(input_shape=[224, 224, 3], feature_extractor=True)
9289

9390
y = model(x, training=False)
9491

9592
self.assertIsInstance(y, dict)
96-
self.assertAllEqual(
97-
list(y.keys()), model_class.available_feature_keys()
93+
self.assertContainsSubset(
94+
model_class.available_feature_keys(),
95+
list(y.keys()),
9896
)
9997
if "EfficientNetV2S" in model_class.__name__:
10098
self.assertEqual(list(y["STEM_S2"].shape), [1, 112, 112, 24])

kimm/models/ghostnet.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -595,9 +595,9 @@ def __init__(
595595
)
596596

597597

598-
add_model_to_registry(GhostNet050, False)
599-
add_model_to_registry(GhostNet100, True)
600-
add_model_to_registry(GhostNet130, True)
601-
add_model_to_registry(GhostNet100V2, True)
602-
add_model_to_registry(GhostNet130V2, True)
603-
add_model_to_registry(GhostNet160V2, True)
598+
add_model_to_registry(GhostNet050)
599+
add_model_to_registry(GhostNet100, "imagenet")
600+
add_model_to_registry(GhostNet130, "imagenet")
601+
add_model_to_registry(GhostNet100V2, "imagenet")
602+
add_model_to_registry(GhostNet130V2, "imagenet")
603+
add_model_to_registry(GhostNet160V2, "imagenet")

kimm/models/ghostnet_test.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@ def test_ghostnet_base(self, model_class):
2222
@parameterized.named_parameters([(GhostNet100.__name__, GhostNet100)])
2323
def test_ghostnet_feature_extractor(self, model_class):
2424
x = random.uniform([1, 224, 224, 3]) * 255.0
25-
model = model_class(as_feature_extractor=True)
25+
model = model_class(feature_extractor=True)
2626

2727
y = model(x, training=False)
2828

2929
self.assertIsInstance(y, dict)
30-
self.assertAllEqual(
31-
list(y.keys()), model_class.available_feature_keys()
30+
self.assertContainsSubset(
31+
model_class.available_feature_keys(),
32+
list(y.keys()),
3233
)
3334
self.assertEqual(list(y["STEM_S2"].shape), [1, 112, 112, 16])
3435
self.assertEqual(list(y["BLOCK1_S4"].shape), [1, 56, 56, 24])
@@ -49,13 +50,14 @@ def test_ghostnetv2_base(self, model_class):
4950
@parameterized.named_parameters([(GhostNet100V2.__name__, GhostNet100V2)])
5051
def test_ghostnetv2_feature_extractor(self, model_class):
5152
x = random.uniform([1, 224, 224, 3]) * 255.0
52-
model = model_class(as_feature_extractor=True)
53+
model = model_class(feature_extractor=True)
5354

5455
y = model(x, training=False)
5556

5657
self.assertIsInstance(y, dict)
57-
self.assertAllEqual(
58-
list(y.keys()), model_class.available_feature_keys()
58+
self.assertContainsSubset(
59+
model_class.available_feature_keys(),
60+
list(y.keys()),
5961
)
6062
self.assertEqual(list(y["STEM_S2"].shape), [1, 112, 112, 16])
6163
self.assertEqual(list(y["BLOCK1_S4"].shape), [1, 56, 56, 24])

0 commit comments

Comments
 (0)