Skip to content

Commit fe4dfa7

Browse files
committed
params now used as function instead of as a keyword.
1 parent 53643c5 commit fe4dfa7

16 files changed

+87
-81
lines changed

docs/src/double_derivative.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
```@example jacobian_gradient
1919
using AbstractNeuralNetworks
2020
using SymbolicNeuralNetworks
21-
using SymbolicNeuralNetworks: Jacobian, Gradient, derivative
21+
using SymbolicNeuralNetworks: Jacobian, Gradient, derivative, params
2222
using Latexify: latexify
2323
2424
c = Chain(Dense(2, 1, tanh; use_bias = false))
@@ -92,7 +92,7 @@ x = \begin{pmatrix} 1 \\ 0 \end{pmatrix}, \quad W = \begin{bmatrix} 1 & 0 \\ 0 &
9292
```
9393

9494
```@example jacobian_gradient
95-
built_function = build_nn_function(derivative(g), nn.params, nn.input)
95+
built_function = build_nn_function(derivative(g), params(nn), nn.input)
9696
9797
x = [1., 0.]
9898
ps = NeuralNetworkParameters((L1 = (W = [1. 0.; 0. 1.], b = [0., 0.]), ))

docs/src/symbolic_neural_networks.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ We first call the symbolic neural network that only consists of one layer:
66

77
```@example snn
88
using SymbolicNeuralNetworks
9-
using AbstractNeuralNetworks: Chain, Dense, initialparameters
9+
using AbstractNeuralNetworks: Chain, Dense, initialparameters, params
1010
1111
input_dim = 2
1212
output_dim = 1
@@ -23,7 +23,7 @@ using Symbolics
2323
using Latexify: latexify
2424
2525
@variables sinput[1:input_dim]
26-
soutput = nn.model(sinput, nn.params)
26+
soutput = nn.model(sinput, params(nn))
2727
2828
soutput
2929
```
@@ -101,7 +101,7 @@ We now compare the neural network-approximated curve to the original one:
101101
fig = Figure()
102102
ax = Axis3(fig[1, 1])
103103
104-
surface!(x_vec, y_vec, [c([x, y], nn_cpu.params)[1] for x in x_vec, y in y_vec]; alpha = .8, colormap = :darkterrain, transparency = true)
104+
surface!(x_vec, y_vec, [c([x, y], params(nn_cpu))[1] for x in x_vec, y in y_vec]; alpha = .8, colormap = :darkterrain, transparency = true)
105105
fig
106106
```
107107

scripts/pullback_comparison.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ output = rand(1, batch_size)
1919
# output sensitivities
2020
_do = 1.
2121

22-
# spb(nn_cpu.params, nn.model, (input, output))[2](_do)
23-
# zpb(nn_cpu.params, nn.model, (input, output))[2](_do)
24-
# @time spb_evaluated = spb(nn_cpu.params, nn.model, (input, output))[2](_do)
25-
# @time zpb_evaluated = zpb(nn_cpu.params, nn.model, (input, output))[2](_do)[1].params
22+
# spb(params(nn_cpu), nn.model, (input, output))[2](_do)
23+
# zpb(params(nn_cpu), nn.model, (input, output))[2](_do)
24+
# @time spb_evaluated = spb(params(nn_cpu), nn.model, (input, output))[2](_do)
25+
# @time zpb_evaluated = zpb(params(nn_cpu), nn.model, (input, output))[2](_do)[1].params
2626
# @assert values(spb_evaluated) .≈ values(zpb_evaluated)
2727

2828
function timenn(pb, params, model, input, output, _do = 1.)
2929
pb(params, model, (input, output))[2](_do)
3030
@time pb(params, model, (input, output))[2](_do)
3131
end
3232

33-
timenn(spb, nn_cpu.params, nn.model, input, output)
34-
timenn(zpb, nn_cpu.params, nn.model, input, output)
33+
timenn(spb, params(nn_cpu), nn.model, input, output)
34+
timenn(zpb, params(nn_cpu), nn.model, input, output)

