|
1 | 1 | import asyncio
|
2 | 2 | import time
|
3 | 3 | from collections.abc import Iterable
|
| 4 | +from contextlib import asynccontextmanager |
4 | 5 | from typing import TypeVar
|
5 | 6 |
|
6 | 7 | import numpy as np
|
@@ -98,73 +99,77 @@ def memorystore() -> "MemoryStore":
|
98 | 99 | return memorystore
|
99 | 100 |
|
100 | 101 |
|
| 102 | +class AsyncTimer: |
| 103 | + """Context manager for timing async operations and making assertions about their execution time.""" |
| 104 | + |
| 105 | + start_time: float |
| 106 | + end_time: float |
| 107 | + total_time: float |
| 108 | + |
| 109 | + @asynccontextmanager |
| 110 | + async def measure(self): |
| 111 | + """Measure the execution time of the async code within this context.""" |
| 112 | + self.start_time = time.time() |
| 113 | + try: |
| 114 | + yield self |
| 115 | + finally: |
| 116 | + self.end_time = time.time() |
| 117 | + self.total_time = self.end_time - self.start_time |
| 118 | + |
| 119 | + |
101 | 120 | @requires_zarr_v3
|
102 | 121 | @pytest.mark.asyncio
|
103 | 122 | class TestAsyncLoad:
|
104 |
| - N_LOADS = 10 |
105 |
| - LATENCY = 1.0 |
| 123 | + N_LOADS: int = 5 |
| 124 | + LATENCY: float = 1.0 |
| 125 | + |
| 126 | + def assert_time_as_expected(self, total_time: float) -> None: |
| 127 | + assert total_time > self.LATENCY # Cannot possibly be quicker than this |
| 128 | + assert ( |
| 129 | + total_time < self.LATENCY * self.N_LOADS |
| 130 | + ) # If this isn't true we're gaining nothing from async |
| 131 | + assert ( |
| 132 | + abs(total_time - self.LATENCY) < 0.5 |
| 133 | + ) # Should take approximately LATENCY seconds, but allow some buffer |
106 | 134 |
|
107 |
| - # TODO refactor these tests |
108 | 135 | async def test_async_load_variable(self, memorystore):
|
109 | 136 | latencystore = LatencyStore(memorystore, latency=self.LATENCY)
|
110 | 137 | ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
|
111 | 138 |
|
112 |
| - # TODO add async load to Dataset and DataArray as well as to Variable |
113 | 139 | # TODO change the syntax to `.async.load()`?
|
114 |
| - start_time = time.time() |
115 |
| - tasks = [ds["foo"].variable.async_load() for _ in range(self.N_LOADS)] |
116 |
| - results = await asyncio.gather(*tasks) |
117 |
| - total_time = time.time() - start_time |
| 140 | + async with AsyncTimer().measure() as timer: |
| 141 | + tasks = [ds["foo"].variable.async_load() for _ in range(self.N_LOADS)] |
| 142 | + results = await asyncio.gather(*tasks) |
118 | 143 |
|
119 | 144 | for result in results:
|
120 | 145 | xrt.assert_identical(result, ds["foo"].variable.load())
|
121 | 146 |
|
122 |
| - assert total_time > self.LATENCY # Cannot possibly be quicker than this |
123 |
| - assert ( |
124 |
| - total_time < self.LATENCY * self.N_LOADS |
125 |
| - ) # If this isn't true we're gaining nothing from async |
126 |
| - assert ( |
127 |
| - abs(total_time - self.LATENCY) < 0.5 |
128 |
| - ) # Should take approximately LATENCY seconds, but allow some buffer |
| 147 | + self.assert_time_as_expected(timer.total_time) |
129 | 148 |
|
130 | 149 | async def test_async_load_dataarray(self, memorystore):
|
131 | 150 | latencystore = LatencyStore(memorystore, latency=self.LATENCY)
|
132 | 151 | ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
|
133 | 152 |
|
134 | 153 | # TODO change the syntax to `.async.load()`?
|
135 |
| - start_time = time.time() |
136 |
| - tasks = [ds["foo"].async_load() for _ in range(self.N_LOADS)] |
137 |
| - results = await asyncio.gather(*tasks) |
138 |
| - total_time = time.time() - start_time |
| 154 | + async with AsyncTimer().measure() as timer: |
| 155 | + tasks = [ds["foo"].async_load() for _ in range(self.N_LOADS)] |
| 156 | + results = await asyncio.gather(*tasks) |
139 | 157 |
|
140 | 158 | for result in results:
|
141 | 159 | xrt.assert_identical(result, ds["foo"].load())
|
142 | 160 |
|
143 |
| - assert total_time > self.LATENCY # Cannot possibly be quicker than this |
144 |
| - assert ( |
145 |
| - total_time < self.LATENCY * self.N_LOADS |
146 |
| - ) # If this isn't true we're gaining nothing from async |
147 |
| - assert ( |
148 |
| - abs(total_time - self.LATENCY) < 0.5 |
149 |
| - ) # Should take approximately LATENCY seconds, but allow some buffer |
| 161 | + self.assert_time_as_expected(timer.total_time) |
150 | 162 |
|
151 | 163 | async def test_async_load_dataset(self, memorystore):
|
152 | 164 | latencystore = LatencyStore(memorystore, latency=self.LATENCY)
|
153 | 165 | ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
|
154 | 166 |
|
155 | 167 | # TODO change the syntax to `.async.load()`?
|
156 |
| - start_time = time.time() |
157 |
| - tasks = [ds.async_load() for _ in range(self.N_LOADS)] |
158 |
| - results = await asyncio.gather(*tasks) |
159 |
| - total_time = time.time() - start_time |
| 168 | + async with AsyncTimer().measure() as timer: |
| 169 | + tasks = [ds.async_load() for _ in range(self.N_LOADS)] |
| 170 | + results = await asyncio.gather(*tasks) |
160 | 171 |
|
161 | 172 | for result in results:
|
162 | 173 | xrt.assert_identical(result, ds.load())
|
163 | 174 |
|
164 |
| - assert total_time > self.LATENCY # Cannot possibly be quicker than this |
165 |
| - assert ( |
166 |
| - total_time < self.LATENCY * self.N_LOADS |
167 |
| - ) # If this isn't true we're gaining nothing from async |
168 |
| - assert ( |
169 |
| - abs(total_time - self.LATENCY) < 0.5 |
170 |
| - ) # Should take approximately LATENCY seconds, but allow some buffer |
| 175 | + self.assert_time_as_expected(timer.total_time) |
0 commit comments