Skip to content

Commit 927370b

Browse files
Add docstrings for kimm.blocks.* (#50)
* Add docstrings for `kimm.blocks.*` * Fix argument
1 parent ac667c4 commit 927370b

File tree

10 files changed

+74
-24
lines changed

10 files changed

+74
-24
lines changed

kimm/_src/blocks/conv2d.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from keras import backend
44
from keras import layers
5+
from keras.src.utils.argument_validation import standardize_tuple
56

67
from kimm._src.kimm_export import kimm_export
78

@@ -10,29 +11,33 @@
1011
def apply_conv2d_block(
1112
inputs,
1213
filters: typing.Optional[int] = None,
13-
kernel_size: typing.Optional[
14-
typing.Union[int, typing.Sequence[int]]
15-
] = None,
14+
kernel_size: typing.Union[int, typing.Sequence[int]] = 1,
1615
strides: int = 1,
1716
groups: int = 1,
1817
activation: typing.Optional[str] = None,
1918
use_depthwise: bool = False,
20-
add_skip: bool = False,
19+
has_skip: bool = False,
2120
bn_momentum: float = 0.9,
2221
bn_epsilon: float = 1e-5,
2322
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
2423
name="conv2d_block",
2524
):
25+
"""(ZeroPadding) + Conv2D/DepthwiseConv2D + BN + (Activation)."""
2626
if kernel_size is None:
2727
raise ValueError(
2828
f"kernel_size must be passed. Received: kernel_size={kernel_size}"
2929
)
30-
if isinstance(kernel_size, int):
31-
kernel_size = [kernel_size, kernel_size]
30+
kernel_size = standardize_tuple(kernel_size, 2, "kernel_size")
3231

3332
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
34-
input_channels = inputs.shape[channels_axis]
35-
has_skip = add_skip and strides == 1 and input_channels == filters
33+
input_filters = inputs.shape[channels_axis]
34+
if has_skip and (strides != 1 or input_filters != filters):
35+
raise ValueError(
36+
"If `has_skip=True`, strides must be 1 and `filters` must be the "
37+
"same as input_filters. "
38+
f"Received: strides={strides}, filters={filters}, "
39+
f"input_filters={input_filters}"
40+
)
3641
x = inputs
3742

3843
if padding is None:

kimm/_src/blocks/depthwise_separation.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
@kimm_export(parent_path=["kimm.blocks"])
1212
def apply_depthwise_separation_block(
1313
inputs,
14-
output_channels: int,
14+
filters: int,
1515
depthwise_kernel_size: int = 3,
1616
pointwise_kernel_size: int = 1,
1717
strides: int = 1,
@@ -21,14 +21,21 @@ def apply_depthwise_separation_block(
2121
se_gate_activation: typing.Optional[str] = "sigmoid",
2222
se_make_divisible_number: typing.Optional[int] = None,
2323
pw_activation: typing.Optional[str] = None,
24-
skip: bool = True,
24+
has_skip: bool = True,
2525
bn_epsilon: float = 1e-5,
2626
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
2727
name: str = "depthwise_separation_block",
2828
):
29+
"""Conv2D block + (SqueezeAndExcitation) + Conv2D."""
2930
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
30-
input_channels = inputs.shape[channels_axis]
31-
has_skip = skip and (strides == 1 and input_channels == output_channels)
31+
input_filters = inputs.shape[channels_axis]
32+
if has_skip and (strides != 1 or input_filters != filters):
33+
raise ValueError(
34+
"If `has_skip=True`, strides must be 1 and `filters` must be the "
35+
"same as input_filters. "
36+
f"Received: strides={strides}, filters={filters}, "
37+
f"input_filters={input_filters}"
38+
)
3239

3340
x = inputs
3441
x = apply_conv2d_block(
@@ -52,7 +59,7 @@ def apply_depthwise_separation_block(
5259
)
5360
x = apply_conv2d_block(
5461
x,
55-
output_channels,
62+
filters,
5663
pointwise_kernel_size,
5764
1,
5865
activation=pw_activation,

kimm/_src/blocks/inverted_residual.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
@kimm_export(parent_path=["kimm.blocks"])
1313
def apply_inverted_residual_block(
1414
inputs,
15-
output_channels: int,
15+
filters: int,
1616
depthwise_kernel_size: int = 3,
1717
expansion_kernel_size: int = 1,
1818
pointwise_kernel_size: int = 1,
@@ -28,10 +28,11 @@ def apply_inverted_residual_block(
2828
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
2929
name: str = "inverted_residual_block",
3030
):
31+
"""Conv2D block + DepthwiseConv2D block + (SE) + Conv2D."""
3132
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
3233
input_channels = inputs.shape[channels_axis]
3334
hidden_channels = make_divisible(input_channels * expansion_ratio)
34-
has_skip = strides == 1 and input_channels == output_channels
35+
has_skip = strides == 1 and input_channels == filters
3536

3637
x = inputs
3738
# Point-wise expansion
@@ -70,7 +71,7 @@ def apply_inverted_residual_block(
7071
# Point-wise linear projection
7172
x = apply_conv2d_block(
7273
x,
73-
output_channels,
74+
filters,
7475
pointwise_kernel_size,
7576
1,
7677
activation=None,

kimm/_src/blocks/squeeze_and_excitation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def apply_se_block(
1717
se_input_channels: typing.Optional[int] = None,
1818
name: str = "se_block",
1919
):
20+
"""Squeeze and Excitation."""
2021
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
2122
input_channels = inputs.shape[channels_axis]
2223
if se_input_channels is None:

kimm/_src/blocks/transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def apply_mlp_block(
1919
data_format: typing.Optional[str] = None,
2020
name: str = "mlp_block",
2121
):
22+
"""Dense/Conv2D + Activation + Dense/Conv2D."""
2223
if data_format is None:
2324
data_format = backend.image_data_format()
2425
dim_axis = -1 if data_format == "channels_last" else 1
@@ -56,6 +57,7 @@ def apply_transformer_block(
5657
activation: str = "gelu",
5758
name: str = "transformer_block",
5859
):
60+
"""LN + Attention + LN + MLP block."""
5961
# data_format must be "channels_last"
6062
x = inputs
6163
residual_1 = x

kimm/_src/layers/reparameterizable_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(
1919
self,
2020
filters,
2121
kernel_size,
22-
strides=(1, 1),
22+
strides=1,
2323
padding=None,
2424
has_skip: bool = True,
2525
has_scale: bool = True,

kimm/_src/models/efficientnet.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,16 +245,28 @@ def __init__(
245245
"activation": activation,
246246
}
247247
if block_type == "ds":
248+
has_skip = x.shape[channels_axis] == c and s == 1
248249
x = apply_depthwise_separation_block(
249-
x, c, k, 1, s, se, se_activation=activation, **_kwargs
250+
x,
251+
c,
252+
k,
253+
1,
254+
s,
255+
se,
256+
se_activation=activation,
257+
has_skip=has_skip,
258+
**_kwargs,
250259
)
251260
elif block_type == "ir":
252261
se_c = x.shape[channels_axis]
253262
x = apply_inverted_residual_block(
254263
x, c, k, 1, 1, s, e, se, se_channels=se_c, **_kwargs
255264
)
256265
elif block_type == "cn":
257-
x = apply_conv2d_block(x, c, k, s, add_skip=True, **_kwargs)
266+
has_skip = x.shape[channels_axis] == c and s == 1
267+
x = apply_conv2d_block(
268+
x, c, k, s, has_skip=has_skip, **_kwargs
269+
)
258270
elif block_type == "er":
259271
x = apply_edge_residual_block(x, c, k, 1, s, e, **_kwargs)
260272
current_stride *= s

kimm/_src/models/hgnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def apply_high_perf_gpu_block(
267267
hidden_channels,
268268
output_channels,
269269
kernel_size,
270-
add_skip=False,
270+
has_skip=False,
271271
use_light_block=False,
272272
use_learnable_affine=False,
273273
aggregation="ese",
@@ -329,7 +329,7 @@ def apply_high_perf_gpu_block(
329329
name=f"{name}_aggregation_0",
330330
)
331331
x = apply_ese_module(x, output_channels, name=f"{name}_aggregation_1")
332-
if add_skip:
332+
if has_skip:
333333
x = layers.Add()([x, inputs])
334334
return x
335335

@@ -375,7 +375,7 @@ def apply_high_perf_gpu_stage(
375375
hidden_channels,
376376
output_channels,
377377
kernel_size,
378-
add_skip=False if i == 0 else True,
378+
has_skip=False if i == 0 else True,
379379
use_light_block=use_light_block,
380380
use_learnable_affine=use_learnable_affine,
381381
aggregation=aggregation,

kimm/_src/models/mobilenet_v2.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import typing
44

55
import keras
6+
from keras import backend
67

78
from kimm._src.blocks.conv2d import apply_conv2d_block
89
from kimm._src.blocks.depthwise_separation import (
@@ -55,6 +56,10 @@ def __init__(
5556
)
5657

5758
self.set_properties(kwargs)
59+
channels_axis = (
60+
-1 if backend.image_data_format() == "channels_last" else -3
61+
)
62+
5863
inputs = self.determine_input_tensor(
5964
input_tensor,
6065
self._input_shape,
@@ -93,8 +98,16 @@ def __init__(
9398
s = s if current_layer_idx == 0 else 1
9499
name = f"blocks_{current_block_idx}_{current_layer_idx}"
95100
if block_type == "ds":
101+
has_skip = x.shape[channels_axis] == c and s == 1
96102
x = apply_depthwise_separation_block(
97-
x, c, k, 1, s, activation="relu6", name=name
103+
x,
104+
c,
105+
k,
106+
1,
107+
s,
108+
activation="relu6",
109+
has_skip=has_skip,
110+
name=name,
98111
)
99112
elif block_type == "ir":
100113
x = apply_inverted_residual_block(

kimm/_src/models/mobilenet_v3.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import warnings
55

66
import keras
7+
from keras import backend
78
from keras import layers
89

910
from kimm._src.blocks.conv2d import apply_conv2d_block
@@ -124,6 +125,10 @@ def __init__(
124125
padding = kwargs.pop("padding", None)
125126

126127
self.set_properties(kwargs)
128+
channels_axis = (
129+
-1 if backend.image_data_format() == "channels_last" else -3
130+
)
131+
127132
inputs = self.determine_input_tensor(
128133
input_tensor,
129134
self._input_shape,
@@ -181,6 +186,10 @@ def __init__(
181186
),
182187
}
183188
if block_type in ("ds", "dsa"):
189+
if block_type == "dsa":
190+
has_skip = False
191+
else:
192+
has_skip = x.shape[channels_axis] == c and s == 1
184193
x = apply_depthwise_separation_block(
185194
x,
186195
c,
@@ -193,7 +202,7 @@ def __init__(
193202
se_gate_activation="hard_sigmoid",
194203
se_make_divisible_number=8,
195204
pw_activation=act if block_type == "dsa" else None,
196-
skip=False if block_type == "dsa" else True,
205+
has_skip=has_skip,
197206
**_kwargs,
198207
)
199208
elif block_type == "ir":

0 commit comments

Comments
 (0)