Skip to content

Commit ec481ac

Browse files
authored
feat💥: make subcmds able to inherit checks from base cmds (#1274)
Co-authored-by: Astrea49 <25420078+Astrea49@users.noreply.github.com>
1 parent e8db8ce commit ec481ac

File tree

4 files changed

+13
-29
lines changed

4 files changed

+13
-29
lines changed

interactions/ext/prefixed_commands/command.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -346,13 +346,6 @@ class PrefixedCommand(BaseCommand):
346346
),
347347
default=True,
348348
)
349-
hierarchical_checking: bool = attrs.field(
350-
metadata=docs(
351-
"If `True` and if the base of a subcommand, every subcommand underneath it will run this command's checks"
352-
" and cooldowns before its own. Otherwise, only the subcommand's checks are checked."
353-
),
354-
default=True,
355-
)
356349
help: Optional[str] = attrs.field(repr=False, metadata=docs("The long help text for the command."), default=None)
357350
brief: Optional[str] = attrs.field(repr=False, metadata=docs("The short help text for the command."), default=None)
358351
parent: Optional["PrefixedCommand"] = attrs.field(
@@ -638,7 +631,7 @@ def subcommand(
638631
enabled: bool = True,
639632
hidden: bool = False,
640633
ignore_extra: bool = True,
641-
hierarchical_checking: bool = True,
634+
inherit_checks: bool = True,
642635
) -> Callable[..., Self]:
643636
"""
644637
A decorator to declare a subcommand for a prefixed command.
@@ -654,8 +647,7 @@ def subcommand(
654647
hidden: If `True`, the default help command (when it is added) does not show this in the help output.
655648
ignore_extra: If `True`, ignores extraneous strings passed to a command if all its requirements are met \
656649
(e.g. ?foo a b c when only expecting a and b). Otherwise, an error is raised.
657-
hierarchical_checking: If `True` and if the base of a subcommand, every subcommand underneath it will \
658-
run this command's checks before its own. Otherwise, only the subcommand's checks are checked.
650+
inherit_checks: If `True`, the subcommand will inherit its checks from the parent command.
659651
"""
660652

661653
def wrapper(func: Callable) -> Self:
@@ -670,7 +662,7 @@ def wrapper(func: Callable) -> Self:
670662
enabled=enabled,
671663
hidden=hidden,
672664
ignore_extra=ignore_extra,
673-
hierarchical_checking=hierarchical_checking,
665+
checks=self.checks if inherit_checks else [],
674666
)
675667
self.add_command(cmd)
676668
return cmd
@@ -776,7 +768,6 @@ def prefixed_command(
776768
enabled: bool = True,
777769
hidden: bool = False,
778770
ignore_extra: bool = True,
779-
hierarchical_checking: bool = True,
780771
) -> Callable[..., PrefixedCommand]:
781772
"""
782773
A decorator to declare a coroutine as a prefixed command.
@@ -792,8 +783,6 @@ def prefixed_command(
792783
hidden: If `True`, the default help command (when it is added) does not show this in the help output.
793784
ignore_extra: If `True`, ignores extraneous strings passed to a command if all its requirements are \
794785
met (e.g. ?foo a b c when only expecting a and b). Otherwise, an error is raised.
795-
hierarchical_checking: If `True` and if the base of a subcommand, every subcommand underneath it will \
796-
run this command's checks before its own. Otherwise, only the subcommand's checks are checked.
797786
"""
798787

799788
def wrapper(func: Callable) -> PrefixedCommand:
@@ -807,7 +796,6 @@ def wrapper(func: Callable) -> PrefixedCommand:
807796
enabled=enabled,
808797
hidden=hidden,
809798
ignore_extra=ignore_extra,
810-
hierarchical_checking=hierarchical_checking,
811799
)
812800

813801
return wrapper

interactions/ext/prefixed_commands/manager.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -324,18 +324,6 @@ async def _dispatch_prefixed_commands(self, event: RawGatewayEvent) -> None:
324324
command = new_command
325325
content_parameters = content_parameters.removeprefix(first_word).strip()
326326

327-
if command.subcommands and command.hierarchical_checking:
328-
try:
329-
await new_command._can_run(context) # will error out if we can't run this command
330-
except Exception as e:
331-
if new_command.error_callback:
332-
await new_command.error_callback(e, context)
333-
elif new_command.extension and new_command.extension.extension_error:
334-
await new_command.extension.extension_error(e, context)
335-
else:
336-
self.client.dispatch(CommandError(ctx=context, error=e))
337-
return
338-
339327
if not isinstance(command, PrefixedCommand) or not command.enabled:
340328
return
341329

interactions/models/internal/application_commands.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,14 +613,16 @@ def wrapper(call: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
613613
option_name = option_name.lower()
614614
return wrapper
615615

616-
def group(self, name: str = None, description: str = "No Description Set") -> "SlashCommand":
617-
616+
def group(
617+
self, name: str = None, description: str = "No Description Set", inherit_checks: bool = True
618+
) -> "SlashCommand":
618619
return SlashCommand(
619620
name=self.name,
620621
description=self.description,
621622
group_name=name,
622623
group_description=description,
623624
scopes=self.scopes,
625+
checks=self.checks if inherit_checks else [],
624626
)
625627

626628
def subcommand(
@@ -631,6 +633,7 @@ def subcommand(
631633
group_description: Absent[LocalisedDesc | str] = MISSING,
632634
options: List[Union[SlashCommandOption, Dict]] = None,
633635
nsfw: bool = False,
636+
inherit_checks: bool = True,
634637
) -> Callable[..., "SlashCommand"]:
635638
def wrapper(call: Callable[..., Coroutine]) -> "SlashCommand":
636639
nonlocal sub_cmd_description
@@ -654,6 +657,7 @@ def wrapper(call: Callable[..., Coroutine]) -> "SlashCommand":
654657
callback=call,
655658
scopes=self.scopes,
656659
nsfw=nsfw,
660+
checks=self.checks if inherit_checks else [],
657661
)
658662

659663
return wrapper

interactions/models/internal/command.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,10 @@ async def _can_run(self, context: "BaseContext") -> bool:
304304
await self.max_concurrency.release(context)
305305
raise
306306

307+
def add_check(self, check: Callable[..., Awaitable[bool]]) -> None:
308+
"""Adds a check into the command."""
309+
self.checks.append(check)
310+
307311
def error(self, call: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
308312
"""A decorator to declare a coroutine as one that will be run upon an error."""
309313
if not asyncio.iscoroutinefunction(call):

0 commit comments

Comments
 (0)