Skip to content

Commit 7a0f2e7

Browse files
Update readme (#9)
1 parent ca90fa4 commit 7a0f2e7

File tree

2 files changed

+136
-9
lines changed

2 files changed

+136
-9
lines changed

README.md

Lines changed: 115 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,123 @@
11
# Keras Image Models
22

3-
## Unit Tests
3+
## Description
4+
5+
**K**eras **Im**age **M**odels (`kimm`) is a collection of image models, blocks and layers written in Keras 3. The goal is to offer SOTA models with pretrained weights in a user-friendly manner.
6+
7+
## Installation
48

59
```bash
6-
# KERAS_BACKEND=jax|numpy|tensorflow|torch
7-
CUDA_VISIBLE_DEVICES= KERAS_BACKEND=tensorflow pytest
10+
# In a working [jax/tensorflow/torch/numpy] backend environment
11+
pip install keras kimm
12+
```
13+
14+
## Quickstart
15+
16+
### Use Pretrained Model
17+
18+
```python
19+
from keras import random
20+
21+
import kimm
22+
23+
# Use `kimm.list_models` to get the list of available models
24+
print(kimm.list_models())
25+
26+
# Specify the name and other arguments to filter the result
27+
print(kimm.list_models("efficientnet", has_pretrained=True)) # fuzzy search
28+
29+
# Initialize the model with pretrained weights
30+
model = kimm.models.EfficientNetV2B0(weights="imagenet")
31+
32+
# Predict
33+
x = random.uniform([1, 192, 192, 3]) * 255.0
34+
y = model.predict(x)
35+
print(y.shape)
36+
37+
# Initialize the model as a feature extractor with pretrained weights
38+
model = kimm.models.EfficientNetV2B0(
39+
as_feature_extractor=True, weights="imagenet"
40+
)
41+
42+
# Extract features for downstream tasks
43+
y = model.predict(x)
44+
print(y.keys())
45+
print(y["BLOCK5_S32"].shape)
46+
```
47+
48+
### Transfer Learning
49+
50+
```python
51+
from keras import layers
52+
from keras import models
53+
from keras import random
54+
55+
import kimm
56+
57+
# Initialize the model as a backbone with pretrained weights
58+
backbone = kimm.models.EfficientNetV2B0(
59+
input_shape=[224, 224, 3],
60+
include_top=False,
61+
pooling="avg",
62+
weights="imagenet",
63+
)
64+
65+
# Freeze the backbone for transfer learning
66+
backbone.trainable = False
67+
68+
# Construct the model with new head
69+
inputs = layers.Input([224, 224, 3])
70+
x = backbone(inputs, training=False)
71+
x = layers.Dropout(0.2)(x)
72+
outputs = layers.Dense(2)(x)
73+
model = models.Model(inputs, outputs)
74+
75+
# Train the new model (put your own logic here)
76+
77+
# Predict
78+
x = random.uniform([1, 224, 224, 3]) * 255.0
79+
y = model.predict(x)
80+
print(y.shape)
881
```
982

10-
## Work in Progress
83+
## License
84+
85+
Please refer to [timm](https://github.com/huggingface/pytorch-image-models#licenses) as this project is built upon it.
86+
87+
### `kimm` Code
88+
89+
The code here is licensed Apache 2.0.
90+
91+
## Acknowledgements
1192

12-
- Test pretrained weights
93+
Thanks for these awesome projects that were used in `kimm`
1394

14-
## Acknowledgments
95+
- [https://github.com/keras-team/keras](https://github.com/keras-team/keras)
96+
- [https://github.com/huggingface/pytorch-image-models](https://github.com/huggingface/pytorch-image-models)
97+
98+
## Citing
99+
100+
### BibTeX
101+
102+
```bash
103+
@misc{rw2019timm,
104+
author = {Ross Wightman},
105+
title = {PyTorch Image Models},
106+
year = {2019},
107+
publisher = {GitHub},
108+
journal = {GitHub repository},
109+
doi = {10.5281/zenodo.4414861},
110+
howpublished = {\url{https://github.com/rwightman/pytorch-image-models}}
111+
}
112+
```
113+
114+
```bash
115+
@misc{hy2024kimm,
116+
author = {Hongyu Chiu},
117+
title = {Keras Image Models},
118+
year = {2024},
119+
publisher = {GitHub},
120+
journal = {GitHub repository},
121+
howpublished = {\url{https://github.com/james77777778/kimm}}
122+
}
123+
```

kimm/utils/model_registry.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,23 @@
99
MODEL_REGISTRY: typing.List[typing.Dict[str, typing.Union[str, bool]]] = []
1010

1111

12+
def _match_string(query: str, target: str):
13+
query = query.lower().replace(" ", "").replace("_", "").replace(".", "")
14+
target = target.lower()
15+
matched_idx = -1
16+
for q_char in query:
17+
matched = False
18+
for idx, t_char in enumerate(target):
19+
if matched:
20+
break
21+
if q_char == t_char and idx > matched_idx:
22+
matched_idx = idx
23+
matched = True
24+
if not matched:
25+
return False
26+
return True
27+
28+
1229
def clear_registry():
1330
MODEL_REGISTRY.clear()
1431

@@ -47,9 +64,10 @@ def list_models(
4764
result_names.add(info["name"])
4865
need_remove = False
4966

50-
# filter by the args
51-
if name is not None and name.lower() not in info["name"].lower():
52-
need_remove = True
67+
# match string (simple implementation)
68+
if name is not None:
69+
need_remove = not _match_string(name, info["name"])
70+
# filter by support_feature and has_pretrained
5371
if (
5472
support_feature is not None
5573
and info["support_feature"] is not support_feature

0 commit comments

Comments
 (0)