Skip to content

Commit 09695b4

Browse files
committed
Quartus Streaming GRU
1 parent e7ab058 commit 09695b4

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed
Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,67 @@
1-
/*
2-
* PLACEHOLDER - TODO - Implement once PR #557 is merged
3-
*/
4-
51
#ifndef NNET_RECURRENT_STREAM_H_
62
#define NNET_RECURRENT_STREAM_H_
73

8-
namespace nnet {}
4+
#include "nnet_common.h"
5+
#include "nnet_dense.h"
6+
#include "nnet_recurrent_activation.h"
7+
8+
namespace nnet {
9+
template<class data_T, class res_T, typename CONFIG_T>
10+
void gru(
11+
stream<data_T> &data_stream,
12+
stream<res_T> &res_stream,
13+
const typename CONFIG_T::weight_t weights[3 * CONFIG_T::n_units * CONFIG_T::n_in],
14+
const typename CONFIG_T::weight_t recurrent_weights[3 * CONFIG_T::n_units * CONFIG_T::n_units],
15+
const typename CONFIG_T::bias_t bias[3 * CONFIG_T::n_units],
16+
const typename CONFIG_T::bias_t recurrent_bias[3 * CONFIG_T::n_units]
17+
) {
18+
19+
hls_register typename res_T::value_type h[CONFIG_T::n_units];
20+
#pragma unroll
21+
for(int i = 0; i < CONFIG_T::n_units; i++) {
22+
h[i] = 0;
23+
}
24+
25+
hls_register typename data_T::value_type x[CONFIG_T::n_in];
26+
27+
DataPropagation:
28+
for(int i_in = 0; i_in < CONFIG_T::n_timesteps * CONFIG_T::n_in / data_T::size; i_in++) {
29+
data_T data_pack = data_stream.read();
30+
31+
DataPack:
32+
#pragma unroll
33+
for (int i_pack = 0; i_pack < data_T::size; i_pack++) {
34+
x[i_pack] = data_pack[i_pack];
35+
}
36+
37+
nnet::gru_cell<typename data_T::value_type, typename res_T::value_type, CONFIG_T>(x, h, weights, recurrent_weights, bias, recurrent_bias);
38+
39+
if (CONFIG_T::return_sequences) {
40+
res_T res_pack;
41+
42+
ResPackRetSeq:
43+
#pragma unroll
44+
for (int i_pack = 0; i_pack < res_T::size; i_pack++) {
45+
res_pack[i_pack] = h[i_pack];
46+
}
47+
48+
res_stream.write(res_pack);
49+
}
50+
}
51+
52+
if (!CONFIG_T::return_sequences) {
53+
res_T res_pack;
54+
55+
ResPackNoRetSeq:
56+
#pragma unroll
57+
for (int i_pack = 0; i_pack < res_T::size; i_pack++) {
58+
res_pack[i_pack] = h[i_pack];
59+
}
60+
61+
res_stream.write(res_pack);
62+
}
63+
}
64+
65+
}
966

1067
#endif

test/pytest/test_rnn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def test_rnn_parsing(rnn_layer, return_sequences):
6969
(GRU, 'Vivado', 'io_parallel'),
7070
(GRU, 'Vivado', 'io_stream'),
7171
(GRU, 'Quartus', 'io_parallel'),
72+
(GRU, 'Quartus', 'io_stream'),
7273
])
7374
@pytest.mark.parametrize('return_sequences', [True, False])
7475
@pytest.mark.parametrize('static', [True, False])

0 commit comments

Comments
 (0)