Skip to content

Commit 997018e

Browse files
Handle cases where spacy model is not yet downloaded
1 parent cd99060 commit 997018e

File tree

1 file changed

+33
-16
lines changed

1 file changed

+33
-16
lines changed

src/neo4j_graphrag/experimental/components/resolver.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ class EntityResolver(Component, abc.ABC):
3535
"""
3636

3737
def __init__(
38-
self,
39-
driver: neo4j.Driver,
40-
filter_query: Optional[str] = None,
38+
self,
39+
driver: neo4j.Driver,
40+
filter_query: Optional[str] = None,
4141
) -> None:
4242
self.driver = driver_config.override_user_agent(driver)
4343
self.filter_query = filter_query
@@ -74,11 +74,11 @@ class SinglePropertyExactMatchResolver(EntityResolver):
7474
"""
7575

7676
def __init__(
77-
self,
78-
driver: neo4j.Driver,
79-
filter_query: Optional[str] = None,
80-
resolve_property: str = "name",
81-
neo4j_database: Optional[str] = None,
77+
self,
78+
driver: neo4j.Driver,
79+
filter_query: Optional[str] = None,
80+
resolve_property: str = "name",
81+
neo4j_database: Optional[str] = None,
8282
) -> None:
8383
super().__init__(driver, filter_query)
8484
self.resolve_property = resolve_property
@@ -174,19 +174,19 @@ class SpaCySemanticMatchResolver(EntityResolver):
174174
"""
175175

176176
def __init__(
177-
self,
178-
driver: neo4j.Driver,
179-
filter_query: Optional[str] = None,
180-
resolve_properties: Optional[List[str]] = None,
181-
similarity_threshold: float = 0.8,
182-
spacy_model: str = "en_core_web_lg",
183-
neo4j_database: Optional[str] = None,
177+
self,
178+
driver: neo4j.Driver,
179+
filter_query: Optional[str] = None,
180+
resolve_properties: Optional[List[str]] = None,
181+
similarity_threshold: float = 0.8,
182+
spacy_model: str = "en_core_web_lg",
183+
neo4j_database: Optional[str] = None,
184184
) -> None:
185185
super().__init__(driver, filter_query)
186186
self.resolve_properties = resolve_properties or ["name"]
187187
self.similarity_threshold = similarity_threshold
188188
self.neo4j_database = neo4j_database
189-
self.nlp = spacy.load(spacy_model)
189+
self.nlp = self._load_or_download_spacy_model(spacy_model)
190190

191191
async def run(self) -> ResolutionStats:
192192
"""Resolve entities based on the following rules:
@@ -296,3 +296,20 @@ def _cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
296296
return 0.0
297297
return float(dot_product / (norm1 * norm2))
298298

299+
@staticmethod
300+
def _load_or_download_spacy_model(model_name: str):
301+
"""
302+
Attempt to load the specified spaCy model by name.
303+
If not installed, automatically download and then load it.
304+
"""
305+
try:
306+
return spacy.load(model_name)
307+
except OSError as e:
308+
# The exact error message can differ slightly depending on spaCy version,
309+
# so you may want to be broader or narrower with handling logic:
310+
if 'doesn\'t seem to be a Python package or a valid path' in str(e):
311+
print(f"Model '{model_name}' not found. Downloading...")
312+
spacy.cli.download(model_name)
313+
return spacy.load(model_name)
314+
else:
315+
raise e

0 commit comments

Comments
 (0)