Skip to content

Commit b657943

Browse files
author
Vladimir Kurmanov
committed
Dataset train for 10 days
1 parent 7014a7c commit b657943

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

code/train_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
ARCHITECTURES_PATH = "/kaggle/input/second-dataset/dataset"
24-
MAX_EPOCHS = 70
24+
MAX_EPOCHS = 60
2525
LEARNING_RATE = 0.025
2626
BATCH_SIZE = 96
2727
NUM_MODLES = 2000
@@ -78,7 +78,7 @@ def get_data_loaders(batch_size=512):
7878
)
7979
num_samples = len(train_data)
8080
indices = np.random.permutation(num_samples)
81-
split = int(num_samples * 0.5)
81+
split = int(num_samples * 0.75)
8282

8383
search_train_loader = DataLoader(
8484
train_data,
@@ -106,7 +106,7 @@ def train_model(
106106
fast_dev_run=False
107107
):
108108
with model_context(architecture):
109-
model = DartsSpace(width=16, num_cells=3, dataset='cifar')
109+
model = DartsSpace(width=16, num_cells=10, dataset='cifar')
110110

111111
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
112112
#if torch.cuda.device_count() > 1:

0 commit comments

Comments
 (0)