Skip to content

Commit b55650c

Browse files
author
Vladimir Kurmanov
committed
seed fix
1 parent 22addd5 commit b55650c

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

code/train_models.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import torch
55
import nni
6-
from torch.utils.data import SubsetRandomSampler
6+
from torch.utils.data import SubsetRandomSampler, SequentialSampler
77
from torchvision import transforms
88
from torchvision.datasets import CIFAR10
99
from nni.nas.evaluator.pytorch import DataLoader, Classification
@@ -21,13 +21,21 @@
2121

2222

2323
ARCHITECTURES_PATH = "/kaggle/input/second-dataset/dataset"
24-
MAX_EPOCHS = 70
24+
MAX_EPOCHS = 1
2525
LEARNING_RATE = 0.025
2626
BATCH_SIZE = 96
2727
NUM_MODLES = 2000
2828
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
2929
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
3030

31+
SEED = 228
32+
# random.seed(SEED)
33+
np.random.seed(SEED)
34+
torch.manual_seed(SEED)
35+
torch.cuda.manual_seed_all(SEED) # если есть GPU
36+
torch.backends.cudnn.deterministic = True
37+
torch.backends.cudnn.benchmark = False
38+
3139

3240
def load_json_from_directory(directory_path):
3341
json_data = []
@@ -70,7 +78,7 @@ def get_data_loaders(batch_size=512):
7078
)
7179
num_samples = len(train_data)
7280
indices = np.random.permutation(num_samples)
73-
split = num_samples // 2
81+
split = int(num_samples * 0.75)
7482

7583
search_train_loader = DataLoader(
7684
train_data,
@@ -80,10 +88,10 @@ def get_data_loaders(batch_size=512):
8088
)
8189

8290
search_valid_loader = DataLoader(
83-
train_data,
91+
train_data,
8492
batch_size=batch_size,
8593
num_workers=6,
86-
sampler=SubsetRandomSampler(indices[split:]),
94+
sampler=SequentialSampler(indices[split:]),
8795
)
8896

8997
return search_train_loader, search_valid_loader
@@ -100,9 +108,9 @@ def train_model(
100108
with model_context(architecture):
101109
model = DartsSpace(width=16, num_cells=10, dataset='cifar')
102110

103-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104-
if torch.cuda.device_count() > 1:
105-
model = torch.nn.DataParallel(model)
111+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
112+
#if torch.cuda.device_count() > 1:
113+
# model = torch.nn.DataParallel(model)
106114
model.to(device)
107115

108116
evaluator = Lightning(
@@ -115,7 +123,8 @@ def train_model(
115123
trainer=Trainer(
116124
gradient_clip_val=5.0,
117125
max_epochs=max_epochs,
118-
fast_dev_run=fast_dev_run
126+
fast_dev_run=fast_dev_run,
127+
devices=[0]
119128
),
120129
train_dataloaders=train_loader,
121130
val_dataloaders=valid_loader
@@ -130,7 +139,7 @@ def evaluate_and_save_results(
130139
architecture,
131140
model_id, # Новый обязательный параметр для идентификации модели
132141
valid_loader,
133-
folder_name="results"
142+
folder_name="results_seq_0"
134143
):
135144
"""
136145
Оценивает модель на валидационном наборе данных и сохраняет результаты в JSON.
@@ -154,6 +163,7 @@ def evaluate_and_save_results(
154163

155164
with torch.no_grad():
156165
for images, labels in valid_loader:
166+
print(labels)
157167
images, labels = images.to(device), labels.to(device)
158168
outputs = model(images)
159169
outputs = torch.softmax(outputs, dim=1)
@@ -201,5 +211,5 @@ def evaluate_and_save_results(
201211
clear_output(wait=True)
202212

203213
evaluate_and_save_results(
204-
model, architecture, idx, valid_loader=search_valid_loader, folder_name="results"
214+
model, architecture, idx, valid_loader=search_valid_loader, folder_name="results_seq_0"
205215
) # Оцениваем и сохраняем архитектуры, предсказания на тестовом наборе CIFAR10 и accuracy

0 commit comments

Comments
 (0)