Skip to content

Commit 91f1c4c

Browse files
update to new conv methods for io_parallel. Still some issues with multiple filters as well as some padding issues
1 parent 6eecaa5 commit 91f1c4c

File tree

6 files changed

+216
-174
lines changed

6 files changed

+216
-174
lines changed

hls4ml/backends/fpga/fpga_backend.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,68 @@ def generate_conv1d_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, ke
492492
" ) {{\n"
493493
).format(index=layer_idx)
494494
indent = ' '
495+
for partition_idx, partition in enumerate(np.split(im2col_matrix, n_partitions)):
496+
generated_code += indent * 2 + 'if (partition == {:>3}) {{\n'.format(partition_idx)
497+
for pixel_idx, arr in enumerate(partition):
498+
buffer_stmts = []
499+
for j, v in enumerate(arr):
500+
if v == 0:
501+
val = '0'
502+
else:
503+
val = 'data[{}]'.format(int(v-1))
504+
buffer_stmts.append('buffer[{}][{}] = {:>10};'.format(pixel_idx, j, val))
505+
generated_code += indent * 3 + ' '.join(buffer_stmts) + '\n'
506+
generated_code += '\n' + indent * 2 + '}\n'
507+
508+
generated_code += indent + '}\n'
509+
generated_code += '};\n'
510+
511+
return generated_code
512+
513+
def _compute_conv1d_tr_im2col(self, input_shape, kernel=3, stride=1):
514+
W, C = input_shape
515+
516+
out_w = W # working with padding in a different way for transpose layers
495517

518+
tr_kernel = (kernel+stride-1)//stride
519+
520+
input_img = np.arange(1, W * C + 1)
521+
im_matrix = np.zeros((tr_kernel * C * out_w, ))
522+
523+
index = 0
524+
for i_ow in range(out_w):
525+
for i_kw in range(tr_kernel):
526+
for i_c in range(C):
527+
# input column is just the output column shifted
528+
input_col = i_ow - (tr_kernel-1) + i_kw
529+
if (input_col >= 0 and input_col < W):
530+
im_matrix[index] = input_img[input_col * C + i_c]
531+
else:
532+
im_matrix[index] = 0
533+
index += 1
534+
im_matrix = im_matrix.reshape(out_w, -1)
535+
return im_matrix
536+
537+
538+
def generate_conv1d_tr_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, kernel=3, stride=1):
539+
540+
im2col_matrix = self._compute_conv1d_tr_im2col(
541+
(in_W, in_C),
542+
kernel,
543+
stride,
544+
)
545+
546+
generated_code = (
547+
"template<class data_T, typename CONFIG_T>\n"
548+
"class fill_buffer_{index} : public FillConv1DBuffer<data_T, CONFIG_T> {{\n"
549+
" public:\n"
550+
" static void fill_buffer(\n"
551+
" data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],\n"
552+
" data_T buffer[CONFIG_T::n_pixels][CONFIG_T::trfilt_width * CONFIG_T::n_chan],\n"
553+
" const unsigned partition\n"
554+
" ) {{\n"
555+
).format(index=layer_idx)
556+
indent = ' '
496557
for partition_idx, partition in enumerate(np.split(im2col_matrix, n_partitions)):
497558
generated_code += indent * 2 + 'if (partition == {:>3}) {{\n'.format(partition_idx)
498559
for pixel_idx, arr in enumerate(partition):

hls4ml/backends/fpga/passes/codegen.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
from hls4ml.model.optimizer import OptimizerPass
2-
from hls4ml.model.layers import Conv1D, Conv2D
2+
from hls4ml.model.layers import Conv1D, Conv2D, Conv1DTranspose
33
from hls4ml.model.types import Source
44

55
class GenerateConvIm2col(OptimizerPass):
66
''' Generates tcode for im2col step of 1D/2d convolution '''
77
def match(self, node):
8-
return isinstance(node, (Conv1D, Conv2D)) and \
8+
return isinstance(node, (Conv1D, Conv2D, Conv1DTranspose)) and \
99
node.model.config.get_config_value('IOType') == 'io_parallel'
1010

