Skip to content

Commit 6fbd38a

Browse files
authored
doc/search_methods (neo4j#68)
* Improve search method documentation * Merge with upstream/main * Fix import for older versions of Python * Ruff * Rename method according to discussion * Update docstring in the right method * Update CHANGELOG
1 parent 6d69a95 commit 6fbd38a

File tree

9 files changed

+147
-27
lines changed

9 files changed

+147
-27
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
- Added PromptTemplate and RagTemplate for customizable prompt generation.
1212
- Added LLMInterface with implementation for OpenAI LLM.
1313
- Updated project configuration to support multiple Python versions (3.8 to 3.12) in CI workflows.
14+
- Improved developer experience by copying the docstring from the `Retriever.get_search_results` method to the `Retriever.search` method
1415

1516
### Changed
1617
- Refactored import paths for retrievers to neo4j_genai.retrievers.
1718
- Implemented exception chaining for all re-raised exceptions to improve stack trace readability.
1819
- Made error messages in `index.py` more consistent.
20+
- Renamed `Retriever._get_search_results` to `Retriever.get_search_results`
1921

2022
## 0.2.0
2123

docs/source/api.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,27 @@ VectorRetriever
1818
===============
1919

2020
.. autoclass:: neo4j_genai.retrievers.vector.VectorRetriever
21-
:members:
21+
:members: search
2222

2323
VectorCypherRetriever
2424
=====================
2525

2626
.. autoclass:: neo4j_genai.retrievers.vector.VectorCypherRetriever
27-
:members:
27+
:members: search
2828

2929

3030
HybridRetriever
3131
===============
3232

3333
.. autoclass:: neo4j_genai.retrievers.hybrid.HybridRetriever
34-
:members:
34+
:members: search
3535

3636

3737
HybridCypherRetriever
3838
=====================
3939

4040
.. autoclass:: neo4j_genai.retrievers.hybrid.HybridCypherRetriever
41-
:members:
41+
:members: search
4242

4343

4444

@@ -53,14 +53,14 @@ WeaviateNeo4jRetriever
5353
======================
5454

5555
.. autoclass:: neo4j_genai.retrievers.external.weaviate.weaviate.WeaviateNeo4jRetriever
56-
:members:
56+
:members: search
5757

5858

5959
PineconeNeo4jRetriever
6060
======================
6161

6262
.. autoclass:: neo4j_genai.retrievers.external.pinecone.pinecone.PineconeNeo4jRetriever
63-
:members:
63+
:members: search
6464

6565

6666
******

src/neo4j_genai/retrievers/base.py

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,66 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from __future__ import annotations
16-
from abc import ABC, abstractmethod
17-
from typing import Optional, Callable, Any
16+
import types
17+
import inspect
18+
from abc import ABC, abstractmethod, ABCMeta
19+
from typing import Optional, Callable, Any, TypeVar
20+
from typing_extensions import ParamSpec
21+
1822
import neo4j
1923

2024
from neo4j_genai.types import RawSearchResult, RetrieverResult, RetrieverResultItem
2125
from neo4j_genai.exceptions import Neo4jVersionError
2226

27+
T = ParamSpec("T")
28+
P = TypeVar("P")
29+
30+
31+
def copy_function(f: Callable[T, P]) -> Callable[T, P]:
32+
"""Based on https://stackoverflow.com/a/30714299"""
33+
g = types.FunctionType(
34+
f.__code__,
35+
f.__globals__,
36+
name=f.__name__,
37+
argdefs=f.__defaults__,
38+
closure=f.__closure__,
39+
)
40+
# in case f was given attrs (note this dict is a shallow copy):
41+
g.__dict__.update(f.__dict__)
42+
return g
43+
44+
45+
class RetrieverMetaclass(ABCMeta):
46+
"""This metaclass is used to copy the docstring from the
47+
`get_search_results` method, instantiated in all subclasses,
48+
to the `search` method in the base class.
49+
"""
50+
51+
def __new__(
52+
meta, name: str, bases: tuple[type, ...], attrs: dict[str, Any]
53+
) -> type:
54+
if "search" in attrs:
55+
# search method was explicitly overridden, do nothing
56+
return type.__new__(meta, name, bases, attrs)
57+
# otherwise, we copy the signature and doc of the get_search_results
58+
# method to a copy of the search method
59+
get_search_results_method = attrs.get("get_search_results")
60+
search_method = None
61+
for b in bases:
62+
search_method = getattr(b, "search", None)
63+
if search_method is not None:
64+
break
65+
if search_method and get_search_results_method:
66+
new_search_method = copy_function(search_method)
67+
new_search_method.__doc__ = get_search_results_method.__doc__
68+
new_search_method.__signature__ = inspect.signature( # type: ignore
69+
get_search_results_method
70+
)
71+
attrs["search"] = new_search_method
72+
return type.__new__(meta, name, bases, attrs)
73+
2374

24-
class Retriever(ABC):
75+
class Retriever(ABC, metaclass=RetrieverMetaclass):
2576
"""
2677
Abstract class for Neo4j retrievers
2778
"""
@@ -78,11 +129,11 @@ def _fetch_index_infos(self) -> None:
78129
raise Exception(f"No index with name {self.index_name} found") from e
79130

80131
def search(self, *args: Any, **kwargs: Any) -> RetrieverResult:
132+
"""Search method. Call the `get_search_results` method that returns
133+
a list of `neo4j.Record`, and format them using the function returned by
134+
`get_result_formatter` to return `RetrieverResult`.
81135
"""
82-
Search method. Call the get_search_result method that returns
83-
a list of neo4j.Record, and format them to return RetrieverResult.
84-
"""
85-
raw_result = self._get_search_results(*args, **kwargs)
136+
raw_result = self.get_search_results(*args, **kwargs)
86137
formatter = self.get_result_formatter()
87138
search_items = [formatter(record) for record in raw_result.records]
88139
metadata = raw_result.metadata or {}
@@ -93,7 +144,20 @@ def search(self, *args: Any, **kwargs: Any) -> RetrieverResult:
93144
)
94145

95146
@abstractmethod
96-
def _get_search_results(self, *args: Any, **kwargs: Any) -> RawSearchResult:
147+
def get_search_results(self, *args: Any, **kwargs: Any) -> RawSearchResult:
148+
"""This method must be implemented in each child class. It will
149+
receive the same parameters provided to the public interface via
150+
the `search` method, after validation. It returns a `RawSearchResult`
151+
object which comprises a list of `neo4j.Record` objects and an optional
152+
`metadata` dictionary that can contain retriever-level information.
153+
154+
Note that, even though this method is not intended to be called from
155+
outside the class, we make it public to make it clearer for the developers
156+
that it should be implemented in child classes.
157+
158+
Returns:
159+
RawSearchResult: List of Neo4j Records and optional metadata dict
160+
"""
97161
pass
98162

99163
def get_result_formatter(self) -> Callable[[neo4j.Record], RetrieverResultItem]:
@@ -127,7 +191,7 @@ def __init__(
127191
self.id_property_neo4j = id_property_neo4j
128192

129193
@abstractmethod
130-
def _get_search_results(
194+
def get_search_results(
131195
self,
132196
query_vector: Optional[list[float]] = None,
133197
query_text: Optional[str] = None,
@@ -137,7 +201,7 @@ def _get_search_results(
137201
"""
138202
139203
Returns:
140-
list[neo4j.Record]: List of Neo4j Records
204+
RawSearchResult: List of Neo4j Records and optional metadata dict
141205
142206
"""
143207
pass

src/neo4j_genai/retrievers/external/pinecone/pinecone.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def __init__(
129129
self.retrieval_query = validated_data.retrieval_query
130130
self.result_formatter = validated_data.result_formatter
131131

132-
def _get_search_results(
132+
def get_search_results(
133133
self,
134134
query_vector: Optional[list[float]] = None,
135135
query_text: Optional[str] = None,

src/neo4j_genai/retrievers/external/weaviate/weaviate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(
117117
self.retrieval_query = validated_data.retrieval_query
118118
self.result_formatter = validated_data.result_formatter
119119

120-
def _get_search_results(
120+
def get_search_results(
121121
self,
122122
query_vector: Optional[list[float]] = None,
123123
query_text: Optional[str] = None,

src/neo4j_genai/retrievers/hybrid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def default_record_formatter(self, record: neo4j.Record) -> RetrieverResultItem:
115115
metadata=metadata,
116116
)
117117

118-
def _get_search_results(
118+
def get_search_results(
119119
self,
120120
query_text: str,
121121
query_vector: Optional[list[float]] = None,
@@ -243,7 +243,7 @@ def __init__(
243243
)
244244
self.result_formatter = result_formatter
245245

246-
def _get_search_results(
246+
def get_search_results(
247247
self,
248248
query_text: str,
249249
query_vector: Optional[list[float]] = None,

src/neo4j_genai/retrievers/text2cypher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(
9595
f"Failed to fetch schema for Text2CypherRetriever: {error_message}"
9696
) from e
9797

98-
def _get_search_results(
98+
def get_search_results(
9999
self,
100100
query_text: str,
101101
) -> RawSearchResult:

src/neo4j_genai/retrievers/vector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def default_record_formatter(self, record: neo4j.Record) -> RetrieverResultItem:
118118
metadata=metadata,
119119
)
120120

121-
def _get_search_results(
121+
def get_search_results(
122122
self,
123123
query_vector: Optional[list[float]] = None,
124124
query_text: Optional[str] = None,
@@ -245,7 +245,7 @@ def __init__(
245245
self._embedding_dimension = None
246246
self._fetch_index_infos()
247247

248-
def _get_search_results(
248+
def get_search_results(
249249
self,
250250
query_vector: Optional[list[float]] = None,
251251
query_text: Optional[str] = None,

tests/unit/retrievers/test_base.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414
# limitations under the License.
1515
from __future__ import annotations # Reminder: May be removed after Python 3.9 is EOL.
1616

17+
import inspect
18+
1719
import pytest
1820

21+
from typing import Union, Any
22+
from unittest.mock import MagicMock, patch
23+
1924
from neo4j_genai.exceptions import Neo4jVersionError
2025
from neo4j_genai.retrievers.base import Retriever
21-
from unittest.mock import MagicMock
22-
from typing import Union, Any
26+
from neo4j_genai.types import RawSearchResult, RetrieverResult
2327

2428

2529
@pytest.mark.parametrize(
@@ -37,12 +41,62 @@ def test_retriever_version_support(
3741
expected_exception: Union[type[ValueError], None],
3842
) -> None:
3943
class MockRetriever(Retriever):
40-
def _get_search_results(self, *args: Any, **kwargs: Any) -> None: # type: ignore
41-
pass
44+
def get_search_results(self, *args: Any, **kwargs: Any) -> RawSearchResult:
45+
return RawSearchResult(records=[])
4246

4347
driver.execute_query.return_value = [[{"versions": db_version}], None, None]
4448
if expected_exception:
4549
with pytest.raises(expected_exception):
4650
MockRetriever(driver=driver)
4751
else:
4852
MockRetriever(driver=driver)
53+
54+
55+
@patch("neo4j_genai.retrievers.base.Retriever._verify_version")
56+
def test_retriever_search_docstring_copied(
57+
_verify_version_mock: MagicMock,
58+
driver: MagicMock,
59+
) -> None:
60+
class MockRetriever(Retriever):
61+
def get_search_results(self, query: str, top_k: int = 10) -> RawSearchResult:
62+
"""My fabulous docstring"""
63+
return RawSearchResult(records=[])
64+
65+
retriever = MockRetriever(driver=driver)
66+
assert retriever.search.__doc__ == "My fabulous docstring"
67+
signature = inspect.signature(retriever.search)
68+
assert "query" in signature.parameters
69+
query_param = signature.parameters["query"]
70+
assert query_param.default == query_param.empty
71+
assert query_param.annotation == "str"
72+
assert "top_k" in signature.parameters
73+
top_k_param = signature.parameters["top_k"]
74+
assert top_k_param.default == 10
75+
assert top_k_param.annotation == "int"
76+
77+
78+
@patch("neo4j_genai.retrievers.base.Retriever._verify_version")
79+
def test_retriever_search_docstring_unchanged(
80+
_verify_version_mock: MagicMock,
81+
driver: MagicMock,
82+
) -> None:
83+
class MockRetrieverForNoise(Retriever):
84+
def get_search_results(self, query: str, top_k: int = 10) -> RawSearchResult:
85+
"""My fabulous docstring"""
86+
return RawSearchResult(records=[])
87+
88+
class MockRetriever(Retriever):
89+
def get_search_results(self, *args: Any, **kwargs: Any) -> RawSearchResult:
90+
return RawSearchResult(records=[])
91+
92+
def search(self, query: str, top_k: int = 10) -> RetrieverResult:
93+
"""My fabulous docstring that I do not want to be updated"""
94+
return RetrieverResult(items=[])
95+
96+
assert MockRetrieverForNoise.search is not MockRetriever.search
97+
98+
retriever = MockRetriever(driver=driver)
99+
assert (
100+
retriever.search.__doc__
101+
== "My fabulous docstring that I do not want to be updated"
102+
)

0 commit comments

Comments
 (0)