Skip to content

Commit 7014a7c

Browse files
author
Vladimir Kurmanov
committed
correct
1 parent b55650c commit 7014a7c

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

code/train_models.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
ARCHITECTURES_PATH = "/kaggle/input/second-dataset/dataset"
24-
MAX_EPOCHS = 1
24+
MAX_EPOCHS = 70
2525
LEARNING_RATE = 0.025
2626
BATCH_SIZE = 96
2727
NUM_MODLES = 2000
@@ -78,19 +78,19 @@ def get_data_loaders(batch_size=512):
7878
)
7979
num_samples = len(train_data)
8080
indices = np.random.permutation(num_samples)
81-
split = int(num_samples * 0.75)
81+
split = int(num_samples * 0.5)
8282

8383
search_train_loader = DataLoader(
8484
train_data,
8585
batch_size=batch_size,
86-
num_workers=6,
86+
num_workers=10,
8787
sampler=SubsetRandomSampler(indices[:split]),
8888
)
8989

9090
search_valid_loader = DataLoader(
9191
train_data,
9292
batch_size=batch_size,
93-
num_workers=6,
93+
num_workers=10,
9494
sampler=SequentialSampler(indices[split:]),
9595
)
9696

@@ -106,7 +106,7 @@ def train_model(
106106
fast_dev_run=False
107107
):
108108
with model_context(architecture):
109-
model = DartsSpace(width=16, num_cells=10, dataset='cifar')
109+
model = DartsSpace(width=16, num_cells=3, dataset='cifar')
110110

111111
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
112112
#if torch.cuda.device_count() > 1:
@@ -163,7 +163,7 @@ def evaluate_and_save_results(
163163

164164
with torch.no_grad():
165165
for images, labels in valid_loader:
166-
print(labels)
166+
# print(labels)
167167
images, labels = images.to(device), labels.to(device)
168168
outputs = model(images)
169169
outputs = torch.softmax(outputs, dim=1)

0 commit comments

Comments
 (0)