Skip to content

Commit 1e3a10e

Browse files
committed
Reduce has to be optional (depending on whether we compute derivatives of neural network parameters or other expressions.
1 parent 7d09283 commit 1e3a10e

File tree

4 files changed

+23
-14
lines changed

4 files changed

+23
-14
lines changed

src/pullback.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ function SymbolicPullback(nn::SymbolicNeuralNetwork, loss::NetworkLoss)
9393
@variables soutput[1:output_dimension(nn.model)]
9494
symbolic_loss = loss(nn.model, nn.params, nn.input, soutput)
9595
symbolic_pullbacks = symbolic_pullback(symbolic_loss, nn)
96-
pbs_executable = build_nn_function(symbolic_pullbacks, nn.params, nn.input, soutput)
96+
pbs_executable = build_nn_function(symbolic_pullbacks, nn.params, nn.input, soutput; reduce = +)
9797
function pbs(input, output, params)
9898
pullback(::Union{Real, AbstractArray{<:Real}}) = _get_contents(_get_params(pbs_executable(input, output, params)))
9999
pullback

src/utils/build_function.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ function build_nn_function(eq::EqT, nn::AbstractSymbolicNeuralNetwork)
2222
build_nn_function(eq, nn.params, nn.input)
2323
end
2424

25-
function build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr)
25+
function build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr; reduce = hcat)
2626
gen_fun = _build_nn_function(eq, sparams, sinput)
27-
gen_fun_returned(x, ps) = mapreduce(k -> gen_fun(x, ps, k), hcat, axes(x, 2))
27+
gen_fun_returned(x, ps) = mapreduce(k -> gen_fun(x, ps, k), reduce, axes(x, 2))
2828
function gen_fun_returned(x::Union{AbstractVector, Symbolics.Arr}, ps)
2929
output_not_reshaped = gen_fun_returned(reshape(x, length(x), 1), ps)
3030
# for vectors we do not reshape, as the output may be a matrix

src/utils/build_function2.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ function build_nn_function(eqs, nn::AbstractSymbolicNeuralNetwork, soutput)
1616
build_nn_function(eqs, nn.params, nn.input, soutput)
1717
end
1818

19-
function build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr, soutput::Symbolics.Arr)
19+
function build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr, soutput::Symbolics.Arr; reduce = hcat)
2020
gen_fun = _build_nn_function(eq, sparams, sinput, soutput)
21-
gen_fun_returned(input, output, ps) = mapreduce(k -> gen_fun(input, output, ps, k), hcat, axes(input, 2))
21+
gen_fun_returned(input, output, ps) = mapreduce(k -> gen_fun(input, output, ps, k), reduce, axes(input, 2))
2222
function gen_fun_returned(x::AT, y::AT, ps) where {AT <: Union{AbstractVector, Symbolics.Arr}}
2323
output_not_reshaped = gen_fun_returned(reshape(x, length(x), 1), reshape(y, length(y), 1), ps)
2424
# for vectors we do not reshape, as the output may be a matrix
@@ -27,11 +27,20 @@ function build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Sy
2727
# check this! (definitely not correct in all cases!)
2828
function gen_fun_returned(x::AT, y::AT, ps) where {AT <: AbstractArray{<:Number, 3}}
2929
output_not_reshaped = gen_fun_returned(reshape(x, size(x, 1), size(x, 2) * size(x, 3)), reshape(y, size(y, 1), size(y, 2) * size(y, 3)), ps)
30-
reshape(output_not_reshaped, size(output_not_reshaped, 1), size(x, 2), size(x, 3))
30+
# if arrays are added together then don't reshape!
31+
optional_reshape(output_not_reshaped, reduce, x)
3132
end
3233
gen_fun_returned
3334
end
3435

36+
function optional_reshape(output_not_reshaped::AbstractVecOrMat, ::typeof(+), ::AbstractArray{<:Number, 3})
37+
output_not_reshaped
38+
end
39+
40+
function optional_reshape(output_not_reshaped::AbstractVecOrMat, ::typeof(hcat), input::AbstractArray{<:Number, 3})
41+
reshape(output_not_reshaped, size(output_not_reshaped, 1), size(input, 2), size(input, 3))
42+
end
43+
3544
"""
3645
_build_nn_function(eq, params, sinput, soutput)
3746

src/utils/build_function_arrays.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ funcs_evaluated = funcs(input, ps)
2929
(true, true, true)
3030
```
3131
"""
32-
function build_nn_function(eqs::AbstractArray{<:Union{NamedTuple, NeuralNetworkParameters}}, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...)
33-
ps_semi = [function_valued_parameters(eq, sparams, sinput...) for eq in eqs]
32+
function build_nn_function(eqs::AbstractArray{<:Union{NamedTuple, NeuralNetworkParameters}}, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...; reduce = hcat)
33+
ps_semi = [function_valued_parameters(eq, sparams, sinput...; reduce = reduce) for eq in eqs]
3434

3535
_pbs_executable(ps_functions, params, input...) = apply_element_wise(ps_functions, params, input...)
3636
__pbs_executable(input, params) = _pbs_executable(ps_semi, params, input)
@@ -72,8 +72,8 @@ funcs_evaluated = funcs(input, ps)
7272
7373
Internally this is using [`function_valued_parameters`](@ref) and [`apply_element_wise`](@ref).
7474
"""
75-
function build_nn_function(eqs::Union{NamedTuple, NeuralNetworkParameters}, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...)
76-
ps = function_valued_parameters(eqs, sparams, sinput...)
75+
function build_nn_function(eqs::Union{NamedTuple, NeuralNetworkParameters}, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...; reduce = hcat)
76+
ps = function_valued_parameters(eqs, sparams, sinput...; reduce = reduce)
7777
_pbs_executable(ps::Union{NamedTuple, NeuralNetworkParameters}, params::NeuralNetworkParameters, input::AbstractArray...) = apply_element_wise(ps, params, input...)
7878
__pbs_executable(input::AbstractArray, params::NeuralNetworkParameters) = _pbs_executable(ps, params, input)
7979
# return this one if sinput & soutput are supplied
@@ -110,13 +110,13 @@ b = c(input, ps).^2
110110
(true, true)
111111
```
112112
"""
113-
function function_valued_parameters(eqs::NeuralNetworkParameters, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...)
114-
vals = Tuple(build_nn_function(eqs[key], sparams, sinput...) for key in keys(eqs))
113+
function function_valued_parameters(eqs::NeuralNetworkParameters, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...; reduce = hcat)
114+
vals = Tuple(build_nn_function(eqs[key], sparams, sinput...; reduce = reduce) for key in keys(eqs))
115115
NeuralNetworkParameters{keys(eqs)}(vals)
116116
end
117117

118-
function function_valued_parameters(eqs::NamedTuple, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...)
119-
vals = Tuple(build_nn_function(eqs[key], sparams, sinput...) for key in keys(eqs))
118+
function function_valued_parameters(eqs::NamedTuple, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr...; reduce = hcat)
119+
vals = Tuple(build_nn_function(eqs[key], sparams, sinput...; reduce = reduce) for key in keys(eqs))
120120
NamedTuple{keys(eqs)}(vals)
121121
end
122122

0 commit comments

Comments
 (0)