Skip to content

Commit 6073263

Browse files
add conv2dtranspose io_parallel implementation. Can still be optimized
1 parent 12ba91e commit 6073263

File tree

6 files changed

+343
-16
lines changed

6 files changed

+343
-16
lines changed

hls4ml/backends/fpga/fpga_backend.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,91 @@ def generate_conv2d_line_buffer_fn(self, layer_idx, n_partitions, in_H, in_W, in
656656

657657
return generated_code
658658

659+
def _compute_conv2d_tr_im2col(self, input_shape, out_shape, kernel=(3, 3), stride=(1, 1)):
660+
H, W, C = input_shape
661+
kernel_h, kernel_w = kernel
662+
stride_h, stride_w = stride
663+
out_h, out_w = out_shape
664+
665+
tr_kernel_h = (kernel_h+stride_h-1)//stride_h
666+
tr_kernel_w = (kernel_w+stride_w-1)//stride_w
667+
668+
input_img = np.arange(1, H * W * C + 1)
669+
im_matrix = np.zeros((tr_kernel_h * tr_kernel_w * C * out_h * out_w, ))
670+
671+
index = 0
672+
for i_oh in range(out_h):
673+
for i_ow in range(out_w):
674+
for i_kh in range(tr_kernel_h):
675+
input_row = i_oh - (tr_kernel_h-1) + i_kh
676+
for i_kw in range(tr_kernel_w):
677+
for i_c in range(C):
678+
if (input_row < 0 or input_row >= H):
679+
im_matrix[index] = 0
680+
else:
681+
input_col = i_ow - (tr_kernel_w-1) + i_kw
682+
if (input_col >= 0 and input_col < W):
683+
im_matrix[index] = input_img[input_row * W * C + input_col * C + i_c]
684+
else:
685+
im_matrix[index] = 0
686+
index += 1
687+
688+
im_matrix = im_matrix.reshape(out_h * out_w, -1)
689+
return im_matrix
690+
691+
692+
def generate_conv2d_tr_line_buffer_fn(self, layer_idx, n_partitions, in_H, in_W, in_C, out_H, out_W, kernel=(3, 3), stride=(1, 1)):
693+
if isinstance(kernel, Iterable):
694+
kernel_height = kernel[0]
695+
kernel_width = kernel[1]
696+
else:
697+
kernel_height = kernel
698+
kernel_width = kernel
699+
700+
if isinstance(stride, Iterable):
701+
stride_height = stride[0]
702+
stride_width = stride[1]
703+
else:
704+
stride_height = stride
705+
stride_width = stride
706+
707+
im2col_matrix = self._compute_conv2d_tr_im2col(
708+
(in_H, in_W, in_C),
709+
(out_W, out_W),
710+
(kernel_height, kernel_width),
711+
(stride_height, stride_width),
712+
)
713+
714+
generated_code = (
715+
"template<class data_T, typename CONFIG_T>\n"
716+
"class fill_buffer_{index} : public FillConv2DBuffer<data_T, CONFIG_T> {{\n"
717+
" public:\n"
718+
" static void fill_buffer(\n"
719+
" data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan],\n"
720+
" data_T buffer[CONFIG_T::n_pixels][CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan],\n"
721+
" const unsigned partition\n"
722+
" ) {{\n"
723+
).format(index=layer_idx)
724+
indent = ' '
725+
726+
for partition_idx, partition in enumerate(np.split(im2col_matrix, n_partitions)):
727+
generated_code += indent * 2 + 'if (partition == {:>3}) {{\n'.format(partition_idx)
728+
for pixel_idx, arr in enumerate(partition):
729+
buffer_stmts = []
730+
for j, v in enumerate(arr):
731+
if v == 0:
732+
val = '0'
733+
else:
734+
val = 'data[{}]'.format(int(v-1))
735+
buffer_stmts.append('buffer[{}][{}] = {:>10};'.format(pixel_idx, j, val))
736+
generated_code += indent * 3 + ' '.join(buffer_stmts) + '\n'
737+
generated_code += '\n' + indent * 2 + '}\n'
738+
739+
generated_code += indent + '}\n'
740+
generated_code += '};\n'
741+
742+
return generated_code
743+
659744
@model_optimizer()
660745
def write_hls(self, model):
661746
self.writer.write_hls(model)

hls4ml/backends/fpga/passes/codegen.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from hls4ml.model.optimizer import OptimizerPass
2-
from hls4ml.model.layers import Conv1D, Conv2D, Conv1DTranspose
2+
from hls4ml.model.layers import Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose
33
from hls4ml.model.types import Source
44

55
class GenerateConvIm2col(OptimizerPass):
66
''' Generates tcode for im2col step of 1D/2d convolution '''
77
def match(self, node):
8-
return isinstance(node, (Conv1D, Conv2D, Conv1DTranspose)) and \
8+
return isinstance(node, (Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose)) and \
99
node.model.config.get_config_value('IOType') == 'io_parallel'
1010

1111
def transform(self, model, node):
@@ -14,6 +14,8 @@ def transform(self, model, node):
1414
self._generate_im2col_1d_transpose(node)
1515
elif '1D' in node_class:
1616
self._generate_im2col_1d(node)
17+
elif '2DTranspose' in node_class:
18+
self._generate_im2col_2d_transpose(node)
1719
elif '2D' in node_class:
1820
self._generate_im2col_2d(node)
1921
else:
@@ -38,7 +40,7 @@ def _generate_im2col_1d_transpose(self, node):
3840
node.get_attr('n_partitions'),
3941
node.get_input_variable().shape[0],
4042
node.get_input_variable().shape[1],
41-
node.get_attr('num_out'),
43+
node.get_attr('proc_width'),
4244
kernel=node.get_attr('filt_width'),
4345
stride=node.get_attr('stride_width'),
4446
)
@@ -58,3 +60,18 @@ def _generate_im2col_2d(self, node):
5860
)
5961

