diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs
index 70d6210fc355..d9d5b67ee4af 100644
--- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs
+++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs
@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
+using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Linq;
using System.Runtime.CompilerServices;
@@ -22,11 +23,62 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL;
///
public class AzureCosmosDBNoSQLMemoryStore : IMemoryStore, IDisposable
{
+ private const string EmbeddingPath = "/embedding";
+
private readonly CosmosClient _cosmosClient;
private readonly VectorEmbeddingPolicy _vectorEmbeddingPolicy;
private readonly IndexingPolicy _indexingPolicy;
private readonly string _databaseName;
+ ///
+ /// Initiates a AzureCosmosDBNoSQLMemoryStore instance using a Azure Cosmos DB connection string
+ /// and other properties required for vector search.
+ ///
+ /// Connection string required to connect to Azure Cosmos DB.
+ /// The database name to connect to.
+ /// The number of dimensions the embedding vectors to be stored.
+ /// The data type of the embedding vectors to be stored.
+ /// The type of index to use for the embedding vectors to be stored.
+ /// The application name to use in requests.
+ public AzureCosmosDBNoSQLMemoryStore(
+ string connectionString,
+ string databaseName,
+ ulong dimensions,
+ VectorDataType vectorDataType,
+ VectorIndexType vectorIndexType,
+ string? applicationName = null)
+ : this(
+ new CosmosClient(
+ connectionString,
+ new CosmosClientOptions
+ {
+ ApplicationName = applicationName ?? HttpHeaderConstant.Values.UserAgent,
+ Serializer = new CosmosSystemTextJsonSerializer(JsonSerializerOptions.Default),
+ }),
+ databaseName,
+ new VectorEmbeddingPolicy(
+ [
+ new Embedding
+ {
+ DataType = vectorDataType,
+ Dimensions = dimensions,
+ DistanceFunction = DistanceFunction.Cosine,
+ Path = EmbeddingPath,
+ }
+ ]),
+ new IndexingPolicy
+ {
+ VectorIndexes = new Collection {
+ new()
+ {
+ Path = EmbeddingPath,
+ Type = vectorIndexType,
+ },
+ },
+ })
+ {
+ }
+
///
/// Initiates a AzureCosmosDBNoSQLMemoryStore instance using a Azure Cosmos DB connection string
/// and other properties required for vector search.
@@ -71,14 +123,29 @@ public AzureCosmosDBNoSQLMemoryStore(
VectorEmbeddingPolicy vectorEmbeddingPolicy,
IndexingPolicy indexingPolicy)
{
- if (!vectorEmbeddingPolicy.Embeddings.Any(e => e.Path == "/embedding"))
+ var embedding = vectorEmbeddingPolicy.Embeddings.FirstOrDefault(e => e.Path == EmbeddingPath);
+ if (embedding is null)
{
throw new InvalidOperationException($"""
In order for {nameof(GetNearestMatchAsync)} to function, {nameof(vectorEmbeddingPolicy)} should
- contain an embedding path at /embedding. It's also recommended to include a that path in the
+ contain an embedding path at {EmbeddingPath}. It's also recommended to include that path in the
{nameof(indexingPolicy)} to improve performance and reduce cost for searches.
""");
}
+ else if (embedding.DistanceFunction != DistanceFunction.Cosine)
+ {
+ throw new InvalidOperationException($"""
+ In order for {nameof(GetNearestMatchAsync)} to reliably return relevance information, the {nameof(DistanceFunction)} should
+ be specified as {nameof(DistanceFunction)}.{nameof(DistanceFunction.Cosine)}.
+ """);
+ }
+ else if (embedding.DataType != VectorDataType.Float16 && embedding.DataType != VectorDataType.Float32)
+ {
+ throw new NotSupportedException($"""
+ Only {nameof(VectorDataType)}.{nameof(VectorDataType.Float16)} and {nameof(VectorDataType)}.{nameof(VectorDataType.Float32)}
+ are supported.
+ """);
+ }
this._cosmosClient = cosmosClient;
this._databaseName = databaseName;
this._vectorEmbeddingPolicy = vectorEmbeddingPolicy;
@@ -164,6 +231,12 @@ public async Task UpsertAsync(
MemoryRecord record,
CancellationToken cancellationToken = default)
{
+ // In some cases we're expected to generate the key to use. Do so if one isn't provided.
+ if (string.IsNullOrEmpty(record.Key))
+ {
+ record.Key = Guid.NewGuid().ToString();
+ }
+
var result = await this._cosmosClient
.GetDatabase(this._databaseName)
.GetContainer(collectionName)
@@ -193,6 +266,7 @@ public async IAsyncEnumerable UpsertBatchAsync(
bool withEmbedding = false,
CancellationToken cancellationToken = default)
{
+ // TODO: Consider using a query when `withEmbedding` is false to avoid passing it over the wire.
var result = await this._cosmosClient
.GetDatabase(this._databaseName)
.GetContainer(collectionName)
@@ -330,9 +404,10 @@ ORDER BY VectorDistance(x.embedding, @embedding)
{
foreach (var memoryRecord in await feedIterator.ReadNextAsync(cancellationToken).ConfigureAwait(false))
{
- if (memoryRecord.SimilarityScore >= minRelevanceScore)
+ var relevanceScore = (memoryRecord.SimilarityScore + 1) / 2;
+ if (relevanceScore >= minRelevanceScore)
{
- yield return (memoryRecord, memoryRecord.SimilarityScore);
+ yield return (memoryRecord, relevanceScore);
}
}
}
diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTests.cs
index 0e8aee320856..e75116e34893 100644
--- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTests.cs
+++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTests.cs
@@ -1,9 +1,14 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
+using System.Collections.Generic;
+using System.Collections.ObjectModel;
using System.Linq;
+using System.Threading;
using System.Threading.Tasks;
+using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL;
+using Microsoft.SemanticKernel.Embeddings;
using Microsoft.SemanticKernel.Memory;
using MongoDB.Driver;
using Xunit;
@@ -117,6 +122,54 @@ public async Task ItCanGetNearestMatchesAsync(int limit, bool withEmbeddings)
await memoryStore.DeleteCollectionAsync(collectionName);
}
+ [Theory(Skip = SkipReason)]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async Task ItCanSaveReferenceGetTextAndSearchTextAsync(bool withEmbedding)
+ {
+ var collectionName = this._fixture.CollectionName;
+ var memoryStore = this._fixture.MemoryStore;
+ var textMemory = new SemanticTextMemory(memoryStore, new MockTextEmbeddingGenerationService());
+ var textToStore = "SampleText";
+ var id = "MyExternalId";
+ var source = "MyExternalSource";
+ var refId = await textMemory.SaveReferenceAsync(collectionName, textToStore, id, source);
+ Assert.NotNull(refId);
+
+ var expectedQueryResult = new MemoryQueryResult(
+ new MemoryRecordMetadata(isReference: true, id, text: "", description: "", source, additionalMetadata: ""),
+ 1.0,
+ withEmbedding ? DataHelper.VectorSearchTestEmbedding : null);
+
+ var queryResult = await textMemory.GetAsync(collectionName, refId, withEmbedding);
+ AssertQueryResultEqual(expectedQueryResult, queryResult, withEmbedding);
+
+ var searchResults = await textMemory.SearchAsync(collectionName, textToStore, withEmbeddings: withEmbedding).ToListAsync();
+ Assert.Equal(1, searchResults?.Count);
+ AssertQueryResultEqual(expectedQueryResult, searchResults?[0], compareEmbeddings: true);
+
+ await textMemory.RemoveAsync(collectionName, refId);
+ }
+
+ private static void AssertQueryResultEqual(MemoryQueryResult expected, MemoryQueryResult? actual, bool compareEmbeddings)
+ {
+ Assert.NotNull(actual);
+ Assert.Equal(expected.Relevance, actual.Relevance);
+ Assert.Equal(expected.Metadata.Id, actual.Metadata.Id);
+ Assert.Equal(expected.Metadata.Text, actual.Metadata.Text);
+ Assert.Equal(expected.Metadata.Description, actual.Metadata.Description);
+ Assert.Equal(expected.Metadata.ExternalSourceName, actual.Metadata.ExternalSourceName);
+ Assert.Equal(expected.Metadata.AdditionalMetadata, actual.Metadata.AdditionalMetadata);
+ Assert.Equal(expected.Metadata.IsReference, actual.Metadata.IsReference);
+
+ if (compareEmbeddings)
+ {
+ Assert.NotNull(expected.Embedding);
+ Assert.NotNull(actual.Embedding);
+ Assert.Equal(expected.Embedding.Value.Span, actual.Embedding.Value.Span);
+ }
+ }
+
private static void AssertMemoryRecordEqual(
MemoryRecord expectedRecord,
MemoryRecord actualRecord,
@@ -147,4 +200,15 @@ private static void AssertMemoryRecordEqual(
Assert.True(actualRecord.Embedding.Span.IsEmpty);
}
}
+
+ private sealed class MockTextEmbeddingGenerationService : ITextEmbeddingGenerationService
+ {
+ public IReadOnlyDictionary Attributes { get; } = ReadOnlyDictionary.Empty;
+
+ public Task>> GenerateEmbeddingsAsync(IList data, Kernel? kernel = null, CancellationToken cancellationToken = default)
+ {
+ IList> result = new List> { DataHelper.VectorSearchTestEmbedding };
+ return Task.FromResult(result);
+ }
+ }
}
diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTestsFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTestsFixture.cs
index 93cbea170f40..1df46166e63f 100644
--- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTestsFixture.cs
+++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTestsFixture.cs
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
-using System.Collections.ObjectModel;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos;
using Microsoft.Extensions.Configuration;
@@ -35,28 +34,9 @@ public AzureCosmosDBNoSQLMemoryStoreTestsFixture()
this.MemoryStore = new AzureCosmosDBNoSQLMemoryStore(
connectionString,
this.DatabaseName,
- new VectorEmbeddingPolicy(
- new Collection
- {
- new()
- {
- DataType = VectorDataType.Float32,
- Dimensions = 3,
- DistanceFunction = DistanceFunction.Cosine,
- Path = "/embedding"
- }
- }),
- new()
- {
- VectorIndexes = new Collection {
- new()
- {
- Path = "/embedding",
- Type = VectorIndexType.Flat,
- },
- },
- }
- );
+ dimensions: 3,
+ VectorDataType.Float32,
+ VectorIndexType.Flat);
}
public Task InitializeAsync()