Skip to content

Commit ea9a69e

Browse files
authored
Merge pull request #44 from rajewsky-lab/fix_imaging_preprocessing
Fix imaging preprocessing
2 parents 6b377e9 + 69eefc9 commit ea9a69e

20 files changed

+386
-189
lines changed

docs/computational/generate_expression_matrix.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ To create such a spatial cell-by-gene ($M\times G$) expression matrix, you will
1313

1414
We efficiently segment cells (or nuclei) from staining images using [cellpose](https://github.com/MouseLand/cellpose).
1515
We provide a model that we fine-tuned for segmentation of fresh-frozen, H&E-stained tissue,
16-
[here](https://github.com/danilexn/openst/blob/main/models/HE_cellpose_rajewsky).
16+
[here](http://bimsbstatic.mdc-berlin.de/rajewsky/openst-public-data/models/HE_cellpose_rajewsky).
1717
You can specify any other model that works best for your data -
1818
refer to the [cellpose](https://cellpose.readthedocs.io/en/latest/index.html) documentation.
1919

@@ -31,7 +31,7 @@ By default, segmentation is extended radially 10 pixels. This can be changed wit
3131
Make sure to populate the arguments with the values specific to your dataset. Here, we provide `--h5-in` consistent
3232
with the previous steps, `--image-in` and `--mask-out` will read and write the staining and mask inside the Open-ST h5 object,
3333
and `--model` is `HE_cellpose_rajewsky`, the default used in our manuscript. This is the model we recommend for H&E images, and
34-
weights are automatically downloaded. It is also [provided in our repo](https://github.com/rajewsky-lab/openst/blob/main/models/HE_cellpose_rajewsky).
34+
weights are automatically downloaded. It is also [provided in our repo](http://bimsbstatic.mdc-berlin.de/rajewsky/openst-public-data/models/HE_cellpose_rajewsky).
3535
The rest of parameters can be checked with `openst segment --help`.
3636

3737
!!! tip

docs/computational/preprocessing_imaging.md

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,33 @@ tile-scan. You can run this by running the following command on the stitched ima
6363

6464
```bash
6565
openst image_preprocess \
66-
--input=<path_to_input_image> \
67-
--CUT \
68-
--CUT-model=<path_to_model> \
69-
--output=<path_to_output>
66+
--image-in Image_Stitched_Composite.tif \
67+
--image-out Image_Stitched_Composite_Restored.tif
68+
# --device cuda # in case you have a CUDA-compatible GPU
7069
```
7170

72-
Make sure to replace the placeholders (`<...>`). For instance,
73-
`<path_to_input_image>` is the full path and file name of the previously stitched image; `<path_to_model>`
74-
is filename our pre-trained [CUT model](https://github.com/rajewsky-lab/openst/models/CUT.pth), and `<output_image>`
75-
is the path to a folder (writeable) and desired filename for the output image.
71+
If you ran `openst merge_modalities`, then imaging data will be contained inside the Open-ST h5 object, and the command
72+
can be adapted:
73+
74+
```bash
75+
openst image_preprocess \
76+
--h5-in multimodal/spots_stitched.h5ad # just a placeholder, adapt
77+
# --device cuda # in case you have a CUDA-compatible GPU
78+
```
79+
80+
By default, the image will be loaded from the key `uns/spatial/staining_image`, and the CUT-restored image will be saved
81+
to `uns/spatial/staining_image_restored`. You can preview the image restoration results using:
82+
83+
```bash
84+
openst preview \
85+
--h5-in multimodal/spots_stitched.h5ad
86+
--image-key uns/spatial/staining_image uns/spatial/staining_image_restored
87+
```
88+
89+
This will load the two images and visualize it using `napari`. Later, you can run segmentation and pairwise alignment
90+
using either the default merged image (`uns/spatial/staining_image`), or the restored image (`uns/spatial/staining_image_restored`).
91+
Always assess these preprocessing choices (quantitatively and qualitatively) to decide whether these make sense for your data.
92+
7693

7794
## Expected output
7895
After running the stitching (and optionally correction algorithms), you will have a single image file per sample. This, together with

docs/examples/adult_mouse/generate_expression_matrix.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ into a cell-by-genes matrix, where cells are defined from the segmentation mask.
1212
To create such a spatial cell-by-gene ($M\times G$) expression matrix, you will first need a segmentation mask.
1313

1414
We efficiently segment cells (or nuclei) from staining images using [cellpose](https://github.com/MouseLand/cellpose).
15-
For the H&E-stained tissue provided in this example, we used our [fine-tuned model](https://github.com/rajewsky-lab/openst/blob/main/models/HE_cellpose_rajewsky).
15+
For the H&E-stained tissue provided in this example, we used our [fine-tuned model](http://bimsbstatic.mdc-berlin.de/rajewsky/openst-public-data/models/HE_cellpose_rajewsky).
1616
Make sure to download it and save it into a new `models` folder that you need to create under the `openst_adult_demo` main folder.
1717

1818
You can run the segmentation on the previously created `openst_demo_adult_mouse_spatial_beads_puck_collection_aligned.h5ad` file, which

docs/examples/e13_mouse/generate_expression_matrix.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ into a cell-by-genes matrix, where cells are defined from the segmentation mask.
1212
To create such a spatial cell-by-gene ($M\times G$) expression matrix, you will first need a segmentation mask.
1313

1414
We efficiently segment cells (or nuclei) from staining images using [cellpose](https://github.com/MouseLand/cellpose).
15-
For the H&E-stained tissue provided in this example, we used our [fine-tuned model](https://github.com/rajewsky-lab/openst/blob/main/models/HE_cellpose_rajewsky).
15+
For the H&E-stained tissue provided in this example, we used our [fine-tuned model](http://bimsbstatic.mdc-berlin.de/rajewsky/openst-public-data/models/HE_cellpose_rajewsky).
1616
Make sure to download it and save it into a new `models` folder that you need to create under the `openst_e13_demo` main folder.
1717

1818
You can run the segmentation on the previously created `openst_demo_e13_mouse_head_spatial_beads_puck_collection_aligned.h5ad` file, which

openst/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.0.2'
1+
__version__ = '0.0.11'

openst/cli.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -680,17 +680,46 @@ def get_image_preprocess_parser():
680680
allow_abbrev=False,
681681
add_help=False,
682682
)
683-
parser.add_argument("--input_img", type=str, required=True, help="path to input image")
684-
parser.add_argument("--cut_dir", type=str, required=True, help="path to CUT directory (to save patched images)")
685-
parser.add_argument("--tile_size_px", type=int, required=True, help="size of the tile in pixels")
683+
parser.add_argument(
684+
"--h5-in",
685+
type=str,
686+
default="",
687+
help="""If set, image is loaded from the Open-ST h5 object (key in --image-in),
688+
and retored image is saved there (to the key --image-out)""",
689+
)
690+
parser.add_argument(
691+
"--image-in",
692+
type=str,
693+
default="uns/spatial/staining_image",
694+
help="Key or path to the input image",
695+
)
696+
parser.add_argument(
697+
"--image-out",
698+
type=str,
699+
default="uns/spatial/staining_image_restored",
700+
help="Key or path where the restored image will be written into",
701+
)
702+
parser.add_argument(
703+
"--tile-size-px",
704+
type=int,
705+
default=512,
706+
help="The input image is split into squared tiles of side `--tile-size-px`, for inference."+
707+
"Larger values avoid boundary effects, but require more memory.",
708+
)
709+
parser.add_argument("--model", type=str, default="HE_CUT_rajewsky", help="CUT model used for image restoration")
686710
parser.add_argument(
687711
"--device",
688712
type=str,
689713
default="cpu",
690714
choices=["cpu", "cuda"],
691-
help="Device used to run feature matching model. Can be ['cpu', 'cuda']",
715+
help="Device used to run CUT restoration model. Can be ['cpu', 'cuda']",
716+
)
717+
parser.add_argument(
718+
"--num-workers",
719+
type=int,
720+
default=-1,
721+
help="Number of CPU workers for parallel processing",
692722
)
693-
parser.add_argument("--checkpoints_dir", type=str, default="./checkpoints", help="models are saved here")
694723

695724
return parser
696725

openst/preprocessing/CUT/models/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
"""
2020

2121
import importlib
22-
from models.base_model import BaseModel
22+
import logging
2323

24+
from openst.preprocessing.CUT.models.base_model import BaseModel
2425

2526
def find_model_using_name(model_name):
2627
"""Import the module "models/[model_name]_model.py".
@@ -29,7 +30,7 @@ def find_model_using_name(model_name):
2930
be instantiated. It has to be a subclass of BaseModel,
3031
and it is case-insensitive.
3132
"""
32-
model_filename = "models." + model_name + "_model"
33+
model_filename = "openst.preprocessing.CUT.models." + model_name + "_model"
3334
modellib = importlib.import_module(model_filename)
3435
model = None
3536
target_model_name = model_name.replace('_', '') + 'model'
@@ -63,5 +64,5 @@ def create_model(opt):
6364
"""
6465
model = find_model_using_name(opt.model)
6566
instance = model(opt)
66-
print("model [%s] was created" % type(instance).__name__)
67+
logging.info("Model architecture `%s` was created" % type(instance).__name__)
6768
return instance

openst/preprocessing/CUT/models/base_model.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import os
22
import torch
3+
import logging
34
from collections import OrderedDict
45
from abc import ABC, abstractmethod
5-
from . import networks
6+
7+
from openst.preprocessing.CUT.models import networks
68

79

810
class BaseModel(ABC):
@@ -211,7 +213,7 @@ def load_networks(self, epoch):
211213
net = getattr(self, 'net' + name)
212214
if isinstance(net, torch.nn.DataParallel):
213215
net = net.module
214-
print('loading the model from %s' % load_path)
216+
logging.info(f'Loading model weights from {load_path}')
215217
# if you are using PyTorch newer than 0.4 (e.g., built from
216218
# GitHub source), you can remove str() on self.device
217219
state_dict = torch.load(load_path, map_location=str(self.device))
@@ -229,17 +231,18 @@ def print_networks(self, verbose):
229231
Parameters:
230232
verbose (bool) -- if verbose: print the network architecture
231233
"""
232-
print('---------- Networks initialized -------------')
234+
message = '---------- (Start) Networks initialized -------------\n'
233235
for name in self.model_names:
234236
if isinstance(name, str):
235237
net = getattr(self, 'net' + name)
236238
num_params = 0
237239
for param in net.parameters():
238240
num_params += param.numel()
239241
if verbose:
240-
print(net)
241-
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
242-
print('-----------------------------------------------')
242+
message += f"{net}\n"
243+
message += '[Network %s] Total number of parameters : %.3f M\n' % (name, num_params / 1e6)
244+
message += '---------- (End) Networks initialized -------------'
245+
logging.debug(message)
243246

244247
def set_requires_grad(self, nets, requires_grad=False):
245248
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations

openst/preprocessing/CUT/models/cut_model.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import numpy as np
22
import torch
3-
from .base_model import BaseModel
4-
from . import networks
5-
from .patchnce import PatchNCELoss
6-
import util.util as util
73

4+
from openst.preprocessing.CUT.models.base_model import BaseModel
5+
from openst.preprocessing.CUT.models import networks
6+
from openst.preprocessing.CUT.models.patchnce import PatchNCELoss
7+
from openst.preprocessing.CUT.util import util
88

99
class CUTModel(BaseModel):
1010
""" This class implements CUT and FastCUT model, described in the paper
@@ -98,10 +98,10 @@ def data_dependent_initialize(self, data):
9898
initialized at the first feedforward pass with some input images.
9999
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
100100
"""
101-
bs_per_gpu = data["A"].size(0) // max(len(self.opt.gpu_ids), 1)
101+
bs_per_gpu = self.opt.batch_size
102102
self.set_input(data)
103103
self.real_A = self.real_A[:bs_per_gpu]
104-
self.real_B = self.real_B[:bs_per_gpu]
104+
self.real_B = None
105105
self.forward() # compute fake images: G(A)
106106
if self.opt.isTrain:
107107
self.compute_D_loss().backward() # calculate gradients for D
@@ -138,14 +138,11 @@ def set_input(self, input):
138138
input (dict): include the data itself and its metadata information.
139139
The option 'direction' can be used to swap domain A and domain B.
140140
"""
141-
AtoB = self.opt.direction == 'AtoB'
142-
self.real_A = input['A' if AtoB else 'B'].to(self.device)
143-
self.real_B = input['B' if AtoB else 'A'].to(self.device)
144-
self.image_paths = input['A_paths' if AtoB else 'B_paths']
141+
self.real_A = input['A'].to(self.device)
145142

146143
def forward(self):
147144
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
148-
self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A
145+
self.real = self.real_A
149146
if self.opt.flip_equivariance:
150147
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
151148
if self.flipped_for_equivariance:

openst/preprocessing/CUT/models/networks.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,6 @@
55
import functools
66
from torch.optim import lr_scheduler
77
import numpy as np
8-
from .stylegan_networks import StyleGAN2Discriminator, StyleGAN2Generator, TileStyleGAN2Discriminator
9-
10-
###############################################################################
11-
# Helper Functions
12-
###############################################################################
13-
148

159
def get_filter(filt_size=3):
1610
if(filt_size == 1):
@@ -256,10 +250,6 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
256250
net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
257251
elif netG == 'unet_256':
258252
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
259-
elif netG == 'stylegan2':
260-
net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, opt=opt)
261-
elif netG == 'smallstylegan2':
262-
net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, n_blocks=2, opt=opt)
263253
elif netG == 'resnet_cat':
264254
n_blocks = 8
265255
net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu')
@@ -323,8 +313,6 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal'
323313
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias,)
324314
elif netD == 'pixel': # classify if each pixel is real or fake
325315
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
326-
elif 'stylegan2' in netD:
327-
net = StyleGAN2Discriminator(input_nc, ndf, n_layers_D, no_antialias=no_antialias, opt=opt)
328316
else:
329317
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
330318
return init_net(net, init_type, init_gain, gpu_ids,
@@ -1214,8 +1202,6 @@ def forward(self, input):
12141202

12151203
class UnetSkipConnectionBlock(nn.Module):
12161204
"""Defines the Unet submodule with skip connection.
1217-
X -------------------identity----------------------
1218-
|-- downsampling -- |submodule| -- upsampling --|
12191205
"""
12201206

12211207
def __init__(self, outer_nc, inner_nc, input_nc=None,
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from packaging import version
2+
import torch
3+
from torch import nn
4+
5+
class PatchNCELoss(nn.Module):
6+
def __init__(self, opt):
7+
super().__init__()
8+
self.opt = opt
9+
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
10+
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
11+
12+
def forward(self, feat_q, feat_k):
13+
num_patches = feat_q.shape[0]
14+
dim = feat_q.shape[1]
15+
feat_k = feat_k.detach()
16+
17+
# pos logit
18+
l_pos = torch.bmm(
19+
feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1))
20+
l_pos = l_pos.view(num_patches, 1)
21+
22+
# neg logit
23+
24+
# Should the negatives from the other samples of a minibatch be utilized?
25+
# In CUT and FastCUT, we found that it's best to only include negatives
26+
# from the same image. Therefore, we set
27+
# --nce_includes_all_negatives_from_minibatch as False
28+
# However, for single-image translation, the minibatch consists of
29+
# crops from the "same" high-resolution image.
30+
# Therefore, we will include the negatives from the entire minibatch.
31+
if self.opt.nce_includes_all_negatives_from_minibatch:
32+
# reshape features as if they are all negatives of minibatch of size 1.
33+
batch_dim_for_bmm = 1
34+
else:
35+
batch_dim_for_bmm = self.opt.batch_size
36+
37+
# reshape features to batch size
38+
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
39+
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
40+
npatches = feat_q.size(1)
41+
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
42+
43+
# diagonal entries are similarity between same features, and hence meaningless.
44+
# just fill the diagonal with very small number, which is exp(-10) and almost zero
45+
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
46+
l_neg_curbatch.masked_fill_(diagonal, -10.0)
47+
l_neg = l_neg_curbatch.view(-1, npatches)
48+
49+
out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
50+
51+
loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
52+
device=feat_q.device))
53+
54+
return loss

openst/preprocessing/CUT/models/template_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@
99
Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
1010
min_<netG> ||netG(data_A) - data_B||_1
1111
You need to implement the following functions:
12-
<modify_commandline_options>: Add model-specific options and rewrite default values for existing options.
12+
<modify_commandline_options>: Add model-specific options and rewrite default values for existing options.
1313
<__init__>: Initialize this model class.
1414
<set_input>: Unpack input data and perform data pre-processing.
1515
<forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>.
1616
<optimize_parameters>: Update network weights; it will be called in every training iteration.
1717
"""
1818
import torch
19-
from .base_model import BaseModel
20-
from . import networks
19+
20+
from openst.preprocessing.CUT.models.base_model import BaseModel
21+
from openst.preprocessing.CUT.models import networks
2122

2223

2324
class TemplateModel(BaseModel):

0 commit comments

Comments
 (0)