Skip to content

Commit 89dc007

Browse files
calad0ivloncar
authored andcommitted
fix merge templates
1 parent 395bcae commit 89dc007

File tree

9 files changed

+32
-27
lines changed

9 files changed

+32
-27
lines changed

hls4ml/templates/catapult/nnet_utils/nnet_merge.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,21 +58,21 @@ void multiply(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem]
5858
template <class input1_T, class input2_T, class res_T, typename CONFIG_T>
5959
void average(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem], res_T res[CONFIG_T::n_elem]) {
6060
for (int ii = 0; ii < CONFIG_T::n_elem; ii++) {
61-
res[ii] = (data1[ii] + data2[ii]) / (res_T)2;
61+
res[ii] = (data1[ii] + data2[ii]) * ac_fixed<1, 0, false>(0.5);
6262
}
6363
}
6464

6565
template <class input1_T, class input2_T, class res_T, typename CONFIG_T>
6666
void maximum(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem], res_T res[CONFIG_T::n_elem]) {
6767
for (int ii = 0; ii < CONFIG_T::n_elem; ii++) {
68-
res[ii] = (data1[ii] > data2[ii]) ? data1[ii] : data2[ii];
68+
res[ii] = (data1[ii] > data2[ii]) ? static_cast<res_T>(data1[ii]) : static_cast<res_T>(data2[ii]);
6969
}
7070
}
7171

7272
template <class input1_T, class input2_T, class res_T, typename CONFIG_T>
7373
void minimum(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem], res_T res[CONFIG_T::n_elem]) {
7474
for (int ii = 0; ii < CONFIG_T::n_elem; ii++) {
75-
res[ii] = (data1[ii] < data2[ii]) ? data1[ii] : data2[ii];
75+
res[ii] = (data1[ii] < data2[ii]) ? static_cast<res_T>(data1[ii]) : static_cast<res_T>(data2[ii]);
7676
}
7777
}
7878

hls4ml/templates/catapult/nnet_utils/nnet_merge_stream.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ void average(ac_channel<input1_T> &data1, ac_channel<input2_T> &data2, ac_channe
9797
AveragePack:
9898
for (int j = 0; j < res_T::size; j++) {
9999
// #pragma HLS UNROLL
100-
out_data[j] = (in_data1[j] + in_data2[j]) / (typename res_T::value_type)2;
100+
out_data[j] = (in_data1[j] + in_data2[j]) * ac_fixed<1, 0, false>(0.5);
101101
}
102102

103103
res.write(out_data);
@@ -122,7 +122,7 @@ void maximum(ac_channel<input1_T> &data1, ac_channel<input2_T> &data2, ac_channe
122122
MaximumPack:
123123
for (int j = 0; j < res_T::size; j++) {
124124
// #pragma HLS UNROLL
125-
out_data[j] = (in_data1[j] > in_data2[j]) ? in_data1[j] : in_data2[j];
125+
out_data[j] = (in_data1[j] > in_data2[j]) ? static_cast<res_T>(in_data1[j]) : static_cast<res_T>(in_data2[j]);
126126
}
127127

128128
res.write(out_data);
@@ -147,7 +147,7 @@ void minimum(ac_channel<input1_T> &data1, ac_channel<input2_T> &data2, ac_channe
147147
MinimumPack:
148148
for (int j = 0; j < res_T::size; j++) {
149149
// #pragma HLS UNROLL
150-
out_data[j] = (in_data1[j] < in_data2[j]) ? in_data1[j] : in_data2[j];
150+
out_data[j] = (in_data1[j] < in_data2[j]) ? static_cast<res_T>(in_data1[j]) : static_cast<res_T>(in_data2[j]);
151151
}
152152

153153
res.write(out_data);

hls4ml/templates/oneapi/firmware/nnet_utils/nnet_merge.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,17 @@ template <class input1_T, class input2_T, class res_T, typename CONFIG_T>
6868
void maximum(const input1_T &data1, const input2_T &data2, res_T &res) {
6969
#pragma unroll
7070
for (int i = 0; i < CONFIG_T::n_elem; i++) {
71-
res[i] = static_cast<typename res_T::value_type>((data1[i] > data2[i]) ? data1[i] : data2[i]);
71+
res[i] = static_cast<typename res_T::value_type>((data1[i] > data2[i]) ? static_cast<res_T>(data1[i])
72+
: static_cast<res_T>(data2[i]));
7273
}
7374
}
7475

7576
template <class input1_T, class input2_T, class res_T, typename CONFIG_T>
7677
void minimum(const input1_T &data1, const input2_T &data2, res_T &res) {
7778
#pragma unroll
7879
for (int i = 0; i < CONFIG_T::n_elem; i++) {
79-
res[i] = static_cast<typename res_T::value_type>((data1[i] < data2[i]) ? data1[i] : data2[i]);
80+
res[i] = static_cast<typename res_T::value_type>((data1[i] < data2[i]) ? static_cast<res_T>(data1[i])
81+
: static_cast<res_T>(data2[i]));
8082
}
8183
}
8284

hls4ml/templates/oneapi/firmware/nnet_utils/nnet_merge_stream.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ template <class input1_pipe, class input2_pipe, class res_pipe, typename CONFIG_
8585
#pragma unroll
8686
for (int j = 0; j < outputSize; j++) {
8787
out_data[j] = static_cast<typename ExtractPipeType<res_pipe>::value_type::value_type>(
88-
(in_data1[j] + in_data2[j]) / (typename ExtractPipeType<res_pipe>::value_type::value_type)2);
88+
(in_data1[j] + in_data2[j]) * ac_fixed<1, 0, false>(0.5));
8989
}
9090

9191
res_pipe::write(out_data);
@@ -108,7 +108,7 @@ template <class input1_pipe, class input2_pipe, class res_pipe, typename CONFIG_
108108
#pragma unroll
109109
for (int j = 0; j < outputSize; j++) {
110110
out_data[j] = static_cast<typename ExtractPipeType<res_pipe>::value_type::value_type>(
111-
(in_data1[j] > in_data2[j]) ? in_data1[j] : in_data2[j]);
111+
(in_data1[j] > in_data2[j]) ? static_cast<res_T>(in_data1[j]) : static_cast<res_T>(in_data2[j]));
112112
}
113113

114114
res_pipe::write(out_data);
@@ -131,7 +131,7 @@ template <class input1_pipe, class input2_pipe, class res_pipe, typename CONFIG_
131131
#pragma unroll
132132
for (int j = 0; j < outputSize; j++) {
133133
out_data[j] = static_cast<typename ExtractPipeType<res_pipe>::value_type::value_type>(
134-
(in_data1[j] < in_data2[j]) ? in_data1[j] : in_data2[j]);
134+
(in_data1[j] < in_data2[j]) ? static_cast<res_T>(in_data1[j]) : static_cast<res_T>(in_data2[j]););
135135
}
136136

