Skip to content

Commit 244c796

Browse files
committed
fix: fix minor bugs and more verbose errors
1 parent a0527b9 commit 244c796

File tree

5 files changed

+23
-11
lines changed

5 files changed

+23
-11
lines changed

cellseg_models_pytorch/datasets/folder_dataset_train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ def read_img_mask(
142142
out["inst"] = masks["inst_map"]
143143
except KeyError:
144144
raise KeyError(
145-
f"The file {self.fnames_masks[ix]} does not contain key `inst_map`."
145+
f"The file {self.fnames_masks[ix]} does not contain key `inst_map`. "
146+
"Try setting `return_inst=False`."
146147
)
147148

148149
if return_type:
@@ -151,6 +152,7 @@ def read_img_mask(
151152
except KeyError:
152153
raise KeyError(
153154
f"The file {self.fnames_masks[ix]} does not contain key `type_map`."
155+
" Try setting `return_type=False`."
154156
)
155157

156158
if return_sem:
@@ -159,6 +161,7 @@ def read_img_mask(
159161
except KeyError:
160162
raise KeyError(
161163
f"The file {self.fnames_masks[ix]} does not contain key `sem_map`."
164+
"Try setting `return_sem=False`."
162165
)
163166

164167
return out

cellseg_models_pytorch/datasets/hdf5_dataset.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,20 +119,27 @@ def read_h5_patch(
119119
out["inst"] = h5.root.insts[ix, ...]
120120
except Exception:
121121
raise IOError(
122-
"The HDF5 database does not contain instance labelled masks."
122+
"The HDF5 database does not contain instance labelled masks. Try "
123+
"setting `return_inst=False`"
123124
)
124125

125126
if return_type:
126127
try:
127128
out["type"] = h5.root.types[ix, ...]
128129
except Exception:
129-
raise IOError("The HDF5 database does not contain type masks.")
130+
raise IOError(
131+
"The HDF5 database does not contain type masks. Try setting "
132+
"`return_type = False` "
133+
)
130134

131135
if return_sem:
132136
try:
133137
out["sem"] = h5.root.areas[ix, ...]
134138
except Exception:
135-
raise IOError("The HDF5 database does not contain semantic masks.")
139+
raise IOError(
140+
"The HDF5 database does not contain semantic masks. Try "
141+
"setting `return_sem = False`"
142+
)
136143

137144
return out
138145

cellseg_models_pytorch/losses/weighted_base_loss.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def apply_ls_to_target(
7171
) -> torch.Tensor:
7272
"""Apply regular label smoothing to the target map.
7373
74+
https://arxiv.org/abs/1512.00567
75+
7476
Parameters
7577
----------
7678
target : torch.Tensor

cellseg_models_pytorch/training/callbacks/wandb_callbacks.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def train_batch_end(
147147
outputs: Dict[str, torch.Tensor],
148148
batch: Dict[str, torch.Tensor],
149149
batch_idx: int,
150-
dataloader_idx: int,
150+
dataloader_idx: int = 0,
151151
) -> None:
152152
"""Log the inputs and outputs of the model to wandb."""
153153
self.batch_end(trainer, outputs["soft_masks"], batch, batch_idx, phase="train")
@@ -159,7 +159,7 @@ def validation_batch_end(
159159
outputs: Dict[str, torch.Tensor],
160160
batch: Dict[str, torch.Tensor],
161161
batch_idx: int,
162-
dataloader_idx: int,
162+
dataloader_idx: int = 0,
163163
) -> None:
164164
"""Log the inputs and outputs of the model to wandb."""
165165
self.batch_end(trainer, outputs["soft_masks"], batch, batch_idx, phase="val")
@@ -171,7 +171,7 @@ def test_batch_end(
171171
outputs: Dict[str, torch.Tensor],
172172
batch: Dict[str, torch.Tensor],
173173
batch_idx: int,
174-
dataloader_idx: int,
174+
dataloader_idx: int = 0,
175175
) -> None:
176176
"""Log the inputs and outputs of the model to wandb."""
177177
self.batch_end(trainer, outputs["soft_masks"], batch, batch_idx, phase="test")
@@ -198,6 +198,7 @@ def batch_end(
198198
batch: Dict[str, torch.Tensor],
199199
batch_idx: int,
200200
phase: str,
201+
dataloader_idx: int = None,
201202
) -> None:
202203
"""Log metrics at every 100th step to wandb."""
203204
if batch_idx % self.freq == 0:
@@ -223,7 +224,7 @@ def on_validation_batch_end(
223224
outputs: Dict[str, torch.Tensor],
224225
batch: Dict[str, torch.Tensor],
225226
batch_idx: int,
226-
dataloader_idx: int,
227+
dataloader_idx: int = 0,
227228
) -> None:
228229
"""Call the callback at val time."""
229230
self.validation_batch_end(
@@ -237,7 +238,7 @@ def on_train_batch_end(
237238
outputs: Dict[str, torch.Tensor],
238239
batch: Dict[str, torch.Tensor],
239240
batch_idx: int,
240-
dataloader_idx: int,
241+
dataloader_idx: int = 0,
241242
) -> None:
242243
"""Call the callback at val time."""
243244
self.train_batch_end(

cellseg_models_pytorch/training/tests/test_training.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from cellseg_models_pytorch.training.lit import SegmentationExperiment
1010

1111

12-
# @pytest.mark.parametrize
13-
def test_training(img_patch_dir, mask_patch_dir):
12+
def test_training(img_patch_dir, mask_patch_dir, auto_lr):
1413
train_ds = SegmentationFolderDataset(
1514
path=img_patch_dir.as_posix(),
1615
mask_path=mask_patch_dir.as_posix(),

0 commit comments

Comments
 (0)