Skip to content

Commit bf80ebc

Browse files
Integrate permission checking into mention decorator
1 parent 39e03dd commit bf80ebc

File tree

5 files changed

+435
-113
lines changed

5 files changed

+435
-113
lines changed

src/django_github_app/permissions.py

Lines changed: 112 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,12 @@
55

66
import cachetools
77
import gidgethub
8+
from gidgethub import sansio
89

910
from django_github_app.github import AsyncGitHubAPI
1011
from django_github_app.github import SyncGitHubAPI
1112

1213

13-
class PermissionCacheKey(NamedTuple):
14-
owner: str
15-
repo: str
16-
username: str
17-
18-
1914
class Permission(int, Enum):
2015
NONE = 0
2116
READ = 1
@@ -47,6 +42,12 @@ def from_string(cls, permission: str) -> Permission:
4742
)
4843

4944

45+
class PermissionCacheKey(NamedTuple):
46+
owner: str
47+
repo: str
48+
username: str
49+
50+
5051
async def aget_user_permission(
5152
gh: AsyncGitHubAPI, owner: str, repo: str, username: str
5253
) -> Permission:
@@ -92,18 +93,121 @@ def get_user_permission(
9293
try:
9394
# Check if user is a collaborator and get their permission
9495
data = gh.getitem(f"/repos/{owner}/{repo}/collaborators/{username}/permission")
95-
permission_str = data.get("permission", "none")
96+
permission_str = data.get("permission", "none") # type: ignore[attr-defined]
9697
permission = Permission.from_string(permission_str)
9798
except gidgethub.HTTPException as e:
9899
if e.status_code == 404:
99100
# User is not a collaborator, they have read permission if repo is public
100101
# Check if repo is public
101102
try:
102103
repo_data = gh.getitem(f"/repos/{owner}/{repo}")
103-
if not repo_data.get("private", True):
104+
if not repo_data.get("private", True): # type: ignore[attr-defined]
104105
permission = Permission.READ
105106
except gidgethub.HTTPException:
106107
pass
107108

108109
cache[cache_key] = permission
109110
return permission
111+
112+
113+
class EventInfo(NamedTuple):
114+
comment_author: str | None
115+
owner: str | None
116+
repo: str | None
117+
118+
@classmethod
119+
def from_event(cls, event: sansio.Event) -> EventInfo:
120+
comment_author = None
121+
owner = None
122+
repo = None
123+
124+
if "comment" in event.data:
125+
comment_author = event.data["comment"]["user"]["login"]
126+
127+
if "repository" in event.data:
128+
owner = event.data["repository"]["owner"]["login"]
129+
repo = event.data["repository"]["name"]
130+
131+
return cls(comment_author=comment_author, owner=owner, repo=repo)
132+
133+
134+
class PermissionCheck(NamedTuple):
135+
has_permission: bool
136+
error_message: str | None
137+
138+
139+
PERMISSION_CHECK_ERROR_MESSAGE = """
140+
❌ **Permission Denied**
141+
142+
@{comment_author}, you need at least **{required_permission}** permission to use this command.
143+
144+
Your current permission level: **{user_permission}**
145+
"""
146+
147+
148+
async def acheck_mention_permission(
149+
event: sansio.Event, gh: AsyncGitHubAPI, required_permission: Permission
150+
) -> PermissionCheck:
151+
comment_author, owner, repo = EventInfo.from_event(event)
152+
153+
if not (comment_author and owner and repo):
154+
return PermissionCheck(has_permission=False, error_message=None)
155+
156+
user_permission = await aget_user_permission(gh, owner, repo, comment_author)
157+
158+
if user_permission >= required_permission:
159+
return PermissionCheck(has_permission=True, error_message=None)
160+
161+
return PermissionCheck(
162+
has_permission=False,
163+
error_message=PERMISSION_CHECK_ERROR_MESSAGE.format(
164+
comment_author=comment_author,
165+
required_permission=required_permission.name.lower(),
166+
user_permission=user_permission.name.lower(),
167+
),
168+
)
169+
170+
171+
def check_mention_permission(
172+
event: sansio.Event, gh: SyncGitHubAPI, required_permission: Permission
173+
) -> PermissionCheck:
174+
comment_author, owner, repo = EventInfo.from_event(event)
175+
176+
if not (comment_author and owner and repo):
177+
return PermissionCheck(has_permission=False, error_message=None)
178+
179+
user_permission = get_user_permission(gh, owner, repo, comment_author)
180+
181+
if user_permission >= required_permission:
182+
return PermissionCheck(has_permission=True, error_message=None)
183+
184+
return PermissionCheck(
185+
has_permission=False,
186+
error_message=PERMISSION_CHECK_ERROR_MESSAGE.format(
187+
comment_author=comment_author,
188+
required_permission=required_permission.name.lower(),
189+
user_permission=user_permission.name.lower(),
190+
),
191+
)
192+
193+
194+
def get_comment_post_url(event: sansio.Event) -> str | None:
195+
if event.data.get("action") != "created":
196+
return None
197+
198+
_, owner, repo = EventInfo.from_event(event)
199+
200+
if not (owner and repo):
201+
return None
202+
203+
if "issue" in event.data:
204+
issue_number = event.data["issue"]["number"]
205+
return f"/repos/{owner}/{repo}/issues/{issue_number}/comments"
206+
elif "pull_request" in event.data:
207+
pr_number = event.data["pull_request"]["number"]
208+
return f"/repos/{owner}/{repo}/issues/{pr_number}/comments"
209+
elif "commit_sha" in event.data:
210+
commit_sha = event.data["commit_sha"]
211+
return f"/repos/{owner}/{repo}/commits/{commit_sha}/comments"
212+
213+
return None

src/django_github_app/routing.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,15 @@
1414
from gidgethub.routing import Router as GidgetHubRouter
1515

1616
from ._typing import override
17+
from .github import AsyncGitHubAPI
18+
from .github import SyncGitHubAPI
1719
from .mentions import MentionScope
1820
from .mentions import check_event_for_mention
1921
from .mentions import check_event_scope
22+
from .permissions import Permission
23+
from .permissions import acheck_mention_permission
24+
from .permissions import check_mention_permission
25+
from .permissions import get_comment_post_url
2026

2127
AsyncCallback = Callable[..., Awaitable[None]]
2228
SyncCallback = Callable[..., None]
@@ -76,7 +82,7 @@ def decorator(func: CB) -> CB:
7682

7783
@wraps(func)
7884
async def async_wrapper(
79-
event: sansio.Event, *args: Any, **wrapper_kwargs: Any
85+
event: sansio.Event, gh: AsyncGitHubAPI, *args: Any, **kwargs: Any
8086
) -> None:
8187
# TODO: Get actual bot username from installation/app data
8288
username = "bot" # Placeholder
@@ -87,12 +93,29 @@ async def async_wrapper(
8793
if not check_event_scope(event, scope):
8894
return
8995

90-
# TODO: Check permissions. For now, just call through.
91-
await func(event, *args, **wrapper_kwargs) # type: ignore[func-returns-value]
96+
# Check permissions if required
97+
if permission is not None:
98+
required_perm = Permission.from_string(permission)
99+
permission_check = await acheck_mention_permission(
100+
event, gh, required_perm
101+
)
102+
103+
if not permission_check.has_permission:
104+
# Post error comment if we have an error message
105+
if permission_check.error_message:
106+
comment_url = get_comment_post_url(event)
107+
if comment_url:
108+
await gh.post(
109+
comment_url,
110+
data={"body": permission_check.error_message},
111+
)
112+
return
113+
114+
await func(event, gh, *args, **kwargs) # type: ignore[func-returns-value]
92115

93116
@wraps(func)
94117
def sync_wrapper(
95-
event: sansio.Event, *args: Any, **wrapper_kwargs: Any
118+
event: sansio.Event, gh: SyncGitHubAPI, *args: Any, **kwargs: Any
96119
) -> None:
97120
# TODO: Get actual bot username from installation/app data
98121
username = "bot" # Placeholder
@@ -103,8 +126,25 @@ def sync_wrapper(
103126
if not check_event_scope(event, scope):
104127
return
105128

106-
# TODO: Check permissions. For now, just call through.
107-
func(event, *args, **wrapper_kwargs)
129+
# Check permissions if required
130+
if permission is not None:
131+
required_perm = Permission.from_string(permission)
132+
permission_check = check_mention_permission(
133+
event, gh, required_perm
134+
)
135+
136+
if not permission_check.has_permission:
137+
# Post error comment if we have an error message
138+
if permission_check.error_message:
139+
comment_url = get_comment_post_url(event)
140+
if comment_url:
141+
gh.post( # type: ignore[unused-coroutine]
142+
comment_url,
143+
data={"body": permission_check.error_message},
144+
)
145+
return
146+
147+
func(event, gh, *args, **kwargs)
108148

109149
wrapper: MentionHandler
110150
if iscoroutinefunction(func):

tests/conftest.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,31 @@ async def mock_getiter(*args, **kwargs):
148148
return _get_mock_github_api
149149

150150

151+
@pytest.fixture
152+
def get_mock_github_api_sync():
153+
def _get_mock_github_api_sync(return_data):
154+
from django_github_app.github import SyncGitHubAPI
155+
156+
mock_api = MagicMock(spec=SyncGitHubAPI)
157+
158+
def mock_getitem(*args, **kwargs):
159+
return return_data
160+
161+
def mock_getiter(*args, **kwargs):
162+
yield from return_data
163+
164+
def mock_post(*args, **kwargs):
165+
pass
166+
167+
mock_api.getitem = mock_getitem
168+
mock_api.getiter = mock_getiter
169+
mock_api.post = mock_post
170+
171+
return mock_api
172+
173+
return _get_mock_github_api_sync
174+
175+
151176
@pytest.fixture
152177
def installation(get_mock_github_api, baker):
153178
installation = baker.make(

0 commit comments

Comments
 (0)