Skip to content

Commit 4d49525

Browse files
authored
Use the neo4j_database parameter everywhere (#216)
* Use self.neo4j_database for all queries in Neo4jWriter * Make sure all execute_query can be run against a custom database * Update CHANGELOG * Update docstring + update examples not to use undocumented feature for neo4j driver * Expose neo4j_database in SimpleKGBuilder * Update CHANGELOG * Simplify changelog
1 parent 8285d56 commit 4d49525

29 files changed

+158
-59
lines changed

CHANGELOG.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,15 @@
33
## Next
44

55
### Added
6-
- Introduced optional lexical graph configuration for SimpleKGPipeline, enhancing flexibility in customizing node labels and relationship types in the lexical graph.
7-
- Ability to provide description and list of properties for entities and relations in the SimpleKGPipeline constructor.
6+
- Introduced optional lexical graph configuration for `SimpleKGPipeline`, enhancing flexibility in customizing node labels and relationship types in the lexical graph.
7+
- Introduced optional `neo4j_database` parameter for `SimpleKGPipeline`, `Neo4jChunkReader`and `Text2CypherRetriever`.
8+
- Ability to provide description and list of properties for entities and relations in the `SimpleKGPipeline` constructor.
9+
10+
### Fixed
11+
- `neo4j_database` parameter is now used for all queries in the `Neo4jWriter`.
12+
13+
### Changed
14+
- Updated all examples to use `neo4j_database` parameter instead of an undocumented neo4j driver constructor.
815

916
## 1.2.0
1017

examples/build_graph/simple_kg_builder_from_pdf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ async def define_and_run_pipeline(
5050
entities=ENTITIES,
5151
relations=RELATIONS,
5252
potential_schema=POTENTIAL_SCHEMA,
53+
neo4j_database=DATABASE,
5354
)
5455
return await kg_builder.run_async(file_path=str(file_path))
5556

@@ -62,7 +63,7 @@ async def main() -> PipelineResult:
6263
"response_format": {"type": "json_object"},
6364
},
6465
)
65-
with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
66+
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
6667
res = await define_and_run_pipeline(driver, llm)
6768
await llm.async_client.close()
6869
return res

examples/build_graph/simple_kg_builder_from_text.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# Neo4j db infos
2222
URI = "neo4j://localhost:7687"
2323
AUTH = ("neo4j", "password")
24-
DATABASE = "neo4j"
24+
DATABASE = "newdb"
2525

