-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgt_position_extractor.py
91 lines (79 loc) · 3.43 KB
/
gt_position_extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from typing import List
import pandas as pd
import os
from argparse import ArgumentParser
import xml.etree.ElementTree as ElementTree
class GTObject:
def __init__(self, line):
data = line.split(',')
self.frame = int(data[0])
self.x = int(data[1])
self.y = int(data[2])
self.width = int(data[3])
self.height = int(data[4])
self.class_name = data[5]
class GtPositionExtractor:
def __init__(self, file_name, train_folder_path = None, file_prefix = "./sample/ILSVRC2015/Annotations/VID/"):
self.file_name = file_name
self.train_folder_path = train_folder_path
self.folder_path: str = file_prefix
csv_folder_path = './output/'
if('val' in self.file_name):
self.folder_path = self.folder_path + 'val/' + self.file_name
csv_folder_path += self.file_name
elif('train' in self.file_name):
self.folder_path = self.folder_path + 'train/' + self.train_folder_path + '/' + self.file_name
csv_folder_path += self.train_folder_path + '/' + self.file_name
csv_folder_path += '/ground_truth'
if not os.path.exists(csv_folder_path):
os.makedirs(csv_folder_path)
self.csv_file_path = csv_folder_path + '/ground_truth.csv'
def run(self):
file_number = 0
self.result:List[List] = []
try :
while True:
file_name = str(file_number).zfill(6)
self.xml_parser(file_name)
file_number += 1
except Exception as ex:
print(ex)
csv = pd.DataFrame(self.result)
csv.to_csv(self.csv_file_path, header=False, index=False, mode = 'w')
def xml_parser(self, file_name):
tree = ElementTree.parse(f'{self.folder_path}/{file_name}.xml')
obj_list = tree.findall('object')
for obj in obj_list:
class_id = obj.find('name').text
box = obj.find('bndbox')
x_max = box.find('xmax').text
x_min = box.find('xmin').text
y_max = box.find('ymax').text
y_min = box.find('ymin').text
width = int(x_max) - int(x_min)
height = int(y_max) - int(y_min)
self.result.append([file_name, x_min, y_min, width, height, class_id, 0])
def get_gtobjects_from_csv(self):
gt_objects = {}
try:
with open(self.csv_file_path, 'r') as csv_file:
for line in csv_file.readlines():
if line[-1] == '\n':
line = line[:-1]
gt_object = GTObject(line)
frame = str(gt_object.frame)
if not frame in gt_objects:
gt_objects[frame] = []
gt_objects[frame].append(gt_object)
except FileNotFoundError:
print(f''' [Error] Failed to read csv file: no such file named {self.csv_file_path}''')
return gt_objects
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--video_file_name', type=str, help='file name without extension')
parser.add_argument('--train_folder_path', type=str, help="only required for 'train'")
args = parser.parse_args()
file_name = args.video_file_name
train_folder_path = args.train_folder_path
gtPositionExtractor = GtPositionExtractor(file_name, train_folder_path=train_folder_path)
gtPositionExtractor.run()