Skip to content

Commit 4600422

Browse files
authored
Support for Azure OpenAI + Ollama + unit tests for the embeders and LLMs (#140)
* Add support for AzureOpenAI + remove OpenAI from llm imports + add tests for (Azure)OpenAI embedder and LLM + update doc * Add UT for vertex AI * Add google-cloud-aiplatform to dev dependencies * Fix vertexai - other imports are trying to import from 'self' (the vertexai.py), instead of the vertexai package * Ruff * Update code after manual tests * Ruff and co * Fix typo :'( * Update file names to prevent import errors * Revert file renaming for embeddings
1 parent ae08316 commit 4600422

22 files changed

+407
-241
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Next
44

55
### Added
6+
- Added AzureOpenAILLM and AzureOpenAIEmbeddings to support Azure served OpenAI models
67
- Add `template` validation in `PromptTemplate` class upon construction.
78
- `custom_prompt` arg is now converted to `Text2CypherTemplate` class within the `Text2CypherRetriever.get_search_results` method.
89
- `Text2CypherTemplate` and `RAGTemplate` prompt templates now require `query_text` arg and will error if it is not present. Previous `query_text` aliases may be used, but will warn of deprecation.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ from neo4j import GraphDatabase
110110
from neo4j_graphrag.retrievers import VectorRetriever
111111
from neo4j_graphrag.llm import OpenAILLM
112112
from neo4j_graphrag.generation import GraphRAG
113-
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
113+
from neo4j_graphrag.embeddings import OpenAIEmbeddings
114114

115115
URI = "neo4j://localhost:7687"
116116
AUTH = ("neo4j", "password")

docs/source/api.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ LLMInterface
176176
OpenAILLM
177177
=========
178178

179-
.. autoclass:: neo4j_graphrag.llm.OpenAILLM
179+
.. autoclass:: neo4j_graphrag.llm.openai.OpenAILLM
180180
:members:
181181

182182
VertexAILLM

docs/source/user_guide_kg_builder.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ It can be used in this way:
246246
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
247247
LLMEntityRelationExtractor,
248248
)
249-
from neo4j_graphrag.llm import OpenAILLM
249+
from neo4j_graphrag.llm.openai import OpenAILLM
250250
251251
extractor = LLMEntityRelationExtractor(
252252
llm=OpenAILLM(

docs/source/user_guide_rag.rst

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ In practice, it's done with only a few lines of code:
2323
2424
from neo4j import GraphDatabase
2525
from neo4j_graphrag.retrievers import VectorRetriever
26-
from neo4j_graphrag.llm import OpenAILLM
26+
from neo4j_graphrag.llm.openai import OpenAILLM
2727
from neo4j_graphrag.generation import GraphRAG
2828
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
2929
@@ -67,12 +67,47 @@ Each component can be configured individually: the LLM and the prompt.
6767
Using Another LLM Model
6868
========================
6969

70-
If OpenAI can not be used, there are two available alternatives:
70+
If OpenAI cannot be used directly, there are a few available alternatives:
7171

72-
1. Utilize any LangChain chat model.
73-
2. Implement a custom interface.
72+
- Use Azure OpenAI.
73+
- Use a local Ollama model.
74+
- Implement a custom interface.
75+
- Utilize any LangChain chat model.
7476

75-
Both options are illustrated below, using a local Ollama model as an example.
77+
All options are illustrated below, using a local Ollama model as an example.
78+
79+
Using Azure Open AI LLM
80+
-----------------------
81+
82+
It is possible to use Azure OpenAI switching to the `AzureOpenAILLM` class:
83+
84+
.. code:: python
85+
86+
from neo4j_graphrag.llm.openai import AzureOpenAILLM
87+
llm = AzureOpenAILLM(
88+
model_name="gpt-4o",
89+
azure_endpoint="https://example-endpoint.openai.azure.com/", # update with your endpoint
90+
api_version="2024-06-01", # update appropriate version
91+
api_key="ba3a46d86259405385f73f08078f588b", # api_key is optional and can also be set with OPENAI_API_KEY env var
92+
)
93+
llm.invoke("say something")
94+
95+
Check the OpenAI Python client [documentation](https://github.com/openai/openai-python?tab=readme-ov-file#microsoft-azure-openai)
96+
to learn more about the configuration.
97+
98+
99+
Using a Local Model via Ollama
100+
-------------------------------
101+
102+
Similarly to the official OpenAI Python client, the `OpenAILLM` can be
103+
used with Ollama. Assuming Ollama is running on the default address `127.0.0.1:11434`,
104+
it can be queried using the following:
105+
106+
.. code:: python
107+
108+
from neo4j_graphrag.llm.openai import OpenAILLM
109+
llm = OpenAILLM(api_key="ollama", base_url="http://127.0.0.1:11434/v1", model_name="orca-mini")
110+
llm.invoke("say something")
76111
77112
78113
Using a Model from LangChain
@@ -99,10 +134,11 @@ It is however not mandatory to use LangChain. The alternative is to implement
99134
a custom model.
100135

101136
Using a Custom Model
102-
-----------------------------
137+
--------------------
103138

104-
To avoid LangChain, developers can create a custom LLM class by subclassing
105-
the `LLMInterface`. Here's an example using the Python Ollama client:
139+
If the provided implementations do not match their needs, developers can create a
140+
custom LLM class by subclassing the `LLMInterface`.
141+
Here's an example using the Python Ollama client:
106142

107143

108144
.. code:: python
@@ -123,6 +159,10 @@ the `LLMInterface`. Here's an example using the Python Ollama client:
123159
content=response["message"]["content"]
124160
)
125161
162+
async def ainvoke(self, input: str) -> LLMResponse:
163+
return self.invoke(input) # TODO: implement async with ollama.AsyncClient
164+
165+
126166
# retriever = ...
127167
128168
llm = OllamaLLM("llama3:8b")
@@ -607,7 +647,7 @@ LLMs can be different.
607647
608648
from neo4j import GraphDatabase
609649
from neo4j_graphrag.retrievers import Text2CypherRetriever
610-
from neo4j_graphrag.llm import OpenAILLM
650+
from neo4j_graphrag.llm.openai import OpenAILLM
611651
612652
URI = "neo4j://localhost:7687"
613653
AUTH = ("neo4j", "password")
@@ -772,4 +812,4 @@ Drop a Vector Index
772812
773813
# Connect to Neo4j database
774814
driver = GraphDatabase.driver(URI, auth=AUTH)
775-
drop_index_if_exists(driver, INDEX_NAME)
815+
drop_index_if_exists(driver, INDEX_NAME)

examples/pipeline/kg_builder_from_pdf.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import asyncio
1818
import logging
19-
from typing import Any, Dict, List
2019

2120
import neo4j
2221
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
@@ -33,71 +32,12 @@
3332
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
3433
FixedSizeSplitter,
3534
)
36-
from neo4j_graphrag.experimental.pipeline import Component, DataModel
3735
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
3836
from neo4j_graphrag.llm import OpenAILLM
39-
from pydantic import BaseModel, validate_call
4037

4138
logging.basicConfig(level=logging.INFO)
4239

4340

44-
class DocumentChunkModel(DataModel):
45-
chunks: list[str]
46-
47-
48-
class DocumentChunker(Component):
49-
async def run(self, text: str) -> DocumentChunkModel:
50-
chunks = [t.strip() for t in text.split(".") if t.strip()]
51-
return DocumentChunkModel(chunks=chunks)
52-
53-
54-
class EntityModel(BaseModel):
55-
label: str
56-
properties: dict[str, str]
57-
58-
59-
class Neo4jGraph(DataModel):
60-
entities: list[dict[str, Any]]
61-
relations: list[dict[str, Any]]
62-
63-
64-
class ERExtractor(Component):
65-
async def _process_chunk(self, chunk: str, schema: str) -> Dict[str, Any]:
66-
return {
67-
"entities": [{"label": "Person", "properties": {"name": "John Doe"}}],
68-
"relations": [],
69-
}
70-
71-
async def run(self, chunks: List[str], schema: str) -> Neo4jGraph:
72-
tasks = [self._process_chunk(chunk, schema) for chunk in chunks]
73-
result = await asyncio.gather(*tasks)
74-
merged_result: dict[str, Any] = {"entities": [], "relations": []}
75-
for res in result:
76-
merged_result["entities"] += res["entities"]
77-
merged_result["relations"] += res["relations"]
78-
return Neo4jGraph(
79-
entities=merged_result["entities"], relations=merged_result["relations"]
80-
)
81-
82-
83-
class WriterModel(DataModel):
84-
status: str
85-
entities: list[EntityModel]
86-
relations: list[EntityModel]
87-
88-
89-
class Writer(Component):
90-
@validate_call
91-
async def run(self, graph: Neo4jGraph) -> WriterModel:
92-
entities = graph.entities
93-
relations = graph.relations
94-
return WriterModel(
95-
status="OK",
96-
entities=[EntityModel(**e) for e in entities],
97-
relations=[EntityModel(**r) for r in relations],
98-
)
99-
100-
10141
async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
10242
from neo4j_graphrag.experimental.pipeline import Pipeline
10343

examples/text2cypher_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from neo4j import GraphDatabase
22
from neo4j_graphrag.llm import OpenAILLM
3-
from neo4j_graphrag.retrievers.text2cypher import Text2CypherRetriever
3+
from neo4j_graphrag.retrievers import Text2CypherRetriever
44

55
URI = "neo4j://localhost:7687"
66
AUTH = ("neo4j", "password")

0 commit comments

Comments
 (0)