|
| 1 | +import typing |
| 2 | + |
| 3 | +import keras |
| 4 | +from keras import layers |
| 5 | +from keras import utils |
| 6 | + |
| 7 | +from kimm.models.base_model import BaseModel |
| 8 | +from kimm.utils import add_model_to_registry |
| 9 | + |
| 10 | + |
| 11 | +def apply_convmixer_block( |
| 12 | + inputs, output_channels, kernel_size, activation, name="convmixer_block" |
| 13 | +): |
| 14 | + x = inputs |
| 15 | + |
| 16 | + # Depthwise |
| 17 | + x = layers.DepthwiseConv2D( |
| 18 | + kernel_size, |
| 19 | + 1, |
| 20 | + padding="same", |
| 21 | + activation=activation, |
| 22 | + use_bias=True, |
| 23 | + name=f"{name}_0_fn_0_dwconv2d", |
| 24 | + )(x) |
| 25 | + x = layers.BatchNormalization( |
| 26 | + momentum=0.9, epsilon=1e-5, name=f"{name}_0_fn_2" |
| 27 | + )(x) |
| 28 | + x = layers.Add()([x, inputs]) |
| 29 | + |
| 30 | + # Pointwise |
| 31 | + x = layers.Conv2D( |
| 32 | + output_channels, |
| 33 | + 1, |
| 34 | + 1, |
| 35 | + activation=activation, |
| 36 | + use_bias=True, |
| 37 | + name=f"{name}_1_conv2d", |
| 38 | + )(x) |
| 39 | + x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name=f"{name}_3")( |
| 40 | + x |
| 41 | + ) |
| 42 | + return x |
| 43 | + |
| 44 | + |
| 45 | +class ConvMixer(BaseModel): |
| 46 | + def __init__( |
| 47 | + self, |
| 48 | + depth: int = 32, |
| 49 | + hidden_channels: int = 768, |
| 50 | + patch_size: int = 7, |
| 51 | + kernel_size: int = 7, |
| 52 | + activation: str = "relu", |
| 53 | + **kwargs, |
| 54 | + ): |
| 55 | + parsed_kwargs = self.parse_kwargs(kwargs) |
| 56 | + img_input = self.determine_input_tensor( |
| 57 | + parsed_kwargs["input_tensor"], |
| 58 | + parsed_kwargs["input_shape"], |
| 59 | + parsed_kwargs["default_size"], |
| 60 | + ) |
| 61 | + x = img_input |
| 62 | + |
| 63 | + if parsed_kwargs["include_preprocessing"]: |
| 64 | + x = self.build_preprocessing(x, "imagenet") |
| 65 | + |
| 66 | + # Prepare feature extraction |
| 67 | + features = {} |
| 68 | + |
| 69 | + # Stem |
| 70 | + x = layers.Conv2D( |
| 71 | + hidden_channels, |
| 72 | + patch_size, |
| 73 | + patch_size, |
| 74 | + activation=activation, |
| 75 | + use_bias=True, |
| 76 | + name="stem_conv2d", |
| 77 | + )(x) |
| 78 | + x = layers.BatchNormalization( |
| 79 | + momentum=0.9, epsilon=1e-5, name="stem_bn" |
| 80 | + )(x) |
| 81 | + features["STEM"] = x |
| 82 | + |
| 83 | + # Blocks |
| 84 | + for i in range(depth): |
| 85 | + x = apply_convmixer_block( |
| 86 | + x, hidden_channels, kernel_size, activation, name=f"blocks_{i}" |
| 87 | + ) |
| 88 | + # Add feature |
| 89 | + features[f"BLOCK{i}"] = x |
| 90 | + |
| 91 | + # Head |
| 92 | + if parsed_kwargs["include_top"]: |
| 93 | + x = self.build_top( |
| 94 | + x, |
| 95 | + parsed_kwargs["classes"], |
| 96 | + parsed_kwargs["classifier_activation"], |
| 97 | + parsed_kwargs["dropout_rate"], |
| 98 | + ) |
| 99 | + else: |
| 100 | + if parsed_kwargs["pooling"] == "avg": |
| 101 | + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) |
| 102 | + elif parsed_kwargs["pooling"] == "max": |
| 103 | + x = layers.GlobalMaxPooling2D(name="max_pool")(x) |
| 104 | + |
| 105 | + # Ensure that the model takes into account |
| 106 | + # any potential predecessors of `input_tensor`. |
| 107 | + if parsed_kwargs["input_tensor"] is not None: |
| 108 | + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) |
| 109 | + else: |
| 110 | + inputs = img_input |
| 111 | + |
| 112 | + super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) |
| 113 | + |
| 114 | + # All references to `self` below this line |
| 115 | + self.add_references(parsed_kwargs) |
| 116 | + self.depth = depth |
| 117 | + self.hidden_channels = hidden_channels |
| 118 | + self.patch_size = patch_size |
| 119 | + self.kernel_size = kernel_size |
| 120 | + self.activation = activation |
| 121 | + |
| 122 | + @staticmethod |
| 123 | + def available_feature_keys(): |
| 124 | + raise NotImplementedError |
| 125 | + |
| 126 | + def get_config(self): |
| 127 | + config = super().get_config() |
| 128 | + config.update( |
| 129 | + { |
| 130 | + "depth": self.depth, |
| 131 | + "hidden_channels": self.hidden_channels, |
| 132 | + "patch_size": self.patch_size, |
| 133 | + "kernel_size": self.kernel_size, |
| 134 | + "activation": self.activation, |
| 135 | + } |
| 136 | + ) |
| 137 | + return config |
| 138 | + |
| 139 | + def fix_config(self, config): |
| 140 | + unused_kwargs = [ |
| 141 | + "depth", |
| 142 | + "hidden_channels", |
| 143 | + "patch_size", |
| 144 | + "kernel_size", |
| 145 | + "activation", |
| 146 | + ] |
| 147 | + for k in unused_kwargs: |
| 148 | + config.pop(k, None) |
| 149 | + return config |
| 150 | + |
| 151 | + |
| 152 | +""" |
| 153 | +Model Definition |
| 154 | +""" |
| 155 | + |
| 156 | + |
| 157 | +class ConvMixer736D32(ConvMixer): |
| 158 | + def __init__( |
| 159 | + self, |
| 160 | + input_tensor: keras.KerasTensor = None, |
| 161 | + input_shape: typing.Optional[typing.Sequence[int]] = None, |
| 162 | + include_preprocessing: bool = True, |
| 163 | + include_top: bool = True, |
| 164 | + pooling: typing.Optional[str] = None, |
| 165 | + dropout_rate: float = 0.0, |
| 166 | + classes: int = 1000, |
| 167 | + classifier_activation: str = "softmax", |
| 168 | + weights: typing.Optional[str] = None, |
| 169 | + name: str = "ConvMixer736D32", |
| 170 | + **kwargs, |
| 171 | + ): |
| 172 | + kwargs = self.fix_config(kwargs) |
| 173 | + super().__init__( |
| 174 | + 32, |
| 175 | + 768, |
| 176 | + 7, |
| 177 | + 7, |
| 178 | + "relu", |
| 179 | + input_tensor=input_tensor, |
| 180 | + input_shape=input_shape, |
| 181 | + include_preprocessing=include_preprocessing, |
| 182 | + include_top=include_top, |
| 183 | + pooling=pooling, |
| 184 | + dropout_rate=dropout_rate, |
| 185 | + classes=classes, |
| 186 | + classifier_activation=classifier_activation, |
| 187 | + weights=weights, |
| 188 | + name=name, |
| 189 | + **kwargs, |
| 190 | + ) |
| 191 | + |
| 192 | + @staticmethod |
| 193 | + def available_feature_keys(): |
| 194 | + feature_keys = ["STEM"] |
| 195 | + feature_keys.extend([f"BLOCK{i}" for i in range(32)]) |
| 196 | + return feature_keys |
| 197 | + |
| 198 | + |
| 199 | +class ConvMixer1024D20(ConvMixer): |
| 200 | + def __init__( |
| 201 | + self, |
| 202 | + input_tensor: keras.KerasTensor = None, |
| 203 | + input_shape: typing.Optional[typing.Sequence[int]] = None, |
| 204 | + include_preprocessing: bool = True, |
| 205 | + include_top: bool = True, |
| 206 | + pooling: typing.Optional[str] = None, |
| 207 | + dropout_rate: float = 0.0, |
| 208 | + classes: int = 1000, |
| 209 | + classifier_activation: str = "softmax", |
| 210 | + weights: typing.Optional[str] = None, |
| 211 | + name: str = "ConvMixer1024D20", |
| 212 | + **kwargs, |
| 213 | + ): |
| 214 | + kwargs = self.fix_config(kwargs) |
| 215 | + super().__init__( |
| 216 | + 20, |
| 217 | + 1024, |
| 218 | + 14, |
| 219 | + 9, |
| 220 | + "gelu", |
| 221 | + input_tensor=input_tensor, |
| 222 | + input_shape=input_shape, |
| 223 | + include_preprocessing=include_preprocessing, |
| 224 | + include_top=include_top, |
| 225 | + pooling=pooling, |
| 226 | + dropout_rate=dropout_rate, |
| 227 | + classes=classes, |
| 228 | + classifier_activation=classifier_activation, |
| 229 | + weights=weights, |
| 230 | + name=name, |
| 231 | + **kwargs, |
| 232 | + ) |
| 233 | + |
| 234 | + @staticmethod |
| 235 | + def available_feature_keys(): |
| 236 | + feature_keys = ["STEM"] |
| 237 | + feature_keys.extend([f"BLOCK{i}" for i in range(20)]) |
| 238 | + return feature_keys |
| 239 | + |
| 240 | + |
| 241 | +class ConvMixer1536D20(ConvMixer): |
| 242 | + def __init__( |
| 243 | + self, |
| 244 | + input_tensor: keras.KerasTensor = None, |
| 245 | + input_shape: typing.Optional[typing.Sequence[int]] = None, |
| 246 | + include_preprocessing: bool = True, |
| 247 | + include_top: bool = True, |
| 248 | + pooling: typing.Optional[str] = None, |
| 249 | + dropout_rate: float = 0.0, |
| 250 | + classes: int = 1000, |
| 251 | + classifier_activation: str = "softmax", |
| 252 | + weights: typing.Optional[str] = None, |
| 253 | + name: str = "ConvMixer1536D20", |
| 254 | + **kwargs, |
| 255 | + ): |
| 256 | + kwargs = self.fix_config(kwargs) |
| 257 | + super().__init__( |
| 258 | + 20, |
| 259 | + 1536, |
| 260 | + 7, |
| 261 | + 9, |
| 262 | + "gelu", |
| 263 | + input_tensor=input_tensor, |
| 264 | + input_shape=input_shape, |
| 265 | + include_preprocessing=include_preprocessing, |
| 266 | + include_top=include_top, |
| 267 | + pooling=pooling, |
| 268 | + dropout_rate=dropout_rate, |
| 269 | + classes=classes, |
| 270 | + classifier_activation=classifier_activation, |
| 271 | + weights=weights, |
| 272 | + name=name, |
| 273 | + **kwargs, |
| 274 | + ) |
| 275 | + |
| 276 | + @staticmethod |
| 277 | + def available_feature_keys(): |
| 278 | + feature_keys = ["STEM"] |
| 279 | + feature_keys.extend([f"BLOCK{i}" for i in range(20)]) |
| 280 | + return feature_keys |
| 281 | + |
| 282 | + |
| 283 | +add_model_to_registry(ConvMixer736D32, "imagenet") |
| 284 | +add_model_to_registry(ConvMixer1024D20, "imagenet") |
| 285 | +add_model_to_registry(ConvMixer1536D20, "imagenet") |
0 commit comments