Skip to content

Commit 088488f

Browse files
Ruff
1 parent 9732246 commit 088488f

File tree

3 files changed

+33
-38
lines changed

3 files changed

+33
-38
lines changed

examples/customize/build_graph/components/resolvers/spacy_entity_resolver_pre_filter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ async def main(driver: neo4j.Driver) -> None:
2222
driver,
2323
# let's filter all entities that belong to a certain docId
2424
filter_query="WHERE (entity)-[:FROM_CHUNK]->(:Chunk)-[:FROM_DOCUMENT]->(doc:"
25-
"Document {id = 'docId'}",
25+
"Document {id = 'docId'}",
2626
# optionally, change the properties used for resolution (default is "name")
2727
# resolve_properties=["name", "ssn"],
2828
# the similarity threshold (default is 0.8)

src/neo4j_graphrag/experimental/components/resolver.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ class EntityResolver(Component, abc.ABC):
3535
"""
3636

3737
def __init__(
38-
self,
39-
driver: neo4j.Driver,
40-
filter_query: Optional[str] = None,
38+
self,
39+
driver: neo4j.Driver,
40+
filter_query: Optional[str] = None,
4141
) -> None:
4242
self.driver = driver_config.override_user_agent(driver)
4343
self.filter_query = filter_query
@@ -74,11 +74,11 @@ class SinglePropertyExactMatchResolver(EntityResolver):
7474
"""
7575

7676
def __init__(
77-
self,
78-
driver: neo4j.Driver,
79-
filter_query: Optional[str] = None,
80-
resolve_property: str = "name",
81-
neo4j_database: Optional[str] = None,
77+
self,
78+
driver: neo4j.Driver,
79+
filter_query: Optional[str] = None,
80+
resolve_property: str = "name",
81+
neo4j_database: Optional[str] = None,
8282
) -> None:
8383
super().__init__(driver, filter_query)
8484
self.resolve_property = resolve_property
@@ -174,13 +174,13 @@ class SpaCySemanticMatchResolver(EntityResolver):
174174
"""
175175

176176
def __init__(
177-
self,
178-
driver: neo4j.Driver,
179-
filter_query: Optional[str] = None,
180-
resolve_properties: Optional[List[str]] = None,
181-
similarity_threshold: float = 0.8,
182-
spacy_model: str = "en_core_web_lg",
183-
neo4j_database: Optional[str] = None,
177+
self,
178+
driver: neo4j.Driver,
179+
filter_query: Optional[str] = None,
180+
resolve_properties: Optional[List[str]] = None,
181+
similarity_threshold: float = 0.8,
182+
spacy_model: str = "en_core_web_lg",
183+
neo4j_database: Optional[str] = None,
184184
) -> None:
185185
super().__init__(driver, filter_query)
186186
self.resolve_properties = resolve_properties or ["name"]
@@ -230,8 +230,9 @@ async def run(self) -> ResolutionStats:
230230
node_embeddings = {}
231231
for ent in entities:
232232
# concatenate all textual properties (if non-null) into a single string
233-
texts = [str(ent[p]) for p in self.resolve_properties if
234-
p in ent and ent[p]]
233+
texts = [
234+
str(ent[p]) for p in self.resolve_properties if p in ent and ent[p]
235+
]
235236
combined_text = " ".join(texts).strip()
236237
if combined_text:
237238
node_embeddings[ent["id"]] = self.nlp(combined_text).vector
@@ -258,8 +259,9 @@ async def run(self) -> ResolutionStats:
258259
"YIELD node RETURN id(node)"
259260
)
260261
result, _, _ = self.driver.execute_query(
261-
merge_query, {"ids": list(node_id_set)},
262-
database_=self.neo4j_database
262+
merge_query,
263+
{"ids": list(node_id_set)},
264+
database_=self.neo4j_database,
263265
)
264266
merged_count += len(result)
265267

@@ -307,7 +309,7 @@ def _load_or_download_spacy_model(model_name: str):
307309
except OSError as e:
308310
# The exact error message can differ slightly depending on spaCy version,
309311
# so you may want to be broader or narrower with handling logic:
310-
if 'doesn\'t seem to be a Python package or a valid path' in str(e):
312+
if "doesn't seem to be a Python package or a valid path" in str(e):
311313
print(f"Model '{model_name}' not found. Downloading...")
312314
spacy.cli.download(model_name)
313315
return spacy.load(model_name)

tests/unit/experimental/components/test_resolver.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ async def test_simple_resolver_custom_filter(driver: MagicMock) -> None:
5858
)
5959
]
6060
)
61-
62-
61+
62+
6363
@pytest.mark.asyncio
6464
async def test_spacy_resolver_match_on_name_property(driver: MagicMock) -> None:
6565
driver.execute_query.side_effect = [
@@ -70,7 +70,7 @@ async def test_spacy_resolver_match_on_name_property(driver: MagicMock) -> None:
7070
"lab": "Person",
7171
"labelCluster": [
7272
{"id": 1, "name": "Alice"},
73-
{"id": 2, "name": "Alice"}
73+
{"id": 2, "name": "Alice"},
7474
],
7575
}
7676
)
@@ -94,6 +94,7 @@ async def test_spacy_resolver_match_on_name_property(driver: MagicMock) -> None:
9494

9595
assert driver.execute_query.call_count == 2
9696

97+
9798
@pytest.mark.asyncio
9899
async def test_spacy_resolver_no_merge(driver: MagicMock) -> None:
99100
driver.execute_query.side_effect = [
@@ -104,7 +105,7 @@ async def test_spacy_resolver_no_merge(driver: MagicMock) -> None:
104105
"lab": "Person",
105106
"labelCluster": [
106107
{"id": 1, "name": "Alice"},
107-
{"id": 2, "name": "Bob"}
108+
{"id": 2, "name": "Bob"},
108109
],
109110
}
110111
)
@@ -122,9 +123,10 @@ async def test_spacy_resolver_no_merge(driver: MagicMock) -> None:
122123

123124
assert driver.execute_query.call_count == 1
124125

126+
125127
@pytest.mark.asyncio
126128
async def test_spacy_resolver_match_on_multiple_text_properties(
127-
driver: MagicMock
129+
driver: MagicMock,
128130
) -> None:
129131
driver.execute_query.side_effect = [
130132
(
@@ -133,16 +135,8 @@ async def test_spacy_resolver_match_on_multiple_text_properties(
133135
{
134136
"lab": "Person",
135137
"labelCluster": [
136-
{
137-
"id": 10,
138-
"name": "John Smith",
139-
"ssn": "23-45-6789"
140-
},
141-
{
142-
"id": 11,
143-
"name": "Jonathan Smith",
144-
"ssn": "23-45-6789"
145-
}
138+
{"id": 10, "name": "John Smith", "ssn": "23-45-6789"},
139+
{"id": 11, "name": "Jonathan Smith", "ssn": "23-45-6789"},
146140
],
147141
}
148142
)
@@ -158,8 +152,7 @@ async def test_spacy_resolver_match_on_multiple_text_properties(
158152
]
159153

160154
resolver = SpaCySemanticMatchResolver(
161-
driver=driver,
162-
resolve_properties=["name", "ssn"]
155+
driver=driver, resolve_properties=["name", "ssn"]
163156
)
164157
res = await resolver.run()
165158
assert isinstance(res, ResolutionStats)

0 commit comments

Comments
 (0)