137137
res_pipe::write(out_data);

hls4ml/templates/quartus/firmware/nnet_utils/nnet_merge.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ template <class input1_T, class input2_T, class res_T, typename CONFIG_T>
6060
void average(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem], res_T res[CONFIG_T::n_elem]) {
6161
#pragma unroll
6262
for (int i = 0; i < CONFIG_T::n_elem; i++) {
63-
res[i] = static_cast<res_T>((data1[i] + data2[i]) / (res_T)2);
63+
res[i] = static_cast<res_T>((data1[i] + data2[i]) * ac_fixed<1, 0, false>(0.5));
6464
}
6565
}
6666

hls4ml/templates/quartus/firmware/nnet_utils/nnet_merge_stream.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,7 @@ void average(stream<input1_T> &data1, stream<input2_T> &data2, stream<res_T> &re
8484
AvgPack:
8585
#pragma unroll
8686
for (int j = 0; j < res_T::size; j++) {
87-
out_data[j] =
88-
static_cast<typename res_T::value_type>((in_data1[j] + in_data2[j]) / (typename res_T::value_type)2);
87+
out_data[j] = static_cast<typename res_T::value_type>((in_data1[j] + in_data2[j]) * ac_fixed<1, 0, false>(0.5));
8988
}
9089

9190
res.write(out_data);
@@ -107,8 +106,8 @@ void maximum(stream<input1_T> &data1, stream<input2_T> &data2, stream<res_T> &re
107106
MaxPack:
108107
#pragma unroll
109108
for (int j = 0; j < res_T::size; j++) {
110-
out_data[j] = static_cast<typename res_T::value_type>(out_data[j] = (in_data1[j] > in_data2[j]) ? in_data1[j]
111-
: in_data2[j]);
109+
out_data[j] = (in_data1[j] > in_data2[j]) ? static_cast<typename res_T::value_type>(in_data1[j])
110+
: static_cast<typename res_T::value_type>(in_data2[j]);
112111
}
113112

114113
res.write(out_data);
@@ -130,8 +129,8 @@ void minimum(stream<input1_T> &data1, stream<input2_T> &data2, stream<res_T> &re
130129
MinPack:
131130
#pragma unroll
132131
for (int j = 0; j < res_T::size; j++) {
133-
out_data[j] = static_cast<typename res_T::value_type>(out_data[j] = (in_data1[j] < in_data2[j]) ? in_data1[j]
134-
: in_data2[j]);
132+
out_data[j] = (in_data1[j] < in_data2[j]) ? static_cast<typename res_T::value_type>(in_data1[j])
133+
: static_cast<typename res_T::value_type>(in_data2[j]);
135134
}
136135

137136
res.write(out_data);

hls4ml/templates/vivado/nnet_utils/nnet_merge.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ void average(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem],
6565
#pragma HLS PIPELINE
6666

6767
for (int ii = 0; ii < CONFIG_T::n_elem; ii++) {
68-
res[ii] = (data1[ii] + data2[ii]) / (res_T)2;
68+
res[ii] = (data1[ii] + data2[ii]) * ap_ufixed<1, 0>(0.5);
6969
}
7070
}
7171

@@ -74,7 +74,7 @@ void maximum(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem],
7474
#pragma HLS PIPELINE
7575

7676
for (int ii = 0; ii < CONFIG_T::n_elem; ii++) {
77-
res[ii] = (data1[ii] > data2[ii]) ? data1[ii] : data2[ii];
77+
res[ii] = (data1[ii] > data2[ii]) ? static_cast<res_T>(data1[ii]) : static_cast<res_T>(data2[ii]);
7878
}
7979
}
8080

@@ -83,7 +83,7 @@ void minimum(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem],
8383
#pragma HLS PIPELINE
8484

8585
for (int ii = 0; ii < CONFIG_T::n_elem; ii++) {
86-
res[ii] = (data1[ii] < data2[ii]) ? data1[ii] : data2[ii];
86+
res[ii] = (data1[ii] < data2[ii]) ? static_cast<res_T>(data1[ii]) : static_cast<res_T>(data2[ii]);
8787
}
8888
}
8989

