Skip to content

Commit 497c7a5

Browse files
authored
fix: autodefer with args and kwargs in commands (#1074)
* fix: args and kwargs in commands * fix: autodefer errors
1 parent 7d5a369 commit 497c7a5

File tree

2 files changed

+77
-33
lines changed

2 files changed

+77
-33
lines changed

interactions/client/models/command.py

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -816,59 +816,83 @@ async def command_error(ctx, error):
816816
message=f"Your command needs at least {'three parameters to return self, context, and the' if self.extension else 'two parameter to return context and'} error.",
817817
)
818818

819-
self.error_callback = self.__wrap_coro(coro)
819+
self.error_callback = self.__wrap_coro(coro, error_callback=True)
820820
return coro
821821

822822
async def __call(
823823
self,
824824
coro: Callable[..., Awaitable],
825825
ctx: "CommandContext",
826-
*args,
826+
*args, # empty for now since all parameters are dispatched as kwargs
827827
_name: Optional[str] = None,
828828
_res: Optional[Union[BaseResult, GroupResult]] = None,
829829
**kwargs,
830830
) -> Optional[Any]:
831831
"""Handles calling the coroutine based on parameter count."""
832-
param_len = len(signature(coro).parameters)
833-
opt_len = self.num_options.get(_name, len(args) + len(kwargs))
832+
params = signature(coro).parameters
833+
param_len = len(params)
834+
opt_len = self.num_options.get(_name, len(args) + len(kwargs)) # options of slash command
835+
last = params[list(params)[-1]] # last parameter
836+
has_args = any(param.kind == param.VAR_POSITIONAL for param in params.values()) # any *args
837+
index_of_var_pos = next(
838+
(i for i, param in enumerate(params.values()) if param.kind == param.VAR_POSITIONAL),
839+
param_len,
840+
) # index of *args
841+
par_opts = list(params.keys())[
842+
(num := 2 if self.extension else 1) : (
843+
-1 if last.kind in (last.VAR_POSITIONAL, last.VAR_KEYWORD) else index_of_var_pos
844+
)
845+
] # parameters that are before *args and **kwargs
846+
keyword_only_args = list(params.keys())[index_of_var_pos:] # parameters after *args
834847

835848
try:
836849
_coro = coro if hasattr(coro, "_wrapped") else self.__wrap_coro(coro)
837850

838-
if param_len < (2 if self.extension else 1):
851+
if last.kind == last.VAR_KEYWORD: # foo(ctx, ..., **kwargs)
852+
return await _coro(ctx, *args, **kwargs)
853+
if last.kind == last.VAR_POSITIONAL: # foo(ctx, ..., *args)
854+
return await _coro(
855+
ctx,
856+
*(kwargs[opt] for opt in par_opts if opt in kwargs),
857+
*args,
858+
)
859+
if has_args: # foo(ctx, ..., *args, ..., **kwargs) OR foo(ctx, *args, ...)
860+
return await _coro(
861+
ctx,
862+
*(kwargs[opt] for opt in par_opts if opt in kwargs), # pos before *args
863+
*args,
864+
*(
865+
kwargs[opt]
866+
for opt in kwargs
867+
if opt not in par_opts and opt not in keyword_only_args
868+
), # additional args
869+
**{
870+
opt: kwargs[opt]
871+
for opt in kwargs
872+
if opt not in par_opts and opt in keyword_only_args
873+
}, # kwargs after *args
874+
)
875+
876+
if param_len < num:
877+
inner_msg: str = f"{num} parameter{'s' if num > 1 else ''} to return" + (
878+
" self and" if self.extension else ""
879+
)
839880
raise LibraryException(
840-
code=11,
841-
message=f"Your command needs at least {'two parameters to return self and' if self.extension else 'one parameter to return'} context.",
881+
code=11, message=f"Your command needs at least {inner_msg} context."
842882
)
843883

844-
if param_len == (2 if self.extension else 1):
884+
if param_len == num:
845885
return await _coro(ctx)
846886

