Skip to content

Commit 94e9da1

Browse files
committed
Resolved merge conflict.
1 parent 49baf61 commit 94e9da1

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

src/build_function/build_function_arrays.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,17 @@ Return an executable function for each entry in `eqs`. This still has to be proc
8383
8484
```jldoctest
8585
using SymbolicNeuralNetworks: function_valued_parameters, SymbolicNeuralNetwork
86-
using AbstractNeuralNetworks: Chain, Dense,fffff, NeuralNetworkParameters, params
86+
using AbstractNeuralNetworks: Chain, Dense, NeuralNetwork, params
8787
import Random
8888
Random.seed!(123)
8989
9090
c = Chain(Dense(2, 1, tanh))
91-
nn = SymbolicNeuralNetwork(c)
92-
eqs = (a = c(nn.input, params(nn)), b = c(nn.input, params(nn)).^2)
93-
funcs = function_valued_parameters(eqs, params(nn), nn.input)
91+
nn = NeuralNetwork(c)
92+
snn = SymbolicNeuralNetwork(nn)
93+
eqs = (a = c(snn.input, params(snn)), b = c(snn.input, params(snn)).^2)
94+
funcs = function_valued_parameters(eqs, params(snn), snn.input)
9495
input = [1., 2.]
95-
ps = initialparameters(c) |> NeuralNetworkParameters
96+
ps = params(nn)
9697
a = c(input, ps)
9798
b = c(input, ps).^2
9899

src/derivatives/pullback.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,4 @@ _get_contents(nt::Tuple{<:Union{NamedTuple, NeuralNetworkParameters}}) = nt[1]
143143
# (_pullback::SymbolicPullback)(ps, model, input_nt::QPTOAT)::Tuple = Zygote.pullback(ps -> _pullback.loss(model, ps, input_nt), ps)
144144
function (_pullback::SymbolicPullback)(ps, model, input_nt_output_nt::Tuple{<:QPTOAT, <:QPTOAT})::Tuple
145145
_pullback.loss(model, ps, input_nt_output_nt...), _pullback.fun(input_nt_output_nt..., ps)
146-
end
146+
end

0 commit comments

Comments
 (0)