Skip to content

Commit c531f4e

Browse files
stellasiaoskarhane
andauthored
RAG Pipeline (neo4j#50)
* RAG Pipeline * Ruff * Copyright header * Add doc * Examples * Add examples to template * RAG Pipeline * Ruff * Copyright header * Add doc * Examples * Add examples to template * Fix tests * Observer interface + types for RAG * RAG Pipeline * Ruff * Copyright header * Add doc * Examples * Add examples to template * RAG Pipeline * Fix tests * Observer interface + types for RAG * Mypy-related stuffs * Fix e2e test * More e2e tests * Fix merge * Turn prompt package into a module since it contains a single file * Remove observer * Update examples/rag_custom_prompt.py Co-authored-by: Oskar Hane <oh@oskarhane.com> * Use Literal type * Update error handling (required small refactoring to avoid circular imports) * Merge with main * Rename RAG to GraphRAG * Improved doc and naming * Ruff * Fix class name in docs * Rename files to match class name "GraphRAG" * Rename test cases for consistency with name "GraphRAG" * Introduce LLMResponse model * Update changelog * Fix examples * Ignore type so that mypy does not complain about None VS module types * Changes to accept langchain (or other) models as input to the RAG pipeline + working example * Update Text2CypherRetriever with new LLM interface * Update poetry.lock as required by poetry during CI * Fix tests * Move example in README (and fix the example...) * Update imports --------- Co-authored-by: Oskar Hane <oh@oskarhane.com>
1 parent dd60687 commit c531f4e

30 files changed

+1195
-71
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
- Added `upsert_vector` utility function for attaching vectors to node properties.
88
- Introduced `Neo4jInsertionError` for handling insertion failures in Neo4j.
99
- Included Pinecone and Weaviate retrievers in neo4j_genai.retrievers.
10+
- Introduced the GraphRAG object, enabling a full RAG (Retrieval-Augmented Generation) pipeline with context retrieval, prompt formatting, and answer generation.
11+
- Added PromptTemplate and RagTemplate for customizable prompt generation.
12+
- Added LLMInterface with implementation for OpenAI LLM.
1013

1114
### Changed
1215
- Refactored import paths for retrievers to neo4j_genai.retrievers.

README.md

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,21 @@ While the library has more retrievers than shown here, the following examples sh
2828

2929
Assumption: Neo4j running with populated vector index in place.
3030

31+
In the following example, we use a simple vector search as retriever,
32+
that will perform a similarity search over the `index-name` vector index
33+
in Neo4j.
34+
3135
```python
3236
from neo4j import GraphDatabase
3337
from neo4j_genai.retrievers import VectorRetriever
34-
from langchain_openai import OpenAIEmbeddings
38+
from neo4j_genai.llm import OpenAILLM
39+
from neo4j_genai.generation import GraphRAG
40+
from neo4j_genai.embeddings.openai import OpenAIEmbeddings
3541

3642
URI = "neo4j://localhost:7687"
3743
AUTH = ("neo4j", "password")
3844

39-
INDEX_NAME = "embedding-name"
45+
INDEX_NAME = "index-name"
4046

4147
# Connect to Neo4j database
4248
driver = GraphDatabase.driver(URI, auth=AUTH)
@@ -47,9 +53,17 @@ embedder = OpenAIEmbeddings(model="text-embedding-3-large")
4753
# Initialize the retriever
4854
retriever = VectorRetriever(driver, INDEX_NAME, embedder)
4955

50-
# Run the similarity search
56+
# Initialize the LLM
57+
# Note: the OPENAI_API_KEY must be in the env vars
58+
llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
59+
60+
# Initialize the RAG pipeline
61+
rag = GraphRAG(retriever=retriever, llm=llm)
62+
63+
# Query the graph
5164
query_text = "How do I do similarity search in Neo4j?"
52-
response = retriever.search(query_text=query_text, top_k=5)
65+
response = rag.search(query_text=query_text, retriever_config={"top_k": 5})
66+
print(response.answer)
5367
```
5468

5569
### Creating a vector index

docs/source/api.rst

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,17 @@
33
API Documentation
44
#################
55

6-
************************************
7-
Retrieval-Augmented Generation (RAG)
8-
************************************
9-
RAG is a technique that enhances Large Language Model (LLM) responses by retrieving
10-
source information from external data stores to augment generated responses.
11-
12-
This package enables Python developers to perform RAG using Neo4j.
13-
146
**********
157
Retrievers
168
**********
179

10+
RetrieverInterface
11+
===================
12+
13+
.. autoclass:: neo4j_genai.retrievers.base.Retriever
14+
:members:
15+
16+
1817
VectorRetriever
1918
===============
2019

@@ -89,6 +88,12 @@ Errors
8988

9089
* :class:`neo4j_genai.exceptions.SchemaFetchError`
9190

91+
* :class:`neo4j_genai.exceptions.RagInitializationError`
92+
93+
* :class:`neo4j_genai.exceptions.PromptMissingInputError`
94+
95+
* :class:`neo4j_genai.exceptions.LLMGenerationError`
96+
9297

9398
Neo4jGenAiError
9499
===============
@@ -165,3 +170,24 @@ SchemaFetchError
165170

166171
.. autoclass:: neo4j_genai.exceptions.SchemaFetchError
167172
:show-inheritance:
173+
174+
175+
RagInitializationError
176+
==========================
177+
178+
.. autoclass:: neo4j_genai.exceptions.RagInitializationError
179+
:show-inheritance:
180+
181+
182+
PromptMissingInputError
183+
==========================
184+
185+
.. autoclass:: neo4j_genai.exceptions.PromptMissingInputError
186+
:show-inheritance:
187+
188+
189+
LLMGenerationError
190+
==========================
191+
192+
.. autoclass:: neo4j_genai.exceptions.LLMGenerationError
193+
:show-inheritance:

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
"sphinx.ext.intersphinx",
4545
"sphinx.ext.napoleon",
4646
"sphinx.ext.viewcode",
47+
"sphinx.ext.autosectionlabel",
4748
]
4849