2626
# Text to process
2727
TEXT = """The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House Atreides,
@@ -67,6 +67,7 @@ async def define_and_run_pipeline(
6767
relations=RELATIONS,
6868
potential_schema=POTENTIAL_SCHEMA,
6969
from_pdf=False,
70+
neo4j_database=DATABASE,
7071
)
7172
return await kg_builder.run_async(text=TEXT)
7273

@@ -79,7 +80,7 @@ async def main() -> PipelineResult:
7980
"response_format": {"type": "json_object"},
8081
},
8182
)
82-
with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
83+
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
8384
res = await define_and_run_pipeline(driver, llm)
8485
await llm.async_client.close()
8586
return res

examples/customize/answer/custom_prompt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
driver = neo4j.GraphDatabase.driver(
2424
URI,
2525
auth=AUTH,
26-
database=DATABASE,
2726
)
2827

2928
embedder = OpenAIEmbeddings()
@@ -33,6 +32,7 @@
3332
index_name=INDEX,
3433
retrieval_query="WITH node, score RETURN node.title as title, node.plot as plot",
3534
embedder=embedder,
35+
neo4j_database=DATABASE,
3636
)
3737

3838
llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})

examples/customize/answer/langchain_compatiblity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
driver = neo4j.GraphDatabase.driver(
2222
URI,
2323
auth=AUTH,
24-
database=DATABASE,
2524
)
2625

2726
embedder = OpenAIEmbeddings(model="text-embedding-ada-002")
@@ -31,6 +30,7 @@
3130
index_name=INDEX,
3231
retrieval_query="WITH node, score RETURN node.title as title, node.plot as plot",
3332
embedder=embedder, # type: ignore[arg-type, unused-ignore]
33+
neo4j_database=DATABASE,
3434
)
3535

3636
llm = ChatOpenAI(model="gpt-4o", temperature=0)

examples/customize/retrievers/result_formatter_vector_cypher_retriever.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def my_result_formatter(record: neo4j.Record) -> RetrieverResultItem:
3838
)
3939

4040

41-
with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
41+
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
4242
# Initialize the retriever
4343
retriever = VectorCypherRetriever(
4444
driver=driver,
@@ -48,7 +48,7 @@ def my_result_formatter(record: neo4j.Record) -> RetrieverResultItem:
4848
retrieval_query=RETRIEVAL_QUERY,
4949
result_formatter=my_result_formatter,
5050
# optionally, set neo4j database
51-
# neo4j_database="neo4j",
51+
neo4j_database=DATABASE,
5252
)
5353

5454
# Perform the similarity search for a text query

examples/customize/retrievers/result_formatter_vector_retriever.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636

3737
# Connect to Neo4j database
38-
driver = neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
38+
driver = neo4j.GraphDatabase.driver(URI, auth=AUTH)
3939

4040

4141
query_text = "Find a movie about astronauts"
@@ -52,6 +52,7 @@
5252
index_name=INDEX_NAME,
5353
embedder=OpenAIEmbeddings(),
5454
return_properties=["title", "plot"],
55+
neo4j_database=DATABASE,
5556
)
5657
print(retriever.search(query_text=query_text, top_k=top_k_results))
5758
print()

examples/question_answering/graphrag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def formatter(record: neo4j.Record) -> RetrieverResultItem:
3535
driver = neo4j.GraphDatabase.driver(
3636
URI,
3737
auth=AUTH,
38-
database=DATABASE,
3938
)
4039

4140
embedder = OpenAIEmbeddings()
@@ -46,6 +45,7 @@ def formatter(record: neo4j.Record) -> RetrieverResultItem:
4645
retrieval_query="with node, score return node.title as title, node.plot as plot",
4746
result_formatter=formatter,
4847
embedder=embedder,
48+
neo4j_database=DATABASE,
4949
)
5050

5151
llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})

examples/retrieve/hybrid_cypher_retriever.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# the name of all actors starring in that movie
2525
RETRIEVAL_QUERY = " MATCH (node)<-[:ACTED_IN]-(p:Person) RETURN node.title as movieTitle, node.plot as moviePlot, collect(p.name) as actors, score as similarityScore"
2626

27-
with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
27+
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
2828
# Initialize the retriever
2929
retriever = HybridCypherRetriever(
3030
driver=driver,
@@ -37,7 +37,7 @@
3737
# (see corresponding example in 'customize' directory)
3838
# result_formatter=None,
3939
# optionally, set neo4j database
40-
# neo4j_database="neo4j",
40+
neo4j_database=DATABASE,
4141
)
4242

4343
# Perform the similarity search for a text query

examples/retrieve/hybrid_retriever.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
FULLTEXT_INDEX_NAME = "movieFulltext"
1919

2020

21-
with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
21+
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
2222
# Initialize the retriever
2323
retriever = HybridRetriever(
2424
driver=driver,
@@ -31,7 +31,7 @@
3131
# (see corresponding example in 'customize' directory)
3232
# result_formatter=None,
3333
# optionally, set neo4j database
34-
# neo4j_database="neo4j",
34+
neo4j_database=DATABASE,
3535
)
3636

3737
# Perform the similarity search for a text query

examples/retrieve/similarity_search_for_text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
INDEX_NAME = "moviePlotsEmbedding"
1818

1919

20-
with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
20+
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
2121
# Initialize the retriever
2222
retriever = VectorRetriever(
2323
driver=driver,
@@ -29,7 +29,7 @@
2929
# (see corresponding example in 'customize' directory)
3030
# result_formatter=None,
3131
# optionally, set neo4j database
32-
# neo4j_database="neo4j",
32+
neo4j_database=DATABASE,
3333
)
3434

3535
# Perform the similarity search for a text query

examples/retrieve/similarity_search_for_vector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
INDEX_NAME = "moviePlotsEmbedding"
1818

1919

20-
with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
20+
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
2121
# Initialize the retriever
2222
retriever = VectorRetriever(
2323
driver=driver,
2424
index_name=INDEX_NAME,
25+
neo4j_database=DATABASE,
2526
)
2627

2728
# Perform the similarity search for a vector query

examples/retrieve/text2cypher_search.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
# optionally, you can also provide your own prompt
4848
# for the text2Cypher generation step
4949
# custom_prompt="",
50+
neo4j_database=DATABASE,
5051
)
5152

5253
# Generate a Cypher query using the LLM, send it to the Neo4j database, and return the results

examples/retrieve/vector_cypher_retriever.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# the name of all actors starring in that movie
2323
RETRIEVAL_QUERY = " MATCH (node)<-[:ACTED_IN]-(p:Person) RETURN node.title as movieTitle, node.plot as moviePlot, collect(p.name) as actors, score as similarityScore"
2424

25-
with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
25+
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
2626
# Initialize the retriever
2727
retriever = VectorCypherRetriever(
2828
driver=driver,
@@ -34,7 +34,7 @@
3434
# (see corresponding example in 'customize' directory)
3535
# result_formatter=None,
3636
# optionally, set neo4j database
37-
# neo4j_database="neo4j",
37+
neo4j_database=DATABASE,
3838
)
3939

4040
# Perform the similarity search for a text query

src/neo4j_graphrag/experimental/components/kg_writer.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class Neo4jWriter(KGWriter):
8484
8585
Args:
8686
driver (neo4j.driver): The Neo4j driver to connect to the database.
87-
neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided.
87+
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
8888
batch_size (int): The number of nodes or relationships to write to the database in a batch. Defaults to 1000.
8989
9090
Example:
@@ -99,7 +99,7 @@ class Neo4jWriter(KGWriter):
9999
AUTH = ("neo4j", "password")
100100
DATABASE = "neo4j"
101101
102-
driver = GraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
102+
driver = GraphDatabase.driver(URI, auth=AUTH)
103103
writer = Neo4jWriter(driver=driver, neo4j_database=DATABASE)
104104
105105
pipeline = Pipeline()
@@ -119,10 +119,11 @@ def __init__(
119119
self.is_version_5_23_or_above = self._check_if_version_5_23_or_above()
120120

121121
def _db_setup(self) -> None:
122-
# create index on __Entity__.id
122+
# create index on __KGBuilder__.id
123123
# used when creating the relationships
124124
self.driver.execute_query(
125-
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)"
125+
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)",
126+
database_=self.neo4j_database,
126127
)
127128

128129
@staticmethod
@@ -150,10 +151,16 @@ def _upsert_nodes(
150151
parameters = {"rows": self._nodes_to_rows(nodes, lexical_graph_config)}
151152
if self.is_version_5_23_or_above:
152153
self.driver.execute_query(
153-
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters
154+
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE,
155+
parameters_=parameters,
156+
database_=self.neo4j_database,
154157
)
155158
else:
156-
self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters)
159+
self.driver.execute_query(
160+
UPSERT_NODE_QUERY,
161+
parameters_=parameters,
162+
database_=self.neo4j_database,
163+
)
157164

158165
def _get_version(self) -> tuple[int, ...]:
159166
records, _, _ = self.driver.execute_query(
@@ -187,10 +194,16 @@ def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
187194
parameters = {"rows": [rel.model_dump() for rel in rels]}
188195
if self.is_version_5_23_or_above:
189196
self.driver.execute_query(
190-
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters
197+
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE,
198+
parameters_=parameters,
199+
database_=self.neo4j_database,
191200
)
192201
else:
193-
self.driver.execute_query(UPSERT_RELATIONSHIP_QUERY, parameters_=parameters)
202+
self.driver.execute_query(
203+
UPSERT_RELATIONSHIP_QUERY,
204+
parameters_=parameters,
205+
database_=self.neo4j_database,
206+
)
194207

195208
@validate_call
196209
async def run(
@@ -202,7 +215,7 @@ async def run(
202215
203216
Args:
204217
graph (Neo4jGraph): The knowledge graph to upsert into the database.
205-
lexical_graph_config (LexicalGraphConfig):
218+
lexical_graph_config (LexicalGraphConfig): Node labels and relationship types for the lexical graph.
206219
"""
207220
try:
208221
self._db_setup()