1111
def transform(self, model, node):
1212
node_class = node.__class__.__name__
13-
if '1D' in node_class:
13+
if '1DTranspose' in node_class:
14+
self._generate_im2col_1d_transpose(node)
15+
elif '1D' in node_class:
1416
self._generate_im2col_1d(node)
1517
elif '2D' in node_class:
1618
self._generate_im2col_2d(node)
@@ -30,6 +32,18 @@ def _generate_im2col_1d(self, node):
3032

3133
node.set_attr('line_buffer_codegen', Source(code_str))
3234

35+
def _generate_im2col_1d_transpose(self, node):
36+
code_str = node.model.config.backend.generate_conv1d_tr_line_buffer_fn(
37+
node.get_attr('index'),
38+
node.get_attr('n_partitions'),
39+
node.get_input_variable().shape[0],
40+
node.get_input_variable().shape[1],
41+
kernel=node.get_attr('filt_width'),
42+
stride=node.get_attr('stride_width'),
43+
)
44+
45+
node.set_attr('line_buffer_codegen', Source(code_str))
46+
3347
def _generate_im2col_2d(self, node):
3448
code_str = node.model.config.backend.generate_conv2d_line_buffer_fn(
3549
node.get_attr('index'),

hls4ml/backends/vivado/passes/convolution_templates.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def format(self, node):
112112
static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation};
113113
static const unsigned min_width = {min_width};
114114
static const ap_uint<filt_width> pixels[min_width];
115+
static const unsigned n_partitions = {n_partitions};
116+
static const unsigned n_pixels = in_width / n_partitions;
117+
template<class data_T, class CONFIG_T>
118+
using fill_buffer = nnet::{fill_fn}<data_T, CONFIG_T>;
115119
typedef {accum_t.name} accum_t;
116120
typedef {bias_t.name} bias_t;
117121
typedef {weight_t.name} weight_t;
@@ -137,6 +141,10 @@ def format(self, node):
137141
// node.get_attr('stride_width')
138142

139143
params['config_t'] = 'config{}_mult'.format(node.index)
144+
if node.model.config.get_config_value('IOType') == 'io_parallel':
145+
params['fill_fn'] = 'fill_buffer_{}'.format(node.index)
146+
else:
147+
params['fill_fn'] = 'FillConv1DBuffer'
140148
conv_config = self.template.format(**params)
141149

142150
mult_params = self._default_config_params(node)

hls4ml/backends/vivado/vivado_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,22 @@ def init_conv1dtranspose(self, layer):
194194
self.set_closest_reuse_factor(layer, n_in, n_out)
195195
else:
196196
layer.set_attr('strategy', 'latency')
197+
198+
in_width = layer.get_input_variable().shape[0]
199+
chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1)
200+
valid_pf = self.get_valid_conv_partition_splits(1, in_width)
201+
if chosen_pf not in valid_pf:
202+
closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf)
203+
print('WARNING: Invalid ParallelizationFactor={} in layer "{}". Using ParallelizationFactor={} instead. Valid ParallelizationFactor(s): {}.'
204+
.format(chosen_pf, layer.name, closest_pf, ','.join(map(str, valid_pf))))
205+
else:
206+
closest_pf = chosen_pf
207+
layer.set_attr('n_partitions', in_width // closest_pf)
197208

198209
layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower())
199210

211+
self._validate_conv_strategy(layer)
212+
200213
@layer_optimizer(SeparableConv1D)
201214
def init_sepconv1d(self, layer):
202215
if layer.model.config.is_resource_strategy(layer):

hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ void conv_1d_transpose_cl(
3838
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]
3939
)
4040
{
41+
#pragma HLS INLINE region
4142
//for now, we are only adding resource strategy
4243
conv_1d_transpose_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
4344
}

0 commit comments

Comments
 (0)