Skip to content
4 changes: 2 additions & 2 deletions example.env
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ PROMETHEUS_NEO4J_BATCH_SIZE=1000
# Knowledge Graph settings
PROMETHEUS_WORKING_DIRECTORY=working_dir/
PROMETHEUS_KNOWLEDGE_GRAPH_MAX_AST_DEPTH=3
PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_SIZE=10000
PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_SIZE=8000
PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_OVERLAP=1000
PROMETHEUS_MAX_TOKEN_PER_NEO4J_RESULT=10000
PROMETHEUS_MAX_TOKEN_PER_NEO4J_RESULT=8000

# LLM model settings
PROMETHEUS_ADVANCED_MODEL=gpt-4o
Expand Down
2 changes: 2 additions & 0 deletions prometheus/lang_graph/nodes/edit_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class EditNode:
- Only one match of old_content should exist in the file
- If multiple matches exist, more context is needed
- If no matches exist, content must be verified
- Do not write any tests, your change will be tested by reproduction tests and regression tests later

EXAMPLES:

Expand Down Expand Up @@ -114,6 +115,7 @@ def other_method():
4. When replacing multiple lines, include all lines in old_content
5. If multiple matches found, include more context
6. Verify uniqueness of matches before changes
7. NEVER write tests, your change will be tested by reproduction tests and regression tests later
"""

def __init__(self, model: BaseChatModel, local_path: str):
Expand Down
34 changes: 22 additions & 12 deletions prometheus/lang_graph/nodes/final_patch_selection_node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import threading
from typing import Sequence

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
Expand Down Expand Up @@ -132,18 +133,13 @@ def __init__(self, model: BaseChatModel, max_retries: int = 2):
)
self.majority_voting_times = 10

def format_human_message(self, state: IssueNotVerifiedBugState):
if state["run_regression_test"]:
patches = [result.patch for result in state["tested_patch_result"] if result.passed]
else:
patches = state["edit_patches"]

def format_human_message(self, patches: Sequence[str], state: IssueNotVerifiedBugState):
patches_str = ""
for index, patch in enumerate(patches):
patches_str += f"Patch at index {index}:\n"
patches_str += f"{patch}\n\n"
patches_str += (
f"You must select a patch with index from 0 to {len(state['edit_patches']) - 1},"
f"You must select a patch with index from 0 to {len(patches) - 1},"
f" and provide your reasoning."
)

Expand All @@ -156,28 +152,42 @@ def format_human_message(self, state: IssueNotVerifiedBugState):
)

def __call__(self, state: IssueNotVerifiedBugState):
human_prompt = self.format_human_message(state)
result = [0 for _ in range(len(state["edit_patches"]))]
# Determine candidate patches
if state["run_regression_test"]:
patches = [result.patch for result in state["tested_patch_result"] if result.passed]
else:
patches = state["deduplicated_patches"]

# Formalize Human Message
human_prompt = self.format_human_message(patches, state)

# Majority voting
result = [0 for _ in range(len(patches))]
for turn in range(self.majority_voting_times):
# Call the model
response = self.model.invoke({"human_prompt": human_prompt})
self._logger.info(
f"FinalPatchSelectionNode response at {turn + 1}/{self.majority_voting_times} try:"
f"Selected patch index: {response.patch_index}, "
)

if 0 <= response.patch_index < len(state["edit_patches"]):
# Tally the vote if the index is valid
if 0 <= response.patch_index < len(patches):
result[response.patch_index] += 1

# Early stopping if a patch has already secured majority
if max(result) > self.majority_voting_times // 2:
selected_patch_index = result.index(max(result))
self._logger.info(
f"FinalPatchSelectionNode early stopping at turn {turn + 1} with result: {result},"
f"selected patch index: {selected_patch_index}"
)
return {"final_patch": state["edit_patches"][selected_patch_index]}
return {"final_patch": patches[selected_patch_index]}

# Select the maximum voted patch index
selected_patch_index = result.index(max(result))
self._logger.info(
f"FinalPatchSelectionNode voting results: {result}, "
f"selected patch index: {selected_patch_index}"
)
return {"final_patch": state["edit_patches"][selected_patch_index]}
return {"final_patch": patches[selected_patch_index]}
22 changes: 9 additions & 13 deletions prometheus/lang_graph/nodes/patch_normalization_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import threading
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List
from typing import Dict, List, Sequence

from prometheus.lang_graph.subgraphs.issue_not_verified_bug_state import IssueNotVerifiedBugState


@dataclass
Expand Down Expand Up @@ -93,7 +95,7 @@ def calculate_patch_metrics(self, normalized_patch: str) -> PatchMetrics:
"""Calculate basic metrics for a patch"""
return PatchMetrics()

def deduplicate_patches(self, patches: List[str]) -> List[NormalizedPatch]:
def deduplicate_patches(self, patches: Sequence[str]) -> List[NormalizedPatch]:
"""Deduplicate patches using normalization

