Skip to content

Commit 11b2080

Browse files
committed
Remove file handling
1 parent 7406da7 commit 11b2080

File tree

3 files changed

+27
-25
lines changed

3 files changed

+27
-25
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ TODO: add face registration and recognition.
1414
* Use `venv` -> `source venv/bin/activate`
1515
* `pip install -r requirements-dev.txt`
1616
* Run tests with `venv/bin/pytest tests/*`
17-
* Black format with `venv/bin/black deepstack/core.py`
17+
* Black format with `venv/bin/black deepstack/core.py` and `venv/bin/black tests/test_deepstack.py`

deepstack/core.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,14 @@ def get_objects_summary(predictions: List[Dict]):
5959
}
6060

6161

62-
def post_image(url: str, image: bytes, api_key: str, timeout: int, data: dict = {}):
62+
def post_image(
63+
url: str, image_bytes: bytes, api_key: str, timeout: int, data: dict = {}
64+
):
6365
"""Post an image to Deepstack."""
6466
try:
6567
data["api_key"] = api_key
6668
response = requests.post(
67-
url, files={"image": image}, data=data, timeout=timeout
69+
url, files={"image": image_bytes}, data=data, timeout=timeout
6870
)
6971
return response
7072
except requests.exceptions.Timeout:
@@ -96,13 +98,8 @@ def __init__(
9698
self._timeout = timeout
9799
self._predictions = []
98100

99-
def process_file(self, file_path: str):
100-
"""Process an image file."""
101-
with open(file_path, "rb") as image_bytes:
102-
self.process_image_bytes(image_bytes)
103-
104-
def process_image_bytes(self, image_bytes: bytes):
105-
"""Process an image, performing detection."""
101+
def detect(self, image_bytes: bytes):
102+
"""Process image_bytes, performing detection."""
106103
self._predictions = []
107104
url = self._url_detection.format(self._ip_address, self._port)
108105

@@ -150,19 +147,18 @@ def __init__(
150147
ip_address, port, api_key, timeout, url_detection=URL_FACE_DETECTION
151148
)
152149

153-
def register_face(self, name: str, file_path: str):
150+
def register_face(self, name: str, image_bytes: bytes):
154151
"""
155152
Register a face name to a file.
156153
"""
157154

158-
with open(file_path, "rb") as image:
159-
response = post_image(
160-
url=URL_FACE_REGISTRATION.format(self._ip_address, self._port),
161-
image=image,
162-
api_key=self._api_key,
163-
timeout=self._timeout,
164-
data={"userid": name},
165-
)
155+
response = post_image(
156+
url=URL_FACE_REGISTRATION.format(self._ip_address, self._port),
157+
image_bytes=image_bytes,
158+
api_key=self._api_key,
159+
timeout=self._timeout,
160+
data={"userid": name},
161+
)
166162

167163
if response.status_code == 200 and response.json()["success"] == True:
168164
print("Taught face {} using file {}".format(name, file_path))

tests/test_deepstack.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,23 +48,25 @@
4848
CONFIDENCE_THRESHOLD = 0.7
4949

5050

51-
def test_DeepstackObject_process_image_bytes():
51+
def test_DeepstackObject_detect():
5252
"""Test a good response from server."""
5353
with requests_mock.Mocker() as mock_req:
54-
mock_req.post(MOCK_URL, status_code=ds.HTTP_OK, json=MOCK_OBJECT_DETECTION_RESPONSE)
54+
mock_req.post(
55+
MOCK_URL, status_code=ds.HTTP_OK, json=MOCK_OBJECT_DETECTION_RESPONSE
56+
)
5557

5658
dsobject = ds.DeepstackObject(MOCK_IP_ADDRESS, MOCK_PORT)
57-
dsobject.process_image_bytes(MOCK_BYTES)
59+
dsobject.detect(MOCK_BYTES)
5860
assert dsobject.predictions == MOCK_OBJECT_PREDICTIONS
5961

6062

61-
def test_DeepstackObject_process_image_bytes_timeout():
63+
def test_DeepstackObject_detect_timeout():
6264
"""Test a timeout. THIS SHOULD FAIL"""
6365
with pytest.raises(ds.DeepstackException) as excinfo:
6466
with requests_mock.Mocker() as mock_req:
6567
mock_req.post(MOCK_URL, exc=requests.exceptions.ConnectTimeout)
6668
dsobject = ds.DeepstackObject(MOCK_IP_ADDRESS, MOCK_PORT)
67-
dsobject.process_image_bytes(MOCK_BYTES)
69+
dsobject.detect(MOCK_BYTES)
6870
assert False
6971
assert "SHOULD FAIL" in str(excinfo.value)
7072

@@ -90,6 +92,10 @@ def test_get_object_confidences():
9092

9193
def test_get_confidences_above_threshold():
9294
assert (
93-
len(ds.get_confidences_above_threshold(MOCK_OBJECT_CONFIDENCES, CONFIDENCE_THRESHOLD))
95+
len(
96+
ds.get_confidences_above_threshold(
97+
MOCK_OBJECT_CONFIDENCES, CONFIDENCE_THRESHOLD
98+
)
99+
)
94100
== 1
95101
)

0 commit comments

Comments
 (0)