Skip to content

Commit fffe6cc

Browse files
committed
docs: typo fixes in notebooks
1 parent 7d08002 commit fffe6cc

File tree

2 files changed

+33
-105
lines changed

2 files changed

+33
-105
lines changed

examples/pannuke_nuclei_segmentation_cppnet.ipynb

Lines changed: 19 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,12 @@
4141
],
4242
"source": [
4343
"# version info\n",
44-
"import torch\n",
44+
"from platform import python_version\n",
45+
"\n",
4546
"import lightning\n",
47+
"import torch\n",
48+
"\n",
4649
"import cellseg_models_pytorch\n",
47-
"from platform import python_version\n",
4850
"\n",
4951
"print(\"torch version:\", torch.__version__)\n",
5052
"print(\"lightning version:\", lightning.__version__)\n",
@@ -86,6 +88,7 @@
8688
],
8789
"source": [
8890
"from pathlib import Path\n",
91+
"\n",
8992
"from cellseg_models_pytorch.datamodules import PannukeDataModule\n",
9093
"\n",
9194
"# fold1 and fold2 are used for training, fold3 is used for validation\n",
@@ -145,17 +148,18 @@
145148
}
146149
],
147150
"source": [
148-
"import numpy as np\n",
149151
"import matplotlib.pyplot as plt\n",
152+
"import numpy as np\n",
150153
"from skimage.color import label2rgb\n",
151154
"\n",
152-
"# filehandler contains methods to read and write images and masks\n",
153-
"from cellseg_models_pytorch.utils import FileHandler\n",
154155
"from cellseg_models_pytorch.transforms.functional import (\n",
155-
" gen_stardist_maps,\n",
156156
" gen_dist_maps,\n",
157+
" gen_stardist_maps,\n",
157158
")\n",
158159
"\n",
160+
"# filehandler contains methods to read and write images and masks\n",
161+
"from cellseg_models_pytorch.utils import FileHandler\n",
162+
"\n",
159163
"img_dir = save_dir / \"train\" / \"images\"\n",
160164
"mask_dir = save_dir / \"train\" / \"labels\"\n",
161165
"imgs = sorted(img_dir.glob(\"*\"))\n",
@@ -216,13 +220,12 @@
216220
"metadata": {},
217221
"outputs": [],
218222
"source": [
223+
"from typing import Dict, List, Optional, Tuple\n",
224+
"\n",
225+
"import lightning.pytorch as pl\n",
219226
"import torch\n",
220227
"import torch.nn as nn\n",
221228
"import torch.optim as optim\n",
222-
"import torchmetrics\n",
223-
"import lightning.pytorch as pl\n",
224-
"from copy import deepcopy\n",
225-
"from typing import List, Tuple, Dict, Optional\n",
226229
"\n",
227230
"\n",
228231
"class SegmentationExperiment(pl.LightningModule):\n",
@@ -232,7 +235,6 @@
232235
" multitask_loss: Dict[str, nn.Module],\n",
233236
" optimizer: optim.Optimizer,\n",
234237
" scheduler: optim.lr_scheduler._LRScheduler,\n",
235-
" branch_metrics: Optional[Dict[str, List[torchmetrics.Metric]]] = None,\n",
236238
" optimizer_kwargs: Optional[Dict[str, float]] = None,\n",
237239
" scheduler_kwargs: Optional[Dict[str, float]] = None,\n",
238240
" **kwargs,\n",
@@ -249,16 +251,10 @@
249251
" self.optimizer_kwargs = optimizer_kwargs or {}\n",
250252
" self.scheduler_kwargs = scheduler_kwargs or {}\n",
251253
"\n",
252-
" self.branch_metrics = branch_metrics\n",
253254
" self.criterion = multitask_loss\n",
254255
"\n",
255256
" self.save_hyperparameters(ignore=\"model\")\n",
256257
"\n",
257-
" metrics = self.configure_metrics()\n",
258-
" self.train_metrics = deepcopy(metrics)\n",
259-
" self.val_metrics = deepcopy(metrics)\n",
260-
" self.test_metrics = deepcopy(metrics)\n",
261-
"\n",
262258
" def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:\n",
263259
" \"\"\"Forward pass.\"\"\"\n",
264260
" return self.model(x)\n",
@@ -298,16 +294,9 @@
298294
" # forward backward pass\n",
299295
" loss, soft_masks, targets = self.step(batch)\n",
300296
"\n",
301-
" # compute training metrics\n",
302-
" metrics = self.compute_metrics(soft_masks, targets, \"train\")\n",
303-
"\n",
304297
" # log the loss\n",
305298
" self.log(\"train_loss\", loss, on_step=True, on_epoch=False, prog_bar=True)\n",
306299
"\n",
307-
" # log the metrics\n",
308-
" for metric_name, metric in metrics.items():\n",
309-
" self.log(metric_name, metric, on_step=True, on_epoch=False, prog_bar=True)\n",
310-
"\n",
311300
" return loss\n",
312301
"\n",
313302
" def validation_step(\n",
@@ -317,16 +306,9 @@
317306
" # forward pass\n",
318307
" loss, soft_masks, targets = self.step(batch)\n",
319308
"\n",
320-
" # compute validation metrics\n",
321-
" metrics = self.compute_metrics(soft_masks, targets, \"val\")\n",
322-
"\n",
323309
" # log the loss\n",
324310
" self.log(\"val_loss\", loss, on_step=False, on_epoch=True, prog_bar=False)\n",
325311
"\n",
326-
" # log the metrics\n",
327-
" for metric_name, metric in metrics.items():\n",
328-
" self.log(metric_name, metric, on_step=True, on_epoch=True, prog_bar=True)\n",
329-
"\n",
330312
" return loss\n",
331313
"\n",
332314
" def test_step(\n",
@@ -336,50 +318,11 @@
336318
" # forward pass\n",
337319
" loss, soft_masks, targets = self.step(batch)\n",
338320
"\n",
339-
" # compute validation metrics\n",
340-
" metrics = self.compute_metrics(soft_masks, targets, \"test\")\n",
341-
"\n",
342321
" # log the loss\n",
343322
" self.log(\"test_loss\", loss, on_step=False, on_epoch=True, prog_bar=False)\n",
344323
"\n",
345-
" # log the metrics\n",
346-
" for metric_name, metric in metrics.items():\n",
347-
" self.log(metric_name, metric, on_step=False, on_epoch=True, prog_bar=False)\n",
348-
"\n",
349324
" return loss\n",
350325
"\n",
351-
" def compute_metrics(\n",
352-
" self,\n",
353-
" preds: Dict[str, torch.Tensor],\n",
354-
" targets: Dict[str, torch.Tensor],\n",
355-
" phase: str,\n",
356-
" ) -> Dict[str, torch.Tensor]:\n",
357-
" \"\"\"Compute metrics for logging.\"\"\"\n",
358-
" if phase == \"train\":\n",
359-
" metrics_dict = self.train_metrics\n",
360-
" elif phase == \"val\":\n",
361-
" metrics_dict = self.val_metrics\n",
362-
" elif phase == \"test\":\n",
363-
" metrics_dict = self.test_metrics\n",
364-
"\n",
365-
" ret = {}\n",
366-
" for metric_name, metric in metrics_dict.items():\n",
367-
" if metric is not None:\n",
368-
" branch = metric_name.split(\"_\")[0]\n",
369-
" ret[metric_name] = metric(preds[branch], targets[branch])\n",
370-
"\n",
371-
" return ret\n",
372-
"\n",
373-
" def configure_metrics(self) -> Dict[str, torchmetrics.Metric]:\n",
374-
" \"\"\"Configure the metrics for the model.\"\"\"\n",
375-
" # We can put all the metrics in a ModuleDict and return it.\n",
376-
" metrics = nn.ModuleDict()\n",
377-
" for branch, metric_list in self.branch_metrics.items():\n",
378-
" for metric in metric_list:\n",
379-
" metrics[f\"{branch}_{metric}\"] = metric\n",
380-
"\n",
381-
" return metrics\n",
382-
"\n",
383326
" def configure_optimizers(self) -> List[optim.Optimizer]:\n",
384327
" \"\"\"Configure the optimizers for the model.\"\"\"\n",
385328
" opt = self.optimizer(self.parameters(), **self.optimizer_kwargs)\n",
@@ -411,9 +354,6 @@
411354
"\n",
412355
"**Branch losses.** For each output of the model, we define a joint-loss function. These losses are summed together during backprop to form a multi-task loss. In `CPP-Net`, we get two outputs for the `stardist` output instead of one. The reason is that `CPP-Net` was designed to refine the `stardist`-output for better segmentation performance, thus, the model returns the both the regular `stardist` and the refined `stardist_refined` output. For both of these outputs we set the masked `MAE`-loss. For the `\"dist\"`-outputs we use the masked `BCELoss` (binary-cross-entropy) with `MSE`, and for the `\"type\"`-output we will use the masked multiclass categorical `CELoss` with `DiceLoss`.\n",
413356
"\n",
414-
"**Logging metrics.**\n",
415-
"For the nuclei type masks we will monitor Jaccard-index i.e. the mIoU metric during training. The metric is averaged over the class-specific mIoUs. The metrics are from the `torchmetrics` library. Also, we will monitor the MSE-metric for the `stardist`-outputs.\n",
416-
"\n",
417357
"**Optimizer and scheduler.**\n",
418358
"The optimizer used here is [AdamW](https://arxiv.org/abs/1711.05101). The learning rate is scheduled with the `ReduceLROnPlateau` schedule. The learning rate will be set to 0.0003. \n",
419359
"\n",
@@ -438,18 +378,17 @@
438378
],
439379
"source": [
440380
"import torch.optim as optim\n",
441-
"from torchmetrics import JaccardIndex, MeanSquaredError\n",
442381
"\n",
443-
"from cellseg_models_pytorch.models import cppnet_base_multiclass\n",
444382
"from cellseg_models_pytorch.losses import (\n",
445383
" MAE,\n",
446384
" MSE,\n",
447-
" DiceLoss,\n",
448385
" BCELoss,\n",
449386
" CELoss,\n",
387+
" DiceLoss,\n",
450388
" JointLoss,\n",
451389
" MultiTaskLoss,\n",
452390
")\n",
391+
"from cellseg_models_pytorch.models import cppnet_base_multiclass\n",
453392
"\n",
454393
"# seed the experiment for reproducibility\n",
455394
"pl.seed_everything(42)\n",
@@ -481,21 +420,6 @@
481420
" }, # weights for each branch\n",
482421
")\n",
483422
"\n",
484-
"# Define the metrics that will be logged for each model output.\n",
485-
"# We will log the Jaccard Index for the type maps and the Mean Squared Error\n",
486-
"# for the horizontal and vertical gradients output. For the type maps, the metric\n",
487-
"# is the average computed over each class separately ('macro').\n",
488-
"branch_metrics = {\n",
489-
" \"dist\": [None],\n",
490-
" \"type\": [\n",
491-
" JaccardIndex(\n",
492-
" task=\"multiclass\", average=\"macro\", num_classes=6, compute_on_cpu=True\n",
493-
" )\n",
494-
" ],\n",
495-
" \"stardist\": [MeanSquaredError(compute_on_cpu=True)],\n",
496-
" \"stardist_refined\": [MeanSquaredError(compute_on_cpu=True)],\n",
497-
"}\n",
498-
"\n",
499423
"# Initialize the optimizer.\n",
500424
"# We will be using the AdamW optimizer from the torch.optim library with learning rate of 0.0003.\n",
501425
"adamw = optim.AdamW\n",
@@ -511,7 +435,6 @@
511435
"experiment = SegmentationExperiment(\n",
512436
" model=model,\n",
513437
" multitask_loss=multitask_loss,\n",
514-
" branch_metrics=branch_metrics,\n",
515438
" optimizer=adamw,\n",
516439
" scheduler=scheduler,\n",
517440
" optimizer_kwargs=optimizer_kwargs,\n",
@@ -560,7 +483,7 @@
560483
" monitor=\"val_loss\",\n",
561484
" mode=\"min\",\n",
562485
")\n",
563-
"ckpt_callback.CHECKPOINT_NAME_LAST = f\"cppnet_last\"\n",
486+
"ckpt_callback.CHECKPOINT_NAME_LAST = \"cppnet_last\"\n",
564487
"callbacks.append(ckpt_callback)\n",
565488
"\n",
566489
"# Lightning training\n",
@@ -1002,14 +925,14 @@
1002925
"inferer = ResizeInferer(\n",
1003926
" model=experiment,\n",
1004927
" input_path=save_dir / \"valid\" / \"images\",\n",
1005-
" checkpoint_path=Path.home() / \"pannuke\" / \"cppnet\" / \"cppnet_last.ckpt\",\n",
1006928
" out_activations=out_acts,\n",
1007929
" out_boundary_weights=out_weights,\n",
1008930
" resize=(256, 256), # Not actually resizing anything,\n",
1009931
" instance_postproc=\"stardist\",\n",
1010932
" batch_size=8,\n",
1011933
" n_images=50, # Use only the 50 first images of the folder,\n",
1012934
" normalization=\"percentile\",\n",
935+
" checkpoint_path=trainer.checkpoint_callback.best_model_path,\n",
1013936
")\n",
1014937
"inferer.infer()"
1015938
]
@@ -1050,9 +973,10 @@
1050973
}
1051974
],
1052975
"source": [
976+
"import matplotlib.patches as mpatches\n",
1053977
"import numpy as np\n",
978+
"\n",
1054979
"from cellseg_models_pytorch.utils import draw_thing_contours\n",
1055-
"import matplotlib.patches as mpatches\n",
1056980
"\n",
1057981
"fig, ax = plt.subplots(5, 2, figsize=(10, 17))\n",
1058982
"ax = ax.flatten()\n",

examples/pannuke_nuclei_segmentation_omnipose.ipynb

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
"outputs": [],
2424
"source": [
2525
"# !pip install cellseg-models-pytorch\n",
26-
"# !pip install accelerate"
26+
"# !pip install accelerate\n",
27+
"# !pip install torchmetrics"
2728
]
2829
},
2930
{
@@ -43,10 +44,12 @@
4344
}
4445
],
4546
"source": [
46-
"import torch\n",
47+
"from platform import python_version\n",
48+
"\n",
4749
"import accelerate\n",
50+
"import torch\n",
51+
"\n",
4852
"import cellseg_models_pytorch\n",
49-
"from platform import python_version\n",
5053
"\n",
5154
"print(\"torch version:\", torch.__version__)\n",
5255
"print(\"accelerate version:\", accelerate.__version__)\n",
@@ -88,6 +91,7 @@
8891
],
8992
"source": [
9093
"from pathlib import Path\n",
94+
"\n",
9195
"from cellseg_models_pytorch.datamodules import PannukeDataModule\n",
9296
"\n",
9397
"fold_split = {\"fold1\": \"train\", \"fold2\": \"train\", \"fold3\": \"valid\"}\n",
@@ -148,13 +152,14 @@
148152
}
149153
],
150154
"source": [
151-
"import numpy as np\n",
152155
"import matplotlib.pyplot as plt\n",
156+
"import numpy as np\n",
153157
"from skimage.color import label2rgb\n",
154158
"\n",
159+
"from cellseg_models_pytorch.transforms.functional import gen_omni_flow_maps\n",
160+
"\n",
155161
"# filehandler contains methods to read and write images and masks\n",
156162
"from cellseg_models_pytorch.utils import FileHandler\n",
157-
"from cellseg_models_pytorch.transforms.functional import gen_omni_flow_maps\n",
158163
"\n",
159164
"img_dir = save_dir / \"train\" / \"images\"\n",
160165
"mask_dir = save_dir / \"train\" / \"labels\"\n",
@@ -238,15 +243,14 @@
238243
"outputs": [],
239244
"source": [
240245
"import accelerate\n",
241-
"from accelerate.utils import set_seed\n",
242-
"\n",
243246
"import torch.nn as nn\n",
244247
"import torch.nn.functional as F\n",
248+
"from accelerate.utils import set_seed\n",
245249
"from torchmetrics import JaccardIndex, MeanSquaredError\n",
246250
"from tqdm import tqdm\n",
247251
"\n",
252+
"from cellseg_models_pytorch.losses import CELoss, DiceLoss, JointLoss, MultiTaskLoss\n",
248253
"from cellseg_models_pytorch.models import omnipose_base\n",
249-
"from cellseg_models_pytorch.losses import DiceLoss, CELoss, JointLoss, MultiTaskLoss\n",
250254
"\n",
251255
"\n",
252256
"# Quick wrapper for MAE loss to make it fit the JointLoss API\n",
@@ -961,10 +965,10 @@
961965
}
962966
],
963967
"source": [
964-
"import numpy as np\n",
965-
"from cellseg_models_pytorch.utils import draw_thing_contours\n",
966968
"import matplotlib.patches as mpatches\n",
969+
"import numpy as np\n",
967970
"\n",
971+
"from cellseg_models_pytorch.utils import draw_thing_contours\n",
968972
"\n",
969973
"fig, ax = plt.subplots(5, 2, figsize=(10, 17))\n",
970974
"ax = ax.flatten()\n",

0 commit comments

Comments
 (0)