|
1 | 1 | # Keras Image Models
|
2 | 2 |
|
3 |
| -## Unit Tests |
| 3 | +## Description |
| 4 | + |
| 5 | +**K**eras **Im**age **M**odels (`kimm`) is a collection of image models, blocks and layers written in Keras 3. The goal is to offer SOTA models with pretrained weights in a user-friendly manner. |
| 6 | + |
| 7 | +## Installation |
4 | 8 |
|
5 | 9 | ```bash
|
6 |
| -# KERAS_BACKEND=jax|numpy|tensorflow|torch |
7 |
| -CUDA_VISIBLE_DEVICES= KERAS_BACKEND=tensorflow pytest |
| 10 | +# In a working [jax/tensorflow/torch/numpy] backend environment |
| 11 | +pip install keras kimm |
| 12 | +``` |
| 13 | + |
| 14 | +## Quickstart |
| 15 | + |
| 16 | +### Use Pretrained Model |
| 17 | + |
| 18 | +```python |
| 19 | +from keras import random |
| 20 | + |
| 21 | +import kimm |
| 22 | + |
| 23 | +# Use `kimm.list_models` to get the list of available models |
| 24 | +print(kimm.list_models()) |
| 25 | + |
| 26 | +# Specify the name and other arguments to filter the result |
| 27 | +print(kimm.list_models("efficientnet", has_pretrained=True)) # fuzzy search |
| 28 | + |
| 29 | +# Initialize the model with pretrained weights |
| 30 | +model = kimm.models.EfficientNetV2B0(weights="imagenet") |
| 31 | + |
| 32 | +# Predict |
| 33 | +x = random.uniform([1, 192, 192, 3]) * 255.0 |
| 34 | +y = model.predict(x) |
| 35 | +print(y.shape) |
| 36 | + |
| 37 | +# Initialize the model as a feature extractor with pretrained weights |
| 38 | +model = kimm.models.EfficientNetV2B0( |
| 39 | + as_feature_extractor=True, weights="imagenet" |
| 40 | +) |
| 41 | + |
| 42 | +# Extract features for downstream tasks |
| 43 | +y = model.predict(x) |
| 44 | +print(y.keys()) |
| 45 | +print(y["BLOCK5_S32"].shape) |
| 46 | +``` |
| 47 | + |
| 48 | +### Transfer Learning |
| 49 | + |
| 50 | +```python |
| 51 | +from keras import layers |
| 52 | +from keras import models |
| 53 | +from keras import random |
| 54 | + |
| 55 | +import kimm |
| 56 | + |
| 57 | +# Initialize the model as a backbone with pretrained weights |
| 58 | +backbone = kimm.models.EfficientNetV2B0( |
| 59 | + input_shape=[224, 224, 3], |
| 60 | + include_top=False, |
| 61 | + pooling="avg", |
| 62 | + weights="imagenet", |
| 63 | +) |
| 64 | + |
| 65 | +# Freeze the backbone for transfer learning |
| 66 | +backbone.trainable = False |
| 67 | + |
| 68 | +# Construct the model with new head |
| 69 | +inputs = layers.Input([224, 224, 3]) |
| 70 | +x = backbone(inputs, training=False) |
| 71 | +x = layers.Dropout(0.2)(x) |
| 72 | +outputs = layers.Dense(2)(x) |
| 73 | +model = models.Model(inputs, outputs) |
| 74 | + |
| 75 | +# Train the new model (put your own logic here) |
| 76 | + |
| 77 | +# Predict |
| 78 | +x = random.uniform([1, 224, 224, 3]) * 255.0 |
| 79 | +y = model.predict(x) |
| 80 | +print(y.shape) |
8 | 81 | ```
|
9 | 82 |
|
10 |
| -## Work in Progress |
| 83 | +## License |
| 84 | + |
| 85 | +Please refer to [timm](https://github.com/huggingface/pytorch-image-models#licenses) as this project is built upon it. |
| 86 | + |
| 87 | +### `kimm` Code |
| 88 | + |
| 89 | +The code here is licensed Apache 2.0. |
| 90 | + |
| 91 | +## Acknowledgements |
11 | 92 |
|
12 |
| -- Test pretrained weights |
| 93 | +Thanks for these awesome projects that were used in `kimm` |
13 | 94 |
|
14 |
| -## Acknowledgments |
| 95 | +- [https://github.com/keras-team/keras](https://github.com/keras-team/keras) |
| 96 | +- [https://github.com/huggingface/pytorch-image-models](https://github.com/huggingface/pytorch-image-models) |
| 97 | + |
| 98 | +## Citing |
| 99 | + |
| 100 | +### BibTeX |
| 101 | + |
| 102 | +```bash |
| 103 | +@misc{rw2019timm, |
| 104 | + author = {Ross Wightman}, |
| 105 | + title = {PyTorch Image Models}, |
| 106 | + year = {2019}, |
| 107 | + publisher = {GitHub}, |
| 108 | + journal = {GitHub repository}, |
| 109 | + doi = {10.5281/zenodo.4414861}, |
| 110 | + howpublished = {\url{https://github.com/rwightman/pytorch-image-models}} |
| 111 | +} |
| 112 | +``` |
| 113 | + |
| 114 | +```bash |
| 115 | +@misc{hy2024kimm, |
| 116 | + author = {Hongyu Chiu}, |
| 117 | + title = {Keras Image Models}, |
| 118 | + year = {2024}, |
| 119 | + publisher = {GitHub}, |
| 120 | + journal = {GitHub repository}, |
| 121 | + howpublished = {\url{https://github.com/james77777778/kimm}} |
| 122 | +} |
| 123 | +``` |
0 commit comments