Skip to content

Commit 2368ed3

Browse files
Fix bad initial value shape assumptions in save_mem_new_scan
1 parent 6cb7a38 commit 2368ed3

File tree

3 files changed

+247
-85
lines changed

3 files changed

+247
-85
lines changed

aesara/scan/rewriting.py

Lines changed: 101 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import dataclasses
55
from itertools import chain
66
from sys import maxsize
7-
from typing import Dict, List, Optional, Tuple, cast
7+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
88

99
import numpy as np
1010

@@ -36,7 +36,6 @@
3636
from aesara.scan.utils import (
3737
ScanArgs,
3838
compress_outs,
39-
expand_empty,
4039
reconstruct_graph,
4140
safe_new,
4241
scan_can_remove_outs,
@@ -48,18 +47,22 @@
4847
from aesara.tensor.rewriting.basic import constant_folding, local_useless_switch
4948
from aesara.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs
5049
from aesara.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink
51-
from aesara.tensor.shape import shape
50+
from aesara.tensor.shape import shape, shape_tuple
5251
from aesara.tensor.subtensor import (
5352
IncSubtensor,
5453
Subtensor,
5554
get_canonical_form_slice,
5655
get_idx_list,
5756
get_slice_elements,
57+
indices_from_subtensor,
5858
set_subtensor,
5959
)
6060
from aesara.tensor.var import TensorConstant, get_unique_value
6161

6262

63+
if TYPE_CHECKING:
64+
from aesara.tensor.var import TensorVariable
65+
6366
list_opt_slice = [
6467
local_abs_merge,
6568
local_mul_switch_sink,
@@ -1103,6 +1106,72 @@ def sanitize(x):
11031106
return at.as_tensor_variable(x)
11041107

11051108

1109+
def reshape_output_storage(
1110+
out_storage: "TensorVariable",
1111+
steps_needed: "TensorVariable",
1112+
tap_spread: int,
1113+
) -> "TensorVariable":
1114+
"""Resize the first dimension of ``storage`` in ``set_subtensor(storage[:tap_spread], initial_tap_vals)``.
1115+
1116+
This is used by `save_mem_new_scan` to reduce the amount of storage
1117+
(pre)allocated for `Scan` output arrays (i.e. ``storage`` is assumed to be
1118+
an `AllocEmpty` output).
1119+
1120+
Parameters
1121+
----------
1122+
out_storage
1123+
This corresponds to a graph with the form
1124+
``set_subtensor(storage[:tap_spread], initial_tap_vals)``.
1125+
tap_spread
1126+
The spread of the relevant tap. This will generally be the length of
1127+
``initial_tap_vals``, but sometimes not (e.g. because the initial
1128+
values broadcast across the indices/slice).
1129+
1130+
Returns
1131+
-------
1132+
Return a graph like
1133+
``set_subtensor(new_storage[:tap_spread], initial_tap_vals)``,
1134+
where ``new_storage`` is an `AllocEmpty` with a first
1135+
dimension having length ``maximum(steps_needed_var, tap_spread)``.
1136+
1137+
"""
1138+
out_storage_node = out_storage.owner
1139+
if (
1140+
out_storage_node
1141+
and isinstance(out_storage_node.op, IncSubtensor)
1142+
and out_storage_node.op.set_instead_of_inc
1143+
and len(out_storage_node.op.idx_list) == 1
1144+
and isinstance(out_storage_node.op.idx_list[0], slice)
1145+
):
1146+
# The "fill" value of the `IncSubtensor` across the
1147+
# slice. This should generally consist of the initial
1148+
# values.
1149+
initial_tap_vals = out_storage_node.inputs[1]
1150+
1151+
storage_slice = indices_from_subtensor(
1152+
out_storage_node.inputs[2:], out_storage_node.op.idx_list
1153+
)
1154+
inner_storage_var = out_storage_node.inputs[0]
1155+
1156+
# Why this size exactly? (N.B. This is what the original Theano logic ultimately did.)
1157+
max_storage_size = at.switch(
1158+
at.lt(steps_needed, tap_spread), steps_needed + 2 * tap_spread, steps_needed
1159+
)
1160+
new_inner_storage_var = at.empty(
1161+
(
1162+
max_storage_size,
1163+
*shape_tuple(inner_storage_var)[1:],
1164+
),
1165+
dtype=initial_tap_vals.dtype,
1166+
)
1167+
res = at.set_subtensor(new_inner_storage_var[storage_slice], initial_tap_vals)
1168+
else:
1169+
max_storage_size = maximum(steps_needed, tap_spread)
1170+
res = out_storage[:max_storage_size]
1171+
1172+
return cast("TensorVariable", res)
1173+
1174+
11061175
@node_rewriter([Scan])
11071176
def save_mem_new_scan(fgraph, node):
11081177
r"""Graph optimizer that reduces scan memory consumption.
@@ -1398,13 +1467,16 @@ def save_mem_new_scan(fgraph, node):
13981467
# by the inner function .. )
13991468
replaced_outs = []
14001469
offset = 1 + op_info.n_seqs + op_info.n_mit_mot
1401-
for idx, _val in enumerate(store_steps[op_info.n_mit_mot :]):
1470+
for idx, steps_needed in enumerate(store_steps[op_info.n_mit_mot :]):
14021471
i = idx + op_info.n_mit_mot
1403-
if not (isinstance(_val, int) and _val <= 0 and i not in required):
1404-
if idx + op_info.n_mit_mot in required:
1405-
val = 1
1406-
else:
1407-
val = _val
1472+
if not (
1473+
isinstance(steps_needed, int)
1474+
and steps_needed <= 0
1475+
and i not in required
1476+
):
1477+
if i in required:
1478+
steps_needed = 1
1479+
14081480
# If the memory for this output has been pre-allocated
14091481
# before going into the scan op (by an alloc node)
14101482
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
@@ -1413,38 +1485,18 @@ def save_mem_new_scan(fgraph, node):
14131485
# a) the input is a set_subtensor, in that case we
14141486
# can replace the initial tensor by a slice,
14151487
# b) it is not, and we simply take a slice of it.
1416-
# TODO: commit change below with Razvan
1417-
if (
1418-
nw_inputs[offset + idx].owner
1419-
and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor)
1420-
and isinstance(
1421-
nw_inputs[offset + idx].owner.op.idx_list[0], slice
1422-
)
1423-
):
1424-
assert isinstance(
1425-
nw_inputs[offset + idx].owner.op, IncSubtensor
1426-
)
1427-
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
1428-
cval = at.as_tensor_variable(val)
1429-
initl = at.as_tensor_variable(init_l[i])
1430-
tmp_idx = at.switch(cval < initl, cval + initl, cval - initl)
1431-
nw_input = expand_empty(_nw_input, tmp_idx)
1432-
else:
1433-
tmp = at.as_tensor_variable(val)
1434-
initl = at.as_tensor_variable(init_l[i])
1435-
tmp = maximum(tmp, initl)
1436-
nw_input = nw_inputs[offset + idx][:tmp]
1488+
out_storage = nw_inputs[offset + idx]
1489+
tap_spread = init_l[i]
1490+
nw_input = reshape_output_storage(
1491+
out_storage, steps_needed, tap_spread
1492+
)
14371493

14381494
nw_inputs[offset + idx] = nw_input
1439-
replaced_outs.append(op_info.n_mit_mot + idx)
1440-
odx = op_info.n_mit_mot + idx
1495+
replaced_outs.append(i)
14411496
old_outputs += [
14421497
(
1443-
odx,
1444-
[
1445-
x[0].outputs[0]
1446-
for x in fgraph.clients[node.outputs[odx]]
1447-
],
1498+
i,
1499+
[x[0].outputs[0] for x in fgraph.clients[node.outputs[i]]],
14481500
)
14491501
]
14501502
# If there is no memory pre-allocated for this output
@@ -1457,48 +1509,28 @@ def save_mem_new_scan(fgraph, node):
14571509
+ op_info.n_shared_outs
14581510
)
14591511
if nw_inputs[pos] == node.inputs[0]:
1460-
nw_inputs[pos] = val
1461-
odx = op_info.n_mit_mot + idx
1462-
replaced_outs.append(odx)
1512+
nw_inputs[pos] = steps_needed
1513+
replaced_outs.append(i)
14631514
old_outputs += [
14641515
(
1465-
odx,
1466-
[
1467-
x[0].outputs[0]
1468-
for x in fgraph.clients[node.outputs[odx]]
1469-
],
1516+
i,
1517+
[x[0].outputs[0] for x in fgraph.clients[node.outputs[i]]],
14701518
)
14711519
]
14721520
# 3.4. Recompute inputs for everything else based on the new
14731521
# number of steps
14741522
if global_nsteps is not None:
1475-
for idx, val in enumerate(store_steps[op_info.n_mit_mot :]):
1476-
if val == 0:
1477-
# val == 0 means that we want to keep all intermediate
1523+
for idx, steps_needed in enumerate(store_steps[op_info.n_mit_mot :]):
1524+
if steps_needed == 0:
1525+
# steps_needed == 0 means that we want to keep all intermediate
14781526
# results for that state, including the initial values.
14791527
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
14801528
in_idx = offset + idx
1481-
# Number of steps in the initial state
1482-
initl = init_l[op_info.n_mit_mot + idx]
1483-
1484-
# If the initial buffer has the form
1485-
# inc_subtensor(zeros(...)[...], _nw_input)
1486-
# we want to make the zeros tensor as small as
1487-
# possible (nw_steps + initl), and call
1488-
# inc_subtensor on that instead.
1489-
# Otherwise, simply take 0:(nw_steps+initl).
1490-
if (
1491-
nw_inputs[in_idx].owner
1492-
and isinstance(nw_inputs[in_idx].owner.op, IncSubtensor)
1493-
and isinstance(
1494-
nw_inputs[in_idx].owner.op.idx_list[0], slice
1495-
)
1496-
):
1497-
_nw_input = nw_inputs[in_idx].owner.inputs[1]
1498-
nw_input = expand_empty(_nw_input, nw_steps)
1499-
nw_inputs[in_idx] = nw_input
1500-
else:
1501-
nw_input = nw_inputs[in_idx][: (initl + nw_steps)]
1529+
out_storage = nw_inputs[in_idx]
1530+
tap_spread = init_l[op_info.n_mit_mot + idx]
1531+
nw_input = reshape_output_storage(
1532+
out_storage, steps_needed, tap_spread
1533+
)
15021534

15031535
elif (
15041536
idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot

tests/scan/test_basic.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2950,22 +2950,6 @@ def rec_fn(*args):
29502950
utt.assert_allclose(outs[2], v_w + 3)
29512951
utt.assert_allclose(sh.get_value(), v_w + 4)
29522952

2953-
def test_seq_tap_bug_jeremiah(self):
2954-
inp = np.arange(10).reshape(-1, 1).astype(config.floatX)
2955-
exp_out = np.zeros((10, 1)).astype(config.floatX)
2956-
exp_out[4:] = inp[:-4]
2957-
2958-
def onestep(x, x_tm4):
2959-
return x, x_tm4
2960-
2961-
seq = matrix()
2962-
initial_value = shared(np.zeros((4, 1), dtype=config.floatX))
2963-
outputs_info = [OrderedDict([("initial", initial_value), ("taps", [-4])]), None]
2964-
results, updates = scan(fn=onestep, sequences=seq, outputs_info=outputs_info)
2965-
2966-
f = function([seq], results[1])
2967-
assert np.all(exp_out == f(inp))
2968-
29692953
def test_shared_borrow(self):
29702954
"""
29712955
This tests two things. The first is a bug occurring when scan wrongly

0 commit comments

Comments
 (0)