Skip to content

Commit 6309e06

Browse files
Refactor model definition (#42)
* Refactor * Update version * Untrack `GlobalResponseNormalization` * Add warning if no available pretrained weights * Allow UserWarning
1 parent 195a622 commit 6309e06

24 files changed

+1661
-4720
lines changed

.github/workflows/actions.yml

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,17 @@ permissions:
1212
contents: read
1313

1414
jobs:
15+
format:
16+
name: Check the code format
17+
runs-on: ubuntu-latest
18+
steps:
19+
- uses: actions/checkout@v4
20+
- name: Set up Python 3.9
21+
uses: actions/setup-python@v5
22+
with:
23+
python-version: '3.9'
24+
- uses: pre-commit/action@v3.0.1
25+
1526
build:
1627
strategy:
1728
fail-fast: false
@@ -54,29 +65,3 @@ jobs:
5465
files: coverage.xml
5566
flags: kimm,kimm-${{ matrix.backend }}
5667
fail_ci_if_error: false
57-
58-
format:
59-
name: Check the code format
60-
runs-on: ubuntu-latest
61-
steps:
62-
- uses: actions/checkout@v4
63-
- name: Set up Python 3.9
64-
uses: actions/setup-python@v5
65-
with:
66-
python-version: '3.9'
67-
- name: Get pip cache dir
68-
id: Pip-cache
69-
run: |
70-
python -m pip install --upgrade pip setuptools
71-
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
72-
- name: pip cache
73-
uses: actions/cache@v4
74-
with:
75-
path: ${{ steps.pip-cache.outputs.dir }}
76-
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}-${{ hashFiles('requirements.txt') }}
77-
- name: Install dependencies
78-
run: |
79-
pip install -r requirements.txt --progress-bar off --upgrade
80-
pip install -e ".[tests]" --progress-bar off --upgrade
81-
- name: Lint
82-
run: bash shell/lint.sh

.pre-commit-config.yaml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v4.6.0
4+
hooks:
5+
- id: check-ast
6+
- id: check-merge-conflict
7+
- id: check-toml
8+
- id: check-yaml
9+
- id: end-of-file-fixer
10+
files: \.py$
11+
- id: debug-statements
12+
files: \.py$
13+
- id: trailing-whitespace
14+
files: \.py$
15+
16+
- repo: https://github.com/pycqa/isort
17+
rev: 5.13.2
18+
hooks:
19+
- id: isort
20+
name: isort (python)
21+
22+
- repo: https://github.com/psf/black-pre-commit-mirror
23+
rev: 24.4.2
24+
hooks:
25+
- id: black
26+
27+
- repo: https://github.com/astral-sh/ruff-pre-commit
28+
rev: v0.4.4
29+
hooks:
30+
- id: ruff
31+
args:
32+
- --fix
33+
- id: ruff-format

kimm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
from kimm import utils
44
from kimm.utils.model_registry import list_models
55

6-
__version__ = "0.1.8"
6+
__version__ = "0.2.0"

kimm/models/convmixer.py

Lines changed: 52 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -134,20 +134,16 @@ def fix_config(self, config):
134134
return config
135135

136136

137-
"""
138-
Model Definition
139-
"""
137+
# Model Definition
140138

141139

142-
class ConvMixer736D32(ConvMixer):
143-
available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(32)]]
144-
available_weights = [
145-
(
146-
"imagenet",
147-
ConvMixer.default_origin,
148-
"convmixer736d32_convmixer_768_32.in1k.keras",
149-
)
150-
]
140+
class ConvMixerVariant(ConvMixer):
141+
# Parameters
142+
depth = None
143+
hidden_channels = None
144+
patch_size = None
145+
kernel_size = None
146+
activation = None
151147

