Skip to content

Commit c608f1c

Browse files
Add type hints (#25)
1 parent a50003c commit c608f1c

17 files changed

+135
-98
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
<!-- markdownlint-disable MD033 -->
22
<!-- markdownlint-disable MD041 -->
33

4-
# Keras Image Models
5-
64
<div align="center">
75
<img width="50%" src="https://github.com/james77777778/kimm/assets/20734616/b21db8f2-307b-4791-b93d-e913e45fb238" alt="KIMM">
86

@@ -11,6 +9,8 @@
119
[![codecov](https://codecov.io/gh/james77777778/kimm/graph/badge.svg?token=eEha1SR80D)](https://codecov.io/gh/james77777778/kimm)
1210
</div>
1311

12+
# Keras Image Models
13+
1414
## Description
1515

1616
**K**eras **Im**age **M**odels (`kimm`) is a collection of image models, blocks and layers written in Keras 3. The goal is to offer SOTA models with pretrained weights in a user-friendly manner.

kimm/blocks/base_block.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
import typing
2+
13
from keras import layers
24

35
from kimm.utils import make_divisible
46

57

6-
def apply_activation(x, activation=None, name="activation"):
8+
def apply_activation(
9+
inputs, activation: typing.Optional[str] = None, name: str = "activation"
10+
):
11+
x = inputs
712
if activation is not None:
813
if isinstance(activation, str):
914
x = layers.Activation(activation, name=name)(x)
@@ -18,16 +23,18 @@ def apply_activation(x, activation=None, name="activation"):
1823

1924
def apply_conv2d_block(
2025
inputs,
21-
filters=None,
22-
kernel_size=None,
23-
strides=1,
24-
groups=1,
25-
activation=None,
26-
use_depthwise=False,
27-
add_skip=False,
28-
bn_momentum=0.9,
29-
bn_epsilon=1e-5,
30-
padding=None,
26+
filters: typing.Optional[int] = None,
27+
kernel_size: typing.Optional[
28+
typing.Union[int, typing.Sequence[int]]
29+
] = None,
30+
strides: int = 1,
31+
groups: int = 1,
32+
activation: typing.Optional[str] = None,
33+
use_depthwise: bool = False,
34+
add_skip: bool = False,
35+
bn_momentum: float = 0.9,
36+
bn_epsilon: float = 1e-5,
37+
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
3138
name="conv2d_block",
3239
):
3340
if kernel_size is None:
@@ -77,12 +84,12 @@ def apply_conv2d_block(
7784

7885
def apply_se_block(
7986
inputs,
80-
se_ratio=0.25,
81-
activation="relu",
82-
gate_activation="sigmoid",
83-
make_divisible_number=None,
84-
se_input_channels=None,
85-
name="se_block",
87+
se_ratio: float = 0.25,
88+
activation: typing.Optional[str] = "relu",
89+
gate_activation: typing.Optional[str] = "sigmoid",
90+
make_divisible_number: typing.Optional[int] = None,
91+
se_input_channels: typing.Optional[int] = None,
92+
name: str = "se_block",
8693
):
8794
input_channels = inputs.shape[-1]
8895
if se_input_channels is None:
@@ -94,7 +101,6 @@ def apply_se_block(
94101
se_input_channels * se_ratio, make_divisible_number
95102
)
96103

97-
ori_x = inputs
98104
x = inputs
99105
x = layers.GlobalAveragePooling2D(keepdims=True, name=f"{name}_mean")(x)
100106
x = layers.Conv2D(
@@ -105,5 +111,5 @@ def apply_se_block(
105111
input_channels, 1, use_bias=True, name=f"{name}_conv_expand"
106112
)(x)
107113
x = apply_activation(x, gate_activation, name=f"{name}_gate")
108-
out = layers.Multiply(name=name)([ori_x, x])
109-
return out
114+
x = layers.Multiply(name=name)([inputs, x])
115+
return x

kimm/blocks/depthwise_separation_block.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import typing
2+
13
from keras import layers
24

35
from kimm.blocks.base_block import apply_conv2d_block
@@ -6,20 +8,20 @@
68

79
def apply_depthwise_separation_block(
810
inputs,
9-
output_channels,
10-
depthwise_kernel_size=3,
11-
pointwise_kernel_size=1,
12-
strides=1,
13-
se_ratio=0.0,
14-
activation="swish",
15-
se_activation="relu",
16-
se_gate_activation="sigmoid",
17-
se_make_divisible_number=None,
18-
pw_activation=None,
19-
skip=True,
20-
bn_epsilon=1e-5,
21-
padding=None,
22-
name="depthwise_separation_block",
11+
output_channels: int,
12+
depthwise_kernel_size: int = 3,
13+
pointwise_kernel_size: int = 1,
14+
strides: int = 1,
15+
se_ratio: float = 0.0,
16+
activation: typing.Optional[str] = "swish",
17+
se_activation: typing.Optional[str] = "relu",
18+
se_gate_activation: typing.Optional[str] = "sigmoid",
19+
se_make_divisible_number: typing.Optional[int] = None,
20+
pw_activation: typing.Optional[str] = None,
21+
skip: bool = True,
22+
bn_epsilon: float = 1e-5,
23+
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
24+
name: str = "depthwise_separation_block",
2325
):
2426
input_channels = inputs.shape[-1]
2527
has_skip = skip and (strides == 1 and input_channels == output_channels)

kimm/blocks/inverted_residual_block.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import typing
2+
13
from keras import layers
24

35
from kimm.blocks.base_block import apply_conv2d_block
@@ -7,21 +9,21 @@
79

810
def apply_inverted_residual_block(
911
inputs,
10-
output_channels,
11-
depthwise_kernel_size=3,
12-
expansion_kernel_size=1,
13-
pointwise_kernel_size=1,
14-
strides=1,
15-
expansion_ratio=1.0,
16-
se_ratio=0.0,
17-
activation="swish",
18-
se_channels=None,
19-
se_activation=None,
20-
se_gate_activation="sigmoid",
21-
se_make_divisible_number=None,
22-
bn_epsilon=1e-5,
23-
padding=None,
24-
name="inverted_residual_block",
12+
output_channels: int,
13+
depthwise_kernel_size: int = 3,
14+
expansion_kernel_size: int = 1,
15+
pointwise_kernel_size: int = 1,
16+
strides: int = 1,
17+
expansion_ratio: float = 1.0,
18+
se_ratio: float = 0.0,
19+
activation: str = "swish",
20+
se_channels: typing.Optional[int] = None,
21+
se_activation: typing.Optional[str] = None,
22+
se_gate_activation: typing.Optional[str] = "sigmoid",
23+
se_make_divisible_number: typing.Optional[int] = None,
24+
bn_epsilon: float = 1e-5,
25+
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
26+
name: str = "inverted_residual_block",
2527
):
2628
input_channels = inputs.shape[-1]
2729
hidden_channels = make_divisible(input_channels * expansion_ratio)

kimm/blocks/transformer_block.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1+
import typing
2+
13
from keras import layers
24

35
from kimm import layers as kimm_layers
46

57

68
def apply_mlp_block(
79
inputs,
8-
hidden_dim,
9-
output_dim=None,
10-
activation="gelu",
11-
normalization=None,
12-
use_bias=True,
13-
dropout_rate=0.0,
14-
use_conv_mlp=False,
15-
name="mlp_block",
10+
hidden_dim: int,
11+
output_dim: typing.Optional[int] = None,
12+
activation: str = "gelu",
13+
use_bias: bool = True,
14+
dropout_rate: float = 0.0,
15+
use_conv_mlp: bool = False,
16+
name: str = "mlp_block",
1617
):
1718
input_dim = inputs.shape[-1]
1819
output_dim = output_dim or input_dim
@@ -26,8 +27,6 @@ def apply_mlp_block(
2627
x = layers.Dense(hidden_dim, use_bias=use_bias, name=f"{name}_fc1")(x)
2728
x = layers.Activation(activation, name=f"{name}_act")(x)
2829
x = layers.Dropout(dropout_rate, name=f"{name}_drop1")(x)
29-
if normalization is not None:
30-
x = normalization(name=f"{name}_norm")(x)
3130
if use_conv_mlp:
3231
x = layers.Conv2D(
3332
output_dim, 1, use_bias=use_bias, name=f"{name}_fc2_conv2d"
@@ -40,15 +39,15 @@ def apply_mlp_block(
4039

4140
def apply_transformer_block(
4241
inputs,
43-
dim,
44-
num_heads,
45-
mlp_ratio=4.0,
46-
use_qkv_bias=False,
47-
use_qk_norm=False,
48-
projection_dropout_rate=0.0,
49-
attention_dropout_rate=0.0,
50-
activation="gelu",
51-
name="transformer_block",
42+
dim: int,
43+
num_heads: int,
44+
mlp_ratio: float = 4.0,
45+
use_qkv_bias: bool = False,
46+
use_qk_norm: bool = False,
47+
projection_dropout_rate: float = 0.0,
48+
attention_dropout_rate: float = 0.0,
49+
activation: str = "gelu",
50+
name: str = "transformer_block",
5251
):
5352
x = inputs
5453
residual_1 = x

kimm/layers/attention.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
class Attention(layers.Layer):
88
def __init__(
99
self,
10-
hidden_dim,
11-
num_heads=8,
12-
use_qkv_bias=False,
13-
use_qk_norm=False,
14-
attention_dropout_rate=0.0,
15-
projection_dropout_rate=0.0,
16-
name="attention",
10+
hidden_dim: int,
11+
num_heads: int = 8,
12+
use_qkv_bias: bool = False,
13+
use_qk_norm: bool = False,
14+
attention_dropout_rate: float = 0.0,
15+
projection_dropout_rate: float = 0.0,
16+
name: str = "attention",
1717
**kwargs,
1818
):
1919
super().__init__(**kwargs)

kimm/layers/layer_scale.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
from keras import initializers
33
from keras import layers
44
from keras import ops
5+
from keras.initializers import Initializer
56

67

78
@keras.saving.register_keras_serializable(package="kimm")
89
class LayerScale(layers.Layer):
910
def __init__(
1011
self,
11-
hidden_size,
12-
initializer=initializers.Constant(1e-5),
13-
name="layer_scale",
12+
hidden_size: int,
13+
initializer: Initializer = initializers.Constant(1e-5),
14+
name: str = "layer_scale",
1415
**kwargs,
1516
):
1617
super().__init__(**kwargs)

kimm/models/base_model.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,12 @@ def set_properties(
9292

9393
def determine_input_tensor(
9494
self,
95-
input_tensor=None,
96-
input_shape=None,
97-
default_size=224,
98-
min_size=32,
99-
require_flatten=False,
100-
static_shape=False,
95+
input_tensor: typing.Optional[KerasTensor] = None,
96+
input_shape: typing.Optional[typing.Sequence[int]] = None,
97+
default_size: int = 224,
98+
min_size: int = 32,
99+
require_flatten: bool = False,
100+
static_shape: bool = False,
101101
):
102102
"""Determine the input tensor by the arguments."""
103103
input_shape = imagenet_utils.obtain_input_shape(
@@ -118,7 +118,11 @@ def determine_input_tensor(
118118
x = utils.get_source_inputs(input_tensor)
119119
return x
120120

121-
def build_preprocessing(self, inputs, mode="imagenet"):
121+
def build_preprocessing(
122+
self,
123+
inputs,
124+
mode: typing.Literal["imagenet", "0_1", "-1_1"] = "imagenet",
125+
):
122126
if self._include_preprocessing is False:
123127
return inputs
124128
if mode == "imagenet":
@@ -140,7 +144,13 @@ def build_preprocessing(self, inputs, mode="imagenet"):
140144
)
141145
return x
142146

143-
def build_top(self, inputs, classes, classifier_activation, dropout_rate):
147+
def build_top(
148+
self,
149+
inputs,
150+
classes: int,
151+
classifier_activation: str,
152+
dropout_rate: float,
153+
):
144154
x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs)
145155
x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x)
146156
x = layers.Dense(

kimm/models/efficientnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(
134134
fix_stem_and_head_channels: bool = False,
135135
fix_first_and_last_blocks: bool = False,
136136
activation="swish",
137-
config: typing.Union[str, typing.List] = "v1",
137+
config: str = "v1",
138138
**kwargs,
139139
):
140140
_available_configs = [

kimm/models/ghostnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def __init__(
234234
self,
235235
width: float = 1.0,
236236
config: typing.Union[str, typing.List] = "default",
237-
version: str = "v1",
237+
version: typing.Literal["v1", "v2"] = "v1",
238238
**kwargs,
239239
):
240240
_available_configs = ["default"]

kimm/models/inception_v3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def apply_inception_aux_block(inputs, classes, name="inception_aux_block"):
203203

204204
@keras.saving.register_keras_serializable(package="kimm")
205205
class InceptionV3Base(BaseModel):
206-
def __init__(self, has_aux_logits=False, **kwargs):
206+
def __init__(self, has_aux_logits: bool = False, **kwargs):
207207
input_tensor = kwargs.pop("input_tensor", None)
208208
self.set_properties(kwargs, 299)
209209
inputs = self.determine_input_tensor(

kimm/models/mobilenet_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(
2929
width: float = 1.0,
3030
depth: float = 1.0,
3131
fix_stem_and_head_channels: bool = False,
32-
config: typing.Union[str, typing.List] = "default",
32+
config: typing.Literal["default"] = "default",
3333
**kwargs,
3434
):
3535
_available_configs = ["default"]

kimm/models/mobilenet_v3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(
8686
width: float = 1.0,
8787
depth: float = 1.0,
8888
fix_stem_and_head_channels: bool = False,
89-
config: typing.Union[str, typing.List] = "large",
89+
config: typing.Literal["small", "large", "lcnet"] = "large",
9090
minimal: bool = False,
9191
**kwargs,
9292
):

kimm/models/mobilevit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def __init__(
166166
self,
167167
stem_channels: int = 16,
168168
head_channels: int = 640,
169-
activation="swish",
169+
activation: str = "swish",
170170
config: str = "v1_s",
171171
**kwargs,
172172
):

kimm/models/resnet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,10 @@ def apply_bottleneck_block(
106106
@keras.saving.register_keras_serializable(package="kimm")
107107
class ResNet(BaseModel):
108108
def __init__(
109-
self, block_fn: str, num_blocks: typing.Sequence[int], **kwargs
109+
self,
110+
block_fn: typing.Literal["basic", "bottleneck"],
111+
num_blocks: typing.Sequence[int],
112+
**kwargs,
110113
):
111114
if block_fn not in ("basic", "bottleneck"):
112115
raise ValueError(

0 commit comments

Comments
 (0)