Skip to content

Commit dd32d3b

Browse files
committed
update vitis nnet_pooling with some changes from vivado backend
1 parent 2cb6fe1 commit dd32d3b

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

hls4ml/templates/vitis/nnet_utils/nnet_pooling.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ struct pooling1d_config {
7070
static const unsigned n_out = (n_in - pool_width) / stride_width + 1;
7171
static const unsigned pad_left = 0;
7272
static const unsigned pad_right = 0;
73+
static const bool count_pad = false;
7374
// Pooling function
7475
static const Pool_Op pool_op = Max;
7576
};
@@ -130,6 +131,7 @@ void global_pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T r
130131

131132
for (int ff = 0; ff < CONFIG_T::n_filt; ff++) {
132133
data_T pool[CONFIG_T::n_in];
134+
#pragma HLS ARRAY_PARTITION variable=pool complete dim=0
133135
for (int jj = 0; jj < CONFIG_T::n_in; jj++) {
134136
pool[jj] = data[jj * CONFIG_T::n_filt + ff];
135137
}
@@ -154,6 +156,7 @@ struct pooling2d_config {
154156
static const unsigned pad_bottom = 0;
155157
static const unsigned pad_left = 0;
156158
static const unsigned pad_right = 0;
159+
static const bool count_pad = false;
157160
// Pooling function
158161
static const Pool_Op pool_op = Max;
159162
// Reuse factor
@@ -245,6 +248,7 @@ void pooling2d_cf(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
245248
// Loop over input image x in steps of stride
246249
for (int jj = 0; jj < padded_width; jj += CONFIG_T::stride_width) {
247250
data_T pool[CONFIG_T::pool_height * CONFIG_T::pool_width];
251+
#pragma HLS ARRAY_PARTITION variable=pool complete dim=0
248252
// Keep track of number of pixels in image vs padding region
249253
unsigned img_overlap = 0;
250254
// Loop over pool window y
@@ -255,10 +259,12 @@ void pooling2d_cf(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
255259
jj + ll < CONFIG_T::pad_left || jj + ll >= (padded_width - CONFIG_T::pad_right)) {
256260
// Add padding
257261
pool[kk * CONFIG_T::stride_width + ll] = pad_val<data_T, CONFIG_T::pool_op>();
262+
if (CONFIG_T::count_pad)
263+
img_overlap++;
258264
} else {
259265
pool[kk * CONFIG_T::stride_width + ll] =
260-
data[(ii + kk) * CONFIG_T::in_width + ff * CONFIG_T::in_width * CONFIG_T::in_height + ll +
261-
jj];
266+
data[(ii + kk - CONFIG_T::pad_top) * CONFIG_T::in_width +
267+
ff * CONFIG_T::in_width * CONFIG_T::in_height + ll + jj - CONFIG_T::pad_left];
262268
img_overlap++;
263269
}
264270
}

0 commit comments

Comments
 (0)