Skip to content

Commit 4325f07

Browse files
GustavoDCCMatthijsBurgh
authored andcommitted
Add model files
1 parent a99c97c commit 4325f07

File tree

4 files changed

+101
-0
lines changed

4 files changed

+101
-0
lines changed

image_recognition_footwear/scripts/get_footwear

Whitespace-only changes.
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
import torch.nn as nn
3+
class Model(nn.Module):
4+
def __init__(self, in_channels, channel_1, channel_2, channel_3, \
5+
node_1, node_2, num_classes):
6+
super().__init__()
7+
####### Convolutional layers ######
8+
self.conv1 = nn.Sequential(
9+
nn.Conv2d(in_channels, channel_1, kernel_size=3, padding=1, stride=1),
10+
nn.BatchNorm2d(channel_1),
11+
nn.LeakyReLU(),
12+
nn.Conv2d(channel_1, channel_1, kernel_size=3, padding=1, stride=1),
13+
nn.BatchNorm2d(channel_1),
14+
nn.LeakyReLU(),
15+
nn.MaxPool2d(kernel_size=2, stride=2),
16+
)
17+
self.conv2 = nn.Sequential(
18+
nn.Conv2d(channel_1, channel_2, kernel_size=3, padding=1, stride=1),
19+
nn.BatchNorm2d(channel_2),
20+
nn.LeakyReLU(),
21+
nn.Conv2d(channel_2, channel_2, kernel_size=3, padding=1, stride=1),
22+
nn.BatchNorm2d(channel_2),
23+
nn.LeakyReLU(),
24+
nn.MaxPool2d(kernel_size=2, stride=2),
25+
)
26+
self.conv3 = nn.Sequential(
27+
nn.Conv2d(channel_2, channel_3, kernel_size=3, padding=1, stride=1),
28+
nn.BatchNorm2d(channel_3),
29+
nn.LeakyReLU(),
30+
nn.Conv2d(channel_3, channel_3, kernel_size=3, padding=1, stride=1),
31+
nn.BatchNorm2d(channel_3),
32+
nn.LeakyReLU(),
33+
nn.MaxPool2d(kernel_size=7, stride=2),
34+
)
35+
36+
######## Affine layers ########
37+
self.fc = nn.Sequential(
38+
nn.Flatten(),
39+
nn.Linear(channel_3, node_1),
40+
nn.BatchNorm1d(node_1),
41+
nn.Dropout(p=0.5),
42+
43+
nn.Linear(node_1, node_2),
44+
nn.BatchNorm1d(node_2),
45+
46+
nn.Linear(node_2, num_classes)
47+
)
48+
49+
def forward(self, x):
50+
x = self.conv1(x)
51+
x = self.conv2(x)
52+
x = self.conv3(x)
53+
54+
scores = self.fc(x)
55+
return scores
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
from torchvision import transforms as T
3+
import torch
4+
from PIL import Image, ImageOps
5+
6+
def preprocess_RGB(img):
7+
"""preproces image:
8+
input is a PIL image.
9+
Output image should be pytorch tensor that is compatible with your model"""
10+
img = T.functional.resize(img, size=(32, 32), interpolation=Image.NEAREST)
11+
trans = T.Compose([T.ToTensor(),T.Grayscale(num_output_channels=3),T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
12+
img = trans(img)
13+
img = img.unsqueeze(0)
14+
15+
return img
16+
def heroPreprocess(img):
17+
"""preproces image:
18+
expected input is a PIL image from Hero.
19+
Output image should be pytorch tensor that is compatible with your model"""
20+
width, height = img.size # Hero image size (640x480)
21+
left = width/2 - 100
22+
top = height/2 + 140
23+
right = width/2
24+
bottom = height
25+
im1 = img.crop((left, top, right, bottom))
26+
img2 = T.functional.resize(im1, size=(32, 32), interpolation=Image.NEAREST)
27+
trans = T.Compose([T.ToTensor(),T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
28+
img_trans = trans(img2)
29+
img_trans = img_trans.unsqueeze(0)
30+
31+
return img_trans
32+
33+
def detection_RGB(img, model):
34+
"""Detection of foortwear:
35+
Input is a preprocessed image to provide to the model.
36+
Output should be binary classification [True, False], where True is the detection of the footwear."""
37+
model.eval()
38+
info = next(model.parameters()) # Retrieve the first parameter tensor from the iterator
39+
device = info.device
40+
dtype = info.dtype
41+
with torch.no_grad():
42+
img = img.to(device=device, dtype=dtype)
43+
scores = model(img)
44+
preds = torch.argmax(scores, axis=1)
45+
score_max_numpy = int(preds.cpu().detach().numpy())
46+
return score_max_numpy

image_recognition_footwear/test/test_footwear.py

Whitespace-only changes.

0 commit comments

Comments
 (0)