18
18
19
19
import neo4j
20
20
import numpy as np
21
+ from numpy .typing import NDArray
22
+ import spacy
23
+ from spacy .language import Language
24
+ from spacy .cli .download import download as spacy_download
21
25
22
26
from neo4j_graphrag .experimental .components .types import ResolutionStats
23
27
from neo4j_graphrag .experimental .pipeline import Component
24
28
from neo4j_graphrag .utils import driver_config
25
29
26
- import spacy
27
-
28
30
29
31
class EntityResolver (Component , abc .ABC ):
30
32
"""Entity resolution base class
@@ -241,7 +243,9 @@ async def run(self) -> ResolutionStats:
241
243
# identify pairs to merge
242
244
pairs_to_merge = []
243
245
for (id1 , emb1 ), (id2 , emb2 ) in combinations (node_embeddings .items (), 2 ):
244
- sim = self ._cosine_similarity (emb1 , emb2 )
246
+ sim = self ._cosine_similarity (
247
+ np .asarray (emb1 , dtype = np .float64 ),
248
+ np .asarray (emb2 , dtype = np .float64 ))
245
249
if sim >= self .similarity_threshold :
246
250
pairs_to_merge .append ({id1 , id2 })
247
251
@@ -273,9 +277,9 @@ async def run(self) -> ResolutionStats:
273
277
)
274
278
275
279
@staticmethod
276
- def _consolidate_sets (pairs : List [set ] ) -> List [set ]:
280
+ def _consolidate_sets (pairs : List [set [ str ]] ) -> List [set [ str ] ]:
277
281
"""Consolidate overlapping sets of node pairs into unique sets."""
278
- consolidated = []
282
+ consolidated : List [ set [ str ]] = []
279
283
for pair in pairs :
280
284
merged = False
281
285
for cons in consolidated :
@@ -289,7 +293,10 @@ def _consolidate_sets(pairs: List[set]) -> List[set]:
289
293
return consolidated
290
294
291
295
@staticmethod
292
- def _cosine_similarity (vec1 : np .ndarray , vec2 : np .ndarray ) -> float :
296
+ def _cosine_similarity (
297
+ vec1 : NDArray [np .float64 ],
298
+ vec2 : NDArray [np .float64 ]
299
+ ) -> float :
293
300
"""Calculate cosine similarity between two embedding vectors."""
294
301
dot_product = np .dot (vec1 , vec2 )
295
302
norm1 = np .linalg .norm (vec1 )
@@ -299,7 +306,7 @@ def _cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
299
306
return float (dot_product / (norm1 * norm2 ))
300
307
301
308
@staticmethod
302
- def _load_or_download_spacy_model (model_name : str ):
309
+ def _load_or_download_spacy_model (model_name : str ) -> Language :
303
310
"""
304
311
Attempt to load the specified spaCy model by name.
305
312
If not installed, automatically download and then load it.
@@ -311,7 +318,7 @@ def _load_or_download_spacy_model(model_name: str):
311
318
# so you may want to be broader or narrower with handling logic:
312
319
if "doesn't seem to be a Python package or a valid path" in str (e ):
313
320
print (f"Model '{ model_name } ' not found. Downloading..." )
314
- spacy . cli . download (model_name )
321
+ spacy_download (model_name )
315
322
return spacy .load (model_name )
316
323
else :
317
324
raise e
0 commit comments