diff --git a/src/django_github_app/mentions.py b/src/django_github_app/mentions.py new file mode 100644 index 0000000..13321f3 --- /dev/null +++ b/src/django_github_app/mentions.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass +from enum import Enum +from typing import NamedTuple + +from gidgethub import sansio + + +class EventAction(NamedTuple): + event: str + action: str + + +class MentionScope(str, Enum): + COMMIT = "commit" + ISSUE = "issue" + PR = "pr" + + def get_events(self) -> list[EventAction]: + match self: + case MentionScope.ISSUE: + return [ + EventAction("issue_comment", "created"), + ] + case MentionScope.PR: + return [ + EventAction("issue_comment", "created"), + EventAction("pull_request_review_comment", "created"), + EventAction("pull_request_review", "submitted"), + ] + case MentionScope.COMMIT: + return [ + EventAction("commit_comment", "created"), + ] + + @classmethod + def all_events(cls) -> list[EventAction]: + return list( + dict.fromkeys( + event_action for scope in cls for event_action in scope.get_events() + ) + ) + + @classmethod + def from_event(cls, event: sansio.Event) -> MentionScope | None: + if event.event == "issue_comment": + issue = event.data.get("issue", {}) + is_pull_request = ( + "pull_request" in issue and issue["pull_request"] is not None + ) + return cls.PR if is_pull_request else cls.ISSUE + + for scope in cls: + scope_events = scope.get_events() + if any(event_action.event == event.event for event_action in scope_events): + return scope + + return None + + +@dataclass +class RawMention: + match: re.Match[str] + username: str + position: int + end: int + + +CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```", re.MULTILINE) +INLINE_CODE_PATTERN = re.compile(r"`[^`]+`") +BLOCKQUOTE_PATTERN = re.compile(r"^\s*>.*$", re.MULTILINE) +# GitHub username rules: +# - 1-39 characters long +# - Can only contain alphanumeric characters or hyphens +# - Cannot start or end with a hyphen +# - Cannot have multiple consecutive hyphens +GITHUB_MENTION_PATTERN = re.compile( + r"(?:^|(?<=\s))@([a-z\d](?:[a-z\d]|-(?=[a-z\d])){0,38})", + re.MULTILINE | re.IGNORECASE, +) + + +def extract_all_mentions(text: str) -> list[RawMention]: + # replace all code blocks, inline code, and blockquotes with spaces + # this preserves linenos and postitions while not being able to + # match against anything in them + processed_text = CODE_BLOCK_PATTERN.sub(lambda m: " " * len(m.group(0)), text) + processed_text = INLINE_CODE_PATTERN.sub( + lambda m: " " * len(m.group(0)), processed_text + ) + processed_text = BLOCKQUOTE_PATTERN.sub( + lambda m: " " * len(m.group(0)), processed_text + ) + return [ + RawMention( + match=match, + username=match.group(1), + position=match.start(), + end=match.end(), + ) + for match in GITHUB_MENTION_PATTERN.finditer(processed_text) + ] + + +class LineInfo(NamedTuple): + lineno: int + text: str + + @classmethod + def for_mention_in_comment(cls, comment: str, mention_position: int): + lines = comment.splitlines() + text_before = comment[:mention_position] + line_number = text_before.count("\n") + 1 + + line_index = line_number - 1 + line_text = lines[line_index] if line_index < len(lines) else "" + + return cls(lineno=line_number, text=line_text) + + +@dataclass +class ParsedMention: + username: str + position: int + line_info: LineInfo + previous_mention: ParsedMention | None = None + next_mention: ParsedMention | None = None + + +def matches_pattern(text: str, pattern: str | re.Pattern[str]) -> bool: + match pattern: + case re.Pattern(): + return pattern.fullmatch(text) is not None + case str(): + return text.strip().lower() == pattern.strip().lower() + + +def extract_mentions_from_event( + event: sansio.Event, username_pattern: str | re.Pattern[str] | None = None +) -> list[ParsedMention]: + comment_key = "comment" if event.event != "pull_request_review" else "review" + comment = event.data.get(comment_key, {}).get("body", "") + + if not comment: + return [] + + mentions: list[ParsedMention] = [] + potential_mentions = extract_all_mentions(comment) + for raw_mention in potential_mentions: + if username_pattern and not matches_pattern( + raw_mention.username, username_pattern + ): + continue + + mentions.append( + ParsedMention( + username=raw_mention.username, + position=raw_mention.position, + line_info=LineInfo.for_mention_in_comment( + comment, raw_mention.position + ), + previous_mention=None, + next_mention=None, + ) + ) + + for i, mention in enumerate(mentions): + if i > 0: + mention.previous_mention = mentions[i - 1] + if i < len(mentions) - 1: + mention.next_mention = mentions[i + 1] + + return mentions + + +@dataclass +class Mention: + mention: ParsedMention + scope: MentionScope | None + + @classmethod + def from_event( + cls, + event: sansio.Event, + *, + username: str | re.Pattern[str] | None = None, + scope: MentionScope | None = None, + ): + mentions = extract_mentions_from_event(event, username) + for mention in mentions: + yield cls(mention=mention, scope=scope) diff --git a/src/django_github_app/routing.py b/src/django_github_app/routing.py index 8217b03..68233ae 100644 --- a/src/django_github_app/routing.py +++ b/src/django_github_app/routing.py @@ -1,15 +1,24 @@ from __future__ import annotations +import re +from asyncio import iscoroutinefunction from collections.abc import Awaitable from collections.abc import Callable +from functools import wraps from typing import Any +from typing import Protocol from typing import TypeVar +from typing import cast from django.utils.functional import classproperty from gidgethub import sansio from gidgethub.routing import Router as GidgetHubRouter from ._typing import override +from .github import AsyncGitHubAPI +from .github import SyncGitHubAPI +from .mentions import Mention +from .mentions import MentionScope AsyncCallback = Callable[..., Awaitable[None]] SyncCallback = Callable[..., None] @@ -17,6 +26,19 @@ CB = TypeVar("CB", AsyncCallback, SyncCallback) +class AsyncMentionHandler(Protocol): + async def __call__( + self, event: sansio.Event, *args: Any, **kwargs: Any + ) -> None: ... + + +class SyncMentionHandler(Protocol): + def __call__(self, event: sansio.Event, *args: Any, **kwargs: Any) -> None: ... + + +MentionHandler = AsyncMentionHandler | SyncMentionHandler + + class GitHubRouter(GidgetHubRouter): _routers: list[GidgetHubRouter] = [] @@ -24,13 +46,70 @@ def __init__(self, *args) -> None: super().__init__(*args) GitHubRouter._routers.append(self) + @override + def add( + self, func: AsyncCallback | SyncCallback, event_type: str, **data_detail: Any + ) -> None: + # Override to accept both async and sync callbacks. + super().add(cast(AsyncCallback, func), event_type, **data_detail) + @classproperty def routers(cls): return list(cls._routers) def event(self, event_type: str, **kwargs: Any) -> Callable[[CB], CB]: def decorator(func: CB) -> CB: - self.add(func, event_type, **kwargs) # type: ignore[arg-type] + self.add(func, event_type, **kwargs) + return func + + return decorator + + def mention( + self, + *, + username: str | re.Pattern[str] | None = None, + scope: MentionScope | None = None, + **kwargs: Any, + ) -> Callable[[CB], CB]: + def decorator(func: CB) -> CB: + @wraps(func) + async def async_wrapper( + event: sansio.Event, gh: AsyncGitHubAPI, *args: Any, **kwargs: Any + ) -> None: + event_scope = MentionScope.from_event(event) + if scope is not None and event_scope != scope: + return + + for mention in Mention.from_event( + event, username=username, scope=event_scope + ): + await func(event, gh, *args, context=mention, **kwargs) # type: ignore[func-returns-value] + + @wraps(func) + def sync_wrapper( + event: sansio.Event, gh: SyncGitHubAPI, *args: Any, **kwargs: Any + ) -> None: + event_scope = MentionScope.from_event(event) + if scope is not None and event_scope != scope: + return + + for mention in Mention.from_event( + event, username=username, scope=event_scope + ): + func(event, gh, *args, context=mention, **kwargs) + + wrapper: MentionHandler + if iscoroutinefunction(func): + wrapper = cast(AsyncMentionHandler, async_wrapper) + else: + wrapper = cast(SyncMentionHandler, sync_wrapper) + + events = scope.get_events() if scope else MentionScope.all_events() + for event_action in events: + self.add( + wrapper, event_action.event, action=event_action.action, **kwargs + ) + return func return decorator diff --git a/tests/conftest.py b/tests/conftest.py index 70bf9cb..8e6e0cf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -129,8 +129,8 @@ def repository_id(): @pytest.fixture -def get_mock_github_api(): - def _get_mock_github_api(return_data): +def aget_mock_github_api(): + def _aget_mock_github_api(return_data, installation_id=12345): mock_api = AsyncMock(spec=AsyncGitHubAPI) async def mock_getitem(*args, **kwargs): @@ -144,6 +144,33 @@ async def mock_getiter(*args, **kwargs): mock_api.getiter = mock_getiter mock_api.__aenter__.return_value = mock_api mock_api.__aexit__.return_value = None + mock_api.installation_id = installation_id + + return mock_api + + return _aget_mock_github_api + + +@pytest.fixture +def get_mock_github_api(): + def _get_mock_github_api(return_data, installation_id=12345): + from django_github_app.github import SyncGitHubAPI + + mock_api = MagicMock(spec=SyncGitHubAPI) + + def mock_getitem(*args, **kwargs): + return return_data + + def mock_getiter(*args, **kwargs): + yield from return_data + + def mock_post(*args, **kwargs): + pass + + mock_api.getitem = mock_getitem + mock_api.getiter = mock_getiter + mock_api.post = mock_post + mock_api.installation_id = installation_id return mock_api @@ -151,11 +178,11 @@ async def mock_getiter(*args, **kwargs): @pytest.fixture -def installation(get_mock_github_api, baker): +def installation(aget_mock_github_api, baker): installation = baker.make( "django_github_app.Installation", installation_id=seq.next() ) - mock_github_api = get_mock_github_api( + mock_github_api = aget_mock_github_api( [ {"id": seq.next(), "node_id": "node1", "full_name": "owner/repo1"}, {"id": seq.next(), "node_id": "node2", "full_name": "owner/repo2"}, @@ -167,11 +194,11 @@ def installation(get_mock_github_api, baker): @pytest_asyncio.fixture -async def ainstallation(get_mock_github_api, baker): +async def ainstallation(aget_mock_github_api, baker): installation = await sync_to_async(baker.make)( "django_github_app.Installation", installation_id=seq.next() ) - mock_github_api = get_mock_github_api( + mock_github_api = aget_mock_github_api( [ {"id": seq.next(), "node_id": "node1", "full_name": "owner/repo1"}, {"id": seq.next(), "node_id": "node2", "full_name": "owner/repo2"}, @@ -183,14 +210,14 @@ async def ainstallation(get_mock_github_api, baker): @pytest.fixture -def repository(installation, get_mock_github_api, baker): +def repository(installation, aget_mock_github_api, baker): repository = baker.make( "django_github_app.Repository", repository_id=seq.next(), full_name="owner/repo", installation=installation, ) - mock_github_api = get_mock_github_api( + mock_github_api = aget_mock_github_api( [ { "number": 1, @@ -210,14 +237,14 @@ def repository(installation, get_mock_github_api, baker): @pytest_asyncio.fixture -async def arepository(ainstallation, get_mock_github_api, baker): +async def arepository(ainstallation, aget_mock_github_api, baker): repository = await sync_to_async(baker.make)( "django_github_app.Repository", repository_id=seq.next(), full_name="owner/repo", installation=ainstallation, ) - mock_github_api = get_mock_github_api( + mock_github_api = aget_mock_github_api( [ { "number": 1, @@ -247,19 +274,40 @@ def _create_event(event_type, delivery_id=None, **data): if delivery_id is None: delivery_id = seq.next() - if event_type == "issue_comment" and "comment" not in data: - data["comment"] = {"body": faker.sentence()} + # Auto-create comment field for comment events + if ( + event_type + in ["issue_comment", "pull_request_review_comment", "commit_comment"] + and "comment" not in data + ): + data["comment"] = {"body": f"@{faker.user_name()} {faker.sentence()}"} - if "comment" in data and isinstance(data["comment"], str): - # Allow passing just the comment body as a string - data["comment"] = {"body": data["comment"]} + # Auto-create review field for pull request review events + if event_type == "pull_request_review" and "review" not in data: + data["review"] = {"body": f"@{faker.user_name()} {faker.sentence()}"} + # Add user to comment if not present if "comment" in data and "user" not in data["comment"]: data["comment"]["user"] = {"login": faker.user_name()} + # Add user to review if not present + if "review" in data and "user" not in data["review"]: + data["review"]["user"] = {"login": faker.user_name()} + + if event_type == "issue_comment" and "issue" not in data: + data["issue"] = {"number": faker.random_int(min=1, max=1000)} + + if event_type == "commit_comment" and "commit" not in data: + data["commit"] = {"sha": faker.sha1()} + + if event_type == "pull_request_review_comment" and "pull_request" not in data: + data["pull_request"] = {"number": faker.random_int(min=1, max=1000)} + if "repository" not in data and event_type in [ "issue_comment", "pull_request", + "pull_request_review_comment", + "commit_comment", "push", ]: data["repository"] = { diff --git a/tests/settings.py b/tests/settings.py index 3561b2a..6b0ce1d 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -20,4 +20,5 @@ "django.contrib.auth.hashers.MD5PasswordHasher", ], "SECRET_KEY": "not-a-secret", + "USE_TZ": True, } diff --git a/tests/test_mentions.py b/tests/test_mentions.py new file mode 100644 index 0000000..f0587a8 --- /dev/null +++ b/tests/test_mentions.py @@ -0,0 +1,683 @@ +from __future__ import annotations + +import re +import time + +import pytest + +from django_github_app.mentions import LineInfo +from django_github_app.mentions import Mention +from django_github_app.mentions import MentionScope +from django_github_app.mentions import extract_all_mentions +from django_github_app.mentions import extract_mentions_from_event +from django_github_app.mentions import matches_pattern + + +@pytest.fixture(autouse=True) +def setup_test_app_name(override_app_settings): + with override_app_settings(NAME="bot"): + yield + + +class TestExtractAllMentions: + @pytest.mark.parametrize( + "text,expected_mentions", + [ + # Valid usernames + ("@validuser", [("validuser", 0, 10)]), + ("@Valid-User-123", [("Valid-User-123", 0, 15)]), + ("@123startswithnumber", [("123startswithnumber", 0, 20)]), + # Multiple mentions + ( + "@alice review @bob help @charlie test", + [("alice", 0, 6), ("bob", 14, 18), ("charlie", 24, 32)], + ), + # Invalid patterns - partial extraction + ("@-invalid", []), # Can't start with hyphen + ("@invalid-", [("invalid", 0, 8)]), # Hyphen at end not included + ("@in--valid", [("in", 0, 3)]), # Stops at double hyphen + # Long username - truncated to 39 chars + ( + "@toolongusernamethatexceedsthirtyninecharacters", + [("toolongusernamethatexceedsthirtyninecha", 0, 40)], + ), + # Special blocks tested in test_preserves_positions_with_special_blocks + # Edge cases + ("@", []), # Just @ symbol + ("@@double", []), # Double @ symbol + ("email@example.com", []), # Email (not at start of word) + ("@123", [("123", 0, 4)]), # Numbers only + ("@user_name", [("user", 0, 5)]), # Underscore stops extraction + ("test@user", []), # Not at word boundary + ("@user@another", [("user", 0, 5)]), # Second @ not at boundary + ], + ) + def test_extract_all_mentions(self, text, expected_mentions): + mentions = extract_all_mentions(text) + + assert len(mentions) == len(expected_mentions) + for i, (username, start, end) in enumerate(expected_mentions): + assert mentions[i].username == username + assert mentions[i].position == start + assert mentions[i].end == end + + @pytest.mark.parametrize( + "text,expected_mentions", + [ + # Code block with triple backticks + ( + "Before code\n```\n@codebot ignored\n```\n@realbot after", + [("realbot", 37, 45)], + ), + # Inline code with single backticks + ( + "Use `@inlinebot command` here, but @realbot works", + [("realbot", 35, 43)], + ), + # Blockquote with > + ( + "> @quotedbot ignored\n@realbot visible", + [("realbot", 21, 29)], + ), + # Multiple code blocks + ( + "```\n@bot1\n```\nMiddle @bot2\n```\n@bot3\n```\nEnd @bot4", + [("bot2", 21, 26), ("bot4", 45, 50)], + ), + # Nested backticks in code block + ( + "```\n`@nestedbot`\n```\n@realbot after", + [("realbot", 21, 29)], + ), + # Multiple inline codes + ( + "`@bot1` and `@bot2` but @bot3 and @bot4", + [("bot3", 24, 29), ("bot4", 34, 39)], + ), + # Mixed special blocks + ( + "Start\n```\n@codebot\n```\n`@inline` text\n> @quoted line\n@realbot end", + [("realbot", 53, 61)], + ), + # Empty code block + ( + "Before\n```\n\n```\n@realbot after", + [("realbot", 16, 24)], + ), + # Code block at start + ( + "```\n@ignored\n```\n@realbot only", + [("realbot", 17, 25)], + ), + # Multiple blockquotes + ( + "> @bot1 quoted\n> @bot2 also quoted\n@bot3 not quoted", + [("bot3", 35, 40)], + ), + ], + ) + def test_preserves_positions_with_special_blocks(self, text, expected_mentions): + mentions = extract_all_mentions(text) + + assert len(mentions) == len(expected_mentions) + for i, (username, start, end) in enumerate(expected_mentions): + assert mentions[i].username == username + assert mentions[i].position == start + assert mentions[i].end == end + # Verify positions are preserved despite replacements + assert text[mentions[i].position : mentions[i].end] == f"@{username}" + + +class TestExtractMentionsFromEvent: + @pytest.mark.parametrize( + "body,username,expected", + [ + # Simple mention with command + ( + "@mybot help", + "mybot", + [{"username": "mybot"}], + ), + # Mention without command + ("@mybot", "mybot", [{"username": "mybot"}]), + # Case insensitive matching - preserves original case + ("@MyBot help", "mybot", [{"username": "MyBot"}]), + # Command case preserved + ("@mybot HELP", "mybot", [{"username": "mybot"}]), + # Mention in middle + ("Hey @mybot help me", "mybot", [{"username": "mybot"}]), + # With punctuation + ("@mybot help!", "mybot", [{"username": "mybot"}]), + # No space after mention + ( + "@mybot, please help", + "mybot", + [{"username": "mybot"}], + ), + # Multiple spaces before command + ("@mybot help", "mybot", [{"username": "mybot"}]), + # Hyphenated command + ( + "@mybot async-test", + "mybot", + [{"username": "mybot"}], + ), + # Special character command + ("@mybot ?", "mybot", [{"username": "mybot"}]), + # Hyphenated username matches pattern + ("@my-bot help", "my-bot", [{"username": "my-bot"}]), + # Username with underscore - doesn't match pattern + ("@my_bot help", "my_bot", []), + # Empty text + ("", "mybot", []), + ], + ) + def test_mention_extraction_scenarios(self, body, username, expected, create_event): + event = create_event("issue_comment", comment={"body": body} if body else {}) + + mentions = extract_mentions_from_event(event, username) + + assert len(mentions) == len(expected) + for i, exp in enumerate(expected): + assert mentions[i].username == exp["username"] + + @pytest.mark.parametrize( + "body,bot_pattern,expected_mentions", + [ + # Multiple mentions of same bot + ( + "@mybot help and then @mybot deploy", + "mybot", + ["mybot", "mybot"], + ), + # Filter specific mentions, ignore others + ( + "@otheruser help @mybot deploy @someone else", + "mybot", + ["mybot"], + ), + # Default pattern (None matches all mentions) + ("@bot help @otherbot test", None, ["bot", "otherbot"]), + # Specific bot name pattern + ( + "@bot help @deploy-bot test @test-bot check", + "deploy-bot", + ["deploy-bot"], + ), + ], + ) + def test_mention_filtering_and_patterns( + self, body, bot_pattern, expected_mentions, create_event + ): + event = create_event("issue_comment", comment={"body": body}) + + mentions = extract_mentions_from_event(event, bot_pattern) + + assert len(mentions) == len(expected_mentions) + for i, username in enumerate(expected_mentions): + assert mentions[i].username == username + + def test_missing_comment_body(self, create_event): + event = create_event("issue_comment") + + mentions = extract_mentions_from_event(event, "mybot") + + assert mentions == [] + + def test_mention_linking(self, create_event): + event = create_event( + "issue_comment", + comment={"body": "@bot1 first @bot2 second @bot3 third"}, + ) + + mentions = extract_mentions_from_event(event, re.compile(r"bot\d")) + + assert len(mentions) == 3 + + first = mentions[0] + second = mentions[1] + third = mentions[2] + + assert first.previous_mention is None + assert first.next_mention is second + + assert second.previous_mention is first + assert second.next_mention is third + + assert third.previous_mention is second + assert third.next_mention is None + + +class TestMentionScope: + @pytest.mark.parametrize( + "event_type,data,expected", + [ + ("issue_comment", {}, MentionScope.ISSUE), + ( + "issue_comment", + {"issue": {"pull_request": {"url": "..."}}}, + MentionScope.PR, + ), + ("issue_comment", {"issue": {"pull_request": None}}, MentionScope.ISSUE), + ("pull_request_review", {}, MentionScope.PR), + ("pull_request_review_comment", {}, MentionScope.PR), + ("commit_comment", {}, MentionScope.COMMIT), + ("unknown_event", {}, None), + ], + ) + def test_from_event(self, event_type, data, expected, create_event): + event = create_event(event_type=event_type, **data) + + assert MentionScope.from_event(event) == expected + + +class TestReDoSProtection: + def test_redos_vulnerability(self, create_event): + # Create a malicious comment that would cause potentially cause ReDoS + # Pattern: (bot|ai|assistant)+ matching "botbotbot...x" + malicious_username = "bot" * 20 + "x" + event = create_event( + "issue_comment", comment={"body": f"@{malicious_username} hello"} + ) + + pattern = re.compile(r"(bot|ai|assistant)+") + + start_time = time.time() + mentions = extract_mentions_from_event(event, pattern) + execution_time = time.time() - start_time + + assert execution_time < 0.1 + # The username gets truncated at 39 chars, and the 'x' is left out + # So it will match the pattern, but the important thing is it completes quickly + assert len(mentions) == 1 + assert mentions[0].username == "botbotbotbotbotbotbotbotbotbotbotbotbot" + + def test_nested_quantifier_pattern(self, create_event): + event = create_event( + "issue_comment", comment={"body": "@deploy-bot-bot-bot test command"} + ) + + # This type of pattern could cause issues: (word)+ + pattern = re.compile(r"(deploy|bot)+") + + start_time = time.time() + mentions = extract_mentions_from_event(event, pattern) + execution_time = time.time() - start_time + + assert execution_time < 0.1 + # Username contains hyphens, so it won't match this pattern + assert len(mentions) == 0 + + def test_alternation_with_quantifier(self, create_event): + event = create_event( + "issue_comment", comment={"body": "@mybot123bot456bot789 deploy"} + ) + + # Pattern like (a|b)* that could be dangerous + pattern = re.compile(r"(my|bot|[0-9])+") + + start_time = time.time() + mentions = extract_mentions_from_event(event, pattern) + execution_time = time.time() - start_time + + assert execution_time < 0.1 + # Should match safely + assert len(mentions) == 1 + assert mentions[0].username == "mybot123bot456bot789" + + def test_complex_regex_patterns_handled_safely(self, create_event): + event = create_event( + "issue_comment", + comment={ + "body": "@test @test-bot @test-bot-123 @testbotbotbot @verylongusername123456789" + }, + ) + + patterns = [ + re.compile(r".*bot.*"), # Wildcards + re.compile(r"test.*"), # Leading wildcard + re.compile(r".*"), # Match all + re.compile(r"(test|bot)+"), # Alternation with quantifier + re.compile(r"[a-z]+[0-9]+"), # Character classes with quantifiers + ] + + for pattern in patterns: + start_time = time.time() + extract_mentions_from_event(event, pattern) + execution_time = time.time() - start_time + + assert execution_time < 0.1 + + def test_performance_with_many_mentions(self, create_event): + usernames = [f"@user{i}" for i in range(100)] + comment_body = " ".join(usernames) + " Please review all" + event = create_event("issue_comment", comment={"body": comment_body}) + + pattern = re.compile(r"user\d+") + + start_time = time.time() + mentions = extract_mentions_from_event(event, pattern) + execution_time = time.time() - start_time + + assert execution_time < 0.5 + assert len(mentions) == 100 + for i, mention in enumerate(mentions): + assert mention.username == f"user{i}" + + +class TestLineInfo: + @pytest.mark.parametrize( + "comment,position,expected_lineno,expected_text", + [ + # Single line mentions + ("@user hello", 0, 1, "@user hello"), + ("Hey @user how are you?", 4, 1, "Hey @user how are you?"), + ("Thanks @user", 7, 1, "Thanks @user"), + # Multi-line mentions + ( + "@user please review\nthis pull request\nthanks!", + 0, + 1, + "@user please review", + ), + ("Hello there\n@user can you help?\nThanks!", 12, 2, "@user can you help?"), + ("First line\nSecond line\nThanks @user", 31, 3, "Thanks @user"), + # Empty and edge cases + ("", 0, 1, ""), + ( + "Simple comment with @user mention", + 20, + 1, + "Simple comment with @user mention", + ), + # Blank lines + ( + "First line\n\n@user on third line\n\nFifth line", + 12, + 3, + "@user on third line", + ), + ("\n\n\n@user appears here", 3, 4, "@user appears here"), + # Unicode/emoji + ( + "First line ๐Ÿ‘‹\n@user ใ“ใ‚“ใซใกใฏ ๐ŸŽ‰\nThird line", + 14, + 2, + "@user ใ“ใ‚“ใซใกใฏ ๐ŸŽ‰", + ), + ], + ) + def test_for_mention_in_comment( + self, comment, position, expected_lineno, expected_text + ): + line_info = LineInfo.for_mention_in_comment(comment, position) + + assert line_info.lineno == expected_lineno + assert line_info.text == expected_text + + @pytest.mark.parametrize( + "comment,position,expected_lineno,expected_text", + [ + # Trailing newlines should be stripped from line text + ("Hey @user\n", 4, 1, "Hey @user"), + # Position beyond comment length + ("Short", 100, 1, "Short"), + # Unix-style line endings + ("Line 1\n@user line 2", 7, 2, "@user line 2"), + # Windows-style line endings (\r\n handled as single separator) + ("Line 1\r\n@user line 2", 8, 2, "@user line 2"), + ], + ) + def test_edge_cases(self, comment, position, expected_lineno, expected_text): + line_info = LineInfo.for_mention_in_comment(comment, position) + + assert line_info.lineno == expected_lineno + assert line_info.text == expected_text + + @pytest.mark.parametrize( + "comment,position,expected_lineno", + [ + ("Hey @alice and @bob, please review", 4, 1), + ("Hey @alice and @bob, please review", 15, 1), + ], + ) + def test_multiple_mentions_same_line(self, comment, position, expected_lineno): + line_info = LineInfo.for_mention_in_comment(comment, position) + + assert line_info.lineno == expected_lineno + assert line_info.text == comment + + +class TestMatchesPattern: + @pytest.mark.parametrize( + "text,pattern,expected", + [ + # String patterns - exact match (case insensitive) + ("deploy", "deploy", True), + ("DEPLOY", "deploy", True), + ("deploy", "DEPLOY", True), + ("Deploy", "deploy", True), + # String patterns - whitespace handling + (" deploy ", "deploy", True), + ("deploy", " deploy ", True), + (" deploy ", " deploy ", True), + # String patterns - no match + ("deploy prod", "deploy", False), + ("deployment", "deploy", False), + ("redeploy", "deploy", False), + ("help", "deploy", False), + # Empty strings + ("", "", True), + ("deploy", "", False), + ("", "deploy", False), + # Special characters in string patterns + ("deploy-prod", "deploy-prod", True), + ("deploy_prod", "deploy_prod", True), + ("deploy.prod", "deploy.prod", True), + ], + ) + def test_string_pattern_matching(self, text, pattern, expected): + assert matches_pattern(text, pattern) == expected + + @pytest.mark.parametrize( + "text,pattern_str,flags,expected", + [ + # Basic regex patterns + ("deploy", r"deploy", 0, True), + ("deploy prod", r"deploy", 0, False), # fullmatch requires entire string + ("deploy", r".*deploy.*", 0, True), + ("redeploy", r".*deploy.*", 0, True), + # Case sensitivity with regex - moved to test_pattern_flags_preserved + # Complex regex patterns + ("deploy-prod", r"deploy-(prod|staging|dev)", 0, True), + ("deploy-staging", r"deploy-(prod|staging|dev)", 0, True), + ("deploy-test", r"deploy-(prod|staging|dev)", 0, False), + # Anchored patterns (fullmatch behavior) + ("deploy prod", r"^deploy$", 0, False), + ("deploy", r"^deploy$", 0, True), + # Wildcards and quantifiers + ("deploy", r"dep.*", 0, True), + ("deployment", r"deploy.*", 0, True), + ("dep", r"deploy?", 0, False), # fullmatch requires entire string + # Character classes + ("deploy123", r"deploy\d+", 0, True), + ("deploy-abc", r"deploy\d+", 0, False), + # Empty pattern + ("anything", r".*", 0, True), + ("", r".*", 0, True), + # Suffix matching (from removed test) + ("deploy-bot", r".*-bot", 0, True), + ("test-bot", r".*-bot", 0, True), + ("user", r".*-bot", 0, False), + # Prefix with digits (from removed test) + ("mybot1", r"mybot\d+", 0, True), + ("mybot2", r"mybot\d+", 0, True), + ("otherbot", r"mybot\d+", 0, False), + ], + ) + def test_regex_pattern_matching(self, text, pattern_str, flags, expected): + pattern = re.compile(pattern_str, flags) + + assert matches_pattern(text, pattern) == expected + + @pytest.mark.parametrize( + "text,expected", + [ + # re.match would return True for these, but fullmatch returns False + ("deploy prod", False), + ("deployment", False), + # Only exact full matches should return True + ("deploy", True), + ], + ) + def test_regex_fullmatch_vs_match_behavior(self, text, expected): + pattern = re.compile(r"deploy") + + assert matches_pattern(text, pattern) is expected + + @pytest.mark.parametrize( + "text,pattern_str,flags,expected", + [ + # Case insensitive pattern + ("DEPLOY", r"deploy", re.IGNORECASE, True), + ("Deploy", r"deploy", re.IGNORECASE, True), + ("deploy", r"deploy", re.IGNORECASE, True), + # Case sensitive pattern (default) + ("DEPLOY", r"deploy", 0, False), + ("Deploy", r"deploy", 0, False), + ("deploy", r"deploy", 0, True), + # DOTALL flag allows . to match newlines + ("line1\nline2", r"line1.*line2", re.DOTALL, True), + ( + "line1\nline2", + r"line1.*line2", + 0, + False, + ), # Without DOTALL, . doesn't match \n + ("line1 line2", r"line1.*line2", 0, True), + ], + ) + def test_pattern_flags_preserved(self, text, pattern_str, flags, expected): + pattern = re.compile(pattern_str, flags) + + assert matches_pattern(text, pattern) == expected + + +class TestMention: + @pytest.mark.parametrize( + "event_type,event_data,username,expected_count,expected_mentions", + [ + # Basic mention extraction + ( + "issue_comment", + {"comment": {"body": "@bot help"}}, + "bot", + 1, + [{"username": "bot"}], + ), + # No mentions in event + ( + "issue_comment", + {"comment": {"body": "No mentions here"}}, + None, + 0, + [], + ), + # Multiple mentions, filter by username + ( + "issue_comment", + {"comment": {"body": "@bot1 help @bot2 deploy @user test"}}, + re.compile(r"bot\d"), + 2, + [ + {"username": "bot1"}, + {"username": "bot2"}, + ], + ), + # Issue comment with issue data + ( + "issue_comment", + {"comment": {"body": "@bot help"}, "issue": {}}, + "bot", + 1, + [{"username": "bot"}], + ), + # PR comment (issue_comment with pull_request) + ( + "issue_comment", + {"comment": {"body": "@bot help"}, "issue": {"pull_request": {}}}, + "bot", + 1, + [{"username": "bot"}], + ), + # No username filter matches all mentions + ( + "issue_comment", + {"comment": {"body": "@alice review @bot help"}}, + None, + 2, + [{"username": "alice"}, {"username": "bot"}], + ), + # Get all mentions with wildcard regex pattern + ( + "issue_comment", + {"comment": {"body": "@alice review @bob help"}}, + re.compile(r".*"), + 2, + [ + {"username": "alice"}, + {"username": "bob"}, + ], + ), + # PR review comment + ( + "pull_request_review_comment", + {"comment": {"body": "@reviewer please check"}}, + "reviewer", + 1, + [{"username": "reviewer"}], + ), + # Commit comment + ( + "commit_comment", + {"comment": {"body": "@bot test this commit"}}, + "bot", + 1, + [{"username": "bot"}], + ), + # Empty comment body + ( + "issue_comment", + {"comment": {"body": ""}}, + None, + 0, + [], + ), + # Mentions in code blocks (should be ignored) + ( + "issue_comment", + {"comment": {"body": "```\n@bot deploy\n```\n@bot help"}}, + "bot", + 1, + [{"username": "bot"}], + ), + ], + ) + def test_from_event( + self, + create_event, + event_type, + event_data, + username, + expected_count, + expected_mentions, + ): + event = create_event(event_type, **event_data) + scope = MentionScope.from_event(event) + + mentions = list(Mention.from_event(event, username=username, scope=scope)) + + assert len(mentions) == expected_count + for mention, expected in zip(mentions, expected_mentions, strict=False): + assert mention.mention.username == expected["username"] + assert mention.scope == scope diff --git a/tests/test_models.py b/tests/test_models.py index bc1d5c6..a562931 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -272,10 +272,10 @@ async def test_arefresh_from_gh( account_type, private_key, ainstallation, - get_mock_github_api, + aget_mock_github_api, override_app_settings, ): - mock_github_api = get_mock_github_api({"foo": "bar"}) + mock_github_api = aget_mock_github_api({"foo": "bar"}) ainstallation.get_gh_client = MagicMock(return_value=mock_github_api) with override_app_settings(PRIVATE_KEY=private_key): @@ -289,10 +289,10 @@ def test_refresh_from_gh( account_type, private_key, installation, - get_mock_github_api, + aget_mock_github_api, override_app_settings, ): - mock_github_api = get_mock_github_api({"foo": "bar"}) + mock_github_api = aget_mock_github_api({"foo": "bar"}) installation.get_gh_client = MagicMock(return_value=mock_github_api) with override_app_settings(PRIVATE_KEY=private_key): diff --git a/tests/test_routing.py b/tests/test_routing.py index 6646d0c..d3c7e4e 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,10 +1,13 @@ from __future__ import annotations +import re + import pytest from django.http import HttpRequest from django.http import JsonResponse from django_github_app.github import SyncGitHubAPI +from django_github_app.mentions import MentionScope from django_github_app.routing import GitHubRouter from django_github_app.views import BaseWebhookView @@ -109,3 +112,418 @@ def test_router_memory_stress_test_legacy(self): assert len(views) == view_count assert not all(view.router is view1_router for view in views) + + +class TestMentionDecorator: + def test_mention(self, test_router, get_mock_github_api, create_event): + calls = [] + + @test_router.mention() + def handle_mention(event, *args, **kwargs): + calls.append((event, args, kwargs)) + + event = create_event( + "issue_comment", + action="created", + comment={"body": "@bot hello"}, + ) + + test_router.dispatch(event, get_mock_github_api({})) + + assert len(calls) > 0 + + @pytest.mark.asyncio + async def test_async_mention(self, test_router, aget_mock_github_api, create_event): + calls = [] + + @test_router.mention() + async def async_handle_mention(event, *args, **kwargs): + calls.append((event, args, kwargs)) + + event = create_event( + "issue_comment", + action="created", + comment={"body": "@bot async hello"}, + ) + + await test_router.adispatch(event, aget_mock_github_api({})) + + assert len(calls) > 0 + + @pytest.mark.parametrize( + "username,body,expected_call_count", + [ + ("bot", "@bot help", 1), + ("bot", "@other-bot help", 0), + (re.compile(r".*-bot"), "@deploy-bot start @test-bot check @user help", 2), + (re.compile(r".*"), "@alice review @bob deploy @charlie test", 3), + ("", "@alice review @bob deploy @charlie test", 3), + ], + ) + def test_mention_with_username( + self, + test_router, + get_mock_github_api, + create_event, + username, + body, + expected_call_count, + ): + calls = [] + + @test_router.mention(username=username) + def help_handler(event, *args, **kwargs): + calls.append((event, args, kwargs)) + + event = create_event( + "issue_comment", + action="created", + comment={"body": body}, + ) + + test_router.dispatch(event, get_mock_github_api({})) + + assert len(calls) == expected_call_count + + @pytest.mark.parametrize( + "username,body,expected_call_count", + [ + ("bot", "@bot help", 1), + ("bot", "@other-bot help", 0), + (re.compile(r".*-bot"), "@deploy-bot start @test-bot check @user help", 2), + (re.compile(r".*"), "@alice review @bob deploy @charlie test", 3), + ("", "@alice review @bob deploy @charlie test", 3), + ], + ) + @pytest.mark.asyncio + async def test_async_mention_with_username( + self, + test_router, + aget_mock_github_api, + create_event, + username, + body, + expected_call_count, + ): + calls = [] + + @test_router.mention(username=username) + async def help_handler(event, *args, **kwargs): + calls.append((event, args, kwargs)) + + event = create_event( + "issue_comment", + action="created", + comment={"body": body}, + ) + + await test_router.adispatch(event, aget_mock_github_api({})) + + assert len(calls) == expected_call_count + + @pytest.mark.parametrize( + "scope", [MentionScope.PR, MentionScope.ISSUE, MentionScope.COMMIT] + ) + def test_mention_with_scope( + self, + test_router, + get_mock_github_api, + create_event, + scope, + ): + calls = [] + + @test_router.mention(scope=scope) + def scoped_handler(event, *args, **kwargs): + calls.append((event, args, kwargs)) + + mock_gh = get_mock_github_api({}) + + expected_events = scope.get_events() + + # Test all events that should match this scope + for event_action in expected_events: + # Special case: PR scope issue_comment needs pull_request field + event_kwargs = {} + if scope == MentionScope.PR and event_action.event == "issue_comment": + event_kwargs["issue"] = {"pull_request": {"url": "..."}} + + event = create_event( + event_action.event, action=event_action.action, **event_kwargs + ) + + test_router.dispatch(event, mock_gh) + + assert len(calls) == len(expected_events) + + # Test that events from other scopes don't trigger this handler + for other_scope in MentionScope: + if other_scope == scope: + continue + + for event_action in other_scope.get_events(): + # Ensure the event has the right structure for its intended scope + event_kwargs = {} + if ( + other_scope == MentionScope.PR + and event_action.event == "issue_comment" + ): + event_kwargs["issue"] = {"pull_request": {"url": "..."}} + elif ( + other_scope == MentionScope.ISSUE + and event_action.event == "issue_comment" + ): + # Explicitly set empty issue (no pull_request) + event_kwargs["issue"] = {} + + event = create_event( + event_action.event, action=event_action.action, **event_kwargs + ) + test_router.dispatch(event, mock_gh) + + assert len(calls) == len(expected_events) + + @pytest.mark.parametrize( + "scope", [MentionScope.PR, MentionScope.ISSUE, MentionScope.COMMIT] + ) + @pytest.mark.asyncio + async def test_async_mention_with_scope( + self, + test_router, + aget_mock_github_api, + create_event, + scope, + ): + calls = [] + + @test_router.mention(scope=scope) + async def async_scoped_handler(event, *args, **kwargs): + calls.append((event, args, kwargs)) + + mock_gh = aget_mock_github_api({}) + + expected_events = scope.get_events() + + # Test all events that should match this scope + for event_action in expected_events: + # Special case: PR scope issue_comment needs pull_request field + event_kwargs = {} + if scope == MentionScope.PR and event_action.event == "issue_comment": + event_kwargs["issue"] = {"pull_request": {"url": "..."}} + + event = create_event( + event_action.event, action=event_action.action, **event_kwargs + ) + + await test_router.adispatch(event, mock_gh) + + assert len(calls) == len(expected_events) + + # Test that events from other scopes don't trigger this handler + for other_scope in MentionScope: + if other_scope == scope: + continue + + for event_action in other_scope.get_events(): + # Ensure the event has the right structure for its intended scope + event_kwargs = {} + if ( + other_scope == MentionScope.PR + and event_action.event == "issue_comment" + ): + event_kwargs["issue"] = {"pull_request": {"url": "..."}} + elif ( + other_scope == MentionScope.ISSUE + and event_action.event == "issue_comment" + ): + # Explicitly set empty issue (no pull_request) + event_kwargs["issue"] = {} + + event = create_event( + event_action.event, action=event_action.action, **event_kwargs + ) + + await test_router.adispatch(event, mock_gh) + + assert len(calls) == len(expected_events) + + def test_issue_scope_excludes_pr_comments( + self, test_router, get_mock_github_api, create_event + ): + calls = [] + + @test_router.mention(scope=MentionScope.ISSUE) + def issue_only_handler(event, *args, **kwargs): + calls.append((event, args, kwargs)) + + mock_gh = get_mock_github_api({}) + + # Test that regular issue comments trigger the handler + issue_event = create_event( + "issue_comment", + action="created", + comment={"body": "@bot help"}, + issue={}, # No pull_request field + ) + + test_router.dispatch(issue_event, mock_gh) + + assert len(calls) == 1 + + # Test that PR comments don't trigger the handler + pr_event = create_event( + "issue_comment", + action="created", + comment={"body": "@bot help"}, + issue={"pull_request": {"url": "https://github.com/test/repo/pull/1"}}, + ) + + test_router.dispatch(pr_event, mock_gh) + + # Should still be 1 - no new calls + assert len(calls) == 1 + + @pytest.mark.parametrize( + "event_kwargs,expected_call_count", + [ + # All conditions met + ( + { + "comment": {"body": "@deploy-bot deploy now"}, + "issue": {"pull_request": {"url": "..."}}, + }, + 1, + ), + # Wrong username + ( + { + "comment": {"body": "@bot deploy now"}, + "issue": {"pull_request": {"url": "..."}}, + }, + 0, + ), + # Different mention text (shouldn't matter without pattern) + ( + { + "comment": {"body": "@deploy-bot help"}, + "issue": {"pull_request": {"url": "..."}}, + }, + 1, + ), + # Wrong scope (issue instead of PR) + ( + { + "comment": {"body": "@deploy-bot deploy now"}, + "issue": {}, # No pull_request field + }, + 0, + ), + ], + ) + def test_combined_mention_filters( + self, + test_router, + get_mock_github_api, + create_event, + event_kwargs, + expected_call_count, + ): + calls = [] + + @test_router.mention( + username=re.compile(r".*-bot"), + scope=MentionScope.PR, + ) + def combined_filter_handler(event, *args, **kwargs): + calls.append((event, args, kwargs)) + + event = create_event("issue_comment", action="created", **event_kwargs) + + test_router.dispatch(event, get_mock_github_api({})) + + assert len(calls) == expected_call_count + + def test_mention_context(self, test_router, get_mock_github_api, create_event): + calls = [] + + @test_router.mention() + def test_handler(event, *args, **kwargs): + calls.append((event, args, kwargs)) + + event = create_event( + "issue_comment", + action="created", + comment={"body": "@bot test"}, + ) + + test_router.dispatch(event, get_mock_github_api({})) + + captured_mention = calls[0][2]["context"] + + assert captured_mention.scope.name == "ISSUE" + + triggered = captured_mention.mention + + assert triggered.username == "bot" + assert triggered.position == 0 + assert triggered.line_info.lineno == 1 + + @pytest.mark.asyncio + async def test_async_mention_context( + self, test_router, aget_mock_github_api, create_event + ): + calls = [] + + @test_router.mention() + async def async_handler(event, *args, **kwargs): + calls.append((event, args, kwargs)) + + event = create_event( + "issue_comment", + action="created", + comment={"body": "@bot async-test now"}, + ) + + await test_router.adispatch(event, aget_mock_github_api({})) + + captured_mention = calls[0][2]["context"] + + assert captured_mention.scope.name == "ISSUE" + + triggered = captured_mention.mention + + assert triggered.username == "bot" + assert triggered.position == 0 + assert triggered.line_info.lineno == 1 + + def test_mention_context_multiple_mentions( + self, test_router, get_mock_github_api, create_event + ): + calls = [] + + @test_router.mention() + def deploy_handler(event, *args, **kwargs): + calls.append((event, args, kwargs)) + + event = create_event( + "issue_comment", + action="created", + comment={"body": "@bot help\n@second-bot deploy production"}, + ) + + test_router.dispatch(event, get_mock_github_api({})) + + assert len(calls) == 2 + + first = calls[0][2]["context"].mention + second = calls[1][2]["context"].mention + + assert first.username == "bot" + assert first.line_info.lineno == 1 + assert first.previous_mention is None + assert first.next_mention is second + + assert second.username == "second-bot" + assert second.line_info.lineno == 2 + assert second.previous_mention is first + assert second.next_mention is None