Skip to content

Commit e302b6c

Browse files
committed
add all files to full process of using surrogate function. Read readme to use.
1 parent 31e3ecd commit e302b6c

File tree

12 files changed

+656
-210
lines changed

12 files changed

+656
-210
lines changed

code/dependencies/train_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,17 @@ class TrainConfig:
4242

4343
best_models_save_path: str
4444

45+
n_epochs_final: int
46+
lr_final: float
47+
batch_size_final: int
48+
dataset_name: str
49+
final_dataset_path: str
50+
output_path: str
51+
width: int
52+
num_cells: int
53+
num_workers: int
54+
n_ece_bins: int
55+
4556
seed: Optional[int] = None
4657

4758
# Internal fields

code/inference_surrogate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import gc
2020
from torch.utils.data import DataLoader
2121
from collections import deque
22+
import shutil
2223

2324
# Custom imports
2425
import sys
@@ -98,7 +99,7 @@ def architecture_search(self):
9899
dataset,
99100
batch_size=self.config.batch_size_inference,
100101
shuffle=False,
101-
num_workers=4,
102+
num_workers=self.config.num_workers,
102103
collate_fn=collate_graphs,
103104
)
104105

@@ -259,6 +260,7 @@ def select_central_models_by_clusters(self):
259260
plt.show()
260261

261262
def save_models(self):
263+
shutil.rmtree(self.config.best_models_save_path, ignore_errors=True)
262264
os.makedirs(self.config.best_models_save_path, exist_ok=True)
263265

264266
# Сохраняем архитектуры по одной

