Skip to content

Commit c98226f

Browse files
Add MobileOne (#36)
* Update export scripts * Add `MobileOne` * Update version * Update README * Fix tests * Remove explicit naming
1 parent e9cfd19 commit c98226f

28 files changed

+1167
-46
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ Reference: [Grad-CAM class activation visualization (keras.io)](https://keras.io
193193
|LCNet|[arXiv 2021](https://arxiv.org/abs/2109.15099)|`timm`|`kimm.models.LCNet*`|
194194
|MobileNetV2|[CVPR 2018](https://arxiv.org/abs/1801.04381)|`timm`|`kimm.models.MobileNetV2*`|
195195
|MobileNetV3|[ICCV 2019](https://arxiv.org/abs/1905.02244)|`timm`|`kimm.models.MobileNetV3*`|
196+
|MobileOne|[CVPR 2023](https://arxiv.org/abs/2206.04040)|`timm`|`kimm.models.MobileOne*`|
196197
|MobileViT|[ICLR 2022](https://arxiv.org/abs/2110.02178)|`timm`|`kimm.models.MobileViT*`|
197198
|MobileViTV2|[arXiv 2022](https://arxiv.org/abs/2206.02680)|`timm`|`kimm.models.MobileViTV2*`|
198199
|RegNet|[CVPR 2020](https://arxiv.org/abs/2003.13678)|`timm`|`kimm.models.RegNet*`|

kimm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
from kimm import models # force to add models to the registry
33
from kimm.utils.model_registry import list_models
44

5-
__version__ = "0.1.6"
5+
__version__ = "0.1.7"

kimm/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from kimm.layers.attention import Attention
22
from kimm.layers.layer_scale import LayerScale
3+
from kimm.layers.mobile_one_conv2d import MobileOneConv2D
34
from kimm.layers.position_embedding import PositionEmbedding
45
from kimm.layers.rep_conv2d import RepConv2D

kimm/layers/attention.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ def __init__(
1313
use_qk_norm: bool = False,
1414
attention_dropout_rate: float = 0.0,
1515
projection_dropout_rate: float = 0.0,
16-
name: str = "attention",
1716
**kwargs,
1817
):
1918
super().__init__(**kwargs)
@@ -25,20 +24,19 @@ def __init__(
2524
self.use_qk_norm = use_qk_norm
2625
self.attention_dropout_rate = attention_dropout_rate
2726
self.projection_dropout_rate = projection_dropout_rate
28-
self.name = name
2927

3028
self.qkv = layers.Dense(
3129
hidden_dim * 3,
3230
use_bias=use_qkv_bias,
3331
dtype=self.dtype_policy,
34-
name=f"{name}_qkv",
32+
name=f"{self.name}_qkv",
3533
)
3634
if use_qk_norm:
3735
self.q_norm = layers.LayerNormalization(
38-
dtype=self.dtype_policy, name=f"{name}_q_norm"
36+
dtype=self.dtype_policy, name=f"{self.name}_q_norm"
3937
)
4038
self.k_norm = layers.LayerNormalization(
41-
dtype=self.dtype_policy, name=f"{name}_k_norm"
39+
dtype=self.dtype_policy, name=f"{self.name}_k_norm"
4240
)
4341
else:
4442
self.q_norm = layers.Identity(dtype=self.dtype_policy)
@@ -47,15 +45,15 @@ def __init__(
4745
self.attention_dropout = layers.Dropout(
4846
attention_dropout_rate,
4947
dtype=self.dtype_policy,
50-
name=f"{name}_attn_drop",
48+
name=f"{self.name}_attn_drop",
5149
)
5250
self.projection = layers.Dense(
53-
hidden_dim, dtype=self.dtype_policy, name=f"{name}_proj"
51+
hidden_dim, dtype=self.dtype_policy, name=f"{self.name}_proj"
5452
)
5553
self.projection_dropout = layers.Dropout(
5654
projection_dropout_rate,
5755
dtype=self.dtype_policy,
58-
name=f"{name}_proj_drop",
56+
name=f"{self.name}_proj_drop",
5957
)
6058

6159
def build(self, input_shape):

kimm/layers/layer_scale.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,11 @@ def __init__(
1111
self,
1212
axis: int = -1,
1313
initializer: Initializer = initializers.Constant(1e-5),
14-
name: str = "layer_scale",
1514
**kwargs,
1615
):
1716
super().__init__(**kwargs)
1817
self.axis = axis
1918
self.initializer = initializer
20-
self.name = name
2119

2220
def build(self, input_shape):
2321
if isinstance(self.axis, list):

0 commit comments

Comments
 (0)