Skip to content

Commit 7fdbd4d

Browse files
committed
multi graph dimname fix
1 parent c595c4f commit 7fdbd4d

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

hls4ml/model/graph.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
from hls4ml.backends import get_backend
1616
from hls4ml.model.flow import get_flow
17-
from hls4ml.model.layers import layer_map
17+
from hls4ml.model.layers import Layer, layer_map
1818
from hls4ml.model.optimizer import get_available_passes, optimize_model
19-
from hls4ml.model.types import Serializable
19+
from hls4ml.model.types import Serializable, TensorVariable
2020
from hls4ml.utils.string_utils import convert_to_snake_case
2121

2222

@@ -1069,8 +1069,8 @@ def from_model_graph(cls, base_model: ModelGraph, split_before_layers: list[str]
10691069
cfg_copy.config['ProjectName'] = f'{base_model.config.get_project_name()}_graph{idx + 1}'
10701070
cfg_copy.config['OutputDir'] = os.path.join(base_model.config.get_output_dir(), f'graph{idx + 1}')
10711071

1072-
subgraph = base_model.__class__(cfg_copy, inputs=[], outputs=[])
1073-
graph_dict = OrderedDict()
1072+
subgraph = ModelGraph(cfg_copy, inputs=[], outputs=[])
1073+
graph_dict: OrderedDict[str, Layer] = OrderedDict()
10741074

10751075
if idx > 0:
10761076
next_index += 1
@@ -1091,6 +1091,11 @@ def from_model_graph(cls, base_model: ModelGraph, split_before_layers: list[str]
10911091
subgraph.outputs = slice_[-1].outputs if idx < len(node_slices) - 1 else base_model.outputs
10921092
subgraph._applied_flows = base_model._applied_flows
10931093

1094+
for node in subgraph.graph.values():
1095+
# Prevent name conflict in different subgraphs
1096+
variable: TensorVariable = node.get_output_variable()
1097+
variable.dim_names = [f'G{idx}_{name}' for name in variable.dim_names]
1098+
10941099
# NOTE might need to examine other subgraph-related flows (i.e., fifo_optimizer)
10951100
subgraph.apply_flow('vivado:specific_types')
10961101
subgraph.apply_flow('vitis:apply_templates')

0 commit comments

Comments
 (0)