Skip to content

Commit 4f20fcb

Browse files
committed
Add custom
1 parent 063f722 commit 4f20fcb

File tree

1 file changed

+30
-66
lines changed

1 file changed

+30
-66
lines changed

deepstack/core.py

Lines changed: 30 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,22 @@
1010
DEFAULT_TIMEOUT = 10 # seconds
1111

1212
## API urls
13-
URL_OBJECT_DETECTION = "http://{}:{}/v1/vision/detection"
14-
URL_FACE_DETECTION = "http://{}:{}/v1/vision/face"
15-
URL_FACE_REGISTRATION = "http://{}:{}/v1/vision/face/register"
16-
URL_FACE_RECOGNITION = "http://{}:{}/v1/vision/face/recognize"
17-
URL_SCENE_DETECTION = "http://{}:{}/v1/vision/scene"
13+
URL_CUSTOM = "http://{ip}:{port}/v1/vision/custom/{custom_model}"
14+
URL_OBJECT_DETECTION = "http://{ip}:{port}/v1/vision/detection"
15+
URL_FACE_DETECTION = "http://{ip}:{port}/v1/vision/face"
16+
URL_FACE_REGISTRATION = "http://{ip}:{port}/v1/vision/face/register"
17+
URL_FACE_RECOGNITION = "http://{ip}:{port}/v1/vision/face/recognize"
18+
URL_SCENE_DETECTION = "http://{ip}:{port}/v1/vision/scene"
1819

1920

2021
def format_confidence(confidence: Union[str, float]) -> float:
2122
"""Takes a confidence from the API like
22-
0.55623 and returne 55.6 (%).
23+
0.55623 and returns 55.6 (%).
2324
"""
2425
return round(float(confidence) * 100, 1)
2526

2627

27-
def get_confidences_above_threshold(
28-
confidences: List[float], confidence_threshold: float
29-
) -> List[float]:
28+
def get_confidences_above_threshold(confidences: List[float], confidence_threshold: float) -> List[float]:
3029
"""Takes a list of confidences and returns those above a confidence_threshold."""
3130
return [val for val in confidences if val >= confidence_threshold]
3231

@@ -58,9 +57,7 @@ def get_object_confidences(predictions: List[Dict], target_object: str):
5857
"""
5958
Return the list of confidences of instances of target label.
6059
"""
61-
confidences = [
62-
pred["confidence"] for pred in predictions if pred["label"] == target_object
63-
]
60+
confidences = [pred["confidence"] for pred in predictions if pred["label"] == target_object]
6461
return confidences
6562

6663

