-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Description
function f(t, y, B)
return B * y
end
function solve(y0, t_i, t_f, A, dt=0.2)
t = t_i + dt
y = copy(y0)
while t < t_f
y .+= f(t, y, A) .* dt
t += dt
t = min(t, t_f)
end
sum(y) # standin for target
end
t_i = 0.0
t_f = 2.0
y0 = [[1.0 9.0]
[1.0 9.0]]
A = [[0 1.0]
[-100.0 0]]
sol = solve(y0, t_i, t_f, A)
using Enzyme
dy0 = Enzyme.make_zero(y0)
dA = Enzyme.make_zero(A)
autodiff(Reverse, solve, Active, Duplicated(y0, dy0), Const(t_i), Const(t_f), Duplicated(A, dA))
using Checkpointing
function solve2(y0, t_i, t_f, A, dt=0.2)
t = t_i + dt
y = copy(y0)
@ad_checkpoint Online_r2(2) y while t < t_f
y .+= f(t, y, A) .* dt
t += dt
t = min(t, t_f)
end
sum(y) # standin for target
end
sol = solve2(y0, t_i, t_f, A)
dy0_2 = Enzyme.make_zero(y0)
dA_2 = Enzyme.make_zero(A)
autodiff(Reverse, solve2, Active, Duplicated(y0, dy0_2), Const(t_i), Const(t_f), Duplicated(A, dA_2))
@sriharikrishna and I have found dy0
and dA
to be in agreement with diffrax, but dy0
and dA
differ with Checkpointing vs Enzyme directly.
Metadata
Metadata
Assignees
Labels
No labels