1
- import logging
2
1
import threading
3
2
from queue import Queue
4
3
from typing import Any , Iterable
5
- import threading
6
4
7
- logger = logging . getLogger ( __name__ )
5
+ from labelbox . exceptions import ThreadException
8
6
9
7
10
8
class ThreadSafeGen :
@@ -27,13 +25,43 @@ def __next__(self):
27
25
return next (self .iterable )
28
26
29
27
28
+ class PrefetchThread (threading .Thread ):
29
+ """Class to override the Thread class. Helps raise
30
+ exceptions to the main caller thread
31
+ """
32
+
33
+ def __init__ (self , ** kwargs ):
34
+ super ().__init__ (** kwargs )
35
+ self .exc = None
36
+
37
+ def run (self ):
38
+ try :
39
+ super ().run ()
40
+ except BaseException as e :
41
+ self .exc = e
42
+
43
+ def join (self , timeout = None ):
44
+ threading .Thread .join (self )
45
+ if self .exc :
46
+ raise self .exc
47
+
48
+
30
49
class PrefetchGenerator :
31
50
"""
32
51
Applys functions asynchronously to the output of a generator.
33
52
Useful for modifying the generator results based on data from a network
34
53
"""
35
54
36
- def __init__ (self , data : Iterable [Any ], prefetch_limit = 20 , num_executors = 4 ):
55
+ #maybe change num exec to just 1, and if 1, make sync
56
+ #instead of self.get qeue in next, itll return just self._data.next
57
+ #kwarg on export for multithread, and all other things that use prefetch
58
+
59
+ # def __init__(self, data: Iterable[Any], prefetch_limit=20, num_executors=4):
60
+ def __init__ (self ,
61
+ data : Iterable [Any ],
62
+ prefetch_limit = 20 ,
63
+ num_executors = 1 ,
64
+ multithread : bool = False ):
37
65
if isinstance (data , (list , tuple )):
38
66
self ._data = (r for r in data )
39
67
else :
@@ -44,14 +72,17 @@ def __init__(self, data: Iterable[Any], prefetch_limit=20, num_executors=4):
44
72
self .completed_threads = 0
45
73
# Can only iterate over once it the queue.get hangs forever.
46
74
self .done = False
75
+ if multithread :
76
+ num_executors = 4
47
77
self .num_executors = num_executors
48
78
self .threads = [
49
- threading .Thread (target = self .fill_queue )
50
- for _ in range (num_executors )
79
+ PrefetchThread (target = self .fill_queue ) for _ in range (num_executors )
51
80
]
52
81
for thread in self .threads :
53
82
thread .daemon = True
54
83
thread .start ()
84
+ for thread in self .threads :
85
+ thread .join ()
55
86
56
87
def _process (self , value ) -> Any :
57
88
raise NotImplementedError ("Abstract method needs to be implemented" )
@@ -63,11 +94,9 @@ def fill_queue(self):
63
94
if value is None :
64
95
raise ValueError ("Unexpected None" )
65
96
self .queue .put (value )
66
- except Exception as e :
67
- logger .warning ("Unexpected exception while filling the queue. %r" ,
68
- e )
69
- finally :
70
- self .queue .put (None )
97
+ except :
98
+ raise ThreadException (
99
+ "Unexpected exception while filling the queue." )
71
100
72
101
def __iter__ (self ):
73
102
return self
0 commit comments