@@ -35,9 +35,9 @@ class EntityResolver(Component, abc.ABC):
35
35
"""
36
36
37
37
def __init__ (
38
- self ,
39
- driver : neo4j .Driver ,
40
- filter_query : Optional [str ] = None ,
38
+ self ,
39
+ driver : neo4j .Driver ,
40
+ filter_query : Optional [str ] = None ,
41
41
) -> None :
42
42
self .driver = driver_config .override_user_agent (driver )
43
43
self .filter_query = filter_query
@@ -74,11 +74,11 @@ class SinglePropertyExactMatchResolver(EntityResolver):
74
74
"""
75
75
76
76
def __init__ (
77
- self ,
78
- driver : neo4j .Driver ,
79
- filter_query : Optional [str ] = None ,
80
- resolve_property : str = "name" ,
81
- neo4j_database : Optional [str ] = None ,
77
+ self ,
78
+ driver : neo4j .Driver ,
79
+ filter_query : Optional [str ] = None ,
80
+ resolve_property : str = "name" ,
81
+ neo4j_database : Optional [str ] = None ,
82
82
) -> None :
83
83
super ().__init__ (driver , filter_query )
84
84
self .resolve_property = resolve_property
@@ -174,19 +174,19 @@ class SpaCySemanticMatchResolver(EntityResolver):
174
174
"""
175
175
176
176
def __init__ (
177
- self ,
178
- driver : neo4j .Driver ,
179
- filter_query : Optional [str ] = None ,
180
- resolve_properties : Optional [List [str ]] = None ,
181
- similarity_threshold : float = 0.8 ,
182
- spacy_model : str = "en_core_web_lg" ,
183
- neo4j_database : Optional [str ] = None ,
177
+ self ,
178
+ driver : neo4j .Driver ,
179
+ filter_query : Optional [str ] = None ,
180
+ resolve_properties : Optional [List [str ]] = None ,
181
+ similarity_threshold : float = 0.8 ,
182
+ spacy_model : str = "en_core_web_lg" ,
183
+ neo4j_database : Optional [str ] = None ,
184
184
) -> None :
185
185
super ().__init__ (driver , filter_query )
186
186
self .resolve_properties = resolve_properties or ["name" ]
187
187
self .similarity_threshold = similarity_threshold
188
188
self .neo4j_database = neo4j_database
189
- self .nlp = spacy . load (spacy_model )
189
+ self .nlp = self . _load_or_download_spacy_model (spacy_model )
190
190
191
191
async def run (self ) -> ResolutionStats :
192
192
"""Resolve entities based on the following rules:
@@ -296,3 +296,20 @@ def _cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
296
296
return 0.0
297
297
return float (dot_product / (norm1 * norm2 ))
298
298
299
+ @staticmethod
300
+ def _load_or_download_spacy_model (model_name : str ):
301
+ """
302
+ Attempt to load the specified spaCy model by name.
303
+ If not installed, automatically download and then load it.
304
+ """
305
+ try :
306
+ return spacy .load (model_name )
307
+ except OSError as e :
308
+ # The exact error message can differ slightly depending on spaCy version,
309
+ # so you may want to be broader or narrower with handling logic:
310
+ if 'doesn\' t seem to be a Python package or a valid path' in str (e ):
311
+ print (f"Model '{ model_name } ' not found. Downloading..." )
312
+ spacy .cli .download (model_name )
313
+ return spacy .load (model_name )
314
+ else :
315
+ raise e
0 commit comments