Skip to content

Commit ff8c6ee

Browse files
authored
Merge pull request #20 from JuliaGNI/fix-reshape
Fix reshape
2 parents 5343bae + 81c7135 commit ff8c6ee

File tree

9 files changed

+92
-119
lines changed

9 files changed

+92
-119
lines changed

docs/src/hamiltonian_neural_network.md

Lines changed: 0 additions & 99 deletions
This file was deleted.

src/SymbolicNeuralNetworks.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ module SymbolicNeuralNetworks
1919

2020
include("symbolic_neuralnet/symbolize.jl")
2121

22+
include("utils/create_array.jl")
23+
2224
export AbstractSymbolicNeuralNetwork
2325
export SymbolicNeuralNetwork
2426

src/build_function/build_function.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,19 @@ function build_nn_function(eq::EqT, nn::AbstractSymbolicNeuralNetwork)
2222
build_nn_function(eq, params(nn), nn.input)
2323
end
2424

25-
function build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr)
25+
function build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr; reduce = hcat)
2626
gen_fun = _build_nn_function(eq, sparams, sinput)
27-
gen_fun_returned(x, ps) = mapreduce(k -> gen_fun(x, ps, k), hcat, axes(x, 2))
28-
gen_fun_returned(x::Union{AbstractVector, Symbolics.Arr}, ps) = gen_fun_returned(reshape(x, length(x), 1), ps)
27+
gen_fun_returned(x, ps) = mapreduce(k -> gen_fun(x, ps, k), reduce, axes(x, 2))
28+
function gen_fun_returned(x::Union{AbstractVector, Symbolics.Arr}, ps)
29+
output_not_reshaped = gen_fun_returned(reshape(x, length(x), 1), ps)
30+
# for vectors we do not reshape, as the output may be a matrix
31+
output_not_reshaped
32+
end
2933
# check this! (definitely not correct in all cases!)
30-
gen_fun_returned(x::AbstractArray{<:Number, 3}, ps) = reshape(gen_fun_returned(reshape(x, size(x, 1), size(x, 2) * size(x, 3)), ps), size(x, 1), size(x, 2), size(x, 3))
34+
function gen_fun_returned(x::AbstractArray{<:Number, 3}, ps)
35+
output_not_reshaped = gen_fun_returned(reshape(x, size(x, 1), size(x, 2) * size(x, 3)), ps)
36+
reshape(output_not_reshaped, size(output_not_reshaped, 1), size(x, 2), size(x, 3))
37+
end
3138
gen_fun_returned
3239
end
3340

