Skip to content

Commit f25eeb5

Browse files
Add Fuzzy match resolver and refactor code
1 parent f041b59 commit f25eeb5

File tree

1 file changed

+113
-48
lines changed

1 file changed

+113
-48
lines changed

src/neo4j_graphrag/experimental/components/resolver.py

Lines changed: 113 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@
1515
import abc
1616
import logging
1717
from itertools import combinations
18-
from typing import Any, Optional, List
18+
from typing import Any, List, Optional
1919

20-
import neo4j
2120
import numpy as np
22-
from numpy.typing import NDArray
21+
import rapidfuzz.fuzz
2322
import spacy
24-
from spacy.language import Language
23+
from numpy.typing import NDArray
24+
from rapidfuzz import utils
2525
from spacy.cli.download import download as spacy_download
26+
from spacy.language import Language
2627

28+
import neo4j
2729
from neo4j_graphrag.experimental.components.types import ResolutionStats
2830
from neo4j_graphrag.experimental.pipeline import Component
2931
from neo4j_graphrag.utils import driver_config
@@ -148,34 +150,28 @@ async def run(self) -> ResolutionStats:
148150
)
149151

150152

151-
class SpaCySemanticMatchResolver(EntityResolver):
153+
class BasePropertySimilarityResolver(EntityResolver, abc.ABC):
152154
"""
155+
Base class for similarity-based matching of properties for entity resolution.
153156
Resolve entities with same label and similar set of textual properties (default is
154-
["name"]) based on spaCy's static embeddings and cosine similarities.
157+
["name"]):
158+
- Group entities by label
159+
- Concatenate the specified textual properties
160+
- Compute similarity between each pair
161+
- Consolidate overlapping sets
162+
- Merge similar nodes via APOC (See apoc.refactor.mergeNodes documentation for more
163+
details).
164+
165+
Subclasses implement `compute_similarity` based on different strategies, and return
166+
a similarity score between 0 and 1.
155167
156168
Args:
157169
driver (neo4j.Driver): The Neo4j driver to connect to the database.
158170
filter_query (Optional[str]): Optional Cypher WHERE clause to reduce the resolution scope.
159171
resolve_properties (Optional[List[str]]): The list of properties to consider for embeddings Defaults to ["name"].
160172
similarity_threshold (float): The similarity threshold above which nodes are merged. Defaults to 0.8.
161-
spacy_model (str): The name of the spaCy model to load. Defaults to "en_core_web_lg".
162173
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
163174
164-
Example:
165-
166-
.. code-block:: python
167-
168-
from neo4j import GraphDatabase
169-
from neo4j_graphrag.experimental.components.resolver import SinglePropertyExactMatchResolver
170-
171-
URI = "neo4j://localhost:7687"
172-
AUTH = ("neo4j", "password")
173-
DATABASE = "neo4j"
174-
175-
driver = GraphDatabase.driver(URI, auth=AUTH)
176-
resolver = SinglePropertyExactMatchResolver(driver=driver, neo4j_database=DATABASE)
177-
await resolver.run() # no expected parameters
178-
179175
"""
180176

