Skip to content

Commit dd1fbd0

Browse files
authored
Merge pull request #1180 from CNOCycle:tflite/ops
* Simpify generating permutation testes for tflite * Simpify converting keras into TF for tflite tests * Add global_pool_2d tests for tflite models
1 parent 723bdf2 commit dd1fbd0

7 files changed

+35
-55
lines changed

testdata/dnn/tflite/generate.py

Lines changed: 35 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -93,77 +93,57 @@ def split(x):
9393
inp = np.random.standard_normal((1, 3)).astype(np.float32)
9494
save_tflite_model(split, inp, 'split')
9595

96+
def keras_to_tf(model, input_shape):
97+
tf_func = tf.function(
98+
model.call,
99+
input_signature=[tf.TensorSpec(input_shape, tf.float32)],
100+
)
101+
inp = np.random.standard_normal((input_shape)).astype(np.float32)
102+
103+
return tf_func, inp
96104

97105
fully_connected = tf.keras.models.Sequential([
98106
tf.keras.layers.Dense(3),
99107
tf.keras.layers.ReLU(),
100108
tf.keras.layers.Softmax(),
101109
])
102110

103-
fully_connected = tf.function(
104-
fully_connected.call,
105-
input_signature=[tf.TensorSpec((1,2), tf.float32)],
106-
)
107-
108-
inp = np.random.standard_normal((1, 2)).astype(np.float32)
111+
fully_connected, inp = keras_to_tf(fully_connected, (1, 2))
109112
save_tflite_model(fully_connected, inp, 'fully_connected')
110113

111114
permutation_3d = tf.keras.models.Sequential([
112-
tf.keras.layers.Permute((2,1))
115+
tf.keras.layers.Permute((2, 1))
113116
])
114117

115-
permutation_3d = tf.function(
116-
permutation_3d.call,
117-
input_signature=[tf.TensorSpec((1,2,3), tf.float32)],
118-
)
119-
inp = np.random.standard_normal((1, 2, 3)).astype(np.float32)
118+
permutation_3d, inp = keras_to_tf(permutation_3d, (1, 2, 3))
120119
save_tflite_model(permutation_3d, inp, 'permutation_3d')
121120

122-
# Temporarily disabled as TFLiteConverter produces a incorrect graph in this case
123-
#permutation_4d_0123 = tf.keras.models.Sequential([
124-
# tf.keras.layers.Permute((1,2,3)),
125-
# tf.keras.layers.Conv2D(3,1)
126-
#])
127-
#
128-
#permutation_4d_0123 = tf.function(
129-
# permutation_4d_0123.call,
130-
# input_signature=[tf.TensorSpec((1,2,3,4), tf.float32)],
131-
#)
132-
#inp = np.random.standard_normal((1, 2, 3, 4)).astype(np.float32)
133-
#save_tflite_model(permutation_4d_0123, inp, 'permutation_4d_0123')
134-
135-
permutation_4d_0132 = tf.keras.models.Sequential([
136-
tf.keras.layers.Permute((1,3,2)),
137-
tf.keras.layers.Conv2D(3,1)
138-
])
139-
140-
permutation_4d_0132 = tf.function(
141-
permutation_4d_0132.call,
142-
input_signature=[tf.TensorSpec((1,2,3,4), tf.float32)],
143-
)
144-
inp = np.random.standard_normal((1, 2, 3, 4)).astype(np.float32)
145-
save_tflite_model(permutation_4d_0132, inp, 'permutation_4d_0132')
146-
147-
permutation_4d_0213 = tf.keras.models.Sequential([
148-
tf.keras.layers.Permute((2,1,3)),
149-
tf.keras.layers.Conv2D(3,1)
121+
# (1, 2, 3) is temporarily disabled as TFLiteConverter produces a incorrect graph in this case
122+
permutation_4d_list = [(1, 3, 2), (2, 1, 3), (2, 3, 1)]
123+
for perm_axis in permutation_4d_list:
124+
permutation_4d_model = tf.keras.models.Sequential([
125+
tf.keras.layers.Permute(perm_axis),
126+
tf.keras.layers.Conv2D(3, 1)
127+
])
128+
129+
permutation_4d_model, inp = keras_to_tf(permutation_4d_model, (1, 2, 3, 4))
130+
model_name = f"permutation_4d_0{''.join(map(str, perm_axis))}"
131+
save_tflite_model(permutation_4d_model, inp, model_name)
132+
133+
global_average_pooling_2d = tf.keras.models.Sequential([
134+
tf.keras.layers.GlobalAveragePooling2D(keepdims=True),
135+
tf.keras.layers.ZeroPadding2D(1),
136+
tf.keras.layers.GlobalAveragePooling2D(keepdims=False)
150137
])
151138

152-
permutation_4d_0213 = tf.function(
153-
permutation_4d_0213.call,
154-
input_signature=[tf.TensorSpec((1,2,3,4), tf.float32)],
155-
)
156-
inp = np.random.standard_normal((1, 2, 3, 4)).astype(np.float32)
157-
save_tflite_model(permutation_4d_0213, inp, 'permutation_4d_0213')
139+
global_average_pooling_2d, inp = keras_to_tf(global_average_pooling_2d, (1, 7, 7, 5))
140+
save_tflite_model(global_average_pooling_2d, inp, 'global_average_pooling_2d')
158141

159-
permutation_4d_0231 = tf.keras.models.Sequential([
160-
tf.keras.layers.Permute((2,3,1)),
161-
tf.keras.layers.Conv2D(3,1)
142+
global_max_pool = tf.keras.models.Sequential([
143+
tf.keras.layers.GlobalMaxPool2D(keepdims=True),
144+
tf.keras.layers.ZeroPadding2D(1),
145+
tf.keras.layers.GlobalMaxPool2D(keepdims=True)
162146
])
163147

164-
permutation_4d_0231 = tf.function(
165-
permutation_4d_0231.call,
166-
input_signature=[tf.TensorSpec((1,2,3,4), tf.float32)],
167-
)
168-
inp = np.random.standard_normal((1, 2, 3, 4)).astype(np.float32)
169-
save_tflite_model(permutation_4d_0231, inp, 'permutation_4d_0231')
148+
global_max_pool, inp = keras_to_tf(global_max_pool, (1, 7, 7, 5))
149+
save_tflite_model(global_max_pool, inp, 'global_max_pooling_2d')
Binary file not shown.
Binary file not shown.
Binary file not shown.
1.25 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)