|
25 | 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
26 | 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
27 | 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
28 |
| - |
29 |
| - |
30 | 28 | import numpy as np
|
| 29 | +import warnings |
31 | 30 |
|
32 | 31 | from qonnx.core.modelwrapper import ModelWrapper
|
33 | 32 | from qonnx.transformation.base import Transformation
|
@@ -58,21 +57,45 @@ def apply(self, model: ModelWrapper):
|
58 | 57 |
|
59 | 58 |
|
60 | 59 | def remove_node_and_rewire(model, node):
|
| 60 | + # Currently cannot remove and rewire join-nodes, probably not necessary to |
| 61 | + # support this |
| 62 | + if model.is_join_node(node): |
| 63 | + # Log this as a warning, so the user is aware of this, there might be |
| 64 | + # somthing wrong or some checks missing at the caller site |
| 65 | + warnings.warn( |
| 66 | + "Tried to remove join-node operation: Currently not supported" |
| 67 | + ) |
| 68 | + # Exit the function here without doing anything |
| 69 | + return |
| 70 | + # We already know that node is not a join-node, thus to rewire, we only need |
| 71 | + # to check the single producer |
61 | 72 | producer = model.find_producer(node.input[0])
|
62 |
| - if producer is not None: |
63 |
| - # wire output tensor to |
64 |
| - # output of producer node |
| 73 | + # If there is a producer which is not a fork-node, rewiring is simple |
| 74 | + if producer is not None and not model.is_fork_node(producer): |
| 75 | + # Rewire by skipping the node, letting the producer directly feed the |
| 76 | + # nodes output. |
| 77 | + # TODO: Check whether this already covers fork-node identities? |
65 | 78 | producer.output[0] = node.output[0]
|
| 79 | + # If there is no producer or the producer forks, rewiring is a bit more |
| 80 | + # complicated |
66 | 81 | else:
|
67 |
| - # node is first in graph |
| 82 | + # Now it depends on the successor nodes to rewire their inputs |
68 | 83 | successors = model.find_direct_successors(node)
|
| 84 | + # Singular node detached from the rest of the graph? |
69 | 85 | assert successors is not None, "Whole graph is one node."
|
70 |
| - for succ in successors: |
71 |
| - for i, s_inp in enumerate(succ.input): |
| 86 | + # We need to rewire the input of each successor to not detach parts of |
| 87 | + # the graph |
| 88 | + for successor in successors: |
| 89 | + # Find the inputs of the successor which are produced by the node to |
| 90 | + # be removed |
| 91 | + for i, s_inp in enumerate(successor.input): |
| 92 | + # Note: This might happen multiple times? |
72 | 93 | if s_inp == node.output[0]:
|
73 |
| - # rewire successor's input directly to graph input |
74 |
| - succ.input[i] = node.input[0] |
75 |
| - # remove node |
| 94 | + # Rewire successor's input directly to nodes input |
| 95 | + # Note: Node may not be a join-node, but there is probably |
| 96 | + # no such thing as join-node identity anyway |
| 97 | + successor.input[i] = node.input[0] |
| 98 | + # Remove node |
76 | 99 | model.graph.node.remove(node)
|
77 | 100 |
|
78 | 101 |
|
|
0 commit comments