Skip to content

Commit 5a297a2

Browse files
committed
modify train models for preparing dataset
1 parent 849fe95 commit 5a297a2

8 files changed

+180
-222175
lines changed

code/darts_classification_module.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import torch
2+
from nni.nas.evaluator.pytorch import ClassificationModule
3+
from torch.nn import DataParallel
4+
5+
6+
class DartsClassificationModule(ClassificationModule):
7+
def __init__(
8+
self,
9+
learning_rate: float = 0.001,
10+
weight_decay: float = 0.,
11+
auxiliary_loss_weight: float = 0.4,
12+
max_epochs: int = 600
13+
):
14+
self.auxiliary_loss_weight = auxiliary_loss_weight
15+
# Training length will be used in LR scheduler
16+
self.max_epochs = max_epochs
17+
super().__init__(learning_rate=learning_rate, weight_decay=weight_decay, export_onnx=False, num_classes=10)
18+
19+
def configure_optimizers(self):
20+
"""Customized optimizer with momentum, as well as a scheduler."""
21+
optimizer = torch.optim.SGD(
22+
self.parameters(),
23+
momentum=0.9,
24+
lr=self.hparams.learning_rate,
25+
weight_decay=self.hparams.weight_decay
26+
)
27+
# Cosine annealing scheduler with T_max equal to total epochs
28+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
29+
optimizer, T_max=self.max_epochs, eta_min=1e-3
30+
)
31+
return {
32+
'optimizer': optimizer,
33+
'lr_scheduler': {
34+
'scheduler': scheduler,
35+
'interval': 'epoch',
36+
'frequency': 1
37+
}
38+
}
39+
40+
def training_step(self, batch, batch_idx):
41+
"""Training step, customized with auxiliary loss and flexible unpacking."""
42+
x, y = batch
43+
out = self(x)
44+
# Handle auxiliary output if present
45+
if self.auxiliary_loss_weight and isinstance(out, (tuple, list)) and len(out) == 2:
46+
y_hat, y_aux = out
47+
loss_main = self.criterion(y_hat, y)
48+
loss_aux = self.criterion(y_aux, y)
49+
self.log('train_loss_main', loss_main)
50+
self.log('train_loss_aux', loss_aux)
51+
loss = loss_main + self.auxiliary_loss_weight * loss_aux
52+
else:
53+
# single output or no auxiliary
54+
y_hat = out[0] if isinstance(out, (tuple, list)) else out
55+
loss = self.criterion(y_hat, y)
56+
self.log('train_loss', loss, prog_bar=True)
57+
for name, metric in self.metrics.items():
58+
self.log('train_' + name, metric(y_hat, y), prog_bar=True)
59+
return loss
60+
61+
def on_train_epoch_start(self):
62+
# Handle DataParallel wrapper when adjusting drop path
63+
model = self.trainer.model
64+
if isinstance(model, DataParallel):
65+
target_model = model.module
66+
else:
67+
target_model = model
68+
69+
# Set drop path probability before every epoch, scaled by epoch ratio
70+
if hasattr(target_model, 'set_drop_path_prob') and hasattr(target_model, 'drop_path_prob'):
71+
drop_prob = target_model.drop_path_prob * self.current_epoch / self.max_epochs
72+
target_model.set_drop_path_prob(drop_prob)
73+
74+
# Logging learning rate at the beginning of every epoch
75+
lr = self.trainer.optimizers[0].param_groups[0]['lr']
76+
self.log('lr', lr)

code/data_generator.ipynb

Lines changed: 5 additions & 5 deletions
Large diffs are not rendered by default.

code/dataset.zip

36.5 KB
Binary file not shown.

code/dependecies.zip

1.24 KB
Binary file not shown.

code/train-best-models.ipynb

