diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs index 70d6210fc355..d9d5b67ee4af 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Diagnostics; using System.Linq; using System.Runtime.CompilerServices; @@ -22,11 +23,62 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; /// public class AzureCosmosDBNoSQLMemoryStore : IMemoryStore, IDisposable { + private const string EmbeddingPath = "/embedding"; + private readonly CosmosClient _cosmosClient; private readonly VectorEmbeddingPolicy _vectorEmbeddingPolicy; private readonly IndexingPolicy _indexingPolicy; private readonly string _databaseName; + /// + /// Initiates a AzureCosmosDBNoSQLMemoryStore instance using a Azure Cosmos DB connection string + /// and other properties required for vector search. + /// + /// Connection string required to connect to Azure Cosmos DB. + /// The database name to connect to. + /// The number of dimensions the embedding vectors to be stored. + /// The data type of the embedding vectors to be stored. + /// The type of index to use for the embedding vectors to be stored. + /// The application name to use in requests. + public AzureCosmosDBNoSQLMemoryStore( + string connectionString, + string databaseName, + ulong dimensions, + VectorDataType vectorDataType, + VectorIndexType vectorIndexType, + string? applicationName = null) + : this( + new CosmosClient( + connectionString, + new CosmosClientOptions + { + ApplicationName = applicationName ?? HttpHeaderConstant.Values.UserAgent, + Serializer = new CosmosSystemTextJsonSerializer(JsonSerializerOptions.Default), + }), + databaseName, + new VectorEmbeddingPolicy( + [ + new Embedding + { + DataType = vectorDataType, + Dimensions = dimensions, + DistanceFunction = DistanceFunction.Cosine, + Path = EmbeddingPath, + } + ]), + new IndexingPolicy + { + VectorIndexes = new Collection { + new() + { + Path = EmbeddingPath, + Type = vectorIndexType, + }, + }, + }) + { + } + /// /// Initiates a AzureCosmosDBNoSQLMemoryStore instance using a Azure Cosmos DB connection string /// and other properties required for vector search. @@ -71,14 +123,29 @@ public AzureCosmosDBNoSQLMemoryStore( VectorEmbeddingPolicy vectorEmbeddingPolicy, IndexingPolicy indexingPolicy) { - if (!vectorEmbeddingPolicy.Embeddings.Any(e => e.Path == "/embedding")) + var embedding = vectorEmbeddingPolicy.Embeddings.FirstOrDefault(e => e.Path == EmbeddingPath); + if (embedding is null) { throw new InvalidOperationException($""" In order for {nameof(GetNearestMatchAsync)} to function, {nameof(vectorEmbeddingPolicy)} should - contain an embedding path at /embedding. It's also recommended to include a that path in the + contain an embedding path at {EmbeddingPath}. It's also recommended to include that path in the {nameof(indexingPolicy)} to improve performance and reduce cost for searches. """); } + else if (embedding.DistanceFunction != DistanceFunction.Cosine) + { + throw new InvalidOperationException($""" + In order for {nameof(GetNearestMatchAsync)} to reliably return relevance information, the {nameof(DistanceFunction)} should + be specified as {nameof(DistanceFunction)}.{nameof(DistanceFunction.Cosine)}. + """); + } + else if (embedding.DataType != VectorDataType.Float16 && embedding.DataType != VectorDataType.Float32) + { + throw new NotSupportedException($""" + Only {nameof(VectorDataType)}.{nameof(VectorDataType.Float16)} and {nameof(VectorDataType)}.{nameof(VectorDataType.Float32)} + are supported. + """); + } this._cosmosClient = cosmosClient; this._databaseName = databaseName; this._vectorEmbeddingPolicy = vectorEmbeddingPolicy; @@ -164,6 +231,12 @@ public async Task UpsertAsync( MemoryRecord record, CancellationToken cancellationToken = default) { + // In some cases we're expected to generate the key to use. Do so if one isn't provided. + if (string.IsNullOrEmpty(record.Key)) + { + record.Key = Guid.NewGuid().ToString(); + } + var result = await this._cosmosClient .GetDatabase(this._databaseName) .GetContainer(collectionName) @@ -193,6 +266,7 @@ public async IAsyncEnumerable UpsertBatchAsync( bool withEmbedding = false, CancellationToken cancellationToken = default) { + // TODO: Consider using a query when `withEmbedding` is false to avoid passing it over the wire. var result = await this._cosmosClient .GetDatabase(this._databaseName) .GetContainer(collectionName) @@ -330,9 +404,10 @@ ORDER BY VectorDistance(x.embedding, @embedding) { foreach (var memoryRecord in await feedIterator.ReadNextAsync(cancellationToken).ConfigureAwait(false)) { - if (memoryRecord.SimilarityScore >= minRelevanceScore) + var relevanceScore = (memoryRecord.SimilarityScore + 1) / 2; + if (relevanceScore >= minRelevanceScore) { - yield return (memoryRecord, memoryRecord.SimilarityScore); + yield return (memoryRecord, relevanceScore); } } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTests.cs index 0e8aee320856..e75116e34893 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTests.cs @@ -1,9 +1,14 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Linq; +using System.Threading; using System.Threading.Tasks; +using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; +using Microsoft.SemanticKernel.Embeddings; using Microsoft.SemanticKernel.Memory; using MongoDB.Driver; using Xunit; @@ -117,6 +122,54 @@ public async Task ItCanGetNearestMatchesAsync(int limit, bool withEmbeddings) await memoryStore.DeleteCollectionAsync(collectionName); } + [Theory(Skip = SkipReason)] + [InlineData(true)] + [InlineData(false)] + public async Task ItCanSaveReferenceGetTextAndSearchTextAsync(bool withEmbedding) + { + var collectionName = this._fixture.CollectionName; + var memoryStore = this._fixture.MemoryStore; + var textMemory = new SemanticTextMemory(memoryStore, new MockTextEmbeddingGenerationService()); + var textToStore = "SampleText"; + var id = "MyExternalId"; + var source = "MyExternalSource"; + var refId = await textMemory.SaveReferenceAsync(collectionName, textToStore, id, source); + Assert.NotNull(refId); + + var expectedQueryResult = new MemoryQueryResult( + new MemoryRecordMetadata(isReference: true, id, text: "", description: "", source, additionalMetadata: ""), + 1.0, + withEmbedding ? DataHelper.VectorSearchTestEmbedding : null); + + var queryResult = await textMemory.GetAsync(collectionName, refId, withEmbedding); + AssertQueryResultEqual(expectedQueryResult, queryResult, withEmbedding); + + var searchResults = await textMemory.SearchAsync(collectionName, textToStore, withEmbeddings: withEmbedding).ToListAsync(); + Assert.Equal(1, searchResults?.Count); + AssertQueryResultEqual(expectedQueryResult, searchResults?[0], compareEmbeddings: true); + + await textMemory.RemoveAsync(collectionName, refId); + } + + private static void AssertQueryResultEqual(MemoryQueryResult expected, MemoryQueryResult? actual, bool compareEmbeddings) + { + Assert.NotNull(actual); + Assert.Equal(expected.Relevance, actual.Relevance); + Assert.Equal(expected.Metadata.Id, actual.Metadata.Id); + Assert.Equal(expected.Metadata.Text, actual.Metadata.Text); + Assert.Equal(expected.Metadata.Description, actual.Metadata.Description); + Assert.Equal(expected.Metadata.ExternalSourceName, actual.Metadata.ExternalSourceName); + Assert.Equal(expected.Metadata.AdditionalMetadata, actual.Metadata.AdditionalMetadata); + Assert.Equal(expected.Metadata.IsReference, actual.Metadata.IsReference); + + if (compareEmbeddings) + { + Assert.NotNull(expected.Embedding); + Assert.NotNull(actual.Embedding); + Assert.Equal(expected.Embedding.Value.Span, actual.Embedding.Value.Span); + } + } + private static void AssertMemoryRecordEqual( MemoryRecord expectedRecord, MemoryRecord actualRecord, @@ -147,4 +200,15 @@ private static void AssertMemoryRecordEqual( Assert.True(actualRecord.Embedding.Span.IsEmpty); } } + + private sealed class MockTextEmbeddingGenerationService : ITextEmbeddingGenerationService + { + public IReadOnlyDictionary Attributes { get; } = ReadOnlyDictionary.Empty; + + public Task>> GenerateEmbeddingsAsync(IList data, Kernel? kernel = null, CancellationToken cancellationToken = default) + { + IList> result = new List> { DataHelper.VectorSearchTestEmbedding }; + return Task.FromResult(result); + } + } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTestsFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTestsFixture.cs index 93cbea170f40..1df46166e63f 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTestsFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTestsFixture.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.ObjectModel; using System.Threading.Tasks; using Microsoft.Azure.Cosmos; using Microsoft.Extensions.Configuration; @@ -35,28 +34,9 @@ public AzureCosmosDBNoSQLMemoryStoreTestsFixture() this.MemoryStore = new AzureCosmosDBNoSQLMemoryStore( connectionString, this.DatabaseName, - new VectorEmbeddingPolicy( - new Collection - { - new() - { - DataType = VectorDataType.Float32, - Dimensions = 3, - DistanceFunction = DistanceFunction.Cosine, - Path = "/embedding" - } - }), - new() - { - VectorIndexes = new Collection { - new() - { - Path = "/embedding", - Type = VectorIndexType.Flat, - }, - }, - } - ); + dimensions: 3, + VectorDataType.Float32, + VectorIndexType.Flat); } public Task InitializeAsync()