@@ -69,26 +66,17 @@ def get_objects_summary(predictions: List[Dict]):
6966
Get a summary of the objects detected.
7067
"""
7168
objects = get_objects(predictions)
72-
return {
73-
target_object: len(get_object_confidences(predictions, target_object))
74-
for target_object in objects
75-
}
69+
return {target_object: len(get_object_confidences(predictions, target_object)) for target_object in objects}
7670

7771

78-
def post_image(
79-
url: str, image_bytes: bytes, api_key: str, timeout: int, data: dict = {}
80-
):
72+
def post_image(url: str, image_bytes: bytes, api_key: str, timeout: int, data: dict = {}):
8173
"""Post an image to Deepstack."""
8274
try:
8375
data["api_key"] = api_key
84-
response = requests.post(
85-
url, files={"image": image_bytes}, data=data, timeout=timeout
86-
)
76+
response = requests.post(url, files={"image": image_bytes}, data=data, timeout=timeout)
8777
return response
8878
except requests.exceptions.Timeout:
89-
raise DeepstackException(
90-
f"Timeout connecting to Deepstack, current timeout is {timeout} seconds"
91-
)
79+
raise DeepstackException(f"Timeout connecting to Deepstack, current timeout is {timeout} seconds")
9280
except requests.exceptions.ConnectionError as exc:
9381
raise DeepstackException(f"Connection error: {exc}")
9482

@@ -101,32 +89,25 @@ class Deepstack(object):
10189
"""Base class for deepstack."""
10290

10391
def __init__(
104-
self,
105-
ip_address: str,
106-
port: str,
107-
api_key: str = "",
108-
timeout: int = DEFAULT_TIMEOUT,
109-
url_detection: str = "",
92+
self, ip: str, port: str, api_key: str = "", timeout: int = DEFAULT_TIMEOUT, url: str = "",
11093
):
11194

112-
self._ip_address = ip_address
95+
self._ip = ip
11396
self._port = port
114-
self._url_detection = url_detection
97+
self._url = url
11598
self._api_key = api_key
11699
self._timeout = timeout
117100
self._response = None
118101

119102
def detect(self, image_bytes: bytes):
120103
"""Process image_bytes, performing detection."""
121104
self._response = None
122-
url = self._url_detection.format(self._ip_address, self._port)
105+
url = self._url.format(ip=self._ip, port=self._port)
123106

124107
response = post_image(url, image_bytes, self._api_key, self._timeout)
125108

126109
if not response.status_code == HTTP_OK:
127-
raise DeepstackException(
128-
f"Error from request, status code: {response.status_code}"
129-
)
110+
raise DeepstackException(f"Error from request, status code: {response.status_code}")
130111
return
131112

132113
self._response = response.json()
@@ -144,15 +125,12 @@ class DeepstackObject(Deepstack):
144125
"""Work with objects"""
145126

146127
def __init__(
147-
self,
148-
ip_address: str,
149-
port: str,
150-
api_key: str = "",
151-
timeout: int = DEFAULT_TIMEOUT,
128+
self, ip: str, port: str, api_key: str = "", timeout: int = DEFAULT_TIMEOUT, custom_model: str = None,
152129
):
153-
super().__init__(
154-
ip_address, port, api_key, timeout, url_detection=URL_OBJECT_DETECTION
155-
)
130+
if not custom_model:
131+
super().__init__(ip, port, api_key, timeout, url=URL_OBJECT_DETECTION)
132+
else:
133+
super().__init__(ip, port, api_key, timeout, url=URL_CUSTOM.format(custom_model=custom_model))
156134

157135
@property
158136
def predictions(self):
@@ -164,15 +142,9 @@ class DeepstackScene(Deepstack):
164142
"""Work with scenes"""
165143

166144
def __init__(
167-
self,
168-
ip_address: str,
169-
port: str,
170-
api_key: str = "",
171-
timeout: int = DEFAULT_TIMEOUT,
145+
self, ip: str, port: str, api_key: str = "", timeout: int = DEFAULT_TIMEOUT,
172146
):
173-
super().__init__(
174-
ip_address, port, api_key, timeout, url_detection=URL_SCENE_DETECTION
175-
)
147+
super().__init__(ip, port, api_key, timeout, url=URL_SCENE_DETECTION)
176148

177149
@property
178150
def predictions(self):
@@ -184,15 +156,9 @@ class DeepstackFace(Deepstack):
184156
"""Work with objects"""
185157

186158
def __init__(
187-
self,
188-
ip_address: str,
189-
port: str,
190-
api_key: str = "",
191-
timeout: int = DEFAULT_TIMEOUT,
159+
self, ip: str, port: str, api_key: str = "", timeout: int = DEFAULT_TIMEOUT,
192160
):
193-
super().__init__(
194-
ip_address, port, api_key, timeout, url_detection=URL_FACE_DETECTION
195-
)
161+
super().__init__(ip, port, api_key, timeout, url=URL_FACE_DETECTION)
196162

197163
@property
198164
def predictions(self):
@@ -205,7 +171,7 @@ def register_face(self, name: str, image_bytes: bytes):
205171
"""
206172

207173
response = post_image(
208-
url=URL_FACE_REGISTRATION.format(self._ip_address, self._port),
174+
url=URL_FACE_REGISTRATION.format(ip=self._ip, port=self._port),
209175
image_bytes=image_bytes,
210176
api_key=self._api_key,
211177
timeout=self._timeout,
@@ -220,14 +186,12 @@ def register_face(self, name: str, image_bytes: bytes):
220186

221187
def recognise(self, image_bytes: bytes):
222188
"""Process image_bytes, performing recognition."""
223-
url = URL_FACE_RECOGNITION.format(self._ip_address, self._port)
189+
url = URL_FACE_RECOGNITION.format(ip=self._ip, port=self._port)
224190

225191
response = post_image(url, image_bytes, self._api_key, self._timeout)
226192

227193
if not response.status_code == HTTP_OK:
228-
raise DeepstackException(
229-
f"Error from request, status code: {response.status_code}"
230-
)
194+
raise DeepstackException(f"Error from request, status code: {response.status_code}")
231195
return
232196

233197
self._response = response.json()

0 commit comments

Comments
 (0)