Skip to content

Commit 30a0045

Browse files
committed
Fixed docstring problem.
1 parent a260ae5 commit 30a0045

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/derivatives/pullback.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ loss_and_pullback = pb(nn.params, nn.model, input_output)
5757
pb_values = loss_and_pullback[2](1)
5858
5959
@variables soutput[1:SymbolicNeuralNetworks.output_dimension(nn.model)]
60-
symbolic_pullbacks = SymbolicNeuralNetworks.symbolic_pullback(loss(nn.model, nn.params, nn.input, soutput), nn)
61-
pb_values2 = build_nn_function(symbolic_pullbacks, nn.params, nn.input, soutput)(input_output[1], input_output[2], ps)
60+
symbolic_pullbacks = SymbolicNeuralNetworks.symbolic_pullback(loss(nn.model, snn.params, snn.input, soutput), snn)
61+
pb_values2 = build_nn_function(symbolic_pullbacks, snn.params, snn.input, soutput)(input_output[1], input_output[2], nn.params)
6262
63-
pb_values == (pb_values2 |> SymbolicNeuralNetworks._get_params |> SymbolicNeuralNetworks._get_contents)
63+
pb_values == (pb_values2 |> SymbolicNeuralNetworks._get_contents |> SymbolicNeuralNetworks._get_params)
6464
6565
# output
6666
@@ -128,15 +128,15 @@ _get_contents([(a = "element_contained_in_vector", )])
128128
```
129129
"""
130130
_get_contents(nt::NamedTuple) = nt
131-
function _get_contents(nt::AbstractVector{<:NamedTuple})
131+
function _get_contents(nt::AbstractVector{<:Union{NamedTuple, NeuralNetworkParameters}})
132132
length(nt) == 1 ? nt[1] : __get_contents(nt)
133133
end
134-
function __get_contents(nt::AbstractArray{<:NamedTuple})
134+
function __get_contents(nt::AbstractArray{<:Union{NamedTuple, NeuralNetworkParameters}})
135135
@warn "The pullback returns an array expression. There is probably a bug in the code somewhere."
136136
nt
137137
end
138-
_get_contents(nt::AbstractArray{<:NamedTuple}) = __get_contents(nt)
139-
_get_contents(nt::Tuple{<:NamedTuple}) = nt[1]
138+
_get_contents(nt::AbstractArray{<:Union{NamedTuple, NeuralNetworkParameters}}) = __get_contents(nt)
139+
_get_contents(nt::Tuple{<:Union{NamedTuple, NeuralNetworkParameters}}) = nt[1]
140140

141141
# (_pullback::SymbolicPullback)(ps, model, input_nt::QPTOAT)::Tuple = Zygote.pullback(ps -> _pullback.loss(model, ps, input_nt), ps)
142142
function (_pullback::SymbolicPullback)(ps, model, input_nt_output_nt::Tuple{<:QPTOAT, <:QPTOAT})::Tuple

0 commit comments

Comments
 (0)