diff --git a/src/neo4j_graphrag/indexes.py b/src/neo4j_graphrag/indexes.py index a8e8d6234..4abafa7bc 100644 --- a/src/neo4j_graphrag/indexes.py +++ b/src/neo4j_graphrag/indexes.py @@ -608,7 +608,11 @@ def _sort_by_index_name( def retrieve_vector_index_info( - driver: neo4j.Driver, index_name: str, label_or_type: str, embedding_property: str + driver: neo4j.Driver, + index_name: str, + label_or_type: str, + embedding_property: str, + neo4j_database: Optional[str] = None, ) -> Optional[neo4j.Record]: """ Check if a vector index exists in a Neo4j database and return its @@ -621,6 +625,9 @@ def retrieve_vector_index_info( of the index. embedding_property (str): The name of the property containing the embeddings. + neo4j_database (Optional[str]): The name of the Neo4j database. + If not provided, this defaults to the server's default database ("neo4j" by default) + (`see reference to documentation `_). Returns: Optional[Dict[str, Any]]: @@ -640,6 +647,7 @@ def retrieve_vector_index_info( "label_or_type": label_or_type, "embedding_property": embedding_property, }, + database_=neo4j_database, ) index_information = _sort_by_index_name(result.records, index_name) if len(index_information) > 0: @@ -653,6 +661,7 @@ def retrieve_fulltext_index_info( index_name: str, label_or_type: str, text_properties: List[str] = [], + neo4j_database: Optional[str] = None, ) -> Optional[neo4j.Record]: """ Check if a full text index exists in a Neo4j database and return its @@ -664,6 +673,9 @@ def retrieve_fulltext_index_info( label_or_type (str): The label (for nodes) or type (for relationships) of the index. text_properties (List[str]): The names of the text properties indexed. + neo4j_database (Optional[str]): The name of the Neo4j database. + If not provided, this defaults to the server's default database ("neo4j" by default) + (`see reference to documentation `_). Returns: Optional[Dict[str, Any]]: @@ -683,6 +695,7 @@ def retrieve_fulltext_index_info( "label_or_type": label_or_type, "text_properties": text_properties, }, + database_=neo4j_database, ) index_information = _sort_by_index_name(result.records, index_name) if len(index_information) > 0: