Skip to content

Commit a9546c7

Browse files
author
Enrico Lupi
committed
ADD tests for bidirectional layer
1 parent 5eef679 commit a9546c7

File tree

1 file changed

+83
-48
lines changed

1 file changed

+83
-48
lines changed

test/pytest/test_rnn.py

Lines changed: 83 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44
import pytest
5-
from tensorflow.keras.layers import GRU, LSTM, Input, SimpleRNN
5+
from tensorflow.keras.layers import GRU, LSTM, Bidirectional, Input, SimpleRNN
66
from tensorflow.keras.models import Model, Sequential
77

88
import hls4ml
@@ -14,13 +14,21 @@
1414

1515
@pytest.mark.parametrize('rnn_layer', rnn_layers)
1616
@pytest.mark.parametrize('return_sequences', [True, False])
17-
def test_rnn_parsing(rnn_layer, return_sequences):
17+
@pytest.mark.parametrize('bidirectional', [True, False])
18+
def test_rnn_parsing(rnn_layer, return_sequences, bidirectional):
19+
20+
if rnn_layer is SimpleRNN and bidirectional:
21+
pytest.skip("SimpleRNN does not support bidirectional layers")
22+
1823
time_steps = 3
1924
input_size = 8
2025
input_shape = (time_steps, input_size)
2126

2227
model_input = Input(shape=input_shape)
23-
model_output = rnn_layer(64, return_sequences=return_sequences)(model_input)
28+
if not bidirectional:
29+
model_output = rnn_layer(64, return_sequences=return_sequences)(model_input)
30+
else:
31+
model_output = Bidirectional(rnn_layer(64, return_sequences=return_sequences))(model_input)
2432

2533
model = Model(model_input, model_output)
2634
model.compile(optimizer='adam', loss='mse')
@@ -34,13 +42,26 @@ def test_rnn_parsing(rnn_layer, return_sequences):
3442
keras_layer = model.layers[1]
3543

3644
# Basic sanity check, I/O, activations
37-
assert hls_layer.class_name == rnn_layer.__name__
38-
assert hls_layer.attributes['n_out'] == keras_layer.units
39-
assert hls_layer.attributes['activation'] == keras_layer.activation.__name__
40-
if 'recurrent_activation' in hls_layer.attributes: # SimpleRNN doesn't have this
41-
assert hls_layer.attributes['recurrent_activation'] == keras_layer.recurrent_activation.__name__
42-
assert hls_layer.get_input_variable().shape == list(input_shape)
43-
assert hls_layer.get_output_variable().shape == model_output.shape.as_list()[1:] # Ignore the batch size
45+
if not bidirectional:
46+
assert hls_layer.class_name == rnn_layer.__name__
47+
assert hls_layer.attributes['n_out'] == keras_layer.units
48+
assert hls_layer.attributes['activation'] == keras_layer.activation.__name__
49+
if 'recurrent_activation' in hls_layer.attributes: # SimpleRNN doesn't have this
50+
assert hls_layer.attributes['recurrent_activation'] == keras_layer.recurrent_activation.__name__
51+
assert hls_layer.get_input_variable().shape == list(input_shape)
52+
assert hls_layer.get_output_variable().shape == model_output.shape.as_list()[1:] # Ignore the batch size
53+
else:
54+
assert hls_layer.class_name == 'Bidirectional' + rnn_layer.__name__
55+
assert hls_layer.attributes['merge_mode'] == keras_layer.merge_mode
56+
if hls_layer.attributes['merge_mode'] == 'concat':
57+
assert hls_layer.attributes['n_out'] == 2 * keras_layer.forward_layer.units
58+
else:
59+
assert hls_layer.attributes['n_out'] == keras_layer.forward_layer.units
60+
assert hls_layer.attributes['activation'] == keras_layer.forward_layer.activation.__name__
61+
if 'recurrent_activation' in hls_layer.attributes: # SimpleRNN doesn't have this
62+
assert hls_layer.attributes['recurrent_activation'] == keras_layer.forward_layer.recurrent_activation.__name__
63+
assert hls_layer.get_input_variable().shape == list(input_shape)
64+
assert hls_layer.get_output_variable().shape == model_output.shape.as_list()[1:] # Ignore the batch size
4465

