Skip to content

Commit 354619c

Browse files
committed
Refactor core
1 parent 4f20fcb commit 354619c

File tree

3 files changed

+147
-51
lines changed

3 files changed

+147
-51
lines changed

deepstack/core.py

Lines changed: 84 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919

2020

2121
def format_confidence(confidence: Union[str, float]) -> float:
22-
"""Takes a confidence from the API like
23-
0.55623 and returns 55.6 (%).
22+
"""Takes a confidence from the API like
23+
0.55623 and returns 55.6 (%).
2424
"""
2525
return round(float(confidence) * 100, 1)
2626

2727

28-
def get_confidences_above_threshold(confidences: List[float], confidence_threshold: float) -> List[float]:
28+
def get_confidences_above_threshold(
29+
confidences: List[float], confidence_threshold: float
30+
) -> List[float]:
2931
"""Takes a list of confidences and returns those above a confidence_threshold."""
3032
return [val for val in confidences if val >= confidence_threshold]
3133

@@ -57,7 +59,9 @@ def get_object_confidences(predictions: List[Dict], target_object: str):
5759
"""
5860
Return the list of confidences of instances of target label.
5961
"""
60-
confidences = [pred["confidence"] for pred in predictions if pred["label"] == target_object]
62+
confidences = [
63+
pred["confidence"] for pred in predictions if pred["label"] == target_object
64+
]
6165
return confidences
6266

6367

@@ -66,17 +70,26 @@ def get_objects_summary(predictions: List[Dict]):
6670
Get a summary of the objects detected.
6771
"""
6872
objects = get_objects(predictions)
69-
return {target_object: len(get_object_confidences(predictions, target_object)) for target_object in objects}
73+
return {
74+
target_object: len(get_object_confidences(predictions, target_object))
75+
for target_object in objects
76+
}
7077

7178

72-
def post_image(url: str, image_bytes: bytes, api_key: str, timeout: int, data: dict = {}):
79+
def post_image(
80+
url: str, image_bytes: bytes, api_key: str, timeout: int, data: dict = {}
81+
):
7382
"""Post an image to Deepstack."""
7483
try:
7584
data["api_key"] = api_key
76-
response = requests.post(url, files={"image": image_bytes}, data=data, timeout=timeout)
85+
response = requests.post(
86+
url, files={"image": image_bytes}, data=data, timeout=timeout
87+
)
7788
return response
7889
except requests.exceptions.Timeout:
79-
raise DeepstackException(f"Timeout connecting to Deepstack, current timeout is {timeout} seconds")
90+
raise DeepstackException(
91+
f"Timeout connecting to Deepstack, current timeout is {timeout} seconds"
92+
)
8093
except requests.exceptions.ConnectionError as exc:
8194
raise DeepstackException(f"Connection error: {exc}")
8295

@@ -89,25 +102,32 @@ class Deepstack(object):
89102
"""Base class for deepstack."""
90103

91104
def __init__(
92-
self, ip: str, port: str, api_key: str = "", timeout: int = DEFAULT_TIMEOUT, url: str = "",
105+
self,
106+
api_key: str = "",
107+
timeout: int = DEFAULT_TIMEOUT,
108+
url_detect: str = None,
109+
url_recognise: str = None,
110+
url_register: str = None,
93111
):
94112

95-
self._ip = ip
96-
self._port = port
97-
self._url = url
113+
self._url_detect = url_detect
114+
self._url_recognise = url_recognise
115+
self._url_register = url_register
98116
self._api_key = api_key
99117
self._timeout = timeout
100118
self._response = None
101119

102120
def detect(self, image_bytes: bytes):
103121
"""Process image_bytes, performing detection."""
104122
self._response = None
105-
url = self._url.format(ip=self._ip, port=self._port)
106-
107-
response = post_image(url, image_bytes, self._api_key, self._timeout)
123+
response = post_image(
124+
self._url_detect, image_bytes, self._api_key, self._timeout
125+
)
108126

109127
if not response.status_code == HTTP_OK:
110-
raise DeepstackException(f"Error from request, status code: {response.status_code}")
128+
raise DeepstackException(
129+
f"Error from request, status code: {response.status_code}"
130+
)
111131
return
112132

113133
self._response = response.json()
@@ -125,12 +145,27 @@ class DeepstackObject(Deepstack):
125145
"""Work with objects"""
126146