4950
# The suffix(es) of source filenames.

docs/source/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ fast to ship new features and high performing patterns and methods.
1616
Topics
1717
******
1818

19+
+ :ref:`rag-documentation`
1920
+ :ref:`api-documentation`
2021
+ :ref:`types-documentation`
2122

@@ -24,6 +25,7 @@ Topics
2425
:caption: Contents:
2526
:hidden:
2627

28+
rag.rst
2729
api.rst
2830
types.rst
2931

docs/source/rag.rst

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
.. _rag-documentation:
2+
3+
RAG Documentation
4+
#################
5+
6+
************************************
7+
Retrieval-Augmented Generation (RAG)
8+
************************************
9+
RAG is a technique that enhances Large Language Model (LLM) responses by retrieving
10+
source information from external data stores to augment generated responses.
11+
12+
This package enables Python developers to perform RAG using Neo4j.
13+
14+
15+
************************************
16+
Overview
17+
************************************
18+
19+
.. code:: python
20+
21+
from neo4j import GraphDatabase
22+
from neo4j_genai.retrievers import VectorRetriever
23+
from neo4j_genai.llm import OpenAILLM
24+
from neo4j_genai.generation import GraphRAG
25+
from neo4j_genai.embeddings.openai import OpenAIEmbeddings
26+
27+
URI = "neo4j://localhost:7687"
28+
AUTH = ("neo4j", "password")
29+
30+
INDEX_NAME = "index-name"
31+
32+
# Connect to Neo4j database
33+
driver = GraphDatabase.driver(URI, auth=AUTH)
34+
35+
# Create Embedder object
36+
embedder = OpenAIEmbeddings(model="text-embedding-3-large")
37+
38+
# Initialize the retriever
39+
retriever = VectorRetriever(driver, INDEX_NAME, embedder)
40+
41+
# Initialize the LLM
42+
# Note: the OPENAI_API_KEY must be in the env vars
43+
llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
44+
45+
# Initialize the RAG pipeline
46+
rag = GraphRAG(retriever=retriever, llm=llm)
47+
48+
# Query the graph
49+
query_text = "How do I do similarity search in Neo4j?"
50+
response = rag.search(query_text=query_text, retriever_config={"top_k": 5})
51+
print(response.answer)
52+
53+
54+
The retriever can be any of the :ref:`supported retrievers<retrievers>`, or any class
55+
inheriting from the `Retriever` interface.
56+
57+
***************
58+
Advanced usage
59+
***************
60+
61+
62+
Using another LLM
63+
==================
64+
65+
This package only provide support for OpenAI LLM. If you need to use another LLM,
66+
you need to subclass the `LLMInterface`:
67+
68+
.. autoclass:: neo4j_genai.llm.LLMInterface
69+
:members:
70+
:show-inheritance:
71+
72+
Configuring the prompt
73+
=======================
74+
75+
Prompt are managed through `PromptTemplate` classes. More
76+
specifically, the `RAG` pipeline uses a `RagTemplate` with
77+
a default prompt. You can use another prompt by subclassing
78+
the `RagTemplate` class and passing it to the `RAG` pipeline
79+
object during initialization:
80+
81+
.. code:: python
82+
83+
from neo4j_genai.generation import RagTemplate, GraphRAG
84+
85+
# ...
86+
87+
prompt_template = RagTemplate(
88+
prompt="Answer the question {question} using context {context} and examples {examples}",
89+
expected_inputs=["context", "question", "examples"]
90+
)
91+
rag = GraphRAG(retriever=vector_retriever, llm=llm, prompt_template=prompt_template)
92+
93+
# ...
94+
95+
For more details, see:
96+
97+
.. autoclass:: neo4j_genai.generation.prompts.PromptTemplate
98+
:members:
99+
100+
and
101+
102+
.. autoclass:: neo4j_genai.generation.prompts.RagTemplate
103+
:members:

