Skip to content

Commit 4e6f64d

Browse files
committed
docs: minor typo fixes in example notebooks
1 parent 372db65 commit 4e6f64d

File tree

2 files changed

+30
-197
lines changed

2 files changed

+30
-197
lines changed

examples/pannuke_nuclei_segmentation_cellpose.ipynb

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,12 @@
4141
}
4242
],
4343
"source": [
44-
"import torch\n",
44+
"from platform import python_version\n",
45+
"\n",
4546
"import accelerate\n",
47+
"import torch\n",
48+
"\n",
4649
"import cellseg_models_pytorch\n",
47-
"from platform import python_version\n",
4850
"\n",
4951
"print(\"torch version:\", torch.__version__)\n",
5052
"print(\"accelerate version:\", accelerate.__version__)\n",
@@ -86,6 +88,7 @@
8688
],
8789
"source": [
8890
"from pathlib import Path\n",
91+
"\n",
8992
"from cellseg_models_pytorch.datamodules import PannukeDataModule\n",
9093
"\n",
9194
"fold_split = {\"fold1\": \"train\", \"fold2\": \"train\", \"fold3\": \"valid\"}\n",
@@ -146,13 +149,14 @@
146149
}
147150
],
148151
"source": [
149-
"import numpy as np\n",
150152
"import matplotlib.pyplot as plt\n",
153+
"import numpy as np\n",
151154
"from skimage.color import label2rgb\n",
152155
"\n",
156+
"from cellseg_models_pytorch.transforms.functional import gen_flow_maps\n",
157+
"\n",
153158
"# filehandler contains methods to read and write images and masks\n",
154159
"from cellseg_models_pytorch.utils import FileHandler\n",
155-
"from cellseg_models_pytorch.transforms.functional import gen_flow_maps\n",
156160
"\n",
157161
"img_dir = save_dir / \"train\" / \"images\"\n",
158162
"mask_dir = save_dir / \"train\" / \"labels\"\n",
@@ -216,7 +220,7 @@
216220
"First, we will define the CellPose nuclei segmentation model with a `imagenet` pre-trained encoder. Specifically, we will use the `convnext_small` backbone for this demonstration. Many more encoders can be used, since these are imported from the `timm` library. There are also support for some transformer based encoders, but these are shown in another notebooks.\n",
217221
"\n",
218222
"**Branch losses**.\n",
219-
"For each output of the model, we define a joint-loss function. These losses are summed together during backprop to form a multi-task loss. For the `\"cellpose\"`branch output we set a joint-loss composed of SSIM-loss (Structural Similarity Index) and MSE-loss (Mean Squared Error) and for the `\"type\"` (cell type predictions) outputs we will set a joint-loss composed of CE-loss (Cross Entropy) and DICE-loss for both. For the CE-losses, we will also be using [spectral decoupling](https://arxiv.org/abs/2011.09468) to regularize the model. \n",
223+
"For each output of the model, we define a joint-loss function. These losses are summed together during backprop to form a multi-task loss. For the `\"cellpose\"`branch output we set the MSE-loss (Mean Squared Error) and for the `\"type\"` (cell type predictions) outputs we will set a joint-loss composed of CE-loss (Cross Entropy) and DICE-loss for both. For the CE-losses, we will also be using [spectral decoupling](https://arxiv.org/abs/2011.09468) to regularize the model. \n",
220224
"\n",
221225
"\n",
222226
"**Logging metrics.**\n",
@@ -236,16 +240,15 @@
236240
"outputs": [],
237241
"source": [
238242
"import accelerate\n",
239-
"from accelerate.utils import set_seed\n",
240-
"\n",
241243
"import torch.nn as nn\n",
242244
"import torch.nn.functional as F\n",
245+
"from accelerate.utils import set_seed\n",
243246
"from torch.optim.lr_scheduler import OneCycleLR\n",
244247
"from torchmetrics import JaccardIndex, MeanSquaredError\n",
245248
"from tqdm import tqdm\n",
246249
"\n",
250+
"from cellseg_models_pytorch.losses import CELoss, DiceLoss, JointLoss, MultiTaskLoss\n",
247251
"from cellseg_models_pytorch.models import cellpose_base\n",
248-
"from cellseg_models_pytorch.losses import DiceLoss, CELoss, JointLoss, MultiTaskLoss\n",
249252
"\n",
250253
"\n",
251254
"# Quick wrapper for MSE loss to make it fit the JointLoss API\n",
@@ -821,7 +824,6 @@
821824
"source": [
822825
"from accelerate import load_checkpoint_and_dispatch\n",
823826
"\n",
824-
"\n",
825827
"# The model state dict was saved in the project_dir\n",
826828
"model = cellpose_base(\n",
827829
" enc_name=\"convnext_small\",\n",
@@ -967,10 +969,10 @@
967969
}
968970
],
969971
"source": [
970-
"import numpy as np\n",
971-
"from cellseg_models_pytorch.utils import draw_thing_contours\n",
972972
"import matplotlib.patches as mpatches\n",
973+
"import numpy as np\n",
973974
"\n",
975+
"from cellseg_models_pytorch.utils import draw_thing_contours\n",
974976
"\n",
975977
"fig, ax = plt.subplots(5, 2, figsize=(10, 17))\n",
976978
"ax = ax.flatten()\n",

examples/pannuke_nuclei_segmentation_cellpose_dinov2.ipynb

Lines changed: 17 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,12 @@
4141
}
4242
],
4343
"source": [
44-
"import torch\n",
44+
"from platform import python_version\n",
45+
"\n",
4546
"import lightning\n",
47+
"import torch\n",
48+
"\n",
4649
"import cellseg_models_pytorch\n",
47-
"from platform import python_version\n",
4850
"\n",
4951
"print(\"torch version:\", torch.__version__)\n",
5052
"print(\"lightning version:\", lightning.__version__)\n",
@@ -88,6 +90,7 @@
8890
],
8991
"source": [
9092
"from pathlib import Path\n",
93+
"\n",
9194
"from cellseg_models_pytorch.datamodules import PannukeDataModule\n",
9295
"\n",
9396
"fold_split = {\"fold1\": \"train\", \"fold2\": \"train\", \"fold3\": \"valid\"}\n",
@@ -150,13 +153,14 @@
150153
}
151154
],
152155
"source": [
153-
"import numpy as np\n",
154156
"import matplotlib.pyplot as plt\n",
157+
"import numpy as np\n",
155158
"from skimage.color import label2rgb\n",
156159
"\n",
160+
"from cellseg_models_pytorch.transforms.functional import gen_flow_maps\n",
161+
"\n",
157162
"# filehandler contains methods to read and write images and masks\n",
158163
"from cellseg_models_pytorch.utils import FileHandler\n",
159-
"from cellseg_models_pytorch.transforms.functional import gen_flow_maps\n",
160164
"\n",
161165
"img_dir = save_dir / \"train\" / \"images\"\n",
162166
"mask_dir = save_dir / \"train\" / \"labels\"\n",
@@ -219,11 +223,12 @@
219223
"metadata": {},
220224
"outputs": [],
221225
"source": [
226+
"from typing import Dict, List, Tuple\n",
227+
"\n",
228+
"import lightning.pytorch as pl\n",
222229
"import torch\n",
223230
"import torch.nn as nn\n",
224231
"import torch.optim as optim\n",
225-
"import lightning.pytorch as pl\n",
226-
"from typing import List, Tuple, Dict\n",
227232
"\n",
228233
"from cellseg_models_pytorch.losses import MultiTaskLoss\n",
229234
"\n",
@@ -388,7 +393,7 @@
388393
}
389394
],
390395
"source": [
391-
"from cellseg_models_pytorch.losses import JointLoss, CELoss, DiceLoss, SSIM, MSE\n",
396+
"from cellseg_models_pytorch.losses import MSE, SSIM, CELoss, DiceLoss, JointLoss\n",
392397
"from cellseg_models_pytorch.models import cellpose_base\n",
393398
"\n",
394399
"# initialize hovernet\n",
@@ -464,183 +469,9 @@
464469
},
465470
{
466471
"cell_type": "code",
467-
"execution_count": 8,
472+
"execution_count": null,
468473
"metadata": {},
469-
"outputs": [
470-
{
471-
"name": "stderr",
472-
"output_type": "stream",
473-
"text": [
474-
"You are using a CUDA device ('NVIDIA GeForce RTX 3080 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
475-
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
476-
]
477-
},
478-
{
479-
"name": "stdout",
480-
"output_type": "stream",
481-
"text": [
482-
"Found all folds. Skip downloading.\n",
483-
"Found processed pannuke data. If in need of a re-download, please empty the `save_dir` folder.\n"
484-
]
485-
},
486-
{
487-
"name": "stderr",
488-
"output_type": "stream",
489-
"text": [
490-
"\n",
491-
" | Name | Type | Params | Mode \n",
492-
"----------------------------------------------------\n",
493-
"0 | model | CellPoseUnet | 92.4 M | train\n",
494-
"1 | criterion | MultiTaskLoss | 0 | train\n",
495-
"----------------------------------------------------\n",
496-
"3.8 M Trainable params\n",
497-
"88.6 M Non-trainable params\n",
498-
"92.4 M Total params\n",
499-
"369.433 Total estimated model params size (MB)\n"
500-
]
501-
},
502-
{
503-
"data": {
504-
"application/vnd.jupyter.widget-view+json": {
505-
"model_id": "6d39790609e94710a1c2db412e95e672",
506-
"version_major": 2,
507-
"version_minor": 0
508-
},
509-
"text/plain": [
510-
"Sanity Checking: | | 0/? [00:00<?, ?it/s]"
511-
]
512-
},
513-
"metadata": {},
514-
"output_type": "display_data"
515-
},
516-
{
517-
"data": {
518-
"application/vnd.jupyter.widget-view+json": {
519-
"model_id": "9e149243fda244edad091b1ebdb4bbb5",
520-
"version_major": 2,
521-
"version_minor": 0
522-
},
523-
"text/plain": [
524-
"Training: | | 0/? [00:00<?, ?it/s]"
525-
]
526-
},
527-
"metadata": {},
528-
"output_type": "display_data"
529-
},
530-
{
531-
"data": {
532-
"application/vnd.jupyter.widget-view+json": {
533-
"model_id": "6915db3487c54dba98ad2bc147b1d75f",
534-
"version_major": 2,
535-
"version_minor": 0
536-
},
537-
"text/plain": [
538-
"Validation: | | 0/? [00:00<?, ?it/s]"
539-
]
540-
},
541-
"metadata": {},
542-
"output_type": "display_data"
543-
},
544-
{
545-
"name": "stderr",
546-
"output_type": "stream",
547-
"text": [
548-
"Epoch 0, global step 648: 'val_loss' reached 1.21952 (best 1.21952), saving model to '/home/leos/pannuke/dino_cellpose/epoch=0-step=648.ckpt' as top 1\n"
549-
]
550-
},
551-
{
552-
"data": {
553-
"application/vnd.jupyter.widget-view+json": {
554-
"model_id": "93cfe15ce7b444b8b5ccc4e509f86e7b",
555-
"version_major": 2,
556-
"version_minor": 0
557-
},
558-
"text/plain": [
559-
"Validation: | | 0/? [00:00<?, ?it/s]"
560-
]
561-
},
562-
"metadata": {},
563-
"output_type": "display_data"
564-
},
565-
{
566-
"name": "stderr",
567-
"output_type": "stream",
568-
"text": [
569-
"Epoch 1, global step 1296: 'val_loss' reached 1.19438 (best 1.19438), saving model to '/home/leos/pannuke/dino_cellpose/epoch=1-step=1296.ckpt' as top 1\n"
570-
]
571-
},
572-
{
573-
"data": {
574-
"application/vnd.jupyter.widget-view+json": {
575-
"model_id": "36558f8ef059416d9b6d5db628bf34fb",
576-
"version_major": 2,
577-
"version_minor": 0
578-
},
579-
"text/plain": [
580-
"Validation: | | 0/? [00:00<?, ?it/s]"
581-
]
582-
},
583-
"metadata": {},
584-
"output_type": "display_data"
585-
},
586-
{
587-
"name": "stderr",
588-
"output_type": "stream",
589-
"text": [
590-
"Epoch 2, global step 1944: 'val_loss' reached 1.07785 (best 1.07785), saving model to '/home/leos/pannuke/dino_cellpose/epoch=2-step=1944.ckpt' as top 1\n"
591-
]
592-
},
593-
{
594-
"data": {
595-
"application/vnd.jupyter.widget-view+json": {
596-
"model_id": "f33dbb9ed6794453950d33f3069861fe",
597-
"version_major": 2,
598-
"version_minor": 0
599-
},
600-
"text/plain": [
601-
"Validation: | | 0/? [00:00<?, ?it/s]"
602-
]
603-
},
604-
"metadata": {},
605-
"output_type": "display_data"
606-
},
607-
{
608-
"name": "stderr",
609-
"output_type": "stream",
610-
"text": [
611-
"Epoch 3, global step 2592: 'val_loss' reached 1.04576 (best 1.04576), saving model to '/home/leos/pannuke/dino_cellpose/epoch=3-step=2592.ckpt' as top 1\n"
612-
]
613-
},
614-
{
615-
"data": {
616-
"application/vnd.jupyter.widget-view+json": {
617-
"model_id": "45c76d38563b471ea2996d14d6bdbd8e",
618-
"version_major": 2,
619-
"version_minor": 0
620-
},
621-
"text/plain": [
622-
"Validation: | | 0/? [00:00<?, ?it/s]"
623-
]
624-
},
625-
"metadata": {},
626-
"output_type": "display_data"
627-
},
628-
{
629-
"name": "stderr",
630-
"output_type": "stream",
631-
"text": [
632-
"Epoch 4, global step 3240: 'val_loss' reached 0.94926 (best 0.94926), saving model to '/home/leos/pannuke/dino_cellpose/epoch=4-step=3240.ckpt' as top 1\n",
633-
"`Trainer.fit` stopped: `max_epochs=5` reached.\n"
634-
]
635-
},
636-
{
637-
"name": "stdout",
638-
"output_type": "stream",
639-
"text": [
640-
"gg\n"
641-
]
642-
}
643-
],
474+
"outputs": [],
644475
"source": [
645476
"# Train the model\n",
646477
"trainer.fit(model=experiment, datamodule=pannuke_module)"
@@ -688,8 +519,8 @@
688519
],
689520
"source": [
690521
"import torch.nn.functional as F\n",
691-
"from cellseg_models_pytorch.utils import percentile_normalize_torch\n",
692522
"\n",
523+
"from cellseg_models_pytorch.utils import percentile_normalize_torch\n",
693524
"\n",
694525
"img_dir = save_dir / \"valid\" / \"images\"\n",
695526
"mask_dir = save_dir / \"valid\" / \"labels\"\n",
@@ -816,10 +647,10 @@
816647
}
817648
],
818649
"source": [
819-
"import numpy as np\n",
820-
"from cellseg_models_pytorch.utils import draw_thing_contours\n",
821650
"import matplotlib.patches as mpatches\n",
651+
"import numpy as np\n",
822652
"\n",
653+
"from cellseg_models_pytorch.utils import draw_thing_contours\n",
823654
"\n",
824655
"fig, ax = plt.subplots(5, 2, figsize=(10, 17))\n",
825656
"ax = ax.flatten()\n",

0 commit comments

Comments
 (0)