Skip to content

Commit af27578

Browse files
authored
Fixed query parameter bug in GraphRAG class (#109)
1 parent 6e544df commit af27578

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

src/neo4j_genai/generation/graphrag.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,13 @@ def search(
8585
DeprecationWarning,
8686
stacklevel=2,
8787
)
88-
elif isinstance(query, str):
89-
warnings.warn(
90-
"'query' is deprecated and will be removed in a future version, please use 'query_text' instead.",
91-
DeprecationWarning,
92-
stacklevel=2,
93-
)
94-
query_text = query
88+
elif isinstance(query, str):
89+
warnings.warn(
90+
"'query' is deprecated and will be removed in a future version, please use 'query_text' instead.",
91+
DeprecationWarning,
92+
stacklevel=2,
93+
)
94+
query_text = query
9595

9696
validated_data = RagSearchModel(
9797
query_text=query_text,

tests/unit/test_graphrag.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from unittest.mock import MagicMock
16+
from warnings import catch_warnings
1617

1718
import pytest
1819
from neo4j_genai.exceptions import RagInitializationError, SearchValidationError
@@ -21,6 +22,7 @@
2122
from neo4j_genai.generation.types import RagResultModel
2223
from neo4j_genai.llm import LLMResponse
2324
from neo4j_genai.types import RetrieverResult, RetrieverResultItem
25+
from pydantic import ValidationError
2426

2527

2628
def test_graphrag_prompt_template() -> None:
@@ -99,3 +101,21 @@ def test_graphrag_search_error(retriever_mock: MagicMock, llm: MagicMock) -> Non
99101
with pytest.raises(SearchValidationError) as excinfo:
100102
rag.search(10) # type: ignore
101103
assert "Input should be a valid string" in str(excinfo)
104+
105+
106+
def test_graphrag_search_query_deprecation_warning(
107+
retriever_mock: MagicMock, llm: MagicMock
108+
) -> None:
109+
with catch_warnings(record=True) as warn_list:
110+
rag = GraphRAG(
111+
retriever=retriever_mock,
112+
llm=llm,
113+
)
114+
with pytest.raises(ValidationError):
115+
rag.search(query="Some query text")
116+
117+
assert len(warn_list) == 1
118+
assert (
119+
str(warn_list[0].message)
120+
== "'query' is deprecated and will be removed in a future version, please use 'query_text' instead."
121+
)

0 commit comments

Comments
 (0)