6
6
import time
7
7
import warnings
8
8
from itertools import chain
9
- from typing import TYPE_CHECKING , List , Optional , Tuple , Type
9
+ from typing import TYPE_CHECKING , List , Optional , Tuple , Type , Union
10
10
11
11
import numpy as np
12
+ from typing_extensions import Literal
12
13
13
14
import aesara
14
15
import aesara .compile .profiling
15
- from aesara .compile .io import In , SymbolicInput , SymbolicOutput
16
- from aesara .compile .ops import deep_copy_op , view_op
16
+ from aesara .compile .io import In , Out , SymbolicInput , SymbolicOutput
17
+ from aesara .compile .ops import deep_copy_op , update_placeholder , view_op
18
+ from aesara .compile .profiling import ProfileStats
17
19
from aesara .configdefaults import config
18
20
from aesara .graph .basic import (
19
21
Constant ,
@@ -732,10 +734,10 @@ def checkSV(sv_ori, sv_rpl):
732
734
message = name
733
735
else :
734
736
message = str (profile .message ) + " copy"
735
- profile = aesara . compile . profiling . ProfileStats (message = message )
737
+ profile = ProfileStats (message = message )
736
738
# profile -> object
737
739
elif isinstance (profile , str ):
738
- profile = aesara . compile . profiling . ProfileStats (message = profile )
740
+ profile = ProfileStats (message = profile )
739
741
740
742
f_cpy = maker .__class__ (
741
743
inputs = ins ,
@@ -1392,13 +1394,35 @@ def check_unused_inputs(inputs, outputs, on_unused_input):
1392
1394
1393
1395
@staticmethod
1394
1396
def prepare_fgraph (
1395
- inputs ,
1396
- outputs ,
1397
- additional_outputs ,
1397
+ inputs : List [ In ] ,
1398
+ outputs : List [ Out ] ,
1399
+ additional_outputs : List [ Out ] ,
1398
1400
fgraph : FunctionGraph ,
1399
1401
mode : "Mode" ,
1400
- profile ,
1402
+ profile : Union [ Optional [ ProfileStats ], Literal [ False ]] ,
1401
1403
):
1404
+ r"""Perform rewrites on a graph, insert `DeepCopyOp`\s, and remove unused updates.
1405
+
1406
+ .. warning::
1407
+
1408
+ The `additional_outputs` list and `fgraph.outputs` are updated in-place by this method.
1409
+
1410
+ Parameters
1411
+ ==========
1412
+ inputs
1413
+ The wrapped inputs.
1414
+ outputs
1415
+ The wrapped outputs (i.e. wrapped with `Out`).
1416
+ additional_outputs
1417
+ Output graphs that essentially serve as updates to mutable `inputs`.
1418
+ fgraph
1419
+ The `FunctionGraph` to be prepared.
1420
+ mode
1421
+ The `Mode` that determines--for example--which rewrites are applied.
1422
+ profile
1423
+ The profile object/setting to use.
1424
+
1425
+ """
1402
1426
1403
1427
rewriter = mode .optimizer
1404
1428
@@ -1419,6 +1443,40 @@ def prepare_fgraph(
1419
1443
rewrite_time = end_rewriter - start_rewriter
1420
1444
_logger .debug (f"Rewriting took { rewrite_time :f} seconds" )
1421
1445
1446
+ fgraph_outputs = tuple (fgraph .outputs )
1447
+ update_mappings = tuple (fgraph .update_mapping .items ())
1448
+ outputs_to_remove = []
1449
+ additional_outputs_to_remove = []
1450
+
1451
+ # Remove unused updates
1452
+ for i , (out_idx , in_idx ) in enumerate (update_mappings ):
1453
+ update = fgraph_outputs [out_idx ]
1454
+
1455
+ if update .owner and update .owner .op == update_placeholder :
1456
+
1457
+ # TODO: Consider removing the corresponding
1458
+ # `FunctionGraph` input when it has no other
1459
+ # references?
1460
+ # updated_var = fgraph_inputs[in_idx]
1461
+ # if not fgraph.clients[updated_var]:
1462
+ # fgraph.remove_input(updated_var)
1463
+
1464
+ # Remove the update entry from the wrapped inputs
1465
+ inputs [in_idx ].update = None
1466
+
1467
+ # We assume that the orders of `fgraph.update_mapping` and
1468
+ # `additional_outputs` correspond (and they should)
1469
+ additional_outputs_to_remove .append (additional_outputs [i ])
1470
+
1471
+ outputs_to_remove .append (fgraph .outputs [out_idx ])
1472
+
1473
+ for add_out , out in zip (
1474
+ additional_outputs_to_remove , outputs_to_remove
1475
+ ):
1476
+ additional_outputs .remove (add_out )
1477
+ fgraph_out_idx = fgraph .outputs .index (out )
1478
+ fgraph .remove_output (fgraph_out_idx )
1479
+
1422
1480
# Add deep copy to respect the memory interface
1423
1481
insert_deepcopy (fgraph , inputs , outputs + additional_outputs )
1424
1482
finally :
0 commit comments