Skip to content

Added created_at:datetime on the :Session and :Message node #349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ async def run_with_context(self, context_: RunContext, value: int) -> OutputMode
for i in range(value):
await asyncio.sleep(0.5) # Simulate work
await context_.notify(
message=f"Added {i+1}/{value}", data={"current": i + 1, "total": value}
message=f"Added {i + 1}/{value}",
data={"current": i + 1, "total": value},
)
return OutputModel(result=value + self.number)

Expand All @@ -38,7 +39,8 @@ async def run_with_context(self, context_: RunContext, value: int) -> OutputMode
for i in range(3): # Always do 3 steps
await asyncio.sleep(0.7) # Simulate work
await context_.notify(
message=f"Multiplication step {i+1}/3", data={"step": i + 1, "total": 3}
message=f"Multiplication step {i + 1}/3",
data={"step": i + 1, "total": 3},
)
return OutputModel(result=value * self.multiplier)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def embed_query(self, text: str) -> list[float]:
)

# Initialize the retriever
retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author)" "RETURN author.name"
retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author)RETURN author.name"
retriever = HybridCypherRetriever(
driver, INDEX_NAME, FULLTEXT_INDEX_NAME, retrieval_query, embedder
)
Expand Down
2 changes: 1 addition & 1 deletion examples/question_answering/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


def formatter(record: neo4j.Record) -> RetrieverResultItem:
return RetrieverResultItem(content=f'{record.get("title")}: {record.get("plot")}')
return RetrieverResultItem(content=f"{record.get('title')}: {record.get('plot')}")