code/old_train_models.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import os
2+
import json
3+
import numpy as np
4+
import torch
5+
import nni
6+
from torch.utils.data import SubsetRandomSampler, SequentialSampler
7+
from torchvision import transforms
8+
from torchvision.datasets import CIFAR10, CIFAR100
9+
from nni.nas.evaluator.pytorch import DataLoader, Classification
10+
11+
from DartsSpace import DARTS_with_CIFAR100 as DartsSpace
12+
13+
14+
from nni.nas.space import model_context
15+
from tqdm import tqdm
16+
from IPython.display import clear_output
17+
from nni.nas.evaluator.pytorch import Lightning, Trainer
18+
19+
from dependecies.data_generator import generate_arch_dicts
20+
from dependecies.darts_classification_module import DartsClassificationModule
21+
22+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23+
TEST = False
24+
25+
26+
ARCHITECTURES_PATH = "/kaggle/input/second-dataset/dataset"
27+
MAX_EPOCHS = 60
28+
LEARNING_RATE = 0.025
29+
BATCH_SIZE = 96
30+
NUM_MODLES = 2000
31+
32+
DATASET = "CIFAR100"
33+
34+
if DATASET == "CIFAR10":
35+
MEAN = [0.49139968, 0.48215827, 0.44653124]
36+
STD = [0.24703233, 0.24348505, 0.26158768]
37+
elif DATASET == "CIFAR100":
38+
MEAN = [0.5071, 0.4867, 0.4408]
39+
STD = [0.2673, 0.2564, 0.2762]
40+
41+
SEED = 228
42+
# random.seed(SEED)
43+
np.random.seed(SEED)
44+
torch.manual_seed(SEED)
45+
torch.cuda.manual_seed_all(SEED) # если есть GPU
46+
torch.backends.cudnn.deterministic = True
47+
torch.backends.cudnn.benchmark = False
48+
49+
50+
def load_json_from_directory(directory_path):
51+
json_data = []
52+
for root, _, files in os.walk(directory_path):
53+
for file in files:
54+
if file.endswith('.json'):
55+
file_path = os.path.join(root, file)
56+
with open(file_path, 'r', encoding='utf-8') as f:
57+
try:
58+
data = json.load(f)
59+
json_data.append(data)
60+
except json.JSONDecodeError as e:
61+
print(f"Error decoding JSON from file {file_path}: {e}")
62+
return json_data
63+
64+
65+
def get_data_loaders(batch_size=512):
66+
"""
67+
Возвращает загрузчики данных для обучения и валидации.
68+
69+
Параметры:
70+
batch_size (int): Размер батча для загрузчиков данных. По умолчанию 1024.
71+
72+
Возвращает:
73+
tuple: Кортеж, содержащий два объекта DataLoader:
74+
- search_train_loader: Загрузчик данных для обучения.
75+
- search_valid_loader: Загрузчик данных для валидации.
76+
"""
77+
transform = transforms.Compose(
78+
[
79+
transforms.RandomCrop(32, padding=4),
80+
transforms.RandomHorizontalFlip(),
81+
transforms.ToTensor(),
82+
transforms.Normalize(MEAN, STD),
83+
]
84+
)
85+
if DATASET == 'CIFAR10':
86+
train_data = nni.trace(CIFAR10)(
87+
root="./data", train=True, download=True, transform=transform
88+
)
89+
elif DATASET == 'CIFAR100':
90+
train_data = nni.trace(CIFAR100)(
91+
root="./data", train=True, download=True, transform=transform
92+
)
93+
num_samples = len(train_data)
94+
indices = np.random.permutation(num_samples)
95+
split = int(num_samples * 0.5)
96+
97+
search_train_loader = DataLoader(
98+
train_data,
99+
batch_size=batch_size,
100+
num_workers=10,
101+
sampler=SubsetRandomSampler(indices[:split]),
102+
)
103+
104+
search_valid_loader = DataLoader(
105+
train_data,
106+
batch_size=batch_size,
107+
num_workers=10,
108+
sampler=SequentialSampler(indices[split:]),
109+
)
110+
111+
return search_train_loader, search_valid_loader
112+
113+
114+
def train_model(
115+
architecture,
116+
train_loader,
117+
valid_loader,
118+
max_epochs=600,
119+
learning_rate=0.025,
120+
fast_dev_run=False
121+
):
122+
with model_context(architecture):
123+
if DATASET == 'CIFAR10':
124+
model = DartsSpace(width=16, num_cells=10, dataset='cifar')
125+
elif DATASET == 'CIFAR100':
126+
model = DartsSpace(width=16, num_cells=10, dataset='cifar100')
127+
128+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
129+
#if torch.cuda.device_count() > 1:
130+
# model = torch.nn.DataParallel(model)
131+
model.to(device)
132+
133+
evaluator = Lightning(
134+
DartsClassificationModule(
135+
learning_rate=learning_rate,
136+
weight_decay=3e-4,
137+
auxiliary_loss_weight=0.4,
138+
max_epochs=max_epochs
139+
),
140+
trainer=Trainer(
141+
gradient_clip_val=5.0,
142+
max_epochs=max_epochs,
143+
fast_dev_run=fast_dev_run,
144+
devices=[0]
145+
),
146+
train_dataloaders=train_loader#,
147+
#val_dataloaders=valid_loader
148+
)
149+
150+
evaluator.fit(model)
151+
return model
152+
153+
154+
def evaluate_and_save_results(
155+
model,
156+
architecture,
157+
model_id, # Новый обязательный параметр для идентификации модели
158+
valid_loader,
159+
folder_name="results_seq_0"
160+
):
161+
"""
162+
Оценивает модель на валидационном наборе данных и сохраняет результаты в JSON.
163+
Аргументы:
164+
model: Обученная модель
165+
architecture: Архитектура модели
166+
valid_loader (DataLoader): DataLoader для валидационных данных
167+
model_id: Уникальный идентификатор модели
168+
folder_name (str): Папка для сохранения результатов
169+
"""
170+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
171+
os.makedirs(folder_name, exist_ok=True)
172+
173+
# Перенос модели на устройство и режим оценки
174+
model.to(device)
175+
model.eval()
176+
177+
valid_correct = 0
178+
valid_total = 0
179+
valid_preds = []
180+
181+
with torch.no_grad():
182+
for images, labels in valid_loader:
183+
# print(labels)
184+
images, labels = images.to(device), labels.to(device)
185+
outputs = model(images)
186+
outputs = torch.softmax(outputs, dim=1)
187+
valid_preds.extend(outputs.cpu().tolist())
188+
_, predicted = torch.max(outputs, 1)
189+
valid_correct += (predicted == labels).sum().item()
190+
valid_total += labels.size(0)
191+
192+
valid_accuracy = valid_correct / valid_total
193+
194+
# Формирование результата
195+
result = {
196+
"architecture": architecture,
197+
"valid_predictions": valid_preds,
198+
"valid_accuracy": valid_accuracy,
199+
}
200+
201+
# Генерация имени файла с использованием model_id
202+
file_name = f"model_{model_id:04d}_results.json"
203+
file_path = os.path.join(folder_name, file_name)
204+
205+
# Сохранение результатов
206+
with open(file_path, "w") as f:
207+
json.dump(result, f, indent=4)
208+
209+
print(f"Results for model_{model_id} saved to {file_path}")
210+
211+
212+
if __name__ == "__main__":
213+
arch_dicts = generate_arch_dicts(NUM_MODLES)
214+
arch_dicts = [tmp_arch["architecture"] for tmp_arch in arch_dicts]
215+
search_train_loader, search_valid_loader = get_data_loaders(
216+
batch_size=BATCH_SIZE
217+
) # Получаем загрузчики CIFAR10
218+
219+
for idx, architecture in enumerate(tqdm(arch_dicts)):
220+
model = train_model( # Обучаем модель
221+
architecture,
222+
search_train_loader,
223+
search_valid_loader,
224+
max_epochs=MAX_EPOCHS,
225+
learning_rate=LEARNING_RATE,
226+
fast_dev_run=False
227+
)
228+
clear_output(wait=True)
229+
230+
evaluate_and_save_results(
231+
model, architecture, idx, valid_loader=search_valid_loader, folder_name="results_cifar100"
232+
) # Оцениваем и сохраняем архитектуры, предсказания на тестовом наборе CIFAR10 и accuracy

