Skip to content

Commit 81c7135

Browse files
committed
Fixed problem with non-compiling docs (reduce was hcat by default.
1 parent 3e9346b commit 81c7135

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/build_function/build_function_double_input.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
build_nn_function(eqs, nn, soutput)
33
4-
Build an executable function that can also depend on an output. It is then called with:
4+
Build an executable function that can also depend on an output. The resulting `built_function` is then called with:
55
```julia
66
built_function(input, output, ps)
77
```
@@ -17,6 +17,7 @@ function build_nn_function(eqs, nn::AbstractSymbolicNeuralNetwork, soutput)
1717
end
1818

1919
function build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr, soutput::Symbolics.Arr; reduce = hcat)
20+
@assert ( (reduce == hcat) || (reduce == +) ) "Keyword reduce either has to be + or hcat!"
2021
gen_fun = _build_nn_function(eq, sparams, sinput, soutput)
2122
gen_fun_returned(input, output, ps) = mapreduce(k -> gen_fun(input, output, ps, k), reduce, axes(input, 2))
2223
function gen_fun_returned(x::AT, y::AT, ps) where {AT <: Union{AbstractVector, Symbolics.Arr}}

src/derivatives/pullback.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function SymbolicPullback(nn::SymbolicNeuralNetwork, loss::NetworkLoss)
9292
@variables soutput[1:output_dimension(nn.model)]
9393
symbolic_loss = loss(nn.model, params(nn), nn.input, soutput)
9494
symbolic_pullbacks = symbolic_pullback(symbolic_loss, nn)
95-
pbs_executable = build_nn_function(symbolic_pullbacks, params(nn), nn.input, soutput)
95+
pbs_executable = build_nn_function(symbolic_pullbacks, params(nn), nn.input, soutput; reduce = +)
9696
function pbs(input, output, params)
9797
pullback(::Union{Real, AbstractArray{<:Real}}) = _get_contents(_get_params(pbs_executable(input, output, params)))
9898
pullback
@@ -146,4 +146,4 @@ _get_contents(nt::Tuple{<:Union{NamedTuple, NeuralNetworkParameters}}) = nt[1]
146146
# (_pullback::SymbolicPullback)(ps, model, input_nt::QPTOAT)::Tuple = Zygote.pullback(ps -> _pullback.loss(model, ps, input_nt), ps)
147147
function (_pullback::SymbolicPullback)(ps, model, input_nt_output_nt::Tuple{<:QPTOAT, <:QPTOAT})::Tuple
148148
_pullback.loss(model, ps, input_nt_output_nt...), _pullback.fun(input_nt_output_nt..., ps)
149-
end
149+
end

0 commit comments

Comments
 (0)