Skip to content

Commit 645f8f4

Browse files
move transposing of weight matrix to resource_strategy for transpose layers
1 parent 8d42dc1 commit 645f8f4

File tree

5 files changed

+34
-41
lines changed

5 files changed

+34
-41
lines changed

hls4ml/backends/vivado/passes/resource_strategy.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import numpy as np
22

33
from hls4ml.model.optimizer import OptimizerPass
4-
from hls4ml.model.layers import Conv1D, Conv2D, Dense, SeparableConv1D, SeparableConv2D, LSTM, GRU
4+
from hls4ml.model.layers import Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose, Dense, SeparableConv1D, SeparableConv2D, LSTM, GRU
55

66
class ApplyResourceStrategy(OptimizerPass):
77
''' Transposes the weights to use the dense_resource matrix multiply routine '''
88
def match(self, node):
99

10-
node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU))
10+
node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, Conv1DTranspose, Conv2DTranspose, SeparableConv2D, LSTM, GRU))
11+
1112
is_resource_strategy = node.get_attr('strategy', '').lower() == 'resource'
1213
already_transformed = node.get_attr('_weights_transposed', False) == True
1314

@@ -18,11 +19,15 @@ def transform(self, model, node):
1819
node.weights['weight'].data = np.transpose(node.weights['weight'].data)
1920
elif isinstance(node, Conv1D):
2021
node.weights['weight'].data = np.transpose(node.weights['weight'].data, axes=[2, 0, 1]) #(W,C,F) => (F,W,C)
22+
elif isinstance(node, Conv1DTranspose):
23+
node.weights['weight'].data = np.transpose(node.weights['weight'].data, axes=[1, 0, 2]) #(W,F,C) => (F,W,C)
2124
elif isinstance(node, SeparableConv1D):
2225
node.weights['depthwise'].data = np.transpose(node.weights['depthwise'].data, axes=[2, 0, 1]) #(W,C,F) => (F,W,C)
2326
node.weights['pointwise'].data = np.transpose(node.weights['pointwise'].data, axes=[2, 0, 1]) #(W,C,F) => (F,W,C)
2427
elif isinstance(node, Conv2D):
2528
node.weights['weight'].data = np.transpose(node.weights['weight'].data, axes=[3, 0, 1, 2]) #(H,W,C,F) => (F,H,W,C)
29+
elif isinstance(node, Conv2DTranspose):
30+
node.weights['weight'].data = np.transpose(node.weights['weight'].data, axes=[2, 0, 1, 3]) #(H,W,F,C) => (F,H,W,C)
2631
elif isinstance(node, SeparableConv2D):
2732
node.weights['depthwise'].data = np.transpose(node.weights['depthwise'].data, axes=[3, 0, 1, 2]) #(H,W,C,F) => (F,H,W,C)
2833
node.weights['pointwise'].data = np.transpose(node.weights['pointwise'].data, axes=[3, 0, 1, 2]) #(H,W,C,F) => (F,H,W,C)

hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_resource.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ void conv_1d_transpose_resource_cl(
5151
trfilt_weights[
5252
i_nf * CONFIG_T::n_chan * CONFIG_T::trfilt_width + i_fw * CONFIG_T::n_chan + i_nc
5353
][i_sw] = weights[
54-
filt_ind * CONFIG_T::n_filt * CONFIG_T::n_chan + i_nf * CONFIG_T::n_chan + i_nc
54+
i_nf * CONFIG_T::n_chan * CONFIG_T::filt_width + filt_ind * CONFIG_T::n_chan + i_nc
5555
];
5656
}
5757
else {
5858
trfilt_weights[
59-
i_fw * CONFIG_T::n_chan + i_nf * CONFIG_T::n_chan * CONFIG_T::trfilt_width + i_nc
59+
i_nf * CONFIG_T::n_chan * CONFIG_T::trfilt_width + i_fw * CONFIG_T::n_chan + i_nc
6060
][i_sw] = 0;
6161
}
6262
}

hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ void weights_trim(
3333
#pragma HLS UNROLL
3434
#pragma HLS PIPELINE
3535
for (int chan_ind = 0; chan_ind < CONFIG_T::n_chan; chan_ind++) {
36-
// #pragma HLS LOOP_FLATTEN
3736
#pragma HLS UNROLL
3837
#pragma HLS PIPELINE
3938
if (row_indices[step] >= CONFIG_T::filt_width) {
@@ -42,14 +41,12 @@ void weights_trim(
4241
} else {
4342
row_weights[filt_ind * CONFIG_T::trfilt_width * CONFIG_T::n_chan +
4443
step * CONFIG_T::n_chan + chan_ind] =
45-
weights[row_indices[step] * CONFIG_T::n_filt * CONFIG_T::n_chan +
46-
filt_ind * CONFIG_T::n_chan + chan_ind];
44+
weights[filt_ind * CONFIG_T::filt_width * CONFIG_T::n_chan +
45+
row_indices[step] * CONFIG_T::n_chan + chan_ind];
4746
}
4847
}
4948
}
5049
}
51-
//try to split if else into two loops
52-
//pre-compute
5350
}
5451

