-
Notifications
You must be signed in to change notification settings - Fork 32
Description
Why this feature is necessary:
Resolving this would enable us to use larger batches and possibly train with data that hasn't been moved to locally mounted SSDs.
A possible solution is:
The maximum time to get N batches should be when max_workers=1. Adding more workers should enable us to parallelize this but right now there appears to be some blocking going on.
I have considered the following alternatives:
The issue can likely be traced to the implementation in
def enqueue_batches(self) -> None: |
I have experimented with a few different ways to use workers in this method and have not seen significant improvement. Attempts included using a ThreadPoolExecutor
combined with an as_completed
loop over futures and also without an as_completed
loop while queueing futures themselves.
Additional context
Reproduce this with the following:
bh = BatchHandler(..., sample_shape=(60, 60, 10),
queue_cap=50, batch_size=16,
n_batches=16, mode='lazy',
max_workers=1)
start = time.time()
batches = list(bh)
print(time.time() - start)
bh = BatchHandler(..., sample_shape=(60, 60, 10),
queue_cap=50, batch_size=16,
n_batches=16, mode='lazy',
max_workers=10)
start = time.time()
batches = list(bh)
print(time.time() - start)
I have profiled both of these code blocks with cProfile
and see some strange differences in the timing and number of calls for the sample_batch
function but don't know what to make of those differences. Without moving data to local SSD: With max_workers=1 I see a per call time of ~40 seconds and 16 calls With max_workers=10 I see a call time of ~400 seconds and 2 calls. Calls to get_batch
go from ~20 seconds to ~60 seconds.
profile for max_workers=1:
profile for max_workers=10:
Urgency / Timeframe
Not urgent. max_workers=1 works well currently with training data on local SSD.