Skip to content

Commit efdc3d8

Browse files
committed
Merge branch 'main' into output-to-f
2 parents 8421be6 + ad63009 commit efdc3d8

File tree

7 files changed

+25
-41
lines changed

7 files changed

+25
-41
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ julia = "1.10"
2828
[extras]
2929
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3030
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
31+
GeometricMachineLearning = "194d25b2-d3f5-49f0-af24-c124f4aa80cc"
3132
Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
3233
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3334
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3435
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3536
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
36-
GeometricMachineLearning = "194d25b2-d3f5-49f0-af24-c124f4aa80cc"
3737

3838
[targets]
3939
test = ["Test", "ForwardDiff", "Random", "Documenter", "Latexify", "SafeTestsets", "Zygote", "GeometricMachineLearning"]

src/build_function/build_function.jl

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,31 +69,16 @@ This first calls `Symbolics.build_function` with the keyword argument `expressio
6969
7070
See the docstrings for those functions for details on how the code is modified.
7171
"""
72-
function _build_nn_function(eq::EqT, params::NeuralNetworkParameters, sinput::Symbolics.Arr)
72+
function _build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr)
7373
sc_eq = Symbolics.scalarize(eq)
74-
code = build_function(sc_eq, sinput, values(params)...; expression = Val{true}) |> _reduce_code
74+
code = build_function(sc_eq, sinput, values(sparams)...; expression = Val{true}) |> _reduce
7575
rewritten_code = fix_map_reduce(modify_input_arguments(rewrite_arguments(fix_create_array(code))))
7676
parallelized_code = make_kernel(rewritten_code)
7777
@RuntimeGeneratedFunction(parallelized_code)
7878
end
7979

80-
"""
81-
_reduce_code(code)
82-
83-
Reduce the code.
84-
85-
For some reason `Symbolics.build_function` sometimes returns a tuple and sometimes it doesn't.
86-
87-
This function takes care of this.
88-
If `build_function` returns a tuple `reduce_code` checks which of the expressions is in-place and then returns the other (not in-place) expression.
89-
"""
90-
function _reduce_code(code::Expr)
91-
code
92-
end
93-
94-
function _reduce_code(code::Tuple{Expr, Expr})
95-
contains(string(code[1]), "ˍ₋out") ? code[2] : code[1]
96-
end
80+
_reduce(a) = a
81+
_reduce(a::Tuple) = a[1]
9782

9883
"""
9984
rewrite_arguments(s)

src/build_function/build_function_double_input.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ See the docstrings for those functions for details on how the code is modified.
4242
"""
4343
function _build_nn_function(eq::EqT, params::NeuralNetworkParameters, sinput::Symbolics.Arr, soutput::Symbolics.Arr)
4444
sc_eq = Symbolics.scalarize(eq)
45-
code = build_function(sc_eq, sinput, soutput, values(params)...; expression = Val{true}) |> _reduce_code
45+
code = build_function(sc_eq, sinput, soutput, values(params)...; expression = Val{true}) |> _reduce
4646
rewritten_code = fix_map_reduce(modify_input_arguments2(rewrite_arguments2(fix_create_array(code))))
4747
parallelized_code = make_kernel2(rewritten_code)
4848
@RuntimeGeneratedFunction(parallelized_code)

src/derivatives/derivative.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Derivative
33
"""
4-
abstract type Derivative{ST, FT, SDT} end
4+
abstract type Derivative{OT, SDT, ST <: AbstractSymbolicNeuralNetwork} end
55

66
derivative(::DT) where {DT <: Derivative} = error("No method of function `derivative` defined for type $(DT).")
77

src/derivatives/gradient.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ nn = SymbolicNeuralNetwork(c)
3232
3333
Internally the constructors are using [`symbolic_pullback`](@ref).
3434
"""
35-
struct Gradient{ST, FT, SDT} <: Derivative{ST, FT, SDT}
36-
nn::ST
37-
f::FT
35+
struct Gradient{OT, SDT, ST} <: Derivative{OT, SDT, ST}
36+
output::OT
3837
::SDT
38+
nn::ST
3939
end
4040

4141
"""
@@ -65,7 +65,7 @@ derivative(g::Gradient) = g.∇
6565

6666
function Gradient(output::EqT, nn::SymbolicNeuralNetwork)
6767
typeof(output) <: AbstractArray ? nothing : (@warn "You should only use `Gradient` together with array expressions! Maybe you wanted to use `SymbolicPullback`.")
68-
Gradient(nn, output, symbolic_pullback(output, nn))
68+
Gradient(output, symbolic_pullback(output, nn), nn)
6969
end
7070

7171
function Gradient(nn::SymbolicNeuralNetwork)
@@ -86,14 +86,13 @@ using SymbolicNeuralNetworks: SymbolicNeuralNetwork, symbolic_pullback
8686
using AbstractNeuralNetworks
8787
using AbstractNeuralNetworks: params
8888
using LinearAlgebra: norm
89-
using Latexify: latexify
9089
9190
c = Chain(Dense(2, 1, tanh))
9291
nn = SymbolicNeuralNetwork(c)
9392
output = c(nn.input, params(nn))
9493
spb = symbolic_pullback(output, nn)
9594
96-
spb[1].L1.b |> latexify
95+
spb[1].L1.b
9796
```
9897
"""
9998
function symbolic_pullback(f::EqT, nn::AbstractSymbolicNeuralNetwork)::Union{AbstractArray{<:Union{NamedTuple, NeuralNetworkParameters}}, Union{NamedTuple, NeuralNetworkParameters}}

src/derivatives/jacobian.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,28 +74,28 @@ jacobian1(input, ps) ≈ [analytic_jacobian(i, j) for j ∈ 1:output_dim, i ∈
7474
true
7575
```
7676
"""
77-
struct Jacobian{ST, FT, SDT} <: Derivative{ST, FT, SDT}
78-
nn::ST
79-
f::FT
77+
struct Jacobian{OT, SDT, ST} <: Derivative{OT, SDT, ST}
78+
output::OT
8079
::SDT
80+
nn::ST
8181
end
8282

8383
derivative(j::Jacobian) = j.□
8484

85-
function Jacobian(nn::AbstractSymbolicNeuralNetwork)
86-
87-
# Evaluation of the symbolic output
88-
soutput = nn.model(nn.input, params(nn))
89-
90-
Jacobian(soutput, nn)
91-
end
92-
9385
function Jacobian(f::EqT, nn::AbstractSymbolicNeuralNetwork)
9486
# make differential
9587
Dx = symbolic_differentials(nn.input)
9688

9789
# Evaluation of gradient
9890
s∇f = hcat([expand_derivatives.(Symbolics.scalarize(dx(f))) for dx in Dx]...)
9991

100-
Jacobian(nn, f, s∇f)
92+
Jacobian(f, s∇f, nn)
93+
end
94+
95+
function Jacobian(nn::AbstractSymbolicNeuralNetwork)
96+
97+
# Evaluation of the symbolic output
98+
soutput = nn.model(nn.input, params(nn))
99+
100+
Jacobian(soutput, nn)
101101
end

test/derivatives/jacobian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function test_jacobian(n::Integer, T = Float32)
3232
@test build_nn_function(derivative(g), nn)(input, _params) ForwardDiff.jacobian(input -> c(input, _params), input)
3333
end
3434

35-
for n 1:10
35+
for n 10:1
3636
for T (Float32, Float64)
3737
test_jacobian(n, T)
3838
end

0 commit comments

Comments
 (0)