Skip to content

Commit e0731a0

Browse files
committed
consolidate tests via a parametrized fixture
1 parent d288351 commit e0731a0

File tree

1 file changed

+17
-35
lines changed

1 file changed

+17
-35
lines changed

xarray/tests/test_async.py

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -123,53 +123,35 @@ class TestAsyncLoad:
123123
N_LOADS: int = 5
124124
LATENCY: float = 1.0
125125

126+
@pytest.fixture(params=["ds", "da", "var"])
127+
def xr_obj(self, request, memorystore) -> xr.Dataset | xr.DataArray | xr.Variable:
128+
latencystore = LatencyStore(memorystore, latency=self.LATENCY)
129+
ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
130+
131+
match request.param:
132+
case "var":
133+
return ds["foo"].variable
134+
case "da":
135+
return ds["foo"]
136+
case "ds":
137+
return ds
138+
126139
def assert_time_as_expected(self, total_time: float) -> None:
127140
assert total_time > self.LATENCY # Cannot possibly be quicker than this
128141
assert (
129142
total_time < self.LATENCY * self.N_LOADS
130143
) # If this isn't true we're gaining nothing from async
131144
assert (
132-
abs(total_time - self.LATENCY) < 0.5
145+
abs(total_time - self.LATENCY) < 2.0
133146
) # Should take approximately LATENCY seconds, but allow some buffer
134147

135-
async def test_async_load_variable(self, memorystore):
136-
latencystore = LatencyStore(memorystore, latency=self.LATENCY)
137-
ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
138-
139-
# TODO change the syntax to `.async.load()`?
140-
async with AsyncTimer().measure() as timer:
141-
tasks = [ds["foo"].variable.async_load() for _ in range(self.N_LOADS)]
142-
results = await asyncio.gather(*tasks)
143-
144-
for result in results:
145-
xrt.assert_identical(result, ds["foo"].variable.load())
146-
147-
self.assert_time_as_expected(timer.total_time)
148-
149-
async def test_async_load_dataarray(self, memorystore):
150-
latencystore = LatencyStore(memorystore, latency=self.LATENCY)
151-
ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
152-
153-
# TODO change the syntax to `.async.load()`?
154-
async with AsyncTimer().measure() as timer:
155-
tasks = [ds["foo"].async_load() for _ in range(self.N_LOADS)]
156-
results = await asyncio.gather(*tasks)
157-
158-
for result in results:
159-
xrt.assert_identical(result, ds["foo"].load())
160-
161-
self.assert_time_as_expected(timer.total_time)
162-
163-
async def test_async_load_dataset(self, memorystore):
164-
latencystore = LatencyStore(memorystore, latency=self.LATENCY)
165-
ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
166-
148+
async def test_async_load(self, xr_obj):
167149
# TODO change the syntax to `.async.load()`?
168150
async with AsyncTimer().measure() as timer:
169-
tasks = [ds.async_load() for _ in range(self.N_LOADS)]
151+
tasks = [xr_obj.async_load() for _ in range(self.N_LOADS)]
170152
results = await asyncio.gather(*tasks)
171153

172154
for result in results:
173-
xrt.assert_identical(result, ds.load())
155+
xrt.assert_identical(result, xr_obj.load())
174156

175157
self.assert_time_as_expected(timer.total_time)

0 commit comments

Comments
 (0)