src/build_function/build_function_arrays.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ funcs_evaluated = funcs(input, params(nn))
2626
(c = [0.9576465981186686],)
2727
```
2828
"""
29-
function build_nn_function(eqs::AbstractArray{<:Union{NamedTuple, NeuralNetworkParameters}}, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...)
30-
ps_semi = [function_valued_parameters(eq, sparams, sinput...) for eq in eqs]
29+
function build_nn_function(eqs::AbstractArray{<:Union{NamedTuple, NeuralNetworkParameters}}, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...; reduce = hcat)
30+
ps_semi = [function_valued_parameters(eq, sparams, sinput...; reduce = reduce) for eq in eqs]
3131

3232
_pbs_executable(ps_functions, params, input...) = apply_element_wise(ps_functions, params, input...)
3333
__pbs_executable(input, params) = _pbs_executable(ps_semi, params, input)
@@ -65,8 +65,8 @@ funcs_evaluated = funcs(input, params(nn))
6565
6666
Internally this is using [`function_valued_parameters`](@ref) and [`apply_element_wise`](@ref).
6767
"""
68-
function build_nn_function(eqs::Union{NamedTuple, NeuralNetworkParameters}, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...)
69-
ps = function_valued_parameters(eqs, sparams, sinput...)
68+
function build_nn_function(eqs::Union{NamedTuple, NeuralNetworkParameters}, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...; reduce = hcat)
69+
ps = function_valued_parameters(eqs, sparams, sinput...; reduce = reduce)
7070
_pbs_executable(ps::Union{NamedTuple, NeuralNetworkParameters}, params::NeuralNetworkParameters, input::AbstractArray...) = apply_element_wise(ps, params, input...)
7171
__pbs_executable(input::AbstractArray, params::NeuralNetworkParameters) = _pbs_executable(ps, params, input)
7272
# return this one if sinput & soutput are supplied
@@ -104,13 +104,13 @@ b = c(input, ps).^2
104104
(true, true)
105105
```
106106
"""
107-
function function_valued_parameters(eqs::NeuralNetworkParameters, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...)
108-
vals = Tuple(build_nn_function(eqs[key], sparams, sinput...) for key in keys(eqs))
107+
function function_valued_parameters(eqs::NeuralNetworkParameters, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...; reduce = hcat)
108+
vals = Tuple(build_nn_function(eqs[key], sparams, sinput...; reduce = reduce) for key in keys(eqs))
109109
NeuralNetworkParameters{keys(eqs)}(vals)
110110
end
111111

112-
function function_valued_parameters(eqs::NamedTuple, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...)
113-
vals = Tuple(build_nn_function(eqs[key], sparams, sinput...) for key in keys(eqs))
112+
function function_valued_parameters(eqs::NamedTuple, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...; reduce = hcat)
113+
vals = Tuple(build_nn_function(eqs[key], sparams, sinput...; reduce = reduce) for key in keys(eqs))
114114
NamedTuple{keys(eqs)}(vals)
115115
end
116116

src/build_function/build_function_double_input.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
build_nn_function(eqs, nn, soutput)
33
4-
Build an executable function that can also depend on an output. It is then called with:
4+
Build an executable function that can also depend on an output. The resulting `built_function` is then called with:
55
```julia
66
built_function(input, output, ps)
77
```
@@ -16,14 +16,32 @@ function build_nn_function(eqs, nn::AbstractSymbolicNeuralNetwork, soutput)
1616
build_nn_function(eqs, params(nn), nn.input, soutput)
1717
end
1818

19-
function build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr, soutput::Symbolics.Arr)
19+
function build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr, soutput::Symbolics.Arr; reduce = hcat)
20+
@assert ( (reduce == hcat) || (reduce == +) ) "Keyword reduce either has to be + or hcat!"
2021
gen_fun = _build_nn_function(eq, sparams, sinput, soutput)
21-
gen_fun_returned(input, output, ps) = mapreduce(k -> gen_fun(input, output, ps, k), +, axes(input, 2))
22-
gen_fun_returned(input::AT, output::AT, ps) where {AT <: Union{AbstractVector, Symbolics.Arr}} = gen_fun_returned(reshape(input, length(input), 1), reshape(output, length(output), 1), ps)
23-
gen_fun_returned(input::AT, output::AT, ps) where {T, AT <: AbstractArray{T, 3}} = gen_fun_returned(reshape(input, size(input, 1), size(input, 2) * size(input, 3)), reshape(output, size(output, 1), size(output, 2) * size(output, 3)), ps)
22+
gen_fun_returned(input, output, ps) = mapreduce(k -> gen_fun(input, output, ps, k), reduce, axes(input, 2))
23+
function gen_fun_returned(x::AT, y::AT, ps) where {AT <: Union{AbstractVector, Symbolics.Arr}}
24+
output_not_reshaped = gen_fun_returned(reshape(x, length(x), 1), reshape(y, length(y), 1), ps)
25+
# for vectors we do not reshape, as the output may be a matrix
26+
output_not_reshaped
27+
end
28+
# check this! (definitely not correct in all cases!)
29+
function gen_fun_returned(x::AT, y::AT, ps) where {AT <: AbstractArray{<:Number, 3}}
30+
output_not_reshaped = gen_fun_returned(reshape(x, size(x, 1), size(x, 2) * size(x, 3)), reshape(y, size(y, 1), size(y, 2) * size(y, 3)), ps)
31+
# if arrays are added together then don't reshape!
32+
optional_reshape(output_not_reshaped, reduce, x)
33+
end
2434
gen_fun_returned
2535
end
2636

37+
function optional_reshape(output_not_reshaped::AbstractVecOrMat, ::typeof(+), ::AbstractArray{<:Number, 3})
38+
output_not_reshaped
39+
end
40+
41+
function optional_reshape(output_not_reshaped::AbstractVecOrMat, ::typeof(hcat), input::AbstractArray{<:Number, 3})
42+
reshape(output_not_reshaped, size(output_not_reshaped, 1), size(input, 2), size(input, 3))
43+
end
44+
2745
"""
2846
_build_nn_function(eq, params, sinput, soutput)
2947

