1
+ """
2
+ @Author: Matheus Teixeira de Sousa (mtsousa14@gmail.com)
3
+
4
+ Detect forks from images with trained YOLOv7 ONNX model
5
+ """
6
+
7
+ import cv2 as cv
8
+ import numpy as np
9
+ import onnxruntime as ort
10
+ from torch .utils .data import DataLoader
11
+ from utils .dataset import TestDataset
12
+ from utils .utils import plot_one_box , show_predicted_image , adjust_image
13
+ from os .path import exists , isdir
14
+ from os import makedirs
15
+ from random import randint
16
+ import argparse
17
+
18
+ def predict_bbox (session , images ):
19
+ """
20
+ Predict bounding boxes from images
21
+ """
22
+ outname = [i .name for i in session .get_outputs ()]
23
+
24
+ dict_output = {}
25
+ for i , samples in enumerate (images ):
26
+ im , ratio , dwdh , name = samples ['image' ], samples ['ratio' ], samples ['dwdh' ], samples ['name' ]
27
+ im = np .ascontiguousarray (im / 255 )
28
+ out = session .run (outname , {'images' :im })[0 ]
29
+ dict_output [f"batch { i } " ] = {"preds" : out , "ratio" : ratio , "dwdh" : dwdh , "name" : name }
30
+
31
+ return dict_output
32
+
33
+ if __name__ == '__main__' :
34
+ # Parse command line arguments
35
+ parser = argparse .ArgumentParser (
36
+ description = 'Predict with YOLOv7-fork ONNX model' )
37
+
38
+ parser .add_argument ('--model' , required = True ,
39
+ metavar = '/path/to/model.onnx' ,
40
+ help = "Path to ONNX model" )
41
+ parser .add_argument ('--input' , required = True ,
42
+ help = "Path to images (path/to/images) or path to image (path/to/image.jpg)" )
43
+ parser .add_argument ('--batch' , default = 1 ,
44
+ help = "Batch size" )
45
+ parser .add_argument ('--save' , default = False , action = 'store_true' ,
46
+ help = "Save predicted image" )
47
+ parser .add_argument ('--dontshow' , default = False , action = 'store_true' ,
48
+ help = "Don't show predicted image" )
49
+ parser .add_argument ('--cuda' , default = False , action = 'store_true' ,
50
+ help = "Set execution on GPU" )
51
+
52
+ args = parser .parse_args ()
53
+ for key , value in args ._get_kwargs ():
54
+ if value is not None :
55
+ print (f'{ key .capitalize ()} : { value } ' )
56
+ print ()
57
+
58
+ # Check if the input is a dir
59
+ input_isdir = isdir (args .input )
60
+
61
+ # Load the model
62
+ print ('Loading model...' , flush = True )
63
+ providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ] if args .cuda else ['CPUExecutionProvider' ]
64
+ session = ort .InferenceSession (args .model , providers = providers )
65
+
66
+ # Get output name and input shape
67
+ outname = [i .name for i in session .get_outputs ()]
68
+ input_shape = session .get_inputs ()[0 ].shape
69
+ h , w = input_shape [2 ], input_shape [3 ]
70
+
71
+ # Load the images
72
+ print ('Loading images...' , flush = True )
73
+ if input_isdir :
74
+ dataset = TestDataset (args .input , shape = (h , w ))
75
+ images = DataLoader (dataset , batch_size = args .batch , shuffle = False , num_workers = 0 )
76
+ else :
77
+ images = [adjust_image (args .input , shape = (h , w ))]
78
+
79
+ # Predict from images
80
+ print ('Making predictions...' , flush = True )
81
+ dict_output = predict_bbox (session , images )
82
+
83
+ names = ['fork' ]
84
+ colors = {name : [randint (0 , 255 ) for _ in range (3 )] for name in names }
85
+ # colors = {name: [104, 184, 82] for name in names} # green
86
+
87
+ if args .save and not exists (f'data/responses' ):
88
+ makedirs (f'data/responses' )
89
+
90
+ # For each image, plot the results
91
+ print ('Plotting results...' , flush = True )
92
+ for i , key in enumerate (dict_output .keys ()):
93
+ pred , ratio , dwdh , name = dict_output [key ]['preds' ], dict_output [key ]['ratio' ][0 ], dict_output [key ]['dwdh' ], dict_output [key ]['name' ][0 ]
94
+ ratio = float (ratio )
95
+ dwdh = float (dwdh [0 ]), float (dwdh [1 ])
96
+
97
+ # Load original image
98
+ if input_isdir :
99
+ image = dataset .__getsrc__ (i )
100
+ else :
101
+ image = cv .imread (args .input )
102
+
103
+ # Adjust bounding box to original image
104
+ for prediction in pred :
105
+ batch_id , x0 , y0 , x1 , y1 , cls_id , score = prediction
106
+ box = np .array ([x0 ,y0 ,x1 ,y1 ])
107
+ box -= np .array (dwdh * 2 )
108
+ box /= ratio
109
+ box = box .round ().astype (np .int32 ).tolist ()
110
+ cls_id = int (cls_id )
111
+ score = round (float (score ),3 )
112
+ label = names [cls_id ]
113
+ color = colors [label ]
114
+ label += ' ' + str (score )
115
+ plot_one_box (box , image , label = label , color = color , line_thickness = 1 )
116
+
117
+ if args .save :
118
+ path = 'data/responses/' + name
119
+ cv .imwrite (path , image )
120
+
121
+ if not args .dontshow :
122
+ show_predicted_image (image )
123
+
0 commit comments