@@ -48,31 +48,30 @@ def run_mediapipe_solution(solution, inp_size):
48
48
49
49
# Save TensorFlow model as TFLite
50
50
def save_tflite_model (model , inp , name ):
51
- func = model .__call__ . get_concrete_function ()
52
- converter = tf .lite .TFLiteConverter .from_concrete_functions ([func ], model )
51
+ func = model .get_concrete_function ()
52
+ converter = tf .lite .TFLiteConverter .from_concrete_functions ([func ])
53
53
tflite_model = converter .convert ()
54
54
55
+ interpreter = tf .lite .Interpreter (model_content = tflite_model )
56
+
55
57
with open (f'{ name } .tflite' , 'wb' ) as f :
56
58
f .write (tflite_model )
57
59
58
60
out = model (inp )
59
61
60
62
np .save (f'{ name } _inp.npy' , inp .transpose (0 , 3 , 1 , 2 ))
61
- np .save (f'{ name } _out_PartitionedCall:0 .npy' , np .array (out ).transpose (0 , 3 , 1 , 2 ))
63
+ np .save (f'{ name } _out_Identity .npy' , np .array (out ).transpose (0 , 3 , 1 , 2 ))
62
64
63
65
64
- class ReplicateByPack (tf .Module ):
65
- @tf .function (input_signature = [tf .TensorSpec (shape = [1 , 3 , 3 , 1 ], dtype = tf .float32 )])
66
- def __call__ (self , x ):
67
- pack_1 = tf .stack ([x , x ], axis = 3 )
68
- reshape_1 = tf .reshape (pack_1 , [1 , 3 , 6 , 1 ])
69
- pack_2 = tf .stack ([reshape_1 , reshape_1 ], axis = 2 )
70
- reshape_2 = tf .reshape (pack_2 , [1 , 6 , 6 , 1 ])
71
- scaled = tf .image .resize (reshape_2 , size = (3 , 3 ), method = tf .image .ResizeMethod .NEAREST_NEIGHBOR )
72
- return scaled + x
66
+ @tf .function (input_signature = [tf .TensorSpec (shape = [1 , 3 , 3 , 1 ], dtype = tf .float32 )])
67
+ def replicate_by_pack (x ):
68
+ pack_1 = tf .stack ([x , x ], axis = 3 )
69
+ reshape_1 = tf .reshape (pack_1 , [1 , 3 , 6 , 1 ])
70
+ pack_2 = tf .stack ([reshape_1 , reshape_1 ], axis = 2 )
71
+ reshape_2 = tf .reshape (pack_2 , [1 , 6 , 6 , 1 ])
72
+ scaled = tf .image .resize (reshape_2 , size = (3 , 3 ), method = tf .image .ResizeMethod .NEAREST_NEIGHBOR )
73
+ return scaled + x
73
74
74
- model = ReplicateByPack ()
75
75
inp = np .random .standard_normal ((1 , 3 , 3 , 1 )).astype (np .float32 )
76
-
77
- save_tflite_model (model , inp , 'replicate_by_pack' )
76
+ save_tflite_model (replicate_by_pack , inp , 'replicate_by_pack' )
78
77
0 commit comments