Skip to content

Commit 0da18f0

Browse files
optimize conv transpose resource to get it working reasonably well. may still have slight optimization left
1 parent 8092409 commit 0da18f0

File tree

1 file changed

+24
-12
lines changed

1 file changed

+24
-12
lines changed

hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_resource.h

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,33 +30,41 @@ void conv_1d_transpose_resource_cl(
3030
typename CONFIG_T::accum_t acc[CONFIG_T::n_pixels][mult_n_out][CONFIG_T::stride_width];
3131
#pragma HLS ARRAY_PARTITION variable=acc complete dim=0
3232

33-
typename CONFIG_T::weight_t trfilt_weights[CONFIG_T::stride_width][
33+
typename CONFIG_T::weight_t trfilt_weights[
3434
CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan
35-
];
35+
][CONFIG_T::stride_width];
3636

3737
for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) {
38+
#pragma HLS UNROLL
39+
3840
for (unsigned i_fw = 0; i_fw < CONFIG_T::trfilt_width; i_fw++) {
41+
#pragma HLS UNROLL
42+
3943
unsigned filt_ind = i_sw + (CONFIG_T::trfilt_width-i_fw-1)*CONFIG_T::stride_width;
4044
for (unsigned i_nf = 0; i_nf < CONFIG_T::n_filt; i_nf++) {
45+
#pragma HLS UNROLL
46+
4147
for (unsigned i_nc = 0; i_nc < CONFIG_T::n_chan; i_nc++) {
42-
if (i_fw < CONFIG_T::filt_width) {
43-
trfilt_weights[i_sw][
48+
#pragma HLS UNROLL
49+
50+
if (filt_ind < CONFIG_T::filt_width) {
51+
trfilt_weights[
4452
i_nf * CONFIG_T::n_chan * CONFIG_T::trfilt_width + i_fw * CONFIG_T::n_chan + i_nc
45-
] = weights[
53+
][i_sw] = weights[
4654
filt_ind * CONFIG_T::n_filt * CONFIG_T::n_chan + i_nf * CONFIG_T::n_chan + i_nc
4755
];
4856
}
4957
else {
50-
trfilt_weights[i_sw][
58+
trfilt_weights[
5159
i_fw * CONFIG_T::n_chan + i_nf * CONFIG_T::n_chan * CONFIG_T::trfilt_width + i_nc
52-
] = 0;
60+
][i_sw] = 0;
5361
}
5462
}
5563
}
5664
}
5765
}
5866

59-
#pragma HLS ARRAY_RESHAPE variable=trfilt_weights block factor=block_factor dim=2
67+
#pragma HLS ARRAY_RESHAPE variable=trfilt_weights block factor=block_factor dim=1
6068

6169
PartitionLoop:
6270
for (unsigned i_part = 0; i_part < CONFIG_T::n_partitions; i_part++) {
@@ -101,7 +109,7 @@ void conv_1d_transpose_resource_cl(
101109

102110
acc[i_pxl][i_out][i_sw] += static_cast<typename CONFIG_T::accum_t>(
103111
CONFIG_T::mult_config::template product<data_T, typename CONFIG_T::mult_config::weight_t>::product(
104-
data_buf[i_pxl][i_in], trfilt_weights[i_sw][i_w]
112+
data_buf[i_pxl][i_in], trfilt_weights[i_w][i_sw]
105113
)
106114
);
107115
}
@@ -124,21 +132,25 @@ void conv_1d_transpose_resource_cl(
124132
}
125133
}
126134

135+
127136
PixelResultLoop:
128137
for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) {
129138
#pragma HLS UNROLL
130139

131140
StrideResultLoop:
132141
for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) {
133142
#pragma HLS UNROLL
143+
144+
unsigned output_index = i_pxl * CONFIG_T::n_partitions * CONFIG_T::stride_width +
145+
i_part * CONFIG_T::stride_width + i_sw;
134146

135-
if (i_pxl * CONFIG_T::n_partitions * CONFIG_T::stride_width + i_part * CONFIG_T::stride_width + i_sw >= CONFIG_T::pad_left &&
136-
i_pxl * CONFIG_T::n_partitions * CONFIG_T::stride_width + i_part * CONFIG_T::stride_width + i_sw < CONFIG_T::out_width + CONFIG_T::pad_left) {
147+
if (output_index >= CONFIG_T::pad_left &&
148+
output_index < CONFIG_T::out_width + CONFIG_T::pad_left) {
137149
ResultLoop:
138150
for (unsigned i_res = 0; i_res < mult_n_out; i_res++) {
139151
#pragma HLS UNROLL
140152

141-
*(res++) = cast<data_T, res_T, typename CONFIG_T::mult_config>(acc[i_pxl][i_res][i_sw]);
153+
res[output_index][i_res] = cast<data_T, res_T, typename CONFIG_T::mult_config>(acc[i_pxl][i_res][i_sw]);
142154
}
143155
}
144156
}

0 commit comments

Comments
 (0)