Skip to content

Commit 91a3316

Browse files
committed
fix bugs and add weights
1 parent b4204be commit 91a3316

File tree

5 files changed

+155
-255
lines changed

5 files changed

+155
-255
lines changed

code/GCN.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(self, input_dim, output_dim=16, dropout=0.5, pooling="max"):
8989
self.residual_proj = (
9090
nn.Linear(input_dim, self.hidden_dim) if input_dim != self.hidden_dim else nn.Identity()
9191
)
92-
92+
9393
self.layer_norm = nn.LayerNorm(self.hidden_dim)
9494
self.dropout = nn.Dropout(dropout)
9595
self.pooling = pooling
@@ -121,7 +121,7 @@ def forward(self, x, edge_index, batch=None):
121121
elif self.pooling == "mean":
122122
x = global_mean_pool(x, batch)
123123
elif self.pooling == "sum":
124-
x = torch.sum(x, dim=0)
124+
x = global_add_pool(x, batch)
125125
else:
126126
raise ValueError("Unsupported pooling method. Use 'max', 'mean' or 'sum'.")
127127

@@ -134,6 +134,7 @@ def forward(self, x, edge_index, batch=None):
134134
x = torch.sigmoid(x)
135135
return x
136136

137+
137138
class SimpleGAT(nn.Module):
138139
def __init__(
139140
self,
@@ -443,7 +444,7 @@ def train_model_accuracy(
443444
for epoch in tqdm(range(num_epochs), desc="Training Progress"):
444445
model.train()
445446
train_loss = 0
446-
n_train_batches = 0
447+
n_train_samples = 0 # изменено: считаем количество графов, а не батчей
447448

448449
for i, data in enumerate(train_loader):
449450
if developer_mode and i > 0:
@@ -452,25 +453,24 @@ def train_model_accuracy(
452453
data = data.to(device)
453454
optimizer.zero_grad()
454455

455-
# Явно передаем `edge_index` и `x`
456-
prediction = model(data.x, data.edge_index, data.batch).squeeze()
456+
prediction = model(data.x, data.edge_index, data.batch).squeeze()
457457
target = data.y.float()
458458

459459
loss = criterion(prediction, target)
460460
loss.backward()
461461
optimizer.step()
462462

463-
train_loss += loss.item()
464-
n_train_batches += 1
463+
train_loss += loss.item() * data.num_graphs # весим loss по числу графов
464+
n_train_samples += data.num_graphs
465465

466466
scheduler.step()
467-
avg_train_loss = train_loss / max(1, n_train_batches)
467+
avg_train_loss = train_loss / max(1, n_train_samples)
468468
train_losses.append(avg_train_loss)
469469

470470
# Validation
471471
model.eval()
472472
valid_loss = 0
473-
n_val_batches = 0
473+
n_val_samples = 0
474474

475475
with torch.no_grad():
476476
for i, data in enumerate(valid_loader):
@@ -483,10 +483,10 @@ def train_model_accuracy(
483483
target = data.y.float()
484484

485485
loss = criterion(prediction, target)
486-
valid_loss += loss.item()
487-
n_val_batches += 1
486+
valid_loss += loss.item() * data.num_graphs
487+
n_val_samples += data.num_graphs
488488

489-
avg_valid_loss = valid_loss / max(1, n_val_batches)
489+
avg_valid_loss = valid_loss / max(1, n_val_samples)
490490
valid_losses.append(avg_valid_loss)
491491

492492
try:
@@ -496,11 +496,13 @@ def train_model_accuracy(
496496
pass
497497

498498
plt.figure(figsize=(12, 6))
499-
plt.plot(range(1, len(train_losses) + 1), train_losses, marker="o", label="Train Loss")
500-
plt.plot(range(1, len(valid_losses) + 1), valid_losses, marker="o", label="Valid Loss")
499+
tmp_train_losses = np.array(train_losses)
500+
tmp_valid_losses = np.array(valid_losses)
501+
plt.plot(range(1, len(tmp_train_losses) + 1), tmp_train_losses * 1e4, marker="o", label="Train Loss")
502+
plt.plot(range(1, len(tmp_valid_losses) + 1), tmp_valid_losses * 1e4, marker="o", label="Valid Loss")
501503
plt.xlabel("Epoch")
502-
plt.ylabel("Loss")
503-
# plt.ylim(0, 0.002)
504+
plt.ylabel("Loss * 1e4")
505+
plt.ylim(0, 10)
504506
plt.grid(True)
505507
plt.legend()
506508
plt.show()

code/dependecies.zip

1.17 KB
Binary file not shown.

code/gcn-training.ipynb

Lines changed: 137 additions & 239 deletions
Large diffs are not rendered by default.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)