@@ -33,15 +33,7 @@ class PrefetchGenerator:
33
33
Useful for modifying the generator results based on data from a network
34
34
"""
35
35
36
- #maybe change num exec to just 1, and if 1, make sync
37
- #instead of self.get qeue in next, itll return just self._data.next
38
- #kwarg on export for multithread, and all other things that use prefetch
39
-
40
- def __init__ (self ,
41
- data : Iterable [Any ],
42
- prefetch_limit = 20 ,
43
- num_executors = 4 ,
44
- multithread : bool = False ):
36
+ def __init__ (self , data : Iterable [Any ], prefetch_limit = 20 , num_executors = 1 ):
45
37
if isinstance (data , (list , tuple )):
46
38
self ._data = (r for r in data )
47
39
else :
@@ -51,8 +43,8 @@ def __init__(self,
51
43
self ._data = ThreadSafeGen (self ._data )
52
44
self .completed_threads = 0
53
45
# Can only iterate over once it the queue.get hangs forever.
46
+ self .multithread = False if num_executors == 1 else True
54
47
self .done = False
55
- self .multithread = multithread
56
48
57
49
if self .multithread :
58
50
self .num_executors = num_executors
@@ -64,7 +56,7 @@ def __init__(self,
64
56
thread .daemon = True
65
57
thread .start ()
66
58
else :
67
- self .fill_queue ( )
59
+ self ._data = iter ( self . _data )
68
60
69
61
def _process (self , value ) -> Any :
70
62
raise NotImplementedError ("Abstract method needs to be implemented" )
@@ -79,25 +71,29 @@ def fill_queue(self):
79
71
except Exception as e :
80
72
self .queue .put (
81
73
ValueError ("Unexpected exception while filling queue. %r" , e ))
74
+ finally :
75
+ self .queue .put (None )
82
76
83
77
def __iter__ (self ):
84
78
return self
85
79
86
80
def __next__ (self ) -> Any :
87
- if self .done or self . queue . empty () :
81
+ if self .done :
88
82
raise StopIteration
89
- value = self .queue .get ()
90
- if isinstance (value , ValueError ):
91
- raise value
92
- while value is None :
93
- if not self .multithread :
94
- value = self .queue .get ()
95
- continue
96
- self .completed_threads += 1
97
- if self .completed_threads == self .num_executors :
98
- self .done = True
99
- for thread in self .threads :
100
- thread .join ()
101
- raise StopIteration
83
+
84
+ if self .multithread :
102
85
value = self .queue .get ()
86
+ if isinstance (value , Exception ):
87
+ raise value
88
+
89
+ while value is None :
90
+ self .completed_threads += 1
91
+ if self .completed_threads == self .num_executors :
92
+ self .done = True
93
+ for thread in self .threads :
94
+ thread .join ()
95
+ raise StopIteration
96
+ value = self .queue .get ()
97
+ else :
98
+ value = self ._process (next (self ._data ))
103
99
return value
0 commit comments