Skip to content

Commit 79f8f37

Browse files
authored
Merge pull request #5 from robmarkcole/development
Development
2 parents c698108 + 1ea0023 commit 79f8f37

File tree

10 files changed

+298
-461
lines changed

10 files changed

+298
-461
lines changed

.devcontainer/Dockerfile

Lines changed: 0 additions & 37 deletions
This file was deleted.

.devcontainer/devcontainer.json

Lines changed: 0 additions & 21 deletions
This file was deleted.

.devcontainer/noop.txt

Lines changed: 0 additions & 3 deletions
This file was deleted.

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,6 @@ dmypy.json
115115
# Pyre type checker
116116
.pyre/
117117
.DS_Store
118+
119+
venv
120+
.vscode

deepstack/core.py

Lines changed: 77 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,99 +1,117 @@
11
"""
22
Deepstack core.
33
"""
4-
import imghdr
54
import requests
65
from PIL import Image
6+
from typing import Union, List, Set, Dict
77

88
## Const
9-
CLASSIFIER = "deepstack"
109
HTTP_OK = 200
1110
HTTP_BAD_REQUEST = 400
1211
HTTP_UNAUTHORIZED = 401
13-
TIMEOUT = 30 # seconds
12+
DEFAULT_TIMEOUT = 10 # seconds
1413

1514

16-
def get_matched_faces(predictions: dict):
15+
def format_confidence(confidence: Union[str, float]) -> float:
16+
"""Takes a confidence from the API like
17+
0.55623 and returne 55.6 (%).
1718
"""
18-
Get the predicted faces and their confidence.
19+
return round(float(confidence) * 100, 1)
20+
21+
22+
def get_confidences_above_threshold(
23+
confidences: List[float], confidence_threshold: float
24+
) -> List[float]:
25+
"""Takes a list of confidences and returns those above a confidence_threshold."""
26+
return [val for val in confidences if val >= confidence_threshold]
27+
28+
29+
def get_object_labels(predictions: List[Dict]) -> Set[str]:
1930
"""
20-
try:
21-
matched_faces = {
22-
face["userid"]: round(face["confidence"] * 100, 1)
23-
for face in predictions
24-
if not face["userid"] == "unknown"
25-
}
26-
return matched_faces
27-
except:
28-
return {}
31+
Get a list of the unique object labels predicted.
32+
"""
33+
labels = [pred["label"] for pred in predictions]
34+
return set(labels)
2935

3036

31-
def is_valid_image(file_path: str):
37+
def get_label_confidences(predictions: List[Dict], target_label: str):
3238
"""
33-
Check file_path is valid image, using PIL then imghdr.
39+
Return the list of confidences of instances of target label.
3440
"""
35-
try:
36-
with Image.open(file_path):
37-
pass
41+
confidences = [
42+
pred["confidence"] for pred in predictions if pred["label"] == target_label
43+
]
44+
return confidences
3845

39-
image_extension = imghdr.what(file_path)
40-
if image_extension in ["jpeg", ".jpg", ".png"]:
41-
return True
42-
return False
43-
except Exception as exc:
44-
print(exc)
45-
return False
4646

47+
def get_objects_summary(predictions: List[Dict]):
48+
"""
49+
Get a summary of the objects detected.
50+
"""
51+
labels = get_object_labels(predictions)
52+
return {
53+
label: len(get_label_confidences(predictions, target_label=label))
54+
for label in labels
55+
}
4756

48-
def post_image(url: str, image: bytes):
49-
"""Post an image to the classifier."""
57+
58+
def post_image(url: str, image: bytes, api_key: str, timeout: int):
59+
"""Post an image to Deepstack."""
5060
try:
51-
response = requests.post(url, files={"image": image}, timeout=TIMEOUT)
61+
response = requests.post(
62+
url, files={"image": image}, data={"api_key": api_key}, timeout=timeout
63+
)
5264
return response
53-
except requests.exceptions.ConnectionError:
54-
print("ConnectionError: Is %s running?", CLASSIFIER)
55-
return None
5665
except requests.exceptions.Timeout:
57-
print("Timeout error from %s", CLASSIFIER)
58-
return None
66+
raise DeepstackException(
67+
f"Timeout connecting to Deepstack, current timeout is {timeout} seconds"
68+
)
69+
5970

71+
class DeepstackException(Exception):
72+
pass
6073

61-
class DeepstackFace:
62-
"""Work with faces."""
6374

64-
def __init__(self, ip_address: str, port: str):
75+
class DeepstackObject:
76+
"""The object detection API locates and classifies 80
77+
different kinds of objects in a single image.."""
6578

66-
self._url_check = "http://{}:{}/v1/vision/face/recognize".format(
79+
def __init__(
80+
self,
81+
ip_address: str,
82+
port: str,
83+
api_key: str = "",
84+
timeout: int = DEFAULT_TIMEOUT,
85+
):
86+
87+
self._url_object_detection = "http://{}:{}/v1/vision/detection".format(
6788
ip_address, port
6889
)
69-
70-
self._faces = None
71-
self._matched = {}
90+
self._api_key = api_key
91+
self._timeout = timeout
92+
self._predictions = []
7293

7394
def process_file(self, file_path: str):
7495
"""Process an image file."""
75-
if is_valid_image(file_path):
76-
with open(file_path, "rb") as image_bytes:
77-
self.process_image_bytes(image_bytes)
96+
with open(file_path, "rb") as image_bytes:
97+
self.process_image_bytes(image_bytes)
7898

7999
def process_image_bytes(self, image_bytes: bytes):
80100
"""Process an image."""
81-
response = post_image(self._url_check, image_bytes)
82-
if response:
83-
if response.status_code == HTTP_OK:
84-
predictions_json = response.json()["predictions"]
85-
self._faces = len(predictions_json)
86-
self._matched = get_matched_faces(predictions_json)
101+
self._predictions = []
102+
103+
response = post_image(
104+
self._url_object_detection, image_bytes, self._api_key, self._timeout
105+
)
87106

88-
else:
89-
self._faces = None
90-
self._matched = {}
107+
if response.status_code == HTTP_OK:
108+
if response.json()["success"]:
109+
self._predictions = response.json()["predictions"]
110+
else:
111+
error = response.json()["error"]
112+
raise DeepstackException(f"Error from Deepstack: {error}")
91113

92114
@property
93-
def attributes(self):
115+
def predictions(self):
94116
"""Return the classifier attributes."""
95-
return {
96-
"faces": self._faces,
97-
"matched_faces": self._matched,
98-
"total_matched_faces": len(self._matched),
99-
}
117+
return self._predictions

0 commit comments

Comments
 (0)