Skip to content

Commit 39927ca

Browse files
committed
reversion to original code + addition for non-multithread option
1 parent fbc6f77 commit 39927ca

File tree

2 files changed

+18
-42
lines changed

2 files changed

+18
-42
lines changed

labelbox/data/generator.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
from queue import Queue
33
from typing import Any, Iterable
44

5-
from labelbox.exceptions import ThreadException
6-
75

86
class ThreadSafeGen:
97
"""
@@ -25,27 +23,6 @@ def __next__(self):
2523
return next(self.iterable)
2624

2725

28-
class PrefetchThread(threading.Thread):
29-
"""Class to override the Thread class. Helps raise
30-
exceptions to the main caller thread
31-
"""
32-
33-
def __init__(self, **kwargs):
34-
super().__init__(**kwargs)
35-
self.exc = None
36-
37-
def run(self):
38-
try:
39-
super().run()
40-
except BaseException as e:
41-
self.exc = e
42-
43-
def join(self, timeout=None):
44-
threading.Thread.join(self)
45-
if self.exc:
46-
raise self.exc
47-
48-
4926
class PrefetchGenerator:
5027
"""
5128
Applys functions asynchronously to the output of a generator.
@@ -59,7 +36,7 @@ class PrefetchGenerator:
5936
def __init__(self,
6037
data: Iterable[Any],
6138
prefetch_limit=20,
62-
num_executors=1,
39+
num_executors=4,
6340
multithread: bool = False):
6441
if isinstance(data, (list, tuple)):
6542
self._data = (r for r in data)
@@ -71,16 +48,18 @@ def __init__(self,
7148
self.completed_threads = 0
7249
# Can only iterate over once it the queue.get hangs forever.
7350
self.done = False
51+
7452
if multithread:
75-
num_executors = 4
76-
self.num_executors = num_executors
77-
self.threads = [
78-
PrefetchThread(target=self.fill_queue) for _ in range(num_executors)
79-
]
80-
for thread in self.threads:
81-
thread.daemon = True
82-
thread.start()
83-
thread.join()
53+
self.num_executors = num_executors
54+
self.threads = [
55+
threading.Thread(target=self.fill_queue)
56+
for _ in range(num_executors)
57+
]
58+
for thread in self.threads:
59+
thread.daemon = True
60+
thread.start()
61+
else:
62+
self.fill_queue()
8463

8564
def _process(self, value) -> Any:
8665
raise NotImplementedError("Abstract method needs to be implemented")
@@ -93,16 +72,18 @@ def fill_queue(self):
9372
raise ValueError("Unexpected None")
9473
self.queue.put(value)
9574
except:
96-
raise ThreadException(
97-
"Unexpected exception while filling the queue.")
75+
self.queue.put(
76+
ValueError("Unexpected exception while filling queue."))
9877

9978
def __iter__(self):
10079
return self
10180

10281
def __next__(self) -> Any:
10382
if self.done:
10483
raise StopIteration
105-
value = self.queue.get()
84+
value = self.queue.get(block=False)
85+
if isinstance(value, ValueError):
86+
raise value
10687
while value is None:
10788
self.completed_threads += 1
10889
if self.completed_threads == self.num_executors:

labelbox/exceptions.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,4 @@ class MALValidationError(LabelboxError):
122122

123123
class OperationNotAllowedException(Exception):
124124
"""Raised when user does not have permissions to a resource or has exceeded usage limit"""
125-
pass
126-
127-
128-
class ThreadException(Exception):
129-
"""Raised when there is an issue with a thread, typically in Prefetch generator"""
130-
pass
125+
pass

0 commit comments

Comments
 (0)