Skip to content

Commit b8f600f

Browse files
authored
Adds commit classification rule (#397)
This PR adds a new rule using the `LLMService`. It sends the diff of a commit to the LLM and asks if this commit is security relevant or not. Relevance of the rule is set to 32 for now, but this value can be adjusted after evaluation. Thanks to @tommasoaiello
1 parent 53446b0 commit b8f600f

File tree

5 files changed

+102
-8
lines changed

5 files changed

+102
-8
lines changed

prospector/llm/llm_service.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import validators
44
from langchain_core.language_models.llms import LLM
55
from langchain_core.output_parsers import StrOutputParser
6+
from requests import HTTPError
67

78
from llm.instantiation import create_model_instance
8-
from llm.prompts import prompt_best_guess
9+
from llm.prompts.classify_commit import zero_shot as cc_zero_shot
10+
from llm.prompts.get_repository_url import prompt_best_guess
911
from log.logger import logger
1012
from util.config_parser import LLMServiceConfig
1113
from util.singleton import Singleton
@@ -74,3 +76,53 @@ def get_repository_url(self, advisory_description, advisory_references) -> str:
7476
raise RuntimeError(f"Prompt-model chain could not be invoked: {e}")
7577

7678
return url
79+
80+
def classify_commit(
81+
self, diff: str, repository_name: str, commit_message: str
82+
) -> bool:
83+
"""Ask an LLM whether a commit is security relevant or not. The response will be either True or False.
84+
85+
Args:
86+
candidate (Commit): The commit to input into the LLM
87+
88+
Returns:
89+
True if the commit is deemed security relevant, False if not.
90+
91+
Raises:
92+
ValueError if there is an error in the model invocation or the response was not valid.
93+
"""
94+
try:
95+
chain = cc_zero_shot | self.model | StrOutputParser()
96+
97+
is_relevant = chain.invoke(
98+
{
99+
"diff": diff,
100+
"repository_name": repository_name,
101+
"commit_message": commit_message,
102+
}
103+
)
104+
logger.info(f"LLM returned is_relevant={is_relevant}")
105+
106+
except HTTPError as e:
107+
# if the diff is too big, a 400 error is returned -> silently ignore by returning False for this commit
108+
status_code = e.response.status_code
109+
if status_code == 400:
110+
return False
111+
raise RuntimeError(f"Prompt-model chain could not be invoked: {e}")
112+
except Exception as e:
113+
raise RuntimeError(f"Prompt-model chain could not be invoked: {e}")
114+
115+
if is_relevant in [
116+
"True",
117+
"ANSWER:True",
118+
"```ANSWER:True```",
119+
]:
120+
return True
121+
elif is_relevant in [
122+
"False",
123+
"ANSWER:False",
124+
"```ANSWER:False```",
125+
]:
126+
return False
127+
else:
128+
raise RuntimeError(f"The model returned an invalid response: {is_relevant}")
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from langchain.prompts import PromptTemplate
2+
3+
zero_shot = PromptTemplate.from_template(
4+
"""Is the following commit security relevant or not?
5+
Please provide the output as a boolean value, either True or False.
6+
If it is security relevant just answer True otherwise answer False. Do not return anything else.
7+
8+
To provide you with some context, the name of the repository is: {repository_name}, and the
9+
commit message is: {commit_message}.
10+
11+
Finally, here is the diff of the commit:
12+
{diff}\n
13+
14+
15+
Your answer:\n"""
16+
)

prospector/rules/rules.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,18 @@ def apply(self, candidate: Commit, advisory_record: AdvisoryRecord):
413413
return False
414414

415415

416+
class CommitIsSecurityRelevant(Rule):
417+
"""Matches commits that are deemed security relevant by the commit classification service."""
418+
419+
def apply(
420+
self,
421+
candidate: Commit,
422+
) -> bool:
423+
return LLMService().classify_commit(
424+
candidate.diff, candidate.repository, candidate.message
425+
)
426+
427+
416428
RULES_PHASE_1: List[Rule] = [
417429
VulnIdInMessage("VULN_ID_IN_MESSAGE", 64),
418430
# CommitMentionedInAdv("COMMIT_IN_ADVISORY", 64),
@@ -433,4 +445,6 @@ def apply(self, candidate: Commit, advisory_record: AdvisoryRecord):
433445
CommitHasTwins("COMMIT_HAS_TWINS", 2),
434446
]
435447

436-
RULES_PHASE_2: List[Rule] = []
448+
RULES_PHASE_2: List[Rule] = [
449+
CommitIsSecurityRelevant("COMMIT_IS_SECURITY_RELEVANT", 32)
450+
]

prospector/rules/rules_test.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def candidates():
8989
changed_files={
9090
"core/src/main/java/org/apache/cxf/workqueue/AutomaticWorkQueueImpl.java"
9191
},
92-
minhash=get_encoded_minhash(get_msg("Insecure deserialization", 50)),
92+
minhash=get_encoded_minhash(
93+
get_msg("Insecure deserialization", 50)
94+
),
9395
),
9496
# TODO: Not matched by existing tests: GHSecurityAdvInMessage, ReferencesBug, ChangesRelevantCode, TwinMentionedInAdv, VulnIdInLinkedIssue, SecurityKeywordInLinkedGhIssue, SecurityKeywordInLinkedBug, CrossReferencedBug, CrossReferencedGh, CommitHasTwins, ChangesRelevantFiles, CommitMentionedInAdv, RelevantWordsInMessage
9597
]
@@ -109,37 +111,47 @@ def advisory_record():
109111
)
110112

