Skip to content

Checkpointing and referencing of variables #169

@jrmaddison

Description

@jrmaddison

The pyadjoint tape references backend variables. This means that any memory allocated for the forward variables, during the forward calculation, is referenced by the tape. This can prevent memory usage being reduced by checkpointing.

Example

from firedrake import *
from firedrake.adjoint import *
from checkpoint_schedules import MultistageCheckpointSchedule

N = 100

mesh = UnitIntervalMesh(1)
space = FunctionSpace(mesh, "Lagrange", 1)

tape = get_working_tape()
tape.enable_checkpointing(MultistageCheckpointSchedule(N, 3, 0))

u = Function(space, name="u").interpolate(Constant(2.0))
continue_annotation()
for _ in tape.timestepper(iter(range(N))):
    u_ = Function(space, name="u")
    assemble(Interpolate(u + u, space), tensor=u_)
    u = u_
    del u_
pause_annotation()
del u

deps = set()
outputs = set()
for block in tape._blocks:
    for dep in block.get_dependencies():
        if isinstance(dep.output, Function):
            deps.add(dep.output.count())
    for dep in block.get_outputs():
        if isinstance(dep.output, Function):
            outputs.add(dep.output.count())

print(f"{len(deps)=}")
print(f"{len(outputs)=}")

leads to output

len(deps)=100
len(outputs)=100

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions