@@ -153,7 +153,7 @@ def to_kelvin(value: float) -> int:
153
153
FuncT = TypeVar ("FuncT" , bound = Callable [..., Any ])
154
154
155
155
156
- def _websocket_exception_handler (request_fn : FuncT ) -> FuncT :
156
+ def _websocket_retry_wrapper (request_fn : FuncT ) -> FuncT :
157
157
retry_on_exceptions = (
158
158
websockets .InvalidHandshake ,
159
159
websockets .InvalidState ,
@@ -260,18 +260,10 @@ def _encode_pair(
260
260
261
261
return address , raw_value
262
262
263
- @_websocket_exception_handler
264
263
async def _websocket_request (self , payload : bytes ) -> bytes :
265
- async with websockets .connect (
266
- f"ws://{ self .ip_address } /" ,
267
- open_timeout = WEBSOCKETS_OPEN_TIMEOUT ,
268
- logger = logger ,
269
- ) as ws :
270
- await ws .send (payload )
271
- task = asyncio .create_task (ws .recv ())
272
- return await asyncio .wait_for (task , timeout = WEBSOCKETS_RECV_TIMEOUT )
264
+ return (await self ._websocket_request_multiple (payload , 1 ))[0 ]
273
265
274
- @_websocket_exception_handler
266
+ @_websocket_retry_wrapper
275
267
async def _websocket_request_multiple (
276
268
self , payload : bytes , read_packets : int
277
269
) -> List [bytes ]:
@@ -281,8 +273,13 @@ async def _websocket_request_multiple(
281
273
logger = logger ,
282
274
) as ws :
283
275
await ws .send (payload )
284
- tasks = asyncio .gather (* [ws .recv () for _ in range (0 , read_packets )])
285
- return await asyncio .wait_for (tasks , timeout = WEBSOCKETS_RECV_TIMEOUT )
276
+
277
+ async def _get_responses () -> List [bytes ]:
278
+ return [await ws .recv () for _ in range (0 , read_packets )]
279
+
280
+ return await asyncio .wait_for (
281
+ _get_responses (), timeout = WEBSOCKETS_RECV_TIMEOUT * read_packets
282
+ )
286
283
287
284
async def fetch_metrics (
288
285
self , metric_keys : Optional [List [str ]] = None
0 commit comments