4
4
import dataclasses
5
5
from itertools import chain
6
6
from sys import maxsize
7
- from typing import Dict , List , Optional , Tuple , cast
7
+ from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , cast
8
8
9
9
import numpy as np
10
10
36
36
from aesara .scan .utils import (
37
37
ScanArgs ,
38
38
compress_outs ,
39
- expand_empty ,
40
39
reconstruct_graph ,
41
40
safe_new ,
42
41
scan_can_remove_outs ,
48
47
from aesara .tensor .rewriting .basic import constant_folding , local_useless_switch
49
48
from aesara .tensor .rewriting .elemwise import local_upcast_elemwise_constant_inputs
50
49
from aesara .tensor .rewriting .math import local_abs_merge , local_mul_switch_sink
51
- from aesara .tensor .shape import shape
50
+ from aesara .tensor .shape import shape , shape_tuple
52
51
from aesara .tensor .subtensor import (
53
52
IncSubtensor ,
54
53
Subtensor ,
55
54
get_canonical_form_slice ,
56
55
get_idx_list ,
57
56
get_slice_elements ,
57
+ indices_from_subtensor ,
58
58
set_subtensor ,
59
59
)
60
60
from aesara .tensor .var import TensorConstant , get_unique_value
61
61
62
62
63
+ if TYPE_CHECKING :
64
+ from aesara .tensor .var import TensorVariable
65
+
63
66
list_opt_slice = [
64
67
local_abs_merge ,
65
68
local_mul_switch_sink ,
@@ -1103,6 +1106,72 @@ def sanitize(x):
1103
1106
return at .as_tensor_variable (x )
1104
1107
1105
1108
1109
+ def reshape_output_storage (
1110
+ out_storage : "TensorVariable" ,
1111
+ steps_needed : "TensorVariable" ,
1112
+ tap_spread : int ,
1113
+ ) -> "TensorVariable" :
1114
+ """Resize the first dimension of ``storage`` in ``set_subtensor(storage[:tap_spread], initial_tap_vals)``.
1115
+
1116
+ This is used by `save_mem_new_scan` to reduce the amount of storage
1117
+ (pre)allocated for `Scan` output arrays (i.e. ``storage`` is assumed to be
1118
+ an `AllocEmpty` output).
1119
+
1120
+ Parameters
1121
+ ----------
1122
+ out_storage
1123
+ This corresponds to a graph with the form
1124
+ ``set_subtensor(storage[:tap_spread], initial_tap_vals)``.
1125
+ tap_spread
1126
+ The spread of the relevant tap. This will generally be the length of
1127
+ ``initial_tap_vals``, but sometimes not (e.g. because the initial
1128
+ values broadcast across the indices/slice).
1129
+
1130
+ Returns
1131
+ -------
1132
+ Return a graph like
1133
+ ``set_subtensor(new_storage[:tap_spread], initial_tap_vals)``,
1134
+ where ``new_storage`` is an `AllocEmpty` with a first
1135
+ dimension having length ``maximum(steps_needed_var, tap_spread)``.
1136
+
1137
+ """
1138
+ out_storage_node = out_storage .owner
1139
+ if (
1140
+ out_storage_node
1141
+ and isinstance (out_storage_node .op , IncSubtensor )
1142
+ and out_storage_node .op .set_instead_of_inc
1143
+ and len (out_storage_node .op .idx_list ) == 1
1144
+ and isinstance (out_storage_node .op .idx_list [0 ], slice )
1145
+ ):
1146
+ # The "fill" value of the `IncSubtensor` across the
1147
+ # slice. This should generally consist of the initial
1148
+ # values.
1149
+ initial_tap_vals = out_storage_node .inputs [1 ]
1150
+
1151
+ storage_slice = indices_from_subtensor (
1152
+ out_storage_node .inputs [2 :], out_storage_node .op .idx_list
1153
+ )
1154
+ inner_storage_var = out_storage_node .inputs [0 ]
1155
+
1156
+ # Why this size exactly? (N.B. This is what the original Theano logic ultimately did.)
1157
+ max_storage_size = at .switch (
1158
+ at .lt (steps_needed , tap_spread ), steps_needed + 2 * tap_spread , steps_needed
1159
+ )
1160
+ new_inner_storage_var = at .empty (
1161
+ (
1162
+ max_storage_size ,
1163
+ * shape_tuple (inner_storage_var )[1 :],
1164
+ ),
1165
+ dtype = initial_tap_vals .dtype ,
1166
+ )
1167
+ res = at .set_subtensor (new_inner_storage_var [storage_slice ], initial_tap_vals )
1168
+ else :
1169
+ max_storage_size = maximum (steps_needed , tap_spread )
1170
+ res = out_storage [:max_storage_size ]
1171
+
1172
+ return cast ("TensorVariable" , res )
1173
+
1174
+
1106
1175
@node_rewriter ([Scan ])
1107
1176
def save_mem_new_scan (fgraph , node ):
1108
1177
r"""Graph optimizer that reduces scan memory consumption.
@@ -1398,13 +1467,16 @@ def save_mem_new_scan(fgraph, node):
1398
1467
# by the inner function .. )
1399
1468
replaced_outs = []
1400
1469
offset = 1 + op_info .n_seqs + op_info .n_mit_mot
1401
- for idx , _val in enumerate (store_steps [op_info .n_mit_mot :]):
1470
+ for idx , steps_needed in enumerate (store_steps [op_info .n_mit_mot :]):
1402
1471
i = idx + op_info .n_mit_mot
1403
- if not (isinstance (_val , int ) and _val <= 0 and i not in required ):
1404
- if idx + op_info .n_mit_mot in required :
1405
- val = 1
1406
- else :
1407
- val = _val
1472
+ if not (
1473
+ isinstance (steps_needed , int )
1474
+ and steps_needed <= 0
1475
+ and i not in required
1476
+ ):
1477
+ if i in required :
1478
+ steps_needed = 1
1479
+
1408
1480
# If the memory for this output has been pre-allocated
1409
1481
# before going into the scan op (by an alloc node)
1410
1482
if idx < op_info .n_mit_sot + op_info .n_sit_sot :
@@ -1413,38 +1485,18 @@ def save_mem_new_scan(fgraph, node):
1413
1485
# a) the input is a set_subtensor, in that case we
1414
1486
# can replace the initial tensor by a slice,
1415
1487
# b) it is not, and we simply take a slice of it.
1416
- # TODO: commit change below with Razvan
1417
- if (
1418
- nw_inputs [offset + idx ].owner
1419
- and isinstance (nw_inputs [offset + idx ].owner .op , IncSubtensor )
1420
- and isinstance (
1421
- nw_inputs [offset + idx ].owner .op .idx_list [0 ], slice
1422
- )
1423
- ):
1424
- assert isinstance (
1425
- nw_inputs [offset + idx ].owner .op , IncSubtensor
1426
- )
1427
- _nw_input = nw_inputs [offset + idx ].owner .inputs [1 ]
1428
- cval = at .as_tensor_variable (val )
1429
- initl = at .as_tensor_variable (init_l [i ])
1430
- tmp_idx = at .switch (cval < initl , cval + initl , cval - initl )
1431
- nw_input = expand_empty (_nw_input , tmp_idx )
1432
- else :
1433
- tmp = at .as_tensor_variable (val )
1434
- initl = at .as_tensor_variable (init_l [i ])
1435
- tmp = maximum (tmp , initl )
1436
- nw_input = nw_inputs [offset + idx ][:tmp ]
1488
+ out_storage = nw_inputs [offset + idx ]
1489
+ tap_spread = init_l [i ]
1490
+ nw_input = reshape_output_storage (
1491
+ out_storage , steps_needed , tap_spread
1492
+ )
1437
1493
1438
1494
nw_inputs [offset + idx ] = nw_input
1439
- replaced_outs .append (op_info .n_mit_mot + idx )
1440
- odx = op_info .n_mit_mot + idx
1495
+ replaced_outs .append (i )
1441
1496
old_outputs += [
1442
1497
(
1443
- odx ,
1444
- [
1445
- x [0 ].outputs [0 ]
1446
- for x in fgraph .clients [node .outputs [odx ]]
1447
- ],
1498
+ i ,
1499
+ [x [0 ].outputs [0 ] for x in fgraph .clients [node .outputs [i ]]],
1448
1500
)
1449
1501
]
1450
1502
# If there is no memory pre-allocated for this output
@@ -1457,48 +1509,28 @@ def save_mem_new_scan(fgraph, node):
1457
1509
+ op_info .n_shared_outs
1458
1510
)
1459
1511
if nw_inputs [pos ] == node .inputs [0 ]:
1460
- nw_inputs [pos ] = val
1461
- odx = op_info .n_mit_mot + idx
1462
- replaced_outs .append (odx )
1512
+ nw_inputs [pos ] = steps_needed
1513
+ replaced_outs .append (i )
1463
1514
old_outputs += [
1464
1515
(
1465
- odx ,
1466
- [
1467
- x [0 ].outputs [0 ]
1468
- for x in fgraph .clients [node .outputs [odx ]]
1469
- ],
1516
+ i ,
1517
+ [x [0 ].outputs [0 ] for x in fgraph .clients [node .outputs [i ]]],
1470
1518
)
1471
1519
]
1472
1520
# 3.4. Recompute inputs for everything else based on the new
1473
1521
# number of steps
1474
1522
if global_nsteps is not None :
1475
- for idx , val in enumerate (store_steps [op_info .n_mit_mot :]):
1476
- if val == 0 :
1477
- # val == 0 means that we want to keep all intermediate
1523
+ for idx , steps_needed in enumerate (store_steps [op_info .n_mit_mot :]):
1524
+ if steps_needed == 0 :
1525
+ # steps_needed == 0 means that we want to keep all intermediate
1478
1526
# results for that state, including the initial values.
1479
1527
if idx < op_info .n_mit_sot + op_info .n_sit_sot :
1480
1528
in_idx = offset + idx
1481
- # Number of steps in the initial state
1482
- initl = init_l [op_info .n_mit_mot + idx ]
1483
-
1484
- # If the initial buffer has the form
1485
- # inc_subtensor(zeros(...)[...], _nw_input)
1486
- # we want to make the zeros tensor as small as
1487
- # possible (nw_steps + initl), and call
1488
- # inc_subtensor on that instead.
1489
- # Otherwise, simply take 0:(nw_steps+initl).
1490
- if (
1491
- nw_inputs [in_idx ].owner
1492
- and isinstance (nw_inputs [in_idx ].owner .op , IncSubtensor )
1493
- and isinstance (
1494
- nw_inputs [in_idx ].owner .op .idx_list [0 ], slice
1495
- )
1496
- ):
1497
- _nw_input = nw_inputs [in_idx ].owner .inputs [1 ]
1498
- nw_input = expand_empty (_nw_input , nw_steps )
1499
- nw_inputs [in_idx ] = nw_input
1500
- else :
1501
- nw_input = nw_inputs [in_idx ][: (initl + nw_steps )]
1529
+ out_storage = nw_inputs [in_idx ]
1530
+ tap_spread = init_l [op_info .n_mit_mot + idx ]
1531
+ nw_input = reshape_output_storage (
1532
+ out_storage , steps_needed , tap_spread
1533
+ )
1502
1534
1503
1535
elif (
1504
1536
idx < op_info .n_mit_sot + op_info .n_sit_sot + op_info .n_nit_sot
0 commit comments