src/build_function/build_function.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ The functions mentioned in the implementation section were adjusted ad-hoc to de
1919
Other problems may occur. In case you bump into one please [open an issue on github](https://github.com/JuliaGNI/SymbolicNeuralNetworks.jl/issues).
2020
"""
2121
function build_nn_function(eq::EqT, nn::AbstractSymbolicNeuralNetwork)
22-
build_nn_function(eq, nn.params, nn.input)
22+
build_nn_function(eq, params(nn), nn.input)
2323
end
2424

2525
function build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr)
@@ -40,16 +40,16 @@ Build a function that can process a matrix. This is used as a starting point for
4040
4141
```jldoctest
4242
using SymbolicNeuralNetworks: _build_nn_function, SymbolicNeuralNetwork
43-
using AbstractNeuralNetworks
43+
using AbstractNeuralNetworks: params, Chain, Dense, NeuralNetwork
4444
import Random
4545
Random.seed!(123)
4646
4747
c = Chain(Dense(2, 1, tanh))
4848
nn = NeuralNetwork(c)
4949
snn = SymbolicNeuralNetwork(nn)
50-
eq = c(snn.input, snn.params)
51-
built_function = _build_nn_function(eq, snn.params, snn.input)
52-
built_function([1. 2.; 3. 4.], nn.params, 1)
50+
eq = c(snn.input, params(snn))
51+
built_function = _build_nn_function(eq, params(snn), snn.input)
52+
built_function([1. 2.; 3. 4.], params(nn), 1)
5353
5454
# output
5555

src/build_function/build_function_arrays.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ Build an executable function based on an array of symbolic equations `eqs`.
77
88
```jldoctest
99
using SymbolicNeuralNetworks: build_nn_function, SymbolicNeuralNetwork
10-
using AbstractNeuralNetworks: Chain, Dense, NeuralNetwork
10+
using AbstractNeuralNetworks: Chain, Dense, NeuralNetwork, params
1111
import Random
1212
Random.seed!(123)
1313
1414
ch = Chain(Dense(2, 1, tanh))
1515
nn = NeuralNetwork(ch)
1616
snn = SymbolicNeuralNetwork(nn)
17-
eqs = [(a = ch(snn.input, snn.params), b = ch(snn.input, snn.params).^2), (c = ch(snn.input, snn.params).^3, )]
18-
funcs = build_nn_function(eqs, snn.params, snn.input)
17+
eqs = [(a = ch(snn.input, params(snn)), b = ch(snn.input, params(snn)).^2), (c = ch(snn.input, params(snn)).^3, )]
18+
funcs = build_nn_function(eqs, params(snn), snn.input)
1919
input = [1., 2.]
20-
funcs_evaluated = funcs(input, nn.params)
20+
funcs_evaluated = funcs(input, params(nn))
2121
2222
# output
2323
@@ -44,17 +44,17 @@ Return a function that takes an input, (optionally) an output and neural network
4444
4545
```jldoctest
4646
using SymbolicNeuralNetworks: build_nn_function, SymbolicNeuralNetwork
47-
using AbstractNeuralNetworks: Chain, Dense, NeuralNetwork
47+
using AbstractNeuralNetworks: Chain, Dense, NeuralNetwork, params
4848
import Random
4949
Random.seed!(123)
5050
5151
c = Chain(Dense(2, 1, tanh))
5252
nn = NeuralNetwork(c)
5353
snn = SymbolicNeuralNetwork(nn)
54-
eqs = (a = c(snn.input, snn.params), b = c(snn.input, snn.params).^2)
55-
funcs = build_nn_function(eqs, snn.params, snn.input)
54+
eqs = (a = c(snn.input, params(snn)), b = c(snn.input, params(snn)).^2)
55+
funcs = build_nn_function(eqs, params(snn), snn.input)
5656
input = [1., 2.]
57-
funcs_evaluated = funcs(input, nn.params)
57+
funcs_evaluated = funcs(input, params(nn))
5858
5959
# output
6060
@@ -83,14 +83,14 @@ Return an executable function for each entry in `eqs`. This still has to be proc
8383
8484
```jldoctest
8585
using SymbolicNeuralNetworks: function_valued_parameters, SymbolicNeuralNetwork
86-
using AbstractNeuralNetworks: Chain, Dense, initialparameters, NeuralNetworkParameters
86+
using AbstractNeuralNetworks: Chain, Dense, initialparameters, NeuralNetworkParameters, params
8787
import Random
8888
Random.seed!(123)
8989
9090
c = Chain(Dense(2, 1, tanh))
9191
nn = SymbolicNeuralNetwork(c)
92-
eqs = (a = c(nn.input, nn.params), b = c(nn.input, nn.params).^2)
93-
funcs = function_valued_parameters(eqs, nn.params, nn.input)
92+
eqs = (a = c(nn.input, params(nn)), b = c(nn.input, params(nn)).^2)
93+
funcs = function_valued_parameters(eqs, params(nn), nn.input)
9494
input = [1., 2.]
9595
ps = initialparameters(c) |> NeuralNetworkParameters
9696
a = c(input, ps)

src/build_function/build_function_double_input.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Also compare this to [`build_nn_function(::EqT, ::AbstractSymbolicNeuralNetwork)
1313
See the *extended help section* of [`build_nn_function(::EqT, ::AbstractSymbolicNeuralNetwork)`](@ref).
1414
"""
1515
function build_nn_function(eqs, nn::AbstractSymbolicNeuralNetwork, soutput)
16-
build_nn_function(eqs, nn.params, nn.input, soutput)
16+
build_nn_function(eqs, params(nn), nn.input, soutput)
1717
end
1818

1919
function build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr, soutput::Symbolics.Arr)

src/derivatives/gradient.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ function Gradient(output::EqT, nn::SymbolicNeuralNetwork)
7575
end
7676

7777
function Gradient(nn::SymbolicNeuralNetwork)
78-
Gradient(nn.model(nn.input, nn.params), nn)
78+
Gradient(nn.model(nn.input, params(nn)), nn)
7979
end
8080

8181
@doc raw"""
@@ -90,12 +90,13 @@ This is used by [`Gradient`](@ref) and [`SymbolicPullback`](@ref).
9090
```jldoctest
9191
using SymbolicNeuralNetworks: SymbolicNeuralNetwork, symbolic_pullback
9292
using AbstractNeuralNetworks
93+
using AbstractNeuralNetworks: params
9394
using LinearAlgebra: norm
9495
using Latexify: latexify
9596
9697
c = Chain(Dense(2, 1, tanh))
9798
nn = SymbolicNeuralNetwork(c)
98-
output = c(nn.input, nn.params)
99+
output = c(nn.input, params(nn))
99100
spb = symbolic_pullback(output, nn)
100101
101102
spb[1].L1.b |> latexify
@@ -113,6 +114,6 @@ L"\begin{equation}
113114
```
114115
"""
115116
function symbolic_pullback(soutput::EqT, nn::AbstractSymbolicNeuralNetwork)::Union{AbstractArray{<:Union{NamedTuple, NeuralNetworkParameters}}, Union{NamedTuple, NeuralNetworkParameters}}
116-
symbolic_diffs = symbolic_differentials(nn.params)
117+
symbolic_diffs = symbolic_differentials(params(nn))
117118
[symbolic_derivative(soutput_single, symbolic_diffs) for soutput_single soutput]
118119
end

src/derivatives/jacobian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ The output of `Jacobian` consists of a `NamedTuple` that has the following keys:
1818
If `output` is not supplied as an input argument than it is taken to be:
1919
2020
```julia
21-
soutput = nn.model(nn.input, nn.params)
21+
soutput = nn.model(nn.input, params(nn))
2222
```
2323
2424
# Implementation
@@ -82,7 +82,7 @@ derivative(j::Jacobian) = j.□
8282
function Jacobian(nn::AbstractSymbolicNeuralNetwork)
8383

8484
# Evaluation of the symbolic output
85-
soutput = nn.model(nn.input, nn.params)
85+
soutput = nn.model(nn.input, params(nn))
8686

8787
Jacobian(soutput, nn)
8888
end

