Skip to content

Commit 5ad7160

Browse files
Handle caching for spaCy embeddings
1 parent 3d0acbe commit 5ad7160

File tree

2 files changed

+58
-3
lines changed

2 files changed

+58
-3
lines changed

src/neo4j_graphrag/experimental/components/resolver.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,15 +334,21 @@ def __init__(
334334
neo4j_database,
335335
)
336336
self.nlp = self._load_or_download_spacy_model(spacy_model)
337+
self.embedding_cache: dict[str, np.ndarray] = {}
337338

338339
def compute_similarity(self, text_a: str, text_b: str) -> float:
339-
emb1 = self.nlp(text_a).vector
340-
emb2 = self.nlp(text_b).vector
340+
emb1 = self._get_embedding(text_a)
341+
emb2 = self._get_embedding(text_b)
341342
sim = self._cosine_similarity(
342343
np.asarray(emb1, dtype=np.float64), np.asarray(emb2, dtype=np.float64)
343344
)
344345
return sim
345346

347+
def _get_embedding(self, text: str) -> np.ndarray:
348+
if text not in self.embedding_cache:
349+
self.embedding_cache[text] = self.nlp(text).vector
350+
return self.embedding_cache[text]
351+
346352
@staticmethod
347353
def _cosine_similarity(
348354
vec1: NDArray[np.float64], vec2: NDArray[np.float64]

tests/unit/experimental/components/test_resolver.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from unittest.mock import MagicMock, call
15+
from unittest.mock import MagicMock, call, patch
1616

1717
import pytest
1818
from neo4j_graphrag.experimental.components.resolver import (
@@ -237,3 +237,52 @@ async def test_fuzzy_match_resolver_normalization(driver: MagicMock) -> None:
237237

238238
sim = resolver.compute_similarity(" ALICE ", "alice!")
239239
assert sim == 1
240+
241+
242+
@pytest.mark.asyncio
243+
async def test_spacy_resolver_caching(driver: MagicMock) -> None:
244+
driver.execute_query.side_effect = [
245+
(
246+
[
247+
neo4j.Record(
248+
{
249+
"lab": "Person",
250+
"labelCluster": [
251+
{"id": 1, "name": "Alice"},
252+
{"id": 2, "name": "Alice"},
253+
{"id": 3, "name": "Bob"},
254+
],
255+
}
256+
)
257+
],
258+
None,
259+
None,
260+
),
261+
(
262+
[neo4j.Record({"id(node)": 1})],
263+
None,
264+
None,
265+
),
266+
(
267+
[neo4j.Record({"id(node)": 3})],
268+
None,
269+
None,
270+
),
271+
]
272+
273+
resolver = SpaCySemanticMatchResolver(driver=driver)
274+
275+
# patch spaCy NLP call to track how often embeddings are computed
276+
with patch.object(resolver, "nlp", wraps=resolver.nlp) as mock_nlp:
277+
await resolver.run()
278+
279+
# "Alice" should be embedded only once, despite being used twice.
280+
# "Bob" should be embedded once.
281+
assert mock_nlp.call_count == 2, (
282+
f"Expected spaCy to embed each unique text once. Got {mock_nlp.call_count} "
283+
f"calls."
284+
)
285+
286+
# "Alice" and "Bob" are expected to be the only two distinct texts passed to spaCy.
287+
called_texts = {call.args[0] for call in mock_nlp.call_args_list}
288+
assert called_texts == {"Alice", "Bob"}

0 commit comments

Comments
 (0)