Skip to content

Commit 04ecb46

Browse files
committed
fix: small bugfixes to base encoder
1 parent 2136e8e commit 04ecb46

File tree

3 files changed

+218
-1
lines changed

3 files changed

+218
-1
lines changed

cellseg_models_pytorch/encoders/_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
class BaseTrEncoder(nn.Module):
1111
def __init__(
1212
self,
13+
name: str,
1314
checkpoint_path: Optional[Union[str, Path]] = None,
1415
out_indices: Optional[Tuple[int, ...]] = None,
1516
depth: int = 4,
@@ -19,6 +20,8 @@ def __init__(
1920
2021
Parameters
2122
----------
23+
name : str
24+
Name of the backbone.
2225
checkpoint_path : Optional[Union[Path, str]], optional
2326
Path to the weights of the backbone. If None, the backbone is initialized
2427
with random weights. Defaults to None.
@@ -31,6 +34,7 @@ def __init__(
3134
features will be the last `depth` features of the backbone. Defaults to 4.
3235
"""
3336
super().__init__()
37+
self.name = name
3438
self.depth = depth
3539

3640
# set checkpoint path
@@ -80,5 +84,5 @@ def load_checkpoint(self) -> None:
8084
except BaseException as e:
8185
raise RuntimeError(f"Error loading checkpoint: {e}")
8286

83-
print(f"Loading checkpoint: {msg}")
87+
print(f"Loading pre-trained {self.name} checkpoint: {msg}")
8488
self.backbone = backbone
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""Adapted from https://github.com/jopo666/HistoEncoder.
2+
3+
Copyright 2023 Joona Pohjonen
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
"""
17+
18+
from pathlib import Path
19+
from typing import List, Optional, Tuple, Union
20+
21+
import timm
22+
import torch
23+
import torch.nn as nn
24+
25+
from cellseg_models_pytorch.encoders._base import BaseTrEncoder
26+
27+
__all__ = ["HistoEncoder", "build_histo_encoder"]
28+
29+
# histo_encoder model name to timm model name mapping
30+
NAME_TO_MODEL = {
31+
"histo_encoder_prostate_s": "xcit_small_12_p16_224",
32+
"histo_encoder_prostate_m": "xcit_medium_24_p16_224",
33+
}
34+
35+
# name to pre-trained weights mapping
36+
MODEL_URLS = {
37+
"histo_encoder_prostate_s": "https://dl.dropboxusercontent.com/s/tbff9wslc8p7ie3/prostate_small.pth?dl=0", # noqa
38+
"histo_encoder_prostate_m": "https://dl.dropboxusercontent.com/s/k1fr09x5auki8sp/prostate_medium.pth?dl=0", # noqa
39+
}
40+
41+
42+
class HistoEncoder(BaseTrEncoder):
43+
def __init__(
44+
self,
45+
backbone: nn.Module,
46+
checkpoint_path: Optional[Union[Path, str]] = None,
47+
out_indices: Optional[Tuple[int, ...]] = None,
48+
num_blocks: int = 1,
49+
embed_dim: int = 384,
50+
patch_size: int = 16,
51+
avg_pool: bool = False,
52+
**kwargs,
53+
) -> None:
54+
"""Create HistoEncoder backbone.
55+
56+
HistoEncoder: https://github.com/jopo666/HistoEncoder
57+
58+
Parameters
59+
----------
60+
checkpoint_path : Optional[Union[Path, str]], optional
61+
Path to the weights of the backbone. If None and pretrained is False the
62+
backbone is initialized randomly. Defaults to None.
63+
num_blocks : int, optional
64+
Number of attention blocks to include in the extracted features.
65+
When `num_blocks>1`, the outputs of the last `num_blocks` attention
66+
blocks are concatenated to make up the features. Defaults to 1.
67+
avg_pool : bool, optional
68+
Whether to average pool the outputs of the last attention block.
69+
Defaults to False.
70+
"""
71+
super().__init__(
72+
name="Histo-encoder",
73+
checkpoint_path=checkpoint_path,
74+
out_indices=out_indices,
75+
)
76+
77+
self.backbone = backbone
78+
self.avg_pool = avg_pool
79+
self.num_blocks = num_blocks
80+
self.embed_dim = embed_dim
81+
self.patch_size = patch_size
82+
83+
if checkpoint_path is not None:
84+
self.load_checkpoint()
85+
86+
@property
87+
def n_blocks(self):
88+
"""Get the number of attention blocks in the backbone."""
89+
return len(self.backbone.blocks)
90+
91+
def forward_features(
92+
self, x: torch.Tensor
93+
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
94+
"""Forward pass of the backbone and return all the features.
95+
96+
Parameters
97+
----------
98+
x : torch.Tensor
99+
Input tensor (input image). Shape: (B, C, H, W)
100+
101+
Returns
102+
-------
103+
Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
104+
torch.Tensor: Output of last layers (all tokens, without classification)
105+
torch.Tensor: Classification output
106+
torch.Tensor: All the intermediate features from the attention blocks
107+
"""
108+
B = x.shape[0]
109+
x, (Hp, Wp) = self.backbone.patch_embed(x)
110+
111+
if self.backbone.pos_embed is not None:
112+
pos_encoding = (
113+
self.backbone.pos_embed(B, Hp, Wp)
114+
.reshape(B, -1, x.shape[1])
115+
.permute(0, 2, 1)
116+
)
117+
x = x + pos_encoding
118+
119+
x = self.backbone.pos_drop(x)
120+
121+
# Collect intermediate outputs.
122+
intermediate_outputs = []
123+
res_outputs = []
124+
for i, blk in enumerate(self.backbone.blocks):
125+
x = blk(x, Hp, Wp)
126+
intermediate_outputs.append(x)
127+
if i in self.out_indices:
128+
res_outputs.append(
129+
x.reshape(B, Hp, Wp, self.embed_dim).permute(0, 3, 1, 2)
130+
)
131+
132+
# collect intermediate outputs and add cls token block
133+
cls_tokens = self.backbone.cls_token.expand(B, -1, -1)
134+
x = torch.cat((cls_tokens, x), dim=1)
135+
for j, blk in enumerate(self.backbone.cls_attn_blocks, i + 1):
136+
x = blk(x)
137+
intermediate_outputs.append(x)
138+
if j in self.out_indices:
139+
res_outputs.append(
140+
x[:, 1:, :].reshape(B, Wp, Hp, self.embed_dim).permute(0, 3, 1, 2)
141+
)
142+
143+
norm_outputs = [
144+
self.backbone.norm(x) for x in intermediate_outputs[-self.num_blocks :]
145+
]
146+
output = torch.cat([x[:, 0] for x in norm_outputs], axis=-1)
147+
148+
if self.avg_pool:
149+
output = torch.cat(
150+
[output, torch.mean(norm_outputs[-1][:, 1:], dim=1)], axis=-1
151+
)
152+
153+
return torch.mean(norm_outputs[-1][:, 1:], dim=1), output, res_outputs
154+
155+
def forward(self, x: torch.Tensor) -> torch.Tensor:
156+
"""Forward pass of the histo-encoder backbone."""
157+
logits, cls_token, features = self.forward_features(x)
158+
159+
return features
160+
161+
162+
def build_histo_encoder(
163+
name: str, pretrained: bool = True, checkpoint_path: str = None, **kwargs
164+
) -> HistoEncoder:
165+
"""Build HistoEncoder backbone.
166+
167+
Parameters
168+
----------
169+
name : str
170+
Name of the encoder. Must be one of "histo_encoder_prostate_s".
171+
"histo_encoder_prostate_m".
172+
pretrained : bool, optional
173+
If True, load pretrained weights, by default True.
174+
checkpoint_path : str, optional
175+
Path to the weights of the backbone. If None and pretrained is False the
176+
backbone is initialized randomly. Defaults to None.
177+
178+
Returns
179+
-------
180+
nn.Module
181+
The initialized Histo-encoder.
182+
"""
183+
if name not in ("histo_encoder_prostate_s", "histo_encoder_prostate_m"):
184+
raise ValueError(
185+
f"Unknown encoder name: {name}, "
186+
"allowed values are 'histo_encoder_prostate_s', 'histo_encoder_prostate_m'"
187+
)
188+
189+
if checkpoint_path is None and pretrained:
190+
checkpoint_path = MODEL_URLS[name]
191+
192+
# init XCit backbone
193+
backbone = timm.create_model(NAME_TO_MODEL[name], num_classes=0, **kwargs)
194+
195+
if name == "histo_encoder_prostate_s":
196+
histo_encoder = HistoEncoder(
197+
backbone=backbone,
198+
out_indices=(2, 5, 10, 13),
199+
checkpoint_path=checkpoint_path,
200+
embed_dim=384,
201+
patch_size=16,
202+
)
203+
elif name == "histo_encoder_prostate_m":
204+
histo_encoder = HistoEncoder(
205+
backbone=backbone,
206+
out_indices=(4, 11, 20, 25),
207+
checkpoint_path=checkpoint_path,
208+
embed_dim=512,
209+
patch_size=16,
210+
)
211+
212+
return histo_encoder

cellseg_models_pytorch/encoders/vit_det_SAM.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,7 @@ def __init__(
628628
Indexes for blocks using global attention.
629629
"""
630630
super().__init__(
631+
name="SAM-VitDet",
631632
checkpoint_path=checkpoint_path,
632633
out_indices=out_indices,
633634
)

0 commit comments

Comments
 (0)