Skip to content

Commit 9b905dc

Browse files
Fix mypy errors
1 parent e4dc3e3 commit 9b905dc

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

src/neo4j_graphrag/experimental/components/resolver.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818

1919
import neo4j
2020
import numpy as np
21+
from numpy.typing import NDArray
22+
import spacy
23+
from spacy.language import Language
24+
from spacy.cli.download import download as spacy_download
2125

2226
from neo4j_graphrag.experimental.components.types import ResolutionStats
2327
from neo4j_graphrag.experimental.pipeline import Component
2428
from neo4j_graphrag.utils import driver_config
2529

26-
import spacy
27-
2830

2931
class EntityResolver(Component, abc.ABC):
3032
"""Entity resolution base class
@@ -241,7 +243,9 @@ async def run(self) -> ResolutionStats:
241243
# identify pairs to merge
242244
pairs_to_merge = []
243245
for (id1, emb1), (id2, emb2) in combinations(node_embeddings.items(), 2):
244-
sim = self._cosine_similarity(emb1, emb2)
246+
sim = self._cosine_similarity(
247+
np.asarray(emb1, dtype=np.float64),
248+
np.asarray(emb2, dtype=np.float64))
245249
if sim >= self.similarity_threshold:
246250
pairs_to_merge.append({id1, id2})
247251

@@ -273,9 +277,9 @@ async def run(self) -> ResolutionStats:
273277
)
274278

275279
@staticmethod
276-
def _consolidate_sets(pairs: List[set]) -> List[set]:
280+
def _consolidate_sets(pairs: List[set[str]]) -> List[set[str]]:
277281
"""Consolidate overlapping sets of node pairs into unique sets."""
278-
consolidated = []
282+
consolidated: List[set[str]] = []
279283
for pair in pairs:
280284
merged = False
281285
for cons in consolidated:
@@ -289,7 +293,10 @@ def _consolidate_sets(pairs: List[set]) -> List[set]:
289293
return consolidated
290294

291295
@staticmethod
292-
def _cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
296+
def _cosine_similarity(
297+
vec1: NDArray[np.float64],
298+
vec2: NDArray[np.float64]
299+
) -> float:
293300
"""Calculate cosine similarity between two embedding vectors."""
294301
dot_product = np.dot(vec1, vec2)
295302
norm1 = np.linalg.norm(vec1)
@@ -299,7 +306,7 @@ def _cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
299306
return float(dot_product / (norm1 * norm2))
300307

301308
@staticmethod
302-
def _load_or_download_spacy_model(model_name: str):
309+
def _load_or_download_spacy_model(model_name: str) -> Language:
303310
"""
304311
Attempt to load the specified spaCy model by name.
305312
If not installed, automatically download and then load it.
@@ -311,7 +318,7 @@ def _load_or_download_spacy_model(model_name: str):
311318
# so you may want to be broader or narrower with handling logic:
312319
if "doesn't seem to be a Python package or a valid path" in str(e):
313320
print(f"Model '{model_name}' not found. Downloading...")
314-
spacy.cli.download(model_name)
321+
spacy_download(model_name)
315322
return spacy.load(model_name)
316323
else:
317324
raise e

0 commit comments

Comments
 (0)