@@ -816,59 +816,83 @@ async def command_error(ctx, error):
816
816
message = f"Your command needs at least { 'three parameters to return self, context, and the' if self .extension else 'two parameter to return context and' } error." ,
817
817
)
818
818
819
- self .error_callback = self .__wrap_coro (coro )
819
+ self .error_callback = self .__wrap_coro (coro , error_callback = True )
820
820
return coro
821
821
822
822
async def __call (
823
823
self ,
824
824
coro : Callable [..., Awaitable ],
825
825
ctx : "CommandContext" ,
826
- * args ,
826
+ * args , # empty for now since all parameters are dispatched as kwargs
827
827
_name : Optional [str ] = None ,
828
828
_res : Optional [Union [BaseResult , GroupResult ]] = None ,
829
829
** kwargs ,
830
830
) -> Optional [Any ]:
831
831
"""Handles calling the coroutine based on parameter count."""
832
- param_len = len (signature (coro ).parameters )
833
- opt_len = self .num_options .get (_name , len (args ) + len (kwargs ))
832
+ params = signature (coro ).parameters
833
+ param_len = len (params )
834
+ opt_len = self .num_options .get (_name , len (args ) + len (kwargs )) # options of slash command
835
+ last = params [list (params )[- 1 ]] # last parameter
836
+ has_args = any (param .kind == param .VAR_POSITIONAL for param in params .values ()) # any *args
837
+ index_of_var_pos = next (
838
+ (i for i , param in enumerate (params .values ()) if param .kind == param .VAR_POSITIONAL ),
839
+ param_len ,
840
+ ) # index of *args
841
+ par_opts = list (params .keys ())[
842
+ (num := 2 if self .extension else 1 ) : (
843
+ - 1 if last .kind in (last .VAR_POSITIONAL , last .VAR_KEYWORD ) else index_of_var_pos
844
+ )
845
+ ] # parameters that are before *args and **kwargs
846
+ keyword_only_args = list (params .keys ())[index_of_var_pos :] # parameters after *args
834
847
835
848
try :
836
849
_coro = coro if hasattr (coro , "_wrapped" ) else self .__wrap_coro (coro )
837
850
838
- if param_len < (2 if self .extension else 1 ):
851
+ if last .kind == last .VAR_KEYWORD : # foo(ctx, ..., **kwargs)
852
+ return await _coro (ctx , * args , ** kwargs )
853
+ if last .kind == last .VAR_POSITIONAL : # foo(ctx, ..., *args)
854
+ return await _coro (
855
+ ctx ,
856
+ * (kwargs [opt ] for opt in par_opts if opt in kwargs ),
857
+ * args ,
858
+ )
859
+ if has_args : # foo(ctx, ..., *args, ..., **kwargs) OR foo(ctx, *args, ...)
860
+ return await _coro (
861
+ ctx ,
862
+ * (kwargs [opt ] for opt in par_opts if opt in kwargs ), # pos before *args
863
+ * args ,
864
+ * (
865
+ kwargs [opt ]
866
+ for opt in kwargs
867
+ if opt not in par_opts and opt not in keyword_only_args
868
+ ), # additional args
869
+ ** {
870
+ opt : kwargs [opt ]
871
+ for opt in kwargs
872
+ if opt not in par_opts and opt in keyword_only_args
873
+ }, # kwargs after *args
874
+ )
875
+
876
+ if param_len < num :
877
+ inner_msg : str = f"{ num } parameter{ 's' if num > 1 else '' } to return" + (
878
+ " self and" if self .extension else ""
879
+ )
839
880
raise LibraryException (
840
- code = 11 ,
841
- message = f"Your command needs at least { 'two parameters to return self and' if self .extension else 'one parameter to return' } context." ,
881
+ code = 11 , message = f"Your command needs at least { inner_msg } context."
842
882
)
843
883
844
- if param_len == ( 2 if self . extension else 1 ) :
884
+ if param_len == num :
845
885
return await _coro (ctx )
846
886
847
887
if _res :
848
- if param_len - opt_len == ( 2 if self . extension else 1 ) :
888
+ if param_len - opt_len == num :
849
889
return await _coro (ctx , * args , ** kwargs )
850
- elif param_len - opt_len == ( 3 if self . extension else 2 ) :
890
+ elif param_len - opt_len == num + 1 :
851
891
return await _coro (ctx , _res , * args , ** kwargs )
852
892
853
893
return await _coro (ctx , * args , ** kwargs )
854
894
except CancelledError :
855
895
pass
856
- except Exception as e :
857
- if self .error_callback :
858
- num_params = len (signature (self .error_callback ).parameters )
859
-
860
- if num_params == (3 if self .extension else 2 ):
861
- await self .error_callback (ctx , e )
862
- elif num_params == (4 if self .extension else 3 ):
863
- await self .error_callback (ctx , e , _res )
864
- else :
865
- await self .error_callback (ctx , e , _res , * args , ** kwargs )
866
- elif self .listener and "on_command_error" in self .listener .events :
867
- self .listener .dispatch ("on_command_error" , ctx , e )
868
- else :
869
- raise e
870
-
871
- return StopCommand
872
896
873
897
def __check_command (self , command_type : str ) -> None :
874
898
"""Checks if subcommands, groups, or autocompletions are created on context menus."""
@@ -895,7 +919,9 @@ async def __no_group(self, *args, **kwargs) -> None:
895
919
"""This is the coroutine used when no group coroutine is provided."""
896
920
pass
897
921
898
- def __wrap_coro (self , coro : Callable [..., Awaitable ]) -> Callable [..., Awaitable ]:
922
+ def __wrap_coro (
923
+ self , coro : Callable [..., Awaitable ], / , * , error_callback : bool = False
924
+ ) -> Callable [..., Awaitable ]:
899
925
"""Wraps a coroutine to make sure the :class:`interactions.client.bot.Extension` is passed to the coroutine, if any."""
900
926
901
927
@wraps (coro )
@@ -907,11 +933,28 @@ async def wrapper(ctx: "CommandContext", *args, **kwargs):
907
933
except CancelledError :
908
934
pass
909
935
except Exception as e :
936
+ if error_callback :
937
+ raise e
910
938
if self .error_callback :
911
- num_params = len (signature (self .error_callback ).parameters )
912
-
913
- if num_params == (3 if self .extension else 2 ):
939
+ params = signature (self .error_callback ).parameters
940
+ num_params = len (params )
941
+ last = params [list (params )[- 1 ]]
942
+ num = 2 if self .extension else 1
943
+
944
+ if num_params == num :
945
+ await self .error_callback (ctx )
946
+ elif num_params == num + 1 :
914
947
await self .error_callback (ctx , e )
948
+ elif last .kind == last .VAR_KEYWORD :
949
+ if num_params == num + 2 :
950
+ await self .error_callback (ctx , e , ** kwargs )
951
+ elif num_params >= num + 3 :
952
+ await self .error_callback (ctx , e , * args , ** kwargs )
953
+ elif last .kind == last .VAR_POSITIONAL :
954
+ if num_params == num + 2 :
955
+ await self .error_callback (ctx , e , * args )
956
+ elif num_params >= num + 3 :
957
+ await self .error_callback (ctx , e , * args , ** kwargs )
915
958
else :
916
959
await self .error_callback (ctx , e , * args , ** kwargs )
917
960
elif self .listener and "on_command_error" in self .listener .events :
0 commit comments