|
1 | 1 | """
|
2 | 2 | Deepstack core.
|
3 | 3 | """
|
4 |
| -import imghdr |
5 | 4 | import requests
|
6 | 5 | from PIL import Image
|
| 6 | +from typing import Union, List, Set, Dict |
7 | 7 |
|
8 | 8 | ## Const
|
9 |
| -CLASSIFIER = "deepstack" |
10 | 9 | HTTP_OK = 200
|
11 | 10 | HTTP_BAD_REQUEST = 400
|
12 | 11 | HTTP_UNAUTHORIZED = 401
|
13 |
| -TIMEOUT = 30 # seconds |
| 12 | +DEFAULT_TIMEOUT = 10 # seconds |
14 | 13 |
|
15 | 14 |
|
16 |
| -def get_matched_faces(predictions: dict): |
| 15 | +def format_confidence(confidence: Union[str, float]) -> float: |
| 16 | + """Takes a confidence from the API like |
| 17 | + 0.55623 and returne 55.6 (%). |
17 | 18 | """
|
18 |
| - Get the predicted faces and their confidence. |
| 19 | + return round(float(confidence) * 100, 1) |
| 20 | + |
| 21 | + |
| 22 | +def get_confidences_above_threshold( |
| 23 | + confidences: List[float], confidence_threshold: float |
| 24 | +) -> List[float]: |
| 25 | + """Takes a list of confidences and returns those above a confidence_threshold.""" |
| 26 | + return [val for val in confidences if val >= confidence_threshold] |
| 27 | + |
| 28 | + |
| 29 | +def get_object_labels(predictions: List[Dict]) -> Set[str]: |
19 | 30 | """
|
20 |
| - try: |
21 |
| - matched_faces = { |
22 |
| - face["userid"]: round(face["confidence"] * 100, 1) |
23 |
| - for face in predictions |
24 |
| - if not face["userid"] == "unknown" |
25 |
| - } |
26 |
| - return matched_faces |
27 |
| - except: |
28 |
| - return {} |
| 31 | + Get a list of the unique object labels predicted. |
| 32 | + """ |
| 33 | + labels = [pred["label"] for pred in predictions] |
| 34 | + return set(labels) |
29 | 35 |
|
30 | 36 |
|
31 |
| -def is_valid_image(file_path: str): |
| 37 | +def get_label_confidences(predictions: List[Dict], target_label: str): |
32 | 38 | """
|
33 |
| - Check file_path is valid image, using PIL then imghdr. |
| 39 | + Return the list of confidences of instances of target label. |
34 | 40 | """
|
35 |
| - try: |
36 |
| - with Image.open(file_path): |
37 |
| - pass |
| 41 | + confidences = [ |
| 42 | + pred["confidence"] for pred in predictions if pred["label"] == target_label |
| 43 | + ] |
| 44 | + return confidences |
38 | 45 |
|
39 |
| - image_extension = imghdr.what(file_path) |
40 |
| - if image_extension in ["jpeg", ".jpg", ".png"]: |
41 |
| - return True |
42 |
| - return False |
43 |
| - except Exception as exc: |
44 |
| - print(exc) |
45 |
| - return False |
46 | 46 |
|
| 47 | +def get_objects_summary(predictions: List[Dict]): |
| 48 | + """ |
| 49 | + Get a summary of the objects detected. |
| 50 | + """ |
| 51 | + labels = get_object_labels(predictions) |
| 52 | + return { |
| 53 | + label: len(get_label_confidences(predictions, target_label=label)) |
| 54 | + for label in labels |
| 55 | + } |
47 | 56 |
|
48 |
| -def post_image(url: str, image: bytes): |
49 |
| - """Post an image to the classifier.""" |
| 57 | + |
| 58 | +def post_image(url: str, image: bytes, api_key: str, timeout: int): |
| 59 | + """Post an image to Deepstack.""" |
50 | 60 | try:
|
51 |
| - response = requests.post(url, files={"image": image}, timeout=TIMEOUT) |
| 61 | + response = requests.post( |
| 62 | + url, files={"image": image}, data={"api_key": api_key}, timeout=timeout |
| 63 | + ) |
52 | 64 | return response
|
53 |
| - except requests.exceptions.ConnectionError: |
54 |
| - print("ConnectionError: Is %s running?", CLASSIFIER) |
55 |
| - return None |
56 | 65 | except requests.exceptions.Timeout:
|
57 |
| - print("Timeout error from %s", CLASSIFIER) |
58 |
| - return None |
| 66 | + raise DeepstackException( |
| 67 | + f"Timeout connecting to Deepstack, current timeout is {timeout} seconds" |
| 68 | + ) |
| 69 | + |
59 | 70 |
|
| 71 | +class DeepstackException(Exception): |
| 72 | + pass |
60 | 73 |
|
61 |
| -class DeepstackFace: |
62 |
| - """Work with faces.""" |
63 | 74 |
|
64 |
| - def __init__(self, ip_address: str, port: str): |
| 75 | +class DeepstackObject: |
| 76 | + """The object detection API locates and classifies 80 |
| 77 | + different kinds of objects in a single image..""" |
65 | 78 |
|
66 |
| - self._url_check = "http://{}:{}/v1/vision/face/recognize".format( |
| 79 | + def __init__( |
| 80 | + self, |
| 81 | + ip_address: str, |
| 82 | + port: str, |
| 83 | + api_key: str = "", |
| 84 | + timeout: int = DEFAULT_TIMEOUT, |
| 85 | + ): |
| 86 | + |
| 87 | + self._url_object_detection = "http://{}:{}/v1/vision/detection".format( |
67 | 88 | ip_address, port
|
68 | 89 | )
|
69 |
| - |
70 |
| - self._faces = None |
71 |
| - self._matched = {} |
| 90 | + self._api_key = api_key |
| 91 | + self._timeout = timeout |
| 92 | + self._predictions = [] |
72 | 93 |
|
73 | 94 | def process_file(self, file_path: str):
|
74 | 95 | """Process an image file."""
|
75 |
| - if is_valid_image(file_path): |
76 |
| - with open(file_path, "rb") as image_bytes: |
77 |
| - self.process_image_bytes(image_bytes) |
| 96 | + with open(file_path, "rb") as image_bytes: |
| 97 | + self.process_image_bytes(image_bytes) |
78 | 98 |
|
79 | 99 | def process_image_bytes(self, image_bytes: bytes):
|
80 | 100 | """Process an image."""
|
81 |
| - response = post_image(self._url_check, image_bytes) |
82 |
| - if response: |
83 |
| - if response.status_code == HTTP_OK: |
84 |
| - predictions_json = response.json()["predictions"] |
85 |
| - self._faces = len(predictions_json) |
86 |
| - self._matched = get_matched_faces(predictions_json) |
| 101 | + self._predictions = [] |
| 102 | + |
| 103 | + response = post_image( |
| 104 | + self._url_object_detection, image_bytes, self._api_key, self._timeout |
| 105 | + ) |
87 | 106 |
|
88 |
| - else: |
89 |
| - self._faces = None |
90 |
| - self._matched = {} |
| 107 | + if response.status_code == HTTP_OK: |
| 108 | + if response.json()["success"]: |
| 109 | + self._predictions = response.json()["predictions"] |
| 110 | + else: |
| 111 | + error = response.json()["error"] |
| 112 | + raise DeepstackException(f"Error from Deepstack: {error}") |
91 | 113 |
|
92 | 114 | @property
|
93 |
| - def attributes(self): |
| 115 | + def predictions(self): |
94 | 116 | """Return the classifier attributes."""
|
95 |
| - return { |
96 |
| - "faces": self._faces, |
97 |
| - "matched_faces": self._matched, |
98 |
| - "total_matched_faces": len(self._matched), |
99 |
| - } |
| 117 | + return self._predictions |
0 commit comments