Skip to content

Commit ad7f91f

Browse files
authored
Add embedder for TextChunkEmbedder to SimpleKGPipeline (#166)
* Add embedder for TextChunkEmbedder to SimpleKGPipeline * E2E test * llm back to MagicMock * from_pdf = False in e2e test * Remove async client close in E2E test
1 parent 72964fe commit ad7f91f

File tree

5 files changed

+177
-5
lines changed

5 files changed

+177
-5
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ llm = OpenAILLM(
6464
kg_builder = SimpleKGPipeline(
6565
llm=llm,
6666
driver=driver,
67+
embedder=OpenAIEmbeddings(),
6768
file_path=file_path,
6869
entities=entities,
6970
relations=relations,

examples/pipeline/kg_builder_example.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import logging
1919

2020
import neo4j
21+
from neo4j_graphrag.embeddings import OpenAIEmbeddings
2122
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
2223
from neo4j_graphrag.llm.openai_llm import OpenAILLM
2324

@@ -44,10 +45,14 @@ async def main(neo4j_driver: neo4j.Driver) -> None:
4445
},
4546
)
4647

48+
# Use OpenAIEmbeddings as embedder
49+
embedder = OpenAIEmbeddings()
50+
4751
# Create an instance of the SimpleKGPipeline
4852
kg_builder_pdf = SimpleKGPipeline(
4953
llm=llm,
5054
driver=neo4j_driver,
55+
embedder=embedder,
5156
entities=entities,
5257
relations=relations,
5358
potential_schema=potential_schema,
@@ -64,6 +69,7 @@ async def main(neo4j_driver: neo4j.Driver) -> None:
6469
kg_builder_text = SimpleKGPipeline(
6570
llm=llm,
6671
driver=neo4j_driver,
72+
embedder=embedder,
6773
entities=entities,
6874
relations=relations,
6975
potential_schema=potential_schema,

src/neo4j_graphrag/experimental/pipeline/kg_builder.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import neo4j
2121
from pydantic import BaseModel, ConfigDict, Field
2222

23+
from neo4j_graphrag.embeddings import Embedder
24+
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
2325
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
2426
LLMEntityRelationExtractor,
2527
OnError,
@@ -47,6 +49,7 @@ class SimpleKGPipelineConfig(BaseModel):
4749
llm: LLMInterface
4850
driver: neo4j.Driver
4951
from_pdf: bool
52+
embedder: Embedder
5053
entities: list[SchemaEntity] = Field(default_factory=list)
5154
relations: list[SchemaRelation] = Field(default_factory=list)
5255
potential_schema: list[tuple[str, str, str]] = Field(default_factory=list)
@@ -68,6 +71,7 @@ class SimpleKGPipeline:
6871
Args:
6972
llm (LLMInterface): An instance of an LLM to use for entity and relation extraction.
7073
driver (neo4j.Driver): A Neo4j driver instance for database connection.
74+
embedder (Embedder): An instance of an embedder used to generate chunk embeddings from text chunks.
7175
entities (Optional[List[str]]): A list of entity labels as strings.
7276
relations (Optional[List[str]]): A list of relation labels as strings.
7377
potential_schema (Optional[List[tuple]]): A list of potential schema relationships.
@@ -79,12 +83,15 @@ class SimpleKGPipeline:
7983
kg_writer (Optional[Any]): A knowledge graph writer component. Defaults to Neo4jWriter().
8084
on_error (str): Error handling strategy. Defaults to "RAISE". Possible values: "RAISE" or "IGNORE".
8185
perform_entity_resolution (bool): Merge entities with same label and name. Default: True
86+
text_splitter (Optional[Any]): A text splitter component. Defaults to FixedSizeSplitter().
87+
prompt_template (str): A custom prompt template to use for extraction.
8288
"""
8389

8490
def __init__(
8591
self,
8692
llm: LLMInterface,
8793
driver: neo4j.Driver,
94+
embedder: Embedder,
8895
entities: Optional[List[str]] = None,
8996
relations: Optional[List[str]] = None,
9097
potential_schema: Optional[List[tuple[str, str, str]]] = None,
@@ -119,12 +126,14 @@ def __init__(
119126
text_splitter=text_splitter,
120127
on_error=on_error_enum,
121128
prompt_template=prompt_template,
129+
embedder=embedder,
122130
perform_entity_resolution=perform_entity_resolution,
123131
)
124132

125133
self.from_pdf = config.from_pdf
126134
self.llm = config.llm
127135
self.driver = config.driver
136+
self.embedder = config.embedder
128137
self.text_splitter = config.text_splitter or FixedSizeSplitter()
129138
self.on_error = config.on_error
130139
self.pdf_loader = config.pdf_loader if pdf_loader is not None else PdfLoader()
@@ -149,6 +158,7 @@ def _build_pipeline(self) -> Pipeline:
149158
),
150159
"extractor",
151160
)
161+
pipe.add_component(TextChunkEmbedder(embedder=self.embedder), "chunk_embedder")
152162
pipe.add_component(self.kg_writer, "writer")
153163

154164
if self.from_pdf:
@@ -178,9 +188,11 @@ def _build_pipeline(self) -> Pipeline:
178188
)
179189

180190
pipe.connect(
181-
"splitter",
182-
"extractor",
183-
input_config={"chunks": "splitter"},
191+
"splitter", "chunk_embedder", input_config={"text_chunks": "splitter"}
192+
)
193+
194+
pipe.connect(
195+
"chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"}
184196
)
185197

186198
# Connect extractor to writer
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
import os
18+
from unittest.mock import MagicMock
19+
20+
import neo4j
21+
import pytest
22+
from neo4j_graphrag.embeddings.base import Embedder
23+
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
24+
from neo4j_graphrag.llm import LLMInterface, LLMResponse
25+
26+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
27+
28+
29+
@pytest.fixture
30+
def llm() -> LLMInterface:
31+
llm = MagicMock(spec=LLMInterface)
32+
return llm
33+
34+
35+
@pytest.fixture
36+
def embedder() -> Embedder:
37+
embedder = MagicMock(spec=Embedder)
38+
return embedder
39+
40+
41+
@pytest.fixture
42+
def harry_potter_text() -> str:
43+
with open(os.path.join(BASE_DIR, "data/harry_potter.txt"), "r") as f:
44+
text = f.read()
45+
return text
46+
47+
48+
@pytest.mark.asyncio
49+
@pytest.mark.usefixtures("setup_neo4j_for_kg_construction")
50+
async def test_pipeline_builder_happy_path(
51+
harry_potter_text: str,
52+
llm: MagicMock,
53+
embedder: MagicMock,
54+
driver: neo4j.Driver,
55+
) -> None:
56+
"""When everything works as expected, extracted entities, relations and text
57+
chunks must be in the DB
58+
"""
59+
driver.execute_query("MATCH (n) DETACH DELETE n")
60+
embedder.embed_query.return_value = [1, 2, 3]
61+
llm.ainvoke.side_effect = [
62+
LLMResponse(
63+
content="""{
64+
"nodes": [
65+
{
66+
"id": "0",
67+
"label": "Person",
68+
"properties": {
69+
"name": "Harry Potter"
70+
}
71+
},
72+
{
73+
"id": "1",
74+
"label": "Person",
75+
"properties": {
76+
"name": "Alastor Mad-Eye Moody"
77+
}
78+
},
79+
{
80+
"id": "2",
81+
"label": "Organization",
82+
"properties": {
83+
"name": "The Order of the Phoenix"
84+
}
85+
}
86+
],
87+
"relationships": [
88+
{
89+
"type": "KNOWS",
90+
"start_node_id": "0",
91+
"end_node_id": "1"
92+
},
93+
{
94+
"type": "LED_BY",
95+
"start_node_id": "2",
96+
"end_node_id": "1"
97+
}
98+
]
99+
}"""
100+
),
101+
LLMResponse(content='{"nodes": [], "relationships": []}'),
102+
]
103+
104+
# Instantiate Entity and Relation objects
105+
entities = ["PERSON", "ORGANIZATION", "HORCRUX", "LOCATION"]
106+
relations = ["SITUATED_AT", "INTERACTS", "OWNS", "LED_BY"]
107+
potential_schema = [
108+
("PERSON", "SITUATED_AT", "LOCATION"),
109+
("PERSON", "INTERACTS", "PERSON"),
110+
("PERSON", "OWNS", "HORCRUX"),
111+
("ORGANIZATION", "LED_BY", "PERSON"),
112+
]
113+
114+
# Create an instance of the SimpleKGPipeline
115+
kg_builder_text = SimpleKGPipeline(
116+
llm=llm,
117+
driver=driver,
118+
embedder=embedder,
119+
entities=entities,
120+
relations=relations,
121+
potential_schema=potential_schema,
122+
from_pdf=False,
123+
on_error="RAISE",
124+
)
125+
126+
# Run the knowledge graph building process with text input
127+
text_input = "John Doe lives in New York City."
128+
await kg_builder_text.run_async(text=text_input)

tests/unit/experimental/pipeline/test_kg_builder.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import neo4j
1818
import pytest
19+
from neo4j_graphrag.embeddings import Embedder
1920
from neo4j_graphrag.experimental.components.entity_relation_extractor import OnError
2021
from neo4j_graphrag.experimental.components.schema import SchemaEntity, SchemaRelation
2122
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
@@ -28,15 +29,18 @@
2829
async def test_knowledge_graph_builder_init_with_text() -> None:
2930
llm = MagicMock(spec=LLMInterface)
3031
driver = MagicMock(spec=neo4j.Driver)
32+
embedder = MagicMock(spec=Embedder)
3133

3234
kg_builder = SimpleKGPipeline(
3335
llm=llm,
3436
driver=driver,
37+
embedder=embedder,
3538
from_pdf=False,
3639
)
3740

3841
assert kg_builder.llm == llm
3942
assert kg_builder.driver == driver
43+
assert kg_builder.embedder == embedder
4044
assert kg_builder.from_pdf is False
4145
assert kg_builder.entities == []
4246
assert kg_builder.relations == []
@@ -60,10 +64,12 @@ async def test_knowledge_graph_builder_init_with_text() -> None:
6064
async def test_knowledge_graph_builder_init_with_file_path() -> None:
6165
llm = MagicMock(spec=LLMInterface)
6266
driver = MagicMock(spec=neo4j.Driver)
67+
embedder = MagicMock(spec=Embedder)
6368

6469
kg_builder = SimpleKGPipeline(
6570
llm=llm,
6671
driver=driver,
72+
embedder=embedder,
6773
from_pdf=True,
6874
)
6975

@@ -92,10 +98,12 @@ async def test_knowledge_graph_builder_init_with_file_path() -> None:
9298
async def test_knowledge_graph_builder_run_with_both_inputs() -> None:
9399
llm = MagicMock(spec=LLMInterface)
94100
driver = MagicMock(spec=neo4j.Driver)
101+
embedder = MagicMock(spec=Embedder)
95102

96103
kg_builder = SimpleKGPipeline(
97104
llm=llm,
98105
driver=driver,
106+
embedder=embedder,
99107
from_pdf=True,
100108
)
101109

@@ -114,11 +122,13 @@ async def test_knowledge_graph_builder_run_with_both_inputs() -> None:
114122
async def test_knowledge_graph_builder_run_with_no_inputs() -> None:
115123
llm = MagicMock(spec=LLMInterface)
116124
driver = MagicMock(spec=neo4j.Driver)
125+
embedder = MagicMock(spec=Embedder)
117126

118127
kg_builder = SimpleKGPipeline(
119128
llm=llm,
120129
driver=driver,
121-
from_pdf=True, # or False
130+
embedder=embedder,
131+
from_pdf=True,
122132
)
123133

124134
with pytest.raises(PipelineDefinitionError) as exc_info:
@@ -133,10 +143,12 @@ async def test_knowledge_graph_builder_run_with_no_inputs() -> None:
133143
async def test_knowledge_graph_builder_document_info_with_file() -> None:
134144
llm = MagicMock(spec=LLMInterface)
135145
driver = MagicMock(spec=neo4j.Driver)
146+
embedder = MagicMock(spec=Embedder)
136147

137148
kg_builder = SimpleKGPipeline(
138149
llm=llm,
139150
driver=driver,
151+
embedder=embedder,
140152
from_pdf=True,
141153
)
142154

@@ -159,10 +171,12 @@ async def test_knowledge_graph_builder_document_info_with_file() -> None:
159171
async def test_knowledge_graph_builder_document_info_with_text() -> None:
160172
llm = MagicMock(spec=LLMInterface)
161173
driver = MagicMock(spec=neo4j.Driver)
174+
embedder = MagicMock(spec=Embedder)
162175

163176
kg_builder = SimpleKGPipeline(
164177
llm=llm,
165178
driver=driver,
179+
embedder=embedder,
166180
from_pdf=False,
167181
)
168182

@@ -184,6 +198,7 @@ async def test_knowledge_graph_builder_document_info_with_text() -> None:
184198
async def test_knowledge_graph_builder_with_entities_and_file() -> None:
185199
llm = MagicMock(spec=LLMInterface)
186200
driver = MagicMock(spec=neo4j.Driver)
201+
embedder = MagicMock(spec=Embedder)
187202

188203
entities = ["Document", "Section"]
189204
relations = ["CONTAINS"]
@@ -192,6 +207,7 @@ async def test_knowledge_graph_builder_with_entities_and_file() -> None:
192207
kg_builder = SimpleKGPipeline(
193208
llm=llm,
194209
driver=driver,
210+
embedder=embedder,
195211
entities=entities,
196212
relations=relations,
197213
potential_schema=potential_schema,
@@ -221,10 +237,12 @@ async def test_knowledge_graph_builder_with_entities_and_file() -> None:
221237
def test_simple_kg_pipeline_on_error_conversion() -> None:
222238
llm = MagicMock(spec=LLMInterface)
223239
driver = MagicMock(spec=neo4j.Driver)
240+
embedder = MagicMock(spec=Embedder)
224241

225242
kg_builder = SimpleKGPipeline(
226243
llm=llm,
227244
driver=driver,
245+
embedder=embedder,
228246
on_error="RAISE",
229247
)
230248

@@ -234,11 +252,13 @@ def test_simple_kg_pipeline_on_error_conversion() -> None:
234252
def test_simple_kg_pipeline_on_error_invalid_value() -> None:
235253
llm = MagicMock(spec=LLMInterface)
236254
driver = MagicMock(spec=neo4j.Driver)
255+
embedder = MagicMock(spec=Embedder)
237256

238257
with pytest.raises(PipelineDefinitionError) as exc_info:
239258
SimpleKGPipeline(
240259
llm=llm,
241260
driver=driver,
261+
embedder=embedder,
242262
on_error="IGNORE",
243263
)
244264

@@ -248,9 +268,14 @@ def test_simple_kg_pipeline_on_error_invalid_value() -> None:
248268
def test_simple_kg_pipeline_no_entity_resolution() -> None:
249269
llm = MagicMock(spec=LLMInterface)
250270
driver = MagicMock(spec=neo4j.Driver)
271+
embedder = MagicMock(spec=Embedder)
251272

252273
kg_builder = SimpleKGPipeline(
253-
llm=llm, driver=driver, on_error="CONTINUE", perform_entity_resolution=False
274+
llm=llm,
275+
driver=driver,
276+
embedder=embedder,
277+
on_error="CONTINUE",
278+
perform_entity_resolution=False,
254279
)
255280

256281
assert "resolver" not in kg_builder.pipeline

0 commit comments

Comments
 (0)