src/neo4j_graphrag/experimental/components/neo4j_reader.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
from typing import Optional
18+
1719
import neo4j
1820
from pydantic import validate_call
1921

@@ -26,13 +28,39 @@
2628

2729

2830
class Neo4jChunkReader(Component):
31+
"""Reads text chunks from a Neo4j database.
32+
33+
Args:
34+
driver (neo4j.driver): The Neo4j driver to connect to the database.
35+
fetch_embeddings (bool): If True, the embedding property is also returned. Default to False.
36+
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
37+
38+
Example:
39+
40+
.. code-block:: python
41+
42+
from neo4j import GraphDatabase
43+
from neo4j_graphrag.experimental.components.neo4j_reader import Neo4jChunkReader
44+
45+
URI = "neo4j://localhost:7687"
46+
AUTH = ("neo4j", "password")
47+
DATABASE = "neo4j"
48+
49+
driver = GraphDatabase.driver(URI, auth=AUTH)
50+
reader = Neo4jChunkReader(driver=driver, neo4j_database=DATABASE)
51+
await reader.run()
52+
53+
"""
54+
2955
def __init__(
3056
self,
3157
driver: neo4j.Driver,
3258
fetch_embeddings: bool = False,
59+
neo4j_database: Optional[str] = None,
3360
):
3461
self.driver = driver
3562
self.fetch_embeddings = fetch_embeddings
63+
self.neo4j_database = neo4j_database
3664

3765
def _get_query(
3866
self,
@@ -56,12 +84,20 @@ async def run(
5684
self,
5785
lexical_graph_config: LexicalGraphConfig = LexicalGraphConfig(),
5886
) -> TextChunks:
87+
"""Reads text chunks from a Neo4j database.
88+
89+
Args:
90+
lexical_graph_config (LexicalGraphConfig): Node labels and relationship types for the lexical graph.
91+
"""
5992
query = self._get_query(
6093
lexical_graph_config.chunk_node_label,
6194
lexical_graph_config.chunk_index_property,
6295
lexical_graph_config.chunk_embedding_property,
6396
)
64-
result, _, _ = self.driver.execute_query(query)
97+
result, _, _ = self.driver.execute_query(
98+
query,
99+
database_=self.neo4j_database,
100+
)
65101
chunks = []
66102
for record in result:
67103
chunk = record.get("chunk")

0 commit comments

Comments
 (0)