Skip to content

Commit fbd4e63

Browse files
committed
adds to raise exception to main caller thread from prefetch generator
1 parent dba0ce8 commit fbd4e63

File tree

4 files changed

+51
-13
lines changed

4 files changed

+51
-13
lines changed

labelbox/data/annotation_types/collection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,11 @@ class LabelGenerator(PrefetchGenerator):
172172
than the LabelList but will be much more memory efficient.
173173
"""
174174

175-
def __init__(self, data: Generator[Label, None, None], *args, **kwargs):
175+
def __init__(self,
176+
data: Generator[Label, None, None],
177+
multithread: bool = False,
178+
*args,
179+
**kwargs):
176180
self._fns = {}
177181
super().__init__(data, *args, **kwargs)
178182

labelbox/data/generator.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import logging
21
import threading
32
from queue import Queue
43
from typing import Any, Iterable
5-
import threading
64

7-
logger = logging.getLogger(__name__)
5+
from labelbox.exceptions import ThreadException
86

97

108
class ThreadSafeGen:
@@ -27,13 +25,43 @@ def __next__(self):
2725
return next(self.iterable)
2826

2927

28+
class PrefetchThread(threading.Thread):
29+
"""Class to override the Thread class. Helps raise
30+
exceptions to the main caller thread
31+
"""
32+
33+
def __init__(self, **kwargs):
34+
super().__init__(**kwargs)
35+
self.exc = None
36+
37+
def run(self):
38+
try:
39+
super().run()
40+
except BaseException as e:
41+
self.exc = e
42+
43+
def join(self, timeout=None):
44+
threading.Thread.join(self)
45+
if self.exc:
46+
raise self.exc
47+
48+
3049
class PrefetchGenerator:
3150
"""
3251
Applys functions asynchronously to the output of a generator.
3352
Useful for modifying the generator results based on data from a network
3453
"""
3554

36-
def __init__(self, data: Iterable[Any], prefetch_limit=20, num_executors=4):
55+
#maybe change num exec to just 1, and if 1, make sync
56+
#instead of self.get qeue in next, itll return just self._data.next
57+
#kwarg on export for multithread, and all other things that use prefetch
58+
59+
# def __init__(self, data: Iterable[Any], prefetch_limit=20, num_executors=4):
60+
def __init__(self,
61+
data: Iterable[Any],
62+
prefetch_limit=20,
63+
num_executors=1,
64+
multithread: bool = False):
3765
if isinstance(data, (list, tuple)):
3866
self._data = (r for r in data)
3967
else:
@@ -44,14 +72,17 @@ def __init__(self, data: Iterable[Any], prefetch_limit=20, num_executors=4):
4472
self.completed_threads = 0
4573
# Can only iterate over once it the queue.get hangs forever.
4674
self.done = False
75+
if multithread:
76+
num_executors = 4
4777
self.num_executors = num_executors
4878
self.threads = [
49-
threading.Thread(target=self.fill_queue)
50-
for _ in range(num_executors)
79+
PrefetchThread(target=self.fill_queue) for _ in range(num_executors)
5180
]
5281
for thread in self.threads:
5382
thread.daemon = True
5483
thread.start()
84+
for thread in self.threads:
85+
thread.join()
5586

5687
def _process(self, value) -> Any:
5788
raise NotImplementedError("Abstract method needs to be implemented")
@@ -63,11 +94,9 @@ def fill_queue(self):
6394
if value is None:
6495
raise ValueError("Unexpected None")
6596
self.queue.put(value)
66-
except Exception as e:
67-
logger.warning("Unexpected exception while filling the queue. %r",
68-
e)
69-
finally:
70-
self.queue.put(None)
97+
except:
98+
raise ThreadException(
99+
"Unexpected exception while filling the queue.")
71100

72101
def __iter__(self):
73102
return self

labelbox/data/serialization/labelbox_v1/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class LBV1VideoIterator(PrefetchGenerator):
8282
Generator that fetches video annotations in the background to be faster.
8383
"""
8484

85-
def __init__(self, examples, client):
85+
def __init__(self, examples, client, multithread: bool = False):
8686
self.client = client
8787
super().__init__(examples)
8888

labelbox/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,8 @@ class MALValidationError(LabelboxError):
123123
class OperationNotAllowedException(Exception):
124124
"""Raised when user does not have permissions to a resource or has exceeded usage limit"""
125125
pass
126+
127+
128+
class ThreadException(Exception):
129+
"""Raised when there is an issue with a thread, typically in Prefetch generator"""
130+
pass

0 commit comments

Comments
 (0)