Skip to content

Commit c172873

Browse files
GustavoDCCMatthijsBurgh
authored andcommitted
Add params to load model and use for detection
1 parent 4325f07 commit c172873

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import print_function
4+
import argparse
5+
from image_recognition_footwear.model import Model
6+
from image_recognition_footwear.process_data import heroPreprocess, detection_RGB
7+
from PIL import Image
8+
import os
9+
import torch
10+
11+
# Assign description to the help doc
12+
parser = argparse.ArgumentParser(description='Get footwear detected using PyTorch')
13+
14+
# Add arguments
15+
parser.add_argument('image', type=str, help='Image')
16+
parser.add_argument('--weights-path', type=str, help='Path to the weights of the VGG model',
17+
default=os.path.expanduser('~/data/pytorch_models/footwearModel.pth'))
18+
19+
parser.add_argument('--input-channel', type=int, help='Size of the input model channel', default=3)
20+
parser.add_argument('--channel1-size', type=int, help='Size channel 1', default=128)
21+
parser.add_argument('--channel2-size', type=int, help='Size channel 2', default=256)
22+
parser.add_argument('--channel3-size', type=int, help='Size channel 3', default=512)
23+
parser.add_argument('--nodes-fclayer1-size', type=int, help='Size fully connected layer 1 neurons', default=1024)
24+
parser.add_argument('--nodes-fclayer2-size', type=int, help='Size fully connected layer 2 neurons', default=1024)
25+
parser.add_argument('--class-size', type=int, help='Classes of the network', default=2)
26+
27+
device = torch.device('cuda')
28+
dtype = torch.float32
29+
30+
args = parser.parse_args()
31+
32+
# Read the image and preprocess
33+
img = Image.open(args.image)
34+
preprocessed_img = heroPreprocess(img)
35+
36+
# Load the model
37+
model = Model(in_channel=args.input_channel, channel_1=args.channel1_size, channel_2=args.channel2_size, channel_3=args.channel3_size, node_1=args.nodes_fclayer1_size, node_2=args.nodes_fclayer2_size, num_classes=args.class_size)
38+
model.load_state_dict(torch.load(args.weights_path))
39+
model.to(device=device)
40+
41+
# Detection
42+
detector = detection_RGB(preprocessed_img, model)
43+
44+
print(detector)
45+

0 commit comments

Comments
 (0)