@@ -57,10 +57,10 @@ loss_and_pullback = pb(nn.params, nn.model, input_output)
57
57
pb_values = loss_and_pullback[2](1)
58
58
59
59
@variables soutput[1:SymbolicNeuralNetworks.output_dimension(nn.model)]
60
- symbolic_pullbacks = SymbolicNeuralNetworks.symbolic_pullback(loss(nn.model, nn .params, nn .input, soutput), nn )
61
- pb_values2 = build_nn_function(symbolic_pullbacks, nn .params, nn .input, soutput)(input_output[1], input_output[2], ps )
60
+ symbolic_pullbacks = SymbolicNeuralNetworks.symbolic_pullback(loss(nn.model, snn .params, snn .input, soutput), snn )
61
+ pb_values2 = build_nn_function(symbolic_pullbacks, snn .params, snn .input, soutput)(input_output[1], input_output[2], nn.params )
62
62
63
- pb_values == (pb_values2 |> SymbolicNeuralNetworks._get_params |> SymbolicNeuralNetworks._get_contents )
63
+ pb_values == (pb_values2 |> SymbolicNeuralNetworks._get_contents |> SymbolicNeuralNetworks._get_params )
64
64
65
65
# output
66
66
@@ -128,15 +128,15 @@ _get_contents([(a = "element_contained_in_vector", )])
128
128
```
129
129
"""
130
130
_get_contents (nt:: NamedTuple ) = nt
131
- function _get_contents (nt:: AbstractVector{<:NamedTuple} )
131
+ function _get_contents (nt:: AbstractVector{<:Union{ NamedTuple, NeuralNetworkParameters} } )
132
132
length (nt) == 1 ? nt[1 ] : __get_contents (nt)
133
133
end
134
- function __get_contents (nt:: AbstractArray{<:NamedTuple} )
134
+ function __get_contents (nt:: AbstractArray{<:Union{ NamedTuple, NeuralNetworkParameters} } )
135
135
@warn " The pullback returns an array expression. There is probably a bug in the code somewhere."
136
136
nt
137
137
end
138
- _get_contents (nt:: AbstractArray{<:NamedTuple} ) = __get_contents (nt)
139
- _get_contents (nt:: Tuple{<:NamedTuple} ) = nt[1 ]
138
+ _get_contents (nt:: AbstractArray{<:Union{ NamedTuple, NeuralNetworkParameters} } ) = __get_contents (nt)
139
+ _get_contents (nt:: Tuple{<:Union{ NamedTuple, NeuralNetworkParameters} } ) = nt[1 ]
140
140
141
141
# (_pullback::SymbolicPullback)(ps, model, input_nt::QPTOAT)::Tuple = Zygote.pullback(ps -> _pullback.loss(model, ps, input_nt), ps)
142
142
function (_pullback:: SymbolicPullback )(ps, model, input_nt_output_nt:: Tuple{<:QPTOAT, <:QPTOAT} ):: Tuple
0 commit comments