Skip to content

Commit 8aa4684

Browse files
committed
Fix types and format
1 parent 7b8ef44 commit 8aa4684

File tree

10 files changed

+46
-40
lines changed

10 files changed

+46
-40
lines changed

src/neo4j_graphrag/embeddings/cohere.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
try:
2222
import cohere
2323
except ImportError:
24-
cohere = None # type: ignore
24+
cohere = None
2525

2626

2727
class CohereEmbeddings(Embedder):

src/neo4j_graphrag/embeddings/mistral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
try:
2525
from mistralai import Mistral
2626
except ImportError:
27-
Mistral = None # type: ignore
27+
Mistral = None
2828

2929

3030
class MistralAIEmbeddings(Embedder):

src/neo4j_graphrag/experimental/pipeline/pipeline.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,6 @@ def get_neo4j_viz_graph(self, hide_unused_outputs: bool = True) -> NeoVizGraph:
251251
color="#4C8BF5", # Blue for component nodes
252252
caption_align=CaptionAlignment.CENTER,
253253
caption_size=12,
254-
pinned=False,
255-
x=0,
256-
y=0,
257254
)
258255
)
259256
node_counter += 1
@@ -270,9 +267,6 @@ def get_neo4j_viz_graph(self, hide_unused_outputs: bool = True) -> NeoVizGraph:
270267
color="#34A853", # Green for output nodes
271268
caption_align=CaptionAlignment.CENTER,
272269
caption_size=10,
273-
pinned=False,
274-
x=0,
275-
y=0,
276270
)
277271
)
278272
# Connect component to its output

