Skip to content

Commit e398eba

Browse files
committed
feat(encoders): update encoder module to use only timm encoders. Re-build the unetTR upsampler module for generic image transformer encoders.
1 parent b972309 commit e398eba

File tree

3 files changed

+337
-140
lines changed

3 files changed

+337
-140
lines changed
Lines changed: 39 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,98 +1,64 @@
1-
from typing import Any, Dict, Tuple, Union
1+
from typing import Any, Dict, Tuple
22

33
import torch
44
import torch.nn as nn
55

6-
from .dino_vit import build_dinov2_encoder
7-
from .histo_encoder import build_histo_encoder
6+
from .encoder_upsampler import EncoderUpsampler
87
from .timm_encoder import TimmEncoder
9-
from .unettr_encoder import EncoderUnetTR
10-
from .vit_det_SAM import build_sam_encoder
118

129
__all__ = ["Encoder"]
1310

1411

15-
TR_ENCODERS = {
16-
"histo_encoder_prostate_s": build_histo_encoder,
17-
"histo_encoder_prostate_m": build_histo_encoder,
18-
"sam_vit_l": build_sam_encoder,
19-
"sam_vit_b": build_sam_encoder,
20-
"sam_vit_h": build_sam_encoder,
21-
"dinov2_vit_small": build_dinov2_encoder,
22-
"dinov2_vit_base": build_dinov2_encoder,
23-
"dinov2_vit_large": build_dinov2_encoder,
24-
"dinov2_vit_giant": build_dinov2_encoder,
25-
}
26-
27-
2812
class Encoder(nn.Module):
2913
def __init__(
3014
self,
31-
name: str,
32-
pretrained: bool = False,
33-
checkpoint_path: str = None,
34-
in_channels: int = 3,
35-
depth: int = 4,
36-
out_indices: Tuple[int] = None,
37-
unettr_kwargs: Dict[str, Any] = None,
38-
**kwargs,
15+
timm_encoder_name: str,
16+
timm_encoder_out_indices: Tuple[int, ...],
17+
pixel_decoder_out_channels: Tuple[int, ...],
18+
timm_encoder_pretrained: bool = True,
19+
timm_extra_kwargs: Dict[str, Any] = {},
3920
) -> None:
40-
"""Wrap timm conv-based encoders and transformer-based encoders to one class.
41-
42-
NOTE: Refer to the docstring of the `TimmEncoder` and `EncoderUnetTR` for the
43-
input key-word arguments (**kwargs).
21+
"""Wrap timm encoders to one class.
4422
4523
Parameters
4624
----------
47-
name : str
25+
timm_encoder_name : str
4826
Name of the encoder. If the name is in `TR_ENCODERS.keys()`, a transformer
4927
will be used. Otherwise, a timm encoder will be used.
50-
pretrained : bool, optional, default=False
51-
If True, load imagenet pretrained weights, by default False.
52-
checkpoint_path : str, optional
53-
Path to the weights of the encoder. If None, the encoder is initialized
54-
with imagenet pre-trained weights if `enc_pretrain` argument is set to True
55-
or with random weights if set to False. Defaults to None.
56-
in_channels : int, optional
57-
Number of input channels, by default 3.
58-
depth : int, optional
59-
Number of output features, by default 4. Ignored for transformer encoders.
60-
out_indices : Tuple[int], optional
61-
Indices of the output features, by default None. If None,
62-
out_indices is set to range(len(depth)). Overrides the `depth` argument.
63-
unettr_kwargs : Dict[str, Any]
64-
Key-word arguments for the transformer encoder. These arguments are used
65-
only if the encoder is transformer based. Refer to the docstring of the
66-
`EncoderUnetTR`
67-
**kwargs : Dict[str, Any]
68-
Key-word arguments for any `timm` based encoder. These arguments are used
69-
in `timm.create_model(**kwargs)` function call.
28+
timm_encoder_out_indices : Tuple[int], optional
29+
Indices of the output features.
30+
pixel_decoder_out_channels : Tuple[int], optional
31+
Number of output channels at each upsampling stage.
32+
timm_encoder_pretrained : bool, optional, default=False
33+
If True, load pretrained timm weights, by default False.
34+
timm_extra_kwargs : Dict[str, Any], optional, default={}
35+
Key-word arguments for any `timm` based encoder. These arguments are
36+
used in `timm.create_model(**kwargs)` function call.
7037
"""
7138
super().__init__()
7239

73-
if name not in TR_ENCODERS.keys():
74-
self.encoder = TimmEncoder(
75-
name,
76-
pretrained=pretrained,
77-
checkpoint_path=checkpoint_path,
78-
in_channels=in_channels,
79-
depth=depth,
80-
out_indices=out_indices,
81-
**kwargs,
82-
)
83-
else:
84-
self.encoder = EncoderUnetTR(
85-
backbone=TR_ENCODERS[name](
86-
name,
87-
pretrained=pretrained,
88-
checkpoint_path=checkpoint_path,
89-
),
90-
**unettr_kwargs if unettr_kwargs is not None else {},
40+
# initialize timm encoder
41+
self.encoder = TimmEncoder(
42+
timm_encoder_name,
43+
pretrained=timm_encoder_pretrained,
44+
out_indices=timm_encoder_out_indices,
45+
extra_kwargs=timm_extra_kwargs,
46+
)
47+
48+
# initialize feature upsampler if encoder is a vision transformer
49+
feature_info = self.encoder.feature_info
50+
reductions = [finfo["reduction"] for finfo in feature_info]
51+
if all(element == reductions[0] for element in reductions):
52+
self.encoder = EncoderUpsampler(
53+
backbone=self.encoder,
54+
out_channels=pixel_decoder_out_channels,
9155
)
56+
feature_info = self.encoder.feature_info
9257

93-
self.out_channels = self.encoder.out_channels
94-
self.feature_info = self.encoder.feature_info
58+
self.out_channels = [f["num_chs"] for f in self.encoder.feature_info][::-1]
59+
self.feature_info = self.encoder.feature_info[::-1]
9560

96-
def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
61+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
9762
"""Forward pass of the encoder and return all the features."""
98-
return self.encoder(x)
63+
output, feats = self.encoder(x)
64+
return output, feats[::-1]
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
from typing import Tuple
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
__all__ = ["EncoderUpsampler", "FeatUpsampleBlock"]
7+
8+
9+
class FeatUpsampleBlock(nn.Module):
10+
def __init__(
11+
self,
12+
in_channels: int,
13+
out_channels: int = None,
14+
scale_factor: int = 2,
15+
) -> None:
16+
"""Upsample 2D dimensions of a feature.
17+
18+
TransConv + Conv layers
19+
20+
Parameters:
21+
in_channels (int):
22+
Number of input channels.
23+
out_channels (int):
24+
Number of output channels.
25+
scale_factor (int):
26+
Scale factor for upsampling. Defaults to 2.
27+
"""
28+
super().__init__()
29+
if out_channels is None:
30+
out_channels = in_channels
31+
32+
self.scale_factor = scale_factor
33+
self.out_channels = out_channels
34+
35+
if isinstance(scale_factor, int):
36+
self.up = nn.ConvTranspose2d(
37+
in_channels=in_channels,
38+
out_channels=out_channels,
39+
kernel_size=2 ** (scale_factor - 1),
40+
stride=2 ** (scale_factor - 1),
41+
padding=0,
42+
output_padding=0,
43+
)
44+
else:
45+
self.up = nn.Upsample(
46+
scale_factor=scale_factor,
47+
mode="bilinear",
48+
align_corners=True,
49+
)
50+
51+
self.conv_block = nn.Conv2d(
52+
in_channels=out_channels,
53+
out_channels=out_channels,
54+
kernel_size=3,
55+
padding=1,
56+
)
57+
58+
def forward(self, x: torch.Tensor) -> torch.Tensor:
59+
x = self.up(x)
60+
x = self.conv_block(x)
61+
return x
62+
63+
64+
class EncoderUpsampler(nn.Module):
65+
def __init__(
66+
self,
67+
backbone: nn.Module,
68+
out_channels: Tuple[int, ...],
69+
) -> None:
70+
"""Feature upsampler for transformer-like backbones.
71+
72+
Note:
73+
This is a U-NetTR like upsampler that takes the features from the backbone
74+
and upsamples them such that the scale factor between the upsampled features
75+
are two. Builds an image-pyramid like structure.
76+
77+
Parameters:
78+
backbone (nn.Module):
79+
Backbone network that extracts features.
80+
out_channels (Tuple[int, ...]):
81+
Number of channels in the output tensor of each upsampling block.
82+
Defaults to None.
83+
"""
84+
print(out_channels, backbone.feature_info)
85+
super().__init__()
86+
if len(out_channels) != len(backbone.feature_info):
87+
raise ValueError(
88+
"`out_channels` must have the same len as the `backbone.feature_info.`"
89+
f"Got {len(out_channels)} and {len(backbone.feature_info)} respectively."
90+
)
91+
92+
self.backbone = backbone
93+
self.out_channels = out_channels
94+
self.feature_info = []
95+
96+
# flip the feature info so that we start building the
97+
# upsampling blocks from the bottleneck layer
98+
feature_info = backbone.feature_info[::-1]
99+
100+
# bottleneck layer
101+
self.bottleneck = nn.Conv2d(
102+
in_channels=feature_info[0]["num_chs"],
103+
out_channels=self.out_channels[0],
104+
kernel_size=1,
105+
)
106+
107+
# add timm-like feature info of the bottleneck layer
108+
self.feature_info.append(
109+
{
110+
"num_chs": self.out_channels[0],
111+
"module": "bottleneck",
112+
"reduction": float(feature_info[0]["reduction"]),
113+
}
114+
)
115+
116+
self.up_blocks = nn.ModuleDict()
117+
n_up_blocks = list(range(1, len(self.out_channels)))
118+
for i, (out_chls, finfo, n_blocks) in enumerate(
119+
zip(self.out_channels[1:], feature_info[1:], n_up_blocks)
120+
):
121+
up_blocks = []
122+
squeeze_rates = list(range(n_blocks))[::-1]
123+
124+
for j, squeeze_ratio in zip(range(n_blocks), squeeze_rates):
125+
if j == 0:
126+
in_channels = finfo["num_chs"]
127+
else:
128+
in_channels = up.out_channels # noqa
129+
130+
up = FeatUpsampleBlock(
131+
in_channels=in_channels,
132+
out_channels=out_chls * (2**squeeze_ratio),
133+
scale_factor=2,
134+
)
135+
up_blocks.append(up)
136+
137+
# add feature info
138+
self.feature_info.append(
139+
{
140+
"num_chs": out_chls,
141+
"module": f"up{i + 1}",
142+
"reduction": finfo["reduction"] / 2**n_blocks,
143+
}
144+
)
145+
self.up_blocks[f"up{i + 1}"] = nn.Sequential(*up_blocks)
146+
147+
# flip the feature info back to the original order to match the top-down
148+
# order of timm feature_info. (high to low res)
149+
self.feature_info = self.feature_info[::-1]
150+
151+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
152+
# get the features from the backbone
153+
final_feat, feats = self.backbone(x)
154+
155+
# flip the features so that we start from the bottleneck (low res)
156+
feats = feats[::-1]
157+
158+
# bottleneck feature
159+
up_feat = self.bottleneck(feats[0])
160+
intermediate_features = [up_feat]
161+
162+
# upsampled features
163+
for i, feat in enumerate(feats[1:]):
164+
up_feat = self.up_blocks[f"up{i + 1}"](feat)
165+
intermediate_features.append(up_feat)
166+
167+
return final_feat, tuple(intermediate_features[::-1]) # feats in top-down order

0 commit comments

Comments
 (0)