Skip to content

Commit c0e9611

Browse files
Support for multi graph build (#1174)
* split ModelGraph at specified layer name feat: add make_multi_graph classmethod to ModelGraph - The method returns two instances of the `ModelGraph` class. - Each instance is initialized with the same config, only output folder changes, allowing separate models to be created in one call. - This improves usability by simplifying the process of generating multiple graphs from a single configuration input. make_multi_graph can now support arbitrary number of graphs * takes as input the split_layer_names as split points * returns a list of ModelGraph * works for dense/fc layers at the moment * need to find input_shape of split layer for conv layers. Currenly in dense/fc we find it through 'n_in' key Pass output_shapes to make_multi_graph to detect input shapes of split layers fixed layer index in the newly created graph fix minor mistakes * Add TCL script for automatic connection of subgraph IPs in Vivado * Automatically scans and add HLS IP cores for subgraphs in Vivado * Automatically detects interface types used by the IPs (either unpacked or AXI stream) and configures the connections accordingly. * Also, updated the multigraph logic to copy the precision of the last layer from the previous graph and apply it to the input layer of the next graph. some minor fixes in tcl script and make_multi_graph support for parallel subgraph builds. Also, make_multi_graph now returns a MultiModelGraph instance new tcl script connected external and control signals integrate ip_stitcher tcl script in hls4ml fix in tcl. folder creation for stitch project package final stitched ip in hls4ml Notes: * missing X_INTERFACE_INFO for axi interfaces in the generated HDL during packaging * Vivado throws warning : Misformed interface info * We ommit this warning at the moment, as IP can still be packaged support for multiple inputs/outputs in first/last layer of stitched ip * initial support for stitched ip simulation generate verilog testbench for stitched ip read testbench output minor changes improvements in testbench generation and build interface​ general improvements only simulate stitched_design, better verilog testbench prepare testbench input from user support for user-defined input in verilog testbench of stitched IP * fix for multi input/output layers in graph splitting * documentation for MultiModelGraph flow * faster rtl simulation * unwrap list if it has single element * Make MultiModelGraph adaptable to user-defined names * stitch script time verbose * fix with existing stitch project folder * initial support for multigraph compilation in bridge file * stitched report fix for VivadoSynth aggregate * use log_to_stdout flag for parallel builds * small change * remove bridged multigraph compilation for now * [pre-commit.ci] auto fixes from pre-commit hooks * fix 'ap_rst' port polarity for active high case * support for partition interface in verilog testbench * support for MultiModelGraph predict using chained bridge file * Add pytest for multi-graph and fix minor issues * pre-commit fixes * removed pandas dependency in read_testbench_log * Ensure stitched RTL simulation results align with CSim output * parallel subgraph compilation * added additional checks in ip_stitcher * small improvements on MultiModelGraph * correct AXIS port slicing for Verilog simulation * Generate Verilog testbench inputs using C++ bridge * Fix rebase conflict in ModelGraph object creation * Major rewrite of multi-graph splitting. it now uses the optimized ModelGraph * minor fixes and improvements * minor fixes * skip stitching if a graph failed * final changes * remove synthesis test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 89dc007 commit c0e9611

File tree

16 files changed

+2423
-46
lines changed

16 files changed

+2423
-46
lines changed

docs/img/logo_small.png

6.28 KB
Loading

docs/ir/multimodelgraph.rst

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
=======================
2+
MultiModelGraph Class
3+
=======================
4+
5+
This page documents the ``MultiModelGraph`` class, which enables handling multiple subgraphs (each represented as a ``ModelGraph``) derived from a single original model.
6+
The central concept here is the division of a larger model into multiple smaller subgraphs at given layers which can be useful for:
7+
8+
* Very large models
9+
* Step-wise optimization
10+
* Modular design flows
11+
12+
A ``MultiModelGraph`` manages these subgraphs, facilitating:
13+
14+
* Parallel building and synthesis
15+
* Stitched designs (merging the subgraphs in HW after synthesis)
16+
* Simulation and performance estimation of the stitched design
17+
18+
--------------
19+
Keras Example
20+
--------------
21+
22+
For example, when converting a Keras model, you can specify the layers at which to split the model directly:
23+
24+
.. code-block:: python
25+
26+
config = hls4ml.utils.config_from_keras_model(model, granularity='model')
27+
28+
hls_model = hls4ml.converters.convert_from_keras_model(
29+
model,
30+
hls_config=config,
31+
backend='vitis',
32+
)
33+
hls_multigraph_model = hls4ml.model.to_multi_model_graph(hls_model, ['layer3', 'layer7'])
34+
35+
Here, the ``hls_multigraph_model`` is a ``MultiModelGraph`` containing three subgraphs. Each subgraph is a ``ModelGraph`` accessible via indexing: ``hls_multigraph_model[i]``.
36+
37+
38+
----------------------------------
39+
Key Methods for MultiModelGraph
40+
----------------------------------
41+
42+
* :ref:`compile <mmg-compile-method>`
43+
* :ref:`predict <mmg-predict-method>`
44+
* :ref:`build <mmg-build-method>`
45+
* :ref:`trace <mmg-trace-method>`
46+
47+
----
48+
49+
.. _mmg-compile-method:
50+
51+
``compile`` method
52+
==================
53+
54+
Compiles all the individual ``ModelGraph`` subgraphs within the ``MultiModelGraph``. Also, compiles a chained bridge file with all the subgraphs linked together that can be used for the predict function.
55+
56+
.. code-block:: python
57+
58+
hls_multigraph_model.compile()
59+
60+
----
61+
62+
.. _mmg-build-method:
63+
64+
``build`` method
65+
================
66+
67+
Builds all subgraphs in parallel, each as if they were standalone ``ModelGraph`` projects. Returns reports for each subgraph. If configured, it then runs the stitching flow in Vivado, connecting the individual exported IPs and allowing you to simulate the stitched design at the RTL level.
68+
69+
.. code-block:: python
70+
71+
report = hls_multigraph_model.build(.., export=True, stitch_design=True, sim_stitched_design=True, export_stitched_design=True)
72+
73+
The returned ``report`` contains results from each subgraph's build and, if stitching was performed, a combined report of the stitched design. Reports for individual ``ModelGraph`` instances are always accessible via
74+
``MultiModelGraph.graph_reports``.
75+
76+
77+
----
78+
79+
.. _mmg-predict-method:
80+
81+
``predict`` method
82+
==================
83+
84+
Performs a forward pass through the chained bridge file using the C-simulation (``sim='csim'``), providing 1-to-1 output with the original model. You can also leverage RTL simulation (``sim='rtl'``) to perform the forward pass at the register-transfer level. In this case, a Verilog testbench is dynamically generated and executed against the stitched IP design, providing behavioral simulation to accurately verify latency and output at the hardware level. Note that the input data for the RTL simulation must have a single batch dimension.
85+
86+
.. code-block:: python
87+
88+
# Perform prediction using C-simulation (default)
89+
y_csim = hls_multigraph_model.predict(X, sim='csim')
90+
91+
# Perform prediction using RTL simulation (behavioral)
92+
y_rtl = hls_multigraph_model.predict(X, sim='rtl')
93+
94+
95+
96+
--------------------------
97+
Summary
98+
--------------------------
99+
100+
The ``MultiModelGraph`` class is a tool for modular hardware design. By splitting a large neural network into multiple subgraphs, building each independently, and then stitching them together, you gain flexibility, parallelism, and facilitate hierarchical design, incremental optimization, and integrated system-level simulations.
101+
102+
103+
Notes and Known Issues
104+
=======================
105+
106+
Graph Splitting
107+
---------------
108+
109+
- Splitting in the middle of a branched architecture (e.g., ResNet skip connections) is currently unsupported.
110+
- Each split subgraph must have exactly one input.
111+
112+
Multiple Inputs & Outputs
113+
-------------------------
114+
115+
- The final NN output can support multiple output layers.
116+
- For networks with multiple input layers (a relatively uncommon case), proper synchronization is required in the testbench to drive inputs—especially for io_stream interfaces.
117+
118+
Simulation Discrepancies
119+
------------------------
120+
121+
- Users should carefully verify functional equivalence (particularly for models that use ``io_stream`` interface)
122+
- These discrepancies are more noticeable with raw output logits; applying a softmax layer at the model output can often help mask these differences, but this should be used with caution.
123+
124+
TODOs
125+
-----------------------
126+
127+
- Currently tested with Vitis 2024.1. Investigate compatibility with other versions.
128+
- Add support for Verilator-based simulation to enable faster RTL simulation.
129+
- Investigate ``io_stream`` interface (output discrepancies, fifo optimization)
130+
- Investigate differences in resource utilization for the ``io_parallel`` interface.

hls4ml/backends/vitis/vitis_backend.py

Lines changed: 131 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
1+
import importlib.util
2+
import json
13
import os
4+
import shutil
5+
import subprocess
26
import sys
37

48
from hls4ml.backends import VivadoBackend
59
from hls4ml.model.flow import get_flow, register_flow
6-
from hls4ml.report import parse_vivado_report
10+
from hls4ml.report import aggregate_graph_reports, parse_vivado_report
11+
from hls4ml.utils.simulation_utils import (
12+
annotate_axis_stream_widths,
13+
prepare_tb_inputs,
14+
read_testbench_log,
15+
write_verilog_testbench,
16+
)
717

818

919
class VitisBackend(VivadoBackend):
@@ -98,29 +108,131 @@ def build(
98108
export=False,
99109
vsynth=False,
100110
fifo_opt=False,
111+
log_to_stdout=True,
101112
):
102113
if 'linux' in sys.platform:
103114
found = os.system('command -v vitis_hls > /dev/null')
104115
if found != 0:
105116
raise Exception('Vitis HLS installation not found. Make sure "vitis_hls" is on PATH.')
106117

107-
curr_dir = os.getcwd()
108-
os.chdir(model.config.get_output_dir())
109-
os.system(
110-
(
111-
'vitis_hls -f build_prj.tcl "reset={reset} csim={csim} synth={synth} cosim={cosim} '
112-
'validation={validation} export={export} vsynth={vsynth} fifo_opt={fifo_opt}"'
113-
).format(
114-
reset=reset,
115-
csim=csim,
116-
synth=synth,
117-
cosim=cosim,
118-
validation=validation,
119-
export=export,
120-
vsynth=vsynth,
121-
fifo_opt=fifo_opt,
122-
)
118+
build_command = (
119+
'vitis_hls -f build_prj.tcl "reset={reset} csim={csim} synth={synth} cosim={cosim} '
120+
'validation={validation} export={export} vsynth={vsynth} fifo_opt={fifo_opt}"'
121+
).format(
122+
reset=reset,
123+
csim=csim,
124+
synth=synth,
125+
cosim=cosim,
126+
validation=validation,
127+
export=export,
128+
vsynth=vsynth,
129+
fifo_opt=fifo_opt,
123130
)
124-
os.chdir(curr_dir)
125131

126-
return parse_vivado_report(model.config.get_output_dir())
132+
output_dir = model.config.get_output_dir()
133+
stdout_log = os.path.join(output_dir, 'build_stdout.log')
134+
stderr_log = os.path.join(output_dir, 'build_stderr.log')
135+
136+
stdout_target = None if log_to_stdout else open(stdout_log, 'w')
137+
stderr_target = None if log_to_stdout else open(stderr_log, 'w')
138+
139+
try:
140+
process = subprocess.Popen(
141+
build_command, shell=True, cwd=output_dir, stdout=stdout_target, stderr=stderr_target, text=True
142+
)
143+
process.communicate()
144+
145+
if process.returncode != 0:
146+
raise Exception(f'Build failed for {model.config.get_project_name()}. See logs for details.')
147+
finally:
148+
if not log_to_stdout:
149+
stdout_target.close()
150+
stderr_target.close()
151+
152+
return parse_vivado_report(output_dir)
153+
154+
def build_stitched_design(
155+
self,
156+
model,
157+
stitch_design=True,
158+
sim_stitched_design=False,
159+
export_stitched_design=False,
160+
graph_reports=None,
161+
simulation_input_data=None,
162+
):
163+
164+
nn_config = model.nn_config
165+
os.makedirs(nn_config['OutputDir'], exist_ok=True)
166+
stitched_design_dir = os.path.join(nn_config['OutputDir'], nn_config['StitchedProjectName'])
167+
if stitch_design:
168+
if os.path.exists(stitched_design_dir):
169+
shutil.rmtree(stitched_design_dir)
170+
os.makedirs(stitched_design_dir)
171+
172+
spec = importlib.util.find_spec('hls4ml')
173+
hls4ml_path = os.path.dirname(spec.origin)
174+
ip_stitcher_path = os.path.join(hls4ml_path, 'templates/vivado/ip_stitcher.tcl')
175+
stdout_log = os.path.join(stitched_design_dir, 'stitcher_stdout.log')
176+
stderr_log = os.path.join(stitched_design_dir, 'stitcher_stderr.log')
177+
nn_config_path = os.path.join(stitched_design_dir, 'nn_config.json')
178+
testbench_path = os.path.join(stitched_design_dir, 'testbench.v')
179+
testbench_log_path = os.path.join(stitched_design_dir, 'testbench_log.csv')
180+
181+
try:
182+
shutil.copy(ip_stitcher_path, stitched_design_dir)
183+
except Exception as e:
184+
print(f"Error: {e}. Cannot copy 'ip_stitcher.tcl' to {nn_config['StitchedProjectName']} folder.")
185+
186+
# Verilog output bitwidths are rounded up and may differ from HLS output bitwidths
187+
if nn_config['outputs'][0]['pragma'] == 'stream':
188+
last_graph_project_path = os.path.join(
189+
model.graphs[-1].config.get_output_dir(), model.graphs[-1].config.get_project_dir()
190+
)
191+
annotate_axis_stream_widths(nn_config, last_graph_project_path)
192+
with open(nn_config_path, "w") as file:
193+
json.dump(nn_config, file, indent=4)
194+
195+
if sim_stitched_design:
196+
write_verilog_testbench(nn_config, testbench_path)
197+
tb_inputs = prepare_tb_inputs(simulation_input_data, nn_config['inputs'])
198+
model.write_tb_inputs(tb_inputs, stitched_design_dir)
199+
print('Verilog testbench and its input data were generated.')
200+
201+
print('Running build process of stitched IP...\n')
202+
stitch_command = [
203+
'vivado',
204+
'-mode',
205+
'batch',
206+
'-nojournal',
207+
'-nolog',
208+
'-notrace',
209+
'-source',
210+
ip_stitcher_path,
211+
'-tclargs',
212+
f'stitch_design={int(stitch_design)}',
213+
f'sim_design={int(sim_stitched_design)}',
214+
f'export_design={int(export_stitched_design)}',
215+
f"stitch_project_name={nn_config['StitchedProjectName']}",
216+
f"original_project_name={nn_config['OriginalProjectName']}",
217+
'sim_verilog_file=testbench.v',
218+
]
219+
220+
with open(stdout_log, 'w') as stdout_file, open(stderr_log, 'w') as stderr_file:
221+
process = subprocess.Popen(
222+
stitch_command, cwd=stitched_design_dir, stdout=stdout_file, stderr=stderr_file, text=True, shell=False
223+
)
224+
process.communicate()
225+
if process.returncode != 0:
226+
raise Exception(f"Stitching failed for {nn_config['StitchedProjectName']}. See logs for details.")
227+
228+
stitched_report = {'StitchedDesignReport': {}}
229+
if stitch_design:
230+
stitched_report = aggregate_graph_reports(graph_reports)
231+
232+
if sim_stitched_design:
233+
testbench_output = read_testbench_log(testbench_log_path, nn_config['outputs'])
234+
stitched_report['BehavSimResults'] = testbench_output['BehavSimResults']
235+
stitched_report['StitchedDesignReport']['BestLatency'] = testbench_output['BestLatency']
236+
stitched_report['StitchedDesignReport']['WorstLatency'] = testbench_output['WorstLatency']
237+
238+
return stitched_report

hls4ml/backends/vivado/passes/transform_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def transform(self, model, node):
3131
new_var = self.array_var_converter.convert(var, pragma='stream')
3232
elif io_type == 'io_parallel':
3333
if out_name in node.model.inputs:
34+
# NOTE this needs to be changed to partition
3435
new_var = self.array_var_converter.convert(var, pragma='reshape')
3536
elif isinstance(var, InplaceTensorVariable):
3637
new_var = self.inplace_array_var_converter.convert(var, pragma='')

hls4ml/converters/keras_v2_to_hls.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,4 @@ def parse_keras_model(model_arch, reader):
357357
def keras_v2_to_hls(config):
358358
model_arch, reader = get_model_arch(config)
359359
layer_list, input_layers, output_layers, _ = parse_keras_model(model_arch, reader)
360-
print('Creating HLS model')
361-
hls_model = ModelGraph.from_layer_list(config, layer_list, input_layers, output_layers)
362-
return hls_model
360+
return ModelGraph.from_layer_list(config, layer_list, input_layers, output_layers)

hls4ml/converters/pytorch_to_hls.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,4 @@ def parse_pytorch_model(config, verbose=True):
425425
@requires('_torch')
426426
def pytorch_to_hls(config):
427427
layer_list, input_layers, output_layers = parse_pytorch_model(config)
428-
print('Creating HLS model')
429-
hls_model = ModelGraph.from_layer_list(config, layer_list, inputs=input_layers, outputs=output_layers)
430-
return hls_model
428+
return ModelGraph.from_layer_list(config, layer_list, inputs=input_layers, outputs=output_layers)

hls4ml/model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from hls4ml.model.graph import HLSConfig, ModelGraph # noqa: F401
1+
from hls4ml.model.graph import HLSConfig, ModelGraph, MultiModelGraph, to_multi_model_graph # noqa: F401

0 commit comments

Comments
 (0)