Skip to content

Commit cb7f6fe

Browse files
committed
Merge branch 'main' into fix-reshape
2 parents 0ed585f + 5343bae commit cb7f6fe

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+530
-1000
lines changed

Project.toml

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SymbolicNeuralNetworks"
22
uuid = "aed23131-dcd0-47ca-8090-d21e605652e3"
33
authors = ["Michael Kraus"]
4-
version = "0.2.0"
4+
version = "0.3.0"
55

66
[deps]
77
AbstractNeuralNetworks = "60874f82-5ada-4c70-bd1c-fa6be7711c8a"
@@ -10,25 +10,30 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
1111
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1212

13+
[weakdeps]
14+
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
15+
1316
[compat]
14-
AbstractNeuralNetworks = "0.3, 0.4"
17+
AbstractNeuralNetworks = "0.3, 0.4, 0.5, 0.6"
1518
Documenter = "1.8.0"
1619
ForwardDiff = "0.10.38"
20+
GeometricMachineLearning = "0.4"
1721
Latexify = "0.16.5"
1822
RuntimeGeneratedFunctions = "0.5"
19-
SafeTestsets = "0.1"
23+
SymbolicUtils = "<3.8.0"
2024
Symbolics = "5, 6"
2125
Zygote = "0.6.73"
22-
julia = "1.6"
26+
julia = "1.10"
2327

2428
[extras]
2529
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
2630
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
31+
GeometricMachineLearning = "194d25b2-d3f5-49f0-af24-c124f4aa80cc"
2732
Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
2833
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2934
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3035
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3136
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3237

3338
[targets]
34-
test = ["Test", "ForwardDiff", "Random", "Documenter", "Latexify", "SafeTestsets", "Zygote"]
39+
test = ["Test", "ForwardDiff", "Random", "Documenter", "Latexify", "SafeTestsets", "Zygote", "GeometricMachineLearning"]

