Skip to content

Commit e1b4658

Browse files
Add mention decorator for GitHub command handling
1 parent d9440ff commit e1b4658

File tree

3 files changed

+292
-1
lines changed

3 files changed

+292
-1
lines changed

src/django_github_app/commands.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from __future__ import annotations
2+
3+
from enum import Enum
4+
from typing import NamedTuple
5+
6+
7+
class EventAction(NamedTuple):
8+
event: str
9+
action: str
10+
11+
12+
class CommandScope(str, Enum):
13+
COMMIT = "commit"
14+
ISSUE = "issue"
15+
PR = "pr"
16+
17+
def get_events(self) -> list[EventAction]:
18+
match self:
19+
case CommandScope.ISSUE:
20+
return [
21+
EventAction("issue_comment", "created"),
22+
]
23+
case CommandScope.PR:
24+
return [
25+
EventAction("issue_comment", "created"),
26+
EventAction("pull_request_review_comment", "created"),
27+
EventAction("pull_request_review", "submitted"),
28+
]
29+
case CommandScope.COMMIT:
30+
return [
31+
EventAction("commit_comment", "created"),
32+
]
33+
34+
@classmethod
35+
def all_events(cls) -> list[EventAction]:
36+
return list(
37+
dict.fromkeys(
38+
event_action for scope in cls for event_action in scope.get_events()
39+
)
40+
)
41+

src/django_github_app/routing.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,111 @@
11
from __future__ import annotations
22

3+
from asyncio import iscoroutinefunction
34
from collections.abc import Awaitable
45
from collections.abc import Callable
6+
from functools import wraps
57
from typing import Any
8+
from typing import Protocol
69
from typing import TypeVar
10+
from typing import cast
711

812
from django.utils.functional import classproperty
913
from gidgethub import sansio
1014
from gidgethub.routing import Router as GidgetHubRouter
1115

1216
from ._typing import override
17+
from .commands import CommandScope
1318

1419
AsyncCallback = Callable[..., Awaitable[None]]
1520
SyncCallback = Callable[..., None]
1621

1722
CB = TypeVar("CB", AsyncCallback, SyncCallback)
1823

1924

25+
class MentionHandlerBase(Protocol):
26+
_mention_command: str | None
27+
_mention_scope: CommandScope | None
28+
_mention_permission: str | None
29+
30+
31+
class AsyncMentionHandler(MentionHandlerBase, Protocol):
32+
async def __call__(
33+
self, event: sansio.Event, *args: Any, **kwargs: Any
34+
) -> None: ...
35+
36+
37+
class SyncMentionHandler(MentionHandlerBase, Protocol):
38+
def __call__(self, event: sansio.Event, *args: Any, **kwargs: Any) -> None: ...
39+
40+
41+
MentionHandler = AsyncMentionHandler | SyncMentionHandler
42+
43+
2044
class GitHubRouter(GidgetHubRouter):
2145
_routers: list[GidgetHubRouter] = []
2246

2347
def __init__(self, *args) -> None:
2448
super().__init__(*args)
2549
GitHubRouter._routers.append(self)
2650

51+
@override
52+
def add(
53+
self, func: AsyncCallback | SyncCallback, event_type: str, **data_detail: Any
54+
) -> None:
55+
"""Override to accept both async and sync callbacks."""
56+
super().add(cast(AsyncCallback, func), event_type, **data_detail)
57+
2758
@classproperty
2859
def routers(cls):
2960
return list(cls._routers)
3061

3162
def event(self, event_type: str, **kwargs: Any) -> Callable[[CB], CB]:
3263
def decorator(func: CB) -> CB:
33-
self.add(func, event_type, **kwargs) # type: ignore[arg-type]
64+
self.add(func, event_type, **kwargs)
65+
return func
66+
67+
return decorator
68+
69+
def mention(self, **kwargs: Any) -> Callable[[CB], CB]:
70+
def decorator(func: CB) -> CB:
71+
command = kwargs.pop("command", None)
72+
scope = kwargs.pop("scope", None)
73+
permission = kwargs.pop("permission", None)
74+
75+
@wraps(func)
76+
async def async_wrapper(
77+
event: sansio.Event, *args: Any, **wrapper_kwargs: Any
78+
) -> None:
79+
# TODO: Parse comment body for mentions
80+
# TODO: If command specified, check if it matches
81+
# TODO: Check permissions
82+
# For now, just call through
83+
await func(event, *args, **wrapper_kwargs) # type: ignore[func-returns-value]
84+
85+
@wraps(func)
86+
def sync_wrapper(
87+
event: sansio.Event, *args: Any, **wrapper_kwargs: Any
88+
) -> None:
89+
# TODO: Parse comment body for mentions
90+
# TODO: If command specified, check if it matches
91+
# TODO: Check permissions
92+
# For now, just call through
93+
func(event, *args, **wrapper_kwargs)
94+
95+
wrapper: MentionHandler
96+
if iscoroutinefunction(func):
97+
wrapper = cast(AsyncMentionHandler, async_wrapper)
98+
else:
99+
wrapper = cast(SyncMentionHandler, sync_wrapper)
100+
101+
wrapper._mention_command = command.lower() if command else None
102+
wrapper._mention_scope = scope
103+
wrapper._mention_permission = permission
104+
105+
events = scope.get_events() if scope else CommandScope.all_events()
106+
for event_action in events:
107+
self.add(wrapper, event_action.event, action=event_action.action, **kwargs)
108+
34109
return func
35110

36111
return decorator

tests/test_routing.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from __future__ import annotations
22

3+
import asyncio
4+
35
import pytest
46
from django.http import HttpRequest
57
from django.http import JsonResponse
8+
from gidgethub import sansio
69

