Node classification issues on heterogenous graph - missing initial embeddings #10333
Unanswered
norapettersson-smf
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi!
I have been working on implementing node classification GNN using
torch_geometric
and I’m having troubles getting the model to work. As I’m developing this as a part of my work, I’m not at liberty to share the data but I hope that I could get some feedback conceptually if the idea even make sense or if there are some limitations I’m not aware of.The Graph:
So far, I have a heterogenous graph with two node types, A and B. The edges only goes from A to B or B to A. That is to say, nodes A are never directly connected to any other A nodes, same goes for B. The graph is of size 30k of type A and ~500k nodes of type B.
Truth classes exist for nodes B and are used to train the network.
Node A has an initial embedding but B is lacking any prior knowledge (initial embedding is put to a 1-dim vector. The idea was to learn embeddings for B with the interactions. And make following classification on the nodes B.
Ideally, if the network would work, I want to add additional classifier to the model to predict more properties.
The Model:
The model consists of three
SAGEConv
layers,BatchNorm1d
, non-linear functionGELU
and dropout. At the end is a linear layer with the output number of classes, 10. A dummy of the class looks something like below. Afterwards the model is adapted to heterogeneous data via to_hetero().The training loop:
The training is pretty straightforward with a CrossEntropyLoss, comparing the output with the truth labels in
y
of the nodes.I'm using the NeighborLoader to split the graph into batches, based on node B.
For some reason I can't get the network to learn, there is a slight turn-on of the accuracy/precision at the first couple epoch but then all metrics just plateau, around epoch 20-30, at close to random-guessing. I have validated all the data and formatting through the training and everything looks OK.
Disabling all regularisations doesn't improve either.
I have tried different optimisers as well as LRSchedulers to decrease the LR along the way but the end result seems to be the same.
Questions
I'm just looking for some general feedback if my approach makes sense.
I'm not sure how much of a problem it is that one node type is lacking any initial embeddings. There is also quite a gap in the number of nodes between A and B, which I could see being an issue (but that should be more to failing to generalise). Unfortunately, here there is no way to solve the ratio by getting more data.
Are there any ideas that would be good to try? Or some obvious pitfalls
Beta Was this translation helpful? Give feedback.
All reactions