847887
if _res:
848-
if param_len - opt_len == (2 if self.extension else 1):
888+
if param_len - opt_len == num:
849889
return await _coro(ctx, *args, **kwargs)
850-
elif param_len - opt_len == (3 if self.extension else 2):
890+
elif param_len - opt_len == num + 1:
851891
return await _coro(ctx, _res, *args, **kwargs)
852892

853893
return await _coro(ctx, *args, **kwargs)
854894
except CancelledError:
855895
pass
856-
except Exception as e:
857-
if self.error_callback:
858-
num_params = len(signature(self.error_callback).parameters)
859-
860-
if num_params == (3 if self.extension else 2):
861-
await self.error_callback(ctx, e)
862-
elif num_params == (4 if self.extension else 3):
863-
await self.error_callback(ctx, e, _res)
864-
else:
865-
await self.error_callback(ctx, e, _res, *args, **kwargs)
866-
elif self.listener and "on_command_error" in self.listener.events:
867-
self.listener.dispatch("on_command_error", ctx, e)
868-
else:
869-
raise e
870-
871-
return StopCommand
872896

873897
def __check_command(self, command_type: str) -> None:
874898
"""Checks if subcommands, groups, or autocompletions are created on context menus."""
@@ -895,7 +919,9 @@ async def __no_group(self, *args, **kwargs) -> None:
895919
"""This is the coroutine used when no group coroutine is provided."""
896920
pass
897921

898-
def __wrap_coro(self, coro: Callable[..., Awaitable]) -> Callable[..., Awaitable]:
922+
def __wrap_coro(
923+
self, coro: Callable[..., Awaitable], /, *, error_callback: bool = False
924+
) -> Callable[..., Awaitable]:
899925
"""Wraps a coroutine to make sure the :class:`interactions.client.bot.Extension` is passed to the coroutine, if any."""
900926

901927
@wraps(coro)
@@ -907,11 +933,28 @@ async def wrapper(ctx: "CommandContext", *args, **kwargs):
907933
except CancelledError:
908934
pass
909935
except Exception as e:
936+
if error_callback:
937+
raise e
910938
if self.error_callback:
911-
num_params = len(signature(self.error_callback).parameters)
912-
913-
if num_params == (3 if self.extension else 2):
939+
params = signature(self.error_callback).parameters
940+
num_params = len(params)
941+
last = params[list(params)[-1]]
942+
num = 2 if self.extension else 1
943+
944+
if num_params == num:
945+
await self.error_callback(ctx)
946+
elif num_params == num + 1:
914947
await self.error_callback(ctx, e)
948+
elif last.kind == last.VAR_KEYWORD:
949+
if num_params == num + 2:
950+
await self.error_callback(ctx, e, **kwargs)
951+
elif num_params >= num + 3:
952+
await self.error_callback(ctx, e, *args, **kwargs)
953+
elif last.kind == last.VAR_POSITIONAL:
954+
if num_params == num + 2:
955+
await self.error_callback(ctx, e, *args)
956+
elif num_params >= num + 3:
957+
await self.error_callback(ctx, e, *args, **kwargs)
915958
else:
916959
await self.error_callback(ctx, e, *args, **kwargs)
917960
elif self.listener and "on_command_error" in self.listener.events:

interactions/utils/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ..api.models.message import Message
2626
from ..api.models.misc import Snowflake
2727
from ..client.bot import Client, Extension
28-
from ..client.context import CommandContext
28+
from ..client.context import CommandContext # noqa F401
2929

3030
__all__ = (
3131
"autodefer",
@@ -67,7 +67,7 @@ async def command(ctx):
6767
"""
6868

6969
def decorator(coro: Callable[..., Union[Awaitable, Coroutine]]) -> Callable[..., Awaitable]:
70-
from ..client.context import ComponentContext
70+
from ..client.context import CommandContext, ComponentContext # noqa F811
7171

7272
@wraps(coro)
7373
async def deferring_func(
@@ -80,7 +80,8 @@ async def deferring_func(
8080

8181
if isinstance(args[0], (ComponentContext, CommandContext)):
8282
self = ctx
83-
ctx = list(args).pop(0)
83+
args = list(args)
84+
ctx = args.pop(0)
8485

8586
task: Task = loop.create_task(coro(self, ctx, *args, **kwargs))
8687

0 commit comments

Comments
 (0)