2
2
from queue import Queue
3
3
from typing import Any , Iterable
4
4
5
- from labelbox .exceptions import ThreadException
6
-
7
5
8
6
class ThreadSafeGen :
9
7
"""
@@ -25,27 +23,6 @@ def __next__(self):
25
23
return next (self .iterable )
26
24
27
25
28
- class PrefetchThread (threading .Thread ):
29
- """Class to override the Thread class. Helps raise
30
- exceptions to the main caller thread
31
- """
32
-
33
- def __init__ (self , ** kwargs ):
34
- super ().__init__ (** kwargs )
35
- self .exc = None
36
-
37
- def run (self ):
38
- try :
39
- super ().run ()
40
- except BaseException as e :
41
- self .exc = e
42
-
43
- def join (self , timeout = None ):
44
- threading .Thread .join (self )
45
- if self .exc :
46
- raise self .exc
47
-
48
-
49
26
class PrefetchGenerator :
50
27
"""
51
28
Applys functions asynchronously to the output of a generator.
@@ -59,7 +36,7 @@ class PrefetchGenerator:
59
36
def __init__ (self ,
60
37
data : Iterable [Any ],
61
38
prefetch_limit = 20 ,
62
- num_executors = 1 ,
39
+ num_executors = 4 ,
63
40
multithread : bool = False ):
64
41
if isinstance (data , (list , tuple )):
65
42
self ._data = (r for r in data )
@@ -71,16 +48,18 @@ def __init__(self,
71
48
self .completed_threads = 0
72
49
# Can only iterate over once it the queue.get hangs forever.
73
50
self .done = False
51
+
74
52
if multithread :
75
- num_executors = 4
76
- self .num_executors = num_executors
77
- self .threads = [
78
- PrefetchThread (target = self .fill_queue ) for _ in range (num_executors )
79
- ]
80
- for thread in self .threads :
81
- thread .daemon = True
82
- thread .start ()
83
- thread .join ()
53
+ self .num_executors = num_executors
54
+ self .threads = [
55
+ threading .Thread (target = self .fill_queue )
56
+ for _ in range (num_executors )
57
+ ]
58
+ for thread in self .threads :
59
+ thread .daemon = True
60
+ thread .start ()
61
+ else :
62
+ self .fill_queue ()
84
63
85
64
def _process (self , value ) -> Any :
86
65
raise NotImplementedError ("Abstract method needs to be implemented" )
@@ -93,16 +72,18 @@ def fill_queue(self):
93
72
raise ValueError ("Unexpected None" )
94
73
self .queue .put (value )
95
74
except :
96
- raise ThreadException (
97
- "Unexpected exception while filling the queue." )
75
+ self . queue . put (
76
+ ValueError ( "Unexpected exception while filling queue." ) )
98
77
99
78
def __iter__ (self ):
100
79
return self
101
80
102
81
def __next__ (self ) -> Any :
103
82
if self .done :
104
83
raise StopIteration
105
- value = self .queue .get ()
84
+ value = self .queue .get (block = False )
85
+ if isinstance (value , ValueError ):
86
+ raise value
106
87
while value is None :
107
88
self .completed_threads += 1
108
89
if self .completed_threads == self .num_executors :
0 commit comments