Skip to content

Commit d308304

Browse files
Optimizations for reading dat + copytree bugfix
1 parent 8a617ff commit d308304

File tree

2 files changed

+10
-18
lines changed

2 files changed

+10
-18
lines changed

hls4ml/backends/vitis_accelerator/vitis_accelerator_backend.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,13 @@ def build(self, model, target="all"):
7171
raise Exception("Currently untested on non-Linux OS")
7272

7373
def _numpy_to_dat(self, model, x):
74-
if len(self.get_input_variables()) != 1:
75-
raise Exception("Currently unsupported for multi-input projects")
74+
if len(model.get_input_variables()) != 1:
75+
raise Exception("Currently unsupported for multi-input/output projects")
7676

7777
# Verify numpy array of correct shape
78-
expected_shape = (np.newaxis, model.get_input_variables()[0].size())
79-
print(f"Expected model input shape: {expected_shape}")
80-
print(f"Give numpy array shape: {x.shape}")
81-
if expected_shape != x.shape:
82-
raise Exception(f'Input shape mismatch, got {x.shape}, expected {expected_shape}')
78+
expected_shape = model.get_input_variables()[0].size()
79+
if expected_shape != x.shape[-1]:
80+
raise Exception(f'Input shape mismatch, got {x.shape}, expected (_, {expected_shape})')
8381

8482
# Write to tb_data/tb_input_features.dat
8583
input_dat = open(f'{model.config.get_output_dir()}/tb_data/tb_input_features.dat', 'w')
@@ -90,16 +88,8 @@ def _numpy_to_dat(self, model, x):
9088

9189
def _dat_to_numpy(self, model):
9290
expected_shape = model.get_output_variables()[0].size()
93-
y = np.array([], dtype=float).reshape(0, expected_shape)
94-
95-
output_dat = open(f'{model.config.get_output_dir()}/tb_data/hw_results.dat', 'r')
96-
for line in output_dat.readlines():
97-
data = [list(map(float, line.strip().split()))]
98-
if len(data) != expected_shape:
99-
raise Exception('Error in output file. Does not match expected model output shape.')
100-
y = np.concatenate(y, np.array(data)[np.newaxis, :], axis=0)
101-
output_dat.close()
102-
91+
output_file = f'{model.config.get_output_dir()}/tb_data/hw_results.dat'
92+
y = np.loadtxt(output_file, dtype=float).reshape(-1, expected_shape)
10393
return y
10494

10595
def hardware_predict(self, model, x):

hls4ml/writer/vitis_accelerator_writer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from shutil import copy, copytree
2+
from shutil import copy, copytree, rmtree
33

44
from hls4ml.writer.vitis_writer import VitisWriter
55

@@ -165,6 +165,8 @@ def write_host(self, model):
165165
# Write libraries
166166
src = os.path.join(filedir, '../templates/vitis_accelerator/libs')
167167
dst = f'{model.config.get_output_dir()}/libs'
168+
if os.path.exists(dst):
169+
rmtree(dst)
168170
copytree(src, dst, copy_function=copy)
169171

170172
def write_makefile(self, model):

0 commit comments

Comments
 (0)