Skip to content

indiscriminately clearing checkpoints with SingleMemoryStorageSchedule corrupts the adjoints #211

@sghelichkhani

Description

@sghelichkhani

Second part of an issue we have encountered g-adopt/g-adopt#237, consider the case of this reproducer. When using no scheduler by:

python minimal_clearing_cache.py  none

we get:

using scheduler: NoneType
        J1: 12.520709541433629
        dJdm 1: 3205.301642607008

But when using a scheduler by:

python minimal_clearing_cache.py memory

we get:

using scheduler: SingleMemoryStorageSchedule
        J1: 12.520709541433629
        dJdm 1: 73765.11424588156

Cause

SingleMemoryStorageSchedule explicitly deletes every checkpoint that did not appear in the previous step’s adjoint_dependencies:

if isinstance(self._schedule, SingleMemoryStorageSchedule):
if step > 1 and var not in self.tape.timesteps[step - 1].adjoint_dependencies:
var._checkpoint = None

If a variable is part of a long-range dependency (e.g. is used every third step), it can be missing from the immediately-preceding dependency set even though it is still required later. The checkpoint is therefore cleared and the reverse pass reconstructs an incorrect value, giving a wrong gradient.

Considering the comment saying “Handle the case for SingleMemoryStorageSchedule”, I would have thought there is a reason for this, but no explanation is given.

Ideas (!?)

  • Remove the special-case clearing and rely on the revised dependency machinery to decide what can be freed, or

  • Only clear if the variable is guaranteed not to re-appear later (e.g. by inspecting timesteps[step:].adjoint_dependencies) ?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions