@@ -198,209 +198,3 @@ def forward(self, x: torch.Tensor, style: torch.Tensor = None) -> torch.Tensor:
198
198
x = x + identity
199
199
200
200
return x
201
-
202
-
203
- ##############
204
- ##############
205
- ##############
206
- # from typing import List, Tuple
207
-
208
- # import torch
209
- # import torch.nn as nn
210
-
211
- # from .conv_block import ConvBlock
212
- # from .misc_modules import ChannelPool
213
-
214
- # __all__ = ["ConvLayer"]
215
-
216
-
217
- # class ConvLayer(nn.Module):
218
- # def __init__(
219
- # self,
220
- # in_channels: int,
221
- # out_channels: int,
222
- # n_blocks: int = 2,
223
- # layer_residual: bool = False,
224
- # short_skip: str = "residual",
225
- # style_channels: int = None,
226
- # expand_ratios: Tuple[float, ...] = (1.0, 1.0),
227
- # block_types: Tuple[str, ...] = ("basic", "basic"),
228
- # normalizations: Tuple[str, ...] = ("bn", "bn"),
229
- # activations: Tuple[str, ...] = ("relu", "relu"),
230
- # convolutions: Tuple[str, ...] = ("conv", "conv"),
231
- # kernel_sizes: Tuple[int, ...] = (3, 3),
232
- # groups: Tuple[int, ...] = (1, 1),
233
- # biases: Tuple[bool, ...] = (True, True),
234
- # preactivates: Tuple[bool, ...] = (False, False),
235
- # attentions: Tuple[str, ...] = (None, None),
236
- # preattends: Tuple[bool, ...] = (False, False),
237
- # use_styles: Tuple[bool, ...] = (False, False),
238
- # **kwargs,
239
- # ) -> None:
240
- # """Chain conv-blocks in a ModuleDict to compose a full layer.
241
-
242
- # Optional:
243
- # - add a style vector to the output at the end of each conv block(Cellpose)
244
-
245
- # Parameters
246
- # ----------
247
- # in_channels : int
248
- # Number of input channels.
249
- # out_channels : int
250
- # Number of output channels.
251
- # n_blocks : int, default=2
252
- # Number of ConvBlocks used in this layer.
253
- # layer_residual : bool, default=False
254
- # Apply a layer level residual skip. I.e x + layer(x). NOTE: residual
255
- # skips can be also applied inside the ConvBlocks, so this is justextra.
256
- # style_channels : int, default=None
257
- # Number of style vector channels. If None, style vectors are ignored.
258
- # short_skip : str, default="residual"
259
- # The name of the short skip method. One of: "residual", "dense","basic"
260
- # expand_ratios : Tuple[float, ...], default=(1.0, 1.0):
261
- # Expansion/Squeeze ratios for the out channels of each conv block.
262
- # block_types : Tuple[str, ...], default=("basic", "basic")
263
- # The name of the conv-blocks. Length of the tuple has toequal`n_blocks`
264
- # One of: "basic". "mbconv", "fmbconv" "dws", "bottleneck".
265
- # normalizations : Tuple[str, ...], default=("bn", "bn"):
266
- # Normalization methods. One of: "bn", "bcn", "gn", "in", "ln", "lrn"
267
- # activations : Tuple[str, ...], default=("relu", "relu")
268
- # Activation methods. One of: "mish", "swish", "relu", "relu6", "rrelu",
269
- # "selu", "celu", "gelu", "glu", "tanh", "sigmoid", "silu", "prelu",
270
- # "leaky-relu", "elu", "hardshrink", "tanhshrink", "hardsigmoid"
271
- # convolutions : Tuple[str, ...], default=("conv", "conv")
272
- # The convolution method. One of: "conv", "wsconv", "scaled_wsconv"
273
- # preactivates : Tuple[bool, ...], default=(False, False)
274
- # Pre-activations flags for the conv-blocks.
275
- # kernel_sizes : Tuple[int, ...], default=(3, 3)
276
- # The size of the convolution kernels in each conv block.
277
- # groups : int, default=(1, 1)
278
- # Number of groups for the kernels in each convolution blocks.
279
- # biases : Tuple[bool, ...], default=(True, True)
280
- # Include bias terms in the convolution blocks.
281
- # attentions : Tuple[str, ...], default=(None, None)
282
- # Attention methods. One of: "se", "scse", "gc", "eca", None
283
- # preattends : Tuple[bool, ...], default=(False, False)
284
- # If True, Attention is applied at the beginning of forward pass.
285
- # use_styles : Tuple[bool, ...], default=(False, False)
286
- # If True and `style_channels` is not None, adds a style vec to the
287
- # ConvBlock outputs.
288
-
289
- # Raises
290
- # ------
291
- # ValueError:
292
- # If lengths of the tuple arguments are not equal to `n_blocks`.
293
- # """
294
- # super().__init__()
295
- # self.layer_residual = layer_residual
296
- # self.short_skip = short_skip
297
- # self.in_channels = in_channels
298
-
299
- # illegal_args = [
300
- # (k, a)
301
- # for k, a in locals().items()
302
- # if isinstance(a, tuple) and len(a) != n_blocks
303
- # ]
304
-
305
- # if illegal_args:
306
- # raise ValueError(
307
- # f"All the tuple-arg lengths need to be equalto`n_blocks`={n_blocks}. "
308
- # f"Illegal args: {illegal_args}"
309
- # )
310
-
311
- # self.conv_blocks = nn.ModuleDict()
312
- # blocks = list(range(n_blocks))
313
- # for i in blocks:
314
- # out = int(out_channels * expand_ratios[i])
315
-
316
- # conv_block = ConvBlock(
317
- # name=block_types[i],
318
- # in_channels=in_channels,
319
- # out_channels=out,
320
- # style_channels=style_channels,
321
- # short_skip=short_skip,
322
- # kernel_size=kernel_sizes[i],
323
- # groups=groups[i],
324
- # bias=biases[i],
325
- # normalization=normalizations[i],
326
- # convolution=convolutions[i],
327
- # activation=activations[i],
328
- # attention=attentions[i],
329
- # preactivate=preactivates[i],
330
- # preattend=preattends[i],
331
- # use_style=use_styles[i],
332
- # **kwargs,
333
- # )
334
- # self.conv_blocks[f"{short_skip}_{block_types[i]}_{i + 1}"] = conv_block
335
-
336
- # if short_skip == "dense":
337
- # in_channels += conv_block.out_channels
338
- # else:
339
- # in_channels = conv_block.out_channels
340
-
341
- # self.out_channels = conv_block.out_channels
342
-
343
- # if short_skip == "dense":
344
- # self.transition = ConvBlock(
345
- # name="basic",
346
- # in_channels=in_channels,
347
- # short_skip="basic",
348
- # out_channels=out_channels,
349
- # same_padding=False,
350
- # bias=False,
351
- # kernel_size=1,
352
- # convolution=conv_block.block.conv_choice,
353
- # normalization=normalizations[-1],
354
- # activation=activations[-1],
355
- # preactivate=preactivates[-1],
356
- # )
357
- # self.out_channels = self.transition.out_channels
358
-
359
- # self.downsample = None
360
- # if layer_residual and self.in_channels != self.out_channels:
361
- # self.downsample = ChannelPool(
362
- # in_channels=self.in_channels,
363
- # out_channels=self.out_channels,
364
- # convolution=convolutions[-1],
365
- # normalization=normalizations[-1],
366
- # )
367
-
368
- # def forward_features_dense(
369
- # self, init_features: List[torch.Tensor], style: torch.Tensor = None
370
- # ) -> torch.Tensor:
371
- # """Dense forward pass."""
372
- # features = [init_features]
373
- # for conv_block in self.conv_blocks.values():
374
- # new_features = conv_block(features, style)
375
- # features.append(new_features)
376
-
377
- # x = torch.cat(features, 1)
378
- # x = self.transition(x)
379
-
380
- # return x
381
-
382
- # def forward_features(
383
- # self, x: torch.Tensor, style: torch.Tensor = None
384
- # ) -> torch.Tensor:
385
- # """Regular forward pass."""
386
- # for conv_block in self.conv_blocks.values():
387
- # x = conv_block(x, style)
388
-
389
- # return x
390
-
391
- # def forward(self, x: torch.Tensor, style: torch.Tensor = None) -> torch.Tensor:
392
- # """Forward pass of the conv-layer."""
393
- # if self.layer_residual:
394
- # identity = x
395
- # if self.downsample is not None:
396
- # identity = self.downsample(x)
397
-
398
- # if self.short_skip == "dense":
399
- # x = self.forward_features_dense(x, style)
400
- # else:
401
- # x = self.forward_features(x, style)
402
-
403
- # if self.layer_residual:
404
- # x = x + identity
405
-
406
- # return x
0 commit comments