|
226 | 226 | "import torch\n",
|
227 | 227 | "import torch.nn as nn\n",
|
228 | 228 | "import torch.optim as optim\n",
|
229 |
| - "import torchmetrics\n", |
230 | 229 | "\n",
|
231 | 230 | "\n",
|
232 | 231 | "class SegmentationExperiment(pl.LightningModule):\n",
|
|
236 | 235 | " multitask_loss: Dict[str, nn.Module],\n",
|
237 | 236 | " optimizer: optim.Optimizer,\n",
|
238 | 237 | " scheduler: optim.lr_scheduler._LRScheduler,\n",
|
239 |
| - " branch_metrics: Optional[Dict[str, List[torchmetrics.Metric]]] = None,\n", |
240 | 238 | " optimizer_kwargs: Optional[Dict[str, float]] = None,\n",
|
241 | 239 | " scheduler_kwargs: Optional[Dict[str, float]] = None,\n",
|
242 | 240 | " **kwargs,\n",
|
|
253 | 251 | " self.optimizer_kwargs = optimizer_kwargs or {}\n",
|
254 | 252 | " self.scheduler_kwargs = scheduler_kwargs or {}\n",
|
255 | 253 | "\n",
|
256 |
| - " self.branch_metrics = branch_metrics\n", |
257 | 254 | " self.criterion = multitask_loss\n",
|
258 | 255 | "\n",
|
259 | 256 | " self._validate_branch_args()\n",
|
|
0 commit comments