Skip to content

Commit 7e9ae0f

Browse files
committed
implement async load for dataarray and dataset
1 parent 629ab31 commit 7e9ae0f

File tree

3 files changed

+84
-8
lines changed

3 files changed

+84
-8
lines changed

xarray/core/dataarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,6 +1160,14 @@ def load(self, **kwargs) -> Self:
11601160
self._coords = new._coords
11611161
return self
11621162

1163+
async def async_load(self, **kwargs) -> Self:
1164+
temp_ds = self._to_temp_dataset()
1165+
ds = await temp_ds.async_load(**kwargs)
1166+
new = self._from_temp_dataset(ds)
1167+
self._variable = new._variable
1168+
self._coords = new._coords
1169+
return self
1170+
11631171
def compute(self, **kwargs) -> Self:
11641172
"""Manually trigger loading of this array's data from disk or a
11651173
remote source into memory and return a new array.

xarray/core/dataset.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,31 @@ def load(self, **kwargs) -> Self:
552552

553553
return self
554554

555+
async def async_load(self, **kwargs) -> Self:
556+
# this blocks on chunked arrays but not on lazily indexed arrays
557+
558+
# access .data to coerce everything to numpy or dask arrays
559+
lazy_data = {
560+
k: v._data for k, v in self.variables.items() if is_chunked_array(v._data)
561+
}
562+
if lazy_data:
563+
chunkmanager = get_chunked_array_type(*lazy_data.values())
564+
565+
# evaluate all the chunked arrays simultaneously
566+
evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute(
567+
*lazy_data.values(), **kwargs
568+
)
569+
570+
for k, data in zip(lazy_data, evaluated_data, strict=False):
571+
self.variables[k].data = data
572+
573+
# load everything else sequentially
574+
for k, v in self.variables.items():
575+
if k not in lazy_data:
576+
await v.async_load()
577+
578+
return self
579+
555580
def __dask_tokenize__(self) -> object:
556581
from dask.base import normalize_token
557582

xarray/tests/test_async.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,27 +101,70 @@ def memorystore() -> "MemoryStore":
101101
@requires_zarr_v3
102102
@pytest.mark.asyncio
103103
class TestAsyncLoad:
104-
async def test_async_load_variable(self, memorystore):
105-
N_LOADS = 5
106-
LATENCY = 1.0
104+
N_LOADS = 10
105+
LATENCY = 1.0
107106

108-
latencystore = LatencyStore(memorystore, latency=LATENCY)
107+
# TODO refactor these tests
108+
async def test_async_load_variable(self, memorystore):
109+
latencystore = LatencyStore(memorystore, latency=self.LATENCY)
109110
ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
110111

111112
# TODO add async load to Dataset and DataArray as well as to Variable
112113
# TODO change the syntax to `.async.load()`?
113114
start_time = time.time()
114-
tasks = [ds["foo"].variable.async_load() for _ in range(N_LOADS)]
115+
tasks = [ds["foo"].variable.async_load() for _ in range(self.N_LOADS)]
115116
results = await asyncio.gather(*tasks)
116117
total_time = time.time() - start_time
117118

118119
for result in results:
119120
xrt.assert_identical(result, ds["foo"].variable.load())
120121

121-
assert total_time > LATENCY # Cannot possibly be quicker than this
122+
assert total_time > self.LATENCY # Cannot possibly be quicker than this
123+
assert (
124+
total_time < self.LATENCY * self.N_LOADS
125+
) # If this isn't true we're gaining nothing from async
126+
assert (
127+
abs(total_time - self.LATENCY) < 0.5
128+
) # Should take approximately LATENCY seconds, but allow some buffer
129+
130+
async def test_async_load_dataarray(self, memorystore):
131+
latencystore = LatencyStore(memorystore, latency=self.LATENCY)
132+
ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
133+
134+
# TODO change the syntax to `.async.load()`?
135+
start_time = time.time()
136+
tasks = [ds["foo"].async_load() for _ in range(self.N_LOADS)]
137+
results = await asyncio.gather(*tasks)
138+
total_time = time.time() - start_time
139+
140+
for result in results:
141+
xrt.assert_identical(result, ds["foo"].load())
142+
143+
assert total_time > self.LATENCY # Cannot possibly be quicker than this
144+
assert (
145+
total_time < self.LATENCY * self.N_LOADS
146+
) # If this isn't true we're gaining nothing from async
147+
assert (
148+
abs(total_time - self.LATENCY) < 0.5
149+
) # Should take approximately LATENCY seconds, but allow some buffer
150+
151+
async def test_async_load_dataset(self, memorystore):
152+
latencystore = LatencyStore(memorystore, latency=self.LATENCY)
153+
ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
154+
155+
# TODO change the syntax to `.async.load()`?
156+
start_time = time.time()
157+
tasks = [ds.async_load() for _ in range(self.N_LOADS)]
158+
results = await asyncio.gather(*tasks)
159+
total_time = time.time() - start_time
160+
161+
for result in results:
162+
xrt.assert_identical(result, ds.load())
163+
164+
assert total_time > self.LATENCY # Cannot possibly be quicker than this
122165
assert (
123-
total_time < LATENCY * N_LOADS
166+
total_time < self.LATENCY * self.N_LOADS
124167
) # If this isn't true we're gaining nothing from async
125168
assert (
126-
abs(total_time - LATENCY) < 0.5
169+
abs(total_time - self.LATENCY) < 0.5
127170
) # Should take approximately LATENCY seconds, but allow some buffer

0 commit comments

Comments
 (0)