code/output/ensemble_results.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Ensemble Top-1 Accuracy: 9.90%
2+
Ensemble ECE: 0.0125
3+
Number of models: 2
4+
Model 1 Accuracy: 10.08%
5+
Model 2 Accuracy: 9.90%

code/output/results.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Ensemble Top-1 Accuracy: 9.57%Ensemble ECE: 0.0078Model 1 Top-1 Accuracy: 9.83%Model 2 Top-1 Accuracy: 9.55%Model 3 Top-1 Accuracy: 9.89%Model 4 Top-1 Accuracy: 9.93%Model 5 Top-1 Accuracy: 10.04%

code/readme.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Чтобы запустить все этапы, выполните команду:
2+
3+
./start_all.sh
4+
5+
Гиперпараметры для настройки находятся в файле surrogate_hp.json

code/start_all.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/usr/bin/env bash
2+
set -euo pipefail
3+
4+
CONFIG="surrogate_hp.json"
5+
6+
echo "=== Запуск surrogate_train.py ==="
7+
python3 train_surrogate.py --hyperparameters_json "$CONFIG"
8+
9+
echo "=== Запуск inference_surrogate.py ==="
10+
python3 inference_surrogate.py --hyperparameters_json "$CONFIG"
11+
12+
echo "=== Запуск train_models.py ==="
13+
python3 train_models.py --hyperparameters_json "$CONFIG"
14+
15+
echo "=== Все этапы успешно завершены ==="

code/surrogate_hp.json

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
{
22
"seed":42,
3+
"num_workers": 4,
34
"dataset_path": "third_dataset/",
45
"device": "cpu",
5-
"developer_mode": false,
6+
"developer_mode": true,
67
"n_models": 1300,
78

89
"upper_margin": 0.75,
@@ -13,14 +14,14 @@
1314
"batch_size": 8,
1415
"input_dim":8,
1516

16-
"acc_num_epochs": 40,
17+
"acc_num_epochs": 10,
1718
"acc_lr": 1e-2,
1819
"acc_final_lr": 1e-5,
1920
"acc_dropout": 0.2,
2021
"acc_n_heads": 16,
2122
"draw_fig_acc": false,
2223

23-
"div_num_epochs": 25,
24+
"div_num_epochs": 5,
2425
"div_lr": 1e-3,
2526
"div_final_lr": 1e-6,
2627
"div_dropout": 0.1,
@@ -31,11 +32,21 @@
3132

3233
"surrogate_inference_path": "surrogate_models/",
3334

34-
"n_ensemble_models": 5,
35+
"n_ensemble_models": 2,
3536
"n_models_in_pool": 128,
36-
"n_models_to_generate": 5000,
37-
"batch_size_inference": 8192,
38-
"min_accuracy_for_pool": 0.83,
39-
"plot_tsne": true,
40-
"best_models_save_path": "best_models/"
37+
"n_models_to_generate": 4096,
38+
"batch_size_inference": 4096,
39+
"min_accuracy_for_pool": 0.85,
40+
"plot_tsne": false,
41+
"best_models_save_path": "best_models/",
42+
43+
"n_epochs_final": 1,
44+
"lr_final": 0.025,
45+
"batch_size_final": 96,
46+
"dataset_name": "CIFAR10",
47+
"final_dataset_path": "final_dataset/",
48+
"output_path": "output/",
49+
"width": 4,
50+
"num_cells": 3,
51+
"n_ece_bins": 15
4152
}
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)