Returns list of unique normalized patches with occurrence counts.
Expand Down Expand Up @@ -136,20 +138,17 @@ def deduplicate_patches(self, patches: List[str]) -> List[NormalizedPatch]:

return deduplicated

def __call__(self, state: Dict) -> Dict:
def __call__(self, state: IssueNotVerifiedBugState) -> Dict:
"""Node call interface

Process edit_patches in state, return normalized, deduplicated patches and selected best patch
Process edit_patches in state, return normalized, deduplicated patches
"""
patches = state.get("edit_patches", [])

if not patches:
self._logger.warning("No patches found to process")
return {
"normalized_patches": [],
"final_patch": "",
"original_patch_count": 0,
"unique_patch_count": 0,
"deduplicated_patches": [],
}

self._logger.info(f"Starting to process {len(patches)} patches")
Expand All @@ -161,12 +160,9 @@ def __call__(self, state: Dict) -> Dict:
deduplicated_patches = [patch.original_content for patch in normalized_patches]

self._logger.info(
f"Patch processing complete, deduplicated to {len(normalized_patches)} unique patches"
f"Patch processing complete, deduplicated to {len(deduplicated_patches)} unique patches"
)

return {
"normalized_patches": normalized_patches,
"edit_patches": deduplicated_patches, # Return deduplicated patches for selection
"original_patch_count": len(patches),
"unique_patch_count": len(normalized_patches),
"deduplicated_patches": deduplicated_patches,
}
15 changes: 11 additions & 4 deletions prometheus/lang_graph/nodes/run_regression_tests_structure_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class RunRegressionTestsStructureOutput(BaseModel):
description="If the test failed, contains the complete test FAILURE log. Otherwise empty string"
)
total_tests_run: int = Field(
description="Total number of tests run, including both passed and failed tests"
description="Total number of tests run, including both passed and failed tests, or 0 if no tests were run",
default=0,
)


Expand All @@ -31,13 +32,14 @@ class RunRegressionTestsStructuredNode:
- Test summary showing "passed" or "PASSED"
- Warning is ok
- No "FAILURES" section
2. If a test fails, capture the complete failure output
2. If a test fails, capture the complete failure output. Otherwise empty string for failure log
3. Return the exact test identifiers that passed
4. Count the total number of tests run. Only count tests that were actually executed! If tests were unable to run due to an error, do not count them!

Return:
- passed_regression_tests: List of test identifier of regression tests that passed (e.g., class name and method name)
- regression_test_fail_log: empty string if all tests pass, exact complete test output if a test fails
- total_tests_run: Total number of tests run, including both passed and failed tests
- total_tests_run: Total number of tests run, including both passed and failed tests. If you can't find any test run, return 0

