|
41 | 41 | ],
|
42 | 42 | "source": [
|
43 | 43 | "# version info\n",
|
44 |
| - "import torch\n", |
| 44 | + "from platform import python_version\n", |
| 45 | + "\n", |
45 | 46 | "import lightning\n",
|
| 47 | + "import torch\n", |
| 48 | + "\n", |
46 | 49 | "import cellseg_models_pytorch\n",
|
47 |
| - "from platform import python_version\n", |
48 | 50 | "\n",
|
49 | 51 | "print(\"torch version:\", torch.__version__)\n",
|
50 | 52 | "print(\"lightning version:\", lightning.__version__)\n",
|
|
86 | 88 | ],
|
87 | 89 | "source": [
|
88 | 90 | "from pathlib import Path\n",
|
| 91 | + "\n", |
89 | 92 | "from cellseg_models_pytorch.datamodules import PannukeDataModule\n",
|
90 | 93 | "\n",
|
91 | 94 | "# fold1 and fold2 are used for training, fold3 is used for validation\n",
|
|
145 | 148 | }
|
146 | 149 | ],
|
147 | 150 | "source": [
|
148 |
| - "import numpy as np\n", |
149 | 151 | "import matplotlib.pyplot as plt\n",
|
| 152 | + "import numpy as np\n", |
150 | 153 | "from skimage.color import label2rgb\n",
|
151 | 154 | "\n",
|
152 |
| - "# filehandler contains methods to read and write images and masks\n", |
153 |
| - "from cellseg_models_pytorch.utils import FileHandler\n", |
154 | 155 | "from cellseg_models_pytorch.transforms.functional import (\n",
|
155 |
| - " gen_stardist_maps,\n", |
156 | 156 | " gen_dist_maps,\n",
|
| 157 | + " gen_stardist_maps,\n", |
157 | 158 | ")\n",
|
158 | 159 | "\n",
|
| 160 | + "# filehandler contains methods to read and write images and masks\n", |
| 161 | + "from cellseg_models_pytorch.utils import FileHandler\n", |
| 162 | + "\n", |
159 | 163 | "img_dir = save_dir / \"train\" / \"images\"\n",
|
160 | 164 | "mask_dir = save_dir / \"train\" / \"labels\"\n",
|
161 | 165 | "imgs = sorted(img_dir.glob(\"*\"))\n",
|
|
216 | 220 | "metadata": {},
|
217 | 221 | "outputs": [],
|
218 | 222 | "source": [
|
| 223 | + "from typing import Dict, List, Optional, Tuple\n", |
| 224 | + "\n", |
| 225 | + "import lightning.pytorch as pl\n", |
219 | 226 | "import torch\n",
|
220 | 227 | "import torch.nn as nn\n",
|
221 | 228 | "import torch.optim as optim\n",
|
222 |
| - "import torchmetrics\n", |
223 |
| - "import lightning.pytorch as pl\n", |
224 |
| - "from copy import deepcopy\n", |
225 |
| - "from typing import List, Tuple, Dict, Optional\n", |
226 | 229 | "\n",
|
227 | 230 | "\n",
|
228 | 231 | "class SegmentationExperiment(pl.LightningModule):\n",
|
|
232 | 235 | " multitask_loss: Dict[str, nn.Module],\n",
|
233 | 236 | " optimizer: optim.Optimizer,\n",
|
234 | 237 | " scheduler: optim.lr_scheduler._LRScheduler,\n",
|
235 |
| - " branch_metrics: Optional[Dict[str, List[torchmetrics.Metric]]] = None,\n", |
236 | 238 | " optimizer_kwargs: Optional[Dict[str, float]] = None,\n",
|
237 | 239 | " scheduler_kwargs: Optional[Dict[str, float]] = None,\n",
|
238 | 240 | " **kwargs,\n",
|
|
249 | 251 | " self.optimizer_kwargs = optimizer_kwargs or {}\n",
|
250 | 252 | " self.scheduler_kwargs = scheduler_kwargs or {}\n",
|
251 | 253 | "\n",
|
252 |
| - " self.branch_metrics = branch_metrics\n", |
253 | 254 | " self.criterion = multitask_loss\n",
|
254 | 255 | "\n",
|
255 | 256 | " self.save_hyperparameters(ignore=\"model\")\n",
|
256 | 257 | "\n",
|
257 |
| - " metrics = self.configure_metrics()\n", |
258 |
| - " self.train_metrics = deepcopy(metrics)\n", |
259 |
| - " self.val_metrics = deepcopy(metrics)\n", |
260 |
| - " self.test_metrics = deepcopy(metrics)\n", |
261 |
| - "\n", |
262 | 258 | " def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:\n",
|
263 | 259 | " \"\"\"Forward pass.\"\"\"\n",
|
264 | 260 | " return self.model(x)\n",
|
|
298 | 294 | " # forward backward pass\n",
|
299 | 295 | " loss, soft_masks, targets = self.step(batch)\n",
|
300 | 296 | "\n",
|
301 |
| - " # compute training metrics\n", |
302 |
| - " metrics = self.compute_metrics(soft_masks, targets, \"train\")\n", |
303 |
| - "\n", |
304 | 297 | " # log the loss\n",
|
305 | 298 | " self.log(\"train_loss\", loss, on_step=True, on_epoch=False, prog_bar=True)\n",
|
306 | 299 | "\n",
|
307 |
| - " # log the metrics\n", |
308 |
| - " for metric_name, metric in metrics.items():\n", |
309 |
| - " self.log(metric_name, metric, on_step=True, on_epoch=False, prog_bar=True)\n", |
310 |
| - "\n", |
311 | 300 | " return loss\n",
|
312 | 301 | "\n",
|
313 | 302 | " def validation_step(\n",
|
|
317 | 306 | " # forward pass\n",
|
318 | 307 | " loss, soft_masks, targets = self.step(batch)\n",
|
319 | 308 | "\n",
|
320 |
| - " # compute validation metrics\n", |
321 |
| - " metrics = self.compute_metrics(soft_masks, targets, \"val\")\n", |
322 |
| - "\n", |
323 | 309 | " # log the loss\n",
|
324 | 310 | " self.log(\"val_loss\", loss, on_step=False, on_epoch=True, prog_bar=False)\n",
|
325 | 311 | "\n",
|
326 |
| - " # log the metrics\n", |
327 |
| - " for metric_name, metric in metrics.items():\n", |
328 |
| - " self.log(metric_name, metric, on_step=True, on_epoch=True, prog_bar=True)\n", |
329 |
| - "\n", |
330 | 312 | " return loss\n",
|
331 | 313 | "\n",
|
332 | 314 | " def test_step(\n",
|
|
336 | 318 | " # forward pass\n",
|
337 | 319 | " loss, soft_masks, targets = self.step(batch)\n",
|
338 | 320 | "\n",
|
339 |
| - " # compute validation metrics\n", |
340 |
| - " metrics = self.compute_metrics(soft_masks, targets, \"test\")\n", |
341 |
| - "\n", |
342 | 321 | " # log the loss\n",
|
343 | 322 | " self.log(\"test_loss\", loss, on_step=False, on_epoch=True, prog_bar=False)\n",
|
344 | 323 | "\n",
|
345 |
| - " # log the metrics\n", |
346 |
| - " for metric_name, metric in metrics.items():\n", |
347 |
| - " self.log(metric_name, metric, on_step=False, on_epoch=True, prog_bar=False)\n", |
348 |
| - "\n", |
349 | 324 | " return loss\n",
|
350 | 325 | "\n",
|
351 |
| - " def compute_metrics(\n", |
352 |
| - " self,\n", |
353 |
| - " preds: Dict[str, torch.Tensor],\n", |
354 |
| - " targets: Dict[str, torch.Tensor],\n", |
355 |
| - " phase: str,\n", |
356 |
| - " ) -> Dict[str, torch.Tensor]:\n", |
357 |
| - " \"\"\"Compute metrics for logging.\"\"\"\n", |
358 |
| - " if phase == \"train\":\n", |
359 |
| - " metrics_dict = self.train_metrics\n", |
360 |
| - " elif phase == \"val\":\n", |
361 |
| - " metrics_dict = self.val_metrics\n", |
362 |
| - " elif phase == \"test\":\n", |
363 |
| - " metrics_dict = self.test_metrics\n", |
364 |
| - "\n", |
365 |
| - " ret = {}\n", |
366 |
| - " for metric_name, metric in metrics_dict.items():\n", |
367 |
| - " if metric is not None:\n", |
368 |
| - " branch = metric_name.split(\"_\")[0]\n", |
369 |
| - " ret[metric_name] = metric(preds[branch], targets[branch])\n", |
370 |
| - "\n", |
371 |
| - " return ret\n", |
372 |
| - "\n", |
373 |
| - " def configure_metrics(self) -> Dict[str, torchmetrics.Metric]:\n", |
374 |
| - " \"\"\"Configure the metrics for the model.\"\"\"\n", |
375 |
| - " # We can put all the metrics in a ModuleDict and return it.\n", |
376 |
| - " metrics = nn.ModuleDict()\n", |
377 |
| - " for branch, metric_list in self.branch_metrics.items():\n", |
378 |
| - " for metric in metric_list:\n", |
379 |
| - " metrics[f\"{branch}_{metric}\"] = metric\n", |
380 |
| - "\n", |
381 |
| - " return metrics\n", |
382 |
| - "\n", |
383 | 326 | " def configure_optimizers(self) -> List[optim.Optimizer]:\n",
|
384 | 327 | " \"\"\"Configure the optimizers for the model.\"\"\"\n",
|
385 | 328 | " opt = self.optimizer(self.parameters(), **self.optimizer_kwargs)\n",
|
|
411 | 354 | "\n",
|
412 | 355 | "**Branch losses.** For each output of the model, we define a joint-loss function. These losses are summed together during backprop to form a multi-task loss. In `CPP-Net`, we get two outputs for the `stardist` output instead of one. The reason is that `CPP-Net` was designed to refine the `stardist`-output for better segmentation performance, thus, the model returns the both the regular `stardist` and the refined `stardist_refined` output. For both of these outputs we set the masked `MAE`-loss. For the `\"dist\"`-outputs we use the masked `BCELoss` (binary-cross-entropy) with `MSE`, and for the `\"type\"`-output we will use the masked multiclass categorical `CELoss` with `DiceLoss`.\n",
|
413 | 356 | "\n",
|
414 |
| - "**Logging metrics.**\n", |
415 |
| - "For the nuclei type masks we will monitor Jaccard-index i.e. the mIoU metric during training. The metric is averaged over the class-specific mIoUs. The metrics are from the `torchmetrics` library. Also, we will monitor the MSE-metric for the `stardist`-outputs.\n", |
416 |
| - "\n", |
417 | 357 | "**Optimizer and scheduler.**\n",
|
418 | 358 | "The optimizer used here is [AdamW](https://arxiv.org/abs/1711.05101). The learning rate is scheduled with the `ReduceLROnPlateau` schedule. The learning rate will be set to 0.0003. \n",
|
419 | 359 | "\n",
|
|
438 | 378 | ],
|
439 | 379 | "source": [
|
440 | 380 | "import torch.optim as optim\n",
|
441 |
| - "from torchmetrics import JaccardIndex, MeanSquaredError\n", |
442 | 381 | "\n",
|
443 |
| - "from cellseg_models_pytorch.models import cppnet_base_multiclass\n", |
444 | 382 | "from cellseg_models_pytorch.losses import (\n",
|
445 | 383 | " MAE,\n",
|
446 | 384 | " MSE,\n",
|
447 |
| - " DiceLoss,\n", |
448 | 385 | " BCELoss,\n",
|
449 | 386 | " CELoss,\n",
|
| 387 | + " DiceLoss,\n", |
450 | 388 | " JointLoss,\n",
|
451 | 389 | " MultiTaskLoss,\n",
|
452 | 390 | ")\n",
|
| 391 | + "from cellseg_models_pytorch.models import cppnet_base_multiclass\n", |
453 | 392 | "\n",
|
454 | 393 | "# seed the experiment for reproducibility\n",
|
455 | 394 | "pl.seed_everything(42)\n",
|
|
481 | 420 | " }, # weights for each branch\n",
|
482 | 421 | ")\n",
|
483 | 422 | "\n",
|
484 |
| - "# Define the metrics that will be logged for each model output.\n", |
485 |
| - "# We will log the Jaccard Index for the type maps and the Mean Squared Error\n", |
486 |
| - "# for the horizontal and vertical gradients output. For the type maps, the metric\n", |
487 |
| - "# is the average computed over each class separately ('macro').\n", |
488 |
| - "branch_metrics = {\n", |
489 |
| - " \"dist\": [None],\n", |
490 |
| - " \"type\": [\n", |
491 |
| - " JaccardIndex(\n", |
492 |
| - " task=\"multiclass\", average=\"macro\", num_classes=6, compute_on_cpu=True\n", |
493 |
| - " )\n", |
494 |
| - " ],\n", |
495 |
| - " \"stardist\": [MeanSquaredError(compute_on_cpu=True)],\n", |
496 |
| - " \"stardist_refined\": [MeanSquaredError(compute_on_cpu=True)],\n", |
497 |
| - "}\n", |
498 |
| - "\n", |
499 | 423 | "# Initialize the optimizer.\n",
|
500 | 424 | "# We will be using the AdamW optimizer from the torch.optim library with learning rate of 0.0003.\n",
|
501 | 425 | "adamw = optim.AdamW\n",
|
|
511 | 435 | "experiment = SegmentationExperiment(\n",
|
512 | 436 | " model=model,\n",
|
513 | 437 | " multitask_loss=multitask_loss,\n",
|
514 |
| - " branch_metrics=branch_metrics,\n", |
515 | 438 | " optimizer=adamw,\n",
|
516 | 439 | " scheduler=scheduler,\n",
|
517 | 440 | " optimizer_kwargs=optimizer_kwargs,\n",
|
|
560 | 483 | " monitor=\"val_loss\",\n",
|
561 | 484 | " mode=\"min\",\n",
|
562 | 485 | ")\n",
|
563 |
| - "ckpt_callback.CHECKPOINT_NAME_LAST = f\"cppnet_last\"\n", |
| 486 | + "ckpt_callback.CHECKPOINT_NAME_LAST = \"cppnet_last\"\n", |
564 | 487 | "callbacks.append(ckpt_callback)\n",
|
565 | 488 | "\n",
|
566 | 489 | "# Lightning training\n",
|
|
1002 | 925 | "inferer = ResizeInferer(\n",
|
1003 | 926 | " model=experiment,\n",
|
1004 | 927 | " input_path=save_dir / \"valid\" / \"images\",\n",
|
1005 |
| - " checkpoint_path=Path.home() / \"pannuke\" / \"cppnet\" / \"cppnet_last.ckpt\",\n", |
1006 | 928 | " out_activations=out_acts,\n",
|
1007 | 929 | " out_boundary_weights=out_weights,\n",
|
1008 | 930 | " resize=(256, 256), # Not actually resizing anything,\n",
|
1009 | 931 | " instance_postproc=\"stardist\",\n",
|
1010 | 932 | " batch_size=8,\n",
|
1011 | 933 | " n_images=50, # Use only the 50 first images of the folder,\n",
|
1012 | 934 | " normalization=\"percentile\",\n",
|
| 935 | + " checkpoint_path=trainer.checkpoint_callback.best_model_path,\n", |
1013 | 936 | ")\n",
|
1014 | 937 | "inferer.infer()"
|
1015 | 938 | ]
|
|
1050 | 973 | }
|
1051 | 974 | ],
|
1052 | 975 | "source": [
|
| 976 | + "import matplotlib.patches as mpatches\n", |
1053 | 977 | "import numpy as np\n",
|
| 978 | + "\n", |
1054 | 979 | "from cellseg_models_pytorch.utils import draw_thing_contours\n",
|
1055 |
| - "import matplotlib.patches as mpatches\n", |
1056 | 980 | "\n",
|
1057 | 981 | "fig, ax = plt.subplots(5, 2, figsize=(10, 17))\n",
|
1058 | 982 | "ax = ax.flatten()\n",
|
|
0 commit comments