Skip to content

Commit 48e4534

Browse files
committed
concurrently load different variables in ds.load_async using asyncio.gather
1 parent 2079d7e commit 48e4534

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

xarray/core/dataset.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import copy
45
import datetime
56
import math
@@ -531,49 +532,50 @@ def load(self, **kwargs) -> Self:
531532
dask.compute
532533
"""
533534
# access .data to coerce everything to numpy or dask arrays
534-
lazy_data = {
535+
chunked_data = {
535536
k: v._data for k, v in self.variables.items() if is_chunked_array(v._data)
536537
}
537-
if lazy_data:
538-
chunkmanager = get_chunked_array_type(*lazy_data.values())
538+
if chunked_data:
539+
chunkmanager = get_chunked_array_type(*chunked_data.values())
539540

540541
# evaluate all the chunked arrays simultaneously
541542
evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute(
542-
*lazy_data.values(), **kwargs
543+
*chunked_data.values(), **kwargs
543544
)
544545

545-
for k, data in zip(lazy_data, evaluated_data, strict=False):
546+
for k, data in zip(chunked_data, evaluated_data, strict=False):
546547
self.variables[k].data = data
547548

548549
# load everything else sequentially
549-
for k, v in self.variables.items():
550-
if k not in lazy_data:
551-
v.load()
550+
[v.load_async() for k, v in self.variables.items() if k not in chunked_data]
552551

553552
return self
554553

555554
async def load_async(self, **kwargs) -> Self:
555+
# TODO refactor this to pul out the common chunked_data codepath
556+
556557
# this blocks on chunked arrays but not on lazily indexed arrays
557558

558559
# access .data to coerce everything to numpy or dask arrays
559-
lazy_data = {
560+
chunked_data = {
560561
k: v._data for k, v in self.variables.items() if is_chunked_array(v._data)
561562
}
562-
if lazy_data:
563-
chunkmanager = get_chunked_array_type(*lazy_data.values())
563+
if chunked_data:
564+
chunkmanager = get_chunked_array_type(*chunked_data.values())
564565

565566
# evaluate all the chunked arrays simultaneously
566567
evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute(
567-
*lazy_data.values(), **kwargs
568+
*chunked_data.values(), **kwargs
568569
)
569570

570-
for k, data in zip(lazy_data, evaluated_data, strict=False):
571+
for k, data in zip(chunked_data, evaluated_data, strict=False):
571572
self.variables[k].data = data
572573

573-
# load everything else sequentially
574-
for k, v in self.variables.items():
575-
if k not in lazy_data:
576-
await v.load_async()
574+
# load everything else concurrently
575+
tasks = [
576+
v.load_async() for k, v in self.variables.items() if k not in chunked_data
577+
]
578+
await asyncio.gather(*tasks)
577579

578580
return self
579581

0 commit comments

Comments
 (0)