Skip to content

Commit e9a6c59

Browse files
clear unit tests
1 parent b9dc848 commit e9a6c59

File tree

2 files changed

+2
-20
lines changed

2 files changed

+2
-20
lines changed

tests/test_slide_encoders.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import unittest
2-
from unittest.mock import patch, MagicMock
32
import torch
4-
import warnings
53

64
import sys; sys.path.append('../')
75

@@ -32,20 +30,6 @@ def _test_encoder_forward(self, encoder, batch, expected_precision):
3230
self.assertTrue(output.shape[-1] == encoder.embedding_dim)
3331
print("\033[94m"+ f" {encoder.__class__.__name__} forward pass success with output shape {output.shape}" + "\033[0m")
3432

35-
def test_threads_encoder_initialization(self):
36-
sample_batch = {
37-
'features': torch.randn(1, 100, 768),
38-
'coords': torch.randn(1, 100, 2),
39-
}
40-
self._test_encoder_forward(ThreadsSlideEncoder(), sample_batch, torch.bfloat16)
41-
42-
# def test_trilobite_encoder_initialization(self):
43-
# sample_batch = {
44-
# 'features': torch.randn(1, 100, 768),
45-
# 'coords': torch.randn(1, 100, 2),
46-
# }
47-
# self._test_encoder_forward(TrilobiteSlideEncoder(), sample_batch, torch.float16)
48-
4933
def test_prism_encoder_initialization(self):
5034
sample_batch = {
5135
'features': torch.randn(1, 100, 2560),
@@ -79,13 +63,11 @@ def test_slide_encoder_factory_with_valid_names(self):
7963
print("\033[95m" + "Testing Slide Encoder Factory with valid names" + "\033[0m")
8064
# Test factory method for valid model names
8165
for model_name, expected_class in [
82-
('mean-conch-v15', MeanSlideEncoder),
66+
('mean-conch_v15', MeanSlideEncoder),
8367
('mean-blahblah', MeanSlideEncoder),
8468
('prism', PRISMSlideEncoder),
8569
('chief', CHIEFSlideEncoder),
8670
('gigapath', GigaPathSlideEncoder),
87-
('threads', ThreadsSlideEncoder),
88-
# ('trilobite==trimodal-200k-new_model==epoch101', TrilobiteSlideEncoder),
8971
('titan', TitanSlideEncoder),
9072
]:
9173
encoder = encoder_factory(model_name)

trident/patch_encoder_models/load.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _build(self, inference_aug = False, with_proj = False, out_norm = False, ret
101101
from musk import utils, modeling
102102
except:
103103
traceback.print_exc()
104-
raise Exception("Please install MUSK `pip install git+https://github.com/lilab-stanford/MUSK`")
104+
raise Exception("Please install MUSK `pip install fairscale git+https://github.com/lilab-stanford/MUSK`")
105105

106106
try:
107107
from timm.models import create_model

0 commit comments

Comments
 (0)