Skip to content

Commit 4ce3b56

Browse files
authored
Adds database, timeout, and sanitize options to schema functions (#278)
1 parent 5101575 commit 4ce3b56

File tree

2 files changed

+148
-19
lines changed

2 files changed

+148
-19
lines changed

src/neo4j_graphrag/schema.py

Lines changed: 147 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,13 @@ def query_database(
180180
return json_data
181181

182182

183-
def get_schema(driver: neo4j.Driver, is_enhanced: bool = False) -> str:
183+
def get_schema(
184+
driver: neo4j.Driver,
185+
is_enhanced: bool = False,
186+
database: Optional[str] = None,
187+
timeout: Optional[float] = None,
188+
sanitize: bool = False,
189+
) -> str:
184190
"""
185191
Returns the schema of the graph as a string with following format:
186192
@@ -197,16 +203,34 @@ def get_schema(driver: neo4j.Driver, is_enhanced: bool = False) -> str:
197203
driver (neo4j.Driver): Neo4j Python driver instance.
198204
is_enhanced (bool): Flag indicating whether to format the schema with
199205
detailed statistics (True) or in a simpler overview format (False).
206+
database (Optional[str]): The name of the database to connect to. Default is 'neo4j'.
207+
timeout (Optional[float]): The timeout for transactions in seconds.
208+
Useful for terminating long-running queries.
209+
By default, there is no timeout set.
210+
sanitize (bool): A flag to indicate whether to remove lists with
211+
more than 128 elements from results. Useful for removing
212+
embedding-like properties from database responses. Default is False.
213+
200214
201215
Returns:
202216
str: the graph schema information in a serialized format.
203217
"""
204-
structured_schema = get_structured_schema(driver, is_enhanced)
218+
structured_schema = get_structured_schema(
219+
driver=driver,
220+
is_enhanced=is_enhanced,
221+
database=database,
222+
timeout=timeout,
223+
sanitize=sanitize,
224+
)
205225
return format_schema(structured_schema, is_enhanced)
206226

207227

208228
def get_structured_schema(
209-
driver: neo4j.Driver, is_enhanced: bool = False
229+
driver: neo4j.Driver,
230+
is_enhanced: bool = False,
231+
database: Optional[str] = None,
232+
timeout: Optional[float] = None,
233+
sanitize: bool = False,
210234
) -> dict[str, Any]:
211235
"""
212236
Returns the structured schema of the graph.
@@ -249,45 +273,75 @@ def get_structured_schema(
249273
driver (neo4j.Driver): Neo4j Python driver instance.
250274
is_enhanced (bool): Flag indicating whether to format the schema with
251275
detailed statistics (True) or in a simpler overview format (False).
276+
database (Optional[str]): The name of the database to connect to. Default is 'neo4j'.
277+
timeout (Optional[float]): The timeout for transactions in seconds.
278+
Useful for terminating long-running queries.
279+
By default, there is no timeout set.
280+
sanitize (bool): A flag to indicate whether to remove lists with
281+
more than 128 elements from results. Useful for removing
282+
embedding-like properties from database responses. Default is False.
252283
253284
Returns:
254285
dict[str, Any]: the graph schema information in a structured format.
255286
"""
256287
node_properties = [
257288
data["output"]
258289
for data in query_database(
259-
driver,
260-
NODE_PROPERTIES_QUERY,
290+
driver=driver,
291+
query=NODE_PROPERTIES_QUERY,
261292
params={
262293
"EXCLUDED_LABELS": EXCLUDED_LABELS
263294
+ [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL]
264295
},
296+
database=database,
297+
timeout=timeout,
298+
sanitize=sanitize,
265299
)
266300
]
267301

268302
rel_properties = [
269303
data["output"]
270304
for data in query_database(
271-
driver, REL_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": EXCLUDED_RELS}
305+
driver=driver,
306+
query=REL_PROPERTIES_QUERY,
307+
params={"EXCLUDED_LABELS": EXCLUDED_RELS},
308+
database=database,
309+
timeout=timeout,
310+
sanitize=sanitize,
272311
)
273312
]
274313

275314
relationships = [
276315
data["output"]
277316
for data in query_database(
278-
driver,
279-
REL_QUERY,
317+
driver=driver,
318+
query=REL_QUERY,
280319
params={
281320
"EXCLUDED_LABELS": EXCLUDED_LABELS
282321
+ [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL]
283322
},
323+
database=database,
324+
timeout=timeout,
325+
sanitize=sanitize,
284326
)
285327
]
286328

287329
# Get constraints and indexes
288330
try:
289-
constraint = query_database(driver, "SHOW CONSTRAINTS")
290-
index = query_database(driver, INDEX_QUERY)
331+
constraint = query_database(
332+
driver=driver,
333+
query="SHOW CONSTRAINTS",
334+
database=database,
335+
timeout=timeout,
336+
sanitize=sanitize,
337+
)
338+
index = query_database(
339+
driver=driver,
340+
query=INDEX_QUERY,
341+
database=database,
342+
timeout=timeout,
343+
sanitize=sanitize,
344+
)
291345
except ClientError:
292346
constraint = []
293347
index = []
@@ -299,7 +353,13 @@ def get_structured_schema(
299353
"metadata": {"constraint": constraint, "index": index},
300354
}
301355
if is_enhanced:
302-
enhance_schema(driver=driver, structured_schema=structured_schema)
356+
enhance_schema(
357+
driver=driver,
358+
structured_schema=structured_schema,
359+
database=database,
360+
timeout=timeout,
361+
sanitize=sanitize,
362+
)
303363
return structured_schema
304364

305365

@@ -436,6 +496,9 @@ def _build_str_clauses(
436496
label_or_type: str,
437497
exhaustive: bool,
438498
prop_index: Optional[List[Any]] = None,
499+
database: Optional[str] = None,
500+
timeout: Optional[float] = None,
501+
sanitize: bool = False,
439502
) -> Tuple[List[str], List[str]]:
440503
"""
441504
Build Cypher clauses for string property statistics.
@@ -455,6 +518,13 @@ def _build_str_clauses(
455518
prop_index (Optional[List[Any]]): Optional metadata about the property's
456519
index. If provided, certain optimizations are applied based on
457520
distinct value limits and index availability.
521+
database (Optional[str]): The name of the database to connect to. Default is 'neo4j'.
522+
timeout (Optional[float]): The timeout for transactions in seconds.
523+
Useful for terminating long-running queries.
524+
By default, there is no timeout set.
525+
sanitize (bool): A flag to indicate whether to remove lists with
526+
more than 128 elements from results. Useful for removing
527+
embedding-like properties from database responses. Default is False.
458528
459529
Returns:
460530
Tuple[List[str], List[str]]:
@@ -471,9 +541,14 @@ def _build_str_clauses(
471541
and prop_index[0].get("distinctValues") <= DISTINCT_VALUE_LIMIT
472542
):
473543
distinct_values = query_database(
474-
driver,
475-
f"CALL apoc.schema.properties.distinct("
476-
f"'{label_or_type}', '{prop_name}') YIELD value",
544+
driver=driver,
545+
query=(
546+
f"CALL apoc.schema.properties.distinct("
547+
f"'{label_or_type}', '{prop_name}') YIELD value"
548+
),
549+
database=database,
550+
timeout=timeout,
551+
sanitize=sanitize,
477552
)[0]["value"]
478553
return_clauses.append(
479554
(f"values: {distinct_values}," f" distinct_count: {len(distinct_values)}")
@@ -582,6 +657,9 @@ def get_enhanced_schema_cypher(
582657
exhaustive: bool,
583658
sample_size: int = 5,
584659
is_relationship: bool = False,
660+
database: Optional[str] = None,
661+
timeout: Optional[float] = None,
662+
sanitize: bool = False,
585663
) -> str:
586664
"""
587665
Build a Cypher query for enhanced schema information.
@@ -605,6 +683,13 @@ def get_enhanced_schema_cypher(
605683
exhaustive is False. Defaults to 5.
606684
is_relationship (bool, optional): Indicates if the query is for
607685
a relationship type (True) or a node label (False). Defaults to False.
686+
database (Optional[str]): The name of the database to connect to. Default is 'neo4j'.
687+
timeout (Optional[float]): The timeout for transactions in seconds.
688+
Useful for terminating long-running queries.
689+
By default, there is no timeout set.
690+
sanitize (bool): A flag to indicate whether to remove lists with
691+
more than 128 elements from results. Useful for removing
692+
embedding-like properties from database responses. Default is False.
608693
609694
Returns:
610695
str: A Cypher query string that gathers enhanced property metadata.
@@ -643,6 +728,9 @@ def get_enhanced_schema_cypher(
643728
label_or_type=label_or_type,
644729
exhaustive=exhaustive,
645730
prop_index=prop_index,
731+
database=database,
732+
timeout=timeout,
733+
sanitize=sanitize,
646734
)
647735
with_clauses += str_w_clauses
648736
return_clauses += str_r_clauses
@@ -682,6 +770,9 @@ def enhance_properties(
682770
structured_schema: Dict[str, Any],
683771
prop_dict: Dict[str, Any],
684772
is_relationship: bool,
773+
database: Optional[str] = None,
774+
timeout: Optional[float] = None,
775+
sanitize: bool = False,
685776
) -> None:
686777
"""
687778
Enhance the structured schema with detailed statistics for a single node label or relationship type.
@@ -699,6 +790,13 @@ def enhance_properties(
699790
relationship type to be enhanced.
700791
is_relationship (bool): Indicates whether the properties to be enhanced belong to a relationship
701792
(True) or a node (False).
793+
database (Optional[str]): The name of the database to connect to. Default is 'neo4j'.
794+
timeout (Optional[float]): The timeout for transactions in seconds.
795+
Useful for terminating long-running queries.
796+
By default, there is no timeout set.
797+
sanitize (bool): A flag to indicate whether to remove lists with
798+
more than 128 elements from results. Useful for removing
799+
embedding-like properties from database responses. Default is False.
702800
703801
Returns:
704802
None
@@ -720,6 +818,9 @@ def enhance_properties(
720818
properties=props,
721819
exhaustive=count < EXHAUSTIVE_SEARCH_LIMIT,
722820
is_relationship=is_relationship,
821+
database=database,
822+
timeout=timeout,
823+
sanitize=sanitize,
723824
)
724825
# Due to schema-flexible nature of neo4j errors can happen
725826
try:
@@ -733,9 +834,12 @@ def enhance_properties(
733834
else {}
734835
)
735836
enhanced_info = query_database(
736-
driver,
737-
enhanced_cypher,
837+
driver=driver,
838+
query=enhanced_cypher,
738839
session_params=session_params,
840+
database=database,
841+
timeout=timeout,
842+
sanitize=sanitize,
739843
)[0]["output"]
740844
for prop in props:
741845
if prop["property"] in enhanced_info:
@@ -744,7 +848,13 @@ def enhance_properties(
744848
return
745849

746850

747-
def enhance_schema(driver: neo4j.Driver, structured_schema: Dict[str, Any]) -> None:
851+
def enhance_schema(
852+
driver: neo4j.Driver,
853+
structured_schema: Dict[str, Any],
854+
database: Optional[str] = None,
855+
timeout: Optional[float] = None,
856+
sanitize: bool = False,
857+
) -> None:
748858
"""
749859
Enhance the structured schema with detailed property statistics.
750860
@@ -759,18 +869,34 @@ def enhance_schema(driver: neo4j.Driver, structured_schema: Dict[str, Any]) -> N
759869
structured_schema (Dict[str, Any]): The initial structured schema
760870
containing node and relationship properties, which will be updated
761871
with enhanced statistics.
872+
database (Optional[str]): The name of the database to connect to. Default is 'neo4j'.
873+
timeout (Optional[float]): The timeout for transactions in seconds.
874+
Useful for terminating long-running queries.
875+
By default, there is no timeout set.
876+
sanitize (bool): A flag to indicate whether to remove lists with
877+
more than 128 elements from results. Useful for removing
878+
embedding-like properties from database responses. Default is False.
762879
763880
Returns:
764881
None
765882
"""
766-
schema_counts = query_database(driver, SCHEMA_COUNTS_QUERY)
883+
schema_counts = query_database(
884+
driver=driver,
885+
query=SCHEMA_COUNTS_QUERY,
886+
database=database,
887+
timeout=timeout,
888+
sanitize=sanitize,
889+
)
767890
# Update node info
768891
for node in schema_counts[0]["nodes"]:
769892
enhance_properties(
770893
driver=driver,
771894
structured_schema=structured_schema,
772895
prop_dict=node,
773896
is_relationship=False,
897+
database=database,
898+
timeout=timeout,
899+
sanitize=sanitize,
774900
)
775901
# Update rel info
776902
for rel in schema_counts[0]["relationships"]:
@@ -779,4 +905,7 @@ def enhance_schema(driver: neo4j.Driver, structured_schema: Dict[str, Any]) -> N
779905
structured_schema=structured_schema,
780906
prop_dict=rel,
781907
is_relationship=True,
908+
database=database,
909+
timeout=timeout,
910+
sanitize=sanitize,
782911
)

tests/unit/test_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939

4040
def _query_return_value(*args: Any, **kwargs: Any) -> list[Any]:
41-
query = args[1]
41+
query = kwargs.get("query", args[1] if len(args) > 1 else None)
4242
if NODE_PROPERTIES_QUERY in query:
4343
return [
4444
{

0 commit comments

Comments
 (0)