Skip to content

Commit d3a281b

Browse files
committed
re-add bench
1 parent 8ffc85a commit d3a281b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+5567
-0
lines changed

src/hest/bench/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .data_modules import st_dataset
2+
from .training.predict_expression import benchmark_encoder
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .models import *
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from .models.custom_weight_loaders import load_pretrained_weights_into_model_cocavit
2+
from .models.vision_transformer_custom import vit_large_w_pooler
3+
from .models.densenetbackbone import DenseNetBackbone
4+
from .models.remedis_models import resnet152_remedis
5+
from .models.phikon import ibot_vit
6+
from .models import TimmCNNEncoder, TimmViTEncoder, HFViTEncoder
7+
import timm
8+
import os
9+
from functools import partial
10+
import torch
11+
from loguru import logger
12+
from .model_registry import _MODEL_CONFIGS
13+
from .utils import get_eval_transforms, get_constants
14+
import torch.nn as nn
15+
import torchvision.models as models
16+
from .models.post_processor import CLIPVisionModelPostProcessor
17+
18+
def get_encoder(model_name, overwrite_kwargs={}, img_size = 224):
19+
config = _MODEL_CONFIGS[model_name]
20+
for k in overwrite_kwargs:
21+
if k not in config:
22+
raise ValueError(f"Invalid overwrite key: {k}")
23+
config[k] = overwrite_kwargs[k]
24+
model, eval_transform = build_model(config)
25+
mean, std = get_constants(config['img_norm'])
26+
27+
if eval_transform is None:
28+
eval_transform = get_eval_transforms(mean, std, target_img_size=img_size)
29+
return model, eval_transform, config
30+
31+
def load_resnet18_ciga(ckpt_path):
32+
def clean_state_dict_ciga(state_dict):
33+
state_dict = {k.replace("model.resnet.", ''):v for k,v in state_dict.items() if 'fc.' not in k}
34+
return state_dict
35+
base_encoder = models.resnet18(weights=None)
36+
base_encoder.fc = nn.Identity()
37+
state_dict = torch.load(ckpt_path, map_location='cpu')['state_dict']
38+
state_dict = clean_state_dict_ciga(state_dict)
39+
base_encoder.load_state_dict(state_dict, strict=True)
40+
return base_encoder
41+
42+
43+
def build_model(config):
44+
logger.info(f"Building model with config: {config['name']}")
45+
load_state_dict = False
46+
eval_transform = None
47+
if config.get("checkpoint_path", None) is not None:
48+
if not os.path.exists(config["checkpoint_path"]):
49+
if os.environ.get("CHECKPOINT_PATH", None) is not None:
50+
config["checkpoint_path"] = os.environ["CHECKPOINT_PATH"]
51+
else:
52+
raise ValueError(f"checkpoint_path does not exist: {config['checkpoint_path']} and no CHECKPOINT_PATH environment variable set")
53+
load_state_dict = True
54+
if config['loader'] == 'timm_wrapper_cnn':
55+
# uses timm to load a CNN model, then wraps it in a custom module that adds pooling
56+
model = TimmCNNEncoder(**config['loader_kwargs'])
57+
elif config['loader'] == 'hf_wrapper_vit':
58+
model = HFViTEncoder(**config['loader_kwargs'])
59+
elif config['loader'] == 'conch_openclip_custom':
60+
from conch.open_clip_custom import create_model_from_pretrained
61+
model, _ = create_model_from_pretrained(**config['loader_kwargs'], checkpoint_path=config["checkpoint_path"])
62+
model.forward = partial(model.encode_image, proj_contrast=False, normalize=False)
63+
elif config['loader'] == 'timm':
64+
# uses timm to load a model
65+
model = timm.create_model(**config['loader_kwargs'])
66+
elif config['loader'] == 'ctranspath_loader':
67+
from .models.ctran import ctranspath
68+
ckpt_path = config["checkpoint_path"]
69+
assert os.path.isfile(ckpt_path)
70+
model = ctranspath(img_size=224)
71+
model.head = nn.Identity()
72+
state_dict = torch.load(ckpt_path)['model']
73+
state_dict = {key: val for key, val in state_dict.items() if 'attn_mask' not in key}
74+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
75+
load_state_dict = False
76+
### Kimia Net
77+
elif config['loader'] == 'kimianet_loader':
78+
ckpt_path = config["checkpoint_path"]
79+
assert os.path.isfile(ckpt_path)
80+
model = models.densenet121()
81+
state_dict = torch.load(ckpt_path, map_location='cpu')
82+
state_dict = {"features."+k[len("module.model.0."):]:v for k,v in state_dict.items() if "fc_4" not in k}
83+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
84+
assert missing_keys == ['classifier.weight', 'classifier.bias']
85+
model = DenseNetBackbone(model)
86+
load_state_dict = False
87+
elif config['loader'] == 'ciga_loader':
88+
model = load_resnet18_ciga(config["checkpoint_path"])
89+
load_state_dict = False
90+
elif config['loader'] == 'remedis_loader':
91+
ckpt_path = config["checkpoint_path"]
92+
model = resnet152_remedis(ckpt_path=ckpt_path, pretrained=True)
93+
load_state_dict = False
94+
elif config['loader'] == 'plip_loader':
95+
from transformers import CLIPImageProcessor, CLIPVisionModel
96+
model_name = "vinid/plip"
97+
img_transforms_clip = CLIPImageProcessor.from_pretrained(model_name)
98+
model = CLIPVisionModel.from_pretrained(
99+
model_name) # Use for feature extraction
100+
model = CLIPVisionModelPostProcessor(model)
101+
def _eval_transform(img): return img_transforms_clip(
102+
img, return_tensors='pt', padding=True)['pixel_values'].squeeze(0)
103+
eval_transform = _eval_transform
104+
elif config['loader'] == 'ibot_uni':
105+
ckpt_path = config["checkpoint_path"]
106+
model = ibot_vit.iBOTViT(architecture="vit_base_pancan", encoder="teacher", weights_path=ckpt_path)
107+
108+
load_state_dict = False
109+
elif config['loader'] == 'pathchat':
110+
kwargs = {}
111+
add_kwargs = {'pooler_n_queries_contrast': 1}
112+
add_kwargs['legacy'] = False
113+
kwargs.update(add_kwargs)
114+
model = vit_large_w_pooler(**kwargs, init_values=1e-6)
115+
ckpt_path = config["checkpoint_path"]
116+
checkpoint = ckpt_path.split('/')[-1]
117+
enc_name = os.path.dirname(ckpt_path).split('/')[-1]
118+
assets_dir = os.path.dirname(os.path.dirname(ckpt_path))
119+
load_pretrained_weights_into_model_cocavit(
120+
model, enc_name, checkpoint, assets_dir)
121+
122+
load_state_dict = False
123+
124+
elif config['loader'] == 'gigapath':
125+
from torchvision import transforms
126+
model = timm.create_model(model_name='vit_giant_patch14_dinov2',
127+
**{'img_size': 224, 'in_chans': 3,
128+
'patch_size': 16, 'embed_dim': 1536,
129+
'depth': 40, 'num_heads': 24, 'init_values': 1e-05,
130+
'mlp_ratio': 5.33334, 'num_classes': 0})
131+
ckpt_path = config["checkpoint_path"]
132+
state_dict = torch.load(ckpt_path, map_location='cpu')
133+
model.load_state_dict(state_dict, strict=True)
134+
eval_transform = transforms.Compose(
135+
[
136+
transforms.CenterCrop(224),
137+
transforms.ToTensor(),
138+
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
139+
]
140+
)
141+
load_state_dict = False
142+
143+
else:
144+
raise ValueError(f"Unsupported loader type: {config['loader']}")
145+
if load_state_dict:
146+
ckpt_path = config["checkpoint_path"]
147+
strict = config.get("load_state_dict_strict", False)
148+
logger.info(f"Loading model from checkpoint: {ckpt_path}")
149+
logger.info(f"load_state_dict_strict: {strict}")
150+
missing, unexpected = model.load_state_dict(torch.load(ckpt_path, map_location="cpu"),
151+
strict=strict)
152+
logger.info(f"Missing keys: {missing}")
153+
logger.info(f"Unexpected keys: {unexpected}")
154+
return model, eval_transform
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from pathlib import Path
2+
import json
3+
import re
4+
5+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / "pretrained_configs/", Path(__file__).parent.parent / "private/pretrained_configs/"]
6+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
7+
8+
def _natural_key(string_):
9+
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
10+
11+
def _rescan_model_configs():
12+
global _MODEL_CONFIGS
13+
14+
config_ext = ('.json',)
15+
config_files = []
16+
for config_path in _MODEL_CONFIG_PATHS:
17+
if config_path.is_file() and config_path.suffix in config_ext:
18+
config_files.append(config_path)
19+
elif config_path.is_dir():
20+
for ext in config_ext:
21+
config_files.extend(config_path.glob(f'*{ext}'))
22+
23+
for cf in config_files:
24+
with open(cf, 'r') as f:
25+
model_cfg = json.load(f)
26+
_MODEL_CONFIGS[cf.stem] = model_cfg
27+
28+
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
29+
30+
31+
_rescan_model_configs() # initial populate of model config registry
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .vision_transformer_latest import *
2+
from .vision_transformer_dinov2 import (vit_small as vit_small_dinov2,
3+
vit_base as vit_base_dinov2,
4+
vit_large as vit_large_dinov2,
5+
clean_state_dict as clean_state_dict_dinov2)
6+
from .vision_transformer_ijepa import (vit_huge as vit_huge_ijepa, clean_state_dict as clean_state_dict_ijepa)
7+
from .timm_wrappers import *
8+
from .hf_wrappers import *
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
2+
from timm_ctp.models.layers.helpers import to_2tuple
3+
from timm_ctp import create_model as ctp_create_model
4+
import torch.nn as nn
5+
from functools import partial
6+
import pdb
7+
8+
class ConvStem(nn.Module):
9+
10+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
11+
super().__init__()
12+
13+
assert patch_size == 4
14+
assert embed_dim % 8 == 0
15+
16+
img_size = to_2tuple(img_size)
17+
patch_size = to_2tuple(patch_size)
18+
self.img_size = img_size
19+
self.patch_size = patch_size
20+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
21+
self.num_patches = self.grid_size[0] * self.grid_size[1]
22+
self.flatten = flatten
23+
24+
25+
stem = []
26+
input_dim, output_dim = 3, embed_dim // 8
27+
for l in range(2):
28+
stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
29+
stem.append(nn.BatchNorm2d(output_dim))
30+
stem.append(nn.ReLU(inplace=True))
31+
input_dim = output_dim
32+
output_dim *= 2
33+
stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
34+
self.proj = nn.Sequential(*stem)
35+
36+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
37+
38+
def forward(self, x):
39+
B, C, H, W = x.shape
40+
# assert H == self.img_size[0] and W == self.img_size[1], \
41+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
42+
x = self.proj(x)
43+
if self.flatten:
44+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
45+
x = self.norm(x)
46+
return x
47+
48+
def ctranspath(img_size = 224, **kwargs):
49+
model = ctp_create_model('swin_tiny_patch4_window7_224',
50+
embed_layer=ConvStem,
51+
pretrained=False,
52+
img_size=img_size,
53+
**kwargs)
54+
return model
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from timm_ctp.models.layers.helpers import to_2tuple
2+
import timm_ctp
3+
import torch.nn as nn
4+
from functools import partial
5+
import pdb
6+
7+
class ConvStem(nn.Module):
8+
9+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
10+
super().__init__()
11+
12+
assert patch_size == 4
13+
assert embed_dim % 8 == 0
14+
15+
img_size = to_2tuple(img_size)
16+
patch_size = to_2tuple(patch_size)
17+
self.img_size = img_size
18+
self.patch_size = patch_size
19+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
20+
self.num_patches = self.grid_size[0] * self.grid_size[1]
21+
self.flatten = flatten
22+
23+
24+
stem = []
25+
input_dim, output_dim = 3, embed_dim // 8
26+
for l in range(2):
27+
stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
28+
stem.append(nn.BatchNorm2d(output_dim))
29+
stem.append(nn.ReLU(inplace=True))
30+
input_dim = output_dim
31+
output_dim *= 2
32+
stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
33+
self.proj = nn.Sequential(*stem)
34+
35+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
36+
37+
def forward(self, x):
38+
B, C, H, W = x.shape
39+
# assert H == self.img_size[0] and W == self.img_size[1], \
40+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
41+
x = self.proj(x)
42+
if self.flatten:
43+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
44+
x = self.norm(x)
45+
return x
46+
47+
def ctranspath(img_size = 224, **kwargs):
48+
model = timm_ctp.create_model('swin_tiny_patch4_window7_224',
49+
embed_layer=ConvStem,
50+
pretrained=False,
51+
img_size=img_size,
52+
**kwargs)
53+
return model

0 commit comments

Comments
 (0)