-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Description
The pyadjoint tape references backend variables. This means that any memory allocated for the forward variables, during the forward calculation, is referenced by the tape. This can prevent memory usage being reduced by checkpointing.
Example
from firedrake import *
from firedrake.adjoint import *
from checkpoint_schedules import MultistageCheckpointSchedule
N = 100
mesh = UnitIntervalMesh(1)
space = FunctionSpace(mesh, "Lagrange", 1)
tape = get_working_tape()
tape.enable_checkpointing(MultistageCheckpointSchedule(N, 3, 0))
u = Function(space, name="u").interpolate(Constant(2.0))
continue_annotation()
for _ in tape.timestepper(iter(range(N))):
u_ = Function(space, name="u")
assemble(Interpolate(u + u, space), tensor=u_)
u = u_
del u_
pause_annotation()
del u
deps = set()
outputs = set()
for block in tape._blocks:
for dep in block.get_dependencies():
if isinstance(dep.output, Function):
deps.add(dep.output.count())
for dep in block.get_outputs():
if isinstance(dep.output, Function):
outputs.add(dep.output.count())
print(f"{len(deps)=}")
print(f"{len(outputs)=}")
leads to output
len(deps)=100
len(outputs)=100
Metadata
Metadata
Assignees
Labels
No labels