@@ -89,7 +89,7 @@ def __init__(self, input_dim, output_dim=16, dropout=0.5, pooling="max"):
89
89
self .residual_proj = (
90
90
nn .Linear (input_dim , self .hidden_dim ) if input_dim != self .hidden_dim else nn .Identity ()
91
91
)
92
-
92
+
93
93
self .layer_norm = nn .LayerNorm (self .hidden_dim )
94
94
self .dropout = nn .Dropout (dropout )
95
95
self .pooling = pooling
@@ -121,7 +121,7 @@ def forward(self, x, edge_index, batch=None):
121
121
elif self .pooling == "mean" :
122
122
x = global_mean_pool (x , batch )
123
123
elif self .pooling == "sum" :
124
- x = torch . sum (x , dim = 0 )
124
+ x = global_add_pool (x , batch )
125
125
else :
126
126
raise ValueError ("Unsupported pooling method. Use 'max', 'mean' or 'sum'." )
127
127
@@ -134,6 +134,7 @@ def forward(self, x, edge_index, batch=None):
134
134
x = torch .sigmoid (x )
135
135
return x
136
136
137
+
137
138
class SimpleGAT (nn .Module ):
138
139
def __init__ (
139
140
self ,
@@ -443,7 +444,7 @@ def train_model_accuracy(
443
444
for epoch in tqdm (range (num_epochs ), desc = "Training Progress" ):
444
445
model .train ()
445
446
train_loss = 0
446
- n_train_batches = 0
447
+ n_train_samples = 0 # изменено: считаем количество графов, а не батчей
447
448
448
449
for i , data in enumerate (train_loader ):
449
450
if developer_mode and i > 0 :
@@ -452,25 +453,24 @@ def train_model_accuracy(
452
453
data = data .to (device )
453
454
optimizer .zero_grad ()
454
455
455
- # Явно передаем `edge_index` и `x`
456
- prediction = model (data .x , data .edge_index , data .batch ).squeeze ()
456
+ prediction = model (data .x , data .edge_index , data .batch ).squeeze ()
457
457
target = data .y .float ()
458
458
459
459
loss = criterion (prediction , target )
460
460
loss .backward ()
461
461
optimizer .step ()
462
462
463
- train_loss += loss .item ()
464
- n_train_batches += 1
463
+ train_loss += loss .item () * data . num_graphs # весим loss по числу графов
464
+ n_train_samples += data . num_graphs
465
465
466
466
scheduler .step ()
467
- avg_train_loss = train_loss / max (1 , n_train_batches )
467
+ avg_train_loss = train_loss / max (1 , n_train_samples )
468
468
train_losses .append (avg_train_loss )
469
469
470
470
# Validation
471
471
model .eval ()
472
472
valid_loss = 0
473
- n_val_batches = 0
473
+ n_val_samples = 0
474
474
475
475
with torch .no_grad ():
476
476
for i , data in enumerate (valid_loader ):
@@ -483,10 +483,10 @@ def train_model_accuracy(
483
483
target = data .y .float ()
484
484
485
485
loss = criterion (prediction , target )
486
- valid_loss += loss .item ()
487
- n_val_batches += 1
486
+ valid_loss += loss .item () * data . num_graphs
487
+ n_val_samples += data . num_graphs
488
488
489
- avg_valid_loss = valid_loss / max (1 , n_val_batches )
489
+ avg_valid_loss = valid_loss / max (1 , n_val_samples )
490
490
valid_losses .append (avg_valid_loss )
491
491
492
492
try :
@@ -496,11 +496,13 @@ def train_model_accuracy(
496
496
pass
497
497
498
498
plt .figure (figsize = (12 , 6 ))
499
- plt .plot (range (1 , len (train_losses ) + 1 ), train_losses , marker = "o" , label = "Train Loss" )
500
- plt .plot (range (1 , len (valid_losses ) + 1 ), valid_losses , marker = "o" , label = "Valid Loss" )
499
+ tmp_train_losses = np .array (train_losses )
500
+ tmp_valid_losses = np .array (valid_losses )
501
+ plt .plot (range (1 , len (tmp_train_losses ) + 1 ), tmp_train_losses * 1e4 , marker = "o" , label = "Train Loss" )
502
+ plt .plot (range (1 , len (tmp_valid_losses ) + 1 ), tmp_valid_losses * 1e4 , marker = "o" , label = "Valid Loss" )
501
503
plt .xlabel ("Epoch" )
502
- plt .ylabel ("Loss" )
503
- # plt.ylim(0, 0.002 )
504
+ plt .ylabel ("Loss * 1e4 " )
505
+ plt .ylim (0 , 10 )
504
506
plt .grid (True )
505
507
plt .legend ()
506
508
plt .show ()
0 commit comments