Skip to content

Commit 1be5a06

Browse files
committed
revert the handler to original state
1 parent 3d97deb commit 1be5a06

File tree

1 file changed

+10
-27
lines changed
  • strawberry/subscriptions/protocols/graphql_transport_ws

1 file changed

+10
-27
lines changed

strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88
Any,
99
Awaitable,
1010
Callable,
11-
Coroutine,
1211
Dict,
1312
List,
1413
Optional,
15-
Union,
1614
)
1715

1816
from graphql import GraphQLError, GraphQLSyntaxError, parse
@@ -105,7 +103,7 @@ def on_request_accepted(self) -> None:
105103

106104
async def handle_connection_init_timeout(self) -> None:
107105
task = asyncio.current_task()
108-
assert task is not None # for typecheckers
106+
assert task
109107
try:
110108
delay = self.connection_init_wait_timeout.total_seconds()
111109
await asyncio.sleep(delay=delay)
@@ -266,13 +264,12 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None:
266264
operation_name=message.payload.operationName,
267265
)
268266

269-
# create AsyncGenerator returning a single result
270-
async def single_result() -> AsyncIterator[ExecutionResult]:
271-
yield result # type: ignore
267+
operation = Operation(self, message.id, operation_type)
272268

273269
# Create task to handle this subscription, reserve the operation ID
274-
operation = Operation(self, message.id, operation_type, start_operation)
275-
operation.task = asyncio.create_task(self.operation_task(operation))
270+
operation.task = asyncio.create_task(
271+
self.operation_task(result_source, operation)
272+
)
276273
self.operations[message.id] = operation
277274

278275
async def operation_task(
@@ -302,11 +299,9 @@ async def operation_task(
302299
self.operations.pop(operation.id, None)
303300
raise
304301
finally:
305-
# Clenaup. Remove the operation from the list of active operations
306-
if operation.id in self.operations:
307-
del self.operations[operation.id]
308-
# TODO: Stop collecting background tasks, not necessary.
309-
# Add this task to a list to be reaped later
302+
# add this task to a list to be reaped later
303+
task = asyncio.current_task()
304+
assert task is not None
310305
self.completed_tasks.append(task)
311306

312307
def forget_id(self, id: str) -> None:
@@ -344,35 +339,23 @@ async def reap_completed_tasks(self) -> None:
344339
class Operation:
345340
"""A class encapsulating a single operation with its id. Helps enforce protocol state transition."""
346341

347-
__slots__ = [
348-
"handler",
349-
"id",
350-
"operation_type",
351-
"start_operation",
352-
"completed",
353-
"task",
354-
]
342+
__slots__ = ["handler", "id", "operation_type", "completed", "task"]
355343

356344
def __init__(
357345
self,
358346
handler: BaseGraphQLTransportWSHandler,
359347
id: str,
360348
operation_type: OperationType,
361-
start_operation: Callable[
362-
[], Coroutine[Any, Any, Union[Any, AsyncGenerator[Any, None]]]
363-
],
364349
) -> None:
365350
self.handler = handler
366351
self.id = id
367352
self.operation_type = operation_type
368-
self.start_operation = start_operation
369353
self.completed = False
370354
self.task: Optional[asyncio.Task] = None
371355

372356
async def send_message(self, message: GraphQLTransportMessage) -> None:
373-
# defensive check, should never happen
374357
if self.completed:
375-
return # pragma: no cover
358+
return
376359
if isinstance(message, (CompleteMessage, ErrorMessage)):
377360
self.completed = True
378361
# de-register the operation _before_ sending the final message

0 commit comments

Comments
 (0)