Skip to content

Commit 875b0f0

Browse files
committed
feat: support coroutine checks in bot,wait_for
Without breaking!
1 parent d4738f1 commit 875b0f0

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

interactions/client/client.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Sequence,
2323
Type,
2424
Union,
25+
Awaitable,
2526
)
2627

2728
import interactions.api.events as events
@@ -935,6 +936,17 @@ async def stop(self) -> None:
935936
await self.http.close()
936937
await self._connection_state.stop()
937938

939+
async def _process_waits(self, event: events.BaseEvent) -> None:
940+
if _waits := self.waits.get(event.resolved_name, []):
941+
index_to_remove = []
942+
for i, _wait in enumerate(_waits):
943+
result = await _wait(event)
944+
if result:
945+
index_to_remove.append(i)
946+
947+
for idx in sorted(index_to_remove, reverse=True):
948+
_waits.pop(idx)
949+
938950
def dispatch(self, event: events.BaseEvent, *args, **kwargs) -> None:
939951
"""
940952
Dispatch an event.
@@ -954,15 +966,7 @@ def dispatch(self, event: events.BaseEvent, *args, **kwargs) -> None:
954966
f"An error occurred attempting during {event.resolved_name} event processing"
955967
) from e
956968

957-
if _waits := self.waits.get(event.resolved_name, []):
958-
index_to_remove = []
959-
for i, _wait in enumerate(_waits):
960-
result = _wait(event)
961-
if result:
962-
index_to_remove.append(i)
963-
964-
for idx in sorted(index_to_remove, reverse=True):
965-
_waits.pop(idx)
969+
asyncio.create_task(self._process_waits(event))
966970

967971
if "event" in self.listeners:
968972
# special meta event listener
@@ -976,7 +980,7 @@ async def wait_until_ready(self) -> None:
976980
def wait_for(
977981
self,
978982
event: Union[str, "BaseEvent"],
979-
checks: Absent[Optional[Callable[..., bool]]] = MISSING,
983+
checks: Absent[Optional[Union[Callable[..., bool], Callable[..., Awaitable[bool]]]]] = MISSING,
980984
timeout: Optional[float] = None,
981985
) -> Any:
982986
"""
@@ -989,7 +993,6 @@ def wait_for(
989993
990994
Returns:
991995
The event object.
992-
993996
"""
994997
event = get_event_name(event)
995998

interactions/models/internal/wait.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,29 @@
1-
from asyncio import Future
2-
from typing import Callable, Optional
1+
from asyncio import Future, iscoroutinefunction
2+
from typing import Callable, Optional, Union, Awaitable
33

44
__all__ = ("Wait",)
55

66

77
class Wait:
88
"""Class for waiting for a future event to happen. Internally used by wait_for."""
99

10-
def __init__(self, event: str, checks: Optional[Callable[..., bool]], future: Future) -> None:
11-
self.event = event
12-
self.checks = checks
13-
self.future = future
10+
def __init__(
11+
self, event: str, checks: Optional[Union[Callable[..., bool], Callable[..., Awaitable[bool]]]], future: Future
12+
) -> None:
13+
self.event: str = event
14+
self.check: Optional[Union[Callable[..., bool], Callable[..., Awaitable[bool]]]] = checks
15+
self.future: Future = future
1416

15-
def __call__(self, *args, **kwargs) -> bool:
17+
async def __call__(self, *args, **kwargs) -> bool:
1618
if self.future.cancelled():
1719
return True
1820

19-
if self.checks:
21+
if self.check:
2022
try:
21-
check_result = self.checks(*args, **kwargs)
23+
if iscoroutinefunction(self.check):
24+
check_result = await self.check(*args, **kwargs)
25+
else:
26+
check_result = self.check(*args, **kwargs)
2227
except Exception as exc:
2328
self.future.set_exception(exc)
2429
return True

0 commit comments

Comments
 (0)