Skip to content

Commit 1c345f7

Browse files
change 1d transpose weight input to be 2-dimensional (passed from python code)
1 parent c6719fc commit 1c345f7

File tree

11 files changed

+133
-114
lines changed

11 files changed

+133
-114
lines changed

hls4ml/backends/fpga/fpga_types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,15 @@ def __init__(self, type_converter):
326326

327327
class StaticWeightVariableDefinition(VariableDefinition):
328328
def definition_cpp(self, name_suffix='', as_reference=False):
329+
if self.keep_dims > 0:
330+
size_str = ''
331+
for dim in range(self.keep_dims):
332+
size_str += '[{cur_dim}]'.format(cur_dim=self.shape[dim])
333+
final_dim = 1
334+
for dim in range(self.keep_dims, len(self.shape)):
335+
final_dim *= self.shape[dim]
336+
size_str += '[{last_dim}]'.format(last_dim=final_dim)
337+
return '{type} {name}{sizes}'.format(type=self.type.name, name=self.name, sizes=size_str)
329338
return '{type} {name}[{size}]'.format(type=self.type.name, name=self.name, size=self.data_length)
330339

331340
class StaticWeightVariableConverter(object):

hls4ml/backends/vivado/passes/convolution_templates.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,6 @@ def format(self, node):
138138
params = self._default_config_params(node)
139139
params['dilation'] = node.get_attr('dilation', 1)
140140
params['nzeros'] = node.get_weights('weight').nzeros
141-
params['trfilt_width'] = (node.get_attr('filt_width') + node.get_attr('stride_width') - 1) \
142-
// node.get_attr('stride_width')
143141

144142
params['config_t'] = 'config{}_mult'.format(node.index)
145143
if node.model.config.get_config_value('IOType') == 'io_parallel':

hls4ml/backends/vivado/passes/resource_strategy.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,22 @@ def transform(self, model, node):
2020
elif isinstance(node, Conv1D):
2121
node.weights['weight'].data = np.transpose(node.weights['weight'].data, axes=[2, 0, 1]) #(W,C,F) => (F,W,C)
2222
elif isinstance(node, Conv1DTranspose):
23-
node.weights['weight'].data = np.transpose(node.weights['weight'].data, axes=[1, 0, 2]) #(W,F,C) => (F,W,C)
23+
pass
24+
# #(W,F,C) => (F,W,C)
25+
# node.weights['weight'].data = np.transpose(node.weights['weight'].data, axes=[1, 0, 2])
26+
# # now split the kernel into stride width kernels (F, W, C) -> (S, F, W/S, C)
27+
# n_filts, kern_width, n_chan = node.weights['weight'].data.shape
28+
# new_weights = np.zeros((node.get_attr('stride_width'), n_filts, node.get_attr('trfilt_width'), n_chan))
29+
# for i_sw in range(node.get_attr('stride_width')):
30+
# for i_fw in range(node.get_attr('trfilt_width')):
31+
# filt_ind = i_sw + (node.get_attr('trfilt_width')-i_fw-1) * node.get_attr('stride_width')
32+
# for i_nf in range(n_filts):
33+
# for i_nc in range(n_chan):
34+
# if filt_ind < kern_width:
35+
# new_weights[i_sw][i_nf][i_fw][i_nc] = \
36+
# node.weights['weight'].data[i_nf][filt_ind][i_nc]
37+
# node.weights['weight'].data = new_weights
38+
# print("Updated shape:", node.weights['weight'].data.shape)
2439
elif isinstance(node, SeparableConv1D):
2540
node.weights['depthwise'].data = np.transpose(node.weights['depthwise'].data, axes=[2, 0, 1]) #(W,C,F) => (F,W,C)
2641
node.weights['pointwise'].data = np.transpose(node.weights['pointwise'].data, axes=[2, 0, 1]) #(W,C,F) => (F,W,C)

