Skip to content

Commit 52be247

Browse files
committed
Removed reduce_code function.
1 parent 42aa39a commit 52be247

File tree

5 files changed

+31
-64
lines changed

5 files changed

+31
-64
lines changed

src/derivatives/gradient.jl

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,14 @@ Compute the symbolic output of `nn` and differentiate it with respect to the par
1515
1616
# Examples
1717
18-
```jldoctest
18+
```julia
1919
using SymbolicNeuralNetworks: SymbolicNeuralNetwork, Gradient, derivative
2020
using AbstractNeuralNetworks
2121
using Latexify: latexify
2222
2323
c = Chain(Dense(2, 1, tanh))
2424
nn = SymbolicNeuralNetwork(c)
25-
(Gradient(nn) |> derivative)[1].L1.b |> latexify
26-
27-
# output
28-
29-
L"\begin{equation}
30-
\left[
31-
\begin{array}{c}
32-
1 - \tanh^{2}\left( \mathtt{b\_1}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) \\
33-
\end{array}
34-
\right]
35-
\end{equation}
36-
"
25+
(Gradient(nn) |> derivative)[1].L1.b
3726
```
3827
3928
# Implementation
@@ -87,7 +76,7 @@ This is used by [`Gradient`](@ref) and [`SymbolicPullback`](@ref).
8776
8877
# Examples
8978
90-
```jldoctest
79+
```julia
9180
using SymbolicNeuralNetworks: SymbolicNeuralNetwork, symbolic_pullback
9281
using AbstractNeuralNetworks
9382
using LinearAlgebra: norm
@@ -98,18 +87,7 @@ nn = SymbolicNeuralNetwork(c)
9887
output = c(nn.input, nn.params)
9988
spb = symbolic_pullback(output, nn)
10089
101-
spb[1].L1.b |> latexify
102-
103-
# output
104-
105-
L"\begin{equation}
106-
\left[
107-
\begin{array}{c}
108-
1 - \tanh^{2}\left( \mathtt{b\_1}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) \\
109-
\end{array}
110-
\right]
111-
\end{equation}
112-
"
90+
spb[1].L1.b
11391
```
11492
"""
11593
function symbolic_pullback(soutput::EqT, nn::AbstractSymbolicNeuralNetwork)::Union{AbstractArray{<:Union{NamedTuple, NeuralNetworkParameters}}, Union{NamedTuple, NeuralNetworkParameters}}

src/utils/build_function.jl

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -68,31 +68,16 @@ This first calls `Symbolics.build_function` with the keyword argument `expressio
6868
6969
See the docstrings for those functions for details on how the code is modified.
7070
"""
71-
function _build_nn_function(eq::EqT, params::NeuralNetworkParameters, sinput::Symbolics.Arr)
71+
function _build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr)
7272
sc_eq = Symbolics.scalarize(eq)
73-
code = build_function(sc_eq, sinput, values(params)...; expression = Val{true}) |> _reduce_code
73+
code = build_function(sc_eq, sinput, values(sparams)...; expression = Val{true}) |> _reduce
7474
rewritten_code = fix_map_reduce(modify_input_arguments(rewrite_arguments(fix_create_array(code))))
7575
parallelized_code = make_kernel(rewritten_code)
7676
@RuntimeGeneratedFunction(parallelized_code)
7777
end
7878

79-
"""
80-
_reduce_code(code)
81-
82-
Reduce the code.
83-
84-
For some reason `Symbolics.build_function` sometimes returns a tuple and sometimes it doesn't.
85-
86-
This function takes care of this.
87-
If `build_function` returns a tuple `reduce_code` checks which of the expressions is in-place and then returns the other (not in-place) expression.
88-
"""
89-
function _reduce_code(code::Expr)
90-
code
91-
end
92-
93-
function _reduce_code(code::Tuple{Expr, Expr})
94-
contains(string(code[1]), "ˍ₋out") ? code[2] : code[1]
95-
end
79+
_reduce(a) = a
80+
_reduce(a::Tuple) = a[1]
9681

9782
"""
9883
rewrite_arguments(s)

src/utils/build_function2.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ See the docstrings for those functions for details on how the code is modified.
4242
"""
4343
function _build_nn_function(eq::EqT, params::NeuralNetworkParameters, sinput::Symbolics.Arr, soutput::Symbolics.Arr)
4444
sc_eq = Symbolics.scalarize(eq)
45-
code = build_function(sc_eq, sinput, soutput, values(params)...; expression = Val{true}) |> _reduce_code
45+
code = build_function(sc_eq, sinput, soutput, values(params)...; expression = Val{true}) |> _reduce
4646
rewritten_code = fix_map_reduce(modify_input_arguments2(rewrite_arguments2(fix_create_array(code))))
4747
parallelized_code = make_kernel2(rewritten_code)
4848
@RuntimeGeneratedFunction(parallelized_code)

test/neural_network_derivative.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ function test_jacobian(n::Integer, T = Float32)
2828

2929
params = NeuralNetwork(c, T).params
3030
input = rand(T, n)
31-
@test build_nn_function(g.output, nn)(input, params) c(input, params)
32-
@test build_nn_function(derivative(g), nn)(input, params) ForwardDiff.jacobian(input -> c(input, params), input)
31+
f = build_nn_function(g.output, nn)
32+
∇f = build_nn_function(derivative(g), nn)
33+
@test f(input, params) c(input, params)
34+
@test ∇f(input, params) ForwardDiff.jacobian(input -> c(input, params), input)
3335
end
3436

35-
for n 1:10
37+
for n 10:1
3638
for T (Float32, Float64)
3739
test_jacobian(n, T)
3840
end

test/symbolic_gradient.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,32 @@ using SymbolicNeuralNetworks: symbolic_differentials, symbolic_derivative, _buil
33
using LinearAlgebra: norm
44
using Symbolics, AbstractNeuralNetworks
55
using AbstractNeuralNetworks: NeuralNetworkParameters
6+
using Test
67

78
import Zygote
89
import Random
910
Random.seed!(123)
1011

11-
"""
12-
This test checks if we perform the parallelization in the correct way.
13-
"""
14-
function test_symbolic_gradient(input_dim::Integer = 3, output_dim::Integer = 1, hidden_dim::Integer = 2, T::DataType = Float64, second_dim::Integer = 3)
15-
@assert second_dim > 1 "second_dim must be greater than 1!"
12+
function chain_input_output_and_params(input_dim::Integer, hidden_dim::Integer, output_dim::Integer, T::DataType)
1613
c = Chain(Dense(input_dim, hidden_dim, tanh), Dense(hidden_dim, output_dim, tanh))
1714
sparams = symbolicparameters(c)
1815
ps = NeuralNetwork(c, T).params
1916
@variables sinput[1:input_dim]
2017
sout = norm(c(sinput, sparams)) ^ 2
2118
sdparams = symbolic_differentials(sparams)
2219
_sgrad = symbolic_derivative(sout, sdparams)
20+
c, ps, sinput, sparams, _sgrad
21+
end
22+
23+
"""
24+
This test checks if we perform the parallelization in the correct way.
25+
"""
26+
function test_symbolic_gradient(input_dim::Integer = 3, output_dim::Integer = 1, hidden_dim::Integer = 2, T::DataType = Float64, second_dim::Integer = 3)
27+
@assert second_dim > 1 "second_dim must be greater than 1!"
28+
c, ps, sinput, sparams, _sgrad = chain_input_output_and_params(input_dim, hidden_dim, output_dim, T)
2329
input = rand(T, input_dim, second_dim)
2430
for k in 1:second_dim
31+
# derivative for one vector
2532
zgrad = Zygote.gradient(ps -> (norm(c(input[:, k], ps)) ^ 2), ps)[1].params
2633
for key1 in keys(_sgrad)
2734
for key2 in keys(_sgrad[key1])
@@ -35,23 +42,18 @@ function test_symbolic_gradient(input_dim::Integer = 3, output_dim::Integer = 1,
3542
end
3643

3744
"""
38-
Also checks the parallelization, but for the full function.
45+
Also checks the parallelization, but by calling `build_nn_function` instead of `_build_nn_function`.
3946
"""
4047
function test_symbolic_gradient2(input_dim::Integer = 3, output_dim::Integer = 1, hidden_dim::Integer = 2, T::DataType = Float64, second_dim::Integer = 1, third_dim::Integer = 1)
41-
c = Chain(Dense(input_dim, hidden_dim, tanh), Dense(hidden_dim, output_dim, tanh))
42-
sparams = symbolicparameters(c)
43-
ps = NeuralNetwork(c, T).params
44-
@variables sinput[1:input_dim]
45-
sout = norm(c(sinput, sparams)) ^ 2
48+
c, ps, sinput, sparams, _sgrad = chain_input_output_and_params(input_dim, hidden_dim, output_dim, T)
4649
input = rand(T, input_dim, second_dim, third_dim)
47-
zgrad = Zygote.gradient(ps -> (norm(c(input, ps)) ^ 2), ps)[1].params
48-
sdparams = symbolic_differentials(sparams)
49-
_sgrad = symbolic_derivative(sout, sdparams)
5050
sgrad = build_nn_function(_sgrad, sparams, sinput)(input, ps)
51+
# derivative for whole array
52+
zgrad = Zygote.gradient(ps -> (norm(c(input, ps)) ^ 2), ps)[1].params
5153
for key1 in keys(sgrad) for key2 in keys(sgrad[key1]) @test zgrad[key1][key2] sgrad[key1][key2] end end
5254
end
5355

54-
for second_dim in (2, 3, 4)
56+
for second_dim in (4, )
5557
test_symbolic_gradient(3, 1, 2, Float64, second_dim)
5658
end
5759

0 commit comments

Comments
 (0)