-
Notifications
You must be signed in to change notification settings - Fork 38
Description
Second part of an issue we have encountered g-adopt/g-adopt#237, consider the case of this reproducer. When using no scheduler by:
python minimal_clearing_cache.py none
we get:
using scheduler: NoneType
J1: 12.520709541433629
dJdm 1: 3205.301642607008
But when using a scheduler by:
python minimal_clearing_cache.py memory
we get:
using scheduler: SingleMemoryStorageSchedule
J1: 12.520709541433629
dJdm 1: 73765.11424588156
Cause
SingleMemoryStorageSchedule
explicitly deletes every checkpoint that did not appear in the previous step’s adjoint_dependencies
:
pyadjoint/pyadjoint/checkpointing.py
Lines 362 to 364 in a4f940a
if isinstance(self._schedule, SingleMemoryStorageSchedule): | |
if step > 1 and var not in self.tape.timesteps[step - 1].adjoint_dependencies: | |
var._checkpoint = None |
If a variable is part of a long-range dependency (e.g. is used every third step), it can be missing from the immediately-preceding dependency set even though it is still required later. The checkpoint is therefore cleared and the reverse pass reconstructs an incorrect value, giving a wrong gradient.
Considering the comment saying “Handle the case for SingleMemoryStorageSchedule”, I would have thought there is a reason for this, but no explanation is given.
Ideas (!?)
-
Remove the special-case clearing and rely on the revised dependency machinery to decide what can be freed, or
-
Only clear if the variable is guaranteed not to re-appear later (e.g. by inspecting timesteps[step:].adjoint_dependencies) ?