|
8 | 8 | Any,
|
9 | 9 | Awaitable,
|
10 | 10 | Callable,
|
11 |
| - Coroutine, |
12 | 11 | Dict,
|
13 | 12 | List,
|
14 | 13 | Optional,
|
15 |
| - Union, |
16 | 14 | )
|
17 | 15 |
|
18 | 16 | from graphql import GraphQLError, GraphQLSyntaxError, parse
|
@@ -105,7 +103,7 @@ def on_request_accepted(self) -> None:
|
105 | 103 |
|
106 | 104 | async def handle_connection_init_timeout(self) -> None:
|
107 | 105 | task = asyncio.current_task()
|
108 |
| - assert task is not None # for typecheckers |
| 106 | + assert task |
109 | 107 | try:
|
110 | 108 | delay = self.connection_init_wait_timeout.total_seconds()
|
111 | 109 | await asyncio.sleep(delay=delay)
|
@@ -266,13 +264,12 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None:
|
266 | 264 | operation_name=message.payload.operationName,
|
267 | 265 | )
|
268 | 266 |
|
269 |
| - # create AsyncGenerator returning a single result |
270 |
| - async def single_result() -> AsyncIterator[ExecutionResult]: |
271 |
| - yield result # type: ignore |
| 267 | + operation = Operation(self, message.id, operation_type) |
272 | 268 |
|
273 | 269 | # Create task to handle this subscription, reserve the operation ID
|
274 |
| - operation = Operation(self, message.id, operation_type, start_operation) |
275 |
| - operation.task = asyncio.create_task(self.operation_task(operation)) |
| 270 | + operation.task = asyncio.create_task( |
| 271 | + self.operation_task(result_source, operation) |
| 272 | + ) |
276 | 273 | self.operations[message.id] = operation
|
277 | 274 |
|
278 | 275 | async def operation_task(
|
@@ -302,11 +299,9 @@ async def operation_task(
|
302 | 299 | self.operations.pop(operation.id, None)
|
303 | 300 | raise
|
304 | 301 | finally:
|
305 |
| - # Clenaup. Remove the operation from the list of active operations |
306 |
| - if operation.id in self.operations: |
307 |
| - del self.operations[operation.id] |
308 |
| - # TODO: Stop collecting background tasks, not necessary. |
309 |
| - # Add this task to a list to be reaped later |
| 302 | + # add this task to a list to be reaped later |
| 303 | + task = asyncio.current_task() |
| 304 | + assert task is not None |
310 | 305 | self.completed_tasks.append(task)
|
311 | 306 |
|
312 | 307 | def forget_id(self, id: str) -> None:
|
@@ -344,35 +339,23 @@ async def reap_completed_tasks(self) -> None:
|
344 | 339 | class Operation:
|
345 | 340 | """A class encapsulating a single operation with its id. Helps enforce protocol state transition."""
|
346 | 341 |
|
347 |
| - __slots__ = [ |
348 |
| - "handler", |
349 |
| - "id", |
350 |
| - "operation_type", |
351 |
| - "start_operation", |
352 |
| - "completed", |
353 |
| - "task", |
354 |
| - ] |
| 342 | + __slots__ = ["handler", "id", "operation_type", "completed", "task"] |
355 | 343 |
|
356 | 344 | def __init__(
|
357 | 345 | self,
|
358 | 346 | handler: BaseGraphQLTransportWSHandler,
|
359 | 347 | id: str,
|
360 | 348 | operation_type: OperationType,
|
361 |
| - start_operation: Callable[ |
362 |
| - [], Coroutine[Any, Any, Union[Any, AsyncGenerator[Any, None]]] |
363 |
| - ], |
364 | 349 | ) -> None:
|
365 | 350 | self.handler = handler
|
366 | 351 | self.id = id
|
367 | 352 | self.operation_type = operation_type
|
368 |
| - self.start_operation = start_operation |
369 | 353 | self.completed = False
|
370 | 354 | self.task: Optional[asyncio.Task] = None
|
371 | 355 |
|
372 | 356 | async def send_message(self, message: GraphQLTransportMessage) -> None:
|
373 |
| - # defensive check, should never happen |
374 | 357 | if self.completed:
|
375 |
| - return # pragma: no cover |
| 358 | + return |
376 | 359 | if isinstance(message, (CompleteMessage, ErrorMessage)):
|
377 | 360 | self.completed = True
|
378 | 361 | # de-register the operation _before_ sending the final message
|
|
0 commit comments