Skip to content

Commit 31e3ecd

Browse files
committed
refactor some code
1 parent 6199cb4 commit 31e3ecd

File tree

7 files changed

+64
-105
lines changed

7 files changed

+64
-105
lines changed

code/dependencies/GCN.py

Lines changed: 56 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from joblib import Parallel, delayed
2020
from torch_geometric.data import Data, Batch
2121
from torch_geometric.loader import DataLoader
22+
from torch_geometric.utils import dropout_edge
2223

2324

2425
class SimpleGCN(nn.Module):
@@ -146,22 +147,20 @@ def __init__(self, input_dim, output_dim=16, dropout=0.5, pooling="max", heads=4
146147

147148
self.input_dim = input_dim
148149
self.output_dim = output_dim
149-
self.hidden_dim = 64 # базовая размерность скрытого слоя
150+
self.hidden_dim = 64
150151
self.heads = heads
151152

152-
# Attention Conv слои
153153
self.gat1 = GATv2Conv(input_dim, self.hidden_dim // heads, heads=heads)
154154
self.gat2 = GATv2Conv(self.hidden_dim, 256 // heads, heads=heads)
155155
self.gat3 = GATv2Conv(256, 256 // heads, heads=heads)
156156
self.gat4 = GATv2Conv(256, self.hidden_dim // heads, heads=heads)
157157

158-
# Проекции для full residual
158+
# residual projection
159159
self.res1 = nn.Linear(input_dim, self.hidden_dim)
160160
self.res2 = nn.Linear(self.hidden_dim, 256)
161161
self.res3 = nn.Linear(256, 256)
162162
self.res4 = nn.Linear(256, self.hidden_dim)
163163

164-
# Нормализация после каждого блока
165164
self.norm1 = GraphNorm(self.hidden_dim)
166165
self.norm2 = GraphNorm(256)
167166
self.norm3 = GraphNorm(256)
@@ -170,7 +169,6 @@ def __init__(self, input_dim, output_dim=16, dropout=0.5, pooling="max", heads=4
170169
self.dropout = nn.Dropout(dropout)
171170
self.pooling = pooling
172171

173-
# Полносвязная часть для графового эмбеддинга
174172
self.fc1 = nn.Linear(self.hidden_dim, self.hidden_dim)
175173
self.fc_norm = nn.LayerNorm(self.hidden_dim)
176174
self.fc2 = nn.Linear(self.hidden_dim, output_dim)
@@ -200,7 +198,7 @@ def forward(self, x, edge_index, batch=None):
200198
h4 = self.norm4(h4)
201199
h4 = self.dropout(h4)
202200

203-
# Глобальное агрегирование
201+
# Global pooling
204202
if self.pooling == "max":
205203
hg = global_max_pool(h4, batch)
206204
elif self.pooling == "mean":
@@ -210,7 +208,7 @@ def forward(self, x, edge_index, batch=None):
210208
else:
211209
raise ValueError("Unsupported pooling method. Use 'max', 'mean' or 'sum'.")
212210

213-
# Финальный MLP
211+
# Final MLP
214212
out = self.fc1(hg)
215213
out = self.fc_norm(out)
216214
out = F.leaky_relu(out)
@@ -223,7 +221,7 @@ def forward(self, x, edge_index, batch=None):
223221
class CustomDataset(Dataset):
224222
@staticmethod
225223
def preprocess(adj, features):
226-
"""Преобразует матрицу смежности и признаки в тензоры."""
224+
"""Transforms the adjacency matrix and features into tensors."""
227225
adj = torch.tensor(adj, dtype=torch.float)
228226
features = torch.tensor(features, dtype=torch.float)
229227
return adj, features
@@ -272,84 +270,84 @@ def __len__(self):
272270
class TripletGraphDataset(Dataset):
273271
def __init__(self, base_dataset, diversity_matrix):
274272
"""
275-
base_dataset: CustomDataset, отдающий Data с .index
276-
diversity_matrix: матрица [M, M], M >= N, значений {1, -1, 0}
273+
base_dataset: CustomDataset that transmits data from .index
274+
diversity_matrix: matrix [M, M], M >= N, value {1, -1, 0}
277275
"""
278276
self.base = base_dataset
279277
self.div = diversity_matrix
280278
self.N = len(self.base)
281279

282-
# Строим отображение оригинального индекса -> внутренний
283-
# пример: если base_dataset[5].index == 42, то orig2int[42] = 5
280+
# Building a display of the original index -> internal
281+
# example: if base_dataset[5].index == 42, then orig2int[42] = 5
284282
self.orig2int = {self.base[i].index: i for i in range(self.N)}
285283

286284
def __len__(self):
287285
return self.N
288286

289287
def __getitem__(self, idx):
290-
# 1) Получаем Data и его оригинальный индекс
288+
# 1) Get Data and its original index
291289
anchor = self.base[idx]
292-
anchor_orig = anchor.index # в диапазоне [0, M-1]
290+
anchor_orig = anchor.index # in range [0, M-1]
293291

294-
# 2) Берём строку diversity_matrix по оригинальному индексу
295-
row = self.div[anchor_orig] # длина M
292+
# 2) Get the row of diversity_matrix by the original index
293+
row = self.div[anchor_orig] # length M
296294

297-
# 3) Находим оригинальные индексы положительных и отрицательных
295+
# 3) Find original indices of positive and negative examples
298296
pos_orig = np.where((row == 1) & (np.arange(len(row)) != anchor_orig))[0]
299297
neg_orig = np.where(row == -1)[0]
300298

301-
# 4) Фильтруем по наличию в self.orig2int
299+
# 4) Filter by presence in self.orig2int
302300
pos_orig = [i for i in pos_orig if i in self.orig2int]
303301
neg_orig = [i for i in neg_orig if i in self.orig2int]
304302

305-
# 5) Проверка наличия хотя бы одного положительного и отрицательного
303+
# 5) Check for at least one positive and negative example
306304
if len(pos_orig) == 0 or len(neg_orig) == 0:
307305
raise IndexError(f"No valid pos/neg for original index {anchor_orig}")
308306

309-
# 6) Случайно выбираем подходящие индексы
307+
# 6) Randomly select appropriate indices
310308
pos_o = int(np.random.choice(pos_orig))
311309
neg_o = int(np.random.choice(neg_orig))
312310

313-
# 7) Переводим в внутренние индексы и получаем Data
311+
# 7) Convert to internal indices and get Data
314312
pos_int = self.orig2int[pos_o]
315313
neg_int = self.orig2int[neg_o]
316314

317315
positive = self.base[pos_int]
318316
negative = self.base[neg_int]
319317

320-
# 8) Возвращаем три Data и тензор оригинальных индексов
318+
# 8) Return three Data and a tensor of original indices
321319
idx_triplet = torch.tensor([anchor_orig, pos_o, neg_o], dtype=torch.long)
322320
return anchor, positive, negative, idx_triplet
323321

324322

325323
def collate_triplets(batch):
326324
"""
327-
batch: list of tuples (anchor, pos, neg, idx_triplet)
328-
Возвращаем:
329-
- три Batched Data
330-
- один LongTensor [batch_size, 3] с исходными индексами
325+
batch: list of types (anchor, pos, neg, idx_triplet)
326+
is returned:
327+
- Three Batched Data
328+
- one LongTensor [batch_size, 3] with the original indexes
331329
"""
332330
anchors, positives, negatives, idxs = zip(*batch)
333331
batch_anchor = Batch.from_data_list(anchors)
334332
batch_positive = Batch.from_data_list(positives)
335333
batch_negative = Batch.from_data_list(negatives)
336-
# соберём матрицу индексов shape=(batch_size,3)
334+
# assemble the matrix of indexes shape=(batch_size,3)
337335
idx_tensor = torch.cat(idxs, dim=0).view(-1, 3)
338336
return batch_anchor, batch_positive, batch_negative, idx_tensor
339337

340338

341339
def collate_graphs(batch):
342340
"""
343341
batch: list of torch_geometric.data.Data
344-
Возвращает Batch, который можно подать в GNN.
342+
Returns Batch, which can be passed to GNN.
345343
"""
346344
return Batch.from_data_list(batch)
347345

348346

349347
def train_model_diversity(
350348
model,
351-
train_loader, # DataLoader выдаёт (anchor_batch, pos_batch, neg_batch)
352-
valid_loader, # То же для валидации
349+
train_loader, # DataLoader returns (anchor_batch, pos_batch, neg_batch, idx_triplet)
350+
valid_loader, # The same for validation
353351
optimizer,
354352
criterion,
355353
num_epochs,
@@ -364,7 +362,7 @@ def train_model_diversity(
364362

365363
for epoch in tqdm(range(num_epochs), desc="Training Progress"):
366364
# --------------------
367-
# 1) Тренировочный проход
365+
# 1) Training pass
368366
# --------------------
369367
model.train()
370368
running_loss = 0.0
@@ -377,19 +375,19 @@ def train_model_diversity(
377375
break
378376

379377
optimizer.zero_grad()
380-
# Переносим весь батч на device
378+
# Move the entire batch to device
381379
anchor_batch = anchor_batch.to(device)
382380
pos_batch = pos_batch.to(device)
383381
neg_batch = neg_batch.to(device)
384382

385-
# Прогоняем через модель
383+
# Feed through the model
386384
emb_anchor = model(
387385
anchor_batch.x, anchor_batch.edge_index, anchor_batch.batch
388386
)
389387
emb_pos = model(pos_batch.x, pos_batch.edge_index, pos_batch.batch)
390388
emb_neg = model(neg_batch.x, neg_batch.edge_index, neg_batch.batch)
391389

392-
# Считаем loss, backward, step
390+
# Calculate loss, backward, step
393391
loss = criterion(emb_anchor, emb_pos, emb_neg)
394392
loss.backward()
395393
optimizer.step()
@@ -402,7 +400,7 @@ def train_model_diversity(
402400
train_losses.append(avg_train_loss)
403401

404402
# --------------------
405-
# 2) Валидация
403+
# 2) Validation
406404
# --------------------
407405
model.eval()
408406
val_loss = 0.0
@@ -415,7 +413,7 @@ def train_model_diversity(
415413
if developer_mode and i > 0:
416414
break
417415

418-
# перенос на device
416+
# Move the entire batch to device
419417
anchor_batch = anchor_batch.to(device)
420418
pos_batch = pos_batch.to(device)
421419
neg_batch = neg_batch.to(device)
@@ -433,43 +431,14 @@ def train_model_diversity(
433431
avg_valid_loss = val_loss / max(1, n_val_batches)
434432
valid_losses.append(avg_valid_loss)
435433

436-
# --------------------
437-
# 3) Лог и визуализация
438-
# --------------------
439-
try:
440-
from IPython.display import clear_output
441-
442-
clear_output(wait=True)
443-
except ImportError:
444-
pass
445-
446434
lr = scheduler.get_last_lr()[0]
447435
print(
448436
f"Epoch {epoch+1}/{num_epochs} — "
449437
f"Train Loss: {avg_train_loss:.4f}, Valid Loss: {avg_valid_loss:.4f}, LR: {lr:.6f}"
450438
)
451439

452440
if draw_figure:
453-
plt.figure(figsize=(12, 6))
454-
plt.rc("font", size=20)
455-
plt.plot(
456-
range(1, len(train_losses) + 1),
457-
train_losses,
458-
marker="o",
459-
label="Train Loss",
460-
)
461-
plt.plot(
462-
range(1, len(valid_losses) + 1),
463-
valid_losses,
464-
marker="s",
465-
label="Valid Loss",
466-
)
467-
plt.xlabel("Epoch")
468-
plt.ylabel("Loss")
469-
plt.grid(True)
470-
plt.legend()
471-
plt.tight_layout()
472-
plt.show()
441+
plot_train_valid_losses(train_losses, valid_losses)
473442

474443
return train_losses, valid_losses
475444

@@ -495,7 +464,7 @@ def train_model_accuracy(
495464
for epoch in tqdm(range(num_epochs), desc="Training Progress"):
496465
model.train()
497466
train_loss = 0
498-
n_train_samples = 0 # изменено: считаем количество графов, а не батчей
467+
n_train_samples = 0
499468

500469
for i, data in enumerate(train_loader):
501470
if developer_mode and i > 0:
@@ -540,45 +509,37 @@ def train_model_accuracy(
540509
avg_valid_loss = valid_loss / max(1, n_val_samples)
541510
valid_losses.append(avg_valid_loss)
542511

543-
try:
544-
from IPython.display import clear_output
545-
546-
clear_output(wait=True)
547-
except:
548-
pass
549-
550512
lr = scheduler.get_last_lr()[0]
551513
print(
552514
f"Epoch {epoch+1}, Train Loss: {avg_train_loss * 1e4:.4f}, "
553515
f"Valid Loss: {avg_valid_loss * 1e4:.4f}, LR: {lr:.6f}"
554516
)
517+
555518
if draw_figure:
556-
plt.figure(figsize=(12, 6))
557-
plt.rc("font", size=20)
558519
tmp_train_losses = np.array(train_losses)
559-
tmp_valid_losses = np.array(valid_losses)
560-
plt.plot(
561-
range(1, len(tmp_train_losses) + 1),
562-
tmp_train_losses * 1e4,
563-
marker="o",
564-
label="Train Loss",
565-
)
566-
plt.plot(
567-
range(1, len(tmp_valid_losses) + 1),
568-
tmp_valid_losses * 1e4,
569-
marker="o",
570-
label="Valid Loss",
571-
)
572-
plt.xlabel("Epoch")
573-
plt.ylabel("Loss * 1e4")
574-
plt.ylim(0, 10)
575-
plt.grid(True)
576-
plt.legend()
577-
plt.show()
520+
tmp_valid_losses = np.array(valid_losses) * 1e4
521+
plot_train_valid_losses(tmp_train_losses, tmp_valid_losses)
578522

579523
return train_losses, valid_losses
580524

581525

526+
def plot_train_valid_losses(train_losses, valid_losses):
527+
plt.figure(figsize=(12, 6))
528+
plt.rc("font", size=20)
529+
plt.plot(
530+
range(1, len(train_losses) + 1), train_losses, marker="o", label="Train Loss"
531+
)
532+
plt.plot(
533+
range(1, len(valid_losses) + 1), valid_losses, marker="s", label="Valid Loss"
534+
)
535+
plt.xlabel("Epoch")
536+
plt.ylabel("Loss")
537+
plt.grid(True)
538+
plt.legend()
539+
plt.tight_layout()
540+
plt.show()
541+
542+
582543
def get_positive_and_negative(diversity_matrix, indices, dataset=None):
583544
positive_indices = []
584545
negative_indices = []

code/dependencies/Graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
'dil_conv_5x5',
1717
]
1818

19-
encoder = OneHotEncoder(handle_unknown='ignore')
2019
encoder = OneHotEncoder(handle_unknown='ignore')
2120
ops_array = np.array(DARTS_OPS).reshape(-1, 1)
2221

code/greedy-finding-best-models.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@
162162
},
163163
{
164164
"cell_type": "code",
165-
"execution_count": 5,
165+
"execution_count": null,
166166
"metadata": {
167167
"execution": {
168168
"iopub.execute_input": "2025-05-21T16:52:59.345061Z",

code/surrogate_hp.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,27 @@
22
"seed":42,
33
"dataset_path": "third_dataset/",
44
"device": "cpu",
5-
"developer_mode": true,
5+
"developer_mode": false,
66
"n_models": 1300,
77

8-
"upper_margin": 0.9,
9-
"lower_margin": 0.1,
8+
"upper_margin": 0.75,
9+
"lower_margin": 0.25,
1010
"diversity_matrix_metric": "overlap",
1111

1212
"train_size": 0.8,
1313
"batch_size": 8,
1414
"input_dim":8,
1515

16-
"acc_num_epochs": 20,
16+
"acc_num_epochs": 40,
1717
"acc_lr": 1e-2,
1818
"acc_final_lr": 1e-5,
1919
"acc_dropout": 0.2,
2020
"acc_n_heads": 16,
2121
"draw_fig_acc": false,
2222

23-
"div_num_epochs": 15,
23+
"div_num_epochs": 25,
2424
"div_lr": 1e-3,
25-
"div_final_lr": 1e-5,
25+
"div_final_lr": 1e-6,
2626
"div_dropout": 0.1,
2727
"div_n_heads": 4,
2828
"margin": 1,
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)