Skip to content

Commit b0a41e4

Browse files
author
Vladimir Kurmanov
committed
2 parents 97d1367 + a572bca commit b0a41e4

22 files changed

+790
-809
lines changed

Udeneev2025Surrogate.pdf

720 KB
Binary file not shown.

code/best_models.zip

36 Bytes
Binary file not shown.

code/data_generator.ipynb

Lines changed: 60 additions & 33 deletions
Large diffs are not rendered by default.

code/dependecies.zip

39 Bytes
Binary file not shown.

code/GCN.py renamed to code/dependecies/GCN.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
GraphNorm,
1212
)
1313
from torch_geometric.utils import dense_to_sparse
14-
from torch.optim.lr_scheduler import StepLR
14+
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
1515
from tqdm.notebook import tqdm
1616
import matplotlib.pyplot as plt
1717
import numpy as np
@@ -347,10 +347,11 @@ def train_model_diversity(
347347
num_epochs,
348348
device="cpu",
349349
developer_mode=False,
350+
final_lr=0.001
350351
):
351352
model.to(device)
352353
train_losses, valid_losses = [], []
353-
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
354+
scheduler = CosineAnnealingLR(optimizer, num_epochs, eta_min=final_lr)
354355

355356
for epoch in tqdm(range(num_epochs), desc="Training Progress"):
356357
# --------------------
@@ -424,7 +425,8 @@ def train_model_diversity(
424425
except ImportError:
425426
pass
426427

427-
plt.figure(figsize=(10, 5))
428+
plt.figure(figsize=(12, 6))
429+
plt.rc('font', size=20)
428430
plt.plot(range(1, len(train_losses)+1), train_losses, marker='o', label='Train Loss')
429431
plt.plot(range(1, len(valid_losses)+1), valid_losses, marker='s', label='Valid Loss')
430432
plt.xlabel('Epoch')
@@ -449,12 +451,13 @@ def train_model_accuracy(
449451
num_epochs,
450452
device="cpu",
451453
developer_mode=False,
454+
final_lr=0.001,
452455
):
453456
model.to(device)
454457
train_losses = []
455458
valid_losses = []
456459

457-
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
460+
scheduler = CosineAnnealingLR(optimizer, num_epochs, eta_min=final_lr)
458461

459462
for epoch in tqdm(range(num_epochs), desc="Training Progress"):
460463
model.train()
@@ -511,6 +514,7 @@ def train_model_accuracy(
511514
pass
512515

513516
plt.figure(figsize=(12, 6))
517+
plt.rc('font', size=20)
514518
tmp_train_losses = np.array(train_losses)
515519
tmp_valid_losses = np.array(valid_losses)
516520
plt.plot(range(1, len(tmp_train_losses) + 1), tmp_train_losses * 1e4, marker="o", label="Train Loss")
File renamed without changes.
File renamed without changes.

code/gcn-training.ipynb

Lines changed: 293 additions & 286 deletions
Large diffs are not rendered by default.
21.3 KB
Loading

code/graphics/clusters_GAT.png

666 KB
Loading

code/graphics/darts_cell.png

61 KB
Loading
25.2 KB
Loading

code/graphics/surrogate_arch.png

-31 Bytes
Loading

code/graphics/surrogate_arch.xml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
<mxfile host="app.diagrams.net" agent="Mozilla/5.0 (X11; Linux x86_64; rv:138.0) Gecko/20100101 Firefox/138.0" version="27.0.2">
1+
<mxfile host="app.diagrams.net" agent="Mozilla/5.0 (X11; Linux x86_64; rv:138.0) Gecko/20100101 Firefox/138.0" version="27.0.5">
22
<diagram name="GAT Architecture" id="gat-diagram">
3-
<mxGraphModel dx="1072" dy="577" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
3+
<mxGraphModel dx="1554" dy="845" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
44
<root>
55
<mxCell id="0" />
66
<mxCell id="1" parent="0" />
@@ -17,7 +17,7 @@
1717
<mxGeometry x="130" y="480" width="160" height="40" as="geometry" />
1818
</mxCell>
1919
<mxCell id="fc1" value="&lt;div&gt;Linear&lt;/div&gt;" style="shape=rectangle;rounded=1;whiteSpace=wrap;html=1;fillColor=#F5F5F5;" parent="1" vertex="1">
20-
<mxGeometry x="400" y="300" width="160" height="30" as="geometry" />
20+
<mxGeometry x="400" y="300" width="160" height="40" as="geometry" />
2121
</mxCell>
2222
<mxCell id="fc_norm" value="&lt;div&gt;Leaky ReLU&lt;/div&gt;&lt;div&gt;LayerNorm&lt;/div&gt;" style="shape=rectangle;rounded=1;whiteSpace=wrap;html=1;fillColor=#D5E8D4;" parent="1" vertex="1">
2323
<mxGeometry x="400" y="360" width="160" height="40" as="geometry" />
@@ -48,19 +48,19 @@
4848
<mxCell id="e13" style="edgeStyle=orthogonalEdgeStyle;rounded=0;" parent="1" source="fc2" target="out" edge="1">
4949
<mxGeometry relative="1" as="geometry" />
5050
</mxCell>
51-
<mxCell id="1mhNudxIzqTfyHtObyKU-1" value="&lt;div&gt;&lt;font style=&quot;font-size: 17px;&quot;&gt;&lt;/font&gt;&lt;/div&gt;" style="text;html=1;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" vertex="1" parent="1">
51+
<mxCell id="1mhNudxIzqTfyHtObyKU-1" value="&lt;div&gt;&lt;font style=&quot;font-size: 17px;&quot;&gt;&lt;/font&gt;&lt;/div&gt;" style="text;html=1;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
5252
<mxGeometry x="50" y="360" width="60" height="30" as="geometry" />
5353
</mxCell>
54-
<mxCell id="1mhNudxIzqTfyHtObyKU-2" value="&lt;div&gt;Output Layer&lt;/div&gt;" style="shape=rectangle;rounded=1;whiteSpace=wrap;html=1;fillColor=#DAE8FC;" vertex="1" parent="1">
54+
<mxCell id="1mhNudxIzqTfyHtObyKU-2" value="&lt;div&gt;Output Layer&lt;/div&gt;" style="shape=rectangle;rounded=1;whiteSpace=wrap;html=1;fillColor=#DAE8FC;" parent="1" vertex="1">
5555
<mxGeometry x="400" y="550" width="160" height="40" as="geometry" />
5656
</mxCell>
57-
<mxCell id="1mhNudxIzqTfyHtObyKU-3" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=1;exitDx=0;exitDy=0;" edge="1" parent="1" source="out" target="1mhNudxIzqTfyHtObyKU-2">
57+
<mxCell id="1mhNudxIzqTfyHtObyKU-3" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=1;exitDx=0;exitDy=0;" parent="1" source="out" target="1mhNudxIzqTfyHtObyKU-2" edge="1">
5858
<mxGeometry width="50" height="50" relative="1" as="geometry">
5959
<mxPoint x="340" y="470" as="sourcePoint" />
6060
<mxPoint x="390" y="420" as="targetPoint" />
6161
</mxGeometry>
6262
</mxCell>
63-
<mxCell id="1mhNudxIzqTfyHtObyKU-4" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="pool" target="fc1">
63+
<mxCell id="1mhNudxIzqTfyHtObyKU-4" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" parent="1" source="pool" target="fc1" edge="1">
6464
<mxGeometry width="50" height="50" relative="1" as="geometry">
6565
<mxPoint x="340" y="510" as="sourcePoint" />
6666
<mxPoint x="390" y="460" as="targetPoint" />
@@ -72,7 +72,7 @@
7272
</Array>
7373
</mxGeometry>
7474
</mxCell>
75-
<mxCell id="1mhNudxIzqTfyHtObyKU-8" value="" style="shape=curlyBracket;whiteSpace=wrap;html=1;rounded=1;labelPosition=left;verticalLabelPosition=middle;align=right;verticalAlign=middle;" vertex="1" parent="1">
75+
<mxCell id="1mhNudxIzqTfyHtObyKU-8" value="" style="shape=curlyBracket;whiteSpace=wrap;html=1;rounded=1;labelPosition=left;verticalLabelPosition=middle;align=right;verticalAlign=middle;" parent="1" vertex="1">
7676
<mxGeometry x="100" y="290" width="20" height="180" as="geometry" />
7777
</mxCell>
7878
</root>

0 commit comments

Comments
 (0)