10+
from django_github_app.commands import CommandScope
711
from django_github_app.github import SyncGitHubAPI
812
from django_github_app.routing import GitHubRouter
913
from django_github_app.views import BaseWebhookView
@@ -109,3 +113,174 @@ def test_router_memory_stress_test_legacy(self):
109113

110114
assert len(views) == view_count
111115
assert not all(view.router is view1_router for view in views)
116+
117+
118+
class TestMentionDecorator:
119+
def test_basic_mention_no_command(self, test_router):
120+
handler_called = False
121+
handler_args = None
122+
123+
@test_router.mention()
124+
def handle_mention(event, *args, **kwargs):
125+
nonlocal handler_called, handler_args
126+
handler_called = True
127+
handler_args = (event, args, kwargs)
128+
129+
event = sansio.Event(
130+
{"action": "created", "comment": {"body": "@bot hello"}},
131+
event="issue_comment",
132+
delivery_id="123",
133+
)
134+
test_router.dispatch(event, None)
135+
136+
assert handler_called
137+
assert handler_args[0] == event
138+
139+
def test_mention_with_command(self, test_router):
140+
handler_called = False
141+
142+
@test_router.mention(command="help")
143+
def help_command(event, *args, **kwargs):
144+
nonlocal handler_called
145+
handler_called = True
146+
return "help response"
147+
148+
event = sansio.Event(
149+
{"action": "created", "comment": {"body": "@bot help"}},
150+
event="issue_comment",
151+
delivery_id="123",
152+
)
153+
test_router.dispatch(event, None)
154+
155+
assert handler_called
156+
157+
def test_mention_with_scope(self, test_router):
158+
pr_handler_called = False
159+
160+
@test_router.mention(command="deploy", scope=CommandScope.PR)
161+
def deploy_command(event, *args, **kwargs):
162+
nonlocal pr_handler_called
163+
pr_handler_called = True
164+
165+
pr_event = sansio.Event(
166+
{"action": "created", "comment": {"body": "@bot deploy"}},
167+
event="pull_request_review_comment",
168+
delivery_id="123",
169+
)
170+
test_router.dispatch(pr_event, None)
171+
172+
assert pr_handler_called
173+
174+
issue_event = sansio.Event(
175+
{"action": "created", "comment": {"body": "@bot deploy"}},
176+
event="commit_comment", # This is NOT a PR event
177+
delivery_id="124",
178+
)
179+
pr_handler_called = False # Reset
180+
181+
test_router.dispatch(issue_event, None)
182+
183+
assert not pr_handler_called
184+
185+
def test_mention_with_permission(self, test_router):
186+
handler_called = False
187+
188+
@test_router.mention(command="delete", permission="admin")
189+
def delete_command(event, *args, **kwargs):
190+
nonlocal handler_called
191+
handler_called = True
192+
193+
event = sansio.Event(
194+
{"action": "created", "comment": {"body": "@bot delete"}},
195+
event="issue_comment",
196+
delivery_id="123",
197+
)
198+
test_router.dispatch(event, None)
199+
200+
assert handler_called
201+
202+
def test_case_insensitive_command(self, test_router):
203+
handler_called = False
204+
205+
@test_router.mention(command="HELP")
206+
def help_command(event, *args, **kwargs):
207+
nonlocal handler_called
208+
handler_called = True
209+
210+
event = sansio.Event(
211+
{"action": "created", "comment": {"body": "@bot help"}},
212+
event="issue_comment",
213+
delivery_id="123",
214+
)
215+
test_router.dispatch(event, None)
216+
217+
assert handler_called
218+
219+
def test_multiple_decorators_on_same_function(self, test_router):
220+
call_count = 0
221+
222+
@test_router.mention(command="help")
223+
@test_router.mention(command="h")
224+
@test_router.mention(command="?")
225+
def help_command(event, *args, **kwargs):
226+
nonlocal call_count
227+
call_count += 1
228+
return f"help called {call_count} times"
229+
230+
event1 = sansio.Event(
231+
{"action": "created", "comment": {"body": "@bot help"}},
232+
event="issue_comment",
233+
delivery_id="123",
234+
)
235+
test_router.dispatch(event1, None)
236+
237+
assert call_count == 3
238+
239+
call_count = 0
240+
event2 = sansio.Event(
241+
{"action": "created", "comment": {"body": "@bot h"}},
242+
event="issue_comment",
243+
delivery_id="124",
244+
)
245+
test_router.dispatch(event2, None)
246+
247+
assert call_count == 3
248+
249+
# This behavior will change once we implement command parsing
250+
251+
def test_async_mention_handler(self, test_router):
252+
handler_called = False
253+
254+
@test_router.mention(command="async-test")
255+
async def async_handler(event, *args, **kwargs):
256+
nonlocal handler_called
257+
handler_called = True
258+
return "async response"
259+
260+
event = sansio.Event(
261+
{"action": "created", "comment": {"body": "@bot async-test"}},
262+
event="issue_comment",
263+
delivery_id="123",
264+
)
265+
266+
asyncio.run(test_router.adispatch(event, None))
267+
268+
assert handler_called
269+
270+
def test_sync_mention_handler(self, test_router):
271+
handler_called = False
272+
273+
@test_router.mention(command="sync-test")
274+
def sync_handler(event, *args, **kwargs):
275+
nonlocal handler_called
276+
handler_called = True
277+
return "sync response"
278+
279+
event = sansio.Event(
280+
{"action": "created", "comment": {"body": "@bot sync-test"}},
281+
event="issue_comment",
282+
delivery_id="123",
283+
)
284+
test_router.dispatch(event, None)
285+
286+
assert handler_called

0 commit comments

Comments
 (0)