4566
# Compare weights
4667
hls_weights = list(hls_layer.get_weights()) # [weights, recurrent_weights, bias, recurrent_bias]
@@ -66,54 +87,66 @@ def test_rnn_parsing(rnn_layer, return_sequences):
6687

6788

6889
@pytest.mark.parametrize(
69-
'rnn_layer, backend, io_type, strategy',
90+
'rnn_layer, bidirectional, backend, io_type, strategy',
7091
[
71-
(SimpleRNN, 'Quartus', 'io_parallel', 'resource'),
72-
(SimpleRNN, 'oneAPI', 'io_parallel', 'resource'),
73-
(LSTM, 'Vivado', 'io_parallel', 'resource'),
74-
(LSTM, 'Vivado', 'io_parallel', 'latency'),
75-
(LSTM, 'Vitis', 'io_parallel', 'resource'),
76-
(LSTM, 'Vitis', 'io_parallel', 'latency'),
77-
(LSTM, 'Quartus', 'io_parallel', 'resource'),
78-
(LSTM, 'oneAPI', 'io_parallel', 'resource'),
79-
(LSTM, 'Vivado', 'io_stream', 'resource'),
80-
(LSTM, 'Vivado', 'io_stream', 'latency'),
81-
(LSTM, 'Vitis', 'io_stream', 'resource'),
82-
(LSTM, 'Vitis', 'io_stream', 'latency'),
83-
(GRU, 'Vivado', 'io_parallel', 'resource'),
84-
(GRU, 'Vivado', 'io_parallel', 'latency'),
85-
(GRU, 'Vitis', 'io_parallel', 'resource'),
86-
(GRU, 'Vitis', 'io_parallel', 'latency'),
87-
(GRU, 'Quartus', 'io_parallel', 'resource'),
88-
(GRU, 'oneAPI', 'io_parallel', 'resource'),
89-
(GRU, 'Vivado', 'io_stream', 'resource'),
90-
(GRU, 'Vivado', 'io_stream', 'latency'),
91-
(GRU, 'Vitis', 'io_stream', 'resource'),
92-
(GRU, 'Vitis', 'io_stream', 'latency'),
93-
(GRU, 'Quartus', 'io_stream', 'resource'),
94-
(GRU, 'oneAPI', 'io_stream', 'resource'),
92+
(SimpleRNN, False, 'Quartus', 'io_parallel', 'resource'),
93+
(SimpleRNN, False, 'oneAPI', 'io_parallel', 'resource'),
94+
(LSTM, False, 'Vivado', 'io_parallel', 'resource'),
95+
(LSTM, False, 'Vivado', 'io_parallel', 'latency'),
96+
(LSTM, False, 'Vitis', 'io_parallel', 'resource'),
97+
(LSTM, False, 'Vitis', 'io_parallel', 'latency'),
98+
(LSTM, True, 'Vivado', 'io_parallel', 'resource'),
99+
(LSTM, True, 'Vivado', 'io_parallel', 'latency'),
100+
(LSTM, True, 'Vitis', 'io_parallel', 'resource'),
101+
(LSTM, True, 'Vitis', 'io_parallel', 'latency'),
102+
(LSTM, False, 'Quartus', 'io_parallel', 'resource'),
103+
(LSTM, False, 'oneAPI', 'io_parallel', 'resource'),
104+
(LSTM, False, 'Vivado', 'io_stream', 'resource'),
105+
(LSTM, False, 'Vivado', 'io_stream', 'latency'),
106+
(LSTM, False, 'Vitis', 'io_stream', 'resource'),
107+
(LSTM, False, 'Vitis', 'io_stream', 'latency'),
108+
(GRU, False, 'Vivado', 'io_parallel', 'resource'),
109+
(GRU, False, 'Vivado', 'io_parallel', 'latency'),
110+
(GRU, False, 'Vitis', 'io_parallel', 'resource'),
111+
(GRU, False, 'Vitis', 'io_parallel', 'latency'),
112+
(GRU, True, 'Vivado', 'io_parallel', 'resource'),
113+
(GRU, True, 'Vivado', 'io_parallel', 'latency'),
114+
(GRU, True, 'Vitis', 'io_parallel', 'resource'),
115+
(GRU, True, 'Vitis', 'io_parallel', 'latency'),
116+
(GRU, False, 'Quartus', 'io_parallel', 'resource'),
117+
(GRU, False, 'oneAPI', 'io_parallel', 'resource'),
118+
(GRU, False, 'Vivado', 'io_stream', 'resource'),
119+
(GRU, False, 'Vivado', 'io_stream', 'latency'),
120+
(GRU, False, 'Vitis', 'io_stream', 'resource'),
121+
(GRU, False, 'Vitis', 'io_stream', 'latency'),
122+
(GRU, False, 'Quartus', 'io_stream', 'resource'),
123+
(GRU, False, 'oneAPI', 'io_stream', 'resource'),
95124
],
96125
)
97126
@pytest.mark.parametrize('return_sequences', [True, False])
98127
@pytest.mark.parametrize('static', [True, False])
99-
def test_rnn_accuracy(rnn_layer, return_sequences, backend, io_type, strategy, static):
128+
def test_rnn_accuracy(rnn_layer, bidirectional, return_sequences, backend, io_type, strategy, static):
100129
# Subtract 0.5 to include negative values
101130
input_shape = (12, 8)
102131
X = np.random.rand(50, *input_shape) - 0.5
103132

