1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import re
5
+ from future .moves .urllib .request import urlretrieve
6
+ import unittest
7
+
8
+ from PIL import Image
9
+ import rospkg
10
+ from image_recognition_footwear .model import Model
11
+ from image_recognition_footwear .process_data import heroPreprocess , detection_RGB
12
+ import torch
13
+
14
+ @unittest .skip
15
+ def test_footwear ():
16
+ local_path = "~/data/pytorch_models/footwearModel.pth"
17
+
18
+ if not os .path .exists (local_path ):
19
+ print ("File does not exit {}" .format (local_path ))
20
+
21
+ def is_there_footwear_from_asset_name (asset_name ):
22
+ binary_str = re .search ("(\w+)_shoe" , asset_name ).groups ()
23
+ return binary_str == "yes"
24
+
25
+ assets_path = os .path .join (rospkg .RosPack ().get_path ("image_recognition_footwear" ), 'test/assets' )
26
+ images_gt = [(Image .open (os .path .join (assets_path , asset )), is_there_footwear_from_asset_name (asset ))
27
+ for asset in os .listdir (assets_path )]
28
+
29
+ device = torch .device ('cuda' )
30
+ model = Model (in_channel = 3 , channel_1 = 128 , channel_2 = 256 , channel_3 = 512 , node_1 = 1024 , node_2 = 1024 , num_classes = 2 )
31
+ model .load_state_dict (torch .load (local_path ))
32
+ model .to (device = device )
33
+ detections = detection_RGB ([image for image , _ in images_gt ], model )
34
+
35
+ estimations = AgeGenderEstimator (local_path , 64 , 16 , 8 ).estimate ([image for image , _ in images_gt ])
36
+
37
+ for (_ , (is_footwear_gt )), (binary_detection ) in zip (images_gt , detections ):
38
+ binary_detection = int (binary_detection )
39
+ assert is_footwear_gt == binary_detection , f"{ binary_detection = } , { is_footwear_gt = } "
40
+
41
+
42
+ if __name__ == "__main__" :
43
+ test_footwear ()
0 commit comments