Skip to content

Feature: TGCN should support non linear_activations #596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 28, 2025
Merged
15 changes: 10 additions & 5 deletions GNNLux/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ LuxCore.apply(m::GNNContainerLayer, g, x, ps, st) = m(g, x, ps, st)
init_state::Function
end

function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform,
init_state = zeros32, init_bias = zeros32, add_self_loops = false,
use_edge_weight = true, act = sigmoid)
in_dims, out_dims = ch
conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight)
conv = GCNConv(ch, act; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight)
gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state)
return TGCNCell(in_dims, out_dims, conv, gru, init_state)
end
Expand All @@ -57,7 +59,7 @@ function Base.show(io::IO, tgcn::TGCNCell)
end

"""
TGCN(in => out; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
TGCN(in => out; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true, act = sigmoid)

Temporal Graph Convolutional Network (T-GCN) recurrent layer from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf).

Expand All @@ -76,7 +78,7 @@ Performs a layer of GCNConv to model spatial dependencies, followed by a Gated R
If `add_self_loops=true` the new weights will be set to 1.
This option is ignored if the `edge_weight` is explicitly provided in the forward pass.
Default `false`.

- `act`: Activation function used in the GCNConv layer. Default `sigmoid`.


# Examples
Expand All @@ -91,9 +93,12 @@ rng = Random.default_rng()
g = rand_graph(rng, 5, 10)
x = rand(rng, Float32, 2, 5)

# create TGCN layer
# create TGCN layer
tgcn = TGCN(2 => 6)

# create TGCN layer with custom activation
tgcn_relu = TGCN(2 => 6, act = relu)

# setup layer
ps, st = LuxCore.setup(rng, tgcn)

Expand Down
14 changes: 14 additions & 0 deletions GNNLux/test/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,25 @@
tx = [x for _ in 1:5]

@testset "TGCN" begin
# Test with default activation (sigmoid)
l = TGCN(3=>3)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
y1, _ = l(g, x, ps, st)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])

# Test with custom activation (relu)
l_relu = TGCN(3=>3, act = relu)
ps_relu = LuxCore.initialparameters(rng, l_relu)
st_relu = LuxCore.initialstates(rng, l_relu)
y2, _ = l_relu(g, x, ps_relu, st_relu)

# Outputs should be different with different activation functions
@test !isapprox(y1, y2, rtol=1.0f-2)

loss_relu = (x, ps) -> sum(first(l_relu(g, x, ps, st_relu)))
test_gradients(loss_relu, x, ps_relu; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

@testset "A3TGCN" begin
Expand Down
26 changes: 18 additions & 8 deletions GraphNeuralNetworks/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ EvolveGCNO(args...; kws...) = GNNRecurrence(EvolveGCNOCell(args...; kws...))


"""
TGCNCell(in => out; kws...)
TGCNCell(in => out, act = relu, kws...)

Recurrent graph convolutional cell from the paper
[T-GCN: A Temporal Graph Convolutional
Expand Down Expand Up @@ -824,12 +824,14 @@ end

Flux.@layer :noexpand TGCNCell

function TGCNCell((in, out)::Pair{Int, Int}; kws...)
conv_z = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
function TGCNCell((in, out)::Pair{Int, Int};
act = relu,
kws...)
conv_z = GNNChain(GCNConv(in => out, act; kws...), GCNConv(out => out; kws...))
dense_z = Dense(2*out => out, sigmoid)
conv_r = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
conv_r = GNNChain(GCNConv(in => out, act; kws...), GCNConv(out => out; kws...))
dense_r = Dense(2*out => out, sigmoid)
conv_h = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
conv_h = GNNChain(GCNConv(in => out, act; kws...), GCNConv(out => out; kws...))
dense_h = Dense(2*out => out, tanh)
return TGCNCell(in, out, conv_z, dense_z, conv_r, dense_r, conv_h, dense_h)
end
Expand Down Expand Up @@ -868,6 +870,8 @@ See [`GNNRecurrence`](@ref) for more details.
# Examples

```jldoctest
julia> using Flux # Ensure activation functions are available

julia> num_nodes, num_edges = 5, 10;

julia> d_in, d_out = 2, 3;
Expand All @@ -876,9 +880,14 @@ julia> timesteps = 5;

julia> g = rand_graph(num_nodes, num_edges);

julia> x = rand(Float32, d_in, timesteps, num_nodes);
julia> x = rand(Float32, d_in, timesteps, g.num_nodes);

julia> layer = TGCN(d_in => d_out) # Default activation (relu)
GNNRecurrence(
TGCNCell(2 => 3), # 126 parameters
) # Total: 18 arrays, 126 parameters, 1.469 KiB.

julia> layer = TGCN(d_in => d_out)
julia> layer_tanh = TGCN(d_in => d_out, act = tanh) # Custom activation
GNNRecurrence(
TGCNCell(2 => 3), # 126 parameters
) # Total: 18 arrays, 126 parameters, 1.469 KiB.
Expand All @@ -889,5 +898,6 @@ julia> size(y) # (d_out, timesteps, num_nodes)
(3, 5, 5)
```
"""
TGCN(args...; kws...) = GNNRecurrence(TGCNCell(args...; kws...))
TGCN(args...; kws...) =
GNNRecurrence(TGCNCell(args...; kws...))

30 changes: 30 additions & 0 deletions GraphNeuralNetworks/test/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ end

@testitem "TGCNCell" setup=[TemporalConvTestModule, TestModule] begin
using .TemporalConvTestModule, .TestModule

# Test with default activation function
cell = GraphNeuralNetworks.TGCNCell(in_channel => out_channel)
y, h = cell(g, g.x)
@test y === h
Expand All @@ -33,10 +35,25 @@ end
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_HIGH)
# with initial state
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_HIGH)

# Test with custom activation function
custom_activation = tanh
cell_custom = GraphNeuralNetworks.TGCNCell(in_channel => out_channel, act = custom_activation)
y_custom, h_custom = cell_custom(g, g.x)
@test y_custom === h_custom
@test size(h_custom) == (out_channel, g.num_nodes)
# Test that outputs differ when using different activation functions
@test !isapprox(y, y_custom, rtol=RTOL_HIGH)
# with no initial state
test_gradients(cell_custom, g, g.x, loss=cell_loss, rtol=RTOL_HIGH)
# with initial state
test_gradients(cell_custom, g, g.x, h_custom, loss=cell_loss, rtol=RTOL_HIGH)
end

@testitem "TGCN" setup=[TemporalConvTestModule, TestModule] begin
using .TemporalConvTestModule, .TestModule

# Test with default activation function
layer = TGCN(in_channel => out_channel)
x = rand(Float32, in_channel, timesteps, g.num_nodes)
state0 = rand(Float32, out_channel, g.num_nodes)
Expand All @@ -48,6 +65,19 @@ end
# with initial state
test_gradients(layer, g, x, state0, rtol = RTOL_HIGH)

# Test with custom activation function
custom_activation = tanh
layer_custom = TGCN(in_channel => out_channel, act = custom_activation)
y_custom = layer_custom(g, x)
@test layer_custom isa GNNRecurrence
@test size(y_custom) == (out_channel, timesteps, g.num_nodes)
# Test that outputs differ when using different activation functions
@test !isapprox(y, y_custom, rtol = RTOL_HIGH)
# with no initial state
test_gradients(layer_custom, g, x, rtol = RTOL_HIGH)
# with initial state
test_gradients(layer_custom, g, x, state0, rtol = RTOL_HIGH)

# interplay with GNNChain
model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1))
y = model(g, x)
Expand Down