src/neo4j_graphrag/llm/anthropic_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def get_messages(
9191
raise LLMGenerationError(e.errors()) from e
9292
messages.extend(cast(Iterable[dict[str, Any]], message_history))
9393
messages.append(UserMessage(content=input).model_dump())
94-
return messages # type: ignore
94+
return messages
9595

9696
def invoke(
9797
self,

src/neo4j_graphrag/llm/cohere_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def get_messages(
9494
raise LLMGenerationError(e.errors()) from e
9595
messages.extend(cast(Iterable[dict[str, Any]], message_history))
9696
messages.append(UserMessage(content=input).model_dump())
97-
return messages # type: ignore
97+
return messages
9898

9999
def invoke(
100100
self,

src/neo4j_graphrag/llm/mistralai_llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
from mistralai import Messages, Mistral
3636
from mistralai.models.sdkerror import SDKError
3737
except ImportError:
38-
Mistral = None # type: ignore
39-
SDKError = None # type: ignore
38+
Mistral = None
39+
SDKError = None
4040

4141

4242
class MistralAILLM(LLMInterface):

src/neo4j_graphrag/llm/ollama_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def get_messages(
7676
raise LLMGenerationError(e.errors()) from e
7777
messages.extend(cast(Iterable[dict[str, Any]], message_history))
7878
messages.append(UserMessage(content=input).model_dump())
79-
return messages # type: ignore
79+
return messages
8080

8181
def invoke(
8282
self,

tests/unit/llm/test_anthropic_llm.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_anthropic_invoke_happy_path(mock_anthropic: Mock) -> None:
4949
input_text = "may thy knife chip and shatter"
5050
response = llm.invoke(input_text)
5151
assert response.content == "generated text"
52-
llm.client.messages.create.assert_called_once_with( # type: ignore
52+
llm.client.messages.create.assert_called_once_with(
5353
messages=[{"role": "user", "content": input_text}],
5454
model="claude-3-opus-20240229",
5555
system=anthropic.NOT_GIVEN,
@@ -66,16 +66,22 @@ def test_anthropic_invoke_with_message_history_happy_path(mock_anthropic: Mock)
6666
"claude-3-opus-20240229",
6767
model_params=model_params,
6868
)
69-
message_history = [
70-
{"role": "user", "content": "When does the sun come up in the summer?"},
71-
{"role": "assistant", "content": "Usually around 6am."},
72-
]
69+
from neo4j_graphrag.message_history import InMemoryMessageHistory
70+
from neo4j_graphrag.types import LLMMessage
71+
72+
message_history = InMemoryMessageHistory()
73+
message_history.add_message(
74+
LLMMessage(role="user", content="When does the sun come up in the summer?")
75+
)
76+
message_history.add_message(
77+
LLMMessage(role="assistant", content="Usually around 6am.")
78+
)
7379
question = "What about next season?"
7480

75-
response = llm.invoke(question, message_history) # type: ignore
81+
response = llm.invoke(question, message_history)
7682
assert response.content == "generated text"
77-
message_history.append({"role": "user", "content": question})
78-
llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined]
83+
message_history.add_message(LLMMessage(role="user", content=question))
84+
llm.client.messages.create.assert_called_once_with(
7985
messages=message_history,
8086
model="claude-3-opus-20240229",
8187
system=anthropic.NOT_GIVEN,
@@ -101,14 +107,14 @@ def test_anthropic_invoke_with_system_instruction(
101107
assert isinstance(response, LLMResponse)
102108
assert response.content == "generated text"
103109
messages = [{"role": "user", "content": question}]
104-
llm.client.messages.create.assert_called_with( # type: ignore[attr-defined]
110+
llm.client.messages.create.assert_called_with(
105111
model="claude-3-opus-20240229",
106112
system=system_instruction,
107113
messages=messages,
108114
**model_params,
109115
)
110116

111-
assert llm.client.messages.create.call_count == 1 # type: ignore
117+
assert llm.client.messages.create.call_count == 1
112118

113119

114120
def test_anthropic_invoke_with_message_history_and_system_instruction(
@@ -123,24 +129,30 @@ def test_anthropic_invoke_with_message_history_and_system_instruction(
123129
"claude-3-opus-20240229",
124130
model_params=model_params,
125131
)
126-
message_history = [
127-
{"role": "user", "content": "When does the sun come up in the summer?"},
128-
{"role": "assistant", "content": "Usually around 6am."},
129-
]
132+
from neo4j_graphrag.message_history import InMemoryMessageHistory
133+
from neo4j_graphrag.types import LLMMessage
134+
135+
message_history = InMemoryMessageHistory()
136+
message_history.add_message(
137+
LLMMessage(role="user", content="When does the sun come up in the summer?")
138+
)
139+
message_history.add_message(
140+
LLMMessage(role="assistant", content="Usually around 6am.")
141+
)
130142

131143
question = "When does it come up in the winter?"
132-
response = llm.invoke(question, message_history, system_instruction) # type: ignore
144+
response = llm.invoke(question, message_history, system_instruction)
133145
assert isinstance(response, LLMResponse)
134146
assert response.content == "generated text"
135-
message_history.append({"role": "user", "content": question})
136-
llm.client.messages.create.assert_called_with( # type: ignore[attr-defined]
147+
message_history.add_message(LLMMessage(role="user", content=question))
148+
llm.client.messages.create.assert_called_with(
137149
model="claude-3-opus-20240229",
138150
system=system_instruction,
139151
messages=message_history,
140152
**model_params,
141153
)
142154

143-
assert llm.client.messages.create.call_count == 1 # type: ignore
155+
assert llm.client.messages.create.call_count == 1
144156

145157

146158
def test_anthropic_invoke_with_message_history_validation_error(
@@ -178,7 +190,7 @@ async def test_anthropic_ainvoke_happy_path(mock_anthropic: Mock) -> None:
178190
input_text = "may thy knife chip and shatter"
179191
response = await llm.ainvoke(input_text)
180192
assert response.content == "Return text"
181-
llm.async_client.messages.create.assert_awaited_once_with( # type: ignore
193+
llm.async_client.messages.create.assert_awaited_once_with(
182194
model="claude-3-opus-20240229",
183195
system=anthropic.NOT_GIVEN,
184196
messages=[{"role": "user", "content": input_text}],

tests/unit/llm/test_mistralai_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_mistralai_llm_invoke_with_message_history(mock_mistral: Mock) -> None:
7171
messages = [{"role": "system", "content": system_instruction}]
7272
messages.extend(message_history)
7373
messages.append({"role": "user", "content": question})
74-
llm.client.chat.complete.assert_called_once_with( # type: ignore[attr-defined]
74+
llm.client.chat.complete.assert_called_once_with(
7575
messages=messages,
7676
model=model,
7777
)
@@ -103,12 +103,12 @@ def test_mistralai_llm_invoke_with_message_history_and_system_instruction(
103103
messages = [{"role": "system", "content": system_instruction}]
104104
messages.extend(message_history)
105105
messages.append({"role": "user", "content": question})
106-
llm.client.chat.complete.assert_called_once_with( # type: ignore[attr-defined]
106+
llm.client.chat.complete.assert_called_once_with(
107107
messages=messages,
108108
model=model,
109109
)
110110

111-
assert llm.client.chat.complete.call_count == 1 # type: ignore
111+
assert llm.client.chat.complete.call_count == 1
112112

113113

114114
@patch("neo4j_graphrag.llm.mistralai_llm.Mistral")

tests/unit/llm/test_ollama_llm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None:
5555
messages = [
5656
{"role": "user", "content": question},
5757
]
58-
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
58+
llm.client.chat.assert_called_once_with(
5959
model=model, messages=messages, options=model_params
6060
)
6161

@@ -80,7 +80,7 @@ def test_ollama_invoke_with_system_instruction_happy_path(mock_import: Mock) ->
8080
assert response.content == "ollama chat response"
8181
messages = [{"role": "system", "content": system_instruction}]
8282
messages.append({"role": "user", "content": question})
83-
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
83+
llm.client.chat.assert_called_once_with(
8484
model=model, messages=messages, options=model_params
8585
)
8686

@@ -108,7 +108,7 @@ def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> Non
108108
assert response.content == "ollama chat response"
109109
messages = [m for m in message_history]
110110
messages.append({"role": "user", "content": question})
111-
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
111+
llm.client.chat.assert_called_once_with(
112112
model=model, messages=messages, options=model_params
113113
)
114114

@@ -144,10 +144,10 @@ def test_ollama_invoke_with_message_history_and_system_instruction(
144144
messages = [{"role": "system", "content": system_instruction}]
145145
messages.extend(message_history)
146146
messages.append({"role": "user", "content": question})
147-
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
147+
llm.client.chat.assert_called_once_with(
148148
model=model, messages=messages, options=model_params
149149
)
150-
assert llm.client.chat.call_count == 1 # type: ignore
150+
assert llm.client.chat.call_count == 1
151151

152152

153153
@patch("builtins.__import__")

0 commit comments

Comments
 (0)