20
20
from torch_geometric .loader import DataLoader
21
21
22
22
class SimpleGCN (nn .Module ):
23
- def __init__ (
24
- self , input_dim , hidden_dim , embedding_dim , dropout = 0.5 , pooling = "max"
25
- ):
23
+ def __init__ (self , input_dim , embedding_dim , hidden_dim = 64 , dropout = 0.5 , pooling = "max" ):
26
24
"""
27
25
input_dim: размер входных признаков узлов
28
26
hidden_dim: размер скрытого пространства в графовых свёрточных слоях
@@ -35,16 +33,18 @@ def __init__(
35
33
self .input_dim = input_dim
36
34
self .hidden_dim = hidden_dim
37
35
self .embedding_dim = embedding_dim
36
+ self .pooling = pooling
38
37
39
38
self .gc1 = GCNConv (input_dim , hidden_dim )
40
39
self .gc2 = GCNConv (hidden_dim , hidden_dim )
40
+
41
41
self .graph_norm = GraphNorm (hidden_dim )
42
42
self .layer_norm = nn .LayerNorm (hidden_dim )
43
43
self .dropout = nn .Dropout (dropout )
44
- self . pooling = pooling
44
+
45
45
self .fc = nn .Linear (hidden_dim , embedding_dim )
46
46
47
- def forward (self , x , edge_index ):
47
+ def forward (self , x , edge_index , batch = None ):
48
48
x = F .relu (self .gc1 (x , edge_index ))
49
49
x = self .graph_norm (x )
50
50
x = self .layer_norm (x )
@@ -55,22 +55,21 @@ def forward(self, x, edge_index):
55
55
x = self .layer_norm (x )
56
56
x = self .dropout (x )
57
57
58
- # Глобальное агрегирование узловых признаков для получения представления всего графа
58
+ # Пулинг по графу
59
59
if self .pooling == "max" :
60
- pooled = torch . max (x , dim = 0 ). values
60
+ x = global_max_pool (x , batch )
61
61
elif self .pooling == "mean" :
62
- pooled = torch . mean (x , dim = 0 )
62
+ x = global_mean_pool (x , batch )
63
63
elif self .pooling == "sum" :
64
- pooled = torch . sum (x , dim = 0 )
64
+ x = global_add_pool (x , batch )
65
65
else :
66
- raise ValueError ("Unsupported pooling method. Use 'max', 'mean' или 'sum'." )
66
+ raise ValueError ("Unsupported pooling method. Use 'max', 'mean' or 'sum'." )
67
67
68
- # Преобразование агрегированного представления в эмбеддинг графа
69
- embedding = self .fc (pooled )
68
+ x = self .fc (x )
70
69
71
70
if self .embedding_dim == 1 :
72
- embedding = nn . Sigmoid ()( embedding )
73
- return embedding
71
+ x = torch . sigmoid ( x )
72
+ return x
74
73
75
74
76
75
class GCN (nn .Module ):
@@ -135,70 +134,85 @@ def forward(self, x, edge_index, batch=None):
135
134
return x
136
135
137
136
138
- class SimpleGAT (nn .Module ):
139
- def __init__ (
140
- self ,
141
- input_dim ,
142
- embed_dim ,
143
- hidden_dim ,
144
- out_dim ,
145
- dropout = 0.5 ,
146
- pooling = "max" ,
147
- heads = 1 ,
148
- ):
149
- """
150
- input_dim: размер one-hot признаков узлов
151
- embed_dim: размер обучаемого эмбеддинга узлов
152
- hidden_dim: скрытое пространство для GAT
153
- out_dim: размер итогового графового эмбеддинга
154
- pooling: 'max', 'mean', или 'sum'
155
- """
156
- super (SimpleGAT , self ).__init__ ()
157
-
158
- self .pooling = pooling
159
- self .dropout = nn .Dropout (dropout )
160
-
161
- self .node_encoder = nn .Linear (
162
- input_dim , embed_dim
163
- ) # обучаемый слой проекции one-hot -> dense
137
+ class GAT (nn .Module ):
138
+ def __init__ (self , input_dim , output_dim = 16 , dropout = 0.5 , pooling = "max" , heads = 4 ):
139
+ super (GAT , self ).__init__ ()
164
140
165
- self .gat1 = GATConv (embed_dim , hidden_dim , heads = heads , concat = True )
166
- self .norm1 = GraphNorm (hidden_dim * heads )
141
+ self .input_dim = input_dim
142
+ self .output_dim = output_dim
143
+ self .hidden_dim = 64 # базовая размерность скрытого слоя
144
+ self .heads = heads
145
+
146
+ # Attention Conv слои
147
+ self .gat1 = GATConv (input_dim , self .hidden_dim // heads , heads = heads )
148
+ self .gat2 = GATConv (self .hidden_dim , 256 // heads , heads = heads )
149
+ self .gat3 = GATConv (256 , 256 // heads , heads = heads )
150
+ self .gat4 = GATConv (256 , self .hidden_dim // heads , heads = heads )
151
+
152
+ # Проекции для full residual
153
+ self .res1 = nn .Linear (input_dim , self .hidden_dim )
154
+ self .res2 = nn .Linear (self .hidden_dim , 256 )
155
+ self .res3 = nn .Linear (256 , 256 )
156
+ self .res4 = nn .Linear (256 , self .hidden_dim )
157
+
158
+ # Нормализация после каждого блока
159
+ self .norm1 = GraphNorm (self .hidden_dim )
160
+ self .norm2 = GraphNorm (256 )
161
+ self .norm3 = GraphNorm (256 )
162
+ self .norm4 = GraphNorm (self .hidden_dim )
167
163
168
- self .gat2 = GATConv ( hidden_dim * heads , hidden_dim , heads = heads , concat = True )
169
- self .norm2 = GraphNorm ( hidden_dim * heads )
164
+ self .dropout = nn . Dropout ( dropout )
165
+ self .pooling = pooling
170
166
171
- self .fc = nn .Linear (hidden_dim * heads , out_dim )
167
+ # Полносвязная часть для графового эмбеддинга
168
+ self .fc1 = nn .Linear (self .hidden_dim , self .hidden_dim )
169
+ self .fc_norm = nn .LayerNorm (self .hidden_dim )
170
+ self .fc2 = nn .Linear (self .hidden_dim , output_dim )
172
171
173
172
def forward (self , x , edge_index , batch = None ):
174
- x = self .node_encoder (x )
175
- x = F .elu (self .gat1 (x , edge_index ))
176
- x = self .norm1 (x )
177
- x = self .dropout (x )
178
-
179
- x = F .elu (self .gat2 (x , edge_index ))
180
- x = self .norm2 (x )
181
- x = self .dropout (x )
182
-
183
- # Пуллинг по графу (поддержка батча)
184
- if batch is None :
185
- batch = torch .zeros (x .size (0 ), dtype = torch .long , device = x .device )
186
-
173
+ # Layer 1
174
+ h1 = self .gat1 (x , edge_index )
175
+ h1 = F .leaky_relu (h1 + self .res1 (x ))
176
+ h1 = self .norm1 (h1 )
177
+ h1 = self .dropout (h1 )
178
+
179
+ # Layer 2
180
+ h2 = self .gat2 (h1 , edge_index )
181
+ h2 = F .leaky_relu (h2 + self .res2 (h1 ))
182
+ h2 = self .norm2 (h2 )
183
+ h2 = self .dropout (h2 )
184
+
185
+ # Layer 3
186
+ h3 = self .gat3 (h2 , edge_index )
187
+ h3 = F .leaky_relu (h3 + self .res3 (h2 ))
188
+ h3 = self .norm3 (h3 )
189
+ h3 = self .dropout (h3 )
190
+
191
+ # Layer 4
192
+ h4 = self .gat4 (h3 , edge_index )
193
+ h4 = F .leaky_relu (h4 + self .res4 (h3 ))
194
+ h4 = self .norm4 (h4 )
195
+ h4 = self .dropout (h4 )
196
+
197
+ # Глобальное агрегирование
187
198
if self .pooling == "max" :
188
- pooled = global_max_pool (x , batch )
199
+ hg = global_max_pool (h4 , batch )
189
200
elif self .pooling == "mean" :
190
- pooled = global_mean_pool (x , batch )
201
+ hg = global_mean_pool (h4 , batch )
191
202
elif self .pooling == "sum" :
192
- pooled = global_add_pool (x , batch )
203
+ hg = global_add_pool (h4 , batch )
193
204
else :
194
205
raise ValueError ("Unsupported pooling method. Use 'max', 'mean' or 'sum'." )
195
206
196
- embedding = self .fc (pooled )
197
-
198
- if self .fc .out_features == 1 :
199
- embedding = torch .sigmoid (embedding )
207
+ # Финальный MLP
208
+ out = self .fc1 (hg )
209
+ out = self .fc_norm (out )
210
+ out = F .leaky_relu (out )
211
+ out = self .fc2 (out )
212
+ if self .output_dim == 1 :
213
+ out = torch .sigmoid (out )
214
+ return out
200
215
201
- return embedding
202
216
203
217
class CustomDataset (Dataset ):
204
218
@staticmethod
0 commit comments