Skip to content

Commit 39d0e91

Browse files
authored
Merge pull request #1089 from vloncar/weight_txt_path
Hardcore weight txt path
2 parents 352c124 + 12034d3 commit 39d0e91

File tree

4 files changed

+20
-24
lines changed

4 files changed

+20
-24
lines changed

hls4ml/model/graph.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -763,32 +763,24 @@ def predict(self, x):
763763
n_inputs = len(self.get_input_variables())
764764
n_outputs = len(self.get_output_variables())
765765

766-
curr_dir = os.getcwd()
767-
os.chdir(self.config.get_output_dir() + '/firmware')
768-
769766
output = []
770767
if n_samples == 1 and n_inputs == 1:
771768
x = [x]
772769

773-
try:
774-
for i in range(n_samples):
775-
predictions = [np.zeros(yj.size(), dtype=ctype) for yj in self.get_output_variables()]
776-
if n_inputs == 1:
777-
inp = [np.asarray(x[i])]
778-
else:
779-
inp = [np.asarray(xj[i]) for xj in x]
780-
argtuple = inp
781-
argtuple += predictions
782-
argtuple = tuple(argtuple)
783-
top_function(*argtuple)
784-
output.append(predictions)
785-
786-
# Convert to list of numpy arrays (one for each output)
787-
output = [
788-
np.asarray([output[i_sample][i_output] for i_sample in range(n_samples)]) for i_output in range(n_outputs)
789-
]
790-
finally:
791-
os.chdir(curr_dir)
770+
for i in range(n_samples):
771+
predictions = [np.zeros(yj.size(), dtype=ctype) for yj in self.get_output_variables()]
772+
if n_inputs == 1:
773+
inp = [np.asarray(x[i])]
774+
else:
775+
inp = [np.asarray(xj[i]) for xj in x]
776+
argtuple = inp
777+
argtuple += predictions
778+
argtuple = tuple(argtuple)
779+
top_function(*argtuple)
780+
output.append(predictions)
781+
782+
# Convert to list of numpy arrays (one for each output)
783+
output = [np.asarray([output[i_sample][i_output] for i_sample in range(n_samples)]) for i_output in range(n_outputs)]
792784

793785
if n_samples == 1 and n_outputs == 1:
794786
return output[0][0]

hls4ml/templates/catapult/myproject_bridge.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <algorithm>
77
#include <map>
88

9-
static std::string s_weights_dir = "weights";
9+
// hls-fpga-machine-learning insert weights dir
1010

1111
const char *get_weights_dir() { return s_weights_dir.c_str(); }
1212

hls4ml/templates/vivado/build_lib.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ LDFLAGS=
1111
INCFLAGS="-Ifirmware/ap_types/"
1212
PROJECT=myproject
1313
LIB_STAMP=mystamp
14-
WEIGHTS_DIR="\"weights\""
14+
BASEDIR="$(cd "$(dirname "$0")" && pwd)"
15+
WEIGHTS_DIR="\"${BASEDIR}/firmware/weights\""
1516

1617
${CC} ${CFLAGS} ${INCFLAGS} -D WEIGHTS_DIR=${WEIGHTS_DIR} -c firmware/${PROJECT}.cpp -o ${PROJECT}.o
1718
${CC} ${CFLAGS} ${INCFLAGS} -D WEIGHTS_DIR=${WEIGHTS_DIR} -c ${PROJECT}_bridge.cpp -o ${PROJECT}_bridge.o

hls4ml/writer/catapult_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,9 @@ def write_bridge(self, model):
676676
newline = line.replace('MYPROJECT', format(model.config.get_project_name().upper()))
677677
elif 'myproject' in line:
678678
newline = line.replace('myproject', format(model.config.get_project_name()))
679+
elif '// hls-fpga-machine-learning insert weights dir' in line:
680+
weights_dir = (Path(fout.name).parent / 'firmware/weights').resolve()
681+
newline = f'static std::string s_weights_dir = "{weights_dir}";\n'
679682
elif '// hls-fpga-machine-learning insert bram' in line:
680683
newline = line
681684
for bram in model_brams:

0 commit comments

Comments
 (0)