Skip to content

Revise GNNLux Temporal Graph Layers #595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions GNNLux/docs/src_tutorials/graph_classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,22 @@
# The TU Dortmund University has collected a wide range of different graph classification datasets, known as the [**TUDatasets**](https://chrsmrrs.github.io/datasets/), which are also accessible via MLDatasets.jl.
# Let's import the necessary packages. Then we'll load and inspect one of the smaller ones, the **MUTAG dataset**:

using Lux, GNNLux
using MLDatasets, MLUtils
using Lux
using GNNLux
using MLDatasets
using MLUtils
using LinearAlgebra, Random, Statistics
using Zygote, Optimisers, OneHotArrays


struct GlobalPool{F} <: GNNLayer
aggr::F
end

(l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st

(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))

ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation
rng = Random.seed!(42); # for reproducibility

Expand Down Expand Up @@ -107,19 +118,20 @@ MLUtils.batch(vec_gs)

# The final architecture for applying GNNs to the task of graph classification then looks as follows and allows for complete end-to-end training:

function create_model(nin, nh, nout)
GNNChain(GCNConv(nin => nh, relu),
GCNConv(nh => nh, relu),
GCNConv(nh => nh),
# Then use it in the model
function create_model_graphconv(nin, nh, nout)
GNNChain(GraphConv(nin => nh, relu),
GraphConv(nh => nh, relu),
GraphConv(nh => nh),
GlobalPool(mean),
Dropout(0.5),
Dense(nh, nout))
end;
end

nin = 7
nh = 64
nout = 2
model = create_model(nin, nh, nout)
model = create_model_graphconv(nin, nh, nout)

ps, st = LuxCore.initialparameters(rng, model), LuxCore.initialstates(rng, model);

Expand Down Expand Up @@ -191,11 +203,4 @@ model, ps, st = train_model!(model, ps, st);

# This layer is implemented under the name `GraphConv` in GraphNeuralNetworks.jl.

# As an exercise, you are invited to complete the following code to the extent that it makes use of `GraphConv` rather than `GCNConv`.
# This should bring you close to **82% test accuracy**.

# ## Conclusion

# In this chapter, you have learned how to apply GNNs to the task of graph classification.
# You have learned how graphs can be batched together for better GPU utilization, and how to apply readout layers for obtaining graph embeddings rather than node embeddings.

# As an exercise, you are invited to complete the following code to the extent that it makes use of `
211 changes: 211 additions & 0 deletions GNNLux/docs/src_tutorials/temporal_graph_classification_lux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
using Lux
using GNNLux
using MLDatasets
using MLUtils
using LinearAlgebra, Random, Statistics
using Zygote, Optimisers, OneHotArrays
using MLDatasets: TemporalBrains
using GNNlib
using Optimisers

ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation
rng = Random.seed!(42); # for reproducibility

brain_dataset = MLDatasets.TemporalBrains()

function data_loader(brain_dataset)
graphs = brain_dataset.graphs
dataset = Vector{TemporalSnapshotsGNNGraph}(undef, length(graphs))
for i in 1:length(graphs)
graph = graphs[i]
dataset[i] = TemporalSnapshotsGNNGraph(GNNGraphs.mlgraph2gnngraph.(graph.snapshots))
## Add graph and node features
for t in 1:27
s = dataset[i].snapshots[t]
s.ndata.x = Float32.([I(102); s.ndata.x'])
end
dataset[i].tgdata.g = Float32.(onehotbatch([graph.graph_data.g], ["F", "M"]))
end

## Split the dataset into a 80% training set and a 20% test set
train_graphs = dataset[1:200]
test_graphs = dataset[201:250]

# Create tuples of (graph, label) for compatibility with training loop
train_loader = [(g, g.tgdata.g) for g in train_graphs]
test_loader = [(g, g.tgdata.g) for g in test_graphs]

return train_loader, test_loader
end

struct GlobalPool{F} <: GNNLayer
aggr::F
end

# Implementation for regular GNNGraph (similar to graph_classification.jl)
(l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st

# Implementation for TemporalSnapshotsGNNGraph - processes each snapshot and returns mean
function (l::GlobalPool)(g::TemporalSnapshotsGNNGraph, x::AbstractVector, ps, st)
h = [GNNlib.global_pool(l, g.snapshots[i], x[i]) for i in 1:g.num_snapshots]
return mean(h), st
end


# Convenience method for directly creating graph-level embeddings
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))

struct GenderPredictionModel <: AbstractLuxLayer
gin::GINConv
mlp::Chain
globalpool::GlobalPool
dense::Dense
end

# Implementation for GINConv with TemporalSnapshotsGNNGraph - non-mutating version
function (l::GINConv)(g::TemporalSnapshotsGNNGraph, x::AbstractVector, ps, st)
# Use map instead of preallocation and mutation
results = map(1:g.num_snapshots) do i
l(g.snapshots[i], x[i], ps, st)
end

# Extract outputs and final state
h = [r[1] for r in results]
st_final = results[end][2] # Use the final state

return h, st_final
end

# Constructor for GenderPredictionModel using Lux components
function GenderPredictionModel(; nfeatures = 103, nhidden = 128, σ = relu)
mlp = Chain(Dense(nfeatures => nhidden, σ), Dense(nhidden => nhidden, σ))
gin = GINConv(mlp, 0.5f0)
globalpool = GlobalPool(mean)
dense = Dense(nhidden => 2)
return GenderPredictionModel(gin, mlp, globalpool, dense)
end

# Type-constrained forward pass
function (m::GenderPredictionModel)(
g::TemporalSnapshotsGNNGraph,
x::AbstractVector,
ps::NamedTuple,
st::NamedTuple
)
# Now Julia will throw an error if types don't match
h, st_gin = m.gin(g, x, ps.gin, st.gin)
h, st_globalpool = m.globalpool(g, h, ps.globalpool, st.globalpool)
output, st_dense = m.dense(h, ps.dense, st.dense)

st_new = (gin=st_gin, globalpool=st_globalpool, dense=st_dense)
return output, st_new
end

# Type-constrained custom loss that handles the layers wrapper
function custom_loss(
model::GenderPredictionModel,
ps::NamedTuple,
st::NamedTuple,
tuple::Tuple{TemporalSnapshotsGNNGraph, AbstractVector, AbstractMatrix}
)
g, x, y = tuple
logitcrossentropy = CrossEntropyLoss(; logits=Val(true))

# Check if we're dealing with a state that has the layers wrapper
actual_st = if haskey(st, :layers)
st.layers # Unwrap the layers to get the actual state structure
else
st
end

# Ensure state is in trainmode
actual_st = Lux.trainmode(actual_st)

# Forward pass
ŷ, new_st = model(g, x, ps, actual_st)

# Wrap the new state back in the layers structure if needed
final_st = if haskey(st, :layers)
(layers = new_st,)
else
new_st
end

return logitcrossentropy(ŷ, y), final_st, 0
end

# Implement Lux interface methods for parameter and state initialization
function LuxCore.initialparameters(rng::AbstractRNG, m::GenderPredictionModel)
return (
gin = LuxCore.initialparameters(rng, m.gin),
mlp = LuxCore.initialparameters(rng, m.mlp),
globalpool = LuxCore.initialparameters(rng, m.globalpool),
dense = LuxCore.initialparameters(rng, m.dense)
)
end

function LuxCore.initialstates(rng::AbstractRNG, m::GenderPredictionModel)
return (
gin = LuxCore.initialstates(rng, m.gin),
mlp = LuxCore.initialstates(rng, m.mlp),
globalpool = LuxCore.initialstates(rng, m.globalpool),
dense = LuxCore.initialstates(rng, m.dense)
)
end

# Initialize model and parameters
model = GenderPredictionModel()
ps, st = LuxCore.initialparameters(rng, model), LuxCore.initialstates(rng, model);

# Simple loss function that works with predictions and targets
lossfunction(ŷ, y) = mean(-y .* log.(sigmoid.(ŷ)) - (1 .- y) .* log.(1 .- sigmoid.(ŷ)));

function eval_loss_accuracy(model, ps, st, data_loader)
losses = []
accs = []

for (g, y) in data_loader
# Extract features from each snapshot
x = [s.ndata.x for s in g.snapshots]

# Forward pass with Lux model
ŷ, _ = model(g, x, ps, st)

# Calculate loss
push!(losses, lossfunction(ŷ, y))

# Calculate accuracy
pred_indices = [argmax(ŷ[:, i]) for i in 1:size(ŷ, 2)]
true_indices = [argmax(y[:, i]) for i in 1:size(y, 2)]
accuracy = round(100 * mean(pred_indices .== true_indices), digits=2)
push!(accs, accuracy)
end

return (loss = mean(losses), acc = mean(accs))
end

# Train the model
train_loader, test_loader = data_loader(brain_dataset)

function train(model, train_loader, test_loader )
train_state = Lux.Training.TrainState(model, ps, st, Adam(1e-2))
function report(epoch)
current_ps = train_state.parameters
current_st = train_state.states
train = eval_loss_accuracy(model, current_ps, current_st, train_loader)
test_st = Lux.testmode(current_st)
test = eval_loss_accuracy(model, current_ps, test_st, test_loader)
@info (; epoch, train, test)
end

for epoch in 1:5
for (g, y) in train_loader
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss, (g, g.ndata.x, y), train_state)
end
if epoch % 1 == 0
report(epoch)
end
end
end

train(model, train_loader, test_loader)
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@
#
# We start by importing the necessary libraries. We use `GraphNeuralNetworks.jl`, `Flux.jl` and `MLDatasets.jl`, among others.

## Comments Miguel for CLaudio:
# 1. Create method to check the download datasets are download correctly, if not problems may arise. This happened to me when downloading TemporalBrains dataset.


using Flux
using GraphNeuralNetworks
using Statistics, Random
using LinearAlgebra
using MLDatasets: TemporalBrains
using CUDA # comment out if you don't have a CUDA GPU
using DataDeps
#using CUDA # comment out if you don't have a CUDA GPU

ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation
Random.seed!(17); # for reproducibility
Expand All @@ -29,7 +34,7 @@ Random.seed!(17); # for reproducibility
# Each temporal graph has a label representing gender ('M' for male and 'F' for female) and age group (22-25, 26-30, 31-35, and 36+).
# The network's edge weights are binarized, and the threshold is set to 0.6 by default.

brain_dataset = TemporalBrains()
brain_dataset = MLDatasets.TemporalBrains()

# After loading the dataset from the MLDatasets.jl package, we see that there are 1000 graphs and we need to convert them to the `TemporalSnapshotsGNNGraph` format.
# So we create a function called `data_loader` that implements the latter and splits the dataset into the training set that will be used to train the model and the test set that will be used to test the performance of the model. Due to computational costs, we use only 250 out of the original 1000 graphs, 200 for training and 50 for testing.
Expand Down Expand Up @@ -83,6 +88,11 @@ end

Flux.@layer GenderPredictionModel

function (l::GINConv)(g::TemporalSnapshotsGNNGraph, x::AbstractVector)
h = [l(g[i], x[i]) for i in 1:(g.num_snapshots)]
return h
end

function GenderPredictionModel(; nfeatures = 103, nhidden = 128, σ = relu)
mlp = Chain(Dense(nfeatures => nhidden, σ), Dense(nhidden => nhidden, σ))
gin = GINConv(mlp, 0.5)
Expand Down
Loading