diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index ec496c2104..e09c761d17 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -14,7 +14,9 @@ Activation, BatchNormalization, Conv1D, + Conv1DTranspose, Conv2D, + Conv2DTranspose, Dense, Dot, Embedding, @@ -52,7 +54,9 @@ def __init__(self, name): accum_layers = [ Dense, Conv1D, + Conv1DTranspose, Conv2D, + Conv2DTranspose, SeparableConv1D, SeparableConv2D, Pooling1D, @@ -158,6 +162,22 @@ def get_layer_mult_size(self, layer): n_out = layer.get_attr('n_out') return n_in, n_out + if 'Conv1DTranspose' in layer.class_name: + trfilt_width = (layer.get_attr('filt_width') + layer.get_attr('stride_width') - 1) \ + // layer.get_attr('stride_width') + n_in = layer.get_attr('n_chan') * trfilt_width + n_out = layer.get_attr('n_filt') + return n_in, n_out + + if 'Conv2DTranspose' in layer.class_name: + trfilt_width = (layer.get_attr('filt_width') + layer.get_attr('stride_width') - 1) \ + // layer.get_attr('stride_width') + trfilt_height = (layer.get_attr('filt_height') + layer.get_attr('stride_height') - 1) \ + // layer.get_attr('stride_height') + n_in = layer.get_attr('n_chan') * trfilt_height * trfilt_width + n_out = layer.get_attr('n_filt') + return n_in, n_out + if 'Conv1D' in layer.class_name: n_in = layer.get_attr('n_chan') * layer.get_attr('filt_width') n_out = layer.get_attr('n_filt') @@ -713,7 +733,67 @@ def generate_conv1d_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, ke " ) {{\n" ).format(index=layer_idx) indent = ' ' + for partition_idx, partition in enumerate(np.split(im2col_matrix, n_partitions)): + generated_code += indent * 2 + 'if (partition == {:>3}) {{\n'.format(partition_idx) + for pixel_idx, arr in enumerate(partition): + buffer_stmts = [] + for j, v in enumerate(arr): + if v == 0: + val = '0' + else: + val = 'data[{}]'.format(int(v-1)) + buffer_stmts.append('buffer[{}][{}] = {:>10};'.format(pixel_idx, j, val)) + generated_code += indent * 3 + ' '.join(buffer_stmts) + '\n' + generated_code += '\n' + indent * 2 + '}\n' + + generated_code += indent + '}\n' + generated_code += '};\n' + + return generated_code + + def _compute_conv1d_tr_im2col(self, input_shape, out_w, kernel=3, stride=1): + W, C = input_shape + + tr_kernel = (kernel+stride-1)//stride + + input_img = np.arange(1, W * C + 1) + im_matrix = np.zeros((tr_kernel * C * out_w, )) + + index = 0 + for i_ow in range(out_w): + for i_kw in range(tr_kernel): + for i_c in range(C): + # input column is just the output column shifted + input_col = i_ow - (tr_kernel-1) + i_kw + if (input_col >= 0 and input_col < W): + im_matrix[index] = input_img[input_col * C + i_c] + else: + im_matrix[index] = 0 + index += 1 + im_matrix = im_matrix.reshape(out_w, -1) + return im_matrix + + + def generate_conv1d_tr_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, out_W, kernel=3, stride=1): + im2col_matrix = self._compute_conv1d_tr_im2col( + (in_W, in_C), + out_W, + kernel, + stride, + ) + + generated_code = ( + "template\n" + "class fill_buffer_{index} : public FillConv1DBuffer {{\n" + " public:\n" + " static void fill_buffer(\n" + " data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],\n" + " data_T buffer[CONFIG_T::n_pixels][CONFIG_T::trfilt_width * CONFIG_T::n_chan],\n" + " const unsigned partition\n" + " ) {{\n" + ).format(index=layer_idx) + indent = ' ' for partition_idx, partition in enumerate(np.split(im2col_matrix, n_partitions)): generated_code += indent * 2 + f'if (partition == {partition_idx:>3}) {{\n' for pixel_idx, arr in enumerate(partition): @@ -862,6 +942,91 @@ def generate_conv2d_line_buffer_fn( return generated_code + def _compute_conv2d_tr_im2col(self, input_shape, out_shape, kernel=(3, 3), stride=(1, 1)): + H, W, C = input_shape + kernel_h, kernel_w = kernel + stride_h, stride_w = stride + out_h, out_w = out_shape + + tr_kernel_h = (kernel_h+stride_h-1)//stride_h + tr_kernel_w = (kernel_w+stride_w-1)//stride_w + + input_img = np.arange(1, H * W * C + 1) + im_matrix = np.zeros((tr_kernel_h * tr_kernel_w * C * out_h * out_w, )) + + index = 0 + for i_oh in range(out_h): + for i_ow in range(out_w): + for i_kh in range(tr_kernel_h): + input_row = i_oh - (tr_kernel_h-1) + i_kh + for i_kw in range(tr_kernel_w): + for i_c in range(C): + if (input_row < 0 or input_row >= H): + im_matrix[index] = 0 + else: + input_col = i_ow - (tr_kernel_w-1) + i_kw + if (input_col >= 0 and input_col < W): + im_matrix[index] = input_img[input_row * W * C + input_col * C + i_c] + else: + im_matrix[index] = 0 + index += 1 + + im_matrix = im_matrix.reshape(out_h * out_w, -1) + return im_matrix + + + def generate_conv2d_tr_line_buffer_fn(self, layer_idx, n_partitions, in_H, in_W, in_C, out_H, out_W, kernel=(3, 3), stride=(1, 1)): + if isinstance(kernel, Iterable): + kernel_height = kernel[0] + kernel_width = kernel[1] + else: + kernel_height = kernel + kernel_width = kernel + + if isinstance(stride, Iterable): + stride_height = stride[0] + stride_width = stride[1] + else: + stride_height = stride + stride_width = stride + + im2col_matrix = self._compute_conv2d_tr_im2col( + (in_H, in_W, in_C), + (out_W, out_W), + (kernel_height, kernel_width), + (stride_height, stride_width), + ) + + generated_code = ( + "template\n" + "class fill_buffer_{index} : public FillConv2DBuffer {{\n" + " public:\n" + " static void fill_buffer(\n" + " data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan],\n" + " data_T buffer[CONFIG_T::n_pixels][CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan],\n" + " const unsigned partition\n" + " ) {{\n" + ).format(index=layer_idx) + indent = ' ' + + for partition_idx, partition in enumerate(np.split(im2col_matrix, n_partitions)): + generated_code += indent * 2 + 'if (partition == {:>3}) {{\n'.format(partition_idx) + for pixel_idx, arr in enumerate(partition): + buffer_stmts = [] + for j, v in enumerate(arr): + if v == 0: + val = '0' + else: + val = 'data[{}]'.format(int(v-1)) + buffer_stmts.append('buffer[{}][{}] = {:>10};'.format(pixel_idx, j, val)) + generated_code += indent * 3 + ' '.join(buffer_stmts) + '\n' + generated_code += '\n' + indent * 2 + '}\n' + + generated_code += indent + '}\n' + generated_code += '};\n' + + return generated_code + @model_optimizer() def write_hls(self, model): self.writer.write_hls(model) diff --git a/hls4ml/backends/fpga/fpga_types.py b/hls4ml/backends/fpga/fpga_types.py index 6b1e63a469..fcd98fdf0b 100644 --- a/hls4ml/backends/fpga/fpga_types.py +++ b/hls4ml/backends/fpga/fpga_types.py @@ -326,6 +326,15 @@ def __init__(self, type_converter): class StaticWeightVariableDefinition(VariableDefinition): def definition_cpp(self, name_suffix='', as_reference=False): + if self.keep_dims > 0: + size_str = '' + for dim in range(self.keep_dims): + size_str += '[{cur_dim}]'.format(cur_dim=self.shape[dim]) + final_dim = 1 + for dim in range(self.keep_dims, len(self.shape)): + final_dim *= self.shape[dim] + size_str += '[{last_dim}]'.format(last_dim=final_dim) + return '{type} {name}{sizes}'.format(type=self.type.name, name=self.name, sizes=size_str) return '{type} {name}[{size}]'.format(type=self.type.name, name=self.name, size=self.data_length) class StaticWeightVariableConverter(object): diff --git a/hls4ml/backends/fpga/passes/codegen.py b/hls4ml/backends/fpga/passes/codegen.py index f031645091..037280a7fd 100644 --- a/hls4ml/backends/fpga/passes/codegen.py +++ b/hls4ml/backends/fpga/passes/codegen.py @@ -1,17 +1,21 @@ from hls4ml.model.optimizer import OptimizerPass -from hls4ml.model.layers import Conv1D, Conv2D +from hls4ml.model.layers import Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose from hls4ml.model.types import Source class GenerateConvIm2col(OptimizerPass): ''' Generates tcode for im2col step of 1D/2d convolution ''' def match(self, node): - return isinstance(node, (Conv1D, Conv2D)) and \ + return isinstance(node, (Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose)) and \ node.model.config.get_config_value('IOType') == 'io_parallel' def transform(self, model, node): node_class = node.__class__.__name__ - if '1D' in node_class: + if '1DTranspose' in node_class: + self._generate_im2col_1d_transpose(node) + elif '1D' in node_class: self._generate_im2col_1d(node) + elif '2DTranspose' in node_class: + self._generate_im2col_2d_transpose(node) elif '2D' in node_class: self._generate_im2col_2d(node) else: @@ -30,6 +34,19 @@ def _generate_im2col_1d(self, node): node.set_attr('line_buffer_codegen', Source(code_str)) + def _generate_im2col_1d_transpose(self, node): + code_str = node.model.config.backend.generate_conv1d_tr_line_buffer_fn( + node.get_attr('index'), + node.get_attr('n_partitions'), + node.get_input_variable().shape[0], + node.get_input_variable().shape[1], + node.get_attr('proc_width'), + kernel=node.get_attr('filt_width'), + stride=node.get_attr('stride_width'), + ) + + node.set_attr('line_buffer_codegen', Source(code_str)) + def _generate_im2col_2d(self, node): code_str = node.model.config.backend.generate_conv2d_line_buffer_fn( node.get_attr('index'), @@ -43,3 +60,18 @@ def _generate_im2col_2d(self, node): ) node.set_attr('line_buffer_codegen', Source(code_str)) + + def _generate_im2col_2d_transpose(self, node): + code_str = node.model.config.backend.generate_conv2d_tr_line_buffer_fn( + node.get_attr('index'), + node.get_attr('n_partitions'), + node.get_input_variable().shape[0], + node.get_input_variable().shape[1], + node.get_input_variable().shape[2], + node.get_attr('proc_height'), + node.get_attr('proc_width'), + kernel=(node.get_attr('filt_height'), node.get_attr('filt_width')), + stride=(node.get_attr('stride_height'), node.get_attr('stride_width')), + ) + + node.set_attr('line_buffer_codegen', Source(code_str)) diff --git a/hls4ml/backends/vivado/passes/conv_same_pad.py b/hls4ml/backends/vivado/passes/conv_same_pad.py index dc27076574..54c676386a 100644 --- a/hls4ml/backends/vivado/passes/conv_same_pad.py +++ b/hls4ml/backends/vivado/passes/conv_same_pad.py @@ -1,5 +1,5 @@ from hls4ml.model.optimizer import OptimizerPass -from hls4ml.model.layers import Conv1D, SeparableConv1D, Conv2D, SeparableConv2D +from hls4ml.model.layers import Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, Conv1DTranspose, Conv2DTranspose class InsertZeroPaddingBeforeConv1D(OptimizerPass): name = 'insert_zero_padding_before_conv1d' @@ -46,6 +46,53 @@ def transform(self, model, node): return True +class InsertZeroPaddingBeforeConv1DTranspose(OptimizerPass): + name = 'insert_zero_padding_before_conv1dtranspose' + + def match(self, node): + is_match = isinstance(node, (Conv1DTranspose)) and \ + node.get_attr('padding') == 'same' and \ + node.get_attr('filt_width') != 1 + return is_match + + def transform(self, model, node): + if model.config.get_config_value('IOType') != 'io_stream': + return False + + # Get the padding parameters from Conv1D layer + pad_left = node.get_attr('pad_left') + pad_right = node.get_attr('pad_right') + convtr_out_width = node.get_attr('out_width') + in_width = node.get_attr('in_width') + stride_width = node.get_attr('stride_width') + trfilt_width = (node.get_attr('filt_width') + node.get_attr('stride_width') - 1) \ + // node.get_attr('stride_width') + + add_right = (convtr_out_width + pad_left)//stride_width - (in_width-1) + + out_width = in_width + add_right + trfilt_width-1 + + attrs = { + 'pad_left': trfilt_width-1, + 'pad_right': add_right, + 'in_width': in_width, + 'out_width': out_width, + 'n_chan': node.get_attr('n_chan'), + 'data_format': node.get_attr('data_format', 'channels_last') + } + + # Switch Conv1DTranspose to be 'valid'. I think this is wrong + node.set_attr('padding', 'valid') + node.set_attr('in_width', out_width) + node.set_attr('pad_left', pad_left + (trfilt_width-1)*stride_width) + + # Insert new ZeroPadding1D node above Conv1DTranspose + padding_layer = model.make_node('ZeroPadding1D', 'zp1d_' + node.name, attrs, node.inputs.copy()) + padding_layer.get_output_variable().type.precision = node.get_input_variable().type.precision + model.insert_node(padding_layer) + + return True + class InsertZeroPaddingBeforeConv2D(OptimizerPass): name = 'insert_zero_padding_before_conv2d' @@ -100,3 +147,66 @@ def transform(self, model, node): model.insert_node(padding_layer, before=node) return True + +class InsertZeroPaddingBeforeConv2DTranspose(OptimizerPass): + name = 'insert_zero_padding_before_conv2dtranspose' + + def match(self, node): + is_match = isinstance(node, Conv2DTranspose) and \ + node.get_attr('padding') == 'same' and \ + node.get_attr('filt_width') != 1 + return is_match + + def transform(self, model, node): + if model.config.get_config_value('IOType') != 'io_stream': + return False + + # Get the padding parameters from Conv2DTranspose layer + pad_left = node.get_attr('pad_left') + pad_right = node.get_attr('pad_right') + pad_top = node.get_attr('pad_top') + pad_bottom = node.get_attr('pad_bottom') + convtr_out_width = node.get_attr('out_width') + convtr_out_height = node.get_attr('out_height') + in_width = node.get_attr('in_width') + in_height = node.get_attr('in_height') + stride_width = node.get_attr('stride_width') + stride_height = node.get_attr('stride_height') + trfilt_width = (node.get_attr('filt_width') + node.get_attr('stride_width') - 1) \ + // node.get_attr('stride_width') + trfilt_height = (node.get_attr('filt_height') + node.get_attr('stride_height') - 1) \ + // node.get_attr('stride_height') + + add_right = (convtr_out_width + pad_left)//stride_width-(in_width-1) + add_bottom = (convtr_out_height + pad_top)//stride_height-(in_height-1) + + out_width = in_width + add_right + trfilt_width-1 + out_height = in_height + add_bottom + trfilt_height-1 + + attrs = { + 'pad_left': trfilt_width-1, + 'pad_right': add_right, + 'pad_top': trfilt_height-1, + 'pad_bottom': add_bottom, + 'in_width': in_width, + 'in_height': in_height, + 'out_width': out_width, + 'out_height': out_height, + 'n_chan': node.get_attr('n_chan'), + 'data_format': node.get_attr('data_format', 'channels_last') + } + + # switch Conv2DTranspose to be 'valid'. This is technically not true though + node.set_attr('padding', 'valid') + node.set_attr('in_width', out_width) + node.set_attr('in_height', out_height) + node.set_attr('pad_left', pad_left + (trfilt_width-1)*stride_width) + node.set_attr('pad_top', pad_top + (trfilt_height-1)*stride_height) + + # insert new ZeroPadding2D ndoe above Conv2DTranspose + padding_layer = model.make_node('ZeroPadding2D', 'zp2d_' + node.name, attrs, node.inputs.copy()) + padding_layer.get_output_variable().type.precision = node.get_input_variable().type.precision + model.insert_node(padding_layer, before=node) + + return True + diff --git a/hls4ml/backends/vivado/passes/conv_stream.py b/hls4ml/backends/vivado/passes/conv_stream.py index e0bb853d83..fbca22d5fa 100644 --- a/hls4ml/backends/vivado/passes/conv_stream.py +++ b/hls4ml/backends/vivado/passes/conv_stream.py @@ -1,4 +1,4 @@ -from hls4ml.model.layers import Conv1D, Conv2D, SeparableConv1D, SeparableConv2D +from hls4ml.model.layers import Conv1D, Conv2D, SeparableConv1D, SeparableConv2D, Conv2DTranspose, Conv1DTranspose from hls4ml.model.optimizer import OptimizerPass @@ -6,7 +6,7 @@ class GenerateConvStreamingInstructions(OptimizerPass): '''Generates the instructions for streaming implementation of CNNs''' def match(self, node): - return isinstance(node, (Conv1D, SeparableConv1D, Conv2D, SeparableConv2D)) + return isinstance(node, (Conv1D, Conv1DTranspose, SeparableConv1D, Conv2D, SeparableConv2D, Conv2DTranspose)) def transform(self, model, node): node_class = node.__class__.__name__ @@ -18,12 +18,19 @@ def transform(self, model, node): raise Exception(f'Cannot generate instructions for node {node.name} ({node_class})') def _generate_1d_instructions(self, node): + kernel_width = node.get_attr('filt_width') + stride_width = node.get_attr('stride_width') + if isinstance(node, Conv1DTranspose): + # set kernel width to trfilt_width and set stride to 1 (effective kernel dimensions in transpose) + kernel_width = (node.get_attr('filt_width') + node.get_attr('stride_width')-1) \ + // node.get_attr('stride_width') + stride_width = 1 if node.model.config.get_config_value('IOType') == 'io_stream': min_w, instructions = node.model.config.backend.compute_conv1d_instructions( node.get_input_variable().shape[0], node.get_input_variable().shape[1], - node.get_attr('filt_width'), - node.get_attr('stride_width'), + kernel_width, + stride_width, ) instructions_str = ','.join(str(i) for i in instructions) node.set_attr('min_width', min_w) @@ -34,13 +41,20 @@ def _generate_1d_instructions(self, node): node.set_attr('instructions', '0') def _generate_2d_instructions(self, node): + kernel_height = node.get_attr('filt_height') + stride_height = node.get_attr('stride_height') + if isinstance(node, Conv2DTranspose): + # set actual kernel height to trfilt_height and set stride to 1 (effective kernel in transpose) + kernel_height = (node.get_attr('filt_height') + node.get_attr('stride_height') - 1) \ + // node.get_attr('stride_height') + stride_height = 1 if node.model.config.get_config_value('IOType') == 'io_stream': min_h, min_w, instructions = node.model.config.backend.compute_conv2d_instructions( node.get_input_variable().shape[0], node.get_input_variable().shape[1], node.get_input_variable().shape[2], - node.get_attr('filt_height'), - node.get_attr('stride_height'), + kernel_height, + stride_height, ) instructions_str = ','.join(str(i) for i in instructions) node.set_attr('min_height', min_h) diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index 195fc00b58..1759bedba8 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -1,6 +1,6 @@ from hls4ml.backends.backend import get_backend from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate -from hls4ml.model.layers import Conv1D, Conv2D, Conv2DBatchnorm, DepthwiseConv2D, SeparableConv1D, SeparableConv2D +from hls4ml.model.layers import Conv1D, Conv2D, Conv2DBatchnorm, DepthwiseConv2D, SeparableConv1D, SeparableConv2D, Conv1DTranspose, Conv2DTranspose # Shared multiplication template @@ -102,6 +102,82 @@ def format(self, node): return self.template.format(**params) +# Conv1DTranspose Templates + +conv1dtranspose_config_template = """struct config{index} : nnet::conv1dtranspose_config {{ + static const unsigned pad_left = {pad_left}; + static const unsigned pad_right = {pad_right}; + static const unsigned in_width = {in_width}; + static const unsigned n_chan = {n_chan}; + static const unsigned filt_width = {filt_width}; + static const unsigned kernel_size = filt_width; + static const unsigned n_filt = {n_filt}; + static const unsigned stride_width = {stride_width}; + static const unsigned dilation = {dilation}; + static const unsigned out_width = {out_width}; + static const unsigned reuse_factor = {reuse}; + static const unsigned n_zeros = {nzeros}; + static const unsigned trfilt_width = {trfilt_width}; + static const bool store_weights_in_bram = false; + static const unsigned strategy = nnet::{strategy}; + static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation}; + static const unsigned min_width = {min_width}; + static const ap_uint pixels[min_width]; + static const unsigned n_partitions = {n_partitions}; + static const unsigned proc_width = {proc_width}; + static const unsigned n_pixels = proc_width / n_partitions; + template + using fill_buffer = nnet::{fill_fn}; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + typedef {config_t} mult_config; +}}; +const ap_uint config{index}::pixels[] = {{{instructions}}};\n""" + +conv1dtranspose_function_template = 'nnet::conv_1d_transpose_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' + +conv1dtranspose_include_list = ['nnet_utils/nnet_conv1dtranspose.h', 'nnet_utils/nnet_conv1dtranspose_stream.h'] + +class Conv1DTransposeConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Conv1DTranspose) + self.template = conv1dtranspose_config_template + self.mult_template = conv_mult_config_template + + def format(self, node): + params = self._default_config_params(node) + params['dilation'] = node.get_attr('dilation', 1) + params['nzeros'] = node.get_weights('weight').nzeros + + params['config_t'] = 'config{}_mult'.format(node.index) + if node.model.config.get_config_value('IOType') == 'io_parallel': + params['fill_fn'] = 'fill_buffer_{}'.format(node.index) + else: + params['fill_fn'] = 'FillConv1DBuffer' + conv_config = self.template.format(**params) + + mult_params = self._default_config_params(node) + mult_params['n_in'] = node.get_attr('n_chan') * \ + (node.get_attr('filt_width') + node.get_attr('stride_width') - 1) // node.get_attr('stride_width') + mult_params['n_out'] = node.get_attr('n_filt') + mult_params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) + mult_config = self.mult_template.format(**mult_params) + + return mult_config + '\n' + conv_config + +class Conv1DTransposeFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(Conv1DTranspose, include_header=conv1dtranspose_include_list) + self.template = conv1dtranspose_function_template + + def format(self, node): + params = self._default_function_params(node) + params['data_format'] = 'cf' if node.get_attr('data_format') == 'channels_first' else 'cl' + params['w'] = node.get_weights('weight').name + params['b'] = node.get_weights('bias').name + + return self.template.format(**params) # Conv2D Templates @@ -212,6 +288,92 @@ def __init__(self): super(Conv2DFunctionTemplate, self).__init__(DepthwiseConv2D, include_header=sepconv2d_include_list) self.template = depthconv2d_function_template +# Conv2DTranspose Templates +conv2dtranspose_config_template = """struct config{index} : nnet::conv2dtranspose_config {{ + static const unsigned pad_top = {pad_top}; + static const unsigned pad_bottom = {pad_bottom}; + static const unsigned pad_left = {pad_left}; + static const unsigned pad_right = {pad_right}; + static const unsigned in_height = {in_height}; + static const unsigned in_width = {in_width}; + static const unsigned n_chan = {n_chan}; + static const unsigned filt_height = {filt_height}; + static const unsigned filt_width = {filt_width}; + static const unsigned kernel_size = filt_height * filt_width; + static const unsigned n_filt = {n_filt}; + static const unsigned stride_height = {stride_height}; + static const unsigned stride_width = {stride_width}; + static const unsigned out_height = {out_height}; + static const unsigned out_width = {out_width}; + static const unsigned reuse_factor = {reuse}; + static const unsigned n_zeros = {nzeros}; + static const unsigned trfilt_width = {trfilt_width}; + static const unsigned trfilt_height = {trfilt_height}; + static const bool store_weights_in_bram = false; + static const unsigned strategy = nnet::{strategy}; + static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation}; + static const unsigned min_height = {min_height}; + static const unsigned min_width = {min_width}; + static const ap_uint pixels[min_height * min_width]; + static const unsigned n_partitions = {n_partitions}; + static const unsigned proc_height = {proc_height}; + static const unsigned proc_width = {proc_width}; + static const unsigned n_pixels = proc_height * proc_width / n_partitions; + template + using fill_buffer = nnet::{fill_fn}; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + typedef {config_t} mult_config; +}}; +const ap_uint config{index}::pixels[] = {{{instructions}}};\n""" + +conv2dtranspose_function_template = 'nnet::conv_2d_transpose_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' + +conv2dtranspose_include_list = ['nnet_utils/nnet_conv2dtranspose.h', 'nnet_utils/nnet_conv2dtranspose_stream.h'] + +class Conv2DTransposeConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Conv2DTranspose) + self.template = conv2dtranspose_config_template + self.mult_template = conv_mult_config_template + + def format(self, node): + params = self._default_config_params(node) + params['dilation'] = node.get_attr('dilation', 1) + params['nzeros'] = node.get_weights('weight').nzeros + params['trfilt_width'] = (node.get_attr('filt_width') + node.get_attr('stride_width') - 1) \ + // node.get_attr('stride_width') + params['trfilt_height'] = (node.get_attr('filt_height') + node.get_attr('stride_height') - 1) \ + // node.get_attr('stride_height') + + params['config_t'] = 'config{}_mult'.format(node.index) + if node.model.config.get_config_value('IOType') == 'io_parallel': + params['fill_fn'] = 'fill_buffer_{}'.format(node.index) + else: + params['fill_fn'] = 'FillConv2DBuffer' + conv_config = self.template.format(**params) + + mult_params = self._default_config_params(node) + mult_params['n_in'] = node.get_attr('n_chan') * params['trfilt_width'] * params['trfilt_height'] + mult_params['n_out'] = node.get_attr('n_filt') + mult_params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) + mult_config = self.mult_template.format(**mult_params) + + return mult_config + '\n' + conv_config + +class Conv2DTransposeFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(Conv2DTranspose, include_header=conv2dtranspose_include_list) + self.template = conv2dtranspose_function_template + + def format(self, node): + params = self._default_function_params(node) + params['data_format'] = 'cf' if node.get_attr('data_format') == 'channels_first' else 'cl' + params['w'] = node.get_weights('weight').name + params['b'] = node.get_weights('bias').name + + return self.template.format(**params) # SeparableConv1D/2D Templates diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 7698bc680d..9988987075 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -11,7 +11,9 @@ GRU, LSTM, Conv1D, + Conv1DTranspose, Conv2D, + Conv2DTranspose, Dense, DepthwiseConv2D, Embedding, @@ -82,6 +84,8 @@ def _register_flows(self): 'vivado:clone_output', 'vivado:insert_zero_padding_before_conv1d', 'vivado:insert_zero_padding_before_conv2d', + 'vivado:insert_zero_padding_before_conv1dtranspose', + 'vivado:insert_zero_padding_before_conv2dtranspose', 'vivado:broadcast_stream', ] streaming_flow = register_flow('streaming', streaming_passes, requires=[init_flow], backend=self.name) @@ -265,6 +269,34 @@ def init_conv1d(self, layer): self._validate_conv_strategy(layer) + @layer_optimizer(Conv1DTranspose) + def init_conv1dtranspose(self, layer): + if layer.model.config.is_resource_strategy(layer): + layer.set_attr('strategy', 'resource') + n_in, n_out = self.get_layer_mult_size(layer) + self.set_target_reuse_factor(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + else: + layer.set_attr('strategy', 'latency') + + in_width = layer.get_input_variable().shape[0] + proc_width = (layer.get_output_variable().shape[0] + layer.get_attr('pad_left') + layer.get_attr('stride_width')-1) \ + // layer.get_attr('stride_width') + chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1) + valid_pf = self.get_valid_conv_partition_splits(1, proc_width) + if chosen_pf not in valid_pf: + closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf) + print('WARNING: Invalid ParallelizationFactor={} in layer "{}". Using ParallelizationFactor={} instead. Valid ParallelizationFactor(s): {}.' + .format(chosen_pf, layer.name, closest_pf, ','.join(map(str, valid_pf)))) + else: + closest_pf = chosen_pf + layer.set_attr('n_partitions', proc_width // closest_pf) + layer.set_attr('proc_width', proc_width) + + layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + + self._validate_conv_strategy(layer) + @layer_optimizer(SeparableConv1D) def init_sepconv1d(self, layer): if layer.model.config.is_resource_strategy(layer): @@ -311,6 +343,42 @@ def init_conv2d(self, layer): self._validate_conv_strategy(layer) + @layer_optimizer(Conv2DTranspose) + def init_conv2dtranspose(self, layer): + if len(layer.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D + layer.weights['weight'].data = np.expand_dims(layer.weights['weight'].data, axis=(0,1)) + + if layer.model.config.is_resource_strategy(layer): + layer.set_attr('strategy', 'resource') + self.set_target_reuse_factor(layer) + n_in, n_out = self.get_layer_mult_size(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + else: + layer.set_attr('strategy', 'latency') + + in_height = layer.get_input_variable().shape[0] + in_width = layer.get_input_variable().shape[1] + + proc_height = (layer.get_output_variable().shape[0] + layer.get_attr('pad_top') + layer.get_attr('stride_height')-1) \ + // layer.get_attr('stride_height') + proc_width = (layer.get_output_variable().shape[1] + layer.get_attr('pad_left') + layer.get_attr('stride_width')-1) \ + // layer.get_attr('stride_width') + chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1) + valid_pf = self.get_valid_conv_partition_splits(proc_height, proc_width) + if chosen_pf not in valid_pf: + closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf) + print('WARNING: Invalid ParallelizationFactor={} in layer "{}". Using ParallelizationFactor={} instead. Valid ParallelizationFactor(s): {}.' + .format(chosen_pf, layer.name, closest_pf, ','.join(map(str, valid_pf)))) + else: + closest_pf = chosen_pf + layer.set_attr('n_partitions', proc_height * proc_width // closest_pf) + layer.set_attr('proc_height', proc_height) + layer.set_attr('proc_width', proc_width) + + layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + + self._validate_conv_strategy(layer) + @layer_optimizer(SeparableConv2D) def init_sepconv2d(self, layer): if layer.model.config.is_resource_strategy(layer): diff --git a/hls4ml/converters/keras/convolution.py b/hls4ml/converters/keras/convolution.py index 3c402496a8..61c70f12a7 100644 --- a/hls4ml/converters/keras/convolution.py +++ b/hls4ml/converters/keras/convolution.py @@ -1,5 +1,5 @@ from hls4ml.converters.keras_to_hls import keras_handler, parse_default_keras_layer -from hls4ml.converters.utils import compute_padding_1d, compute_padding_2d, parse_data_format +from hls4ml.converters.utils import compute_padding_1d, compute_padding_2d, parse_data_format, compute_padding_1d_transpose, compute_padding_2d_transpose @keras_handler('Conv1D', 'SeparableConv1D') @@ -26,6 +26,39 @@ def parse_conv1d_layer(keras_layer, input_names, input_shapes, data_reader): return layer, output_shape +@keras_handler('Conv1DTranspose') +def parse_conv1dtranspose_layer(keras_layer, input_names, input_shapes, data_reader): + assert('Conv1DTranspose' in keras_layer['class_name']) + layer = parse_default_keras_layer(keras_layer, input_names) + + ( + layer['in_width'], + layer['n_chan'] + ) = parse_data_format(input_shapes[0], layer['data_format']) + + layer['n_filt'] = keras_layer['config']['filters'] + layer['filt_width'] = keras_layer['config']['kernel_size'][0] + layer['stride_width'] = keras_layer['config']['strides'][0] + layer['padding'] = keras_layer['config']['padding'] + layer['trfilt_width'] = (layer['filt_width'] + layer['stride_width'] - 1)//layer['stride_width'] + + ( + layer['out_width'], + layer['pad_left'], + layer['pad_right'], + ) = compute_padding_1d_transpose( + layer['padding'], + layer['in_width'], + layer['stride_width'], + layer['filt_width'] + ) + + if layer['data_format'] == 'channels_last': + output_shape = [input_shapes[0][0], layer['out_width'], layer['n_filt']] + elif layer['data_format'] == 'channels_first': + output_shape = [input_shapes[0][0], layer['n_filt'], layer['out_width']] + + return layer, output_shape @keras_handler('Conv2D', 'SeparableConv2D', 'DepthwiseConv2D') def parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader): @@ -68,3 +101,51 @@ def parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader): output_shape = [input_shapes[0][0], layer['out_height'], layer['out_width'], layer['n_filt']] return layer, output_shape + +@keras_handler('Conv2DTranspose') +def parse_conv2dtranspose_layer(keras_layer, input_names, input_shapes, data_reader): + assert('Conv2DTranspose' in keras_layer['class_name']) + + layer = parse_default_keras_layer(keras_layer, input_names) + + ( + layer['in_height'], + layer['in_width'], + layer['n_chan'] + ) = parse_data_format(input_shapes[0], layer['data_format']) + + if 'filters' in keras_layer['config']: + layer['n_filt'] = keras_layer['config']['filters'] + else: + layer['n_filt'] = layer['n_chan'] + layer['filt_height'] = keras_layer['config']['kernel_size'][0] + layer['filt_width'] = keras_layer['config']['kernel_size'][1] + layer['stride_height'] = keras_layer['config']['strides'][0] + layer['stride_width'] = keras_layer['config']['strides'][1] + layer['padding'] = keras_layer['config']['padding'] + layer['trfilt_height'] = (layer['filt_height'] + layer['stride_height'] - 1)//layer['stride_height'] + layer['trfilt_width'] = (layer['filt_width'] + layer['stride_width'] - 1)//layer['stride_width'] + + ( + layer['out_height'], + layer['out_width'], + layer['pad_top'], + layer['pad_bottom'], + layer['pad_left'], + layer['pad_right'] + ) = compute_padding_2d_transpose( + layer['padding'], + layer['in_height'], + layer['in_width'], + layer['stride_height'], + layer['stride_width'], + layer['filt_height'], + layer['filt_width'] + ) + + if layer['data_format'] == 'channels_first': + output_shape = [input_shapes[0][0], layer['n_filt'], layer['out_height'], layer['out_width']] + else: + output_shape = [input_shapes[0][0], layer['out_height'], layer['out_width'], layer['n_filt']] + + return layer, output_shape \ No newline at end of file diff --git a/hls4ml/converters/utils.py b/hls4ml/converters/utils.py index d3fe6edfdd..7618801f1f 100644 --- a/hls4ml/converters/utils.py +++ b/hls4ml/converters/utils.py @@ -45,6 +45,21 @@ def compute_padding_1d(pad_type, in_size, stride, filt_size): return (n_out, pad_left, pad_right) +def compute_padding_1d_transpose(pad_type, in_size, stride, filt_size): + if pad_type.lower() == 'same': + n_out = stride*in_size + pad_along_size = max(filt_size-stride, 0) + pad_left = pad_along_size//2 + pad_right = pad_along_size-pad_left + elif pad_type.lower() == 'valid': + n_out = stride*(in_size-1) + filt_size + pad_left = 0 + pad_right = 0 + else: + raise Exception('Unknown padding type: {}'.format(pad_type)) + + return (n_out, pad_left, pad_right) + def compute_padding_2d(pad_type, in_height, in_width, stride_height, stride_width, filt_height, filt_width): if pad_type.lower() == 'same': #Height @@ -74,4 +89,30 @@ def compute_padding_2d(pad_type, in_height, in_width, stride_height, stride_widt else: raise Exception('Unknown padding type: {}'.format(pad_type)) + return (out_height, out_width, pad_top, pad_bottom, pad_left, pad_right) + +def compute_padding_2d_transpose(pad_type, in_height, in_width, stride_height, stride_width, filt_height, filt_width): + if pad_type.lower() == 'same': + #Height + out_height = stride_height*in_height + pad_along_height = max(filt_height-stride_height, 0) + pad_top = pad_along_height//2 + pad_bottom = pad_along_height-pad_top + #Width + out_width = stride_width*in_width + pad_along_width = max(filt_width-stride_width, 0) + pad_left = pad_along_width//2 + pad_right = pad_along_width-pad_left + elif pad_type.lower() == 'valid': + #something + out_height = stride_height*in_height + out_width = stride_width*in_width + + pad_top = 0 + pad_bottom = 0 + pad_left = 0 + pad_right = 0 + else: + raise Exception('Unknown padding type: {}'.format(pad_type)) + return (out_height, out_width, pad_top, pad_bottom, pad_left, pad_right) \ No newline at end of file diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index b8a3a1a4d9..14b8b857f5 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -232,11 +232,11 @@ def add_output_variable( self.set_attr(out_name, out) - def add_weights(self, quantizer=None, compression=False): + def add_weights(self, quantizer=None, compression=False, keep_dims=0): data = self.model.get_weights_data(self.name, 'kernel') self.add_weights_variable( - name='weight', var_name='w{index}', data=data, quantizer=quantizer, compression=compression + name='weight', var_name='w{index}', data=data, quantizer=quantizer, compression=compression, keep_dims=keep_dims ) def add_bias(self, quantizer=None): @@ -254,7 +254,7 @@ def add_bias(self, quantizer=None): ) def add_weights_variable( - self, name, var_name=None, type_name=None, precision=None, data=None, quantizer=None, compression=False + self, name, var_name=None, type_name=None, precision=None, data=None, quantizer=None, compression=False, keep_dims=0 ): if var_name is None: var_name = name + '{index}' @@ -300,7 +300,7 @@ def add_weights_variable( ) else: var = WeightVariable( - var_name, type_name=type_name, precision=precision, quantizer=quantizer, data=data, index=self.index + var_name, type_name=type_name, precision=precision, quantizer=quantizer, data=data, index=self.index, keep_dims=keep_dims ) var.data_unquantized = data_unquantized @@ -423,6 +423,60 @@ def initialize(self): self.add_bias(quantizer=self.get_attr('bias_quantizer')) +class Conv1DTranspose(Layer): + _expected_attributes = [ + Attribute('in_width'), + Attribute('out_width'), + + Attribute('n_chan'), + Attribute('n_filt'), + + Attribute('filt_width'), + Attribute('stride_width'), + + Attribute('pad_left'), + Attribute('pad_right'), + + WeightAttribute('weight'), + WeightAttribute('bias'), + + TypeAttribute('weight'), + TypeAttribute('bias'), + ] + + def initialize(self): + if self.get_attr('data_format') == 'channels_last': + shape = [self.attributes['out_width'], self.attributes['n_filt']] + dims = ['N_OUTPUTS_{}'.format(self.index), 'N_FILT_{}'.format(self.index)] + else: + shape = [self.attributes['n_filt'], self.attributes['out_width']] + dims = ['N_FILT_{}'.format(self.index), 'N_OUTPUTS_{}'.format(self.index)] + + data = self.model.get_weights_data(self.name, 'kernel') + # now we transform the entire kernel + + #(W,F,C) => (F,W,C) + data = np.transpose(data, axes=[1, 0, 2]) + # now split the kernel into stride width kernels (F, W, C) -> (S, F, W/S, C) + n_filts, kern_width, n_chan = data.shape + new_weights = np.zeros((self.attributes['stride_width'], n_filts, self.attributes['trfilt_width'], n_chan)) + for i_sw in range(self.attributes['stride_width']): + for i_fw in range(self.attributes['trfilt_width']): + filt_ind = i_sw + (self.attributes['trfilt_width']-i_fw-1) * self.attributes['stride_width'] + for i_nf in range(n_filts): + for i_nc in range(n_chan): + if filt_ind < kern_width: + new_weights[i_sw][i_nf][i_fw][i_nc] = \ + data[i_nf][filt_ind][i_nc] + data = new_weights + + self.add_output_variable(shape, dims) + # self.add_weights(quantizer = self.get_attr('weight_quantizer'), keep_dims=1) + self.add_weights_variable(name='weight', var_name='w{index}', \ + data=data, quantizer=self.get_attr('weight_quantizer'), keep_dims=1) + self.add_bias(quantizer = self.get_attr('bias_quantizer')) + + class SeparableConv1D(Layer): _expected_attributes = [ Attribute('in_width'), @@ -500,6 +554,69 @@ def initialize(self): self.add_weights(quantizer=self.get_attr('weight_quantizer')) self.add_bias(quantizer=self.get_attr('bias_quantizer')) +class Conv2DTranspose(Layer): + _expected_attributes = [ + Attribute('in_height'), + Attribute('in_width'), + + Attribute('out_height'), + Attribute('out_width'), + + Attribute('n_chan'), + Attribute('n_filt'), + + Attribute('filt_height'), + Attribute('filt_width'), + Attribute('stride_height'), + Attribute('stride_width'), + + Attribute('pad_top'), + Attribute('pad_bottom'), + Attribute('pad_left'), + Attribute('pad_right'), + + WeightAttribute('weight'), + WeightAttribute('bias'), + + TypeAttribute('weight'), + TypeAttribute('bias'), + ] + + def initialize(self): + if self.get_attr('data_format') == 'channels_last': + shape = [self.attributes['out_height'], self.attributes['out_width'], self.attributes['n_filt']] + dims = ['OUT_HEIGHT_{}'.format(self.index), 'OUT_WIDTH_{}'.format(self.index), 'N_FILT_{}'.format(self.index)] + else: + shape = [self.attributes['n_filt'], self.attributes['out_height'], self.attributes['out_width']] + dims = ['N_FILT_{}'.format(self.index), 'OUT_HEIGHT_{}'.format(self.index), 'OUT_WIDTH_{}'.format(self.index)] + + data = self.model.get_weights_data(self.name, 'kernel') + # now we transform the entire kernel + + #(H,W,F,C) => (F,H,W,C) + data = np.transpose(data, axes=[2, 0, 1, 3]) + # now split the kernel into stride width kernels (F, W, C) -> (Sh, Sw, F, H/Sh, W/Sw, C) + n_filts, kern_height, kern_width, n_chan = data.shape + new_weights = np.zeros((self.attributes['stride_height'], self.attributes['stride_width'], \ + n_filts, self.attributes['trfilt_height'], self.attributes['trfilt_width'], n_chan)) + for i_sh in range(self.attributes['stride_height']): + for i_sw in range(self.attributes['stride_width']): + for i_fh in range(self.attributes['trfilt_height']): + for i_fw in range(self.attributes['trfilt_width']): + filt_h_ind = i_sh + (self.attributes['trfilt_height']-i_fh-1)*self.attributes['stride_height'] + filt_w_ind = i_sw + (self.attributes['trfilt_width']-i_fw-1)*self.attributes['stride_width'] + for i_nf in range(n_filts): + for i_nc in range(n_chan): + if filt_h_ind < kern_height and filt_w_ind < kern_width: + new_weights[i_sh][i_sw][i_nf][i_fh][i_fw][i_nc] = \ + data[i_nf][filt_h_ind][filt_w_ind][i_nc] + data = new_weights + + self.add_output_variable(shape, dims) + self.add_weights_variable(name='weight', var_name='w{index}', \ + data=data, quantizer=self.get_attr('weight_quantizer'), keep_dims=2) + self.add_bias(quantizer=self.get_attr('bias_quantizer')) + class Conv2DBatchnorm(Conv2D): def _get_folded_weights(self): @@ -1274,6 +1391,8 @@ def _initialize_transforms(self): 'Conv2D': Conv2D, 'BinaryConv2D': Conv2D, 'QConv2D': Conv2D, + 'Conv1DTranspose': Conv1DTranspose, + 'Conv2DTranspose': Conv2DTranspose, 'QConv2DBatchnorm': Conv2DBatchnorm, 'SeparableConv1D': SeparableConv1D, 'SeparableConv2D': SeparableConv2D, diff --git a/hls4ml/model/types.py b/hls4ml/model/types.py index 115ff5cce0..2e94135038 100644 --- a/hls4ml/model/types.py +++ b/hls4ml/model/types.py @@ -104,9 +104,12 @@ def __str__(self): args = [self.width, self.integer, self.rounding_mode, self.saturation_mode, self.saturation_bits] args = ','.join([str(arg) for arg in args if arg is not None]) typestring = '{signed}fixed<{args}>'.format(signed='u' if not self.signed else '', args=args) + typestring = 'ap_{signed}fixed<{args}>'.format(signed='u' if not self.signed else '', args=args) return typestring def __eq__(self, other): + if not isinstance(other, FixedPrecisionType): + return False eq = self.width == other.width eq = eq and self.integer == other.integer eq = eq and self.fractional == other.fractional @@ -224,7 +227,7 @@ def definition_cpp(self, name_suffix='', as_reference=False): return None class WeightVariable(Variable): - def __init__(self, var_name, type_name, precision, data, quantizer=None, **kwargs): + def __init__(self, var_name, type_name, precision, data, quantizer=None, keep_dims=0, **kwargs): super(WeightVariable, self).__init__(var_name, NamedType(type_name, precision, **kwargs), **kwargs) self.data = data self.nzeros = -1 @@ -237,6 +240,7 @@ def __init__(self, var_name, type_name, precision, data, quantizer=None, **kwarg self._iterator = None self.update_precision(precision) self.quantizer = quantizer + self.keep_dims = keep_dims def __iter__(self): self._iterator = np.nditer(self.data, order='C') diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose.h new file mode 100644 index 0000000000..9794645721 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose.h @@ -0,0 +1,50 @@ +#ifndef NNET_CONV1DTRANSPOSE_H_ +#define NNET_CONV1DTRANSPOSE_H_ + +#include "nnet_common.h" +#include "nnet_conv1dtranspose_resource.h" +#include + +namespace nnet{ + +struct conv1dtranspose_config +{ + //Internal data type definitions + typedef float bias_t; + typedef float weight_t; + typedef float accum_t; + + //Convolutional parameters + static const unsigned pad_left = 0; + static const unsigned pad_right = 0; + static const unsigned in_width = 10; + static const unsigned n_chan = 0; + static const unsigned filt_width = 1; + static const unsigned kernel_size = filt_width; + static const unsigned stride_width = 1; + static const unsigned dilation = 1; + static const unsigned out_width = 10; + + static const unsigned reuse_factor = 1; + static const bool store_weights_in_bram = false; + static const unsigned n_zeros = 0; +}; + +template +void conv_1d_transpose_cl( + data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::stride_width][ + CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + #pragma HLS INLINE region + //for now, we are only adding resource strategy + conv_1d_transpose_resource_cl(data, res, weights, biases); +} + +} + +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_resource.h new file mode 100644 index 0000000000..033dbf676e --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_resource.h @@ -0,0 +1,132 @@ +#ifndef NNET_CONV1DTRANSPOSE_RESOURCE_H_ +#define NNET_CONV1DTRANSPOSE_RESOURCE_H_ + +#include "nnet_common.h" +#include "nnet_dense.h" + +namespace nnet{ + +template +void conv_1d_transpose_resource_cl( + data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::stride_width][ + CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + constexpr unsigned mult_n_in = CONFIG_T::trfilt_width * CONFIG_T::n_chan; + constexpr unsigned mult_n_out = CONFIG_T::n_filt; + constexpr unsigned block_factor = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor); + constexpr unsigned multscale = block_factor / mult_n_out; + + assert((block_factor % mult_n_out == 0 || CONFIG_T::reuse_factor >= mult_n_in) && "The current Reuse Factor is not allowed"); + assert((CONFIG_T::reuse_factor <= CONFIG_T::trfilt_width * CONFIG_T::n_chan) && "This function is correct only for RF <= TRFILT_WIDTH * N_CHAN"); + + data_T data_buf[CONFIG_T::n_pixels][mult_n_in]; + #pragma HLS ARRAY_PARTITION variable=data_buf complete dim=0 + #pragma HLS ARRAY_PARTITION variable=biases complete + + typename CONFIG_T::accum_t acc[CONFIG_T::n_pixels][mult_n_out][CONFIG_T::stride_width]; + #pragma HLS ARRAY_PARTITION variable=acc complete dim=0 + + #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor dim=2 + + PartitionLoop: + for (unsigned i_part = 0; i_part < CONFIG_T::n_partitions; i_part++) { + CONFIG_T::template fill_buffer::fill_buffer(data, data_buf, i_part); + + PixelInitAccumLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + InitAccumLoop: + for (unsigned i_acc = 0; i_acc < mult_n_out; i_acc++) { + #pragma HLS UNROLL + + InitStrideLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + acc[i_pxl][i_acc][i_sw] = (typename CONFIG_T::accum_t) biases[i_acc]; + } + } + } + + ReuseLoop: + for (unsigned i_rf = 0; i_rf < CONFIG_T::reuse_factor; i_rf++) { + #pragma HLS PIPELINE II=1 rewind + + unsigned i_w = i_rf; + unsigned i_in = i_rf; + unsigned i_out = 0; + unsigned i_acc = 0; + + MultLoop: + for (unsigned i_blk = 0; i_blk < block_factor; i_blk++) { + #pragma HLS UNROLL + + PixelMultLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + StrideMultLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + + acc[i_pxl][i_out][i_sw] += static_cast( + CONFIG_T::mult_config::template product::product( + data_buf[i_pxl][i_in], weights[i_sw][i_w] + ) + ); + } + } + + // Increment i_w + i_w += CONFIG_T::reuse_factor; + // Increment i_in + i_in += CONFIG_T::reuse_factor; + if (i_in >= mult_n_in) { + i_in = i_rf; + } + // Increment i_out + if (i_acc + 1 >= multscale) { + i_acc = 0; + i_out++; + } else { + i_acc++; + } + } + } + + + PixelResultLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + StrideResultLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + + unsigned output_index = i_pxl * CONFIG_T::n_partitions * CONFIG_T::stride_width + + i_part * CONFIG_T::stride_width + i_sw; + + if (output_index >= CONFIG_T::pad_left && + output_index < CONFIG_T::out_width + CONFIG_T::pad_left) { + ResultLoop: + for (unsigned i_res = 0; i_res < mult_n_out; i_res++) { + #pragma HLS UNROLL + + res[(output_index-CONFIG_T::pad_left)*mult_n_out + i_res] = + cast(acc[i_pxl][i_res][i_sw]); + } + } + } + } + } + +} + + +} +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h new file mode 100644 index 0000000000..d9dca82007 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h @@ -0,0 +1,141 @@ +#ifndef NNET_CONV1DTRANSPOSE_STREAM_H +#define NNET_CONV1DTRANSPOSE_STREAM_H + +#include "nnet_common.h" +#include "nnet_conv_stream.h" +#include "hls_stream.h" + +namespace nnet { + +template +void kernel_shift_tr_1d( + const data_T& in_elem, + typename data_T::value_type kernel_window[CONFIG_T::trfilt_width * CONFIG_T::n_chan] +) { + #pragma HLS INLINE + + // Shift kernel_window by one step to the left (manual shift operation) + static const int filt_width = CONFIG_T::trfilt_width - 1; + KernelShiftWidth: for (int i_iw = 0; i_iw < filt_width; i_iw++) { + #pragma HLS PIPELINE II = 1 + KernelShiftChannel: for (unsigned i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + #pragma HLS UNROLL + // Shift every element in kernel_window to the left + kernel_window[i_iw * CONFIG_T::n_chan + i_ic] = kernel_window[(i_iw + 1) * CONFIG_T::n_chan + i_ic]; + } + } + + // Insert shift_buffer column into right-most column of kernel + static const int lastheight = (CONFIG_T::trfilt_width - 1) * CONFIG_T::n_chan; + KernelPushChannel: for (int i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + #pragma HLS UNROLL + kernel_window[lastheight + i_ic] = in_elem[i_ic]; + } +} + +// Conv 1D transpose compute output +template +void compute_output_buffer_tr_1d( + const data_T& in_elem, + hls::stream &res_stream, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_width][ + CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + #pragma HLS INLINE + + // Thresholds + const static int lShiftX = CONFIG_T::trfilt_width - 1; + + // Counters + static int pX = 0; // pixel counter + static int oX = 0; // output counter (deals with 'padding') + + static typename data_T::value_type kernel_data[CONFIG_T::trfilt_width * CONFIG_T::n_chan]; + #pragma HLS ARRAY_PARTITION variable=kernel_data complete + + typename res_T::value_type res_out[CONFIG_T::n_filt]; + #pragma HLS ARRAY_PARTITION variable=res_out complete dim = 0 + + res_T res_pack; + #pragma HLS DATA_PACK variable=res_pack + + // Add pixel to buffer + nnet::kernel_shift_tr_1d(in_elem, kernel_data); + + //always do stride number of multiplications + StrideLoop: for (int idx = 0; idx < CONFIG_T::stride_width; idx++) { + #pragma HLS UNROLL + #pragma HLS INLINE region + // Dense multiply + if (CONFIG_T::strategy == nnet::latency) { + dense_latency( + kernel_data, res_out, weights[idx], biases); + } else { + dense_resource( + kernel_data, res_out, weights[idx], biases); + } + + // Pack output + if (oX >= CONFIG_T::pad_left && oX < CONFIG_T::pad_left + CONFIG_T::out_width) { + CastLoop: for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) { + #pragma HLS UNROLL + res_pack[i_ic] = res_out[i_ic]; + } + res_stream.write(res_pack); + } + // Write output to stream when output ready + oX++; + } + + // static var housekeeping + if (pX + 1 == CONFIG_T::in_width) // done with all of the inputs + { + pX = 0; + oX = 0; + } else { + pX = pX + 1; + } +} + +template +void conv_1d_transpose_buffer_cl( + hls::stream &data, + hls::stream &res, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_width][ + CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) +{ + ReadInputWidth: for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width; i_iw++) { + #pragma HLS LOOP_FLATTEN + // if (CONFIG_T::strategy == nnet::latency) { + // #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + // } + compute_output_buffer_tr_1d(data.read(), res, weights, biases); + } +} + +template +void conv_1d_transpose_cl( + hls::stream &data, + hls::stream &res, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_width][ + CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + switch(CONFIG_T::implementation) { + #pragma HLS inline region + case conv_implementation::linebuffer: + conv_1d_transpose_buffer_cl(data, res, weights, biases); + break; + } +} + +} +#endif +//NEED TO PAD INPUT OR CLEAR KERNEL diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose.h new file mode 100644 index 0000000000..aec7eee7f0 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose.h @@ -0,0 +1,60 @@ +#ifndef NNET_CONV2DTRANSPOSE_H +#define NNET_CONV2DTRANSPOSE_H + +#include "nnet_common.h" +#include "nnet_conv2dtranspose_resource.h" +#include + +namespace nnet{ + +struct conv2dtranspose_config +{ + // Internal data type definitions + typedef float bias_t; + typedef float weight_t; + typedef float accum_t; + + // Convolutional parameters + static const unsigned pad_top = 0; + static const unsigned pad_bottom = 0; + static const unsigned pad_left = 0; + static const unsigned pad_right = 0; + static const unsigned in_height = 10; + static const unsigned in_width = 10; + static const unsigned n_chan = 1; + static const unsigned filt_height = 1; + static const unsigned filt_width = 1; + static const unsigned kernel_size = filt_height * filt_width; + static const unsigned n_filt = 1; + static const unsigned stride_height = 1; + static const unsigned stride_width = 1; + static const unsigned out_height = 10; + static const unsigned out_width = 10; + static const unsigned dilation_height = 1; + static const unsigned dilation_width = 1; + static const unsigned trfilt_height = 1; + static const unsigned trfilt_width = 1; + + static const unsigned reuse_factor = 1; + static const bool store_weights_in_bram = false; + static const unsigned n_zeros = 0; // not used yet +}; + +template +void conv_2d_transpose_cl( + data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::stride_height][CONFIG_T::stride_width][ + CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + #pragma HLS INLINE region + //only have resource strategy as of now + conv_2d_transpose_resource_cl(data, res, weights, biases); +} + +} + +#endif \ No newline at end of file diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_resource.h new file mode 100644 index 0000000000..8f0ed748ed --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_resource.h @@ -0,0 +1,148 @@ +#ifndef NNET_CONV2DTRANSPOSE_RESOURCE_H +#define NNET_CONV2DTRANSPOSE_RESOURCE_H + +#include "nnet_common.h" +#include "nnet_dense.h" + +namespace nnet{ + +template +void conv_2d_transpose_resource_cl( + data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::stride_height][CONFIG_T::stride_width][ + CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + constexpr unsigned mult_n_in = CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan; + constexpr unsigned mult_n_out = CONFIG_T::n_filt; + constexpr unsigned block_factor = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor); + + constexpr unsigned multiplier_limit = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor); + constexpr unsigned multscale = multiplier_limit / mult_n_out; + + assert((multiplier_limit % mult_n_out == 0 || CONFIG_T::reuse_factor >= mult_n_in) && "The current Reuse Factor is not allowed"); + assert((multiplier_limit == block_factor) && "This function is correct only for RF <= TRFILT_HEIGHT * TRFILT_WIDTH * N_CHAN"); + + data_T data_buf[CONFIG_T::n_pixels][mult_n_in]; + #pragma HLS ARRAY_PARTITION variable=data_buf complete dim=0 + + #pragma HLS ARRAY_PARTITION variable=biases complete + + typename CONFIG_T::accum_t acc[CONFIG_T::n_pixels][mult_n_out][CONFIG_T::stride_height][CONFIG_T::stride_width]; + #pragma HLS ARRAY_PARTITION variable=acc complete dim=0 + + #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor dim=3 + + PartitionLoop: + for (unsigned i_part = 0; i_part < CONFIG_T::n_partitions; i_part++) { + CONFIG_T::template fill_buffer::fill_buffer(data, data_buf, i_part); + + PixelInitAccumLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + InitAccumLoop: + for (unsigned i_acc = 0; i_acc < mult_n_out; i_acc++) { + #pragma HLS UNROLL + + InitStrideHeightLoop: + for (unsigned i_sh = 0; i_sh < CONFIG_T::stride_height; i_sh++) { + #pragma HLS UNROLL + + InitStrideWidthLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + acc[i_pxl][i_acc][i_sh][i_sw] = (typename CONFIG_T::accum_t) biases[i_acc]; + } + } + } + } + + ReuseLoop: + for (unsigned i_rf = 0; i_rf < CONFIG_T::reuse_factor; i_rf++) { + #pragma HLS PIPELINE II=1 rewind + + unsigned i_w = i_rf; + unsigned i_in = i_rf; + unsigned i_out = 0; + unsigned i_acc = 0; + + MultLoop: + for (unsigned i_blk = 0; i_blk < block_factor; i_blk++) { + #pragma HLS UNROLL + PixelMultLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + StrideHeightMultLoop: + for (unsigned i_sh = 0; i_sh < CONFIG_T::stride_height; i_sh++) { + #pragma HLS UNROLL + StrideWidthMultLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + + acc[i_pxl][i_out][i_sh][i_sw] += static_cast( + CONFIG_T::mult_config::template product::product( + data_buf[i_pxl][i_in], weights[i_sh][i_sw][i_w] + ) + ); + } + } + } + + // Increment i_w + i_w += CONFIG_T::reuse_factor; + // Increment i_in + i_in += CONFIG_T::reuse_factor; + if (i_in >= mult_n_in) { + i_in = i_rf; + } + // Increment i_out + if (i_acc + 1 >= multscale) { + i_acc = 0; + i_out++; + } else { + i_acc++; + } + } + } + + PixelResultLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + StrideHeightResultLoop: + for (unsigned i_sh = 0; i_sh < CONFIG_T::stride_height; i_sh++) { + #pragma HLS UNROLL + StrideWidthResultLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + + unsigned px_ind = i_pxl * CONFIG_T::n_partitions + i_part; + unsigned height_ind = (px_ind / CONFIG_T::proc_width) * CONFIG_T::stride_height + i_sh; + unsigned width_ind = (px_ind % CONFIG_T::proc_width) * CONFIG_T::stride_width + i_sw; + + if (height_ind >= CONFIG_T::pad_top && height_ind < CONFIG_T::out_height + CONFIG_T::pad_top && + width_ind >= CONFIG_T::pad_left && width_ind < CONFIG_T::out_width + CONFIG_T::pad_left) { + ResultLoop: for (unsigned i_res = 0; i_res < mult_n_out; i_res++) { + #pragma HLS UNROLL + + res[((height_ind-CONFIG_T::pad_top)*CONFIG_T::out_width + width_ind-CONFIG_T::pad_left)* + CONFIG_T::n_filt + i_res] = + cast( + acc[i_pxl][i_res][i_sh][i_sw] + ); + } + } + } + } + + } + } +} + +} + +#endif \ No newline at end of file diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h new file mode 100644 index 0000000000..eff09cac91 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h @@ -0,0 +1,209 @@ +#ifndef NNET_CONV2DTRANSPOSE_STREAM_H +#define NNET_CONV2DTRANSPOSE_STREAM_H + +#include "ap_shift_reg.h" +#include "nnet_conv_stream.h" +#include "nnet_common.h" +#include "hls_stream.h" + +namespace nnet { + +template +void kernel_shift_tr_2d( + typename data_T::value_type shift_buffer[CONFIG_T::trfilt_height][CONFIG_T::n_chan], + typename data_T::value_type kernel_window[CONFIG_T::trfilt_width * CONFIG_T::trfilt_height * CONFIG_T::n_chan] +) { + #pragma HLS inline + + // Shift kernel_window by one step to the left (manual shift operation) + static const int filt_width = CONFIG_T::trfilt_width - 1; + KernelShiftWidth: for (int i_iw = 0; i_iw < filt_width; i_iw++) { + #pragma HLS PIPELINE II = 1 + KernelShiftHeight: for (unsigned i_ih = 0; i_ih < CONFIG_T::trfilt_height; i_ih++) { + KernelShiftChannel: for (unsigned i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + // Shift every element in kernel_window to the left + kernel_window[i_ih * CONFIG_T::trfilt_width * CONFIG_T::n_chan + i_iw * CONFIG_T::n_chan + i_ic] = kernel_window[i_ih * CONFIG_T::trfilt_width * CONFIG_T::n_chan + (i_iw + 1) * CONFIG_T::n_chan + i_ic]; + } + } + } + + // Insert shift_buffer column into right-most column of kernel + static const int lastheight = (CONFIG_T::trfilt_width - 1) * CONFIG_T::n_chan; + KernelPushHeight: for (int i_ih = 0; i_ih < CONFIG_T::trfilt_height; i_ih++) { + #pragma HLS UNROLL + KernelPushChannel: for (int i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + kernel_window[lastheight + i_ih * CONFIG_T::trfilt_width * CONFIG_T::n_chan + i_ic] = shift_buffer[i_ih][i_ic]; + } + } +} + +template +void shift_line_buffer_tr(const data_T& in_elem, + ap_shift_reg line_buffer[MAX(CONFIG_T::trfilt_height - 1,1)][CONFIG_T::n_chan], + typename data_T::value_type kernel_window[CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan] +) { + + #pragma HLS PIPELINE + + // Temporary buffer for popped (shifted) elements + typename data_T::value_type shift_buffer[CONFIG_T::trfilt_height][CONFIG_T::n_chan]; + #pragma HLS ARRAY_PARTITION variable = shift_buffer complete dim = 0 + + UpdateBuffer: for (int i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + #pragma HLS UNROLL + + // Insert pixel(s) at end of shift buffer + shift_buffer[CONFIG_T::trfilt_height - 1][i_ic] = in_elem[i_ic]; + } + + LineBufferDataIn: for (int i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + // Shift the shift buffer into the line buffer + LineBufferShift: for (unsigned i_ih = 1; i_ih < CONFIG_T::trfilt_height; i_ih++) { + #pragma HLS UNROLL + typename data_T::value_type pop_elem = + line_buffer[i_ih - 1][i_ic].shift(shift_buffer[CONFIG_T::trfilt_height - i_ih][i_ic]); // Shift the line buffer, return the popped pixel + shift_buffer[CONFIG_T::trfilt_height - i_ih - 1][i_ic] = pop_elem; // Popped element placed back into shift_buffer, one row up. + } + } + kernel_shift_tr_2d(shift_buffer, kernel_window); +} + +template +void compute_output_buffer_tr_2d( + const data_T& in_elem, + ap_shift_reg line_buffer[MAX(CONFIG_T::trfilt_height-1, 1)][CONFIG_T::n_chan], + hls::stream &res_stream, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_height][CONFIG_T::stride_width][ + CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + #pragma HLS INLINE + + //Counters + static int pX = 0; //pixel counters + static int pY = 0; + + static typename data_T::value_type kernel_data[CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan]; + #pragma HLS ARRAY_PARTITION variable=kernel_data complete + + typename res_T::value_type res_out[CONFIG_T::n_filt]; + #pragma HLS ARRAY_PARTITION variable=res_out complete dim = 0 + + static typename res_T::value_type output_buffer[ + CONFIG_T::in_width*CONFIG_T::stride_width*CONFIG_T::stride_height*CONFIG_T::n_filt + ]; + #pragma HLS ARRAY_PARTITION variable=output_buffer complete dim = 0 + + res_T res_pack; + #pragma HLS DATA_PACK variable = res_pack + + //Add pixel to the buffer + nnet::shift_line_buffer_tr(in_elem, line_buffer, kernel_data); + + HeightStrideLoop: for (int w_idx = 0; w_idx < CONFIG_T::stride_width; w_idx++) { + // #pragma HLS PIPELINE + #pragma HLS UNROLL + WidthStrideLoop: for (int h_idx = 0; h_idx < CONFIG_T::stride_height; h_idx++) { + #pragma HLS UNROLL + + #pragma HLS INLINE region + + if (CONFIG_T::strategy == nnet::latency) { + dense_latency( + kernel_data, res_out, weights[h_idx][w_idx], biases + ); + } else { + dense_resource( + kernel_data, res_out, weights[h_idx][w_idx], biases + ); + } + + BufferOutputLoop: for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) { + #pragma HLS UNROLL + output_buffer[ + (pX*CONFIG_T::stride_width+w_idx)*CONFIG_T::stride_height*CONFIG_T::n_filt + + h_idx*CONFIG_T::n_filt + i_ic + ] = res_out[i_ic]; + } + } + } + + //Counter Housekeeping and printing buffered output + if (pX + 1 == CONFIG_T::in_width) { + pX = 0; + //write all of the buffered output for outputs we want + HeightOutputLoop: for (unsigned h_idx = 0; h_idx < CONFIG_T::stride_height; h_idx++) { + // #pragma HLS PIPELINE + if (pY*CONFIG_T::stride_height + h_idx >= CONFIG_T::pad_top && + pY*CONFIG_T::stride_height + h_idx < CONFIG_T::pad_top + CONFIG_T::out_height) { + WidthOutputLoop: for (unsigned oX = CONFIG_T::pad_left; oX < CONFIG_T::pad_left + CONFIG_T::out_width; oX++) { + #pragma HLS PIPELINE + CastLoop: for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) { + #pragma HLS UNROLL + res_pack[i_ic] = output_buffer[ + oX*CONFIG_T::stride_height*CONFIG_T::n_filt + + h_idx*CONFIG_T::n_filt + i_ic + ]; + } + res_stream.write(res_pack); + } + } + } + + if (pY + 1 == CONFIG_T::in_height) { + pY = 0; + } else { + pY = pY + 1; + } + } else { + pX = pX + 1; + } + +} + +template +void conv_2d_transpose_buffer_cl( + hls::stream &data, + hls::stream &res, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_height][CONFIG_T::stride_width][ + CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + static ap_shift_reg line_buffer[MAX(CONFIG_T::trfilt_height-1, 1)][CONFIG_T::n_chan]; + #pragma HLS ARRAY_PARTITION variable = line_buffer complete dim = 2 + + ReadInputHeight: for (unsigned i_ih = 0; i_ih < CONFIG_T::in_height; i_ih++) { + ReadInputWidth: for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width; i_iw++) { + #pragma HLS LOOP_FLATTEN + if (CONFIG_T::strategy == nnet::latency) { + #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + } + compute_output_buffer_tr_2d(data.read(), line_buffer, res, weights, biases); + } + } +} + +template +void conv_2d_transpose_cl( + hls::stream &data, + hls::stream &res, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_height][CONFIG_T::stride_width][ + CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + #pragma HLS INLINE region + switch(CONFIG_T::implementation) { + case conv_implementation::linebuffer: + conv_2d_transpose_buffer_cl(data, res, weights, biases); + break; + } +} + +} +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_helpers.h b/hls4ml/templates/vivado/nnet_utils/nnet_helpers.h index eed64fc701..9d030772fd 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_helpers.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_helpers.h @@ -67,6 +67,83 @@ void load_weights_from_txt(T *w, const char* fname) { } } +template +void load_weights_from_txt(T w[DIM_1][DIM_2], const char* fname) { + + std::string full_path = std::string(WEIGHTS_DIR) + "/" + std::string(fname); + std::ifstream infile(full_path.c_str(), std::ios::binary); + + if (infile.fail()) { + std::cerr << "ERROR: file " << std::string(fname) << " does not exist" << std::endl; + exit(1); + } + + std::string line; + if (std::getline(infile, line)) { + std::istringstream iss(line); + std::string token; + + size_t i = 0; + size_t j = 0; + size_t tot = 0; + while(std::getline(iss, token, ',')) { + std::istringstream(token) >> w[i][j]; + j++; + if (j == DIM_2) { + j = 0; + i++; + } + tot++; + } + + if (DIM_1*DIM_2 != tot) { + std::cerr << "ERROR: Expected " << DIM_1*DIM_2 << " values"; + std::cerr << " but read only " << tot << " values" << std::endl; + } + } +} + +template +void load_weights_from_txt(T w[DIM_1][DIM_2][DIM_3], const char* fname) { + + std::string full_path = std::string(WEIGHTS_DIR) + "/" + std::string(fname); + std::ifstream infile(full_path.c_str(), std::ios::binary); + + if (infile.fail()) { + std::cerr << "ERROR: file " << std::string(fname) << " does not exist" << std::endl; + exit(1); + } + + std::string line; + if (std::getline(infile, line)) { + std::istringstream iss(line); + std::string token; + + size_t i = 0; + size_t j = 0; + size_t k = 0; + size_t tot = 0; + while(std::getline(iss, token, ',')) { + std::istringstream(token) >> w[i][j][k]; + k++; + if (k == DIM_3) { + k = 0; + j++; + if (j == DIM_2) { + j = 0; + i++; + } + } + tot++; + } + + if (DIM_1*DIM_2*DIM_3 != tot) { + std::cerr << "ERROR: Expected " << DIM_1*DIM_2*DIM_3 << " values"; + std::cerr << " but read only " << tot << " values" << std::endl; + } + } +} + template void load_compressed_weights_from_txt(T *w, const char* fname) { diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index bcf752b835..152069436a 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -43,17 +43,27 @@ def print_array_to_cpp(self, var, odir, write_txt_file=True): h_file.write(var.definition_cpp() + ";\n") h_file.write("#else\n") - h_file.write(var.definition_cpp() + " = {") - - # fill c++ array. - # not including internal brackets for multidimensional case - sep = '' - for x in var: - h_file.write(sep + x) + h_file.write(var.definition_cpp() + " = ") + + factors = np.ones(len(var.shape)+1) + for idx in range(len(var.shape)-1, -1, -1): + factors[idx] = var.shape[idx] * factors[idx+1] + #fill c++ array, keeping the first keep_dims dimensions in-tact. + for idx, x in enumerate(var): + for dim in range(var.keep_dims+1): + if idx % factors[dim] == 0: + h_file.write("{") + h_file.write(x) if write_txt_file: - txt_file.write(sep + x) - sep = ", " - h_file.write("};\n") + txt_file.write(x) + for dim in range(var.keep_dims+1): + if idx % factors[dim] == factors[dim]-1: + h_file.write("}") + if idx < factors[0]-1: # only don't put comma at the end + h_file.write(", ") + if write_txt_file: + txt_file.write(", ") + h_file.write(";\n") if write_txt_file: h_file.write("#endif\n") txt_file.close() @@ -150,9 +160,12 @@ def write_project_cpp(self, model): w.type.name, w.data_length, w.name, w.name ) else: - newline += indent + ' nnet::load_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format( - w.type.name, w.data_length, w.name, w.name - ) + dim_info = w.data_length + if w.keep_dims == 1: + dim_info = '{}, {}'.format(w.shape[0], w.data_length//w.shape[0]) + if w.keep_dims == 2: + dim_info = '{}, {}, {}'.format(w.shape[0], w.shape[1], w.data_length//(w.shape[0]*w.shape[1])) + newline += indent + ' nnet::load_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format(w.type.name, dim_info, w.name, w.name) # Add input/output type elif '//hls-fpga-machine-learning insert IO' in line: