Skip to content

Wrong gradient for euler timestepping #65

@vchuravy

Description

@vchuravy
function f(t, y, B)
  return B * y
end

function solve(y0, t_i, t_f, A, dt=0.2)
  t = t_i + dt
  y = copy(y0)
  while t < t_f
    y .+= f(t, y, A) .* dt
    t += dt
    t = min(t, t_f)
  end 
  sum(y) # standin for target
end

t_i = 0.0
t_f = 2.0

y0 = [[1.0 9.0]
      [1.0 9.0]]

A  = [[0 1.0]
      [-100.0 0]]

sol = solve(y0, t_i, t_f, A)
using Enzyme

dy0 = Enzyme.make_zero(y0)
dA = Enzyme.make_zero(A)
autodiff(Reverse, solve, Active, Duplicated(y0, dy0), Const(t_i), Const(t_f), Duplicated(A, dA))

using Checkpointing
function solve2(y0, t_i, t_f, A, dt=0.2)
  t = t_i + dt
  y = copy(y0)
  @ad_checkpoint Online_r2(2) y while t < t_f
    y .+= f(t, y, A) .* dt
    t += dt
    t = min(t, t_f)
  end
  sum(y) # standin for target
end

sol = solve2(y0, t_i, t_f, A)

dy0_2 = Enzyme.make_zero(y0)
dA_2 = Enzyme.make_zero(A)
autodiff(Reverse, solve2, Active, Duplicated(y0, dy0_2), Const(t_i), Const(t_f), Duplicated(A, dA_2))

@sriharikrishna and I have found dy0 and dA to be in agreement with diffrax, but dy0 and dA differ with Checkpointing vs Enzyme directly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions