|
36 | 36 | "# !pip install random --quiet --break-system-packages\n",
|
37 | 37 | "# !pip install wandb --quiet --break-system-packages\n",
|
38 | 38 | "# !pip install pytorch-lightning --quiet --break-system-packages\n",
|
39 |
| - "!pip install torchmetrics --quiet --break-system-packages" |
| 39 | + "# !pip install torchmetrics --quiet --break-system-packages" |
40 | 40 | ]
|
41 | 41 | },
|
42 | 42 | {
|
|
68 | 68 | "from nni.nas.evaluator.pytorch import ClassificationModule\n",
|
69 | 69 | "from nni.nas.evaluator.pytorch import Lightning, Trainer\n",
|
70 | 70 | "\n",
|
| 71 | + "from darts_classification_module import DartsClassificationModule\n", |
| 72 | + "\n", |
71 | 73 | "import matplotlib.pyplot as plt\n",
|
72 | 74 | "from IPython.display import clear_output\n",
|
73 | 75 | "\n",
|
|
206 | 208 | "arch_dicts = load_json_from_directory('../home/best_models_greed_cluster')"
|
207 | 209 | ]
|
208 | 210 | },
|
209 |
| - { |
210 |
| - "cell_type": "code", |
211 |
| - "execution_count": 6, |
212 |
| - "metadata": { |
213 |
| - "execution": { |
214 |
| - "iopub.execute_input": "2025-05-09T15:55:45.281299Z", |
215 |
| - "iopub.status.busy": "2025-05-09T15:55:45.280801Z", |
216 |
| - "iopub.status.idle": "2025-05-09T15:55:45.290522Z", |
217 |
| - "shell.execute_reply": "2025-05-09T15:55:45.289841Z", |
218 |
| - "shell.execute_reply.started": "2025-05-09T15:55:45.281274Z" |
219 |
| - }, |
220 |
| - "trusted": true |
221 |
| - }, |
222 |
| - "outputs": [], |
223 |
| - "source": [ |
224 |
| - "class DartsClassificationModule(ClassificationModule):\n", |
225 |
| - " def __init__(\n", |
226 |
| - " self,\n", |
227 |
| - " learning_rate: float = 0.001,\n", |
228 |
| - " weight_decay: float = 0.,\n", |
229 |
| - " auxiliary_loss_weight: float = 0.4,\n", |
230 |
| - " max_epochs: int = 600\n", |
231 |
| - " ):\n", |
232 |
| - " self.auxiliary_loss_weight = auxiliary_loss_weight\n", |
233 |
| - " # Training length will be used in LR scheduler\n", |
234 |
| - " self.max_epochs = max_epochs\n", |
235 |
| - " super().__init__(learning_rate=learning_rate, weight_decay=weight_decay, export_onnx=False, num_classes=10)\n", |
236 |
| - " \n", |
237 |
| - " def configure_optimizers(self):\n", |
238 |
| - " \"\"\"Customized optimizer with momentum, as well as a scheduler.\"\"\"\n", |
239 |
| - " optimizer = torch.optim.SGD(\n", |
240 |
| - " self.parameters(),\n", |
241 |
| - " momentum=0.9,\n", |
242 |
| - " lr=self.hparams.learning_rate,\n", |
243 |
| - " weight_decay=self.hparams.weight_decay\n", |
244 |
| - " )\n", |
245 |
| - " # Cosine annealing scheduler with T_max equal to total epochs\n", |
246 |
| - " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n", |
247 |
| - " optimizer, T_max=self.max_epochs, eta_min=1e-3\n", |
248 |
| - " )\n", |
249 |
| - " return {\n", |
250 |
| - " 'optimizer': optimizer,\n", |
251 |
| - " 'lr_scheduler': {\n", |
252 |
| - " 'scheduler': scheduler,\n", |
253 |
| - " 'interval': 'epoch',\n", |
254 |
| - " 'frequency': 1\n", |
255 |
| - " }\n", |
256 |
| - " }\n", |
257 |
| - "\n", |
258 |
| - " def training_step(self, batch, batch_idx):\n", |
259 |
| - " \"\"\"Training step, customized with auxiliary loss and flexible unpacking.\"\"\"\n", |
260 |
| - " x, y = batch\n", |
261 |
| - " out = self(x)\n", |
262 |
| - " # Handle auxiliary output if present\n", |
263 |
| - " if self.auxiliary_loss_weight and isinstance(out, (tuple, list)) and len(out) == 2:\n", |
264 |
| - " y_hat, y_aux = out\n", |
265 |
| - " loss_main = self.criterion(y_hat, y)\n", |
266 |
| - " loss_aux = self.criterion(y_aux, y)\n", |
267 |
| - " self.log('train_loss_main', loss_main)\n", |
268 |
| - " self.log('train_loss_aux', loss_aux)\n", |
269 |
| - " loss = loss_main + self.auxiliary_loss_weight * loss_aux\n", |
270 |
| - " else:\n", |
271 |
| - " # single output or no auxiliary\n", |
272 |
| - " y_hat = out[0] if isinstance(out, (tuple, list)) else out\n", |
273 |
| - " loss = self.criterion(y_hat, y)\n", |
274 |
| - " self.log('train_loss', loss, prog_bar=True)\n", |
275 |
| - " for name, metric in self.metrics.items():\n", |
276 |
| - " self.log('train_' + name, metric(y_hat, y), prog_bar=True)\n", |
277 |
| - " return loss\n", |
278 |
| - "\n", |
279 |
| - " def on_train_epoch_start(self):\n", |
280 |
| - " # Handle DataParallel wrapper when adjusting drop path\n", |
281 |
| - " model = self.trainer.model\n", |
282 |
| - " if isinstance(model, DataParallel):\n", |
283 |
| - " target_model = model.module\n", |
284 |
| - " else:\n", |
285 |
| - " target_model = model\n", |
286 |
| - "\n", |
287 |
| - " # Set drop path probability before every epoch, scaled by epoch ratio\n", |
288 |
| - " if hasattr(target_model, 'set_drop_path_prob') and hasattr(target_model, 'drop_path_prob'):\n", |
289 |
| - " drop_prob = target_model.drop_path_prob * self.current_epoch / self.max_epochs\n", |
290 |
| - " target_model.set_drop_path_prob(drop_prob)\n", |
291 |
| - "\n", |
292 |
| - " # Logging learning rate at the beginning of every epoch\n", |
293 |
| - " lr = self.trainer.optimizers[0].param_groups[0]['lr']\n", |
294 |
| - " self.log('lr', lr)\n" |
295 |
| - ] |
296 |
| - }, |
297 | 211 | {
|
298 | 212 | "cell_type": "code",
|
299 | 213 | "execution_count": 10,
|
|
351 | 265 | " gradient_clip_val=5.0,\n",
|
352 | 266 | " max_epochs=max_epochs,\n",
|
353 | 267 | " fast_dev_run=fast_dev_run,\n",
|
354 |
| - " precision='16-mixed',\n", |
355 | 268 | " logger=WandbLogger(experiment=wandb.run)\n",
|
356 | 269 | " ),\n",
|
357 | 270 | " train_dataloaders=train_loader,\n",
|
|
363 | 276 | " return model"
|
364 | 277 | ]
|
365 | 278 | },
|
366 |
| - { |
367 |
| - "cell_type": "code", |
368 |
| - "execution_count": 11, |
369 |
| - "metadata": {}, |
370 |
| - "outputs": [], |
371 |
| - "source": [ |
372 |
| - "torch.set_float32_matmul_precision('medium')" |
373 |
| - ] |
374 |
| - }, |
375 | 279 | {
|
376 | 280 | "cell_type": "code",
|
377 | 281 | "execution_count": null,
|
|
624 | 528 | " folder_name (str, необязательно): Имя папки для сохранения результатов. По умолчанию \"results\".\n",
|
625 | 529 | " Исключения:\n",
|
626 | 530 | " ValueError: Если количество моделей и архитектур не совпадает.\n",
|
627 |
| - " Результаты:\n", |
628 |
| - " Для каждой модели создается файл JSON с результатами, содержащий:\n", |
| 531 | + " Результаты: Для каждой модели создается файл JSON с результатами, содержащий:\n", |
629 | 532 | " - architecture: Архитектура модели.\n",
|
630 | 533 | " - valid_predictions: Предсказания модели на валидационном наборе данных.\n",
|
631 | 534 | " - valid_accuracy: Точность модели на валидационном наборе данных.\n",
|
|
0 commit comments