Skip to content

Commit a260ae5

Browse files
committed
Added test that compares symbolic pullback to zygote pullback.
Flipped order of functions. Added _get_contents method for Tuple as argument. Added another method to deal with Zygote idiosyncracies. Fixed method for _get_params.
1 parent f495dc8 commit a260ae5

File tree

4 files changed

+54
-10
lines changed

4 files changed

+54
-10
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1414
AbstractNeuralNetworks = "0.3, 0.4"
1515
Documenter = "1.8.0"
1616
ForwardDiff = "0.10.38"
17+
GeometricMachineLearning = "0.3.7"
1718
Latexify = "0.16.5"
1819
RuntimeGeneratedFunctions = "0.5"
1920
Symbolics = "5, 6"
@@ -28,6 +29,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2829
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
2930
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3031
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
32+
GeometricMachineLearning = "194d25b2-d3f5-49f0-af24-c124f4aa80cc"
3133

3234
[targets]
33-
test = ["Test", "ForwardDiff", "Random", "Documenter", "Latexify", "SafeTestsets", "Zygote"]
35+
test = ["Test", "ForwardDiff", "Random", "Documenter", "Latexify", "SafeTestsets", "Zygote", "GeometricMachineLearning"]

src/derivatives/pullback.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ nn = SymbolicNeuralNetwork(c)
1616
loss = FeedForwardLoss()
1717
pb = SymbolicPullback(nn, loss)
1818
ps = initialparameters(c) |> NeuralNetworkParameters
19-
pv_values = pb(ps, nn.model, (rand(2), rand(1)))[2](1) |> typeof
19+
pb_values = pb(ps, nn.model, (rand(2), rand(1)))[2](1) |> typeof
2020
2121
# output
2222
@@ -47,19 +47,20 @@ import Random
4747
Random.seed!(123)
4848
4949
c = Chain(Dense(2, 1, tanh))
50-
nn = SymbolicNeuralNetwork(c)
50+
nn = NeuralNetwork(c)
51+
snn = SymbolicNeuralNetwork(nn)
5152
loss = FeedForwardLoss()
52-
pb = SymbolicPullback(nn, loss)
53-
ps = initialparameters(c) |> NeuralNetworkParameters
53+
pb = SymbolicPullback(snn, loss)
5454
input_output = (rand(2), rand(1))
55-
loss_and_pullback = pb(ps, nn.model, input_output)
56-
pv_values = loss_and_pullback[2](1)
55+
loss_and_pullback = pb(nn.params, nn.model, input_output)
56+
# note that we apply the second argument to another input `1`
57+
pb_values = loss_and_pullback[2](1)
5758
5859
@variables soutput[1:SymbolicNeuralNetworks.output_dimension(nn.model)]
5960
symbolic_pullbacks = SymbolicNeuralNetworks.symbolic_pullback(loss(nn.model, nn.params, nn.input, soutput), nn)
60-
pv_values2 = build_nn_function(symbolic_pullbacks, nn.params, nn.input, soutput)(input_output[1], input_output[2], ps)
61+
pb_values2 = build_nn_function(symbolic_pullbacks, nn.params, nn.input, soutput)(input_output[1], input_output[2], ps)
6162
62-
pv_values == (pv_values2 |> SymbolicNeuralNetworks._get_params |> SymbolicNeuralNetworks._get_contents)
63+
pb_values == (pb_values2 |> SymbolicNeuralNetworks._get_params |> SymbolicNeuralNetworks._get_contents)
6364
6465
# output
6566
@@ -106,6 +107,7 @@ Return the `NamedTuple` that's equivalent to the `NeuralNetworkParameters`.
106107
"""
107108
_get_params(nt::NamedTuple) = nt
108109
_get_params(ps::NeuralNetworkParameters) = ps.params
110+
_get_params(ps::NamedTuple{(:params,), Tuple{NT}}) where {NT<:NamedTuple} = ps.params
109111
_get_params(ps::AbstractArray{<:Union{NamedTuple, NeuralNetworkParameters}}) = [_get_params(nt) for nt in ps]
110112

111113
"""
@@ -134,6 +136,7 @@ function __get_contents(nt::AbstractArray{<:NamedTuple})
134136
nt
135137
end
136138
_get_contents(nt::AbstractArray{<:NamedTuple}) = __get_contents(nt)
139+
_get_contents(nt::Tuple{<:NamedTuple}) = nt[1]
137140

138141
# (_pullback::SymbolicPullback)(ps, model, input_nt::QPTOAT)::Tuple = Zygote.pullback(ps -> _pullback.loss(model, ps, input_nt), ps)
139142
function (_pullback::SymbolicPullback)(ps, model, input_nt_output_nt::Tuple{<:QPTOAT, <:QPTOAT})::Tuple

test/derivatives/pullback.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using SymbolicNeuralNetworks
2+
using SymbolicNeuralNetworks: _get_params, _get_contents
3+
using AbstractNeuralNetworks
4+
using Symbolics
5+
using GeometricMachineLearning: ZygotePullback
6+
using Test
7+
import Random
8+
Random.seed!(123)
9+
10+
compare_values(arr1::Array, arr2::Array) = @test arr1 arr2
11+
function compare_values(nt1::NamedTuple, nt2::NamedTuple)
12+
@assert keys(nt1) == keys(nt2)
13+
NamedTuple{keys(nt1)}((compare_values(arr1, arr2) for (arr1, arr2) in zip(values(nt1), values(nt2))))
14+
end
15+
16+
function compare_symbolic_pullback_to_zygote_pullback(input_dim::Integer, output_dim::Integer, second_dim::Integer=1)
17+
c = Chain(Dense(input_dim, output_dim, tanh))
18+
nn = NeuralNetwork(c)
19+
snn = SymbolicNeuralNetwork(nn)
20+
loss = FeedForwardLoss()
21+
spb = SymbolicPullback(snn, loss)
22+
input_output = (rand(input_dim, second_dim), rand(output_dim, second_dim))
23+
loss_and_pullback = spb(nn.params, nn.model, input_output)
24+
# note that we apply the second argument to another input `1`
25+
pb_values = loss_and_pullback[2](1)
26+
27+
zpb = ZygotePullback(loss)
28+
loss_and_pullback_zygote = zpb(nn.params, nn.model, input_output)
29+
pb_values_zygote = loss_and_pullback_zygote[2](1) |> _get_contents |> _get_params
30+
31+
compare_values(pb_values, pb_values_zygote)
32+
end
33+
34+
for input_dim (2, 3)
35+
for output_dim (1, 2)
36+
compare_symbolic_pullback_to_zygote_pullback(input_dim, output_dim)
37+
end
38+
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ using SafeTestsets
77
@safetestset "Symbolic Params " begin include("symbolic_neuralnet/symbolize.jl") end
88
@safetestset "Tests associated with 'build_function.jl' " begin include("build_function/build_function.jl") end
99
@safetestset "Tests associated with 'build_function_double_input.jl' " begin include("build_function/build_function_double_input.jl") end
10-
@safetestset "Tests associated with 'build_function_array.jl " begin include("build_function/build_function_arrays.jl") end
10+
@safetestset "Tests associated with 'build_function_array.jl " begin include("build_function/build_function_arrays.jl") end
11+
@safetestset "Compare Zygote Pullback with Symbolic Pullback " begin include("derivatives/pullback.jl") end

0 commit comments

Comments
 (0)