Skip to content

Commit be56247

Browse files
authored
Add SimpleKGPipeline class (#165)
* Add KnowledgeGraphBuilder class * Simplify schema and relations * Add kg_writer to pydantic data validation * Refactor pipe_input to be passed to run and async run of KnowledgeGraphBuilder * Fixed mypy errors * Refactor KnowledgeGraphBuilder class * Fix build_pipeline * Removed SimpleKGPipeline from init * Update README * Handle event loop creation for python 3.9 * Fix typo in SchemaProperty docstring * Use PipelineDefinitionError for SimpleKGPipeline * PipelineDefinitionError in tests * Update kg_builder.run() in README * Allow users to pass strings instead of enums * Pass OnError as string in example * Update test to async * Fixed OnError mypy errors * Fixed test case for IGNORE OnError * Update SimpleKGPipeline example in README
1 parent 84397f0 commit be56247

File tree

6 files changed

+602
-2
lines changed

6 files changed

+602
-2
lines changed

README.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,47 @@ Follow installation instructions [here](https://pygraphviz.github.io/documentati
3636

3737
## Examples
3838

39+
### Knowledge graph construction
40+
41+
```python
42+
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
43+
from neo4j_graphrag.llm.openai_llm import OpenAILLM
44+
45+
# Instantiate Entity and Relation objects
46+
entities = ["PERSON", "ORGANIZATION", "LOCATION"]
47+
relations = ["SITUATED_AT", "INTERACTS", "LED_BY"]
48+
potential_schema = [
49+
("PERSON", "SITUATED_AT", "LOCATION"),
50+
("PERSON", "INTERACTS", "PERSON"),
51+
("ORGANIZATION", "LED_BY", "PERSON"),
52+
]
53+
54+
# Instantiate the LLM
55+
llm = OpenAILLM(
56+
model_name="gpt-4o",
57+
model_params={
58+
"max_tokens": 2000,
59+
"response_format": {"type": "json_object"},
60+
},
61+
)
62+
63+
# Create an instance of the SimpleKGPipeline
64+
kg_builder = SimpleKGPipeline(
65+
llm=llm,
66+
driver=driver,
67+
file_path=file_path,
68+
entities=entities,
69+
relations=relations,
70+
)
71+
72+
await kg_builder.run_async(text="""
73+
Albert Einstein was a German physicist born in 1879 who wrote many groundbreaking
74+
papers especially about general relativity and quantum mechanics.
75+
""")
76+
```
77+
78+
79+
3980
### Creating a vector index
4081

4182
When creating a vector index, make sure you match the number of dimensions in the index with the number of dimensions the embeddings have.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
import asyncio
18+
import logging
19+
20+
import neo4j
21+
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
22+
from neo4j_graphrag.llm.openai_llm import OpenAILLM
23+
24+
logging.basicConfig(level=logging.INFO)
25+
26+
27+
async def main(neo4j_driver: neo4j.Driver) -> None:
28+
# Instantiate Entity and Relation objects
29+
entities = ["PERSON", "ORGANIZATION", "HORCRUX", "LOCATION"]
30+
relations = ["SITUATED_AT", "INTERACTS", "OWNS", "LED_BY"]
31+
potential_schema = [
32+
("PERSON", "SITUATED_AT", "LOCATION"),
33+
("PERSON", "INTERACTS", "PERSON"),
34+
("PERSON", "OWNS", "HORCRUX"),
35+
("ORGANIZATION", "LED_BY", "PERSON"),
36+
]
37+
38+
# Instantiate the LLM
39+
llm = OpenAILLM(
40+
model_name="gpt-4o",
41+
model_params={
42+
"max_tokens": 2000,
43+
"response_format": {"type": "json_object"},
44+
},
45+
)
46+
47+
# Create an instance of the SimpleKGPipeline
48+
kg_builder_pdf = SimpleKGPipeline(
49+
llm=llm,
50+
driver=neo4j_driver,
51+
entities=entities,
52+
relations=relations,
53+
potential_schema=potential_schema,
54+
from_pdf=True,
55+
on_error="RAISE",
56+
)
57+
58+
# Run the knowledge graph building process asynchronously
59+
pdf_file_path = "examples/pipeline/Harry Potter and the Death Hallows Summary.pdf"
60+
pdf_result = await kg_builder_pdf.run_async(file_path=pdf_file_path)
61+
print(f"PDF Processing Result: {pdf_result}")
62+
63+
# Create an instance of the SimpleKGPipeline for text input
64+
kg_builder_text = SimpleKGPipeline(
65+
llm=llm,
66+
driver=neo4j_driver,
67+
entities=entities,
68+
relations=relations,
69+
potential_schema=potential_schema,
70+
from_pdf=False,
71+
on_error="RAISE",
72+
)
73+
74+
# Run the knowledge graph building process with text input
75+
text_input = "John Doe lives in New York City."
76+
text_result = await kg_builder_text.run_async(text=text_input)
77+
print(f"Text Processing Result: {text_result}")
78+
79+
await llm.async_client.close()
80+
81+
82+
if __name__ == "__main__":
83+
with neo4j.GraphDatabase.driver(
84+
"bolt://localhost:7687", auth=("neo4j", "password")
85+
) as driver:
86+
asyncio.run(main(driver))

src/neo4j_graphrag/experimental/components/pdf_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from fsspec.implementations.local import LocalFileSystem
2626

2727
from neo4j_graphrag.exceptions import PdfLoaderError
28-
from neo4j_graphrag.experimental.pipeline import Component, DataModel
28+
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
2929

3030

3131
class DocumentInfo(DataModel):

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pydantic import BaseModel, ValidationError, model_validator, validate_call
2020

2121
from neo4j_graphrag.exceptions import SchemaValidationError
22-
from neo4j_graphrag.experimental.pipeline import Component, DataModel
22+
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
2323

2424

2525
class SchemaProperty(BaseModel):
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from __future__ import annotations
17+
18+
from typing import Any, List, Optional, Union
19+
20+
import neo4j
21+
from pydantic import BaseModel, ConfigDict, Field
22+
23+
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
24+
LLMEntityRelationExtractor,
25+
OnError,
26+
)
27+
from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
28+
from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader
29+
from neo4j_graphrag.experimental.components.schema import (
30+
SchemaBuilder,
31+
SchemaEntity,
32+
SchemaRelation,
33+
)
34+
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
35+
FixedSizeSplitter,
36+
)
37+
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
38+
from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline, PipelineResult
39+
from neo4j_graphrag.generation.prompts import ERExtractionTemplate
40+
from neo4j_graphrag.llm.base import LLMInterface
41+
42+
43+
class SimpleKGPipelineConfig(BaseModel):
44+
llm: LLMInterface
45+
driver: neo4j.Driver
46+
from_pdf: bool
47+
entities: list[SchemaEntity] = Field(default_factory=list)
48+
relations: list[SchemaRelation] = Field(default_factory=list)
49+
potential_schema: list[tuple[str, str, str]] = Field(default_factory=list)
50+
pdf_loader: Any = None
51+
kg_writer: Any = None
52+
text_splitter: Any = None
53+
on_error: OnError = OnError.RAISE
54+
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate()
55+
56+
model_config = ConfigDict(arbitrary_types_allowed=True)
57+
58+
59+
class SimpleKGPipeline:
60+
"""
61+
A class to simplify the process of building a knowledge graph from text documents.
62+
It abstracts away the complexity of setting up the pipeline and its components.
63+
64+
Args:
65+
llm (LLMInterface): An instance of an LLM to use for entity and relation extraction.
66+
driver (neo4j.Driver): A Neo4j driver instance for database connection.
67+
entities (Optional[List[str]]): A list of entity labels as strings.
68+
relations (Optional[List[str]]): A list of relation labels as strings.
69+
potential_schema (Optional[List[tuple]]): A list of potential schema relationships.
70+
from_pdf (bool): Determines whether to include the PdfLoader in the pipeline.
71+
If True, expects `file_path` input in `run` methods.
72+
If False, expects `text` input in `run` methods.
73+
text_splitter (Optional[Any]): A text splitter component. Defaults to FixedSizeSplitter().
74+
pdf_loader (Optional[Any]): A PDF loader component. Defaults to PdfLoader().
75+
kg_writer (Optional[Any]): A knowledge graph writer component. Defaults to Neo4jWriter().
76+
on_error (OnError): Error handling strategy. Defaults to OnError.RAISE.
77+
"""
78+
79+
def __init__(
80+
self,
81+
llm: LLMInterface,
82+
driver: neo4j.Driver,
83+
entities: Optional[List[str]] = None,
84+
relations: Optional[List[str]] = None,
85+
potential_schema: Optional[List[tuple[str, str, str]]] = None,
86+
from_pdf: bool = True,
87+
text_splitter: Optional[Any] = None,
88+
pdf_loader: Optional[Any] = None,
89+
kg_writer: Optional[Any] = None,
90+
on_error: str = "RAISE",
91+
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate(),
92+
):
93+
self.entities = [SchemaEntity(label=label) for label in entities or []]
94+
self.relations = [SchemaRelation(label=label) for label in relations or []]
95+
self.potential_schema = potential_schema if potential_schema is not None else []
96+
97+
try:
98+
on_error_enum = OnError(on_error)
99+
except ValueError:
100+
raise PipelineDefinitionError(
101+
f"Invalid value for on_error: {on_error}. Expected 'RAISE' or 'CONTINUE'."
102+
)
103+
104+
config = SimpleKGPipelineConfig(
105+
llm=llm,
106+
driver=driver,
107+
entities=self.entities,
108+
relations=self.relations,
109+
potential_schema=self.potential_schema,
110+
from_pdf=from_pdf,
111+
pdf_loader=pdf_loader,
112+
kg_writer=kg_writer,
113+
text_splitter=text_splitter,
114+
on_error=on_error_enum,
115+
prompt_template=prompt_template,
116+
)
117+
118+
self.from_pdf = config.from_pdf
119+
self.llm = config.llm
120+
self.driver = config.driver
121+
self.text_splitter = config.text_splitter or FixedSizeSplitter()
122+
self.on_error = config.on_error
123+
self.pdf_loader = config.pdf_loader if pdf_loader is not None else PdfLoader()
124+
self.kg_writer = (
125+
config.kg_writer if kg_writer is not None else Neo4jWriter(driver)
126+
)
127+
self.prompt_template = config.prompt_template
128+
129+
self.pipeline = self._build_pipeline()
130+
131+
def _build_pipeline(self) -> Pipeline:
132+
pipe = Pipeline()
133+
134+
pipe.add_component(self.text_splitter, "splitter")
135+
pipe.add_component(SchemaBuilder(), "schema")
136+
pipe.add_component(
137+
LLMEntityRelationExtractor(
138+
llm=self.llm,
139+
on_error=self.on_error,
140+
prompt_template=self.prompt_template,
141+
),
142+
"extractor",
143+
)
144+
pipe.add_component(self.kg_writer, "writer")
145+
146+
if self.from_pdf:
147+
pipe.add_component(self.pdf_loader, "pdf_loader")
148+
149+
pipe.connect(
150+
"pdf_loader",
151+
"splitter",
152+
input_config={"text": "pdf_loader.text"},
153+
)
154+
155+
pipe.connect(
156+
"schema",
157+
"extractor",
158+
input_config={
159+
"schema": "schema",
160+
"document_info": "pdf_loader.document_info",
161+
},
162+
)
163+
else:
164+
pipe.connect(
165+
"schema",
166+
"extractor",
167+
input_config={
168+
"schema": "schema",
169+
},
170+
)
171+
172+
pipe.connect(
173+
"splitter",
174+
"extractor",
175+
input_config={"chunks": "splitter"},
176+
)
177+
178+
# Connect extractor to writer
179+
pipe.connect(
180+
"extractor",
181+
"writer",
182+
input_config={"graph": "extractor"},
183+
)
184+
185+
return pipe
186+
187+
async def run_async(
188+
self, file_path: Optional[str] = None, text: Optional[str] = None
189+
) -> PipelineResult:
190+
"""
191+
Asynchronously runs the knowledge graph building process.
192+
193+
Args:
194+
file_path (Optional[str]): The path to the PDF file to process. Required if `from_pdf` is True.
195+
text (Optional[str]): The text content to process. Required if `from_pdf` is False.
196+
197+
Returns:
198+
PipelineResult: The result of the pipeline execution.
199+
"""
200+
pipe_inputs = self._prepare_inputs(file_path=file_path, text=text)
201+
return await self.pipeline.run(pipe_inputs)
202+
203+
def _prepare_inputs(
204+
self, file_path: Optional[str], text: Optional[str]
205+
) -> dict[str, Any]:
206+
if self.from_pdf:
207+
if file_path is None or text is not None:
208+
raise PipelineDefinitionError(
209+
"Expected 'file_path' argument when 'from_pdf' is True."
210+
)
211+
else:
212+
if text is None or file_path is not None:
213+
raise PipelineDefinitionError(
214+
"Expected 'text' argument when 'from_pdf' is False."
215+
)
216+
217+
pipe_inputs: dict[str, Any] = {
218+
"schema": {
219+
"entities": self.entities,
220+
"relations": self.relations,
221+
"potential_schema": self.potential_schema,
222+
},
223+
}
224+
225+
if self.from_pdf:
226+
pipe_inputs["pdf_loader"] = {"filepath": file_path}
227+
else:
228+
pipe_inputs["splitter"] = {"text": text}
229+
230+
return pipe_inputs

0 commit comments

Comments
 (0)