examples/graphrag.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""End to end example of building a RAG pipeline backed by a Neo4j database.
2+
Requires OPENAI_API_KEY to be in the env var.
3+
4+
This example illustrates:
5+
- VectorCypherRetriever with a custom formatter function to extract relevant
6+
context from neo4j result
7+
- Logging configuration
8+
"""
9+
10+
import logging
11+
import neo4j
12+
13+
from neo4j_genai.embeddings.openai import OpenAIEmbeddings
14+
from neo4j_genai.types import RetrieverResultItem
15+
from neo4j_genai.retrievers import VectorCypherRetriever
16+
from neo4j_genai.llm import OpenAILLM
17+
from neo4j_genai.generation import GraphRAG
18+
19+
URI = "neo4j://localhost:7687"
20+
AUTH = ("neo4j", "password")
21+
DATABASE = "neo4j"
22+
INDEX = "moviePlotsEmbedding"
23+
24+
25+
# setup logger config
26+
logger = logging.getLogger("neo4j_genai")
27+
logging.basicConfig(format="%(asctime)s - %(message)s")
28+
logger.setLevel(logging.DEBUG)
29+
30+
31+
def formatter(record: neo4j.Record) -> RetrieverResultItem:
32+
return RetrieverResultItem(content=f'{record.get("title")}: {record.get("plot")}')
33+
34+
35+
driver = neo4j.GraphDatabase.driver(
36+
URI,
37+
auth=AUTH,
38+
database=DATABASE,
39+
)
40+
41+
embedder = OpenAIEmbeddings()
42+
43+
retriever = VectorCypherRetriever(
44+
driver,
45+
index_name=INDEX,
46+
retrieval_query="with node, score return node.title as title, node.plot as plot",
47+
format_record_function=formatter,
48+
embedder=embedder,
49+
)
50+
51+
llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
52+
53+
rag = GraphRAG(retriever=retriever, llm=llm)
54+
55+
result = rag.search("Tell me more about Avatar movies")
56+
print(result.answer)
57+
58+
driver.close()

examples/graphrag_custom_prompt.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""End to end example of building a RAG pipeline backed by a Neo4j database.
2+
Requires OPENAI_API_KEY to be in the env var.
3+
4+
This example illustrates:
5+
- VectorCypherRetriever with a custom formatter function to extract relevant
6+
context from neo4j result
7+
- Use of a custom prompt for RAG
8+
- Logging configuration
9+
"""
10+
11+
import logging
12+
import neo4j
13+
14+
from neo4j_genai.types import RetrieverResultItem
15+
from neo4j_genai.embeddings.openai import OpenAIEmbeddings
16+
from neo4j_genai.retrievers import VectorCypherRetriever
17+
from neo4j_genai.generation import GraphRAG, RagTemplate
18+
from neo4j_genai.llm import OpenAILLM
19+
20+
URI = "neo4j://localhost:7687"
21+
AUTH = ("neo4j", "password")
22+
DATABASE = "neo4j"
23+
INDEX = "moviePlotsEmbedding"
24+
25+
26+
# setup logger config
27+
logger = logging.getLogger("neo4j_genai")
28+
logging.basicConfig(format="%(asctime)s - %(message)s")
29+
logger.setLevel(logging.DEBUG)
30+
31+
32+
def formatter(record: neo4j.Record) -> RetrieverResultItem:
33+
return RetrieverResultItem(content=f'{record.get("title")}: {record.get("plot")}')
34+
35+
36+
driver = neo4j.GraphDatabase.driver(
37+
URI,
38+
auth=AUTH,
39+
database=DATABASE,
40+
)
41+
42+
embedder = OpenAIEmbeddings()
43+
44+
retriever = VectorCypherRetriever(
45+
driver,
46+
index_name=INDEX,
47+
retrieval_query="with node, score return node.title as title, node.plot as plot",
48+
format_record_function=formatter,
49+
embedder=embedder,
50+
)
51+
52+
llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
53+
54+
template = RagTemplate(
55+
template="""You are an expert at movies and actors. Your task is to
56+
answer the user's question based on the provided context. Use only the
57+
information within that context.
58+
59+
Context:
60+
{context}
61+
62+
Question:
63+
{query}
64+
65+
Answer:
66+
"""
67+
)
68+
69+
rag = GraphRAG(retriever=retriever, llm=llm, prompt_template=template)
70+
71+
result = rag.search("Tell me more about Avatar movies")
72+
print(result.answer)
73+
74+
driver.close()

0 commit comments

Comments
 (0)