@@ -16,7 +16,7 @@ nn = SymbolicNeuralNetwork(c)
16
16
loss = FeedForwardLoss()
17
17
pb = SymbolicPullback(nn, loss)
18
18
ps = initialparameters(c) |> NeuralNetworkParameters
19
- pv_values = pb(ps, nn.model, (rand(2), rand(1)))[2](1) |> typeof
19
+ pb_values = pb(ps, nn.model, (rand(2), rand(1)))[2](1) |> typeof
20
20
21
21
# output
22
22
@@ -47,19 +47,20 @@ import Random
47
47
Random.seed!(123)
48
48
49
49
c = Chain(Dense(2, 1, tanh))
50
- nn = SymbolicNeuralNetwork(c)
50
+ nn = NeuralNetwork(c)
51
+ snn = SymbolicNeuralNetwork(nn)
51
52
loss = FeedForwardLoss()
52
- pb = SymbolicPullback(nn, loss)
53
- ps = initialparameters(c) |> NeuralNetworkParameters
53
+ pb = SymbolicPullback(snn, loss)
54
54
input_output = (rand(2), rand(1))
55
- loss_and_pullback = pb(ps, nn.model, input_output)
56
- pv_values = loss_and_pullback[2](1)
55
+ loss_and_pullback = pb(nn.params, nn.model, input_output)
56
+ # note that we apply the second argument to another input `1`
57
+ pb_values = loss_and_pullback[2](1)
57
58
58
59
@variables soutput[1:SymbolicNeuralNetworks.output_dimension(nn.model)]
59
60
symbolic_pullbacks = SymbolicNeuralNetworks.symbolic_pullback(loss(nn.model, nn.params, nn.input, soutput), nn)
60
- pv_values2 = build_nn_function(symbolic_pullbacks, nn.params, nn.input, soutput)(input_output[1], input_output[2], ps)
61
+ pb_values2 = build_nn_function(symbolic_pullbacks, nn.params, nn.input, soutput)(input_output[1], input_output[2], ps)
61
62
62
- pv_values == (pv_values2 |> SymbolicNeuralNetworks._get_params |> SymbolicNeuralNetworks._get_contents)
63
+ pb_values == (pb_values2 |> SymbolicNeuralNetworks._get_params |> SymbolicNeuralNetworks._get_contents)
63
64
64
65
# output
65
66
@@ -106,6 +107,7 @@ Return the `NamedTuple` that's equivalent to the `NeuralNetworkParameters`.
106
107
"""
107
108
_get_params (nt:: NamedTuple ) = nt
108
109
_get_params (ps:: NeuralNetworkParameters ) = ps. params
110
+ _get_params (ps:: NamedTuple{(:params,), Tuple{NT}} ) where {NT<: NamedTuple } = ps. params
109
111
_get_params (ps:: AbstractArray{<:Union{NamedTuple, NeuralNetworkParameters}} ) = [_get_params (nt) for nt in ps]
110
112
111
113
"""
@@ -134,6 +136,7 @@ function __get_contents(nt::AbstractArray{<:NamedTuple})
134
136
nt
135
137
end
136
138
_get_contents (nt:: AbstractArray{<:NamedTuple} ) = __get_contents (nt)
139
+ _get_contents (nt:: Tuple{<:NamedTuple} ) = nt[1 ]
137
140
138
141
# (_pullback::SymbolicPullback)(ps, model, input_nt::QPTOAT)::Tuple = Zygote.pullback(ps -> _pullback.loss(model, ps, input_nt), ps)
139
142
function (_pullback:: SymbolicPullback )(ps, model, input_nt_output_nt:: Tuple{<:QPTOAT, <:QPTOAT} ):: Tuple
0 commit comments