Skip to content

Commit 197afe2

Browse files
committed
Merge branch 'main' of https://github.com/neo4j/neo4j-genai-python into feature/kg_builder
# Conflicts: # CHANGELOG.md
2 parents 74b60e9 + 149d1e9 commit 197afe2

File tree

14 files changed

+246
-41
lines changed

14 files changed

+246
-41
lines changed

.github/workflows/pr-e2e-tests.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@ jobs:
1515
runs-on: ubuntu-latest
1616
strategy:
1717
matrix:
18-
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
18+
python-version: ['3.8', '3.12']
1919
neo4j-version:
2020
- 5
2121
neo4j-edition:
22-
- community
2322
- enterprise
2423
services:
2524
t2v-transformers:
@@ -50,6 +49,10 @@ jobs:
5049
steps:
5150
- name: Check out repository code
5251
uses: actions/checkout@v4
52+
- name: Docker Prune
53+
run: |
54+
docker system prune -af
55+
docker volume prune -f
5356
- name: Set up Python ${{ matrix.python-version }}
5457
uses: actions/setup-python@v5
5558
with:
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
name: 'Neo4j-GenAI Scheduled E2E Tests'
2+
3+
on:
4+
schedule:
5+
- cron: '0 6,9,12,15,18 * * 1-5' # Runs every 3 hours daytime on working days
6+
push:
7+
branches:
8+
- main
9+
10+
jobs:
11+
e2e-tests:
12+
runs-on: ubuntu-latest
13+
strategy:
14+
max-parallel: 6
15+
matrix:
16+
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
17+
neo4j-version:
18+
- 5
19+
neo4j-edition:
20+
- community
21+
- enterprise
22+
services:
23+
t2v-transformers:
24+
image: cr.weaviate.io/semitechnologies/transformers-inference:sentence-transformers-all-MiniLM-L6-v2-onnx
25+
env:
26+
ENABLE_CUDA: '0'
27+
weaviate:
28+
image: cr.weaviate.io/semitechnologies/weaviate:1.25.1
29+
env:
30+
TRANSFORMERS_INFERENCE_API: 'http://t2v-transformers:8080'
31+
AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true'
32+
DEFAULT_VECTORIZER_MODULE: 'text2vec-transformers'
33+
ENABLE_MODULES: 'text2vec-transformers'
34+
CLUSTER_HOSTNAME: 'node1'
35+
ports:
36+
- 8080:8080
37+
- 50051:50051
38+
neo4j:
39+
image: neo4j:${{ matrix.neo4j-version }}-${{ matrix.neo4j-edition }}
40+
env:
41+
NEO4J_AUTH: neo4j/password
42+
NEO4J_ACCEPT_LICENSE_AGREEMENT: 'eval'
43+
NEO4J_PLUGINS: '["apoc"]'
44+
ports:
45+
- 7687:7687
46+
- 7474:7474
47+
48+
steps:
49+
- name: Check out repository code
50+
uses: actions/checkout@v4
51+
- name: Docker Prune
52+
run: |
53+
docker system prune -af
54+
docker volume prune -f
55+
- name: Set up Python ${{ matrix.python-version }}
56+
uses: actions/setup-python@v5
57+
with:
58+
python-version: ${{ matrix.python-version }}
59+
- name: Install Poetry
60+
uses: snok/install-poetry@v1
61+
with:
62+
virtualenvs-create: true
63+
virtualenvs-in-project: true
64+
installer-parallel: true
65+
- name: Set Python version for Poetry
66+
run: poetry env use python${{ matrix.python-version }}
67+
- name: Load cached venv
68+
id: cached-poetry-dependencies
69+
uses: actions/cache@v4
70+
with:
71+
path: .venv
72+
key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
73+
- name: Install dependencies
74+
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
75+
run: poetry install --no-interaction --no-root
76+
- name: Install root project
77+
run: poetry install --no-interaction
78+
- name: Install dependencies
79+
run: poetry install --with dev
80+
- name: Wait for Weaviate to start
81+
shell: bash
82+
run: |
83+
set +e
84+
count=0; until curl -s --fail localhost:8080/v1/.well-known/ready; do ((count++)); [ $count -ge 10 ] && echo "Reached maximum retry limit" && exit 1; sleep 15; done
85+
- name: Run tests
86+
shell: bash
87+
run: |
88+
if [[ "${{ matrix.neo4j-edition }}" == "community" ]]; then
89+
poetry run pytest -m 'not enterprise_only' ./tests/e2e
90+
else
91+
poetry run pytest ./tests/e2e
92+
fi

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,13 @@
33
## Next
44