hls4ml/converters/keras/convolution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def parse_conv1dtranspose_layer(keras_layer, input_names, input_shapes, data_rea
5151
layer['filt_width'] = keras_layer['config']['kernel_size'][0]
5252
layer['stride_width'] = keras_layer['config']['strides'][0]
5353
layer['padding'] = keras_layer['config']['padding']
54+
layer['trfilt_width'] = (layer['filt_width'] + layer['stride_width'] - 1)//layer['stride_width']
5455

5556
(
5657
layer['out_width'],

hls4ml/model/layers.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,10 @@ def add_output_variable(self, shape, dim_names, out_name=None, var_name='layer{i
162162

163163
self.set_attr(out_name, out)
164164

165-
def add_weights(self, quantizer=None, compression=False):
165+
def add_weights(self, quantizer=None, compression=False, keep_dims=0):
166166
data = self.model.get_weights_data(self.name, 'kernel')
167167

168-
self.add_weights_variable(name='weight', var_name='w{index}', data=data, quantizer=quantizer, compression=compression)
168+
self.add_weights_variable(name='weight', var_name='w{index}', data=data, quantizer=quantizer, compression=compression, keep_dims=keep_dims)
169169

170170
def add_bias(self, quantizer=None):
171171
data = self.model.get_weights_data(self.name, 'bias')
@@ -179,7 +179,7 @@ def add_bias(self, quantizer=None):
179179

180180
self.add_weights_variable(name='bias', var_name='b{index}', type_name=type_name, precision=precision, data=data, quantizer=quantizer)
181181

182-
def add_weights_variable(self, name, var_name=None, type_name=None, precision=None, data=None, quantizer=None, compression=False):
182+
def add_weights_variable(self, name, var_name=None, type_name=None, precision=None, data=None, quantizer=None, compression=False, keep_dims=0):
183183
if var_name is None:
184184
var_name = name + '{index}'
185185

@@ -213,7 +213,7 @@ def add_weights_variable(self, name, var_name=None, type_name=None, precision=No
213213
elif exponent_type:
214214
var = ExponentWeightVariable(var_name, type_name=type_name, precision=precision, quantizer=quantizer, data=data, index=self.index)
215215
else:
216-
var = WeightVariable(var_name, type_name=type_name, precision=precision, quantizer=quantizer, data=data, index=self.index)
216+
var = WeightVariable(var_name, type_name=type_name, precision=precision, quantizer=quantizer, data=data, index=self.index, keep_dims=keep_dims)
217217

218218
var.data_unquantized = data_unquantized
219219

@@ -366,8 +366,28 @@ def initialize(self):
366366
shape = [self.attributes['n_filt'], self.attributes['out_width']]
367367
dims = ['N_FILT_{}'.format(self.index), 'N_OUTPUTS_{}'.format(self.index)]
368368

369+
data = self.model.get_weights_data(self.name, 'kernel')
370+
# now we transform the entire kernel
371+
372+
#(W,F,C) => (F,W,C)
373+
data = np.transpose(data, axes=[1, 0, 2])
374+
# now split the kernel into stride width kernels (F, W, C) -> (S, F, W/S, C)
375+
n_filts, kern_width, n_chan = data.shape
376+
new_weights = np.zeros((self.attributes['stride_width'], n_filts, self.attributes['trfilt_width'], n_chan))
377+
for i_sw in range(self.attributes['stride_width']):
378+
for i_fw in range(self.attributes['trfilt_width']):
379+
filt_ind = i_sw + (self.attributes['trfilt_width']-i_fw-1) * self.attributes['stride_width']
380+
for i_nf in range(n_filts):
381+
for i_nc in range(n_chan):
382+
if filt_ind < kern_width:
383+
new_weights[i_sw][i_nf][i_fw][i_nc] = \
384+
data[i_nf][filt_ind][i_nc]
385+
data = new_weights
386+
369387
self.add_output_variable(shape, dims)
370-
self.add_weights(quantizer = self.get_attr('weight_quantizer'))
388+
# self.add_weights(quantizer = self.get_attr('weight_quantizer'), keep_dims=1)
389+
self.add_weights_variable(name='weight', var_name='w{index}', \
390+
data=data, quantizer=self.get_attr('weight_quantizer'), keep_dims=1)
371391
self.add_bias(quantizer = self.get_attr('bias_quantizer'))
372392

373393

hls4ml/model/types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,12 @@ def definition_cpp(self, name_suffix='', as_reference=False):
224224
return None
225225

226226
class WeightVariable(Variable):
227-
def __init__(self, var_name, type_name, precision, data, quantizer=None, **kwargs):
227+
def __init__(self, var_name, type_name, precision, data, quantizer=None, keep_dims=0, **kwargs):
228228
super(WeightVariable, self).__init__(var_name, NamedType(type_name, precision, **kwargs), **kwargs)
229229
self.data = data
230230
self.nzeros = -1
231231
self.shape = list(self.data.shape)
232+
print("Weight Variable shape object creation:", self.shape)
232233
self.data_length = np.prod(self.data.shape)
233234
self.nonzeros = np.count_nonzero(self.data)
234235
self.nzeros = self.data_length - self.nonzeros
@@ -237,6 +238,7 @@ def __init__(self, var_name, type_name, precision, data, quantizer=None, **kwarg
237238
self._iterator = None
238239
self.update_precision(precision)
239240
self.quantizer = quantizer
241+
self.keep_dims = keep_dims
240242

241243
def __iter__(self):
242244
self._iterator = np.nditer(self.data, order='C')

hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ template<class data_T, class res_T, typename CONFIG_T>
3434
void conv_1d_transpose_cl(
3535
data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
3636
res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
37-
typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt],
37+
typename CONFIG_T::weight_t weights[CONFIG_T::stride_width][
38+
CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan
39+
],
3840
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]
3941
)
4042
{

hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_resource.h

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@ template<class data_T, class res_T, typename CONFIG_T>
1010
void conv_1d_transpose_resource_cl(
1111
data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
1212
res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
13-
typename CONFIG_T::weight_t weights[CONFIG_T::n_filt * CONFIG_T::filt_width * CONFIG_T::n_chan],
13+
typename CONFIG_T::weight_t weights[CONFIG_T::stride_width][
14+
CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan
15+
],
1416
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]
1517
)
1618
{
17-
1819
constexpr unsigned mult_n_in = CONFIG_T::trfilt_width * CONFIG_T::n_chan;
1920
constexpr unsigned mult_n_out = CONFIG_T::n_filt;
2021
constexpr unsigned block_factor = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor);
@@ -30,41 +31,7 @@ void conv_1d_transpose_resource_cl(
3031
typename CONFIG_T::accum_t acc[CONFIG_T::n_pixels][mult_n_out][CONFIG_T::stride_width];
3132
#pragma HLS ARRAY_PARTITION variable=acc complete dim=0
3233

33-
typename CONFIG_T::weight_t trfilt_weights[
34-
CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan
35-
][CONFIG_T::stride_width];
36-
37-
for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) {
38-
#pragma HLS UNROLL
39-
40-
for (unsigned i_fw = 0; i_fw < CONFIG_T::trfilt_width; i_fw++) {
41-
#pragma HLS UNROLL
42-
43-
unsigned filt_ind = i_sw + (CONFIG_T::trfilt_width-i_fw-1)*CONFIG_T::stride_width;
44-
for (unsigned i_nf = 0; i_nf < CONFIG_T::n_filt; i_nf++) {
45-
#pragma HLS UNROLL
46-
47-
for (unsigned i_nc = 0; i_nc < CONFIG_T::n_chan; i_nc++) {
48-
#pragma HLS UNROLL
49-
50-
if (filt_ind < CONFIG_T::filt_width) {
51-
trfilt_weights[
52-
i_nf * CONFIG_T::n_chan * CONFIG_T::trfilt_width + i_fw * CONFIG_T::n_chan + i_nc
53-
][i_sw] = weights[
54-
i_nf * CONFIG_T::n_chan * CONFIG_T::filt_width + filt_ind * CONFIG_T::n_chan + i_nc
55-
];
56-
}
57-
else {
58-
trfilt_weights[
59-
i_nf * CONFIG_T::n_chan * CONFIG_T::trfilt_width + i_fw * CONFIG_T::n_chan + i_nc
60-
][i_sw] = 0;
61-
}
62-
}
63-
}
64-
}
65-
}
66-
67-
#pragma HLS ARRAY_RESHAPE variable=trfilt_weights block factor=block_factor dim=1
34+
#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor dim=2
6835

6936
PartitionLoop:
7037
for (unsigned i_part = 0; i_part < CONFIG_T::n_partitions; i_part++) {
@@ -109,7 +76,7 @@ void conv_1d_transpose_resource_cl(
10976

11077
acc[i_pxl][i_out][i_sw] += static_cast<typename CONFIG_T::accum_t>(
11178
CONFIG_T::mult_config::template product<data_T, typename CONFIG_T::mult_config::weight_t>::product(
112-
data_buf[i_pxl][i_in], trfilt_weights[i_w][i_sw]
79+
data_buf[i_pxl][i_in], weights[i_sw][i_w]
11380
)
11481
);
11582
}

hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h

Lines changed: 11 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -7,49 +7,6 @@
77

88
namespace nnet {
99

10-
template <typename CONFIG_T>
11-
void load_trfilt_weights_1d(
12-
typename CONFIG_T::weight_t trfilt_weights[CONFIG_T::stride_width][
13-
CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan
14-
],
15-
typename CONFIG_T::weight_t weights[
16-
CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt
17-
]
18-
)
19-
{
20-
#pragma HLS INLINE
21-
22-
for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) {
23-
#pragma HLS UNROLL
24-
25-
for (unsigned i_fw = 0; i_fw < CONFIG_T::trfilt_width; i_fw++) {
26-
#pragma HLS UNROLL
27-
28-
unsigned filt_ind = i_sw + (CONFIG_T::trfilt_width-i_fw-1)*CONFIG_T::stride_width;
29-
for (unsigned i_nf = 0; i_nf < CONFIG_T::n_filt; i_nf++) {
30-
#pragma HLS UNROLL
31-
32-
for (unsigned i_nc = 0; i_nc < CONFIG_T::n_chan; i_nc++) {
33-
#pragma HLS UNROLL
34-
35-
if (filt_ind < CONFIG_T::filt_width) {
36-
trfilt_weights[i_sw][
37-
i_nf * CONFIG_T::n_chan * CONFIG_T::trfilt_width + i_fw * CONFIG_T::n_chan + i_nc
38-
] = weights[
39-
i_nf * CONFIG_T::n_chan * CONFIG_T::filt_width + filt_ind * CONFIG_T::n_chan + i_nc
40-
];
41-
}
42-
else {
43-
trfilt_weights[i_sw][
44-
i_nf * CONFIG_T::n_chan * CONFIG_T::trfilt_width + i_fw * CONFIG_T::n_chan + i_nc
45-
] = 0;
46-
}
47-
}
48-
}
49-
}
50-
}
51-
}
52-
5310
template <class data_T, typename CONFIG_T>
5411
void kernel_shift_tr_1d(
5512
const data_T& in_elem,
@@ -81,7 +38,9 @@ template<class data_T, class res_T, typename CONFIG_T>
8138
void compute_output_buffer_tr_1d(
8239
const data_T& in_elem,
8340
hls::stream<res_T> &res_stream,
84-
typename CONFIG_T::weight_t weights[CONFIG_T::kernel_size * CONFIG_T::n_chan * CONFIG_T::n_filt],
41+
typename CONFIG_T::weight_t weights[CONFIG_T::stride_width][
42+
CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan
43+
],
8544
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]
8645
)
8746
{
@@ -106,23 +65,17 @@ void compute_output_buffer_tr_1d(
10665
// Add pixel to buffer
10766
nnet::kernel_shift_tr_1d<data_T, CONFIG_T>(in_elem, kernel_data);
10867

109-
static typename CONFIG_T::weight_t trfilt_weights[CONFIG_T::stride_width][
110-
CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan
111-
];
112-
113-
load_trfilt_weights_1d<CONFIG_T>(trfilt_weights, weights);
114-
11568
//always do stride number of multiplications
11669
StrideLoop: for (int idx = 0; idx < CONFIG_T::stride_width; idx++) {
11770
#pragma HLS UNROLL
11871
#pragma HLS INLINE region
11972
// Dense multiply
12073
if (CONFIG_T::strategy == nnet::latency) {
12174
dense_latency<typename data_T::value_type, typename res_T::value_type, typename CONFIG_T::mult_config>(
122-
kernel_data, res_out, trfilt_weights[idx], biases);
75+
kernel_data, res_out, weights[idx], biases);
12376
} else {
12477
dense_resource<typename data_T::value_type, typename res_T::value_type, typename CONFIG_T::mult_config>(
125-
kernel_data, res_out, trfilt_weights[idx], biases);
78+
kernel_data, res_out, weights[idx], biases);
12679
}
12780

12881
// Pack output
@@ -135,7 +88,6 @@ void compute_output_buffer_tr_1d(
13588
}
13689
// Write output to stream when output ready
13790
oX++;
138-
// weight_start++;
13991
}
14092

14193
// static var housekeeping
@@ -152,7 +104,9 @@ template<class data_T, class res_T, typename CONFIG_T>
152104
void conv_1d_transpose_buffer_cl(
153105
hls::stream<data_T> &data,
154106
hls::stream<res_T> &res,
155-
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
107+
typename CONFIG_T::weight_t weights[CONFIG_T::stride_width][
108+
CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan
109+
],
156110
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt])
157111
{
158112
ReadInputWidth: for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width; i_iw++) {
@@ -168,7 +122,9 @@ template<class data_T, class res_T, typename CONFIG_T>
168122
void conv_1d_transpose_cl(
169123
hls::stream<data_T> &data,
170124
hls::stream<res_T> &res,
171-
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
125+
typename CONFIG_T::weight_t weights[CONFIG_T::stride_width][
126+
CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan
127+
],
172128
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]
173129
)
174130
{

hls4ml/templates/vivado/nnet_utils/nnet_helpers.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,42 @@ void load_weights_from_txt(T *w, const char* fname) {
6767
}
6868
}
6969

70+
template<class T, size_t DIM_1, size_t DIM_2>
71+
void load_weights_from_txt(T w[DIM_1][DIM_2], const char* fname) {
72+
73+
std::string full_path = std::string(WEIGHTS_DIR) + "/" + std::string(fname);
74+
std::ifstream infile(full_path.c_str(), std::ios::binary);
75+
76+
if (infile.fail()) {
77+
std::cerr << "ERROR: file " << std::string(fname) << " does not exist" << std::endl;
78+
exit(1);
79+
}
80+
81+
std::string line;
82+
if (std::getline(infile, line)) {
83+
std::istringstream iss(line);
84+
std::string token;
85+
86+
size_t i = 0;
87+
size_t j = 0;
88+
size_t tot = 0;
89+
while(std::getline(iss, token, ',')) {
90+
std::istringstream(token) >> w[i][j];
91+
j++;
92+
if (j == DIM_2) {
93+
j = 0;
94+
i++;
95+
}
96+
tot++;
97+
}
98+
99+
if (DIM_1*DIM_2 != tot) {
100+
std::cerr << "ERROR: Expected " << DIM_1*DIM_2 << " values";
101+
std::cerr << " but read only " << tot << " values" << std::endl;
102+
}
103+
}
104+
}
105+
70106
template<class T, size_t SIZE>
71107
void load_compressed_weights_from_txt(T *w, const char* fname) {
72108

0 commit comments

Comments
 (0)