Skip to content

Commit 7e061f6

Browse files
authored
Add files via upload
1 parent dc5c413 commit 7e061f6

28 files changed

+203
-0
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Sun Feb 14 15:26:57 2021
4+
5+
@author: trabz
6+
"""
7+
import albumentations as A
8+
9+
import torch
10+
11+
12+
from xmlParser import Parser
13+
import glob
14+
import numpy as np
15+
from matplotlib import pyplot as plt
16+
import cv2
17+
import torchvision.transforms as transforms
18+
19+
20+
class dataloader():
21+
22+
def __init__(self,path,transform=None):
23+
24+
self.paths=glob.glob(path)
25+
self.transform=transform
26+
27+
28+
def __len__(self):
29+
return len(self.paths)
30+
31+
def __getitem__(self,idx):
32+
33+
annotation=(Parser.myType(self.paths[idx],idx,classes=['bird','zebra']))
34+
image=plt.imread(self.paths[annotation['image_id']].replace('xml','jpg'))
35+
36+
if self.transform is not None:
37+
38+
augmented = self.transform(image=image, bboxes=annotation['bbox'], labels=annotation['label'])
39+
40+
41+
return image,augmented,annotation
42+
43+
44+
def collate_fn(batch):
45+
return tuple(zip(*batch))
46+
47+
48+
bbox_params = A.BboxParams(
49+
format='pascal_voc',
50+
min_area=1,
51+
min_visibility=0.5,
52+
label_fields=['labels']
53+
)
54+
55+
aug = A.Compose({
56+
#A.Resize(500, 500,p=0.2),
57+
A.RGBShift(r_shift_limit=40,g_shift_limit=40,b_shift_limit=40,p=0.04),
58+
A.RandomBrightness(p=0.01),
59+
A.RandomContrast(p=0.01),
60+
A.CLAHE(p=0.02),
61+
A.ToGray(p=0.4),
62+
A.Blur(blur_limit=8,p=0.1),
63+
A.RandomBrightness(p=0.1),
64+
A.CenterCrop(100, 100,p=0.01),
65+
A.RandomCrop(222, 222,p=0.1),
66+
A.HorizontalFlip(p=0.1),
67+
A.Rotate(limit=(-90, 90),p=0.2),
68+
A.VerticalFlip(p=0.1),
69+
A.ShiftScaleRotate(),
70+
71+
},bbox_params=bbox_params)
72+
73+
74+
path=path= 'Image A*/train/*.xml'
75+
76+
dataset = dataloader(path,aug)
77+
data_loader = torch.utils.data.DataLoader(
78+
dataset, batch_size=1, collate_fn=collate_fn,shuffle=False)
79+
80+
81+
82+
83+
plt.figure(figsize=(12,12))
84+
for idx,(image,imgO,result) in enumerate(data_loader):
85+
86+
87+
imgA=imgO[0]['image']
88+
image=image[0]
89+
for idc,bbox in enumerate(imgO[0]['bboxes']):
90+
xmin,ymin,xmax,ymax=bbox
91+
xmin,xmax=np.clip([xmin,xmax],5,imgA.shape[0]-5).astype('int')
92+
ymin,ymax=np.clip([ymin,ymax],5,imgA.shape[1]-5).astype('int')
93+
94+
xpastmin,ypastmin,xpastmax,ypastmax=np.clip((result[0]['bbox'][idc]),0,max(image.shape)-10)
95+
96+
imgA=cv2.rectangle(np.array(imgA),((xmin),(ymin)),((xmax),(ymax))
97+
,color=[0,245,0],thickness=4)
98+
99+
image=cv2.rectangle(image,((xpastmin),(ypastmin)),((xpastmax),(ypastmax))
100+
,color=[112,9,11],thickness=4)
101+
102+
103+
104+
105+
sz,wd,_=np.array(image.shape)-np.array((imgA).shape)
106+
img2=imgA
107+
imgA=np.pad((imgA),((sz//2,sz//2),(wd//2,wd//2),(0,0)))
108+
109+
110+
plt.imsave(f'Albumentations/images/{idx}.png',np.hstack(((imgA).astype('uint8'),image)))
111+
plt.axis('off')
112+
plt.tight_layout()
113+
#plt.imshow(np.dstack((img[:,:,2],img[:,:,1],img[:,:,0])))
114+
115+
116+
117+
118+

Albumentations/images/0.png

265 KB
Loading

Albumentations/images/1.png

319 KB
Loading

Albumentations/images/10.png

375 KB
Loading

Albumentations/images/11.png

451 KB
Loading

Albumentations/images/12.png

161 KB
Loading

Albumentations/images/13.png

399 KB
Loading

Albumentations/images/14.png

441 KB
Loading

Albumentations/images/15.png

383 KB
Loading

Albumentations/images/16.png

489 KB
Loading

Albumentations/images/17.png

473 KB
Loading

Albumentations/images/18.png

494 KB
Loading

Albumentations/images/19.png

473 KB
Loading

Albumentations/images/2.png

399 KB
Loading

Albumentations/images/20.png

456 KB
Loading

Albumentations/images/21.png

238 KB
Loading

Albumentations/images/22.png

190 KB
Loading

Albumentations/images/23.png

479 KB
Loading

Albumentations/images/24.png

641 KB
Loading

Albumentations/images/25.png

540 KB
Loading

Albumentations/images/3.png

681 KB
Loading

Albumentations/images/4.png

113 KB
Loading

Albumentations/images/5.png

512 KB
Loading

Albumentations/images/6.png

261 KB
Loading

Albumentations/images/7.png

537 KB
Loading

Albumentations/images/8.png

274 KB
Loading

Albumentations/images/9.png

482 KB
Loading

Albumentations/xmlParser.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Thu Feb 11 21:09:43 2021
4+
5+
@author: trabz
6+
"""
7+
import random
8+
import string
9+
10+
from timeit import default_timer as timer
11+
import glob
12+
13+
from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage
14+
15+
16+
class Parser():
17+
18+
def parse_etree_lxml(file):
19+
from lxml import etree as etree_lxml
20+
#parse xml
21+
xml_as_bytes = Parser.sample_xml('rb',file=file)
22+
23+
timer_start = timer()
24+
25+
print('[etree lxml] Starting to parse XML')
26+
27+
tree = etree_lxml.fromstring(xml_as_bytes)
28+
## Find <object> <object\> in the xml
29+
xml_etree_lxml = tree.findall('object')
30+
31+
seconds = timer() - timer_start
32+
33+
print(f'[etree lxml] Finished parsing XML in {seconds} seconds')
34+
return xml_etree_lxml
35+
def sample_xml(opts,file):
36+
"""Return the sample XML file as a string."""
37+
with open(file, opts) as xml:
38+
return xml.read()
39+
40+
41+
def recursive_dict(element):
42+
return element.tag, \
43+
dict(map(Parser.recursive_dict, element)) or element.text
44+
#return all elements in xml with recursive way
45+
46+
def finalRun(path):
47+
listw=glob.glob(path)
48+
#taking all path
49+
alls=[]
50+
finals=[]
51+
for xmls in listw:
52+
a0=Parser.parse_etree_lxml(xmls) # parse xmls
53+
for obj in a0:
54+
alls.append(Parser.recursive_dict(obj)[1]) #append all the dicts
55+
56+
finals.append(alls)
57+
alls=[]
58+
def getValue(element,name):
59+
result=element.find(name).text
60+
return element.tag,name, result
61+
def myType(path,idx,classes=['bird','zebra']):
62+
objects=Parser.parse_etree_lxml(path)
63+
boxes=[]
64+
labels=[]
65+
bbs=[]
66+
for obj in objects:
67+
xmin=int(Parser.getValue(obj.find('bndbox'),'xmin')[2])
68+
ymin=int(Parser.getValue(obj.find('bndbox'),'ymin')[2])
69+
xmax=int(Parser.getValue(obj.find('bndbox'),'xmax')[2])
70+
ymax=int(Parser.getValue(obj.find('bndbox'),'ymax')[2])
71+
boxes.append([xmin,ymin,xmax,ymax])
72+
bbs.append(BoundingBox(x1=xmin,y1=ymin,x2=xmax,y2=ymax))
73+
label=Parser.getValue(obj,'name')
74+
label= classes.index(label[2])
75+
labels.append(label)
76+
77+
return {'image_id':idx,'label':labels,'bbox':boxes,'bbs':bbs}
78+
79+
80+
81+
82+
83+
84+
85+

0 commit comments

Comments
 (0)