55
### Added
6+
- Add optional custom_prompt arg to the Text2CypherRetriever class.
67
- Introduced support for Component/Pipeline flexible architecture
78

9+
### Changed
10+
- `GraphRAG.search` method first parameter has been renamed `query_text` (was `query`) for consistency with the retrievers interface.
11+
- Made `GraphRAG.search` method backwards compatible with the query parameter, raising warnings to encourage using query_text instead.
12+
813
## 0.3.1
914

1015
### Fixed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Assumption: Neo4j running with a defined vector index
6868

6969
```python
7070
from neo4j import GraphDatabase
71-
from neo4j_genai.indexes import upsert_query
71+
from neo4j_genai.indexes import upsert_vector
7272

7373
URI = "neo4j://localhost:7687"
7474
AUTH = ("neo4j", "password")
@@ -78,7 +78,7 @@ driver = GraphDatabase.driver(URI, auth=AUTH)
7878

7979
# Upsert the vector
8080
vector = ...
81-
upsert_query(
81+
upsert_vector(
8282
driver,
8383
node_id=1,
8484
embedding_property="vectorProperty",

docs/source/index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ Note that the below example is not the only way you can upsert data into your Ne
115115
.. code:: python
116116
117117
from neo4j import GraphDatabase
118-
from neo4j_genai.indexes import upsert_query
118+
from neo4j_genai.indexes import upsert_vector
119119
120120
URI = "neo4j://localhost:7687"
121121
AUTH = ("neo4j", "password")
@@ -125,7 +125,7 @@ Note that the below example is not the only way you can upsert data into your Ne
125125
126126
# Upsert the vector
127127
vector = ...
128-
upsert_query(
128+
upsert_vector(
129129
driver,
130130
node_id=1,
131131
embedding_property="vectorProperty",

examples/graphrag_custom_prompt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def formatter(record: neo4j.Record) -> RetrieverResultItem:
6060
{context}
6161
6262
Question:
63-
{query}
63+
{query_text}
6464
6565
Answer:
6666
"""

src/neo4j_genai/generation/graphrag.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import logging
18+
import warnings
1819
from typing import Any, Optional
1920

2021
from pydantic import ValidationError
@@ -53,43 +54,60 @@ def __init__(
5354

5455
def search(
5556
self,
56-
query: str,
57+
query_text: str = "",
5758
examples: str = "",
5859
retriever_config: Optional[dict[str, Any]] = None,
5960
return_context: bool = False,
61+
query: Optional[str] = None,
6062
) -> RagResultModel:
6163
"""This method performs a full RAG search:
6264
1. Retrieval: context retrieval
6365
2. Augmentation: prompt formatting
6466
3. Generation: answer generation with LLM
6567
6668
Args:
67-
query (str): The user question
69+
query_text (str): The user question
6870
examples: Examples added to the LLM prompt.
6971
retriever_config (Optional[dict]): Parameters passed to the retriever
7072
search method; e.g.: top_k
7173
return_context (bool): Whether to return the retriever result (default: False)
74+
query (Optional[str]): The user question. Will be deprecated in favor of query_text.
7275
7376
Returns:
7477
RagResultModel: The LLM-generated answer
7578
7679
"""
7780
try:
81+
if query is not None:
82+
if query_text:
83+
warnings.warn(
84+
"Both 'query' and 'query_text' are provided, 'query_text' will be used.",
85+
DeprecationWarning,
86+
stacklevel=2,
87+
)
88+
elif isinstance(query, str):
89+
warnings.warn(
90+
"'query' is deprecated and will be removed in a future version, please use 'query_text' instead.",
91+
DeprecationWarning,
92+
stacklevel=2,
93+
)
94+
query_text = query
95+
7896
validated_data = RagSearchModel(
79-
query=query,
97+
query_text=query_text,
8098
examples=examples,
8199
retriever_config=retriever_config or {},
82100
return_context=return_context,
83101
)
84102
except ValidationError as e:
85103
raise SearchValidationError(e.errors())
86-
query = validated_data.query
104+
query_text = validated_data.query_text
87105
retriever_result: RetrieverResult = self.retriever.search(
88-
query_text=query, **validated_data.retriever_config
106+
query_text=query_text, **validated_data.retriever_config
89107
)
90108
context = "\n".join(item.content for item in retriever_result.items)
91109
prompt = self.prompt_template.format(
92-
query=query, context=context, examples=validated_data.examples
110+
query_text=query_text, context=context, examples=validated_data.examples
93111
)
94112
logger.debug(f"RAG: retriever_result={retriever_result}")
95113
logger.debug(f"RAG: prompt={prompt}")

src/neo4j_genai/generation/prompts.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,14 @@ class RagTemplate(PromptTemplate):
8787
{examples}
8888
8989
Question:
90-
{query}
90+
{query_text}
9191
9292
Answer:
9393
"""
94-
EXPECTED_INPUTS = ["context", "query", "examples"]
94+
EXPECTED_INPUTS = ["context", "query_text", "examples"]
9595

96-
def format(self, query: str, context: str, examples: str) -> str:
97-
return super().format(query=query, context=context, examples=examples)
96+
def format(self, query_text: str, context: str, examples: str) -> str:
97+
return super().format(query_text=query_text, context=context, examples=examples)
9898

9999

100100
class Text2CypherTemplate(PromptTemplate):

src/neo4j_genai/generation/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def check_llm(cls, value: Any) -> Any:
3939

4040

4141
class RagSearchModel(BaseModel):
42-
query: str
42+
query_text: str
4343
examples: str = ""
4444
retriever_config: dict[str, Any] = {}
4545
return_context: bool = False

src/neo4j_genai/retrievers/text2cypher.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class Text2CypherRetriever(Retriever):
5555
llm (neo4j_genai.generation.llm.LLMInterface): LLM object to generate the Cypher query.
5656
neo4j_schema (Optional[str]): Neo4j schema used to generate the Cypher query.
5757
examples (Optional[list[str], optional): Optional user input/query pairs for the LLM to use as examples.
58+
custom_prompt (Optional[str]): Optional custom prompt to use instead of auto generated prompt. Will not include the neo4j_schema or examples args, if provided.
5859
5960
Raises:
6061
RetrieverInitializationError: If validation of the input arguments fail.
@@ -69,6 +70,7 @@ def __init__(
6970
result_formatter: Optional[
7071
Callable[[neo4j.Record], RetrieverResultItem]
7172
] = None,
73+
custom_prompt: Optional[str] = None,
7274
) -> None:
7375
try:
7476
driver_model = Neo4jDriverModel(driver=driver)
@@ -82,6 +84,7 @@ def __init__(
8284
neo4j_schema_model=neo4j_schema_model,
8385
examples=examples,
8486
result_formatter=result_formatter,
87+
custom_prompt=custom_prompt,
8588
)
8689
except ValidationError as e:
8790
raise RetrieverInitializationError(e.errors()) from e
@@ -90,12 +93,17 @@ def __init__(
9093
self.llm = validated_data.llm_model.llm
9194
self.examples = validated_data.examples
9295
self.result_formatter = validated_data.result_formatter
96+
self.custom_prompt = validated_data.custom_prompt
9397
try:
94-
self.neo4j_schema = (
95-
validated_data.neo4j_schema_model.neo4j_schema
96-
if validated_data.neo4j_schema_model
97-
else get_schema(validated_data.driver_model.driver)
98-
)
98+
if (
99+
not validated_data.custom_prompt
100+
): # don't need schema for a custom prompt
101+
self.neo4j_schema = (
102+
validated_data.neo4j_schema_model.neo4j_schema
103+
if validated_data.neo4j_schema_model
104+
else get_schema(validated_data.driver_model.driver)
105+
)
106+
99107
except (Neo4jError, DriverError) as e:
100108
error_message = getattr(e, "message", str(e))
101109
raise SchemaFetchError(
@@ -124,12 +132,16 @@ def get_search_results(
124132
except ValidationError as e:
125133
raise SearchValidationError(e.errors()) from e
126134

127-
prompt_template = Text2CypherTemplate()
128-
prompt = prompt_template.format(
129-
schema=self.neo4j_schema,
130-
examples="\n".join(self.examples) if self.examples else "",
131-
query=validated_data.query_text,
132-
)
135+
if not self.custom_prompt:
136+
prompt_template = Text2CypherTemplate()
137+
prompt = prompt_template.format(
138+
schema=self.neo4j_schema,
139+
examples="\n".join(self.examples) if self.examples else "",
140+
query=validated_data.query_text,
141+
)
142+
else:
143+
prompt = self.custom_prompt
144+
133145
logger.debug("Text2CypherRetriever prompt: %s", prompt)
134146

135147
try:

0 commit comments

Comments
 (0)