127147
def __init__(
128-
self, ip: str, port: str, api_key: str = "", timeout: int = DEFAULT_TIMEOUT, custom_model: str = None,
148+
self,
149+
ip: str,
150+
port: str,
151+
api_key: str = "",
152+
timeout: int = DEFAULT_TIMEOUT,
153+
custom_model: str = None,
129154
):
130155
if not custom_model:
131-
super().__init__(ip, port, api_key, timeout, url=URL_OBJECT_DETECTION)
156+
super().__init__(
157+
api_key,
158+
timeout,
159+
url_detect=URL_OBJECT_DETECTION.format(ip=ip, port=port),
160+
)
132161
else:
133-
super().__init__(ip, port, api_key, timeout, url=URL_CUSTOM.format(custom_model=custom_model))
162+
super().__init__(
163+
api_key,
164+
timeout,
165+
url_detect=URL_CUSTOM.format(
166+
ip=ip, port=port, custom_model=custom_model
167+
),
168+
)
134169

135170
@property
136171
def predictions(self):
@@ -142,9 +177,17 @@ class DeepstackScene(Deepstack):
142177
"""Work with scenes"""
143178

144179
def __init__(
145-
self, ip: str, port: str, api_key: str = "", timeout: int = DEFAULT_TIMEOUT,
180+
self,
181+
ip: str,
182+
port: str,
183+
api_key: str = "",
184+
timeout: int = DEFAULT_TIMEOUT,
146185
):
147-
super().__init__(ip, port, api_key, timeout, url=URL_SCENE_DETECTION)
186+
super().__init__(
187+
api_key,
188+
timeout,
189+
url_detect=URL_SCENE_DETECTION.format(ip=self._ip, port=self._port),
190+
)
148191

149192
@property
150193
def predictions(self):
@@ -156,9 +199,19 @@ class DeepstackFace(Deepstack):
156199
"""Work with objects"""
157200

158201
def __init__(
159-
self, ip: str, port: str, api_key: str = "", timeout: int = DEFAULT_TIMEOUT,
202+
self,
203+
ip: str,
204+
port: str,
205+
api_key: str = "",
206+
timeout: int = DEFAULT_TIMEOUT,
160207
):
161-
super().__init__(ip, port, api_key, timeout, url=URL_FACE_DETECTION)
208+
super().__init__(
209+
api_key,
210+
timeout,
211+
url_detect=URL_FACE_DETECTION.format(ip=self._ip, port=self._port),
212+
url_register=URL_FACE_REGISTRATION.format(ip=self._ip, port=self._port),
213+
url_recognise=URL_FACE_RECOGNITION.format(ip=self._ip, port=self._port),
214+
)
162215

163216
@property
164217
def predictions(self):
@@ -171,7 +224,7 @@ def register_face(self, name: str, image_bytes: bytes):
171224
"""
172225

173226
response = post_image(
174-
url=URL_FACE_REGISTRATION.format(ip=self._ip, port=self._port),
227+
url=self._url_register,
175228
image_bytes=image_bytes,
176229
api_key=self._api_key,
177230
timeout=self._timeout,
@@ -180,18 +233,22 @@ def register_face(self, name: str, image_bytes: bytes):
180233

181234
if response.status_code == 200 and response.json()["success"] == True:
182235
return
236+
183237
elif response.status_code == 200 and response.json()["success"] == False:
184238
error = response.json()["error"]
185239
raise DeepstackException(f"Error from Deepstack: {error}")
186240

187241
def recognise(self, image_bytes: bytes):
188242
"""Process image_bytes, performing recognition."""
189-
url = URL_FACE_RECOGNITION.format(ip=self._ip, port=self._port)
190243

191-
response = post_image(url, image_bytes, self._api_key, self._timeout)
244+
response = post_image(
245+
self._url_recognise, image_bytes, self._api_key, self._timeout
246+
)
192247

193248
if not response.status_code == HTTP_OK:
194-
raise DeepstackException(f"Error from request, status code: {response.status_code}")
249+
raise DeepstackException(
250+
f"Error from request, status code: {response.status_code}"
251+
)
195252
return
196253

197254
self._response = response.json()

tests/images/masked.jpg

66.3 KB
Loading

usage-object-detection.ipynb

Lines changed: 63 additions & 24 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)