Lines changed: 4 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"# !pip install random --quiet --break-system-packages\n",
3737
"# !pip install wandb --quiet --break-system-packages\n",
3838
"# !pip install pytorch-lightning --quiet --break-system-packages\n",
39-
"!pip install torchmetrics --quiet --break-system-packages"
39+
"# !pip install torchmetrics --quiet --break-system-packages"
4040
]
4141
},
4242
{
@@ -68,6 +68,8 @@
6868
"from nni.nas.evaluator.pytorch import ClassificationModule\n",
6969
"from nni.nas.evaluator.pytorch import Lightning, Trainer\n",
7070
"\n",
71+
"from darts_classification_module import DartsClassificationModule\n",
72+
"\n",
7173
"import matplotlib.pyplot as plt\n",
7274
"from IPython.display import clear_output\n",
7375
"\n",
@@ -206,94 +208,6 @@
206208
"arch_dicts = load_json_from_directory('../home/best_models_greed_cluster')"
207209
]
208210
},
209-
{
210-
"cell_type": "code",
211-
"execution_count": 6,
212-
"metadata": {
213-
"execution": {
214-
"iopub.execute_input": "2025-05-09T15:55:45.281299Z",
215-
"iopub.status.busy": "2025-05-09T15:55:45.280801Z",
216-
"iopub.status.idle": "2025-05-09T15:55:45.290522Z",
217-
"shell.execute_reply": "2025-05-09T15:55:45.289841Z",
218-
"shell.execute_reply.started": "2025-05-09T15:55:45.281274Z"
219-
},
220-
"trusted": true
221-
},
222-
"outputs": [],
223-
"source": [
224-
"class DartsClassificationModule(ClassificationModule):\n",
225-
" def __init__(\n",
226-
" self,\n",
227-
" learning_rate: float = 0.001,\n",
228-
" weight_decay: float = 0.,\n",
229-
" auxiliary_loss_weight: float = 0.4,\n",
230-
" max_epochs: int = 600\n",
231-
" ):\n",
232-
" self.auxiliary_loss_weight = auxiliary_loss_weight\n",
233-
" # Training length will be used in LR scheduler\n",
234-
" self.max_epochs = max_epochs\n",
235-
" super().__init__(learning_rate=learning_rate, weight_decay=weight_decay, export_onnx=False, num_classes=10)\n",
236-
" \n",
237-
" def configure_optimizers(self):\n",
238-
" \"\"\"Customized optimizer with momentum, as well as a scheduler.\"\"\"\n",
239-
" optimizer = torch.optim.SGD(\n",
240-
" self.parameters(),\n",
241-
" momentum=0.9,\n",
242-
" lr=self.hparams.learning_rate,\n",
243-
" weight_decay=self.hparams.weight_decay\n",
244-
" )\n",
245-
" # Cosine annealing scheduler with T_max equal to total epochs\n",
246-
" scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
247-
" optimizer, T_max=self.max_epochs, eta_min=1e-3\n",
248-
" )\n",
249-
" return {\n",
250-
" 'optimizer': optimizer,\n",
251-
" 'lr_scheduler': {\n",
252-
" 'scheduler': scheduler,\n",
253-
" 'interval': 'epoch',\n",
254-
" 'frequency': 1\n",
255-
" }\n",
256-
" }\n",
257-
"\n",
258-
" def training_step(self, batch, batch_idx):\n",
259-
" \"\"\"Training step, customized with auxiliary loss and flexible unpacking.\"\"\"\n",
260-
" x, y = batch\n",
261-
" out = self(x)\n",
262-
" # Handle auxiliary output if present\n",
263-
" if self.auxiliary_loss_weight and isinstance(out, (tuple, list)) and len(out) == 2:\n",
264-
" y_hat, y_aux = out\n",
265-
" loss_main = self.criterion(y_hat, y)\n",
266-
" loss_aux = self.criterion(y_aux, y)\n",
267-
" self.log('train_loss_main', loss_main)\n",
268-
" self.log('train_loss_aux', loss_aux)\n",
269-
" loss = loss_main + self.auxiliary_loss_weight * loss_aux\n",
270-
" else:\n",
271-
" # single output or no auxiliary\n",
272-
" y_hat = out[0] if isinstance(out, (tuple, list)) else out\n",
273-
" loss = self.criterion(y_hat, y)\n",
274-
" self.log('train_loss', loss, prog_bar=True)\n",
275-
" for name, metric in self.metrics.items():\n",
276-
" self.log('train_' + name, metric(y_hat, y), prog_bar=True)\n",
277-
" return loss\n",
278-
"\n",
279-
" def on_train_epoch_start(self):\n",
280-
" # Handle DataParallel wrapper when adjusting drop path\n",
281-
" model = self.trainer.model\n",
282-
" if isinstance(model, DataParallel):\n",
283-
" target_model = model.module\n",
284-
" else:\n",
285-
" target_model = model\n",
286-
"\n",
287-
" # Set drop path probability before every epoch, scaled by epoch ratio\n",
288-
" if hasattr(target_model, 'set_drop_path_prob') and hasattr(target_model, 'drop_path_prob'):\n",
289-
" drop_prob = target_model.drop_path_prob * self.current_epoch / self.max_epochs\n",
290-
" target_model.set_drop_path_prob(drop_prob)\n",
291-
"\n",
292-
" # Logging learning rate at the beginning of every epoch\n",
293-
" lr = self.trainer.optimizers[0].param_groups[0]['lr']\n",
294-
" self.log('lr', lr)\n"
295-
]
296-
},
297211
{
298212
"cell_type": "code",
299213
"execution_count": 10,
@@ -351,7 +265,6 @@
351265
" gradient_clip_val=5.0,\n",
352266
" max_epochs=max_epochs,\n",
353267
" fast_dev_run=fast_dev_run,\n",
354-
" precision='16-mixed',\n",
355268
" logger=WandbLogger(experiment=wandb.run)\n",
356269
" ),\n",
357270
" train_dataloaders=train_loader,\n",
@@ -363,15 +276,6 @@
363276
" return model"
364277
]
365278
},
366-
{
367-
"cell_type": "code",
368-
"execution_count": 11,
369-
"metadata": {},
370-
"outputs": [],
371-
"source": [
372-
"torch.set_float32_matmul_precision('medium')"
373-
]
374-
},
375279
{
376280
"cell_type": "code",
377281
"execution_count": null,
@@ -624,8 +528,7 @@
624528
" folder_name (str, необязательно): Имя папки для сохранения результатов. По умолчанию \"results\".\n",
625529
" Исключения:\n",
626530
" ValueError: Если количество моделей и архитектур не совпадает.\n",
627-
" Результаты:\n",
628-
" Для каждой модели создается файл JSON с результатами, содержащий:\n",
531+
" Результаты: Для каждой модели создается файл JSON с результатами, содержащий:\n",
629532
" - architecture: Архитектура модели.\n",
630533
" - valid_predictions: Предсказания модели на валидационном наборе данных.\n",
631534
" - valid_accuracy: Точность модели на валидационном наборе данных.\n",

0 commit comments

Comments
 (0)