Skip to content

Commit 7e2a740

Browse files
Merge pull request #2 from mahmoodlab/more_models
Support for additional patch encoders
2 parents 51348c3 + eb9d0a4 commit 7e2a740

File tree

10 files changed

+219
-19
lines changed

10 files changed

+219
-19
lines changed

README.md

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ This project was developed by the [Mahmood Lab](https://faisal.ai/) at Harvard M
1515

1616
- **Tissue Segmentation**: Extract tissue from background using a DeepLabv3 model (supports H&E, IHC, penmark and artifact removal, etc.).
1717
- **Patch Extraction**: Extract tissue patches of any size and magnification.
18-
- **Patch Feature Extraction**: Extract patch embeddings from tissue patches using 13 popular foundation models, including [UNI](https://www.nature.com/articles/s41591-024-02857-3), [CONCH](https://www.nature.com/articles/s41591-024-02856-4), [Virchow](https://www.nature.com/articles/s41591-024-03141-0), [H-Optimus-0](https://github.com/bioptimus/releases/tree/main/models/h-optimus/v0) and many more...
18+
- **Patch Feature Extraction**: Extract patch embeddings from tissue patches using 20 popular foundation models, including [UNI](https://www.nature.com/articles/s41591-024-02857-3), [CONCH](https://www.nature.com/articles/s41591-024-02856-4), [Virchow](https://www.nature.com/articles/s41591-024-03141-0), [H-Optimus-0](https://github.com/bioptimus/releases/tree/main/models/h-optimus/v0) and many more...
1919
- **Slide Feature Extraction**: Extract slide embeddings from pre-extracted patch embeddings using 5 whole-slide foundation models, including [Threads](https://arxiv.org/abs/2501.16652) (coming soon!), [Titan](https://arxiv.org/abs/2411.19666),
2020
[PRISM](https://arxiv.org/abs/2405.10254), [GigaPath](https://www.nature.com/articles/s41586-024-07441-w) and [CHIEF](https://www.nature.com/articles/s41586-024-07894-z).
2121

@@ -93,7 +93,7 @@ python run_single_slide.py --slide_path wsis/xxxx.svs --job_dir ./trident_proces
9393
- **Outputs**:
9494
- Features are saved as h5 files in `./trident_processed/20x_256px/features_uni_v1`. (Shape: `(n_patches, feature_dim)`)
9595

96-
Trident supports 13 patch encoders, loaded via a patch-level [`encoder_factory`](https://github.com/mahmoodlab/trident/blob/main/trident/patch_encoder_models/load.py#L14). Models requiring specific installations will return error messages with additional instructions. Gated models on HuggingFace require access requests.
96+
Trident supports 20 patch encoders, loaded via a patch-level [`encoder_factory`](https://github.com/mahmoodlab/trident/blob/main/trident/patch_encoder_models/load.py#L14). Models requiring specific installations will return error messages with additional instructions. Gated models on HuggingFace require access requests.
9797

9898
- **UNI**: [MahmoodLab/UNI](https://huggingface.co/MahmoodLab/UNI) (`--patch_encoder uni_v1`)
9999
- **UNIv2**: [MahmoodLab/UNI2-h](https://huggingface.co/MahmoodLab/UNI2-h) (`--patch_encoder uni_v2`)
@@ -106,8 +106,11 @@ Trident supports 13 patch encoders, loaded via a patch-level [`encoder_factory`]
106106
- **Prov-Gigapath**: [prov-gigapath](https://huggingface.co/prov-gigapath/prov-gigapath) (`--patch_encoder gigapath`)
107107
- **H-Optimus-0**: [bioptimus/H-optimus-0](https://huggingface.co/bioptimus/H-optimus-0) (`--patch_encoder hoptimus0`)
108108
- **MUSK**: [xiangjx/musk](https://huggingface.co/xiangjx/musk) (`--patch_encoder musk`)
109+
- **Kaiko**: Hosted on TorchHub (`--patch_encoder kaiko-vits8, kaiko-vits16, kaiko-vitb8, kaiko-vitb16, kaiko-vitl14`)
110+
- **Lunit**: [1aurent/vit_small_patch8_224.lunit_dino](https://huggingface.co/1aurent/vit_small_patch8_224.lunit_dino) (`--patch_encoder lunit-vits8`)
111+
- **Hibou**: [histai/hibou-L](https://huggingface.co/histai/hibou-L) (`--patch_encoder hibou_l`)
109112
- **CTransPath-CHIEF**: Automatic download (`--patch_encoder ctranspath`)
110-
- **ResNet50**: Pretrained on ImageNet via torchvision. (`--patch_encoder resnet50`)
113+
- **ResNet50**: Hosted on torchvision. (`--patch_encoder resnet50`)
111114

112115
**Step 3b: Slide Feature Extraction:** Extracts slide embeddings using a slide encoder. Will also automatically extract patch embeddings.
113116
- **Command**:
@@ -124,11 +127,12 @@ Trident supports 13 patch encoders, loaded via a patch-level [`encoder_factory`]
124127
- Features are saved as h5 files in `./trident_processed/20x_256px/slide_features_titan`. (Shape: `(feature_dim)`)
125128

126129
Trident supports 5 slide encoders, loaded via a slide-level [`encoder_factory`](https://github.com/mahmoodlab/trident/blob/main/trident/slide_encoder_models/load.py#L14). Models requiring specific installations will return error messages with additional instructions. Gated models on HuggingFace require access requests.
127-
- **Threads**: Coming Soon! [MahmoodLab/threads](https://huggingface.co/MahmoodLab/threads) (`--slide_encoder threads`).
128-
- **Titan**: [MahmoodLab/TITAN](https://huggingface.co/MahmoodLab/TITAN) (`--slide_encoder titan`)
129-
- **PRISM**: [paige-ai/Prism](https://huggingface.co/paige-ai/Prism) (`--slide_encoder prism`)
130-
- **CHIEF**: [CHIEF](https://github.com/hms-dbmi/CHIEF) (`--slide_encoder chief`)
131-
- **GigaPath**: [prov-gigapath]() (`--slide_encoder gigapath`)
130+
- **Threads**: Coming Soon! [MahmoodLab/threads](https://huggingface.co/MahmoodLab/threads) (`--slide_encoder threads`). Based on `conch_v15` with `512x512` @20x.
131+
- **Titan**: [MahmoodLab/TITAN](https://huggingface.co/MahmoodLab/TITAN) (`--slide_encoder titan`). Based on `conch_v15` with `512x512` @20x.
132+
- **PRISM**: [paige-ai/Prism](https://huggingface.co/paige-ai/Prism) (`--slide_encoder prism`). Based on `virchow` with `256x256` @20x.
133+
- **CHIEF**: [CHIEF](https://github.com/hms-dbmi/CHIEF) (`--slide_encoder chief`). Based on `ctranspath` with `256x256` @10x.
134+
- **GigaPath**: [prov-gigapath](https://huggingface.co/prov-gigapath/prov-gigapath) (`--slide_encoder gigapath`). Based on `gigapath` with `256x256x` @20x.
135+
- **Madeleine**: [MahmoodLab/madeleine](https://huggingface.co/MahmoodLab/madeleine) (`--slide_encoder madeleine`). Based on `conch_v1` with `256x256` @10x.
132136

133137
> [!NOTE]
134138
> If you have a patient containing multiple slides, you have two ways for constructing whole-patient embeddings: processing each slide independently and taking the average of the slide features (late fusion) or pooling all patches together and processing that as a single "pseudo-slide" (early fusion). You can use Trident-generated slide embeddings in your own late fusion pipeline, or use Trident-generated patch embeddings in your own early fusion pipeline. For an implementation of both fusion strategies, please check out our sister repository [Patho-Bench](https://github.com/mahmoodlab/Patho-Bench).

run_batch_of_slides.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,12 @@ def parse_arguments():
6060
parser.add_argument('--patch_encoder', type=str, default='conch_v15',
6161
choices=['conch_v1', 'uni_v1', 'uni_v2', 'ctranspath', 'phikon',
6262
'resnet50', 'gigapath', 'virchow', 'virchow2',
63-
'hoptimus0', 'phikon_v2', 'conch_v15', 'musk'],
63+
'hoptimus0', 'phikon_v2', 'conch_v15', 'musk', 'hibou_l',
64+
'kaiko-vits8', 'kaiko-vits16', 'kaiko-vitb8', 'kaiko-vitb16',
65+
'kaiko-vitl14', 'lunit-vits8'],
6466
help='Patch encoder to use')
6567
parser.add_argument('--slide_encoder', type=str, default=None,
66-
choices=['threads', 'titan', 'prism', 'gigapath', 'chief',
68+
choices=['threads', 'titan', 'prism', 'gigapath', 'chief', 'madeleine',
6769
'mean-virchow', 'mean-virchow2', 'mean-conch_v1', 'mean-conch_v15', 'mean-ctranspath',
6870
'mean-gigapath', 'mean-resnet50', 'mean-hoptimus0', 'mean-phikon', 'mean-phikon_v2',
6971
'mean-musk', 'mean-uni_v1', 'mean-uni_v2',

run_single_slide.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ def parse_arguments():
2222
parser.add_argument("--gpu", type=int, default=0, help="GPU index to use for processing tasks")
2323
parser.add_argument("--slide_path", type=str, required=True, help="Path to the WSI file to process")
2424
parser.add_argument("--job_dir", type=str, required=True, help="Directory to store outputs")
25-
parser.add_argument("--patch_encoder", type=str, default="uni_v1",
26-
choices=["conch_v1", "uni_v1", "uni_v2", "ctranspath", "phikon",
27-
"resnet50", "gigapath", "virchow", "virchow2",
28-
"hoptimus0", "phikon_v2", "conch_v15", "musk"],
29-
help="Patch encoder for feature extraction")
25+
parser.add_argument('--patch_encoder', type=str, default='conch_v15',
26+
choices=['conch_v1', 'uni_v1', 'uni_v2', 'ctranspath', 'phikon',
27+
'resnet50', 'gigapath', 'virchow', 'virchow2',
28+
'hoptimus0', 'phikon_v2', 'conch_v15', 'musk', 'hibou_l',
29+
'kaiko-vits8', 'kaiko-vits16', 'kaiko-vitb8', 'kaiko-vitb16',
30+
'kaiko-vitl14', 'lunit-vits8'],
31+
help='Patch encoder to use')
3032
parser.add_argument("--mag", type=int, choices=[5, 10, 20, 40], default=20,
3133
help="Magnification at which patches/features are extracted")
3234
parser.add_argument("--patch_size", type=int, default=256, help="Patch size at which coords/features are extracted")

tests/test_patch_encoders.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,21 @@ def test_hoptimus0_forward(self):
7878

7979
def test_musk_forward(self):
8080
self._test_encoder_forward('musk')
81+
82+
def test_hibou_l_forward(self):
83+
self._test_encoder_forward('hibou_l')
84+
85+
def test_kaiko_forward(self):
86+
self._test_encoder_forward('kaiko-vits8')
87+
self._test_encoder_forward('kaiko-vits16')
88+
self._test_encoder_forward('kaiko-vitb8')
89+
self._test_encoder_forward('kaiko-vitb16')
90+
self._test_encoder_forward('kaiko-vitl14')
91+
92+
def test_lunitvits8_forward(self):
93+
self._test_encoder_forward('lunit-vits8')
94+
95+
8196

8297
if __name__ == '__main__':
8398
unittest.main()

tests/test_slide_encoders.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,17 @@ def test_slide_encoder_factory_with_valid_names(self):
6969
('chief', CHIEFSlideEncoder),
7070
('gigapath', GigaPathSlideEncoder),
7171
('titan', TitanSlideEncoder),
72+
('madeleine', MadeleineSlideEncoder),
7273
]:
7374
encoder = encoder_factory(model_name)
7475
self.assertIsInstance(encoder, expected_class)
7576

77+
def test_madeleine_encoder_initialization(self):
78+
sample_batch = {
79+
'features': torch.randn(1, 100, 512),
80+
}
81+
self._test_encoder_forward(MadeleineSlideEncoder(), sample_batch, torch.bfloat16)
82+
7683
def test_slide_encoder_factory_invalid_name(self):
7784
print("\033[95m" + "Testing Slide Encoder Factory with invalid names" + "\033[0m")
7885
with self.assertRaises(ValueError):

trident/patch_encoder_models/load.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,20 @@ def encoder_factory(model_name, **kwargs):
4949
enc = Phikonv2InferenceEncoder
5050
elif model_name == 'musk':
5151
enc = MuskInferenceEncoder
52+
elif model_name == 'hibou_l':
53+
enc = HibouLInferenceEncoder
54+
elif model_name == 'kaiko-vitb8':
55+
enc = KaikoB8InferenceEncoder
56+
elif model_name == 'kaiko-vitb16':
57+
enc = KaikoB16InferenceEncoder
58+
elif model_name == 'kaiko-vits8':
59+
enc = KaikoS8InferenceEncoder
60+
elif model_name == 'kaiko-vits16':
61+
enc = KaikoS16InferenceEncoder
62+
elif model_name == 'kaiko-vitl14':
63+
enc = KaikoL14InferenceEncoder
64+
elif model_name == 'lunit-vits8':
65+
enc = LunitS8InferenceEncoder
5266
else:
5367
raise ValueError(f"Unknown encoder name {model_name}")
5468

@@ -130,6 +144,7 @@ def forward(self, x):
130144
return_global=self.return_global
131145
)[0] # Forward pass yields (vision_cls, text_cls). We only need vision_cls.
132146

147+
133148
class Conchv1InferenceEncoder(BasePatchEncoder):
134149

135150
def _build(self, with_proj = False, normalize = False, **kwargs):
@@ -235,6 +250,83 @@ def forward_features(self, x):
235250
return out
236251

237252

253+
class HibouLInferenceEncoder(BasePatchEncoder):
254+
def _build(self, **kwargs):
255+
256+
from transformers import AutoModel
257+
from torchvision.transforms import InterpolationMode
258+
259+
self.enc_name = 'hibou_l'
260+
weights_path = get_weights_path('patch', self.enc_name)
261+
262+
if os.path.exists(weights_path):
263+
model = AutoModel.from_pretrained(weights_path)
264+
else:
265+
model = AutoModel.from_pretrained("histai/hibou-L", trust_remote_code=True)
266+
os.makedirs(weights_path, exist_ok=True)
267+
model.save_pretrained(weights_path)
268+
269+
mean, std = get_constants('hibou')
270+
eval_transform = get_eval_transforms(mean, std, target_img_size=224, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True)
271+
precision = torch.float32
272+
273+
return model, eval_transform, precision
274+
275+
def forward(self, x):
276+
out = self.forward_features(x)
277+
out = out.pooler_output
278+
return out
279+
280+
def forward_features(self, x):
281+
out = self.model(pixel_values=x)
282+
return out
283+
284+
285+
class KaikoInferenceEncoder(BasePatchEncoder):
286+
MODEL_NAME = None # set in subclasses
287+
288+
def _build(self, **kwargs):
289+
from torchvision.transforms import InterpolationMode
290+
self.enc_name = f"kaiko-{self.MODEL_NAME}"
291+
weights_path = get_weights_path("patch", self.enc_name)
292+
293+
if os.path.exists(weights_path):
294+
model = torch.load(weights_path, map_location="cpu", weights_only=False)
295+
else:
296+
model = torch.hub.load("kaiko-ai/towards_large_pathology_fms", self.MODEL_NAME, trust_repo=True)
297+
os.makedirs(os.path.dirname(weights_path), exist_ok=True)
298+
torch.save(model, weights_path)
299+
300+
mean, std = get_constants("kaiko")
301+
eval_transform = get_eval_transforms(mean, std, target_img_size=224, center_crop=True, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=True)
302+
precision = torch.float32
303+
304+
return model, eval_transform, precision
305+
306+
def forward(self, x):
307+
return self.model(x)
308+
309+
310+
class KaikoS16InferenceEncoder(KaikoInferenceEncoder):
311+
MODEL_NAME = "vits16"
312+
313+
314+
class KaikoS8InferenceEncoder(KaikoInferenceEncoder):
315+
MODEL_NAME = "vits8"
316+
317+
318+
class KaikoB16InferenceEncoder(KaikoInferenceEncoder):
319+
MODEL_NAME = "vitb16"
320+
321+
322+
class KaikoB8InferenceEncoder(KaikoInferenceEncoder):
323+
MODEL_NAME = "vitb8"
324+
325+
326+
class KaikoL14InferenceEncoder(KaikoInferenceEncoder):
327+
MODEL_NAME = "vitl14"
328+
329+
238330
class ResNet50InferenceEncoder(BasePatchEncoder):
239331
def _build(
240332
self,
@@ -273,7 +365,27 @@ def forward_features(self, x):
273365
out = out[0]
274366
return out
275367

368+
369+
class LunitS8InferenceEncoder(BasePatchEncoder):
370+
def _build(self, **kwargs):
371+
import timm
372+
from timm.data import resolve_model_data_config
373+
from timm.data.transforms_factory import create_transform
374+
375+
self.enc_name = 'lunit-vits8'
376+
377+
model = timm.create_model(
378+
model_name="hf-hub:1aurent/vit_small_patch8_224.lunit_dino",
379+
pretrained=True,
380+
)
381+
382+
data_config = resolve_model_data_config(model)
383+
eval_transform = create_transform(**data_config, is_training=False)
384+
precision = torch.float32
385+
386+
return model, eval_transform, precision
276387

388+
277389
class UNIInferenceEncoder(BasePatchEncoder):
278390
def _build(
279391
self,

trident/patch_encoder_models/local_ckpts.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010
"virchow2": "",
1111
"hoptimus0": "",
1212
"phikon_v2": "./phikon_v2",
13+
"hibou_l": "./hibou_l",
14+
"kaiko-vitb8": "./kaiko_b8",
15+
"kaiko-vitb16": "./kaiko_b16",
16+
"kaiko-vits8": "./kaiko_s8",
17+
"kaiko-vits16": "./kaiko_s16",
18+
"kaiko-vitl14": "./kaiko_l14",
19+
"lunit-vits8": "./lunit_s8",
1320
"conch_v15": "./conchv1_5/pytorch_model_vision.bin",
1421
"custom_encoder": ""
1522
}

trident/patch_encoder_models/utils/constants.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
IMAGENET_STD = [0.229, 0.224, 0.225]
33
OPENAI_MEAN = [0.48145466, 0.4578275, 0.40821073]
44
OPENAI_STD = [0.26862954, 0.26130258, 0.27577711]
5+
HIBOU_MEAN = [0.7068, 0.5755, 0.722]
6+
HIBOU_STD = [0.195, 0.2316, 0.1816]
7+
KAIKO_MEAN = [0.5, 0.5, 0.5]
8+
KAIKO_STD = [0.5, 0.5, 0.5]
59
NONE_MEAN = None
610
NONE_STD = None
711

@@ -10,7 +14,11 @@ def get_constants(norm='imagenet'):
1014
return IMAGENET_MEAN, IMAGENET_STD
1115
elif norm == 'openai_clip':
1216
return OPENAI_MEAN, OPENAI_STD
17+
elif norm == 'hibou':
18+
return HIBOU_MEAN, HIBOU_STD
1319
elif norm == 'none':
1420
return NONE_MEAN, NONE_STD
21+
elif norm == 'kaiko':
22+
return KAIKO_MEAN, KAIKO_STD
1523
else:
16-
raise ValueError(f"Invalid norm: {norm}")
24+
raise ValueError(f"Invalid norm: {norm}")

trident/slide_encoder_models/load.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def encoder_factory(model_name, pretrained=True, freeze=True, **kwargs):
3838
enc = CHIEFSlideEncoder
3939
elif 'gigapath' in model_name:
4040
enc = GigaPathSlideEncoder
41+
elif 'madeleine' in model_name:
42+
enc = MadeleineSlideEncoder
4143
elif 'abmil' in model_name:
4244
enc = ABMILSlideEncoder
4345
else:
@@ -53,7 +55,8 @@ def encoder_factory(model_name, pretrained=True, freeze=True, **kwargs):
5355
'tcga': 'conch_v15',
5456
'prism': 'virchow',
5557
'chief': 'ctranspath',
56-
'gigapath': 'gigapath'
58+
'gigapath': 'gigapath',
59+
'madeleine': 'conch_v1',
5760
}
5861

5962
####################################################################################################
@@ -286,6 +289,31 @@ def forward(self, batch, device='cuda'):
286289
return z
287290

288291

292+
class MadeleineSlideEncoder(BaseSlideEncoder):
293+
294+
def _build(self, pretrained=True, **kwargs):
295+
296+
assert pretrained, "MadeleineSlideEncoder has no non-pretrained models. Please load with pretrained=True."
297+
298+
self.enc_name = 'madeleine'
299+
weights_path = get_weights_path('slide', self.enc_name)
300+
embedding_dim = 512
301+
302+
try:
303+
from madeleine.models.factory import create_model_from_pretrained
304+
except:
305+
traceback.print_exc()
306+
raise Exception("Please install Madeleine using `pip install git+https://github.com/mahmoodlab/MADELEINE.git`")
307+
308+
model, precision = create_model_from_pretrained(weights_path)
309+
310+
return model, precision, embedding_dim
311+
312+
def forward(self, x, device='cuda'):
313+
z = self.model.encode_he(x['features'], device)
314+
return z
315+
316+
289317
class ThreadsSlideEncoder(BaseSlideEncoder):
290318

291319
def _build(self, pretrained=True, **kwargs):
@@ -297,7 +325,7 @@ def _build(self, pretrained=True, **kwargs):
297325
except:
298326
traceback.print_exc()
299327
raise Exception("Coming Soon! Thanks for your patience.")
300-
328+
301329
return None, None, None
302330

303331
def forward(self, batch, device='cuda', return_raw_attention=False):
@@ -351,6 +379,20 @@ def _build(self, model_name = 'mean-default', **kwargs):
351379
embedding_dim = 1024
352380
elif model_name == 'mean-musk':
353381
embedding_dim = 1024
382+
elif model_name == 'mean-hibou_l':
383+
embedding_dim = 1024
384+
elif model_name == 'mean-kaiko-vit8s':
385+
embedding_dim = 384
386+
elif model_name == 'mean-kaiko-vit16s':
387+
embedding_dim = 384
388+
elif model_name == 'mean-kaiko-vit8b':
389+
embedding_dim = 768
390+
elif model_name == 'mean-kaiko-vit16b':
391+
embedding_dim = 768
392+
elif model_name == 'mean-kaiko-vit14l':
393+
embedding_dim = 1024
394+
elif model_name == 'lunit-vits8':
395+
embedding_dim = 384
354396
else:
355397
print(f"\033[93mWARNING: Could not automatically infer embedding_dim for mean encoder {self.enc_name}. Setting to None.\033[0m")
356398
embedding_dim = None
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
{
2-
"chief": "./CHIEF"
2+
"chief": "./CHIEF",
3+
"madeleine": "./MADELEINE"
34
}

0 commit comments

Comments
 (0)