Skip to content

Commit f9f7143

Browse files
authored
Merge pull request #1160 from dkurt:tflite_new_layers
* Split layer test * FullyConnected test * Update FullyConnected test date
1 parent d0f8e8c commit f9f7143

File tree

7 files changed

+31
-4
lines changed

7 files changed

+31
-4
lines changed
956 Bytes
Binary file not shown.
136 Bytes
Binary file not shown.
Binary file not shown.

testdata/dnn/tflite/generate.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,18 @@ def save_tflite_model(model, inp, name):
5252
converter = tf.lite.TFLiteConverter.from_concrete_functions([func])
5353
tflite_model = converter.convert()
5454

55-
interpreter = tf.lite.Interpreter(model_content=tflite_model)
56-
5755
with open(f'{name}.tflite', 'wb') as f:
5856
f.write(tflite_model)
5957

6058
out = model(inp)
59+
out = np.array(out)
60+
61+
if len(inp.shape) == 4:
62+
inp = inp.transpose(0, 3, 1, 2)
63+
out = out.transpose(0, 3, 1, 2)
6164

62-
np.save(f'{name}_inp.npy', inp.transpose(0, 3, 1, 2))
63-
np.save(f'{name}_out_Identity.npy', np.array(out).transpose(0, 3, 1, 2))
65+
np.save(f'{name}_inp.npy', inp)
66+
np.save(f'{name}_out_Identity.npy', out)
6467

6568

6669
@tf.function(input_signature=[tf.TensorSpec(shape=[1, 3, 3, 1], dtype=tf.float32)])
@@ -75,3 +78,27 @@ def replicate_by_pack(x):
7578
inp = np.random.standard_normal((1, 3, 3, 1)).astype(np.float32)
7679
save_tflite_model(replicate_by_pack, inp, 'replicate_by_pack')
7780

81+
@tf.function(input_signature=[tf.TensorSpec(shape=[1, 3], dtype=tf.float32)])
82+
def split(x):
83+
splitted = tf.split(
84+
x, 3, axis=-1, num=None, name='split'
85+
)
86+
return tf.concat((splitted[2], splitted[1], splitted[0]), axis=-1)
87+
88+
inp = np.random.standard_normal((1, 3)).astype(np.float32)
89+
save_tflite_model(split, inp, 'split')
90+
91+
92+
fully_connected = tf.keras.models.Sequential([
93+
tf.keras.layers.Dense(3),
94+
tf.keras.layers.ReLU(),
95+
tf.keras.layers.Softmax(),
96+
])
97+
98+
fully_connected = tf.function(
99+
fully_connected.call,
100+
input_signature=[tf.TensorSpec((1,2), tf.float32)],
101+
)
102+
103+
inp = np.random.standard_normal((1, 2)).astype(np.float32)
104+
save_tflite_model(fully_connected, inp, 'fully_connected')

testdata/dnn/tflite/split.tflite

1.04 KB
Binary file not shown.

testdata/dnn/tflite/split_inp.npy

140 Bytes
Binary file not shown.
140 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)