181177
def __init__(
@@ -184,22 +180,21 @@ def __init__(
184180
filter_query: Optional[str] = None,
185181
resolve_properties: Optional[List[str]] = None,
186182
similarity_threshold: float = 0.8,
187-
spacy_model: str = "en_core_web_lg",
188183
neo4j_database: Optional[str] = None,
189184
) -> None:
190185
super().__init__(driver, filter_query)
191186
self.resolve_properties = resolve_properties or ["name"]
192187
self.similarity_threshold = similarity_threshold
193188
self.neo4j_database = neo4j_database
194-
self.nlp = self._load_or_download_spacy_model(spacy_model)
195189

196-
async def run(self) -> ResolutionStats:
197-
"""Resolve entities based on the following rules:
198-
For each entity label, entities with similar 'resolve_properties'
199-
(cosine similarity on embedding vectors) are merged into a single node.
200-
201-
See apoc.refactor.mergeNodes documentation for more details.
190+
@abc.abstractmethod
191+
def compute_similarity(self, text_a: str, text_b: str) -> float:
202192
"""
193+
Compute similarity between two textual strings.
194+
"""
195+
pass
196+
197+
async def run(self) -> ResolutionStats:
203198
match_query = "MATCH (entity:__Entity__)"
204199
if self.filter_query:
205200
match_query += f" {self.filter_query}"
@@ -212,7 +207,7 @@ async def run(self) -> ResolutionStats:
212207
# matches extracted entities
213208
# filters entities if filter_query is provided
214209
# unwinds labels to skip reserved ones
215-
# collects all properties needed for embeddings
210+
# collects all properties needed for the calculation of similarity
216211
query = f"""
217212
{match_query}
218213
UNWIND labels(entity) AS lab
@@ -224,41 +219,37 @@ async def run(self) -> ResolutionStats:
224219

225220
records, _, _ = self.driver.execute_query(query, database_=self.neo4j_database)
226221

227-
total_entities_embedded = 0
222+
total_entities = 0
228223
total_merged_nodes = 0
229224

230225
# for each row, 'lab' is the label, 'labelCluster' is a list of dicts (id + textual properties)
231226
for row in records:
232227
entities = row["labelCluster"]
233228

234-
# build node embeddings
235-
node_embeddings = {}
229+
node_texts = {}
236230
for ent in entities:
237231
# concatenate all textual properties (if non-null) into a single string
238232
texts = [
239233
str(ent[p]) for p in self.resolve_properties if p in ent and ent[p]
240234
]
241235
combined_text = " ".join(texts).strip()
242236
if combined_text:
243-
node_embeddings[ent["id"]] = self.nlp(combined_text).vector
244-
total_entities_embedded += len(node_embeddings)
237+
node_texts[ent["id"]] = combined_text
238+
total_entities += len(node_texts)
245239

246-
# identify pairs to merge
240+
# compute pairwise similarity and mark those above the threshold
247241
pairs_to_merge = []
248-
for (id1, emb1), (id2, emb2) in combinations(node_embeddings.items(), 2):
249-
sim = self._cosine_similarity(
250-
np.asarray(emb1, dtype=np.float64),
251-
np.asarray(emb2, dtype=np.float64),
252-
)
242+
for (id1, text1), (id2, text2) in combinations(node_texts.items(), 2):
243+
sim = self.compute_similarity(text1, text2)
253244
if sim >= self.similarity_threshold:
254245
pairs_to_merge.append({id1, id2})
255246

256-
# consolidate overlapping sets of node IDs
257-
resolved_sets = self._consolidate_sets(pairs_to_merge)
247+
# consolidate overlapping pairs into unique merge sets.
248+
merged_sets = self._consolidate_sets(pairs_to_merge)
258249

259-
# perform merges in the db using APOC
250+
# perform merges in the db using APOC.
260251
merged_count = 0
261-
for node_id_set in resolved_sets:
252+
for node_id_set in merged_sets:
262253
if len(node_id_set) > 1:
263254
merge_query = (
264255
"MATCH (n) WHERE id(n) IN $ids "
@@ -272,11 +263,10 @@ async def run(self) -> ResolutionStats:
272263
database_=self.neo4j_database,
273264
)
274265
merged_count += len(result)
275-
276266
total_merged_nodes += merged_count
277267

278268
return ResolutionStats(
279-
number_of_nodes_to_resolve=total_entities_embedded,
269+
number_of_nodes_to_resolve=total_entities,
280270
number_of_created_nodes=total_merged_nodes,
281271
)
282272

@@ -296,6 +286,63 @@ def _consolidate_sets(pairs: List[set[str]]) -> List[set[str]]:
296286
consolidated.append(set(pair))
297287
return consolidated
298288

289+
290+
class SpaCySemanticMatchResolver(BasePropertySimilarityResolver):
291+
"""
292+
Resolve entities with same label and similar set of textual properties (default is
293+
["name"]) based on spaCy's static embeddings and cosine similarities.
294+
295+
Args:
296+
driver (neo4j.Driver): The Neo4j driver to connect to the database.
297+
filter_query (Optional[str]): Optional Cypher WHERE clause to reduce the resolution scope.
298+
resolve_properties (Optional[List[str]]): The list of properties to consider for embeddings Defaults to ["name"].
299+
similarity_threshold (float): The similarity threshold above which nodes are merged. Defaults to 0.8.
300+
spacy_model (str): The name of the spaCy model to load. Defaults to "en_core_web_lg".
301+
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
302+
303+
Example:
304+
305+
.. code-block:: python
306+
307+
from neo4j import GraphDatabase
308+
from neo4j_graphrag.experimental.components.resolver import SpaCySemanticMatchResolver
309+
310+
URI = "neo4j://localhost:7687"
311+
AUTH = ("neo4j", "password")
312+
DATABASE = "neo4j"
313+
314+
driver = GraphDatabase.driver(URI, auth=AUTH)
315+
resolver = SpaCySemanticMatchResolver(driver=driver, neo4j_database=DATABASE)
316+
await resolver.run() # no expected parameters
317+
318+
"""
319+
320+
def __init__(
321+
self,
322+
driver: neo4j.Driver,
323+
filter_query: Optional[str] = None,
324+
resolve_properties: Optional[List[str]] = None,
325+
similarity_threshold: float = 0.8,
326+
spacy_model: str = "en_core_web_lg",
327+
neo4j_database: Optional[str] = None,
328+
) -> None:
329+
super().__init__(
330+
driver,
331+
filter_query,
332+
resolve_properties,
333+
similarity_threshold,
334+
neo4j_database,
335+
)
336+
self.nlp = self._load_or_download_spacy_model(spacy_model)
337+
338+
def compute_similarity(self, text_a: str, text_b: str) -> float:
339+
emb1 = self.nlp(text_a).vector
340+
emb2 = self.nlp(text_b).vector
341+
sim = self._cosine_similarity(
342+
np.asarray(emb1, dtype=np.float64), np.asarray(emb2, dtype=np.float64)
343+
)
344+
return sim
345+
299346
@staticmethod
300347
def _cosine_similarity(
301348
vec1: NDArray[np.float64], vec2: NDArray[np.float64]
@@ -324,3 +371,21 @@ def _load_or_download_spacy_model(model_name: str) -> Language:
324371
return spacy.load(model_name)
325372
else:
326373
raise e
374+
375+
376+
class FuzzyMatchResolver(BasePropertySimilarityResolver):
377+
"""
378+
Resolve entities with the same label and similar set of textual properties using
379+
RapidFuzz for fuzzy matching. Similarity scores are normalized to a value between 0
380+
and 1.
381+
"""
382+
383+
def compute_similarity(self, text_a: str, text_b: str) -> float:
384+
# RapidFuzz's fuzz.WRatio returns a score from 0 to 100
385+
# normalize the input strings before the comparison is done (processor=utils.default_process)
386+
# e.g., lowercase the text, strip whitespace, and remove punctuation
387+
# normalize the score to the 0..1 range
388+
return (
389+
rapidfuzz.fuzz.WRatio(text_a, text_b, processor=utils.default_process)
390+
/ 100.0
391+
)

0 commit comments

Comments
 (0)