Skip to content

Commit 8285d56

Browse files
authored
Improved schema definition in the SimpleKGPipeline (#176)
* some tests * Simplify * Remove SchemaEntity/SchemaRelation as valid inputs in SimpleKGPipeline (not "simple") * Fix mypy * Fix ruff * Ruff again * Ruff again * Update tests, examples and CHANGELOG * Fix mypy again N * Update type, switch to model_validate method due to typing issues * Ruff
1 parent 0eb888f commit 8285d56

File tree

5 files changed

+116
-12
lines changed

5 files changed

+116
-12
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
### Added
66
- Introduced optional lexical graph configuration for SimpleKGPipeline, enhancing flexibility in customizing node labels and relationship types in the lexical graph.
7+
- Ability to provide description and list of properties for entities and relations in the SimpleKGPipeline constructor.
78

89
## 1.2.0
910

examples/build_graph/simple_kg_builder_from_text.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from neo4j_graphrag.embeddings import OpenAIEmbeddings
1212
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
1313
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
14+
from neo4j_graphrag.experimental.pipeline.types import (
15+
EntityInputType,
16+
RelationInputType,
17+
)
1418
from neo4j_graphrag.llm import LLMInterface
1519
from neo4j_graphrag.llm.openai_llm import OpenAILLM
1620

@@ -21,12 +25,28 @@
2125

2226
# Text to process
2327
TEXT = """The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House Atreides,
24-
an aristocratic family that rules the planet Caladan."""
28+
an aristocratic family that rules the planet Caladan, the rainy planet, since 10191."""
2529

2630
# Instantiate Entity and Relation objects. This defines the
2731
# entities and relations the LLM will be looking for in the text.
28-
ENTITIES = ["Person", "House", "Planet"]
29-
RELATIONS = ["PARENT_OF", "HEIR_OF", "RULES"]
32+
ENTITIES: list[EntityInputType] = [
33+
# entities can be defined with a simple label...
34+
"Person",
35+
# ... or with a dict if more details are needed,
36+
# such as a description:
37+
{"label": "House", "description": "Family the person belongs to"},
38+
# or a list of properties the LLM will try to attach to the entity:
39+
{"label": "Planet", "properties": [{"name": "weather", "type": "STRING"}]},
40+
]
41+
# same thing for relationships:
42+
RELATIONS: list[RelationInputType] = [
43+
"PARENT_OF",
44+
{
45+
"label": "HEIR_OF",
46+
"description": "Used for inheritor relationship between father and sons",
47+
},
48+
{"label": "RULES", "properties": [{"name": "fromYear", "type": "INTEGER"}]},
49+
]
3050
POTENTIAL_SCHEMA = [
3151
("Person", "PARENT_OF", "Person"),
3252
("Person", "HEIR_OF", "House"),

src/neo4j_graphrag/experimental/pipeline/kg_builder.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from __future__ import annotations
1717

18-
from typing import Any, List, Optional, Union
18+
from typing import Any, List, Optional, Sequence, Union
1919

2020
import neo4j
2121
from pydantic import BaseModel, ConfigDict, Field
@@ -42,6 +42,10 @@
4242
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
4343
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
4444
from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline, PipelineResult
45+
from neo4j_graphrag.experimental.pipeline.types import (
46+
EntityInputType,
47+
RelationInputType,
48+
)
4549
from neo4j_graphrag.generation.prompts import ERExtractionTemplate
4650
from neo4j_graphrag.llm.base import LLMInterface
4751

@@ -74,8 +78,16 @@ class SimpleKGPipeline:
7478
llm (LLMInterface): An instance of an LLM to use for entity and relation extraction.
7579
driver (neo4j.Driver): A Neo4j driver instance for database connection.
7680
embedder (Embedder): An instance of an embedder used to generate chunk embeddings from text chunks.
77-
entities (Optional[List[str]]): A list of entity labels as strings.
78-
relations (Optional[List[str]]): A list of relation labels as strings.
81+
entities (Optional[List[Union[str, dict[str, str], SchemaEntity]]]): A list of either:
82+
83+
- str: entity labels
84+
- dict: following the SchemaEntity schema, ie with label, description and properties keys
85+
86+
relations (Optional[List[Union[str, dict[str, str], SchemaRelation]]]): A list of either:
87+
88+
- str: relation label
89+
- dict: following the SchemaRelation schema, ie with label, description and properties keys
90+
7991
potential_schema (Optional[List[tuple]]): A list of potential schema relationships.
8092
from_pdf (bool): Determines whether to include the PdfLoader in the pipeline.
8193
If True, expects `file_path` input in `run` methods.
@@ -94,8 +106,8 @@ def __init__(
94106
llm: LLMInterface,
95107
driver: neo4j.Driver,
96108
embedder: Embedder,
97-
entities: Optional[List[str]] = None,
98-
relations: Optional[List[str]] = None,
109+
entities: Optional[Sequence[EntityInputType]] = None,
110+
relations: Optional[Sequence[RelationInputType]] = None,
99111
potential_schema: Optional[List[tuple[str, str, str]]] = None,
100112
from_pdf: bool = True,
101113
text_splitter: Optional[Any] = None,
@@ -106,9 +118,9 @@ def __init__(
106118
perform_entity_resolution: bool = True,
107119
lexical_graph_config: Optional[LexicalGraphConfig] = None,
108120
):
109-
self.entities = [SchemaEntity(label=label) for label in entities or []]
110-
self.relations = [SchemaRelation(label=label) for label in relations or []]
111-
self.potential_schema = potential_schema if potential_schema is not None else []
121+
self.potential_schema = potential_schema or []
122+
self.entities = [self.to_schema_entity(e) for e in entities or []]
123+
self.relations = [self.to_schema_relation(r) for r in relations or []]
112124

113125
try:
114126
on_error_enum = OnError(on_error)
@@ -150,6 +162,18 @@ def __init__(
150162

151163
self.pipeline = self._build_pipeline()
152164

165+
@staticmethod
166+
def to_schema_entity(entity: EntityInputType) -> SchemaEntity:
167+
if isinstance(entity, dict):
168+
return SchemaEntity.model_validate(entity)
169+
return SchemaEntity(label=entity)
170+
171+
@staticmethod
172+
def to_schema_relation(relation: RelationInputType) -> SchemaRelation:
173+
if isinstance(relation, dict):
174+
return SchemaRelation.model_validate(relation)
175+
return SchemaRelation(label=relation)
176+
153177
def _build_pipeline(self) -> Pipeline:
154178
pipe = Pipeline()
155179

src/neo4j_graphrag/experimental/pipeline/types.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
from typing import Union
18+
1719
from pydantic import BaseModel, ConfigDict
1820

1921
from neo4j_graphrag.experimental.pipeline.component import Component
@@ -35,3 +37,12 @@ class ConnectionConfig(BaseModel):
3537
class PipelineConfig(BaseModel):
3638
components: list[ComponentConfig]
3739
connections: list[ConnectionConfig]
40+
41+
42+
EntityInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]]
43+
RelationInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]]
44+
"""Types derived from the SchemaEntity and SchemaRelation types,
45+
so the possible types for dict values are:
46+
- str (for label and description)
47+
- list[dict[str, str]] (for properties)
48+
"""

tests/unit/experimental/pipeline/test_kg_builder.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
import pytest
2020
from neo4j_graphrag.embeddings import Embedder
2121
from neo4j_graphrag.experimental.components.entity_relation_extractor import OnError
22-
from neo4j_graphrag.experimental.components.schema import SchemaEntity, SchemaRelation
22+
from neo4j_graphrag.experimental.components.schema import (
23+
SchemaEntity,
24+
SchemaProperty,
25+
SchemaRelation,
26+
)
2327
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
2428
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
2529
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
@@ -379,3 +383,47 @@ async def test_knowledge_graph_builder_with_lexical_graph_config(_: Mock) -> Non
379383
assert pipe_inputs["extractor"] == {
380384
"lexical_graph_config": lexical_graph_config
381385
}
386+
387+
388+
def test_knowledge_graph_builder_to_schema_entity_method() -> None:
389+
assert SimpleKGPipeline.to_schema_entity("EntityType") == SchemaEntity(
390+
label="EntityType"
391+
)
392+
assert SimpleKGPipeline.to_schema_entity({"label": "EntityType"}) == SchemaEntity(
393+
label="EntityType"
394+
)
395+
assert SimpleKGPipeline.to_schema_entity(
396+
{"label": "EntityType", "description": "A special entity"}
397+
) == SchemaEntity(label="EntityType", description="A special entity")
398+
assert SimpleKGPipeline.to_schema_entity(
399+
{"label": "EntityType", "properties": []}
400+
) == SchemaEntity(label="EntityType")
401+
assert SimpleKGPipeline.to_schema_entity(
402+
{
403+
"label": "EntityType",
404+
"properties": [{"name": "entityProperty", "type": "DATE"}],
405+
}
406+
) == SchemaEntity(
407+
label="EntityType",
408+
properties=[SchemaProperty(name="entityProperty", type="DATE")],
409+
)
410+
411+
412+
def test_knowledge_graph_builder_to_schema_relation_method() -> None:
413+
assert SimpleKGPipeline.to_schema_relation("REL_TYPE") == SchemaRelation(
414+
label="REL_TYPE"
415+
)
416+
assert SimpleKGPipeline.to_schema_relation({"label": "REL_TYPE"}) == SchemaRelation(
417+
label="REL_TYPE"
418+
)
419+
assert SimpleKGPipeline.to_schema_relation(
420+
{"label": "REL_TYPE", "description": "A rel type"}
421+
) == SchemaRelation(label="REL_TYPE", description="A rel type")
422+
assert SimpleKGPipeline.to_schema_relation(
423+
{"label": "REL_TYPE", "properties": []}
424+
) == SchemaRelation(label="REL_TYPE")
425+
assert SimpleKGPipeline.to_schema_relation(
426+
{"label": "REL_TYPE", "properties": [{"name": "relProperty", "type": "DATE"}]}
427+
) == SchemaRelation(
428+
label="REL_TYPE", properties=[SchemaProperty(name="relProperty", type="DATE")]
429+
)

0 commit comments

Comments
 (0)