Skip to content

Fixes and refactors the KG writer component #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 24 additions & 21 deletions src/neo4j_genai/components/kg_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,23 @@ def _upsert_node(self, node: Neo4jNode) -> None:
node (Neo4jNode): The node to upsert into the database.
"""
# Create the initial node
properties = "{" + f"id: {node.id}"
parameters = {"id": node.id}
if node.properties:
properties += (
", " + ", ".join(f"{p.key}: {p.value}" for p in node.properties) + "}"
)
else:
properties += "}"
parameters.update(node.properties)
properties = (
"{" + ", ".join(f"{key}: ${key}" for key in parameters.keys()) + "}"
)
query = UPSERT_NODE_QUERY.format(label=node.label, properties=properties)
result = self.driver.execute_query(query)
result = self.driver.execute_query(query, parameters_=parameters)
node_id = result.records[0]["elementID(n)"]
# Add the embedding properties to the node
if node.embedding_properties:
for prop in node.embedding_properties:
for prop, vector in node.embedding_properties.items():
upsert_vector(
driver=self.driver,
node_id=node_id,
embedding_property=prop.key,
vector=prop.value,
embedding_property=prop,
vector=vector,
neo4j_database=self.neo4j_database,
)

Expand All @@ -101,27 +100,31 @@ def _upsert_relationship(self, rel: Neo4jRelationship) -> None:
rel (Neo4jRelationship): The relationship to upsert into the database.
"""
# Create the initial relationship
properties = (
"{" + ", ".join(f"{p.key}: {p.value}" for p in rel.properties) + "}"
if rel.properties
else "{}"
)
parameters = {
"start_node_id": rel.start_node_id,
"end_node_id": rel.end_node_id,
}
if rel.properties:
properties = (
"{" + ", ".join(f"{key}: ${key}" for key in rel.properties.keys()) + "}"
)
parameters.update(rel.properties)
else:
properties = "{}"
query = UPSERT_RELATIONSHIP_QUERY.format(
start_node_id=rel.start_node_id,
end_node_id=rel.end_node_id,
type=rel.type,
properties=properties,
)
result = self.driver.execute_query(query)
result = self.driver.execute_query(query, parameters_=parameters)
rel_id = result.records[0]["elementID(r)"]
# Add the embedding properties to the relationship
if rel.embedding_properties:
for prop in rel.embedding_properties:
for prop, vector in rel.embedding_properties.items():
upsert_vector_on_relationship(
driver=self.driver,
rel_id=rel_id,
embedding_property=prop.key,
vector=prop.value,
embedding_property=prop,
vector=vector,
neo4j_database=self.neo4j_database,
)

Expand Down
51 changes: 18 additions & 33 deletions src/neo4j_genai/components/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, Optional

from neo4j_genai.pipeline.component import DataModel
from pydantic import BaseModel
from pydantic import BaseModel, field_validator


class TextChunk(BaseModel):
Expand All @@ -42,44 +42,29 @@ class TextChunks(DataModel):
chunks: list[TextChunk]


class Neo4jProperty(BaseModel):
"""Represents a Neo4j property.

Attributes:
key (str): The property name.
value (Any): The property value.
"""

key: str
value: Any


class Neo4jEmbeddingProperty(BaseModel):
"""Represents a Neo4j embedding property.

Attributes:
key (str): The property name.
value (list[float]): The embedding vector.
"""

key: str
value: list[float]


class Neo4jNode(BaseModel):
"""Represents a Neo4j node.

Attributes:
id (str): The ID of the node.
label (str): The label of the node.
properties (Optional[list[Neo4jProperty]]): A list of properties associated with the node.
embedding_properties (Optional[list[Neo4jEmbeddingProperty]]): A list of embedding properties associated with the node.
properties (Optional[dict[str, Any]]): A dictionary of properties attached to the node.
embedding_properties (Optional[dict[str, list[float]]]): A list of embedding properties attached to the node.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for now, but I'm wondering if we could detect embeddings if they are stored in the global property list, since there must be an index in the database for them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah maybe we could include a call to the DB which checks which properties are indexed before uploading anything. I guess the disadvantage there is that I believe you can upload the properties to the nodes then create the vector index, and that would disallow that workflow.

"""

id: str
label: str
properties: Optional[list[Neo4jProperty]] = None
embedding_properties: Optional[list[Neo4jEmbeddingProperty]] = None
properties: Optional[dict[str, Any]] = None
embedding_properties: Optional[dict[str, list[float]]] = None

@field_validator("properties", "embedding_properties")
@classmethod
def check_for_id_properties(
cls, v: Optional[dict[str, Any]]
) -> Optional[dict[str, Any]]:
if v and "id" in v.keys():
raise TypeError("'id' as a property name is not allowed")
return v


