Skip to content

Commit d288351

Browse files
committed
factor out common logic
1 parent 7e9ae0f commit d288351

File tree

1 file changed

+42
-37
lines changed

1 file changed

+42
-37
lines changed

xarray/tests/test_async.py

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import time
33
from collections.abc import Iterable
4+
from contextlib import asynccontextmanager
45
from typing import TypeVar
56

67
import numpy as np
@@ -98,73 +99,77 @@ def memorystore() -> "MemoryStore":
9899
return memorystore
99100

100101

102+
class AsyncTimer:
103+
"""Context manager for timing async operations and making assertions about their execution time."""
104+
105+
start_time: float
106+
end_time: float
107+
total_time: float
108+
109+
@asynccontextmanager
110+
async def measure(self):
111+
"""Measure the execution time of the async code within this context."""
112+
self.start_time = time.time()
113+
try:
114+
yield self
115+
finally:
116+
self.end_time = time.time()
117+
self.total_time = self.end_time - self.start_time
118+
119+
101120
@requires_zarr_v3
102121
@pytest.mark.asyncio
103122
class TestAsyncLoad:
104-
N_LOADS = 10
105-
LATENCY = 1.0
123+
N_LOADS: int = 5
124+
LATENCY: float = 1.0
125+
126+
def assert_time_as_expected(self, total_time: float) -> None:
127+
assert total_time > self.LATENCY # Cannot possibly be quicker than this
128+
assert (
129+
total_time < self.LATENCY * self.N_LOADS
130+
) # If this isn't true we're gaining nothing from async
131+
assert (
132+
abs(total_time - self.LATENCY) < 0.5
133+
) # Should take approximately LATENCY seconds, but allow some buffer
106134

107-
# TODO refactor these tests
108135
async def test_async_load_variable(self, memorystore):
109136
latencystore = LatencyStore(memorystore, latency=self.LATENCY)
110137
ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
111138

112-
# TODO add async load to Dataset and DataArray as well as to Variable
113139
# TODO change the syntax to `.async.load()`?
114-
start_time = time.time()
115-
tasks = [ds["foo"].variable.async_load() for _ in range(self.N_LOADS)]
116-
results = await asyncio.gather(*tasks)
117-
total_time = time.time() - start_time
140+
async with AsyncTimer().measure() as timer:
141+
tasks = [ds["foo"].variable.async_load() for _ in range(self.N_LOADS)]
142+
results = await asyncio.gather(*tasks)
118143

119144
for result in results:
120145
xrt.assert_identical(result, ds["foo"].variable.load())
121146

122-
assert total_time > self.LATENCY # Cannot possibly be quicker than this
123-
assert (
124-
total_time < self.LATENCY * self.N_LOADS
125-
) # If this isn't true we're gaining nothing from async
126-
assert (
127-
abs(total_time - self.LATENCY) < 0.5
128-
) # Should take approximately LATENCY seconds, but allow some buffer
147+
self.assert_time_as_expected(timer.total_time)
129148

130149
async def test_async_load_dataarray(self, memorystore):
131150
latencystore = LatencyStore(memorystore, latency=self.LATENCY)
132151
ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
133152

134153
# TODO change the syntax to `.async.load()`?
135-
start_time = time.time()
136-
tasks = [ds["foo"].async_load() for _ in range(self.N_LOADS)]
137-
results = await asyncio.gather(*tasks)
138-
total_time = time.time() - start_time
154+
async with AsyncTimer().measure() as timer:
155+
tasks = [ds["foo"].async_load() for _ in range(self.N_LOADS)]
156+
results = await asyncio.gather(*tasks)
139157

140158
for result in results:
141159
xrt.assert_identical(result, ds["foo"].load())
142160

143-
assert total_time > self.LATENCY # Cannot possibly be quicker than this
144-
assert (
145-
total_time < self.LATENCY * self.N_LOADS
146-
) # If this isn't true we're gaining nothing from async
147-
assert (
148-
abs(total_time - self.LATENCY) < 0.5
149-
) # Should take approximately LATENCY seconds, but allow some buffer
161+
self.assert_time_as_expected(timer.total_time)
150162

151163
async def test_async_load_dataset(self, memorystore):
152164
latencystore = LatencyStore(memorystore, latency=self.LATENCY)
153165
ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
154166

155167
# TODO change the syntax to `.async.load()`?
156-
start_time = time.time()
157-
tasks = [ds.async_load() for _ in range(self.N_LOADS)]
158-
results = await asyncio.gather(*tasks)
159-
total_time = time.time() - start_time
168+
async with AsyncTimer().measure() as timer:
169+
tasks = [ds.async_load() for _ in range(self.N_LOADS)]
170+
results = await asyncio.gather(*tasks)
160171

161172
for result in results:
162173
xrt.assert_identical(result, ds.load())
163174

164-
assert total_time > self.LATENCY # Cannot possibly be quicker than this
165-
assert (
166-
total_time < self.LATENCY * self.N_LOADS
167-
) # If this isn't true we're gaining nothing from async
168-
assert (
169-
abs(total_time - self.LATENCY) < 0.5
170-
) # Should take approximately LATENCY seconds, but allow some buffer
175+
self.assert_time_as_expected(timer.total_time)

0 commit comments

Comments
 (0)