src/derivatives/pullback.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function SymbolicPullback(nn::SymbolicNeuralNetwork, loss::NetworkLoss)
9292
@variables soutput[1:output_dimension(nn.model)]
9393
symbolic_loss = loss(nn.model, params(nn), nn.input, soutput)
9494
symbolic_pullbacks = symbolic_pullback(symbolic_loss, nn)
95-
pbs_executable = build_nn_function(symbolic_pullbacks, params(nn), nn.input, soutput)
95+
pbs_executable = build_nn_function(symbolic_pullbacks, params(nn), nn.input, soutput; reduce = +)
9696
function pbs(input, output, params)
9797
pullback(::Union{Real, AbstractArray{<:Real}}) = _get_contents(_get_params(pbs_executable(input, output, params)))
9898
pullback
@@ -146,4 +146,4 @@ _get_contents(nt::Tuple{<:Union{NamedTuple, NeuralNetworkParameters}}) = nt[1]
146146
# (_pullback::SymbolicPullback)(ps, model, input_nt::QPTOAT)::Tuple = Zygote.pullback(ps -> _pullback.loss(model, ps, input_nt), ps)
147147
function (_pullback::SymbolicPullback)(ps, model, input_nt_output_nt::Tuple{<:QPTOAT, <:QPTOAT})::Tuple
148148
_pullback.loss(model, ps, input_nt_output_nt...), _pullback.fun(input_nt_output_nt..., ps)
149-
end
149+
end

src/utils/create_array.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# TODO: this shouldn't be there (type piracy); remove once https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/679 has been merged!
2+
function Symbolics.SymbolicUtils.Code.create_array(::Type{<:Base.ReshapedArray{T, N, P}}, S, nd::Val, d::Val, elems...) where {T, N, P}
3+
Symbolics.SymbolicUtils.Code.create_array(P, S, nd, d, elems...)
4+
end

test/reshape_test.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using SymbolicNeuralNetworks
2+
using AbstractNeuralNetworks
3+
using Symbolics
4+
using Test
5+
6+
function set_up_network()
7+
c = Chain(Dense(2, 3))
8+
nn = SymbolicNeuralNetwork(c)
9+
soutput = nn.model(nn.input, nn.params)
10+
nn_cpu = NeuralNetwork(c)
11+
nn, soutput, nn_cpu
12+
end
13+
14+
function test_for_input()
15+
nn, soutput, nn_cpu = set_up_network()
16+
input = rand(2, 5)
17+
input2 = reshape((@view input[:, 1:2]), 2, 1, 2)
18+
built_function = build_nn_function(soutput, nn.params, nn.input)
19+
outputs = built_function(input2, nn_cpu.params)
20+
for i in 1:2
21+
@test nn.model(input[:, i], nn_cpu.params) outputs[:, 1, i]
22+
end
23+
end
24+
25+
function test_for_input_and_output()
26+
nn, soutput2, nn_cpu = set_up_network()
27+
input = rand(2, 5)
28+
output = rand(3, 5)
29+
input2 = reshape((@view input[:, 1:2]), 2, 1, 2)
30+
output2 = reshape((@view output[:, 1:2]), 3, 1, 2)
31+
@variables soutput[1:3]
32+
built_function = build_nn_function((soutput - soutput2).^2, nn.params, nn.input, soutput)
33+
outputs = built_function(input2, output2, nn_cpu.params)
34+
for i in 1:2
35+
@test (nn.model(input[:, i], nn_cpu.params) - output[:, i]).^2 outputs[:, 1, i]
36+
end
37+
end
38+
39+
test_for_input()
40+
test_for_input_and_output()

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
using SymbolicNeuralNetworks
22
using SafeTestsets
33

4+
@safetestset "Check if reshape works in the correct way with the generated functions. " begin include("reshape_test.jl") end
45
@safetestset "Symbolic gradient " begin include("derivatives/symbolic_gradient.jl") end
56
@safetestset "Symbolic Neural network " begin include("derivatives/jacobian.jl") end
67
@safetestset "Symbolic Params " begin include("symbolic_neuralnet/symbolize.jl") end
78
@safetestset "Tests associated with 'build_function.jl' " begin include("build_function/build_function.jl") end
89
@safetestset "Tests associated with 'build_function_double_input.jl' " begin include("build_function/build_function_double_input.jl") end
910
@safetestset "Tests associated with 'build_function_array.jl " begin include("build_function/build_function_arrays.jl") end
10-
@safetestset "Compare Zygote Pullback with Symbolic Pullback " begin include("derivatives/pullback.jl") end
11+
@safetestset "Compare Zygote Pullback with Symbolic Pullback " begin include("derivatives/pullback.jl") end

0 commit comments

Comments
 (0)