@@ -123,53 +123,35 @@ class TestAsyncLoad:
123
123
N_LOADS : int = 5
124
124
LATENCY : float = 1.0
125
125
126
+ @pytest .fixture (params = ["ds" , "da" , "var" ])
127
+ def xr_obj (self , request , memorystore ) -> xr .Dataset | xr .DataArray | xr .Variable :
128
+ latencystore = LatencyStore (memorystore , latency = self .LATENCY )
129
+ ds = xr .open_zarr (latencystore , zarr_format = 3 , consolidated = False , chunks = None )
130
+
131
+ match request .param :
132
+ case "var" :
133
+ return ds ["foo" ].variable
134
+ case "da" :
135
+ return ds ["foo" ]
136
+ case "ds" :
137
+ return ds
138
+
126
139
def assert_time_as_expected (self , total_time : float ) -> None :
127
140
assert total_time > self .LATENCY # Cannot possibly be quicker than this
128
141
assert (
129
142
total_time < self .LATENCY * self .N_LOADS
130
143
) # If this isn't true we're gaining nothing from async
131
144
assert (
132
- abs (total_time - self .LATENCY ) < 0.5
145
+ abs (total_time - self .LATENCY ) < 2.0
133
146
) # Should take approximately LATENCY seconds, but allow some buffer
134
147
135
- async def test_async_load_variable (self , memorystore ):
136
- latencystore = LatencyStore (memorystore , latency = self .LATENCY )
137
- ds = xr .open_zarr (latencystore , zarr_format = 3 , consolidated = False , chunks = None )
138
-
139
- # TODO change the syntax to `.async.load()`?
140
- async with AsyncTimer ().measure () as timer :
141
- tasks = [ds ["foo" ].variable .async_load () for _ in range (self .N_LOADS )]
142
- results = await asyncio .gather (* tasks )
143
-
144
- for result in results :
145
- xrt .assert_identical (result , ds ["foo" ].variable .load ())
146
-
147
- self .assert_time_as_expected (timer .total_time )
148
-
149
- async def test_async_load_dataarray (self , memorystore ):
150
- latencystore = LatencyStore (memorystore , latency = self .LATENCY )
151
- ds = xr .open_zarr (latencystore , zarr_format = 3 , consolidated = False , chunks = None )
152
-
153
- # TODO change the syntax to `.async.load()`?
154
- async with AsyncTimer ().measure () as timer :
155
- tasks = [ds ["foo" ].async_load () for _ in range (self .N_LOADS )]
156
- results = await asyncio .gather (* tasks )
157
-
158
- for result in results :
159
- xrt .assert_identical (result , ds ["foo" ].load ())
160
-
161
- self .assert_time_as_expected (timer .total_time )
162
-
163
- async def test_async_load_dataset (self , memorystore ):
164
- latencystore = LatencyStore (memorystore , latency = self .LATENCY )
165
- ds = xr .open_zarr (latencystore , zarr_format = 3 , consolidated = False , chunks = None )
166
-
148
+ async def test_async_load (self , xr_obj ):
167
149
# TODO change the syntax to `.async.load()`?
168
150
async with AsyncTimer ().measure () as timer :
169
- tasks = [ds .async_load () for _ in range (self .N_LOADS )]
151
+ tasks = [xr_obj .async_load () for _ in range (self .N_LOADS )]
170
152
results = await asyncio .gather (* tasks )
171
153
172
154
for result in results :
173
- xrt .assert_identical (result , ds .load ())
155
+ xrt .assert_identical (result , xr_obj .load ())
174
156
175
157
self .assert_time_as_expected (timer .total_time )
0 commit comments