6062
node.set_attr('line_buffer_codegen', Source(code_str))
63+
64+
def _generate_im2col_2d_transpose(self, node):
65+
code_str = node.model.config.backend.generate_conv2d_tr_line_buffer_fn(
66+
node.get_attr('index'),
67+
node.get_attr('n_partitions'),
68+
node.get_input_variable().shape[0],
69+
node.get_input_variable().shape[1],
70+
node.get_input_variable().shape[2],
71+
node.get_attr('proc_height'),
72+
node.get_attr('proc_width'),
73+
kernel=(node.get_attr('filt_height'), node.get_attr('filt_width')),
74+
stride=(node.get_attr('stride_height'), node.get_attr('stride_width')),
75+
)
76+
77+
node.set_attr('line_buffer_codegen', Source(code_str))

hls4ml/backends/vivado/passes/convolution_templates.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,18 +111,18 @@ def format(self, node):
111111
static const unsigned strategy = nnet::{strategy};
112112
static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation};
113113
static const unsigned min_width = {min_width};
114-
static const ap_uint<filt_width> pixels[min_width];
114+
static const ap_uint<trfilt_width> pixels[min_width];
115115
static const unsigned n_partitions = {n_partitions};
116-
static const unsigned num_out = {num_out};
117-
static const unsigned n_pixels = num_out / n_partitions;
116+
static const unsigned proc_width = {proc_width};
117+
static const unsigned n_pixels = proc_width / n_partitions;
118118
template<class data_T, class CONFIG_T>
119119
using fill_buffer = nnet::{fill_fn}<data_T, CONFIG_T>;
120120
typedef {accum_t.name} accum_t;
121121
typedef {bias_t.name} bias_t;
122122
typedef {weight_t.name} weight_t;
123123
typedef {config_t} mult_config;
124124
}};
125-
const ap_uint<config{index}::filt_width> config{index}::pixels[] = {{{instructions}}};\n"""
125+
const ap_uint<config{index}::trfilt_width> config{index}::pixels[] = {{{instructions}}};\n"""
126126

127127
conv1dtranspose_function_template = 'nnet::conv_1d_transpose_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
128128

@@ -282,13 +282,19 @@ def __init__(self):
282282
static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation};
283283
static const unsigned min_height = {min_height};
284284
static const unsigned min_width = {min_width};
285-
static const ap_uint<filt_height * filt_width> pixels[min_height * min_width];
285+
static const ap_uint<trfilt_height * trfilt_width> pixels[min_height * min_width];
286+
static const unsigned n_partitions = {n_partitions};
287+
static const unsigned proc_height = {proc_height};
288+
static const unsigned proc_width = {proc_width};
289+
static const unsigned n_pixels = proc_height * proc_width / n_partitions;
290+
template<class data_T, class CONFIG_T>
291+
using fill_buffer = nnet::{fill_fn}<data_T, CONFIG_T>;
286292
typedef {accum_t.name} accum_t;
287293
typedef {bias_t.name} bias_t;
288294
typedef {weight_t.name} weight_t;
289295
typedef {config_t} mult_config;
290296
}};
291-
const ap_uint<config{index}::filt_height * config{index}::filt_width> config{index}::pixels[] = {{{instructions}}};\n"""
297+
const ap_uint<config{index}::trfilt_height * config{index}::trfilt_width> config{index}::pixels[] = {{{instructions}}};\n"""
292298

293299
conv2dtranspose_function_template = 'nnet::conv_2d_transpose_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
294300

@@ -310,6 +316,10 @@ def format(self, node):
310316
// node.get_attr('stride_height')
311317

