19
19
from joblib import Parallel , delayed
20
20
from torch_geometric .data import Data , Batch
21
21
from torch_geometric .loader import DataLoader
22
+ from torch_geometric .utils import dropout_edge
22
23
23
24
24
25
class SimpleGCN (nn .Module ):
@@ -146,22 +147,20 @@ def __init__(self, input_dim, output_dim=16, dropout=0.5, pooling="max", heads=4
146
147
147
148
self .input_dim = input_dim
148
149
self .output_dim = output_dim
149
- self .hidden_dim = 64 # базовая размерность скрытого слоя
150
+ self .hidden_dim = 64
150
151
self .heads = heads
151
152
152
- # Attention Conv слои
153
153
self .gat1 = GATv2Conv (input_dim , self .hidden_dim // heads , heads = heads )
154
154
self .gat2 = GATv2Conv (self .hidden_dim , 256 // heads , heads = heads )
155
155
self .gat3 = GATv2Conv (256 , 256 // heads , heads = heads )
156
156
self .gat4 = GATv2Conv (256 , self .hidden_dim // heads , heads = heads )
157
157
158
- # Проекции для full residual
158
+ # residual projection
159
159
self .res1 = nn .Linear (input_dim , self .hidden_dim )
160
160
self .res2 = nn .Linear (self .hidden_dim , 256 )
161
161
self .res3 = nn .Linear (256 , 256 )
162
162
self .res4 = nn .Linear (256 , self .hidden_dim )
163
163
164
- # Нормализация после каждого блока
165
164
self .norm1 = GraphNorm (self .hidden_dim )
166
165
self .norm2 = GraphNorm (256 )
167
166
self .norm3 = GraphNorm (256 )
@@ -170,7 +169,6 @@ def __init__(self, input_dim, output_dim=16, dropout=0.5, pooling="max", heads=4
170
169
self .dropout = nn .Dropout (dropout )
171
170
self .pooling = pooling
172
171
173
- # Полносвязная часть для графового эмбеддинга
174
172
self .fc1 = nn .Linear (self .hidden_dim , self .hidden_dim )
175
173
self .fc_norm = nn .LayerNorm (self .hidden_dim )
176
174
self .fc2 = nn .Linear (self .hidden_dim , output_dim )
@@ -200,7 +198,7 @@ def forward(self, x, edge_index, batch=None):
200
198
h4 = self .norm4 (h4 )
201
199
h4 = self .dropout (h4 )
202
200
203
- # Глобальное агрегирование
201
+ # Global pooling
204
202
if self .pooling == "max" :
205
203
hg = global_max_pool (h4 , batch )
206
204
elif self .pooling == "mean" :
@@ -210,7 +208,7 @@ def forward(self, x, edge_index, batch=None):
210
208
else :
211
209
raise ValueError ("Unsupported pooling method. Use 'max', 'mean' or 'sum'." )
212
210
213
- # Финальный MLP
211
+ # Final MLP
214
212
out = self .fc1 (hg )
215
213
out = self .fc_norm (out )
216
214
out = F .leaky_relu (out )
@@ -223,7 +221,7 @@ def forward(self, x, edge_index, batch=None):
223
221
class CustomDataset (Dataset ):
224
222
@staticmethod
225
223
def preprocess (adj , features ):
226
- """Преобразует матрицу смежности и признаки в тензоры ."""
224
+ """Transforms the adjacency matrix and features into tensors ."""
227
225
adj = torch .tensor (adj , dtype = torch .float )
228
226
features = torch .tensor (features , dtype = torch .float )
229
227
return adj , features
@@ -272,84 +270,84 @@ def __len__(self):
272
270
class TripletGraphDataset (Dataset ):
273
271
def __init__ (self , base_dataset , diversity_matrix ):
274
272
"""
275
- base_dataset: CustomDataset, отдающий Data с .index
276
- diversity_matrix: матрица [M, M], M >= N, значений {1, -1, 0}
273
+ base_dataset: CustomDataset that transmits data from .index
274
+ diversity_matrix: matrix [M, M], M >= N, value {1, -1, 0}
277
275
"""
278
276
self .base = base_dataset
279
277
self .div = diversity_matrix
280
278
self .N = len (self .base )
281
279
282
- # Строим отображение оригинального индекса -> внутренний
283
- # пример: если base_dataset[5].index == 42, то orig2int[42] = 5
280
+ # Building a display of the original index -> internal
281
+ # example: if base_dataset[5].index == 42, then orig2int[42] = 5
284
282
self .orig2int = {self .base [i ].index : i for i in range (self .N )}
285
283
286
284
def __len__ (self ):
287
285
return self .N
288
286
289
287
def __getitem__ (self , idx ):
290
- # 1) Получаем Data и его оригинальный индекс
288
+ # 1) Get Data and its original index
291
289
anchor = self .base [idx ]
292
- anchor_orig = anchor .index # в диапазоне [0, M-1]
290
+ anchor_orig = anchor .index # in range [0, M-1]
293
291
294
- # 2) Берём строку diversity_matrix по оригинальному индексу
295
- row = self .div [anchor_orig ] # длина M
292
+ # 2) Get the row of diversity_matrix by the original index
293
+ row = self .div [anchor_orig ] # length M
296
294
297
- # 3) Находим оригинальные индексы положительных и отрицательных
295
+ # 3) Find original indices of positive and negative examples
298
296
pos_orig = np .where ((row == 1 ) & (np .arange (len (row )) != anchor_orig ))[0 ]
299
297
neg_orig = np .where (row == - 1 )[0 ]
300
298
301
- # 4) Фильтруем по наличию в self.orig2int
299
+ # 4) Filter by presence in self.orig2int
302
300
pos_orig = [i for i in pos_orig if i in self .orig2int ]
303
301
neg_orig = [i for i in neg_orig if i in self .orig2int ]
304
302
305
- # 5) Проверка наличия хотя бы одного положительного и отрицательного
303
+ # 5) Check for at least one positive and negative example
306
304
if len (pos_orig ) == 0 or len (neg_orig ) == 0 :
307
305
raise IndexError (f"No valid pos/neg for original index { anchor_orig } " )
308
306
309
- # 6) Случайно выбираем подходящие индексы
307
+ # 6) Randomly select appropriate indices
310
308
pos_o = int (np .random .choice (pos_orig ))
311
309
neg_o = int (np .random .choice (neg_orig ))
312
310
313
- # 7) Переводим в внутренние индексы и получаем Data
311
+ # 7) Convert to internal indices and get Data
314
312
pos_int = self .orig2int [pos_o ]
315
313
neg_int = self .orig2int [neg_o ]
316
314
317
315
positive = self .base [pos_int ]
318
316
negative = self .base [neg_int ]
319
317
320
- # 8) Возвращаем три Data и тензор оригинальных индексов
318
+ # 8) Return three Data and a tensor of original indices
321
319
idx_triplet = torch .tensor ([anchor_orig , pos_o , neg_o ], dtype = torch .long )
322
320
return anchor , positive , negative , idx_triplet
323
321
324
322
325
323
def collate_triplets (batch ):
326
324
"""
327
- batch: list of tuples (anchor, pos, neg, idx_triplet)
328
- Возвращаем :
329
- - три Batched Data
330
- - один LongTensor [batch_size, 3] с исходными индексами
325
+ batch: list of types (anchor, pos, neg, idx_triplet)
326
+ is returned :
327
+ - Three Batched Data
328
+ - one LongTensor [batch_size, 3] with the original indexes
331
329
"""
332
330
anchors , positives , negatives , idxs = zip (* batch )
333
331
batch_anchor = Batch .from_data_list (anchors )
334
332
batch_positive = Batch .from_data_list (positives )
335
333
batch_negative = Batch .from_data_list (negatives )
336
- # соберём матрицу индексов shape=(batch_size,3)
334
+ # assemble the matrix of indexes shape=(batch_size,3)
337
335
idx_tensor = torch .cat (idxs , dim = 0 ).view (- 1 , 3 )
338
336
return batch_anchor , batch_positive , batch_negative , idx_tensor
339
337
340
338
341
339
def collate_graphs (batch ):
342
340
"""
343
341
batch: list of torch_geometric.data.Data
344
- Возвращает Batch, который можно подать в GNN.
342
+ Returns Batch, which can be passed to GNN.
345
343
"""
346
344
return Batch .from_data_list (batch )
347
345
348
346
349
347
def train_model_diversity (
350
348
model ,
351
- train_loader , # DataLoader выдаёт (anchor_batch, pos_batch, neg_batch)
352
- valid_loader , # То же для валидации
349
+ train_loader , # DataLoader returns (anchor_batch, pos_batch, neg_batch, idx_triplet )
350
+ valid_loader , # The same for validation
353
351
optimizer ,
354
352
criterion ,
355
353
num_epochs ,
@@ -364,7 +362,7 @@ def train_model_diversity(
364
362
365
363
for epoch in tqdm (range (num_epochs ), desc = "Training Progress" ):
366
364
# --------------------
367
- # 1) Тренировочный проход
365
+ # 1) Training pass
368
366
# --------------------
369
367
model .train ()
370
368
running_loss = 0.0
@@ -377,19 +375,19 @@ def train_model_diversity(
377
375
break
378
376
379
377
optimizer .zero_grad ()
380
- # Переносим весь батч на device
378
+ # Move the entire batch to device
381
379
anchor_batch = anchor_batch .to (device )
382
380
pos_batch = pos_batch .to (device )
383
381
neg_batch = neg_batch .to (device )
384
382
385
- # Прогоняем через модель
383
+ # Feed through the model
386
384
emb_anchor = model (
387
385
anchor_batch .x , anchor_batch .edge_index , anchor_batch .batch
388
386
)
389
387
emb_pos = model (pos_batch .x , pos_batch .edge_index , pos_batch .batch )
390
388
emb_neg = model (neg_batch .x , neg_batch .edge_index , neg_batch .batch )
391
389
392
- # Считаем loss, backward, step
390
+ # Calculate loss, backward, step
393
391
loss = criterion (emb_anchor , emb_pos , emb_neg )
394
392
loss .backward ()
395
393
optimizer .step ()
@@ -402,7 +400,7 @@ def train_model_diversity(
402
400
train_losses .append (avg_train_loss )
403
401
404
402
# --------------------
405
- # 2) Валидация
403
+ # 2) Validation
406
404
# --------------------
407
405
model .eval ()
408
406
val_loss = 0.0
@@ -415,7 +413,7 @@ def train_model_diversity(
415
413
if developer_mode and i > 0 :
416
414
break
417
415
418
- # перенос на device
416
+ # Move the entire batch to device
419
417
anchor_batch = anchor_batch .to (device )
420
418
pos_batch = pos_batch .to (device )
421
419
neg_batch = neg_batch .to (device )
@@ -433,43 +431,14 @@ def train_model_diversity(
433
431
avg_valid_loss = val_loss / max (1 , n_val_batches )
434
432
valid_losses .append (avg_valid_loss )
435
433
436
- # --------------------
437
- # 3) Лог и визуализация
438
- # --------------------
439
- try :
440
- from IPython .display import clear_output
441
-
442
- clear_output (wait = True )
443
- except ImportError :
444
- pass
445
-
446
434
lr = scheduler .get_last_lr ()[0 ]
447
435
print (
448
436
f"Epoch { epoch + 1 } /{ num_epochs } — "
449
437
f"Train Loss: { avg_train_loss :.4f} , Valid Loss: { avg_valid_loss :.4f} , LR: { lr :.6f} "
450
438
)
451
439
452
440
if draw_figure :
453
- plt .figure (figsize = (12 , 6 ))
454
- plt .rc ("font" , size = 20 )
455
- plt .plot (
456
- range (1 , len (train_losses ) + 1 ),
457
- train_losses ,
458
- marker = "o" ,
459
- label = "Train Loss" ,
460
- )
461
- plt .plot (
462
- range (1 , len (valid_losses ) + 1 ),
463
- valid_losses ,
464
- marker = "s" ,
465
- label = "Valid Loss" ,
466
- )
467
- plt .xlabel ("Epoch" )
468
- plt .ylabel ("Loss" )
469
- plt .grid (True )
470
- plt .legend ()
471
- plt .tight_layout ()
472
- plt .show ()
441
+ plot_train_valid_losses (train_losses , valid_losses )
473
442
474
443
return train_losses , valid_losses
475
444
@@ -495,7 +464,7 @@ def train_model_accuracy(
495
464
for epoch in tqdm (range (num_epochs ), desc = "Training Progress" ):
496
465
model .train ()
497
466
train_loss = 0
498
- n_train_samples = 0 # изменено: считаем количество графов, а не батчей
467
+ n_train_samples = 0
499
468
500
469
for i , data in enumerate (train_loader ):
501
470
if developer_mode and i > 0 :
@@ -540,45 +509,37 @@ def train_model_accuracy(
540
509
avg_valid_loss = valid_loss / max (1 , n_val_samples )
541
510
valid_losses .append (avg_valid_loss )
542
511
543
- try :
544
- from IPython .display import clear_output
545
-
546
- clear_output (wait = True )
547
- except :
548
- pass
549
-
550
512
lr = scheduler .get_last_lr ()[0 ]
551
513
print (
552
514
f"Epoch { epoch + 1 } , Train Loss: { avg_train_loss * 1e4 :.4f} , "
553
515
f"Valid Loss: { avg_valid_loss * 1e4 :.4f} , LR: { lr :.6f} "
554
516
)
517
+
555
518
if draw_figure :
556
- plt .figure (figsize = (12 , 6 ))
557
- plt .rc ("font" , size = 20 )
558
519
tmp_train_losses = np .array (train_losses )
559
- tmp_valid_losses = np .array (valid_losses )
560
- plt .plot (
561
- range (1 , len (tmp_train_losses ) + 1 ),
562
- tmp_train_losses * 1e4 ,
563
- marker = "o" ,
564
- label = "Train Loss" ,
565
- )
566
- plt .plot (
567
- range (1 , len (tmp_valid_losses ) + 1 ),
568
- tmp_valid_losses * 1e4 ,
569
- marker = "o" ,
570
- label = "Valid Loss" ,
571
- )
572
- plt .xlabel ("Epoch" )
573
- plt .ylabel ("Loss * 1e4" )
574
- plt .ylim (0 , 10 )
575
- plt .grid (True )
576
- plt .legend ()
577
- plt .show ()
520
+ tmp_valid_losses = np .array (valid_losses ) * 1e4
521
+ plot_train_valid_losses (tmp_train_losses , tmp_valid_losses )
578
522
579
523
return train_losses , valid_losses
580
524
581
525
526
+ def plot_train_valid_losses (train_losses , valid_losses ):
527
+ plt .figure (figsize = (12 , 6 ))
528
+ plt .rc ("font" , size = 20 )
529
+ plt .plot (
530
+ range (1 , len (train_losses ) + 1 ), train_losses , marker = "o" , label = "Train Loss"
531
+ )
532
+ plt .plot (
533
+ range (1 , len (valid_losses ) + 1 ), valid_losses , marker = "s" , label = "Valid Loss"
534
+ )
535
+ plt .xlabel ("Epoch" )
536
+ plt .ylabel ("Loss" )
537
+ plt .grid (True )
538
+ plt .legend ()
539
+ plt .tight_layout ()
540
+ plt .show ()
541
+
542
+
582
543
def get_positive_and_negative (diversity_matrix , indices , dataset = None ):
583
544
positive_indices = []
584
545
negative_indices = []
0 commit comments