driver = neo4j.GraphDatabase.driver(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _get_query(
return_properties.append(f"{embedding_property}: null")
query = (
f"MATCH (c:`{chunk_label}`) "
f"RETURN c {{ { ', '.join(return_properties) } }} as chunk "
f"RETURN c {{ {', '.join(return_properties)} }} as chunk "
)
if index_property:
query += f"ORDER BY c.{index_property}"
Expand Down
11 changes: 3 additions & 8 deletions src/neo4j_graphrag/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,7 @@ def _handle_field_filter(

if field.startswith(OPERATOR_PREFIX):
raise FilterValidationError(
f"Invalid filter condition. Expected a field but got an operator: "
f"{field}"
f"Invalid filter condition. Expected a field but got an operator: {field}"
)

if isinstance(value, dict):
Expand All @@ -273,8 +272,7 @@ def _handle_field_filter(
# Verify that that operator is an operator
if operator not in SUPPORTED_OPERATORS:
raise FilterValidationError(
f"Invalid operator: {operator}. "
f"Expected one of {SUPPORTED_OPERATORS}"
f"Invalid operator: {operator}. Expected one of {SUPPORTED_OPERATORS}"
)
else: # if value is not dict, then we assume an equality operator
operator = OPERATOR_EQ
Expand Down Expand Up @@ -344,10 +342,7 @@ def _construct_metadata_filter(
else:
raise FilterValidationError(f"Unsupported operator: {key}")
query = cypher_operator.join(
[
f"({ _construct_metadata_filter(el, param_store, node_alias)})"
for el in value
]
[f"({_construct_metadata_filter(el, param_store, node_alias)})" for el in value]
)
return query

Expand Down
6 changes: 4 additions & 2 deletions src/neo4j_graphrag/message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
Neo4jMessageHistoryModel,
)

CREATE_SESSION_NODE_QUERY = "MERGE (s:`{node_label}` {{id:$session_id}})"
CREATE_SESSION_NODE_QUERY = (
"MERGE (s:`{node_label}` {{id:$session_id, createdAt: datetime()}})"
)

DELETE_SESSION_AND_MESSAGES_QUERY = (
"MATCH (s:`{node_label}`) "
Expand Down Expand Up @@ -56,7 +58,7 @@
"MATCH (s:`{node_label}`) WHERE s.id = $session_id "
"OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message) "
"CREATE (s)-[:LAST_MESSAGE]->(new:Message) "
"SET new += {{role:$role, content:$content}} "
"SET new += {{role:$role, content:$content, createdAt: datetime()}} "
"WITH new, lm, last_message WHERE last_message IS NOT NULL "
"CREATE (last_message)-[:NEXT]->(new) "
"DELETE lm"
Expand Down
12 changes: 6 additions & 6 deletions src/neo4j_graphrag/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def _format_property(prop: Dict[str, Any]) -> Optional[str]:
else:
return (
"Available options: "
+ f'{[_clean_string_values(el) for el in prop["values"]]}'
+ f"{[_clean_string_values(el) for el in prop['values']]}"
)
elif prop["type"] in [
"INTEGER",
Expand All @@ -395,14 +395,14 @@ def _format_property(prop: Dict[str, Any]) -> Optional[str]:
"LOCAL_DATE_TIME",
]:
if prop.get("min") and prop.get("max"):
return f'Min: {prop["min"]}, Max: {prop["max"]}'
return f"Min: {prop['min']}, Max: {prop['max']}"
else:
return f'Example: "{prop["values"][0]}"' if prop.get("values") else ""
elif prop["type"] == "LIST":
if not prop.get("min_size") or prop["min_size"] > LIST_LIMIT:
return None
else:
return f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}'
return f"Min Size: {prop['min_size']}, Max Size: {prop['max_size']}"
return ""


Expand Down Expand Up @@ -551,7 +551,7 @@ def _build_str_clauses(
sanitize=sanitize,
)[0]["value"]
return_clauses.append(
(f"values: {distinct_values}," f" distinct_count: {len(distinct_values)}")
(f"values: {distinct_values}, distinct_count: {len(distinct_values)}")
)
else:
with_clauses.append(
Expand Down Expand Up @@ -595,7 +595,7 @@ def _build_list_clauses(prop_name: str) -> Tuple[str, str]:
)

return_clause = (
f"min_size: `{prop_name}_size_min`, " f"max_size: `{prop_name}_size_max`"
f"min_size: `{prop_name}_size_min`, max_size: `{prop_name}_size_max`"
)
return with_clause, return_clause

Expand Down Expand Up @@ -630,7 +630,7 @@ def _build_num_date_clauses(
return_clauses = []
if not prop_index and not exhaustive:
with_clauses.append(
f"collect(distinct toString(n.`{prop_name}`)) " f"AS `{prop_name}_values`"
f"collect(distinct toString(n.`{prop_name}`)) AS `{prop_name}_values`"
)
return_clauses.append(f"values: `{prop_name}_values`")
else:
Expand Down
8 changes: 2 additions & 6 deletions tests/e2e/retrievers/test_hybrid_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def test_hybrid_retriever_no_neo4j_deprecation_warning(
def test_hybrid_cypher_retriever_search_text(
driver: Driver, random_embedder: Embedder
) -> None:
retrieval_query = (
"MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name"
)
retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author) RETURN author.name"
retriever = HybridCypherRetriever(
driver,
"vector-index-name",
Expand Down Expand Up @@ -127,9 +125,7 @@ def test_hybrid_retriever_search_vector(driver: Driver) -> None:

@pytest.mark.usefixtures("setup_neo4j_for_retrieval")
def test_hybrid_cypher_retriever_search_vector(driver: Driver) -> None:
retrieval_query = (
"MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name"
)
retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author) RETURN author.name"
retriever = HybridCypherRetriever(
driver,
"vector-index-name",
Expand Down
8 changes: 2 additions & 6 deletions tests/e2e/retrievers/test_vector_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def test_vector_retriever_search_text(
def test_vector_cypher_retriever_search_text(
driver: Driver, random_embedder: Embedder
) -> None:
retrieval_query = (
"MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name"
)
retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author) RETURN author.name"
retriever = VectorCypherRetriever(
driver, "vector-index-name", retrieval_query, random_embedder
)
Expand Down Expand Up @@ -88,9 +86,7 @@ def test_vector_retriever_search_vector(driver: Driver) -> None:

@pytest.mark.usefixtures("setup_neo4j_for_retrieval")
def test_vector_cypher_retriever_search_vector(driver: Driver) -> None:
retrieval_query = (
"MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name"
)
retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author) RETURN author.name"
retriever = VectorCypherRetriever(driver, "vector-index-name", retrieval_query)

top_k = 5
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def test_format_schema(
True,
5,
False,
("MATCH (n:`Journey`)\n" "RETURN {} AS output"),
("MATCH (n:`Journey`)\nRETURN {} AS output"),
),
(
"Non-exhaustive, duration property",
Expand Down