Skip to content

Commit e2d31cc

Browse files
Add arbitrary shape support for ViT series (#46)
* Improvements 1. Add arbitrary shape support to ViT and MobileViT 2. Simplify the logic in BaseModel 3. Add more model weights * Update version number * Improve test coverage * Fix numpy tests
1 parent 2160b3e commit e2d31cc

33 files changed

+509
-240
lines changed

kimm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
from kimm._src.utils.model_registry import list_models
1414
from kimm._src.version import version
1515

16-
__version__ = "0.2.0"
16+
__version__ = "0.2.1"

kimm/_src/blocks/transformer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ def apply_transformer_block(
5151
num_heads: int,
5252
mlp_ratio: float = 4.0,
5353
use_qkv_bias: bool = False,
54-
use_qk_norm: bool = False,
5554
projection_dropout_rate: float = 0.0,
5655
attention_dropout_rate: float = 0.0,
5756
activation: str = "gelu",
5857
name: str = "transformer_block",
5958
):
59+
# data_format must be "channels_last"
6060
x = inputs
6161
residual_1 = x
6262

@@ -65,7 +65,6 @@ def apply_transformer_block(
6565
dim,
6666
num_heads,
6767
use_qkv_bias,
68-
use_qk_norm,
6968
attention_dropout_rate,
7069
projection_dropout_rate,
7170
name=f"{name}_attn",
@@ -79,7 +78,7 @@ def apply_transformer_block(
7978
int(dim * mlp_ratio),
8079
activation=activation,
8180
dropout_rate=projection_dropout_rate,
82-
data_format="channels_last", # TODO: let backend decides
81+
data_format="channels_last",
8382
name=f"{name}_mlp",
8483
)
8584
x = layers.Add()([residual_2, x])

kimm/_src/layers/attention.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import keras
2+
from keras import InputSpec
23
from keras import layers
34
from keras import ops
45

@@ -13,7 +14,6 @@ def __init__(
1314
hidden_dim: int,
1415
num_heads: int = 8,
1516
use_qkv_bias: bool = False,
16-
use_qk_norm: bool = False,
1717
attention_dropout_rate: float = 0.0,
1818
projection_dropout_rate: float = 0.0,
1919
**kwargs,
@@ -24,7 +24,6 @@ def __init__(
2424
self.head_dim = hidden_dim // num_heads
2525
self.scale = self.head_dim ** (-0.5)
2626
self.use_qkv_bias = use_qkv_bias
27-
self.use_qk_norm = use_qk_norm
2827
self.attention_dropout_rate = attention_dropout_rate
2928
self.projection_dropout_rate = projection_dropout_rate
3029

@@ -34,16 +33,6 @@ def __init__(
3433
dtype=self.dtype_policy,
3534
name=f"{self.name}_qkv",
3635
)
37-
if use_qk_norm:
38-
self.q_norm = layers.LayerNormalization(
39-
dtype=self.dtype_policy, name=f"{self.name}_q_norm"
40-
)
41-
self.k_norm = layers.LayerNormalization(
42-
dtype=self.dtype_policy, name=f"{self.name}_k_norm"
43-
)
44-
else:
45-
self.q_norm = layers.Identity(dtype=self.dtype_policy)
46-
self.k_norm = layers.Identity(dtype=self.dtype_policy)
4736

4837
self.attention_dropout = layers.Dropout(
4938
attention_dropout_rate,
@@ -60,11 +49,16 @@ def __init__(
6049
)
6150

6251
def build(self, input_shape):
52+
self.input_spec = InputSpec(ndim=len(input_shape))
53+
if self.input_spec.ndim not in (3, 4):
54+
raise ValueError(
55+
"The ndim of the inputs must be 3 or 4. "
56+
f"Received: input_shape={input_shape}"
57+
)
58+
6359
self.qkv.build(input_shape)
6460
qkv_output_shape = list(input_shape)
6561
qkv_output_shape[-1] = qkv_output_shape[-1] * 3
66-
self.q_norm.build(qkv_output_shape)
67-
self.k_norm.build(qkv_output_shape)
6862
attention_input_shape = [
6963
input_shape[0],
7064
self.num_heads,
@@ -79,30 +73,42 @@ def build(self, input_shape):
7973
def call(self, inputs, training=None, mask=None):
8074
input_shape = ops.shape(inputs)
8175
qkv = self.qkv(inputs)
82-
qkv = ops.reshape(
83-
qkv,
84-
[
85-
input_shape[0],
86-
input_shape[1],
87-
3,
88-
self.num_heads,
89-
self.head_dim,
90-
],
91-
)
92-
qkv = ops.transpose(qkv, [2, 0, 3, 1, 4])
93-
q, k, v = ops.unstack(qkv, 3, axis=0)
94-
q = self.q_norm(q)
95-
k = self.k_norm(k)
76+
if self.input_spec.ndim == 3:
77+
qkv = ops.reshape(
78+
qkv,
79+
[
80+
input_shape[0],
81+
input_shape[1],
82+
3,
83+
self.num_heads,
84+
self.head_dim,
85+
],
86+
)
87+
qkv = ops.transpose(qkv, [0, 3, 2, 1, 4])
88+
q, k, v = ops.unstack(qkv, 3, axis=2)
89+
else:
90+
# self.input_spec.ndim==4
91+
qkv = ops.reshape(
92+
qkv,
93+
[
94+
input_shape[0],
95+
input_shape[1],
96+
input_shape[2],
97+
3,
98+
self.num_heads,
99+
self.head_dim,
100+
],
101+
)
102+
qkv = ops.transpose(qkv, [0, 1, 4, 3, 2, 5])
103+
q, k, v = ops.unstack(qkv, 3, axis=3)
96104

97105
# attention
98106
q = ops.multiply(q, self.scale)
99107
attn = ops.matmul(q, ops.swapaxes(k, -2, -1))
100108
attn = ops.softmax(attn)
101109
attn = self.attention_dropout(attn)
102110
x = ops.matmul(attn, v)
103-
104-
x = ops.swapaxes(x, 1, 2)
105-
x = ops.reshape(x, input_shape)
111+
x = ops.reshape(ops.swapaxes(x, -3, -2), input_shape)
106112
x = self.projection(x)
107113
x = self.projection_dropout(x)
108114
return x
@@ -114,7 +120,6 @@ def get_config(self):
114120
"hidden_dim": self.hidden_dim,
115121
"num_heads": self.num_heads,
116122
"use_qkv_bias": self.use_qkv_bias,
117-
"use_qk_norm": self.use_qk_norm,
118123
"attention_dropout_rate": self.attention_dropout_rate,
119124
"projection_dropout_rate": self.projection_dropout_rate,
120125
"name": self.name,

kimm/_src/layers/attention_test.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import keras
12
import pytest
23
from absl.testing import parameterized
34
from keras.src import testing
@@ -7,7 +8,7 @@
78

89
class AttentionTest(testing.TestCase, parameterized.TestCase):
910
@pytest.mark.requires_trainable_backend
10-
def test_attention_basic(self):
11+
def test_basic_3d(self):
1112
self.run_layer_test(
1213
Attention,
1314
init_kwargs={"hidden_dim": 20, "num_heads": 2},
@@ -18,3 +19,31 @@ def test_attention_basic(self):
1819
expected_num_losses=0,
1920
supports_masking=False,
2021
)
22+
23+
@pytest.mark.requires_trainable_backend
24+
def test_basic_4d(self):
25+
self.run_layer_test(
26+
Attention,
27+
init_kwargs={"hidden_dim": 20, "num_heads": 2},
28+
input_shape=(1, 2, 10, 20),
29+
expected_output_shape=(1, 2, 10, 20),
30+
expected_num_trainable_weights=3,
31+
expected_num_non_trainable_weights=0,
32+
expected_num_losses=0,
33+
supports_masking=False,
34+
)
35+
36+
def test_invalid_ndim(self):
37+
# Test 2D
38+
inputs = keras.Input(shape=[1])
39+
with self.assertRaisesRegex(
40+
ValueError, "The ndim of the inputs must be 3 or 4."
41+
):
42+
Attention(1, 1)(inputs)
43+
44+
# Test 5D
45+
inputs = keras.Input(shape=[1, 2, 3, 4])
46+
with self.assertRaisesRegex(
47+
ValueError, "The ndim of the inputs must be 3 or 4."
48+
):
49+
Attention(1, 1)(inputs)

kimm/_src/layers/layer_scale_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class LayerScaleTest(testing.TestCase, parameterized.TestCase):
99
@pytest.mark.requires_trainable_backend
10-
def test_layer_scale_basic(self):
10+
def test_basic(self):
1111
self.run_layer_test(
1212
LayerScale,
1313
init_kwargs={"axis": -1},

kimm/_src/layers/learnable_affine_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class LearnableAffineTest(testing.TestCase, parameterized.TestCase):
99
@pytest.mark.requires_trainable_backend
10-
def test_layer_scale_basic(self):
10+
def test_basic(self):
1111
self.run_layer_test(
1212
LearnableAffine,
1313
init_kwargs={"scale_value": 1.0, "bias_value": 0.0},

kimm/_src/layers/mobile_one_conv2d_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
class MobileOneConv2DTest(testing.TestCase, parameterized.TestCase):
7474
@parameterized.parameters(TEST_CASES)
7575
@pytest.mark.requires_trainable_backend
76-
def test_mobile_one_conv2d_basic(
76+
def test_basic(
7777
self,
7878
filters,
7979
kernel_size,
@@ -113,7 +113,7 @@ def test_mobile_one_conv2d_basic(
113113
)
114114

115115
@parameterized.parameters(TEST_CASES)
116-
def test_mobile_one_conv2d_get_reparameterized_weights(
116+
def test_get_reparameterized_weights(
117117
self,
118118
filters,
119119
kernel_size,

kimm/_src/layers/position_embedding.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,24 @@
88
@kimm_export(parent_path=["kimm.layers"])
99
@keras.saving.register_keras_serializable(package="kimm")
1010
class PositionEmbedding(layers.Layer):
11-
def __init__(self, **kwargs):
11+
def __init__(self, height, width, **kwargs):
1212
super().__init__(**kwargs)
13+
# We need height and width for saving and loading
14+
self.height = int(height)
15+
self.width = int(width)
1316

1417
def build(self, input_shape):
1518
if len(input_shape) != 3:
1619
raise ValueError(
1720
"PositionEmbedding only accepts 3-dimensional input. "
1821
f"Received: input_shape={input_shape}"
1922
)
23+
if self.height * self.width != input_shape[-2]:
24+
raise ValueError(
25+
"The embedding size doesn't match the height and width. "
26+
f"Received: height={self.height}, width={self.width}, "
27+
f"input_shape={input_shape}"
28+
)
2029
self.pos_embed = self.add_weight(
2130
shape=[1, input_shape[-2] + 1, input_shape[-1]],
2231
initializer="random_normal",
@@ -41,5 +50,41 @@ def compute_output_shape(self, input_shape):
4150
output_shape[1] = output_shape[1] + 1
4251
return output_shape
4352

53+
def save_own_variables(self, store):
54+
super().save_own_variables(store)
55+
# Add height and width information
56+
store["height"] = self.height
57+
store["width"] = self.width
58+
59+
def load_own_variables(self, store):
60+
old_height = int(store["height"][...])
61+
old_width = int(store["width"][...])
62+
if old_height == self.height and old_width == self.width:
63+
self.pos_embed.assign(store["0"])
64+
self.cls_token.assign(store["1"])
65+
return
66+
67+
# Resize the embedding if there is a shape mismatch
68+
pos_embed = store["0"]
69+
pos_embed_prefix, pos_embed = pos_embed[:, :1], pos_embed[:, 1:]
70+
pos_embed_dim = pos_embed.shape[-1]
71+
pos_embed = ops.cast(pos_embed, "float32")
72+
pos_embed = ops.reshape(pos_embed, [1, old_height, old_width, -1])
73+
pos_embed = ops.image.resize(
74+
pos_embed,
75+
size=[self.height, self.width],
76+
interpolation="bilinear",
77+
antialias=True,
78+
data_format="channels_last",
79+
)
80+
pos_embed = ops.reshape(pos_embed, [1, -1, pos_embed_dim])
81+
pos_embed = ops.concatenate([pos_embed_prefix, pos_embed], axis=1)
82+
self.pos_embed.assign(pos_embed)
83+
self.cls_token.assign(store["1"])
84+
4485
def get_config(self):
45-
return super().get_config()
86+
config = super().get_config()
87+
config.update(
88+
{"height": self.height, "width": self.width, "name": self.name}
89+
)
90+
return config
Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import pytest
22
from absl.testing import parameterized
33
from keras import layers
4+
from keras import models
45
from keras.src import testing
56

67
from kimm._src.layers.position_embedding import PositionEmbedding
78

89

910
class PositionEmbeddingTest(testing.TestCase, parameterized.TestCase):
1011
@pytest.mark.requires_trainable_backend
11-
def test_position_embedding_basic(self):
12+
def test_basic(self):
1213
self.run_layer_test(
1314
PositionEmbedding,
14-
init_kwargs={},
15+
init_kwargs={"height": 2, "width": 5},
1516
input_shape=(1, 10, 10),
1617
expected_output_shape=(1, 11, 10),
1718
expected_num_trainable_weights=2,
@@ -20,10 +21,23 @@ def test_position_embedding_basic(self):
2021
supports_masking=False,
2122
)
2223

24+
def test_embedding_resizing(self):
25+
temp_dir = self.get_temp_dir()
26+
model = models.Sequential(
27+
[layers.Input(shape=[256, 8]), PositionEmbedding(16, 16)]
28+
)
29+
model.save(f"{temp_dir}/model.keras")
30+
31+
# Resize from (16, 16) to (8, 8)
32+
model = models.Sequential(
33+
[layers.Input(shape=[64, 8]), PositionEmbedding(8, 8)]
34+
)
35+
model.load_weights(f"{temp_dir}/model.keras")
36+
2337
@pytest.mark.requires_trainable_backend
24-
def test_position_embedding_invalid_input_shape(self):
38+
def test_invalid_input_shape(self):
2539
inputs = layers.Input([3])
2640
with self.assertRaisesRegex(
2741
ValueError, "PositionEmbedding only accepts 3-dimensional input."
2842
):
29-
PositionEmbedding()(inputs)
43+
PositionEmbedding(2, 2)(inputs)

kimm/_src/layers/rep_conv2d_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
class RepConv2DTest(testing.TestCase, parameterized.TestCase):
5454
@parameterized.parameters(TEST_CASES)
5555
@pytest.mark.requires_trainable_backend
56-
def test_rep_conv2d_basic(
56+
def test_basic(
5757
self,
5858
filters,
5959
kernel_size,
@@ -89,7 +89,7 @@ def test_rep_conv2d_basic(
8989
)
9090

9191
@parameterized.parameters(TEST_CASES)
92-
def test_rep_conv2d_get_reparameterized_weights(
92+
def test_get_reparameterized_weights(
9393
self,
9494
filters,
9595
kernel_size,

0 commit comments

Comments
 (0)