Example 1:
```
Expand Down Expand Up @@ -67,7 +69,7 @@ class RunRegressionTestsStructuredNode:
"test_file_operation.py::test_edit_file",
"test_file_operation.py::test_create_file_already_exists"
],
"reproducing_test_fail_log": "" # ONLY output the log exact and complete test FAILURE log when test failure. Otherwise empty string,
"reproducing_test_fail_log": "",
"total_tests_run": 7
}}

Expand All @@ -76,19 +78,24 @@ class RunRegressionTestsStructuredNode:
- A single failing test means the test is not passing
- Include complete test output in failure log
- Do Not output any log when where is no test executed. ONLY output the log exact and complete test FAILURE log when test failure!
- Do not forget to return the total number of tests run! If tests were unable to run due to an error, do not count them!
- If you can't find any test run, return 0 for total number of tests run!
"""
HUMAN_PROMPT = """
We have run the selected regression tests on the codebase.
The following regression tests were selected to run:
--- BEGIN SELECTED REGRESSION TESTS ---
{selected_regression_tests}
--- END SELECTED REGRESSION TESTS ---

Run Regression Tests Logs:
--- BEGIN LOG ---
{run_regression_tests_messages}
--- END LOG ---

Please analyze the logs and determine which regression tests passed!. You should return the exact test identifier
that we give to you.
Don't forget to return the total number of tests run!
"""

def __init__(self, model: BaseChatModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def invoke(self, query: str, max_refined_query_loop: int) -> Dict[str, Sequence[
- "context" (Sequence[Context]): A list of selected context snippets relevant to the query.
"""
# Set the recursion limit based on the maximum number of refined query loops
config = {"recursion_limit": max_refined_query_loop * 50}
config = {"recursion_limit": (max_refined_query_loop + 1) * 40}

input_state = {
"query": query,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class IssueNotVerifiedBugState(TypedDict):

edit_patches: Annotated[Sequence[str], add]

deduplicated_patches: Sequence[str]

final_patch: str

run_regression_test: bool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from prometheus.lang_graph.nodes.issue_bug_analyzer_message_node import IssueBugAnalyzerMessageNode
from prometheus.lang_graph.nodes.issue_bug_analyzer_node import IssueBugAnalyzerNode
from prometheus.lang_graph.nodes.issue_bug_context_message_node import IssueBugContextMessageNode
from prometheus.lang_graph.nodes.noop_node import NoopNode
from prometheus.lang_graph.nodes.patch_normalization_node import PatchNormalizationNode
from prometheus.lang_graph.nodes.reset_messages_node import ResetMessagesNode
from prometheus.lang_graph.subgraphs.issue_not_verified_bug_state import IssueNotVerifiedBugState

Expand All @@ -37,8 +37,6 @@ def __init__(
neo4j_driver: neo4j.Driver,
max_token_per_neo4j_result: int,
):
noop_node = NoopNode()

issue_bug_context_message_node = IssueBugContextMessageNode()
context_retrieval_subgraph_node = ContextRetrievalSubgraphNode(
model=base_model,
Expand Down Expand Up @@ -66,20 +64,22 @@ def __init__(
reset_issue_bug_analyzer_messages_node = ResetMessagesNode("issue_bug_analyzer_messages")
reset_edit_messages_node = ResetMessagesNode("edit_messages")

# Patch Normalization Node
patch_normalization_node = PatchNormalizationNode()

# Get pass regression test patch subgraph node
get_pass_regression_test_patch_subgraph_node = GetPassRegressionTestPatchSubgraphNode(
model=base_model,
container=container,
git_repo=git_repo,
testing_patch_key="edit_patches",
testing_patch_key="deduplicated_patches",
is_testing_patch_list=True,
)

# Final patch selection node
final_patch_selection_node = FinalPatchSelectionNode(advanced_model)

workflow = StateGraph(IssueNotVerifiedBugState)
workflow.add_node("noop_node", noop_node)

workflow.add_node("issue_bug_context_message_node", issue_bug_context_message_node)
workflow.add_node("context_retrieval_subgraph_node", context_retrieval_subgraph_node)
Expand All @@ -98,6 +98,8 @@ def __init__(
)
workflow.add_node("reset_edit_messages_node", reset_edit_messages_node)

workflow.add_node("patch_normalization_node", patch_normalization_node)

workflow.add_node(
"get_pass_regression_test_patch_subgraph_node",
get_pass_regression_test_patch_subgraph_node,
Expand All @@ -124,11 +126,11 @@ def __init__(
lambda state: len(state["edit_patches"]) < state["number_of_candidate_patch"],
{
True: "git_reset_node",
False: "noop_node",
False: "patch_normalization_node",
},
)
workflow.add_conditional_edges(
"noop_node",
"patch_normalization_node",
lambda state: state["run_regression_test"],
{
True: "get_pass_regression_test_patch_subgraph_node",
Expand Down