hls4ml/templates/vivado/nnet_utils/nnet_merge_stream.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ void average(hls::stream<input1_T> &data1, hls::stream<input2_T> &data2, hls::st
9292
AveragePack:
9393
for (int j = 0; j < res_T::size; j++) {
9494
#pragma HLS UNROLL
95-
out_data[j] = (in_data1[j] + in_data2[j]) / (typename res_T::value_type)2;
95+
out_data[j] = (in_data1[j] + in_data2[j]) * ap_ufixed<1, 0>(0.5);
9696
}
9797

9898
res.write(out_data);
@@ -115,7 +115,8 @@ void maximum(hls::stream<input1_T> &data1, hls::stream<input2_T> &data2, hls::st
115115
MaximumPack:
116116
for (int j = 0; j < res_T::size; j++) {
117117
#pragma HLS UNROLL
118-
out_data[j] = (in_data1[j] > in_data2[j]) ? in_data1[j] : in_data2[j];
118+
out_data[j] = (in_data1[j] > in_data2[j]) ? static_cast<typename res_T::value_type>(in_data1[j])
119+
: static_cast<typename res_T::value_type>(in_data2[j]);
119120
}
120121

121122
res.write(out_data);
@@ -138,7 +139,8 @@ void minimum(hls::stream<input1_T> &data1, hls::stream<input2_T> &data2, hls::st
138139
MinimumPack:
139140
for (int j = 0; j < res_T::size; j++) {
140141
#pragma HLS UNROLL
141-
out_data[j] = (in_data1[j] < in_data2[j]) ? in_data1[j] : in_data2[j];
142+
out_data[j] = (in_data1[j] < in_data2[j]) ? static_cast<typename res_T::value_type>(in_data1[j])
143+
: static_cast<typename res_T::value_type>(in_data2[j]);
142144
}
143145

144146
res.write(out_data);

test/pytest/test_merge.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
def test_merge(merge_layer, io_type, backend, swap_inputs):
1818
input_shape = (10, 10, 3)
1919

20-
in1 = Input(shape=input_shape)
21-
in2 = Input(shape=input_shape)
20+
in1 = Input(shape=input_shape, name='inp1')
21+
in2 = Input(shape=input_shape, name='inp2')
2222
if swap_inputs:
2323
out = merge_layer()([in2, in1])
2424
else:
@@ -27,11 +27,13 @@ def test_merge(merge_layer, io_type, backend, swap_inputs):
2727
model = tf.keras.models.Model(inputs=[in1, in2], outputs=out)
2828
model.compile()
2929

30-
config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,16>')
3130
output_dir = str(
3231
test_root_path
3332
/ f'hls4mlprj_merge_{"swap_inputs_" if swap_inputs else ""}{merge_layer.__name__.lower()}_{backend}_{io_type}'
3433
)
34+
35+
config = {'Model': {'Precision': 'fixed<32,16>', 'ReuseFactor': 1}, 'LayerName': {'inp2': {'Precision': 'fixed<32,15>'}}}
36+
3537
hls_model = hls4ml.converters.convert_from_keras_model(
3638
model, hls_config=config, output_dir=output_dir, io_type=io_type, backend=backend
3739
)

0 commit comments

Comments
 (0)