Skip to content

Commit 31563bb

Browse files
committed
fix object
1 parent 2238202 commit 31563bb

File tree

3 files changed

+166
-193
lines changed

3 files changed

+166
-193
lines changed

deepstack/core.py

Lines changed: 102 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,39 @@
88
## Const
99
DEFAULT_API_KEY = ""
1010
DEFAULT_TIMEOUT = 10 # seconds
11+
DEFAULT_IP = "localhost"
12+
DEFAULT_PORT = 80
1113

1214
## HTTP codes
1315
HTTP_OK = 200
1416
BAD_URL = 404
1517

1618
## API urls
17-
URL_CUSTOM = "http://{ip}:{port}/v1/vision/custom/{custom_model}"
18-
URL_OBJECT_DETECTION = "http://{ip}:{port}/v1/vision/detection"
19-
URL_FACE_DETECTION = "http://{ip}:{port}/v1/vision/face"
20-
URL_FACE_REGISTER = "http://{ip}:{port}/v1/vision/face/register"
21-
URL_FACE_RECOGNIZE = "http://{ip}:{port}/v1/vision/face/recognize"
22-
URL_SCENE_DETECTION = "http://{ip}:{port}/v1/vision/scene"
19+
URL_BASE_VISION = "http://{ip}:{port}/v1/vision"
20+
URL_CUSTOM = "/custom/{custom_model}"
21+
URL_OBJECT_DETECTION = "/detection"
22+
URL_FACE_DETECTION = "/face"
23+
URL_FACE_REGISTER = "/face/register"
24+
URL_FACE_RECOGNIZE = "/face/recognize"
25+
URL_SCENE_DETECTION = "/scene"
26+
27+
28+
class DeepstackException(Exception):
29+
pass
2330

2431

2532
def format_confidence(confidence: Union[str, float]) -> float:
26-
"""Takes a confidence from the API like
33+
"""
34+
Takes a confidence from the API like
2735
0.55623 and returns 55.6 (%).
2836
"""
2937
DECIMALS = 1
3038
return round(float(confidence) * 100, DECIMALS)
3139

3240

33-
def get_confidences_above_threshold(confidences: List[float], confidence_threshold: float) -> List[float]:
41+
def get_confidences_above_threshold(
42+
confidences: List[float], confidence_threshold: float
43+
) -> List[float]:
3444
"""Takes a list of confidences and returns those above a confidence_threshold."""
3545
return [val for val in confidences if val >= confidence_threshold]
3646

@@ -58,11 +68,15 @@ def get_objects(predictions: List[Dict]) -> List[str]:
5868
return sorted(list(set(labels)))
5969

6070

61-
def get_object_confidences(predictions: List[Dict], target_object: str):
71+
def get_object_confidences(predictions: List[Dict], target_object: str) -> List[float]:
6272
"""
6373
Return the list of confidences of instances of target label.
6474
"""
65-
confidences = [pred["confidence"] for pred in predictions if pred["label"] == target_object]
75+
confidences = [
76+
float(pred["confidence"])
77+
for pred in predictions
78+
if pred["label"] == target_object
79+
]
6680
return confidences
6781

6882

@@ -71,56 +85,66 @@ def get_objects_summary(predictions: List[Dict]):
7185
Get a summary of the objects detected.
7286
"""
7387
objects = get_objects(predictions)
74-
return {target_object: len(get_object_confidences(predictions, target_object)) for target_object in objects}
88+
return {
89+
target_object: len(get_object_confidences(predictions, target_object))
90+
for target_object in objects
91+
}
7592

7693

77-
def post_image(url: str, image_bytes: bytes, api_key: str, timeout: int, data: dict = {}):
78-
"""Post an image to Deepstack."""
94+
def post_image(
95+
url: str, image_bytes: bytes, api_key: str, timeout: int, data: dict = {}
96+
):
97+
"""Post an image to Deepstack. Only handles excpetions."""
7998
try:
8099
data["api_key"] = api_key # Insert the api_key
81-
response = requests.post(url, files={"image": image_bytes}, data=data, timeout=timeout)
82-
if response.status_code == HTTP_OK:
83-
return response
84-
elif response.status_code == BAD_URL:
85-
raise DeepstackException(f"Bad url supplied, url {url} raised error {BAD_URL}")
86-
else:
87-
raise DeepstackException(f"Error from Deepstack request, status code: {response.status_code}")
100+
return requests.post(
101+
url, files={"image": image_bytes}, data=data, timeout=timeout
102+
)
88103
except requests.exceptions.Timeout:
89-
raise DeepstackException(f"Timeout connecting to Deepstack, current timeout is {timeout} seconds")
104+
raise DeepstackException(
105+
f"Timeout connecting to Deepstack, current timeout is {timeout} seconds"
106+
)
90107
except requests.exceptions.ConnectionError as exc:
91108
raise DeepstackException(f"Connection error: {exc}")
92109

93110

94-
class DeepstackException(Exception):
95-
pass
111+
def process_image(url: str, image_bytes: bytes, api_key: str, timeout: int):
112+
"""Process image_bytes and detect. Handles common status codes"""
113+
response = post_image(url, image_bytes, api_key, timeout)
114+
if response.status_code == HTTP_OK:
115+
return response.json()
116+
elif response.status_code == BAD_URL:
117+
raise DeepstackException(f"Bad url supplied, url {url} raised error {BAD_URL}")
118+
else:
119+
raise DeepstackException(
120+
f"Error from Deepstack request, status code: {response.status_code}"
121+
)
96122

97123

98-
class Deepstack(object):
99-
"""Base class for deepstack."""
124+
class DeepstackVision:
125+
"""Base class for Deepstack vision."""
100126

