From c1ea360093aae44eaf3009539d2a9ec4948a07d7 Mon Sep 17 00:00:00 2001 From: Jonathan-Shoemaker Date: Mon, 20 Jun 2022 12:27:03 -0500 Subject: [PATCH 1/5] attempt to add support for conv1d transpose add new files for conv1dtranspose resource clean up so that conv code is reached. Still need to get the actual implementation matching keras implement conv1dtranspose super inefficiently (gets correct answer though) try to fix indices to make code work make the c code work for conv1dtranspose reduce weight dimensions to properly reflect transposed kernel size clean up so that transpose filter width is passes around from config fix code such that simple transpose layer gets synthesized move variables out of loops, optimize slightly and add in alternative method of computation to compute by kernel (that option is not optimized as of now) add in conv1d transpose linebuffer format code. seems to work, unsure of if it is optimized yet trying to fix stream behavior get transpose compilation working mostly as expected. weird jump in latency from reuse 1 to 2 still exists initial conv2dtranspose addition. Output is permuted as of now. output in correct order. using large array to buffer output though fix up conv1dtranspose a bit to pad correctly. fix up stream instructions for both 1d and 2d transposes fix allowed reuse factors for transpose layers update to new conv methods for io_parallel. Still some issues with multiple filters as well as some padding issues clean up error with multiple filters and larger kernels optimize conv transpose resource to get it working reasonably well. may still have slight optimization left fix output to conv1d transpose resource add conv2dtranspose io_parallel implementation. Can still be optimized small changeup to data storage in conv1d parallel fix zero padding pass addition for transpose stream layers move transposing of weight matrix to resource_strategy for transpose layers change how stream loads in weights to be like parallel for conv transposes. unroll all stride steps completely fix output of 1d transpose parallel to be faster change 1d transpose weight input to be 2-dimensional (passed from python code) change 2d transpose weight input to be 3-dimensional (passed from python code) small changes to transposes Revert "fix nondefault project name handling (#626)". The commit breaks the Vivado Accelerator workflow, and the fix is unclear to me right now. This reverts commit e8f048ad2a49c067eb5e49740a5d94c7c1e33b24. steps towards getting integer inputs to work --- hls4ml/backends/fpga/fpga_backend.py | 161 +++++++++++ hls4ml/backends/fpga/fpga_types.py | 9 + hls4ml/backends/fpga/passes/codegen.py | 38 ++- .../backends/vivado/passes/conv_same_pad.py | 112 +++++++- hls4ml/backends/vivado/passes/conv_stream.py | 26 +- .../vivado/passes/convolution_templates.py | 164 +++++++++++- .../vivado/passes/resource_strategy.py | 3 +- hls4ml/backends/vivado/vivado_backend.py | 68 +++++ hls4ml/converters/keras/convolution.py | 83 +++++- hls4ml/converters/utils.py | 41 +++ hls4ml/model/layers.py | 127 ++++++++- hls4ml/model/types.py | 6 +- hls4ml/report/vivado_report.py | 20 +- hls4ml/templates/vivado/build_prj.tcl | 252 +++++++++--------- .../vivado/nnet_utils/nnet_conv1dtranspose.h | 50 ++++ .../nnet_conv1dtranspose_resource.h | 132 +++++++++ .../nnet_utils/nnet_conv1dtranspose_stream.h | 141 ++++++++++ .../vivado/nnet_utils/nnet_conv2dtranspose.h | 60 +++++ .../nnet_conv2dtranspose_resource.h | 148 ++++++++++ .../nnet_utils/nnet_conv2dtranspose_stream.h | 209 +++++++++++++++ .../vivado/nnet_utils/nnet_helpers.h | 77 ++++++ hls4ml/templates/vivado/vivado_synth.tcl | 7 +- .../alveo/tcl_scripts/axi_stream_design.tcl | 8 +- hls4ml/writer/vivado_writer.py | 43 +-- 24 files changed, 1808 insertions(+), 177 deletions(-) create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose.h create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_resource.h create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose.h create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_resource.h create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index ec496c2104..205ff292c5 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -158,6 +158,22 @@ def get_layer_mult_size(self, layer): n_out = layer.get_attr('n_out') return n_in, n_out + if 'Conv1DTranspose' in layer.class_name: + trfilt_width = (layer.get_attr('filt_width') + layer.get_attr('stride_width') - 1) \ + // layer.get_attr('stride_width') + n_in = layer.get_attr('n_chan') * trfilt_width + n_out = layer.get_attr('n_filt') + return n_in, n_out + + if 'Conv2DTranspose' in layer.class_name: + trfilt_width = (layer.get_attr('filt_width') + layer.get_attr('stride_width') - 1) \ + // layer.get_attr('stride_width') + trfilt_height = (layer.get_attr('filt_height') + layer.get_attr('stride_height') - 1) \ + // layer.get_attr('stride_height') + n_in = layer.get_attr('n_chan') * trfilt_height * trfilt_width + n_out = layer.get_attr('n_filt') + return n_in, n_out + if 'Conv1D' in layer.class_name: n_in = layer.get_attr('n_chan') * layer.get_attr('filt_width') n_out = layer.get_attr('n_filt') @@ -713,7 +729,67 @@ def generate_conv1d_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, ke " ) {{\n" ).format(index=layer_idx) indent = ' ' + for partition_idx, partition in enumerate(np.split(im2col_matrix, n_partitions)): + generated_code += indent * 2 + 'if (partition == {:>3}) {{\n'.format(partition_idx) + for pixel_idx, arr in enumerate(partition): + buffer_stmts = [] + for j, v in enumerate(arr): + if v == 0: + val = '0' + else: + val = 'data[{}]'.format(int(v-1)) + buffer_stmts.append('buffer[{}][{}] = {:>10};'.format(pixel_idx, j, val)) + generated_code += indent * 3 + ' '.join(buffer_stmts) + '\n' + generated_code += '\n' + indent * 2 + '}\n' + + generated_code += indent + '}\n' + generated_code += '};\n' + + return generated_code + + def _compute_conv1d_tr_im2col(self, input_shape, out_w, kernel=3, stride=1): + W, C = input_shape + + tr_kernel = (kernel+stride-1)//stride + + input_img = np.arange(1, W * C + 1) + im_matrix = np.zeros((tr_kernel * C * out_w, )) + + index = 0 + for i_ow in range(out_w): + for i_kw in range(tr_kernel): + for i_c in range(C): + # input column is just the output column shifted + input_col = i_ow - (tr_kernel-1) + i_kw + if (input_col >= 0 and input_col < W): + im_matrix[index] = input_img[input_col * C + i_c] + else: + im_matrix[index] = 0 + index += 1 + im_matrix = im_matrix.reshape(out_w, -1) + return im_matrix + + + def generate_conv1d_tr_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, out_W, kernel=3, stride=1): + im2col_matrix = self._compute_conv1d_tr_im2col( + (in_W, in_C), + out_W, + kernel, + stride, + ) + + generated_code = ( + "template\n" + "class fill_buffer_{index} : public FillConv1DBuffer {{\n" + " public:\n" + " static void fill_buffer(\n" + " data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],\n" + " data_T buffer[CONFIG_T::n_pixels][CONFIG_T::trfilt_width * CONFIG_T::n_chan],\n" + " const unsigned partition\n" + " ) {{\n" + ).format(index=layer_idx) + indent = ' ' for partition_idx, partition in enumerate(np.split(im2col_matrix, n_partitions)): generated_code += indent * 2 + f'if (partition == {partition_idx:>3}) {{\n' for pixel_idx, arr in enumerate(partition): @@ -862,6 +938,91 @@ def generate_conv2d_line_buffer_fn( return generated_code + def _compute_conv2d_tr_im2col(self, input_shape, out_shape, kernel=(3, 3), stride=(1, 1)): + H, W, C = input_shape + kernel_h, kernel_w = kernel + stride_h, stride_w = stride + out_h, out_w = out_shape + + tr_kernel_h = (kernel_h+stride_h-1)//stride_h + tr_kernel_w = (kernel_w+stride_w-1)//stride_w + + input_img = np.arange(1, H * W * C + 1) + im_matrix = np.zeros((tr_kernel_h * tr_kernel_w * C * out_h * out_w, )) + + index = 0 + for i_oh in range(out_h): + for i_ow in range(out_w): + for i_kh in range(tr_kernel_h): + input_row = i_oh - (tr_kernel_h-1) + i_kh + for i_kw in range(tr_kernel_w): + for i_c in range(C): + if (input_row < 0 or input_row >= H): + im_matrix[index] = 0 + else: + input_col = i_ow - (tr_kernel_w-1) + i_kw + if (input_col >= 0 and input_col < W): + im_matrix[index] = input_img[input_row * W * C + input_col * C + i_c] + else: + im_matrix[index] = 0 + index += 1 + + im_matrix = im_matrix.reshape(out_h * out_w, -1) + return im_matrix + + + def generate_conv2d_tr_line_buffer_fn(self, layer_idx, n_partitions, in_H, in_W, in_C, out_H, out_W, kernel=(3, 3), stride=(1, 1)): + if isinstance(kernel, Iterable): + kernel_height = kernel[0] + kernel_width = kernel[1] + else: + kernel_height = kernel + kernel_width = kernel + + if isinstance(stride, Iterable): + stride_height = stride[0] + stride_width = stride[1] + else: + stride_height = stride + stride_width = stride + + im2col_matrix = self._compute_conv2d_tr_im2col( + (in_H, in_W, in_C), + (out_W, out_W), + (kernel_height, kernel_width), + (stride_height, stride_width), + ) + + generated_code = ( + "template\n" + "class fill_buffer_{index} : public FillConv2DBuffer {{\n" + " public:\n" + " static void fill_buffer(\n" + " data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan],\n" + " data_T buffer[CONFIG_T::n_pixels][CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan],\n" + " const unsigned partition\n" + " ) {{\n" + ).format(index=layer_idx) + indent = ' ' + + for partition_idx, partition in enumerate(np.split(im2col_matrix, n_partitions)): + generated_code += indent * 2 + 'if (partition == {:>3}) {{\n'.format(partition_idx) + for pixel_idx, arr in enumerate(partition): + buffer_stmts = [] + for j, v in enumerate(arr): + if v == 0: + val = '0' + else: + val = 'data[{}]'.format(int(v-1)) + buffer_stmts.append('buffer[{}][{}] = {:>10};'.format(pixel_idx, j, val)) + generated_code += indent * 3 + ' '.join(buffer_stmts) + '\n' + generated_code += '\n' + indent * 2 + '}\n' + + generated_code += indent + '}\n' + generated_code += '};\n' + + return generated_code + @model_optimizer() def write_hls(self, model): self.writer.write_hls(model) diff --git a/hls4ml/backends/fpga/fpga_types.py b/hls4ml/backends/fpga/fpga_types.py index 6b1e63a469..fcd98fdf0b 100644 --- a/hls4ml/backends/fpga/fpga_types.py +++ b/hls4ml/backends/fpga/fpga_types.py @@ -326,6 +326,15 @@ def __init__(self, type_converter): class StaticWeightVariableDefinition(VariableDefinition): def definition_cpp(self, name_suffix='', as_reference=False): + if self.keep_dims > 0: + size_str = '' + for dim in range(self.keep_dims): + size_str += '[{cur_dim}]'.format(cur_dim=self.shape[dim]) + final_dim = 1 + for dim in range(self.keep_dims, len(self.shape)): + final_dim *= self.shape[dim] + size_str += '[{last_dim}]'.format(last_dim=final_dim) + return '{type} {name}{sizes}'.format(type=self.type.name, name=self.name, sizes=size_str) return '{type} {name}[{size}]'.format(type=self.type.name, name=self.name, size=self.data_length) class StaticWeightVariableConverter(object): diff --git a/hls4ml/backends/fpga/passes/codegen.py b/hls4ml/backends/fpga/passes/codegen.py index f031645091..037280a7fd 100644 --- a/hls4ml/backends/fpga/passes/codegen.py +++ b/hls4ml/backends/fpga/passes/codegen.py @@ -1,17 +1,21 @@ from hls4ml.model.optimizer import OptimizerPass -from hls4ml.model.layers import Conv1D, Conv2D +from hls4ml.model.layers import Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose from hls4ml.model.types import Source class GenerateConvIm2col(OptimizerPass): ''' Generates tcode for im2col step of 1D/2d convolution ''' def match(self, node): - return isinstance(node, (Conv1D, Conv2D)) and \ + return isinstance(node, (Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose)) and \ node.model.config.get_config_value('IOType') == 'io_parallel' def transform(self, model, node): node_class = node.__class__.__name__ - if '1D' in node_class: + if '1DTranspose' in node_class: + self._generate_im2col_1d_transpose(node) + elif '1D' in node_class: self._generate_im2col_1d(node) + elif '2DTranspose' in node_class: + self._generate_im2col_2d_transpose(node) elif '2D' in node_class: self._generate_im2col_2d(node) else: @@ -30,6 +34,19 @@ def _generate_im2col_1d(self, node): node.set_attr('line_buffer_codegen', Source(code_str)) + def _generate_im2col_1d_transpose(self, node): + code_str = node.model.config.backend.generate_conv1d_tr_line_buffer_fn( + node.get_attr('index'), + node.get_attr('n_partitions'), + node.get_input_variable().shape[0], + node.get_input_variable().shape[1], + node.get_attr('proc_width'), + kernel=node.get_attr('filt_width'), + stride=node.get_attr('stride_width'), + ) + + node.set_attr('line_buffer_codegen', Source(code_str)) + def _generate_im2col_2d(self, node): code_str = node.model.config.backend.generate_conv2d_line_buffer_fn( node.get_attr('index'), @@ -43,3 +60,18 @@ def _generate_im2col_2d(self, node): ) node.set_attr('line_buffer_codegen', Source(code_str)) + + def _generate_im2col_2d_transpose(self, node): + code_str = node.model.config.backend.generate_conv2d_tr_line_buffer_fn( + node.get_attr('index'), + node.get_attr('n_partitions'), + node.get_input_variable().shape[0], + node.get_input_variable().shape[1], + node.get_input_variable().shape[2], + node.get_attr('proc_height'), + node.get_attr('proc_width'), + kernel=(node.get_attr('filt_height'), node.get_attr('filt_width')), + stride=(node.get_attr('stride_height'), node.get_attr('stride_width')), + ) + + node.set_attr('line_buffer_codegen', Source(code_str)) diff --git a/hls4ml/backends/vivado/passes/conv_same_pad.py b/hls4ml/backends/vivado/passes/conv_same_pad.py index dc27076574..54c676386a 100644 --- a/hls4ml/backends/vivado/passes/conv_same_pad.py +++ b/hls4ml/backends/vivado/passes/conv_same_pad.py @@ -1,5 +1,5 @@ from hls4ml.model.optimizer import OptimizerPass -from hls4ml.model.layers import Conv1D, SeparableConv1D, Conv2D, SeparableConv2D +from hls4ml.model.layers import Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, Conv1DTranspose, Conv2DTranspose class InsertZeroPaddingBeforeConv1D(OptimizerPass): name = 'insert_zero_padding_before_conv1d' @@ -46,6 +46,53 @@ def transform(self, model, node): return True +class InsertZeroPaddingBeforeConv1DTranspose(OptimizerPass): + name = 'insert_zero_padding_before_conv1dtranspose' + + def match(self, node): + is_match = isinstance(node, (Conv1DTranspose)) and \ + node.get_attr('padding') == 'same' and \ + node.get_attr('filt_width') != 1 + return is_match + + def transform(self, model, node): + if model.config.get_config_value('IOType') != 'io_stream': + return False + + # Get the padding parameters from Conv1D layer + pad_left = node.get_attr('pad_left') + pad_right = node.get_attr('pad_right') + convtr_out_width = node.get_attr('out_width') + in_width = node.get_attr('in_width') + stride_width = node.get_attr('stride_width') + trfilt_width = (node.get_attr('filt_width') + node.get_attr('stride_width') - 1) \ + // node.get_attr('stride_width') + + add_right = (convtr_out_width + pad_left)//stride_width - (in_width-1) + + out_width = in_width + add_right + trfilt_width-1 + + attrs = { + 'pad_left': trfilt_width-1, + 'pad_right': add_right, + 'in_width': in_width, + 'out_width': out_width, + 'n_chan': node.get_attr('n_chan'), + 'data_format': node.get_attr('data_format', 'channels_last') + } + + # Switch Conv1DTranspose to be 'valid'. I think this is wrong + node.set_attr('padding', 'valid') + node.set_attr('in_width', out_width) + node.set_attr('pad_left', pad_left + (trfilt_width-1)*stride_width) + + # Insert new ZeroPadding1D node above Conv1DTranspose + padding_layer = model.make_node('ZeroPadding1D', 'zp1d_' + node.name, attrs, node.inputs.copy()) + padding_layer.get_output_variable().type.precision = node.get_input_variable().type.precision + model.insert_node(padding_layer) + + return True + class InsertZeroPaddingBeforeConv2D(OptimizerPass): name = 'insert_zero_padding_before_conv2d' @@ -100,3 +147,66 @@ def transform(self, model, node): model.insert_node(padding_layer, before=node) return True + +class InsertZeroPaddingBeforeConv2DTranspose(OptimizerPass): + name = 'insert_zero_padding_before_conv2dtranspose' + + def match(self, node): + is_match = isinstance(node, Conv2DTranspose) and \ + node.get_attr('padding') == 'same' and \ + node.get_attr('filt_width') != 1 + return is_match + + def transform(self, model, node): + if model.config.get_config_value('IOType') != 'io_stream': + return False + + # Get the padding parameters from Conv2DTranspose layer + pad_left = node.get_attr('pad_left') + pad_right = node.get_attr('pad_right') + pad_top = node.get_attr('pad_top') + pad_bottom = node.get_attr('pad_bottom') + convtr_out_width = node.get_attr('out_width') + convtr_out_height = node.get_attr('out_height') + in_width = node.get_attr('in_width') + in_height = node.get_attr('in_height') + stride_width = node.get_attr('stride_width') + stride_height = node.get_attr('stride_height') + trfilt_width = (node.get_attr('filt_width') + node.get_attr('stride_width') - 1) \ + // node.get_attr('stride_width') + trfilt_height = (node.get_attr('filt_height') + node.get_attr('stride_height') - 1) \ + // node.get_attr('stride_height') + + add_right = (convtr_out_width + pad_left)//stride_width-(in_width-1) + add_bottom = (convtr_out_height + pad_top)//stride_height-(in_height-1) + + out_width = in_width + add_right + trfilt_width-1 + out_height = in_height + add_bottom + trfilt_height-1 + + attrs = { + 'pad_left': trfilt_width-1, + 'pad_right': add_right, + 'pad_top': trfilt_height-1, + 'pad_bottom': add_bottom, + 'in_width': in_width, + 'in_height': in_height, + 'out_width': out_width, + 'out_height': out_height, + 'n_chan': node.get_attr('n_chan'), + 'data_format': node.get_attr('data_format', 'channels_last') + } + + # switch Conv2DTranspose to be 'valid'. This is technically not true though + node.set_attr('padding', 'valid') + node.set_attr('in_width', out_width) + node.set_attr('in_height', out_height) + node.set_attr('pad_left', pad_left + (trfilt_width-1)*stride_width) + node.set_attr('pad_top', pad_top + (trfilt_height-1)*stride_height) + + # insert new ZeroPadding2D ndoe above Conv2DTranspose + padding_layer = model.make_node('ZeroPadding2D', 'zp2d_' + node.name, attrs, node.inputs.copy()) + padding_layer.get_output_variable().type.precision = node.get_input_variable().type.precision + model.insert_node(padding_layer, before=node) + + return True + diff --git a/hls4ml/backends/vivado/passes/conv_stream.py b/hls4ml/backends/vivado/passes/conv_stream.py index e0bb853d83..fbca22d5fa 100644 --- a/hls4ml/backends/vivado/passes/conv_stream.py +++ b/hls4ml/backends/vivado/passes/conv_stream.py @@ -1,4 +1,4 @@ -from hls4ml.model.layers import Conv1D, Conv2D, SeparableConv1D, SeparableConv2D +from hls4ml.model.layers import Conv1D, Conv2D, SeparableConv1D, SeparableConv2D, Conv2DTranspose, Conv1DTranspose from hls4ml.model.optimizer import OptimizerPass @@ -6,7 +6,7 @@ class GenerateConvStreamingInstructions(OptimizerPass): '''Generates the instructions for streaming implementation of CNNs''' def match(self, node): - return isinstance(node, (Conv1D, SeparableConv1D, Conv2D, SeparableConv2D)) + return isinstance(node, (Conv1D, Conv1DTranspose, SeparableConv1D, Conv2D, SeparableConv2D, Conv2DTranspose)) def transform(self, model, node): node_class = node.__class__.__name__ @@ -18,12 +18,19 @@ def transform(self, model, node): raise Exception(f'Cannot generate instructions for node {node.name} ({node_class})') def _generate_1d_instructions(self, node): + kernel_width = node.get_attr('filt_width') + stride_width = node.get_attr('stride_width') + if isinstance(node, Conv1DTranspose): + # set kernel width to trfilt_width and set stride to 1 (effective kernel dimensions in transpose) + kernel_width = (node.get_attr('filt_width') + node.get_attr('stride_width')-1) \ + // node.get_attr('stride_width') + stride_width = 1 if node.model.config.get_config_value('IOType') == 'io_stream': min_w, instructions = node.model.config.backend.compute_conv1d_instructions( node.get_input_variable().shape[0], node.get_input_variable().shape[1], - node.get_attr('filt_width'), - node.get_attr('stride_width'), + kernel_width, + stride_width, ) instructions_str = ','.join(str(i) for i in instructions) node.set_attr('min_width', min_w) @@ -34,13 +41,20 @@ def _generate_1d_instructions(self, node): node.set_attr('instructions', '0') def _generate_2d_instructions(self, node): + kernel_height = node.get_attr('filt_height') + stride_height = node.get_attr('stride_height') + if isinstance(node, Conv2DTranspose): + # set actual kernel height to trfilt_height and set stride to 1 (effective kernel in transpose) + kernel_height = (node.get_attr('filt_height') + node.get_attr('stride_height') - 1) \ + // node.get_attr('stride_height') + stride_height = 1 if node.model.config.get_config_value('IOType') == 'io_stream': min_h, min_w, instructions = node.model.config.backend.compute_conv2d_instructions( node.get_input_variable().shape[0], node.get_input_variable().shape[1], node.get_input_variable().shape[2], - node.get_attr('filt_height'), - node.get_attr('stride_height'), + kernel_height, + stride_height, ) instructions_str = ','.join(str(i) for i in instructions) node.set_attr('min_height', min_h) diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index 195fc00b58..1759bedba8 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -1,6 +1,6 @@ from hls4ml.backends.backend import get_backend from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate -from hls4ml.model.layers import Conv1D, Conv2D, Conv2DBatchnorm, DepthwiseConv2D, SeparableConv1D, SeparableConv2D +from hls4ml.model.layers import Conv1D, Conv2D, Conv2DBatchnorm, DepthwiseConv2D, SeparableConv1D, SeparableConv2D, Conv1DTranspose, Conv2DTranspose # Shared multiplication template @@ -102,6 +102,82 @@ def format(self, node): return self.template.format(**params) +# Conv1DTranspose Templates + +conv1dtranspose_config_template = """struct config{index} : nnet::conv1dtranspose_config {{ + static const unsigned pad_left = {pad_left}; + static const unsigned pad_right = {pad_right}; + static const unsigned in_width = {in_width}; + static const unsigned n_chan = {n_chan}; + static const unsigned filt_width = {filt_width}; + static const unsigned kernel_size = filt_width; + static const unsigned n_filt = {n_filt}; + static const unsigned stride_width = {stride_width}; + static const unsigned dilation = {dilation}; + static const unsigned out_width = {out_width}; + static const unsigned reuse_factor = {reuse}; + static const unsigned n_zeros = {nzeros}; + static const unsigned trfilt_width = {trfilt_width}; + static const bool store_weights_in_bram = false; + static const unsigned strategy = nnet::{strategy}; + static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation}; + static const unsigned min_width = {min_width}; + static const ap_uint pixels[min_width]; + static const unsigned n_partitions = {n_partitions}; + static const unsigned proc_width = {proc_width}; + static const unsigned n_pixels = proc_width / n_partitions; + template + using fill_buffer = nnet::{fill_fn}; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + typedef {config_t} mult_config; +}}; +const ap_uint config{index}::pixels[] = {{{instructions}}};\n""" + +conv1dtranspose_function_template = 'nnet::conv_1d_transpose_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' + +conv1dtranspose_include_list = ['nnet_utils/nnet_conv1dtranspose.h', 'nnet_utils/nnet_conv1dtranspose_stream.h'] + +class Conv1DTransposeConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Conv1DTranspose) + self.template = conv1dtranspose_config_template + self.mult_template = conv_mult_config_template + + def format(self, node): + params = self._default_config_params(node) + params['dilation'] = node.get_attr('dilation', 1) + params['nzeros'] = node.get_weights('weight').nzeros + + params['config_t'] = 'config{}_mult'.format(node.index) + if node.model.config.get_config_value('IOType') == 'io_parallel': + params['fill_fn'] = 'fill_buffer_{}'.format(node.index) + else: + params['fill_fn'] = 'FillConv1DBuffer' + conv_config = self.template.format(**params) + + mult_params = self._default_config_params(node) + mult_params['n_in'] = node.get_attr('n_chan') * \ + (node.get_attr('filt_width') + node.get_attr('stride_width') - 1) // node.get_attr('stride_width') + mult_params['n_out'] = node.get_attr('n_filt') + mult_params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) + mult_config = self.mult_template.format(**mult_params) + + return mult_config + '\n' + conv_config + +class Conv1DTransposeFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(Conv1DTranspose, include_header=conv1dtranspose_include_list) + self.template = conv1dtranspose_function_template + + def format(self, node): + params = self._default_function_params(node) + params['data_format'] = 'cf' if node.get_attr('data_format') == 'channels_first' else 'cl' + params['w'] = node.get_weights('weight').name + params['b'] = node.get_weights('bias').name + + return self.template.format(**params) # Conv2D Templates @@ -212,6 +288,92 @@ def __init__(self): super(Conv2DFunctionTemplate, self).__init__(DepthwiseConv2D, include_header=sepconv2d_include_list) self.template = depthconv2d_function_template +# Conv2DTranspose Templates +conv2dtranspose_config_template = """struct config{index} : nnet::conv2dtranspose_config {{ + static const unsigned pad_top = {pad_top}; + static const unsigned pad_bottom = {pad_bottom}; + static const unsigned pad_left = {pad_left}; + static const unsigned pad_right = {pad_right}; + static const unsigned in_height = {in_height}; + static const unsigned in_width = {in_width}; + static const unsigned n_chan = {n_chan}; + static const unsigned filt_height = {filt_height}; + static const unsigned filt_width = {filt_width}; + static const unsigned kernel_size = filt_height * filt_width; + static const unsigned n_filt = {n_filt}; + static const unsigned stride_height = {stride_height}; + static const unsigned stride_width = {stride_width}; + static const unsigned out_height = {out_height}; + static const unsigned out_width = {out_width}; + static const unsigned reuse_factor = {reuse}; + static const unsigned n_zeros = {nzeros}; + static const unsigned trfilt_width = {trfilt_width}; + static const unsigned trfilt_height = {trfilt_height}; + static const bool store_weights_in_bram = false; + static const unsigned strategy = nnet::{strategy}; + static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation}; + static const unsigned min_height = {min_height}; + static const unsigned min_width = {min_width}; + static const ap_uint pixels[min_height * min_width]; + static const unsigned n_partitions = {n_partitions}; + static const unsigned proc_height = {proc_height}; + static const unsigned proc_width = {proc_width}; + static const unsigned n_pixels = proc_height * proc_width / n_partitions; + template + using fill_buffer = nnet::{fill_fn}; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + typedef {config_t} mult_config; +}}; +const ap_uint config{index}::pixels[] = {{{instructions}}};\n""" + +conv2dtranspose_function_template = 'nnet::conv_2d_transpose_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' + +conv2dtranspose_include_list = ['nnet_utils/nnet_conv2dtranspose.h', 'nnet_utils/nnet_conv2dtranspose_stream.h'] + +class Conv2DTransposeConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Conv2DTranspose) + self.template = conv2dtranspose_config_template + self.mult_template = conv_mult_config_template + + def format(self, node): + params = self._default_config_params(node) + params['dilation'] = node.get_attr('dilation', 1) + params['nzeros'] = node.get_weights('weight').nzeros + params['trfilt_width'] = (node.get_attr('filt_width') + node.get_attr('stride_width') - 1) \ + // node.get_attr('stride_width') + params['trfilt_height'] = (node.get_attr('filt_height') + node.get_attr('stride_height') - 1) \ + // node.get_attr('stride_height') + + params['config_t'] = 'config{}_mult'.format(node.index) + if node.model.config.get_config_value('IOType') == 'io_parallel': + params['fill_fn'] = 'fill_buffer_{}'.format(node.index) + else: + params['fill_fn'] = 'FillConv2DBuffer' + conv_config = self.template.format(**params) + + mult_params = self._default_config_params(node) + mult_params['n_in'] = node.get_attr('n_chan') * params['trfilt_width'] * params['trfilt_height'] + mult_params['n_out'] = node.get_attr('n_filt') + mult_params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) + mult_config = self.mult_template.format(**mult_params) + + return mult_config + '\n' + conv_config + +class Conv2DTransposeFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(Conv2DTranspose, include_header=conv2dtranspose_include_list) + self.template = conv2dtranspose_function_template + + def format(self, node): + params = self._default_function_params(node) + params['data_format'] = 'cf' if node.get_attr('data_format') == 'channels_first' else 'cl' + params['w'] = node.get_weights('weight').name + params['b'] = node.get_weights('bias').name + + return self.template.format(**params) # SeparableConv1D/2D Templates diff --git a/hls4ml/backends/vivado/passes/resource_strategy.py b/hls4ml/backends/vivado/passes/resource_strategy.py index 9e41456f5c..eecde7a373 100644 --- a/hls4ml/backends/vivado/passes/resource_strategy.py +++ b/hls4ml/backends/vivado/passes/resource_strategy.py @@ -1,13 +1,14 @@ import numpy as np from hls4ml.model.optimizer import OptimizerPass -from hls4ml.model.layers import Conv1D, Conv2D, Dense, SeparableConv1D, SeparableConv2D, LSTM, GRU +from hls4ml.model.layers import Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose, Dense, SeparableConv1D, SeparableConv2D, LSTM, GRU class ApplyResourceStrategy(OptimizerPass): ''' Transposes the weights to use the dense_resource matrix multiply routine ''' def match(self, node): node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU)) + is_resource_strategy = node.get_attr('strategy', '').lower() == 'resource' already_transformed = node.get_attr('_weights_transposed', False) == True diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 7698bc680d..9988987075 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -11,7 +11,9 @@ GRU, LSTM, Conv1D, + Conv1DTranspose, Conv2D, + Conv2DTranspose, Dense, DepthwiseConv2D, Embedding, @@ -82,6 +84,8 @@ def _register_flows(self): 'vivado:clone_output', 'vivado:insert_zero_padding_before_conv1d', 'vivado:insert_zero_padding_before_conv2d', + 'vivado:insert_zero_padding_before_conv1dtranspose', + 'vivado:insert_zero_padding_before_conv2dtranspose', 'vivado:broadcast_stream', ] streaming_flow = register_flow('streaming', streaming_passes, requires=[init_flow], backend=self.name) @@ -265,6 +269,34 @@ def init_conv1d(self, layer): self._validate_conv_strategy(layer) + @layer_optimizer(Conv1DTranspose) + def init_conv1dtranspose(self, layer): + if layer.model.config.is_resource_strategy(layer): + layer.set_attr('strategy', 'resource') + n_in, n_out = self.get_layer_mult_size(layer) + self.set_target_reuse_factor(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + else: + layer.set_attr('strategy', 'latency') + + in_width = layer.get_input_variable().shape[0] + proc_width = (layer.get_output_variable().shape[0] + layer.get_attr('pad_left') + layer.get_attr('stride_width')-1) \ + // layer.get_attr('stride_width') + chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1) + valid_pf = self.get_valid_conv_partition_splits(1, proc_width) + if chosen_pf not in valid_pf: + closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf) + print('WARNING: Invalid ParallelizationFactor={} in layer "{}". Using ParallelizationFactor={} instead. Valid ParallelizationFactor(s): {}.' + .format(chosen_pf, layer.name, closest_pf, ','.join(map(str, valid_pf)))) + else: + closest_pf = chosen_pf + layer.set_attr('n_partitions', proc_width // closest_pf) + layer.set_attr('proc_width', proc_width) + + layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + + self._validate_conv_strategy(layer) + @layer_optimizer(SeparableConv1D) def init_sepconv1d(self, layer): if layer.model.config.is_resource_strategy(layer): @@ -311,6 +343,42 @@ def init_conv2d(self, layer): self._validate_conv_strategy(layer) + @layer_optimizer(Conv2DTranspose) + def init_conv2dtranspose(self, layer): + if len(layer.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D + layer.weights['weight'].data = np.expand_dims(layer.weights['weight'].data, axis=(0,1)) + + if layer.model.config.is_resource_strategy(layer): + layer.set_attr('strategy', 'resource') + self.set_target_reuse_factor(layer) + n_in, n_out = self.get_layer_mult_size(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + else: + layer.set_attr('strategy', 'latency') + + in_height = layer.get_input_variable().shape[0] + in_width = layer.get_input_variable().shape[1] + + proc_height = (layer.get_output_variable().shape[0] + layer.get_attr('pad_top') + layer.get_attr('stride_height')-1) \ + // layer.get_attr('stride_height') + proc_width = (layer.get_output_variable().shape[1] + layer.get_attr('pad_left') + layer.get_attr('stride_width')-1) \ + // layer.get_attr('stride_width') + chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1) + valid_pf = self.get_valid_conv_partition_splits(proc_height, proc_width) + if chosen_pf not in valid_pf: + closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf) + print('WARNING: Invalid ParallelizationFactor={} in layer "{}". Using ParallelizationFactor={} instead. Valid ParallelizationFactor(s): {}.' + .format(chosen_pf, layer.name, closest_pf, ','.join(map(str, valid_pf)))) + else: + closest_pf = chosen_pf + layer.set_attr('n_partitions', proc_height * proc_width // closest_pf) + layer.set_attr('proc_height', proc_height) + layer.set_attr('proc_width', proc_width) + + layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + + self._validate_conv_strategy(layer) + @layer_optimizer(SeparableConv2D) def init_sepconv2d(self, layer): if layer.model.config.is_resource_strategy(layer): diff --git a/hls4ml/converters/keras/convolution.py b/hls4ml/converters/keras/convolution.py index 3c402496a8..61c70f12a7 100644 --- a/hls4ml/converters/keras/convolution.py +++ b/hls4ml/converters/keras/convolution.py @@ -1,5 +1,5 @@ from hls4ml.converters.keras_to_hls import keras_handler, parse_default_keras_layer -from hls4ml.converters.utils import compute_padding_1d, compute_padding_2d, parse_data_format +from hls4ml.converters.utils import compute_padding_1d, compute_padding_2d, parse_data_format, compute_padding_1d_transpose, compute_padding_2d_transpose @keras_handler('Conv1D', 'SeparableConv1D') @@ -26,6 +26,39 @@ def parse_conv1d_layer(keras_layer, input_names, input_shapes, data_reader): return layer, output_shape +@keras_handler('Conv1DTranspose') +def parse_conv1dtranspose_layer(keras_layer, input_names, input_shapes, data_reader): + assert('Conv1DTranspose' in keras_layer['class_name']) + layer = parse_default_keras_layer(keras_layer, input_names) + + ( + layer['in_width'], + layer['n_chan'] + ) = parse_data_format(input_shapes[0], layer['data_format']) + + layer['n_filt'] = keras_layer['config']['filters'] + layer['filt_width'] = keras_layer['config']['kernel_size'][0] + layer['stride_width'] = keras_layer['config']['strides'][0] + layer['padding'] = keras_layer['config']['padding'] + layer['trfilt_width'] = (layer['filt_width'] + layer['stride_width'] - 1)//layer['stride_width'] + + ( + layer['out_width'], + layer['pad_left'], + layer['pad_right'], + ) = compute_padding_1d_transpose( + layer['padding'], + layer['in_width'], + layer['stride_width'], + layer['filt_width'] + ) + + if layer['data_format'] == 'channels_last': + output_shape = [input_shapes[0][0], layer['out_width'], layer['n_filt']] + elif layer['data_format'] == 'channels_first': + output_shape = [input_shapes[0][0], layer['n_filt'], layer['out_width']] + + return layer, output_shape @keras_handler('Conv2D', 'SeparableConv2D', 'DepthwiseConv2D') def parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader): @@ -68,3 +101,51 @@ def parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader): output_shape = [input_shapes[0][0], layer['out_height'], layer['out_width'], layer['n_filt']] return layer, output_shape + +@keras_handler('Conv2DTranspose') +def parse_conv2dtranspose_layer(keras_layer, input_names, input_shapes, data_reader): + assert('Conv2DTranspose' in keras_layer['class_name']) + + layer = parse_default_keras_layer(keras_layer, input_names) + + ( + layer['in_height'], + layer['in_width'], + layer['n_chan'] + ) = parse_data_format(input_shapes[0], layer['data_format']) + + if 'filters' in keras_layer['config']: + layer['n_filt'] = keras_layer['config']['filters'] + else: + layer['n_filt'] = layer['n_chan'] + layer['filt_height'] = keras_layer['config']['kernel_size'][0] + layer['filt_width'] = keras_layer['config']['kernel_size'][1] + layer['stride_height'] = keras_layer['config']['strides'][0] + layer['stride_width'] = keras_layer['config']['strides'][1] + layer['padding'] = keras_layer['config']['padding'] + layer['trfilt_height'] = (layer['filt_height'] + layer['stride_height'] - 1)//layer['stride_height'] + layer['trfilt_width'] = (layer['filt_width'] + layer['stride_width'] - 1)//layer['stride_width'] + + ( + layer['out_height'], + layer['out_width'], + layer['pad_top'], + layer['pad_bottom'], + layer['pad_left'], + layer['pad_right'] + ) = compute_padding_2d_transpose( + layer['padding'], + layer['in_height'], + layer['in_width'], + layer['stride_height'], + layer['stride_width'], + layer['filt_height'], + layer['filt_width'] + ) + + if layer['data_format'] == 'channels_first': + output_shape = [input_shapes[0][0], layer['n_filt'], layer['out_height'], layer['out_width']] + else: + output_shape = [input_shapes[0][0], layer['out_height'], layer['out_width'], layer['n_filt']] + + return layer, output_shape \ No newline at end of file diff --git a/hls4ml/converters/utils.py b/hls4ml/converters/utils.py index d3fe6edfdd..7618801f1f 100644 --- a/hls4ml/converters/utils.py +++ b/hls4ml/converters/utils.py @@ -45,6 +45,21 @@ def compute_padding_1d(pad_type, in_size, stride, filt_size): return (n_out, pad_left, pad_right) +def compute_padding_1d_transpose(pad_type, in_size, stride, filt_size): + if pad_type.lower() == 'same': + n_out = stride*in_size + pad_along_size = max(filt_size-stride, 0) + pad_left = pad_along_size//2 + pad_right = pad_along_size-pad_left + elif pad_type.lower() == 'valid': + n_out = stride*(in_size-1) + filt_size + pad_left = 0 + pad_right = 0 + else: + raise Exception('Unknown padding type: {}'.format(pad_type)) + + return (n_out, pad_left, pad_right) + def compute_padding_2d(pad_type, in_height, in_width, stride_height, stride_width, filt_height, filt_width): if pad_type.lower() == 'same': #Height @@ -74,4 +89,30 @@ def compute_padding_2d(pad_type, in_height, in_width, stride_height, stride_widt else: raise Exception('Unknown padding type: {}'.format(pad_type)) + return (out_height, out_width, pad_top, pad_bottom, pad_left, pad_right) + +def compute_padding_2d_transpose(pad_type, in_height, in_width, stride_height, stride_width, filt_height, filt_width): + if pad_type.lower() == 'same': + #Height + out_height = stride_height*in_height + pad_along_height = max(filt_height-stride_height, 0) + pad_top = pad_along_height//2 + pad_bottom = pad_along_height-pad_top + #Width + out_width = stride_width*in_width + pad_along_width = max(filt_width-stride_width, 0) + pad_left = pad_along_width//2 + pad_right = pad_along_width-pad_left + elif pad_type.lower() == 'valid': + #something + out_height = stride_height*in_height + out_width = stride_width*in_width + + pad_top = 0 + pad_bottom = 0 + pad_left = 0 + pad_right = 0 + else: + raise Exception('Unknown padding type: {}'.format(pad_type)) + return (out_height, out_width, pad_top, pad_bottom, pad_left, pad_right) \ No newline at end of file diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index b8a3a1a4d9..14b8b857f5 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -232,11 +232,11 @@ def add_output_variable( self.set_attr(out_name, out) - def add_weights(self, quantizer=None, compression=False): + def add_weights(self, quantizer=None, compression=False, keep_dims=0): data = self.model.get_weights_data(self.name, 'kernel') self.add_weights_variable( - name='weight', var_name='w{index}', data=data, quantizer=quantizer, compression=compression + name='weight', var_name='w{index}', data=data, quantizer=quantizer, compression=compression, keep_dims=keep_dims ) def add_bias(self, quantizer=None): @@ -254,7 +254,7 @@ def add_bias(self, quantizer=None): ) def add_weights_variable( - self, name, var_name=None, type_name=None, precision=None, data=None, quantizer=None, compression=False + self, name, var_name=None, type_name=None, precision=None, data=None, quantizer=None, compression=False, keep_dims=0 ): if var_name is None: var_name = name + '{index}' @@ -300,7 +300,7 @@ def add_weights_variable( ) else: var = WeightVariable( - var_name, type_name=type_name, precision=precision, quantizer=quantizer, data=data, index=self.index + var_name, type_name=type_name, precision=precision, quantizer=quantizer, data=data, index=self.index, keep_dims=keep_dims ) var.data_unquantized = data_unquantized @@ -423,6 +423,60 @@ def initialize(self): self.add_bias(quantizer=self.get_attr('bias_quantizer')) +class Conv1DTranspose(Layer): + _expected_attributes = [ + Attribute('in_width'), + Attribute('out_width'), + + Attribute('n_chan'), + Attribute('n_filt'), + + Attribute('filt_width'), + Attribute('stride_width'), + + Attribute('pad_left'), + Attribute('pad_right'), + + WeightAttribute('weight'), + WeightAttribute('bias'), + + TypeAttribute('weight'), + TypeAttribute('bias'), + ] + + def initialize(self): + if self.get_attr('data_format') == 'channels_last': + shape = [self.attributes['out_width'], self.attributes['n_filt']] + dims = ['N_OUTPUTS_{}'.format(self.index), 'N_FILT_{}'.format(self.index)] + else: + shape = [self.attributes['n_filt'], self.attributes['out_width']] + dims = ['N_FILT_{}'.format(self.index), 'N_OUTPUTS_{}'.format(self.index)] + + data = self.model.get_weights_data(self.name, 'kernel') + # now we transform the entire kernel + + #(W,F,C) => (F,W,C) + data = np.transpose(data, axes=[1, 0, 2]) + # now split the kernel into stride width kernels (F, W, C) -> (S, F, W/S, C) + n_filts, kern_width, n_chan = data.shape + new_weights = np.zeros((self.attributes['stride_width'], n_filts, self.attributes['trfilt_width'], n_chan)) + for i_sw in range(self.attributes['stride_width']): + for i_fw in range(self.attributes['trfilt_width']): + filt_ind = i_sw + (self.attributes['trfilt_width']-i_fw-1) * self.attributes['stride_width'] + for i_nf in range(n_filts): + for i_nc in range(n_chan): + if filt_ind < kern_width: + new_weights[i_sw][i_nf][i_fw][i_nc] = \ + data[i_nf][filt_ind][i_nc] + data = new_weights + + self.add_output_variable(shape, dims) + # self.add_weights(quantizer = self.get_attr('weight_quantizer'), keep_dims=1) + self.add_weights_variable(name='weight', var_name='w{index}', \ + data=data, quantizer=self.get_attr('weight_quantizer'), keep_dims=1) + self.add_bias(quantizer = self.get_attr('bias_quantizer')) + + class SeparableConv1D(Layer): _expected_attributes = [ Attribute('in_width'), @@ -500,6 +554,69 @@ def initialize(self): self.add_weights(quantizer=self.get_attr('weight_quantizer')) self.add_bias(quantizer=self.get_attr('bias_quantizer')) +class Conv2DTranspose(Layer): + _expected_attributes = [ + Attribute('in_height'), + Attribute('in_width'), + + Attribute('out_height'), + Attribute('out_width'), + + Attribute('n_chan'), + Attribute('n_filt'), + + Attribute('filt_height'), + Attribute('filt_width'), + Attribute('stride_height'), + Attribute('stride_width'), + + Attribute('pad_top'), + Attribute('pad_bottom'), + Attribute('pad_left'), + Attribute('pad_right'), + + WeightAttribute('weight'), + WeightAttribute('bias'), + + TypeAttribute('weight'), + TypeAttribute('bias'), + ] + + def initialize(self): + if self.get_attr('data_format') == 'channels_last': + shape = [self.attributes['out_height'], self.attributes['out_width'], self.attributes['n_filt']] + dims = ['OUT_HEIGHT_{}'.format(self.index), 'OUT_WIDTH_{}'.format(self.index), 'N_FILT_{}'.format(self.index)] + else: + shape = [self.attributes['n_filt'], self.attributes['out_height'], self.attributes['out_width']] + dims = ['N_FILT_{}'.format(self.index), 'OUT_HEIGHT_{}'.format(self.index), 'OUT_WIDTH_{}'.format(self.index)] + + data = self.model.get_weights_data(self.name, 'kernel') + # now we transform the entire kernel + + #(H,W,F,C) => (F,H,W,C) + data = np.transpose(data, axes=[2, 0, 1, 3]) + # now split the kernel into stride width kernels (F, W, C) -> (Sh, Sw, F, H/Sh, W/Sw, C) + n_filts, kern_height, kern_width, n_chan = data.shape + new_weights = np.zeros((self.attributes['stride_height'], self.attributes['stride_width'], \ + n_filts, self.attributes['trfilt_height'], self.attributes['trfilt_width'], n_chan)) + for i_sh in range(self.attributes['stride_height']): + for i_sw in range(self.attributes['stride_width']): + for i_fh in range(self.attributes['trfilt_height']): + for i_fw in range(self.attributes['trfilt_width']): + filt_h_ind = i_sh + (self.attributes['trfilt_height']-i_fh-1)*self.attributes['stride_height'] + filt_w_ind = i_sw + (self.attributes['trfilt_width']-i_fw-1)*self.attributes['stride_width'] + for i_nf in range(n_filts): + for i_nc in range(n_chan): + if filt_h_ind < kern_height and filt_w_ind < kern_width: + new_weights[i_sh][i_sw][i_nf][i_fh][i_fw][i_nc] = \ + data[i_nf][filt_h_ind][filt_w_ind][i_nc] + data = new_weights + + self.add_output_variable(shape, dims) + self.add_weights_variable(name='weight', var_name='w{index}', \ + data=data, quantizer=self.get_attr('weight_quantizer'), keep_dims=2) + self.add_bias(quantizer=self.get_attr('bias_quantizer')) + class Conv2DBatchnorm(Conv2D): def _get_folded_weights(self): @@ -1274,6 +1391,8 @@ def _initialize_transforms(self): 'Conv2D': Conv2D, 'BinaryConv2D': Conv2D, 'QConv2D': Conv2D, + 'Conv1DTranspose': Conv1DTranspose, + 'Conv2DTranspose': Conv2DTranspose, 'QConv2DBatchnorm': Conv2DBatchnorm, 'SeparableConv1D': SeparableConv1D, 'SeparableConv2D': SeparableConv2D, diff --git a/hls4ml/model/types.py b/hls4ml/model/types.py index 115ff5cce0..2e94135038 100644 --- a/hls4ml/model/types.py +++ b/hls4ml/model/types.py @@ -104,9 +104,12 @@ def __str__(self): args = [self.width, self.integer, self.rounding_mode, self.saturation_mode, self.saturation_bits] args = ','.join([str(arg) for arg in args if arg is not None]) typestring = '{signed}fixed<{args}>'.format(signed='u' if not self.signed else '', args=args) + typestring = 'ap_{signed}fixed<{args}>'.format(signed='u' if not self.signed else '', args=args) return typestring def __eq__(self, other): + if not isinstance(other, FixedPrecisionType): + return False eq = self.width == other.width eq = eq and self.integer == other.integer eq = eq and self.fractional == other.fractional @@ -224,7 +227,7 @@ def definition_cpp(self, name_suffix='', as_reference=False): return None class WeightVariable(Variable): - def __init__(self, var_name, type_name, precision, data, quantizer=None, **kwargs): + def __init__(self, var_name, type_name, precision, data, quantizer=None, keep_dims=0, **kwargs): super(WeightVariable, self).__init__(var_name, NamedType(type_name, precision, **kwargs), **kwargs) self.data = data self.nzeros = -1 @@ -237,6 +240,7 @@ def __init__(self, var_name, type_name, precision, data, quantizer=None, **kwarg self._iterator = None self.update_precision(precision) self.quantizer = quantizer + self.keep_dims = keep_dims def __iter__(self): self._iterator = np.nditer(self.data, order='C') diff --git a/hls4ml/report/vivado_report.py b/hls4ml/report/vivado_report.py index 1201770cd3..a3d00c5642 100644 --- a/hls4ml/report/vivado_report.py +++ b/hls4ml/report/vivado_report.py @@ -12,8 +12,8 @@ def read_vivado_report(hls_dir, full_report=False): prj_dir = None top_func_name = None - if os.path.isfile(hls_dir + '/project.tcl'): - prj_dir, top_func_name = _parse_project_script(hls_dir) + if os.path.isfile(hls_dir + '/build_prj.tcl'): + prj_dir, top_func_name = _parse_build_script(hls_dir + '/build_prj.tcl') if prj_dir is None or top_func_name is None: print('Unable to read project data. Exiting.') @@ -31,17 +31,21 @@ def read_vivado_report(hls_dir, full_report=False): print('Reports for solution "{}":\n'.format(sln)) _find_reports(sln_dir + '/' + sln, top_func_name, full_report) -def _parse_project_script(path): +def _parse_build_script(path): prj_dir = None top_func_name = None + build_path = path + '/build_prj.tcl' project_path = path + '/project.tcl' + with open(build_path, 'r') as f: + for line in f.readlines(): + if 'set_top' in line: + top_func_name = line.split()[-1] with open(project_path, 'r') as f: for line in f.readlines(): - if 'set project_name' in line: - top_func_name = line.split('"')[-2] - prj_dir = top_func_name + '_prj' + if 'set myproject' in line: + prj_dir = line.split('"')[-2] + '_prj' return prj_dir, top_func_name @@ -109,8 +113,8 @@ def parse_vivado_report(hls_dir): prj_dir = None top_func_name = None - if os.path.isfile(hls_dir + '/project.tcl'): - prj_dir, top_func_name = _parse_project_script(hls_dir) + if os.path.isfile(hls_dir + '/build_prj.tcl'): + prj_dir, top_func_name = _parse_build_script(hls_dir) if prj_dir is None or top_func_name is None: print('Unable to read project data. Exiting.') diff --git a/hls4ml/templates/vivado/build_prj.tcl b/hls4ml/templates/vivado/build_prj.tcl index 3b0f9ad53b..df01e459ac 100644 --- a/hls4ml/templates/vivado/build_prj.tcl +++ b/hls4ml/templates/vivado/build_prj.tcl @@ -2,14 +2,14 @@ # HLS4ML ################# array set opt { - reset 0 - csim 1 - synth 1 - cosim 1 - validation 1 - export 0 - vsynth 0 - fifo_opt 0 + reset 0 + csim 1 + synth 1 + cosim 1 + validation 1 + export 0 + vsynth 0 + fifo_opt 0 } set tcldir [file dirname [info script]] @@ -19,7 +19,7 @@ proc remove_recursive_log_wave {} { set tcldir [file dirname [info script]] source [file join $tcldir project.tcl] - set filename ${project_name}_prj/solution1/sim/verilog/${project_name}.tcl + set filename ${myproject}_prj/solution1/sim/verilog/${myproject}.tcl set timestamp [clock format [clock seconds] -format {%Y%m%d%H%M%S}] set temp $filename.new.$timestamp # set backup $filename.bak.$timestamp @@ -35,19 +35,19 @@ proc remove_recursive_log_wave {} { puts $out $line } - close $in - close $out + close $in + close $out - # move the new data to the proper filename - file delete -force $filename - file rename -force $temp $filename + # move the new data to the proper filename + file delete -force $filename + file rename -force $temp $filename } proc add_vcd_instructions_tcl {} { set tcldir [file dirname [info script]] source [file join $tcldir project.tcl] - set filename ${project_name}_prj/solution1/sim/verilog/${project_name}.tcl + set filename ${myproject}_prj/solution1/sim/verilog/${myproject}.tcl set timestamp [clock format [clock seconds] -format {%Y%m%d%H%M%S}] set temp $filename.new.$timestamp # set backup $filename.bak.$timestamp @@ -58,43 +58,45 @@ proc add_vcd_instructions_tcl {} { # line-by-line, read the original file while {[gets $in line] != -1} { if {[string equal "$line" "log_wave -r /"]} { - set line {source "../../../../project.tcl" - if {[string equal "$backend" "vivadoaccelerator"]} { - current_scope [get_scopes -regex "/apatb_${project_name}_axi_top/AESL_inst_${project_name}_axi/${project_name}_U0.*"] - set scopes [get_scopes -regexp {layer(\d*)_.*data_0_V_U.*}] - append scopes { } - current_scope "/apatb_${project_name}_axi_top/AESL_inst_${project_name}_axi" - append scopes [get_scopes -regexp {(in_local_V_data.*_0_.*)}] - append scopes { } - append scopes [get_scopes -regexp {(out_local_V_data.*_0_.*)}] - } else { - current_scope [get_scopes -regex "/apatb_${project_name}_top/AESL_inst_${project_name}"] - set scopes [get_scopes -regexp {layer(\d*)_.*data_0_V_U.*}] - } - open_vcd fifo_opt.vcd - foreach scope $scopes { - current_scope $scope - if {[catch [get_objects usedw]] == 0} { - puts "$scope skipped" - continue - } - set usedw [get_objects usedw] - set depth [get_objects DEPTH] - add_wave $usedw - log_vcd $usedw - log_wave $usedw - add_wave $depth - log_vcd $depth - log_wave $depth - } - } + set line {source "../../../../project.tcl" +if {[string equal "$backend" "vivadoaccelerator"]} { + current_scope [get_scopes -regex /apatb_${myproject}_axi_top/AESL_inst_${myproject}_axi/${myproject}_U0.*] + set scopes [get_scopes -regexp {layer(\d*)_.*data_0_V_U.*}] + append scopes { } + current_scope /apatb_${myproject}_axi_top/AESL_inst_${myproject}_axi + append scopes [get_scopes -regexp {(in_local_V_data.*_0_.*)}] + append scopes { } + append scopes [get_scopes -regexp {(out_local_V_data.*_0_.*)}] +} else { + current_scope [get_scopes -regex /apatb_${myproject}_top/AESL_inst_${myproject}] + set scopes [get_scopes -regexp {layer(\d*)_.*data_0_V_U.*}] +} +open_vcd fifo_opt.vcd +foreach scope $scopes { + current_scope $scope + if {[catch [get_objects usedw]] == 0} { + puts "$scope skipped" + continue + } + set usedw [get_objects usedw] + set depth [get_objects DEPTH] + add_wave $usedw + log_vcd $usedw + log_wave $usedw + add_wave $depth + log_vcd $depth + log_wave $depth + } + } + + set line [string map [list "myproject" $myproject] $line] } if {[string equal "$line" "quit"]} { set line {flush_vcd - close_vcd - quit - } +close_vcd +quit +} } # then write the transformed line puts $out $line @@ -109,17 +111,17 @@ proc add_vcd_instructions_tcl {} { } foreach arg $::argv { - foreach o [lsort [array names opt]] { - regexp "$o=+(\\w+)" $arg unused opt($o) - } + foreach o [lsort [array names opt]] { + regexp "$o=+(\\w+)" $arg unused opt($o) + } } proc report_time { op_name time_start time_end } { - set time_taken [expr $time_end - $time_start] - set time_s [expr ($time_taken / 1000) % 60] - set time_m [expr ($time_taken / (1000*60)) % 60] - set time_h [expr ($time_taken / (1000*60*60)) % 24] - puts "***** ${op_name} COMPLETED IN ${time_h}h${time_m}m${time_s}s *****" + set time_taken [expr $time_end - $time_start] + set time_s [expr ($time_taken / 1000) % 60] + set time_m [expr ($time_taken / (1000*60)) % 60] + set time_h [expr ($time_taken / (1000*60*60)) % 24] + puts "***** ${op_name} COMPLETED IN ${time_h}h${time_m}m${time_s}s *****" } # Compare file content: 1 = same, 0 = different @@ -147,102 +149,102 @@ set CSIM_RESULTS "./tb_data/csim_results.log" set RTL_COSIM_RESULTS "./tb_data/rtl_cosim_results.log" if {$opt(reset)} { - open_project -reset ${project_name}_prj + open_project -reset ${myproject}_prj } else { - open_project ${project_name}_prj + open_project ${myproject}_prj } -set_top ${project_name} -add_files firmware/${project_name}.cpp -cflags "-std=c++0x" -add_files -tb ${project_name}_test.cpp -cflags "-std=c++0x" +set_top myproject +add_files firmware/myproject.cpp -cflags "-std=c++0x" +add_files -tb myproject_test.cpp -cflags "-std=c++0x" add_files -tb firmware/weights add_files -tb tb_data if {$opt(reset)} { - open_solution -reset "solution1" + open_solution -reset "solution1" } else { - open_solution "solution1" + open_solution "solution1" } catch {config_array_partition -maximum_size 4096} config_compile -name_max_length 60 -set_part $part -create_clock -period $clock_period -name default +set_part {xcku115-flvb2104-2-i} +create_clock -period 5 -name default if {$opt(csim)} { - puts "***** C SIMULATION *****" - set time_start [clock clicks -milliseconds] - csim_design - set time_end [clock clicks -milliseconds] - report_time "C SIMULATION" $time_start $time_end + puts "***** C SIMULATION *****" + set time_start [clock clicks -milliseconds] + csim_design + set time_end [clock clicks -milliseconds] + report_time "C SIMULATION" $time_start $time_end } if {$opt(synth)} { - puts "***** C/RTL SYNTHESIS *****" - set time_start [clock clicks -milliseconds] - csynth_design - set time_end [clock clicks -milliseconds] - report_time "C/RTL SYNTHESIS" $time_start $time_end + puts "***** C/RTL SYNTHESIS *****" + set time_start [clock clicks -milliseconds] + csynth_design + set time_end [clock clicks -milliseconds] + report_time "C/RTL SYNTHESIS" $time_start $time_end } if {$opt(cosim)} { - puts "***** C/RTL SIMULATION *****" - # TODO: This is a workaround (Xilinx defines __RTL_SIMULATION__ only for SystemC testbenches). - add_files -tb ${project_name}_test.cpp -cflags "-std=c++0x -DRTL_SIM" - set time_start [clock clicks -milliseconds] - - cosim_design -trace_level all -setup - - if {$opt(fifo_opt)} { - puts "\[hls4ml\] - FIFO optimization started" - add_vcd_instructions_tcl - } - - remove_recursive_log_wave - set old_pwd [pwd] - cd ${project_name}_prj/solution1/sim/verilog/ - source run_sim.tcl - cd $old_pwd - - set time_end [clock clicks -milliseconds] - puts "INFO:" - if {[string equal "$backend" "vivadoaccelerator"]} { - puts [read [open ${project_name}_prj/solution1/sim/report/${project_name}_axi_cosim.rpt r]] - } else { - puts [read [open ${project_name}_prj/solution1/sim/report/${project_name}_cosim.rpt r]] - } - report_time "C/RTL SIMULATION" $time_start $time_end + puts "***** C/RTL SIMULATION *****" + # TODO: This is a workaround (Xilinx defines __RTL_SIMULATION__ only for SystemC testbenches). + add_files -tb myproject_test.cpp -cflags "-std=c++0x -DRTL_SIM" + set time_start [clock clicks -milliseconds] + + cosim_design -trace_level all -setup + + if {$opt(fifo_opt)} { + puts "\[hls4ml\] - FIFO optimization started" + add_vcd_instructions_tcl + } + + remove_recursive_log_wave + set old_pwd [pwd] + cd ${myproject}_prj/solution1/sim/verilog/ + source run_sim.tcl + cd $old_pwd + + set time_end [clock clicks -milliseconds] + puts "INFO:" + if {[string equal "$backend" "vivadoaccelerator"]} { + puts [read [open ${myproject}_prj/solution1/sim/report/${myproject}_axi_cosim.rpt r]] + } else { + puts [read [open ${myproject}_prj/solution1/sim/report/${myproject}_cosim.rpt r]] + } + report_time "C/RTL SIMULATION" $time_start $time_end } if {$opt(validation)} { - puts "***** C/RTL VALIDATION *****" - if {[compare_files $CSIM_RESULTS $RTL_COSIM_RESULTS]} { - puts "INFO: Test PASSED" - } else { - puts "ERROR: Test failed" - puts "ERROR: - csim log: $CSIM_RESULTS" - puts "ERROR: - RTL-cosim log: $RTL_COSIM_RESULTS" - exit 1 - } + puts "***** C/RTL VALIDATION *****" + if {[compare_files $CSIM_RESULTS $RTL_COSIM_RESULTS]} { + puts "INFO: Test PASSED" + } else { + puts "ERROR: Test failed" + puts "ERROR: - csim log: $CSIM_RESULTS" + puts "ERROR: - RTL-cosim log: $RTL_COSIM_RESULTS" + exit 1 + } } if {$opt(export)} { - puts "***** EXPORT IP *****" - set time_start [clock clicks -milliseconds] - export_design -format ip_catalog - set time_end [clock clicks -milliseconds] - report_time "EXPORT IP" $time_start $time_end + puts "***** EXPORT IP *****" + set time_start [clock clicks -milliseconds] + export_design -format ip_catalog + set time_end [clock clicks -milliseconds] + report_time "EXPORT IP" $time_start $time_end } if {$opt(vsynth)} { - puts "***** VIVADO SYNTHESIS *****" - if {[file exist ${project_name}_prj/solution1/syn/vhdl]} { - set time_start [clock clicks -milliseconds] - exec vivado -mode batch -source vivado_synth.tcl >@ stdout - set time_end [clock clicks -milliseconds] - report_time "VIVADO SYNTHESIS" $time_start $time_end - } else { - puts "ERROR: Cannot find generated VHDL files. Did you run C synthesis?" - exit 1 - } + puts "***** VIVADO SYNTHESIS *****" + if {[file exist ${myproject}_prj/solution1/syn/vhdl]} { + set time_start [clock clicks -milliseconds] + exec vivado -mode batch -source vivado_synth.tcl >@ stdout + set time_end [clock clicks -milliseconds] + report_time "VIVADO SYNTHESIS" $time_start $time_end + } else { + puts "ERROR: Cannot find generated VHDL files. Did you run C synthesis?" + exit 1 + } } exit diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose.h new file mode 100644 index 0000000000..9794645721 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose.h @@ -0,0 +1,50 @@ +#ifndef NNET_CONV1DTRANSPOSE_H_ +#define NNET_CONV1DTRANSPOSE_H_ + +#include "nnet_common.h" +#include "nnet_conv1dtranspose_resource.h" +#include + +namespace nnet{ + +struct conv1dtranspose_config +{ + //Internal data type definitions + typedef float bias_t; + typedef float weight_t; + typedef float accum_t; + + //Convolutional parameters + static const unsigned pad_left = 0; + static const unsigned pad_right = 0; + static const unsigned in_width = 10; + static const unsigned n_chan = 0; + static const unsigned filt_width = 1; + static const unsigned kernel_size = filt_width; + static const unsigned stride_width = 1; + static const unsigned dilation = 1; + static const unsigned out_width = 10; + + static const unsigned reuse_factor = 1; + static const bool store_weights_in_bram = false; + static const unsigned n_zeros = 0; +}; + +template +void conv_1d_transpose_cl( + data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::stride_width][ + CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + #pragma HLS INLINE region + //for now, we are only adding resource strategy + conv_1d_transpose_resource_cl(data, res, weights, biases); +} + +} + +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_resource.h new file mode 100644 index 0000000000..033dbf676e --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_resource.h @@ -0,0 +1,132 @@ +#ifndef NNET_CONV1DTRANSPOSE_RESOURCE_H_ +#define NNET_CONV1DTRANSPOSE_RESOURCE_H_ + +#include "nnet_common.h" +#include "nnet_dense.h" + +namespace nnet{ + +template +void conv_1d_transpose_resource_cl( + data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::stride_width][ + CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + constexpr unsigned mult_n_in = CONFIG_T::trfilt_width * CONFIG_T::n_chan; + constexpr unsigned mult_n_out = CONFIG_T::n_filt; + constexpr unsigned block_factor = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor); + constexpr unsigned multscale = block_factor / mult_n_out; + + assert((block_factor % mult_n_out == 0 || CONFIG_T::reuse_factor >= mult_n_in) && "The current Reuse Factor is not allowed"); + assert((CONFIG_T::reuse_factor <= CONFIG_T::trfilt_width * CONFIG_T::n_chan) && "This function is correct only for RF <= TRFILT_WIDTH * N_CHAN"); + + data_T data_buf[CONFIG_T::n_pixels][mult_n_in]; + #pragma HLS ARRAY_PARTITION variable=data_buf complete dim=0 + #pragma HLS ARRAY_PARTITION variable=biases complete + + typename CONFIG_T::accum_t acc[CONFIG_T::n_pixels][mult_n_out][CONFIG_T::stride_width]; + #pragma HLS ARRAY_PARTITION variable=acc complete dim=0 + + #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor dim=2 + + PartitionLoop: + for (unsigned i_part = 0; i_part < CONFIG_T::n_partitions; i_part++) { + CONFIG_T::template fill_buffer::fill_buffer(data, data_buf, i_part); + + PixelInitAccumLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + InitAccumLoop: + for (unsigned i_acc = 0; i_acc < mult_n_out; i_acc++) { + #pragma HLS UNROLL + + InitStrideLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + acc[i_pxl][i_acc][i_sw] = (typename CONFIG_T::accum_t) biases[i_acc]; + } + } + } + + ReuseLoop: + for (unsigned i_rf = 0; i_rf < CONFIG_T::reuse_factor; i_rf++) { + #pragma HLS PIPELINE II=1 rewind + + unsigned i_w = i_rf; + unsigned i_in = i_rf; + unsigned i_out = 0; + unsigned i_acc = 0; + + MultLoop: + for (unsigned i_blk = 0; i_blk < block_factor; i_blk++) { + #pragma HLS UNROLL + + PixelMultLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + StrideMultLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + + acc[i_pxl][i_out][i_sw] += static_cast( + CONFIG_T::mult_config::template product::product( + data_buf[i_pxl][i_in], weights[i_sw][i_w] + ) + ); + } + } + + // Increment i_w + i_w += CONFIG_T::reuse_factor; + // Increment i_in + i_in += CONFIG_T::reuse_factor; + if (i_in >= mult_n_in) { + i_in = i_rf; + } + // Increment i_out + if (i_acc + 1 >= multscale) { + i_acc = 0; + i_out++; + } else { + i_acc++; + } + } + } + + + PixelResultLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + StrideResultLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + + unsigned output_index = i_pxl * CONFIG_T::n_partitions * CONFIG_T::stride_width + + i_part * CONFIG_T::stride_width + i_sw; + + if (output_index >= CONFIG_T::pad_left && + output_index < CONFIG_T::out_width + CONFIG_T::pad_left) { + ResultLoop: + for (unsigned i_res = 0; i_res < mult_n_out; i_res++) { + #pragma HLS UNROLL + + res[(output_index-CONFIG_T::pad_left)*mult_n_out + i_res] = + cast(acc[i_pxl][i_res][i_sw]); + } + } + } + } + } + +} + + +} +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h new file mode 100644 index 0000000000..d9dca82007 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h @@ -0,0 +1,141 @@ +#ifndef NNET_CONV1DTRANSPOSE_STREAM_H +#define NNET_CONV1DTRANSPOSE_STREAM_H + +#include "nnet_common.h" +#include "nnet_conv_stream.h" +#include "hls_stream.h" + +namespace nnet { + +template +void kernel_shift_tr_1d( + const data_T& in_elem, + typename data_T::value_type kernel_window[CONFIG_T::trfilt_width * CONFIG_T::n_chan] +) { + #pragma HLS INLINE + + // Shift kernel_window by one step to the left (manual shift operation) + static const int filt_width = CONFIG_T::trfilt_width - 1; + KernelShiftWidth: for (int i_iw = 0; i_iw < filt_width; i_iw++) { + #pragma HLS PIPELINE II = 1 + KernelShiftChannel: for (unsigned i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + #pragma HLS UNROLL + // Shift every element in kernel_window to the left + kernel_window[i_iw * CONFIG_T::n_chan + i_ic] = kernel_window[(i_iw + 1) * CONFIG_T::n_chan + i_ic]; + } + } + + // Insert shift_buffer column into right-most column of kernel + static const int lastheight = (CONFIG_T::trfilt_width - 1) * CONFIG_T::n_chan; + KernelPushChannel: for (int i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + #pragma HLS UNROLL + kernel_window[lastheight + i_ic] = in_elem[i_ic]; + } +} + +// Conv 1D transpose compute output +template +void compute_output_buffer_tr_1d( + const data_T& in_elem, + hls::stream &res_stream, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_width][ + CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + #pragma HLS INLINE + + // Thresholds + const static int lShiftX = CONFIG_T::trfilt_width - 1; + + // Counters + static int pX = 0; // pixel counter + static int oX = 0; // output counter (deals with 'padding') + + static typename data_T::value_type kernel_data[CONFIG_T::trfilt_width * CONFIG_T::n_chan]; + #pragma HLS ARRAY_PARTITION variable=kernel_data complete + + typename res_T::value_type res_out[CONFIG_T::n_filt]; + #pragma HLS ARRAY_PARTITION variable=res_out complete dim = 0 + + res_T res_pack; + #pragma HLS DATA_PACK variable=res_pack + + // Add pixel to buffer + nnet::kernel_shift_tr_1d(in_elem, kernel_data); + + //always do stride number of multiplications + StrideLoop: for (int idx = 0; idx < CONFIG_T::stride_width; idx++) { + #pragma HLS UNROLL + #pragma HLS INLINE region + // Dense multiply + if (CONFIG_T::strategy == nnet::latency) { + dense_latency( + kernel_data, res_out, weights[idx], biases); + } else { + dense_resource( + kernel_data, res_out, weights[idx], biases); + } + + // Pack output + if (oX >= CONFIG_T::pad_left && oX < CONFIG_T::pad_left + CONFIG_T::out_width) { + CastLoop: for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) { + #pragma HLS UNROLL + res_pack[i_ic] = res_out[i_ic]; + } + res_stream.write(res_pack); + } + // Write output to stream when output ready + oX++; + } + + // static var housekeeping + if (pX + 1 == CONFIG_T::in_width) // done with all of the inputs + { + pX = 0; + oX = 0; + } else { + pX = pX + 1; + } +} + +template +void conv_1d_transpose_buffer_cl( + hls::stream &data, + hls::stream &res, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_width][ + CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) +{ + ReadInputWidth: for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width; i_iw++) { + #pragma HLS LOOP_FLATTEN + // if (CONFIG_T::strategy == nnet::latency) { + // #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + // } + compute_output_buffer_tr_1d(data.read(), res, weights, biases); + } +} + +template +void conv_1d_transpose_cl( + hls::stream &data, + hls::stream &res, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_width][ + CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + switch(CONFIG_T::implementation) { + #pragma HLS inline region + case conv_implementation::linebuffer: + conv_1d_transpose_buffer_cl(data, res, weights, biases); + break; + } +} + +} +#endif +//NEED TO PAD INPUT OR CLEAR KERNEL diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose.h new file mode 100644 index 0000000000..aec7eee7f0 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose.h @@ -0,0 +1,60 @@ +#ifndef NNET_CONV2DTRANSPOSE_H +#define NNET_CONV2DTRANSPOSE_H + +#include "nnet_common.h" +#include "nnet_conv2dtranspose_resource.h" +#include + +namespace nnet{ + +struct conv2dtranspose_config +{ + // Internal data type definitions + typedef float bias_t; + typedef float weight_t; + typedef float accum_t; + + // Convolutional parameters + static const unsigned pad_top = 0; + static const unsigned pad_bottom = 0; + static const unsigned pad_left = 0; + static const unsigned pad_right = 0; + static const unsigned in_height = 10; + static const unsigned in_width = 10; + static const unsigned n_chan = 1; + static const unsigned filt_height = 1; + static const unsigned filt_width = 1; + static const unsigned kernel_size = filt_height * filt_width; + static const unsigned n_filt = 1; + static const unsigned stride_height = 1; + static const unsigned stride_width = 1; + static const unsigned out_height = 10; + static const unsigned out_width = 10; + static const unsigned dilation_height = 1; + static const unsigned dilation_width = 1; + static const unsigned trfilt_height = 1; + static const unsigned trfilt_width = 1; + + static const unsigned reuse_factor = 1; + static const bool store_weights_in_bram = false; + static const unsigned n_zeros = 0; // not used yet +}; + +template +void conv_2d_transpose_cl( + data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::stride_height][CONFIG_T::stride_width][ + CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + #pragma HLS INLINE region + //only have resource strategy as of now + conv_2d_transpose_resource_cl(data, res, weights, biases); +} + +} + +#endif \ No newline at end of file diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_resource.h new file mode 100644 index 0000000000..8f0ed748ed --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_resource.h @@ -0,0 +1,148 @@ +#ifndef NNET_CONV2DTRANSPOSE_RESOURCE_H +#define NNET_CONV2DTRANSPOSE_RESOURCE_H + +#include "nnet_common.h" +#include "nnet_dense.h" + +namespace nnet{ + +template +void conv_2d_transpose_resource_cl( + data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::stride_height][CONFIG_T::stride_width][ + CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + constexpr unsigned mult_n_in = CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan; + constexpr unsigned mult_n_out = CONFIG_T::n_filt; + constexpr unsigned block_factor = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor); + + constexpr unsigned multiplier_limit = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor); + constexpr unsigned multscale = multiplier_limit / mult_n_out; + + assert((multiplier_limit % mult_n_out == 0 || CONFIG_T::reuse_factor >= mult_n_in) && "The current Reuse Factor is not allowed"); + assert((multiplier_limit == block_factor) && "This function is correct only for RF <= TRFILT_HEIGHT * TRFILT_WIDTH * N_CHAN"); + + data_T data_buf[CONFIG_T::n_pixels][mult_n_in]; + #pragma HLS ARRAY_PARTITION variable=data_buf complete dim=0 + + #pragma HLS ARRAY_PARTITION variable=biases complete + + typename CONFIG_T::accum_t acc[CONFIG_T::n_pixels][mult_n_out][CONFIG_T::stride_height][CONFIG_T::stride_width]; + #pragma HLS ARRAY_PARTITION variable=acc complete dim=0 + + #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor dim=3 + + PartitionLoop: + for (unsigned i_part = 0; i_part < CONFIG_T::n_partitions; i_part++) { + CONFIG_T::template fill_buffer::fill_buffer(data, data_buf, i_part); + + PixelInitAccumLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + InitAccumLoop: + for (unsigned i_acc = 0; i_acc < mult_n_out; i_acc++) { + #pragma HLS UNROLL + + InitStrideHeightLoop: + for (unsigned i_sh = 0; i_sh < CONFIG_T::stride_height; i_sh++) { + #pragma HLS UNROLL + + InitStrideWidthLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + acc[i_pxl][i_acc][i_sh][i_sw] = (typename CONFIG_T::accum_t) biases[i_acc]; + } + } + } + } + + ReuseLoop: + for (unsigned i_rf = 0; i_rf < CONFIG_T::reuse_factor; i_rf++) { + #pragma HLS PIPELINE II=1 rewind + + unsigned i_w = i_rf; + unsigned i_in = i_rf; + unsigned i_out = 0; + unsigned i_acc = 0; + + MultLoop: + for (unsigned i_blk = 0; i_blk < block_factor; i_blk++) { + #pragma HLS UNROLL + PixelMultLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + StrideHeightMultLoop: + for (unsigned i_sh = 0; i_sh < CONFIG_T::stride_height; i_sh++) { + #pragma HLS UNROLL + StrideWidthMultLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + + acc[i_pxl][i_out][i_sh][i_sw] += static_cast( + CONFIG_T::mult_config::template product::product( + data_buf[i_pxl][i_in], weights[i_sh][i_sw][i_w] + ) + ); + } + } + } + + // Increment i_w + i_w += CONFIG_T::reuse_factor; + // Increment i_in + i_in += CONFIG_T::reuse_factor; + if (i_in >= mult_n_in) { + i_in = i_rf; + } + // Increment i_out + if (i_acc + 1 >= multscale) { + i_acc = 0; + i_out++; + } else { + i_acc++; + } + } + } + + PixelResultLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + StrideHeightResultLoop: + for (unsigned i_sh = 0; i_sh < CONFIG_T::stride_height; i_sh++) { + #pragma HLS UNROLL + StrideWidthResultLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + + unsigned px_ind = i_pxl * CONFIG_T::n_partitions + i_part; + unsigned height_ind = (px_ind / CONFIG_T::proc_width) * CONFIG_T::stride_height + i_sh; + unsigned width_ind = (px_ind % CONFIG_T::proc_width) * CONFIG_T::stride_width + i_sw; + + if (height_ind >= CONFIG_T::pad_top && height_ind < CONFIG_T::out_height + CONFIG_T::pad_top && + width_ind >= CONFIG_T::pad_left && width_ind < CONFIG_T::out_width + CONFIG_T::pad_left) { + ResultLoop: for (unsigned i_res = 0; i_res < mult_n_out; i_res++) { + #pragma HLS UNROLL + + res[((height_ind-CONFIG_T::pad_top)*CONFIG_T::out_width + width_ind-CONFIG_T::pad_left)* + CONFIG_T::n_filt + i_res] = + cast( + acc[i_pxl][i_res][i_sh][i_sw] + ); + } + } + } + } + + } + } +} + +} + +#endif \ No newline at end of file diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h new file mode 100644 index 0000000000..eff09cac91 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h @@ -0,0 +1,209 @@ +#ifndef NNET_CONV2DTRANSPOSE_STREAM_H +#define NNET_CONV2DTRANSPOSE_STREAM_H + +#include "ap_shift_reg.h" +#include "nnet_conv_stream.h" +#include "nnet_common.h" +#include "hls_stream.h" + +namespace nnet { + +template +void kernel_shift_tr_2d( + typename data_T::value_type shift_buffer[CONFIG_T::trfilt_height][CONFIG_T::n_chan], + typename data_T::value_type kernel_window[CONFIG_T::trfilt_width * CONFIG_T::trfilt_height * CONFIG_T::n_chan] +) { + #pragma HLS inline + + // Shift kernel_window by one step to the left (manual shift operation) + static const int filt_width = CONFIG_T::trfilt_width - 1; + KernelShiftWidth: for (int i_iw = 0; i_iw < filt_width; i_iw++) { + #pragma HLS PIPELINE II = 1 + KernelShiftHeight: for (unsigned i_ih = 0; i_ih < CONFIG_T::trfilt_height; i_ih++) { + KernelShiftChannel: for (unsigned i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + // Shift every element in kernel_window to the left + kernel_window[i_ih * CONFIG_T::trfilt_width * CONFIG_T::n_chan + i_iw * CONFIG_T::n_chan + i_ic] = kernel_window[i_ih * CONFIG_T::trfilt_width * CONFIG_T::n_chan + (i_iw + 1) * CONFIG_T::n_chan + i_ic]; + } + } + } + + // Insert shift_buffer column into right-most column of kernel + static const int lastheight = (CONFIG_T::trfilt_width - 1) * CONFIG_T::n_chan; + KernelPushHeight: for (int i_ih = 0; i_ih < CONFIG_T::trfilt_height; i_ih++) { + #pragma HLS UNROLL + KernelPushChannel: for (int i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + kernel_window[lastheight + i_ih * CONFIG_T::trfilt_width * CONFIG_T::n_chan + i_ic] = shift_buffer[i_ih][i_ic]; + } + } +} + +template +void shift_line_buffer_tr(const data_T& in_elem, + ap_shift_reg line_buffer[MAX(CONFIG_T::trfilt_height - 1,1)][CONFIG_T::n_chan], + typename data_T::value_type kernel_window[CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan] +) { + + #pragma HLS PIPELINE + + // Temporary buffer for popped (shifted) elements + typename data_T::value_type shift_buffer[CONFIG_T::trfilt_height][CONFIG_T::n_chan]; + #pragma HLS ARRAY_PARTITION variable = shift_buffer complete dim = 0 + + UpdateBuffer: for (int i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + #pragma HLS UNROLL + + // Insert pixel(s) at end of shift buffer + shift_buffer[CONFIG_T::trfilt_height - 1][i_ic] = in_elem[i_ic]; + } + + LineBufferDataIn: for (int i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + // Shift the shift buffer into the line buffer + LineBufferShift: for (unsigned i_ih = 1; i_ih < CONFIG_T::trfilt_height; i_ih++) { + #pragma HLS UNROLL + typename data_T::value_type pop_elem = + line_buffer[i_ih - 1][i_ic].shift(shift_buffer[CONFIG_T::trfilt_height - i_ih][i_ic]); // Shift the line buffer, return the popped pixel + shift_buffer[CONFIG_T::trfilt_height - i_ih - 1][i_ic] = pop_elem; // Popped element placed back into shift_buffer, one row up. + } + } + kernel_shift_tr_2d(shift_buffer, kernel_window); +} + +template +void compute_output_buffer_tr_2d( + const data_T& in_elem, + ap_shift_reg line_buffer[MAX(CONFIG_T::trfilt_height-1, 1)][CONFIG_T::n_chan], + hls::stream &res_stream, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_height][CONFIG_T::stride_width][ + CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + #pragma HLS INLINE + + //Counters + static int pX = 0; //pixel counters + static int pY = 0; + + static typename data_T::value_type kernel_data[CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan]; + #pragma HLS ARRAY_PARTITION variable=kernel_data complete + + typename res_T::value_type res_out[CONFIG_T::n_filt]; + #pragma HLS ARRAY_PARTITION variable=res_out complete dim = 0 + + static typename res_T::value_type output_buffer[ + CONFIG_T::in_width*CONFIG_T::stride_width*CONFIG_T::stride_height*CONFIG_T::n_filt + ]; + #pragma HLS ARRAY_PARTITION variable=output_buffer complete dim = 0 + + res_T res_pack; + #pragma HLS DATA_PACK variable = res_pack + + //Add pixel to the buffer + nnet::shift_line_buffer_tr(in_elem, line_buffer, kernel_data); + + HeightStrideLoop: for (int w_idx = 0; w_idx < CONFIG_T::stride_width; w_idx++) { + // #pragma HLS PIPELINE + #pragma HLS UNROLL + WidthStrideLoop: for (int h_idx = 0; h_idx < CONFIG_T::stride_height; h_idx++) { + #pragma HLS UNROLL + + #pragma HLS INLINE region + + if (CONFIG_T::strategy == nnet::latency) { + dense_latency( + kernel_data, res_out, weights[h_idx][w_idx], biases + ); + } else { + dense_resource( + kernel_data, res_out, weights[h_idx][w_idx], biases + ); + } + + BufferOutputLoop: for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) { + #pragma HLS UNROLL + output_buffer[ + (pX*CONFIG_T::stride_width+w_idx)*CONFIG_T::stride_height*CONFIG_T::n_filt + + h_idx*CONFIG_T::n_filt + i_ic + ] = res_out[i_ic]; + } + } + } + + //Counter Housekeeping and printing buffered output + if (pX + 1 == CONFIG_T::in_width) { + pX = 0; + //write all of the buffered output for outputs we want + HeightOutputLoop: for (unsigned h_idx = 0; h_idx < CONFIG_T::stride_height; h_idx++) { + // #pragma HLS PIPELINE + if (pY*CONFIG_T::stride_height + h_idx >= CONFIG_T::pad_top && + pY*CONFIG_T::stride_height + h_idx < CONFIG_T::pad_top + CONFIG_T::out_height) { + WidthOutputLoop: for (unsigned oX = CONFIG_T::pad_left; oX < CONFIG_T::pad_left + CONFIG_T::out_width; oX++) { + #pragma HLS PIPELINE + CastLoop: for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) { + #pragma HLS UNROLL + res_pack[i_ic] = output_buffer[ + oX*CONFIG_T::stride_height*CONFIG_T::n_filt + + h_idx*CONFIG_T::n_filt + i_ic + ]; + } + res_stream.write(res_pack); + } + } + } + + if (pY + 1 == CONFIG_T::in_height) { + pY = 0; + } else { + pY = pY + 1; + } + } else { + pX = pX + 1; + } + +} + +template +void conv_2d_transpose_buffer_cl( + hls::stream &data, + hls::stream &res, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_height][CONFIG_T::stride_width][ + CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + static ap_shift_reg line_buffer[MAX(CONFIG_T::trfilt_height-1, 1)][CONFIG_T::n_chan]; + #pragma HLS ARRAY_PARTITION variable = line_buffer complete dim = 2 + + ReadInputHeight: for (unsigned i_ih = 0; i_ih < CONFIG_T::in_height; i_ih++) { + ReadInputWidth: for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width; i_iw++) { + #pragma HLS LOOP_FLATTEN + if (CONFIG_T::strategy == nnet::latency) { + #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + } + compute_output_buffer_tr_2d(data.read(), line_buffer, res, weights, biases); + } + } +} + +template +void conv_2d_transpose_cl( + hls::stream &data, + hls::stream &res, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_height][CONFIG_T::stride_width][ + CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan + ], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) +{ + #pragma HLS INLINE region + switch(CONFIG_T::implementation) { + case conv_implementation::linebuffer: + conv_2d_transpose_buffer_cl(data, res, weights, biases); + break; + } +} + +} +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_helpers.h b/hls4ml/templates/vivado/nnet_utils/nnet_helpers.h index eed64fc701..9d030772fd 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_helpers.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_helpers.h @@ -67,6 +67,83 @@ void load_weights_from_txt(T *w, const char* fname) { } } +template +void load_weights_from_txt(T w[DIM_1][DIM_2], const char* fname) { + + std::string full_path = std::string(WEIGHTS_DIR) + "/" + std::string(fname); + std::ifstream infile(full_path.c_str(), std::ios::binary); + + if (infile.fail()) { + std::cerr << "ERROR: file " << std::string(fname) << " does not exist" << std::endl; + exit(1); + } + + std::string line; + if (std::getline(infile, line)) { + std::istringstream iss(line); + std::string token; + + size_t i = 0; + size_t j = 0; + size_t tot = 0; + while(std::getline(iss, token, ',')) { + std::istringstream(token) >> w[i][j]; + j++; + if (j == DIM_2) { + j = 0; + i++; + } + tot++; + } + + if (DIM_1*DIM_2 != tot) { + std::cerr << "ERROR: Expected " << DIM_1*DIM_2 << " values"; + std::cerr << " but read only " << tot << " values" << std::endl; + } + } +} + +template +void load_weights_from_txt(T w[DIM_1][DIM_2][DIM_3], const char* fname) { + + std::string full_path = std::string(WEIGHTS_DIR) + "/" + std::string(fname); + std::ifstream infile(full_path.c_str(), std::ios::binary); + + if (infile.fail()) { + std::cerr << "ERROR: file " << std::string(fname) << " does not exist" << std::endl; + exit(1); + } + + std::string line; + if (std::getline(infile, line)) { + std::istringstream iss(line); + std::string token; + + size_t i = 0; + size_t j = 0; + size_t k = 0; + size_t tot = 0; + while(std::getline(iss, token, ',')) { + std::istringstream(token) >> w[i][j][k]; + k++; + if (k == DIM_3) { + k = 0; + j++; + if (j == DIM_2) { + j = 0; + i++; + } + } + tot++; + } + + if (DIM_1*DIM_2*DIM_3 != tot) { + std::cerr << "ERROR: Expected " << DIM_1*DIM_2*DIM_3 << " values"; + std::cerr << " but read only " << tot << " values" << std::endl; + } + } +} + template void load_compressed_weights_from_txt(T *w, const char* fname) { diff --git a/hls4ml/templates/vivado/vivado_synth.tcl b/hls4ml/templates/vivado/vivado_synth.tcl index a4e57a8edb..3cef6b2549 100644 --- a/hls4ml/templates/vivado/vivado_synth.tcl +++ b/hls4ml/templates/vivado/vivado_synth.tcl @@ -1,6 +1,3 @@ -set tcldir [file dirname [info script]] -source [file join $tcldir project.tcl] - -add_files ${project_name}_prj/solution1/syn/vhdl -synth_design -top ${project_name} -part $part +add_files myproject_prj/solution1/syn/vhdl +synth_design -top myproject -part xcku115-flvb2104-2-i report_utilization -file vivado_synth.rpt \ No newline at end of file diff --git a/hls4ml/templates/vivado_accelerator/alveo/tcl_scripts/axi_stream_design.tcl b/hls4ml/templates/vivado_accelerator/alveo/tcl_scripts/axi_stream_design.tcl index 97da885770..f02c8d8449 100644 --- a/hls4ml/templates/vivado_accelerator/alveo/tcl_scripts/axi_stream_design.tcl +++ b/hls4ml/templates/vivado_accelerator/alveo/tcl_scripts/axi_stream_design.tcl @@ -1,9 +1,9 @@ set tcldir [file dirname [info script]] source [file join $tcldir project.tcl] -create_project project_1 ${project_name}_vivado_accelerator -part ${part} -force +create_project project_1 ${myproject}_vivado_accelerator -part ${part} -force -set_property ip_repo_paths ${project_name}_prj [current_project] +set_property ip_repo_paths ${myproject}_prj [current_project] update_ip_catalog @@ -12,7 +12,7 @@ import_files [list src/krnl_rtl_int.sv src/krnl_rtl_axi_read_master.sv src/krnl_ -create_ip -vlnv xilinx.com:hls:${project_name}_axi:1.0 -module_name ${project_name}_axi_0 +create_ip -vlnv xilinx.com:hls:${myproject}_axi:1.0 -module_name myproject_axi_0 ipx::package_project -root_dir hls4ml_IP -vendor fastmachinelearning.org -library hls4ml -taxonomy /UserIP -import_files -set_current false @@ -106,4 +106,4 @@ ipx::archive_core hls4ml_IP/fastmachinelearning.org_hls4ml_krnl_rtl_1.0.zip [ipx current_project project_1 -package_xo -force -xo_path xo_files/${project_name}_kernel.xo -kernel_name krnl_rtl -ip_directory hls4ml_IP +package_xo -force -xo_path xo_files/${myproject}_kernel.xo -kernel_name krnl_rtl -ip_directory hls4ml_IP diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index bcf752b835..3e79dc92da 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -43,17 +43,27 @@ def print_array_to_cpp(self, var, odir, write_txt_file=True): h_file.write(var.definition_cpp() + ";\n") h_file.write("#else\n") - h_file.write(var.definition_cpp() + " = {") - - # fill c++ array. - # not including internal brackets for multidimensional case - sep = '' - for x in var: - h_file.write(sep + x) + h_file.write(var.definition_cpp() + " = ") + + factors = np.ones(len(var.shape)+1) + for idx in range(len(var.shape)-1, -1, -1): + factors[idx] = var.shape[idx] * factors[idx+1] + #fill c++ array, keeping the first keep_dims dimensions in-tact. + for idx, x in enumerate(var): + for dim in range(var.keep_dims+1): + if idx % factors[dim] == 0: + h_file.write("{") + h_file.write(x) if write_txt_file: - txt_file.write(sep + x) - sep = ", " - h_file.write("};\n") + txt_file.write(x) + for dim in range(var.keep_dims+1): + if idx % factors[dim] == factors[dim]-1: + h_file.write("}") + if idx < factors[0]-1: # only don't put comma at the end + h_file.write(", ") + if write_txt_file: + txt_file.write(", ") + h_file.write(";\n") if write_txt_file: h_file.write("#endif\n") txt_file.close() @@ -150,9 +160,12 @@ def write_project_cpp(self, model): w.type.name, w.data_length, w.name, w.name ) else: - newline += indent + ' nnet::load_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format( - w.type.name, w.data_length, w.name, w.name - ) + dim_info = w.data_length + if w.keep_dims == 1: + dim_info = '{}, {}'.format(w.shape[0], w.data_length//w.shape[0]) + if w.keep_dims == 2: + dim_info = '{}, {}, {}'.format(w.shape[0], w.shape[1], w.data_length//(w.shape[0]*w.shape[1])) + newline += indent + ' nnet::load_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format(w.type.name, dim_info, w.name, w.name) # Add input/output type elif '//hls-fpga-machine-learning insert IO' in line: @@ -578,10 +591,6 @@ def write_build_script(self, model): f.write(f'set project_name "{model.config.get_project_name()}"\n') f.write('variable backend\n') f.write('set backend "vivado"\n') - f.write('variable part\n') - f.write('set part "{}"\n'.format(model.config.get_config_value('Part'))) - f.write('variable clock_period\n') - f.write('set clock_period {}\n'.format(model.config.get_config_value('ClockPeriod'))) f.close() # build_prj.tcl From 7933d4272dc3a60da1c3c96cfddc0b3a6d1d7060 Mon Sep 17 00:00:00 2001 From: Javier Duarte Date: Sat, 18 Mar 2023 08:09:30 -0700 Subject: [PATCH 2/5] restore build_prj.tcl --- hls4ml/templates/vivado/build_prj.tcl | 252 +++++++++++++------------- 1 file changed, 125 insertions(+), 127 deletions(-) diff --git a/hls4ml/templates/vivado/build_prj.tcl b/hls4ml/templates/vivado/build_prj.tcl index df01e459ac..3b0f9ad53b 100644 --- a/hls4ml/templates/vivado/build_prj.tcl +++ b/hls4ml/templates/vivado/build_prj.tcl @@ -2,14 +2,14 @@ # HLS4ML ################# array set opt { - reset 0 - csim 1 - synth 1 - cosim 1 - validation 1 - export 0 - vsynth 0 - fifo_opt 0 + reset 0 + csim 1 + synth 1 + cosim 1 + validation 1 + export 0 + vsynth 0 + fifo_opt 0 } set tcldir [file dirname [info script]] @@ -19,7 +19,7 @@ proc remove_recursive_log_wave {} { set tcldir [file dirname [info script]] source [file join $tcldir project.tcl] - set filename ${myproject}_prj/solution1/sim/verilog/${myproject}.tcl + set filename ${project_name}_prj/solution1/sim/verilog/${project_name}.tcl set timestamp [clock format [clock seconds] -format {%Y%m%d%H%M%S}] set temp $filename.new.$timestamp # set backup $filename.bak.$timestamp @@ -35,19 +35,19 @@ proc remove_recursive_log_wave {} { puts $out $line } - close $in - close $out + close $in + close $out - # move the new data to the proper filename - file delete -force $filename - file rename -force $temp $filename + # move the new data to the proper filename + file delete -force $filename + file rename -force $temp $filename } proc add_vcd_instructions_tcl {} { set tcldir [file dirname [info script]] source [file join $tcldir project.tcl] - set filename ${myproject}_prj/solution1/sim/verilog/${myproject}.tcl + set filename ${project_name}_prj/solution1/sim/verilog/${project_name}.tcl set timestamp [clock format [clock seconds] -format {%Y%m%d%H%M%S}] set temp $filename.new.$timestamp # set backup $filename.bak.$timestamp @@ -58,45 +58,43 @@ proc add_vcd_instructions_tcl {} { # line-by-line, read the original file while {[gets $in line] != -1} { if {[string equal "$line" "log_wave -r /"]} { - set line {source "../../../../project.tcl" -if {[string equal "$backend" "vivadoaccelerator"]} { - current_scope [get_scopes -regex /apatb_${myproject}_axi_top/AESL_inst_${myproject}_axi/${myproject}_U0.*] - set scopes [get_scopes -regexp {layer(\d*)_.*data_0_V_U.*}] - append scopes { } - current_scope /apatb_${myproject}_axi_top/AESL_inst_${myproject}_axi - append scopes [get_scopes -regexp {(in_local_V_data.*_0_.*)}] - append scopes { } - append scopes [get_scopes -regexp {(out_local_V_data.*_0_.*)}] -} else { - current_scope [get_scopes -regex /apatb_${myproject}_top/AESL_inst_${myproject}] - set scopes [get_scopes -regexp {layer(\d*)_.*data_0_V_U.*}] -} -open_vcd fifo_opt.vcd -foreach scope $scopes { - current_scope $scope - if {[catch [get_objects usedw]] == 0} { - puts "$scope skipped" - continue - } - set usedw [get_objects usedw] - set depth [get_objects DEPTH] - add_wave $usedw - log_vcd $usedw - log_wave $usedw - add_wave $depth - log_vcd $depth - log_wave $depth - } - } - - set line [string map [list "myproject" $myproject] $line] + set line {source "../../../../project.tcl" + if {[string equal "$backend" "vivadoaccelerator"]} { + current_scope [get_scopes -regex "/apatb_${project_name}_axi_top/AESL_inst_${project_name}_axi/${project_name}_U0.*"] + set scopes [get_scopes -regexp {layer(\d*)_.*data_0_V_U.*}] + append scopes { } + current_scope "/apatb_${project_name}_axi_top/AESL_inst_${project_name}_axi" + append scopes [get_scopes -regexp {(in_local_V_data.*_0_.*)}] + append scopes { } + append scopes [get_scopes -regexp {(out_local_V_data.*_0_.*)}] + } else { + current_scope [get_scopes -regex "/apatb_${project_name}_top/AESL_inst_${project_name}"] + set scopes [get_scopes -regexp {layer(\d*)_.*data_0_V_U.*}] + } + open_vcd fifo_opt.vcd + foreach scope $scopes { + current_scope $scope + if {[catch [get_objects usedw]] == 0} { + puts "$scope skipped" + continue + } + set usedw [get_objects usedw] + set depth [get_objects DEPTH] + add_wave $usedw + log_vcd $usedw + log_wave $usedw + add_wave $depth + log_vcd $depth + log_wave $depth + } + } } if {[string equal "$line" "quit"]} { set line {flush_vcd -close_vcd -quit -} + close_vcd + quit + } } # then write the transformed line puts $out $line @@ -111,17 +109,17 @@ quit } foreach arg $::argv { - foreach o [lsort [array names opt]] { - regexp "$o=+(\\w+)" $arg unused opt($o) - } + foreach o [lsort [array names opt]] { + regexp "$o=+(\\w+)" $arg unused opt($o) + } } proc report_time { op_name time_start time_end } { - set time_taken [expr $time_end - $time_start] - set time_s [expr ($time_taken / 1000) % 60] - set time_m [expr ($time_taken / (1000*60)) % 60] - set time_h [expr ($time_taken / (1000*60*60)) % 24] - puts "***** ${op_name} COMPLETED IN ${time_h}h${time_m}m${time_s}s *****" + set time_taken [expr $time_end - $time_start] + set time_s [expr ($time_taken / 1000) % 60] + set time_m [expr ($time_taken / (1000*60)) % 60] + set time_h [expr ($time_taken / (1000*60*60)) % 24] + puts "***** ${op_name} COMPLETED IN ${time_h}h${time_m}m${time_s}s *****" } # Compare file content: 1 = same, 0 = different @@ -149,102 +147,102 @@ set CSIM_RESULTS "./tb_data/csim_results.log" set RTL_COSIM_RESULTS "./tb_data/rtl_cosim_results.log" if {$opt(reset)} { - open_project -reset ${myproject}_prj + open_project -reset ${project_name}_prj } else { - open_project ${myproject}_prj + open_project ${project_name}_prj } -set_top myproject -add_files firmware/myproject.cpp -cflags "-std=c++0x" -add_files -tb myproject_test.cpp -cflags "-std=c++0x" +set_top ${project_name} +add_files firmware/${project_name}.cpp -cflags "-std=c++0x" +add_files -tb ${project_name}_test.cpp -cflags "-std=c++0x" add_files -tb firmware/weights add_files -tb tb_data if {$opt(reset)} { - open_solution -reset "solution1" + open_solution -reset "solution1" } else { - open_solution "solution1" + open_solution "solution1" } catch {config_array_partition -maximum_size 4096} config_compile -name_max_length 60 -set_part {xcku115-flvb2104-2-i} -create_clock -period 5 -name default +set_part $part +create_clock -period $clock_period -name default if {$opt(csim)} { - puts "***** C SIMULATION *****" - set time_start [clock clicks -milliseconds] - csim_design - set time_end [clock clicks -milliseconds] - report_time "C SIMULATION" $time_start $time_end + puts "***** C SIMULATION *****" + set time_start [clock clicks -milliseconds] + csim_design + set time_end [clock clicks -milliseconds] + report_time "C SIMULATION" $time_start $time_end } if {$opt(synth)} { - puts "***** C/RTL SYNTHESIS *****" - set time_start [clock clicks -milliseconds] - csynth_design - set time_end [clock clicks -milliseconds] - report_time "C/RTL SYNTHESIS" $time_start $time_end + puts "***** C/RTL SYNTHESIS *****" + set time_start [clock clicks -milliseconds] + csynth_design + set time_end [clock clicks -milliseconds] + report_time "C/RTL SYNTHESIS" $time_start $time_end } if {$opt(cosim)} { - puts "***** C/RTL SIMULATION *****" - # TODO: This is a workaround (Xilinx defines __RTL_SIMULATION__ only for SystemC testbenches). - add_files -tb myproject_test.cpp -cflags "-std=c++0x -DRTL_SIM" - set time_start [clock clicks -milliseconds] - - cosim_design -trace_level all -setup - - if {$opt(fifo_opt)} { - puts "\[hls4ml\] - FIFO optimization started" - add_vcd_instructions_tcl - } - - remove_recursive_log_wave - set old_pwd [pwd] - cd ${myproject}_prj/solution1/sim/verilog/ - source run_sim.tcl - cd $old_pwd - - set time_end [clock clicks -milliseconds] - puts "INFO:" - if {[string equal "$backend" "vivadoaccelerator"]} { - puts [read [open ${myproject}_prj/solution1/sim/report/${myproject}_axi_cosim.rpt r]] - } else { - puts [read [open ${myproject}_prj/solution1/sim/report/${myproject}_cosim.rpt r]] - } - report_time "C/RTL SIMULATION" $time_start $time_end + puts "***** C/RTL SIMULATION *****" + # TODO: This is a workaround (Xilinx defines __RTL_SIMULATION__ only for SystemC testbenches). + add_files -tb ${project_name}_test.cpp -cflags "-std=c++0x -DRTL_SIM" + set time_start [clock clicks -milliseconds] + + cosim_design -trace_level all -setup + + if {$opt(fifo_opt)} { + puts "\[hls4ml\] - FIFO optimization started" + add_vcd_instructions_tcl + } + + remove_recursive_log_wave + set old_pwd [pwd] + cd ${project_name}_prj/solution1/sim/verilog/ + source run_sim.tcl + cd $old_pwd + + set time_end [clock clicks -milliseconds] + puts "INFO:" + if {[string equal "$backend" "vivadoaccelerator"]} { + puts [read [open ${project_name}_prj/solution1/sim/report/${project_name}_axi_cosim.rpt r]] + } else { + puts [read [open ${project_name}_prj/solution1/sim/report/${project_name}_cosim.rpt r]] + } + report_time "C/RTL SIMULATION" $time_start $time_end } if {$opt(validation)} { - puts "***** C/RTL VALIDATION *****" - if {[compare_files $CSIM_RESULTS $RTL_COSIM_RESULTS]} { - puts "INFO: Test PASSED" - } else { - puts "ERROR: Test failed" - puts "ERROR: - csim log: $CSIM_RESULTS" - puts "ERROR: - RTL-cosim log: $RTL_COSIM_RESULTS" - exit 1 - } + puts "***** C/RTL VALIDATION *****" + if {[compare_files $CSIM_RESULTS $RTL_COSIM_RESULTS]} { + puts "INFO: Test PASSED" + } else { + puts "ERROR: Test failed" + puts "ERROR: - csim log: $CSIM_RESULTS" + puts "ERROR: - RTL-cosim log: $RTL_COSIM_RESULTS" + exit 1 + } } if {$opt(export)} { - puts "***** EXPORT IP *****" - set time_start [clock clicks -milliseconds] - export_design -format ip_catalog - set time_end [clock clicks -milliseconds] - report_time "EXPORT IP" $time_start $time_end + puts "***** EXPORT IP *****" + set time_start [clock clicks -milliseconds] + export_design -format ip_catalog + set time_end [clock clicks -milliseconds] + report_time "EXPORT IP" $time_start $time_end } if {$opt(vsynth)} { - puts "***** VIVADO SYNTHESIS *****" - if {[file exist ${myproject}_prj/solution1/syn/vhdl]} { - set time_start [clock clicks -milliseconds] - exec vivado -mode batch -source vivado_synth.tcl >@ stdout - set time_end [clock clicks -milliseconds] - report_time "VIVADO SYNTHESIS" $time_start $time_end - } else { - puts "ERROR: Cannot find generated VHDL files. Did you run C synthesis?" - exit 1 - } + puts "***** VIVADO SYNTHESIS *****" + if {[file exist ${project_name}_prj/solution1/syn/vhdl]} { + set time_start [clock clicks -milliseconds] + exec vivado -mode batch -source vivado_synth.tcl >@ stdout + set time_end [clock clicks -milliseconds] + report_time "VIVADO SYNTHESIS" $time_start $time_end + } else { + puts "ERROR: Cannot find generated VHDL files. Did you run C synthesis?" + exit 1 + } } exit From 45c6171c50805e3220f444031d004f7eb8a9e1fb Mon Sep 17 00:00:00 2001 From: Javier Duarte Date: Sat, 18 Mar 2023 08:23:14 -0700 Subject: [PATCH 3/5] resotre more to main --- hls4ml/report/vivado_report.py | 20 ++++++++----------- hls4ml/templates/vivado/vivado_synth.tcl | 7 +++++-- .../alveo/tcl_scripts/axi_stream_design.tcl | 8 ++++---- hls4ml/writer/vivado_writer.py | 4 ++++ 4 files changed, 21 insertions(+), 18 deletions(-) diff --git a/hls4ml/report/vivado_report.py b/hls4ml/report/vivado_report.py index a3d00c5642..1201770cd3 100644 --- a/hls4ml/report/vivado_report.py +++ b/hls4ml/report/vivado_report.py @@ -12,8 +12,8 @@ def read_vivado_report(hls_dir, full_report=False): prj_dir = None top_func_name = None - if os.path.isfile(hls_dir + '/build_prj.tcl'): - prj_dir, top_func_name = _parse_build_script(hls_dir + '/build_prj.tcl') + if os.path.isfile(hls_dir + '/project.tcl'): + prj_dir, top_func_name = _parse_project_script(hls_dir) if prj_dir is None or top_func_name is None: print('Unable to read project data. Exiting.') @@ -31,21 +31,17 @@ def read_vivado_report(hls_dir, full_report=False): print('Reports for solution "{}":\n'.format(sln)) _find_reports(sln_dir + '/' + sln, top_func_name, full_report) -def _parse_build_script(path): +def _parse_project_script(path): prj_dir = None top_func_name = None - build_path = path + '/build_prj.tcl' project_path = path + '/project.tcl' - with open(build_path, 'r') as f: - for line in f.readlines(): - if 'set_top' in line: - top_func_name = line.split()[-1] with open(project_path, 'r') as f: for line in f.readlines(): - if 'set myproject' in line: - prj_dir = line.split('"')[-2] + '_prj' + if 'set project_name' in line: + top_func_name = line.split('"')[-2] + prj_dir = top_func_name + '_prj' return prj_dir, top_func_name @@ -113,8 +109,8 @@ def parse_vivado_report(hls_dir): prj_dir = None top_func_name = None - if os.path.isfile(hls_dir + '/build_prj.tcl'): - prj_dir, top_func_name = _parse_build_script(hls_dir) + if os.path.isfile(hls_dir + '/project.tcl'): + prj_dir, top_func_name = _parse_project_script(hls_dir) if prj_dir is None or top_func_name is None: print('Unable to read project data. Exiting.') diff --git a/hls4ml/templates/vivado/vivado_synth.tcl b/hls4ml/templates/vivado/vivado_synth.tcl index 3cef6b2549..a4e57a8edb 100644 --- a/hls4ml/templates/vivado/vivado_synth.tcl +++ b/hls4ml/templates/vivado/vivado_synth.tcl @@ -1,3 +1,6 @@ -add_files myproject_prj/solution1/syn/vhdl -synth_design -top myproject -part xcku115-flvb2104-2-i +set tcldir [file dirname [info script]] +source [file join $tcldir project.tcl] + +add_files ${project_name}_prj/solution1/syn/vhdl +synth_design -top ${project_name} -part $part report_utilization -file vivado_synth.rpt \ No newline at end of file diff --git a/hls4ml/templates/vivado_accelerator/alveo/tcl_scripts/axi_stream_design.tcl b/hls4ml/templates/vivado_accelerator/alveo/tcl_scripts/axi_stream_design.tcl index f02c8d8449..97da885770 100644 --- a/hls4ml/templates/vivado_accelerator/alveo/tcl_scripts/axi_stream_design.tcl +++ b/hls4ml/templates/vivado_accelerator/alveo/tcl_scripts/axi_stream_design.tcl @@ -1,9 +1,9 @@ set tcldir [file dirname [info script]] source [file join $tcldir project.tcl] -create_project project_1 ${myproject}_vivado_accelerator -part ${part} -force +create_project project_1 ${project_name}_vivado_accelerator -part ${part} -force -set_property ip_repo_paths ${myproject}_prj [current_project] +set_property ip_repo_paths ${project_name}_prj [current_project] update_ip_catalog @@ -12,7 +12,7 @@ import_files [list src/krnl_rtl_int.sv src/krnl_rtl_axi_read_master.sv src/krnl_ -create_ip -vlnv xilinx.com:hls:${myproject}_axi:1.0 -module_name myproject_axi_0 +create_ip -vlnv xilinx.com:hls:${project_name}_axi:1.0 -module_name ${project_name}_axi_0 ipx::package_project -root_dir hls4ml_IP -vendor fastmachinelearning.org -library hls4ml -taxonomy /UserIP -import_files -set_current false @@ -106,4 +106,4 @@ ipx::archive_core hls4ml_IP/fastmachinelearning.org_hls4ml_krnl_rtl_1.0.zip [ipx current_project project_1 -package_xo -force -xo_path xo_files/${myproject}_kernel.xo -kernel_name krnl_rtl -ip_directory hls4ml_IP +package_xo -force -xo_path xo_files/${project_name}_kernel.xo -kernel_name krnl_rtl -ip_directory hls4ml_IP diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index 3e79dc92da..152069436a 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -591,6 +591,10 @@ def write_build_script(self, model): f.write(f'set project_name "{model.config.get_project_name()}"\n') f.write('variable backend\n') f.write('set backend "vivado"\n') + f.write('variable part\n') + f.write('set part "{}"\n'.format(model.config.get_config_value('Part'))) + f.write('variable clock_period\n') + f.write('set clock_period {}\n'.format(model.config.get_config_value('ClockPeriod'))) f.close() # build_prj.tcl From ba89ef8a62bcd2de1e4d811c4552925f6ab0a233 Mon Sep 17 00:00:00 2001 From: Javier Duarte Date: Sat, 18 Mar 2023 08:24:21 -0700 Subject: [PATCH 4/5] apply resource --- hls4ml/backends/vivado/passes/resource_strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hls4ml/backends/vivado/passes/resource_strategy.py b/hls4ml/backends/vivado/passes/resource_strategy.py index eecde7a373..ee0ee7bd03 100644 --- a/hls4ml/backends/vivado/passes/resource_strategy.py +++ b/hls4ml/backends/vivado/passes/resource_strategy.py @@ -7,7 +7,7 @@ class ApplyResourceStrategy(OptimizerPass): ''' Transposes the weights to use the dense_resource matrix multiply routine ''' def match(self, node): - node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU)) + node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU, Conv1DTranspose, Conv2DTranspose)) is_resource_strategy = node.get_attr('strategy', '').lower() == 'resource' already_transformed = node.get_attr('_weights_transposed', False) == True From 970ee1c19788f670aff78bfe8cfe52f3f9254f66 Mon Sep 17 00:00:00 2001 From: Javier Duarte Date: Sat, 18 Mar 2023 10:00:36 -0700 Subject: [PATCH 5/5] fix accum_t; no transpose for resource? --- hls4ml/backends/fpga/fpga_backend.py | 4 ++++ hls4ml/backends/vivado/passes/resource_strategy.py | 5 ++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index 205ff292c5..e09c761d17 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -14,7 +14,9 @@ Activation, BatchNormalization, Conv1D, + Conv1DTranspose, Conv2D, + Conv2DTranspose, Dense, Dot, Embedding, @@ -52,7 +54,9 @@ def __init__(self, name): accum_layers = [ Dense, Conv1D, + Conv1DTranspose, Conv2D, + Conv2DTranspose, SeparableConv1D, SeparableConv2D, Pooling1D, diff --git a/hls4ml/backends/vivado/passes/resource_strategy.py b/hls4ml/backends/vivado/passes/resource_strategy.py index ee0ee7bd03..9e41456f5c 100644 --- a/hls4ml/backends/vivado/passes/resource_strategy.py +++ b/hls4ml/backends/vivado/passes/resource_strategy.py @@ -1,14 +1,13 @@ import numpy as np from hls4ml.model.optimizer import OptimizerPass -from hls4ml.model.layers import Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose, Dense, SeparableConv1D, SeparableConv2D, LSTM, GRU +from hls4ml.model.layers import Conv1D, Conv2D, Dense, SeparableConv1D, SeparableConv2D, LSTM, GRU class ApplyResourceStrategy(OptimizerPass): ''' Transposes the weights to use the dense_resource matrix multiply routine ''' def match(self, node): - node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU, Conv1DTranspose, Conv2DTranspose)) - + node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU)) is_resource_strategy = node.get_attr('strategy', '').lower() == 'resource' already_transformed = node.get_attr('_weights_transposed', False) == True