docs/src/double_derivative.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ We can easily differentiate a neural network twice by using [`SymbolicNeuralNetw
5959
```@example jacobian_gradient
6060
using AbstractNeuralNetworks
6161
using SymbolicNeuralNetworks
62-
using SymbolicNeuralNetworks: Jacobian, Gradient, derivative
62+
using SymbolicNeuralNetworks: Jacobian, Gradient, derivative, params
6363
using Latexify: latexify
6464
6565
c = Chain(Dense(2, 1, tanh))
@@ -92,7 +92,7 @@ x = \begin{pmatrix} 1 \\ 0 \end{pmatrix}, \quad W = \begin{bmatrix} 1 & 0 \\ 0 &
9292
```
9393

9494
```@example jacobian_gradient
95-
built_function = build_nn_function(derivative(g), nn.params, nn.input)
95+
built_function = build_nn_function(derivative(g), params(nn), nn.input)
9696
9797
x = [1., 0.]
9898
ps = NeuralNetworkParameters((L1 = (W = [1. 0.; 0. 1.], b = [0., 0.]), ))

docs/src/symbolic_neural_networks.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ We first call the symbolic neural network that only consists of one layer:
66

77
```@example snn
88
using SymbolicNeuralNetworks
9-
using AbstractNeuralNetworks: Chain, Dense, initialparameters
9+
using AbstractNeuralNetworks: Chain, Dense, params
1010
1111
input_dim = 2
1212
output_dim = 1
@@ -23,7 +23,7 @@ using Symbolics
2323
using Latexify: latexify
2424
2525
@variables sinput[1:input_dim]
26-
soutput = nn.model(sinput, nn.params)
26+
soutput = nn.model(sinput, params(nn))
2727
2828
soutput
2929
```
@@ -101,7 +101,7 @@ We now compare the neural network-approximated curve to the original one:
101101
fig = Figure()
102102
ax = Axis3(fig[1, 1])
103103
104-
surface!(x_vec, y_vec, [c([x, y], nn_cpu.params)[1] for x in x_vec, y in y_vec]; alpha = .8, colormap = :darkterrain, transparency = true)
104+
surface!(x_vec, y_vec, [c([x, y], params(nn_cpu))[1] for x in x_vec, y in y_vec]; alpha = .8, colormap = :darkterrain, transparency = true)
105105
fig
106106
```
107107

scripts/pullback_comparison.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using SymbolicNeuralNetworks
22
using AbstractNeuralNetworks
33
using GeometricMachineLearning
4-
using AbstractNeuralNetworks: FeedForwardLoss
4+
using AbstractNeuralNetworks: FeedForwardLoss, params
55
using GeometricMachineLearning: ZygotePullback
66
import Random
77
Random.seed!(123)
@@ -19,16 +19,16 @@ output = rand(1, batch_size)
1919
# output sensitivities
2020
_do = 1.
2121

22-
# spb(nn_cpu.params, nn.model, (input, output))[2](_do)
23-
# zpb(nn_cpu.params, nn.model, (input, output))[2](_do)
24-
# @time spb_evaluated = spb(nn_cpu.params, nn.model, (input, output))[2](_do)
25-
# @time zpb_evaluated = zpb(nn_cpu.params, nn.model, (input, output))[2](_do)[1].params
22+
# spb(params(nn_cpu), nn.model, (input, output))[2](_do)
23+
# zpb(params(nn_cpu), nn.model, (input, output))[2](_do)
24+
# @time spb_evaluated = spb(params(nn_cpu), nn.model, (input, output))[2](_do)
25+
# @time zpb_evaluated = zpb(params(nn_cpu), nn.model, (input, output))[2](_do)[1].params
2626
# @assert values(spb_evaluated) .≈ values(zpb_evaluated)
2727

28-
function timenn(pb, params, model, input, output, _do = 1.)
29-
pb(params, model, (input, output))[2](_do)
30-
@time pb(params, model, (input, output))[2](_do)
28+
function timenn(pb, _params, model, input, output, _do = 1.)
29+
pb(_params, model, (input, output))[2](_do)
30+
@time pb(_params, model, (input, output))[2](_do)
3131
end
3232

33-
timenn(spb, nn_cpu.params, nn.model, input, output)
34-
timenn(zpb, nn_cpu.params, nn.model, input, output)
33+
timenn(spb, params(nn_cpu), nn.model, input, output)
34+
timenn(zpb, params(nn_cpu), nn.model, input, output)

src/SymbolicNeuralNetworks.jl

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,44 +15,28 @@ module SymbolicNeuralNetworks
1515

1616
RuntimeGeneratedFunctions.init(@__MODULE__)
1717

18-
include("equation_types.jl")
18+
include("custom_definitions_and_extensions/equation_types.jl")
1919

20-
export symbolize
21-
include("utils/symbolize.jl")
20+
include("symbolic_neuralnet/symbolize.jl")
2221

2322
include("utils/create_array.jl")
2423

2524
export AbstractSymbolicNeuralNetwork
26-
export SymbolicNeuralNetwork, SymbolicModel
27-
export HamiltonianSymbolicNeuralNetwork, HNNLoss
28-
export architecture, model, params, equations, functions
25+
export SymbolicNeuralNetwork
2926

30-
# make symbolic parameters (`NeuralNetworkParameters`)
31-
export symbolicparameters
32-
include("layers/abstract.jl")
33-
include("layers/dense.jl")
34-
include("layers/linear.jl")
35-
include("chain.jl")
36-
37-
export evaluate_equations
38-
include("symbolic_neuralnet.jl")
39-
40-
export symbolic_hamiltonian
41-
include("hamiltonian.jl")
27+
include("symbolic_neuralnet/symbolic_neuralnet.jl")
4228

4329
export build_nn_function
44-
include("utils/build_function.jl")
45-
include("utils/build_function2.jl")
46-
include("utils/build_function_arrays.jl")
30+
include("build_function/build_function.jl")
31+
include("build_function/build_function_double_input.jl")
32+
include("build_function/build_function_arrays.jl")
4733

4834
export SymbolicPullback
49-
include("pullback.jl")
35+
include("derivatives/pullback.jl")
5036

5137
include("derivatives/derivative.jl")
5238
include("derivatives/jacobian.jl")
5339
include("derivatives/gradient.jl")
5440

55-
include("custom_equation.jl")
56-
57-
include("utils/latexraw.jl")
41+
include("custom_definitions_and_extensions/latexraw.jl")
5842
end

src/utils/build_function.jl renamed to src/build_function/build_function.jl

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ The functions mentioned in the implementation section were adjusted ad-hoc to de
1919
Other problems may occur. In case you bump into one please [open an issue on github](https://github.com/JuliaGNI/SymbolicNeuralNetworks.jl/issues).
2020
"""
2121
function build_nn_function(eq::EqT, nn::AbstractSymbolicNeuralNetwork)
22-
build_nn_function(eq, nn.params, nn.input)
22+
build_nn_function(eq, params(nn), nn.input)
2323
end
2424

2525
function build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr; reduce = hcat)
@@ -46,25 +46,26 @@ Build a function that can process a matrix. This is used as a starting point for
4646
# Examples
4747
4848
```jldoctest
49-
using SymbolicNeuralNetworks: _build_nn_function, symbolicparameters
50-
using Symbolics
51-
using AbstractNeuralNetworks
49+
using SymbolicNeuralNetworks: _build_nn_function, SymbolicNeuralNetwork
50+
using AbstractNeuralNetworks: params, Chain, Dense, NeuralNetwork
51+
import Random
52+
Random.seed!(123)
5253
5354
c = Chain(Dense(2, 1, tanh))
54-
params = symbolicparameters(c)
55-
@variables sinput[1:2]
56-
eq = c(sinput, params)
57-
built_function = _build_nn_function(eq, params, sinput)
58-
ps = initialparameters(c)
59-
input = rand(2, 2)
60-
61-
(built_function(input, ps, 1), built_function(input, ps, 2)) .≈ (c(input[:, 1], ps), c(input[:, 2], ps))
55+
nn = NeuralNetwork(c)
56+
snn = SymbolicNeuralNetwork(nn)
57+
eq = c(snn.input, params(snn))
58+
built_function = _build_nn_function(eq, params(snn), snn.input)
59+
built_function([1. 2.; 3. 4.], params(nn), 1)
6260
6361
# output
6462
65-
(true, true)
63+
1-element Vector{Float64}:
64+
0.9912108161055604
6665
```
6766
67+
Note that we have to supply an extra argument (index) to `_build_nn_function` that we do not have to supply to [`build_nn_function`](@ref).
68+
6869
# Implementation
6970
7071
This first calls `Symbolics.build_function` with the keyword argument `expression = Val{true}` and then modifies the generated code by calling:
@@ -75,31 +76,16 @@ This first calls `Symbolics.build_function` with the keyword argument `expressio
7576
7677
See the docstrings for those functions for details on how the code is modified.
7778
"""
78-
function _build_nn_function(eq::EqT, params::NeuralNetworkParameters, sinput::Symbolics.Arr)
79+
function _build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr)
7980
sc_eq = Symbolics.scalarize(eq)
80-
code = build_function(sc_eq, sinput, values(params)...; expression = Val{true}) |> _reduce_code
81+
code = build_function(sc_eq, sinput, values(sparams)...; expression = Val{true}) |> _reduce
8182
rewritten_code = fix_map_reduce(modify_input_arguments(rewrite_arguments(fix_create_array(code))))
8283
parallelized_code = make_kernel(rewritten_code)
8384
@RuntimeGeneratedFunction(parallelized_code)
8485
end
8586

86-
"""
87-
_reduce_code(code)
88-
89-
Reduce the code.
90-
91-
For some reason `Symbolics.build_function` sometimes returns a tuple and sometimes it doesn't.
92-
93-
This function takes care of this.
94-
If `build_function` returns a tuple `reduce_code` checks which of the expressions is in-place and then returns the other (not in-place) expression.
95-
"""
96-
function _reduce_code(code::Expr)
97-
code
98-
end
99-
100-
function _reduce_code(code::Tuple{Expr, Expr})
101-
contains(string(code[1]), "ˍ₋out") ? code[2] : code[1]
102-
end
87+
_reduce(a) = a
88+
_reduce(a::Tuple) = a[1]
10389

10490
"""
10591
rewrite_arguments(s)

src/utils/build_function_arrays.jl renamed to src/build_function/build_function_arrays.jl

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,29 @@
11
"""
22
build_nn_function(eqs::AbstractArray{<:NeuralNetworkParameters}, sparams, sinput...)
33
4-
Build an executable function based on `eqs` that potentially also has a symbolic output.
4+
Build an executable function based on an array of symbolic equations `eqs`.
55
66
# Examples
77
88
```jldoctest
99
using SymbolicNeuralNetworks: build_nn_function, SymbolicNeuralNetwork
10-
using AbstractNeuralNetworks: Chain, Dense, initialparameters, NeuralNetworkParameters
10+
using AbstractNeuralNetworks: Chain, Dense, NeuralNetwork, params
1111
import Random
1212
Random.seed!(123)
1313
1414
ch = Chain(Dense(2, 1, tanh))
15-
nn = SymbolicNeuralNetwork(ch)
16-
eqs = [(a = ch(nn.input, nn.params), b = ch(nn.input, nn.params).^2), (c = ch(nn.input, nn.params).^3, )]
17-
funcs = build_nn_function(eqs, nn.params, nn.input)
15+
nn = NeuralNetwork(ch)
16+
snn = SymbolicNeuralNetwork(nn)
17+
eqs = [(a = ch(snn.input, params(snn)), b = ch(snn.input, params(snn)).^2), (c = ch(snn.input, params(snn)).^3, )]
18+
funcs = build_nn_function(eqs, params(snn), snn.input)
1819
input = [1., 2.]
19-
ps = initialparameters(ch) |> NeuralNetworkParameters
20-
a = ch(input, ps)
21-
b = ch(input, ps).^2
22-
c = ch(input, ps).^3
23-
funcs_evaluated = funcs(input, ps)
24-
25-
(funcs_evaluated[1].a, funcs_evaluated[1].b, funcs_evaluated[2].c) .≈ (a, b, c)
20+
funcs_evaluated = funcs(input, params(nn))
2621
2722
# output
2823
29-
(true, true, true)
24+
2-element Vector{NamedTuple}:
25+
(a = [0.985678060655224], b = [0.9715612392570434])
26+
(c = [0.9576465981186686],)
3027
```
3128
"""
3229
function build_nn_function(eqs::AbstractArray{<:Union{NamedTuple, NeuralNetworkParameters}}, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...; reduce = hcat)
@@ -47,25 +44,21 @@ Return a function that takes an input, (optionally) an output and neural network
4744
4845
```jldoctest
4946
using SymbolicNeuralNetworks: build_nn_function, SymbolicNeuralNetwork
50-
using AbstractNeuralNetworks: Chain, Dense, initialparameters, NeuralNetworkParameters
47+
using AbstractNeuralNetworks: Chain, Dense, NeuralNetwork, params
5148
import Random
5249
Random.seed!(123)
5350
5451
c = Chain(Dense(2, 1, tanh))
55-
nn = SymbolicNeuralNetwork(c)
56-
eqs = (a = c(nn.input, nn.params), b = c(nn.input, nn.params).^2)
57-
funcs = build_nn_function(eqs, nn.params, nn.input)
52+
nn = NeuralNetwork(c)
53+
snn = SymbolicNeuralNetwork(nn)
54+
eqs = (a = c(snn.input, params(snn)), b = c(snn.input, params(snn)).^2)
55+
funcs = build_nn_function(eqs, params(snn), snn.input)
5856
input = [1., 2.]
59-
ps = initialparameters(c) |> NeuralNetworkParameters
60-
a = c(input, ps)
61-
b = c(input, ps).^2
62-
funcs_evaluated = funcs(input, ps)
63-
64-
(funcs_evaluated.a, funcs_evaluated.b) .≈ (a, b)
57+
funcs_evaluated = funcs(input, params(nn))
6558
6659
# output
6760
68-
(true, true)
61+
(a = [0.985678060655224], b = [0.9715612392570434])
6962
```
7063
7164
# Implementation
@@ -90,16 +83,17 @@ Return an executable function for each entry in `eqs`. This still has to be proc
9083
9184
```jldoctest
9285
using SymbolicNeuralNetworks: function_valued_parameters, SymbolicNeuralNetwork
93-
using AbstractNeuralNetworks: Chain, Dense, initialparameters, NeuralNetworkParameters
86+
using AbstractNeuralNetworks: Chain, Dense, NeuralNetwork, params
9487
import Random
9588
Random.seed!(123)
9689
9790
c = Chain(Dense(2, 1, tanh))
98-
nn = SymbolicNeuralNetwork(c)
99-
eqs = (a = c(nn.input, nn.params), b = c(nn.input, nn.params).^2)
100-
funcs = function_valued_parameters(eqs, nn.params, nn.input)
91+
nn = NeuralNetwork(c)
92+
snn = SymbolicNeuralNetwork(nn)
93+
eqs = (a = c(snn.input, params(snn)), b = c(snn.input, params(snn)).^2)
94+
funcs = function_valued_parameters(eqs, params(snn), snn.input)
10195
input = [1., 2.]
102-
ps = initialparameters(c) |> NeuralNetworkParameters
96+
ps = params(nn)
10397
a = c(input, ps)
10498
b = c(input, ps).^2
10599

src/utils/build_function2.jl renamed to src/build_function/build_function_double_input.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Also compare this to [`build_nn_function(::EqT, ::AbstractSymbolicNeuralNetwork)
1313
See the *extended help section* of [`build_nn_function(::EqT, ::AbstractSymbolicNeuralNetwork)`](@ref).
1414
"""
1515
function build_nn_function(eqs, nn::AbstractSymbolicNeuralNetwork, soutput)
16-
build_nn_function(eqs, nn.params, nn.input, soutput)
16+
build_nn_function(eqs, params(nn), nn.input, soutput)
1717
end
1818

1919
function build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr, soutput::Symbolics.Arr; reduce = hcat)
@@ -59,7 +59,7 @@ See the docstrings for those functions for details on how the code is modified.
5959
"""
6060
function _build_nn_function(eq::EqT, params::NeuralNetworkParameters, sinput::Symbolics.Arr, soutput::Symbolics.Arr)
6161
sc_eq = Symbolics.scalarize(eq)
62-
code = build_function(sc_eq, sinput, soutput, values(params)...; expression = Val{true}) |> _reduce_code
62+
code = build_function(sc_eq, sinput, soutput, values(params)...; expression = Val{true}) |> _reduce
6363
rewritten_code = fix_map_reduce(modify_input_arguments2(rewrite_arguments2(fix_create_array(code))))
6464
parallelized_code = make_kernel2(rewritten_code)
6565
@RuntimeGeneratedFunction(parallelized_code)

src/chain.jl

Lines changed: 0 additions & 5 deletions
This file was deleted.

0 commit comments

Comments
 (0)