Skip to content

Commit b74b210

Browse files
change how stream loads in weights to be like parallel for conv transposes. unroll all stride steps completely
1 parent 645f8f4 commit b74b210

File tree

2 files changed

+107
-110
lines changed

2 files changed

+107
-110
lines changed

hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,42 +7,43 @@
77

88
namespace nnet {
99

10-
template<typename CONFIG_T>
11-
void weights_trim(
12-
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
13-
typename CONFIG_T::weight_t row_weights[
14-
CONFIG_T::n_filt * CONFIG_T::trfilt_width * CONFIG_T::n_chan
10+
template <typename CONFIG_T>
11+
void load_trfilt_weights_1d(
12+
typename CONFIG_T::weight_t trfilt_weights[CONFIG_T::stride_width][
13+
CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan
1514
],
16-
const int weight_start
15+
typename CONFIG_T::weight_t weights[
16+
CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt
17+
]
1718
)
1819
{
1920
#pragma HLS INLINE
20-
#pragma HLS PIPELINE II = 1
2121

22-
int row_indices[CONFIG_T::trfilt_width];
23-
for (int step = 0; step < CONFIG_T::trfilt_width; step++) {
24-
// #pragma HLS PIPELINE
22+
for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) {
2523
#pragma HLS UNROLL
26-
row_indices[step] = weight_start - step * CONFIG_T::stride_width;
27-
}
2824

29-
WeightsLoop: for (int step = 0; step < CONFIG_T::trfilt_width; step++) {
30-
#pragma HLS UNROLL
31-
#pragma HLS PIPELINE
32-
for (int filt_ind = 0; filt_ind < CONFIG_T::n_filt; filt_ind++) {
33-
#pragma HLS UNROLL
34-
#pragma HLS PIPELINE
35-
for (int chan_ind = 0; chan_ind < CONFIG_T::n_chan; chan_ind++) {
36-
#pragma HLS UNROLL
37-
#pragma HLS PIPELINE
38-
if (row_indices[step] >= CONFIG_T::filt_width) {
39-
row_weights[filt_ind * CONFIG_T::trfilt_width * CONFIG_T::n_chan +
40-
step * CONFIG_T::n_chan + chan_ind] = 0;
41-
} else {
42-
row_weights[filt_ind * CONFIG_T::trfilt_width * CONFIG_T::n_chan +
43-
step * CONFIG_T::n_chan + chan_ind] =
44-
weights[filt_ind * CONFIG_T::filt_width * CONFIG_T::n_chan +
45-
row_indices[step] * CONFIG_T::n_chan + chan_ind];
25+
for (unsigned i_fw = 0; i_fw < CONFIG_T::trfilt_width; i_fw++) {
26+
#pragma HLS UNROLL
27+
28+
unsigned filt_ind = i_sw + (CONFIG_T::trfilt_width-i_fw-1)*CONFIG_T::stride_width;
29+
for (unsigned i_nf = 0; i_nf < CONFIG_T::n_filt; i_nf++) {
30+
#pragma HLS UNROLL
31+
32+
for (unsigned i_nc = 0; i_nc < CONFIG_T::n_chan; i_nc++) {
33+
#pragma HLS UNROLL
34+
35+
if (filt_ind < CONFIG_T::filt_width) {
36+
trfilt_weights[i_sw][
37+
i_nf * CONFIG_T::n_chan * CONFIG_T::trfilt_width + i_fw * CONFIG_T::n_chan + i_nc
38+
] = weights[
39+
i_nf * CONFIG_T::n_chan * CONFIG_T::filt_width + filt_ind * CONFIG_T::n_chan + i_nc
40+
];
41+
}
42+
else {
43+
trfilt_weights[i_sw][
44+
i_nf * CONFIG_T::n_chan * CONFIG_T::trfilt_width + i_fw * CONFIG_T::n_chan + i_nc
45+
] = 0;
46+
}
4647
}
4748
}
4849
}
@@ -96,10 +97,6 @@ void compute_output_buffer_tr_1d(
9697
static typename data_T::value_type kernel_data[CONFIG_T::trfilt_width * CONFIG_T::n_chan];
9798
#pragma HLS ARRAY_PARTITION variable=kernel_data complete
9899

99-
typename CONFIG_T::weight_t row_weights[
100-
CONFIG_T::n_filt * CONFIG_T::trfilt_width * CONFIG_T::n_chan
101-
];
102-
103100
typename res_T::value_type res_out[CONFIG_T::n_filt];
104101
#pragma HLS ARRAY_PARTITION variable=res_out complete dim = 0
105102

@@ -109,22 +106,23 @@ void compute_output_buffer_tr_1d(
109106
// Add pixel to buffer
110107
nnet::kernel_shift_tr_1d<data_T, CONFIG_T>(in_elem, kernel_data);
111108

112-
int weight_start = CONFIG_T::stride_width * (CONFIG_T::trfilt_width-1);
109+
static typename CONFIG_T::weight_t trfilt_weights[CONFIG_T::stride_width][
110+
CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan
111+
];
112+
113+
load_trfilt_weights_1d<CONFIG_T>(trfilt_weights, weights);
113114

114115
//always do stride number of multiplications
115116
StrideLoop: for (int idx = 0; idx < CONFIG_T::stride_width; idx++) {
116-
//load in the weights for this multiplication
117-
weights_trim<CONFIG_T>(
118-
weights, row_weights, weight_start
119-
);
120-
117+
#pragma HLS UNROLL
118+
#pragma HLS INLINE region
121119
// Dense multiply
122120
if (CONFIG_T::strategy == nnet::latency) {
123121
dense_latency<typename data_T::value_type, typename res_T::value_type, typename CONFIG_T::mult_config>(
124-
kernel_data, res_out, row_weights, biases);
122+
kernel_data, res_out, trfilt_weights[idx], biases);
125123
} else {
126124
dense_resource<typename data_T::value_type, typename res_T::value_type, typename CONFIG_T::mult_config>(
127-
kernel_data, res_out, row_weights, biases);
125+
kernel_data, res_out, trfilt_weights[idx], biases);
128126
}
129127

130128
// Pack output
@@ -137,11 +135,10 @@ void compute_output_buffer_tr_1d(
137135
}
138136
// Write output to stream when output ready
139137
oX++;
140-
weight_start++;
138+
// weight_start++;
141139
}
142140

143141
// static var housekeeping
144-
// might need to zero the kernel? unsure...
145142
if (pX + 1 == CONFIG_T::in_width) // done with all of the inputs
146143
{
147144
pX = 0;

hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h

Lines changed: 67 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ void kernel_shift_tr_2d(
2121
#pragma HLS PIPELINE II = 1
2222
KernelShiftHeight: for (unsigned i_ih = 0; i_ih < CONFIG_T::trfilt_height; i_ih++) {
2323
KernelShiftChannel: for (unsigned i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) {
24-
// Shift every element in kernel_window to the left
24+
// Shift every element in kernel_window to the left
2525
kernel_window[i_ih * CONFIG_T::trfilt_width * CONFIG_T::n_chan + i_iw * CONFIG_T::n_chan + i_ic] = kernel_window[i_ih * CONFIG_T::trfilt_width * CONFIG_T::n_chan + (i_iw + 1) * CONFIG_T::n_chan + i_ic];
2626
}
2727
}
@@ -69,53 +69,50 @@ void shift_line_buffer_tr(const data_T& in_elem,
6969
}
7070

7171
template<typename CONFIG_T>
72-
void load_tr_kern_weights(
73-
typename CONFIG_T::weight_t weights[CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
74-
typename CONFIG_T::weight_t kernel_weights[
75-
CONFIG_T::n_filt * CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan
72+
void load_trfilt_weights(
73+
typename CONFIG_T::weight_t trfilt_weights[CONFIG_T::stride_height][CONFIG_T::stride_width][
74+
CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan
7675
],
77-
const int weight_x_start,
78-
const int weight_y_start
76+
typename CONFIG_T::weight_t weights[
77+
CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan * CONFIG_T::n_filt
78+
]
7979
)
8080
{
81-
int x_indices[CONFIG_T::trfilt_width];
82-
int y_indices[CONFIG_T::trfilt_height];
83-
for (int step = 0; step < CONFIG_T::trfilt_width; step++) {
84-
x_indices[step] = weight_x_start - step * CONFIG_T::stride_width;
85-
}
86-
for (int step = 0; step < CONFIG_T::trfilt_height; step++) {
87-
y_indices[step] = weight_y_start - step * CONFIG_T::stride_height;
88-
}
89-
90-
WeightsLoop: for (int x_step = 0; x_step < CONFIG_T::trfilt_width; x_step++) {
91-
#pragma HLS UNROLL
92-
#pragma HLS PIPELINE
93-
for (int y_step = 0; y_step < CONFIG_T::trfilt_height; y_step++) {
94-
#pragma HLS UNROLL
95-
#pragma HLS PIPELINE
96-
for (int filt_ind = 0; filt_ind < CONFIG_T::n_filt; filt_ind++) {
97-
#pragma HLS UNROLL
98-
#pragma HLS PIPELINE
99-
for (int chan_ind = 0; chan_ind < CONFIG_T::n_chan; chan_ind++) {
100-
#pragma HLS UNROLL
101-
#pragma HLS PIPELINE
102-
if (x_indices[x_step] >= CONFIG_T::filt_width || y_indices[y_step] >= CONFIG_T::filt_height) {
103-
kernel_weights[
104-
filt_ind * CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan +
105-
y_step * CONFIG_T::trfilt_width * CONFIG_T::n_chan +
106-
x_step * CONFIG_T::n_chan + chan_ind
107-
] = 0;
108-
}
109-
else {
110-
kernel_weights[
111-
filt_ind * CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan +
112-
y_step * CONFIG_T::trfilt_width * CONFIG_T::n_chan +
113-
x_step * CONFIG_T::n_chan + chan_ind
114-
] = weights[
115-
filt_ind * CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan +
116-
y_indices[y_step] * CONFIG_T::filt_width * CONFIG_T::n_chan +
117-
x_indices[x_step] * CONFIG_T::n_chan + chan_ind
118-
];
81+
#pragma HLS INLINE
82+
//pull out the individual filter weights (split kernel into stride_height x stride_width kernels)
83+
TrfiltWeightsLoop: for (unsigned i_sh = 0; i_sh < CONFIG_T::stride_height; i_sh++) {
84+
#pragma HLS UNROLL
85+
for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) {
86+
#pragma HLS UNROLL
87+
for (unsigned i_fh = 0; i_fh < CONFIG_T::trfilt_height; i_fh++) {
88+
#pragma HLS UNROLL
89+
for (unsigned i_fw = 0; i_fw < CONFIG_T::trfilt_width; i_fw++) {
90+
#pragma HLS UNROLL
91+
unsigned filt_h_ind = i_sh + (CONFIG_T::trfilt_height-i_fh-1)*CONFIG_T::stride_height;
92+
unsigned filt_w_ind = i_sw + (CONFIG_T::trfilt_width-i_fw-1)*CONFIG_T::stride_width;
93+
for (unsigned i_nf = 0; i_nf < CONFIG_T::n_filt; i_nf++) {
94+
#pragma HLS UNROLL
95+
for (unsigned i_nc = 0; i_nc < CONFIG_T::n_chan; i_nc++) {
96+
#pragma HLS UNROLL
97+
if (filt_h_ind < CONFIG_T::filt_height && filt_w_ind < CONFIG_T::filt_width) {
98+
trfilt_weights[i_sh][i_sw][
99+
i_nf * CONFIG_T::n_chan * CONFIG_T::trfilt_height * CONFIG_T::trfilt_width +
100+
i_fh * CONFIG_T::trfilt_width * CONFIG_T::n_chan +
101+
i_fw * CONFIG_T::n_chan + i_nc
102+
]= weights[
103+
i_nf * CONFIG_T::n_chan * CONFIG_T::filt_height * CONFIG_T::filt_width +
104+
filt_h_ind * CONFIG_T::n_chan * CONFIG_T::filt_width +
105+
filt_w_ind * CONFIG_T::n_chan + i_nc
106+
];
107+
}
108+
else {
109+
trfilt_weights[i_sh][i_sw][
110+
i_nf * CONFIG_T::n_chan * CONFIG_T::trfilt_height * CONFIG_T::trfilt_width +
111+
i_fh * CONFIG_T::trfilt_width * CONFIG_T::n_chan +
112+
i_fw * CONFIG_T::n_chan + i_nc
113+
] = 0;
114+
}
115+
}
119116
}
120117
}
121118
}
@@ -141,64 +138,67 @@ void compute_output_buffer_tr_2d(
141138
static typename data_T::value_type kernel_data[CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan];
142139
#pragma HLS ARRAY_PARTITION variable=kernel_data complete
143140

144-
typename CONFIG_T::weight_t kernel_weights[
145-
CONFIG_T::n_filt * CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan
141+
static typename CONFIG_T::weight_t trfilt_weights[CONFIG_T::stride_height][CONFIG_T::stride_width][
142+
CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan
146143
];
147144

145+
load_trfilt_weights<CONFIG_T>(trfilt_weights, weights);
146+
148147
typename res_T::value_type res_out[CONFIG_T::n_filt];
149148
#pragma HLS ARRAY_PARTITION variable=res_out complete dim = 0
150149

151150
static typename res_T::value_type output_buffer[
152151
CONFIG_T::in_width*CONFIG_T::stride_width*CONFIG_T::stride_height*CONFIG_T::n_filt
153152
];
153+
#pragma HLS ARRAY_PARTITION variable=output_buffer complete dim = 0
154154

155155
res_T res_pack;
156156
#pragma HLS DATA_PACK variable = res_pack
157157

158158
//Add pixel to the buffer
159159
nnet::shift_line_buffer_tr<data_T, CONFIG_T>(in_elem, line_buffer, kernel_data);
160160

161-
int weight_x_start = CONFIG_T::stride_width * (CONFIG_T::trfilt_width-1);
162-
int weight_y_start = CONFIG_T::stride_height * (CONFIG_T::trfilt_height-1);
161+
HeightStrideLoop: for (int w_idx = 0; w_idx < CONFIG_T::stride_width; w_idx++) {
162+
// #pragma HLS PIPELINE
163+
#pragma HLS UNROLL
164+
WidthStrideLoop: for (int h_idx = 0; h_idx < CONFIG_T::stride_height; h_idx++) {
165+
#pragma HLS UNROLL
163166

164-
WidthStrideLoop: for (int h_idx = 0; h_idx < CONFIG_T::stride_height; h_idx++) {
165-
weight_x_start = CONFIG_T::stride_height * (CONFIG_T::trfilt_width-1);
166-
HeightStrideLoop: for (int w_idx = 0; w_idx < CONFIG_T::stride_width; w_idx++) {
167-
load_tr_kern_weights<CONFIG_T>(
168-
weights, kernel_weights, weight_x_start, weight_y_start
169-
);
167+
#pragma HLS INLINE region
170168

171169
if (CONFIG_T::strategy == nnet::latency) {
172170
dense_latency<typename data_T::value_type, typename res_T::value_type, typename CONFIG_T::mult_config>(
173-
kernel_data, res_out, kernel_weights, biases
171+
kernel_data, res_out, trfilt_weights[h_idx][w_idx], biases
174172
);
175173
} else {
176174
dense_resource<typename data_T::value_type, typename res_T::value_type, typename CONFIG_T::mult_config>(
177-
kernel_data, res_out, kernel_weights, biases
175+
kernel_data, res_out, trfilt_weights[h_idx][w_idx], biases
178176
);
179177
}
180178

181179
BufferOutputLoop: for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) {
180+
#pragma HLS UNROLL
182181
output_buffer[
183182
(pX*CONFIG_T::stride_width+w_idx)*CONFIG_T::stride_height*CONFIG_T::n_filt +
184183
h_idx*CONFIG_T::n_filt + i_ic
185184
] = res_out[i_ic];
185+
// res_pack[i_ic] = res_out[i_ic];
186186
}
187+
// res_stream.write(res_pack);
187188

188-
weight_x_start++;
189189
}
190-
weight_y_start++;
191190
}
192191

193-
//Counter Housekeeping
194-
if (pX + 1 == CONFIG_T::in_width) //HAVE TO THINK ABOUT oX, oY STUFF. NOT AS EASY AS INCREMENTING
195-
{
192+
//Counter Housekeeping and printing buffered output
193+
if (pX + 1 == CONFIG_T::in_width) {
196194
pX = 0;
197-
//write all of the buffered output
198-
for (int h_idx = 0; h_idx < CONFIG_T::stride_height; h_idx++) {
195+
//write all of the buffered output for outputs we want
196+
HeightOutputLoop: for (unsigned h_idx = 0; h_idx < CONFIG_T::stride_height; h_idx++) {
197+
// #pragma HLS PIPELINE
199198
if (pY*CONFIG_T::stride_height + h_idx >= CONFIG_T::pad_top &&
200-
pY*CONFIG_T::stride_height +h_idx < CONFIG_T::pad_top + CONFIG_T::out_height) {
201-
for (int oX = CONFIG_T::pad_left; oX < CONFIG_T::pad_left + CONFIG_T::out_width; oX++) {
199+
pY*CONFIG_T::stride_height + h_idx < CONFIG_T::pad_top + CONFIG_T::out_height) {
200+
WidthOutputLoop: for (unsigned oX = CONFIG_T::pad_left; oX < CONFIG_T::pad_left + CONFIG_T::out_width; oX++) {
201+
#pragma HLS PIPELINE
202202
CastLoop: for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) {
203203
#pragma HLS UNROLL
204204
res_pack[i_ic] = output_buffer[
@@ -219,7 +219,7 @@ void compute_output_buffer_tr_2d(
219219
} else {
220220
pX = pX + 1;
221221
}
222-
222+
223223
}
224224

225225
template<class data_T, class res_T, typename CONFIG_T>

0 commit comments

Comments
 (0)