|
17 | 17 |
|
18 | 18 | **K**eras **Im**age **M**odels (`kimm`) is a collection of image models, blocks and layers written in Keras 3. The goal is to offer SOTA models with pretrained weights in a user-friendly manner.
|
19 | 19 |
|
20 |
| -`kimm` is: |
| 20 | +KIMM is: |
21 | 21 |
|
22 |
| -- 🚀 A model zoo where almost all models come with pre-trained weights on ImageNet. |
| 22 | +🚀 A model zoo where almost all models come with pre-trained weights on ImageNet. |
23 | 23 |
|
24 |
| - > **Note:** |
25 |
| - > The accuracy of the converted models can be found at [results-imagenet.csv (timm)](https://github.com/huggingface/pytorch-image-models/blob/main/results/results-imagenet.csv) and [https://keras.io/api/applications/ (keras)](https://keras.io/api/applications/), |
26 |
| - > and the numerical differences of the converted models can be verified in `tools/convert_*.py` |
| 24 | +> [!NOTE] |
| 25 | +> The accuracy of the converted models can be found at [results-imagenet.csv (timm)](https://github.com/huggingface/pytorch-image-models/blob/main/results/results-imagenet.csv) and [https://keras.io/api/applications/ (keras)](https://keras.io/api/applications/), |
| 26 | +> and the numerical differences of the converted models can be verified in `tools/convert_*.py`. |
27 | 27 |
|
28 |
| -- ✨ Exposing a common API identical to offcial `keras.applications.*`. |
| 28 | +✨ Exposing a common API identical to offcial `keras.applications.*`. |
29 | 29 |
|
30 |
| - ```python |
31 |
| - model = kimm.models.RegNetY002( |
32 |
| - input_tensor: keras.KerasTensor = None, |
33 |
| - input_shape: typing.Optional[typing.Sequence[int]] = None, |
34 |
| - include_preprocessing: bool = True, |
35 |
| - include_top: bool = True, |
36 |
| - pooling: typing.Optional[str] = None, |
37 |
| - dropout_rate: float = 0.0, |
38 |
| - classes: int = 1000, |
39 |
| - classifier_activation: str = "softmax", |
40 |
| - weights: typing.Optional[str] = "imagenet", |
41 |
| - name: str = "RegNetY002", |
42 |
| - ) |
43 |
| - ``` |
44 |
| - |
45 |
| -- 🔥 Integrated with feature extraction capability. |
46 |
| - |
47 |
| - ```python |
48 |
| - from keras import random |
49 |
| - import kimm |
50 |
| - |
51 |
| - model = kimm.models.ConvNeXtAtto(feature_extractor=True) |
52 |
| - x = random.uniform([1, 224, 224, 3]) |
53 |
| - y = model(x, training=False) |
54 |
| - # y becomes a dict |
55 |
| - for k, v in y.items(): |
56 |
| - print(k, v.shape) |
57 |
| - ``` |
58 |
| - |
59 |
| -- 🧰 Providing APIs to export models to `.tflite` and `.onnx`. |
60 |
| - |
61 |
| - ```python |
62 |
| - # in tensorflow backend |
63 |
| - from keras import backend |
64 |
| - import kimm |
65 |
| - |
66 |
| - backend.set_image_data_format("channels_last") |
67 |
| - model = kimm.models.MobileNet050V3Small() |
68 |
| - kimm.export.export_tflite(model, [224, 224, 3], "model.tflite") |
69 |
| - ``` |
70 |
| - |
71 |
| - ```python |
72 |
| - # in torch backend |
73 |
| - from keras import backend |
74 |
| - import kimm |
75 |
| - |
76 |
| - backend.set_image_data_format("channels_first") |
77 |
| - model = kimm.models.MobileNet050V3Small() |
78 |
| - kimm.export.export_onnx(model, [3, 224, 224], "model.onnx") |
79 |
| - ``` |
| 30 | +```python |
| 31 | +model = kimm.models.RegNetY002( |
| 32 | + input_tensor: keras.KerasTensor = None, |
| 33 | + input_shape: typing.Optional[typing.Sequence[int]] = None, |
| 34 | + include_preprocessing: bool = True, |
| 35 | + include_top: bool = True, |
| 36 | + pooling: typing.Optional[str] = None, |
| 37 | + dropout_rate: float = 0.0, |
| 38 | + classes: int = 1000, |
| 39 | + classifier_activation: str = "softmax", |
| 40 | + weights: typing.Optional[str] = "imagenet", |
| 41 | + name: str = "RegNetY002", |
| 42 | +) |
| 43 | +``` |
| 44 | + |
| 45 | +🔥 Integrated with feature extraction capability. |
| 46 | + |
| 47 | +```python |
| 48 | +model = kimm.models.ConvNeXtAtto(feature_extractor=True) |
| 49 | +x = keras.random.uniform([1, 224, 224, 3]) |
| 50 | +y = model.predict(x) |
| 51 | +# y becomes a dict |
| 52 | +for k, v in y.items(): |
| 53 | + print(k, v.shape) |
| 54 | +``` |
| 55 | + |
| 56 | +🧰 Providing APIs to export models to `.tflite` and `.onnx`. |
| 57 | + |
| 58 | +```python |
| 59 | +# tensorflow backend |
| 60 | +keras.backend.set_image_data_format("channels_last") |
| 61 | +model = kimm.models.MobileNet050V3Small() |
| 62 | +kimm.export.export_tflite(model, [224, 224, 3], "model.tflite") |
| 63 | +``` |
| 64 | + |
| 65 | +```python |
| 66 | +# torch backend |
| 67 | +keras.backend.set_image_data_format("channels_first") |
| 68 | +model = kimm.models.MobileNet050V3Small() |
| 69 | +kimm.export.export_onnx(model, [3, 224, 224], "model.onnx") |
| 70 | +``` |
| 71 | + |
| 72 | +> [!IMPORTANT] |
| 73 | +> `kimm.export.export_tflite` is currently restricted to `tensorflow` backend and `channels_last`. |
| 74 | +> `kimm.export.export_onnx` is currently restricted to `torch` backend and `channels_first`. |
| 75 | +
|
| 76 | +🔧 Supporting the reparameterization technique. |
| 77 | + |
| 78 | +```python |
| 79 | +model = kimm.models.RepVGGA0() |
| 80 | +reparameterized_model = kimm.utils.get_reparameterized_model(model) |
| 81 | +# or |
| 82 | +# reparameterized_model = model.get_reparameterized_model() |
| 83 | +y1 = model.predict(x) |
| 84 | +y2 = model.predict(x) |
| 85 | +np.testing.assert_allclose(y1, y2, atol=1e-5) |
| 86 | +``` |
80 | 87 |
|
81 | 88 | ## Installation
|
82 | 89 |
|
@@ -175,6 +182,7 @@ Reference: [Grad-CAM class activation visualization (keras.io)](https://keras.io
|
175 | 182 | |MobileViT|[ICLR 2022](https://arxiv.org/abs/2110.02178)|`timm`|`kimm.models.MobileViT*`|
|
176 | 183 | |MobileViTV2|[arXiv 2022](https://arxiv.org/abs/2206.02680)|`timm`|`kimm.models.MobileViTV2*`|
|
177 | 184 | |RegNet|[CVPR 2020](https://arxiv.org/abs/2003.13678)|`timm`|`kimm.models.RegNet*`|
|
| 185 | +|RepVGG|[CVPR 2021](https://arxiv.org/abs/2101.03697)|`timm`|`kimm.models.RepVGG*`| |
178 | 186 | |ResNet|[CVPR 2015](https://arxiv.org/abs/1512.03385)|`timm`|`kimm.models.ResNet*`|
|
179 | 187 | |TinyNet|[NeurIPS 2020](https://arxiv.org/abs/2010.14819)|`timm`|`kimm.models.TinyNet*`|
|
180 | 188 | |VGG|[ICLR 2015](https://arxiv.org/abs/1409.1556)|`timm`|`kimm.models.VGG*`|
|
|
0 commit comments