Skip to content

Commit deccc47

Browse files
committed
add some fixes
1 parent 91a3316 commit deccc47

File tree

3 files changed

+281
-194
lines changed

3 files changed

+281
-194
lines changed

code/GCN.py

Lines changed: 80 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020
from torch_geometric.loader import DataLoader
2121

2222
class SimpleGCN(nn.Module):
23-
def __init__(
24-
self, input_dim, hidden_dim, embedding_dim, dropout=0.5, pooling="max"
25-
):
23+
def __init__(self, input_dim, embedding_dim, hidden_dim=64, dropout=0.5, pooling="max"):
2624
"""
2725
input_dim: размер входных признаков узлов
2826
hidden_dim: размер скрытого пространства в графовых свёрточных слоях
@@ -35,16 +33,18 @@ def __init__(
3533
self.input_dim = input_dim
3634
self.hidden_dim = hidden_dim
3735
self.embedding_dim = embedding_dim
36+
self.pooling = pooling
3837

3938
self.gc1 = GCNConv(input_dim, hidden_dim)
4039
self.gc2 = GCNConv(hidden_dim, hidden_dim)
40+
4141
self.graph_norm = GraphNorm(hidden_dim)
4242
self.layer_norm = nn.LayerNorm(hidden_dim)
4343
self.dropout = nn.Dropout(dropout)
44-
self.pooling = pooling
44+
4545
self.fc = nn.Linear(hidden_dim, embedding_dim)
4646

47-
def forward(self, x, edge_index):
47+
def forward(self, x, edge_index, batch=None):
4848
x = F.relu(self.gc1(x, edge_index))
4949
x = self.graph_norm(x)
5050
x = self.layer_norm(x)
@@ -55,22 +55,21 @@ def forward(self, x, edge_index):
5555
x = self.layer_norm(x)
5656
x = self.dropout(x)
5757

58-
# Глобальное агрегирование узловых признаков для получения представления всего графа
58+
# Пулинг по графу
5959
if self.pooling == "max":
60-
pooled = torch.max(x, dim=0).values
60+
x = global_max_pool(x, batch)
6161
elif self.pooling == "mean":
62-
pooled = torch.mean(x, dim=0)
62+
x = global_mean_pool(x, batch)
6363
elif self.pooling == "sum":
64-
pooled = torch.sum(x, dim=0)
64+
x = global_add_pool(x, batch)
6565
else:
66-
raise ValueError("Unsupported pooling method. Use 'max', 'mean' или 'sum'.")
66+
raise ValueError("Unsupported pooling method. Use 'max', 'mean' or 'sum'.")
6767

68-
# Преобразование агрегированного представления в эмбеддинг графа
69-
embedding = self.fc(pooled)
68+
x = self.fc(x)
7069

7170
if self.embedding_dim == 1:
72-
embedding = nn.Sigmoid()(embedding)
73-
return embedding
71+
x = torch.sigmoid(x)
72+
return x
7473

7574

7675
class GCN(nn.Module):
@@ -135,70 +134,85 @@ def forward(self, x, edge_index, batch=None):
135134
return x
136135

137136

138-
class SimpleGAT(nn.Module):
139-
def __init__(
140-
self,
141-
input_dim,
142-
embed_dim,
143-
hidden_dim,
144-
out_dim,
145-
dropout=0.5,
146-
pooling="max",
147-
heads=1,
148-
):
149-
"""
150-
input_dim: размер one-hot признаков узлов
151-
embed_dim: размер обучаемого эмбеддинга узлов
152-
hidden_dim: скрытое пространство для GAT
153-
out_dim: размер итогового графового эмбеддинга
154-
pooling: 'max', 'mean', или 'sum'
155-
"""
156-
super(SimpleGAT, self).__init__()
157-
158-
self.pooling = pooling
159-
self.dropout = nn.Dropout(dropout)
160-
161-
self.node_encoder = nn.Linear(
162-
input_dim, embed_dim
163-
) # обучаемый слой проекции one-hot -> dense
137+
class GAT(nn.Module):
138+
def __init__(self, input_dim, output_dim=16, dropout=0.5, pooling="max", heads=4):
139+
super(GAT, self).__init__()
164140

165-
self.gat1 = GATConv(embed_dim, hidden_dim, heads=heads, concat=True)
166-
self.norm1 = GraphNorm(hidden_dim * heads)
141+
self.input_dim = input_dim
142+
self.output_dim = output_dim
143+
self.hidden_dim = 64 # базовая размерность скрытого слоя
144+
self.heads = heads
145+
146+
# Attention Conv слои
147+
self.gat1 = GATConv(input_dim, self.hidden_dim // heads, heads=heads)
148+
self.gat2 = GATConv(self.hidden_dim, 256 // heads, heads=heads)
149+
self.gat3 = GATConv(256, 256 // heads, heads=heads)
150+
self.gat4 = GATConv(256, self.hidden_dim // heads, heads=heads)
151+
152+
# Проекции для full residual
153+
self.res1 = nn.Linear(input_dim, self.hidden_dim)
154+
self.res2 = nn.Linear(self.hidden_dim, 256)
155+
self.res3 = nn.Linear(256, 256)
156+
self.res4 = nn.Linear(256, self.hidden_dim)
157+
158+
# Нормализация после каждого блока
159+
self.norm1 = GraphNorm(self.hidden_dim)
160+
self.norm2 = GraphNorm(256)
161+
self.norm3 = GraphNorm(256)
162+
self.norm4 = GraphNorm(self.hidden_dim)
167163

168-
self.gat2 = GATConv(hidden_dim * heads, hidden_dim, heads=heads, concat=True)
169-
self.norm2 = GraphNorm(hidden_dim * heads)
164+
self.dropout = nn.Dropout(dropout)
165+
self.pooling = pooling
170166

171-
self.fc = nn.Linear(hidden_dim * heads, out_dim)
167+
# Полносвязная часть для графового эмбеддинга
168+
self.fc1 = nn.Linear(self.hidden_dim, self.hidden_dim)
169+
self.fc_norm = nn.LayerNorm(self.hidden_dim)
170+
self.fc2 = nn.Linear(self.hidden_dim, output_dim)
172171

173172
def forward(self, x, edge_index, batch=None):
174-
x = self.node_encoder(x)
175-
x = F.elu(self.gat1(x, edge_index))
176-
x = self.norm1(x)
177-
x = self.dropout(x)
178-
179-
x = F.elu(self.gat2(x, edge_index))
180-
x = self.norm2(x)
181-
x = self.dropout(x)
182-
183-
# Пуллинг по графу (поддержка батча)
184-
if batch is None:
185-
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
186-
173+
# Layer 1
174+
h1 = self.gat1(x, edge_index)
175+
h1 = F.leaky_relu(h1 + self.res1(x))
176+
h1 = self.norm1(h1)
177+
h1 = self.dropout(h1)
178+
179+
# Layer 2
180+
h2 = self.gat2(h1, edge_index)
181+
h2 = F.leaky_relu(h2 + self.res2(h1))
182+
h2 = self.norm2(h2)
183+
h2 = self.dropout(h2)
184+
185+
# Layer 3
186+
h3 = self.gat3(h2, edge_index)
187+
h3 = F.leaky_relu(h3 + self.res3(h2))
188+
h3 = self.norm3(h3)
189+
h3 = self.dropout(h3)
190+
191+
# Layer 4
192+
h4 = self.gat4(h3, edge_index)
193+
h4 = F.leaky_relu(h4 + self.res4(h3))
194+
h4 = self.norm4(h4)
195+
h4 = self.dropout(h4)
196+
197+
# Глобальное агрегирование
187198
if self.pooling == "max":
188-
pooled = global_max_pool(x, batch)
199+
hg = global_max_pool(h4, batch)
189200
elif self.pooling == "mean":
190-
pooled = global_mean_pool(x, batch)
201+
hg = global_mean_pool(h4, batch)
191202
elif self.pooling == "sum":
192-
pooled = global_add_pool(x, batch)
203+
hg = global_add_pool(h4, batch)
193204
else:
194205
raise ValueError("Unsupported pooling method. Use 'max', 'mean' or 'sum'.")
195206

196-
embedding = self.fc(pooled)
197-
198-
if self.fc.out_features == 1:
199-
embedding = torch.sigmoid(embedding)
207+
# Финальный MLP
208+
out = self.fc1(hg)
209+
out = self.fc_norm(out)
210+
out = F.leaky_relu(out)
211+
out = self.fc2(out)
212+
if self.output_dim == 1:
213+
out = torch.sigmoid(out)
214+
return out
200215

201-
return embedding
202216

203217
class CustomDataset(Dataset):
204218
@staticmethod

0 commit comments

Comments
 (0)