Skip to content

Commit 48409dd

Browse files
Add support for update placeholders
1 parent bbd4400 commit 48409dd

File tree

3 files changed

+129
-9
lines changed

3 files changed

+129
-9
lines changed

aesara/compile/function/types.py

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66
import time
77
import warnings
88
from itertools import chain
9-
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
9+
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
1010

1111
import numpy as np
12+
from typing_extensions import Literal
1213

1314
import aesara
1415
import aesara.compile.profiling
15-
from aesara.compile.io import In, SymbolicInput, SymbolicOutput
16-
from aesara.compile.ops import deep_copy_op, view_op
16+
from aesara.compile.io import In, Out, SymbolicInput, SymbolicOutput
17+
from aesara.compile.ops import deep_copy_op, update_placeholder, view_op
18+
from aesara.compile.profiling import ProfileStats
1719
from aesara.configdefaults import config
1820
from aesara.graph.basic import (
1921
Constant,
@@ -732,10 +734,10 @@ def checkSV(sv_ori, sv_rpl):
732734
message = name
733735
else:
734736
message = str(profile.message) + " copy"
735-
profile = aesara.compile.profiling.ProfileStats(message=message)
737+
profile = ProfileStats(message=message)
736738
# profile -> object
737739
elif isinstance(profile, str):
738-
profile = aesara.compile.profiling.ProfileStats(message=profile)
740+
profile = ProfileStats(message=profile)
739741

740742
f_cpy = maker.__class__(
741743
inputs=ins,
@@ -1392,13 +1394,35 @@ def check_unused_inputs(inputs, outputs, on_unused_input):
13921394

13931395
@staticmethod
13941396
def prepare_fgraph(
1395-
inputs,
1396-
outputs,
1397-
additional_outputs,
1397+
inputs: List[In],
1398+
outputs: List[Out],
1399+
additional_outputs: List[Out],
13981400
fgraph: FunctionGraph,
13991401
mode: "Mode",
1400-
profile,
1402+
profile: Union[Optional[ProfileStats], Literal[False]],
14011403
):
1404+
r"""Perform rewrites on a graph, insert `DeepCopyOp`\s, and remove unused updates.
1405+
1406+
.. warning::
1407+
1408+
The `additional_outputs` list and `fgraph.outputs` are updated in-place by this method.
1409+
1410+
Parameters
1411+
==========
1412+
inputs
1413+
The wrapped inputs.
1414+
outputs
1415+
The wrapped outputs (i.e. wrapped with `Out`).
1416+
additional_outputs
1417+
Output graphs that essentially serve as updates to mutable `inputs`.
1418+
fgraph
1419+
The `FunctionGraph` to be prepared.
1420+
mode
1421+
The `Mode` that determines--for example--which rewrites are applied.
1422+
profile
1423+
The profile object/setting to use.
1424+
1425+
"""
14021426

14031427
rewriter = mode.optimizer
14041428

@@ -1419,6 +1443,40 @@ def prepare_fgraph(
14191443
rewrite_time = end_rewriter - start_rewriter
14201444
_logger.debug(f"Rewriting took {rewrite_time:f} seconds")
14211445

1446+
fgraph_outputs = tuple(fgraph.outputs)
1447+
update_mappings = tuple(fgraph.update_mapping.items())
1448+
outputs_to_remove = []
1449+
additional_outputs_to_remove = []
1450+
1451+
# Remove unused updates
1452+
for i, (out_idx, in_idx) in enumerate(update_mappings):
1453+
update = fgraph_outputs[out_idx]
1454+
1455+
if update.owner and update.owner.op == update_placeholder:
1456+
1457+
# TODO: Consider removing the corresponding
1458+
# `FunctionGraph` input when it has no other
1459+
# references?
1460+
# updated_var = fgraph_inputs[in_idx]
1461+
# if not fgraph.clients[updated_var]:
1462+
# fgraph.remove_input(updated_var)
1463+
1464+
# Remove the update entry from the wrapped inputs
1465+
inputs[in_idx].update = None
1466+
1467+
# We assume that the orders of `fgraph.update_mapping` and
1468+
# `additional_outputs` correspond (and they should)
1469+
additional_outputs_to_remove.append(additional_outputs[i])
1470+
1471+
outputs_to_remove.append(fgraph.outputs[out_idx])
1472+
1473+
for add_out, out in zip(
1474+
additional_outputs_to_remove, outputs_to_remove
1475+
):
1476+
additional_outputs.remove(add_out)
1477+
fgraph_out_idx = fgraph.outputs.index(out)
1478+
fgraph.remove_output(fgraph_out_idx)
1479+
14221480
# Add deep copy to respect the memory interface
14231481
insert_deepcopy(fgraph, inputs, outputs + additional_outputs)
14241482
finally:

aesara/compile/ops.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,3 +329,45 @@ def make_op(fn):
329329
return FromFunctionOp(fn, itypes, otypes, infer_shape)
330330

331331
return make_op
332+
333+
334+
class UpdatePlaceholder(Op):
335+
"""A placeholder for a `SharedVariable` update that hasn't been set.
336+
337+
These will appear in `FunctionGraph.outputs` and represent potential updates
338+
(i.e. the `updates` argument to `aesara.function` and/or updates specified via
339+
`SharedVariable.default_update`s) that could be specified by rewrites.
340+
341+
When present, these should be removed by non-rewrite steps in the
342+
compilation pipeline (and before any thunks are created for them).
343+
344+
.. note::
345+
346+
One reason these can't be removed during the rewrite passes is that the
347+
`FunctionGraph.outputs` list entries containing them need to be
348+
entirely removed, and we don't want to add/remove
349+
`FunctionGraph.outputs` during rewriting.
350+
351+
"""
352+
353+
view_map = {0: [0]}
354+
355+
def make_node(self, x):
356+
return Apply(self, [x], [x.type()])
357+
358+
def perform(self, node, inp, out): # pragma: no cover
359+
(x,) = inp
360+
(z,) = out
361+
z[0] = x
362+
363+
def __str__(self): # pragma: no cover
364+
return f"{self.__class__.__name__}"
365+
366+
def infer_shape(self, fgraph, node, input_shapes):
367+
return input_shapes
368+
369+
def grad(self, args, g_outs): # pragma: no cover
370+
return g_outs
371+
372+
373+
update_placeholder = UpdatePlaceholder()

tests/compile/function/test_types.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from aesara.compile.function.types import UnusedInputError
1212
from aesara.compile.io import In, Out
1313
from aesara.compile.mode import Mode, get_default_mode
14+
from aesara.compile.ops import update_placeholder
1415
from aesara.configdefaults import config
1516
from aesara.graph.basic import Constant
1617
from aesara.graph.rewriting.basic import OpKeyGraphRewriter, PatternNodeRewriter
@@ -1265,3 +1266,22 @@ def test_empty_givens_updates():
12651266
y = x * 2
12661267
function([In(x)], y, givens={})
12671268
function([In(x)], y, updates={})
1269+
1270+
1271+
def test_update_placeholder():
1272+
a, x, s, m, n = scalars("axsmn")
1273+
1274+
f1 = function(
1275+
[
1276+
x,
1277+
In(a, value=1.0, name="a"),
1278+
In(m, value=0.0, update=update_placeholder(m), mutable=True),
1279+
In(s, value=0.0, update=s + a * x, mutable=True),
1280+
In(n, value=0.0, update=update_placeholder(n), mutable=True),
1281+
],
1282+
s + a * x,
1283+
)
1284+
1285+
# The second update shouldn't be present
1286+
assert len(f1.maker.fgraph.outputs) == 2
1287+
assert f1.maker.fgraph.update_mapping == {1: 3}

0 commit comments

Comments
 (0)