111113

112-
def test_apply_phase_1_rules(candidates: List[Commit], advisory_record: AdvisoryRecord):
114+
def test_apply_phase_1_rules(
115+
candidates: List[Commit], advisory_record: AdvisoryRecord
116+
):
113117
annotated_candidates = apply_rules(
114118
candidates, advisory_record, enabled_rules=enabled_rules_from_config
115119
)
116120

117121
# Repo 5: Should match: AdvKeywordsInFiles, SecurityKeywordsInMsg, CommitMentionedInReference
118122
assert len(annotated_candidates[0].matched_rules) == 3
119123

120-
matched_rules_names = [item["id"] for item in annotated_candidates[0].matched_rules]
124+
matched_rules_names = [
125+
item["id"] for item in annotated_candidates[0].matched_rules
126+
]
121127
assert "ADV_KEYWORDS_IN_FILES" in matched_rules_names
122128
assert "COMMIT_IN_REFERENCE" in matched_rules_names
123129
assert "SEC_KEYWORDS_IN_MESSAGE" in matched_rules_names
124130

125131
# Repo 1: Should match: VulnIdInMessage, ReferencesGhIssue
126132
assert len(annotated_candidates[1].matched_rules) == 2
127133

128-
matched_rules_names = [item["id"] for item in annotated_candidates[1].matched_rules]
134+
matched_rules_names = [
135+
item["id"] for item in annotated_candidates[1].matched_rules
136+
]
129137
assert "VULN_ID_IN_MESSAGE" in matched_rules_names
130138
assert "GITHUB_ISSUE_IN_MESSAGE" in matched_rules_names
131139

132140
# Repo 3: Should match: VulnIdInMessage, ReferencesGhIssue
133141
assert len(annotated_candidates[2].matched_rules) == 2
134142

135-
matched_rules_names = [item["id"] for item in annotated_candidates[2].matched_rules]
143+
matched_rules_names = [
144+
item["id"] for item in annotated_candidates[2].matched_rules
145+
]
136146
assert "VULN_ID_IN_MESSAGE" in matched_rules_names
137147
assert "GITHUB_ISSUE_IN_MESSAGE" in matched_rules_names
138148

139149
# Repo 4: Should match: SecurityKeywordsInMsg
140150
assert len(annotated_candidates[3].matched_rules) == 1
141151

142-
matched_rules_names = [item["id"] for item in annotated_candidates[3].matched_rules]
152+
matched_rules_names = [
153+
item["id"] for item in annotated_candidates[3].matched_rules
154+
]
143155
assert "SEC_KEYWORDS_IN_MESSAGE" in matched_rules_names
144156

145157
# Repo 2: Matches nothing

0 commit comments

Comments
 (0)