Skip to content

Commit d82fea1

Browse files
Merge pull request #62 from mahmoodlab/image_readers
Image readers
2 parents fe5075c + c943396 commit d82fea1

28 files changed

+2054
-648
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ This project was developed by the [Mahmood Lab](https://faisal.ai/) at Harvard M
1919
- **Slide Feature Extraction**: Extract slide embeddings from 5+ slide foundation models, including [Threads](https://arxiv.org/abs/2501.16652) (coming soon!), [Titan](https://arxiv.org/abs/2411.19666), and [GigaPath](https://www.nature.com/articles/s41586-024-07441-w).
2020

2121
### Updates:
22+
- 04.25: Native support for PIL.Image and CuCIM (use `wsi = load_wsi(xxx.svs)`). Support for seg + patch encoding without Internet.
2223
- 04.25: Remove artifacts from the tissue segmentation with `--remove_artifacts`. Works well for H&E.
2324
- 02.25: New image converter from `czi`, `png`, etc to `tiff`.
2425
- 02.25: Support for [GrandQC](https://www.nature.com/articles/s41467-024-54769-y) tissue vs. background segmentation.

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "trident"
3-
version = "0.0.5"
3+
version = "0.1.0"
44
description = "A package for preprocessing whole-slide images."
55
authors = [
66
"Andrew Zhang <andrewzh@mit.edu>",
@@ -43,6 +43,7 @@ build-backend = "poetry.core.masonry.api"
4343
[tool.poetry.package]
4444
include = [
4545
{ format = "file", path = "trident/slide_encoder_models/local_ckpts.json" },
46-
{ format = "file", path = "trident/patch_encoder_models/local_ckpts.json" }
46+
{ format = "file", path = "trident/patch_encoder_models/local_ckpts.json" },
47+
{ format = "file", path = "trident/segmentation_models/local_ckpts.json" },
4748
]
4849
include_package_data = true

run_batch_of_slides.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import argparse
1111
import torch
1212
from trident import Processor
13+
from trident import WSIReaderType
1314

1415

1516
def parse_arguments():
@@ -23,6 +24,15 @@ def parse_arguments():
2324
choices=['cache', 'seg', 'coords', 'feat', 'all'],
2425
help='Task to run: cache, seg (segmentation), coords (save tissue coordinates), img (save tissue images), feat (extract features)')
2526
parser.add_argument('--job_dir', type=str, required=True, help='Directory to store outputs')
27+
parser.add_argument('--wsi_cache', type=str, default=None,
28+
help='Directory to copy slides to for local processing')
29+
parser.add_argument('--clear_cache', action='store_true', default=False,
30+
help='Delete slides from cache after processing')
31+
parser.add_argument('--skip_errors', action='store_true', default=False,
32+
help='Skip errored slides and continue processing')
33+
parser.add_argument('--max_workers', type=int, default=None, help='Maximum number of workers. Set to 0 to use main process.')
34+
35+
# Slide-related arguments
2636
parser.add_argument('--wsi_dir', type=str, required=True,
2737
help='Directory containing WSI files (no nesting allowed)')
2838
parser.add_argument('--wsi_ext', type=str, nargs='+', default=None,
@@ -31,12 +41,9 @@ def parse_arguments():
3141
help='Custom keys used to store the resolution as MPP (micron per pixel) in your list of whole-slide image.')
3242
parser.add_argument('--custom_list_of_wsis', type=str, default=None,
3343
help='Custom list of WSIs specified in a csv file.')
34-
parser.add_argument('--wsi_cache', type=str, default=None,
35-
help='Directory to copy slides to for local processing')
36-
parser.add_argument('--clear_cache', action='store_true', default=False,
37-
help='Delete slides from cache after processing')
38-
parser.add_argument('--skip_errors', action='store_true', default=False,
39-
help='Skip errored slides and continue processing')
44+
parser.add_argument('--reader_type', type=str, choices=['openslide', 'image', 'cucim'], default=None,
45+
help='Force the use of a specific WSI image reader. Options are ["openslide", "image", "cucim"]. Defaults to None (auto-determine which reader to use).')
46+
4047
# Segmentation arguments
4148
parser.add_argument('--segmenter', type=str, default='hest',
4249
choices=['hest', 'grandqc'],
@@ -66,6 +73,16 @@ def parse_arguments():
6673
'kaiko-vits8', 'kaiko-vits16', 'kaiko-vitb8', 'kaiko-vitb16',
6774
'kaiko-vitl14', 'lunit-vits8'],
6875
help='Patch encoder to use')
76+
parser.add_argument(
77+
'--patch_encoder_ckpt_path', type=str, default=None,
78+
help=(
79+
"Optional local path to a patch encoder checkpoint (.pt, .pth, .bin, or .safetensors). "
80+
"This is only needed in offline environments (e.g., compute clusters without internet). "
81+
"If not provided, models are downloaded automatically from Hugging Face. "
82+
"You can also specify local paths via the model registry at "
83+
"`./trident/patch_encoder_models/local_ckpts.json`."
84+
)
85+
)
6986
parser.add_argument('--slide_encoder', type=str, default=None,
7087
choices=['threads', 'titan', 'prism', 'gigapath', 'chief', 'madeleine',
7188
'mean-virchow', 'mean-virchow2', 'mean-conch_v1', 'mean-conch_v15', 'mean-ctranspath',
@@ -89,7 +106,9 @@ def initialize_processor(args):
89106
clear_cache=args.clear_cache,
90107
skip_errors=args.skip_errors,
91108
custom_mpp_keys=args.custom_mpp_keys,
92-
custom_list_of_wsis=args.custom_list_of_wsis
109+
custom_list_of_wsis=args.custom_list_of_wsis,
110+
max_workers=args.max_workers,
111+
reader_type=args.reader_type
93112
)
94113

95114
def run_task(processor, args):
@@ -107,12 +126,10 @@ def run_task(processor, args):
107126
segmentation_model = segmentation_model_factory(
108127
args.segmenter,
109128
confidence_thresh=args.seg_conf_thresh,
110-
device=f'cuda:{args.gpu}'
111129
)
112130
if args.remove_artifacts:
113131
artifact_remover_model = segmentation_model_factory(
114132
'grandqc_artifact',
115-
device=f'cuda:{args.gpu}'
116133
)
117134
else:
118135
artifact_remover_model = None
@@ -122,7 +139,8 @@ def run_task(processor, args):
122139
segmentation_model,
123140
seg_mag=segmentation_model.target_mag,
124141
holes_are_tissue= not args.remove_holes,
125-
artifact_remover_model=artifact_remover_model
142+
artifact_remover_model=artifact_remover_model,
143+
device=f'cuda:{args.gpu}',
126144
)
127145
elif args.task == 'coords':
128146
# Minimal example for tissue patching:
@@ -139,7 +157,7 @@ def run_task(processor, args):
139157
# Minimal example for feature extraction:
140158
# python run_batch_of_slides.py --task feat --wsi_dir wsis --job_dir trident_processed --patch_encoder uni_v1 --mag 20 --patch_size 256
141159
from trident.patch_encoder_models.load import encoder_factory
142-
encoder = encoder_factory(args.patch_encoder)
160+
encoder = encoder_factory(args.patch_encoder, weights_path=args.patch_encoder_ckpt_path)
143161
processor.run_patch_feature_extraction_job(
144162
coords_dir=args.coords_dir or f'{args.mag}x_{args.patch_size}px_{args.overlap}px_overlap',
145163
patch_encoder=encoder,

run_single_slide.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import argparse
1010
import os
1111

12-
from trident import OpenSlideWSI
12+
from trident import load_wsi
1313
from trident.segmentation_models import segmentation_model_factory
1414
from trident.patch_encoder_models import encoder_factory
1515

@@ -53,19 +53,19 @@ def process_slide(args):
5353

5454
# Initialize the WSI
5555
print(f"Processing slide: {args.slide_path}")
56-
slide = OpenSlideWSI(slide_path=args.slide_path, lazy_init=False, custom_mpp_keys=args.custom_mpp_keys)
56+
slide = load_wsi(slide_path=args.slide_path, lazy_init=False, custom_mpp_keys=args.custom_mpp_keys)
5757

5858
# Step 1: Tissue Segmentation
5959
print("Running tissue segmentation...")
6060
segmentation_model = segmentation_model_factory(
6161
model_name=args.segmenter,
6262
confidence_thresh=args.seg_conf_thresh,
63-
device=f"cuda:{args.gpu}"
6463
)
6564
slide.segment_tissue(
6665
segmentation_model=segmentation_model,
6766
target_mag=segmentation_model.target_mag,
68-
job_dir=args.job_dir
67+
job_dir=args.job_dir,
68+
device=f"cuda:{args.gpu}"
6969
)
7070
print(f"Tissue segmentation completed. Results saved to {args.job_dir}contours_geojson and {args.job_dir}contours")
7171

tests/test_encoder_same_local_hf.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import torch
2+
import numpy as np
3+
from PIL import Image
4+
import unittest
5+
import json
6+
from pathlib import Path
7+
8+
try:
9+
import lovely_tensors; lovely_tensors.monkey_patch()
10+
except:
11+
pass
12+
13+
import sys; sys.path.append('../')
14+
from trident.patch_encoder_models import *
15+
16+
17+
class TestEncoderConsistency(unittest.TestCase):
18+
@classmethod
19+
def setUpClass(cls):
20+
cls.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21+
cls.dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
22+
23+
def _load_encoder(self, encoder_name, source, weights_path=None, **kwargs):
24+
print(f" 🔧 Loading {encoder_name} ({source})")
25+
encoder = encoder_factory(encoder_name, weights_path=weights_path, **kwargs)
26+
encoder = encoder.to(self.device)
27+
encoder.eval()
28+
return encoder
29+
30+
def _run_forward(self, encoder, encoder_name, source):
31+
with torch.inference_mode(), torch.amp.autocast('cuda', dtype=encoder.precision):
32+
dummy_input = encoder.eval_transforms(self.dummy_image).to(self.device).unsqueeze(dim=0)
33+
output = encoder(dummy_input)
34+
print(f" 📐 Output shape from {source}: {tuple(output.shape)}")
35+
return output
36+
37+
def _compare_architecture(self, enc1, enc2):
38+
keys1 = set(enc1.state_dict().keys())
39+
keys2 = set(enc2.state_dict().keys())
40+
if keys1 != keys2:
41+
print("\033[1;33m⚠️ Architecture mismatch in keys:\033[0m")
42+
print(" Only in default :", keys1 - keys2)
43+
print(" Only in local :", keys2 - keys1)
44+
return False
45+
return True
46+
47+
def _compare_weights(self, enc1, enc2):
48+
diffs = []
49+
for k in enc1.state_dict().keys():
50+
w1 = enc1.state_dict()[k]
51+
w2 = enc2.state_dict()[k]
52+
if not torch.allclose(w1, w2, atol=1e-5, rtol=1e-4):
53+
abs_diff = (w1 - w2).abs()
54+
max_diff = abs_diff.max().item()
55+
mean_diff = abs_diff.mean().item()
56+
diffs.append((k, max_diff, mean_diff))
57+
if diffs:
58+
print("\033[1;33m⚠️ Weight differences found:\033[0m")
59+
for k, max_d, mean_d in sorted(diffs, key=lambda x: -x[1])[:10]:
60+
print(f" 🔍 {k:<50} max diff: {max_d:.4e}, mean diff: {mean_d:.4e}")
61+
return False
62+
return True
63+
64+
65+
def generate_encoder_test(encoder_name, weights_path, **kwargs):
66+
def test(self):
67+
header = f"🧪 TEST: {encoder_name}"
68+
if kwargs:
69+
kwarg_str = ', '.join(f"{k}={v}" for k, v in kwargs.items())
70+
header += f" ({kwarg_str})"
71+
print(f"\n\033[1;36m{'=' * len(header)}\n{header}\n{'=' * len(header)}\033[0m")
72+
73+
# Load models
74+
enc_default = self._load_encoder(encoder_name, source="default", **kwargs)
75+
enc_local = self._load_encoder(encoder_name, source="local checkpoint", weights_path=weights_path, **kwargs)
76+
77+
# # Compare architecture
78+
# arch_match = self._compare_architecture(enc_default, enc_local)
79+
# self.assertTrue(arch_match, f"Architecture mismatch in {encoder_name}")
80+
81+
# # Compare weights
82+
# weights_match = self._compare_weights(enc_default, enc_local)
83+
# self.assertTrue(weights_match, f"Weight mismatch in {encoder_name}")
84+
85+
# Run inference
86+
out_default = self._run_forward(enc_default, encoder_name, source="default")
87+
out_local = self._run_forward(enc_local, encoder_name, source="local checkpoint")
88+
89+
if torch.allclose(out_default, out_local, atol=1e-5, rtol=1e-4):
90+
print(f"\033[1;32m✅ Outputs match for {encoder_name}\033[0m")
91+
else:
92+
diff = (out_default - out_local).abs().max().item()
93+
print(f"\033[1;31m❌ Outputs do NOT match (max abs diff = {diff:.4e})\033[0m")
94+
self.fail(f"Output mismatch for {encoder_name} with kwargs={kwargs}")
95+
return test
96+
97+
98+
# Dynamically register tests before unittest.main()
99+
def register_tests():
100+
ckpt_path = Path('../trident/patch_encoder_models/local_ckpts_guillaume.json')
101+
with open(ckpt_path) as f:
102+
local_ckpts = json.load(f)
103+
104+
# local ckpt not supported
105+
local_ckpts.pop('musk')
106+
local_ckpts.pop('custom_encoder')
107+
local_ckpts.pop('hibou_l')
108+
109+
for encoder_name, path in local_ckpts.items():
110+
test_name = f"test_{encoder_name}"
111+
test_fn = generate_encoder_test(encoder_name, path)
112+
setattr(TestEncoderConsistency, test_name, test_fn)
113+
114+
115+
register_tests()
116+
117+
if __name__ == '__main__':
118+
unittest.main()

tests/test_openslidewsi.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch # Check for CUDA availability
44

55
import sys; sys.path.append('../')
6-
from trident import OpenSlideWSI
6+
from trident import load_wsi
77
from trident.segmentation_models import segmentation_model_factory
88
from trident.patch_encoder_models import encoder_factory
99

@@ -19,8 +19,8 @@ class TestOpenSlideWSI(unittest.TestCase):
1919
HF_REPO = "MahmoodLab/unit-testing"
2020
TEST_SLIDE_FILENAMES = [
2121
"394140.svs",
22-
# "TCGA-AN-A0XW-01Z-00-DX1.811E11E7-FA67-46BB-9BC6-1FD0106B789D.svs",
23-
# "TCGA-B6-A0IJ-01Z-00-DX1.BF2E062F-06DA-4CA8-86C4-36674C035CAA.svs"
22+
"TCGA-AN-A0XW-01Z-00-DX1.811E11E7-FA67-46BB-9BC6-1FD0106B789D.svs",
23+
"TCGA-B6-A0IJ-01Z-00-DX1.BF2E062F-06DA-4CA8-86C4-36674C035CAA.svs"
2424
]
2525
TEST_OUTPUT_DIR = "test_single_slide_processing/"
2626
TEST_PATCH_ENCODER = "uni_v1"
@@ -50,11 +50,11 @@ def test_integration(self):
5050
for slide_filename in self.TEST_SLIDE_FILENAMES:
5151
with self.subTest(slide=slide_filename):
5252
slide_path = os.path.join(self.local_wsi_dir, slide_filename)
53-
slide = OpenSlideWSI(slide_path=slide_path, lazy_init=False)
53+
slide = load_wsi(slide_path=slide_path, lazy_init=False)
5454

5555
# Step 1: Tissue segmentation
56-
segmentation_model = segmentation_model_factory("hest", device=self.TEST_DEVICE)
57-
slide.segment_tissue(segmentation_model=segmentation_model, target_mag=10, job_dir=self.TEST_OUTPUT_DIR)
56+
segmentation_model = segmentation_model_factory("hest")
57+
slide.segment_tissue(segmentation_model=segmentation_model, target_mag=10, job_dir=self.TEST_OUTPUT_DIR, device=self.TEST_DEVICE)
5858

5959
# Step 2: Tissue coordinate extraction
6060
coords_path = slide.extract_tissue_coords(

tests/test_patch_encoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,4 @@ def test_lunitvits8_forward(self):
9898

9999

100100
if __name__ == '__main__':
101-
unittest.main()
101+
unittest.main()

tests/test_processor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,11 @@ def test_tissue_processing(self):
8686
wsi_ext=self.TEST_WSI_EXT
8787
)
8888

89-
segmentation_model = segmentation_model_factory('hest', device=f'cuda:{self.TEST_GPU_INDEX}')
89+
segmentation_model = segmentation_model_factory('hest')
9090
self.processor.run_segmentation_job(
9191
segmentation_model=segmentation_model,
92-
seg_mag=5
92+
seg_mag=5,
93+
device=f'cuda:{self.TEST_GPU_INDEX}'
9394
)
9495
output_dirs = ["contours", "contours_geojson"]
9596
for dir_name in output_dirs:

tests/test_segmentation_models.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ def setUp(self):
2323
def _test_forward(self, encoder_name):
2424
print("\033[95m" + f"Testing {encoder_name} forward pass" + "\033[0m")
2525
device = 'cuda' if torch.cuda.is_available() else 'cpu'
26-
encoder = segmentation_model_factory(encoder_name, device=device)
26+
encoder = segmentation_model_factory(encoder_name).to(device)
2727

2828
self.dummy_image = np.random.randint(0, 256, (encoder.input_size, encoder.input_size, 3), dtype=np.uint8)
2929
self.dummy_image = Image.fromarray(self.dummy_image)
3030

3131
with torch.inference_mode():
32-
dummy_input = encoder.eval_transforms(self.dummy_image).unsqueeze(dim=0)
32+
dummy_input = encoder.eval_transforms(self.dummy_image).unsqueeze(dim=0).to(device)
3333
output = encoder(dummy_input)
3434

3535
self.assertIsNotNone(output)
@@ -39,7 +39,8 @@ def _test_forward(self, encoder_name):
3939
def test_hest(self):
4040
self._test_forward('hest')
4141

42-
# Add more segmentation models here
42+
def test_grandqc(self):
43+
self._test_forward('grandqc')
4344

4445
if __name__ == '__main__':
4546
unittest.main()

tests/test_slide_encoders.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import torch
33

44
import sys; sys.path.append('../')
5-
6-
# New imports to test
75
from trident.slide_encoder_models import *
86

97
"""

trident/Converter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
# PIL
1717
PIL_EXTENSIONS = {'.png', '.jpg', '.jpeg'}
1818

19+
# OpenSlide
20+
OPENSLIDE_EXTENSIONS = {'.svs', '.tif', '.dcm', '.vms', '.vmu', '.ndpi', '.scn', '.mrxs', '.tiff', '.svslide', '.bif', '.czi'}
21+
1922
# Combined with CZI
2023
SUPPORTED_EXTENSIONS = BIOFORMAT_EXTENSIONS | PIL_EXTENSIONS | {'.czi'}
2124

0 commit comments

Comments
 (0)