Skip to content

Commit bb032f9

Browse files
authored
Fixes and refactors the KG writer component (#92)
* Fixes and refactors the KG writer component * Fixed mypy error * Made start_node_id and end_node_id parameters in UPSERT_RELATIONSHIP_QUERY
1 parent 9b22736 commit bb032f9

File tree

6 files changed

+118
-113
lines changed

6 files changed

+118
-113
lines changed

src/neo4j_genai/components/kg_writer.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -73,24 +73,23 @@ def _upsert_node(self, node: Neo4jNode) -> None:
7373
node (Neo4jNode): The node to upsert into the database.
7474
"""
7575
# Create the initial node
76-
properties = "{" + f"id: {node.id}"
76+
parameters = {"id": node.id}
7777
if node.properties:
78-
properties += (
79-
", " + ", ".join(f"{p.key}: {p.value}" for p in node.properties) + "}"
80-
)
81-
else:
82-
properties += "}"
78+
parameters.update(node.properties)
79+
properties = (
80+
"{" + ", ".join(f"{key}: ${key}" for key in parameters.keys()) + "}"
81+
)
8382
query = UPSERT_NODE_QUERY.format(label=node.label, properties=properties)
84-
result = self.driver.execute_query(query)
83+
result = self.driver.execute_query(query, parameters_=parameters)
8584
node_id = result.records[0]["elementID(n)"]
8685
# Add the embedding properties to the node
8786
if node.embedding_properties:
88-
for prop in node.embedding_properties:
87+
for prop, vector in node.embedding_properties.items():
8988
upsert_vector(
9089
driver=self.driver,
9190
node_id=node_id,
92-
embedding_property=prop.key,
93-
vector=prop.value,
91+
embedding_property=prop,
92+
vector=vector,
9493
neo4j_database=self.neo4j_database,
9594
)
9695

@@ -101,27 +100,31 @@ def _upsert_relationship(self, rel: Neo4jRelationship) -> None:
101100
rel (Neo4jRelationship): The relationship to upsert into the database.
102101
"""
103102
# Create the initial relationship
104-
properties = (
105-
"{" + ", ".join(f"{p.key}: {p.value}" for p in rel.properties) + "}"
106-
if rel.properties
107-
else "{}"
108-
)
103+
parameters = {
104+
"start_node_id": rel.start_node_id,
105+
"end_node_id": rel.end_node_id,
106+
}
107+
if rel.properties:
108+
properties = (
109+
"{" + ", ".join(f"{key}: ${key}" for key in rel.properties.keys()) + "}"
110+
)
111+
parameters.update(rel.properties)
112+
else:
113+
properties = "{}"
109114
query = UPSERT_RELATIONSHIP_QUERY.format(
110-
start_node_id=rel.start_node_id,
111-
end_node_id=rel.end_node_id,
112115
type=rel.type,
113116
properties=properties,
114117
)
115-
result = self.driver.execute_query(query)
118+
result = self.driver.execute_query(query, parameters_=parameters)
116119
rel_id = result.records[0]["elementID(r)"]
117120
# Add the embedding properties to the relationship
118121
if rel.embedding_properties:
119-
for prop in rel.embedding_properties:
122+
for prop, vector in rel.embedding_properties.items():
120123
upsert_vector_on_relationship(
121124
driver=self.driver,
122125
rel_id=rel_id,
123-
embedding_property=prop.key,
124-
vector=prop.value,
126+
embedding_property=prop,
127+
vector=vector,
125128
neo4j_database=self.neo4j_database,
126129
)
127130

src/neo4j_genai/components/types.py

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Any, Optional
1818

1919
from neo4j_genai.pipeline.component import DataModel
20-
from pydantic import BaseModel
20+
from pydantic import BaseModel, field_validator
2121

2222

2323
class TextChunk(BaseModel):
@@ -42,44 +42,29 @@ class TextChunks(DataModel):
4242
chunks: list[TextChunk]
4343

4444

45-
class Neo4jProperty(BaseModel):
46-
"""Represents a Neo4j property.
47-
48-
Attributes:
49-
key (str): The property name.
50-
value (Any): The property value.
51-
"""
52-
53-
key: str
54-
value: Any
55-
56-
57-
class Neo4jEmbeddingProperty(BaseModel):
58-
"""Represents a Neo4j embedding property.
59-
60-
Attributes:
61-
key (str): The property name.
62-
value (list[float]): The embedding vector.
63-
"""
64-
65-
key: str
66-
value: list[float]
67-
68-
6945
class Neo4jNode(BaseModel):
7046
"""Represents a Neo4j node.
7147
7248
Attributes:
7349
id (str): The ID of the node.
7450
label (str): The label of the node.
75-
properties (Optional[list[Neo4jProperty]]): A list of properties associated with the node.
76-
embedding_properties (Optional[list[Neo4jEmbeddingProperty]]): A list of embedding properties associated with the node.
51+
properties (Optional[dict[str, Any]]): A dictionary of properties attached to the node.
52+
embedding_properties (Optional[dict[str, list[float]]]): A list of embedding properties attached to the node.
7753
"""
7854

7955
id: str
8056
label: str
81-
properties: Optional[list[Neo4jProperty]] = None
82-
embedding_properties: Optional[list[Neo4jEmbeddingProperty]] = None
57+
properties: Optional[dict[str, Any]] = None
58+
embedding_properties: Optional[dict[str, list[float]]] = None
59+
60+
@field_validator("properties", "embedding_properties")
61+
@classmethod
62+
def check_for_id_properties(
63+
cls, v: Optional[dict[str, Any]]
64+
) -> Optional[dict[str, Any]]:
65+
if v and "id" in v.keys():
66+
raise TypeError("'id' as a property name is not allowed")
67+
return v
8368

8469

8570
class Neo4jRelationship(BaseModel):
@@ -89,15 +74,15 @@ class Neo4jRelationship(BaseModel):
8974
start_node_id (str): The ID of the start node.
9075
end_node_id (str): The ID of the end node.
9176
type (str): The relationship type.
92-
properties (Optional[list[Neo4jProperty]]): A list of properties associated with the relationship.
93-
embedding_properties (Optional[list[Neo4jEmbeddingProperty]]): A list of embedding properties associated with the relationship.
77+
properties (Optional[dict[str, Any]]): A dictionary of properties attached to the relationship.
78+
embedding_properties (Optional[dict[str, list[float]]]): A list of embedding properties attached to the relationship.
9479
"""
9580

9681
start_node_id: str
9782
end_node_id: str
9883
type: str
99-
properties: Optional[list[Neo4jProperty]] = None
100-
embedding_properties: Optional[list[Neo4jEmbeddingProperty]] = None
84+
properties: Optional[dict[str, Any]] = None
85+
embedding_properties: Optional[dict[str, list[float]]] = None
10186

10287

10388
class Neo4jGraph(BaseModel):

src/neo4j_genai/neo4j_queries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
UPSERT_NODE_QUERY = "MERGE (n:`{label}` {properties}) RETURN elementID(n)"
4545

4646
UPSERT_RELATIONSHIP_QUERY = (
47-
"MATCH (start {{ id: {start_node_id} }}), (end {{ id: {end_node_id} }}) "
47+
"MATCH (start {{ id: $start_node_id }}), (end {{ id: $end_node_id }}) "
4848
"MERGE (start)-[r:{type} {properties}]->(end) "
4949
"RETURN elementID(r)"
5050
)

tests/e2e/test_kg_construction_e2e.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717
import pytest
1818
from neo4j_genai.components.kg_writer import Neo4jWriter
1919
from neo4j_genai.components.types import (
20-
Neo4jEmbeddingProperty,
2120
Neo4jGraph,
2221
Neo4jNode,
23-
Neo4jProperty,
2422
Neo4jRelationship,
2523
)
2624

@@ -31,21 +29,14 @@ async def test_kg_writer(driver: neo4j.Driver) -> None:
3129
start_node = Neo4jNode(
3230
id="1",
3331
label="Document",
34-
properties=[Neo4jProperty(key="chunk", value=1)],
35-
embedding_properties=[
36-
Neo4jEmbeddingProperty(key="vectorProperty", value=[1.0, 2.0, 3.0])
37-
],
32+
properties={"chunk": 1},
33+
embedding_properties={"vectorProperty": [1.0, 2.0, 3.0]},
3834
)
3935
end_node = Neo4jNode(
4036
id="2",
4137
label="Document",
42-
properties=[Neo4jProperty(key="chunk", value=2)],
43-
embedding_properties=[
44-
Neo4jEmbeddingProperty(
45-
key="vectorProperty",
46-
value=[1.0, 2.0, 3.0],
47-
)
48-
],
38+
properties={"chunk": 2},
39+
embedding_properties={"vectorProperty": [1.0, 2.0, 3.0]},
4940
)
5041
relationship = Neo4jRelationship(
5142
start_node_id="1", end_node_id="2", type="NEXT_CHUNK"
@@ -56,7 +47,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None:
5647
await neo4j_writer.run(graph=graph)
5748

5849
query = """
59-
MATCH (a:Document {id: 1})-[r:NEXT_CHUNK]->(b:Document {id: 2})
50+
MATCH (a:Document {id: '1'})-[r:NEXT_CHUNK]->(b:Document {id: '2'})
6051
RETURN a, r, b
6152
"""
6253
record = driver.execute_query(query).records[0]
@@ -66,25 +57,25 @@ async def test_kg_writer(driver: neo4j.Driver) -> None:
6657
assert start_node.label == list(node_a.labels)[0]
6758
assert start_node.id == str(node_a.get("id"))
6859
if start_node.properties:
69-
for prop in start_node.properties:
70-
assert prop.key in node_a.keys()
71-
assert prop.value == node_a.get(prop.key)
60+
for key, val in start_node.properties.items():
61+
assert key in node_a.keys()
62+
assert val == node_a.get(key)
7263
if start_node.embedding_properties:
73-
for embedding_prop in start_node.embedding_properties:
74-
assert embedding_prop.key in node_a.keys()
75-
assert node_a.get(embedding_prop.key) == [1.0, 2.0, 3.0]
64+
for key, val in start_node.embedding_properties.items():
65+
assert key in node_a.keys()
66+
assert node_a.get(key) == [1.0, 2.0, 3.0]
7667

7768
node_b = record["b"]
7869
assert end_node.label == list(node_b.labels)[0]
7970
assert end_node.id == str(node_b.get("id"))
8071
if end_node.properties:
81-
for prop in end_node.properties:
82-
assert prop.key in node_b.keys()
83-
assert prop.value == node_b.get(prop.key)
72+
for key, val in end_node.properties.items():
73+
assert key in node_b.keys()
74+
assert val == node_b.get(key)
8475
if end_node.embedding_properties:
85-
for embedding_prop in end_node.embedding_properties:
86-
assert embedding_prop.key in node_b.keys()
87-
assert node_b.get(embedding_prop.key) == [1.0, 2.0, 3.0]
76+
for key, val in end_node.embedding_properties.items():
77+
assert key in node_b.keys()
78+
assert node_b.get(key) == [1.0, 2.0, 3.0]
8879

8980
rel = record["r"]
9081
assert rel.type == relationship.type

0 commit comments

Comments
 (0)