From 96413a6d495787bf7924e0f077f2aa7cb76b36da Mon Sep 17 00:00:00 2001 From: Morgan Senechal Date: Wed, 28 May 2025 09:49:09 +0100 Subject: [PATCH 1/3] Added created_at:datetime on the :Session and :Message node --- src/neo4j_graphrag/message_history.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/neo4j_graphrag/message_history.py b/src/neo4j_graphrag/message_history.py index 0238b3ea4..ebb07e7f6 100644 --- a/src/neo4j_graphrag/message_history.py +++ b/src/neo4j_graphrag/message_history.py @@ -25,7 +25,7 @@ Neo4jMessageHistoryModel, ) -CREATE_SESSION_NODE_QUERY = "MERGE (s:`{node_label}` {{id:$session_id}})" +CREATE_SESSION_NODE_QUERY = "MERGE (s:`{node_label}` {{id:$session_id, created_at: datetime()}})" DELETE_SESSION_AND_MESSAGES_QUERY = ( "MATCH (s:`{node_label}`) " @@ -56,7 +56,7 @@ "MATCH (s:`{node_label}`) WHERE s.id = $session_id " "OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message) " "CREATE (s)-[:LAST_MESSAGE]->(new:Message) " - "SET new += {{role:$role, content:$content}} " + "SET new += {{role:$role, content:$content, created_at: datetime()}} " "WITH new, lm, last_message WHERE last_message IS NOT NULL " "CREATE (last_message)-[:NEXT]->(new) " "DELETE lm" From b034f5b0abded3a7cdbf0569786236bb2affbcc8 Mon Sep 17 00:00:00 2001 From: Morgan Senechal Date: Mon, 9 Jun 2025 13:33:58 +0100 Subject: [PATCH 2/3] Changed property from created_at to createdAt + ran ruff format . --- .../build_graph/pipeline/pipeline_streaming.py | 6 ++++-- .../hybrid_retrievers/hybrid_cypher_search.py | 2 +- examples/question_answering/graphrag.py | 2 +- .../experimental/components/neo4j_reader.py | 2 +- src/neo4j_graphrag/filters.py | 11 +++-------- src/neo4j_graphrag/message_history.py | 6 ++++-- src/neo4j_graphrag/schema.py | 12 ++++++------ tests/e2e/pinecone_e2e/test_pinecone_e2e.py | 6 +++--- tests/e2e/qdrant_e2e/test_qdrant_e2e.py | 6 +++--- tests/e2e/retrievers/test_hybrid_e2e.py | 8 ++------ tests/e2e/retrievers/test_vector_e2e.py | 8 ++------ tests/e2e/weaviate_e2e/test_weaviate_e2e.py | 6 +++--- .../experimental/components/test_kg_writer.py | 6 +++--- .../experimental/components/test_schema.py | 12 ++++++------ tests/unit/retrievers/test_text2cypher.py | 6 +++--- tests/unit/test_schema.py | 8 ++++---- tests/unit/utils/test_version_utils.py | 18 +++++++++--------- 17 files changed, 58 insertions(+), 67 deletions(-) diff --git a/examples/customize/build_graph/pipeline/pipeline_streaming.py b/examples/customize/build_graph/pipeline/pipeline_streaming.py index 13a262a9c..088901fe2 100644 --- a/examples/customize/build_graph/pipeline/pipeline_streaming.py +++ b/examples/customize/build_graph/pipeline/pipeline_streaming.py @@ -22,7 +22,8 @@ async def run_with_context(self, context_: RunContext, value: int) -> OutputMode for i in range(value): await asyncio.sleep(0.5) # Simulate work await context_.notify( - message=f"Added {i+1}/{value}", data={"current": i + 1, "total": value} + message=f"Added {i + 1}/{value}", + data={"current": i + 1, "total": value}, ) return OutputModel(result=value + self.number) @@ -38,7 +39,8 @@ async def run_with_context(self, context_: RunContext, value: int) -> OutputMode for i in range(3): # Always do 3 steps await asyncio.sleep(0.7) # Simulate work await context_.notify( - message=f"Multiplication step {i+1}/3", data={"step": i + 1, "total": 3} + message=f"Multiplication step {i + 1}/3", + data={"step": i + 1, "total": 3}, ) return OutputModel(result=value * self.multiplier) diff --git a/examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py b/examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py index fea0fd688..e1d59e379 100644 --- a/examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py +++ b/examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py @@ -40,7 +40,7 @@ def embed_query(self, text: str) -> list[float]: ) # Initialize the retriever -retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author)" "RETURN author.name" +retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author)RETURN author.name" retriever = HybridCypherRetriever( driver, INDEX_NAME, FULLTEXT_INDEX_NAME, retrieval_query, embedder ) diff --git a/examples/question_answering/graphrag.py b/examples/question_answering/graphrag.py index f1cb935de..edd9a3bc1 100644 --- a/examples/question_answering/graphrag.py +++ b/examples/question_answering/graphrag.py @@ -30,7 +30,7 @@ def formatter(record: neo4j.Record) -> RetrieverResultItem: - return RetrieverResultItem(content=f'{record.get("title")}: {record.get("plot")}') + return RetrieverResultItem(content=f"{record.get('title')}: {record.get('plot')}") driver = neo4j.GraphDatabase.driver( diff --git a/src/neo4j_graphrag/experimental/components/neo4j_reader.py b/src/neo4j_graphrag/experimental/components/neo4j_reader.py index 24d25fd3d..f54539bc2 100644 --- a/src/neo4j_graphrag/experimental/components/neo4j_reader.py +++ b/src/neo4j_graphrag/experimental/components/neo4j_reader.py @@ -74,7 +74,7 @@ def _get_query( return_properties.append(f"{embedding_property}: null") query = ( f"MATCH (c:`{chunk_label}`) " - f"RETURN c {{ { ', '.join(return_properties) } }} as chunk " + f"RETURN c {{ {', '.join(return_properties)} }} as chunk " ) if index_property: query += f"ORDER BY c.{index_property}" diff --git a/src/neo4j_graphrag/filters.py b/src/neo4j_graphrag/filters.py index ed51863e8..497df2bc6 100644 --- a/src/neo4j_graphrag/filters.py +++ b/src/neo4j_graphrag/filters.py @@ -255,8 +255,7 @@ def _handle_field_filter( if field.startswith(OPERATOR_PREFIX): raise FilterValidationError( - f"Invalid filter condition. Expected a field but got an operator: " - f"{field}" + f"Invalid filter condition. Expected a field but got an operator: {field}" ) if isinstance(value, dict): @@ -273,8 +272,7 @@ def _handle_field_filter( # Verify that that operator is an operator if operator not in SUPPORTED_OPERATORS: raise FilterValidationError( - f"Invalid operator: {operator}. " - f"Expected one of {SUPPORTED_OPERATORS}" + f"Invalid operator: {operator}. Expected one of {SUPPORTED_OPERATORS}" ) else: # if value is not dict, then we assume an equality operator operator = OPERATOR_EQ @@ -344,10 +342,7 @@ def _construct_metadata_filter( else: raise FilterValidationError(f"Unsupported operator: {key}") query = cypher_operator.join( - [ - f"({ _construct_metadata_filter(el, param_store, node_alias)})" - for el in value - ] + [f"({_construct_metadata_filter(el, param_store, node_alias)})" for el in value] ) return query diff --git a/src/neo4j_graphrag/message_history.py b/src/neo4j_graphrag/message_history.py index ebb07e7f6..cd0ff87d5 100644 --- a/src/neo4j_graphrag/message_history.py +++ b/src/neo4j_graphrag/message_history.py @@ -25,7 +25,9 @@ Neo4jMessageHistoryModel, ) -CREATE_SESSION_NODE_QUERY = "MERGE (s:`{node_label}` {{id:$session_id, created_at: datetime()}})" +CREATE_SESSION_NODE_QUERY = ( + "MERGE (s:`{node_label}` {{id:$session_id, createdAt: datetime()}})" +) DELETE_SESSION_AND_MESSAGES_QUERY = ( "MATCH (s:`{node_label}`) " @@ -56,7 +58,7 @@ "MATCH (s:`{node_label}`) WHERE s.id = $session_id " "OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message) " "CREATE (s)-[:LAST_MESSAGE]->(new:Message) " - "SET new += {{role:$role, content:$content, created_at: datetime()}} " + "SET new += {{role:$role, content:$content, createdAt: datetime()}} " "WITH new, lm, last_message WHERE last_message IS NOT NULL " "CREATE (last_message)-[:NEXT]->(new) " "DELETE lm" diff --git a/src/neo4j_graphrag/schema.py b/src/neo4j_graphrag/schema.py index efb2d3a38..40292067f 100644 --- a/src/neo4j_graphrag/schema.py +++ b/src/neo4j_graphrag/schema.py @@ -385,7 +385,7 @@ def _format_property(prop: Dict[str, Any]) -> Optional[str]: else: return ( "Available options: " - + f'{[_clean_string_values(el) for el in prop["values"]]}' + + f"{[_clean_string_values(el) for el in prop['values']]}" ) elif prop["type"] in [ "INTEGER", @@ -395,14 +395,14 @@ def _format_property(prop: Dict[str, Any]) -> Optional[str]: "LOCAL_DATE_TIME", ]: if prop.get("min") and prop.get("max"): - return f'Min: {prop["min"]}, Max: {prop["max"]}' + return f"Min: {prop['min']}, Max: {prop['max']}" else: return f'Example: "{prop["values"][0]}"' if prop.get("values") else "" elif prop["type"] == "LIST": if not prop.get("min_size") or prop["min_size"] > LIST_LIMIT: return None else: - return f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}' + return f"Min Size: {prop['min_size']}, Max Size: {prop['max_size']}" return "" @@ -551,7 +551,7 @@ def _build_str_clauses( sanitize=sanitize, )[0]["value"] return_clauses.append( - (f"values: {distinct_values}," f" distinct_count: {len(distinct_values)}") + (f"values: {distinct_values}, distinct_count: {len(distinct_values)}") ) else: with_clauses.append( @@ -595,7 +595,7 @@ def _build_list_clauses(prop_name: str) -> Tuple[str, str]: ) return_clause = ( - f"min_size: `{prop_name}_size_min`, " f"max_size: `{prop_name}_size_max`" + f"min_size: `{prop_name}_size_min`, max_size: `{prop_name}_size_max`" ) return with_clause, return_clause @@ -630,7 +630,7 @@ def _build_num_date_clauses( return_clauses = [] if not prop_index and not exhaustive: with_clauses.append( - f"collect(distinct toString(n.`{prop_name}`)) " f"AS `{prop_name}_values`" + f"collect(distinct toString(n.`{prop_name}`)) AS `{prop_name}_values`" ) return_clauses.append(f"values: `{prop_name}_values`") else: diff --git a/tests/e2e/pinecone_e2e/test_pinecone_e2e.py b/tests/e2e/pinecone_e2e/test_pinecone_e2e.py index be1c0d2ca..9e54b4b31 100644 --- a/tests/e2e/pinecone_e2e/test_pinecone_e2e.py +++ b/tests/e2e/pinecone_e2e/test_pinecone_e2e.py @@ -30,9 +30,9 @@ @pytest.fixture(scope="module") -def sentence_transformer_embedder() -> ( - Generator[SentenceTransformerEmbeddings, Any, Any] -): +def sentence_transformer_embedder() -> Generator[ + SentenceTransformerEmbeddings, Any, Any +]: embedder = SentenceTransformerEmbeddings(model="all-MiniLM-L6-v2") yield embedder diff --git a/tests/e2e/qdrant_e2e/test_qdrant_e2e.py b/tests/e2e/qdrant_e2e/test_qdrant_e2e.py index 5bb43977f..f81b97503 100644 --- a/tests/e2e/qdrant_e2e/test_qdrant_e2e.py +++ b/tests/e2e/qdrant_e2e/test_qdrant_e2e.py @@ -31,9 +31,9 @@ @pytest.fixture(scope="module") -def sentence_transformer_embedder() -> ( - Generator[SentenceTransformerEmbeddings, Any, Any] -): +def sentence_transformer_embedder() -> Generator[ + SentenceTransformerEmbeddings, Any, Any +]: embedder = SentenceTransformerEmbeddings(model="all-MiniLM-L6-v2") yield embedder diff --git a/tests/e2e/retrievers/test_hybrid_e2e.py b/tests/e2e/retrievers/test_hybrid_e2e.py index d3da020b3..9936a0d23 100644 --- a/tests/e2e/retrievers/test_hybrid_e2e.py +++ b/tests/e2e/retrievers/test_hybrid_e2e.py @@ -76,9 +76,7 @@ def test_hybrid_retriever_no_neo4j_deprecation_warning( def test_hybrid_cypher_retriever_search_text( driver: Driver, random_embedder: Embedder ) -> None: - retrieval_query = ( - "MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name" - ) + retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author) RETURN author.name" retriever = HybridCypherRetriever( driver, "vector-index-name", @@ -127,9 +125,7 @@ def test_hybrid_retriever_search_vector(driver: Driver) -> None: @pytest.mark.usefixtures("setup_neo4j_for_retrieval") def test_hybrid_cypher_retriever_search_vector(driver: Driver) -> None: - retrieval_query = ( - "MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name" - ) + retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author) RETURN author.name" retriever = HybridCypherRetriever( driver, "vector-index-name", diff --git a/tests/e2e/retrievers/test_vector_e2e.py b/tests/e2e/retrievers/test_vector_e2e.py index 1ea54f80e..7c20b487c 100644 --- a/tests/e2e/retrievers/test_vector_e2e.py +++ b/tests/e2e/retrievers/test_vector_e2e.py @@ -45,9 +45,7 @@ def test_vector_retriever_search_text( def test_vector_cypher_retriever_search_text( driver: Driver, random_embedder: Embedder ) -> None: - retrieval_query = ( - "MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name" - ) + retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author) RETURN author.name" retriever = VectorCypherRetriever( driver, "vector-index-name", retrieval_query, random_embedder ) @@ -88,9 +86,7 @@ def test_vector_retriever_search_vector(driver: Driver) -> None: @pytest.mark.usefixtures("setup_neo4j_for_retrieval") def test_vector_cypher_retriever_search_vector(driver: Driver) -> None: - retrieval_query = ( - "MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name" - ) + retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author) RETURN author.name" retriever = VectorCypherRetriever(driver, "vector-index-name", retrieval_query) top_k = 5 diff --git a/tests/e2e/weaviate_e2e/test_weaviate_e2e.py b/tests/e2e/weaviate_e2e/test_weaviate_e2e.py index b55ea3930..79edbdaeb 100644 --- a/tests/e2e/weaviate_e2e/test_weaviate_e2e.py +++ b/tests/e2e/weaviate_e2e/test_weaviate_e2e.py @@ -32,9 +32,9 @@ @pytest.fixture(scope="module") -def sentence_transformer_embedder() -> ( - Generator[SentenceTransformerEmbeddings, Any, Any] -): +def sentence_transformer_embedder() -> Generator[ + SentenceTransformerEmbeddings, Any, Any +]: embedder = SentenceTransformerEmbeddings(model="all-MiniLM-L6-v2") yield embedder diff --git a/tests/unit/experimental/components/test_kg_writer.py b/tests/unit/experimental/components/test_kg_writer.py index f149a264f..93bd4598f 100644 --- a/tests/unit/experimental/components/test_kg_writer.py +++ b/tests/unit/experimental/components/test_kg_writer.py @@ -370,6 +370,6 @@ def test_get_version( ) driver.execute_query = execute_query_mock neo4j_writer = Neo4jWriter(driver=driver) - assert ( - neo4j_writer.is_version_5_23_or_above is is_5_23_or_above - ), f"Failed is_version_5_23_or_above test case: {description}" + assert neo4j_writer.is_version_5_23_or_above is is_5_23_or_above, ( + f"Failed is_version_5_23_or_above test case: {description}" + ) diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index e8fc670c2..f1827c130 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -158,9 +158,9 @@ def test_create_schema_model_invalid_entity( list(valid_relationship_types), list(patterns_with_invalid_entity), ) - assert "Entity 'NON_EXISTENT_ENTITY' is not defined" in str( - exc_info.value - ), "Should fail due to non-existent entity" + assert "Entity 'NON_EXISTENT_ENTITY' is not defined" in str(exc_info.value), ( + "Should fail due to non-existent entity" + ) def test_create_schema_model_invalid_relation( @@ -175,9 +175,9 @@ def test_create_schema_model_invalid_relation( list(valid_relationship_types), list(patterns_with_invalid_relation), ) - assert "Relation 'NON_EXISTENT_RELATION' is not defined" in str( - exc_info.value - ), "Should fail due to non-existent relation" + assert "Relation 'NON_EXISTENT_RELATION' is not defined" in str(exc_info.value), ( + "Should fail due to non-existent relation" + ) def test_create_schema_model_no_potential_schema( diff --git a/tests/unit/retrievers/test_text2cypher.py b/tests/unit/retrievers/test_text2cypher.py index 3ee15a5d4..bad480eb3 100644 --- a/tests/unit/retrievers/test_text2cypher.py +++ b/tests/unit/retrievers/test_text2cypher.py @@ -498,6 +498,6 @@ def test_t2c_retriever_with_custom_prompt_and_schema( def test_extract_cypher( description: str, cypher_query: str, expected_output: str ) -> None: - assert ( - extract_cypher(cypher_query) == expected_output - ), f"Failed test case: {description}" + assert extract_cypher(cypher_query) == expected_output, ( + f"Failed test case: {description}" + ) diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index e36b8a162..1865f0776 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -205,9 +205,9 @@ def test__value_sanitize( description: str, input_value: Dict[str, Any], expected_output: Any ) -> None: """Test the _value_sanitize function.""" - assert ( - _value_sanitize(input_value) == expected_output - ), f"Failed test case: {description}" + assert _value_sanitize(input_value) == expected_output, ( + f"Failed test case: {description}" + ) @pytest.mark.parametrize( @@ -833,7 +833,7 @@ def test_format_schema( True, 5, False, - ("MATCH (n:`Journey`)\n" "RETURN {} AS output"), + ("MATCH (n:`Journey`)\nRETURN {} AS output"), ), ( "Non-exhaustive, duration property", diff --git a/tests/unit/utils/test_version_utils.py b/tests/unit/utils/test_version_utils.py index a5ead6090..f17027013 100644 --- a/tests/unit/utils/test_version_utils.py +++ b/tests/unit/utils/test_version_utils.py @@ -67,9 +67,9 @@ def test_is_version_5_23_or_above( Ensures that the is_version_5_23_or_above function accurately determines if a given version is 5.23 or higher. """ - assert ( - is_version_5_23_or_above(version_tuple) == expected_result - ), f"Failed test case: {version_tuple}" + assert is_version_5_23_or_above(version_tuple) == expected_result, ( + f"Failed test case: {version_tuple}" + ) @pytest.mark.parametrize( @@ -87,9 +87,9 @@ def test_has_vector_index_support( Tests the has_vector_index_support function to confirm it correctly identifies if the given version and platform support vector indexing. """ - assert ( - has_vector_index_support(version_tuple) == expected_result - ), f"Failed test case: {version_tuple}" + assert has_vector_index_support(version_tuple) == expected_result, ( + f"Failed test case: {version_tuple}" + ) @pytest.mark.parametrize( @@ -110,6 +110,6 @@ def test_has_metadata_filtering_support( Tests the has_metadata_filtering_support function to confirm it correctly identifies if the given version and platform support vector index metadata filtering. """ - assert ( - has_metadata_filtering_support(version_tuple, is_aura) == expected_result - ), f"Failed test case: {version_tuple}, is_aura: {is_aura}" + assert has_metadata_filtering_support(version_tuple, is_aura) == expected_result, ( + f"Failed test case: {version_tuple}, is_aura: {is_aura}" + ) From 205e81852228171ba6c6f9d4fd85fb9c401bcd51 Mon Sep 17 00:00:00 2001 From: Morgan Senechal Date: Mon, 9 Jun 2025 13:39:44 +0100 Subject: [PATCH 3/3] Ruff through peotry conf --- tests/e2e/pinecone_e2e/test_pinecone_e2e.py | 6 +++--- tests/e2e/qdrant_e2e/test_qdrant_e2e.py | 6 +++--- tests/e2e/weaviate_e2e/test_weaviate_e2e.py | 6 +++--- .../experimental/components/test_kg_writer.py | 6 +++--- .../experimental/components/test_schema.py | 12 ++++++------ tests/unit/retrievers/test_text2cypher.py | 6 +++--- tests/unit/test_schema.py | 6 +++--- tests/unit/utils/test_version_utils.py | 18 +++++++++--------- 8 files changed, 33 insertions(+), 33 deletions(-) diff --git a/tests/e2e/pinecone_e2e/test_pinecone_e2e.py b/tests/e2e/pinecone_e2e/test_pinecone_e2e.py index 9e54b4b31..be1c0d2ca 100644 --- a/tests/e2e/pinecone_e2e/test_pinecone_e2e.py +++ b/tests/e2e/pinecone_e2e/test_pinecone_e2e.py @@ -30,9 +30,9 @@ @pytest.fixture(scope="module") -def sentence_transformer_embedder() -> Generator[ - SentenceTransformerEmbeddings, Any, Any -]: +def sentence_transformer_embedder() -> ( + Generator[SentenceTransformerEmbeddings, Any, Any] +): embedder = SentenceTransformerEmbeddings(model="all-MiniLM-L6-v2") yield embedder diff --git a/tests/e2e/qdrant_e2e/test_qdrant_e2e.py b/tests/e2e/qdrant_e2e/test_qdrant_e2e.py index f81b97503..5bb43977f 100644 --- a/tests/e2e/qdrant_e2e/test_qdrant_e2e.py +++ b/tests/e2e/qdrant_e2e/test_qdrant_e2e.py @@ -31,9 +31,9 @@ @pytest.fixture(scope="module") -def sentence_transformer_embedder() -> Generator[ - SentenceTransformerEmbeddings, Any, Any -]: +def sentence_transformer_embedder() -> ( + Generator[SentenceTransformerEmbeddings, Any, Any] +): embedder = SentenceTransformerEmbeddings(model="all-MiniLM-L6-v2") yield embedder diff --git a/tests/e2e/weaviate_e2e/test_weaviate_e2e.py b/tests/e2e/weaviate_e2e/test_weaviate_e2e.py index 79edbdaeb..b55ea3930 100644 --- a/tests/e2e/weaviate_e2e/test_weaviate_e2e.py +++ b/tests/e2e/weaviate_e2e/test_weaviate_e2e.py @@ -32,9 +32,9 @@ @pytest.fixture(scope="module") -def sentence_transformer_embedder() -> Generator[ - SentenceTransformerEmbeddings, Any, Any -]: +def sentence_transformer_embedder() -> ( + Generator[SentenceTransformerEmbeddings, Any, Any] +): embedder = SentenceTransformerEmbeddings(model="all-MiniLM-L6-v2") yield embedder diff --git a/tests/unit/experimental/components/test_kg_writer.py b/tests/unit/experimental/components/test_kg_writer.py index 93bd4598f..f149a264f 100644 --- a/tests/unit/experimental/components/test_kg_writer.py +++ b/tests/unit/experimental/components/test_kg_writer.py @@ -370,6 +370,6 @@ def test_get_version( ) driver.execute_query = execute_query_mock neo4j_writer = Neo4jWriter(driver=driver) - assert neo4j_writer.is_version_5_23_or_above is is_5_23_or_above, ( - f"Failed is_version_5_23_or_above test case: {description}" - ) + assert ( + neo4j_writer.is_version_5_23_or_above is is_5_23_or_above + ), f"Failed is_version_5_23_or_above test case: {description}" diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index f1827c130..e8fc670c2 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -158,9 +158,9 @@ def test_create_schema_model_invalid_entity( list(valid_relationship_types), list(patterns_with_invalid_entity), ) - assert "Entity 'NON_EXISTENT_ENTITY' is not defined" in str(exc_info.value), ( - "Should fail due to non-existent entity" - ) + assert "Entity 'NON_EXISTENT_ENTITY' is not defined" in str( + exc_info.value + ), "Should fail due to non-existent entity" def test_create_schema_model_invalid_relation( @@ -175,9 +175,9 @@ def test_create_schema_model_invalid_relation( list(valid_relationship_types), list(patterns_with_invalid_relation), ) - assert "Relation 'NON_EXISTENT_RELATION' is not defined" in str(exc_info.value), ( - "Should fail due to non-existent relation" - ) + assert "Relation 'NON_EXISTENT_RELATION' is not defined" in str( + exc_info.value + ), "Should fail due to non-existent relation" def test_create_schema_model_no_potential_schema( diff --git a/tests/unit/retrievers/test_text2cypher.py b/tests/unit/retrievers/test_text2cypher.py index bad480eb3..3ee15a5d4 100644 --- a/tests/unit/retrievers/test_text2cypher.py +++ b/tests/unit/retrievers/test_text2cypher.py @@ -498,6 +498,6 @@ def test_t2c_retriever_with_custom_prompt_and_schema( def test_extract_cypher( description: str, cypher_query: str, expected_output: str ) -> None: - assert extract_cypher(cypher_query) == expected_output, ( - f"Failed test case: {description}" - ) + assert ( + extract_cypher(cypher_query) == expected_output + ), f"Failed test case: {description}" diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index 1865f0776..b1ff1a19f 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -205,9 +205,9 @@ def test__value_sanitize( description: str, input_value: Dict[str, Any], expected_output: Any ) -> None: """Test the _value_sanitize function.""" - assert _value_sanitize(input_value) == expected_output, ( - f"Failed test case: {description}" - ) + assert ( + _value_sanitize(input_value) == expected_output + ), f"Failed test case: {description}" @pytest.mark.parametrize( diff --git a/tests/unit/utils/test_version_utils.py b/tests/unit/utils/test_version_utils.py index f17027013..a5ead6090 100644 --- a/tests/unit/utils/test_version_utils.py +++ b/tests/unit/utils/test_version_utils.py @@ -67,9 +67,9 @@ def test_is_version_5_23_or_above( Ensures that the is_version_5_23_or_above function accurately determines if a given version is 5.23 or higher. """ - assert is_version_5_23_or_above(version_tuple) == expected_result, ( - f"Failed test case: {version_tuple}" - ) + assert ( + is_version_5_23_or_above(version_tuple) == expected_result + ), f"Failed test case: {version_tuple}" @pytest.mark.parametrize( @@ -87,9 +87,9 @@ def test_has_vector_index_support( Tests the has_vector_index_support function to confirm it correctly identifies if the given version and platform support vector indexing. """ - assert has_vector_index_support(version_tuple) == expected_result, ( - f"Failed test case: {version_tuple}" - ) + assert ( + has_vector_index_support(version_tuple) == expected_result + ), f"Failed test case: {version_tuple}" @pytest.mark.parametrize( @@ -110,6 +110,6 @@ def test_has_metadata_filtering_support( Tests the has_metadata_filtering_support function to confirm it correctly identifies if the given version and platform support vector index metadata filtering. """ - assert has_metadata_filtering_support(version_tuple, is_aura) == expected_result, ( - f"Failed test case: {version_tuple}, is_aura: {is_aura}" - ) + assert ( + has_metadata_filtering_support(version_tuple, is_aura) == expected_result + ), f"Failed test case: {version_tuple}, is_aura: {is_aura}"