|
15 | 15 | from __future__ import annotations
|
16 | 16 |
|
17 | 17 | import asyncio
|
18 |
| -import logging |
| 18 | +import logging.config |
19 | 19 | from typing import Any
|
20 | 20 |
|
21 |
| -from neo4j_genai.pipeline import Component, DataModel |
22 |
| -from pydantic import BaseModel, validate_call |
23 |
| - |
24 |
| -logging.basicConfig(level=logging.DEBUG) |
25 |
| - |
26 |
| - |
27 |
| -class DocumentChunkModel(DataModel): |
28 |
| - chunks: list[str] |
29 |
| - |
30 |
| - |
31 |
| -class DocumentChunker(Component): |
32 |
| - async def run(self, text: str) -> DocumentChunkModel: |
33 |
| - chunks = [t.strip() for t in text.split(".") if t.strip()] |
34 |
| - return DocumentChunkModel(chunks=chunks) |
35 |
| - |
36 |
| - |
37 |
| -class SchemaModel(DataModel): |
38 |
| - data_schema: str |
39 |
| - |
40 |
| - |
41 |
| -class SchemaBuilder(Component): |
42 |
| - async def run(self, schema: str) -> SchemaModel: |
43 |
| - return SchemaModel(data_schema=schema) |
44 |
| - |
45 |
| - |
46 |
| -class EntityModel(BaseModel): |
47 |
| - label: str |
48 |
| - properties: dict[str, str] |
49 |
| - |
50 |
| - |
51 |
| -class Neo4jGraph(DataModel): |
52 |
| - entities: list[dict[str, Any]] |
53 |
| - relations: list[dict[str, Any]] |
54 |
| - |
55 |
| - |
56 |
| -class ERExtractor(Component): |
57 |
| - async def _process_chunk(self, chunk: str, schema: str) -> dict[str, Any]: |
58 |
| - return { |
59 |
| - "entities": [{"label": "Person", "properties": {"name": "John Doe"}}], |
60 |
| - "relations": [], |
61 |
| - } |
62 |
| - |
63 |
| - async def run(self, chunks: list[str], schema: str) -> Neo4jGraph: |
64 |
| - tasks = [self._process_chunk(chunk, schema) for chunk in chunks] |
65 |
| - result = await asyncio.gather(*tasks) |
66 |
| - merged_result: dict[str, Any] = {"entities": [], "relations": []} |
67 |
| - for res in result: |
68 |
| - merged_result["entities"] += res["entities"] |
69 |
| - merged_result["relations"] += res["relations"] |
70 |
| - return Neo4jGraph( |
71 |
| - entities=merged_result["entities"], relations=merged_result["relations"] |
72 |
| - ) |
73 |
| - |
74 |
| - |
75 |
| -class WriterModel(DataModel): |
76 |
| - status: str |
77 |
| - entities: list[EntityModel] |
78 |
| - relations: list[EntityModel] |
79 |
| - |
80 |
| - |
81 |
| -class Writer(Component): |
82 |
| - @validate_call |
83 |
| - async def run(self, graph: Neo4jGraph) -> WriterModel: |
84 |
| - entities = graph.entities |
85 |
| - relations = graph.relations |
86 |
| - return WriterModel( |
87 |
| - status="OK", |
88 |
| - entities=[EntityModel(**e) for e in entities], |
89 |
| - relations=[EntityModel(**r) for r in relations], |
90 |
| - ) |
91 |
| - |
92 |
| - |
93 |
| -if __name__ == "__main__": |
94 |
| - from neo4j_genai.pipeline import Pipeline |
95 |
| - |
| 21 | +import neo4j |
| 22 | +from langchain_text_splitters import CharacterTextSplitter |
| 23 | +from neo4j_genai.components.entity_relation_extractor import ( |
| 24 | + LLMEntityRelationExtractor, |
| 25 | + OnError, |
| 26 | +) |
| 27 | +from neo4j_genai.components.kg_writer import Neo4jWriter |
| 28 | +from neo4j_genai.components.schema import ( |
| 29 | + SchemaBuilder, |
| 30 | + SchemaEntity, |
| 31 | + SchemaProperty, |
| 32 | + SchemaRelation, |
| 33 | +) |
| 34 | +from neo4j_genai.components.text_splitters.langchain import LangChainTextSplitterAdapter |
| 35 | +from neo4j_genai.llm import OpenAILLM |
| 36 | +from neo4j_genai.pipeline import Pipeline |
| 37 | + |
| 38 | +# set log level to DEBUG for all neo4j_genai.* loggers |
| 39 | +logging.config.dictConfig( |
| 40 | + { |
| 41 | + "version": 1, |
| 42 | + "handlers": { |
| 43 | + "console": { |
| 44 | + "class": "logging.StreamHandler", |
| 45 | + } |
| 46 | + }, |
| 47 | + "loggers": { |
| 48 | + "root": { |
| 49 | + "handlers": ["console"], |
| 50 | + }, |
| 51 | + "neo4j_genai": { |
| 52 | + "level": "DEBUG", |
| 53 | + }, |
| 54 | + }, |
| 55 | + } |
| 56 | +) |
| 57 | + |
| 58 | + |
| 59 | +async def main(neo4j_driver: neo4j.Driver) -> dict[str, Any]: |
| 60 | + """This is where we define and run the KG builder pipeline, instantiating a few |
| 61 | + components: |
| 62 | + - Text Splitter: in this example we use a text splitter from the LangChain package |
| 63 | + - Schema Builder: this component takes a list of entities, relationships and |
| 64 | + possible triplets as inputs, validate them and return a schema ready to use |
| 65 | + for the rest of the pipeline |
| 66 | + - LLM Entity Relation Extractor is an LLM-based entity and relation extractor: |
| 67 | + based on the provided schema, the LLM will do its best to identity these |
| 68 | + entities and their relations within the provided text |
| 69 | + - KG writer: once entities and relations are extracted, they can be writen |
| 70 | + to a Neo4j database |
| 71 | + """ |
96 | 72 | pipe = Pipeline()
|
97 |
| - pipe.add_component("chunker", DocumentChunker()) |
| 73 | + # define the components |
| 74 | + pipe.add_component( |
| 75 | + "splitter", |
| 76 | + LangChainTextSplitterAdapter( |
| 77 | + # chunk_size=50 for the sake of this demo |
| 78 | + CharacterTextSplitter(chunk_size=50, chunk_overlap=10, separator=".") |
| 79 | + ), |
| 80 | + ) |
98 | 81 | pipe.add_component("schema", SchemaBuilder())
|
99 |
| - pipe.add_component("extractor", ERExtractor()) |
100 |
| - pipe.add_component("writer", Writer()) |
101 |
| - pipe.connect("chunker", "extractor", input_config={"chunks": "chunker.chunks"}) |
102 |
| - pipe.connect("schema", "extractor", input_config={"schema": "schema.data_schema"}) |
| 82 | + pipe.add_component( |
| 83 | + "extractor", |
| 84 | + LLMEntityRelationExtractor( |
| 85 | + llm=OpenAILLM( |
| 86 | + model_name="gpt-4o", |
| 87 | + model_params={ |
| 88 | + "max_tokens": 1000, |
| 89 | + "response_format": {"type": "json_object"}, |
| 90 | + }, |
| 91 | + ), |
| 92 | + on_error=OnError.RAISE, |
| 93 | + ), |
| 94 | + ) |
| 95 | + pipe.add_component("writer", Neo4jWriter(neo4j_driver)) |
| 96 | + # define the execution order of component |
| 97 | + # and how the output of previous components must be used |
| 98 | + pipe.connect("splitter", "extractor", input_config={"chunks": "splitter"}) |
| 99 | + pipe.connect("schema", "extractor", input_config={"schema": "schema"}) |
103 | 100 | pipe.connect(
|
104 | 101 | "extractor",
|
105 | 102 | "writer",
|
106 | 103 | input_config={"graph": "extractor"},
|
107 | 104 | )
|
108 |
| - |
| 105 | + # user input: |
| 106 | + # the initial text |
| 107 | + # and the list of entities and relations we are looking for |
109 | 108 | pipe_inputs = {
|
110 |
| - "chunker": { |
111 |
| - "text": """Graphs are everywhere. |
112 |
| - GraphRAG is the future of Artificial Intelligence. |
113 |
| - Robots are already running the world.""" |
| 109 | + "splitter": { |
| 110 | + "text": """Albert Einstein was a German physicist born in 1879 who |
| 111 | + wrote many groundbreaking papers especially about general relativity |
| 112 | + and quantum mechanics. He worked for many different institutions, including |
| 113 | + the University of Bern in Switzerland and the University of Oxford.""" |
| 114 | + }, |
| 115 | + "schema": { |
| 116 | + "entities": [ |
| 117 | + SchemaEntity( |
| 118 | + label="Person", |
| 119 | + properties=[ |
| 120 | + SchemaProperty(name="name", type="STRING"), |
| 121 | + SchemaProperty(name="place_of_birth", type="STRING"), |
| 122 | + SchemaProperty(name="date_of_birth", type="DATE"), |
| 123 | + ], |
| 124 | + ), |
| 125 | + SchemaEntity( |
| 126 | + label="Organization", |
| 127 | + properties=[ |
| 128 | + SchemaProperty(name="name", type="STRING"), |
| 129 | + SchemaProperty(name="country", type="STRING"), |
| 130 | + ], |
| 131 | + ), |
| 132 | + SchemaEntity( |
| 133 | + label="Field", |
| 134 | + properties=[ |
| 135 | + SchemaProperty(name="name", type="STRING"), |
| 136 | + ], |
| 137 | + ), |
| 138 | + ], |
| 139 | + "relations": [ |
| 140 | + SchemaRelation( |
| 141 | + label="WORKED_ON", |
| 142 | + ), |
| 143 | + SchemaRelation( |
| 144 | + label="WORKED_FOR", |
| 145 | + ), |
| 146 | + ], |
| 147 | + "potential_schema": [ |
| 148 | + ("Person", "WORKED_ON", "Field"), |
| 149 | + ("Person", "WORKED_FOR", "Organization"), |
| 150 | + ], |
114 | 151 | },
|
115 |
| - "schema": {"schema": "Person OWNS House"}, |
116 | 152 | }
|
117 |
| - print(asyncio.run(pipe.run(pipe_inputs))) |
| 153 | + # run the pipeline |
| 154 | + return await pipe.run(pipe_inputs) |
| 155 | + |
| 156 | + |
| 157 | +if __name__ == "__main__": |
| 158 | + with neo4j.GraphDatabase.driver( |
| 159 | + "bolt://localhost:7687", auth=("neo4j", "password") |
| 160 | + ) as driver: |
| 161 | + print(asyncio.run(main(driver))) |
0 commit comments