312318
params['config_t'] = 'config{}_mult'.format(node.index)
319+
if node.model.config.get_config_value('IOType') == 'io_parallel':
320+
params['fill_fn'] = 'fill_buffer_{}'.format(node.index)
321+
else:
322+
params['fill_fn'] = 'FillConv2DBuffer'
313323
conv_config = self.template.format(**params)
314324

315325
mult_params = self._default_config_params(node)

hls4ml/backends/vivado/vivado_backend.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,18 @@ def init_conv1dtranspose(self, layer):
190190
layer.set_attr('strategy', 'latency')
191191

192192
in_width = layer.get_input_variable().shape[0]
193-
num_out = 1 + in_width + (layer.get_output_variable().shape[1] + layer.get_attr('pad_left'))//layer.get_attr('stride_width')
193+
proc_width = (layer.get_output_variable().shape[0] + layer.get_attr('pad_left') + layer.get_attr('stride_width')-1) \
194+
// layer.get_attr('stride_width')
194195
chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1)
195-
valid_pf = self.get_valid_conv_partition_splits(1, num_out)
196+
valid_pf = self.get_valid_conv_partition_splits(1, proc_width)
196197
if chosen_pf not in valid_pf:
197198
closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf)
198199
print('WARNING: Invalid ParallelizationFactor={} in layer "{}". Using ParallelizationFactor={} instead. Valid ParallelizationFactor(s): {}.'
199200
.format(chosen_pf, layer.name, closest_pf, ','.join(map(str, valid_pf))))
200201
else:
201202
closest_pf = chosen_pf
202-
layer.set_attr('n_partitions', num_out // closest_pf)
203-
layer.set_attr('num_out', num_out)
203+
layer.set_attr('n_partitions', proc_width // closest_pf)
204+
layer.set_attr('proc_width', proc_width)
204205

205206
layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower())
206207

@@ -247,7 +248,7 @@ def init_conv2d(self, layer):
247248
self._validate_conv_strategy(layer)
248249

249250
@layer_optimizer(Conv2DTranspose)
250-
def init_conv2d(self, layer):
251+
def init_conv2dtranspose(self, layer):
251252
if len(layer.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D
252253
layer.weights['weight'].data = np.expand_dims(layer.weights['weight'].data, axis=(0,1))
253254

@@ -259,8 +260,29 @@ def init_conv2d(self, layer):
259260
else:
260261
layer.set_attr('strategy', 'latency')
261262

263+
in_height = layer.get_input_variable().shape[0]
264+
in_width = layer.get_input_variable().shape[1]
265+
266+
proc_height = (layer.get_output_variable().shape[0] + layer.get_attr('pad_top') + layer.get_attr('stride_height')-1) \
267+
// layer.get_attr('stride_height')
268+
proc_width = (layer.get_output_variable().shape[1] + layer.get_attr('pad_left') + layer.get_attr('stride_width')-1) \
269+
// layer.get_attr('stride_width')
270+
chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1)
271+
valid_pf = self.get_valid_conv_partition_splits(proc_height, proc_width)
272+
if chosen_pf not in valid_pf:
273+
closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf)
274+
print('WARNING: Invalid ParallelizationFactor={} in layer "{}". Using ParallelizationFactor={} instead. Valid ParallelizationFactor(s): {}.'
275+
.format(chosen_pf, layer.name, closest_pf, ','.join(map(str, valid_pf))))
276+
else:
277+
closest_pf = chosen_pf
278+
layer.set_attr('n_partitions', proc_height * proc_width // closest_pf)
279+
layer.set_attr('proc_height', proc_height)
280+
layer.set_attr('proc_width', proc_width)
281+
262282
layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower())
263283

284+
self._validate_conv_strategy(layer)
285+
264286
@layer_optimizer(SeparableConv2D)
265287
def init_sepconv2d(self, layer):
266288
if layer.model.config.is_resource_strategy(layer):

hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define NNET_CONV2DTRANSPOSE_H
33

44
#include "nnet_common.h"
5+
#include "nnet_conv2dtranspose_resource.h"
56
#include <cstdlib>
67

78
namespace nnet{
@@ -40,14 +41,16 @@ struct conv2dtranspose_config
4041
};
4142

4243
template<class data_T, class res_T, typename CONFIG_T>
43-
void conv_2d_cl(
44+
void conv_2d_transpose_cl(
4445
data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan],
4546
res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt],
4647
typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
4748
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]
4849
)
4950
{
50-
return; //only stream is supported currently
51+
#pragma HLS INLINE region
52+
//only have resource strategy as of now
53+
conv_2d_transpose_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
5154
}
5255

5356
}

0 commit comments

Comments
 (0)