Skip to content

Commit f6ef9f3

Browse files
Update README.md (#38)
* Update README * Update README * Update version
1 parent 68d45f0 commit f6ef9f3

File tree

2 files changed

+57
-90
lines changed

2 files changed

+57
-90
lines changed

README.md

Lines changed: 54 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# Keras Image Models
1515

1616
- [Introduction](#introduction)
17+
- [Usage](#usage)
1718
- [Installation](#installation)
1819
- [Quickstart](#quickstart)
1920
- [Image classification with ImageNet weights](#image-classification-using-the-model-pretrained-on-imagenet)
@@ -29,74 +30,67 @@
2930

3031
**KIMM** is:
3132

32-
🚀 A model zoo where almost all models come with **pre-trained weights on ImageNet**.
33+
- 🚀 A model zoo where almost all models come with **pre-trained weights on ImageNet**.
34+
- 🧰 Providing APIs to export models to `.tflite` and `.onnx`.
35+
- 🔧 Supporting the **reparameterization** technique.
36+
- ✨ Integrated with **feature extraction** capability.
3337

34-
> [!NOTE]
35-
> The accuracy of the converted models can be found at [results-imagenet.csv (timm)](https://github.com/huggingface/pytorch-image-models/blob/main/results/results-imagenet.csv) and [https://keras.io/api/applications/ (keras)](https://keras.io/api/applications/),
36-
> and the numerical differences of the converted models can be verified in `tools/convert_*.py`.
38+
## Usage
39+
40+
- `kimm.list_models`
41+
- `kimm.models.*.available_feature_keys`
42+
- `kimm.models.*(...)`
43+
- `kimm.models.*(..., feature_extractor=True, feature_keys=[...])`
44+
- `kimm.utils.get_reparameterized_model`
45+
- `kimm.export.export_tflite`
46+
- `kimm.export.export_onnx`
3747

38-
✨ Exposing a common API identical to offcial `keras.applications.*`.
39-
4048
```python
41-
model = kimm.models.RegNetY002(
42-
input_tensor: keras.KerasTensor = None,
43-
input_shape: typing.Optional[typing.Sequence[int]] = None,
44-
include_preprocessing: bool = True,
45-
include_top: bool = True,
46-
pooling: typing.Optional[str] = None,
47-
dropout_rate: float = 0.0,
48-
classes: int = 1000,
49-
classifier_activation: str = "softmax",
50-
weights: typing.Optional[str] = "imagenet",
51-
name: str = "RegNetY002",
52-
)
53-
```
49+
import keras
50+
import kimm
51+
import numpy as np
5452

55-
🔥 Integrated with **feature extraction** capability.
5653

57-
```python
58-
model = kimm.models.ConvNeXtAtto(feature_extractor=True)
54+
# List available models
55+
print(kimm.list_models("mobileone", weights="imagenet"))
56+
# ['MobileOneS0', 'MobileOneS1', 'MobileOneS2', 'MobileOneS3']
57+
58+
# Initialize model with pretrained ImageNet weights
5959
x = keras.random.uniform([1, 224, 224, 3])
60+
model = kimm.models.MobileOneS0()
6061
y = model.predict(x)
61-
# y becomes a dict
62-
for k, v in y.items():
63-
print(k, v.shape)
64-
```
62+
print(y.shape)
63+
# (1, 1000)
6564

66-
🧰 Providing APIs to export models to `.tflite` and `.onnx`.
65+
# Get reparameterized model by kimm.utils.get_reparameterized_model
66+
reparameterized_model = kimm.utils.get_reparameterized_model(model)
67+
y2 = reparameterized_model.predict(x)
68+
np.testing.assert_allclose(
69+
keras.ops.convert_to_numpy(y), keras.ops.convert_to_numpy(y2), atol=1e-5
70+
)
6771

68-
```python
69-
# tensorflow backend
70-
keras.backend.set_image_data_format("channels_last")
71-
model = kimm.models.MobileNetV3W050Small()
72-
kimm.export.export_tflite(model, [224, 224, 3], "model.tflite")
73-
```
72+
# Export model to tflite format
73+
kimm.export.export_tflite(reparameterized_model, 224, "model.tflite")
7474

75-
```python
76-
# torch backend
77-
keras.backend.set_image_data_format("channels_first")
78-
model = kimm.models.MobileNetV3W050Small()
79-
kimm.export.export_onnx(model, [3, 224, 224], "model.onnx")
80-
```
75+
# Export model to onnx format (note: must be "channels_first" format)
76+
# kimm.export.export_onnx(reparameterized_model, 224, "model.onnx")
8177

82-
> [!IMPORTANT]
83-
> `kimm.export.export_tflite` is currently restricted to `tensorflow` backend and `channels_last`.
84-
> `kimm.export.export_onnx` is currently restricted to `torch` backend and `channels_first`.
78+
# List available feature keys of the model class
79+
print(kimm.models.MobileOneS0.available_feature_keys)
80+
# ['STEM_S2', 'BLOCK0_S4', 'BLOCK1_S8', 'BLOCK2_S16', 'BLOCK3_S32']
8581

86-
🔧 Supporting the **reparameterization** technique.
82+
# Enable feature extraction by setting `feature_extractor=True`
83+
# `feature_keys` can be optionally specified
84+
model = kimm.models.MobileOneS0(
85+
feature_extractor=True, feature_keys=["BLOCK2_S16", "BLOCK3_S32"]
86+
)
87+
features = model.predict(x)
88+
for feature_name, feature in features.items():
89+
print(feature_name, feature.shape)
90+
# BLOCK2_S16 (1, 14, 14, 256)
91+
# BLOCK3_S32 (1, 7, 7, 1024)
92+
# TOP (1, 1000)
8793

88-
```python
89-
model = kimm.models.RepVGGA0()
90-
reparameterized_model = kimm.utils.get_reparameterized_model(model)
91-
# or
92-
# reparameterized_model = model.get_reparameterized_model()
93-
model.summary()
94-
# Total params: 9,132,616 (34.84 MB)
95-
reparameterized_model.summary()
96-
# Total params: 8,309,384 (31.70 MB)
97-
y1 = model.predict(x)
98-
y2 = reparameterized_model.predict(x)
99-
np.testing.assert_allclose(y1, y2, atol=1e-5)
10094
```
10195

10296
## Installation
@@ -111,41 +105,13 @@ pip install keras kimm -U
111105

112106
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14WxYgVjlwCIO9MwqPYW-dskbTL2UHsVN?usp=sharing)
113107

114-
```python
115-
import keras
116-
from keras import ops
117-
from keras import utils
118-
from keras.applications.imagenet_utils import decode_predictions
119-
120-
import kimm
108+
Using `kimm.models.VisionTransformerTiny16`:
121109

122-
# Use `kimm.list_models` to get the list of available models
123-
print(kimm.list_models())
124-
125-
# Specify the name and other arguments to filter the result
126-
print(kimm.list_models("vision_transformer", weights="imagenet")) # fuzzy search
127-
128-
# Initialize the model with pretrained weights
129-
model = kimm.models.VisionTransformerTiny16()
130-
image_size = (model._default_size, model._default_size)
131-
132-
# Load an image as the model input
133-
image_path = keras.utils.get_file(
134-
"african_elephant.jpg", "https://i.imgur.com/Bvro0YD.png"
135-
)
136-
image = utils.load_img(image_path, target_size=image_size)
137-
image = utils.img_to_array(image)
138-
x = ops.convert_to_tensor(image)
139-
x = ops.expand_dims(x, axis=0)
140-
141-
# Predict
142-
preds = model.predict(x)
143-
print("Predicted:", decode_predictions(preds, top=3)[0])
144-
```
110+
<div align="center">
111+
<img width="50%" src="https://github.com/james77777778/keras-image-models/assets/20734616/7caa4e5e-8561-425b-aaf2-6ae44ac3ea00" alt="african_elephant">
112+
</div>
145113

146114
```bash
147-
['ConvMixer1024D20', 'ConvMixer1536D20', 'ConvMixer736D32', 'ConvNeXtAtto', ...]
148-
['VisionTransformerBase16', 'VisionTransformerBase32', 'VisionTransformerSmall16', ...]
149115
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step
150116
Predicted: [('n02504458', 'African_elephant', 0.6895825), ('n01871265', 'tusker', 0.17934209), ('n02504013', 'Indian_elephant', 0.12927249)]
151117
```
@@ -171,7 +137,7 @@ Reference: [Transfer learning & fine-tuning (keras.io)](https://keras.io/guides/
171137
Using `kimm.models.MobileViTS`:
172138

173139
<div align="center">
174-
<img width="75%" src="https://github.com/james77777778/kimm/assets/20734616/cb5022a3-aaea-4324-a9cd-3d2e63a0a6b2" alt="grad_cam">
140+
<img width="50%" src="https://github.com/james77777778/kimm/assets/20734616/cb5022a3-aaea-4324-a9cd-3d2e63a0a6b2" alt="grad_cam">
175141
</div>
176142

177143
Reference: [Grad-CAM class activation visualization (keras.io)](https://keras.io/examples/vision/grad_cam/)

kimm/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from kimm import export
2-
from kimm import models # force to add models to the registry
2+
from kimm import models
3+
from kimm import utils
34
from kimm.utils.model_registry import list_models
45

5-
__version__ = "0.1.7"
6+
__version__ = "0.1.8"

0 commit comments

Comments
 (0)