|
| 1 | +# Copyright (c) "Neo4j" |
| 2 | +# Neo4j Sweden AB [https://neo4j.com] |
| 3 | +# # |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# # |
| 8 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# # |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +from __future__ import annotations |
| 17 | + |
| 18 | +from typing import Any, List, Optional, Union |
| 19 | + |
| 20 | +import neo4j |
| 21 | +from pydantic import BaseModel, ConfigDict, Field |
| 22 | + |
| 23 | +from neo4j_graphrag.experimental.components.entity_relation_extractor import ( |
| 24 | + LLMEntityRelationExtractor, |
| 25 | + OnError, |
| 26 | +) |
| 27 | +from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter |
| 28 | +from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader |
| 29 | +from neo4j_graphrag.experimental.components.schema import ( |
| 30 | + SchemaBuilder, |
| 31 | + SchemaEntity, |
| 32 | + SchemaRelation, |
| 33 | +) |
| 34 | +from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( |
| 35 | + FixedSizeSplitter, |
| 36 | +) |
| 37 | +from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError |
| 38 | +from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline, PipelineResult |
| 39 | +from neo4j_graphrag.generation.prompts import ERExtractionTemplate |
| 40 | +from neo4j_graphrag.llm.base import LLMInterface |
| 41 | + |
| 42 | + |
| 43 | +class SimpleKGPipelineConfig(BaseModel): |
| 44 | + llm: LLMInterface |
| 45 | + driver: neo4j.Driver |
| 46 | + from_pdf: bool |
| 47 | + entities: list[SchemaEntity] = Field(default_factory=list) |
| 48 | + relations: list[SchemaRelation] = Field(default_factory=list) |
| 49 | + potential_schema: list[tuple[str, str, str]] = Field(default_factory=list) |
| 50 | + pdf_loader: Any = None |
| 51 | + kg_writer: Any = None |
| 52 | + text_splitter: Any = None |
| 53 | + on_error: OnError = OnError.RAISE |
| 54 | + prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() |
| 55 | + |
| 56 | + model_config = ConfigDict(arbitrary_types_allowed=True) |
| 57 | + |
| 58 | + |
| 59 | +class SimpleKGPipeline: |
| 60 | + """ |
| 61 | + A class to simplify the process of building a knowledge graph from text documents. |
| 62 | + It abstracts away the complexity of setting up the pipeline and its components. |
| 63 | +
|
| 64 | + Args: |
| 65 | + llm (LLMInterface): An instance of an LLM to use for entity and relation extraction. |
| 66 | + driver (neo4j.Driver): A Neo4j driver instance for database connection. |
| 67 | + entities (Optional[List[str]]): A list of entity labels as strings. |
| 68 | + relations (Optional[List[str]]): A list of relation labels as strings. |
| 69 | + potential_schema (Optional[List[tuple]]): A list of potential schema relationships. |
| 70 | + from_pdf (bool): Determines whether to include the PdfLoader in the pipeline. |
| 71 | + If True, expects `file_path` input in `run` methods. |
| 72 | + If False, expects `text` input in `run` methods. |
| 73 | + text_splitter (Optional[Any]): A text splitter component. Defaults to FixedSizeSplitter(). |
| 74 | + pdf_loader (Optional[Any]): A PDF loader component. Defaults to PdfLoader(). |
| 75 | + kg_writer (Optional[Any]): A knowledge graph writer component. Defaults to Neo4jWriter(). |
| 76 | + on_error (OnError): Error handling strategy. Defaults to OnError.RAISE. |
| 77 | + """ |
| 78 | + |
| 79 | + def __init__( |
| 80 | + self, |
| 81 | + llm: LLMInterface, |
| 82 | + driver: neo4j.Driver, |
| 83 | + entities: Optional[List[str]] = None, |
| 84 | + relations: Optional[List[str]] = None, |
| 85 | + potential_schema: Optional[List[tuple[str, str, str]]] = None, |
| 86 | + from_pdf: bool = True, |
| 87 | + text_splitter: Optional[Any] = None, |
| 88 | + pdf_loader: Optional[Any] = None, |
| 89 | + kg_writer: Optional[Any] = None, |
| 90 | + on_error: str = "RAISE", |
| 91 | + prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate(), |
| 92 | + ): |
| 93 | + self.entities = [SchemaEntity(label=label) for label in entities or []] |
| 94 | + self.relations = [SchemaRelation(label=label) for label in relations or []] |
| 95 | + self.potential_schema = potential_schema if potential_schema is not None else [] |
| 96 | + |
| 97 | + try: |
| 98 | + on_error_enum = OnError(on_error) |
| 99 | + except ValueError: |
| 100 | + raise PipelineDefinitionError( |
| 101 | + f"Invalid value for on_error: {on_error}. Expected 'RAISE' or 'CONTINUE'." |
| 102 | + ) |
| 103 | + |
| 104 | + config = SimpleKGPipelineConfig( |
| 105 | + llm=llm, |
| 106 | + driver=driver, |
| 107 | + entities=self.entities, |
| 108 | + relations=self.relations, |
| 109 | + potential_schema=self.potential_schema, |
| 110 | + from_pdf=from_pdf, |
| 111 | + pdf_loader=pdf_loader, |
| 112 | + kg_writer=kg_writer, |
| 113 | + text_splitter=text_splitter, |
| 114 | + on_error=on_error_enum, |
| 115 | + prompt_template=prompt_template, |
| 116 | + ) |
| 117 | + |
| 118 | + self.from_pdf = config.from_pdf |
| 119 | + self.llm = config.llm |
| 120 | + self.driver = config.driver |
| 121 | + self.text_splitter = config.text_splitter or FixedSizeSplitter() |
| 122 | + self.on_error = config.on_error |
| 123 | + self.pdf_loader = config.pdf_loader if pdf_loader is not None else PdfLoader() |
| 124 | + self.kg_writer = ( |
| 125 | + config.kg_writer if kg_writer is not None else Neo4jWriter(driver) |
| 126 | + ) |
| 127 | + self.prompt_template = config.prompt_template |
| 128 | + |
| 129 | + self.pipeline = self._build_pipeline() |
| 130 | + |
| 131 | + def _build_pipeline(self) -> Pipeline: |
| 132 | + pipe = Pipeline() |
| 133 | + |
| 134 | + pipe.add_component(self.text_splitter, "splitter") |
| 135 | + pipe.add_component(SchemaBuilder(), "schema") |
| 136 | + pipe.add_component( |
| 137 | + LLMEntityRelationExtractor( |
| 138 | + llm=self.llm, |
| 139 | + on_error=self.on_error, |
| 140 | + prompt_template=self.prompt_template, |
| 141 | + ), |
| 142 | + "extractor", |
| 143 | + ) |
| 144 | + pipe.add_component(self.kg_writer, "writer") |
| 145 | + |
| 146 | + if self.from_pdf: |
| 147 | + pipe.add_component(self.pdf_loader, "pdf_loader") |
| 148 | + |
| 149 | + pipe.connect( |
| 150 | + "pdf_loader", |
| 151 | + "splitter", |
| 152 | + input_config={"text": "pdf_loader.text"}, |
| 153 | + ) |
| 154 | + |
| 155 | + pipe.connect( |
| 156 | + "schema", |
| 157 | + "extractor", |
| 158 | + input_config={ |
| 159 | + "schema": "schema", |
| 160 | + "document_info": "pdf_loader.document_info", |
| 161 | + }, |
| 162 | + ) |
| 163 | + else: |
| 164 | + pipe.connect( |
| 165 | + "schema", |
| 166 | + "extractor", |
| 167 | + input_config={ |
| 168 | + "schema": "schema", |
| 169 | + }, |
| 170 | + ) |
| 171 | + |
| 172 | + pipe.connect( |
| 173 | + "splitter", |
| 174 | + "extractor", |
| 175 | + input_config={"chunks": "splitter"}, |
| 176 | + ) |
| 177 | + |
| 178 | + # Connect extractor to writer |
| 179 | + pipe.connect( |
| 180 | + "extractor", |
| 181 | + "writer", |
| 182 | + input_config={"graph": "extractor"}, |
| 183 | + ) |
| 184 | + |
| 185 | + return pipe |
| 186 | + |
| 187 | + async def run_async( |
| 188 | + self, file_path: Optional[str] = None, text: Optional[str] = None |
| 189 | + ) -> PipelineResult: |
| 190 | + """ |
| 191 | + Asynchronously runs the knowledge graph building process. |
| 192 | +
|
| 193 | + Args: |
| 194 | + file_path (Optional[str]): The path to the PDF file to process. Required if `from_pdf` is True. |
| 195 | + text (Optional[str]): The text content to process. Required if `from_pdf` is False. |
| 196 | +
|
| 197 | + Returns: |
| 198 | + PipelineResult: The result of the pipeline execution. |
| 199 | + """ |
| 200 | + pipe_inputs = self._prepare_inputs(file_path=file_path, text=text) |
| 201 | + return await self.pipeline.run(pipe_inputs) |
| 202 | + |
| 203 | + def _prepare_inputs( |
| 204 | + self, file_path: Optional[str], text: Optional[str] |
| 205 | + ) -> dict[str, Any]: |
| 206 | + if self.from_pdf: |
| 207 | + if file_path is None or text is not None: |
| 208 | + raise PipelineDefinitionError( |
| 209 | + "Expected 'file_path' argument when 'from_pdf' is True." |
| 210 | + ) |
| 211 | + else: |
| 212 | + if text is None or file_path is not None: |
| 213 | + raise PipelineDefinitionError( |
| 214 | + "Expected 'text' argument when 'from_pdf' is False." |
| 215 | + ) |
| 216 | + |
| 217 | + pipe_inputs: dict[str, Any] = { |
| 218 | + "schema": { |
| 219 | + "entities": self.entities, |
| 220 | + "relations": self.relations, |
| 221 | + "potential_schema": self.potential_schema, |
| 222 | + }, |
| 223 | + } |
| 224 | + |
| 225 | + if self.from_pdf: |
| 226 | + pipe_inputs["pdf_loader"] = {"filepath": file_path} |
| 227 | + else: |
| 228 | + pipe_inputs["splitter"] = {"text": text} |
| 229 | + |
| 230 | + return pipe_inputs |
0 commit comments