Skip to content

Python: New tests for azure_cosmos_db_mongodb_collection and local_step #11518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
382955c
work in progress
Mar 18, 2025
9ce5b0e
Merge remote-tracking branch 'upstream/main' into new-tests
Mar 18, 2025
8cca51a
new tests for OllamaChatCompletion and OpenAIRealTime classes
Mar 19, 2025
f057ef2
fix
Mar 24, 2025
f1c1a85
updates to tests
Mar 28, 2025
3302266
fixes in tests
Mar 30, 2025
1f32149
fix pre-commit errors
Mar 31, 2025
ef1f0db
fix pre-commit errors
Mar 31, 2025
78f5d6b
fix pre-commit errors
Mar 31, 2025
4dc951b
merge conflict solved
Apr 1, 2025
8c4d3e5
pre-commit
Apr 1, 2025
61e04ed
pre-commit
Apr 1, 2025
e49a62a
pre-commit
Apr 1, 2025
6807474
Merge pull request #1 from gaudyb/new-tests
gaudyb Apr 2, 2025
aa2ac64
revert to original version
Apr 2, 2025
5b0f8c8
Merge remote-tracking branch 'upstream/main'
Apr 2, 2025
fb17b30
feedback implemented
Apr 3, 2025
d61993a
feedback implemented
Apr 3, 2025
47f7586
feedback implemented
Apr 4, 2025
268fb25
Merge remote-tracking branch 'upstream/main'
Apr 8, 2025
0cbde4f
new tests progress
Apr 9, 2025
c5e9eb1
new test
Apr 9, 2025
ac95200
new tests for azure_cosmos_db_mongodb_collection and local_step
Apr 10, 2025
129845a
Merge remote-tracking branch 'upstream/main' into new_tests
Apr 10, 2025
77d1ee6
Merge remote-tracking branch 'upstream/main'
Apr 10, 2025
3cfb325
Merge remote-tracking branch 'origin/main' into new_tests
Apr 10, 2025
bf129df
Merge pull request #2 from gaudyb/new_tests
gaudyb Apr 15, 2025
3df55df
feedback implemented
Apr 24, 2025
4d29984
merge conflict solved
Apr 24, 2025
6fcd98c
uv file updated
Apr 24, 2025
d140381
issues fixed
Apr 24, 2025
d2bbb0c
Merge remote-tracking branch 'upstream/main'
Apr 24, 2025
e9abbbd
Merge remote-tracking branch 'upstream/main'
Apr 25, 2025
d78304b
merge conflict solved
Apr 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# Copyright (c) Microsoft. All rights reserved.

from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from pydantic import BaseModel, ValidationError
from pydantic_core import InitErrorDetails
from pymongo import AsyncMongoClient

import semantic_kernel.connectors.memory.azure_cosmos_db.azure_cosmos_db_mongodb_collection as cosmos_collection
import semantic_kernel.connectors.memory.azure_cosmos_db.azure_cosmos_db_mongodb_settings as cosmos_settings
from semantic_kernel.data.const import DistanceFunction, IndexKind
from semantic_kernel.data.record_definition import (
VectorStoreRecordDataField,
VectorStoreRecordDefinition,
VectorStoreRecordKeyField,
VectorStoreRecordVectorField,
)
from semantic_kernel.exceptions import VectorStoreInitializationException


async def test_constructor_with_mongo_client_provided() -> None:
"""
Test the constructor of AzureCosmosDBforMongoDBCollection when a mongo_client
is directly provided. Expect that the class is successfully initialized
and doesn't attempt to manage the client.
"""
mock_client = AsyncMock(spec=AsyncMongoClient)
collection_name = "test_collection"
fake_definition = VectorStoreRecordDefinition(
fields={
"id": VectorStoreRecordKeyField(),
"content": VectorStoreRecordDataField(),
"vector": VectorStoreRecordVectorField(),
}
)

collection = cosmos_collection.AzureCosmosDBforMongoDBCollection(
collection_name=collection_name,
data_model_type=dict,
mongo_client=mock_client,
data_model_definition=fake_definition,
)

assert collection.mongo_client == mock_client
assert collection.collection_name == collection_name
assert not collection.managed_client, "Should not be managing client when provided"


