Skip to content

Commit 632b12f

Browse files
authored
Add schema for kg builder (neo4j#88)
* Add schema for kg builder and tests * Fixed mypy checks * Reverted kg builder example with schema * Revert to List and Dict due to Python3.8 issue with using get_type_hints * Added properties to Entity and Relation * Add test for missing properties * Fix type annotations in test * Add property types * Refactored entity, relation, and property types * Unused import
1 parent bb032f9 commit 632b12f

File tree

6 files changed

+656
-1
lines changed

6 files changed

+656
-1
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
import asyncio
18+
import logging
19+
from typing import Any
20+
21+
from neo4j_genai.kg_construction.schema import (
22+
SchemaBuilder,
23+
SchemaEntity,
24+
SchemaRelation,
25+
)
26+
from neo4j_genai.pipeline import Component, DataModel
27+
from pydantic import BaseModel, validate_call
28+
29+
logging.basicConfig(level=logging.DEBUG)
30+
31+
32+
class DocumentChunkModel(DataModel):
33+
chunks: list[str]
34+
35+
36+
class DocumentChunker(Component):
37+
async def run(self, text: str) -> DocumentChunkModel:
38+
chunks = [t.strip() for t in text.split(".") if t.strip()]
39+
return DocumentChunkModel(chunks=chunks)
40+
41+
42+
class EntityModel(BaseModel):
43+
label: str
44+
properties: dict[str, str]
45+
46+
47+
class Neo4jGraph(DataModel):
48+
entities: list[dict[str, Any]]
49+
relations: list[dict[str, Any]]
50+
51+
52+
class ERExtractor(Component):
53+
async def _process_chunk(self, chunk: str, schema: str) -> dict[str, Any]:
54+
return {
55+
"entities": [{"label": "Person", "properties": {"name": "John Doe"}}],
56+
"relations": [],
57+
}
58+
59+
async def run(self, chunks: list[str], schema: str) -> Neo4jGraph:
60+
tasks = [self._process_chunk(chunk, schema) for chunk in chunks]
61+
result = await asyncio.gather(*tasks)
62+
merged_result: dict[str, Any] = {"entities": [], "relations": []}
63+
for res in result:
64+
merged_result["entities"] += res["entities"]
65+
merged_result["relations"] += res["relations"]
66+
return Neo4jGraph(
67+
entities=merged_result["entities"], relations=merged_result["relations"]
68+
)
69+
70+
71+
class WriterModel(DataModel):
72+
status: str
73+
entities: list[EntityModel]
74+
relations: list[EntityModel]
75+
76+
77+
class Writer(Component):
78+
@validate_call
79+
async def run(self, graph: Neo4jGraph) -> WriterModel:
80+
entities = graph.entities
81+
relations = graph.relations
82+
return WriterModel(
83+
status="OK",
84+
entities=[EntityModel(**e) for e in entities],
85+
relations=[EntityModel(**r) for r in relations],
86+
)
87+
88+
89+
if __name__ == "__main__":
90+
from neo4j_genai.pipeline import Pipeline
91+
92+
# Instantiate Entity and Relation objects
93+
entities = [
94+
SchemaEntity(label="PERSON", description="An individual human being."),
95+
SchemaEntity(
96+
label="ORGANIZATION",
97+
description="A structured group of people with a common purpose.",
98+
),
99+
SchemaEntity(
100+
label="AGE",
101+
),
102+
]
103+
relations = [
104+
SchemaRelation(
105+
label="EMPLOYED_BY", description="Indicates employment relationship."
106+
),
107+
SchemaRelation(
108+
label="ORGANIZED_BY",
109+
description="Indicates organization responsible for an event.",
110+
),
111+
SchemaRelation(
112+
label="ATTENDED_BY", description="Indicates attendance at an event."
113+
),
114+
]
115+
potential_schema = [
116+
("PERSON", "EMPLOYED_BY", "ORGANIZATION"),
117+
("ORGANIZATION", "ATTENDED_BY", "PERSON"),
118+
]
119+
120+
# Set up the pipeline
121+
pipe = Pipeline()
122+
pipe.add_component("chunker", DocumentChunker())
123+
pipe.add_component("schema", SchemaBuilder())
124+
pipe.add_component("extractor", ERExtractor())
125+
pipe.add_component("writer", Writer())
126+
pipe.connect("chunker", "extractor", input_config={"chunks": "chunker.chunks"})
127+
pipe.connect("schema", "extractor", input_config={"schema": "schema"})
128+
pipe.connect(
129+
"extractor",
130+
"writer",
131+
input_config={"graph": "extractor"},
132+
)
133+
134+
pipe_inputs = {
135+
"chunker": {
136+
"text": """Graphs are everywhere.
137+
GraphRAG is the future of Artificial Intelligence.
138+
Robots are already running the world."""
139+
},
140+
"schema": {
141+
"entities": entities,
142+
"relations": relations,
143+
"potential_schema": potential_schema,
144+
},
145+
}
146+
print(asyncio.run(pipe.run(pipe_inputs)))

src/neo4j_genai/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,9 @@ class SchemaFetchError(Neo4jGenAiError):
104104
"""Exception raised when a Neo4jSchema cannot be fetched."""
105105

106106
pass
107+
108+
109+
class SchemaValidationError(Exception):
110+
"""Custom exception for errors in schema configuration."""
111+
112+
pass
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# Copyright (c) "Neo4j"
16+
# Neo4j Sweden AB [https://neo4j.com]
17+
# #
18+
# Licensed under the Apache License, Version 2.0 (the "License");
19+
# you may not use this file except in compliance with the License.
20+
# You may obtain a copy of the License at
21+
# #
22+
# https://www.apache.org/licenses/LICENSE-2.0
23+
# #
24+
# Unless required by applicable law or agreed to in writing, software
25+
# distributed under the License is distributed on an "AS IS" BASIS,
26+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27+
# See the License for the specific language governing permissions and
28+
# limitations under the License.
29+
from __future__ import annotations
30+
31+
from typing import Any, Dict, List, Literal, Tuple
32+
33+
from neo4j_genai.exceptions import SchemaValidationError
34+
from neo4j_genai.pipeline import Component, DataModel
35+
from pydantic import BaseModel, ValidationError, model_validator, validate_call
36+
37+
38+
class SchemaProperty(BaseModel):
39+
name: str
40+
# See https://neo4j.com/docs/cypher-manual/current/values-and-types/property-structural-constructed/#property-types
41+
type: Literal[
42+
"BOOLEAN",
43+
"DATE",
44+
"DURATION",
45+
"FLOAT",
46+
"INTEGER",
47+
"LIST",
48+
"LOCAL_DATETIME",
49+
"LOCAL_TIME",
50+
"POINT",
51+
"STRING",
52+
"ZONED_DATETIME",
53+
"ZONED_TIME",
54+
]
55+
description: str = ""
56+
57+
58+
class SchemaEntity(BaseModel):
59+
"""
60+
Represents a possible node in the graph.
61+
"""
62+
63+
label: str
64+
description: str = ""
65+
properties: List[SchemaProperty] = []
66+
67+
68+
class SchemaRelation(BaseModel):
69+
"""
70+
Represents a possible relationship between nodes in the graph.
71+
"""
72+
73+
label: str
74+
description: str = ""
75+
properties: List[SchemaProperty] = []
76+
77+
78+
class SchemaConfig(DataModel):
79+
"""
80+
Represents possible relationships between entities and relations in the graph.
81+
"""
82+
83+
entities: Dict[str, Dict[str, Any]]
84+
relations: Dict[str, Dict[str, Any]]
85+
potential_schema: List[Tuple[str, str, str]]
86+
87+
@model_validator(mode="before")
88+
def check_schema(cls, data: Dict[str, Any]) -> Dict[str, Any]:
89+
entities = data.get("entities", {}).keys()
90+
relations = data.get("relations", {}).keys()
91+
potential_schema = data.get("potential_schema", [])
92+
93+
for entity1, relation, entity2 in potential_schema:
94+
if entity1 not in entities:
95+
raise SchemaValidationError(
96+
f"Entity '{entity1}' is not defined in the provided entities."
97+
)
98+
if relation not in relations:
99+
raise SchemaValidationError(
100+
f"Relation '{relation}' is not defined in the provided relations."
101+
)
102+
if entity2 not in entities:
103+
raise SchemaValidationError(
104+
f"Entity '{entity1}' is not defined in the provided entities."
105+
)
106+
107+
return data
108+
109+
110+
class SchemaBuilder(Component):
111+
"""
112+
A builder class for constructing SchemaConfig objects from given entities,
113+
relations, and their interrelationships defined in a potential schema.
114+
"""
115+
116+
@staticmethod
117+
def create_schema_model(
118+
entities: List[SchemaEntity],
119+
relations: List[SchemaRelation],
120+
potential_schema: List[Tuple[str, str, str]],
121+
) -> SchemaConfig:
122+
"""
123+
Creates a SchemaConfig object from Lists of Entity and Relation objects
124+
and a Dictionary defining potential relationships.
125+
126+
Args:
127+
entities (List[SchemaEntity]): List of Entity objects.
128+
relations (List[SchemaRelation]): List of Relation objects.
129+
potential_schema (Dict[str, List[str]]): Dictionary mapping entity names to Lists of relation names.
130+
131+
Returns:
132+
SchemaConfig: A configured schema object.
133+
"""
134+
entity_dict = {entity.label: entity.dict() for entity in entities}
135+
relation_dict = {relation.label: relation.dict() for relation in relations}
136+
137+
try:
138+
return SchemaConfig(
139+
entities=entity_dict,
140+
relations=relation_dict,
141+
potential_schema=potential_schema,
142+
)
143+
except (ValidationError, SchemaValidationError) as e:
144+
raise SchemaValidationError(e)
145+
146+
@validate_call
147+
async def run(
148+
self,
149+
entities: List[SchemaEntity],
150+
relations: List[SchemaRelation],
151+
potential_schema: List[Tuple[str, str, str]],
152+
) -> SchemaConfig:
153+
"""
154+
Asynchronously constructs and returns a SchemaConfig object.
155+
156+
Args:
157+
entities (List[SchemaEntity]): List of Entity objects.
158+
relations (List[SchemaRelation]): List of Relation objects.
159+
potential_schema (Dict[str, List[str]]): Dictionary mapping entity names to Lists of relation names.
160+
161+
Returns:
162+
SchemaConfig: A configured schema object, constructed asynchronously.
163+
"""
164+
return self.create_schema_model(entities, relations, potential_schema)

src/neo4j_genai/pipeline/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class TaskPipelineNode(PipelineNode):
5858
"""Runnable node. It must have:
5959
- a name (unique within the pipeline)
6060
- a component instance
61-
- a reference to the pipline it belongs to
61+
- a reference to the pipeline it belongs to
6262
(to find dependent tasks)
6363
"""
6464

tests/e2e/test_schema_e2e.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
from __future__ import annotations
1516

1617
import pytest
1718
from neo4j import Driver

0 commit comments

Comments
 (0)