Skip to content

Commit fe5f667

Browse files
committed
Add statistics
1 parent 8671a2d commit fe5f667

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

src/neural_de.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -254,15 +254,16 @@ function (n::NeuralDAE)(u_du::Tuple, p, st)
254254
nn_out = model(vcat(u, du), p)
255255
alg_out = n.constraints_model(u, p, t)
256256
iter_nn, iter_const = 0, 0
257-
map(n.differential_vars) do isdiff
257+
res = map(n.differential_vars) do isdiff
258258
if isdiff
259259
iter_nn += 1
260-
selectdim(nn_out, 1, iter_nn)
260+
nn_out[iter_nn]
261261
else
262262
iter_const += 1
263-
selectdim(alg_out, 1, iter_const)
263+
alg_out[iter_const]
264264
end
265265
end
266+
return res
266267
end
267268

268269
prob = DAEProblem{false}(f, du0, u0, n.tspan, p; n.differential_vars)

test/neural_dae.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using ComponentArrays, DiffEqFlux, Zygote, Optimization, OrdinaryDiffEq, Random
1+
using ComponentArrays,
2+
DiffEqFlux, Zygote, Optimization, OptimizationOptimJL, OrdinaryDiffEq, Random
23

34
#A desired MWE for now, not a test yet.
45

@@ -27,6 +28,7 @@ tspan = (0.0, 10.0)
2728
ndae = NeuralDAE(dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1], tspan, DImplicitEuler();
2829
differential_vars = [true, true, false])
2930
ps, st = Lux.setup(Xoshiro(0), ndae)
31+
ps = ComponentArray(ps)
3032
truedu0 = similar(u₀)
3133

3234
ndae((u₀, truedu0), ps, st)
@@ -36,13 +38,11 @@ predict_n_dae(p) = first(ndae(u₀, p, st))
3638
function loss(p)
3739
pred = predict_n_dae(p)
3840
loss = sum(abs2, sol .- pred)
39-
loss, pred
41+
return loss, pred
4042
end
4143

42-
p = p .+ rand(3) .* p
43-
4444
optfunc = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote())
45-
optprob = Optimization.OptimizationProblem(optfunc, p)
45+
optprob = Optimization.OptimizationProblem(optfunc, ps)
4646
res = Optimization.solve(optprob, BFGS(; initial_stepnorm = 0.0001))
4747

4848
# Same stuff with Lux

test/neural_gde.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using DiffEqFlux, ComponentArrays, GeometricFlux, GraphSignals, OrdinaryDiffEq, Random,
2-
Test, OptimizationOptimisers, Optimization
2+
Test, OptimizationOptimisers, Optimization, Statistics
33
import Flux
44

55
# Fully Connected Graph

0 commit comments

Comments
 (0)