async def test_constructor_without_mongo_client_success() -> None:
"""
Test the constructor of AzureCosmosDBforMongoDBCollection when a mongo_client
is not provided. Expect it to create settings and initialize an AsyncMongoClient.
"""
mock_data_model_definition = VectorStoreRecordDefinition(
fields={
"id": VectorStoreRecordKeyField(),
"content": VectorStoreRecordDataField(),
"vector": VectorStoreRecordVectorField(),
}
)

fake_client = AsyncMock(name="fake_async_client", spec=AsyncMongoClient)

with (
patch.object(
cosmos_settings.AzureCosmosDBforMongoDBSettings,
"create",
return_value=AsyncMock(
connection_string=AsyncMock(get_secret_value=lambda: "mongodb://test"), database_name="test_db"
),
) as mock_settings_create,
patch(
"semantic_kernel.connectors.memory.azure_cosmos_db.azure_cosmos_db_mongodb_collection.AsyncMongoClient",
return_value=fake_client,
spec=AsyncMongoClient,
) as mock_async_client,
):
collection = cosmos_collection.AzureCosmosDBforMongoDBCollection(
collection_name="test_collection",
data_model_type=dict,
data_model_definition=mock_data_model_definition,
connection_string="mongodb://test-env",
database_name="",
)

mock_settings_create.assert_called_once()
mock_async_client.assert_called_once()

created_client = mock_async_client.return_value

assert collection.mongo_client == created_client
assert collection.managed_client, "Should manage client when none is provided"
assert collection.database_name == "test_db"


async def test_constructor_raises_exception_on_validation_error() -> None:
"""
Test that the constructor raises VectorStoreInitializationException when
AzureCosmosDBforMongoDBSettings.create fails with ValidationError.
"""

mock_data_model_definition = VectorStoreRecordDefinition(
fields={
"id": VectorStoreRecordKeyField(),
"content": VectorStoreRecordDataField(),
"vector": VectorStoreRecordVectorField(),
}
)

class DummyModel(BaseModel):
connection_string: str

error = InitErrorDetails(
type="missing",
loc=("connection_string",),
msg="Field required",
input=None,
) # type: ignore

validation_error = ValidationError.from_exception_data("DummyModel", [error])

with patch.object(
cosmos_settings.AzureCosmosDBforMongoDBSettings,
"create",
side_effect=validation_error,
):
with pytest.raises(VectorStoreInitializationException) as exc_info:
cosmos_collection.AzureCosmosDBforMongoDBCollection(
collection_name="test_collection",
data_model_type=dict,
data_model_definition=mock_data_model_definition,
connection_string="mongodb://test-env",
database_name="",
)
assert "Failed to create Azure CosmosDB for MongoDB settings." in str(exc_info.value)


async def test_constructor_raises_exception_if_no_connection_string() -> None:
"""
Ensure that a VectorStoreInitializationException is raised if the
AzureCosmosDBforMongoDBSettings.connection_string is None.
"""
# Mock settings without a connection string
mock_settings = AsyncMock(spec=cosmos_settings.AzureCosmosDBforMongoDBSettings)
mock_settings.connection_string = None
mock_settings.database_name = "some_database"

with patch.object(cosmos_settings.AzureCosmosDBforMongoDBSettings, "create", return_value=mock_settings):
with pytest.raises(VectorStoreInitializationException) as exc_info:
cosmos_collection.AzureCosmosDBforMongoDBCollection(collection_name="test_collection", data_model_type=dict)
assert "The Azure CosmosDB for MongoDB connection string is required." in str(exc_info.value)


async def test_create_collection_calls_database_methods() -> None:
"""
Test create_collection to verify that it first creates a collection, then
calls the appropriate command to create a vector index.
"""
# Setup
mock_database = AsyncMock()
mock_database.create_collection = AsyncMock()
mock_database.command = AsyncMock()

mock_client = AsyncMock(spec=AsyncMongoClient)
mock_client.get_database = MagicMock(return_value=mock_database)