101127
def __init__(
102128
self,
103-
api_key: str = "",
129+
ip: str = DEFAULT_IP,
130+
port: int = DEFAULT_PORT,
131+
api_key: str = DEFAULT_API_KEY,
104132
timeout: int = DEFAULT_TIMEOUT,
105-
url_detect: str = None,
106-
url_recognize: str = None,
107-
url_register: str = None,
133+
url_detect: str = "",
134+
url_recognize: str = "",
135+
url_register: str = "",
108136
):
109137

110-
self._url_detect = url_detect
111-
self._url_recognize = url_recognize
112-
self._url_register = url_register
138+
self._url_base = URL_BASE_VISION.format(ip=ip, port=port)
139+
self._url_detect = self._url_base + url_detect
140+
self._url_recognize = self._url_base + url_recognize
141+
self._url_register = self._url_base + url_register
113142
self._api_key = api_key
114143
self._timeout = timeout
115-
self._response = None
116144

117-
def detect(self, image_bytes: bytes):
145+
def detect(self):
118146
"""Process image_bytes and detect."""
119-
self._response = None
120-
self._response = post_image(self._url_detect, image_bytes, self._api_key, self._timeout).json()
121-
if not self._response["success"]:
122-
error = self._response["error"]
123-
raise DeepstackException(f"Error from Deepstack: {error}")
147+
raise NotImplementedError
124148

125149
def recognize(self):
126150
"""Process image_bytes and recognize."""
@@ -130,59 +154,68 @@ def register(self):
130154
"""Perform a registration."""
131155
raise NotImplementedError
132156

133-
@property
134-
def predictions(self):
135-
"""Return the predictions."""
136-
raise NotImplementedError
137157

138-
139-
class DeepstackObject(Deepstack):
158+
class DeepstackObject(DeepstackVision):
140159
"""Work with objects"""
141160

142161
def __init__(
143162
self,
144-
ip: str,
145-
port: str,
163+
ip: str = DEFAULT_IP,
164+
port: int = DEFAULT_PORT,
146165
api_key: str = DEFAULT_API_KEY,
147166
timeout: int = DEFAULT_TIMEOUT,
148167
custom_model: str = None,
149168
):
150169
if not custom_model:
151170
super().__init__(
152-
api_key, timeout, url_detect=URL_OBJECT_DETECTION.format(ip=ip, port=port),
171+
ip, port, api_key, timeout, url_detect=URL_OBJECT_DETECTION,
153172
)
154173
else:
155174
super().__init__(
156-
api_key, timeout, url_detect=URL_CUSTOM.format(ip=ip, port=port, custom_model=custom_model),
175+
ip,
176+
port,
177+
api_key,
178+
timeout,
179+
url_detect=URL_CUSTOM.format(custom_model=custom_model),
157180
)
158181

159-
@property
160-
def predictions(self):
161-
"""Return the predictions."""
162-
return self._response["predictions"]
182+
def detect(self, image_bytes: bytes):
183+
"""Process image_bytes and detect."""
184+
response_json = process_image(
185+
self._url_detect, image_bytes, self._api_key, self._timeout
186+
)
187+
return response_json["predictions"]
163188

164189

165-
class DeepstackScene(Deepstack):
190+
class DeepstackScene(DeepstackVision):
166191
"""Work with scenes"""
167192

168193
def __init__(
169-
self, ip: str, port: str, api_key: str = DEFAULT_API_KEY, timeout: int = DEFAULT_TIMEOUT,
194+
self,
195+
ip: str,
196+
port: str,
197+
api_key: str = DEFAULT_API_KEY,
198+
timeout: int = DEFAULT_TIMEOUT,
170199
):
171200
super().__init__(
172-
api_key, timeout, url_detect=URL_SCENE_DETECTION.format(ip=self._ip, port=self._port),
201+
api_key, timeout, url_detect=URL_SCENE_DETECTION,
173202
)
174203

175-
@property
176-
def predictions(self):
177-
"""Return the predictions."""
178-
return self._response
204+
def detect(self, image_bytes: bytes):
205+
"""Process image_bytes and detect."""
206+
response_json = process_image(self, image_bytes, self._api_key, self._timeout)
207+
return response_json
179208

180209

181-
class DeepstackFace(Deepstack):
210+
class DeepstackFace(DeepstackVision):
182211
"""Work with objects"""
183212

184213
def __init__(
185-
self, ip: str, port: str, api_key: str = DEFAULT_API_KEY, timeout: int = DEFAULT_TIMEOUT,
214+
self,
215+
ip: str,
216+
port: str,
217+
api_key: str = DEFAULT_API_KEY,
218+
timeout: int = DEFAULT_TIMEOUT,
186219
):
187220
super().__init__(
188221
api_key,
@@ -192,10 +225,10 @@ def __init__(
192225
url_recognize=URL_FACE_RECOGNIZE.format(ip=self._ip, port=self._port),
193226
)
194227

195-
@property
196-
def predictions(self):
197-
"""Return the classifier attributes."""
198-
return self._response["predictions"]
228+
def detect(self, image_bytes: bytes):
229+
"""Process image_bytes and detect."""
230+
response_json = process_image(self, image_bytes, self._api_key, self._timeout)
231+
return response_json
199232

200233
def register(self, name: str, image_bytes: bytes):
201234
"""
@@ -220,7 +253,9 @@ def register(self, name: str, image_bytes: bytes):
220253
def recognize(self, image_bytes: bytes):
221254
"""Process image_bytes, performing recognition."""
222255

223-
response = post_image(self._url_recognize, image_bytes, self._api_key, self._timeout)
256+
response = post_image(
257+
self._url_recognize, image_bytes, self._api_key, self._timeout
258+
)
224259

225260
self._response = response.json()
226261
if not self._response["success"]:

0 commit comments

Comments
 (0)