Skip to content

Commit f6eba92

Browse files
committed
update paper, size up number of epochs for accuracy surrogate and add preview information about ensemble accuracy on random architectures
1 parent 22addd5 commit f6eba92

11 files changed

+394
-323
lines changed

Udeneev2025Surrogate.pdf

67.2 KB
Binary file not shown.

code/GCN.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
GraphNorm,
1212
)
1313
from torch_geometric.utils import dense_to_sparse
14-
from torch.optim.lr_scheduler import StepLR
14+
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
1515
from tqdm.notebook import tqdm
1616
import matplotlib.pyplot as plt
1717
import numpy as np
@@ -347,10 +347,11 @@ def train_model_diversity(
347347
num_epochs,
348348
device="cpu",
349349
developer_mode=False,
350+
final_lr=0.001
350351
):
351352
model.to(device)
352353
train_losses, valid_losses = [], []
353-
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
354+
scheduler = CosineAnnealingLR(optimizer, num_epochs, eta_min=final_lr)
354355

355356
for epoch in tqdm(range(num_epochs), desc="Training Progress"):
356357
# --------------------
@@ -424,7 +425,8 @@ def train_model_diversity(
424425
except ImportError:
425426
pass
426427

427-
plt.figure(figsize=(10, 5))
428+
plt.figure(figsize=(12, 6))
429+
plt.rc('font', size=20)
428430
plt.plot(range(1, len(train_losses)+1), train_losses, marker='o', label='Train Loss')
429431
plt.plot(range(1, len(valid_losses)+1), valid_losses, marker='s', label='Valid Loss')
430432
plt.xlabel('Epoch')
@@ -449,12 +451,13 @@ def train_model_accuracy(
449451
num_epochs,
450452
device="cpu",
451453
developer_mode=False,
454+
final_lr=0.001,
452455
):
453456
model.to(device)
454457
train_losses = []
455458
valid_losses = []
456459

457-
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
460+
scheduler = CosineAnnealingLR(optimizer, num_epochs, eta_min=final_lr)
458461

459462
for epoch in tqdm(range(num_epochs), desc="Training Progress"):
460463
model.train()
@@ -511,6 +514,7 @@ def train_model_accuracy(
511514
pass
512515

513516
plt.figure(figsize=(12, 6))
517+
plt.rc('font', size=20)
514518
tmp_train_losses = np.array(train_losses)
515519
tmp_valid_losses = np.array(valid_losses)
516520
plt.plot(range(1, len(tmp_train_losses) + 1), tmp_train_losses * 1e4, marker="o", label="Train Loss")

code/best_models.zip

5.24 KB
Binary file not shown.

code/data_generator.ipynb

Lines changed: 60 additions & 33 deletions
Large diffs are not rendered by default.

code/dependecies.zip

39 Bytes
Binary file not shown.

code/gcn-training.ipynb

Lines changed: 293 additions & 286 deletions
Large diffs are not rendered by default.
21.3 KB
Loading

code/graphics/darts_cell.png

61 KB
Loading
25.2 KB
Loading

code/random_arch_ensemble_acc.txt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
========================================================================
2+
Ensemble Accuracy: 93.62%
3+
Model 1 Accuracy: 90.95%
4+
Model 2 Accuracy: 90.92%
5+
Model 3 Accuracy: 90.57%
6+
Model 4 Accuracy: 91.02%
7+
Model 5 Accuracy: 90.76%
8+
Model 6 Accuracy: 90.38%
9+
========================================================================
10+
Ensemble Accuracy: 93.80%
11+
Model 1 Accuracy: 91.24%
12+
Model 2 Accuracy: 91.25%
13+
Model 3 Accuracy: 91.36%
14+
Model 4 Accuracy: 91.89%
15+
Model 5 Accuracy: 91.64%
16+
Model 6 Accuracy: 90.58%
17+
========================================================================
18+
Ensemble Accuracy: 94.18%
19+
Model 1 Accuracy: 91.64%
20+
Model 2 Accuracy: 91.51%
21+
Model 3 Accuracy: 91.24%
22+
Model 4 Accuracy: 92.18%
23+
Model 5 Accuracy: 91.99%
24+
Model 6 Accuracy: 91.70%
25+
========================================================================

code/test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import math
2+
3+
def binom(n, m):
4+
return math.factorial(n) // (math.factorial(n - m) * math.factorial(m))
5+
6+
result = binom(2, 2) * binom(3, 2) * binom(4, 2) * binom(5, 2) * 7 ** 8
7+
8+
print(result // 1e9)

0 commit comments

Comments
 (0)