152148
def __init__(
153149
self,
@@ -160,16 +156,21 @@ def __init__(
160156
classes: int = 1000,
161157
classifier_activation: str = "softmax",
162158
weights: typing.Optional[str] = "imagenet",
163-
name: str = "ConvMixer736D32",
159+
name: typing.Optional[str] = None,
164160
**kwargs,
165161
):
162+
if type(self) is ConvMixerVariant:
163+
raise NotImplementedError(
164+
f"Cannot instantiate base class: {self.__class__.__name__}. "
165+
"You should use its subclasses."
166+
)
166167
kwargs = self.fix_config(kwargs)
167168
super().__init__(
168-
32,
169-
768,
170-
7,
171-
7,
172-
"relu",
169+
depth=self.depth,
170+
hidden_channels=self.hidden_channels,
171+
patch_size=self.patch_size,
172+
kernel_size=self.kernel_size,
173+
activation=self.activation,
173174
input_tensor=input_tensor,
174175
input_shape=input_shape,
175176
include_preprocessing=include_preprocessing,
@@ -179,12 +180,30 @@ def __init__(
179180
classes=classes,
180181
classifier_activation=classifier_activation,
181182
weights=weights,
182-
name=name,
183+
name=name or str(self.__class__.__name__),
183184
**kwargs,
184185
)
185186

186187

187-
class ConvMixer1024D20(ConvMixer):
188+
class ConvMixer736D32(ConvMixerVariant):
189+
available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(32)]]
190+
available_weights = [
191+
(
192+
"imagenet",
193+
ConvMixer.default_origin,
194+
"convmixer736d32_convmixer_768_32.in1k.keras",
195+
)
196+
]
197+
198+
# Parameters
199+
depth = 32
200+
hidden_channels = 768
201+
patch_size = 7
202+
kernel_size = 7
203+
activation = "relu"
204+
205+
206+
class ConvMixer1024D20(ConvMixerVariant):
188207
available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(20)]]
189208
available_weights = [
190209
(
@@ -194,42 +213,15 @@ class ConvMixer1024D20(ConvMixer):
194213
)
195214
]
196215

197-
def __init__(
198-
self,
199-
input_tensor: keras.KerasTensor = None,
200-
input_shape: typing.Optional[typing.Sequence[int]] = None,
201-
include_preprocessing: bool = True,
202-
include_top: bool = True,
203-
pooling: typing.Optional[str] = None,
204-
dropout_rate: float = 0.0,
205-
classes: int = 1000,
206-
classifier_activation: str = "softmax",
207-
weights: typing.Optional[str] = "imagenet",
208-
name: str = "ConvMixer1024D20",
209-
**kwargs,
210-
):
211-
kwargs = self.fix_config(kwargs)
212-
super().__init__(
213-
20,
214-
1024,
215-
14,
216-
9,
217-
"gelu",
218-
input_tensor=input_tensor,
219-
input_shape=input_shape,
220-
include_preprocessing=include_preprocessing,
221-
include_top=include_top,
222-
pooling=pooling,
223-
dropout_rate=dropout_rate,
224-
classes=classes,
225-
classifier_activation=classifier_activation,
226-
weights=weights,
227-
name=name,
228-
**kwargs,
229-
)
216+
# Parameters
217+
depth = 20
218+
hidden_channels = 1024
219+
patch_size = 14
220+
kernel_size = 9
221+
activation = "gelu"
230222

231223

232-
class ConvMixer1536D20(ConvMixer):
224+
class ConvMixer1536D20(ConvMixerVariant):
233225
available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(20)]]
234226
available_weights = [
235227
(
@@ -239,39 +231,12 @@ class ConvMixer1536D20(ConvMixer):
239231
)
240232
]
241233

242-
def __init__(
243-
self,
244-
input_tensor: keras.KerasTensor = None,
245-
input_shape: typing.Optional[typing.Sequence[int]] = None,
246-
include_preprocessing: bool = True,
247-
include_top: bool = True,
248-
pooling: typing.Optional[str] = None,
249-
dropout_rate: float = 0.0,
250-
classes: int = 1000,
251-
classifier_activation: str = "softmax",
252-
weights: typing.Optional[str] = "imagenet",
253-
name: str = "ConvMixer1536D20",
254-
**kwargs,
255-
):
256-
kwargs = self.fix_config(kwargs)
257-
super().__init__(
258-
20,
259-
1536,
260-
7,
261-
9,
262-
"gelu",
263-
input_tensor=input_tensor,
264-
input_shape=input_shape,
265-
include_preprocessing=include_preprocessing,
266-
include_top=include_top,
267-
pooling=pooling,
268-
dropout_rate=dropout_rate,
269-
classes=classes,
270-
classifier_activation=classifier_activation,
271-
weights=weights,
272-
name=name,
273-
**kwargs,
274-
)
234+
# Parameters
235+
depth = 20
236+
hidden_channels = 1536
237+
patch_size = 7
238+
kernel_size = 9
239+
activation = "gelu"
275240

276241

277242
add_model_to_registry(ConvMixer736D32, "imagenet")

0 commit comments

Comments
 (0)