5552
template <class data_T, typename CONFIG_T>
@@ -116,36 +113,27 @@ void compute_output_buffer_tr_1d(
116113

117114
//always do stride number of multiplications
118115
StrideLoop: for (int idx = 0; idx < CONFIG_T::stride_width; idx++) {
119-
// #pragma HLS DATAFLOW
120-
// #pragma HLS PIPELINE
121-
// #pragma HLS INLINE region
122116
//load in the weights for this multiplication
123-
WeightsRegion: {
124-
weights_trim<CONFIG_T>(
125-
weights, row_weights, weight_start
126-
);
117+
weights_trim<CONFIG_T>(
118+
weights, row_weights, weight_start
119+
);
120+
121+
// Dense multiply
122+
if (CONFIG_T::strategy == nnet::latency) {
123+
dense_latency<typename data_T::value_type, typename res_T::value_type, typename CONFIG_T::mult_config>(
124+
kernel_data, res_out, row_weights, biases);
125+
} else {
126+
dense_resource<typename data_T::value_type, typename res_T::value_type, typename CONFIG_T::mult_config>(
127+
kernel_data, res_out, row_weights, biases);
127128
}
128129

129-
MultRegion: {
130-
// Dense multiply
131-
if (CONFIG_T::strategy == nnet::latency) {
132-
dense_latency<typename data_T::value_type, typename res_T::value_type, typename CONFIG_T::mult_config>(
133-
kernel_data, res_out, row_weights, biases);
134-
} else {
135-
dense_resource<typename data_T::value_type, typename res_T::value_type, typename CONFIG_T::mult_config>(
136-
kernel_data, res_out, row_weights, biases);
137-
}
138-
}
139-
140-
PackRegion: {
141-
// Pack output
142-
if (oX >= CONFIG_T::pad_left && oX < CONFIG_T::pad_left + CONFIG_T::out_width) {
143-
CastLoop: for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) {
144-
#pragma HLS UNROLL
145-
res_pack[i_ic] = res_out[i_ic];
146-
}
147-
res_stream.write(res_pack);
130+
// Pack output
131+
if (oX >= CONFIG_T::pad_left && oX < CONFIG_T::pad_left + CONFIG_T::out_width) {
132+
CastLoop: for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) {
133+
#pragma HLS UNROLL
134+
res_pack[i_ic] = res_out[i_ic];
148135
}
136+
res_stream.write(res_pack);
149137
}
150138
// Write output to stream when output ready
151139
oX++;

hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_resource.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ void conv_2d_transpose_resource_cl(
5757
i_fh * CONFIG_T::trfilt_width * CONFIG_T::n_chan +
5858
i_fw * CONFIG_T::n_chan + i_nc
5959
][i_sh][i_sw] = weights[
60-
filt_h_ind * CONFIG_T::filt_width * CONFIG_T::n_filt * CONFIG_T::n_chan +
61-
filt_w_ind * CONFIG_T::n_filt * CONFIG_T::n_chan +
62-
i_nf * CONFIG_T::n_chan + i_nc
60+
i_nf * CONFIG_T::n_chan * CONFIG_T::filt_height * CONFIG_T::filt_width +
61+
filt_h_ind * CONFIG_T::n_chan * CONFIG_T::filt_width +
62+
filt_w_ind * CONFIG_T::n_chan + i_nc
6363
];
6464
}
6565
else {

hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ void load_tr_kern_weights(
112112
y_step * CONFIG_T::trfilt_width * CONFIG_T::n_chan +
113113
x_step * CONFIG_T::n_chan + chan_ind
114114
] = weights[
115-
y_indices[y_step] * CONFIG_T::filt_width * CONFIG_T::n_filt * CONFIG_T::n_chan +
116-
x_indices[x_step] * CONFIG_T::n_filt * CONFIG_T::n_chan +
117-
filt_ind * CONFIG_T::n_chan + chan_ind
115+
filt_ind * CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan +
116+
y_indices[y_step] * CONFIG_T::filt_width * CONFIG_T::n_chan +
117+
x_indices[x_step] * CONFIG_T::n_chan + chan_ind
118118
];
119119
}
120120
}

0 commit comments

Comments
 (0)