Skip to content

Commit f2e95a5

Browse files
authored
add embedder tests (#430)
1 parent 6b85e92 commit f2e95a5

File tree

4 files changed

+415
-0
lines changed

4 files changed

+415
-0
lines changed

tests/embedder/embedder_fixtures.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""
2+
Copyright 2024, Zep Software, Inc.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
18+
def create_embedding_values(multiplier: float = 0.1, dimension: int = 1536) -> list[float]:
19+
"""Create embedding values with the specified multiplier and dimension."""
20+
return [multiplier] * dimension

tests/embedder/test_gemini.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""
2+
Copyright 2024, Zep Software, Inc.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from collections.abc import Generator
18+
from typing import Any
19+
from unittest.mock import AsyncMock, MagicMock, patch
20+
21+
import pytest
22+
23+
from graphiti_core.embedder.gemini import (
24+
DEFAULT_EMBEDDING_MODEL,
25+
GeminiEmbedder,
26+
GeminiEmbedderConfig,
27+
)
28+
from tests.embedder.embedder_fixtures import create_embedding_values
29+
30+
31+
def create_gemini_embedding(multiplier: float = 0.1) -> MagicMock:
32+
"""Create a mock Gemini embedding with specified value multiplier."""
33+
mock_embedding = MagicMock()
34+
mock_embedding.values = create_embedding_values(multiplier)
35+
return mock_embedding
36+
37+
38+
@pytest.fixture
39+
def mock_gemini_response() -> MagicMock:
40+
"""Create a mock Gemini embeddings response."""
41+
mock_result = MagicMock()
42+
mock_result.embeddings = [create_gemini_embedding()]
43+
return mock_result
44+
45+
46+
@pytest.fixture
47+
def mock_gemini_batch_response() -> MagicMock:
48+
"""Create a mock Gemini batch embeddings response."""
49+
mock_result = MagicMock()
50+
mock_result.embeddings = [
51+
create_gemini_embedding(0.1),
52+
create_gemini_embedding(0.2),
53+
create_gemini_embedding(0.3),
54+
]
55+
return mock_result
56+
57+
58+
@pytest.fixture
59+
def mock_gemini_client() -> Generator[Any, Any, None]:
60+
"""Create a mocked Gemini client."""
61+
with patch('google.genai.Client') as mock_client:
62+
mock_instance = mock_client.return_value
63+
mock_instance.aio = MagicMock()
64+
mock_instance.aio.models = MagicMock()
65+
mock_instance.aio.models.embed_content = AsyncMock()
66+
yield mock_instance
67+
68+
69+
@pytest.fixture
70+
def gemini_embedder(mock_gemini_client: Any) -> GeminiEmbedder:
71+
"""Create a GeminiEmbedder with a mocked client."""
72+
config = GeminiEmbedderConfig(api_key='test_api_key')
73+
client = GeminiEmbedder(config=config)
74+
client.client = mock_gemini_client
75+
return client
76+
77+
78+
@pytest.mark.asyncio
79+
async def test_create_calls_api_correctly(
80+
gemini_embedder: GeminiEmbedder, mock_gemini_client: Any, mock_gemini_response: MagicMock
81+
) -> None:
82+
"""Test that create method correctly calls the API and processes the response."""
83+
# Setup
84+
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
85+
86+
# Call method
87+
result = await gemini_embedder.create('Test input')
88+
89+
# Verify API is called with correct parameters
90+
mock_gemini_client.aio.models.embed_content.assert_called_once()
91+
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
92+
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
93+
assert kwargs['contents'] == ['Test input']
94+
95+
# Verify result is processed correctly
96+
assert result == mock_gemini_response.embeddings[0].values
97+
98+
99+
@pytest.mark.asyncio
100+
async def test_create_batch_processes_multiple_inputs(
101+
gemini_embedder: GeminiEmbedder, mock_gemini_client: Any, mock_gemini_batch_response: MagicMock
102+
) -> None:
103+
"""Test that create_batch method correctly processes multiple inputs."""
104+
# Setup
105+
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response
106+
input_batch = ['Input 1', 'Input 2', 'Input 3']
107+
108+
# Call method
109+
result = await gemini_embedder.create_batch(input_batch)
110+
111+
# Verify API is called with correct parameters
112+
mock_gemini_client.aio.models.embed_content.assert_called_once()
113+
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
114+
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
115+
assert kwargs['contents'] == input_batch
116+
117+
# Verify all results are processed correctly
118+
assert len(result) == 3
119+
assert result == [
120+
mock_gemini_batch_response.embeddings[0].values,
121+
mock_gemini_batch_response.embeddings[1].values,
122+
mock_gemini_batch_response.embeddings[2].values,
123+
]
124+
125+
126+
if __name__ == '__main__':
127+
pytest.main(['-xvs', __file__])

tests/embedder/test_openai.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""
2+
Copyright 2024, Zep Software, Inc.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from collections.abc import Generator
18+
from typing import Any
19+
from unittest.mock import AsyncMock, MagicMock, patch
20+
21+
import pytest
22+
23+
from graphiti_core.embedder.openai import (
24+
DEFAULT_EMBEDDING_MODEL,
25+
OpenAIEmbedder,
26+
OpenAIEmbedderConfig,
27+
)
28+
from tests.embedder.embedder_fixtures import create_embedding_values
29+
30+
31+
def create_openai_embedding(multiplier: float = 0.1) -> MagicMock:
32+
"""Create a mock OpenAI embedding with specified value multiplier."""
33+
mock_embedding = MagicMock()
34+
mock_embedding.embedding = create_embedding_values(multiplier)
35+
return mock_embedding
36+
37+
38+
@pytest.fixture
39+
def mock_openai_response() -> MagicMock:
40+
"""Create a mock OpenAI embeddings response."""
41+
mock_result = MagicMock()
42+
mock_result.data = [create_openai_embedding()]
43+
return mock_result
44+
45+
46+
@pytest.fixture
47+
def mock_openai_batch_response() -> MagicMock:
48+
"""Create a mock OpenAI batch embeddings response."""
49+
mock_result = MagicMock()
50+
mock_result.data = [
51+
create_openai_embedding(0.1),
52+
create_openai_embedding(0.2),
53+
create_openai_embedding(0.3),
54+
]
55+
return mock_result
56+
57+
58+
@pytest.fixture
59+
def mock_openai_client() -> Generator[Any, Any, None]:
60+
"""Create a mocked OpenAI client."""
61+
with patch('openai.AsyncOpenAI') as mock_client:
62+
mock_instance = mock_client.return_value
63+
mock_instance.embeddings = MagicMock()
64+
mock_instance.embeddings.create = AsyncMock()
65+
yield mock_instance
66+
67+
68+
@pytest.fixture
69+
def openai_embedder(mock_openai_client: Any) -> OpenAIEmbedder:
70+
"""Create an OpenAIEmbedder with a mocked client."""
71+
config = OpenAIEmbedderConfig(api_key='test_api_key')
72+
client = OpenAIEmbedder(config=config)
73+
client.client = mock_openai_client
74+
return client
75+
76+
77+
@pytest.mark.asyncio
78+
async def test_create_calls_api_correctly(
79+
openai_embedder: OpenAIEmbedder, mock_openai_client: Any, mock_openai_response: MagicMock
80+
) -> None:
81+
"""Test that create method correctly calls the API and processes the response."""
82+
# Setup
83+
mock_openai_client.embeddings.create.return_value = mock_openai_response
84+
85+
# Call method
86+
result = await openai_embedder.create('Test input')
87+
88+
# Verify API is called with correct parameters
89+
mock_openai_client.embeddings.create.assert_called_once()
90+
_, kwargs = mock_openai_client.embeddings.create.call_args
91+
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
92+
assert kwargs['input'] == 'Test input'
93+
94+
# Verify result is processed correctly
95+
assert result == mock_openai_response.data[0].embedding[: openai_embedder.config.embedding_dim]
96+
97+
98+
@pytest.mark.asyncio
99+
async def test_create_batch_processes_multiple_inputs(
100+
openai_embedder: OpenAIEmbedder, mock_openai_client: Any, mock_openai_batch_response: MagicMock
101+
) -> None:
102+
"""Test that create_batch method correctly processes multiple inputs."""
103+
# Setup
104+
mock_openai_client.embeddings.create.return_value = mock_openai_batch_response
105+
input_batch = ['Input 1', 'Input 2', 'Input 3']
106+
107+
# Call method
108+
result = await openai_embedder.create_batch(input_batch)
109+
110+
# Verify API is called with correct parameters
111+
mock_openai_client.embeddings.create.assert_called_once()
112+
_, kwargs = mock_openai_client.embeddings.create.call_args
113+
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
114+
assert kwargs['input'] == input_batch
115+
116+
# Verify all results are processed correctly
117+
assert len(result) == 3
118+
assert result == [
119+
mock_openai_batch_response.data[0].embedding[: openai_embedder.config.embedding_dim],
120+
mock_openai_batch_response.data[1].embedding[: openai_embedder.config.embedding_dim],
121+
mock_openai_batch_response.data[2].embedding[: openai_embedder.config.embedding_dim],
122+
]
123+
124+
125+
if __name__ == '__main__':
126+
pytest.main(['-xvs', __file__])

0 commit comments

Comments
 (0)