Skip to content

Commit 69a24f2

Browse files
Add RepVGG (#32)
* Add `RepVGG` * Update tests * Add test for `get_reparameterized_model` * Update README * Minor update * Fix test * Fix readme * Fix readme
1 parent 219fe28 commit 69a24f2

14 files changed

+1214
-64
lines changed

README.md

Lines changed: 64 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -17,66 +17,73 @@
1717

1818
**K**eras **Im**age **M**odels (`kimm`) is a collection of image models, blocks and layers written in Keras 3. The goal is to offer SOTA models with pretrained weights in a user-friendly manner.
1919

20-
`kimm` is:
20+
KIMM is:
2121

22-
- 🚀 A model zoo where almost all models come with pre-trained weights on ImageNet.
22+
🚀 A model zoo where almost all models come with pre-trained weights on ImageNet.
2323

24-
> **Note:**
25-
> The accuracy of the converted models can be found at [results-imagenet.csv (timm)](https://github.com/huggingface/pytorch-image-models/blob/main/results/results-imagenet.csv) and [https://keras.io/api/applications/ (keras)](https://keras.io/api/applications/),
26-
> and the numerical differences of the converted models can be verified in `tools/convert_*.py`
24+
> [!NOTE]
25+
> The accuracy of the converted models can be found at [results-imagenet.csv (timm)](https://github.com/huggingface/pytorch-image-models/blob/main/results/results-imagenet.csv) and [https://keras.io/api/applications/ (keras)](https://keras.io/api/applications/),
26+
> and the numerical differences of the converted models can be verified in `tools/convert_*.py`.
2727
28-
- ✨ Exposing a common API identical to offcial `keras.applications.*`.
28+
✨ Exposing a common API identical to offcial `keras.applications.*`.
2929

30-
```python
31-
model = kimm.models.RegNetY002(
32-
input_tensor: keras.KerasTensor = None,
33-
input_shape: typing.Optional[typing.Sequence[int]] = None,
34-
include_preprocessing: bool = True,
35-
include_top: bool = True,
36-
pooling: typing.Optional[str] = None,
37-
dropout_rate: float = 0.0,
38-
classes: int = 1000,
39-
classifier_activation: str = "softmax",
40-
weights: typing.Optional[str] = "imagenet",
41-
name: str = "RegNetY002",
42-
)
43-
```
44-
45-
- 🔥 Integrated with feature extraction capability.
46-
47-
```python
48-
from keras import random
49-
import kimm
50-
51-
model = kimm.models.ConvNeXtAtto(feature_extractor=True)
52-
x = random.uniform([1, 224, 224, 3])
53-
y = model(x, training=False)
54-
# y becomes a dict
55-
for k, v in y.items():
56-
print(k, v.shape)
57-
```
58-
59-
- 🧰 Providing APIs to export models to `.tflite` and `.onnx`.
60-
61-
```python
62-
# in tensorflow backend
63-
from keras import backend
64-
import kimm
65-
66-
backend.set_image_data_format("channels_last")
67-
model = kimm.models.MobileNet050V3Small()
68-
kimm.export.export_tflite(model, [224, 224, 3], "model.tflite")
69-
```
70-
71-
```python
72-
# in torch backend
73-
from keras import backend
74-
import kimm
75-
76-
backend.set_image_data_format("channels_first")
77-
model = kimm.models.MobileNet050V3Small()
78-
kimm.export.export_onnx(model, [3, 224, 224], "model.onnx")
79-
```
30+
```python
31+
model = kimm.models.RegNetY002(
32+
input_tensor: keras.KerasTensor = None,
33+
input_shape: typing.Optional[typing.Sequence[int]] = None,
34+
include_preprocessing: bool = True,
35+
include_top: bool = True,
36+
pooling: typing.Optional[str] = None,
37+
dropout_rate: float = 0.0,
38+
classes: int = 1000,
39+
classifier_activation: str = "softmax",
40+
weights: typing.Optional[str] = "imagenet",
41+
name: str = "RegNetY002",
42+
)
43+
```
44+
45+
🔥 Integrated with feature extraction capability.
46+
47+
```python
48+
model = kimm.models.ConvNeXtAtto(feature_extractor=True)
49+
x = keras.random.uniform([1, 224, 224, 3])
50+
y = model.predict(x)
51+
# y becomes a dict
52+
for k, v in y.items():
53+
print(k, v.shape)
54+
```
55+
56+
🧰 Providing APIs to export models to `.tflite` and `.onnx`.
57+
58+
```python
59+
# tensorflow backend
60+
keras.backend.set_image_data_format("channels_last")
61+
model = kimm.models.MobileNet050V3Small()
62+
kimm.export.export_tflite(model, [224, 224, 3], "model.tflite")
63+
```
64+
65+
```python
66+
# torch backend
67+
keras.backend.set_image_data_format("channels_first")
68+
model = kimm.models.MobileNet050V3Small()
69+
kimm.export.export_onnx(model, [3, 224, 224], "model.onnx")
70+
```
71+
72+
> [!IMPORTANT]
73+
> `kimm.export.export_tflite` is currently restricted to `tensorflow` backend and `channels_last`.
74+
> `kimm.export.export_onnx` is currently restricted to `torch` backend and `channels_first`.
75+
76+
🔧 Supporting the reparameterization technique.
77+
78+
```python
79+
model = kimm.models.RepVGGA0()
80+
reparameterized_model = kimm.utils.get_reparameterized_model(model)
81+
# or
82+
# reparameterized_model = model.get_reparameterized_model()
83+
y1 = model.predict(x)
84+
y2 = model.predict(x)
85+
np.testing.assert_allclose(y1, y2, atol=1e-5)
86+
```
8087

8188
## Installation
8289

@@ -175,6 +182,7 @@ Reference: [Grad-CAM class activation visualization (keras.io)](https://keras.io
175182
|MobileViT|[ICLR 2022](https://arxiv.org/abs/2110.02178)|`timm`|`kimm.models.MobileViT*`|
176183
|MobileViTV2|[arXiv 2022](https://arxiv.org/abs/2206.02680)|`timm`|`kimm.models.MobileViTV2*`|
177184
|RegNet|[CVPR 2020](https://arxiv.org/abs/2003.13678)|`timm`|`kimm.models.RegNet*`|
185+
|RepVGG|[CVPR 2021](https://arxiv.org/abs/2101.03697)|`timm`|`kimm.models.RepVGG*`|
178186
|ResNet|[CVPR 2015](https://arxiv.org/abs/1512.03385)|`timm`|`kimm.models.ResNet*`|
179187
|TinyNet|[NeurIPS 2020](https://arxiv.org/abs/2010.14819)|`timm`|`kimm.models.TinyNet*`|
180188
|VGG|[ICLR 2015](https://arxiv.org/abs/1409.1556)|`timm`|`kimm.models.VGG*`|

kimm/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from kimm.layers.attention import Attention
22
from kimm.layers.layer_scale import LayerScale
33
from kimm.layers.position_embedding import PositionEmbedding
4+
from kimm.layers.rep_conv2d import RepConv2D

kimm/layers/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
self.qkv = layers.Dense(
3131
hidden_dim * 3,
3232
use_bias=use_qkv_bias,
33-
dtype=self.dtype,
33+
dtype=self.dtype_policy,
3434
name=f"{name}_qkv",
3535
)
3636
if use_qk_norm:

0 commit comments

Comments
 (0)