class Neo4jRelationship(BaseModel):
Expand All @@ -89,15 +74,15 @@ class Neo4jRelationship(BaseModel):
start_node_id (str): The ID of the start node.
end_node_id (str): The ID of the end node.
type (str): The relationship type.
properties (Optional[list[Neo4jProperty]]): A list of properties associated with the relationship.
embedding_properties (Optional[list[Neo4jEmbeddingProperty]]): A list of embedding properties associated with the relationship.
properties (Optional[dict[str, Any]]): A dictionary of properties attached to the relationship.
embedding_properties (Optional[dict[str, list[float]]]): A list of embedding properties attached to the relationship.
"""

start_node_id: str
end_node_id: str
type: str
properties: Optional[list[Neo4jProperty]] = None
embedding_properties: Optional[list[Neo4jEmbeddingProperty]] = None
properties: Optional[dict[str, Any]] = None
embedding_properties: Optional[dict[str, list[float]]] = None


class Neo4jGraph(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_genai/neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
UPSERT_NODE_QUERY = "MERGE (n:`{label}` {properties}) RETURN elementID(n)"

UPSERT_RELATIONSHIP_QUERY = (
"MATCH (start {{ id: {start_node_id} }}), (end {{ id: {end_node_id} }}) "
"MATCH (start {{ id: $start_node_id }}), (end {{ id: $end_node_id }}) "
"MERGE (start)-[r:{type} {properties}]->(end) "
"RETURN elementID(r)"
)
Expand Down
43 changes: 17 additions & 26 deletions tests/e2e/test_kg_construction_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
import pytest
from neo4j_genai.components.kg_writer import Neo4jWriter
from neo4j_genai.components.types import (
Neo4jEmbeddingProperty,
Neo4jGraph,
Neo4jNode,
Neo4jProperty,
Neo4jRelationship,
)

Expand All @@ -31,21 +29,14 @@ async def test_kg_writer(driver: neo4j.Driver) -> None:
start_node = Neo4jNode(
id="1",
label="Document",
properties=[Neo4jProperty(key="chunk", value=1)],
embedding_properties=[
Neo4jEmbeddingProperty(key="vectorProperty", value=[1.0, 2.0, 3.0])
],
properties={"chunk": 1},
embedding_properties={"vectorProperty": [1.0, 2.0, 3.0]},
)
end_node = Neo4jNode(
id="2",
label="Document",
properties=[Neo4jProperty(key="chunk", value=2)],
embedding_properties=[
Neo4jEmbeddingProperty(
key="vectorProperty",
value=[1.0, 2.0, 3.0],
)
],
properties={"chunk": 2},
embedding_properties={"vectorProperty": [1.0, 2.0, 3.0]},
)
relationship = Neo4jRelationship(
start_node_id="1", end_node_id="2", type="NEXT_CHUNK"
Expand All @@ -56,7 +47,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None:
await neo4j_writer.run(graph=graph)

query = """
MATCH (a:Document {id: 1})-[r:NEXT_CHUNK]->(b:Document {id: 2})
MATCH (a:Document {id: '1'})-[r:NEXT_CHUNK]->(b:Document {id: '2'})
RETURN a, r, b
"""
record = driver.execute_query(query).records[0]
Expand All @@ -66,25 +57,25 @@ async def test_kg_writer(driver: neo4j.Driver) -> None:
assert start_node.label == list(node_a.labels)[0]
assert start_node.id == str(node_a.get("id"))
if start_node.properties:
for prop in start_node.properties:
assert prop.key in node_a.keys()
assert prop.value == node_a.get(prop.key)
for key, val in start_node.properties.items():
assert key in node_a.keys()
assert val == node_a.get(key)
if start_node.embedding_properties:
for embedding_prop in start_node.embedding_properties:
assert embedding_prop.key in node_a.keys()
assert node_a.get(embedding_prop.key) == [1.0, 2.0, 3.0]
for key, val in start_node.embedding_properties.items():
assert key in node_a.keys()
assert node_a.get(key) == [1.0, 2.0, 3.0]

node_b = record["b"]
assert end_node.label == list(node_b.labels)[0]
assert end_node.id == str(node_b.get("id"))
if end_node.properties:
for prop in end_node.properties:
assert prop.key in node_b.keys()
assert prop.value == node_b.get(prop.key)
for key, val in end_node.properties.items():
assert key in node_b.keys()
assert val == node_b.get(key)
if end_node.embedding_properties:
for embedding_prop in end_node.embedding_properties:
assert embedding_prop.key in node_b.keys()
assert node_b.get(embedding_prop.key) == [1.0, 2.0, 3.0]
for key, val in end_node.embedding_properties.items():
assert key in node_b.keys()
assert node_b.get(key) == [1.0, 2.0, 3.0]

rel = record["r"]
assert rel.type == relationship.type
Expand Down
Loading