Skip to content

Commit f08a869

Browse files
authored
Merge pull request #113 from bwintermann/main
Updated partitioning function for PartitionFromDict
2 parents 2c91d6d + 5d246e8 commit f08a869

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

src/qonnx/transformation/create_generic_partitions.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,14 @@ def __init__(self, partitioning={}, partition_dir=None):
195195
def apply(self, model):
196196
# prepare node -> int assignment fct.
197197
def partitioning_func(node):
198-
partition_id = -1
199-
for key in self.partitioning:
200-
if node in list(model.graph.node) and list(model.graph.node).index(node) in list(self.partitioning[key]):
201-
assert partition_id == -1, """single node assigned to multiple partitions"""
202-
partition_id = key
203-
204-
return partition_id
198+
if node not in model.graph.node:
199+
return -1
200+
node_index = list(model.graph.node).index(node)
201+
candidates = list(filter(lambda key_value: node_index in key_value[1], self.partitioning.items()))
202+
if len(candidates) == 0:
203+
return -1
204+
assert len(candidates) == 1, f"single node assigned to multiple partitions: {candidates}"
205+
return candidates[0][0] # partition_id
205206

206207
# apply partitioning
207208
model = model.transform(PartitionFromLambda(partitioning_func, self.partition_dir))

0 commit comments

Comments
 (0)