Skip to content

Commit 89d03b0

Browse files
committed
Update test data
1 parent 06b9136 commit 89d03b0

File tree

4 files changed

+14
-15
lines changed

4 files changed

+14
-15
lines changed

testdata/dnn/tflite/generate.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,31 +48,30 @@ def run_mediapipe_solution(solution, inp_size):
4848

4949
# Save TensorFlow model as TFLite
5050
def save_tflite_model(model, inp, name):
51-
func = model.__call__.get_concrete_function()
52-
converter = tf.lite.TFLiteConverter.from_concrete_functions([func], model)
51+
func = model.get_concrete_function()
52+
converter = tf.lite.TFLiteConverter.from_concrete_functions([func])
5353
tflite_model = converter.convert()
5454

55+
interpreter = tf.lite.Interpreter(model_content=tflite_model)
56+
5557
with open(f'{name}.tflite', 'wb') as f:
5658
f.write(tflite_model)
5759

5860
out = model(inp)
5961

6062
np.save(f'{name}_inp.npy', inp.transpose(0, 3, 1, 2))
61-
np.save(f'{name}_out_PartitionedCall:0.npy', np.array(out).transpose(0, 3, 1, 2))
63+
np.save(f'{name}_out_Identity.npy', np.array(out).transpose(0, 3, 1, 2))
6264

6365

64-
class ReplicateByPack(tf.Module):
65-
@tf.function(input_signature=[tf.TensorSpec(shape=[1, 3, 3, 1], dtype=tf.float32)])
66-
def __call__(self, x):
67-
pack_1 = tf.stack([x, x], axis=3)
68-
reshape_1 = tf.reshape(pack_1, [1, 3, 6, 1])
69-
pack_2 = tf.stack([reshape_1, reshape_1], axis=2)
70-
reshape_2 = tf.reshape(pack_2, [1, 6, 6, 1])
71-
scaled = tf.image.resize(reshape_2, size=(3, 3), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
72-
return scaled + x
66+
@tf.function(input_signature=[tf.TensorSpec(shape=[1, 3, 3, 1], dtype=tf.float32)])
67+
def replicate_by_pack(x):
68+
pack_1 = tf.stack([x, x], axis=3)
69+
reshape_1 = tf.reshape(pack_1, [1, 3, 6, 1])
70+
pack_2 = tf.stack([reshape_1, reshape_1], axis=2)
71+
reshape_2 = tf.reshape(pack_2, [1, 6, 6, 1])
72+
scaled = tf.image.resize(reshape_2, size=(3, 3), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
73+
return scaled + x
7374

74-
model = ReplicateByPack()
7575
inp = np.random.standard_normal((1, 3, 3, 1)).astype(np.float32)
76-
77-
save_tflite_model(model, inp, 'replicate_by_pack')
76+
save_tflite_model(replicate_by_pack, inp, 'replicate_by_pack')
7877

-124 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)