104-
layer_name = rnn_layer.__name__
133+
layer_name = ("Bidirectional" if bidirectional else "") + rnn_layer.__name__
105134
keras_model = Sequential()
106-
keras_model.add(
107-
rnn_layer(
108-
units=32,
109-
input_shape=input_shape,
110-
kernel_initializer='lecun_uniform',
111-
recurrent_initializer='lecun_uniform',
112-
bias_initializer='lecun_uniform',
113-
return_sequences=return_sequences,
114-
name=layer_name,
115-
)
135+
keras_model.add(Input(shape=input_shape))
136+
test_layer = rnn_layer(
137+
units=32,
138+
input_shape=input_shape,
139+
kernel_initializer='lecun_uniform',
140+
recurrent_initializer='lecun_uniform',
141+
bias_initializer='lecun_uniform',
142+
return_sequences=return_sequences,
143+
name=layer_name,
116144
)
145+
if not bidirectional:
146+
keras_model.add(test_layer)
147+
else:
148+
keras_model.add(Bidirectional(test_layer, name=layer_name))
149+
117150
keras_model.compile()
118151

119152
default_precision = 'ap_fixed<32, 16>' if backend in ['Vivado', 'Vitis'] else 'ac_fixed<32, 16, true>'
@@ -123,7 +156,9 @@ def test_rnn_accuracy(rnn_layer, return_sequences, backend, io_type, strategy, s
123156
hls_config['LayerName'][layer_name]['static'] = static
124157
hls_config['LayerName'][layer_name]['Strategy'] = strategy
125158
prj_name = (
126-
f'hls4mlprj_rnn_accuracy_{layer_name}_static_{int(static)}_ret_seq_{int(return_sequences)}_'
159+
'hls4mlprj_rnn_accuracy_'
160+
+ ('bidirectional_' if bidirectional else '')
161+
+ f'{layer_name}_static_{int(static)}_ret_seq_{int(return_sequences)}_'
127162
f'{backend}_{io_type}_{strategy}'
128163
)
129164
output_dir = str(test_root_path / prj_name)

0 commit comments

Comments
 (0)