@@ -134,20 +134,16 @@ def fix_config(self, config):
134
134
return config
135
135
136
136
137
- """
138
- Model Definition
139
- """
137
+ # Model Definition
140
138
141
139
142
- class ConvMixer736D32 (ConvMixer ):
143
- available_feature_keys = ["STEM" , * [f"BLOCK{ i } " for i in range (32 )]]
144
- available_weights = [
145
- (
146
- "imagenet" ,
147
- ConvMixer .default_origin ,
148
- "convmixer736d32_convmixer_768_32.in1k.keras" ,
149
- )
150
- ]
140
+ class ConvMixerVariant (ConvMixer ):
141
+ # Parameters
142
+ depth = None
143
+ hidden_channels = None
144
+ patch_size = None
145
+ kernel_size = None
146
+ activation = None
151
147
152
148
def __init__ (
153
149
self ,
@@ -160,16 +156,21 @@ def __init__(
160
156
classes : int = 1000 ,
161
157
classifier_activation : str = "softmax" ,
162
158
weights : typing .Optional [str ] = "imagenet" ,
163
- name : str = "ConvMixer736D32" ,
159
+ name : typing . Optional [ str ] = None ,
164
160
** kwargs ,
165
161
):
162
+ if type (self ) is ConvMixerVariant :
163
+ raise NotImplementedError (
164
+ f"Cannot instantiate base class: { self .__class__ .__name__ } . "
165
+ "You should use its subclasses."
166
+ )
166
167
kwargs = self .fix_config (kwargs )
167
168
super ().__init__ (
168
- 32 ,
169
- 768 ,
170
- 7 ,
171
- 7 ,
172
- "relu" ,
169
+ depth = self . depth ,
170
+ hidden_channels = self . hidden_channels ,
171
+ patch_size = self . patch_size ,
172
+ kernel_size = self . kernel_size ,
173
+ activation = self . activation ,
173
174
input_tensor = input_tensor ,
174
175
input_shape = input_shape ,
175
176
include_preprocessing = include_preprocessing ,
@@ -179,12 +180,30 @@ def __init__(
179
180
classes = classes ,
180
181
classifier_activation = classifier_activation ,
181
182
weights = weights ,
182
- name = name ,
183
+ name = name or str ( self . __class__ . __name__ ) ,
183
184
** kwargs ,
184
185
)
185
186
186
187
187
- class ConvMixer1024D20 (ConvMixer ):
188
+ class ConvMixer736D32 (ConvMixerVariant ):
189
+ available_feature_keys = ["STEM" , * [f"BLOCK{ i } " for i in range (32 )]]
190
+ available_weights = [
191
+ (
192
+ "imagenet" ,
193
+ ConvMixer .default_origin ,
194
+ "convmixer736d32_convmixer_768_32.in1k.keras" ,
195
+ )
196
+ ]
197
+
198
+ # Parameters
199
+ depth = 32
200
+ hidden_channels = 768
201
+ patch_size = 7
202
+ kernel_size = 7
203
+ activation = "relu"
204
+
205
+
206
+ class ConvMixer1024D20 (ConvMixerVariant ):
188
207
available_feature_keys = ["STEM" , * [f"BLOCK{ i } " for i in range (20 )]]
189
208
available_weights = [
190
209
(
@@ -194,42 +213,15 @@ class ConvMixer1024D20(ConvMixer):
194
213
)
195
214
]
196
215
197
- def __init__ (
198
- self ,
199
- input_tensor : keras .KerasTensor = None ,
200
- input_shape : typing .Optional [typing .Sequence [int ]] = None ,
201
- include_preprocessing : bool = True ,
202
- include_top : bool = True ,
203
- pooling : typing .Optional [str ] = None ,
204
- dropout_rate : float = 0.0 ,
205
- classes : int = 1000 ,
206
- classifier_activation : str = "softmax" ,
207
- weights : typing .Optional [str ] = "imagenet" ,
208
- name : str = "ConvMixer1024D20" ,
209
- ** kwargs ,
210
- ):
211
- kwargs = self .fix_config (kwargs )
212
- super ().__init__ (
213
- 20 ,
214
- 1024 ,
215
- 14 ,
216
- 9 ,
217
- "gelu" ,
218
- input_tensor = input_tensor ,
219
- input_shape = input_shape ,
220
- include_preprocessing = include_preprocessing ,
221
- include_top = include_top ,
222
- pooling = pooling ,
223
- dropout_rate = dropout_rate ,
224
- classes = classes ,
225
- classifier_activation = classifier_activation ,
226
- weights = weights ,
227
- name = name ,
228
- ** kwargs ,
229
- )
216
+ # Parameters
217
+ depth = 20
218
+ hidden_channels = 1024
219
+ patch_size = 14
220
+ kernel_size = 9
221
+ activation = "gelu"
230
222
231
223
232
- class ConvMixer1536D20 (ConvMixer ):
224
+ class ConvMixer1536D20 (ConvMixerVariant ):
233
225
available_feature_keys = ["STEM" , * [f"BLOCK{ i } " for i in range (20 )]]
234
226
available_weights = [
235
227
(
@@ -239,39 +231,12 @@ class ConvMixer1536D20(ConvMixer):
239
231
)
240
232
]
241
233
242
- def __init__ (
243
- self ,
244
- input_tensor : keras .KerasTensor = None ,
245
- input_shape : typing .Optional [typing .Sequence [int ]] = None ,
246
- include_preprocessing : bool = True ,
247
- include_top : bool = True ,
248
- pooling : typing .Optional [str ] = None ,
249
- dropout_rate : float = 0.0 ,
250
- classes : int = 1000 ,
251
- classifier_activation : str = "softmax" ,
252
- weights : typing .Optional [str ] = "imagenet" ,
253
- name : str = "ConvMixer1536D20" ,
254
- ** kwargs ,
255
- ):
256
- kwargs = self .fix_config (kwargs )
257
- super ().__init__ (
258
- 20 ,
259
- 1536 ,
260
- 7 ,
261
- 9 ,
262
- "gelu" ,
263
- input_tensor = input_tensor ,
264
- input_shape = input_shape ,
265
- include_preprocessing = include_preprocessing ,
266
- include_top = include_top ,
267
- pooling = pooling ,
268
- dropout_rate = dropout_rate ,
269
- classes = classes ,
270
- classifier_activation = classifier_activation ,
271
- weights = weights ,
272
- name = name ,
273
- ** kwargs ,
274
- )
234
+ # Parameters
235
+ depth = 20
236
+ hidden_channels = 1536
237
+ patch_size = 7
238
+ kernel_size = 9
239
+ activation = "gelu"
275
240
276
241
277
242
add_model_to_registry (ConvMixer736D32 , "imagenet" )
0 commit comments