Skip to content

Commit 40973e8

Browse files
committed
updates to files to remove multithread param and instead just set based on num exec
1 parent b78d67b commit 40973e8

File tree

3 files changed

+23
-31
lines changed

3 files changed

+23
-31
lines changed

labelbox/data/annotation_types/collection.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,7 @@ class LabelGenerator(PrefetchGenerator):
172172
than the LabelList but will be much more memory efficient.
173173
"""
174174

175-
def __init__(self,
176-
data: Generator[Label, None, None],
177-
multithread: bool = False,
178-
*args,
179-
**kwargs):
175+
def __init__(self, data: Generator[Label, None, None], *args, **kwargs):
180176
self._fns = {}
181177
super().__init__(data, *args, **kwargs)
182178

labelbox/data/generator.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,7 @@ class PrefetchGenerator:
3333
Useful for modifying the generator results based on data from a network
3434
"""
3535

36-
#maybe change num exec to just 1, and if 1, make sync
37-
#instead of self.get qeue in next, itll return just self._data.next
38-
#kwarg on export for multithread, and all other things that use prefetch
39-
40-
def __init__(self,
41-
data: Iterable[Any],
42-
prefetch_limit=20,
43-
num_executors=4,
44-
multithread: bool = False):
36+
def __init__(self, data: Iterable[Any], prefetch_limit=20, num_executors=1):
4537
if isinstance(data, (list, tuple)):
4638
self._data = (r for r in data)
4739
else:
@@ -51,8 +43,8 @@ def __init__(self,
5143
self._data = ThreadSafeGen(self._data)
5244
self.completed_threads = 0
5345
# Can only iterate over once it the queue.get hangs forever.
46+
self.multithread = False if num_executors == 1 else True
5447
self.done = False
55-
self.multithread = multithread
5648

5749
if self.multithread:
5850
self.num_executors = num_executors
@@ -64,7 +56,7 @@ def __init__(self,
6456
thread.daemon = True
6557
thread.start()
6658
else:
67-
self.fill_queue()
59+
self._data = iter(self._data)
6860

6961
def _process(self, value) -> Any:
7062
raise NotImplementedError("Abstract method needs to be implemented")
@@ -79,25 +71,29 @@ def fill_queue(self):
7971
except Exception as e:
8072
self.queue.put(
8173
ValueError("Unexpected exception while filling queue. %r", e))
74+
finally:
75+
self.queue.put(None)
8276

8377
def __iter__(self):
8478
return self
8579

8680
def __next__(self) -> Any:
87-
if self.done or self.queue.empty():
81+
if self.done:
8882
raise StopIteration
89-
value = self.queue.get()
90-
if isinstance(value, ValueError):
91-
raise value
92-
while value is None:
93-
if not self.multithread:
94-
value = self.queue.get()
95-
continue
96-
self.completed_threads += 1
97-
if self.completed_threads == self.num_executors:
98-
self.done = True
99-
for thread in self.threads:
100-
thread.join()
101-
raise StopIteration
83+
84+
if self.multithread:
10285
value = self.queue.get()
86+
if isinstance(value, Exception):
87+
raise value
88+
89+
while value is None:
90+
self.completed_threads += 1
91+
if self.completed_threads == self.num_executors:
92+
self.done = True
93+
for thread in self.threads:
94+
thread.join()
95+
raise StopIteration
96+
value = self.queue.get()
97+
else:
98+
value = self._process(next(self._data))
10399
return value

labelbox/data/serialization/labelbox_v1/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class LBV1VideoIterator(PrefetchGenerator):
8282
Generator that fetches video annotations in the background to be faster.
8383
"""
8484

85-
def __init__(self, examples, client, multithread: bool = False):
85+
def __init__(self, examples, client):
8686
self.client = client
8787
super().__init__(examples)
8888

0 commit comments

Comments
 (0)