Skip to content

Commit ca90fa4

Browse files
Add MobileViT (#8)
* Add `MobileViT` * Improve `MobileViT`
1 parent 5fe9c62 commit ca90fa4

11 files changed

+804
-280
lines changed

kimm/blocks/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
from kimm.blocks.base_block import apply_activation
22
from kimm.blocks.base_block import apply_conv2d_block
33
from kimm.blocks.base_block import apply_se_block
4+
from kimm.blocks.inverted_residual_block import apply_inverted_residual_block
5+
from kimm.blocks.transformer_block import apply_mlp_block
6+
from kimm.blocks.transformer_block import apply_transformer_block
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from keras import layers
2+
3+
from kimm.blocks.base_block import apply_conv2d_block
4+
from kimm.blocks.base_block import apply_se_block
5+
from kimm.utils import make_divisible
6+
7+
8+
def apply_inverted_residual_block(
9+
inputs,
10+
output_channels,
11+
depthwise_kernel_size=3,
12+
expansion_kernel_size=1,
13+
pointwise_kernel_size=1,
14+
strides=1,
15+
expansion_ratio=1.0,
16+
se_ratio=0.0,
17+
activation="swish",
18+
se_input_channels=None,
19+
se_activation=None,
20+
se_gate_activation="sigmoid",
21+
se_make_divisible_number=None,
22+
bn_epsilon=1e-5,
23+
padding=None,
24+
name="inverted_residual_block",
25+
):
26+
input_channels = inputs.shape[-1]
27+
hidden_channels = make_divisible(input_channels * expansion_ratio)
28+
has_skip = strides == 1 and input_channels == output_channels
29+
30+
x = inputs
31+
# Point-wise expansion
32+
x = apply_conv2d_block(
33+
x,
34+
hidden_channels,
35+
expansion_kernel_size,
36+
1,
37+
activation=activation,
38+
bn_epsilon=bn_epsilon,
39+
padding=padding,
40+
name=f"{name}_conv_pw",
41+
)
42+
# Depth-wise convolution
43+
x = apply_conv2d_block(
44+
x,
45+
kernel_size=depthwise_kernel_size,
46+
strides=strides,
47+
activation=activation,
48+
use_depthwise=True,
49+
bn_epsilon=bn_epsilon,
50+
padding=padding,
51+
name=f"{name}_conv_dw",
52+
)
53+
# Squeeze-and-excitation
54+
if se_ratio > 0:
55+
x = apply_se_block(
56+
x,
57+
se_ratio,
58+
activation=se_activation or activation,
59+
gate_activation=se_gate_activation,
60+
se_input_channels=se_input_channels,
61+
make_divisible_number=se_make_divisible_number,
62+
name=f"{name}_se",
63+
)
64+
# Point-wise linear projection
65+
x = apply_conv2d_block(
66+
x,
67+
output_channels,
68+
pointwise_kernel_size,
69+
1,
70+
activation=None,
71+
bn_epsilon=bn_epsilon,
72+
padding=padding,
73+
name=f"{name}_conv_pwl",
74+
)
75+
if has_skip:
76+
x = layers.Add()([x, inputs])
77+
return x

kimm/blocks/transformer_block.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from keras import layers
2+
3+
from kimm import layers as kimm_layers
4+
5+
6+
def apply_mlp_block(
7+
inputs,
8+
hidden_dim,
9+
output_dim=None,
10+
activation="gelu",
11+
normalization=None,
12+
use_bias=True,
13+
dropout_rate=0.0,
14+
name="mlp_block",
15+
):
16+
input_dim = inputs.shape[-1]
17+
output_dim = output_dim or input_dim
18+
19+
x = inputs
20+
x = layers.Dense(hidden_dim, use_bias=use_bias, name=f"{name}_fc1")(x)
21+
x = layers.Activation(activation, name=f"{name}_act")(x)
22+
x = layers.Dropout(dropout_rate, name=f"{name}_drop1")(x)
23+
if normalization is not None:
24+
x = normalization(name=f"{name}_norm")(x)
25+
x = layers.Dense(output_dim, use_bias=use_bias, name=f"{name}_fc2")(x)
26+
x = layers.Dropout(dropout_rate, name=f"{name}_drop2")(x)
27+
return x
28+
29+
30+
def apply_transformer_block(
31+
inputs,
32+
dim,
33+
num_heads,
34+
mlp_ratio=4.0,
35+
use_qkv_bias=False,
36+
use_qk_norm=False,
37+
projection_dropout_rate=0.0,
38+
attention_dropout_rate=0.0,
39+
activation="gelu",
40+
name="transformer_block",
41+
):
42+
x = inputs
43+
residual_1 = x
44+
45+
x = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_norm1")(x)
46+
x = kimm_layers.Attention(
47+
dim,
48+
num_heads,
49+
use_qkv_bias,
50+
use_qk_norm,
51+
attention_dropout_rate,
52+
projection_dropout_rate,
53+
name=f"{name}_attn",
54+
)(x)
55+
x = layers.Add()([residual_1, x])
56+
57+
residual_2 = x
58+
x = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_norm2")(x)
59+
x = apply_mlp_block(
60+
x,
61+
int(dim * mlp_ratio),
62+
activation=activation,
63+
dropout_rate=projection_dropout_rate,
64+
name=f"{name}_mlp",
65+
)
66+
x = layers.Add()([residual_2, x])
67+
return x

kimm/models/efficientnet.py

Lines changed: 2 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras.src.applications import imagenet_utils
99

1010
from kimm.blocks import apply_conv2d_block
11+
from kimm.blocks import apply_inverted_residual_block
1112
from kimm.blocks import apply_se_block
1213
from kimm.models.feature_extractor import FeatureExtractor
1314
from kimm.utils import make_divisible
@@ -130,73 +131,6 @@ def apply_depthwise_separation_block(
130131
return x
131132

132133

133-
def apply_inverted_residual_block(
134-
inputs,
135-
output_channels,
136-
depthwise_kernel_size=3,
137-
expansion_kernel_size=1,
138-
pointwise_kernel_size=1,
139-
strides=1,
140-
expansion_ratio=1.0,
141-
se_ratio=0.0,
142-
activation="swish",
143-
bn_epsilon=1e-5,
144-
padding=None,
145-
name="inverted_residual_block",
146-
):
147-
input_channels = inputs.shape[-1]
148-
hidden_channels = make_divisible(input_channels * expansion_ratio)
149-
has_skip = strides == 1 and input_channels == output_channels
150-
151-
x = inputs
152-
# Point-wise expansion
153-
x = apply_conv2d_block(
154-
x,
155-
hidden_channels,
156-
expansion_kernel_size,
157-
1,
158-
activation=activation,
159-
bn_epsilon=bn_epsilon,
160-
padding=padding,
161-
name=f"{name}_conv_pw",
162-
)
163-
# Depth-wise convolution
164-
x = apply_conv2d_block(
165-
x,
166-
kernel_size=depthwise_kernel_size,
167-
strides=strides,
168-
activation=activation,
169-
use_depthwise=True,
170-
bn_epsilon=bn_epsilon,
171-
padding=padding,
172-
name=f"{name}_conv_dw",
173-
)
174-
# Squeeze-and-excitation
175-
if se_ratio > 0:
176-
x = apply_se_block(
177-
x,
178-
se_ratio,
179-
activation=activation,
180-
gate_activation="sigmoid",
181-
se_input_channels=input_channels,
182-
name=f"{name}_se",
183-
)
184-
# Point-wise linear projection
185-
x = apply_conv2d_block(
186-
x,
187-
output_channels,
188-
pointwise_kernel_size,
189-
1,
190-
activation=None,
191-
bn_epsilon=bn_epsilon,
192-
padding=padding,
193-
name=f"{name}_conv_pwl",
194-
)
195-
if has_skip:
196-
x = layers.Add()([x, inputs])
197-
return x
198-
199-
200134
def apply_edge_residual_block(
201135
inputs,
202136
output_channels,
@@ -271,7 +205,7 @@ def __init__(
271205
classes: int = 1000,
272206
classifier_activation: str = "softmax",
273207
weights: typing.Optional[str] = None, # TODO: imagenet
274-
config: typing.Union[str, typing.List] = "default",
208+
config: typing.Union[str, typing.List] = "v1",
275209
**kwargs,
276210
):
277211
_available_configs = [

kimm/models/mobilenet_v2.py

Lines changed: 2 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras.src.applications import imagenet_utils
99

1010
from kimm.blocks import apply_conv2d_block
11+
from kimm.blocks import apply_inverted_residual_block
1112
from kimm.models.feature_extractor import FeatureExtractor
1213
from kimm.utils import make_divisible
1314
from kimm.utils.model_registry import add_model_to_registry
@@ -58,55 +59,6 @@ def apply_depthwise_separation_block(
5859
return x
5960

6061

61-
def apply_inverted_residual_block(
62-
inputs,
63-
output_channels,
64-
depthwise_kernel_size=3,
65-
expansion_kernel_size=1,
66-
pointwise_kernel_size=1,
67-
strides=1,
68-
expansion_ratio=1.0,
69-
activation="relu6",
70-
name="inverted_residual_block",
71-
):
72-
input_channels = inputs.shape[-1]
73-
hidden_channels = make_divisible(input_channels * expansion_ratio)
74-
has_skip = strides == 1 and input_channels == output_channels
75-
76-
x = inputs
77-
78-
# Point-wise expansion
79-
x = apply_conv2d_block(
80-
x,
81-
hidden_channels,
82-
expansion_kernel_size,
83-
1,
84-
activation=activation,
85-
name=f"{name}_conv_pw",
86-
)
87-
# Depth-wise convolution
88-
x = apply_conv2d_block(
89-
x,
90-
kernel_size=depthwise_kernel_size,
91-
strides=strides,
92-
activation=activation,
93-
use_depthwise=True,
94-
name=f"{name}_conv_dw",
95-
)
96-
# Point-wise linear projection
97-
x = apply_conv2d_block(
98-
x,
99-
output_channels,
100-
pointwise_kernel_size,
101-
1,
102-
activation=None,
103-
name=f"{name}_conv_pwl",
104-
)
105-
if has_skip:
106-
x = layers.Add()([x, inputs])
107-
return x
108-
109-
11062
class MobileNetV2(FeatureExtractor):
11163
def __init__(
11264
self,
@@ -189,7 +141,7 @@ def __init__(
189141
)
190142
elif block_type == "ir":
191143
x = apply_inverted_residual_block(
192-
x, c, k, 1, 1, s, e, name=name
144+
x, c, k, 1, 1, s, e, activation="relu6", name=name
193145
)
194146
current_stride *= s
195147
features[f"BLOCK{current_block_idx}_S{current_stride}"] = x

0 commit comments

Comments
 (0)