|
41 | 41 | ],
|
42 | 42 | "source": [
|
43 | 43 | "# version info\n",
|
44 |
| - "import torch\n", |
| 44 | + "from platform import python_version\n", |
| 45 | + "\n", |
45 | 46 | "import lightning\n",
|
| 47 | + "import torch\n", |
| 48 | + "\n", |
46 | 49 | "import cellseg_models_pytorch\n",
|
47 |
| - "from platform import python_version\n", |
48 | 50 | "\n",
|
49 | 51 | "print(\"torch version:\", torch.__version__)\n",
|
50 | 52 | "print(\"lightning version:\", lightning.__version__)\n",
|
|
86 | 88 | ],
|
87 | 89 | "source": [
|
88 | 90 | "from pathlib import Path\n",
|
| 91 | + "\n", |
89 | 92 | "from cellseg_models_pytorch.datamodules import PannukeDataModule\n",
|
90 | 93 | "\n",
|
91 | 94 | "# fold1 and fold2 are used for training, fold3 is used for validation\n",
|
|
144 | 147 | }
|
145 | 148 | ],
|
146 | 149 | "source": [
|
147 |
| - "import numpy as np\n", |
148 | 150 | "import matplotlib.pyplot as plt\n",
|
| 151 | + "import numpy as np\n", |
149 | 152 | "from skimage.color import label2rgb\n",
|
150 | 153 | "\n",
|
151 |
| - "# filehandler contains methods to read and write images and masks\n", |
152 |
| - "from cellseg_models_pytorch.utils import FileHandler\n", |
153 | 154 | "from cellseg_models_pytorch.transforms.functional import (\n",
|
154 |
| - " gen_stardist_maps,\n", |
155 | 155 | " gen_dist_maps,\n",
|
| 156 | + " gen_stardist_maps,\n", |
156 | 157 | ")\n",
|
157 | 158 | "\n",
|
| 159 | + "# filehandler contains methods to read and write images and masks\n", |
| 160 | + "from cellseg_models_pytorch.utils import FileHandler\n", |
| 161 | + "\n", |
158 | 162 | "img_dir = save_dir / \"train\" / \"images\"\n",
|
159 | 163 | "mask_dir = save_dir / \"train\" / \"labels\"\n",
|
160 | 164 | "imgs = sorted(img_dir.glob(\"*\"))\n",
|
|
216 | 220 | "metadata": {},
|
217 | 221 | "outputs": [],
|
218 | 222 | "source": [
|
| 223 | + "from typing import Dict, List, Optional, Tuple\n", |
| 224 | + "\n", |
| 225 | + "import lightning.pytorch as pl\n", |
219 | 226 | "import torch\n",
|
220 | 227 | "import torch.nn as nn\n",
|
221 | 228 | "import torch.optim as optim\n",
|
222 | 229 | "import torchmetrics\n",
|
223 |
| - "import lightning.pytorch as pl\n", |
224 |
| - "from typing import List, Tuple, Dict, Optional\n", |
225 | 230 | "\n",
|
226 | 231 | "\n",
|
227 | 232 | "class SegmentationExperiment(pl.LightningModule):\n",
|
|
392 | 397 | "source": [
|
393 | 398 | "import torch.optim as optim\n",
|
394 | 399 | "\n",
|
395 |
| - "from cellseg_models_pytorch.models import stardist_base_multiclass\n", |
396 | 400 | "from cellseg_models_pytorch.losses import (\n",
|
397 | 401 | " MAE,\n",
|
398 | 402 | " MSE,\n",
|
399 |
| - " DiceLoss,\n", |
400 | 403 | " BCELoss,\n",
|
401 | 404 | " CELoss,\n",
|
| 405 | + " DiceLoss,\n", |
402 | 406 | " JointLoss,\n",
|
403 | 407 | " MultiTaskLoss,\n",
|
404 | 408 | ")\n",
|
| 409 | + "from cellseg_models_pytorch.models import stardist_base_multiclass\n", |
405 | 410 | "\n",
|
406 | 411 | "# seed the experiment for reproducibility\n",
|
407 | 412 | "pl.seed_everything(42)\n",
|
|
888 | 893 | }
|
889 | 894 | ],
|
890 | 895 | "source": [
|
| 896 | + "import matplotlib.patches as mpatches\n", |
891 | 897 | "import numpy as np\n",
|
| 898 | + "\n", |
892 | 899 | "from cellseg_models_pytorch.utils import draw_thing_contours\n",
|
893 |
| - "import matplotlib.patches as mpatches\n", |
894 | 900 | "\n",
|
895 | 901 | "fig, ax = plt.subplots(5, 2, figsize=(10, 17))\n",
|
896 | 902 | "ax = ax.flatten()\n",
|
|
0 commit comments