Skip to content

Commit 06d9889

Browse files
authored
Makes relations and potential_schema optional in SchemaBuilder (#184)
* Fixed small typo in README * Fixed small typo in KG creation prompt * Made relations and potential schema optional in schema component * Updated unit tests * Updated changelog
1 parent 20b374d commit 06d9889

File tree

5 files changed

+117
-24
lines changed

5 files changed

+117
-24
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## Next
44

5+
### Added
6+
- Made `relations` and `potential_schema` optional in `SchemaBuilder`.
7+
58
## 1.1.0
69

710
### Added

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ kg_builder = SimpleKGPipeline(
101101

102102
# Run the pipeline on a piece of text
103103
text = (
104-
"The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House"
104+
"The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House "
105105
"Atreides, an aristocratic family that rules the planet Caladan."
106106
)
107107
asyncio.run(kg_builder.run_async(text=text))
@@ -164,7 +164,7 @@ embedder = OpenAIEmbeddings(model="text-embedding-3-large")
164164

165165
# Generate an embedding for some text
166166
text = (
167-
"The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House"
167+
"The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House "
168168
"Atreides, an aristocratic family that rules the planet Caladan."
169169
)
170170
vector = embedder.embed_query(text)

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
from typing import Any, Dict, List, Literal, Tuple
17+
from typing import Any, Dict, List, Literal, Optional, Tuple
1818

1919
from pydantic import BaseModel, ValidationError, model_validator, validate_call
2020

@@ -72,28 +72,33 @@ class SchemaConfig(DataModel):
7272
"""
7373

7474
entities: Dict[str, Dict[str, Any]]
75-
relations: Dict[str, Dict[str, Any]]
76-
potential_schema: List[Tuple[str, str, str]]
75+
relations: Optional[Dict[str, Dict[str, Any]]]
76+
potential_schema: Optional[List[Tuple[str, str, str]]]
7777

7878
@model_validator(mode="before")
7979
def check_schema(cls, data: Dict[str, Any]) -> Dict[str, Any]:
8080
entities = data.get("entities", {}).keys()
8181
relations = data.get("relations", {}).keys()
8282
potential_schema = data.get("potential_schema", [])
8383

84-
for entity1, relation, entity2 in potential_schema:
85-
if entity1 not in entities:
84+
if potential_schema:
85+
if not relations:
8686
raise SchemaValidationError(
87-
f"Entity '{entity1}' is not defined in the provided entities."
88-
)
89-
if relation not in relations:
90-
raise SchemaValidationError(
91-
f"Relation '{relation}' is not defined in the provided relations."
92-
)
93-
if entity2 not in entities:
94-
raise SchemaValidationError(
95-
f"Entity '{entity2}' is not defined in the provided entities."
87+
"Relations must also be provided when using a potential schema."
9688
)
89+
for entity1, relation, entity2 in potential_schema:
90+
if entity1 not in entities:
91+
raise SchemaValidationError(
92+
f"Entity '{entity1}' is not defined in the provided entities."
93+
)
94+
if relation not in relations:
95+
raise SchemaValidationError(
96+
f"Relation '{relation}' is not defined in the provided relations."
97+
)
98+
if entity2 not in entities:
99+
raise SchemaValidationError(
100+
f"Entity '{entity2}' is not defined in the provided entities."
101+
)
97102

98103
return data
99104

@@ -160,8 +165,8 @@ class SchemaBuilder(Component):
160165
@staticmethod
161166
def create_schema_model(
162167
entities: List[SchemaEntity],
163-
relations: List[SchemaRelation],
164-
potential_schema: List[Tuple[str, str, str]],
168+
relations: Optional[List[SchemaRelation]] = None,
169+
potential_schema: Optional[List[Tuple[str, str, str]]] = None,
165170
) -> SchemaConfig:
166171
"""
167172
Creates a SchemaConfig object from Lists of Entity and Relation objects
@@ -176,9 +181,11 @@ def create_schema_model(
176181
SchemaConfig: A configured schema object.
177182
"""
178183
entity_dict = {entity.label: entity.model_dump() for entity in entities}
179-
relation_dict = {
180-
relation.label: relation.model_dump() for relation in relations
181-
}
184+
relation_dict = (
185+
{relation.label: relation.model_dump() for relation in relations}
186+
if relations
187+
else {}
188+
)
182189

183190
try:
184191
return SchemaConfig(
@@ -193,8 +200,8 @@ def create_schema_model(
193200
async def run(
194201
self,
195202
entities: List[SchemaEntity],
196-
relations: List[SchemaRelation],
197-
potential_schema: List[Tuple[str, str, str]],
203+
relations: Optional[List[SchemaRelation]] = None,
204+
potential_schema: Optional[List[Tuple[str, str, str]]] = None,
198205
) -> SchemaConfig:
199206
"""
200207
Asynchronously constructs and returns a SchemaConfig object.

src/neo4j_graphrag/generation/prompts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class ERExtractionTemplate(PromptTemplate):
167167
{{"nodes": [ {{"id": "0", "label": "Person", "properties": {{"name": "John"}} }}],
168168
"relationships": [{{"type": "KNOWS", "start_node_id": "0", "end_node_id": "1", "properties": {{"since": "2024-08-01"}} }}] }}
169169
170-
Use only fhe following nodes and relationships (if provided):
170+
Use only the following nodes and relationships (if provided):
171171
{schema}
172172
173173
Assign a unique ID (string) to each node, and reuse it to define relationships.

tests/unit/experimental/components/test_schema.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def test_create_schema_model_valid_data(
117117
)
118118
assert schema_instance.entities["AGE"]["description"] == "Age of a person in years."
119119

120+
assert schema_instance.relations
120121
assert (
121122
schema_instance.relations["EMPLOYED_BY"]["description"]
122123
== "Indicates employment relationship."
@@ -134,6 +135,7 @@ def test_create_schema_model_valid_data(
134135
{"description": "", "name": "end_time", "type": "LOCAL_DATETIME"},
135136
]
136137

138+
assert schema_instance.potential_schema
137139
assert schema_instance.potential_schema == potential_schema
138140

139141

@@ -159,6 +161,7 @@ def test_create_schema_model_missing_description(
159161

160162
assert schema_instance.entities["ORGANIZATION"]["description"] == ""
161163
assert schema_instance.entities["AGE"]["description"] == ""
164+
assert schema_instance.relations
162165
assert schema_instance.relations["ORGANIZED_BY"]["description"] == ""
163166
assert schema_instance.relations["ATTENDED_BY"]["description"] == ""
164167

@@ -242,6 +245,7 @@ async def test_run_method(
242245
)
243246
assert schema.entities["AGE"]["description"] == "Age of a person in years."
244247

248+
assert schema.relations
245249
assert (
246250
schema.relations["EMPLOYED_BY"]["description"]
247251
== "Indicates employment relationship."
@@ -255,6 +259,7 @@ async def test_run_method(
255259
== "Indicates attendance at an event."
256260
)
257261

262+
assert schema.potential_schema
258263
assert schema.potential_schema == potential_schema
259264

260265

@@ -327,6 +332,7 @@ def test_create_schema_model_missing_properties(
327332
schema_instance.entities["AGE"]["properties"] == []
328333
), "Expected empty properties for AGE"
329334

335+
assert schema_instance.relations
330336
assert (
331337
schema_instance.relations["EMPLOYED_BY"]["properties"] == []
332338
), "Expected empty properties for EMPLOYED_BY"
@@ -336,3 +342,80 @@ def test_create_schema_model_missing_properties(
336342
assert (
337343
schema_instance.relations["ATTENDED_BY"]["properties"] == []
338344
), "Expected empty properties for ATTENDED_BY"
345+
346+
347+
def test_create_schema_model_no_potential_schema(
348+
schema_builder: SchemaBuilder,
349+
valid_entities: list[SchemaEntity],
350+
valid_relations: list[SchemaRelation],
351+
) -> None:
352+
schema_instance = schema_builder.create_schema_model(
353+
valid_entities, valid_relations
354+
)
355+
356+
assert (
357+
schema_instance.entities["PERSON"]["description"]
358+
== "An individual human being."
359+
)
360+
assert schema_instance.entities["PERSON"]["properties"] == [
361+
{"description": "", "name": "birth date", "type": "ZONED_DATETIME"},
362+
{"description": "", "name": "name", "type": "STRING"},
363+
]
364+
assert (
365+
schema_instance.entities["ORGANIZATION"]["description"]
366+
== "A structured group of people with a common purpose."
367+
)
368+
assert schema_instance.entities["AGE"]["description"] == "Age of a person in years."
369+
370+
assert schema_instance.relations
371+
assert (
372+
schema_instance.relations["EMPLOYED_BY"]["description"]
373+
== "Indicates employment relationship."
374+
)
375+
assert (
376+
schema_instance.relations["ORGANIZED_BY"]["description"]
377+
== "Indicates organization responsible for an event."
378+
)
379+
assert (
380+
schema_instance.relations["ATTENDED_BY"]["description"]
381+
== "Indicates attendance at an event."
382+
)
383+
assert schema_instance.relations["EMPLOYED_BY"]["properties"] == [
384+
{"description": "", "name": "start_time", "type": "LOCAL_DATETIME"},
385+
{"description": "", "name": "end_time", "type": "LOCAL_DATETIME"},
386+
]
387+
388+
389+
def test_create_schema_model_no_relations_or_potential_schema(
390+
schema_builder: SchemaBuilder,
391+
valid_entities: list[SchemaEntity],
392+
) -> None:
393+
schema_instance = schema_builder.create_schema_model(valid_entities)
394+
395+
assert (
396+
schema_instance.entities["PERSON"]["description"]
397+
== "An individual human being."
398+
)
399+
assert schema_instance.entities["PERSON"]["properties"] == [
400+
{"description": "", "name": "birth date", "type": "ZONED_DATETIME"},
401+
{"description": "", "name": "name", "type": "STRING"},
402+
]
403+
assert (
404+
schema_instance.entities["ORGANIZATION"]["description"]
405+
== "A structured group of people with a common purpose."
406+
)
407+
assert schema_instance.entities["AGE"]["description"] == "Age of a person in years."
408+
409+
410+
def test_create_schema_model_missing_relations(
411+
schema_builder: SchemaBuilder,
412+
valid_entities: list[SchemaEntity],
413+
potential_schema: list[tuple[str, str, str]],
414+
) -> None:
415+
with pytest.raises(SchemaValidationError) as exc_info:
416+
schema_builder.create_schema_model(
417+
entities=valid_entities, potential_schema=potential_schema
418+
)
419+
assert "Relations must also be provided when using a potential schema." in str(
420+
exc_info.value
421+
), "Should fail due to missing relations"

0 commit comments

Comments
 (0)