mock_data_model_definition = AsyncMock(spec=VectorStoreRecordDefinition)
# Simulate a data_model_definition with certain fields & vector_fields
mock_field = AsyncMock(spec=VectorStoreRecordDataField)
type(mock_field).name = "test_field"
type(mock_field).is_filterable = True
type(mock_field).is_full_text_searchable = True

type(mock_field).property_type = "str"

mock_vector_field = AsyncMock()
type(mock_vector_field).dimensions = 128
type(mock_vector_field).name = "embedding"
type(mock_vector_field).distance_function = DistanceFunction.COSINE_SIMILARITY
type(mock_vector_field).index_kind = IndexKind.IVF_FLAT
type(mock_vector_field).property_type = "float"

mock_data_model_definition.fields = {"test_field": mock_field}
mock_data_model_definition.vector_fields = [mock_vector_field]
mock_data_model_definition.key_field = mock_field

# Instantiate
collection = cosmos_collection.AzureCosmosDBforMongoDBCollection(
collection_name="test_collection",
data_model_type=dict,
data_model_definition=mock_data_model_definition,
mongo_client=mock_client,
database_name="test_db",
)

# Act
await collection.create_collection(customArg="customValue")

# Assert
mock_database.create_collection.assert_awaited_once_with("test_collection", customArg="customValue")
mock_database.command.assert_awaited()
command_args = mock_database.command.call_args.kwargs["command"]

assert command_args["createIndexes"] == "test_collection"
assert len(command_args["indexes"]) == 2, "One for the data field, one for the vector field"
# Check the data field index
assert command_args["indexes"][0]["name"] == "test_field_"
# Check the vector field index creation
assert command_args["indexes"][1]["name"] == "embedding_"
assert command_args["indexes"][1]["key"] == {"embedding": "cosmosSearch"}
assert command_args["indexes"][1]["cosmosSearchOptions"]["kind"] == "vector-ivf"
assert command_args["indexes"][1]["cosmosSearchOptions"]["similarity"] is not None
assert command_args["indexes"][1]["cosmosSearchOptions"]["dimensions"] == 128


async def test_context_manager_calls_aconnect_and_close_when_managed() -> None:
"""
Test that the context manager in AzureCosmosDBforMongoDBCollection calls 'aconnect' and
'close' when the client is managed (i.e., created internally).
"""
mock_client = AsyncMock(spec=AsyncMongoClient)

mock_data_model_definition = VectorStoreRecordDefinition(
fields={
"id": VectorStoreRecordKeyField(),
"content": VectorStoreRecordDataField(),
"vector": VectorStoreRecordVectorField(),
}
)

with patch(
"semantic_kernel.connectors.memory.azure_cosmos_db.azure_cosmos_db_mongodb_collection.AsyncMongoClient",
return_value=mock_client,
):
collection = cosmos_collection.AzureCosmosDBforMongoDBCollection(
collection_name="test_collection",
data_model_type=dict,
connection_string="mongodb://fake",
data_model_definition=mock_data_model_definition,
)

# "__aenter__" should call 'aconnect'
async with collection as c:
mock_client.aconnect.assert_awaited_once()
assert c is collection

# "__aexit__" should call 'close' if managed
mock_client.close.assert_awaited_once()


async def test_context_manager_does_not_close_when_not_managed() -> None:
"""
Test that the context manager in AzureCosmosDBforMongoDBCollection does not call 'close'
when the client is not managed (i.e., provided externally).
"""
mock_data_model_definition = VectorStoreRecordDefinition(
fields={
"id": VectorStoreRecordKeyField(),
"content": VectorStoreRecordDataField(),
"vector": VectorStoreRecordVectorField(),
}
)

external_client = AsyncMock(spec=AsyncMongoClient, name="external_client", value=None)
external_client.aconnect = AsyncMock(name="aconnect")
external_client.close = AsyncMock(name="close")

collection = cosmos_collection.AzureCosmosDBforMongoDBCollection(
collection_name="test_collection",
data_model_type=dict,
mongo_client=external_client,
data_model_definition=mock_data_model_definition,
)

# "__aenter__" scenario
async with collection as c:
external_client.aconnect.assert_awaited()
assert c is collection

# "__aexit__" should NOT call "close" when not managed
external_client.close.assert_not_awaited()
Loading
Loading