|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
15 | 15 | from unittest.mock import MagicMock
|
| 16 | +from warnings import catch_warnings |
16 | 17 |
|
17 | 18 | import pytest
|
18 | 19 | from neo4j_genai.exceptions import RagInitializationError, SearchValidationError
|
|
21 | 22 | from neo4j_genai.generation.types import RagResultModel
|
22 | 23 | from neo4j_genai.llm import LLMResponse
|
23 | 24 | from neo4j_genai.types import RetrieverResult, RetrieverResultItem
|
| 25 | +from pydantic import ValidationError |
24 | 26 |
|
25 | 27 |
|
26 | 28 | def test_graphrag_prompt_template() -> None:
|
@@ -99,3 +101,21 @@ def test_graphrag_search_error(retriever_mock: MagicMock, llm: MagicMock) -> Non
|
99 | 101 | with pytest.raises(SearchValidationError) as excinfo:
|
100 | 102 | rag.search(10) # type: ignore
|
101 | 103 | assert "Input should be a valid string" in str(excinfo)
|
| 104 | + |
| 105 | + |
| 106 | +def test_graphrag_search_query_deprecation_warning( |
| 107 | + retriever_mock: MagicMock, llm: MagicMock |
| 108 | +) -> None: |
| 109 | + with catch_warnings(record=True) as warn_list: |
| 110 | + rag = GraphRAG( |
| 111 | + retriever=retriever_mock, |
| 112 | + llm=llm, |
| 113 | + ) |
| 114 | + with pytest.raises(ValidationError): |
| 115 | + rag.search(query="Some query text") |
| 116 | + |
| 117 | + assert len(warn_list) == 1 |
| 118 | + assert ( |
| 119 | + str(warn_list[0].message) |
| 120 | + == "'query' is deprecated and will be removed in a future version, please use 'query_text' instead." |
| 121 | + ) |
0 commit comments