Skip to content

Commit 3de8176

Browse files
committed
fix(datasets): ret_binary & ret_inst opts
1 parent dd356cc commit 3de8176

File tree

4 files changed

+64
-34
lines changed

4 files changed

+64
-34
lines changed

cellseg_models_pytorch/datasets/_base_dataset.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def __init__(
2727
img_transforms: List[str],
2828
inst_transforms: List[str],
2929
normalization: str = None,
30-
return_inst: bool = True,
30+
return_inst: bool = False,
31+
return_binary: bool = True,
3132
return_type: bool = True,
3233
return_sem: bool = False,
3334
return_weight: bool = False,
@@ -48,7 +49,9 @@ def __init__(
4849
normalization : str, optional
4950
Apply img normalization after all the transformations. One of "minmax",
5051
"norm", "percentile", None.
51-
return_inst : bool, default=True
52+
return_inst : bool, default=False
53+
If True, returns the instance labelled mask. (If the db contains these.)
54+
return_binary : bool, default=True
5255
If True, returns a binarized instance mask. (If the db contains these.)
5356
return_type : bool, default=True
5457
If True, returns a type mask. (If the db contains these.)
@@ -79,6 +82,7 @@ def __init__(
7982
)
8083

8184
# Return masks
85+
self.return_binary = return_binary
8286
self.return_inst = return_inst
8387
self.return_type = return_type
8488
self.return_sem = return_sem
@@ -89,7 +93,7 @@ def __init__(
8993
img_transforms.append(NORM_TRANSFORMS[normalization]())
9094

9195
inst_transforms = [INST_TRANSFORMS[tr](**kwargs) for tr in inst_transforms]
92-
if return_inst:
96+
if return_binary:
9397
inst_transforms.append(INST_TRANSFORMS["binarize"]())
9498

9599
if return_weight:
@@ -142,10 +146,14 @@ def _getitem(
142146
out[n] = aux_map
143147

144148
# remove redundant target (not needed in downstream).
149+
145150
if self.return_inst:
146-
out["inst"] = out["binary"]
147-
del out["binary"]
151+
out["inst_map"] = out["inst"]
148152
else:
149153
del out["inst"]
150154

155+
if self.return_binary:
156+
out["inst"] = out["binary"]
157+
del out["binary"]
158+
151159
return out

cellseg_models_pytorch/datasets/folder_dataset_train.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def __init__(
2121
img_transforms: List[str],
2222
inst_transforms: List[str],
2323
normalization: str = None,
24-
return_inst: bool = True,
24+
return_inst: bool = False,
25+
return_binary: bool = True,
2526
return_type: bool = True,
2627
return_sem: bool = False,
2728
return_weight: bool = False,
@@ -46,7 +47,9 @@ def __init__(
4647
normalization : str, optional
4748
Apply img normalization after all the transformations. One of "minmax",
4849
"norm", "percentile", None.
49-
return_inst : bool, default=True
50+
return_inst : bool, default=False
51+
If True, returns the instance labelled mask. (If the db contains these.)
52+
return_binary : bool, default=True
5053
If True, returns a binarized instance mask. (If the db contains these.)
5154
return_type : bool, default=True
5255
If True, returns a type mask. (If the db contains these.)
@@ -67,6 +70,7 @@ def __init__(
6770
return_type=return_type,
6871
return_sem=return_sem,
6972
return_weight=return_weight,
73+
return_binary=return_binary,
7074
**kwargs,
7175
)
7276

cellseg_models_pytorch/datasets/hdf5_dataset.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def __init__(
2424
img_transforms: List[str],
2525
inst_transforms: List[str],
2626
normalization: str = None,
27-
return_inst: bool = True,
27+
return_inst: bool = False,
28+
return_binary: bool = True,
2829
return_type: bool = True,
2930
return_sem: bool = False,
3031
return_weight: bool = False,
@@ -49,7 +50,9 @@ def __init__(
4950
normalization : str, optional
5051
Apply img normalization after all the transformations. One of "minmax",
5152
"norm", "percentile", None.
52-
return_inst : bool, default=True
53+
return_inst : bool, default=False
54+
If True, returns the instance labelled mask. (If the db contains these.)
55+
return_binary : bool, default=True
5356
If True, returns a binarized instance mask. (If the db contains these.)
5457
return_type : bool, default=True
5558
If True, returns a type mask. (If the db contains these.)
@@ -66,6 +69,7 @@ def __init__(
6669
return_type=return_type,
6770
return_sem=return_sem,
6871
return_weight=return_weight,
72+
return_binary=return_binary,
6973
)
7074

7175
self.path = Path(path)

cellseg_models_pytorch/training/lit/lightning_experiment.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -140,52 +140,66 @@ def step(
140140
loss = self.criterion(yhats=soft_masks, targets=targets)
141141
metrics = self.compute_metrics(soft_masks, targets, phase)
142142

143+
if batch_idx % self.log_freq == 0:
144+
ret_masks = soft_masks
145+
elif phase == "test":
146+
ret_masks = soft_masks
147+
else:
148+
ret_masks = None
149+
143150
ret = {
144-
"soft_masks": soft_masks if batch_idx % self.log_freq == 0 else None,
151+
"soft_masks": ret_masks,
145152
"loss": loss,
146153
}
147154

148155
return {**ret, **metrics}
149156

150-
def training_step(
151-
self, batch: Dict[str, torch.Tensor], batch_idx: int
157+
def log_step(
158+
self, batch: Dict[str, torch.Tensor], batch_idx: int, phase: str
152159
) -> Dict[str, torch.Tensor]:
153-
"""Training step + train metric logs."""
154-
res = self.step(batch, batch_idx, "train")
160+
"""Do the logging."""
161+
on_epoch = phase in ("val", "test")
162+
on_step = phase == "train"
163+
prog_bar = phase == "train"
164+
165+
res = self.step(batch, batch_idx, phase)
155166

156167
# log all the metrics
157-
self.log("train_loss", res["loss"], prog_bar=True, on_epoch=False, on_step=True)
168+
self.log(
169+
f"{phase}_loss",
170+
res["loss"],
171+
prog_bar=prog_bar,
172+
on_epoch=on_epoch,
173+
on_step=on_step,
174+
)
175+
158176
for k, val in res.items():
159177
if k not in ("loss", "soft_masks"):
160-
self.log(f"train_{k}", val, prog_bar=True, on_epoch=False, on_step=True)
178+
self.log(
179+
f"{phase}_{k}",
180+
val,
181+
prog_bar=prog_bar,
182+
on_epoch=on_epoch,
183+
on_step=on_step,
184+
)
161185

162186
return res
163187

188+
def training_step(
189+
self, batch: Dict[str, torch.Tensor], batch_idx: int
190+
) -> Dict[str, torch.Tensor]:
191+
"""Training step + train metric logs."""
192+
return self.log_step(batch, batch_idx, "train")
193+
164194
def validation_step(
165195
self, batch: Dict[str, torch.Tensor], batch_idx: int
166196
) -> Dict[str, torch.Tensor]:
167197
"""Validate step + validation metric logs + example outputs for logging."""
168-
res = self.step(batch, batch_idx, "val")
169-
170-
# log all the metrics
171-
self.log("val_loss", res["loss"], prog_bar=False, on_epoch=True, on_step=False)
172-
for k, val in res.items():
173-
if k not in ("loss", "soft_masks"):
174-
self.log(f"val_{k}", val, prog_bar=False, on_epoch=True, on_step=False)
175-
176-
return res
198+
return self.log_step(batch, batch_idx, "val")
177199

178200
def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None:
179201
"""Test step + test metric logs."""
180-
res = self.step(batch, batch_idx, "test")
181-
182-
del res["soft_masks"] # soft masks not needed for logging
183-
loss = res.pop("loss")
184-
185-
# log all the metrics
186-
self.log("test_loss", loss, prog_bar=False, on_epoch=True, on_step=False)
187-
for k, val in res.items():
188-
self.log(f"test_{k}", val, prog_bar=False, on_epoch=True, on_step=False)
202+
return self.log_step(batch, batch_idx, "test")
189203

190204
def compute_metrics(
191205
self,

0 commit comments

Comments
 (0)