src/derivatives/pullback.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ We note the following seeming peculiarity:
4141
4242
```jldoctest
4343
using SymbolicNeuralNetworks
44-
using AbstractNeuralNetworks
44+
using AbstractNeuralNetworks: Chain, Dense, NeuralNetwork, FeedForwardLoss, params
4545
using Symbolics
4646
import Random
4747
Random.seed!(123)
@@ -52,13 +52,13 @@ snn = SymbolicNeuralNetwork(nn)
5252
loss = FeedForwardLoss()
5353
pb = SymbolicPullback(snn, loss)
5454
input_output = (rand(2), rand(1))
55-
loss_and_pullback = pb(nn.params, nn.model, input_output)
55+
loss_and_pullback = pb(params(nn), nn.model, input_output)
5656
# note that we apply the second argument to another input `1`
5757
pb_values = loss_and_pullback[2](1)
5858
5959
@variables soutput[1:SymbolicNeuralNetworks.output_dimension(nn.model)]
60-
symbolic_pullbacks = SymbolicNeuralNetworks.symbolic_pullback(loss(nn.model, snn.params, snn.input, soutput), snn)
61-
pb_values2 = build_nn_function(symbolic_pullbacks, snn.params, snn.input, soutput)(input_output[1], input_output[2], nn.params)
60+
symbolic_pullbacks = SymbolicNeuralNetworks.symbolic_pullback(loss(nn.model, params(snn), snn.input, soutput), snn)
61+
pb_values2 = build_nn_function(symbolic_pullbacks, params(snn), snn.input, soutput)(input_output[1], input_output[2], params(nn))
6262
6363
pb_values == (pb_values2 |> SymbolicNeuralNetworks._get_contents |> SymbolicNeuralNetworks._get_params)
6464
@@ -88,9 +88,9 @@ end
8888

8989
function SymbolicPullback(nn::SymbolicNeuralNetwork, loss::NetworkLoss)
9090
@variables soutput[1:output_dimension(nn.model)]
91-
symbolic_loss = loss(nn.model, nn.params, nn.input, soutput)
91+
symbolic_loss = loss(nn.model, params(nn), nn.input, soutput)
9292
symbolic_pullbacks = symbolic_pullback(symbolic_loss, nn)
93-
pbs_executable = build_nn_function(symbolic_pullbacks, nn.params, nn.input, soutput)
93+
pbs_executable = build_nn_function(symbolic_pullbacks, params(nn), nn.input, soutput)
9494
function pbs(input, output, params)
9595
pullback(::Union{Real, AbstractArray{<:Real}}) = _get_contents(_get_params(pbs_executable(input, output, params)))
9696
pullback
@@ -106,7 +106,7 @@ SymbolicPullback(nn::SymbolicNeuralNetwork) = SymbolicPullback(nn, AbstractNeura
106106
Return the `NamedTuple` that's equivalent to the `NeuralNetworkParameters`.
107107
"""
108108
_get_params(nt::NamedTuple) = nt
109-
_get_params(ps::NeuralNetworkParameters) = ps.params
109+
_get_params(ps::NeuralNetworkParameters) = params(ps)
110110
_get_params(ps::NamedTuple{(:params,), Tuple{NT}}) where {NT<:NamedTuple} = ps.params
111111
_get_params(ps::AbstractArray{<:Union{NamedTuple, NeuralNetworkParameters}}) = [_get_params(nt) for nt in ps]
112112

src/symbolic_neuralnet/symbolic_neuralnet.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ end
3131

3232
function SymbolicNeuralNetwork(nn::NeuralNetwork)
3333
cache = Dict()
34-
sparams = symbolize!(cache, nn.params, :W)
34+
sparams = symbolize!(cache, params(nn), :W)
3535
@variables sinput[1:input_dimension(nn.model)]
3636

3737
SymbolicNeuralNetwork(nn.architecture, nn.model, sparams, sinput)
@@ -54,6 +54,8 @@ function SymbolicNeuralNetwork(d::AbstractExplicitLayer)
5454
SymbolicNeuralNetwork(UnknownArchitecture(), d)
5555
end
5656

57+
params(snn::AbstractSymbolicNeuralNetwork) = snn.params
58+
5759
apply(snn::AbstractSymbolicNeuralNetwork, x, args...) = snn(x, args...)
5860

5961
input_dimension(::AbstractExplicitLayer{M}) where M = M
@@ -68,5 +70,5 @@ function Base.show(io::IO, snn::SymbolicNeuralNetwork)
6870
print(io, "\nModel = ")
6971
print(io, snn.model)
7072
print(io, "\nSymbolic Params = ")
71-
print(io, snn.params)
73+
print(io, params(snn))
7274
end

0 commit comments

Comments
 (0)