Skip to content

Commit 80d7843

Browse files
committed
fixed embedding extraction
1 parent a2ffefe commit 80d7843

File tree

7 files changed

+58
-10
lines changed

7 files changed

+58
-10
lines changed

load-asap.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ source deactivate
99
source deactivate
1010
source activate 2D-VQ-AE-2
1111

12-
export PYTHONPATH=$PYTHONPATH:~/.conda/envs/2D-VQ-AE-2/lib/python3.9/site-packages/
12+
export PYTHONPATH=~/.conda/envs/2D-VQ-AE-2/lib/python3.9/site-packages/:$PYTHONPATH

scripts/extract_embeddings/conf/camelyon16_embeddings.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
defaults:
2+
- override hydra/launcher: submitit_snellius
3+
14
run_path: ???
25
force_outputs_or_multirun_root: True
36
dataset_target_hotswap: datamodules.camelyon16.CAMELYON16SlicePatchDataSet
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
timeout_min: 7200 # 5 days
2+
partition: 'gpu'
3+
cpus_per_task: 18
4+
tasks_per_node: 4
5+
gpus_per_node: 4
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
timeout_min: 7200 # 5 days
2+
partition: 'gpu'
3+
tasks_per_node: 1
4+
gpus_per_task: 1
5+
cpus_per_gpu: 18
6+
additional_parameters:
7+
gpu_bind: closest
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
- set -e
2+
3+
- echo "Loading modules"
4+
# TODO: remove this hacky relative path finding
5+
- source ${hydra.runtime.config_sources.1.path}/../../load-asap.sh
6+
7+
- echo "Setting environment variables"
8+
- export OMP_NUM_THREADS=18
9+
- export PYTHONFAULTHANDLER=1
10+
- export NCCL_DEBUG=INFO
11+
- export NCCL_ASYNC_ERROR_HANDLING=1
12+
13+
- if [[ -d ${oc.env:CAMELYON16_PATH} ]]; then
14+
- echo "Detected folder ${oc.env:CAMELYON16_PATH}, skipping copy"
15+
- else
16+
- echo "${oc.env:CAMELYON16_PATH} is not a valid file"
17+
- exit 1
18+
- fi
19+
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
defaults:
2+
- submitit_slurm
3+
- node@_here_: gpu_a100_shared
4+
- setup: camelyon16_setup

scripts/extract_embeddings/extract_embeddings.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import functools
4+
import logging
45
from collections.abc import Iterable
56
from dataclasses import dataclass
67
from functools import partial
@@ -50,6 +51,13 @@ def get_slices(patch_idx):
5051
for dim in (torch.cat([patch_idx[None], patch_idx[None]+1]).T.swapaxes(0, 1) * patch_size)
5152
])
5253

54+
def cast_to_lowest_dtype(array: np.ndarray) -> np.ndarray:
55+
return array.astype(
56+
bool
57+
if (array_min := array.min()) == 0 and (array_max := array.max()) == 1
58+
else np.result_type(*map(np.min_scalar_type, (array_min, array_max)))
59+
)
60+
5361
arrays, counts, patch_size = {}, {}, None
5462

5563
for ret_values in run_eval(model, dataset):
@@ -60,8 +68,8 @@ def get_slices(patch_idx):
6068

6169
slices = get_slices(patch_idx)
6270

63-
u_names, u_idx, u_inverse, u_counts = np.unique(
64-
names, return_counts=True, return_index=True, return_inverse=True
71+
u_names, u_idx, u_counts = np.unique(
72+
names, return_counts=True, return_index=True
6573
)
6674

6775
for name, image_index, count in zip(u_names, img_idx[u_idx], u_counts):
@@ -72,23 +80,23 @@ def get_slices(patch_idx):
7280
device=encodings.device
7381
))
7482

75-
mask = u_inverse == int(image_index)
83+
mask = img_idx == image_index
7684
current_array[slices[mask].swapaxes(0, 1)] = encodings[mask]
7785
current_count -= count # persistent because of np.array
7886

7987
if current_count == 0:
8088
counts.pop(name)
81-
yield name, (arr := arrays.pop(name).cpu().numpy()).astype(np.min_scalar_type(arr.max()))
89+
yield name, cast_to_lowest_dtype(arrays.pop(name).cpu().numpy())
8290

8391

8492
@torch.no_grad()
85-
def run_eval(model, dataset, batch_size=1800):
93+
def run_eval(model, dataset, batch_size=2500):
8694

8795
dataloader = DataLoader(
8896
dataset,
8997
batch_size=batch_size,
9098
pin_memory=True,
91-
num_workers=6,
99+
num_workers=18,
92100
prefetch_factor=10
93101
)
94102

@@ -103,11 +111,13 @@ def extract_path(path: str) -> str:
103111

104112
max_pool = None
105113

114+
logging.info("Setup complete, starting encoding")
115+
106116
for imgs, labels, (img_index, patch_index, img_path, label_path) in dataloader:
107117

108118
imgs, labels = (
109-
imgs.to(device, non_blocking=True, dtype=torch.half),
110-
labels.to(device, non_blocking=True, dtype=torch.int16)
119+
imgs.to(device, non_blocking=True),
120+
labels.to(device, non_blocking=True)
111121
)
112122

113123
with torch.autocast('cuda'):
@@ -121,7 +131,7 @@ def extract_path(path: str) -> str:
121131
yield (
122132
(data, list(map(extract_path, paths)), img_index, patch_index)
123133
for data, paths in (
124-
(encoding_indices.to(torch.int16), img_path), # make a ndim < 2**15 assumption
134+
(encoding_indices, img_path),
125135
(labels_pooled, label_path)
126136
)
127137
)

0 commit comments

Comments
 (0)