diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 69548fef9058..282a4a768b59 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -87,6 +87,7 @@ + diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index acf61a114486..6517eeca2b29 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -107,11 +107,14 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{958AD708-F04 EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Diagnostics", "Diagnostics", "{29E7D971-1308-4171-9872-E8E4669A1134}" ProjectSection(SolutionItems) = preProject + src\InternalUtilities\src\Diagnostics\ActivityExtensions.cs = src\InternalUtilities\src\Diagnostics\ActivityExtensions.cs src\InternalUtilities\src\Diagnostics\CompilerServicesAttributes.cs = src\InternalUtilities\src\Diagnostics\CompilerServicesAttributes.cs src\InternalUtilities\src\Diagnostics\DynamicallyAccessedMembersAttribute.cs = src\InternalUtilities\src\Diagnostics\DynamicallyAccessedMembersAttribute.cs src\InternalUtilities\src\Diagnostics\ExceptionExtensions.cs = src\InternalUtilities\src\Diagnostics\ExceptionExtensions.cs src\InternalUtilities\src\Diagnostics\ExperimentalAttribute.cs = src\InternalUtilities\src\Diagnostics\ExperimentalAttribute.cs src\InternalUtilities\src\Diagnostics\IsExternalInit.cs = src\InternalUtilities\src\Diagnostics\IsExternalInit.cs + src\InternalUtilities\src\Diagnostics\KernelVerify.cs = src\InternalUtilities\src\Diagnostics\KernelVerify.cs + src\InternalUtilities\src\Diagnostics\LoggingExtensions.cs = src\InternalUtilities\src\Diagnostics\LoggingExtensions.cs src\InternalUtilities\src\Diagnostics\NullableAttributes.cs = src\InternalUtilities\src\Diagnostics\NullableAttributes.cs src\InternalUtilities\src\Diagnostics\RequiresDynamicCodeAttribute.cs = src\InternalUtilities\src\Diagnostics\RequiresDynamicCodeAttribute.cs src\InternalUtilities\src\Diagnostics\RequiresUnreferencedCodeAttribute.cs = src\InternalUtilities\src\Diagnostics\RequiresUnreferencedCodeAttribute.cs @@ -547,6 +550,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Runtime.InProcess", "src\Ag EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Runtime.InProcess.Tests", "src\Agents\Runtime\InProcess.Tests\Runtime.InProcess.Tests.csproj", "{DA6B4ED4-ED0B-D25C-889C-9F940E714891}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "VectorData.UnitTests", "src\Connectors\VectorData.UnitTests\VectorData.UnitTests.csproj", "{AAC7B5E8-CC4E-49D0-AF6A-2B4F7B43BD84}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -1508,6 +1513,12 @@ Global {DA6B4ED4-ED0B-D25C-889C-9F940E714891}.Publish|Any CPU.Build.0 = Release|Any CPU {DA6B4ED4-ED0B-D25C-889C-9F940E714891}.Release|Any CPU.ActiveCfg = Release|Any CPU {DA6B4ED4-ED0B-D25C-889C-9F940E714891}.Release|Any CPU.Build.0 = Release|Any CPU + {AAC7B5E8-CC4E-49D0-AF6A-2B4F7B43BD84}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {AAC7B5E8-CC4E-49D0-AF6A-2B4F7B43BD84}.Debug|Any CPU.Build.0 = Debug|Any CPU + {AAC7B5E8-CC4E-49D0-AF6A-2B4F7B43BD84}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {AAC7B5E8-CC4E-49D0-AF6A-2B4F7B43BD84}.Publish|Any CPU.Build.0 = Debug|Any CPU + {AAC7B5E8-CC4E-49D0-AF6A-2B4F7B43BD84}.Release|Any CPU.ActiveCfg = Release|Any CPU + {AAC7B5E8-CC4E-49D0-AF6A-2B4F7B43BD84}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -1713,6 +1724,7 @@ Global {A4F05541-7D23-A5A9-033D-382F1E13D0FE} = {A70ED5A7-F8E1-4A57-9455-3C05989542DA} {CCC909E4-5269-A31E-0BFD-4863B4B29BBB} = {A70ED5A7-F8E1-4A57-9455-3C05989542DA} {DA6B4ED4-ED0B-D25C-889C-9F940E714891} = {A70ED5A7-F8E1-4A57-9455-3C05989542DA} + {AAC7B5E8-CC4E-49D0-AF6A-2B4F7B43BD84} = {5A7028A7-4DDF-4E4F-84A9-37CE8F8D7E89} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83} diff --git a/dotnet/docs/EXPERIMENTS.md b/dotnet/docs/EXPERIMENTS.md index 99fd9b56afb4..1e143695e0f0 100644 --- a/dotnet/docs/EXPERIMENTS.md +++ b/dotnet/docs/EXPERIMENTS.md @@ -25,6 +25,8 @@ You can use the following diagnostic IDs to ignore warnings or errors for a part | SKEXP0100 | Advanced Semantic Kernel features | | SKEXP0110 | Semantic Kernel Agents | | SKEXP0120 | Native-AOT | +| MEVD9000 | Microsoft.Extensions.VectorData experimental user-facing APIs | +| MEVD9001 | Microsoft.Extensions.VectorData experimental connector-facing APIs | ## Experimental Features Tracking diff --git a/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs b/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs index 78c54df49434..4d9eb15d36a2 100644 --- a/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs +++ b/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs @@ -199,8 +199,8 @@ public async Task OnPromptRenderAsync(PromptRenderContext context, Funcfalse true - $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101,SKEXP0110,OPENAI001,CA1724 + $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101,SKEXP0110,OPENAI001,CA1724,MEVD9000 Library 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/samples/Concepts/Memory/HuggingFace_TextEmbeddingCustomHttpHandler.cs b/dotnet/samples/Concepts/Memory/HuggingFace_TextEmbeddingCustomHttpHandler.cs index 744274d4c527..d9ebbf568c3a 100644 --- a/dotnet/samples/Concepts/Memory/HuggingFace_TextEmbeddingCustomHttpHandler.cs +++ b/dotnet/samples/Concepts/Memory/HuggingFace_TextEmbeddingCustomHttpHandler.cs @@ -15,6 +15,7 @@ namespace Memory; /// For example, the cointegrated/LaBSE-en-ru model returns results as a 1 * 1 * 4 * 768 matrix, which is different from Hugging Face embedding generation service implementation. /// To address this, a custom can be used to modify the response before sending it back. /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class HuggingFace_TextEmbeddingCustomHttpHandler(ITestOutputHelper output) : BaseTest(output) { public async Task RunInferenceApiEmbeddingCustomHttpHandlerAsync() diff --git a/dotnet/samples/Concepts/Memory/MemoryStore_CustomReadOnly.cs b/dotnet/samples/Concepts/Memory/MemoryStore_CustomReadOnly.cs deleted file mode 100644 index e8994db01afd..000000000000 --- a/dotnet/samples/Concepts/Memory/MemoryStore_CustomReadOnly.cs +++ /dev/null @@ -1,239 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Numerics.Tensors; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; -using System.Text.Json; -using Microsoft.SemanticKernel.Memory; - -namespace Memory; - -/// -/// This sample provides a custom implementation of that is read only. -/// In this sample, the data is stored in a JSON string and deserialized into an -/// . For this specific sample, the implementation -/// of has a single collection, and thus does not need to be named. -/// It also assumes that the JSON formatted data can be deserialized into objects. -/// -public class MemoryStore_CustomReadOnly(ITestOutputHelper output) : BaseTest(output) -{ - [Fact] - public async Task RunAsync() - { - var store = new ReadOnlyMemoryStore(s_jsonVectorEntries); - - var embedding = new ReadOnlyMemory([22, 4, 6]); - - Console.WriteLine("Reading data from custom read-only memory store"); - var memoryRecord = await store.GetAsync("collection", "key3"); - if (memoryRecord is not null) - { - Console.WriteLine($"ID = {memoryRecord.Metadata.Id}, Embedding = {string.Join(", ", MemoryMarshal.ToEnumerable(memoryRecord.Embedding))}"); - } - - Console.WriteLine($"Getting most similar vector to {string.Join(", ", MemoryMarshal.ToEnumerable(embedding))}"); - var result = await store.GetNearestMatchAsync("collection", embedding, 0.0); - if (result.HasValue) - { - Console.WriteLine($"ID = {string.Join(", ", MemoryMarshal.ToEnumerable(result.Value.Item1.Embedding))}, Embedding = {result.Value.Item2}"); - } - } - - private sealed class ReadOnlyMemoryStore : IMemoryStore - { - private readonly MemoryRecord[]? _memoryRecords = null; - private readonly int _vectorSize = 3; - - public ReadOnlyMemoryStore(string valueString) - { - s_jsonVectorEntries = s_jsonVectorEntries.Replace("\n", string.Empty, StringComparison.Ordinal); - s_jsonVectorEntries = s_jsonVectorEntries.Replace(" ", string.Empty, StringComparison.Ordinal); - this._memoryRecords = JsonSerializer.Deserialize(valueString); - - if (this._memoryRecords is null) - { - throw new Exception("Unable to deserialize memory records"); - } - } - - public Task CreateCollectionAsync(string collectionName, CancellationToken cancellationToken = default) - { - throw new System.NotImplementedException(); - } - - public Task DeleteCollectionAsync(string collectionName, CancellationToken cancellationToken = default) - { - throw new System.NotImplementedException(); - } - - public Task DoesCollectionExistAsync(string collectionName, CancellationToken cancellationToken = default) - { - throw new System.NotImplementedException(); - } - - public Task GetAsync(string collectionName, string key, bool withEmbedding = false, CancellationToken cancellationToken = default) - { - // Note: with this simple implementation, the MemoryRecord will always contain the embedding. - return Task.FromResult(this._memoryRecords?.FirstOrDefault(x => x.Key == key)); - } - - public async IAsyncEnumerable GetBatchAsync(string collectionName, IEnumerable keys, bool withEmbeddings = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - // Note: with this simple implementation, the MemoryRecord will always contain the embedding. - if (this._memoryRecords is not null) - { - foreach (var memoryRecord in this._memoryRecords) - { - if (keys.Contains(memoryRecord.Key)) - { - yield return memoryRecord; - } - } - } - } - - public IAsyncEnumerable GetCollectionsAsync(CancellationToken cancellationToken = default) - { - throw new System.NotImplementedException(); - } - - public async Task<(MemoryRecord, double)?> GetNearestMatchAsync(string collectionName, ReadOnlyMemory embedding, double minRelevanceScore = 0, - bool withEmbedding = false, CancellationToken cancellationToken = default) - { - // Note: with this simple implementation, the MemoryRecord will always contain the embedding. - await foreach (var item in this.GetNearestMatchesAsync( - collectionName: collectionName, - embedding: embedding, - limit: 1, - minRelevanceScore: minRelevanceScore, - withEmbeddings: withEmbedding, - cancellationToken: cancellationToken).ConfigureAwait(false)) - { - return item; - } - - return default; - } - - public async IAsyncEnumerable<(MemoryRecord, double)> GetNearestMatchesAsync(string collectionName, ReadOnlyMemory embedding, int limit, - double minRelevanceScore = 0, bool withEmbeddings = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - // Note: with this simple implementation, the MemoryRecord will always contain the embedding. - if (this._memoryRecords is null || this._memoryRecords.Length == 0) - { - yield break; - } - - if (embedding.Length != this._vectorSize) - { - throw new Exception($"Embedding vector size {embedding.Length} does not match expected size of {this._vectorSize}"); - } - - List<(MemoryRecord Record, double Score)> embeddings = []; - - foreach (var item in this._memoryRecords) - { - double similarity = TensorPrimitives.CosineSimilarity(embedding.Span, item.Embedding.Span); - if (similarity >= minRelevanceScore) - { - embeddings.Add(new(item, similarity)); - } - } - - foreach (var item in embeddings.OrderByDescending(l => l.Score).Take(limit)) - { - yield return (item.Record, item.Score); - } - } - - public Task RemoveAsync(string collectionName, string key, CancellationToken cancellationToken = default) - { - throw new System.NotImplementedException(); - } - - public Task RemoveBatchAsync(string collectionName, IEnumerable keys, CancellationToken cancellationToken = default) - { - throw new System.NotImplementedException(); - } - - public Task UpsertAsync(string collectionName, MemoryRecord record, CancellationToken cancellationToken = default) - { - throw new System.NotImplementedException(); - } - - public IAsyncEnumerable UpsertBatchAsync(string collectionName, IEnumerable records, CancellationToken cancellationToken = default) - { - throw new System.NotImplementedException(); - } - } - - private static string s_jsonVectorEntries = """ - [ - { - "embedding": [0, 0, 0], - "metadata": { - "is_reference": false, - "external_source_name": "externalSourceName", - "id": "Id1", - "description": "description", - "text": "text", - "additional_metadata" : "value:" - }, - "key": "key1", - "timestamp": null - }, - { - "embedding": [0, 0, 10], - "metadata": { - "is_reference": false, - "external_source_name": "externalSourceName", - "id": "Id2", - "description": "description", - "text": "text", - "additional_metadata" : "value:" - }, - "key": "key2", - "timestamp": null - }, - { - "embedding": [1, 2, 3], - "metadata": { - "is_reference": false, - "external_source_name": "externalSourceName", - "id": "Id3", - "description": "description", - "text": "text", - "additional_metadata" : "value:" - }, - "key": "key3", - "timestamp": null - }, - { - "embedding": [-1, -2, -3], - "metadata": { - "is_reference": false, - "external_source_name": "externalSourceName", - "id": "Id4", - "description": "description", - "text": "text", - "additional_metadata" : "value:" - }, - "key": "key4", - "timestamp": null - }, - { - "embedding": [12, 8, 4], - "metadata": { - "is_reference": false, - "external_source_name": "externalSourceName", - "id": "Id5", - "description": "description", - "text": "text", - "additional_metadata" : "value:" - }, - "key": "key5", - "timestamp": null - } - ] - """; -} diff --git a/dotnet/samples/Concepts/Memory/SemanticTextMemory_Building.cs b/dotnet/samples/Concepts/Memory/SemanticTextMemory_Building.cs deleted file mode 100644 index 72cb44af516a..000000000000 --- a/dotnet/samples/Concepts/Memory/SemanticTextMemory_Building.cs +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using Microsoft.SemanticKernel.Connectors.AzureAISearch; -using Microsoft.SemanticKernel.Connectors.OpenAI; -using Microsoft.SemanticKernel.Memory; - -namespace Memory; - -/* The files contains two examples about SK Semantic Memory. - * - * 1. Memory using Azure AI Search. - * 2. Memory using a custom embedding generator and vector engine. - * - * Semantic Memory allows to store your data like traditional DBs, - * adding the ability to query it using natural language. - */ -public class SemanticTextMemory_Building(ITestOutputHelper output) : BaseTest(output) -{ - private const string MemoryCollectionName = "SKGitHub"; - - [Fact] - public async Task RunAsync() - { - Console.WriteLine("=============================================================="); - Console.WriteLine("======== Semantic Memory using Azure AI Search ========"); - Console.WriteLine("=============================================================="); - - /* This example leverages Azure AI Search to provide SK with Semantic Memory. - * - * Azure AI Search automatically indexes your data semantically, so you don't - * need to worry about embedding generation. - */ - - var memoryWithACS = new MemoryBuilder() - .WithOpenAITextEmbeddingGeneration("text-embedding-ada-002", TestConfiguration.OpenAI.ApiKey) - .WithMemoryStore(new AzureAISearchMemoryStore(TestConfiguration.AzureAISearch.Endpoint, TestConfiguration.AzureAISearch.ApiKey)) - .Build(); - - await RunExampleAsync(memoryWithACS); - - Console.WriteLine("===================================================="); - Console.WriteLine("======== Semantic Memory (volatile, in RAM) ========"); - Console.WriteLine("===================================================="); - - /* You can build your own semantic memory combining an Embedding Generator - * with a Memory storage that supports search by similarity (ie semantic search). - * - * In this example we use a volatile memory, a local simulation of a vector DB. - * - * You can replace VolatileMemoryStore with Qdrant (see QdrantMemoryStore connector) - * or implement your connectors for Pinecone, Vespa, Postgres + pgvector, SQLite VSS, etc. - */ - - var memoryWithCustomDb = new MemoryBuilder() - .WithOpenAITextEmbeddingGeneration("text-embedding-ada-002", TestConfiguration.OpenAI.ApiKey) - .WithMemoryStore(new VolatileMemoryStore()) - .Build(); - - // Uncomment the following line to use GoogleAI embeddings - // var memoryWithCustomDb = new MemoryBuilder() - // .WithGoogleAITextEmbeddingGeneration(TestConfiguration.GoogleAI.EmbeddingModelId, TestConfiguration.GoogleAI.ApiKey) - // .WithMemoryStore(new VolatileMemoryStore()) - // .Build(); - - await RunExampleAsync(memoryWithCustomDb); - } - - private async Task RunExampleAsync(ISemanticTextMemory memory) - { - await StoreMemoryAsync(memory); - - await SearchMemoryAsync(memory, "How do I get started?"); - - /* - Output: - - Query: How do I get started? - - Result 1: - URL: : https://github.com/microsoft/semantic-kernel/blob/main/README.md - Title : README: Installation, getting started, and how to contribute - - Result 2: - URL: : https://github.com/microsoft/semantic-kernel/blob/main/samples/dotnet-jupyter-notebooks/00-getting-started.ipynb - Title : Jupyter notebook describing how to get started with the Semantic Kernel - - */ - - await SearchMemoryAsync(memory, "Can I build a chat with SK?"); - - /* - Output: - - Query: Can I build a chat with SK? - - Result 1: - URL: : https://github.com/microsoft/semantic-kernel/tree/main/prompt_template_samples/ChatPlugin/ChatGPT - Title : Sample demonstrating how to create a chat plugin interfacing with ChatGPT - - Result 2: - URL: : https://github.com/microsoft/semantic-kernel/blob/main/samples/apps/chat-summary-webapp-react/README.md - Title : README: README associated with a sample chat summary react-based webapp - - */ - } - - private async Task SearchMemoryAsync(ISemanticTextMemory memory, string query) - { - Console.WriteLine("\nQuery: " + query + "\n"); - - var memoryResults = memory.SearchAsync(MemoryCollectionName, query, limit: 2, minRelevanceScore: 0.5); - - int i = 0; - await foreach (MemoryQueryResult memoryResult in memoryResults) - { - Console.WriteLine($"Result {++i}:"); - Console.WriteLine(" URL: : " + memoryResult.Metadata.Id); - Console.WriteLine(" Title : " + memoryResult.Metadata.Description); - Console.WriteLine(" Relevance: " + memoryResult.Relevance); - Console.WriteLine(); - } - - Console.WriteLine("----------------------"); - } - - private async Task StoreMemoryAsync(ISemanticTextMemory memory) - { - /* Store some data in the semantic memory. - * - * When using Azure AI Search the data is automatically indexed on write. - * - * When using the combination of VolatileStore and Embedding generation, SK takes - * care of creating and storing the index - */ - - Console.WriteLine("\nAdding some GitHub file URLs and their descriptions to the semantic memory."); - var githubFiles = SampleData(); - var i = 0; - foreach (var entry in githubFiles) - { - await memory.SaveReferenceAsync( - collection: MemoryCollectionName, - externalSourceName: "GitHub", - externalId: entry.Key, - description: entry.Value, - text: entry.Value); - - Console.Write($" #{++i} saved."); - } - - Console.WriteLine("\n----------------------"); - } - - private static Dictionary SampleData() - { - return new Dictionary - { - ["https://github.com/microsoft/semantic-kernel/blob/main/README.md"] - = "README: Installation, getting started, and how to contribute", - ["https://github.com/microsoft/semantic-kernel/blob/main/dotnet/notebooks/02-running-prompts-from-file.ipynb"] - = "Jupyter notebook describing how to pass prompts from a file to a semantic plugin or function", - ["https://github.com/microsoft/semantic-kernel/blob/main/dotnet/notebooks/00-getting-started.ipynb"] - = "Jupyter notebook describing how to get started with the Semantic Kernel", - ["https://github.com/microsoft/semantic-kernel/tree/main/prompt_template_samples/ChatPlugin/ChatGPT"] - = "Sample demonstrating how to create a chat plugin interfacing with ChatGPT", - ["https://github.com/microsoft/semantic-kernel/blob/main/dotnet/src/Plugins/Plugins.Memory/VolatileMemoryStore.cs"] - = "C# class that defines a volatile embedding store", - }; - } -} diff --git a/dotnet/samples/Concepts/Memory/TextMemoryPlugin_GeminiEmbeddingGeneration.cs b/dotnet/samples/Concepts/Memory/TextMemoryPlugin_GeminiEmbeddingGeneration.cs deleted file mode 100644 index 0313370782e0..000000000000 --- a/dotnet/samples/Concepts/Memory/TextMemoryPlugin_GeminiEmbeddingGeneration.cs +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using Microsoft.SemanticKernel; -using Microsoft.SemanticKernel.Connectors.Google; -using Microsoft.SemanticKernel.Embeddings; -using Microsoft.SemanticKernel.Memory; - -namespace Memory; - -/// -/// Represents an example class for Gemini Embedding Generation with volatile memory store. -/// -public sealed class TextMemoryPlugin_GeminiEmbeddingGeneration(ITestOutputHelper output) : BaseTest(output) -{ - private const string MemoryCollectionName = "aboutMe"; - - [Fact] - public async Task GoogleAIAsync() - { - Console.WriteLine("============= Google AI - Gemini Embedding Generation ============="); - - Assert.NotNull(TestConfiguration.GoogleAI.ApiKey); - Assert.NotNull(TestConfiguration.GoogleAI.EmbeddingModelId); - - Kernel kernel = Kernel.CreateBuilder() - .AddGoogleAIGeminiChatCompletion( - modelId: TestConfiguration.GoogleAI.EmbeddingModelId, - apiKey: TestConfiguration.GoogleAI.ApiKey) - .AddGoogleAIEmbeddingGeneration( - modelId: TestConfiguration.GoogleAI.EmbeddingModelId, - apiKey: TestConfiguration.GoogleAI.ApiKey) - .Build(); - - await this.RunSimpleSampleAsync(kernel); - await this.RunTextMemoryPluginSampleAsync(kernel); - } - - [Fact] - public async Task VertexAIAsync() - { - Console.WriteLine("============= Vertex AI - Gemini Embedding Generation ============="); - - Assert.NotNull(TestConfiguration.VertexAI.BearerKey); - Assert.NotNull(TestConfiguration.VertexAI.Location); - Assert.NotNull(TestConfiguration.VertexAI.ProjectId); - Assert.NotNull(TestConfiguration.VertexAI.Gemini.ModelId); - Assert.NotNull(TestConfiguration.VertexAI.EmbeddingModelId); - - Kernel kernel = Kernel.CreateBuilder() - .AddVertexAIGeminiChatCompletion( - modelId: TestConfiguration.VertexAI.Gemini.ModelId, - bearerKey: TestConfiguration.VertexAI.BearerKey, - location: TestConfiguration.VertexAI.Location, - projectId: TestConfiguration.VertexAI.ProjectId) - .AddVertexAIEmbeddingGeneration( - modelId: TestConfiguration.VertexAI.EmbeddingModelId, - bearerKey: TestConfiguration.VertexAI.BearerKey, - location: TestConfiguration.VertexAI.Location, - projectId: TestConfiguration.VertexAI.ProjectId) - .Build(); - - // To generate bearer key, you need installed google sdk or use google web console with command: - // - // gcloud auth print-access-token - // - // Above code pass bearer key as string, it is not recommended way in production code, - // especially if IChatCompletionService and IEmbeddingGenerationService will be long lived, tokens generated by google sdk lives for 1 hour. - // You should use bearer key provider, which will be used to generate token on demand: - // - // Example: - // - // Kernel kernel = Kernel.CreateBuilder() - // .AddVertexAIGeminiChatCompletion( - // modelId: TestConfiguration.VertexAI.Gemini.ModelId, - // bearerKeyProvider: () => - // { - // // This is just example, in production we recommend using Google SDK to generate your BearerKey token. - // // This delegate will be called on every request, - // // when providing the token consider using caching strategy and refresh token logic when it is expired or close to expiration. - // return GetBearerKey(); - // }, - // location: TestConfiguration.VertexAI.Location, - // projectId: TestConfiguration.VertexAI.ProjectId) - // .AddVertexAIEmbeddingGeneration( - // modelId: embeddingModelId, - // bearerKeyProvider: () => - // { - // // This is just example, in production we recommend using Google SDK to generate your BearerKey token. - // // This delegate will be called on every request, - // // when providing the token consider using caching strategy and refresh token logic when it is expired or close to expiration. - // return GetBearerKey(); - // }, - // location: geminiLocation, - // projectId: geminiProject); - - await this.RunSimpleSampleAsync(kernel); - await this.RunTextMemoryPluginSampleAsync(kernel); - } - - private async Task RunSimpleSampleAsync(Kernel kernel) - { - Console.WriteLine("== Simple Sample: Generating Embeddings =="); - - // Obtain an embedding generator. - var embeddingGenerator = kernel.GetRequiredService(); - - var generatedEmbeddings = await embeddingGenerator.GenerateEmbeddingAsync("My name is Andrea"); - Console.WriteLine($"Generated Embeddings count: {generatedEmbeddings.Length}, " + - $"First five: {string.Join(", ", generatedEmbeddings[..5])}..."); - Console.WriteLine(); - } - - private async Task RunTextMemoryPluginSampleAsync(Kernel kernel) - { - Console.WriteLine("== Complex Sample: TextMemoryPlugin =="); - - var memoryStore = new VolatileMemoryStore(); - - // Obtain an embedding generator to use for semantic memory. - var embeddingGenerator = kernel.GetRequiredService(); - - // The combination of the text embedding generator and the memory store makes up the 'SemanticTextMemory' object used to - // store and retrieve memories. - Microsoft.SemanticKernel.Memory.SemanticTextMemory textMemory = new(memoryStore, embeddingGenerator); - - ///////////////////////////////////////////////////////////////////////////////////////////////////// - // PART 1: Store and retrieve memories using the ISemanticTextMemory (textMemory) object. - // - // This is a simple way to store memories from a code perspective, without using the Kernel. - ///////////////////////////////////////////////////////////////////////////////////////////////////// - Console.WriteLine("== PART 1: Saving Memories through the ISemanticTextMemory object =="); - - Console.WriteLine("Saving memory with key 'info1': \"My name is Andrea\""); - await textMemory.SaveInformationAsync(MemoryCollectionName, id: "info1", text: "My name is Andrea"); - - Console.WriteLine("Saving memory with key 'info2': \"I work as a tourist operator\""); - await textMemory.SaveInformationAsync(MemoryCollectionName, id: "info2", text: "I work as a tourist operator"); - - Console.WriteLine("Saving memory with key 'info3': \"I've been living in Seattle since 2005\""); - await textMemory.SaveInformationAsync(MemoryCollectionName, id: "info3", text: "I've been living in Seattle since 2005"); - - Console.WriteLine("Saving memory with key 'info4': \"I visited France and Italy five times since 2015\""); - await textMemory.SaveInformationAsync(MemoryCollectionName, id: "info4", text: "I visited France and Italy five times since 2015"); - - Console.WriteLine(); - - ///////////////////////////////////////////////////////////////////////////////////////////////////// - // PART 2: Create TextMemoryPlugin, store memories through the Kernel. - // - // This enables prompt functions and the AI (via Planners) to access memories - ///////////////////////////////////////////////////////////////////////////////////////////////////// - - Console.WriteLine("== PART 2: Saving Memories through the Kernel with TextMemoryPlugin and the 'Save' function =="); - - // Import the TextMemoryPlugin into the Kernel for other functions - var memoryPlugin = kernel.ImportPluginFromObject(new Microsoft.SemanticKernel.Plugins.Memory.TextMemoryPlugin(textMemory)); - - // Save a memory with the Kernel - Console.WriteLine("Saving memory with key 'info5': \"My family is from New York\""); - await kernel.InvokeAsync(memoryPlugin["Save"], new() - { - [Microsoft.SemanticKernel.Plugins.Memory.TextMemoryPlugin.InputParam] = "My family is from New York", - [Microsoft.SemanticKernel.Plugins.Memory.TextMemoryPlugin.CollectionParam] = MemoryCollectionName, - [Microsoft.SemanticKernel.Plugins.Memory.TextMemoryPlugin.KeyParam] = "info5", - }); - - Console.WriteLine(); - - ///////////////////////////////////////////////////////////////////////////////////////////////////// - // PART 3: Recall similar ideas with semantic search - // - // Uses AI Embeddings for fuzzy lookup of memories based on intent, rather than a specific key. - ///////////////////////////////////////////////////////////////////////////////////////////////////// - - Console.WriteLine("== PART 3: Recall (similarity search) with AI Embeddings =="); - - Console.WriteLine("== PART 3a: Recall (similarity search) with ISemanticTextMemory =="); - Console.WriteLine("Ask: live in Seattle?"); - - await foreach (var answer in textMemory.SearchAsync( - collection: MemoryCollectionName, - query: "live in Seattle?", - limit: 2, - minRelevanceScore: 0.79, - withEmbeddings: true)) - { - Console.WriteLine($"Answer: {answer.Metadata.Text}"); - } - - /* Possible output: - Answer: I've been living in Seattle since 2005 - */ - - Console.WriteLine("== PART 3b: Recall (similarity search) with Kernel and TextMemoryPlugin 'Recall' function =="); - Console.WriteLine("Ask: my family is from?"); - - var result = await kernel.InvokeAsync(memoryPlugin["Recall"], new() - { - [Microsoft.SemanticKernel.Plugins.Memory.TextMemoryPlugin.InputParam] = "Ask: my family is from?", - [Microsoft.SemanticKernel.Plugins.Memory.TextMemoryPlugin.CollectionParam] = MemoryCollectionName, - [Microsoft.SemanticKernel.Plugins.Memory.TextMemoryPlugin.LimitParam] = "2", - [Microsoft.SemanticKernel.Plugins.Memory.TextMemoryPlugin.RelevanceParam] = "0.79", - }); - - Console.WriteLine($"Answer: {result.GetValue()}"); - Console.WriteLine(); - - /* Possible output: - Answer: ["My family is from New York"] - */ - - ///////////////////////////////////////////////////////////////////////////////////////////////////// - // PART 4: TextMemoryPlugin Recall in a Prompt Function - // - // Looks up related memories when rendering a prompt template, then sends the rendered prompt to - // the text generation model to answer a natural language query. - ///////////////////////////////////////////////////////////////////////////////////////////////////// - - Console.WriteLine("== PART 4: Using TextMemoryPlugin 'Recall' function in a Prompt Function =="); - - // Build a prompt function that uses memory to find facts - const string RecallFunctionDefinition = @" -Consider only the facts below when answering questions: - -BEGIN FACTS -About me: {{recall 'live in Seattle?'}} -About me: {{recall 'my family is from?'}} -END FACTS - -Question: {{$input}} - -Answer: -"; - - result = await kernel.InvokePromptAsync(RecallFunctionDefinition, new(new GeminiPromptExecutionSettings { MaxTokens = 1000 }) - { - [Microsoft.SemanticKernel.Plugins.Memory.TextMemoryPlugin.InputParam] = "Where are my family from?", - [Microsoft.SemanticKernel.Plugins.Memory.TextMemoryPlugin.CollectionParam] = MemoryCollectionName, - [Microsoft.SemanticKernel.Plugins.Memory.TextMemoryPlugin.LimitParam] = "2", - [Microsoft.SemanticKernel.Plugins.Memory.TextMemoryPlugin.RelevanceParam] = "0.79", - }); - - Console.WriteLine("Ask: Where are my family from?"); - Console.WriteLine($"Answer: {result.GetValue()}"); - - /* Possible output: - Answer: New York - */ - - Console.WriteLine(); - - ///////////////////////////////////////////////////////////////////////////////////////////////////// - // PART 5: Cleanup, deleting database collection - // - ///////////////////////////////////////////////////////////////////////////////////////////////////// - - Console.WriteLine("== PART 5: Cleanup, deleting database collection =="); - - Console.WriteLine("Printing Collections in DB..."); - var collections = memoryStore.GetCollectionsAsync(); - await foreach (var collection in collections) - { - Console.WriteLine(collection); - } - - Console.WriteLine(); - - Console.WriteLine($"Removing Collection {MemoryCollectionName}"); - await memoryStore.DeleteCollectionAsync(MemoryCollectionName); - Console.WriteLine(); - - Console.WriteLine($"Printing Collections in DB (after removing {MemoryCollectionName})..."); - collections = memoryStore.GetCollectionsAsync(); - await foreach (var collection in collections) - { - Console.WriteLine(collection); - } - } -} diff --git a/dotnet/samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs b/dotnet/samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs deleted file mode 100644 index 0c0f4da85bff..000000000000 --- a/dotnet/samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs +++ /dev/null @@ -1,336 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using Microsoft.SemanticKernel; -using Microsoft.SemanticKernel.Connectors.AzureAISearch; -using Microsoft.SemanticKernel.Connectors.Chroma; -using Microsoft.SemanticKernel.Connectors.DuckDB; -using Microsoft.SemanticKernel.Connectors.Kusto; -using Microsoft.SemanticKernel.Connectors.MongoDB; -using Microsoft.SemanticKernel.Connectors.OpenAI; -using Microsoft.SemanticKernel.Connectors.Pinecone; -using Microsoft.SemanticKernel.Connectors.Postgres; -using Microsoft.SemanticKernel.Connectors.Qdrant; -using Microsoft.SemanticKernel.Connectors.Redis; -using Microsoft.SemanticKernel.Connectors.Sqlite; -using Microsoft.SemanticKernel.Connectors.Weaviate; -using Microsoft.SemanticKernel.Memory; -using Microsoft.SemanticKernel.Plugins.Memory; -using Npgsql; -using StackExchange.Redis; - -namespace Memory; - -public class TextMemoryPlugin_MultipleMemoryStore(ITestOutputHelper output) : BaseTest(output) -{ - private const string MemoryCollectionName = "aboutMe"; - - [Theory] - [InlineData("Volatile")] - [InlineData("AzureAISearch")] - public async Task RunAsync(string provider) - { - // Volatile Memory Store - an in-memory store that is not persisted - IMemoryStore store = provider switch - { - "AzureAISearch" => CreateSampleAzureAISearchMemoryStore(), - _ => new VolatileMemoryStore(), - }; - - /////////////////////////////////////////////////////////////////////////////////////////////////// - // INSTRUCTIONS: uncomment one of the following lines to select a different memory store to use. // - /////////////////////////////////////////////////////////////////////////////////////////////////// - - // Sqlite Memory Store - a file-based store that persists data in a Sqlite database - // store = await CreateSampleSqliteMemoryStoreAsync(); - - // DuckDB Memory Store - a file-based store that persists data in a DuckDB database - // store = await CreateSampleDuckDbMemoryStoreAsync(); - - // MongoDB Memory Store - a store that persists data in a MongoDB database - // store = CreateSampleMongoDBMemoryStore(); - - // Azure AI Search Memory Store - a store that persists data in a hosted Azure AI Search database - // store = CreateSampleAzureAISearchMemoryStore(); - - // Qdrant Memory Store - a store that persists data in a local or remote Qdrant database - // store = CreateSampleQdrantMemoryStore(); - - // Chroma Memory Store - // store = CreateSampleChromaMemoryStore(); - - // Pinecone Memory Store - a store that persists data in a hosted Pinecone database - // store = CreateSamplePineconeMemoryStore(); - - // Weaviate Memory Store - // store = CreateSampleWeaviateMemoryStore(); - - // Redis Memory Store - // store = await CreateSampleRedisMemoryStoreAsync(); - - // Postgres Memory Store - // store = CreateSamplePostgresMemoryStore(); - - // Kusto Memory Store - // store = CreateSampleKustoMemoryStore(); - - await RunWithStoreAsync(store); - } - - private async Task CreateSampleSqliteMemoryStoreAsync() - { - IMemoryStore store = await SqliteMemoryStore.ConnectAsync("memories.sqlite"); - return store; - } - - private async Task CreateSampleDuckDbMemoryStoreAsync() - { - IMemoryStore store = await DuckDBMemoryStore.ConnectAsync("memories.duckdb", 1536); - return store; - } - - private IMemoryStore CreateSampleMongoDBMemoryStore() - { - IMemoryStore store = new MongoDBMemoryStore(TestConfiguration.MongoDB.ConnectionString, "memoryPluginExample"); - return store; - } - - private IMemoryStore CreateSampleAzureAISearchMemoryStore() - { - IMemoryStore store = new AzureAISearchMemoryStore(TestConfiguration.AzureAISearch.Endpoint, TestConfiguration.AzureAISearch.ApiKey); - return store; - } - - private IMemoryStore CreateSampleChromaMemoryStore() - { - IMemoryStore store = new ChromaMemoryStore(TestConfiguration.Chroma.Endpoint, this.LoggerFactory); - return store; - } - - private IMemoryStore CreateSampleQdrantMemoryStore() - { - IMemoryStore store = new QdrantMemoryStore(TestConfiguration.Qdrant.Endpoint, 1536, this.LoggerFactory); - return store; - } - - private IMemoryStore CreateSamplePineconeMemoryStore() - { - IMemoryStore store = new PineconeMemoryStore(TestConfiguration.Pinecone.Environment, TestConfiguration.Pinecone.ApiKey, this.LoggerFactory); - return store; - } - - private IMemoryStore CreateSampleWeaviateMemoryStore() - { - IMemoryStore store = new WeaviateMemoryStore(TestConfiguration.Weaviate.Endpoint, TestConfiguration.Weaviate.ApiKey); - return store; - } - - private async Task CreateSampleRedisMemoryStoreAsync() - { - string configuration = TestConfiguration.Redis.Configuration; - ConnectionMultiplexer connectionMultiplexer = await ConnectionMultiplexer.ConnectAsync(configuration); - IDatabase database = connectionMultiplexer.GetDatabase(); - IMemoryStore store = new RedisMemoryStore(database, vectorSize: 1536); - return store; - } - - private static IMemoryStore CreateSamplePostgresMemoryStore() - { - NpgsqlDataSourceBuilder dataSourceBuilder = new(TestConfiguration.Postgres.ConnectionString); - dataSourceBuilder.UseVector(); - NpgsqlDataSource dataSource = dataSourceBuilder.Build(); - IMemoryStore store = new PostgresMemoryStore(dataSource, vectorSize: 1536, schema: "public"); - return store; - } - - private static IMemoryStore CreateSampleKustoMemoryStore() - { - var connectionString = new Kusto.Data.KustoConnectionStringBuilder(TestConfiguration.Kusto.ConnectionString).WithAadUserPromptAuthentication(); - IMemoryStore store = new KustoMemoryStore(connectionString, "MyDatabase"); - return store; - } - - private async Task RunWithStoreAsync(IMemoryStore memoryStore) - { - var kernel = Kernel.CreateBuilder() - .AddOpenAIChatCompletion(TestConfiguration.OpenAI.ChatModelId, TestConfiguration.OpenAI.ApiKey) - .AddOpenAITextEmbeddingGeneration(TestConfiguration.OpenAI.EmbeddingModelId, TestConfiguration.OpenAI.ApiKey) - .Build(); - - // Create an embedding generator to use for semantic memory. - var embeddingGenerator = new OpenAITextEmbeddingGenerationService(TestConfiguration.OpenAI.EmbeddingModelId, TestConfiguration.OpenAI.ApiKey); - - // The combination of the text embedding generator and the memory store makes up the 'SemanticTextMemory' object used to - // store and retrieve memories. - SemanticTextMemory textMemory = new(memoryStore, embeddingGenerator); - - ///////////////////////////////////////////////////////////////////////////////////////////////////// - // PART 1: Store and retrieve memories using the ISemanticTextMemory (textMemory) object. - // - // This is a simple way to store memories from a code perspective, without using the Kernel. - ///////////////////////////////////////////////////////////////////////////////////////////////////// - Console.WriteLine("== PART 1a: Saving Memories through the ISemanticTextMemory object =="); - - Console.WriteLine("Saving memory with key 'info1': \"My name is Andrea\""); - await textMemory.SaveInformationAsync(MemoryCollectionName, id: "info1", text: "My name is Andrea"); - - Console.WriteLine("Saving memory with key 'info2': \"I work as a tourist operator\""); - await textMemory.SaveInformationAsync(MemoryCollectionName, id: "info2", text: "I work as a tourist operator"); - - Console.WriteLine("Saving memory with key 'info3': \"I've been living in Seattle since 2005\""); - await textMemory.SaveInformationAsync(MemoryCollectionName, id: "info3", text: "I've been living in Seattle since 2005"); - - Console.WriteLine("Saving memory with key 'info4': \"I visited France and Italy five times since 2015\""); - await textMemory.SaveInformationAsync(MemoryCollectionName, id: "info4", text: "I visited France and Italy five times since 2015"); - - // Retrieve a memory - Console.WriteLine("== PART 1b: Retrieving Memories through the ISemanticTextMemory object =="); - MemoryQueryResult? lookup = await textMemory.GetAsync(MemoryCollectionName, "info1"); - Console.WriteLine("Memory with key 'info1':" + lookup?.Metadata.Text ?? "ERROR: memory not found"); - Console.WriteLine(); - - ///////////////////////////////////////////////////////////////////////////////////////////////////// - // PART 2: Create TextMemoryPlugin, store and retrieve memories through the Kernel. - // - // This enables prompt functions and the AI (via Planners) to access memories - ///////////////////////////////////////////////////////////////////////////////////////////////////// - - Console.WriteLine("== PART 2a: Saving Memories through the Kernel with TextMemoryPlugin and the 'Save' function =="); - - // Import the TextMemoryPlugin into the Kernel for other functions - var memoryPlugin = kernel.ImportPluginFromObject(new TextMemoryPlugin(textMemory)); - - // Save a memory with the Kernel - Console.WriteLine("Saving memory with key 'info5': \"My family is from New York\""); - await kernel.InvokeAsync(memoryPlugin["Save"], new() - { - [TextMemoryPlugin.InputParam] = "My family is from New York", - [TextMemoryPlugin.CollectionParam] = MemoryCollectionName, - [TextMemoryPlugin.KeyParam] = "info5", - }); - - // Retrieve a specific memory with the Kernel - Console.WriteLine("== PART 2b: Retrieving Memories through the Kernel with TextMemoryPlugin and the 'Retrieve' function =="); - var result = await kernel.InvokeAsync(memoryPlugin["Retrieve"], new KernelArguments() - { - [TextMemoryPlugin.CollectionParam] = MemoryCollectionName, - [TextMemoryPlugin.KeyParam] = "info5" - }); - - Console.WriteLine("Memory with key 'info5':" + result.GetValue() ?? "ERROR: memory not found"); - Console.WriteLine(); - - ///////////////////////////////////////////////////////////////////////////////////////////////////// - // PART 3: Recall similar ideas with semantic search - // - // Uses AI Embeddings for fuzzy lookup of memories based on intent, rather than a specific key. - ///////////////////////////////////////////////////////////////////////////////////////////////////// - - Console.WriteLine("== PART 3: Recall (similarity search) with AI Embeddings =="); - - Console.WriteLine("== PART 3a: Recall (similarity search) with ISemanticTextMemory =="); - Console.WriteLine("Ask: where did I grow up?"); - - await foreach (var answer in textMemory.SearchAsync( - collection: MemoryCollectionName, - query: "where did I grow up?", - limit: 2, - minRelevanceScore: 0.79, - withEmbeddings: true)) - { - Console.WriteLine($"Answer: {answer.Metadata.Text}"); - } - - Console.WriteLine("== PART 3b: Recall (similarity search) with Kernel and TextMemoryPlugin 'Recall' function =="); - Console.WriteLine("Ask: where do I live?"); - - result = await kernel.InvokeAsync(memoryPlugin["Recall"], new() - { - [TextMemoryPlugin.InputParam] = "Ask: where do I live?", - [TextMemoryPlugin.CollectionParam] = MemoryCollectionName, - [TextMemoryPlugin.LimitParam] = "2", - [TextMemoryPlugin.RelevanceParam] = "0.79", - }); - - Console.WriteLine($"Answer: {result.GetValue()}"); - Console.WriteLine(); - - /* - Output: - - Ask: where did I grow up? - Answer: - ["My family is from New York","I\u0027ve been living in Seattle since 2005"] - - Ask: where do I live? - Answer: - ["I\u0027ve been living in Seattle since 2005","My family is from New York"] - */ - - ///////////////////////////////////////////////////////////////////////////////////////////////////// - // PART 4: TextMemoryPlugin Recall in a Prompt Function - // - // Looks up related memories when rendering a prompt template, then sends the rendered prompt to - // the text generation model to answer a natural language query. - ///////////////////////////////////////////////////////////////////////////////////////////////////// - - Console.WriteLine("== PART 4: Using TextMemoryPlugin 'Recall' function in a Prompt Function =="); - - // Build a prompt function that uses memory to find facts - const string RecallFunctionDefinition = @" -Consider only the facts below when answering questions: - -BEGIN FACTS -About me: {{recall 'where did I grow up?'}} -About me: {{recall 'where do I live now?'}} -END FACTS - -Question: {{$input}} - -Answer: -"; - - var aboutMeOracle = kernel.CreateFunctionFromPrompt(RecallFunctionDefinition, new OpenAIPromptExecutionSettings() { MaxTokens = 100 }); - - result = await kernel.InvokeAsync(aboutMeOracle, new() - { - [TextMemoryPlugin.InputParam] = "Do I live in the same town where I grew up?", - [TextMemoryPlugin.CollectionParam] = MemoryCollectionName, - [TextMemoryPlugin.LimitParam] = "2", - [TextMemoryPlugin.RelevanceParam] = "0.79", - }); - - Console.WriteLine("Ask: Do I live in the same town where I grew up?"); - Console.WriteLine($"Answer: {result.GetValue()}"); - - /* - Approximate Output: - Answer: No, I do not live in the same town where I grew up since my family is from New York and I have been living in Seattle since 2005. - */ - - ///////////////////////////////////////////////////////////////////////////////////////////////////// - // PART 5: Cleanup, deleting database collection - // - ///////////////////////////////////////////////////////////////////////////////////////////////////// - - Console.WriteLine("== PART 5: Cleanup, deleting database collection =="); - - Console.WriteLine("Printing Collections in DB..."); - var collections = memoryStore.GetCollectionsAsync(); - await foreach (var collection in collections) - { - Console.WriteLine(collection); - } - Console.WriteLine(); - - Console.WriteLine($"Removing Collection {MemoryCollectionName}"); - await memoryStore.DeleteCollectionAsync(MemoryCollectionName); - Console.WriteLine(); - - Console.WriteLine($"Printing Collections in DB (after removing {MemoryCollectionName})..."); - collections = memoryStore.GetCollectionsAsync(); - await foreach (var collection in collections) - { - Console.WriteLine(collection); - } - } -} diff --git a/dotnet/samples/Concepts/Memory/TextMemoryPlugin_RecallJsonSerializationWithOptions.cs b/dotnet/samples/Concepts/Memory/TextMemoryPlugin_RecallJsonSerializationWithOptions.cs deleted file mode 100644 index 883195b68df9..000000000000 --- a/dotnet/samples/Concepts/Memory/TextMemoryPlugin_RecallJsonSerializationWithOptions.cs +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Text.Encodings.Web; -using System.Text.Json; -using System.Text.Unicode; -using Microsoft.SemanticKernel; -using Microsoft.SemanticKernel.Connectors.AzureOpenAI; -using Microsoft.SemanticKernel.Memory; -using Microsoft.SemanticKernel.Plugins.Memory; - -namespace Memory; - -/// -/// This example shows how to use custom when serializing multiple results during recall using . -/// -/// -/// When multiple results are returned during recall, has to turn these results into a string to pass back to the kernel. -/// The uses to turn the results into a string. -/// In some cases though, the default serialization options may not work, e.g. if the memories contain non-latin text, -/// will escape these characters by default. In this case, you can provide custom to the to control how the memories are serialized. -/// -public class TextMemoryPlugin_RecallJsonSerializationWithOptions(ITestOutputHelper output) : BaseTest(output) -{ - [Fact] - public async Task RunAsync() - { - // Create a Kernel. - var kernelWithoutOptions = Kernel.CreateBuilder() - .Build(); - - // Create an embedding generator to use for semantic memory. - var embeddingGenerator = new AzureOpenAITextEmbeddingGenerationService(TestConfiguration.AzureOpenAIEmbeddings.DeploymentName, TestConfiguration.AzureOpenAIEmbeddings.Endpoint, TestConfiguration.AzureOpenAIEmbeddings.ApiKey); - - // Using an in memory store for this example. - var memoryStore = new VolatileMemoryStore(); - - // The combination of the text embedding generator and the memory store makes up the 'SemanticTextMemory' object used to - // store and retrieve memories. - SemanticTextMemory textMemory = new(memoryStore, embeddingGenerator); - await textMemory.SaveInformationAsync("samples", "First example of some text in Thai and Bengali: วรรณยุกต์ চলিতভাষা", "test-record-1"); - await textMemory.SaveInformationAsync("samples", "Second example of some text in Thai and Bengali: วรรณยุกต์ চলিতভাষা", "test-record-2"); - - // Import the TextMemoryPlugin into the Kernel without any custom JsonSerializerOptions. - var memoryPluginWithoutOptions = kernelWithoutOptions.ImportPluginFromObject(new TextMemoryPlugin(textMemory)); - - // Retrieve the memories using the TextMemoryPlugin. - var resultWithoutOptions = await kernelWithoutOptions.InvokeAsync(memoryPluginWithoutOptions["Recall"], new() - { - [TextMemoryPlugin.InputParam] = "Text examples", - [TextMemoryPlugin.CollectionParam] = "samples", - [TextMemoryPlugin.LimitParam] = "2", - [TextMemoryPlugin.RelevanceParam] = "0.79", - }); - - // The recall operation returned the following text, where the Thai and Bengali text was escaped: - // ["Second example of some text in Thai and Bengali: \u0E27\u0E23\u0E23\u0E13\u0E22\u0E38\u0E01\u0E15\u0E4C \u099A\u09B2\u09BF\u09A4\u09AD\u09BE\u09B7\u09BE","First example of some text in Thai and Bengali: \u0E27\u0E23\u0E23\u0E13\u0E22\u0E38\u0E01\u0E15\u0E4C \u099A\u09B2\u09BF\u09A4\u09AD\u09BE\u09B7\u09BE"] - Console.WriteLine(resultWithoutOptions.GetValue()); - - // Create a Kernel. - var kernelWithOptions = Kernel.CreateBuilder() - .Build(); - - // Import the TextMemoryPlugin into the Kernel with custom JsonSerializerOptions that allow Thai and Bengali script to be serialized unescaped. - var options = new JsonSerializerOptions { Encoder = JavaScriptEncoder.Create(UnicodeRanges.BasicLatin, UnicodeRanges.Thai, UnicodeRanges.Bengali) }; - var memoryPluginWithOptions = kernelWithOptions.ImportPluginFromObject(new TextMemoryPlugin(textMemory, jsonSerializerOptions: options)); - - // Retrieve the memories using the TextMemoryPlugin. - var result = await kernelWithOptions.InvokeAsync(memoryPluginWithOptions["Recall"], new() - { - [TextMemoryPlugin.InputParam] = "Text examples", - [TextMemoryPlugin.CollectionParam] = "samples", - [TextMemoryPlugin.LimitParam] = "2", - [TextMemoryPlugin.RelevanceParam] = "0.79", - }); - - // The recall operation returned the following text, where the Thai and Bengali text was not escaped: - // ["Second example of some text in Thai and Bengali: วรรณยุกต์ চলিতভাষা","First example of some text in Thai and Bengali: วรรณยุกต์ চলিতভাষা"] - Console.WriteLine(result.GetValue()); - } -} diff --git a/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/GenerateTextEmbeddingAttribute.cs b/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/GenerateTextEmbeddingAttribute.cs deleted file mode 100644 index 9a8e6b17aa27..000000000000 --- a/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/GenerateTextEmbeddingAttribute.cs +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -namespace Memory.VectorStoreEmbeddingGeneration; - -/// -/// An attribute that can be used for an embedding property to indicate that it should -/// be generated from one or more text properties located on the same class. -/// -/// -/// This class is part of the sample. -/// -[AttributeUsage(AttributeTargets.Property, AllowMultiple = false, Inherited = true)] -public sealed class GenerateTextEmbeddingAttribute : Attribute -{ - /// - /// Initializes a new instance of the class. - /// - /// The name of the property that the embedding should be generated from. -#pragma warning disable CA1019 // Define accessors for attribute arguments - public GenerateTextEmbeddingAttribute(string sourcePropertyName) -#pragma warning restore CA1019 // Define accessors for attribute arguments - { - this.SourcePropertyNames = [sourcePropertyName]; - } - - /// - /// Initializes a new instance of the class. - /// - /// The names of the properties that the embedding should be generated from. - public GenerateTextEmbeddingAttribute(string[] sourcePropertyNames) - { - this.SourcePropertyNames = sourcePropertyNames; - } - - /// - /// Gets the name of the property to use as the source for generating the embedding. - /// - public string[] SourcePropertyNames { get; } -} diff --git a/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStore.cs b/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStore.cs deleted file mode 100644 index 6848b38af48f..000000000000 --- a/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStore.cs +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Embeddings; - -namespace Memory.VectorStoreEmbeddingGeneration; - -/// -/// Decorator for a that generates embeddings for records on upsert. -/// -/// -/// This class is part of the sample. -/// -public class TextEmbeddingVectorStore : IVectorStore -{ - /// The decorated . - private readonly IVectorStore _decoratedVectorStore; - - /// The service to use for generating the embeddings. - private readonly ITextEmbeddingGenerationService _textEmbeddingGenerationService; - - /// - /// Initializes a new instance of the class. - /// - /// The decorated . - /// The service to use for generating the embeddings. - public TextEmbeddingVectorStore(IVectorStore decoratedVectorStore, ITextEmbeddingGenerationService textEmbeddingGenerationService) - { - // Verify & Assign. - this._decoratedVectorStore = decoratedVectorStore ?? throw new ArgumentNullException(nameof(decoratedVectorStore)); - this._textEmbeddingGenerationService = textEmbeddingGenerationService ?? throw new ArgumentNullException(nameof(textEmbeddingGenerationService)); - } - - /// - public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) - where TKey : notnull - { - var collection = this._decoratedVectorStore.GetCollection(name, vectorStoreRecordDefinition); - var embeddingStore = new TextEmbeddingVectorStoreRecordCollection(collection, this._textEmbeddingGenerationService); - return embeddingStore; - } - - /// - public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) - { - return this._decoratedVectorStore.ListCollectionNamesAsync(cancellationToken); - } -} diff --git a/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStoreExtensions.cs b/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStoreExtensions.cs deleted file mode 100644 index e1b6c779fdb8..000000000000 --- a/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStoreExtensions.cs +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Embeddings; - -namespace Memory.VectorStoreEmbeddingGeneration; - -/// -/// Contains extension methods to help add text embedding generation to a or -/// -/// -/// This class is part of the sample. -/// -public static class TextEmbeddingVectorStoreExtensions -{ - /// - /// Add text embedding generation to a . - /// - /// The to add text embedding generation to. - /// The service to use for generating text embeddings. - /// The with text embedding added. - public static IVectorStore UseTextEmbeddingGeneration(this IVectorStore vectorStore, ITextEmbeddingGenerationService textEmbeddingGenerationService) - { - return new TextEmbeddingVectorStore(vectorStore, textEmbeddingGenerationService); - } - - /// - /// Add text embedding generation to a . - /// - /// The to add text embedding generation to. - /// The service to use for generating text embeddings. - /// The data type of the record key. - /// The record data model to use for adding, updating and retrieving data from the store. - /// The with text embedding added. - public static IVectorStoreRecordCollection UseTextEmbeddingGeneration(this IVectorStoreRecordCollection vectorStoreRecordCollection, ITextEmbeddingGenerationService textEmbeddingGenerationService) - where TKey : notnull - { - return new TextEmbeddingVectorStoreRecordCollection(vectorStoreRecordCollection, textEmbeddingGenerationService); - } -} diff --git a/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStoreRecordCollection.cs b/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStoreRecordCollection.cs deleted file mode 100644 index 000cb1ebba07..000000000000 --- a/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStoreRecordCollection.cs +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Reflection; -using System.Runtime.CompilerServices; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Embeddings; - -namespace Memory.VectorStoreEmbeddingGeneration; - -/// -/// Decorator for a that generates embeddings for records on upsert and when using . -/// -/// -/// This class is part of the sample. -/// -/// The data type of the record key. -/// The record data model to use for adding, updating and retrieving data from the store. -#pragma warning disable CA1711 // Identifiers should not have incorrect suffix -public class TextEmbeddingVectorStoreRecordCollection : IVectorStoreRecordCollection, IVectorizableTextSearch -#pragma warning restore CA1711 // Identifiers should not have incorrect suffix - where TKey : notnull -{ - /// The decorated . - private readonly IVectorStoreRecordCollection _decoratedVectorStoreRecordCollection; - - /// The service to use for generating the embeddings. - private readonly ITextEmbeddingGenerationService _textEmbeddingGenerationService; - - /// Optional configuration options for this class. - private readonly IEnumerable<(PropertyInfo EmbeddingPropertyInfo, IList SourcePropertiesInfo)> _embeddingPropertiesWithSourceProperties; - - /// - /// Initializes a new instance of the class. - /// - /// The decorated . - /// The service to use for generating the embeddings. - /// Thrown when embedding properties are referencing data source properties that do not exist. - /// Thrown when required parameters are null. - public TextEmbeddingVectorStoreRecordCollection(IVectorStoreRecordCollection decoratedVectorStoreRecordCollection, ITextEmbeddingGenerationService textEmbeddingGenerationService) - { - // Assign. - this._decoratedVectorStoreRecordCollection = decoratedVectorStoreRecordCollection ?? throw new ArgumentNullException(nameof(decoratedVectorStoreRecordCollection)); - this._textEmbeddingGenerationService = textEmbeddingGenerationService ?? throw new ArgumentNullException(nameof(textEmbeddingGenerationService)); - - // Find all the embedding properties to generate embeddings for. - this._embeddingPropertiesWithSourceProperties = FindDataPropertiesWithEmbeddingProperties(typeof(TRecord)); - } - - /// - public string CollectionName => this._decoratedVectorStoreRecordCollection.CollectionName; - - /// - public Task CollectionExistsAsync(CancellationToken cancellationToken = default) - { - return this._decoratedVectorStoreRecordCollection.CollectionExistsAsync(cancellationToken); - } - - /// - public Task CreateCollectionAsync(CancellationToken cancellationToken = default) - { - return this._decoratedVectorStoreRecordCollection.CreateCollectionAsync(cancellationToken); - } - - /// - public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) - { - if (!await this.CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) - { - await this.CreateCollectionAsync(cancellationToken).ConfigureAwait(false); - } - } - - /// - public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) - { - return this._decoratedVectorStoreRecordCollection.DeleteCollectionAsync(cancellationToken); - } - - /// - public Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) - { - return this._decoratedVectorStoreRecordCollection.DeleteAsync(key, cancellationToken); - } - - /// - public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) - { - return this._decoratedVectorStoreRecordCollection.DeleteBatchAsync(keys, cancellationToken); - } - - /// - public Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) - { - return this._decoratedVectorStoreRecordCollection.GetAsync(key, options, cancellationToken); - } - - /// - public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) - { - return this._decoratedVectorStoreRecordCollection.GetBatchAsync(keys, options, cancellationToken); - } - - /// - public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) - { - var recordWithEmbeddings = await this.AddEmbeddingsAsync(record, cancellationToken).ConfigureAwait(false); - return await this._decoratedVectorStoreRecordCollection.UpsertAsync(recordWithEmbeddings, cancellationToken).ConfigureAwait(false); - } - - /// - public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - var recordWithEmbeddingsTasks = records.Select(r => this.AddEmbeddingsAsync(r, cancellationToken)); - var recordWithEmbeddings = await Task.WhenAll(recordWithEmbeddingsTasks).ConfigureAwait(false); - var upsertResults = this._decoratedVectorStoreRecordCollection.UpsertBatchAsync(recordWithEmbeddings, cancellationToken); - await foreach (var upsertResult in upsertResults.ConfigureAwait(false)) - { - yield return upsertResult; - } - } - - /// - public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) - { - return this._decoratedVectorStoreRecordCollection.VectorizedSearchAsync(vector, options, cancellationToken); - } - - /// - public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) - { - var embeddingValue = await this._textEmbeddingGenerationService.GenerateEmbeddingAsync(searchText, cancellationToken: cancellationToken).ConfigureAwait(false); - return await this.VectorizedSearchAsync(embeddingValue, options, cancellationToken).ConfigureAwait(false); - } - - /// - /// Generate and add embeddings for each embedding field that has a on the provided record. - /// - /// The record to generate embeddings for. - /// The to monitor for cancellation requests. - /// The record with embeddings added. - private async Task AddEmbeddingsAsync(TRecord record, CancellationToken cancellationToken) - { - foreach (var (embeddingPropertyInfo, sourcePropertiesInfo) in this._embeddingPropertiesWithSourceProperties) - { - var sourceValues = sourcePropertiesInfo.Select(x => x.GetValue(record)).Cast().Where(x => !string.IsNullOrWhiteSpace(x)); - var sourceString = string.Join("\n", sourceValues); - - var embeddingValue = await this._textEmbeddingGenerationService.GenerateEmbeddingAsync(sourceString, cancellationToken: cancellationToken).ConfigureAwait(false); - embeddingPropertyInfo.SetValue(record, embeddingValue); - } - - return record; - } - - /// - /// Get the list of properties with from the data model. - /// - /// The type of the data model to find - /// The list of properties with with the properties from which the embedding can be generated. - private static IEnumerable<(PropertyInfo EmbeddingPropertyInfo, IList SourcePropertiesInfo)> FindDataPropertiesWithEmbeddingProperties(Type dataModelType) - { - var allProperties = dataModelType.GetProperties(); - var propertiesDictionary = allProperties.ToDictionary(p => p.Name); - - // Loop through all the properties to find the ones that have the GenerateTextEmbeddingAttribute. - foreach (var property in allProperties) - { - var attribute = property.GetCustomAttribute(); - if (attribute is not null) - { - // Find the source properties that the embedding should be generated from. - var sourcePropertiesInfo = new List(); - foreach (var sourcePropertyName in attribute.SourcePropertyNames) - { - if (!propertiesDictionary.TryGetValue(sourcePropertyName, out var sourcePropertyInfo)) - { - throw new ArgumentException($"The source property '{sourcePropertyName}' as referenced by embedding property '{property.Name}' does not exist in the record model."); - } - else if (sourcePropertyInfo.PropertyType != typeof(string)) - { - throw new ArgumentException($"The source property '{sourcePropertyName}' as referenced by embedding property '{property.Name}' has type {sourcePropertyInfo.PropertyType} but must be a string."); - } - else - { - sourcePropertiesInfo.Add(sourcePropertyInfo); - } - } - - yield return (property, sourcePropertiesInfo); - } - } - } -} diff --git a/dotnet/samples/Concepts/Memory/VectorStoreExtensions.cs b/dotnet/samples/Concepts/Memory/VectorStoreExtensions.cs index 3d54787aee79..3a2183ba34ee 100644 --- a/dotnet/samples/Concepts/Memory/VectorStoreExtensions.cs +++ b/dotnet/samples/Concepts/Memory/VectorStoreExtensions.cs @@ -45,6 +45,7 @@ internal static async Task> CreateCo ITextEmbeddingGenerationService embeddingGenerationService, CreateRecordFromString createRecord) where TKey : notnull + where TRecord : notnull { // Get and create collection if it doesn't exist. var collection = vectorStore.GetCollection(collectionName); @@ -80,6 +81,7 @@ internal static async Task> CreateCo ITextEmbeddingGenerationService embeddingGenerationService, CreateRecordFromTextSearchResult createRecord) where TKey : notnull + where TRecord : notnull { // Get and create collection if it doesn't exist. var collection = vectorStore.GetCollection(collectionName); diff --git a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/AzureAISearchFactory.cs b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/AzureAISearchFactory.cs deleted file mode 100644 index 2bf0cb763a7a..000000000000 --- a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/AzureAISearchFactory.cs +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Text.Json; -using System.Text.Json.Nodes; -using System.Text.Json.Serialization; -using Azure.Search.Documents.Indexes; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.AzureAISearch; - -namespace Memory.VectorStoreLangchainInterop; - -/// -/// Contains a factory method that can be used to create an Azure AI Search vector store that is compatible with datasets ingested using Langchain. -/// -/// -/// This class is used with the sample. -/// -public static class AzureAISearchFactory -{ - /// - /// Record definition that matches the storage format used by Langchain for Azure AI Search. - /// - private static readonly VectorStoreRecordDefinition s_recordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("id", typeof(string)), - new VectorStoreRecordDataProperty("content", typeof(string)), - new VectorStoreRecordDataProperty("metadata", typeof(string)), - new VectorStoreRecordVectorProperty("content_vector", typeof(ReadOnlyMemory)) { Dimensions = 1536 } - } - }; - - /// - /// Create a new Azure AI Search-backed that can be used to read data that was ingested using Langchain. - /// - /// Azure AI Search client that can be used to manage the list of indices in an Azure AI Search Service. - /// The . - public static IVectorStore CreateQdrantLangchainInteropVectorStore(SearchIndexClient searchIndexClient) - => new AzureAISearchLangchainInteropVectorStore(searchIndexClient); - - private sealed class AzureAISearchLangchainInteropVectorStore(SearchIndexClient searchIndexClient, AzureAISearchVectorStoreOptions? options = default) - : AzureAISearchVectorStore(searchIndexClient, options) - { - private readonly SearchIndexClient _searchIndexClient = searchIndexClient; - - public override IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) - { - if (typeof(TKey) != typeof(string) || typeof(TRecord) != typeof(LangchainDocument)) - { - throw new NotSupportedException("This VectorStore is only usable with string keys and LangchainDocument record types"); - } - - // Create an Azure AI Search collection. To be compatible with Langchain - // we need to use a custom record definition that matches the - // schema used by Langchain. We also need to use a custom mapper - // since the Langchain schema includes a metadata field that is - // a JSON string containing the source property. Parsing this - // string and extracting the source is not supported by the default mapper. - return (new AzureAISearchVectorStoreRecordCollection( - _searchIndexClient, - name, - new() - { - VectorStoreRecordDefinition = s_recordDefinition, - JsonObjectCustomMapper = new LangchainInteropMapper() as IVectorStoreRecordMapper - }) as IVectorStoreRecordCollection)!; - } - } - - /// - /// Custom mapper to map the metadata string field, since it contains JSON as a string and this is not supported - /// automatically by the built in mapper. - /// - private sealed class LangchainInteropMapper : IVectorStoreRecordMapper, JsonObject> - { - public JsonObject MapFromDataToStorageModel(LangchainDocument dataModel) - { - var storageDocument = new AzureAISearchLangchainDocument() - { - Key = dataModel.Key, - Content = dataModel.Content, - Metadata = $"{{\"source\": \"{dataModel.Source}\"}}", - Embedding = dataModel.Embedding - }; - - return JsonSerializer.SerializeToNode(storageDocument)!.AsObject(); - } - - public LangchainDocument MapFromStorageToDataModel(JsonObject storageModel, StorageToDataModelMapperOptions options) - { - var storageDocument = JsonSerializer.Deserialize(storageModel)!; - var metadataDocument = JsonSerializer.Deserialize(storageDocument.Metadata); - var source = metadataDocument?["source"]?.AsValue()?.ToString(); - - return new LangchainDocument() - { - Key = storageDocument.Key, - Content = storageDocument.Content, - Source = source!, - Embedding = storageDocument.Embedding - }; - } - } - - /// - /// Model class that matches the storage format used by Langchain for Azure AI Search. - /// - private sealed class AzureAISearchLangchainDocument - { - [JsonPropertyName("id")] - public string Key { get; set; } - - [JsonPropertyName("content")] - public string Content { get; set; } - - /// - /// The storage format used by Langchain stores the source information - /// in the metadata field as a JSON string. - /// E.g. {"source": "my-doc"} - /// - [JsonPropertyName("metadata")] - public string Metadata { get; set; } - - [JsonPropertyName("content_vector")] - public ReadOnlyMemory Embedding { get; set; } - } -} diff --git a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/MappingVectorStoreRecordCollection.cs b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/MappingVectorStoreRecordCollection.cs deleted file mode 100644 index 1951f3a6dbee..000000000000 --- a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/MappingVectorStoreRecordCollection.cs +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -// TODO: Commented out as part of implementing LINQ-based filtering, since MappingVectorStoreRecordCollection is no longer easy/feasible. -// TODO: The user provides an expression tree accepting a TPublicRecord, but we require an expression tree accepting a TInternalRecord. -// TODO: This is something that the user must provide, and is quite advanced. - -#if DISABLED - -using System.Runtime.CompilerServices; -using Microsoft.Extensions.VectorData; - -namespace Memory.VectorStoreLangchainInterop; - -/// -/// Decorator class that allows conversion of keys and records between public and internal representations. -/// -/// -/// This class is useful if a vector store implementation exposes keys or records in a way that is not -/// suitable for the user of the vector store. E.g. let's say that the vector store supports Guid keys -/// but you want to work with string keys that contain Guids. This class allows you to map between the -/// public string Guids and the internal Guids. -/// -/// The type of the key that the user of this class will use. -/// The type of the key that the internal collection exposes. -/// The type of the record that the user of this class will use. -/// The type of the record that the internal collection exposes. -internal sealed class MappingVectorStoreRecordCollection : IVectorStoreRecordCollection - where TPublicKey : notnull - where TInternalKey : notnull -{ - private readonly IVectorStoreRecordCollection _collection; - private readonly Func _publicToInternalKeyMapper; - private readonly Func _internalToPublicKeyMapper; - private readonly Func _publicToInternalRecordMapper; - private readonly Func _internalToPublicRecordMapper; - - public MappingVectorStoreRecordCollection( - IVectorStoreRecordCollection collection, - Func publicToInternalKeyMapper, - Func internalToPublicKeyMapper, - Func publicToInternalRecordMapper, - Func internalToPublicRecordMapper) - { - this._collection = collection; - this._publicToInternalKeyMapper = publicToInternalKeyMapper; - this._internalToPublicKeyMapper = internalToPublicKeyMapper; - this._publicToInternalRecordMapper = publicToInternalRecordMapper; - this._internalToPublicRecordMapper = internalToPublicRecordMapper; - } - - /// - public string CollectionName => this._collection.CollectionName; - - /// - public Task CollectionExistsAsync(CancellationToken cancellationToken = default) - { - return this._collection.CollectionExistsAsync(cancellationToken); - } - - /// - public Task CreateCollectionAsync(CancellationToken cancellationToken = default) - { - return this._collection.CreateCollectionAsync(cancellationToken); - } - - /// - public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) - { - return this._collection.CreateCollectionIfNotExistsAsync(cancellationToken); - } - - /// - public Task DeleteAsync(TPublicKey key, CancellationToken cancellationToken = default) - { - return this._collection.DeleteAsync(this._publicToInternalKeyMapper(key), cancellationToken); - } - - /// - public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) - { - return this._collection.DeleteBatchAsync(keys.Select(this._publicToInternalKeyMapper), cancellationToken); - } - - /// - public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) - { - return this._collection.DeleteCollectionAsync(cancellationToken); - } - - /// - public async Task GetAsync(TPublicKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) - { - var internalRecord = await this._collection.GetAsync(this._publicToInternalKeyMapper(key), options, cancellationToken).ConfigureAwait(false); - if (internalRecord == null) - { - return default; - } - - return this._internalToPublicRecordMapper(internalRecord); - } - - /// - public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) - { - var internalRecords = this._collection.GetBatchAsync(keys.Select(this._publicToInternalKeyMapper), options, cancellationToken); - return internalRecords.Select(this._internalToPublicRecordMapper); - } - - /// - public async Task UpsertAsync(TPublicRecord record, CancellationToken cancellationToken = default) - { - var internalRecord = this._publicToInternalRecordMapper(record); - var internalKey = await this._collection.UpsertAsync(internalRecord, cancellationToken).ConfigureAwait(false); - return this._internalToPublicKeyMapper(internalKey); - } - - /// - public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - var internalRecords = records.Select(this._publicToInternalRecordMapper); - var internalKeys = this._collection.UpsertBatchAsync(internalRecords, cancellationToken); - await foreach (var internalKey in internalKeys.ConfigureAwait(false)) - { - yield return this._internalToPublicKeyMapper(internalKey); - } - } - - /// - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) - { - var searchResults = await this._collection.VectorizedSearchAsync(vector, options, cancellationToken).ConfigureAwait(false); - var publicResultRecords = searchResults.Results.Select(result => new VectorSearchResult(this._internalToPublicRecordMapper(result.Record), result.Score)); - - return new VectorSearchResults(publicResultRecords) - { - TotalCount = searchResults.TotalCount, - Metadata = searchResults.Metadata, - }; - } -} - -#endif diff --git a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/PineconeFactory.cs b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/PineconeFactory.cs index 2f878199b62a..6e391fffc16a 100644 --- a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/PineconeFactory.cs +++ b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/PineconeFactory.cs @@ -24,7 +24,7 @@ public static class PineconeFactory new VectorStoreRecordKeyProperty("Key", typeof(string)), new VectorStoreRecordDataProperty("Content", typeof(string)) { StoragePropertyName = "text" }, new VectorStoreRecordDataProperty("Source", typeof(string)) { StoragePropertyName = "source" }, - new VectorStoreRecordVectorProperty("Embedding", typeof(ReadOnlyMemory)) { StoragePropertyName = "embedding", Dimensions = 1536 } + new VectorStoreRecordVectorProperty("Embedding", typeof(ReadOnlyMemory), 1536) { StoragePropertyName = "embedding" } } }; @@ -34,14 +34,18 @@ public static class PineconeFactory /// Pinecone client that can be used to manage the collections and points in a Pinecone store. /// The . public static IVectorStore CreatePineconeLangchainInteropVectorStore(Sdk.PineconeClient pineconeClient) - => new PineconeLangchainInteropVectorStore(pineconeClient); + => new PineconeLangchainInteropVectorStore(new PineconeVectorStore(pineconeClient), pineconeClient); - private sealed class PineconeLangchainInteropVectorStore(Sdk.PineconeClient pineconeClient) - : PineconeVectorStore(pineconeClient) + private sealed class PineconeLangchainInteropVectorStore( + IVectorStore innerStore, + Sdk.PineconeClient pineconeClient) + : IVectorStore { private readonly Sdk.PineconeClient _pineconeClient = pineconeClient; - public override IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + where TKey : notnull + where TRecord : notnull { if (typeof(TKey) != typeof(string) || typeof(TRecord) != typeof(LangchainDocument)) { @@ -51,7 +55,7 @@ public override IVectorStoreRecordCollection GetCollection( + return (new PineconeVectorStoreRecordCollection( _pineconeClient, name, new() @@ -59,5 +63,13 @@ public override IVectorStoreRecordCollection GetCollection)!; } + + public object? GetService(Type serviceType, object? serviceKey = null) => innerStore.GetService(serviceType, serviceKey); + + public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) => innerStore.ListCollectionNamesAsync(cancellationToken); + + public Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) => innerStore.CollectionExistsAsync(name, cancellationToken); + + public Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default) => innerStore.DeleteCollectionAsync(name, cancellationToken); } } diff --git a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/QdrantFactory.cs b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/QdrantFactory.cs deleted file mode 100644 index 53f0b399af82..000000000000 --- a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/QdrantFactory.cs +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.Qdrant; -using Qdrant.Client; -using Qdrant.Client.Grpc; - -namespace Memory.VectorStoreLangchainInterop; - -/// -/// Contains a factory method that can be used to create a Qdrant vector store that is compatible with datasets ingested using Langchain. -/// -/// -/// This class is used with the sample. -/// -public static class QdrantFactory -{ - /// - /// Record definition that matches the storage format used by Langchain for Qdrant. - /// There is no need to list the data fields, since they have no indexing requirements and Qdrant - /// doesn't require individual fields to be defined on index creation. - /// - private static readonly VectorStoreRecordDefinition s_recordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(Guid)), - new VectorStoreRecordVectorProperty("Embedding", typeof(ReadOnlyMemory)) { StoragePropertyName = "embedding", Dimensions = 1536 } - } - }; - - /// - /// Create a new Qdrant-backed that can be used to read data that was ingested using Langchain. - /// - /// Qdrant client that can be used to manage the collections and points in a Qdrant store. - /// The . - public static IVectorStore CreateQdrantLangchainInteropVectorStore(QdrantClient qdrantClient) - => new QdrantLangchainInteropVectorStore(qdrantClient); - - private sealed class QdrantLangchainInteropVectorStore(QdrantClient qdrantClient) - : QdrantVectorStore(qdrantClient) - { - private readonly QdrantClient _qdrantClient = qdrantClient; - - public override IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) - { - // Create a Qdrant collection. To be compatible with Langchain - // we need to use a custom record definition that matches the - // schema used by Langchain. We also need to use a custom mapper - // since the Langchain schema includes a metadata field that is - // a struct and this isn't supported by the default mapper. - // Since langchain creates collections without named vector support - // we should set HasNamedVectors to false. - var collection = new QdrantVectorStoreRecordCollection>( - _qdrantClient, - name, - new() - { - HasNamedVectors = false, - VectorStoreRecordDefinition = s_recordDefinition, - PointStructCustomMapper = new LangchainInteropMapper() - }); - - // If the user asked for a guid key, we can return the collection as is. - if (typeof(TKey) == typeof(Guid) && typeof(TRecord) == typeof(LangchainDocument)) - { - return (collection as IVectorStoreRecordCollection)!; - } - -#if DISABLED_FOR_NOW // TODO: See note on MappingVectorStoreRecordCollection - // If the user asked for a string key, we can add a decorator which converts back and forth between string and guid. - // The string that the user provides will still need to contain a valid guid, since the Langchain created collection - // uses guid keys. - // Supporting string keys like this is useful since it means you can work with the collection in the same way as with - // collections from other vector stores that support string keys. - if (typeof(TKey) == typeof(string) && typeof(TRecord) == typeof(LangchainDocument)) - { - var stringKeyCollection = new MappingVectorStoreRecordCollection, LangchainDocument>( - collection, - p => Guid.Parse(p), - i => i.ToString("D"), - p => new LangchainDocument { Key = Guid.Parse(p.Key), Content = p.Content, Source = p.Source, Embedding = p.Embedding }, - i => new LangchainDocument { Key = i.Key.ToString("D"), Content = i.Content, Source = i.Source, Embedding = i.Embedding }); - - return (stringKeyCollection as IVectorStoreRecordCollection)!; - } -#endif - - throw new NotSupportedException("This VectorStore is only usable with Guid keys and LangchainDocument record types or string keys and LangchainDocument record types"); - } - } - - /// - /// A custom mapper that is required to map the metadata struct. While the other - /// fields in the record can be mapped by the default Qdrant mapper, the default - /// mapper doesn't support complex types like metadata, which is a Qdrant struct - /// containing a source field. - /// - private sealed class LangchainInteropMapper : IVectorStoreRecordMapper, PointStruct> - { - public PointStruct MapFromDataToStorageModel(LangchainDocument dataModel) - { - var metadataStruct = new Struct() - { - Fields = { ["source"] = dataModel.Source } - }; - - var pointStruct = new PointStruct() - { - Id = new PointId() { Uuid = dataModel.Key.ToString("D") }, - Vectors = new Vectors() { Vector = dataModel.Embedding.ToArray() }, - Payload = - { - ["page_content"] = dataModel.Content, - ["metadata"] = new Value() { StructValue = metadataStruct } - }, - }; - - return pointStruct; - } - - public LangchainDocument MapFromStorageToDataModel(PointStruct storageModel, StorageToDataModelMapperOptions options) - { - return new LangchainDocument() - { - Key = new Guid(storageModel.Id.Uuid), - Content = storageModel.Payload["page_content"].StringValue, - Source = storageModel.Payload["metadata"].StructValue.Fields["source"].StringValue, - Embedding = options.IncludeVectors ? storageModel.Vectors.Vector.Data.ToArray() : null - }; - } - } -} diff --git a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/RedisFactory.cs b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/RedisFactory.cs index 23fd026401b4..86e54937bdf6 100644 --- a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/RedisFactory.cs +++ b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/RedisFactory.cs @@ -24,7 +24,7 @@ public static class RedisFactory new VectorStoreRecordKeyProperty("Key", typeof(string)), new VectorStoreRecordDataProperty("Content", typeof(string)) { StoragePropertyName = "text" }, new VectorStoreRecordDataProperty("Source", typeof(string)) { StoragePropertyName = "source" }, - new VectorStoreRecordVectorProperty("Embedding", typeof(ReadOnlyMemory)) { StoragePropertyName = "embedding", Dimensions = 1536 } + new VectorStoreRecordVectorProperty("Embedding", typeof(ReadOnlyMemory), 1536) { StoragePropertyName = "embedding" } } }; @@ -34,14 +34,18 @@ public static class RedisFactory /// The redis database to read/write from. /// The . public static IVectorStore CreateRedisLangchainInteropVectorStore(IDatabase database) - => new RedisLangchainInteropVectorStore(database); + => new RedisLangchainInteropVectorStore(new RedisVectorStore(database), database); - private sealed class RedisLangchainInteropVectorStore(IDatabase database) - : RedisVectorStore(database) + private sealed class RedisLangchainInteropVectorStore( + IVectorStore innerStore, + IDatabase database) + : IVectorStore { private readonly IDatabase _database = database; - public override IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + where TKey : notnull + where TRecord : notnull { if (typeof(TKey) != typeof(string) || typeof(TRecord) != typeof(LangchainDocument)) { @@ -52,7 +56,7 @@ public override IVectorStoreRecordCollection GetCollection( + return (new RedisHashSetVectorStoreRecordCollection( _database, name, new() @@ -60,5 +64,13 @@ public override IVectorStoreRecordCollection GetCollection)!; } + + public object? GetService(Type serviceType, object? serviceKey = null) => innerStore.GetService(serviceType, serviceKey); + + public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) => innerStore.ListCollectionNamesAsync(cancellationToken); + + public Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) => innerStore.CollectionExistsAsync(name, cancellationToken); + + public Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default) => innerStore.DeleteCollectionAsync(name, cancellationToken); } } diff --git a/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_AzureAISearch.cs b/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_AzureAISearch.cs index da4ae9cf7a76..12ce70374a0b 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_AzureAISearch.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_AzureAISearch.cs @@ -26,6 +26,7 @@ namespace Memory; /// dotnet user-secrets set "AzureAISearch:Endpoint" "https://myazureaisearchinstance.search.windows.net" /// dotnet user-secrets set "AzureAISearch:ApiKey" "samplesecret" /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class VectorStore_ConsumeFromMemoryStore_AzureAISearch(ITestOutputHelper output, VectorStoreQdrantContainerFixture qdrantFixture) : BaseTest(output), IClassFixture { private const int VectorSize = 1536; diff --git a/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_Common.cs b/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_Common.cs index 772327889f49..50782b075af6 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_Common.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_Common.cs @@ -19,6 +19,7 @@ namespace Memory; /// /// /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public static class VectorStore_ConsumeFromMemoryStore_Common { public static async Task CreateCollectionAndAddSampleDataAsync(IMemoryStore memoryStore, string collectionName, ITextEmbeddingGenerationService textEmbeddingService) diff --git a/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_Qdrant.cs b/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_Qdrant.cs index 1f21c404e312..00b85bb6b494 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_Qdrant.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_Qdrant.cs @@ -23,6 +23,7 @@ namespace Memory; /// To run this sample, you need a local instance of Docker running, since the associated fixture /// will try and start a Qdrant container in the local docker instance to run against. /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class VectorStore_ConsumeFromMemoryStore_Qdrant(ITestOutputHelper output, VectorStoreQdrantContainerFixture qdrantFixture) : BaseTest(output), IClassFixture { private const int VectorSize = 1536; diff --git a/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_Redis.cs b/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_Redis.cs index 91ecae46c124..669f5d2dfa7e 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_Redis.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_Redis.cs @@ -22,6 +22,7 @@ namespace Memory; /// To run this sample, you need a local instance of Docker running, since the associated fixture /// will try and start a Redis container in the local docker instance to run against. /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class VectorStore_ConsumeFromMemoryStore_Redis(ITestOutputHelper output, VectorStoreRedisContainerFixture redisFixture) : BaseTest(output), IClassFixture { private const int VectorSize = 1536; diff --git a/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_CustomMapper.cs b/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_CustomMapper.cs deleted file mode 100644 index 3f86c763acbb..000000000000 --- a/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_CustomMapper.cs +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Text.Json; -using System.Text.Json.Nodes; -using Azure.Identity; -using Memory.VectorStoreFixtures; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.AzureOpenAI; -using Microsoft.SemanticKernel.Connectors.Redis; -using Microsoft.SemanticKernel.Embeddings; -using StackExchange.Redis; - -namespace Memory; - -/// -/// An example showing how to ingest data into a vector store using with a custom mapper. -/// In this example, the storage model differs significantly from the data model, so a custom mapper is used to map between the two. -/// A is used to define the schema of the storage model, and this means that the connector -/// will not try and infer the schema from the data model. -/// In storage the data is stored as a JSON object that looks similar to this: -/// -/// { -/// "Term": "API", -/// "Definition": "Application Programming Interface. A set of rules and specifications that allow software components to communicate and exchange data.", -/// "DefinitionEmbedding": [ ... ] -/// } -/// -/// However, the data model is a class with a property for key and two dictionaries for the data (Term and Definition) and vector (DefinitionEmbedding). -/// -/// The example shows the following steps: -/// 1. Create an embedding generator. -/// 2. Create a Redis Vector Store using a custom factory for creating collections. -/// When constructing a collection, the factory injects a custom mapper that maps between the data model and the storage model if required. -/// 3. Ingest some data into the vector store. -/// 4. Read the data back from the vector store. -/// -/// You need a local instance of Docker running, since the associated fixture will try and start a Redis container in the local docker instance to run against. -/// -public class VectorStore_DataIngestion_CustomMapper(ITestOutputHelper output, VectorStoreRedisContainerFixture redisFixture) : BaseTest(output), IClassFixture -{ - /// - /// A record definition for the glossary entries that defines the storage schema of the record. - /// - private static readonly VectorStoreRecordDefinition s_glossaryDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("Term", typeof(string)), - new VectorStoreRecordDataProperty("Definition", typeof(string)), - new VectorStoreRecordVectorProperty("DefinitionEmbedding", typeof(ReadOnlyMemory)) { Dimensions = 1536, DistanceFunction = DistanceFunction.DotProductSimilarity } - } - }; - - [Fact] - public async Task ExampleAsync() - { - // Create an embedding generation service. - var textEmbeddingGenerationService = new AzureOpenAITextEmbeddingGenerationService( - TestConfiguration.AzureOpenAIEmbeddings.DeploymentName, - TestConfiguration.AzureOpenAIEmbeddings.Endpoint, - new AzureCliCredential()); - - // Initiate the docker container and construct the vector store using the custom factory for creating collections. - await redisFixture.ManualInitializeAsync(); - ConnectionMultiplexer redis = ConnectionMultiplexer.Connect("localhost:6379"); - var vectorStore = new CustomRedisVectorStore(redis.GetDatabase()); - - // Get and create collection if it doesn't exist, using the record definition containing the storage model. - var collection = vectorStore.GetCollection("skglossary", s_glossaryDefinition); - await collection.CreateCollectionIfNotExistsAsync(); - - // Create glossary entries and generate embeddings for them. - var glossaryEntries = CreateGlossaryEntries().ToList(); - var tasks = glossaryEntries.Select(entry => Task.Run(async () => - { - entry.Vectors["DefinitionEmbedding"] = await textEmbeddingGenerationService.GenerateEmbeddingAsync((string)entry.Data["Definition"]); - })); - await Task.WhenAll(tasks); - - // Upsert the glossary entries into the collection and return their keys. - var upsertedKeysTasks = glossaryEntries.Select(x => collection.UpsertAsync(x)); - var upsertedKeys = await Task.WhenAll(upsertedKeysTasks); - - // Retrieve one of the upserted records from the collection. - var upsertedRecord = await collection.GetAsync(upsertedKeys.First(), new() { IncludeVectors = true }); - - // Write upserted keys and one of the upserted records to the console. - Console.WriteLine($"Upserted keys: {string.Join(", ", upsertedKeys)}"); - Console.WriteLine($"Upserted record: {JsonSerializer.Serialize(upsertedRecord)}"); - } - - /// - /// A custom mapper that maps between the data model and the storage model. - /// - private sealed class Mapper : IVectorStoreRecordMapper - { - public (string Key, JsonNode Node) MapFromDataToStorageModel(GenericDataModel dataModel) - { - var jsonObject = new JsonObject(); - - jsonObject.Add("Term", dataModel.Data["Term"].ToString()); - jsonObject.Add("Definition", dataModel.Data["Definition"].ToString()); - - var vector = (ReadOnlyMemory)dataModel.Vectors["DefinitionEmbedding"]; - var jsonArray = new JsonArray(vector.ToArray().Select(x => JsonValue.Create(x)).ToArray()); - jsonObject.Add("DefinitionEmbedding", jsonArray); - - return (dataModel.Key, jsonObject); - } - - public GenericDataModel MapFromStorageToDataModel((string Key, JsonNode Node) storageModel, StorageToDataModelMapperOptions options) - { - var dataModel = new GenericDataModel - { - Key = storageModel.Key, - Data = new Dictionary - { - { "Term", (string)storageModel.Node["Term"]! }, - { "Definition", (string)storageModel.Node["Definition"]! } - }, - Vectors = new Dictionary - { - { "DefinitionEmbedding", new ReadOnlyMemory(storageModel.Node["DefinitionEmbedding"]!.AsArray().Select(x => (float)x!).ToArray()) } - } - }; - - return dataModel; - } - } - - private sealed class CustomRedisVectorStore(IDatabase database, RedisVectorStoreOptions? options = default) - : RedisVectorStore(database, options) - { - private readonly IDatabase _database = database; - - public override IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) - { - // If the record definition is the glossary definition and the record type is the generic data model, inject the custom mapper into the collection options. - if (vectorStoreRecordDefinition == s_glossaryDefinition && typeof(TRecord) == typeof(GenericDataModel)) - { - var customCollection = new RedisJsonVectorStoreRecordCollection(_database, name, new() { VectorStoreRecordDefinition = vectorStoreRecordDefinition, JsonNodeCustomMapper = new Mapper() }) as IVectorStoreRecordCollection; - return customCollection!; - } - - // Otherwise, just create a standard collection with the default mapper. - var collection = new RedisJsonVectorStoreRecordCollection(_database, name, new() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }) as IVectorStoreRecordCollection; - return collection!; - } - } - - /// - /// Sample generic data model class that can store any data. - /// - private sealed class GenericDataModel - { - public string Key { get; set; } - - public Dictionary Data { get; set; } - - public Dictionary Vectors { get; set; } - } - - /// - /// Create some sample glossary entries using the generic data model. - /// - /// A list of sample glossary entries. - private static IEnumerable CreateGlossaryEntries() - { - yield return new GenericDataModel - { - Key = "1", - Data = new() - { - { "Term", "API" }, - { "Definition", "Application Programming Interface. A set of rules and specifications that allow software components to communicate and exchange data." } - }, - Vectors = new() - }; - - yield return new GenericDataModel - { - Key = "2", - Data = new() - { - { "Term", "Connectors" }, - { "Definition", "Connectors allow you to integrate with various services provide AI capabilities, including LLM, AudioToText, TextToAudio, Embedding generation, etc." } - }, - Vectors = new() - }; - - yield return new GenericDataModel - { - Key = "3", - Data = new() - { - { "Term", "RAG" }, - { "Definition", "Retrieval Augmented Generation - a term that refers to the process of retrieving additional data to provide as context to an LLM to use when generating a response (completion) to a user’s question (prompt)." } - }, - Vectors = new() - }; - } -} diff --git a/dotnet/samples/Concepts/Memory/VectorStore_GenericDataModel_Interop.cs b/dotnet/samples/Concepts/Memory/VectorStore_DynamicDataModel_Interop.cs similarity index 71% rename from dotnet/samples/Concepts/Memory/VectorStore_GenericDataModel_Interop.cs rename to dotnet/samples/Concepts/Memory/VectorStore_DynamicDataModel_Interop.cs index 50c99dfcd03c..d7bb667284f4 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_GenericDataModel_Interop.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_DynamicDataModel_Interop.cs @@ -12,15 +12,15 @@ namespace Memory; /// -/// Semantic Kernel provides a generic data model for vector stores that can be used with any +/// Semantic Kernel support dynamic data modeling for vector stores that can be used with any /// schema. The schema still has to be provided in the form of a record definition, but no -/// custom data model is required. +/// custom .NET data model is required; a simple dictionary can be used. /// /// The sample shows how to -/// 1. Upsert data using the generic data model and retrieve it from the vector store using a custom data model. -/// 2. Upsert data using a custom data model and retrieve it from the vector store using the generic data model. +/// 1. Upsert data using dynamic data modeling and retrieve it from the vector store using a custom data model. +/// 2. Upsert data using a custom data model and retrieve it from the vector store using the dynamic data modeling. /// -public class VectorStore_GenericDataModel_Interop(ITestOutputHelper output, VectorStoreQdrantContainerFixture qdrantFixture) : BaseTest(output), IClassFixture +public class VectorStore_DynamicDataModel_Interop(ITestOutputHelper output, VectorStoreQdrantContainerFixture qdrantFixture) : BaseTest(output), IClassFixture { private static readonly JsonSerializerOptions s_indentedSerializerOptions = new() { WriteIndented = true }; @@ -31,12 +31,12 @@ public class VectorStore_GenericDataModel_Interop(ITestOutputHelper output, Vect new VectorStoreRecordKeyProperty("Key", typeof(ulong)), new VectorStoreRecordDataProperty("Term", typeof(string)), new VectorStoreRecordDataProperty("Definition", typeof(string)), - new VectorStoreRecordVectorProperty("DefinitionEmbedding", typeof(ReadOnlyMemory)) { Dimensions = 1536 } + new VectorStoreRecordVectorProperty("DefinitionEmbedding", typeof(ReadOnlyMemory), 1536) } }; [Fact] - public async Task UpsertWithGenericRetrieveWithCustomAsync() + public async Task UpsertWithDynamicRetrieveWithCustomAsync() { // Create an embedding generation service. var textEmbeddingGenerationService = new AzureOpenAITextEmbeddingGenerationService( @@ -48,27 +48,27 @@ public async Task UpsertWithGenericRetrieveWithCustomAsync() await qdrantFixture.ManualInitializeAsync(); var vectorStore = new QdrantVectorStore(new QdrantClient("localhost")); - // Get and create collection if it doesn't exist using the generic data model and record definition that defines the schema. - var genericDataModelCollection = vectorStore.GetCollection>("skglossary", s_vectorStoreRecordDefinition); - await genericDataModelCollection.CreateCollectionIfNotExistsAsync(); + // Get and create collection if it doesn't exist using the dynamic data model and record definition that defines the schema. + var dynamicDataModelCollection = vectorStore.GetCollection>("skglossary", s_vectorStoreRecordDefinition); + await dynamicDataModelCollection.CreateCollectionIfNotExistsAsync(); // Create glossary entries and generate embeddings for them. - var glossaryEntries = CreateGenericGlossaryEntries().ToList(); + var glossaryEntries = CreateDynamicGlossaryEntries().ToList(); var tasks = glossaryEntries.Select(entry => Task.Run(async () => { - entry.Vectors["DefinitionEmbedding"] = await textEmbeddingGenerationService.GenerateEmbeddingAsync((string)entry.Data["Definition"]!); + entry["DefinitionEmbedding"] = await textEmbeddingGenerationService.GenerateEmbeddingAsync((string)entry["Definition"]!); })); await Task.WhenAll(tasks); // Upsert the glossary entries into the collection and return their keys. - var upsertedKeysTasks = glossaryEntries.Select(x => genericDataModelCollection.UpsertAsync(x)); + var upsertedKeysTasks = glossaryEntries.Select(x => dynamicDataModelCollection.UpsertAsync(x)); var upsertedKeys = await Task.WhenAll(upsertedKeysTasks); // Get the collection using the custom data model. var customDataModelCollection = vectorStore.GetCollection("skglossary"); // Retrieve one of the upserted records from the collection. - var upsertedRecord = await customDataModelCollection.GetAsync(upsertedKeys.First(), new() { IncludeVectors = true }); + var upsertedRecord = await customDataModelCollection.GetAsync((ulong)upsertedKeys.First(), new() { IncludeVectors = true }); // Write upserted keys and one of the upserted records to the console. Console.WriteLine($"Upserted keys: {string.Join(", ", upsertedKeys)}"); @@ -76,7 +76,7 @@ public async Task UpsertWithGenericRetrieveWithCustomAsync() } [Fact] - public async Task UpsertWithCustomRetrieveWithGenericAsync() + public async Task UpsertWithCustomRetrieveWithDynamicAsync() { // Create an embedding generation service. var textEmbeddingGenerationService = new AzureOpenAITextEmbeddingGenerationService( @@ -104,11 +104,11 @@ public async Task UpsertWithCustomRetrieveWithGenericAsync() var upsertedKeysTasks = glossaryEntries.Select(x => customDataModelCollection.UpsertAsync(x)); var upsertedKeys = await Task.WhenAll(upsertedKeysTasks); - // Get the collection using the generic data model. - var genericDataModelCollection = vectorStore.GetCollection>("skglossary", s_vectorStoreRecordDefinition); + // Get the collection using the dynamic data model. + var dynamicDataModelCollection = vectorStore.GetCollection>("skglossary", s_vectorStoreRecordDefinition); // Retrieve one of the upserted records from the collection. - var upsertedRecord = await genericDataModelCollection.GetAsync(upsertedKeys.First(), new() { IncludeVectors = true }); + var upsertedRecord = await dynamicDataModelCollection.GetAsync(upsertedKeys.First(), new() { IncludeVectors = true }); // Write upserted keys and one of the upserted records to the console. Console.WriteLine($"Upserted keys: {string.Join(", ", upsertedKeys)}"); @@ -166,36 +166,30 @@ private static IEnumerable CreateCustomGlossaryEntries() } /// - /// Create some sample glossary entries using the generic data model. + /// Create some sample glossary entries using dynamic data modeling. /// /// A list of sample glossary entries. - private static IEnumerable> CreateGenericGlossaryEntries() + private static IEnumerable> CreateDynamicGlossaryEntries() { - yield return new VectorStoreGenericDataModel(1) + yield return new Dictionary { - Data = new Dictionary - { - ["Term"] = "API", - ["Definition"] = "Application Programming Interface. A set of rules and specifications that allow software components to communicate and exchange data.", - } + ["Key"] = 1, + ["Term"] = "API", + ["Definition"] = "Application Programming Interface. A set of rules and specifications that allow software components to communicate and exchange data." }; - yield return new VectorStoreGenericDataModel(2) + yield return new Dictionary { - Data = new Dictionary - { - ["Term"] = "Connectors", - ["Definition"] = "Connectors allow you to integrate with various services provide AI capabilities, including LLM, AudioToText, TextToAudio, Embedding generation, etc.", - } + ["Key"] = 2, + ["Term"] = "Connectors", + ["Definition"] = "Connectors allow you to integrate with various services provide AI capabilities, including LLM, AudioToText, TextToAudio, Embedding generation, etc." }; - yield return new VectorStoreGenericDataModel(3) + yield return new Dictionary { - Data = new Dictionary - { - ["Term"] = "RAG", - ["Definition"] = "Retrieval Augmented Generation - a term that refers to the process of retrieving additional data to provide as context to an LLM to use when generating a response (completion) to a user’s question (prompt).", - } + ["Key"] = 3, + ["Term"] = "RAG", + ["Definition"] = "Retrieval Augmented Generation - a term that refers to the process of retrieving additional data to provide as context to an LLM to use when generating a response (completion) to a user’s question (prompt)." }; } } diff --git a/dotnet/samples/Concepts/Memory/VectorStore_EmbeddingGeneration.cs b/dotnet/samples/Concepts/Memory/VectorStore_EmbeddingGeneration.cs deleted file mode 100644 index b641443e878a..000000000000 --- a/dotnet/samples/Concepts/Memory/VectorStore_EmbeddingGeneration.cs +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using Azure.Identity; -using Memory.VectorStoreEmbeddingGeneration; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.AzureOpenAI; -using Microsoft.SemanticKernel.Connectors.InMemory; - -namespace Memory; - -/// -/// This sample shows how to abstract embedding generation away from usage by -/// using the decorator pattern. -/// -/// In the sample we create an and then using -/// an extension method -/// we wrap the with a that will automatically generate embeddings for properties -/// that have the attribute. -/// -/// The decorated vector store also adds the additional interface to the collection -/// which allows us to search the collection using a text string without having to manually generate the embeddings. -/// -/// Note that the demonstrated here are part of this sample and not part of the Semantic Kernel libraries. -/// To use it, you will need to copy it to your own project. -/// -public class VectorStore_EmbeddingGeneration(ITestOutputHelper output) : BaseTest(output) -{ - [Fact] - public async Task UseEmbeddingGenerationViaDecoratorAsync() - { - // Create an embedding generation service. - var textEmbeddingGenerationService = new AzureOpenAITextEmbeddingGenerationService( - TestConfiguration.AzureOpenAIEmbeddings.DeploymentName, - TestConfiguration.AzureOpenAIEmbeddings.Endpoint, - new AzureCliCredential()); - - // Construct an InMemory vector store with embedding generation. - // The UseTextEmbeddingGeneration method adds an embedding generation - // decorator class to the vector store that will automatically generate - // embeddings for properties that are decorated with the GenerateTextEmbeddingAttribute. - var vectorStore = new InMemoryVectorStore().UseTextEmbeddingGeneration(textEmbeddingGenerationService); - - // Get and create collection if it doesn't exist. - var collection = vectorStore.GetCollection("skglossary"); - await collection.CreateCollectionIfNotExistsAsync(); - - // Create and upsert glossary entries into the collection. - await collection.UpsertBatchAsync(CreateGlossaryEntries()).ToListAsync(); - - // Search the collection using a vectorizable text search. - var search = collection as IVectorizableTextSearch; - var searchString = "What is an Application Programming Interface"; - var searchResult = await search!.VectorizableTextSearchAsync(searchString, new() { Top = 1 }); - var resultRecords = await searchResult.Results.ToListAsync(); - - Console.WriteLine("Search string: " + searchString); - Console.WriteLine("Result: " + resultRecords.First().Record.Definition); - Console.WriteLine(); - } - - /// - /// Sample model class that represents a glossary entry. - /// - /// - /// Note that each property is decorated with an attribute that specifies how the property should be treated by the vector store. - /// This allows us to create a collection in the vector store and upsert and retrieve instances of this class without any further configuration. - /// - /// The property is also decorated with the attribute which - /// allows the vector store to automatically generate an embedding for the property when the record is upserted. - /// - private sealed class Glossary - { - [VectorStoreRecordKey] - public ulong Key { get; set; } - - [VectorStoreRecordData(IsFilterable = true)] - public string Category { get; set; } - - [VectorStoreRecordData] - public string Term { get; set; } - - [VectorStoreRecordData] - public string Definition { get; set; } - - [GenerateTextEmbedding(nameof(Definition))] - [VectorStoreRecordVector(1536)] - public ReadOnlyMemory DefinitionEmbedding { get; set; } - } - - /// - /// Create some sample glossary entries. - /// - /// A list of sample glossary entries. - private static IEnumerable CreateGlossaryEntries() - { - yield return new Glossary - { - Key = 1, - Category = "External Definitions", - Term = "API", - Definition = "Application Programming Interface. A set of rules and specifications that allow software components to communicate and exchange data." - }; - - yield return new Glossary - { - Key = 2, - Category = "Core Definitions", - Term = "Connectors", - Definition = "Connectors allow you to integrate with various services provide AI capabilities, including LLM, AudioToText, TextToAudio, Embedding generation, etc." - }; - - yield return new Glossary - { - Key = 3, - Category = "External Definitions", - Term = "RAG", - Definition = "Retrieval Augmented Generation - a term that refers to the process of retrieving additional data to provide as context to an LLM to use when generating a response (completion) to a user’s question (prompt)." - }; - } -} diff --git a/dotnet/samples/Concepts/Memory/VectorStore_HybridSearch_Simple_AzureAISearch.cs b/dotnet/samples/Concepts/Memory/VectorStore_HybridSearch_Simple_AzureAISearch.cs index 521b8f03434a..1ce7b2e87be0 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_HybridSearch_Simple_AzureAISearch.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_HybridSearch_Simple_AzureAISearch.cs @@ -56,8 +56,7 @@ public async Task IngestDataAndUseHybridSearch() // Search the collection using a vector search. var searchString = "What is an Application Programming Interface"; var searchVector = await textEmbeddingGenerationService.GenerateEmbeddingAsync(searchString); - var searchResult = await hybridSearchCollection.HybridSearchAsync(searchVector, ["Application", "Programming", "Interface"], new() { Top = 1 }); - var resultRecords = await searchResult.Results.ToListAsync(); + var resultRecords = await hybridSearchCollection.HybridSearchAsync(searchVector, ["Application", "Programming", "Interface"], top: 1).ToListAsync(); Console.WriteLine("Search string: " + searchString); Console.WriteLine("Result: " + resultRecords.First().Record.Definition); @@ -66,8 +65,7 @@ public async Task IngestDataAndUseHybridSearch() // Search the collection using a vector search. searchString = "What is Retrieval Augmented Generation"; searchVector = await textEmbeddingGenerationService.GenerateEmbeddingAsync(searchString); - searchResult = await hybridSearchCollection.HybridSearchAsync(searchVector, ["Retrieval", "Augmented", "Generation"], new() { Top = 1 }); - resultRecords = await searchResult.Results.ToListAsync(); + resultRecords = await hybridSearchCollection.HybridSearchAsync(searchVector, ["Retrieval", "Augmented", "Generation"], top: 1).ToListAsync(); Console.WriteLine("Search string: " + searchString); Console.WriteLine("Result: " + resultRecords.First().Record.Definition); @@ -76,8 +74,7 @@ public async Task IngestDataAndUseHybridSearch() // Search the collection using a vector search with pre-filtering. searchString = "What is Retrieval Augmented Generation"; searchVector = await textEmbeddingGenerationService.GenerateEmbeddingAsync(searchString); - searchResult = await hybridSearchCollection.HybridSearchAsync(searchVector, ["Retrieval", "Augmented", "Generation"], new() { Top = 3, Filter = g => g.Category == "External Definitions" }); - resultRecords = await searchResult.Results.ToListAsync(); + resultRecords = await hybridSearchCollection.HybridSearchAsync(searchVector, ["Retrieval", "Augmented", "Generation"], top: 3, new() { Filter = g => g.Category == "External Definitions" }).ToListAsync(); Console.WriteLine("Search string: " + searchString); Console.WriteLine("Number of results: " + resultRecords.Count); @@ -99,13 +96,13 @@ private sealed class Glossary [VectorStoreRecordKey] public string Key { get; set; } - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string Category { get; set; } [VectorStoreRecordData] public string Term { get; set; } - [VectorStoreRecordData(IsFullTextSearchable = true)] + [VectorStoreRecordData(IsFullTextIndexed = true)] public string Definition { get; set; } [VectorStoreRecordVector(1536)] diff --git a/dotnet/samples/Concepts/Memory/VectorStore_Langchain_Interop.cs b/dotnet/samples/Concepts/Memory/VectorStore_Langchain_Interop.cs index 5466e7fd30af..ca10dbe496ee 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_Langchain_Interop.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_Langchain_Interop.cs @@ -1,13 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. -using Azure; using Azure.Identity; -using Azure.Search.Documents.Indexes; using Memory.VectorStoreLangchainInterop; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.AzureOpenAI; using Microsoft.SemanticKernel.Embeddings; -using Qdrant.Client; using StackExchange.Redis; using Sdk = Pinecone; @@ -28,32 +25,6 @@ namespace Memory; /// public class VectorStore_Langchain_Interop(ITestOutputHelper output) : BaseTest(output) { - /// - /// Shows how to read data from an Azure AI Search collection that was created and ingested using Langchain. - /// - [Fact] - public async Task ReadDataFromLangchainAzureAISearchAsync() - { - var searchIndexClient = new SearchIndexClient( - new Uri(TestConfiguration.AzureAISearch.Endpoint), - new AzureKeyCredential(TestConfiguration.AzureAISearch.ApiKey)); - var vectorStore = AzureAISearchFactory.CreateQdrantLangchainInteropVectorStore(searchIndexClient); - await this.ReadDataFromCollectionAsync(vectorStore, "pets"); - } - - /// - /// Shows how to read data from a Qdrant collection that was created and ingested using Langchain. - /// Also adds a converter to expose keys as strings containing GUIDs instead of objects, - /// to match the document schema of the other vector stores. - /// - [Fact] - public async Task ReadDataFromLangchainQdrantAsync() - { - var qdrantClient = new QdrantClient("localhost"); - var vectorStore = QdrantFactory.CreateQdrantLangchainInteropVectorStore(qdrantClient); - await this.ReadDataFromCollectionAsync(vectorStore, "pets"); - } - /// /// Shows how to read data from a Pinecone collection that was created and ingested using Langchain. /// @@ -96,8 +67,7 @@ private async Task ReadDataFromCollectionAsync(IVectorStore vectorStore, string // Search the data set. var searchString = "I'm looking for an animal that is loyal and will make a great companion"; var searchVector = await textEmbeddingGenerationService.GenerateEmbeddingAsync(searchString); - var searchResult = await collection.VectorizedSearchAsync(searchVector, new() { Top = 1 }); - var resultRecords = await searchResult.Results.ToListAsync(); + var resultRecords = await collection.SearchEmbeddingAsync(searchVector, top: 1).ToListAsync(); this.Output.WriteLine("Search string: " + searchString); this.Output.WriteLine("Source: " + resultRecords.First().Record.Source); diff --git a/dotnet/samples/Concepts/Memory/VectorStore_MigrateFromMemoryStore_Redis.cs b/dotnet/samples/Concepts/Memory/VectorStore_MigrateFromMemoryStore_Redis.cs index c5ee0d648d91..2a0aea7e47cf 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_MigrateFromMemoryStore_Redis.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_MigrateFromMemoryStore_Redis.cs @@ -29,6 +29,7 @@ namespace Memory; /// /// To run this sample, you need a local instance of Docker running, since the associated fixture will try and start a Redis container in the local docker instance to run against. /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class VectorStore_MigrateFromMemoryStore_Redis(ITestOutputHelper output, VectorStoreRedisContainerFixture redisFixture) : BaseTest(output), IClassFixture { private const int VectorSize = 1536; diff --git a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Common.cs b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Common.cs index 10d7c05e7df1..9f50d8b56b28 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Common.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Common.cs @@ -51,8 +51,7 @@ public async Task IngestDataAndSearchAsync(string collectionName, Func(string collectionName, Func(string collectionName, Func g.Category == "External Definitions" }); - resultRecords = await searchResult.Results.ToListAsync(); + resultRecords = await collection.SearchEmbeddingAsync(searchVector, top: 3, new() { Filter = g => g.Category == "External Definitions" }).ToListAsync(); output.WriteLine("Search string: " + searchString); output.WriteLine("Number of results: " + resultRecords.Count); @@ -128,7 +125,7 @@ private sealed class Glossary [VectorStoreRecordKey] public TKey Key { get; set; } - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string Category { get; set; } [VectorStoreRecordData] diff --git a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiVector.cs b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiVector.cs index 645a1040c115..2cd98b672944 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiVector.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiVector.cs @@ -54,13 +54,11 @@ public async Task VectorSearchWithMultiVectorRecordAsync() // Search the store using the description embedding. var searchString = "I am looking for a reasonably priced coffee maker"; var searchVector = await textEmbeddingGenerationService.GenerateEmbeddingAsync(searchString); - var searchResult = await collection.VectorizedSearchAsync( - searchVector, new() + var resultRecords = await collection.SearchEmbeddingAsync( + searchVector, top: 1, new() { - Top = 1, VectorProperty = r => r.DescriptionEmbedding - }); - var resultRecords = await searchResult.Results.ToListAsync(); + }).ToListAsync(); WriteLine("Search string: " + searchString); WriteLine("Result: " + resultRecords.First().Record.Description); @@ -70,14 +68,13 @@ public async Task VectorSearchWithMultiVectorRecordAsync() // Search the store using the feature list embedding. searchString = "I am looking for a handheld vacuum cleaner that will remove pet hair"; searchVector = await textEmbeddingGenerationService.GenerateEmbeddingAsync(searchString); - searchResult = await collection.VectorizedSearchAsync( + resultRecords = await collection.SearchEmbeddingAsync( searchVector, + top: 1, new() { - Top = 1, VectorProperty = r => r.FeatureListEmbedding - }); - resultRecords = await searchResult.Results.ToListAsync(); + }).ToListAsync(); WriteLine("Search string: " + searchString); WriteLine("Result: " + resultRecords.First().Record.Description); diff --git a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_Paging.cs b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_Paging.cs index c8b136f72542..ad8881ea7d30 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_Paging.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_Paging.cs @@ -47,17 +47,17 @@ public async Task VectorSearchWithPagingAsync() while (moreResults) { // Get the next page of results by asking for 10 results, and using 'Skip' to skip the results from the previous pages. - var currentPageResults = await collection.VectorizedSearchAsync( + var currentPageResults = collection.SearchEmbeddingAsync( searchVector, + top: 10, new() { - Top = 10, Skip = page * 10 }); // Print the results. var pageCount = 0; - await foreach (var result in currentPageResults.Results) + await foreach (var result in currentPageResults) { Console.WriteLine($"Key: {result.Record.Key}, Text: {result.Record.Text}"); pageCount++; diff --git a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_Simple.cs b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_Simple.cs index 9a43c01aeb43..9f2e7f1315db 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_Simple.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_Simple.cs @@ -50,8 +50,7 @@ public async Task ExampleAsync() // Search the collection using a vector search. var searchString = "What is an Application Programming Interface"; var searchVector = await textEmbeddingGenerationService.GenerateEmbeddingAsync(searchString); - var searchResult = await collection.VectorizedSearchAsync(searchVector, new() { Top = 1 }); - var resultRecords = await searchResult.Results.ToListAsync(); + var resultRecords = await collection.SearchEmbeddingAsync(searchVector, top: 1).ToListAsync(); Console.WriteLine("Search string: " + searchString); Console.WriteLine("Result: " + resultRecords.First().Record.Definition); @@ -60,8 +59,7 @@ public async Task ExampleAsync() // Search the collection using a vector search. searchString = "What is Retrieval Augmented Generation"; searchVector = await textEmbeddingGenerationService.GenerateEmbeddingAsync(searchString); - searchResult = await collection.VectorizedSearchAsync(searchVector, new() { Top = 1 }); - resultRecords = await searchResult.Results.ToListAsync(); + resultRecords = await collection.SearchEmbeddingAsync(searchVector, top: 1).ToListAsync(); Console.WriteLine("Search string: " + searchString); Console.WriteLine("Result: " + resultRecords.First().Record.Definition); @@ -70,8 +68,7 @@ public async Task ExampleAsync() // Search the collection using a vector search with pre-filtering. searchString = "What is Retrieval Augmented Generation"; searchVector = await textEmbeddingGenerationService.GenerateEmbeddingAsync(searchString); - searchResult = await collection.VectorizedSearchAsync(searchVector, new() { Top = 3, Filter = g => g.Category == "External Definitions" }); - resultRecords = await searchResult.Results.ToListAsync(); + resultRecords = await collection.SearchEmbeddingAsync(searchVector, top: 3, new() { Filter = g => g.Category == "External Definitions" }).ToListAsync(); Console.WriteLine("Search string: " + searchString); Console.WriteLine("Number of results: " + resultRecords.Count); @@ -93,7 +90,7 @@ private sealed class Glossary [VectorStoreRecordKey] public ulong Key { get; set; } - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string Category { get; set; } [VectorStoreRecordData] diff --git a/dotnet/samples/Concepts/Memory/VolatileVectorStore_LoadData.cs b/dotnet/samples/Concepts/Memory/VolatileVectorStore_LoadData.cs index 9e70c987aed3..e3a2c2dc0e64 100644 --- a/dotnet/samples/Concepts/Memory/VolatileVectorStore_LoadData.cs +++ b/dotnet/samples/Concepts/Memory/VolatileVectorStore_LoadData.cs @@ -71,8 +71,7 @@ static DataModel CreateRecord(string text, ReadOnlyMemory embedding) // Search the collection using a vector search. var searchString = "What is the Semantic Kernel?"; var searchVector = await embeddingGenerationService.GenerateEmbeddingAsync(searchString); - var searchResult = await vectorSearch!.VectorizedSearchAsync(searchVector, new() { Top = 1 }); - var resultRecords = await searchResult.Results.ToListAsync(); + var resultRecords = await vectorSearch!.SearchEmbeddingAsync(searchVector, top: 1).ToListAsync(); Console.WriteLine("Search string: " + searchString); Console.WriteLine("Result: " + resultRecords.First().Record.Text); @@ -116,8 +115,7 @@ static DataModel CreateRecord(TextSearchResult searchResult, ReadOnlyMemory(CollectionName); await collection.CreateCollectionIfNotExistsAsync(context.CancellationToken); - await collection.UpsertBatchAsync(exampleRecords, cancellationToken: context.CancellationToken).ToListAsync(context.CancellationToken); + await collection.UpsertAsync(exampleRecords, cancellationToken: context.CancellationToken); // Generate embedding for original request. var requestEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(request, cancellationToken: context.CancellationToken); // Find top N examples which are similar to original request. - var searchResults = await collection.VectorizedSearchAsync(requestEmbedding, new() { Top = TopN }, cancellationToken: context.CancellationToken); - var topNExamples = (await searchResults.Results.ToListAsync(context.CancellationToken)).Select(l => l.Record).ToList(); + var topNExamples = (await collection.SearchEmbeddingAsync(requestEmbedding, top: TopN, cancellationToken: context.CancellationToken) + .ToListAsync(context.CancellationToken)).Select(l => l.Record).ToList(); // Override arguments to use only top N examples, which will be sent to LLM. context.Arguments["Examples"] = topNExamples.Select(l => l.Example); @@ -323,7 +323,7 @@ private sealed class ExampleRecord [VectorStoreRecordData] public string Example { get; set; } - [VectorStoreRecordVector] + [VectorStoreRecordVector(1536)] public ReadOnlyMemory ExampleEmbedding { get; set; } } } diff --git a/dotnet/samples/Concepts/Optimization/PluginSelectionWithFilters.cs b/dotnet/samples/Concepts/Optimization/PluginSelectionWithFilters.cs index 695ff675e17f..73a4c8fd0815 100644 --- a/dotnet/samples/Concepts/Optimization/PluginSelectionWithFilters.cs +++ b/dotnet/samples/Concepts/Optimization/PluginSelectionWithFilters.cs @@ -298,8 +298,8 @@ public async Task> GetBestFunctionsAsync( await collection.CreateCollectionIfNotExistsAsync(cancellationToken); // Find best functions to call for original request. - var searchResults = await collection.VectorizedSearchAsync(requestEmbedding, new() { Top = numberOfBestFunctions }, cancellationToken); - var recordKeys = (await searchResults.Results.ToListAsync(cancellationToken)).Select(l => l.Record.Id); + var recordKeys = (await collection.SearchEmbeddingAsync(requestEmbedding, top: numberOfBestFunctions, cancellationToken: cancellationToken) + .ToListAsync(cancellationToken)).Select(l => l.Record.Id); return plugins .SelectMany(plugin => plugin) @@ -341,7 +341,7 @@ public async Task SaveAsync(string collectionName, KernelPluginCollection plugin var collection = vectorStore.GetCollection(collectionName); await collection.CreateCollectionIfNotExistsAsync(cancellationToken); - await collection.UpsertBatchAsync(functionRecords, cancellationToken: cancellationToken).ToListAsync(cancellationToken); + await collection.UpsertAsync(functionRecords, cancellationToken: cancellationToken); } private static List<(KernelFunction Function, string TextToVectorize)> GetFunctionsData(KernelPluginCollection plugins) @@ -422,7 +422,7 @@ private sealed class FunctionRecord [VectorStoreRecordData] public string FunctionInfo { get; set; } - [VectorStoreRecordVector] + [VectorStoreRecordVector(1536)] public ReadOnlyMemory FunctionInfoEmbedding { get; set; } } diff --git a/dotnet/samples/Concepts/RAG/WithPlugins.cs b/dotnet/samples/Concepts/RAG/WithPlugins.cs index 267a6c3618a9..24419dd7fdb7 100644 --- a/dotnet/samples/Concepts/RAG/WithPlugins.cs +++ b/dotnet/samples/Concepts/RAG/WithPlugins.cs @@ -2,10 +2,13 @@ using System.Net.Http.Headers; using System.Text.Json; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel; -using Microsoft.SemanticKernel.Connectors.Chroma; -using Microsoft.SemanticKernel.Connectors.OpenAI; -using Microsoft.SemanticKernel.Memory; +using Microsoft.SemanticKernel.Connectors.InMemory; +using Microsoft.SemanticKernel.Data; +using Microsoft.SemanticKernel.PromptTemplates.Handlebars; +using OpenAI; using Resources; namespace RAG; @@ -27,23 +30,45 @@ public async Task RAGWithCustomPluginAsync() } /// - /// Shows how to use RAG pattern with . + /// Shows how to use RAG pattern with . /// - [Fact(Skip = "Requires Chroma server up and running")] - public async Task RAGWithTextMemoryPluginAsync() + [Fact] + public async Task RAGWithInMemoryVectorStoreAndPluginAsync() { - var memory = new MemoryBuilder() - .WithMemoryStore(new ChromaMemoryStore("http://localhost:8000")) - .WithOpenAITextEmbeddingGeneration(TestConfiguration.OpenAI.EmbeddingModelId, TestConfiguration.OpenAI.ApiKey) - .Build(); + var textEmbeddingGenerator = new OpenAIClient(TestConfiguration.OpenAI.ApiKey) + .GetEmbeddingClient(TestConfiguration.OpenAI.EmbeddingModelId) + .AsIEmbeddingGenerator(); var kernel = Kernel.CreateBuilder() .AddOpenAIChatCompletion(TestConfiguration.OpenAI.ChatModelId, TestConfiguration.OpenAI.ApiKey) .Build(); - kernel.ImportPluginFromObject(new Microsoft.SemanticKernel.Plugins.Memory.TextMemoryPlugin(memory)); + // Create the collection and add data + var vectorStore = new InMemoryVectorStore(new() { EmbeddingGenerator = textEmbeddingGenerator }); + var collection = vectorStore.GetCollection("finances"); + await collection.CreateCollectionAsync(); + string[] budgetInfo = + { + "The budget for 2020 is EUR 100 000", + "The budget for 2021 is EUR 120 000", + "The budget for 2022 is EUR 150 000", + "The budget for 2023 is EUR 200 000", + "The budget for 2024 is EUR 364 000" + }; + var records = budgetInfo.Select((input, index) => new FinanceInfo { Key = index.ToString(), Text = input }); + await collection.UpsertAsync(records); + + // Add the collection to the kernel as a plugin. + var textSearch = new VectorStoreTextSearch(collection); + kernel.Plugins.Add(textSearch.CreateWithSearch("FinanceSearch", "Can search for budget information")); - var result = await kernel.InvokePromptAsync("{{recall 'budget by year' collection='finances'}} What is my budget for 2024?"); + // Invoke the kernel, using the plugin from within the prompt. + KernelArguments arguments = new() { { "query", "What is my budget for 2024?" } }; + var result = await kernel.InvokePromptAsync( + "{{FinanceSearch-Search query}} {{query}}", + arguments, + templateFormat: HandlebarsPromptTemplateFactory.HandlebarsTemplateFormat, + promptTemplateFactory: new HandlebarsPromptTemplateFactory()); Console.WriteLine(result); } @@ -91,5 +116,18 @@ public async Task SearchAsync(string query) } } - #endregion + private sealed class FinanceInfo + { + [VectorStoreRecordKey] + public string Key { get; set; } = string.Empty; + + [TextSearchResultValue] + [VectorStoreRecordData] + public string Text { get; set; } = string.Empty; + + [VectorStoreRecordVector(1536)] + public string Embedding => this.Text; + } + + #endregion Custom Plugin } diff --git a/dotnet/samples/Concepts/README.md b/dotnet/samples/Concepts/README.md index 53a63c441f0b..61c6cb259343 100644 --- a/dotnet/samples/Concepts/README.md +++ b/dotnet/samples/Concepts/README.md @@ -131,13 +131,8 @@ dotnet test -l "console;verbosity=detailed" --filter "FullyQualifiedName=ChatCom - [Ollama_EmbeddingGeneration](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/Ollama_EmbeddingGeneration.cs) - [Onnx_EmbeddingGeneration](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/Onnx_EmbeddingGeneration.cs) - [HuggingFace_EmbeddingGeneration](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/HuggingFace_EmbeddingGeneration.cs) -- [MemoryStore_CustomReadOnly](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/MemoryStore_CustomReadOnly.cs) -- [SemanticTextMemory_Building](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/SemanticTextMemory_Building.cs) - [TextChunkerUsage](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/TextChunkerUsage.cs) - [TextChunkingAndEmbedding](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/TextChunkingAndEmbedding.cs) -- [TextMemoryPlugin_GeminiEmbeddingGeneration](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/TextMemoryPlugin_GeminiEmbeddingGeneration.cs) -- [TextMemoryPlugin_MultipleMemoryStore](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs) -- [TextMemoryPlugin_RecallJsonSerializationWithOptions](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/TextMemoryPlugin_RecallJsonSerializationWithOptions.cs) - [VectorStore_DataIngestion_Simple: A simple example of how to do data ingestion into a vector store when getting started.](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_Simple.cs) - [VectorStore_DataIngestion_MultiStore: An example of data ingestion that uses the same code to ingest into multiple vector stores types.](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_MultiStore.cs) - [VectorStore_DataIngestion_CustomMapper: An example that shows how to use a custom mapper for when your data model and storage model doesn't match.](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_CustomMapper.cs) @@ -146,7 +141,7 @@ dotnet test -l "console;verbosity=detailed" --filter "FullyQualifiedName=ChatCom - [VectorStore_VectorSearch_MultiVector: An example showing how to pick a target vector when doing vector search on a record that contains multiple vectors.](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiVector.cs) - [VectorStore_VectorSearch_MultiStore_Common: An example showing how to write vector database agnostic code with different vector databases.](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Common.cs) - [VectorStore_HybridSearch_Simple_AzureAISearch: An example showing how to do hybrid search using AzureAISearch.](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/VectorStore_HybridSearch_Simple_AzureAISearch.cs) -- [VectorStore_GenericDataModel_Interop: An example that shows how you can use the built-in, generic data model from Semantic Kernel to read and write to a Vector Store.](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/VectorStore_GenericDataModel_Interop.cs) +- [VectorStore_DynamicDataModel_Interop: An example that shows how you can use dynamic data modeling from Semantic Kernel to read and write to a Vector Store.](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/VectorStore_DynamicDataModel_Interop.cs) - [VectorStore_ConsumeFromMemoryStore_AzureAISearch: An example that shows how you can use the AzureAISearchVectorStore to consume data that was ingested using the AzureAISearchMemoryStore.](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_AzureAISearch.cs) - [VectorStore_ConsumeFromMemoryStore_Qdrant: An example that shows how you can use the QdrantVectorStore to consume data that was ingested using the QdrantMemoryStore.](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_Qdrant.cs) - [VectorStore_ConsumeFromMemoryStore_Redis: An example that shows how you can use the RedisVectorStore to consume data that was ingested using the RedisMemoryStore.](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/VectorStore_ConsumeFromMemoryStore_Redis.cs) diff --git a/dotnet/samples/Concepts/Search/VectorStore_TextSearch.cs b/dotnet/samples/Concepts/Search/VectorStore_TextSearch.cs index f6a3d4ab6356..f6f7a4adfdbe 100644 --- a/dotnet/samples/Concepts/Search/VectorStore_TextSearch.cs +++ b/dotnet/samples/Concepts/Search/VectorStore_TextSearch.cs @@ -1,5 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +#if DISABLED + +using System.Runtime.CompilerServices; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.InMemory; using Microsoft.SemanticKernel.Connectors.OpenAI; @@ -122,6 +125,7 @@ internal static async Task> CreateCo ITextEmbeddingGenerationService embeddingGenerationService, CreateRecord createRecord) where TKey : notnull + where TRecord : notnull { // Get and create collection if it doesn't exist. var collection = vectorStore.GetCollection(collectionName); @@ -139,16 +143,29 @@ internal static async Task> CreateCo } /// - /// Decorator for a that generates embeddings for text search queries. + /// Decorator for a that generates embeddings for text search queries. /// - private sealed class VectorizedSearchWrapper(IVectorizedSearch vectorizedSearch, ITextEmbeddingGenerationService textEmbeddingGeneration) : IVectorizableTextSearch + private sealed class VectorizedSearchWrapper(IVectorSearch vectorizedSearch, ITextEmbeddingGenerationService textEmbeddingGeneration) : IVectorizableTextSearch { /// - public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async IAsyncEnumerable> VectorizableTextSearchAsync(string searchText, int top, VectorSearchOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var vectorizedQuery = await textEmbeddingGeneration!.GenerateEmbeddingAsync(searchText, cancellationToken: cancellationToken).ConfigureAwait(false); - return await vectorizedSearch.VectorizedSearchAsync(vectorizedQuery, options, cancellationToken); + await foreach (var result in vectorizedSearch.VectorizedSearchAsync(vectorizedQuery, top, options, cancellationToken)) + { + yield return result; + } + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + ArgumentNullException.ThrowIfNull(serviceType); + + return + serviceKey is null && serviceType.IsInstanceOfType(this) ? this : + vectorizedSearch.GetService(serviceType, serviceKey); } } @@ -173,3 +190,5 @@ private sealed class DataModel public ReadOnlyMemory Embedding { get; init; } } } + +#endif diff --git a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Extensions/VectorStoreExtensions.cs b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Extensions/VectorStoreExtensions.cs index dacb15ff410a..8d06423301e0 100644 --- a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Extensions/VectorStoreExtensions.cs +++ b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Extensions/VectorStoreExtensions.cs @@ -35,6 +35,7 @@ public static async Task> CreateColl ITextEmbeddingGenerationService embeddingGenerationService, CreateRecordFromString createRecord) where TKey : notnull + where TRecord : notnull { // Get and create collection if it doesn't exist. var collection = vectorStore.GetCollection(collectionName); diff --git a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Program.cs b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Program.cs index d86ef65e3d85..aef6893f40d0 100644 --- a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Program.cs +++ b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Program.cs @@ -126,12 +126,12 @@ static TextDataModel CreateRecord(string text, ReadOnlyMemory embedding) ReadOnlyMemory promptEmbedding = await embeddingGenerationService.GenerateEmbeddingAsync(prompt, cancellationToken: cancellationToken); // Retrieve top three matching records from the vector store - VectorSearchResults result = await vsCollection.VectorizedSearchAsync(promptEmbedding, new() { Top = 3 }, cancellationToken); + var result = vsCollection.SearchEmbeddingAsync(promptEmbedding, top: 3, cancellationToken: cancellationToken); // Return the records as resource contents List contents = []; - await foreach (var record in result.Results) + await foreach (var record in result) { contents.Add(new TextResourceContents() { diff --git a/dotnet/samples/Demos/OnnxSimpleRAG/OnnxSimpleRAG.csproj b/dotnet/samples/Demos/OnnxSimpleRAG/OnnxSimpleRAG.csproj index bbb5f38ba81d..24a28cf88ab8 100644 --- a/dotnet/samples/Demos/OnnxSimpleRAG/OnnxSimpleRAG.csproj +++ b/dotnet/samples/Demos/OnnxSimpleRAG/OnnxSimpleRAG.csproj @@ -3,7 +3,7 @@ Exe net8.0 - $(NoWarn);CA2007;CS0612;VSTHRD111;SKEXP0070;SKEXP0050;SKEXP0001;SKEXP0020 + $(NoWarn);CA2007;CS0612;VSTHRD111;SKEXP0070;SKEXP0050;SKEXP0001 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/samples/Demos/OnnxSimpleRAG/Program.cs b/dotnet/samples/Demos/OnnxSimpleRAG/Program.cs index 0a8b76360850..e500c241febe 100644 --- a/dotnet/samples/Demos/OnnxSimpleRAG/Program.cs +++ b/dotnet/samples/Demos/OnnxSimpleRAG/Program.cs @@ -55,7 +55,7 @@ foreach (var factTextFile in Directory.GetFiles("Facts", "*.txt")) { var factContent = File.ReadAllText(factTextFile); - await collection.UpsertAsync(new() + await collection.UpsertAsync(new InformationItem() { Id = Guid.NewGuid().ToString(), Text = factContent, @@ -64,7 +64,10 @@ await collection.UpsertAsync(new() } // Add a plugin to search the database with. +// TODO: Once OpenAITextEmbeddingGenerationService implements MEAI's IEmbeddingGenerator (#10811), configure it with the InMemoryVectorStore above instead of passing it here. +#pragma warning disable CS0618 // VectorStoreTextSearch with ITextEmbeddingGenerationService is obsolete var vectorStoreTextSearch = new VectorStoreTextSearch(collection, embeddingService); +#pragma warning restore CS0618 kernel.Plugins.Add(vectorStoreTextSearch.CreateWithSearch("SearchPlugin")); // Start the conversation diff --git a/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.ProcessOrchestrator/ProcessFramework.Aspire.ProcessOrchestrator.csproj b/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.ProcessOrchestrator/ProcessFramework.Aspire.ProcessOrchestrator.csproj index 846843bdca9e..7d1d3995191d 100644 --- a/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.ProcessOrchestrator/ProcessFramework.Aspire.ProcessOrchestrator.csproj +++ b/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.ProcessOrchestrator/ProcessFramework.Aspire.ProcessOrchestrator.csproj @@ -6,7 +6,7 @@ enable enable - $(NoWarn);CS8618,IDE0009,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0101,SKEXP0110,OPENAI001 + $(NoWarn);CS8618,IDE0009,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0101,SKEXP0110,OPENAI001 diff --git a/dotnet/samples/Demos/ProcessWithCloudEvents/ProcessWithCloudEvents.Grpc/ProcessWithCloudEvents.Grpc.csproj b/dotnet/samples/Demos/ProcessWithCloudEvents/ProcessWithCloudEvents.Grpc/ProcessWithCloudEvents.Grpc.csproj index 5724e503f68e..b2d5022ffa34 100644 --- a/dotnet/samples/Demos/ProcessWithCloudEvents/ProcessWithCloudEvents.Grpc/ProcessWithCloudEvents.Grpc.csproj +++ b/dotnet/samples/Demos/ProcessWithCloudEvents/ProcessWithCloudEvents.Grpc/ProcessWithCloudEvents.Grpc.csproj @@ -5,7 +5,7 @@ enable enable - $(NoWarn);CA2007,CS1591,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110 + $(NoWarn);CA2007,CS1591,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/samples/Demos/ProcessWithCloudEvents/ProcessWithCloudEvents.Processes/ProcessWithCloudEvents.Processes.csproj b/dotnet/samples/Demos/ProcessWithCloudEvents/ProcessWithCloudEvents.Processes/ProcessWithCloudEvents.Processes.csproj index eb4cbc961b66..1fafc3012f07 100644 --- a/dotnet/samples/Demos/ProcessWithCloudEvents/ProcessWithCloudEvents.Processes/ProcessWithCloudEvents.Processes.csproj +++ b/dotnet/samples/Demos/ProcessWithCloudEvents/ProcessWithCloudEvents.Processes/ProcessWithCloudEvents.Processes.csproj @@ -5,7 +5,7 @@ enable enable - $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110 + $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110 diff --git a/dotnet/samples/Demos/ProcessWithDapr/ProcessWithDapr.csproj b/dotnet/samples/Demos/ProcessWithDapr/ProcessWithDapr.csproj index 69628bbacda5..d1bd90408672 100644 --- a/dotnet/samples/Demos/ProcessWithDapr/ProcessWithDapr.csproj +++ b/dotnet/samples/Demos/ProcessWithDapr/ProcessWithDapr.csproj @@ -5,7 +5,7 @@ enable enable - $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110 + $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110 diff --git a/dotnet/samples/Demos/VectorStoreRAG/DataLoader.cs b/dotnet/samples/Demos/VectorStoreRAG/DataLoader.cs index 2cd7d43ce746..c23f00ee9a56 100644 --- a/dotnet/samples/Demos/VectorStoreRAG/DataLoader.cs +++ b/dotnet/samples/Demos/VectorStoreRAG/DataLoader.cs @@ -66,8 +66,8 @@ public async Task LoadPdf(string pdfPath, int batchSize, int betweenBatchDelayIn // Upsert the records into the vector store. var records = await Task.WhenAll(recordTasks).ConfigureAwait(false); - var upsertedKeys = vectorStoreRecordCollection.UpsertBatchAsync(records, cancellationToken: cancellationToken); - await foreach (var key in upsertedKeys.ConfigureAwait(false)) + var upsertedKeys = await vectorStoreRecordCollection.UpsertAsync(records, cancellationToken: cancellationToken).ConfigureAwait(false); + foreach (var key in upsertedKeys) { Console.WriteLine($"Upserted record '{key}' into VectorDB"); } diff --git a/dotnet/samples/Demos/VectorStoreRAG/VectorStoreRAG.csproj b/dotnet/samples/Demos/VectorStoreRAG/VectorStoreRAG.csproj index 7b1557a8005c..a5d2dbd59e7a 100644 --- a/dotnet/samples/Demos/VectorStoreRAG/VectorStoreRAG.csproj +++ b/dotnet/samples/Demos/VectorStoreRAG/VectorStoreRAG.csproj @@ -5,7 +5,7 @@ net8.0 enable enable - $(NoWarn);SKEXP0001;SKEXP0010;SKEXP0020 + $(NoWarn);SKEXP0001;SKEXP0010 c4203b00-7179-47c1-8701-ee352e381412 diff --git a/dotnet/samples/GettingStarted/GettingStarted.csproj b/dotnet/samples/GettingStarted/GettingStarted.csproj index 6341c4dbae5a..c5c77c4238a2 100644 --- a/dotnet/samples/GettingStarted/GettingStarted.csproj +++ b/dotnet/samples/GettingStarted/GettingStarted.csproj @@ -7,7 +7,7 @@ true false - $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101 + $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101 Library 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/samples/GettingStartedWithAgents/GettingStartedWithAgents.csproj b/dotnet/samples/GettingStartedWithAgents/GettingStartedWithAgents.csproj index 1698d6be44b5..90818906f219 100644 --- a/dotnet/samples/GettingStartedWithAgents/GettingStartedWithAgents.csproj +++ b/dotnet/samples/GettingStartedWithAgents/GettingStartedWithAgents.csproj @@ -9,7 +9,7 @@ true - $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101,SKEXP0110,OPENAI001 + $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101,SKEXP0110,OPENAI001 Library 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/samples/GettingStartedWithProcesses/GettingStartedWithProcesses.csproj b/dotnet/samples/GettingStartedWithProcesses/GettingStartedWithProcesses.csproj index e77aab6ba7f3..1da2089382b7 100644 --- a/dotnet/samples/GettingStartedWithProcesses/GettingStartedWithProcesses.csproj +++ b/dotnet/samples/GettingStartedWithProcesses/GettingStartedWithProcesses.csproj @@ -10,7 +10,7 @@ - $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0101,SKEXP0110,OPENAI001 + $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0101,SKEXP0110,OPENAI001 Library 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/samples/GettingStartedWithTextSearch/GettingStartedWithTextSearch.csproj b/dotnet/samples/GettingStartedWithTextSearch/GettingStartedWithTextSearch.csproj index 41fc7813300f..03a522206317 100644 --- a/dotnet/samples/GettingStartedWithTextSearch/GettingStartedWithTextSearch.csproj +++ b/dotnet/samples/GettingStartedWithTextSearch/GettingStartedWithTextSearch.csproj @@ -7,7 +7,7 @@ true false - $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101 + $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101 Library 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/samples/GettingStartedWithTextSearch/InMemoryVectorStoreFixture.cs b/dotnet/samples/GettingStartedWithTextSearch/InMemoryVectorStoreFixture.cs index 2af880f4bdc2..23da3fff00ea 100644 --- a/dotnet/samples/GettingStartedWithTextSearch/InMemoryVectorStoreFixture.cs +++ b/dotnet/samples/GettingStartedWithTextSearch/InMemoryVectorStoreFixture.cs @@ -113,6 +113,7 @@ private async Task> CreateCollection string[] entries, CreateRecord createRecord) where TKey : notnull + where TRecord : notnull { // Get and create collection if it doesn't exist. var collection = this.InMemoryVectorStore.GetCollection(this.CollectionName); @@ -150,7 +151,7 @@ public sealed class DataModel [TextSearchResultLink] public string Link { get; init; } - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public required string Tag { get; init; } [VectorStoreRecordVector(1536)] diff --git a/dotnet/samples/GettingStartedWithTextSearch/Step4_Search_With_VectorStore.cs b/dotnet/samples/GettingStartedWithTextSearch/Step4_Search_With_VectorStore.cs index 9c48c3a6880b..a2f950fb3804 100644 --- a/dotnet/samples/GettingStartedWithTextSearch/Step4_Search_With_VectorStore.cs +++ b/dotnet/samples/GettingStartedWithTextSearch/Step4_Search_With_VectorStore.cs @@ -24,10 +24,13 @@ public async Task UsingInMemoryVectorStoreRecordTextSearchAsync() { // Use embedding generation service and record collection for the fixture. var textEmbeddingGeneration = fixture.TextEmbeddingGenerationService; - var vectorizedSearch = fixture.VectorStoreRecordCollection; + var collection = fixture.VectorStoreRecordCollection; // Create a text search instance using the InMemory vector store. - var textSearch = new VectorStoreTextSearch(vectorizedSearch, textEmbeddingGeneration); + // TODO: Once OpenAITextEmbeddingGenerationService implements MEAI's IEmbeddingGenerator (#10811), configure it with the collection +#pragma warning disable CS0618 // VectorStoreTextSearch with ITextEmbeddingGenerationService is obsolete + var textSearch = new VectorStoreTextSearch(collection, textEmbeddingGeneration); +#pragma warning restore CS0618 // Search and return results as TextSearchResult items var query = "What is the Semantic Kernel?"; @@ -57,10 +60,13 @@ public async Task RagWithInMemoryVectorStoreTextSearchAsync() // Use embedding generation service and record collection for the fixture. var textEmbeddingGeneration = fixture.TextEmbeddingGenerationService; - var vectorizedSearch = fixture.VectorStoreRecordCollection; + var collection = fixture.VectorStoreRecordCollection; // Create a text search instance using the InMemory vector store. - var textSearch = new VectorStoreTextSearch(vectorizedSearch, textEmbeddingGeneration); + // TODO: Once OpenAITextEmbeddingGenerationService implements MEAI's IEmbeddingGenerator (#10811), configure it with the collection +#pragma warning disable CS0618 // VectorStoreTextSearch with ITextEmbeddingGenerationService is obsolete + var textSearch = new VectorStoreTextSearch(collection, textEmbeddingGeneration); +#pragma warning restore CS0618 // Build a text search plugin with vector store search and add to the kernel var searchPlugin = textSearch.CreateWithGetTextSearchResults("SearchPlugin"); @@ -69,14 +75,14 @@ public async Task RagWithInMemoryVectorStoreTextSearchAsync() // Invoke prompt and use text search plugin to provide grounding information var query = "What is the Semantic Kernel?"; string promptTemplate = """ - {{#with (SearchPlugin-GetTextSearchResults query)}} - {{#each this}} + {{#with (SearchPlugin-GetTextSearchResults query)}} + {{#each this}} Name: {{Name}} Value: {{Value}} Link: {{Link}} ----------------- - {{/each}} - {{/with}} + {{/each}} + {{/with}} {{query}} @@ -108,10 +114,13 @@ public async Task FunctionCallingWithInMemoryVectorStoreTextSearchAsync() // Use embedding generation service and record collection for the fixture. var textEmbeddingGeneration = fixture.TextEmbeddingGenerationService; - var vectorizedSearch = fixture.VectorStoreRecordCollection; + var collection = fixture.VectorStoreRecordCollection; // Create a text search instance using the InMemory vector store. - var textSearch = new VectorStoreTextSearch(vectorizedSearch, textEmbeddingGeneration); + // TODO: Once OpenAITextEmbeddingGenerationService implements MEAI's IEmbeddingGenerator (#10811), configure it with the collection +#pragma warning disable CS0618 // VectorStoreTextSearch with ITextEmbeddingGenerationService is obsolete + var textSearch = new VectorStoreTextSearch(collection, textEmbeddingGeneration); +#pragma warning restore CS0618 // Build a text search plugin with vector store search and add to the kernel var searchPlugin = textSearch.CreateWithGetTextSearchResults("SearchPlugin"); diff --git a/dotnet/samples/GettingStartedWithVectorStores/GettingStartedWithVectorStores.csproj b/dotnet/samples/GettingStartedWithVectorStores/GettingStartedWithVectorStores.csproj index 5160921a3bbd..dec156215f6d 100644 --- a/dotnet/samples/GettingStartedWithVectorStores/GettingStartedWithVectorStores.csproj +++ b/dotnet/samples/GettingStartedWithVectorStores/GettingStartedWithVectorStores.csproj @@ -7,7 +7,7 @@ true false - $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101 + $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101 Library 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/samples/GettingStartedWithVectorStores/Glossary.cs b/dotnet/samples/GettingStartedWithVectorStores/Glossary.cs index 8fc0ee87b4ad..58491513dcbd 100644 --- a/dotnet/samples/GettingStartedWithVectorStores/Glossary.cs +++ b/dotnet/samples/GettingStartedWithVectorStores/Glossary.cs @@ -16,7 +16,7 @@ internal sealed class Glossary [VectorStoreRecordKey] public string Key { get; set; } - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string Category { get; set; } [VectorStoreRecordData] diff --git a/dotnet/samples/GettingStartedWithVectorStores/Step2_Vector_Search.cs b/dotnet/samples/GettingStartedWithVectorStores/Step2_Vector_Search.cs index 2eda86863a60..195d638573f7 100644 --- a/dotnet/samples/GettingStartedWithVectorStores/Step2_Vector_Search.cs +++ b/dotnet/samples/GettingStartedWithVectorStores/Step2_Vector_Search.cs @@ -43,13 +43,9 @@ internal static async Task> SearchVectorStoreAsync( var searchVector = await textEmbeddingGenerationService.GenerateEmbeddingAsync(searchString); // Search the store and get the single most relevant result. - var searchResult = await collection.VectorizedSearchAsync( + var searchResultItems = await collection.SearchEmbeddingAsync( searchVector, - new() - { - Top = 1 - }); - var searchResultItems = await searchResult.Results.ToListAsync(); + top: 1).ToListAsync(); return searchResultItems.First(); } @@ -66,14 +62,13 @@ public async Task SearchAnInMemoryVectorStoreWithFilteringAsync() var searchVector = await fixture.TextEmbeddingGenerationService.GenerateEmbeddingAsync(searchString); // Search the store with a filter and get the single most relevant result. - var searchResult = await collection.VectorizedSearchAsync( + var searchResultItems = await collection.SearchEmbeddingAsync( searchVector, + top: 1, new() { - Top = 1, Filter = g => g.Category == "AI" - }); - var searchResultItems = await searchResult.Results.ToListAsync(); + }).ToListAsync(); // Write the search result with its score to the console. Console.WriteLine(searchResultItems.First().Record.Definition); diff --git a/dotnet/samples/GettingStartedWithVectorStores/Step4_NonStringKey_VectorStore.cs b/dotnet/samples/GettingStartedWithVectorStores/Step4_NonStringKey_VectorStore.cs deleted file mode 100644 index 35ca4822a824..000000000000 --- a/dotnet/samples/GettingStartedWithVectorStores/Step4_NonStringKey_VectorStore.cs +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -#if DISABLED_FOR_NOW // TODO: See note in MappingVectorStoreRecordCollection - -using System.Runtime.CompilerServices; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.Qdrant; -using Qdrant.Client; - -namespace GettingStartedWithVectorStores; - - -/// -/// Example that shows that you can switch between different vector stores with the same code, in this case -/// with a vector store that doesn't use string keys. -/// This sample demonstrates one possible approach, however it is also possible to use generics -/// in the common code to achieve code reuse. -/// -public class Step4_NonStringKey_VectorStore(ITestOutputHelper output, VectorStoresFixture fixture) : BaseTest(output), IClassFixture -{ - /// - /// Here we are going to use the same code that we used in and - /// but now with an . - /// Qdrant uses Guid or ulong as the key type, but the common code works with a string key. The string keys of the records created - /// in contain numbers though, so it's possible for us to convert them to ulong. - /// In this example, we'll demonstrate how to do that. - /// - /// This example requires a Qdrant server up and running. To run a Qdrant server in a Docker container, use the following command: - /// docker run -d --name qdrant -p 6333:6333 -p 6334:6334 qdrant/qdrant:latest - /// - [Fact] - public async Task UseAQdrantVectorStoreAsync() - { - // Construct a Qdrant vector store collection. - var collection = new QdrantVectorStoreRecordCollection(new QdrantClient("localhost"), "skglossary"); - - // Wrap the collection using a decorator that allows us to expose a version that uses string keys, but internally - // we convert to and from ulong. - var stringKeyCollection = new MappingVectorStoreRecordCollection( - collection, - p => ulong.Parse(p), - i => i.ToString(), - p => new UlongGlossary { Key = ulong.Parse(p.Key), Category = p.Category, Term = p.Term, Definition = p.Definition, DefinitionEmbedding = p.DefinitionEmbedding }, - i => new Glossary { Key = i.Key.ToString("D"), Category = i.Category, Term = i.Term, Definition = i.Definition, DefinitionEmbedding = i.DefinitionEmbedding }); - - // Ingest data into the collection using the same code as we used in Step1 with the InMemory Vector Store. - await Step1_Ingest_Data.IngestDataIntoVectorStoreAsync(stringKeyCollection, fixture.TextEmbeddingGenerationService); - - // Search the vector store using the same code as we used in Step2 with the InMemory Vector Store. - var searchResultItem = await Step2_Vector_Search.SearchVectorStoreAsync( - stringKeyCollection, - "What is an Application Programming Interface?", - fixture.TextEmbeddingGenerationService); - - // Write the search result with its score to the console. - Console.WriteLine(searchResultItem.Record.Definition); - Console.WriteLine(searchResultItem.Score); - } - - /// - /// Data model that uses a ulong as the key type instead of a string. - /// - private sealed class UlongGlossary - { - [VectorStoreRecordKey] - public ulong Key { get; set; } - - [VectorStoreRecordData(IsFilterable = true)] - public string Category { get; set; } - - [VectorStoreRecordData] - public string Term { get; set; } - - [VectorStoreRecordData] - public string Definition { get; set; } - - [VectorStoreRecordVector(Dimensions: 1536)] - public ReadOnlyMemory DefinitionEmbedding { get; set; } - } - - /// - /// Simple decorator class that allows conversion of keys and records from one type to another. - /// - private sealed class MappingVectorStoreRecordCollection : IVectorStoreRecordCollection - where TPublicKey : notnull - where TInternalKey : notnull - { - private readonly IVectorStoreRecordCollection _collection; - private readonly Func _publicToInternalKeyMapper; - private readonly Func _internalToPublicKeyMapper; - private readonly Func _publicToInternalRecordMapper; - private readonly Func _internalToPublicRecordMapper; - - public MappingVectorStoreRecordCollection( - IVectorStoreRecordCollection collection, - Func publicToInternalKeyMapper, - Func internalToPublicKeyMapper, - Func publicToInternalRecordMapper, - Func internalToPublicRecordMapper) - { - this._collection = collection; - this._publicToInternalKeyMapper = publicToInternalKeyMapper; - this._internalToPublicKeyMapper = internalToPublicKeyMapper; - this._publicToInternalRecordMapper = publicToInternalRecordMapper; - this._internalToPublicRecordMapper = internalToPublicRecordMapper; - } - - /// - public string CollectionName => this._collection.CollectionName; - - /// - public Task CollectionExistsAsync(CancellationToken cancellationToken = default) - { - return this._collection.CollectionExistsAsync(cancellationToken); - } - - /// - public Task CreateCollectionAsync(CancellationToken cancellationToken = default) - { - return this._collection.CreateCollectionAsync(cancellationToken); - } - - /// - public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) - { - return this._collection.CreateCollectionIfNotExistsAsync(cancellationToken); - } - - /// - public Task DeleteAsync(TPublicKey key, CancellationToken cancellationToken = default) - { - return this._collection.DeleteAsync(this._publicToInternalKeyMapper(key), cancellationToken); - } - - /// - public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) - { - return this._collection.DeleteBatchAsync(keys.Select(this._publicToInternalKeyMapper), cancellationToken); - } - - /// - public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) - { - return this._collection.DeleteCollectionAsync(cancellationToken); - } - - /// - public async Task GetAsync(TPublicKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) - { - var internalRecord = await this._collection.GetAsync(this._publicToInternalKeyMapper(key), options, cancellationToken).ConfigureAwait(false); - if (internalRecord == null) - { - return default; - } - - return this._internalToPublicRecordMapper(internalRecord); - } - - /// - public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) - { - var internalRecords = this._collection.GetBatchAsync(keys.Select(this._publicToInternalKeyMapper), options, cancellationToken); - return internalRecords.Select(this._internalToPublicRecordMapper); - } - - /// - public async Task UpsertAsync(TPublicRecord record, CancellationToken cancellationToken = default) - { - var internalRecord = this._publicToInternalRecordMapper(record); - var internalKey = await this._collection.UpsertAsync(internalRecord, cancellationToken).ConfigureAwait(false); - return this._internalToPublicKeyMapper(internalKey); - } - - /// - public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - var internalRecords = records.Select(this._publicToInternalRecordMapper); - var internalKeys = this._collection.UpsertBatchAsync(internalRecords, cancellationToken); - await foreach (var internalKey in internalKeys.ConfigureAwait(false)) - { - yield return this._internalToPublicKeyMapper(internalKey); - } - } - - /// - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) - { - var searchResults = await this._collection.VectorizedSearchAsync(vector, options, cancellationToken).ConfigureAwait(false); - var publicResultRecords = searchResults.Results.Select(result => new VectorSearchResult(this._internalToPublicRecordMapper(result.Record), result.Score)); - - return new VectorSearchResults(publicResultRecords) - { - TotalCount = searchResults.TotalCount, - Metadata = searchResults.Metadata, - }; - } - } -} - -#endif diff --git a/dotnet/samples/GettingStartedWithVectorStores/Step5_Use_GenericDataModel.cs b/dotnet/samples/GettingStartedWithVectorStores/Step4_Use_DynamicDataModel.cs similarity index 67% rename from dotnet/samples/GettingStartedWithVectorStores/Step5_Use_GenericDataModel.cs rename to dotnet/samples/GettingStartedWithVectorStores/Step4_Use_DynamicDataModel.cs index 449daf1c19b1..63ed0ef1d34f 100644 --- a/dotnet/samples/GettingStartedWithVectorStores/Step5_Use_GenericDataModel.cs +++ b/dotnet/samples/GettingStartedWithVectorStores/Step4_Use_DynamicDataModel.cs @@ -8,19 +8,19 @@ namespace GettingStartedWithVectorStores; /// -/// Example that shows that you can use the generic data model to interact with a vector database. -/// This makes it possible to use the vector store abstractions without having to create your own data model. +/// Example that shows that you can use the dynamic data modeling to interact with a vector database. +/// This makes it possible to use the vector store abstractions without having to create your own strongly-typed data model. /// -public class Step5_Use_GenericDataModel(ITestOutputHelper output, VectorStoresFixture fixture) : BaseTest(output), IClassFixture +public class Step4_Use_DynamicDataModel(ITestOutputHelper output, VectorStoresFixture fixture) : BaseTest(output), IClassFixture { /// - /// Example showing how to query a vector store that uses the generic data model. + /// Example showing how to query a vector store that uses dynamic data modeling. /// /// This example requires a Redis server running on localhost:6379. To run a Redis server in a Docker container, use the following command: /// docker run -d --name redis-stack -p 6379:6379 -p 8001:8001 redis/redis-stack:latest /// [Fact] - public async Task SearchAVectorStoreWithGenericDataModelAsync() + public async Task SearchAVectorStoreWithDynamicMappingAsync() { // Construct a redis vector store. var vectorStore = new RedisVectorStore(ConnectionMultiplexer.Connect("localhost:6379").GetDatabase()); @@ -32,7 +32,7 @@ public async Task SearchAVectorStoreWithGenericDataModelAsync() var customDataModelCollection = vectorStore.GetCollection("skglossary"); await Step1_Ingest_Data.IngestDataIntoVectorStoreAsync(customDataModelCollection, fixture.TextEmbeddingGenerationService); - // To use the generic data model, we still have to describe the storage schema to the vector store + // To use dynamic data modeling, we still have to describe the storage schema to the vector store // using a record definition. The benefit over a custom data model is that this definition // does not have to be known at compile time. // E.g. it can be read from a configuration or retrieved from a service. @@ -44,34 +44,30 @@ public async Task SearchAVectorStoreWithGenericDataModelAsync() new VectorStoreRecordDataProperty("Category", typeof(string)), new VectorStoreRecordDataProperty("Term", typeof(string)), new VectorStoreRecordDataProperty("Definition", typeof(string)), - new VectorStoreRecordVectorProperty("DefinitionEmbedding", typeof(ReadOnlyMemory)) { Dimensions = 1536 }, + new VectorStoreRecordVectorProperty("DefinitionEmbedding", typeof(ReadOnlyMemory), 1536), } }; - // Now, let's create a collection that uses the generic data model. - var genericDataModelCollection = vectorStore.GetCollection>("skglossary", recordDefinition); + // Now, let's create a collection that uses a dynamic data model. + var dynamicDataModelCollection = vectorStore.GetCollection>("skglossary", recordDefinition); // Generate an embedding from the search string. var searchString = "How do I provide additional context to an LLM?"; var searchVector = await fixture.TextEmbeddingGenerationService.GenerateEmbeddingAsync(searchString); // Search the generic data model collection and get the single most relevant result. - var searchResult = await genericDataModelCollection.VectorizedSearchAsync( + var searchResultItems = await dynamicDataModelCollection.SearchEmbeddingAsync( searchVector, - new() - { - Top = 1, - }); - var searchResultItems = await searchResult.Results.ToListAsync(); + top: 1).ToListAsync(); // Write the search result with its score to the console. - // Note that here we can loop through all the data properties - // without knowing the schema, since the data properties are + // Note that here we can loop through all the properties + // without knowing the schema, since the properties are // stored as a dictionary of string keys and object values - // when using the generic data model. - foreach (var dataProperty in searchResultItems.First().Record.Data) + // when using the dynamic data model. + foreach (var property in searchResultItems.First().Record) { - Console.WriteLine($"{dataProperty.Key}: {dataProperty.Value}"); + Console.WriteLine($"{property.Key}: {property.Value}"); } Console.WriteLine(searchResultItems.First().Score); } diff --git a/dotnet/samples/GettingStartedWithVectorStores/Step6_Use_CustomMapper.cs b/dotnet/samples/GettingStartedWithVectorStores/Step6_Use_CustomMapper.cs deleted file mode 100644 index cc86a773b0c0..000000000000 --- a/dotnet/samples/GettingStartedWithVectorStores/Step6_Use_CustomMapper.cs +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Text.Json; -using System.Text.Json.Nodes; -using Azure; -using Azure.Search.Documents.Indexes; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.AzureAISearch; -using Microsoft.SemanticKernel.Embeddings; - -namespace GettingStartedWithVectorStores; - -/// -/// Example that shows how you can use custom mappers if you wish the data model and storage schema to differ. -/// -public class Step6_Use_CustomMapper(ITestOutputHelper output, VectorStoresFixture fixture) : BaseTest(output), IClassFixture -{ - /// - /// Example showing how to upsert and query records when using a custom mapper if you wish - /// the data model and storage schema to differ. - /// - /// This example requires an Azure AI Search service to be available. - /// - [Fact] - public async Task UseCustomMapperAsync() - { - // When using a custom mapper, we still have to describe the storage schema to the vector store - // using a record definition. Since the storage schema does not match the data model - // it won't make sense for the vector store to infer the schema from the data model. - var recordDefinition = new VectorStoreRecordDefinition - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("Category", typeof(string)), - new VectorStoreRecordDataProperty("Term", typeof(string)), - new VectorStoreRecordDataProperty("Definition", typeof(string)), - new VectorStoreRecordVectorProperty("DefinitionEmbedding", typeof(ReadOnlyMemory)) { Dimensions = 1536 }, - } - }; - - // Construct an Azure AI Search vector store collection and - // pass in the custom mapper and record definition. - var collection = new AzureAISearchVectorStoreRecordCollection( - new SearchIndexClient( - new Uri(TestConfiguration.AzureAISearch.Endpoint), - new AzureKeyCredential(TestConfiguration.AzureAISearch.ApiKey)), - "skglossary", - new() - { - JsonObjectCustomMapper = new CustomMapper(), - VectorStoreRecordDefinition = recordDefinition - }); - - // Create the collection if it doesn't exist. - // This call will use the schena defined by the record definition - // above for creating the collection. - await collection.CreateCollectionIfNotExistsAsync(); - - // Now we can upsert a record using - // the data model, even though it doesn't match the storage schema. - var definition = "A set of rules and protocols that allows one software application to interact with another."; - await collection.UpsertAsync(new ComplexGlossary - { - Key = "1", - Metadata = new Metadata - { - Category = "API", - Term = "Application Programming Interface" - }, - Definition = definition, - DefinitionEmbedding = await fixture.TextEmbeddingGenerationService.GenerateEmbeddingAsync(definition) - }); - - // Generate an embedding from the search string. - var searchVector = await fixture.TextEmbeddingGenerationService.GenerateEmbeddingAsync("How do two software applications interact with another?"); - - // Search the vector store. - var searchResult = await collection.VectorizedSearchAsync( - searchVector, - new() - { - Top = 1 - }); - var searchResultItem = await searchResult.Results.FirstAsync(); - - // Write the search result with its score to the console. - Console.WriteLine(searchResultItem.Record.Metadata.Term); - Console.WriteLine(searchResultItem.Record.Definition); - Console.WriteLine(searchResultItem.Score); - } - - /// - /// Sample mapper class that maps between the custom data model - /// and the that should match the storage schema. - /// - private sealed class CustomMapper : IVectorStoreRecordMapper - { - public JsonObject MapFromDataToStorageModel(ComplexGlossary dataModel) - { - return new JsonObject - { - ["Key"] = dataModel.Key, - ["Category"] = dataModel.Metadata.Category, - ["Term"] = dataModel.Metadata.Term, - ["Definition"] = dataModel.Definition, - ["DefinitionEmbedding"] = JsonSerializer.SerializeToNode(dataModel.DefinitionEmbedding.ToArray()) - }; - } - - public ComplexGlossary MapFromStorageToDataModel(JsonObject storageModel, StorageToDataModelMapperOptions options) - { - return new ComplexGlossary - { - Key = storageModel["Key"]!.ToString(), - Metadata = new Metadata - { - Category = storageModel["Category"]!.ToString(), - Term = storageModel["Term"]!.ToString() - }, - Definition = storageModel["Definition"]!.ToString(), - DefinitionEmbedding = JsonSerializer.Deserialize>(storageModel["DefinitionEmbedding"]) - }; - } - } - - /// - /// Sample model class that represents a glossary entry. - /// This model differs from the model used in previous steps by having a complex property - /// that contains the category and term. - /// - private sealed class ComplexGlossary - { - public string Key { get; set; } - - public Metadata Metadata { get; set; } - - public string Definition { get; set; } - - public ReadOnlyMemory DefinitionEmbedding { get; set; } - } - - private sealed class Metadata - { - public string Category { get; set; } - - public string Term { get; set; } - } -} diff --git a/dotnet/samples/LearnResources/LearnResources.csproj b/dotnet/samples/LearnResources/LearnResources.csproj index f347bb620e21..398e4883a6a1 100644 --- a/dotnet/samples/LearnResources/LearnResources.csproj +++ b/dotnet/samples/LearnResources/LearnResources.csproj @@ -7,7 +7,7 @@ enable false - $(NoWarn);CS8618,IDE0009,CA1051,CA1050,CA1707,CA2007,VSTHRD111,CS1591,RCS1110,CA5394,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0101 + $(NoWarn);CS8618,IDE0009,CA1051,CA1050,CA1707,CA2007,VSTHRD111,CS1591,RCS1110,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0101 Library 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/src/Agents/Runtime/Abstractions/Runtime.Abstractions.csproj b/dotnet/src/Agents/Runtime/Abstractions/Runtime.Abstractions.csproj index 9f687750928e..2e32ccad3c7a 100644 --- a/dotnet/src/Agents/Runtime/Abstractions/Runtime.Abstractions.csproj +++ b/dotnet/src/Agents/Runtime/Abstractions/Runtime.Abstractions.csproj @@ -15,6 +15,7 @@ + diff --git a/dotnet/src/Agents/Runtime/Core/Runtime.Core.csproj b/dotnet/src/Agents/Runtime/Core/Runtime.Core.csproj index 2b996f882698..5607805fa2d3 100644 --- a/dotnet/src/Agents/Runtime/Core/Runtime.Core.csproj +++ b/dotnet/src/Agents/Runtime/Core/Runtime.Core.csproj @@ -14,6 +14,7 @@ + diff --git a/dotnet/src/Agents/Runtime/InProcess/Runtime.InProcess.csproj b/dotnet/src/Agents/Runtime/InProcess/Runtime.InProcess.csproj index dc585e3b2ef9..fe2326664ca5 100644 --- a/dotnet/src/Agents/Runtime/InProcess/Runtime.InProcess.csproj +++ b/dotnet/src/Agents/Runtime/InProcess/Runtime.InProcess.csproj @@ -14,6 +14,7 @@ + diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchDynamicDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchDynamicDataModelMapperTests.cs new file mode 100644 index 000000000000..1afd89e1897b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchDynamicDataModelMapperTests.cs @@ -0,0 +1,279 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json.Nodes; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.SemanticKernel.Connectors.AzureAISearch; +using Xunit; + +namespace SemanticKernel.Connectors.AzureAISearch.UnitTests; + +/// +/// Tests for the class. +/// +public class AzureAISearchDynamicDataModelMapperTests +{ + private static readonly VectorStoreRecordModel s_model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordDataProperty("IntDataProp", typeof(int)), + new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), + new VectorStoreRecordDataProperty("LongDataProp", typeof(long)), + new VectorStoreRecordDataProperty("NullableLongDataProp", typeof(long?)), + new VectorStoreRecordDataProperty("FloatDataProp", typeof(float)), + new VectorStoreRecordDataProperty("NullableFloatDataProp", typeof(float?)), + new VectorStoreRecordDataProperty("DoubleDataProp", typeof(double)), + new VectorStoreRecordDataProperty("NullableDoubleDataProp", typeof(double?)), + new VectorStoreRecordDataProperty("BoolDataProp", typeof(bool)), + new VectorStoreRecordDataProperty("NullableBoolDataProp", typeof(bool?)), + new VectorStoreRecordDataProperty("DateTimeOffsetDataProp", typeof(DateTimeOffset)), + new VectorStoreRecordDataProperty("NullableDateTimeOffsetDataProp", typeof(DateTimeOffset?)), + new VectorStoreRecordDataProperty("TagListDataProp", typeof(string[])), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10), + new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?), 10), + ]); + + private static readonly float[] s_vector1 = [1.0f, 2.0f, 3.0f]; + private static readonly float[] s_vector2 = [4.0f, 5.0f, 6.0f]; + private static readonly string[] s_taglist = ["tag1", "tag2"]; + + [Fact] + public void MapFromDataToStorageModelMapsAllSupportedTypes() + { + // Arrange + var sut = new AzureAISearchDynamicDataModelMapper(s_model); + var dataModel = new Dictionary + { + ["Key"] = "key", + + ["StringDataProp"] = "string", + ["IntDataProp"] = 1, + ["NullableIntDataProp"] = 2, + ["LongDataProp"] = 3L, + ["NullableLongDataProp"] = 4L, + ["FloatDataProp"] = 5.0f, + ["NullableFloatDataProp"] = 6.0f, + ["DoubleDataProp"] = 7.0, + ["NullableDoubleDataProp"] = 8.0, + ["BoolDataProp"] = true, + ["NullableBoolDataProp"] = false, + ["DateTimeOffsetDataProp"] = new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), + ["NullableDateTimeOffsetDataProp"] = new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), + ["TagListDataProp"] = s_taglist, + + ["FloatVector"] = new ReadOnlyMemory(s_vector1), + ["NullableFloatVector"] = new ReadOnlyMemory(s_vector2) + }; + + // Act + var storageModel = sut.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal("key", (string?)storageModel["Key"]); + Assert.Equal("string", (string?)storageModel["StringDataProp"]); + Assert.Equal(1, (int?)storageModel["IntDataProp"]); + Assert.Equal(2, (int?)storageModel["NullableIntDataProp"]); + Assert.Equal(3L, (long?)storageModel["LongDataProp"]); + Assert.Equal(4L, (long?)storageModel["NullableLongDataProp"]); + Assert.Equal(5.0f, (float?)storageModel["FloatDataProp"]); + Assert.Equal(6.0f, (float?)storageModel["NullableFloatDataProp"]); + Assert.Equal(7.0, (double?)storageModel["DoubleDataProp"]); + Assert.Equal(8.0, (double?)storageModel["NullableDoubleDataProp"]); + Assert.Equal(true, (bool?)storageModel["BoolDataProp"]); + Assert.Equal(false, (bool?)storageModel["NullableBoolDataProp"]); + Assert.Equal(new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), (DateTimeOffset?)storageModel["DateTimeOffsetDataProp"]); + Assert.Equal(new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), (DateTimeOffset?)storageModel["NullableDateTimeOffsetDataProp"]); + Assert.Equal(s_taglist, storageModel["TagListDataProp"]!.AsArray().Select(x => (string)x!).ToArray()); + Assert.Equal(s_vector1, storageModel["FloatVector"]!.AsArray().Select(x => (float)x!).ToArray()); + Assert.Equal(s_vector2, storageModel["NullableFloatVector"]!.AsArray().Select(x => (float)x!).ToArray()); + } + + [Fact] + public void MapFromDataToStorageModelMapsNullValues() + { + // Arrange + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), + new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?), 10), + ]); + + var dataModel = new Dictionary + { + ["Key"] = "key", + ["StringDataProp"] = null, + ["NullableIntDataProp"] = null, + ["NullableFloatVector"] = null + }; + + var sut = new AzureAISearchDynamicDataModelMapper(model); + + // Act + var storageModel = sut.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Null(storageModel["StringDataProp"]); + Assert.Null(storageModel["NullableIntDataProp"]); + Assert.Null(storageModel["NullableFloatVector"]); + } + + [Fact] + public void MapFromStorageToDataModelMapsAllSupportedTypes() + { + // Arrange + var sut = new AzureAISearchDynamicDataModelMapper(s_model); + var storageModel = new JsonObject(); + storageModel["Key"] = "key"; + storageModel["StringDataProp"] = "string"; + storageModel["IntDataProp"] = 1; + storageModel["NullableIntDataProp"] = 2; + storageModel["LongDataProp"] = 3L; + storageModel["NullableLongDataProp"] = 4L; + storageModel["FloatDataProp"] = 5.0f; + storageModel["NullableFloatDataProp"] = 6.0f; + storageModel["DoubleDataProp"] = 7.0; + storageModel["NullableDoubleDataProp"] = 8.0; + storageModel["BoolDataProp"] = true; + storageModel["NullableBoolDataProp"] = false; + storageModel["DateTimeOffsetDataProp"] = new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero); + storageModel["NullableDateTimeOffsetDataProp"] = new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero); + storageModel["TagListDataProp"] = new JsonArray { "tag1", "tag2" }; + storageModel["FloatVector"] = new JsonArray { 1.0f, 2.0f, 3.0f }; + storageModel["NullableFloatVector"] = new JsonArray { 4.0f, 5.0f, 6.0f }; + + // Act + var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); + + // Assert + Assert.Equal("key", dataModel["Key"]); + Assert.Equal("string", dataModel["StringDataProp"]); + Assert.Equal(1, dataModel["IntDataProp"]); + Assert.Equal(2, dataModel["NullableIntDataProp"]); + Assert.Equal(3L, dataModel["LongDataProp"]); + Assert.Equal(4L, dataModel["NullableLongDataProp"]); + Assert.Equal(5.0f, dataModel["FloatDataProp"]); + Assert.Equal(6.0f, dataModel["NullableFloatDataProp"]); + Assert.Equal(7.0, dataModel["DoubleDataProp"]); + Assert.Equal(8.0, dataModel["NullableDoubleDataProp"]); + Assert.Equal(true, dataModel["BoolDataProp"]); + Assert.Equal(false, dataModel["NullableBoolDataProp"]); + Assert.Equal(new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), dataModel["DateTimeOffsetDataProp"]); + Assert.Equal(new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), dataModel["NullableDateTimeOffsetDataProp"]); + Assert.Equal(s_taglist, dataModel["TagListDataProp"]); + Assert.Equal(s_vector1, ((ReadOnlyMemory)dataModel["FloatVector"]!).ToArray()); + Assert.Equal(s_vector2, ((ReadOnlyMemory)dataModel["NullableFloatVector"]!)!.ToArray()); + } + + [Fact] + public void MapFromStorageToDataModelMapsNullValues() + { + // Arrange + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), + new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?), 10), + ]); + + var storageModel = new JsonObject(); + storageModel["Key"] = "key"; + storageModel["StringDataProp"] = null; + storageModel["NullableIntDataProp"] = null; + storageModel["NullableFloatVector"] = null; + + var sut = new AzureAISearchDynamicDataModelMapper(model); + + // Act + var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); + + // Assert + Assert.Equal("key", dataModel["Key"]); + Assert.Null(dataModel["StringDataProp"]); + Assert.Null(dataModel["NullableIntDataProp"]); + Assert.Null(dataModel["NullableFloatVector"]); + } + + [Fact] + public void MapFromStorageToDataModelThrowsForMissingKey() + { + // Arrange + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), + new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?), 10), + ]); + + var sut = new AzureAISearchDynamicDataModelMapper(model); + var storageModel = new JsonObject(); + + // Act + var exception = Assert.Throws(() => sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true })); + + // Assert + Assert.Equal("The key property 'Key' is missing from the record retrieved from storage.", exception.Message); + } + + [Fact] + public void MapFromDataToStorageModelSkipsMissingProperties() + { + // Arrange + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10), + ]); + + var dataModel = new Dictionary { ["Key"] = "key" }; + var sut = new AzureAISearchDynamicDataModelMapper(model); + + // Act + var storageModel = sut.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal("key", (string?)storageModel["Key"]); + Assert.False(storageModel.ContainsKey("StringDataProp")); + Assert.False(storageModel.ContainsKey("FloatVector")); + } + + [Fact] + public void MapFromStorageToDataModelSkipsMissingProperties() + { + // Arrange + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10), + ]); + + var storageModel = new JsonObject(); + storageModel["Key"] = "key"; + + var sut = new AzureAISearchDynamicDataModelMapper(model); + + // Act + var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); + + // Assert + Assert.Equal("key", dataModel["Key"]); + Assert.False(dataModel.ContainsKey("StringDataProp")); + Assert.False(dataModel.ContainsKey("FloatVector")); + } + + private static VectorStoreRecordModel BuildModel(List properties) + => new VectorStoreRecordJsonModelBuilder(AzureAISearchModelBuilder.s_modelBuildingOptions) + .Build( + typeof(Dictionary), + new() { Properties = properties }, + defaultEmbeddingGenerator: null); +} diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchGenericDataModelMapperTests.cs deleted file mode 100644 index 8326be0dd639..000000000000 --- a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchGenericDataModelMapperTests.cs +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text.Json.Nodes; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.AzureAISearch; -using Xunit; - -namespace SemanticKernel.Connectors.AzureAISearch.UnitTests; - -/// -/// Tests for the class. -/// -public class AzureAISearchGenericDataModelMapperTests -{ - private static readonly VectorStoreRecordDefinition s_vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordDataProperty("IntDataProp", typeof(int)), - new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordDataProperty("LongDataProp", typeof(long)), - new VectorStoreRecordDataProperty("NullableLongDataProp", typeof(long?)), - new VectorStoreRecordDataProperty("FloatDataProp", typeof(float)), - new VectorStoreRecordDataProperty("NullableFloatDataProp", typeof(float?)), - new VectorStoreRecordDataProperty("DoubleDataProp", typeof(double)), - new VectorStoreRecordDataProperty("NullableDoubleDataProp", typeof(double?)), - new VectorStoreRecordDataProperty("BoolDataProp", typeof(bool)), - new VectorStoreRecordDataProperty("NullableBoolDataProp", typeof(bool?)), - new VectorStoreRecordDataProperty("DateTimeOffsetDataProp", typeof(DateTimeOffset)), - new VectorStoreRecordDataProperty("NullableDateTimeOffsetDataProp", typeof(DateTimeOffset?)), - new VectorStoreRecordDataProperty("TagListDataProp", typeof(string[])), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?)), - }, - }; - - private static readonly float[] s_vector1 = new float[] { 1.0f, 2.0f, 3.0f }; - private static readonly float[] s_vector2 = new float[] { 4.0f, 5.0f, 6.0f }; - private static readonly string[] s_taglist = new string[] { "tag1", "tag2" }; - - [Fact] - public void MapFromDataToStorageModelMapsAllSupportedTypes() - { - // Arrange - var sut = new AzureAISearchGenericDataModelMapper(s_vectorStoreRecordDefinition); - var dataModel = new VectorStoreGenericDataModel("key") - { - Data = - { - ["StringDataProp"] = "string", - ["IntDataProp"] = 1, - ["NullableIntDataProp"] = 2, - ["LongDataProp"] = 3L, - ["NullableLongDataProp"] = 4L, - ["FloatDataProp"] = 5.0f, - ["NullableFloatDataProp"] = 6.0f, - ["DoubleDataProp"] = 7.0, - ["NullableDoubleDataProp"] = 8.0, - ["BoolDataProp"] = true, - ["NullableBoolDataProp"] = false, - ["DateTimeOffsetDataProp"] = new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), - ["NullableDateTimeOffsetDataProp"] = new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), - ["TagListDataProp"] = s_taglist, - }, - Vectors = - { - ["FloatVector"] = new ReadOnlyMemory(s_vector1), - ["NullableFloatVector"] = new ReadOnlyMemory(s_vector2), - }, - }; - - // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal("key", (string?)storageModel["Key"]); - Assert.Equal("string", (string?)storageModel["StringDataProp"]); - Assert.Equal(1, (int?)storageModel["IntDataProp"]); - Assert.Equal(2, (int?)storageModel["NullableIntDataProp"]); - Assert.Equal(3L, (long?)storageModel["LongDataProp"]); - Assert.Equal(4L, (long?)storageModel["NullableLongDataProp"]); - Assert.Equal(5.0f, (float?)storageModel["FloatDataProp"]); - Assert.Equal(6.0f, (float?)storageModel["NullableFloatDataProp"]); - Assert.Equal(7.0, (double?)storageModel["DoubleDataProp"]); - Assert.Equal(8.0, (double?)storageModel["NullableDoubleDataProp"]); - Assert.Equal(true, (bool?)storageModel["BoolDataProp"]); - Assert.Equal(false, (bool?)storageModel["NullableBoolDataProp"]); - Assert.Equal(new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), (DateTimeOffset?)storageModel["DateTimeOffsetDataProp"]); - Assert.Equal(new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), (DateTimeOffset?)storageModel["NullableDateTimeOffsetDataProp"]); - Assert.Equal(s_taglist, storageModel["TagListDataProp"]!.AsArray().Select(x => (string)x!).ToArray()); - Assert.Equal(s_vector1, storageModel["FloatVector"]!.AsArray().Select(x => (float)x!).ToArray()); - Assert.Equal(s_vector2, storageModel["NullableFloatVector"]!.AsArray().Select(x => (float)x!).ToArray()); - } - - [Fact] - public void MapFromDataToStorageModelMapsNullValues() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?)), - }, - }; - - var dataModel = new VectorStoreGenericDataModel("key") - { - Data = - { - ["StringDataProp"] = null, - ["NullableIntDataProp"] = null, - }, - Vectors = - { - ["NullableFloatVector"] = null, - }, - }; - - var sut = new AzureAISearchGenericDataModelMapper(vectorStoreRecordDefinition); - - // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Null(storageModel["StringDataProp"]); - Assert.Null(storageModel["NullableIntDataProp"]); - Assert.Null(storageModel["NullableFloatVector"]); - } - - [Fact] - public void MapFromStorageToDataModelMapsAllSupportedTypes() - { - // Arrange - var sut = new AzureAISearchGenericDataModelMapper(s_vectorStoreRecordDefinition); - var storageModel = new JsonObject(); - storageModel["Key"] = "key"; - storageModel["StringDataProp"] = "string"; - storageModel["IntDataProp"] = 1; - storageModel["NullableIntDataProp"] = 2; - storageModel["LongDataProp"] = 3L; - storageModel["NullableLongDataProp"] = 4L; - storageModel["FloatDataProp"] = 5.0f; - storageModel["NullableFloatDataProp"] = 6.0f; - storageModel["DoubleDataProp"] = 7.0; - storageModel["NullableDoubleDataProp"] = 8.0; - storageModel["BoolDataProp"] = true; - storageModel["NullableBoolDataProp"] = false; - storageModel["DateTimeOffsetDataProp"] = new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero); - storageModel["NullableDateTimeOffsetDataProp"] = new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero); - storageModel["TagListDataProp"] = new JsonArray { "tag1", "tag2" }; - storageModel["FloatVector"] = new JsonArray { 1.0f, 2.0f, 3.0f }; - storageModel["NullableFloatVector"] = new JsonArray { 4.0f, 5.0f, 6.0f }; - - // Act - var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); - - // Assert - Assert.Equal("key", dataModel.Key); - Assert.Equal("string", dataModel.Data["StringDataProp"]); - Assert.Equal(1, dataModel.Data["IntDataProp"]); - Assert.Equal(2, dataModel.Data["NullableIntDataProp"]); - Assert.Equal(3L, dataModel.Data["LongDataProp"]); - Assert.Equal(4L, dataModel.Data["NullableLongDataProp"]); - Assert.Equal(5.0f, dataModel.Data["FloatDataProp"]); - Assert.Equal(6.0f, dataModel.Data["NullableFloatDataProp"]); - Assert.Equal(7.0, dataModel.Data["DoubleDataProp"]); - Assert.Equal(8.0, dataModel.Data["NullableDoubleDataProp"]); - Assert.Equal(true, dataModel.Data["BoolDataProp"]); - Assert.Equal(false, dataModel.Data["NullableBoolDataProp"]); - Assert.Equal(new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), dataModel.Data["DateTimeOffsetDataProp"]); - Assert.Equal(new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), dataModel.Data["NullableDateTimeOffsetDataProp"]); - Assert.Equal(s_taglist, dataModel.Data["TagListDataProp"]); - Assert.Equal(s_vector1, ((ReadOnlyMemory)dataModel.Vectors["FloatVector"]!).ToArray()); - Assert.Equal(s_vector2, ((ReadOnlyMemory)dataModel.Vectors["NullableFloatVector"]!)!.ToArray()); - } - - [Fact] - public void MapFromStorageToDataModelMapsNullValues() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?)), - }, - }; - - var storageModel = new JsonObject(); - storageModel["Key"] = "key"; - storageModel["StringDataProp"] = null; - storageModel["NullableIntDataProp"] = null; - storageModel["NullableFloatVector"] = null; - - var sut = new AzureAISearchGenericDataModelMapper(vectorStoreRecordDefinition); - - // Act - var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); - - // Assert - Assert.Equal("key", dataModel.Key); - Assert.Null(dataModel.Data["StringDataProp"]); - Assert.Null(dataModel.Data["NullableIntDataProp"]); - Assert.Null(dataModel.Vectors["NullableFloatVector"]); - } - - [Fact] - public void MapFromStorageToDataModelThrowsForMissingKey() - { - // Arrange - var sut = new AzureAISearchGenericDataModelMapper(s_vectorStoreRecordDefinition); - var storageModel = new JsonObject(); - - // Act - var exception = Assert.Throws(() => sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true })); - - // Assert - Assert.Equal("The key property 'Key' is missing from the record retrieved from storage.", exception.Message); - } - - [Fact] - public void MapFromDataToStorageModelSkipsMissingProperties() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - }, - }; - - var dataModel = new VectorStoreGenericDataModel("key"); - var sut = new AzureAISearchGenericDataModelMapper(vectorStoreRecordDefinition); - - // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal("key", (string?)storageModel["Key"]); - Assert.False(storageModel.ContainsKey("StringDataProp")); - Assert.False(storageModel.ContainsKey("FloatVector")); - } - - [Fact] - public void MapFromStorageToDataModelSkipsMissingProperties() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - }, - }; - - var storageModel = new JsonObject(); - storageModel["Key"] = "key"; - - var sut = new AzureAISearchGenericDataModelMapper(vectorStoreRecordDefinition); - - // Act - var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); - - // Assert - Assert.Equal("key", dataModel.Key); - Assert.False(dataModel.Data.ContainsKey("StringDataProp")); - Assert.False(dataModel.Vectors.ContainsKey("FloatVector")); - } -} diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchKernelBuilderExtensionsTests.cs index e7c567b7895c..f2f95374128c 100644 --- a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchKernelBuilderExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchKernelBuilderExtensionsTests.cs @@ -105,11 +105,11 @@ private void AssertVectorStoreRecordCollectionCreated() var collection = kernel.Services.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = kernel.Services.GetRequiredService>(); + var vectorizedSearch = kernel.Services.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); + Assert.IsType>(vectorizedSearch); } #pragma warning disable CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchMemoryStoreTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchMemoryStoreTests.cs index 0ebda1fc706e..95b0801d23ad 100644 --- a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchMemoryStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchMemoryStoreTests.cs @@ -21,6 +21,7 @@ namespace SemanticKernel.Connectors.UnitTests.Memory.AzureAISearch; /// /// Unit tests for class. /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public sealed class AzureAISearchMemoryStoreTests { private readonly Mock _mockSearchIndexClient = new(); diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchServiceCollectionExtensionsTests.cs index 0310aa2ca4a2..d1b24704c3f2 100644 --- a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchServiceCollectionExtensionsTests.cs @@ -105,11 +105,11 @@ private void AssertVectorStoreRecordCollectionCreated() var collection = serviceProvider.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = serviceProvider.GetRequiredService>(); + var vectorizedSearch = serviceProvider.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); + Assert.IsType>(vectorizedSearch); } #pragma warning disable CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreCollectionCreateMappingTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreCollectionCreateMappingTests.cs index 24bed31f87ed..8c4aca17be0a 100644 --- a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreCollectionCreateMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreCollectionCreateMappingTests.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using Azure.Search.Documents.Indexes.Models; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.AzureAISearch; using Xunit; @@ -18,15 +19,14 @@ public class AzureAISearchVectorStoreCollectionCreateMappingTests public void MapKeyFieldCreatesSearchableField() { // Arrange - var keyProperty = new VectorStoreRecordKeyProperty("testkey", typeof(string)); - var storagePropertyName = "test_key"; + var keyProperty = new VectorStoreRecordKeyPropertyModel("testkey", typeof(string)) { StorageName = "test_key" }; // Act - var result = AzureAISearchVectorStoreCollectionCreateMapping.MapKeyField(keyProperty, storagePropertyName); + var result = AzureAISearchVectorStoreCollectionCreateMapping.MapKeyField(keyProperty); // Assert Assert.NotNull(result); - Assert.Equal(storagePropertyName, result.Name); + Assert.Equal("test_key", result.Name); Assert.True(result.IsKey); Assert.True(result.IsFilterable); } @@ -37,16 +37,19 @@ public void MapKeyFieldCreatesSearchableField() public void MapFilterableStringDataFieldCreatesSimpleField(bool isFilterable) { // Arrange - var dataProperty = new VectorStoreRecordDataProperty("testdata", typeof(string)) { IsFilterable = isFilterable }; - var storagePropertyName = "test_data"; + var dataProperty = new VectorStoreRecordDataPropertyModel("testdata", typeof(string)) + { + IsIndexed = isFilterable, + StorageName = "test_data" + }; // Act - var result = AzureAISearchVectorStoreCollectionCreateMapping.MapDataField(dataProperty, storagePropertyName); + var result = AzureAISearchVectorStoreCollectionCreateMapping.MapDataField(dataProperty); // Assert Assert.NotNull(result); Assert.IsType(result); - Assert.Equal(storagePropertyName, result.Name); + Assert.Equal("test_data", result.Name); Assert.False(result.IsKey); Assert.Equal(isFilterable, result.IsFilterable); } @@ -57,16 +60,20 @@ public void MapFilterableStringDataFieldCreatesSimpleField(bool isFilterable) public void MapFullTextSearchableStringDataFieldCreatesSearchableField(bool isFilterable) { // Arrange - var dataProperty = new VectorStoreRecordDataProperty("testdata", typeof(string)) { IsFilterable = isFilterable, IsFullTextSearchable = true }; - var storagePropertyName = "test_data"; + var dataProperty = new VectorStoreRecordDataPropertyModel("testdata", typeof(string)) + { + IsIndexed = isFilterable, + IsFullTextIndexed = true, + StorageName = "test_data" + }; // Act - var result = AzureAISearchVectorStoreCollectionCreateMapping.MapDataField(dataProperty, storagePropertyName); + var result = AzureAISearchVectorStoreCollectionCreateMapping.MapDataField(dataProperty); // Assert Assert.NotNull(result); Assert.IsType(result); - Assert.Equal(storagePropertyName, result.Name); + Assert.Equal("test_data", result.Name); Assert.False(result.IsKey); Assert.Equal(isFilterable, result.IsFilterable); } @@ -75,11 +82,14 @@ public void MapFullTextSearchableStringDataFieldCreatesSearchableField(bool isFi public void MapFullTextSearchableStringDataFieldThrowsForInvalidType() { // Arrange - var dataProperty = new VectorStoreRecordDataProperty("testdata", typeof(int)) { IsFullTextSearchable = true }; - var storagePropertyName = "test_data"; + var dataProperty = new VectorStoreRecordDataPropertyModel("testdata", typeof(int)) + { + IsFullTextIndexed = true, + StorageName = "test_data" + }; // Act & Assert - Assert.Throws(() => AzureAISearchVectorStoreCollectionCreateMapping.MapDataField(dataProperty, storagePropertyName)); + Assert.Throws(() => AzureAISearchVectorStoreCollectionCreateMapping.MapDataField(dataProperty)); } [Theory] @@ -88,16 +98,19 @@ public void MapFullTextSearchableStringDataFieldThrowsForInvalidType() public void MapDataFieldCreatesSimpleField(bool isFilterable) { // Arrange - var dataProperty = new VectorStoreRecordDataProperty("testdata", typeof(int)) { IsFilterable = isFilterable }; - var storagePropertyName = "test_data"; + var dataProperty = new VectorStoreRecordDataPropertyModel("testdata", typeof(int)) + { + IsIndexed = isFilterable, + StorageName = "test_data" + }; // Act - var result = AzureAISearchVectorStoreCollectionCreateMapping.MapDataField(dataProperty, storagePropertyName); + var result = AzureAISearchVectorStoreCollectionCreateMapping.MapDataField(dataProperty); // Assert Assert.NotNull(result); Assert.IsType(result); - Assert.Equal(storagePropertyName, result.Name); + Assert.Equal("test_data", result.Name); Assert.Equal(SearchFieldDataType.Int32, result.Type); Assert.False(result.IsKey); Assert.Equal(isFilterable, result.IsFilterable); @@ -107,17 +120,22 @@ public void MapDataFieldCreatesSimpleField(bool isFilterable) public void MapVectorFieldCreatesVectorSearchField() { // Arrange - var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)) { Dimensions = 10, IndexKind = IndexKind.Flat, DistanceFunction = DistanceFunction.DotProductSimilarity }; - var storagePropertyName = "test_vector"; + var vectorProperty = new VectorStoreRecordVectorPropertyModel("testvector", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + IndexKind = IndexKind.Flat, + DistanceFunction = DistanceFunction.DotProductSimilarity, + StorageName = "test_vector" + }; // Act - var (vectorSearchField, algorithmConfiguration, vectorSearchProfile) = AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField(vectorProperty, storagePropertyName); + var (vectorSearchField, algorithmConfiguration, vectorSearchProfile) = AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField(vectorProperty); // Assert Assert.NotNull(vectorSearchField); Assert.NotNull(algorithmConfiguration); Assert.NotNull(vectorSearchProfile); - Assert.Equal(storagePropertyName, vectorSearchField.Name); + Assert.Equal("test_vector", vectorSearchField.Name); Assert.Equal(vectorProperty.Dimensions, vectorSearchField.VectorSearchDimensions); Assert.Equal("test_vectorAlgoConfig", algorithmConfiguration.Name); @@ -135,11 +153,16 @@ public void MapVectorFieldCreatesVectorSearchField() public void MapVectorFieldCreatesExpectedAlgoConfigTypes(string indexKind, Type algoConfigType) { // Arrange - var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)) { Dimensions = 10, IndexKind = indexKind, DistanceFunction = DistanceFunction.DotProductSimilarity }; - var storagePropertyName = "test_vector"; + var vectorProperty = new VectorStoreRecordVectorPropertyModel("testvector", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + IndexKind = indexKind, + DistanceFunction = DistanceFunction.DotProductSimilarity, + StorageName = "test_vector" + }; // Act - var (vectorSearchField, algorithmConfiguration, vectorSearchProfile) = AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField(vectorProperty, storagePropertyName); + var (vectorSearchField, algorithmConfiguration, vectorSearchProfile) = AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField(vectorProperty); // Assert Assert.Equal("test_vectorAlgoConfig", algorithmConfiguration.Name); @@ -150,11 +173,10 @@ public void MapVectorFieldCreatesExpectedAlgoConfigTypes(string indexKind, Type public void MapVectorFieldDefaultsToHsnwAndCosine() { // Arrange - var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)) { Dimensions = 10 }; - var storagePropertyName = "test_vector"; + var vectorProperty = new VectorStoreRecordVectorPropertyModel("testvector", typeof(ReadOnlyMemory)) { Dimensions = 10 }; // Act - var (vectorSearchField, algorithmConfiguration, vectorSearchProfile) = AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField(vectorProperty, storagePropertyName); + var (vectorSearchField, algorithmConfiguration, vectorSearchProfile) = AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField(vectorProperty); // Assert Assert.IsType(algorithmConfiguration); @@ -166,22 +188,14 @@ public void MapVectorFieldDefaultsToHsnwAndCosine() public void MapVectorFieldThrowsForUnsupportedDistanceFunction() { // Arrange - var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)) { Dimensions = 10, DistanceFunction = DistanceFunction.ManhattanDistance }; - var storagePropertyName = "test_vector"; + var vectorProperty = new VectorStoreRecordVectorPropertyModel("testvector", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + DistanceFunction = DistanceFunction.ManhattanDistance, + }; // Act & Assert - Assert.Throws(() => AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField(vectorProperty, storagePropertyName)); - } - - [Fact] - public void MapVectorFieldThrowsForMissingDimensionsCount() - { - // Arrange - var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)); - var storagePropertyName = "test_vector"; - - // Act & Assert - Assert.Throws(() => AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField(vectorProperty, storagePropertyName)); + Assert.Throws(() => AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField(vectorProperty)); } [Theory] diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreCollectionSearchMappingTests.cs deleted file mode 100644 index 13216b9ec8be..000000000000 --- a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreCollectionSearchMappingTests.cs +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.AzureAISearch; -using Xunit; - -namespace SemanticKernel.Connectors.AzureAISearch.UnitTests; - -#pragma warning disable CS0618 // VectorSearchFilter is obsolete - -/// -/// Contains tests for the class. -/// -public class AzureAISearchVectorStoreCollectionSearchMappingTests -{ - [Theory] - [MemberData(nameof(DataTypeMappingOptions))] - public void BuildFilterStringBuildsCorrectEqualityStringForEachFilterType(string fieldName, object? fieldValue, string expected) - { - // Arrange. - var filter = new VectorSearchFilter().EqualTo(fieldName, fieldValue!); - - // Act. - var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(filter, new Dictionary { { fieldName, "storage_" + fieldName } }); - - // Assert. - Assert.Equal(expected, actual); - } - - [Fact] - public void BuildFilterStringBuildsCorrectTagContainsString() - { - // Arrange. - var filter = new VectorSearchFilter().AnyTagEqualTo("Tags", "mytag"); - - // Act. - var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(filter, new Dictionary { { "Tags", "storage_tags" } }); - - // Assert. - Assert.Equal("storage_tags/any(t: t eq 'mytag')", actual); - } - - [Fact] - public void BuildFilterStringCombinesFilterOptions() - { - // Arrange. - var filter = new VectorSearchFilter().EqualTo("intField", 5).AnyTagEqualTo("Tags", "mytag"); - - // Act. - var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(filter, new Dictionary { { "Tags", "storage_tags" }, { "intField", "storage_intField" } }); - - // Assert. - Assert.Equal("storage_intField eq 5 and storage_tags/any(t: t eq 'mytag')", actual); - } - - [Fact] - public void BuildFilterStringThrowsForUnknownPropertyName() - { - // Act and assert. - Assert.Throws(() => AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(new VectorSearchFilter().EqualTo("unknown", "value"), new Dictionary())); - Assert.Throws(() => AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(new VectorSearchFilter().AnyTagEqualTo("unknown", "value"), new Dictionary())); - } - - public static IEnumerable DataTypeMappingOptions() - { - yield return new object[] { "stringField", "value", "storage_stringField eq 'value'" }; - yield return new object[] { "boolField", true, "storage_boolField eq true" }; - yield return new object[] { "intField", 5, "storage_intField eq 5" }; - yield return new object[] { "longField", 5L, "storage_longField eq 5" }; - yield return new object[] { "floatField", 5.5f, "storage_floatField eq 5.5" }; - yield return new object[] { "doubleField", 5.5d, "storage_doubleField eq 5.5" }; - yield return new object[] { "dateTimeOffSetField", new DateTimeOffset(2000, 10, 20, 5, 55, 55, TimeSpan.Zero), "storage_dateTimeOffSetField eq 2000-10-20T05:55:55.0000000Z" }; - yield return new object[] { "nullField", null!, "storage_nullField eq null" }; - } -} diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreRecordCollectionTests.cs index b786c8d8fa58..d5fb1dbd585c 100644 --- a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreRecordCollectionTests.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Linq; using System.Text.Json; -using System.Text.Json.Nodes; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; @@ -23,7 +22,7 @@ namespace SemanticKernel.Connectors.AzureAISearch.UnitTests; #pragma warning disable CS0618 // VectorSearchFilter is obsolete /// -/// Contains tests for the class. +/// Contains tests for the class. /// public class AzureAISearchVectorStoreRecordCollectionTests { @@ -41,6 +40,7 @@ public AzureAISearchVectorStoreRecordCollectionTests() this._searchClientMock = new Mock(MockBehavior.Strict); this._searchIndexClientMock = new Mock(MockBehavior.Strict); this._searchIndexClientMock.Setup(x => x.GetSearchClient(TestCollectionName)).Returns(this._searchClientMock.Object); + this._searchIndexClientMock.Setup(x => x.ServiceName).Returns("TestService"); } [Theory] @@ -64,7 +64,7 @@ public async Task CollectionExistsReturnsCollectionStateAsync(string collectionN .ThrowsAsync(new RequestFailedException(404, "Index not found")); } - var sut = new AzureAISearchVectorStoreRecordCollection(this._searchIndexClientMock.Object, collectionName); + var sut = new AzureAISearchVectorStoreRecordCollection(this._searchIndexClientMock.Object, collectionName); // Act. var actual = await sut.CollectionExistsAsync(this._testCancellationToken); @@ -208,7 +208,7 @@ public async Task CanGetRecordWithoutVectorsAsync(bool useDefinition, bool useCu // Arrange. var storageObject = JsonSerializer.SerializeToNode(CreateModel(TestRecordKey1, false))!.AsObject(); - var expectedSelectFields = useCustomJsonSerializerOptions ? new[] { "key", "storage_data1", "data2" } : new[] { "Key", "storage_data1", "Data2" }; + string[] expectedSelectFields = useCustomJsonSerializerOptions ? ["key", "storage_data1", "data2"] : ["Key", "storage_data1", "Data2"]; this._searchClientMock.Setup( x => x.GetDocumentAsync( TestRecordKey1, @@ -257,7 +257,7 @@ public async Task CanGetManyRecordsWithVectorsAsync(bool useDefinition) var sut = this.CreateRecordCollection(useDefinition); // Act. - var actual = await sut.GetBatchAsync( + var actual = await sut.GetAsync( [TestRecordKey1, TestRecordKey2], new() { IncludeVectors = true }, this._testCancellationToken).ToListAsync(); @@ -269,49 +269,6 @@ public async Task CanGetManyRecordsWithVectorsAsync(bool useDefinition) Assert.Equal(TestRecordKey2, actual[1].Key); } - [Fact] - public async Task CanGetRecordWithCustomMapperAsync() - { - // Arrange. - var storageObject = JsonSerializer.SerializeToNode(CreateModel(TestRecordKey1, true))!.AsObject(); - - // Arrange GetDocumentAsync mock returning JsonObject. - this._searchClientMock.Setup( - x => x.GetDocumentAsync( - TestRecordKey1, - It.Is(x => !x.SelectedFields.Any()), - this._testCancellationToken)) - .ReturnsAsync(Response.FromValue(storageObject, Mock.Of())); - - // Arrange mapper mock from JsonObject to data model. - var mapperMock = new Mock>(MockBehavior.Strict); - mapperMock.Setup( - x => x.MapFromStorageToDataModel( - storageObject, - It.Is(x => x.IncludeVectors))) - .Returns(CreateModel(TestRecordKey1, true)); - - // Arrange target with custom mapper. - var sut = new AzureAISearchVectorStoreRecordCollection( - this._searchIndexClientMock.Object, - TestCollectionName, - new() - { - JsonObjectCustomMapper = mapperMock.Object - }); - - // Act. - var actual = await sut.GetAsync(TestRecordKey1, new() { IncludeVectors = true }, this._testCancellationToken); - - // Assert. - Assert.NotNull(actual); - Assert.Equal(TestRecordKey1, actual.Key); - Assert.Equal("data 1", actual.Data1); - Assert.Equal("data 2", actual.Data2); - Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector1!.Value.ToArray()); - Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector2!.Value.ToArray()); - } - [Theory] [InlineData(true)] [InlineData(false)] @@ -368,7 +325,7 @@ public async Task CanDeleteManyRecordsWithVectorsAsync(bool useDefinition) var sut = this.CreateRecordCollection(useDefinition); // Act. - await sut.DeleteBatchAsync( + await sut.DeleteAsync( [TestRecordKey1, TestRecordKey2], cancellationToken: this._testCancellationToken); @@ -455,9 +412,9 @@ public async Task CanUpsertManyRecordsAsync(bool useDefinition) var model2 = CreateModel(TestRecordKey2, true); // Act. - var actual = await sut.UpsertBatchAsync( + var actual = await sut.UpsertAsync( [model1, model2], - cancellationToken: this._testCancellationToken).ToListAsync(); + cancellationToken: this._testCancellationToken); // Assert. Assert.NotNull(actual); @@ -473,58 +430,6 @@ public async Task CanUpsertManyRecordsAsync(bool useDefinition) Times.Once); } - [Fact] - public async Task CanUpsertRecordWithCustomMapperAsync() - { - // Arrange. -#pragma warning disable Moq1002 // Moq: No matching constructor - var indexingResult = new Mock(MockBehavior.Strict, TestRecordKey1, true, 200); - var indexingResults = new List(); - indexingResults.Add(indexingResult.Object); - var indexDocumentsResultMock = new Mock(MockBehavior.Strict, indexingResults); -#pragma warning restore Moq1002 // Moq: No matching constructor - - var model = CreateModel(TestRecordKey1, true); - var storageObject = JsonSerializer.SerializeToNode(model)!.AsObject(); - - // Arrange UploadDocumentsAsync mock returning upsert result. - this._searchClientMock.Setup( - x => x.UploadDocumentsAsync( - It.IsAny>(), - It.IsAny(), - this._testCancellationToken)) - .ReturnsAsync((IEnumerable documents, IndexDocumentsOptions options, CancellationToken cancellationToken) => - { - // Need to force a materialization of the documents enumerable here, otherwise the mapper (and therefore its mock) doesn't get invoked. - var materializedDocuments = documents.ToList(); - return Response.FromValue(indexDocumentsResultMock.Object, Mock.Of()); - }); - - // Arrange mapper mock from data model to JsonObject. - var mapperMock = new Mock>(MockBehavior.Strict); - mapperMock - .Setup(x => x.MapFromDataToStorageModel(It.IsAny())) - .Returns(storageObject); - - // Arrange target with custom mapper. - var sut = new AzureAISearchVectorStoreRecordCollection( - this._searchIndexClientMock.Object, - TestCollectionName, - new() - { - JsonObjectCustomMapper = mapperMock.Object - }); - - // Act. - await sut.UpsertAsync(model, this._testCancellationToken); - - // Assert. - mapperMock - .Verify( - x => x.MapFromDataToStorageModel(It.Is(x => x.Key == TestRecordKey1)), - Times.Once); - } - /// /// Tests that the collection can be created even if the definition and the type do not match. /// In this case, the expectation is that a custom mapper will be provided to map between the @@ -538,17 +443,17 @@ public void CanCreateCollectionWithMismatchedDefinitionAndType() { Properties = new List { - new VectorStoreRecordKeyProperty("Id", typeof(string)), - new VectorStoreRecordDataProperty("Text", typeof(string)), - new VectorStoreRecordVectorProperty("Embedding", typeof(ReadOnlyMemory)) { Dimensions = 4 }, + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("Data1", typeof(string)), + new VectorStoreRecordVectorProperty("Vector1", typeof(ReadOnlyMemory), 4), } }; // Act. - var sut = new AzureAISearchVectorStoreRecordCollection( + var sut = new AzureAISearchVectorStoreRecordCollection( this._searchIndexClientMock.Object, TestCollectionName, - new() { VectorStoreRecordDefinition = definition, JsonObjectCustomMapper = Mock.Of>() }); + new() { VectorStoreRecordDefinition = definition }); } [Fact] @@ -562,7 +467,7 @@ public async Task CanSearchWithVectorAndFilterAsync() .Setup(x => x.SearchAsync(null, It.IsAny(), It.IsAny())) .ReturnsAsync(Response.FromValue(searchResultsMock, Mock.Of())); - var sut = new AzureAISearchVectorStoreRecordCollection( + var sut = new AzureAISearchVectorStoreRecordCollection( this._searchIndexClientMock.Object, TestCollectionName); var filter = new VectorSearchFilter().EqualTo(nameof(MultiPropsModel.Data1), "Data1FilterValue"); @@ -570,14 +475,14 @@ public async Task CanSearchWithVectorAndFilterAsync() // Act. var searchResults = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[4]), + top: 5, new() { - Top = 5, Skip = 3, OldFilter = filter, VectorProperty = record => record.Vector1 }, - this._testCancellationToken); + this._testCancellationToken).ToListAsync(); // Assert. this._searchClientMock.Verify( @@ -604,7 +509,7 @@ public async Task CanSearchWithTextAndFilterAsync() .Setup(x => x.SearchAsync(null, It.IsAny(), It.IsAny())) .ReturnsAsync(Response.FromValue(searchResultsMock, Mock.Of())); - var sut = new AzureAISearchVectorStoreRecordCollection( + var sut = new AzureAISearchVectorStoreRecordCollection( this._searchIndexClientMock.Object, TestCollectionName); var filter = new VectorSearchFilter().EqualTo(nameof(MultiPropsModel.Data1), "Data1FilterValue"); @@ -612,14 +517,14 @@ public async Task CanSearchWithTextAndFilterAsync() // Act. var searchResults = await sut.VectorizableTextSearchAsync( "search string", + top: 5, new() { - Top = 5, Skip = 3, OldFilter = filter, VectorProperty = record => record.Vector1 }, - this._testCancellationToken); + this._testCancellationToken).ToListAsync(); // Assert. this._searchClientMock.Verify( @@ -636,9 +541,9 @@ public async Task CanSearchWithTextAndFilterAsync() Times.Once); } - private AzureAISearchVectorStoreRecordCollection CreateRecordCollection(bool useDefinition, bool useCustomJsonSerializerOptions = false) + private AzureAISearchVectorStoreRecordCollection CreateRecordCollection(bool useDefinition, bool useCustomJsonSerializerOptions = false) { - return new AzureAISearchVectorStoreRecordCollection( + return new AzureAISearchVectorStoreRecordCollection( this._searchIndexClientMock.Object, TestCollectionName, new() @@ -673,8 +578,8 @@ private static MultiPropsModel CreateModel(string key, bool withVectors) new VectorStoreRecordKeyProperty("Key", typeof(string)), new VectorStoreRecordDataProperty("Data1", typeof(string)), new VectorStoreRecordDataProperty("Data2", typeof(string)), - new VectorStoreRecordVectorProperty("Vector1", typeof(ReadOnlyMemory)) { Dimensions = 4 }, - new VectorStoreRecordVectorProperty("Vector2", typeof(ReadOnlyMemory)) { Dimensions = 4 } + new VectorStoreRecordVectorProperty("Vector1", typeof(ReadOnlyMemory), 4), + new VectorStoreRecordVectorProperty("Vector2", typeof(ReadOnlyMemory), 4) ] }; @@ -684,7 +589,7 @@ public sealed class MultiPropsModel public string Key { get; set; } = string.Empty; [JsonPropertyName("storage_data1")] - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string Data1 { get; set; } = string.Empty; [VectorStoreRecordData] diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreTests.cs index b79b048a5f38..17e9dff36e5a 100644 --- a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreTests.cs @@ -32,6 +32,7 @@ public AzureAISearchVectorStoreTests() this._searchClientMock = new Mock(MockBehavior.Strict); this._searchIndexClientMock = new Mock(MockBehavior.Strict); this._searchIndexClientMock.Setup(x => x.GetSearchClient(TestCollectionName)).Returns(this._searchClientMock.Object); + this._searchIndexClientMock.Setup(x => x.ServiceName).Returns("TestService"); } [Fact] @@ -45,7 +46,7 @@ public void GetCollectionReturnsCollection() // Assert. Assert.NotNull(actual); - Assert.IsType>(actual); + Assert.IsType>(actual); } #pragma warning disable CS0618 // IAzureAISearchVectorStoreRecordCollectionFactory is obsolete diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/Connectors.AzureAISearch.UnitTests.csproj b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/Connectors.AzureAISearch.UnitTests.csproj index 8583008891e7..27d2f0811843 100644 --- a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/Connectors.AzureAISearch.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/Connectors.AzureAISearch.UnitTests.csproj @@ -8,7 +8,8 @@ enable disable false - $(NoWarn);SKEXP0001,SKEXP0020 + $(NoWarn);SKEXP0001 + $(NoWarn);MEVD9001 diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBHotelModel.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBHotelModel.cs index 7fe5e3875fb8..f8c541aa20ca 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBHotelModel.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBHotelModel.cs @@ -14,7 +14,7 @@ public class AzureCosmosDBMongoDBHotelModel(string hotelId) public string HotelId { get; init; } = hotelId; /// A string metadata field. - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string? HotelName { get; set; } /// An int metadata field. @@ -39,6 +39,6 @@ public class AzureCosmosDBMongoDBHotelModel(string hotelId) public string? Description { get; set; } /// A vector field. - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance, IndexKind: IndexKind.IvfFlat)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineDistance, IndexKind = IndexKind.IvfFlat)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } } diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBKernelBuilderExtensionsTests.cs index 41151c77eba0..46bc932fe707 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBKernelBuilderExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBKernelBuilderExtensionsTests.cs @@ -82,11 +82,11 @@ private void AssertVectorStoreRecordCollectionCreated() var collection = kernel.Services.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = kernel.Services.GetRequiredService>(); + var vectorizedSearch = kernel.Services.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); + Assert.IsType>(vectorizedSearch); } #pragma warning disable CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBServiceCollectionExtensionsTests.cs index 9484be5ba373..996add717588 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBServiceCollectionExtensionsTests.cs @@ -82,11 +82,11 @@ private void AssertVectorStoreRecordCollectionCreated() var collection = serviceProvider.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = serviceProvider.GetRequiredService>(); + var vectorizedSearch = serviceProvider.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); + Assert.IsType>(vectorizedSearch); } #pragma warning disable CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreCollectionSearchMappingTests.cs index 9dee844e61d2..4f16055088b1 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreCollectionSearchMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreCollectionSearchMappingTests.cs @@ -3,7 +3,9 @@ using System; using System.Collections.Generic; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; +using Microsoft.SemanticKernel.Connectors.MongoDB; using MongoDB.Bson; using Xunit; @@ -16,11 +18,18 @@ namespace SemanticKernel.Connectors.AzureCosmosDBMongoDB.UnitTests; /// public sealed class AzureCosmosDBMongoDBVectorStoreCollectionSearchMappingTests { - private readonly Dictionary _storagePropertyNames = new() - { - ["Property1"] = "property_1", - ["Property2"] = "property_2", - }; + private readonly VectorStoreRecordModel _model = new MongoDBModelBuilder() + .Build( + typeof(Dictionary), + new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Property1", typeof(string)) { StoragePropertyName = "property_1" }, + new VectorStoreRecordDataProperty("Property2", typeof(string)) { StoragePropertyName = "property_2" } + ] + }, + defaultEmbeddingGenerator: null); [Fact] public void BuildFilterWithNullVectorSearchFilterReturnsNull() @@ -29,7 +38,7 @@ public void BuildFilterWithNullVectorSearchFilterReturnsNull() VectorSearchFilter? vectorSearchFilter = null; // Act - var filter = AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames); + var filter = AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._model); // Assert Assert.Null(filter); @@ -42,7 +51,7 @@ public void BuildFilterWithoutFilterClausesReturnsNull() VectorSearchFilter vectorSearchFilter = new(); // Act - var filter = AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames); + var filter = AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._model); // Assert Assert.Null(filter); @@ -55,7 +64,7 @@ public void BuildFilterThrowsExceptionWithUnsupportedFilterClause() var vectorSearchFilter = new VectorSearchFilter().AnyTagEqualTo("NonExistentProperty", "TestValue"); // Act & Assert - Assert.Throws(() => AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames)); + Assert.Throws(() => AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._model)); } [Fact] @@ -65,7 +74,7 @@ public void BuildFilterThrowsExceptionWithNonExistentPropertyName() var vectorSearchFilter = new VectorSearchFilter().EqualTo("NonExistentProperty", "TestValue"); // Act & Assert - Assert.Throws(() => AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames)); + Assert.Throws(() => AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._model)); } [Fact] @@ -77,7 +86,7 @@ public void BuildFilterThrowsExceptionWithMultipleFilterClausesOfSameType() .EqualTo("Property1", "TestValue2"); // Act & Assert - Assert.Throws(() => AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames)); + Assert.Throws(() => AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._model)); } [Fact] @@ -88,8 +97,15 @@ public void BuilderFilterByDefaultReturnsValidFilter() var vectorSearchFilter = new VectorSearchFilter().EqualTo("Property1", "TestValue1"); // Act - var filter = AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames); + var filter = AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._model); - Assert.Equal(filter.ToJson(), expectedFilter.ToJson()); + Assert.Equal(expectedFilter.ToJson(), filter.ToJson()); } + + private static VectorStoreRecordModel BuildModel(List properties) + => new MongoDBModelBuilder() + .Build( + typeof(Dictionary), + new() { Properties = properties }, + defaultEmbeddingGenerator: null); } diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs index 60e2584bf754..4e1fe0fd1e35 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs @@ -19,7 +19,7 @@ namespace SemanticKernel.Connectors.AzureCosmosDBMongoDB.UnitTests; /// -/// Unit tests for class. +/// Unit tests for class. /// public sealed class AzureCosmosDBMongoDBVectorStoreRecordCollectionTests { @@ -37,7 +37,7 @@ public AzureCosmosDBMongoDBVectorStoreRecordCollectionTests() public void ConstructorForModelWithoutKeyThrowsException() { // Act & Assert - var exception = Assert.Throws(() => new AzureCosmosDBMongoDBVectorStoreRecordCollection(this._mockMongoDatabase.Object, "collection")); + var exception = Assert.Throws(() => new AzureCosmosDBMongoDBVectorStoreRecordCollection(this._mockMongoDatabase.Object, "collection")); Assert.Contains("No key property found", exception.Message); } @@ -45,7 +45,7 @@ public void ConstructorForModelWithoutKeyThrowsException() public void ConstructorWithDeclarativeModelInitializesCollection() { // Act & Assert - var collection = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var collection = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); @@ -62,7 +62,7 @@ public void ConstructorWithImperativeModelInitializesCollection() }; // Act - var collection = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var collection = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection", new() { VectorStoreRecordDefinition = definition }); @@ -90,7 +90,7 @@ public async Task CollectionExistsReturnsValidResultAsync(List collectio .Setup(l => l.ListCollectionNamesAsync(It.IsAny(), It.IsAny())) .ReturnsAsync(mockCursor.Object); - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, collectionName); @@ -144,7 +144,7 @@ public async Task CreateCollectionInvokesValidMethodsAsync(bool indexExists, int .Setup(l => l.ListCollectionNamesAsync(It.IsAny(), It.IsAny())) .ReturnsAsync(mockCursor.Object); - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(this._mockMongoDatabase.Object, CollectionName); + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(this._mockMongoDatabase.Object, CollectionName); // Act await sut.CreateCollectionAsync(); @@ -207,7 +207,7 @@ public async Task CreateCollectionIfNotExistsInvokesValidMethodsAsync() .Setup(l => l.Indexes) .Returns(mockMongoIndexManager.Object); - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, CollectionName); @@ -231,7 +231,7 @@ public async Task DeleteInvokesValidMethodsAsync() // Arrange const string RecordKey = "key"; - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); @@ -255,7 +255,7 @@ public async Task DeleteBatchInvokesValidMethodsAsync() // Arrange List recordKeys = ["key1", "key2"]; - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); @@ -264,7 +264,7 @@ public async Task DeleteBatchInvokesValidMethodsAsync() var expectedDefinition = Builders.Filter.In(document => document["_id"].AsString, recordKeys); // Act - await sut.DeleteBatchAsync(recordKeys); + await sut.DeleteAsync(recordKeys); // Assert this._mockMongoCollection.Verify(l => l.DeleteManyAsync( @@ -279,7 +279,7 @@ public async Task DeleteCollectionInvokesValidMethodsAsync() // Arrange const string CollectionName = "collection"; - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, CollectionName); @@ -316,7 +316,7 @@ public async Task GetReturnsValidRecordAsync() It.IsAny())) .ReturnsAsync(mockCursor.Object); - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); @@ -354,12 +354,12 @@ public async Task GetBatchReturnsValidRecordAsync() It.IsAny())) .ReturnsAsync(mockCursor.Object); - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); // Act - var results = await sut.GetBatchAsync(["key1", "key2", "key3"]).ToListAsync(); + var results = await sut.GetAsync(["key1", "key2", "key3"]).ToListAsync(); // Assert Assert.NotNull(results[0]); @@ -385,7 +385,7 @@ public async Task UpsertReturnsRecordKeyAsync() var documentSerializer = serializerRegistry.GetSerializer(); var expectedDefinition = Builders.Filter.Eq(document => document["_id"], "key"); - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); @@ -413,12 +413,12 @@ public async Task UpsertBatchReturnsRecordKeysAsync() var hotel2 = new AzureCosmosDBMongoDBHotelModel("key2") { HotelName = "Test Name 2" }; var hotel3 = new AzureCosmosDBMongoDBHotelModel("key3") { HotelName = "Test Name 3" }; - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); // Act - var results = await sut.UpsertBatchAsync([hotel1, hotel2, hotel3]).ToListAsync(); + var results = await sut.UpsertAsync([hotel1, hotel2, hotel3]); // Assert Assert.NotNull(results); @@ -489,110 +489,32 @@ await this.TestUpsertWithModelAsync( expectedPropertyName: "bson_hotel_name"); } - [Fact] - public async Task UpsertWithCustomMapperWorksCorrectlyAsync() - { - // Arrange - var hotel = new AzureCosmosDBMongoDBHotelModel("key") { HotelName = "Test Name" }; - - var mockMapper = new Mock>(); - - mockMapper - .Setup(l => l.MapFromDataToStorageModel(It.IsAny())) - .Returns(new BsonDocument { ["_id"] = "key", ["my_name"] = "Test Name" }); - - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( - this._mockMongoDatabase.Object, - "collection", - new() { BsonDocumentCustomMapper = mockMapper.Object }); - - // Act - var result = await sut.UpsertAsync(hotel); - - // Assert - Assert.Equal("key", result); - - this._mockMongoCollection.Verify(l => l.ReplaceOneAsync( - It.IsAny>(), - It.Is(document => - document["_id"] == "key" && - document["my_name"] == "Test Name"), - It.IsAny(), - It.IsAny()), Times.Once()); - } - - [Fact] - public async Task GetWithCustomMapperWorksCorrectlyAsync() - { - // Arrange - const string RecordKey = "key"; - - var document = new BsonDocument { ["_id"] = RecordKey, ["my_name"] = "Test Name" }; - - var mockCursor = new Mock>(); - mockCursor - .Setup(l => l.MoveNextAsync(It.IsAny())) - .ReturnsAsync(true); - - mockCursor - .Setup(l => l.Current) - .Returns([document]); - - this._mockMongoCollection - .Setup(l => l.FindAsync( - It.IsAny>(), - It.IsAny>(), - It.IsAny())) - .ReturnsAsync(mockCursor.Object); - - var mockMapper = new Mock>(); - - mockMapper - .Setup(l => l.MapFromStorageToDataModel(It.IsAny(), It.IsAny())) - .Returns(new AzureCosmosDBMongoDBHotelModel(RecordKey) { HotelName = "Name from mapper" }); - - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( - this._mockMongoDatabase.Object, - "collection", - new() { BsonDocumentCustomMapper = mockMapper.Object }); - - // Act - var result = await sut.GetAsync(RecordKey); - - // Assert - Assert.NotNull(result); - Assert.Equal(RecordKey, result.HotelId); - Assert.Equal("Name from mapper", result.HotelName); - } - [Theory] - [MemberData(nameof(VectorizedSearchVectorTypeData))] - public async Task VectorizedSearchThrowsExceptionWithInvalidVectorTypeAsync(object vector, bool exceptionExpected) + [MemberData(nameof(SearchEmbeddingVectorTypeData))] + public async Task SearchEmbeddingThrowsExceptionWithInvalidVectorTypeAsync(object vector, bool exceptionExpected) { // Arrange this.MockCollectionForSearch(); - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); // Act & Assert if (exceptionExpected) { - await Assert.ThrowsAsync(async () => await sut.VectorizedSearchAsync(vector)); + await Assert.ThrowsAsync(async () => await sut.SearchEmbeddingAsync(vector, top: 3).ToListAsync()); } else { - var actual = await sut.VectorizedSearchAsync(vector); - - Assert.NotNull(actual); + Assert.NotNull(await sut.SearchEmbeddingAsync(vector, top: 3).FirstOrDefaultAsync()); } } [Theory] [InlineData("TestEmbedding1", "TestEmbedding1", 3, 3)] [InlineData("TestEmbedding2", "test_embedding_2", 4, 4)] - public async Task VectorizedSearchUsesValidQueryAsync( + public async Task SearchEmbeddingUsesValidQueryAsync( string? vectorPropertyName, string expectedVectorPropertyName, int actualTop, @@ -632,7 +554,7 @@ public async Task VectorizedSearchUsesValidQueryAsync( this.MockCollectionForSearch(); - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); @@ -644,14 +566,13 @@ public async Task VectorizedSearchUsesValidQueryAsync( }; // Act - var actual = await sut.VectorizedSearchAsync(vector, new() + var actual = await sut.SearchEmbeddingAsync(vector, top: actualTop, new() { VectorProperty = vectorSelector, - Top = actualTop, - }); + }).FirstOrDefaultAsync(); // Assert - Assert.NotNull(await actual.Results.FirstOrDefaultAsync()); + Assert.NotNull(actual); this._mockMongoCollection.Verify(l => l.AggregateAsync( It.Is>(pipeline => @@ -661,36 +582,35 @@ public async Task VectorizedSearchUsesValidQueryAsync( } [Fact] - public async Task VectorizedSearchThrowsExceptionWithNonExistentVectorPropertyNameAsync() + public async Task SearchEmbeddingThrowsExceptionWithNonExistentVectorPropertyNameAsync() { // Arrange this.MockCollectionForSearch(); - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); var options = new MEVD.VectorSearchOptions { VectorProperty = r => "non-existent-property" }; // Act & Assert - await Assert.ThrowsAsync(async () => await (await sut.VectorizedSearchAsync(new ReadOnlyMemory([1f, 2f, 3f]), options)).Results.FirstOrDefaultAsync()); + await Assert.ThrowsAsync(async () => await sut.SearchEmbeddingAsync(new ReadOnlyMemory([1f, 2f, 3f]), top: 3, options).FirstOrDefaultAsync()); } [Fact] - public async Task VectorizedSearchReturnsRecordWithScoreAsync() + public async Task SearchEmbeddingReturnsRecordWithScoreAsync() { // Arrange this.MockCollectionForSearch(); - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); // Act - var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([1f, 2f, 3f])); + var result = await sut.SearchEmbeddingAsync(new ReadOnlyMemory([1f, 2f, 3f]), top: 3).FirstOrDefaultAsync(); // Assert - var result = await actual.Results.FirstOrDefaultAsync(); Assert.NotNull(result); Assert.Equal("key", result.Record.HotelId); Assert.Equal("Test Name", result.Record.HotelName); @@ -709,7 +629,7 @@ public async Task VectorizedSearchReturnsRecordWithScoreAsync() { [], 1 } }; - public static TheoryData VectorizedSearchVectorTypeData => new() + public static TheoryData SearchEmbeddingVectorTypeData => new() { { new ReadOnlyMemory([1f, 2f, 3f]), false }, { new ReadOnlyMemory([1f, 2f, 3f]), false }, @@ -772,7 +692,7 @@ private async Task TestUpsertWithModelAsync( new() { VectorStoreRecordDefinition = definition } : null; - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection", options); @@ -861,11 +781,11 @@ private sealed class VectorSearchModel [VectorStoreRecordData] public string? HotelName { get; set; } - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance, IndexKind: IndexKind.IvfFlat, StoragePropertyName = "test_embedding_1")] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineDistance, IndexKind = IndexKind.IvfFlat, StoragePropertyName = "test_embedding_1")] public ReadOnlyMemory TestEmbedding1 { get; set; } [BsonElement("test_embedding_2")] - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance, IndexKind: IndexKind.IvfFlat)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineDistance, IndexKind = IndexKind.IvfFlat)] public ReadOnlyMemory TestEmbedding2 { get; set; } } #pragma warning restore CA1812 diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/Connectors.AzureCosmosDBMongoDB.UnitTests.csproj b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/Connectors.AzureCosmosDBMongoDB.UnitTests.csproj index a31e4b802b52..21b4d379162f 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/Connectors.AzureCosmosDBMongoDB.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/Connectors.AzureCosmosDBMongoDB.UnitTests.csproj @@ -8,7 +8,7 @@ enable disable false - $(NoWarn);SKEXP0001,SKEXP0020 + $(NoWarn);SKEXP0001 diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLDynamicDataModelMapperTests.cs similarity index 50% rename from dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLGenericDataModelMapperTests.cs rename to dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLDynamicDataModelMapperTests.cs index cbfc8e57f131..3b3ef3f9eb94 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLGenericDataModelMapperTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLDynamicDataModelMapperTests.cs @@ -6,58 +6,51 @@ using System.Text.Json; using System.Text.Json.Nodes; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; using Xunit; namespace SemanticKernel.Connectors.AzureCosmosDBNoSQL.UnitTests; /// -/// Unit tests for class. +/// Unit tests for class. /// -public sealed class AzureCosmosDBNoSQLGenericDataModelMapperTests +public sealed class AzureCosmosDBNoSQLDynamicDataModelMapperTests { private static readonly JsonSerializerOptions s_jsonSerializerOptions = JsonSerializerOptions.Default; - private static readonly VectorStoreRecordDefinition s_vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("BoolDataProp", typeof(bool)), - new VectorStoreRecordDataProperty("NullableBoolDataProp", typeof(bool?)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordDataProperty("IntDataProp", typeof(int)), - new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordDataProperty("LongDataProp", typeof(long)), - new VectorStoreRecordDataProperty("NullableLongDataProp", typeof(long?)), - new VectorStoreRecordDataProperty("FloatDataProp", typeof(float)), - new VectorStoreRecordDataProperty("NullableFloatDataProp", typeof(float?)), - new VectorStoreRecordDataProperty("DoubleDataProp", typeof(double)), - new VectorStoreRecordDataProperty("NullableDoubleDataProp", typeof(double?)), - new VectorStoreRecordDataProperty("DateTimeOffsetDataProp", typeof(DateTimeOffset)), - new VectorStoreRecordDataProperty("NullableDateTimeOffsetDataProp", typeof(DateTimeOffset?)), - new VectorStoreRecordDataProperty("TagListDataProp", typeof(List)), -#if NET5_0_OR_GREATER - new VectorStoreRecordVectorProperty("HalfVector", typeof(ReadOnlyMemory)), - new VectorStoreRecordVectorProperty("NullableHalfVector", typeof(ReadOnlyMemory?)), -#endif - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?)), - new VectorStoreRecordVectorProperty("ByteVector", typeof(ReadOnlyMemory)), - new VectorStoreRecordVectorProperty("NullableByteVector", typeof(ReadOnlyMemory?)), - new VectorStoreRecordVectorProperty("SByteVector", typeof(ReadOnlyMemory)), - new VectorStoreRecordVectorProperty("NullableSByteVector", typeof(ReadOnlyMemory?)), - }, - }; - - private static readonly Dictionary s_storagePropertyNames = - s_vectorStoreRecordDefinition.Properties.ToDictionary( - k => k.DataModelPropertyName, - v => v is VectorStoreRecordKeyProperty ? "id" : v.DataModelPropertyName); - -#if NET5_0_OR_GREATER - private static readonly Half[] s_halfVector = [(Half)1.0f, (Half)2.0f, (Half)3.0f]; -#endif + private static readonly VectorStoreRecordModel s_model = new AzureCosmosDBNoSQLVectorStoreModelBuilder() + .Build( + typeof(Dictionary), + new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("BoolDataProp", typeof(bool)), + new VectorStoreRecordDataProperty("NullableBoolDataProp", typeof(bool?)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordDataProperty("IntDataProp", typeof(int)), + new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), + new VectorStoreRecordDataProperty("LongDataProp", typeof(long)), + new VectorStoreRecordDataProperty("NullableLongDataProp", typeof(long?)), + new VectorStoreRecordDataProperty("FloatDataProp", typeof(float)), + new VectorStoreRecordDataProperty("NullableFloatDataProp", typeof(float?)), + new VectorStoreRecordDataProperty("DoubleDataProp", typeof(double)), + new VectorStoreRecordDataProperty("NullableDoubleDataProp", typeof(double?)), + new VectorStoreRecordDataProperty("DateTimeOffsetDataProp", typeof(DateTimeOffset)), + new VectorStoreRecordDataProperty("NullableDateTimeOffsetDataProp", typeof(DateTimeOffset?)), + new VectorStoreRecordDataProperty("TagListDataProp", typeof(List)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10), + new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?), 10), + new VectorStoreRecordVectorProperty("ByteVector", typeof(ReadOnlyMemory), 10), + new VectorStoreRecordVectorProperty("NullableByteVector", typeof(ReadOnlyMemory?), 10), + new VectorStoreRecordVectorProperty("SByteVector", typeof(ReadOnlyMemory), 10), + new VectorStoreRecordVectorProperty("NullableSByteVector", typeof(ReadOnlyMemory?), 10), + }, + }, + defaultEmbeddingGenerator: null); + private static readonly float[] s_floatVector = [1.0f, 2.0f, 3.0f]; private static readonly byte[] s_byteVector = [1, 2, 3]; private static readonly sbyte[] s_sbyteVector = [1, 2, 3]; @@ -67,47 +60,37 @@ public sealed class AzureCosmosDBNoSQLGenericDataModelMapperTests public void MapFromDataToStorageModelMapsAllSupportedTypes() { // Arrange - var sut = new AzureCosmosDBNoSQLGenericDataModelMapper( - s_vectorStoreRecordDefinition.Properties, - s_storagePropertyNames, - s_jsonSerializerOptions); + var sut = new AzureCosmosDBNoSQLDynamicDataModelMapper(s_model, s_jsonSerializerOptions); - var dataModel = new VectorStoreGenericDataModel("key") + var dataModel = new Dictionary { - Data = - { - ["BoolDataProp"] = true, - ["NullableBoolDataProp"] = false, - ["StringDataProp"] = "string", - ["IntDataProp"] = 1, - ["NullableIntDataProp"] = 2, - ["LongDataProp"] = 3L, - ["NullableLongDataProp"] = 4L, - ["FloatDataProp"] = 5.0f, - ["NullableFloatDataProp"] = 6.0f, - ["DoubleDataProp"] = 7.0, - ["NullableDoubleDataProp"] = 8.0, - ["DateTimeOffsetDataProp"] = new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), - ["NullableDateTimeOffsetDataProp"] = new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), - ["TagListDataProp"] = s_taglist, - }, - Vectors = - { -#if NET5_0_OR_GREATER - ["HalfVector"] = new ReadOnlyMemory(s_halfVector), - ["NullableHalfVector"] = new ReadOnlyMemory(s_halfVector), -#endif - ["FloatVector"] = new ReadOnlyMemory(s_floatVector), - ["NullableFloatVector"] = new ReadOnlyMemory(s_floatVector), - ["ByteVector"] = new ReadOnlyMemory(s_byteVector), - ["NullableByteVector"] = new ReadOnlyMemory(s_byteVector), - ["SByteVector"] = new ReadOnlyMemory(s_sbyteVector), - ["NullableSByteVector"] = new ReadOnlyMemory(s_sbyteVector) - }, + ["Key"] = "key", + + ["BoolDataProp"] = true, + ["NullableBoolDataProp"] = false, + ["StringDataProp"] = "string", + ["IntDataProp"] = 1, + ["NullableIntDataProp"] = 2, + ["LongDataProp"] = 3L, + ["NullableLongDataProp"] = 4L, + ["FloatDataProp"] = 5.0f, + ["NullableFloatDataProp"] = 6.0f, + ["DoubleDataProp"] = 7.0, + ["NullableDoubleDataProp"] = 8.0, + ["DateTimeOffsetDataProp"] = new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), + ["NullableDateTimeOffsetDataProp"] = new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), + ["TagListDataProp"] = s_taglist, + + ["FloatVector"] = new ReadOnlyMemory(s_floatVector), + ["NullableFloatVector"] = new ReadOnlyMemory(s_floatVector), + ["ByteVector"] = new ReadOnlyMemory(s_byteVector), + ["NullableByteVector"] = new ReadOnlyMemory(s_byteVector), + ["SByteVector"] = new ReadOnlyMemory(s_sbyteVector), + ["NullableSByteVector"] = new ReadOnlyMemory(s_sbyteVector) }; // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); + var storageModel = sut.MapFromDataToStorageModel(dataModel, generatedEmbeddings: null); // Assert Assert.Equal("key", (string?)storageModel["id"]); @@ -125,10 +108,6 @@ public void MapFromDataToStorageModelMapsAllSupportedTypes() Assert.Equal(new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), (DateTimeOffset?)storageModel["DateTimeOffsetDataProp"]); Assert.Equal(new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), (DateTimeOffset?)storageModel["NullableDateTimeOffsetDataProp"]); Assert.Equal(s_taglist, storageModel["TagListDataProp"]!.AsArray().GetValues().ToArray()); -#if NET5_0_OR_GREATER - Assert.Equal(s_halfVector, storageModel["HalfVector"]!.AsArray().Select(l => (Half)(float)l!).ToArray()); - Assert.Equal(s_halfVector, storageModel["NullableHalfVector"]!.AsArray().Select(l => (Half)(float)l!).ToArray()); -#endif Assert.Equal(s_floatVector, storageModel["FloatVector"]!.AsArray().GetValues().ToArray()); Assert.Equal(s_floatVector, storageModel["NullableFloatVector"]!.AsArray().GetValues().ToArray()); Assert.Equal(s_byteVector, storageModel["ByteVector"]!.AsArray().GetValues().ToArray()); @@ -148,30 +127,22 @@ public void MapFromDataToStorageModelMapsNullValues() new VectorStoreRecordKeyProperty("Key", typeof(string)), new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?)), + new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?), 10), }, }; - var dataModel = new VectorStoreGenericDataModel("key") + var dataModel = new Dictionary { - Data = - { - ["StringDataProp"] = null, - ["NullableIntDataProp"] = null, - }, - Vectors = - { - ["NullableFloatVector"] = null, - }, + ["Key"] = "key", + ["StringDataProp"] = null, + ["NullableIntDataProp"] = null, + ["NullableFloatVector"] = null }; - var sut = new AzureCosmosDBNoSQLGenericDataModelMapper( - s_vectorStoreRecordDefinition.Properties, - s_storagePropertyNames, - s_jsonSerializerOptions); + var sut = new AzureCosmosDBNoSQLDynamicDataModelMapper(s_model, s_jsonSerializerOptions); // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); + var storageModel = sut.MapFromDataToStorageModel(dataModel, generatedEmbeddings: null); // Assert Assert.Null(storageModel["StringDataProp"]); @@ -183,10 +154,7 @@ public void MapFromDataToStorageModelMapsNullValues() public void MapFromStorageToDataModelMapsAllSupportedTypes() { // Arrange - var sut = new AzureCosmosDBNoSQLGenericDataModelMapper( - s_vectorStoreRecordDefinition.Properties, - s_storagePropertyNames, - s_jsonSerializerOptions); + var sut = new AzureCosmosDBNoSQLDynamicDataModelMapper(s_model, s_jsonSerializerOptions); var storageModel = new JsonObject { @@ -205,10 +173,6 @@ public void MapFromStorageToDataModelMapsAllSupportedTypes() ["DateTimeOffsetDataProp"] = new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), ["NullableDateTimeOffsetDataProp"] = new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), ["TagListDataProp"] = new JsonArray(s_taglist.Select(l => (JsonValue)l).ToArray()), -#if NET5_0_OR_GREATER - ["HalfVector"] = new JsonArray(s_halfVector.Select(l => (JsonValue)(float)l).ToArray()), - ["NullableHalfVector"] = new JsonArray(s_halfVector.Select(l => (JsonValue)(float)l).ToArray()), -#endif ["FloatVector"] = new JsonArray(s_floatVector.Select(l => (JsonValue)l).ToArray()), ["NullableFloatVector"] = new JsonArray(s_floatVector.Select(l => (JsonValue)l).ToArray()), ["ByteVector"] = new JsonArray(s_byteVector.Select(l => (JsonValue)l).ToArray()), @@ -221,31 +185,27 @@ public void MapFromStorageToDataModelMapsAllSupportedTypes() var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); // Assert - Assert.Equal("key", dataModel.Key); - Assert.Equal(true, dataModel.Data["BoolDataProp"]); - Assert.Equal(false, dataModel.Data["NullableBoolDataProp"]); - Assert.Equal("string", dataModel.Data["StringDataProp"]); - Assert.Equal(1, dataModel.Data["IntDataProp"]); - Assert.Equal(2, dataModel.Data["NullableIntDataProp"]); - Assert.Equal(3L, dataModel.Data["LongDataProp"]); - Assert.Equal(4L, dataModel.Data["NullableLongDataProp"]); - Assert.Equal(5.0f, dataModel.Data["FloatDataProp"]); - Assert.Equal(6.0f, dataModel.Data["NullableFloatDataProp"]); - Assert.Equal(7.0, dataModel.Data["DoubleDataProp"]); - Assert.Equal(8.0, dataModel.Data["NullableDoubleDataProp"]); - Assert.Equal(new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), dataModel.Data["DateTimeOffsetDataProp"]); - Assert.Equal(new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), dataModel.Data["NullableDateTimeOffsetDataProp"]); - Assert.Equal(s_taglist, dataModel.Data["TagListDataProp"]); -#if NET5_0_OR_GREATER - Assert.Equal(s_halfVector, ((ReadOnlyMemory)dataModel.Vectors["HalfVector"]!).ToArray()); - Assert.Equal(s_halfVector, ((ReadOnlyMemory)dataModel.Vectors["NullableHalfVector"]!)!.ToArray()); -#endif - Assert.Equal(s_floatVector, ((ReadOnlyMemory)dataModel.Vectors["FloatVector"]!).ToArray()); - Assert.Equal(s_floatVector, ((ReadOnlyMemory)dataModel.Vectors["NullableFloatVector"]!)!.ToArray()); - Assert.Equal(s_byteVector, ((ReadOnlyMemory)dataModel.Vectors["ByteVector"]!).ToArray()); - Assert.Equal(s_byteVector, ((ReadOnlyMemory)dataModel.Vectors["NullableByteVector"]!)!.ToArray()); - Assert.Equal(s_sbyteVector, ((ReadOnlyMemory)dataModel.Vectors["SByteVector"]!).ToArray()); - Assert.Equal(s_sbyteVector, ((ReadOnlyMemory)dataModel.Vectors["NullableSByteVector"]!)!.ToArray()); + Assert.Equal("key", dataModel["Key"]); + Assert.Equal(true, dataModel["BoolDataProp"]); + Assert.Equal(false, dataModel["NullableBoolDataProp"]); + Assert.Equal("string", dataModel["StringDataProp"]); + Assert.Equal(1, dataModel["IntDataProp"]); + Assert.Equal(2, dataModel["NullableIntDataProp"]); + Assert.Equal(3L, dataModel["LongDataProp"]); + Assert.Equal(4L, dataModel["NullableLongDataProp"]); + Assert.Equal(5.0f, dataModel["FloatDataProp"]); + Assert.Equal(6.0f, dataModel["NullableFloatDataProp"]); + Assert.Equal(7.0, dataModel["DoubleDataProp"]); + Assert.Equal(8.0, dataModel["NullableDoubleDataProp"]); + Assert.Equal(new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), dataModel["DateTimeOffsetDataProp"]); + Assert.Equal(new DateTimeOffset(2021, 1, 1, 0, 0, 0, TimeSpan.Zero), dataModel["NullableDateTimeOffsetDataProp"]); + Assert.Equal(s_taglist, dataModel["TagListDataProp"]); + Assert.Equal(s_floatVector, ((ReadOnlyMemory)dataModel["FloatVector"]!).ToArray()); + Assert.Equal(s_floatVector, ((ReadOnlyMemory)dataModel["NullableFloatVector"]!)!.ToArray()); + Assert.Equal(s_byteVector, ((ReadOnlyMemory)dataModel["ByteVector"]!).ToArray()); + Assert.Equal(s_byteVector, ((ReadOnlyMemory)dataModel["NullableByteVector"]!)!.ToArray()); + Assert.Equal(s_sbyteVector, ((ReadOnlyMemory)dataModel["SByteVector"]!).ToArray()); + Assert.Equal(s_sbyteVector, ((ReadOnlyMemory)dataModel["NullableSByteVector"]!)!.ToArray()); } [Fact] @@ -259,7 +219,7 @@ public void MapFromStorageToDataModelMapsNullValues() new VectorStoreRecordKeyProperty("Key", typeof(string)), new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?)), + new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?), 10), }, }; @@ -271,29 +231,23 @@ public void MapFromStorageToDataModelMapsNullValues() ["NullableFloatVector"] = null }; - var sut = new AzureCosmosDBNoSQLGenericDataModelMapper( - s_vectorStoreRecordDefinition.Properties, - s_storagePropertyNames, - s_jsonSerializerOptions); + var sut = new AzureCosmosDBNoSQLDynamicDataModelMapper(s_model, s_jsonSerializerOptions); // Act var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); // Assert - Assert.Equal("key", dataModel.Key); - Assert.Null(dataModel.Data["StringDataProp"]); - Assert.Null(dataModel.Data["NullableIntDataProp"]); - Assert.Null(dataModel.Vectors["NullableFloatVector"]); + Assert.Equal("key", dataModel["Key"]); + Assert.Null(dataModel["StringDataProp"]); + Assert.Null(dataModel["NullableIntDataProp"]); + Assert.Null(dataModel["NullableFloatVector"]); } [Fact] public void MapFromStorageToDataModelThrowsForMissingKey() { // Arrange - var sut = new AzureCosmosDBNoSQLGenericDataModelMapper( - s_vectorStoreRecordDefinition.Properties, - s_storagePropertyNames, - s_jsonSerializerOptions); + var sut = new AzureCosmosDBNoSQLDynamicDataModelMapper(s_model, s_jsonSerializerOptions); var storageModel = new JsonObject(); @@ -312,18 +266,15 @@ public void MapFromDataToStorageModelSkipsMissingProperties() { new VectorStoreRecordKeyProperty("Key", typeof(string)), new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10), }, }; - var dataModel = new VectorStoreGenericDataModel("key"); - var sut = new AzureCosmosDBNoSQLGenericDataModelMapper( - s_vectorStoreRecordDefinition.Properties, - s_storagePropertyNames, - s_jsonSerializerOptions); + var dataModel = new Dictionary { ["Key"] = "key" }; + var sut = new AzureCosmosDBNoSQLDynamicDataModelMapper(s_model, s_jsonSerializerOptions); // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); + var storageModel = sut.MapFromDataToStorageModel(dataModel, generatedEmbeddings: null); // Assert Assert.Equal("key", (string?)storageModel["id"]); @@ -341,7 +292,7 @@ public void MapFromStorageToDataModelSkipsMissingProperties() { new VectorStoreRecordKeyProperty("Key", typeof(string)), new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10), }, }; @@ -350,17 +301,14 @@ public void MapFromStorageToDataModelSkipsMissingProperties() ["id"] = "key" }; - var sut = new AzureCosmosDBNoSQLGenericDataModelMapper( - s_vectorStoreRecordDefinition.Properties, - s_storagePropertyNames, - s_jsonSerializerOptions); + var sut = new AzureCosmosDBNoSQLDynamicDataModelMapper(s_model, s_jsonSerializerOptions); // Act var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); // Assert - Assert.Equal("key", dataModel.Key); - Assert.False(dataModel.Data.ContainsKey("StringDataProp")); - Assert.False(dataModel.Vectors.ContainsKey("FloatVector")); + Assert.Equal("key", dataModel["Key"]); + Assert.False(dataModel.ContainsKey("StringDataProp")); + Assert.False(dataModel.ContainsKey("FloatVector")); } } diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLHotel.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLHotel.cs index df06e97d3846..331758b02202 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLHotel.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLHotel.cs @@ -39,6 +39,6 @@ public class AzureCosmosDBNoSQLHotel(string hotelId) /// A vector field. [JsonPropertyName("description_embedding")] - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineSimilarity, IndexKind: IndexKind.Flat)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineSimilarity, IndexKind = IndexKind.Flat)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } } diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLKernelBuilderExtensionsTests.cs index 1ad9fc3ea68a..59a697c2e869 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLKernelBuilderExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLKernelBuilderExtensionsTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Reflection; +using System.Text.Json; using Microsoft.Azure.Cosmos; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; @@ -18,12 +19,24 @@ namespace SemanticKernel.Connectors.AzureCosmosDBNoSQL.UnitTests; public sealed class AzureCosmosDBNoSQLKernelBuilderExtensionsTests { private readonly IKernelBuilder _kernelBuilder = Kernel.CreateBuilder(); + private readonly Mock _mockDatabase = new(); + + public AzureCosmosDBNoSQLKernelBuilderExtensionsTests() + { + var mockClient = new Mock(); + + mockClient.Setup(l => l.ClientOptions).Returns(new CosmosClientOptions() { UseSystemTextJsonSerializerWithOptions = JsonSerializerOptions.Default }); + + this._mockDatabase + .Setup(l => l.Client) + .Returns(mockClient.Object); + } [Fact] public void AddVectorStoreRegistersClass() { // Arrange - this._kernelBuilder.Services.AddSingleton(Mock.Of()); + this._kernelBuilder.Services.AddSingleton(this._mockDatabase.Object); // Act this._kernelBuilder.AddAzureCosmosDBNoSQLVectorStore(); @@ -55,7 +68,7 @@ public void AddVectorStoreWithConnectionStringRegistersClass() public void AddVectorStoreRecordCollectionRegistersClass() { // Arrange - this._kernelBuilder.Services.AddSingleton(Mock.Of()); + this._kernelBuilder.Services.AddSingleton(this._mockDatabase.Object); // Act this._kernelBuilder.AddAzureCosmosDBNoSQLVectorStoreRecordCollection("testcollection"); @@ -80,11 +93,11 @@ private void AssertVectorStoreRecordCollectionCreated() var collection = kernel.Services.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = kernel.Services.GetRequiredService>(); + var vectorizedSearch = kernel.Services.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); + Assert.IsType>(vectorizedSearch); } #pragma warning disable CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLServiceCollectionExtensionsTests.cs index 4574415ecb2e..07900cb0dfc0 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLServiceCollectionExtensionsTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Reflection; +using System.Text.Json; using Microsoft.Azure.Cosmos; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; @@ -18,12 +19,24 @@ namespace SemanticKernel.Connectors.AzureCosmosDBNoSQL.UnitTests; public sealed class AzureCosmosDBNoSQLServiceCollectionExtensionsTests { private readonly IServiceCollection _serviceCollection = new ServiceCollection(); + private readonly Mock _mockDatabase = new(); + + public AzureCosmosDBNoSQLServiceCollectionExtensionsTests() + { + var mockClient = new Mock(); + + mockClient.Setup(l => l.ClientOptions).Returns(new CosmosClientOptions() { UseSystemTextJsonSerializerWithOptions = JsonSerializerOptions.Default }); + + this._mockDatabase + .Setup(l => l.Client) + .Returns(mockClient.Object); + } [Fact] public void AddVectorStoreRegistersClass() { // Arrange - this._serviceCollection.AddSingleton(Mock.Of()); + this._serviceCollection.AddSingleton(this._mockDatabase.Object); // Act this._serviceCollection.AddAzureCosmosDBNoSQLVectorStore(); @@ -56,7 +69,7 @@ public void AddVectorStoreWithConnectionStringRegistersClass() public void AddVectorStoreRecordCollectionRegistersClass() { // Arrange - this._serviceCollection.AddSingleton(Mock.Of()); + this._serviceCollection.AddSingleton(this._mockDatabase.Object); // Act this._serviceCollection.AddAzureCosmosDBNoSQLVectorStoreRecordCollection("testcollection"); @@ -81,11 +94,11 @@ private void AssertVectorStoreRecordCollectionCreated() var collection = serviceProvider.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = serviceProvider.GetRequiredService>(); + var vectorizedSearch = serviceProvider.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); + Assert.IsType>(vectorizedSearch); } #pragma warning disable CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs index 55d062441674..c3a9bcb78c3a 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs @@ -2,8 +2,8 @@ using System; using System.Collections.Generic; -using System.Linq; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; using Xunit; @@ -18,12 +18,19 @@ public sealed class AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests { private const string ScorePropertyName = "TestScore"; - private readonly Dictionary _storagePropertyNames = new() - { - ["TestProperty1"] = "test_property_1", - ["TestProperty2"] = "test_property_2", - ["TestProperty3"] = "test_property_3", - }; + private readonly VectorStoreRecordModel _model = new AzureCosmosDBNoSQLVectorStoreModelBuilder().Build( + typeof(Dictionary), + new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordVectorProperty("TestProperty1", typeof(ReadOnlyMemory), 10) { StoragePropertyName = "test_property_1" }, + new VectorStoreRecordDataProperty("TestProperty2", typeof(string)) { StoragePropertyName = "test_property_2" }, + new VectorStoreRecordDataProperty("TestProperty3", typeof(string)) { StoragePropertyName = "test_property_3" } + ] + }, + defaultEmbeddingGenerator: null); [Fact] public void BuildSearchQueryByDefaultReturnsValidQueryDefinition() @@ -31,7 +38,6 @@ public void BuildSearchQueryByDefaultReturnsValidQueryDefinition() // Arrange var vector = new ReadOnlyMemory([1f, 2f, 3f]); var vectorPropertyName = "test_property_1"; - var fields = this._storagePropertyNames.Values.ToList(); var filter = new VectorSearchFilter() .EqualTo("TestProperty2", "test-value-2") @@ -41,23 +47,23 @@ public void BuildSearchQueryByDefaultReturnsValidQueryDefinition() var queryDefinition = AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSearchQuery, DummyType>( vector, keywords: null, - fields, - this._storagePropertyNames, + this._model, vectorPropertyName, textPropertyName: null, ScorePropertyName, oldFilter: filter, filter: null, 10, - 5); + 5, + includeVectors: true); var queryText = queryDefinition.QueryText; var queryParameters = queryDefinition.GetQueryParameters(); // Assert - Assert.Contains("SELECT x.test_property_1,x.test_property_2,x.test_property_3,VectorDistance(x.test_property_1, @vector) AS TestScore", queryText); + Assert.Contains("SELECT x.id,x.TestProperty1,x.TestProperty2,x.TestProperty3,VectorDistance(x.test_property_1, @vector) AS TestScore", queryText); Assert.Contains("FROM x", queryText); - Assert.Contains("WHERE x.test_property_2 = @cv0 AND ARRAY_CONTAINS(x.test_property_3, @cv1)", queryText); + Assert.Contains("WHERE x.TestProperty2 = @cv0 AND ARRAY_CONTAINS(x.TestProperty3, @cv1)", queryText); Assert.Contains("ORDER BY VectorDistance(x.test_property_1, @vector)", queryText); Assert.Contains("OFFSET 5 LIMIT 10", queryText); @@ -77,7 +83,6 @@ public void BuildSearchQueryWithoutOffsetReturnsQueryDefinitionWithTopParameter( // Arrange var vector = new ReadOnlyMemory([1f, 2f, 3f]); var vectorPropertyName = "test_property_1"; - var fields = this._storagePropertyNames.Values.ToList(); var filter = new VectorSearchFilter() .EqualTo("TestProperty2", "test-value-2") @@ -87,23 +92,23 @@ public void BuildSearchQueryWithoutOffsetReturnsQueryDefinitionWithTopParameter( var queryDefinition = AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSearchQuery, DummyType>( vector, keywords: null, - fields, - this._storagePropertyNames, + this._model, vectorPropertyName, textPropertyName: null, ScorePropertyName, oldFilter: filter, filter: null, 10, - 0); + 0, + includeVectors: true); var queryText = queryDefinition.QueryText; var queryParameters = queryDefinition.GetQueryParameters(); // Assert - Assert.Contains("SELECT TOP 10 x.test_property_1,x.test_property_2,x.test_property_3,VectorDistance(x.test_property_1, @vector) AS TestScore", queryText); + Assert.Contains("SELECT TOP 10 x.id,x.TestProperty1,x.TestProperty2,x.TestProperty3,VectorDistance(x.test_property_1, @vector) AS TestScore", queryText); Assert.Contains("FROM x", queryText); - Assert.Contains("WHERE x.test_property_2 = @cv0 AND ARRAY_CONTAINS(x.test_property_3, @cv1)", queryText); + Assert.Contains("WHERE x.TestProperty2 = @cv0 AND ARRAY_CONTAINS(x.TestProperty3, @cv1)", queryText); Assert.Contains("ORDER BY VectorDistance(x.test_property_1, @vector)", queryText); Assert.DoesNotContain("OFFSET 0 LIMIT 10", queryText); @@ -124,7 +129,6 @@ public void BuildSearchQueryWithInvalidFilterThrowsException() // Arrange var vector = new ReadOnlyMemory([1f, 2f, 3f]); var vectorPropertyName = "test_property_1"; - var fields = this._storagePropertyNames.Values.ToList(); var filter = new VectorSearchFilter().EqualTo("non-existent-property", "test-value-2"); @@ -133,15 +137,15 @@ public void BuildSearchQueryWithInvalidFilterThrowsException() AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSearchQuery, DummyType>( vector, keywords: null, - fields, - this._storagePropertyNames, + this._model, vectorPropertyName, textPropertyName: null, ScorePropertyName, oldFilter: filter, filter: null, 10, - 5)); + 5, + includeVectors: true)); } [Fact] @@ -150,21 +154,20 @@ public void BuildSearchQueryWithoutFilterDoesNotContainWhereClause() // Arrange var vector = new ReadOnlyMemory([1f, 2f, 3f]); var vectorPropertyName = "test_property_1"; - var fields = this._storagePropertyNames.Values.ToList(); // Act var queryDefinition = AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSearchQuery, DummyType>( vector, keywords: null, - fields, - this._storagePropertyNames, + this._model, vectorPropertyName, textPropertyName: null, ScorePropertyName, oldFilter: null, filter: null, 10, - 5); + 5, + includeVectors: true); var queryText = queryDefinition.QueryText; var queryParameters = queryDefinition.GetQueryParameters(); @@ -182,23 +185,35 @@ public void BuildSelectQueryByDefaultReturnsValidQueryDefinition() { // Arrange const string ExpectedQueryText = """ - SELECT x.key,x.property_1,x.property_2 + SELECT x.id,x.TestProperty1,x.TestProperty2 FROM x - WHERE (x.key_property = @rk0 AND x.partition_key_property = @pk0) + WHERE (x.id = @rk0 AND x.TestProperty1 = @pk0) """; - const string KeyStoragePropertyName = "key_property"; - const string PartitionKeyPropertyName = "partition_key_property"; - - var keys = new List { new("key", "partition_key") }; - var fields = new List { "key", "property_1", "property_2" }; + const string KeyStoragePropertyName = "id"; + const string PartitionKeyPropertyName = "TestProperty1"; + + var model = new AzureCosmosDBNoSQLVectorStoreModelBuilder().Build( + typeof(Dictionary), + new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("TestProperty1", typeof(string)), + new VectorStoreRecordDataProperty("TestProperty2", typeof(string)) + ] + }, + defaultEmbeddingGenerator: null); + var keys = new List { new("id", "TestProperty1") }; // Act var queryDefinition = AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSelectQuery( + model, KeyStoragePropertyName, PartitionKeyPropertyName, keys, - fields); + includeVectors: true); var queryText = queryDefinition.QueryText; var queryParameters = queryDefinition.GetQueryParameters(); @@ -207,10 +222,10 @@ FROM x Assert.Equal(ExpectedQueryText, queryText); Assert.Equal("@rk0", queryParameters[0].Name); - Assert.Equal("key", queryParameters[0].Value); + Assert.Equal("id", queryParameters[0].Value); Assert.Equal("@pk0", queryParameters[1].Name); - Assert.Equal("partition_key", queryParameters[1].Value); + Assert.Equal("TestProperty1", queryParameters[1].Value); } [Fact] @@ -219,9 +234,8 @@ public void BuildSearchQueryWithHybridFieldsReturnsValidHybridQueryDefinition() // Arrange var vector = new ReadOnlyMemory([1f, 2f, 3f]); var keywordText = "hybrid"; - var vectorPropertyName = "test_property_1"; - var textPropertyName = "test_property_2"; - var fields = this._storagePropertyNames.Values.ToList(); + var vectorPropertyName = "TestProperty1"; + var textPropertyName = "TestProperty2"; var filter = new VectorSearchFilter() .EqualTo("TestProperty2", "test-value-2") @@ -231,24 +245,24 @@ public void BuildSearchQueryWithHybridFieldsReturnsValidHybridQueryDefinition() var queryDefinition = AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSearchQuery, DummyType>( vector, [keywordText], - fields, - this._storagePropertyNames, + this._model, vectorPropertyName, textPropertyName, ScorePropertyName, oldFilter: filter, filter: null, 10, - 5); + 5, + includeVectors: true); var queryText = queryDefinition.QueryText; var queryParameters = queryDefinition.GetQueryParameters(); // Assert - Assert.Contains("SELECT x.test_property_1,x.test_property_2,x.test_property_3,VectorDistance(x.test_property_1, @vector) AS TestScore", queryText); + Assert.Contains("SELECT x.id,x.TestProperty1,x.TestProperty2,x.TestProperty3,VectorDistance(x.TestProperty1, @vector) AS TestScore", queryText); Assert.Contains("FROM x", queryText); - Assert.Contains("WHERE x.test_property_2 = @cv0 AND ARRAY_CONTAINS(x.test_property_3, @cv1)", queryText); - Assert.Contains("ORDER BY RANK RRF(VectorDistance(x.test_property_1, @vector), FullTextScore(x.test_property_2, [\"hybrid\"]))", queryText); + Assert.Contains("WHERE x.TestProperty2 = @cv0 AND ARRAY_CONTAINS(x.TestProperty3, @cv1)", queryText); + Assert.Contains("ORDER BY RANK RRF(VectorDistance(x.TestProperty1, @vector), FullTextScore(x.TestProperty2, [\"hybrid\"]))", queryText); Assert.Contains("OFFSET 5 LIMIT 10", queryText); Assert.Equal("@vector", queryParameters[0].Name); diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs index 6f33a19e0b28..073120889d47 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text.Json; using System.Text.Json.Nodes; using System.Threading; using System.Threading.Tasks; @@ -17,7 +18,7 @@ namespace SemanticKernel.Connectors.AzureCosmosDBNoSQL.UnitTests; /// -/// Unit tests for class. +/// Unit tests for class. /// public sealed class AzureCosmosDBNoSQLVectorStoreRecordCollectionTests { @@ -26,24 +27,44 @@ public sealed class AzureCosmosDBNoSQLVectorStoreRecordCollectionTests public AzureCosmosDBNoSQLVectorStoreRecordCollectionTests() { + this._mockDatabase.Setup(l => l.GetContainer(It.IsAny())).Returns(this._mockContainer.Object); + + var mockClient = new Mock(); + + mockClient.Setup(l => l.ClientOptions).Returns(new CosmosClientOptions() { UseSystemTextJsonSerializerWithOptions = JsonSerializerOptions.Default }); + this._mockDatabase - .Setup(l => l.GetContainer(It.IsAny())) - .Returns(this._mockContainer.Object); + .Setup(l => l.Client) + .Returns(mockClient.Object); } [Fact] public void ConstructorForModelWithoutKeyThrowsException() { // Act & Assert - var exception = Assert.Throws(() => new AzureCosmosDBNoSQLVectorStoreRecordCollection(this._mockDatabase.Object, "collection")); + var exception = Assert.Throws(() => new AzureCosmosDBNoSQLVectorStoreRecordCollection(this._mockDatabase.Object, "collection")); Assert.Contains("No key property found", exception.Message); } + [Fact] + public void ConstructorWithoutSystemTextJsonSerializerOptionsThrowsArgumentException() + { + // Arrange + var mockDatabase = new Mock(); + var mockClient = new Mock(); + + mockDatabase.Setup(l => l.Client).Returns(mockClient.Object); + + // Act & Assert + var exception = Assert.Throws(() => new AzureCosmosDBNoSQLVectorStoreRecordCollection(mockDatabase.Object, "collection")); + Assert.Contains(nameof(CosmosClientOptions.UseSystemTextJsonSerializerWithOptions), exception.Message); + } + [Fact] public void ConstructorWithDeclarativeModelInitializesCollection() { // Act & Assert - var collection = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + var collection = new AzureCosmosDBNoSQLVectorStoreRecordCollection( this._mockDatabase.Object, "collection"); @@ -60,7 +81,7 @@ public void ConstructorWithImperativeModelInitializesCollection() }; // Act - var collection = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + var collection = new AzureCosmosDBNoSQLVectorStoreRecordCollection( this._mockDatabase.Object, "collection", new() { VectorStoreRecordDefinition = definition }); @@ -96,7 +117,7 @@ public async Task CollectionExistsReturnsValidResultAsync(List collectio It.IsAny())) .Returns(mockFeedIterator.Object); - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( this._mockDatabase.Object, collectionName); @@ -116,7 +137,7 @@ public async Task CreateCollectionUsesValidContainerPropertiesAsync(IndexingMode // Arrange const string CollectionName = "collection"; - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( this._mockDatabase.Object, CollectionName, new() { IndexingMode = indexingMode, Automatic = indexingMode != IndexingMode.None }); @@ -216,7 +237,7 @@ public async Task CreateCollectionIfNotExistsInvokesValidMethodsAsync(List())) .Returns(mockFeedIterator.Object); - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( this._mockDatabase.Object, CollectionName); @@ -243,18 +264,22 @@ public async Task DeleteInvokesValidMethodsAsync( const string RecordKey = "recordKey"; const string PartitionKey = "partitionKey"; - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( - this._mockDatabase.Object, - "collection"); - // Act if (useCompositeKeyCollection) { + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + this._mockDatabase.Object, + "collection"); + await ((IVectorStoreRecordCollection)sut).DeleteAsync( new AzureCosmosDBNoSQLCompositeKey(RecordKey, PartitionKey)); } else { + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + this._mockDatabase.Object, + "collection"); + await ((IVectorStoreRecordCollection)sut).DeleteAsync( RecordKey); } @@ -274,12 +299,12 @@ public async Task DeleteBatchInvokesValidMethodsAsync() // Arrange List recordKeys = ["key1", "key2"]; - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( this._mockDatabase.Object, "collection"); // Act - await sut.DeleteBatchAsync(recordKeys); + await sut.DeleteAsync(recordKeys); // Assert foreach (var key in recordKeys) @@ -297,7 +322,7 @@ public async Task DeleteBatchInvokesValidMethodsAsync() public async Task DeleteCollectionInvokesValidMethodsAsync() { // Arrange - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( this._mockDatabase.Object, "collection"); @@ -341,7 +366,7 @@ public async Task GetReturnsValidRecordAsync() It.IsAny())) .Returns(mockFeedIterator.Object); - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( this._mockDatabase.Object, "collection"); @@ -384,12 +409,12 @@ public async Task GetBatchReturnsValidRecordAsync() It.IsAny())) .Returns(mockFeedIterator.Object); - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( this._mockDatabase.Object, "collection"); // Act - var results = await sut.GetBatchAsync(["key1", "key2", "key3"]).ToListAsync(); + var results = await sut.GetAsync(["key1", "key2", "key3"]).ToListAsync(); // Assert Assert.NotNull(results[0]); @@ -411,7 +436,7 @@ public async Task UpsertReturnsRecordKeyAsync() // Arrange var hotel = new AzureCosmosDBNoSQLHotel("key") { HotelName = "Test Name" }; - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( this._mockDatabase.Object, "collection"); @@ -439,12 +464,12 @@ public async Task UpsertBatchReturnsRecordKeysAsync() var hotel2 = new AzureCosmosDBNoSQLHotel("key2") { HotelName = "Test Name 2" }; var hotel3 = new AzureCosmosDBNoSQLHotel("key3") { HotelName = "Test Name 3" }; - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( this._mockDatabase.Object, "collection"); // Act - var results = await sut.UpsertBatchAsync([hotel1, hotel2, hotel3]).ToListAsync(); + var results = await sut.UpsertAsync([hotel1, hotel2, hotel3]); // Assert Assert.NotNull(results); @@ -455,89 +480,6 @@ public async Task UpsertBatchReturnsRecordKeysAsync() Assert.Equal("key3", results[2]); } - [Fact] - public async Task UpsertWithCustomMapperWorksCorrectlyAsync() - { - // Arrange - var hotel = new AzureCosmosDBNoSQLHotel("key") { HotelName = "Test Name" }; - - var mockMapper = new Mock>(); - - mockMapper - .Setup(l => l.MapFromDataToStorageModel(It.IsAny())) - .Returns(new JsonObject { ["id"] = "key", ["my_name"] = "Test Name" }); - - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( - this._mockDatabase.Object, - "collection", - new() { JsonObjectCustomMapper = mockMapper.Object }); - - // Act - var result = await sut.UpsertAsync(hotel); - - // Assert - Assert.Equal("key", result); - - this._mockContainer.Verify(l => l.UpsertItemAsync( - It.Is(node => - node["id"]!.ToString() == "key" && - node["my_name"]!.ToString() == "Test Name"), - new PartitionKey("key"), - It.IsAny(), - It.IsAny()), - Times.Once()); - } - - [Fact] - public async Task GetWithCustomMapperWorksCorrectlyAsync() - { - // Arrange - const string RecordKey = "key"; - - var jsonObject = new JsonObject { ["id"] = RecordKey, ["HotelName"] = "Test Name" }; - - var mockFeedResponse = new Mock>(); - mockFeedResponse - .Setup(l => l.Resource) - .Returns([jsonObject]); - - var mockFeedIterator = new Mock>(); - mockFeedIterator - .SetupSequence(l => l.HasMoreResults) - .Returns(true) - .Returns(false); - - mockFeedIterator - .Setup(l => l.ReadNextAsync(It.IsAny())) - .ReturnsAsync(mockFeedResponse.Object); - - this._mockContainer - .Setup(l => l.GetItemQueryIterator( - It.IsAny(), - It.IsAny(), - It.IsAny())) - .Returns(mockFeedIterator.Object); - - var mockMapper = new Mock>(); - - mockMapper - .Setup(l => l.MapFromStorageToDataModel(It.IsAny(), It.IsAny())) - .Returns(new AzureCosmosDBNoSQLHotel(RecordKey) { HotelName = "Name from mapper" }); - - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( - this._mockDatabase.Object, - "collection", - new() { JsonObjectCustomMapper = mockMapper.Object }); - - // Act - var result = await sut.GetAsync(RecordKey); - - // Assert - Assert.NotNull(result); - Assert.Equal(RecordKey, result.HotelId); - Assert.Equal("Name from mapper", result.HotelName); - } - [Fact] public async Task VectorizedSearchReturnsValidRecordAsync() { @@ -574,14 +516,12 @@ public async Task VectorizedSearchReturnsValidRecordAsync() It.IsAny())) .Returns(mockFeedIterator.Object); - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( this._mockDatabase.Object, "collection"); // Act - var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([1f, 2f, 3f])); - - var results = await actual.Results.ToListAsync(); + var results = await sut.SearchEmbeddingAsync(new ReadOnlyMemory([1f, 2f, 3f]), top: 3).ToListAsync(); var result = results[0]; // Assert @@ -595,20 +535,20 @@ public async Task VectorizedSearchReturnsValidRecordAsync() public async Task VectorizedSearchWithUnsupportedVectorTypeThrowsExceptionAsync() { // Arrange - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( this._mockDatabase.Object, "collection"); // Act & Assert await Assert.ThrowsAsync(async () => - await (await sut.VectorizedSearchAsync(new List([1, 2, 3]))).Results.ToListAsync()); + await sut.SearchEmbeddingAsync(new List([1, 2, 3]), top: 3).ToListAsync()); } [Fact] public async Task VectorizedSearchWithNonExistentVectorPropertyNameThrowsExceptionAsync() { // Arrange - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( this._mockDatabase.Object, "collection"); @@ -616,7 +556,7 @@ public async Task VectorizedSearchWithNonExistentVectorPropertyNameThrowsExcepti // Act & Assert await Assert.ThrowsAsync(async () => - await (await sut.VectorizedSearchAsync(new ReadOnlyMemory([1f, 2f, 3f]), searchOptions)).Results.ToListAsync()); + await sut.SearchEmbeddingAsync(new ReadOnlyMemory([1f, 2f, 3f]), top: 3, searchOptions).ToListAsync()); } public static TheoryData, string, bool> CollectionExistsData => new() @@ -695,19 +635,19 @@ private sealed class TestIndexingModel [VectorStoreRecordKey] public string? Id { get; set; } - [VectorStoreRecordVector(Dimensions: 2, DistanceFunction: DistanceFunction.CosineSimilarity, IndexKind: IndexKind.Flat)] + [VectorStoreRecordVector(Dimensions: 2, DistanceFunction = DistanceFunction.CosineSimilarity, IndexKind = IndexKind.Flat)] public ReadOnlyMemory? DescriptionEmbedding2 { get; set; } - [VectorStoreRecordVector(Dimensions: 3, DistanceFunction: DistanceFunction.DotProductSimilarity, IndexKind: IndexKind.QuantizedFlat)] + [VectorStoreRecordVector(Dimensions: 3, DistanceFunction = DistanceFunction.DotProductSimilarity, IndexKind = IndexKind.QuantizedFlat)] public ReadOnlyMemory? DescriptionEmbedding3 { get; set; } - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.EuclideanDistance, IndexKind: IndexKind.DiskAnn)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.EuclideanDistance, IndexKind = IndexKind.DiskAnn)] public ReadOnlyMemory? DescriptionEmbedding4 { get; set; } - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string? IndexableData1 { get; set; } - [VectorStoreRecordData(IsFullTextSearchable = true)] + [VectorStoreRecordData(IsFullTextIndexed = true)] public string? IndexableData2 { get; set; } [VectorStoreRecordData] diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreRecordMapperTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreRecordMapperTests.cs index 9c2b7de29b41..9446a05a5045 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreRecordMapperTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreRecordMapperTests.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Text.Json; using System.Text.Json.Nodes; +using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; using Xunit; @@ -15,20 +16,23 @@ namespace SemanticKernel.Connectors.AzureCosmosDBNoSQL.UnitTests; /// public sealed class AzureCosmosDBNoSQLVectorStoreRecordMapperTests { - private readonly AzureCosmosDBNoSQLVectorStoreRecordMapper _sut; - - public AzureCosmosDBNoSQLVectorStoreRecordMapperTests() - { - var storagePropertyNames = new Dictionary - { - ["HotelId"] = "HotelId", - ["HotelName"] = "HotelName", - ["Tags"] = "Tags", - ["DescriptionEmbedding"] = "description_embedding", - }; - - this._sut = new("HotelId", storagePropertyNames, JsonSerializerOptions.Default); - } + private readonly AzureCosmosDBNoSQLVectorStoreRecordMapper _sut + = new( + new AzureCosmosDBNoSQLVectorStoreModelBuilder().Build( + typeof(Dictionary), + new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("HotelId", typeof(string)), + new VectorStoreRecordVectorProperty("TestProperty1", typeof(ReadOnlyMemory), 10) { StoragePropertyName = "test_property_1" }, + new VectorStoreRecordDataProperty("TestProperty2", typeof(string)) { StoragePropertyName = "test_property_2" }, + new VectorStoreRecordDataProperty("TestProperty3", typeof(string)) { StoragePropertyName = "test_property_3" } + ] + }, + defaultEmbeddingGenerator: null, + JsonSerializerOptions.Default), + JsonSerializerOptions.Default); [Fact] public void MapFromDataToStorageModelReturnsValidObject() @@ -42,7 +46,7 @@ public void MapFromDataToStorageModelReturnsValidObject() }; // Act - var document = this._sut.MapFromDataToStorageModel(hotel); + var document = this._sut.MapFromDataToStorageModel(hotel, generatedEmbeddings: null); // Assert Assert.NotNull(document); diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreTests.cs index 84ad3b36f4a6..aa1e14a4a771 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Cosmos; @@ -20,6 +21,17 @@ public sealed class AzureCosmosDBNoSQLVectorStoreTests { private readonly Mock _mockDatabase = new(); + public AzureCosmosDBNoSQLVectorStoreTests() + { + var mockClient = new Mock(); + + mockClient.Setup(l => l.ClientOptions).Returns(new CosmosClientOptions() { UseSystemTextJsonSerializerWithOptions = JsonSerializerOptions.Default }); + + this._mockDatabase + .Setup(l => l.Client) + .Returns(mockClient.Object); + } + [Fact] public void GetCollectionWithNotSupportedKeyThrowsException() { diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/Connectors.AzureCosmosDBNoSQL.UnitTests.csproj b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/Connectors.AzureCosmosDBNoSQL.UnitTests.csproj index ff8643740f11..032bd3bd9eed 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/Connectors.AzureCosmosDBNoSQL.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/Connectors.AzureCosmosDBNoSQL.UnitTests.csproj @@ -8,7 +8,8 @@ enable disable false - $(NoWarn);SKEXP0001,SKEXP0020 + $(NoWarn);SKEXP0001 + $(NoWarn);MEVD9001 diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Connectors.Google.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Google.UnitTests/Connectors.Google.UnitTests.csproj index adff4d81e1b0..4468b0001333 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Connectors.Google.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Connectors.Google.UnitTests.csproj @@ -8,7 +8,7 @@ enable disable false - $(NoWarn);CA2007,CA1806,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0050,SKEXP0070 + $(NoWarn);CA2007,CA1806,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0050,SKEXP0070 diff --git a/dotnet/src/Connectors/Connectors.InMemory.UnitTests/Connectors.InMemory.UnitTests.csproj b/dotnet/src/Connectors/Connectors.InMemory.UnitTests/Connectors.InMemory.UnitTests.csproj index a125a758c729..1b00bcec55de 100644 --- a/dotnet/src/Connectors/Connectors.InMemory.UnitTests/Connectors.InMemory.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.InMemory.UnitTests/Connectors.InMemory.UnitTests.csproj @@ -8,7 +8,7 @@ enable disable false - $(NoWarn);CA2007,CA1806,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0020 + $(NoWarn);CA2007,CA1806,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001 diff --git a/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryKernelBuilderExtensionsTests.cs index b3ce5286c9d6..aa6dceb543dd 100644 --- a/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryKernelBuilderExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryKernelBuilderExtensionsTests.cs @@ -52,7 +52,7 @@ private void AssertVectorStoreRecordCollectionCreated() Assert.NotNull(collection); Assert.IsType>(collection); - var vectorizedSearch = kernel.Services.GetRequiredService>(); + var vectorizedSearch = kernel.Services.GetRequiredService>(); Assert.NotNull(vectorizedSearch); Assert.IsType>(vectorizedSearch); } diff --git a/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryServiceCollectionExtensionsTests.cs index f195f9267711..99d64b820d2c 100644 --- a/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryServiceCollectionExtensionsTests.cs @@ -52,7 +52,7 @@ private void AssertVectorStoreRecordCollectionCreated() Assert.NotNull(collection); Assert.IsType>(collection); - var vectorizedSearch = serviceProvider.GetRequiredService>(); + var vectorizedSearch = serviceProvider.GetRequiredService>(); Assert.NotNull(vectorizedSearch); Assert.IsType>(vectorizedSearch); } diff --git a/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryVectorStoreRecordCollectionTests.cs deleted file mode 100644 index b6ac78086915..000000000000 --- a/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryVectorStoreRecordCollectionTests.cs +++ /dev/null @@ -1,578 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.InMemory; -using Xunit; - -namespace SemanticKernel.Connectors.InMemory.UnitTests; - -/// -/// Contains tests for the class. -/// -public class InMemoryVectorStoreRecordCollectionTests -{ - private const string TestCollectionName = "testcollection"; - private const string TestRecordKey1 = "testid1"; - private const string TestRecordKey2 = "testid2"; - private const int TestRecordIntKey1 = 1; - private const int TestRecordIntKey2 = 2; - - private readonly CancellationToken _testCancellationToken = new(false); - - private readonly ConcurrentDictionary> _collectionStore; - private readonly ConcurrentDictionary _collectionStoreTypes; - - public InMemoryVectorStoreRecordCollectionTests() - { - this._collectionStore = new(); - this._collectionStoreTypes = new(); - } - - [Theory] - [InlineData(TestCollectionName, true)] - [InlineData("nonexistentcollection", false)] - public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) - { - // Arrange - var collection = new ConcurrentDictionary(); - this._collectionStore.TryAdd(TestCollectionName, collection); - - var sut = new InMemoryVectorStoreRecordCollection>( - this._collectionStore, - this._collectionStoreTypes, - collectionName); - - // Act - var actual = await sut.CollectionExistsAsync(this._testCancellationToken); - - // Assert - Assert.Equal(expectedExists, actual); - } - - [Fact] - public async Task CanCreateCollectionAsync() - { - // Arrange - var sut = this.CreateRecordCollection(false); - - // Act - await sut.CreateCollectionAsync(this._testCancellationToken); - - // Assert - Assert.True(this._collectionStore.ContainsKey(TestCollectionName)); - } - - [Fact] - public async Task DeleteCollectionRemovesCollectionFromDictionaryAsync() - { - // Arrange - var collection = new ConcurrentDictionary(); - this._collectionStore.TryAdd(TestCollectionName, collection); - - var sut = this.CreateRecordCollection(false); - - // Act - await sut.DeleteCollectionAsync(this._testCancellationToken); - - // Assert - Assert.Empty(this._collectionStore); - } - - [Theory] - [InlineData(true, TestRecordKey1)] - [InlineData(true, TestRecordIntKey1)] - [InlineData(false, TestRecordKey1)] - [InlineData(false, TestRecordIntKey1)] - public async Task CanGetRecordWithVectorsAsync(bool useDefinition, TKey testKey) - where TKey : notnull - { - // Arrange - var record = CreateModel(testKey, withVectors: true); - var collection = new ConcurrentDictionary(); - collection.TryAdd(testKey!, record); - this._collectionStore.TryAdd(TestCollectionName, collection); - - var sut = this.CreateRecordCollection(useDefinition); - - // Act - var actual = await sut.GetAsync( - testKey, - new() - { - IncludeVectors = true - }, - this._testCancellationToken); - - // Assert - var expectedArgs = new object[] { TestRecordKey1 }; - - Assert.NotNull(actual); - Assert.Equal(testKey, actual.Key); - Assert.Equal($"data {testKey}", actual.Data); - Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector!.Value.ToArray()); - } - - [Theory] - [InlineData(true, TestRecordKey1, TestRecordKey2)] - [InlineData(true, TestRecordIntKey1, TestRecordIntKey2)] - [InlineData(false, TestRecordKey1, TestRecordKey2)] - [InlineData(false, TestRecordIntKey1, TestRecordIntKey2)] - public async Task CanGetManyRecordsWithVectorsAsync(bool useDefinition, TKey testKey1, TKey testKey2) - where TKey : notnull - { - // Arrange - var record1 = CreateModel(testKey1, withVectors: true); - var record2 = CreateModel(testKey2, withVectors: true); - var collection = new ConcurrentDictionary(); - collection.TryAdd(testKey1!, record1); - collection.TryAdd(testKey2!, record2); - this._collectionStore.TryAdd(TestCollectionName, collection); - - var sut = this.CreateRecordCollection(useDefinition); - - // Act - var actual = await sut.GetBatchAsync( - [testKey1, testKey2], - new() - { - IncludeVectors = true - }, - this._testCancellationToken).ToListAsync(); - - // Assert - Assert.NotNull(actual); - Assert.Equal(2, actual.Count); - Assert.Equal(testKey1, actual[0].Key); - Assert.Equal($"data {testKey1}", actual[0].Data); - Assert.Equal(testKey2, actual[1].Key); - Assert.Equal($"data {testKey2}", actual[1].Data); - } - - [Theory] - [InlineData(true, TestRecordKey1, TestRecordKey2)] - [InlineData(true, TestRecordIntKey1, TestRecordIntKey2)] - [InlineData(false, TestRecordKey1, TestRecordKey2)] - [InlineData(false, TestRecordIntKey1, TestRecordIntKey2)] - public async Task CanDeleteRecordAsync(bool useDefinition, TKey testKey1, TKey testKey2) - where TKey : notnull - { - // Arrange - var record1 = CreateModel(testKey1, withVectors: true); - var record2 = CreateModel(testKey2, withVectors: true); - var collection = new ConcurrentDictionary(); - collection.TryAdd(testKey1, record1); - collection.TryAdd(testKey2, record2); - this._collectionStore.TryAdd(TestCollectionName, collection); - - var sut = this.CreateRecordCollection(useDefinition); - - // Act - await sut.DeleteAsync( - testKey1, - cancellationToken: this._testCancellationToken); - - // Assert - Assert.False(collection.ContainsKey(testKey1)); - Assert.True(collection.ContainsKey(testKey2)); - } - - [Theory] - [InlineData(true, TestRecordKey1, TestRecordKey2)] - [InlineData(true, TestRecordIntKey1, TestRecordIntKey2)] - [InlineData(false, TestRecordKey1, TestRecordKey2)] - [InlineData(false, TestRecordIntKey1, TestRecordIntKey2)] - public async Task CanDeleteManyRecordsWithVectorsAsync(bool useDefinition, TKey testKey1, TKey testKey2) - where TKey : notnull - { - // Arrange - var record1 = CreateModel(testKey1, withVectors: true); - var record2 = CreateModel(testKey2, withVectors: true); - var collection = new ConcurrentDictionary(); - collection.TryAdd(testKey1, record1); - collection.TryAdd(testKey2, record2); - this._collectionStore.TryAdd(TestCollectionName, collection); - - var sut = this.CreateRecordCollection(useDefinition); - - // Act - await sut.DeleteBatchAsync( - [testKey1, testKey2], - cancellationToken: this._testCancellationToken); - - // Assert - Assert.False(collection.ContainsKey(testKey1)); - Assert.False(collection.ContainsKey(testKey2)); - } - - [Theory] - [InlineData(true, TestRecordKey1)] - [InlineData(true, TestRecordIntKey1)] - [InlineData(false, TestRecordKey1)] - [InlineData(false, TestRecordIntKey1)] - public async Task CanUpsertRecordAsync(bool useDefinition, TKey testKey1) - where TKey : notnull - { - // Arrange - var record1 = CreateModel(testKey1, withVectors: true); - var collection = new ConcurrentDictionary(); - this._collectionStore.TryAdd(TestCollectionName, collection); - - var sut = this.CreateRecordCollection(useDefinition); - - // Act - var upsertResult = await sut.UpsertAsync( - record1, - cancellationToken: this._testCancellationToken); - - // Assert - Assert.Equal(testKey1, upsertResult); - Assert.True(collection.ContainsKey(testKey1)); - Assert.IsType>(collection[testKey1]); - Assert.Equal($"data {testKey1}", (collection[testKey1] as SinglePropsModel)!.Data); - } - - [Theory] - [InlineData(true, TestRecordKey1, TestRecordKey2)] - [InlineData(true, TestRecordIntKey1, TestRecordIntKey2)] - [InlineData(false, TestRecordKey1, TestRecordKey2)] - [InlineData(false, TestRecordIntKey1, TestRecordIntKey2)] - public async Task CanUpsertManyRecordsAsync(bool useDefinition, TKey testKey1, TKey testKey2) - where TKey : notnull - { - // Arrange - var record1 = CreateModel(testKey1, withVectors: true); - var record2 = CreateModel(testKey2, withVectors: true); - - var collection = new ConcurrentDictionary(); - this._collectionStore.TryAdd(TestCollectionName, collection); - - var sut = this.CreateRecordCollection(useDefinition); - - // Act - var actual = await sut.UpsertBatchAsync( - [record1, record2], - cancellationToken: this._testCancellationToken).ToListAsync(); - - // Assert - Assert.NotNull(actual); - Assert.Equal(2, actual.Count); - Assert.Equal(testKey1, actual[0]); - Assert.Equal(testKey2, actual[1]); - - Assert.True(collection.ContainsKey(testKey1)); - Assert.IsType>(collection[testKey1]); - Assert.Equal($"data {testKey1}", (collection[testKey1] as SinglePropsModel)!.Data); - } - - [Theory] - [InlineData(true, TestRecordKey1, TestRecordKey2)] - [InlineData(true, TestRecordIntKey1, TestRecordIntKey2)] - [InlineData(false, TestRecordKey1, TestRecordKey2)] - [InlineData(false, TestRecordIntKey1, TestRecordIntKey2)] - public async Task CanSearchWithVectorAsync(bool useDefinition, TKey testKey1, TKey testKey2) - where TKey : notnull - { - // Arrange - var record1 = CreateModel(testKey1, withVectors: true, new float[] { 1, 1, 1, 1 }); - var record2 = CreateModel(testKey2, withVectors: true, new float[] { -1, -1, -1, -1 }); - - var collection = new ConcurrentDictionary(); - collection.TryAdd(testKey1, record1); - collection.TryAdd(testKey2, record2); - - this._collectionStore.TryAdd(TestCollectionName, collection); - - var sut = this.CreateRecordCollection(useDefinition); - - // Act - var actual = await sut.VectorizedSearchAsync( - new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new() { IncludeVectors = true }, - this._testCancellationToken); - - // Assert - Assert.NotNull(actual); - Assert.Null(actual.TotalCount); - var actualResults = await actual.Results.ToListAsync(); - Assert.Equal(2, actualResults.Count); - Assert.Equal(testKey1, actualResults[0].Record.Key); - Assert.Equal($"data {testKey1}", actualResults[0].Record.Data); - Assert.Equal(1, actualResults[0].Score); - Assert.Equal(testKey2, actualResults[1].Record.Key); - Assert.Equal($"data {testKey2}", actualResults[1].Record.Data); - Assert.Equal(-1, actualResults[1].Score); - } - -#pragma warning disable CS0618 // VectorSearchFilter is obsolete - [Theory] - [InlineData(true, TestRecordKey1, TestRecordKey2, "Equality")] - [InlineData(true, TestRecordIntKey1, TestRecordIntKey2, "Equality")] - [InlineData(false, TestRecordKey1, TestRecordKey2, "Equality")] - [InlineData(false, TestRecordIntKey1, TestRecordIntKey2, "Equality")] - [InlineData(true, TestRecordKey1, TestRecordKey2, "TagListContains")] - [InlineData(true, TestRecordIntKey1, TestRecordIntKey2, "TagListContains")] - [InlineData(false, TestRecordKey1, TestRecordKey2, "TagListContains")] - [InlineData(false, TestRecordIntKey1, TestRecordIntKey2, "TagListContains")] - public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, TKey testKey1, TKey testKey2, string filterType) - where TKey : notnull - { - // Arrange - var record1 = CreateModel(testKey1, withVectors: true, new float[] { 1, 1, 1, 1 }); - var record2 = CreateModel(testKey2, withVectors: true, new float[] { -1, -1, -1, -1 }); - - var collection = new ConcurrentDictionary(); - collection.TryAdd(testKey1, record1); - collection.TryAdd(testKey2, record2); - - this._collectionStore.TryAdd(TestCollectionName, collection); - - var sut = this.CreateRecordCollection(useDefinition); - - // Act - var filter = filterType == "Equality" ? new VectorSearchFilter().EqualTo("Data", $"data {testKey2}") : new VectorSearchFilter().AnyTagEqualTo("Tags", $"tag {testKey2}"); - var actual = await sut.VectorizedSearchAsync( - new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new() { IncludeVectors = true, OldFilter = filter, IncludeTotalCount = true }, - this._testCancellationToken); - - // Assert - Assert.NotNull(actual); - Assert.Equal(1, actual.TotalCount); - var actualResults = await actual.Results.ToListAsync(); - Assert.Single(actualResults); - Assert.Equal(testKey2, actualResults[0].Record.Key); - Assert.Equal($"data {testKey2}", actualResults[0].Record.Data); - Assert.Equal(-1, actualResults[0].Score); - } -#pragma warning restore CS0618 // Type or member is obsolete - - [Theory] - [InlineData(DistanceFunction.CosineSimilarity, 1, -1)] - [InlineData(DistanceFunction.CosineDistance, 0, 2)] - [InlineData(DistanceFunction.DotProductSimilarity, 4, -4)] - [InlineData(DistanceFunction.EuclideanDistance, 0, 4)] - public async Task CanSearchWithDifferentDistanceFunctionsAsync(string distanceFunction, double expectedScoreResult1, double expectedScoreResult2) - { - // Arrange - var record1 = CreateModel(TestRecordKey1, withVectors: true, new float[] { 1, 1, 1, 1 }); - var record2 = CreateModel(TestRecordKey2, withVectors: true, new float[] { -1, -1, -1, -1 }); - - var collection = new ConcurrentDictionary(); - collection.TryAdd(TestRecordKey1, record1); - collection.TryAdd(TestRecordKey2, record2); - - this._collectionStore.TryAdd(TestCollectionName, collection); - - VectorStoreRecordDefinition singlePropsDefinition = new() - { - Properties = - [ - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("Data", typeof(string)), - new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) { DistanceFunction = distanceFunction } - ] - }; - - var sut = new InMemoryVectorStoreRecordCollection>( - this._collectionStore, - this._collectionStoreTypes, - TestCollectionName, - new() - { - VectorStoreRecordDefinition = singlePropsDefinition - }); - - // Act - var actual = await sut.VectorizedSearchAsync( - new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new() { IncludeVectors = true }, - this._testCancellationToken); - - // Assert - Assert.NotNull(actual); - var actualResults = await actual.Results.ToListAsync(); - Assert.Equal(2, actualResults.Count); - Assert.Equal(TestRecordKey1, actualResults[0].Record.Key); - Assert.Equal($"data {TestRecordKey1}", actualResults[0].Record.Data); - Assert.Equal(expectedScoreResult1, actualResults[0].Score); - Assert.Equal(TestRecordKey2, actualResults[1].Record.Key); - Assert.Equal($"data {TestRecordKey2}", actualResults[1].Record.Data); - Assert.Equal(expectedScoreResult2, actualResults[1].Score); - } - - [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task CanSearchManyRecordsAsync(bool useDefinition) - { - // Arrange - var collection = new ConcurrentDictionary(); - for (int i = 0; i < 1000; i++) - { - if (i <= 14) - { - collection.TryAdd(i, CreateModel(i, withVectors: true, new float[] { 1, 1, 1, 1 })); - } - else - { - collection.TryAdd(i, CreateModel(i, withVectors: true, new float[] { -1, -1, -1, -1 })); - } - } - - this._collectionStore.TryAdd(TestCollectionName, collection); - - var sut = this.CreateRecordCollection(useDefinition); - - // Act - var actual = await sut.VectorizedSearchAsync( - new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new() { IncludeVectors = true, Top = 10, Skip = 10, IncludeTotalCount = true }, - this._testCancellationToken); - - // Assert - Assert.NotNull(actual); - Assert.Equal(1000, actual.TotalCount); - - // Assert that top was respected - var actualResults = await actual.Results.ToListAsync(); - Assert.Equal(10, actualResults.Count); - var actualIds = actualResults.Select(r => r.Record.Key).ToList(); - for (int i = 0; i < 10; i++) - { - // Assert that skip was respected - Assert.Contains(i + 10, actualIds); - if (i <= 4) - { - Assert.Equal(1, actualResults[i].Score); - } - else - { - Assert.Equal(-1, actualResults[i].Score); - } - } - } - - [Theory] - [InlineData(TestRecordKey1, TestRecordKey2)] - [InlineData(TestRecordIntKey1, TestRecordIntKey2)] - public async Task ItCanSearchUsingTheGenericDataModelAsync(TKey testKey1, TKey testKey2) - where TKey : notnull - { - // Arrange - var record1 = new VectorStoreGenericDataModel(testKey1) - { - Data = new Dictionary - { - ["Data"] = $"data {testKey1}", - ["Tags"] = new List { "default tag", "tag " + testKey1 } - }, - Vectors = new Dictionary - { - ["Vector"] = new ReadOnlyMemory([1, 1, 1, 1]) - } - }; - var record2 = new VectorStoreGenericDataModel(testKey2) - { - Data = new Dictionary - { - ["Data"] = $"data {testKey2}", - ["Tags"] = new List { "default tag", "tag " + testKey2 } - }, - Vectors = new Dictionary - { - ["Vector"] = new ReadOnlyMemory([-1, -1, -1, -1]) - } - }; - - var collection = new ConcurrentDictionary(); - collection.TryAdd(testKey1, record1); - collection.TryAdd(testKey2, record2); - - this._collectionStore.TryAdd(TestCollectionName, collection); - - var sut = new InMemoryVectorStoreRecordCollection>( - this._collectionStore, - this._collectionStoreTypes, - TestCollectionName, - new() - { - VectorStoreRecordDefinition = this._singlePropsDefinition - }); - - // Act - var actual = await sut.VectorizedSearchAsync( - new ReadOnlyMemory([1, 1, 1, 1]), - new() { IncludeVectors = true, VectorProperty = r => r.Vectors["Vector"] }, - this._testCancellationToken); - - // Assert - Assert.NotNull(actual); - var actualResults = await actual.Results.ToListAsync(); - Assert.Equal(2, actualResults.Count); - Assert.Equal(testKey1, actualResults[0].Record.Key); - Assert.Equal($"data {testKey1}", actualResults[0].Record.Data["Data"]); - Assert.Equal(1, actualResults[0].Score); - Assert.Equal(testKey2, actualResults[1].Record.Key); - Assert.Equal($"data {testKey2}", actualResults[1].Record.Data["Data"]); - Assert.Equal(-1, actualResults[1].Score); - } - - private static SinglePropsModel CreateModel(TKey key, bool withVectors, float[]? vector = null) - { - return new SinglePropsModel - { - Key = key, - Data = "data " + key, - Tags = new List { "default tag", "tag " + key }, - Vector = vector ?? (withVectors ? new float[] { 1, 2, 3, 4 } : null), - NotAnnotated = null, - }; - } - - private InMemoryVectorStoreRecordCollection> CreateRecordCollection(bool useDefinition) - where TKey : notnull - { - return new InMemoryVectorStoreRecordCollection>( - this._collectionStore, - this._collectionStoreTypes, - TestCollectionName, - new() - { - VectorStoreRecordDefinition = useDefinition ? this._singlePropsDefinition : null - }); - } - - private readonly VectorStoreRecordDefinition _singlePropsDefinition = new() - { - Properties = - [ - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("Tags", typeof(List)) { IsFilterable = true }, - new VectorStoreRecordDataProperty("Data", typeof(string)) { IsFilterable = true }, - new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) - ] - }; - - public sealed class SinglePropsModel - { - [VectorStoreRecordKey] - public TKey? Key { get; set; } - - [VectorStoreRecordData(IsFilterable = true)] - public List Tags { get; set; } = new List(); - - [VectorStoreRecordData(IsFilterable = true)] - public string Data { get; set; } = string.Empty; - - [VectorStoreRecordVector] - public ReadOnlyMemory? Vector { get; set; } - - public string? NotAnnotated { get; set; } - } -} diff --git a/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryVectorStoreTests.cs index 14d54969d8c3..fe9717c80c70 100644 --- a/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryVectorStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryVectorStoreTests.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Concurrent; -using System.Linq; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.InMemory; @@ -45,23 +43,6 @@ public void GetCollectionReturnsCollectionWithNonStringKey() Assert.IsType>>(actual); } - [Fact] - public async Task ListCollectionNamesReadsDictionaryAsync() - { - // Arrange. - var collectionStore = new ConcurrentDictionary>(); - collectionStore.TryAdd("collection1", new ConcurrentDictionary()); - collectionStore.TryAdd("collection2", new ConcurrentDictionary()); - var sut = new InMemoryVectorStore(collectionStore); - - // Act. - var collectionNames = sut.ListCollectionNamesAsync(); - - // Assert. - var collectionNamesList = await collectionNames.ToListAsync(); - Assert.Equal(new[] { "collection1", "collection2" }, collectionNamesList); - } - [Fact] public async Task GetCollectionDoesNotAllowADifferentDataTypeThanPreviouslyUsedAsync() { diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchConstants.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchConstants.cs new file mode 100644 index 000000000000..8737519c2eba --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchConstants.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; + +namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; + +internal static class AzureAISearchConstants +{ + internal const string VectorStoreSystemName = "azure.aisearch"; + + /// A set of types that a key on the provided model may have. + internal static readonly HashSet SupportedKeyTypes = [typeof(string)]; + + /// A set of types that data properties on the provided model may have. + internal static readonly HashSet SupportedDataTypes = + [ + typeof(string), + typeof(int), + typeof(long), + typeof(double), + typeof(float), + typeof(bool), + typeof(DateTimeOffset) + ]; + + /// A set of types that vectors on the provided model may have. + /// + /// Azure AI Search is adding support for more types than just float32, but these are not available for use via the + /// SDK yet. We will update this list as the SDK is updated. + /// + /// + internal static readonly HashSet SupportedVectorTypes = + [ + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?) + ]; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchDynamicDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchDynamicDataModelMapper.cs new file mode 100644 index 000000000000..b5b4e9ed08e8 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchDynamicDataModelMapper.cs @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Nodes; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; + +namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; + +/// +/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Azure AI Search. +/// +#pragma warning disable CS0618 // IVectorStoreRecordMapper is obsolete +internal sealed class AzureAISearchDynamicDataModelMapper(VectorStoreRecordModel model) +#pragma warning restore CS0618 +{ + /// + public JsonObject MapFromDataToStorageModel(Dictionary dataModel) + { + Verify.NotNull(dataModel); + + var storageJsonObject = new JsonObject(); + + // Loop through all known properties and map each from the data model json to the storage json. + foreach (var property in model.Properties) + { + switch (property) + { + case VectorStoreRecordKeyPropertyModel keyProperty: + storageJsonObject.Add(keyProperty.StorageName, (string)model.KeyProperty.GetValueAsObject(dataModel)!); + continue; + + case VectorStoreRecordDataPropertyModel dataProperty: + case VectorStoreRecordVectorPropertyModel vectorProperty: + if (dataModel.TryGetValue(property.ModelName, out var dataValue)) + { + var serializedJsonNode = JsonSerializer.SerializeToNode(dataValue); + storageJsonObject.Add(property.StorageName, serializedJsonNode); + } + continue; + + default: + throw new UnreachableException(); + } + } + + return storageJsonObject; + } + + /// + public Dictionary MapFromStorageToDataModel(JsonObject storageModel, StorageToDataModelMapperOptions options) + { + Verify.NotNull(storageModel); + + // Create variables to store the response properties. + var result = new Dictionary(); + + // Loop through all known properties and map each from json to the data type. + foreach (var property in model.Properties) + { + switch (property) + { + case VectorStoreRecordKeyPropertyModel keyProperty: + result[keyProperty.ModelName] = (string?)storageModel[keyProperty.StorageName] + ?? throw new VectorStoreRecordMappingException($"The key property '{keyProperty.StorageName}' is missing from the record retrieved from storage."); + + continue; + + case VectorStoreRecordDataPropertyModel dataProperty: + { + if (storageModel.TryGetPropertyValue(dataProperty.StorageName, out var value)) + { + result.Add(dataProperty.ModelName, value is null ? null : GetDataPropertyValue(property.Type, value)); + } + continue; + } + + case VectorStoreRecordVectorPropertyModel vectorProperty when options.IncludeVectors: + { + if (storageModel.TryGetPropertyValue(vectorProperty.StorageName, out var value)) + { + if (value is not null) + { + ReadOnlyMemory vector = value.AsArray().Select(x => (float)x!).ToArray(); + result.Add(vectorProperty.ModelName, vector); + } + else + { + result.Add(vectorProperty.ModelName, null); + } + } + + continue; + } + + case VectorStoreRecordVectorPropertyModel vectorProperty when !options.IncludeVectors: + break; + + default: + throw new UnreachableException(); + } + } + + return result; + } + + /// + /// Get the value of the given json node as the given property type. + /// + /// The type of property that is required. + /// The json node containing the property value. + /// The value of the json node as the required type. + private static object? GetDataPropertyValue(Type propertyType, JsonNode value) + { + if (propertyType == typeof(string)) + { + return (string?)value.AsValue(); + } + + if (propertyType == typeof(int) || propertyType == typeof(int?)) + { + return (int?)value.AsValue(); + } + + if (propertyType == typeof(long) || propertyType == typeof(long?)) + { + return (long?)value.AsValue(); + } + + if (propertyType == typeof(float) || propertyType == typeof(float?)) + { + return (float?)value.AsValue(); + } + + if (propertyType == typeof(double) || propertyType == typeof(double?)) + { + return (double?)value.AsValue(); + } + + if (propertyType == typeof(bool) || propertyType == typeof(bool?)) + { + return (bool?)value.AsValue(); + } + + if (propertyType == typeof(DateTimeOffset) || propertyType == typeof(DateTimeOffset?)) + { + return value.GetValue(); + } + + if (typeof(IEnumerable).IsAssignableFrom(propertyType)) + { + return value.Deserialize(propertyType); + } + + return null; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchDynamicModelBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchDynamicModelBuilder.cs new file mode 100644 index 000000000000..059d073ca3c3 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchDynamicModelBuilder.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using Microsoft.Extensions.VectorData.ConnectorSupport; + +namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; + +internal class AzureAISearchDynamicModelBuilder() : VectorStoreRecordModelBuilder(s_modelBuildingOptions) +{ + internal static readonly VectorStoreRecordModelBuildingOptions s_modelBuildingOptions = new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + + SupportedKeyPropertyTypes = AzureAISearchConstants.SupportedKeyTypes, + SupportedDataPropertyTypes = AzureAISearchConstants.SupportedDataTypes, + SupportedEnumerableDataPropertyElementTypes = AzureAISearchConstants.SupportedDataTypes, + SupportedVectorPropertyTypes = AzureAISearchConstants.SupportedVectorTypes, + + UsesExternalSerializer = true + }; + + protected override void Validate(Type type) + { + base.Validate(type); + + if (this.VectorProperties.FirstOrDefault(p => p.EmbeddingGenerator is not null) is VectorStoreRecordPropertyModel property) + { + throw new NotSupportedException( + $"The Azure AI Search connector does not currently support a custom embedding generator (configured for property '{property.ModelName}' on type '{type.Name}'). " + + "However, you can configure embedding generation in Azure AI Search itself, without requiring a .NET IEmbeddingGenerator."); + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchFilterTranslator.cs index 16164c2a3eca..8c9d172ca863 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchFilterTranslator.cs @@ -7,31 +7,35 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; using System.Text; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.ConnectorSupport.Filter; namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; internal class AzureAISearchFilterTranslator { - private IReadOnlyDictionary _storagePropertyNames = null!; + private VectorStoreRecordModel _model = null!; private ParameterExpression _recordParameter = null!; private readonly StringBuilder _filter = new(); private static readonly char[] s_searchInDefaultDelimiter = [' ', ',']; - internal string Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + internal string Translate(LambdaExpression lambdaExpression, VectorStoreRecordModel model) { Debug.Assert(this._filter.Length == 0); - this._storagePropertyNames = storagePropertyNames; + this._model = model; Debug.Assert(lambdaExpression.Parameters.Count == 1); this._recordParameter = lambdaExpression.Parameters[0]; - this.Translate(lambdaExpression.Body); + var preprocessor = new FilterTranslationPreprocessor { InlineCapturedVariables = true }; + var preprocessedExpression = preprocessor.Visit(lambdaExpression.Body); + + this.Translate(preprocessedExpression); + return this._filter.ToString(); } @@ -138,26 +142,24 @@ private void GenerateLiteral(object? value) private void TranslateMember(MemberExpression memberExpression) { - switch (memberExpression) + if (this.TryBindProperty(memberExpression, out var property)) { - case var _ when this.TryGetField(memberExpression, out var column): - this._filter.Append(column); // TODO: Escape - return; - - // Identify captured lambda variables, inline them as constants - case var _ when TryGetCapturedValue(memberExpression, out var capturedValue): - this.GenerateLiteral(capturedValue); - return; - - default: - throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + this._filter.Append(property.StorageName); // TODO: Escape + return; } + + throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); } private void TranslateMethodCall(MethodCallExpression methodCall) { switch (methodCall) { + // Dictionary access for dynamic mapping (r => r["SomeString"] == "foo") + case MethodCallExpression when this.TryBindProperty(methodCall, out var property): + this._filter.Append(property.StorageName); // TODO: Escape + return; + // Enumerable.Contains() case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains when contains.Method.DeclaringType == typeof(Enumerable): @@ -188,7 +190,7 @@ private void TranslateContains(Expression source, Expression item) switch (source) { // Contains over array field (r => r.Strings.Contains("foo")) - case var _ when this.TryGetField(source, out _): + case var _ when this.TryBindProperty(source, out _): this.Translate(source); this._filter.Append("/any(t: t eq "); this.Translate(item); @@ -201,20 +203,23 @@ private void TranslateContains(Expression source, Expression item) for (var i = 0; i < newArray.Expressions.Count; i++) { - if (!TryGetConstant(newArray.Expressions[i], out var elementValue)) + if (newArray.Expressions[i] is not ConstantExpression { Value: var elementValue }) { throw new NotSupportedException("Invalid element in array"); } + if (elementValue is not string) + { + throw new NotSupportedException("Contains over non-string arrays is not supported"); + } + elements[i] = elementValue; } ProcessInlineEnumerable(elements, item); return; - // Contains over captured enumerable (we inline) - case var _ when TryGetConstant(source, out var constantEnumerable) - && constantEnumerable is IEnumerable enumerable and not string: + case ConstantExpression { Value: IEnumerable enumerable and not string }: ProcessInlineEnumerable(enumerable, item); return; @@ -224,11 +229,6 @@ private void TranslateContains(Expression source, Expression item) void ProcessInlineEnumerable(IEnumerable elements, Expression item) { - if (item.Type != typeof(string)) - { - throw new NotSupportedException("Contains over non-string arrays is not supported"); - } - this._filter.Append("search.in("); this.Translate(item); this._filter.Append(", '"); @@ -311,56 +311,59 @@ private void TranslateUnary(UnaryExpression unary) this._filter.Append(')'); return; + // Handle convert over member access, for dynamic dictionary access (r => (int)r["SomeInt"] == 8) + case ExpressionType.Convert when this.TryBindProperty(unary.Operand, out var property) && unary.Type == property.Type: + this._filter.Append(property.StorageName); // TODO: Escape + return; + default: throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); } } - private bool TryGetField(Expression expression, [NotNullWhen(true)] out string? field) + private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out VectorStoreRecordPropertyModel? property) { - if (expression is MemberExpression member && member.Expression == this._recordParameter) - { - if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out field)) - { - throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); - } + Type? convertedClrType = null; - return true; + if (expression is UnaryExpression { NodeType: ExpressionType.Convert } unary) + { + expression = unary.Operand; + convertedClrType = unary.Type; } - field = null; - return false; - } - - private static bool TryGetCapturedValue(Expression expression, out object? capturedValue) - { - if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } - && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) - && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + var modelName = expression switch { - capturedValue = fieldInfo.GetValue(constant.Value); - return true; - } + // Regular member access for strongly-typed POCO binding (e.g. r => r.SomeInt == 8) + MemberExpression memberExpression when memberExpression.Expression == this._recordParameter + => memberExpression.Member.Name, - capturedValue = null; - return false; - } + // Dictionary lookup for weakly-typed dynamic binding (e.g. r => r["SomeInt"] == 8) + MethodCallExpression + { + Method: { Name: "get_Item", DeclaringType: var declaringType }, + Arguments: [ConstantExpression { Value: string keyName }] + } methodCall when methodCall.Object == this._recordParameter && declaringType == typeof(Dictionary) + => keyName, - private static bool TryGetConstant(Expression expression, out object? constantValue) - { - switch (expression) + _ => null + }; + + if (modelName is null) { - case ConstantExpression { Value: var v }: - constantValue = v; - return true; + property = null; + return false; + } - case var _ when TryGetCapturedValue(expression, out var capturedValue): - constantValue = capturedValue; - return true; + if (!this._model.PropertyMap.TryGetValue(modelName, out property)) + { + throw new InvalidOperationException($"Property name '{modelName}' provided as part of the filter clause is not a valid property name."); + } - default: - constantValue = null; - return false; + if (convertedClrType is not null && convertedClrType != property.Type) + { + throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convertedClrType.Name}', but its configured type is '{property.Type.Name}'."); } + + return true; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchGenericDataModelMapper.cs deleted file mode 100644 index 502edaed2605..000000000000 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchGenericDataModelMapper.cs +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections; -using System.Collections.Generic; -using System.Linq; -using System.Text.Json; -using System.Text.Json.Nodes; -using Microsoft.Extensions.VectorData; - -namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; - -/// -/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Azure AI Search. -/// -internal class AzureAISearchGenericDataModelMapper : IVectorStoreRecordMapper, JsonObject> -{ - /// A that defines the schema of the data in the database. - private readonly VectorStoreRecordDefinition _vectorStoreRecordDefinition; - - /// - /// Initializes a new instance of the class. - /// - /// A that defines the schema of the data in the database. - public AzureAISearchGenericDataModelMapper(VectorStoreRecordDefinition vectorStoreRecordDefinition) - { - Verify.NotNull(vectorStoreRecordDefinition); - - this._vectorStoreRecordDefinition = vectorStoreRecordDefinition; - } - - /// - public JsonObject MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - { - Verify.NotNull(dataModel); - - var storageJsonObject = new JsonObject(); - - // Loop through all known properties and map each from the data model json to the storage json. - foreach (var property in this._vectorStoreRecordDefinition.Properties) - { - if (property is VectorStoreRecordKeyProperty keyProperty) - { - var storagePropertyName = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; - storageJsonObject.Add(storagePropertyName, dataModel.Key); - } - else if (property is VectorStoreRecordDataProperty dataProperty) - { - if (dataModel.Data is not null && dataModel.Data.TryGetValue(dataProperty.DataModelPropertyName, out var dataValue)) - { - var storagePropertyName = dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName; - var serializedJsonNode = JsonSerializer.SerializeToNode(dataValue); - storageJsonObject.Add(storagePropertyName, serializedJsonNode); - } - } - else if (property is VectorStoreRecordVectorProperty vectorProperty) - { - if (dataModel.Vectors is not null && dataModel.Vectors.TryGetValue(vectorProperty.DataModelPropertyName, out var vectorValue)) - { - var storagePropertyName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; - var serializedJsonNode = JsonSerializer.SerializeToNode(vectorValue); - storageJsonObject.Add(storagePropertyName, serializedJsonNode); - } - } - } - - return storageJsonObject; - } - - /// - public VectorStoreGenericDataModel MapFromStorageToDataModel(JsonObject storageModel, StorageToDataModelMapperOptions options) - { - Verify.NotNull(storageModel); - - // Create variables to store the response properties. - var dataProperties = new Dictionary(); - var vectorProperties = new Dictionary(); - string? key = null; - - // Loop through all known properties and map each from json to the data type. - foreach (var property in this._vectorStoreRecordDefinition.Properties) - { - if (property is VectorStoreRecordKeyProperty keyProperty) - { - var storagePropertyName = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; - var value = storageModel[storagePropertyName]; - if (value is null) - { - throw new VectorStoreRecordMappingException($"The key property '{storagePropertyName}' is missing from the record retrieved from storage."); - } - - key = (string)value!; - } - else if (property is VectorStoreRecordDataProperty dataProperty) - { - var storagePropertyName = dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName; - if (!storageModel.TryGetPropertyValue(storagePropertyName, out var value)) - { - continue; - } - - if (value is not null) - { - dataProperties.Add(dataProperty.DataModelPropertyName, GetDataPropertyValue(property.PropertyType, value)); - } - else - { - dataProperties.Add(dataProperty.DataModelPropertyName, null); - } - } - else if (property is VectorStoreRecordVectorProperty vectorProperty && options.IncludeVectors) - { - var storagePropertyName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; - if (!storageModel.TryGetPropertyValue(storagePropertyName, out var value)) - { - continue; - } - - if (value is not null) - { - ReadOnlyMemory vector = value.AsArray().Select(x => (float)x!).ToArray(); - vectorProperties.Add(vectorProperty.DataModelPropertyName, vector); - } - else - { - vectorProperties.Add(vectorProperty.DataModelPropertyName, null); - } - } - } - - if (key is null) - { - throw new VectorStoreRecordMappingException("No key property was found in the record retrieved from storage."); - } - - return new VectorStoreGenericDataModel(key) { Data = dataProperties, Vectors = vectorProperties }; - } - - /// - /// Get the value of the given json node as the given property type. - /// - /// The type of property that is required. - /// The json node containing the property value. - /// The value of the json node as the required type. - private static object? GetDataPropertyValue(Type propertyType, JsonNode value) - { - if (propertyType == typeof(string)) - { - return (string?)value.AsValue(); - } - - if (propertyType == typeof(int) || propertyType == typeof(int?)) - { - return (int?)value.AsValue(); - } - - if (propertyType == typeof(long) || propertyType == typeof(long?)) - { - return (long?)value.AsValue(); - } - - if (propertyType == typeof(float) || propertyType == typeof(float?)) - { - return (float?)value.AsValue(); - } - - if (propertyType == typeof(double) || propertyType == typeof(double?)) - { - return (double?)value.AsValue(); - } - - if (propertyType == typeof(bool) || propertyType == typeof(bool?)) - { - return (bool?)value.AsValue(); - } - - if (propertyType == typeof(DateTimeOffset) || propertyType == typeof(DateTimeOffset?)) - { - return value.GetValue(); - } - - if (typeof(IEnumerable).IsAssignableFrom(propertyType)) - { - return value.Deserialize(propertyType); - } - - return null; - } -} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchKernelBuilderExtensions.cs index 5096c1486f1f..76c54f8cb412 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchKernelBuilderExtensions.cs @@ -12,6 +12,7 @@ namespace Microsoft.SemanticKernel; /// /// Extension methods to register Azure AI Search instances on the . /// +[Obsolete("The IKernelBuilder extensions are being obsoleted, call the appropriate function on the Services property of your IKernelBuilder")] public static class AzureAISearchKernelBuilderExtensions { /// @@ -58,13 +59,13 @@ public static IKernelBuilder AddAzureAISearchVectorStore(this IKernelBuilder bui } /// - /// Register an Azure AI Search , and with the + /// Register an Azure AI Search , and with the /// specified service ID and where is retrieved from the dependency injection container. /// /// The type of the data model that the collection should contain. /// The builder to register the on. - /// The name of the collection that this will access. - /// Optional configuration options to pass to the . + /// The name of the collection that this will access. + /// Optional configuration options to pass to the . /// An optional service id to use as the service key. /// The kernel builder. public static IKernelBuilder AddAzureAISearchVectorStoreRecordCollection( @@ -72,21 +73,22 @@ public static IKernelBuilder AddAzureAISearchVectorStoreRecordCollection? options = default, string? serviceId = default) + where TRecord : notnull { builder.Services.AddAzureAISearchVectorStoreRecordCollection(collectionName, options, serviceId); return builder; } /// - /// Register an Azure AI Search , and with the + /// Register an Azure AI Search , and with the /// provided and and the specified service ID. /// /// The type of the data model that the collection should contain. /// The builder to register the on. - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// The service endpoint for Azure AI Search. /// The credential to authenticate to Azure AI Search with. - /// Optional configuration options to pass to the . + /// Optional configuration options to pass to the . /// An optional service id to use as the service key. /// The kernel builder. public static IKernelBuilder AddAzureAISearchVectorStoreRecordCollection( @@ -96,21 +98,22 @@ public static IKernelBuilder AddAzureAISearchVectorStoreRecordCollection? options = default, string? serviceId = default) + where TRecord : notnull { builder.Services.AddAzureAISearchVectorStoreRecordCollection(collectionName, endpoint, tokenCredential, options, serviceId); return builder; } /// - /// Register an Azure AI Search , and with the + /// Register an Azure AI Search , and with the /// provided and and the specified service ID. /// /// The type of the data model that the collection should contain. /// The builder to register the on. - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// The service endpoint for Azure AI Search. /// The credential to authenticate to Azure AI Search with. - /// Optional configuration options to pass to the . + /// Optional configuration options to pass to the . /// An optional service id to use as the service key. /// The kernel builder. public static IKernelBuilder AddAzureAISearchVectorStoreRecordCollection( @@ -120,6 +123,7 @@ public static IKernelBuilder AddAzureAISearchVectorStoreRecordCollection? options = default, string? serviceId = default) + where TRecord : notnull { builder.Services.AddAzureAISearchVectorStoreRecordCollection(collectionName, endpoint, credential, options, serviceId); return builder; diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchMemoryRecord.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchMemoryRecord.cs index 88a57a2ed4a0..49da4d8e1a4b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchMemoryRecord.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchMemoryRecord.cs @@ -1,18 +1,19 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Diagnostics.CodeAnalysis; using System.Text; using System.Text.Json.Serialization; using Microsoft.SemanticKernel.Memory; namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// Azure AI Search record and index definition. /// Note: once defined, index cannot be modified. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being phased out, use Microsoft.Extensions.VectorData and AzureAISearchVectorStore")] internal sealed class AzureAISearchMemoryRecord { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchMemoryStore.cs index 3dab67ba52b4..eae807bb53a2 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchMemoryStore.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -21,10 +20,12 @@ namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// is a memory store implementation using Azure AI Search. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being phased out, use Microsoft.Extensions.VectorData and AzureAISearchVectorStore")] public partial class AzureAISearchMemoryStore : IMemoryStore { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchModelBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchModelBuilder.cs new file mode 100644 index 000000000000..e3ff3b5c2983 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchModelBuilder.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using Microsoft.Extensions.VectorData.ConnectorSupport; + +namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; + +internal class AzureAISearchModelBuilder() : VectorStoreRecordJsonModelBuilder(s_modelBuildingOptions) +{ + internal static readonly VectorStoreRecordModelBuildingOptions s_modelBuildingOptions = new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + + SupportedKeyPropertyTypes = AzureAISearchConstants.SupportedKeyTypes, + SupportedDataPropertyTypes = AzureAISearchConstants.SupportedDataTypes, + SupportedEnumerableDataPropertyElementTypes = AzureAISearchConstants.SupportedDataTypes, + SupportedVectorPropertyTypes = AzureAISearchConstants.SupportedVectorTypes, + + UsesExternalSerializer = true + }; + + protected override void Validate(Type type) + { + base.Validate(type); + + if (this.VectorProperties.FirstOrDefault(p => p.EmbeddingGenerator is not null) is VectorStoreRecordPropertyModel property) + { + throw new NotSupportedException( + $"The Azure AI Search connector does not currently support a custom embedding generator (configured for property '{property.ModelName}' on type '{type.Name}'). " + + "However, you can configure embedding generation in Azure AI Search itself, without requiring a .NET IEmbeddingGenerator."); + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchServiceCollectionExtensions.cs index 0daa73595cbd..7ce1d876bcc9 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchServiceCollectionExtensions.cs @@ -7,6 +7,7 @@ using Azure.Core.Serialization; using Azure.Search.Documents; using Azure.Search.Documents.Indexes; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.AzureAISearch; @@ -35,11 +36,12 @@ public static IServiceCollection AddAzureAISearchVectorStore(this IServiceCollec (sp, obj) => { var searchIndexClient = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService(); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new AzureAISearchVectorStore( - searchIndexClient, - selectedOptions); + return new AzureAISearchVectorStore(searchIndexClient, options); }); return services; @@ -63,14 +65,15 @@ public static IServiceCollection AddAzureAISearchVectorStore(this IServiceCollec serviceId, (sp, obj) => { - var selectedOptions = options ?? sp.GetService(); - var searchClientOptions = BuildSearchClientOptions(selectedOptions?.JsonSerializerOptions); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; + var searchClientOptions = BuildSearchClientOptions(options?.JsonSerializerOptions); var searchIndexClient = new SearchIndexClient(endpoint, tokenCredential, searchClientOptions); // Construct the vector store. - return new AzureAISearchVectorStore( - searchIndexClient, - selectedOptions); + return new AzureAISearchVectorStore(searchIndexClient, options); }); return services; @@ -94,27 +97,28 @@ public static IServiceCollection AddAzureAISearchVectorStore(this IServiceCollec serviceId, (sp, obj) => { - var selectedOptions = options ?? sp.GetService(); - var searchClientOptions = BuildSearchClientOptions(selectedOptions?.JsonSerializerOptions); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; + var searchClientOptions = BuildSearchClientOptions(options?.JsonSerializerOptions); var searchIndexClient = new SearchIndexClient(endpoint, credential, searchClientOptions); // Construct the vector store. - return new AzureAISearchVectorStore( - searchIndexClient, - selectedOptions); + return new AzureAISearchVectorStore(searchIndexClient, options); }); return services; } /// - /// Register an Azure AI Search , and with the + /// Register an Azure AI Search , and with the /// specified service ID and where is retrieved from the dependency injection container. /// /// The type of the data model that the collection should contain. /// The to register the on. - /// The name of the collection that this will access. - /// Optional configuration options to pass to the . + /// The name of the collection that this will access. + /// Optional configuration options to pass to the . /// An optional service id to use as the service key. /// The service collection. public static IServiceCollection AddAzureAISearchVectorStoreRecordCollection( @@ -122,6 +126,7 @@ public static IServiceCollection AddAzureAISearchVectorStoreRecordCollection? options = default, string? serviceId = default) + where TRecord : notnull { // If we are not constructing the SearchIndexClient, add the IVectorStore as transient, since we // cannot make assumptions about how SearchIndexClient is being managed. @@ -130,12 +135,12 @@ public static IServiceCollection AddAzureAISearchVectorStoreRecordCollection { var searchIndexClient = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new AzureAISearchVectorStoreRecordCollection( - searchIndexClient, - collectionName, - selectedOptions); + return new AzureAISearchVectorStoreRecordCollection(searchIndexClient, collectionName, options); }); AddVectorizedSearch(services, serviceId); @@ -144,15 +149,15 @@ public static IServiceCollection AddAzureAISearchVectorStoreRecordCollection - /// Register an Azure AI Search , and with the + /// Register an Azure AI Search , and with the /// provided and and the specified service ID. /// /// The type of the data model that the collection should contain. /// The to register the on. - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// The service endpoint for Azure AI Search. /// The credential to authenticate to Azure AI Search with. - /// Optional configuration options to pass to the . + /// Optional configuration options to pass to the . /// An optional service id to use as the service key. /// The service collection. public static IServiceCollection AddAzureAISearchVectorStoreRecordCollection( @@ -162,6 +167,7 @@ public static IServiceCollection AddAzureAISearchVectorStoreRecordCollection? options = default, string? serviceId = default) + where TRecord : notnull { Verify.NotNull(endpoint); Verify.NotNull(tokenCredential); @@ -170,15 +176,15 @@ public static IServiceCollection AddAzureAISearchVectorStoreRecordCollection { - var selectedOptions = options ?? sp.GetService>(); - var searchClientOptions = BuildSearchClientOptions(selectedOptions?.JsonSerializerOptions); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; + var searchClientOptions = BuildSearchClientOptions(options?.JsonSerializerOptions); var searchIndexClient = new SearchIndexClient(endpoint, tokenCredential, searchClientOptions); // Construct the vector store. - return new AzureAISearchVectorStoreRecordCollection( - searchIndexClient, - collectionName, - selectedOptions); + return new AzureAISearchVectorStoreRecordCollection(searchIndexClient, collectionName, options); }); AddVectorizedSearch(services, serviceId); @@ -187,15 +193,15 @@ public static IServiceCollection AddAzureAISearchVectorStoreRecordCollection - /// Register an Azure AI Search , and with the + /// Register an Azure AI Search , and with the /// provided and and the specified service ID. /// /// The type of the data model that the collection should contain. /// The to register the on. - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// The service endpoint for Azure AI Search. /// The credential to authenticate to Azure AI Search with. - /// Optional configuration options to pass to the . + /// Optional configuration options to pass to the . /// An optional service id to use as the service key. /// The service collection. public static IServiceCollection AddAzureAISearchVectorStoreRecordCollection( @@ -205,6 +211,7 @@ public static IServiceCollection AddAzureAISearchVectorStoreRecordCollection? options = default, string? serviceId = default) + where TRecord : notnull { Verify.NotNull(endpoint); Verify.NotNull(credential); @@ -213,15 +220,15 @@ public static IServiceCollection AddAzureAISearchVectorStoreRecordCollection { - var selectedOptions = options ?? sp.GetService>(); - var searchClientOptions = BuildSearchClientOptions(selectedOptions?.JsonSerializerOptions); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; + var searchClientOptions = BuildSearchClientOptions(options?.JsonSerializerOptions); var searchIndexClient = new SearchIndexClient(endpoint, credential, searchClientOptions); // Construct the vector store. - return new AzureAISearchVectorStoreRecordCollection( - searchIndexClient, - collectionName, - selectedOptions); + return new AzureAISearchVectorStoreRecordCollection(searchIndexClient, collectionName, options); }); AddVectorizedSearch(services, serviceId); @@ -230,14 +237,14 @@ public static IServiceCollection AddAzureAISearchVectorStoreRecordCollection - /// Also register the with the given as a . + /// Also register the with the given as a . /// /// The type of the data model that the collection should contain. /// The service collection to register on. /// The service id that the registrations should use. - private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) + private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) where TRecord : notnull { - services.AddKeyedTransient>( + services.AddKeyedTransient>( serviceId, (sp, obj) => { diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStore.cs index 5329cdf3cee4..eb866c2998c8 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStore.cs @@ -17,10 +17,10 @@ namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; /// /// This class can be used with collections of any schema type, but requires you to provide schema information when getting a collection. /// -public class AzureAISearchVectorStore : IVectorStore +public sealed class AzureAISearchVectorStore : IVectorStore { - /// The name of this database for telemetry purposes. - private const string DatabaseName = "AzureAISearch"; + /// Metadata about vector store. + private readonly VectorStoreMetadata _metadata; /// Azure AI Search client that can be used to manage the list of indices in an Azure AI Search Service. private readonly SearchIndexClient _searchIndexClient; @@ -28,6 +28,9 @@ public class AzureAISearchVectorStore : IVectorStore /// Optional configuration options for this class. private readonly AzureAISearchVectorStoreOptions _options; + /// A general purpose definition that can be used to construct a collection when needing to proxy schema agnostic operations. + private static readonly VectorStoreRecordDefinition s_generalPurposeDefinition = new() { Properties = [new VectorStoreRecordKeyProperty("Key", typeof(string))] }; + /// /// Initializes a new instance of the class. /// @@ -39,11 +42,18 @@ public AzureAISearchVectorStore(SearchIndexClient searchIndexClient, AzureAISear this._searchIndexClient = searchIndexClient; this._options = options ?? new AzureAISearchVectorStoreOptions(); + + this._metadata = new() + { + VectorStoreSystemName = AzureAISearchConstants.VectorStoreSystemName, + VectorStoreName = searchIndexClient.ServiceName + }; } /// - public virtual IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull + where TRecord : notnull { #pragma warning disable CS0618 // IAzureAISearchVectorStoreRecordCollectionFactor is obsolete if (this._options.VectorStoreCollectionFactory is not null) @@ -52,45 +62,69 @@ public virtual IVectorStoreRecordCollection GetCollection( + var recordCollection = new AzureAISearchVectorStoreRecordCollection( this._searchIndexClient, name, new AzureAISearchVectorStoreRecordCollectionOptions() { JsonSerializerOptions = this._options.JsonSerializerOptions, - VectorStoreRecordDefinition = vectorStoreRecordDefinition + VectorStoreRecordDefinition = vectorStoreRecordDefinition, + EmbeddingGenerator = this._options.EmbeddingGenerator }) as IVectorStoreRecordCollection; return recordCollection!; } /// - public virtual async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { var indexNamesEnumerable = this._searchIndexClient.GetIndexNamesAsync(cancellationToken).ConfigureAwait(false); var indexNamesEnumerator = indexNamesEnumerable.GetAsyncEnumerator(); - var nextResult = await GetNextIndexNameAsync(indexNamesEnumerator).ConfigureAwait(false); + var nextResult = await this.GetNextIndexNameAsync(indexNamesEnumerator).ConfigureAwait(false); while (nextResult.more) { yield return nextResult.name; - nextResult = await GetNextIndexNameAsync(indexNamesEnumerator).ConfigureAwait(false); + nextResult = await this.GetNextIndexNameAsync(indexNamesEnumerator).ConfigureAwait(false); } } + /// + public Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.CollectionExistsAsync(cancellationToken); + } + + /// + public Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.DeleteCollectionAsync(cancellationToken); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreMetadata) ? this._metadata : + serviceType == typeof(SearchIndexClient) ? this._searchIndexClient : + serviceType.IsInstanceOfType(this) ? this : + null; + } + /// /// Helper method to get the next index name from the enumerator with a try catch around the move next call to convert - /// any to , since try catch is not supported + /// any to , since try catch is not supported /// around a yield return. /// /// The enumerator to get the next result from. /// A value indicating whether there are more results and the current string if true. - private static async Task<(string name, bool more)> GetNextIndexNameAsync(ConfiguredCancelableAsyncEnumerable.Enumerator enumerator) + private async Task<(string name, bool more)> GetNextIndexNameAsync( + ConfiguredCancelableAsyncEnumerable.Enumerator enumerator) { const string OperationName = "GetIndexNames"; @@ -103,7 +137,8 @@ public virtual async IAsyncEnumerable ListCollectionNamesAsync([Enumerat { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreSystemName = AzureAISearchConstants.VectorStoreSystemName, + VectorStoreName = this._metadata.VectorStoreName, OperationName = OperationName }; } @@ -111,7 +146,8 @@ public virtual async IAsyncEnumerable ListCollectionNamesAsync([Enumerat { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreSystemName = AzureAISearchConstants.VectorStoreSystemName, + VectorStoreName = this._metadata.VectorStoreName, OperationName = OperationName }; } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionCreateMapping.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionCreateMapping.cs index 93f95ca69c48..14ef7f40376d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionCreateMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionCreateMapping.cs @@ -6,6 +6,7 @@ using System.Linq; using Azure.Search.Documents.Indexes.Models; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; @@ -18,53 +19,56 @@ internal static class AzureAISearchVectorStoreCollectionCreateMapping /// Map from a to an Azure AI Search . /// /// The key property definition. - /// The name of the property in storage. /// The for the provided property definition. - public static SearchableField MapKeyField(VectorStoreRecordKeyProperty keyProperty, string storagePropertyName) + public static SearchableField MapKeyField(VectorStoreRecordKeyPropertyModel keyProperty) { - return new SearchableField(storagePropertyName) { IsKey = true, IsFilterable = true }; + return new SearchableField(keyProperty.StorageName) { IsKey = true, IsFilterable = true }; } /// /// Map from a to an Azure AI Search . /// /// The data property definition. - /// The name of the property in storage. /// The for the provided property definition. /// Throws when the definition is missing required information. - public static SimpleField MapDataField(VectorStoreRecordDataProperty dataProperty, string storagePropertyName) + public static SimpleField MapDataField(VectorStoreRecordDataPropertyModel dataProperty) { - if (dataProperty.IsFullTextSearchable) + if (dataProperty.IsFullTextIndexed) { - if (dataProperty.PropertyType != typeof(string)) + if (dataProperty.Type != typeof(string)) { - throw new InvalidOperationException($"Property {nameof(dataProperty.IsFullTextSearchable)} on {nameof(VectorStoreRecordDataProperty)} '{dataProperty.DataModelPropertyName}' is set to true, but the property type is not a string. The Azure AI Search VectorStore supports {nameof(dataProperty.IsFullTextSearchable)} on string properties only."); + throw new InvalidOperationException($"Property {nameof(dataProperty.IsFullTextIndexed)} on {nameof(VectorStoreRecordDataProperty)} '{dataProperty.ModelName}' is set to true, but the property type is not a string. The Azure AI Search VectorStore supports {nameof(dataProperty.IsFullTextIndexed)} on string properties only."); } - return new SearchableField(storagePropertyName) { IsFilterable = dataProperty.IsFilterable }; + return new SearchableField(dataProperty.StorageName) + { + IsFilterable = dataProperty.IsIndexed, + // Sometimes the users ask to also OrderBy given filterable property, so we make it sortable. + IsSortable = dataProperty.IsIndexed + }; } - return new SimpleField(storagePropertyName, AzureAISearchVectorStoreCollectionCreateMapping.GetSDKFieldDataType(dataProperty.PropertyType)) { IsFilterable = dataProperty.IsFilterable }; + var fieldType = AzureAISearchVectorStoreCollectionCreateMapping.GetSDKFieldDataType(dataProperty.Type); + return new SimpleField(dataProperty.StorageName, fieldType) + { + IsFilterable = dataProperty.IsIndexed, + // Sometimes the users ask to also OrderBy given filterable property, so we make it sortable. + IsSortable = dataProperty.IsIndexed && !fieldType.IsCollection + }; } /// /// Map form a to an Azure AI Search and generate the required index configuration. /// /// The vector property definition. - /// The name of the property in storage. /// The and required index configuration. /// Throws when the definition is missing required information, or unsupported options are configured. - public static (VectorSearchField vectorSearchField, VectorSearchAlgorithmConfiguration algorithmConfiguration, VectorSearchProfile vectorSearchProfile) MapVectorField(VectorStoreRecordVectorProperty vectorProperty, string storagePropertyName) + public static (VectorSearchField vectorSearchField, VectorSearchAlgorithmConfiguration algorithmConfiguration, VectorSearchProfile vectorSearchProfile) MapVectorField(VectorStoreRecordVectorPropertyModel vectorProperty) { - if (vectorProperty.Dimensions is not > 0) - { - throw new InvalidOperationException($"Property {nameof(vectorProperty.Dimensions)} on {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' must be set to a positive integer to create a collection."); - } - // Build a name for the profile and algorithm configuration based on the property name // since we'll just create a separate one for each vector property. - var vectorSearchProfileName = $"{storagePropertyName}Profile"; - var algorithmConfigName = $"{storagePropertyName}AlgoConfig"; + var vectorSearchProfileName = $"{vectorProperty.StorageName}Profile"; + var algorithmConfigName = $"{vectorProperty.StorageName}AlgoConfig"; // Read the vector index settings from the property definition and create the right index configuration. var indexKind = AzureAISearchVectorStoreCollectionCreateMapping.GetSKIndexKind(vectorProperty); @@ -74,11 +78,11 @@ public static (VectorSearchField vectorSearchField, VectorSearchAlgorithmConfigu { IndexKind.Hnsw => new HnswAlgorithmConfiguration(algorithmConfigName) { Parameters = new HnswParameters { Metric = algorithmMetric } }, IndexKind.Flat => new ExhaustiveKnnAlgorithmConfiguration(algorithmConfigName) { Parameters = new ExhaustiveKnnParameters { Metric = algorithmMetric } }, - _ => throw new InvalidOperationException($"Index kind '{indexKind}' on {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' is not supported by the Azure AI Search VectorStore.") + _ => throw new InvalidOperationException($"Index kind '{indexKind}' on {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.ModelName}' is not supported by the Azure AI Search VectorStore.") }; var vectorSearchProfile = new VectorSearchProfile(vectorSearchProfileName, algorithmConfigName); - return (new VectorSearchField(storagePropertyName, vectorProperty.Dimensions.Value, vectorSearchProfileName), algorithmConfiguration, vectorSearchProfile); + return (new VectorSearchField(vectorProperty.StorageName, vectorProperty.Dimensions, vectorSearchProfileName), algorithmConfiguration, vectorSearchProfile); } /// @@ -87,15 +91,8 @@ public static (VectorSearchField vectorSearchField, VectorSearchAlgorithmConfigu /// /// The vector property definition. /// The configured or default . - public static string GetSKIndexKind(VectorStoreRecordVectorProperty vectorProperty) - { - if (vectorProperty.IndexKind is null) - { - return IndexKind.Hnsw; - } - - return vectorProperty.IndexKind; - } + public static string GetSKIndexKind(VectorStoreRecordVectorPropertyModel vectorProperty) + => vectorProperty.IndexKind ?? IndexKind.Hnsw; /// /// Get the configured from the given . @@ -104,21 +101,14 @@ public static string GetSKIndexKind(VectorStoreRecordVectorProperty vectorProper /// The vector property definition. /// The chosen . /// Thrown if a distance function is chosen that isn't supported by Azure AI Search. - public static VectorSearchAlgorithmMetric GetSDKDistanceAlgorithm(VectorStoreRecordVectorProperty vectorProperty) - { - if (vectorProperty.DistanceFunction is null) + public static VectorSearchAlgorithmMetric GetSDKDistanceAlgorithm(VectorStoreRecordVectorPropertyModel vectorProperty) + => vectorProperty.DistanceFunction switch { - return VectorSearchAlgorithmMetric.Cosine; - } - - return vectorProperty.DistanceFunction switch - { - DistanceFunction.CosineSimilarity => VectorSearchAlgorithmMetric.Cosine, + DistanceFunction.CosineSimilarity or null => VectorSearchAlgorithmMetric.Cosine, DistanceFunction.DotProductSimilarity => VectorSearchAlgorithmMetric.DotProduct, DistanceFunction.EuclideanDistance => VectorSearchAlgorithmMetric.Euclidean, - _ => throw new InvalidOperationException($"Distance function '{vectorProperty.DistanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' is not supported by the Azure AI Search VectorStore.") + _ => throw new InvalidOperationException($"Distance function '{vectorProperty.DistanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.ModelName}' is not supported by the Azure AI Search VectorStore.") }; - } /// /// Maps the given property type to the corresponding . diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionSearchMapping.cs index 732b6aeae42c..4e9240f98bb1 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionSearchMapping.cs @@ -1,9 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; using System.Linq; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; @@ -17,10 +17,10 @@ internal static class AzureAISearchVectorStoreCollectionSearchMapping /// Build an OData filter string from the provided . /// /// The to build an OData filter string from. - /// A mapping of data model property names to the names under which they are stored. + /// The model. /// The OData filter string. /// Thrown when a provided filter value is not supported. - public static string BuildLegacyFilterString(VectorSearchFilter basicVectorSearchFilter, IReadOnlyDictionary storagePropertyNames) + public static string BuildLegacyFilterString(VectorSearchFilter basicVectorSearchFilter, VectorStoreRecordModel model) { var filterString = string.Empty; if (basicVectorSearchFilter.FilterClauses is not null) @@ -28,7 +28,7 @@ public static string BuildLegacyFilterString(VectorSearchFilter basicVectorSearc // Map Equality clauses. var filterStrings = basicVectorSearchFilter?.FilterClauses.OfType().Select(x => { - string storageFieldName = GetStoragePropertyName(storagePropertyNames, x.FieldName); + string storageFieldName = GetStoragePropertyName(model, x.FieldName); return x.Value switch { @@ -49,11 +49,7 @@ public static string BuildLegacyFilterString(VectorSearchFilter basicVectorSearc // Map tag contains clauses. var tagListContainsStrings = basicVectorSearchFilter?.FilterClauses .OfType() - .Select(x => - { - string storageFieldName = GetStoragePropertyName(storagePropertyNames, x.FieldName); - return $"{storageFieldName}/any(t: t eq '{x.Value}')"; - }); + .Select(x => $"{GetStoragePropertyName(model, x.FieldName)}/any(t: t eq '{x.Value}')"); // Combine clauses. filterString = string.Join(" and ", filterStrings!.Concat(tagListContainsStrings!)); @@ -66,17 +62,17 @@ public static string BuildLegacyFilterString(VectorSearchFilter basicVectorSearc /// /// Gets the name of the name under which the property with the given name is stored. /// - /// A mapping of data model property names to the names under which they are stored. + /// The model. /// The name of the property in the data model. /// The name that the property os stored under. /// Thrown when the property name is not found. - private static string GetStoragePropertyName(IReadOnlyDictionary storagePropertyNames, string fieldName) + private static string GetStoragePropertyName(VectorStoreRecordModel model, string fieldName) { - if (!storagePropertyNames.TryGetValue(fieldName, out var storageFieldName)) + if (!model.PropertyMap.TryGetValue(fieldName, out var property)) { throw new InvalidOperationException($"Property name '{fieldName}' provided as part of the filter clause is not a valid property name."); } - return storageFieldName; + return property.StorageName; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreOptions.cs index 4c17ed4195e6..c18a882042e6 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreOptions.cs @@ -3,6 +3,7 @@ using System; using System.Text.Json; using Azure.Search.Documents.Indexes; +using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; @@ -12,7 +13,7 @@ namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; public sealed class AzureAISearchVectorStoreOptions { /// - /// An optional factory to use for constructing instances, if a custom record collection is required. + /// An optional factory to use for constructing instances, if a custom record collection is required. /// [Obsolete("To control how collections are instantiated, extend your provider's IVectorStore implementation and override GetCollection()")] public IAzureAISearchVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } @@ -23,4 +24,9 @@ public sealed class AzureAISearchVectorStoreOptions /// to provide the same set of both here and when constructing the . /// public JsonSerializerOptions? JsonSerializerOptions { get; init; } = null; + + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs index eda71258ef24..c3b5eefb210e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs @@ -2,10 +2,10 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Linq.Expressions; using System.Runtime.CompilerServices; -using System.Text.Json; using System.Text.Json.Nodes; using System.Threading; using System.Threading.Tasks; @@ -15,58 +15,28 @@ using Azure.Search.Documents.Indexes.Models; using Azure.Search.Documents.Models; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; /// /// Service for storing and retrieving vector records, that uses Azure AI Search as the underlying storage. /// +/// The data type of the record key. Can be either , or for dynamic mapping. /// The data model to use for adding, updating and retrieving data from storage. #pragma warning disable CA1711 // Identifiers should not have incorrect suffix -public class AzureAISearchVectorStoreRecordCollection : - IVectorStoreRecordCollection, +#pragma warning disable CS0618 // IVectorizableTextSearch is obsolete +public sealed class AzureAISearchVectorStoreRecordCollection : + IVectorStoreRecordCollection, IVectorizableTextSearch, IKeywordHybridSearch + where TKey : notnull + where TRecord : notnull +#pragma warning restore CS0618 // IVectorizableTextSearch is obsolete #pragma warning restore CA1711 // Identifiers should not have incorrect suffix { - /// The name of this database for telemetry purposes. - private const string DatabaseName = "AzureAISearch"; - - /// A set of types that a key on the provided model may have. - private static readonly HashSet s_supportedKeyTypes = - [ - typeof(string) - ]; - - /// A set of types that data properties on the provided model may have. - private static readonly HashSet s_supportedDataTypes = - [ - typeof(string), - typeof(int), - typeof(long), - typeof(double), - typeof(float), - typeof(bool), - typeof(DateTimeOffset), - typeof(int?), - typeof(long?), - typeof(double?), - typeof(float?), - typeof(bool?), - typeof(DateTimeOffset?), - ]; - - /// A set of types that vectors on the provided model may have. - /// - /// Azure AI Search is adding support for more types than just float32, but these are not available for use via the - /// SDK yet. We will update this list as the SDK is updated. - /// - /// - private static readonly HashSet s_supportedVectorTypes = - [ - typeof(ReadOnlyMemory), - typeof(ReadOnlyMemory?) - ]; + /// Metadata about vector store record collection. + private readonly VectorStoreRecordCollectionMetadata _collectionMetadata; /// The default options for vector search. private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); @@ -80,74 +50,68 @@ public class AzureAISearchVectorStoreRecordCollection : /// Azure AI Search client that can be used to manage data in an Azure AI Search Service index. private readonly SearchClient _searchClient; - /// The name of the collection that this will access. + /// The name of the collection that this will access. private readonly string _collectionName; /// Optional configuration options for this class. private readonly AzureAISearchVectorStoreRecordCollectionOptions _options; /// A mapper to use for converting between the data model and the Azure AI Search record. - private readonly IVectorStoreRecordMapper? _mapper; + private readonly AzureAISearchDynamicDataModelMapper? _dynamicMapper; - /// A helper to access property information for the current data model and record definition. - private readonly VectorStoreRecordPropertyReader _propertyReader; + /// The model for this collection. + private readonly VectorStoreRecordModel _model; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// Azure AI Search client that can be used to manage the list of indices in an Azure AI Search Service. - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// Optional configuration options for this class. /// Thrown when is null. /// Thrown when options are misconfigured. - public AzureAISearchVectorStoreRecordCollection(SearchIndexClient searchIndexClient, string collectionName, AzureAISearchVectorStoreRecordCollectionOptions? options = default) + public AzureAISearchVectorStoreRecordCollection(SearchIndexClient searchIndexClient, string name, AzureAISearchVectorStoreRecordCollectionOptions? options = default) { // Verify. Verify.NotNull(searchIndexClient); - Verify.NotNullOrWhiteSpace(collectionName); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.JsonObjectCustomMapper is not null, s_supportedKeyTypes); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); + Verify.NotNullOrWhiteSpace(name); + + if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(object)) + { + throw new NotSupportedException("Only string keys are supported (and object for dynamic mapping)"); + } // Assign. this._searchIndexClient = searchIndexClient; - this._collectionName = collectionName; + this._collectionName = name; this._options = options ?? new AzureAISearchVectorStoreRecordCollectionOptions(); - this._searchClient = this._searchIndexClient.GetSearchClient(collectionName); - this._propertyReader = new VectorStoreRecordPropertyReader( - typeof(TRecord), - this._options.VectorStoreRecordDefinition, - new() - { - RequiresAtLeastOneVector = false, - SupportsMultipleKeys = false, - SupportsMultipleVectors = true, - JsonSerializerOptions = this._options.JsonSerializerOptions ?? JsonSerializerOptions.Default - }); + this._searchClient = this._searchIndexClient.GetSearchClient(name); - // Validate property types. - this._propertyReader.VerifyKeyProperties(s_supportedKeyTypes); - this._propertyReader.VerifyDataProperties(s_supportedDataTypes, supportEnumerable: true); - this._propertyReader.VerifyVectorProperties(s_supportedVectorTypes); + this._model = typeof(TRecord) == typeof(Dictionary) ? + new AzureAISearchDynamicModelBuilder().Build(typeof(TRecord), this._options.VectorStoreRecordDefinition, this._options.EmbeddingGenerator) : + new AzureAISearchModelBuilder().Build(typeof(TRecord), this._options.VectorStoreRecordDefinition, this._options.EmbeddingGenerator, this._options.JsonSerializerOptions); // Resolve mapper. - // First, if someone has provided a custom mapper, use that. // If they didn't provide a custom mapper, and the record type is the generic data model, use the built in mapper for that. // Otherwise, don't set the mapper, and we'll default to just using Azure AI Search's built in json serialization and deserialization. - if (this._options.JsonObjectCustomMapper is not null) + if (typeof(TRecord) == typeof(Dictionary)) { - this._mapper = this._options.JsonObjectCustomMapper; + this._dynamicMapper = new AzureAISearchDynamicDataModelMapper(this._model); } - else if (typeof(TRecord) == typeof(VectorStoreGenericDataModel)) + + this._collectionMetadata = new() { - this._mapper = new AzureAISearchGenericDataModelMapper(this._propertyReader.RecordDefinition) as IVectorStoreRecordMapper; - } + VectorStoreSystemName = AzureAISearchConstants.VectorStoreSystemName, + VectorStoreName = searchIndexClient.ServiceName, + CollectionName = name + }; } /// - public string CollectionName => this._collectionName; + public string Name => this._collectionName; /// - public virtual async Task CollectionExistsAsync(CancellationToken cancellationToken = default) + public async Task CollectionExistsAsync(CancellationToken cancellationToken = default) { try { @@ -162,7 +126,8 @@ public virtual async Task CollectionExistsAsync(CancellationToken cancella { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreSystemName = AzureAISearchConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, CollectionName = this._collectionName, OperationName = "GetIndex" }; @@ -170,41 +135,35 @@ public virtual async Task CollectionExistsAsync(CancellationToken cancella } /// - public virtual Task CreateCollectionAsync(CancellationToken cancellationToken = default) + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) { var vectorSearchConfig = new VectorSearch(); var searchFields = new List(); // Loop through all properties and create the search fields. - foreach (var property in this._propertyReader.Properties) + foreach (var property in this._model.Properties) { - // Key property. - if (property is VectorStoreRecordKeyProperty keyProperty) + switch (property) { - searchFields.Add(AzureAISearchVectorStoreCollectionCreateMapping.MapKeyField( - keyProperty, - this._propertyReader.KeyPropertyJsonName)); - } + case VectorStoreRecordKeyPropertyModel p: + searchFields.Add(AzureAISearchVectorStoreCollectionCreateMapping.MapKeyField(p)); + break; - // Data property. - if (property is VectorStoreRecordDataProperty dataProperty) - { - searchFields.Add(AzureAISearchVectorStoreCollectionCreateMapping.MapDataField( - dataProperty, - this._propertyReader.GetJsonPropertyName(dataProperty.DataModelPropertyName))); - } + case VectorStoreRecordDataPropertyModel p: + searchFields.Add(AzureAISearchVectorStoreCollectionCreateMapping.MapDataField(p)); + break; - // Vector property. - if (property is VectorStoreRecordVectorProperty vectorProperty) - { - (VectorSearchField vectorSearchField, VectorSearchAlgorithmConfiguration algorithmConfiguration, VectorSearchProfile vectorSearchProfile) = AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField( - vectorProperty, - this._propertyReader.GetJsonPropertyName(vectorProperty.DataModelPropertyName)); - - // Add the search field, plus its profile and algorithm configuration to the search config. - searchFields.Add(vectorSearchField); - vectorSearchConfig.Algorithms.Add(algorithmConfiguration); - vectorSearchConfig.Profiles.Add(vectorSearchProfile); + case VectorStoreRecordVectorPropertyModel p: + (VectorSearchField vectorSearchField, VectorSearchAlgorithmConfiguration algorithmConfiguration, VectorSearchProfile vectorSearchProfile) = AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField(p); + + // Add the search field, plus its profile and algorithm configuration to the search config. + searchFields.Add(vectorSearchField); + vectorSearchConfig.Algorithms.Add(algorithmConfiguration); + vectorSearchConfig.Profiles.Add(vectorSearchProfile); + break; + + default: + throw new UnreachableException(); } } @@ -218,7 +177,7 @@ public virtual Task CreateCollectionAsync(CancellationToken cancellationToken = } /// - public virtual async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) { if (!await this.CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) { @@ -227,7 +186,7 @@ public virtual async Task CreateCollectionIfNotExistsAsync(CancellationToken can } /// - public virtual Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) { return this.RunOperationAsync( "DeleteIndex", @@ -245,10 +204,8 @@ public virtual Task DeleteCollectionAsync(CancellationToken cancellationToken = } /// - public virtual Task GetAsync(string key, GetRecordOptions? options = default, CancellationToken cancellationToken = default) + public Task GetAsync(TKey key, GetRecordOptions? options = default, CancellationToken cancellationToken = default) { - Verify.NotNullOrWhiteSpace(key); - // Create Options. var innerOptions = this.ConvertGetDocumentOptions(options); var includeVectors = options?.IncludeVectors ?? false; @@ -258,7 +215,7 @@ public virtual Task DeleteCollectionAsync(CancellationToken cancellationToken = } /// - public virtual async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = default, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable GetAsync(IEnumerable keys, GetRecordOptions? options = default, [EnumeratorCancellation] CancellationToken cancellationToken = default) { Verify.NotNull(keys); @@ -279,29 +236,35 @@ public virtual async IAsyncEnumerable GetBatchAsync(IEnumerable } /// - public virtual Task DeleteAsync(string key, CancellationToken cancellationToken = default) + public Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) { - Verify.NotNullOrWhiteSpace(key); + var stringKey = this.GetStringKey(key); // Remove record. return this.RunOperationAsync( "DeleteDocuments", - () => this._searchClient.DeleteDocumentsAsync(this._propertyReader.KeyPropertyJsonName, [key], new IndexDocumentsOptions(), cancellationToken)); + () => this._searchClient.DeleteDocumentsAsync(this._model.KeyProperty.StorageName, [stringKey], new IndexDocumentsOptions(), cancellationToken)); } /// - public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + public Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default) { Verify.NotNull(keys); + if (!keys.Any()) + { + return Task.CompletedTask; + } + + var stringKeys = keys is IEnumerable k ? k : keys.Cast(); // Remove records. return this.RunOperationAsync( "DeleteDocuments", - () => this._searchClient.DeleteDocumentsAsync(this._propertyReader.KeyPropertyJsonName, keys, new IndexDocumentsOptions(), cancellationToken)); + () => this._searchClient.DeleteDocumentsAsync(this._model.KeyProperty.StorageName, stringKeys, new IndexDocumentsOptions(), cancellationToken)); } /// - public virtual async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) { Verify.NotNull(record); @@ -310,13 +273,18 @@ public virtual async Task UpsertAsync(TRecord record, CancellationToken // Upsert record. var results = await this.MapToStorageModelAndUploadDocumentAsync([record], innerOptions, cancellationToken).ConfigureAwait(false); - return results.Value.Results[0].Key; + + return (TKey)(object)results.Value.Results[0].Key; } /// - public virtual async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async Task> UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) { Verify.NotNull(records); + if (!records.Any()) + { + return []; + } // Create Options var innerOptions = new IndexDocumentsOptions { ThrowOnAnyError = true }; @@ -324,32 +292,36 @@ public virtual async IAsyncEnumerable UpsertBatchAsync(IEnumerable x.Key).ToList(); - foreach (var resultKey in resultKeys) { yield return resultKey; } + return results.Value.Results.Select(x => (TKey)(object)x.Key).ToList(); } /// - public virtual Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public IAsyncEnumerable> SearchEmbeddingAsync( + TVector vector, + int top, + VectorSearchOptions? options = null, + CancellationToken cancellationToken = default) + where TVector : notnull { - var floatVector = VerifyVectorParam(vector); + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); - // Resolve options. - var internalOptions = options ?? s_defaultVectorSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(internalOptions); - var vectorPropertyName = this._propertyReader.GetJsonPropertyName(vectorProperty!.DataModelPropertyName); + var floatVector = VerifyVectorParam(vector); + Verify.NotLessThan(top, 1); // Configure search settings. - var vectorQueries = new List(); - vectorQueries.Add(new VectorizedQuery(floatVector) { KNearestNeighborsCount = internalOptions.Top, Fields = { vectorPropertyName } }); + var vectorQueries = new List + { + new VectorizedQuery(floatVector) { KNearestNeighborsCount = top, Fields = { vectorProperty.StorageName } } + }; #pragma warning disable CS0618 // VectorSearchFilter is obsolete // Build filter object. - var filter = internalOptions switch + var filter = options switch { { OldFilter: not null, Filter: not null } => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), - { OldFilter: VectorSearchFilter legacyFilter } => AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(legacyFilter, this._propertyReader.JsonPropertyNamesMap), - { Filter: Expression> newFilter } => new AzureAISearchFilterTranslator().Translate(newFilter, this._propertyReader.StoragePropertyNamesMap), + { OldFilter: VectorSearchFilter legacyFilter } => AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(legacyFilter, this._model), + { Filter: Expression> newFilter } => new AzureAISearchFilterTranslator().Translate(newFilter, this._model), _ => null }; #pragma warning restore CS0618 @@ -358,9 +330,8 @@ public virtual Task> VectorizedSearchAsync var searchOptions = new SearchOptions { VectorSearch = new(), - Size = internalOptions.Top, - Skip = internalOptions.Skip, - IncludeTotalCount = internalOptions.IncludeTotalCount, + Size = top, + Skip = options.Skip, }; if (filter is not null) @@ -371,41 +342,112 @@ public virtual Task> VectorizedSearchAsync searchOptions.VectorSearch.Queries.AddRange(vectorQueries); // Filter out vector fields if requested. - if (!internalOptions.IncludeVectors) + if (!options.IncludeVectors) + { + searchOptions.Select.Add(this._model.KeyProperty.StorageName); + + foreach (var dataProperty in this._model.DataProperties) + { + searchOptions.Select.Add(dataProperty.StorageName); + } + } + + return this.SearchAndMapToDataModelAsync(null, searchOptions, options.IncludeVectors, cancellationToken); + } + + /// + [Obsolete("Use either SearchEmbeddingAsync to search directly on embeddings, or SearchAsync to handle embedding generation internally as part of the call.")] + public IAsyncEnumerable> VectorizedSearchAsync(TVector vector, int top, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + => this.SearchEmbeddingAsync(vector, top, options, cancellationToken); + + /// + public IAsyncEnumerable GetAsync(Expression> filter, int top, + GetFilteredRecordOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(filter); + Verify.NotLessThan(top, 1); + + options ??= new(); + + SearchOptions searchOptions = new() + { + VectorSearch = new(), + Size = top, + Skip = options.Skip, + Filter = new AzureAISearchFilterTranslator().Translate(filter, this._model), + }; + + // Filter out vector fields if requested. + if (!options.IncludeVectors) + { + searchOptions.Select.Add(this._model.KeyProperty.StorageName); + + foreach (var dataProperty in this._model.DataProperties) + { + searchOptions.Select.Add(dataProperty.StorageName); + } + } + + foreach (var pair in options.OrderBy.Values) { - searchOptions.Select.Add(this._propertyReader.KeyPropertyJsonName); - searchOptions.Select.AddRange(this._propertyReader.DataPropertyJsonNames); + VectorStoreRecordPropertyModel property = this._model.GetDataOrKeyProperty(pair.PropertySelector); + string name = property.StorageName; + // From https://learn.microsoft.com/dotnet/api/azure.search.documents.searchoptions.orderby: + // "Each expression can be followed by asc to indicate ascending, or desc to indicate descending". + // "The default is ascending order." + if (!pair.Ascending) + { + name += " desc"; + } + + searchOptions.OrderBy.Add(name); } - return this.SearchAndMapToDataModelAsync(null, searchOptions, internalOptions.IncludeVectors, cancellationToken); + return this.SearchAndMapToDataModelAsync(null, searchOptions, options.IncludeVectors, cancellationToken) + .SelectAsync(result => result.Record, cancellationToken); } /// - public virtual Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public IAsyncEnumerable> SearchAsync( + TInput value, + int top, + VectorSearchOptions? options = default, + CancellationToken cancellationToken = default) + where TInput : notnull { + var searchText = value switch + { + string s => s, + null => throw new ArgumentNullException(nameof(value)), + _ => throw new ArgumentException($"The provided search type '{value?.GetType().Name}' is not supported by the Azure AI Search connector, pass a string.") + }; + Verify.NotNull(searchText); + Verify.NotLessThan(top, 1); - if (this._propertyReader.FirstVectorPropertyName is null) + if (this._model.VectorProperties.Count == 0) { throw new InvalidOperationException("The collection does not have any vector fields, so vector search is not possible."); } // Resolve options. - var internalOptions = options ?? s_defaultVectorSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(internalOptions); - var vectorPropertyName = this._propertyReader.GetJsonPropertyName(vectorProperty!.DataModelPropertyName); + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); // Configure search settings. - var vectorQueries = new List(); - vectorQueries.Add(new VectorizableTextQuery(searchText) { KNearestNeighborsCount = internalOptions.Top, Fields = { vectorPropertyName } }); + var vectorQueries = new List + { + new VectorizableTextQuery(searchText) { KNearestNeighborsCount = top, Fields = { vectorProperty.StorageName } } + }; #pragma warning disable CS0618 // VectorSearchFilter is obsolete // Build filter object. - var filter = internalOptions switch + var filter = options switch { { OldFilter: not null, Filter: not null } => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), - { OldFilter: VectorSearchFilter legacyFilter } => AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(legacyFilter, this._propertyReader.JsonPropertyNamesMap), - { Filter: Expression> newFilter } => new AzureAISearchFilterTranslator().Translate(newFilter, this._propertyReader.StoragePropertyNamesMap), + { OldFilter: VectorSearchFilter legacyFilter } => AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(legacyFilter, this._model), + { Filter: Expression> newFilter } => new AzureAISearchFilterTranslator().Translate(newFilter, this._model), _ => null }; #pragma warning restore CS0618 @@ -414,9 +456,8 @@ public virtual Task> VectorizableTextSearchAsync(st var searchOptions = new SearchOptions { VectorSearch = new(), - Size = internalOptions.Top, - Skip = internalOptions.Skip, - IncludeTotalCount = internalOptions.IncludeTotalCount, + Size = top, + Skip = options.Skip, }; if (filter is not null) @@ -427,39 +468,49 @@ public virtual Task> VectorizableTextSearchAsync(st searchOptions.VectorSearch.Queries.AddRange(vectorQueries); // Filter out vector fields if requested. - if (!internalOptions.IncludeVectors) + if (!options.IncludeVectors) { - searchOptions.Select.Add(this._propertyReader.KeyPropertyJsonName); - searchOptions.Select.AddRange(this._propertyReader.DataPropertyJsonNames); + searchOptions.Select.Add(this._model.KeyProperty.StorageName); + + foreach (var dataProperty in this._model.DataProperties) + { + searchOptions.Select.Add(dataProperty.StorageName); + } } - return this.SearchAndMapToDataModelAsync(null, searchOptions, internalOptions.IncludeVectors, cancellationToken); + return this.SearchAndMapToDataModelAsync(null, searchOptions, options.IncludeVectors, cancellationToken); } /// - public Task> HybridSearchAsync(TVector vector, ICollection keywords, HybridSearchOptions? options = null, CancellationToken cancellationToken = default) + [Obsolete("Use SearchAsync")] + public IAsyncEnumerable> VectorizableTextSearchAsync(string searchText, int top, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + => this.SearchAsync(searchText, top, options, cancellationToken); + + /// + public IAsyncEnumerable> HybridSearchAsync(TVector vector, ICollection keywords, int top, HybridSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(keywords); var floatVector = VerifyVectorParam(vector); + Verify.NotLessThan(top, 1); // Resolve options. - var internalOptions = options ?? s_defaultKeywordVectorizedHybridSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(new() { VectorProperty = internalOptions.VectorProperty }); - var vectorPropertyName = this._propertyReader.GetJsonPropertyName(vectorProperty.DataModelPropertyName); - var textDataProperty = this._propertyReader.GetFullTextDataPropertyOrSingle(internalOptions.AdditionalProperty); - var textDataPropertyName = this._propertyReader.GetJsonPropertyName(textDataProperty.DataModelPropertyName); + options ??= s_defaultKeywordVectorizedHybridSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(new() { VectorProperty = options.VectorProperty }); + var textDataProperty = this._model.GetFullTextDataPropertyOrSingle(options.AdditionalProperty); // Configure search settings. - var vectorQueries = new List(); - vectorQueries.Add(new VectorizedQuery(floatVector) { KNearestNeighborsCount = internalOptions.Top, Fields = { vectorPropertyName } }); + var vectorQueries = new List + { + new VectorizedQuery(floatVector) { KNearestNeighborsCount = top, Fields = { vectorProperty.StorageName } } + }; #pragma warning disable CS0618 // VectorSearchFilter is obsolete // Build filter object. - var filter = internalOptions switch + var filter = options switch { { OldFilter: not null, Filter: not null } => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), - { OldFilter: VectorSearchFilter legacyFilter } => AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(legacyFilter, this._propertyReader.JsonPropertyNamesMap), - { Filter: Expression> newFilter } => new AzureAISearchFilterTranslator().Translate(newFilter, this._propertyReader.StoragePropertyNamesMap), + { OldFilter: VectorSearchFilter legacyFilter } => AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(legacyFilter, this._model), + { Filter: Expression> newFilter } => new AzureAISearchFilterTranslator().Translate(newFilter, this._model), _ => null }; #pragma warning restore CS0618 @@ -468,24 +519,42 @@ public Task> HybridSearchAsync(TVector vec var searchOptions = new SearchOptions { VectorSearch = new(), - Size = internalOptions.Top, - Skip = internalOptions.Skip, + Size = top, + Skip = options.Skip, Filter = filter, - IncludeTotalCount = internalOptions.IncludeTotalCount, + IncludeTotalCount = options.IncludeTotalCount, }; searchOptions.VectorSearch.Queries.AddRange(vectorQueries); - searchOptions.SearchFields.Add(textDataPropertyName); + searchOptions.SearchFields.Add(textDataProperty.StorageName); // Filter out vector fields if requested. - if (!internalOptions.IncludeVectors) + if (!options.IncludeVectors) { - searchOptions.Select.Add(this._propertyReader.KeyPropertyJsonName); - searchOptions.Select.AddRange(this._propertyReader.DataPropertyJsonNames); + searchOptions.Select.Add(this._model.KeyProperty.StorageName); + + foreach (var dataProperty in this._model.DataProperties) + { + searchOptions.Select.Add(dataProperty.StorageName); + } } var keywordsCombined = string.Join(" ", keywords); - return this.SearchAndMapToDataModelAsync(keywordsCombined, searchOptions, internalOptions.IncludeVectors, cancellationToken); + return this.SearchAndMapToDataModelAsync(keywordsCombined, searchOptions, options.IncludeVectors, cancellationToken); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreRecordCollectionMetadata) ? this._collectionMetadata : + serviceType == typeof(SearchIndexClient) ? this._searchIndexClient : + serviceType == typeof(SearchClient) ? this._searchClient : + serviceType.IsInstanceOfType(this) ? this : + null; } /// @@ -497,19 +566,23 @@ public Task> HybridSearchAsync(TVector vec /// The to monitor for cancellation requests. The default is . /// The retrieved document, mapped to the consumer data model. private async Task GetDocumentAndMapToDataModelAsync( - string key, + TKey key, bool includeVectors, GetDocumentOptions innerOptions, CancellationToken cancellationToken) { const string OperationName = "GetDocument"; + var stringKey = this.GetStringKey(key); + // Use the user provided mapper. - if (this._mapper is not null) + if (this._dynamicMapper is not null) { + Debug.Assert(typeof(TRecord) == typeof(Dictionary)); + var jsonObject = await this.RunOperationAsync( OperationName, - () => GetDocumentWithNotFoundHandlingAsync(this._searchClient, key, innerOptions, cancellationToken)).ConfigureAwait(false); + () => GetDocumentWithNotFoundHandlingAsync(this._searchClient, stringKey, innerOptions, cancellationToken)).ConfigureAwait(false); if (jsonObject is null) { @@ -517,16 +590,17 @@ public Task> HybridSearchAsync(TVector vec } return VectorStoreErrorHandler.RunModelConversion( - DatabaseName, + AzureAISearchConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, this._collectionName, OperationName, - () => this._mapper!.MapFromStorageToDataModel(jsonObject, new() { IncludeVectors = includeVectors })); + () => (TRecord)(object)this._dynamicMapper!.MapFromStorageToDataModel(jsonObject, new() { IncludeVectors = includeVectors })); } // Use the built in Azure AI Search mapper. return await this.RunOperationAsync( OperationName, - () => GetDocumentWithNotFoundHandlingAsync(this._searchClient, key, innerOptions, cancellationToken)).ConfigureAwait(false); + () => GetDocumentWithNotFoundHandlingAsync(this._searchClient, stringKey, innerOptions, cancellationToken)).ConfigureAwait(false); } /// @@ -537,29 +611,37 @@ public Task> HybridSearchAsync(TVector vec /// A value indicating whether to include vectors in the result or not. /// The to monitor for cancellation requests. The default is . /// The mapped search results. - private async Task> SearchAndMapToDataModelAsync( + private async IAsyncEnumerable> SearchAndMapToDataModelAsync( string? searchText, SearchOptions searchOptions, bool includeVectors, - CancellationToken cancellationToken) + [EnumeratorCancellation] CancellationToken cancellationToken) { const string OperationName = "Search"; // Execute search and map using the user provided mapper. - if (this._options.JsonObjectCustomMapper is not null) + if (this._dynamicMapper is not null) { + Debug.Assert(typeof(TRecord) == typeof(Dictionary)); + var jsonObjectResults = await this.RunOperationAsync( OperationName, () => this._searchClient.SearchAsync(searchText, searchOptions, cancellationToken)).ConfigureAwait(false); - var mappedJsonObjectResults = this.MapSearchResultsAsync(jsonObjectResults.Value.GetResultsAsync(), OperationName, includeVectors); - return new VectorSearchResults(mappedJsonObjectResults) { TotalCount = jsonObjectResults.Value.TotalCount }; + await foreach (var result in this.MapSearchResultsAsync(jsonObjectResults.Value.GetResultsAsync(), OperationName, includeVectors).ConfigureAwait(false)) + { + yield return result; + } + + yield break; } // Execute search and map using the built in Azure AI Search mapper. Response> results = await this.RunOperationAsync(OperationName, () => this._searchClient.SearchAsync(searchText, searchOptions, cancellationToken)).ConfigureAwait(false); - var mappedResults = this.MapSearchResultsAsync(results.Value.GetResultsAsync()); - return new VectorSearchResults(mappedResults) { TotalCount = results.Value.TotalCount }; + await foreach (var result in this.MapSearchResultsAsync(results.Value.GetResultsAsync()).ConfigureAwait(false)) + { + yield return result; + } } /// @@ -577,13 +659,16 @@ private Task> MapToStorageModelAndUploadDocumentA const string OperationName = "UploadDocuments"; // Use the user provided mapper. - if (this._mapper is not null) + if (this._dynamicMapper is not null) { + Debug.Assert(typeof(TRecord) == typeof(Dictionary)); + var jsonObjects = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, + AzureAISearchConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, this._collectionName, OperationName, - () => records.Select(this._mapper!.MapFromDataToStorageModel)); + () => records.Select(r => this._dynamicMapper!.MapFromDataToStorageModel((Dictionary)(object)r))); return this.RunOperationAsync( OperationName, @@ -597,7 +682,7 @@ private Task> MapToStorageModelAndUploadDocumentA } /// - /// Map the search results from to objects using the configured mapper type. + /// Map the search results from to objects using the configured mapper type. /// /// The search results to map. /// The name of the current operation for telemetry purposes. @@ -608,16 +693,17 @@ private async IAsyncEnumerable> MapSearchResultsAsyn await foreach (var result in results.ConfigureAwait(false)) { var document = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, + AzureAISearchConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, this._collectionName, operationName, - () => this._options.JsonObjectCustomMapper!.MapFromStorageToDataModel(result.Document, new() { IncludeVectors = includeVectors })); + () => (TRecord)(object)this._dynamicMapper!.MapFromStorageToDataModel(result.Document, new() { IncludeVectors = includeVectors })); yield return new VectorSearchResult(document, result.Score); } } /// - /// Map the search results from to objects. + /// Map the search results from to objects. /// /// The search results to map. /// The mapped results. @@ -639,8 +725,12 @@ private GetDocumentOptions ConvertGetDocumentOptions(GetRecordOptions? options) var innerOptions = new GetDocumentOptions(); if (options?.IncludeVectors is not true) { - innerOptions.SelectedFields.AddRange(this._propertyReader.KeyPropertyJsonNames); - innerOptions.SelectedFields.AddRange(this._propertyReader.DataPropertyJsonNames); + innerOptions.SelectedFields.Add(this._model.KeyProperty.StorageName); + + foreach (var dataProperty in this._model.DataProperties) + { + innerOptions.SelectedFields.Add(dataProperty.StorageName); + } } return innerOptions; @@ -688,7 +778,8 @@ private async Task RunOperationAsync(string operationName, Func> o { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreSystemName = AzureAISearchConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, CollectionName = this._collectionName, OperationName = operationName }; @@ -697,7 +788,8 @@ private async Task RunOperationAsync(string operationName, Func> o { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreSystemName = AzureAISearchConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, CollectionName = this._collectionName, OperationName = operationName }; @@ -715,4 +807,15 @@ private static ReadOnlyMemory VerifyVectorParam(TVector vector) return floatVector; } + + private string GetStringKey(TKey key) + { + Verify.NotNull(key); + + var stringKey = key as string ?? throw new UnreachableException("string key should have been validated during model building"); + + Verify.NotNullOrWhiteSpace(stringKey, nameof(key)); + + return stringKey; + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollectionOptions.cs index 5d2ec9c9bb23..0b405f4710bd 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollectionOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollectionOptions.cs @@ -1,14 +1,16 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Text.Json; using System.Text.Json.Nodes; using Azure.Search.Documents.Indexes; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; /// -/// Options when creating a . +/// Options when creating a . /// public sealed class AzureAISearchVectorStoreRecordCollectionOptions { @@ -18,6 +20,7 @@ public sealed class AzureAISearchVectorStoreRecordCollectionOptions /// /// If not set, the default mapper that is provided by the Azure AI Search client SDK will be used. /// + [Obsolete("Custom mappers are no longer supported.", error: true)] public IVectorStoreRecordMapper? JsonObjectCustomMapper { get; init; } = null; /// @@ -36,4 +39,9 @@ public sealed class AzureAISearchVectorStoreRecordCollectionOptions /// to provide the same set of both here and when constructing the . /// public JsonSerializerOptions? JsonSerializerOptions { get; init; } = null; + + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/Connectors.Memory.AzureAISearch.csproj b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/Connectors.Memory.AzureAISearch.csproj index 3a53cd12212f..f4036fe33e1e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/Connectors.Memory.AzureAISearch.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/Connectors.Memory.AzureAISearch.csproj @@ -3,13 +3,14 @@ Microsoft.SemanticKernel.Connectors.AzureAISearch Microsoft.SemanticKernel.Connectors.AzureAISearch - net8.0;netstandard2.0 + net8.0;netstandard2.0;net462 preview + @@ -19,15 +20,21 @@ + + + + + + + - diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/IAzureAISearchVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/IAzureAISearchVectorStoreRecordCollectionFactory.cs index 2c9def54ae18..14f094d659d9 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/IAzureAISearchVectorStoreRecordCollectionFactory.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/IAzureAISearchVectorStoreRecordCollectionFactory.cs @@ -22,5 +22,6 @@ public interface IAzureAISearchVectorStoreRecordCollectionFactory /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. /// The new instance of . IVectorStoreRecordCollection CreateVectorStoreRecordCollection(SearchIndexClient searchIndexClient, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) - where TKey : notnull; + where TKey : notnull + where TRecord : notnull; } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBConfig.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBConfig.cs index 7cb62b601075..289eb2ff4240 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBConfig.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBConfig.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using Microsoft.SemanticKernel.Http; namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; @@ -12,7 +12,7 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; /// /// Initialize the with default values. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being phased out, use Microsoft.Extensions.VectorData and AzureMongoDBMongoDBVectorStore")] public class AzureCosmosDBMongoDBConfig(int dimensions) { private const string DefaultIndexName = "default_index"; diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBConstants.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBConstants.cs new file mode 100644 index 000000000000..b6a003c3e548 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBConstants.cs @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; + +internal static class AzureCosmosDBMongoDBConstants +{ + public const string VectorStoreSystemName = "azure.cosmosdbmongodb"; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBFilterTranslator.cs index 6c0b4e44e23b..94fbb845fc20 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBFilterTranslator.cs @@ -7,8 +7,8 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.ConnectorSupport.Filter; using MongoDB.Bson; namespace Microsoft.SemanticKernel.Connectors.MongoDB; @@ -17,17 +17,20 @@ namespace Microsoft.SemanticKernel.Connectors.MongoDB; // Information specific to vector search pre-filter: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter internal class AzureCosmosDBMongoDBFilterTranslator { - private IReadOnlyDictionary _storagePropertyNames = null!; + private VectorStoreRecordModel _model = null!; private ParameterExpression _recordParameter = null!; - internal BsonDocument Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + internal BsonDocument Translate(LambdaExpression lambdaExpression, VectorStoreRecordModel model) { - this._storagePropertyNames = storagePropertyNames; + this._model = model; Debug.Assert(lambdaExpression.Parameters.Count == 1); this._recordParameter = lambdaExpression.Parameters[0]; - return this.Translate(lambdaExpression.Body); + var preprocessor = new FilterTranslationPreprocessor { InlineCapturedVariables = true }; + var preprocessedExpression = preprocessor.Visit(lambdaExpression.Body); + + return this.Translate(preprocessedExpression); } private BsonDocument Translate(Expression? node) @@ -46,9 +49,9 @@ or ExpressionType.LessThan or ExpressionType.LessThanOrEqual UnaryExpression { NodeType: ExpressionType.Not } not => this.TranslateNot(not), - // MemberExpression is generally handled within e.g. TranslateEqualityComparison; this is used to translate direct bool inside filter (e.g. Filter => r => r.Bool) - MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _) - => this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(true))), + // Special handling for bool constant as the filter expression (r => r.Bool) + Expression when node.Type == typeof(bool) && this.TryBindProperty(node, out var property) + => this.GenerateEqualityComparison(property, value: true, ExpressionType.Equal), MethodCallExpression methodCall => this.TranslateMethodCall(methodCall), @@ -56,36 +59,37 @@ or ExpressionType.LessThan or ExpressionType.LessThanOrEqual }; private BsonDocument TranslateEqualityComparison(BinaryExpression binary) + => this.TryBindProperty(binary.Left, out var property) && binary.Right is ConstantExpression { Value: var rightConstant } + ? this.GenerateEqualityComparison(property, rightConstant, binary.NodeType) + : this.TryBindProperty(binary.Right, out property) && binary.Left is ConstantExpression { Value: var leftConstant } + ? this.GenerateEqualityComparison(property, leftConstant, binary.NodeType) + : throw new NotSupportedException("Invalid equality/comparison"); + + private BsonDocument GenerateEqualityComparison(VectorStoreRecordPropertyModel property, object? value, ExpressionType nodeType) { - if ((this.TryTranslateFieldAccess(binary.Left, out var storagePropertyName) && TryGetConstant(binary.Right, out var value)) - || (this.TryTranslateFieldAccess(binary.Right, out storagePropertyName) && TryGetConstant(binary.Left, out value))) + if (value is null) { - if (value is null) - { - throw new NotSupportedException("MongogDB does not support null checks in vector search pre-filters"); - } - - // Short form of equality (instead of $eq) - if (binary.NodeType is ExpressionType.Equal) - { - return new BsonDocument { [storagePropertyName] = BsonValue.Create(value) }; - } + throw new NotSupportedException("MongogDB does not support null checks in vector search pre-filters"); + } - var filterOperator = binary.NodeType switch - { - ExpressionType.NotEqual => "$ne", - ExpressionType.GreaterThan => "$gt", - ExpressionType.GreaterThanOrEqual => "$gte", - ExpressionType.LessThan => "$lt", - ExpressionType.LessThanOrEqual => "$lte", + // Short form of equality (instead of $eq) + if (nodeType is ExpressionType.Equal) + { + return new BsonDocument { [property.StorageName] = BsonValue.Create(value) }; + } - _ => throw new UnreachableException() - }; + var filterOperator = nodeType switch + { + ExpressionType.NotEqual => "$ne", + ExpressionType.GreaterThan => "$gt", + ExpressionType.GreaterThanOrEqual => "$gte", + ExpressionType.LessThan => "$lt", + ExpressionType.LessThanOrEqual => "$lte", - return new BsonDocument { [storagePropertyName] = new BsonDocument { [filterOperator] = BsonValue.Create(value) } }; - } + _ => throw new UnreachableException() + }; - throw new NotSupportedException("Invalid equality/comparison"); + return new BsonDocument { [property.StorageName] = new BsonDocument { [filterOperator] = BsonValue.Create(value) } }; } private BsonDocument TranslateAndOr(BinaryExpression andOr) @@ -130,9 +134,9 @@ private BsonDocument TranslateNot(UnaryExpression not) binary.Left, binary.Right)); - // Not over bool field (Filter => r => !r.Bool) - case MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _): - return this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(false))); + // Not over bool field (r => !r.Bool) + case var negated when negated.Type == typeof(bool) && this.TryBindProperty(negated, out var property): + return this.GenerateEqualityComparison(property, false, ExpressionType.Equal); } var operand = this.Translate(not.Operand); @@ -174,7 +178,7 @@ private BsonDocument TranslateContains(Expression source, Expression item) switch (source) { // Contains over array column (r => r.Strings.Contains("foo")) - case var _ when this.TryTranslateFieldAccess(source, out _): + case var _ when this.TryBindProperty(source, out _): throw new NotSupportedException("MongoDB does not support Contains within array fields ($elemMatch) in vector search pre-filters"); // Contains over inline enumerable @@ -183,7 +187,7 @@ private BsonDocument TranslateContains(Expression source, Expression item) for (var i = 0; i < newArray.Expressions.Count; i++) { - if (!TryGetConstant(newArray.Expressions[i], out var elementValue)) + if (newArray.Expressions[i] is not ConstantExpression { Value: var elementValue }) { throw new NotSupportedException("Invalid element in array"); } @@ -194,8 +198,7 @@ private BsonDocument TranslateContains(Expression source, Expression item) return ProcessInlineEnumerable(elements, item); // Contains over captured enumerable (we inline) - case var _ when TryGetConstant(source, out var constantEnumerable) - && constantEnumerable is IEnumerable enumerable and not string: + case ConstantExpression { Value: IEnumerable enumerable and not string }: return ProcessInlineEnumerable(enumerable, item); default: @@ -204,14 +207,14 @@ private BsonDocument TranslateContains(Expression source, Expression item) BsonDocument ProcessInlineEnumerable(IEnumerable elements, Expression item) { - if (!this.TryTranslateFieldAccess(item, out var storagePropertyName)) + if (!this.TryBindProperty(item, out var property)) { throw new NotSupportedException("Unsupported item type in Contains"); } return new BsonDocument { - [storagePropertyName] = new BsonDocument + [property.StorageName] = new BsonDocument { ["$in"] = new BsonArray(from object? element in elements select BsonValue.Create(element)) } @@ -219,40 +222,49 @@ BsonDocument ProcessInlineEnumerable(IEnumerable elements, Expression item) } } - private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName) + private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out VectorStoreRecordPropertyModel? property) { - if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter) - { - if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName)) - { - throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name."); - } + Type? convertedClrType = null; - return true; + if (expression is UnaryExpression { NodeType: ExpressionType.Convert } unary) + { + expression = unary.Operand; + convertedClrType = unary.Type; } - storagePropertyName = null; - return false; - } + var modelName = expression switch + { + // Regular member access for strongly-typed POCO binding (e.g. r => r.SomeInt == 8) + MemberExpression memberExpression when memberExpression.Expression == this._recordParameter + => memberExpression.Member.Name, - private static bool TryGetConstant(Expression expression, out object? constantValue) - { - switch (expression) + // Dictionary lookup for weakly-typed dynamic binding (e.g. r => r["SomeInt"] == 8) + MethodCallExpression + { + Method: { Name: "get_Item", DeclaringType: var declaringType }, + Arguments: [ConstantExpression { Value: string keyName }] + } methodCall when methodCall.Object == this._recordParameter && declaringType == typeof(Dictionary) + => keyName, + + _ => null + }; + + if (modelName is null) { - case ConstantExpression { Value: var v }: - constantValue = v; - return true; + property = null; + return false; + } - // This identifies compiler-generated closure types which contain captured variables. - case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } - when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) - && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true): - constantValue = fieldInfo.GetValue(constant.Value); - return true; + if (!this._model.PropertyMap.TryGetValue(modelName, out property)) + { + throw new InvalidOperationException($"Property name '{modelName}' provided as part of the filter clause is not a valid property name."); + } - default: - constantValue = null; - return false; + if (convertedClrType is not null && convertedClrType != property.Type) + { + throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convertedClrType.Name}', but its configured type is '{property.Type.Name}'."); } + + return true; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBKernelBuilderExtensions.cs index af73629568ec..9c3f0ea120e5 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBKernelBuilderExtensions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; using MongoDB.Driver; @@ -9,6 +10,7 @@ namespace Microsoft.SemanticKernel; /// /// Extension methods to register Azure CosmosDB MongoDB instances on the . /// +[Obsolete("The IKernelBuilder extensions are being obsoleted, call the appropriate function on the Services property of your IKernelBuilder")] public static class AzureCosmosDBMongoDBKernelBuilderExtensions { /// @@ -50,7 +52,7 @@ public static IKernelBuilder AddAzureCosmosDBMongoDBVectorStore( } /// - /// Register an Azure CosmosDB MongoDB and with the specified service ID + /// Register an Azure CosmosDB MongoDB and with the specified service ID /// and where the Azure CosmosDB MongoDB is retrieved from the dependency injection container. /// /// The type of the record. @@ -64,13 +66,14 @@ public static IKernelBuilder AddAzureCosmosDBMongoDBVectorStoreRecordCollection< string collectionName, AzureCosmosDBMongoDBVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { builder.Services.AddAzureCosmosDBMongoDBVectorStoreRecordCollection(collectionName, options, serviceId); return builder; } /// - /// Register an Azure CosmosDB MongoDB and with the specified service ID + /// Register an Azure CosmosDB MongoDB and with the specified service ID /// and where the Azure CosmosDB MongoDB is constructed using the provided and . /// /// The type of the record. @@ -88,6 +91,7 @@ public static IKernelBuilder AddAzureCosmosDBMongoDBVectorStoreRecordCollection< string databaseName, AzureCosmosDBMongoDBVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { builder.Services.AddAzureCosmosDBMongoDBVectorStoreRecordCollection(collectionName, connectionString, databaseName, options, serviceId); return builder; diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecord.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecord.cs index d9e181d95e5e..4bd5ac65a802 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecord.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecord.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Diagnostics.CodeAnalysis; using System.Linq; using Microsoft.SemanticKernel.Memory; using MongoDB.Bson; @@ -10,10 +9,12 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// A MongoDB memory record. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being phased out, use Microsoft.Extensions.VectorData and AzureMongoDBMongoDBVectorStore")] internal sealed class AzureCosmosDBMongoDBMemoryRecord { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecordMetadata.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecordMetadata.cs index a614ad0d8c87..095d8f62f17a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecordMetadata.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecordMetadata.cs @@ -1,16 +1,18 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using Microsoft.SemanticKernel.Memory; using MongoDB.Bson.Serialization.Attributes; namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// A MongoDB memory record metadata. /// #pragma warning disable CA1815 // Override equals and operator equals on value types -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being phased out, use Microsoft.Extensions.VectorData and AzureMongoDBMongoDBVectorStore")] internal struct AzureCosmosDBMongoDBMemoryRecordMetadata #pragma warning restore CA1815 // Override equals and operator equals on value types { diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStore.cs index e91048e780d6..931c7fe3d792 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStore.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Linq; using System.Runtime.CompilerServices; @@ -14,11 +13,13 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// An implementation of backed by a Azure CosmosDB Mongo vCore database. /// Get more details about Azure Cosmos Mongo vCore vector search https://learn.microsoft.com/en-us/azure/cosmos-db/mongodb/vcore/vector-search /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being phased out, use Microsoft.Extensions.VectorData and AzureMongoDBMongoDBVectorStore")] public class AzureCosmosDBMongoDBMemoryStore : IMemoryStore, IDisposable { private readonly MongoClient _mongoClient; diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBServiceCollectionExtensions.cs index f4f77082a271..58a55be23741 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBServiceCollectionExtensions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; @@ -33,7 +34,10 @@ public static IServiceCollection AddAzureCosmosDBMongoDBVectorStore( (sp, obj) => { var database = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService(); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; return new AzureCosmosDBMongoDBVectorStore(database, options); }); @@ -70,7 +74,10 @@ public static IServiceCollection AddAzureCosmosDBMongoDBVectorStore( var mongoClient = new MongoClient(settings); var database = mongoClient.GetDatabase(databaseName); - var selectedOptions = options ?? sp.GetService(); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; return new AzureCosmosDBMongoDBVectorStore(database, options); }); @@ -79,7 +86,7 @@ public static IServiceCollection AddAzureCosmosDBMongoDBVectorStore( } /// - /// Register an Azure CosmosDB MongoDB and with the specified service ID + /// Register an Azure CosmosDB MongoDB and with the specified service ID /// and where the Azure CosmosDB MongoDB is retrieved from the dependency injection container. /// /// The type of the record. @@ -93,15 +100,19 @@ public static IServiceCollection AddAzureCosmosDBMongoDBVectorStoreRecordCollect string collectionName, AzureCosmosDBMongoDBVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { services.AddKeyedTransient>( serviceId, (sp, obj) => { var database = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new AzureCosmosDBMongoDBVectorStoreRecordCollection(database, collectionName, selectedOptions); + return new AzureCosmosDBMongoDBVectorStoreRecordCollection(database, collectionName, options); }); AddVectorizedSearch(services, serviceId); @@ -110,7 +121,7 @@ public static IServiceCollection AddAzureCosmosDBMongoDBVectorStoreRecordCollect } /// - /// Register an Azure CosmosDB MongoDB and with the specified service ID + /// Register an Azure CosmosDB MongoDB and with the specified service ID /// and where the Azure CosmosDB MongoDB is constructed using the provided and . /// /// The type of the record. @@ -128,6 +139,7 @@ public static IServiceCollection AddAzureCosmosDBMongoDBVectorStoreRecordCollect string databaseName, AzureCosmosDBMongoDBVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { services.AddKeyedSingleton>( serviceId, @@ -139,9 +151,12 @@ public static IServiceCollection AddAzureCosmosDBMongoDBVectorStoreRecordCollect var mongoClient = new MongoClient(settings); var database = mongoClient.GetDatabase(databaseName); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new AzureCosmosDBMongoDBVectorStoreRecordCollection(database, collectionName, selectedOptions); + return new AzureCosmosDBMongoDBVectorStoreRecordCollection(database, collectionName, options); }); AddVectorizedSearch(services, serviceId); @@ -150,14 +165,14 @@ public static IServiceCollection AddAzureCosmosDBMongoDBVectorStoreRecordCollect } /// - /// Also register the with the given as a . + /// Also register the with the given as a . /// /// The type of the data model that the collection should contain. /// The service collection to register on. /// The service id that the registrations should use. - private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) + private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) where TRecord : notnull { - services.AddKeyedTransient>( + services.AddKeyedTransient>( serviceId, (sp, obj) => { diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStore.cs index 76dc9e8500a4..50d2295284ba 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStore.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Threading; +using System.Threading.Tasks; using Microsoft.Extensions.VectorData; using MongoDB.Driver; @@ -15,14 +16,20 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; /// /// This class can be used with collections of any schema type, but requires you to provide schema information when getting a collection. /// -public class AzureCosmosDBMongoDBVectorStore : IVectorStore +public sealed class AzureCosmosDBMongoDBVectorStore : IVectorStore { + /// Metadata about vector store. + private readonly VectorStoreMetadata _metadata; + /// that can be used to manage the collections in Azure CosmosDB MongoDB. private readonly IMongoDatabase _mongoDatabase; /// Optional configuration options for this class. private readonly AzureCosmosDBMongoDBVectorStoreOptions _options; + /// A general purpose definition that can be used to construct a collection when needing to proxy schema agnostic operations. + private static readonly VectorStoreRecordDefinition s_generalPurposeDefinition = new() { Properties = [new VectorStoreRecordKeyProperty("Key", typeof(string))] }; + /// /// Initializes a new instance of the class. /// @@ -34,11 +41,18 @@ public AzureCosmosDBMongoDBVectorStore(IMongoDatabase mongoDatabase, AzureCosmos this._mongoDatabase = mongoDatabase; this._options = options ?? new(); + + this._metadata = new() + { + VectorStoreSystemName = AzureCosmosDBMongoDBConstants.VectorStoreSystemName, + VectorStoreName = mongoDatabase.DatabaseNamespace?.DatabaseName + }; } /// - public virtual IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull + where TRecord : notnull { #pragma warning disable CS0618 // IAzureCosmosDBMongoDBVectorStoreRecordCollectionFactory is obsolete if (this._options.VectorStoreCollectionFactory is not null) @@ -47,21 +61,20 @@ public virtual IVectorStoreRecordCollection GetCollection( + var recordCollection = new AzureCosmosDBMongoDBVectorStoreRecordCollection( this._mongoDatabase, name, - new() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }) as IVectorStoreRecordCollection; + new() + { + VectorStoreRecordDefinition = vectorStoreRecordDefinition, + EmbeddingGenerator = this._options.EmbeddingGenerator + }) as IVectorStoreRecordCollection; return recordCollection!; } /// - public virtual async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { using var cursor = await this._mongoDatabase .ListCollectionNamesAsync(cancellationToken: cancellationToken) @@ -75,4 +88,31 @@ public virtual async IAsyncEnumerable ListCollectionNamesAsync([Enumerat } } } + + /// + public Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.CollectionExistsAsync(cancellationToken); + } + + /// + public Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.DeleteCollectionAsync(cancellationToken); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreMetadata) ? this._metadata : + serviceType == typeof(IMongoDatabase) ? this._mongoDatabase : + serviceType.IsInstanceOfType(this) ? this : + null; + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionCreateMapping.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionCreateMapping.cs index 8bd883163870..0263af542c2e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionCreateMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionCreateMapping.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using MongoDB.Bson; namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; @@ -16,13 +17,11 @@ internal static class AzureCosmosDBMongoDBVectorStoreCollectionCreateMapping /// Returns an array of indexes to create for vector properties. /// /// Collection of vector properties for index creation. - /// A dictionary that maps from a property name to the storage name. /// Collection of unique existing indexes to avoid creating duplicates. /// Number of clusters that the inverted file (IVF) index uses to group the vector data. /// The size of the dynamic candidate list for constructing the graph. public static BsonArray GetVectorIndexes( - IReadOnlyList vectorProperties, - Dictionary storagePropertyNames, + IReadOnlyList vectorProperties, HashSet uniqueIndexes, int numLists, int efConstruction) @@ -32,9 +31,10 @@ public static BsonArray GetVectorIndexes( // Create separate index for each vector property foreach (var property in vectorProperties) { + var storageName = property.StorageName; + // Use index name same as vector property name with underscore - var vectorPropertyName = storagePropertyNames[property.DataModelPropertyName]; - var indexName = $"{vectorPropertyName}_"; + var indexName = $"{storageName}_"; // If index already exists, proceed to the next vector property if (uniqueIndexes.Contains(indexName)) @@ -45,9 +45,9 @@ public static BsonArray GetVectorIndexes( // Otherwise, create a new index var searchOptions = new BsonDocument { - { "kind", GetIndexKind(property.IndexKind, vectorPropertyName) }, + { "kind", GetIndexKind(property.IndexKind, storageName) }, { "numLists", numLists }, - { "similarity", GetDistanceFunction(property.DistanceFunction, vectorPropertyName) }, + { "similarity", GetDistanceFunction(property.DistanceFunction, storageName) }, { "dimensions", property.Dimensions }, { "efConstruction", efConstruction } }; @@ -55,7 +55,7 @@ public static BsonArray GetVectorIndexes( var indexDocument = new BsonDocument { ["name"] = indexName, - ["key"] = new BsonDocument { [vectorPropertyName] = "cosmosSearch" }, + ["key"] = new BsonDocument { [storageName] = "cosmosSearch" }, ["cosmosSearchOptions"] = searchOptions }; @@ -69,11 +69,9 @@ public static BsonArray GetVectorIndexes( /// Returns an array of indexes to create for filterable data properties. /// /// Collection of data properties for index creation. - /// A dictionary that maps from a property name to the storage name. /// Collection of unique existing indexes to avoid creating duplicates. public static BsonArray GetFilterableDataIndexes( - IReadOnlyList dataProperties, - Dictionary storagePropertyNames, + IReadOnlyList dataProperties, HashSet uniqueIndexes) { var indexArray = new BsonArray(); @@ -81,11 +79,10 @@ public static BsonArray GetFilterableDataIndexes( // Create separate index for each data property foreach (var property in dataProperties) { - if (property.IsFilterable) + if (property.IsIndexed) { // Use index name same as data property name with underscore - var dataPropertyName = storagePropertyNames[property.DataModelPropertyName]; - var indexName = $"{dataPropertyName}_"; + var indexName = $"{property.StorageName}_"; // If index already exists, proceed to the next data property if (uniqueIndexes.Contains(indexName)) @@ -97,7 +94,7 @@ public static BsonArray GetFilterableDataIndexes( var indexDocument = new BsonDocument { ["name"] = indexName, - ["key"] = new BsonDocument { [dataPropertyName] = 1 } + ["key"] = new BsonDocument { [property.StorageName] = 1 } }; indexArray.Add(indexDocument); @@ -111,30 +108,22 @@ public static BsonArray GetFilterableDataIndexes( /// More information about Azure CosmosDB for MongoDB index kinds here: . /// private static string GetIndexKind(string? indexKind, string vectorPropertyName) - { - var vectorPropertyIndexKind = AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.GetVectorPropertyIndexKind(indexKind); - - return vectorPropertyIndexKind switch + => AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.GetVectorPropertyIndexKind(indexKind) switch { IndexKind.Hnsw => "vector-hnsw", IndexKind.IvfFlat => "vector-ivf", _ => throw new InvalidOperationException($"Index kind '{indexKind}' on {nameof(VectorStoreRecordVectorProperty)} '{vectorPropertyName}' is not supported by the Azure CosmosDB for MongoDB VectorStore.") }; - } /// /// More information about Azure CosmosDB for MongoDB distance functions here: . /// private static string GetDistanceFunction(string? distanceFunction, string vectorPropertyName) - { - var vectorPropertyDistanceFunction = AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.GetVectorPropertyDistanceFunction(distanceFunction); - - return vectorPropertyDistanceFunction switch + => AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.GetVectorPropertyDistanceFunction(distanceFunction) switch { DistanceFunction.CosineDistance => "COS", DistanceFunction.DotProductSimilarity => "IP", DistanceFunction.EuclideanDistance => "L2", _ => throw new InvalidOperationException($"Distance function '{distanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorPropertyName}' is not supported by the Azure CosmosDB for MongoDB VectorStore.") }; - } } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.cs index 32377244112c..a78ce746f736 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.cs @@ -1,9 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; using System.Linq; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.MongoDB; using MongoDB.Bson; @@ -25,12 +25,10 @@ internal static class AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping /// Build Azure CosmosDB MongoDB filter from the provided . /// /// The to build Azure CosmosDB MongoDB filter from. - /// A dictionary that maps from a property name to the storage name. + /// The model. /// Thrown when the provided filter type is unsupported. /// Thrown when property name specified in filter doesn't exist. - public static BsonDocument? BuildFilter( - VectorSearchFilter? vectorSearchFilter, - Dictionary storagePropertyNames) + public static BsonDocument? BuildFilter(VectorSearchFilter? vectorSearchFilter, VectorStoreRecordModel model) { const string EqualOperator = "$eq"; @@ -63,25 +61,27 @@ internal static class AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping nameof(EqualToFilterClause)])}"); } - if (!storagePropertyNames.TryGetValue(propertyName, out var storagePropertyName)) + if (!model.PropertyMap.TryGetValue(propertyName, out var property)) { throw new InvalidOperationException($"Property name '{propertyName}' provided as part of the filter clause is not a valid property name."); } - if (filter.Contains(storagePropertyName)) + var storageName = property.StorageName; + + if (filter.Contains(storageName)) { - if (filter[storagePropertyName] is BsonDocument document && document.Contains(filterOperator)) + if (filter[storageName] is BsonDocument document && document.Contains(filterOperator)) { throw new NotSupportedException( $"Filter with operator '{filterOperator}' is already added to '{propertyName}' property. " + "Multiple filters of the same type in the same property are not supported."); } - filter[storagePropertyName][filterOperator] = propertyValue; + filter[storageName][filterOperator] = propertyValue; } else { - filter[storagePropertyName] = new BsonDocument() { [filterOperator] = propertyValue }; + filter[storageName] = new BsonDocument() { [filterOperator] = propertyValue }; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreOptions.cs index 8e9b2cccbc6e..45e6363fa3ca 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreOptions.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; @@ -10,7 +11,12 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; public sealed class AzureCosmosDBMongoDBVectorStoreOptions { /// - /// An optional factory to use for constructing instances, if a custom record collection is required. + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } + + /// + /// An optional factory to use for constructing instances, if a custom record collection is required. /// [Obsolete("To control how collections are instantiated, extend your provider's IVectorStore implementation and override GetCollection()")] public IAzureCosmosDBMongoDBVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs index 13e31475447d..225d990fee08 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs @@ -2,13 +2,17 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.Properties; using Microsoft.SemanticKernel.Connectors.MongoDB; using MongoDB.Bson; using MongoDB.Bson.Serialization.Attributes; @@ -20,13 +24,16 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; /// /// Service for storing and retrieving vector records, that uses Azure CosmosDB MongoDB as the underlying storage. /// +/// The data type of the record key. Can be either , or for dynamic mapping. /// The data model to use for adding, updating and retrieving data from storage. #pragma warning disable CA1711 // Identifiers should not have incorrect suffix -public class AzureCosmosDBMongoDBVectorStoreRecordCollection : IVectorStoreRecordCollection +public sealed class AzureCosmosDBMongoDBVectorStoreRecordCollection : IVectorStoreRecordCollection + where TKey : notnull + where TRecord : notnull #pragma warning restore CA1711 // Identifiers should not have incorrect suffix { - /// The name of this database for telemetry purposes. - private const string DatabaseName = "AzureCosmosDBMongoDB"; + /// Metadata about vector store record collection. + private readonly VectorStoreRecordCollectionMetadata _collectionMetadata; /// Property name to be used for search similarity score value. private const string ScorePropertyName = "similarityScore"; @@ -47,60 +54,58 @@ public class AzureCosmosDBMongoDBVectorStoreRecordCollection : IVectorS private readonly AzureCosmosDBMongoDBVectorStoreRecordCollectionOptions _options; /// Interface for mapping between a storage model, and the consumer record data model. - private readonly IVectorStoreRecordMapper _mapper; + private readonly IMongoDBMapper _mapper; - /// A dictionary that maps from a property name to the storage name that should be used when serializing it for data and vector properties. - private readonly Dictionary _storagePropertyNames; - - /// Collection of vector storage property names. - private readonly List _vectorStoragePropertyNames; - - /// A helper to access property information for the current data model and record definition. - private readonly VectorStoreRecordPropertyReader _propertyReader; + /// The model for this collection. + private readonly VectorStoreRecordModel _model; /// - public string CollectionName { get; } + public string Name { get; } /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// that can be used to manage the collections in Azure CosmosDB MongoDB. - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// Optional configuration options for this class. public AzureCosmosDBMongoDBVectorStoreRecordCollection( IMongoDatabase mongoDatabase, - string collectionName, + string name, AzureCosmosDBMongoDBVectorStoreRecordCollectionOptions? options = default) { // Verify. Verify.NotNull(mongoDatabase); - Verify.NotNullOrWhiteSpace(collectionName); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.BsonDocumentCustomMapper is not null, MongoDBConstants.SupportedKeyTypes); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); + Verify.NotNullOrWhiteSpace(name); + + if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(object)) + { + throw new NotSupportedException("Only string keys are supported (and object for dynamic mapping)"); + } // Assign. this._mongoDatabase = mongoDatabase; - this._mongoCollection = mongoDatabase.GetCollection(collectionName); - this.CollectionName = collectionName; + this._mongoCollection = mongoDatabase.GetCollection(name); + this.Name = name; this._options = options ?? new AzureCosmosDBMongoDBVectorStoreRecordCollectionOptions(); - this._propertyReader = new VectorStoreRecordPropertyReader(typeof(TRecord), this._options.VectorStoreRecordDefinition, new() { RequiresAtLeastOneVector = false, SupportsMultipleKeys = false, SupportsMultipleVectors = true }); - - this._storagePropertyNames = GetStoragePropertyNames(this._propertyReader.Properties, typeof(TRecord)); - - // Use Mongo reserved key property name as storage key property name - this._storagePropertyNames[this._propertyReader.KeyPropertyName] = MongoDBConstants.MongoReservedKeyPropertyName; + this._model = new MongoDBModelBuilder().Build(typeof(TRecord), this._options.VectorStoreRecordDefinition, this._options.EmbeddingGenerator); + this._mapper = typeof(TRecord) == typeof(Dictionary) + ? (new MongoDBDynamicDataModelMapper(this._model) as IMongoDBMapper)! + : new MongoDBVectorStoreRecordMapper(this._model); - this._vectorStoragePropertyNames = this._propertyReader.VectorProperties.Select(property => this._storagePropertyNames[property.DataModelPropertyName]).ToList(); - - this._mapper = this.InitializeMapper(); + this._collectionMetadata = new() + { + VectorStoreSystemName = AzureCosmosDBMongoDBConstants.VectorStoreSystemName, + VectorStoreName = mongoDatabase.DatabaseNamespace?.DatabaseName, + CollectionName = name + }; } /// - public virtual Task CollectionExistsAsync(CancellationToken cancellationToken = default) + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) => this.RunOperationAsync("ListCollectionNames", () => this.InternalCollectionExistsAsync(cancellationToken)); /// - public virtual async Task CreateCollectionAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionAsync(CancellationToken cancellationToken = default) { // The IMongoDatabase.CreateCollectionAsync "Creates a new collection if not already available". // To make sure that all the connectors are consistent, we throw when the collection exists. @@ -108,8 +113,9 @@ public virtual async Task CreateCollectionAsync(CancellationToken cancellationTo { throw new VectorStoreOperationException("Collection already exists.") { - VectorStoreType = DatabaseName, - CollectionName = this.CollectionName, + VectorStoreSystemName = AzureCosmosDBMongoDBConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, OperationName = "CreateCollection" }; } @@ -118,50 +124,56 @@ public virtual async Task CreateCollectionAsync(CancellationToken cancellationTo } /// - public virtual async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) { await this.RunOperationAsync("CreateCollection", - () => this._mongoDatabase.CreateCollectionAsync(this.CollectionName, cancellationToken: cancellationToken)).ConfigureAwait(false); + () => this._mongoDatabase.CreateCollectionAsync(this.Name, cancellationToken: cancellationToken)).ConfigureAwait(false); await this.RunOperationAsync("CreateIndexes", - () => this.CreateIndexesAsync(this.CollectionName, cancellationToken: cancellationToken)).ConfigureAwait(false); + () => this.CreateIndexesAsync(this.Name, cancellationToken: cancellationToken)).ConfigureAwait(false); } /// - public virtual async Task DeleteAsync(string key, CancellationToken cancellationToken = default) + public async Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) { - Verify.NotNullOrWhiteSpace(key); + var stringKey = this.GetStringKey(key); - await this.RunOperationAsync("DeleteOne", () => this._mongoCollection.DeleteOneAsync(this.GetFilterById(key), cancellationToken)) + await this.RunOperationAsync("DeleteOne", () => this._mongoCollection.DeleteOneAsync(this.GetFilterById(stringKey), cancellationToken)) .ConfigureAwait(false); } /// - public virtual async Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + public async Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default) { Verify.NotNull(keys); - await this.RunOperationAsync("DeleteMany", () => this._mongoCollection.DeleteManyAsync(this.GetFilterByIds(keys), cancellationToken)) + var stringKeys = keys is IEnumerable k ? k : keys.Cast(); + + await this.RunOperationAsync("DeleteMany", () => this._mongoCollection.DeleteManyAsync(this.GetFilterByIds(stringKeys), cancellationToken)) .ConfigureAwait(false); } /// - public virtual Task DeleteCollectionAsync(CancellationToken cancellationToken = default) - => this.RunOperationAsync("DropCollection", () => this._mongoDatabase.DropCollectionAsync(this.CollectionName, cancellationToken)); + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + => this.RunOperationAsync("DropCollection", () => this._mongoDatabase.DropCollectionAsync(this.Name, cancellationToken)); /// - public virtual async Task GetAsync(string key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public async Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { - Verify.NotNullOrWhiteSpace(key); - const string OperationName = "Find"; + var stringKey = this.GetStringKey(key); + var includeVectors = options?.IncludeVectors ?? false; + if (includeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } var record = await this.RunOperationAsync(OperationName, async () => { using var cursor = await this - .FindAsync(this.GetFilterById(key), options, cancellationToken) + .FindAsync(this.GetFilterById(stringKey), options, cancellationToken) .ConfigureAwait(false); return await cursor.SingleOrDefaultAsync(cancellationToken).ConfigureAwait(false); @@ -173,15 +185,16 @@ public virtual Task DeleteCollectionAsync(CancellationToken cancellationToken = } return VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, + AzureCosmosDBMongoDBConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, OperationName, () => this._mapper.MapFromStorageToDataModel(record, new() { IncludeVectors = includeVectors })); } /// - public virtual async IAsyncEnumerable GetBatchAsync( - IEnumerable keys, + public async IAsyncEnumerable GetAsync( + IEnumerable keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { @@ -189,8 +202,15 @@ public virtual async IAsyncEnumerable GetBatchAsync( const string OperationName = "Find"; + if (options?.IncludeVectors == true && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + + var stringKeys = keys is IEnumerable k ? k : keys.Cast(); + using var cursor = await this - .FindAsync(this.GetFilterByIds(keys), options, cancellationToken) + .FindAsync(this.GetFilterByIds(stringKeys), options, cancellationToken) .ConfigureAwait(false); while (await cursor.MoveNextAsync(cancellationToken).ConfigureAwait(false)) @@ -200,8 +220,9 @@ public virtual async IAsyncEnumerable GetBatchAsync( if (record is not null) { yield return VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, + AzureCosmosDBMongoDBConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, OperationName, () => this._mapper.MapFromStorageToDataModel(record, new())); } @@ -210,55 +231,148 @@ public virtual async IAsyncEnumerable GetBatchAsync( } /// - public virtual Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) { Verify.NotNull(record); const string OperationName = "ReplaceOne"; + Embedding?[]? generatedEmbeddings = null; + + var vectorPropertyCount = this._model.VectorProperties.Count; + for (var i = 0; i < vectorPropertyCount; i++) + { + var vectorProperty = this._model.VectorProperties[i]; + + if (vectorProperty.EmbeddingGenerator is null) + { + continue; + } + + // TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties), + // and generate embeddings for them in a single batch. That's some more complexity though. + if (vectorProperty.TryGenerateEmbedding, ReadOnlyMemory>(record, cancellationToken, out var floatTask)) + { + generatedEmbeddings ??= new Embedding?[vectorPropertyCount]; + generatedEmbeddings[i] = await floatTask.ConfigureAwait(false); + } + else if (vectorProperty.TryGenerateEmbedding, ReadOnlyMemory>(record, cancellationToken, out var doubleTask)) + { + generatedEmbeddings ??= new Embedding?[vectorPropertyCount]; + generatedEmbeddings[i] = await doubleTask.ConfigureAwait(false); + } + else + { + throw new InvalidOperationException( + $"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding).Name}' for the given input type."); + } + } + var replaceOptions = new ReplaceOptions { IsUpsert = true }; var storageModel = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, + AzureCosmosDBMongoDBConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, OperationName, - () => this._mapper.MapFromDataToStorageModel(record)); + () => this._mapper.MapFromDataToStorageModel(record, generatedEmbeddings)); var key = storageModel[MongoDBConstants.MongoReservedKeyPropertyName].AsString; - return this.RunOperationAsync(OperationName, async () => + return await this.RunOperationAsync(OperationName, async () => { await this._mongoCollection .ReplaceOneAsync(this.GetFilterById(key), storageModel, replaceOptions, cancellationToken) .ConfigureAwait(false); - return key; - }); + return (TKey)(object)key; + }).ConfigureAwait(false); } /// - public virtual async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async Task> UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) { Verify.NotNull(records); var tasks = records.Select(record => this.UpsertAsync(record, cancellationToken)); var results = await Task.WhenAll(tasks).ConfigureAwait(false); + return results.Where(r => r is not null).ToList(); + } + + #region Search - foreach (var result in results) + /// + public async IAsyncEnumerable> SearchAsync( + TInput value, + int top, + MEVD.VectorSearchOptions? options = default, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TInput : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + switch (vectorProperty.EmbeddingGenerator) { - if (result is not null) + case IEmbeddingGenerator> generator: { - yield return result; + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; } + + case IEmbeddingGenerator> generator: + { + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; + } + + case null: + throw new InvalidOperationException(VectorDataStrings.NoEmbeddingGeneratorWasConfiguredForSearch); + + default: + throw new InvalidOperationException( + MongoDBConstants.SupportedVectorTypes.Contains(typeof(TInput)) + ? string.Format(VectorDataStrings.EmbeddingTypePassedToSearchAsync) + : string.Format(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType, typeof(TInput).Name, vectorProperty.EmbeddingGenerator.GetType().Name)); } } /// - public virtual async Task> VectorizedSearchAsync( + public IAsyncEnumerable> SearchEmbeddingAsync( TVector vector, + int top, MEVD.VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + return this.SearchCoreAsync(vector, top, vectorProperty, operationName: "SearchEmbedding", options, cancellationToken); + } + + private async IAsyncEnumerable> SearchCoreAsync( + TVector vector, + int top, + VectorStoreRecordVectorPropertyModel vectorProperty, + string operationName, + MEVD.VectorSearchOptions options, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TVector : notnull { Verify.NotNull(vector); + Verify.NotLessThan(top, 1); Array vectorArray = vector switch { @@ -271,25 +385,24 @@ public virtual async Task> VectorizedSearchAsync).FullName])}") }; - var searchOptions = options ?? s_defaultVectorSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(searchOptions); - var vectorPropertyName = this._storagePropertyNames[vectorProperty.DataModelPropertyName]; + if (options.IncludeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } #pragma warning disable CS0618 // VectorSearchFilter is obsolete - var filter = searchOptions switch + var filter = options switch { { OldFilter: not null, Filter: not null } => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), - { OldFilter: VectorSearchFilter legacyFilter } => AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter( - legacyFilter, - this._storagePropertyNames), - { Filter: Expression> newFilter } => new AzureCosmosDBMongoDBFilterTranslator().Translate(newFilter, this._storagePropertyNames), + { OldFilter: VectorSearchFilter legacyFilter } => AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter(legacyFilter, this._model), + { Filter: Expression> newFilter } => new AzureCosmosDBMongoDBFilterTranslator().Translate(newFilter, this._model), _ => null }; #pragma warning restore CS0618 // Constructing a query to fetch "skip + top" total items // to perform skip logic locally, since skip option is not part of API. - var itemsAmount = searchOptions.Skip + searchOptions.Top; + var itemsAmount = options.Skip + top; var vectorPropertyIndexKind = AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.GetVectorPropertyIndexKind(vectorProperty.IndexKind); @@ -297,17 +410,17 @@ public virtual async Task> VectorizedSearchAsync AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.GetSearchQueryForHnswIndex( vectorArray, - vectorPropertyName, + vectorProperty.StorageName, itemsAmount, this._options.EfSearch, filter), IndexKind.IvfFlat => AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.GetSearchQueryForIvfIndex( vectorArray, - vectorPropertyName, + vectorProperty.StorageName, itemsAmount, filter), _ => throw new InvalidOperationException( - $"Index kind '{vectorProperty.IndexKind}' on {nameof(VectorStoreRecordVectorProperty)} '{vectorPropertyName}' is not supported by the Azure CosmosDB for MongoDB VectorStore. " + + $"Index kind '{vectorProperty.IndexKind}' on {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.StorageName}' is not supported by the Azure CosmosDB for MongoDB VectorStore. " + $"Supported index kinds are: {string.Join(", ", [IndexKind.Hnsw, IndexKind.IvfFlat])}") }; @@ -321,7 +434,88 @@ public virtual async Task> VectorizedSearchAsync(pipeline, cancellationToken: cancellationToken) .ConfigureAwait(false); - return new VectorSearchResults(this.EnumerateAndMapSearchResultsAsync(cursor, searchOptions, cancellationToken)); + await foreach (var result in this.EnumerateAndMapSearchResultsAsync(cursor, options, cancellationToken).ConfigureAwait(false)) + { + yield return result; + } + } + + /// + [Obsolete("Use either SearchEmbeddingAsync to search directly on embeddings, or SearchAsync to handle embedding generation internally as part of the call.")] + public IAsyncEnumerable> VectorizedSearchAsync(TVector vector, int top, MEVD.VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + => this.SearchEmbeddingAsync(vector, top, options, cancellationToken); + + #endregion Search + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreRecordCollectionMetadata) ? this._collectionMetadata : + serviceType == typeof(IMongoDatabase) ? this._mongoDatabase : + serviceType == typeof(IMongoCollection) ? this._mongoCollection : + serviceType.IsInstanceOfType(this) ? this : + null; + } + + /// + public async IAsyncEnumerable GetAsync(Expression> filter, int top, + GetFilteredRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(filter); + Verify.NotLessThan(top, 1); + + options ??= new(); + + // Translate the filter now, so if it fails, we throw immediately. + var translatedFilter = new AzureCosmosDBMongoDBFilterTranslator().Translate(filter, this._model); + + SortDefinition? sortDefinition = null; + if (options.OrderBy.Values.Count > 0) + { + sortDefinition = Builders.Sort.Combine( + options.OrderBy.Values.Select(pair => + { + var storageName = this._model.GetDataOrKeyProperty(pair.PropertySelector).StorageName; + + return pair.Ascending + ? Builders.Sort.Ascending(storageName) + : Builders.Sort.Descending(storageName); + })); + } + + using IAsyncCursor cursor = await this.RunOperationAsync( + "GetAsync", + async () => + { + return await this._mongoCollection.FindAsync(translatedFilter, + new() + { + Limit = top, + Skip = options.Skip, + Sort = sortDefinition + }, + cancellationToken: cancellationToken).ConfigureAwait(false); + }).ConfigureAwait(false); + + while (await cursor.MoveNextAsync(cancellationToken).ConfigureAwait(false)) + { + foreach (var response in cursor.Current) + { + var record = VectorStoreErrorHandler.RunModelConversion( + AzureCosmosDBMongoDBConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, + "GetAsync", + () => this._mapper.MapFromStorageToDataModel(response, new() { IncludeVectors = options.IncludeVectors })); + + yield return record; + } + } } #region private @@ -335,15 +529,13 @@ private async Task CreateIndexesAsync(string collectionName, CancellationToken c var indexArray = new BsonArray(); indexArray.AddRange(AzureCosmosDBMongoDBVectorStoreCollectionCreateMapping.GetVectorIndexes( - this._propertyReader.VectorProperties, - this._storagePropertyNames, + this._model.VectorProperties, uniqueIndexes, this._options.NumLists, this._options.EfConstruction)); indexArray.AddRange(AzureCosmosDBMongoDBVectorStoreCollectionCreateMapping.GetFilterableDataIndexes( - this._propertyReader.DataProperties, - this._storagePropertyNames, + this._model.DataProperties, uniqueIndexes)); if (indexArray.Count > 0) @@ -365,13 +557,13 @@ private async Task> FindAsync(FilterDefinition 0) + if (!includeVectors && this._model.VectorProperties.Count > 0) { - foreach (var vectorPropertyName in this._vectorStoragePropertyNames) + foreach (var vectorProperty in this._model.VectorProperties) { projectionDefinition = projectionDefinition is not null ? - projectionDefinition.Exclude(vectorPropertyName) : - projectionBuilder.Exclude(vectorPropertyName); + projectionDefinition.Exclude(vectorProperty.StorageName) : + projectionBuilder.Exclude(vectorProperty.StorageName); } } @@ -399,8 +591,9 @@ private async IAsyncEnumerable> EnumerateAndMapSearc { var score = response[ScorePropertyName].AsDouble; var record = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, + AzureCosmosDBMongoDBConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, OperationName, () => this._mapper.MapFromStorageToDataModel(response[DocumentPropertyName].AsBsonDocument, new())); @@ -420,7 +613,7 @@ private FilterDefinition GetFilterByIds(IEnumerable ids) private async Task InternalCollectionExistsAsync(CancellationToken cancellationToken) { - var filter = new BsonDocument("name", this.CollectionName); + var filter = new BsonDocument("name", this.Name); var options = new ListCollectionNamesOptions { Filter = filter }; using var cursor = await this._mongoDatabase.ListCollectionNamesAsync(options, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -438,8 +631,9 @@ private async Task RunOperationAsync(string operationName, Func operation) { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, - CollectionName = this.CollectionName, + VectorStoreSystemName = AzureCosmosDBMongoDBConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, OperationName = operationName }; } @@ -455,8 +649,9 @@ private async Task RunOperationAsync(string operationName, Func> o { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, - CollectionName = this.CollectionName, + VectorStoreSystemName = AzureCosmosDBMongoDBConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, OperationName = operationName }; } @@ -493,22 +688,15 @@ private static Dictionary GetStoragePropertyNames( return storagePropertyNames; } - /// - /// Returns custom mapper, generic data model mapper or default record mapper. - /// - private IVectorStoreRecordMapper InitializeMapper() + private string GetStringKey(TKey key) { - if (this._options.BsonDocumentCustomMapper is not null) - { - return this._options.BsonDocumentCustomMapper; - } + Verify.NotNull(key); - if (typeof(TRecord) == typeof(VectorStoreGenericDataModel)) - { - return (new MongoDBGenericDataModelMapper(this._propertyReader.RecordDefinition) as IVectorStoreRecordMapper)!; - } + var stringKey = key as string ?? throw new UnreachableException("string key should have been validated during model building"); + + Verify.NotNullOrWhiteSpace(stringKey, nameof(key)); - return new MongoDBVectorStoreRecordMapper(this._propertyReader); + return stringKey; } #endregion diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionOptions.cs index d94c6c0956f1..0eee8b280525 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionOptions.cs @@ -1,18 +1,21 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; using MongoDB.Bson; namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; /// -/// Options when creating a . +/// Options when creating a . /// public sealed class AzureCosmosDBMongoDBVectorStoreRecordCollectionOptions { /// /// Gets or sets an optional custom mapper to use when converting between the data model and the Azure CosmosDB MongoDB BSON object. /// + [Obsolete("Custom mappers are no longer supported.", error: true)] public IVectorStoreRecordMapper? BsonDocumentCustomMapper { get; init; } = null; /// @@ -25,6 +28,11 @@ public sealed class AzureCosmosDBMongoDBVectorStoreRecordCollectionOptions public VectorStoreRecordDefinition? VectorStoreRecordDefinition { get; init; } = null; + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } + /// /// This integer is the number of clusters that the inverted file (IVF) index uses to group the vector data. Default is 1. /// We recommend that numLists is set to documentCount/1000 for up to 1 million documents and to sqrt(documentCount) diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBSimilarityType.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBSimilarityType.cs index d6ae10c7bbb8..72cf0b2774ee 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBSimilarityType.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBSimilarityType.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Reflection; using MongoDB.Bson; using MongoDB.Bson.Serialization.Attributes; @@ -11,7 +11,7 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; /// /// Similarity metric to use with the index. Possible options are COS (cosine distance), L2 (Euclidean distance), and IP (inner product). /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being phased out, use Microsoft.Extensions.VectorData and AzureMongoDBMongoDBVectorStore")] public enum AzureCosmosDBSimilarityType { /// @@ -33,7 +33,7 @@ public enum AzureCosmosDBSimilarityType Euclidean } -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being phased out, use Microsoft.Extensions.VectorData and AzureMongoDBMongoDBVectorStore")] internal static class AzureCosmosDBSimilarityTypeExtensions { public static string GetCustomName(this AzureCosmosDBSimilarityType type) diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBVectorSearchType.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBVectorSearchType.cs index 0bd827257304..a07fe1aebabc 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBVectorSearchType.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBVectorSearchType.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Reflection; using MongoDB.Bson.Serialization.Attributes; @@ -10,7 +10,7 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; /// /// Type of vector index to create. The options are vector-ivf and vector-hnsw. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being phased out, use Microsoft.Extensions.VectorData and AzureMongoDBMongoDBVectorStore")] public enum AzureCosmosDBVectorSearchType { /// @@ -26,7 +26,7 @@ public enum AzureCosmosDBVectorSearchType VectorHNSW } -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being phased out, use Microsoft.Extensions.VectorData and AzureMongoDBMongoDBVectorStore")] internal static class AzureCosmosDBVectorSearchTypeExtensions { public static string GetCustomName(this AzureCosmosDBVectorSearchType type) diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/Connectors.Memory.AzureCosmosDBMongoDB.csproj b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/Connectors.Memory.AzureCosmosDBMongoDB.csproj index 3822d58d8b27..8e7eda4105ac 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/Connectors.Memory.AzureCosmosDBMongoDB.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/Connectors.Memory.AzureCosmosDBMongoDB.csproj @@ -4,13 +4,14 @@ Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB $(AssemblyName) - net8.0;netstandard2.0 + net8.0;netstandard2.0;net462 preview + @@ -31,6 +32,13 @@ + + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/IAzureCosmosDBMongoDBVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/IAzureCosmosDBMongoDBVectorStoreRecordCollectionFactory.cs index 5aeec3f3f4ff..e1ca5eebcbe9 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/IAzureCosmosDBMongoDBVectorStoreRecordCollectionFactory.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/IAzureCosmosDBMongoDBVectorStoreRecordCollectionFactory.cs @@ -22,5 +22,6 @@ public interface IAzureCosmosDBMongoDBVectorStoreRecordCollectionFactory /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. /// The new instance of . IVectorStoreRecordCollection CreateVectorStoreRecordCollection(IMongoDatabase mongoDatabase, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) - where TKey : notnull; + where TKey : notnull + where TRecord : notnull; } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLConstants.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLConstants.cs index 6dbb0d440b45..f667488e7fd9 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLConstants.cs @@ -4,6 +4,8 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; internal static class AzureCosmosDBNoSQLConstants { + internal const string VectorStoreSystemName = "azure.cosmosdbnosql"; + /// /// Reserved key property name in Azure CosmosDB NoSQL. /// diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLDynamicDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLDynamicDataModelMapper.cs new file mode 100644 index 000000000000..0d9fec10f740 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLDynamicDataModelMapper.cs @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text.Json; +using System.Text.Json.Nodes; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using MEAI = Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; + +/// +/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Azure CosmosDB NoSQL. +/// +internal sealed class AzureCosmosDBNoSQLDynamicDataModelMapper(VectorStoreRecordModel model, JsonSerializerOptions jsonSerializerOptions) + : ICosmosNoSQLMapper> +{ + /// A default for serialization/deserialization of vector properties. + private static readonly JsonSerializerOptions s_vectorJsonSerializerOptions = new() + { + Converters = { new AzureCosmosDBNoSQLReadOnlyMemoryByteConverter() } + }; + + public JsonObject MapFromDataToStorageModel(Dictionary dataModel, MEAI.Embedding?[]? generatedEmbeddings) + { + Verify.NotNull(dataModel); + + var jsonObject = new JsonObject(); + + jsonObject[AzureCosmosDBNoSQLConstants.ReservedKeyPropertyName] = !dataModel.TryGetValue(model.KeyProperty.ModelName, out var keyValue) + ? throw new InvalidOperationException($"Missing value for key property '{model.KeyProperty.ModelName}") + : keyValue switch + { + string s => s, + null => throw new InvalidOperationException($"Key property '{model.KeyProperty.ModelName}' is null."), + _ => throw new InvalidCastException($"Key property '{model.KeyProperty.ModelName}' must be a string.") + }; + + foreach (var dataProperty in model.DataProperties) + { + if (dataModel.TryGetValue(dataProperty.StorageName, out var dataValue)) + { + jsonObject[dataProperty.StorageName] = dataValue is not null ? + JsonSerializer.SerializeToNode(dataValue, dataProperty.Type, jsonSerializerOptions) : + null; + } + } + + for (var i = 0; i < model.VectorProperties.Count; i++) + { + var property = model.VectorProperties[i]; + + if (generatedEmbeddings?[i] is null) + { + // No generated embedding, read the vector directly from the data model + if (dataModel.TryGetValue(property.ModelName, out var sourceValue)) + { + jsonObject.Add(property.StorageName, sourceValue is null + ? null + : JsonSerializer.SerializeToNode(sourceValue, property.Type, s_vectorJsonSerializerOptions)); + } + } + else + { + Debug.Assert(property.EmbeddingGenerator is not null); + var embedding = generatedEmbeddings[i]; + jsonObject.Add( + property.StorageName, + embedding switch + { + MEAI.Embedding e => JsonSerializer.SerializeToNode(e.Vector, s_vectorJsonSerializerOptions), + MEAI.Embedding e => JsonSerializer.SerializeToNode(e.Vector, s_vectorJsonSerializerOptions), + MEAI.Embedding e => JsonSerializer.SerializeToNode(e.Vector, s_vectorJsonSerializerOptions), + _ => throw new UnreachableException() + }); + } + } + + return jsonObject; + } + + public Dictionary MapFromStorageToDataModel(JsonObject storageModel, StorageToDataModelMapperOptions options) + { + Verify.NotNull(storageModel); + + var result = new Dictionary(); + + // Loop through all known properties and map each from the storage model to the data model. + foreach (var property in model.Properties) + { + switch (property) + { + case VectorStoreRecordKeyPropertyModel keyProperty: + result[keyProperty.ModelName] = storageModel.TryGetPropertyValue(AzureCosmosDBNoSQLConstants.ReservedKeyPropertyName, out var keyValue) + ? keyValue?.GetValue() + : throw new VectorStoreRecordMappingException("No key property was found in the record retrieved from storage."); + continue; + + case VectorStoreRecordDataPropertyModel dataProperty: + if (storageModel.TryGetPropertyValue(dataProperty.StorageName, out var dataValue)) + { + result.Add(property.ModelName, dataValue.Deserialize(property.Type, jsonSerializerOptions)); + } + continue; + + case VectorStoreRecordVectorPropertyModel vectorProperty: + if (options.IncludeVectors && storageModel.TryGetPropertyValue(vectorProperty.StorageName, out var vectorValue)) + { + result.Add(property.ModelName, vectorValue.Deserialize(property.Type, s_vectorJsonSerializerOptions)); + } + continue; + + default: + throw new UnreachableException(); + } + } + + return result; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLGenericDataModelMapper.cs deleted file mode 100644 index 2a52c2604a4d..000000000000 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLGenericDataModelMapper.cs +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Collections.Generic; -using System.Text.Json; -using System.Text.Json.Nodes; -using Microsoft.Extensions.VectorData; - -namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; - -/// -/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Azure CosmosDB NoSQL. -/// -internal sealed class AzureCosmosDBNoSQLGenericDataModelMapper : IVectorStoreRecordMapper, JsonObject> -{ - /// A default for serialization/deserialization of vector properties. - private static readonly JsonSerializerOptions s_vectorJsonSerializerOptions = new() - { - Converters = { new AzureCosmosDBNoSQLReadOnlyMemoryByteConverter() } - }; - - /// A for serialization/deserialization of data properties - private readonly JsonSerializerOptions _jsonSerializerOptions; - - /// The list of properties from the record definition. - private readonly IReadOnlyList _properties; - - /// A dictionary that maps from a property name to the storage name. - public readonly Dictionary _storagePropertyNames; - - public AzureCosmosDBNoSQLGenericDataModelMapper( - IReadOnlyList properties, - Dictionary storagePropertyNames, - JsonSerializerOptions jsonSerializerOptions) - { - Verify.NotNull(properties); - - this._properties = properties; - this._storagePropertyNames = storagePropertyNames; - this._jsonSerializerOptions = jsonSerializerOptions; - } - - public JsonObject MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - { - Verify.NotNull(dataModel); - - var jsonObject = new JsonObject(); - - // Loop through all known properties and map each from the data model to the storage model. - foreach (var property in this._properties) - { - var storagePropertyName = this._storagePropertyNames[property.DataModelPropertyName]; - - if (property is VectorStoreRecordKeyProperty keyProperty) - { - jsonObject[AzureCosmosDBNoSQLConstants.ReservedKeyPropertyName] = dataModel.Key; - } - else if (property is VectorStoreRecordDataProperty dataProperty) - { - if (dataModel.Data is not null && dataModel.Data.TryGetValue(dataProperty.DataModelPropertyName, out var dataValue)) - { - jsonObject[storagePropertyName] = dataValue is not null ? - JsonSerializer.SerializeToNode(dataValue, property.PropertyType, this._jsonSerializerOptions) : - null; - } - } - else if (property is VectorStoreRecordVectorProperty vectorProperty) - { - if (dataModel.Vectors is not null && dataModel.Vectors.TryGetValue(vectorProperty.DataModelPropertyName, out var vectorValue)) - { - jsonObject[storagePropertyName] = vectorValue is not null ? - JsonSerializer.SerializeToNode(vectorValue, property.PropertyType, s_vectorJsonSerializerOptions) : - null; - } - } - } - - return jsonObject; - } - - public VectorStoreGenericDataModel MapFromStorageToDataModel(JsonObject storageModel, StorageToDataModelMapperOptions options) - { - Verify.NotNull(storageModel); - - // Create variables to store the response properties. - string? key = null; - var dataProperties = new Dictionary(); - var vectorProperties = new Dictionary(); - - // Loop through all known properties and map each from the storage model to the data model. - foreach (var property in this._properties) - { - var storagePropertyName = this._storagePropertyNames[property.DataModelPropertyName]; - - if (property is VectorStoreRecordKeyProperty keyProperty) - { - if (storageModel.TryGetPropertyValue(AzureCosmosDBNoSQLConstants.ReservedKeyPropertyName, out var keyValue)) - { - key = keyValue?.GetValue(); - } - } - else if (property is VectorStoreRecordDataProperty dataProperty) - { - if (storageModel.TryGetPropertyValue(storagePropertyName, out var dataValue)) - { - dataProperties.Add(property.DataModelPropertyName, dataValue.Deserialize(property.PropertyType, this._jsonSerializerOptions)); - } - } - else if (property is VectorStoreRecordVectorProperty vectorProperty && options.IncludeVectors) - { - if (storageModel.TryGetPropertyValue(storagePropertyName, out var vectorValue)) - { - vectorProperties.Add(property.DataModelPropertyName, vectorValue.Deserialize(property.PropertyType, s_vectorJsonSerializerOptions)); - } - } - } - - if (key is null) - { - throw new VectorStoreRecordMappingException("No key property was found in the record retrieved from storage."); - } - - return new VectorStoreGenericDataModel(key) { Data = dataProperties, Vectors = vectorProperties }; - } -} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLKernelBuilderExtensions.cs index 12f7c0118538..f2b914078f14 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLKernelBuilderExtensions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using Microsoft.Azure.Cosmos; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; @@ -9,6 +10,7 @@ namespace Microsoft.SemanticKernel; /// /// Extension methods to register Azure CosmosDB NoSQL instances on the . /// +[Obsolete("The IKernelBuilder extensions are being obsoleted, call the appropriate function on the Services property of your IKernelBuilder")] public static class AzureCosmosDBNoSQLKernelBuilderExtensions { /// @@ -55,7 +57,7 @@ public static IKernelBuilder AddAzureCosmosDBNoSQLVectorStore( } /// - /// Register an Azure CosmosDB NoSQL and with the specified service ID + /// Register an Azure CosmosDB NoSQL and with the specified service ID /// and where the Azure CosmosDB NoSQL is retrieved from the dependency injection container. /// /// The type of the record. @@ -69,13 +71,14 @@ public static IKernelBuilder AddAzureCosmosDBNoSQLVectorStoreRecordCollection? options = default, string? serviceId = default) + where TRecord : notnull { builder.Services.AddAzureCosmosDBNoSQLVectorStoreRecordCollection(collectionName, options, serviceId); return builder; } /// - /// Register an Azure CosmosDB NoSQL and with the specified service ID + /// Register an Azure CosmosDB NoSQL and with the specified service ID /// and where the Azure CosmosDB NoSQL is constructed using the provided and . /// /// The type of the record. @@ -93,6 +96,7 @@ public static IKernelBuilder AddAzureCosmosDBNoSQLVectorStoreRecordCollection? options = default, string? serviceId = default) + where TRecord : notnull { builder.Services.AddAzureCosmosDBNoSQLVectorStoreRecordCollection(collectionName, connectionString, databaseName, options, serviceId); return builder; diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs index ab898fad6c13..08a2948e8cff 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Collections.ObjectModel; using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; using System.Text; @@ -18,11 +17,13 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// An implementation of backed by a Azure Cosmos DB database. /// Get more details about Azure Cosmos DB vector search https://learn.microsoft.com/en-us/azure/cosmos-db/ /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and AzureCosmosDBNoSQLVectorStore")] public class AzureCosmosDBNoSQLMemoryStore : IMemoryStore, IDisposable { private const string EmbeddingPath = "/embedding"; @@ -289,7 +290,7 @@ public async IAsyncEnumerable GetBatchAsync( var queryStart = $""" SELECT x.id,x.key,x.metadata,x.timestamp{(withEmbeddings ? ",x.embedding" : "")} FROM x - WHERE + WHERE """; // NOTE: Cosmos DB queries are limited to 512kB, so we'll break this into chunks // of around 500kB. We don't go all the way to 512kB so that we don't have to @@ -446,7 +447,7 @@ protected virtual void Dispose(bool disposing) /// [DebuggerDisplay("{GetDebuggerDisplay()}")] #pragma warning disable CA1812 // 'MemoryRecordWithSimilarityScore' is an internal class that is apparently never instantiated. If so, remove the code from the assembly. If this class is intended to contain only static members, make it 'static' (Module in Visual Basic). (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1812) -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and AzureCosmosDBNoSQLVectorStore")] internal sealed class MemoryRecordWithSimilarityScore( #pragma warning restore CA1812 MemoryRecordMetadata metadata, @@ -468,7 +469,7 @@ private string GetDebuggerDisplay() /// /// Creates a new record that also serializes an "id" property. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and AzureCosmosDBNoSQLVectorStore")] [DebuggerDisplay("{GetDebuggerDisplay()}")] internal sealed class MemoryRecordWithId : MemoryRecord { diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLServiceCollectionExtensions.cs index 1c70d360ee62..de7910ba078d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLServiceCollectionExtensions.cs @@ -2,6 +2,7 @@ using System.Text.Json; using Microsoft.Azure.Cosmos; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; @@ -34,7 +35,10 @@ public static IServiceCollection AddAzureCosmosDBNoSQLVectorStore( (sp, obj) => { var database = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService(); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; return new AzureCosmosDBNoSQLVectorStore(database, options); }); @@ -71,7 +75,10 @@ public static IServiceCollection AddAzureCosmosDBNoSQLVectorStore( }); var database = cosmosClient.GetDatabase(databaseName); - var selectedOptions = options ?? sp.GetService(); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; return new AzureCosmosDBNoSQLVectorStore(database, options); }); @@ -80,7 +87,7 @@ public static IServiceCollection AddAzureCosmosDBNoSQLVectorStore( } /// - /// Register an Azure CosmosDB NoSQL and with the specified service ID + /// Register an Azure CosmosDB NoSQL and with the specified service ID /// and where the Azure CosmosDB NoSQL is retrieved from the dependency injection container. /// /// The type of the record. @@ -94,15 +101,19 @@ public static IServiceCollection AddAzureCosmosDBNoSQLVectorStoreRecordCollectio string collectionName, AzureCosmosDBNoSQLVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { services.AddKeyedTransient>( serviceId, (sp, obj) => { var database = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new AzureCosmosDBNoSQLVectorStoreRecordCollection(database, collectionName, selectedOptions); + return new AzureCosmosDBNoSQLVectorStoreRecordCollection(database, collectionName, options); }); AddVectorizedSearch(services, serviceId); @@ -111,7 +122,7 @@ public static IServiceCollection AddAzureCosmosDBNoSQLVectorStoreRecordCollectio } /// - /// Register an Azure CosmosDB NoSQL and with the specified service ID + /// Register an Azure CosmosDB NoSQL and with the specified service ID /// and where the Azure CosmosDB NoSQL is constructed using the provided and . /// /// The type of the record. @@ -129,6 +140,7 @@ public static IServiceCollection AddAzureCosmosDBNoSQLVectorStoreRecordCollectio string databaseName, AzureCosmosDBNoSQLVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { services.AddKeyedSingleton>( serviceId, @@ -141,9 +153,12 @@ public static IServiceCollection AddAzureCosmosDBNoSQLVectorStoreRecordCollectio }); var database = cosmosClient.GetDatabase(databaseName); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new AzureCosmosDBNoSQLVectorStoreRecordCollection(database, collectionName, selectedOptions); + return new AzureCosmosDBNoSQLVectorStoreRecordCollection(database, collectionName, options); }); AddVectorizedSearch(services, serviceId); @@ -152,14 +167,14 @@ public static IServiceCollection AddAzureCosmosDBNoSQLVectorStoreRecordCollectio } /// - /// Also register the with the given as a . + /// Also register the with the given as a . /// /// The type of the data model that the collection should contain. /// The service collection to register on. /// The service id that the registrations should use. - private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) + private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) where TRecord : notnull { - services.AddKeyedTransient>( + services.AddKeyedTransient>( serviceId, (sp, obj) => { diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStore.cs index 39320e0a8ae2..d74c4df364fc 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStore.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Threading; +using System.Threading.Tasks; using Microsoft.Azure.Cosmos; using Microsoft.Extensions.VectorData; @@ -15,14 +16,20 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; /// /// This class can be used with collections of any schema type, but requires you to provide schema information when getting a collection. /// -public class AzureCosmosDBNoSQLVectorStore : IVectorStore +public sealed class AzureCosmosDBNoSQLVectorStore : IVectorStore { + /// Metadata about vector store. + private readonly VectorStoreMetadata _metadata; + /// that can be used to manage the collections in Azure CosmosDB NoSQL. private readonly Database _database; /// Optional configuration options for this class. private readonly AzureCosmosDBNoSQLVectorStoreOptions _options; + /// A general purpose definition that can be used to construct a collection when needing to proxy schema agnostic operations. + private static readonly VectorStoreRecordDefinition s_generalPurposeDefinition = new() { Properties = [new VectorStoreRecordKeyProperty("Key", typeof(string))] }; + /// /// Initializes a new instance of the class. /// @@ -34,11 +41,18 @@ public AzureCosmosDBNoSQLVectorStore(Database database, AzureCosmosDBNoSQLVector this._database = database; this._options = options ?? new(); + + this._metadata = new() + { + VectorStoreSystemName = AzureCosmosDBNoSQLConstants.VectorStoreSystemName, + VectorStoreName = database.Id + }; } /// - public virtual IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull + where TRecord : notnull { #pragma warning disable CS0618 // IAzureCosmosDBNoSQLVectorStoreRecordCollectionFactory is obsolete if (this._options.VectorStoreCollectionFactory is not null) @@ -50,25 +64,21 @@ public virtual IVectorStoreRecordCollection GetCollection( + var recordCollection = new AzureCosmosDBNoSQLVectorStoreRecordCollection( this._database, name, new() { VectorStoreRecordDefinition = vectorStoreRecordDefinition, - JsonSerializerOptions = this._options.JsonSerializerOptions + JsonSerializerOptions = this._options.JsonSerializerOptions, + EmbeddingGenerator = this._options.EmbeddingGenerator }) as IVectorStoreRecordCollection; return recordCollection!; } /// - public virtual async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { const string Query = "SELECT VALUE(c.id) FROM c"; @@ -84,4 +94,31 @@ public virtual async IAsyncEnumerable ListCollectionNamesAsync([Enumerat } } } + + /// + public Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.CollectionExistsAsync(cancellationToken); + } + + /// + public Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.DeleteCollectionAsync(cancellationToken); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreMetadata) ? this._metadata : + serviceType == typeof(Database) ? this._database : + serviceType.IsInstanceOfType(this) ? this : + null; + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.cs index d3ae19517db5..258e3a4755ac 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.cs @@ -7,6 +7,7 @@ using System.Text; using Microsoft.Azure.Cosmos; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; @@ -25,8 +26,7 @@ internal static class AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder public static QueryDefinition BuildSearchQuery( TVector vector, ICollection? keywords, - List fields, - Dictionary storagePropertyNames, + VectorStoreRecordModel model, string vectorPropertyName, string? textPropertyName, string scorePropertyName, @@ -35,7 +35,8 @@ public static QueryDefinition BuildSearchQuery( #pragma warning restore CS0618 // Type or member is obsolete Expression>? filter, int top, - int skip) + int skip, + bool includeVectors) { Verify.NotNull(vector); @@ -45,7 +46,12 @@ public static QueryDefinition BuildSearchQuery( var tableVariableName = AzureCosmosDBNoSQLConstants.ContainerAlias; - var fieldsArgument = fields.Select(field => $"{tableVariableName}.{field}"); + IEnumerable projectionProperties = model.Properties; + if (!includeVectors) + { + projectionProperties = projectionProperties.Where(p => p is not VectorStoreRecordVectorPropertyModel); + } + var fieldsArgument = projectionProperties.Select(p => $"{tableVariableName}.{p.StorageName}"); var vectorDistanceArgument = $"VectorDistance({tableVariableName}.{vectorPropertyName}, {VectorVariableName})"; var vectorDistanceArgumentWithAlias = $"{vectorDistanceArgument} AS {scorePropertyName}"; @@ -63,8 +69,8 @@ public static QueryDefinition BuildSearchQuery( var (whereClause, filterParameters) = (OldFilter: oldFilter, Filter: filter) switch { { OldFilter: not null, Filter: not null } => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), - { OldFilter: VectorSearchFilter legacyFilter } => BuildSearchFilter(legacyFilter, storagePropertyNames), - { Filter: Expression> newFilter } => new AzureCosmosDBNoSqlFilterTranslator().Translate(newFilter, storagePropertyNames), + { OldFilter: VectorSearchFilter legacyFilter } => BuildSearchFilter(legacyFilter, model), + { Filter: Expression> newFilter } => new AzureCosmosDBNoSqlFilterTranslator().Translate(newFilter, model), _ => (null, []) }; #pragma warning restore CS0618 // VectorSearchFilter is obsolete @@ -119,14 +125,73 @@ public static QueryDefinition BuildSearchQuery( return queryDefinition; } + internal static QueryDefinition BuildSearchQuery( + VectorStoreRecordModel model, + string whereClause, Dictionary filterParameters, + GetFilteredRecordOptions filterOptions, + int top) + { + var tableVariableName = AzureCosmosDBNoSQLConstants.ContainerAlias; + + IEnumerable projectionProperties = model.Properties; + if (!filterOptions.IncludeVectors) + { + projectionProperties = projectionProperties.Where(p => p is not VectorStoreRecordVectorPropertyModel); + } + + var fieldsArgument = projectionProperties.Select(field => $"{tableVariableName}.{field.StorageName}"); + + var selectClauseArguments = string.Join(SelectClauseDelimiter, [.. fieldsArgument]); + + // If Offset is not configured, use Top parameter instead of Limit/Offset + // since it's more optimized. + var topArgument = filterOptions.Skip == 0 ? $"TOP {top} " : string.Empty; + + var builder = new StringBuilder(); + + builder.AppendLine($"SELECT {topArgument}{selectClauseArguments}"); + builder.AppendLine($"FROM {tableVariableName}"); + builder.Append("WHERE ").AppendLine(whereClause); + + if (filterOptions.OrderBy.Values.Count > 0) + { + builder.Append("ORDER BY "); + + foreach (var sortInfo in filterOptions.OrderBy.Values) + { + builder.AppendFormat("{0}.{1} {2},", tableVariableName, + model.GetDataOrKeyProperty(sortInfo.PropertySelector).StorageName, + sortInfo.Ascending ? "ASC" : "DESC"); + } + + builder.Length--; // remove the last comma + builder.AppendLine(); + } + + if (string.IsNullOrEmpty(topArgument)) + { + builder.AppendLine($"OFFSET {filterOptions.Skip} LIMIT {top}"); + } + + var queryDefinition = new QueryDefinition(builder.ToString()); + + foreach (var queryParameter in filterParameters) + { + queryDefinition.WithParameter(queryParameter.Key, queryParameter.Value); + } + + return queryDefinition; + } + /// /// Builds to get items from Azure CosmosDB NoSQL. /// public static QueryDefinition BuildSelectQuery( + VectorStoreRecordModel model, string keyStoragePropertyName, string partitionKeyStoragePropertyName, List keys, - List fields) + bool includeVectors) { Verify.True(keys.Count > 0, "At least one key should be provided.", nameof(keys)); @@ -135,8 +200,14 @@ public static QueryDefinition BuildSelectQuery( var tableVariableName = AzureCosmosDBNoSQLConstants.ContainerAlias; - var selectClauseArguments = string.Join(SelectClauseDelimiter, - fields.Select(field => $"{tableVariableName}.{field}")); + IEnumerable projectionProperties = model.Properties; + if (!includeVectors) + { + projectionProperties = projectionProperties.Where(p => p is not VectorStoreRecordVectorPropertyModel); + } + var fields = projectionProperties.Select(field => field.StorageName); + + var selectClauseArguments = string.Join(SelectClauseDelimiter, fields.Select(field => $"{tableVariableName}.{field}")); var whereClauseArguments = string.Join(OrConditionDelimiter, keys.Select((key, index) => @@ -171,7 +242,7 @@ public static QueryDefinition BuildSelectQuery( #pragma warning disable CS0618 // VectorSearchFilter is obsolete private static (string WhereClause, Dictionary Parameters) BuildSearchFilter( VectorSearchFilter filter, - Dictionary storagePropertyNames) + VectorStoreRecordModel model) { const string EqualOperator = "="; const string ArrayContainsOperator = "ARRAY_CONTAINS"; @@ -197,13 +268,13 @@ private static (string WhereClause, Dictionary Parameters) Buil if (filterClause is EqualToFilterClause equalToFilterClause) { - var propertyName = GetStoragePropertyName(equalToFilterClause.FieldName, storagePropertyNames); + var propertyName = GetStoragePropertyName(equalToFilterClause.FieldName, model); whereClauseBuilder.Append($"{tableVariableName}.{propertyName} {EqualOperator} {queryParameterName}"); queryParameterValue = equalToFilterClause.Value; } else if (filterClause is AnyTagEqualToFilterClause anyTagEqualToFilterClause) { - var propertyName = GetStoragePropertyName(anyTagEqualToFilterClause.FieldName, storagePropertyNames); + var propertyName = GetStoragePropertyName(anyTagEqualToFilterClause.FieldName, model); whereClauseBuilder.Append($"{ArrayContainsOperator}({tableVariableName}.{propertyName}, {queryParameterName})"); queryParameterValue = anyTagEqualToFilterClause.Value; } @@ -223,14 +294,14 @@ private static (string WhereClause, Dictionary Parameters) Buil } #pragma warning restore CS0618 // VectorSearchFilter is obsolete - private static string GetStoragePropertyName(string propertyName, Dictionary storagePropertyNames) + private static string GetStoragePropertyName(string propertyName, VectorStoreRecordModel model) { - if (!storagePropertyNames.TryGetValue(propertyName, out var storagePropertyName)) + if (!model.PropertyMap.TryGetValue(propertyName, out var property)) { throw new InvalidOperationException($"Property name '{propertyName}' provided as part of the filter clause is not a valid property name."); } - return storagePropertyName; + return property.StorageName; } #endregion diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreOptions.cs index edbfe436f136..6120a3c26630 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreOptions.cs @@ -2,6 +2,7 @@ using System; using System.Text.Json; +using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; @@ -11,7 +12,7 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; public sealed class AzureCosmosDBNoSQLVectorStoreOptions { /// - /// An optional factory to use for constructing instances, if a custom record collection is required. + /// An optional factory to use for constructing instances, if a custom record collection is required. /// [Obsolete("To control how collections are instantiated, extend your provider's IVectorStore implementation and override GetCollection()")] public IAzureCosmosDBNoSQLVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } @@ -20,4 +21,9 @@ public sealed class AzureCosmosDBNoSQLVectorStoreOptions /// Gets or sets the JSON serializer options to use when converting between the data model and the Azure CosmosDB NoSQL record. /// public JsonSerializerOptions? JsonSerializerOptions { get; init; } + + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollection.cs index dd0e245c8004..9c566ee111cf 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollection.cs @@ -3,16 +3,22 @@ using System; using System.Collections.Generic; using System.Collections.ObjectModel; +using System.Diagnostics; using System.Linq; +using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Nodes; using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Cosmos; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.Properties; using DistanceFunction = Microsoft.Azure.Cosmos.DistanceFunction; using IndexKind = Microsoft.Extensions.VectorData.IndexKind; +using MEAI = Microsoft.Extensions.AI; using SKDistanceFunction = Microsoft.Extensions.VectorData.DistanceFunction; namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; @@ -20,54 +26,16 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; /// /// Service for storing and retrieving vector records, that uses Azure CosmosDB NoSQL as the underlying storage. /// +/// The data type of the record key. Can be either , or for dynamic mapping. /// The data model to use for adding, updating and retrieving data from storage. #pragma warning disable CA1711 // Identifiers should not have incorrect suffix -public class AzureCosmosDBNoSQLVectorStoreRecordCollection : - IVectorStoreRecordCollection, - IVectorStoreRecordCollection, - IKeywordHybridSearch +public sealed class AzureCosmosDBNoSQLVectorStoreRecordCollection : IVectorStoreRecordCollection, IKeywordHybridSearch + where TKey : notnull + where TRecord : notnull #pragma warning restore CA1711 // Identifiers should not have incorrect { - /// The name of this database for telemetry purposes. - private const string DatabaseName = "AzureCosmosDBNoSQL"; - - /// A of types that a key on the provided model may have. - private static readonly HashSet s_supportedKeyTypes = - [ - typeof(string) - ]; - - /// A of types that data properties on the provided model may have. - private static readonly HashSet s_supportedDataTypes = - [ - typeof(bool), - typeof(bool?), - typeof(string), - typeof(int), - typeof(int?), - typeof(long), - typeof(long?), - typeof(float), - typeof(float?), - typeof(double), - typeof(double?), - typeof(DateTimeOffset), - typeof(DateTimeOffset?), - ]; - - /// A of types that vector properties on the provided model may have, based on enumeration. - private static readonly HashSet s_supportedVectorTypes = - [ - // Float32 - typeof(ReadOnlyMemory), - typeof(ReadOnlyMemory?), - // Uint8 - typeof(ReadOnlyMemory), - typeof(ReadOnlyMemory?), - // Int8 - typeof(ReadOnlyMemory), - typeof(ReadOnlyMemory?), - ]; + /// Metadata about vector store record collection. + private readonly VectorStoreRecordCollectionMetadata _collectionMetadata; /// The default options for vector search. private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); @@ -81,99 +49,96 @@ public class AzureCosmosDBNoSQLVectorStoreRecordCollection : /// Optional configuration options for this class. private readonly AzureCosmosDBNoSQLVectorStoreRecordCollectionOptions _options; - /// A helper to access property information for the current data model and record definition. - private readonly VectorStoreRecordPropertyReader _propertyReader; + /// The model for this collection. + private readonly VectorStoreRecordModel _model; - /// The storage names of all non vector fields on the current model. - private readonly List _nonVectorStoragePropertyNames = []; - - /// A dictionary that maps from a property name to the storage name that should be used when serializing it to json for data and vector properties. - private readonly Dictionary _storagePropertyNames = []; - - /// The storage name of the key field for the collections that this class is used with. - private readonly string _keyStoragePropertyName; - - /// The property name to use as partition key. - private readonly string _partitionKeyPropertyName; - - /// The storage property name to use as partition key. - private readonly string _partitionKeyStoragePropertyName; + // TODO: Refactor this into the model (Co) + /// The property to use as partition key. + private readonly VectorStoreRecordPropertyModel _partitionKeyProperty; /// The mapper to use when mapping between the consumer data model and the Azure CosmosDB NoSQL record. - private readonly IVectorStoreRecordMapper _mapper; + private readonly ICosmosNoSQLMapper _mapper; /// - public string CollectionName { get; } + public string Name { get; } /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// that can be used to manage the collections in Azure CosmosDB NoSQL. - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// Optional configuration options for this class. public AzureCosmosDBNoSQLVectorStoreRecordCollection( Database database, - string collectionName, + string name, AzureCosmosDBNoSQLVectorStoreRecordCollectionOptions? options = default) { // Verify. Verify.NotNull(database); - Verify.NotNullOrWhiteSpace(collectionName); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.JsonObjectCustomMapper is not null, s_supportedKeyTypes); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); + Verify.NotNullOrWhiteSpace(name); + + if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(AzureCosmosDBNoSQLCompositeKey) && typeof(TKey) != typeof(object)) + { + throw new NotSupportedException($"Only {nameof(String)} and {nameof(AzureCosmosDBNoSQLCompositeKey)} keys are supported (and object for dynamic mapping)."); + } + + if (database.Client?.ClientOptions?.UseSystemTextJsonSerializerWithOptions is null) + { + throw new ArgumentException( + $"Property {nameof(CosmosClientOptions.UseSystemTextJsonSerializerWithOptions)} in CosmosClient.ClientOptions " + + $"is required to be configured for {nameof(AzureCosmosDBNoSQLVectorStoreRecordCollection)}."); + } // Assign. this._database = database; - this.CollectionName = collectionName; + this.Name = name; this._options = options ?? new(); var jsonSerializerOptions = this._options.JsonSerializerOptions ?? JsonSerializerOptions.Default; - this._propertyReader = new VectorStoreRecordPropertyReader(typeof(TRecord), this._options.VectorStoreRecordDefinition, new() - { - RequiresAtLeastOneVector = false, - SupportsMultipleKeys = false, - SupportsMultipleVectors = true, - JsonSerializerOptions = jsonSerializerOptions - }); - - // Validate property types. - this._propertyReader.VerifyKeyProperties(s_supportedKeyTypes); - this._propertyReader.VerifyDataProperties(s_supportedDataTypes, supportEnumerable: true); - this._propertyReader.VerifyVectorProperties(s_supportedVectorTypes); - - // Get storage names and store for later use. - this._storagePropertyNames = this._propertyReader.JsonPropertyNamesMap.ToDictionary(x => x.Key, x => x.Value); + this._model = new AzureCosmosDBNoSQLVectorStoreModelBuilder() + .Build(typeof(TRecord), this._options.VectorStoreRecordDefinition, this._options.EmbeddingGenerator, jsonSerializerOptions); // Assign mapper. - this._mapper = this.InitializeMapper(jsonSerializerOptions); - - // Use Azure CosmosDB NoSQL reserved key property name as storage key property name. - this._storagePropertyNames[this._propertyReader.KeyPropertyName] = AzureCosmosDBNoSQLConstants.ReservedKeyPropertyName; - this._keyStoragePropertyName = AzureCosmosDBNoSQLConstants.ReservedKeyPropertyName; + this._mapper = typeof(TRecord) == typeof(Dictionary) + ? (new AzureCosmosDBNoSQLDynamicDataModelMapper(this._model, jsonSerializerOptions) as ICosmosNoSQLMapper)! + : new AzureCosmosDBNoSQLVectorStoreRecordMapper(this._model, this._options.JsonSerializerOptions); - // If partition key is not provided, use key property as a partition key. - this._partitionKeyPropertyName = !string.IsNullOrWhiteSpace(this._options.PartitionKeyPropertyName) ? - this._options.PartitionKeyPropertyName! : - this._propertyReader.KeyPropertyName; + // Setup partition key property + if (this._options.PartitionKeyPropertyName is not null) + { + if (!this._model.PropertyMap.TryGetValue(this._options.PartitionKeyPropertyName, out var property)) + { + throw new ArgumentException($"Partition key property '{this._options.PartitionKeyPropertyName}' is not part of the record definition."); + } - VerifyPartitionKeyProperty(this._partitionKeyPropertyName, this._propertyReader.Properties); + if (property.Type != typeof(string)) + { + throw new ArgumentException("Partition key property must be string."); + } - this._partitionKeyStoragePropertyName = this._storagePropertyNames[this._partitionKeyPropertyName]; + this._partitionKeyProperty = property; + } + else + { + // If partition key is not provided, use key property as a partition key. + this._partitionKeyProperty = this._model.KeyProperty; + } - this._nonVectorStoragePropertyNames = this._propertyReader.DataProperties - .Cast() - .Concat([this._propertyReader.KeyProperty]) - .Select(x => this._storagePropertyNames[x.DataModelPropertyName]) - .ToList(); + this._collectionMetadata = new() + { + VectorStoreSystemName = AzureCosmosDBNoSQLConstants.VectorStoreSystemName, + VectorStoreName = database.Id, + CollectionName = name + }; } /// - public virtual Task CollectionExistsAsync(CancellationToken cancellationToken = default) + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) { return this.RunOperationAsync("GetContainerQueryIterator", async () => { const string Query = "SELECT VALUE(c.id) FROM c WHERE c.id = @collectionName"; - var queryDefinition = new QueryDefinition(Query).WithParameter("@collectionName", this.CollectionName); + var queryDefinition = new QueryDefinition(Query).WithParameter("@collectionName", this.Name); using var feedIterator = this._database.GetContainerQueryIterator(queryDefinition); @@ -192,14 +157,14 @@ public virtual Task CollectionExistsAsync(CancellationToken cancellationTo } /// - public virtual Task CreateCollectionAsync(CancellationToken cancellationToken = default) + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) { return this.RunOperationAsync("CreateContainer", () => this._database.CreateContainerAsync(this.GetContainerProperties(), cancellationToken: cancellationToken)); } /// - public virtual async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) { if (!await this.CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) { @@ -208,56 +173,93 @@ public virtual async Task CreateCollectionIfNotExistsAsync(CancellationToken can } /// - public virtual Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + public async Task DeleteCollectionAsync(CancellationToken cancellationToken = default) { - return this.RunOperationAsync("DeleteContainer", () => - this._database - .GetContainer(this.CollectionName) - .DeleteContainerAsync(cancellationToken: cancellationToken)); + try + { + await this._database + .GetContainer(this.Name) + .DeleteContainerAsync(cancellationToken: cancellationToken).ConfigureAwait(false); + } + catch (CosmosException ex) when (ex.StatusCode == System.Net.HttpStatusCode.NotFound) + { + // Do nothing, since the container is already deleted. + } + catch (CosmosException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreSystemName = AzureCosmosDBNoSQLConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, + OperationName = "DeleteContainer" + }; + } } - #region Implementation of IVectorStoreRecordCollection + /// + public Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) + => this.DeleteAsync([key], cancellationToken); /// - public virtual Task DeleteAsync(string key, CancellationToken cancellationToken = default) + public async Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default) { - // Use record key as partition key - var compositeKey = new AzureCosmosDBNoSQLCompositeKey(recordKey: key, partitionKey: key); + Verify.NotNull(keys); - return this.InternalDeleteAsync([compositeKey], cancellationToken); - } + var tasks = GetCompositeKeys(keys).Select(key => + { + Verify.NotNullOrWhiteSpace(key.RecordKey); + Verify.NotNullOrWhiteSpace(key.PartitionKey); - /// - public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) - { - // Use record keys as partition keys - var compositeKeys = keys.Select(key => new AzureCosmosDBNoSQLCompositeKey(recordKey: key, partitionKey: key)); + return this.RunOperationAsync("DeleteItem", () => + this._database + .GetContainer(this.Name) + .DeleteItemAsync(key.RecordKey, new PartitionKey(key.PartitionKey), cancellationToken: cancellationToken)); + }); - return this.InternalDeleteAsync(compositeKeys, cancellationToken); + await Task.WhenAll(tasks).ConfigureAwait(false); } /// - public virtual async Task GetAsync(string key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public async Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { - // Use record key as partition key - var compositeKey = new AzureCosmosDBNoSQLCompositeKey(recordKey: key, partitionKey: key); - - return await this.InternalGetAsync([compositeKey], options, cancellationToken) + return await this.GetAsync([key], options, cancellationToken) .FirstOrDefaultAsync(cancellationToken) .ConfigureAwait(false); } /// - public virtual async IAsyncEnumerable GetBatchAsync( - IEnumerable keys, + public async IAsyncEnumerable GetAsync( + IEnumerable keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - // Use record keys as partition keys - var compositeKeys = keys.Select(key => new AzureCosmosDBNoSQLCompositeKey(recordKey: key, partitionKey: key)); + Verify.NotNull(keys); - await foreach (var record in this.InternalGetAsync(compositeKeys, options, cancellationToken).ConfigureAwait(false)) + const string OperationName = "GetItemQueryIterator"; + + var includeVectors = options?.IncludeVectors ?? false; + if (includeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + + var queryDefinition = AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSelectQuery( + this._model, + this._model.KeyProperty.StorageName, + this._partitionKeyProperty.StorageName, + GetCompositeKeys(keys).ToList(), + includeVectors); + + await foreach (var jsonObject in this.GetItemsAsync(queryDefinition, cancellationToken).ConfigureAwait(false)) { + var record = VectorStoreErrorHandler.RunModelConversion( + AzureCosmosDBNoSQLConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, + OperationName, + () => this._mapper.MapFromStorageToDataModel(jsonObject, new() { IncludeVectors = includeVectors })); + if (record is not null) { yield return record; @@ -266,181 +268,303 @@ public virtual async IAsyncEnumerable GetBatchAsync( } /// - public virtual async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) { - var key = await this.InternalUpsertAsync(record, cancellationToken).ConfigureAwait(false); + Verify.NotNull(record); - return key.RecordKey; - } + const string OperationName = "UpsertItem"; - /// - public virtual async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Verify.NotNull(records); + MEAI.Embedding?[]? generatedEmbeddings = null; - var tasks = records.Select(record => this.InternalUpsertAsync(record, cancellationToken)); + var vectorPropertyCount = this._model.VectorProperties.Count; + for (var i = 0; i < vectorPropertyCount; i++) + { + var vectorProperty = this._model.VectorProperties[i]; - var keys = await Task.WhenAll(tasks).ConfigureAwait(false); + if (vectorProperty.EmbeddingGenerator is null) + { + continue; + } - foreach (var key in keys) - { - if (key is not null) + // TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties), + // and generate embeddings for them in a single batch. That's some more complexity though. + if (vectorProperty.TryGenerateEmbedding, ReadOnlyMemory>(record, cancellationToken, out var floatTask)) { - yield return key.RecordKey; + generatedEmbeddings ??= new MEAI.Embedding?[vectorPropertyCount]; + generatedEmbeddings[i] = await floatTask.ConfigureAwait(false); + } + else if (vectorProperty.TryGenerateEmbedding, ReadOnlyMemory>(record, cancellationToken, out var byteTask)) + { + generatedEmbeddings ??= new MEAI.Embedding?[vectorPropertyCount]; + generatedEmbeddings[i] = await byteTask.ConfigureAwait(false); + } + else if (vectorProperty.TryGenerateEmbedding, ReadOnlyMemory>(record, cancellationToken, out var sbyteTask)) + { + generatedEmbeddings ??= new MEAI.Embedding?[vectorPropertyCount]; + generatedEmbeddings[i] = await sbyteTask.ConfigureAwait(false); + } + else + { + throw new InvalidOperationException( + $"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of types '{typeof(Embedding).Name}', '{typeof(Embedding).Name}' or '{typeof(Embedding).Name}' for the given input type."); } } - } - #endregion + var jsonObject = VectorStoreErrorHandler.RunModelConversion( + AzureCosmosDBNoSQLConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, + OperationName, + () => this._mapper.MapFromDataToStorageModel(record, generatedEmbeddings)); - #region Implementation of IVectorStoreRecordCollection + var keyValue = jsonObject.TryGetPropertyValue(this._model.KeyProperty.StorageName!, out var jsonKey) ? jsonKey?.ToString() : null; + var partitionKeyValue = jsonObject.TryGetPropertyValue(this._partitionKeyProperty.StorageName, out var jsonPartitionKey) ? jsonPartitionKey?.ToString() : null; - /// - public virtual async Task GetAsync(AzureCosmosDBNoSQLCompositeKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) - { - return await this.InternalGetAsync([key], options, cancellationToken) - .FirstOrDefaultAsync(cancellationToken) - .ConfigureAwait(false); - } + if (string.IsNullOrWhiteSpace(keyValue)) + { + throw new VectorStoreOperationException($"Key property {this._model.KeyProperty.ModelName} is not initialized."); + } - /// - public virtual async IAsyncEnumerable GetBatchAsync( - IEnumerable keys, - GetRecordOptions? options = null, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - await foreach (var record in this.InternalGetAsync(keys, options, cancellationToken).ConfigureAwait(false)) + if (string.IsNullOrWhiteSpace(partitionKeyValue)) { - if (record is not null) - { - yield return record; - } + throw new VectorStoreOperationException($"Partition key property {this._partitionKeyProperty.ModelName} is not initialized."); } - } - /// - public virtual Task DeleteAsync(AzureCosmosDBNoSQLCompositeKey key, CancellationToken cancellationToken = default) - { - return this.InternalDeleteAsync([key], cancellationToken); - } + await this.RunOperationAsync(OperationName, () => + this._database + .GetContainer(this.Name) + .UpsertItemAsync(jsonObject, new PartitionKey(partitionKeyValue), cancellationToken: cancellationToken)) + .ConfigureAwait(false); - /// - public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) - { - return this.InternalDeleteAsync(keys, cancellationToken); + return typeof(TKey) switch + { + var t when t == typeof(AzureCosmosDBNoSQLCompositeKey) || t == typeof(object) => (TKey)(object)new AzureCosmosDBNoSQLCompositeKey(keyValue!, partitionKeyValue!), + var t when t == typeof(string) => (TKey)(object)keyValue!, + _ => throw new UnreachableException() + }; } /// - Task IVectorStoreRecordCollection.UpsertAsync(TRecord record, CancellationToken cancellationToken) + public async Task> UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) { - return this.InternalUpsertAsync(record, cancellationToken); + Verify.NotNull(records); + + // TODO: Do proper bulk upsert rather than parallel single inserts, #11350 + var tasks = records.Select(record => this.UpsertAsync(record, cancellationToken)); + var keys = await Task.WhenAll(tasks).ConfigureAwait(false); + return keys.Where(k => k is not null).ToList(); } + #region Search + /// - async IAsyncEnumerable IVectorStoreRecordCollection.UpsertBatchAsync( - IEnumerable records, - [EnumeratorCancellation] CancellationToken cancellationToken) + public async IAsyncEnumerable> SearchAsync( + TInput value, + int top, + VectorSearchOptions? options = default, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TInput : notnull { - Verify.NotNull(records); + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); - var tasks = records.Select(record => this.InternalUpsertAsync(record, cancellationToken)); + switch (vectorProperty.EmbeddingGenerator) + { + case IEmbeddingGenerator> generator: + { + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); - var keys = await Task.WhenAll(tasks).ConfigureAwait(false); + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } - foreach (var key in keys) - { - if (key is not null) + yield break; + } + + case IEmbeddingGenerator> generator: { - yield return key; + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; + } + + case IEmbeddingGenerator> generator: + { + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; } + + case null: + throw new InvalidOperationException(VectorDataStrings.NoEmbeddingGeneratorWasConfiguredForSearch); + + default: + throw new InvalidOperationException( + AzureCosmosDBNoSQLVectorStoreModelBuilder.s_supportedVectorTypes.Contains(typeof(TInput)) + ? string.Format(VectorDataStrings.EmbeddingTypePassedToSearchAsync) + : string.Format(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType, typeof(TInput).Name, vectorProperty.EmbeddingGenerator.GetType().Name)); } } /// - public virtual Task> VectorizedSearchAsync( + public IAsyncEnumerable> SearchEmbeddingAsync( TVector vector, + int top, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + return this.SearchCoreAsync(vector, top, vectorProperty, operationName: "SearchEmbedding", options, cancellationToken); + } + + private IAsyncEnumerable> SearchCoreAsync( + TVector vector, + int top, + VectorStoreRecordVectorPropertyModel vectorProperty, + string operationName, + VectorSearchOptions options, + CancellationToken cancellationToken = default) + where TVector : notnull { const string OperationName = "VectorizedSearch"; const string ScorePropertyName = "SimilarityScore"; this.VerifyVectorType(vector); + Verify.NotLessThan(top, 1); - var searchOptions = options ?? s_defaultVectorSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(searchOptions); - var vectorPropertyName = this._storagePropertyNames[vectorProperty.DataModelPropertyName]; - - var fields = new List(searchOptions.IncludeVectors ? this._storagePropertyNames.Values : this._nonVectorStoragePropertyNames); + if (options.IncludeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } #pragma warning disable CS0618 // Type or member is obsolete var queryDefinition = AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSearchQuery( vector, null, - fields, - this._storagePropertyNames, - vectorPropertyName, + this._model, + vectorProperty.StorageName, null, ScorePropertyName, - searchOptions.OldFilter, - searchOptions.Filter, - searchOptions.Top, - searchOptions.Skip); + options.OldFilter, + options.Filter, + top, + options.Skip, + options.IncludeVectors); #pragma warning restore CS0618 // Type or member is obsolete var searchResults = this.GetItemsAsync(queryDefinition, cancellationToken); - var mappedResults = this.MapSearchResultsAsync( + return this.MapSearchResultsAsync( searchResults, ScorePropertyName, OperationName, - searchOptions.IncludeVectors, + options.IncludeVectors, cancellationToken); - return Task.FromResult(new VectorSearchResults(mappedResults)); } /// - public Task> HybridSearchAsync(TVector vector, ICollection keywords, HybridSearchOptions? options = null, CancellationToken cancellationToken = default) + [Obsolete("Use either SearchEmbeddingAsync to search directly on embeddings, or SearchAsync to handle embedding generation internally as part of the call.")] + public IAsyncEnumerable> VectorizedSearchAsync(TVector vector, int top, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + => this.SearchEmbeddingAsync(vector, top, options, cancellationToken); + + #endregion Search + + /// + public async IAsyncEnumerable GetAsync(Expression> filter, int top, + GetFilteredRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(filter); + Verify.NotLessThan(top, 1); + + options ??= new(); + + var (whereClause, filterParameters) = new AzureCosmosDBNoSqlFilterTranslator().Translate(filter, this._model); + + var queryDefinition = AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSearchQuery( + this._model, + whereClause, + filterParameters, + options, + top); + + var searchResults = this.GetItemsAsync(queryDefinition, cancellationToken); + + await foreach (var jsonObject in searchResults.ConfigureAwait(false)) + { + var record = VectorStoreErrorHandler.RunModelConversion( + AzureCosmosDBNoSQLConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, + "GetAsync", + () => this._mapper.MapFromStorageToDataModel(jsonObject, new() { IncludeVectors = options.IncludeVectors })); + + yield return record; + } + } + + /// + public IAsyncEnumerable> HybridSearchAsync(TVector vector, ICollection keywords, int top, HybridSearchOptions? options = null, CancellationToken cancellationToken = default) { const string OperationName = "VectorizedSearch"; const string ScorePropertyName = "SimilarityScore"; this.VerifyVectorType(vector); + Verify.NotLessThan(top, 1); - var searchOptions = options ?? s_defaultKeywordVectorizedHybridSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(new() { VectorProperty = searchOptions.VectorProperty }); - var vectorPropertyName = this._storagePropertyNames[vectorProperty.DataModelPropertyName]; - - var textProperty = this._propertyReader.GetFullTextDataPropertyOrSingle(searchOptions.AdditionalProperty); - var textPropertyName = this._storagePropertyNames[textProperty.DataModelPropertyName]; - - var fields = new List(searchOptions.IncludeVectors ? this._storagePropertyNames.Values : this._nonVectorStoragePropertyNames); + options ??= s_defaultKeywordVectorizedHybridSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(new() { VectorProperty = options.VectorProperty }); + var textProperty = this._model.GetFullTextDataPropertyOrSingle(options.AdditionalProperty); #pragma warning disable CS0618 // Type or member is obsolete var queryDefinition = AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSearchQuery( vector, keywords, - fields, - this._storagePropertyNames, - vectorPropertyName, - textPropertyName, + this._model, + vectorProperty.StorageName, + textProperty.StorageName, ScorePropertyName, - searchOptions.OldFilter, - searchOptions.Filter, - searchOptions.Top, - searchOptions.Skip); + options.OldFilter, + options.Filter, + top, + options.Skip, + options.IncludeVectors); #pragma warning restore CS0618 // Type or member is obsolete var searchResults = this.GetItemsAsync(queryDefinition, cancellationToken); - var mappedResults = this.MapSearchResultsAsync( + return this.MapSearchResultsAsync( searchResults, ScorePropertyName, OperationName, - searchOptions.IncludeVectors, + options.IncludeVectors, cancellationToken); - return Task.FromResult(new VectorSearchResults(mappedResults)); } - #endregion + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreRecordCollectionMetadata) ? this._collectionMetadata : + serviceType == typeof(Database) ? this._database : + serviceType.IsInstanceOfType(this) ? this : + null; + } #region private @@ -450,11 +574,11 @@ private void VerifyVectorType(TVector? vector) var vectorType = vector.GetType(); - if (!s_supportedVectorTypes.Contains(vectorType)) + if (!AzureCosmosDBNoSQLVectorStoreModelBuilder.s_supportedVectorTypes.Contains(vectorType)) { throw new NotSupportedException( $"The provided vector type {vectorType.FullName} is not supported by the Azure CosmosDB NoSQL connector. " + - $"Supported types are: {string.Join(", ", s_supportedVectorTypes.Select(l => l.FullName))}"); + $"Supported types are: {string.Join(", ", AzureCosmosDBNoSQLVectorStoreModelBuilder.s_supportedVectorTypes.Select(l => l.FullName))}"); } } @@ -464,33 +588,18 @@ private async Task RunOperationAsync(string operationName, Func> o { return await operation.Invoke().ConfigureAwait(false); } - catch (Exception ex) + catch (CosmosException ex) { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, - CollectionName = this.CollectionName, + VectorStoreSystemName = AzureCosmosDBNoSQLConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, OperationName = operationName }; } } - private static void VerifyPartitionKeyProperty(string partitionKeyPropertyName, IReadOnlyList properties) - { - var partitionKeyProperty = properties - .FirstOrDefault(l => l.DataModelPropertyName.Equals(partitionKeyPropertyName, StringComparison.Ordinal)); - - if (partitionKeyProperty is null) - { - throw new ArgumentException("Partition key property must be part of record definition."); - } - - if (partitionKeyProperty.PropertyType != typeof(string)) - { - throw new ArgumentException("Partition key property must be string."); - } - } - /// /// Returns instance of with applied indexing policy. /// More information here: . @@ -498,7 +607,7 @@ private static void VerifyPartitionKeyProperty(string partitionKeyPropertyName, private ContainerProperties GetContainerProperties() { // Process Vector properties. - var embeddings = new Collection(); + var embeddings = new Collection(); var vectorIndexPaths = new Collection(); var indexingPolicy = new IndexingPolicy @@ -509,34 +618,27 @@ private ContainerProperties GetContainerProperties() if (this._options.IndexingMode == IndexingMode.None) { - return new ContainerProperties(this.CollectionName, partitionKeyPath: $"/{this._partitionKeyStoragePropertyName}") + return new ContainerProperties(this.Name, partitionKeyPath: $"/{this._partitionKeyProperty.StorageName}") { IndexingPolicy = indexingPolicy }; } - foreach (var property in this._propertyReader.VectorProperties) + foreach (var property in this._model.VectorProperties) { - var vectorPropertyName = this._storagePropertyNames[property.DataModelPropertyName]; - - if (property.Dimensions is not > 0) - { - throw new VectorStoreOperationException($"Property {nameof(property.Dimensions)} on {nameof(VectorStoreRecordVectorProperty)} '{property.DataModelPropertyName}' must be set to a positive integer to create a collection."); - } - - var path = $"/{vectorPropertyName}"; + var path = $"/{property.StorageName}"; - var embedding = new Embedding + var embedding = new Azure.Cosmos.Embedding { - DataType = GetDataType(property.PropertyType, vectorPropertyName), + DataType = GetDataType(property.EmbeddingType, property.StorageName), Dimensions = (int)property.Dimensions, - DistanceFunction = GetDistanceFunction(property.DistanceFunction, vectorPropertyName), + DistanceFunction = GetDistanceFunction(property.DistanceFunction, property.StorageName), Path = path }; var vectorIndexPath = new VectorIndexPath { - Type = GetIndexKind(property.IndexKind, vectorPropertyName), + Type = GetIndexKind(property.IndexKind, property.StorageName), Path = path }; @@ -550,17 +652,17 @@ private ContainerProperties GetContainerProperties() var vectorEmbeddingPolicy = new VectorEmbeddingPolicy(embeddings); // Process Data properties. - foreach (var property in this._propertyReader.DataProperties) + foreach (var property in this._model.DataProperties) { - if (property.IsFilterable || property.IsFullTextSearchable) + if (property.IsIndexed || property.IsFullTextIndexed) { - indexingPolicy.IncludedPaths.Add(new IncludedPath { Path = $"/{this._storagePropertyNames[property.DataModelPropertyName]}/?" }); + indexingPolicy.IncludedPaths.Add(new IncludedPath { Path = $"/{property.StorageName}/?" }); } - if (property.IsFullTextSearchable) + if (property.IsFullTextIndexed) { - indexingPolicy.FullTextIndexes.Add(new FullTextIndexPath { Path = $"/{this._storagePropertyNames[property.DataModelPropertyName]}" }); + indexingPolicy.FullTextIndexes.Add(new FullTextIndexPath { Path = $"/{property.StorageName}" }); // TODO: Switch to using language from a setting. - fullTextPolicy.FullTextPaths.Add(new FullTextPath { Path = $"/{this._storagePropertyNames[property.DataModelPropertyName]}", Language = "en-US" }); + fullTextPolicy.FullTextPaths.Add(new FullTextPath { Path = $"/{property.StorageName}", Language = "en-US" }); } } @@ -573,7 +675,7 @@ private ContainerProperties GetContainerProperties() indexingPolicy.ExcludedPaths.Add(new ExcludedPath { Path = $"{vectorIndexPath.Path}/*" }); } - return new ContainerProperties(this.CollectionName, partitionKeyPath: $"/{this._partitionKeyStoragePropertyName}") + return new ContainerProperties(this.Name, partitionKeyPath: $"/{this._partitionKeyProperty.StorageName}") { VectorEmbeddingPolicy = vectorEmbeddingPolicy, IndexingPolicy = indexingPolicy, @@ -585,21 +687,13 @@ private ContainerProperties GetContainerProperties() /// More information about Azure CosmosDB NoSQL index kinds here: . /// private static VectorIndexType GetIndexKind(string? indexKind, string vectorPropertyName) - { - if (string.IsNullOrWhiteSpace(indexKind)) - { - // Use default index kind. - return VectorIndexType.DiskANN; - } - - return indexKind switch + => indexKind switch { + IndexKind.DiskAnn or null => VectorIndexType.DiskANN, IndexKind.Flat => VectorIndexType.Flat, IndexKind.QuantizedFlat => VectorIndexType.QuantizedFlat, - IndexKind.DiskAnn => VectorIndexType.DiskANN, _ => throw new InvalidOperationException($"Index kind '{indexKind}' on {nameof(VectorStoreRecordVectorProperty)} '{vectorPropertyName}' is not supported by the Azure CosmosDB NoSQL VectorStore.") }; - } /// /// More information about Azure CosmosDB NoSQL distance functions here: . @@ -625,101 +719,18 @@ private static DistanceFunction GetDistanceFunction(string? distanceFunction, st /// Returns based on vector property type. /// private static VectorDataType GetDataType(Type vectorDataType, string vectorPropertyName) - { - return vectorDataType switch + => vectorDataType switch { Type type when type == typeof(ReadOnlyMemory) || type == typeof(ReadOnlyMemory?) => VectorDataType.Float32, Type type when type == typeof(ReadOnlyMemory) || type == typeof(ReadOnlyMemory?) => VectorDataType.Uint8, Type type when type == typeof(ReadOnlyMemory) || type == typeof(ReadOnlyMemory?) => VectorDataType.Int8, _ => throw new InvalidOperationException($"Data type '{vectorDataType}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorPropertyName}' is not supported by the Azure CosmosDB NoSQL VectorStore.") }; - } - - private async IAsyncEnumerable InternalGetAsync( - IEnumerable keys, - GetRecordOptions? options = null, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Verify.NotNull(keys); - - const string OperationName = "GetItemQueryIterator"; - - var includeVectors = options?.IncludeVectors ?? false; - var fields = new List(includeVectors ? this._storagePropertyNames.Values : this._nonVectorStoragePropertyNames); - var queryDefinition = AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSelectQuery( - this._keyStoragePropertyName, - this._partitionKeyStoragePropertyName, - keys.ToList(), - fields); - - await foreach (var jsonObject in this.GetItemsAsync(queryDefinition, cancellationToken).ConfigureAwait(false)) - { - yield return VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, - OperationName, - () => this._mapper.MapFromStorageToDataModel(jsonObject, new() { IncludeVectors = includeVectors })); - } - } - - private async Task InternalUpsertAsync( - TRecord record, - CancellationToken cancellationToken) - { - Verify.NotNull(record); - - const string OperationName = "UpsertItem"; - - var jsonObject = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, - OperationName, - () => this._mapper.MapFromDataToStorageModel(record)); - - var keyValue = jsonObject.TryGetPropertyValue(this._keyStoragePropertyName, out var jsonKey) ? jsonKey?.ToString() : null; - var partitionKeyValue = jsonObject.TryGetPropertyValue(this._partitionKeyStoragePropertyName, out var jsonPartitionKey) ? jsonPartitionKey?.ToString() : null; - - if (string.IsNullOrWhiteSpace(keyValue)) - { - throw new VectorStoreOperationException($"Key property {this._propertyReader.KeyPropertyName} is not initialized."); - } - - if (string.IsNullOrWhiteSpace(partitionKeyValue)) - { - throw new VectorStoreOperationException($"Partition key property {this._partitionKeyPropertyName} is not initialized."); - } - - await this.RunOperationAsync(OperationName, () => - this._database - .GetContainer(this.CollectionName) - .UpsertItemAsync(jsonObject, new PartitionKey(partitionKeyValue), cancellationToken: cancellationToken)) - .ConfigureAwait(false); - - return new AzureCosmosDBNoSQLCompositeKey(keyValue!, partitionKeyValue!); - } - - private async Task InternalDeleteAsync(IEnumerable keys, CancellationToken cancellationToken) - { - Verify.NotNull(keys); - - var tasks = keys.Select(key => - { - Verify.NotNullOrWhiteSpace(key.RecordKey); - Verify.NotNullOrWhiteSpace(key.PartitionKey); - - return this.RunOperationAsync("DeleteItem", () => - this._database - .GetContainer(this.CollectionName) - .DeleteItemAsync(key.RecordKey, new PartitionKey(key.PartitionKey), cancellationToken: cancellationToken)); - }); - - await Task.WhenAll(tasks).ConfigureAwait(false); - } private async IAsyncEnumerable GetItemsAsync(QueryDefinition queryDefinition, [EnumeratorCancellation] CancellationToken cancellationToken) { var iterator = this._database - .GetContainer(this.CollectionName) + .GetContainer(this.Name) .GetItemQueryIterator(queryDefinition); while (iterator.HasMoreResults) @@ -751,8 +762,9 @@ private async IAsyncEnumerable> MapSearchResultsAsyn jsonObject.Remove(scorePropertyName); var record = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, + AzureCosmosDBNoSQLConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, operationName, () => this._mapper.MapFromStorageToDataModel(jsonObject, new() { IncludeVectors = includeVectors })); @@ -760,27 +772,19 @@ private async IAsyncEnumerable> MapSearchResultsAsyn } } - /// - /// Returns custom mapper, generic data model mapper or default record mapper. - /// - private IVectorStoreRecordMapper InitializeMapper(JsonSerializerOptions jsonSerializerOptions) - { - if (this._options.JsonObjectCustomMapper is not null) - { - return this._options.JsonObjectCustomMapper; - } - - if (typeof(TRecord) == typeof(VectorStoreGenericDataModel)) + private static IEnumerable GetCompositeKeys(IEnumerable keys) + => keys switch { - var mapper = new AzureCosmosDBNoSQLGenericDataModelMapper(this._propertyReader.Properties, this._storagePropertyNames, jsonSerializerOptions); - return (mapper as IVectorStoreRecordMapper)!; - } - - return new AzureCosmosDBNoSQLVectorStoreRecordMapper( - this._storagePropertyNames[this._propertyReader.KeyPropertyName], - this._storagePropertyNames, - jsonSerializerOptions); - } + IEnumerable k => k, + IEnumerable k => k.Select(key => new AzureCosmosDBNoSQLCompositeKey(recordKey: key, partitionKey: key)), + IEnumerable k => k.Select(key => key switch + { + string s => new AzureCosmosDBNoSQLCompositeKey(recordKey: s, partitionKey: s), + AzureCosmosDBNoSQLCompositeKey ck => ck, + _ => throw new ArgumentException($"Invalid key type '{key.GetType().Name}'.") + }), + _ => throw new UnreachableException() + }; #endregion } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollectionOptions.cs index 047cd2b56b6c..55c5ddf9aad6 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollectionOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollectionOptions.cs @@ -1,14 +1,16 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Text.Json; using System.Text.Json.Nodes; using Microsoft.Azure.Cosmos; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; /// -/// Options when creating a . +/// Options when creating a . /// public sealed class AzureCosmosDBNoSQLVectorStoreRecordCollectionOptions { @@ -18,6 +20,7 @@ public sealed class AzureCosmosDBNoSQLVectorStoreRecordCollectionOptions /// If not set, the default mapper that is provided by the Azure CosmosDB NoSQL client SDK will be used. /// + [Obsolete("Custom mappers are no longer supported.", error: true)] public IVectorStoreRecordMapper? JsonObjectCustomMapper { get; init; } = null; /// @@ -56,4 +59,9 @@ public sealed class AzureCosmosDBNoSQLVectorStoreRecordCollectionOptions. /// public bool Automatic { get; init; } = true; + + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordMapper.cs index 4f6da286d51b..4a2bede9820e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordMapper.cs @@ -1,9 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Collections.Generic; +using System.Diagnostics; using System.Text.Json; using System.Text.Json.Nodes; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using MEAI = Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; @@ -11,45 +14,51 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; /// Class for mapping between a json node stored in Azure CosmosDB NoSQL and the consumer data model. /// /// The consumer data model to map to or from. -internal sealed class AzureCosmosDBNoSQLVectorStoreRecordMapper : IVectorStoreRecordMapper +internal sealed class AzureCosmosDBNoSQLVectorStoreRecordMapper(VectorStoreRecordModel model, JsonSerializerOptions? jsonSerializerOptions) + : ICosmosNoSQLMapper { - /// The JSON serializer options to use when converting between the data model and the Azure CosmosDB NoSQL record. - private readonly JsonSerializerOptions _jsonSerializerOptions; + private readonly VectorStoreRecordKeyPropertyModel _keyProperty = model.KeyProperty; - /// The storage property name of the key field of consumer data model. - private readonly string _keyStoragePropertyName; - - /// A dictionary that maps from a property name to the storage name that should be used when serializing it to json for data and vector properties. - private readonly Dictionary _storagePropertyNames = []; - - public AzureCosmosDBNoSQLVectorStoreRecordMapper( - string keyStoragePropertyName, - Dictionary storagePropertyNames, - JsonSerializerOptions jsonSerializerOptions) + public JsonObject MapFromDataToStorageModel(TRecord dataModel, MEAI.Embedding?[]? generatedEmbeddings) { - Verify.NotNull(jsonSerializerOptions); + var jsonObject = JsonSerializer.SerializeToNode(dataModel, jsonSerializerOptions)!.AsObject(); - this._keyStoragePropertyName = keyStoragePropertyName; - this._storagePropertyNames = storagePropertyNames; - this._jsonSerializerOptions = jsonSerializerOptions; - } + // The key property in Azure CosmosDB NoSQL is always named 'id'. + // But the external JSON serializer used just above isn't aware of that, and will produce a JSON object with another name, taking into + // account e.g. naming policies. TemporaryStorageName gets populated in the model builder - containing that name - once VectorStoreModelBuildingOptions.ReservedKeyPropertyName is set + RenameJsonProperty(jsonObject, this._keyProperty.TemporaryStorageName!, AzureCosmosDBNoSQLConstants.ReservedKeyPropertyName); - public JsonObject MapFromDataToStorageModel(TRecord dataModel) - { - var jsonObject = JsonSerializer.SerializeToNode(dataModel, this._jsonSerializerOptions)!.AsObject(); - - // Key property in Azure CosmosDB NoSQL has a reserved name. - RenameJsonProperty(jsonObject, this._keyStoragePropertyName, AzureCosmosDBNoSQLConstants.ReservedKeyPropertyName); + // Go over the vector properties; those which have an embedding generator configured on them will have embedding generators, overwrite + // the value in the JSON object with that. + if (generatedEmbeddings is not null) + { + for (var i = 0; i < model.VectorProperties.Count; i++) + { + if (generatedEmbeddings[i] is not null) + { + var property = model.VectorProperties[i]; + Debug.Assert(property.EmbeddingGenerator is not null); + var embedding = generatedEmbeddings[i]; + jsonObject[property.StorageName] = embedding switch + { + Embedding e => JsonSerializer.SerializeToNode(e.Vector, jsonSerializerOptions), + Embedding e => JsonSerializer.SerializeToNode(e.Vector, jsonSerializerOptions), + Embedding e => JsonSerializer.SerializeToNode(e.Vector, jsonSerializerOptions), + _ => throw new UnreachableException() + }; + } + } + } return jsonObject; } public TRecord MapFromStorageToDataModel(JsonObject storageModel, StorageToDataModelMapperOptions options) { - // Rename key property for valid deserialization. - RenameJsonProperty(storageModel, AzureCosmosDBNoSQLConstants.ReservedKeyPropertyName, this._keyStoragePropertyName); + // See above comment. + RenameJsonProperty(storageModel, AzureCosmosDBNoSQLConstants.ReservedKeyPropertyName, this._keyProperty.TemporaryStorageName!); - return storageModel.Deserialize(this._jsonSerializerOptions)!; + return storageModel.Deserialize(jsonSerializerOptions)!; } #region private diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlFilterTranslator.cs index e18f176c2ea7..84587a1a9cdf 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlFilterTranslator.cs @@ -9,27 +9,33 @@ using System.Reflection; using System.Runtime.CompilerServices; using System.Text; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.ConnectorSupport.Filter; namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; internal class AzureCosmosDBNoSqlFilterTranslator { - private IReadOnlyDictionary _storagePropertyNames = null!; + private VectorStoreRecordModel _model = null!; private ParameterExpression _recordParameter = null!; private readonly Dictionary _parameters = new(); private readonly StringBuilder _sql = new(); - internal (string WhereClause, Dictionary Parameters) Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + internal (string WhereClause, Dictionary Parameters) Translate(LambdaExpression lambdaExpression, VectorStoreRecordModel model) { Debug.Assert(this._sql.Length == 0); - this._storagePropertyNames = storagePropertyNames; + this._model = model; Debug.Assert(lambdaExpression.Parameters.Count == 1); this._recordParameter = lambdaExpression.Parameters[0]; + var preprocessor = new FilterTranslationPreprocessor { InlineCapturedVariables = false }; + var preprocessedExpression = preprocessor.Visit(lambdaExpression); + this.Translate(lambdaExpression.Body); + return (this._sql.ToString(), this._parameters); } @@ -139,8 +145,8 @@ private void TranslateMember(MemberExpression memberExpression) { switch (memberExpression) { - case var _ when this.TryGetPropertyAccess(memberExpression, out var column): - this._sql.Append(AzureCosmosDBNoSQLConstants.ContainerAlias).Append("[\"").Append(column).Append("\"]"); + case var _ when this.TryBindProperty(memberExpression, out var property): + this.GeneratePropertyAccess(property); return; // Identify captured lambda variables, translate to Cosmos parameters (@foo, @bar...) @@ -188,6 +194,11 @@ private void TranslateMethodCall(MethodCallExpression methodCall) { switch (methodCall) { + // Dictionary access for dynamic mapping (r => r["SomeString"] == "foo") + case MethodCallExpression when this.TryBindProperty(methodCall, out var property): + this.GeneratePropertyAccess(property); + return; + // Enumerable.Contains() case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains when contains.Method.DeclaringType == typeof(Enumerable): @@ -243,25 +254,63 @@ private void TranslateUnary(UnaryExpression unary) this._sql.Append(')'); return; + // Handle convert over member access, for dynamic dictionary access (r => (int)r["SomeInt"] == 8) + case ExpressionType.Convert when this.TryBindProperty(unary.Operand, out var property) && unary.Type == property.Type: + this.GeneratePropertyAccess(property); + return; + default: throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); } } - private bool TryGetPropertyAccess(Expression expression, [NotNullWhen(true)] out string? column) + protected virtual void GeneratePropertyAccess(VectorStoreRecordPropertyModel property) + => this._sql.Append(AzureCosmosDBNoSQLConstants.ContainerAlias).Append("[\"").Append(property.StorageName).Append("\"]"); + + private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out VectorStoreRecordPropertyModel? property) { - if (expression is MemberExpression member && member.Expression == this._recordParameter) + Type? convertedClrType = null; + + if (expression is UnaryExpression { NodeType: ExpressionType.Convert } unary) + { + expression = unary.Operand; + convertedClrType = unary.Type; + } + + var modelName = expression switch { - if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column)) + // Regular member access for strongly-typed POCO binding (e.g. r => r.SomeInt == 8) + MemberExpression memberExpression when memberExpression.Expression == this._recordParameter + => memberExpression.Member.Name, + + // Dictionary lookup for weakly-typed dynamic binding (e.g. r => r["SomeInt"] == 8) + MethodCallExpression { - throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); - } + Method: { Name: "get_Item", DeclaringType: var declaringType }, + Arguments: [ConstantExpression { Value: string keyName }] + } methodCall when methodCall.Object == this._recordParameter && declaringType == typeof(Dictionary) + => keyName, - return true; + _ => null + }; + + if (modelName is null) + { + property = null; + return false; } - column = null; - return false; + if (!this._model.PropertyMap.TryGetValue(modelName, out property)) + { + throw new InvalidOperationException($"Property name '{modelName}' provided as part of the filter clause is not a valid property name."); + } + + if (convertedClrType is not null && convertedClrType != property.Type) + { + throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convertedClrType.Name}', but its configured type is '{property.Type.Name}'."); + } + + return true; } private static bool TryGetCapturedValue(Expression expression, [NotNullWhen(true)] out string? name, out object? value) diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlVectorStoreModelBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlVectorStoreModelBuilder.cs new file mode 100644 index 000000000000..66b5a6fd6970 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlVectorStoreModelBuilder.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData.ConnectorSupport; + +namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; + +internal class AzureCosmosDBNoSQLVectorStoreModelBuilder() : VectorStoreRecordJsonModelBuilder(s_modelBuildingOptions) +{ + private static readonly HashSet s_supportedDataTypes = + [ + typeof(bool), + typeof(string), + typeof(int), + typeof(long), + typeof(float), + typeof(double), + typeof(DateTimeOffset) + ]; + + internal static readonly HashSet s_supportedVectorTypes = + [ + // Float32 + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?), + + // Uint8 + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?), + + // Int8 + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?), + ]; + + private static readonly VectorStoreRecordModelBuildingOptions s_modelBuildingOptions = new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + UsesExternalSerializer = true, + + // TODO: Cosmos supports other key types (int, Guid...) + SupportedKeyPropertyTypes = [typeof(string)], + SupportedDataPropertyTypes = s_supportedDataTypes, + SupportedEnumerableDataPropertyElementTypes = s_supportedDataTypes, + SupportedVectorPropertyTypes = s_supportedVectorTypes, + + ReservedKeyStorageName = AzureCosmosDBNoSQLConstants.ReservedKeyPropertyName, + }; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/Connectors.Memory.AzureCosmosDBNoSQL.csproj b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/Connectors.Memory.AzureCosmosDBNoSQL.csproj index a3dbe540101c..606b8a2fe866 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/Connectors.Memory.AzureCosmosDBNoSQL.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/Connectors.Memory.AzureCosmosDBNoSQL.csproj @@ -4,7 +4,7 @@ Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL $(AssemblyName) - net8.0;netstandard2.0 + net8.0;netstandard2.0;net462 $(NoWarn);NU5104 preview @@ -12,6 +12,7 @@ + @@ -25,6 +26,12 @@ + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/IAzureCosmosDBNoSQLVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/IAzureCosmosDBNoSQLVectorStoreRecordCollectionFactory.cs index 8d51dbb555b0..626677bccc40 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/IAzureCosmosDBNoSQLVectorStoreRecordCollectionFactory.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/IAzureCosmosDBNoSQLVectorStoreRecordCollectionFactory.cs @@ -25,5 +25,6 @@ IVectorStoreRecordCollection CreateVectorStoreRecordCollection +{ + /// + /// Maps from the consumer record data model to the storage model. + /// + JsonObject MapFromDataToStorageModel(TRecord dataModel, MEAI.Embedding?[]? generatedEmbeddings); + + /// + /// Maps from the storage model to the consumer record data model. + /// + TRecord MapFromStorageToDataModel(JsonObject storageModel, StorageToDataModelMapperOptions options); +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Chroma/Connectors.Memory.Chroma.csproj b/dotnet/src/Connectors/Connectors.Memory.Chroma/Connectors.Memory.Chroma.csproj index e89013694aae..96ed5812cfe4 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Chroma/Connectors.Memory.Chroma.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Chroma/Connectors.Memory.Chroma.csproj @@ -4,7 +4,7 @@ Microsoft.SemanticKernel.Connectors.Chroma $(AssemblyName) - net8.0;netstandard2.0 + net8.0;netstandard2.0;net462 alpha diff --git a/dotnet/src/Connectors/Connectors.Memory.Chroma/README.md b/dotnet/src/Connectors/Connectors.Memory.Chroma/README.md index 42fc0f468a6b..47b9037200c6 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Chroma/README.md +++ b/dotnet/src/Connectors/Connectors.Memory.Chroma/README.md @@ -21,8 +21,6 @@ docker-compose up -d --build 3. Use Semantic Kernel with Chroma, using server local endpoint `http://localhost:8000`: - > See [Example 14](../../../samples/Concepts/Memory/SemanticTextMemory_Building.cs) and [Example 15](../../../samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs) for more memory usage examples with the kernel. - ```csharp const string endpoint = "http://localhost:8000"; diff --git a/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs index cad9bd1048c2..7bc56615cb0c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs @@ -6,25 +6,27 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; using System.Text; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.ConnectorSupport.Filter; namespace Microsoft.SemanticKernel.Connectors; +#pragma warning disable MEVD9001 // Microsoft.Extensions.VectorData experimental connector-facing APIs + internal abstract class SqlFilterTranslator { - private readonly IReadOnlyDictionary _storagePropertyNames; + private readonly VectorStoreRecordModel _model; private readonly LambdaExpression _lambdaExpression; private readonly ParameterExpression _recordParameter; protected readonly StringBuilder _sql; internal SqlFilterTranslator( - IReadOnlyDictionary storagePropertyNames, + VectorStoreRecordModel model, LambdaExpression lambdaExpression, StringBuilder? sql = null) { - this._storagePropertyNames = storagePropertyNames; + this._model = model; this._lambdaExpression = lambdaExpression; Debug.Assert(lambdaExpression.Parameters.Count == 1); this._recordParameter = lambdaExpression.Parameters[0]; @@ -40,10 +42,13 @@ internal void Translate(bool appendWhere) this._sql.Append("WHERE "); } - this.Translate(this._lambdaExpression.Body, null); + var preprocessor = new FilterTranslationPreprocessor { TransformCapturedVariablesToQueryParameterExpressions = true }; + var preprocessedExpression = preprocessor.Visit(this._lambdaExpression.Body); + + this.Translate(preprocessedExpression, isSearchCondition: true); } - protected void Translate(Expression? node, Expression? parent) + protected void Translate(Expression? node, bool isSearchCondition = false) { switch (node) { @@ -55,16 +60,20 @@ protected void Translate(Expression? node, Expression? parent) this.TranslateConstant(constant.Value); return; + case QueryParameterExpression { Name: var name, Value: var value }: + this.TranslateQueryParameter(name, value); + return; + case MemberExpression member: - this.TranslateMember(member, parent); + this.TranslateMember(member, isSearchCondition); return; case MethodCallExpression methodCall: - this.TranslateMethodCall(methodCall); + this.TranslateMethodCall(methodCall, isSearchCondition); return; case UnaryExpression unary: - this.TranslateUnary(unary); + this.TranslateUnary(unary, isSearchCondition); return; default: @@ -79,29 +88,29 @@ protected void TranslateBinary(BinaryExpression binary) { case ExpressionType.Equal when IsNull(binary.Right): this._sql.Append('('); - this.Translate(binary.Left, binary); + this.Translate(binary.Left); this._sql.Append(" IS NULL)"); return; case ExpressionType.NotEqual when IsNull(binary.Right): this._sql.Append('('); - this.Translate(binary.Left, binary); + this.Translate(binary.Left); this._sql.Append(" IS NOT NULL)"); return; case ExpressionType.Equal when IsNull(binary.Left): this._sql.Append('('); - this.Translate(binary.Right, binary); + this.Translate(binary.Right); this._sql.Append(" IS NULL)"); return; case ExpressionType.NotEqual when IsNull(binary.Left): this._sql.Append('('); - this.Translate(binary.Right, binary); + this.Translate(binary.Right); this._sql.Append(" IS NOT NULL)"); return; } this._sql.Append('('); - this.Translate(binary.Left, binary); + this.Translate(binary.Left, isSearchCondition: binary.NodeType is ExpressionType.AndAlso or ExpressionType.OrElse); this._sql.Append(binary.NodeType switch { @@ -119,12 +128,12 @@ protected void TranslateBinary(BinaryExpression binary) _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) }); - this.Translate(binary.Right, binary); + this.Translate(binary.Right, isSearchCondition: binary.NodeType is ExpressionType.AndAlso or ExpressionType.OrElse); + this._sql.Append(')'); static bool IsNull(Expression expression) - => expression is ConstantExpression { Value: null } - || (TryGetCapturedValue(expression, out _, out var capturedValue) && capturedValue is null); + => expression is ConstantExpression { Value: null } or QueryParameterExpression { Value: null }; } protected virtual void TranslateConstant(object? value) @@ -169,36 +178,35 @@ protected virtual void TranslateConstant(object? value) } } - private void TranslateMember(MemberExpression memberExpression, Expression? parent) + private void TranslateMember(MemberExpression memberExpression, bool isSearchCondition) { - switch (memberExpression) + if (this.TryBindProperty(memberExpression, out var property)) { - case var _ when this.TryGetColumn(memberExpression, out var column): - this.TranslateColumn(column, memberExpression, parent); - return; - - case var _ when TryGetCapturedValue(memberExpression, out var name, out var value): - this.TranslateCapturedVariable(name, value); - return; - - default: - throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + this.GenerateColumn(property.StorageName, isSearchCondition); + return; } + + throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); } - protected virtual void TranslateColumn(string column, MemberExpression memberExpression, Expression? parent) - => this._sql.Append('"').Append(column).Append('"'); + protected virtual void GenerateColumn(string column, bool isSearchCondition = false) + => this._sql.Append('"').Append(column.Replace("\"", "\"\"")).Append('"'); - protected abstract void TranslateCapturedVariable(string name, object? capturedValue); + protected abstract void TranslateQueryParameter(string name, object? value); - private void TranslateMethodCall(MethodCallExpression methodCall) + private void TranslateMethodCall(MethodCallExpression methodCall, bool isSearchCondition = false) { switch (methodCall) { + // Dictionary access for dynamic mapping (r => r["SomeString"] == "foo") + case MethodCallExpression when this.TryBindProperty(methodCall, out var property): + this.GenerateColumn(property.StorageName, isSearchCondition); + return; + // Enumerable.Contains() case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains when contains.Method.DeclaringType == typeof(Enumerable): - this.TranslateContains(source, item, methodCall); + this.TranslateContains(source, item); return; // List.Contains() @@ -212,7 +220,7 @@ private void TranslateMethodCall(MethodCallExpression methodCall) Object: Expression source, Arguments: [var item] } when declaringType.GetGenericTypeDefinition() == typeof(List<>): - this.TranslateContains(source, item, methodCall); + this.TranslateContains(source, item); return; default: @@ -220,18 +228,18 @@ private void TranslateMethodCall(MethodCallExpression methodCall) } } - private void TranslateContains(Expression source, Expression item, MethodCallExpression parent) + private void TranslateContains(Expression source, Expression item) { switch (source) { // Contains over array column (r => r.Strings.Contains("foo")) - case var _ when this.TryGetColumn(source, out _): - this.TranslateContainsOverArrayColumn(source, item, parent); + case var _ when this.TryBindProperty(source, out _): + this.TranslateContainsOverArrayColumn(source, item); return; // Contains over inline array (r => new[] { "foo", "bar" }.Contains(r.String)) case NewArrayExpression newArray: - this.Translate(item, parent); + this.Translate(item); this._sql.Append(" IN ("); var isFirst = true; @@ -246,15 +254,15 @@ private void TranslateContains(Expression source, Expression item, MethodCallExp this._sql.Append(", "); } - this.Translate(element, parent); + this.Translate(element); } this._sql.Append(')'); return; // Contains over captured array (r => arrayLocalVariable.Contains(r.String)) - case var _ when TryGetCapturedValue(source, out _, out var value): - this.TranslateContainsOverCapturedArray(source, item, parent, value); + case QueryParameterExpression { Value: var value }: + this.TranslateContainsOverParameterizedArray(source, item, value); return; default: @@ -262,11 +270,11 @@ private void TranslateContains(Expression source, Expression item, MethodCallExp } } - protected abstract void TranslateContainsOverArrayColumn(Expression source, Expression item, MethodCallExpression parent); + protected abstract void TranslateContainsOverArrayColumn(Expression source, Expression item); - protected abstract void TranslateContainsOverCapturedArray(Expression source, Expression item, MethodCallExpression parent, object? value); + protected abstract void TranslateContainsOverParameterizedArray(Expression source, Expression item, object? value); - private void TranslateUnary(UnaryExpression unary) + private void TranslateUnary(UnaryExpression unary, bool isSearchCondition) { switch (unary.NodeType) { @@ -283,44 +291,63 @@ private void TranslateUnary(UnaryExpression unary) } this._sql.Append("(NOT "); - this.Translate(unary.Operand, unary); + this.Translate(unary.Operand, isSearchCondition); this._sql.Append(')'); return; + // Handle convert over member access, for dynamic dictionary access (r => (int)r["SomeInt"] == 8) + case ExpressionType.Convert when this.TryBindProperty(unary.Operand, out var property) && unary.Type == property.Type: + this.GenerateColumn(property.StorageName, isSearchCondition); + return; + default: throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); } } - private bool TryGetColumn(Expression expression, [NotNullWhen(true)] out string? column) + private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out VectorStoreRecordPropertyModel? property) { - if (expression is MemberExpression member && member.Expression == this._recordParameter) + Type? convertedClrType = null; + + if (expression is UnaryExpression { NodeType: ExpressionType.Convert } unary) { - if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column)) + expression = unary.Operand; + convertedClrType = unary.Type; + } + + var modelName = expression switch + { + // Regular member access for strongly-typed POCO binding (e.g. r => r.SomeInt == 8) + MemberExpression memberExpression when memberExpression.Expression == this._recordParameter + => memberExpression.Member.Name, + + // Dictionary lookup for weakly-typed dynamic binding (e.g. r => r["SomeInt"] == 8) + MethodCallExpression { - throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); - } + Method: { Name: "get_Item", DeclaringType: var declaringType }, + Arguments: [ConstantExpression { Value: string keyName }] + } methodCall when methodCall.Object == this._recordParameter && declaringType == typeof(Dictionary) + => keyName, - return true; + _ => null + }; + + if (modelName is null) + { + property = null; + return false; } - column = null; - return false; - } + if (!this._model.PropertyMap.TryGetValue(modelName, out property)) + { + throw new InvalidOperationException($"Property name '{modelName}' provided as part of the filter clause is not a valid property name."); + } - private static bool TryGetCapturedValue(Expression expression, [NotNullWhen(true)] out string? name, out object? value) - { - if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } - && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) - && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + if (convertedClrType is not null && convertedClrType != property.Type) { - name = fieldInfo.Name; - value = fieldInfo.GetValue(constant.Value); - return true; + throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convertedClrType.Name}', but its configured type is '{property.Type.Name}'."); } - name = null; - value = null; - return false; + return true; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.DuckDB/Connectors.Memory.DuckDB.csproj b/dotnet/src/Connectors/Connectors.Memory.DuckDB/Connectors.Memory.DuckDB.csproj index d793de68dc3a..bc40eced6ff8 100644 --- a/dotnet/src/Connectors/Connectors.Memory.DuckDB/Connectors.Memory.DuckDB.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.DuckDB/Connectors.Memory.DuckDB.csproj @@ -4,8 +4,9 @@ Microsoft.SemanticKernel.Connectors.DuckDB $(AssemblyName) - net8.0;netstandard2.0 + net8.0;netstandard2.0;net462 alpha + false diff --git a/dotnet/src/Connectors/Connectors.Memory.InMemory/Connectors.Memory.InMemory.csproj b/dotnet/src/Connectors/Connectors.Memory.InMemory/Connectors.Memory.InMemory.csproj index 1815446a8f80..68455976f42f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.InMemory/Connectors.Memory.InMemory.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.InMemory/Connectors.Memory.InMemory.csproj @@ -3,13 +3,14 @@ Microsoft.SemanticKernel.Connectors.InMemory $(AssemblyName) - net8.0;netstandard2.0 + net8.0;netstandard2.0;net462 preview + @@ -27,6 +28,12 @@ + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryConstants.cs b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryConstants.cs new file mode 100644 index 000000000000..af318ef12622 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryConstants.cs @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.InMemory; + +internal static class InMemoryConstants +{ + internal const string VectorStoreSystemName = "inmemory"; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryKernelBuilderExtensions.cs index 85311ceba4fb..ade5a79b5861 100644 --- a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryKernelBuilderExtensions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.InMemory; @@ -8,6 +9,7 @@ namespace Microsoft.SemanticKernel; /// /// Extension methods to register Data services on the . /// +[Obsolete("The IKernelBuilder extensions are being obsoleted, call the appropriate function on the Services property of your IKernelBuilder")] public static class InMemoryKernelBuilderExtensions { /// @@ -18,12 +20,12 @@ public static class InMemoryKernelBuilderExtensions /// The kernel builder. public static IKernelBuilder AddInMemoryVectorStore(this IKernelBuilder builder, string? serviceId = default) { - builder.Services.AddInMemoryVectorStore(serviceId); + builder.Services.AddInMemoryVectorStore(serviceId: serviceId); return builder; } /// - /// Register an InMemory and with the specified service ID. + /// Register an InMemory and with the specified service ID. /// /// The type of the key. /// The type of the record. @@ -38,6 +40,7 @@ public static IKernelBuilder AddInMemoryVectorStoreRecordCollection? options = default, string? serviceId = default) where TKey : notnull + where TRecord : notnull { builder.Services.AddInMemoryVectorStoreRecordCollection(collectionName, options, serviceId); return builder; diff --git a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryModelBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryModelBuilder.cs new file mode 100644 index 000000000000..096ba01d467a --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryModelBuilder.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.Extensions.VectorData.ConnectorSupport; + +namespace Microsoft.SemanticKernel.Connectors.InMemory; + +internal class InMemoryModelBuilder() : VectorStoreRecordModelBuilder(ValidationOptions) +{ + internal static readonly VectorStoreRecordModelBuildingOptions ValidationOptions = new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + + // Disable property type validation + SupportedKeyPropertyTypes = null, + SupportedDataPropertyTypes = null, + SupportedEnumerableDataPropertyElementTypes = null, + SupportedVectorPropertyTypes = [typeof(ReadOnlyMemory), typeof(ReadOnlyMemory?)] + }; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryServiceCollectionExtensions.cs index b541aad65b98..56f1c8769bfe 100644 --- a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryServiceCollectionExtensions.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.InMemory; @@ -17,17 +18,30 @@ public static class InMemoryServiceCollectionExtensions /// Register an InMemory with the specified service ID. /// /// The to register the on. + /// Optional options to further configure the . /// An optional service id to use as the service key. /// The service collection. - public static IServiceCollection AddInMemoryVectorStore(this IServiceCollection services, string? serviceId = default) + public static IServiceCollection AddInMemoryVectorStore(this IServiceCollection services, InMemoryVectorStoreOptions? options = default, string? serviceId = default) { + services.AddKeyedTransient( + serviceId, + (sp, obj) => + { + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; + + return new InMemoryVectorStore(options); + }); + services.AddKeyedSingleton(serviceId); services.AddKeyedSingleton(serviceId, (sp, obj) => sp.GetRequiredKeyedService(serviceId)); return services; } /// - /// Register an InMemory and with the specified service ID. + /// Register an InMemory and with the specified service ID. /// /// The type of the key. /// The type of the record. @@ -42,16 +56,20 @@ public static IServiceCollection AddInMemoryVectorStoreRecordCollection? options = default, string? serviceId = default) where TKey : notnull + where TRecord : notnull { services.AddKeyedSingleton>( serviceId, (sp, obj) => { - var selectedOptions = options ?? sp.GetService>(); - return (new InMemoryVectorStoreRecordCollection(collectionName, selectedOptions) as IVectorStoreRecordCollection)!; + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; + return (new InMemoryVectorStoreRecordCollection(collectionName, options) as IVectorStoreRecordCollection)!; }); - services.AddKeyedSingleton>( + services.AddKeyedSingleton>( serviceId, (sp, obj) => { diff --git a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorRecordWrapper.cs b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorRecordWrapper.cs new file mode 100644 index 000000000000..856b1c1a2d9b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorRecordWrapper.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; + +namespace Microsoft.SemanticKernel.Connectors.InMemory; + +internal readonly struct InMemoryVectorRecordWrapper(TRecord record) +{ + public TRecord Record { get; } = record; + public Dictionary> EmbeddingGeneratedVectors { get; } = new(); +} diff --git a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStore.cs index 2db7013b0d27..85b623551980 100644 --- a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStore.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Linq; using System.Threading; +using System.Threading.Tasks; using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.InMemory; @@ -14,8 +15,13 @@ namespace Microsoft.SemanticKernel.Connectors.InMemory; /// public sealed class InMemoryVectorStore : IVectorStore { + private readonly InMemoryVectorStoreOptions _options; + + /// Metadata about vector store. + private readonly VectorStoreMetadata _metadata; + /// Internal storage for the record collection. - private readonly ConcurrentDictionary> _internalCollection; + private readonly ConcurrentDictionary> _internalCollections; /// The data type of each collection, to enforce a single type per collection. private readonly ConcurrentDictionary _internalCollectionTypes = new(); @@ -23,23 +29,22 @@ public sealed class InMemoryVectorStore : IVectorStore /// /// Initializes a new instance of the class. /// - public InMemoryVectorStore() + /// Optional configuration options for this class + public InMemoryVectorStore(InMemoryVectorStoreOptions? options = default) { - this._internalCollection = new(); - } + this._options = options ?? new InMemoryVectorStoreOptions(); + this._internalCollections = new(); - /// - /// Initializes a new instance of the class. - /// - /// Allows passing in the dictionary used for storage, for testing purposes. - internal InMemoryVectorStore(ConcurrentDictionary> internalCollection) - { - this._internalCollection = internalCollection; + this._metadata = new() + { + VectorStoreSystemName = InMemoryConstants.VectorStoreSystemName, + }; } /// public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull + where TRecord : notnull { if (this._internalCollectionTypes.TryGetValue(name, out var existingCollectionDataType) && existingCollectionDataType != typeof(TRecord)) { @@ -47,16 +52,47 @@ public IVectorStoreRecordCollection GetCollection( } var collection = new InMemoryVectorStoreRecordCollection( - this._internalCollection, + this._internalCollections, this._internalCollectionTypes, name, - new() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }) as IVectorStoreRecordCollection; + new() + { + VectorStoreRecordDefinition = vectorStoreRecordDefinition, + EmbeddingGenerator = this._options.EmbeddingGenerator + }) as IVectorStoreRecordCollection; return collection!; } /// public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) { - return this._internalCollection.Keys.ToAsyncEnumerable(); + return this._internalCollections.Keys.ToAsyncEnumerable(); + } + + /// + public Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) + { + return this._internalCollections.ContainsKey(name) ? Task.FromResult(true) : Task.FromResult(false); + } + + /// + public Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default) + { + this._internalCollections.TryRemove(name, out _); + this._internalCollectionTypes.TryRemove(name, out _); + return Task.CompletedTask; + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreMetadata) ? this._metadata : + serviceType == typeof(ConcurrentDictionary>) ? this._internalCollections : + serviceType.IsInstanceOfType(this) ? this : + null; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreCollectionSearchMapping.cs index 6b33671cef9f..33e0afda04b7 100644 --- a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreCollectionSearchMapping.cs @@ -92,15 +92,15 @@ public static float ConvertScore(float score, string? distanceFunction) /// /// Filter the provided records using the provided filter definition. /// - /// The filter definition to filter the with. - /// The records to filter. + /// The filter definition to filter the with. + /// The records to filter. /// The filtered records. /// Thrown when an unsupported filter clause is encountered. - public static IEnumerable FilterRecords(VectorSearchFilter filter, IEnumerable records) + public static IEnumerable> FilterRecords(VectorSearchFilter filter, IEnumerable> recordWrappers) { - return records.Where(record => + return recordWrappers.Where(wrapper => { - if (record is null) + if (wrapper.Record is null) { return false; } @@ -114,7 +114,7 @@ public static IEnumerable FilterRecords(VectorSearchFilter fil { if (clause is EqualToFilterClause equalToFilter) { - result = result && CheckEqualTo(record, equalToFilter); + result = result && CheckEqualTo(wrapper.Record, equalToFilter); if (result == false) { @@ -123,7 +123,7 @@ public static IEnumerable FilterRecords(VectorSearchFilter fil } else if (clause is AnyTagEqualToFilterClause anyTagEqualToFilter) { - result = result && CheckAnyTagEqualTo(record, anyTagEqualToFilter); + result = result && CheckAnyTagEqualTo(wrapper.Record, anyTagEqualToFilter); if (result == false) { diff --git a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreExtensions.cs index b81f3e16d062..7975f783d5d3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreExtensions.cs @@ -32,6 +32,7 @@ public static async Task SerializeCollectionAsJsonAsync( Stream stream, JsonSerializerOptions? jsonSerializerOptions = null) where TKey : notnull + where TRecord : notnull { // Get collection and verify that it exists. var collection = vectorStore.GetCollection(collectionName); @@ -59,6 +60,7 @@ public static async Task SerializeCollectionAsJsonAsync( this InMemoryVectorStore vectorStore, Stream stream) where TKey : notnull + where TRecord : notnull { IVectorStoreRecordCollection? collection = null; diff --git a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreOptions.cs new file mode 100644 index 000000000000..e3fb530e1821 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreOptions.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel.Connectors.InMemory; + +/// +/// Options when creating a . +/// +public sealed class InMemoryVectorStoreOptions +{ + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreRecordCollection.cs index 6fbcdf2633bf..d4b4646b1671 100644 --- a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreRecordCollection.cs @@ -3,12 +3,16 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.Properties; namespace Microsoft.SemanticKernel.Connectors.InMemory; @@ -21,13 +25,10 @@ namespace Microsoft.SemanticKernel.Connectors.InMemory; public sealed class InMemoryVectorStoreRecordCollection : IVectorStoreRecordCollection #pragma warning restore CA1711 // Identifiers should not have incorrect suffix where TKey : notnull + where TRecord : notnull { - /// A set of types that vectors on the provided model may have. - private static readonly HashSet s_supportedVectorTypes = - [ - typeof(ReadOnlyMemory), - typeof(ReadOnlyMemory?), - ]; + /// Metadata about vector store record collection. + private readonly VectorStoreRecordCollectionMetadata _collectionMetadata; /// The default options for vector search. private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); @@ -44,11 +45,8 @@ public sealed class InMemoryVectorStoreRecordCollection : IVector /// The name of the collection that this will access. private readonly string _collectionName; - /// A helper to access property information for the current data model and record definition. - private readonly VectorStoreRecordPropertyReader _propertyReader; - - /// A dictionary of vector properties on the provided model, keyed by the property name. - private readonly Dictionary _vectorProperties; + /// The model for this collection. + private readonly VectorStoreRecordModel _model; /// An function to look up vectors from the records. private readonly InMemoryVectorStoreVectorResolver _vectorResolver; @@ -59,28 +57,52 @@ public sealed class InMemoryVectorStoreRecordCollection : IVector /// /// Initializes a new instance of the class. /// - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// Optional configuration options for this class. - public InMemoryVectorStoreRecordCollection(string collectionName, InMemoryVectorStoreRecordCollectionOptions? options = default) + public InMemoryVectorStoreRecordCollection(string name, InMemoryVectorStoreRecordCollectionOptions? options = default) { // Verify. - Verify.NotNullOrWhiteSpace(collectionName); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); + Verify.NotNullOrWhiteSpace(name); // Assign. - this._collectionName = collectionName; + this._collectionName = name; this._internalCollections = new(); this._internalCollectionTypes = new(); this._options = options ?? new InMemoryVectorStoreRecordCollectionOptions(); - this._propertyReader = new VectorStoreRecordPropertyReader(typeof(TRecord), this._options.VectorStoreRecordDefinition, new() { RequiresAtLeastOneVector = false, SupportsMultipleKeys = false, SupportsMultipleVectors = true }); - // Validate property types. - this._propertyReader.VerifyVectorProperties(s_supportedVectorTypes); - this._vectorProperties = this._propertyReader.VectorProperties.ToDictionary(x => x.DataModelPropertyName); + this._model = new InMemoryModelBuilder() + .Build(typeof(TRecord), this._options.VectorStoreRecordDefinition, this._options.EmbeddingGenerator); // Assign resolvers. - this._vectorResolver = CreateVectorResolver(this._options.VectorResolver, this._vectorProperties); - this._keyResolver = CreateKeyResolver(this._options.KeyResolver, this._propertyReader.KeyProperty); + // TODO: Make generic to avoid boxing +#pragma warning disable MEVD9000 // KeyResolver and VectorResolver are experimental + this._keyResolver = this._options.KeyResolver is null + ? record => (TKey)this._model.KeyProperty.GetValueAsObject(record)! + : this._options.KeyResolver; + + this._vectorResolver = this._options.VectorResolver is not null + ? this._options.VectorResolver + : (vectorPropertyName, record) => + { + if (!this._model.PropertyMap.TryGetValue(vectorPropertyName, out var property)) + { + throw new InvalidOperationException($"The collection does not have a vector field named '{vectorPropertyName}', so vector search is not possible."); + } + + if (property is not VectorStoreRecordVectorPropertyModel vectorProperty) + { + throw new InvalidOperationException($"The property '{vectorPropertyName}' isn't a vector property."); + } + + return property.GetValueAsObject(record); + }; +#pragma warning restore MEVD9000 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + + this._collectionMetadata = new() + { + VectorStoreSystemName = InMemoryConstants.VectorStoreSystemName, + CollectionName = name + }; } /// @@ -102,7 +124,7 @@ internal InMemoryVectorStoreRecordCollection( } /// - public string CollectionName => this._collectionName; + public string Name => this._collectionName; /// public Task CollectionExistsAsync(CancellationToken cancellationToken = default) @@ -122,8 +144,8 @@ public Task CreateCollectionAsync(CancellationToken cancellationToken = default) return Task.FromException(new VectorStoreOperationException("Collection already exists.") { - VectorStoreType = "InMemory", - CollectionName = this.CollectionName, + VectorStoreSystemName = InMemoryConstants.VectorStoreSystemName, + CollectionName = this.Name, OperationName = "CreateCollection" }); } @@ -141,25 +163,38 @@ public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellatio public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) { this._internalCollections.TryRemove(this._collectionName, out _); + this._internalCollectionTypes.TryRemove(this._collectionName, out _); return Task.CompletedTask; } /// public Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { + if (options?.IncludeVectors == true && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + var collectionDictionary = this.GetCollectionDictionary(); if (collectionDictionary.TryGetValue(key, out var record)) { - return Task.FromResult((TRecord?)record); + return Task.FromResult(((InMemoryVectorRecordWrapper)record).Record); } return Task.FromResult(default); } /// - public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable GetAsync(IEnumerable keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + Verify.NotNull(keys); + + if (options?.IncludeVectors == true && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + foreach (var key in keys) { var record = await this.GetAsync(key, options, cancellationToken).ConfigureAwait(false); @@ -181,8 +216,10 @@ public Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) } /// - public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + public Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default) { + Verify.NotNull(keys); + var collectionDictionary = this.GetCollectionDictionary(); foreach (var key in keys) @@ -194,183 +231,317 @@ public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellat } /// - public Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + { + var keys = await this.UpsertAsync([record], cancellationToken).ConfigureAwait(false); + + return keys.Single(); + } + + /// + public async Task> UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) { - Verify.NotNull(record); + Verify.NotNull(records); + + IReadOnlyList? recordsList = null; + + // If an embedding generator is defined, invoke it once per property for all records. + IReadOnlyList?[]? generatedEmbeddings = null; + + var vectorPropertyCount = this._model.VectorProperties.Count; + for (var i = 0; i < vectorPropertyCount; i++) + { + var vectorProperty = this._model.VectorProperties[i]; + + if (vectorProperty.EmbeddingGenerator is null) + { + continue; + } + // We have a property with embedding generation; materialize the records' enumerable if needed, to + // prevent multiple enumeration. + if (recordsList is null) + { + recordsList = records is IReadOnlyList r ? r : records.ToList(); + + if (recordsList.Count == 0) + { + return []; + } + + records = recordsList; + } + + // TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties), + // and generate embeddings for them in a single batch. That's some more complexity though. + if (vectorProperty.TryGenerateEmbeddings, ReadOnlyMemory>(records, cancellationToken, out var floatTask)) + { + generatedEmbeddings ??= new IReadOnlyList?[vectorPropertyCount]; + generatedEmbeddings[i] = (IReadOnlyList>)await floatTask.ConfigureAwait(false); + } + else + { + throw new InvalidOperationException( + $"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding).Name}' for the given input type."); + } + } + + var keys = new List(); var collectionDictionary = this.GetCollectionDictionary(); - var key = (TKey)this._keyResolver(record)!; - collectionDictionary.AddOrUpdate(key!, record, (key, currentValue) => record); + var recordIndex = 0; + foreach (var record in records) + { + var key = (TKey)this._keyResolver(record)!; + var wrappedRecord = new InMemoryVectorRecordWrapper(record); + + if (generatedEmbeddings is not null) + { + for (var i = 0; i < this._model.VectorProperties.Count; i++) + { + if (generatedEmbeddings![i] is IReadOnlyList propertyEmbeddings) + { + var property = this._model.VectorProperties[i]; + + wrappedRecord.EmbeddingGeneratedVectors[property.ModelName] = propertyEmbeddings[recordIndex] switch + { + Embedding e => e.Vector, + _ => throw new UnreachableException() + }; + } + } + } + + collectionDictionary.AddOrUpdate(key!, wrappedRecord, (key, currentValue) => wrappedRecord); + + keys.Add(key); + + recordIndex++; + } - return Task.FromResult(key!); + return keys; } + #region Search + /// - public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable> SearchAsync( + TInput value, + int top, + VectorSearchOptions? options = default, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TInput : notnull { - foreach (var record in records) + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + switch (vectorProperty.EmbeddingGenerator) { - yield return await this.UpsertAsync(record, cancellationToken).ConfigureAwait(false); + case IEmbeddingGenerator> generator: + { + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; + } + + case null: + throw new InvalidOperationException(VectorDataStrings.NoEmbeddingGeneratorWasConfiguredForSearch); + + default: + throw new InvalidOperationException( + InMemoryModelBuilder.ValidationOptions.SupportedVectorPropertyTypes.Contains(typeof(TInput)) + ? string.Format(VectorDataStrings.EmbeddingTypePassedToSearchAsync) + : string.Format(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType, typeof(TInput).Name, vectorProperty.EmbeddingGenerator.GetType().Name)); } } /// -#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - Need to satisfy the interface which returns IAsyncEnumerable - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) -#pragma warning restore CS1998 + public IAsyncEnumerable> SearchEmbeddingAsync( + TVector vector, + int top, + VectorSearchOptions? options = null, + CancellationToken cancellationToken = default) + where TVector : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + return this.SearchCoreAsync(vector, top, vectorProperty, operationName: "SearchEmbedding", options, cancellationToken); + } + + private IAsyncEnumerable> SearchCoreAsync( + TVector vector, + int top, + VectorStoreRecordVectorPropertyModel vectorProperty, + string operationName, + VectorSearchOptions options, + CancellationToken cancellationToken = default) + where TVector : notnull { Verify.NotNull(vector); + Verify.NotLessThan(top, 1); if (vector is not ReadOnlyMemory floatVector) { throw new NotSupportedException($"The provided vector type {vector.GetType().FullName} is not supported by the InMemory Vector Store."); } - // Resolve options and get requested vector property or first as default. - var internalOptions = options ?? s_defaultVectorSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(internalOptions); + if (options.IncludeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } #pragma warning disable CS0618 // VectorSearchFilter is obsolete // Filter records using the provided filter before doing the vector comparison. - var allValues = this.GetCollectionDictionary().Values.Cast(); - var filteredRecords = internalOptions switch + var allValues = this.GetCollectionDictionary().Values.Cast>(); + var filteredRecords = options switch { { OldFilter: not null, Filter: not null } => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), { OldFilter: VectorSearchFilter legacyFilter } => InMemoryVectorStoreCollectionSearchMapping.FilterRecords(legacyFilter, allValues), - { Filter: Expression> newFilter } => allValues.AsQueryable().Where(newFilter), + { Filter: Expression> newFilter } => allValues.AsQueryable().Where(this.ConvertFilter(newFilter)), _ => allValues }; #pragma warning restore CS0618 // VectorSearchFilter is obsolete // Compare each vector in the filtered results with the provided vector. - var results = filteredRecords.Select(record => + var results = filteredRecords.Select, (TRecord record, float score)?>(wrapper => { - var vectorObject = this._vectorResolver(vectorProperty.DataModelPropertyName!, record); - if (vectorObject is not ReadOnlyMemory dbVector) + ReadOnlyMemory vector; + + if (vectorProperty.EmbeddingGenerator is null) + { + var vectorObject = this._vectorResolver(vectorProperty.ModelName!, wrapper.Record); + if (vectorObject is not ReadOnlyMemory dbVector) + { + return null; + } + vector = dbVector; + } + else { - return null; + vector = wrapper.EmbeddingGeneratedVectors[vectorProperty.ModelName]; } - var score = InMemoryVectorStoreCollectionSearchMapping.CompareVectors(floatVector.Span, dbVector.Span, vectorProperty.DistanceFunction); + var score = InMemoryVectorStoreCollectionSearchMapping.CompareVectors(floatVector.Span, vector.Span, vectorProperty.DistanceFunction); var convertedscore = InMemoryVectorStoreCollectionSearchMapping.ConvertScore(score, vectorProperty.DistanceFunction); - return (record, convertedscore); + return (wrapper.Record, convertedscore); }); // Get the non-null results since any record with a null vector results in a null result. var nonNullResults = results.Where(x => x.HasValue).Select(x => x!.Value); - // Calculate the total results count if requested. - long? count = null; - if (internalOptions.IncludeTotalCount) - { - count = nonNullResults.Count(); - } - // Sort the results appropriately for the selected distance function and get the right page of results . var sortedScoredResults = InMemoryVectorStoreCollectionSearchMapping.ShouldSortDescending(vectorProperty.DistanceFunction) ? nonNullResults.OrderByDescending(x => x.score) : nonNullResults.OrderBy(x => x.score); - var resultsPage = sortedScoredResults.Skip(internalOptions.Skip).Take(internalOptions.Top); + var resultsPage = sortedScoredResults.Skip(options.Skip).Take(top); // Build the response. - var vectorSearchResultList = resultsPage.Select(x => new VectorSearchResult((TRecord)x.record, x.score)).ToAsyncEnumerable(); - return new VectorSearchResults(vectorSearchResultList) { TotalCount = count }; + return resultsPage.Select(x => new VectorSearchResult((TRecord)x.record, x.score)).ToAsyncEnumerable(); } - /// - /// Get the collection dictionary from the internal storage, throws if it does not exist. - /// - /// The retrieved collection dictionary. - internal ConcurrentDictionary GetCollectionDictionary() - { - if (!this._internalCollections.TryGetValue(this._collectionName, out var collectionDictionary)) - { - throw new VectorStoreOperationException($"Call to vector store failed. Collection '{this._collectionName}' does not exist."); - } + /// + [Obsolete("Use either SearchEmbeddingAsync to search directly on embeddings, or SearchAsync to handle embedding generation internally as part of the call.")] + public IAsyncEnumerable> VectorizedSearchAsync(TVector vector, int top, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + => this.SearchEmbeddingAsync(vector, top, options, cancellationToken); - return collectionDictionary; + #endregion Search + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreRecordCollectionMetadata) ? this._collectionMetadata : + serviceType == typeof(ConcurrentDictionary>) ? this._internalCollections : + serviceType.IsInstanceOfType(this) ? this : + null; } - /// - /// Pick / create a vector resolver that will read a vector from a record in the store based on the vector name. - /// 1. If an override resolver is provided, use that. - /// 2. If the record type is create a resolver that looks up the vector in its dictionary. - /// 3. Otherwise, create a resolver that assumes the vector is a property directly on the record and use the record definition to determine the name. - /// - /// The override vector resolver if one was provided. - /// A dictionary of vector properties from the record definition. - /// The . - private static InMemoryVectorStoreVectorResolver CreateVectorResolver(InMemoryVectorStoreVectorResolver? overrideVectorResolver, Dictionary vectorProperties) + /// + public IAsyncEnumerable GetAsync(Expression> filter, int top, + GetFilteredRecordOptions? options = null, CancellationToken cancellationToken = default) { - // Custom resolver. - if (overrideVectorResolver is not null) - { - return overrideVectorResolver; - } + Verify.NotNull(filter); + Verify.NotLessThan(top, 1); - // Generic data model resolver. - if (typeof(TRecord).IsGenericType && typeof(TRecord).GetGenericTypeDefinition() == typeof(VectorStoreGenericDataModel<>)) - { - return (vectorName, record) => - { - var genericDataModelRecord = record as VectorStoreGenericDataModel; - var vectorsDictionary = genericDataModelRecord!.Vectors; - if (vectorsDictionary != null && vectorsDictionary.TryGetValue(vectorName, out var vector)) - { - return vector; - } + options ??= new(); - throw new InvalidOperationException($"The collection does not have a vector field named '{vectorName}', so vector search is not possible."); - }; + if (options.IncludeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); } - // Default resolver. - var vectorPropertiesInfo = vectorProperties.Values - .Select(x => x.DataModelPropertyName) - .Select(x => typeof(TRecord).GetProperty(x) ?? throw new ArgumentException($"Vector property '{x}' was not found on {typeof(TRecord).Name}")) - .ToDictionary(x => x.Name); + var records = this.GetCollectionDictionary() + .Values + .Cast>() + .Select(x => x.Record) + .AsQueryable() + .Where(filter); - return (vectorName, record) => + if (options.OrderBy.Values.Count > 0) { - if (vectorPropertiesInfo.TryGetValue(vectorName, out var vectorPropertyInfo)) + var first = options.OrderBy.Values[0]; + var sorted = first.Ascending + ? records.OrderBy(first.PropertySelector) + : records.OrderByDescending(first.PropertySelector); + + for (int i = 1; i < options.OrderBy.Values.Count; i++) { - return vectorPropertyInfo.GetValue(record); + var next = options.OrderBy.Values[i]; + sorted = next.Ascending + ? sorted.ThenBy(next.PropertySelector) + : sorted.ThenByDescending(next.PropertySelector); } - throw new InvalidOperationException($"The collection does not have a vector field named '{vectorName}', so vector search is not possible."); - }; + records = sorted; + } + + return records + .Skip(options.Skip) + .Take(top) + .ToAsyncEnumerable(); } /// - /// Pick / create a key resolver that will read a key from a record in the store. - /// 1. If an override resolver is provided, use that. - /// 2. If the record type is create a resolver that reads the Key property from it. - /// 3. Otherwise, create a resolver that assumes the key is a property directly on the record and use the record definition to determine the name. + /// Get the collection dictionary from the internal storage, throws if it does not exist. /// - /// The override key resolver if one was provided. - /// They key property from the record definition. - /// The . - private static InMemoryVectorStoreKeyResolver CreateKeyResolver(InMemoryVectorStoreKeyResolver? overrideKeyResolver, VectorStoreRecordKeyProperty keyProperty) + /// The retrieved collection dictionary. + internal ConcurrentDictionary GetCollectionDictionary() { - // Custom resolver. - if (overrideKeyResolver is not null) + if (!this._internalCollections.TryGetValue(this._collectionName, out var collectionDictionary)) { - return overrideKeyResolver; + throw new VectorStoreOperationException($"Call to vector store failed. Collection '{this._collectionName}' does not exist."); } - // Generic data model resolver. - if (typeof(TRecord).IsGenericType && typeof(TRecord).GetGenericTypeDefinition() == typeof(VectorStoreGenericDataModel<>)) - { - return (record) => - { - var genericDataModelRecord = record as VectorStoreGenericDataModel; - return genericDataModelRecord!.Key; - }; - } + return collectionDictionary; + } + + /// + /// The user provides a filter expression accepting a Record, but we internally store it wrapped in an InMemoryVectorRecordWrapper. + /// This method converts a filter expression accepting a Record to one accepting an InMemoryVectorRecordWrapper. + /// + private Expression, bool>> ConvertFilter(Expression> recordFilter) + { + var wrapperParameter = Expression.Parameter(typeof(InMemoryVectorRecordWrapper), "w"); + var replacement = Expression.Property(wrapperParameter, nameof(InMemoryVectorRecordWrapper.Record)); + + return Expression.Lambda, bool>>( + new ParameterReplacer(recordFilter.Parameters.Single(), replacement).Visit(recordFilter.Body), + wrapperParameter); + } - // Default resolver. - var keyPropertyInfo = typeof(TRecord).GetProperty(keyProperty.DataModelPropertyName) ?? throw new ArgumentException($"Key property {keyProperty.DataModelPropertyName} not found on {typeof(TRecord).Name}"); - return (record) => (TKey)keyPropertyInfo.GetValue(record)!; + private sealed class ParameterReplacer(ParameterExpression originalRecordParameter, Expression replacementExpression) : ExpressionVisitor + { + protected override Expression VisitParameter(ParameterExpression node) + => node == originalRecordParameter ? replacementExpression : base.VisitParameter(node); } } diff --git a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreRecordCollectionOptions.cs index 5e5dfc7e166a..b93a3caf66de 100644 --- a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreRecordCollectionOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreRecordCollectionOptions.cs @@ -1,5 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.InMemory; @@ -22,6 +24,11 @@ public sealed class InMemoryVectorStoreRecordCollectionOptions /// public VectorStoreRecordDefinition? VectorStoreRecordDefinition { get; init; } = null; + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } + /// /// An optional function that can be used to look up vectors from a record. /// @@ -30,6 +37,7 @@ public sealed class InMemoryVectorStoreRecordCollectionOptions /// using reflection. This delegate can be used to provide a custom implementation if /// the vector properties are located somewhere else on the record. /// + [Experimental("MEVD9000")] public InMemoryVectorStoreVectorResolver? VectorResolver { get; init; } = null; /// @@ -40,5 +48,6 @@ public sealed class InMemoryVectorStoreRecordCollectionOptions /// using reflection. This delegate can be used to provide a custom implementation if /// the key property is located somewhere else on the record. /// + [Experimental("MEVD9000")] public InMemoryVectorStoreKeyResolver? KeyResolver { get; init; } = null; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Kusto/Connectors.Memory.Kusto.csproj b/dotnet/src/Connectors/Connectors.Memory.Kusto/Connectors.Memory.Kusto.csproj index dddcbcd37c5f..3d9ba43628ab 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Kusto/Connectors.Memory.Kusto.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Kusto/Connectors.Memory.Kusto.csproj @@ -3,10 +3,11 @@ Microsoft.SemanticKernel.Connectors.Kusto Microsoft.SemanticKernel.Connectors.Kusto - net8.0;netstandard2.0 + net8.0;netstandard2.0;net462 alpha $(NoWarn);NU5104 + false diff --git a/dotnet/src/Connectors/Connectors.Memory.Kusto/README.md b/dotnet/src/Connectors/Connectors.Memory.Kusto/README.md index f7c276c7e9c3..53a014bcd93f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Kusto/README.md +++ b/dotnet/src/Connectors/Connectors.Memory.Kusto/README.md @@ -1,13 +1,12 @@ # Microsoft.SemanticKernel.Connectors.Kusto -This connector uses [Azure Data Explorer (Kusto)](https://learn.microsoft.com/en-us/azure/data-explorer/) to implement Semantic Memory. +This connector uses [Azure Data Explorer (Kusto)](https://learn.microsoft.com/azure/data-explorer/) to implement Semantic Memory. ## Quick Start -1. Create a cluster and database in Azure Data Explorer (Kusto) - see https://learn.microsoft.com/en-us/azure/data-explorer/create-cluster-and-database?tabs=free +1. Create a cluster and database in Azure Data Explorer (Kusto) - see https://learn.microsoft.com/azure/data-explorer/create-cluster-and-database?tabs=free 2. To use Kusto as a semantic memory store, use the following code: - > See [Example 14](../../../samples/Concepts/Memory/SemanticTextMemory_Building.cs) and [Example 15](../../../samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs) for more memory usage examples with the kernel. ```csharp using Kusto.Data; @@ -37,9 +36,9 @@ The function is called `series_cosine_similarity_fl` and is located in the `Func Kusto is an append-only store. This means that when a fact is updated, the old fact is not deleted. This isn't a problem for the semantic memory connector, as it always utilizes the most recent fact. -This is made possible by using the [arg_max](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/arg-max-aggfunction) aggregation function in conjunction with the [ingestion_time](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ingestiontimefunction) function. +This is made possible by using the [arg_max](https://learn.microsoft.com/azure/data-explorer/kusto/query/arg-max-aggfunction) aggregation function in conjunction with the [ingestion_time](https://learn.microsoft.com/azure/data-explorer/kusto/query/ingestiontimefunction) function. However, users manually querying the underlying table should be aware of this behavior. ### Authentication -Please note that the authentication used in the example above is not recommended for production use. You can find more details here: https://learn.microsoft.com/en-us/azure/data-explorer/kusto/api/connection-strings/kusto +Please note that the authentication used in the example above is not recommended for production use. You can find more details here: https://learn.microsoft.com/azure/data-explorer/kusto/api/connection-strings/kusto diff --git a/dotnet/src/Connectors/Connectors.Memory.Milvus/Connectors.Memory.Milvus.csproj b/dotnet/src/Connectors/Connectors.Memory.Milvus/Connectors.Memory.Milvus.csproj index 9df2ba3e4db3..07b6696cea8a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Milvus/Connectors.Memory.Milvus.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Milvus/Connectors.Memory.Milvus.csproj @@ -4,7 +4,7 @@ Microsoft.SemanticKernel.Connectors.Milvus $(AssemblyName) - net8.0;netstandard2.0 + net8.0;netstandard2.0;net462 enable alpha diff --git a/dotnet/src/Connectors/Connectors.Memory.Milvus/README.md b/dotnet/src/Connectors/Connectors.Memory.Milvus/README.md index b4d8e71d5a2c..cbdb1f99f35c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Milvus/README.md +++ b/dotnet/src/Connectors/Connectors.Memory.Milvus/README.md @@ -19,7 +19,6 @@ docker-compose up -d ``` 3. Use Semantic Kernel with Milvus, connecting to `localhost` with the default (gRPC) port of 1536: - > See [Example 14](../../../samples/Concepts/Memory/SemanticTextMemory_Building.cs) and [Example 15](../../../samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs) for more memory usage examples with the kernel. ```csharp using MilvusMemoryStore memoryStore = new("localhost"); diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/Connectors.Memory.MongoDB.csproj b/dotnet/src/Connectors/Connectors.Memory.MongoDB/Connectors.Memory.MongoDB.csproj index b091931d6e9e..bc85cd441115 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/Connectors.Memory.MongoDB.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/Connectors.Memory.MongoDB.csproj @@ -4,13 +4,14 @@ Microsoft.SemanticKernel.Connectors.MongoDB $(AssemblyName) - net8.0;netstandard2.0 + net8.0;netstandard2.0;net462 preview + @@ -27,6 +28,12 @@ + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/IMongoDBVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/IMongoDBVectorStoreRecordCollectionFactory.cs index 0726870eb56c..d395b17b30e6 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/IMongoDBVectorStoreRecordCollectionFactory.cs +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/IMongoDBVectorStoreRecordCollectionFactory.cs @@ -22,5 +22,6 @@ public interface IMongoDBVectorStoreRecordCollectionFactory /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. /// The new instance of . IVectorStoreRecordCollection CreateVectorStoreRecordCollection(IMongoDatabase mongoDatabase, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) - where TKey : notnull; + where TKey : notnull + where TRecord : notnull; } diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBFilterTranslator.cs index 202908de1c0b..0280aee116c3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBFilterTranslator.cs @@ -7,8 +7,8 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.ConnectorSupport.Filter; using MongoDB.Bson; namespace Microsoft.SemanticKernel.Connectors.MongoDB; @@ -17,17 +17,20 @@ namespace Microsoft.SemanticKernel.Connectors.MongoDB; // Information specific to vector search pre-filter: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter internal class MongoDBFilterTranslator { - private IReadOnlyDictionary _storagePropertyNames = null!; + private VectorStoreRecordModel _model = null!; private ParameterExpression _recordParameter = null!; - internal BsonDocument Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + internal BsonDocument Translate(LambdaExpression lambdaExpression, VectorStoreRecordModel model) { - this._storagePropertyNames = storagePropertyNames; + this._model = model; Debug.Assert(lambdaExpression.Parameters.Count == 1); this._recordParameter = lambdaExpression.Parameters[0]; - return this.Translate(lambdaExpression.Body); + var preprocessor = new FilterTranslationPreprocessor { InlineCapturedVariables = true }; + var preprocessedExpression = preprocessor.Visit(lambdaExpression.Body); + + return this.Translate(preprocessedExpression); } private BsonDocument Translate(Expression? node) @@ -46,9 +49,9 @@ or ExpressionType.LessThan or ExpressionType.LessThanOrEqual UnaryExpression { NodeType: ExpressionType.Not } not => this.TranslateNot(not), - // MemberExpression is generally handled within e.g. TranslateEqualityComparison; this is used to translate direct bool inside filter (e.g. Filter => r => r.Bool) - MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _) - => this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(true))), + // Special handling for bool constant as the filter expression (r => r.Bool) + Expression when node.Type == typeof(bool) && this.TryBindProperty(node, out var property) + => this.GenerateEqualityComparison(property, value: true, ExpressionType.Equal), MethodCallExpression methodCall => this.TranslateMethodCall(methodCall), @@ -56,36 +59,37 @@ or ExpressionType.LessThan or ExpressionType.LessThanOrEqual }; private BsonDocument TranslateEqualityComparison(BinaryExpression binary) + => this.TryBindProperty(binary.Left, out var property) && binary.Right is ConstantExpression { Value: var rightConstant } + ? this.GenerateEqualityComparison(property, rightConstant, binary.NodeType) + : this.TryBindProperty(binary.Right, out property) && binary.Left is ConstantExpression { Value: var leftConstant } + ? this.GenerateEqualityComparison(property, leftConstant, binary.NodeType) + : throw new NotSupportedException("Invalid equality/comparison"); + + private BsonDocument GenerateEqualityComparison(VectorStoreRecordPropertyModel property, object? value, ExpressionType nodeType) { - if ((this.TryTranslateFieldAccess(binary.Left, out var storagePropertyName) && TryGetConstant(binary.Right, out var value)) - || (this.TryTranslateFieldAccess(binary.Right, out storagePropertyName) && TryGetConstant(binary.Left, out value))) + if (value is null) { - if (value is null) - { - throw new NotSupportedException("MongogDB does not support null checks in vector search pre-filters"); - } - - // Short form of equality (instead of $eq) - if (binary.NodeType is ExpressionType.Equal) - { - return new BsonDocument { [storagePropertyName] = BsonValue.Create(value) }; - } + throw new NotSupportedException("MongogDB does not support null checks in vector search pre-filters"); + } - var filterOperator = binary.NodeType switch - { - ExpressionType.NotEqual => "$ne", - ExpressionType.GreaterThan => "$gt", - ExpressionType.GreaterThanOrEqual => "$gte", - ExpressionType.LessThan => "$lt", - ExpressionType.LessThanOrEqual => "$lte", + // Short form of equality (instead of $eq) + if (nodeType is ExpressionType.Equal) + { + return new BsonDocument { [property.StorageName] = BsonValue.Create(value) }; + } - _ => throw new UnreachableException() - }; + var filterOperator = nodeType switch + { + ExpressionType.NotEqual => "$ne", + ExpressionType.GreaterThan => "$gt", + ExpressionType.GreaterThanOrEqual => "$gte", + ExpressionType.LessThan => "$lt", + ExpressionType.LessThanOrEqual => "$lte", - return new BsonDocument { [storagePropertyName] = new BsonDocument { [filterOperator] = BsonValue.Create(value) } }; - } + _ => throw new UnreachableException() + }; - throw new NotSupportedException("Invalid equality/comparison"); + return new BsonDocument { [property.StorageName] = new BsonDocument { [filterOperator] = BsonValue.Create(value) } }; } private BsonDocument TranslateAndOr(BinaryExpression andOr) @@ -130,9 +134,9 @@ private BsonDocument TranslateNot(UnaryExpression not) binary.Left, binary.Right)); - // Not over bool field (Filter => r => !r.Bool) - case MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _): - return this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(false))); + // Not over bool field (r => !r.Bool) + case var negated when negated.Type == typeof(bool) && this.TryBindProperty(negated, out var property): + return this.GenerateEqualityComparison(property, false, ExpressionType.Equal); } var operand = this.Translate(not.Operand); @@ -174,7 +178,7 @@ private BsonDocument TranslateContains(Expression source, Expression item) switch (source) { // Contains over array column (r => r.Strings.Contains("foo")) - case var _ when this.TryTranslateFieldAccess(source, out _): + case var _ when this.TryBindProperty(source, out _): throw new NotSupportedException("MongoDB does not support Contains within array fields ($elemMatch) in vector search pre-filters"); // Contains over inline enumerable @@ -183,7 +187,7 @@ private BsonDocument TranslateContains(Expression source, Expression item) for (var i = 0; i < newArray.Expressions.Count; i++) { - if (!TryGetConstant(newArray.Expressions[i], out var elementValue)) + if (newArray.Expressions[i] is not ConstantExpression { Value: var elementValue }) { throw new NotSupportedException("Invalid element in array"); } @@ -193,9 +197,7 @@ private BsonDocument TranslateContains(Expression source, Expression item) return ProcessInlineEnumerable(elements, item); - // Contains over captured enumerable (we inline) - case var _ when TryGetConstant(source, out var constantEnumerable) - && constantEnumerable is IEnumerable enumerable and not string: + case ConstantExpression { Value: IEnumerable enumerable and not string }: return ProcessInlineEnumerable(enumerable, item); default: @@ -204,14 +206,14 @@ private BsonDocument TranslateContains(Expression source, Expression item) BsonDocument ProcessInlineEnumerable(IEnumerable elements, Expression item) { - if (!this.TryTranslateFieldAccess(item, out var storagePropertyName)) + if (!this.TryBindProperty(item, out var property)) { throw new NotSupportedException("Unsupported item type in Contains"); } return new BsonDocument { - [storagePropertyName] = new BsonDocument + [property.StorageName] = new BsonDocument { ["$in"] = new BsonArray(from object? element in elements select BsonValue.Create(element)) } @@ -219,40 +221,49 @@ BsonDocument ProcessInlineEnumerable(IEnumerable elements, Expression item) } } - private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName) + private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out VectorStoreRecordPropertyModel? property) { - if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter) - { - if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName)) - { - throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name."); - } + Type? convertedClrType = null; - return true; + if (expression is UnaryExpression { NodeType: ExpressionType.Convert } unary) + { + expression = unary.Operand; + convertedClrType = unary.Type; } - storagePropertyName = null; - return false; - } + var modelName = expression switch + { + // Regular member access for strongly-typed POCO binding (e.g. r => r.SomeInt == 8) + MemberExpression memberExpression when memberExpression.Expression == this._recordParameter + => memberExpression.Member.Name, - private static bool TryGetConstant(Expression expression, out object? constantValue) - { - switch (expression) + // Dictionary lookup for weakly-typed dynamic binding (e.g. r => r["SomeInt"] == 8) + MethodCallExpression + { + Method: { Name: "get_Item", DeclaringType: var declaringType }, + Arguments: [ConstantExpression { Value: string keyName }] + } methodCall when methodCall.Object == this._recordParameter && declaringType == typeof(Dictionary) + => keyName, + + _ => null + }; + + if (modelName is null) { - case ConstantExpression { Value: var v }: - constantValue = v; - return true; + property = null; + return false; + } - // This identifies compiler-generated closure types which contain captured variables. - case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } - when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) - && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true): - constantValue = fieldInfo.GetValue(constant.Value); - return true; + if (!this._model.PropertyMap.TryGetValue(modelName, out property)) + { + throw new InvalidOperationException($"Property name '{modelName}' provided as part of the filter clause is not a valid property name."); + } - default: - constantValue = null; - return false; + if (convertedClrType is not null && convertedClrType != property.Type) + { + throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convertedClrType.Name}', but its configured type is '{property.Type.Name}'."); } + + return true; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBMemoryEntry.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBMemoryEntry.cs index ef2c88da699c..066424dc1d83 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBMemoryEntry.cs +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBMemoryEntry.cs @@ -1,17 +1,18 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Diagnostics.CodeAnalysis; using Microsoft.SemanticKernel.Memory; using MongoDB.Bson; using MongoDB.Bson.Serialization.Attributes; namespace Microsoft.SemanticKernel.Connectors.MongoDB; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// A MongoDB memory entry. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and MongoDBVectorStore")] public sealed class MongoDBMemoryEntry { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBMemoryRecordMetadata.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBMemoryRecordMetadata.cs index cf7e4f7894d5..f7a3f5b55548 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBMemoryRecordMetadata.cs +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBMemoryRecordMetadata.cs @@ -1,16 +1,18 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using Microsoft.SemanticKernel.Memory; using MongoDB.Bson.Serialization.Attributes; namespace Microsoft.SemanticKernel.Connectors.MongoDB; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// A MongoDB record metadata. /// #pragma warning disable CA1815 // Override equals and operator equals on value types -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and MongoDBVectorStore")] public struct MongoDBMemoryRecordMetadata #pragma warning restore CA1815 // Override equals and operator equals on value types { diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBMemoryStore.cs index 3e81afd0efde..72bfe4deead2 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBMemoryStore.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -12,10 +11,12 @@ namespace Microsoft.SemanticKernel.Connectors.MongoDB; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// An implementation of backed by a MongoDB database. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and MongoDBVectorStore")] public class MongoDBMemoryStore : IMemoryStore, IDisposable { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBServiceCollectionExtensions.cs index b8e89aab82da..43c28c93974f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBServiceCollectionExtensions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.MongoDB; @@ -33,7 +34,10 @@ public static IServiceCollection AddMongoDBVectorStore( (sp, obj) => { var database = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService(); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; return new MongoDBVectorStore(database, options); }); @@ -70,7 +74,10 @@ public static IServiceCollection AddMongoDBVectorStore( var mongoClient = new MongoClient(settings); var database = mongoClient.GetDatabase(databaseName); - var selectedOptions = options ?? sp.GetService(); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; return new MongoDBVectorStore(database, options); }); @@ -79,7 +86,7 @@ public static IServiceCollection AddMongoDBVectorStore( } /// - /// Register a MongoDB and with the specified service ID + /// Register a MongoDB and with the specified service ID /// and where the MongoDB is retrieved from the dependency injection container. /// /// The type of the record. @@ -93,15 +100,19 @@ public static IServiceCollection AddMongoDBVectorStoreRecordCollection( string collectionName, MongoDBVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { services.AddKeyedTransient>( serviceId, (sp, obj) => { var database = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new MongoDBVectorStoreRecordCollection(database, collectionName, selectedOptions); + return new MongoDBVectorStoreRecordCollection(database, collectionName, options); }); AddVectorizedSearch(services, serviceId); @@ -110,7 +121,7 @@ public static IServiceCollection AddMongoDBVectorStoreRecordCollection( } /// - /// Register a MongoDB and with the specified service ID + /// Register a MongoDB and with the specified service ID /// and where the MongoDB is constructed using the provided and . /// /// The type of the record. @@ -128,6 +139,7 @@ public static IServiceCollection AddMongoDBVectorStoreRecordCollection( string databaseName, MongoDBVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { services.AddKeyedSingleton>( serviceId, @@ -139,9 +151,12 @@ public static IServiceCollection AddMongoDBVectorStoreRecordCollection( var mongoClient = new MongoClient(settings); var database = mongoClient.GetDatabase(databaseName); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new MongoDBVectorStoreRecordCollection(database, collectionName, selectedOptions); + return new MongoDBVectorStoreRecordCollection(database, collectionName, options); }); AddVectorizedSearch(services, serviceId); @@ -150,14 +165,14 @@ public static IServiceCollection AddMongoDBVectorStoreRecordCollection( } /// - /// Also register the with the given as a . + /// Also register the with the given as a . /// /// The type of the data model that the collection should contain. /// The service collection to register on. /// The service id that the registrations should use. - private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) + private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) where TRecord : notnull { - services.AddKeyedTransient>( + services.AddKeyedTransient>( serviceId, (sp, obj) => { diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStore.cs index 27169e3e9557..56429596cd03 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStore.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Threading; +using System.Threading.Tasks; using Microsoft.Extensions.VectorData; using MongoDB.Driver; @@ -15,14 +16,20 @@ namespace Microsoft.SemanticKernel.Connectors.MongoDB; /// /// This class can be used with collections of any schema type, but requires you to provide schema information when getting a collection. /// -public class MongoDBVectorStore : IVectorStore +public sealed class MongoDBVectorStore : IVectorStore { + /// Metadata about vector store. + private readonly VectorStoreMetadata _metadata; + /// that can be used to manage the collections in MongoDB. private readonly IMongoDatabase _mongoDatabase; /// Optional configuration options for this class. private readonly MongoDBVectorStoreOptions _options; + /// A general purpose definition that can be used to construct a collection when needing to proxy schema agnostic operations. + private static readonly VectorStoreRecordDefinition s_generalPurposeDefinition = new() { Properties = [new VectorStoreRecordKeyProperty("Key", typeof(string))] }; + /// /// Initializes a new instance of the class. /// @@ -34,11 +41,18 @@ public MongoDBVectorStore(IMongoDatabase mongoDatabase, MongoDBVectorStoreOption this._mongoDatabase = mongoDatabase; this._options = options ?? new(); + + this._metadata = new() + { + VectorStoreSystemName = MongoDBConstants.VectorStoreSystemName, + VectorStoreName = mongoDatabase.DatabaseNamespace?.DatabaseName + }; } /// - public virtual IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull + where TRecord : notnull { #pragma warning disable CS0618 // IMongoDBVectorStoreRecordCollectionFactoryß is obsolete if (this._options.VectorStoreCollectionFactory is not null) @@ -47,21 +61,20 @@ public virtual IVectorStoreRecordCollection GetCollection( + var recordCollection = new MongoDBVectorStoreRecordCollection( this._mongoDatabase, name, - new() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }) as IVectorStoreRecordCollection; + new() + { + VectorStoreRecordDefinition = vectorStoreRecordDefinition, + EmbeddingGenerator = this._options.EmbeddingGenerator + }) as IVectorStoreRecordCollection; return recordCollection!; } /// - public virtual async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { using var cursor = await this._mongoDatabase .ListCollectionNamesAsync(cancellationToken: cancellationToken) @@ -75,4 +88,31 @@ public virtual async IAsyncEnumerable ListCollectionNamesAsync([Enumerat } } } + + /// + public Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.CollectionExistsAsync(cancellationToken); + } + + /// + public Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.DeleteCollectionAsync(cancellationToken); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreMetadata) ? this._metadata : + serviceType == typeof(IMongoDatabase) ? this._mongoDatabase : + serviceType.IsInstanceOfType(this) ? this : + null; + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionCreateMapping.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionCreateMapping.cs index 3d6b634a14e1..5eebe8a91001 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionCreateMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionCreateMapping.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using MongoDB.Bson; namespace Microsoft.SemanticKernel.Connectors.MongoDB; @@ -16,25 +17,19 @@ internal static class MongoDBVectorStoreCollectionCreateMapping /// Returns an array of indexes to create for vector properties. /// /// Collection of vector properties for index creation. - /// A dictionary that maps from a property name to the storage name. - public static BsonArray GetVectorIndexFields( - IReadOnlyList vectorProperties, - Dictionary storagePropertyNames) + public static BsonArray GetVectorIndexFields(IReadOnlyList vectorProperties) { var indexArray = new BsonArray(); // Create separate index for each vector property foreach (var property in vectorProperties) { - // Use index name same as vector property name with underscore - var vectorPropertyName = storagePropertyNames[property.DataModelPropertyName]; - var indexDocument = new BsonDocument { { "type", "vector" }, { "numDimensions", property.Dimensions }, - { "path", vectorPropertyName }, - { "similarity", GetDistanceFunction(property.DistanceFunction, vectorPropertyName) }, + { "path", property.StorageName }, + { "similarity", GetDistanceFunction(property.DistanceFunction, property.ModelName) }, }; indexArray.Add(indexDocument); @@ -47,25 +42,19 @@ public static BsonArray GetVectorIndexFields( /// Returns an array of indexes to create for filterable data properties. /// /// Collection of data properties for index creation. - /// A dictionary that maps from a property name to the storage name. - public static BsonArray GetFilterableDataIndexFields( - IReadOnlyList dataProperties, - Dictionary storagePropertyNames) + public static BsonArray GetFilterableDataIndexFields(IReadOnlyList dataProperties) { var indexArray = new BsonArray(); // Create separate index for each data property foreach (var property in dataProperties) { - if (property.IsFilterable) + if (property.IsIndexed) { - // Use index name same as data property name with underscore - var dataPropertyName = storagePropertyNames[property.DataModelPropertyName]; - var indexDocument = new BsonDocument { { "type", "filter" }, - { "path", dataPropertyName }, + { "path", property.StorageName }, }; indexArray.Add(indexDocument); @@ -79,23 +68,18 @@ public static BsonArray GetFilterableDataIndexFields( /// Returns a list of of fields to index for full text search data properties. /// /// Collection of data properties for index creation. - /// A dictionary that maps from a property name to the storage name. - public static List GetFullTextSearchableDataIndexFields( - IReadOnlyList dataProperties, - Dictionary storagePropertyNames) + public static List GetFullTextSearchableDataIndexFields(IReadOnlyList dataProperties) { var fieldElements = new List(); // Create separate index for each data property foreach (var property in dataProperties) { - if (property.IsFullTextSearchable) + if (property.IsFullTextIndexed) { - var dataPropertyName = storagePropertyNames[property.DataModelPropertyName]; - - fieldElements.Add(new BsonElement(dataPropertyName, new BsonArray() + fieldElements.Add(new BsonElement(property.StorageName, new BsonArray() { - new BsonDocument() { { "type", "string" }, } + new BsonDocument() { { "type", "string" } } })); } } @@ -107,15 +91,11 @@ public static List GetFullTextSearchableDataIndexFields( /// More information about MongoDB distance functions here: . /// private static string GetDistanceFunction(string? distanceFunction, string vectorPropertyName) - { - var vectorPropertyDistanceFunction = MongoDBVectorStoreCollectionSearchMapping.GetVectorPropertyDistanceFunction(distanceFunction); - - return vectorPropertyDistanceFunction switch + => distanceFunction switch { - DistanceFunction.CosineSimilarity => "cosine", + DistanceFunction.CosineSimilarity or null => "cosine", DistanceFunction.DotProductSimilarity => "dotProduct", DistanceFunction.EuclideanDistance => "euclidean", _ => throw new InvalidOperationException($"Distance function '{distanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorPropertyName}' is not supported by the MongoDB VectorStore.") }; - } } diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionSearchMapping.cs index 8e0258f0aa21..75d617155c5a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionSearchMapping.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Linq; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using MongoDB.Bson; namespace Microsoft.SemanticKernel.Connectors.MongoDB; @@ -13,20 +14,17 @@ namespace Microsoft.SemanticKernel.Connectors.MongoDB; /// internal static class MongoDBVectorStoreCollectionSearchMapping { - /// Returns distance function specified on vector property or default. - public static string GetVectorPropertyDistanceFunction(string? distanceFunction) => !string.IsNullOrWhiteSpace(distanceFunction) ? distanceFunction! : DistanceFunction.CosineSimilarity; - #pragma warning disable CS0618 // VectorSearchFilter is obsolete /// /// Build MongoDB filter from the provided . /// /// The to build MongoDB filter from. - /// A dictionary that maps from a property name to the storage name. + /// The model. /// Thrown when the provided filter type is unsupported. /// Thrown when property name specified in filter doesn't exist. public static BsonDocument? BuildLegacyFilter( VectorSearchFilter vectorSearchFilter, - Dictionary storagePropertyNames) + VectorStoreRecordModel model) { const string EqualOperator = "$eq"; @@ -59,25 +57,25 @@ internal static class MongoDBVectorStoreCollectionSearchMapping nameof(EqualToFilterClause)])}"); } - if (!storagePropertyNames.TryGetValue(propertyName, out var storagePropertyName)) + if (!model.PropertyMap.TryGetValue(propertyName, out var property)) { throw new InvalidOperationException($"Property name '{propertyName}' provided as part of the filter clause is not a valid property name."); } - if (filter.Contains(storagePropertyName)) + if (filter.Contains(property.StorageName)) { - if (filter[storagePropertyName] is BsonDocument document && document.Contains(filterOperator)) + if (filter[property.StorageName] is BsonDocument document && document.Contains(filterOperator)) { throw new NotSupportedException( $"Filter with operator '{filterOperator}' is already added to '{propertyName}' property. " + "Multiple filters of the same type in the same property are not supported."); } - filter[storagePropertyName][filterOperator] = propertyValue; + filter[property.StorageName][filterOperator] = propertyValue; } else { - filter[storagePropertyName] = new BsonDocument() { [filterOperator] = propertyValue }; + filter[property.StorageName] = new BsonDocument() { [filterOperator] = propertyValue }; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreOptions.cs index 3382019ea1f6..f53bc2fb6b92 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreOptions.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.Connectors.MongoDB; @@ -10,7 +11,12 @@ namespace Microsoft.SemanticKernel.Connectors.MongoDB; public sealed class MongoDBVectorStoreOptions { /// - /// An optional factory to use for constructing instances, if a custom record collection is required. + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } + + /// + /// An optional factory to use for constructing instances, if a custom record collection is required. /// [Obsolete("To control how collections are instantiated, extend your provider's IVectorStore implementation and override GetCollection()")] public IMongoDBVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollection.cs index dc2aa163a803..062c22d936e8 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollection.cs @@ -2,15 +2,17 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Linq.Expressions; -using System.Reflection; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.Properties; using MongoDB.Bson; -using MongoDB.Bson.Serialization.Attributes; using MongoDB.Driver; using MEVD = Microsoft.Extensions.VectorData; @@ -19,13 +21,16 @@ namespace Microsoft.SemanticKernel.Connectors.MongoDB; /// /// Service for storing and retrieving vector records, that uses MongoDB as the underlying storage. /// +/// The data type of the record key. Can be either , or for dynamic mapping. /// The data model to use for adding, updating and retrieving data from storage. #pragma warning disable CA1711 // Identifiers should not have incorrect suffix -public class MongoDBVectorStoreRecordCollection : IVectorStoreRecordCollection, IKeywordHybridSearch +public sealed class MongoDBVectorStoreRecordCollection : IVectorStoreRecordCollection, IKeywordHybridSearch + where TKey : notnull + where TRecord : notnull #pragma warning restore CA1711 // Identifiers should not have incorrect suffix { - /// The name of this database for telemetry purposes. - private const string DatabaseName = "MongoDB"; + /// Metadata about vector store record collection. + private readonly VectorStoreRecordCollectionMetadata _collectionMetadata; /// Property name to be used for search similarity score value. private const string ScorePropertyName = "similarityScore"; @@ -49,60 +54,58 @@ public class MongoDBVectorStoreRecordCollection : IVectorStoreRecordCol private readonly MongoDBVectorStoreRecordCollectionOptions _options; /// Interface for mapping between a storage model, and the consumer record data model. - private readonly IVectorStoreRecordMapper _mapper; + private readonly IMongoDBMapper _mapper; - /// A dictionary that maps from a property name to the storage name that should be used when serializing it for data and vector properties. - private readonly Dictionary _storagePropertyNames; - - /// Collection of vector storage property names. - private readonly List _vectorStoragePropertyNames; - - /// A helper to access property information for the current data model and record definition. - private readonly VectorStoreRecordPropertyReader _propertyReader; + /// The model for this collection. + private readonly VectorStoreRecordModel _model; /// - public string CollectionName { get; } + public string Name { get; } /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// that can be used to manage the collections in MongoDB. - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// Optional configuration options for this class. public MongoDBVectorStoreRecordCollection( IMongoDatabase mongoDatabase, - string collectionName, + string name, MongoDBVectorStoreRecordCollectionOptions? options = default) { // Verify. Verify.NotNull(mongoDatabase); - Verify.NotNullOrWhiteSpace(collectionName); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.BsonDocumentCustomMapper is not null, MongoDBConstants.SupportedKeyTypes); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); + Verify.NotNullOrWhiteSpace(name); + + if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(object)) + { + throw new NotSupportedException("Only string keys are supported (and object for dynamic mapping)"); + } // Assign. this._mongoDatabase = mongoDatabase; - this._mongoCollection = mongoDatabase.GetCollection(collectionName); - this.CollectionName = collectionName; + this._mongoCollection = mongoDatabase.GetCollection(name); + this.Name = name; this._options = options ?? new MongoDBVectorStoreRecordCollectionOptions(); - this._propertyReader = new VectorStoreRecordPropertyReader(typeof(TRecord), this._options.VectorStoreRecordDefinition, new() { RequiresAtLeastOneVector = false, SupportsMultipleKeys = false, SupportsMultipleVectors = true }); + this._model = new MongoDBModelBuilder().Build(typeof(TRecord), this._options.VectorStoreRecordDefinition, this._options.EmbeddingGenerator); + this._mapper = typeof(TRecord) == typeof(Dictionary) + ? (new MongoDBDynamicDataModelMapper(this._model) as IMongoDBMapper)! + : new MongoDBVectorStoreRecordMapper(this._model); - this._storagePropertyNames = GetStoragePropertyNames(this._propertyReader.Properties, typeof(TRecord)); - - // Use Mongo reserved key property name as storage key property name - this._storagePropertyNames[this._propertyReader.KeyPropertyName] = MongoDBConstants.MongoReservedKeyPropertyName; - - this._vectorStoragePropertyNames = this._propertyReader.VectorProperties.Select(property => this._storagePropertyNames[property.DataModelPropertyName]).ToList(); - - this._mapper = this.InitializeMapper(); + this._collectionMetadata = new() + { + VectorStoreSystemName = MongoDBConstants.VectorStoreSystemName, + VectorStoreName = mongoDatabase.DatabaseNamespace?.DatabaseName, + CollectionName = name + }; } /// - public virtual Task CollectionExistsAsync(CancellationToken cancellationToken = default) + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) => this.RunOperationAsync("ListCollectionNames", () => this.InternalCollectionExistsAsync(cancellationToken)); /// - public virtual async Task CreateCollectionAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionAsync(CancellationToken cancellationToken = default) { // The IMongoDatabase.CreateCollectionAsync "Creates a new collection if not already available". // To make sure that all the connectors are consistent, we throw when the collection exists. @@ -110,8 +113,9 @@ public virtual async Task CreateCollectionAsync(CancellationToken cancellationTo { throw new VectorStoreOperationException("Collection already exists.") { - VectorStoreType = DatabaseName, - CollectionName = this.CollectionName, + VectorStoreSystemName = MongoDBConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, OperationName = "CreateCollection" }; } @@ -120,56 +124,62 @@ public virtual async Task CreateCollectionAsync(CancellationToken cancellationTo } /// - public virtual async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) { // The IMongoDatabase.CreateCollectionAsync "Creates a new collection if not already available". // So for CreateCollectionIfNotExistsAsync, we don't perform an additional check. await this.RunOperationAsync("CreateCollection", - () => this._mongoDatabase.CreateCollectionAsync(this.CollectionName, cancellationToken: cancellationToken)).ConfigureAwait(false); + () => this._mongoDatabase.CreateCollectionAsync(this.Name, cancellationToken: cancellationToken)).ConfigureAwait(false); await this.RunOperationWithRetryAsync( "CreateIndexes", this._options.MaxRetries, this._options.DelayInMilliseconds, - () => this.CreateIndexesAsync(this.CollectionName, cancellationToken), + () => this.CreateIndexesAsync(this.Name, cancellationToken), cancellationToken).ConfigureAwait(false); } /// - public virtual async Task DeleteAsync(string key, CancellationToken cancellationToken = default) + public async Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) { - Verify.NotNullOrWhiteSpace(key); + var stringKey = this.GetStringKey(key); - await this.RunOperationAsync("DeleteOne", () => this._mongoCollection.DeleteOneAsync(this.GetFilterById(key), cancellationToken)) + await this.RunOperationAsync("DeleteOne", () => this._mongoCollection.DeleteOneAsync(this.GetFilterById(stringKey), cancellationToken)) .ConfigureAwait(false); } /// - public virtual async Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + public async Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default) { Verify.NotNull(keys); - await this.RunOperationAsync("DeleteMany", () => this._mongoCollection.DeleteManyAsync(this.GetFilterByIds(keys), cancellationToken)) + var stringKeys = keys is IEnumerable k ? k : keys.Cast(); + + await this.RunOperationAsync("DeleteMany", () => this._mongoCollection.DeleteManyAsync(this.GetFilterByIds(stringKeys), cancellationToken)) .ConfigureAwait(false); } /// - public virtual Task DeleteCollectionAsync(CancellationToken cancellationToken = default) - => this.RunOperationAsync("DropCollection", () => this._mongoDatabase.DropCollectionAsync(this.CollectionName, cancellationToken)); + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + => this.RunOperationAsync("DropCollection", () => this._mongoDatabase.DropCollectionAsync(this.Name, cancellationToken)); /// - public virtual async Task GetAsync(string key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public async Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { - Verify.NotNullOrWhiteSpace(key); - const string OperationName = "Find"; + var stringKey = this.GetStringKey(key); + var includeVectors = options?.IncludeVectors ?? false; + if (includeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } var record = await this.RunOperationAsync(OperationName, async () => { using var cursor = await this - .FindAsync(this.GetFilterById(key), options, cancellationToken) + .FindAsync(this.GetFilterById(stringKey), options, cancellationToken) .ConfigureAwait(false); return await cursor.SingleOrDefaultAsync(cancellationToken).ConfigureAwait(false); @@ -181,15 +191,16 @@ public virtual Task DeleteCollectionAsync(CancellationToken cancellationToken = } return VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, + MongoDBConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, OperationName, () => this._mapper.MapFromStorageToDataModel(record, new() { IncludeVectors = includeVectors })); } /// - public virtual async IAsyncEnumerable GetBatchAsync( - IEnumerable keys, + public async IAsyncEnumerable GetAsync( + IEnumerable keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { @@ -197,8 +208,15 @@ public virtual async IAsyncEnumerable GetBatchAsync( const string OperationName = "Find"; + if (options?.IncludeVectors == true && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + + var stringKeys = keys is IEnumerable k ? k : keys.Cast(); + using var cursor = await this - .FindAsync(this.GetFilterByIds(keys), options, cancellationToken) + .FindAsync(this.GetFilterByIds(stringKeys), options, cancellationToken) .ConfigureAwait(false); while (await cursor.MoveNextAsync(cancellationToken).ConfigureAwait(false)) @@ -208,8 +226,9 @@ public virtual async IAsyncEnumerable GetBatchAsync( if (record is not null) { yield return VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, + MongoDBConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, OperationName, () => this._mapper.MapFromStorageToDataModel(record, new())); } @@ -218,80 +237,174 @@ public virtual async IAsyncEnumerable GetBatchAsync( } /// - public virtual Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) { Verify.NotNull(record); const string OperationName = "ReplaceOne"; + Embedding?[]? generatedEmbeddings = null; + + var vectorPropertyCount = this._model.VectorProperties.Count; + for (var i = 0; i < vectorPropertyCount; i++) + { + var vectorProperty = this._model.VectorProperties[i]; + + if (vectorProperty.EmbeddingGenerator is null) + { + continue; + } + + // TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties), + // and generate embeddings for them in a single batch. That's some more complexity though. + if (vectorProperty.TryGenerateEmbedding, ReadOnlyMemory>(record, cancellationToken, out var floatTask)) + { + generatedEmbeddings ??= new Embedding?[vectorPropertyCount]; + generatedEmbeddings[i] = await floatTask.ConfigureAwait(false); + } + else if (vectorProperty.TryGenerateEmbedding, ReadOnlyMemory>(record, cancellationToken, out var doubleTask)) + { + generatedEmbeddings ??= new Embedding?[vectorPropertyCount]; + generatedEmbeddings[i] = await doubleTask.ConfigureAwait(false); + } + else + { + throw new InvalidOperationException( + $"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding).Name}' for the given input type."); + } + } + var replaceOptions = new ReplaceOptions { IsUpsert = true }; var storageModel = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, + MongoDBConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, OperationName, - () => this._mapper.MapFromDataToStorageModel(record)); + () => this._mapper.MapFromDataToStorageModel(record, generatedEmbeddings)); var key = storageModel[MongoDBConstants.MongoReservedKeyPropertyName].AsString; - return this.RunOperationAsync(OperationName, async () => + return await this.RunOperationAsync(OperationName, async () => { await this._mongoCollection .ReplaceOneAsync(this.GetFilterById(key), storageModel, replaceOptions, cancellationToken) .ConfigureAwait(false); - return key; - }); + return (TKey)(object)key; + }).ConfigureAwait(false); } /// - public virtual async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async Task> UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) { Verify.NotNull(records); var tasks = records.Select(record => this.UpsertAsync(record, cancellationToken)); var results = await Task.WhenAll(tasks).ConfigureAwait(false); + return results.Where(r => r is not null).ToList(); + } + + #region Search - foreach (var result in results) + /// + public async IAsyncEnumerable> SearchAsync( + TInput value, + int top, + MEVD.VectorSearchOptions? options = default, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TInput : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + switch (vectorProperty.EmbeddingGenerator) { - if (result is not null) + case IEmbeddingGenerator> generator: + { + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; + } + + case IEmbeddingGenerator> generator: { - yield return result; + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; } + + case null: + throw new InvalidOperationException(VectorDataStrings.NoEmbeddingGeneratorWasConfiguredForSearch); + + default: + throw new InvalidOperationException( + MongoDBConstants.SupportedVectorTypes.Contains(typeof(TInput)) + ? string.Format(VectorDataStrings.EmbeddingTypePassedToSearchAsync) + : string.Format(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType, typeof(TInput).Name, vectorProperty.EmbeddingGenerator.GetType().Name)); } } /// - public virtual async Task> VectorizedSearchAsync( + public IAsyncEnumerable> SearchEmbeddingAsync( TVector vector, + int top, MEVD.VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + return this.SearchCoreAsync(vector, top, vectorProperty, operationName: "SearchEmbedding", options, cancellationToken); + } + + private async IAsyncEnumerable> SearchCoreAsync( + TVector vector, + int top, + VectorStoreRecordVectorPropertyModel vectorProperty, + string operationName, + MEVD.VectorSearchOptions options, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TVector : notnull { Array vectorArray = VerifyVectorParam(vector); + Verify.NotLessThan(top, 1); - var searchOptions = options ?? s_defaultVectorSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(searchOptions); - var vectorPropertyName = this._storagePropertyNames[vectorProperty.DataModelPropertyName]; + if (options.IncludeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } #pragma warning disable CS0618 // VectorSearchFilter is obsolete - var filter = searchOptions switch + var filter = options switch { { OldFilter: not null, Filter: not null } => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), - { OldFilter: VectorSearchFilter legacyFilter } => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(legacyFilter, this._storagePropertyNames), - { Filter: Expression> newFilter } => new MongoDBFilterTranslator().Translate(newFilter, this._storagePropertyNames), + { OldFilter: VectorSearchFilter legacyFilter } => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(legacyFilter, this._model), + { Filter: Expression> newFilter } => new MongoDBFilterTranslator().Translate(newFilter, this._model), _ => null }; #pragma warning restore CS0618 // Constructing a query to fetch "skip + top" total items // to perform skip logic locally, since skip option is not part of API. - var itemsAmount = searchOptions.Skip + searchOptions.Top; + var itemsAmount = options.Skip + top; var numCandidates = this._options.NumCandidates ?? itemsAmount * MongoDBConstants.DefaultNumCandidatesRatio; var searchQuery = MongoDBVectorStoreCollectionSearchMapping.GetSearchQuery( vectorArray, this._options.VectorIndexName, - vectorPropertyName, + vectorProperty.StorageName, itemsAmount, numCandidates, filter); @@ -302,7 +415,7 @@ public virtual async Task> VectorizedSearchAsync> VectorizedSearchAsync(pipeline, cancellationToken: cancellationToken) .ConfigureAwait(false); - return new VectorSearchResults(this.EnumerateAndMapSearchResultsAsync(cursor, searchOptions.Skip, searchOptions.IncludeVectors, cancellationToken)); + return this.EnumerateAndMapSearchResultsAsync(cursor, options.Skip, options.IncludeVectors, cancellationToken); + }, + cancellationToken).ConfigureAwait(false); + + await foreach (var result in results.ConfigureAwait(false)) + { + yield return result; + } + } + + /// + [Obsolete("Use either SearchEmbeddingAsync to search directly on embeddings, or SearchAsync to handle embedding generation internally as part of the call.")] + public IAsyncEnumerable> VectorizedSearchAsync(TVector vector, int top, MEVD.VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + => this.SearchEmbeddingAsync(vector, top, options, cancellationToken); + + #endregion Search + + /// + public async IAsyncEnumerable GetAsync(Expression> filter, int top, + GetFilteredRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(filter); + Verify.NotLessThan(top, 1); + + options ??= new(); + + // Translate the filter now, so if it fails, we throw immediately. + var translatedFilter = new MongoDBFilterTranslator().Translate(filter, this._model); + SortDefinition? sortDefinition = null; + if (options.OrderBy.Values.Count > 0) + { + sortDefinition = Builders.Sort.Combine( + options.OrderBy.Values.Select(pair => + { + var storageName = this._model.GetDataOrKeyProperty(pair.PropertySelector).StorageName; + + return pair.Ascending + ? Builders.Sort.Ascending(storageName) + : Builders.Sort.Descending(storageName); + })); + } + + using IAsyncCursor cursor = await this.RunOperationWithRetryAsync( + "GetAsync", + this._options.MaxRetries, + this._options.DelayInMilliseconds, + async () => + { + return await this._mongoCollection.FindAsync(translatedFilter, + new() + { + Limit = top, + Skip = options.Skip, + Sort = sortDefinition + }, + cancellationToken: cancellationToken).ConfigureAwait(false); }, cancellationToken).ConfigureAwait(false); + + while (await cursor.MoveNextAsync(cancellationToken).ConfigureAwait(false)) + { + foreach (var response in cursor.Current) + { + var record = VectorStoreErrorHandler.RunModelConversion( + MongoDBConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, + "GetAsync", + () => this._mapper.MapFromStorageToDataModel(response, new() { IncludeVectors = options.IncludeVectors })); + + yield return record; + } + } } /// - public async Task> HybridSearchAsync(TVector vector, ICollection keywords, HybridSearchOptions? options = null, CancellationToken cancellationToken = default) + public async IAsyncEnumerable> HybridSearchAsync(TVector vector, ICollection keywords, int top, HybridSearchOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { Array vectorArray = VerifyVectorParam(vector); + Verify.NotLessThan(top, 1); - var searchOptions = options ?? s_defaultKeywordVectorizedHybridSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(new() { VectorProperty = searchOptions.VectorProperty }); - var vectorPropertyName = this._storagePropertyNames[vectorProperty.DataModelPropertyName]; - var textDataProperty = this._propertyReader.GetFullTextDataPropertyOrSingle(searchOptions.AdditionalProperty); - var textDataPropertyName = this._storagePropertyNames[textDataProperty.DataModelPropertyName]; + options ??= s_defaultKeywordVectorizedHybridSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(new() { VectorProperty = options.VectorProperty }); + var textDataProperty = this._model.GetFullTextDataPropertyOrSingle(options.AdditionalProperty); #pragma warning disable CS0618 // VectorSearchFilter is obsolete - var filter = searchOptions switch + var filter = options switch { { OldFilter: not null, Filter: not null } => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), - { OldFilter: VectorSearchFilter legacyFilter } => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(legacyFilter, this._storagePropertyNames), - { Filter: Expression> newFilter } => new MongoDBFilterTranslator().Translate(newFilter, this._storagePropertyNames), + { OldFilter: VectorSearchFilter legacyFilter } => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(legacyFilter, this._model), + { Filter: Expression> newFilter } => new MongoDBFilterTranslator().Translate(newFilter, this._model), _ => null }; #pragma warning restore CS0618 // Constructing a query to fetch "skip + top" total items - // to perform skip logic locally, since skip option is not part of API. - var itemsAmount = searchOptions.Skip + searchOptions.Top; + // to perform skip logic locally, since skip option is not part of API. + var itemsAmount = options.Skip + top; var numCandidates = this._options.NumCandidates ?? itemsAmount * MongoDBConstants.DefaultNumCandidatesRatio; BsonDocument[] pipeline = MongoDBVectorStoreCollectionSearchMapping.GetHybridSearchPipeline( vectorArray, keywords, - this.CollectionName, + this.Name, this._options.VectorIndexName, this._options.FullTextSearchIndexName, - vectorPropertyName, - textDataPropertyName, + vectorProperty.StorageName, + textDataProperty.StorageName, ScorePropertyName, DocumentPropertyName, itemsAmount, numCandidates, filter); - return await this.RunOperationWithRetryAsync( + var results = await this.RunOperationWithRetryAsync( "KeywordVectorizedHybridSearch", this._options.MaxRetries, this._options.DelayInMilliseconds, @@ -368,9 +551,28 @@ public async Task> HybridSearchAsync(TVect .AggregateAsync(pipeline, cancellationToken: cancellationToken) .ConfigureAwait(false); - return new VectorSearchResults(this.EnumerateAndMapSearchResultsAsync(cursor, searchOptions.Skip, searchOptions.IncludeVectors, cancellationToken)); + return this.EnumerateAndMapSearchResultsAsync(cursor, options.Skip, options.IncludeVectors, cancellationToken); }, cancellationToken).ConfigureAwait(false); + + await foreach (var result in results.ConfigureAwait(false)) + { + yield return result; + } + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreRecordCollectionMetadata) ? this._collectionMetadata : + serviceType == typeof(IMongoDatabase) ? this._mongoDatabase : + serviceType == typeof(IMongoCollection) ? this._mongoCollection : + serviceType.IsInstanceOfType(this) ? this : + null; } #region private @@ -387,13 +589,8 @@ private async Task CreateIndexesAsync(string collectionName, CancellationToken c { var fieldsArray = new BsonArray(); - fieldsArray.AddRange(MongoDBVectorStoreCollectionCreateMapping.GetVectorIndexFields( - this._propertyReader.VectorProperties, - this._storagePropertyNames)); - - fieldsArray.AddRange(MongoDBVectorStoreCollectionCreateMapping.GetFilterableDataIndexFields( - this._propertyReader.DataProperties, - this._storagePropertyNames)); + fieldsArray.AddRange(MongoDBVectorStoreCollectionCreateMapping.GetVectorIndexFields(this._model.VectorProperties)); + fieldsArray.AddRange(MongoDBVectorStoreCollectionCreateMapping.GetFilterableDataIndexFields(this._model.DataProperties)); if (fieldsArray.Count > 0) { @@ -411,9 +608,7 @@ private async Task CreateIndexesAsync(string collectionName, CancellationToken c { var fieldsDocument = new BsonDocument(); - fieldsDocument.AddRange(MongoDBVectorStoreCollectionCreateMapping.GetFullTextSearchableDataIndexFields( - this._propertyReader.DataProperties, - this._storagePropertyNames)); + fieldsDocument.AddRange(MongoDBVectorStoreCollectionCreateMapping.GetFullTextSearchableDataIndexFields(this._model.DataProperties)); if (fieldsDocument.ElementCount > 0) { @@ -455,13 +650,13 @@ private async Task> FindAsync(FilterDefinition 0) + if (!includeVectors) { - foreach (var vectorPropertyName in this._vectorStoragePropertyNames) + foreach (var vectorPropertyName in this._model.VectorProperties) { projectionDefinition = projectionDefinition is not null ? - projectionDefinition.Exclude(vectorPropertyName) : - projectionBuilder.Exclude(vectorPropertyName); + projectionDefinition.Exclude(vectorPropertyName.StorageName) : + projectionBuilder.Exclude(vectorPropertyName.StorageName); } } @@ -490,8 +685,9 @@ private async IAsyncEnumerable> EnumerateAndMapSearc { var score = response[ScorePropertyName].AsDouble; var record = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, + MongoDBConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, OperationName, () => this._mapper.MapFromStorageToDataModel(response[DocumentPropertyName].AsBsonDocument, new() { IncludeVectors = includeVectors })); @@ -511,7 +707,7 @@ private FilterDefinition GetFilterByIds(IEnumerable ids) private async Task InternalCollectionExistsAsync(CancellationToken cancellationToken) { - var filter = new BsonDocument("name", this.CollectionName); + var filter = new BsonDocument("name", this.Name); var options = new ListCollectionNamesOptions { Filter = filter }; using var cursor = await this._mongoDatabase.ListCollectionNamesAsync(options, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -529,8 +725,9 @@ private async Task RunOperationAsync(string operationName, Func operation) { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, - CollectionName = this.CollectionName, + VectorStoreSystemName = MongoDBConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, OperationName = operationName }; } @@ -546,8 +743,9 @@ private async Task RunOperationAsync(string operationName, Func> o { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, - CollectionName = this.CollectionName, + VectorStoreSystemName = MongoDBConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, OperationName = operationName }; } @@ -577,8 +775,9 @@ private async Task RunOperationWithRetryAsync( { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, - CollectionName = this.CollectionName, + VectorStoreSystemName = MongoDBConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, OperationName = operationName }; } @@ -611,8 +810,9 @@ private async Task RunOperationWithRetryAsync( { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, - CollectionName = this.CollectionName, + VectorStoreSystemName = MongoDBConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, OperationName = operationName }; } @@ -624,55 +824,6 @@ private async Task RunOperationWithRetryAsync( throw new VectorStoreOperationException("Retry logic failed."); } - /// - /// Gets storage property names taking into account BSON serialization attributes. - /// - private static Dictionary GetStoragePropertyNames( - IReadOnlyList properties, - Type dataModel) - { - var storagePropertyNames = new Dictionary(); - - foreach (var property in properties) - { - var propertyInfo = dataModel.GetProperty(property.DataModelPropertyName); - string propertyName; - - if (propertyInfo != null) - { - var bsonElementAttribute = propertyInfo.GetCustomAttribute(); - - propertyName = bsonElementAttribute?.ElementName ?? property.DataModelPropertyName; - } - else - { - propertyName = property.DataModelPropertyName; - } - - storagePropertyNames[property.DataModelPropertyName] = propertyName; - } - - return storagePropertyNames; - } - - /// - /// Returns custom mapper, generic data model mapper or default record mapper. - /// - private IVectorStoreRecordMapper InitializeMapper() - { - if (this._options.BsonDocumentCustomMapper is not null) - { - return this._options.BsonDocumentCustomMapper; - } - - if (typeof(TRecord) == typeof(VectorStoreGenericDataModel)) - { - return (new MongoDBGenericDataModelMapper(this._propertyReader.RecordDefinition) as IVectorStoreRecordMapper)!; - } - - return new MongoDBVectorStoreRecordMapper(this._propertyReader); - } - private static Array VerifyVectorParam(TVector vector) { Verify.NotNull(vector); @@ -688,5 +839,17 @@ private static Array VerifyVectorParam(TVector vector) typeof(ReadOnlyMemory).FullName])}") }; } + + private string GetStringKey(TKey key) + { + Verify.NotNull(key); + + var stringKey = key as string ?? throw new UnreachableException("string key should have been validated during model building"); + + Verify.NotNullOrWhiteSpace(stringKey, nameof(key)); + + return stringKey; + } + #endregion } diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollectionOptions.cs index bc591f87cdc0..f2d9dee36f97 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollectionOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollectionOptions.cs @@ -1,18 +1,21 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; using MongoDB.Bson; namespace Microsoft.SemanticKernel.Connectors.MongoDB; /// -/// Options when creating a . +/// Options when creating a . /// public sealed class MongoDBVectorStoreRecordCollectionOptions { /// /// Gets or sets an optional custom mapper to use when converting between the data model and the MongoDB BSON object. /// + [Obsolete("Custom mappers are no longer supported.", error: true)] public IVectorStoreRecordMapper? BsonDocumentCustomMapper { get; init; } = null; /// @@ -25,6 +28,11 @@ public sealed class MongoDBVectorStoreRecordCollectionOptions /// public VectorStoreRecordDefinition? VectorStoreRecordDefinition { get; init; } = null; + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } + /// /// Vector index name to use. If null, the default "vector_index" name will be used. /// diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/README.md b/dotnet/src/Connectors/Connectors.Memory.MongoDB/README.md index 4a6ddcda3483..923dd3d9252c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/README.md +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/README.md @@ -6,44 +6,8 @@ This connector uses [MongoDB Atlas Vector Search](https://www.mongodb.com/produc 1. Create [Atlas cluster](https://www.mongodb.com/docs/atlas/getting-started/) -2. Create a [collection](https://www.mongodb.com/docs/atlas/atlas-ui/collections/) +2. Create a Mongo DB Vector Store using instructions on the [Microsoft Learn site](https://learn.microsoft.com/semantic-kernel/concepts/vector-store-connectors/out-of-the-box-connectors/mongodb-connector). -3. Create [Vector Search Index](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/) for the collection. The index has to be defined on a field called `embedding`. For example: - -``` -{ - "type": "vectorSearch", - "fields": [ - { - "numDimensions": , - "path": "embedding", - "similarity": "euclidean | cosine | dotProduct", - "type": "vector" - } - ] -} -``` - -4. Create the MongoDB memory store - > See [Example 14](../../../samples/Concepts/Memory/SemanticTextMemory_Building.cs) and [Example 15](../../../samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs) for more memory usage examples with the kernel. - -```csharp -var connectionString = "MONGODB ATLAS CONNECTION STRING" -MongoDBMemoryStore memoryStore = new(connectionString, "MyDatabase"); - -var embeddingGenerator = new OpenAITextEmbeddingGenerationService("text-embedding-ada-002", apiKey); - -SemanticTextMemory textMemory = new(memoryStore, embeddingGenerator); - -var memoryPlugin = kernel.ImportPluginFromObject(new TextMemoryPlugin(textMemory)); -``` +3. Use the [getting started instructions](https://learn.microsoft.com/semantic-kernel/concepts/vector-store-connectors/?pivots=programming-language-csharp#getting-started-with-vector-store-connectors) on the Microsoft Leearn site to learn more about using the vector store. > Guide to find the connection string: https://www.mongodb.com/docs/manual/reference/connection-string/ - -## Important Notes - -### Vector search indexes - -In this version, vector search index management is outside of `MongoDBMemoryStore` scope. -Creation and maintenance of the indexes have to be done by the user. Please note that deleting a collection -(`memoryStore.DeleteCollectionAsync`) will delete the index as well. diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Connectors.Memory.Pinecone.csproj b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Connectors.Memory.Pinecone.csproj index 3b51b03623ae..e085ae887d7e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Connectors.Memory.Pinecone.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Connectors.Memory.Pinecone.csproj @@ -4,13 +4,14 @@ Microsoft.SemanticKernel.Connectors.Pinecone $(AssemblyName) - net8.0;netstandard2.0 + net8.0;netstandard2.0;net462 preview + @@ -20,6 +21,11 @@ + + + + + @@ -29,6 +35,7 @@ + diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/ConfigureIndexRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/ConfigureIndexRequest.cs index cf6278a4154f..7a2a0c8a5f2c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/ConfigureIndexRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/ConfigureIndexRequest.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; using System.Text.Json.Serialization; @@ -10,7 +10,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// This operation specifies the pod type and number of replicas for an index. /// See https://docs.pinecone.io/reference/configure_index /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] internal sealed class ConfigureIndexRequest { public string IndexName { get; set; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/DeleteIndexRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/DeleteIndexRequest.cs index 63746ca62f88..e01c8e77e4b4 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/DeleteIndexRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/DeleteIndexRequest.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; namespace Microsoft.SemanticKernel.Connectors.Pinecone; @@ -9,7 +9,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// Deletes an index and all its data. /// See https://docs.pinecone.io/reference/delete_index /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] internal sealed class DeleteIndexRequest { public static DeleteIndexRequest Create(string indexName) diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/DeleteRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/DeleteRequest.cs index 3b45cbddbdec..82149dedd2cc 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/DeleteRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/DeleteRequest.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Net.Http; using System.Text; @@ -13,7 +13,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// DeleteRequest /// See https://docs.pinecone.io/reference/delete_post /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] internal sealed class DeleteRequest { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/DescribeIndexRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/DescribeIndexRequest.cs index 9955adebb078..fa49d7c5cf8c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/DescribeIndexRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/DescribeIndexRequest.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; namespace Microsoft.SemanticKernel.Connectors.Pinecone; @@ -9,7 +9,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// Get information about an index. /// See https://docs.pinecone.io/reference/describe_index /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] internal sealed class DescribeIndexRequest { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/DescribeIndexStatsRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/DescribeIndexStatsRequest.cs index 90481a5d8129..184c6f9914bf 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/DescribeIndexStatsRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/DescribeIndexStatsRequest.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Net.Http; using System.Text.Json.Serialization; @@ -11,7 +11,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// DescribeIndexStatsRequest /// See https://docs.pinecone.io/reference/describe_index_stats_post /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] internal sealed class DescribeIndexStatsRequest { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/FetchRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/FetchRequest.cs index da10d86c7d3c..caf1d1fb7c20 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/FetchRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/FetchRequest.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Net.Http; using System.Text.Json.Serialization; @@ -12,7 +12,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// FetchRequest /// See https://docs.pinecone.io/reference/fetch /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] internal sealed class FetchRequest { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/FetchResponse.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/FetchResponse.cs index afa2534e65d8..f740ea341b81 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/FetchResponse.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/FetchResponse.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Text.Json.Serialization; @@ -13,7 +13,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// FetchResponse /// See https://docs.pinecone.io/reference/fetch /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] internal sealed class FetchResponse { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/ListIndexesRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/ListIndexesRequest.cs index 77919ddcf15a..9ed2d9bd0dc3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/ListIndexesRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/ListIndexesRequest.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; namespace Microsoft.SemanticKernel.Connectors.Pinecone; @@ -9,7 +9,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// ListIndexesRequest /// See https://docs.pinecone.io/reference/list_indexes /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] internal sealed class ListIndexesRequest { public static ListIndexesRequest Create() diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/QueryRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/QueryRequest.cs index a75309d2c266..cadffd29c5c0 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/QueryRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/QueryRequest.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Net.Http; using System.Text.Json.Serialization; @@ -12,7 +11,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// QueryRequest /// See https://docs.pinecone.io/reference/query /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] internal sealed class QueryRequest { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/QueryResponse.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/QueryResponse.cs index f7ede69bccad..5baee10108b6 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/QueryResponse.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/QueryResponse.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Pinecone; @@ -12,7 +12,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// QueryResponse /// See https://docs.pinecone.io/reference/query /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] internal sealed class QueryResponse { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/UpdateVectorRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/UpdateVectorRequest.cs index a8988b01f7eb..b3d2d8e124b5 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/UpdateVectorRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/UpdateVectorRequest.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Net.Http; using System.Text.Json.Serialization; @@ -14,7 +13,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// If a set_metadata is included, the values of the fields specified in it will be added or overwrite the previous value. /// See https://docs.pinecone.io/reference/update /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] internal sealed class UpdateVectorRequest { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/UpsertRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/UpsertRequest.cs index f4a407d0ef66..8d4ad95213c3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/UpsertRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/UpsertRequest.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Net.Http; using System.Text.Json.Serialization; @@ -11,7 +11,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// UpsertRequest /// See https://docs.pinecone.io/reference/upsert /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] internal sealed class UpsertRequest { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/UpsertResponse.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/UpsertResponse.cs index 616e2746b5c8..066075b42883 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/UpsertResponse.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Http/ApiSchema/UpsertResponse.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Pinecone; @@ -11,7 +11,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// UpsertResponse /// See https://docs.pinecone.io/reference/upsert /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] internal sealed class UpsertResponse { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/IPineconeClient.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/IPineconeClient.cs index d9642325b51c..1375353de9f0 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/IPineconeClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/IPineconeClient.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; @@ -11,7 +10,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// /// Interface for a Pinecone client /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public interface IPineconeClient { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/IPineconeMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/IPineconeMemoryStore.cs index c23c52b68760..b17b85fec853 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/IPineconeMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/IPineconeMemoryStore.cs @@ -2,18 +2,19 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; using Microsoft.SemanticKernel.Memory; namespace Microsoft.SemanticKernel.Connectors.Pinecone; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// Interface for Pinecone memory store that extends the memory store interface /// to add support for namespaces /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public interface IPineconeMemoryStore : IMemoryStore { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/IPineconeVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/IPineconeVectorStoreRecordCollectionFactory.cs index 25b6efae42de..347412b3e861 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/IPineconeVectorStoreRecordCollectionFactory.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/IPineconeVectorStoreRecordCollectionFactory.cs @@ -22,5 +22,6 @@ public interface IPineconeVectorStoreRecordCollectionFactory /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. /// The new instance of . IVectorStoreRecordCollection CreateVectorStoreRecordCollection(Sdk.PineconeClient pineconeClient, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) - where TKey : notnull; + where TKey : notnull + where TRecord : notnull; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexDefinition.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexDefinition.cs index b6924f0a3ea3..de0094d1b8b7 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexDefinition.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexDefinition.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; using System.Text; using System.Text.Json.Serialization; @@ -11,7 +11,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// Used to create a new index. /// See https://docs.pinecone.io/reference/create_index /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public class IndexDefinition { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexMetadataConfig.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexMetadataConfig.cs index 0152df016d9f..733542974bff 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexMetadataConfig.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexMetadataConfig.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; using Microsoft.SemanticKernel.Memory; @@ -10,7 +10,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// /// Configuration for the behavior of Pinecone's internal metadata index. By default, all metadata is indexed; when metadata_config is present, only specified metadata fields are indexed. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public class MetadataIndexConfig { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexMetric.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexMetric.cs index 0cfb54e5bfc2..a9ebc98663e6 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexMetric.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexMetric.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Runtime.Serialization; using System.Text.Json.Serialization; @@ -10,7 +10,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// The vector similarity metric of the index /// /// The vector similarity metric of the index -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] [JsonConverter(typeof(JsonStringEnumConverter))] public enum IndexMetric { diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexNamespaceStats.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexNamespaceStats.cs index 1099d27ace24..7f166ac997b1 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexNamespaceStats.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexNamespaceStats.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Pinecone; @@ -8,7 +8,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// /// Index namespace parameters. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public class IndexNamespaceStats { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexState.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexState.cs index 6998fb950c0d..7d1af1db0c39 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexState.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexState.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Runtime.Serialization; using System.Text.Json.Serialization; @@ -9,7 +9,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// /// The current status of a index. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] [JsonConverter(typeof(JsonStringEnumConverter))] public enum IndexState { diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexStats.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexStats.cs index fcecab22f888..eb3b4c53e1be 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexStats.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexStats.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Pinecone; @@ -9,7 +9,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// /// Index parameters. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public class IndexStats { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexStatus.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexStatus.cs index 80a805118ceb..8e027aa47417 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexStatus.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexStatus.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Pinecone; @@ -8,7 +8,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// /// Status of the index. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public class IndexStatus { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/OperationType.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/OperationType.cs index 2ab24223baa3..be2da0ef5306 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/OperationType.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/OperationType.cs @@ -1,10 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; namespace Microsoft.SemanticKernel.Connectors.Pinecone; -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] internal enum OperationType { Upsert, diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/PineconeIndex.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/PineconeIndex.cs index a018d6163446..451401722b70 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/PineconeIndex.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/PineconeIndex.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Pinecone; @@ -8,7 +8,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// /// Index entity. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public sealed class PineconeIndex { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/PodType.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/PodType.cs index a1ca09720942..828a23952ec5 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/PodType.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/PodType.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Reflection; using System.Runtime.Serialization; @@ -13,7 +12,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// /// Pod type of the index, see https://docs.pinecone.io/docs/indexes#pods-pod-types-and-pod-sizes. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] [JsonConverter(typeof(PodTypeJsonConverter))] public enum PodType { @@ -108,7 +107,7 @@ public enum PodType } #pragma warning disable CA1812 // Avoid uninstantiated internal classes -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] internal sealed class PodTypeJsonConverter : JsonConverter #pragma warning restore CA1812 { diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/Query.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/Query.cs index deed1cd706fe..382a12e76f62 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/Query.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/Query.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Pinecone; @@ -10,7 +9,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// /// Query parameters for use in a query request. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public sealed class Query { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/SparseVectorData.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/SparseVectorData.cs index 93fdf06fa985..4b9f606b0714 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/SparseVectorData.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/SparseVectorData.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Pinecone; @@ -10,7 +9,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// /// Represents a sparse vector data, which is a list of indices and a list of corresponding values, both of the same length. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public class SparseVectorData { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeClient.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeClient.cs index 169a90a6273c..35305269f1cb 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeClient.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Net; using System.Net.Http; @@ -20,7 +19,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// /// A client for the Pinecone API /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public sealed class PineconeClient : IPineconeClient { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeConstants.cs new file mode 100644 index 000000000000..a8b4cfadfed2 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeConstants.cs @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.Pinecone; + +internal static class PineconeConstants +{ + internal const string VectorStoreSystemName = "pinecone"; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeDocument.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeDocument.cs index 147d7d96b741..8a619813dcf5 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeDocument.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeDocument.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; @@ -13,7 +12,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// /// Pinecone Document entity. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public class PineconeDocument { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeDocumentExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeDocumentExtensions.cs index 5cba6c227717..8c22facbeb50 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeDocumentExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeDocumentExtensions.cs @@ -2,17 +2,18 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Text.Json; using Microsoft.SemanticKernel.Memory; namespace Microsoft.SemanticKernel.Connectors.Pinecone; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// Extensions for class. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public static class PineconeDocumentExtensions { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeFilterTranslator.cs index 54a7202eaa07..7725e76ed633 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeFilterTranslator.cs @@ -7,8 +7,8 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.ConnectorSupport.Filter; using Pinecone; namespace Microsoft.SemanticKernel.Connectors.Pinecone; @@ -20,17 +20,20 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; // as we sometimes need to extend the collection (with for example another condition). internal class PineconeFilterTranslator { - private IReadOnlyDictionary _storagePropertyNames = null!; + private VectorStoreRecordModel _model = null!; private ParameterExpression _recordParameter = null!; - internal Metadata Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + internal Metadata Translate(LambdaExpression lambdaExpression, VectorStoreRecordModel model) { - this._storagePropertyNames = storagePropertyNames; + this._model = model; Debug.Assert(lambdaExpression.Parameters.Count == 1); this._recordParameter = lambdaExpression.Parameters[0]; - return this.Translate(lambdaExpression.Body); + var preprocessor = new FilterTranslationPreprocessor { InlineCapturedVariables = true }; + var preprocessedExpression = preprocessor.Visit(lambdaExpression.Body); + + return this.Translate(preprocessedExpression); } private Metadata Translate(Expression? node) @@ -49,9 +52,9 @@ or ExpressionType.LessThan or ExpressionType.LessThanOrEqual UnaryExpression { NodeType: ExpressionType.Not } not => this.TranslateNot(not), - // MemberExpression is generally handled within e.g. TranslateEqualityComparison; this is used to translate direct bool inside filter (e.g. Filter => r => r.Bool) - MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _) - => this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(true))), + // Special handling for bool constant as the filter expression (r => r.Bool) + Expression when node.Type == typeof(bool) && this.TryBindProperty(node, out var property) + => this.GenerateEqualityComparison(property, true, ExpressionType.Equal), MethodCallExpression methodCall => this.TranslateMethodCall(methodCall), @@ -59,36 +62,37 @@ or ExpressionType.LessThan or ExpressionType.LessThanOrEqual }; private Metadata TranslateEqualityComparison(BinaryExpression binary) + => this.TryBindProperty(binary.Left, out var property) && binary.Right is ConstantExpression { Value: var rightConstant } + ? this.GenerateEqualityComparison(property, rightConstant, binary.NodeType) + : this.TryBindProperty(binary.Right, out property) && binary.Left is ConstantExpression { Value: var leftConstant } + ? this.GenerateEqualityComparison(property, leftConstant, binary.NodeType) + : throw new NotSupportedException("Invalid equality/comparison"); + + private Metadata GenerateEqualityComparison(VectorStoreRecordPropertyModel property, object? value, ExpressionType nodeType) { - if ((this.TryTranslateFieldAccess(binary.Left, out var storagePropertyName) && TryGetConstant(binary.Right, out var value)) - || (this.TryTranslateFieldAccess(binary.Right, out storagePropertyName) && TryGetConstant(binary.Left, out value))) + if (value is null) { - if (value is null) - { - throw new NotSupportedException("Pincone does not support null checks in vector search pre-filters"); - } - - // Short form of equality (instead of $eq) - if (binary.NodeType is ExpressionType.Equal) - { - return new Metadata { [storagePropertyName] = ToMetadata(value) }; - } + throw new NotSupportedException("Pincone does not support null checks in vector search pre-filters"); + } - var filterOperator = binary.NodeType switch - { - ExpressionType.NotEqual => "$ne", - ExpressionType.GreaterThan => "$gt", - ExpressionType.GreaterThanOrEqual => "$gte", - ExpressionType.LessThan => "$lt", - ExpressionType.LessThanOrEqual => "$lte", + // Short form of equality (instead of $eq) + if (nodeType is ExpressionType.Equal) + { + return new Metadata { [property.StorageName] = ToMetadata(value) }; + } - _ => throw new UnreachableException() - }; + var filterOperator = nodeType switch + { + ExpressionType.NotEqual => "$ne", + ExpressionType.GreaterThan => "$gt", + ExpressionType.GreaterThanOrEqual => "$gte", + ExpressionType.LessThan => "$lt", + ExpressionType.LessThanOrEqual => "$lte", - return new Metadata { [storagePropertyName] = new Metadata { [filterOperator] = ToMetadata(value) } }; - } + _ => throw new UnreachableException() + }; - throw new NotSupportedException("Invalid equality/comparison"); + return new Metadata { [property.StorageName] = new Metadata { [filterOperator] = ToMetadata(value) } }; } private Metadata TranslateAndOr(BinaryExpression andOr) @@ -134,8 +138,8 @@ private Metadata TranslateNot(UnaryExpression not) binary.Right)); // Not over bool field (Filter => r => !r.Bool) - case MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _): - return this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(false))); + case Expression when not.Operand.Type == typeof(bool) && this.TryBindProperty(not.Operand, out var property): + return this.GenerateEqualityComparison(property, false, ExpressionType.Equal); } var operand = this.Translate(not.Operand); @@ -177,7 +181,7 @@ private Metadata TranslateContains(Expression source, Expression item) switch (source) { // Contains over array column (r => r.Strings.Contains("foo")) - case var _ when this.TryTranslateFieldAccess(source, out _): + case var _ when this.TryBindProperty(source, out _): throw new NotSupportedException("Pinecone does not support Contains within array fields ($elemMatch) in vector search pre-filters"); // Contains over inline enumerable @@ -186,7 +190,7 @@ private Metadata TranslateContains(Expression source, Expression item) for (var i = 0; i < newArray.Expressions.Count; i++) { - if (!TryGetConstant(newArray.Expressions[i], out var elementValue)) + if (newArray.Expressions[i] is not ConstantExpression { Value: var elementValue }) { throw new NotSupportedException("Invalid element in array"); } @@ -196,9 +200,7 @@ private Metadata TranslateContains(Expression source, Expression item) return ProcessInlineEnumerable(elements, item); - // Contains over captured enumerable (we inline) - case var _ when TryGetConstant(source, out var constantEnumerable) - && constantEnumerable is IEnumerable enumerable and not string: + case ConstantExpression { Value: IEnumerable enumerable and not string }: return ProcessInlineEnumerable(enumerable, item); default: @@ -207,14 +209,14 @@ private Metadata TranslateContains(Expression source, Expression item) Metadata ProcessInlineEnumerable(IEnumerable elements, Expression item) { - if (!this.TryTranslateFieldAccess(item, out var storagePropertyName)) + if (!this.TryBindProperty(item, out var property)) { throw new NotSupportedException("Unsupported item type in Contains"); } return new Metadata { - [storagePropertyName] = new Metadata + [property.StorageName] = new Metadata { ["$in"] = new MetadataValue(elements.Cast().Select(ToMetadata).ToList()) } @@ -222,41 +224,50 @@ Metadata ProcessInlineEnumerable(IEnumerable elements, Expression item) } } - private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName) + private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out VectorStoreRecordPropertyModel? property) { - if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter) - { - if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName)) - { - throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name."); - } + Type? convertedClrType = null; - return true; + if (expression is UnaryExpression { NodeType: ExpressionType.Convert } unary) + { + expression = unary.Operand; + convertedClrType = unary.Type; } - storagePropertyName = null; - return false; - } + var modelName = expression switch + { + // Regular member access for strongly-typed POCO binding (e.g. r => r.SomeInt == 8) + MemberExpression memberExpression when memberExpression.Expression == this._recordParameter + => memberExpression.Member.Name, - private static bool TryGetConstant(Expression expression, out object? constantValue) - { - switch (expression) + // Dictionary lookup for weakly-typed dynamic binding (e.g. r => r["SomeInt"] == 8) + MethodCallExpression + { + Method: { Name: "get_Item", DeclaringType: var declaringType }, + Arguments: [ConstantExpression { Value: string keyName }] + } methodCall when methodCall.Object == this._recordParameter && declaringType == typeof(Dictionary) + => keyName, + + _ => null + }; + + if (modelName is null) { - case ConstantExpression { Value: var v }: - constantValue = v; - return true; + property = null; + return false; + } - // This identifies compiler-generated closure types which contain captured variables. - case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } - when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) - && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true): - constantValue = fieldInfo.GetValue(constant.Value); - return true; + if (!this._model.PropertyMap.TryGetValue(modelName, out property)) + { + throw new InvalidOperationException($"Property name '{modelName}' provided as part of the filter clause is not a valid property name."); + } - default: - constantValue = null; - return false; + if (convertedClrType is not null && convertedClrType != property.Type) + { + throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convertedClrType.Name}', but its configured type is '{property.Type.Name}'."); } + + return true; } private static MetadataValue? ToMetadata(object? value) diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeGenericDataModelMapper.cs deleted file mode 100644 index df783a230498..000000000000 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeGenericDataModelMapper.cs +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using Microsoft.Extensions.VectorData; -using Pinecone; - -namespace Microsoft.SemanticKernel.Connectors.Pinecone; - -/// -/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Pinecone. -/// -internal sealed class PineconeGenericDataModelMapper : IVectorStoreRecordMapper, Vector> -{ - private readonly VectorStoreRecordPropertyReader _propertyReader; - - /// - /// Initializes a new instance of the class. - /// - /// A helper to access property information for the current data model and record definition. - public PineconeGenericDataModelMapper( - VectorStoreRecordPropertyReader propertyReader) - { - Verify.NotNull(propertyReader); - - // Validate property types. - propertyReader.VerifyKeyProperties(PineconeVectorStoreRecordFieldMapping.s_supportedKeyTypes); - propertyReader.VerifyDataProperties(PineconeVectorStoreRecordFieldMapping.s_supportedDataTypes, PineconeVectorStoreRecordFieldMapping.s_supportedEnumerableDataElementTypes); - propertyReader.VerifyVectorProperties(PineconeVectorStoreRecordFieldMapping.s_supportedVectorTypes); - - // Assign. - this._propertyReader = propertyReader; - } - - /// - public Vector MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - { - var metadata = new Metadata(); - - // Map data properties. - foreach (var dataProperty in this._propertyReader.DataProperties) - { - if (dataModel.Data.TryGetValue(dataProperty.DataModelPropertyName, out var propertyValue)) - { - var propertyStorageName = this._propertyReader.GetStoragePropertyName(dataProperty.DataModelPropertyName); - metadata[propertyStorageName] = propertyValue is not null - ? PineconeVectorStoreRecordFieldMapping.ConvertToMetadataValue(propertyValue) - : null; - } - } - - // Map vector property. - if (dataModel.Vectors.Count != 1) - { - throw new VectorStoreRecordMappingException($"Exactly one vector is supported by the Pinecone connector, but the provided data model contains {dataModel.Vectors.Count}."); - } - - if (!dataModel.Vectors.TryGetValue(this._propertyReader.FirstVectorPropertyName!, out var valuesObject) || valuesObject is not ReadOnlyMemory values) - { - throw new VectorStoreRecordMappingException($"Vector property '{this._propertyReader.FirstVectorPropertyName}' on provided record of type {nameof(VectorStoreGenericDataModel)} must be of type ReadOnlyMemory and not null."); - } - - // TODO: what about sparse values? - var result = new Vector - { - Id = dataModel.Key, - Values = values, - Metadata = metadata, - SparseValues = null - }; - - return result; - } - - /// - public VectorStoreGenericDataModel MapFromStorageToDataModel(Vector storageModel, StorageToDataModelMapperOptions options) - { - // Construct the data model. - var dataModel = new VectorStoreGenericDataModel(storageModel.Id); - - // Set Vector. - if (options?.IncludeVectors is true) - { - dataModel.Vectors.Add(this._propertyReader.FirstVectorPropertyName!, storageModel.Values); - } - - // Set Data. - if (storageModel.Metadata != null) - { - foreach (var dataProperty in this._propertyReader.DataProperties) - { - var propertyStorageName = this._propertyReader.GetStoragePropertyName(dataProperty.DataModelPropertyName); - if (storageModel.Metadata.TryGetValue(propertyStorageName, out var propertyValue)) - { - dataModel.Data[dataProperty.DataModelPropertyName] = - propertyValue is not null - ? PineconeVectorStoreRecordFieldMapping.ConvertFromMetadataValueToNativeType(propertyValue, dataProperty.PropertyType) - : null; - } - } - } - - return dataModel; - } -} diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeKernelBuilderExtensions.cs index 50048c8dfa6f..694e5327ba9b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeKernelBuilderExtensions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Pinecone; using Sdk = Pinecone; @@ -9,6 +10,7 @@ namespace Microsoft.SemanticKernel; /// /// Extension methods to register Pinecone instances on the . /// +[Obsolete("The IKernelBuilder extensions are being obsoleted, call the appropriate function on the Services property of your IKernelBuilder")] public static class PineconeKernelBuilderExtensions { /// @@ -39,13 +41,13 @@ public static IKernelBuilder AddPineconeVectorStore(this IKernelBuilder builder, } /// - /// Register a Pinecone and with the + /// Register a Pinecone and with the /// specified service ID and where is retrieved from the dependency injection container. /// /// The type of the data model that the collection should contain. /// The builder to register the on. - /// The name of the collection that this will access. - /// Optional configuration options to pass to the . + /// The name of the collection that this will access. + /// Optional configuration options to pass to the . /// An optional service id to use as the service key. /// The kernel builder. public static IKernelBuilder AddPineconeVectorStoreRecordCollection( @@ -53,20 +55,21 @@ public static IKernelBuilder AddPineconeVectorStoreRecordCollection( string collectionName, PineconeVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { builder.Services.AddPineconeVectorStoreRecordCollection(collectionName, options, serviceId); return builder; } /// - /// Register a Pinecone and with the + /// Register a Pinecone and with the /// provided and the specified service ID. /// /// The type of the data model that the collection should contain. /// The builder to register the on. - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// The api key for Pinecone. - /// Optional configuration options to pass to the . + /// Optional configuration options to pass to the . /// An optional service id to use as the service key. /// The kernel builder. public static IKernelBuilder AddPineconeVectorStoreRecordCollection( @@ -75,6 +78,7 @@ public static IKernelBuilder AddPineconeVectorStoreRecordCollection( string apiKey, PineconeVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { builder.Services.AddPineconeVectorStoreRecordCollection(collectionName, apiKey, options, serviceId); return builder; diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeMemoryBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeMemoryBuilderExtensions.cs index cec2c1391560..be3b1c800297 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeMemoryBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeMemoryBuilderExtensions.cs @@ -1,16 +1,18 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; using Microsoft.SemanticKernel.Http; using Microsoft.SemanticKernel.Memory; namespace Microsoft.SemanticKernel.Connectors.Pinecone; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// Provides extension methods for the class to configure Pinecone connector. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public static class PineconeMemoryBuilderExtensions { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeMemoryStore.cs index 0d1118ded656..aacc9e1d0492 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeMemoryStore.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; using System.Threading; @@ -13,6 +12,8 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// An implementation of for Pinecone Vector database. /// @@ -23,7 +24,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// For that reason, we use the term "Index" in Pinecone to refer to what is a "Collection" in IMemoryStore. So, in the case of Pinecone, /// "Collection" is synonymous with "Index" when referring to IMemoryStore. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public class PineconeMemoryStore : IPineconeMemoryStore { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeServiceCollectionExtensions.cs index 5e7658eb923f..9b0e3a67a3d8 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeServiceCollectionExtensions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Pinecone; @@ -28,11 +29,12 @@ public static IServiceCollection AddPineconeVectorStore(this IServiceCollection (sp, obj) => { var pineconeClient = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService(); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new PineconeVectorStore( - pineconeClient, - selectedOptions); + return new PineconeVectorStore(pineconeClient, options); }); return services; @@ -53,24 +55,25 @@ public static IServiceCollection AddPineconeVectorStore(this IServiceCollection (sp, obj) => { var pineconeClient = new Sdk.PineconeClient(apiKey); - var selectedOptions = options ?? sp.GetService(); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new PineconeVectorStore( - pineconeClient, - selectedOptions); + return new PineconeVectorStore(pineconeClient, options); }); return services; } /// - /// Register a Pinecone and with the + /// Register a Pinecone and with the /// specified service ID and where is retrieved from the dependency injection container. /// /// The type of the data model that the collection should contain. /// The to register the on. - /// The name of the collection that this will access. - /// Optional configuration options to pass to the . + /// The name of the collection that this will access. + /// Optional configuration options to pass to the . /// An optional service id to use as the service key. /// The service collection. public static IServiceCollection AddPineconeVectorStoreRecordCollection( @@ -78,6 +81,7 @@ public static IServiceCollection AddPineconeVectorStoreRecordCollection string collectionName, PineconeVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { // If we are not constructing the PineconeClient, add the IVectorStore as transient, since we // cannot make assumptions about how PineconeClient is being managed. @@ -86,12 +90,12 @@ public static IServiceCollection AddPineconeVectorStoreRecordCollection (sp, obj) => { var pineconeClient = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new PineconeVectorStoreRecordCollection( - pineconeClient, - collectionName, - selectedOptions); + return new PineconeVectorStoreRecordCollection(pineconeClient, collectionName, options); }); AddVectorizedSearch(services, serviceId); @@ -100,14 +104,14 @@ public static IServiceCollection AddPineconeVectorStoreRecordCollection } /// - /// Register a Pinecone and with the + /// Register a Pinecone and with the /// provided and the specified service ID. /// /// The type of the data model that the collection should contain. /// The to register the on. - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// The api key for Pinecone. - /// Optional configuration options to pass to the . + /// Optional configuration options to pass to the . /// An optional service id to use as the service key. /// The service collection. public static IServiceCollection AddPineconeVectorStoreRecordCollection( @@ -116,18 +120,19 @@ public static IServiceCollection AddPineconeVectorStoreRecordCollection string apiKey, PineconeVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { services.AddKeyedSingleton>( serviceId, (sp, obj) => { var pineconeClient = new Sdk.PineconeClient(apiKey); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new PineconeVectorStoreRecordCollection( - pineconeClient, - collectionName, - selectedOptions); + return new PineconeVectorStoreRecordCollection(pineconeClient, collectionName, options); }); AddVectorizedSearch(services, serviceId); @@ -136,14 +141,14 @@ public static IServiceCollection AddPineconeVectorStoreRecordCollection } /// - /// Also register the with the given as a . + /// Also register the with the given as a . /// /// The type of the data model that the collection should contain. /// The service collection to register on. /// The service id that the registrations should use. - private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) + private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) where TRecord : notnull { - services.AddKeyedTransient>( + services.AddKeyedTransient>( serviceId, (sp, obj) => { diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeUtils.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeUtils.cs index 18d495399986..4058079097c2 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeUtils.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeUtils.cs @@ -3,7 +3,6 @@ using System; using System.Collections; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.IO; using System.Text; using System.Text.Encodings.Web; @@ -16,7 +15,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// /// Utils for Pinecone connector. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public static class PineconeUtils { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStore.cs index a072ea6e7336..03763a26601a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStore.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Threading; +using System.Threading.Tasks; using Microsoft.Extensions.VectorData; using Pinecone; using Sdk = Pinecone; @@ -16,13 +17,17 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// /// This class can be used with collections of any schema type, but requires you to provide schema information when getting a collection. /// -public class PineconeVectorStore : IVectorStore +public sealed class PineconeVectorStore : IVectorStore { - private const string DatabaseName = "Pinecone"; - private readonly Sdk.PineconeClient _pineconeClient; private readonly PineconeVectorStoreOptions _options; + /// Metadata about vector store. + private readonly VectorStoreMetadata _metadata; + + /// A general purpose definition that can be used to construct a collection when needing to proxy schema agnostic operations. + private static readonly VectorStoreRecordDefinition s_generalPurposeDefinition = new() { Properties = [new VectorStoreRecordKeyProperty("Key", typeof(string)), new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 1)] }; + /// /// Initializes a new instance of the class. /// @@ -34,11 +39,17 @@ public PineconeVectorStore(Sdk.PineconeClient pineconeClient, PineconeVectorStor this._pineconeClient = pineconeClient; this._options = options ?? new PineconeVectorStoreOptions(); + + this._metadata = new() + { + VectorStoreSystemName = PineconeConstants.VectorStoreSystemName + }; } /// - public virtual IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull + where TRecord : notnull { #pragma warning disable CS0618 // IPineconeVectorStoreRecordCollectionFactory is obsolete if (this._options.VectorStoreCollectionFactory is not null) @@ -47,19 +58,18 @@ public virtual IVectorStoreRecordCollection GetCollection( + return (new PineconeVectorStoreRecordCollection( this._pineconeClient, name, - new PineconeVectorStoreRecordCollectionOptions() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }) as IVectorStoreRecordCollection)!; + new PineconeVectorStoreRecordCollectionOptions() + { + VectorStoreRecordDefinition = vectorStoreRecordDefinition, + EmbeddingGenerator = this._options.EmbeddingGenerator + }) as IVectorStoreRecordCollection)!; } /// - public virtual async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { IndexList indexList; @@ -71,7 +81,8 @@ public virtual async IAsyncEnumerable ListCollectionNamesAsync([Enumerat { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreSystemName = PineconeConstants.VectorStoreSystemName, + VectorStoreName = this._metadata.VectorStoreName, OperationName = "ListCollections" }; } @@ -84,4 +95,31 @@ public virtual async IAsyncEnumerable ListCollectionNamesAsync([Enumerat } } } + + /// + public Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.CollectionExistsAsync(cancellationToken); + } + + /// + public Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.DeleteCollectionAsync(cancellationToken); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreMetadata) ? this._metadata : + serviceType == typeof(Sdk.PineconeClient) ? this._pineconeClient : + serviceType.IsInstanceOfType(this) ? this : + null; + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreCollectionSearchMapping.cs index 8e633c76e47e..2b819060abfb 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreCollectionSearchMapping.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Pinecone; namespace Microsoft.SemanticKernel.Connectors.Pinecone; @@ -17,10 +18,10 @@ internal static class PineconeVectorStoreCollectionSearchMapping /// Build a Pinecone from a set of filter clauses. /// /// The filter clauses to build the Pinecone from. - /// A mapping from property name to the name under which the property would be stored. + /// The model. /// The Pinecone . /// Thrown for invalid property names, value types or filter clause types. - public static Metadata BuildSearchFilter(IEnumerable? filterClauses, IReadOnlyDictionary storagePropertyNamesMap) + public static Metadata BuildSearchFilter(IEnumerable? filterClauses, VectorStoreRecordModel model) { var metadataMap = new Metadata(); @@ -33,7 +34,7 @@ public static Metadata BuildSearchFilter(IEnumerable? filterClause { if (filterClause is EqualToFilterClause equalToFilterClause) { - if (!storagePropertyNamesMap.TryGetValue(equalToFilterClause.FieldName, out var storagePropertyName)) + if (!model.PropertyMap.TryGetValue(equalToFilterClause.FieldName, out var property)) { throw new InvalidOperationException($"Property '{equalToFilterClause.FieldName}' is not a valid property name."); } @@ -49,7 +50,7 @@ public static Metadata BuildSearchFilter(IEnumerable? filterClause _ => throw new NotSupportedException($"Unsupported filter value type '{equalToFilterClause.Value.GetType().Name}'.") }; - metadataMap.Add(storagePropertyName, metadataValue); + metadataMap.Add(property.StorageName, metadataValue); } else { diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreOptions.cs index 310cce39d533..f27b8a54239f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreOptions.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.Connectors.Pinecone; @@ -10,7 +11,12 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; public sealed class PineconeVectorStoreOptions { /// - /// An optional factory to use for constructing instances, if a custom record collection is required. + /// Gets or sets the default embedding generator for vector properties in this collection. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } + + /// + /// An optional factory to use for constructing instances, if a custom record collection is required. /// [Obsolete("To control how collections are instantiated, extend your provider's IVectorStore implementation and override GetCollection()")] public IPineconeVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollection.cs index 3da753575141..a4aab608c8ec 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollection.cs @@ -2,12 +2,16 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.Properties; using Pinecone; using Sdk = Pinecone; @@ -16,97 +20,87 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// /// Service for storing and retrieving vector records, that uses Pinecone as the underlying storage. /// +/// The data type of the record key. Can be either , or for dynamic mapping. /// The data model to use for adding, updating and retrieving data from storage. #pragma warning disable CA1711 // Identifiers should not have incorrect suffix -public class PineconeVectorStoreRecordCollection : IVectorStoreRecordCollection +public sealed class PineconeVectorStoreRecordCollection : IVectorStoreRecordCollection + where TKey : notnull + where TRecord : notnull #pragma warning restore CA1711 // Identifiers should not have incorrect suffix { - private const string DatabaseName = "Pinecone"; - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + /// Metadata about vector store record collection. + private readonly VectorStoreRecordCollectionMetadata _collectionMetadata; + private readonly Sdk.PineconeClient _pineconeClient; private readonly PineconeVectorStoreRecordCollectionOptions _options; - private readonly VectorStoreRecordPropertyReader _propertyReader; - private readonly IVectorStoreRecordMapper _mapper; + private readonly VectorStoreRecordModel _model; + private readonly PineconeVectorStoreRecordMapper _mapper; private IndexClient? _indexClient; /// - public string CollectionName { get; } + public string Name { get; } /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// Pinecone client that can be used to manage the collections and vectors in a Pinecone store. /// Optional configuration options for this class. /// Thrown if the is null. - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// Thrown for any misconfigured options. - public PineconeVectorStoreRecordCollection(Sdk.PineconeClient pineconeClient, string collectionName, PineconeVectorStoreRecordCollectionOptions? options = null) + public PineconeVectorStoreRecordCollection(Sdk.PineconeClient pineconeClient, string name, PineconeVectorStoreRecordCollectionOptions? options = null) { Verify.NotNull(pineconeClient); - VerifyCollectionName(collectionName); + VerifyCollectionName(name); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.VectorCustomMapper is not null, PineconeVectorStoreRecordFieldMapping.s_supportedKeyTypes); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); + if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(object)) + { + throw new NotSupportedException("Only string keys are supported (and object for dynamic mapping)"); + } this._pineconeClient = pineconeClient; - this.CollectionName = collectionName; + this.Name = name; this._options = options ?? new PineconeVectorStoreRecordCollectionOptions(); - this._propertyReader = new VectorStoreRecordPropertyReader( - typeof(TRecord), - this._options.VectorStoreRecordDefinition, - new() - { - RequiresAtLeastOneVector = true, - SupportsMultipleKeys = false, - SupportsMultipleVectors = false, - }); + this._model = new VectorStoreRecordModelBuilder(PineconeVectorStoreRecordFieldMapping.ModelBuildingOptions) + .Build(typeof(TRecord), this._options.VectorStoreRecordDefinition, this._options.EmbeddingGenerator); + this._mapper = new PineconeVectorStoreRecordMapper(this._model); - if (this._options.VectorCustomMapper is not null) - { - // Custom Mapper. - this._mapper = this._options.VectorCustomMapper; - } - else if (typeof(TRecord) == typeof(VectorStoreGenericDataModel)) + this._collectionMetadata = new() { - // Generic data model mapper. - this._mapper = (new PineconeGenericDataModelMapper(this._propertyReader) as IVectorStoreRecordMapper)!; - } - else - { - // Default Mapper. - this._mapper = new PineconeVectorStoreRecordMapper(this._propertyReader); - } + VectorStoreSystemName = PineconeConstants.VectorStoreSystemName, + CollectionName = name + }; } /// - public virtual Task CollectionExistsAsync(CancellationToken cancellationToken = default) + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) => this.RunCollectionOperationAsync( "CollectionExists", async () => { var collections = await this._pineconeClient.ListIndexesAsync(cancellationToken: cancellationToken).ConfigureAwait(false); - return collections.Indexes?.Any(x => x.Name == this.CollectionName) is true; + return collections.Indexes?.Any(x => x.Name == this.Name) is true; }); /// - public virtual Task CreateCollectionAsync(CancellationToken cancellationToken = default) + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) { // we already run through record property validation, so a single VectorStoreRecordVectorProperty is guaranteed. - var vectorProperty = this._propertyReader.VectorProperty!; + var vectorProperty = this._model.VectorProperty!; if (!string.IsNullOrEmpty(vectorProperty.IndexKind) && vectorProperty.IndexKind != "PGA") { throw new InvalidOperationException( - $"IndexKind of '{vectorProperty.IndexKind}' for property '{vectorProperty.DataModelPropertyName}' is not supported. Pinecone only supports 'PGA' (Pinecone Graph Algorithm), which is always enabled."); + $"IndexKind of '{vectorProperty.IndexKind}' for property '{vectorProperty.ModelName}' is not supported. Pinecone only supports 'PGA' (Pinecone Graph Algorithm), which is always enabled."); } CreateIndexRequest request = new() { - Name = this.CollectionName, - Dimension = vectorProperty.Dimensions ?? throw new InvalidOperationException($"Property {nameof(vectorProperty.Dimensions)} on {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' must be set to a positive integer to create a collection."), + Name = this.Name, + Dimension = vectorProperty.Dimensions, Metric = MapDistanceFunction(vectorProperty), Spec = new ServerlessIndexSpec { @@ -123,7 +117,7 @@ public virtual Task CreateCollectionAsync(CancellationToken cancellationToken = } /// - public virtual async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) { if (!await this.CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) { @@ -139,11 +133,11 @@ public virtual async Task CreateCollectionIfNotExistsAsync(CancellationToken can } /// - public virtual async Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + public async Task DeleteCollectionAsync(CancellationToken cancellationToken = default) { try { - await this._pineconeClient.DeleteIndexAsync(this.CollectionName, cancellationToken: cancellationToken).ConfigureAwait(false); + await this._pineconeClient.DeleteIndexAsync(this.Name, cancellationToken: cancellationToken).ConfigureAwait(false); } catch (NotFoundError) { @@ -153,22 +147,26 @@ public virtual async Task DeleteCollectionAsync(CancellationToken cancellationTo { throw new VectorStoreOperationException("Call to vector store failed.", other) { - VectorStoreType = DatabaseName, - CollectionName = this.CollectionName, + VectorStoreSystemName = PineconeConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, OperationName = "DeleteCollection" }; } } /// - public virtual async Task GetAsync(string key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public async Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { - Verify.NotNull(key); + if (options?.IncludeVectors is true && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } Sdk.FetchRequest request = new() { Namespace = this._options.IndexNamespace, - Ids = [key] + Ids = [this.GetStringKey(key)] }; var response = await this.RunIndexOperationAsync( @@ -183,21 +181,35 @@ public virtual async Task DeleteCollectionAsync(CancellationToken cancellationTo StorageToDataModelMapperOptions mapperOptions = new() { IncludeVectors = options?.IncludeVectors is true }; return VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, + PineconeConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, "Get", () => this._mapper.MapFromStorageToDataModel(result, mapperOptions)); } /// - public virtual async IAsyncEnumerable GetBatchAsync( - IEnumerable keys, + public async IAsyncEnumerable GetAsync( + IEnumerable keys, GetRecordOptions? options = default, [EnumeratorCancellation] CancellationToken cancellationToken = default) { Verify.NotNull(keys); - List keysList = keys.ToList(); + if (options?.IncludeVectors is true && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + +#pragma warning disable CA1851 // Bogus: Possible multiple enumerations of 'IEnumerable' collection + var keysList = keys switch + { + IEnumerable k => k.ToList(), + IEnumerable k => k.Cast().ToList(), + _ => throw new UnreachableException("string key should have been validated during model building") + }; +#pragma warning restore CA1851 + if (keysList.Count == 0) { yield break; @@ -219,8 +231,9 @@ public virtual async IAsyncEnumerable GetBatchAsync( StorageToDataModelMapperOptions mapperOptions = new() { IncludeVectors = options?.IncludeVectors is true }; var records = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, + PineconeConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, "GetBatch", () => response.Vectors.Values.Select(x => this._mapper.MapFromStorageToDataModel(x, mapperOptions))); @@ -231,14 +244,12 @@ public virtual async IAsyncEnumerable GetBatchAsync( } /// - public virtual Task DeleteAsync(string key, CancellationToken cancellationToken = default) + public Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) { - Verify.NotNullOrWhiteSpace(key); - Sdk.DeleteRequest request = new() { Namespace = this._options.IndexNamespace, - Ids = [key] + Ids = [this.GetStringKey(key)] }; return this.RunIndexOperationAsync( @@ -247,11 +258,17 @@ public virtual Task DeleteAsync(string key, CancellationToken cancellationToken } /// - public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + public Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default) { Verify.NotNull(keys); - List keysList = keys.ToList(); + var keysList = keys switch + { + IEnumerable k => k.ToList(), + IEnumerable k => k.Cast().ToList(), + _ => throw new UnreachableException("string key should have been validated during model building") + }; + if (keysList.Count == 0) { return Task.CompletedTask; @@ -269,15 +286,33 @@ public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken } /// - public virtual async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) { Verify.NotNull(record); + // If an embedding generator is defined, invoke it once for all records. + Embedding? generatedEmbedding = null; + + Debug.Assert(this._model.VectorProperties.Count <= 1); + if (this._model.VectorProperties is [{ EmbeddingGenerator: not null } vectorProperty]) + { + if (vectorProperty.TryGenerateEmbedding, ReadOnlyMemory>(record, cancellationToken, out var task)) + { + generatedEmbedding = await task.ConfigureAwait(false); + } + else + { + throw new InvalidOperationException( + $"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding).Name}' for the given input type."); + } + } + var vector = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, + PineconeConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, "Upsert", - () => this._mapper.MapFromDataToStorageModel(record)); + () => this._mapper.MapFromDataToStorageModel(record, generatedEmbedding)); Sdk.UpsertRequest request = new() { @@ -289,23 +324,51 @@ await this.RunIndexOperationAsync( "Upsert", indexClient => indexClient.UpsertAsync(request, cancellationToken: cancellationToken)).ConfigureAwait(false); - return vector.Id; + return (TKey)(object)vector.Id; } /// - public virtual async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async Task> UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) { Verify.NotNull(records); + // If an embedding generator is defined, invoke it once for all records. + GeneratedEmbeddings>? generatedEmbeddings = null; + + if (this._model.VectorProperties is [{ EmbeddingGenerator: not null } vectorProperty]) + { + var recordsList = records is IReadOnlyList r ? r : records.ToList(); + + if (recordsList.Count == 0) + { + return []; + } + + records = recordsList; + + if (vectorProperty.TryGenerateEmbeddings, ReadOnlyMemory>(records, cancellationToken, out var task)) + { + generatedEmbeddings = await task.ConfigureAwait(false); + + Debug.Assert(generatedEmbeddings.Count == recordsList.Count); + } + else + { + throw new InvalidOperationException( + $"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding).Name}' for the given input type."); + } + } + var vectors = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, + PineconeConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, "UpsertBatch", - () => records.Select(this._mapper.MapFromDataToStorageModel).ToList()); + () => records.Select((r, i) => this._mapper.MapFromDataToStorageModel(r, generatedEmbeddings?[i])).ToList()); if (vectors.Count == 0) { - yield break; + return []; } Sdk.UpsertRequest request = new() @@ -318,16 +381,70 @@ await this.RunIndexOperationAsync( "UpsertBatch", indexClient => indexClient.UpsertAsync(request, cancellationToken: cancellationToken)).ConfigureAwait(false); - foreach (var vector in vectors) + return vectors.Select(x => (TKey)(object)x.Id).ToList(); + } + + #region Search + + /// + public async IAsyncEnumerable> SearchAsync( + TInput value, + int top, + VectorSearchOptions? options = default, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TInput : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + switch (vectorProperty.EmbeddingGenerator) { - yield return vector.Id; + case IEmbeddingGenerator> generator: + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; + + case null: + throw new InvalidOperationException(VectorDataStrings.NoEmbeddingGeneratorWasConfiguredForSearch); + + default: + throw new InvalidOperationException( + PineconeVectorStoreRecordFieldMapping.s_supportedVectorTypes.Contains(typeof(TInput)) + ? string.Format(VectorDataStrings.EmbeddingTypePassedToSearchAsync) + : string.Format(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType, typeof(TInput).Name, vectorProperty.EmbeddingGenerator.GetType().Name)); } } /// - public virtual async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public IAsyncEnumerable> SearchEmbeddingAsync( + TVector vector, + int top, + VectorSearchOptions? options = null, + CancellationToken cancellationToken = default) + where TVector : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + return this.SearchCoreAsync(vector, top, vectorProperty, operationName: "SearchEmbedding", options, cancellationToken); + } + + private async IAsyncEnumerable> SearchCoreAsync( + TVector vector, + int top, + VectorStoreRecordVectorPropertyModel vectorProperty, + string operationName, + VectorSearchOptions options, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TVector : notnull { Verify.NotNull(vector); + Verify.NotLessThan(top, 1); if (vector is not ReadOnlyMemory floatVector) { @@ -335,21 +452,24 @@ public virtual async Task> VectorizedSearchAsync).FullName}"); } - options ??= s_defaultVectorSearchOptions; + if (options.IncludeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } #pragma warning disable CS0618 // VectorSearchFilter is obsolete var filter = options switch { { OldFilter: not null, Filter: not null } => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), - { OldFilter: VectorSearchFilter legacyFilter } => PineconeVectorStoreCollectionSearchMapping.BuildSearchFilter(options.OldFilter?.FilterClauses, this._propertyReader.StoragePropertyNamesMap), - { Filter: Expression> newFilter } => new PineconeFilterTranslator().Translate(newFilter, this._propertyReader.StoragePropertyNamesMap), + { OldFilter: VectorSearchFilter legacyFilter } => PineconeVectorStoreCollectionSearchMapping.BuildSearchFilter(options.OldFilter?.FilterClauses, this._model), + { Filter: Expression> newFilter } => new PineconeFilterTranslator().Translate(newFilter, this._model), _ => null }; #pragma warning restore CS0618 Sdk.QueryRequest request = new() { - TopK = (uint)(options.Top + options.Skip), + TopK = (uint)(top + options.Skip), Namespace = this._options.IndexNamespace, IncludeValues = options.IncludeVectors, IncludeMetadata = true, @@ -358,12 +478,12 @@ public virtual async Task> VectorizedSearchAsync indexClient.QueryAsync(request, cancellationToken: cancellationToken)).ConfigureAwait(false); if (response.Matches is null) { - return new VectorSearchResults(Array.Empty>().ToAsyncEnumerable()); + yield break; } // Pinecone does not provide a way to skip results, so we need to do it manually. @@ -372,19 +492,103 @@ public virtual async Task> VectorizedSearchAsync skippedResults.Select(x => new VectorSearchResult(this._mapper.MapFromStorageToDataModel(new Sdk.Vector() { Id = x.Id, Values = x.Values ?? Array.Empty(), Metadata = x.Metadata, SparseValues = x.SparseValues - }, mapperOptions), x.Score))) - .ToAsyncEnumerable(); + }, mapperOptions), x.Score))); + + foreach (var record in records) + { + yield return record; + } + } + + /// + [Obsolete("Use either SearchEmbeddingAsync to search directly on embeddings, or SearchAsync to handle embedding generation internally as part of the call.")] + public IAsyncEnumerable> VectorizedSearchAsync(TVector vector, int top, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + => this.SearchEmbeddingAsync(vector, top, options, cancellationToken); + + #endregion Search + + /// + public async IAsyncEnumerable GetAsync(Expression> filter, int top, GetFilteredRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(filter); + Verify.NotLessThan(top, 1); + + if (options?.OrderBy.Values.Count > 0) + { + throw new NotSupportedException("Pinecone does not support ordering."); + } + + options ??= new(); + + if (options.IncludeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + + Sdk.QueryRequest request = new() + { + TopK = (uint)(top + options.Skip), + Namespace = this._options.IndexNamespace, + IncludeValues = options.IncludeVectors, + IncludeMetadata = true, + // "Either 'vector' or 'ID' must be provided" + // Since we are doing a query, we don't have a vector to provide, so we fake one. + // When https://github.com/pinecone-io/pinecone-dotnet-client/issues/43 gets implemented, we need to switch. + Vector = new ReadOnlyMemory(new float[this._model.VectorProperty.Dimensions]), + Filter = new PineconeFilterTranslator().Translate(filter, this._model), + }; + + Sdk.QueryResponse response = await this.RunIndexOperationAsync( + "Get", + indexClient => indexClient.QueryAsync(request, cancellationToken: cancellationToken)).ConfigureAwait(false); + + if (response.Matches is null) + { + yield break; + } + + StorageToDataModelMapperOptions mapperOptions = new() { IncludeVectors = options.IncludeVectors is true }; + var records = VectorStoreErrorHandler.RunModelConversion( + PineconeConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, + "Query", + () => response.Matches.Skip(options.Skip).Select(x => this._mapper.MapFromStorageToDataModel(new Sdk.Vector() + { + Id = x.Id, + Values = x.Values ?? Array.Empty(), + Metadata = x.Metadata, + SparseValues = x.SparseValues + }, mapperOptions))); + + foreach (var record in records) + { + yield return record; + } + } - return new(records); + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreRecordCollectionMetadata) ? this._collectionMetadata : + serviceType == typeof(Sdk.PineconeClient) ? this._pineconeClient : + serviceType.IsInstanceOfType(this) ? this : + null; } private async Task RunIndexOperationAsync(string operationName, Func> operation) @@ -395,7 +599,7 @@ private async Task RunIndexOperationAsync(string operationName, Func RunIndexOperationAsync(string operationName, Func RunCollectionOperationAsync(string operationName, Func< { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, - CollectionName = this.CollectionName, + VectorStoreSystemName = PineconeConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, OperationName = operationName }; } @@ -438,13 +644,12 @@ private static ServerlessSpecCloud MapCloud(string serverlessIndexCloud) _ => throw new ArgumentException($"Invalid serverless index cloud: {serverlessIndexCloud}.", nameof(serverlessIndexCloud)) }; - private static CreateIndexRequestMetric MapDistanceFunction(VectorStoreRecordVectorProperty vectorProperty) + private static CreateIndexRequestMetric MapDistanceFunction(VectorStoreRecordVectorPropertyModel vectorProperty) => vectorProperty.DistanceFunction switch { - DistanceFunction.CosineSimilarity => CreateIndexRequestMetric.Cosine, + DistanceFunction.CosineSimilarity or null => CreateIndexRequestMetric.Cosine, DistanceFunction.DotProductSimilarity => CreateIndexRequestMetric.Dotproduct, DistanceFunction.EuclideanSquaredDistance => CreateIndexRequestMetric.Euclidean, - null => CreateIndexRequestMetric.Cosine, _ => throw new NotSupportedException($"Distance function '{vectorProperty.DistanceFunction}' is not supported.") }; @@ -461,4 +666,15 @@ private static void VerifyCollectionName(string collectionName) } } } + + private string GetStringKey(TKey key) + { + Verify.NotNull(key); + + var stringKey = key as string ?? throw new UnreachableException("string key should have been validated during model building"); + + Verify.NotNullOrWhiteSpace(stringKey, nameof(key)); + + return stringKey; + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollectionOptions.cs index feb147a75763..4be187f80b16 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollectionOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollectionOptions.cs @@ -1,18 +1,21 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; using Pinecone; namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// -/// Options when creating a . +/// Options when creating a . /// public sealed class PineconeVectorStoreRecordCollectionOptions { /// /// Gets or sets an optional custom mapper to use when converting between the data model and the Pinecone vector. /// + [Obsolete("Custom mappers are no longer supported.", error: true)] public IVectorStoreRecordMapper? VectorCustomMapper { get; init; } = null; /// @@ -45,4 +48,9 @@ public sealed class PineconeVectorStoreRecordCollectionOptions /// This option is only used when creating a new Pinecone index. Default value is 'us-east-1'. /// public string ServerlessIndexRegion { get; init; } = "us-east-1"; + + /// + /// Gets or sets the default embedding generator for vector properties in this collection. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordFieldMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordFieldMapping.cs index 9573740f8580..0210ea23d467 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordFieldMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordFieldMapping.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Linq; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Pinecone; namespace Microsoft.SemanticKernel.Connectors.Pinecone; @@ -13,33 +14,6 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// internal static class PineconeVectorStoreRecordFieldMapping { - /// A set of types that a key on the provided model may have. - public static readonly HashSet s_supportedKeyTypes = [typeof(string)]; - - /// A set of types that data properties on the provided model may have. - public static readonly HashSet s_supportedDataTypes = - [ - typeof(bool), - typeof(bool?), - typeof(string), - typeof(int), - typeof(int?), - typeof(long), - typeof(long?), - typeof(float), - typeof(float?), - typeof(double), - typeof(double?), - typeof(decimal), - typeof(decimal?), - ]; - - /// A set of types that enumerable data properties on the provided model may use as their element types. - public static readonly HashSet s_supportedEnumerableDataElementTypes = - [ - typeof(string) - ]; - /// A set of types that vectors on the provided model may have. public static readonly HashSet s_supportedVectorTypes = [ @@ -47,6 +21,30 @@ internal static class PineconeVectorStoreRecordFieldMapping typeof(ReadOnlyMemory?), ]; + public static readonly VectorStoreRecordModelBuildingOptions ModelBuildingOptions = new() + { + RequiresAtLeastOneVector = true, + SupportsMultipleKeys = false, + SupportsMultipleVectors = false, + + SupportedKeyPropertyTypes = [typeof(string)], + + SupportedDataPropertyTypes = + [ + typeof(bool), + typeof(string), + typeof(int), + typeof(long), + typeof(float), + typeof(double), + typeof(decimal) + ], + + SupportedEnumerableDataPropertyElementTypes = [typeof(string)], + + SupportedVectorPropertyTypes = s_supportedVectorTypes + }; + public static object? ConvertFromMetadataValueToNativeType(MetadataValue metadataValue, Type targetType) => metadataValue.Value switch { diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordMapper.cs index 1163c1a66bea..08abd5947201 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordMapper.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Pinecone; namespace Microsoft.SemanticKernel.Connectors.Pinecone; @@ -10,52 +12,32 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// Mapper between a Pinecone record and the consumer data model that uses json as an intermediary to allow supporting a wide range of models. /// /// The consumer data model to map to or from. -internal sealed class PineconeVectorStoreRecordMapper : IVectorStoreRecordMapper +internal sealed class PineconeVectorStoreRecordMapper(VectorStoreRecordModel model) { - private readonly VectorStoreRecordPropertyReader _propertyReader; - - /// - /// Initializes a new instance of the class. - /// - /// A helper to access property information for the current data model and record definition. - public PineconeVectorStoreRecordMapper( - VectorStoreRecordPropertyReader propertyReader) - { - // Validate property types. - propertyReader.VerifyHasParameterlessConstructor(); - propertyReader.VerifyKeyProperties(PineconeVectorStoreRecordFieldMapping.s_supportedKeyTypes); - propertyReader.VerifyDataProperties(PineconeVectorStoreRecordFieldMapping.s_supportedDataTypes, PineconeVectorStoreRecordFieldMapping.s_supportedEnumerableDataElementTypes); - propertyReader.VerifyVectorProperties(PineconeVectorStoreRecordFieldMapping.s_supportedVectorTypes); - - // Assign. - this._propertyReader = propertyReader; - } - /// - public Vector MapFromDataToStorageModel(TRecord dataModel) + public Vector MapFromDataToStorageModel(TRecord dataModel, Embedding? generatedEmbedding) { - var keyObject = this._propertyReader.KeyPropertyInfo.GetValue(dataModel); + var keyObject = model.KeyProperty.GetValueAsObject(dataModel!); if (keyObject is null) { - throw new VectorStoreRecordMappingException($"Key property {this._propertyReader.KeyPropertyName} on provided record of type {typeof(TRecord).FullName} may not be null."); + throw new VectorStoreRecordMappingException($"Key property '{model.KeyProperty.ModelName}' on provided record of type '{typeof(TRecord).Name}' may not be null."); } var metadata = new Metadata(); - foreach (var dataPropertyInfo in this._propertyReader.DataPropertiesInfo) + foreach (var property in model.DataProperties) { - var propertyName = this._propertyReader.GetStoragePropertyName(dataPropertyInfo.Name); - var propertyValue = dataPropertyInfo.GetValue(dataModel); - if (propertyValue != null) + if (property.GetValueAsObject(dataModel!) is { } value) { - metadata[propertyName] = PineconeVectorStoreRecordFieldMapping.ConvertToMetadataValue(propertyValue); + metadata[property.StorageName] = PineconeVectorStoreRecordFieldMapping.ConvertToMetadataValue(value); } } - var valuesObject = this._propertyReader.FirstVectorPropertyInfo!.GetValue(dataModel); - if (valuesObject is not ReadOnlyMemory values) + var values = (generatedEmbedding?.Vector ?? model.VectorProperty!.GetValueAsObject(dataModel!)) switch { - throw new VectorStoreRecordMappingException($"Vector property {this._propertyReader.FirstVectorPropertyName} on provided record of type {typeof(TRecord).FullName} may not be null."); - } + ReadOnlyMemory floats => floats, + null => throw new VectorStoreRecordMappingException($"Vector property '{model.VectorProperty.ModelName}' on provided record of type '{typeof(TRecord).Name}' may not be null."), + _ => throw new VectorStoreRecordMappingException($"Unsupported vector type '{model.VectorProperty.Type.Name}' for vector property '{model.VectorProperty.ModelName}' on provided record of type '{typeof(TRecord).Name}'.") + }; // TODO: what about sparse values? var result = new Vector @@ -72,29 +54,27 @@ public Vector MapFromDataToStorageModel(TRecord dataModel) /// public TRecord MapFromStorageToDataModel(Vector storageModel, StorageToDataModelMapperOptions options) { - // Construct the output record. - var outputRecord = (TRecord)this._propertyReader.ParameterLessConstructorInfo.Invoke(null); + var outputRecord = model.CreateRecord()!; - // Set Key. - this._propertyReader.KeyPropertyInfo.SetValue(outputRecord, storageModel.Id); + model.KeyProperty.SetValueAsObject(outputRecord, storageModel.Id); - // Set Vector. if (options?.IncludeVectors is true) { - this._propertyReader.FirstVectorPropertyInfo!.SetValue( + model.VectorProperty.SetValueAsObject( outputRecord, storageModel.Values); } - // Set Data. if (storageModel.Metadata != null) { - VectorStoreRecordMapping.SetValuesOnProperties( - outputRecord, - this._propertyReader.DataPropertiesInfo, - this._propertyReader.StoragePropertyNamesMap, - storageModel.Metadata, - PineconeVectorStoreRecordFieldMapping.ConvertFromMetadataValueToNativeType!); + foreach (var property in model.DataProperties) + { + property.SetValueAsObject( + outputRecord, + storageModel.Metadata.TryGetValue(property.StorageName, out var metadataValue) && metadataValue is not null + ? PineconeVectorStoreRecordFieldMapping.ConvertFromMetadataValueToNativeType(metadataValue, property.Type) + : null); + } } return outputRecord; diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj b/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj index 4a97b2962a14..17398d17217e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj @@ -4,7 +4,7 @@ Microsoft.SemanticKernel.Connectors.Postgres $(AssemblyName) - net8.0;netstandard2.0 + net8.0;netstandard2.0;net462 preview $(NoWarn);CS0436 @@ -12,6 +12,7 @@ + @@ -24,12 +25,15 @@ + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs index 1d13a847ba4b..11c5a504dcbc 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; using Pgvector; @@ -12,7 +11,7 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// /// Interface for client managing postgres database operations for . /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PostgresVectorStore")] public interface IPostgresDbClient { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs index 0175243131cd..0581c6a6c134 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Linq.Expressions; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Pgvector; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -35,22 +36,23 @@ internal interface IPostgresVectorStoreCollectionSqlBuilder /// /// The schema of the table. /// The name of the table. - /// The properties of the table. + /// The collection model. /// Specifies whether to include IF NOT EXISTS in the command. /// The built SQL command info. - PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tableName, IReadOnlyList properties, bool ifNotExists = true); + PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tableName, VectorStoreRecordModel model, bool ifNotExists = true); /// /// Builds a SQL command to create a vector index in the Postgres vector store. /// /// The schema of the table. /// The name of the table. - /// The name of the vector column. + /// The name of the column. /// The kind of index to create. /// The distance function to use for the index. + /// Specifies whether the column is a vector column. /// Specifies whether to include IF NOT EXISTS in the command. /// The built SQL command info. - PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction, bool ifNotExists); + PostgresSqlCommandInfo BuildCreateIndexCommand(string schema, string tableName, string columnName, string indexKind, string distanceFunction, bool isVector, bool ifNotExists); /// /// Builds a SQL command to drop a table in the Postgres vector store. @@ -85,22 +87,22 @@ internal interface IPostgresVectorStoreCollectionSqlBuilder /// /// The schema of the table. /// The name of the table. - /// The properties of the table. + /// The collection model. /// The key of the record to get. /// Specifies whether to include vectors in the record. /// The built SQL command info. - PostgresSqlCommandInfo BuildGetCommand(string schema, string tableName, IReadOnlyList properties, TKey key, bool includeVectors = false) where TKey : notnull; + PostgresSqlCommandInfo BuildGetCommand(string schema, string tableName, VectorStoreRecordModel model, TKey key, bool includeVectors = false) where TKey : notnull; /// /// Builds a SQL command to get a batch of records from the Postgres vector store. /// /// The schema of the table. /// The name of the table. - /// The properties of the table. + /// The collection model. /// The keys of the records to get. /// Specifies whether to include vectors in the records. /// The built SQL command info. - PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string tableName, IReadOnlyList properties, List keys, bool includeVectors = false) where TKey : notnull; + PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string tableName, VectorStoreRecordModel model, List keys, bool includeVectors = false) where TKey : notnull; /// /// Builds a SQL command to delete a record from the Postgres vector store. @@ -127,7 +129,7 @@ internal interface IPostgresVectorStoreCollectionSqlBuilder /// /// The schema of the table. /// The name of the table. - /// The property reader. + /// The collection model. /// The property which the vectors to compare are stored in. /// The vector to match. /// The filter conditions for the query. @@ -137,6 +139,6 @@ internal interface IPostgresVectorStoreCollectionSqlBuilder /// The maximum number of records to return. /// The built SQL command info. #pragma warning disable CS0618 // VectorSearchFilter is obsolete - PostgresSqlCommandInfo BuildGetNearestMatchCommand(string schema, string tableName, VectorStoreRecordPropertyReader propertyReader, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, VectorSearchFilter? legacyFilter, Expression>? newFilter, int? skip, bool includeVectors, int limit); + PostgresSqlCommandInfo BuildGetNearestMatchCommand(string schema, string tableName, VectorStoreRecordModel model, VectorStoreRecordVectorPropertyModel vectorProperty, Vector vectorValue, VectorSearchFilter? legacyFilter, Expression>? newFilter, int? skip, bool includeVectors, int limit); #pragma warning restore CS0618 // VectorSearchFilter is obsolete } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs index 020aa46dbda6..157c4aacce61 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs @@ -6,6 +6,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Npgsql; using Pgvector; @@ -21,6 +22,11 @@ internal interface IPostgresVectorStoreDbClient /// NpgsqlDataSource DataSource { get; } + /// + /// The name of the database. + /// + string? DatabaseName { get; } + /// /// Check if a table exists. /// @@ -39,11 +45,11 @@ internal interface IPostgresVectorStoreDbClient /// Create a table. Also creates an index on vector columns if the table has vector properties defined. /// /// The name assigned to a table of entries. - /// The properties of the record definition that define the table. + /// The collection model. /// Specifies whether to include IF NOT EXISTS in the command. /// The to monitor for cancellation requests. The default is . /// - Task CreateTableAsync(string tableName, IReadOnlyList properties, bool ifNotExists = true, CancellationToken cancellationToken = default); + Task CreateTableAsync(string tableName, VectorStoreRecordModel model, bool ifNotExists = true, CancellationToken cancellationToken = default); /// /// Drop a table. @@ -77,11 +83,11 @@ internal interface IPostgresVectorStoreDbClient /// /// The name assigned to a table of entries. /// The key of the entry to get. - /// The properties to include in the entry. + /// The collection model. /// If true, the vectors will be included in the entry. /// The to monitor for cancellation requests. The default is . /// The row if the key is found, otherwise null. - Task?> GetAsync(string tableName, TKey key, IReadOnlyList properties, bool includeVectors = false, CancellationToken cancellationToken = default) + Task?> GetAsync(string tableName, TKey key, VectorStoreRecordModel model, bool includeVectors = false, CancellationToken cancellationToken = default) where TKey : notnull; /// @@ -89,11 +95,11 @@ internal interface IPostgresVectorStoreDbClient /// /// The name assigned to a table of entries. /// The keys of the entries to get. - /// The properties of the table. + /// The collection model. /// If true, the vectors will be included in the entries. /// The to monitor for cancellation requests. The default is . /// The rows that match the given keys. - IAsyncEnumerable> GetBatchAsync(string tableName, IEnumerable keys, IReadOnlyList properties, bool includeVectors = false, CancellationToken cancellationToken = default) + IAsyncEnumerable> GetBatchAsync(string tableName, IEnumerable keys, VectorStoreRecordModel model, bool includeVectors = false, CancellationToken cancellationToken = default) where TKey : notnull; /// @@ -120,18 +126,16 @@ internal interface IPostgresVectorStoreDbClient /// Gets the nearest matches to the . /// /// The name assigned to a table of entries. - /// The property reader. + /// The collection model. /// The vector property. /// The to compare the table's vector with. /// The maximum number of similarity results to return. - /// Optional conditions to filter the results. - /// Optional conditions to filter the results. - /// The number of entries to skip. - /// If true, the vectors will be returned in the entries. + /// The options that control the behavior of the search. /// The to monitor for cancellation requests. The default is . - /// An asynchronous stream of objects that the nearest matches to the . -#pragma warning disable CS0618 // VectorSearchFilter is obsolete - IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync(string tableName, VectorStoreRecordPropertyReader propertyReader, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, - VectorSearchFilter? legacyFilter = default, Expression>? newFilter = default, int? skip = default, bool includeVectors = false, CancellationToken cancellationToken = default); -#pragma warning restore CS0618 // VectorSearchFilter is obsolete + /// An asynchronous stream of result objects that the nearest matches to the . + IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync(string tableName, VectorStoreRecordModel model, VectorStoreRecordVectorPropertyModel vectorProperty, Vector vectorValue, int limit, + VectorSearchOptions options, CancellationToken cancellationToken = default); + + IAsyncEnumerable> GetMatchingRecordsAsync(string tableName, VectorStoreRecordModel model, + Expression> filter, int top, GetFilteredRecordOptions options, CancellationToken cancellationToken = default); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs index 58384ba767ac..58af348ac87d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs @@ -22,5 +22,6 @@ public interface IPostgresVectorStoreRecordCollectionFactory /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. /// The new instance of . IVectorStoreRecordCollection CreateVectorStoreRecordCollection(NpgsqlDataSource dataSource, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) - where TKey : notnull; + where TKey : notnull + where TRecord : notnull; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs index f8784890e83a..1ab2fc44211e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs @@ -3,66 +3,68 @@ using System; using System.Collections.Generic; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.Postgres; internal static class PostgresConstants { - /// The name of this database for telemetry purposes. - public const string DatabaseName = "Postgres"; + /// The name of this vector store for telemetry purposes. + public const string VectorStoreSystemName = "postgresql"; - /// A of types that a key on the provided model may have. - public static readonly HashSet SupportedKeyTypes = - [ - typeof(short), - typeof(int), - typeof(long), - typeof(string), - typeof(Guid), - ]; + /// Validation options. + public static readonly VectorStoreRecordModelBuildingOptions ModelBuildingOptions = new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, - /// A of types that data properties on the provided model may have. - public static readonly HashSet SupportedDataTypes = - [ - typeof(bool), - typeof(bool?), - typeof(short), - typeof(short?), - typeof(int), - typeof(int?), - typeof(long), - typeof(long?), - typeof(float), - typeof(float?), - typeof(double), - typeof(double?), - typeof(decimal), - typeof(decimal?), - typeof(string), - typeof(DateTime), - typeof(DateTime?), - typeof(DateTimeOffset), - typeof(DateTimeOffset?), - typeof(Guid), - typeof(Guid?), - typeof(byte[]), - ]; + SupportedKeyPropertyTypes = + [ + typeof(short), + typeof(int), + typeof(long), + typeof(string), + typeof(Guid) + ], - /// A of types that enumerable data properties on the provided model may use as their element types. - public static readonly HashSet SupportedEnumerableDataElementTypes = - [ - typeof(bool), - typeof(short), - typeof(int), - typeof(long), - typeof(float), - typeof(double), - typeof(decimal), - typeof(string), - typeof(DateTime), - typeof(DateTimeOffset), - typeof(Guid), - ]; + SupportedDataPropertyTypes = + [ + typeof(bool), + typeof(short), + typeof(int), + typeof(long), + typeof(float), + typeof(double), + typeof(decimal), + typeof(string), + typeof(DateTime), + typeof(DateTimeOffset), + typeof(Guid), + typeof(byte[]), + ], + + SupportedEnumerableDataPropertyElementTypes = + [ + typeof(bool), + typeof(short), + typeof(int), + typeof(long), + typeof(float), + typeof(double), + typeof(decimal), + typeof(string), + typeof(DateTime), + typeof(DateTimeOffset), + typeof(Guid), + ], + + SupportedVectorPropertyTypes = + [ + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?) + ] + }; /// A of types that vector properties on the provided model may have. public static readonly HashSet SupportedVectorTypes = diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs index 20b384ec06f2..bcfd4443622c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; using System.Threading; @@ -17,7 +16,7 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// An implementation of a client for Postgres. This class is used to managing postgres database operations for . /// [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "We need to build the full table name using schema and collection, it does not support parameterized passing.")] -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PostgresVectorStore")] public class PostgresDbClient : IPostgresDbClient { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs index b4b9707c1c99..52a230865065 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs @@ -2,6 +2,8 @@ using System.Collections.Generic; using System.Linq.Expressions; +using System.Text; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -11,42 +13,43 @@ internal sealed class PostgresFilterTranslator : SqlFilterTranslator private int _parameterIndex; internal PostgresFilterTranslator( - IReadOnlyDictionary storagePropertyNames, + VectorStoreRecordModel model, LambdaExpression lambdaExpression, - int startParamIndex) : base(storagePropertyNames, lambdaExpression, sql: null) + int startParamIndex, + StringBuilder? sql = null) : base(model, lambdaExpression, sql) { this._parameterIndex = startParamIndex; } internal List ParameterValues => this._parameterValues; - protected override void TranslateContainsOverArrayColumn(Expression source, Expression item, MethodCallExpression parent) + protected override void TranslateContainsOverArrayColumn(Expression source, Expression item) { - this.Translate(source, parent); + this.Translate(source); this._sql.Append(" @> ARRAY["); - this.Translate(item, parent); + this.Translate(item); this._sql.Append(']'); } - protected override void TranslateContainsOverCapturedArray(Expression source, Expression item, MethodCallExpression parent, object? value) + protected override void TranslateContainsOverParameterizedArray(Expression source, Expression item, object? value) { - this.Translate(item, parent); + this.Translate(item); this._sql.Append(" = ANY ("); - this.Translate(source, parent); + this.Translate(source); this._sql.Append(')'); } - protected override void TranslateCapturedVariable(string name, object? capturedValue) + protected override void TranslateQueryParameter(string name, object? value) { // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) - if (capturedValue is null) + if (value is null) { this._sql.Append("NULL"); } else { - this._parameterValues.Add(capturedValue); + this._parameterValues.Add(value); this._sql.Append('$').Append(this._parameterIndex++); } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs deleted file mode 100644 index efdec538c772..000000000000 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Collections.Generic; -using Microsoft.Extensions.VectorData; - -namespace Microsoft.SemanticKernel.Connectors.Postgres; - -internal sealed class PostgresGenericDataModelMapper : IVectorStoreRecordMapper, Dictionary> - where TKey : notnull -{ - /// with helpers for reading vector store model properties and their attributes. - private readonly VectorStoreRecordPropertyReader _propertyReader; - - /// - /// Initializes a new instance of the class. - /// /// - /// with helpers for reading vector store model properties and their attributes. - public PostgresGenericDataModelMapper(VectorStoreRecordPropertyReader propertyReader) - { - Verify.NotNull(propertyReader); - - this._propertyReader = propertyReader; - - // Validate property types. - this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, PostgresConstants.SupportedEnumerableDataElementTypes); - this._propertyReader.VerifyVectorProperties(PostgresConstants.SupportedVectorTypes); - } - - public Dictionary MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - { - var properties = new Dictionary - { - // Add key property - { this._propertyReader.KeyPropertyStoragePropertyName, dataModel.Key } - }; - - // Add data properties - if (dataModel.Data is not null) - { - foreach (var property in this._propertyReader.DataProperties) - { - if (dataModel.Data.TryGetValue(property.DataModelPropertyName, out var dataValue)) - { - properties.Add(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), dataValue); - } - } - } - - // Add vector properties - if (dataModel.Vectors is not null) - { - foreach (var property in this._propertyReader.VectorProperties) - { - if (dataModel.Vectors.TryGetValue(property.DataModelPropertyName, out var vectorValue)) - { - var result = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vectorValue); - properties.Add(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), result); - } - } - } - - return properties; - } - - public VectorStoreGenericDataModel MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) - { - TKey key; - var dataProperties = new Dictionary(); - var vectorProperties = new Dictionary(); - - // Process key property. - if (storageModel.TryGetValue(this._propertyReader.KeyPropertyStoragePropertyName, out var keyObject) && keyObject is not null) - { - key = (TKey)keyObject; - } - else - { - throw new VectorStoreRecordMappingException("No key property was found in the record retrieved from storage."); - } - - // Process data properties. - foreach (var property in this._propertyReader.DataProperties) - { - if (storageModel.TryGetValue(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), out var dataValue)) - { - dataProperties.Add(property.DataModelPropertyName, dataValue); - } - } - - // Process vector properties - if (options.IncludeVectors) - { - foreach (var property in this._propertyReader.VectorProperties) - { - if (storageModel.TryGetValue(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), out var vectorValue)) - { - vectorProperties.Add(property.DataModelPropertyName, PostgresVectorStoreRecordPropertyMapping.MapVectorForDataModel(vectorValue)); - } - } - } - - return new VectorStoreGenericDataModel(key) { Data = dataProperties, Vectors = vectorProperties }; - } -} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryBuilderExtensions.cs index ad04abe0b7de..64d6d3070d23 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryBuilderExtensions.cs @@ -1,15 +1,17 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using Microsoft.SemanticKernel.Memory; using Npgsql; namespace Microsoft.SemanticKernel.Connectors.Postgres; +#pragma warning disable SKEXP0001 + /// /// Provides extension methods for the class to configure Postgres connector. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PostgresVectorStore")] public static class PostgresMemoryBuilderExtensions { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryEntry.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryEntry.cs index 449856653dbb..8e80f9b49ef3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryEntry.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryEntry.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Diagnostics.CodeAnalysis; using Pgvector; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -9,7 +8,7 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// /// A postgres memory entry. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PostgresVectorStore")] public record struct PostgresMemoryEntry { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryStore.cs index f981ba926d96..f6a59ed5e463 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryStore.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -14,6 +13,8 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// An implementation of backed by a Postgres database with pgvector extension. /// @@ -21,7 +22,7 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// The embedded data is saved to the Postgres database specified in the constructor. /// Similarity search capability is provided through the pgvector extension. Use Postgres's "Table" to implement "Collection". /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PostgresVectorStore")] public class PostgresMemoryStore : IMemoryStore, IDisposable { internal const string DefaultSchema = "public"; diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs index 983b8e7db443..bae79e63c4c4 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Postgres; @@ -28,11 +29,12 @@ public static IServiceCollection AddPostgresVectorStore(this IServiceCollection (sp, obj) => { var dataSource = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService(); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new PostgresVectorStore( - dataSource, - selectedOptions); + return new PostgresVectorStore(dataSource, options); }); return services; @@ -64,18 +66,19 @@ public static IServiceCollection AddPostgresVectorStore(this IServiceCollection (sp, obj) => { var dataSource = sp.GetRequiredKeyedService(npgsqlServiceId); - var selectedOptions = options ?? sp.GetService(); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new PostgresVectorStore( - dataSource, - selectedOptions); + return new PostgresVectorStore(dataSource, options); }); return services; } /// - /// Register a Postgres and with the specified service ID + /// Register a Postgres and with the specified service ID /// and where the NpgsqlDataSource is retrieved from the dependency injection container. /// /// The type of the key. @@ -91,15 +94,19 @@ public static IServiceCollection AddPostgresVectorStoreRecordCollection? options = default, string? serviceId = default) where TKey : notnull + where TRecord : notnull { services.AddKeyedTransient>( serviceId, (sp, obj) => { var dataSource = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return (new PostgresVectorStoreRecordCollection(dataSource, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; + return (new PostgresVectorStoreRecordCollection(dataSource, collectionName, options) as IVectorStoreRecordCollection)!; }); AddVectorizedSearch(services, serviceId); @@ -108,7 +115,7 @@ public static IServiceCollection AddPostgresVectorStoreRecordCollection - /// Register a Postgres and with the specified service ID + /// Register a Postgres and with the specified service ID /// and where the NpgsqlDataSource is constructed using the provided parameters. /// /// The type of the key. @@ -126,6 +133,7 @@ public static IServiceCollection AddPostgresVectorStoreRecordCollection? options = default, string? serviceId = default) where TKey : notnull + where TRecord : notnull { string? npgsqlServiceId = serviceId == null ? default : $"{serviceId}_NpgsqlDataSource"; // Register NpgsqlDataSource to ensure proper disposal. @@ -144,6 +152,11 @@ public static IServiceCollection AddPostgresVectorStoreRecordCollection(npgsqlServiceId); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; + return (new PostgresVectorStoreRecordCollection(dataSource, collectionName, options) as IVectorStoreRecordCollection)!; }); @@ -153,7 +166,7 @@ public static IServiceCollection AddPostgresVectorStoreRecordCollection - /// Also register the with the given as a . + /// Also register the with the given as a . /// /// The type of the key. /// The type of the data model that the collection should contain. @@ -161,8 +174,9 @@ public static IServiceCollection AddPostgresVectorStoreRecordCollectionThe service id that the registrations should use. private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) where TKey : notnull + where TRecord : notnull { - services.AddKeyedTransient>( + services.AddKeyedTransient>( serviceId, (sp, obj) => { diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlBuilder.cs similarity index 66% rename from dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs rename to dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlBuilder.cs index f661c09ebf44..e744d88efce3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlBuilder.cs @@ -6,6 +6,7 @@ using System.Linq.Expressions; using System.Text; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Npgsql; using NpgsqlTypes; using Pgvector; @@ -15,10 +16,10 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// /// Provides methods to build SQL commands for managing vector store collections in PostgreSQL. /// -internal class PostgresVectorStoreCollectionSqlBuilder : IPostgresVectorStoreCollectionSqlBuilder +internal static class PostgresSqlBuilder { /// - public PostgresSqlCommandInfo BuildDoesTableExistCommand(string schema, string tableName) + internal static PostgresSqlCommandInfo BuildDoesTableExistCommand(string schema, string tableName) { return new PostgresSqlCommandInfo( commandText: """ @@ -36,7 +37,7 @@ FROM information_schema.tables } /// - public PostgresSqlCommandInfo BuildGetTablesCommand(string schema) + internal static PostgresSqlCommandInfo BuildGetTablesCommand(string schema) { return new PostgresSqlCommandInfo( commandText: """ @@ -49,69 +50,34 @@ FROM information_schema.tables } /// - public PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tableName, IReadOnlyList properties, bool ifNotExists = true) + internal static PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tableName, VectorStoreRecordModel model, bool ifNotExists = true) { if (string.IsNullOrWhiteSpace(tableName)) { throw new ArgumentException("Table name cannot be null or whitespace", nameof(tableName)); } - VectorStoreRecordKeyProperty? keyProperty = default; - List dataProperties = new(); - List vectorProperties = new(); - - foreach (var property in properties) - { - if (property is VectorStoreRecordKeyProperty keyProp) - { - if (keyProperty != null) - { - // Should be impossible, as property reader should have already validated that - // multiple key properties are not allowed. - throw new ArgumentException("Record definition cannot have more than one key property."); - } - keyProperty = keyProp; - } - else if (property is VectorStoreRecordDataProperty dataProp) - { - dataProperties.Add(dataProp); - } - else if (property is VectorStoreRecordVectorProperty vectorProp) - { - vectorProperties.Add(vectorProp); - } - else - { - throw new NotSupportedException($"Property type {property.GetType().Name} is not supported by this store."); - } - } - - if (keyProperty == null) - { - throw new ArgumentException("Record definition must have a key property."); - } - - var keyName = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; + var keyName = model.KeyProperty.StorageName; StringBuilder createTableCommand = new(); createTableCommand.AppendLine($"CREATE TABLE {(ifNotExists ? "IF NOT EXISTS " : "")}{schema}.\"{tableName}\" ("); // Add the key column - var keyPgTypeInfo = PostgresVectorStoreRecordPropertyMapping.GetPostgresTypeName(keyProperty.PropertyType); + var keyPgTypeInfo = PostgresVectorStoreRecordPropertyMapping.GetPostgresTypeName(model.KeyProperty.Type); createTableCommand.AppendLine($" \"{keyName}\" {keyPgTypeInfo.PgType} {(keyPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); // Add the data columns - foreach (var dataProperty in dataProperties) + foreach (var dataProperty in model.DataProperties) { - string columnName = dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName; - var dataPgTypeInfo = PostgresVectorStoreRecordPropertyMapping.GetPostgresTypeName(dataProperty.PropertyType); + string columnName = dataProperty.StorageName; + var dataPgTypeInfo = PostgresVectorStoreRecordPropertyMapping.GetPostgresTypeName(dataProperty.Type); createTableCommand.AppendLine($" \"{columnName}\" {dataPgTypeInfo.PgType} {(dataPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); } // Add the vector columns - foreach (var vectorProperty in vectorProperties) + foreach (var vectorProperty in model.VectorProperties) { - string columnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + string columnName = vectorProperty.StorageName; var vectorPgTypeInfo = PostgresVectorStoreRecordPropertyMapping.GetPgVectorTypeName(vectorProperty); createTableCommand.AppendLine($" \"{columnName}\" {vectorPgTypeInfo.PgType} {(vectorPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); } @@ -124,8 +90,17 @@ public PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tabl } /// - public PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction, bool ifNotExists) + internal static PostgresSqlCommandInfo BuildCreateIndexCommand(string schema, string tableName, string columnName, string indexKind, string distanceFunction, bool isVector, bool ifNotExists) { + var indexName = $"{tableName}_{columnName}_index"; + + if (!isVector) + { + return new PostgresSqlCommandInfo(commandText: + $@"CREATE INDEX {(ifNotExists ? "IF NOT EXISTS " : "")}""{indexName}"" ON {schema}.""{tableName}"" (""{columnName}"");" + ); + } + // Only support creating HNSW index creation through the connector. var indexTypeName = indexKind switch { @@ -145,16 +120,14 @@ public PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, strin _ => throw new NotSupportedException($"Distance function {distanceFunction} is not supported.") }; - var indexName = $"{tableName}_{vectorColumnName}_index"; - return new PostgresSqlCommandInfo( commandText: $@" - CREATE INDEX {(ifNotExists ? "IF NOT EXISTS " : "")} ""{indexName}"" ON {schema}.""{tableName}"" USING {indexTypeName} (""{vectorColumnName}"" {indexOps});" + CREATE INDEX {(ifNotExists ? "IF NOT EXISTS " : "")} ""{indexName}"" ON {schema}.""{tableName}"" USING {indexTypeName} (""{columnName}"" {indexOps});" ); } /// - public PostgresSqlCommandInfo BuildDropTableCommand(string schema, string tableName) + internal static PostgresSqlCommandInfo BuildDropTableCommand(string schema, string tableName) { return new PostgresSqlCommandInfo( commandText: $@"DROP TABLE IF EXISTS {schema}.""{tableName}""" @@ -162,7 +135,7 @@ public PostgresSqlCommandInfo BuildDropTableCommand(string schema, string tableN } /// - public PostgresSqlCommandInfo BuildUpsertCommand(string schema, string tableName, string keyColumn, Dictionary row) + internal static PostgresSqlCommandInfo BuildUpsertCommand(string schema, string tableName, string keyColumn, Dictionary row) { var columns = row.Keys.ToList(); var columnNames = string.Join(", ", columns.Select(k => $"\"{k}\"")); @@ -185,7 +158,7 @@ ON CONFLICT ("{keyColumn}") } /// - public PostgresSqlCommandInfo BuildUpsertBatchCommand(string schema, string tableName, string keyColumn, List> rows) + internal static PostgresSqlCommandInfo BuildUpsertBatchCommand(string schema, string tableName, string keyColumn, List> rows) { if (rows == null || rows.Count == 0) { @@ -232,62 +205,38 @@ ON CONFLICT ("{keyColumn}") } /// - public PostgresSqlCommandInfo BuildGetCommand(string schema, string tableName, IReadOnlyList properties, TKey key, bool includeVectors = false) + internal static PostgresSqlCommandInfo BuildGetCommand(string schema, string tableName, VectorStoreRecordModel model, TKey key, bool includeVectors = false) where TKey : notnull { List queryColumns = new(); - string? keyColumn = null; - foreach (var property in properties) + foreach (var property in model.Properties) { - if (property is VectorStoreRecordKeyProperty keyProperty) - { - if (keyColumn != null) - { - throw new ArgumentException("Record definition cannot have more than one key property."); - } - keyColumn = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; - queryColumns.Add($"\"{keyColumn}\""); - } - else if (property is VectorStoreRecordDataProperty dataProperty) - { - string columnName = dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName; - queryColumns.Add($"\"{columnName}\""); - } - else if (property is VectorStoreRecordVectorProperty vectorProperty && includeVectors) - { - string columnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; - queryColumns.Add($"\"{columnName}\""); - } + queryColumns.Add($"\"{property.StorageName}\""); } - Verify.NotNull(keyColumn, "Record definition must have a key property."); - var queryColumnList = string.Join(", ", queryColumns); return new PostgresSqlCommandInfo( commandText: $""" SELECT {queryColumnList} FROM {schema}."{tableName}" -WHERE "{keyColumn}" = ${1}; +WHERE "{model.KeyProperty.StorageName}" = ${1}; """, parameters: [new NpgsqlParameter() { Value = key }] ); } /// - public PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string tableName, IReadOnlyList properties, List keys, bool includeVectors = false) + internal static PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string tableName, VectorStoreRecordModel model, List keys, bool includeVectors = false) where TKey : notnull { NpgsqlDbType? keyType = PostgresVectorStoreRecordPropertyMapping.GetNpgsqlDbType(typeof(TKey)) ?? throw new ArgumentException($"Unsupported key type {typeof(TKey).Name}"); - var keyProperty = properties.OfType().FirstOrDefault() ?? throw new ArgumentException("Properties must contain a key property", nameof(properties)); - var keyColumn = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; - // Generate the column names - var columns = properties - .Where(p => includeVectors || p is not VectorStoreRecordVectorProperty) - .Select(p => p.StoragePropertyName ?? p.DataModelPropertyName) + var columns = model.Properties + .Where(p => includeVectors || p is not VectorStoreRecordVectorPropertyModel) + .Select(p => p.StorageName) .ToList(); var columnNames = string.Join(", ", columns.Select(c => $"\"{c}\"")); @@ -297,7 +246,7 @@ public PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string t var commandText = $""" SELECT {columnNames} FROM {schema}."{tableName}" -WHERE "{keyColumn}" = ANY($1); +WHERE "{model.KeyProperty.StorageName}" = ANY($1); """; return new PostgresSqlCommandInfo(commandText) @@ -307,7 +256,7 @@ public PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string t } /// - public PostgresSqlCommandInfo BuildDeleteCommand(string schema, string tableName, string keyColumn, TKey key) + internal static PostgresSqlCommandInfo BuildDeleteCommand(string schema, string tableName, string keyColumn, TKey key) { return new PostgresSqlCommandInfo( commandText: $""" @@ -319,7 +268,7 @@ DELETE FROM {schema}."{tableName}" } /// - public PostgresSqlCommandInfo BuildDeleteBatchCommand(string schema, string tableName, string keyColumn, List keys) + internal static PostgresSqlCommandInfo BuildDeleteBatchCommand(string schema, string tableName, string keyColumn, List keys) { NpgsqlDbType? keyType = PostgresVectorStoreRecordPropertyMapping.GetNpgsqlDbType(typeof(TKey)) ?? throw new ArgumentException($"Unsupported key type {typeof(TKey).Name}"); @@ -344,15 +293,11 @@ DELETE FROM {schema}."{tableName}" #pragma warning disable CS0618 // VectorSearchFilter is obsolete /// - public PostgresSqlCommandInfo BuildGetNearestMatchCommand( - string schema, string tableName, VectorStoreRecordPropertyReader propertyReader, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, + internal static PostgresSqlCommandInfo BuildGetNearestMatchCommand( + string schema, string tableName, VectorStoreRecordModel model, VectorStoreRecordVectorPropertyModel vectorProperty, Vector vectorValue, VectorSearchFilter? legacyFilter, Expression>? newFilter, int? skip, bool includeVectors, int limit) { - var columns = string.Join(" ,", - propertyReader.RecordDefinition.Properties - .Select(property => property.StoragePropertyName ?? property.DataModelPropertyName) - .Select(column => $"\"{column}\"") - ); + var columns = string.Join(" ,", model.Properties.Select(property => $"\"{property.StorageName}\"")); var distanceFunction = vectorProperty.DistanceFunction ?? PostgresConstants.DefaultDistanceFunction; var distanceOp = distanceFunction switch @@ -366,15 +311,15 @@ public PostgresSqlCommandInfo BuildGetNearestMatchCommand( _ => throw new NotSupportedException($"Distance function {vectorProperty.DistanceFunction} is not supported.") }; - var vectorColumn = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + var vectorColumn = vectorProperty.StorageName; // Start where clause params at 2, vector takes param 1. #pragma warning disable CS0618 // VectorSearchFilter is obsolete var (where, parameters) = (oldFilter: legacyFilter, newFilter) switch { (not null, not null) => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), - (not null, null) => GenerateLegacyFilterWhereClause(schema, tableName, propertyReader.RecordDefinition.Properties, legacyFilter, startParamIndex: 2), - (null, not null) => GenerateNewFilterWhereClause(propertyReader, newFilter), + (not null, null) => GenerateLegacyFilterWhereClause(schema, tableName, model, legacyFilter, startParamIndex: 2), + (null, not null) => GenerateNewFilterWhereClause(model, newFilter, startParamIndex: 2), _ => (Clause: string.Empty, Parameters: []) }; #pragma warning restore CS0618 // VectorSearchFilter is obsolete @@ -416,15 +361,60 @@ ORDER BY {PostgresConstants.DistanceColumnName} }; } - internal static (string Clause, List Parameters) GenerateNewFilterWhereClause(VectorStoreRecordPropertyReader propertyReader, LambdaExpression newFilter) + internal static PostgresSqlCommandInfo BuildSelectWhereCommand( + string schema, string tableName, VectorStoreRecordModel model, + Expression> filter, int top, GetFilteredRecordOptions options) + { + StringBuilder query = new(200); + query.Append("SELECT "); + foreach (var property in model.Properties) + { + if (options.IncludeVectors || property is not VectorStoreRecordVectorPropertyModel) + { + query.AppendFormat("\"{0}\",", property.StorageName); + } + } + query.Length--; // Remove trailing comma + query.AppendLine(); + query.AppendFormat("FROM {0}.\"{1}\"", schema, tableName).AppendLine(); + + PostgresFilterTranslator translator = new(model, filter, startParamIndex: 1, query); + translator.Translate(appendWhere: true); + query.AppendLine(); + + if (options.OrderBy.Values.Count > 0) + { + query.Append("ORDER BY "); + + foreach (var sortInfo in options.OrderBy.Values) + { + query.AppendFormat("\"{0}\" {1},", + model.GetDataOrKeyProperty(sortInfo.PropertySelector).StorageName, + sortInfo.Ascending ? "ASC" : "DESC"); + } + + query.Length--; // remove the last comma + query.AppendLine(); + } + + query.AppendFormat("OFFSET {0}", options.Skip).AppendLine(); + query.AppendFormat("LIMIT {0}", top).AppendLine(); + + return new PostgresSqlCommandInfo(query.ToString()) + { + Parameters = translator.ParameterValues.Select(p => new NpgsqlParameter { Value = p }).ToList() + }; + } + + internal static (string Clause, List Parameters) GenerateNewFilterWhereClause(VectorStoreRecordModel model, LambdaExpression newFilter, int startParamIndex) { - PostgresFilterTranslator translator = new(propertyReader.StoragePropertyNamesMap, newFilter, startParamIndex: 2); + PostgresFilterTranslator translator = new(model, newFilter, startParamIndex); translator.Translate(appendWhere: true); return (translator.Clause.ToString(), translator.ParameterValues); } #pragma warning disable CS0618 // VectorSearchFilter is obsolete - internal static (string Clause, List Parameters) GenerateLegacyFilterWhereClause(string schema, string tableName, IReadOnlyList properties, VectorSearchFilter legacyFilter, int startParamIndex) + internal static (string Clause, List Parameters) GenerateLegacyFilterWhereClause(string schema, string tableName, VectorStoreRecordModel model, VectorSearchFilter legacyFilter, int startParamIndex) { var whereClause = new StringBuilder("WHERE "); var filterClauses = new List(); @@ -436,26 +426,24 @@ internal static (string Clause, List Parameters) GenerateLegacyFilterWhe { if (filterClause is EqualToFilterClause equalTo) { - var property = properties.FirstOrDefault(p => p.DataModelPropertyName == equalTo.FieldName); + var property = model.Properties.FirstOrDefault(p => p.ModelName == equalTo.FieldName); if (property == null) { throw new ArgumentException($"Property {equalTo.FieldName} not found in record definition."); } - var columnName = property.StoragePropertyName ?? property.DataModelPropertyName; - filterClauses.Add($"\"{columnName}\" = ${paramIndex}"); + filterClauses.Add($"\"{property.StorageName}\" = ${paramIndex}"); parameters.Add(equalTo.Value); paramIndex++; } else if (filterClause is AnyTagEqualToFilterClause anyTagEqualTo) { - var property = properties.FirstOrDefault(p => p.DataModelPropertyName == anyTagEqualTo.FieldName); + var property = model.Properties.FirstOrDefault(p => p.ModelName == anyTagEqualTo.FieldName); if (property == null) { throw new ArgumentException($"Property {anyTagEqualTo.FieldName} not found in record definition."); } - if (property.PropertyType != typeof(List)) + if (property.Type != typeof(List)) { throw new ArgumentException($"Property {anyTagEqualTo.FieldName} must be of type List to use AnyTagEqualTo filter."); } - var columnName = property.StoragePropertyName ?? property.DataModelPropertyName; - filterClauses.Add($"\"{columnName}\" @> ARRAY[${paramIndex}::TEXT]"); + filterClauses.Add($"\"{property.StorageName}\" @> ARRAY[${paramIndex}::TEXT]"); parameters.Add(anyTagEqualTo.Value); paramIndex++; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs index 0f61e692ae7f..18b72eb89376 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Threading; +using System.Threading.Tasks; using Microsoft.Extensions.VectorData; using Npgsql; @@ -11,12 +12,18 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// /// Represents a vector store implementation using PostgreSQL. /// -public class PostgresVectorStore : IVectorStore +public sealed class PostgresVectorStore : IVectorStore { private readonly IPostgresVectorStoreDbClient _postgresClient; private readonly NpgsqlDataSource? _dataSource; private readonly PostgresVectorStoreOptions _options; + /// Metadata about vector store. + private readonly VectorStoreMetadata _metadata; + + /// A general purpose definition that can be used to construct a collection when needing to proxy schema agnostic operations. + private static readonly VectorStoreRecordDefinition s_generalPurposeDefinition = new() { Properties = [new VectorStoreRecordKeyProperty("Key", typeof(string))] }; + /// /// Initializes a new instance of the class. /// @@ -27,6 +34,12 @@ public PostgresVectorStore(NpgsqlDataSource dataSource, PostgresVectorStoreOptio this._dataSource = dataSource; this._options = options ?? new PostgresVectorStoreOptions(); this._postgresClient = new PostgresVectorStoreDbClient(this._dataSource, this._options.Schema); + + this._metadata = new() + { + VectorStoreSystemName = PostgresConstants.VectorStoreSystemName, + VectorStoreName = this._postgresClient.DatabaseName + }; } /// @@ -38,27 +51,29 @@ internal PostgresVectorStore(IPostgresVectorStoreDbClient postgresDbClient, Post { this._postgresClient = postgresDbClient; this._options = options ?? new PostgresVectorStoreOptions(); + + this._metadata = new() + { + VectorStoreSystemName = PostgresConstants.VectorStoreSystemName, + VectorStoreName = this._postgresClient.DatabaseName + }; } /// - public virtual IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) + public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) { - const string OperationName = "ListCollectionNames"; return PostgresVectorStoreUtils.WrapAsyncEnumerableAsync( this._postgresClient.GetTablesAsync(cancellationToken), - OperationName + "ListCollectionNames", + this._metadata.VectorStoreName ); } /// - public virtual IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull + where TRecord : notnull { - if (!PostgresConstants.SupportedKeyTypes.Contains(typeof(TKey))) - { - throw new NotSupportedException($"Unsupported key type: {typeof(TKey)}"); - } - #pragma warning disable CS0618 // IPostgresVectorStoreRecordCollectionFactory is obsolete if (this._options.VectorStoreCollectionFactory is not null) { @@ -69,9 +84,41 @@ public virtual IVectorStoreRecordCollection GetCollection( this._postgresClient, name, - new PostgresVectorStoreRecordCollectionOptions() { Schema = this._options.Schema, VectorStoreRecordDefinition = vectorStoreRecordDefinition } + new PostgresVectorStoreRecordCollectionOptions() + { + Schema = this._options.Schema, + VectorStoreRecordDefinition = vectorStoreRecordDefinition, + EmbeddingGenerator = this._options.EmbeddingGenerator, + } ); return recordCollection as IVectorStoreRecordCollection ?? throw new InvalidOperationException("Failed to cast record collection."); } + + /// + public Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.CollectionExistsAsync(cancellationToken); + } + + /// + public Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.DeleteCollectionAsync(cancellationToken); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreMetadata) ? this._metadata : + serviceType == typeof(NpgsqlDataSource) ? this._dataSource : + serviceType.IsInstanceOfType(this) ? this : + null; + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs index 07c228540038..4031f913233a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -8,6 +8,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Npgsql; using Pgvector; @@ -17,7 +18,7 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// An implementation of a client for Postgres. This class is used to managing postgres database operations. /// /// -/// Initializes a new instance of the class. +/// Initializes a new instance of the class. /// /// Postgres data source. /// Schema of collection tables. @@ -26,10 +27,12 @@ internal class PostgresVectorStoreDbClient(NpgsqlDataSource dataSource, string s { private readonly string _schema = schema; - private IPostgresVectorStoreCollectionSqlBuilder _sqlBuilder = new PostgresVectorStoreCollectionSqlBuilder(); + private readonly NpgsqlConnectionStringBuilder _connectionStringBuilder = new(dataSource.ConnectionString); public NpgsqlDataSource DataSource { get; } = dataSource; + public string? DatabaseName => this._connectionStringBuilder.Database; + /// public async Task DoesTableExistsAsync(string tableName, CancellationToken cancellationToken = default) { @@ -37,7 +40,7 @@ public async Task DoesTableExistsAsync(string tableName, CancellationToken await using (connection) { - var commandInfo = this._sqlBuilder.BuildDoesTableExistCommand(this._schema, tableName); + var commandInfo = PostgresSqlBuilder.BuildDoesTableExistCommand(this._schema, tableName); using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) @@ -56,7 +59,7 @@ public async IAsyncEnumerable GetTablesAsync([EnumeratorCancellation] Ca await using (connection) { - var commandInfo = this._sqlBuilder.BuildGetTablesCommand(this._schema); + var commandInfo = PostgresSqlBuilder.BuildGetTablesCommand(this._schema); using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) @@ -67,22 +70,21 @@ public async IAsyncEnumerable GetTablesAsync([EnumeratorCancellation] Ca } /// - public async Task CreateTableAsync(string tableName, IReadOnlyList properties, bool ifNotExists = true, CancellationToken cancellationToken = default) + public async Task CreateTableAsync(string tableName, VectorStoreRecordModel model, bool ifNotExists = true, CancellationToken cancellationToken = default) { // Prepare the SQL commands. - var commandInfo = this._sqlBuilder.BuildCreateTableCommand(this._schema, tableName, properties, ifNotExists); + var commandInfo = PostgresSqlBuilder.BuildCreateTableCommand(this._schema, tableName, model, ifNotExists); var createIndexCommands = - PostgresVectorStoreRecordPropertyMapping.GetVectorIndexInfo(properties) + PostgresVectorStoreRecordPropertyMapping.GetIndexInfo(model.Properties) .Select(index => - this._sqlBuilder.BuildCreateVectorIndexCommand(this._schema, tableName, index.column, index.kind, index.function, ifNotExists) - ); + PostgresSqlBuilder.BuildCreateIndexCommand(this._schema, tableName, index.column, index.kind, index.function, index.isVector, ifNotExists)); // Execute the commands in a transaction. NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); await using (connection) { -#if !NETSTANDARD2_0 +#if NET8_0_OR_GREATER var transaction = await connection.BeginTransactionAsync(cancellationToken).ConfigureAwait(false); await using (transaction) #else @@ -99,7 +101,7 @@ public async Task CreateTableAsync(string tableName, IReadOnlyList public async Task DeleteTableAsync(string tableName, CancellationToken cancellationToken = default) { - var commandInfo = this._sqlBuilder.BuildDropTableCommand(this._schema, tableName); + var commandInfo = PostgresSqlBuilder.BuildDropTableCommand(this._schema, tableName); await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); } /// public async Task UpsertAsync(string tableName, Dictionary row, string keyColumn, CancellationToken cancellationToken = default) { - var commandInfo = this._sqlBuilder.BuildUpsertCommand(this._schema, tableName, keyColumn, row); + var commandInfo = PostgresSqlBuilder.BuildUpsertCommand(this._schema, tableName, keyColumn, row); await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); } /// public async Task UpsertBatchAsync(string tableName, IEnumerable> rows, string keyColumn, CancellationToken cancellationToken = default) { - var commandInfo = this._sqlBuilder.BuildUpsertBatchCommand(this._schema, tableName, keyColumn, rows.ToList()); + var commandInfo = PostgresSqlBuilder.BuildUpsertBatchCommand(this._schema, tableName, keyColumn, rows.ToList()); await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); } /// - public async Task?> GetAsync(string tableName, TKey key, IReadOnlyList properties, bool includeVectors = false, CancellationToken cancellationToken = default) where TKey : notnull + public async Task?> GetAsync(string tableName, TKey key, VectorStoreRecordModel model, bool includeVectors = false, CancellationToken cancellationToken = default) where TKey : notnull { NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); await using (connection) { - var commandInfo = this._sqlBuilder.BuildGetCommand(this._schema, tableName, properties, key, includeVectors); + var commandInfo = PostgresSqlBuilder.BuildGetCommand(this._schema, tableName, model, key, includeVectors); using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) { - return this.GetRecord(dataReader, properties, includeVectors); + return this.GetRecord(dataReader, model, includeVectors); } return null; @@ -149,7 +151,7 @@ public async Task UpsertBatchAsync(string tableName, IEnumerable - public async IAsyncEnumerable> GetBatchAsync(string tableName, IEnumerable keys, IReadOnlyList properties, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable> GetBatchAsync(string tableName, IEnumerable keys, VectorStoreRecordModel model, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) where TKey : notnull { Verify.NotNull(keys); @@ -164,12 +166,12 @@ public async Task UpsertBatchAsync(string tableName, IEnumerable public async Task DeleteAsync(string tableName, string keyColumn, TKey key, CancellationToken cancellationToken = default) { - var commandInfo = this._sqlBuilder.BuildDeleteCommand(this._schema, tableName, keyColumn, key); + var commandInfo = PostgresSqlBuilder.BuildDeleteCommand(this._schema, tableName, keyColumn, key); await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); } /// -#pragma warning disable CS0618 // VectorSearchFilter is obsolete public async IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync( - string tableName, VectorStoreRecordPropertyReader propertyReader, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, - VectorSearchFilter? legacyFilter = default, Expression>? newFilter = default, int? skip = default, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) -#pragma warning restore CS0618 // VectorSearchFilter is obsolete + string tableName, VectorStoreRecordModel model, VectorStoreRecordVectorPropertyModel vectorProperty, Vector vectorValue, int limit, + VectorSearchOptions options, [EnumeratorCancellation] CancellationToken cancellationToken = default) { NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); await using (connection) { - var commandInfo = this._sqlBuilder.BuildGetNearestMatchCommand(this._schema, tableName, propertyReader, vectorProperty, vectorValue, legacyFilter, newFilter, skip, includeVectors, limit); + var commandInfo = PostgresSqlBuilder.BuildGetNearestMatchCommand(this._schema, tableName, model, vectorProperty, vectorValue, +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + options.OldFilter, +#pragma warning restore CS0618 // VectorSearchFilter is obsolete + options.Filter, options.Skip, options.IncludeVectors, limit); using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) { var distance = dataReader.GetDouble(dataReader.GetOrdinal(PostgresConstants.DistanceColumnName)); - yield return (Row: this.GetRecord(dataReader, propertyReader.RecordDefinition.Properties, includeVectors), Distance: distance); + yield return (Row: this.GetRecord(dataReader, model, options.IncludeVectors), Distance: distance); + } + } + } + + public async IAsyncEnumerable> GetMatchingRecordsAsync(string tableName, VectorStoreRecordModel model, + Expression> filter, int top, GetFilteredRecordOptions options, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = PostgresSqlBuilder.BuildSelectWhereCommand(this._schema, tableName, model, filter, top, options); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return this.GetRecord(dataReader, model, options.IncludeVectors); } } } @@ -212,42 +234,25 @@ public async Task DeleteBatchAsync(string tableName, string keyColumn, IEn return; } - var commandInfo = this._sqlBuilder.BuildDeleteBatchCommand(this._schema, tableName, keyColumn, listOfKeys); + var commandInfo = PostgresSqlBuilder.BuildDeleteBatchCommand(this._schema, tableName, keyColumn, listOfKeys); await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); } - #region internal =============================================================================== - - /// - /// Sets the SQL builder for the client. - /// - /// - /// - /// This method is used for other Semnatic Kernel connectors that may need to override the default SQL - /// used by this client. - /// - internal void SetSqlBuilder(IPostgresVectorStoreCollectionSqlBuilder sqlBuilder) - { - this._sqlBuilder = sqlBuilder; - } - - #endregion - #region private ================================================================================ private Dictionary GetRecord( NpgsqlDataReader reader, - IEnumerable properties, + VectorStoreRecordModel model, bool includeVectors = false ) { var storageModel = new Dictionary(); - foreach (var property in properties) + foreach (var property in model.Properties) { - var isEmbedding = property is VectorStoreRecordVectorProperty; - var propertyName = property.StoragePropertyName ?? property.DataModelPropertyName; - var propertyType = property.PropertyType; + var isEmbedding = property is VectorStoreRecordVectorPropertyModel; + var propertyName = property.StorageName; + var propertyType = property.Type; var propertyValue = !isEmbedding || includeVectors ? PostgresVectorStoreRecordPropertyMapping.GetPropertyValue(reader, propertyName, propertyType) : null; storageModel.Add(propertyName, propertyValue); diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs index 5add40eed8ee..f96926143126 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -14,6 +15,11 @@ public sealed class PostgresVectorStoreOptions /// public string Schema { get; init; } = "public"; + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } + /// /// An optional factory to use for constructing instances, if a custom record collection is required. /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index ce619398bf99..2c3bda1627d6 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -3,10 +3,14 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.Properties; using Npgsql; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -17,12 +21,16 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// The type of the key. /// The type of the record. #pragma warning disable CA1711 // Identifiers should not have incorrect suffix -public class PostgresVectorStoreRecordCollection : IVectorStoreRecordCollection +public sealed class PostgresVectorStoreRecordCollection : IVectorStoreRecordCollection #pragma warning restore CA1711 // Identifiers should not have incorrect suffix where TKey : notnull + where TRecord : notnull { /// - public string CollectionName { get; } + public string Name { get; } + + /// Metadata about vector store record collection. + private readonly VectorStoreRecordCollectionMetadata _collectionMetadata; /// Postgres client that is used to interact with the database. private readonly IPostgresVectorStoreDbClient _client; @@ -30,11 +38,11 @@ public class PostgresVectorStoreRecordCollection : IVectorStoreRe // Optional configuration options for this class. private readonly PostgresVectorStoreRecordCollectionOptions _options; - /// A helper to access property information for the current data model and record definition. - private readonly VectorStoreRecordPropertyReader _propertyReader; + /// The model for this collection. + private readonly VectorStoreRecordModel _model; /// A mapper to use for converting between the data model and the Azure AI Search record. - private readonly IVectorStoreRecordMapper> _mapper; + private readonly PostgresVectorStoreRecordMapper _mapper; /// The default options for vector search. private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); @@ -43,10 +51,10 @@ public class PostgresVectorStoreRecordCollection : IVectorStoreRe /// Initializes a new instance of the class. /// /// The data source to use for connecting to the database. - /// The name of the collection. + /// The name of the collection. /// Optional configuration options for this class. - public PostgresVectorStoreRecordCollection(NpgsqlDataSource dataSource, string collectionName, PostgresVectorStoreRecordCollectionOptions? options = default) - : this(new PostgresVectorStoreDbClient(dataSource), collectionName, options) + public PostgresVectorStoreRecordCollection(NpgsqlDataSource dataSource, string name, PostgresVectorStoreRecordCollectionOptions? options = default) + : this(new PostgresVectorStoreDbClient(dataSource), name, options) { } @@ -54,67 +62,46 @@ public PostgresVectorStoreRecordCollection(NpgsqlDataSource dataSource, string c /// Initializes a new instance of the class. /// /// The client to use for interacting with the database. - /// The name of the collection. + /// The name of the collection. /// Optional configuration options for this class. /// /// This constructor is internal. It allows internal code to create an instance of this class with a custom client. /// - internal PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client, string collectionName, PostgresVectorStoreRecordCollectionOptions? options = default) + internal PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client, string name, PostgresVectorStoreRecordCollectionOptions? options = default) { // Verify. Verify.NotNull(client); - Verify.NotNullOrWhiteSpace(collectionName); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.DictionaryCustomMapper is not null, PostgresConstants.SupportedKeyTypes); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); + Verify.NotNullOrWhiteSpace(name); // Assign. this._client = client; - this.CollectionName = collectionName; + this.Name = name; this._options = options ?? new PostgresVectorStoreRecordCollectionOptions(); - this._propertyReader = new VectorStoreRecordPropertyReader( - typeof(TRecord), - this._options.VectorStoreRecordDefinition, - new() - { - RequiresAtLeastOneVector = false, - SupportsMultipleKeys = false, - SupportsMultipleVectors = true, - }); - - // Validate property types. - this._propertyReader.VerifyKeyProperties(PostgresConstants.SupportedKeyTypes); - this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, PostgresConstants.SupportedEnumerableDataElementTypes); - this._propertyReader.VerifyVectorProperties(PostgresConstants.SupportedVectorTypes); - - // Resolve mapper. - // First, if someone has provided a custom mapper, use that. - // If they didn't provide a custom mapper, and the record type is the generic data model, use the built in mapper for that. - // Otherwise, use our own default mapper implementation for all other data models. - if (this._options.DictionaryCustomMapper is not null) - { - this._mapper = this._options.DictionaryCustomMapper; - } - else if (typeof(TRecord).IsGenericType && typeof(TRecord).GetGenericTypeDefinition() == typeof(VectorStoreGenericDataModel<>)) - { - this._mapper = (new PostgresGenericDataModelMapper(this._propertyReader) as IVectorStoreRecordMapper>)!; - } - else + + this._model = new VectorStoreRecordModelBuilder(PostgresConstants.ModelBuildingOptions) + .Build(typeof(TRecord), options?.VectorStoreRecordDefinition, options?.EmbeddingGenerator); + + this._mapper = new PostgresVectorStoreRecordMapper(this._model); + + this._collectionMetadata = new() { - this._mapper = new PostgresVectorStoreRecordMapper(this._propertyReader); - } + VectorStoreSystemName = PostgresConstants.VectorStoreSystemName, + VectorStoreName = this._client.DatabaseName, + CollectionName = name + }; } /// - public virtual Task CollectionExistsAsync(CancellationToken cancellationToken = default) + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) { const string OperationName = "DoesTableExists"; return this.RunOperationAsync(OperationName, () => - this._client.DoesTableExistsAsync(this.CollectionName, cancellationToken) + this._client.DoesTableExistsAsync(this.Name, cancellationToken) ); } /// - public virtual Task CreateCollectionAsync(CancellationToken cancellationToken = default) + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) { const string OperationName = "CreateCollection"; return this.RunOperationAsync(OperationName, () => @@ -123,7 +110,7 @@ public virtual Task CreateCollectionAsync(CancellationToken cancellationToken = } /// - public virtual Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) { const string OperationName = "CreateCollectionIfNotExists"; return this.RunOperationAsync(OperationName, () => @@ -132,68 +119,138 @@ public virtual Task CreateCollectionIfNotExistsAsync(CancellationToken cancellat } /// - public virtual Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) { const string OperationName = "DeleteCollection"; return this.RunOperationAsync(OperationName, () => - this._client.DeleteTableAsync(this.CollectionName, cancellationToken) + this._client.DeleteTableAsync(this.Name, cancellationToken) ); } /// - public virtual Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) { const string OperationName = "Upsert"; + IReadOnlyList?[]? generatedEmbeddings = null; + + var vectorPropertyCount = this._model.VectorProperties.Count; + for (var i = 0; i < vectorPropertyCount; i++) + { + var vectorProperty = this._model.VectorProperties[i]; + + if (vectorProperty.EmbeddingGenerator is null) + { + continue; + } + + // TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties), + // and generate embeddings for them in a single batch. That's some more complexity though. + if (vectorProperty.TryGenerateEmbedding, ReadOnlyMemory>(record, cancellationToken, out var floatTask)) + { + generatedEmbeddings ??= new IReadOnlyList?[vectorPropertyCount]; + generatedEmbeddings[i] = [await floatTask.ConfigureAwait(false)]; + } + else + { + throw new InvalidOperationException( + $"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding).Name}' for the given input type."); + } + } + var storageModel = VectorStoreErrorHandler.RunModelConversion( - PostgresConstants.DatabaseName, - this.CollectionName, + PostgresConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, OperationName, - () => this._mapper.MapFromDataToStorageModel(record)); + () => this._mapper.MapFromDataToStorageModel(record, recordIndex: 0, generatedEmbeddings)); Verify.NotNull(storageModel); - var keyObj = storageModel[this._propertyReader.KeyPropertyStoragePropertyName]; + var keyObj = storageModel[this._model.KeyProperty.StorageName]; Verify.NotNull(keyObj); TKey key = (TKey)keyObj!; - return this.RunOperationAsync(OperationName, async () => + return await this.RunOperationAsync(OperationName, async () => { - await this._client.UpsertAsync(this.CollectionName, storageModel, this._propertyReader.KeyPropertyStoragePropertyName, cancellationToken).ConfigureAwait(false); + await this._client.UpsertAsync(this.Name, storageModel, this._model.KeyProperty.StorageName, cancellationToken).ConfigureAwait(false); return key; - } - ); + }).ConfigureAwait(false); } /// - public virtual async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async Task> UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) { Verify.NotNull(records); const string OperationName = "UpsertBatch"; - var storageModels = records.Select(record => VectorStoreErrorHandler.RunModelConversion( - PostgresConstants.DatabaseName, - this.CollectionName, + IReadOnlyList? recordsList = null; + + // If an embedding generator is defined, invoke it once per property for all records. + IReadOnlyList?[]? generatedEmbeddings = null; + + var vectorPropertyCount = this._model.VectorProperties.Count; + for (var i = 0; i < vectorPropertyCount; i++) + { + var vectorProperty = this._model.VectorProperties[i]; + + if (vectorProperty.EmbeddingGenerator is null) + { + continue; + } + + // We have a property with embedding generation; materialize the records' enumerable if needed, to + // prevent multiple enumeration. + if (recordsList is null) + { + recordsList = records is IReadOnlyList r ? r : records.ToList(); + + if (recordsList.Count == 0) + { + return []; + } + + records = recordsList; + } + + // TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties), + // and generate embeddings for them in a single batch. That's some more complexity though. + if (vectorProperty.TryGenerateEmbeddings, ReadOnlyMemory>(records, cancellationToken, out var floatTask)) + { + generatedEmbeddings ??= new IReadOnlyList?[vectorPropertyCount]; + generatedEmbeddings[i] = (IReadOnlyList>)await floatTask.ConfigureAwait(false); + } + else + { + throw new InvalidOperationException( + $"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding).Name}' for the given input type."); + } + } + + var storageModels = VectorStoreErrorHandler.RunModelConversion( + PostgresConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, OperationName, - () => this._mapper.MapFromDataToStorageModel(record))).ToList(); + () => records.Select((r, i) => this._mapper.MapFromDataToStorageModel(r, i, generatedEmbeddings)).ToList()); if (storageModels.Count == 0) { - yield break; + return []; } - var keys = storageModels.Select(model => model[this._propertyReader.KeyPropertyStoragePropertyName]!).ToList(); + var keys = storageModels.Select(model => model[this._model.KeyProperty.StorageName]!).ToList(); await this.RunOperationAsync(OperationName, () => - this._client.UpsertBatchAsync(this.CollectionName, storageModels, this._propertyReader.KeyPropertyStoragePropertyName, cancellationToken) + this._client.UpsertBatchAsync(this.Name, storageModels, this._model.KeyProperty.StorageName, cancellationToken) ).ConfigureAwait(false); - foreach (var key in keys) { yield return (TKey)key!; } + return keys.Select(key => (TKey)key!).ToList(); } /// - public virtual Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { const string OperationName = "Get"; @@ -201,21 +258,27 @@ await this.RunOperationAsync(OperationName, () => bool includeVectors = options?.IncludeVectors is true; + if (includeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + return this.RunOperationAsync(OperationName, async () => { - var row = await this._client.GetAsync(this.CollectionName, key, this._propertyReader.RecordDefinition.Properties, includeVectors, cancellationToken).ConfigureAwait(false); + var row = await this._client.GetAsync(this.Name, key, this._model, includeVectors, cancellationToken).ConfigureAwait(false); if (row is null) { return default; } return VectorStoreErrorHandler.RunModelConversion( - PostgresConstants.DatabaseName, - this.CollectionName, + PostgresConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, OperationName, () => this._mapper.MapFromStorageToDataModel(row, new() { IncludeVectors = includeVectors })); }); } /// - public virtual IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public IAsyncEnumerable GetAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { const string OperationName = "GetBatch"; @@ -223,47 +286,111 @@ public virtual IAsyncEnumerable GetBatchAsync(IEnumerable keys, G bool includeVectors = options?.IncludeVectors is true; + if (includeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + return PostgresVectorStoreUtils.WrapAsyncEnumerableAsync( - this._client.GetBatchAsync(this.CollectionName, keys, this._propertyReader.RecordDefinition.Properties, includeVectors, cancellationToken) + this._client.GetBatchAsync(this.Name, keys, this._model, includeVectors, cancellationToken) .SelectAsync(row => VectorStoreErrorHandler.RunModelConversion( - PostgresConstants.DatabaseName, - this.CollectionName, + PostgresConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, OperationName, () => this._mapper.MapFromStorageToDataModel(row, new() { IncludeVectors = includeVectors })), cancellationToken ), OperationName, - this.CollectionName + this._collectionMetadata.VectorStoreName, + this.Name ); } /// - public virtual Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) + public Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) { const string OperationName = "Delete"; return this.RunOperationAsync(OperationName, () => - this._client.DeleteAsync(this.CollectionName, this._propertyReader.KeyPropertyStoragePropertyName, key, cancellationToken) + this._client.DeleteAsync(this.Name, this._model.KeyProperty.StorageName, key, cancellationToken) ); } /// - public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + public Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default) { Verify.NotNull(keys); const string OperationName = "DeleteBatch"; return this.RunOperationAsync(OperationName, () => - this._client.DeleteBatchAsync(this.CollectionName, this._propertyReader.KeyPropertyStoragePropertyName, keys, cancellationToken) + this._client.DeleteBatchAsync(this.Name, this._model.KeyProperty.StorageName, keys, cancellationToken) ); } + #region Search + /// - public virtual Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async IAsyncEnumerable> SearchAsync( + TInput value, + int top, + VectorSearchOptions? options = default, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TInput : notnull { - const string OperationName = "VectorizedSearch"; + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + switch (vectorProperty.EmbeddingGenerator) + { + case IEmbeddingGenerator> generator: + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; + + // TODO: Implement support for Half, binary, sparse embeddings (#11083) + + case null: + throw new InvalidOperationException(VectorDataStrings.NoEmbeddingGeneratorWasConfiguredForSearch); + default: + throw new InvalidOperationException( + PostgresConstants.SupportedVectorTypes.Contains(typeof(TInput)) + ? string.Format(VectorDataStrings.EmbeddingTypePassedToSearchAsync) + : string.Format(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType, typeof(TInput).Name, vectorProperty.EmbeddingGenerator.GetType().Name)); + } + } + + /// + public IAsyncEnumerable> SearchEmbeddingAsync( + TVector vector, + int top, + VectorSearchOptions? options = null, + CancellationToken cancellationToken = default) + where TVector : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + return this.SearchCoreAsync(vector, top, vectorProperty, operationName: "SearchEmbedding", options, cancellationToken); + } + + private IAsyncEnumerable> SearchCoreAsync( + TVector vector, + int top, + VectorStoreRecordVectorPropertyModel vectorProperty, + string operationName, + VectorSearchOptions options, + CancellationToken cancellationToken = default) + where TVector : notnull + { Verify.NotNull(vector); + Verify.NotLessThan(top, 1); var vectorType = vector.GetType(); @@ -274,8 +401,10 @@ public virtual Task> VectorizedSearchAsync $"Supported types are: {string.Join(", ", PostgresConstants.SupportedVectorTypes.Select(l => l.FullName))}"); } - var searchOptions = options ?? s_defaultVectorSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(searchOptions); + if (options.IncludeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } var pgVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); @@ -283,43 +412,80 @@ public virtual Task> VectorizedSearchAsync // Simulating skip/offset logic locally, since OFFSET can work only with LIMIT in combination // and LIMIT is not supported in vector search extension, instead of LIMIT - "k" parameter is used. - var limit = searchOptions.Top + searchOptions.Skip; + var limit = top + options.Skip; - return this.RunOperationAsync(OperationName, () => - { - var results = this._client.GetNearestMatchesAsync( - this.CollectionName, - this._propertyReader, - vectorProperty, - pgVector, - searchOptions.Top, -#pragma warning disable CS0618 // VectorSearchFilter is obsolete - searchOptions.OldFilter, -#pragma warning restore CS0618 // VectorSearchFilter is obsolete - searchOptions.Filter, - searchOptions.Skip, - searchOptions.IncludeVectors, - cancellationToken) - .SelectAsync(result => + StorageToDataModelMapperOptions mapperOptions = new() { IncludeVectors = options.IncludeVectors }; + + return PostgresVectorStoreUtils.WrapAsyncEnumerableAsync( + this._client.GetNearestMatchesAsync(this.Name, this._model, vectorProperty, pgVector, top, options, cancellationToken) + .SelectAsync(result => { var record = VectorStoreErrorHandler.RunModelConversion( - PostgresConstants.DatabaseName, - this.CollectionName, - OperationName, - () => this._mapper.MapFromStorageToDataModel( - result.Row, new StorageToDataModelMapperOptions() { IncludeVectors = searchOptions.IncludeVectors }) - ); + PostgresConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, + operationName, + () => this._mapper.MapFromStorageToDataModel(result.Row, mapperOptions)); return new VectorSearchResult(record, result.Distance); - }, cancellationToken); + }, cancellationToken), + operationName, + this._collectionMetadata.VectorStoreName, + this.Name + ); + } - return Task.FromResult(new VectorSearchResults(results)); - }); + /// + [Obsolete("Use either SearchEmbeddingAsync to search directly on embeddings, or SearchAsync to handle embedding generation internally as part of the call.")] + public IAsyncEnumerable> VectorizedSearchAsync(TVector vector, int top, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + => this.SearchEmbeddingAsync(vector, top, options, cancellationToken); + + #endregion Search + + /// + public IAsyncEnumerable GetAsync(Expression> filter, int top, + GetFilteredRecordOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(filter); + Verify.NotLessThan(top, 1); + + options ??= new(); + + StorageToDataModelMapperOptions mapperOptions = new() { IncludeVectors = options.IncludeVectors }; + + return PostgresVectorStoreUtils.WrapAsyncEnumerableAsync( + this._client.GetMatchingRecordsAsync(this.Name, this._model, filter, top, options, cancellationToken) + .SelectAsync(dictionary => + { + return VectorStoreErrorHandler.RunModelConversion( + PostgresConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, + "Get", + () => this._mapper.MapFromStorageToDataModel(dictionary, mapperOptions)); + }, cancellationToken), + "Get", + this._collectionMetadata.VectorStoreName, + this.Name); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreRecordCollectionMetadata) ? this._collectionMetadata : + serviceType == typeof(NpgsqlDataSource) ? this._client.DataSource : + serviceType.IsInstanceOfType(this) ? this : + null; } private Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken = default) { - return this._client.CreateTableAsync(this.CollectionName, this._propertyReader.RecordDefinition.Properties, ifNotExists, cancellationToken); + return this._client.CreateTableAsync(this.Name, this._model, ifNotExists, cancellationToken); } private async Task RunOperationAsync(string operationName, Func operation) @@ -332,8 +498,9 @@ private async Task RunOperationAsync(string operationName, Func operation) { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = PostgresConstants.DatabaseName, - CollectionName = this.CollectionName, + VectorStoreSystemName = PostgresConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, OperationName = operationName }; } @@ -349,8 +516,9 @@ private async Task RunOperationAsync(string operationName, Func> o { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = PostgresConstants.DatabaseName, - CollectionName = this.CollectionName, + VectorStoreSystemName = PostgresConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, OperationName = operationName }; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs index 753713d21b3f..1e440a53878d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -21,6 +23,7 @@ public sealed class PostgresVectorStoreRecordCollectionOptions /// /// If not set, the default mapper will be used. /// + [Obsolete("Custom mappers are no longer supported.", error: true)] public IVectorStoreRecordMapper>? DictionaryCustomMapper { get; init; } = null; /// @@ -32,4 +35,9 @@ public sealed class PostgresVectorStoreRecordCollectionOptions /// See , and . /// public VectorStoreRecordDefinition? VectorStoreRecordDefinition { get; init; } = null; + + /// + /// Gets or sets the default embedding generator for vector properties in this collection. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs index e656678413cc..7dc507badd09 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs @@ -2,7 +2,10 @@ using System; using System.Collections.Generic; +using System.Diagnostics; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -10,52 +13,37 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// A mapper class that handles the conversion between data models and storage models for Postgres vector store. /// /// The type of the data model record. -internal sealed class PostgresVectorStoreRecordMapper : IVectorStoreRecordMapper> +internal sealed class PostgresVectorStoreRecordMapper(VectorStoreRecordModel model) + where TRecord : notnull { - /// with helpers for reading vector store model properties and their attributes. - private readonly VectorStoreRecordPropertyReader _propertyReader; - - /// - /// Initializes a new instance of the class. - /// - /// A that defines the schema of the data in the database. - public PostgresVectorStoreRecordMapper(VectorStoreRecordPropertyReader propertyReader) + public Dictionary MapFromDataToStorageModel(TRecord dataModel, int recordIndex, IReadOnlyList?[]? generatedEmbeddings) { - Verify.NotNull(propertyReader); + var keyProperty = model.KeyProperty; - this._propertyReader = propertyReader; - - this._propertyReader.VerifyHasParameterlessConstructor(); - - // Validate property types. - this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, PostgresConstants.SupportedEnumerableDataElementTypes); - this._propertyReader.VerifyVectorProperties(PostgresConstants.SupportedVectorTypes); - } - - public Dictionary MapFromDataToStorageModel(TRecord dataModel) - { var properties = new Dictionary { - // Add key property - { this._propertyReader.KeyPropertyStoragePropertyName, this._propertyReader.KeyPropertyInfo.GetValue(dataModel) } + { keyProperty.StorageName, keyProperty.GetValueAsObject(dataModel) } }; - // Add data properties - foreach (var property in this._propertyReader.DataPropertiesInfo) + foreach (var property in model.DataProperties) { - properties.Add( - this._propertyReader.GetStoragePropertyName(property.Name), - property.GetValue(dataModel) - ); + properties.Add(property.StorageName, property.GetValueAsObject(dataModel)); } - // Add vector properties - foreach (var property in this._propertyReader.VectorPropertiesInfo) + for (var i = 0; i < model.VectorProperties.Count; i++) { - var propertyValue = property.GetValue(dataModel); - var result = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(propertyValue); + var property = model.VectorProperties[i]; - properties.Add(this._propertyReader.GetStoragePropertyName(property.Name), result); + properties.Add( + property.StorageName, + PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel( + generatedEmbeddings?[i] is IReadOnlyList e + ? e[recordIndex] switch + { + Embedding fe => fe.Vector, + _ => throw new UnreachableException() + } + : (ReadOnlyMemory?)property.GetValueAsObject(dataModel!)!)); } return properties; @@ -63,36 +51,38 @@ public PostgresVectorStoreRecordMapper(VectorStoreRecordPropertyReader propertyR public TRecord MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) { - var record = (TRecord)this._propertyReader.ParameterLessConstructorInfo.Invoke(null); + var record = model.CreateRecord()!; - // Set key. - var keyPropertyValue = Convert.ChangeType( - storageModel[this._propertyReader.KeyPropertyStoragePropertyName], - this._propertyReader.KeyProperty.PropertyType); + var keyProperty = model.KeyProperty; + var keyPropertyValue = Convert.ChangeType(storageModel[keyProperty.StorageName], keyProperty.Type); + keyProperty.SetValueAsObject(record, keyPropertyValue); - this._propertyReader.KeyPropertyInfo.SetValue(record, keyPropertyValue); - - // Process data properties. - var dataPropertiesInfoWithValues = VectorStoreRecordMapping.BuildPropertiesInfoWithValues( - this._propertyReader.DataPropertiesInfo, - this._propertyReader.StoragePropertyNamesMap, - storageModel); - - VectorStoreRecordMapping.SetPropertiesOnRecord(record, dataPropertiesInfoWithValues); + foreach (var dataProperty in model.DataProperties) + { + dataProperty.SetValueAsObject(record, storageModel[dataProperty.StorageName]); + } if (options.IncludeVectors) { - // Process vector properties. - var vectorPropertiesInfoWithValues = VectorStoreRecordMapping.BuildPropertiesInfoWithValues( - this._propertyReader.VectorPropertiesInfo, - this._propertyReader.StoragePropertyNamesMap, - storageModel, - (object? vector, Type type) => + foreach (var vectorProperty in model.VectorProperties) + { + switch (storageModel[vectorProperty.StorageName]) { - return PostgresVectorStoreRecordPropertyMapping.MapVectorForDataModel(vector); - }); + case Pgvector.Vector pgVector: + vectorProperty.SetValueAsObject(record, pgVector.Memory); + continue; + + // TODO: Implement support for Half, binary, sparse embeddings (#11083) + + // TODO: We currently allow round-tripping null for the vector property; this is not supported for most (?) dedicated databases; think about it. + case null: + vectorProperty.SetValueAsObject(record, null); + continue; - VectorStoreRecordMapping.SetPropertiesOnRecord(record, vectorPropertiesInfoWithValues); + case var value: + throw new InvalidOperationException($"Embedding vector read back from PostgreSQL is of type '{value.GetType().Name}' instead of the expected Pgvector.Vector type for property '{vectorProperty.ModelName}'."); + } + } } return record; diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs index 5e8509236e31..761c1ea8f21b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs @@ -3,9 +3,11 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Runtime.InteropServices; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Npgsql; using NpgsqlTypes; using Pgvector; @@ -14,50 +16,20 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; internal static class PostgresVectorStoreRecordPropertyMapping { - internal static float[] GetOrCreateArray(ReadOnlyMemory memory) => - MemoryMarshal.TryGetArray(memory, out ArraySegment array) && - array.Count == array.Array!.Length ? - array.Array : - memory.ToArray(); - - public static Vector? MapVectorForStorageModel(TVector vector) - { - if (vector == null) + public static Vector? MapVectorForStorageModel(object? vector) + => vector switch { - return null; - } + ReadOnlyMemory floatMemory + => new Pgvector.Vector( + MemoryMarshal.TryGetArray(floatMemory, out ArraySegment segment) && + segment.Count == segment.Array!.Length ? segment.Array : floatMemory.ToArray()), - if (vector is ReadOnlyMemory floatMemory) - { - var vecArray = MemoryMarshal.TryGetArray(floatMemory, out ArraySegment array) && - array.Count == array.Array!.Length ? - array.Array : - floatMemory.ToArray(); - return new Vector(vecArray); - } + // TODO: Implement support for Half, binary, sparse embeddings (#11083) - throw new NotSupportedException($"Mapping for type {typeof(TVector).FullName} to a vector is not supported."); - } + null => null, - public static ReadOnlyMemory? MapVectorForDataModel(object? vector) - { - var pgVector = vector is Vector pgv ? pgv : null; - if (pgVector == null) { return null; } - var vecArray = pgVector.ToArray(); - return vecArray != null && vecArray.Length != 0 ? (ReadOnlyMemory)vecArray : null; - } - - public static TPropertyType? GetPropertyValue(NpgsqlDataReader reader, string propertyName) - { - int propertyIndex = reader.GetOrdinal(propertyName); - - if (reader.IsDBNull(propertyIndex)) - { - return default; - } - - return reader.GetFieldValue(propertyIndex); - } + var value => throw new NotSupportedException($"Mapping for type '{value.GetType().Name}' to a vector is not supported.") + }; public static object? GetPropertyValue(NpgsqlDataReader reader, string propertyName, Type propertyType) { @@ -164,14 +136,9 @@ public static (string PgType, bool IsNullable) GetPostgresTypeName(Type property /// /// The vector property. /// The PostgreSQL vector type name. - public static (string PgType, bool IsNullable) GetPgVectorTypeName(VectorStoreRecordVectorProperty vectorProperty) + public static (string PgType, bool IsNullable) GetPgVectorTypeName(VectorStoreRecordVectorPropertyModel vectorProperty) { - if (vectorProperty.Dimensions <= 0) - { - throw new ArgumentException("Vector property must have a positive number of dimensions."); - } - - return ($"VECTOR({vectorProperty.Dimensions})", Nullable.GetUnderlyingType(vectorProperty.PropertyType) != null); + return ($"VECTOR({vectorProperty.Dimensions})", Nullable.GetUnderlyingType(vectorProperty.EmbeddingType) != null); } public static NpgsqlParameter GetNpgsqlParameter(object? value) @@ -199,42 +166,59 @@ public static NpgsqlParameter GetNpgsqlParameter(object? value) } /// - /// Returns information about vector indexes to create, validating that the dimensions of the vector are supported. + /// Returns information about indexes to create, validating that the dimensions of the vector are supported. /// /// The properties of the vector store record. - /// A list of tuples containing the column name, index kind, and distance function for each vector property. + /// A list of tuples containing the column name, index kind, and distance function for each property. /// /// The default index kind is "Flat", which prevents the creation of an index. /// - public static List<(string column, string kind, string function)> GetVectorIndexInfo(IReadOnlyList properties) + public static List<(string column, string kind, string function, bool isVector)> GetIndexInfo(IReadOnlyList properties) { - var vectorIndexesToCreate = new List<(string column, string kind, string function)>(); + var vectorIndexesToCreate = new List<(string column, string kind, string function, bool isVector)>(); foreach (var property in properties) { - if (property is VectorStoreRecordVectorProperty vectorProperty) + switch (property) { - var vectorColumnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; - var indexKind = vectorProperty.IndexKind ?? PostgresConstants.DefaultIndexKind; - var distanceFunction = vectorProperty.DistanceFunction ?? PostgresConstants.DefaultDistanceFunction; - - // Index kind of "Flat" to prevent the creation of an index. This is the default behavior. - // Otherwise, the index will be created with the specified index kind and distance function, if supported. - if (indexKind != IndexKind.Flat) - { - // Ensure the dimensionality of the vector is supported for indexing. - if (PostgresConstants.IndexMaxDimensions.TryGetValue(indexKind, out int maxDimensions) && vectorProperty.Dimensions > maxDimensions) + case VectorStoreRecordKeyPropertyModel: + // There is no need to create a separate index for the key property. + break; + + case VectorStoreRecordVectorPropertyModel vectorProperty: + var indexKind = vectorProperty.IndexKind ?? PostgresConstants.DefaultIndexKind; + var distanceFunction = vectorProperty.DistanceFunction ?? PostgresConstants.DefaultDistanceFunction; + + // Index kind of "Flat" to prevent the creation of an index. This is the default behavior. + // Otherwise, the index will be created with the specified index kind and distance function, if supported. + if (indexKind != IndexKind.Flat) { - throw new NotSupportedException( - $"The provided vector property {vectorProperty.DataModelPropertyName} has {vectorProperty.Dimensions} dimensions, " + - $"which is not supported by the {indexKind} index. The maximum number of dimensions supported by the {indexKind} index " + - $"is {maxDimensions}. Please reduce the number of dimensions or use a different index." - ); + // Ensure the dimensionality of the vector is supported for indexing. + if (PostgresConstants.IndexMaxDimensions.TryGetValue(indexKind, out int maxDimensions) && vectorProperty.Dimensions > maxDimensions) + { + throw new NotSupportedException( + $"The provided vector property {vectorProperty.ModelName} has {vectorProperty.Dimensions} dimensions, " + + $"which is not supported by the {indexKind} index. The maximum number of dimensions supported by the {indexKind} index " + + $"is {maxDimensions}. Please reduce the number of dimensions or use a different index." + ); + } + + vectorIndexesToCreate.Add((vectorProperty.StorageName, indexKind, distanceFunction, isVector: true)); } - vectorIndexesToCreate.Add((vectorColumnName, indexKind, distanceFunction)); - } + break; + + case VectorStoreRecordDataPropertyModel dataProperty: + if (dataProperty.IsIndexed) + { + vectorIndexesToCreate.Add((dataProperty.StorageName, "", "", isVector: false)); + } + break; + + default: + throw new UnreachableException(); } } + return vectorIndexesToCreate; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs index 27fa7181bdc5..87efc8e547b3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs @@ -17,17 +17,22 @@ internal static class PostgresVectorStoreUtils /// The type of the items in the async enumerable. /// The async enumerable to wrap. /// The name of the operation being performed. + /// The name of the vector store. /// The name of the collection being operated on. /// An async enumerable that will throw a if an exception is thrown while iterating over the original enumerator. - public static async IAsyncEnumerable WrapAsyncEnumerableAsync(IAsyncEnumerable asyncEnumerable, string operationName, string? collectionName = null) + public static async IAsyncEnumerable WrapAsyncEnumerableAsync( + IAsyncEnumerable asyncEnumerable, + string operationName, + string? vectorStoreName = null, + string? collectionName = null) { var enumerator = asyncEnumerable.ConfigureAwait(false).GetAsyncEnumerator(); - var nextResult = await GetNextAsync(enumerator, operationName, collectionName).ConfigureAwait(false); + var nextResult = await GetNextAsync(enumerator, operationName, vectorStoreName, collectionName).ConfigureAwait(false); while (nextResult.more) { yield return nextResult.item; - nextResult = await GetNextAsync(enumerator, operationName, collectionName).ConfigureAwait(false); + nextResult = await GetNextAsync(enumerator, operationName, vectorStoreName, collectionName).ConfigureAwait(false); } } @@ -37,20 +42,26 @@ public static async IAsyncEnumerable WrapAsyncEnumerableAsync(IAsyncEnumer /// /// The enumerator to get the next result from. /// The name of the operation being performed. + /// The name of the vector store. /// The name of the collection being operated on. /// A value indicating whether there are more results and the current string if true. - public static async Task<(T item, bool more)> GetNextAsync(ConfiguredCancelableAsyncEnumerable.Enumerator enumerator, string operationName, string? collectionName = null) + public static async Task<(T item, bool more)> GetNextAsync( + ConfiguredCancelableAsyncEnumerable.Enumerator enumerator, + string operationName, + string? vectorStoreName = null, + string? collectionName = null) { try { var more = await enumerator.MoveNextAsync(); return (enumerator.Current, more); } - catch (Exception ex) + catch (Exception ex) when (ex is not (NotSupportedException or ArgumentException)) { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = PostgresConstants.DatabaseName, + VectorStoreSystemName = PostgresConstants.VectorStoreSystemName, + VectorStoreName = vectorStoreName, CollectionName = collectionName, OperationName = operationName }; diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md b/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md index e9ed71109fbb..4d98642cb02b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md @@ -37,58 +37,6 @@ sk_demo=# CREATE EXTENSION vector; See [this sample](../../../samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs) for an example of using the vector store. -### Using PostgresMemoryStore +For more information on using Postgres as a vector store, see the [PostgresVectorStore](https://learn.microsoft.com/semantic-kernel/concepts/vector-store-connectors/out-of-the-box-connectors/postgres-connector) documentation. -> See [Example 14](../../../samples/Concepts/Memory/SemanticTextMemory_Building.cs) and [Example 15](../../../samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs) for more memory usage examples with the kernel. - -```csharp -NpgsqlDataSourceBuilder dataSourceBuilder = new NpgsqlDataSourceBuilder("Host=localhost;Port=5432;Database=sk_demo;User Id=postgres;Password=mysecretpassword"); -dataSourceBuilder.UseVector(); -NpgsqlDataSource dataSource = dataSourceBuilder.Build(); - -var memoryWithPostgres = new MemoryBuilder() - .WithPostgresMemoryStore(dataSource, vectorSize: 1536/*, schema: "public" */) - .WithLoggerFactory(loggerFactory) - .WithOpenAITextEmbeddingGeneration("text-embedding-ada-002", apiKey) - .Build(); - -var memoryPlugin = kernel.ImportPluginFromObject(new TextMemoryPlugin(memoryWithPostgres)); -``` - -### Create Index - -> By default, pgvector performs exact nearest neighbor search, which provides perfect recall. - -> You can add an index to use approximate nearest neighbor search, which trades some recall for performance. Unlike typical indexes, you will see different results for queries after adding an approximate index. - -> Three keys to achieving good recall are: -> -> - Create the index after the table has some data -> - Choose an appropriate number of lists - a good place to start is rows / 1000 for up to 1M rows and sqrt(rows) for over 1M rows -> - When querying, specify an appropriate number of probes (higher is better for recall, lower is better for speed) - a good place to start is sqrt(lists) - -Please read [the documentation](https://github.com/pgvector/pgvector#indexing) for more information. - -Based on the data rows of your collection table, consider the following statement to create an index. - -```sql -DO $$ -DECLARE - collection TEXT; - c_count INTEGER; -BEGIN - SELECT 'REPLACE YOUR COLLECTION TABLE NAME' INTO collection; - - -- Get count of records in collection - EXECUTE format('SELECT count(*) FROM public.%I;', collection) INTO c_count; - - -- Create Index (https://github.com/pgvector/pgvector#indexing) - IF c_count > 10000000 THEN - EXECUTE format('CREATE INDEX %I ON public.%I USING ivfflat (embedding vector_cosine_ops) WITH (lists = %s);', - collection || '_ix', collection, ROUND(sqrt(c_count))); - ELSIF c_count > 10000 THEN - EXECUTE format('CREATE INDEX %I ON public.%I USING ivfflat (embedding vector_cosine_ops) WITH (lists = %s);', - collection || '_ix', collection, c_count / 1000); - END IF; -END $$; -``` +Use the [getting started instructions](https://learn.microsoft.com/semantic-kernel/concepts/vector-store-connectors/?pivots=programming-language-csharp#getting-started-with-vector-store-connectors) on the Microsoft Leearn site to learn more about using the vector store. diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Connectors.Memory.Qdrant.csproj b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Connectors.Memory.Qdrant.csproj index 93425c69fbe9..499c656ad41c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Connectors.Memory.Qdrant.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Connectors.Memory.Qdrant.csproj @@ -4,13 +4,14 @@ Microsoft.SemanticKernel.Connectors.Qdrant $(AssemblyName) - net8.0;netstandard2.0 + net8.0;netstandard2.0;net462 preview + @@ -19,16 +20,22 @@ - + + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/CreateCollectionRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/CreateCollectionRequest.cs index 6aaab2f26256..1f8e0945bc71 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/CreateCollectionRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/CreateCollectionRequest.cs @@ -1,13 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Diagnostics.CodeAnalysis; using System.Net.Http; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Qdrant; -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] internal sealed class CreateCollectionRequest { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/DeleteCollectionRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/DeleteCollectionRequest.cs index fd6df2fe945d..47ff25e408e6 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/DeleteCollectionRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/DeleteCollectionRequest.cs @@ -1,11 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; namespace Microsoft.SemanticKernel.Connectors.Qdrant; -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] internal sealed class DeleteCollectionRequest { public static DeleteCollectionRequest Create(string collectionName) diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/DeleteVectorsRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/DeleteVectorsRequest.cs index 6993168b84fe..80a1002e951d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/DeleteVectorsRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/DeleteVectorsRequest.cs @@ -1,13 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Net.Http; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Qdrant; -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] internal sealed class DeleteVectorsRequest { [JsonPropertyName("points")] diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/DeleteVectorsResponse.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/DeleteVectorsResponse.cs index c227f407babf..fe0817bc78b1 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/DeleteVectorsResponse.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/DeleteVectorsResponse.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; namespace Microsoft.SemanticKernel.Connectors.Qdrant; @@ -8,6 +8,6 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant; /// Empty qdrant response for requests that return nothing but status / error. /// #pragma warning disable CA1812 // Avoid uninstantiated internal classes. Justification: deserialized by QdrantVectorDbClient.DeleteVectorsByIdAsync & QdrantVectorDbClient.DeleteVectorByPayloadIdAsync -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] internal sealed class DeleteVectorsResponse : QdrantResponse; #pragma warning restore CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/GetCollectionRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/GetCollectionRequest.cs index 5648ae22212f..6c1b8f701414 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/GetCollectionRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/GetCollectionRequest.cs @@ -1,12 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Qdrant; -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] internal sealed class GetCollectionsRequest { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/GetVectorsRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/GetVectorsRequest.cs index 60015e496cb7..455ff62c1d5b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/GetVectorsRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/GetVectorsRequest.cs @@ -1,14 +1,14 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Net.Http; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Qdrant; -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] internal sealed class GetVectorsRequest { /// @@ -52,7 +52,13 @@ public static GetVectorsRequest Create(string collectionName) public GetVectorsRequest WithPointId(string pointId) { +#if NET462 + var points = this.PointIds.ToList(); + points.Add(pointId); + this.PointIds = points; +#else this.PointIds = this.PointIds.Append(pointId); +#endif return this; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/GetVectorsResponse.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/GetVectorsResponse.cs index 35c0584c73d7..20fb32e7b0db 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/GetVectorsResponse.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/GetVectorsResponse.cs @@ -2,13 +2,12 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Qdrant; #pragma warning disable CA1812 // Avoid uninstantiated internal classes: Used for Json Deserialization -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] internal sealed class GetVectorsResponse : QdrantResponse { internal sealed class Record diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/ListCollectionsRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/ListCollectionsRequest.cs index 831f9213f2d7..f577c42a0b8f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/ListCollectionsRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/ListCollectionsRequest.cs @@ -1,11 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; namespace Microsoft.SemanticKernel.Connectors.Qdrant; -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] internal sealed class ListCollectionsRequest { public static ListCollectionsRequest Create() diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/ListCollectionsResponse.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/ListCollectionsResponse.cs index 250c2b06698e..6a5dda325887 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/ListCollectionsResponse.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/ListCollectionsResponse.cs @@ -1,13 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Qdrant; #pragma warning disable CA1812 // Avoid uninstantiated internal classes: Used for Json Deserialization -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] internal sealed class ListCollectionsResponse : QdrantResponse { internal sealed class CollectionResult diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/NumberToStringConverter.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/NumberToStringConverter.cs index 5e1223fab1cc..5cb5502d8e7f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/NumberToStringConverter.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/NumberToStringConverter.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Text.Json; using System.Text.Json.Serialization; @@ -9,7 +8,7 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant; #pragma warning disable CA1812 // Avoid uninstantiated internal classes: Used for Json Deserialization -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] internal sealed class NumberToStringConverter : JsonConverter { public override string Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/QdrantResponse.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/QdrantResponse.cs index 16717281120d..380f152d45f0 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/QdrantResponse.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/QdrantResponse.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Qdrant; @@ -8,7 +8,7 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant; /// /// Base class for Qdrant response schema. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] internal abstract class QdrantResponse { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/SearchVectorsRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/SearchVectorsRequest.cs index 6aaf2645eb34..f7109a8c8a85 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/SearchVectorsRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/SearchVectorsRequest.cs @@ -2,13 +2,12 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Net.Http; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Qdrant; -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] internal sealed class SearchVectorsRequest { [JsonPropertyName("vector")] diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/SearchVectorsResponse.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/SearchVectorsResponse.cs index 81652d032caa..9a0414109aa1 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/SearchVectorsResponse.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/SearchVectorsResponse.cs @@ -2,13 +2,12 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Qdrant; #pragma warning disable CA1812 // Avoid uninstantiated internal classes: Used for Json Deserialization -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] internal sealed class SearchVectorsResponse : QdrantResponse { internal sealed class ScoredPoint diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/UpsertVectorRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/UpsertVectorRequest.cs index 4b03bbf047e2..6f6661c80902 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/UpsertVectorRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/UpsertVectorRequest.cs @@ -2,13 +2,12 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Net.Http; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Qdrant; -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] internal sealed class UpsertVectorRequest { public static UpsertVectorRequest Create(string collectionName) diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/UpsertVectorResponse.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/UpsertVectorResponse.cs index 59f61f3ae94b..3caa14c65efe 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/UpsertVectorResponse.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/UpsertVectorResponse.cs @@ -1,12 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Qdrant; #pragma warning disable CA1812 // Avoid uninstantiated internal classes: Used for Json Deserialization -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] internal sealed class UpsertVectorResponse : QdrantResponse { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/IQdrantVectorDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/IQdrantVectorDbClient.cs index aa9ad3f72190..414900d2f2b0 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/IQdrantVectorDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/IQdrantVectorDbClient.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; @@ -11,7 +10,7 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant; /// /// Interface for a Qdrant vector database client. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] public interface IQdrantVectorDbClient { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/IQdrantVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/IQdrantVectorStoreRecordCollectionFactory.cs index 32dd7ed47d91..994e3629e81b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/IQdrantVectorStoreRecordCollectionFactory.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/IQdrantVectorStoreRecordCollectionFactory.cs @@ -22,5 +22,6 @@ public interface IQdrantVectorStoreRecordCollectionFactory /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. /// The new instance of . IVectorStoreRecordCollection CreateVectorStoreRecordCollection(QdrantClient qdrantClient, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) - where TKey : notnull; + where TKey : notnull + where TRecord : notnull; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/MockableQdrantClient.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/MockableQdrantClient.cs index 8575bc3bef7d..7461b574184e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/MockableQdrantClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/MockableQdrantClient.cs @@ -312,4 +312,23 @@ public virtual Task> QueryAsync( lookupFrom, timeout, cancellationToken); + + public virtual Task ScrollAsync( + string collectionName, + Filter filter, + WithVectorsSelector vectorsSelector, + uint limit = 10, + OrderBy? orderBy = null, + CancellationToken cancellationToken = default) + => this._qdrantClient.ScrollAsync( + collectionName, + filter, + limit, + offset: null, + payloadSelector: null, + vectorsSelector, + readConsistency: null, + shardKeySelector: null, + orderBy, + cancellationToken); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantConstants.cs new file mode 100644 index 000000000000..6e983cb76806 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantConstants.cs @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.Qdrant; + +internal static class QdrantConstants +{ + internal const string VectorStoreSystemName = "qdrant"; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantDistanceType.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantDistanceType.cs index 16a3f58a7daf..3c948fc84c7c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantDistanceType.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantDistanceType.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Qdrant; @@ -9,7 +9,7 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant; /// The vector distance type used by Qdrant. /// [JsonConverter(typeof(JsonStringEnumConverter))] -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] public enum QdrantDistanceType { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantFilterTranslator.cs index ffd0333f0867..bf88e98b5bc9 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantFilterTranslator.cs @@ -7,9 +7,9 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; using Google.Protobuf.Collections; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.ConnectorSupport.Filter; using Qdrant.Client.Grpc; using Range = Qdrant.Client.Grpc.Range; @@ -17,17 +17,20 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant; internal class QdrantFilterTranslator { - private IReadOnlyDictionary _storagePropertyNames = null!; + private VectorStoreRecordModel _model = null!; private ParameterExpression _recordParameter = null!; - internal Filter Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + internal Filter Translate(LambdaExpression lambdaExpression, VectorStoreRecordModel model) { - this._storagePropertyNames = storagePropertyNames; + this._model = model; Debug.Assert(lambdaExpression.Parameters.Count == 1); this._recordParameter = lambdaExpression.Parameters[0]; - return this.Translate(lambdaExpression.Body); + var preprocessor = new FilterTranslationPreprocessor { InlineCapturedVariables = true }; + var preprocessedExpression = preprocessor.Visit(lambdaExpression.Body); + + return this.Translate(preprocessedExpression); } private Filter Translate(Expression? node) @@ -46,9 +49,9 @@ private Filter Translate(Expression? node) BinaryExpression { NodeType: ExpressionType.OrElse } orElse => this.TranslateOrElse(orElse.Left, orElse.Right), UnaryExpression { NodeType: ExpressionType.Not } not => this.TranslateNot(not.Operand), - // MemberExpression is generally handled within e.g. TranslateEqual; this is used to translate direct bool inside filter (e.g. Filter => r => r.Bool) - MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _) - => this.TranslateEqual(member, Expression.Constant(true)), + // Special handling for bool constant as the filter expression (r => r.Bool) + Expression when node.Type == typeof(bool) && this.TryBindProperty(node, out var property) + => this.GenerateEqual(property.StorageName, value: true), MethodCallExpression methodCall => this.TranslateMethodCall(methodCall), @@ -56,53 +59,45 @@ private Filter Translate(Expression? node) }; private Filter TranslateEqual(Expression left, Expression right, bool negated = false) - { - return TryProcessEqual(left, right, out var result) - ? result - : TryProcessEqual(right, left, out result) - ? result - : throw new NotSupportedException("Equality expression not supported by Qdrant"); + => this.TryBindProperty(left, out var property) && right is ConstantExpression { Value: var rightConstant } + ? this.GenerateEqual(property.StorageName, rightConstant, negated) + : this.TryBindProperty(right, out property) && left is ConstantExpression { Value: var leftConstant } + ? this.GenerateEqual(property.StorageName, leftConstant, negated) + : throw new NotSupportedException("Invalid equality/comparison"); - bool TryProcessEqual(Expression first, Expression second, [NotNullWhen(true)] out Filter? result) - { - // TODO: Nullable - if (this.TryTranslateFieldAccess(first, out var storagePropertyName) - && TryGetConstant(second, out var constantValue)) + private Filter GenerateEqual(string propertyStorageName, object? value, bool negated = false) + { + var condition = value is null + ? new Condition { IsNull = new() { Key = propertyStorageName } } + : new Condition { - var condition = constantValue is null - ? new Condition { IsNull = new() { Key = storagePropertyName } } - : new Condition + Field = new FieldCondition + { + Key = propertyStorageName, + Match = value switch { - Field = new FieldCondition - { - Key = storagePropertyName, - Match = constantValue switch - { - string stringValue => new Match { Keyword = stringValue }, - int intValue => new Match { Integer = intValue }, - long longValue => new Match { Integer = longValue }, - bool boolValue => new Match { Boolean = boolValue }, - - _ => throw new InvalidOperationException($"Unsupported filter value type '{constantValue.GetType().Name}'.") - } - } - }; + string stringValue => new Match { Keyword = stringValue }, + int intValue => new Match { Integer = intValue }, + long longValue => new Match { Integer = longValue }, + bool boolValue => new Match { Boolean = boolValue }, - result = new Filter(); - if (negated) - { - result.MustNot.Add(condition); - } - else - { - result.Must.Add(condition); + _ => throw new InvalidOperationException($"Unsupported filter value type '{value.GetType().Name}'.") + } } - return true; - } + }; - result = null; - return false; + var result = new Filter(); + + if (negated) + { + result.MustNot.Add(condition); } + else + { + result.Must.Add(condition); + } + + return result; } private Filter TranslateComparison(BinaryExpression comparison) @@ -116,8 +111,7 @@ private Filter TranslateComparison(BinaryExpression comparison) bool TryProcessComparison(Expression first, Expression second, [NotNullWhen(true)] out Filter? result) { // TODO: Nullable - if (this.TryTranslateFieldAccess(first, out var storagePropertyName) - && TryGetConstant(second, out var constantValue)) + if (this.TryBindProperty(first, out var property) && second is ConstantExpression { Value: var constantValue }) { double doubleConstantValue = constantValue switch { @@ -132,7 +126,7 @@ bool TryProcessComparison(Expression first, Expression second, [NotNullWhen(true { Field = new FieldCondition { - Key = storagePropertyName, + Key = property.StorageName, Range = comparison.NodeType switch { ExpressionType.GreaterThan => new Range { Gt = doubleConstantValue }, @@ -279,7 +273,7 @@ private Filter TranslateContains(Expression source, Expression item) switch (source) { // Contains over field enumerable - case var _ when this.TryTranslateFieldAccess(source, out _): + case var _ when this.TryBindProperty(source, out _): // Oddly, in Qdrant, tag list contains is handled using a Match condition, just like equality. return this.TranslateEqual(source, item); @@ -289,9 +283,9 @@ private Filter TranslateContains(Expression source, Expression item) for (var i = 0; i < newArray.Expressions.Count; i++) { - if (!TryGetConstant(newArray.Expressions[i], out var elementValue)) + if (newArray.Expressions[i] is not ConstantExpression { Value: var elementValue }) { - throw new NotSupportedException("Invalid element in array"); + throw new NotSupportedException("Inline array elements must be constants"); } elements[i] = elementValue; @@ -299,9 +293,7 @@ private Filter TranslateContains(Expression source, Expression item) return ProcessInlineEnumerable(elements, item); - // Contains over captured enumerable (we inline) - case var _ when TryGetConstant(source, out var constantEnumerable) - && constantEnumerable is IEnumerable enumerable and not string: + case ConstantExpression { Value: IEnumerable enumerable and not string }: return ProcessInlineEnumerable(enumerable, item); default: @@ -310,77 +302,86 @@ private Filter TranslateContains(Expression source, Expression item) Filter ProcessInlineEnumerable(IEnumerable elements, Expression item) { - if (!this.TryTranslateFieldAccess(item, out var storagePropertyName)) + if (!this.TryBindProperty(item, out var property)) { throw new NotSupportedException("Unsupported item type in Contains"); } - if (item.Type == typeof(string)) + switch (property.Type) { - var strings = new RepeatedStrings(); + case var t when t == typeof(string): + var strings = new RepeatedStrings(); - foreach (var value in elements) - { - strings.Strings.Add(value is string or null - ? (string?)value - : throw new ArgumentException("Non-string element in string Contains array")); - } + foreach (var value in elements) + { + strings.Strings.Add(value is string or null + ? (string?)value + : throw new ArgumentException("Non-string element in string Contains array")); + } - return new Filter { Must = { new Condition { Field = new FieldCondition { Key = storagePropertyName, Match = new Match { Keywords = strings } } } } }; - } + return new Filter { Must = { new Condition { Field = new FieldCondition { Key = property.StorageName, Match = new Match { Keywords = strings } } } } }; - if (item.Type == typeof(int)) - { - var ints = new RepeatedIntegers(); + case var t when t == typeof(int): + var ints = new RepeatedIntegers(); - foreach (var value in elements) - { - ints.Integers.Add(value is int intValue - ? intValue - : throw new ArgumentException("Non-int element in string Contains array")); - } + foreach (var value in elements) + { + ints.Integers.Add(value is int intValue + ? intValue + : throw new ArgumentException("Non-int element in string Contains array")); + } - return new Filter { Must = { new Condition { Field = new FieldCondition { Key = storagePropertyName, Match = new Match { Integers = ints } } } } }; - } + return new Filter { Must = { new Condition { Field = new FieldCondition { Key = property.StorageName, Match = new Match { Integers = ints } } } } }; - throw new NotSupportedException("Contains only supported over array of ints or strings"); + default: + throw new NotSupportedException("Contains only supported over array of ints or strings"); + } } } - private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName) + private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out VectorStoreRecordPropertyModel? property) { - if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter) - { - if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName)) - { - throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name."); - } + Type? convertedClrType = null; - return true; + if (expression is UnaryExpression { NodeType: ExpressionType.Convert } unary) + { + expression = unary.Operand; + convertedClrType = unary.Type; } - storagePropertyName = null; - return false; - } + var modelName = expression switch + { + // Regular member access for strongly-typed POCO binding (e.g. r => r.SomeInt == 8) + MemberExpression memberExpression when memberExpression.Expression == this._recordParameter + => memberExpression.Member.Name, - private static bool TryGetConstant(Expression expression, out object? constantValue) - { - switch (expression) + // Dictionary lookup for weakly-typed dynamic binding (e.g. r => r["SomeInt"] == 8) + MethodCallExpression + { + Method: { Name: "get_Item", DeclaringType: var declaringType }, + Arguments: [ConstantExpression { Value: string keyName }] + } methodCall when methodCall.Object == this._recordParameter && declaringType == typeof(Dictionary) + => keyName, + + _ => null + }; + + if (modelName is null) { - case ConstantExpression { Value: var v }: - constantValue = v; - return true; + property = null; + return false; + } - // This identifies compiler-generated closure types which contain captured variables. - case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } - when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) - && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true): - constantValue = fieldInfo.GetValue(constant.Value); - return true; + if (!this._model.PropertyMap.TryGetValue(modelName, out property)) + { + throw new InvalidOperationException($"Property name '{modelName}' provided as part of the filter clause is not a valid property name."); + } - default: - constantValue = null; - return false; + if (convertedClrType is not null && convertedClrType != property.Type) + { + throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convertedClrType.Name}', but its configured type is '{property.Type.Name}'."); } + + return true; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantGenericDataModelMapper.cs deleted file mode 100644 index 5cce141d0223..000000000000 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantGenericDataModelMapper.cs +++ /dev/null @@ -1,217 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Linq; -using Microsoft.Extensions.VectorData; -using Qdrant.Client.Grpc; - -namespace Microsoft.SemanticKernel.Connectors.Qdrant; - -/// -/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Qdrant. -/// -internal class QdrantGenericDataModelMapper : IVectorStoreRecordMapper, PointStruct>, IVectorStoreRecordMapper, PointStruct> -{ - /// A helper to access property information for the current data model and record definition. - private readonly VectorStoreRecordPropertyReader _propertyReader; - - /// A value indicating whether the vectors in the store are named, or whether there is just a single unnamed vector per qdrant point. - private readonly bool _hasNamedVectors; - - /// - /// Initializes a new instance of the class. - /// - /// A helper to access property information for the current data model and record definition. - /// A value indicating whether the vectors in the store are named, or whether there is just a single unnamed vector per qdrant point. - public QdrantGenericDataModelMapper( - VectorStoreRecordPropertyReader propertyReader, - bool hasNamedVectors) - { - Verify.NotNull(propertyReader); - - // Validate property types. - propertyReader.VerifyDataProperties(QdrantVectorStoreRecordFieldMapping.s_supportedDataTypes, supportEnumerable: true); - propertyReader.VerifyVectorProperties(QdrantVectorStoreRecordFieldMapping.s_supportedVectorTypes); - - // Assign. - this._propertyReader = propertyReader; - this._hasNamedVectors = hasNamedVectors; - } - - /// - public PointStruct MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - { - // Create point. - var pointStruct = new PointStruct - { - Id = new PointId { Num = dataModel.Key }, - Vectors = new Vectors(), - Payload = { }, - }; - - // Loop through all properties and map each from the data model to the storage model. - MapProperties( - this._propertyReader.Properties, - dataModel.Data, - dataModel.Vectors, - pointStruct, - this._hasNamedVectors); - - return pointStruct; - } - - /// - public VectorStoreGenericDataModel MapFromStorageToDataModel(PointStruct storageModel, StorageToDataModelMapperOptions options) - { - var dataModel = new VectorStoreGenericDataModel(storageModel.Id.Num); - MapProperties(this._propertyReader.Properties, storageModel, dataModel.Data, dataModel.Vectors, this._hasNamedVectors); - return dataModel; - } - - /// - PointStruct IVectorStoreRecordMapper, PointStruct>.MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - { - // Create point. - var pointStruct = new PointStruct - { - Id = new PointId { Uuid = dataModel.Key.ToString("D") }, - Vectors = new Vectors(), - Payload = { }, - }; - - // Loop through all properties and map each from the data model to the storage model. - MapProperties( - this._propertyReader.Properties, - dataModel.Data, - dataModel.Vectors, - pointStruct, - this._hasNamedVectors); - - return pointStruct; - } - - /// - VectorStoreGenericDataModel IVectorStoreRecordMapper, PointStruct>.MapFromStorageToDataModel(PointStruct storageModel, StorageToDataModelMapperOptions options) - { - var dataModel = new VectorStoreGenericDataModel(new Guid(storageModel.Id.Uuid)); - MapProperties(this._propertyReader.Properties, storageModel, dataModel.Data, dataModel.Vectors, this._hasNamedVectors); - return dataModel; - } - - /// - /// Map the payload and vector properties from the data model to the qdrant storage model. - /// - /// The list of properties to map. - /// The payload properties on the data model. - /// The vector properties on the data model. - /// The storage model to map to. - /// A value indicating whether qdrant is using named vectors for this collection. - /// Thrown if a vector on the data model is not a supported type. - private static void MapProperties(IEnumerable properties, Dictionary dataProperties, Dictionary vectorProperties, PointStruct pointStruct, bool hasNamedVectors) - { - if (hasNamedVectors) - { - pointStruct.Vectors.Vectors_ = new NamedVectors(); - } - - foreach (var property in properties) - { - if (property is VectorStoreRecordDataProperty dataProperty) - { - var storagePropertyName = dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName; - - // Just skip this property if it's not in the data model. - if (!dataProperties.TryGetValue(dataProperty.DataModelPropertyName, out var propertyValue)) - { - continue; - } - - // Map. - pointStruct.Payload.Add(storagePropertyName, QdrantVectorStoreRecordFieldMapping.ConvertToGrpcFieldValue(propertyValue)); - } - else if (property is VectorStoreRecordVectorProperty vectorProperty) - { - var storagePropertyName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; - - // Just skip this property if it's not in the data model. - if (!vectorProperties.TryGetValue(vectorProperty.DataModelPropertyName, out var vector)) - { - continue; - } - - // Validate. - if (vector is not ReadOnlyMemory floatROM) - { - throw new VectorStoreRecordMappingException($"Vector property '{vectorProperty.DataModelPropertyName}' on provided record of type {nameof(VectorStoreGenericDataModel)} must be of type ReadOnlyMemory and not null."); - } - - // Map. - if (hasNamedVectors) - { - pointStruct.Vectors.Vectors_.Vectors.Add(storagePropertyName, floatROM.ToArray()); - } - else - { - pointStruct.Vectors.Vector = floatROM.ToArray(); - } - } - } - } - - /// - /// Map the payload and vector properties from the qdrant storage model to the data model. - /// - /// The list of properties to map. - /// The storage model to map from. - /// The payload properties on the data model. - /// The vector properties on the data model. - /// A value indicating whether qdrant is using named vectors for this collection. - public static void MapProperties(IEnumerable properties, PointStruct storageModel, Dictionary dataProperties, Dictionary vectorProperties, bool hasNamedVectors) - { - foreach (var property in properties) - { - if (property is VectorStoreRecordDataProperty dataProperty) - { - var storagePropertyName = dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName; - - // Just skip this property if it's not in the storage model. - if (!storageModel.Payload.TryGetValue(storagePropertyName, out var propertyValue)) - { - continue; - } - - if (propertyValue.HasNullValue) - { - // Shortcut any null handling here so we don't have to check for it for each case. - dataProperties[dataProperty.DataModelPropertyName] = null; - } - else - { - var convertedValue = QdrantVectorStoreRecordFieldMapping.ConvertFromGrpcFieldValueToNativeType(propertyValue, dataProperty.PropertyType); - dataProperties[dataProperty.DataModelPropertyName] = convertedValue; - } - } - else if (property is VectorStoreRecordVectorProperty vectorProperty) - { - Vector? vector; - if (hasNamedVectors) - { - var storagePropertyName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; - - // Just skip this property if it's not in the storage model. - if (!storageModel.Vectors.Vectors_.Vectors.TryGetValue(storagePropertyName, out vector)) - { - continue; - } - } - else - { - vector = storageModel.Vectors.Vector; - } - - vectorProperties[vectorProperty.DataModelPropertyName] = new ReadOnlyMemory(vector.Data.ToArray()); - } - } - } -} diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantKernelBuilderExtensions.cs index c8dd0b6070b9..ae78eac52689 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantKernelBuilderExtensions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Qdrant; using Qdrant.Client; @@ -9,6 +10,7 @@ namespace Microsoft.SemanticKernel; /// /// Extension methods to register Qdrant instances on the . /// +[Obsolete("The IKernelBuilder extensions are being obsoleted, call the appropriate function on the Services property of your IKernelBuilder")] public static class QdrantKernelBuilderExtensions { /// @@ -41,7 +43,7 @@ public static IKernelBuilder AddQdrantVectorStore(this IKernelBuilder builder, s } /// - /// Register a Qdrant and with the specified service ID + /// Register a Qdrant and with the specified service ID /// and where the Qdrant is retrieved from the dependency injection container. /// /// The type of the key. @@ -57,13 +59,14 @@ public static IKernelBuilder AddQdrantVectorStoreRecordCollection QdrantVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) where TKey : notnull + where TRecord : notnull { builder.Services.AddQdrantVectorStoreRecordCollection(collectionName, options, serviceId); return builder; } /// - /// Register a Qdrant and with the specified service ID + /// Register a Qdrant and with the specified service ID /// and where the Qdrant is constructed using the provided parameters. /// /// The type of the key. @@ -87,6 +90,7 @@ public static IKernelBuilder AddQdrantVectorStoreRecordCollection QdrantVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) where TKey : notnull + where TRecord : notnull { builder.Services.AddQdrantVectorStoreRecordCollection(collectionName, host, port, https, apiKey, options, serviceId); return builder; diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantMemoryBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantMemoryBuilderExtensions.cs index f4233c47a6c0..56989090334c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantMemoryBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantMemoryBuilderExtensions.cs @@ -1,16 +1,18 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; using Microsoft.SemanticKernel.Http; using Microsoft.SemanticKernel.Memory; namespace Microsoft.SemanticKernel.Connectors.Qdrant; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// Provides extension methods for the class to configure Qdrant connector. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] public static class QdrantMemoryBuilderExtensions { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantMemoryStore.cs index fdd4a2eaff9b..f08b8a9d4241 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantMemoryStore.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Net.Http; using System.Runtime.CompilerServices; @@ -14,13 +13,15 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// An implementation of for Qdrant Vector Database. /// /// The Embedding data is saved to a Qdrant Vector Database instance specified in the constructor by url and port. /// The embedding data persists between subsequent instances and has similarity search capability. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] public class QdrantMemoryStore : IMemoryStore { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantServiceCollectionExtensions.cs index 693d8d94fc3b..1abd7708030b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantServiceCollectionExtensions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Qdrant; @@ -28,11 +29,12 @@ public static IServiceCollection AddQdrantVectorStore(this IServiceCollection se (sp, obj) => { var qdrantClient = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService(); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new QdrantVectorStore( - qdrantClient, - selectedOptions); + return new QdrantVectorStore(qdrantClient, options); }); return services; @@ -55,18 +57,19 @@ public static IServiceCollection AddQdrantVectorStore(this IServiceCollection se (sp, obj) => { var qdrantClient = new QdrantClient(host, port, https, apiKey); - var selectedOptions = options ?? sp.GetService(); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new QdrantVectorStore( - qdrantClient, - selectedOptions); + return new QdrantVectorStore(qdrantClient, options); }); return services; } /// - /// Register a Qdrant and with the specified service ID + /// Register a Qdrant and with the specified service ID /// and where the Qdrant is retrieved from the dependency injection container. /// /// The type of the key. @@ -82,15 +85,19 @@ public static IServiceCollection AddQdrantVectorStoreRecordCollection? options = default, string? serviceId = default) where TKey : notnull + where TRecord : notnull { services.AddKeyedTransient>( serviceId, (sp, obj) => { var qdrantClient = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return (new QdrantVectorStoreRecordCollection(qdrantClient, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; + return (new QdrantVectorStoreRecordCollection(qdrantClient, collectionName, options) as IVectorStoreRecordCollection)!; }); AddVectorizedSearch(services, serviceId); @@ -99,7 +106,7 @@ public static IServiceCollection AddQdrantVectorStoreRecordCollection - /// Register a Qdrant and with the specified service ID + /// Register a Qdrant and with the specified service ID /// and where the Qdrant is constructed using the provided parameters. /// /// The type of the key. @@ -123,15 +130,19 @@ public static IServiceCollection AddQdrantVectorStoreRecordCollection? options = default, string? serviceId = default) where TKey : notnull + where TRecord : notnull { services.AddKeyedSingleton>( serviceId, (sp, obj) => { var qdrantClient = new QdrantClient(host, port, https, apiKey); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return (new QdrantVectorStoreRecordCollection(qdrantClient, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; + return (new QdrantVectorStoreRecordCollection(qdrantClient, collectionName, options) as IVectorStoreRecordCollection)!; }); AddVectorizedSearch(services, serviceId); @@ -140,7 +151,7 @@ public static IServiceCollection AddQdrantVectorStoreRecordCollection - /// Also register the with the given as a . + /// Also register the with the given as a . /// /// The type of the key. /// The type of the data model that the collection should contain. @@ -148,8 +159,9 @@ public static IServiceCollection AddQdrantVectorStoreRecordCollectionThe service id that the registrations should use. private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) where TKey : notnull + where TRecord : notnull { - services.AddKeyedTransient>( + services.AddKeyedTransient>( serviceId, (sp, obj) => { diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorDbClient.cs index 9e158551bf24..ef4a7642f298 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorDbClient.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Net; using System.Net.Http; @@ -21,7 +20,7 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant; /// connect, create, delete, and get embeddings data from a Qdrant Vector Database instance. /// #pragma warning disable CA1001 // Types that own disposable fields should be disposable. No need to dispose the Http client here. It can either be an internal client using NonDisposableHttpClientHandler or an external client managed by the calling code, which should handle its disposal. -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] public sealed class QdrantVectorDbClient : IQdrantVectorDbClient #pragma warning restore CA1001 // Types that own disposable fields should be disposable. No need to dispose the Http client here. It can either be an internal client using NonDisposableHttpClientHandler or an external client managed by the calling code, which should handle its disposal. { diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorRecord.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorRecord.cs index 93ab5d24deb6..c7dc2189e4de 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorRecord.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorRecord.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Text.Json; using System.Text.Json.Serialization; @@ -11,7 +10,7 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant; /// /// A record structure used by Qdrant that contains an embedding and metadata. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and QdrantVectorStore")] public class QdrantVectorRecord { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStore.cs index bfac788a7cfd..dd55b3da663f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStore.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Threading; +using System.Threading.Tasks; using Grpc.Core; using Microsoft.Extensions.VectorData; using Qdrant.Client; @@ -16,10 +17,10 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant; /// /// This class can be used with collections of any schema type, but requires you to provide schema information when getting a collection. /// -public class QdrantVectorStore : IVectorStore +public sealed class QdrantVectorStore : IVectorStore { - /// The name of this database for telemetry purposes. - private const string DatabaseName = "Qdrant"; + /// Metadata about vector store. + private readonly VectorStoreMetadata _metadata; /// Qdrant client that can be used to manage the collections and points in a Qdrant store. private readonly MockableQdrantClient _qdrantClient; @@ -27,6 +28,9 @@ public class QdrantVectorStore : IVectorStore /// Optional configuration options for this class. private readonly QdrantVectorStoreOptions _options; + /// A general purpose definition that can be used to construct a collection when needing to proxy schema agnostic operations. + private static readonly VectorStoreRecordDefinition s_generalPurposeDefinition = new() { Properties = [new VectorStoreRecordKeyProperty("Key", typeof(ulong)), new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 1)] }; + /// /// Initializes a new instance of the class. /// @@ -48,11 +52,17 @@ internal QdrantVectorStore(MockableQdrantClient qdrantClient, QdrantVectorStoreO this._qdrantClient = qdrantClient; this._options = options ?? new QdrantVectorStoreOptions(); + + this._metadata = new() + { + VectorStoreSystemName = QdrantConstants.VectorStoreSystemName + }; } /// - public virtual IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull + where TRecord : notnull { #pragma warning disable CS0618 // IQdrantVectorStoreRecordCollectionFactory is obsolete if (this._options.VectorStoreCollectionFactory is not null) @@ -61,22 +71,18 @@ public virtual IVectorStoreRecordCollection GetCollection(this._qdrantClient, name, new QdrantVectorStoreRecordCollectionOptions() + var recordCollection = new QdrantVectorStoreRecordCollection(this._qdrantClient, name, new QdrantVectorStoreRecordCollectionOptions() { HasNamedVectors = this._options.HasNamedVectors, - VectorStoreRecordDefinition = vectorStoreRecordDefinition + VectorStoreRecordDefinition = vectorStoreRecordDefinition, + EmbeddingGenerator = this._options.EmbeddingGenerator }); var castRecordCollection = recordCollection as IVectorStoreRecordCollection; return castRecordCollection!; } /// - public virtual async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { IReadOnlyList collections; @@ -88,7 +94,8 @@ public virtual async IAsyncEnumerable ListCollectionNamesAsync([Enumerat { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreSystemName = QdrantConstants.VectorStoreSystemName, + VectorStoreName = this._metadata.VectorStoreName, OperationName = "ListCollections" }; } @@ -98,4 +105,31 @@ public virtual async IAsyncEnumerable ListCollectionNamesAsync([Enumerat yield return collection; } } + + /// + public Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.CollectionExistsAsync(cancellationToken); + } + + /// + public Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.DeleteCollectionAsync(cancellationToken); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreMetadata) ? this._metadata : + serviceType == typeof(QdrantClient) ? this._qdrantClient.QdrantClient : + serviceType.IsInstanceOfType(this) ? this : + null; + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionCreateMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionCreateMapping.cs index d2b5d1c55cab..0c73352de662 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionCreateMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionCreateMapping.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Qdrant.Client.Grpc; namespace Microsoft.SemanticKernel.Connectors.Qdrant; @@ -53,16 +54,11 @@ internal static class QdrantVectorStoreCollectionCreateMapping /// The property to map. /// The mapped . /// Thrown if the property is missing information or has unsupported options specified. - public static VectorParams MapSingleVector(VectorStoreRecordVectorProperty vectorProperty) + public static VectorParams MapSingleVector(VectorStoreRecordVectorPropertyModel vectorProperty) { - if (vectorProperty!.Dimensions is not > 0) - { - throw new InvalidOperationException($"Property {nameof(vectorProperty.Dimensions)} on {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' must be set to a positive integer to create a collection."); - } - if (vectorProperty!.IndexKind is not null && vectorProperty!.IndexKind != IndexKind.Hnsw) { - throw new InvalidOperationException($"Index kind '{vectorProperty!.IndexKind}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' is not supported by the Qdrant VectorStore."); + throw new InvalidOperationException($"Index kind '{vectorProperty!.IndexKind}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.ModelName}' is not supported by the Qdrant VectorStore."); } return new VectorParams { Size = (ulong)vectorProperty.Dimensions, Distance = QdrantVectorStoreCollectionCreateMapping.GetSDKDistanceAlgorithm(vectorProperty) }; @@ -72,21 +68,16 @@ public static VectorParams MapSingleVector(VectorStoreRecordVectorProperty vecto /// Maps a collection of to a qdrant . /// /// The properties to map. - /// The mapping of property names to storage names. /// THe mapped . /// Thrown if the property is missing information or has unsupported options specified. - public static VectorParamsMap MapNamedVectors(IEnumerable vectorProperties, IReadOnlyDictionary storagePropertyNames) + public static VectorParamsMap MapNamedVectors(IEnumerable vectorProperties) { var vectorParamsMap = new VectorParamsMap(); foreach (var vectorProperty in vectorProperties) { - var storageName = storagePropertyNames[vectorProperty.DataModelPropertyName]; - // Add each vector property to the vectors map. - vectorParamsMap.Map.Add( - storageName, - MapSingleVector(vectorProperty)); + vectorParamsMap.Map.Add(vectorProperty.StorageName, MapSingleVector(vectorProperty)); } return vectorParamsMap; @@ -99,20 +90,14 @@ public static VectorParamsMap MapNamedVectors(IEnumerableThe vector property definition. /// The chosen . /// Thrown if a distance function is chosen that isn't supported by qdrant. - public static Distance GetSDKDistanceAlgorithm(VectorStoreRecordVectorProperty vectorProperty) - { - if (vectorProperty.DistanceFunction is null) + public static Distance GetSDKDistanceAlgorithm(VectorStoreRecordVectorPropertyModel vectorProperty) + => vectorProperty.DistanceFunction switch { - return Distance.Cosine; - } - - return vectorProperty.DistanceFunction switch - { - DistanceFunction.CosineSimilarity => Distance.Cosine, + DistanceFunction.CosineSimilarity or null => Distance.Cosine, DistanceFunction.DotProductSimilarity => Distance.Dot, DistanceFunction.EuclideanDistance => Distance.Euclid, DistanceFunction.ManhattanDistance => Distance.Manhattan, - _ => throw new InvalidOperationException($"Distance function '{vectorProperty.DistanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' is not supported by the Qdrant VectorStore.") + + _ => throw new InvalidOperationException($"Distance function '{vectorProperty.DistanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.ModelName}' is not supported by the Qdrant VectorStore.") }; - } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionSearchMapping.cs index 7151cc5fbf0a..3646127798e9 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionSearchMapping.cs @@ -1,9 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; -using System.Linq; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Qdrant.Client.Grpc; namespace Microsoft.SemanticKernel.Connectors.Qdrant; @@ -18,10 +17,10 @@ internal static class QdrantVectorStoreCollectionSearchMapping /// Build a Qdrant from the provided . /// /// The to build a Qdrant from. - /// A mapping of data model property names to the names under which they are stored. + /// The model. /// The Qdrant . /// Thrown when the provided filter contains unsupported types, values or unknown properties. - public static Filter BuildFromLegacyFilter(VectorSearchFilter basicVectorSearchFilter, IReadOnlyDictionary storagePropertyNames) + public static Filter BuildFromLegacyFilter(VectorSearchFilter basicVectorSearchFilter, VectorStoreRecordModel model) { var filter = new Filter(); @@ -48,7 +47,7 @@ public static Filter BuildFromLegacyFilter(VectorSearchFilter basicVectorSearchF } // Get the storage name for the field. - if (!storagePropertyNames.TryGetValue(fieldName, out var storagePropertyName)) + if (!model.PropertyMap.TryGetValue(fieldName, out var property)) { throw new InvalidOperationException($"Property name '{fieldName}' provided as part of the filter clause is not a valid property name."); } @@ -62,7 +61,7 @@ public static Filter BuildFromLegacyFilter(VectorSearchFilter basicVectorSearchF Lte = Google.Protobuf.WellKnownTypes.Timestamp.FromDateTimeOffset(dateTimeOffset), }; - filter.Must.Add(new Condition() { Field = new FieldCondition() { Key = storagePropertyName, DatetimeRange = range } }); + filter.Must.Add(new Condition() { Field = new FieldCondition() { Key = property.StorageName, DatetimeRange = range } }); continue; } @@ -76,7 +75,7 @@ public static Filter BuildFromLegacyFilter(VectorSearchFilter basicVectorSearchF _ => throw new InvalidOperationException($"Unsupported filter value type '{filterValue.GetType().Name}'.") }; - filter.Must.Add(new Condition() { Field = new FieldCondition() { Key = storagePropertyName, Match = match } }); + filter.Must.Add(new Condition() { Field = new FieldCondition() { Key = property.StorageName, Match = match } }); } return filter; @@ -90,50 +89,46 @@ public static Filter BuildFromLegacyFilter(VectorSearchFilter basicVectorSearchF /// The point to map to a . /// The mapper to perform the main mapping operation with. /// A value indicating whether to include vectors in the mapped result. - /// The name of the database system the operation is being run on. + /// The name of the vector store system the operation is being run on. + /// The name of the vector store the operation is being run on. /// The name of the collection the operation is being run on. /// The type of database operation being run. /// The mapped . - public static VectorSearchResult MapScoredPointToVectorSearchResult(ScoredPoint point, IVectorStoreRecordMapper mapper, bool includeVectors, string databaseSystemName, string collectionName, string operationName) + public static VectorSearchResult MapScoredPointToVectorSearchResult( + ScoredPoint point, + QdrantVectorStoreRecordMapper mapper, + bool includeVectors, + string vectorStoreSystemName, + string? vectorStoreName, + string collectionName, + string operationName) { - // Since the mapper doesn't know about scored points, we need to convert the scored point to a point struct first. - var pointStruct = new PointStruct - { - Id = point.Id, - Payload = { } - }; - - if (includeVectors) - { - pointStruct.Vectors = new(); - switch (point.Vectors.VectorsOptionsCase) - { - case VectorsOutput.VectorsOptionsOneofCase.Vector: - pointStruct.Vectors.Vector = point.Vectors.Vector.Data.ToArray(); - break; - case VectorsOutput.VectorsOptionsOneofCase.Vectors: - pointStruct.Vectors.Vectors_ = new(); - foreach (var v in point.Vectors.Vectors.Vectors) - { - // TODO: Refactor mapper to not require pre-mapping to pointstruct to avoid this ToArray conversion. - pointStruct.Vectors.Vectors_.Vectors.Add(v.Key, v.Value.Data.ToArray()); - } - break; - } - } - - foreach (KeyValuePair payloadEntry in point.Payload) - { - pointStruct.Payload.Add(payloadEntry.Key, payloadEntry.Value); - } - // Do the mapping with error handling. return new VectorSearchResult( VectorStoreErrorHandler.RunModelConversion( - databaseSystemName, + vectorStoreSystemName, + vectorStoreName, collectionName, operationName, - () => mapper.MapFromStorageToDataModel(pointStruct, new() { IncludeVectors = includeVectors })), + () => mapper.MapFromStorageToDataModel(point.Id, point.Payload, point.Vectors, new() { IncludeVectors = includeVectors })), point.Score); } + + internal static TRecord MapRetrievedPointToRecord( + RetrievedPoint point, + QdrantVectorStoreRecordMapper mapper, + bool includeVectors, + string vectorStoreSystemName, + string? vectorStoreName, + string collectionName, + string operationName) + { + // Do the mapping with error handling. + return VectorStoreErrorHandler.RunModelConversion( + vectorStoreSystemName, + vectorStoreName, + collectionName, + operationName, + () => mapper.MapFromStorageToDataModel(point.Id, point.Payload, point.Vectors, new() { IncludeVectors = includeVectors })); + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreOptions.cs index e7ce3f053970..97f59d5149a4 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreOptions.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.Connectors.Qdrant; @@ -16,7 +17,12 @@ public sealed class QdrantVectorStoreOptions public bool HasNamedVectors { get; set; } = false; /// - /// An optional factory to use for constructing instances, if a custom record collection is required. + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } + + /// + /// An optional factory to use for constructing instances, if a custom record collection is required. /// [Obsolete("To control how collections are instantiated, extend your provider's IVectorStore implementation and override GetCollection()")] public IQdrantVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollection.cs index 24cbd097f422..7673933d99af 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollection.cs @@ -1,14 +1,19 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Grpc.Core; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.Properties; using Qdrant.Client; using Qdrant.Client.Grpc; @@ -17,20 +22,16 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant; /// /// Service for storing and retrieving vector records, that uses Qdrant as the underlying storage. /// +/// The data type of the record key. Can be either or , or for dynamic mapping. /// The data model to use for adding, updating and retrieving data from storage. #pragma warning disable CA1711 // Identifiers should not have incorrect suffix -public class QdrantVectorStoreRecordCollection : - IVectorStoreRecordCollection, - IVectorStoreRecordCollection, - IKeywordHybridSearch +public sealed class QdrantVectorStoreRecordCollection : IVectorStoreRecordCollection, IKeywordHybridSearch + where TKey : notnull + where TRecord : notnull #pragma warning restore CA1711 // Identifiers should not have incorrect suffix { - /// A set of types that a key on the provided model may have. - private static readonly HashSet s_supportedKeyTypes = - [ - typeof(ulong), - typeof(Guid) - ]; + /// Metadata about vector store record collection. + private readonly VectorStoreRecordCollectionMetadata _collectionMetadata; /// The default options for vector search. private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); @@ -38,9 +39,6 @@ public class QdrantVectorStoreRecordCollection : /// The default options for hybrid vector search. private static readonly HybridSearchOptions s_defaultKeywordVectorizedHybridSearchOptions = new(); - /// The name of this database for telemetry purposes. - private const string DatabaseName = "Qdrant"; - /// The name of the upsert operation for telemetry purposes. private const string UpsertName = "Upsert"; @@ -50,91 +48,72 @@ public class QdrantVectorStoreRecordCollection : /// Qdrant client that can be used to manage the collections and points in a Qdrant store. private readonly MockableQdrantClient _qdrantClient; - /// The name of the collection that this will access. + /// The name of the collection that this will access. private readonly string _collectionName; /// Optional configuration options for this class. private readonly QdrantVectorStoreRecordCollectionOptions _options; - /// A helper to access property information for the current data model and record definition. - private readonly VectorStoreRecordPropertyReader _propertyReader; + /// The model for this collection. + private readonly VectorStoreRecordModel _model; /// A mapper to use for converting between qdrant point and consumer models. - private readonly IVectorStoreRecordMapper _mapper; + private readonly QdrantVectorStoreRecordMapper _mapper; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// Qdrant client that can be used to manage the collections and points in a Qdrant store. - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// Optional configuration options for this class. /// Thrown if the is null. /// Thrown for any misconfigured options. - public QdrantVectorStoreRecordCollection(QdrantClient qdrantClient, string collectionName, QdrantVectorStoreRecordCollectionOptions? options = null) - : this(new MockableQdrantClient(qdrantClient), collectionName, options) + public QdrantVectorStoreRecordCollection(QdrantClient qdrantClient, string name, QdrantVectorStoreRecordCollectionOptions? options = null) + : this(new MockableQdrantClient(qdrantClient), name, options) { } /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// Qdrant client that can be used to manage the collections and points in a Qdrant store. - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// Optional configuration options for this class. /// Thrown if the is null. /// Thrown for any misconfigured options. - internal QdrantVectorStoreRecordCollection(MockableQdrantClient qdrantClient, string collectionName, QdrantVectorStoreRecordCollectionOptions? options = null) + internal QdrantVectorStoreRecordCollection(MockableQdrantClient qdrantClient, string name, QdrantVectorStoreRecordCollectionOptions? options = null) { // Verify. Verify.NotNull(qdrantClient); - Verify.NotNullOrWhiteSpace(collectionName); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.PointStructCustomMapper is not null, s_supportedKeyTypes); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); + Verify.NotNullOrWhiteSpace(name); + + if (typeof(TKey) != typeof(ulong) && typeof(TKey) != typeof(Guid) && typeof(TKey) != typeof(object)) + { + throw new NotSupportedException("Only ulong and Guid keys are supported (and object for dynamic mapping)."); + } // Assign. this._qdrantClient = qdrantClient; - this._collectionName = collectionName; + this._collectionName = name; this._options = options ?? new QdrantVectorStoreRecordCollectionOptions(); - this._propertyReader = new VectorStoreRecordPropertyReader( - typeof(TRecord), - this._options.VectorStoreRecordDefinition, - new() - { - RequiresAtLeastOneVector = !this._options.HasNamedVectors, - SupportsMultipleKeys = false, - SupportsMultipleVectors = this._options.HasNamedVectors - }); - // Validate property types. - this._propertyReader.VerifyKeyProperties(s_supportedKeyTypes); + this._model = new VectorStoreRecordModelBuilder(QdrantVectorStoreRecordFieldMapping.GetModelBuildOptions(this._options.HasNamedVectors)) + .Build(typeof(TRecord), this._options.VectorStoreRecordDefinition, options?.EmbeddingGenerator); - // Assign Mapper. - if (this._options.PointStructCustomMapper is not null) - { - // Custom Mapper. - this._mapper = this._options.PointStructCustomMapper; - } - else if (typeof(TRecord) == typeof(VectorStoreGenericDataModel) || typeof(TRecord) == typeof(VectorStoreGenericDataModel)) - { - // Generic data model mapper. - this._mapper = (IVectorStoreRecordMapper)new QdrantGenericDataModelMapper( - this._propertyReader, - this._options.HasNamedVectors); - } - else + this._mapper = new QdrantVectorStoreRecordMapper(this._model, this._options.HasNamedVectors); + + this._collectionMetadata = new() { - // Default Mapper. - this._mapper = new QdrantVectorStoreRecordMapper( - this._propertyReader, - this._options.HasNamedVectors); - } + VectorStoreSystemName = QdrantConstants.VectorStoreSystemName, + CollectionName = name + }; } /// - public string CollectionName => this._collectionName; + public string Name => this._collectionName; /// - public virtual Task CollectionExistsAsync(CancellationToken cancellationToken = default) + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) { return this.RunOperationAsync( "CollectionExists", @@ -142,12 +121,12 @@ public virtual Task CollectionExistsAsync(CancellationToken cancellationTo } /// - public virtual async Task CreateCollectionAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionAsync(CancellationToken cancellationToken = default) { if (!this._options.HasNamedVectors) { // If we are not using named vectors, we can only have one vector property. We can assume we have exactly one, since this is already verified in the constructor. - var singleVectorProperty = this._propertyReader.VectorProperty; + var singleVectorProperty = this._model.VectorProperty; // Map the single vector property to the qdrant config. var vectorParams = QdrantVectorStoreCollectionCreateMapping.MapSingleVector(singleVectorProperty!); @@ -163,10 +142,10 @@ await this.RunOperationAsync( else { // Since we are using named vectors, iterate over all vector properties. - var vectorProperties = this._propertyReader.VectorProperties; + var vectorProperties = this._model.VectorProperties; // Map the named vectors to the qdrant config. - var vectorParamsMap = QdrantVectorStoreCollectionCreateMapping.MapNamedVectors(vectorProperties, this._propertyReader.StoragePropertyNamesMap); + var vectorParamsMap = QdrantVectorStoreCollectionCreateMapping.MapNamedVectors(vectorProperties); // Create the collection with named vectors. await this.RunOperationAsync( @@ -178,57 +157,55 @@ await this.RunOperationAsync( } // Add indexes for each of the data properties that require filtering. - var dataProperties = this._propertyReader.DataProperties.Where(x => x.IsFilterable); + var dataProperties = this._model.DataProperties.Where(x => x.IsIndexed); foreach (var dataProperty in dataProperties) { - var storageFieldName = this._propertyReader.GetStoragePropertyName(dataProperty.DataModelPropertyName); - - if (QdrantVectorStoreCollectionCreateMapping.s_schemaTypeMap.TryGetValue(dataProperty.PropertyType!, out PayloadSchemaType schemaType)) + if (QdrantVectorStoreCollectionCreateMapping.s_schemaTypeMap.TryGetValue(dataProperty.Type, out PayloadSchemaType schemaType)) { // Do nothing since schemaType is already set. } - else if (VectorStoreRecordPropertyVerification.IsSupportedEnumerableType(dataProperty.PropertyType) && VectorStoreRecordPropertyVerification.GetCollectionElementType(dataProperty.PropertyType) == typeof(string)) + else if (VectorStoreRecordPropertyVerification.IsSupportedEnumerableType(dataProperty.Type) && VectorStoreRecordPropertyVerification.GetCollectionElementType(dataProperty.Type) == typeof(string)) { // For enumerable of strings, use keyword schema type, since this allows tag filtering. schemaType = PayloadSchemaType.Keyword; } else { - throw new InvalidOperationException($"Property {nameof(VectorStoreRecordDataProperty.IsFilterable)} on {nameof(VectorStoreRecordDataProperty)} '{dataProperty.DataModelPropertyName}' is set to true, but the property type is not supported for filtering. The Qdrant VectorStore supports filtering on {string.Join(", ", QdrantVectorStoreCollectionCreateMapping.s_schemaTypeMap.Keys.Select(x => x.Name))} properties only."); + // TODO: This should move to model validation + throw new InvalidOperationException($"Property {nameof(VectorStoreRecordDataProperty.IsIndexed)} on {nameof(VectorStoreRecordDataProperty)} '{dataProperty.ModelName}' is set to true, but the property type is not supported for filtering. The Qdrant VectorStore supports filtering on {string.Join(", ", QdrantVectorStoreCollectionCreateMapping.s_schemaTypeMap.Keys.Select(x => x.Name))} properties only."); } await this.RunOperationAsync( "CreatePayloadIndex", () => this._qdrantClient.CreatePayloadIndexAsync( this._collectionName, - storageFieldName, + dataProperty.StorageName, schemaType, cancellationToken: cancellationToken)).ConfigureAwait(false); } // Add indexes for each of the data properties that require full text search. - dataProperties = this._propertyReader.DataProperties.Where(x => x.IsFullTextSearchable); + dataProperties = this._model.DataProperties.Where(x => x.IsFullTextIndexed); foreach (var dataProperty in dataProperties) { - if (dataProperty.PropertyType != typeof(string)) + // TODO: This should move to model validation + if (dataProperty.Type != typeof(string)) { - throw new InvalidOperationException($"Property {nameof(dataProperty.IsFullTextSearchable)} on {nameof(VectorStoreRecordDataProperty)} '{dataProperty.DataModelPropertyName}' is set to true, but the property type is not a string. The Qdrant VectorStore supports {nameof(dataProperty.IsFullTextSearchable)} on string properties only."); + throw new InvalidOperationException($"Property {nameof(dataProperty.IsFullTextIndexed)} on {nameof(VectorStoreRecordDataProperty)} '{dataProperty.ModelName}' is set to true, but the property type is not a string. The Qdrant VectorStore supports {nameof(dataProperty.IsFullTextIndexed)} on string properties only."); } - var storageFieldName = this._propertyReader.GetStoragePropertyName(dataProperty.DataModelPropertyName); - await this.RunOperationAsync( "CreatePayloadIndex", () => this._qdrantClient.CreatePayloadIndexAsync( this._collectionName, - storageFieldName, + dataProperty.StorageName, PayloadSchemaType.Text, cancellationToken: cancellationToken)).ConfigureAwait(false); } } /// - public virtual async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) { if (!await this.CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) { @@ -237,7 +214,7 @@ public virtual async Task CreateCollectionIfNotExistsAsync(CancellationToken can } /// - public virtual Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) => this.RunOperationAsync("DeleteCollection", async () => { @@ -260,272 +237,356 @@ public virtual Task DeleteCollectionAsync(CancellationToken cancellationToken = }); /// - public virtual async Task GetAsync(ulong key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public async Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(key); - var retrievedPoints = await this.GetBatchAsync([key], options, cancellationToken).ToListAsync(cancellationToken).ConfigureAwait(false); + var retrievedPoints = await this.GetAsync([key], options, cancellationToken).ToListAsync(cancellationToken).ConfigureAwait(false); return retrievedPoints.FirstOrDefault(); } /// - public virtual async Task GetAsync(Guid key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public async IAsyncEnumerable GetAsync( + IEnumerable keys, + GetRecordOptions? options = default, + [EnumeratorCancellation] CancellationToken cancellationToken = default) { - Verify.NotNull(key); + const string OperationName = "Retrieve"; - var retrievedPoints = await this.GetBatchAsync([key], options, cancellationToken).ToListAsync(cancellationToken).ConfigureAwait(false); - return retrievedPoints.FirstOrDefault(); - } + Verify.NotNull(keys); - /// - public virtual IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = default, CancellationToken cancellationToken = default) - { - return this.GetBatchByPointIdAsync(keys, key => new PointId { Num = key }, options, cancellationToken); - } + // Create options. + var pointsIds = new List(); - /// - public virtual IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = default, CancellationToken cancellationToken = default) - { - return this.GetBatchByPointIdAsync(keys, key => new PointId { Uuid = key.ToString("D") }, options, cancellationToken); - } + Type? keyType = null; - /// - public virtual Task DeleteAsync(ulong key, CancellationToken cancellationToken = default) - { - Verify.NotNull(key); + foreach (var key in keys) + { + switch (key) + { + case ulong id: + if (keyType == typeof(Guid)) + { + throw new NotSupportedException("Mixing ulong and Guid keys is not supported"); + } - return this.RunOperationAsync( - DeleteName, - () => this._qdrantClient.DeleteAsync( + keyType = typeof(ulong); + pointsIds.Add(new PointId { Num = id }); + break; + + case Guid id: + if (keyType == typeof(ulong)) + { + throw new NotSupportedException("Mixing ulong and Guid keys is not supported"); + } + + pointsIds.Add(new PointId { Uuid = id.ToString("D") }); + keyType = typeof(Guid); + break; + + default: + throw new NotSupportedException($"The provided key type '{key.GetType().Name}' is not supported by Qdrant."); + } + } + + var includeVectors = options?.IncludeVectors ?? false; + if (includeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + + // Retrieve data points. + var retrievedPoints = await this.RunOperationAsync( + OperationName, + () => this._qdrantClient.RetrieveAsync(this._collectionName, pointsIds, true, includeVectors, cancellationToken: cancellationToken)).ConfigureAwait(false); + + // Convert the retrieved points to the target data model. + foreach (var retrievedPoint in retrievedPoints) + { + yield return VectorStoreErrorHandler.RunModelConversion( + QdrantConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, this._collectionName, - key, - wait: true, - cancellationToken: cancellationToken)); + OperationName, + () => this._mapper.MapFromStorageToDataModel(retrievedPoint.Id, retrievedPoint.Payload, retrievedPoint.Vectors, new() { IncludeVectors = includeVectors })); + } } /// - public virtual Task DeleteAsync(Guid key, CancellationToken cancellationToken = default) + public Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) { Verify.NotNull(key); return this.RunOperationAsync( DeleteName, - () => this._qdrantClient.DeleteAsync( - this._collectionName, - key, - wait: true, - cancellationToken: cancellationToken)); + () => key switch + { + ulong id => this._qdrantClient.DeleteAsync(this._collectionName, id, wait: true, cancellationToken: cancellationToken), + Guid id => this._qdrantClient.DeleteAsync(this._collectionName, id, wait: true, cancellationToken: cancellationToken), + _ => throw new NotSupportedException($"The provided key type '{key.GetType().Name}' is not supported by Qdrant.") + }); } /// - public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + public Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default) { Verify.NotNull(keys); - return this.RunOperationAsync( - DeleteName, - () => this._qdrantClient.DeleteAsync( - this._collectionName, - keys.ToList(), - wait: true, - cancellationToken: cancellationToken)); - } + IList? keyList = null; - /// - public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) - { - Verify.NotNull(keys); + switch (keys) + { + case IEnumerable k: + keyList = k.ToList(); + break; + + case IEnumerable k: + keyList = k.ToList(); + break; + + case IEnumerable objectKeys: + { + // We need to cast the keys to a list of the same type as the first element. + List? guidKeys = null; + List? ulongKeys = null; + + var isFirst = true; + foreach (var key in objectKeys) + { + if (isFirst) + { + switch (key) + { + case ulong l: + ulongKeys = new List { l }; + keyList = ulongKeys; + break; + + case Guid g: + guidKeys = new List { g }; + keyList = guidKeys; + break; + + default: + throw new NotSupportedException($"The provided key type '{key.GetType().Name}' is not supported by Qdrant."); + } + + isFirst = false; + continue; + } + + switch (key) + { + case ulong u when ulongKeys is not null: + ulongKeys.Add(u); + continue; + + case Guid g when guidKeys is not null: + guidKeys.Add(g); + continue; + + case Guid or ulong: + throw new NotSupportedException("Mixing ulong and Guid keys is not supported"); + + default: + throw new NotSupportedException($"The provided key type '{key.GetType().Name}' is not supported by Qdrant."); + } + } + + break; + } + } + + if (keyList is { Count: 0 }) + { + return Task.CompletedTask; + } return this.RunOperationAsync( DeleteName, - () => this._qdrantClient.DeleteAsync( - this._collectionName, - keys.ToList(), - wait: true, - cancellationToken: cancellationToken)); - } - - /// - public virtual async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) - { - Verify.NotNull(record); + () => keyList switch + { + List keysList => this._qdrantClient.DeleteAsync( + this._collectionName, + keysList, + wait: true, + cancellationToken: cancellationToken), - // Create point from record. - var pointStruct = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this._collectionName, - UpsertName, - () => this._mapper.MapFromDataToStorageModel(record)); + List keysList => this._qdrantClient.DeleteAsync( + this._collectionName, + keysList, + wait: true, + cancellationToken: cancellationToken), - // Upsert. - await this.RunOperationAsync( - UpsertName, - () => this._qdrantClient.UpsertAsync(this._collectionName, [pointStruct], true, cancellationToken: cancellationToken)).ConfigureAwait(false); - return pointStruct.Id.Num; + _ => throw new UnreachableException() + }); } /// - async Task IVectorStoreRecordCollection.UpsertAsync(TRecord record, CancellationToken cancellationToken) + public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) { Verify.NotNull(record); - // Create point from record. - var pointStruct = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this._collectionName, - UpsertName, - () => this._mapper.MapFromDataToStorageModel(record)); + var keys = await this.UpsertAsync([record], cancellationToken).ConfigureAwait(false); - // Upsert. - await this.RunOperationAsync( - UpsertName, - () => this._qdrantClient.UpsertAsync(this._collectionName, [pointStruct], true, cancellationToken: cancellationToken)).ConfigureAwait(false); - return Guid.Parse(pointStruct.Id.Uuid); + return keys.Single(); } /// - public virtual async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async Task> UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) { Verify.NotNull(records); - // Create points from records. - var pointStructs = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this._collectionName, - UpsertName, - () => records.Select(this._mapper.MapFromDataToStorageModel).ToList()); + IReadOnlyList? recordsList = null; - // Upsert. - await this.RunOperationAsync( - UpsertName, - () => this._qdrantClient.UpsertAsync(this._collectionName, pointStructs, true, cancellationToken: cancellationToken)).ConfigureAwait(false); + // If an embedding generator is defined, invoke it once per property for all records. + GeneratedEmbeddings>?[]? generatedEmbeddings = null; - foreach (var pointStruct in pointStructs) + var vectorPropertyCount = this._model.VectorProperties.Count; + for (var i = 0; i < vectorPropertyCount; i++) { - yield return pointStruct.Id.Num; - } - } + var vectorProperty = this._model.VectorProperties[i]; - /// - async IAsyncEnumerable IVectorStoreRecordCollection.UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken) - { - Verify.NotNull(records); + if (vectorProperty.EmbeddingGenerator is null) + { + continue; + } + + if (recordsList is null) + { + recordsList = records is IReadOnlyList r ? r : records.ToList(); + + if (recordsList.Count == 0) + { + return []; + } + + records = recordsList; + } + + // TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties), + // and generate embeddings for them in a single batch. That's some more complexity though. + if (vectorProperty.TryGenerateEmbeddings, ReadOnlyMemory>(records, cancellationToken, out var task)) + { + generatedEmbeddings ??= new GeneratedEmbeddings>?[vectorPropertyCount]; + generatedEmbeddings[i] = await task.ConfigureAwait(false); + } + else + { + throw new InvalidOperationException( + $"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding).Name}' for the given input type."); + } + } // Create points from records. var pointStructs = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, + QdrantConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, this._collectionName, UpsertName, - () => records.Select(this._mapper.MapFromDataToStorageModel).ToList()); + () => records.Select((r, i) => this._mapper.MapFromDataToStorageModel(r, i, generatedEmbeddings)).ToList()); + + if (pointStructs is { Count: 0 }) + { + return Array.Empty(); + } // Upsert. await this.RunOperationAsync( UpsertName, () => this._qdrantClient.UpsertAsync(this._collectionName, pointStructs, true, cancellationToken: cancellationToken)).ConfigureAwait(false); - foreach (var pointStruct in pointStructs) - { - yield return Guid.Parse(pointStruct.Id.Uuid); - } + return pointStructs.Count == 0 + ? [] + : pointStructs[0].Id switch + { + { HasNum: true } => pointStructs.Select(pointStruct => (TKey)(object)pointStruct.Id.Num).ToList(), + { HasUuid: true } => pointStructs.Select(pointStruct => (TKey)(object)Guid.Parse(pointStruct.Id.Uuid)).ToList(), + _ => throw new UnreachableException("The Qdrant point ID is neither a number nor a UUID.") + }; } - /// - /// Get the requested records from the Qdrant store using the provided keys. - /// - /// The keys of the points to retrieve. - /// Function to convert the provided keys to point ids. - /// The retrieval options. - /// The to monitor for cancellation requests. The default is . - /// The retrieved points. - private async IAsyncEnumerable GetBatchByPointIdAsync( - IEnumerable keys, - Func keyConverter, - GetRecordOptions? options, - [EnumeratorCancellation] CancellationToken cancellationToken) - { - const string OperationName = "Retrieve"; - Verify.NotNull(keys); - - // Create options. - var pointsIds = keys.Select(key => keyConverter(key)).ToArray(); - var includeVectors = options?.IncludeVectors ?? false; + #region Search - // Retrieve data points. - var retrievedPoints = await this.RunOperationAsync( - OperationName, - () => this._qdrantClient.RetrieveAsync(this._collectionName, pointsIds, true, includeVectors, cancellationToken: cancellationToken)).ConfigureAwait(false); + /// + public async IAsyncEnumerable> SearchAsync( + TInput value, + int top, + VectorSearchOptions? options = default, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TInput : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); - // Convert the retrieved points to the target data model. - foreach (var retrievedPoint in retrievedPoints) + switch (vectorProperty.EmbeddingGenerator) { - var pointStruct = new PointStruct - { - Id = retrievedPoint.Id, - Payload = { } - }; + case IEmbeddingGenerator> generator: + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); - if (includeVectors) - { - pointStruct.Vectors = new(); - switch (retrievedPoint.Vectors.VectorsOptionsCase) + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) { - case VectorsOutput.VectorsOptionsOneofCase.Vector: - pointStruct.Vectors.Vector = retrievedPoint.Vectors.Vector.Data.ToArray(); - break; - case VectorsOutput.VectorsOptionsOneofCase.Vectors: - pointStruct.Vectors.Vectors_ = new(); - foreach (var v in retrievedPoint.Vectors.Vectors.Vectors) - { - // TODO: Refactor mapper to not require pre-mapping to pointstruct to avoid this ToArray conversion. - pointStruct.Vectors.Vectors_.Vectors.Add(v.Key, v.Value.Data.ToArray()); - } - break; + yield return record; } - } - foreach (KeyValuePair payloadEntry in retrievedPoint.Payload) - { - pointStruct.Payload.Add(payloadEntry.Key, payloadEntry.Value); - } + yield break; - yield return VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this._collectionName, - OperationName, - () => this._mapper.MapFromStorageToDataModel(pointStruct, new() { IncludeVectors = includeVectors })); + case null: + throw new InvalidOperationException(VectorDataStrings.NoEmbeddingGeneratorWasConfiguredForSearch); + + default: + throw new InvalidOperationException( + QdrantVectorStoreRecordFieldMapping.s_supportedVectorTypes.Contains(typeof(TInput)) + ? string.Format(VectorDataStrings.EmbeddingTypePassedToSearchAsync) + : string.Format(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType, typeof(TInput).Name, vectorProperty.EmbeddingGenerator.GetType().Name)); } } /// - public virtual async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public IAsyncEnumerable> SearchEmbeddingAsync( + TVector vector, + int top, + VectorSearchOptions? options = null, + CancellationToken cancellationToken = default) + where TVector : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + return this.SearchCoreAsync(vector, top, vectorProperty, operationName: "SearchEmbedding", options, cancellationToken); + } + + private async IAsyncEnumerable> SearchCoreAsync( + TVector vector, + int top, + VectorStoreRecordVectorPropertyModel vectorProperty, + string operationName, + VectorSearchOptions options, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TVector : notnull { var floatVector = VerifyVectorParam(vector); + Verify.NotLessThan(top, 1); - // Resolve options. - var internalOptions = options ?? s_defaultVectorSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(internalOptions); + if (options.IncludeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } #pragma warning disable CS0618 // Type or member is obsolete // Build filter object. - var filter = internalOptions switch + var filter = options switch { { OldFilter: not null, Filter: not null } => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), - { OldFilter: VectorSearchFilter legacyFilter } => QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(legacyFilter, this._propertyReader.StoragePropertyNamesMap), - { Filter: Expression> newFilter } => new QdrantFilterTranslator().Translate(newFilter, this._propertyReader.StoragePropertyNamesMap), + { OldFilter: VectorSearchFilter legacyFilter } => QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(legacyFilter, this._model), + { Filter: Expression> newFilter } => new QdrantFilterTranslator().Translate(newFilter, this._model), _ => new Filter() }; #pragma warning restore CS0618 // Type or member is obsolete - // Specify the vector name if named vectors are used. - string? vectorName = null; - if (this._options.HasNamedVectors) - { - vectorName = this._propertyReader.GetStoragePropertyName(vectorProperty.DataModelPropertyName); - } - // Specify whether to include vectors in the search results. var vectorsSelector = new WithVectorsSelector(); - vectorsSelector.Enable = internalOptions.IncludeVectors; + vectorsSelector.Enable = options.IncludeVectors; var query = new Query { @@ -534,14 +595,14 @@ public virtual async Task> VectorizedSearchAsync this._qdrantClient.QueryAsync( - this.CollectionName, + this.Name, query: query, - usingVector: vectorName, + usingVector: this._options.HasNamedVectors ? vectorProperty.StorageName : null, filter: filter, - limit: (ulong)internalOptions.Top, - offset: (ulong)internalOptions.Skip, + limit: (ulong)top, + offset: (ulong)options.Skip, vectorsSelector: vectorsSelector, cancellationToken: cancellationToken)).ConfigureAwait(false); @@ -549,47 +610,108 @@ public virtual async Task> VectorizedSearchAsync QdrantVectorStoreCollectionSearchMapping.MapScoredPointToVectorSearchResult( point, this._mapper, - internalOptions.IncludeVectors, - DatabaseName, + options.IncludeVectors, + QdrantConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, this._collectionName, "Query")); - return new VectorSearchResults(mappedResults.ToAsyncEnumerable()); + foreach (var result in mappedResults) + { + yield return result; + } + } + + /// + [Obsolete("Use either SearchEmbeddingAsync to search directly on embeddings, or SearchAsync to handle embedding generation internally as part of the call.")] + public IAsyncEnumerable> VectorizedSearchAsync(TVector vector, int top, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + => this.SearchEmbeddingAsync(vector, top, options, cancellationToken); + + #endregion Search + + /// + public async IAsyncEnumerable GetAsync(Expression> filter, int top, + GetFilteredRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(filter); + Verify.NotLessThan(top, 1); + + options ??= new(); + + var translatedFilter = new QdrantFilterTranslator().Translate(filter, this._model); + + // Specify whether to include vectors in the search results. + WithVectorsSelector vectorsSelector = new() { Enable = options.IncludeVectors }; + + var sortInfo = options.OrderBy.Values.Count switch + { + 0 => null, + 1 => options.OrderBy.Values[0], + _ => throw new NotSupportedException("Qdrant does not support ordering by more than one property.") + }; + + OrderBy? orderBy = null; + if (sortInfo is not null) + { + var orderByName = this._model.GetDataOrKeyProperty(sortInfo.PropertySelector).StorageName; + orderBy = new(orderByName) + { + Direction = sortInfo.Ascending ? global::Qdrant.Client.Grpc.Direction.Asc : global::Qdrant.Client.Grpc.Direction.Desc + }; + } + + var scrollResponse = await this.RunOperationAsync( + "Scroll", + () => this._qdrantClient.ScrollAsync( + this.Name, + translatedFilter, + vectorsSelector, + limit: (uint)(top + options.Skip), + orderBy, + cancellationToken: cancellationToken)).ConfigureAwait(false); + + var mappedResults = scrollResponse.Result.Skip(options.Skip).Select(point => QdrantVectorStoreCollectionSearchMapping.MapRetrievedPointToRecord( + point, + this._mapper, + options.IncludeVectors, + QdrantConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this._collectionName, + "Scroll")); + + foreach (var mappedResult in mappedResults) + { + yield return mappedResult; + } } /// - public async Task> HybridSearchAsync(TVector vector, ICollection keywords, HybridSearchOptions? options = null, CancellationToken cancellationToken = default) + public async IAsyncEnumerable> HybridSearchAsync(TVector vector, ICollection keywords, int top, HybridSearchOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var floatVector = VerifyVectorParam(vector); + Verify.NotLessThan(top, 1); // Resolve options. - var internalOptions = options ?? s_defaultKeywordVectorizedHybridSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(new() { VectorProperty = internalOptions.VectorProperty }); - var textDataProperty = this._propertyReader.GetFullTextDataPropertyOrSingle(internalOptions.AdditionalProperty); - var textDataPropertyName = this._propertyReader.GetStoragePropertyName(textDataProperty.DataModelPropertyName); + options ??= s_defaultKeywordVectorizedHybridSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(new() { VectorProperty = options.VectorProperty }); + var textDataProperty = this._model.GetFullTextDataPropertyOrSingle(options.AdditionalProperty); // Build filter object. #pragma warning disable CS0618 // Type or member is obsolete // Build filter object. - var filter = internalOptions switch + var filter = options switch { { OldFilter: not null, Filter: not null } => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), - { OldFilter: VectorSearchFilter legacyFilter } => QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(legacyFilter, this._propertyReader.StoragePropertyNamesMap), - { Filter: Expression> newFilter } => new QdrantFilterTranslator().Translate(newFilter, this._propertyReader.StoragePropertyNamesMap), + { OldFilter: VectorSearchFilter legacyFilter } => QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(legacyFilter, this._model), + { Filter: Expression> newFilter } => new QdrantFilterTranslator().Translate(newFilter, this._model), _ => new Filter() }; #pragma warning restore CS0618 // Type or member is obsolete - // Specify the vector name if named vectors are used. - string? vectorName = null; - if (this._options.HasNamedVectors) - { - vectorName = this._propertyReader.GetStoragePropertyName(vectorProperty.DataModelPropertyName); - } - // Specify whether to include vectors in the search results. var vectorsSelector = new WithVectorsSelector(); - vectorsSelector.Enable = internalOptions.IncludeVectors; + vectorsSelector.Enable = options.IncludeVectors; // Build the vector query. var vectorQuery = new PrefetchQuery @@ -603,7 +725,7 @@ public async Task> HybridSearchAsync(TVect if (this._options.HasNamedVectors) { - vectorQuery.Using = vectorName; + vectorQuery.Using = this._options.HasNamedVectors ? vectorProperty.StorageName : null; } // Build the keyword query. @@ -611,7 +733,7 @@ public async Task> HybridSearchAsync(TVect var keywordSubFilter = new Filter(); foreach (string keyword in keywords) { - keywordSubFilter.Should.Add(new Condition() { Field = new FieldCondition() { Key = textDataPropertyName, Match = new Match { Text = keyword } } }); + keywordSubFilter.Should.Add(new Condition() { Field = new FieldCondition() { Key = textDataProperty.StorageName, Match = new Match { Text = keyword } } }); } keywordFilter.Must.Add(new Condition() { Filter = keywordSubFilter }); var keywordQuery = new PrefetchQuery @@ -629,11 +751,11 @@ public async Task> HybridSearchAsync(TVect var points = await this.RunOperationAsync( "Query", () => this._qdrantClient.QueryAsync( - this.CollectionName, + this.Name, prefetch: new List() { vectorQuery, keywordQuery }, query: fusionQuery, - limit: (ulong)internalOptions.Top, - offset: (ulong)internalOptions.Skip, + limit: (ulong)top, + offset: (ulong)options.Skip, vectorsSelector: vectorsSelector, cancellationToken: cancellationToken)).ConfigureAwait(false); @@ -641,12 +763,29 @@ public async Task> HybridSearchAsync(TVect var mappedResults = points.Select(point => QdrantVectorStoreCollectionSearchMapping.MapScoredPointToVectorSearchResult( point, this._mapper, - internalOptions.IncludeVectors, - DatabaseName, + options.IncludeVectors, + QdrantConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, this._collectionName, "Query")); - return new VectorSearchResults(mappedResults.ToAsyncEnumerable()); + foreach (var result in mappedResults) + { + yield return result; + } + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreRecordCollectionMetadata) ? this._collectionMetadata : + serviceType == typeof(QdrantClient) ? this._qdrantClient.QdrantClient : + serviceType.IsInstanceOfType(this) ? this : + null; } /// @@ -665,7 +804,8 @@ private async Task RunOperationAsync(string operationName, Func operation) { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreSystemName = QdrantConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, CollectionName = this._collectionName, OperationName = operationName }; @@ -689,7 +829,8 @@ private async Task RunOperationAsync(string operationName, Func> o { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreSystemName = QdrantConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, CollectionName = this._collectionName, OperationName = operationName }; diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollectionOptions.cs index bdb1a8658e59..78d633435a6f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollectionOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollectionOptions.cs @@ -1,12 +1,14 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; using Qdrant.Client.Grpc; namespace Microsoft.SemanticKernel.Connectors.Qdrant; /// -/// Options when creating a . +/// Options when creating a . /// public sealed class QdrantVectorStoreRecordCollectionOptions { @@ -22,6 +24,7 @@ public sealed class QdrantVectorStoreRecordCollectionOptions /// /// If not set, a default mapper that uses json as an intermediary to allow automatic mapping to a wide variety of types will be used. /// + [Obsolete("Custom mappers are no longer supported.", error: true)] public IVectorStoreRecordMapper? PointStructCustomMapper { get; init; } = null; /// @@ -33,4 +36,9 @@ public sealed class QdrantVectorStoreRecordCollectionOptions /// See , and . /// public VectorStoreRecordDefinition? VectorStoreRecordDefinition { get; init; } = null; + + /// + /// Gets or sets the default embedding generator for vector properties in this collection. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordFieldMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordFieldMapping.cs index 7134d45f47d7..fae43c2f2de0 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordFieldMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordFieldMapping.cs @@ -6,6 +6,7 @@ using System.Globalization; using System.Linq; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Qdrant.Client.Grpc; namespace Microsoft.SemanticKernel.Connectors.Qdrant; @@ -15,6 +16,19 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant; /// internal static class QdrantVectorStoreRecordFieldMapping { + public static VectorStoreRecordModelBuildingOptions GetModelBuildOptions(bool hasNamedVectors) + => new() + { + RequiresAtLeastOneVector = !hasNamedVectors, + SupportsMultipleKeys = false, + SupportsMultipleVectors = hasNamedVectors, + + SupportedKeyPropertyTypes = [typeof(ulong), typeof(Guid)], + SupportedDataPropertyTypes = QdrantVectorStoreRecordFieldMapping.s_supportedDataTypes, + SupportedEnumerableDataPropertyElementTypes = QdrantVectorStoreRecordFieldMapping.s_supportedDataTypes, + SupportedVectorPropertyTypes = QdrantVectorStoreRecordFieldMapping.s_supportedVectorTypes + }; + /// A set of types that data properties on the provided model may have. public static readonly HashSet s_supportedDataTypes = [ @@ -24,13 +38,7 @@ internal static class QdrantVectorStoreRecordFieldMapping typeof(double), typeof(float), typeof(bool), - typeof(DateTimeOffset), - typeof(int?), - typeof(long?), - typeof(double?), - typeof(float?), - typeof(bool?), - typeof(DateTimeOffset?), + typeof(DateTimeOffset) ]; /// A set of types that vectors on the provided model may have. diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordMapper.cs index 4b2963c464d7..368e4510b094 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordMapper.cs @@ -1,8 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Diagnostics; using System.Linq; +using Google.Protobuf.Collections; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Qdrant.Client.Grpc; namespace Microsoft.SemanticKernel.Connectors.Qdrant; @@ -11,54 +15,26 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant; /// Mapper between a Qdrant record and the consumer data model that uses json as an intermediary to allow supporting a wide range of models. /// /// The consumer data model to map to or from. -internal sealed class QdrantVectorStoreRecordMapper : IVectorStoreRecordMapper +internal sealed class QdrantVectorStoreRecordMapper(VectorStoreRecordModel model, bool hasNamedVectors) { - /// A helper to access property information for the current data model and record definition. - private readonly VectorStoreRecordPropertyReader _propertyReader; - - /// A value indicating whether the vectors in the store are named, or whether there is just a single unnamed vector per qdrant point. - private readonly bool _hasNamedVectors; - - /// - /// Initializes a new instance of the class. - /// - /// A helper to access property information for the current data model and record definition. - /// A value indicating whether the vectors in the store are named, or whether there is just a single unnamed vector per qdrant point. - public QdrantVectorStoreRecordMapper( - VectorStoreRecordPropertyReader propertyReader, - bool hasNamedVectors) - { - Verify.NotNull(propertyReader); - - // Validate property types. - propertyReader.VerifyHasParameterlessConstructor(); - propertyReader.VerifyDataProperties(QdrantVectorStoreRecordFieldMapping.s_supportedDataTypes, supportEnumerable: true); - propertyReader.VerifyVectorProperties(QdrantVectorStoreRecordFieldMapping.s_supportedVectorTypes); - - // Assign. - this._propertyReader = propertyReader; - this._hasNamedVectors = hasNamedVectors; - } - /// - public PointStruct MapFromDataToStorageModel(TRecord dataModel) + public PointStruct MapFromDataToStorageModel(TRecord dataModel, int recordIndex, GeneratedEmbeddings>?[]? generatedEmbeddings) { - PointId pointId; - var keyPropertyInfo = this._propertyReader.KeyPropertyInfo; - if (keyPropertyInfo.PropertyType == typeof(ulong)) - { - var key = keyPropertyInfo.GetValue(dataModel) as ulong? ?? throw new VectorStoreRecordMappingException($"Missing key property {keyPropertyInfo.Name} on provided record of type {typeof(TRecord).FullName}."); - pointId = new PointId { Num = key }; - } - else if (keyPropertyInfo.PropertyType == typeof(Guid)) - { - var key = keyPropertyInfo.GetValue(dataModel) as Guid? ?? throw new VectorStoreRecordMappingException($"Missing key property {keyPropertyInfo.Name} on provided record of type {typeof(TRecord).FullName}."); - pointId = new PointId { Uuid = key.ToString("D") }; - } - else + var keyProperty = model.KeyProperty; + + var pointId = keyProperty.Type switch { - throw new VectorStoreRecordMappingException($"Unsupported key type {keyPropertyInfo.PropertyType.FullName} for key property {keyPropertyInfo.Name} on provided record of type {typeof(TRecord).FullName}."); - } + var t when t == typeof(ulong) => new PointId + { + Num = (ulong?)keyProperty.GetValueAsObject(dataModel!) ?? throw new VectorStoreRecordMappingException($"Missing key property '{keyProperty.ModelName}' on provided record of type '{typeof(TRecord).Name}'.") + }, + + var t when t == typeof(Guid) => new PointId + { + Uuid = ((Guid?)keyProperty.GetValueAsObject(dataModel!))?.ToString("D") ?? throw new VectorStoreRecordMappingException($"Missing key property '{keyProperty.ModelName}' on provided record of type '{typeof(TRecord).Name}'.") + }, + _ => throw new VectorStoreRecordMappingException($"Unsupported key type '{keyProperty.Type.Name}' for key property '{keyProperty.ModelName}' on provided record of type '{typeof(TRecord).Name}'.") + }; // Create point. var pointStruct = new PointStruct @@ -69,26 +45,28 @@ public PointStruct MapFromDataToStorageModel(TRecord dataModel) }; // Add point payload. - foreach (var dataPropertyInfo in this._propertyReader.DataPropertiesInfo) + foreach (var property in model.DataProperties) { - var propertyName = this._propertyReader.GetStoragePropertyName(dataPropertyInfo.Name); - var propertyValue = dataPropertyInfo.GetValue(dataModel); - pointStruct.Payload.Add(propertyName, QdrantVectorStoreRecordFieldMapping.ConvertToGrpcFieldValue(propertyValue)); + var propertyValue = property.GetValueAsObject(dataModel!); + pointStruct.Payload.Add(property.StorageName, QdrantVectorStoreRecordFieldMapping.ConvertToGrpcFieldValue(propertyValue)); } // Add vectors. - if (this._hasNamedVectors) + if (hasNamedVectors) { var namedVectors = new NamedVectors(); - foreach (var vectorPropertyInfo in this._propertyReader.VectorPropertiesInfo) + + for (var i = 0; i < model.VectorProperties.Count; i++) { - var propertyName = this._propertyReader.GetStoragePropertyName(vectorPropertyInfo.Name); - var propertyValue = vectorPropertyInfo.GetValue(dataModel); - if (propertyValue is not null) - { - var castPropertyValue = (ReadOnlyMemory)propertyValue; - namedVectors.Vectors.Add(propertyName, castPropertyValue.ToArray()); - } + var property = model.VectorProperties[i]; + + namedVectors.Vectors.Add( + property.StorageName, + GetVector( + property, + generatedEmbeddings?[i] is GeneratedEmbeddings> e + ? e[recordIndex] + : property.GetValueAsObject(dataModel!))); } pointStruct.Vectors.Vectors_ = namedVectors; @@ -96,59 +74,71 @@ public PointStruct MapFromDataToStorageModel(TRecord dataModel) else { // We already verified in the constructor via FindProperties that there is exactly one vector property when not using named vectors. - var vectorPropertyInfo = this._propertyReader.FirstVectorPropertyInfo!; - if (vectorPropertyInfo.GetValue(dataModel) is ReadOnlyMemory floatROM) - { - pointStruct.Vectors.Vector = floatROM.ToArray(); - } - else - { - throw new VectorStoreRecordMappingException($"Vector property {vectorPropertyInfo.Name} on provided record of type {typeof(TRecord).FullName} may not be null when not using named vectors."); - } + Debug.Assert( + generatedEmbeddings is null || generatedEmbeddings.Length == 1 && generatedEmbeddings[0] is not null, + "There should be exactly one generated embedding when not using named vectors (single vector property)."); + pointStruct.Vectors.Vector = GetVector( + model.VectorProperty, + generatedEmbeddings is null + ? model.VectorProperty.GetValueAsObject(dataModel!) + : generatedEmbeddings[0]![recordIndex].Vector); } return pointStruct; + + Vector GetVector(VectorStoreRecordPropertyModel property, object? embedding) + => embedding switch + { + ReadOnlyMemory floatVector => floatVector.ToArray(), + null => throw new VectorStoreRecordMappingException($"Vector property '{property.ModelName}' on provided record of type '{typeof(TRecord).Name}' may not be null when not using named vectors."), + var unknownEmbedding => throw new VectorStoreRecordMappingException($"Vector property '{property.ModelName}' on provided record of type '{typeof(TRecord).Name}' has unsupported embedding type '{unknownEmbedding.GetType().Name}'.") + }; } /// - public TRecord MapFromStorageToDataModel(PointStruct storageModel, StorageToDataModelMapperOptions options) + public TRecord MapFromStorageToDataModel(PointId pointId, MapField payload, VectorsOutput vectorsOutput, StorageToDataModelMapperOptions options) { - // Get the key property name and value. - var keyPropertyValue = storageModel.Id.HasNum ? storageModel.Id.Num as object : new Guid(storageModel.Id.Uuid) as object; + var outputRecord = model.CreateRecord()!; - // Construct the output record. - var outputRecord = (TRecord)this._propertyReader.ParameterLessConstructorInfo.Invoke(null); - - // Set Key - this._propertyReader.KeyPropertyInfo.SetValue(outputRecord, keyPropertyValue); + // TODO: Set the following generically to avoid boxing + model.KeyProperty.SetValueAsObject(outputRecord, pointId switch + { + { HasNum: true } => pointId.Num, + { HasUuid: true } => Guid.Parse(pointId.Uuid), + _ => throw new UnreachableException() + }); // Set each vector property if embeddings are included in the point. if (options?.IncludeVectors is true) { - if (this._hasNamedVectors) + if (hasNamedVectors) { - VectorStoreRecordMapping.SetValuesOnProperties( - outputRecord, - this._propertyReader.VectorPropertiesInfo, - this._propertyReader.StoragePropertyNamesMap, - storageModel.Vectors.Vectors_.Vectors, - (Vector vector, Type targetType) => new ReadOnlyMemory(vector.Data.ToArray())); + var storageVectors = vectorsOutput.Vectors.Vectors; + + foreach (var vectorProperty in model.VectorProperties) + { + vectorProperty.SetValueAsObject( + outputRecord, + new ReadOnlyMemory(storageVectors[vectorProperty.StorageName].Data.ToArray())); + } } else { - this._propertyReader.FirstVectorPropertyInfo!.SetValue( + model.VectorProperty.SetValueAsObject( outputRecord, - new ReadOnlyMemory(storageModel.Vectors.Vector.Data.ToArray())); + new ReadOnlyMemory(vectorsOutput.Vector.Data.ToArray())); } } - // Set each data property. - VectorStoreRecordMapping.SetValuesOnProperties( - outputRecord, - this._propertyReader.DataPropertiesInfo, - this._propertyReader.StoragePropertyNamesMap, - storageModel.Payload, - QdrantVectorStoreRecordFieldMapping.ConvertFromGrpcFieldValueToNativeType); + foreach (var dataProperty in model.DataProperties) + { + if (payload.TryGetValue(dataProperty.StorageName, out var fieldValue)) + { + dataProperty.SetValueAsObject( + outputRecord, + QdrantVectorStoreRecordFieldMapping.ConvertFromGrpcFieldValueToNativeType(fieldValue, dataProperty.Type)); + } + } return outputRecord; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/Connectors.Memory.Redis.csproj b/dotnet/src/Connectors/Connectors.Memory.Redis/Connectors.Memory.Redis.csproj index 3f2fd4360fe8..3c9e83f89056 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/Connectors.Memory.Redis.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/Connectors.Memory.Redis.csproj @@ -4,13 +4,14 @@ Microsoft.SemanticKernel.Connectors.Redis $(AssemblyName) - net8.0;netstandard2.0 + net8.0;netstandard2.0;net462 preview + @@ -19,15 +20,21 @@ - + + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/IRedisJsonMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/IRedisJsonMapper.cs new file mode 100644 index 000000000000..0585c0055fb1 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/IRedisJsonMapper.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Text.Json.Nodes; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +internal interface IRedisJsonMapper +{ + /// + /// Maps from the consumer record data model to the storage model. + /// + (string Key, JsonNode Node) MapFromDataToStorageModel(TRecord dataModel, int recordIndex, IReadOnlyList?[]? generatedEmbeddings); + + /// + /// Maps from the storage model to the consumer record data model. + /// + TRecord MapFromStorageToDataModel((string Key, JsonNode Node) storageModel, StorageToDataModelMapperOptions options); +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/IRedisVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/IRedisVectorStoreRecordCollectionFactory.cs index ea98a9a6308d..519ef2151eb2 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/IRedisVectorStoreRecordCollectionFactory.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/IRedisVectorStoreRecordCollectionFactory.cs @@ -22,5 +22,6 @@ public interface IRedisVectorStoreRecordCollectionFactory /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. /// The new instance of . IVectorStoreRecordCollection CreateVectorStoreRecordCollection(IDatabase database, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) - where TKey : notnull; + where TKey : notnull + where TRecord : notnull; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/README.md b/dotnet/src/Connectors/Connectors.Memory.Redis/README.md index 8acfd839a810..c0feab4eb169 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/README.md +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/README.md @@ -22,19 +22,6 @@ Ways to get RediSearch: docker run -d --name redis-stack-server -p 6379:6379 redis/redis-stack-server:latest ``` -2. To use Redis as a semantic memory store: - > See [Example 14](../../../samples/Concepts/Memory/SemanticTextMemory_Building.cs) and [Example 15](../../../samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs) for more memory usage examples with the kernel. +2. Create a Redis Vector Store using instructions on the [Microsoft Learn site](https://learn.microsoft.com/semantic-kernel/concepts/vector-store-connectors/out-of-the-box-connectors/redis-connector). -```csharp -// ConnectionMultiplexer should be a singleton instance in your application, please consider to dispose of it when your application shuts down. -// See https://stackexchange.github.io/StackExchange.Redis/Basics#basic-usage -ConnectionMultiplexer connectionMultiplexer = await ConnectionMultiplexer.ConnectAsync("localhost:6379"); -IDatabase database = connectionMultiplexer.GetDatabase(); -RedisMemoryStore memoryStore = new RedisMemoryStore(database, vectorSize: 1536); - -var embeddingGenerator = new OpenAITextEmbeddingGenerationService("text-embedding-ada-002", apiKey); - -SemanticTextMemory textMemory = new(memoryStore, embeddingGenerator); - -var memoryPlugin = kernel.ImportPluginFromObject(new TextMemoryPlugin(textMemory)); -``` +3. Use the [getting started instructions](https://learn.microsoft.com/semantic-kernel/concepts/vector-store-connectors/?pivots=programming-language-csharp#getting-started-with-vector-store-connectors) on the Microsoft Leearn site to learn more about using the vector store. diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisConstants.cs new file mode 100644 index 000000000000..8d3e442ff671 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisConstants.cs @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +internal static class RedisConstants +{ + internal const string VectorStoreSystemName = "redis"; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisFilterTranslator.cs index ec5bcd73514f..92ecdd3fd798 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisFilterTranslator.cs @@ -6,28 +6,31 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; using System.Text; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.ConnectorSupport.Filter; namespace Microsoft.SemanticKernel.Connectors.Redis; internal class RedisFilterTranslator { - private IReadOnlyDictionary _storagePropertyNames = null!; + private VectorStoreRecordModel _model = null!; private ParameterExpression _recordParameter = null!; private readonly StringBuilder _filter = new(); - internal string Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + internal string Translate(LambdaExpression lambdaExpression, VectorStoreRecordModel model) { Debug.Assert(this._filter.Length == 0); - this._storagePropertyNames = storagePropertyNames; + this._model = model; Debug.Assert(lambdaExpression.Parameters.Count == 1); this._recordParameter = lambdaExpression.Parameters[0]; - this.Translate(lambdaExpression.Body); + var preprocessor = new FilterTranslationPreprocessor { InlineCapturedVariables = true }; + var preprocessedExpression = preprocessor.Visit(lambdaExpression.Body); + + this.Translate(preprocessedExpression); return this._filter.ToString(); } @@ -67,7 +70,7 @@ or ExpressionType.LessThan or ExpressionType.LessThanOrEqual return; // MemberExpression is generally handled within e.g. TranslateEqual; this is used to translate direct bool inside filter (e.g. Filter => r => r.Bool) - case MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _): + case MemberExpression member when member.Type == typeof(bool) && this.TryBindProperty(member, out _): { this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(true))); return; @@ -92,8 +95,7 @@ private void TranslateEqualityComparison(BinaryExpression binary) bool TryProcessEqualityComparison(Expression first, Expression second) { // TODO: Nullable - if (this.TryTranslateFieldAccess(first, out var storagePropertyName) - && TryGetConstant(second, out var constantValue)) + if (this.TryBindProperty(first, out var property) && second is ConstantExpression { Value: var constantValue }) { // Numeric negation has a special syntax (!=), for the rest we nest in a NOT if (binary.NodeType is ExpressionType.NotEqual && constantValue is not int or long or float or double) @@ -103,17 +105,17 @@ bool TryProcessEqualityComparison(Expression first, Expression second) } // https://redis.io/docs/latest/develop/interact/search-and-query/query/exact-match - this._filter.Append('@').Append(storagePropertyName); + this._filter.Append('@').Append(property.StorageName); this._filter.Append( binary.NodeType switch { ExpressionType.Equal when constantValue is int or long or float or double => $" == {constantValue}", ExpressionType.Equal when constantValue is string stringValue -#if NETSTANDARD2_0 - => $$""":{"{{stringValue.Replace("\"", "\"\"")}}"}""", -#else +#if NET8_0_OR_GREATER => $$""":{"{{stringValue.Replace("\"", "\\\"", StringComparison.Ordinal)}}"}""", +#else + => $$""":{"{{stringValue.Replace("\"", "\"\"")}}"}""", #endif ExpressionType.Equal when constantValue is null => throw new NotSupportedException("Null value type not supported"), // TODO @@ -175,13 +177,11 @@ private void TranslateMethodCall(MethodCallExpression methodCall) private void TranslateContains(Expression source, Expression item) { // Contains over tag field - if (this.TryTranslateFieldAccess(source, out var storagePropertyName) - && TryGetConstant(item, out var itemConstant) - && itemConstant is string stringConstant) + if (this.TryBindProperty(source, out var property) && item is ConstantExpression { Value: string stringConstant }) { this._filter .Append('@') - .Append(storagePropertyName) + .Append(property.StorageName) .Append(":{") .Append(stringConstant) .Append('}'); @@ -191,40 +191,49 @@ private void TranslateContains(Expression source, Expression item) throw new NotSupportedException("Contains supported only over tag field"); } - private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName) + private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out VectorStoreRecordPropertyModel? property) { - if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter) - { - if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName)) - { - throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name."); - } + Type? convertedClrType = null; - return true; + if (expression is UnaryExpression { NodeType: ExpressionType.Convert } unary) + { + expression = unary.Operand; + convertedClrType = unary.Type; } - storagePropertyName = null; - return false; - } + var modelName = expression switch + { + // Regular member access for strongly-typed POCO binding (e.g. r => r.SomeInt == 8) + MemberExpression memberExpression when memberExpression.Expression == this._recordParameter + => memberExpression.Member.Name, - private static bool TryGetConstant(Expression expression, out object? constantValue) - { - switch (expression) + // Dictionary lookup for weakly-typed dynamic binding (e.g. r => r["SomeInt"] == 8) + MethodCallExpression + { + Method: { Name: "get_Item", DeclaringType: var declaringType }, + Arguments: [ConstantExpression { Value: string keyName }] + } methodCall when methodCall.Object == this._recordParameter && declaringType == typeof(Dictionary) + => keyName, + + _ => null + }; + + if (modelName is null) { - case ConstantExpression { Value: var v }: - constantValue = v; - return true; + property = null; + return false; + } - // This identifies compiler-generated closure types which contain captured variables. - case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } - when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) - && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true): - constantValue = fieldInfo.GetValue(constant.Value); - return true; + if (!this._model.PropertyMap.TryGetValue(modelName, out property)) + { + throw new InvalidOperationException($"Property name '{modelName}' provided as part of the filter clause is not a valid property name."); + } - default: - constantValue = null; - return false; + if (convertedClrType is not null && convertedClrType != property.Type) + { + throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convertedClrType.Name}', but its configured type is '{property.Type.Name}'."); } + + return true; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetGenericDataModelMapper.cs deleted file mode 100644 index c4676976db9d..000000000000 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetGenericDataModelMapper.cs +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.InteropServices; -using Microsoft.Extensions.VectorData; -using StackExchange.Redis; - -namespace Microsoft.SemanticKernel.Connectors.Redis; - -/// -/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Redis when using hash sets. -/// -internal class RedisHashSetGenericDataModelMapper : IVectorStoreRecordMapper, (string Key, HashEntry[] HashEntries)> -{ - /// All the properties from the record definition. - private readonly IReadOnlyList _properties; - - /// - /// Initializes a new instance of the class. - /// - /// All the properties from the record definition. - public RedisHashSetGenericDataModelMapper(IReadOnlyList properties) - { - Verify.NotNull(properties); - this._properties = properties; - } - - /// - public (string Key, HashEntry[] HashEntries) MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - { - var hashEntries = new List(); - - foreach (var property in this._properties) - { - var storagePropertyName = property.StoragePropertyName ?? property.DataModelPropertyName; - var sourceDictionary = property is VectorStoreRecordDataProperty ? dataModel.Data : dataModel.Vectors; - - // Only map properties across that actually exist in the input. - if (sourceDictionary is null || !sourceDictionary.TryGetValue(property.DataModelPropertyName, out var sourceValue)) - { - continue; - } - - // Replicate null if the property exists but is null. - if (sourceValue is null) - { - hashEntries.Add(new HashEntry(storagePropertyName, RedisValue.Null)); - continue; - } - - // Map data Properties - if (property is VectorStoreRecordDataProperty dataProperty) - { - hashEntries.Add(new HashEntry(storagePropertyName, RedisValue.Unbox(sourceValue))); - } - // Map vector properties - else if (property is VectorStoreRecordVectorProperty vectorProperty) - { - if (sourceValue is ReadOnlyMemory rom) - { - hashEntries.Add(new HashEntry(storagePropertyName, RedisVectorStoreRecordFieldMapping.ConvertVectorToBytes(rom))); - } - else if (sourceValue is ReadOnlyMemory rod) - { - hashEntries.Add(new HashEntry(storagePropertyName, RedisVectorStoreRecordFieldMapping.ConvertVectorToBytes(rod))); - } - else - { - throw new VectorStoreRecordMappingException($"Unsupported vector type {sourceValue.GetType().Name} found on property ${vectorProperty.DataModelPropertyName}. Only float and double vectors are supported."); - } - } - } - - return (dataModel.Key, hashEntries.ToArray()); - } - - /// - public VectorStoreGenericDataModel MapFromStorageToDataModel((string Key, HashEntry[] HashEntries) storageModel, StorageToDataModelMapperOptions options) - { - var dataModel = new VectorStoreGenericDataModel(storageModel.Key); - - foreach (var property in this._properties) - { - var storagePropertyName = property.StoragePropertyName ?? property.DataModelPropertyName; - var targetDictionary = property is VectorStoreRecordDataProperty ? dataModel.Data : dataModel.Vectors; - var hashEntry = storageModel.HashEntries.FirstOrDefault(x => x.Name == storagePropertyName); - - // Only map properties across that actually exist in the input. - if (!hashEntry.Name.HasValue) - { - continue; - } - - // Replicate null if the property exists but is null. - if (hashEntry.Value.IsNull) - { - targetDictionary.Add(property.DataModelPropertyName, null); - continue; - } - - // Map data Properties - if (property is VectorStoreRecordDataProperty dataProperty) - { - var typeOrNullableType = Nullable.GetUnderlyingType(property.PropertyType) ?? property.PropertyType; - var convertedValue = Convert.ChangeType(hashEntry.Value, typeOrNullableType); - dataModel.Data.Add(dataProperty.DataModelPropertyName, convertedValue); - } - - // Map vector properties - else if (property is VectorStoreRecordVectorProperty vectorProperty) - { - if (property.PropertyType == typeof(ReadOnlyMemory) || property.PropertyType == typeof(ReadOnlyMemory?)) - { - var array = MemoryMarshal.Cast((byte[])hashEntry.Value!).ToArray(); - dataModel.Vectors.Add(vectorProperty.DataModelPropertyName, new ReadOnlyMemory(array)); - } - else if (property.PropertyType == typeof(ReadOnlyMemory) || property.PropertyType == typeof(ReadOnlyMemory?)) - { - var array = MemoryMarshal.Cast((byte[])hashEntry.Value!).ToArray(); - dataModel.Vectors.Add(vectorProperty.DataModelPropertyName, new ReadOnlyMemory(array)); - } - else - { - throw new VectorStoreRecordMappingException($"Unsupported vector type '{property.PropertyType.Name}' found on property '{property.DataModelPropertyName}'. Only float and double vectors are supported."); - } - } - } - - return dataModel; - } -} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs index a08fe1e86628..6a69967d42f3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs @@ -2,11 +2,16 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.Properties; using NRedisStack.RedisStackCommands; using NRedisStack.Search; using NRedisStack.Search.Literals.Enums; @@ -17,39 +22,16 @@ namespace Microsoft.SemanticKernel.Connectors.Redis; /// /// Service for storing and retrieving vector records, that uses Redis HashSets as the underlying storage. /// +/// The data type of the record key. Can be either , or for dynamic mapping. /// The data model to use for adding, updating and retrieving data from storage. #pragma warning disable CA1711 // Identifiers should not have incorrect suffix -public class RedisHashSetVectorStoreRecordCollection : IVectorStoreRecordCollection +public sealed class RedisHashSetVectorStoreRecordCollection : IVectorStoreRecordCollection + where TKey : notnull + where TRecord : notnull #pragma warning restore CA1711 // Identifiers should not have incorrect suffix { - /// The name of this database for telemetry purposes. - private const string DatabaseName = "Redis"; - - /// A set of types that a key on the provided model may have. - private static readonly HashSet s_supportedKeyTypes = - [ - typeof(string) - ]; - - /// A set of types that data properties on the provided model may have. - private static readonly HashSet s_supportedDataTypes = - [ - typeof(string), - typeof(int), - typeof(uint), - typeof(long), - typeof(ulong), - typeof(double), - typeof(float), - typeof(bool), - typeof(int?), - typeof(uint?), - typeof(long?), - typeof(ulong?), - typeof(double?), - typeof(float?), - typeof(bool?) - ]; + /// Metadata about vector store record collection. + private readonly VectorStoreRecordCollectionMetadata _collectionMetadata; /// A set of types that vectors on the provided model may have. private static readonly HashSet s_supportedVectorTypes = @@ -60,20 +42,45 @@ public class RedisHashSetVectorStoreRecordCollection : IVectorStoreReco typeof(ReadOnlyMemory?) ]; + internal static readonly VectorStoreRecordModelBuildingOptions ModelBuildingOptions = new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + + SupportedKeyPropertyTypes = [typeof(string)], + + SupportedDataPropertyTypes = + [ + typeof(string), + typeof(int), + typeof(uint), + typeof(long), + typeof(ulong), + typeof(double), + typeof(float), + typeof(bool) + ], + + SupportedEnumerableDataPropertyElementTypes = [], + + SupportedVectorPropertyTypes = s_supportedVectorTypes + }; + /// The default options for vector search. private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// The Redis database to read/write records from. private readonly IDatabase _database; - /// The name of the collection that this will access. + /// The name of the collection that this will access. private readonly string _collectionName; /// Optional configuration options for this class. private readonly RedisHashSetVectorStoreRecordCollectionOptions _options; - /// A helper to access property information for the current data model and record definition. - private readonly VectorStoreRecordPropertyReader _propertyReader; + /// The model. + private readonly VectorStoreRecordModel _model; /// An array of the names of all the data properties that are part of the Redis payload as RedisValue objects, i.e. all properties except the key and vector properties. private readonly RedisValue[] _dataStoragePropertyNameRedisValues; @@ -82,72 +89,53 @@ public class RedisHashSetVectorStoreRecordCollection : IVectorStoreReco private readonly string[] _dataStoragePropertyNamesWithScore; /// The mapper to use when mapping between the consumer data model and the Redis record. - private readonly IVectorStoreRecordMapper _mapper; + private readonly RedisHashSetVectorStoreRecordMapper _mapper; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The Redis database to read/write records from. - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// Optional configuration options for this class. /// Throw when parameters are invalid. - public RedisHashSetVectorStoreRecordCollection(IDatabase database, string collectionName, RedisHashSetVectorStoreRecordCollectionOptions? options = null) + public RedisHashSetVectorStoreRecordCollection(IDatabase database, string name, RedisHashSetVectorStoreRecordCollectionOptions? options = null) { // Verify. Verify.NotNull(database); - Verify.NotNullOrWhiteSpace(collectionName); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.HashEntriesCustomMapper is not null, s_supportedKeyTypes); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); + Verify.NotNullOrWhiteSpace(name); + + if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(object)) + { + throw new NotSupportedException("Only string keys are supported (and object for dynamic mapping)."); + } // Assign. this._database = database; - this._collectionName = collectionName; + this._collectionName = name; this._options = options ?? new RedisHashSetVectorStoreRecordCollectionOptions(); - this._propertyReader = new VectorStoreRecordPropertyReader( - typeof(TRecord), - this._options.VectorStoreRecordDefinition, - new() - { - RequiresAtLeastOneVector = false, - SupportsMultipleKeys = false, - SupportsMultipleVectors = true - }); - - // Validate property types. - this._propertyReader.VerifyKeyProperties(s_supportedKeyTypes); - this._propertyReader.VerifyDataProperties(s_supportedDataTypes, supportEnumerable: false); - this._propertyReader.VerifyVectorProperties(s_supportedVectorTypes); + this._model = new VectorStoreRecordModelBuilder(ModelBuildingOptions) + .Build(typeof(TRecord), this._options.VectorStoreRecordDefinition, this._options.EmbeddingGenerator); // Lookup storage property names. - this._dataStoragePropertyNameRedisValues = this._propertyReader.DataPropertyStoragePropertyNames - .Select(RedisValue.Unbox) - .ToArray(); - - this._dataStoragePropertyNamesWithScore = [.. this._propertyReader.DataPropertyStoragePropertyNames, "vector_score"]; + this._dataStoragePropertyNameRedisValues = this._model.DataProperties.Select(p => RedisValue.Unbox(p.StorageName)).ToArray(); + this._dataStoragePropertyNamesWithScore = [.. this._model.DataProperties.Select(p => p.StorageName), "vector_score"]; // Assign Mapper. - if (this._options.HashEntriesCustomMapper is not null) - { - // Custom Mapper. - this._mapper = this._options.HashEntriesCustomMapper; - } - else if (typeof(TRecord) == typeof(VectorStoreGenericDataModel)) - { - // Generic data model mapper. - this._mapper = (IVectorStoreRecordMapper)new RedisHashSetGenericDataModelMapper(this._propertyReader.Properties); - } - else + this._mapper = new RedisHashSetVectorStoreRecordMapper(this._model); + + this._collectionMetadata = new() { - // Default Mapper. - this._mapper = new RedisHashSetVectorStoreRecordMapper(this._propertyReader); - } + VectorStoreSystemName = RedisConstants.VectorStoreSystemName, + VectorStoreName = database.Database.ToString(), + CollectionName = name + }; } /// - public string CollectionName => this._collectionName; + public string Name => this._collectionName; /// - public virtual async Task CollectionExistsAsync(CancellationToken cancellationToken = default) + public async Task CollectionExistsAsync(CancellationToken cancellationToken = default) { try { @@ -162,7 +150,8 @@ public virtual async Task CollectionExistsAsync(CancellationToken cancella { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreSystemName = RedisConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, CollectionName = this._collectionName, OperationName = "FT.INFO" }; @@ -170,10 +159,10 @@ public virtual async Task CollectionExistsAsync(CancellationToken cancella } /// - public virtual Task CreateCollectionAsync(CancellationToken cancellationToken = default) + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) { // Map the record definition to a schema. - var schema = RedisVectorStoreCollectionCreateMapping.MapToSchema(this._propertyReader.Properties, this._propertyReader.StoragePropertyNamesMap, useDollarPrefix: false); + var schema = RedisVectorStoreCollectionCreateMapping.MapToSchema(this._model.Properties, useDollarPrefix: false); // Create the index creation params. // Add the collection name and colon as the index prefix, which means that any record where the key is prefixed with this text will be indexed by this index @@ -186,7 +175,7 @@ public virtual Task CreateCollectionAsync(CancellationToken cancellationToken = } /// - public virtual async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) { if (!await this.CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) { @@ -195,7 +184,7 @@ public virtual async Task CreateCollectionIfNotExistsAsync(CancellationToken can } /// - public virtual async Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + public async Task DeleteCollectionAsync(CancellationToken cancellationToken = default) { try { @@ -217,13 +206,19 @@ await this.RunOperationAsync("FT.DROPINDEX", } /// - public virtual async Task GetAsync(string key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public async Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { - Verify.NotNullOrWhiteSpace(key); + var stringKey = this.GetStringKey(key); // Create Options - var maybePrefixedKey = this.PrefixKeyIfNeeded(key); + var maybePrefixedKey = this.PrefixKeyIfNeeded(stringKey); + var includeVectors = options?.IncludeVectors ?? false; + if (includeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + var operationName = includeVectors ? "HGETALL" : "HMGET"; // Get the Redis value. @@ -251,17 +246,18 @@ await this.RunOperationAsync("FT.DROPINDEX", // Convert to the caller's data model. return VectorStoreErrorHandler.RunModelConversion( - DatabaseName, + RedisConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, this._collectionName, operationName, () => { - return this._mapper.MapFromStorageToDataModel((key, retrievedHashEntries), new() { IncludeVectors = includeVectors }); + return this._mapper.MapFromStorageToDataModel((stringKey, retrievedHashEntries), new() { IncludeVectors = includeVectors }); }); } /// - public virtual async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = default, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable GetAsync(IEnumerable keys, GetRecordOptions? options = default, [EnumeratorCancellation] CancellationToken cancellationToken = default) { Verify.NotNull(keys); @@ -278,12 +274,12 @@ public virtual async IAsyncEnumerable GetBatchAsync(IEnumerable } /// - public virtual Task DeleteAsync(string key, CancellationToken cancellationToken = default) + public Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) { - Verify.NotNullOrWhiteSpace(key); + var stringKey = this.GetStringKey(key); // Create Options - var maybePrefixedKey = this.PrefixKeyIfNeeded(key); + var maybePrefixedKey = this.PrefixKeyIfNeeded(stringKey); // Remove. return this.RunOperationAsync( @@ -293,7 +289,7 @@ public virtual Task DeleteAsync(string key, CancellationToken cancellationToken } /// - public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + public Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default) { Verify.NotNull(keys); @@ -303,16 +299,37 @@ public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken } /// - public virtual async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + { + (_, var generatedEmbeddings) = await RedisVectorStoreRecordFieldMapping.ProcessEmbeddingsAsync(this._model, [record], cancellationToken).ConfigureAwait(false); + + return await this.UpsertCoreAsync(record, 0, generatedEmbeddings, cancellationToken).ConfigureAwait(false); + } + + /// + public async Task> UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) + { + Verify.NotNull(records); + + (records, var generatedEmbeddings) = await RedisVectorStoreRecordFieldMapping.ProcessEmbeddingsAsync(this._model, records, cancellationToken).ConfigureAwait(false); + + // Upsert records in parallel. + var tasks = records.Select((r, i) => this.UpsertCoreAsync(r, i, generatedEmbeddings, cancellationToken)); + var results = await Task.WhenAll(tasks).ConfigureAwait(false); + return results.Where(r => r is not null).ToList(); + } + + private async Task UpsertCoreAsync(TRecord record, int recordIndex, IReadOnlyList?[]? generatedEmbeddings, CancellationToken cancellationToken = default) { Verify.NotNull(record); // Map. var redisHashSetRecord = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, + RedisConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, this._collectionName, "HSET", - () => this._mapper.MapFromDataToStorageModel(record)); + () => this._mapper.MapFromDataToStorageModel(record, recordIndex, generatedEmbeddings)); // Upsert. var maybePrefixedKey = this.PrefixKeyIfNeeded(redisHashSetRecord.Key); @@ -324,42 +341,99 @@ await this.RunOperationAsync( maybePrefixedKey, redisHashSetRecord.HashEntries)).ConfigureAwait(false); - return redisHashSetRecord.Key; + return (TKey)(object)redisHashSetRecord.Key; } + #region Search + /// - public virtual async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable> SearchAsync( + TInput value, + int top, + VectorSearchOptions? options = default, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TInput : notnull { - Verify.NotNull(records); + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); - // Upsert records in parallel. - var tasks = records.Select(x => this.UpsertAsync(x, cancellationToken)); - var results = await Task.WhenAll(tasks).ConfigureAwait(false); - foreach (var result in results) + switch (vectorProperty.EmbeddingGenerator) { - if (result is not null) + case IEmbeddingGenerator> generator: { - yield return result; + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; + } + + case IEmbeddingGenerator> generator: + { + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; } + + case null: + throw new InvalidOperationException(VectorDataStrings.NoEmbeddingGeneratorWasConfiguredForSearch); + + default: + throw new InvalidOperationException( + s_supportedVectorTypes.Contains(typeof(TInput)) + ? string.Format(VectorDataStrings.EmbeddingTypePassedToSearchAsync) + : string.Format(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType, typeof(TInput).Name, vectorProperty.EmbeddingGenerator.GetType().Name)); } } /// - public virtual async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public IAsyncEnumerable> SearchEmbeddingAsync( + TVector vector, + int top, + VectorSearchOptions? options = null, + CancellationToken cancellationToken = default) + where TVector : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + return this.SearchCoreAsync(vector, top, vectorProperty, operationName: "SearchEmbedding", options, cancellationToken); + } + + private async IAsyncEnumerable> SearchCoreAsync( + TVector vector, + int top, + VectorStoreRecordVectorPropertyModel vectorProperty, + string operationName, + VectorSearchOptions options, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TVector : notnull { Verify.NotNull(vector); + Verify.NotLessThan(top, 1); - var internalOptions = options ?? s_defaultVectorSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(internalOptions); + if (options.IncludeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } // Build query & search. - var selectFields = internalOptions.IncludeVectors ? null : this._dataStoragePropertyNamesWithScore; + var selectFields = options.IncludeVectors ? null : this._dataStoragePropertyNamesWithScore; byte[] vectorBytes = RedisVectorStoreCollectionSearchMapping.ValidateVectorAndConvertToBytes(vector, "HashSet"); var query = RedisVectorStoreCollectionSearchMapping.BuildQuery( vectorBytes, - internalOptions, - this._propertyReader.StoragePropertyNamesMap, - this._propertyReader.GetStoragePropertyName(vectorProperty.DataModelPropertyName), + top, + options, + this._model, + vectorProperty, selectFields); var results = await this.RunOperationAsync( "FT.SEARCH", @@ -370,30 +444,97 @@ public virtual async Task> VectorizedSearchAsync { - var retrievedHashEntries = this._propertyReader.DataPropertyStoragePropertyNames - .Concat(this._propertyReader.VectorPropertyStoragePropertyNames) + var retrievedHashEntries = this._model.DataProperties.Select(p => p.StorageName) + .Concat(this._model.VectorProperties.Select(p => p.StorageName)) .Select(propertyName => new HashEntry(propertyName, result[propertyName])) .ToArray(); // Convert to the caller's data model. var dataModel = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, + RedisConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, this._collectionName, "FT.SEARCH", () => { - return this._mapper.MapFromStorageToDataModel((this.RemoveKeyPrefixIfNeeded(result.Id), retrievedHashEntries), new() { IncludeVectors = internalOptions.IncludeVectors }); + return this._mapper.MapFromStorageToDataModel((this.RemoveKeyPrefixIfNeeded(result.Id), retrievedHashEntries), new() { IncludeVectors = options.IncludeVectors }); }); // Process the score of the result item. - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(internalOptions); + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); var distanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(vectorProperty); var score = RedisVectorStoreCollectionSearchMapping.GetOutputScoreFromRedisScore(result["vector_score"].HasValue ? (float)result["vector_score"] : null, distanceFunction); return new VectorSearchResult(dataModel, score); }); - return new VectorSearchResults(mappedResults.ToAsyncEnumerable()); + foreach (var result in mappedResults) + { + yield return result; + } + } + + /// + [Obsolete("Use either SearchEmbeddingAsync to search directly on embeddings, or SearchAsync to handle embedding generation internally as part of the call.")] + public IAsyncEnumerable> VectorizedSearchAsync(TVector vector, int top, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + => this.SearchEmbeddingAsync(vector, top, options, cancellationToken); + + #endregion Search + + /// + public async IAsyncEnumerable GetAsync(Expression> filter, int top, + GetFilteredRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(filter); + Verify.NotLessThan(top, 1); + + options ??= new(); + + if (options.IncludeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + + Query query = RedisVectorStoreCollectionSearchMapping.BuildQuery(filter, top, options, this._model); + + var results = await this.RunOperationAsync( + "FT.SEARCH", + () => this._database + .FT() + .SearchAsync(this._collectionName, query)).ConfigureAwait(false); + + foreach (var document in results.Documents) + { + var retrievedHashEntries = this._model.DataProperties.Select(p => p.StorageName) + .Concat(this._model.VectorProperties.Select(p => p.StorageName)) + .Select(propertyName => new HashEntry(propertyName, document[propertyName])) + .ToArray(); + + // Convert to the caller's data model. + yield return VectorStoreErrorHandler.RunModelConversion( + RedisConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this._collectionName, + "FT.SEARCH", + () => + { + return this._mapper.MapFromStorageToDataModel((this.RemoveKeyPrefixIfNeeded(document.Id), retrievedHashEntries), new() { IncludeVectors = options.IncludeVectors }); + }); + } + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreRecordCollectionMetadata) ? this._collectionMetadata : + serviceType == typeof(IDatabase) ? this._database : + serviceType.IsInstanceOfType(this) ? this : + null; } /// @@ -445,7 +586,8 @@ private async Task RunOperationAsync(string operationName, Func> o { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreSystemName = RedisConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, CollectionName = this._collectionName, OperationName = operationName }; @@ -468,10 +610,22 @@ private async Task RunOperationAsync(string operationName, Func operation) { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreSystemName = RedisConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, CollectionName = this._collectionName, OperationName = operationName }; } } + + private string GetStringKey(TKey key) + { + Verify.NotNull(key); + + var stringKey = key as string ?? throw new UnreachableException("string key should have been validated during model building"); + + Verify.NotNullOrWhiteSpace(stringKey, nameof(key)); + + return stringKey; + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollectionOptions.cs index 8d61c1fb74ea..c465e203cc1c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollectionOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollectionOptions.cs @@ -1,5 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; using StackExchange.Redis; @@ -23,6 +25,7 @@ public sealed class RedisHashSetVectorStoreRecordCollectionOptions /// /// Gets or sets an optional custom mapper to use when converting between the data model and the Redis record. /// + [Obsolete("Custom mappers are no longer supported.", error: true)] public IVectorStoreRecordMapper? HashEntriesCustomMapper { get; init; } = null; /// @@ -34,4 +37,9 @@ public sealed class RedisHashSetVectorStoreRecordCollectionOptions /// See , and . /// public VectorStoreRecordDefinition? VectorStoreRecordDefinition { get; init; } = null; + + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordMapper.cs index ceed34e41a05..5a01d33d53fa 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordMapper.cs @@ -4,7 +4,9 @@ using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using StackExchange.Redis; namespace Microsoft.SemanticKernel.Connectors.Redis; @@ -13,55 +15,50 @@ namespace Microsoft.SemanticKernel.Connectors.Redis; /// Class for mapping between a hashset stored in redis, and the consumer data model. /// /// The consumer data model to map to or from. -internal sealed class RedisHashSetVectorStoreRecordMapper : IVectorStoreRecordMapper +internal sealed class RedisHashSetVectorStoreRecordMapper(VectorStoreRecordModel model) { - /// A helper to access property information for the current data model and record definition. - private readonly VectorStoreRecordPropertyReader _propertyReader; - - /// - /// Initializes a new instance of the class. - /// - /// A helper to access property information for the current data model and record definition. - public RedisHashSetVectorStoreRecordMapper( - VectorStoreRecordPropertyReader propertyReader) - { - Verify.NotNull(propertyReader); - - propertyReader.VerifyHasParameterlessConstructor(); - - this._propertyReader = propertyReader; - } - /// - public (string Key, HashEntry[] HashEntries) MapFromDataToStorageModel(TConsumerDataModel dataModel) + public (string Key, HashEntry[] HashEntries) MapFromDataToStorageModel(TConsumerDataModel dataModel, int recordIndex, IReadOnlyList?[]? generatedEmbeddings) { - var keyValue = this._propertyReader.KeyPropertyInfo.GetValue(dataModel) as string ?? - throw new VectorStoreRecordMappingException($"Missing key property {this._propertyReader.KeyPropertyName} on provided record of type {typeof(TConsumerDataModel).FullName}."); + var keyValue = model.KeyProperty.GetValueAsObject(dataModel!) as string ?? + throw new VectorStoreRecordMappingException($"Missing key property {model.KeyProperty.ModelName} on provided record of type '{typeof(TConsumerDataModel).Name}'."); var hashEntries = new List(); - foreach (var property in this._propertyReader.DataPropertiesInfo) + foreach (var property in model.DataProperties) { - var storageName = this._propertyReader.GetStoragePropertyName(property.Name); - var value = property.GetValue(dataModel); - hashEntries.Add(new HashEntry(storageName, RedisValue.Unbox(value))); + var value = property.GetValueAsObject(dataModel!); + hashEntries.Add(new HashEntry(property.StorageName, RedisValue.Unbox(value))); } - foreach (var property in this._propertyReader.VectorPropertiesInfo) + for (var i = 0; i < model.VectorProperties.Count; i++) { - var storageName = this._propertyReader.GetStoragePropertyName(property.Name); - var value = property.GetValue(dataModel); + var property = model.VectorProperties[i]; + + var value = generatedEmbeddings?[i]?[recordIndex] ?? property.GetValueAsObject(dataModel!); + if (value is not null) { // Convert the vector to a byte array and store it in the hash entry. // We only support float and double vectors and we do checking in the // collection constructor to ensure that the model has no other vector types. - if (value is ReadOnlyMemory rom) - { - hashEntries.Add(new HashEntry(storageName, RedisVectorStoreRecordFieldMapping.ConvertVectorToBytes(rom))); - } - else if (value is ReadOnlyMemory rod) + switch (value) { - hashEntries.Add(new HashEntry(storageName, RedisVectorStoreRecordFieldMapping.ConvertVectorToBytes(rod))); + case ReadOnlyMemory rom: + hashEntries.Add(new HashEntry(property.StorageName, RedisVectorStoreRecordFieldMapping.ConvertVectorToBytes(rom))); + continue; + case ReadOnlyMemory rod: + hashEntries.Add(new HashEntry(property.StorageName, RedisVectorStoreRecordFieldMapping.ConvertVectorToBytes(rod))); + continue; + + case Embedding embedding: + hashEntries.Add(new HashEntry(property.StorageName, RedisVectorStoreRecordFieldMapping.ConvertVectorToBytes(embedding.Vector))); + continue; + case Embedding embedding: + hashEntries.Add(new HashEntry(property.StorageName, RedisVectorStoreRecordFieldMapping.ConvertVectorToBytes(embedding.Vector))); + continue; + + default: + throw new VectorStoreRecordMappingException($"Unsupported vector type '{value.GetType()}'. Only float and double vectors are supported."); } } } @@ -75,49 +72,51 @@ public TConsumerDataModel MapFromStorageToDataModel((string Key, HashEntry[] Has var hashEntriesDictionary = storageModel.HashEntries.ToDictionary(x => (string)x.Name!, x => x.Value); // Construct the output record. - var outputRecord = (TConsumerDataModel)this._propertyReader.ParameterLessConstructorInfo.Invoke(null); + var outputRecord = model.CreateRecord()!; // Set Key. - this._propertyReader.KeyPropertyInfo.SetValue(outputRecord, storageModel.Key); + model.KeyProperty.SetValueAsObject(outputRecord, storageModel.Key); // Set each vector property if embeddings should be returned. if (options?.IncludeVectors is true) { - VectorStoreRecordMapping.SetValuesOnProperties( - outputRecord, - this._propertyReader.VectorPropertiesInfo, - this._propertyReader.StoragePropertyNamesMap, - hashEntriesDictionary, - (RedisValue vector, Type targetType) => + foreach (var property in model.VectorProperties) + { + if (hashEntriesDictionary.TryGetValue(property.StorageName, out var vector)) { - if (targetType == typeof(ReadOnlyMemory) || targetType == typeof(ReadOnlyMemory?)) - { - var array = MemoryMarshal.Cast((byte[])vector!).ToArray(); - return new ReadOnlyMemory(array); - } - else if (targetType == typeof(ReadOnlyMemory) || targetType == typeof(ReadOnlyMemory?)) + if (vector.IsNull) { - var array = MemoryMarshal.Cast((byte[])vector!).ToArray(); - return new ReadOnlyMemory(array); + property.SetValueAsObject(outputRecord!, null); + continue; } - else + + property.SetValueAsObject(outputRecord!, property.Type switch { - throw new VectorStoreRecordMappingException($"Unsupported vector type '{targetType}'. Only float and double vectors are supported."); - } - }); + Type t when t == typeof(ReadOnlyMemory) || t == typeof(ReadOnlyMemory?) + => new ReadOnlyMemory(MemoryMarshal.Cast((byte[])vector!).ToArray()), + Type t when t == typeof(ReadOnlyMemory) || t == typeof(ReadOnlyMemory?) + => new ReadOnlyMemory(MemoryMarshal.Cast((byte[])vector!).ToArray()), + _ => throw new VectorStoreRecordMappingException($"Unsupported vector type '{property.Type}'. Only float and double vectors are supported.") + }); + } + } } - // Set each data property. - VectorStoreRecordMapping.SetValuesOnProperties( - outputRecord, - this._propertyReader.DataPropertiesInfo, - this._propertyReader.StoragePropertyNamesMap, - hashEntriesDictionary, - (RedisValue hashValue, Type targetType) => + foreach (var property in model.DataProperties) + { + if (hashEntriesDictionary.TryGetValue(property.StorageName, out var hashValue)) { - var typeOrNullableType = Nullable.GetUnderlyingType(targetType) ?? targetType; - return Convert.ChangeType(hashValue, typeOrNullableType); - }); + if (hashValue.IsNull) + { + property.SetValueAsObject(outputRecord!, null); + continue; + } + + var typeOrNullableType = Nullable.GetUnderlyingType(property.Type) ?? property.Type; + var value = Convert.ChangeType(hashValue, typeOrNullableType); + property.SetValueAsObject(outputRecord!, value); + } + } return outputRecord; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonDynamicDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonDynamicDataModelMapper.cs new file mode 100644 index 000000000000..0c1cd9e61ae7 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonDynamicDataModelMapper.cs @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Text.Json; +using System.Text.Json.Nodes; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +/// +/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Redis when using JSON. +/// +internal class RedisJsonDynamicDataModelMapper(VectorStoreRecordModel model, JsonSerializerOptions jsonSerializerOptions) : IRedisJsonMapper> +{ + /// + public (string Key, JsonNode Node) MapFromDataToStorageModel(Dictionary dataModel, int recordIndex, IReadOnlyList?[]? generatedEmbeddings) + { + var jsonObject = new JsonObject(); + + // Key handled below, outside of the JsonNode + + foreach (var dataProperty in model.DataProperties) + { + if (dataModel.TryGetValue(dataProperty.ModelName, out var sourceValue)) + { + jsonObject.Add(dataProperty.StorageName, sourceValue is null + ? null + : JsonSerializer.SerializeToNode(sourceValue, dataProperty.Type, jsonSerializerOptions)); + } + } + + for (var i = 0; i < model.VectorProperties.Count; i++) + { + var property = model.VectorProperties[i]; + + if (generatedEmbeddings?[i] is IReadOnlyList propertyEmbedding) + { + Debug.Assert(property.EmbeddingGenerator is not null); + + jsonObject.Add( + property.StorageName, + propertyEmbedding[recordIndex] switch + { + Embedding e => JsonSerializer.SerializeToNode(e.Vector, jsonSerializerOptions), + Embedding e => JsonSerializer.SerializeToNode(e.Vector, jsonSerializerOptions), + _ => throw new UnreachableException() + }); + } + else + { + // No generated embedding, read the vector directly from the data model + if (dataModel.TryGetValue(property.ModelName, out var sourceValue)) + { + jsonObject.Add(property.StorageName, sourceValue is null + ? null + : JsonSerializer.SerializeToNode(sourceValue, property.Type, jsonSerializerOptions)); + } + } + } + + return ((string)dataModel[model.KeyProperty.ModelName]!, jsonObject); + } + + /// + public Dictionary MapFromStorageToDataModel((string Key, JsonNode Node) storageModel, StorageToDataModelMapperOptions options) + { + var dataModel = new Dictionary + { + [model.KeyProperty.ModelName] = storageModel.Key, + }; + + // The redis result can be either a single object or an array with a single object in the case where we are doing an MGET. + var jsonObject = storageModel.Node switch + { + JsonObject topLevelJsonObject => topLevelJsonObject, + JsonArray jsonArray and [JsonObject arrayEntryJsonObject] => arrayEntryJsonObject, + _ => throw new VectorStoreRecordMappingException($"Invalid data format for document with key '{storageModel.Key}'"), + }; + + // The key was handled above + + foreach (var dataProperty in model.DataProperties) + { + // Replicate null if the property exists but is null. + if (jsonObject.TryGetPropertyValue(dataProperty.StorageName, out var sourceValue)) + { + dataModel.Add(dataProperty.ModelName, sourceValue is null + ? null + : JsonSerializer.Deserialize(sourceValue, dataProperty.Type, jsonSerializerOptions)); + } + } + + foreach (var vectorProperty in model.VectorProperties) + { + // For vector properties which have embedding generation configured, we need to remove the embeddings before deserializing + // (we can't go back from an embedding to e.g. string). + // For other cases (no embedding generation), we leave the properties even if IncludeVectors is false. + if (vectorProperty.EmbeddingGenerator is not null) + { + continue; + } + + // Replicate null if the property exists but is null. + if (jsonObject.TryGetPropertyValue(vectorProperty.StorageName, out var sourceValue)) + { + dataModel.Add(vectorProperty.ModelName, sourceValue is null + ? null + : JsonSerializer.Deserialize(sourceValue, vectorProperty.Type, jsonSerializerOptions)); + } + } + + return dataModel; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonGenericDataModelMapper.cs deleted file mode 100644 index f499b0bfb4eb..000000000000 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonGenericDataModelMapper.cs +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Collections.Generic; -using System.Linq; -using System.Text.Json; -using System.Text.Json.Nodes; -using Microsoft.Extensions.VectorData; - -namespace Microsoft.SemanticKernel.Connectors.Redis; - -/// -/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Redis when using JSON. -/// -internal class RedisJsonGenericDataModelMapper : IVectorStoreRecordMapper, (string Key, JsonNode Node)> -{ - /// All the properties from the record definition. - private readonly IReadOnlyList _properties; - - /// The JSON serializer options to use when converting between the data model and the Redis record. - private readonly JsonSerializerOptions _jsonSerializerOptions; - - /// A dictionary that maps from a property name to the storage name that should be used when serializing it to json for data and vector properties. - public readonly Dictionary _storagePropertyNames; - - /// - /// Initializes a new instance of the class. - /// - /// All the properties from the record definition. - /// The JSON serializer options to use when converting between the data model and the Redis record. - public RedisJsonGenericDataModelMapper( - IReadOnlyList properties, - JsonSerializerOptions jsonSerializerOptions) - { - Verify.NotNull(properties); - Verify.NotNull(jsonSerializerOptions); - - this._properties = properties; - this._jsonSerializerOptions = jsonSerializerOptions; - - // Create a dictionary that maps from the data model property name to the storage property name. - this._storagePropertyNames = properties.Select(x => - { - if (x.StoragePropertyName is not null) - { - return new KeyValuePair( - x.DataModelPropertyName, - x.StoragePropertyName); - } - - if (jsonSerializerOptions.PropertyNamingPolicy is not null) - { - return new KeyValuePair( - x.DataModelPropertyName, - jsonSerializerOptions.PropertyNamingPolicy.ConvertName(x.DataModelPropertyName)); - } - - return new KeyValuePair( - x.DataModelPropertyName, - x.DataModelPropertyName); - }).ToDictionary(x => x.Key, x => x.Value); - } - - /// - public (string Key, JsonNode Node) MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - { - var jsonObject = new JsonObject(); - - foreach (var property in this._properties) - { - var storagePropertyName = this._storagePropertyNames[property.DataModelPropertyName]; - var sourceDictionary = property is VectorStoreRecordDataProperty ? dataModel.Data : dataModel.Vectors; - - // Only map properties across that actually exist in the input. - if (sourceDictionary is null || !sourceDictionary.TryGetValue(property.DataModelPropertyName, out var sourceValue)) - { - continue; - } - - // Replicate null if the property exists but is null. - if (sourceValue is null) - { - jsonObject.Add(storagePropertyName, null); - continue; - } - - jsonObject.Add(storagePropertyName, JsonSerializer.SerializeToNode(sourceValue, property.PropertyType)); - } - - return (dataModel.Key, jsonObject); - } - - /// - public VectorStoreGenericDataModel MapFromStorageToDataModel((string Key, JsonNode Node) storageModel, StorageToDataModelMapperOptions options) - { - var dataModel = new VectorStoreGenericDataModel(storageModel.Key); - - // The redis result can be either a single object or an array with a single object in the case where we are doing an MGET. - JsonObject jsonObject; - if (storageModel.Node is JsonObject topLevelJsonObject) - { - jsonObject = topLevelJsonObject; - } - else if (storageModel.Node is JsonArray jsonArray && jsonArray.Count == 1 && jsonArray[0] is JsonObject arrayEntryJsonObject) - { - jsonObject = arrayEntryJsonObject; - } - else - { - throw new VectorStoreRecordMappingException($"Invalid data format for document with key '{storageModel.Key}'"); - } - - foreach (var property in this._properties) - { - var storagePropertyName = this._storagePropertyNames[property.DataModelPropertyName]; - var targetDictionary = property is VectorStoreRecordDataProperty ? dataModel.Data : dataModel.Vectors; - - // Only map properties across that actually exist in the input. - if (!jsonObject.TryGetPropertyValue(storagePropertyName, out var sourceValue)) - { - continue; - } - - // Replicate null if the property exists but is null. - if (sourceValue is null) - { - targetDictionary.Add(property.DataModelPropertyName, null); - continue; - } - - // Map data and vector values. - if (property is VectorStoreRecordDataProperty || property is VectorStoreRecordVectorProperty) - { - targetDictionary.Add(property.DataModelPropertyName, JsonSerializer.Deserialize(sourceValue, property.PropertyType)); - } - } - - return dataModel; - } -} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs index af6a0a7d220f..af6d2bdb7c18 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs @@ -2,13 +2,18 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Nodes; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.Properties; using NRedisStack.Json.DataTypes; using NRedisStack.RedisStackCommands; using NRedisStack.Search; @@ -20,22 +25,18 @@ namespace Microsoft.SemanticKernel.Connectors.Redis; /// /// Service for storing and retrieving vector records, that uses Redis JSON as the underlying storage. /// +/// The data type of the record key. Can be either , or for dynamic mapping. /// The data model to use for adding, updating and retrieving data from storage. #pragma warning disable CA1711 // Identifiers should not have incorrect suffix -public class RedisJsonVectorStoreRecordCollection : IVectorStoreRecordCollection +public sealed class RedisJsonVectorStoreRecordCollection : IVectorStoreRecordCollection + where TKey : notnull + where TRecord : notnull #pragma warning restore CA1711 // Identifiers should not have incorrect suffix { - /// The name of this database for telemetry purposes. - private const string DatabaseName = "Redis"; + /// Metadata about vector store record collection. + private readonly VectorStoreRecordCollectionMetadata _collectionMetadata; - /// A set of types that a key on the provided model may have. - private static readonly HashSet s_supportedKeyTypes = - [ - typeof(string) - ]; - - /// A set of types that vectors on the provided model may have. - private static readonly HashSet s_supportedVectorTypes = + internal static readonly HashSet s_supportedVectorTypes = [ typeof(ReadOnlyMemory), typeof(ReadOnlyMemory), @@ -43,93 +44,94 @@ public class RedisJsonVectorStoreRecordCollection : IVectorStoreRecordC typeof(ReadOnlyMemory?) ]; + internal static readonly VectorStoreRecordModelBuildingOptions ModelBuildingOptions = new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + + SupportedKeyPropertyTypes = [typeof(string)], + SupportedDataPropertyTypes = null, // TODO: Validate data property types + SupportedEnumerableDataPropertyElementTypes = null, + SupportedVectorPropertyTypes = s_supportedVectorTypes, + + UsesExternalSerializer = true + }; + /// The default options for vector search. private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// The Redis database to read/write records from. private readonly IDatabase _database; - /// The name of the collection that this will access. + /// The name of the collection that this will access. private readonly string _collectionName; /// Optional configuration options for this class. private readonly RedisJsonVectorStoreRecordCollectionOptions _options; - /// A helper to access property information for the current data model and record definition. - private readonly VectorStoreRecordPropertyReader _propertyReader; + /// The model. + private readonly VectorStoreRecordModel _model; /// An array of the storage names of all the data properties that are part of the Redis payload, i.e. all properties except the key and vector properties. private readonly string[] _dataStoragePropertyNames; /// The mapper to use when mapping between the consumer data model and the Redis record. - private readonly IVectorStoreRecordMapper _mapper; + private readonly IRedisJsonMapper _mapper; /// The JSON serializer options to use when converting between the data model and the Redis record. private readonly JsonSerializerOptions _jsonSerializerOptions; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The Redis database to read/write records from. - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// Optional configuration options for this class. /// Throw when parameters are invalid. - public RedisJsonVectorStoreRecordCollection(IDatabase database, string collectionName, RedisJsonVectorStoreRecordCollectionOptions? options = null) + public RedisJsonVectorStoreRecordCollection(IDatabase database, string name, RedisJsonVectorStoreRecordCollectionOptions? options = null) { // Verify. Verify.NotNull(database); - Verify.NotNullOrWhiteSpace(collectionName); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.JsonNodeCustomMapper is not null, s_supportedKeyTypes); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); + Verify.NotNullOrWhiteSpace(name); + + if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(object)) + { + throw new NotSupportedException("Only string keys are supported (and object for dynamic mapping)."); + } + + var isDynamic = typeof(TRecord) == typeof(Dictionary); // Assign. this._database = database; - this._collectionName = collectionName; + this._collectionName = name; this._options = options ?? new RedisJsonVectorStoreRecordCollectionOptions(); this._jsonSerializerOptions = this._options.JsonSerializerOptions ?? JsonSerializerOptions.Default; - this._propertyReader = new VectorStoreRecordPropertyReader( - typeof(TRecord), - this._options.VectorStoreRecordDefinition, - new() - { - RequiresAtLeastOneVector = false, - SupportsMultipleKeys = false, - SupportsMultipleVectors = true, - JsonSerializerOptions = this._jsonSerializerOptions - }); - - // Validate property types. - this._propertyReader.VerifyKeyProperties(s_supportedKeyTypes); - this._propertyReader.VerifyVectorProperties(s_supportedVectorTypes); + this._model = isDynamic ? + new VectorStoreRecordModelBuilder(ModelBuildingOptions).Build(typeof(TRecord), this._options.VectorStoreRecordDefinition, this._options.EmbeddingGenerator) : + new VectorStoreRecordJsonModelBuilder(ModelBuildingOptions).Build(typeof(TRecord), this._options.VectorStoreRecordDefinition, this._options.EmbeddingGenerator, this._jsonSerializerOptions); // Lookup storage property names. - this._dataStoragePropertyNames = this._propertyReader.DataPropertyJsonNames.ToArray(); + this._dataStoragePropertyNames = this._model.DataProperties.Select(p => p.StorageName).ToArray(); // Assign Mapper. - if (this._options.JsonNodeCustomMapper is not null) - { - // Custom Mapper. - this._mapper = this._options.JsonNodeCustomMapper; - } - else if (typeof(TRecord) == typeof(VectorStoreGenericDataModel)) - { - // Generic data model mapper. - this._mapper = (IVectorStoreRecordMapper)new RedisJsonGenericDataModelMapper( - this._propertyReader.Properties, - this._jsonSerializerOptions); - } - else + this._mapper = isDynamic + ? (IRedisJsonMapper)new RedisJsonDynamicDataModelMapper(this._model, this._jsonSerializerOptions) + : new RedisJsonVectorStoreRecordMapper(this._model, this._jsonSerializerOptions); + + this._collectionMetadata = new() { - // Default Mapper. - this._mapper = new RedisJsonVectorStoreRecordMapper(this._propertyReader.KeyPropertyJsonName, this._jsonSerializerOptions); - } + VectorStoreSystemName = RedisConstants.VectorStoreSystemName, + VectorStoreName = database.Database.ToString(), + CollectionName = name + }; } /// - public string CollectionName => this._collectionName; + public string Name => this._collectionName; /// - public virtual async Task CollectionExistsAsync(CancellationToken cancellationToken = default) + public async Task CollectionExistsAsync(CancellationToken cancellationToken = default) { try { @@ -144,7 +146,8 @@ public virtual async Task CollectionExistsAsync(CancellationToken cancella { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreSystemName = RedisConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, CollectionName = this._collectionName, OperationName = "FT.INFO" }; @@ -152,10 +155,10 @@ public virtual async Task CollectionExistsAsync(CancellationToken cancella } /// - public virtual Task CreateCollectionAsync(CancellationToken cancellationToken = default) + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) { // Map the record definition to a schema. - var schema = RedisVectorStoreCollectionCreateMapping.MapToSchema(this._propertyReader.Properties, this._propertyReader.JsonPropertyNamesMap, useDollarPrefix: true); + var schema = RedisVectorStoreCollectionCreateMapping.MapToSchema(this._model.Properties, useDollarPrefix: true); // Create the index creation params. // Add the collection name and colon as the index prefix, which means that any record where the key is prefixed with this text will be indexed by this index @@ -168,7 +171,7 @@ public virtual Task CreateCollectionAsync(CancellationToken cancellationToken = } /// - public virtual async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) { if (!await this.CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) { @@ -177,7 +180,7 @@ public virtual async Task CreateCollectionIfNotExistsAsync(CancellationToken can } /// - public virtual async Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + public async Task DeleteCollectionAsync(CancellationToken cancellationToken = default) { try { @@ -199,14 +202,19 @@ await this.RunOperationAsync("FT.DROPINDEX", } /// - public virtual async Task GetAsync(string key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public async Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { - Verify.NotNullOrWhiteSpace(key); + var stringKey = this.GetStringKey(key); // Create Options - var maybePrefixedKey = this.PrefixKeyIfNeeded(key); + var maybePrefixedKey = this.PrefixKeyIfNeeded(stringKey); var includeVectors = options?.IncludeVectors ?? false; + if (includeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + // Get the Redis value. var redisResult = await this.RunOperationAsync( "GET", @@ -233,26 +241,39 @@ await this.RunOperationAsync("FT.DROPINDEX", // Convert to the caller's data model. return VectorStoreErrorHandler.RunModelConversion( - DatabaseName, + RedisConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, this._collectionName, "GET", () => { var node = JsonSerializer.Deserialize(redisResultString, this._jsonSerializerOptions)!; - return this._mapper.MapFromStorageToDataModel((key, node), new() { IncludeVectors = includeVectors }); + return this._mapper.MapFromStorageToDataModel((stringKey, node), new() { IncludeVectors = includeVectors }); }); } /// - public virtual async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = default, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable GetAsync(IEnumerable keys, GetRecordOptions? options = default, [EnumeratorCancellation] CancellationToken cancellationToken = default) { Verify.NotNull(keys); - var keysList = keys.ToList(); + +#pragma warning disable CA1851 // Possible multiple enumerations of 'IEnumerable' collection + var keysList = keys switch + { + IEnumerable k => k.ToList(), + IEnumerable k => k.Cast().ToList(), + _ => throw new UnreachableException() + }; +#pragma warning restore CA1851 // Possible multiple enumerations of 'IEnumerable' collection // Create Options var maybePrefixedKeys = keysList.Select(key => this.PrefixKeyIfNeeded(key)); var redisKeys = maybePrefixedKeys.Select(x => new RedisKey(x)).ToArray(); var includeVectors = options?.IncludeVectors ?? false; + if (includeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } // Get the list of Redis results. var redisResults = await this.RunOperationAsync( @@ -282,7 +303,8 @@ public virtual async IAsyncEnumerable GetBatchAsync(IEnumerable // Convert to the caller's data model. yield return VectorStoreErrorHandler.RunModelConversion( - DatabaseName, + RedisConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, this._collectionName, "MGET", () => @@ -294,12 +316,12 @@ public virtual async IAsyncEnumerable GetBatchAsync(IEnumerable } /// - public virtual Task DeleteAsync(string key, CancellationToken cancellationToken = default) + public Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) { - Verify.NotNullOrWhiteSpace(key); + var stringKey = this.GetStringKey(key); // Create Options - var maybePrefixedKey = this.PrefixKeyIfNeeded(key); + var maybePrefixedKey = this.PrefixKeyIfNeeded(stringKey); // Remove. return this.RunOperationAsync( @@ -310,7 +332,7 @@ public virtual Task DeleteAsync(string key, CancellationToken cancellationToken } /// - public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + public Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default) { Verify.NotNull(keys); @@ -320,18 +342,21 @@ public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken } /// - public virtual async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) { Verify.NotNull(record); // Map. + (_, var generatedEmbeddings) = await RedisVectorStoreRecordFieldMapping.ProcessEmbeddingsAsync(this._model, [record], cancellationToken).ConfigureAwait(false); + var redisJsonRecord = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, + RedisConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, this._collectionName, "SET", () => { - var mapResult = this._mapper.MapFromDataToStorageModel(record); + var mapResult = this._mapper.MapFromDataToStorageModel(record, recordIndex: 0, generatedEmbeddings); var serializedRecord = JsonSerializer.Serialize(mapResult.Node, this._jsonSerializerOptions); return new { Key = mapResult.Key, SerializedRecord = serializedRecord }; }); @@ -347,25 +372,31 @@ await this.RunOperationAsync( "$", redisJsonRecord.SerializedRecord)).ConfigureAwait(false); - return redisJsonRecord.Key; + return (TKey)(object)redisJsonRecord.Key; } /// - public virtual async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async Task> UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) { Verify.NotNull(records); // Map. + (records, var generatedEmbeddings) = await RedisVectorStoreRecordFieldMapping.ProcessEmbeddingsAsync(this._model, records, cancellationToken).ConfigureAwait(false); + var redisRecords = new List<(string maybePrefixedKey, string originalKey, string serializedRecord)>(); + + var recordIndex = 0; + foreach (var record in records) { var redisJsonRecord = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, + RedisConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, this._collectionName, "MSET", () => { - var mapResult = this._mapper.MapFromDataToStorageModel(record); + var mapResult = this._mapper.MapFromDataToStorageModel(record, recordIndex++, generatedEmbeddings); var serializedRecord = JsonSerializer.Serialize(mapResult.Node, this._jsonSerializerOptions); return new { Key = mapResult.Key, SerializedRecord = serializedRecord }; }); @@ -382,28 +413,98 @@ await this.RunOperationAsync( .JSON() .MSetAsync(keyPathValues)).ConfigureAwait(false); - // Return keys of upserted records. - foreach (var record in redisRecords) + return redisRecords.Select(x => (TKey)(object)x.originalKey).ToList(); + } + + #region Search + + /// + public async IAsyncEnumerable> SearchAsync( + TInput value, + int top, + VectorSearchOptions? options = default, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TInput : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + switch (vectorProperty.EmbeddingGenerator) { - yield return record.originalKey; + case IEmbeddingGenerator> generator: + { + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; + } + + case IEmbeddingGenerator> generator: + { + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; + } + + case null: + throw new InvalidOperationException(VectorDataStrings.NoEmbeddingGeneratorWasConfiguredForSearch); + + default: + throw new InvalidOperationException( + s_supportedVectorTypes.Contains(typeof(TInput)) + ? string.Format(VectorDataStrings.EmbeddingTypePassedToSearchAsync) + : string.Format(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType, typeof(TInput).Name, vectorProperty.EmbeddingGenerator.GetType().Name)); } } /// - public virtual async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public IAsyncEnumerable> SearchEmbeddingAsync( + TVector vector, + int top, + VectorSearchOptions? options = null, + CancellationToken cancellationToken = default) + where TVector : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + return this.SearchCoreAsync(vector, top, vectorProperty, operationName: "SearchEmbedding", options, cancellationToken); + } + + private async IAsyncEnumerable> SearchCoreAsync( + TVector vector, + int top, + VectorStoreRecordVectorPropertyModel vectorProperty, + string operationName, + VectorSearchOptions options, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TVector : notnull { Verify.NotNull(vector); + Verify.NotLessThan(top, 1); - var internalOptions = options ?? s_defaultVectorSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(internalOptions); + if (options.IncludeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } // Build query & search. byte[] vectorBytes = RedisVectorStoreCollectionSearchMapping.ValidateVectorAndConvertToBytes(vector, "JSON"); var query = RedisVectorStoreCollectionSearchMapping.BuildQuery( vectorBytes, - internalOptions, - this._propertyReader.JsonPropertyNamesMap, - this._propertyReader.GetJsonPropertyName(vectorProperty.DataModelPropertyName), + top, + options, + this._model, + vectorProperty, null); var results = await this.RunOperationAsync( "FT.SEARCH", @@ -416,7 +517,8 @@ public virtual async Task> VectorizedSearchAsync @@ -424,18 +526,80 @@ public virtual async Task> VectorizedSearchAsync(redisResultString, this._jsonSerializerOptions)!; return this._mapper.MapFromStorageToDataModel( (this.RemoveKeyPrefixIfNeeded(result.Id), node), - new() { IncludeVectors = internalOptions.IncludeVectors }); + new() { IncludeVectors = options.IncludeVectors }); }); // Process the score of the result item. - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(internalOptions); + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); var distanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(vectorProperty); var score = RedisVectorStoreCollectionSearchMapping.GetOutputScoreFromRedisScore(result["vector_score"].HasValue ? (float)result["vector_score"] : null, distanceFunction); return new VectorSearchResult(mappedRecord, score); }); - return new VectorSearchResults(mappedResults.ToAsyncEnumerable()); + foreach (var result in mappedResults) + { + yield return result; + } + } + + /// + [Obsolete("Use either SearchEmbeddingAsync to search directly on embeddings, or SearchAsync to handle embedding generation internally as part of the call.")] + public IAsyncEnumerable> VectorizedSearchAsync(TVector vector, int top, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + => this.SearchEmbeddingAsync(vector, top, options, cancellationToken); + + #endregion Search + + /// + public async IAsyncEnumerable GetAsync(Expression> filter, int top, + GetFilteredRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(filter); + Verify.NotLessThan(top, 1); + + if (options?.IncludeVectors == true && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + + Query query = RedisVectorStoreCollectionSearchMapping.BuildQuery(filter, top, options ??= new(), this._model); + + var results = await this.RunOperationAsync( + "FT.SEARCH", + () => this._database + .FT() + .SearchAsync(this._collectionName, query)).ConfigureAwait(false); + + foreach (var document in results.Documents) + { + var redisResultString = document["json"].ToString(); + yield return VectorStoreErrorHandler.RunModelConversion( + RedisConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this._collectionName, + "FT.SEARCH", + () => + { + var node = JsonSerializer.Deserialize(redisResultString, this._jsonSerializerOptions)!; + return this._mapper.MapFromStorageToDataModel( + (this.RemoveKeyPrefixIfNeeded(document.Id), node), + new() { IncludeVectors = options.IncludeVectors }); + }); + } + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreRecordCollectionMetadata) ? this._collectionMetadata : + serviceType == typeof(IDatabase) ? this._database : + serviceType.IsInstanceOfType(this) ? this : + null; } /// @@ -486,7 +650,8 @@ private async Task RunOperationAsync(string operationName, Func operation) { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreSystemName = RedisConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, CollectionName = this._collectionName, OperationName = operationName }; @@ -510,10 +675,22 @@ private async Task RunOperationAsync(string operationName, Func> o { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreSystemName = RedisConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, CollectionName = this._collectionName, OperationName = operationName }; } } + + private string GetStringKey(TKey key) + { + Verify.NotNull(key); + + var stringKey = key as string ?? throw new UnreachableException("string key should have been validated during model building"); + + Verify.NotNullOrWhiteSpace(stringKey, nameof(key)); + + return stringKey; + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollectionOptions.cs index d5f8696fc30d..f9c1c4e08fb4 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollectionOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollectionOptions.cs @@ -1,13 +1,15 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Text.Json; using System.Text.Json.Nodes; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.Redis; /// -/// Options when creating a . +/// Options when creating a . /// public sealed class RedisJsonVectorStoreRecordCollectionOptions { @@ -27,6 +29,7 @@ public sealed class RedisJsonVectorStoreRecordCollectionOptions /// /// If not set, the default built in mapper will be used, which uses record attrigutes or the provided to map the record. /// + [Obsolete("Custom mappers are no longer supported.", error: true)] public IVectorStoreRecordMapper? JsonNodeCustomMapper { get; init; } = null; /// @@ -43,4 +46,9 @@ public sealed class RedisJsonVectorStoreRecordCollectionOptions /// Gets or sets the JSON serializer options to use when converting between the data model and the Redis record. /// public JsonSerializerOptions? JsonSerializerOptions { get; init; } = null; + + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordMapper.cs index a6ffc4bec208..a2a29acdb9c9 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordMapper.cs @@ -1,8 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; using System.Text.Json; using System.Text.Json.Nodes; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.Redis; @@ -10,75 +15,95 @@ namespace Microsoft.SemanticKernel.Connectors.Redis; /// Class for mapping between a json node stored in redis, and the consumer data model. /// /// The consumer data model to map to or from. -internal sealed class RedisJsonVectorStoreRecordMapper : IVectorStoreRecordMapper +internal sealed class RedisJsonVectorStoreRecordMapper( + VectorStoreRecordModel model, + JsonSerializerOptions jsonSerializerOptions) + : IRedisJsonMapper { - /// The name of the temporary json property that the key field will be serialized / parsed from. - private readonly string _keyFieldJsonPropertyName; - - /// The JSON serializer options to use when converting between the data model and the Redis record. - private readonly JsonSerializerOptions _jsonSerializerOptions; - - /// - /// Initializes a new instance of the class. - /// - /// The name of the key field on the model when serialized to json. - /// The JSON serializer options to use when converting between the data model and the Redis record. - public RedisJsonVectorStoreRecordMapper(string keyFieldJsonPropertyName, JsonSerializerOptions jsonSerializerOptions) - { - Verify.NotNullOrWhiteSpace(keyFieldJsonPropertyName); - Verify.NotNull(jsonSerializerOptions); - - this._keyFieldJsonPropertyName = keyFieldJsonPropertyName; - this._jsonSerializerOptions = jsonSerializerOptions; - } + /// The key property. + private readonly string _keyPropertyStorageName = model.KeyProperty.StorageName; /// - public (string Key, JsonNode Node) MapFromDataToStorageModel(TConsumerDataModel dataModel) + public (string Key, JsonNode Node) MapFromDataToStorageModel(TConsumerDataModel dataModel, int recordIndex, IReadOnlyList?[]? generatedEmbeddings) { // Convert the provided record into a JsonNode object and try to get the key field for it. // Since we already checked that the key field is a string in the constructor, and that it exists on the model, // the only edge case we have to be concerned about is if the key field is null. - var jsonNode = JsonSerializer.SerializeToNode(dataModel, this._jsonSerializerOptions); - if (jsonNode!.AsObject().TryGetPropertyValue(this._keyFieldJsonPropertyName, out var keyField) && keyField is JsonValue jsonValue) + var jsonNode = JsonSerializer.SerializeToNode(dataModel, jsonSerializerOptions)!.AsObject(); + + if (!(jsonNode.TryGetPropertyValue(this._keyPropertyStorageName, out var keyField) && keyField is JsonValue jsonValue)) { - // Remove the key field from the JSON object since we don't want to store it in the redis payload. - var keyValue = jsonValue.ToString(); - jsonNode.AsObject().Remove(this._keyFieldJsonPropertyName); + throw new VectorStoreRecordMappingException($"Missing key field '{this._keyPropertyStorageName}' on provided record of type {typeof(TConsumerDataModel).FullName}."); + } - return (keyValue, jsonNode); + // Remove the key field from the JSON object since we don't want to store it in the redis payload. + var keyValue = jsonValue.ToString(); + jsonNode.Remove(this._keyPropertyStorageName); + + // Go over the vector properties; those which have an embedding generator configured on them will have embedding generators, overwrite + // the value in the JSON object with that. + if (generatedEmbeddings is not null) + { + for (var i = 0; i < model.VectorProperties.Count; i++) + { + if (generatedEmbeddings[i] is IReadOnlyList propertyEmbeddings) + { + var property = model.VectorProperties[i]; + Debug.Assert(property.EmbeddingGenerator is not null); + jsonNode[property.StorageName] = propertyEmbeddings[recordIndex] switch + { + Embedding e => JsonSerializer.SerializeToNode(e.Vector, jsonSerializerOptions), + Embedding e => JsonSerializer.SerializeToNode(e.Vector, jsonSerializerOptions), + _ => throw new UnreachableException() + }; + } + } } - throw new VectorStoreRecordMappingException($"Missing key field {this._keyFieldJsonPropertyName} on provided record of type {typeof(TConsumerDataModel).FullName}."); + return (keyValue, jsonNode); } /// public TConsumerDataModel MapFromStorageToDataModel((string Key, JsonNode Node) storageModel, StorageToDataModelMapperOptions options) { - JsonObject jsonObject; - - // The redis result can be either a single object or an array with a single object in the case where we are doing an MGET. - if (storageModel.Node is JsonObject topLevelJsonObject) - { - jsonObject = topLevelJsonObject; - } - else if (storageModel.Node is JsonArray jsonArray && jsonArray.Count == 1 && jsonArray[0] is JsonObject arrayEntryJsonObject) - { - jsonObject = arrayEntryJsonObject; - } - else + // The redis result can have one of three different formats: + // 1. a single object + // 2. an array with a single object in the case where we are doing an MGET + // 3. a single value (string, number, etc.) in the case where there is only one property being requested because the model has only one property apart from the key + var jsonObject = storageModel.Node switch { - throw new VectorStoreRecordMappingException($"Invalid data format for document with key '{storageModel.Key}'"); - } + JsonObject topLevelJsonObject => topLevelJsonObject, + JsonArray and [JsonObject arrayEntryJsonObject] => arrayEntryJsonObject, + JsonValue when model.DataProperties.Count + model.VectorProperties.Count == 1 => new JsonObject + { + [model.DataProperties.Concat(model.VectorProperties).First().StorageName] = storageModel.Node + }, + _ => throw new VectorStoreRecordMappingException($"Invalid data format for document with key '{storageModel.Key}'") + }; // Check that the key field is not already present in the redis value. - if (jsonObject.ContainsKey(this._keyFieldJsonPropertyName)) + if (jsonObject.ContainsKey(this._keyPropertyStorageName)) { - throw new VectorStoreRecordMappingException($"Invalid data format for document with key '{storageModel.Key}'. Key property '{this._keyFieldJsonPropertyName}' is already present on retrieved object."); + throw new VectorStoreRecordMappingException($"Invalid data format for document with key '{storageModel.Key}'. Key property '{this._keyPropertyStorageName}' is already present on retrieved object."); } // Since the key is not stored in the redis value, add it back in before deserializing into the data model. - jsonObject.Add(this._keyFieldJsonPropertyName, storageModel.Key); + jsonObject.Add(this._keyPropertyStorageName, storageModel.Key); + + // For vector properties which have embedding generation configured, we need to remove the embeddings before deserializing + // (we can't go back from an embedding to e.g. string). + // For other cases (no embedding generation), we leave the properties even if IncludeVectors is false. + if (!options.IncludeVectors) + { + foreach (var vectorProperty in model.VectorProperties) + { + if (vectorProperty.EmbeddingGenerator is not null) + { + jsonObject.Remove(vectorProperty.StorageName); + } + } + } - return JsonSerializer.Deserialize(jsonObject, this._jsonSerializerOptions)!; + return JsonSerializer.Deserialize(jsonObject, jsonSerializerOptions)!; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisKernelBuilderExtensions.cs index 1f7ed194856f..6421c68928ae 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisKernelBuilderExtensions.cs @@ -53,6 +53,7 @@ public static IKernelBuilder AddRedisHashSetVectorStoreRecordCollection string collectionName, RedisHashSetVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { builder.Services.AddRedisHashSetVectorStoreRecordCollection(collectionName, options, serviceId); return builder; @@ -75,6 +76,7 @@ public static IKernelBuilder AddRedisHashSetVectorStoreRecordCollection string redisConnectionConfiguration, RedisHashSetVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { builder.Services.AddRedisHashSetVectorStoreRecordCollection(collectionName, redisConnectionConfiguration, options, serviceId); return builder; @@ -95,6 +97,7 @@ public static IKernelBuilder AddRedisJsonVectorStoreRecordCollection( string collectionName, RedisJsonVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { builder.Services.AddRedisJsonVectorStoreRecordCollection(collectionName, options, serviceId); return builder; @@ -117,6 +120,7 @@ public static IKernelBuilder AddRedisJsonVectorStoreRecordCollection( string redisConnectionConfiguration, RedisJsonVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { builder.Services.AddRedisJsonVectorStoreRecordCollection(collectionName, redisConnectionConfiguration, options, serviceId); return builder; diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisMemoryStore.cs index 29dfb78da922..a6a626419aa7 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisMemoryStore.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -18,13 +17,15 @@ namespace Microsoft.SemanticKernel.Connectors.Redis; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// An implementation of for Redis. /// /// The embedded data is saved to the Redis server database specified in the constructor. /// Similarity search capability is provided through the RediSearch module. Use RediSearch's "Index" to implement "Collection". /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and RedisVectorStore")] public class RedisMemoryStore : IMemoryStore, IDisposable { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisServiceCollectionExtensions.cs index 778c1e75a88a..ed00293afd27 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisServiceCollectionExtensions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Redis; @@ -28,11 +29,12 @@ public static IServiceCollection AddRedisVectorStore(this IServiceCollection ser (sp, obj) => { var database = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService(); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new RedisVectorStore( - database, - selectedOptions); + return new RedisVectorStore(database, options); }); return services; @@ -55,11 +57,12 @@ public static IServiceCollection AddRedisVectorStore(this IServiceCollection ser (sp, obj) => { var database = ConnectionMultiplexer.Connect(redisConnectionConfiguration).GetDatabase(); - var selectedOptions = options ?? sp.GetService(); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new RedisVectorStore( - database, - selectedOptions); + return new RedisVectorStore(database, options); }); return services; @@ -80,15 +83,19 @@ public static IServiceCollection AddRedisHashSetVectorStoreRecordCollection? options = default, string? serviceId = default) + where TRecord : notnull { services.AddKeyedTransient>( serviceId, (sp, obj) => { var database = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new RedisHashSetVectorStoreRecordCollection(database, collectionName, selectedOptions); + return new RedisHashSetVectorStoreRecordCollection(database, collectionName, options); }); AddVectorizedSearch(services, serviceId); @@ -113,15 +120,19 @@ public static IServiceCollection AddRedisHashSetVectorStoreRecordCollection? options = default, string? serviceId = default) + where TRecord : notnull { services.AddKeyedSingleton>( serviceId, (sp, obj) => { var database = ConnectionMultiplexer.Connect(redisConnectionConfiguration).GetDatabase(); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new RedisHashSetVectorStoreRecordCollection(database, collectionName, selectedOptions); + return new RedisHashSetVectorStoreRecordCollection(database, collectionName, options); }); AddVectorizedSearch(services, serviceId); @@ -144,15 +155,19 @@ public static IServiceCollection AddRedisJsonVectorStoreRecordCollection? options = default, string? serviceId = default) + where TRecord : notnull { services.AddKeyedTransient>( serviceId, (sp, obj) => { var database = sp.GetRequiredService(); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new RedisJsonVectorStoreRecordCollection(database, collectionName, selectedOptions); + return new RedisJsonVectorStoreRecordCollection(database, collectionName, options); }); AddVectorizedSearch(services, serviceId); @@ -177,15 +192,19 @@ public static IServiceCollection AddRedisJsonVectorStoreRecordCollection? options = default, string? serviceId = default) + where TRecord : notnull { services.AddKeyedSingleton>( serviceId, (sp, obj) => { var database = ConnectionMultiplexer.Connect(redisConnectionConfiguration).GetDatabase(); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new RedisJsonVectorStoreRecordCollection(database, collectionName, selectedOptions); + return new RedisJsonVectorStoreRecordCollection(database, collectionName, options); }); AddVectorizedSearch(services, serviceId); @@ -194,14 +213,14 @@ public static IServiceCollection AddRedisJsonVectorStoreRecordCollection - /// Also register the with the given as a . + /// Also register the with the given as a . /// /// The type of the data model that the collection should contain. /// The service collection to register on. /// The service id that the registrations should use. - private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) + private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) where TRecord : notnull { - services.AddKeyedTransient>( + services.AddKeyedTransient>( serviceId, (sp, obj) => { diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorDistanceMetric.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorDistanceMetric.cs index 551d3e2e844d..96bcbd1bc917 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorDistanceMetric.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorDistanceMetric.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; namespace Microsoft.SemanticKernel.Connectors.Redis; @@ -8,7 +8,7 @@ namespace Microsoft.SemanticKernel.Connectors.Redis; /// Supported distance metrics are {L2, IP, COSINE}. The default value is "COSINE". /// /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and RedisVectorStore")] public enum VectorDistanceMetric { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStore.cs index 4966917d3990..156acb9bed66 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStore.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Threading; +using System.Threading.Tasks; using Microsoft.Extensions.VectorData; using NRedisStack.RedisStackCommands; using StackExchange.Redis; @@ -16,10 +17,10 @@ namespace Microsoft.SemanticKernel.Connectors.Redis; /// /// This class can be used with collections of any schema type, but requires you to provide schema information when getting a collection. /// -public class RedisVectorStore : IVectorStore +public sealed class RedisVectorStore : IVectorStore { - /// The name of this database for telemetry purposes. - private const string DatabaseName = "Redis"; + /// Metadata about vector store. + private readonly VectorStoreMetadata _metadata; /// The redis database to read/write indices from. private readonly IDatabase _database; @@ -27,6 +28,9 @@ public class RedisVectorStore : IVectorStore /// Optional configuration options for this class. private readonly RedisVectorStoreOptions _options; + /// A general purpose definition that can be used to construct a collection when needing to proxy schema agnostic operations. + private static readonly VectorStoreRecordDefinition s_generalPurposeDefinition = new() { Properties = [new VectorStoreRecordKeyProperty("Key", typeof(string))] }; + /// /// Initializes a new instance of the class. /// @@ -38,11 +42,18 @@ public RedisVectorStore(IDatabase database, RedisVectorStoreOptions? options = d this._database = database; this._options = options ?? new RedisVectorStoreOptions(); + + this._metadata = new() + { + VectorStoreSystemName = RedisConstants.VectorStoreSystemName, + VectorStoreName = database.Database.ToString() + }; } /// - public virtual IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull + where TRecord : notnull { #pragma warning disable CS0618 // IRedisVectorStoreRecordCollectionFactory is obsolete if (this._options.VectorStoreCollectionFactory is not null) @@ -51,25 +62,28 @@ public virtual IVectorStoreRecordCollection GetCollection(this._database, name, new RedisHashSetVectorStoreRecordCollectionOptions() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }) as IVectorStoreRecordCollection; + var recordCollection = new RedisHashSetVectorStoreRecordCollection(this._database, name, new RedisHashSetVectorStoreRecordCollectionOptions() + { + VectorStoreRecordDefinition = vectorStoreRecordDefinition, + EmbeddingGenerator = this._options.EmbeddingGenerator + }) as IVectorStoreRecordCollection; return recordCollection!; } else { - var recordCollection = new RedisJsonVectorStoreRecordCollection(this._database, name, new RedisJsonVectorStoreRecordCollectionOptions() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }) as IVectorStoreRecordCollection; + var recordCollection = new RedisJsonVectorStoreRecordCollection(this._database, name, new RedisJsonVectorStoreRecordCollectionOptions() + { + VectorStoreRecordDefinition = vectorStoreRecordDefinition, + EmbeddingGenerator = this._options.EmbeddingGenerator + }) as IVectorStoreRecordCollection; return recordCollection!; } } /// - public virtual async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { const string OperationName = ""; RedisResult[] listResult; @@ -82,7 +96,8 @@ public virtual async IAsyncEnumerable ListCollectionNamesAsync([Enumerat { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreSystemName = RedisConstants.VectorStoreSystemName, + VectorStoreName = this._metadata.VectorStoreName, OperationName = OperationName }; } @@ -96,4 +111,31 @@ public virtual async IAsyncEnumerable ListCollectionNamesAsync([Enumerat } } } + + /// + public Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.CollectionExistsAsync(cancellationToken); + } + + /// + public Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.DeleteCollectionAsync(cancellationToken); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreMetadata) ? this._metadata : + serviceType == typeof(IDatabase) ? this._metadata : + serviceType.IsInstanceOfType(this) ? this : + null; + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionCreateMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionCreateMapping.cs index 8bde159e848a..6ee4e163eadf 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionCreateMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionCreateMapping.cs @@ -3,9 +3,11 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Diagnostics; using System.Globalization; using System.Linq; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using NRedisStack.Search; namespace Microsoft.SemanticKernel.Connectors.Redis; @@ -47,11 +49,10 @@ internal static class RedisVectorStoreCollectionCreateMapping /// Map from the given list of items to the Redis . /// /// The property definitions to map from. - /// A dictionary that maps from a property name to the storage name that should be used when serializing it to json for data and vector properties. /// A value indicating whether to include $. prefix for field names as required in JSON mode. /// The mapped Redis . /// Thrown if there are missing required or unsupported configuration options set. - public static Schema MapToSchema(IEnumerable properties, IReadOnlyDictionary storagePropertyNames, bool useDollarPrefix) + public static Schema MapToSchema(IEnumerable properties, bool useDollarPrefix) { var schema = new Schema(); var fieldNamePrefix = useDollarPrefix ? "$." : string.Empty; @@ -59,79 +60,68 @@ public static Schema MapToSchema(IEnumerable properti // Loop through all properties and create the index fields. foreach (var property in properties) { - // Key property. - if (property is VectorStoreRecordKeyProperty keyProperty) - { - // Do nothing, since key is not stored as part of the payload and therefore doesn't have to be added to the index. - continue; - } + var storageName = property.StorageName; - // Data property. - if (property is VectorStoreRecordDataProperty dataProperty && (dataProperty.IsFilterable || dataProperty.IsFullTextSearchable)) + switch (property) { - var storageName = storagePropertyNames[dataProperty.DataModelPropertyName]; + case VectorStoreRecordKeyPropertyModel keyProperty: + // Do nothing, since key is not stored as part of the payload and therefore doesn't have to be added to the index. + continue; - if (dataProperty.IsFilterable && dataProperty.IsFullTextSearchable) - { - throw new InvalidOperationException($"Property '{dataProperty.DataModelPropertyName}' has both {nameof(VectorStoreRecordDataProperty.IsFilterable)} and {nameof(VectorStoreRecordDataProperty.IsFullTextSearchable)} set to true, and this is not supported by the Redis VectorStore."); - } - - // Add full text search field index. - if (dataProperty.IsFullTextSearchable) - { - if (dataProperty.PropertyType == typeof(string) || (typeof(IEnumerable).IsAssignableFrom(dataProperty.PropertyType) && GetEnumerableType(dataProperty.PropertyType) == typeof(string))) + case VectorStoreRecordDataPropertyModel dataProperty when dataProperty.IsIndexed || dataProperty.IsFullTextIndexed: + if (dataProperty.IsIndexed && dataProperty.IsFullTextIndexed) { - schema.AddTextField(new FieldName($"{fieldNamePrefix}{storageName}", storageName)); + throw new InvalidOperationException($"Property '{dataProperty.ModelName}' has both {nameof(VectorStoreRecordDataProperty.IsIndexed)} and {nameof(VectorStoreRecordDataProperty.IsFullTextIndexed)} set to true, and this is not supported by the Redis VectorStore."); } - else - { - throw new InvalidOperationException($"Property {nameof(dataProperty.IsFullTextSearchable)} on {nameof(VectorStoreRecordDataProperty)} '{dataProperty.DataModelPropertyName}' is set to true, but the property type is not a string or IEnumerable. The Redis VectorStore supports {nameof(dataProperty.IsFullTextSearchable)} on string or IEnumerable properties only."); - } - } - // Add filter field index. - if (dataProperty.IsFilterable) - { - if (dataProperty.PropertyType == typeof(string)) - { - schema.AddTagField(new FieldName($"{fieldNamePrefix}{storageName}", storageName)); - } - else if (typeof(IEnumerable).IsAssignableFrom(dataProperty.PropertyType) && GetEnumerableType(dataProperty.PropertyType) == typeof(string)) + // Add full text search field index. + if (dataProperty.IsFullTextIndexed) { - schema.AddTagField(new FieldName($"{fieldNamePrefix}{storageName}.*", storageName)); + if (dataProperty.Type == typeof(string) || (typeof(IEnumerable).IsAssignableFrom(dataProperty.Type) && GetEnumerableType(dataProperty.Type) == typeof(string))) + { + schema.AddTextField(new FieldName($"{fieldNamePrefix}{storageName}", storageName)); + } + else + { + throw new InvalidOperationException($"Property {nameof(dataProperty.IsFullTextIndexed)} on {nameof(VectorStoreRecordDataProperty)} '{dataProperty.ModelName}' is set to true, but the property type is not a string or IEnumerable. The Redis VectorStore supports {nameof(dataProperty.IsFullTextIndexed)} on string or IEnumerable properties only."); + } } - else if (RedisVectorStoreCollectionCreateMapping.s_supportedFilterableNumericDataTypes.Contains(dataProperty.PropertyType)) - { - schema.AddNumericField(new FieldName($"{fieldNamePrefix}{storageName}", storageName)); - } - else + + // Add filter field index. + if (dataProperty.IsIndexed) { - throw new InvalidOperationException($"Property '{dataProperty.DataModelPropertyName}' is marked as {nameof(VectorStoreRecordDataProperty.IsFilterable)}, but the property type '{dataProperty.PropertyType}' is not supported. Only string, IEnumerable and numeric properties are supported for filtering by the Redis VectorStore."); + if (dataProperty.Type == typeof(string)) + { + schema.AddTagField(new FieldName($"{fieldNamePrefix}{storageName}", storageName)); + } + else if (typeof(IEnumerable).IsAssignableFrom(dataProperty.Type) && GetEnumerableType(dataProperty.Type) == typeof(string)) + { + schema.AddTagField(new FieldName($"{fieldNamePrefix}{storageName}.*", storageName)); + } + else if (RedisVectorStoreCollectionCreateMapping.s_supportedFilterableNumericDataTypes.Contains(dataProperty.Type)) + { + schema.AddNumericField(new FieldName($"{fieldNamePrefix}{storageName}", storageName)); + } + else + { + throw new InvalidOperationException($"Property '{dataProperty.ModelName}' is marked as {nameof(VectorStoreRecordDataProperty.IsIndexed)}, but the property type '{dataProperty.Type}' is not supported. Only string, IEnumerable and numeric properties are supported for filtering by the Redis VectorStore."); + } } - } - continue; - } + continue; - // Vector property. - if (property is VectorStoreRecordVectorProperty vectorProperty) - { - if (vectorProperty.Dimensions is not > 0) - { - throw new InvalidOperationException($"Property {nameof(vectorProperty.Dimensions)} on {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' must be set to a positive integer to create a collection."); - } - - var storageName = storagePropertyNames[vectorProperty.DataModelPropertyName]; - var indexKind = GetSDKIndexKind(vectorProperty); - var vectorType = GetSDKVectorType(vectorProperty); - var dimensions = vectorProperty.Dimensions.Value.ToString(CultureInfo.InvariantCulture); - var distanceAlgorithm = GetSDKDistanceAlgorithm(vectorProperty); - schema.AddVectorField(new FieldName($"{fieldNamePrefix}{storageName}", storageName), indexKind, new Dictionary() - { - ["TYPE"] = vectorType, - ["DIM"] = dimensions, - ["DISTANCE_METRIC"] = distanceAlgorithm - }); + case VectorStoreRecordVectorPropertyModel vectorProperty: + var indexKind = GetSDKIndexKind(vectorProperty); + var vectorType = GetSDKVectorType(vectorProperty); + var dimensions = vectorProperty.Dimensions.ToString(CultureInfo.InvariantCulture); + var distanceAlgorithm = GetSDKDistanceAlgorithm(vectorProperty); + schema.AddVectorField(new FieldName($"{fieldNamePrefix}{storageName}", storageName), indexKind, new Dictionary() + { + ["TYPE"] = vectorType, + ["DIM"] = dimensions, + ["DISTANCE_METRIC"] = distanceAlgorithm + }); + continue; } } @@ -145,20 +135,13 @@ public static Schema MapToSchema(IEnumerable properti /// The vector property definition. /// The chosen . /// Thrown if a index type was chosen that isn't supported by Redis. - public static Schema.VectorField.VectorAlgo GetSDKIndexKind(VectorStoreRecordVectorProperty vectorProperty) - { - if (vectorProperty.IndexKind is null) - { - return Schema.VectorField.VectorAlgo.HNSW; - } - - return vectorProperty.IndexKind switch + public static Schema.VectorField.VectorAlgo GetSDKIndexKind(VectorStoreRecordVectorPropertyModel vectorProperty) + => vectorProperty.IndexKind switch { - IndexKind.Hnsw => Schema.VectorField.VectorAlgo.HNSW, + IndexKind.Hnsw or null => Schema.VectorField.VectorAlgo.HNSW, IndexKind.Flat => Schema.VectorField.VectorAlgo.FLAT, - _ => throw new InvalidOperationException($"Index kind '{vectorProperty.IndexKind}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' is not supported by the Redis VectorStore.") + _ => throw new InvalidOperationException($"Index kind '{vectorProperty.IndexKind}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.ModelName}' is not supported by the Redis VectorStore.") }; - } /// /// Get the configured distance metric from the given . @@ -167,22 +150,15 @@ public static Schema.VectorField.VectorAlgo GetSDKIndexKind(VectorStoreRecordVec /// The vector property definition. /// The chosen distance metric. /// Thrown if a distance function is chosen that isn't supported by Redis. - public static string GetSDKDistanceAlgorithm(VectorStoreRecordVectorProperty vectorProperty) - { - if (vectorProperty.DistanceFunction is null) + public static string GetSDKDistanceAlgorithm(VectorStoreRecordVectorPropertyModel vectorProperty) + => vectorProperty.DistanceFunction switch { - return "COSINE"; - } - - return vectorProperty.DistanceFunction switch - { - DistanceFunction.CosineSimilarity => "COSINE", + DistanceFunction.CosineSimilarity or null => "COSINE", DistanceFunction.CosineDistance => "COSINE", DistanceFunction.DotProductSimilarity => "IP", DistanceFunction.EuclideanSquaredDistance => "L2", - _ => throw new InvalidOperationException($"Distance function '{vectorProperty.DistanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' is not supported by the Redis VectorStore.") + _ => throw new InvalidOperationException($"Distance function '{vectorProperty.DistanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.ModelName}' is not supported by the Redis VectorStore.") }; - } /// /// Get the vector type to pass to the SDK based on the data type of the vector property. @@ -190,17 +166,16 @@ public static string GetSDKDistanceAlgorithm(VectorStoreRecordVectorProperty vec /// The vector property definition. /// The SDK required vector type. /// Thrown if the property data type is not supported by the connector. - public static string GetSDKVectorType(VectorStoreRecordVectorProperty vectorProperty) - { - return vectorProperty.PropertyType switch + public static string GetSDKVectorType(VectorStoreRecordVectorPropertyModel vectorProperty) + => vectorProperty.EmbeddingType switch { Type t when t == typeof(ReadOnlyMemory) => "FLOAT32", Type t when t == typeof(ReadOnlyMemory?) => "FLOAT32", Type t when t == typeof(ReadOnlyMemory) => "FLOAT64", Type t when t == typeof(ReadOnlyMemory?) => "FLOAT64", - _ => throw new InvalidOperationException($"Vector data type '{vectorProperty.PropertyType.FullName}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' is not supported by the Redis VectorStore.") + null => throw new UnreachableException("null embedding type"), + _ => throw new InvalidOperationException($"Vector data type '{vectorProperty.Type.Name}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.ModelName}' is not supported by the Redis VectorStore.") }; - } /// /// Gets the type of object stored in the given enumerable type. diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionSearchMapping.cs index b9d199eb3361..2140f3b24c48 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionSearchMapping.cs @@ -1,11 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Runtime.InteropServices; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using NRedisStack.Search; namespace Microsoft.SemanticKernel.Connectors.Redis; @@ -24,49 +24,39 @@ internal static class RedisVectorStoreCollectionSearchMapping /// The vector converted to a byte array. /// Thrown if the vector type is not supported. public static byte[] ValidateVectorAndConvertToBytes(TVector vector, string connectorTypeName) - { - byte[] vectorBytes; - if (vector is ReadOnlyMemory floatVector) - { - vectorBytes = MemoryMarshal.AsBytes(floatVector.Span).ToArray(); - } - else if (vector is ReadOnlyMemory doubleVector) + => vector switch { - vectorBytes = MemoryMarshal.AsBytes(doubleVector.Span).ToArray(); - } - else - { - throw new NotSupportedException($"The provided vector type {vector?.GetType().FullName} is not supported by the Redis {connectorTypeName} connector."); - } - - return vectorBytes; - } + ReadOnlyMemory floatVector => MemoryMarshal.AsBytes(floatVector.Span).ToArray(), + ReadOnlyMemory doubleVector => MemoryMarshal.AsBytes(doubleVector.Span).ToArray(), + _ => throw new NotSupportedException($"The provided vector type {vector?.GetType().FullName} is not supported by the Redis {connectorTypeName} connector.") + }; /// /// Build a Redis object from the given vector and options. /// /// The vector to search the database with as a byte array. + /// The maximum number of elements to return. /// The options to configure the behavior of the search. - /// A mapping of data model property names to the names under which they are stored. - /// The storage name of the vector property. + /// The model. + /// The vector property. /// The set of fields to limit the results to. Null for all. /// The . - public static Query BuildQuery(byte[] vectorBytes, VectorSearchOptions options, IReadOnlyDictionary storagePropertyNames, string vectorStoragePropertyName, string[]? selectFields) + public static Query BuildQuery(byte[] vectorBytes, int top, VectorSearchOptions options, VectorStoreRecordModel model, VectorStoreRecordVectorPropertyModel vectorProperty, string[]? selectFields) { // Build search query. - var redisLimit = options.Top + options.Skip; + var redisLimit = top + options.Skip; #pragma warning disable CS0618 // Type or member is obsolete var filter = options switch { { OldFilter: not null, Filter: not null } => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), - { OldFilter: VectorSearchFilter legacyFilter } => BuildLegacyFilter(legacyFilter, storagePropertyNames), - { Filter: Expression> newFilter } => new RedisFilterTranslator().Translate(newFilter, storagePropertyNames), + { OldFilter: VectorSearchFilter legacyFilter } => BuildLegacyFilter(legacyFilter, model), + { Filter: Expression> newFilter } => new RedisFilterTranslator().Translate(newFilter, model), _ => "*" }; #pragma warning restore CS0618 // Type or member is obsolete - var query = new Query($"{filter}=>[KNN {redisLimit} @{vectorStoragePropertyName} $embedding AS vector_score]") + var query = new Query($"{filter}=>[KNN {redisLimit} @{vectorProperty.StorageName} $embedding AS vector_score]") .AddParam("embedding", vectorBytes) .SetSortBy("vector_score") .Limit(options.Skip, redisLimit) @@ -81,21 +71,44 @@ public static Query BuildQuery(byte[] vectorBytes, VectorSearchOptions< return query; } + internal static Query BuildQuery(Expression> filter, int top, GetFilteredRecordOptions options, VectorStoreRecordModel model) + { + var translatedFilter = new RedisFilterTranslator().Translate(filter, model); + Query query = new Query(translatedFilter) + .Limit(options.Skip, top) + .Dialect(2); + + var sortInfo = options.OrderBy.Values.Count switch + { + 0 => null, + 1 => options.OrderBy.Values[0], + _ => throw new NotSupportedException("Redis does not support ordering by more than one property.") + }; + + if (sortInfo is not null) + { + string storageName = model.GetDataOrKeyProperty(sortInfo.PropertySelector).StorageName; + query = query.SetSortBy(field: storageName, ascending: sortInfo.Ascending); + } + + return query; + } + /// /// Build a redis filter string from the provided . /// /// The to build the Redis filter string from. - /// A mapping of data model property names to the names under which they are stored. + /// The model. /// The Redis filter string. /// Thrown when a provided filter value is not supported. #pragma warning disable CS0618 // Type or member is obsolete - public static string BuildLegacyFilter(VectorSearchFilter basicVectorSearchFilter, IReadOnlyDictionary storagePropertyNames) + public static string BuildLegacyFilter(VectorSearchFilter basicVectorSearchFilter, VectorStoreRecordModel model) { var filterClauses = basicVectorSearchFilter.FilterClauses.Select(clause => { if (clause is EqualToFilterClause equalityFilterClause) { - var storagePropertyName = GetStoragePropertyName(storagePropertyNames, equalityFilterClause.FieldName); + var storagePropertyName = GetStoragePropertyName(model, equalityFilterClause.FieldName); return equalityFilterClause.Value switch { @@ -109,7 +122,7 @@ public static string BuildLegacyFilter(VectorSearchFilter basicVectorSearchFilte } else if (clause is AnyTagEqualToFilterClause tagListContainsClause) { - var storagePropertyName = GetStoragePropertyName(storagePropertyNames, tagListContainsClause.FieldName); + var storagePropertyName = GetStoragePropertyName(model, tagListContainsClause.FieldName); return $"@{storagePropertyName}:{{{tagListContainsClause.Value}}}"; } else @@ -128,7 +141,7 @@ public static string BuildLegacyFilter(VectorSearchFilter basicVectorSearchFilte /// /// The vector property to be used. /// The distance function for the vector we want to search. - public static string ResolveDistanceFunction(VectorStoreRecordVectorProperty vectorProperty) + public static string ResolveDistanceFunction(VectorStoreRecordVectorPropertyModel vectorProperty) => vectorProperty.DistanceFunction ?? DistanceFunction.CosineSimilarity; /// @@ -159,17 +172,17 @@ public static string ResolveDistanceFunction(VectorStoreRecordVectorProperty vec /// /// Gets the name of the name under which the property with the given name is stored. /// - /// A mapping of data model property names to the names under which they are stored. + /// The model. /// The name of the property in the data model. /// The name that the property os stored under. /// Thrown when the property name is not found. - private static string GetStoragePropertyName(IReadOnlyDictionary storagePropertyNames, string fieldName) + private static string GetStoragePropertyName(VectorStoreRecordModel model, string fieldName) { - if (!storagePropertyNames.TryGetValue(fieldName, out var storageFieldName)) + if (!model.PropertyMap.TryGetValue(fieldName, out var property)) { throw new InvalidOperationException($"Property name '{fieldName}' provided as part of the filter clause is not a valid property name."); } - return storageFieldName; + return property.StorageName; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreOptions.cs index c9af8554c231..c31580ec1730 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreOptions.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.Connectors.Redis; @@ -10,7 +11,7 @@ namespace Microsoft.SemanticKernel.Connectors.Redis; public sealed class RedisVectorStoreOptions { /// - /// An optional factory to use for constructing instances, if a custom record collection is required. + /// An optional factory to use for constructing instances, if a custom record collection is required. /// [Obsolete("To control how collections are instantiated, extend your provider's IVectorStore implementation and override GetCollection()")] public IRedisVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } @@ -19,4 +20,9 @@ public sealed class RedisVectorStoreOptions /// Indicates the way in which data should be stored in redis. Default is . /// public RedisStorageType? StorageType { get; init; } = RedisStorageType.Json; + + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreRecordFieldMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreRecordFieldMapping.cs index fd9d183330a4..40b9c9d0c120 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreRecordFieldMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreRecordFieldMapping.cs @@ -1,7 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; +using System.Linq; using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.Redis; @@ -29,4 +35,61 @@ public static byte[] ConvertVectorToBytes(ReadOnlyMemory vector) { return MemoryMarshal.AsBytes(vector.Span).ToArray(); } + + internal static async ValueTask<(IEnumerable records, IReadOnlyList?[]?)> ProcessEmbeddingsAsync( + VectorStoreRecordModel model, + IEnumerable records, + CancellationToken cancellationToken) + where TRecord : notnull + { + IReadOnlyList? recordsList = null; + + // If an embedding generator is defined, invoke it once per property for all records. + IReadOnlyList?[]? generatedEmbeddings = null; + + var vectorPropertyCount = model.VectorProperties.Count; + for (var i = 0; i < vectorPropertyCount; i++) + { + var vectorProperty = model.VectorProperties[i]; + + if (vectorProperty.EmbeddingGenerator is null) + { + continue; + } + + // We have a property with embedding generation; materialize the records' enumerable if needed, to + // prevent multiple enumeration. + if (recordsList is null) + { + recordsList = records is IReadOnlyList r ? r : records.ToList(); + + if (recordsList.Count == 0) + { + return (records, null); + } + + records = recordsList; + } + + // TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties), + // and generate embeddings for them in a single batch. That's some more complexity though. + if (vectorProperty.TryGenerateEmbeddings, ReadOnlyMemory>(records, cancellationToken, out var floatTask)) + { + generatedEmbeddings ??= new IReadOnlyList?[vectorPropertyCount]; + generatedEmbeddings[i] = await floatTask.ConfigureAwait(false); + } + else if (vectorProperty.TryGenerateEmbeddings, ReadOnlyMemory>(records, cancellationToken, out var doubleTask)) + { + generatedEmbeddings ??= new IReadOnlyList?[vectorPropertyCount]; + generatedEmbeddings[i] = await doubleTask.ConfigureAwait(false); + } + else + { + throw new InvalidOperationException( + $"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding).Name}' for the given input type."); + } + } + + return (records, generatedEmbeddings); + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj b/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj index b188e9a2d2aa..5afb4f42b560 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj @@ -4,13 +4,14 @@ Microsoft.SemanticKernel.Connectors.SqlServer $(AssemblyName) - netstandard2.0;net8.0 + netstandard2.0;net8.0;net462 preview + @@ -23,10 +24,13 @@ + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/ExceptionWrapper.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/ExceptionWrapper.cs index 6690f1d564a4..43a58f791223 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/ExceptionWrapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/ExceptionWrapper.cs @@ -12,15 +12,14 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer; internal static class ExceptionWrapper { - internal const string VectorStoreType = "SqlServer"; - internal static async Task WrapAsync( SqlConnection connection, SqlCommand command, Func> func, - CancellationToken cancellationToken, string operationName, - string? collectionName = null) + string? vectorStoreName = null, + string? collectionName = null, + CancellationToken cancellationToken = default) { if (connection.State != System.Data.ConnectionState.Open) { @@ -41,18 +40,20 @@ internal static async Task WrapAsync( throw new VectorStoreOperationException(ex.Message, ex) { - OperationName = operationName, - VectorStoreType = VectorStoreType, - CollectionName = collectionName + VectorStoreSystemName = SqlServerConstants.VectorStoreSystemName, + VectorStoreName = vectorStoreName, + CollectionName = collectionName, + OperationName = operationName }; } } internal static async Task WrapReadAsync( SqlDataReader reader, - CancellationToken cancellationToken, string operationName, - string? collectionName = null) + string? vectorStoreName = null, + string? collectionName = null, + CancellationToken cancellationToken = default) { try { @@ -62,9 +63,10 @@ internal static async Task WrapReadAsync( { throw new VectorStoreOperationException(ex.Message, ex) { - OperationName = operationName, - VectorStoreType = VectorStoreType, - CollectionName = collectionName + VectorStoreSystemName = SqlServerConstants.VectorStoreSystemName, + VectorStoreName = vectorStoreName, + CollectionName = collectionName, + OperationName = operationName }; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/GenericRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/GenericRecordMapper.cs deleted file mode 100644 index ff9c7851f4cb..000000000000 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/GenericRecordMapper.cs +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using Microsoft.Extensions.VectorData; - -namespace Microsoft.SemanticKernel.Connectors.SqlServer; - -internal sealed class GenericRecordMapper : IVectorStoreRecordMapper, IDictionary> - where TKey : notnull -{ - private readonly VectorStoreRecordPropertyReader _propertyReader; - - internal GenericRecordMapper(VectorStoreRecordPropertyReader propertyReader) => this._propertyReader = propertyReader; - - public IDictionary MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - { - Dictionary properties = new() - { - { SqlServerCommandBuilder.GetColumnName(this._propertyReader.KeyProperty), dataModel.Key } - }; - - foreach (var property in this._propertyReader.DataProperties) - { - string name = SqlServerCommandBuilder.GetColumnName(property); - if (dataModel.Data.TryGetValue(name, out var dataValue)) - { - properties.Add(name, dataValue); - } - } - - // Add vector properties - if (dataModel.Vectors is not null) - { - foreach (var property in this._propertyReader.VectorProperties) - { - string name = SqlServerCommandBuilder.GetColumnName(property); - if (dataModel.Vectors.TryGetValue(name, out var vectorValue)) - { - if (vectorValue is ReadOnlyMemory floats) - { - properties.Add(name, floats); - } - else if (vectorValue is not null) - { - throw new VectorStoreRecordMappingException($"Vector property '{name}' contained value of non supported type: '{vectorValue.GetType().FullName}'."); - } - } - } - } - - return properties; - } - - public VectorStoreGenericDataModel MapFromStorageToDataModel(IDictionary storageModel, StorageToDataModelMapperOptions options) - { - TKey key; - var dataProperties = new Dictionary(); - var vectorProperties = new Dictionary(); - - if (storageModel.TryGetValue(SqlServerCommandBuilder.GetColumnName(this._propertyReader.KeyProperty), out var keyObject) && keyObject is not null) - { - key = (TKey)keyObject; - } - else - { - throw new VectorStoreRecordMappingException("No key property was found in the record retrieved from storage."); - } - - foreach (var property in this._propertyReader.DataProperties) - { - string name = SqlServerCommandBuilder.GetColumnName(property); - if (storageModel.TryGetValue(name, out var dataValue)) - { - dataProperties.Add(name, dataValue); - } - } - - if (options.IncludeVectors) - { - foreach (var property in this._propertyReader.VectorProperties) - { - string name = SqlServerCommandBuilder.GetColumnName(property); - if (storageModel.TryGetValue(name, out var vectorValue)) - { - vectorProperties.Add(name, vectorValue); - } - } - } - - return new(key) { Data = dataProperties, Vectors = vectorProperties }; - } -} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/ISqlServerClient.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/ISqlServerClient.cs index a457cddd3859..a351e929ab3a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/ISqlServerClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/ISqlServerClient.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; @@ -11,7 +10,7 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer; /// /// Interface for client managing SQL Server or Azure SQL database operations. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and SqlServerVectorStore")] internal interface ISqlServerClient { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/README.md b/dotnet/src/Connectors/Connectors.Memory.SqlServer/README.md index bb78b9c9e4fa..b753e40879da 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/README.md +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/README.md @@ -36,7 +36,7 @@ using Microsoft.SemanticKernel.Connectors.OpenAI; using Microsoft.SemanticKernel.Connectors.SqlServer; using Microsoft.SemanticKernel.Memory; -#pragma warning disable SKEXP0001, SKEXP0010, SKEXP0020 +#pragma warning disable SKEXP0001, SKEXP0010 // Replace with your Azure OpenAI endpoint const string AzureOpenAIEndpoint = "https://.openai.azure.com/"; diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs index 240f2814e044..ed2cbbcf5df4 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs @@ -2,37 +2,38 @@ using System; using System.Collections.Generic; -using System.Reflection; +using System.Diagnostics; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.SqlServer; -internal sealed class RecordMapper : IVectorStoreRecordMapper> +internal sealed class RecordMapper(VectorStoreRecordModel model) { - private readonly VectorStoreRecordPropertyReader _propertyReader; - - internal RecordMapper(VectorStoreRecordPropertyReader propertyReader) => this._propertyReader = propertyReader; - - public IDictionary MapFromDataToStorageModel(TRecord dataModel) + public IDictionary MapFromDataToStorageModel(TRecord dataModel, int recordIndex, IReadOnlyList?[]? generatedEmbeddings) { Dictionary map = new(StringComparer.Ordinal); - map[SqlServerCommandBuilder.GetColumnName(this._propertyReader.KeyProperty)] = this._propertyReader.KeyPropertyInfo.GetValue(dataModel); + map[model.KeyProperty.StorageName] = model.KeyProperty.GetValueAsObject(dataModel!); - var dataProperties = this._propertyReader.DataProperties; - var dataPropertiesInfo = this._propertyReader.DataPropertiesInfo; - for (int i = 0; i < dataProperties.Count; i++) + foreach (var property in model.DataProperties) { - object? value = dataPropertiesInfo[i].GetValue(dataModel); - map[SqlServerCommandBuilder.GetColumnName(dataProperties[i])] = value; + map[property.StorageName] = property.GetValueAsObject(dataModel!); } - var vectorProperties = this._propertyReader.VectorProperties; - var vectorPropertiesInfo = this._propertyReader.VectorPropertiesInfo; - for (int i = 0; i < vectorProperties.Count; i++) + + for (var i = 0; i < model.VectorProperties.Count; i++) { - // We restrict the vector properties to ReadOnlyMemory so the cast here is safe. - ReadOnlyMemory floats = (ReadOnlyMemory)vectorPropertiesInfo[i].GetValue(dataModel)!; - map[SqlServerCommandBuilder.GetColumnName(vectorProperties[i])] = floats; + var property = model.VectorProperties[i]; + + // We restrict the vector properties to ReadOnlyMemory in model validation + map[property.StorageName] = generatedEmbeddings?[i] is IReadOnlyList e + ? e[recordIndex] switch + { + Embedding fe => fe.Vector, + _ => throw new UnreachableException() + } + : (ReadOnlyMemory)property.GetValueAsObject(dataModel!)!; } return map; @@ -40,33 +41,32 @@ internal sealed class RecordMapper : IVectorStoreRecordMapper storageModel, StorageToDataModelMapperOptions options) { - TRecord record = Activator.CreateInstance()!; - SetValue(storageModel, record, this._propertyReader.KeyPropertyInfo, this._propertyReader.KeyProperty); - var data = this._propertyReader.DataProperties; - var dataInfo = this._propertyReader.DataPropertiesInfo; - for (int i = 0; i < data.Count; i++) + var record = model.CreateRecord()!; + + SetValue(storageModel, record, model.KeyProperty, storageModel[model.KeyProperty.StorageName]); + + foreach (var property in model.DataProperties) { - SetValue(storageModel, record, dataInfo[i], data[i]); + SetValue(storageModel, record, property, storageModel[property.StorageName]); } if (options.IncludeVectors) { - var vector = this._propertyReader.VectorProperties; - var vectorInfo = this._propertyReader.VectorPropertiesInfo; - for (int i = 0; i < vector.Count; i++) + foreach (var property in model.VectorProperties) { - object? value = storageModel[SqlServerCommandBuilder.GetColumnName(vector[i])]; + var value = storageModel[property.StorageName]; + if (value is not null) { if (value is ReadOnlyMemory floats) { - vectorInfo[i].SetValue(record, floats); + SetValue(storageModel, record, property, floats); } else { // When deserializing a string to a ReadOnlyMemory fails in SqlDataReaderDictionary, // we store the raw value so the user can handle the error in a custom mapper. - throw new VectorStoreRecordMappingException($"Failed to deserialize vector property '{vector[i].DataModelPropertyName}', it contained value '{value}'."); + throw new VectorStoreRecordMappingException($"Failed to deserialize vector property '{property.ModelName}', it contained value '{value}'."); } } } @@ -74,25 +74,15 @@ public TRecord MapFromStorageToDataModel(IDictionary storageMod return record; - static void SetValue(IDictionary storageModel, object record, PropertyInfo propertyInfo, VectorStoreRecordProperty property) + static void SetValue(IDictionary storageModel, object record, VectorStoreRecordPropertyModel property, object? value) { - // If we got here, there should be no column name mismatch (the query would fail). - object? value = storageModel[SqlServerCommandBuilder.GetColumnName(property)]; - - if (value is null) - { - // There is no need to call the reflection to set the null, - // as it's the default value of every .NET reference type field. - return; - } - try { - propertyInfo.SetValue(record, value); + property.SetValueAsObject(record, value); } catch (Exception ex) { - throw new VectorStoreRecordMappingException($"Failed to set value '{value}' on property '{propertyInfo.Name}' of type '{propertyInfo.PropertyType.FullName}'.", ex); + throw new VectorStoreRecordMappingException($"Failed to set value '{value}' on property '{property.ModelName}' of type '{property.Type.Name}'.", ex); } } } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs index 414ff8de4afd..81179dfc4021 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Text.Json; using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.SqlServer; @@ -12,20 +13,12 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer; /// This class is used to provide a dictionary-like interface to a . /// The goal is to avoid the need of allocating a new dictionary for each row read from the database. /// -internal sealed class SqlDataReaderDictionary : IDictionary +internal sealed class SqlDataReaderDictionary(SqlDataReader sqlDataReader, IReadOnlyList vectorProperties) + : IDictionary { - private readonly SqlDataReader _sqlDataReader; - private readonly IReadOnlyList _vectorPropertyStoragePropertyNames; - // This field will get instantiated lazily, only if needed by a custom mapper. private Dictionary? _dictionary; - internal SqlDataReaderDictionary(SqlDataReader sqlDataReader, IReadOnlyList vectorPropertyStoragePropertyNames) - { - this._sqlDataReader = sqlDataReader; - this._vectorPropertyStoragePropertyNames = vectorPropertyStoragePropertyNames; - } - private object? Unwrap(string storageName, object? value) { // Let's make sure our users don't need to learn what DBNull is. @@ -35,11 +28,11 @@ internal SqlDataReaderDictionary(SqlDataReader sqlDataReader, IReadOnlyList 0 && value is string text) + if (vectorProperties.Count > 0 && value is string text) { - for (int i = 0; i < this._vectorPropertyStoragePropertyNames.Count; i++) + for (int i = 0; i < vectorProperties.Count; i++) { - if (string.Equals(storageName, this._vectorPropertyStoragePropertyNames[i], StringComparison.Ordinal)) + if (string.Equals(storageName, vectorProperties[i].StorageName, StringComparison.Ordinal)) { try { @@ -71,7 +64,7 @@ internal SqlDataReaderDictionary(SqlDataReader sqlDataReader, IReadOnlyList this.Unwrap(key, this._sqlDataReader[key]); + get => this.Unwrap(key, sqlDataReader[key]); set => throw new InvalidOperationException(); } @@ -79,7 +72,7 @@ public object? this[string key] public ICollection Values => this.GetDictionary().Values; - public int Count => this._sqlDataReader.FieldCount; + public int Count => sqlDataReader.FieldCount; public bool IsReadOnly => true; @@ -96,7 +89,7 @@ public bool ContainsKey(string key) { try { - return this._sqlDataReader.GetOrdinal(key) >= 0; + return sqlDataReader.GetOrdinal(key) >= 0; } catch (IndexOutOfRangeException) { @@ -121,7 +114,7 @@ public bool TryGetValue(string key, out object? value) { try { - value = this.Unwrap(key, this._sqlDataReader[key]); + value = this.Unwrap(key, sqlDataReader[key]); return true; } catch (IndexOutOfRangeException) @@ -135,11 +128,11 @@ public bool TryGetValue(string key, out object? value) { if (this._dictionary is null) { - Dictionary dictionary = new(this._sqlDataReader.FieldCount, StringComparer.Ordinal); - for (int i = 0; i < this._sqlDataReader.FieldCount; i++) + Dictionary dictionary = new(sqlDataReader.FieldCount, StringComparer.Ordinal); + for (int i = 0; i < sqlDataReader.FieldCount; i++) { - string name = this._sqlDataReader.GetName(i); - dictionary.Add(name, this.Unwrap(name, this._sqlDataReader[i])); + string name = sqlDataReader.GetName(i); + dictionary.Add(name, this.Unwrap(name, sqlDataReader[i])); } this._dictionary = dictionary; } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerClient.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerClient.cs index a6d6912a4b98..0263eaf39fbc 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerClient.cs @@ -17,7 +17,7 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer; /// Implementation of database client managing SQL Server or Azure SQL database operations. /// [SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "We need to build the full table name using schema and collection, it does not support parameterized passing.")] -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and SqlServerVectorStore")] internal sealed class SqlServerClient : ISqlServerClient { private readonly SqlConnection _connection; diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 78d3ad51b998..cb56e9f318ca 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -2,10 +2,12 @@ using System; using System.Collections.Generic; +using System.Linq.Expressions; using System.Text; using System.Text.Json; using Microsoft.Data.SqlClient; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; #pragma warning disable CA2100 // Review SQL queries for security vulnerabilities @@ -18,9 +20,7 @@ internal static SqlCommand CreateTable( string? schema, string tableName, bool ifNotExists, - VectorStoreRecordKeyProperty keyProperty, - IReadOnlyList dataProperties, - IReadOnlyList vectorProperties) + VectorStoreRecordModel model) { StringBuilder sb = new(200); if (ifNotExists) @@ -33,35 +33,48 @@ internal static SqlCommand CreateTable( sb.Append("CREATE TABLE "); sb.AppendTableName(schema, tableName); sb.AppendLine(" ("); - string keyColumnName = GetColumnName(keyProperty); - sb.AppendFormat("[{0}] {1} NOT NULL,", keyColumnName, Map(keyProperty)); + sb.AppendFormat("[{0}] {1} NOT NULL,", model.KeyProperty.StorageName, Map(model.KeyProperty)); sb.AppendLine(); - for (int i = 0; i < dataProperties.Count; i++) + + foreach (var property in model.DataProperties) { - sb.AppendFormat("[{0}] {1},", GetColumnName(dataProperties[i]), Map(dataProperties[i])); + sb.AppendFormat("[{0}] {1},", property.StorageName, Map(property)); sb.AppendLine(); } - for (int i = 0; i < vectorProperties.Count; i++) + + foreach (var property in model.VectorProperties) { - sb.AppendFormat("[{0}] VECTOR({1}),", GetColumnName(vectorProperties[i]), vectorProperties[i].Dimensions); + sb.AppendFormat("[{0}] VECTOR({1}),", property.StorageName, property.Dimensions); sb.AppendLine(); } - sb.AppendFormat("PRIMARY KEY ([{0}])", keyColumnName); + + sb.AppendFormat("PRIMARY KEY ([{0}])", model.KeyProperty.StorageName); sb.AppendLine(); sb.AppendLine(");"); // end the table definition - foreach (var vectorProperty in vectorProperties) + foreach (var dataProperty in model.DataProperties) + { + if (dataProperty.IsIndexed) + { + sb.AppendFormat("CREATE INDEX "); + sb.AppendIndexName(tableName, dataProperty.StorageName); + sb.AppendFormat(" ON ").AppendTableName(schema, tableName); + sb.AppendFormat("([{0}]);", dataProperty.StorageName); + sb.AppendLine(); + } + } + + foreach (var vectorProperty in model.VectorProperties) { switch (vectorProperty.IndexKind) { - case null: - case "": - case IndexKind.Flat: + case IndexKind.Flat or null or "": // TODO: Move to early validation break; default: throw new NotSupportedException($"Index kind {vectorProperty.IndexKind} is not supported."); } } + sb.Append("END;"); return connection.CreateCommand(sb); @@ -108,8 +121,7 @@ internal static SqlCommand MergeIntoSingle( SqlConnection connection, string? schema, string tableName, - VectorStoreRecordKeyProperty keyProperty, - IReadOnlyList properties, + VectorStoreRecordModel model, IDictionary record) { SqlCommand command = connection.CreateCommand(); @@ -119,23 +131,25 @@ internal static SqlCommand MergeIntoSingle( sb.AppendLine(" AS t"); sb.Append("USING (VALUES ("); int paramIndex = 0; - foreach (VectorStoreRecordProperty property in properties) + + foreach (var property in model.Properties) { sb.AppendParameterName(property, ref paramIndex, out string paramName).Append(','); - command.AddParameter(property, paramName, record[GetColumnName(property)]); + command.AddParameter(property, paramName, record[property.StorageName]); } + sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis sb.Append(") AS s ("); - sb.AppendColumnNames(properties); + sb.AppendColumnNames(model.Properties); sb.AppendLine(")"); - sb.AppendFormat("ON (t.[{0}] = s.[{0}])", GetColumnName(keyProperty)).AppendLine(); + sb.AppendFormat("ON (t.[{0}] = s.[{0}])", model.KeyProperty.StorageName).AppendLine(); sb.AppendLine("WHEN MATCHED THEN"); sb.Append("UPDATE SET "); - foreach (VectorStoreRecordProperty property in properties) + foreach (var property in model.Properties) { - if (property != keyProperty) // don't update the key + if (property is not VectorStoreRecordKeyPropertyModel) // don't update the key { - sb.AppendFormat("t.[{0}] = s.[{0}],", GetColumnName(property)); + sb.AppendFormat("t.[{0}] = s.[{0}],", property.StorageName); } } --sb.Length; // remove the last comma @@ -144,12 +158,12 @@ internal static SqlCommand MergeIntoSingle( sb.Append("WHEN NOT MATCHED THEN"); sb.AppendLine(); sb.Append("INSERT ("); - sb.AppendColumnNames(properties); + sb.AppendColumnNames(model.Properties); sb.AppendLine(")"); sb.Append("VALUES ("); - sb.AppendColumnNames(properties, prefix: "s."); + sb.AppendColumnNames(model.Properties, prefix: "s."); sb.AppendLine(")"); - sb.AppendFormat("OUTPUT inserted.[{0}];", GetColumnName(keyProperty)); + sb.AppendFormat("OUTPUT inserted.[{0}];", model.KeyProperty.StorageName); command.CommandText = sb.ToString(); return command; @@ -159,13 +173,12 @@ internal static bool MergeIntoMany( SqlCommand command, string? schema, string tableName, - VectorStoreRecordKeyProperty keyProperty, - IReadOnlyList properties, + VectorStoreRecordModel model, IEnumerable> records) { StringBuilder sb = new(200); // The DECLARE statement creates a table variable to store the keys of the inserted rows. - sb.AppendFormat("DECLARE @InsertedKeys TABLE (KeyColumn {0});", Map(keyProperty)); + sb.AppendFormat("DECLARE @InsertedKeys TABLE (KeyColumn {0});", Map(model.KeyProperty)); sb.AppendLine(); // The MERGE statement performs the upsert operation and outputs the keys of the inserted rows into the table variable. sb.Append("MERGE INTO "); @@ -176,10 +189,10 @@ internal static bool MergeIntoMany( foreach (var record in records) { sb.Append('('); - foreach (VectorStoreRecordProperty property in properties) + foreach (var property in model.Properties) { sb.AppendParameterName(property, ref paramIndex, out string paramName).Append(','); - command.AddParameter(property, paramName, record[GetColumnName(property)]); + command.AddParameter(property, paramName, record[property.StorageName]); } sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis sb.AppendLine(","); @@ -194,16 +207,16 @@ internal static bool MergeIntoMany( sb.Length -= (1 + Environment.NewLine.Length); // remove the last comma and newline sb.Append(") AS s ("); // s stands for source - sb.AppendColumnNames(properties); + sb.AppendColumnNames(model.Properties); sb.AppendLine(")"); - sb.AppendFormat("ON (t.[{0}] = s.[{0}])", GetColumnName(keyProperty)).AppendLine(); + sb.AppendFormat("ON (t.[{0}] = s.[{0}])", model.KeyProperty.StorageName).AppendLine(); sb.AppendLine("WHEN MATCHED THEN"); sb.Append("UPDATE SET "); - foreach (VectorStoreRecordProperty property in properties) + foreach (var property in model.Properties) { - if (property != keyProperty) // don't update the key + if (property is not VectorStoreRecordKeyPropertyModel) // don't update the key { - sb.AppendFormat("t.[{0}] = s.[{0}],", GetColumnName(property)); + sb.AppendFormat("t.[{0}] = s.[{0}],", property.StorageName); } } --sb.Length; // remove the last comma @@ -211,12 +224,12 @@ internal static bool MergeIntoMany( sb.Append("WHEN NOT MATCHED THEN"); sb.AppendLine(); sb.Append("INSERT ("); - sb.AppendColumnNames(properties); + sb.AppendColumnNames(model.Properties); sb.AppendLine(")"); sb.Append("VALUES ("); - sb.AppendColumnNames(properties, prefix: "s."); + sb.AppendColumnNames(model.Properties, prefix: "s."); sb.AppendLine(")"); - sb.AppendFormat("OUTPUT inserted.[{0}] INTO @InsertedKeys (KeyColumn);", GetColumnName(keyProperty)); + sb.AppendFormat("OUTPUT inserted.[{0}] INTO @InsertedKeys (KeyColumn);", model.KeyProperty.StorageName); sb.AppendLine(); // The SELECT statement returns the keys of the inserted rows. @@ -228,7 +241,7 @@ internal static bool MergeIntoMany( internal static SqlCommand DeleteSingle( SqlConnection connection, string? schema, string tableName, - VectorStoreRecordKeyProperty keyProperty, object key) + VectorStoreRecordKeyPropertyModel keyProperty, object key) { SqlCommand command = connection.CreateCommand(); @@ -236,7 +249,7 @@ internal static SqlCommand DeleteSingle( StringBuilder sb = new(100); sb.Append("DELETE FROM "); sb.AppendTableName(schema, tableName); - sb.AppendFormat(" WHERE [{0}] = ", GetColumnName(keyProperty)); + sb.AppendFormat(" WHERE [{0}] = ", keyProperty.StorageName); sb.AppendParameterName(keyProperty, ref paramIndex, out string keyParamName); command.AddParameter(keyProperty, keyParamName, key); @@ -246,12 +259,12 @@ internal static SqlCommand DeleteSingle( internal static bool DeleteMany( SqlCommand command, string? schema, string tableName, - VectorStoreRecordKeyProperty keyProperty, IEnumerable keys) + VectorStoreRecordKeyPropertyModel keyProperty, IEnumerable keys) { StringBuilder sb = new(100); sb.Append("DELETE FROM "); sb.AppendTableName(schema, tableName); - sb.AppendFormat(" WHERE [{0}] IN (", GetColumnName(keyProperty)); + sb.AppendFormat(" WHERE [{0}] IN (", keyProperty.StorageName); sb.AppendKeyParameterList(keys, command, keyProperty, out bool emptyKeys); sb.Append(')'); // close the IN clause @@ -266,8 +279,7 @@ internal static bool DeleteMany( internal static SqlCommand SelectSingle( SqlConnection sqlConnection, string? schema, string collectionName, - VectorStoreRecordKeyProperty keyProperty, - IReadOnlyList properties, + VectorStoreRecordModel model, object key, bool includeVectors) { @@ -276,14 +288,14 @@ internal static SqlCommand SelectSingle( int paramIndex = 0; StringBuilder sb = new(200); sb.AppendFormat("SELECT "); - sb.AppendColumnNames(properties, includeVectors: includeVectors); + sb.AppendColumnNames(model.Properties, includeVectors: includeVectors); sb.AppendLine(); sb.Append("FROM "); sb.AppendTableName(schema, collectionName); sb.AppendLine(); - sb.AppendFormat("WHERE [{0}] = ", GetColumnName(keyProperty)); - sb.AppendParameterName(keyProperty, ref paramIndex, out string keyParamName); - command.AddParameter(keyProperty, keyParamName, key); + sb.AppendFormat("WHERE [{0}] = ", model.KeyProperty.StorageName); + sb.AppendParameterName(model.KeyProperty, ref paramIndex, out string keyParamName); + command.AddParameter(model.KeyProperty, keyParamName, key); command.CommandText = sb.ToString(); return command; @@ -291,20 +303,19 @@ internal static SqlCommand SelectSingle( internal static bool SelectMany( SqlCommand command, string? schema, string tableName, - VectorStoreRecordKeyProperty keyProperty, - IReadOnlyList properties, + VectorStoreRecordModel model, IEnumerable keys, bool includeVectors) { StringBuilder sb = new(200); sb.AppendFormat("SELECT "); - sb.AppendColumnNames(properties, includeVectors: includeVectors); + sb.AppendColumnNames(model.Properties, includeVectors: includeVectors); sb.AppendLine(); sb.Append("FROM "); sb.AppendTableName(schema, tableName); sb.AppendLine(); - sb.AppendFormat("WHERE [{0}] IN (", GetColumnName(keyProperty)); - sb.AppendKeyParameterList(keys, command, keyProperty, out bool emptyKeys); + sb.AppendFormat("WHERE [{0}] IN (", model.KeyProperty.StorageName); + sb.AppendKeyParameterList(keys, command, model.KeyProperty, out bool emptyKeys); sb.Append(')'); // close the IN clause if (emptyKeys) @@ -318,9 +329,9 @@ internal static bool SelectMany( internal static SqlCommand SelectVector( SqlConnection connection, string? schema, string tableName, - VectorStoreRecordVectorProperty vectorProperty, - IReadOnlyList properties, - IReadOnlyDictionary storagePropertyNamesMap, + VectorStoreRecordVectorPropertyModel vectorProperty, + VectorStoreRecordModel model, + int top, VectorSearchOptions options, ReadOnlyMemory vector) { @@ -332,10 +343,10 @@ internal static SqlCommand SelectVector( StringBuilder sb = new(200); sb.AppendFormat("SELECT "); - sb.AppendColumnNames(properties, includeVectors: options.IncludeVectors); + sb.AppendColumnNames(model.Properties, includeVectors: options.IncludeVectors); sb.AppendLine(","); sb.AppendFormat("VECTOR_DISTANCE('{0}', {1}, CAST(@vector AS VECTOR({2}))) AS [score]", - distanceMetric, GetColumnName(vectorProperty), vector.Length); + distanceMetric, vectorProperty.StorageName, vector.Length); sb.AppendLine(); sb.Append("FROM "); sb.AppendTableName(schema, tableName); @@ -344,7 +355,7 @@ internal static SqlCommand SelectVector( { int startParamIndex = command.Parameters.Count; - SqlServerFilterTranslator translator = new(storagePropertyNamesMap, options.Filter, sb, startParamIndex: startParamIndex); + SqlServerFilterTranslator translator = new(model, options.Filter, sb, startParamIndex: startParamIndex); translator.Translate(appendWhere: true); List parameters = translator.ParameterValues; @@ -358,16 +369,72 @@ internal static SqlCommand SelectVector( sb.AppendLine(); // Negative Skip and Top values are rejected by the VectorSearchOptions property setters. // 0 is a legal value for OFFSET. - sb.AppendFormat("OFFSET {0} ROWS FETCH NEXT {1} ROWS ONLY;", options.Skip, options.Top); + sb.AppendFormat("OFFSET {0} ROWS FETCH NEXT {1} ROWS ONLY;", options.Skip, top); command.CommandText = sb.ToString(); return command; } - internal static string GetColumnName(VectorStoreRecordProperty property) - => property.StoragePropertyName ?? property.DataModelPropertyName; + internal static SqlCommand SelectWhere( + Expression> filter, + int top, + GetFilteredRecordOptions options, + SqlConnection connection, string? schema, string tableName, + VectorStoreRecordModel model) + { + SqlCommand command = connection.CreateCommand(); + + StringBuilder sb = new(200); + sb.AppendFormat("SELECT "); + sb.AppendColumnNames(model.Properties, includeVectors: options.IncludeVectors); + sb.AppendLine(); + sb.Append("FROM "); + sb.AppendTableName(schema, tableName); + sb.AppendLine(); + if (filter is not null) + { + int startParamIndex = command.Parameters.Count; - internal static StringBuilder AppendParameterName(this StringBuilder sb, VectorStoreRecordProperty property, ref int paramIndex, out string parameterName) + SqlServerFilterTranslator translator = new(model, filter, sb, startParamIndex: startParamIndex); + translator.Translate(appendWhere: true); + List parameters = translator.ParameterValues; + + foreach (object parameter in parameters) + { + command.AddParameter(property: null, $"@_{startParamIndex++}", parameter); + } + sb.AppendLine(); + } + + if (options.OrderBy.Values.Count > 0) + { + sb.Append("ORDER BY "); + + foreach (var sortInfo in options.OrderBy.Values) + { + sb.AppendFormat("[{0}] {1},", + model.GetDataOrKeyProperty(sortInfo.PropertySelector).StorageName, + sortInfo.Ascending ? "ASC" : "DESC"); + } + + sb.Length--; // remove the last comma + sb.AppendLine(); + } + else + { + // no order by properties, but we need to add something for OFFSET and NEXT to work + sb.AppendLine("ORDER BY (SELECT 1)"); + } + + // Negative Skip and Top values are rejected by the GetFilteredRecordOptions property setters. + // 0 is a legal value for OFFSET. + sb.AppendFormat("OFFSET {0} ROWS FETCH NEXT {1} ROWS ONLY;", options.Skip, top); + + command.CommandText = sb.ToString(); + return command; + } + + internal static StringBuilder AppendParameterName(this StringBuilder sb, VectorStoreRecordPropertyModel property, ref int paramIndex, out string parameterName) { // In SQL Server, parameter names cannot be just a number like "@1". // Parameter names must start with an alphabetic character or an underscore @@ -376,10 +443,9 @@ internal static StringBuilder AppendParameterName(this StringBuilder sb, VectorS // is valid parameter name (it can contain whitespaces, or start with a number), // we just append the ASCII letters, stop on the first non-ASCII letter // and append the index. - string columnName = GetColumnName(property); int index = sb.Length; sb.Append('@'); - foreach (char character in columnName) + foreach (char character in property.StorageName) { // We don't call APIs like char.IsWhitespace as they are expensive // as they need to handle all Unicode characters. @@ -410,7 +476,7 @@ internal static StringBuilder AppendTableName(this StringBuilder sb, string? sch if (!string.IsNullOrEmpty(schema)) { sb.Append(schema); - sb.Replace("]", "]]", index, schema.Length); // replace the ] for schema + sb.Replace("]", "]]", index, schema!.Length); // replace the ] for schema sb.Append("].["); index = sb.Length; } @@ -423,14 +489,14 @@ internal static StringBuilder AppendTableName(this StringBuilder sb, string? sch } private static StringBuilder AppendColumnNames(this StringBuilder sb, - IEnumerable properties, + IEnumerable properties, string? prefix = null, bool includeVectors = true) { bool any = false; - foreach (VectorStoreRecordProperty property in properties) + foreach (var property in properties) { - if (!includeVectors && property is VectorStoreRecordVectorProperty) + if (!includeVectors && property is VectorStoreRecordVectorPropertyModel) { continue; } @@ -440,7 +506,7 @@ private static StringBuilder AppendColumnNames(this StringBuilder sb, sb.Append(prefix); } // Use square brackets to escape column names. - sb.AppendFormat("[{0}],", GetColumnName(property)); + sb.AppendFormat("[{0}],", property.StorageName); any = true; } @@ -453,7 +519,7 @@ private static StringBuilder AppendColumnNames(this StringBuilder sb, } private static StringBuilder AppendKeyParameterList(this StringBuilder sb, - IEnumerable keys, SqlCommand command, VectorStoreRecordKeyProperty keyProperty, out bool emptyKeys) + IEnumerable keys, SqlCommand command, VectorStoreRecordKeyPropertyModel keyProperty, out bool emptyKeys) { int keyIndex = 0; foreach (TKey key in keys) @@ -472,6 +538,37 @@ private static StringBuilder AppendKeyParameterList(this StringBuilder sb, return sb; } + private static StringBuilder AppendIndexName(this StringBuilder sb, string tableName, string columnName) + { + int length = sb.Length; + + // "Index names must start with a letter or an underscore (_)." + sb.Append("index"); + sb.Append('_'); + AppendAllowedOnly(tableName); + sb.Append('_'); + AppendAllowedOnly(columnName); + + if (sb.Length > length + SqlServerConstants.MaxIndexNameLength) + { + sb.Length = length + SqlServerConstants.MaxIndexNameLength; + } + + return sb; + + void AppendAllowedOnly(string value) + { + foreach (char c in value) + { + // Index names can include letters, numbers, and underscores. + if (char.IsLetterOrDigit(c) || c == '_') + { + sb.Append(c); + } + } + } + } + private static SqlCommand CreateCommand(this SqlConnection connection, StringBuilder sb) { SqlCommand command = connection.CreateCommand(); @@ -479,11 +576,11 @@ private static SqlCommand CreateCommand(this SqlConnection connection, StringBui return command; } - private static void AddParameter(this SqlCommand command, VectorStoreRecordProperty property, string name, object? value) + private static void AddParameter(this SqlCommand command, VectorStoreRecordPropertyModel? property, string name, object? value) { switch (value) { - case null when property.PropertyType == typeof(byte[]): + case null when property?.Type == typeof(byte[]): command.Parameters.Add(name, System.Data.SqlDbType.VarBinary).Value = DBNull.Value; break; case null: @@ -502,15 +599,15 @@ private static void AddParameter(this SqlCommand command, VectorStoreRecordPrope } } - private static string Map(VectorStoreRecordProperty property) => property.PropertyType switch + private static string Map(VectorStoreRecordPropertyModel property) => property.Type switch { Type t when t == typeof(byte) => "TINYINT", Type t when t == typeof(short) => "SMALLINT", Type t when t == typeof(int) => "INT", Type t when t == typeof(long) => "BIGINT", Type t when t == typeof(Guid) => "UNIQUEIDENTIFIER", - Type t when t == typeof(string) && property is VectorStoreRecordKeyProperty => "NVARCHAR(4000)", - Type t when t == typeof(string) && property is VectorStoreRecordDataProperty { IsFilterable: true } => "NVARCHAR(4000)", + Type t when t == typeof(string) && property is VectorStoreRecordKeyPropertyModel => "NVARCHAR(4000)", + Type t when t == typeof(string) && property is VectorStoreRecordDataPropertyModel { IsIndexed: true } => "NVARCHAR(4000)", Type t when t == typeof(string) => "NVARCHAR(MAX)", Type t when t == typeof(byte[]) => "VARBINARY(MAX)", Type t when t == typeof(bool) => "BIT", @@ -521,7 +618,7 @@ private static void AddParameter(this SqlCommand command, VectorStoreRecordPrope Type t when t == typeof(decimal) => "DECIMAL", Type t when t == typeof(double) => "FLOAT", Type t when t == typeof(float) => "REAL", - _ => throw new NotSupportedException($"Type {property.PropertyType} is not supported.") + _ => throw new NotSupportedException($"Type {property.Type} is not supported.") }; // Source: https://learn.microsoft.com/sql/t-sql/functions/vector-distance-transact-sql diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs index d8ce0f1354e7..072bfd58c689 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs @@ -2,48 +2,64 @@ using System; using System.Collections.Generic; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.SqlServer; internal static class SqlServerConstants { + internal const string VectorStoreSystemName = "microsoft.sql_server"; + // The actual number is actually higher (2_100), but we want to avoid any kind of "off by one" errors. internal const int MaxParameterCount = 2_000; - internal static readonly HashSet SupportedKeyTypes = - [ - typeof(int), // INT - typeof(long), // BIGINT - typeof(string), // VARCHAR - typeof(Guid), // UNIQUEIDENTIFIER - typeof(DateTime), // DATETIME2 - typeof(byte[]) // VARBINARY - ]; - - internal static readonly HashSet SupportedDataTypes = - [ - typeof(int), // INT - typeof(short), // SMALLINT - typeof(byte), // TINYINT - typeof(long), // BIGINT. - typeof(Guid), // UNIQUEIDENTIFIER. - typeof(string), // NVARCHAR - typeof(byte[]), // VARBINARY - typeof(bool), // BIT - typeof(DateTime), // DATETIME2 -#if NET - // We don't support mapping TimeSpan to TIME on purpose - // See https://github.com/microsoft/semantic-kernel/pull/10623#discussion_r1980350721 - typeof(TimeOnly), // TIME -#endif - typeof(decimal), // DECIMAL - typeof(double), // FLOAT - typeof(float), // REAL - ]; + internal const int MaxIndexNameLength = 128; internal static readonly HashSet SupportedVectorTypes = [ typeof(ReadOnlyMemory), // VECTOR typeof(ReadOnlyMemory?) ]; + + public static readonly VectorStoreRecordModelBuildingOptions ModelBuildingOptions = new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + + SupportedKeyPropertyTypes = + [ + typeof(int), // INT + typeof(long), // BIGINT + typeof(string), // VARCHAR + typeof(Guid), // UNIQUEIDENTIFIER + typeof(DateTime), // DATETIME2 + typeof(byte[]) // VARBINARY + ], + + SupportedDataPropertyTypes = + [ + typeof(int), // INT + typeof(short), // SMALLINT + typeof(byte), // TINYINT + typeof(long), // BIGINT. + typeof(Guid), // UNIQUEIDENTIFIER. + typeof(string), // NVARCHAR + typeof(byte[]), // VARBINARY + typeof(bool), // BIT + typeof(DateTime), // DATETIME2 +#if NET + // We don't support mapping TimeSpan to TIME on purpose + // See https://github.com/microsoft/semantic-kernel/pull/10623#discussion_r1980350721 + typeof(TimeOnly), // TIME +#endif + typeof(decimal), // DECIMAL + typeof(double), // FLOAT + typeof(float), // REAL + ], + + SupportedEnumerableDataPropertyElementTypes = [], + + SupportedVectorPropertyTypes = SupportedVectorTypes + }; } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs index 3bd3b2f97e0b..ec819362072c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Linq.Expressions; using System.Text; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.SqlServer; @@ -14,11 +15,11 @@ internal sealed class SqlServerFilterTranslator : SqlFilterTranslator private int _parameterIndex; internal SqlServerFilterTranslator( - IReadOnlyDictionary storagePropertyNames, + VectorStoreRecordModel model, LambdaExpression lambdaExpression, StringBuilder sql, int startParamIndex) - : base(storagePropertyNames, lambdaExpression, sql) + : base(model, lambdaExpression, sql) { this._parameterIndex = startParamIndex; } @@ -44,34 +45,29 @@ protected override void TranslateConstant(object? value) } } - protected override void TranslateColumn(string column, MemberExpression memberExpression, Expression? parent) + protected override void GenerateColumn(string column, bool isSearchCondition = false) { + this._sql.Append('[').Append(column).Append(']'); + // "SELECT * FROM MyTable WHERE BooleanColumn;" is not supported. // "SELECT * FROM MyTable WHERE BooleanColumn = 1;" is supported. - if (memberExpression.Type == typeof(bool) - && (parent is null // Where(x => x.Bool) - || parent is UnaryExpression { NodeType: ExpressionType.Not } // Where(x => !x.Bool) - || parent is BinaryExpression { NodeType: ExpressionType.AndAlso or ExpressionType.OrElse })) // Where(x => x.Bool && other) - { - this.TranslateBinary(Expression.Equal(memberExpression, Expression.Constant(true))); - } - else + if (isSearchCondition) { - this._sql.Append('[').Append(column).Append(']'); + this._sql.Append(" = 1"); } } - protected override void TranslateContainsOverArrayColumn(Expression source, Expression item, MethodCallExpression parent) + protected override void TranslateContainsOverArrayColumn(Expression source, Expression item) => throw new NotSupportedException("Unsupported Contains expression"); - protected override void TranslateContainsOverCapturedArray(Expression source, Expression item, MethodCallExpression parent, object? value) + protected override void TranslateContainsOverParameterizedArray(Expression source, Expression item, object? value) { if (value is not IEnumerable elements) { throw new NotSupportedException("Unsupported Contains expression"); } - this.Translate(item, parent); + this.Translate(item); this._sql.Append(" IN ("); var isFirst = true; @@ -92,17 +88,17 @@ protected override void TranslateContainsOverCapturedArray(Expression source, Ex this._sql.Append(')'); } - protected override void TranslateCapturedVariable(string name, object? capturedValue) + protected override void TranslateQueryParameter(string name, object? value) { // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) - if (capturedValue is null) + if (value is null) { this._sql.Append("NULL"); } else { - this._parameterValues.Add(capturedValue); + this._parameterValues.Add(value); // SQL Server parameters can't start with a digit (but underscore is OK). this._sql.Append("@_").Append(this._parameterIndex++); } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerMemoryBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerMemoryBuilderExtensions.cs index dcaf6dd22734..4b9f34c1a1eb 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerMemoryBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerMemoryBuilderExtensions.cs @@ -1,14 +1,16 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using Microsoft.SemanticKernel.Memory; namespace Microsoft.SemanticKernel.Connectors.SqlServer; +#pragma warning disable SKEXP0001 + /// /// Provides extension methods for the class to configure SQL Server or Azure SQL connector. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and SqlServerVectorStore")] public static class SqlServerMemoryBuilderExtensions { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerMemoryEntry.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerMemoryEntry.cs index e88c1b91e994..41778af5960f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerMemoryEntry.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerMemoryEntry.cs @@ -1,14 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Diagnostics.CodeAnalysis; namespace Microsoft.SemanticKernel.Connectors.SqlServer; /// /// A SQL Server or Azure SQL memory entry. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and SqlServerVectorStore")] internal record struct SqlServerMemoryEntry { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerMemoryStore.cs index d5891dc2e96a..ca5a760a4005 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerMemoryStore.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -11,10 +10,12 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer; +#pragma warning disable SKEXP0001 + /// /// An implementation of backed by a SQL Server or Azure SQL database. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and SqlServerVectorStore")] public class SqlServerMemoryStore : IMemoryStore, IDisposable { internal const string DefaultSchema = "dbo"; diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs index d9481ffc467d..943bd3fd21dc 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs @@ -1,8 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Threading; +using System.Threading.Tasks; using Microsoft.Data.SqlClient; using Microsoft.Extensions.VectorData; @@ -16,6 +18,12 @@ public sealed class SqlServerVectorStore : IVectorStore private readonly string _connectionString; private readonly SqlServerVectorStoreOptions _options; + /// Metadata about vector store. + private readonly VectorStoreMetadata _metadata; + + /// A general purpose definition that can be used to construct a collection when needing to proxy schema agnostic operations. + private static readonly VectorStoreRecordDefinition s_generalPurposeDefinition = new() { Properties = [new VectorStoreRecordKeyProperty("Key", typeof(string))] }; + /// /// Initializes a new instance of the class. /// @@ -29,12 +37,22 @@ public SqlServerVectorStore(string connectionString, SqlServerVectorStoreOptions // We need to create a copy, so any changes made to the option bag after // the ctor call do not affect this instance. this._options = options is not null - ? new() { Schema = options.Schema } + ? new() { Schema = options.Schema, EmbeddingGenerator = options.EmbeddingGenerator } : SqlServerVectorStoreOptions.Defaults; + + var connectionStringBuilder = new SqlConnectionStringBuilder(connectionString); + + this._metadata = new() + { + VectorStoreSystemName = SqlServerConstants.VectorStoreSystemName, + VectorStoreName = connectionStringBuilder.InitialCatalog + }; } /// - public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + where TKey : notnull + where TRecord : notnull { Verify.NotNull(name); @@ -44,7 +62,8 @@ public IVectorStoreRecordCollection GetCollection( new() { Schema = this._options.Schema, - RecordDefinition = vectorStoreRecordDefinition + RecordDefinition = vectorStoreRecordDefinition, + EmbeddingGenerator = this._options.EmbeddingGenerator }); } @@ -54,13 +73,47 @@ public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancel using SqlConnection connection = new(this._connectionString); using SqlCommand command = SqlServerCommandBuilder.SelectTableNames(connection, this._options.Schema); - using SqlDataReader reader = await ExceptionWrapper.WrapAsync(connection, command, + using SqlDataReader reader = await ExceptionWrapper.WrapAsync( + connection, + command, static (cmd, ct) => cmd.ExecuteReaderAsync(ct), - cancellationToken, "ListCollection").ConfigureAwait(false); + operationName: "ListCollectionNames", + vectorStoreName: this._metadata.VectorStoreName, + cancellationToken: cancellationToken).ConfigureAwait(false); - while (await ExceptionWrapper.WrapReadAsync(reader, cancellationToken, "ListCollection").ConfigureAwait(false)) + while (await ExceptionWrapper.WrapReadAsync( + reader, + operationName: "ListCollectionNames", + vectorStoreName: this._metadata.VectorStoreName, + cancellationToken: cancellationToken).ConfigureAwait(false)) { yield return reader.GetString(reader.GetOrdinal("table_name")); } } + + /// + public Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.CollectionExistsAsync(cancellationToken); + } + + /// + public Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.DeleteCollectionAsync(cancellationToken); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreMetadata) ? this._metadata : + serviceType.IsInstanceOfType(this) ? this : + null; + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs index a90b474a3d5f..9fbbd69bc8b6 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs @@ -1,5 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using Microsoft.Extensions.AI; + namespace Microsoft.SemanticKernel.Connectors.SqlServer; /// @@ -13,4 +15,9 @@ public sealed class SqlServerVectorStoreOptions /// Gets or sets the database schema. /// public string? Schema { get; init; } = null; + + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index 9b4ce3b29078..fb2617621475 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -3,11 +3,15 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.Data.SqlClient; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.Properties; namespace Microsoft.SemanticKernel.Connectors.SqlServer; @@ -17,15 +21,20 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer; #pragma warning disable CA1711 // Identifiers should not have incorrect suffix (Collection) public sealed class SqlServerVectorStoreRecordCollection #pragma warning restore CA1711 - : IVectorStoreRecordCollection where TKey : notnull + : IVectorStoreRecordCollection + where TKey : notnull + where TRecord : notnull { + /// Metadata about vector store record collection. + private readonly VectorStoreRecordCollectionMetadata _collectionMetadata; + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); private static readonly SqlServerVectorStoreRecordCollectionOptions s_defaultOptions = new(); private readonly string _connectionString; private readonly SqlServerVectorStoreRecordCollectionOptions _options; - private readonly VectorStoreRecordPropertyReader _propertyReader; - private readonly IVectorStoreRecordMapper> _mapper; + private readonly VectorStoreRecordModel _model; + private readonly RecordMapper _mapper; /// /// Initializes a new instance of the class. @@ -41,71 +50,52 @@ public SqlServerVectorStoreRecordCollection( Verify.NotNullOrWhiteSpace(connectionString); Verify.NotNull(name); - VectorStoreRecordPropertyReader propertyReader = new(typeof(TRecord), - options?.RecordDefinition, - new() - { - RequiresAtLeastOneVector = false, - SupportsMultipleKeys = false, - SupportsMultipleVectors = true, - }); - - if (VectorStoreRecordPropertyVerification.IsGenericDataModel(typeof(TRecord))) - { - VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.Mapper is not null, SqlServerConstants.SupportedKeyTypes); - } - else - { - propertyReader.VerifyKeyProperties(SqlServerConstants.SupportedKeyTypes); - } - propertyReader.VerifyDataProperties(SqlServerConstants.SupportedDataTypes, supportEnumerable: false); - propertyReader.VerifyVectorProperties(SqlServerConstants.SupportedVectorTypes); + this._model = new VectorStoreRecordModelBuilder(SqlServerConstants.ModelBuildingOptions) + .Build(typeof(TRecord), options?.RecordDefinition, options?.EmbeddingGenerator); this._connectionString = connectionString; - this.CollectionName = name; + this.Name = name; // We need to create a copy, so any changes made to the option bag after // the ctor call do not affect this instance. - this._options = options is null ? s_defaultOptions + this._options = options is null + ? s_defaultOptions : new() { Schema = options.Schema, - Mapper = options.Mapper, RecordDefinition = options.RecordDefinition, }; - this._propertyReader = propertyReader; + this._mapper = new RecordMapper(this._model); - if (options is not null && options.Mapper is not null) - { - this._mapper = options.Mapper; - } - else if (typeof(TRecord).IsGenericType && typeof(TRecord).GetGenericTypeDefinition() == typeof(VectorStoreGenericDataModel<>)) - { - this._mapper = (new GenericRecordMapper(propertyReader) as IVectorStoreRecordMapper>)!; - } - else - { - propertyReader.VerifyHasParameterlessConstructor(); + var connectionStringBuilder = new SqlConnectionStringBuilder(connectionString); - this._mapper = new RecordMapper(propertyReader); - } + this._collectionMetadata = new() + { + VectorStoreSystemName = SqlServerConstants.VectorStoreSystemName, + VectorStoreName = connectionStringBuilder.InitialCatalog, + CollectionName = name + }; } /// - public string CollectionName { get; } + public string Name { get; } /// public async Task CollectionExistsAsync(CancellationToken cancellationToken = default) { using SqlConnection connection = new(this._connectionString); using SqlCommand command = SqlServerCommandBuilder.SelectTableName( - connection, this._options.Schema, this.CollectionName); + connection, this._options.Schema, this.Name); return await ExceptionWrapper.WrapAsync(connection, command, static async (cmd, ct) => { using SqlDataReader reader = await cmd.ExecuteReaderAsync(ct).ConfigureAwait(false); return await reader.ReadAsync(ct).ConfigureAwait(false); - }, cancellationToken, "CollectionExists", this.CollectionName).ConfigureAwait(false); + }, + "CollectionExists", + this._collectionMetadata.VectorStoreName, + this.Name, + cancellationToken).ConfigureAwait(false); } /// @@ -118,27 +108,20 @@ public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken private async Task CreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken) { - foreach (var vectorProperty in this._propertyReader.VectorProperties) - { - if (vectorProperty.Dimensions is not > 0) - { - throw new InvalidOperationException($"Property {nameof(vectorProperty.Dimensions)} on {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' must be set to a positive integer to create a collection."); - } - } - using SqlConnection connection = new(this._connectionString); using SqlCommand command = SqlServerCommandBuilder.CreateTable( connection, this._options.Schema, - this.CollectionName, + this.Name, ifNotExists, - this._propertyReader.KeyProperty, - this._propertyReader.DataProperties, - this._propertyReader.VectorProperties); + this._model); await ExceptionWrapper.WrapAsync(connection, command, static (cmd, ct) => cmd.ExecuteNonQueryAsync(ct), - cancellationToken, "CreateCollection", this.CollectionName).ConfigureAwait(false); + "CreateCollection", + this._collectionMetadata.VectorStoreName, + this.Name, + cancellationToken).ConfigureAwait(false); } /// @@ -146,11 +129,14 @@ public async Task DeleteCollectionAsync(CancellationToken cancellationToken = de { using SqlConnection connection = new(this._connectionString); using SqlCommand command = SqlServerCommandBuilder.DropTableIfExists( - connection, this._options.Schema, this.CollectionName); + connection, this._options.Schema, this.Name); await ExceptionWrapper.WrapAsync(connection, command, static (cmd, ct) => cmd.ExecuteNonQueryAsync(ct), - cancellationToken, "DeleteCollection", this.CollectionName).ConfigureAwait(false); + "DeleteCollection", + this._collectionMetadata.VectorStoreName, + this.Name, + cancellationToken).ConfigureAwait(false); } /// @@ -162,17 +148,20 @@ public async Task DeleteAsync(TKey key, CancellationToken cancellationToken = de using SqlCommand command = SqlServerCommandBuilder.DeleteSingle( connection, this._options.Schema, - this.CollectionName, - this._propertyReader.KeyProperty, + this.Name, + this._model.KeyProperty, key); await ExceptionWrapper.WrapAsync(connection, command, static (cmd, ct) => cmd.ExecuteNonQueryAsync(ct), - cancellationToken, "Delete", this.CollectionName).ConfigureAwait(false); + "Delete", + this._collectionMetadata.VectorStoreName, + this.Name, + cancellationToken).ConfigureAwait(false); } /// - public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + public async Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default) { Verify.NotNull(keys); @@ -196,8 +185,8 @@ public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken can if (!SqlServerCommandBuilder.DeleteMany( command, this._options.Schema, - this.CollectionName, - this._propertyReader.KeyProperty, + this.Name, + this._model.KeyProperty, keys.Skip(taken).Take(SqlServerConstants.MaxParameterCount))) { break; // keys is empty, there is nothing to delete @@ -231,9 +220,10 @@ public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken can throw new VectorStoreOperationException(ex.Message, ex) { - OperationName = "DeleteBatch", - VectorStoreType = ExceptionWrapper.VectorStoreType, - CollectionName = this.CollectionName + VectorStoreSystemName = SqlServerConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, + OperationName = "DeleteBatch" }; } } @@ -245,13 +235,17 @@ public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken can bool includeVectors = options?.IncludeVectors is true; + if (includeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + using SqlConnection connection = new(this._connectionString); using SqlCommand command = SqlServerCommandBuilder.SelectSingle( connection, this._options.Schema, - this.CollectionName, - this._propertyReader.KeyProperty, - this._propertyReader.Properties, + this.Name, + this._model, key, includeVectors); @@ -261,23 +255,32 @@ static async (cmd, ct) => SqlDataReader reader = await cmd.ExecuteReaderAsync(ct).ConfigureAwait(false); await reader.ReadAsync(ct).ConfigureAwait(false); return reader; - }, cancellationToken, "Get", this.CollectionName).ConfigureAwait(false); + }, + "Get", + this._collectionMetadata.VectorStoreName, + this.Name, + cancellationToken).ConfigureAwait(false); return reader.HasRows ? this._mapper.MapFromStorageToDataModel( - new SqlDataReaderDictionary(reader, this._propertyReader.VectorPropertyStoragePropertyNames), + new SqlDataReaderDictionary(reader, this._model.VectorProperties), new() { IncludeVectors = includeVectors }) : default; } /// - public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, + public async IAsyncEnumerable GetAsync(IEnumerable keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { Verify.NotNull(keys); bool includeVectors = options?.IncludeVectors is true; + if (includeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + using SqlConnection connection = new(this._connectionString); using SqlCommand command = connection.CreateCommand(); int taken = 0; @@ -292,9 +295,8 @@ public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, Get if (!SqlServerCommandBuilder.SelectMany( command, this._options.Schema, - this.CollectionName, - this._propertyReader.KeyProperty, - this._propertyReader.Properties, + this.Name, + this._model, keys.Skip(taken).Take(SqlServerConstants.MaxParameterCount), includeVectors)) { @@ -308,12 +310,20 @@ public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, Get using SqlDataReader reader = await ExceptionWrapper.WrapAsync(connection, command, static (cmd, ct) => cmd.ExecuteReaderAsync(ct), - cancellationToken, "GetBatch", this.CollectionName).ConfigureAwait(false); - - while (await ExceptionWrapper.WrapReadAsync(reader, cancellationToken, "GetBatch", this.CollectionName).ConfigureAwait(false)) + "GetBatch", + this._collectionMetadata.VectorStoreName, + this.Name, + cancellationToken).ConfigureAwait(false); + + while (await ExceptionWrapper.WrapReadAsync( + reader, + "GetBatch", + this._collectionMetadata.VectorStoreName, + this.Name, + cancellationToken).ConfigureAwait(false)) { yield return this._mapper.MapFromStorageToDataModel( - new SqlDataReaderDictionary(reader, this._propertyReader.VectorPropertyStoragePropertyNames), + new SqlDataReaderDictionary(reader, this._model.VectorProperties), new() { IncludeVectors = includeVectors }); } } while (command.Parameters.Count == SqlServerConstants.MaxParameterCount); @@ -324,14 +334,39 @@ public async Task UpsertAsync(TRecord record, CancellationToken cancellati { Verify.NotNull(record); + IReadOnlyList?[]? generatedEmbeddings = null; + + var vectorPropertyCount = this._model.VectorProperties.Count; + for (var i = 0; i < vectorPropertyCount; i++) + { + var vectorProperty = this._model.VectorProperties[i]; + + if (vectorProperty.EmbeddingGenerator is null) + { + continue; + } + + // TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties), + // and generate embeddings for them in a single batch. That's some more complexity though. + if (vectorProperty.TryGenerateEmbedding, ReadOnlyMemory>(record, cancellationToken, out var floatTask)) + { + generatedEmbeddings ??= new IReadOnlyList?[vectorPropertyCount]; + generatedEmbeddings[i] = [await floatTask.ConfigureAwait(false)]; + } + else + { + throw new InvalidOperationException( + $"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding).Name}' for the given input type."); + } + } + using SqlConnection connection = new(this._connectionString); using SqlCommand command = SqlServerCommandBuilder.MergeIntoSingle( connection, this._options.Schema, - this.CollectionName, - this._propertyReader.KeyProperty, - this._propertyReader.Properties, - this._mapper.MapFromDataToStorageModel(record)); + this.Name, + this._model, + this._mapper.MapFromDataToStorageModel(record, recordIndex: 0, generatedEmbeddings)); return await ExceptionWrapper.WrapAsync(connection, command, async static (cmd, ct) => @@ -339,20 +374,66 @@ async static (cmd, ct) => using SqlDataReader reader = await cmd.ExecuteReaderAsync(ct).ConfigureAwait(false); await reader.ReadAsync(ct).ConfigureAwait(false); return reader.GetFieldValue(0); - }, cancellationToken, "Upsert", this.CollectionName).ConfigureAwait(false); + }, + "Upsert", + this._collectionMetadata.VectorStoreName, + this.Name, + cancellationToken).ConfigureAwait(false); } /// - public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, - [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async Task> UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) { Verify.NotNull(records); + IReadOnlyList? recordsList = null; + + // If an embedding generator is defined, invoke it once per property for all records. + IReadOnlyList?[]? generatedEmbeddings = null; + + var vectorPropertyCount = this._model.VectorProperties.Count; + for (var i = 0; i < vectorPropertyCount; i++) + { + var vectorProperty = this._model.VectorProperties[i]; + + if (vectorProperty.EmbeddingGenerator is null) + { + continue; + } + + // We have a property with embedding generation; materialize the records' enumerable if needed, to + // prevent multiple enumeration. + if (recordsList is null) + { + recordsList = records is IReadOnlyList r ? r : records.ToList(); + + if (recordsList.Count == 0) + { + return []; + } + + records = recordsList; + } + + // TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties), + // and generate embeddings for them in a single batch. That's some more complexity though. + if (vectorProperty.TryGenerateEmbeddings, ReadOnlyMemory>(records, cancellationToken, out var floatTask)) + { + generatedEmbeddings ??= new IReadOnlyList?[vectorPropertyCount]; + generatedEmbeddings[i] = (IReadOnlyList>)await floatTask.ConfigureAwait(false); + } + else + { + throw new InvalidOperationException( + $"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding).Name}' for the given input type."); + } + } + using SqlConnection connection = new(this._connectionString); await connection.OpenAsync(cancellationToken).ConfigureAwait(false); using SqlTransaction transaction = connection.BeginTransaction(); - int parametersPerRecord = this._propertyReader.Properties.Count; + int parametersPerRecord = this._model.Properties.Count; int taken = 0; try @@ -369,12 +450,11 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record if (!SqlServerCommandBuilder.MergeIntoMany( command, this._options.Schema, - this.CollectionName, - this._propertyReader.KeyProperty, - this._propertyReader.Properties, + this.Name, + this._model, records.Skip(taken) .Take(SqlServerConstants.MaxParameterCount / parametersPerRecord) - .Select(this._mapper.MapFromDataToStorageModel))) + .Select((r, i) => this._mapper.MapFromDataToStorageModel(r, taken + i, generatedEmbeddings)))) { break; // records is empty } @@ -407,33 +487,79 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record throw new VectorStoreOperationException(ex.Message, ex) { - OperationName = "UpsertBatch", - VectorStoreType = ExceptionWrapper.VectorStoreType, - CollectionName = this.CollectionName + VectorStoreSystemName = SqlServerConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, + OperationName = "UpsertBatch" }; } - if (typeof(TRecord) == typeof(VectorStoreGenericDataModel)) - { - foreach (var record in records) - { - yield return ((VectorStoreGenericDataModel)(object)record!).Key; - } - } - else + var keyProperty = this._model.KeyProperty; + + return records.Select(r => (TKey)keyProperty.GetValueAsObject(r)!).ToList(); + } + + #region Search + + /// + public async IAsyncEnumerable> SearchAsync( + TInput value, + int top, + VectorSearchOptions? options = default, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TInput : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + switch (vectorProperty.EmbeddingGenerator) { - var keyProperty = this._propertyReader.KeyPropertyInfo; - foreach (var record in records) - { - yield return (TKey)keyProperty.GetValue(record)!; - } + case IEmbeddingGenerator> generator: + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; + + case null: + throw new InvalidOperationException(VectorDataStrings.NoEmbeddingGeneratorWasConfiguredForSearch); + + default: + throw new InvalidOperationException( + SqlServerConstants.SupportedVectorTypes.Contains(typeof(TInput)) + ? string.Format(VectorDataStrings.EmbeddingTypePassedToSearchAsync) + : string.Format(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType, typeof(TInput).Name, vectorProperty.EmbeddingGenerator.GetType().Name)); } } - /// - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + /// + public IAsyncEnumerable> SearchEmbeddingAsync( + TVector vector, + int top, + VectorSearchOptions? options = null, + CancellationToken cancellationToken = default) + where TVector : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + return this.SearchCoreAsync(vector, top, vectorProperty, operationName: "SearchEmbedding", options, cancellationToken); + } + + private IAsyncEnumerable> SearchCoreAsync( + TVector vector, + int top, + VectorStoreRecordVectorPropertyModel vectorProperty, + string operationName, + VectorSearchOptions options, + CancellationToken cancellationToken = default) + where TVector : notnull { Verify.NotNull(vector); + Verify.NotLessThan(top, 1); if (vector is not ReadOnlyMemory allowed) { @@ -442,36 +568,53 @@ public async Task> VectorizedSearchAsync(T $"Supported types are: {string.Join(", ", SqlServerConstants.SupportedVectorTypes.Select(l => l.FullName))}"); } #pragma warning disable CS0618 // Type or member is obsolete - else if (options is not null && options.OldFilter is not null) + else if (options.OldFilter is not null) #pragma warning restore CS0618 // Type or member is obsolete { throw new NotSupportedException("The obsolete Filter is not supported by the SQL Server connector, use NewFilter instead."); } - var searchOptions = options ?? s_defaultVectorSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(searchOptions); + if (options.IncludeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } #pragma warning disable CA2000 // Dispose objects before losing scope - // This connection will be disposed by the ReadVectorSearchResultsAsync + // Connection and command are going to be disposed by the ReadVectorSearchResultsAsync, // when the user is done with the results. SqlConnection connection = new(this._connectionString); -#pragma warning restore CA2000 // Dispose objects before losing scope - using SqlCommand command = SqlServerCommandBuilder.SelectVector( + SqlCommand command = SqlServerCommandBuilder.SelectVector( connection, this._options.Schema, - this.CollectionName, + this.Name, vectorProperty, - this._propertyReader.Properties, - this._propertyReader.StoragePropertyNamesMap, - searchOptions, + this._model, + top, + options, allowed); +#pragma warning restore CA2000 // Dispose objects before losing scope - return await ExceptionWrapper.WrapAsync(connection, command, - (cmd, ct) => - { - var results = this.ReadVectorSearchResultsAsync(connection, cmd, searchOptions.IncludeVectors, ct); - return Task.FromResult(new VectorSearchResults(results)); - }, cancellationToken, "VectorizedSearch", this.CollectionName).ConfigureAwait(false); + return this.ReadVectorSearchResultsAsync(connection, command, options.IncludeVectors, cancellationToken); + } + + /// + [Obsolete("Use either SearchEmbeddingAsync to search directly on embeddings, or SearchAsync to handle embedding generation internally as part of the call.")] + public IAsyncEnumerable> VectorizedSearchAsync(TVector vector, int top, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + => this.SearchEmbeddingAsync(vector, top, options, cancellationToken); + + #endregion Search + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreRecordCollectionMetadata) ? this._collectionMetadata : + serviceType.IsInstanceOfType(this) ? this : + null; } private async IAsyncEnumerable> ReadVectorSearchResultsAsync( @@ -483,11 +626,22 @@ private async IAsyncEnumerable> ReadVectorSearchResu try { StorageToDataModelMapperOptions options = new() { IncludeVectors = includeVectors }; - var vectorPropertyStoragePropertyNames = includeVectors ? this._propertyReader.VectorPropertyStoragePropertyNames : []; - using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + var vectorProperties = includeVectors ? this._model.VectorProperties : []; + + using SqlDataReader reader = await ExceptionWrapper.WrapAsync(connection, command, + static (cmd, ct) => cmd.ExecuteReaderAsync(ct), + "VectorizedSearch", + this._collectionMetadata.VectorStoreName, + this.Name, + cancellationToken).ConfigureAwait(false); int scoreIndex = -1; - while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + while (await ExceptionWrapper.WrapReadAsync( + reader, + "VectorizedSearch", + this._collectionMetadata.VectorStoreName, + this.Name, + cancellationToken).ConfigureAwait(false)) { if (scoreIndex < 0) { @@ -495,13 +649,45 @@ private async IAsyncEnumerable> ReadVectorSearchResu } yield return new VectorSearchResult( - this._mapper.MapFromStorageToDataModel(new SqlDataReaderDictionary(reader, vectorPropertyStoragePropertyNames), options), + this._mapper.MapFromStorageToDataModel(new SqlDataReaderDictionary(reader, vectorProperties), options), reader.GetDouble(scoreIndex)); } } finally { + command.Dispose(); connection.Dispose(); } } + + /// + public async IAsyncEnumerable GetAsync(Expression> filter, int top, + GetFilteredRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(filter); + Verify.NotLessThan(top, 1); + + options ??= new(); + + using SqlConnection connection = new(this._connectionString); + using SqlCommand command = SqlServerCommandBuilder.SelectWhere( + filter, + top, + options, + connection, + this._options.Schema, + this.Name, + this._model); + + using SqlDataReader reader = await ExceptionWrapper.WrapAsync(connection, command, + static (cmd, ct) => cmd.ExecuteReaderAsync(ct), + "GetAsync", this._collectionMetadata.VectorStoreName, this.Name, cancellationToken).ConfigureAwait(false); + + var vectorProperties = options.IncludeVectors ? this._model.VectorProperties : []; + StorageToDataModelMapperOptions mapperOptions = new() { IncludeVectors = options.IncludeVectors }; + while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return this._mapper.MapFromStorageToDataModel(new SqlDataReaderDictionary(reader, vectorProperties), mapperOptions); + } + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollectionOptions.cs index 6b21a5e35842..baec6bd86e7e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollectionOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollectionOptions.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.SqlServer; @@ -21,6 +23,7 @@ public sealed class SqlServerVectorStoreRecordCollectionOptions /// /// If not set, the default mapper will be used. /// + [Obsolete("Custom mappers are no longer supported.", error: true)] public IVectorStoreRecordMapper>? Mapper { get; init; } /// @@ -32,4 +35,9 @@ public sealed class SqlServerVectorStoreRecordCollectionOptions /// See , and . /// public VectorStoreRecordDefinition? RecordDefinition { get; init; } + + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/Conditions/SqliteWhereCondition.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/Conditions/SqliteWhereCondition.cs index ea3f702a42b8..a31b7c5a5050 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/Conditions/SqliteWhereCondition.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/Conditions/SqliteWhereCondition.cs @@ -15,6 +15,6 @@ internal abstract class SqliteWhereCondition(string operand, List values public abstract string BuildQuery(List parameterNames); protected string GetOperand() => !string.IsNullOrWhiteSpace(this.TableName) ? - $"{this.TableName}.{this.Operand}" : - this.Operand; + $"\"{this.TableName}\".\"{this.Operand}\"" : + $"\"{this.Operand}\""; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/Connectors.Memory.Sqlite.csproj b/dotnet/src/Connectors/Connectors.Memory.Sqlite/Connectors.Memory.Sqlite.csproj index fec218bfc49d..3e35c126dfc8 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/Connectors.Memory.Sqlite.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/Connectors.Memory.Sqlite.csproj @@ -4,13 +4,14 @@ Microsoft.SemanticKernel.Connectors.Sqlite $(AssemblyName) - net8.0;netstandard2.0 + net8.0;netstandard2.0;net462 preview + @@ -23,11 +24,14 @@ + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/Database.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/Database.cs index c1e32e16a30d..9c3117cecd06 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/Database.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/Database.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; using System.Threading; @@ -10,7 +10,7 @@ namespace Microsoft.SemanticKernel.Connectors.Sqlite; -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and SqliteVectorStore")] internal struct DatabaseEntry { public string Key { get; set; } @@ -22,7 +22,7 @@ internal struct DatabaseEntry public string? Timestamp { get; set; } } -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and SqliteVectorStore")] internal sealed class Database { private const string TableName = "SKMemoryTable"; diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/ISqliteVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/ISqliteVectorStoreRecordCollectionFactory.cs index 6310489ac118..b8fa2ba0cc53 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/ISqliteVectorStoreRecordCollectionFactory.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/ISqliteVectorStoreRecordCollectionFactory.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Data.Common; using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.Sqlite; @@ -17,13 +16,14 @@ public interface ISqliteVectorStoreRecordCollectionFactory /// /// The data type of the record key. /// The data model to use for adding, updating and retrieving data from storage. - /// that will be used to manage the data in SQLite. + /// The connection string for the SQLite database represented by this . /// The name of the collection to connect to. /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. /// The new instance of . IVectorStoreRecordCollection CreateVectorStoreRecordCollection( - DbConnection connection, + string connectionString, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) - where TKey : notnull; + where TKey : notnull + where TRecord : notnull; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteColumn.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteColumn.cs index ae551cf65b2b..df9122f9c63d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteColumn.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteColumn.cs @@ -18,5 +18,7 @@ internal sealed class SqliteColumn( public bool IsPrimary { get; set; } = isPrimary; + public bool HasIndex { get; set; } + public Dictionary? Configuration { get; set; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteConstants.cs index 44e0c7a63026..9f94432d28e3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteConstants.cs @@ -2,53 +2,57 @@ using System; using System.Collections.Generic; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.Sqlite; internal static class SqliteConstants { + internal const string VectorStoreSystemName = "sqlite"; + /// /// SQLite extension name for vector search. /// More information here: . /// public const string VectorSearchExtensionName = "vec0"; - /// A of types that a key on the provided model may have. - public static readonly HashSet SupportedKeyTypes = - [ - typeof(ulong), - typeof(string) - ]; - - /// A of types that data properties on the provided model may have. - public static readonly HashSet SupportedDataTypes = - [ - typeof(int), - typeof(int?), - typeof(long), - typeof(long?), - typeof(ulong), - typeof(ulong?), - typeof(short), - typeof(short?), - typeof(ushort), - typeof(ushort?), - typeof(string), - typeof(bool), - typeof(bool?), - typeof(float), - typeof(float?), - typeof(double), - typeof(double?), - typeof(decimal), - typeof(decimal?), - typeof(byte[]), - ]; - /// A of types that vector properties on the provided model may have. public static readonly HashSet SupportedVectorTypes = [ typeof(ReadOnlyMemory), typeof(ReadOnlyMemory?) ]; + + public static readonly VectorStoreRecordModelBuildingOptions ModelBuildingOptions = new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + + SupportedKeyPropertyTypes = + [ + typeof(ulong), + typeof(string) + ], + + SupportedDataPropertyTypes = + [ + typeof(int), + typeof(long), + typeof(ulong), + typeof(short), + typeof(ushort), + typeof(string), + typeof(bool), + typeof(float), + typeof(double), + typeof(decimal), + typeof(byte[]) + ], + + SupportedEnumerableDataPropertyElementTypes = [], + SupportedVectorPropertyTypes = SupportedVectorTypes, + + EscapeIdentifier = SqliteVectorStoreCollectionCommandBuilder.EscapeIdentifier + }; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs index 963c1184d274..a602b0542373 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs @@ -4,6 +4,7 @@ using System.Collections; using System.Collections.Generic; using System.Linq.Expressions; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.Sqlite; @@ -11,25 +12,25 @@ internal sealed class SqliteFilterTranslator : SqlFilterTranslator { private readonly Dictionary _parameters = new(); - internal SqliteFilterTranslator(IReadOnlyDictionary storagePropertyNames, - LambdaExpression lambdaExpression) : base(storagePropertyNames, lambdaExpression, sql: null) + internal SqliteFilterTranslator(VectorStoreRecordModel model, LambdaExpression lambdaExpression) + : base(model, lambdaExpression, sql: null) { } internal Dictionary Parameters => this._parameters; // TODO: support Contains over array fields (#10343) - protected override void TranslateContainsOverArrayColumn(Expression source, Expression item, MethodCallExpression parent) + protected override void TranslateContainsOverArrayColumn(Expression source, Expression item) => throw new NotSupportedException("Unsupported Contains expression"); - protected override void TranslateContainsOverCapturedArray(Expression source, Expression item, MethodCallExpression parent, object? value) + protected override void TranslateContainsOverParameterizedArray(Expression source, Expression item, object? value) { if (value is not IEnumerable elements) { throw new NotSupportedException("Unsupported Contains expression"); } - this.Translate(item, parent); + this.Translate(item); this._sql.Append(" IN ("); var isFirst = true; @@ -50,11 +51,11 @@ protected override void TranslateContainsOverCapturedArray(Expression source, Ex this._sql.Append(')'); } - protected override void TranslateCapturedVariable(string name, object? capturedValue) + protected override void TranslateQueryParameter(string name, object? value) { // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) - if (capturedValue is null) + if (value is null) { this._sql.Append("NULL"); } @@ -72,7 +73,7 @@ protected override void TranslateCapturedVariable(string name, object? capturedV } while (this._parameters.ContainsKey(name)); } - this._parameters.Add(name, capturedValue); + this._parameters.Add(name, value); this._sql.Append('@').Append(name); } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteGenericDataModelMapper.cs deleted file mode 100644 index f6b59b2c926b..000000000000 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteGenericDataModelMapper.cs +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using Microsoft.Extensions.VectorData; - -namespace Microsoft.SemanticKernel.Connectors.Sqlite; - -/// -/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within SQLite. -/// -internal sealed class SqliteGenericDataModelMapper : - IVectorStoreRecordMapper, Dictionary>, - IVectorStoreRecordMapper, Dictionary> -{ - /// with helpers for reading vector store model properties and their attributes. - private readonly VectorStoreRecordPropertyReader _propertyReader; - - /// - /// Initializes a new instance of the class. - /// - /// A that defines the schema of the data in the database. - public SqliteGenericDataModelMapper(VectorStoreRecordPropertyReader propertyReader) - { - Verify.NotNull(propertyReader); - - this._propertyReader = propertyReader; - - // Validate property types. - this._propertyReader.VerifyDataProperties(SqliteConstants.SupportedDataTypes, supportEnumerable: false); - this._propertyReader.VerifyVectorProperties(SqliteConstants.SupportedVectorTypes); - } - - #region Implementation of IVectorStoreRecordMapper, Dictionary> - - public Dictionary MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - { - return this.InternalMapFromDataToStorageModel(dataModel); - } - - public VectorStoreGenericDataModel MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) - { - return this.InternalMapFromStorageToDataModel(storageModel, options); - } - - #endregion - - #region Implementation of IVectorStoreRecordMapper, Dictionary> - - public Dictionary MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - { - return this.InternalMapFromDataToStorageModel(dataModel); - } - - VectorStoreGenericDataModel IVectorStoreRecordMapper, Dictionary>.MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) - { - return this.InternalMapFromStorageToDataModel(storageModel, options); - } - - #endregion - - #region private - - private Dictionary InternalMapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - where TKey : notnull - { - var properties = new Dictionary - { - // Add key property - { this._propertyReader.KeyPropertyStoragePropertyName, dataModel.Key } - }; - - // Add data properties - if (dataModel.Data is not null) - { - foreach (var property in this._propertyReader.DataProperties) - { - if (dataModel.Data.TryGetValue(property.DataModelPropertyName, out var dataValue)) - { - properties.Add(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), dataValue); - } - } - } - - // Add vector properties - if (dataModel.Vectors is not null) - { - foreach (var property in this._propertyReader.VectorProperties) - { - if (dataModel.Vectors.TryGetValue(property.DataModelPropertyName, out var vectorValue)) - { - object? result = null; - - if (vectorValue is not null) - { - var vector = (ReadOnlyMemory)vectorValue; - result = SqliteVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); - } - - properties.Add(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), result); - } - } - } - - return properties; - } - - private VectorStoreGenericDataModel InternalMapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) - where TKey : notnull - { - TKey key; - var dataProperties = new Dictionary(); - var vectorProperties = new Dictionary(); - - // Process key property. - if (storageModel.TryGetValue(this._propertyReader.KeyPropertyStoragePropertyName, out var keyObject) && keyObject is not null) - { - key = (TKey)keyObject; - } - else - { - throw new VectorStoreRecordMappingException("No key property was found in the record retrieved from storage."); - } - - // Process data properties. - foreach (var property in this._propertyReader.DataProperties) - { - if (storageModel.TryGetValue(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), out var dataValue)) - { - dataProperties.Add(property.DataModelPropertyName, dataValue); - } - } - - // Process vector properties - if (options.IncludeVectors) - { - foreach (var property in this._propertyReader.VectorProperties) - { - if (storageModel.TryGetValue(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), out var vectorValue) && - vectorValue is byte[] vectorBytes) - { - var vector = SqliteVectorStoreRecordPropertyMapping.MapVectorForDataModel(vectorBytes); - vectorProperties.Add(property.DataModelPropertyName, vector); - } - } - } - - return new VectorStoreGenericDataModel(key) { Data = dataProperties, Vectors = vectorProperties }; - } - - #endregion -} diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteMemoryStore.cs index b8908818a049..7107020c0ddd 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteMemoryStore.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Linq; using System.Numerics.Tensors; @@ -16,13 +15,15 @@ namespace Microsoft.SemanticKernel.Connectors.Sqlite; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// An implementation of backed by a SQLite database. /// /// The data is saved to a database file, specified in the constructor. /// The data persists between subsequent instances. Only one instance may access the file at a time. /// The caller is responsible for deleting the file. -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and SqliteVectorStore")] public class SqliteMemoryStore : IMemoryStore, IDisposable { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteServiceCollectionExtensions.cs index 9c962c0786d5..b3db0e1bc4b6 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteServiceCollectionExtensions.cs @@ -1,7 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Data; +using System; using Microsoft.Data.Sqlite; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Sqlite; @@ -22,29 +23,12 @@ public static class SqliteServiceCollectionExtensions /// Optional options to further configure the . /// An optional service id to use as the service key. /// Service collection. + [Obsolete("Use AddSqliteVectorStore with connectionString instead.", error: true)] public static IServiceCollection AddSqliteVectorStore( this IServiceCollection services, SqliteVectorStoreOptions? options = default, string? serviceId = default) - { - services.AddKeyedTransient( - serviceId, - (sp, obj) => - { - var connection = sp.GetRequiredService(); - - if (connection.State != ConnectionState.Open) - { - connection.Open(); - } - - var selectedOptions = options ?? sp.GetService(); - - return new SqliteVectorStore(connection, options); - }); - - return services; - } + => throw new InvalidOperationException("Use AddSqliteVectorStore with connectionString instead."); /// /// Register a SQLite with the specified service ID. @@ -60,27 +44,12 @@ public static IServiceCollection AddSqliteVectorStore( string connectionString, SqliteVectorStoreOptions? options = default, string? serviceId = default) - { - services.AddKeyedTransient( + => services.AddKeyedSingleton( serviceId, - (sp, obj) => - { - var connection = new SqliteConnection(connectionString); - var extensionName = GetExtensionName(options?.VectorSearchExtensionName); - - connection.Open(); - - connection.LoadExtension(extensionName); - - var selectedOptions = options ?? sp.GetService(); - return new SqliteVectorStore(connection, options); - }); - - return services; - } + (sp, _) => new SqliteVectorStore(connectionString, options ?? sp.GetService() ?? new() { EmbeddingGenerator = sp.GetService() })); /// - /// Register a SQLite and with the specified service ID + /// Register a SQLite and with the specified service ID /// and where the SQLite is retrieved from the dependency injection container. /// In this case vector search extension loading should be handled manually. /// @@ -91,36 +60,17 @@ public static IServiceCollection AddSqliteVectorStore( /// Optional options to further configure the . /// An optional service id to use as the service key. /// Service collection. + [Obsolete("Use AddSqliteVectorStoreRecordCollection with connectionString instead.", error: true)] public static IServiceCollection AddSqliteVectorStoreRecordCollection( this IServiceCollection services, string collectionName, SqliteVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) where TKey : notnull - { - services.AddKeyedTransient>( - serviceId, - (sp, obj) => - { - var connection = sp.GetRequiredService(); - - if (connection.State != ConnectionState.Open) - { - connection.Open(); - } - - var selectedOptions = options ?? sp.GetService>(); - - return (new SqliteVectorStoreRecordCollection(connection, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; - }); - - AddVectorizedSearch(services, serviceId); - - return services; - } + => throw new InvalidOperationException("Use AddSqliteVectorStore with connectionString instead."); /// - /// Register a SQLite and with the specified service ID. + /// Register a SQLite and with the specified service ID. /// instance will be initialized, connection will be opened and vector search extension with be loaded. /// /// The type of the key. @@ -138,22 +88,19 @@ public static IServiceCollection AddSqliteVectorStoreRecordCollection? options = default, string? serviceId = default) where TKey : notnull + where TRecord : notnull { - services.AddKeyedTransient>( + services.AddKeyedSingleton>( serviceId, - (sp, obj) => - { - var connection = new SqliteConnection(connectionString); - var extensionName = GetExtensionName(options?.VectorSearchExtensionName); - - connection.Open(); - - connection.LoadExtension(extensionName); - - var selectedOptions = options ?? sp.GetService>(); - - return (new SqliteVectorStoreRecordCollection(connection, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; - }); + (sp, _) => ( + new SqliteVectorStoreRecordCollection( + connectionString, + collectionName, + options ?? sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }) + as IVectorStoreRecordCollection)!); AddVectorizedSearch(services, serviceId); @@ -161,28 +108,15 @@ public static IServiceCollection AddSqliteVectorStoreRecordCollection - /// Also register the with the given as a . + /// Also register the with the given as a . /// /// The type of the key. /// The type of the data model that the collection should contain. /// The service collection to register on. /// The service id that the registrations should use. - private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) + private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) where TRecord : notnull where TKey : notnull - { - services.AddKeyedTransient>( + => services.AddKeyedSingleton>( serviceId, - (sp, obj) => - { - return sp.GetRequiredKeyedService>(serviceId); - }); - } - - /// - /// Returns extension name for vector search. - /// - private static string GetExtensionName(string? extensionName) - { - return !string.IsNullOrWhiteSpace(extensionName) ? extensionName! : SqliteConstants.VectorSearchExtensionName; - } + (sp, _) => sp.GetRequiredKeyedService>(serviceId)); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStore.cs index 43b1a29b52d2..c820d7baf68c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStore.cs @@ -5,6 +5,7 @@ using System.Data.Common; using System.Runtime.CompilerServices; using System.Threading; +using System.Threading.Tasks; using Microsoft.Data.Sqlite; using Microsoft.Extensions.VectorData; @@ -16,68 +17,90 @@ namespace Microsoft.SemanticKernel.Connectors.Sqlite; /// /// This class can be used with collections of any schema type, but requires you to provide schema information when getting a collection. /// -public class SqliteVectorStore : IVectorStore +public sealed class SqliteVectorStore : IVectorStore { - /// that will be used to manage the data in SQLite. - private readonly DbConnection _connection; + /// Metadata about vector store. + private readonly VectorStoreMetadata _metadata; + + /// The connection string for the SQLite database represented by this . + private readonly string _connectionString; /// Optional configuration options for this class. private readonly SqliteVectorStoreOptions _options; + /// A general purpose definition that can be used to construct a collection when needing to proxy schema agnostic operations. + private static readonly VectorStoreRecordDefinition s_generalPurposeDefinition = new() { Properties = [new VectorStoreRecordKeyProperty("Key", typeof(string))] }; + /// /// Initializes a new instance of the class. /// - /// that will be used to manage the data in SQLite. + /// The connection string for the SQLite database represented by this . /// Optional configuration options for this class. - public SqliteVectorStore( - DbConnection connection, - SqliteVectorStoreOptions? options = default) + public SqliteVectorStore(string connectionString, SqliteVectorStoreOptions? options = default) { - Verify.NotNull(connection); + Verify.NotNull(connectionString); - this._connection = connection; + this._connectionString = connectionString; this._options = options ?? new(); + + var connectionStringBuilder = new SqliteConnectionStringBuilder(connectionString); + + this._metadata = new() + { + VectorStoreSystemName = SqliteConstants.VectorStoreSystemName, + VectorStoreName = connectionStringBuilder.DataSource + }; } + /// + /// Initializes a new instance of the class. + /// + /// that will be used to manage the data in SQLite. + /// Optional configuration options for this class. + [Obsolete("Use the constructor that accepts a connection string instead.", error: true)] + public SqliteVectorStore( + DbConnection connection, + SqliteVectorStoreOptions? options = default) + => throw new InvalidOperationException("Use the constructor that accepts a connection string instead."); + /// - public virtual IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull + where TRecord : notnull { #pragma warning disable CS0618 // ISqliteVectorStoreRecordCollectionFactory is obsolete if (this._options.VectorStoreCollectionFactory is not null) { return this._options.VectorStoreCollectionFactory.CreateVectorStoreRecordCollection( - this._connection, + this._connectionString, name, vectorStoreRecordDefinition); } #pragma warning restore CS0618 - if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(ulong)) - { - throw new NotSupportedException($"Only {nameof(String)} and {nameof(UInt64)} keys are supported."); - } - - var recordCollection = new SqliteVectorStoreRecordCollection( - this._connection, + var recordCollection = new SqliteVectorStoreRecordCollection( + this._connectionString, name, new() { VectorStoreRecordDefinition = vectorStoreRecordDefinition, VectorSearchExtensionName = this._options.VectorSearchExtensionName, - VectorVirtualTableName = this._options.VectorVirtualTableName + VectorVirtualTableName = this._options.VectorVirtualTableName, + EmbeddingGenerator = this._options.EmbeddingGenerator }) as IVectorStoreRecordCollection; return recordCollection!; } /// - public virtual async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { const string TablePropertyName = "name"; const string Query = $"SELECT {TablePropertyName} FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"; - using var command = this._connection.CreateCommand(); + using var connection = new SqliteConnection(this._connectionString); + await connection.OpenAsync(cancellationToken).ConfigureAwait(false); + using var command = connection.CreateCommand(); command.CommandText = Query; @@ -89,4 +112,30 @@ public virtual async IAsyncEnumerable ListCollectionNamesAsync([Enumerat yield return reader.GetString(ordinal); } } + + /// + public Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.CollectionExistsAsync(cancellationToken); + } + + /// + public Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.DeleteCollectionAsync(cancellationToken); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreMetadata) ? this._metadata : + serviceType.IsInstanceOfType(this) ? this : + null; + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs index 701c61bb3236..8844e9494357 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs @@ -8,6 +8,8 @@ using System.Linq; using System.Text; using Microsoft.Data.Sqlite; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.Sqlite; @@ -15,23 +17,13 @@ namespace Microsoft.SemanticKernel.Connectors.Sqlite; /// Command builder for queries in SQLite database. /// [SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "User input is passed using command parameters.")] -internal sealed class SqliteVectorStoreCollectionCommandBuilder +internal static class SqliteVectorStoreCollectionCommandBuilder { - /// that will be used to manage the data in SQLite. - private readonly DbConnection _connection; - - /// - /// Initializes a new instance of the class. - /// - /// that will be used to manage the data in SQLite. - public SqliteVectorStoreCollectionCommandBuilder(DbConnection connection) - { - Verify.NotNull(connection); + internal const string DistancePropertyName = "distance"; - this._connection = connection; - } + internal static string EscapeIdentifier(this string value) => value.Replace("'", "''").Replace("\"", "\"\""); - public DbCommand BuildTableCountCommand(string tableName) + public static DbCommand BuildTableCountCommand(SqliteConnection connection, string tableName) { Verify.NotNullOrWhiteSpace(tableName); @@ -40,7 +32,7 @@ public DbCommand BuildTableCountCommand(string tableName) var query = $"SELECT count(*) FROM {SystemTable} WHERE type='table' AND name={ParameterName};"; - var command = this._connection.CreateCommand(); + var command = connection.CreateCommand(); command.CommandText = query; @@ -49,23 +41,32 @@ public DbCommand BuildTableCountCommand(string tableName) return command; } - public DbCommand BuildCreateTableCommand(string tableName, IReadOnlyList columns, bool ifNotExists) + public static DbCommand BuildCreateTableCommand(SqliteConnection connection, string tableName, IReadOnlyList columns, bool ifNotExists) { var builder = new StringBuilder(); - builder.AppendLine($"CREATE TABLE {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}{tableName} ("); + builder.AppendLine($"CREATE TABLE {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}\"{tableName}\" ("); - builder.AppendLine(string.Join(",\n", columns.Select(GetColumnDefinition))); - builder.Append(");"); + builder.AppendLine(string.Join(",\n", columns.Select(column => GetColumnDefinition(column, quote: true)))); + builder.AppendLine(");"); + + foreach (var column in columns) + { + if (column.HasIndex) + { + builder.AppendLine($"CREATE INDEX {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}\"{tableName}_{column.Name}_index\" ON \"{tableName}\"(\"{column.Name}\");"); + } + } - var command = this._connection.CreateCommand(); + var command = connection.CreateCommand(); command.CommandText = builder.ToString(); return command; } - public DbCommand BuildCreateVirtualTableCommand( + public static DbCommand BuildCreateVirtualTableCommand( + SqliteConnection connection, string tableName, IReadOnlyList columns, bool ifNotExists, @@ -73,38 +74,41 @@ public DbCommand BuildCreateVirtualTableCommand( { var builder = new StringBuilder(); - builder.AppendLine($"CREATE VIRTUAL TABLE {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}{tableName} USING {extensionName}("); + builder.AppendLine($"CREATE VIRTUAL TABLE {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}\"{tableName}\" USING {extensionName}("); - builder.AppendLine(string.Join(",\n", columns.Select(GetColumnDefinition))); + // The vector extension is currently uncapable of handling quoted identifiers. + builder.AppendLine(string.Join(",\n", columns.Select(column => GetColumnDefinition(column, quote: false)))); builder.Append(");"); - var command = this._connection.CreateCommand(); + var command = connection.CreateCommand(); command.CommandText = builder.ToString(); return command; } - public DbCommand BuildDropTableCommand(string tableName) + public static DbCommand BuildDropTableCommand(SqliteConnection connection, string tableName) { - string query = $"DROP TABLE IF EXISTS [{tableName}];"; + string query = $"DROP TABLE IF EXISTS \"{tableName}\";"; - var command = this._connection.CreateCommand(); + var command = connection.CreateCommand(); command.CommandText = query; return command; } - public DbCommand BuildInsertCommand( + public static DbCommand BuildInsertCommand( + SqliteConnection connection, string tableName, string rowIdentifier, - IReadOnlyList columnNames, + IReadOnlyList properties, IReadOnlyList> records, + bool data, bool replaceIfExists = false) { var builder = new StringBuilder(); - var command = this._connection.CreateCommand(); + var command = connection.CreateCommand(); var replacePlaceholder = replaceIfExists ? " OR REPLACE" : string.Empty; @@ -113,11 +117,12 @@ public DbCommand BuildInsertCommand( var rowIdentifierParameterName = GetParameterName(rowIdentifier, recordIndex); var (columns, parameters, values) = GetQueryParts( - columnNames, + properties, records[recordIndex], - recordIndex); + recordIndex, + data); - builder.AppendLine($"INSERT{replacePlaceholder} INTO {tableName} ({string.Join(", ", columns)})"); + builder.AppendLine($"INSERT{replacePlaceholder} INTO \"{tableName}\" ({string.Join(", ", columns)})"); builder.AppendLine($"VALUES ({string.Join(", ", parameters)})"); builder.AppendLine($"RETURNING {rowIdentifier};"); @@ -132,71 +137,93 @@ public DbCommand BuildInsertCommand( return command; } - public DbCommand BuildSelectCommand( + public static DbCommand BuildSelectDataCommand( + SqliteConnection connection, string tableName, - IReadOnlyList columnNames, + VectorStoreRecordModel model, List conditions, - string? orderByPropertyName = null) + GetFilteredRecordOptions? filterOptions = null, + string? extraWhereFilter = null, + Dictionary? extraParameters = null, + int top = 0, + int skip = 0) { var builder = new StringBuilder(); - var (command, whereClause) = this.GetCommandWithWhereClause(conditions); + var (command, whereClause) = GetCommandWithWhereClause(connection, conditions, extraWhereFilter, extraParameters); - builder.AppendLine($"SELECT {string.Join(", ", columnNames)}"); - builder.AppendLine($"FROM {tableName}"); + builder.Append("SELECT "); + builder.AppendColumnNames(includeVectors: false, model.Properties); + builder.AppendLine($"FROM \"{tableName}\""); + builder.AppendWhereClause(whereClause); - AppendWhereClauseIfExists(builder, whereClause); - AppendOrderByIfExists(builder, orderByPropertyName); + if (filterOptions is not null) + { + builder.AppendOrderBy(model, filterOptions); + } + + builder.AppendLimits(top, skip); command.CommandText = builder.ToString(); return command; } - public DbCommand BuildSelectLeftJoinCommand( - string leftTable, - string rightTable, + public static DbCommand BuildSelectLeftJoinCommand( + SqliteConnection connection, + string vectorTableName, + string dataTableName, string joinColumnName, - IReadOnlyList leftTablePropertyNames, - IReadOnlyList rightTablePropertyNames, - List conditions, + VectorStoreRecordModel model, + IReadOnlyList conditions, + bool includeDistance, + GetFilteredRecordOptions? filterOptions = null, string? extraWhereFilter = null, Dictionary? extraParameters = null, - string? orderByPropertyName = null) + int top = 0, + int skip = 0) { var builder = new StringBuilder(); - List propertyNames = - [ - .. leftTablePropertyNames.Select(property => $"{leftTable}.{property}"), - .. rightTablePropertyNames.Select(property => $"{rightTable}.{property}"), - ]; + var (command, whereClause) = GetCommandWithWhereClause(connection, conditions, extraWhereFilter, extraParameters); - var (command, whereClause) = this.GetCommandWithWhereClause(conditions, extraWhereFilter, extraParameters); + builder.Append("SELECT "); + builder.AppendColumnNames(includeVectors: true, model.Properties, vectorTableName, dataTableName); + if (includeDistance) + { + builder.AppendLine($", \"{vectorTableName}\".\"{DistancePropertyName}\""); + } + builder.AppendLine($"FROM \"{vectorTableName}\""); + builder.AppendLine($"LEFT JOIN \"{dataTableName}\" ON \"{vectorTableName}\".\"{joinColumnName}\" = \"{dataTableName}\".\"{joinColumnName}\""); + builder.AppendWhereClause(whereClause); - builder.AppendLine($"SELECT {string.Join(", ", propertyNames)}"); - builder.AppendLine($"FROM {leftTable} "); - builder.AppendLine($"LEFT JOIN {rightTable} ON {leftTable}.{joinColumnName} = {rightTable}.{joinColumnName}"); + if (filterOptions is not null) + { + builder.AppendOrderBy(model, filterOptions, dataTableName); + } + else if (includeDistance) + { + builder.AppendLine($"ORDER BY \"{vectorTableName}\".\"{DistancePropertyName}\""); + } - AppendWhereClauseIfExists(builder, whereClause); - AppendOrderByIfExists(builder, orderByPropertyName); + builder.AppendLimits(top, skip); command.CommandText = builder.ToString(); return command; } - public DbCommand BuildDeleteCommand( + public static DbCommand BuildDeleteCommand( + SqliteConnection connection, string tableName, - List conditions) + IReadOnlyList conditions) { var builder = new StringBuilder(); - var (command, whereClause) = this.GetCommandWithWhereClause(conditions); - - builder.AppendLine($"DELETE FROM [{tableName}]"); + var (command, whereClause) = GetCommandWithWhereClause(connection, conditions); - AppendWhereClauseIfExists(builder, whereClause); + builder.AppendLine($"DELETE FROM \"{tableName}\""); + builder.AppendWhereClause(whereClause); command.CommandText = builder.ToString(); @@ -205,27 +232,92 @@ public DbCommand BuildDeleteCommand( #region private - private static void AppendWhereClauseIfExists(StringBuilder builder, string? whereClause) + private static StringBuilder AppendColumnNames(this StringBuilder builder, bool includeVectors, IReadOnlyList properties, + string? escapedVectorTableName = null, string? escapedDataTableName = null) { - if (!string.IsNullOrWhiteSpace(whereClause)) + foreach (var property in properties) { - builder.AppendLine($"WHERE {whereClause}"); + string? tableName = escapedDataTableName; + if (property is VectorStoreRecordVectorPropertyModel) + { + if (!includeVectors) + { + continue; + } + tableName = escapedVectorTableName; + } + + if (tableName is not null) + { + builder.AppendFormat("\"{0}\".\"{1}\",", tableName, property.StorageName); + } + else + { + builder.AppendFormat("\"{0}\",", property.StorageName); + } + } + + builder.Length--; // Remove the trailing comma + builder.AppendLine(); + return builder; + } + + private static StringBuilder AppendOrderBy(this StringBuilder builder, VectorStoreRecordModel model, + GetFilteredRecordOptions options, string? tableName = null) + { + if (options.OrderBy.Values.Count > 0) + { + builder.Append("ORDER BY "); + + foreach (var sortInfo in options.OrderBy.Values) + { + var storageName = model.GetDataOrKeyProperty(sortInfo.PropertySelector).StorageName; + + if (tableName is not null) + { + builder.AppendFormat("\"{0}\".", tableName); + } + + builder.AppendFormat("\"{0}\" {1},", storageName, sortInfo.Ascending ? "ASC" : "DESC"); + } + + builder.Length--; // remove the last comma + builder.AppendLine(); + } + + return builder; + } + + private static StringBuilder AppendLimits(this StringBuilder builder, int top, int skip) + { + if (top > 0) + { + builder.AppendFormat("LIMIT {0}", top).AppendLine(); + } + + if (skip > 0) + { + builder.AppendFormat("OFFSET {0}", skip).AppendLine(); } + + return builder; } - private static void AppendOrderByIfExists(StringBuilder builder, string? propertyName) + private static StringBuilder AppendWhereClause(this StringBuilder builder, string? whereClause) { - if (!string.IsNullOrWhiteSpace(propertyName)) + if (!string.IsNullOrWhiteSpace(whereClause)) { - builder.AppendLine($"ORDER BY {propertyName}"); + builder.AppendLine($"WHERE {whereClause}"); } + + return builder; } - private static string GetColumnDefinition(SqliteColumn column) + private static string GetColumnDefinition(SqliteColumn column, bool quote) { const string PrimaryKeyIdentifier = "PRIMARY KEY"; - List columnDefinitionParts = [column.Name, column.Type]; + List columnDefinitionParts = [quote ? $"\"{column.Name}\"" : column.Name, column.Type]; if (column.IsPrimary) { @@ -241,14 +333,15 @@ private static string GetColumnDefinition(SqliteColumn column) return string.Join(" ", columnDefinitionParts); } - private (DbCommand Command, string WhereClause) GetCommandWithWhereClause( - List conditions, + private static (DbCommand Command, string WhereClause) GetCommandWithWhereClause( + SqliteConnection connection, + IReadOnlyList conditions, string? extraWhereFilter = null, Dictionary? extraParameters = null) { const string WhereClauseOperator = " AND "; - var command = this._connection.CreateCommand(); + var command = connection.CreateCommand(); var whereClauseParts = new List(); foreach (var condition in conditions) @@ -279,7 +372,7 @@ private static string GetColumnDefinition(SqliteColumn column) whereClause += extraWhereFilter; Debug.Assert(extraParameters is not null, "extraParameters must be provided when extraWhereFilter is provided."); - foreach (var p in extraParameters) + foreach (var p in extraParameters!) { command.Parameters.Add(new SqliteParameter(p.Key, p.Value)); } @@ -289,19 +382,24 @@ private static string GetColumnDefinition(SqliteColumn column) } private static (List Columns, List ParameterNames, List ParameterValues) GetQueryParts( - IReadOnlyList propertyNames, + IReadOnlyList properties, Dictionary record, - int index) + int index, + bool data) { var columns = new List(); var parameterNames = new List(); var parameterValues = new List(); - foreach (var propertyName in propertyNames) + foreach (var property in properties) { - if (record.TryGetValue(propertyName, out var value)) + bool include = property is VectorStoreRecordKeyPropertyModel // The Key column is included in both Vector and Data tables. + || (data == property is VectorStoreRecordDataPropertyModel); // The Data column is included only in the Data table. + + string propertyName = property.StorageName; + if (include && record.TryGetValue(propertyName, out var value)) { - columns.Add(propertyName); + columns.Add($"\"{propertyName}\""); parameterNames.Add(GetParameterName(propertyName, index)); parameterValues.Add(value ?? DBNull.Value); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreOptions.cs index cac514677f07..63b96715c47d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreOptions.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.Connectors.Sqlite; @@ -10,7 +11,7 @@ namespace Microsoft.SemanticKernel.Connectors.Sqlite; public sealed class SqliteVectorStoreOptions { /// - /// An optional factory to use for constructing instances, if a custom record collection is required. + /// An optional factory to use for constructing instances, if a custom record collection is required. /// [Obsolete("To control how collections are instantiated, extend your provider's IVectorStore implementation and override GetCollection()")] public ISqliteVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } @@ -29,4 +30,9 @@ public sealed class SqliteVectorStoreOptions /// If not provided, collection name with prefix "vec_" will be used as virtual table name. /// public string? VectorVirtualTableName { get; set; } = null; + + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs index e2831eac6e53..b7f0ec49dae7 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs @@ -4,58 +4,52 @@ using System.Collections.Generic; using System.Data.Common; using System.Linq; +using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.Data.Sqlite; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.Properties; namespace Microsoft.SemanticKernel.Connectors.Sqlite; /// /// Service for storing and retrieving vector records, that uses SQLite as the underlying storage. /// +/// The data type of the record key. Can be or , or for dynamic mapping. /// The data model to use for adding, updating and retrieving data from storage. #pragma warning disable CA1711 // Identifiers should not have incorrect suffix -public class SqliteVectorStoreRecordCollection : - IVectorStoreRecordCollection, - IVectorStoreRecordCollection +public sealed class SqliteVectorStoreRecordCollection : IVectorStoreRecordCollection + where TKey : notnull + where TRecord : notnull #pragma warning restore CA1711 // Identifiers should not have incorrect { - /// The name of this database for telemetry purposes. - private const string DatabaseName = "SQLite"; + /// Metadata about vector store record collection. + private readonly VectorStoreRecordCollectionMetadata _collectionMetadata; - /// that will be used to manage the data in SQLite. - private readonly DbConnection _connection; + /// The connection string for the SQLite database represented by this . + private readonly string _connectionString; /// Optional configuration options for this class. private readonly SqliteVectorStoreRecordCollectionOptions _options; /// The mapper to use when mapping between the consumer data model and the SQLite record. - private readonly IVectorStoreRecordMapper> _mapper; + private readonly SqliteVectorStoreRecordMapper _mapper; /// The default options for vector search. private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); - /// Command builder for queries in SQLite database. - private readonly SqliteVectorStoreCollectionCommandBuilder _commandBuilder; - - /// Contains helpers for reading vector store model properties and their attributes. - private readonly VectorStoreRecordPropertyReader _propertyReader; + /// The model for this collection. + private readonly VectorStoreRecordModel _model; /// Flag which indicates whether vector properties exist in the consumer data model. private readonly bool _vectorPropertiesExist; - /// Collection of properties to operate in SQLite data table. - private readonly Lazy> _dataTableProperties; - - /// Collection of properties to operate in SQLite vector table. - private readonly Lazy> _vectorTableProperties; - - /// Collection of property names to operate in SQLite data table. - private readonly Lazy> _dataTableStoragePropertyNames; - - /// Collection of property names to operate in SQLite vector table. - private readonly Lazy> _vectorTableStoragePropertyNames; + /// The storage name of the key property. + private readonly string _keyStorageName; /// Table name in SQLite for data properties. private readonly string _dataTableName; @@ -63,63 +57,68 @@ public class SqliteVectorStoreRecordCollection : /// Table name in SQLite for vector properties. private readonly string _vectorTableName; + /// The sqlite_vec extension name to use. + private readonly string _vectorSearchExtensionName; + /// - public string CollectionName { get; } + public string Name { get; } /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// - /// that will be used to manage the data in SQLite. - /// The name of the collection/table that this will access. + /// The connection string for the SQLite database represented by this . + /// The name of the collection/table that this will access. /// Optional configuration options for this class. public SqliteVectorStoreRecordCollection( - DbConnection connection, - string collectionName, + string connectionString, + string name, SqliteVectorStoreRecordCollectionOptions? options = default) { // Verify. - Verify.NotNull(connection); - Verify.NotNullOrWhiteSpace(collectionName); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.DictionaryCustomMapper is not null, SqliteConstants.SupportedKeyTypes); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); + Verify.NotNull(connectionString); + Verify.NotNullOrWhiteSpace(name); + + if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(ulong) && typeof(TKey) != typeof(object)) + { + throw new NotSupportedException($"Only {nameof(String)} and {nameof(UInt64)} keys are supported (and object for dynamic mapping)."); + } // Assign. - this._connection = connection; - this.CollectionName = collectionName; + this._connectionString = connectionString; + this.Name = name; this._options = options ?? new(); + this._vectorSearchExtensionName = this._options.VectorSearchExtensionName ?? SqliteConstants.VectorSearchExtensionName; - this._dataTableName = this.CollectionName; - this._vectorTableName = GetVectorTableName(this._dataTableName, this._options); - - this._propertyReader = new VectorStoreRecordPropertyReader(typeof(TRecord), this._options.VectorStoreRecordDefinition, new() - { - RequiresAtLeastOneVector = false, - SupportsMultipleKeys = false, - SupportsMultipleVectors = true - }); - - // Validate property types. - this._propertyReader.VerifyKeyProperties(SqliteConstants.SupportedKeyTypes); + // Escape both table names before exposing them to anything that may build SQL commands. + this._dataTableName = name.EscapeIdentifier(); + this._vectorTableName = GetVectorTableName(name, this._options).EscapeIdentifier(); - this._vectorPropertiesExist = this._propertyReader.VectorProperties.Count > 0; + this._model = new VectorStoreRecordModelBuilder(SqliteConstants.ModelBuildingOptions) + .Build(typeof(TRecord), this._options.VectorStoreRecordDefinition, this._options.EmbeddingGenerator); - this._dataTableProperties = new(() => [this._propertyReader.KeyProperty, .. this._propertyReader.DataProperties]); - this._vectorTableProperties = new(() => [this._propertyReader.KeyProperty, .. this._propertyReader.VectorProperties]); + this._vectorPropertiesExist = this._model.VectorProperties.Count > 0; - this._dataTableStoragePropertyNames = new(() => [this._propertyReader.KeyPropertyStoragePropertyName, .. this._propertyReader.DataPropertyStoragePropertyNames]); - this._vectorTableStoragePropertyNames = new(() => [this._propertyReader.KeyPropertyStoragePropertyName, .. this._propertyReader.VectorPropertyStoragePropertyNames]); + // Populate some collections of properties + this._keyStorageName = this._model.KeyProperty.StorageName; + this._mapper = new SqliteVectorStoreRecordMapper(this._model); - this._mapper = this.InitializeMapper(); + var connectionStringBuilder = new SqliteConnectionStringBuilder(connectionString); - this._commandBuilder = new SqliteVectorStoreCollectionCommandBuilder(this._connection); + this._collectionMetadata = new() + { + VectorStoreSystemName = SqliteConstants.VectorStoreSystemName, + VectorStoreName = connectionStringBuilder.DataSource, + CollectionName = name + }; } /// - public virtual async Task CollectionExistsAsync(CancellationToken cancellationToken = default) + public async Task CollectionExistsAsync(CancellationToken cancellationToken = default) { const string OperationName = "TableCount"; - using var command = this._commandBuilder.BuildTableCountCommand(this._dataTableName); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + using var command = SqliteVectorStoreCollectionCommandBuilder.BuildTableCountCommand(connection, this._dataTableName); var result = await this .RunOperationAsync(OperationName, () => command.ExecuteScalarAsync(cancellationToken)) @@ -131,34 +130,97 @@ public virtual async Task CollectionExistsAsync(CancellationToken cancella } /// - public virtual Task CreateCollectionAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionAsync(CancellationToken cancellationToken = default) { - return this.InternalCreateCollectionAsync(ifNotExists: false, cancellationToken); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + await this.InternalCreateCollectionAsync(connection, ifNotExists: false, cancellationToken) + .ConfigureAwait(false); } /// - public virtual Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) { - return this.InternalCreateCollectionAsync(ifNotExists: true, cancellationToken); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + await this.InternalCreateCollectionAsync(connection, ifNotExists: true, cancellationToken) + .ConfigureAwait(false); } /// - public virtual async Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + public async Task DeleteCollectionAsync(CancellationToken cancellationToken = default) { - await this.DropTableAsync(this._dataTableName, cancellationToken).ConfigureAwait(false); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + + await this.DropTableAsync(connection, this._dataTableName, cancellationToken).ConfigureAwait(false); if (this._vectorPropertiesExist) { - await this.DropTableAsync(this._vectorTableName, cancellationToken).ConfigureAwait(false); + await this.DropTableAsync(connection, this._vectorTableName, cancellationToken).ConfigureAwait(false); } } + #region Search + /// - public virtual Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async IAsyncEnumerable> SearchAsync( + TInput value, + int top, + VectorSearchOptions? options = default, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TInput : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + switch (vectorProperty.EmbeddingGenerator) + { + case IEmbeddingGenerator> generator: + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; + + case null: + throw new InvalidOperationException(VectorDataStrings.NoEmbeddingGeneratorWasConfiguredForSearch); + + default: + throw new InvalidOperationException( + SqliteConstants.SupportedVectorTypes.Contains(typeof(TInput)) + ? string.Format(VectorDataStrings.EmbeddingTypePassedToSearchAsync) + : string.Format(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType, typeof(TInput).Name, vectorProperty.EmbeddingGenerator.GetType().Name)); + } + } + + /// + public IAsyncEnumerable> SearchEmbeddingAsync( + TVector vector, + int top, + VectorSearchOptions? options = null, + CancellationToken cancellationToken = default) + where TVector : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + return this.SearchCoreAsync(vector, top, vectorProperty, operationName: "SearchEmbedding", options, cancellationToken); + } + + private IAsyncEnumerable> SearchCoreAsync( + TVector vector, + int top, + VectorStoreRecordVectorPropertyModel vectorProperty, + string operationName, + VectorSearchOptions options, + CancellationToken cancellationToken = default) + where TVector : notnull { const string LimitPropertyName = "k"; Verify.NotNull(vector); + Verify.NotLessThan(top, 1); var vectorType = vector.GetType(); if (!SqliteConstants.SupportedVectorTypes.Contains(vectorType)) @@ -168,18 +230,20 @@ public virtual Task> VectorizedSearchAsync $"Supported types are: {string.Join(", ", SqliteConstants.SupportedVectorTypes.Select(l => l.FullName))}"); } - var searchOptions = options ?? s_defaultVectorSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(searchOptions); + if (options.IncludeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } var mappedArray = SqliteVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); // Simulating skip/offset logic locally, since OFFSET can work only with LIMIT in combination // and LIMIT is not supported in vector search extension, instead of LIMIT - "k" parameter is used. - var limit = searchOptions.Top + searchOptions.Skip; + var limit = top + options.Skip; var conditions = new List() { - new SqliteWhereMatchCondition(this._propertyReader.GetStoragePropertyName(vectorProperty.DataModelPropertyName), mappedArray), + new SqliteWhereMatchCondition(vectorProperty.StorageName, mappedArray), new SqliteWhereEqualsCondition(LimitPropertyName, limit) }; @@ -187,122 +251,327 @@ public virtual Task> VectorizedSearchAsync string? extraWhereFilter = null; Dictionary? extraParameters = null; - if (searchOptions.OldFilter is not null) + if (options.OldFilter is not null) { - if (searchOptions.Filter is not null) + if (options.Filter is not null) { throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"); } // Old filter, we translate it to a list of SqliteWhereCondition, and merge these into the conditions we already have - var filterConditions = this.GetFilterConditions(searchOptions.OldFilter, this._dataTableName); + var filterConditions = this.GetFilterConditions(options.OldFilter, this._dataTableName); if (filterConditions is { Count: > 0 }) { conditions.AddRange(filterConditions); } } - else if (searchOptions.Filter is not null) + else if (options.Filter is not null) { - SqliteFilterTranslator translator = new(this._propertyReader.StoragePropertyNamesMap, searchOptions.Filter); + SqliteFilterTranslator translator = new(this._model, options.Filter); translator.Translate(appendWhere: false); extraWhereFilter = translator.Clause.ToString(); extraParameters = translator.Parameters; } #pragma warning restore CS0618 // VectorSearchFilter is obsolete - var vectorSearchResults = new VectorSearchResults(this.EnumerateAndMapSearchResultsAsync( + return this.EnumerateAndMapSearchResultsAsync( conditions, extraWhereFilter, extraParameters, - searchOptions, - cancellationToken)); - - return Task.FromResult(vectorSearchResults); + options, + cancellationToken); } - #region Implementation of IVectorStoreRecordCollection - /// - public virtual Task GetAsync(ulong key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) - { - return this.InternalGetAsync(key, options, cancellationToken); - } + [Obsolete("Use either SearchEmbeddingAsync to search directly on embeddings, or SearchAsync to handle embedding generation internally as part of the call.")] + public IAsyncEnumerable> VectorizedSearchAsync(TVector vector, int top, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + => this.SearchEmbeddingAsync(vector, top, options, cancellationToken); - /// - public virtual IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) - { - return this.InternalGetBatchAsync(keys, options, cancellationToken); - } + #endregion Search /// - public virtual Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + public async IAsyncEnumerable GetAsync(Expression> filter, int top, GetFilteredRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - return this.InternalUpsertAsync(record, cancellationToken); - } + Verify.NotNull(filter); + Verify.NotLessThan(top, 1); - /// - public virtual IAsyncEnumerable UpsertBatchAsync(IEnumerable records, CancellationToken cancellationToken = default) - { - return this.InternalUpsertBatchAsync(records, cancellationToken); + options ??= new(); + + if (options.IncludeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + + SqliteFilterTranslator translator = new(this._model, filter); + translator.Translate(appendWhere: false); + + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + DbCommand? command = null; + + if (options.IncludeVectors) + { + command = SqliteVectorStoreCollectionCommandBuilder.BuildSelectLeftJoinCommand( + connection, + this._vectorTableName, + this._dataTableName, + this._keyStorageName, + this._model, + conditions: [], + includeDistance: false, + filterOptions: options, + translator.Clause.ToString(), + translator.Parameters, + top: top, + skip: options.Skip); + } + else + { + command = SqliteVectorStoreCollectionCommandBuilder.BuildSelectDataCommand( + connection, + this._dataTableName, + this._model, + conditions: [], + filterOptions: options, + translator.Clause.ToString(), + translator.Parameters, + top: top, + skip: options.Skip); + } + + using (command) + { + StorageToDataModelMapperOptions mapperOptions = new() { IncludeVectors = options.IncludeVectors }; + using var reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return this.GetAndMapRecord( + "Get", + reader, + this._model.Properties, + mapperOptions); + } + } } /// - public virtual Task DeleteAsync(ulong key, CancellationToken cancellationToken = default) + public async Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { - return this.InternalDeleteAsync(key, cancellationToken); + Verify.NotNull(key); + + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + + var condition = new SqliteWhereEqualsCondition(this._keyStorageName, key) + { + TableName = this._dataTableName + }; + + return await this.InternalGetBatchAsync(connection, condition, options, cancellationToken) + .FirstOrDefaultAsync(cancellationToken) + .ConfigureAwait(false); } /// - public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + public async IAsyncEnumerable GetAsync(IEnumerable keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - return this.InternalDeleteBatchAsync(keys, cancellationToken); - } + Verify.NotNull(keys); + var keysList = keys.Cast().ToList(); + if (keysList.Count == 0) + { + yield break; + } - #endregion + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); - #region Implementation of IVectorStoreRecordCollection + var condition = new SqliteWhereInCondition(this._keyStorageName, keysList) + { + TableName = this._dataTableName + }; - /// - public virtual Task GetAsync(string key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) - { - return this.InternalGetAsync(key, options, cancellationToken); + await foreach (var record in this.InternalGetBatchAsync(connection, condition, options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } } /// - public virtual IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) { - return this.InternalGetBatchAsync(keys, options, cancellationToken); + Verify.NotNull(record); + + const string OperationName = "Upsert"; + + IReadOnlyList?[]? generatedEmbeddings = null; + + var vectorPropertyCount = this._model.VectorProperties.Count; + for (var i = 0; i < vectorPropertyCount; i++) + { + var vectorProperty = this._model.VectorProperties[i]; + + if (vectorProperty.EmbeddingGenerator is null) + { + continue; + } + + // TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties), + // and generate embeddings for them in a single batch. That's some more complexity though. + if (vectorProperty.TryGenerateEmbedding, ReadOnlyMemory>(record, cancellationToken, out var floatTask)) + { + generatedEmbeddings ??= new IReadOnlyList?[vectorPropertyCount]; + generatedEmbeddings[i] = [await floatTask.ConfigureAwait(false)]; + } + else + { + throw new InvalidOperationException( + $"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding).Name}' for the given input type."); + } + } + + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + + var storageModel = VectorStoreErrorHandler.RunModelConversion( + SqliteConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, + OperationName, + () => this._mapper.MapFromDataToStorageModel(record, recordIndex: 0, generatedEmbeddings)); + + var key = storageModel[this._keyStorageName]; + + Verify.NotNull(key); + + var condition = new SqliteWhereEqualsCondition(this._keyStorageName, key); + + var upsertedRecordKeys = await this.InternalUpsertBatchAsync(connection, [storageModel], condition, cancellationToken) + .ConfigureAwait(false); + + return upsertedRecordKeys.Single() ?? throw new VectorStoreOperationException("Error occurred during upsert operation."); } /// - Task IVectorStoreRecordCollection.UpsertAsync(TRecord record, CancellationToken cancellationToken) + public async Task> UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) { - return this.InternalUpsertAsync(record, cancellationToken); + Verify.NotNull(records); + + const string OperationName = "UpsertBatch"; + + IReadOnlyList? recordsList = null; + + // If an embedding generator is defined, invoke it once per property for all records. + IReadOnlyList?[]? generatedEmbeddings = null; + + var vectorPropertyCount = this._model.VectorProperties.Count; + for (var i = 0; i < vectorPropertyCount; i++) + { + var vectorProperty = this._model.VectorProperties[i]; + + if (vectorProperty.EmbeddingGenerator is null) + { + continue; + } + + // We have a property with embedding generation; materialize the records' enumerable if needed, to + // prevent multiple enumeration. + if (recordsList is null) + { + recordsList = records is IReadOnlyList r ? r : records.ToList(); + + if (recordsList.Count == 0) + { + return []; + } + + records = recordsList; + } + + // TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties), + // and generate embeddings for them in a single batch. That's some more complexity though. + if (vectorProperty.TryGenerateEmbeddings, ReadOnlyMemory>(records, cancellationToken, out var floatTask)) + { + generatedEmbeddings ??= new IReadOnlyList?[vectorPropertyCount]; + generatedEmbeddings[i] = (IReadOnlyList>)await floatTask.ConfigureAwait(false); + } + else + { + throw new InvalidOperationException( + $"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding).Name}' for the given input type."); + } + } + + var storageModels = VectorStoreErrorHandler.RunModelConversion( + SqliteConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, + OperationName, + () => records.Select((r, i) => this._mapper.MapFromDataToStorageModel(r, i, generatedEmbeddings)).ToList()); + + if (storageModels.Count == 0) + { + return []; + } + + var keys = storageModels.Select(model => model[this._keyStorageName]!).ToList(); + + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + var condition = new SqliteWhereInCondition(this._keyStorageName, keys); + + return await this.InternalUpsertBatchAsync(connection, storageModels, condition, cancellationToken).ConfigureAwait(false); } /// - IAsyncEnumerable IVectorStoreRecordCollection.UpsertBatchAsync(IEnumerable records, CancellationToken cancellationToken) + public async Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) { - return this.InternalUpsertBatchAsync(records, cancellationToken); + Verify.NotNull(key); + + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + + var condition = new SqliteWhereEqualsCondition(this._keyStorageName, key); + + await this.InternalDeleteBatchAsync(connection, condition, cancellationToken).ConfigureAwait(false); } /// - public virtual Task DeleteAsync(string key, CancellationToken cancellationToken = default) + public async Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default) { - return this.InternalDeleteAsync(key, cancellationToken); + Verify.NotNull(keys); + var keysList = keys.Cast().ToList(); + if (keysList.Count == 0) + { + return; + } + + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + + var condition = new SqliteWhereInCondition( + this._keyStorageName, + keysList); + + await this.InternalDeleteBatchAsync(connection, condition, cancellationToken).ConfigureAwait(false); } /// - public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + public object? GetService(Type serviceType, object? serviceKey = null) { - return this.InternalDeleteBatchAsync(keys, cancellationToken); - } + Verify.NotNull(serviceType); - #endregion + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreRecordCollectionMetadata) ? this._collectionMetadata : + serviceType.IsInstanceOfType(this) ? this : + null; + } #region private + private async ValueTask GetConnectionAsync(CancellationToken cancellationToken = default) + { + var connection = new SqliteConnection(this._connectionString); + await connection.OpenAsync(cancellationToken).ConfigureAwait(false); + connection.LoadExtension(this._vectorSearchExtensionName); + return connection; + } + private async IAsyncEnumerable> EnumerateAndMapSearchResultsAsync( List conditions, string? extraWhereFilter, @@ -311,59 +580,45 @@ private async IAsyncEnumerable> EnumerateAndMapSearc [EnumeratorCancellation] CancellationToken cancellationToken) { const string OperationName = "VectorizedSearch"; - const string DistancePropertyName = "distance"; - - var leftTableProperties = new List { DistancePropertyName }; - List properties = [this._propertyReader.KeyProperty, .. this._propertyReader.DataProperties]; - - if (searchOptions.IncludeVectors) - { - leftTableProperties.AddRange(this._propertyReader.VectorPropertyStoragePropertyNames); - properties.AddRange(this._propertyReader.VectorProperties); - } - - using var command = this._commandBuilder.BuildSelectLeftJoinCommand( + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + using var command = SqliteVectorStoreCollectionCommandBuilder.BuildSelectLeftJoinCommand( + connection, this._vectorTableName, this._dataTableName, - this._propertyReader.KeyPropertyStoragePropertyName, - leftTableProperties, - this._dataTableStoragePropertyNames.Value, + this._keyStorageName, + this._model, conditions, - extraWhereFilter, - extraParameters, - DistancePropertyName); + includeDistance: true, + extraWhereFilter: extraWhereFilter, + extraParameters: extraParameters); using var reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + StorageToDataModelMapperOptions mapperOptions = new() { IncludeVectors = searchOptions.IncludeVectors }; for (var recordCounter = 0; await reader.ReadAsync(cancellationToken).ConfigureAwait(false); recordCounter++) { if (recordCounter >= searchOptions.Skip) { - var score = SqliteVectorStoreRecordPropertyMapping.GetPropertyValue(reader, DistancePropertyName); + var score = SqliteVectorStoreRecordPropertyMapping.GetPropertyValue(reader, SqliteVectorStoreCollectionCommandBuilder.DistancePropertyName); var record = this.GetAndMapRecord( OperationName, reader, - properties, - searchOptions.IncludeVectors); + this._model.Properties, + mapperOptions); yield return new VectorSearchResult(record, score); } } } - private Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken) + private async Task InternalCreateCollectionAsync(SqliteConnection connection, bool ifNotExists, CancellationToken cancellationToken) { - List dataTableColumns = SqliteVectorStoreRecordPropertyMapping.GetColumns( - this._dataTableProperties.Value, - this._propertyReader.StoragePropertyNamesMap); + List dataTableColumns = SqliteVectorStoreRecordPropertyMapping.GetColumns(this._model.Properties, data: true); - List tasks = [this.CreateTableAsync( - this._dataTableName, - dataTableColumns, - ifNotExists, - cancellationToken)]; + await this.CreateTableAsync(connection, this._dataTableName, dataTableColumns, ifNotExists, cancellationToken) + .ConfigureAwait(false); if (this._vectorPropertiesExist) { @@ -371,85 +626,42 @@ private Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken c this._options.VectorSearchExtensionName : SqliteConstants.VectorSearchExtensionName; - List vectorTableColumns = SqliteVectorStoreRecordPropertyMapping.GetColumns( - this._vectorTableProperties.Value, - this._propertyReader.StoragePropertyNamesMap); + List vectorTableColumns = SqliteVectorStoreRecordPropertyMapping.GetColumns(this._model.Properties, data: false); - tasks.Add(this.CreateVirtualTableAsync( - this._vectorTableName, - vectorTableColumns, - ifNotExists, - extensionName!, - cancellationToken)); + await this.CreateVirtualTableAsync(connection, this._vectorTableName, vectorTableColumns, ifNotExists, extensionName!, cancellationToken) + .ConfigureAwait(false); } - - return Task.WhenAll(tasks); } - private Task CreateTableAsync(string tableName, List columns, bool ifNotExists, CancellationToken cancellationToken) + private Task CreateTableAsync(SqliteConnection connection, string tableName, List columns, bool ifNotExists, CancellationToken cancellationToken) { const string OperationName = "CreateTable"; - using var command = this._commandBuilder.BuildCreateTableCommand(tableName, columns, ifNotExists); + using var command = SqliteVectorStoreCollectionCommandBuilder.BuildCreateTableCommand(connection, tableName, columns, ifNotExists); return this.RunOperationAsync(OperationName, () => command.ExecuteNonQueryAsync(cancellationToken)); } - private Task CreateVirtualTableAsync(string tableName, List columns, bool ifNotExists, string extensionName, CancellationToken cancellationToken) + private Task CreateVirtualTableAsync(SqliteConnection connection, string tableName, List columns, bool ifNotExists, string extensionName, CancellationToken cancellationToken) { const string OperationName = "CreateVirtualTable"; - using var command = this._commandBuilder.BuildCreateVirtualTableCommand(tableName, columns, ifNotExists, extensionName); + using var command = SqliteVectorStoreCollectionCommandBuilder.BuildCreateVirtualTableCommand(connection, tableName, columns, ifNotExists, extensionName); return this.RunOperationAsync(OperationName, () => command.ExecuteNonQueryAsync(cancellationToken)); } - private Task DropTableAsync(string tableName, CancellationToken cancellationToken) + private Task DropTableAsync(SqliteConnection connection, string tableName, CancellationToken cancellationToken) { const string OperationName = "DropTable"; - using var command = this._commandBuilder.BuildDropTableCommand(tableName); + using var command = SqliteVectorStoreCollectionCommandBuilder.BuildDropTableCommand(connection, tableName); return this.RunOperationAsync(OperationName, () => command.ExecuteNonQueryAsync(cancellationToken)); } - private async Task InternalGetAsync( - TKey key, - GetRecordOptions? options, - CancellationToken cancellationToken) - { - Verify.NotNull(key); - - var condition = new SqliteWhereEqualsCondition(this._propertyReader.KeyPropertyStoragePropertyName, key) - { - TableName = this._dataTableName - }; - - return await this.InternalGetBatchAsync(condition, options, cancellationToken) - .FirstOrDefaultAsync(cancellationToken) - .ConfigureAwait(false); - } - - private IAsyncEnumerable InternalGetBatchAsync( - IEnumerable keys, - GetRecordOptions? options, - CancellationToken cancellationToken) - { - Verify.NotNull(keys); - - var keysList = keys.Cast().ToList(); - - Verify.True(keysList.Count > 0, "Number of provided keys should be greater than zero."); - - var condition = new SqliteWhereInCondition(this._propertyReader.KeyPropertyStoragePropertyName, keysList) - { - TableName = this._dataTableName - }; - - return this.InternalGetBatchAsync(condition, options, cancellationToken); - } - private async IAsyncEnumerable InternalGetBatchAsync( + SqliteConnection connection, SqliteWhereCondition condition, GetRecordOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) @@ -459,30 +671,36 @@ private async IAsyncEnumerable InternalGetBatchAsync( bool includeVectors = options?.IncludeVectors is true && this._vectorPropertiesExist; DbCommand command; - List properties = [this._propertyReader.KeyProperty, .. this._propertyReader.DataProperties]; if (includeVectors) { - command = this._commandBuilder.BuildSelectLeftJoinCommand( - this._dataTableName, - this._vectorTableName, - this._propertyReader.KeyPropertyStoragePropertyName, - this._dataTableStoragePropertyNames.Value, - this._propertyReader.VectorPropertyStoragePropertyNames, - [condition]); + if (this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } - properties.AddRange(this._propertyReader.VectorProperties); + command = SqliteVectorStoreCollectionCommandBuilder.BuildSelectLeftJoinCommand( + connection, + this._vectorTableName, + this._dataTableName, + this._keyStorageName, + this._model, + [condition], + includeDistance: false); } else { - command = this._commandBuilder.BuildSelectCommand( + command = SqliteVectorStoreCollectionCommandBuilder.BuildSelectDataCommand( + connection, this._dataTableName, - this._dataTableStoragePropertyNames.Value, + this._model, [condition]); } using (command) { + StorageToDataModelMapperOptions mapperOptions = new() { IncludeVectors = includeVectors }; + using var reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) @@ -490,56 +708,17 @@ private async IAsyncEnumerable InternalGetBatchAsync( yield return this.GetAndMapRecord( OperationName, reader, - properties, - includeVectors); + this._model.Properties, + mapperOptions); } } } - private async Task InternalUpsertAsync(TRecord record, CancellationToken cancellationToken) - { - const string OperationName = "Upsert"; - - var storageModel = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, - OperationName, - () => this._mapper.MapFromDataToStorageModel(record)); - - var key = storageModel[this._propertyReader.KeyPropertyStoragePropertyName]; - - Verify.NotNull(key); - - var condition = new SqliteWhereEqualsCondition(this._propertyReader.KeyPropertyStoragePropertyName, key); - - var upsertedRecordKey = await this.InternalUpsertBatchAsync([storageModel], condition, cancellationToken) - .FirstOrDefaultAsync(cancellationToken) - .ConfigureAwait(false); - - return upsertedRecordKey ?? throw new VectorStoreOperationException("Error occurred during upsert operation."); - } - - private IAsyncEnumerable InternalUpsertBatchAsync(IEnumerable records, CancellationToken cancellationToken) - { - const string OperationName = "UpsertBatch"; - - var storageModels = records.Select(record => VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, - OperationName, - () => this._mapper.MapFromDataToStorageModel(record))).ToList(); - - var keys = storageModels.Select(model => model[this._propertyReader.KeyPropertyStoragePropertyName]!).ToList(); - - var condition = new SqliteWhereInCondition(this._propertyReader.KeyPropertyStoragePropertyName, keys); - - return this.InternalUpsertBatchAsync(storageModels, condition, cancellationToken); - } - - private async IAsyncEnumerable InternalUpsertBatchAsync( + private async Task> InternalUpsertBatchAsync( + SqliteConnection connection, List> storageModels, SqliteWhereCondition condition, - [EnumeratorCancellation] CancellationToken cancellationToken) + CancellationToken cancellationToken) { Verify.NotNull(storageModels); Verify.True(storageModels.Count > 0, "Number of provided records should be greater than zero."); @@ -548,29 +727,35 @@ private async IAsyncEnumerable InternalUpsertBatchAsync( { // Deleting vector records first since current version of vector search extension // doesn't support Upsert operation, only Delete/Insert. - using var vectorDeleteCommand = this._commandBuilder.BuildDeleteCommand( + using var vectorDeleteCommand = SqliteVectorStoreCollectionCommandBuilder.BuildDeleteCommand( + connection, this._vectorTableName, [condition]); await vectorDeleteCommand.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - using var vectorInsertCommand = this._commandBuilder.BuildInsertCommand( + using var vectorInsertCommand = SqliteVectorStoreCollectionCommandBuilder.BuildInsertCommand( + connection, this._vectorTableName, - this._propertyReader.KeyPropertyStoragePropertyName, - this._vectorTableStoragePropertyNames.Value, - storageModels); + this._keyStorageName, + this._model.Properties, + storageModels, + data: false); await vectorInsertCommand.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } - using var dataCommand = this._commandBuilder.BuildInsertCommand( - this._dataTableName, - this._propertyReader.KeyPropertyStoragePropertyName, - this._dataTableStoragePropertyNames.Value, - storageModels, - replaceIfExists: true); + using var dataCommand = SqliteVectorStoreCollectionCommandBuilder.BuildInsertCommand( + connection, + this._dataTableName, + this._keyStorageName, + this._model.Properties, + storageModels, + data: true, + replaceIfExists: true); using var reader = await dataCommand.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + var keys = new List(); while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) { @@ -578,38 +763,16 @@ private async IAsyncEnumerable InternalUpsertBatchAsync( if (key is not null) { - yield return key; + keys.Add(key); } await reader.NextResultAsync(cancellationToken).ConfigureAwait(false); } - } - - private Task InternalDeleteAsync(TKey key, CancellationToken cancellationToken) - { - Verify.NotNull(key); - - var condition = new SqliteWhereEqualsCondition(this._propertyReader.KeyPropertyStoragePropertyName, key); - - return this.InternalDeleteBatchAsync(condition, cancellationToken); - } - - private Task InternalDeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken) - { - Verify.NotNull(keys); - - var keysList = keys.Cast().ToList(); - - Verify.True(keysList.Count > 0, "Number of provided keys should be greater than zero."); - - var condition = new SqliteWhereInCondition( - this._propertyReader.KeyPropertyStoragePropertyName, - keysList); - return this.InternalDeleteBatchAsync(condition, cancellationToken); + return keys; } - private Task InternalDeleteBatchAsync(SqliteWhereCondition condition, CancellationToken cancellationToken) + private Task InternalDeleteBatchAsync(SqliteConnection connection, SqliteWhereCondition condition, CancellationToken cancellationToken) { const string OperationName = "Delete"; @@ -617,14 +780,16 @@ private Task InternalDeleteBatchAsync(SqliteWhereCondition condition, Cancellati if (this._vectorPropertiesExist) { - using var vectorCommand = this._commandBuilder.BuildDeleteCommand( + using var vectorCommand = SqliteVectorStoreCollectionCommandBuilder.BuildDeleteCommand( + connection, this._vectorTableName, [condition]); tasks.Add(this.RunOperationAsync(OperationName, () => vectorCommand.ExecuteNonQueryAsync(cancellationToken))); } - using var dataCommand = this._commandBuilder.BuildDeleteCommand( + using var dataCommand = SqliteVectorStoreCollectionCommandBuilder.BuildDeleteCommand( + connection, this._dataTableName, [condition]); @@ -636,25 +801,26 @@ private Task InternalDeleteBatchAsync(SqliteWhereCondition condition, Cancellati private TRecord GetAndMapRecord( string operationName, DbDataReader reader, - List properties, - bool includeVectors) + IReadOnlyList properties, + StorageToDataModelMapperOptions options) { var storageModel = new Dictionary(); foreach (var property in properties) { - var propertyName = this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName); - var propertyType = property.PropertyType; - var propertyValue = SqliteVectorStoreRecordPropertyMapping.GetPropertyValue(reader, propertyName, propertyType); - - storageModel.Add(propertyName, propertyValue); + if (options.IncludeVectors || property is not VectorStoreRecordVectorPropertyModel) + { + var propertyValue = SqliteVectorStoreRecordPropertyMapping.GetPropertyValue(reader, property.StorageName, property.Type); + storageModel.Add(property.StorageName, propertyValue); + } } return VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, + SqliteConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, operationName, - () => this._mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors })); + () => this._mapper.MapFromStorageToDataModel(storageModel, options)); } private async Task RunOperationAsync(string operationName, Func> operation) @@ -667,30 +833,14 @@ private async Task RunOperationAsync(string operationName, Func> o { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, - CollectionName = this.CollectionName, + VectorStoreSystemName = SqliteConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, OperationName = operationName }; } } - private IVectorStoreRecordMapper> InitializeMapper() - { - if (this._options.DictionaryCustomMapper is not null) - { - return this._options.DictionaryCustomMapper; - } - - if (typeof(TRecord) == typeof(VectorStoreGenericDataModel) || - typeof(TRecord) == typeof(VectorStoreGenericDataModel)) - { - var mapper = new SqliteGenericDataModelMapper(this._propertyReader); - return (mapper as IVectorStoreRecordMapper>)!; - } - - return new SqliteVectorStoreRecordMapper(this._propertyReader); - } - #pragma warning disable CS0618 // VectorSearchFilter is obsolete private List? GetFilterConditions(VectorSearchFilter? filter, string? tableName = null) { @@ -707,12 +857,12 @@ private async Task RunOperationAsync(string operationName, Func> o { if (filterClause is EqualToFilterClause equalToFilterClause) { - if (!this._propertyReader.StoragePropertyNamesMap.TryGetValue(equalToFilterClause.FieldName, out var storagePropertyName)) + if (!this._model.PropertyMap.TryGetValue(equalToFilterClause.FieldName, out var property)) { throw new InvalidOperationException($"Property name '{equalToFilterClause.FieldName}' provided as part of the filter clause is not a valid property name."); } - conditions.Add(new SqliteWhereEqualsCondition(storagePropertyName, equalToFilterClause.Value) + conditions.Add(new SqliteWhereEqualsCondition(property.StorageName, equalToFilterClause.Value) { TableName = tableName }); diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollectionOptions.cs index 90c06511826b..c9dc17db4919 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollectionOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollectionOptions.cs @@ -1,18 +1,21 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.Sqlite; /// -/// Options when creating a . +/// Options when creating a . /// public sealed class SqliteVectorStoreRecordCollectionOptions { /// /// Gets or sets an optional custom mapper to use when converting between the data model and the SQLite record. /// + [Obsolete("Custom mappers are no longer supported.", error: true)] public IVectorStoreRecordMapper>? DictionaryCustomMapper { get; set; } /// @@ -37,4 +40,9 @@ public sealed class SqliteVectorStoreRecordCollectionOptions /// If not provided, collection name with prefix will be used as virtual table name. /// public string? VectorVirtualTableName { get; set; } = null; + + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordMapper.cs index f08ed1992b01..91fa5528b5de 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordMapper.cs @@ -2,7 +2,9 @@ using System; using System.Collections.Generic; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.Sqlite; @@ -10,55 +12,33 @@ namespace Microsoft.SemanticKernel.Connectors.Sqlite; /// Class for mapping between a dictionary and the consumer data model. /// /// The consumer data model to map to or from. -internal sealed class SqliteVectorStoreRecordMapper : IVectorStoreRecordMapper> +internal sealed class SqliteVectorStoreRecordMapper(VectorStoreRecordModel model) { - /// with helpers for reading vector store model properties and their attributes. - private readonly VectorStoreRecordPropertyReader _propertyReader; - - /// - /// Initializes a new instance of the class. - /// - /// A that defines the schema of the data in the database. - public SqliteVectorStoreRecordMapper(VectorStoreRecordPropertyReader propertyReader) - { - Verify.NotNull(propertyReader); - - this._propertyReader = propertyReader; - - this._propertyReader.VerifyHasParameterlessConstructor(); - - // Validate property types. - this._propertyReader.VerifyDataProperties(SqliteConstants.SupportedDataTypes, supportEnumerable: false); - this._propertyReader.VerifyVectorProperties(SqliteConstants.SupportedVectorTypes); - } - - public Dictionary MapFromDataToStorageModel(TRecord dataModel) + public Dictionary MapFromDataToStorageModel(TRecord dataModel, int recordIndex, IReadOnlyList?[]? generatedEmbeddings) { var properties = new Dictionary { - // Add key property - { this._propertyReader.KeyPropertyStoragePropertyName, this._propertyReader.KeyPropertyInfo.GetValue(dataModel) } + { model.KeyProperty.StorageName, model.KeyProperty.GetValueAsObject(dataModel!) } }; - // Add data properties - foreach (var property in this._propertyReader.DataPropertiesInfo) + foreach (var property in model.DataProperties) { - properties.Add(this._propertyReader.GetStoragePropertyName(property.Name), property.GetValue(dataModel)); + properties.Add(property.StorageName, property.GetValueAsObject(dataModel!)); } - // Add vector properties - foreach (var property in this._propertyReader.VectorPropertiesInfo) + for (var i = 0; i < model.VectorProperties.Count; i++) { - object? result = null; - var propertyValue = property.GetValue(dataModel); - - if (propertyValue is not null) - { - var vector = (ReadOnlyMemory)propertyValue; - result = SqliteVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); - } - - properties.Add(this._propertyReader.GetStoragePropertyName(property.Name), result); + var property = model.VectorProperties[i]; + var vector = generatedEmbeddings?[i] is IReadOnlyList e ? ((Embedding)e[recordIndex]).Vector : property.GetValueAsObject(dataModel!); + + properties.Add( + property.StorageName, + vector switch + { + ReadOnlyMemory floats => SqliteVectorStoreRecordPropertyMapping.MapVectorForStorageModel(floats), + null => null, + _ => throw new InvalidOperationException($"Retrieved value for vector property '{property.StorageName}' which is not a ReadOnlyMemory ('{vector?.GetType().Name}').") + }); } return properties; @@ -66,34 +46,27 @@ public SqliteVectorStoreRecordMapper(VectorStoreRecordPropertyReader propertyRea public TRecord MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) { - var record = (TRecord)this._propertyReader.ParameterLessConstructorInfo.Invoke(null); - - // Set key. - var keyPropertyValue = Convert.ChangeType( - storageModel[this._propertyReader.KeyPropertyStoragePropertyName], - this._propertyReader.KeyProperty.PropertyType); - - this._propertyReader.KeyPropertyInfo.SetValue(record, keyPropertyValue); + var record = model.CreateRecord()!; - // Process data properties. - var dataPropertiesInfoWithValues = VectorStoreRecordMapping.BuildPropertiesInfoWithValues( - this._propertyReader.DataPropertiesInfo, - this._propertyReader.StoragePropertyNamesMap, - storageModel); + var keyPropertyValue = Convert.ChangeType(storageModel[model.KeyProperty.StorageName], model.KeyProperty.Type); + model.KeyProperty.SetValueAsObject(record, keyPropertyValue); - VectorStoreRecordMapping.SetPropertiesOnRecord(record, dataPropertiesInfoWithValues); + foreach (var property in model.DataProperties) + { + property.SetValueAsObject(record, storageModel[property.StorageName]); + } if (options.IncludeVectors) { - // Process vector properties. - var vectorPropertiesInfoWithValues = VectorStoreRecordMapping.BuildPropertiesInfoWithValues( - this._propertyReader.VectorPropertiesInfo, - this._propertyReader.StoragePropertyNamesMap, - storageModel, - (object? vector, Type type) => vector is byte[] vectorBytes ? - SqliteVectorStoreRecordPropertyMapping.MapVectorForDataModel(vectorBytes) : null); + foreach (var property in model.VectorProperties) + { + if (storageModel[property.StorageName] is not byte[] vectorBytes) + { + throw new InvalidOperationException($"Retrieved value for vector property '{property.StorageName}' which is not a byte array ('{storageModel[property.StorageName]?.GetType().Name}')."); + } - VectorStoreRecordMapping.SetPropertiesOnRecord(record, vectorPropertiesInfoWithValues); + property.SetValueAsObject(record, SqliteVectorStoreRecordPropertyMapping.MapVectorForDataModel(vectorBytes)); + } } return record; diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordPropertyMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordPropertyMapping.cs index e468d14c3e65..2dc0c7369a5e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordPropertyMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordPropertyMapping.cs @@ -3,8 +3,10 @@ using System; using System.Collections.Generic; using System.Data.Common; +using System.Diagnostics; using System.Runtime.InteropServices; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.Sqlite; @@ -35,9 +37,7 @@ public static ReadOnlyMemory MapVectorForDataModel(byte[] byteArray) return new ReadOnlyMemory(array); } - public static List GetColumns( - List properties, - IReadOnlyDictionary storagePropertyNames) + public static List GetColumns(IReadOnlyList properties, bool data) { const string DistanceMetricConfigurationName = "distance_metric"; @@ -45,28 +45,46 @@ public static List GetColumns( foreach (var property in properties) { - var isPrimary = property is VectorStoreRecordKeyProperty; - var propertyName = storagePropertyNames[property.DataModelPropertyName]; + var isPrimary = false; string propertyType; Dictionary? configuration = null; - if (property is VectorStoreRecordVectorProperty vectorProperty) + if (property is VectorStoreRecordVectorPropertyModel vectorProperty) { + if (data) + { + continue; + } + propertyType = GetStorageVectorPropertyType(vectorProperty); configuration = new() { - [DistanceMetricConfigurationName] = GetDistanceMetric(vectorProperty.DistanceFunction, vectorProperty.DataModelPropertyName) + [DistanceMetricConfigurationName] = GetDistanceMetric(vectorProperty) }; } + else if (property is VectorStoreRecordDataPropertyModel dataProperty) + { + if (!data) + { + continue; + } + + propertyType = GetStorageDataPropertyType(property); + } else { + // The Key column in included in both Vector and Data tables. + Debug.Assert(property is VectorStoreRecordKeyPropertyModel, "property is VectorStoreRecordKeyPropertyModel"); + propertyType = GetStorageDataPropertyType(property); + isPrimary = true; } - var column = new SqliteColumn(propertyName, propertyType, isPrimary) + var column = new SqliteColumn(property.StorageName, propertyType, isPrimary) { - Configuration = configuration + Configuration = configuration, + HasIndex = property is VectorStoreRecordDataPropertyModel { IsIndexed: true } }; columns.Add(column); @@ -116,9 +134,8 @@ public static List GetColumns( #region private - private static string GetStorageDataPropertyType(VectorStoreRecordProperty property) - { - return property.PropertyType switch + private static string GetStorageDataPropertyType(VectorStoreRecordPropertyModel property) + => property.Type switch { // Integer types Type t when t == typeof(int) || t == typeof(int?) => "INTEGER", @@ -142,34 +159,20 @@ private static string GetStorageDataPropertyType(VectorStoreRecordProperty prope Type t when t == typeof(byte[]) => "BLOB", // Default fallback for unknown types - _ => throw new NotSupportedException($"Property {property.DataModelPropertyName} has type {property.PropertyType.FullName}, which is not supported by SQLite connector.") + _ => throw new NotSupportedException($"Property '{property.ModelName}' has type '{property.Type.Name}', which is not supported by SQLite connector.") }; - } - - private static string GetDistanceMetric(string? distanceFunction, string vectorPropertyName) - { - const string Cosine = "cosine"; - const string L1 = "l1"; - const string L2 = "l2"; - if (string.IsNullOrWhiteSpace(distanceFunction)) + private static string GetDistanceMetric(VectorStoreRecordVectorPropertyModel vectorProperty) + => vectorProperty.DistanceFunction switch { - return Cosine; - } - - return distanceFunction switch - { - DistanceFunction.CosineDistance => Cosine, - DistanceFunction.ManhattanDistance => L1, - DistanceFunction.EuclideanDistance => L2, - _ => throw new NotSupportedException($"Distance function '{distanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorPropertyName}' is not supported by the SQLite connector.") + DistanceFunction.CosineDistance or null => "cosine", + DistanceFunction.ManhattanDistance => "l1", + DistanceFunction.EuclideanDistance => "l2", + _ => throw new NotSupportedException($"Distance function '{vectorProperty.DistanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.ModelName}' is not supported by the SQLite connector.") }; - } - private static string GetStorageVectorPropertyType(VectorStoreRecordVectorProperty vectorProperty) - { - return $"FLOAT[{vectorProperty.Dimensions}]"; - } + private static string GetStorageVectorPropertyType(VectorStoreRecordVectorPropertyModel vectorProperty) + => $"FLOAT[{vectorProperty.Dimensions}]"; #endregion } diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Connectors.Memory.Weaviate.csproj b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Connectors.Memory.Weaviate.csproj index 26b63c694dff..1fb9a74fbe3c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Connectors.Memory.Weaviate.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Connectors.Memory.Weaviate.csproj @@ -4,13 +4,14 @@ Microsoft.SemanticKernel.Connectors.Weaviate $(AssemblyName) - net8.0;netstandard2.0 + net8.0;netstandard2.0;net462 preview + @@ -19,10 +20,16 @@ + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/BatchRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/BatchRequest.cs index 24441684ceb9..ff56e3bad8f8 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/BatchRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/BatchRequest.cs @@ -1,13 +1,15 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Net.Http; using Microsoft.SemanticKernel.Memory; namespace Microsoft.SemanticKernel.Connectors.Weaviate; -[Experimental("SKEXP0020")] +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] internal sealed class BatchRequest { private readonly string _class; diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/BatchResponse.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/BatchResponse.cs index af63e8c62e20..2aa03939c188 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/BatchResponse.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/BatchResponse.cs @@ -1,14 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Weaviate; // ReSharper disable once ClassNeverInstantiated.Global #pragma warning disable CA1812 // 'BatchResponse' is an internal class that is apparently never instantiated. If so, remove the code from the assembly. If this class is intended to contain only static members, make it 'static' (Module in Visual Basic). -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] internal sealed class BatchResponse : WeaviateObject #pragma warning restore CA1812 // 'BatchResponse' is an internal class that is apparently never instantiated. If so, remove the code from the assembly. If this class is intended to contain only static members, make it 'static' (Module in Visual Basic). { diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/CreateClassSchemaRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/CreateClassSchemaRequest.cs index 93cd60dfb0ff..9efaeafb7f75 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/CreateClassSchemaRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/CreateClassSchemaRequest.cs @@ -1,11 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; namespace Microsoft.SemanticKernel.Connectors.Weaviate; -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] internal sealed class CreateClassSchemaRequest { private CreateClassSchemaRequest(string @class, string description) diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/CreateClassSchemaResponse.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/CreateClassSchemaResponse.cs index 181d69db467e..d7b65b1393ac 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/CreateClassSchemaResponse.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/CreateClassSchemaResponse.cs @@ -1,11 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; namespace Microsoft.SemanticKernel.Connectors.Weaviate; #pragma warning disable CA1812 // 'CreateClassSchemaResponse' is an internal class that is apparently never instantiated. If so, remove the code from the assembly. If this class is intended to contain only static members, make it 'static' (Module in Visual Basic). -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] internal sealed class CreateClassSchemaResponse #pragma warning restore CA1812 // 'CreateClassSchemaResponse' is an internal class that is apparently never instantiated. If so, remove the code from the assembly. If this class is intended to contain only static members, make it 'static' (Module in Visual Basic). { diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/CreateGraphRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/CreateGraphRequest.cs index 301ee42170f3..c25da8ada8a0 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/CreateGraphRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/CreateGraphRequest.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Linq; using System.Net.Http; @@ -9,7 +8,7 @@ namespace Microsoft.SemanticKernel.Connectors.Weaviate; // ReSharper disable once ClassCannotBeInstantiated -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] internal sealed class CreateGraphRequest { #pragma warning disable CS8618 diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/DeleteObjectRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/DeleteObjectRequest.cs index dfbdd158f819..6580e6e9f355 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/DeleteObjectRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/DeleteObjectRequest.cs @@ -1,11 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; namespace Microsoft.SemanticKernel.Connectors.Weaviate; -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] internal sealed class DeleteObjectRequest { public string? Class { get; set; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/DeleteSchemaRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/DeleteSchemaRequest.cs index 73d7e2fae456..950e05e7e59d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/DeleteSchemaRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/DeleteSchemaRequest.cs @@ -1,11 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; namespace Microsoft.SemanticKernel.Connectors.Weaviate; -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] internal sealed class DeleteSchemaRequest { private readonly string _class; diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetClassRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetClassRequest.cs index 01669e527ced..48de160feca0 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetClassRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetClassRequest.cs @@ -1,12 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Weaviate; -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] internal sealed class GetClassRequest { private GetClassRequest(string @class) diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetClassResponse.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetClassResponse.cs index e176a1f1b619..e3aea355ba2a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetClassResponse.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetClassResponse.cs @@ -1,11 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; namespace Microsoft.SemanticKernel.Connectors.Weaviate; #pragma warning disable CA1812 // 'GetClassResponse' is an internal class that is apparently never instantiated. If so, remove the code from the assembly. If this class is intended to contain only static members, make it 'static' (Module in Visual Basic). -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] internal sealed class GetClassResponse #pragma warning restore CA1812 // 'GetClassResponse' is an internal class that is apparently never instantiated. If so, remove the code from the assembly. If this class is intended to contain only static members, make it 'static' (Module in Visual Basic). { diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetObjectRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetObjectRequest.cs index 4c4317e8a1ab..cc608e62f557 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetObjectRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetObjectRequest.cs @@ -1,11 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; namespace Microsoft.SemanticKernel.Connectors.Weaviate; -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] internal sealed class GetObjectRequest { public string? Id { get; set; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetSchemaRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetSchemaRequest.cs index 3a4be14541eb..94d0e57c180c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetSchemaRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetSchemaRequest.cs @@ -1,11 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; namespace Microsoft.SemanticKernel.Connectors.Weaviate; -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] internal sealed class GetSchemaRequest { public static GetSchemaRequest Create() diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetSchemaResponse.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetSchemaResponse.cs index 76620d603b5c..0503bb11aa22 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetSchemaResponse.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GetSchemaResponse.cs @@ -1,12 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; namespace Microsoft.SemanticKernel.Connectors.Weaviate; #pragma warning disable CA1812 // 'GetSchemaResponse' is an internal class that is apparently never instantiated. If so, remove the code from the assembly. If this class is intended to contain only static members, make it 'static' (Module in Visual Basic). -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] internal sealed class GetSchemaResponse #pragma warning restore CA1812 // 'GetSchemaResponse' is an internal class that is apparently never instantiated. If so, remove the code from the assembly. If this class is intended to contain only static members, make it 'static' (Module in Visual Basic). { diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GraphResponse.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GraphResponse.cs index e31c5645c7de..cf3c1073b64d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GraphResponse.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/GraphResponse.cs @@ -1,12 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Text.Json.Nodes; namespace Microsoft.SemanticKernel.Connectors.Weaviate; #pragma warning disable CA1812 // 'GraphResponse' is an internal class that is apparently never instantiated. If so, remove the code from the assembly. If this class is intended to contain only static members, make it 'static' (Module in Visual Basic). -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] internal sealed class GraphResponse #pragma warning restore CA1812 // 'GraphResponse' is an internal class that is apparently never instantiated. If so, remove the code from the assembly. If this class is intended to contain only static members, make it 'static' (Module in Visual Basic). { diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/ObjectResponseResult.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/ObjectResponseResult.cs index 3ed41a0a0187..33f1bfffe6b4 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/ObjectResponseResult.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/ApiSchema/ObjectResponseResult.cs @@ -1,13 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Text.Json.Nodes; namespace Microsoft.SemanticKernel.Connectors.Weaviate; // ReSharper disable once ClassNeverInstantiated.Global #pragma warning disable CA1812 // 'ObjectResponseResult' is an internal class that is apparently never instantiated. If so, remove the code from the assembly. If this class is intended to contain only static members, make it 'static' (Module in Visual Basic). -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] internal sealed class ObjectResponseResult #pragma warning restore CA1812 // 'ObjectResponseResult' is an internal class that is apparently never instantiated. If so, remove the code from the assembly. If this class is intended to contain only static members, make it 'static' (Module in Visual Basic). { diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/JsonConverter/UnixSecondsDateTimeJsonConverter.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/JsonConverter/UnixSecondsDateTimeJsonConverter.cs index 457ce2114a6d..5c3db6795970 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/JsonConverter/UnixSecondsDateTimeJsonConverter.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/Http/JsonConverter/UnixSecondsDateTimeJsonConverter.cs @@ -1,14 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Diagnostics.CodeAnalysis; using System.Text.Json; using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel.Connectors.Weaviate; #pragma warning disable CA1812 // 'UnixSecondsDateTimeJsonConverter' is an internal class that is apparently never instantiated. If so, remove the code from the assembly. If this class is intended to contain only static members, make it 'static' (Module in Visual Basic). -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] internal sealed class UnixSecondsDateTimeJsonConverter : JsonConverter #pragma warning restore CA1812 // 'UnixSecondsDateTimeJsonConverter' is an internal class that is apparently never instantiated. If so, remove the code from the assembly. If this class is intended to contain only static members, make it 'static' (Module in Visual Basic). { diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/HttpV2/WeaviateGetCollectionsRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/HttpV2/WeaviateGetCollectionsRequest.cs index f31017ca8685..40012278a076 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/HttpV2/WeaviateGetCollectionsRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/HttpV2/WeaviateGetCollectionsRequest.cs @@ -10,6 +10,6 @@ internal sealed class WeaviateGetCollectionsRequest public HttpRequestMessage Build() { - return HttpRequest.CreateGetRequest(ApiRoute, this); + return HttpRequest.CreateGetRequest(ApiRoute); } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/IWeaviateMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/IWeaviateMapper.cs new file mode 100644 index 000000000000..b0083d3073a5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/IWeaviateMapper.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Text.Json.Nodes; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Weaviate; + +internal interface IWeaviateMapper +{ + /// + /// Maps from the consumer record data model to the storage model. + /// + JsonObject MapFromDataToStorageModel(TRecord dataModel, int recordIndex, IReadOnlyList?[]? generatedEmbeddings); + + /// + /// Maps from the storage model to the consumer record data model. + /// + TRecord MapFromStorageToDataModel(JsonObject storageModel, StorageToDataModelMapperOptions options); +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/IWeaviateVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/IWeaviateVectorStoreRecordCollectionFactory.cs index 10210eb8fb82..36e4daf5075d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/IWeaviateVectorStoreRecordCollectionFactory.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/IWeaviateVectorStoreRecordCollectionFactory.cs @@ -25,5 +25,6 @@ IVectorStoreRecordCollection CreateVectorStoreRecordCollection Properties { get; set; } = []; + + [JsonPropertyName("vectorizer")] + public string Vectorizer { get; set; } = WeaviateConstants.DefaultVectorizer; + + [JsonPropertyName("vectorIndexType")] + public string? VectorIndexType { get; set; } + + [JsonPropertyName("vectorIndexConfig")] + public WeaviateCollectionSchemaVectorIndexConfig? VectorIndexConfig { get; set; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/ModelV2/WeaviateCollectionSchemaVectorConfig.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/ModelV2/WeaviateCollectionSchemaVectorConfig.cs index 75bd33471eb7..77830facd893 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/ModelV2/WeaviateCollectionSchemaVectorConfig.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/ModelV2/WeaviateCollectionSchemaVectorConfig.cs @@ -7,10 +7,8 @@ namespace Microsoft.SemanticKernel.Connectors.Weaviate; internal sealed class WeaviateCollectionSchemaVectorConfig { - private const string DefaultVectorizer = "none"; - [JsonPropertyName("vectorizer")] - public Dictionary Vectorizer { get; set; } = new() { [DefaultVectorizer] = null }; + public Dictionary Vectorizer { get; set; } = new() { [WeaviateConstants.DefaultVectorizer] = null }; [JsonPropertyName("vectorIndexType")] public string? VectorIndexType { get; set; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateConstants.cs index a260b4e9fc2c..f98d4a6304cd 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateConstants.cs @@ -4,6 +4,9 @@ namespace Microsoft.SemanticKernel.Connectors.Weaviate; internal sealed class WeaviateConstants { + /// The name of this vector store for telemetry purposes. + public const string VectorStoreSystemName = "weaviate"; + /// Reserved key property name in Weaviate. internal const string ReservedKeyPropertyName = "id"; @@ -13,6 +16,9 @@ internal sealed class WeaviateConstants /// Reserved vector property name in Weaviate. internal const string ReservedVectorPropertyName = "vectors"; + /// Reserved single vector property name in Weaviate. + internal const string ReservedSingleVectorPropertyName = "vector"; + /// Collection property name in Weaviate. internal const string CollectionPropertyName = "class"; @@ -24,4 +30,7 @@ internal sealed class WeaviateConstants /// Additional properties property name in Weaviate. internal const string AdditionalPropertiesPropertyName = "_additional"; + + /// Default vectorizer for vector properties in Weaviate. + internal const string DefaultVectorizer = "none"; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateDynamicDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateDynamicDataModelMapper.cs new file mode 100644 index 000000000000..5fdce5df16f8 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateDynamicDataModelMapper.cs @@ -0,0 +1,176 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Text.Json.Nodes; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; + +namespace Microsoft.SemanticKernel.Connectors.Weaviate; + +/// +/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Weaviate. +/// +internal sealed class WeaviateDynamicDataModelMapper : IWeaviateMapper> +{ + /// The name of the Weaviate collection. + private readonly string _collectionName; + + /// The model. + private readonly VectorStoreRecordModel _model; + + /// A for serialization/deserialization of record properties. + private readonly JsonSerializerOptions _jsonSerializerOptions; + + /// Gets a value indicating whether the vectors in the store are named and multiple vectors are supported, or whether there is just a single unnamed vector in Weaviate collection. + private readonly bool _hasNamedVectors; + + /// Gets a vector property named used in Weaviate collection. + private readonly string _vectorPropertyName; + + /// + /// Initializes a new instance of the class. + /// + /// The name of the Weaviate collection + /// Gets or sets a value indicating whether the vectors in the store are named and multiple vectors are supported, or whether there is just a single unnamed vector in Weaviate collection + /// The model + /// A for serialization/deserialization of record properties. + public WeaviateDynamicDataModelMapper( + string collectionName, + bool hasNamedVectors, + VectorStoreRecordModel model, + JsonSerializerOptions jsonSerializerOptions) + { + this._collectionName = collectionName; + this._hasNamedVectors = hasNamedVectors; + this._model = model; + this._jsonSerializerOptions = jsonSerializerOptions; + + this._vectorPropertyName = hasNamedVectors ? + WeaviateConstants.ReservedVectorPropertyName : + WeaviateConstants.ReservedSingleVectorPropertyName; + } + + public JsonObject MapFromDataToStorageModel(Dictionary dataModel, int recordIndex, IReadOnlyList?[]? generatedEmbeddings) + { + Verify.NotNull(dataModel); + + // Transform generic data model to Weaviate object model. + var keyNode = JsonSerializer.SerializeToNode(dataModel[this._model.KeyProperty.ModelName]); + + // Populate data properties. + var dataNode = new JsonObject(); + foreach (var property in this._model.DataProperties) + { + if (dataModel.TryGetValue(property.ModelName, out var dataValue)) + { + dataNode[property.StorageName] = dataValue is null + ? null + : JsonSerializer.SerializeToNode(dataValue, property.Type, this._jsonSerializerOptions); + } + } + + // Populate vector properties. + JsonNode? vectorNode = null; + + if (this._hasNamedVectors) + { + vectorNode = new JsonObject(); + + for (var i = 0; i < this._model.VectorProperties.Count; i++) + { + var property = this._model.VectorProperties[i]; + + var vectorValue = generatedEmbeddings?[i] switch + { + IReadOnlyList> e => e[recordIndex].Vector, + IReadOnlyList> e => e[recordIndex].Vector, + null => dataModel.TryGetValue(property.ModelName, out var v) ? v : null, + _ => throw new NotSupportedException($"Unsupported embedding type '{generatedEmbeddings?[i]?.GetType().Name}' for property '{property.ModelName}'.") + }; + + vectorNode[property.StorageName] = vectorValue is null + ? null + : JsonSerializer.SerializeToNode(vectorValue, property.EmbeddingType, this._jsonSerializerOptions); + } + } + else + { + var vectorValue = generatedEmbeddings?[0] is IReadOnlyList> e + ? e[recordIndex].Vector + : dataModel.TryGetValue(this._model.VectorProperty.ModelName, out var v) + ? v + : null; + + vectorNode = vectorValue is null + ? null + : JsonSerializer.SerializeToNode(vectorValue, this._model.VectorProperty.EmbeddingType, this._jsonSerializerOptions); + } + + return new JsonObject + { + { WeaviateConstants.CollectionPropertyName, JsonValue.Create(this._collectionName) }, + { WeaviateConstants.ReservedKeyPropertyName, keyNode }, + { WeaviateConstants.ReservedDataPropertyName, dataNode }, + { this._vectorPropertyName, vectorNode }, + }; + } + + public Dictionary MapFromStorageToDataModel(JsonObject storageModel, StorageToDataModelMapperOptions options) + { + Verify.NotNull(storageModel); + + var result = new Dictionary(); + + // Create variables to store the response properties. + var key = storageModel[WeaviateConstants.ReservedKeyPropertyName]?.GetValue(); + + if (!key.HasValue) + { + throw new VectorStoreRecordMappingException("No key property was found in the record retrieved from storage."); + } + + result[this._model.KeyProperty.ModelName] = key.Value; + + // Populate data properties. + foreach (var property in this._model.DataProperties) + { + var jsonObject = storageModel[WeaviateConstants.ReservedDataPropertyName] as JsonObject; + + if (jsonObject is not null && jsonObject.TryGetPropertyValue(property.StorageName, out var dataValue)) + { + result.Add(property.ModelName, dataValue.Deserialize(property.Type, this._jsonSerializerOptions)); + } + } + + // Populate vector properties. + if (options.IncludeVectors) + { + if (this._hasNamedVectors) + { + foreach (var property in this._model.VectorProperties) + { + var jsonObject = storageModel[WeaviateConstants.ReservedVectorPropertyName] as JsonObject; + + if (jsonObject is not null && jsonObject.TryGetPropertyValue(property.StorageName, out var vectorValue)) + { + result.Add(property.ModelName, vectorValue.Deserialize(property.Type, this._jsonSerializerOptions)); + } + } + } + else + { + var jsonNode = storageModel[WeaviateConstants.ReservedSingleVectorPropertyName]; + + if (jsonNode is not null) + { + result.Add(this._model.VectorProperty.ModelName, jsonNode.Deserialize(this._model.VectorProperty.Type, this._jsonSerializerOptions)); + } + } + } + + return result; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateFilterTranslator.cs index 2e4be5391159..87aa773617f3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateFilterTranslator.cs @@ -6,30 +6,33 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.ConnectorSupport.Filter; namespace Microsoft.SemanticKernel.Connectors.Weaviate; // https://weaviate.io/developers/weaviate/api/graphql/filters#filter-structure internal class WeaviateFilterTranslator { - private IReadOnlyDictionary _storagePropertyNames = null!; + private VectorStoreRecordModel _model = null!; private ParameterExpression _recordParameter = null!; private readonly StringBuilder _filter = new(); - internal string Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + internal string Translate(LambdaExpression lambdaExpression, VectorStoreRecordModel model) { Debug.Assert(this._filter.Length == 0); - this._storagePropertyNames = storagePropertyNames; + this._model = model; Debug.Assert(lambdaExpression.Parameters.Count == 1); this._recordParameter = lambdaExpression.Parameters[0]; - this.Translate(lambdaExpression.Body); + var preprocessor = new FilterTranslationPreprocessor { InlineCapturedVariables = true }; + var preprocessedExpression = preprocessor.Visit(lambdaExpression.Body); + + this.Translate(preprocessedExpression); return this._filter.ToString(); } @@ -66,7 +69,7 @@ or ExpressionType.LessThan or ExpressionType.LessThanOrEqual { switch (not.Operand) { - // Special handling for !(a == b) and !(a != b) + // Special handling for !(a == b) and !(a != b), transforming to a != b and a == b respectively. case BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary: this.TranslateEqualityComparison( Expression.MakeBinary( @@ -75,9 +78,9 @@ or ExpressionType.LessThan or ExpressionType.LessThanOrEqual binary.Right)); return; - // Not over bool field (Filter => r => !r.Bool) - case MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _): - this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(false))); + // Not over bool field (r => !r.Bool) + case var negated when negated.Type == typeof(bool) && this.TryBindProperty(negated, out var property): + this.GenerateEqualityComparison(property.StorageName, false, ExpressionType.Equal); return; default: @@ -85,9 +88,9 @@ or ExpressionType.LessThan or ExpressionType.LessThanOrEqual } } - // MemberExpression is generally handled within e.g. TranslateEqual; this is used to translate direct bool inside filter (e.g. Filter => r => r.Bool) - case MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _): - this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(true))); + // Special handling for bool constant as the filter expression (r => r.Bool) + case Expression when node.Type == typeof(bool) && this.TryBindProperty(node, out var property): + this.GenerateEqualityComparison(property.StorageName, true, ExpressionType.Equal); return; case MethodCallExpression methodCall: @@ -101,75 +104,84 @@ or ExpressionType.LessThan or ExpressionType.LessThanOrEqual private void TranslateEqualityComparison(BinaryExpression binary) { - if ((this.TryTranslateFieldAccess(binary.Left, out var storagePropertyName) && TryGetConstant(binary.Right, out var value)) - || (this.TryTranslateFieldAccess(binary.Right, out storagePropertyName) && TryGetConstant(binary.Left, out value))) + if (this.TryBindProperty(binary.Left, out var property) && binary.Right is ConstantExpression { Value: var rightConstant }) { - // { path: ["intPropName"], operator: Equal, ValueInt: 8 } - this._filter - .Append("{ path: [\"") - .Append(JsonEncodedText.Encode(storagePropertyName)) - .Append("\"], operator: "); + this.GenerateEqualityComparison(property.StorageName, rightConstant, binary.NodeType); + return; + } - // Special handling for null comparisons - if (value is null) - { - if (binary.NodeType is ExpressionType.Equal or ExpressionType.NotEqual) - { - this._filter - .Append("IsNull, valueBoolean: ") - .Append(binary.NodeType is ExpressionType.Equal ? "true" : "false") - .Append(" }"); - return; - } + if (this.TryBindProperty(binary.Right, out property) && binary.Left is ConstantExpression { Value: var leftConstant }) + { + this.GenerateEqualityComparison(property.StorageName, leftConstant, binary.NodeType); + return; + } - throw new NotSupportedException("null value supported only with equality/inequality checks"); - } + throw new NotSupportedException("Invalid equality/comparison"); + } - // Operator - this._filter.Append(binary.NodeType switch + private void GenerateEqualityComparison(string propertyStorageName, object? value, ExpressionType nodeType) + { + // { path: ["intPropName"], operator: Equal, ValueInt: 8 } + this._filter + .Append("{ path: [\"") + .Append(JsonEncodedText.Encode(propertyStorageName)) + .Append("\"], operator: "); + + // Special handling for null comparisons + if (value is null) + { + if (nodeType is ExpressionType.Equal or ExpressionType.NotEqual) { - ExpressionType.Equal => "Equal", - ExpressionType.NotEqual => "NotEqual", + this._filter + .Append("IsNull, valueBoolean: ") + .Append(nodeType is ExpressionType.Equal ? "true" : "false") + .Append(" }"); + return; + } - ExpressionType.GreaterThan => "GreaterThan", - ExpressionType.GreaterThanOrEqual => "GreaterThanEqual", - ExpressionType.LessThan => "LessThan", - ExpressionType.LessThanOrEqual => "LessThanEqual", + throw new NotSupportedException("null value supported only with equality/inequality checks"); + } - _ => throw new UnreachableException() - }); + // Operator + this._filter.Append(nodeType switch + { + ExpressionType.Equal => "Equal", + ExpressionType.NotEqual => "NotEqual", - this._filter.Append(", "); + ExpressionType.GreaterThan => "GreaterThan", + ExpressionType.GreaterThanOrEqual => "GreaterThanEqual", + ExpressionType.LessThan => "LessThan", + ExpressionType.LessThanOrEqual => "LessThanEqual", - // FieldType - var type = value.GetType(); - if (Nullable.GetUnderlyingType(type) is Type underlying) - { - type = underlying; - } + _ => throw new UnreachableException() + }); - this._filter.Append(value.GetType() switch - { - Type t when t == typeof(int) || t == typeof(long) || t == typeof(short) || t == typeof(byte) => "valueInt", - Type t when t == typeof(bool) => "valueBoolean", - Type t when t == typeof(string) || t == typeof(Guid) => "valueText", - Type t when t == typeof(float) || t == typeof(double) || t == typeof(decimal) => "valueNumber", - Type t when t == typeof(DateTimeOffset) => "valueDate", + this._filter.Append(", "); - _ => throw new NotSupportedException($"Unsupported value type {type.FullName} in filter.") - }); + // FieldType + var type = value.GetType(); + if (Nullable.GetUnderlyingType(type) is Type underlying) + { + type = underlying; + } - this._filter.Append(": "); + this._filter.Append(value.GetType() switch + { + Type t when t == typeof(int) || t == typeof(long) || t == typeof(short) || t == typeof(byte) => "valueInt", + Type t when t == typeof(bool) => "valueBoolean", + Type t when t == typeof(string) || t == typeof(Guid) => "valueText", + Type t when t == typeof(float) || t == typeof(double) || t == typeof(decimal) => "valueNumber", + Type t when t == typeof(DateTimeOffset) => "valueDate", - // Value - this._filter.Append(JsonSerializer.Serialize(value)); + _ => throw new NotSupportedException($"Unsupported value type {type.FullName} in filter.") + }); - this._filter.Append('}'); + this._filter.Append(": "); - return; - } + // Value + this._filter.Append(JsonSerializer.Serialize(value)); - throw new NotSupportedException("Invalid equality/comparison"); + this._filter.Append('}'); } private void TranslateMethodCall(MethodCallExpression methodCall) @@ -205,13 +217,11 @@ private void TranslateContains(Expression source, Expression item) { // Contains over array // { path: ["stringArrayPropName"], operator: ContainsAny, valueText: ["foo"] } - if (this.TryTranslateFieldAccess(source, out var storagePropertyName) - && TryGetConstant(item, out var itemConstant) - && itemConstant is string stringConstant) + if (this.TryBindProperty(source, out var property) && item is ConstantExpression { Value: string stringConstant }) { this._filter .Append("{ path: [\"") - .Append(JsonEncodedText.Encode(storagePropertyName)) + .Append(JsonEncodedText.Encode(property.StorageName)) .Append("\"], operator: ContainsAny, valueText: [") .Append(JsonEncodedText.Encode(stringConstant)) .Append("]}"); @@ -221,40 +231,49 @@ private void TranslateContains(Expression source, Expression item) throw new NotSupportedException("Contains supported only over tag field"); } - private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName) + private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out VectorStoreRecordPropertyModel? property) { - if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter) - { - if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName)) - { - throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name."); - } + Type? convertedClrType = null; - return true; + if (expression is UnaryExpression { NodeType: ExpressionType.Convert } unary) + { + expression = unary.Operand; + convertedClrType = unary.Type; } - storagePropertyName = null; - return false; - } + var modelName = expression switch + { + // Regular member access for strongly-typed POCO binding (e.g. r => r.SomeInt == 8) + MemberExpression memberExpression when memberExpression.Expression == this._recordParameter + => memberExpression.Member.Name, - private static bool TryGetConstant(Expression expression, out object? constantValue) - { - switch (expression) + // Dictionary lookup for weakly-typed dynamic binding (e.g. r => r["SomeInt"] == 8) + MethodCallExpression + { + Method: { Name: "get_Item", DeclaringType: var declaringType }, + Arguments: [ConstantExpression { Value: string keyName }] + } methodCall when methodCall.Object == this._recordParameter && declaringType == typeof(Dictionary) + => keyName, + + _ => null + }; + + if (modelName is null) { - case ConstantExpression { Value: var v }: - constantValue = v; - return true; + property = null; + return false; + } - // This identifies compiler-generated closure types which contain captured variables. - case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } - when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) - && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true): - constantValue = fieldInfo.GetValue(constant.Value); - return true; + if (!this._model.PropertyMap.TryGetValue(modelName, out property)) + { + throw new InvalidOperationException($"Property name '{modelName}' provided as part of the filter clause is not a valid property name."); + } - default: - constantValue = null; - return false; + if (convertedClrType is not null && convertedClrType != property.Type) + { + throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convertedClrType.Name}', but its configured type is '{property.Type.Name}'."); } + + return true; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateGenericDataModelMapper.cs deleted file mode 100644 index 7e7640744d2d..000000000000 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateGenericDataModelMapper.cs +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Text.Json; -using System.Text.Json.Nodes; -using Microsoft.Extensions.VectorData; - -namespace Microsoft.SemanticKernel.Connectors.Weaviate; - -/// -/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Weaviate. -/// -internal sealed class WeaviateGenericDataModelMapper : IVectorStoreRecordMapper, JsonObject> -{ - /// The name of the Weaviate collection. - private readonly string _collectionName; - - /// A property of record definition. - private readonly VectorStoreRecordKeyProperty _keyProperty; - - /// A collection of properties of record definition. - private readonly IReadOnlyList _dataProperties; - - /// A collection of properties of record definition. - private readonly IReadOnlyList _vectorProperties; - - /// A dictionary that maps from a property name to the storage name. - private readonly IReadOnlyDictionary _storagePropertyNames; - - /// A for serialization/deserialization of record properties. - private readonly JsonSerializerOptions _jsonSerializerOptions; - - /// - /// Initializes a new instance of the class. - /// - /// The name of the Weaviate collection - /// A property of record definition. - /// A collection of properties of record definition. - /// A collection of properties of record definition. - /// A dictionary that maps from a property name to the storage name. - /// A for serialization/deserialization of record properties. - public WeaviateGenericDataModelMapper( - string collectionName, - VectorStoreRecordKeyProperty keyProperty, - IReadOnlyList dataProperties, - IReadOnlyList vectorProperties, - IReadOnlyDictionary storagePropertyNames, - JsonSerializerOptions jsonSerializerOptions) - { - Verify.NotNullOrWhiteSpace(collectionName); - Verify.NotNull(keyProperty); - Verify.NotNull(dataProperties); - Verify.NotNull(vectorProperties); - Verify.NotNull(storagePropertyNames); - Verify.NotNull(jsonSerializerOptions); - - this._collectionName = collectionName; - this._keyProperty = keyProperty; - this._dataProperties = dataProperties; - this._vectorProperties = vectorProperties; - this._storagePropertyNames = storagePropertyNames; - this._jsonSerializerOptions = jsonSerializerOptions; - } - - public JsonObject MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - { - Verify.NotNull(dataModel); - - // Transform generic data model to Weaviate object model. - var weaviateObjectModel = new JsonObject - { - { WeaviateConstants.CollectionPropertyName, JsonValue.Create(this._collectionName) }, - { WeaviateConstants.ReservedKeyPropertyName, dataModel.Key }, - { WeaviateConstants.ReservedDataPropertyName, new JsonObject() }, - { WeaviateConstants.ReservedVectorPropertyName, new JsonObject() }, - }; - - // Populate data properties. - foreach (var property in this._dataProperties) - { - if (dataModel.Data is not null && dataModel.Data.TryGetValue(property.DataModelPropertyName, out var dataValue)) - { - var storagePropertyName = this._storagePropertyNames[property.DataModelPropertyName]; - - weaviateObjectModel[WeaviateConstants.ReservedDataPropertyName]![storagePropertyName] = dataValue is not null ? - JsonSerializer.SerializeToNode(dataValue, property.PropertyType, this._jsonSerializerOptions) : - null; - } - } - - // Populate vector properties. - foreach (var property in this._vectorProperties) - { - if (dataModel.Vectors is not null && dataModel.Vectors.TryGetValue(property.DataModelPropertyName, out var vectorValue)) - { - var storagePropertyName = this._storagePropertyNames[property.DataModelPropertyName]; - - weaviateObjectModel[WeaviateConstants.ReservedVectorPropertyName]![storagePropertyName] = vectorValue is not null ? - JsonSerializer.SerializeToNode(vectorValue, property.PropertyType, this._jsonSerializerOptions) : - null; - } - } - - return weaviateObjectModel; - } - - public VectorStoreGenericDataModel MapFromStorageToDataModel(JsonObject storageModel, StorageToDataModelMapperOptions options) - { - Verify.NotNull(storageModel); - - // Create variables to store the response properties. - var key = storageModel[WeaviateConstants.ReservedKeyPropertyName]?.GetValue(); - - if (!key.HasValue) - { - throw new VectorStoreRecordMappingException("No key property was found in the record retrieved from storage."); - } - - var dataProperties = new Dictionary(); - var vectorProperties = new Dictionary(); - - // Populate data properties. - foreach (var property in this._dataProperties) - { - var storagePropertyName = this._storagePropertyNames[property.DataModelPropertyName]; - var jsonObject = storageModel[WeaviateConstants.ReservedDataPropertyName] as JsonObject; - - if (jsonObject is not null && jsonObject.TryGetPropertyValue(storagePropertyName, out var dataValue)) - { - dataProperties.Add(property.DataModelPropertyName, dataValue.Deserialize(property.PropertyType, this._jsonSerializerOptions)); - } - } - - // Populate vector properties. - if (options.IncludeVectors) - { - foreach (var property in this._vectorProperties) - { - var storagePropertyName = this._storagePropertyNames[property.DataModelPropertyName]; - var jsonObject = storageModel[WeaviateConstants.ReservedVectorPropertyName] as JsonObject; - - if (jsonObject is not null && jsonObject.TryGetPropertyValue(storagePropertyName, out var vectorValue)) - { - vectorProperties.Add(property.DataModelPropertyName, vectorValue.Deserialize(property.PropertyType, this._jsonSerializerOptions)); - } - } - } - - return new VectorStoreGenericDataModel(key.Value) { Data = dataProperties, Vectors = vectorProperties }; - } -} diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateKernelBuilderExtensions.cs index 45c320df959c..23f8b0881ee5 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateKernelBuilderExtensions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Net.Http; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Weaviate; @@ -9,6 +10,7 @@ namespace Microsoft.SemanticKernel; /// /// Extension methods to register Weaviate instances on the . /// +[Obsolete("The IKernelBuilder extensions are being obsoleted, call the appropriate function on the Services property of your IKernelBuilder")] public static class WeaviateKernelBuilderExtensions { /// @@ -34,7 +36,7 @@ public static IKernelBuilder AddWeaviateVectorStore( } /// - /// Register a Weaviate and with the specified service ID. + /// Register a Weaviate and with the specified service ID. /// /// The type of the record. /// The builder to register the on. @@ -53,6 +55,7 @@ public static IKernelBuilder AddWeaviateVectorStoreRecordCollection( HttpClient? httpClient = default, WeaviateVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { builder.Services.AddWeaviateVectorStoreRecordCollection(collectionName, httpClient, options, serviceId); return builder; diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateMemoryBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateMemoryBuilderExtensions.cs index 40795d21eb30..b5247d94058b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateMemoryBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateMemoryBuilderExtensions.cs @@ -1,16 +1,18 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System; using System.Net.Http; using Microsoft.SemanticKernel.Http; using Microsoft.SemanticKernel.Memory; namespace Microsoft.SemanticKernel.Connectors.Weaviate; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// Provides extension methods for the class to configure Weaviate connector. /// -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] public static class WeaviateMemoryBuilderExtensions { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateMemoryStore.cs index ca45d0b828f3..9382d081d481 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateMemoryStore.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Linq; using System.Net; @@ -21,6 +20,8 @@ namespace Microsoft.SemanticKernel.Connectors.Weaviate; +#pragma warning disable SKEXP0001 // IMemoryStore is experimental (but we're obsoleting) + /// /// An implementation of for Weaviate. /// @@ -29,7 +30,7 @@ namespace Microsoft.SemanticKernel.Connectors.Weaviate; /// // ReSharper disable once ClassWithVirtualMembersNeverInherited.Global #pragma warning disable CA1001 // Types that own disposable fields should be disposable. No need to dispose the Http client here. It can either be an internal client using NonDisposableHttpClientHandler or an external client managed by the calling code, which should handle its disposal. -[Experimental("SKEXP0020")] +[Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and WeaviateVectorStore")] public partial class WeaviateMemoryStore : IMemoryStore #pragma warning restore CA1001 // Types that own disposable fields should be disposable. No need to dispose the Http client here. It can either be an internal client using NonDisposableHttpClientHandler or an external client managed by the calling code, which should handle its disposal. { diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateModelBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateModelBuilder.cs new file mode 100644 index 000000000000..fcc645826465 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateModelBuilder.cs @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData.ConnectorSupport; + +namespace Microsoft.SemanticKernel.Connectors.Weaviate; + +internal class WeaviateModelBuilder(bool hasNamedVectors) : VectorStoreRecordJsonModelBuilder(GetModelBuildingOptions(hasNamedVectors)) +{ + private static VectorStoreRecordModelBuildingOptions GetModelBuildingOptions(bool hasNamedVectors) + { + return new() + { + RequiresAtLeastOneVector = !hasNamedVectors, + SupportsMultipleKeys = false, + SupportsMultipleVectors = hasNamedVectors, + + SupportedKeyPropertyTypes = [typeof(Guid)], + SupportedDataPropertyTypes = s_supportedDataTypes, + SupportedEnumerableDataPropertyElementTypes = s_supportedDataTypes, + SupportedVectorPropertyTypes = s_supportedVectorTypes, + + UsesExternalSerializer = true, + ReservedKeyStorageName = WeaviateConstants.ReservedKeyPropertyName + }; + } + + private static readonly HashSet s_supportedDataTypes = + [ + typeof(string), + typeof(bool), + typeof(bool?), + typeof(int), + typeof(int?), + typeof(long), + typeof(long?), + typeof(short), + typeof(short?), + typeof(byte), + typeof(byte?), + typeof(float), + typeof(float?), + typeof(double), + typeof(double?), + typeof(decimal), + typeof(decimal?), + typeof(DateTime), + typeof(DateTime?), + typeof(DateTimeOffset), + typeof(DateTimeOffset?), + typeof(Guid), + typeof(Guid?) + ]; + + internal static readonly HashSet s_supportedVectorTypes = + [ + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?), + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?) + ]; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateServiceCollectionExtensions.cs index 7f6dfe48a404..f277b48ec34d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateServiceCollectionExtensions.cs @@ -2,6 +2,7 @@ using System; using System.Net.Http; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Weaviate; @@ -37,7 +38,10 @@ public static IServiceCollection AddWeaviateVectorStore( (sp, obj) => { var selectedHttpClient = HttpClientProvider.GetHttpClient(httpClient, sp); - var selectedOptions = options ?? sp.GetService(); + options ??= sp.GetService() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; return new WeaviateVectorStore(selectedHttpClient, options); }); @@ -45,7 +49,7 @@ public static IServiceCollection AddWeaviateVectorStore( } /// - /// Register a Weaviate and with the specified service ID. + /// Register a Weaviate and with the specified service ID. /// /// The type of the record. /// The to register the on. @@ -64,15 +68,19 @@ public static IServiceCollection AddWeaviateVectorStoreRecordCollection HttpClient? httpClient = default, WeaviateVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) + where TRecord : notnull { services.AddKeyedTransient>( serviceId, (sp, obj) => { var selectedHttpClient = HttpClientProvider.GetHttpClient(httpClient, sp); - var selectedOptions = options ?? sp.GetService>(); + options ??= sp.GetService>() ?? new() + { + EmbeddingGenerator = sp.GetService() + }; - return new WeaviateVectorStoreRecordCollection(selectedHttpClient, collectionName, selectedOptions); + return new WeaviateVectorStoreRecordCollection(selectedHttpClient, collectionName, options); }); AddVectorizedSearch(services, serviceId); @@ -81,14 +89,14 @@ public static IServiceCollection AddWeaviateVectorStoreRecordCollection } /// - /// Also register the with the given as a . + /// Also register the with the given as a . /// /// The type of the data model that the collection should contain. /// The service collection to register on. /// The service id that the registrations should use. - private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) + private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) where TRecord : notnull { - services.AddKeyedTransient>( + services.AddKeyedTransient>( serviceId, (sp, obj) => { diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStore.cs index 1c45d1e3ac65..98e8ca2a84f1 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStore.cs @@ -6,8 +6,8 @@ using System.Runtime.CompilerServices; using System.Text.Json; using System.Threading; +using System.Threading.Tasks; using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Http; namespace Microsoft.SemanticKernel.Connectors.Weaviate; @@ -17,14 +17,20 @@ namespace Microsoft.SemanticKernel.Connectors.Weaviate; /// /// This class can be used with collections of any schema type, but requires you to provide schema information when getting a collection. /// -public class WeaviateVectorStore : IVectorStore +public sealed class WeaviateVectorStore : IVectorStore { + /// Metadata about vector store. + private readonly VectorStoreMetadata _metadata; + /// that is used to interact with Weaviate API. private readonly HttpClient _httpClient; /// Optional configuration options for this class. private readonly WeaviateVectorStoreOptions _options; + /// A general purpose definition that can be used to construct a collection when needing to proxy schema agnostic operations. + private static readonly VectorStoreRecordDefinition s_generalPurposeDefinition = new() { Properties = [new VectorStoreRecordKeyProperty("Key", typeof(Guid)), new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 1)] }; + /// /// Initializes a new instance of the class. /// @@ -40,12 +46,18 @@ public WeaviateVectorStore(HttpClient httpClient, WeaviateVectorStoreOptions? op this._httpClient = httpClient; this._options = options ?? new(); + + this._metadata = new() + { + VectorStoreSystemName = WeaviateConstants.VectorStoreSystemName + }; } /// /// The collection name must start with a capital letter and contain only ASCII letters and digits. - public virtual IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull + where TRecord : notnull { #pragma warning disable CS0618 // IWeaviateVectorStoreRecordCollectionFactory is obsolete if (this._options.VectorStoreCollectionFactory is not null) @@ -57,39 +69,80 @@ public virtual IVectorStoreRecordCollection GetCollection( + var recordCollection = new WeaviateVectorStoreRecordCollection( this._httpClient, name, new() { VectorStoreRecordDefinition = vectorStoreRecordDefinition, Endpoint = this._options.Endpoint, - ApiKey = this._options.ApiKey + ApiKey = this._options.ApiKey, + HasNamedVectors = this._options.HasNamedVectors, + EmbeddingGenerator = this._options.EmbeddingGenerator }) as IVectorStoreRecordCollection; - return recordCollection!; + return recordCollection; } /// - public virtual async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { using var request = new WeaviateGetCollectionsRequest().Build(); + WeaviateGetCollectionsResponse collectionsResponse; - var response = await this._httpClient.SendWithSuccessCheckAsync(request, cancellationToken).ConfigureAwait(false); - var responseContent = await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false); - var collectionResponse = JsonSerializer.Deserialize(responseContent); + try + { + var httpResponse = await this._httpClient.SendAsync(request, HttpCompletionOption.ResponseContentRead, cancellationToken).ConfigureAwait(false); + + httpResponse.EnsureSuccessStatusCode(); + + var httpResponseContent = await httpResponse.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + + collectionsResponse = JsonSerializer.Deserialize(httpResponseContent)!; + } + catch (Exception e) + { + throw new VectorStoreOperationException("Call to vector store failed.", e) + { + VectorStoreSystemName = WeaviateConstants.VectorStoreSystemName, + VectorStoreName = this._metadata.VectorStoreName, + OperationName = "ListCollectionNames" + }; + } - if (collectionResponse?.Collections is not null) + if (collectionsResponse?.Collections is not null) { - foreach (var collection in collectionResponse.Collections) + foreach (var collection in collectionsResponse.Collections) { yield return collection.CollectionName; } } } + + /// + public Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.CollectionExistsAsync(cancellationToken); + } + + /// + public Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetCollection>(name, s_generalPurposeDefinition); + return collection.DeleteCollectionAsync(cancellationToken); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreMetadata) ? this._metadata : + serviceType == typeof(HttpClient) ? this._httpClient : + serviceType.IsInstanceOfType(this) ? this : + null; + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreCollectionCreateMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreCollectionCreateMapping.cs index 13b944210b14..852339436432 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreCollectionCreateMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreCollectionCreateMapping.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Linq; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.Weaviate; @@ -18,42 +19,48 @@ internal static class WeaviateVectorStoreCollectionCreateMapping /// Maps record type properties to Weaviate collection schema for collection creation. /// /// The name of the vector store collection. - /// Collection of record data properties. - /// Collection of record vector properties. - /// A dictionary that maps from a property name to the storage name that should be used when serializing it to JSON for data and vector properties. + /// Gets a value indicating whether the vectors in the store are named and multiple vectors are supported, or whether there is just a single unnamed vector in Weaviate collection. + /// The model. /// Weaviate collection schema. - public static WeaviateCollectionSchema MapToSchema( - string collectionName, - IEnumerable dataProperties, - IEnumerable vectorProperties, - IReadOnlyDictionary storagePropertyNames) + public static WeaviateCollectionSchema MapToSchema(string collectionName, bool hasNamedVectors, VectorStoreRecordModel model) { var schema = new WeaviateCollectionSchema(collectionName); // Handle data properties. - foreach (var property in dataProperties) + foreach (var property in model.DataProperties) { schema.Properties.Add(new WeaviateCollectionSchemaProperty { - Name = storagePropertyNames[property.DataModelPropertyName], - DataType = [MapType(property.PropertyType)], - IndexFilterable = property.IsFilterable, - IndexSearchable = property.IsFullTextSearchable + Name = property.StorageName, + DataType = [MapType(property.Type)], + IndexFilterable = property.IsIndexed, + IndexSearchable = property.IsFullTextIndexed }); } // Handle vector properties. - foreach (var property in vectorProperties) + if (hasNamedVectors) { - var vectorPropertyName = storagePropertyNames[property.DataModelPropertyName]; - schema.VectorConfigurations.Add(vectorPropertyName, new WeaviateCollectionSchemaVectorConfig + foreach (var property in model.VectorProperties) { - VectorIndexType = MapIndexKind(property.IndexKind, vectorPropertyName), - VectorIndexConfig = new WeaviateCollectionSchemaVectorIndexConfig + schema.VectorConfigurations.Add(property.StorageName, new WeaviateCollectionSchemaVectorConfig { - Distance = MapDistanceFunction(property.DistanceFunction, vectorPropertyName) - } - }); + VectorIndexType = MapIndexKind(property.IndexKind, property.StorageName), + VectorIndexConfig = new WeaviateCollectionSchemaVectorIndexConfig + { + Distance = MapDistanceFunction(property.DistanceFunction, property.StorageName) + } + }); + } + } + else + { + var vectorProperty = model.VectorProperty; + schema.VectorIndexType = MapIndexKind(vectorProperty.IndexKind, vectorProperty.StorageName); + schema.VectorIndexConfig = new WeaviateCollectionSchemaVectorIndexConfig + { + Distance = MapDistanceFunction(vectorProperty.DistanceFunction, vectorProperty.StorageName) + }; } return schema; @@ -116,7 +123,7 @@ private static string MapDistanceFunction(string? distanceFunction, string vecto DistanceFunction.EuclideanSquaredDistance => EuclideanSquared, DistanceFunction.Hamming => Hamming, DistanceFunction.ManhattanDistance => Manhattan, - _ => throw new InvalidOperationException( + _ => throw new NotSupportedException( $"Distance function '{distanceFunction}' on {nameof(VectorStoreRecordVectorProperty)} '{vectorPropertyName}' is not supported by the Weaviate VectorStore. " + $"Supported distance functions: {string.Join(", ", DistanceFunction.CosineDistance, diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreCollectionSearchMapping.cs index 3842a3aded97..02ec36be81a1 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreCollectionSearchMapping.cs @@ -13,7 +13,10 @@ internal static class WeaviateVectorStoreCollectionSearchMapping /// /// Maps vector search result to the format, which is processable by . /// - public static (JsonObject StorageModel, double? Score) MapSearchResult(JsonNode result, string scorePropertyName) + public static (JsonObject StorageModel, double? Score) MapSearchResult( + JsonNode result, + string scorePropertyName, + bool hasNamedVectors) { var additionalProperties = result[WeaviateConstants.AdditionalPropertiesPropertyName]; @@ -25,14 +28,18 @@ public static (JsonObject StorageModel, double? Score) MapSearchResult(JsonNode _ => null }; + var vectorPropertyName = hasNamedVectors ? + WeaviateConstants.ReservedVectorPropertyName : + WeaviateConstants.ReservedSingleVectorPropertyName; + var id = additionalProperties?[WeaviateConstants.ReservedKeyPropertyName]; - var vectors = additionalProperties?[WeaviateConstants.ReservedVectorPropertyName]; + var vectors = additionalProperties?[vectorPropertyName]; var storageModel = new JsonObject { { WeaviateConstants.ReservedKeyPropertyName, id?.DeepClone() }, { WeaviateConstants.ReservedDataPropertyName, result?.DeepClone() }, - { WeaviateConstants.ReservedVectorPropertyName, vectors?.DeepClone() }, + { vectorPropertyName, vectors?.DeepClone() }, }; return (storageModel, score); diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreOptions.cs index ae73e7989d82..c4b048c6fe4a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreOptions.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.Connectors.Weaviate; @@ -10,7 +11,7 @@ namespace Microsoft.SemanticKernel.Connectors.Weaviate; public sealed class WeaviateVectorStoreOptions { /// - /// An optional factory to use for constructing instances, if a custom record collection is required. + /// An optional factory to use for constructing instances, if a custom record collection is required. /// [Obsolete("To control how collections are instantiated, extend your provider's IVectorStore implementation and override GetCollection()")] public IWeaviateVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } @@ -27,4 +28,16 @@ public sealed class WeaviateVectorStoreOptions /// This parameter is optional because authentication may be disabled in local clusters for testing purposes. /// public string? ApiKey { get; set; } = null; + + /// + /// Gets or sets a value indicating whether the vectors in the store are named and multiple vectors are supported, or whether there is just a single unnamed vector in Weaviate collection. + /// Defaults to multiple named vectors. + /// . + /// + public bool HasNamedVectors { get; set; } = true; + + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollection.cs index 393b9a841cbb..91d76424dc0c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollection.cs @@ -2,7 +2,9 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using System.Linq.Expressions; using System.Net; using System.Net.Http; using System.Net.Http.Headers; @@ -11,64 +13,26 @@ using System.Text.Json.Nodes; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Http; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.Properties; namespace Microsoft.SemanticKernel.Connectors.Weaviate; /// /// Service for storing and retrieving vector records, that uses Weaviate as the underlying storage. /// +/// The data type of the record key. Can be either , or for dynamic mapping. /// The data model to use for adding, updating and retrieving data from storage. #pragma warning disable CA1711 // Identifiers should not have incorrect suffix -public class WeaviateVectorStoreRecordCollection : IVectorStoreRecordCollection, IKeywordHybridSearch +public sealed class WeaviateVectorStoreRecordCollection : IVectorStoreRecordCollection, IKeywordHybridSearch + where TKey : notnull + where TRecord : notnull #pragma warning restore CA1711 // Identifiers should not have incorrect suffix { - /// The name of this database for telemetry purposes. - private const string DatabaseName = "Weaviate"; - - /// A set of types that a key on the provided model may have. - private static readonly HashSet s_supportedKeyTypes = - [ - typeof(Guid) - ]; - - /// A set of types that vectors on the provided model may have. - private static readonly HashSet s_supportedVectorTypes = - [ - typeof(ReadOnlyMemory), - typeof(ReadOnlyMemory?), - typeof(ReadOnlyMemory), - typeof(ReadOnlyMemory?) - ]; - - /// A set of types that data properties on the provided model may have. - private static readonly HashSet s_supportedDataTypes = - [ - typeof(string), - typeof(bool), - typeof(bool?), - typeof(int), - typeof(int?), - typeof(long), - typeof(long?), - typeof(short), - typeof(short?), - typeof(byte), - typeof(byte?), - typeof(float), - typeof(float?), - typeof(double), - typeof(double?), - typeof(decimal), - typeof(decimal?), - typeof(DateTime), - typeof(DateTime?), - typeof(DateTimeOffset), - typeof(DateTimeOffset?), - typeof(Guid), - typeof(Guid?) - ]; + /// Metadata about vector store record collection. + private readonly VectorStoreRecordCollectionMetadata _collectionMetadata; /// Default JSON serializer options. private static readonly JsonSerializerOptions s_jsonSerializerOptions = new() @@ -93,11 +57,11 @@ public class WeaviateVectorStoreRecordCollection : IVectorStoreRecordCo /// Optional configuration options for this class. private readonly WeaviateVectorStoreRecordCollectionOptions _options; - /// A helper to access property information for the current data model and record definition. - private readonly VectorStoreRecordPropertyReader _propertyReader; + /// The model for this collection. + private readonly VectorStoreRecordModel _model; /// The mapper to use when mapping between the consumer data model and the Weaviate record. - private readonly IVectorStoreRecordMapper _mapper; + private readonly IWeaviateMapper _mapper; /// Weaviate endpoint. private readonly Uri _endpoint; @@ -106,64 +70,64 @@ public class WeaviateVectorStoreRecordCollection : IVectorStoreRecordCo private readonly string? _apiKey; /// - public string CollectionName { get; } + public string Name { get; } /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// /// that is used to interact with Weaviate API. /// should point to remote or local cluster and API key can be configured via . /// It's also possible to provide these parameters via . /// - /// The name of the collection that this will access. + /// The name of the collection that this will access. /// Optional configuration options for this class. /// The collection name must start with a capital letter and contain only ASCII letters and digits. public WeaviateVectorStoreRecordCollection( HttpClient httpClient, - string collectionName, + string name, WeaviateVectorStoreRecordCollectionOptions? options = default) { // Verify. Verify.NotNull(httpClient); - VerifyCollectionName(collectionName); + VerifyCollectionName(name); + + if (typeof(TKey) != typeof(Guid) && typeof(TKey) != typeof(object)) + { + throw new NotSupportedException($"Only {nameof(Guid)} key is supported (and object for dynamic mapping)."); + } var endpoint = (options?.Endpoint ?? httpClient.BaseAddress) ?? throw new ArgumentException($"Weaviate endpoint should be provided via HttpClient.BaseAddress property or {nameof(WeaviateVectorStoreRecordCollectionOptions)} options parameter."); // Assign. this._httpClient = httpClient; this._endpoint = endpoint; - this.CollectionName = collectionName; + this.Name = name; this._options = options ?? new(); this._apiKey = this._options.ApiKey; - this._propertyReader = new VectorStoreRecordPropertyReader( - typeof(TRecord), - this._options.VectorStoreRecordDefinition, - new() - { - RequiresAtLeastOneVector = false, - SupportsMultipleKeys = false, - SupportsMultipleVectors = true, - JsonSerializerOptions = s_jsonSerializerOptions - }); - - // Validate property types. - this._propertyReader.VerifyKeyProperties(s_supportedKeyTypes); - this._propertyReader.VerifyDataProperties(s_supportedDataTypes, supportEnumerable: true); - this._propertyReader.VerifyVectorProperties(s_supportedVectorTypes); + this._model = new WeaviateModelBuilder(this._options.HasNamedVectors) + .Build(typeof(TRecord), this._options.VectorStoreRecordDefinition, this._options.EmbeddingGenerator, s_jsonSerializerOptions); // Assign mapper. - this._mapper = this.InitializeMapper(); + this._mapper = typeof(TRecord) == typeof(Dictionary) + ? (new WeaviateDynamicDataModelMapper(this.Name, this._options.HasNamedVectors, this._model, s_jsonSerializerOptions) as IWeaviateMapper)! + : new WeaviateVectorStoreRecordMapper(this.Name, this._options.HasNamedVectors, this._model, s_jsonSerializerOptions); + + this._collectionMetadata = new() + { + VectorStoreSystemName = WeaviateConstants.VectorStoreSystemName, + CollectionName = name + }; } /// - public virtual Task CollectionExistsAsync(CancellationToken cancellationToken = default) + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) { const string OperationName = "GetCollectionSchema"; return this.RunOperationAsync(OperationName, async () => { - var request = new WeaviateGetCollectionSchemaRequest(this.CollectionName).Build(); + var request = new WeaviateGetCollectionSchemaRequest(this.Name).Build(); var response = await this .ExecuteRequestWithNotFoundHandlingAsync(request, cancellationToken) @@ -174,26 +138,25 @@ public virtual Task CollectionExistsAsync(CancellationToken cancellationTo } /// - public virtual Task CreateCollectionAsync(CancellationToken cancellationToken = default) + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) { const string OperationName = "CreateCollectionSchema"; + var schema = WeaviateVectorStoreCollectionCreateMapping.MapToSchema( + this.Name, + this._options.HasNamedVectors, + this._model); + return this.RunOperationAsync(OperationName, () => { - var schema = WeaviateVectorStoreCollectionCreateMapping.MapToSchema( - this.CollectionName, - this._propertyReader.DataProperties, - this._propertyReader.VectorProperties, - this._propertyReader.JsonPropertyNamesMap); - var request = new WeaviateCreateCollectionSchemaRequest(schema).Build(); - return this.ExecuteRequestAsync(request, cancellationToken); + return this.ExecuteRequestAsync(request, cancellationToken: cancellationToken); }); } /// - public virtual async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) { if (!await this.CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) { @@ -202,65 +165,87 @@ public virtual async Task CreateCollectionIfNotExistsAsync(CancellationToken can } /// - public virtual Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) { const string OperationName = "DeleteCollectionSchema"; return this.RunOperationAsync(OperationName, () => { - var request = new WeaviateDeleteCollectionSchemaRequest(this.CollectionName).Build(); + var request = new WeaviateDeleteCollectionSchemaRequest(this.Name).Build(); - return this.ExecuteRequestAsync(request, cancellationToken); + return this.ExecuteRequestAsync(request, cancellationToken: cancellationToken); }); } /// - public virtual Task DeleteAsync(Guid key, CancellationToken cancellationToken = default) + public Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) { const string OperationName = "DeleteObject"; return this.RunOperationAsync(OperationName, () => { - var request = new WeaviateDeleteObjectRequest(this.CollectionName, key).Build(); + var guid = key switch + { + Guid g => g, + object o => (Guid)o, + _ => throw new UnreachableException("Guid key should have been validated during model building") + }; - return this.ExecuteRequestAsync(request, cancellationToken); + var request = new WeaviateDeleteObjectRequest(this.Name, guid).Build(); + + return this.ExecuteRequestAsync(request, cancellationToken: cancellationToken); }); } /// - public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + public Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default) { const string OperationName = "DeleteObjectBatch"; const string ContainsAnyOperator = "ContainsAny"; + Verify.NotNull(keys); + + var stringKeys = keys.Select(key => key.ToString()).ToList(); + + if (stringKeys.Count == 0) + { + return Task.CompletedTask; + } + return this.RunOperationAsync(OperationName, () => { var match = new WeaviateQueryMatch { - CollectionName = this.CollectionName, + CollectionName = this.Name, WhereClause = new WeaviateQueryMatchWhereClause { Operator = ContainsAnyOperator, Path = [WeaviateConstants.ReservedKeyPropertyName], - Values = keys.Select(key => key.ToString()).ToList() + Values = stringKeys! } }; var request = new WeaviateDeleteObjectBatchRequest(match).Build(); - return this.ExecuteRequestAsync(request, cancellationToken); + return this.ExecuteRequestAsync(request, cancellationToken: cancellationToken); }); } /// - public virtual Task GetAsync(Guid key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { const string OperationName = "GetCollectionObject"; + var guid = key as Guid? ?? throw new InvalidCastException("Only Guid keys are supported"); + var includeVectors = options?.IncludeVectors is true; + if (includeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + return this.RunOperationAsync(OperationName, async () => { - var includeVectors = options?.IncludeVectors is true; - var request = new WeaviateGetCollectionObjectRequest(this.CollectionName, key, includeVectors).Build(); + using var request = new WeaviateGetCollectionObjectRequest(this.Name, guid, includeVectors).Build(); var jsonObject = await this.ExecuteRequestWithNotFoundHandlingAsync(request, cancellationToken).ConfigureAwait(false); @@ -270,19 +255,22 @@ public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken c } return VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, + WeaviateConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, OperationName, () => this._mapper.MapFromStorageToDataModel(jsonObject!, new() { IncludeVectors = includeVectors })); }); } /// - public virtual async IAsyncEnumerable GetBatchAsync( - IEnumerable keys, + public async IAsyncEnumerable GetAsync( + IEnumerable keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + Verify.NotNull(keys); + var tasks = keys.Select(key => this.GetAsync(key, options, cancellationToken)); var records = await Task.WhenAll(tasks).ConfigureAwait(false); @@ -297,141 +285,311 @@ public virtual async IAsyncEnumerable GetBatchAsync( } /// - public virtual async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) { - return await this.UpsertBatchAsync([record], cancellationToken) - .FirstOrDefaultAsync(cancellationToken) - .ConfigureAwait(false); + var keys = await this.UpsertAsync([record], cancellationToken).ConfigureAwait(false); + + return keys.Single(); } /// - public virtual async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async Task> UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) { const string OperationName = "UpsertCollectionObject"; - var responses = await this.RunOperationAsync(OperationName, async () => + Verify.NotNull(records); + + IReadOnlyList? recordsList = null; + + // If an embedding generator is defined, invoke it once per property for all records. + IReadOnlyList?[]? generatedEmbeddings = null; + + var vectorPropertyCount = this._model.VectorProperties.Count; + for (var i = 0; i < vectorPropertyCount; i++) { - var jsonObjects = records.Select(record => VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, - OperationName, - () => this._mapper.MapFromDataToStorageModel(record))).ToList(); + var vectorProperty = this._model.VectorProperties[i]; + + if (vectorProperty.EmbeddingGenerator is null) + { + continue; + } + + // We have a property with embedding generation; materialize the records' enumerable if needed, to + // prevent multiple enumeration. + if (recordsList is null) + { + recordsList = records is IReadOnlyList r ? r : records.ToList(); + if (recordsList.Count == 0) + { + return []; + } + + records = recordsList; + } + + // TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties), + // and generate embeddings for them in a single batch. That's some more complexity though. + if (vectorProperty.TryGenerateEmbeddings, ReadOnlyMemory>(records, cancellationToken, out var floatTask)) + { + generatedEmbeddings ??= new IReadOnlyList?[vectorPropertyCount]; + generatedEmbeddings[i] = (IReadOnlyList>)await floatTask.ConfigureAwait(false); + } + else if (vectorProperty.TryGenerateEmbeddings, ReadOnlyMemory>(records, cancellationToken, out var doubleTask)) + { + generatedEmbeddings ??= new IReadOnlyList?[vectorPropertyCount]; + generatedEmbeddings[i] = await doubleTask.ConfigureAwait(false); + } + else + { + throw new InvalidOperationException( + $"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding).Name}' for the given input type."); + } + } + + var jsonObjects = records.Select((record, i) => VectorStoreErrorHandler.RunModelConversion( + WeaviateConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, + OperationName, + () => this._mapper.MapFromDataToStorageModel(record, i, generatedEmbeddings))).ToList(); + + if (jsonObjects.Count == 0) + { + return []; + } + + var responses = await this.RunOperationAsync(OperationName, async () => + { var request = new WeaviateUpsertCollectionObjectBatchRequest(jsonObjects).Build(); return await this.ExecuteRequestAsync>(request, cancellationToken).ConfigureAwait(false); }).ConfigureAwait(false); + var keys = new List(jsonObjects.Count); + if (responses is not null) { foreach (var response in responses) { if (response?.Result?.IsSuccess is true) { - yield return response.Id; + keys.Add((TKey)(object)response.Id); } } } + + return keys; } + #region Search + /// - public virtual async Task> VectorizedSearchAsync( + public async IAsyncEnumerable> SearchAsync( + TInput value, + int top, + VectorSearchOptions? options = default, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TInput : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + switch (vectorProperty.EmbeddingGenerator) + { + case IEmbeddingGenerator> generator: + { + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; + } + + case IEmbeddingGenerator> generator: + { + var embedding = await generator.GenerateEmbeddingAsync(value, new() { Dimensions = vectorProperty.Dimensions }, cancellationToken).ConfigureAwait(false); + + await foreach (var record in this.SearchCoreAsync(embedding.Vector, top, vectorProperty, operationName: "Search", options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + + yield break; + } + + case null: + throw new InvalidOperationException(VectorDataStrings.NoEmbeddingGeneratorWasConfiguredForSearch); + + default: + throw new InvalidOperationException( + WeaviateModelBuilder.s_supportedVectorTypes.Contains(typeof(TInput)) + ? string.Format(VectorDataStrings.EmbeddingTypePassedToSearchAsync) + : string.Format(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType, typeof(TInput).Name, vectorProperty.EmbeddingGenerator.GetType().Name)); + } + } + + /// + public IAsyncEnumerable> SearchEmbeddingAsync( TVector vector, + int top, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + { + options ??= s_defaultVectorSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + return this.SearchCoreAsync(vector, top, vectorProperty, operationName: "SearchEmbedding", options, cancellationToken); + } + + private IAsyncEnumerable> SearchCoreAsync( + TVector vector, + int top, + VectorStoreRecordVectorPropertyModel vectorProperty, + string operationName, + VectorSearchOptions options, + CancellationToken cancellationToken = default) + where TVector : notnull { const string OperationName = "VectorSearch"; VerifyVectorParam(vector); + Verify.NotLessThan(top, 1); - var searchOptions = options ?? s_defaultVectorSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(searchOptions); - - var vectorPropertyName = this._propertyReader.GetJsonPropertyName(vectorProperty.DataModelPropertyName); - var fields = this._propertyReader.DataPropertyJsonNames; + if (options.IncludeVectors && this._model.VectorProperties.Any(p => p.EmbeddingGenerator is not null)) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } var query = WeaviateVectorStoreRecordCollectionQueryBuilder.BuildSearchQuery( vector, - this.CollectionName, - vectorPropertyName, - this._propertyReader.KeyPropertyName, + this.Name, + vectorProperty.StorageName, s_jsonSerializerOptions, - searchOptions, - this._propertyReader.JsonPropertyNamesMap, - this._propertyReader.VectorPropertyJsonNames, - this._propertyReader.DataPropertyJsonNames); + top, + options, + this._model, + this._options.HasNamedVectors); + + return this.ExecuteQueryAsync(query, options.IncludeVectors, WeaviateConstants.ScorePropertyName, OperationName, cancellationToken); + } + + /// + [Obsolete("Use either SearchEmbeddingAsync to search directly on embeddings, or SearchAsync to handle embedding generation internally as part of the call.")] + public IAsyncEnumerable> VectorizedSearchAsync(TVector vector, int top, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + where TVector : notnull + => this.SearchEmbeddingAsync(vector, top, options, cancellationToken); + + #endregion Search - return await this.ExecuteQueryAsync(query, searchOptions.IncludeVectors, WeaviateConstants.ScorePropertyName, OperationName, cancellationToken).ConfigureAwait(false); + /// + public IAsyncEnumerable GetAsync(Expression> filter, int top, + GetFilteredRecordOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(filter); + Verify.NotLessThan(top, 1); + + options ??= new(); + + var query = WeaviateVectorStoreRecordCollectionQueryBuilder.BuildQuery( + filter, + top, + options, + this.Name, + this._model, + this._options.HasNamedVectors); + + return this.ExecuteQueryAsync(query, options.IncludeVectors, WeaviateConstants.ScorePropertyName, "GetAsync", cancellationToken) + .SelectAsync(result => result.Record, cancellationToken: cancellationToken); } /// - public async Task> HybridSearchAsync(TVector vector, ICollection keywords, HybridSearchOptions? options = null, CancellationToken cancellationToken = default) + public IAsyncEnumerable> HybridSearchAsync(TVector vector, ICollection keywords, int top, HybridSearchOptions? options = null, CancellationToken cancellationToken = default) { const string OperationName = "HybridSearch"; VerifyVectorParam(vector); + Verify.NotLessThan(top, 1); - var searchOptions = options ?? s_defaultKeywordVectorizedHybridSearchOptions; - var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(new() { VectorProperty = searchOptions.VectorProperty }); - var textDataProperty = this._propertyReader.GetFullTextDataPropertyOrSingle(searchOptions.AdditionalProperty); - - var vectorPropertyName = this._propertyReader.GetJsonPropertyName(vectorProperty.DataModelPropertyName); - var textDataPropertyName = this._propertyReader.GetJsonPropertyName(textDataProperty.DataModelPropertyName); - var fields = this._propertyReader.DataPropertyJsonNames; + options ??= s_defaultKeywordVectorizedHybridSearchOptions; + var vectorProperty = this._model.GetVectorPropertyOrSingle(new() { VectorProperty = options.VectorProperty }); + var textDataProperty = this._model.GetFullTextDataPropertyOrSingle(options.AdditionalProperty); var query = WeaviateVectorStoreRecordCollectionQueryBuilder.BuildHybridSearchQuery( vector, + top, string.Join(" ", keywords), - this.CollectionName, - vectorPropertyName, - this._propertyReader.KeyPropertyName, - textDataPropertyName, + this.Name, + this._model, + vectorProperty, + textDataProperty, s_jsonSerializerOptions, - searchOptions, - this._propertyReader.JsonPropertyNamesMap, - this._propertyReader.VectorPropertyJsonNames, - this._propertyReader.DataPropertyJsonNames); + options, + this._options.HasNamedVectors); + + return this.ExecuteQueryAsync(query, options.IncludeVectors, WeaviateConstants.HybridScorePropertyName, OperationName, cancellationToken); + } - return await this.ExecuteQueryAsync(query, searchOptions.IncludeVectors, WeaviateConstants.HybridScorePropertyName, OperationName, cancellationToken).ConfigureAwait(false); + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreRecordCollectionMetadata) ? this._collectionMetadata : + serviceType == typeof(HttpClient) ? this._httpClient : + serviceType.IsInstanceOfType(this) ? this : + null; } #region private - private async Task> ExecuteQueryAsync(string query, bool includeVectors, string scorePropertyName, string operationName, CancellationToken cancellationToken) + private async IAsyncEnumerable> ExecuteQueryAsync(string query, bool includeVectors, string scorePropertyName, string operationName, [EnumeratorCancellation] CancellationToken cancellationToken) { using var request = new WeaviateVectorSearchRequest(query).Build(); var (responseModel, content) = await this.ExecuteRequestWithResponseContentAsync(request, cancellationToken).ConfigureAwait(false); - var collectionResults = responseModel?.Data?.GetOperation?[this.CollectionName]; + var collectionResults = responseModel?.Data?.GetOperation?[this.Name]; if (collectionResults is null) { throw new VectorStoreOperationException($"Error occurred during vector search. Response: {content}") { - VectorStoreType = DatabaseName, - CollectionName = this.CollectionName, + VectorStoreSystemName = WeaviateConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, OperationName = operationName }; } - var mappedResults = collectionResults.Where(x => x is not null).Select(result => + foreach (var result in collectionResults) { - var (storageModel, score) = WeaviateVectorStoreCollectionSearchMapping.MapSearchResult(result!, scorePropertyName); - - var record = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, - operationName, - () => this._mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors })); + if (result is not null) + { + var (storageModel, score) = WeaviateVectorStoreCollectionSearchMapping.MapSearchResult(result, scorePropertyName, this._options.HasNamedVectors); - return new VectorSearchResult(record, score); - }); + var record = VectorStoreErrorHandler.RunModelConversion( + WeaviateConstants.VectorStoreSystemName, + this._collectionMetadata.VectorStoreName, + this.Name, + operationName, + () => this._mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors })); - return new VectorSearchResults(mappedResults.ToAsyncEnumerable()); + yield return new VectorSearchResult(record, score); + } + } } - private Task ExecuteRequestAsync(HttpRequestMessage request, CancellationToken cancellationToken) + private async Task ExecuteRequestAsync( + HttpRequestMessage request, + bool ensureSuccessStatusCode = true, + CancellationToken cancellationToken = default) { request.RequestUri = new Uri(this._endpoint, request.RequestUri!); @@ -440,15 +598,24 @@ private Task ExecuteRequestAsync(HttpRequestMessage request request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", this._apiKey); } - return this._httpClient.SendWithSuccessCheckAsync(request, cancellationToken); + var response = await this._httpClient + .SendAsync(request, HttpCompletionOption.ResponseContentRead, cancellationToken) + .ConfigureAwait(false); + + if (ensureSuccessStatusCode) + { + response.EnsureSuccessStatusCode(); + } + + return response; } private async Task<(TResponse?, string)> ExecuteRequestWithResponseContentAsync(HttpRequestMessage request, CancellationToken cancellationToken) { - var response = await this.ExecuteRequestAsync(request, cancellationToken).ConfigureAwait(false); - - var responseContent = await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false); + var response = await this.ExecuteRequestAsync(request, cancellationToken: cancellationToken).ConfigureAwait(false); + response.EnsureSuccessStatusCode(); + var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); var responseModel = JsonSerializer.Deserialize(responseContent, s_jsonSerializerOptions); return (responseModel, responseContent); @@ -463,14 +630,19 @@ private Task ExecuteRequestAsync(HttpRequestMessage request private async Task ExecuteRequestWithNotFoundHandlingAsync(HttpRequestMessage request, CancellationToken cancellationToken) { - try - { - return await this.ExecuteRequestAsync(request, cancellationToken).ConfigureAwait(false); - } - catch (HttpOperationException ex) when (ex.StatusCode == HttpStatusCode.NotFound) + var response = await this.ExecuteRequestAsync(request, ensureSuccessStatusCode: false, cancellationToken: cancellationToken).ConfigureAwait(false); + + if (response.StatusCode == HttpStatusCode.NotFound) { return default; } + + response.EnsureSuccessStatusCode(); + + var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + var responseModel = JsonSerializer.Deserialize(responseContent, s_jsonSerializerOptions); + + return responseModel; } private async Task RunOperationAsync(string operationName, Func> operation) @@ -483,56 +655,25 @@ private async Task RunOperationAsync(string operationName, Func> o { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, - CollectionName = this.CollectionName, + VectorStoreSystemName = WeaviateConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, OperationName = operationName }; } } - /// - /// Returns custom mapper, generic data model mapper or default record mapper. - /// - private IVectorStoreRecordMapper InitializeMapper() - { - if (this._options.JsonObjectCustomMapper is not null) - { - return this._options.JsonObjectCustomMapper; - } - - if (typeof(TRecord) == typeof(VectorStoreGenericDataModel)) - { - var mapper = new WeaviateGenericDataModelMapper( - this.CollectionName, - this._propertyReader.KeyProperty, - this._propertyReader.DataProperties, - this._propertyReader.VectorProperties, - this._propertyReader.JsonPropertyNamesMap, - s_jsonSerializerOptions); - - return (mapper as IVectorStoreRecordMapper)!; - } - - return new WeaviateVectorStoreRecordMapper( - this.CollectionName, - this._propertyReader.KeyProperty, - this._propertyReader.DataProperties, - this._propertyReader.VectorProperties, - this._propertyReader.JsonPropertyNamesMap, - s_jsonSerializerOptions); - } - private static void VerifyVectorParam(TVector vector) { Verify.NotNull(vector); var vectorType = vector.GetType(); - if (!s_supportedVectorTypes.Contains(vectorType)) + if (!WeaviateModelBuilder.s_supportedVectorTypes.Contains(vectorType)) { throw new NotSupportedException( $"The provided vector type {vectorType.FullName} is not supported by the Weaviate connector. " + - $"Supported types are: {string.Join(", ", s_supportedVectorTypes.Select(l => l.FullName))}"); + $"Supported types are: {string.Join(", ", WeaviateModelBuilder.s_supportedVectorTypes.Select(l => l.FullName))}"); } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollectionOptions.cs index 9f812e489dcf..572203816faf 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollectionOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollectionOptions.cs @@ -2,18 +2,20 @@ using System; using System.Text.Json.Nodes; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.Weaviate; /// -/// Options when creating a . +/// Options when creating a . /// public sealed class WeaviateVectorStoreRecordCollectionOptions { /// /// Gets or sets an optional custom mapper to use when converting between the data model and Weaviate record. /// + [Obsolete("Custom mappers are no longer supported.", error: true)] public IVectorStoreRecordMapper? JsonObjectCustomMapper { get; init; } = null; /// @@ -38,4 +40,16 @@ public sealed class WeaviateVectorStoreRecordCollectionOptions /// This parameter is optional because authentication may be disabled in local clusters for testing purposes. /// public string? ApiKey { get; set; } = null; + + /// + /// Gets or sets a value indicating whether the vectors in the store are named and multiple vectors are supported, or whether there is just a single unnamed vector in Weaviate collection. + /// Defaults to multiple named vectors. + /// . + /// + public bool HasNamedVectors { get; set; } = true; + + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollectionQueryBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollectionQueryBuilder.cs index 7ef8907e4969..c0dfd17fe124 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollectionQueryBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollectionQueryBuilder.cs @@ -6,6 +6,7 @@ using System.Linq.Expressions; using System.Text.Json; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.Weaviate; @@ -22,27 +23,20 @@ public static string BuildSearchQuery( TVector vector, string collectionName, string vectorPropertyName, - string keyPropertyName, JsonSerializerOptions jsonSerializerOptions, + int top, VectorSearchOptions searchOptions, - IReadOnlyDictionary storagePropertyNames, - IReadOnlyList vectorPropertyStorageNames, - IReadOnlyList dataPropertyStorageNames) + VectorStoreRecordModel model, + bool hasNamedVectors) { - var vectorsQuery = searchOptions.IncludeVectors ? - $"vectors {{ {string.Join(" ", vectorPropertyStorageNames)} }}" : - string.Empty; + var vectorsQuery = GetVectorsPropertyQuery(searchOptions.IncludeVectors, hasNamedVectors, model); #pragma warning disable CS0618 // VectorSearchFilter is obsolete var filter = searchOptions switch { { OldFilter: not null, Filter: not null } => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), - { OldFilter: VectorSearchFilter legacyFilter } => BuildLegacyFilter( - legacyFilter, - jsonSerializerOptions, - keyPropertyName, - storagePropertyNames), - { Filter: Expression> newFilter } => new WeaviateFilterTranslator().Translate(newFilter, storagePropertyNames), + { OldFilter: VectorSearchFilter legacyFilter } => BuildLegacyFilter(legacyFilter, jsonSerializerOptions, model), + { Filter: Expression> newFilter } => new WeaviateFilterTranslator().Translate(newFilter, model), _ => null }; #pragma warning restore CS0618 @@ -53,15 +47,59 @@ public static string BuildSearchQuery( { Get { {{collectionName}} ( - limit: {{searchOptions.Top}} + limit: {{top}} offset: {{searchOptions.Skip}} {{(filter is null ? "" : "where: " + filter)}} nearVector: { - targetVectors: ["{{vectorPropertyName}}"] + {{GetTargetVectorsQuery(hasNamedVectors, vectorPropertyName)}} vector: {{vectorArray}} } ) { - {{string.Join(" ", dataPropertyStorageNames)}} + {{string.Join(" ", model.DataProperties.Select(p => p.StorageName))}} + {{WeaviateConstants.AdditionalPropertiesPropertyName}} { + {{WeaviateConstants.ReservedKeyPropertyName}} + {{WeaviateConstants.ScorePropertyName}} + {{vectorsQuery}} + } + } + } + } + """; + } + + /// + /// Builds Weaviate search query. + /// More information here: . + /// + public static string BuildQuery( + Expression> filter, + int top, + GetFilteredRecordOptions queryOptions, + string collectionName, + VectorStoreRecordModel model, + bool hasNamedVectors) + { + var vectorsQuery = GetVectorsPropertyQuery(queryOptions.IncludeVectors, hasNamedVectors, model); + + var sortPaths = string.Join(",", queryOptions.OrderBy.Values.Select(sortInfo => + { + string sortPath = model.GetDataOrKeyProperty(sortInfo.PropertySelector).StorageName; + + return $$"""{ path: ["{{sortPath}}"], order: {{(sortInfo.Ascending ? "asc" : "desc")}} }"""; + })); + + var translatedFilter = new WeaviateFilterTranslator().Translate(filter, model); + + return $$""" + { + Get { + {{collectionName}} ( + limit: {{top}} + offset: {{queryOptions.Skip}} + where: {{translatedFilter}} + sort: [ {{sortPaths}} ] + ) { + {{string.Join(" ", model.DataProperties.Select(p => p.StorageName))}} {{WeaviateConstants.AdditionalPropertiesPropertyName}} { {{WeaviateConstants.ReservedKeyPropertyName}} {{WeaviateConstants.ScorePropertyName}} @@ -79,31 +117,24 @@ public static string BuildSearchQuery( /// public static string BuildHybridSearchQuery( TVector vector, + int top, string keywords, string collectionName, - string vectorPropertyName, - string keyPropertyName, - string textPropertyName, + VectorStoreRecordModel model, + VectorStoreRecordVectorPropertyModel vectorProperty, + VectorStoreRecordDataPropertyModel textProperty, JsonSerializerOptions jsonSerializerOptions, HybridSearchOptions searchOptions, - IReadOnlyDictionary storagePropertyNames, - IReadOnlyList vectorPropertyStorageNames, - IReadOnlyList dataPropertyStorageNames) + bool hasNamedVectors) { - var vectorsQuery = searchOptions.IncludeVectors ? - $"vectors {{ {string.Join(" ", vectorPropertyStorageNames)} }}" : - string.Empty; + var vectorsQuery = GetVectorsPropertyQuery(searchOptions.IncludeVectors, hasNamedVectors, model); #pragma warning disable CS0618 // VectorSearchFilter is obsolete var filter = searchOptions switch { { OldFilter: not null, Filter: not null } => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), - { OldFilter: VectorSearchFilter legacyFilter } => BuildLegacyFilter( - legacyFilter, - jsonSerializerOptions, - keyPropertyName, - storagePropertyNames), - { Filter: Expression> newFilter } => new WeaviateFilterTranslator().Translate(newFilter, storagePropertyNames), + { OldFilter: VectorSearchFilter legacyFilter } => BuildLegacyFilter(legacyFilter, jsonSerializerOptions, model), + { Filter: Expression> newFilter } => new WeaviateFilterTranslator().Translate(newFilter, model), _ => null }; #pragma warning restore CS0618 @@ -114,18 +145,18 @@ public static string BuildHybridSearchQuery( { Get { {{collectionName}} ( - limit: {{searchOptions.Top}} + limit: {{top}} offset: {{searchOptions.Skip}} {{(filter is null ? "" : "where: " + filter)}} hybrid: { query: "{{keywords}}" - properties: ["{{textPropertyName}}"] - targetVectors: ["{{vectorPropertyName}}"] + properties: ["{{textProperty.StorageName}}"] + {{GetTargetVectorsQuery(hasNamedVectors, vectorProperty.StorageName)}} vector: {{vectorArray}} fusionType: rankedFusion } ) { - {{string.Join(" ", dataPropertyStorageNames)}} + {{string.Join(" ", model.DataProperties.Select(p => p.StorageName))}} {{WeaviateConstants.AdditionalPropertiesPropertyName}} { {{WeaviateConstants.ReservedKeyPropertyName}} {{WeaviateConstants.HybridScorePropertyName}} @@ -139,6 +170,23 @@ public static string BuildHybridSearchQuery( #region private + private static string GetTargetVectorsQuery(bool hasNamedVectors, string vectorPropertyName) + { + return hasNamedVectors ? $"targetVectors: [\"{vectorPropertyName}\"]" : string.Empty; + } + + private static string GetVectorsPropertyQuery( + bool includeVectors, + bool hasNamedVectors, + VectorStoreRecordModel model) + { + return includeVectors + ? hasNamedVectors + ? $"vectors {{ {string.Join(" ", model.VectorProperties.Select(p => p.StorageName))} }}" + : WeaviateConstants.ReservedSingleVectorPropertyName + : string.Empty; + } + #pragma warning disable CS0618 // Type or member is obsolete /// /// Builds filter for Weaviate search query. @@ -147,8 +195,7 @@ public static string BuildHybridSearchQuery( private static string BuildLegacyFilter( VectorSearchFilter? vectorSearchFilter, JsonSerializerOptions jsonSerializerOptions, - string keyPropertyName, - IReadOnlyDictionary storagePropertyNames) + VectorStoreRecordModel model) { const string EqualOperator = "Equal"; const string ContainsAnyOperator = "ContainsAny"; @@ -192,18 +239,12 @@ private static string BuildLegacyFilter( nameof(AnyTagEqualToFilterClause)])}"); } - string? storagePropertyName; - - if (propertyName.Equals(keyPropertyName, StringComparison.Ordinal)) - { - storagePropertyName = WeaviateConstants.ReservedKeyPropertyName; - } - else if (!storagePropertyNames.TryGetValue(propertyName, out storagePropertyName)) + if (!model.PropertyMap.TryGetValue(propertyName, out var property)) { throw new InvalidOperationException($"Property name '{propertyName}' provided as part of the filter clause is not a valid property name."); } - var operand = $$"""{ path: ["{{storagePropertyName}}"], operator: {{filterOperator}}, {{filterValueType}}: {{propertyValue}} }"""; + var operand = $$"""{ path: ["{{property.StorageName}}"], operator: {{filterOperator}}, {{filterValueType}}: {{propertyValue}} }"""; operands.Add(operand); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordMapper.cs index cb1f94a41eae..c19eb908a6ff 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordMapper.cs @@ -1,52 +1,42 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Text.Json; using System.Text.Json.Nodes; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.SemanticKernel.Connectors.Weaviate; -internal sealed class WeaviateVectorStoreRecordMapper : IVectorStoreRecordMapper +internal sealed class WeaviateVectorStoreRecordMapper : IWeaviateMapper { private readonly string _collectionName; - - private readonly string _keyProperty; - - private readonly IReadOnlyList _dataProperties; - - private readonly IReadOnlyList _vectorProperties; - - private readonly IReadOnlyDictionary _storagePropertyNames; - + private readonly bool _hasNamedVectors; + private readonly VectorStoreRecordModel _model; private readonly JsonSerializerOptions _jsonSerializerOptions; + private readonly string _vectorPropertyName; + public WeaviateVectorStoreRecordMapper( string collectionName, - VectorStoreRecordKeyProperty keyProperty, - IReadOnlyList dataProperties, - IReadOnlyList vectorProperties, - IReadOnlyDictionary storagePropertyNames, + bool hasNamedVectors, + VectorStoreRecordModel model, JsonSerializerOptions jsonSerializerOptions) { - Verify.NotNullOrWhiteSpace(collectionName); - Verify.NotNull(keyProperty); - Verify.NotNull(dataProperties); - Verify.NotNull(vectorProperties); - Verify.NotNull(storagePropertyNames); - Verify.NotNull(jsonSerializerOptions); - this._collectionName = collectionName; - this._storagePropertyNames = storagePropertyNames; + this._hasNamedVectors = hasNamedVectors; + this._model = model; this._jsonSerializerOptions = jsonSerializerOptions; - this._keyProperty = this._storagePropertyNames[keyProperty.DataModelPropertyName]; - this._dataProperties = dataProperties.Select(property => this._storagePropertyNames[property.DataModelPropertyName]).ToList(); - this._vectorProperties = vectorProperties.Select(property => this._storagePropertyNames[property.DataModelPropertyName]).ToList(); + this._vectorPropertyName = hasNamedVectors ? + WeaviateConstants.ReservedVectorPropertyName : + WeaviateConstants.ReservedSingleVectorPropertyName; } - public JsonObject MapFromDataToStorageModel(TRecord dataModel) + public JsonObject MapFromDataToStorageModel(TRecord dataModel, int recordIndex, IReadOnlyList?[]? generatedEmbeddings) { Verify.NotNull(dataModel); @@ -56,30 +46,73 @@ public JsonObject MapFromDataToStorageModel(TRecord dataModel) var weaviateObjectModel = new JsonObject { { WeaviateConstants.CollectionPropertyName, JsonValue.Create(this._collectionName) }, - { WeaviateConstants.ReservedKeyPropertyName, jsonNodeDataModel[this._keyProperty]!.DeepClone() }, + // The key property in Weaviate is always named 'id'. + // But the external JSON serializer used just above isn't aware of that, and will produce a JSON object with another name, taking into + // account e.g. naming policies. TemporaryStorageName gets populated in the model builder - containing that name - once VectorStoreModelBuildingOptions.ReservedKeyPropertyName is set + { WeaviateConstants.ReservedKeyPropertyName, jsonNodeDataModel[this._model.KeyProperty.TemporaryStorageName!]!.DeepClone() }, { WeaviateConstants.ReservedDataPropertyName, new JsonObject() }, - { WeaviateConstants.ReservedVectorPropertyName, new JsonObject() }, + { this._vectorPropertyName, new JsonObject() }, }; // Populate data properties. - foreach (var property in this._dataProperties) + foreach (var property in this._model.DataProperties) { - var node = jsonNodeDataModel[property]; + var node = jsonNodeDataModel[property.StorageName]; if (node is not null) { - weaviateObjectModel[WeaviateConstants.ReservedDataPropertyName]![property] = node.DeepClone(); + weaviateObjectModel[WeaviateConstants.ReservedDataPropertyName]![property.StorageName] = node.DeepClone(); } } // Populate vector properties. - foreach (var property in this._vectorProperties) + if (this._hasNamedVectors) { - var node = jsonNodeDataModel[property]; + for (var i = 0; i < this._model.VectorProperties.Count; i++) + { + var property = this._model.VectorProperties[i]; - if (node is not null) + if (generatedEmbeddings?[i] is IReadOnlyList e) + { + weaviateObjectModel[this._vectorPropertyName]![property.StorageName] = e[recordIndex] switch + { + Embedding fe => JsonValue.Create(fe.Vector.ToArray()), + Embedding de => JsonValue.Create(de.Vector.ToArray()), + _ => throw new UnreachableException() + }; + } + else + { + var node = jsonNodeDataModel[property.StorageName]; + + if (node is not null) + { + weaviateObjectModel[this._vectorPropertyName]![property.StorageName] = node.DeepClone(); + } + } + } + } + else + { + var property = this._model.VectorProperty; + + if (generatedEmbeddings?.Single() is IReadOnlyList e) { - weaviateObjectModel[WeaviateConstants.ReservedVectorPropertyName]![property] = node.DeepClone(); + weaviateObjectModel[this._vectorPropertyName] = e[recordIndex] switch + { + Embedding fe => JsonValue.Create(fe.Vector.ToArray()), + Embedding de => JsonValue.Create(de.Vector.ToArray()), + _ => throw new UnreachableException() + }; + } + else + { + var node = jsonNodeDataModel[property.StorageName]; + + if (node is not null) + { + weaviateObjectModel[this._vectorPropertyName] = node.DeepClone(); + } } } @@ -90,33 +123,50 @@ public TRecord MapFromStorageToDataModel(JsonObject storageModel, StorageToDataM { Verify.NotNull(storageModel); + // TemporaryStorageName gets populated in the model builder once VectorStoreModelBuildingOptions.ReservedKeyPropertyName is set + Debug.Assert(this._model.KeyProperty.TemporaryStorageName is not null); + // Transform Weaviate object model to data model. var jsonNodeDataModel = new JsonObject { - { this._keyProperty, storageModel[WeaviateConstants.ReservedKeyPropertyName]?.DeepClone() }, + // See comment above on TemporaryStorageName + { this._model.KeyProperty.TemporaryStorageName!, storageModel[WeaviateConstants.ReservedKeyPropertyName]?.DeepClone() }, }; // Populate data properties. - foreach (var property in this._dataProperties) + foreach (var property in this._model.DataProperties) { - var node = storageModel[WeaviateConstants.ReservedDataPropertyName]?[property]; + var node = storageModel[WeaviateConstants.ReservedDataPropertyName]?[property.StorageName]; if (node is not null) { - jsonNodeDataModel[property] = node.DeepClone(); + jsonNodeDataModel[property.StorageName] = node.DeepClone(); } } // Populate vector properties. if (options.IncludeVectors) { - foreach (var property in this._vectorProperties) + if (this._hasNamedVectors) + { + foreach (var property in this._model.VectorProperties) + { + var node = storageModel[this._vectorPropertyName]?[property.StorageName]; + + if (node is not null) + { + jsonNodeDataModel[property.StorageName] = node.DeepClone(); + } + } + } + else { - var node = storageModel[WeaviateConstants.ReservedVectorPropertyName]?[property]; + var vectorProperty = this._model.VectorProperty; + var node = storageModel[this._vectorPropertyName]; if (node is not null) { - jsonNodeDataModel[property] = node.DeepClone(); + jsonNodeDataModel[vectorProperty.StorageName] = node.DeepClone(); } } } diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/Connectors.MongoDB.UnitTests.csproj b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/Connectors.MongoDB.UnitTests.csproj index b8969e21943e..88918bc7e1df 100644 --- a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/Connectors.MongoDB.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/Connectors.MongoDB.UnitTests.csproj @@ -8,7 +8,8 @@ enable disable false - $(NoWarn);SKEXP0001,SKEXP0020,VSTHRD111,CA2007,CS1591 + $(NoWarn);SKEXP0001,VSTHRD111,CA2007,CS1591 + $(NoWarn);MEVD9001 diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBDynamicDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBDynamicDataModelMapperTests.cs new file mode 100644 index 000000000000..6534c3dfef0d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBDynamicDataModelMapperTests.cs @@ -0,0 +1,291 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.SemanticKernel.Connectors.MongoDB; +using MongoDB.Bson; +using Xunit; + +namespace SemanticKernel.Connectors.MongoDB.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class MongoDBDynamicDataModelMapperTests +{ + private static readonly VectorStoreRecordModel s_model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("BoolDataProp", typeof(bool)), + new VectorStoreRecordDataProperty("NullableBoolDataProp", typeof(bool?)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordDataProperty("IntDataProp", typeof(int)), + new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), + new VectorStoreRecordDataProperty("LongDataProp", typeof(long)), + new VectorStoreRecordDataProperty("NullableLongDataProp", typeof(long?)), + new VectorStoreRecordDataProperty("FloatDataProp", typeof(float)), + new VectorStoreRecordDataProperty("NullableFloatDataProp", typeof(float?)), + new VectorStoreRecordDataProperty("DoubleDataProp", typeof(double)), + new VectorStoreRecordDataProperty("NullableDoubleDataProp", typeof(double?)), + new VectorStoreRecordDataProperty("DecimalDataProp", typeof(decimal)), + new VectorStoreRecordDataProperty("NullableDecimalDataProp", typeof(decimal?)), + new VectorStoreRecordDataProperty("DateTimeDataProp", typeof(DateTime)), + new VectorStoreRecordDataProperty("NullableDateTimeDataProp", typeof(DateTime?)), + new VectorStoreRecordDataProperty("TagListDataProp", typeof(List)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10), + new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?), 10), + new VectorStoreRecordVectorProperty("DoubleVector", typeof(ReadOnlyMemory), 10), + new VectorStoreRecordVectorProperty("NullableDoubleVector", typeof(ReadOnlyMemory?), 10) + ]); + + private static readonly float[] s_floatVector = [1.0f, 2.0f, 3.0f]; + private static readonly double[] s_doubleVector = [1.0f, 2.0f, 3.0f]; + private static readonly List s_taglist = ["tag1", "tag2"]; + + [Fact] + public void MapFromDataToStorageModelMapsAllSupportedTypes() + { + // Arrange + var sut = new MongoDBDynamicDataModelMapper(s_model); + var dataModel = new Dictionary + { + ["Key"] = "key", + + ["BoolDataProp"] = true, + ["NullableBoolDataProp"] = false, + ["StringDataProp"] = "string", + ["IntDataProp"] = 1, + ["NullableIntDataProp"] = 2, + ["LongDataProp"] = 3L, + ["NullableLongDataProp"] = 4L, + ["FloatDataProp"] = 5.0f, + ["NullableFloatDataProp"] = 6.0f, + ["DoubleDataProp"] = 7.0, + ["NullableDoubleDataProp"] = 8.0, + ["DecimalDataProp"] = 9.0m, + ["NullableDecimalDataProp"] = 10.0m, + ["DateTimeDataProp"] = new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), + ["NullableDateTimeDataProp"] = new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), + ["TagListDataProp"] = s_taglist, + + ["FloatVector"] = new ReadOnlyMemory(s_floatVector), + ["NullableFloatVector"] = new ReadOnlyMemory(s_floatVector), + ["DoubleVector"] = new ReadOnlyMemory(s_doubleVector), + ["NullableDoubleVector"] = new ReadOnlyMemory(s_doubleVector), + }; + + // Act + var storageModel = sut.MapFromDataToStorageModel(dataModel, generatedEmbeddings: null); + + // Assert + Assert.Equal("key", storageModel["_id"]); + Assert.Equal(true, (bool?)storageModel["BoolDataProp"]); + Assert.Equal(false, (bool?)storageModel["NullableBoolDataProp"]); + Assert.Equal("string", (string?)storageModel["StringDataProp"]); + Assert.Equal(1, (int?)storageModel["IntDataProp"]); + Assert.Equal(2, (int?)storageModel["NullableIntDataProp"]); + Assert.Equal(3L, (long?)storageModel["LongDataProp"]); + Assert.Equal(4L, (long?)storageModel["NullableLongDataProp"]); + Assert.Equal(5.0f, (float?)storageModel["FloatDataProp"].AsDouble); + Assert.Equal(6.0f, (float?)storageModel["NullableFloatDataProp"].AsNullableDouble); + Assert.Equal(7.0, (double?)storageModel["DoubleDataProp"]); + Assert.Equal(8.0, (double?)storageModel["NullableDoubleDataProp"]); + Assert.Equal(9.0m, (decimal?)storageModel["DecimalDataProp"]); + Assert.Equal(10.0m, (decimal?)storageModel["NullableDecimalDataProp"]); + Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), storageModel["DateTimeDataProp"].ToUniversalTime()); + Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), storageModel["NullableDateTimeDataProp"].ToUniversalTime()); + Assert.Equal(s_taglist, storageModel["TagListDataProp"]!.AsBsonArray.Select(x => (string)x!).ToArray()); + Assert.Equal(s_floatVector, storageModel["FloatVector"]!.AsBsonArray.Select(x => (float)x.AsDouble!).ToArray()); + Assert.Equal(s_floatVector, storageModel["NullableFloatVector"]!.AsBsonArray.Select(x => (float)x.AsNullableDouble!).ToArray()); + Assert.Equal(s_doubleVector, storageModel["DoubleVector"]!.AsBsonArray.Select(x => (double)x!).ToArray()); + Assert.Equal(s_doubleVector, storageModel["NullableDoubleVector"]!.AsBsonArray.Select(x => (double)x!).ToArray()); + } + + [Fact] + public void MapFromDataToStorageModelMapsNullValues() + { + // Arrange + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), + new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?), 10) + ]); + + var dataModel = new Dictionary + { + ["Key"] = "key", + ["StringDataProp"] = null, + ["NullableIntDataProp"] = null, + ["NullableFloatVector"] = null + }; + + var sut = new MongoDBDynamicDataModelMapper(model); + + // Act + var storageModel = sut.MapFromDataToStorageModel(dataModel, generatedEmbeddings: null); + + // Assert + Assert.Equal(BsonNull.Value, storageModel["StringDataProp"]); + Assert.Equal(BsonNull.Value, storageModel["NullableIntDataProp"]); + Assert.Empty(storageModel["NullableFloatVector"].AsBsonArray); + } + + [Fact] + public void MapFromStorageToDataModelMapsAllSupportedTypes() + { + // Arrange + var sut = new MongoDBDynamicDataModelMapper(s_model); + var storageModel = new BsonDocument + { + ["_id"] = "key", + ["BoolDataProp"] = true, + ["NullableBoolDataProp"] = false, + ["StringDataProp"] = "string", + ["IntDataProp"] = 1, + ["NullableIntDataProp"] = 2, + ["LongDataProp"] = 3L, + ["NullableLongDataProp"] = 4L, + ["FloatDataProp"] = 5.0f, + ["NullableFloatDataProp"] = 6.0f, + ["DoubleDataProp"] = 7.0, + ["NullableDoubleDataProp"] = 8.0, + ["DecimalDataProp"] = 9.0m, + ["NullableDecimalDataProp"] = 10.0m, + ["DateTimeDataProp"] = new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), + ["NullableDateTimeDataProp"] = new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), + ["TagListDataProp"] = BsonArray.Create(s_taglist), + ["FloatVector"] = BsonArray.Create(s_floatVector), + ["NullableFloatVector"] = BsonArray.Create(s_floatVector), + ["DoubleVector"] = BsonArray.Create(s_doubleVector), + ["NullableDoubleVector"] = BsonArray.Create(s_doubleVector) + }; + + // Act + var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); + + // Assert + Assert.Equal("key", dataModel["Key"]); + Assert.Equal(true, dataModel["BoolDataProp"]); + Assert.Equal(false, dataModel["NullableBoolDataProp"]); + Assert.Equal("string", dataModel["StringDataProp"]); + Assert.Equal(1, dataModel["IntDataProp"]); + Assert.Equal(2, dataModel["NullableIntDataProp"]); + Assert.Equal(3L, dataModel["LongDataProp"]); + Assert.Equal(4L, dataModel["NullableLongDataProp"]); + Assert.Equal(5.0f, dataModel["FloatDataProp"]); + Assert.Equal(6.0f, dataModel["NullableFloatDataProp"]); + Assert.Equal(7.0, dataModel["DoubleDataProp"]); + Assert.Equal(8.0, dataModel["NullableDoubleDataProp"]); + Assert.Equal(9.0m, dataModel["DecimalDataProp"]); + Assert.Equal(10.0m, dataModel["NullableDecimalDataProp"]); + Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), dataModel["DateTimeDataProp"]); + Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), dataModel["NullableDateTimeDataProp"]); + Assert.Equal(s_taglist, dataModel["TagListDataProp"]); + Assert.Equal(s_floatVector, ((ReadOnlyMemory)dataModel["FloatVector"]!).ToArray()); + Assert.Equal(s_floatVector, ((ReadOnlyMemory)dataModel["NullableFloatVector"]!)!.ToArray()); + Assert.Equal(s_doubleVector, ((ReadOnlyMemory)dataModel["DoubleVector"]!).ToArray()); + Assert.Equal(s_doubleVector, ((ReadOnlyMemory)dataModel["NullableDoubleVector"]!)!.ToArray()); + } + + [Fact] + public void MapFromStorageToDataModelMapsNullValues() + { + // Arrange + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), + new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?), 10) + ]); + + var storageModel = new BsonDocument + { + ["_id"] = "key", + ["StringDataProp"] = BsonNull.Value, + ["NullableIntDataProp"] = BsonNull.Value, + ["NullableFloatVector"] = BsonNull.Value + }; + + var sut = new MongoDBDynamicDataModelMapper(model); + + // Act + var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); + + // Assert + Assert.Equal("key", dataModel["Key"]); + Assert.Null(dataModel["StringDataProp"]); + Assert.Null(dataModel["NullableIntDataProp"]); + Assert.Null(dataModel["NullableFloatVector"]); + } + + [Fact] + public void MapFromStorageToDataModelThrowsForMissingKey() + { + // Arrange + var sut = new MongoDBDynamicDataModelMapper(s_model); + var storageModel = new BsonDocument(); + + // Act & Assert + var exception = Assert.Throws( + () => sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true })); + } + + [Fact] + public void MapFromDataToStorageModelSkipsMissingProperties() + { + // Arrange + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10), + ]); + + var dataModel = new Dictionary { ["Key"] = "key" }; + var sut = new MongoDBDynamicDataModelMapper(model); + + // Act + var storageModel = sut.MapFromDataToStorageModel(dataModel, generatedEmbeddings: null); + + // Assert + Assert.Equal("key", (string?)storageModel["_id"]); + Assert.False(storageModel.Contains("StringDataProp")); + Assert.False(storageModel.Contains("FloatVector")); + } + + [Fact] + public void MapFromStorageToDataModelSkipsMissingProperties() + { + // Arrange + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10), + ]); + + var storageModel = new BsonDocument + { + ["_id"] = "key" + }; + + var sut = new MongoDBDynamicDataModelMapper(model); + + // Act + var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); + + // Assert + Assert.Equal("key", dataModel["Key"]); + Assert.False(dataModel.ContainsKey("StringDataProp")); + Assert.False(dataModel.ContainsKey("FloatVector")); + } + + private static VectorStoreRecordModel BuildModel(IReadOnlyList properties) + => new MongoDBModelBuilder().Build(typeof(Dictionary), new() { Properties = properties }, defaultEmbeddingGenerator: null); +} diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBGenericDataModelMapperTests.cs deleted file mode 100644 index 1e19af61a2f4..000000000000 --- a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBGenericDataModelMapperTests.cs +++ /dev/null @@ -1,310 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Linq; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.MongoDB; -using MongoDB.Bson; -using Xunit; - -namespace SemanticKernel.Connectors.MongoDB.UnitTests; - -/// -/// Unit tests for class. -/// -public sealed class MongoDBGenericDataModelMapperTests -{ - private static readonly VectorStoreRecordDefinition s_vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("BoolDataProp", typeof(bool)), - new VectorStoreRecordDataProperty("NullableBoolDataProp", typeof(bool?)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordDataProperty("IntDataProp", typeof(int)), - new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordDataProperty("LongDataProp", typeof(long)), - new VectorStoreRecordDataProperty("NullableLongDataProp", typeof(long?)), - new VectorStoreRecordDataProperty("FloatDataProp", typeof(float)), - new VectorStoreRecordDataProperty("NullableFloatDataProp", typeof(float?)), - new VectorStoreRecordDataProperty("DoubleDataProp", typeof(double)), - new VectorStoreRecordDataProperty("NullableDoubleDataProp", typeof(double?)), - new VectorStoreRecordDataProperty("DecimalDataProp", typeof(decimal)), - new VectorStoreRecordDataProperty("NullableDecimalDataProp", typeof(decimal?)), - new VectorStoreRecordDataProperty("DateTimeDataProp", typeof(DateTime)), - new VectorStoreRecordDataProperty("NullableDateTimeDataProp", typeof(DateTime?)), - new VectorStoreRecordDataProperty("TagListDataProp", typeof(List)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?)), - new VectorStoreRecordVectorProperty("DoubleVector", typeof(ReadOnlyMemory)), - new VectorStoreRecordVectorProperty("NullableDoubleVector", typeof(ReadOnlyMemory?)), - }, - }; - - private static readonly float[] s_floatVector = [1.0f, 2.0f, 3.0f]; - private static readonly double[] s_doubleVector = [1.0f, 2.0f, 3.0f]; - private static readonly List s_taglist = ["tag1", "tag2"]; - - [Fact] - public void MapFromDataToStorageModelMapsAllSupportedTypes() - { - // Arrange - var sut = new MongoDBGenericDataModelMapper(s_vectorStoreRecordDefinition); - var dataModel = new VectorStoreGenericDataModel("key") - { - Data = - { - ["BoolDataProp"] = true, - ["NullableBoolDataProp"] = false, - ["StringDataProp"] = "string", - ["IntDataProp"] = 1, - ["NullableIntDataProp"] = 2, - ["LongDataProp"] = 3L, - ["NullableLongDataProp"] = 4L, - ["FloatDataProp"] = 5.0f, - ["NullableFloatDataProp"] = 6.0f, - ["DoubleDataProp"] = 7.0, - ["NullableDoubleDataProp"] = 8.0, - ["DecimalDataProp"] = 9.0m, - ["NullableDecimalDataProp"] = 10.0m, - ["DateTimeDataProp"] = new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), - ["NullableDateTimeDataProp"] = new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), - ["TagListDataProp"] = s_taglist, - }, - Vectors = - { - ["FloatVector"] = new ReadOnlyMemory(s_floatVector), - ["NullableFloatVector"] = new ReadOnlyMemory(s_floatVector), - ["DoubleVector"] = new ReadOnlyMemory(s_doubleVector), - ["NullableDoubleVector"] = new ReadOnlyMemory(s_doubleVector), - }, - }; - - // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal("key", storageModel["_id"]); - Assert.Equal(true, (bool?)storageModel["BoolDataProp"]); - Assert.Equal(false, (bool?)storageModel["NullableBoolDataProp"]); - Assert.Equal("string", (string?)storageModel["StringDataProp"]); - Assert.Equal(1, (int?)storageModel["IntDataProp"]); - Assert.Equal(2, (int?)storageModel["NullableIntDataProp"]); - Assert.Equal(3L, (long?)storageModel["LongDataProp"]); - Assert.Equal(4L, (long?)storageModel["NullableLongDataProp"]); - Assert.Equal(5.0f, (float?)storageModel["FloatDataProp"].AsDouble); - Assert.Equal(6.0f, (float?)storageModel["NullableFloatDataProp"].AsNullableDouble); - Assert.Equal(7.0, (double?)storageModel["DoubleDataProp"]); - Assert.Equal(8.0, (double?)storageModel["NullableDoubleDataProp"]); - Assert.Equal(9.0m, (decimal?)storageModel["DecimalDataProp"]); - Assert.Equal(10.0m, (decimal?)storageModel["NullableDecimalDataProp"]); - Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), storageModel["DateTimeDataProp"].ToUniversalTime()); - Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), storageModel["NullableDateTimeDataProp"].ToUniversalTime()); - Assert.Equal(s_taglist, storageModel["TagListDataProp"]!.AsBsonArray.Select(x => (string)x!).ToArray()); - Assert.Equal(s_floatVector, storageModel["FloatVector"]!.AsBsonArray.Select(x => (float)x.AsDouble!).ToArray()); - Assert.Equal(s_floatVector, storageModel["NullableFloatVector"]!.AsBsonArray.Select(x => (float)x.AsNullableDouble!).ToArray()); - Assert.Equal(s_doubleVector, storageModel["DoubleVector"]!.AsBsonArray.Select(x => (double)x!).ToArray()); - Assert.Equal(s_doubleVector, storageModel["NullableDoubleVector"]!.AsBsonArray.Select(x => (double)x!).ToArray()); - } - - [Fact] - public void MapFromDataToStorageModelMapsNullValues() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?)), - }, - }; - - var dataModel = new VectorStoreGenericDataModel("key") - { - Data = - { - ["StringDataProp"] = null, - ["NullableIntDataProp"] = null, - }, - Vectors = - { - ["NullableFloatVector"] = null, - }, - }; - - var sut = new MongoDBGenericDataModelMapper(vectorStoreRecordDefinition); - - // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal(BsonNull.Value, storageModel["StringDataProp"]); - Assert.Equal(BsonNull.Value, storageModel["NullableIntDataProp"]); - Assert.Empty(storageModel["NullableFloatVector"].AsBsonArray); - } - - [Fact] - public void MapFromStorageToDataModelMapsAllSupportedTypes() - { - // Arrange - var sut = new MongoDBGenericDataModelMapper(s_vectorStoreRecordDefinition); - var storageModel = new BsonDocument - { - ["_id"] = "key", - ["BoolDataProp"] = true, - ["NullableBoolDataProp"] = false, - ["StringDataProp"] = "string", - ["IntDataProp"] = 1, - ["NullableIntDataProp"] = 2, - ["LongDataProp"] = 3L, - ["NullableLongDataProp"] = 4L, - ["FloatDataProp"] = 5.0f, - ["NullableFloatDataProp"] = 6.0f, - ["DoubleDataProp"] = 7.0, - ["NullableDoubleDataProp"] = 8.0, - ["DecimalDataProp"] = 9.0m, - ["NullableDecimalDataProp"] = 10.0m, - ["DateTimeDataProp"] = new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), - ["NullableDateTimeDataProp"] = new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), - ["TagListDataProp"] = BsonArray.Create(s_taglist), - ["FloatVector"] = BsonArray.Create(s_floatVector), - ["NullableFloatVector"] = BsonArray.Create(s_floatVector), - ["DoubleVector"] = BsonArray.Create(s_doubleVector), - ["NullableDoubleVector"] = BsonArray.Create(s_doubleVector) - }; - - // Act - var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); - - // Assert - Assert.Equal("key", dataModel.Key); - Assert.Equal(true, dataModel.Data["BoolDataProp"]); - Assert.Equal(false, dataModel.Data["NullableBoolDataProp"]); - Assert.Equal("string", dataModel.Data["StringDataProp"]); - Assert.Equal(1, dataModel.Data["IntDataProp"]); - Assert.Equal(2, dataModel.Data["NullableIntDataProp"]); - Assert.Equal(3L, dataModel.Data["LongDataProp"]); - Assert.Equal(4L, dataModel.Data["NullableLongDataProp"]); - Assert.Equal(5.0f, dataModel.Data["FloatDataProp"]); - Assert.Equal(6.0f, dataModel.Data["NullableFloatDataProp"]); - Assert.Equal(7.0, dataModel.Data["DoubleDataProp"]); - Assert.Equal(8.0, dataModel.Data["NullableDoubleDataProp"]); - Assert.Equal(9.0m, dataModel.Data["DecimalDataProp"]); - Assert.Equal(10.0m, dataModel.Data["NullableDecimalDataProp"]); - Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), dataModel.Data["DateTimeDataProp"]); - Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), dataModel.Data["NullableDateTimeDataProp"]); - Assert.Equal(s_taglist, dataModel.Data["TagListDataProp"]); - Assert.Equal(s_floatVector, ((ReadOnlyMemory)dataModel.Vectors["FloatVector"]!).ToArray()); - Assert.Equal(s_floatVector, ((ReadOnlyMemory)dataModel.Vectors["NullableFloatVector"]!)!.ToArray()); - Assert.Equal(s_doubleVector, ((ReadOnlyMemory)dataModel.Vectors["DoubleVector"]!).ToArray()); - Assert.Equal(s_doubleVector, ((ReadOnlyMemory)dataModel.Vectors["NullableDoubleVector"]!)!.ToArray()); - } - - [Fact] - public void MapFromStorageToDataModelMapsNullValues() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?)), - }, - }; - - var storageModel = new BsonDocument - { - ["_id"] = "key", - ["StringDataProp"] = BsonNull.Value, - ["NullableIntDataProp"] = BsonNull.Value, - ["NullableFloatVector"] = BsonNull.Value - }; - - var sut = new MongoDBGenericDataModelMapper(vectorStoreRecordDefinition); - - // Act - var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); - - // Assert - Assert.Equal("key", dataModel.Key); - Assert.Null(dataModel.Data["StringDataProp"]); - Assert.Null(dataModel.Data["NullableIntDataProp"]); - Assert.Null(dataModel.Vectors["NullableFloatVector"]); - } - - [Fact] - public void MapFromStorageToDataModelThrowsForMissingKey() - { - // Arrange - var sut = new MongoDBGenericDataModelMapper(s_vectorStoreRecordDefinition); - var storageModel = new BsonDocument(); - - // Act & Assert - var exception = Assert.Throws( - () => sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true })); - } - - [Fact] - public void MapFromDataToStorageModelSkipsMissingProperties() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - }, - }; - - var dataModel = new VectorStoreGenericDataModel("key"); - var sut = new MongoDBGenericDataModelMapper(vectorStoreRecordDefinition); - - // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal("key", (string?)storageModel["_id"]); - Assert.False(storageModel.Contains("StringDataProp")); - Assert.False(storageModel.Contains("FloatVector")); - } - - [Fact] - public void MapFromStorageToDataModelSkipsMissingProperties() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - }, - }; - - var storageModel = new BsonDocument - { - ["_id"] = "key" - }; - - var sut = new MongoDBGenericDataModelMapper(vectorStoreRecordDefinition); - - // Act - var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); - - // Assert - Assert.Equal("key", dataModel.Key); - Assert.False(dataModel.Data.ContainsKey("StringDataProp")); - Assert.False(dataModel.Vectors.ContainsKey("FloatVector")); - } -} diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBHotelModel.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBHotelModel.cs index 46374a5cc408..6313fa3cc0dd 100644 --- a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBHotelModel.cs +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBHotelModel.cs @@ -14,7 +14,7 @@ public class MongoDBHotelModel(string hotelId) public string HotelId { get; init; } = hotelId; /// A string metadata field. - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string? HotelName { get; set; } /// An int metadata field. @@ -39,6 +39,6 @@ public class MongoDBHotelModel(string hotelId) public string? Description { get; set; } /// A vector field. - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineSimilarity)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineSimilarity)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } } diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBServiceCollectionExtensionsTests.cs index ac6f401583ac..1098ebfe1ee3 100644 --- a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBServiceCollectionExtensionsTests.cs @@ -82,11 +82,11 @@ private void AssertVectorStoreRecordCollectionCreated() var collection = serviceProvider.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = serviceProvider.GetRequiredService>(); + var vectorizedSearch = serviceProvider.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); + Assert.IsType>(vectorizedSearch); } #pragma warning disable CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreCollectionSearchMappingTests.cs index cea02dee086c..c8c5fcf9f39b 100644 --- a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreCollectionSearchMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreCollectionSearchMappingTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.MongoDB; using MongoDB.Bson; using Xunit; @@ -16,11 +17,18 @@ namespace SemanticKernel.Connectors.MongoDB.UnitTests; /// public sealed class MongoDBVectorStoreCollectionSearchMappingTests { - private readonly Dictionary _storagePropertyNames = new() - { - ["Property1"] = "property_1", - ["Property2"] = "property_2", - }; + private readonly VectorStoreRecordModel _model = new MongoDBModelBuilder() + .Build( + typeof(Dictionary), + new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Property1", typeof(string)) { StoragePropertyName = "property_1" }, + new VectorStoreRecordDataProperty("Property2", typeof(string)) { StoragePropertyName = "property_2" }, + ] + }, + defaultEmbeddingGenerator: null); [Fact] public void BuildFilterThrowsExceptionWithUnsupportedFilterClause() @@ -29,7 +37,7 @@ public void BuildFilterThrowsExceptionWithUnsupportedFilterClause() var vectorSearchFilter = new VectorSearchFilter().AnyTagEqualTo("NonExistentProperty", "TestValue"); // Act & Assert - Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(vectorSearchFilter, this._storagePropertyNames)); + Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(vectorSearchFilter, this._model)); } [Fact] @@ -39,7 +47,7 @@ public void BuildFilterThrowsExceptionWithNonExistentPropertyName() var vectorSearchFilter = new VectorSearchFilter().EqualTo("NonExistentProperty", "TestValue"); // Act & Assert - Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(vectorSearchFilter, this._storagePropertyNames)); + Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(vectorSearchFilter, this._model)); } [Fact] @@ -51,7 +59,7 @@ public void BuildFilterThrowsExceptionWithMultipleFilterClausesOfSameType() .EqualTo("Property1", "TestValue2"); // Act & Assert - Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(vectorSearchFilter, this._storagePropertyNames)); + Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(vectorSearchFilter, this._model)); } [Fact] @@ -62,8 +70,8 @@ public void BuilderFilterByDefaultReturnsValidFilter() var vectorSearchFilter = new VectorSearchFilter().EqualTo("Property1", "TestValue1"); // Act - var filter = MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(vectorSearchFilter, this._storagePropertyNames); + var filter = MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(vectorSearchFilter, this._model); - Assert.Equal(filter.ToJson(), expectedFilter.ToJson()); + Assert.Equal(expectedFilter.ToJson(), filter.ToJson()); } } diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordCollectionTests.cs index ddf71955621e..89774107b140 100644 --- a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordCollectionTests.cs @@ -19,7 +19,7 @@ namespace SemanticKernel.Connectors.MongoDB.UnitTests; /// -/// Unit tests for class. +/// Unit tests for class. /// public sealed class MongoDBVectorStoreRecordCollectionTests { @@ -37,7 +37,7 @@ public MongoDBVectorStoreRecordCollectionTests() public void ConstructorForModelWithoutKeyThrowsException() { // Act & Assert - var exception = Assert.Throws(() => new MongoDBVectorStoreRecordCollection(this._mockMongoDatabase.Object, "collection")); + var exception = Assert.Throws(() => new MongoDBVectorStoreRecordCollection(this._mockMongoDatabase.Object, "collection")); Assert.Contains("No key property found", exception.Message); } @@ -45,7 +45,7 @@ public void ConstructorForModelWithoutKeyThrowsException() public void ConstructorWithDeclarativeModelInitializesCollection() { // Act & Assert - var collection = new MongoDBVectorStoreRecordCollection( + var collection = new MongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); @@ -62,7 +62,7 @@ public void ConstructorWithImperativeModelInitializesCollection() }; // Act - var collection = new MongoDBVectorStoreRecordCollection( + var collection = new MongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection", new() { VectorStoreRecordDefinition = definition }); @@ -90,7 +90,7 @@ public async Task CollectionExistsReturnsValidResultAsync(List collectio .Setup(l => l.ListCollectionNamesAsync(It.IsAny(), It.IsAny())) .ReturnsAsync(mockCursor.Object); - var sut = new MongoDBVectorStoreRecordCollection( + var sut = new MongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, collectionName); @@ -144,7 +144,7 @@ public async Task CreateCollectionInvokesValidMethodsAsync(bool indexExists, int .Setup(l => l.ListCollectionNamesAsync(It.IsAny(), It.IsAny())) .ReturnsAsync(mockCursor.Object); - var sut = new MongoDBVectorStoreRecordCollection(this._mockMongoDatabase.Object, CollectionName); + var sut = new MongoDBVectorStoreRecordCollection(this._mockMongoDatabase.Object, CollectionName); // Act await sut.CreateCollectionAsync(); @@ -207,7 +207,7 @@ public async Task CreateCollectionIfNotExistsInvokesValidMethodsAsync() .Setup(l => l.Indexes) .Returns(mockMongoIndexManager.Object); - var sut = new MongoDBVectorStoreRecordCollection( + var sut = new MongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, CollectionName); @@ -231,7 +231,7 @@ public async Task DeleteInvokesValidMethodsAsync() // Arrange const string RecordKey = "key"; - var sut = new MongoDBVectorStoreRecordCollection( + var sut = new MongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); @@ -255,7 +255,7 @@ public async Task DeleteBatchInvokesValidMethodsAsync() // Arrange List recordKeys = ["key1", "key2"]; - var sut = new MongoDBVectorStoreRecordCollection( + var sut = new MongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); @@ -264,7 +264,7 @@ public async Task DeleteBatchInvokesValidMethodsAsync() var expectedDefinition = Builders.Filter.In(document => document["_id"].AsString, recordKeys); // Act - await sut.DeleteBatchAsync(recordKeys); + await sut.DeleteAsync(recordKeys); // Assert this._mockMongoCollection.Verify(l => l.DeleteManyAsync( @@ -279,7 +279,7 @@ public async Task DeleteCollectionInvokesValidMethodsAsync() // Arrange const string CollectionName = "collection"; - var sut = new MongoDBVectorStoreRecordCollection( + var sut = new MongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, CollectionName); @@ -316,7 +316,7 @@ public async Task GetReturnsValidRecordAsync() It.IsAny())) .ReturnsAsync(mockCursor.Object); - var sut = new MongoDBVectorStoreRecordCollection( + var sut = new MongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); @@ -354,12 +354,12 @@ public async Task GetBatchReturnsValidRecordAsync() It.IsAny())) .ReturnsAsync(mockCursor.Object); - var sut = new MongoDBVectorStoreRecordCollection( + var sut = new MongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); // Act - var results = await sut.GetBatchAsync(["key1", "key2", "key3"]).ToListAsync(); + var results = await sut.GetAsync(["key1", "key2", "key3"]).ToListAsync(); // Assert Assert.NotNull(results[0]); @@ -385,7 +385,7 @@ public async Task UpsertReturnsRecordKeyAsync() var documentSerializer = serializerRegistry.GetSerializer(); var expectedDefinition = Builders.Filter.Eq(document => document["_id"], "key"); - var sut = new MongoDBVectorStoreRecordCollection( + var sut = new MongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); @@ -413,12 +413,12 @@ public async Task UpsertBatchReturnsRecordKeysAsync() var hotel2 = new MongoDBHotelModel("key2") { HotelName = "Test Name 2" }; var hotel3 = new MongoDBHotelModel("key3") { HotelName = "Test Name 3" }; - var sut = new MongoDBVectorStoreRecordCollection( + var sut = new MongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); // Act - var results = await sut.UpsertBatchAsync([hotel1, hotel2, hotel3]).ToListAsync(); + var results = await sut.UpsertAsync([hotel1, hotel2, hotel3]); // Assert Assert.NotNull(results); @@ -489,110 +489,32 @@ await this.TestUpsertWithModelAsync( expectedPropertyName: "bson_hotel_name"); } - [Fact] - public async Task UpsertWithCustomMapperWorksCorrectlyAsync() - { - // Arrange - var hotel = new MongoDBHotelModel("key") { HotelName = "Test Name" }; - - var mockMapper = new Mock>(); - - mockMapper - .Setup(l => l.MapFromDataToStorageModel(It.IsAny())) - .Returns(new BsonDocument { ["_id"] = "key", ["my_name"] = "Test Name" }); - - var sut = new MongoDBVectorStoreRecordCollection( - this._mockMongoDatabase.Object, - "collection", - new() { BsonDocumentCustomMapper = mockMapper.Object }); - - // Act - var result = await sut.UpsertAsync(hotel); - - // Assert - Assert.Equal("key", result); - - this._mockMongoCollection.Verify(l => l.ReplaceOneAsync( - It.IsAny>(), - It.Is(document => - document["_id"] == "key" && - document["my_name"] == "Test Name"), - It.IsAny(), - It.IsAny()), Times.Once()); - } - - [Fact] - public async Task GetWithCustomMapperWorksCorrectlyAsync() - { - // Arrange - const string RecordKey = "key"; - - var document = new BsonDocument { ["_id"] = RecordKey, ["my_name"] = "Test Name" }; - - var mockCursor = new Mock>(); - mockCursor - .Setup(l => l.MoveNextAsync(It.IsAny())) - .ReturnsAsync(true); - - mockCursor - .Setup(l => l.Current) - .Returns([document]); - - this._mockMongoCollection - .Setup(l => l.FindAsync( - It.IsAny>(), - It.IsAny>(), - It.IsAny())) - .ReturnsAsync(mockCursor.Object); - - var mockMapper = new Mock>(); - - mockMapper - .Setup(l => l.MapFromStorageToDataModel(It.IsAny(), It.IsAny())) - .Returns(new MongoDBHotelModel(RecordKey) { HotelName = "Name from mapper" }); - - var sut = new MongoDBVectorStoreRecordCollection( - this._mockMongoDatabase.Object, - "collection", - new() { BsonDocumentCustomMapper = mockMapper.Object }); - - // Act - var result = await sut.GetAsync(RecordKey); - - // Assert - Assert.NotNull(result); - Assert.Equal(RecordKey, result.HotelId); - Assert.Equal("Name from mapper", result.HotelName); - } - [Theory] - [MemberData(nameof(VectorizedSearchVectorTypeData))] - public async Task VectorizedSearchThrowsExceptionWithInvalidVectorTypeAsync(object vector, bool exceptionExpected) + [MemberData(nameof(SearchEmbeddingVectorTypeData))] + public async Task SearchEmbeddingThrowsExceptionWithInvalidVectorTypeAsync(object vector, bool exceptionExpected) { // Arrange this.MockCollectionForSearch(); - var sut = new MongoDBVectorStoreRecordCollection( + var sut = new MongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); // Act & Assert if (exceptionExpected) { - await Assert.ThrowsAsync(async () => await sut.VectorizedSearchAsync(vector)); + await Assert.ThrowsAsync(async () => await sut.SearchEmbeddingAsync(vector, top: 3).ToListAsync()); } else { - var actual = await sut.VectorizedSearchAsync(vector); - - Assert.NotNull(actual); + Assert.NotNull(await sut.SearchEmbeddingAsync(vector, top: 3).FirstOrDefaultAsync()); } } [Theory] [InlineData("TestEmbedding1", "TestEmbedding1", 3, 3)] [InlineData("TestEmbedding2", "test_embedding_2", 4, 4)] - public async Task VectorizedSearchUsesValidQueryAsync( + public async Task SearchEmbeddingUsesValidQueryAsync( string? vectorPropertyName, string expectedVectorPropertyName, int actualTop, @@ -628,7 +550,7 @@ public async Task VectorizedSearchUsesValidQueryAsync( this.MockCollectionForSearch(); - var sut = new MongoDBVectorStoreRecordCollection( + var sut = new MongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); @@ -640,14 +562,13 @@ public async Task VectorizedSearchUsesValidQueryAsync( }; // Act - var actual = await sut.VectorizedSearchAsync(vector, new() + var actual = await sut.SearchEmbeddingAsync(vector, top: actualTop, new() { VectorProperty = vectorSelector, - Top = actualTop, - }); + }).FirstOrDefaultAsync(); // Assert - Assert.NotNull(await actual.Results.FirstOrDefaultAsync()); + Assert.NotNull(actual); this._mockMongoCollection.Verify(l => l.AggregateAsync( It.Is>(pipeline => @@ -657,36 +578,35 @@ public async Task VectorizedSearchUsesValidQueryAsync( } [Fact] - public async Task VectorizedSearchThrowsExceptionWithNonExistentVectorPropertyNameAsync() + public async Task SearchEmbeddingThrowsExceptionWithNonExistentVectorPropertyNameAsync() { // Arrange this.MockCollectionForSearch(); - var sut = new MongoDBVectorStoreRecordCollection( + var sut = new MongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); var options = new MEVD.VectorSearchOptions { VectorProperty = r => "non-existent-property" }; // Act & Assert - await Assert.ThrowsAsync(async () => await (await sut.VectorizedSearchAsync(new ReadOnlyMemory([1f, 2f, 3f]), options)).Results.FirstOrDefaultAsync()); + await Assert.ThrowsAsync(async () => await sut.SearchEmbeddingAsync(new ReadOnlyMemory([1f, 2f, 3f]), top: 3, options).FirstOrDefaultAsync()); } [Fact] - public async Task VectorizedSearchReturnsRecordWithScoreAsync() + public async Task SearchEmbeddingReturnsRecordWithScoreAsync() { // Arrange this.MockCollectionForSearch(); - var sut = new MongoDBVectorStoreRecordCollection( + var sut = new MongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection"); // Act - var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([1f, 2f, 3f])); + var result = await sut.SearchEmbeddingAsync(new ReadOnlyMemory([1f, 2f, 3f]), top: 3).FirstOrDefaultAsync(); // Assert - var result = await actual.Results.FirstOrDefaultAsync(); Assert.NotNull(result); Assert.Equal("key", result.Record.HotelId); Assert.Equal("Test Name", result.Record.HotelName); @@ -705,7 +625,7 @@ public async Task VectorizedSearchReturnsRecordWithScoreAsync() { [], 1 } }; - public static TheoryData VectorizedSearchVectorTypeData => new() + public static TheoryData SearchEmbeddingVectorTypeData => new() { { new ReadOnlyMemory([1f, 2f, 3f]), false }, { new ReadOnlyMemory([1f, 2f, 3f]), false }, @@ -768,7 +688,7 @@ private async Task TestUpsertWithModelAsync( new() { VectorStoreRecordDefinition = definition } : null; - var sut = new MongoDBVectorStoreRecordCollection( + var sut = new MongoDBVectorStoreRecordCollection( this._mockMongoDatabase.Object, "collection", options); @@ -857,11 +777,11 @@ private sealed class VectorSearchModel [VectorStoreRecordData] public string? HotelName { get; set; } - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance, IndexKind: IndexKind.IvfFlat, StoragePropertyName = "test_embedding_1")] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineDistance, IndexKind = IndexKind.IvfFlat, StoragePropertyName = "test_embedding_1")] public ReadOnlyMemory TestEmbedding1 { get; set; } [BsonElement("test_embedding_2")] - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance, IndexKind: IndexKind.IvfFlat)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineDistance, IndexKind = IndexKind.IvfFlat)] public ReadOnlyMemory TestEmbedding2 { get; set; } } #pragma warning restore CA1812 diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordMapperTests.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordMapperTests.cs index 65ccefcc6eee..ea8ab15172f0 100644 --- a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordMapperTests.cs +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordMapperTests.cs @@ -28,11 +28,11 @@ public MongoDBVectorStoreRecordMapperTests() new VectorStoreRecordDataProperty("HotelName", typeof(string)), new VectorStoreRecordDataProperty("Tags", typeof(List)), new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)), - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?), 10) ] }; - this._sut = new(new VectorStoreRecordPropertyReader(typeof(MongoDBHotelModel), definition, null)); + this._sut = new(new MongoDBModelBuilder().Build(typeof(MongoDBHotelModel), definition, defaultEmbeddingGenerator: null)); } [Fact] @@ -48,7 +48,7 @@ public void MapFromDataToStorageModelReturnsValidObject() }; // Act - var document = this._sut.MapFromDataToStorageModel(hotel); + var document = this._sut.MapFromDataToStorageModel(hotel, generatedEmbeddings: null); // Assert Assert.NotNull(document); diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Services/OpenAIFileService.cs b/dotnet/src/Connectors/Connectors.OpenAI/Services/OpenAIFileService.cs index 83a544920a62..e2f691713f22 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Services/OpenAIFileService.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Services/OpenAIFileService.cs @@ -121,11 +121,11 @@ public async Task GetFileContentAsync(string id, CancellationToke using (stream) { using var memoryStream = new MemoryStream(); -#if NETSTANDARD2_0 +#if NET8_0_OR_GREATER + await stream.CopyToAsync(memoryStream, cancellationToken).ConfigureAwait(false); +#else const int DefaultCopyBufferSize = 81920; await stream.CopyToAsync(memoryStream, DefaultCopyBufferSize, cancellationToken).ConfigureAwait(false); -#else - await stream.CopyToAsync(memoryStream, cancellationToken).ConfigureAwait(false); #endif return new(memoryStream.ToArray(), mimetype) diff --git a/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/Connectors.Pinecone.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/Connectors.Pinecone.UnitTests.csproj index 56a1152f4a46..4e89355c5856 100644 --- a/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/Connectors.Pinecone.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/Connectors.Pinecone.UnitTests.csproj @@ -8,7 +8,7 @@ enable disable false - $(NoWarn);CA2007,CA1806,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0050 + $(NoWarn);CA2007,CA1806,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0050 diff --git a/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeClientTests.cs b/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeClientTests.cs index 0426ccff1765..eaf0a040a100 100644 --- a/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeClientTests.cs +++ b/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeClientTests.cs @@ -18,6 +18,7 @@ public sealed class PineconeClientTests [InlineData("//bypass.com")] [InlineData("javascript:alert(1)")] [InlineData("data:text/html,")] + [Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public void ItThrowsOnEnvironmentUrlInjectionAttempt(string maliciousEnvironment) { // Arrange & Act & Assert @@ -37,6 +38,7 @@ public void ItThrowsOnEnvironmentUrlInjectionAttempt(string maliciousEnvironment [InlineData("asia-southeast-1-pncn")] [InlineData("eu-west-1-pncn")] [InlineData("northamerica-northeast1-pncn")] + [Obsolete("The IMemoryStore abstraction is being obsoleted, use Microsoft.Extensions.VectorData and PineconeVectorStore")] public void ItAcceptsValidEnvironmentNames(string validEnvironment) { // Arrange & Act & Assert diff --git a/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeGenericDataModelMapperTests.cs deleted file mode 100644 index 0a96bb41cc3b..000000000000 --- a/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeGenericDataModelMapperTests.cs +++ /dev/null @@ -1,341 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Linq; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.Pinecone; -using Pinecone; -using Xunit; - -namespace SemanticKernel.Connectors.Pinecone.UnitTests; - -/// -/// Contains tests for the class. -/// -public class PineconeGenericDataModelMapperTests -{ - private static readonly VectorStoreRecordDefinition s_singleVectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordDataProperty("IntDataProp", typeof(int)), - new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordDataProperty("LongDataProp", typeof(long)), - new VectorStoreRecordDataProperty("NullableLongDataProp", typeof(long?)), - new VectorStoreRecordDataProperty("FloatDataProp", typeof(float)), - new VectorStoreRecordDataProperty("NullableFloatDataProp", typeof(float?)), - new VectorStoreRecordDataProperty("DoubleDataProp", typeof(double)), - new VectorStoreRecordDataProperty("NullableDoubleDataProp", typeof(double?)), - new VectorStoreRecordDataProperty("BoolDataProp", typeof(bool)), - new VectorStoreRecordDataProperty("NullableBoolDataProp", typeof(bool?)), - new VectorStoreRecordDataProperty("TagListDataProp", typeof(string[])), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - }, - }; - - private static readonly float[] s_vector = new float[] { 1.0f, 2.0f, 3.0f }; - private static readonly string[] s_taglist = new string[] { "tag1", "tag2" }; - private const string TestKeyString = "testKey"; - - [Fact] - public void MapFromDataToStorageModelMapsAllSupportedTypes() - { - // Arrange. - var reader = new VectorStoreRecordPropertyReader( - typeof(VectorStoreGenericDataModel), - s_singleVectorStoreRecordDefinition, - new() { RequiresAtLeastOneVector = true, SupportsMultipleKeys = false, SupportsMultipleVectors = false }); - var sut = new PineconeGenericDataModelMapper(reader); - var dataModel = new VectorStoreGenericDataModel(TestKeyString) - { - Data = - { - ["StringDataProp"] = "string", - ["IntDataProp"] = 1, - ["NullableIntDataProp"] = 2, - ["LongDataProp"] = 3L, - ["NullableLongDataProp"] = 4L, - ["FloatDataProp"] = 5.0f, - ["NullableFloatDataProp"] = 6.0f, - ["DoubleDataProp"] = 7.0, - ["NullableDoubleDataProp"] = 8.0, - ["BoolDataProp"] = true, - ["NullableBoolDataProp"] = false, - ["TagListDataProp"] = s_taglist, - }, - Vectors = - { - ["FloatVector"] = new ReadOnlyMemory(s_vector), - }, - }; - - // Act. - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal(TestKeyString, storageModel.Id); - Assert.Equal("string", (string?)storageModel.Metadata!["StringDataProp"]!.Value); - // MetadataValue converts all numeric types to double. - Assert.Equal(1, (double?)storageModel.Metadata["IntDataProp"]!.Value); - Assert.Equal(2, (double?)storageModel.Metadata["NullableIntDataProp"]!.Value); - Assert.Equal(3L, (double?)storageModel.Metadata["LongDataProp"]!.Value); - Assert.Equal(4L, (double?)storageModel.Metadata["NullableLongDataProp"]!.Value); - Assert.Equal(5.0f, (double?)storageModel.Metadata["FloatDataProp"]!.Value); - Assert.Equal(6.0f, (double?)storageModel.Metadata["NullableFloatDataProp"]!.Value); - Assert.Equal(7.0, (double?)storageModel.Metadata["DoubleDataProp"]!.Value); - Assert.Equal(8.0, (double?)storageModel.Metadata["NullableDoubleDataProp"]!.Value); - Assert.Equal(true, (bool?)storageModel.Metadata["BoolDataProp"]!.Value); - Assert.Equal(false, (bool?)storageModel.Metadata["NullableBoolDataProp"]!.Value); - Assert.Equal(s_taglist, ((IEnumerable?)(storageModel.Metadata["TagListDataProp"]!.Value!)) - .Select(x => x.Value as string) - .ToArray()); - Assert.Equal(s_vector, storageModel.Values); - } - - [Fact] - public void MapFromDataToStorageModelMapsNullValues() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordDataProperty("NullableTagListDataProp", typeof(string[])), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - }, - }; - - var dataModel = new VectorStoreGenericDataModel(TestKeyString) - { - Data = - { - ["StringDataProp"] = null, - ["NullableIntDataProp"] = null, - ["NullableTagListDataProp"] = null, - }, - Vectors = - { - ["FloatVector"] = new ReadOnlyMemory(s_vector), - }, - }; - - var reader = new VectorStoreRecordPropertyReader( - typeof(VectorStoreGenericDataModel), - vectorStoreRecordDefinition, - new() { RequiresAtLeastOneVector = true, SupportsMultipleKeys = false, SupportsMultipleVectors = false }); - var sut = new PineconeGenericDataModelMapper(reader); - - // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal(TestKeyString, storageModel.Id); - Assert.Null(storageModel.Metadata!["StringDataProp"]); - Assert.Null(storageModel.Metadata["NullableIntDataProp"]); - Assert.Null(storageModel.Metadata["NullableTagListDataProp"]); - } - - [Fact] - public void MapFromStorageToDataModelMapsAllSupportedTypes() - { - // Arrange - var reader = new VectorStoreRecordPropertyReader( - typeof(VectorStoreGenericDataModel), - s_singleVectorStoreRecordDefinition, - new() { RequiresAtLeastOneVector = true, SupportsMultipleKeys = false, SupportsMultipleVectors = false }); - var sut = new PineconeGenericDataModelMapper(reader); - var storageModel = new Vector() - { - Id = TestKeyString, - Metadata = new Metadata() - { - ["StringDataProp"] = (MetadataValue)"string", - ["IntDataProp"] = (MetadataValue)1, - ["NullableIntDataProp"] = (MetadataValue)2, - ["LongDataProp"] = (MetadataValue)3L, - ["NullableLongDataProp"] = (MetadataValue)4L, - ["FloatDataProp"] = (MetadataValue)5.0f, - ["NullableFloatDataProp"] = (MetadataValue)6.0f, - ["DoubleDataProp"] = (MetadataValue)7.0, - ["NullableDoubleDataProp"] = (MetadataValue)8.0, - ["BoolDataProp"] = (MetadataValue)true, - ["NullableBoolDataProp"] = (MetadataValue)false, - ["TagListDataProp"] = (MetadataValue)new MetadataValue[] { "tag1", "tag2" } - }, - Values = new float[] { 1.0f, 2.0f, 3.0f } - }; - - // Act - var dataModel = sut.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = true }); - - // Assert - Assert.Equal(TestKeyString, dataModel.Key); - Assert.Equal("string", (string?)dataModel.Data["StringDataProp"]); - Assert.Equal(1, (int?)dataModel.Data["IntDataProp"]); - Assert.Equal(2, (int?)dataModel.Data["NullableIntDataProp"]); - Assert.Equal(3L, (long?)dataModel.Data["LongDataProp"]); - Assert.Equal(4L, (long?)dataModel.Data["NullableLongDataProp"]); - Assert.Equal(5.0f, (float?)dataModel.Data["FloatDataProp"]); - Assert.Equal(6.0f, (float?)dataModel.Data["NullableFloatDataProp"]); - Assert.Equal(7.0, (double?)dataModel.Data["DoubleDataProp"]); - Assert.Equal(8.0, (double?)dataModel.Data["NullableDoubleDataProp"]); - Assert.Equal(true, (bool?)dataModel.Data["BoolDataProp"]); - Assert.Equal(false, (bool?)dataModel.Data["NullableBoolDataProp"]); - Assert.Equal(s_taglist, (string[]?)dataModel.Data["TagListDataProp"]); - Assert.Equal(s_vector, ((ReadOnlyMemory?)dataModel.Vectors["FloatVector"])!.Value.ToArray()); - } - - [Fact] - public void MapFromStorageToDataModelMapsNullValues() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordDataProperty("NullableTagListDataProp", typeof(string[])), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - }, - }; - - var storageModel = new Vector() - { - Id = TestKeyString, - Metadata = new Metadata() - { - ["StringDataProp"] = null, - ["NullableIntDataProp"] = null, - ["NullableTagListDataProp"] = null, - }, - Values = new float[] { 1.0f, 2.0f, 3.0f } - }; - - var reader = new VectorStoreRecordPropertyReader( - typeof(VectorStoreGenericDataModel), - vectorStoreRecordDefinition, - new() { RequiresAtLeastOneVector = true, SupportsMultipleKeys = false, SupportsMultipleVectors = false }); - var sut = new PineconeGenericDataModelMapper(reader); - - // Act - var dataModel = sut.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = true }); - - // Assert - Assert.Equal(TestKeyString, dataModel.Key); - Assert.Null(dataModel.Data["StringDataProp"]); - Assert.Null(dataModel.Data["NullableIntDataProp"]); - Assert.Null(dataModel.Data["NullableTagListDataProp"]); - Assert.Equal(s_vector, ((ReadOnlyMemory?)dataModel.Vectors["FloatVector"])!.Value.ToArray()); - } - - [Fact] - public void MapFromDataToStorageModelThrowsForInvalidVectorType() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - }, - }; - - var reader = new VectorStoreRecordPropertyReader( - typeof(VectorStoreGenericDataModel), - vectorStoreRecordDefinition, - new() { RequiresAtLeastOneVector = true, SupportsMultipleKeys = false, SupportsMultipleVectors = false }); - var sut = new PineconeGenericDataModelMapper(reader); - - var dataModel = new VectorStoreGenericDataModel(TestKeyString) - { - Vectors = - { - ["FloatVector"] = "not a vector", - }, - }; - - // Act - var exception = Assert.Throws(() => sut.MapFromDataToStorageModel(dataModel)); - - // Assert - Assert.Equal("Vector property 'FloatVector' on provided record of type VectorStoreGenericDataModel must be of type ReadOnlyMemory and not null.", exception.Message); - } - - [Fact] - public void MapFromDataToStorageModelSkipsMissingProperties() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - }, - }; - - var reader = new VectorStoreRecordPropertyReader( - typeof(VectorStoreGenericDataModel), - vectorStoreRecordDefinition, - new() { RequiresAtLeastOneVector = true, SupportsMultipleKeys = false, SupportsMultipleVectors = false }); - var sut = new PineconeGenericDataModelMapper(reader); - - var dataModel = new VectorStoreGenericDataModel(TestKeyString) - { - Vectors = { ["FloatVector"] = new ReadOnlyMemory(s_vector) }, - }; - - // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal(TestKeyString, storageModel.Id); - Assert.False(storageModel.Metadata!.ContainsKey("StringDataProp")); - Assert.Equal(s_vector, storageModel.Values); - } - - [Fact] - public void MapFromStorageToDataModelSkipsMissingProperties() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - }, - }; - - var reader = new VectorStoreRecordPropertyReader( - typeof(VectorStoreGenericDataModel), - vectorStoreRecordDefinition, - new() { RequiresAtLeastOneVector = true, SupportsMultipleKeys = false, SupportsMultipleVectors = false }); - var sut = new PineconeGenericDataModelMapper(reader); - - var storageModel = new Vector() - { - Id = TestKeyString, - Values = new float[] { 1.0f, 2.0f, 3.0f } - }; - - // Act - var dataModel = sut.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = true }); - - // Assert - Assert.Equal(TestKeyString, dataModel.Key); - Assert.False(dataModel.Data.ContainsKey("StringDataProp")); - Assert.Equal(s_vector, ((ReadOnlyMemory?)dataModel.Vectors["FloatVector"])!.Value.ToArray()); - } -} diff --git a/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeKernelBuilderExtensionsTests.cs index 21f7b6649da5..7bc63973812c 100644 --- a/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeKernelBuilderExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeKernelBuilderExtensionsTests.cs @@ -84,11 +84,11 @@ private void AssertVectorStoreRecordCollectionCreated() var collection = kernel.Services.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = kernel.Services.GetRequiredService>(); + var vectorizedSearch = kernel.Services.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); + Assert.IsType>(vectorizedSearch); } #pragma warning disable CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeMemoryStoreTests.cs b/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeMemoryStoreTests.cs index 05e942e857da..85755510e3ff 100644 --- a/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeMemoryStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeMemoryStoreTests.cs @@ -14,6 +14,7 @@ namespace SemanticKernel.Connectors.Pinecone.UnitTests; +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class PineconeMemoryStoreTests { private readonly string _id = "Id"; diff --git a/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeServiceCollectionExtensionsTests.cs index 736cc3e3839d..191fd89c52b9 100644 --- a/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeServiceCollectionExtensionsTests.cs @@ -83,11 +83,11 @@ private void AssertVectorStoreRecordCollectionCreated() var collection = serviceProvider.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = serviceProvider.GetRequiredService>(); + var vectorizedSearch = serviceProvider.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); + Assert.IsType>(vectorizedSearch); } #pragma warning disable CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeUtilsTests.cs b/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeUtilsTests.cs index 9f106c91124e..9c1a181cde92 100644 --- a/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeUtilsTests.cs +++ b/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeUtilsTests.cs @@ -11,6 +11,7 @@ namespace SemanticKernel.Connectors.Pinecone.UnitTests; +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class PineconeUtilsTests { [Fact] diff --git a/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeVectorStoreRecordCollectionTests.cs index 0dc2620140f3..4def919f657b 100644 --- a/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Pinecone.UnitTests/PineconeVectorStoreRecordCollectionTests.cs @@ -4,14 +4,13 @@ using System.Collections.Generic; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Pinecone; -using Moq; using Xunit; using Sdk = Pinecone; namespace SemanticKernel.Connectors.Pinecone.UnitTests; /// -/// Contains tests for the class. +/// Contains tests for the class. /// public class PineconeVectorStoreRecordCollectionTests { @@ -30,18 +29,18 @@ public void CanCreateCollectionWithMismatchedDefinitionAndType() { Properties = new List { - new VectorStoreRecordKeyProperty("Id", typeof(string)), - new VectorStoreRecordDataProperty("Text", typeof(string)), - new VectorStoreRecordVectorProperty("Embedding", typeof(ReadOnlyMemory)) { Dimensions = 4 }, + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("OriginalNameData", typeof(string)), + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory?), 4), } }; var pineconeClient = new Sdk.PineconeClient("fake api key"); // Act. - var sut = new PineconeVectorStoreRecordCollection( + var sut = new PineconeVectorStoreRecordCollection( pineconeClient, TestCollectionName, - new() { VectorStoreRecordDefinition = definition, VectorCustomMapper = Mock.Of>() }); + new() { VectorStoreRecordDefinition = definition }); } public sealed class SinglePropsModel diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/Connectors.Postgres.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/Connectors.Postgres.UnitTests.csproj index 5698a909022e..0f93fb8022e6 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/Connectors.Postgres.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/Connectors.Postgres.UnitTests.csproj @@ -8,7 +8,8 @@ enable disable false - $(NoWarn);SKEXP0001,SKEXP0020,VSTHRD111,CA2007,CS1591 + $(NoWarn);SKEXP0001,VSTHRD111,CA2007,CS1591 + $(NoWarn);MEVD9001 diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs deleted file mode 100644 index d9e97fc6b855..000000000000 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.Postgres; -using Pgvector; -using Xunit; - -namespace SemanticKernel.Connectors.Postgres.UnitTests; - -/// -/// Unit tests for class. -/// -public sealed class PostgresGenericDataModelMapperTests -{ - [Fact] - public void MapFromDataToStorageModelWithStringKeyReturnsValidStorageModel() - { - // Arrange - var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); - var dataModel = GetGenericDataModel("key"); - - var mapper = new PostgresGenericDataModelMapper(propertyReader); - - // Act - var result = mapper.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal("key", result["Key"]); - Assert.Equal("Value1", result["StringProperty"]); - Assert.Equal(5, result["IntProperty"]); - - var vector = result["FloatVector"] as Vector; - - Assert.NotNull(vector); - Assert.True(vector.ToArray().Length > 0); - } - - [Fact] - public void MapFromDataToStorageModelWithNumericKeyReturnsValidStorageModel() - { - // Arrange - var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); - var dataModel = GetGenericDataModel(1); - - var mapper = new PostgresGenericDataModelMapper(propertyReader); - - // Act - var result = mapper.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal(1, result["Key"]); - Assert.Equal("Value1", result["StringProperty"]); - Assert.Equal(5, result["IntProperty"]); - - var vector = result["FloatVector"] as Vector; - - Assert.NotNull(vector); - Assert.True(vector.ToArray().Length > 0); - } - - [Theory] - [InlineData(true)] - [InlineData(false)] - public void MapFromStorageToDataModelWithStringKeyReturnsValidGenericModel(bool includeVectors) - { - // Arrange - var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); - var storageVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); - - var storageModel = new Dictionary - { - ["Key"] = "key", - ["StringProperty"] = "Value1", - ["IntProperty"] = 5, - ["FloatVector"] = storageVector - }; - - var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); - - var mapper = new PostgresGenericDataModelMapper(propertyReader); - - // Act - var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); - - // Assert - Assert.Equal("key", result.Key); - Assert.Equal("Value1", result.Data["StringProperty"]); - Assert.Equal(5, result.Data["IntProperty"]); - - if (includeVectors) - { - Assert.NotNull(result.Vectors["FloatVector"]); - Assert.Equal(vector.ToArray(), ((ReadOnlyMemory)result.Vectors["FloatVector"]!).ToArray()); - } - else - { - Assert.False(result.Vectors.ContainsKey("FloatVector")); - } - } - - [Theory] - [InlineData(true)] - [InlineData(false)] - public void MapFromStorageToDataModelWithNumericKeyReturnsValidGenericModel(bool includeVectors) - { - // Arrange - var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); - var storageVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); - - var storageModel = new Dictionary - { - ["Key"] = 1, - ["StringProperty"] = "Value1", - ["IntProperty"] = 5, - ["FloatVector"] = storageVector - }; - - var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); - - var mapper = new PostgresGenericDataModelMapper(propertyReader); - - // Act - var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); - - // Assert - Assert.Equal(1, result.Key); - Assert.Equal("Value1", result.Data["StringProperty"]); - Assert.Equal(5, result.Data["IntProperty"]); - - if (includeVectors) - { - Assert.NotNull(result.Vectors["FloatVector"]); - Assert.Equal(vector.ToArray(), ((ReadOnlyMemory)result.Vectors["FloatVector"]!).ToArray()); - } - else - { - Assert.False(result.Vectors.ContainsKey("FloatVector")); - } - } - - #region private - - private static VectorStoreRecordDefinition GetRecordDefinition() - { - return new VectorStoreRecordDefinition - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(TKey)), - new VectorStoreRecordDataProperty("StringProperty", typeof(string)), - new VectorStoreRecordDataProperty("IntProperty", typeof(int)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - } - }; - } - - private static VectorStoreGenericDataModel GetGenericDataModel(TKey key) - { - return new VectorStoreGenericDataModel(key) - { - Data = new() - { - ["StringProperty"] = "Value1", - ["IntProperty"] = 5 - }, - Vectors = new() - { - ["FloatVector"] = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]) - } - }; - } - - private static VectorStoreRecordPropertyReader GetPropertyReader(VectorStoreRecordDefinition definition) - { - return new VectorStoreRecordPropertyReader(typeof(TRecord), definition, new() - { - RequiresAtLeastOneVector = false, - SupportsMultipleKeys = false, - SupportsMultipleVectors = true - }); - } - - #endregion -} diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs index e8e84badf292..c50fd11567a9 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs @@ -45,7 +45,7 @@ public record PostgresHotel() public DateTimeOffset UpdatedAt { get; set; } = DateTimeOffset.UtcNow; /// A vector field. - [VectorStoreRecordVector(4, IndexKind.Hnsw, DistanceFunction.ManhattanDistance)] + [VectorStoreRecordVector(4, DistanceFunction = IndexKind.Hnsw, IndexKind = DistanceFunction.ManhattanDistance)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } } #pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs index f667d86eee30..8bd047dab572 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs @@ -51,7 +51,7 @@ public void AddVectorStoreRecordCollectionRegistersClass() Assert.NotNull(collection); Assert.IsType>(collection); - var vectorizedSearch = serviceProvider.GetRequiredService>(); + var vectorizedSearch = serviceProvider.GetRequiredService>(); Assert.NotNull(vectorizedSearch); Assert.IsType>(vectorizedSearch); } diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs index 60dd98f45e7a..8cc59d2aec83 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Linq; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.Postgres; using Pgvector; using Xunit; @@ -26,8 +27,6 @@ public PostgresVectorStoreCollectionSqlBuilderTests(ITestOutputHelper output) [InlineData(false)] public void TestBuildCreateTableCommand(bool ifNotExists) { - var builder = new PostgresVectorStoreCollectionSqlBuilder(); - var recordDefinition = new VectorStoreRecordDefinition() { Properties = [ @@ -38,12 +37,12 @@ public void TestBuildCreateTableCommand(bool ifNotExists) new VectorStoreRecordDataProperty("description", typeof(string)), new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)) { StoragePropertyName = "free_parking" }, new VectorStoreRecordDataProperty("tags", typeof(List)), - new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory), 10) { Dimensions = 10, IndexKind = "hnsw", }, - new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?), 10) { Dimensions = 10, IndexKind = "hnsw", @@ -51,7 +50,9 @@ public void TestBuildCreateTableCommand(bool ifNotExists) ] }; - var cmdInfo = builder.BuildCreateTableCommand("public", "testcollection", recordDefinition.Properties, ifNotExists: ifNotExists); + var model = new VectorStoreRecordModelBuilder(PostgresConstants.ModelBuildingOptions).Build(typeof(Dictionary), recordDefinition, defaultEmbeddingGenerator: null); + + var cmdInfo = PostgresSqlBuilder.BuildCreateTableCommand("public", "testcollection", model, ifNotExists: ifNotExists); // Check for expected properties; integration tests will validate the actual SQL. Assert.Contains("public.\"testcollection\" (", cmdInfo.CommandText); @@ -84,18 +85,16 @@ public void TestBuildCreateTableCommand(bool ifNotExists) [InlineData(IndexKind.Hnsw, DistanceFunction.CosineDistance, false)] public void TestBuildCreateIndexCommand(string indexKind, string distanceFunction, bool ifNotExists) { - var builder = new PostgresVectorStoreCollectionSqlBuilder(); - var vectorColumn = "embedding1"; if (indexKind != IndexKind.Hnsw) { - Assert.Throws(() => builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction, ifNotExists)); - Assert.Throws(() => builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction, ifNotExists)); + Assert.Throws(() => PostgresSqlBuilder.BuildCreateIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction, true, ifNotExists)); + Assert.Throws(() => PostgresSqlBuilder.BuildCreateIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction, true, ifNotExists)); return; } - var cmdInfo = builder.BuildCreateVectorIndexCommand("public", "1testcollection", vectorColumn, indexKind, distanceFunction, ifNotExists); + var cmdInfo = PostgresSqlBuilder.BuildCreateIndexCommand("public", "1testcollection", vectorColumn, indexKind, distanceFunction, true, ifNotExists); // Check for expected properties; integration tests will validate the actual SQL. Assert.Contains("CREATE INDEX ", cmdInfo.CommandText); @@ -133,12 +132,24 @@ public void TestBuildCreateIndexCommand(string indexKind, string distanceFunctio this._output.WriteLine(cmdInfo.CommandText); } + [Theory] + [InlineData(true)] + [InlineData(false)] + public void TestBuildCreateNonVectorIndexCommand(bool ifNotExists) + { + var cmdInfo = PostgresSqlBuilder.BuildCreateIndexCommand("schema", "tableName", "columnName", indexKind: "", distanceFunction: "", isVector: false, ifNotExists); + + var expectedCommandText = ifNotExists + ? "CREATE INDEX IF NOT EXISTS \"tableName_columnName_index\" ON schema.\"tableName\" (\"columnName\");" + : "CREATE INDEX \"tableName_columnName_index\" ON schema.\"tableName\" (\"columnName\");"; + + Assert.Equal(expectedCommandText, cmdInfo.CommandText); + } + [Fact] public void TestBuildDropTableCommand() { - var builder = new PostgresVectorStoreCollectionSqlBuilder(); - - var cmdInfo = builder.BuildDropTableCommand("public", "testcollection"); + var cmdInfo = PostgresSqlBuilder.BuildDropTableCommand("public", "testcollection"); // Check for expected properties; integration tests will validate the actual SQL. Assert.Contains("DROP TABLE IF EXISTS public.\"testcollection\"", cmdInfo.CommandText); @@ -150,8 +161,6 @@ public void TestBuildDropTableCommand() [Fact] public void TestBuildUpsertCommand() { - var builder = new PostgresVectorStoreCollectionSqlBuilder(); - var row = new Dictionary() { ["id"] = 123, @@ -166,7 +175,7 @@ public void TestBuildUpsertCommand() var keyColumn = "id"; - var cmdInfo = builder.BuildUpsertCommand("public", "testcollection", keyColumn, row); + var cmdInfo = PostgresSqlBuilder.BuildUpsertCommand("public", "testcollection", keyColumn, row); // Check for expected properties; integration tests will validate the actual SQL. Assert.Contains("INSERT INTO public.\"testcollection\" (", cmdInfo.CommandText); @@ -191,8 +200,6 @@ public void TestBuildUpsertCommand() [Fact] public void TestBuildUpsertBatchCommand() { - var builder = new PostgresVectorStoreCollectionSqlBuilder(); - var rows = new List>() { new() @@ -222,7 +229,7 @@ public void TestBuildUpsertBatchCommand() var keyColumn = "id"; var columnCount = rows.First().Count; - var cmdInfo = builder.BuildUpsertBatchCommand("public", "testcollection", keyColumn, rows); + var cmdInfo = PostgresSqlBuilder.BuildUpsertBatchCommand("public", "testcollection", keyColumn, rows); // Check for expected properties; integration tests will validate the actual SQL. Assert.Contains("INSERT INTO public.\"testcollection\" (", cmdInfo.CommandText); @@ -251,8 +258,6 @@ public void TestBuildUpsertBatchCommand() public void TestBuildGetCommand() { // Arrange - var builder = new PostgresVectorStoreCollectionSqlBuilder(); - var recordDefinition = new VectorStoreRecordDefinition() { Properties = [ @@ -263,23 +268,23 @@ public void TestBuildGetCommand() new VectorStoreRecordDataProperty("description", typeof(string)), new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)) { StoragePropertyName = "free_parking" }, new VectorStoreRecordDataProperty("tags", typeof(List)), - new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory), 10) { - Dimensions = 10, IndexKind = "hnsw", }, - new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?), 10) { - Dimensions = 10, IndexKind = "hnsw", } ] }; + var model = new VectorStoreRecordModelBuilder(PostgresConstants.ModelBuildingOptions).Build(typeof(Dictionary), recordDefinition, defaultEmbeddingGenerator: null); + var key = 123; // Act - var cmdInfo = builder.BuildGetCommand("public", "testcollection", recordDefinition.Properties, key, includeVectors: true); + var cmdInfo = PostgresSqlBuilder.BuildGetCommand("public", "testcollection", model, key, includeVectors: true); // Assert Assert.Contains("SELECT", cmdInfo.CommandText); @@ -296,8 +301,6 @@ public void TestBuildGetCommand() public void TestBuildGetBatchCommand() { // Arrange - var builder = new PostgresVectorStoreCollectionSqlBuilder(); - var recordDefinition = new VectorStoreRecordDefinition() { Properties = [ @@ -308,14 +311,12 @@ public void TestBuildGetBatchCommand() new VectorStoreRecordDataProperty("description", typeof(string)), new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)) { StoragePropertyName = "free_parking" }, new VectorStoreRecordDataProperty("tags", typeof(List)), - new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory), 10) { - Dimensions = 10, IndexKind = "hnsw", }, - new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?), 10) { - Dimensions = 10, IndexKind = "hnsw", } ] @@ -323,8 +324,10 @@ public void TestBuildGetBatchCommand() var keys = new List { 123, 124 }; + var model = new VectorStoreRecordModelBuilder(PostgresConstants.ModelBuildingOptions).Build(typeof(Dictionary), recordDefinition, defaultEmbeddingGenerator: null); + // Act - var cmdInfo = builder.BuildGetBatchCommand("public", "testcollection", recordDefinition.Properties, keys, includeVectors: true); + var cmdInfo = PostgresSqlBuilder.BuildGetBatchCommand("public", "testcollection", model, keys, includeVectors: true); // Assert Assert.Contains("SELECT", cmdInfo.CommandText); @@ -344,12 +347,10 @@ public void TestBuildGetBatchCommand() public void TestBuildDeleteCommand() { // Arrange - var builder = new PostgresVectorStoreCollectionSqlBuilder(); - var key = 123; // Act - var cmdInfo = builder.BuildDeleteCommand("public", "testcollection", "id", key); + var cmdInfo = PostgresSqlBuilder.BuildDeleteCommand("public", "testcollection", "id", key); // Assert Assert.Contains("DELETE", cmdInfo.CommandText); @@ -364,12 +365,10 @@ public void TestBuildDeleteCommand() public void TestBuildDeleteBatchCommand() { // Arrange - var builder = new PostgresVectorStoreCollectionSqlBuilder(); - var keys = new List { 123, 124 }; // Act - var cmdInfo = builder.BuildDeleteBatchCommand("public", "testcollection", "id", keys); + var cmdInfo = PostgresSqlBuilder.BuildDeleteBatchCommand("public", "testcollection", "id", keys); // Assert Assert.Contains("DELETE", cmdInfo.CommandText); diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs index 0533ab28c3f3..6e071966facd 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs @@ -22,6 +22,7 @@ public class PostgresVectorStoreRecordCollectionTests public PostgresVectorStoreRecordCollectionTests() { this._postgresClientMock = new Mock(MockBehavior.Strict); + this._postgresClientMock.Setup(l => l.DatabaseName).Returns("TestDatabase"); } [Fact] @@ -32,20 +33,20 @@ public async Task CreatesCollectionForGenericModelAsync() { Properties = [ new VectorStoreRecordKeyProperty("HotelId", typeof(int)), - new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, - new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsFilterable = true }, - new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { IsFilterable = true, StoragePropertyName = "parking_is_included" }, - new VectorStoreRecordDataProperty("HotelRating", typeof(float)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsIndexed = true, IsFullTextIndexed = true }, + new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsIndexed = true }, + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { IsIndexed = true, StoragePropertyName = "parking_is_included" }, + new VectorStoreRecordDataProperty("HotelRating", typeof(float)) { IsIndexed = true }, new VectorStoreRecordDataProperty("Tags", typeof(List)), new VectorStoreRecordDataProperty("Description", typeof(string)), - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 100, DistanceFunction = DistanceFunction.ManhattanDistance } + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?), 100) { DistanceFunction = DistanceFunction.ManhattanDistance } ] }; - var options = new PostgresVectorStoreRecordCollectionOptions>() + var options = new PostgresVectorStoreRecordCollectionOptions>() { VectorStoreRecordDefinition = recordDefinition }; - var sut = new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TestCollectionName, options); + var sut = new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TestCollectionName, options); this._postgresClientMock.Setup(x => x.DoesTableExistsAsync(TestCollectionName, this._testCancellationToken)).ReturnsAsync(false); // Act @@ -63,16 +64,16 @@ public void ThrowsForUnsupportedType() { Properties = [ new VectorStoreRecordKeyProperty("HotelId", typeof(ulong)), - new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsIndexed = true, IsFullTextIndexed = true }, ] }; - var options = new PostgresVectorStoreRecordCollectionOptions>() + var options = new PostgresVectorStoreRecordCollectionOptions>() { VectorStoreRecordDefinition = recordDefinition }; // Act & Assert - Assert.Throws(() => new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TestCollectionName, options)); + Assert.Throws(() => new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TestCollectionName, options)); } [Fact] @@ -189,7 +190,7 @@ private sealed class TestRecord [VectorStoreRecordData] public string? Data { get; set; } - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineDistance)] public ReadOnlyMemory? Vector { get; set; } } diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs index 11dfd2ecd564..c3791b5bff85 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.Postgres; using Pgvector; using Xunit; @@ -19,13 +20,13 @@ public void MapFromDataToStorageModelWithStringKeyReturnsValidStorageModel() { // Arrange var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); - var dataModel = GetDataModel("key"); + var model = GetModel>(definition); + var dataModel = GetRecord("key"); - var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); + var mapper = new PostgresVectorStoreRecordMapper>(model); // Act - var result = mapper.MapFromDataToStorageModel(dataModel); + var result = mapper.MapFromDataToStorageModel(dataModel, recordIndex: 0, generatedEmbeddings: null); // Assert Assert.Equal("key", result["Key"]); @@ -43,17 +44,17 @@ public void MapFromDataToStorageModelWithStringKeyReturnsValidStorageModel() public void MapFromDataToStorageModelWithNumericKeyReturnsValidStorageModel() { // Arrange - var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); - var dataModel = GetDataModel(1); + var definition = GetRecordDefinition(); + var propertyReader = GetModel>(definition); + var dataModel = GetRecord(1); - var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); + var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); // Act - var result = mapper.MapFromDataToStorageModel(dataModel); + var result = mapper.MapFromDataToStorageModel(dataModel, recordIndex: 0, generatedEmbeddings: null); // Assert - Assert.Equal((ulong)1, result["Key"]); + Assert.Equal(1L, result["Key"]); Assert.Equal("Value1", result["StringProperty"]); Assert.Equal(5, result["IntProperty"]); Assert.Equal(new List { "Value2", "Value3" }, result["StringArray"]); @@ -67,7 +68,7 @@ public void MapFromDataToStorageModelWithNumericKeyReturnsValidStorageModel() [Theory] [InlineData(true)] [InlineData(false)] - public void MapFromStorageToDataModelWithStringKeyReturnsValidGenericModel(bool includeVectors) + public void MapFromStorageToDataModelWithStringKeyReturnsValidDynamicModel(bool includeVectors) { // Arrange var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); @@ -83,7 +84,7 @@ public void MapFromStorageToDataModelWithStringKeyReturnsValidGenericModel(bool }; var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); + var propertyReader = GetModel>(definition); var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); @@ -110,7 +111,7 @@ public void MapFromStorageToDataModelWithStringKeyReturnsValidGenericModel(bool [Theory] [InlineData(true)] [InlineData(false)] - public void MapFromStorageToDataModelWithNumericKeyReturnsValidGenericModel(bool includeVectors) + public void MapFromStorageToDataModelWithNumericKeyReturnsValidDynamicModel(bool includeVectors) { // Arrange var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); @@ -118,23 +119,23 @@ public void MapFromStorageToDataModelWithNumericKeyReturnsValidGenericModel(bool var storageModel = new Dictionary { - ["Key"] = (ulong)1, + ["Key"] = 1L, ["StringProperty"] = "Value1", ["IntProperty"] = 5, ["StringArray"] = new List { "Value2", "Value3" }, ["FloatVector"] = storageVector }; - var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); + var definition = GetRecordDefinition(); + var propertyReader = GetModel>(definition); - var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); + var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); // Act var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); // Assert - Assert.Equal((ulong)1, result.Key); + Assert.Equal(1L, result.Key); Assert.Equal("Value1", result.StringProperty); Assert.Equal(5, result.IntProperty); Assert.Equal(new List { "Value2", "Value3" }, result.StringArray); @@ -162,12 +163,12 @@ private static VectorStoreRecordDefinition GetRecordDefinition() new VectorStoreRecordDataProperty("StringProperty", typeof(string)), new VectorStoreRecordDataProperty("IntProperty", typeof(int)), new VectorStoreRecordDataProperty("StringArray", typeof(IEnumerable)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10), } }; } - private static TestRecord GetDataModel(TKey key) + private static TestRecord GetRecord(TKey key) { return new TestRecord { @@ -179,15 +180,8 @@ private static TestRecord GetDataModel(TKey key) }; } - private static VectorStoreRecordPropertyReader GetPropertyReader(VectorStoreRecordDefinition definition) - { - return new VectorStoreRecordPropertyReader(typeof(TRecord), definition, new() - { - RequiresAtLeastOneVector = false, - SupportsMultipleKeys = false, - SupportsMultipleVectors = true - }); - } + private static VectorStoreRecordModel GetModel(VectorStoreRecordDefinition definition) + => new VectorStoreRecordModelBuilder(PostgresConstants.ModelBuildingOptions).Build(typeof(TRecord), definition, defaultEmbeddingGenerator: null); #pragma warning disable CA1812 private sealed class TestRecord @@ -204,7 +198,7 @@ private sealed class TestRecord [VectorStoreRecordData] public IEnumerable? StringArray { get; set; } - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineDistance)] public ReadOnlyMemory? FloatVector { get; set; } } #pragma warning restore CA1812 diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs index 0631cc2c0df4..14c73e8b42de 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.Postgres; using Pgvector; using Xunit; @@ -38,21 +39,6 @@ public void MapVectorForStorageModelReturnsVector() Assert.True(storageModelVector.ToArray().Length > 0); } - [Fact] - public void MapVectorForDataModelReturnsReadOnlyMemory() - { - // Arrange - var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); - var pgVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); - - // Act - var dataModelVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForDataModel(pgVector); - - // Assert - Assert.NotNull(dataModelVector); - Assert.Equal(vector.ToArray(), dataModelVector!.Value.ToArray()); - } - [Fact] public void GetPropertyValueReturnsCorrectValuesForLists() { @@ -101,32 +87,41 @@ public void GetPropertyValueReturnsCorrectNullableValue() } [Fact] - public void GetVectorIndexInfoReturnsCorrectValues() + public void GetIndexInfoReturnsCorrectValues() { // Arrange - List vectorProperties = [ - new VectorStoreRecordVectorProperty("vector1", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Hnsw, Dimensions = 1000 }, - new VectorStoreRecordVectorProperty("vector2", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Flat, Dimensions = 3000 }, - new VectorStoreRecordVectorProperty("vector3", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Hnsw, Dimensions = 900, DistanceFunction = DistanceFunction.ManhattanDistance }, + List vectorProperties = + [ + new VectorStoreRecordVectorPropertyModel("vector1", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Hnsw, Dimensions = 1000 }, + new VectorStoreRecordVectorPropertyModel("vector2", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Flat, Dimensions = 3000 }, + new VectorStoreRecordVectorPropertyModel("vector3", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Hnsw, Dimensions = 900, DistanceFunction = DistanceFunction.ManhattanDistance }, + new VectorStoreRecordDataPropertyModel("data1", typeof(string)) { IsIndexed = true }, + new VectorStoreRecordDataPropertyModel("data2", typeof(string)) { IsIndexed = false } ]; // Act - var indexInfo = PostgresVectorStoreRecordPropertyMapping.GetVectorIndexInfo(vectorProperties); + var indexInfo = PostgresVectorStoreRecordPropertyMapping.GetIndexInfo(vectorProperties); // Assert - Assert.Equal(2, indexInfo.Count); - foreach (var (columnName, indexKind, distanceFunction) in indexInfo) + Assert.Equal(3, indexInfo.Count); + foreach (var (columnName, indexKind, distanceFunction, isVector) in indexInfo) { if (columnName == "vector1") { + Assert.True(isVector); Assert.Equal(IndexKind.Hnsw, indexKind); Assert.Equal(DistanceFunction.CosineDistance, distanceFunction); } else if (columnName == "vector3") { + Assert.True(isVector); Assert.Equal(IndexKind.Hnsw, indexKind); Assert.Equal(DistanceFunction.ManhattanDistance, distanceFunction); } + else if (columnName == "data1") + { + Assert.False(isVector); + } else { Assert.Fail("Unexpected column name"); @@ -139,9 +134,9 @@ public void GetVectorIndexInfoReturnsCorrectValues() public void GetVectorIndexInfoReturnsThrowsForInvalidDimensions(string indexKind, int dimensions) { // Arrange - var vectorProperty = new VectorStoreRecordVectorProperty("vector", typeof(ReadOnlyMemory?)) { IndexKind = indexKind, Dimensions = dimensions }; + var vectorProperty = new VectorStoreRecordVectorPropertyModel("vector", typeof(ReadOnlyMemory?)) { IndexKind = indexKind, Dimensions = dimensions }; // Act & Assert - Assert.Throws(() => PostgresVectorStoreRecordPropertyMapping.GetVectorIndexInfo([vectorProperty])); + Assert.Throws(() => PostgresVectorStoreRecordPropertyMapping.GetIndexInfo([vectorProperty])); } } diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs index 33cfc005a7bc..8a89582fc0f8 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs @@ -26,6 +26,7 @@ public class PostgresVectorStoreTests public PostgresVectorStoreTests() { this._postgresClientMock = new Mock(MockBehavior.Strict); + this._postgresClientMock.Setup(l => l.DatabaseName).Returns("TestDatabase"); } [Fact] @@ -60,7 +61,10 @@ public void GetCollectionCallsFactoryIfProvided() var factoryMock = new Mock(MockBehavior.Strict); var collectionMock = new Mock>>(MockBehavior.Strict); var clientMock = new Mock(MockBehavior.Strict); + clientMock.Setup(x => x.DataSource).Returns(null); + clientMock.Setup(x => x.DatabaseName).Returns("TestDatabase"); + factoryMock .Setup(x => x.CreateVectorStoreRecordCollection>(It.IsAny(), TestCollectionName, null)) .Returns(collectionMock.Object); diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/Connectors.Qdrant.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/Connectors.Qdrant.UnitTests.csproj index 87782f3d2e8f..322eb096b3e2 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/Connectors.Qdrant.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/Connectors.Qdrant.UnitTests.csproj @@ -8,7 +8,8 @@ enable disable false - $(NoWarn);CA2007,CA1806,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0050 + $(NoWarn);CA2007,CA1806,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0050 + $(NoWarn);MEVD9001 diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantGenericDataModelMapperTests.cs deleted file mode 100644 index 9710bb3b0640..000000000000 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantGenericDataModelMapperTests.cs +++ /dev/null @@ -1,405 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Linq; -using Microsoft.Extensions.VectorData; -using Qdrant.Client.Grpc; -using Xunit; - -namespace Microsoft.SemanticKernel.Connectors.Qdrant.UnitTests; - -/// -/// Contains tests for the class. -/// -public class QdrantGenericDataModelMapperTests -{ - private static readonly VectorStoreRecordDefinition s_singleVectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordDataProperty("IntDataProp", typeof(int)), - new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordDataProperty("LongDataProp", typeof(long)), - new VectorStoreRecordDataProperty("NullableLongDataProp", typeof(long?)), - new VectorStoreRecordDataProperty("FloatDataProp", typeof(float)), - new VectorStoreRecordDataProperty("NullableFloatDataProp", typeof(float?)), - new VectorStoreRecordDataProperty("DoubleDataProp", typeof(double)), - new VectorStoreRecordDataProperty("NullableDoubleDataProp", typeof(double?)), - new VectorStoreRecordDataProperty("BoolDataProp", typeof(bool)), - new VectorStoreRecordDataProperty("NullableBoolDataProp", typeof(bool?)), - new VectorStoreRecordDataProperty("TagListDataProp", typeof(string[])), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - }, - }; - - private static readonly VectorStoreRecordDefinition s_multiVectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordDataProperty("IntDataProp", typeof(int)), - new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordDataProperty("LongDataProp", typeof(long)), - new VectorStoreRecordDataProperty("NullableLongDataProp", typeof(long?)), - new VectorStoreRecordDataProperty("FloatDataProp", typeof(float)), - new VectorStoreRecordDataProperty("NullableFloatDataProp", typeof(float?)), - new VectorStoreRecordDataProperty("DoubleDataProp", typeof(double)), - new VectorStoreRecordDataProperty("NullableDoubleDataProp", typeof(double?)), - new VectorStoreRecordDataProperty("BoolDataProp", typeof(bool)), - new VectorStoreRecordDataProperty("NullableBoolDataProp", typeof(bool?)), - new VectorStoreRecordDataProperty("TagListDataProp", typeof(string[])), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?)), - }, - }; - - private static readonly float[] s_vector1 = new float[] { 1.0f, 2.0f, 3.0f }; - private static readonly float[] s_vector2 = new float[] { 4.0f, 5.0f, 6.0f }; - private static readonly string[] s_taglist = new string[] { "tag1", "tag2" }; - private const string TestGuidKeyString = "11111111-1111-1111-1111-111111111111"; - private static readonly Guid s_testGuidKey = Guid.Parse(TestGuidKeyString); - - [Theory] - [InlineData(true)] - [InlineData(false)] - public void MapFromDataToStorageModelMapsAllSupportedTypes(bool hasNamedVectors) - { - // Arrange. - var reader = new VectorStoreRecordPropertyReader(typeof(VectorStoreGenericDataModel), hasNamedVectors ? s_multiVectorStoreRecordDefinition : s_singleVectorStoreRecordDefinition, null); - var sut = new QdrantGenericDataModelMapper(reader, hasNamedVectors); - var dataModel = new VectorStoreGenericDataModel(1ul) - { - Data = - { - ["StringDataProp"] = "string", - ["IntDataProp"] = 1, - ["NullableIntDataProp"] = 2, - ["LongDataProp"] = 3L, - ["NullableLongDataProp"] = 4L, - ["FloatDataProp"] = 5.0f, - ["NullableFloatDataProp"] = 6.0f, - ["DoubleDataProp"] = 7.0, - ["NullableDoubleDataProp"] = 8.0, - ["BoolDataProp"] = true, - ["NullableBoolDataProp"] = false, - ["TagListDataProp"] = s_taglist, - }, - Vectors = - { - ["FloatVector"] = new ReadOnlyMemory(s_vector1), - }, - }; - - if (hasNamedVectors) - { - dataModel.Vectors.Add("NullableFloatVector", new ReadOnlyMemory(s_vector2)); - } - - // Act. - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal(1ul, storageModel.Id.Num); - Assert.Equal("string", (string?)storageModel.Payload["StringDataProp"].StringValue); - Assert.Equal(1, (int?)storageModel.Payload["IntDataProp"].IntegerValue); - Assert.Equal(2, (int?)storageModel.Payload["NullableIntDataProp"].IntegerValue); - Assert.Equal(3L, (long?)storageModel.Payload["LongDataProp"].IntegerValue); - Assert.Equal(4L, (long?)storageModel.Payload["NullableLongDataProp"].IntegerValue); - Assert.Equal(5.0f, (float?)storageModel.Payload["FloatDataProp"].DoubleValue); - Assert.Equal(6.0f, (float?)storageModel.Payload["NullableFloatDataProp"].DoubleValue); - Assert.Equal(7.0, (double?)storageModel.Payload["DoubleDataProp"].DoubleValue); - Assert.Equal(8.0, (double?)storageModel.Payload["NullableDoubleDataProp"].DoubleValue); - Assert.Equal(true, (bool?)storageModel.Payload["BoolDataProp"].BoolValue); - Assert.Equal(false, (bool?)storageModel.Payload["NullableBoolDataProp"].BoolValue); - Assert.Equal(s_taglist, storageModel.Payload["TagListDataProp"].ListValue.Values.Select(x => x.StringValue).ToArray()); - - if (hasNamedVectors) - { - Assert.Equal(s_vector1, storageModel.Vectors.Vectors_.Vectors["FloatVector"].Data.ToArray()); - Assert.Equal(s_vector2, storageModel.Vectors.Vectors_.Vectors["NullableFloatVector"].Data.ToArray()); - } - else - { - Assert.Equal(s_vector1, storageModel.Vectors.Vector.Data.ToArray()); - } - } - - [Theory] - [InlineData(true)] - [InlineData(false)] - public void MapFromDataToStorageModelMapsNullValues(bool hasNamedVectors) - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(Guid)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordDataProperty("NullableTagListDataProp", typeof(string[])), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - }, - }; - - var dataModel = new VectorStoreGenericDataModel(s_testGuidKey) - { - Data = - { - ["StringDataProp"] = null, - ["NullableIntDataProp"] = null, - ["NullableTagListDataProp"] = null, - }, - Vectors = - { - ["FloatVector"] = new ReadOnlyMemory(s_vector1), - }, - }; - - var reader = new VectorStoreRecordPropertyReader(typeof(VectorStoreGenericDataModel), vectorStoreRecordDefinition, null); - var sut = (IVectorStoreRecordMapper, PointStruct>)new QdrantGenericDataModelMapper(reader, hasNamedVectors); - - // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal(TestGuidKeyString, storageModel.Id.Uuid); - Assert.True(storageModel.Payload["StringDataProp"].HasNullValue); - Assert.True(storageModel.Payload["NullableIntDataProp"].HasNullValue); - Assert.True(storageModel.Payload["NullableTagListDataProp"].HasNullValue); - } - - [Theory] - [InlineData(true)] - [InlineData(false)] - public void MapFromStorageToDataModelMapsAllSupportedTypes(bool hasNamedVectors) - { - // Arrange - var reader = new VectorStoreRecordPropertyReader(typeof(VectorStoreGenericDataModel), hasNamedVectors ? s_multiVectorStoreRecordDefinition : s_singleVectorStoreRecordDefinition, null); - var sut = new QdrantGenericDataModelMapper(reader, hasNamedVectors); - var storageModel = new PointStruct() - { - Id = new PointId() { Num = 1 }, - Payload = - { - ["StringDataProp"] = new Value() { StringValue = "string" }, - ["IntDataProp"] = new Value() { IntegerValue = 1 }, - ["NullableIntDataProp"] = new Value() { IntegerValue = 2 }, - ["LongDataProp"] = new Value() { IntegerValue = 3 }, - ["NullableLongDataProp"] = new Value() { IntegerValue = 4 }, - ["FloatDataProp"] = new Value() { DoubleValue = 5.0 }, - ["NullableFloatDataProp"] = new Value() { DoubleValue = 6.0 }, - ["DoubleDataProp"] = new Value() { DoubleValue = 7.0 }, - ["NullableDoubleDataProp"] = new Value() { DoubleValue = 8.0 }, - ["BoolDataProp"] = new Value() { BoolValue = true }, - ["NullableBoolDataProp"] = new Value() { BoolValue = false }, - ["TagListDataProp"] = new Value() - { - ListValue = new ListValue() - { - Values = - { - new Value() { StringValue = "tag1" }, - new Value() { StringValue = "tag2" }, - }, - }, - }, - }, - Vectors = new Vectors() - }; - - if (hasNamedVectors) - { - storageModel.Vectors.Vectors_ = new NamedVectors(); - storageModel.Vectors.Vectors_.Vectors.Add("FloatVector", new Vector() { Data = { 1.0f, 2.0f, 3.0f } }); - storageModel.Vectors.Vectors_.Vectors.Add("NullableFloatVector", new Vector() { Data = { 4.0f, 5.0f, 6.0f } }); - } - else - { - storageModel.Vectors.Vector = new Vector() { Data = { 1.0f, 2.0f, 3.0f } }; - } - - // Act - var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions()); - - // Assert - Assert.Equal(1ul, dataModel.Key); - Assert.Equal("string", (string?)dataModel.Data["StringDataProp"]); - Assert.Equal(1, (int?)dataModel.Data["IntDataProp"]); - Assert.Equal(2, (int?)dataModel.Data["NullableIntDataProp"]); - Assert.Equal(3L, (long?)dataModel.Data["LongDataProp"]); - Assert.Equal(4L, (long?)dataModel.Data["NullableLongDataProp"]); - Assert.Equal(5.0f, (float?)dataModel.Data["FloatDataProp"]); - Assert.Equal(6.0f, (float?)dataModel.Data["NullableFloatDataProp"]); - Assert.Equal(7.0, (double?)dataModel.Data["DoubleDataProp"]); - Assert.Equal(8.0, (double?)dataModel.Data["NullableDoubleDataProp"]); - Assert.Equal(true, (bool?)dataModel.Data["BoolDataProp"]); - Assert.Equal(false, (bool?)dataModel.Data["NullableBoolDataProp"]); - Assert.Equal(s_taglist, (string[]?)dataModel.Data["TagListDataProp"]); - - if (hasNamedVectors) - { - Assert.Equal(s_vector1, ((ReadOnlyMemory?)dataModel.Vectors["FloatVector"])!.Value.ToArray()); - Assert.Equal(s_vector2, ((ReadOnlyMemory?)dataModel.Vectors["NullableFloatVector"])!.Value.ToArray()); - } - else - { - Assert.Equal(s_vector1, ((ReadOnlyMemory?)dataModel.Vectors["FloatVector"])!.Value.ToArray()); - } - } - - [Theory] - [InlineData(true)] - [InlineData(false)] - public void MapFromStorageToDataModelMapsNullValues(bool hasNamedVectors) - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(Guid)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordDataProperty("NullableTagListDataProp", typeof(string[])), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - }, - }; - - var storageModel = new PointStruct() - { - Id = new PointId() { Uuid = TestGuidKeyString }, - Payload = - { - ["StringDataProp"] = new Value() { NullValue = new NullValue() }, - ["NullableIntDataProp"] = new Value() { NullValue = new NullValue() }, - ["NullableTagListDataProp"] = new Value() { NullValue = new NullValue() }, - }, - Vectors = new Vectors() - }; - - if (hasNamedVectors) - { - storageModel.Vectors.Vectors_ = new NamedVectors(); - storageModel.Vectors.Vectors_.Vectors.Add("FloatVector", new Vector() { Data = { 1.0f, 2.0f, 3.0f } }); - } - else - { - storageModel.Vectors.Vector = new Vector() { Data = { 1.0f, 2.0f, 3.0f } }; - } - - var reader = new VectorStoreRecordPropertyReader(typeof(VectorStoreGenericDataModel), vectorStoreRecordDefinition, null); - var sut = (IVectorStoreRecordMapper, PointStruct>)new QdrantGenericDataModelMapper(reader, hasNamedVectors); - - // Act - var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions()); - - // Assert - Assert.Equal(s_testGuidKey, dataModel.Key); - Assert.Null(dataModel.Data["StringDataProp"]); - Assert.Null(dataModel.Data["NullableIntDataProp"]); - Assert.Null(dataModel.Data["NullableTagListDataProp"]); - Assert.Equal(s_vector1, ((ReadOnlyMemory?)dataModel.Vectors["FloatVector"])!.Value.ToArray()); - } - - [Fact] - public void MapFromDataToStorageModelThrowsForInvalidVectorType() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(ulong)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - }, - }; - - var reader = new VectorStoreRecordPropertyReader(typeof(VectorStoreGenericDataModel), vectorStoreRecordDefinition, null); - var sut = new QdrantGenericDataModelMapper(reader, false); - - var dataModel = new VectorStoreGenericDataModel(1ul) - { - Vectors = - { - ["FloatVector"] = "not a vector", - }, - }; - - // Act - var exception = Assert.Throws(() => sut.MapFromDataToStorageModel(dataModel)); - - // Assert - Assert.Equal("Vector property 'FloatVector' on provided record of type VectorStoreGenericDataModel must be of type ReadOnlyMemory and not null.", exception.Message); - } - - [Fact] - public void MapFromDataToStorageModelSkipsMissingProperties() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(ulong)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - }, - }; - - var reader = new VectorStoreRecordPropertyReader(typeof(VectorStoreGenericDataModel), vectorStoreRecordDefinition, null); - var sut = new QdrantGenericDataModelMapper(reader, false); - - var dataModel = new VectorStoreGenericDataModel(1ul) - { - Vectors = { ["FloatVector"] = new ReadOnlyMemory(s_vector1) }, - }; - - // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal(1ul, storageModel.Id.Num); - Assert.False(storageModel.Payload.ContainsKey("StringDataProp")); - Assert.Equal(s_vector1, storageModel.Vectors.Vector.Data.ToArray()); - } - - [Fact] - public void MapFromStorageToDataModelSkipsMissingProperties() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(ulong)), - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - }, - }; - - var reader = new VectorStoreRecordPropertyReader(typeof(VectorStoreGenericDataModel), vectorStoreRecordDefinition, null); - var sut = new QdrantGenericDataModelMapper(reader, false); - - var storageModel = new PointStruct() - { - Id = new PointId() { Num = 1 }, - Vectors = new Vectors() - { - Vector = new Vector() { Data = { 1.0f, 2.0f, 3.0f } } - }, - }; - - // Act - var dataModel = sut.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = true }); - - // Assert - Assert.Equal(1ul, dataModel.Key); - Assert.False(dataModel.Data.ContainsKey("StringDataProp")); - Assert.Equal(s_vector1, ((ReadOnlyMemory?)dataModel.Vectors["FloatVector"])!.Value.ToArray()); - } -} diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantKernelBuilderExtensionsTests.cs index aa1d89f7b3f4..2bd0b26ccfd7 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantKernelBuilderExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantKernelBuilderExtensionsTests.cs @@ -104,11 +104,11 @@ private void AssertVectorStoreRecordCollectionCreated() var collection = kernel.Services.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = kernel.Services.GetRequiredService>(); + var vectorizedSearch = kernel.Services.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); + Assert.IsType>(vectorizedSearch); } #pragma warning disable CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantMemoryBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantMemoryBuilderExtensionsTests.cs index 897a09087f09..f51abc6f2432 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantMemoryBuilderExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantMemoryBuilderExtensionsTests.cs @@ -13,6 +13,7 @@ namespace SemanticKernel.Connectors.Qdrant.UnitTests; +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public sealed class QdrantMemoryBuilderExtensionsTests : IDisposable { private readonly HttpMessageHandlerStub _messageHandlerStub; diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantMemoryStoreTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantMemoryStoreTests.cs index 6ae498561065..b0c7a448b668 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantMemoryStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantMemoryStoreTests.cs @@ -17,6 +17,7 @@ namespace SemanticKernel.Connectors.Qdrant.UnitTests; /// /// Tests for collection and upsert operations. /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class QdrantMemoryStoreTests { private readonly string _id = "Id"; diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantMemoryStoreTests2.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantMemoryStoreTests2.cs index 8af2061c5d3a..8db0a7c0a839 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantMemoryStoreTests2.cs +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantMemoryStoreTests2.cs @@ -16,6 +16,7 @@ namespace SemanticKernel.Connectors.Qdrant.UnitTests; /// /// Tests for Get and Remove operations. /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class QdrantMemoryStoreTests2 { private readonly string _id = "Id"; diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantMemoryStoreTests3.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantMemoryStoreTests3.cs index ad7d54e2d5bb..7558098b2713 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantMemoryStoreTests3.cs +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantMemoryStoreTests3.cs @@ -20,6 +20,7 @@ namespace SemanticKernel.Connectors.Qdrant.UnitTests; /// /// Tests for Search operations. /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class QdrantMemoryStoreTests3 { private readonly string _id = "Id"; diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantServiceCollectionExtensionsTests.cs index 96985961aa60..5e07567030f9 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantServiceCollectionExtensionsTests.cs @@ -104,11 +104,11 @@ private void AssertVectorStoreRecordCollectionCreated() var collection = serviceProvider.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = serviceProvider.GetRequiredService>(); + var vectorizedSearch = serviceProvider.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); + Assert.IsType>(vectorizedSearch); } #pragma warning disable CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorDbClientTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorDbClientTests.cs index 41a95178a588..bf3a52f7683b 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorDbClientTests.cs +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorDbClientTests.cs @@ -8,6 +8,7 @@ namespace SemanticKernel.Connectors.Qdrant.UnitTests; +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public sealed class QdrantVectorDbClientTests : IDisposable { private readonly HttpMessageHandlerStub _messageHandlerStub; diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionCreateMappingTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionCreateMappingTests.cs index 9dbcec1c88b3..f5ec56fff9a7 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionCreateMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionCreateMappingTests.cs @@ -1,8 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Qdrant.Client.Grpc; using Xunit; @@ -17,7 +17,7 @@ public class QdrantVectorStoreCollectionCreateMappingTests public void MapSingleVectorCreatesVectorParams() { // Arrange. - var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)) { Dimensions = 4, DistanceFunction = DistanceFunction.DotProductSimilarity }; + var vectorProperty = new VectorStoreRecordVectorPropertyModel("testvector", typeof(ReadOnlyMemory)) { Dimensions = 4, DistanceFunction = DistanceFunction.DotProductSimilarity }; // Act. var actual = QdrantVectorStoreCollectionCreateMapping.MapSingleVector(vectorProperty); @@ -32,7 +32,7 @@ public void MapSingleVectorCreatesVectorParams() public void MapSingleVectorDefaultsToCosine() { // Arrange. - var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)) { Dimensions = 4 }; + var vectorProperty = new VectorStoreRecordVectorPropertyModel("testvector", typeof(ReadOnlyMemory)) { Dimensions = 4 }; // Act. var actual = QdrantVectorStoreCollectionCreateMapping.MapSingleVector(vectorProperty); @@ -45,19 +45,7 @@ public void MapSingleVectorDefaultsToCosine() public void MapSingleVectorThrowsForUnsupportedDistanceFunction() { // Arrange. - var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)) { Dimensions = 4, DistanceFunction = DistanceFunction.CosineDistance }; - - // Act and assert. - Assert.Throws(() => QdrantVectorStoreCollectionCreateMapping.MapSingleVector(vectorProperty)); - } - - [Theory] - [InlineData(null)] - [InlineData(0)] - public void MapSingleVectorThrowsIfDimensionsIsInvalid(int? dimensions) - { - // Arrange. - var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)) { Dimensions = dimensions }; + var vectorProperty = new VectorStoreRecordVectorPropertyModel("testvector", typeof(ReadOnlyMemory)) { Dimensions = 4, DistanceFunction = DistanceFunction.CosineDistance }; // Act and assert. Assert.Throws(() => QdrantVectorStoreCollectionCreateMapping.MapSingleVector(vectorProperty)); @@ -67,20 +55,23 @@ public void MapSingleVectorThrowsIfDimensionsIsInvalid(int? dimensions) public void MapNamedVectorsCreatesVectorParamsMap() { // Arrange. - var vectorProperties = new VectorStoreRecordVectorProperty[] - { - new("testvector1", typeof(ReadOnlyMemory)) { Dimensions = 10, DistanceFunction = DistanceFunction.EuclideanDistance }, - new("testvector2", typeof(ReadOnlyMemory)) { Dimensions = 20 } - }; - - var storagePropertyNames = new Dictionary + var vectorProperties = new VectorStoreRecordVectorPropertyModel[] { - { "testvector1", "storage_testvector1" }, - { "testvector2", "storage_testvector2" } + new("testvector1", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + DistanceFunction = DistanceFunction.EuclideanDistance, + StorageName = "storage_testvector1" + }, + new("testvector2", typeof(ReadOnlyMemory)) + { + Dimensions = 20, + StorageName = "storage_testvector2" + } }; // Act. - var actual = QdrantVectorStoreCollectionCreateMapping.MapNamedVectors(vectorProperties, storagePropertyNames); + var actual = QdrantVectorStoreCollectionCreateMapping.MapNamedVectors(vectorProperties); // Assert. Assert.NotNull(actual); diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionSearchMappingTests.cs index 638bc2cbf861..c2bdb2f76e24 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionSearchMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionSearchMappingTests.cs @@ -4,7 +4,7 @@ using System.Collections.Generic; using System.Linq; using Microsoft.Extensions.VectorData; -using Moq; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Qdrant.Client.Grpc; using Xunit; @@ -17,6 +17,21 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant.UnitTests; /// public class QdrantVectorStoreCollectionSearchMappingTests { + private readonly VectorStoreRecordModel _model = + new VectorStoreRecordModelBuilder(QdrantVectorStoreRecordFieldMapping.GetModelBuildOptions(hasNamedVectors: false)) + .Build( + typeof(Dictionary), + new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(Guid)) { StoragePropertyName = "storage_key" }, + new VectorStoreRecordDataProperty("FieldName", typeof(string)) { StoragePropertyName = "storage_FieldName" }, + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 10) { StoragePropertyName = "storage_vector" }, + ] + }, + defaultEmbeddingGenerator: null); + [Theory] [InlineData("string")] [InlineData("int")] @@ -37,7 +52,7 @@ public void BuildFilterMapsEqualityClause(string type) var filter = new VectorSearchFilter().EqualTo("FieldName", expected); // Act. - var actual = QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(filter, new Dictionary() { { "FieldName", "storage_FieldName" } }); + var actual = QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(filter, this._model); // Assert. Assert.Single(actual.Must); @@ -71,7 +86,7 @@ public void BuildFilterMapsTagContainsClause() var filter = new VectorSearchFilter().AnyTagEqualTo("FieldName", "Value"); // Act. - var actual = QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(filter, new Dictionary() { { "FieldName", "storage_FieldName" } }); + var actual = QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(filter, this._model); // Assert. Assert.Single(actual.Must); @@ -83,10 +98,10 @@ public void BuildFilterMapsTagContainsClause() public void BuildFilterThrowsForUnknownFieldName() { // Arrange. - var filter = new VectorSearchFilter().EqualTo("FieldName", "Value"); + var filter = new VectorSearchFilter().EqualTo("UnknownFieldName", "Value"); // Act and Assert. - Assert.Throws(() => QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(filter, new Dictionary())); + Assert.Throws(() => QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(filter, this._model)); } [Fact] @@ -103,11 +118,24 @@ public void MapScoredPointToVectorSearchResultMapsResults() Score = 0.5f }; - var mapperMock = new Mock>(MockBehavior.Strict); - mapperMock.Setup(x => x.MapFromStorageToDataModel(It.IsAny(), It.IsAny())).Returns(new DataModel { Id = 1, DataField = "data 1", Embedding = new float[] { 1, 2, 3 } }); + var model = new VectorStoreRecordModelBuilder(QdrantVectorStoreRecordFieldMapping.GetModelBuildOptions(hasNamedVectors: false)) + .Build( + typeof(DataModel), + new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Id", typeof(ulong)), + new VectorStoreRecordDataProperty("DataField", typeof(string)) { StoragePropertyName = "storage_DataField" }, + new VectorStoreRecordVectorProperty("Embedding", typeof(ReadOnlyMemory), 10), + ] + }, + defaultEmbeddingGenerator: null); + + var mapper = new QdrantVectorStoreRecordMapper(model, hasNamedVectors: false); // Act. - var actual = QdrantVectorStoreCollectionSearchMapping.MapScoredPointToVectorSearchResult(scoredPoint, mapperMock.Object, true, "Qdrant", "mycollection", "query"); + var actual = QdrantVectorStoreCollectionSearchMapping.MapScoredPointToVectorSearchResult(scoredPoint, mapper, true, "Qdrant", "myvectorstore", "mycollection", "query"); // Assert. Assert.Equal(1ul, actual.Record.Id); diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordCollectionTests.cs index 216828137953..5a47b0b80a37 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordCollectionTests.cs @@ -14,7 +14,7 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant.UnitTests; /// -/// Contains tests for the class. +/// Contains tests for the class. /// public class QdrantVectorStoreRecordCollectionTests { @@ -39,7 +39,7 @@ public QdrantVectorStoreRecordCollectionTests() public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) { // Arrange. - var sut = new QdrantVectorStoreRecordCollection>(this._qdrantClientMock.Object, collectionName); + var sut = new QdrantVectorStoreRecordCollection>(this._qdrantClientMock.Object, collectionName); this._qdrantClientMock .Setup(x => x.CollectionExistsAsync( @@ -58,7 +58,7 @@ public async Task CollectionExistsReturnsCollectionStateAsync(string collectionN public async Task CanCreateCollectionAsync() { // Arrange. - var sut = new QdrantVectorStoreRecordCollection>(this._qdrantClientMock.Object, TestCollectionName); + var sut = new QdrantVectorStoreRecordCollection>(this._qdrantClientMock.Object, TestCollectionName); this._qdrantClientMock .Setup(x => x.CreateCollectionAsync( @@ -119,7 +119,7 @@ public async Task CanCreateCollectionAsync() public async Task CanDeleteCollectionAsync() { // Arrange. - var sut = new QdrantVectorStoreRecordCollection>(this._qdrantClientMock.Object, TestCollectionName); + var sut = new QdrantVectorStoreRecordCollection>(this._qdrantClientMock.Object, TestCollectionName); this._qdrantClientMock .Setup(x => x.DeleteCollectionAsync( @@ -226,7 +226,7 @@ public async Task CanGetManyRecordsWithVectorsAsync(bool useDefinition, bo this.SetupRetrieveMock(testRecordKeys.Select(x => CreateRetrievedPoint(hasNamedVectors, x)).ToList()); // Act. - var actual = await sut.GetBatchAsync( + var actual = await sut.GetAsync( testRecordKeys, new() { IncludeVectors = true }, this._testCancellationToken).ToListAsync(); @@ -253,52 +253,6 @@ public async Task CanGetManyRecordsWithVectorsAsync(bool useDefinition, bo Assert.Equal(testRecordKeys[1], actual[1].Key); } - [Fact] - public async Task CanGetRecordWithCustomMapperAsync() - { - // Arrange. - var retrievedPoint = CreateRetrievedPoint(true, UlongTestRecordKey1); - this.SetupRetrieveMock([retrievedPoint]); - - // Arrange mapper mock from PointStruct to data model. - var mapperMock = new Mock, PointStruct>>(MockBehavior.Strict); - mapperMock.Setup( - x => x.MapFromStorageToDataModel( - It.IsAny(), - It.IsAny())) - .Returns(CreateModel(UlongTestRecordKey1, true)); - - // Arrange target with custom mapper. - var sut = new QdrantVectorStoreRecordCollection>( - this._qdrantClientMock.Object, - TestCollectionName, - new() - { - HasNamedVectors = true, - PointStructCustomMapper = mapperMock.Object - }); - - // Act - var actual = await sut.GetAsync( - UlongTestRecordKey1, - new() { IncludeVectors = true }, - this._testCancellationToken); - - // Assert - Assert.NotNull(actual); - Assert.Equal(UlongTestRecordKey1, actual.Key); - Assert.Equal("data 1", actual.OriginalNameData); - Assert.Equal("data 1", actual.Data); - Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector!.Value.ToArray()); - - mapperMock - .Verify( - x => x.MapFromStorageToDataModel( - It.Is(x => x.Id.Num == UlongTestRecordKey1), - It.Is(x => x.IncludeVectors)), - Times.Once); - } - [Theory] [InlineData(true, true)] [InlineData(true, false)] @@ -369,7 +323,7 @@ public async Task CanDeleteManyUlongRecordsAsync(bool useDefinition, bool hasNam this.SetupDeleteMocks(); // Act - await sut.DeleteBatchAsync( + await sut.DeleteAsync( [UlongTestRecordKey1, UlongTestRecordKey2], cancellationToken: this._testCancellationToken); @@ -398,7 +352,7 @@ public async Task CanDeleteManyGuidRecordsAsync(bool useDefinition, bool hasName this.SetupDeleteMocks(); // Act - await sut.DeleteBatchAsync( + await sut.DeleteAsync( [s_guidTestRecordKey1, s_guidTestRecordKey2], cancellationToken: this._testCancellationToken); @@ -454,9 +408,9 @@ public async Task CanUpsertManyRecordsAsync(bool useDefinition, bool hasNa var models = testRecordKeys.Select(x => CreateModel(x, true)); // Act - var actual = await sut.UpsertBatchAsync( + var actual = await sut.UpsertAsync( models, - cancellationToken: this._testCancellationToken).ToListAsync(); + cancellationToken: this._testCancellationToken); // Assert Assert.NotNull(actual); @@ -479,46 +433,6 @@ public async Task CanUpsertManyRecordsAsync(bool useDefinition, bool hasNa Times.Once); } - [Fact] - public async Task CanUpsertRecordWithCustomMapperAsync() - { - // Arrange. - this.SetupUpsertMock(); - var pointStruct = new PointStruct - { - Id = new() { Num = UlongTestRecordKey1 }, - Payload = { ["OriginalNameData"] = "data 1", ["data_storage_name"] = "data 1" }, - Vectors = new[] { 1f, 2f, 3f, 4f } - }; - - // Arrange mapper mock from data model to PointStruct. - var mapperMock = new Mock, PointStruct>>(MockBehavior.Strict); - mapperMock - .Setup(x => x.MapFromDataToStorageModel(It.IsAny>())) - .Returns(pointStruct); - - // Arrange target with custom mapper. - var sut = new QdrantVectorStoreRecordCollection>( - this._qdrantClientMock.Object, - TestCollectionName, - new() - { - HasNamedVectors = false, - PointStructCustomMapper = mapperMock.Object - }); - - var model = CreateModel(UlongTestRecordKey1, true); - - // Act - await sut.UpsertAsync(model, this._testCancellationToken); - - // Assert - mapperMock - .Verify( - x => x.MapFromDataToStorageModel(It.Is>(x => x == model)), - Times.Once); - } - /// /// Tests that the collection can be created even if the definition and the type do not match. /// In this case, the expectation is that a custom mapper will be provided to map between the @@ -532,17 +446,17 @@ public void CanCreateCollectionWithMismatchedDefinitionAndType() { Properties = new List { - new VectorStoreRecordKeyProperty("Id", typeof(ulong)), - new VectorStoreRecordDataProperty("Text", typeof(string)), - new VectorStoreRecordVectorProperty("Embedding", typeof(ReadOnlyMemory)) { Dimensions = 4 }, + new VectorStoreRecordKeyProperty(nameof(SinglePropsModel.Key), typeof(ulong)), + new VectorStoreRecordDataProperty(nameof(SinglePropsModel.OriginalNameData), typeof(string)), + new VectorStoreRecordVectorProperty(nameof(SinglePropsModel.Vector), typeof(ReadOnlyMemory?), 4), } }; // Act. - var sut = new QdrantVectorStoreRecordCollection>( + var sut = new QdrantVectorStoreRecordCollection>( this._qdrantClientMock.Object, TestCollectionName, - new() { VectorStoreRecordDefinition = definition, PointStructCustomMapper = Mock.Of, PointStruct>>() }); + new() { VectorStoreRecordDefinition = definition }); } #pragma warning disable CS0618 // VectorSearchFilter is obsolete @@ -559,10 +473,11 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bo var filter = new VectorSearchFilter().EqualTo(nameof(SinglePropsModel.Data), "data 1"); // Act. - var actual = await sut.VectorizedSearchAsync( + var results = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new[] { 1f, 2f, 3f, 4f }), - new() { IncludeVectors = true, OldFilter = filter, Top = 5, Skip = 2 }, - this._testCancellationToken); + top: 5, + new() { IncludeVectors = true, OldFilter = filter, Skip = 2 }, + this._testCancellationToken).ToListAsync(); // Assert. this._qdrantClientMock @@ -586,7 +501,6 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bo this._testCancellationToken), Times.Once); - var results = await actual.Results.ToListAsync(); Assert.Single(results); Assert.Equal(testRecordKey, results.First().Record.Key); Assert.Equal("data 1", results.First().Record.OriginalNameData); @@ -768,7 +682,7 @@ private static ScoredPoint CreateScoredPoint(bool hasNamedVectors, TKey re private IVectorStoreRecordCollection> CreateRecordCollection(bool useDefinition, bool hasNamedVectors) where T : notnull { - var store = new QdrantVectorStoreRecordCollection>( + var store = new QdrantVectorStoreRecordCollection>( this._qdrantClientMock.Object, TestCollectionName, new() @@ -798,9 +712,9 @@ private static VectorStoreRecordDefinition CreateSinglePropsDefinition(Type keyT Properties = [ new VectorStoreRecordKeyProperty("Key", keyType), - new VectorStoreRecordDataProperty("OriginalNameData", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, - new VectorStoreRecordDataProperty("Data", typeof(string)) { IsFilterable = true, StoragePropertyName = "data_storage_name" }, - new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) { StoragePropertyName = "vector_storage_name" } + new VectorStoreRecordDataProperty("OriginalNameData", typeof(string)) { IsIndexed = true, IsFullTextIndexed = true }, + new VectorStoreRecordDataProperty("Data", typeof(string)) { IsIndexed = true, StoragePropertyName = "data_storage_name" }, + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 4) { StoragePropertyName = "vector_storage_name" } ] }; } @@ -810,11 +724,11 @@ public sealed class SinglePropsModel [VectorStoreRecordKey] public required T Key { get; set; } - [VectorStoreRecordData(IsFilterable = true, IsFullTextSearchable = true)] + [VectorStoreRecordData(IsIndexed = true, IsFullTextIndexed = true)] public string OriginalNameData { get; set; } = string.Empty; [JsonPropertyName("ignored_data_json_name")] - [VectorStoreRecordData(IsFilterable = true, StoragePropertyName = "data_storage_name")] + [VectorStoreRecordData(IsIndexed = true, StoragePropertyName = "data_storage_name")] public string Data { get; set; } = string.Empty; [JsonPropertyName("ignored_vector_json_name")] diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordMapperTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordMapperTests.cs index 29fa57ddb5d9..c172129f83b7 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordMapperTests.cs +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordMapperTests.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Text.Json.Serialization; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.Qdrant; using Qdrant.Client.Grpc; using Xunit; @@ -23,11 +24,12 @@ public void MapsSinglePropsFromDataToStorageModelWithUlong(bool hasNamedVectors) { // Arrange. var definition = CreateSinglePropsVectorStoreRecordDefinition(typeof(ulong)); - var reader = new VectorStoreRecordPropertyReader(typeof(SinglePropsModel), definition, null); - var sut = new QdrantVectorStoreRecordMapper>(reader, hasNamedVectors); + var model = new VectorStoreRecordModelBuilder(QdrantVectorStoreRecordFieldMapping.GetModelBuildOptions(hasNamedVectors)) + .Build(typeof(SinglePropsModel), definition, defaultEmbeddingGenerator: null); + var sut = new QdrantVectorStoreRecordMapper>(model, hasNamedVectors); // Act. - var actual = sut.MapFromDataToStorageModel(CreateSinglePropsModel(5ul)); + var actual = sut.MapFromDataToStorageModel(CreateSinglePropsModel(5ul), recordIndex: 0, generatedEmbeddings: null); // Assert. Assert.NotNull(actual); @@ -52,11 +54,12 @@ public void MapsSinglePropsFromDataToStorageModelWithGuid(bool hasNamedVectors) { // Arrange. var definition = CreateSinglePropsVectorStoreRecordDefinition(typeof(Guid)); - var reader = new VectorStoreRecordPropertyReader(typeof(SinglePropsModel), definition, null); - var sut = new QdrantVectorStoreRecordMapper>(reader, hasNamedVectors); + var model = new VectorStoreRecordModelBuilder(QdrantVectorStoreRecordFieldMapping.GetModelBuildOptions(hasNamedVectors)) + .Build(typeof(SinglePropsModel), definition, defaultEmbeddingGenerator: null); + var sut = new QdrantVectorStoreRecordMapper>(model, hasNamedVectors); // Act. - var actual = sut.MapFromDataToStorageModel(CreateSinglePropsModel(Guid.Parse("11111111-1111-1111-1111-111111111111"))); + var actual = sut.MapFromDataToStorageModel(CreateSinglePropsModel(Guid.Parse("11111111-1111-1111-1111-111111111111")), recordIndex: 0, generatedEmbeddings: null); // Assert. Assert.NotNull(actual); @@ -74,11 +77,13 @@ public void MapsSinglePropsFromStorageToDataModelWithUlong(bool hasNamedVectors, { // Arrange. var definition = CreateSinglePropsVectorStoreRecordDefinition(typeof(ulong)); - var reader = new VectorStoreRecordPropertyReader(typeof(SinglePropsModel), definition, null); - var sut = new QdrantVectorStoreRecordMapper>(reader, hasNamedVectors); + var model = new VectorStoreRecordModelBuilder(QdrantVectorStoreRecordFieldMapping.GetModelBuildOptions(hasNamedVectors)) + .Build(typeof(SinglePropsModel), definition, defaultEmbeddingGenerator: null); + var sut = new QdrantVectorStoreRecordMapper>(model, hasNamedVectors); // Act. - var actual = sut.MapFromStorageToDataModel(CreateSinglePropsPointStruct(5, hasNamedVectors), new() { IncludeVectors = includeVectors }); + var point = CreateSinglePropsPointStruct(5, hasNamedVectors); + var actual = sut.MapFromStorageToDataModel(point.Id, point.Payload, point.Vectors, new() { IncludeVectors = includeVectors }); // Assert. Assert.NotNull(actual); @@ -104,11 +109,13 @@ public void MapsSinglePropsFromStorageToDataModelWithGuid(bool hasNamedVectors, { // Arrange. var definition = CreateSinglePropsVectorStoreRecordDefinition(typeof(Guid)); - var reader = new VectorStoreRecordPropertyReader(typeof(SinglePropsModel), definition, null); - var sut = new QdrantVectorStoreRecordMapper>(reader, hasNamedVectors); + var model = new VectorStoreRecordModelBuilder(QdrantVectorStoreRecordFieldMapping.GetModelBuildOptions(hasNamedVectors)) + .Build(typeof(SinglePropsModel), definition, defaultEmbeddingGenerator: null); + var sut = new QdrantVectorStoreRecordMapper>(model, hasNamedVectors); // Act. - var actual = sut.MapFromStorageToDataModel(CreateSinglePropsPointStruct(Guid.Parse("11111111-1111-1111-1111-111111111111"), hasNamedVectors), new() { IncludeVectors = includeVectors }); + var point = CreateSinglePropsPointStruct(Guid.Parse("11111111-1111-1111-1111-111111111111"), hasNamedVectors); + var actual = sut.MapFromStorageToDataModel(point.Id, point.Payload, point.Vectors, new() { IncludeVectors = includeVectors }); // Assert. Assert.NotNull(actual); @@ -130,11 +137,13 @@ public void MapsMultiPropsFromDataToStorageModelWithUlong() { // Arrange. var definition = CreateMultiPropsVectorStoreRecordDefinition(typeof(ulong)); - var reader = new VectorStoreRecordPropertyReader(typeof(MultiPropsModel), definition, null); - var sut = new QdrantVectorStoreRecordMapper>(reader, true); + var model = new VectorStoreRecordModelBuilder(QdrantVectorStoreRecordFieldMapping.GetModelBuildOptions(hasNamedVectors: true)) + .Build(typeof(MultiPropsModel), definition, defaultEmbeddingGenerator: null); + + var sut = new QdrantVectorStoreRecordMapper>(model, hasNamedVectors: true); // Act. - var actual = sut.MapFromDataToStorageModel(CreateMultiPropsModel(5ul)); + var actual = sut.MapFromDataToStorageModel(CreateMultiPropsModel(5ul), recordIndex: 0, generatedEmbeddings: null); // Assert. Assert.NotNull(actual); @@ -157,11 +166,12 @@ public void MapsMultiPropsFromDataToStorageModelWithGuid() { // Arrange. var definition = CreateMultiPropsVectorStoreRecordDefinition(typeof(Guid)); - var reader = new VectorStoreRecordPropertyReader(typeof(MultiPropsModel), definition, null); - var sut = new QdrantVectorStoreRecordMapper>(reader, true); + var model = new VectorStoreRecordModelBuilder(QdrantVectorStoreRecordFieldMapping.GetModelBuildOptions(hasNamedVectors: true)) + .Build(typeof(MultiPropsModel), definition, defaultEmbeddingGenerator: null); + var sut = new QdrantVectorStoreRecordMapper>(model, hasNamedVectors: true); // Act. - var actual = sut.MapFromDataToStorageModel(CreateMultiPropsModel(Guid.Parse("11111111-1111-1111-1111-111111111111"))); + var actual = sut.MapFromDataToStorageModel(CreateMultiPropsModel(Guid.Parse("11111111-1111-1111-1111-111111111111")), recordIndex: 0, generatedEmbeddings: null); // Assert. Assert.NotNull(actual); @@ -186,11 +196,13 @@ public void MapsMultiPropsFromStorageToDataModelWithUlong(bool includeVectors) { // Arrange. var definition = CreateMultiPropsVectorStoreRecordDefinition(typeof(ulong)); - var reader = new VectorStoreRecordPropertyReader(typeof(MultiPropsModel), definition, null); - var sut = new QdrantVectorStoreRecordMapper>(reader, true); + var model = new VectorStoreRecordModelBuilder(QdrantVectorStoreRecordFieldMapping.GetModelBuildOptions(hasNamedVectors: true)) + .Build(typeof(MultiPropsModel), definition, defaultEmbeddingGenerator: null); + var sut = new QdrantVectorStoreRecordMapper>(model, hasNamedVectors: true); // Act. - var actual = sut.MapFromStorageToDataModel(CreateMultiPropsPointStruct(5), new() { IncludeVectors = includeVectors }); + var point = CreateMultiPropsPointStruct(5); + var actual = sut.MapFromStorageToDataModel(point.Id, point.Payload, point.Vectors, new() { IncludeVectors = includeVectors }); // Assert. Assert.NotNull(actual); @@ -223,11 +235,13 @@ public void MapsMultiPropsFromStorageToDataModelWithGuid(bool includeVectors) { // Arrange. var definition = CreateMultiPropsVectorStoreRecordDefinition(typeof(Guid)); - var reader = new VectorStoreRecordPropertyReader(typeof(MultiPropsModel), definition, null); - var sut = new QdrantVectorStoreRecordMapper>(reader, true); + var model = new VectorStoreRecordModelBuilder(QdrantVectorStoreRecordFieldMapping.GetModelBuildOptions(hasNamedVectors: true)) + .Build(typeof(MultiPropsModel), definition, defaultEmbeddingGenerator: null); + var sut = new QdrantVectorStoreRecordMapper>(model, hasNamedVectors: true); // Act. - var actual = sut.MapFromStorageToDataModel(CreateMultiPropsPointStruct(Guid.Parse("11111111-1111-1111-1111-111111111111")), new() { IncludeVectors = includeVectors }); + var point = CreateMultiPropsPointStruct(Guid.Parse("11111111-1111-1111-1111-111111111111")); + var actual = sut.MapFromStorageToDataModel(point.Id, point.Payload, point.Vectors, new() { IncludeVectors = includeVectors }); // Assert. Assert.NotNull(actual); @@ -283,55 +297,57 @@ private static MultiPropsModel CreateMultiPropsModel(TKey key) }; } - private static PointStruct CreateSinglePropsPointStruct(ulong id, bool hasNamedVectors) + private static RetrievedPoint CreateSinglePropsPointStruct(ulong id, bool hasNamedVectors) { - var pointStruct = new PointStruct(); + var pointStruct = new RetrievedPoint(); pointStruct.Id = new PointId() { Num = id }; AddDataToSinglePropsPointStruct(pointStruct, hasNamedVectors); return pointStruct; } - private static PointStruct CreateSinglePropsPointStruct(Guid id, bool hasNamedVectors) + private static RetrievedPoint CreateSinglePropsPointStruct(Guid id, bool hasNamedVectors) { - var pointStruct = new PointStruct(); + var pointStruct = new RetrievedPoint(); pointStruct.Id = new PointId() { Uuid = id.ToString() }; AddDataToSinglePropsPointStruct(pointStruct, hasNamedVectors); return pointStruct; } - private static void AddDataToSinglePropsPointStruct(PointStruct pointStruct, bool hasNamedVectors) + private static void AddDataToSinglePropsPointStruct(RetrievedPoint pointStruct, bool hasNamedVectors) { + var responseVector = VectorOutput.Parser.ParseJson("{ \"data\": [1, 2, 3, 4] }"); + pointStruct.Payload.Add("data", "data value"); if (hasNamedVectors) { - var namedVectors = new NamedVectors(); - namedVectors.Vectors.Add("vector", new[] { 1f, 2f, 3f, 4f }); - pointStruct.Vectors = new Vectors() { Vectors_ = namedVectors }; + var namedVectors = new NamedVectorsOutput(); + namedVectors.Vectors.Add("vector", responseVector); + pointStruct.Vectors = new VectorsOutput() { Vectors = namedVectors }; } else { - pointStruct.Vectors = new[] { 1f, 2f, 3f, 4f }; + pointStruct.Vectors = new VectorsOutput() { Vector = responseVector }; } } - private static PointStruct CreateMultiPropsPointStruct(ulong id) + private static RetrievedPoint CreateMultiPropsPointStruct(ulong id) { - var pointStruct = new PointStruct(); + var pointStruct = new RetrievedPoint(); pointStruct.Id = new PointId() { Num = id }; AddDataToMultiPropsPointStruct(pointStruct); return pointStruct; } - private static PointStruct CreateMultiPropsPointStruct(Guid id) + private static RetrievedPoint CreateMultiPropsPointStruct(Guid id) { - var pointStruct = new PointStruct(); + var pointStruct = new RetrievedPoint(); pointStruct.Id = new PointId() { Uuid = id.ToString() }; AddDataToMultiPropsPointStruct(pointStruct); return pointStruct; } - private static void AddDataToMultiPropsPointStruct(PointStruct pointStruct) + private static void AddDataToMultiPropsPointStruct(RetrievedPoint pointStruct) { pointStruct.Payload.Add("dataString", "data 1"); pointStruct.Payload.Add("dataInt", 5); @@ -348,10 +364,13 @@ private static void AddDataToMultiPropsPointStruct(PointStruct pointStruct) dataIntArray.Values.Add(4); pointStruct.Payload.Add("dataArrayInt", new Value { ListValue = dataIntArray }); - var namedVectors = new NamedVectors(); - namedVectors.Vectors.Add("vector1", new[] { 1f, 2f, 3f, 4f }); - namedVectors.Vectors.Add("vector2", new[] { 5f, 6f, 7f, 8f }); - pointStruct.Vectors = new Vectors() { Vectors_ = namedVectors }; + var responseVector1 = VectorOutput.Parser.ParseJson("{ \"data\": [1, 2, 3, 4] }"); + var responseVector2 = VectorOutput.Parser.ParseJson("{ \"data\": [5, 6, 7, 8] }"); + + var namedVectors = new NamedVectorsOutput(); + namedVectors.Vectors.Add("vector1", responseVector1); + namedVectors.Vectors.Add("vector2", responseVector2); + pointStruct.Vectors = new VectorsOutput() { Vectors = namedVectors }; } private static VectorStoreRecordDefinition CreateSinglePropsVectorStoreRecordDefinition(Type keyType) => new() @@ -360,7 +379,7 @@ private static void AddDataToMultiPropsPointStruct(PointStruct pointStruct) { new VectorStoreRecordKeyProperty("Key", keyType) { StoragePropertyName = "key" }, new VectorStoreRecordDataProperty("Data", typeof(string)) { StoragePropertyName = "data" }, - new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) { StoragePropertyName = "vector" }, + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 10) { StoragePropertyName = "vector" }, }, }; @@ -372,7 +391,7 @@ private sealed class SinglePropsModel [VectorStoreRecordData(StoragePropertyName = "data")] public string Data { get; set; } = string.Empty; - [VectorStoreRecordVector(StoragePropertyName = "vector")] + [VectorStoreRecordVector(10, StoragePropertyName = "vector")] public ReadOnlyMemory? Vector { get; set; } public string NotAnnotated { get; set; } = string.Empty; @@ -391,8 +410,8 @@ private sealed class SinglePropsModel new VectorStoreRecordDataProperty("DataBool", typeof(bool)) { StoragePropertyName = "dataBool" }, new VectorStoreRecordDataProperty("DataDateTimeOffset", typeof(DateTimeOffset)) { StoragePropertyName = "dataDateTimeOffset" }, new VectorStoreRecordDataProperty("DataArrayInt", typeof(List)) { StoragePropertyName = "dataArrayInt" }, - new VectorStoreRecordVectorProperty("Vector1", typeof(ReadOnlyMemory)) { StoragePropertyName = "vector1" }, - new VectorStoreRecordVectorProperty("Vector2", typeof(ReadOnlyMemory)) { StoragePropertyName = "vector2" }, + new VectorStoreRecordVectorProperty("Vector1", typeof(ReadOnlyMemory), 10) { StoragePropertyName = "vector1" }, + new VectorStoreRecordVectorProperty("Vector2", typeof(ReadOnlyMemory), 10) { StoragePropertyName = "vector2" }, }, }; @@ -426,10 +445,10 @@ private sealed class MultiPropsModel [VectorStoreRecordData(StoragePropertyName = "dataArrayInt")] public List? DataArrayInt { get; set; } - [VectorStoreRecordVector(StoragePropertyName = "vector1")] + [VectorStoreRecordVector(10, StoragePropertyName = "vector1")] public ReadOnlyMemory? Vector1 { get; set; } - [VectorStoreRecordVector(StoragePropertyName = "vector2")] + [VectorStoreRecordVector(10, StoragePropertyName = "vector2")] public ReadOnlyMemory? Vector2 { get; set; } public string NotAnnotated { get; set; } = string.Empty; diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreTests.cs index 9230b5f31fe0..19bc64df4334 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreTests.cs @@ -38,7 +38,7 @@ public void GetCollectionReturnsCollection() // Assert. Assert.NotNull(actual); - Assert.IsType>>(actual); + Assert.IsType>>(actual); } #pragma warning disable CS0618 // IQdrantVectorStoreRecordCollectionFactory is obsolete diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/Connectors.Redis.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Redis.UnitTests/Connectors.Redis.UnitTests.csproj index c54e1a3b5136..1593ec444e1b 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/Connectors.Redis.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/Connectors.Redis.UnitTests.csproj @@ -8,7 +8,8 @@ enable disable false - $(NoWarn);CA2007,CA1806,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0050 + $(NoWarn);CA2007,CA1806,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0050 + $(NoWarn);MEVD9001 diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetDynamicDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetDynamicDataModelMapperTests.cs new file mode 100644 index 000000000000..c34eb89794c2 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetDynamicDataModelMapperTests.cs @@ -0,0 +1,212 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using StackExchange.Redis; +using Xunit; + +namespace Microsoft.SemanticKernel.Connectors.Redis.UnitTests; + +/// +/// Contains dynamic mapping tests for the class. +/// +public class RedisHashSetDynamicDataModelMapperTests +{ + private static readonly VectorStoreRecordModel s_model = BuildModel(RedisHashSetVectorStoreMappingTestHelpers.s_vectorStoreRecordDefinition); + + private static readonly float[] s_floatVector = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + private static readonly double[] s_doubleVector = new double[] { 5.0d, 6.0d, 7.0d, 8.0d }; + + [Fact] + public void MapFromDataToStorageModelMapsAllSupportedTypes() + { + // Arrange. + var sut = new RedisHashSetVectorStoreRecordMapper>(s_model); + var dataModel = new Dictionary + { + ["Key"] = "key", + + ["StringData"] = "data 1", + ["IntData"] = 1, + ["UIntData"] = 2u, + ["LongData"] = 3L, + ["ULongData"] = 4ul, + ["DoubleData"] = 5.5d, + ["FloatData"] = 6.6f, + ["BoolData"] = true, + ["NullableIntData"] = 7, + ["NullableUIntData"] = 8u, + ["NullableLongData"] = 9L, + ["NullableULongData"] = 10ul, + ["NullableDoubleData"] = 11.1d, + ["NullableFloatData"] = 12.2f, + ["NullableBoolData"] = false, + + ["FloatVector"] = new ReadOnlyMemory(s_floatVector), + ["DoubleVector"] = new ReadOnlyMemory(s_doubleVector), + }; + + // Act. + var storageModel = sut.MapFromDataToStorageModel(dataModel, recordIndex: 0, generatedEmbeddings: null); + + // Assert + Assert.Equal("key", storageModel.Key); + RedisHashSetVectorStoreMappingTestHelpers.VerifyHashSet(storageModel.HashEntries); + } + + [Fact] + public void MapFromDataToStorageModelMapsNullValues() + { + // Arrange + VectorStoreRecordModel model = BuildModel(new() + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringData", typeof(string)) { StoragePropertyName = "storage_string_data" }, + new VectorStoreRecordDataProperty("NullableIntData", typeof(int?)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory?), 10), + }, + }); + + var dataModel = new Dictionary + { + ["Key"] = "key", + ["StringData"] = null, + ["NullableIntData"] = null, + ["FloatVector"] = null, + }; + + var sut = new RedisHashSetVectorStoreRecordMapper>(model); + + // Act + var storageModel = sut.MapFromDataToStorageModel(dataModel, recordIndex: 0, generatedEmbeddings: null); + + // Assert + Assert.Equal("key", storageModel.Key); + + Assert.Equal("storage_string_data", storageModel.HashEntries[0].Name.ToString()); + Assert.True(storageModel.HashEntries[0].Value.IsNull); + + Assert.Equal("NullableIntData", storageModel.HashEntries[1].Name.ToString()); + Assert.True(storageModel.HashEntries[1].Value.IsNull); + } + + [Fact] + public void MapFromStorageToDataModelMapsAllSupportedTypes() + { + // Arrange. + var hashSet = RedisHashSetVectorStoreMappingTestHelpers.CreateHashSet(); + var sut = new RedisHashSetVectorStoreRecordMapper>(s_model); + + // Act. + var dataModel = sut.MapFromStorageToDataModel(("key", hashSet), new() { IncludeVectors = true }); + + // Assert. + Assert.Equal("key", dataModel["Key"]); + Assert.Equal("data 1", dataModel["StringData"]); + Assert.Equal(1, dataModel["IntData"]); + Assert.Equal(2u, dataModel["UIntData"]); + Assert.Equal(3L, dataModel["LongData"]); + Assert.Equal(4ul, dataModel["ULongData"]); + Assert.Equal(5.5d, dataModel["DoubleData"]); + Assert.Equal(6.6f, dataModel["FloatData"]); + Assert.True((bool)dataModel["BoolData"]!); + Assert.Equal(7, dataModel["NullableIntData"]); + Assert.Equal(8u, dataModel["NullableUIntData"]); + Assert.Equal(9L, dataModel["NullableLongData"]); + Assert.Equal(10ul, dataModel["NullableULongData"]); + Assert.Equal(11.1d, dataModel["NullableDoubleData"]); + Assert.Equal(12.2f, dataModel["NullableFloatData"]); + Assert.False((bool)dataModel["NullableBoolData"]!); + Assert.Equal(new float[] { 1, 2, 3, 4 }, ((ReadOnlyMemory)dataModel["FloatVector"]!).ToArray()); + Assert.Equal(new double[] { 5, 6, 7, 8 }, ((ReadOnlyMemory)dataModel["DoubleVector"]!).ToArray()); + } + + [Fact] + public void MapFromStorageToDataModelMapsNullValues() + { + // Arrange + var model = BuildModel(new() + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringData", typeof(string)) { StoragePropertyName = "storage_string_data" }, + new VectorStoreRecordDataProperty("NullableIntData", typeof(int?)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory?), 10), + } + }); + + var hashSet = new HashEntry[] + { + new("storage_string_data", RedisValue.Null), + new("NullableIntData", RedisValue.Null), + new("FloatVector", RedisValue.Null), + }; + + var sut = new RedisHashSetVectorStoreRecordMapper>(model); + + // Act + var dataModel = sut.MapFromStorageToDataModel(("key", hashSet), new() { IncludeVectors = true }); + + // Assert + Assert.Equal("key", dataModel["Key"]); + Assert.Null(dataModel["StringData"]); + Assert.Null(dataModel["NullableIntData"]); + Assert.Null(dataModel["FloatVector"]); + } + + [Fact] + public void MapFromDataToStorageModelSkipsMissingProperties() + { + // Arrange. + var model = BuildModel(new() + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringData", typeof(string)) { StoragePropertyName = "storage_string_data" }, + new VectorStoreRecordDataProperty("NullableIntData", typeof(int?)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory?), 10), + } + }); + + var sut = new RedisHashSetVectorStoreRecordMapper>(model); + var dataModel = new Dictionary { ["Key"] = "key" }; + + // Act. + var storageModel = sut.MapFromDataToStorageModel(dataModel, recordIndex: 0, generatedEmbeddings: null); + + // Assert + Assert.Equal("key", storageModel.Key); + + Assert.Equal("storage_string_data", storageModel.HashEntries[0].Name.ToString()); + Assert.True(storageModel.HashEntries[0].Value.IsNull); + + Assert.Equal("NullableIntData", storageModel.HashEntries[1].Name.ToString()); + Assert.True(storageModel.HashEntries[1].Value.IsNull); + } + + [Fact] + public void MapFromStorageToDataModelSkipsMissingProperties() + { + // Arrange. + var hashSet = Array.Empty(); + + var sut = new RedisHashSetVectorStoreRecordMapper>(s_model); + + // Act. + var dataModel = sut.MapFromStorageToDataModel(("key", hashSet), new() { IncludeVectors = true }); + + // Assert. + Assert.Single(dataModel); + Assert.Equal("key", dataModel["Key"]); + } + + private static VectorStoreRecordModel BuildModel(VectorStoreRecordDefinition definition) + => new VectorStoreRecordModelBuilder(RedisHashSetVectorStoreRecordCollection>.ModelBuildingOptions) + .Build(typeof(Dictionary), definition, defaultEmbeddingGenerator: null); +} diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetGenericDataModelMapperTests.cs deleted file mode 100644 index ce0d0c9767d0..000000000000 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetGenericDataModelMapperTests.cs +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using Microsoft.Extensions.VectorData; -using StackExchange.Redis; -using Xunit; - -namespace Microsoft.SemanticKernel.Connectors.Redis.UnitTests; - -/// -/// Contains tests for the class. -/// -public class RedisHashSetGenericDataModelMapperTests -{ - private static readonly float[] s_floatVector = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; - private static readonly double[] s_doubleVector = new double[] { 5.0d, 6.0d, 7.0d, 8.0d }; - - [Fact] - public void MapFromDataToStorageModelMapsAllSupportedTypes() - { - // Arrange. - var sut = new RedisHashSetGenericDataModelMapper(RedisHashSetVectorStoreMappingTestHelpers.s_vectorStoreRecordDefinition.Properties); - var dataModel = new VectorStoreGenericDataModel("key") - { - Data = - { - ["StringData"] = "data 1", - ["IntData"] = 1, - ["UIntData"] = 2u, - ["LongData"] = 3L, - ["ULongData"] = 4ul, - ["DoubleData"] = 5.5d, - ["FloatData"] = 6.6f, - ["BoolData"] = true, - ["NullableIntData"] = 7, - ["NullableUIntData"] = 8u, - ["NullableLongData"] = 9L, - ["NullableULongData"] = 10ul, - ["NullableDoubleData"] = 11.1d, - ["NullableFloatData"] = 12.2f, - ["NullableBoolData"] = false, - }, - Vectors = - { - ["FloatVector"] = new ReadOnlyMemory(s_floatVector), - ["DoubleVector"] = new ReadOnlyMemory(s_doubleVector), - }, - }; - - // Act. - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal("key", storageModel.Key); - RedisHashSetVectorStoreMappingTestHelpers.VerifyHashSet(storageModel.HashEntries); - } - - [Fact] - public void MapFromDataToStorageModelMapsNullValues() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringData", typeof(string)), - new VectorStoreRecordDataProperty("NullableIntData", typeof(int?)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory?)), - }, - }; - - var dataModel = new VectorStoreGenericDataModel("key") - { - Data = - { - ["StringData"] = null, - ["NullableIntData"] = null, - }, - Vectors = - { - ["FloatVector"] = null, - }, - }; - - var sut = new RedisHashSetGenericDataModelMapper(RedisHashSetVectorStoreMappingTestHelpers.s_vectorStoreRecordDefinition.Properties); - - // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal("key", storageModel.Key); - - Assert.Equal("storage_string_data", storageModel.HashEntries[0].Name.ToString()); - Assert.True(storageModel.HashEntries[0].Value.IsNull); - - Assert.Equal("NullableIntData", storageModel.HashEntries[1].Name.ToString()); - Assert.True(storageModel.HashEntries[1].Value.IsNull); - - Assert.Equal("FloatVector", storageModel.HashEntries[2].Name.ToString()); - Assert.True(storageModel.HashEntries[2].Value.IsNull); - } - - [Fact] - public void MapFromStorageToDataModelMapsAllSupportedTypes() - { - // Arrange. - var hashSet = RedisHashSetVectorStoreMappingTestHelpers.CreateHashSet(); - - var sut = new RedisHashSetGenericDataModelMapper(RedisHashSetVectorStoreMappingTestHelpers.s_vectorStoreRecordDefinition.Properties); - - // Act. - var dataModel = sut.MapFromStorageToDataModel(("key", hashSet), new() { IncludeVectors = true }); - - // Assert. - Assert.Equal("key", dataModel.Key); - Assert.Equal("data 1", dataModel.Data["StringData"]); - Assert.Equal(1, dataModel.Data["IntData"]); - Assert.Equal(2u, dataModel.Data["UIntData"]); - Assert.Equal(3L, dataModel.Data["LongData"]); - Assert.Equal(4ul, dataModel.Data["ULongData"]); - Assert.Equal(5.5d, dataModel.Data["DoubleData"]); - Assert.Equal(6.6f, dataModel.Data["FloatData"]); - Assert.True((bool)dataModel.Data["BoolData"]!); - Assert.Equal(7, dataModel.Data["NullableIntData"]); - Assert.Equal(8u, dataModel.Data["NullableUIntData"]); - Assert.Equal(9L, dataModel.Data["NullableLongData"]); - Assert.Equal(10ul, dataModel.Data["NullableULongData"]); - Assert.Equal(11.1d, dataModel.Data["NullableDoubleData"]); - Assert.Equal(12.2f, dataModel.Data["NullableFloatData"]); - Assert.False((bool)dataModel.Data["NullableBoolData"]!); - Assert.Equal(new float[] { 1, 2, 3, 4 }, ((ReadOnlyMemory)dataModel.Vectors["FloatVector"]!).ToArray()); - Assert.Equal(new double[] { 5, 6, 7, 8 }, ((ReadOnlyMemory)dataModel.Vectors["DoubleVector"]!).ToArray()); - } - - [Fact] - public void MapFromStorageToDataModelMapsNullValues() - { - // Arrange - VectorStoreRecordDefinition vectorStoreRecordDefinition = new() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringData", typeof(string)), - new VectorStoreRecordDataProperty("NullableIntData", typeof(int?)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory?)), - }, - }; - - var hashSet = new HashEntry[] - { - new("storage_string_data", RedisValue.Null), - new("NullableIntData", RedisValue.Null), - new("FloatVector", RedisValue.Null), - }; - - var sut = new RedisHashSetGenericDataModelMapper(RedisHashSetVectorStoreMappingTestHelpers.s_vectorStoreRecordDefinition.Properties); - - // Act - var dataModel = sut.MapFromStorageToDataModel(("key", hashSet), new() { IncludeVectors = true }); - - // Assert - Assert.Equal("key", dataModel.Key); - Assert.Null(dataModel.Data["StringData"]); - Assert.Null(dataModel.Data["NullableIntData"]); - Assert.Null(dataModel.Vectors["FloatVector"]); - } - - [Fact] - public void MapFromDataToStorageModelSkipsMissingProperties() - { - // Arrange. - var sut = new RedisHashSetGenericDataModelMapper(RedisHashSetVectorStoreMappingTestHelpers.s_vectorStoreRecordDefinition.Properties); - var dataModel = new VectorStoreGenericDataModel("key") - { - Data = { }, - Vectors = { }, - }; - - // Act. - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal("key", storageModel.Key); - Assert.Empty(storageModel.HashEntries); - } - - [Fact] - public void MapFromStorageToDataModelSkipsMissingProperties() - { - // Arrange. - var hashSet = Array.Empty(); - - var sut = new RedisHashSetGenericDataModelMapper(RedisHashSetVectorStoreMappingTestHelpers.s_vectorStoreRecordDefinition.Properties); - - // Act. - var dataModel = sut.MapFromStorageToDataModel(("key", hashSet), new() { IncludeVectors = true }); - - // Assert. - Assert.Equal("key", dataModel.Key); - Assert.Empty(dataModel.Data); - Assert.Empty(dataModel.Vectors); - } -} diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreMappingTestHelpers.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreMappingTestHelpers.cs index 8b46f69b844b..5ed25e96dcc5 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreMappingTestHelpers.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreMappingTestHelpers.cs @@ -35,8 +35,8 @@ internal static class RedisHashSetVectorStoreMappingTestHelpers new VectorStoreRecordDataProperty("NullableDoubleData", typeof(double?)), new VectorStoreRecordDataProperty("NullableFloatData", typeof(float?)), new VectorStoreRecordDataProperty("NullableBoolData", typeof(bool?)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - new VectorStoreRecordVectorProperty("DoubleVector", typeof(ReadOnlyMemory)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10), + new VectorStoreRecordVectorProperty("DoubleVector", typeof(ReadOnlyMemory), 10), } }; diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs index 117d3d1fcd4b..928c65480143 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs @@ -15,7 +15,7 @@ namespace Microsoft.SemanticKernel.Connectors.Redis.UnitTests; /// -/// Contains tests for the class. +/// Contains tests for the class. /// public class RedisHashSetVectorStoreRecordCollectionTests { @@ -28,6 +28,7 @@ public class RedisHashSetVectorStoreRecordCollectionTests public RedisHashSetVectorStoreRecordCollectionTests() { this._redisDatabaseMock = new Mock(MockBehavior.Strict); + this._redisDatabaseMock.Setup(l => l.Database).Returns(0); var batchMock = new Mock(); this._redisDatabaseMock.Setup(x => x.CreateBatch(It.IsAny())).Returns(batchMock.Object); @@ -47,7 +48,7 @@ public async Task CollectionExistsReturnsCollectionStateAsync(string collectionN { SetupExecuteMock(this._redisDatabaseMock, new RedisServerException("Unknown index name")); } - var sut = new RedisHashSetVectorStoreRecordCollection( + var sut = new RedisHashSetVectorStoreRecordCollection( this._redisDatabaseMock.Object, collectionName); @@ -70,7 +71,7 @@ public async Task CanCreateCollectionAsync() { // Arrange. SetupExecuteMock(this._redisDatabaseMock, string.Empty); - var sut = new RedisHashSetVectorStoreRecordCollection(this._redisDatabaseMock.Object, TestCollectionName); + var sut = new RedisHashSetVectorStoreRecordCollection(this._redisDatabaseMock.Object, TestCollectionName); // Act. await sut.CreateCollectionAsync(); @@ -216,7 +217,7 @@ public async Task CanGetManyRecordsWithVectorsAsync(bool useDefinition) var sut = this.CreateRecordCollection(useDefinition); // Act - var actual = await sut.GetBatchAsync( + var actual = await sut.GetAsync( [TestRecordKey1, TestRecordKey2], new() { IncludeVectors = true }).ToListAsync(); @@ -236,55 +237,6 @@ public async Task CanGetManyRecordsWithVectorsAsync(bool useDefinition) Assert.Equal(new float[] { 5, 6, 7, 8 }, actual[1].Vector!.Value.ToArray()); } - [Fact] - public async Task CanGetRecordWithCustomMapperAsync() - { - // Arrange. - var hashEntries = new HashEntry[] - { - new("OriginalNameData", "data 1"), - new("data_storage_name", "data 1"), - new("vector_storage_name", MemoryMarshal.AsBytes(new ReadOnlySpan(new float[] { 1, 2, 3, 4 })).ToArray()) - }; - this._redisDatabaseMock.Setup(x => x.HashGetAllAsync(It.IsAny(), CommandFlags.None)).ReturnsAsync(hashEntries); - - // Arrange mapper mock from JsonNode to data model. - var mapperMock = new Mock>(MockBehavior.Strict); - mapperMock.Setup( - x => x.MapFromStorageToDataModel( - It.IsAny<(string key, HashEntry[] hashEntries)>(), - It.IsAny())) - .Returns(CreateModel(TestRecordKey1, true)); - - // Arrange target with custom mapper. - var sut = new RedisHashSetVectorStoreRecordCollection( - this._redisDatabaseMock.Object, - TestCollectionName, - new() - { - HashEntriesCustomMapper = mapperMock.Object - }); - - // Act - var actual = await sut.GetAsync( - TestRecordKey1, - new() { IncludeVectors = true }); - - // Assert - Assert.NotNull(actual); - Assert.Equal(TestRecordKey1, actual.Key); - Assert.Equal("data 1", actual.OriginalNameData); - Assert.Equal("data 1", actual.Data); - Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector!.Value.ToArray()); - - mapperMock - .Verify( - x => x.MapFromStorageToDataModel( - It.Is<(string key, HashEntry[] hashEntries)>(x => x.key == TestRecordKey1), - It.Is(x => x.IncludeVectors)), - Times.Once); - } - [Theory] [InlineData(true)] [InlineData(false)] @@ -311,7 +263,7 @@ public async Task CanDeleteManyRecordsWithVectorsAsync(bool useDefinition) var sut = this.CreateRecordCollection(useDefinition); // Act - await sut.DeleteBatchAsync([TestRecordKey1, TestRecordKey2]); + await sut.DeleteAsync([TestRecordKey1, TestRecordKey2]); // Assert this._redisDatabaseMock.Verify(x => x.KeyDeleteAsync(TestRecordKey1, CommandFlags.None), Times.Once); @@ -353,7 +305,7 @@ public async Task CanUpsertManyRecordsAsync(bool useDefinition) var model2 = CreateModel(TestRecordKey2, true); // Act - var actual = await sut.UpsertBatchAsync([model1, model2]).ToListAsync(); + var actual = await sut.UpsertAsync([model1, model2]); // Assert Assert.NotNull(actual); @@ -375,46 +327,6 @@ public async Task CanUpsertManyRecordsAsync(bool useDefinition) Times.Once); } - [Fact] - public async Task CanUpsertRecordWithCustomMapperAsync() - { - // Arrange. - this._redisDatabaseMock.Setup(x => x.HashSetAsync(It.IsAny(), It.IsAny(), CommandFlags.None)).Returns(Task.CompletedTask); - - // Arrange mapper mock from data model to JsonNode. - var mapperMock = new Mock>(MockBehavior.Strict); - var hashEntries = new HashEntry[] - { - new("OriginalNameData", "data 1"), - new("data_storage_name", "data 1"), - new("vector_storage_name", "[1,2,3,4]"), - new("NotAnnotated", RedisValue.Null) - }; - mapperMock - .Setup(x => x.MapFromDataToStorageModel(It.IsAny())) - .Returns((TestRecordKey1, hashEntries)); - - // Arrange target with custom mapper. - var sut = new RedisHashSetVectorStoreRecordCollection( - this._redisDatabaseMock.Object, - TestCollectionName, - new() - { - HashEntriesCustomMapper = mapperMock.Object - }); - - var model = CreateModel(TestRecordKey1, true); - - // Act - await sut.UpsertAsync(model); - - // Assert - mapperMock - .Verify( - x => x.MapFromDataToStorageModel(It.Is(x => x == model)), - Times.Once); - } - #pragma warning disable CS0618 // VectorSearchFilter is obsolete [Theory] [InlineData(true, true)] @@ -446,15 +358,15 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bool inc var filter = new VectorSearchFilter().EqualTo(nameof(SinglePropsModel.Data), "data 1"); // Act. - var actual = await sut.VectorizedSearchAsync( + var results = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new[] { 1f, 2f, 3f, 4f }), + top: 5, new() { IncludeVectors = includeVectors, OldFilter = filter, - Top = 5, Skip = 2 - }); + }).ToListAsync(); // Assert. var expectedArgsPart1 = new object[] @@ -494,7 +406,6 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bool inc It.Is(x => x.Where(y => !(y is byte[])).SequenceEqual(expectedArgs.Where(y => !(y is byte[]))))), Times.Once); - var results = await actual.Results.ToListAsync(); Assert.Single(results); Assert.Equal(TestRecordKey1, results.First().Record.Key); Assert.Equal(0.25d, results.First().Score); @@ -511,6 +422,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bool inc } #pragma warning restore CS0618 // VectorSearchFilter is obsolete +#pragma warning disable CS0618 // IVectorStoreRecordMapper is obsolete /// /// Tests that the collection can be created even if the definition and the type do not match. /// In this case, the expectation is that a custom mapper will be provided to map between the @@ -524,22 +436,23 @@ public void CanCreateCollectionWithMismatchedDefinitionAndType() { Properties = new List { - new VectorStoreRecordKeyProperty("Id", typeof(string)), - new VectorStoreRecordDataProperty("Text", typeof(string)), - new VectorStoreRecordVectorProperty("Embedding", typeof(ReadOnlyMemory)) { Dimensions = 4 }, + new VectorStoreRecordKeyProperty(nameof(SinglePropsModel.Key), typeof(string)), + new VectorStoreRecordDataProperty(nameof(SinglePropsModel.OriginalNameData), typeof(string)), + new VectorStoreRecordVectorProperty(nameof(SinglePropsModel.Vector), typeof(ReadOnlyMemory?), 4), } }; // Act. - var sut = new RedisHashSetVectorStoreRecordCollection( + var sut = new RedisHashSetVectorStoreRecordCollection( this._redisDatabaseMock.Object, TestCollectionName, - new() { VectorStoreRecordDefinition = definition, HashEntriesCustomMapper = Mock.Of>() }); + new() { VectorStoreRecordDefinition = definition }); } +#pragma warning restore CS0618 - private RedisHashSetVectorStoreRecordCollection CreateRecordCollection(bool useDefinition) + private RedisHashSetVectorStoreRecordCollection CreateRecordCollection(bool useDefinition) { - return new RedisHashSetVectorStoreRecordCollection( + return new RedisHashSetVectorStoreRecordCollection( this._redisDatabaseMock.Object, TestCollectionName, new() @@ -618,7 +531,7 @@ private static SinglePropsModel CreateModel(string key, bool withVectors) new VectorStoreRecordKeyProperty("Key", typeof(string)), new VectorStoreRecordDataProperty("OriginalNameData", typeof(string)), new VectorStoreRecordDataProperty("Data", typeof(string)) { StoragePropertyName = "data_storage_name" }, - new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) { StoragePropertyName = "vector_storage_name", DistanceFunction = DistanceFunction.CosineDistance } + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 10) { StoragePropertyName = "vector_storage_name", DistanceFunction = DistanceFunction.CosineDistance } ] }; @@ -627,15 +540,15 @@ public sealed class SinglePropsModel [VectorStoreRecordKey] public string Key { get; set; } = string.Empty; - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string OriginalNameData { get; set; } = string.Empty; [JsonPropertyName("ignored_data_json_name")] - [VectorStoreRecordData(IsFilterable = true, StoragePropertyName = "data_storage_name")] + [VectorStoreRecordData(IsIndexed = true, StoragePropertyName = "data_storage_name")] public string Data { get; set; } = string.Empty; [JsonPropertyName("ignored_vector_json_name")] - [VectorStoreRecordVector(4, DistanceFunction.CosineDistance, StoragePropertyName = "vector_storage_name")] + [VectorStoreRecordVector(4, DistanceFunction = DistanceFunction.CosineDistance, StoragePropertyName = "vector_storage_name")] public ReadOnlyMemory? Vector { get; set; } public string? NotAnnotated { get; set; } diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordMapperTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordMapperTests.cs index 8eb570f15329..62a44c584cfa 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordMapperTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordMapperTests.cs @@ -2,6 +2,7 @@ using System; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.Redis; using Microsoft.SemanticKernel.Connectors.Redis.UnitTests; using Xunit; @@ -13,15 +14,18 @@ namespace SemanticKernel.Connectors.Redis.UnitTests; /// public sealed class RedisHashSetVectorStoreRecordMapperTests { + private static readonly VectorStoreRecordModel s_model + = new VectorStoreRecordModelBuilder(RedisHashSetVectorStoreRecordCollection.ModelBuildingOptions) + .Build(typeof(AllTypesModel), RedisHashSetVectorStoreMappingTestHelpers.s_vectorStoreRecordDefinition, defaultEmbeddingGenerator: null); + [Fact] public void MapsAllFieldsFromDataToStorageModel() { // Arrange. - var reader = new VectorStoreRecordPropertyReader(typeof(AllTypesModel), RedisHashSetVectorStoreMappingTestHelpers.s_vectorStoreRecordDefinition, null); - var sut = new RedisHashSetVectorStoreRecordMapper(reader); + var sut = new RedisHashSetVectorStoreRecordMapper(s_model); // Act. - var actual = sut.MapFromDataToStorageModel(CreateModel("test key")); + var actual = sut.MapFromDataToStorageModel(CreateModel("test key"), recordIndex: 0, generatedEmbeddings: null); // Assert. Assert.NotNull(actual.HashEntries); @@ -33,8 +37,7 @@ public void MapsAllFieldsFromDataToStorageModel() public void MapsAllFieldsFromStorageToDataModel() { // Arrange. - var reader = new VectorStoreRecordPropertyReader(typeof(AllTypesModel), RedisHashSetVectorStoreMappingTestHelpers.s_vectorStoreRecordDefinition, null); - var sut = new RedisHashSetVectorStoreRecordMapper(reader); + var sut = new RedisHashSetVectorStoreRecordMapper(s_model); // Act. var actual = sut.MapFromStorageToDataModel(("test key", RedisHashSetVectorStoreMappingTestHelpers.CreateHashSet()), new() { IncludeVectors = true }); @@ -138,10 +141,10 @@ private sealed class AllTypesModel [VectorStoreRecordData] public bool? NullableBoolData { get; set; } - [VectorStoreRecordVector] + [VectorStoreRecordVector(10)] public ReadOnlyMemory? FloatVector { get; set; } - [VectorStoreRecordVector] + [VectorStoreRecordVector(10)] public ReadOnlyMemory? DoubleVector { get; set; } public string NotAnnotated { get; set; } = string.Empty; diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonDynamicDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonDynamicDataModelMapperTests.cs new file mode 100644 index 000000000000..936f3d1c8865 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonDynamicDataModelMapperTests.cs @@ -0,0 +1,182 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Nodes; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Xunit; + +namespace Microsoft.SemanticKernel.Connectors.Redis.UnitTests; + +/// +/// Contains tests for the class. +/// +public class RedisJsonDynamicDataModelMapperTests +{ + private static readonly float[] s_floatVector = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + + private static readonly VectorStoreRecordModel s_model + = new VectorStoreRecordJsonModelBuilder(RedisJsonVectorStoreRecordCollection>.ModelBuildingOptions) + .Build( + typeof(Dictionary), + new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringData", typeof(string)), + new VectorStoreRecordDataProperty("IntData", typeof(int)), + new VectorStoreRecordDataProperty("NullableIntData", typeof(int?)), + new VectorStoreRecordDataProperty("ComplexObjectData", typeof(ComplexObject)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10), + ] + }, + defaultEmbeddingGenerator: null); + + [Fact] + public void MapFromDataToStorageModelMapsAllSupportedTypes() + { + // Arrange. + var sut = new RedisJsonDynamicDataModelMapper(s_model, JsonSerializerOptions.Default); + var dataModel = new Dictionary + { + ["Key"] = "key", + ["StringData"] = "data 1", + ["IntData"] = 1, + ["NullableIntData"] = 2, + ["ComplexObjectData"] = new ComplexObject { Prop1 = "prop 1", Prop2 = "prop 2" }, + ["FloatVector"] = new ReadOnlyMemory(s_floatVector) + }; + + // Act. + var storageModel = sut.MapFromDataToStorageModel(dataModel, recordIndex: 0, generatedEmbeddings: null); + + // Assert + Assert.Equal("key", storageModel.Key); + Assert.Equal("data 1", (string)storageModel.Node["StringData"]!); + Assert.Equal(1, (int)storageModel.Node["IntData"]!); + Assert.Equal(2, (int?)storageModel.Node["NullableIntData"]!); + Assert.Equal("prop 1", (string)storageModel.Node["ComplexObjectData"]!.AsObject()["Prop1"]!); + Assert.Equal(new float[] { 1, 2, 3, 4 }, storageModel.Node["FloatVector"]?.AsArray().GetValues().ToArray()); + } + + [Fact] + public void MapFromDataToStorageModelMapsNullValues() + { + // Arrange. + var sut = new RedisJsonDynamicDataModelMapper(s_model, JsonSerializerOptions.Default); + var dataModel = new Dictionary + { + ["Key"] = "key", + ["StringData"] = null, + ["IntData"] = null, + ["NullableIntData"] = null, + ["ComplexObjectData"] = null, + ["FloatVector"] = null, + }; + + // Act. + var storageModel = sut.MapFromDataToStorageModel(dataModel, recordIndex: 0, generatedEmbeddings: null); + + // Assert + Assert.Equal("key", storageModel.Key); + Assert.Null(storageModel.Node["storage_string_data"]); + Assert.Null(storageModel.Node["IntData"]); + Assert.Null(storageModel.Node["NullableIntData"]); + Assert.Null(storageModel.Node["ComplexObjectData"]); + Assert.Null(storageModel.Node["FloatVector"]); + } + + [Fact] + public void MapFromStorageToDataModelMapsAllSupportedTypes() + { + // Arrange. + var sut = new RedisJsonDynamicDataModelMapper(s_model, JsonSerializerOptions.Default); + var storageModel = new JsonObject + { + { "StringData", "data 1" }, + { "IntData", 1 }, + { "NullableIntData", 2 }, + { "ComplexObjectData", new JsonObject(new KeyValuePair[] { new("Prop1", JsonValue.Create("prop 1")), new("Prop2", JsonValue.Create("prop 2")) }) }, + { "FloatVector", new JsonArray(new[] { 1, 2, 3, 4 }.Select(x => JsonValue.Create(x)).ToArray()) } + }; + + // Act. + var dataModel = sut.MapFromStorageToDataModel(("key", storageModel), new() { IncludeVectors = true }); + + // Assert. + Assert.Equal("key", dataModel["Key"]); + Assert.Equal("data 1", dataModel["StringData"]); + Assert.Equal(1, dataModel["IntData"]); + Assert.Equal(2, dataModel["NullableIntData"]); + Assert.Equal("prop 1", ((ComplexObject)dataModel["ComplexObjectData"]!).Prop1); + Assert.Equal(new float[] { 1, 2, 3, 4 }, ((ReadOnlyMemory)dataModel["FloatVector"]!).ToArray()); + } + + [Fact] + public void MapFromStorageToDataModelMapsNullValues() + { + // Arrange. + var sut = new RedisJsonDynamicDataModelMapper(s_model, JsonSerializerOptions.Default); + var storageModel = new JsonObject + { + { "StringData", null }, + { "IntData", null }, + { "NullableIntData", null }, + { "ComplexObjectData", null }, + { "FloatVector", null } + }; + + // Act. + var dataModel = sut.MapFromStorageToDataModel(("key", storageModel), new() { IncludeVectors = true }); + + // Assert. + Assert.Equal("key", dataModel["Key"]); + Assert.Null(dataModel["StringData"]); + Assert.Null(dataModel["IntData"]); + Assert.Null(dataModel["NullableIntData"]); + Assert.Null(dataModel["ComplexObjectData"]); + Assert.Null(dataModel["FloatVector"]); + } + + [Fact] + public void MapFromDataToStorageModelSkipsMissingProperties() + { + // Arrange. + var sut = new RedisJsonDynamicDataModelMapper(s_model, JsonSerializerOptions.Default); + var dataModel = new Dictionary { ["Key"] = "key" }; + + // Act. + var storageModel = sut.MapFromDataToStorageModel(dataModel, recordIndex: 0, generatedEmbeddings: null); + + // Assert + Assert.Equal("key", storageModel.Key); + Assert.Empty(storageModel.Node.AsObject()); + } + + [Fact] + public void MapFromStorageToDataModelSkipsMissingProperties() + { + // Arrange. + var storageModel = new JsonObject(); + + var sut = new RedisJsonDynamicDataModelMapper(s_model, JsonSerializerOptions.Default); + + // Act. + var dataModel = sut.MapFromStorageToDataModel(("key", storageModel), new() { IncludeVectors = true }); + + // Assert. + Assert.Equal("key", dataModel["Key"]); + Assert.Single(dataModel); + } + + private sealed class ComplexObject + { + public string Prop1 { get; set; } = string.Empty; + + public string Prop2 { get; set; } = string.Empty; + } +} diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonGenericDataModelMapperTests.cs deleted file mode 100644 index 779dddaffa94..000000000000 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonGenericDataModelMapperTests.cs +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text.Json; -using System.Text.Json.Nodes; -using Microsoft.Extensions.VectorData; -using Xunit; - -namespace Microsoft.SemanticKernel.Connectors.Redis.UnitTests; - -/// -/// Contains tests for the class. -/// -public class RedisJsonGenericDataModelMapperTests -{ - private static readonly float[] s_floatVector = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; - - private static readonly VectorStoreRecordDefinition s_vectorStoreRecordDefinition = new() - { - Properties = new List() - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("StringData", typeof(string)) { StoragePropertyName = "storage_string_data" }, - new VectorStoreRecordDataProperty("IntData", typeof(int)), - new VectorStoreRecordDataProperty("NullableIntData", typeof(int?)), - new VectorStoreRecordDataProperty("ComplexObjectData", typeof(ComplexObject)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - } - }; - - [Fact] - public void MapFromDataToStorageModelMapsAllSupportedTypes() - { - // Arrange. - var sut = new RedisJsonGenericDataModelMapper(s_vectorStoreRecordDefinition.Properties, JsonSerializerOptions.Default); - var dataModel = new VectorStoreGenericDataModel("key") - { - Data = - { - ["StringData"] = "data 1", - ["IntData"] = 1, - ["NullableIntData"] = 2, - ["ComplexObjectData"] = new ComplexObject { Prop1 = "prop 1", Prop2 = "prop 2" }, - }, - Vectors = - { - ["FloatVector"] = new ReadOnlyMemory(s_floatVector), - }, - }; - - // Act. - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal("key", storageModel.Key); - Assert.Equal("data 1", (string)storageModel.Node["storage_string_data"]!); - Assert.Equal(1, (int)storageModel.Node["IntData"]!); - Assert.Equal(2, (int?)storageModel.Node["NullableIntData"]!); - Assert.Equal("prop 1", (string)storageModel.Node["ComplexObjectData"]!.AsObject()["Prop1"]!); - Assert.Equal(new float[] { 1, 2, 3, 4 }, storageModel.Node["FloatVector"]?.AsArray().GetValues().ToArray()); - } - - [Fact] - public void MapFromDataToStorageModelMapsNullValues() - { - // Arrange. - var sut = new RedisJsonGenericDataModelMapper(s_vectorStoreRecordDefinition.Properties, JsonSerializerOptions.Default); - var dataModel = new VectorStoreGenericDataModel("key") - { - Data = - { - ["StringData"] = null, - ["IntData"] = null, - ["NullableIntData"] = null, - ["ComplexObjectData"] = null, - }, - Vectors = - { - ["FloatVector"] = null, - }, - }; - - // Act. - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal("key", storageModel.Key); - Assert.Null(storageModel.Node["storage_string_data"]); - Assert.Null(storageModel.Node["IntData"]); - Assert.Null(storageModel.Node["NullableIntData"]); - Assert.Null(storageModel.Node["ComplexObjectData"]); - Assert.Null(storageModel.Node["FloatVector"]); - } - - [Fact] - public void MapFromStorageToDataModelMapsAllSupportedTypes() - { - // Arrange. - var sut = new RedisJsonGenericDataModelMapper(s_vectorStoreRecordDefinition.Properties, JsonSerializerOptions.Default); - var storageModel = new JsonObject(); - storageModel.Add("storage_string_data", "data 1"); - storageModel.Add("IntData", 1); - storageModel.Add("NullableIntData", 2); - storageModel.Add("ComplexObjectData", new JsonObject(new KeyValuePair[] { new("Prop1", JsonValue.Create("prop 1")), new("Prop2", JsonValue.Create("prop 2")) })); - storageModel.Add("FloatVector", new JsonArray(new[] { 1, 2, 3, 4 }.Select(x => JsonValue.Create(x)).ToArray())); - - // Act. - var dataModel = sut.MapFromStorageToDataModel(("key", storageModel), new() { IncludeVectors = true }); - - // Assert. - Assert.Equal("key", dataModel.Key); - Assert.Equal("data 1", dataModel.Data["StringData"]); - Assert.Equal(1, dataModel.Data["IntData"]); - Assert.Equal(2, dataModel.Data["NullableIntData"]); - Assert.Equal("prop 1", ((ComplexObject)dataModel.Data["ComplexObjectData"]!).Prop1); - Assert.Equal(new float[] { 1, 2, 3, 4 }, ((ReadOnlyMemory)dataModel.Vectors["FloatVector"]!).ToArray()); - } - - [Fact] - public void MapFromStorageToDataModelMapsNullValues() - { - // Arrange. - var sut = new RedisJsonGenericDataModelMapper(s_vectorStoreRecordDefinition.Properties, JsonSerializerOptions.Default); - var storageModel = new JsonObject(); - storageModel.Add("storage_string_data", null); - storageModel.Add("IntData", null); - storageModel.Add("NullableIntData", null); - storageModel.Add("ComplexObjectData", null); - storageModel.Add("FloatVector", null); - - // Act. - var dataModel = sut.MapFromStorageToDataModel(("key", storageModel), new() { IncludeVectors = true }); - - // Assert. - Assert.Equal("key", dataModel.Key); - Assert.Null(dataModel.Data["StringData"]); - Assert.Null(dataModel.Data["IntData"]); - Assert.Null(dataModel.Data["NullableIntData"]); - Assert.Null(dataModel.Data["ComplexObjectData"]); - Assert.Null(dataModel.Vectors["FloatVector"]); - } - - [Fact] - public void MapFromDataToStorageModelSkipsMissingProperties() - { - // Arrange. - var sut = new RedisJsonGenericDataModelMapper(s_vectorStoreRecordDefinition.Properties, JsonSerializerOptions.Default); - var dataModel = new VectorStoreGenericDataModel("key") - { - Data = { }, - Vectors = { }, - }; - - // Act. - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal("key", storageModel.Key); - Assert.Empty(storageModel.Node.AsObject()); - } - - [Fact] - public void MapFromStorageToDataModelSkipsMissingProperties() - { - // Arrange. - var storageModel = new JsonObject(); - - var sut = new RedisJsonGenericDataModelMapper(s_vectorStoreRecordDefinition.Properties, JsonSerializerOptions.Default); - - // Act. - var dataModel = sut.MapFromStorageToDataModel(("key", storageModel), new() { IncludeVectors = true }); - - // Assert. - Assert.Equal("key", dataModel.Key); - Assert.Empty(dataModel.Data); - Assert.Empty(dataModel.Vectors); - } - - private sealed class ComplexObject - { - public string Prop1 { get; set; } = string.Empty; - - public string Prop2 { get; set; } = string.Empty; - } -} diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs index aa47dc512b8c..9b83f5d13cc1 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs @@ -5,7 +5,6 @@ using System.Linq; using System.Runtime.InteropServices; using System.Text.Json; -using System.Text.Json.Nodes; using System.Text.Json.Serialization; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; @@ -19,7 +18,7 @@ namespace Microsoft.SemanticKernel.Connectors.Redis.UnitTests; #pragma warning disable CS0618 // VectorSearchFilter is obsolete /// -/// Contains tests for the class. +/// Contains tests for the class. /// public class RedisJsonVectorStoreRecordCollectionTests { @@ -32,6 +31,7 @@ public class RedisJsonVectorStoreRecordCollectionTests public RedisJsonVectorStoreRecordCollectionTests() { this._redisDatabaseMock = new Mock(MockBehavior.Strict); + this._redisDatabaseMock.Setup(l => l.Database).Returns(0); var batchMock = new Mock(); this._redisDatabaseMock.Setup(x => x.CreateBatch(It.IsAny())).Returns(batchMock.Object); @@ -51,7 +51,7 @@ public async Task CollectionExistsReturnsCollectionStateAsync(string collectionN { SetupExecuteMock(this._redisDatabaseMock, new RedisServerException("Unknown index name")); } - var sut = new RedisJsonVectorStoreRecordCollection( + var sut = new RedisJsonVectorStoreRecordCollection( this._redisDatabaseMock.Object, collectionName); @@ -229,7 +229,7 @@ public async Task CanGetManyRecordsWithVectorsAsync(bool useDefinition) var sut = this.CreateRecordCollection(useDefinition); // Act - var actual = await sut.GetBatchAsync( + var actual = await sut.GetAsync( [TestRecordKey1, TestRecordKey2], new() { IncludeVectors = true }).ToListAsync(); @@ -254,51 +254,6 @@ public async Task CanGetManyRecordsWithVectorsAsync(bool useDefinition) Assert.Equal(new float[] { 5, 6, 7, 8 }, actual[1].Vector1!.Value.ToArray()); } - [Fact] - public async Task CanGetRecordWithCustomMapperAsync() - { - // Arrange. - var redisResultString = """{ "data1_json_name": "data 1", "Data2": "data 2", "vector1_json_name": [1, 2, 3, 4], "Vector2": [1, 2, 3, 4] }"""; - SetupExecuteMock(this._redisDatabaseMock, redisResultString); - - // Arrange mapper mock from JsonNode to data model. - var mapperMock = new Mock>(MockBehavior.Strict); - mapperMock.Setup( - x => x.MapFromStorageToDataModel( - It.IsAny<(string key, JsonNode node)>(), - It.IsAny())) - .Returns(CreateModel(TestRecordKey1, true)); - - // Arrange target with custom mapper. - var sut = new RedisJsonVectorStoreRecordCollection( - this._redisDatabaseMock.Object, - TestCollectionName, - new() - { - JsonNodeCustomMapper = mapperMock.Object - }); - - // Act - var actual = await sut.GetAsync( - TestRecordKey1, - new() { IncludeVectors = true }); - - // Assert - Assert.NotNull(actual); - Assert.Equal(TestRecordKey1, actual.Key); - Assert.Equal("data 1", actual.Data1); - Assert.Equal("data 2", actual.Data2); - Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector1!.Value.ToArray()); - Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector2!.Value.ToArray()); - - mapperMock - .Verify( - x => x.MapFromStorageToDataModel( - It.Is<(string key, JsonNode node)>(x => x.key == TestRecordKey1), - It.Is(x => x.IncludeVectors)), - Times.Once); - } - [Theory] [InlineData(true)] [InlineData(false)] @@ -331,7 +286,7 @@ public async Task CanDeleteManyRecordsWithVectorsAsync(bool useDefinition) var sut = this.CreateRecordCollection(useDefinition); // Act - await sut.DeleteBatchAsync([TestRecordKey1, TestRecordKey2]); + await sut.DeleteAsync([TestRecordKey1, TestRecordKey2]); // Assert var expectedArgs1 = new object[] { TestRecordKey1 }; @@ -389,7 +344,7 @@ public async Task CanUpsertManyRecordsAsync(bool useDefinition) var model2 = CreateModel(TestRecordKey2, true); // Act - var actual = await sut.UpsertBatchAsync([model1, model2]).ToListAsync(); + var actual = await sut.UpsertAsync([model1, model2]); // Assert Assert.NotNull(actual); @@ -407,40 +362,6 @@ public async Task CanUpsertManyRecordsAsync(bool useDefinition) Times.Once); } - [Fact] - public async Task CanUpsertRecordWithCustomMapperAsync() - { - // Arrange. - SetupExecuteMock(this._redisDatabaseMock, "OK"); - - // Arrange mapper mock from data model to JsonNode. - var mapperMock = new Mock>(MockBehavior.Strict); - var jsonNode = """{"data1_json_name":"data 1","Data2": "data 2","vector1_json_name":[1,2,3,4],"Vector2":[1,2,3,4],"NotAnnotated":null}"""; - mapperMock - .Setup(x => x.MapFromDataToStorageModel(It.IsAny())) - .Returns((TestRecordKey1, JsonNode.Parse(jsonNode)!)); - - // Arrange target with custom mapper. - var sut = new RedisJsonVectorStoreRecordCollection( - this._redisDatabaseMock.Object, - TestCollectionName, - new() - { - JsonNodeCustomMapper = mapperMock.Object - }); - - var model = CreateModel(TestRecordKey1, true); - - // Act - await sut.UpsertAsync(model); - - // Assert - mapperMock - .Verify( - x => x.MapFromDataToStorageModel(It.Is(x => x == model)), - Times.Once); - } - [Theory] [InlineData(true)] [InlineData(false)] @@ -466,16 +387,16 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition) var filter = new VectorSearchFilter().EqualTo(nameof(MultiPropsModel.Data1), "data 1"); // Act. - var actual = await sut.VectorizedSearchAsync( + var results = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new[] { 1f, 2f, 3f, 4f }), + top: 5, new() { IncludeVectors = true, OldFilter = filter, VectorProperty = r => r.Vector1, - Top = 5, Skip = 2 - }); + }).ToListAsync(); // Assert. var expectedArgs = new object[] @@ -502,7 +423,6 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition) It.Is(x => x.Where(y => !(y is byte[])).SequenceEqual(expectedArgs.Where(y => !(y is byte[]))))), Times.Once); - var results = await actual.Results.ToListAsync(); Assert.Single(results); Assert.Equal(TestRecordKey1, results.First().Record.Key); Assert.Equal(0.25d, results.First().Score); @@ -512,35 +432,9 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition) Assert.Equal(new float[] { 1, 2, 3, 4 }, results.First().Record.Vector2!.Value.ToArray()); } - /// - /// Tests that the collection can be created even if the definition and the type do not match. - /// In this case, the expectation is that a custom mapper will be provided to map between the - /// schema as defined by the definition and the different data model. - /// - [Fact] - public void CanCreateCollectionWithMismatchedDefinitionAndType() - { - // Arrange. - var definition = new VectorStoreRecordDefinition() - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Id", typeof(string)), - new VectorStoreRecordDataProperty("Text", typeof(string)), - new VectorStoreRecordVectorProperty("Embedding", typeof(ReadOnlyMemory)) { Dimensions = 4 }, - } - }; - - // Act. - var sut = new RedisJsonVectorStoreRecordCollection( - this._redisDatabaseMock.Object, - TestCollectionName, - new() { VectorStoreRecordDefinition = definition, JsonNodeCustomMapper = Mock.Of>() }); - } - - private RedisJsonVectorStoreRecordCollection CreateRecordCollection(bool useDefinition, bool useCustomJsonSerializerOptions = false) + private RedisJsonVectorStoreRecordCollection CreateRecordCollection(bool useDefinition, bool useCustomJsonSerializerOptions = false) { - return new RedisJsonVectorStoreRecordCollection( + return new RedisJsonVectorStoreRecordCollection( this._redisDatabaseMock.Object, TestCollectionName, new() @@ -624,10 +518,10 @@ private static MultiPropsModel CreateModel(string key, bool withVectors) Properties = [ new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("Data1", typeof(string)) { IsFilterable = true, StoragePropertyName = "ignored_data1_storage_name" }, - new VectorStoreRecordDataProperty("Data2", typeof(string)) { IsFilterable = true }, - new VectorStoreRecordVectorProperty("Vector1", typeof(ReadOnlyMemory)) { Dimensions = 4, DistanceFunction = DistanceFunction.CosineDistance, StoragePropertyName = "ignored_vector1_storage_name" }, - new VectorStoreRecordVectorProperty("Vector2", typeof(ReadOnlyMemory)) { Dimensions = 4 } + new VectorStoreRecordDataProperty("Data1", typeof(string)) { IsIndexed = true, StoragePropertyName = "ignored_data1_storage_name" }, + new VectorStoreRecordDataProperty("Data2", typeof(string)) { IsIndexed = true }, + new VectorStoreRecordVectorProperty("Vector1", typeof(ReadOnlyMemory), 4) { DistanceFunction = DistanceFunction.CosineDistance, StoragePropertyName = "ignored_vector1_storage_name" }, + new VectorStoreRecordVectorProperty("Vector2", typeof(ReadOnlyMemory), 4) ] }; @@ -637,14 +531,14 @@ public sealed class MultiPropsModel public string Key { get; set; } = string.Empty; [JsonPropertyName("data1_json_name")] - [VectorStoreRecordData(IsFilterable = true, StoragePropertyName = "ignored_data1_storage_name")] + [VectorStoreRecordData(IsIndexed = true, StoragePropertyName = "ignored_data1_storage_name")] public string Data1 { get; set; } = string.Empty; - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string Data2 { get; set; } = string.Empty; [JsonPropertyName("vector1_json_name")] - [VectorStoreRecordVector(4, DistanceFunction.CosineDistance, StoragePropertyName = "ignored_vector1_storage_name")] + [VectorStoreRecordVector(4, DistanceFunction = DistanceFunction.CosineDistance, StoragePropertyName = "ignored_vector1_storage_name")] public ReadOnlyMemory? Vector1 { get; set; } [VectorStoreRecordVector(4)] diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordMapperTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordMapperTests.cs index fef62c68a530..3bc180a0b788 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordMapperTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordMapperTests.cs @@ -5,6 +5,7 @@ using System.Text.Json; using System.Text.Json.Nodes; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.Redis; using Xunit; @@ -19,10 +20,12 @@ public sealed class RedisJsonVectorStoreRecordMapperTests public void MapsAllFieldsFromDataToStorageModel() { // Arrange. - var sut = new RedisJsonVectorStoreRecordMapper("Key", JsonSerializerOptions.Default); + var model = new VectorStoreRecordJsonModelBuilder(RedisJsonVectorStoreRecordCollection.ModelBuildingOptions) + .Build(typeof(MultiPropsModel), vectorStoreRecordDefinition: null, defaultEmbeddingGenerator: null, JsonSerializerOptions.Default); + var sut = new RedisJsonVectorStoreRecordMapper(model, JsonSerializerOptions.Default); // Act. - var actual = sut.MapFromDataToStorageModel(CreateModel("test key")); + var actual = sut.MapFromDataToStorageModel(CreateModel("test key"), recordIndex: 0, generatedEmbeddings: null); // Assert. Assert.NotNull(actual.Node); @@ -38,10 +41,13 @@ public void MapsAllFieldsFromDataToStorageModel() public void MapsAllFieldsFromDataToStorageModelWithCustomSerializerOptions() { // Arrange. - var sut = new RedisJsonVectorStoreRecordMapper("key", new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }); + var jsonSerializerOptions = new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }; + var model = new VectorStoreRecordJsonModelBuilder(RedisJsonVectorStoreRecordCollection.ModelBuildingOptions) + .Build(typeof(MultiPropsModel), vectorStoreRecordDefinition: null, defaultEmbeddingGenerator: null, jsonSerializerOptions); + var sut = new RedisJsonVectorStoreRecordMapper(model, jsonSerializerOptions); // Act. - var actual = sut.MapFromDataToStorageModel(CreateModel("test key")); + var actual = sut.MapFromDataToStorageModel(CreateModel("test key"), recordIndex: 0, generatedEmbeddings: null); // Assert. Assert.NotNull(actual.Node); @@ -57,7 +63,9 @@ public void MapsAllFieldsFromDataToStorageModelWithCustomSerializerOptions() public void MapsAllFieldsFromStorageToDataModel() { // Arrange. - var sut = new RedisJsonVectorStoreRecordMapper("Key", JsonSerializerOptions.Default); + var model = new VectorStoreRecordJsonModelBuilder(RedisJsonVectorStoreRecordCollection.ModelBuildingOptions) + .Build(typeof(MultiPropsModel), vectorStoreRecordDefinition: null, defaultEmbeddingGenerator: null, JsonSerializerOptions.Default); + var sut = new RedisJsonVectorStoreRecordMapper(model, JsonSerializerOptions.Default); // Act. var jsonObject = new JsonObject(); @@ -80,7 +88,10 @@ public void MapsAllFieldsFromStorageToDataModel() public void MapsAllFieldsFromStorageToDataModelWithCustomSerializerOptions() { // Arrange. - var sut = new RedisJsonVectorStoreRecordMapper("key", new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }); + var jsonSerializerOptions = new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }; + var model = new VectorStoreRecordJsonModelBuilder(RedisJsonVectorStoreRecordCollection.ModelBuildingOptions) + .Build(typeof(MultiPropsModel), vectorStoreRecordDefinition: null, defaultEmbeddingGenerator: null, jsonSerializerOptions); + var sut = new RedisJsonVectorStoreRecordMapper(model, jsonSerializerOptions); // Act. var jsonObject = new JsonObject(); @@ -123,10 +134,10 @@ private sealed class MultiPropsModel [VectorStoreRecordData] public string Data2 { get; set; } = string.Empty; - [VectorStoreRecordVector] + [VectorStoreRecordVector(10)] public ReadOnlyMemory? Vector1 { get; set; } - [VectorStoreRecordVector] + [VectorStoreRecordVector(10)] public ReadOnlyMemory? Vector2 { get; set; } public string NotAnnotated { get; set; } = string.Empty; diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisKernelBuilderExtensionsTests.cs index d83ab4ca403b..a66beeb7183a 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisKernelBuilderExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisKernelBuilderExtensionsTests.cs @@ -69,20 +69,20 @@ private void AssertVectorStoreCreated() Assert.IsType(vectorStore); } - private void AssertHashSetVectorStoreRecordCollectionCreated() + private void AssertHashSetVectorStoreRecordCollectionCreated() where TRecord : notnull { var kernel = this._kernelBuilder.Build(); var collection = kernel.Services.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); } - private void AssertJsonVectorStoreRecordCollectionCreated() + private void AssertJsonVectorStoreRecordCollectionCreated() where TRecord : notnull { var kernel = this._kernelBuilder.Build(); var collection = kernel.Services.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); } #pragma warning disable CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisMemoryStoreTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisMemoryStoreTests.cs index 5c63e568a3a9..892c8c0d495f 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisMemoryStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisMemoryStoreTests.cs @@ -20,6 +20,7 @@ namespace SemanticKernel.Connectors.Redis.UnitTests; /// /// Unit tests of . /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class RedisMemoryStoreTests { private readonly Mock _mockDatabase; diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisServiceCollectionExtensionsTests.cs index 8c6455b4a226..c4cc03b79d68 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisServiceCollectionExtensionsTests.cs @@ -69,20 +69,20 @@ private void AssertVectorStoreCreated() Assert.IsType(vectorStore); } - private void AssertHashSetVectorStoreRecordCollectionCreated() + private void AssertHashSetVectorStoreRecordCollectionCreated() where TRecord : notnull { var serviceProvider = this._serviceCollection.BuildServiceProvider(); var collection = serviceProvider.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); } - private void AssertJsonVectorStoreRecordCollectionCreated() + private void AssertJsonVectorStoreRecordCollectionCreated() where TRecord : notnull { var serviceProvider = this._serviceCollection.BuildServiceProvider(); var collection = serviceProvider.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); } #pragma warning disable CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionCreateMappingTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionCreateMappingTests.cs index ef3ba3447bad..c519ec4bf4bf 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionCreateMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionCreateMappingTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using NRedisStack.Search; using Xunit; using static NRedisStack.Search.Schema; @@ -20,39 +21,33 @@ public class RedisVectorStoreCollectionCreateMappingTests public void MapToSchemaCreatesSchema(bool useDollarPrefix) { // Arrange. - var properties = new VectorStoreRecordProperty[] - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - - new VectorStoreRecordDataProperty("FilterableString", typeof(string)) { IsFilterable = true }, - new VectorStoreRecordDataProperty("FullTextSearchableString", typeof(string)) { IsFullTextSearchable = true }, - new VectorStoreRecordDataProperty("FilterableStringEnumerable", typeof(string[])) { IsFilterable = true }, - new VectorStoreRecordDataProperty("FullTextSearchableStringEnumerable", typeof(string[])) { IsFullTextSearchable = true }, - - new VectorStoreRecordDataProperty("FilterableInt", typeof(int)) { IsFilterable = true }, - new VectorStoreRecordDataProperty("FilterableNullableInt", typeof(int)) { IsFilterable = true }, - - new VectorStoreRecordDataProperty("NonFilterableString", typeof(string)), - - new VectorStoreRecordVectorProperty("VectorDefaultIndexingOptions", typeof(ReadOnlyMemory)) { Dimensions = 10 }, - new VectorStoreRecordVectorProperty("VectorSpecificIndexingOptions", typeof(ReadOnlyMemory)) { Dimensions = 20, IndexKind = IndexKind.Flat, DistanceFunction = DistanceFunction.EuclideanSquaredDistance }, - }; - - var storagePropertyNames = new Dictionary() - { - { "FilterableString", "FilterableString" }, - { "FullTextSearchableString", "FullTextSearchableString" }, - { "FilterableStringEnumerable", "FilterableStringEnumerable" }, - { "FullTextSearchableStringEnumerable", "FullTextSearchableStringEnumerable" }, - { "FilterableInt", "FilterableInt" }, - { "FilterableNullableInt", "FilterableNullableInt" }, - { "NonFilterableString", "NonFilterableString" }, - { "VectorDefaultIndexingOptions", "VectorDefaultIndexingOptions" }, - { "VectorSpecificIndexingOptions", "vector_specific_indexing_options" }, - }; + VectorStoreRecordPropertyModel[] properties = + [ + new VectorStoreRecordKeyPropertyModel("Key", typeof(string)), + + new VectorStoreRecordDataPropertyModel("FilterableString", typeof(string)) { IsIndexed = true }, + new VectorStoreRecordDataPropertyModel("FullTextSearchableString", typeof(string)) { IsFullTextIndexed = true }, + new VectorStoreRecordDataPropertyModel("FilterableStringEnumerable", typeof(string[])) { IsIndexed = true }, + new VectorStoreRecordDataPropertyModel("FullTextSearchableStringEnumerable", typeof(string[])) { IsFullTextIndexed = true }, + + new VectorStoreRecordDataPropertyModel("FilterableInt", typeof(int)) { IsIndexed = true }, + new VectorStoreRecordDataPropertyModel("FilterableNullableInt", typeof(int)) { IsIndexed = true }, + + new VectorStoreRecordDataPropertyModel("NonFilterableString", typeof(string)), + + new VectorStoreRecordVectorPropertyModel("VectorDefaultIndexingOptions", typeof(ReadOnlyMemory)) { Dimensions = 10, EmbeddingType = typeof(ReadOnlyMemory) }, + new VectorStoreRecordVectorPropertyModel("VectorSpecificIndexingOptions", typeof(ReadOnlyMemory)) + { + Dimensions = 20, + IndexKind = IndexKind.Flat, + DistanceFunction = DistanceFunction.EuclideanSquaredDistance, + StorageName = "vector_specific_indexing_options", + EmbeddingType = typeof(ReadOnlyMemory) + } + ]; // Act. - var schema = RedisVectorStoreCollectionCreateMapping.MapToSchema(properties, storagePropertyNames, useDollarPrefix); + var schema = RedisVectorStoreCollectionCreateMapping.MapToSchema(properties, useDollarPrefix); // Assert. Assert.NotNull(schema); @@ -103,24 +98,11 @@ public void MapToSchemaCreatesSchema(bool useDollarPrefix) Assert.Equal("L2", ((VectorField)schema.Fields[7]).Attributes!["DISTANCE_METRIC"]); } - [Theory] - [InlineData(null)] - [InlineData(0)] - public void MapToSchemaThrowsOnInvalidVectorDimensions(int? dimensions) - { - // Arrange. - var properties = new VectorStoreRecordProperty[] { new VectorStoreRecordVectorProperty("VectorProperty", typeof(ReadOnlyMemory)) { Dimensions = dimensions } }; - var storagePropertyNames = new Dictionary() { { "VectorProperty", "VectorProperty" } }; - - // Act and assert. - Assert.Throws(() => RedisVectorStoreCollectionCreateMapping.MapToSchema(properties, storagePropertyNames, true)); - } - [Fact] public void GetSDKIndexKindThrowsOnUnsupportedIndexKind() { // Arrange. - var vectorProperty = new VectorStoreRecordVectorProperty("VectorProperty", typeof(ReadOnlyMemory)) { IndexKind = "Unsupported" }; + var vectorProperty = new VectorStoreRecordVectorPropertyModel("VectorProperty", typeof(ReadOnlyMemory)) { IndexKind = "Unsupported" }; // Act and assert. Assert.Throws(() => RedisVectorStoreCollectionCreateMapping.GetSDKIndexKind(vectorProperty)); @@ -130,7 +112,7 @@ public void GetSDKIndexKindThrowsOnUnsupportedIndexKind() public void GetSDKDistanceAlgorithmThrowsOnUnsupportedDistanceFunction() { // Arrange. - var vectorProperty = new VectorStoreRecordVectorProperty("VectorProperty", typeof(ReadOnlyMemory)) { DistanceFunction = "Unsupported" }; + var vectorProperty = new VectorStoreRecordVectorPropertyModel("VectorProperty", typeof(ReadOnlyMemory)) { DistanceFunction = "Unsupported" }; // Act and assert. Assert.Throws(() => RedisVectorStoreCollectionCreateMapping.GetSDKDistanceAlgorithm(vectorProperty)); diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionSearchMappingTests.cs index 087b707a4b7c..f75a00a86354 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionSearchMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionSearchMappingTests.cs @@ -2,9 +2,9 @@ using System; using System.Collections.Generic; -using System.Linq; using System.Runtime.InteropServices; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Xunit; namespace Microsoft.SemanticKernel.Connectors.Redis.UnitTests; @@ -66,17 +66,18 @@ public void BuildQueryBuildsRedisQueryWithDefaults() // Arrange. var floatVector = new ReadOnlyMemory(new float[] { 1.0f, 2.0f, 3.0f }); var byteArray = MemoryMarshal.AsBytes(floatVector.Span).ToArray(); - var storagePropertyNames = new Dictionary() - { - { "Vector", "storage_Vector" }, - }; + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 10) + ]); // Act. - var query = RedisVectorStoreCollectionSearchMapping.BuildQuery(byteArray, new VectorSearchOptions(), storagePropertyNames, storagePropertyNames.Values.Single(), null); + var query = RedisVectorStoreCollectionSearchMapping.BuildQuery(byteArray, top: 3, new VectorSearchOptions(), model, model.VectorProperty, null); // Assert. Assert.NotNull(query); - Assert.Equal("*=>[KNN 3 @storage_Vector $embedding AS vector_score]", query.QueryString); + Assert.Equal("*=>[KNN 3 @Vector $embedding AS vector_score]", query.QueryString); Assert.Equal("vector_score", query.SortBy); Assert.True(query.WithScores); Assert.Equal(2, query.dialect); @@ -88,15 +89,16 @@ public void BuildQueryBuildsRedisQueryWithCustomVectorName() // Arrange. var floatVector = new ReadOnlyMemory(new float[] { 1.0f, 2.0f, 3.0f }); var byteArray = MemoryMarshal.AsBytes(floatVector.Span).ToArray(); - var vectorSearchOptions = new VectorSearchOptions { Top = 5, Skip = 3 }; - var storagePropertyNames = new Dictionary() - { - { "Vector", "storage_Vector" }, - }; + var vectorSearchOptions = new VectorSearchOptions { Skip = 3 }; + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 10) { StoragePropertyName = "storage_Vector" } + ]); var selectFields = new string[] { "storage_Field1", "storage_Field2" }; // Act. - var query = RedisVectorStoreCollectionSearchMapping.BuildQuery(byteArray, vectorSearchOptions, storagePropertyNames, storagePropertyNames.Values.Single(), selectFields); + var query = RedisVectorStoreCollectionSearchMapping.BuildQuery(byteArray, top: 5, vectorSearchOptions, model, model.VectorProperty, selectFields); // Assert. Assert.NotNull(query); @@ -124,13 +126,15 @@ public void BuildFilterBuildsEqualityFilter(string filterType) _ => throw new InvalidOperationException(), }; - var storagePropertyNames = new Dictionary() - { - { "Data1", "storage_Data1" }, - }; + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("Data1", typeof(string)) { StoragePropertyName = "storage_Data1" }, + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 10) + ]); // Act. - var filter = RedisVectorStoreCollectionSearchMapping.BuildLegacyFilter(basicVectorSearchFilter, storagePropertyNames); + var filter = RedisVectorStoreCollectionSearchMapping.BuildLegacyFilter(basicVectorSearchFilter, model); // Assert. switch (filterType) @@ -157,15 +161,17 @@ public void BuildFilterThrowsForInvalidValueType() { // Arrange. var basicVectorSearchFilter = new VectorSearchFilter().EqualTo("Data1", true); - var storagePropertyNames = new Dictionary() - { - { "Data1", "storage_Data1" }, - }; + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("Data1", typeof(string)) { StoragePropertyName = "storage_Data1" }, + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 10) + ]); // Act & Assert. Assert.Throws(() => { - var filter = RedisVectorStoreCollectionSearchMapping.BuildLegacyFilter(basicVectorSearchFilter, storagePropertyNames); + var filter = RedisVectorStoreCollectionSearchMapping.BuildLegacyFilter(basicVectorSearchFilter, model); }); } @@ -174,22 +180,24 @@ public void BuildFilterThrowsForUnknownFieldName() { // Arrange. var basicVectorSearchFilter = new VectorSearchFilter().EqualTo("UnknownData", "value"); - var storagePropertyNames = new Dictionary() - { - { "Data1", "storage_Data1" }, - }; + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("Data1", typeof(string)) { StoragePropertyName = "storage_Data1" }, + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 10) + ]); // Act & Assert. Assert.Throws(() => { - var filter = RedisVectorStoreCollectionSearchMapping.BuildLegacyFilter(basicVectorSearchFilter, storagePropertyNames); + var filter = RedisVectorStoreCollectionSearchMapping.BuildLegacyFilter(basicVectorSearchFilter, model); }); } [Fact] public void ResolveDistanceFunctionReturnsCosineSimilarityIfNoDistanceFunctionSpecified() { - var property = new VectorStoreRecordVectorProperty("Prop", typeof(ReadOnlyMemory)); + var property = new VectorStoreRecordVectorPropertyModel("Prop", typeof(ReadOnlyMemory)); // Act. var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(property); @@ -201,7 +209,7 @@ public void ResolveDistanceFunctionReturnsCosineSimilarityIfNoDistanceFunctionSp [Fact] public void ResolveDistanceFunctionReturnsDistanceFunctionFromProvidedProperty() { - var property = new VectorStoreRecordVectorProperty("Prop", typeof(ReadOnlyMemory)) { DistanceFunction = DistanceFunction.DotProductSimilarity }; + var property = new VectorStoreRecordVectorPropertyModel("Prop", typeof(ReadOnlyMemory)) { DistanceFunction = DistanceFunction.DotProductSimilarity }; // Act. var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(property); @@ -232,4 +240,11 @@ public void GetOutputScoreFromRedisScoreLeavesNonConsineSimilarityUntouched(stri #pragma warning disable CA1812 // An internal class that is apparently never instantiated. If so, remove the code from the assembly. private sealed class DummyType; #pragma warning restore CA1812 + + private static VectorStoreRecordModel BuildModel(List properties) + => new VectorStoreRecordModelBuilder(RedisHashSetVectorStoreRecordCollection.ModelBuildingOptions) + .Build( + typeof(Dictionary), + new() { Properties = properties }, + defaultEmbeddingGenerator: null); } diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreTests.cs index baf2564c81a2..e2347c4ea989 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreTests.cs @@ -22,6 +22,7 @@ public class RedisVectorStoreTests public RedisVectorStoreTests() { this._redisDatabaseMock = new Mock(MockBehavior.Strict); + this._redisDatabaseMock.Setup(l => l.Database).Returns(0); var batchMock = new Mock(); this._redisDatabaseMock.Setup(x => x.CreateBatch(It.IsAny())).Returns(batchMock.Object); @@ -38,7 +39,7 @@ public void GetCollectionReturnsJsonCollection() // Assert. Assert.NotNull(actual); - Assert.IsType>>(actual); + Assert.IsType>>(actual); } [Fact] @@ -52,7 +53,7 @@ public void GetCollectionReturnsHashSetCollection() // Assert. Assert.NotNull(actual); - Assert.IsType>>(actual); + Assert.IsType>>(actual); } #pragma warning disable CS0618 // IRedisVectorStoreRecordCollectionFactory is obsolete diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Connectors.Sqlite.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Connectors.Sqlite.UnitTests.csproj index 015df8f6e56d..6bc1914ba63b 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Connectors.Sqlite.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Connectors.Sqlite.UnitTests.csproj @@ -8,7 +8,8 @@ enable disable false - $(NoWarn);SKEXP0001,SKEXP0020,VSTHRD111,CA2007,CS1591 + $(NoWarn);SKEXP0001,VSTHRD111,CA2007,CS1591 + $(NoWarn);MEVD9001 diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDBConnection.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDBConnection.cs deleted file mode 100644 index 7c318e1ef413..000000000000 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDBConnection.cs +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Data; -using System.Data.Common; - -namespace SemanticKernel.Connectors.Sqlite.UnitTests; - -#pragma warning disable CS8618, CS8765 - -internal sealed class FakeDBConnection(DbCommand command) : DbConnection -{ - public override string ConnectionString { get; set; } - - public override string Database => throw new NotImplementedException(); - - public override string DataSource => throw new NotImplementedException(); - - public override string ServerVersion => throw new NotImplementedException(); - - public override ConnectionState State => throw new NotImplementedException(); - - public override void ChangeDatabase(string databaseName) => throw new NotImplementedException(); - - public override void Close() => throw new NotImplementedException(); - - public override void Open() => throw new NotImplementedException(); - - protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) => throw new NotImplementedException(); - - protected override DbCommand CreateDbCommand() => command; -} diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbCommand.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbCommand.cs deleted file mode 100644 index df6062d9a4c1..000000000000 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbCommand.cs +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Data; -using System.Data.Common; - -namespace SemanticKernel.Connectors.Sqlite.UnitTests; - -#pragma warning disable CS8618, CS8765 - -internal sealed class FakeDbCommand( - DbDataReader? dataReader = null, - object? scalarResult = null) : DbCommand -{ - public int ExecuteNonQueryCallCount { get; private set; } = 0; - - private readonly FakeDbParameterCollection _parameterCollection = []; - - public override string CommandText { get; set; } - public override int CommandTimeout { get; set; } - public override CommandType CommandType { get; set; } - public override bool DesignTimeVisible { get; set; } - public override UpdateRowSource UpdatedRowSource { get; set; } - protected override DbConnection? DbConnection { get; set; } - - protected override DbParameterCollection DbParameterCollection => this._parameterCollection; - - protected override DbTransaction? DbTransaction { get; set; } - - public override void Cancel() => throw new NotImplementedException(); - - public override int ExecuteNonQuery() - { - this.ExecuteNonQueryCallCount++; - return 0; - } - - public override object? ExecuteScalar() => scalarResult; - - public override void Prepare() => throw new NotImplementedException(); - - protected override DbParameter CreateDbParameter() => throw new NotImplementedException(); - - protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) => dataReader ?? throw new NotImplementedException(); -} diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbParameterCollection.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbParameterCollection.cs deleted file mode 100644 index 246b97a3360b..000000000000 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbParameterCollection.cs +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections; -using System.Collections.Generic; -using System.Data.Common; - -namespace SemanticKernel.Connectors.Sqlite.UnitTests; - -#pragma warning disable CA1812 - -internal sealed class FakeDbParameterCollection : DbParameterCollection -{ - private readonly List _parameters = []; - - public override int Count => this._parameters.Count; - - public override object SyncRoot => throw new NotImplementedException(); - - public override int Add(object value) - { - this._parameters.Add(value); - return default; - } - - public override void AddRange(Array values) - { - this._parameters.AddRange([.. values]); - } - - public override void Clear() - { - this._parameters.Clear(); - } - - public override bool Contains(object value) - { - return this._parameters.Contains(value); - } - - public override bool Contains(string value) - { - return this._parameters.Contains(value); - } - - public override void CopyTo(Array array, int index) - { - this._parameters.CopyTo([.. array], index); - } - - public override IEnumerator GetEnumerator() - { - return this._parameters.GetEnumerator(); - } - - public override int IndexOf(object value) - { - return this._parameters.IndexOf(value); - } - - public override int IndexOf(string parameterName) - { - return this._parameters.IndexOf(parameterName); - } - - public override void Insert(int index, object value) - { - this._parameters.Insert(index, value); - } - - public override void Remove(object value) - { - this._parameters.Remove(value); - } - - public override void RemoveAt(int index) - { - this._parameters.RemoveAt(index); - } - - public override void RemoveAt(string parameterName) - { - throw new NotImplementedException(); - } - - protected override DbParameter GetParameter(int index) - { - return (this._parameters[index] as DbParameter)!; - } - - protected override DbParameter GetParameter(string parameterName) - { - throw new NotImplementedException(); - } - - protected override void SetParameter(int index, DbParameter value) - { - this._parameters[index] = value; - } - - protected override void SetParameter(string parameterName, DbParameter value) - { - throw new NotImplementedException(); - } -} diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteConditionsTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteConditionsTests.cs index 7f02575e9b88..aab78c2150f4 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteConditionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteConditionsTests.cs @@ -22,9 +22,9 @@ public void SqliteWhereEqualsConditionWithoutParameterNamesThrowsException() } [Theory] - [InlineData(null, "Name = @Name0")] - [InlineData("", "Name = @Name0")] - [InlineData("TableName", "TableName.Name = @Name0")] + [InlineData(null, "\"Name\" = @Name0")] + [InlineData("", "\"Name\" = @Name0")] + [InlineData("TableName", "\"TableName\".\"Name\" = @Name0")] public void SqliteWhereEqualsConditionBuildsValidQuery(string? tableName, string expectedQuery) { // Arrange @@ -48,9 +48,9 @@ public void SqliteWhereInConditionWithoutParameterNamesThrowsException() } [Theory] - [InlineData(null, "Name IN (@Name0, @Name1)")] - [InlineData("", "Name IN (@Name0, @Name1)")] - [InlineData("TableName", "TableName.Name IN (@Name0, @Name1)")] + [InlineData(null, "\"Name\" IN (@Name0, @Name1)")] + [InlineData("", "\"Name\" IN (@Name0, @Name1)")] + [InlineData("TableName", "\"TableName\".\"Name\" IN (@Name0, @Name1)")] public void SqliteWhereInConditionBuildsValidQuery(string? tableName, string expectedQuery) { // Arrange @@ -74,9 +74,9 @@ public void SqliteWhereMatchConditionWithoutParameterNamesThrowsException() } [Theory] - [InlineData(null, "Name MATCH @Name0")] - [InlineData("", "Name MATCH @Name0")] - [InlineData("TableName", "TableName.Name MATCH @Name0")] + [InlineData(null, "\"Name\" MATCH @Name0")] + [InlineData("", "\"Name\" MATCH @Name0")] + [InlineData("TableName", "\"TableName\".\"Name\" MATCH @Name0")] public void SqliteWhereMatchConditionBuildsValidQuery(string? tableName, string expectedQuery) { // Arrange diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteGenericDataModelMapperTests.cs deleted file mode 100644 index 3985672bd60e..000000000000 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteGenericDataModelMapperTests.cs +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.Sqlite; -using Xunit; - -namespace SemanticKernel.Connectors.Sqlite.UnitTests; - -/// -/// Unit tests for class. -/// -public sealed class SqliteGenericDataModelMapperTests -{ - [Fact] - public void MapFromDataToStorageModelWithStringKeyReturnsValidStorageModel() - { - // Arrange - var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); - var dataModel = GetGenericDataModel("key"); - - var mapper = new SqliteGenericDataModelMapper(propertyReader); - - // Act - var result = mapper.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal("key", result["Key"]); - Assert.Equal("Value1", result["StringProperty"]); - Assert.Equal(5, result["IntProperty"]); - - var vectorBytes = result["FloatVector"] as byte[]; - - Assert.NotNull(vectorBytes); - Assert.True(vectorBytes.Length > 0); - } - - [Fact] - public void MapFromDataToStorageModelWithNumericKeyReturnsValidStorageModel() - { - // Arrange - var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); - var dataModel = GetGenericDataModel(1); - - var mapper = new SqliteGenericDataModelMapper(propertyReader); - - // Act - var result = mapper.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal((ulong)1, result["Key"]); - Assert.Equal("Value1", result["StringProperty"]); - Assert.Equal(5, result["IntProperty"]); - - var vectorBytes = result["FloatVector"] as byte[]; - - Assert.NotNull(vectorBytes); - Assert.True(vectorBytes.Length > 0); - } - - [Theory] - [InlineData(true)] - [InlineData(false)] - public void MapFromStorageToDataModelWithStringKeyReturnsValidGenericModel(bool includeVectors) - { - // Arrange - var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); - var storageVector = SqliteVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); - - var storageModel = new Dictionary - { - ["Key"] = "key", - ["StringProperty"] = "Value1", - ["IntProperty"] = 5, - ["FloatVector"] = storageVector - }; - - var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); - - var mapper = new SqliteGenericDataModelMapper(propertyReader); - - // Act - var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); - - // Assert - Assert.Equal("key", result.Key); - Assert.Equal("Value1", result.Data["StringProperty"]); - Assert.Equal(5, result.Data["IntProperty"]); - - if (includeVectors) - { - Assert.NotNull(result.Vectors["FloatVector"]); - Assert.Equal(vector.ToArray(), ((ReadOnlyMemory)result.Vectors["FloatVector"]!).ToArray()); - } - else - { - Assert.False(result.Vectors.ContainsKey("FloatVector")); - } - } - - [Theory] - [InlineData(true)] - [InlineData(false)] - public void MapFromStorageToDataModelWithNumericKeyReturnsValidGenericModel(bool includeVectors) - { - // Arrange - var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); - var storageVector = SqliteVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); - - var storageModel = new Dictionary - { - ["Key"] = (ulong)1, - ["StringProperty"] = "Value1", - ["IntProperty"] = 5, - ["FloatVector"] = storageVector - }; - - var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); - - IVectorStoreRecordMapper, Dictionary> mapper = new SqliteGenericDataModelMapper(propertyReader); - - // Act - var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); - - // Assert - Assert.Equal((ulong)1, result.Key); - Assert.Equal("Value1", result.Data["StringProperty"]); - Assert.Equal(5, result.Data["IntProperty"]); - - if (includeVectors) - { - Assert.NotNull(result.Vectors["FloatVector"]); - Assert.Equal(vector.ToArray(), ((ReadOnlyMemory)result.Vectors["FloatVector"]!).ToArray()); - } - else - { - Assert.False(result.Vectors.ContainsKey("FloatVector")); - } - } - - #region private - - private static VectorStoreRecordDefinition GetRecordDefinition() - { - return new VectorStoreRecordDefinition - { - Properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(TKey)), - new VectorStoreRecordDataProperty("StringProperty", typeof(string)), - new VectorStoreRecordDataProperty("IntProperty", typeof(int)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - } - }; - } - - private static VectorStoreGenericDataModel GetGenericDataModel(TKey key) - { - return new VectorStoreGenericDataModel(key) - { - Data = new() - { - ["StringProperty"] = "Value1", - ["IntProperty"] = 5 - }, - Vectors = new() - { - ["FloatVector"] = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]) - } - }; - } - - private static VectorStoreRecordPropertyReader GetPropertyReader(VectorStoreRecordDefinition definition) - { - return new VectorStoreRecordPropertyReader(typeof(TRecord), definition, new() - { - RequiresAtLeastOneVector = false, - SupportsMultipleKeys = false, - SupportsMultipleVectors = true - }); - } - - #endregion -} diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteHotel.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteHotel.cs index 8adb64a8bc88..0deb6f8ade9e 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteHotel.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteHotel.cs @@ -12,7 +12,7 @@ public class SqliteHotel() public TKey? HotelId { get; init; } /// A string metadata field. - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string? HotelName { get; set; } /// An int metadata field. @@ -32,6 +32,6 @@ public class SqliteHotel() public string? Description { get; set; } /// A vector field. - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.EuclideanDistance)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.EuclideanDistance)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } } diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteMemoryStoreTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteMemoryStoreTests.cs index 5086709937ac..c231e711d6bf 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteMemoryStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteMemoryStoreTests.cs @@ -16,6 +16,7 @@ namespace SemanticKernel.Connectors.Sqlite.UnitTests; /// Unit tests of . /// [Collection("Sequential")] +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public sealed class SqliteMemoryStoreTests : IDisposable { private const string DatabaseFile = "SqliteMemoryStoreTests.db"; diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteServiceCollectionExtensionsTests.cs index 69488cf4d8d4..236fd84616f2 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteServiceCollectionExtensionsTests.cs @@ -1,12 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Data; -using Microsoft.Data.Sqlite; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.Sqlite; -using Moq; using Xunit; namespace SemanticKernel.Connectors.Sqlite.UnitTests; @@ -18,23 +15,11 @@ public sealed class SqliteServiceCollectionExtensionsTests { private readonly IServiceCollection _serviceCollection = new ServiceCollection(); - [Theory] - [InlineData(ConnectionState.Open)] - [InlineData(ConnectionState.Closed)] - public void AddVectorStoreRegistersClass(ConnectionState connectionState) + [Fact] + public void AddVectorStoreRegistersClass() { - // Arrange - var expectedOpenCalls = connectionState == ConnectionState.Closed ? 1 : 0; - - var mockConnection = new Mock(); - - mockConnection.Setup(l => l.State).Returns(connectionState); - mockConnection.Setup(l => l.Open()); - - this._serviceCollection.AddTransient((_) => mockConnection.Object); - // Act - this._serviceCollection.AddSqliteVectorStore(); + this._serviceCollection.AddSqliteVectorStore("Data Source=:memory:"); var serviceProvider = this._serviceCollection.BuildServiceProvider(); var vectorStore = serviceProvider.GetRequiredService(); @@ -42,78 +27,42 @@ public void AddVectorStoreRegistersClass(ConnectionState connectionState) // Assert Assert.NotNull(vectorStore); Assert.IsType(vectorStore); - - mockConnection.Verify(l => l.Open(), Times.Exactly(expectedOpenCalls)); } - [Theory] - [InlineData(ConnectionState.Open)] - [InlineData(ConnectionState.Closed)] - public void AddVectorStoreRecordCollectionWithStringKeyRegistersClass(ConnectionState connectionState) + [Fact] + public void AddVectorStoreRecordCollectionWithStringKeyRegistersClass() { - // Arrange - var expectedOpenCalls = connectionState == ConnectionState.Closed ? 1 : 0; - - var mockConnection = new Mock(); - - mockConnection.SetupSequence(l => l.State) - .Returns(connectionState) - .Returns(ConnectionState.Open); - - mockConnection.Setup(l => l.Open()); - - this._serviceCollection.AddTransient((_) => mockConnection.Object); - // Act - this._serviceCollection.AddSqliteVectorStoreRecordCollection("testcollection"); + this._serviceCollection.AddSqliteVectorStoreRecordCollection("testcollection", "Data Source=:memory:"); var serviceProvider = this._serviceCollection.BuildServiceProvider(); // Assert var collection = serviceProvider.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = serviceProvider.GetRequiredService>(); + var vectorizedSearch = serviceProvider.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); - - mockConnection.Verify(l => l.Open(), Times.Exactly(expectedOpenCalls)); + Assert.IsType>(vectorizedSearch); } - [Theory] - [InlineData(ConnectionState.Open)] - [InlineData(ConnectionState.Closed)] - public void AddVectorStoreRecordCollectionWithNumericKeyRegistersClass(ConnectionState connectionState) + [Fact] + public void AddVectorStoreRecordCollectionWithNumericKeyRegistersClass() { - // Arrange - var expectedOpenCalls = connectionState == ConnectionState.Closed ? 1 : 0; - - var mockConnection = new Mock(); - - mockConnection.SetupSequence(l => l.State) - .Returns(connectionState) - .Returns(ConnectionState.Open); - - mockConnection.Setup(l => l.Open()); - - this._serviceCollection.AddTransient((_) => mockConnection.Object); - // Act - this._serviceCollection.AddSqliteVectorStoreRecordCollection("testcollection"); + this._serviceCollection.AddSqliteVectorStoreRecordCollection("testcollection", "Data Source=:memory:"); var serviceProvider = this._serviceCollection.BuildServiceProvider(); // Assert var collection = serviceProvider.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = serviceProvider.GetRequiredService>(); + var vectorizedSearch = serviceProvider.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); - - mockConnection.Verify(l => l.Open(), Times.Exactly(expectedOpenCalls)); + Assert.IsType>(vectorizedSearch); } #region private diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs index 314e1e162420..22948ed099bc 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs @@ -2,6 +2,9 @@ using System; using System.Collections.Generic; +using Microsoft.Data.Sqlite; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.Sqlite; using Xunit; @@ -12,15 +15,13 @@ namespace SemanticKernel.Connectors.Sqlite.UnitTests; /// public sealed class SqliteVectorStoreCollectionCommandBuilderTests : IDisposable { - private readonly FakeDbCommand _command; - private readonly FakeDBConnection _connection; - private readonly SqliteVectorStoreCollectionCommandBuilder _commandBuilder; + private readonly SqliteCommand _command; + private readonly SqliteConnection _connection; public SqliteVectorStoreCollectionCommandBuilderTests() { - this._command = new(); - this._connection = new(this._command); - this._commandBuilder = new(this._connection); + this._command = new() { Connection = this._connection }; + this._connection = new(); } [Fact] @@ -30,7 +31,7 @@ public void ItBuildsTableCountCommand() const string TableName = "TestTable"; // Act - var command = this._commandBuilder.BuildTableCountCommand(TableName); + var command = SqliteVectorStoreCollectionCommandBuilder.BuildTableCountCommand(this._connection, TableName); // Assert Assert.Equal("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=@tableName;", command.CommandText); @@ -53,7 +54,7 @@ public void ItBuildsCreateTableCommand(bool ifNotExists) }; // Act - var command = this._commandBuilder.BuildCreateTableCommand(TableName, columns, ifNotExists); + var command = SqliteVectorStoreCollectionCommandBuilder.BuildCreateTableCommand(this._connection, TableName, columns, ifNotExists); // Assert Assert.Contains("CREATE TABLE", command.CommandText); @@ -61,8 +62,8 @@ public void ItBuildsCreateTableCommand(bool ifNotExists) Assert.Equal(ifNotExists, command.CommandText.Contains("IF NOT EXISTS")); - Assert.Contains("Column1 Type1 PRIMARY KEY", command.CommandText); - Assert.Contains("Column2 Type2 distance_metric=l2", command.CommandText); + Assert.Contains("\"Column1\" Type1 PRIMARY KEY", command.CommandText); + Assert.Contains("\"Column2\" Type2 distance_metric=l2", command.CommandText); } [Theory] @@ -81,7 +82,7 @@ public void ItBuildsCreateVirtualTableCommand(bool ifNotExists) }; // Act - var command = this._commandBuilder.BuildCreateVirtualTableCommand(TableName, columns, ifNotExists, ExtensionName); + var command = SqliteVectorStoreCollectionCommandBuilder.BuildCreateVirtualTableCommand(this._connection, TableName, columns, ifNotExists, ExtensionName); // Assert Assert.Contains("CREATE VIRTUAL TABLE", command.CommandText); @@ -101,10 +102,10 @@ public void ItBuildsDropTableCommand() const string TableName = "TestTable"; // Act - var command = this._commandBuilder.BuildDropTableCommand(TableName); + var command = SqliteVectorStoreCollectionCommandBuilder.BuildDropTableCommand(this._connection, TableName); // Assert - Assert.Equal("DROP TABLE IF EXISTS [TestTable];", command.CommandText); + Assert.Equal("DROP TABLE IF EXISTS \"TestTable\";", command.CommandText); } [Theory] @@ -116,7 +117,13 @@ public void ItBuildsInsertCommand(bool replaceIfExists) const string TableName = "TestTable"; const string RowIdentifier = "Id"; - var columnNames = new List { "Id", "Name", "Age", "Address" }; + VectorStoreRecordPropertyModel[] properties = + [ + new VectorStoreRecordKeyPropertyModel("Id", typeof(string)), + new VectorStoreRecordDataPropertyModel("Name", typeof(string)), + new VectorStoreRecordDataPropertyModel("Age", typeof(int)), + new VectorStoreRecordDataPropertyModel("Address", typeof(string)), + ]; var records = new List> { new() { ["Id"] = "IdValue1", ["Name"] = "NameValue1", ["Age"] = "AgeValue1", ["Address"] = "AddressValue1" }, @@ -124,17 +131,19 @@ public void ItBuildsInsertCommand(bool replaceIfExists) }; // Act - var command = this._commandBuilder.BuildInsertCommand( + var command = SqliteVectorStoreCollectionCommandBuilder.BuildInsertCommand( + this._connection, TableName, RowIdentifier, - columnNames, + properties, records, + data: true, replaceIfExists); // Assert Assert.Equal(replaceIfExists, command.CommandText.Contains("OR REPLACE")); - Assert.Contains($"INTO {TableName} (Id, Name, Age, Address)", command.CommandText); + Assert.Contains($"INTO \"{TableName}\" (\"Id\", \"Name\", \"Age\", \"Address\")", command.CommandText); Assert.Contains("VALUES (@Id0, @Name0, @Age0, @Address0)", command.CommandText); Assert.Contains("VALUES (@Id1, @Name1, @Age1, @Address1)", command.CommandText); Assert.Contains("RETURNING Id", command.CommandText); @@ -173,24 +182,35 @@ public void ItBuildsSelectCommand(string? orderByPropertyName) // Arrange const string TableName = "TestTable"; - var columnNames = new List { "Id", "Name", "Age", "Address" }; + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Id", typeof(string)), + new VectorStoreRecordDataProperty("Name", typeof(string)), + new VectorStoreRecordDataProperty("Age", typeof(string)), + new VectorStoreRecordDataProperty("Address", typeof(string)), + ]); var conditions = new List { new SqliteWhereEqualsCondition("Name", "NameValue"), new SqliteWhereInCondition("Age", [10, 20, 30]), }; + GetFilteredRecordOptions> filterOptions = new(); + if (!string.IsNullOrWhiteSpace(orderByPropertyName)) + { + filterOptions.OrderBy.Ascending(record => record[orderByPropertyName]); + } // Act - var command = this._commandBuilder.BuildSelectCommand(TableName, columnNames, conditions, orderByPropertyName); + var command = SqliteVectorStoreCollectionCommandBuilder.BuildSelectDataCommand>(this._connection, TableName, model, conditions, filterOptions); // Assert - Assert.Contains("SELECT Id, Name, Age, Address", command.CommandText); - Assert.Contains($"FROM {TableName}", command.CommandText); + Assert.Contains("SELECT \"Id\",\"Name\",\"Age\",\"Address\"", command.CommandText); + Assert.Contains($"FROM \"{TableName}\"", command.CommandText); - Assert.Contains("Name = @Name0", command.CommandText); - Assert.Contains("Age IN (@Age0, @Age1, @Age2)", command.CommandText); + Assert.Contains("\"Name\" = @Name0", command.CommandText); + Assert.Contains("\"Age\" IN (@Age0, @Age1, @Age2)", command.CommandText); - Assert.Equal(!string.IsNullOrWhiteSpace(orderByPropertyName), command.CommandText.Contains($"ORDER BY {orderByPropertyName}")); + Assert.Equal(!string.IsNullOrWhiteSpace(orderByPropertyName), command.CommandText.Contains($"ORDER BY \"{orderByPropertyName}\"")); Assert.Equal("@Name0", command.Parameters[0].ParameterName); Assert.Equal("NameValue", command.Parameters[0].Value); @@ -212,41 +232,50 @@ public void ItBuildsSelectCommand(string? orderByPropertyName) public void ItBuildsSelectLeftJoinCommand(string? orderByPropertyName) { // Arrange - const string LeftTable = "LeftTable"; - const string RightTable = "RightTable"; + const string DataTable = "DataTable"; + const string VectorTable = "VectorTable"; const string JoinColumnName = "Id"; - var leftTablePropertyNames = new List { "Id", "Name" }; - var rightTablePropertyNames = new List { "Age", "Address" }; + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("Id", typeof(string)), + new VectorStoreRecordDataProperty("Name", typeof(string)), + new VectorStoreRecordVectorProperty("Age", typeof(ReadOnlyMemory), 10), + new VectorStoreRecordVectorProperty("Address", typeof(ReadOnlyMemory), 10), + ]); var conditions = new List { new SqliteWhereEqualsCondition("Name", "NameValue"), new SqliteWhereInCondition("Age", [10, 20, 30]), }; + GetFilteredRecordOptions> filterOptions = new(); + if (!string.IsNullOrWhiteSpace(orderByPropertyName)) + { + filterOptions.OrderBy.Ascending(record => record[orderByPropertyName]); + } // Act - var command = this._commandBuilder.BuildSelectLeftJoinCommand( - LeftTable, - RightTable, + var command = SqliteVectorStoreCollectionCommandBuilder.BuildSelectLeftJoinCommand( + this._connection, + VectorTable, + DataTable, JoinColumnName, - leftTablePropertyNames, - rightTablePropertyNames, + model, conditions, - extraWhereFilter: null, - extraParameters: null, - orderByPropertyName); + true, + filterOptions); // Assert - Assert.Contains("SELECT LeftTable.Id, LeftTable.Name, RightTable.Age, RightTable.Address", command.CommandText); - Assert.Contains("FROM LeftTable", command.CommandText); + Assert.Contains("SELECT \"DataTable\".\"Id\",\"DataTable\".\"Name\",\"VectorTable\".\"Age\",\"VectorTable\".\"Address\"", command.CommandText); + Assert.Contains("FROM \"VectorTable\"", command.CommandText); - Assert.Contains("LEFT JOIN RightTable ON LeftTable.Id = RightTable.Id", command.CommandText); + Assert.Contains("LEFT JOIN \"DataTable\" ON \"VectorTable\".\"Id\" = \"DataTable\".\"Id\"", command.CommandText); - Assert.Contains("Name = @Name0", command.CommandText); - Assert.Contains("Age IN (@Age0, @Age1, @Age2)", command.CommandText); + Assert.Contains("\"Name\" = @Name0", command.CommandText); + Assert.Contains("\"Age\" IN (@Age0, @Age1, @Age2)", command.CommandText); - Assert.Equal(!string.IsNullOrWhiteSpace(orderByPropertyName), command.CommandText.Contains($"ORDER BY {orderByPropertyName}")); + Assert.Equal(!string.IsNullOrWhiteSpace(orderByPropertyName), command.CommandText.Contains($"ORDER BY \"DataTable\".\"{orderByPropertyName}\"")); Assert.Equal("@Name0", command.Parameters[0].ParameterName); Assert.Equal("NameValue", command.Parameters[0].Value); @@ -274,13 +303,13 @@ public void ItBuildsDeleteCommand() }; // Act - var command = this._commandBuilder.BuildDeleteCommand(TableName, conditions); + var command = SqliteVectorStoreCollectionCommandBuilder.BuildDeleteCommand(this._connection, TableName, conditions); // Assert - Assert.Contains("DELETE FROM [TestTable]", command.CommandText); + Assert.Contains("DELETE FROM \"TestTable\"", command.CommandText); - Assert.Contains("Name = @Name0", command.CommandText); - Assert.Contains("Age IN (@Age0, @Age1, @Age2)", command.CommandText); + Assert.Contains("\"Name\" = @Name0", command.CommandText); + Assert.Contains("\"Age\" IN (@Age0, @Age1, @Age2)", command.CommandText); Assert.Equal("@Name0", command.Parameters[0].ParameterName); Assert.Equal("NameValue", command.Parameters[0].Value); @@ -300,4 +329,11 @@ public void Dispose() this._command.Dispose(); this._connection.Dispose(); } + + private static VectorStoreRecordModel BuildModel(List properties) + => new VectorStoreRecordModelBuilder(SqliteConstants.ModelBuildingOptions) + .Build( + typeof(Dictionary), + new() { Properties = properties }, + defaultEmbeddingGenerator: null); } diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordCollectionTests.cs index 631bf6cebf3d..3c93b99e6dba 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordCollectionTests.cs @@ -1,5 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. +// TODO: Reimplement these as integration tests, #10464 + +#if DISABLED + using System; using System.Collections.Generic; using System.Data.Common; @@ -116,8 +120,7 @@ public async Task VectorizedSearchReturnsRecordAsync(bool includeVectors) var sut = new SqliteVectorStoreRecordCollection>(fakeConnection, "VectorizedSearch"); // Act - var results = await sut.VectorizedSearchAsync(expectedRecord.Vector, new() { IncludeVectors = includeVectors }); - var result = await results.Results.FirstOrDefaultAsync(); + var result = await sut.VectorizedSearchAsync(expectedRecord.Vector, new() { IncludeVectors = includeVectors }).FirstOrDefaultAsync(); // Assert Assert.NotNull(result); @@ -400,3 +403,5 @@ private sealed class TestRecordWithoutVectorProperty #endregion } + +#endif diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordMapperTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordMapperTests.cs index 705b5caa7204..4b023f7ba9d9 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordMapperTests.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordMapperTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.Sqlite; using Xunit; @@ -18,13 +19,13 @@ public void MapFromDataToStorageModelWithStringKeyReturnsValidStorageModel() { // Arrange var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); + var model = BuildModel(typeof(TestRecord), definition); var dataModel = GetDataModel("key"); - var mapper = new SqliteVectorStoreRecordMapper>(propertyReader); + var mapper = new SqliteVectorStoreRecordMapper>(model); // Act - var result = mapper.MapFromDataToStorageModel(dataModel); + var result = mapper.MapFromDataToStorageModel(dataModel, recordIndex: 0, generatedEmbeddings: null); // Assert Assert.Equal("key", result["Key"]); @@ -42,13 +43,13 @@ public void MapFromDataToStorageModelWithNumericKeyReturnsValidStorageModel() { // Arrange var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); + var model = BuildModel(typeof(TestRecord), definition); var dataModel = GetDataModel(1); - var mapper = new SqliteVectorStoreRecordMapper>(propertyReader); + var mapper = new SqliteVectorStoreRecordMapper>(model); // Act - var result = mapper.MapFromDataToStorageModel(dataModel); + var result = mapper.MapFromDataToStorageModel(dataModel, recordIndex: 0, generatedEmbeddings: null); // Assert Assert.Equal((ulong)1, result["Key"]); @@ -64,7 +65,7 @@ public void MapFromDataToStorageModelWithNumericKeyReturnsValidStorageModel() [Theory] [InlineData(true)] [InlineData(false)] - public void MapFromStorageToDataModelWithStringKeyReturnsValidGenericModel(bool includeVectors) + public void MapFromStorageToDataModelWithStringKeyReturnsValidDynamicModel(bool includeVectors) { // Arrange var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); @@ -79,9 +80,9 @@ public void MapFromStorageToDataModelWithStringKeyReturnsValidGenericModel(bool }; var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); + var model = BuildModel(typeof(TestRecord), definition); - var mapper = new SqliteVectorStoreRecordMapper>(propertyReader); + var mapper = new SqliteVectorStoreRecordMapper>(model); // Act var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); @@ -105,7 +106,7 @@ public void MapFromStorageToDataModelWithStringKeyReturnsValidGenericModel(bool [Theory] [InlineData(true)] [InlineData(false)] - public void MapFromStorageToDataModelWithNumericKeyReturnsValidGenericModel(bool includeVectors) + public void MapFromStorageToDataModelWithNumericKeyReturnsValidDynamicModel(bool includeVectors) { // Arrange var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); @@ -120,9 +121,9 @@ public void MapFromStorageToDataModelWithNumericKeyReturnsValidGenericModel(bool }; var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); + var model = BuildModel(typeof(TestRecord), definition); - var mapper = new SqliteVectorStoreRecordMapper>(propertyReader); + var mapper = new SqliteVectorStoreRecordMapper>(model); // Act var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); @@ -154,7 +155,7 @@ private static VectorStoreRecordDefinition GetRecordDefinition() new VectorStoreRecordKeyProperty("Key", typeof(TKey)), new VectorStoreRecordDataProperty("StringProperty", typeof(string)), new VectorStoreRecordDataProperty("IntProperty", typeof(int)), - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10), } }; } @@ -170,15 +171,8 @@ private static TestRecord GetDataModel(TKey key) }; } - private static VectorStoreRecordPropertyReader GetPropertyReader(VectorStoreRecordDefinition definition) - { - return new VectorStoreRecordPropertyReader(typeof(TRecord), definition, new() - { - RequiresAtLeastOneVector = false, - SupportsMultipleKeys = false, - SupportsMultipleVectors = true - }); - } + private static VectorStoreRecordModel BuildModel(Type type, VectorStoreRecordDefinition definition) + => new VectorStoreRecordModelBuilder(SqliteConstants.ModelBuildingOptions).Build(type, definition, defaultEmbeddingGenerator: null); #pragma warning disable CA1812 private sealed class TestRecord @@ -192,7 +186,7 @@ private sealed class TestRecord [VectorStoreRecordData] public int? IntProperty { get; set; } - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineDistance)] public ReadOnlyMemory? FloatVector { get; set; } } #pragma warning restore CA1812 diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordPropertyMappingTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordPropertyMappingTests.cs index 19ec51b2f1a2..b2036be6fcc7 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordPropertyMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordPropertyMappingTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.Sqlite; using Xunit; @@ -51,43 +52,50 @@ public void MapVectorForDataModelReturnsReadOnlyMemory() Assert.Equal(vector.Span.ToArray(), dataModelVector.Span.ToArray()); } - [Fact] - public void GetColumnsReturnsCollectionOfColumns() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void GetColumnsReturnsCollectionOfColumns(bool data) { // Arrange - var properties = new List() - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("Data", typeof(int)), - new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) { Dimensions = 4, DistanceFunction = DistanceFunction.ManhattanDistance }, - }; - - var storagePropertyNames = new Dictionary + var properties = new List() { - ["Key"] = "Key", - ["Data"] = "my_data", - ["Vector"] = "Vector" + new VectorStoreRecordKeyPropertyModel("Key", typeof(string)) { StorageName = "Key" }, + new VectorStoreRecordDataPropertyModel("Data", typeof(int)) { StorageName = "my_data", IsIndexed = true }, + new VectorStoreRecordVectorPropertyModel("Vector", typeof(ReadOnlyMemory)) + { + Dimensions = 4, + DistanceFunction = DistanceFunction.ManhattanDistance, + StorageName = "Vector" + } }; // Act - var columns = SqliteVectorStoreRecordPropertyMapping.GetColumns(properties, storagePropertyNames); + var columns = SqliteVectorStoreRecordPropertyMapping.GetColumns(properties, data: data); // Assert Assert.Equal("Key", columns[0].Name); Assert.Equal("TEXT", columns[0].Type); Assert.True(columns[0].IsPrimary); Assert.Null(columns[0].Configuration); + Assert.False(columns[0].HasIndex); - Assert.Equal("my_data", columns[1].Name); - Assert.Equal("INTEGER", columns[1].Type); - Assert.False(columns[1].IsPrimary); - Assert.Null(columns[1].Configuration); - - Assert.Equal("Vector", columns[2].Name); - Assert.Equal("FLOAT[4]", columns[2].Type); - Assert.False(columns[2].IsPrimary); - Assert.NotNull(columns[2].Configuration); - - Assert.Equal("l1", columns[2].Configuration!["distance_metric"]); + if (data) + { + Assert.Equal("my_data", columns[1].Name); + Assert.Equal("INTEGER", columns[1].Type); + Assert.False(columns[1].IsPrimary); + Assert.Null(columns[1].Configuration); + Assert.True(columns[1].HasIndex); + } + else + { + Assert.Equal("Vector", columns[1].Name); + Assert.Equal("FLOAT[4]", columns[1].Type); + Assert.False(columns[1].IsPrimary); + Assert.NotNull(columns[1].Configuration); + Assert.False(columns[1].HasIndex); + Assert.Equal("l1", columns[1].Configuration!["distance_metric"]); + } } } diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreTests.cs index 44180405aaa3..74b27b4ef046 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreTests.cs @@ -1,5 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. +// TODO: Reimplement these as integration tests, #10464 + +#if DISABLED + using System; using System.Data.Common; using System.Linq; @@ -104,3 +108,5 @@ public async Task ListCollectionNamesReturnsCollectionNamesAsync() Assert.Contains("collection2", collections); } } + +#endif diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj b/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj index 8d27acc67d94..2573821640b1 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj @@ -8,7 +8,7 @@ enable disable false - $(NoWarn);CA2007,CA1806,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0050,SKEXP0120 + $(NoWarn);CA2007,CA1806,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0050,SKEXP0120 diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Chroma/ChromaMemoryStoreTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Chroma/ChromaMemoryStoreTests.cs index fbbf445ef7e7..bf2da33d900c 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Chroma/ChromaMemoryStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Chroma/ChromaMemoryStoreTests.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Net.Http; using System.Text.Json; @@ -18,6 +19,7 @@ namespace SemanticKernel.Connectors.UnitTests.Chroma; /// /// Unit tests for class. /// +[Experimental("SKEXP0020")] public sealed class ChromaMemoryStoreTests : IDisposable { private const string CollectionId = "fake-collection-id"; diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/DuckDB/DuckDBMemoryStoreTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/DuckDB/DuckDBMemoryStoreTests.cs index e9e09599a1b4..b6c706734d30 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/DuckDB/DuckDBMemoryStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/DuckDB/DuckDBMemoryStoreTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; using System.Threading.Tasks; @@ -16,6 +17,7 @@ namespace SemanticKernel.Connectors.UnitTests.DuckDB; /// /// Unit tests of . /// +[Experimental("SKEXP0020")] [Collection("Sequential")] public class DuckDBMemoryStoreTests { diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Kusto/KustoMemoryStoreTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Kusto/KustoMemoryStoreTests.cs index 7cdec0210775..e586b58fe0cc 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Kusto/KustoMemoryStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Kusto/KustoMemoryStoreTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Data; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -19,6 +20,7 @@ namespace SemanticKernel.Connectors.UnitTests.Kusto; /// /// Unit tests for class. /// +[Experimental("SKEXP0020")] public class KustoMemoryStoreTests { private const string CollectionName = "fake_collection"; diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/MongoDB/MongoDBMemoryStoreTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/MongoDB/MongoDBMemoryStoreTests.cs index 4abfbf941498..a9e1589161c4 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/MongoDB/MongoDBMemoryStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/MongoDB/MongoDBMemoryStoreTests.cs @@ -17,6 +17,7 @@ namespace SemanticKernel.Connectors.UnitTests.MongoDB; /// /// Unit tests for class. /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class MongoDBMemoryStoreTests { private const string CollectionName = "test-collection"; diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresMemoryStoreTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresMemoryStoreTests.cs index 928a30568ae6..e1f0f17df187 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresMemoryStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresMemoryStoreTests.cs @@ -15,6 +15,7 @@ namespace SemanticKernel.Connectors.UnitTests.Postgres; /// /// Unit tests for class. /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class PostgresMemoryStoreTests { private const string CollectionName = "fake-collection-name"; diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Sqlite/SqliteMemoryStoreTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Sqlite/SqliteMemoryStoreTests.cs index e91a1794d2a8..0de180a013e8 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Sqlite/SqliteMemoryStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Sqlite/SqliteMemoryStoreTests.cs @@ -16,6 +16,7 @@ namespace SemanticKernel.Connectors.UnitTests.Sqlite; /// Unit tests of . /// [Collection("Sequential")] +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public sealed class SqliteMemoryStoreTests : IDisposable { private const string DatabaseFile = "SqliteMemoryStoreTests.db"; diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/Connectors.Weaviate.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/Connectors.Weaviate.UnitTests.csproj index ca442f3b3233..8312d6ba5b60 100644 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/Connectors.Weaviate.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/Connectors.Weaviate.UnitTests.csproj @@ -8,7 +8,8 @@ enable disable false - $(NoWarn);SKEXP0001,SKEXP0020,VSTHRD111,CA2007 + $(NoWarn);SKEXP0001,VSTHRD111,CA2007 + $(NoWarn);MEVD9001 diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateDynamicDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateDynamicDataModelMapperTests.cs new file mode 100644 index 000000000000..772973fa47ac --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateDynamicDataModelMapperTests.cs @@ -0,0 +1,479 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.SemanticKernel.Connectors.Weaviate; +using Xunit; + +namespace SemanticKernel.Connectors.Weaviate.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class WeaviateDynamicDataModelMapperTests +{ + private const bool HasNamedVectors = true; + + private static readonly JsonSerializerOptions s_jsonSerializerOptions = new() + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + Converters = + { + new WeaviateDateTimeOffsetConverter(), + new WeaviateNullableDateTimeOffsetConverter() + } + }; + + private static readonly VectorStoreRecordModel s_model = new WeaviateModelBuilder(HasNamedVectors) + .Build( + typeof(Dictionary), + new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(Guid)), + + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordDataProperty("BoolDataProp", typeof(bool)), + new VectorStoreRecordDataProperty("NullableBoolDataProp", typeof(bool?)), + new VectorStoreRecordDataProperty("IntDataProp", typeof(int)), + new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), + new VectorStoreRecordDataProperty("LongDataProp", typeof(long)), + new VectorStoreRecordDataProperty("NullableLongDataProp", typeof(long?)), + new VectorStoreRecordDataProperty("ShortDataProp", typeof(short)), + new VectorStoreRecordDataProperty("NullableShortDataProp", typeof(short?)), + new VectorStoreRecordDataProperty("ByteDataProp", typeof(byte)), + new VectorStoreRecordDataProperty("NullableByteDataProp", typeof(byte?)), + new VectorStoreRecordDataProperty("FloatDataProp", typeof(float)), + new VectorStoreRecordDataProperty("NullableFloatDataProp", typeof(float?)), + new VectorStoreRecordDataProperty("DoubleDataProp", typeof(double)), + new VectorStoreRecordDataProperty("NullableDoubleDataProp", typeof(double?)), + new VectorStoreRecordDataProperty("DecimalDataProp", typeof(decimal)), + new VectorStoreRecordDataProperty("NullableDecimalDataProp", typeof(decimal?)), + new VectorStoreRecordDataProperty("DateTimeDataProp", typeof(DateTime)), + new VectorStoreRecordDataProperty("NullableDateTimeDataProp", typeof(DateTime?)), + new VectorStoreRecordDataProperty("DateTimeOffsetDataProp", typeof(DateTimeOffset)), + new VectorStoreRecordDataProperty("NullableDateTimeOffsetDataProp", typeof(DateTimeOffset?)), + new VectorStoreRecordDataProperty("GuidDataProp", typeof(Guid)), + new VectorStoreRecordDataProperty("NullableGuidDataProp", typeof(Guid?)), + new VectorStoreRecordDataProperty("TagListDataProp", typeof(List)), + + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10), + new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?), 10), + new VectorStoreRecordVectorProperty("DoubleVector", typeof(ReadOnlyMemory), 10), + new VectorStoreRecordVectorProperty("NullableDoubleVector", typeof(ReadOnlyMemory?), 10) + ] + }, + defaultEmbeddingGenerator: null, + s_jsonSerializerOptions); + + private static readonly float[] s_floatVector = [1.0f, 2.0f, 3.0f]; + private static readonly double[] s_doubleVector = [1.0f, 2.0f, 3.0f]; + private static readonly List s_taglist = ["tag1", "tag2"]; + + [Fact] + public void MapFromDataToStorageModelMapsAllSupportedTypes() + { + // Arrange + var key = new Guid("55555555-5555-5555-5555-555555555555"); + var sut = new WeaviateDynamicDataModelMapper("Collection", HasNamedVectors, s_model, s_jsonSerializerOptions); + + var dataModel = new Dictionary + { + ["Key"] = key, + + ["StringDataProp"] = "string", + ["BoolDataProp"] = true, + ["NullableBoolDataProp"] = false, + ["IntDataProp"] = 1, + ["NullableIntDataProp"] = 2, + ["LongDataProp"] = 3L, + ["NullableLongDataProp"] = 4L, + ["ShortDataProp"] = (short)5, + ["NullableShortDataProp"] = (short)6, + ["ByteDataProp"] = (byte)7, + ["NullableByteDataProp"] = (byte)8, + ["FloatDataProp"] = 9.0f, + ["NullableFloatDataProp"] = 10.0f, + ["DoubleDataProp"] = 11.0, + ["NullableDoubleDataProp"] = 12.0, + ["DecimalDataProp"] = 13.99m, + ["NullableDecimalDataProp"] = 14.00m, + ["DateTimeDataProp"] = new DateTime(2021, 1, 1), + ["NullableDateTimeDataProp"] = new DateTime(2021, 1, 1), + ["DateTimeOffsetDataProp"] = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), + ["NullableDateTimeOffsetDataProp"] = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), + ["GuidDataProp"] = new Guid("11111111-1111-1111-1111-111111111111"), + ["NullableGuidDataProp"] = new Guid("22222222-2222-2222-2222-222222222222"), + ["TagListDataProp"] = s_taglist, + + ["FloatVector"] = new ReadOnlyMemory(s_floatVector), + ["NullableFloatVector"] = new ReadOnlyMemory(s_floatVector), + ["DoubleVector"] = new ReadOnlyMemory(s_doubleVector), + ["NullableDoubleVector"] = new ReadOnlyMemory(s_doubleVector), + } + ; + + // Act + var storageModel = sut.MapFromDataToStorageModel(dataModel, recordIndex: 0, generatedEmbeddings: null); + + // Assert + Assert.Equal(key, (Guid?)storageModel["id"]); + Assert.Equal("Collection", (string?)storageModel["class"]); + Assert.Equal("string", (string?)storageModel["properties"]?["stringDataProp"]); + Assert.Equal(true, (bool?)storageModel["properties"]?["boolDataProp"]); + Assert.Equal(false, (bool?)storageModel["properties"]?["nullableBoolDataProp"]); + Assert.Equal(1, (int?)storageModel["properties"]?["intDataProp"]); + Assert.Equal(2, (int?)storageModel["properties"]?["nullableIntDataProp"]); + Assert.Equal(3L, (long?)storageModel["properties"]?["longDataProp"]); + Assert.Equal(4L, (long?)storageModel["properties"]?["nullableLongDataProp"]); + Assert.Equal((short)5, (short?)storageModel["properties"]?["shortDataProp"]); + Assert.Equal((short)6, (short?)storageModel["properties"]?["nullableShortDataProp"]); + Assert.Equal((byte)7, (byte?)storageModel["properties"]?["byteDataProp"]); + Assert.Equal((byte)8, (byte?)storageModel["properties"]?["nullableByteDataProp"]); + Assert.Equal(9.0f, (float?)storageModel["properties"]?["floatDataProp"]); + Assert.Equal(10.0f, (float?)storageModel["properties"]?["nullableFloatDataProp"]); + Assert.Equal(11.0, (double?)storageModel["properties"]?["doubleDataProp"]); + Assert.Equal(12.0, (double?)storageModel["properties"]?["nullableDoubleDataProp"]); + Assert.Equal(13.99m, (decimal?)storageModel["properties"]?["decimalDataProp"]); + Assert.Equal(14.00m, (decimal?)storageModel["properties"]?["nullableDecimalDataProp"]); + Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0), (DateTime?)storageModel["properties"]?["dateTimeDataProp"]); + Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0), (DateTime?)storageModel["properties"]?["nullableDateTimeDataProp"]); + Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), (DateTimeOffset?)storageModel["properties"]?["dateTimeOffsetDataProp"]); + Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), (DateTimeOffset?)storageModel["properties"]?["nullableDateTimeOffsetDataProp"]); + Assert.Equal(new Guid("11111111-1111-1111-1111-111111111111"), (Guid?)storageModel["properties"]?["guidDataProp"]); + Assert.Equal(new Guid("22222222-2222-2222-2222-222222222222"), (Guid?)storageModel["properties"]?["nullableGuidDataProp"]); + Assert.Equal(s_taglist, storageModel["properties"]?["tagListDataProp"]!.AsArray().GetValues().ToArray()); + Assert.Equal(s_floatVector, storageModel["vectors"]?["floatVector"]!.AsArray().GetValues().ToArray()); + Assert.Equal(s_floatVector, storageModel["vectors"]?["nullableFloatVector"]!.AsArray().GetValues().ToArray()); + Assert.Equal(s_doubleVector, storageModel["vectors"]?["doubleVector"]!.AsArray().GetValues().ToArray()); + Assert.Equal(s_doubleVector, storageModel["vectors"]?["nullableDoubleVector"]!.AsArray().GetValues().ToArray()); + } + + [Fact] + public void MapFromDataToStorageModelMapsNullValues() + { + // Arrange + var key = new Guid("55555555-5555-5555-5555-555555555555"); + var keyProperty = new VectorStoreRecordKeyProperty("Key", typeof(Guid)); + + var dataProperties = new List + { + new("StringDataProp", typeof(string)), + new("NullableIntDataProp", typeof(int?)), + }; + + var vectorProperties = new List + { + new("NullableFloatVector", typeof(ReadOnlyMemory?), 10) + }; + + var dataModel = new Dictionary + { + ["Key"] = key, + + ["StringDataProp"] = null, + ["NullableIntDataProp"] = null, + + ["NullableFloatVector"] = null + }; + + var sut = new WeaviateDynamicDataModelMapper("Collection", HasNamedVectors, s_model, s_jsonSerializerOptions); + + // Act + var storageModel = sut.MapFromDataToStorageModel(dataModel, recordIndex: 0, generatedEmbeddings: null); + + // Assert + Assert.Null(storageModel["StringDataProp"]); + Assert.Null(storageModel["NullableIntDataProp"]); + Assert.Null(storageModel["NullableFloatVector"]); + } + + [Fact] + public void MapFromStorageToDataModelMapsAllSupportedTypes() + { + // Arrange + var key = new Guid("55555555-5555-5555-5555-555555555555"); + var sut = new WeaviateDynamicDataModelMapper("Collection", HasNamedVectors, s_model, s_jsonSerializerOptions); + + var storageModel = new JsonObject + { + ["id"] = key, + ["properties"] = new JsonObject + { + ["stringDataProp"] = "string", + ["boolDataProp"] = true, + ["nullableBoolDataProp"] = false, + ["intDataProp"] = 1, + ["nullableIntDataProp"] = 2, + ["longDataProp"] = 3L, + ["nullableLongDataProp"] = 4L, + ["shortDataProp"] = (short)5, + ["nullableShortDataProp"] = (short)6, + ["byteDataProp"] = (byte)7, + ["nullableByteDataProp"] = (byte)8, + ["floatDataProp"] = 9.0f, + ["nullableFloatDataProp"] = 10.0f, + ["doubleDataProp"] = 11.0, + ["nullableDoubleDataProp"] = 12.0, + ["decimalDataProp"] = 13.99m, + ["nullableDecimalDataProp"] = 14.00m, + ["dateTimeDataProp"] = new DateTime(2021, 1, 1), + ["nullableDateTimeDataProp"] = new DateTime(2021, 1, 1), + ["dateTimeOffsetDataProp"] = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), + ["nullableDateTimeOffsetDataProp"] = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), + ["guidDataProp"] = new Guid("11111111-1111-1111-1111-111111111111"), + ["nullableGuidDataProp"] = new Guid("22222222-2222-2222-2222-222222222222"), + ["tagListDataProp"] = new JsonArray(s_taglist.Select(l => (JsonValue)l).ToArray()) + }, + ["vectors"] = new JsonObject + { + ["floatVector"] = new JsonArray(s_floatVector.Select(l => (JsonValue)l).ToArray()), + ["nullableFloatVector"] = new JsonArray(s_floatVector.Select(l => (JsonValue)l).ToArray()), + ["doubleVector"] = new JsonArray(s_doubleVector.Select(l => (JsonValue)l).ToArray()), + ["nullableDoubleVector"] = new JsonArray(s_doubleVector.Select(l => (JsonValue)l).ToArray()), + } + }; + + // Act + var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); + + // Assert + Assert.Equal(key, dataModel["Key"]); + Assert.Equal("string", dataModel["StringDataProp"]); + Assert.Equal(true, dataModel["BoolDataProp"]); + Assert.Equal(false, dataModel["NullableBoolDataProp"]); + Assert.Equal(1, dataModel["IntDataProp"]); + Assert.Equal(2, dataModel["NullableIntDataProp"]); + Assert.Equal(3L, dataModel["LongDataProp"]); + Assert.Equal(4L, dataModel["NullableLongDataProp"]); + Assert.Equal((short)5, dataModel["ShortDataProp"]); + Assert.Equal((short)6, dataModel["NullableShortDataProp"]); + Assert.Equal((byte)7, dataModel["ByteDataProp"]); + Assert.Equal((byte)8, dataModel["NullableByteDataProp"]); + Assert.Equal(9.0f, dataModel["FloatDataProp"]); + Assert.Equal(10.0f, dataModel["NullableFloatDataProp"]); + Assert.Equal(11.0, dataModel["DoubleDataProp"]); + Assert.Equal(12.0, dataModel["NullableDoubleDataProp"]); + Assert.Equal(13.99m, dataModel["DecimalDataProp"]); + Assert.Equal(14.00m, dataModel["NullableDecimalDataProp"]); + Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0), dataModel["DateTimeDataProp"]); + Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0), dataModel["NullableDateTimeDataProp"]); + Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), dataModel["DateTimeOffsetDataProp"]); + Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), dataModel["NullableDateTimeOffsetDataProp"]); + Assert.Equal(new Guid("11111111-1111-1111-1111-111111111111"), dataModel["GuidDataProp"]); + Assert.Equal(new Guid("22222222-2222-2222-2222-222222222222"), dataModel["NullableGuidDataProp"]); + Assert.Equal(s_taglist, dataModel["TagListDataProp"]); + Assert.Equal(s_floatVector, ((ReadOnlyMemory)dataModel["FloatVector"]!).ToArray()); + Assert.Equal(s_floatVector, ((ReadOnlyMemory)dataModel["NullableFloatVector"]!)!.ToArray()); + Assert.Equal(s_doubleVector, ((ReadOnlyMemory)dataModel["DoubleVector"]!).ToArray()); + Assert.Equal(s_doubleVector, ((ReadOnlyMemory)dataModel["NullableDoubleVector"]!)!.ToArray()); + } + + [Fact] + public void MapFromStorageToDataModelMapsNullValues() + { + // Arrange + var key = new Guid("55555555-5555-5555-5555-555555555555"); + var keyProperty = new VectorStoreRecordKeyProperty("Key", typeof(Guid)); + + var dataProperties = new List + { + new("StringDataProp", typeof(string)), + new("NullableIntDataProp", typeof(int?)), + }; + + var vectorProperties = new List + { + new("NullableFloatVector", typeof(ReadOnlyMemory?), 10) + }; + + var storageModel = new JsonObject + { + ["id"] = key, + ["properties"] = new JsonObject + { + ["stringDataProp"] = null, + ["nullableIntDataProp"] = null, + }, + ["vectors"] = new JsonObject + { + ["nullableFloatVector"] = null + } + }; + + var sut = new WeaviateDynamicDataModelMapper("Collection", HasNamedVectors, s_model, s_jsonSerializerOptions); + + // Act + var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); + + // Assert + Assert.Equal(key, dataModel["Key"]); + Assert.Null(dataModel["StringDataProp"]); + Assert.Null(dataModel["NullableIntDataProp"]); + Assert.Null(dataModel["NullableFloatVector"]); + } + + [Fact] + public void MapFromStorageToDataModelThrowsForMissingKey() + { + // Arrange + var sut = new WeaviateDynamicDataModelMapper("Collection", HasNamedVectors, s_model, s_jsonSerializerOptions); + + var storageModel = new JsonObject(); + + // Act & Assert + var exception = Assert.Throws( + () => sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true })); + } + + [Fact] + public void MapFromDataToStorageModelSkipsMissingProperties() + { + // Arrange + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(Guid)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10) + ] + }; + + var model = new WeaviateModelBuilder(HasNamedVectors).Build(typeof(Dictionary), recordDefinition, defaultEmbeddingGenerator: null, s_jsonSerializerOptions); + + var key = new Guid("55555555-5555-5555-5555-555555555555"); + + var record = new Dictionary { ["Key"] = key }; + var sut = new WeaviateDynamicDataModelMapper("Collection", HasNamedVectors, model, s_jsonSerializerOptions); + + // Act + var storageModel = sut.MapFromDataToStorageModel(record, recordIndex: 0, generatedEmbeddings: null); + + // Assert + Assert.Equal(key, (Guid?)storageModel["id"]); + Assert.False(storageModel.ContainsKey("StringDataProp")); + Assert.False(storageModel.ContainsKey("FloatVector")); + } + + [Fact] + public void MapFromStorageToDataModelSkipsMissingProperties() + { + // Arrange + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(Guid)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10) + ] + }; + + var model = new WeaviateModelBuilder(HasNamedVectors).Build(typeof(Dictionary), recordDefinition, defaultEmbeddingGenerator: null, s_jsonSerializerOptions); + + var key = new Guid("55555555-5555-5555-5555-555555555555"); + + var sut = new WeaviateDynamicDataModelMapper("Collection", HasNamedVectors, model, s_jsonSerializerOptions); + + var storageModel = new JsonObject + { + ["id"] = key + }; + + // Act + var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); + + // Assert + Assert.Equal(key, dataModel["Key"]); + Assert.False(dataModel.ContainsKey("StringDataProp")); + Assert.False(dataModel.ContainsKey("FloatVector")); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFromDataToStorageModelMapsNamedVectorsCorrectly(bool hasNamedVectors) + { + // Arrange + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(Guid)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 4) + ] + }; + + var model = new WeaviateModelBuilder(hasNamedVectors).Build(typeof(Dictionary), recordDefinition, defaultEmbeddingGenerator: null, s_jsonSerializerOptions); + + var key = new Guid("55555555-5555-5555-5555-555555555555"); + + var record = new Dictionary { ["Key"] = key, ["FloatVector"] = new ReadOnlyMemory(s_floatVector) }; + var sut = new WeaviateDynamicDataModelMapper("Collection", hasNamedVectors, model, s_jsonSerializerOptions); + + // Act + var storageModel = sut.MapFromDataToStorageModel(record, recordIndex: 0, generatedEmbeddings: null); + + // Assert + var vectorProperty = hasNamedVectors ? storageModel["vectors"]!["floatVector"] : storageModel["vector"]; + + Assert.Equal(key, (Guid?)storageModel["id"]); + Assert.Equal(s_floatVector, vectorProperty!.AsArray().GetValues().ToArray()); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFromStorageToDataModelMapsNamedVectorsCorrectly(bool hasNamedVectors) + { + // Arrange + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(Guid)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory), 4) + ] + }; + + var model = new WeaviateModelBuilder(hasNamedVectors).Build(typeof(Dictionary), recordDefinition, defaultEmbeddingGenerator: null, s_jsonSerializerOptions); + + var key = new Guid("55555555-5555-5555-5555-555555555555"); + + var sut = new WeaviateDynamicDataModelMapper("Collection", hasNamedVectors, model, s_jsonSerializerOptions); + + var storageModel = new JsonObject { ["id"] = key }; + + var vector = new JsonArray(s_floatVector.Select(l => (JsonValue)l).ToArray()); + + if (hasNamedVectors) + { + storageModel["vectors"] = new JsonObject + { + ["floatVector"] = vector + }; + } + else + { + storageModel["vector"] = vector; + } + + // Act + var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); + + // Assert + Assert.Equal(key, dataModel["Key"]); + Assert.Equal(s_floatVector, ((ReadOnlyMemory)dataModel["FloatVector"]!).ToArray()); + } +} diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateGenericDataModelMapperTests.cs deleted file mode 100644 index 4eca8d8bf77f..000000000000 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateGenericDataModelMapperTests.cs +++ /dev/null @@ -1,445 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text.Json; -using System.Text.Json.Nodes; -using System.Text.Json.Serialization; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.Weaviate; -using Xunit; - -namespace SemanticKernel.Connectors.Weaviate.UnitTests; - -/// -/// Unit tests for class. -/// -public sealed class WeaviateGenericDataModelMapperTests -{ - private static readonly JsonSerializerOptions s_jsonSerializerOptions = new() - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, - Converters = - { - new WeaviateDateTimeOffsetConverter(), - new WeaviateNullableDateTimeOffsetConverter() - } - }; - - private static readonly VectorStoreRecordKeyProperty s_keyProperty = new("Key", typeof(Guid)); - - private static readonly List s_dataProperties = new() - { - new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), - new VectorStoreRecordDataProperty("BoolDataProp", typeof(bool)), - new VectorStoreRecordDataProperty("NullableBoolDataProp", typeof(bool?)), - new VectorStoreRecordDataProperty("IntDataProp", typeof(int)), - new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), - new VectorStoreRecordDataProperty("LongDataProp", typeof(long)), - new VectorStoreRecordDataProperty("NullableLongDataProp", typeof(long?)), - new VectorStoreRecordDataProperty("ShortDataProp", typeof(short)), - new VectorStoreRecordDataProperty("NullableShortDataProp", typeof(short?)), - new VectorStoreRecordDataProperty("ByteDataProp", typeof(byte)), - new VectorStoreRecordDataProperty("NullableByteDataProp", typeof(byte?)), - new VectorStoreRecordDataProperty("FloatDataProp", typeof(float)), - new VectorStoreRecordDataProperty("NullableFloatDataProp", typeof(float?)), - new VectorStoreRecordDataProperty("DoubleDataProp", typeof(double)), - new VectorStoreRecordDataProperty("NullableDoubleDataProp", typeof(double?)), - new VectorStoreRecordDataProperty("DecimalDataProp", typeof(decimal)), - new VectorStoreRecordDataProperty("NullableDecimalDataProp", typeof(decimal?)), - new VectorStoreRecordDataProperty("DateTimeDataProp", typeof(DateTime)), - new VectorStoreRecordDataProperty("NullableDateTimeDataProp", typeof(DateTime?)), - new VectorStoreRecordDataProperty("DateTimeOffsetDataProp", typeof(DateTimeOffset)), - new VectorStoreRecordDataProperty("NullableDateTimeOffsetDataProp", typeof(DateTimeOffset?)), - new VectorStoreRecordDataProperty("GuidDataProp", typeof(Guid)), - new VectorStoreRecordDataProperty("NullableGuidDataProp", typeof(Guid?)), - new VectorStoreRecordDataProperty("TagListDataProp", typeof(List)), - }; - - private static readonly List s_vectorProperties = new() - { - new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), - new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?)), - new VectorStoreRecordVectorProperty("DoubleVector", typeof(ReadOnlyMemory)), - new VectorStoreRecordVectorProperty("NullableDoubleVector", typeof(ReadOnlyMemory?)), - }; - - private static readonly Dictionary s_storagePropertyNames = s_dataProperties - .Select(l => l.DataModelPropertyName) - .Concat(s_vectorProperties.Select(l => l.DataModelPropertyName)) - .Concat([s_keyProperty.DataModelPropertyName]) - .ToDictionary(k => k, v => v); - - private static readonly float[] s_floatVector = [1.0f, 2.0f, 3.0f]; - private static readonly double[] s_doubleVector = [1.0f, 2.0f, 3.0f]; - private static readonly List s_taglist = ["tag1", "tag2"]; - - [Fact] - public void MapFromDataToStorageModelMapsAllSupportedTypes() - { - // Arrange - var key = new Guid("55555555-5555-5555-5555-555555555555"); - var sut = new WeaviateGenericDataModelMapper( - "Collection", - s_keyProperty, - s_dataProperties, - s_vectorProperties, - s_storagePropertyNames, - s_jsonSerializerOptions); - - var dataModel = new VectorStoreGenericDataModel(key) - { - Data = - { - ["StringDataProp"] = "string", - ["BoolDataProp"] = true, - ["NullableBoolDataProp"] = false, - ["IntDataProp"] = 1, - ["NullableIntDataProp"] = 2, - ["LongDataProp"] = 3L, - ["NullableLongDataProp"] = 4L, - ["ShortDataProp"] = (short)5, - ["NullableShortDataProp"] = (short)6, - ["ByteDataProp"] = (byte)7, - ["NullableByteDataProp"] = (byte)8, - ["FloatDataProp"] = 9.0f, - ["NullableFloatDataProp"] = 10.0f, - ["DoubleDataProp"] = 11.0, - ["NullableDoubleDataProp"] = 12.0, - ["DecimalDataProp"] = 13.99m, - ["NullableDecimalDataProp"] = 14.00m, - ["DateTimeDataProp"] = new DateTime(2021, 1, 1), - ["NullableDateTimeDataProp"] = new DateTime(2021, 1, 1), - ["DateTimeOffsetDataProp"] = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), - ["NullableDateTimeOffsetDataProp"] = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), - ["GuidDataProp"] = new Guid("11111111-1111-1111-1111-111111111111"), - ["NullableGuidDataProp"] = new Guid("22222222-2222-2222-2222-222222222222"), - ["TagListDataProp"] = s_taglist - }, - Vectors = - { - ["FloatVector"] = new ReadOnlyMemory(s_floatVector), - ["NullableFloatVector"] = new ReadOnlyMemory(s_floatVector), - ["DoubleVector"] = new ReadOnlyMemory(s_doubleVector), - ["NullableDoubleVector"] = new ReadOnlyMemory(s_doubleVector), - } - }; - - // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal(key, (Guid?)storageModel["id"]); - Assert.Equal("Collection", (string?)storageModel["class"]); - Assert.Equal("string", (string?)storageModel["properties"]?["StringDataProp"]); - Assert.Equal(true, (bool?)storageModel["properties"]?["BoolDataProp"]); - Assert.Equal(false, (bool?)storageModel["properties"]?["NullableBoolDataProp"]); - Assert.Equal(1, (int?)storageModel["properties"]?["IntDataProp"]); - Assert.Equal(2, (int?)storageModel["properties"]?["NullableIntDataProp"]); - Assert.Equal(3L, (long?)storageModel["properties"]?["LongDataProp"]); - Assert.Equal(4L, (long?)storageModel["properties"]?["NullableLongDataProp"]); - Assert.Equal((short)5, (short?)storageModel["properties"]?["ShortDataProp"]); - Assert.Equal((short)6, (short?)storageModel["properties"]?["NullableShortDataProp"]); - Assert.Equal((byte)7, (byte?)storageModel["properties"]?["ByteDataProp"]); - Assert.Equal((byte)8, (byte?)storageModel["properties"]?["NullableByteDataProp"]); - Assert.Equal(9.0f, (float?)storageModel["properties"]?["FloatDataProp"]); - Assert.Equal(10.0f, (float?)storageModel["properties"]?["NullableFloatDataProp"]); - Assert.Equal(11.0, (double?)storageModel["properties"]?["DoubleDataProp"]); - Assert.Equal(12.0, (double?)storageModel["properties"]?["NullableDoubleDataProp"]); - Assert.Equal(13.99m, (decimal?)storageModel["properties"]?["DecimalDataProp"]); - Assert.Equal(14.00m, (decimal?)storageModel["properties"]?["NullableDecimalDataProp"]); - Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0), (DateTime?)storageModel["properties"]?["DateTimeDataProp"]); - Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0), (DateTime?)storageModel["properties"]?["NullableDateTimeDataProp"]); - Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), (DateTimeOffset?)storageModel["properties"]?["DateTimeOffsetDataProp"]); - Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), (DateTimeOffset?)storageModel["properties"]?["NullableDateTimeOffsetDataProp"]); - Assert.Equal(new Guid("11111111-1111-1111-1111-111111111111"), (Guid?)storageModel["properties"]?["GuidDataProp"]); - Assert.Equal(new Guid("22222222-2222-2222-2222-222222222222"), (Guid?)storageModel["properties"]?["NullableGuidDataProp"]); - Assert.Equal(s_taglist, storageModel["properties"]?["TagListDataProp"]!.AsArray().GetValues().ToArray()); - Assert.Equal(s_floatVector, storageModel["vectors"]?["FloatVector"]!.AsArray().GetValues().ToArray()); - Assert.Equal(s_floatVector, storageModel["vectors"]?["NullableFloatVector"]!.AsArray().GetValues().ToArray()); - Assert.Equal(s_doubleVector, storageModel["vectors"]?["DoubleVector"]!.AsArray().GetValues().ToArray()); - Assert.Equal(s_doubleVector, storageModel["vectors"]?["NullableDoubleVector"]!.AsArray().GetValues().ToArray()); - } - - [Fact] - public void MapFromDataToStorageModelMapsNullValues() - { - // Arrange - var key = new Guid("55555555-5555-5555-5555-555555555555"); - var keyProperty = new VectorStoreRecordKeyProperty("Key", typeof(Guid)); - - var dataProperties = new List - { - new("StringDataProp", typeof(string)), - new("NullableIntDataProp", typeof(int?)), - }; - - var vectorProperties = new List - { - new("NullableFloatVector", typeof(ReadOnlyMemory?)) - }; - - var dataModel = new VectorStoreGenericDataModel(key) - { - Data = - { - ["StringDataProp"] = null, - ["NullableIntDataProp"] = null, - }, - Vectors = - { - ["NullableFloatVector"] = null, - }, - }; - - var sut = new WeaviateGenericDataModelMapper( - "Collection", - keyProperty, - dataProperties, - vectorProperties, - s_storagePropertyNames, - s_jsonSerializerOptions); - - // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Null(storageModel["StringDataProp"]); - Assert.Null(storageModel["NullableIntDataProp"]); - Assert.Null(storageModel["NullableFloatVector"]); - } - - [Fact] - public void MapFromStorageToDataModelMapsAllSupportedTypes() - { - // Arrange - var key = new Guid("55555555-5555-5555-5555-555555555555"); - var sut = new WeaviateGenericDataModelMapper( - "Collection", - s_keyProperty, - s_dataProperties, - s_vectorProperties, - s_storagePropertyNames, - s_jsonSerializerOptions); - - var storageModel = new JsonObject - { - ["id"] = key, - ["properties"] = new JsonObject - { - ["StringDataProp"] = "string", - ["BoolDataProp"] = true, - ["NullableBoolDataProp"] = false, - ["IntDataProp"] = 1, - ["NullableIntDataProp"] = 2, - ["LongDataProp"] = 3L, - ["NullableLongDataProp"] = 4L, - ["ShortDataProp"] = (short)5, - ["NullableShortDataProp"] = (short)6, - ["ByteDataProp"] = (byte)7, - ["NullableByteDataProp"] = (byte)8, - ["FloatDataProp"] = 9.0f, - ["NullableFloatDataProp"] = 10.0f, - ["DoubleDataProp"] = 11.0, - ["NullableDoubleDataProp"] = 12.0, - ["DecimalDataProp"] = 13.99m, - ["NullableDecimalDataProp"] = 14.00m, - ["DateTimeDataProp"] = new DateTime(2021, 1, 1), - ["NullableDateTimeDataProp"] = new DateTime(2021, 1, 1), - ["DateTimeOffsetDataProp"] = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), - ["NullableDateTimeOffsetDataProp"] = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), - ["GuidDataProp"] = new Guid("11111111-1111-1111-1111-111111111111"), - ["NullableGuidDataProp"] = new Guid("22222222-2222-2222-2222-222222222222"), - ["TagListDataProp"] = new JsonArray(s_taglist.Select(l => (JsonValue)l).ToArray()) - }, - ["vectors"] = new JsonObject - { - ["FloatVector"] = new JsonArray(s_floatVector.Select(l => (JsonValue)l).ToArray()), - ["NullableFloatVector"] = new JsonArray(s_floatVector.Select(l => (JsonValue)l).ToArray()), - ["DoubleVector"] = new JsonArray(s_doubleVector.Select(l => (JsonValue)l).ToArray()), - ["NullableDoubleVector"] = new JsonArray(s_doubleVector.Select(l => (JsonValue)l).ToArray()), - } - }; - - // Act - var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); - - // Assert - Assert.Equal(key, dataModel.Key); - Assert.Equal("string", dataModel.Data["StringDataProp"]); - Assert.Equal(true, dataModel.Data["BoolDataProp"]); - Assert.Equal(false, dataModel.Data["NullableBoolDataProp"]); - Assert.Equal(1, dataModel.Data["IntDataProp"]); - Assert.Equal(2, dataModel.Data["NullableIntDataProp"]); - Assert.Equal(3L, dataModel.Data["LongDataProp"]); - Assert.Equal(4L, dataModel.Data["NullableLongDataProp"]); - Assert.Equal((short)5, dataModel.Data["ShortDataProp"]); - Assert.Equal((short)6, dataModel.Data["NullableShortDataProp"]); - Assert.Equal((byte)7, dataModel.Data["ByteDataProp"]); - Assert.Equal((byte)8, dataModel.Data["NullableByteDataProp"]); - Assert.Equal(9.0f, dataModel.Data["FloatDataProp"]); - Assert.Equal(10.0f, dataModel.Data["NullableFloatDataProp"]); - Assert.Equal(11.0, dataModel.Data["DoubleDataProp"]); - Assert.Equal(12.0, dataModel.Data["NullableDoubleDataProp"]); - Assert.Equal(13.99m, dataModel.Data["DecimalDataProp"]); - Assert.Equal(14.00m, dataModel.Data["NullableDecimalDataProp"]); - Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0), dataModel.Data["DateTimeDataProp"]); - Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0), dataModel.Data["NullableDateTimeDataProp"]); - Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), dataModel.Data["DateTimeOffsetDataProp"]); - Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), dataModel.Data["NullableDateTimeOffsetDataProp"]); - Assert.Equal(new Guid("11111111-1111-1111-1111-111111111111"), dataModel.Data["GuidDataProp"]); - Assert.Equal(new Guid("22222222-2222-2222-2222-222222222222"), dataModel.Data["NullableGuidDataProp"]); - Assert.Equal(s_taglist, dataModel.Data["TagListDataProp"]); - Assert.Equal(s_floatVector, ((ReadOnlyMemory)dataModel.Vectors["FloatVector"]!).ToArray()); - Assert.Equal(s_floatVector, ((ReadOnlyMemory)dataModel.Vectors["NullableFloatVector"]!)!.ToArray()); - Assert.Equal(s_doubleVector, ((ReadOnlyMemory)dataModel.Vectors["DoubleVector"]!).ToArray()); - Assert.Equal(s_doubleVector, ((ReadOnlyMemory)dataModel.Vectors["NullableDoubleVector"]!)!.ToArray()); - } - - [Fact] - public void MapFromStorageToDataModelMapsNullValues() - { - // Arrange - var key = new Guid("55555555-5555-5555-5555-555555555555"); - var keyProperty = new VectorStoreRecordKeyProperty("Key", typeof(Guid)); - - var dataProperties = new List - { - new("StringDataProp", typeof(string)), - new("NullableIntDataProp", typeof(int?)), - }; - - var vectorProperties = new List - { - new("NullableFloatVector", typeof(ReadOnlyMemory?)) - }; - - var storageModel = new JsonObject - { - ["id"] = key, - ["properties"] = new JsonObject - { - ["StringDataProp"] = null, - ["NullableIntDataProp"] = null, - }, - ["vectors"] = new JsonObject - { - ["NullableFloatVector"] = null - } - }; - - var sut = new WeaviateGenericDataModelMapper( - "Collection", - s_keyProperty, - s_dataProperties, - s_vectorProperties, - s_storagePropertyNames, - s_jsonSerializerOptions); - - // Act - var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); - - // Assert - Assert.Equal(key, dataModel.Key); - Assert.Null(dataModel.Data["StringDataProp"]); - Assert.Null(dataModel.Data["NullableIntDataProp"]); - Assert.Null(dataModel.Vectors["NullableFloatVector"]); - } - - [Fact] - public void MapFromStorageToDataModelThrowsForMissingKey() - { - // Arrange - var sut = new WeaviateGenericDataModelMapper( - "Collection", - s_keyProperty, - s_dataProperties, - s_vectorProperties, - s_storagePropertyNames, - s_jsonSerializerOptions); - - var storageModel = new JsonObject(); - - // Act & Assert - var exception = Assert.Throws( - () => sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true })); - } - - [Fact] - public void MapFromDataToStorageModelSkipsMissingProperties() - { - // Arrange - var key = new Guid("55555555-5555-5555-5555-555555555555"); - var keyProperty = new VectorStoreRecordKeyProperty("Key", typeof(Guid)); - - var dataProperties = new List - { - new("StringDataProp", typeof(string)), - new("NullableIntDataProp", typeof(int?)), - }; - - var vectorProperties = new List - { - new("FloatVector", typeof(ReadOnlyMemory)) - }; - - var dataModel = new VectorStoreGenericDataModel(key); - var sut = new WeaviateGenericDataModelMapper( - "Collection", - keyProperty, - dataProperties, - vectorProperties, - s_storagePropertyNames, - s_jsonSerializerOptions); - - // Act - var storageModel = sut.MapFromDataToStorageModel(dataModel); - - // Assert - Assert.Equal(key, (Guid?)storageModel["id"]); - Assert.False(storageModel.ContainsKey("StringDataProp")); - Assert.False(storageModel.ContainsKey("FloatVector")); - } - - [Fact] - public void MapFromStorageToDataModelSkipsMissingProperties() - { - // Arrange - var key = new Guid("55555555-5555-5555-5555-555555555555"); - var keyProperty = new VectorStoreRecordKeyProperty("Key", typeof(Guid)); - - var dataProperties = new List - { - new("StringDataProp", typeof(string)), - new("NullableIntDataProp", typeof(int?)), - }; - - var vectorProperties = new List - { - new("FloatVector", typeof(ReadOnlyMemory)) - }; - - var sut = new WeaviateGenericDataModelMapper( - "Collection", - keyProperty, - dataProperties, - vectorProperties, - s_storagePropertyNames, - s_jsonSerializerOptions); - - var storageModel = new JsonObject - { - ["id"] = key - }; - - // Act - var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); - - // Assert - Assert.Equal(key, dataModel.Key); - Assert.False(dataModel.Data.ContainsKey("StringDataProp")); - Assert.False(dataModel.Vectors.ContainsKey("FloatVector")); - } -} diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateHotel.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateHotel.cs index 9ce781b39b8b..d57084ad5100 100644 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateHotel.cs +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateHotel.cs @@ -16,7 +16,7 @@ public sealed record WeaviateHotel public Guid HotelId { get; init; } /// A string metadata field. - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string? HotelName { get; set; } /// An int metadata field. @@ -37,13 +37,13 @@ public sealed record WeaviateHotel public List Tags { get; set; } = []; /// A data field. - [VectorStoreRecordData(IsFullTextSearchable = true)] + [VectorStoreRecordData(IsFullTextIndexed = true)] public string Description { get; set; } [VectorStoreRecordData] public DateTimeOffset Timestamp { get; set; } /// A vector field. - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance, IndexKind: IndexKind.Hnsw)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineDistance, IndexKind = IndexKind.Hnsw)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } } diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateKernelBuilderExtensionsTests.cs index 23b34cdbc2ba..5b2b6eefc582 100644 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateKernelBuilderExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateKernelBuilderExtensionsTests.cs @@ -50,11 +50,11 @@ public void AddWeaviateVectorStoreRecordCollectionRegistersClass() // Assert var collection = kernel.Services.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = kernel.Services.GetRequiredService>(); + var vectorizedSearch = kernel.Services.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); + Assert.IsType>(vectorizedSearch); } #pragma warning disable CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateMemoryBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateMemoryBuilderExtensionsTests.cs index d3c4a2a0c92f..172b84bc9196 100644 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateMemoryBuilderExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateMemoryBuilderExtensionsTests.cs @@ -16,6 +16,7 @@ namespace SemanticKernel.Connectors.UnitTests.Weaviate; +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public sealed class WeaviateMemoryBuilderExtensionsTests : IDisposable { private static readonly JsonSerializerOptions s_jsonSerializerOptions = new() diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateMemoryStoreTests.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateMemoryStoreTests.cs index 97134f46818a..b1eea5a6c6b3 100644 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateMemoryStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateMemoryStoreTests.cs @@ -16,6 +16,7 @@ namespace SemanticKernel.Connectors.UnitTests.Weaviate; /// /// Unit tests for class. /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public sealed class WeaviateMemoryStoreTests : IDisposable { private static readonly JsonSerializerOptions s_jsonSerializerOptions = new() diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateServiceCollectionExtensionsTests.cs index e33f735ebc4f..b63071e9eafb 100644 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateServiceCollectionExtensionsTests.cs @@ -55,11 +55,11 @@ private void AssertVectorStoreRecordCollectionCreated() var collection = serviceProvider.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = serviceProvider.GetRequiredService>(); + var vectorizedSearch = serviceProvider.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); + Assert.IsType>(vectorizedSearch); } #pragma warning disable CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreCollectionCreateMappingTests.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreCollectionCreateMappingTests.cs index 30de049b9ec8..6a02f3db8c91 100644 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreCollectionCreateMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreCollectionCreateMappingTests.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Text.Json; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Weaviate; using Xunit; @@ -13,23 +14,27 @@ namespace SemanticKernel.Connectors.Weaviate.UnitTests; /// public sealed class WeaviateVectorStoreCollectionCreateMappingTests { + private const bool HasNamedVectors = true; + [Fact] public void ItThrowsExceptionWithInvalidIndexKind() { // Arrange - var vectorProperties = new List - { - new("PropertyName", typeof(ReadOnlyMemory)) { IndexKind = "non-existent-index-kind" } - }; - - var storagePropertyNames = new Dictionary { ["PropertyName"] = "propertyName" }; + var model = new WeaviateModelBuilder(HasNamedVectors) + .Build( + typeof(Dictionary), + new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(Guid)), + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 10) { IndexKind = "non-existent-index-kind" } + ] + }, + defaultEmbeddingGenerator: null); // Act & Assert - Assert.Throws(() => WeaviateVectorStoreCollectionCreateMapping.MapToSchema( - collectionName: "CollectionName", - dataProperties: [], - vectorProperties: vectorProperties, - storagePropertyNames: storagePropertyNames)); + Assert.Throws(() => WeaviateVectorStoreCollectionCreateMapping.MapToSchema(collectionName: "CollectionName", HasNamedVectors, model)); } [Theory] @@ -39,43 +44,46 @@ public void ItThrowsExceptionWithInvalidIndexKind() public void ItReturnsCorrectSchemaWithValidIndexKind(string indexKind, string expectedIndexKind) { // Arrange - var vectorProperties = new List - { - new("PropertyName", typeof(ReadOnlyMemory)) { IndexKind = indexKind } - }; - - var storagePropertyNames = new Dictionary { ["PropertyName"] = "propertyName" }; + var model = new WeaviateModelBuilder(HasNamedVectors) + .Build( + typeof(Dictionary), + new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(Guid)), + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 10) { IndexKind = indexKind } + ] + }, + defaultEmbeddingGenerator: null); // Act - var schema = WeaviateVectorStoreCollectionCreateMapping.MapToSchema( - collectionName: "CollectionName", - dataProperties: [], - vectorProperties: vectorProperties, - storagePropertyNames: storagePropertyNames); - - var actualIndexKind = schema.VectorConfigurations["propertyName"].VectorIndexType; + var schema = WeaviateVectorStoreCollectionCreateMapping.MapToSchema(collectionName: "CollectionName", HasNamedVectors, model); + var actualIndexKind = schema.VectorConfigurations["Vector"].VectorIndexType; // Assert Assert.Equal(expectedIndexKind, actualIndexKind); } [Fact] - public void ItThrowsExceptionWithInvalidDistanceFunction() + public void ItThrowsExceptionWithUnsupportedDistanceFunction() { // Arrange - var vectorProperties = new List - { - new("PropertyName", typeof(ReadOnlyMemory)) { DistanceFunction = "non-existent-distance-function" } - }; - - var storagePropertyNames = new Dictionary { ["PropertyName"] = "propertyName" }; + var model = new WeaviateModelBuilder(HasNamedVectors) + .Build( + typeof(Dictionary), + new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(Guid)), + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 10) { DistanceFunction = "unsupported-distance-function" } + ] + }, + defaultEmbeddingGenerator: null); // Act & Assert - Assert.Throws(() => WeaviateVectorStoreCollectionCreateMapping.MapToSchema( - collectionName: "CollectionName", - dataProperties: [], - vectorProperties: vectorProperties, - storagePropertyNames: storagePropertyNames)); + Assert.Throws(() => WeaviateVectorStoreCollectionCreateMapping.MapToSchema(collectionName: "CollectionName", HasNamedVectors, model)); } [Theory] @@ -87,21 +95,23 @@ public void ItThrowsExceptionWithInvalidDistanceFunction() public void ItReturnsCorrectSchemaWithValidDistanceFunction(string distanceFunction, string expectedDistanceFunction) { // Arrange - var vectorProperties = new List - { - new("PropertyName", typeof(ReadOnlyMemory)) { DistanceFunction = distanceFunction } - }; - - var storagePropertyNames = new Dictionary { ["PropertyName"] = "propertyName" }; + var model = new WeaviateModelBuilder(HasNamedVectors) + .Build( + typeof(Dictionary), + new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(Guid)), + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 10) { DistanceFunction = distanceFunction } + ] + }, + defaultEmbeddingGenerator: null); // Act - var schema = WeaviateVectorStoreCollectionCreateMapping.MapToSchema( - collectionName: "CollectionName", - dataProperties: [], - vectorProperties: vectorProperties, - storagePropertyNames: storagePropertyNames); + var schema = WeaviateVectorStoreCollectionCreateMapping.MapToSchema(collectionName: "CollectionName", HasNamedVectors, model); - var actualDistanceFunction = schema.VectorConfigurations["propertyName"].VectorIndexConfig?.Distance; + var actualDistanceFunction = schema.VectorConfigurations["Vector"].VectorIndexConfig?.Distance; // Assert Assert.Equal(expectedDistanceFunction, actualDistanceFunction); @@ -154,24 +164,26 @@ public void ItReturnsCorrectSchemaWithValidDistanceFunction(string distanceFunct [InlineData(typeof(bool?), "boolean")] [InlineData(typeof(List), "boolean[]")] [InlineData(typeof(List), "boolean[]")] - [InlineData(typeof(object), "object")] - [InlineData(typeof(List), "object[]")] public void ItMapsPropertyCorrectly(Type propertyType, string expectedPropertyType) { // Arrange - var dataProperties = new List - { - new("PropertyName", propertyType) { IsFilterable = true, IsFullTextSearchable = true } - }; - - var storagePropertyNames = new Dictionary { ["PropertyName"] = "propertyName" }; + var model = new WeaviateModelBuilder(HasNamedVectors) + .Build( + typeof(Dictionary), + new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(Guid)), + new VectorStoreRecordDataProperty("PropertyName", propertyType) { IsIndexed = true, IsFullTextIndexed = true }, + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 10) + ] + }, + defaultEmbeddingGenerator: null, + new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }); // Act - var schema = WeaviateVectorStoreCollectionCreateMapping.MapToSchema( - collectionName: "CollectionName", - dataProperties: dataProperties, - vectorProperties: [], - storagePropertyNames: storagePropertyNames); + var schema = WeaviateVectorStoreCollectionCreateMapping.MapToSchema(collectionName: "CollectionName", HasNamedVectors, model); var property = schema.Properties[0]; @@ -181,4 +193,49 @@ public void ItMapsPropertyCorrectly(Type propertyType, string expectedPropertyTy Assert.True(property.IndexSearchable); Assert.True(property.IndexFilterable); } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void ItReturnsCorrectSchemaWithValidVectorConfiguration(bool hasNamedVectors) + { + // Arrange + var model = new WeaviateModelBuilder(hasNamedVectors) + .Build( + typeof(Dictionary), + new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(Guid)), + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 4) + { + DistanceFunction = DistanceFunction.CosineDistance, + IndexKind = IndexKind.Hnsw + } + ] + }, + defaultEmbeddingGenerator: null); + + // Act + var schema = WeaviateVectorStoreCollectionCreateMapping.MapToSchema(collectionName: "CollectionName", hasNamedVectors, model); + + // Assert + if (hasNamedVectors) + { + Assert.Null(schema.VectorIndexConfig?.Distance); + Assert.Null(schema.VectorIndexType); + Assert.True(schema.VectorConfigurations.ContainsKey("Vector")); + + Assert.Equal("cosine", schema.VectorConfigurations["Vector"].VectorIndexConfig?.Distance); + Assert.Equal("hnsw", schema.VectorConfigurations["Vector"].VectorIndexType); + } + else + { + Assert.False(schema.VectorConfigurations.ContainsKey("Vector")); + + Assert.Equal("cosine", schema.VectorIndexConfig?.Distance); + Assert.Equal("hnsw", schema.VectorIndexType); + } + } } diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreCollectionSearchMappingTests.cs index 35a00c0376fc..3016b0083e2c 100644 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreCollectionSearchMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreCollectionSearchMappingTests.cs @@ -13,8 +13,10 @@ namespace SemanticKernel.Connectors.Weaviate.UnitTests; /// public sealed class WeaviateVectorStoreCollectionSearchMappingTests { - [Fact] - public void MapSearchResultByDefaultReturnsValidResult() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapSearchResultByDefaultReturnsValidResult(bool hasNamedVectors) { // Arrange var jsonObject = new JsonObject @@ -22,11 +24,7 @@ public void MapSearchResultByDefaultReturnsValidResult() ["_additional"] = new JsonObject { ["distance"] = 0.5, - ["id"] = "55555555-5555-5555-5555-555555555555", - ["vectors"] = new JsonObject - { - ["descriptionEmbedding"] = new JsonArray(new List { 30, 31, 32, 33 }.Select(l => (JsonNode)l).ToArray()) - } + ["id"] = "55555555-5555-5555-5555-555555555555" }, ["description"] = "This is a great hotel.", ["hotelCode"] = 42, @@ -37,14 +35,27 @@ public void MapSearchResultByDefaultReturnsValidResult() ["timestamp"] = "2024-08-28T10:11:12-07:00" }; + var vector = new JsonArray(new List { 30, 31, 32, 33 }.Select(l => (JsonNode)l).ToArray()); + + if (hasNamedVectors) + { + jsonObject["_additional"]!["vectors"] = new JsonObject + { + ["descriptionEmbedding"] = vector + }; + } + else + { + jsonObject["_additional"]!["vector"] = vector; + } + // Act - var (storageModel, score) = WeaviateVectorStoreCollectionSearchMapping.MapSearchResult(jsonObject, "distance"); + var (storageModel, score) = WeaviateVectorStoreCollectionSearchMapping.MapSearchResult(jsonObject, "distance", hasNamedVectors); // Assert Assert.Equal(0.5, score); Assert.Equal("55555555-5555-5555-5555-555555555555", storageModel["id"]!.GetValue()); - Assert.Equal([30f, 31f, 32f, 33f], storageModel["vectors"]!["descriptionEmbedding"]!.AsArray().Select(l => l!.GetValue())); Assert.Equal("This is a great hotel.", storageModel["properties"]!["description"]!.GetValue()); Assert.Equal(42, storageModel["properties"]!["hotelCode"]!.GetValue()); Assert.Equal(4.5, storageModel["properties"]!["hotelRating"]!.GetValue()); @@ -52,5 +63,9 @@ public void MapSearchResultByDefaultReturnsValidResult() Assert.True(storageModel["properties"]!["parking_is_included"]!.GetValue()); Assert.Equal(["t1", "t2"], storageModel["properties"]!["tags"]!.AsArray().Select(l => l!.GetValue())); Assert.Equal("2024-08-28T10:11:12-07:00", storageModel["properties"]!["timestamp"]!.GetValue()); + + var vectorProperty = hasNamedVectors ? storageModel["vectors"]!["descriptionEmbedding"] : storageModel["vector"]; + + Assert.Equal([30f, 31f, 32f, 33f], vectorProperty!.AsArray().Select(l => l!.GetValue())); } } diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionQueryBuilderTests.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionQueryBuilderTests.cs index 5a009649ab1b..3bb70feda9e1 100644 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionQueryBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionQueryBuilderTests.cs @@ -5,6 +5,7 @@ using System.Text.Json; using System.Text.Json.Serialization; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.Weaviate; using Xunit; @@ -19,7 +20,6 @@ public sealed class WeaviateVectorStoreRecordCollectionQueryBuilderTests { private const string CollectionName = "Collection"; private const string VectorPropertyName = "descriptionEmbedding"; - private const string KeyPropertyName = "HotelId"; private static readonly JsonSerializerOptions s_jsonSerializerOptions = new() { @@ -32,23 +32,28 @@ public sealed class WeaviateVectorStoreRecordCollectionQueryBuilderTests } }; - private readonly Dictionary _storagePropertyNames = new() - { - ["HotelId"] = "hotelId", - ["HotelName"] = "hotelName", - ["HotelCode"] = "hotelCode", - ["Tags"] = "tags", - ["DescriptionEmbedding"] = "descriptionEmbedding" - }; + private readonly VectorStoreRecordModel _model = new WeaviateModelBuilder(hasNamedVectors: true) + .Build( + typeof(Dictionary), + new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("HotelId", typeof(Guid)) { StoragePropertyName = "hotelId" }, + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { StoragePropertyName = "hotelName" }, + new VectorStoreRecordDataProperty("HotelCode", typeof(string)) { StoragePropertyName = "hotelCode" }, + new VectorStoreRecordDataProperty("Tags", typeof(string[])) { StoragePropertyName = "tags" }, + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory), 10) { StoragePropertyName = "descriptionEmbeddding" }, + ] + }, + defaultEmbeddingGenerator: null); private readonly ReadOnlyMemory _vector = new([31f, 32f, 33f, 34f]); - private readonly List _vectorPropertyStorageNames = ["descriptionEmbedding"]; - - private readonly List _dataPropertyStorageNames = ["hotelName", "hotelCode"]; - - [Fact] - public void BuildSearchQueryByDefaultReturnsValidQuery() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void BuildSearchQueryByDefaultReturnsValidQuery(bool hasNamedVectors) { // Arrange var expectedQuery = $$""" @@ -59,11 +64,11 @@ public void BuildSearchQueryByDefaultReturnsValidQuery() offset: 2 {{string.Empty}} nearVector: { - targetVectors: ["descriptionEmbedding"] + {{(hasNamedVectors ? "targetVectors: [\"descriptionEmbedding\"]" : string.Empty)}} vector: [31,32,33,34] } ) { - hotelName hotelCode + HotelName HotelCode Tags _additional { id distance @@ -77,7 +82,6 @@ hotelName hotelCode var searchOptions = new VectorSearchOptions { Skip = 2, - Top = 3, }; // Act @@ -85,12 +89,11 @@ hotelName hotelCode this._vector, CollectionName, VectorPropertyName, - KeyPropertyName, s_jsonSerializerOptions, + top: 3, searchOptions, - this._storagePropertyNames, - this._vectorPropertyStorageNames, - this._dataPropertyStorageNames); + this._model, + hasNamedVectors); // Assert Assert.Equal(expectedQuery, query); @@ -99,14 +102,15 @@ hotelName hotelCode Assert.DoesNotContain("where", query); } - [Fact] - public void BuildSearchQueryWithIncludedVectorsReturnsValidQuery() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void BuildSearchQueryWithIncludedVectorsReturnsValidQuery(bool hasNamedVectors) { // Arrange var searchOptions = new VectorSearchOptions { Skip = 2, - Top = 3, IncludeVectors = true }; @@ -115,28 +119,28 @@ public void BuildSearchQueryWithIncludedVectorsReturnsValidQuery() this._vector, CollectionName, VectorPropertyName, - KeyPropertyName, s_jsonSerializerOptions, + top: 3, searchOptions, - this._storagePropertyNames, - this._vectorPropertyStorageNames, - this._dataPropertyStorageNames); + this._model, + hasNamedVectors); // Assert - Assert.Contains("vectors { descriptionEmbedding }", query); + var vectorQuery = hasNamedVectors ? "vectors { DescriptionEmbedding }" : "vector"; + + Assert.Contains(vectorQuery, query); } [Fact] public void BuildSearchQueryWithFilterReturnsValidQuery() { // Arrange - const string ExpectedFirstSubquery = """{ path: ["hotelName"], operator: Equal, valueText: "Test Name" }"""; - const string ExpectedSecondSubquery = """{ path: ["tags"], operator: ContainsAny, valueText: ["t1"] }"""; + const string ExpectedFirstSubquery = """{ path: ["HotelName"], operator: Equal, valueText: "Test Name" }"""; + const string ExpectedSecondSubquery = """{ path: ["Tags"], operator: ContainsAny, valueText: ["t1"] }"""; var searchOptions = new VectorSearchOptions { Skip = 2, - Top = 3, OldFilter = new VectorSearchFilter() .EqualTo("HotelName", "Test Name") .AnyTagEqualTo("Tags", "t1") @@ -147,12 +151,11 @@ public void BuildSearchQueryWithFilterReturnsValidQuery() this._vector, CollectionName, VectorPropertyName, - KeyPropertyName, s_jsonSerializerOptions, + top: 3, searchOptions, - this._storagePropertyNames, - this._vectorPropertyStorageNames, - this._dataPropertyStorageNames); + this._model, + hasNamedVectors: true); // Assert Assert.Contains(ExpectedFirstSubquery, query); @@ -166,7 +169,6 @@ public void BuildSearchQueryWithInvalidFilterValueThrowsException() var searchOptions = new VectorSearchOptions { Skip = 2, - Top = 3, OldFilter = new VectorSearchFilter().EqualTo("HotelName", new TestFilterValue()) }; @@ -175,12 +177,11 @@ public void BuildSearchQueryWithInvalidFilterValueThrowsException() this._vector, CollectionName, VectorPropertyName, - KeyPropertyName, s_jsonSerializerOptions, + top: 3, searchOptions, - this._storagePropertyNames, - this._vectorPropertyStorageNames, - this._dataPropertyStorageNames)); + this._model, + hasNamedVectors: true)); } [Fact] @@ -190,7 +191,6 @@ public void BuildSearchQueryWithNonExistentPropertyInFilterThrowsException() var searchOptions = new VectorSearchOptions { Skip = 2, - Top = 3, OldFilter = new VectorSearchFilter().EqualTo("NonExistentProperty", "value") }; @@ -199,12 +199,11 @@ public void BuildSearchQueryWithNonExistentPropertyInFilterThrowsException() this._vector, CollectionName, VectorPropertyName, - KeyPropertyName, s_jsonSerializerOptions, + top: 3, searchOptions, - this._storagePropertyNames, - this._vectorPropertyStorageNames, - this._dataPropertyStorageNames)); + this._model, + hasNamedVectors: true)); } #region private diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionTests.cs index 0b3e39cac291..99ad9d7e4660 100644 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionTests.cs @@ -10,13 +10,12 @@ using System.Threading.Tasks; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Weaviate; -using Moq; using Xunit; namespace SemanticKernel.Connectors.Weaviate.UnitTests; /// -/// Unit tests for class. +/// Unit tests for class. /// public sealed class WeaviateVectorStoreRecordCollectionTests : IDisposable { @@ -32,7 +31,7 @@ public WeaviateVectorStoreRecordCollectionTests() public void ConstructorForModelWithoutKeyThrowsException() { // Act & Assert - var exception = Assert.Throws(() => new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection")); + var exception = Assert.Throws(() => new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection")); Assert.Contains("No key property found", exception.Message); } @@ -43,7 +42,7 @@ public void ConstructorWithoutEndpointThrowsException() using var httpClient = new HttpClient(); // Act & Assert - var exception = Assert.Throws(() => new WeaviateVectorStoreRecordCollection(httpClient, "Collection")); + var exception = Assert.Throws(() => new WeaviateVectorStoreRecordCollection(httpClient, "Collection")); Assert.Contains("Weaviate endpoint should be provided", exception.Message); } @@ -51,7 +50,7 @@ public void ConstructorWithoutEndpointThrowsException() public void ConstructorWithDeclarativeModelInitializesCollection() { // Act & Assert - var collection = new WeaviateVectorStoreRecordCollection( + var collection = new WeaviateVectorStoreRecordCollection( this._mockHttpClient, "Collection"); @@ -68,7 +67,7 @@ public void ConstructorWithImperativeModelInitializesCollection() }; // Act - var collection = new WeaviateVectorStoreRecordCollection( + var collection = new WeaviateVectorStoreRecordCollection( this._mockHttpClient, "Collection", new() { VectorStoreRecordDefinition = definition }); @@ -84,7 +83,7 @@ public async Task CollectionExistsReturnsValidResultAsync(HttpResponseMessage re // Arrange this._messageHandlerStub.ResponseToReturn = responseMessage; - var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection"); + var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection"); // Act var actualResult = await sut.CollectionExistsAsync(); @@ -113,7 +112,7 @@ public async Task CollectionExistsReturnsValidResultAsync(HttpResponseMessage re [InlineData("containsNonAsciią")] public void CollectionCtorRejectsInvalidNames(string collectionName) { - ArgumentException argumentException = Assert.Throws(() => new WeaviateVectorStoreRecordCollection(this._mockHttpClient, collectionName)); + ArgumentException argumentException = Assert.Throws(() => new WeaviateVectorStoreRecordCollection(this._mockHttpClient, collectionName)); Assert.Equal("collectionName", argumentException.ParamName); } @@ -122,7 +121,7 @@ public async Task CreateCollectionUsesValidCollectionSchemaAsync() { // Arrange const string CollectionName = "Collection"; - var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, CollectionName); + var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, CollectionName); // Act await sut.CreateCollectionAsync(); @@ -158,7 +157,7 @@ public async Task DeleteCollectionSendsValidRequestAsync() { // Arrange const string CollectionName = "Collection"; - var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, CollectionName); + var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, CollectionName); // Act await sut.DeleteCollectionAsync(); @@ -175,7 +174,7 @@ public async Task DeleteSendsValidRequestAsync() const string CollectionName = "Collection"; var id = new Guid("55555555-5555-5555-5555-555555555555"); - var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, CollectionName); + var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, CollectionName); // Act await sut.DeleteAsync(id); @@ -192,10 +191,10 @@ public async Task DeleteBatchUsesValidQueryMatchAsync() const string CollectionName = "Collection"; List ids = [new Guid("11111111-1111-1111-1111-111111111111"), new Guid("22222222-2222-2222-2222-222222222222")]; - var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, CollectionName); + var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, CollectionName); // Act - await sut.DeleteBatchAsync(ids); + await sut.DeleteAsync(ids); // Assert var request = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); @@ -228,7 +227,7 @@ public async Task GetExistingRecordReturnsValidRecordAsync() Content = new StringContent(JsonSerializer.Serialize(jsonObject)) }; - var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection"); + var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection"); // Act var result = await sut.GetAsync(id); @@ -258,10 +257,10 @@ public async Task GetExistingBatchRecordsReturnsValidRecordsAsync() this._messageHandlerStub.ResponseQueue.Enqueue(response1); this._messageHandlerStub.ResponseQueue.Enqueue(response2); - var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection"); + var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection"); // Act - var results = await sut.GetBatchAsync([id1, id2]).ToListAsync(); + var results = await sut.GetAsync([id1, id2]).ToListAsync(); // Assert Assert.NotNull(results[0]); @@ -287,7 +286,7 @@ public async Task UpsertReturnsRecordKeyAsync() Content = new StringContent(JsonSerializer.Serialize(batchResponse)), }; - var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection"); + var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection"); // Act var result = await sut.UpsertAsync(hotel); @@ -326,10 +325,10 @@ public async Task UpsertReturnsRecordKeysAsync() Content = new StringContent(JsonSerializer.Serialize(batchResponse)), }; - var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection"); + var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection"); // Act - var results = await sut.UpsertBatchAsync([hotel1, hotel2]).ToListAsync(); + var results = await sut.UpsertAsync([hotel1, hotel2]); // Assert Assert.Contains(id1, results); @@ -349,85 +348,6 @@ public async Task UpsertReturnsRecordKeysAsync() Assert.Equal("Test Name 2", jsonObject2["properties"]?["hotelName"]?.GetValue()); } - [Fact] - public async Task UpsertWithCustomMapperWorksCorrectlyAsync() - { - // Arrange - var id = new Guid("11111111-1111-1111-1111-111111111111"); - var hotel = new WeaviateHotel { HotelId = id, HotelName = "Test Name" }; - - var jsonObject = new JsonObject { ["id"] = id.ToString(), ["properties"] = new JsonObject() }; - - jsonObject["properties"]!["hotel_name"] = "Test Name from Mapper"; - - var mockMapper = new Mock>(); - - mockMapper - .Setup(l => l.MapFromDataToStorageModel(It.IsAny())) - .Returns(jsonObject); - - var batchResponse = new List { new() { Id = id, Result = new() { Status = "Success" } } }; - - this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(HttpStatusCode.OK) - { - Content = new StringContent(JsonSerializer.Serialize(batchResponse)), - }; - - var sut = new WeaviateVectorStoreRecordCollection( - this._mockHttpClient, - "Collection", - new() { JsonObjectCustomMapper = mockMapper.Object }); - - // Act - var result = await sut.UpsertAsync(hotel); - - // Assert - Assert.Equal(id, result); - - var request = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); - - Assert.NotNull(request?.CollectionObjects); - - var requestObject = request.CollectionObjects[0]; - - Assert.Equal("11111111-1111-1111-1111-111111111111", requestObject["id"]?.GetValue()); - Assert.Equal("Test Name from Mapper", requestObject["properties"]?["hotel_name"]?.GetValue()); - } - - [Fact] - public async Task GetWithCustomMapperWorksCorrectlyAsync() - { - // Arrange - var id = new Guid("11111111-1111-1111-1111-111111111111"); - var jsonObject = new JsonObject { ["id"] = id.ToString(), ["properties"] = new JsonObject() }; - - jsonObject["properties"]!["hotelName"] = "Test Name"; - - this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(HttpStatusCode.OK) - { - Content = new StringContent(JsonSerializer.Serialize(jsonObject)) - }; - - var mockMapper = new Mock>(); - - mockMapper - .Setup(l => l.MapFromStorageToDataModel(It.IsAny(), It.IsAny())) - .Returns(new WeaviateHotel { HotelId = id, HotelName = "Test Name from mapper" }); - - var sut = new WeaviateVectorStoreRecordCollection( - this._mockHttpClient, - "Collection", - new() { JsonObjectCustomMapper = mockMapper.Object }); - - // Act - var result = await sut.GetAsync(id); - - // Assert - Assert.NotNull(result); - Assert.Equal(id, result.HotelId); - Assert.Equal("Test Name from mapper", result.HotelName); - } - [Theory] [InlineData(true, "http://test-endpoint/schema", "Bearer fake-key")] [InlineData(false, "http://default-endpoint/schema", null)] @@ -440,7 +360,7 @@ public async Task ItUsesHttpClientParametersAsync(bool initializeOptions, string new WeaviateVectorStoreRecordCollectionOptions() { Endpoint = new Uri("http://test-endpoint"), ApiKey = "fake-key" } : null; - var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, CollectionName, options); + var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, CollectionName, options); // Act await sut.CreateCollectionAsync(); @@ -456,10 +376,10 @@ public async Task ItUsesHttpClientParametersAsync(bool initializeOptions, string [Theory] [InlineData(true)] [InlineData(false)] - public async Task VectorizedSearchReturnsValidRecordAsync(bool includeVectors) + public async Task SearchEmbeddingReturnsValidRecordAsync(bool includeVectors) { // Arrange - const string CollectionName = "VectorizedSearchCollection"; + const string CollectionName = "SearchEmbeddingCollection"; var id = new Guid("55555555-5555-5555-5555-555555555555"); var vector = new ReadOnlyMemory([30f, 31f, 32f, 33f]); @@ -500,16 +420,15 @@ public async Task VectorizedSearchReturnsValidRecordAsync(bool includeVectors) Content = new StringContent(JsonSerializer.Serialize(jsonObject)) }; - var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, CollectionName); + var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, CollectionName); // Act - var actual = await sut.VectorizedSearchAsync(vector, new() + var results = await sut.SearchEmbeddingAsync(vector, top: 3, new() { IncludeVectors = includeVectors - }); + }).ToListAsync(); // Assert - var results = await actual.Results.ToListAsync(); Assert.Single(results); var score = results[0].Score; @@ -538,28 +457,29 @@ public async Task VectorizedSearchReturnsValidRecordAsync(bool includeVectors) } [Fact] - public async Task VectorizedSearchWithUnsupportedVectorTypeThrowsExceptionAsync() + public async Task SearchEmbeddingWithUnsupportedVectorTypeThrowsExceptionAsync() { // Arrange - var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection"); + var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection"); // Act & Assert await Assert.ThrowsAsync(async () => - await (await sut.VectorizedSearchAsync(new List([1, 2, 3]))).Results.ToListAsync()); + await sut.SearchEmbeddingAsync(new List([1, 2, 3]), top: 3).ToListAsync()); } [Fact] - public async Task VectorizedSearchWithNonExistentVectorPropertyNameThrowsExceptionAsync() + public async Task SearchEmbeddingWithNonExistentVectorPropertyNameThrowsExceptionAsync() { // Arrange - var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection"); + var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection"); // Act & Assert await Assert.ThrowsAsync(async () => - await (await sut.VectorizedSearchAsync( + await sut.SearchEmbeddingAsync( new ReadOnlyMemory([1f, 2f, 3f]), - new() { VectorProperty = r => "non-existent-property" })) - .Results.ToListAsync()); + top: 3, + new() { VectorProperty = r => "non-existent-property" }) + .ToListAsync()); } public void Dispose() diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordMapperTests.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordMapperTests.cs index 5f79925c2c48..f00d62b54255 100644 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordMapperTests.cs +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordMapperTests.cs @@ -28,40 +28,10 @@ public sealed class WeaviateVectorStoreRecordMapperTests } }; - private readonly WeaviateVectorStoreRecordMapper _sut; - - public WeaviateVectorStoreRecordMapperTests() - { - var storagePropertyNames = new Dictionary - { - ["HotelId"] = "hotelId", - ["HotelName"] = "hotelName", - ["Tags"] = "tags", - ["DescriptionEmbedding"] = "descriptionEmbedding", - }; - - var dataProperties = new List - { - new("HotelName", typeof(string)), - new("Tags", typeof(List)) - }; - - var vectorProperties = new List - { - new("DescriptionEmbedding", typeof(ReadOnlyMemory)) - }; - - this._sut = new WeaviateVectorStoreRecordMapper( - "CollectionName", - new VectorStoreRecordKeyProperty("HotelId", typeof(Guid)), - dataProperties, - vectorProperties, - storagePropertyNames, - s_jsonSerializerOptions); - } - - [Fact] - public void MapFromDataToStorageModelReturnsValidObject() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFromDataToStorageModelReturnsValidObject(bool hasNamedVectors) { // Arrange var hotel = new WeaviateHotel @@ -72,8 +42,10 @@ public void MapFromDataToStorageModelReturnsValidObject() DescriptionEmbedding = new ReadOnlyMemory([1f, 2f, 3f]) }; + var sut = GetMapper(hasNamedVectors); + // Act - var document = this._sut.MapFromDataToStorageModel(hotel); + var document = sut.MapFromDataToStorageModel(hotel, recordIndex: 0, generatedEmbeddings: null); // Assert Assert.NotNull(document); @@ -81,11 +53,16 @@ public void MapFromDataToStorageModelReturnsValidObject() Assert.Equal("55555555-5555-5555-5555-555555555555", document["id"]!.GetValue()); Assert.Equal("Test Name", document["properties"]!["hotelName"]!.GetValue()); Assert.Equal(["tag1", "tag2"], document["properties"]!["tags"]!.AsArray().Select(l => l!.GetValue())); - Assert.Equal([1f, 2f, 3f], document["vectors"]!["descriptionEmbedding"]!.AsArray().Select(l => l!.GetValue())); + + var vectorNode = hasNamedVectors ? document["vectors"]!["descriptionEmbedding"] : document["vector"]; + + Assert.Equal([1f, 2f, 3f], vectorNode!.AsArray().Select(l => l!.GetValue())); } - [Fact] - public void MapFromStorageToDataModelReturnsValidObject() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFromStorageToDataModelReturnsValidObject(bool hasNamedVectors) { // Arrange var document = new JsonObject @@ -97,10 +74,22 @@ public void MapFromStorageToDataModelReturnsValidObject() document["properties"]!["hotelName"] = "Test Name"; document["properties"]!["tags"] = new JsonArray(new List { "tag1", "tag2" }.Select(l => JsonValue.Create(l)).ToArray()); - document["vectors"]!["descriptionEmbedding"] = new JsonArray(new List { 1f, 2f, 3f }.Select(l => JsonValue.Create(l)).ToArray()); + + var vectorNode = new JsonArray(new List { 1f, 2f, 3f }.Select(l => JsonValue.Create(l)).ToArray()); + + if (hasNamedVectors) + { + document["vectors"]!["descriptionEmbedding"] = vectorNode; + } + else + { + document["vector"] = vectorNode; + } + + var sut = GetMapper(hasNamedVectors); // Act - var hotel = this._sut.MapFromStorageToDataModel(document, new() { IncludeVectors = true }); + var hotel = sut.MapFromStorageToDataModel(document, new() { IncludeVectors = true }); // Assert Assert.NotNull(hotel); @@ -110,4 +99,28 @@ public void MapFromStorageToDataModelReturnsValidObject() Assert.Equal(["tag1", "tag2"], hotel.Tags); Assert.True(new ReadOnlyMemory([1f, 2f, 3f]).Span.SequenceEqual(hotel.DescriptionEmbedding!.Value.Span)); } + + #region private + + private static WeaviateVectorStoreRecordMapper GetMapper(bool hasNamedVectors) => new( + "CollectionName", + hasNamedVectors, + new WeaviateModelBuilder(hasNamedVectors) + .Build( + typeof(Dictionary), + new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty("HotelId", typeof(Guid)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)), + new VectorStoreRecordDataProperty("Tags", typeof(List)), + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory), 10) + ] + }, + defaultEmbeddingGenerator: null, + s_jsonSerializerOptions), + s_jsonSerializerOptions); + + #endregion } diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreTests.cs index 5a99f4c1ee20..26655edeb74f 100644 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreTests.cs @@ -34,7 +34,7 @@ public void GetCollectionWithNotSupportedKeyThrowsException() var sut = new WeaviateVectorStore(this._mockHttpClient); // Act & Assert - Assert.Throws(() => sut.GetCollection("collection")); + Assert.Throws(() => sut.GetCollection("Collection")); } [Fact] diff --git a/dotnet/src/Connectors/Directory.Build.props b/dotnet/src/Connectors/Directory.Build.props new file mode 100644 index 000000000000..29561a81e526 --- /dev/null +++ b/dotnet/src/Connectors/Directory.Build.props @@ -0,0 +1,10 @@ + + + + + + $(NoWarn);MEVD9000,MEVD9001 + $(NoWarn);CA1863 + + + \ No newline at end of file diff --git a/dotnet/src/Connectors/VectorData.Abstractions/.editorconfig b/dotnet/src/Connectors/VectorData.Abstractions/.editorconfig new file mode 100644 index 000000000000..acb2cb62caf4 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/.editorconfig @@ -0,0 +1,3 @@ +# Suppress missing documentation warnings for generated code (strings) +[*.Designer.cs] +dotnet_diagnostic.CS1591.severity = none diff --git a/dotnet/src/Connectors/VectorData.Abstractions/AssemblyInfo.cs b/dotnet/src/Connectors/VectorData.Abstractions/AssemblyInfo.cs deleted file mode 100644 index cbb67c1c8afd..000000000000 --- a/dotnet/src/Connectors/VectorData.Abstractions/AssemblyInfo.cs +++ /dev/null @@ -1 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. diff --git a/dotnet/src/Connectors/VectorData.Abstractions/CompatibilitySuppressions.xml b/dotnet/src/Connectors/VectorData.Abstractions/CompatibilitySuppressions.xml index cd9bfbaa3ca7..66bc881859c6 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/CompatibilitySuppressions.xml +++ b/dotnet/src/Connectors/VectorData.Abstractions/CompatibilitySuppressions.xml @@ -2,316 +2,526 @@ - CP0001 - T:Microsoft.Extensions.VectorData.DeleteRecordOptions + CP0002 + M:Microsoft.Extensions.VectorData.HybridSearchOptions`1.get_Top lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true - CP0001 - T:Microsoft.Extensions.VectorData.UpsertRecordOptions + CP0002 + M:Microsoft.Extensions.VectorData.HybridSearchOptions`1.set_Top(System.Int32) lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true - CP0001 - T:Microsoft.Extensions.VectorData.VectorSearchOptions + CP0002 + M:Microsoft.Extensions.VectorData.IKeywordHybridSearch`1.HybridSearchAsync``1(``0,System.Collections.Generic.ICollection{System.String},Microsoft.Extensions.VectorData.HybridSearchOptions{`0},System.Threading.CancellationToken) lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true - CP0001 - T:Microsoft.Extensions.VectorData.DeleteRecordOptions - lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll - lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + CP0002 + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true - CP0001 - T:Microsoft.Extensions.VectorData.UpsertRecordOptions - lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll - lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + CP0002 + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true - CP0001 - T:Microsoft.Extensions.VectorData.VectorSearchOptions - lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll - lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + CP0002 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteBatchAsync(System.Collections.Generic.IEnumerable{`0},System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true - CP0001 - T:Microsoft.Extensions.VectorData.DeleteRecordOptions - lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll - lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + CP0002 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.GetBatchAsync(System.Collections.Generic.IEnumerable{`0},Microsoft.Extensions.VectorData.GetRecordOptions,System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true - CP0001 - T:Microsoft.Extensions.VectorData.UpsertRecordOptions - lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll - lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + CP0002 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertBatchAsync(System.Collections.Generic.IEnumerable{`1},System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true - CP0001 - T:Microsoft.Extensions.VectorData.VectorSearchOptions - lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll - lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + CP0002 + M:Microsoft.Extensions.VectorData.VectorSearchOptions`1.get_Top + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.VectorSearchOptions`1.set_Top(System.Int32) lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.VectorStoreException.get_VectorStoreType lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,Microsoft.Extensions.VectorData.DeleteRecordOptions,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.VectorStoreException.set_VectorStoreType(System.String) lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteBatchAsync(System.Collections.Generic.IEnumerable{`0},Microsoft.Extensions.VectorData.DeleteRecordOptions,System.Threading.CancellationToken) - lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll - lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + M:Microsoft.Extensions.VectorData.HybridSearchOptions`1.get_Top + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertAsync(`1,Microsoft.Extensions.VectorData.UpsertRecordOptions,System.Threading.CancellationToken) - lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll - lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + M:Microsoft.Extensions.VectorData.HybridSearchOptions`1.set_Top(System.Int32) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertBatchAsync(System.Collections.Generic.IEnumerable{`1},Microsoft.Extensions.VectorData.UpsertRecordOptions,System.Threading.CancellationToken) - lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll - lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + M:Microsoft.Extensions.VectorData.IKeywordHybridSearch`1.HybridSearchAsync``1(``0,System.Collections.Generic.ICollection{System.String},Microsoft.Extensions.VectorData.HybridSearchOptions{`0},System.Threading.CancellationToken) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteBatchAsync(System.Collections.Generic.IEnumerable{`0},System.Threading.CancellationToken) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.GetBatchAsync(System.Collections.Generic.IEnumerable{`0},Microsoft.Extensions.VectorData.GetRecordOptions,System.Threading.CancellationToken) lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertBatchAsync(System.Collections.Generic.IEnumerable{`1},System.Threading.CancellationToken) lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,Microsoft.Extensions.VectorData.DeleteRecordOptions,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.VectorSearchOptions`1.get_Top lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteBatchAsync(System.Collections.Generic.IEnumerable{`0},Microsoft.Extensions.VectorData.DeleteRecordOptions,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.VectorSearchOptions`1.set_Top(System.Int32) lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertAsync(`1,Microsoft.Extensions.VectorData.UpsertRecordOptions,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.VectorStoreException.get_VectorStoreType lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertBatchAsync(System.Collections.Generic.IEnumerable{`1},Microsoft.Extensions.VectorData.UpsertRecordOptions,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.VectorStoreException.set_VectorStoreType(System.String) lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.HybridSearchOptions`1.get_Top lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.HybridSearchOptions`1.set_Top(System.Int32) lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,Microsoft.Extensions.VectorData.DeleteRecordOptions,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IKeywordHybridSearch`1.HybridSearchAsync``1(``0,System.Collections.Generic.ICollection{System.String},Microsoft.Extensions.VectorData.HybridSearchOptions{`0},System.Threading.CancellationToken) lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteBatchAsync(System.Collections.Generic.IEnumerable{`0},Microsoft.Extensions.VectorData.DeleteRecordOptions,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertAsync(`1,Microsoft.Extensions.VectorData.UpsertRecordOptions,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0002 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertBatchAsync(System.Collections.Generic.IEnumerable{`1},Microsoft.Extensions.VectorData.UpsertRecordOptions,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteBatchAsync(System.Collections.Generic.IEnumerable{`0},System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.GetBatchAsync(System.Collections.Generic.IEnumerable{`0},Microsoft.Extensions.VectorData.GetRecordOptions,System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertBatchAsync(System.Collections.Generic.IEnumerable{`1},System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.VectorSearchOptions`1.get_Top + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.VectorSearchOptions`1.set_Top(System.Int32) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.VectorStoreException.get_VectorStoreType + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.VectorStoreException.set_VectorStoreType(System.String) lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IKeywordHybridSearch`1.GetService(System.Type,System.Object) lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IKeywordHybridSearch`1.HybridSearchAsync``1(``0,System.Collections.Generic.ICollection{System.String},System.Int32,Microsoft.Extensions.VectorData.HybridSearchOptions{`0},System.Threading.CancellationToken) lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.GetService(System.Type,System.Object) lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteBatchAsync(System.Collections.Generic.IEnumerable{`0},System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,System.Int32,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertAsync(`1,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.GetService(System.Type,System.Object) lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertBatchAsync(System.Collections.Generic.IEnumerable{`1},System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,System.Int32,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorStore.GetService(System.Type,System.Object) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(System.Collections.Generic.IEnumerable{`0},System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.GetAsync(System.Collections.Generic.IEnumerable{`0},Microsoft.Extensions.VectorData.GetRecordOptions,System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.GetAsync(System.Linq.Expressions.Expression{System.Func{`1,System.Boolean}},System.Int32,Microsoft.Extensions.VectorData.GetFilteredRecordOptions{`1},System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertAsync(System.Collections.Generic.IEnumerable{`1},System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IKeywordHybridSearch`1.GetService(System.Type,System.Object) lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IKeywordHybridSearch`1.HybridSearchAsync``1(``0,System.Collections.Generic.ICollection{System.String},System.Int32,Microsoft.Extensions.VectorData.HybridSearchOptions{`0},System.Threading.CancellationToken) lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.GetService(System.Type,System.Object) lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteBatchAsync(System.Collections.Generic.IEnumerable{`0},System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,System.Int32,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertAsync(`1,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.GetService(System.Type,System.Object) lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertBatchAsync(System.Collections.Generic.IEnumerable{`1},System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,System.Int32,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorStore.GetService(System.Type,System.Object) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(System.Collections.Generic.IEnumerable{`0},System.Threading.CancellationToken) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.GetAsync(System.Collections.Generic.IEnumerable{`0},Microsoft.Extensions.VectorData.GetRecordOptions,System.Threading.CancellationToken) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.GetAsync(System.Linq.Expressions.Expression{System.Func{`1,System.Boolean}},System.Int32,Microsoft.Extensions.VectorData.GetFilteredRecordOptions{`1},System.Threading.CancellationToken) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertAsync(System.Collections.Generic.IEnumerable{`1},System.Threading.CancellationToken) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IKeywordHybridSearch`1.GetService(System.Type,System.Object) lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IKeywordHybridSearch`1.HybridSearchAsync``1(``0,System.Collections.Generic.ICollection{System.String},System.Int32,Microsoft.Extensions.VectorData.HybridSearchOptions{`0},System.Threading.CancellationToken) lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.GetService(System.Type,System.Object) lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteBatchAsync(System.Collections.Generic.IEnumerable{`0},System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,System.Int32,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertAsync(`1,System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.GetService(System.Type,System.Object) lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll true CP0006 - M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertBatchAsync(System.Collections.Generic.IEnumerable{`1},System.Threading.CancellationToken) + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,System.Int32,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorStore.GetService(System.Type,System.Object) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(System.Collections.Generic.IEnumerable{`0},System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.GetAsync(System.Collections.Generic.IEnumerable{`0},Microsoft.Extensions.VectorData.GetRecordOptions,System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.GetAsync(System.Linq.Expressions.Expression{System.Func{`1,System.Boolean}},System.Int32,Microsoft.Extensions.VectorData.GetFilteredRecordOptions{`1},System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.UpsertAsync(System.Collections.Generic.IEnumerable{`1},System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0021 + M:Microsoft.Extensions.VectorData.IVectorStore.GetCollection``2(System.String,Microsoft.Extensions.VectorData.VectorStoreRecordDefinition)``1:notnull + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0021 + T:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2``1:notnull + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0021 + M:Microsoft.Extensions.VectorData.IVectorStore.GetCollection``2(System.String,Microsoft.Extensions.VectorData.VectorStoreRecordDefinition)``1:notnull + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0021 + T:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2``1:notnull + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0021 + M:Microsoft.Extensions.VectorData.IVectorStore.GetCollection``2(System.String,Microsoft.Extensions.VectorData.VectorStoreRecordDefinition)``1:notnull + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0021 + T:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2``1:notnull lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll true diff --git a/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/Filter/FilterTranslationPreprocessor.cs b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/Filter/FilterTranslationPreprocessor.cs new file mode 100644 index 000000000000..23fa6d776ff4 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/Filter/FilterTranslationPreprocessor.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; + +namespace Microsoft.Extensions.VectorData.ConnectorSupport.Filter; + +/// +/// A processor for user-provided filter expressions which performs various common transformations before actual translation takes place. +/// This is an internal support type meant for use by connectors only, and not for use by applications. +/// +[Experimental("MEVD9001")] +public class FilterTranslationPreprocessor : ExpressionVisitor +{ + /// + /// Whether to inline captured variables in the filter expression (when the database doesn't support parameters). + /// + public bool InlineCapturedVariables { get; init; } + + /// + /// Whether to transform captured variables in the filter expression to (when the database supports parameters). + /// + public bool TransformCapturedVariablesToQueryParameterExpressions { get; init; } + + /// + protected override Expression VisitMember(MemberExpression node) + { + // This identifies compiler-generated closure types which contain captured variables. + // Some databases - mostly relational ones - support out-of-band parameters which can be referenced via placeholders + // from the query itself. For those databases, we transform the captured variable to QueryParameterExpression (this simplifies things for those + // connectors, and centralizes the pattern matching in a single centralized place). + // For databases which don't support parameters, we simply inline the captured variable as a constant in the tree, so that translators don't + // even need to be aware of the captured variable. + // For all other databases, we simply inline the captured variable as a constant in the tree. + if (node is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + { + return (this.InlineCapturedVariables, this.TransformCapturedVariablesToQueryParameterExpressions) switch + { + (true, false) => Expression.Constant(fieldInfo.GetValue(constant.Value), node.Type), + (false, true) => new QueryParameterExpression(fieldInfo.Name, fieldInfo.GetValue(constant.Value), node.Type), + + (true, true) => throw new InvalidOperationException("InlineCapturedVariables and TransformCapturedVariablesToQueryParameterExpressions cannot both be true."), + (false, false) => base.VisitMember(node) + }; + } + + return base.VisitMember(node); + } +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/Filter/QueryParameterExpression.cs b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/Filter/QueryParameterExpression.cs new file mode 100644 index 000000000000..caa86b665b77 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/Filter/QueryParameterExpression.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Linq.Expressions; + +namespace Microsoft.Extensions.VectorData.ConnectorSupport.Filter; + +/// +/// An expression representation a query parameter (captured variable) in the filter expression. +/// +[Experimental("MEVD9001")] +public class QueryParameterExpression(string name, object? value, Type type) : Expression +{ + /// + /// The name of the parameter. + /// + public string Name { get; } = name; + + /// + /// The value of the parameter. + /// + public object? Value { get; } = value; + + /// + public override ExpressionType NodeType => ExpressionType.Extension; + + /// + public override Type Type => type; + + /// + protected override Expression VisitChildren(ExpressionVisitor visitor) => this; +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordDataPropertyModel.cs b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordDataPropertyModel.cs new file mode 100644 index 000000000000..6296bad857e9 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordDataPropertyModel.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.Extensions.VectorData.ConnectorSupport; + +/// +/// Represents a data property on a vector store record. +/// This is an internal support type meant for use by connectors only, and not for use by applications. +/// +[Experimental("MEVD9001")] +public class VectorStoreRecordDataPropertyModel(string modelName, Type type) : VectorStoreRecordPropertyModel(modelName, type) +{ + /// + /// Gets or sets a value indicating whether this data property is indexed. + /// + /// + /// The default is . + /// + public bool IsIndexed { get; set; } + + /// + /// Gets or sets a value indicating whether this data property is indexed for full-text search. + /// + /// + /// The default is . + /// + public bool IsFullTextIndexed { get; set; } + + /// + public override string ToString() + => $"{this.ModelName} (Data, {this.Type.Name})"; +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordJsonModelBuilder.cs b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordJsonModelBuilder.cs new file mode 100644 index 000000000000..e5d5f3881d99 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordJsonModelBuilder.cs @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.Extensions.AI; + +namespace Microsoft.Extensions.VectorData.ConnectorSupport; + +/// +/// A model builder that performs logic specific to connectors which use System.Text.Json for serialization. +/// This is an internal support type meant for use by connectors only, and not for use by applications. +/// +[Experimental("MEVD9001")] +public class VectorStoreRecordJsonModelBuilder : VectorStoreRecordModelBuilder +{ + private JsonSerializerOptions _jsonSerializerOptions = JsonSerializerOptions.Default; + + /// + /// Constructs a new . + /// + public VectorStoreRecordJsonModelBuilder(VectorStoreRecordModelBuildingOptions options) + : base(options) + { + if (!options.UsesExternalSerializer) + { + throw new ArgumentNullException(nameof(options), $"{nameof(options.UsesExternalSerializer)} must be set when using this model builder."); + } + } + + /// + /// Builds and returns an from the given and . + /// + public virtual VectorStoreRecordModel Build( + Type type, + VectorStoreRecordDefinition? vectorStoreRecordDefinition, + IEmbeddingGenerator? defaultEmbeddingGenerator, + JsonSerializerOptions? jsonSerializerOptions) + { + if (jsonSerializerOptions is not null) + { + this._jsonSerializerOptions = jsonSerializerOptions; + } + + return this.Build(type, vectorStoreRecordDefinition, defaultEmbeddingGenerator); + } + + /// + protected override void Customize() + { + // This mimics the naming behavior of the System.Text.Json serializer, which we use for serialization/deserialization. + // The property storage names in the model must in sync with the serializer configuration, since the model is used e.g. for filtering + // even if serialization/deserialization doesn't use the model. + var namingPolicy = this._jsonSerializerOptions.PropertyNamingPolicy; + + foreach (var property in this.Properties) + { + var keyPropertyWithReservedName = this.Options.ReservedKeyStorageName is not null && property is VectorStoreRecordKeyPropertyModel; + string storageName; + + if (property.PropertyInfo?.GetCustomAttribute() is { } jsonPropertyNameAttribute) + { + if (keyPropertyWithReservedName && jsonPropertyNameAttribute.Name != this.Options.ReservedKeyStorageName) + { + throw new InvalidOperationException($"The key property for your connector must always have the reserved name '{this.Options.ReservedKeyStorageName}' and cannot be changed."); + } + + storageName = jsonPropertyNameAttribute.Name; + } + else if (namingPolicy is not null) + { + storageName = namingPolicy.ConvertName(property.ModelName); + } + else + { + storageName = property.ModelName; + } + + if (keyPropertyWithReservedName) + { + // Somewhat hacky: + // Some providers (Weaviate, Cosmos NoSQL) have a fixed, reserved storage name for keys (id), and at the same time use an external + // JSON serializer to serialize the entire user POCO. Since the serializer is unaware of the reserved storage name, it will produce + // a storage name as usual, based on the .NET property's name, possibly with a naming policy applied to it. The connector then needs + // to look that up and replace with the reserved name. + // So we store the policy-transformed name, as StorageName contains the reserved name. + property.TemporaryStorageName = storageName; + } + else + { + property.StorageName = storageName; + } + } + } +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordKeyPropertyModel.cs b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordKeyPropertyModel.cs new file mode 100644 index 000000000000..b791ac9d21e8 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordKeyPropertyModel.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.Extensions.VectorData.ConnectorSupport; + +/// +/// Represents a key property on a vector store record. +/// This is an internal support type meant for use by connectors only, and not for use by applications. +/// +[Experimental("MEVD9001")] +public class VectorStoreRecordKeyPropertyModel(string modelName, Type type) : VectorStoreRecordPropertyModel(modelName, type) +{ + /// + public override string ToString() + => $"{this.ModelName} (Key, {this.Type.Name})"; +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordModel.cs b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordModel.cs new file mode 100644 index 000000000000..9df86e0663d6 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordModel.cs @@ -0,0 +1,226 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; + +namespace Microsoft.Extensions.VectorData.ConnectorSupport; + +/// +/// A model representing a record in a vector store collection. +/// This is an internal support type meant for use by connectors only, and not for use by applications. +/// +[Experimental("MEVD9001")] +public sealed class VectorStoreRecordModel +{ + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] + private readonly Type _recordType; + + private VectorStoreRecordKeyPropertyModel? _singleKeyProperty; + private VectorStoreRecordVectorPropertyModel? _singleVectorProperty; + private VectorStoreRecordDataPropertyModel? _singleFullTextSearchProperty; + + /// + /// The key properties of the record. + /// + public IReadOnlyList KeyProperties { get; } + + /// + /// The data properties of the record. + /// + public IReadOnlyList DataProperties { get; } + + /// + /// The vector properties of the record. + /// + public IReadOnlyList VectorProperties { get; } + + /// + /// All properties of the record, of all types. + /// + public IReadOnlyList Properties { get; } + + /// + /// All properties of the record, of all types, indexed by their model name. + /// + public IReadOnlyDictionary PropertyMap { get; } + + internal VectorStoreRecordModel( + Type recordType, + IReadOnlyList keyProperties, + IReadOnlyList dataProperties, + IReadOnlyList vectorProperties, + IReadOnlyDictionary propertyMap) + { + this._recordType = recordType; + this.KeyProperties = keyProperties; + this.DataProperties = dataProperties; + this.VectorProperties = vectorProperties; + this.PropertyMap = propertyMap; + this.Properties = propertyMap.Values.ToList(); + } + + /// + /// Returns the single key property in the model, and throws if there are multiple key properties. + /// Suitable for connectors where validation is in place for single keys only (). + /// + public VectorStoreRecordKeyPropertyModel KeyProperty => this._singleKeyProperty ??= this.KeyProperties.Single(); + + /// + /// Returns the single vector property in the model, and throws if there are multiple vector properties. + /// Suitable for connectors where validation is in place for single vectors only (). + /// + public VectorStoreRecordVectorPropertyModel VectorProperty => this._singleVectorProperty ??= this.VectorProperties.Single(); + + /// + /// Instantiates a new record of the specified type. + /// + // TODO: the pattern of first instantiating via parameterless constructor and then populating the properties isn't compatible + // with read-only types, where properties have no setters. Supporting those would be problematic given the that different + // connectors have completely different representations of the data coming back from the database, and which needs to be + // populated. + public TRecord CreateRecord() + { + Debug.Assert(typeof(TRecord) == this._recordType, "Type mismatch between record type and model type."); + + return Activator.CreateInstance() ?? throw new InvalidOperationException($"Failed to instantiate record of type '{typeof(TRecord).Name}'."); + } + + /// + /// Get the vector property with the provided name if a name is provided, and fall back + /// to a vector property in the schema if not. If no name is provided and there is more + /// than one vector property, an exception will be thrown. + /// + /// The search options. + /// Thrown if the provided property name is not a valid vector property name. + public VectorStoreRecordVectorPropertyModel GetVectorPropertyOrSingle(VectorSearchOptions searchOptions) + { +#pragma warning disable CS0618 // Type or member is obsolete + string? vectorPropertyName = searchOptions.VectorPropertyName; +#pragma warning restore CS0618 // Type or member is obsolete + + // If vector property name is provided, try to find it in schema or throw an exception. + if (!string.IsNullOrWhiteSpace(vectorPropertyName)) + { + // Check vector properties by data model property name. + return this.VectorProperties.FirstOrDefault(p => p.ModelName == vectorPropertyName) + ?? throw new InvalidOperationException($"The {this._recordType.FullName} type does not have a vector property named '{vectorPropertyName}'."); + } + else if (searchOptions.VectorProperty is Expression> expression) + { + return this.GetMatchingProperty(expression, data: false); + } + + // If vector property name is not provided, check if there is a single vector property, or throw if there are no vectors or more than one. + return this._singleVectorProperty ??= this.VectorProperties switch + { + [var singleProperty] => singleProperty, + { Count: 0 } => throw new InvalidOperationException($"The '{this._recordType.Name}' type does not have any vector properties."), + _ => throw new InvalidOperationException($"The '{this._recordType.Name}' type has multiple vector properties, please specify your chosen property via options.") + }; + } + + /// + /// Get the text data property, that has full text search indexing enabled, with the provided name if a name is provided, and fall back + /// to a text data property in the schema if not. If no name is provided and there is more than one text data property with + /// full text search indexing enabled, an exception will be thrown. + /// + /// The full text search property selector. + /// Thrown if the provided property name is not a valid text data property name. + public VectorStoreRecordDataPropertyModel GetFullTextDataPropertyOrSingle(Expression>? expression) + { + if (expression is not null) + { + var property = this.GetMatchingProperty(expression, data: true); + + return property.IsFullTextIndexed + ? property + : throw new InvalidOperationException($"The property '{property.ModelName}' on '{this._recordType.Name}' must have full text search indexing enabled."); + } + + if (this._singleFullTextSearchProperty is null) + { + // If text data property name is not provided, check if a single full text indexed text property exists or throw otherwise. + var fullTextStringProperties = this.DataProperties + .Where(l => l.Type == typeof(string) && l.IsFullTextIndexed) + .ToList(); + + // If text data property name is not provided, check if a single full text indexed text property exists or throw otherwise. + this._singleFullTextSearchProperty = fullTextStringProperties switch + { + [var singleProperty] => singleProperty, + { Count: 0 } => throw new InvalidOperationException($"The '{this._recordType.Name}' type does not have any text data properties that have full text indexing enabled."), + _ => throw new InvalidOperationException($"The '{this._recordType.Name}' type has multiple text data properties that have full text indexing enabled, please specify your chosen property via options.") + }; + } + + return this._singleFullTextSearchProperty; + } + + /// + /// Get the data or key property selected by provided expression. + /// + /// The property selector. + /// Thrown if the provided property name is not a valid data or key property name. + public VectorStoreRecordPropertyModel GetDataOrKeyProperty(Expression> expression) + => this.GetMatchingProperty(expression, data: true); + + private TProperty GetMatchingProperty(Expression> expression, bool data) + where TProperty : VectorStoreRecordPropertyModel + { + var node = expression.Body; + + // First, unwrap any object convert node: r => (object)r.PropertyName becomes r => r.PropertyName + if (expression.Body is UnaryExpression { NodeType: ExpressionType.Convert } convert + && convert.Type == typeof(object)) + { + node = convert.Operand; + } + + var propertyName = node switch + { + // Simple member expression over the lambda parameter (r => r.PropertyName) + MemberExpression { Member: PropertyInfo clrProperty } member when member.Expression == expression.Parameters[0] + => clrProperty.Name, + + // Dictionary access over the lambda parameter, in dynamic mapping (r => r["PropertyName"]) + MethodCallExpression { Method.Name: "get_Item", Arguments: [var keyExpression] } methodCall + => keyExpression switch + { + ConstantExpression { Value: string text } => text, + MemberExpression field when TryGetCapturedValue(field, out object? capturedValue) && capturedValue is string text => text, + _ => throw new InvalidOperationException("Invalid dictionary key expression") + }, + + _ => throw new InvalidOperationException("Property selector lambda is invalid") + }; + + if (!this.PropertyMap.TryGetValue(propertyName, out var property)) + { + throw new InvalidOperationException($"Property '{propertyName}' could not be found."); + } + + return property is TProperty typedProperty + ? typedProperty + : throw new InvalidOperationException($"Property '{propertyName}' isn't of type '{typeof(TProperty).Name}'."); + + static bool TryGetCapturedValue(Expression expression, out object? capturedValue) + { + if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + { + capturedValue = fieldInfo.GetValue(constant.Value); + return true; + } + + capturedValue = null; + return false; + } + } +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordModelBuilder.cs b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordModelBuilder.cs new file mode 100644 index 000000000000..dfd68ffb6466 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordModelBuilder.cs @@ -0,0 +1,600 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Reflection; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData.Properties; + +namespace Microsoft.Extensions.VectorData.ConnectorSupport; + +/// +/// Represents a builder for a . +/// This is an internal support type meant for use by connectors only, and not for use by applications. +/// +/// Note that this class is single-use only, and not thread-safe. +[Experimental("MEVD9001")] +public class VectorStoreRecordModelBuilder +{ + /// + /// Options for building the model. + /// + protected VectorStoreRecordModelBuildingOptions Options { get; } + + /// + /// The key properties of the record. + /// + protected List KeyProperties { get; } = []; + + /// + /// The data properties of the record. + /// + protected List DataProperties { get; } = []; + + /// + /// The vector properties of the record. + /// + protected List VectorProperties { get; } = []; + + /// + /// All properties of the record, of all types. + /// + protected IEnumerable Properties => this.PropertyMap.Values; + + /// + /// All properties of the record, of all types, indexed by their model name. + /// + protected Dictionary PropertyMap { get; } = new(); + + /// + /// The default embedding generator to use for vector properties, when none is specified at the property or collection level. + /// + protected IEmbeddingGenerator? DefaultEmbeddingGenerator { get; private set; } + + /// + /// Constructs a new . + /// + public VectorStoreRecordModelBuilder(VectorStoreRecordModelBuildingOptions options) + { + if (options.SupportsMultipleKeys && options.ReservedKeyStorageName is not null) + { + throw new ArgumentException($"{nameof(VectorStoreRecordModelBuildingOptions.ReservedKeyStorageName)} cannot be set when {nameof(VectorStoreRecordModelBuildingOptions.SupportsMultipleKeys)} is set."); + } + + this.Options = options; + } + + /// + /// Builds and returns an from the given and . + /// + [RequiresDynamicCode("Currently not compatible with NativeAOT code")] + [RequiresUnreferencedCode("Currently not compatible with trimming")] // TODO + public virtual VectorStoreRecordModel Build(Type type, VectorStoreRecordDefinition? vectorStoreRecordDefinition, IEmbeddingGenerator? defaultEmbeddingGenerator) + { + this.DefaultEmbeddingGenerator = defaultEmbeddingGenerator; + + var dynamicMapping = type == typeof(Dictionary); + + if (!dynamicMapping) + { + this.ProcessTypeProperties(type, vectorStoreRecordDefinition); + } + + if (vectorStoreRecordDefinition is null) + { + if (dynamicMapping) + { + throw new ArgumentException("Vector store record definition must be provided for dynamic mapping."); + } + } + else + { + this.ProcessRecordDefinition(vectorStoreRecordDefinition, dynamicMapping ? null : type); + } + + this.Customize(); + this.Validate(type); + + return new(type, this.KeyProperties, this.DataProperties, this.VectorProperties, this.PropertyMap); + } + + /// + /// As part of building the model, this method processes the properties of the given , + /// detecting and reading attributes that affect the model. Not called for dynamic mapping scenarios. + /// + // TODO: This traverses the CLR type's properties, making it incompatible with trimming (and NativeAOT). + // TODO: We could put [DynamicallyAccessedMembers] to preserve all properties, but that approach wouldn't + // TODO: work with hierarchical data models (#10957). + [RequiresUnreferencedCode("Traverses the CLR type's properties with reflection, so not compatible with trimming")] + protected virtual void ProcessTypeProperties(Type type, VectorStoreRecordDefinition? vectorStoreRecordDefinition) + { + // We want to allow the user-provided record definition to override anything configured via attributes + // (allowing the same CLR type + attributes to be used with different record definitions). + foreach (var clrProperty in type.GetProperties()) + { + VectorStoreRecordPropertyModel? property = null; + string? storageName = null; + + if (clrProperty.GetCustomAttribute() is { } keyAttribute) + { + var keyProperty = new VectorStoreRecordKeyPropertyModel(clrProperty.Name, clrProperty.PropertyType); + this.KeyProperties.Add(keyProperty); + storageName = keyAttribute.StoragePropertyName; + property = keyProperty; + } + + if (clrProperty.GetCustomAttribute() is { } dataAttribute) + { + if (property is not null) + { + // TODO: Test + throw new InvalidOperationException($"Property '{type.Name}.{clrProperty.Name}' has multiple of {nameof(VectorStoreRecordKeyAttribute)}, {nameof(VectorStoreRecordDataAttribute)} or {nameof(VectorStoreRecordVectorAttribute)}. Only one of these attributes can be specified on a property."); + } + + var dataProperty = new VectorStoreRecordDataPropertyModel(clrProperty.Name, clrProperty.PropertyType) + { + IsIndexed = dataAttribute.IsIndexed, + IsFullTextIndexed = dataAttribute.IsFullTextIndexed, + }; + + this.DataProperties.Add(dataProperty); + storageName = dataAttribute.StoragePropertyName; + property = dataProperty; + } + + if (clrProperty.GetCustomAttribute() is { } vectorAttribute) + { + if (property is not null) + { + throw new InvalidOperationException($"Only one of {nameof(VectorStoreRecordKeyAttribute)}, {nameof(VectorStoreRecordDataAttribute)} and {nameof(VectorStoreRecordVectorAttribute)} can be applied to a property."); + } + + // If a record definition exists for the property, we must instantiate it via that definition, as the user may be using + // a generic VectorStoreRecordVectorProperty for a custom input type. + var vectorProperty = vectorStoreRecordDefinition?.Properties.FirstOrDefault(p => p.DataModelPropertyName == clrProperty.Name) is VectorStoreRecordVectorProperty definitionVectorProperty + ? definitionVectorProperty.CreatePropertyModel() + : new VectorStoreRecordVectorPropertyModel(clrProperty.Name, clrProperty.PropertyType); + + vectorProperty.Dimensions = vectorAttribute.Dimensions; + vectorProperty.IndexKind = vectorAttribute.IndexKind; + vectorProperty.DistanceFunction = vectorAttribute.DistanceFunction; + + // If a default embedding generator is defined and the property type isn't an Embedding, we set up that embedding generator on the property. + // At this point we don't know the embedding type (it might get specified in the record definition, that's processed later). So we infer + // + // This also means that the property type is the input type (e.g. string, DataContent) rather than the embedding type. + // Since we need the property type to be the embedding type, we infer that from the generator. This allows users + // to just stick an IEmbeddingGenerator in DI, define a string property as their vector property, and as long as the embedding generator + // is compatible (supports string and ROM, assuming that's what the connector requires), everything just works. + // Note that inferring the embedding type from the IEmbeddingGenerator isn't trivial, involving both connector logic (around which embedding + // types are supported/preferred), as well as the vector property type (which knows about supported input types). + + if (this.DefaultEmbeddingGenerator is null || this.Options.SupportedVectorPropertyTypes.Contains(clrProperty.PropertyType)) + { + vectorProperty.EmbeddingType = clrProperty.PropertyType; + } + else + { + this.SetupEmbeddingGeneration(vectorProperty, this.DefaultEmbeddingGenerator, embeddingType: null); + } + + this.VectorProperties.Add(vectorProperty); + storageName = vectorAttribute.StoragePropertyName; + property = vectorProperty; + } + + if (property is null) + { + // No mapping attribute was found, ignore this property. + continue; + } + + this.SetPropertyStorageName(property, storageName, type); + + property.PropertyInfo = clrProperty; + this.PropertyMap.Add(clrProperty.Name, property); + } + } + + /// + /// As part of building the model, this method processes the given . + /// + protected virtual void ProcessRecordDefinition( + VectorStoreRecordDefinition vectorStoreRecordDefinition, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] Type? type) + { + foreach (VectorStoreRecordProperty definitionProperty in vectorStoreRecordDefinition.Properties) + { + if (!this.PropertyMap.TryGetValue(definitionProperty.DataModelPropertyName, out var property)) + { + // Property wasn't found attribute-annotated on the CLR type, so we need to add it. + + // TODO: Make the property CLR type optional - no need to specify it when using a CLR type. + switch (definitionProperty) + { + case VectorStoreRecordKeyProperty definitionKeyProperty: + var keyProperty = new VectorStoreRecordKeyPropertyModel(definitionKeyProperty.DataModelPropertyName, definitionKeyProperty.PropertyType); + this.KeyProperties.Add(keyProperty); + this.PropertyMap.Add(definitionKeyProperty.DataModelPropertyName, keyProperty); + property = keyProperty; + break; + case VectorStoreRecordDataProperty definitionDataProperty: + var dataProperty = new VectorStoreRecordDataPropertyModel(definitionDataProperty.DataModelPropertyName, definitionDataProperty.PropertyType); + this.DataProperties.Add(dataProperty); + this.PropertyMap.Add(definitionDataProperty.DataModelPropertyName, dataProperty); + property = dataProperty; + break; + case VectorStoreRecordVectorProperty definitionVectorProperty: + var vectorProperty = definitionVectorProperty.CreatePropertyModel(); + this.VectorProperties.Add(vectorProperty); + this.PropertyMap.Add(definitionVectorProperty.DataModelPropertyName, vectorProperty); + property = vectorProperty; + break; + default: + throw new ArgumentException($"Unknown type '{definitionProperty.GetType().FullName}' in vector store record definition."); + } + + if (type is not null) + { + // If we have a CLR type (POCO, not dynamic mapping), get the .NET property's type and make sure it matches the definition. + property.PropertyInfo = type.GetProperty(property.ModelName) + ?? throw new InvalidOperationException($"Property '{property.ModelName}' not found on CLR type '{type.FullName}'."); + + if (property.PropertyInfo.PropertyType != property.Type) + { + throw new InvalidOperationException($"Property '{property.ModelName}' has a different CLR type in the record definition ('{property.Type.Name}') and on the .NET property ('{property.PropertyInfo.PropertyType}')."); + } + } + } + + property.Type = definitionProperty.PropertyType; + this.SetPropertyStorageName(property, definitionProperty.StoragePropertyName, type); + + switch (definitionProperty) + { + case VectorStoreRecordKeyProperty definitionKeyProperty: + if (property is not VectorStoreRecordKeyPropertyModel keyPropertyModel) + { + throw new InvalidOperationException( + $"Property '{property.ModelName}' is present in the {nameof(VectorStoreRecordDefinition)} as a key property, but the .NET property on type '{type?.Name}' has an incompatible attribute."); + } + + break; + + case VectorStoreRecordDataProperty definitionDataProperty: + if (property is not VectorStoreRecordDataPropertyModel dataProperty) + { + throw new InvalidOperationException( + $"Property '{property.ModelName}' is present in the {nameof(VectorStoreRecordDefinition)} as a data property, but the .NET property on type '{type?.Name}' has an incompatible attribute."); + } + + dataProperty.IsIndexed = definitionDataProperty.IsIndexed; + dataProperty.IsFullTextIndexed = definitionDataProperty.IsFullTextIndexed; + + break; + + case VectorStoreRecordVectorProperty definitionVectorProperty: + if (property is not VectorStoreRecordVectorPropertyModel vectorProperty) + { + throw new InvalidOperationException( + $"Property '{property.ModelName}' is present in the {nameof(VectorStoreRecordDefinition)} as a vector property, but the .NET property on type '{type?.Name}' has an incompatible attribute."); + } + + vectorProperty.Dimensions = definitionVectorProperty.Dimensions; + + if (definitionVectorProperty.IndexKind is not null) + { + vectorProperty.IndexKind = definitionVectorProperty.IndexKind; + } + + if (definitionVectorProperty.DistanceFunction is not null) + { + vectorProperty.DistanceFunction = definitionVectorProperty.DistanceFunction; + } + + if (definitionVectorProperty.EmbeddingType is not null) + { + vectorProperty.EmbeddingType = definitionVectorProperty.EmbeddingType; + } + + // Check if embedding generation is configured, either on the property directly or via a default + IEmbeddingGenerator? embeddingGenerator = null; + + // Check if an embedding generator is defined specifically on the property. + if (definitionVectorProperty.EmbeddingGenerator is not null) + { + // If we have a property CLR type (POCO, not dynamic mapping) and it's an embedding type, throw as that's incompatible. + if (this.Options.SupportedVectorPropertyTypes.Contains(property.Type)) + { + throw new InvalidOperationException( + string.Format( + VectorDataStrings.EmbeddingPropertyTypeIncompatibleWithEmbeddingGenerator, + property.ModelName, + property.Type.Name)); + } + + embeddingGenerator = definitionVectorProperty.EmbeddingGenerator; + } + // If a default embedding generator is defined (at the collection or store level), configure that on the property, but only if the property type is not an embedding type. + // If the property type is an embedding type, just ignore the default embedding generator. + else if ((vectorStoreRecordDefinition.EmbeddingGenerator ?? this.DefaultEmbeddingGenerator) is IEmbeddingGenerator defaultEmbeddingGenerator + && !this.Options.SupportedVectorPropertyTypes.Contains(property.Type)) + { + embeddingGenerator = vectorStoreRecordDefinition.EmbeddingGenerator ?? this.DefaultEmbeddingGenerator; + } + + if (embeddingGenerator is null) + { + // No embedding generation - the embedding type and the property (model) type are the same. + vectorProperty.EmbeddingType = property.Type; + } + else + { + this.SetupEmbeddingGeneration(vectorProperty, embeddingGenerator, vectorProperty.EmbeddingType); + } + break; + + default: + throw new ArgumentException($"Unknown type '{definitionProperty.GetType().FullName}' in vector store record definition."); + } + } + } + + private void SetPropertyStorageName(VectorStoreRecordPropertyModel property, string? storageName, Type? type) + { + if (property is VectorStoreRecordKeyPropertyModel && this.Options.ReservedKeyStorageName is not null) + { + // If we have ReservedKeyStorageName, there can only be a single key property (validated in the constructor) + property.StorageName = this.Options.ReservedKeyStorageName; + return; + } + + if (storageName is null) + { + return; + } + + // If a custom serializer is used (e.g. JsonSerializer), it would ignore our own attributes/config, and + // our model needs to be in sync with the serializer's behavior (for e.g. storage names in filters). + // So we ignore the config here as well. + // TODO: Consider throwing here instead of ignoring + if (this.Options.UsesExternalSerializer && type != null) + { + return; + } + + property.StorageName = this.Options.EscapeIdentifier is not null + ? this.Options.EscapeIdentifier(storageName) + : storageName; + } + + /// + /// Attempts to setup embedding generation on the given vector property, with the given embedding generator and user-configured embedding type. + /// Can be overridden by connectors to provide support for other embedding types. + /// + protected virtual void SetupEmbeddingGeneration( + VectorStoreRecordVectorPropertyModel vectorProperty, + IEmbeddingGenerator embeddingGenerator, + Type? embeddingType) + { + if (!vectorProperty.TrySetupEmbeddingGeneration, ReadOnlyMemory>(embeddingGenerator, embeddingType)) + { + throw new InvalidOperationException( + string.Format( + VectorDataStrings.IncompatibleEmbeddingGenerator, + embeddingGenerator.GetType().Name, + string.Join(", ", vectorProperty.GetSupportedInputTypes().Select(t => t.Name)), + "ReadOnlyMemory")); + } + } + + /// + /// Extension hook for connectors to be able to customize the model. + /// + protected virtual void Customize() + { + } + + /// + /// Validates the model after all properties have been processed. + /// + protected virtual void Validate(Type type) + { + if (!this.Options.UsesExternalSerializer && type.GetConstructor(Type.EmptyTypes) is null) + { + throw new NotSupportedException($"Type '{type.Name}' must have a parameterless constructor."); + } + + if (!this.Options.SupportsMultipleKeys && this.KeyProperties.Count > 1) + { + throw new NotSupportedException($"Multiple key properties found on type '{type.Name}' or the provided {nameof(VectorStoreRecordDefinition)} while only one is supported."); + } + + if (this.KeyProperties.Count == 0) + { + throw new NotSupportedException($"No key property found on type '{type.Name}' or the provided {nameof(VectorStoreRecordDefinition)} while at least one is required."); + } + + if (this.Options.RequiresAtLeastOneVector && this.VectorProperties.Count == 0) + { + throw new NotSupportedException($"No vector property found on type '{type.Name}' or the provided {nameof(VectorStoreRecordDefinition)} while at least one is required."); + } + + if (!this.Options.SupportsMultipleVectors && this.VectorProperties.Count > 1) + { + throw new NotSupportedException($"Multiple vector properties found on type '{type.Name}' or the provided {nameof(VectorStoreRecordDefinition)} while only one is supported."); + } + + var storageNameMap = new Dictionary(); + + foreach (var property in this.PropertyMap.Values) + { + this.ValidateProperty(property); + + if (storageNameMap.TryGetValue(property.StorageName, out var otherproperty)) + { + throw new InvalidOperationException($"Property '{property.ModelName}' is being mapped to storage name '{property.StorageName}', but property '{otherproperty.ModelName}' is already mapped to the same storage name."); + } + + storageNameMap[property.StorageName] = property; + } + } + + /// + /// Validates a single property, performing validation on it. + /// + protected virtual void ValidateProperty(VectorStoreRecordPropertyModel propertyModel) + { + var type = propertyModel.Type; + + Debug.Assert(propertyModel.Type is not null); + + if (type.IsGenericType && Nullable.GetUnderlyingType(type) is Type underlyingType) + { + type = underlyingType; + } + + switch (propertyModel) + { + case VectorStoreRecordKeyPropertyModel keyProperty: + if (this.Options.SupportedKeyPropertyTypes is not null) + { + ValidatePropertyType(propertyModel.ModelName, type, "Key", this.Options.SupportedKeyPropertyTypes); + } + break; + + case VectorStoreRecordDataPropertyModel dataProperty: + if (this.Options.SupportedDataPropertyTypes is not null) + { + ValidatePropertyType(propertyModel.ModelName, type, "Data", this.Options.SupportedDataPropertyTypes, this.Options.SupportedEnumerableDataPropertyElementTypes); + } + break; + + case VectorStoreRecordVectorPropertyModel vectorProperty: + Debug.Assert(vectorProperty.EmbeddingGenerator is null ^ vectorProperty.Type != vectorProperty.EmbeddingType); + + if (!this.Options.SupportedVectorPropertyTypes.Contains(vectorProperty.EmbeddingType)) + { + throw new InvalidOperationException( + vectorProperty.EmbeddingGenerator is null + ? string.Format(VectorDataStrings.NonEmbeddingVectorPropertyWithoutEmbeddingGenerator, vectorProperty.ModelName, vectorProperty.EmbeddingType.Name) + : string.Format(VectorDataStrings.EmbeddingGeneratorWithInvalidEmbeddingType, vectorProperty.ModelName, vectorProperty.EmbeddingType.Name)); + } + + if (vectorProperty.Dimensions <= 0) + { + throw new InvalidOperationException($"Vector property '{propertyModel.ModelName}' must have a positive number of dimensions."); + } + + break; + + default: + throw new UnreachableException(); + } + } + + private static void ValidatePropertyType(string propertyName, Type propertyType, string propertyCategoryDescription, HashSet supportedTypes, HashSet? supportedEnumerableElementTypes = null) + { + // Add shortcut before testing all the more expensive scenarios. + if (supportedTypes.Contains(propertyType)) + { + return; + } + + // Check all collection scenarios and get stored type. + if (supportedEnumerableElementTypes?.Count > 0 && IsSupportedEnumerableType(propertyType)) + { + var typeToCheck = GetCollectionElementType(propertyType); + + if (!supportedEnumerableElementTypes.Contains(typeToCheck)) + { + var supportedEnumerableElementTypesString = string.Join(", ", supportedEnumerableElementTypes!.Select(t => t.FullName)); + throw new NotSupportedException($"Enumerable {propertyCategoryDescription} properties must have one of the supported element types: {supportedEnumerableElementTypesString}. Element type of the property '{propertyName}' is {typeToCheck.FullName}."); + } + } + else + { + // if we got here, we know the type is not supported + var supportedTypesString = string.Join(", ", supportedTypes.Select(t => t.FullName)); + var supportedEnumerableTypesString = supportedEnumerableElementTypes is { Count: > 0 } ? string.Join(", ", supportedEnumerableElementTypes.Select(t => t.FullName)) : null; + throw new NotSupportedException($""" + Property '{propertyName}' has unsupported type '{propertyType.Name}'. + {propertyCategoryDescription} properties must be one of the supported types: {supportedTypesString}{(supportedEnumerableElementTypes is null ? "" : ", or a collection type over: " + supportedEnumerableElementTypes)}. + """); + } + } + + private static bool IsSupportedEnumerableType(Type type) + { + if (type.IsArray || type == typeof(IEnumerable)) + { + return true; + } + +#if NET6_0_OR_GREATER + if (typeof(IList).IsAssignableFrom(type) && type.GetMemberWithSameMetadataDefinitionAs(s_objectGetDefaultConstructorInfo) != null) +#else + if (typeof(IList).IsAssignableFrom(type) && type.GetConstructor(Type.EmptyTypes) != null) +#endif + { + return true; + } + + if (type.IsGenericType) + { + var genericTypeDefinition = type.GetGenericTypeDefinition(); + if (genericTypeDefinition == typeof(ICollection<>) || + genericTypeDefinition == typeof(IEnumerable<>) || + genericTypeDefinition == typeof(IList<>) || + genericTypeDefinition == typeof(IReadOnlyCollection<>) || + genericTypeDefinition == typeof(IReadOnlyList<>)) + { + return true; + } + } + + return false; + } + + private static Type GetCollectionElementType(Type collectionType) + { + return collectionType switch + { + IEnumerable => typeof(object), + var enumerableType when GetGenericEnumerableInterface(enumerableType) is Type enumerableInterface => enumerableInterface.GetGenericArguments()[0], + var arrayType when arrayType.IsArray => arrayType.GetElementType()!, + _ => collectionType + }; + } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2070:UnrecognizedReflectionPattern", + Justification = "The 'IEnumerable<>' Type must exist and so trimmer kept it. In which case " + + "It also kept it on any type which implements it. The below call to GetInterfaces " + + "may return fewer results when trimmed but it will return 'IEnumerable<>' " + + "if the type implemented it, even after trimming.")] + private static Type? GetGenericEnumerableInterface(Type type) + { + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IEnumerable<>)) + { + return type; + } + + foreach (Type typeToCheck in type.GetInterfaces()) + { + if (typeToCheck.IsGenericType && typeToCheck.GetGenericTypeDefinition() == typeof(IEnumerable<>)) + { + return typeToCheck; + } + } + + return null; + } + +#if NET6_0_OR_GREATER + private static readonly ConstructorInfo s_objectGetDefaultConstructorInfo = typeof(object).GetConstructor(Type.EmptyTypes)!; +#endif +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordModelBuildingOptions.cs b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordModelBuildingOptions.cs new file mode 100644 index 000000000000..958d24241537 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordModelBuildingOptions.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.Extensions.VectorData.ConnectorSupport; + +/// +/// Contains options affecting model building; passed to . +/// This is an internal support type meant for use by connectors only, and not for use by applications. +/// +[Experimental("MEVD9001")] +public sealed class VectorStoreRecordModelBuildingOptions +{ + /// + /// Whether multiple key properties are supported. + /// + public required bool SupportsMultipleKeys { get; init; } + + /// + /// Whether multiple vector properties are supported. + /// + public required bool SupportsMultipleVectors { get; init; } + + /// + /// Whether at least one vector property is required. + /// + public required bool RequiresAtLeastOneVector { get; init; } + + /// + /// The set of types that are supported as key properties. + /// + public required HashSet? SupportedKeyPropertyTypes { get; init; } + + /// + /// The set of types that are supported as data properties. + /// + public required HashSet? SupportedDataPropertyTypes { get; init; } + + /// + /// The set of element types that are supported within collection types in data properties. + /// + public required HashSet? SupportedEnumerableDataPropertyElementTypes { get; init; } + + /// + /// The set of types that are supported as vector properties. + /// + public required HashSet SupportedVectorPropertyTypes { get; init; } + + /// + /// Indicates that an external serializer will be used (e.g. System.Text.Json). + /// + public bool UsesExternalSerializer { get; init; } + + /// + /// Indicates that the database requires the key property to have a special, reserved name. + /// When set, the model builder will manage the key storage name, and users may not customize it. + /// + public string? ReservedKeyStorageName { get; init; } + + /// + /// A method for escaping storage names. + /// + public Func? EscapeIdentifier { get; init; } +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordPropertyModel.cs b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordPropertyModel.cs new file mode 100644 index 000000000000..4d9534c0d979 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordPropertyModel.cs @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; + +namespace Microsoft.Extensions.VectorData.ConnectorSupport; + +/// +/// Represents a property on a vector store record. +/// This is an internal support type meant for use by connectors only, and not for use by applications. +/// +[Experimental("MEVD9001")] +public abstract class VectorStoreRecordPropertyModel(string modelName, Type type) +{ + private string? _storageName; + + /// + /// The model name of the property. If the property corresponds to a .NET property, this name is the name of that property. + /// + public string ModelName { get; set; } = modelName; + + /// + /// The storage name of the property. This is the name to which the property is mapped in the vector store. + /// + public string StorageName + { + get => this._storageName ?? this.ModelName; + set => this._storageName = value; + } + + // See comment in VectorStoreJsonModelBuilder + // TODO: Spend more time thinking about this, there may be a less hacky way to handle it. + + /// + /// A temporary storage name for the property, for use during the serialization process by certain connectors. + /// + [Experimental("MEVD9001")] + public string? TemporaryStorageName { get; set; } + + /// + /// The CLR type of the property. + /// + public Type Type { get; set; } = type; + + /// + /// The reflection for the .NET property. + /// when using dynamic mapping. + /// + public PropertyInfo? PropertyInfo { get; set; } + + /// + /// Reads the property from the given , returning the value as an . + /// + public virtual object? GetValueAsObject(object record) + { + if (this.PropertyInfo is null) + { + if (record is Dictionary dictionary) + { + return dictionary.TryGetValue(this.ModelName, out var value) + ? value + : null; + } + + throw new UnreachableException("Non-dynamic mapping but PropertyInfo is null."); + } + + // We have a CLR property (non-dynamic POCO mapping) + + // TODO: Implement compiled delegates for better performance, #11122 + // TODO: Implement source-generated accessors for NativeAOT, #10256 + + return this.PropertyInfo.GetValue(record); + } + + /// + /// Writes the property from the given , accepting the value to write as an . + /// s + public virtual void SetValueAsObject(object record, object? value) + { + if (this.PropertyInfo is null) + { + if (record.GetType() == typeof(Dictionary)) + { + var dictionary = (Dictionary)record; + dictionary[this.ModelName] = value; + return; + } + + throw new UnreachableException("Non-dynamic mapping but ClrProperty is null."); + } + + // We have a CLR property (non-dynamic POCO mapping) + + // TODO: Implement compiled delegates for better performance, #11122 + // TODO: Implement source-generated accessors for NativeAOT, #10256 + + // If the value is null, no need to set the property (it's the CLR default) + if (value is not null) + { + this.PropertyInfo.SetValue(record, value); + } + } + + // TODO: implement the generic accessors to avoid boxing, and make use of them in connectors +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordVectorPropertyModel.cs b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordVectorPropertyModel.cs new file mode 100644 index 000000000000..09a557664753 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordVectorPropertyModel.cs @@ -0,0 +1,207 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; + +namespace Microsoft.Extensions.VectorData.ConnectorSupport; + +/// +/// Represents a vector property on a vector store record. +/// This is an internal support type meant for use by connectors only, and not for use by applications. +/// +[Experimental("MEVD9001")] +public class VectorStoreRecordVectorPropertyModel(string modelName, Type type) : VectorStoreRecordPropertyModel(modelName, type) +{ + private int _dimensions; + + /// + /// The number of dimensions that the vector has. + /// + /// + /// This property is required when creating collections, but can be omitted if not using that functionality. + /// If not provided when trying to create a collection, create will fail. + /// + public int Dimensions + { + get => this._dimensions; + + set + { + if (value <= 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "Dimensions must be greater than zero."); + } + + this._dimensions = value; + } + } + + /// + /// The kind of index to use. + /// + /// + /// The default varies by database type. See the documentation of your chosen database connector for more information. + /// + /// + public string? IndexKind { get; set; } + + /// + /// The distance function to use when comparing vectors. + /// + /// + /// The default varies by database type. See the documentation of your chosen database connector for more information. + /// + /// + public string? DistanceFunction { get; set; } + + /// + /// If is set, contains the type representing the embedding stored in the database. + /// Otherwise, this property is identical to . + /// + public Type EmbeddingType { get; set; } = null!; + + /// + /// The embedding generator to use for this property. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; set; } + + /// + /// Checks whether the configured on this property supports the given embedding type. + /// The implementation on this non-generic checks for + /// and as input types for . + /// + public virtual bool TrySetupEmbeddingGeneration(IEmbeddingGenerator embeddingGenerator, Type? embeddingType) + where TEmbedding : Embedding + { + // On the TInput side, this out-of-the-box/simple implementation supports string and DataContent only + // (users who want arbitrary TInput types need to use the generic subclass of this type). + // The TEmbedding side is provided by the connector via the generic type parameter to this method, as the connector controls/knows which embedding types are supported. + // Note that if the user has manually specified an embedding type (e.g. to choose Embedding rather than the default Embedding), that's provided via the embeddingType argument; + // we use that as a filter below. + switch (embeddingGenerator) + { + case IEmbeddingGenerator when this.Type == typeof(string) && (embeddingType is null || embeddingType == typeof(TUnwrappedEmbedding)): + case IEmbeddingGenerator when this.Type == typeof(DataContent) && (embeddingType is null || embeddingType == typeof(TUnwrappedEmbedding)): + this.EmbeddingGenerator = embeddingGenerator; + this.EmbeddingType = embeddingType ?? typeof(TUnwrappedEmbedding); + + return true; + + case null: + throw new UnreachableException("This method should only be called when an embedding generator is configured."); + default: + return false; + } + } + + /// + /// Attempts to generate an embedding of type from the vector property represented by this instance on the given , using + /// the configured . + /// + /// + /// + /// If supports the given , returns and sets to a + /// representing the embedding generation operation. If does not support the given , returns . + /// + /// + /// The implementation on this non-generic checks for + /// and as input types for . + /// + /// + public virtual bool TryGenerateEmbedding(TRecord record, CancellationToken cancellationToken, [NotNullWhen(true)] out Task? task) + where TRecord : notnull + where TEmbedding : Embedding + { + switch (this.EmbeddingGenerator) + { + case IEmbeddingGenerator generator when this.EmbeddingType == typeof(TUnwrappedEmbedding): + { + task = generator.GenerateEmbeddingAsync( + this.GetValueAsObject(record) is var value && value is string s + ? s + : throw new InvalidOperationException($"Property '{this.ModelName}' was configured with an embedding generator accepting a string, but {value?.GetType().Name ?? "null"} was provided."), + new() { Dimensions = this.Dimensions }, + cancellationToken); + return true; + } + + case IEmbeddingGenerator generator when this.EmbeddingType == typeof(TUnwrappedEmbedding): + { + task = generator.GenerateEmbeddingAsync( + this.GetValueAsObject(record) is var value && value is DataContent c + ? c + : throw new InvalidOperationException($"Property '{this.ModelName}' was configured with an embedding generator accepting a {nameof(DataContent)}, but {value?.GetType().Name ?? "null"} was provided."), + new() { Dimensions = this.Dimensions }, + cancellationToken); + return true; + } + + case null: + throw new UnreachableException("This method should only be called when an embedding generator is configured."); + + default: + task = null; + return false; + } + } + + /// + /// Attempts to generate embeddings of type from the vector property represented by this instance on the given , using + /// the configured . + /// + /// + /// + /// If supports the given , returns and sets to a + /// representing the embedding generation operation. If does not support the given , returns . + /// + /// + /// The implementation on this non-generic checks for + /// and as input types for . + /// + /// + public virtual bool TryGenerateEmbeddings(IEnumerable records, CancellationToken cancellationToken, [NotNullWhen(true)] out Task>? task) + where TRecord : notnull + where TEmbedding : Embedding + { + switch (this.EmbeddingGenerator) + { + case IEmbeddingGenerator generator when this.EmbeddingType == typeof(TUnwrappedEmbedding): + task = generator.GenerateAsync( + records.Select(r => this.GetValueAsObject(r) is var value && value is string s + ? s + : throw new InvalidOperationException($"Property '{this.ModelName}' was configured with an embedding generator accepting a string, but {value?.GetType().Name ?? "null"} was provided.")), + new() { Dimensions = this.Dimensions }, cancellationToken); + return true; + + case IEmbeddingGenerator generator when this.EmbeddingType == typeof(TUnwrappedEmbedding): + task = generator.GenerateAsync( + records.Select(r => this.GetValueAsObject(r) is var value && value is DataContent c + ? c + : throw new InvalidOperationException($"Property '{this.ModelName}' was configured with an embedding generator accepting a {nameof(DataContent)}, but {value?.GetType().Name ?? "null"} was provided.")), + new() { Dimensions = this.Dimensions }, cancellationToken); + return true; + + case null: + throw new UnreachableException("This method should only be called when an embedding generator is configured."); + + default: + task = null; + return false; + } + } + + /// + /// Returns the types of input that this property model supports. + /// + public virtual Type[] GetSupportedInputTypes() => [typeof(string), typeof(DataContent)]; + + /// + public override string ToString() + => $"{this.ModelName} (Vector, {this.Type.Name})"; +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordVectorPropertyModel{TInput}.cs b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordVectorPropertyModel{TInput}.cs new file mode 100644 index 000000000000..fdcb56c43e14 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/ConnectorSupport/VectorStoreRecordVectorPropertyModel{TInput}.cs @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; + +namespace Microsoft.Extensions.VectorData.ConnectorSupport; + +/// +[Experimental("MEVD9001")] +public sealed class VectorStoreRecordVectorPropertyModel(string modelName) : VectorStoreRecordVectorPropertyModel(modelName, typeof(TInput)) +{ + /// + public override bool TrySetupEmbeddingGeneration(IEmbeddingGenerator embeddingGenerator, Type? embeddingType) + { + switch (embeddingGenerator) + { + case IEmbeddingGenerator when this.Type == typeof(TInput) && (embeddingType is null || embeddingType == typeof(TUnwrappedEmbedding)): + this.EmbeddingGenerator = embeddingGenerator; + this.EmbeddingType = embeddingType ?? typeof(TUnwrappedEmbedding); + + return true; + + case null: + throw new UnreachableException("This method should only be called when an embedding generator is configured."); + default: + return false; + } + } + + /// + public override bool TryGenerateEmbedding(TRecord record, CancellationToken cancellationToken, [NotNullWhen(true)] out Task? task) + { + switch (this.EmbeddingGenerator) + { + case IEmbeddingGenerator generator when this.EmbeddingType == typeof(TUnwrappedEmbedding): + task = generator.GenerateEmbeddingAsync( + this.GetValueAsObject(record) is var value && value is TInput s + ? s + : throw new InvalidOperationException($"Property '{this.ModelName}' was configured with an embedding generator accepting a {nameof(TInput)}, but {value?.GetType().Name ?? "null"} was provided."), + new() { Dimensions = this.Dimensions }, + cancellationToken); + return true; + + case null: + throw new UnreachableException("This method should only be called when an embedding generator is configured."); + + default: + task = null; + return false; + } + } + + /// + public override bool TryGenerateEmbeddings(IEnumerable records, CancellationToken cancellationToken, [NotNullWhen(true)] out Task>? task) + { + switch (this.EmbeddingGenerator) + { + case IEmbeddingGenerator generator when this.EmbeddingType == typeof(TUnwrappedEmbedding): + task = generator.GenerateAsync( + records.Select(r => this.GetValueAsObject(r) is var value && value is TInput s + ? s + : throw new InvalidOperationException($"Property '{this.ModelName}' was configured with an embedding generator accepting a string, but {value?.GetType().Name ?? "null"} was provided.")), + new() { Dimensions = this.Dimensions }, cancellationToken); + return true; + + case null: + throw new UnreachableException("This method should only be called when an embedding generator is configured."); + + default: + task = null; + return false; + } + } + + /// + public override Type[] GetSupportedInputTypes() => [typeof(TInput)]; +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/PACKAGE.md b/dotnet/src/Connectors/VectorData.Abstractions/PACKAGE.md index df87cb1b8586..da8da52c5eb5 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/PACKAGE.md +++ b/dotnet/src/Connectors/VectorData.Abstractions/PACKAGE.md @@ -4,7 +4,7 @@ Contains abstractions for accessing Vector Databases and Vector Indexes. ## Key Features -- Interfaces for Vector Database implementations which are provided in other packages including `Microsoft.SemanticKernel.Connectors.AzureAISearch`. +- Interfaces for Vector Database implementation. Vector Database implementations are provided separately in other packages, for example `Microsoft.SemanticKernel.Connectors.AzureAISearch`. ## How to Use @@ -22,14 +22,25 @@ The main types provided by this library are: ## Related Packages +Vector Database utilities: + +- `Microsoft.Extensions.VectorData` + +Vector Database implementations: + - `Microsoft.SemanticKernel.Connectors.AzureAISearch` - `Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB` - `Microsoft.SemanticKernel.Connectors.AzureCosmosNoSQL` +- `Microsoft.SemanticKernel.Connectors.InMemory` +- `Microsoft.SemanticKernel.Connectors.MongoDB` - `Microsoft.SemanticKernel.Connectors.Pinecone` +- `Microsoft.SemanticKernel.Connectors.Postgres` - `Microsoft.SemanticKernel.Connectors.Qdrant` - `Microsoft.SemanticKernel.Connectors.Redis` +- `Microsoft.SemanticKernel.Connectors.Sqlite` +- `Microsoft.SemanticKernel.Connectors.SqlServer` - `Microsoft.SemanticKernel.Connectors.Weaviate` ## Feedback & Contributing -Microsoft.Extensions.DependencyInjection.Abstractions is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/microsoft/semantic-kernel). +Microsoft.Extensions.VectorData.Abstractions is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/microsoft/semantic-kernel). diff --git a/dotnet/src/Connectors/VectorData.Abstractions/Properties/AssemblyInfo.cs b/dotnet/src/Connectors/VectorData.Abstractions/Properties/AssemblyInfo.cs new file mode 100644 index 000000000000..09647faa37af --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/Properties/AssemblyInfo.cs @@ -0,0 +1,3 @@ +// Copyright (c) Microsoft. All rights reserved. + +[assembly: System.Resources.NeutralResourcesLanguage("en-US")] diff --git a/dotnet/src/Connectors/VectorData.Abstractions/Properties/VectorDataStrings.Designer.cs b/dotnet/src/Connectors/VectorData.Abstractions/Properties/VectorDataStrings.Designer.cs new file mode 100644 index 000000000000..74eb4f806842 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/Properties/VectorDataStrings.Designer.cs @@ -0,0 +1,121 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.Extensions.VectorData.Properties +{ + using System; + + + [System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "4.0.0.0")] + [System.Diagnostics.DebuggerNonUserCodeAttribute()] + [System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + public class VectorDataStrings + { + + private static System.Resources.ResourceManager resourceMan; + + private static System.Globalization.CultureInfo resourceCulture; + + [System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal VectorDataStrings() + { + } + + [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Advanced)] + public static System.Resources.ResourceManager ResourceManager + { + get + { + if (object.Equals(null, resourceMan)) + { + System.Resources.ResourceManager temp = new System.Resources.ResourceManager("Microsoft.Extensions.VectorData.Properties.VectorDataStrings", typeof(VectorDataStrings).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Advanced)] + public static System.Globalization.CultureInfo Culture + { + get + { + return resourceCulture; + } + set + { + resourceCulture = value; + } + } + + public static string IncompatibleEmbeddingGenerator + { + get + { + return ResourceManager.GetString("IncompatibleEmbeddingGenerator", resourceCulture); + } + } + + public static string IncompatibleEmbeddingGeneratorWasConfiguredForInputType + { + get + { + return ResourceManager.GetString("IncompatibleEmbeddingGeneratorWasConfiguredForInputType", resourceCulture); + } + } + + public static string NoEmbeddingGeneratorWasConfiguredForSearch + { + get + { + return ResourceManager.GetString("NoEmbeddingGeneratorWasConfiguredForSearch", resourceCulture); + } + } + + public static string NonEmbeddingVectorPropertyWithoutEmbeddingGenerator + { + get + { + return ResourceManager.GetString("NonEmbeddingVectorPropertyWithoutEmbeddingGenerator", resourceCulture); + } + } + + public static string EmbeddingTypePassedToSearchAsync + { + get + { + return ResourceManager.GetString("EmbeddingTypePassedToSearchAsync", resourceCulture); + } + } + + public static string EmbeddingPropertyTypeIncompatibleWithEmbeddingGenerator + { + get + { + return ResourceManager.GetString("EmbeddingPropertyTypeIncompatibleWithEmbeddingGenerator", resourceCulture); + } + } + + public static string IncludeVectorsNotSupportedWithEmbeddingGeneration + { + get + { + return ResourceManager.GetString("IncludeVectorsNotSupportedWithEmbeddingGeneration", resourceCulture); + } + } + + public static string EmbeddingGeneratorWithInvalidEmbeddingType + { + get + { + return ResourceManager.GetString("EmbeddingGeneratorWithInvalidEmbeddingType", resourceCulture); + } + } + } +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/Properties/VectorDataStrings.resx b/dotnet/src/Connectors/VectorData.Abstractions/Properties/VectorDataStrings.resx new file mode 100644 index 000000000000..531eb0159b9e --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/Properties/VectorDataStrings.resx @@ -0,0 +1,45 @@ + + + + + + + + + + text/microsoft-resx + + + 1.3 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=2.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=2.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Embedding generator '{0}' is incompatible with the required input and output types. The property input type must be '{1}', and the output type must be '{2}'. + + + An input of type '{0}' was provided, but an incompatible embedding generator of type '{1}' was configured. + + + 'SearchAsync' requires an embedding generator to be configured. To pass an embedding directly, use 'SearchEmbeddingAsync', otherwise configure an embedding generator with your vector store connector. + + + Property '{0}' has non-Embedding type '{1}', but no embedding generator is configured. + + + 'SearchAsync' performs embedding generation, and does not accept Embedding types directly. To search for an Embedding directly, use 'SearchEmbeddingAsync'. + + + Property '{0}' has embedding type '{1}', but an embedding generator is configured on the property. Remove the embedding generator or change the property's .NET type to a non-embedding input type to the generator (e.g. string). + + + When an embedding generator is configured, `Include Vectors` cannot be enabled. + + + An embedding generator was configured on property '{0}', but output embedding type '{1}' isn't supported by the connector. + + diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordDataAttribute.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordDataAttribute.cs index 38302c7fecc8..8239bb55bf51 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordDataAttribute.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordDataAttribute.cs @@ -21,21 +21,39 @@ public sealed class VectorStoreRecordDataAttribute : Attribute /// /// The default is . /// + [Obsolete("This property is now obsolete and will have no affect if used. Please use IsIndexed instead", error: true)] public bool IsFilterable { get; init; } /// - /// Gets or sets a value indicating whether this data property is full-text searchable. + /// Gets or sets a value indicating whether this data property is full text searchable. /// /// /// The default is . /// + [Obsolete("This property is now obsolete and will have no affect if used. Please use IsFullTextIndexed instead", error: true)] public bool IsFullTextSearchable { get; init; } + /// + /// Gets or sets a value indicating whether this data property is indexed. + /// + /// + /// The default is . + /// + public bool IsIndexed { get; init; } + + /// + /// Gets or sets a value indicating whether this data property is indexed for full-text search. + /// + /// + /// The default is . + /// + public bool IsFullTextIndexed { get; init; } + /// /// Gets or sets an optional name to use for the property in storage, if different from the property name. /// /// /// For example, the property name might be "MyProperty" and the storage name might be "my_property". /// - public string? StoragePropertyName { get; set; } + public string? StoragePropertyName { get; init; } } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs index 318521355f1b..769c09802f15 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs @@ -19,5 +19,5 @@ public sealed class VectorStoreRecordKeyAttribute : Attribute /// /// For example, the property name might be "MyProperty" and the storage name might be "my_property". /// - public string? StoragePropertyName { get; set; } + public string? StoragePropertyName { get; init; } } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordVectorAttribute.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordVectorAttribute.cs index a69e50bd7029..229127df1ca5 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordVectorAttribute.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordVectorAttribute.cs @@ -16,6 +16,7 @@ public sealed class VectorStoreRecordVectorAttribute : Attribute /// /// Initializes a new instance of the class. /// + [Obsolete("This constructor is obsolete, since Dimensions is now a required parameter.", error: true)] public VectorStoreRecordVectorAttribute() { } @@ -26,6 +27,11 @@ public VectorStoreRecordVectorAttribute() /// The number of dimensions that the vector has. public VectorStoreRecordVectorAttribute(int Dimensions) { + if (Dimensions <= 0) + { + throw new ArgumentOutOfRangeException(nameof(Dimensions), "Dimensions must be greater than zero."); + } + this.Dimensions = Dimensions; } @@ -34,6 +40,7 @@ public VectorStoreRecordVectorAttribute(int Dimensions) /// /// The number of dimensions that the vector has. /// The distance function to use when comparing vectors. + [Obsolete("This constructor is obsolete. Use the constructor that takes Dimensions as a parameter and set the DistanceFunction property directly, e.g. [[VectorStoreRecordVector(Dimensions: 1536, DistanceFunction = DistanceFunction.CosineSimilarity)]]", error: true)] public VectorStoreRecordVectorAttribute(int Dimensions, string? DistanceFunction) { this.Dimensions = Dimensions; @@ -46,6 +53,7 @@ public VectorStoreRecordVectorAttribute(int Dimensions, string? DistanceFunction /// The number of dimensions that the vector has. /// The distance function to use when comparing vectors. /// The kind of index to use. + [Obsolete("This constructor is obsolete. Use the constructor that takes Dimensions as a parameter and set the DistanceFunction and IndexKind properties directly, e.g. [[VectorStoreRecordVector(Dimensions: 1536, DistanceFunction = DistanceFunction.CosineSimilarity, IndexKind = IndexKind.Flat)]]", error: true)] public VectorStoreRecordVectorAttribute(int Dimensions, string? DistanceFunction, string? IndexKind) { this.Dimensions = Dimensions; @@ -60,7 +68,7 @@ public VectorStoreRecordVectorAttribute(int Dimensions, string? DistanceFunction /// This property is required when creating collections, but can be omitted if not using that functionality. /// If not provided when trying to create a collection, create will fail. /// - public int? Dimensions { get; private set; } + public int Dimensions { get; private set; } /// /// Gets the kind of index to use. @@ -69,7 +77,9 @@ public VectorStoreRecordVectorAttribute(int Dimensions, string? DistanceFunction /// The default value varies by database type. See the documentation of your chosen database connector for more information. /// /// - public string? IndexKind { get; private set; } +#pragma warning disable CA1019 // Define accessors for attribute arguments: The constructor overload that contains this property is obsolete. + public string? IndexKind { get; init; } +#pragma warning restore CA1019 /// /// Gets the distance function to use when comparing vectors. @@ -78,7 +88,9 @@ public VectorStoreRecordVectorAttribute(int Dimensions, string? DistanceFunction /// The default value varies by database type. See the documentation of your chosen database connector for more information. /// /// - public string? DistanceFunction { get; private set; } +#pragma warning disable CA1019 // Define accessors for attribute arguments: The constructor overload that contains this property is obsolete. + public string? DistanceFunction { get; init; } +#pragma warning restore CA1019 /// /// Gets or sets an optional name to use for the property in storage, if different from the property name. @@ -86,5 +98,5 @@ public VectorStoreRecordVectorAttribute(int Dimensions, string? DistanceFunction /// /// For example, the property name might be "MyProperty" and the storage name might be "my_property". /// - public string? StoragePropertyName { get; set; } + public string? StoragePropertyName { get; init; } } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordDataProperty.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordDataProperty.cs index e3e5c22296b5..5cc543f73474 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordDataProperty.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordDataProperty.cs @@ -29,8 +29,8 @@ public VectorStoreRecordDataProperty(string propertyName, Type propertyType) public VectorStoreRecordDataProperty(VectorStoreRecordDataProperty source) : base(source) { - this.IsFilterable = source.IsFilterable; - this.IsFullTextSearchable = source.IsFullTextSearchable; + this.IsIndexed = source.IsIndexed; + this.IsFullTextIndexed = source.IsFullTextIndexed; } /// @@ -39,6 +39,7 @@ public VectorStoreRecordDataProperty(VectorStoreRecordDataProperty source) /// /// The default is . /// + [Obsolete("This property is now obsolete and will have no affect if used. Please use IsIndexed instead", error: true)] public bool IsFilterable { get; init; } /// @@ -47,5 +48,22 @@ public VectorStoreRecordDataProperty(VectorStoreRecordDataProperty source) /// /// The default is . /// + [Obsolete("This property is now obsolete and will have no affect if used. Please use IsFullTextIndexed instead", error: true)] public bool IsFullTextSearchable { get; init; } + + /// + /// Gets or sets a value indicating whether this data property is indexed. + /// + /// + /// The default is . + /// + public bool IsIndexed { get; init; } + + /// + /// Gets or sets a value indicating whether this data property is indexed for full-text search. + /// + /// + /// The default is . + /// + public bool IsFullTextIndexed { get; init; } } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordDefinition.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordDefinition.cs index d33d0fd4a145..7b43508d814c 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordDefinition.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordDefinition.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using Microsoft.Extensions.AI; namespace Microsoft.Extensions.VectorData; @@ -19,4 +20,9 @@ public sealed class VectorStoreRecordDefinition /// Gets or sets the list of properties that are stored in the record. /// public IReadOnlyList Properties { get; init; } = s_emptyFields; + + /// + /// Gets or sets the default embedding generator for vector properties in this collection. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordVectorProperty.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordVectorProperty.cs index 1d1791ed555f..d953eb489ee1 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordVectorProperty.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordVectorProperty.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData.ConnectorSupport; namespace Microsoft.Extensions.VectorData; @@ -10,18 +12,33 @@ namespace Microsoft.Extensions.VectorData; /// /// The characteristics defined here influence how the property is treated by the vector store. /// -public sealed class VectorStoreRecordVectorProperty : VectorStoreRecordProperty +public class VectorStoreRecordVectorProperty : VectorStoreRecordProperty { + private int _dimensions; + /// /// Initializes a new instance of the class. /// /// The name of the property. /// The type of the property. + [Obsolete("This constructor is obsolete, since dimensions is now a required parameter.", error: true)] public VectorStoreRecordVectorProperty(string propertyName, Type propertyType) : base(propertyName, propertyType) { } + /// + /// Initializes a new instance of the class. + /// + /// The name of the property. + /// The type of the property. + /// The number of dimensions that the vector has. + public VectorStoreRecordVectorProperty(string propertyName, Type propertyType, int dimensions) + : base(propertyName, propertyType) + { + this.Dimensions = dimensions; + } + /// /// Initializes a new instance of the class by cloning the given source. /// @@ -32,8 +49,19 @@ public VectorStoreRecordVectorProperty(VectorStoreRecordVectorProperty source) this.Dimensions = source.Dimensions; this.IndexKind = source.IndexKind; this.DistanceFunction = source.DistanceFunction; + this.EmbeddingGenerator = source.EmbeddingGenerator; + this.EmbeddingType = source.EmbeddingType; } + /// + /// Gets or sets the default embedding generator to use for this property. + /// + /// + /// If not set, embedding generation will be performed in the database, if supported by your connector. + /// If not supported, only pre-generated embeddings can be used, e.g. via . + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; init; } + /// /// Gets or sets the number of dimensions that the vector has. /// @@ -41,7 +69,20 @@ public VectorStoreRecordVectorProperty(VectorStoreRecordVectorProperty source) /// This property is required when creating collections, but can be omitted if not using that functionality. /// If not provided when trying to create a collection, create will fail. /// - public int? Dimensions { get; init; } + public int Dimensions + { + get => this._dimensions; + + init + { + if (value <= 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "Dimensions must be greater than zero."); + } + + this._dimensions = value; + } + } /// /// Gets or sets the kind of index to use. @@ -60,4 +101,19 @@ public VectorStoreRecordVectorProperty(VectorStoreRecordVectorProperty source) /// /// public string? DistanceFunction { get; init; } + + /// + /// Gets or sets the desired embedding type (e.g. Embedding<Half>, for cases where the default (typically Embedding<float>) isn't suitable. + /// + public Type? EmbeddingType { get; init; } + + internal virtual VectorStoreRecordVectorPropertyModel CreatePropertyModel() + => new(this.DataModelPropertyName, this.PropertyType) + { + Dimensions = this.Dimensions, + IndexKind = this.IndexKind, + DistanceFunction = this.DistanceFunction, + EmbeddingGenerator = this.EmbeddingGenerator, + EmbeddingType = this.EmbeddingType! + }; } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordVectorProperty{TInput}.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordVectorProperty{TInput}.cs new file mode 100644 index 000000000000..ab20a7cee39b --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordVectorProperty{TInput}.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData.ConnectorSupport; + +namespace Microsoft.Extensions.VectorData; + +/// +/// Defines a vector property on a vector store record. +/// +/// +/// +/// The characteristics defined here influence how the property is treated by the vector store. +/// +/// +/// This generic version of only needs to be used when an is +/// configured on the property, and a custom .NET type is used as input (any type other than or ). +/// +/// +public class VectorStoreRecordVectorProperty : VectorStoreRecordVectorProperty +{ + /// + public VectorStoreRecordVectorProperty(string propertyName, int dimensions) + : base(propertyName, typeof(TInput), dimensions) + { + } + + /// + public VectorStoreRecordVectorProperty(VectorStoreRecordVectorProperty source) + : base(source) + { + } + + internal override VectorStoreRecordVectorPropertyModel CreatePropertyModel() + => new VectorStoreRecordVectorPropertyModel(this.DataModelPropertyName) + { + Dimensions = this.Dimensions, + IndexKind = this.IndexKind, + DistanceFunction = this.DistanceFunction, + EmbeddingGenerator = this.EmbeddingGenerator + }; +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordOptions/GetFilteredRecordOptions.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordOptions/GetFilteredRecordOptions.cs new file mode 100644 index 000000000000..6843061e369a --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordOptions/GetFilteredRecordOptions.cs @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; + +namespace Microsoft.Extensions.VectorData; + +/// +/// Defines options for filter search. +/// +/// Type of the record. +public sealed class GetFilteredRecordOptions +{ + private int _skip = 0; + + /// + /// Gets or sets the number of results to skip before returning results, that is, the index of the first result to return. + /// + /// Thrown when the value is less than 0. + public int Skip + { + get => this._skip; + init + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "Skip must be greater than or equal to 0."); + } + + this._skip = value; + } + } + + /// + /// Gets or sets the data property to order by. + /// + /// + /// If not provided, the order of returned results is non-deterministic. + /// + public OrderByDefinition OrderBy { get; } = new(); + + /// + /// Gets or sets a value indicating whether to include vectors in the retrieval result. + /// + public bool IncludeVectors { get; init; } = false; + + /// + /// A builder for sorting. + /// + // This type does not derive any collection in order to avoid Intellisense suggesting LINQ methods. + public sealed class OrderByDefinition + { + private readonly List _values = new(); + + /// + /// Gets the expressions to sort by. + /// + /// This property is intended to be consumed by the connectors to retrieve the configuration. + public IReadOnlyList Values => this._values; + + /// + /// Creates an ascending sort. + /// + public OrderByDefinition Ascending(Expression> propertySelector) + { + if (propertySelector is null) + { + throw new ArgumentNullException(nameof(propertySelector)); + } + + this._values.Add(new(propertySelector, true)); + return this; + } + + /// + /// Creates a descending sort. + /// + public OrderByDefinition Descending(Expression> propertySelector) + { + if (propertySelector is null) + { + throw new ArgumentNullException(nameof(propertySelector)); + } + + this._values.Add(new(propertySelector, false)); + return this; + } + + /// + /// Provides a way to define property ordering. + /// + /// This class is intended to be consumed by the connectors to retrieve the configuration. + public sealed class SortInfo + { + internal SortInfo(Expression> propertySelector, bool isAscending) + { + this.PropertySelector = propertySelector; + this.Ascending = isAscending; + } + + /// + /// The expression to select the property to sort by. + /// + public Expression> PropertySelector { get; } + + /// + /// True if the sort is ascending; otherwise, false. + /// + public bool Ascending { get; } + } + } +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordOptions/GetRecordOptions.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordOptions/GetRecordOptions.cs index e623cb676247..a6bf3d9f3b12 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordOptions/GetRecordOptions.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordOptions/GetRecordOptions.cs @@ -1,9 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections.Generic; +using System.Threading; + namespace Microsoft.Extensions.VectorData; /// -/// Defines options for calling . +/// Defines options for calling +/// or . /// public class GetRecordOptions { diff --git a/dotnet/src/Connectors/VectorData.Abstractions/Throw.cs b/dotnet/src/Connectors/VectorData.Abstractions/Throw.cs new file mode 100644 index 000000000000..42682c708155 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/Throw.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; + +namespace Microsoft.Extensions.VectorData; + +internal static class Throw +{ + /// Throws an exception indicating that a required service is not available. + public static InvalidOperationException CreateMissingServiceException(Type serviceType, object? serviceKey) => + new(serviceKey is null ? + $"No service of type '{serviceType}' is available." : + $"No service of type '{serviceType}' for the key '{serviceKey}' is available."); +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorData.Abstractions.csproj b/dotnet/src/Connectors/VectorData.Abstractions/VectorData.Abstractions.csproj index f1dc235aa5bd..5fd1d4d4c24f 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorData.Abstractions.csproj +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorData.Abstractions.csproj @@ -4,7 +4,8 @@ Microsoft.Extensions.VectorData.Abstractions Microsoft.Extensions.VectorData net8.0;netstandard2.0;net462 - true + + false @@ -13,7 +14,7 @@ 9.0.0-preview.1.25161.1 9.0.0.0 - 9.0.0-preview.1.25078.1 + 9.0.0-preview.1.25161.1 Microsoft.Extensions.VectorData.Abstractions $(AssemblyName) Abstractions for vector database access. @@ -30,8 +31,17 @@ Microsoft.Extensions.VectorData.IVectorStoreRecordCollection<TKey, TRecord> https://dot.net/ - + + + + + + + + + + @@ -40,7 +50,32 @@ Microsoft.Extensions.VectorData.IVectorStoreRecordCollection<TKey, TRecord> - - + + + + + + + + + + + + + + + + PublicResXFileCodeGenerator + VectorDataStrings.Designer.cs + + + + + + True + True + VectorDataStrings.resx + $(NoWarn);1591 + diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/HybridSearchOptions.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/HybridSearchOptions.cs index 0711cd0aba43..96c251a086f0 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/HybridSearchOptions.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/HybridSearchOptions.cs @@ -10,6 +10,8 @@ namespace Microsoft.Extensions.VectorData; /// public class HybridSearchOptions { + private int _skip = 0; + /// /// Gets or sets a search filter to use before doing the hybrid search. /// @@ -35,21 +37,29 @@ public class HybridSearchOptions /// /// Gets or sets the additional target property to do the text/keyword search on. - /// The property must have full text search enabled. - /// If not provided will look if there is a text property with full text search enabled, and + /// The property must have full text indexing enabled. + /// If not provided will look if there is a text property with full text indexing enabled, and /// will throw if either none or multiple exist. /// public Expression>? AdditionalProperty { get; init; } /// - /// Gets or sets the maximum number of results to return. + /// Gets or sets the number of results to skip before returning results, that is, the index of the first result to return. /// - public int Top { get; init; } = 3; + /// Thrown when the value is less than 0. + public int Skip + { + get => this._skip; + init + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "Skip must be greater than or equal to 0."); + } - /// - /// Gets or sets the number of results to skip before returning results, i.e. the index of the first result to return. - /// - public int Skip { get; init; } = 0; + this._skip = value; + } + } /// /// Gets or sets a value indicating whether to include vectors in the retrieval result. diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IKeywordHybridSearch.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IKeywordHybridSearch.cs index 53d2e062fcda..26f0f5f5a81c 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IKeywordHybridSearch.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IKeywordHybridSearch.cs @@ -1,8 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Threading; -using System.Threading.Tasks; namespace Microsoft.Extensions.VectorData; @@ -18,12 +19,27 @@ public interface IKeywordHybridSearch /// The type of the vector. /// The vector to search the store with. /// A collection of keywords to search the store with. + /// The maximum number of results to return. /// The options that control the behavior of the search. /// The to monitor for cancellation requests. The default is . /// The records found by the hybrid search, including their result scores. - Task> HybridSearchAsync( + IAsyncEnumerable> HybridSearchAsync( TVector vector, ICollection keywords, + int top, HybridSearchOptions? options = default, CancellationToken cancellationToken = default); + + /// Asks the for an object of the specified type . + /// The type of object being requested. + /// An optional key that can be used to help identify the target service. + /// The found object, otherwise . + /// is . + /// + /// The purpose of this method is to allow for the retrieval of strongly-typed services that might be provided by the , + /// including itself or any services it might be wrapping. For example, to access the for the instance, + /// may be used to request it. + /// + [Experimental("MEVD9000")] + object? GetService(Type serviceType, object? serviceKey = null); } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorSearch.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorSearch.cs new file mode 100644 index 000000000000..a1385d9bab8c --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorSearch.cs @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using Microsoft.Extensions.AI; + +namespace Microsoft.Extensions.VectorData; + +/// +/// An interface for performing vector searches on a vector store. +/// +/// The record data model to use for retrieving data from the store. +public interface IVectorSearch +{ + /// + /// Searches the vector store for records that are similar to given value. + /// + /// + /// When using this method, is converted to an embedding internally; depending on your database, you may need to configure an embedding generator. + /// + /// The type of the input value on which to perform the similarity search. + /// The value on which to perform the similarity search. + /// The maximum number of results to return. + /// The options that control the behavior of the search. + /// The to monitor for cancellation requests. The default is . + /// The records found by the vector search, including their result scores. + IAsyncEnumerable> SearchAsync( + TInput value, + int top, + VectorSearchOptions? options = default, + CancellationToken cancellationToken = default) + where TInput : notnull; + + /// + /// Searches the vector store for records that are similar to given embedding. + /// + /// + /// This is a low-level method that requires embedding generation to be handled manually. + /// Consider configuring an and using to have embeddings generated automatically. + /// + /// The type of the vector. + /// The vector to search the store with. + /// The maximum number of results to return. + /// The options that control the behavior of the search. + /// The to monitor for cancellation requests. The default is . + /// The records found by the vector search, including their result scores. + // TODO: We may also want to consider allowing the user to pass Embedding, rather than just ReadOnlyMemory (#11701). + // TODO: However, if they have an Embedding, they likely got it from an IEmbeddingGenerator, at which point why not wire that up into MEVD and use SearchAsync? + // TODO: So this raw embedding API is likely more for users who already have a ReadOnlyMemory at hand and we don't want to force them to wrap it with Embedding. + IAsyncEnumerable> SearchEmbeddingAsync( + TVector vector, + int top, + VectorSearchOptions? options = default, + CancellationToken cancellationToken = default) + where TVector : notnull; + + /// Asks the for an object of the specified type . + /// The type of object being requested. + /// An optional key that can be used to help identify the target service. + /// The found object, otherwise . + /// is . + /// + /// The purpose of this method is to allow for the retrieval of strongly-typed services that might be provided by the , + /// including itself or any services it might be wrapping. For example, to access the for the instance, + /// may be used to request it. + /// + [Experimental("MEVD9000")] + object? GetService(Type serviceType, object? serviceKey = null); +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizableTextSearch.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizableTextSearch.cs index 5368c5301828..02ff6e3b3afc 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizableTextSearch.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizableTextSearch.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Threading; -using System.Threading.Tasks; namespace Microsoft.Extensions.VectorData; @@ -9,17 +11,33 @@ namespace Microsoft.Extensions.VectorData; /// Contains a method for doing a vector search using text that will be vectorized downstream. /// /// The record data model to use for retrieving data from the store. +[Obsolete("Use IVectorStoreRecordCollection.SearchAsync instead")] public interface IVectorizableTextSearch { /// /// Searches the vector store for records that match the given text and filter. The text string will be vectorized downstream and used for the vector search. /// /// The text to search the store with. + /// The maximum number of results to return. /// The options that control the behavior of the search. /// The to monitor for cancellation requests. The default is . /// The records found by the vector search, including their result scores. - Task> VectorizableTextSearchAsync( + IAsyncEnumerable> VectorizableTextSearchAsync( string searchText, + int top, VectorSearchOptions? options = default, CancellationToken cancellationToken = default); + + /// Asks the for an object of the specified type . + /// The type of object being requested. + /// An optional key that can be used to help identify the target service. + /// The found object, otherwise . + /// is . + /// + /// The purpose of this method is to allow for the retrieval of strongly-typed services that might be provided by the , + /// including itself or any services it might be wrapping. For example, to access the for the instance, + /// may be used to request it. + /// + [Experimental("MEVD9000")] + object? GetService(Type serviceType, object? serviceKey = null); } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizedSearch.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizedSearch.cs index b2a5a54194a6..1fef2039e4d2 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizedSearch.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizedSearch.cs @@ -1,7 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; using System.Threading; -using System.Threading.Tasks; namespace Microsoft.Extensions.VectorData; @@ -9,6 +10,7 @@ namespace Microsoft.Extensions.VectorData; /// Contains a method for doing a vector search using a vector. /// /// The record data model to use for retrieving data from the store. +[Obsolete("This interface is obsolete, use either SearchEmbeddingAsync to search directly on embeddings, or SearchAsync to handle embedding generation internally as part of the call")] public interface IVectorizedSearch { /// @@ -16,11 +18,15 @@ public interface IVectorizedSearch /// /// The type of the vector. /// The vector to search the store with. + /// The maximum number of results to return. /// The options that control the behavior of the search. /// The to monitor for cancellation requests. The default is . /// The records found by the vector search, including their result scores. - Task> VectorizedSearchAsync( + [Obsolete("Use either SearchEmbeddingAsync to search directly on embeddings, or SearchAsync to handle embedding generation internally as part of the call.")] + IAsyncEnumerable> VectorizedSearchAsync( TVector vector, + int top, VectorSearchOptions? options = default, - CancellationToken cancellationToken = default); + CancellationToken cancellationToken = default) + where TVector : notnull; } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/KeywordHybridSearchExtensions.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/KeywordHybridSearchExtensions.cs new file mode 100644 index 000000000000..0e8435ae25c2 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/KeywordHybridSearchExtensions.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.Extensions.VectorData; + +/// Provides a collection of static methods for extending instances. +[Experimental("MEVD9000")] +public static class KeywordHybridSearchExtensions +{ + /// + /// Asks the for an object of the specified type + /// and throw an exception if one isn't available. + /// + /// The record data model to use for retrieving data from the store. + /// The keyword hybrid search. + /// The type of object being requested. + /// An optional key that can be used to help identify the target service. + /// The found object. + /// is . + /// is . + /// No service of the requested type for the specified key is available. + public static object GetRequiredService(this IKeywordHybridSearch keywordHybridSearch, Type serviceType, object? serviceKey = null) + { + if (keywordHybridSearch is null) { throw new ArgumentNullException(nameof(keywordHybridSearch)); } + if (serviceType is null) { throw new ArgumentNullException(nameof(serviceType)); } + + return + keywordHybridSearch.GetService(serviceType, serviceKey) ?? + throw Throw.CreateMissingServiceException(serviceType, serviceKey); + } +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchExtensions.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchExtensions.cs new file mode 100644 index 000000000000..13f8e9960d0d --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchExtensions.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.Extensions.VectorData; + +/// Provides a collection of static methods for extending instances. +[Experimental("MEVD9000")] +public static class VectorSearchExtensions +{ + /// + /// Asks the for an object of the specified type + /// and throw an exception if one isn't available. + /// + /// The record data model to use for retrieving data from the store. + /// The vector search. + /// The type of object being requested. + /// An optional key that can be used to help identify the target service. + /// The found object. + /// is . + /// is . + /// No service of the requested type for the specified key is available. + public static object GetRequiredService(this IVectorSearch vectorSearch, Type serviceType, object? serviceKey = null) + { + if (vectorSearch is null) { throw new ArgumentNullException(nameof(vectorSearch)); } + if (serviceType is null) { throw new ArgumentNullException(nameof(serviceType)); } + + return + vectorSearch.GetService(serviceType, serviceKey) ?? + throw Throw.CreateMissingServiceException(serviceType, serviceKey); + } +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs index 7f6cc16f5dfa..533ede18348d 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs @@ -10,7 +10,7 @@ namespace Microsoft.Extensions.VectorData; /// public class VectorSearchOptions { - private int _top = 3, _skip = 0; + private int _skip = 0; /// /// Gets or sets a search filter to use before doing the vector search. @@ -44,24 +44,6 @@ public class VectorSearchOptions /// public Expression>? VectorProperty { get; init; } - /// - /// Gets or sets the maximum number of results to return. - /// - /// Thrown when the value is less than 1. - public int Top - { - get => this._top; - init - { - if (value < 1) - { - throw new ArgumentOutOfRangeException(nameof(value), "Top must be greater than or equal to 1."); - } - - this._top = value; - } - } - /// /// Gets or sets the number of results to skip before returning results, that is, the index of the first result to return. /// @@ -95,5 +77,6 @@ public int Skip /// Not all vector search implementations support this option, in which case the total /// count will be null even if requested via this option. /// + [Obsolete("Total count is no longer included in the results.", error: true)] public bool IncludeTotalCount { get; init; } = false; } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchResults.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchResults.cs deleted file mode 100644 index 293315ee554a..000000000000 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchResults.cs +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Collections.Generic; - -namespace Microsoft.Extensions.VectorData; - -/// -/// Contains the full list of search results for a vector search operation with metadata. -/// -/// The record data model to use for retrieving data from the store. -/// The list of records returned by the search operation. -public class VectorSearchResults(IAsyncEnumerable> results) -{ - /// - /// Gets or sets the total count of results found by the search operation, or null - /// if the count was not requested or cannot be computed. - /// - /// - /// This value represents the total number of results that are available for the current query and not the number of results being returned. - /// - public long? TotalCount { get; init; } - - /// - /// Gets or sets the metadata associated with the content. - /// - public IReadOnlyDictionary? Metadata { get; init; } - - /// - /// Gets the search results. - /// - public IAsyncEnumerable> Results { get; } = results; -} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/IVectorStore.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/IVectorStore.cs index 007dcf79da03..df4372e3ecbf 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/IVectorStore.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/IVectorStore.cs @@ -1,7 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Threading; +using System.Threading.Tasks; namespace Microsoft.Extensions.VectorData; @@ -9,7 +12,8 @@ namespace Microsoft.Extensions.VectorData; /// Defines an interface for accessing the list of collections in a vector store. /// /// -/// This interface can be used with collections of any schema type, but requires you to provide schema information when getting a collection. +/// This interface can be used with collections of any schema type, but requires you to provide schema information when getting a collection. +/// Unless otherwise documented, implementations of this interface can be expected to be thread-safe, and can be used concurrently from multiple threads. /// public interface IVectorStore { @@ -29,7 +33,8 @@ public interface IVectorStore /// /// IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) - where TKey : notnull; + where TKey : notnull + where TRecord : notnull; /// /// Retrieves the names of all the collections in the vector store. @@ -37,4 +42,33 @@ IVectorStoreRecordCollection GetCollection(string /// The to monitor for cancellation requests. The default is . /// The list of names of all the collections in the vector store. IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default); + + /// + /// Checks if the collection exists in the vector store. + /// + /// The name of the collection. + /// The to monitor for cancellation requests. The default is . + /// if the collection exists, otherwise. + Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default); + + /// + /// Deletes the collection from the vector store. + /// + /// The name of the collection. + /// The to monitor for cancellation requests. The default is . + /// A that completes when the collection has been deleted. + Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default); + + /// Asks the for an object of the specified type . + /// The type of object being requested. + /// An optional key that can be used to help identify the target service. + /// The found object, otherwise . + /// is . + /// + /// The purpose of this method is to allow for the retrieval of strongly-typed services that might be provided by the , + /// including itself or any services it might be wrapping. For example, to access the for the instance, + /// may be used to request it. + /// + [Experimental("MEVD9000")] + object? GetService(Type serviceType, object? serviceKey = null); } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/IVectorStoreRecordCollection.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/IVectorStoreRecordCollection.cs index b8e410d4afd5..91a62b496c20 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/IVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/IVectorStoreRecordCollection.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; +using System.Linq.Expressions; using System.Threading; using System.Threading.Tasks; @@ -11,15 +13,21 @@ namespace Microsoft.Extensions.VectorData; /// /// The data type of the record key. /// The record data model to use for adding, updating, and retrieving data from the store. -#pragma warning disable CA1711 // Identifiers should not have incorrect suffix -public interface IVectorStoreRecordCollection : IVectorizedSearch -#pragma warning restore CA1711 // Identifiers should not have incorrect suffix +/// +/// Unless otherwise documented, implementations of this interface can be expected to be thread-safe, and can be used concurrently from multiple threads. +/// +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix (Collection) +#pragma warning disable CS0618 // IVectorizedSearch is obsolete +public interface IVectorStoreRecordCollection : IVectorSearch, IVectorizedSearch +#pragma warning restore CS0618 // IVectorizedSearch is obsolete +#pragma warning restore CA1711 where TKey : notnull + where TRecord : notnull { /// /// Gets the name of the collection. /// - string CollectionName { get; } + string Name { get; } /// /// Checks if the collection exists in the vector store. @@ -75,7 +83,7 @@ public interface IVectorStoreRecordCollection : IVectorizedSearch /// /// The command fails to execute for any reason. /// The mapping between the storage model and record data model fails. - IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = default, CancellationToken cancellationToken = default); + IAsyncEnumerable GetAsync(IEnumerable keys, GetRecordOptions? options = default, CancellationToken cancellationToken = default); /// /// Deletes a record from the vector store. Does not guarantee that the collection exists. @@ -98,7 +106,7 @@ public interface IVectorStoreRecordCollection : IVectorizedSearch /// If any record can't be deleted for any other reason, the operation throws. Some records might have already been deleted while others might not have, so the entire operation should be retried. /// /// The command fails to execute for any reason other than that a record does not exist. - Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default); + Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default); /// /// Upserts a record into the vector store. Does not guarantee that the collection exists. @@ -107,23 +115,44 @@ public interface IVectorStoreRecordCollection : IVectorizedSearch /// /// The record to upsert. /// The to monitor for cancellation requests. The default is . - /// The unique identifier for the record. + /// The key for the records, to be used when keys are generated in the database. /// The command fails to execute for any reason. /// The mapping between the storage model and record data model fails. Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default); /// - /// Upserts a group of records into the vector store. Does not guarantee that the collection exists. + /// Upserts a batch of records into the vector store. Does not guarantee that the collection exists. /// If the record already exists, it is updated. /// If the record does not exist, it is created. /// /// The records to upsert. /// The to monitor for cancellation requests. The default is . - /// The unique identifiers for the records. + /// The keys for the records, to be used when keys are generated in the database. /// - /// Upserts are made in a single request or in a single parallel batch depending on the available store functionality. + /// + /// The exact method of upserting the batch is implementation-specific and can vary based on database support; some databases support batch upserts via a single, efficient + /// request, while in other cases the implementation might send multiple upserts in parallel. + /// + /// + /// Similarly, the error behavior can vary across databases: where possible, the batch will be upserted atomically, so that any errors cause the entire batch to be rolled + /// back. Where not supported, some records may be upserted while others are not. If key properties are set by the user, then the entire upsert operation is idempotent, + /// and can simply be retried again if an error occurs. However, if store-generated keys are in use, the upsert operation is no longer idempotent; in that case, if the + /// database doesn't guarantee atomicity, retrying could cause duplicate records to be created. + /// /// /// The command fails to execute for any reason. /// The mapping between the storage model and record data model fails. - IAsyncEnumerable UpsertBatchAsync(IEnumerable records, CancellationToken cancellationToken = default); + Task> UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default); + + /// + /// Gets matching records from the vector store. Does not guarantee that the collection exists. + /// + /// The predicate to filter the records. + /// The maximum number of results to return. + /// Options for retrieving the records. + /// The to monitor for cancellation requests. The default is . + /// The records matching given predicate. + /// The command fails to execute for any reason. + /// The mapping between the storage model and record data model fails. + IAsyncEnumerable GetAsync(Expression> filter, int top, GetFilteredRecordOptions? options = null, CancellationToken cancellationToken = default); } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/IVectorStoreRecordMapper.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/IVectorStoreRecordMapper.cs index 3bac47a89121..e61cfe4be48a 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/IVectorStoreRecordMapper.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/IVectorStoreRecordMapper.cs @@ -1,5 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System; + namespace Microsoft.Extensions.VectorData; /// @@ -7,6 +9,7 @@ namespace Microsoft.Extensions.VectorData; /// /// The consumer record data model to map to or from. /// The storage model to map to or from. +[Obsolete("Custom mappers are no longer supported.", error: true)] public interface IVectorStoreRecordMapper { /// diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/VectorStoreException.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/VectorStoreException.cs index dc0f5bd1d1b5..42e18181fe1b 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/VectorStoreException.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/VectorStoreException.cs @@ -33,10 +33,18 @@ protected VectorStoreException(string? message, Exception? innerException) : bas { } + /// The name of the vector store system. + /// + /// Where possible, this maps to the "db.system.name" attribute defined in the + /// OpenTelemetry Semantic Conventions for database calls and systems, see . + /// Example: redis, sqlite, mysql. + /// + public string? VectorStoreSystemName { get; init; } + /// - /// Gets or sets the type of vector store that the failing operation was performed on. + /// The name of the vector store (database). /// - public string? VectorStoreType { get; init; } + public string? VectorStoreName { get; init; } /// /// Gets or sets the name of the vector store collection that the failing operation was performed on. diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/VectorStoreExtensions.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/VectorStoreExtensions.cs new file mode 100644 index 000000000000..9d50678cc118 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/VectorStoreExtensions.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.Extensions.VectorData; + +/// Provides a collection of static methods for extending instances. +[Experimental("MEVD9000")] +public static class VectorStoreExtensions +{ + /// + /// Asks the for an object of the specified type + /// and throw an exception if one isn't available. + /// + /// The record data model to use for retrieving data from the store. + /// The vector store. + /// The type of object being requested. + /// An optional key that can be used to help identify the target service. + /// The found object. + /// is . + /// is . + /// No service of the requested type for the specified key is available. + public static object GetRequiredService(this IVectorStore vectorStore, Type serviceType, object? serviceKey = null) + { + if (vectorStore is null) { throw new ArgumentNullException(nameof(vectorStore)); } + if (serviceType is null) { throw new ArgumentNullException(nameof(serviceType)); } + + return + vectorStore.GetService(serviceType, serviceKey) ?? + throw Throw.CreateMissingServiceException(serviceType, serviceKey); + } +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/VectorStoreMetadata.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/VectorStoreMetadata.cs new file mode 100644 index 000000000000..f89884736cb6 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/VectorStoreMetadata.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.Extensions.VectorData; + +/// Provides metadata about an . +[Experimental("MEVD9000")] +public class VectorStoreMetadata +{ + /// The name of the vector store system. + /// + /// Where possible, this maps to the "db.system.name" attribute defined in the + /// OpenTelemetry Semantic Conventions for database calls and systems, see . + /// Example: redis, sqlite, mysql. + /// + public string? VectorStoreSystemName { get; init; } + + /// + /// The name of the vector store (database). + /// + public string? VectorStoreName { get; init; } +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/VectorStoreRecordCollectionMetadata.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/VectorStoreRecordCollectionMetadata.cs new file mode 100644 index 000000000000..b2ea092878c3 --- /dev/null +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorStorage/VectorStoreRecordCollectionMetadata.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.Extensions.VectorData; + +/// Provides metadata about an . +[Experimental("MEVD9000")] +public class VectorStoreRecordCollectionMetadata +{ + /// The name of the vector store system. + /// + /// Where possible, this maps to the "db.system.name" attribute defined in the + /// OpenTelemetry Semantic Conventions for database calls and systems, see . + /// Example: redis, sqlite, mysql. + /// + public string? VectorStoreSystemName { get; init; } + + /// + /// The name of the vector store (database). + /// + public string? VectorStoreName { get; init; } + + /// + /// The name of a collection (table, container) within the vector store (database). + /// + public string? CollectionName { get; init; } +} diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorStoreGenericDataModel.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorStoreGenericDataModel.cs index 6ab9ee119e55..04570c6a816a 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorStoreGenericDataModel.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorStoreGenericDataModel.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; namespace Microsoft.Extensions.VectorData; @@ -8,13 +9,30 @@ namespace Microsoft.Extensions.VectorData; /// Represents a generic data model that can be used to store and retrieve any data from a vector store. /// /// The data type of the record key. -/// The key of the record. -public sealed class VectorStoreGenericDataModel(TKey key) +[Obsolete($"{nameof(VectorStoreGenericDataModel)} has been replaced by Dictionary", error: true)] +public sealed class VectorStoreGenericDataModel { + /// + /// Constructs a new . + /// +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. + public VectorStoreGenericDataModel() +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. + { + } + + /// + /// Constructs a new . + /// + public VectorStoreGenericDataModel(TKey key) + { + this.Key = key; + } + /// /// Gets or sets the key of the record. /// - public TKey Key { get; set; } = key; + public TKey Key { get; set; } /// /// Gets or sets a dictionary of data items stored in the record. diff --git a/dotnet/src/Connectors/VectorData.UnitTests/VectorData.UnitTests.csproj b/dotnet/src/Connectors/VectorData.UnitTests/VectorData.UnitTests.csproj new file mode 100644 index 000000000000..4dae2d3e5f1f --- /dev/null +++ b/dotnet/src/Connectors/VectorData.UnitTests/VectorData.UnitTests.csproj @@ -0,0 +1,45 @@ + + + + VectorData.UnitTests + VectorData.UnitTests + net8.0 + true + enable + disable + false + $(NoWarn);VSTHRD111,CA2007,CS8618 + $(NoWarn);MEVD9001 + + + $(NoWarn);CA1515 + $(NoWarn);CA1707 + $(NoWarn);CA1716 + $(NoWarn);CA1720 + $(NoWarn);CA1721 + $(NoWarn);CA1861 + $(NoWarn);CA1863 + $(NoWarn);CA2007;VSTHRD111 + $(NoWarn);CS1591 + $(NoWarn);IDE1006 + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + diff --git a/dotnet/src/Connectors/VectorData.UnitTests/VectorStoreRecordModelBuilderTests.cs b/dotnet/src/Connectors/VectorData.UnitTests/VectorStoreRecordModelBuilderTests.cs new file mode 100644 index 000000000000..38c42cbd57ea --- /dev/null +++ b/dotnet/src/Connectors/VectorData.UnitTests/VectorStoreRecordModelBuilderTests.cs @@ -0,0 +1,335 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using Microsoft.Extensions.VectorData.Properties; +using Xunit; + +namespace VectorData.UnitTests; + +#pragma warning disable CA2000 // Dispose objects before losing scope + +public class VectorStoreRecordModelBuilderTests +{ + [Fact] + public void Default_embedding_generator_without_record_definition() + { + using var embeddingGenerator = new FakeEmbeddingGenerator>(); + var model = new CustomModelBuilder().Build(typeof(RecordWithStringVectorProperty), vectorStoreRecordDefinition: null, embeddingGenerator); + + // The embedding's .NET type (Embedding) is inferred from the embedding generator. + Assert.Same(embeddingGenerator, model.VectorProperty.EmbeddingGenerator); + Assert.Same(typeof(string), model.VectorProperty.Type); + Assert.Same(typeof(ReadOnlyMemory), model.VectorProperty.EmbeddingType); + } + + [Fact] + public void Default_embedding_generator_with_clr_type_and_record_definition() + { + using var embeddingGenerator = new FakeEmbeddingGenerator>(); + + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(RecordWithEmbeddingVectorProperty.Id), typeof(int)), + new VectorStoreRecordDataProperty(nameof(RecordWithEmbeddingVectorProperty.Name), typeof(string)), + new VectorStoreRecordVectorProperty(nameof(RecordWithEmbeddingVectorProperty.Embedding), typeof(string), dimensions: 3) + { + // The following configures the property to be ReadOnlyMemory (non-default embedding type for this connector) + EmbeddingType = typeof(ReadOnlyMemory) + } + ] + }; + + var model = new CustomModelBuilder().Build(typeof(RecordWithStringVectorProperty), recordDefinition, embeddingGenerator); + + // The embedding's .NET type (Embedding) is inferred from the embedding generator. + Assert.Same(embeddingGenerator, model.VectorProperty.EmbeddingGenerator); + Assert.Same(typeof(string), model.VectorProperty.Type); + Assert.Same(typeof(ReadOnlyMemory), model.VectorProperty.EmbeddingType); + } + + [Fact] + public void Default_embedding_generator_with_dynamic() + { + using var embeddingGenerator = new FakeEmbeddingGenerator>(); + + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(RecordWithEmbeddingVectorProperty.Id), typeof(int)), + new VectorStoreRecordDataProperty(nameof(RecordWithEmbeddingVectorProperty.Name), typeof(string)), + new VectorStoreRecordVectorProperty(nameof(RecordWithEmbeddingVectorProperty.Embedding), typeof(string), dimensions: 3) + ] + }; + + var model = new CustomModelBuilder().Build(typeof(Dictionary), recordDefinition, embeddingGenerator); + + // The embedding's .NET type (Embedding) is inferred from the embedding generator. + Assert.Same(embeddingGenerator, model.VectorProperty.EmbeddingGenerator); + Assert.Same(typeof(string), model.VectorProperty.Type); + Assert.Same(typeof(ReadOnlyMemory), model.VectorProperty.EmbeddingType); + } + + [Fact] + public void Default_embedding_generator_with_dynamic_and_non_default_EmbeddingType() + { + using var embeddingGenerator = new FakeEmbeddingGenerator>(); + + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(RecordWithEmbeddingVectorProperty.Id), typeof(int)), + new VectorStoreRecordDataProperty(nameof(RecordWithEmbeddingVectorProperty.Name), typeof(string)), + new VectorStoreRecordVectorProperty(nameof(RecordWithEmbeddingVectorProperty.Embedding), typeof(string), dimensions: 3) + { + EmbeddingType = typeof(ReadOnlyMemory) + } + ] + }; + + var model = new CustomModelBuilder().Build(typeof(Dictionary), recordDefinition, embeddingGenerator); + + Assert.Same(embeddingGenerator, model.VectorProperty.EmbeddingGenerator); + Assert.Same(typeof(string), model.VectorProperty.Type); + Assert.Same(typeof(ReadOnlyMemory), model.VectorProperty.EmbeddingType); + } + + [Fact] + public void Property_embedding_generator_takes_precedence_over_default_generator() + { + using var propertyEmbeddingGenerator = new FakeEmbeddingGenerator>(); + using var defaultEmbeddingGenerator = new FakeEmbeddingGenerator>(); + + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(RecordWithEmbeddingVectorProperty.Id), typeof(int)), + new VectorStoreRecordDataProperty(nameof(RecordWithEmbeddingVectorProperty.Name), typeof(string)), + new VectorStoreRecordVectorProperty(nameof(RecordWithEmbeddingVectorProperty.Embedding), typeof(string), dimensions: 3) + { + EmbeddingGenerator = propertyEmbeddingGenerator + } + ] + }; + + var model = new CustomModelBuilder().Build(typeof(Dictionary), recordDefinition, defaultEmbeddingGenerator); + + Assert.Same(propertyEmbeddingGenerator, model.VectorProperty.EmbeddingGenerator); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void Embedding_property_type_with_default_embedding_generator_ignores_generator(bool dynamic) + { + using var embeddingGenerator = new FakeEmbeddingGenerator>(); + + var model = dynamic + ? new CustomModelBuilder().Build( + typeof(Dictionary), + new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(RecordWithEmbeddingVectorProperty.Id), typeof(int)), + new VectorStoreRecordDataProperty(nameof(RecordWithEmbeddingVectorProperty.Name), typeof(string)), + new VectorStoreRecordVectorProperty(nameof(RecordWithEmbeddingVectorProperty.Embedding), typeof(ReadOnlyMemory), dimensions: 3) + ] + }, + embeddingGenerator) + : new CustomModelBuilder().Build(typeof(RecordWithEmbeddingVectorProperty), vectorStoreRecordDefinition: null, embeddingGenerator); + + Assert.Null(model.VectorProperty.EmbeddingGenerator); + Assert.Same(typeof(ReadOnlyMemory), model.VectorProperty.Type); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void Custom_input_type(bool dynamic) + { + using var embeddingGenerator = new FakeEmbeddingGenerator>(); + + // TODO: Allow custom input type without a record definition (i.e. generic attribute) + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(RecordWithEmbeddingVectorProperty.Id), typeof(int)), + new VectorStoreRecordDataProperty(nameof(RecordWithEmbeddingVectorProperty.Name), typeof(string)), + new VectorStoreRecordVectorProperty(nameof(RecordWithEmbeddingVectorProperty.Embedding), dimensions: 3) + ] + }; + + var model = dynamic + ? new CustomModelBuilder().Build(typeof(Dictionary), recordDefinition, embeddingGenerator) + : new CustomModelBuilder().Build(typeof(RecordWithCustomerVectorProperty), recordDefinition, embeddingGenerator); + + var vectorProperty = model.VectorProperty; + + Assert.Same(embeddingGenerator, vectorProperty.EmbeddingGenerator); + Assert.Same(typeof(Customer), vectorProperty.Type); + Assert.Same(typeof(ReadOnlyMemory), vectorProperty.EmbeddingType); + } + + [Fact] + public void Incompatible_embedding_on_embedding_generator_throws() + { + // Embedding is not a supported embedding type by the connector + using var embeddingGenerator = new FakeEmbeddingGenerator>(); + + var exception = Assert.Throws(() => + new CustomModelBuilder().Build(typeof(RecordWithStringVectorProperty), vectorStoreRecordDefinition: null, embeddingGenerator)); + + Assert.Equal($"Embedding generator '{typeof(FakeEmbeddingGenerator<,>).Name}' is incompatible with the required input and output types. The property input type must be 'String, DataContent', and the output type must be 'ReadOnlyMemory, ReadOnlyMemory'.", exception.Message); + } + + [Fact] + public void Incompatible_input_on_embedding_generator_throws() + { + // int is not a supported input type for the embedding generator + using var embeddingGenerator = new FakeEmbeddingGenerator>(); + + var exception = Assert.Throws(() => + new CustomModelBuilder().Build(typeof(RecordWithStringVectorProperty), vectorStoreRecordDefinition: null, embeddingGenerator)); + + Assert.Equal($"Embedding generator '{typeof(FakeEmbeddingGenerator<,>).Name}' is incompatible with the required input and output types. The property input type must be 'String, DataContent', and the output type must be 'ReadOnlyMemory, ReadOnlyMemory'.", exception.Message); + } + + [Fact] + public void Non_embedding_vector_property_without_embedding_generator_throws() + { + var exception = Assert.Throws(() => + new CustomModelBuilder().Build(typeof(RecordWithStringVectorProperty), vectorStoreRecordDefinition: null, defaultEmbeddingGenerator: null)); + + Assert.Equal($"Property '{nameof(RecordWithStringVectorProperty.Embedding)}' has non-Embedding type 'String', but no embedding generator is configured.", exception.Message); + } + + [Fact] + public void Embedding_property_type_with_property_embedding_generator_throws() + { + using var embeddingGenerator = new FakeEmbeddingGenerator>(); + + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(RecordWithEmbeddingVectorProperty.Id), typeof(int)), + new VectorStoreRecordDataProperty(nameof(RecordWithEmbeddingVectorProperty.Name), typeof(string)), + new VectorStoreRecordVectorProperty(nameof(RecordWithEmbeddingVectorProperty.Embedding), typeof(ReadOnlyMemory), dimensions: 3) + { + EmbeddingGenerator = embeddingGenerator + } + ] + }; + + var exception = Assert.Throws(() => + new CustomModelBuilder().Build(typeof(RecordWithEmbeddingVectorProperty), recordDefinition, embeddingGenerator)); + + Assert.Equal( + $"Property '{nameof(RecordWithEmbeddingVectorProperty.Embedding)}' has embedding type 'ReadOnlyMemory`1', but an embedding generator is configured on the property. Remove the embedding generator or change the property's .NET type to a non-embedding input type to the generator (e.g. string).", + exception.Message); + } + + public class RecordWithStringVectorProperty + { + [VectorStoreRecordKey] + public int Id { get; set; } + + [VectorStoreRecordData] + public string Name { get; set; } + + [VectorStoreRecordVector(Dimensions: 3)] + public string Embedding { get; set; } + } + + public class RecordWithEmbeddingVectorProperty + { + [VectorStoreRecordKey] + public int Id { get; set; } + + [VectorStoreRecordData] + public string Name { get; set; } + + [VectorStoreRecordVector(Dimensions: 3)] + public ReadOnlyMemory Embedding { get; set; } + } + + public class RecordWithCustomerVectorProperty + { + [VectorStoreRecordKey] + public int Id { get; set; } + + [VectorStoreRecordData] + public string Name { get; set; } + + [VectorStoreRecordVector(Dimensions: 3)] + public Customer Embedding { get; set; } + } + + public class Customer + { + public string FirstName { get; set; } + public string LastName { get; set; } + } + + private sealed class CustomModelBuilder(VectorStoreRecordModelBuildingOptions? options = null) + : VectorStoreRecordModelBuilder(options ?? s_defaultOptions) + { + private static readonly VectorStoreRecordModelBuildingOptions s_defaultOptions = new() + { + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + RequiresAtLeastOneVector = false, + + SupportedKeyPropertyTypes = [typeof(string), typeof(int)], + SupportedDataPropertyTypes = [typeof(string), typeof(int)], + SupportedEnumerableDataPropertyElementTypes = [typeof(string), typeof(int)], + SupportedVectorPropertyTypes = [typeof(ReadOnlyMemory), typeof(ReadOnlyMemory)] + }; + + protected override void SetupEmbeddingGeneration( + VectorStoreRecordVectorPropertyModel vectorProperty, + IEmbeddingGenerator embeddingGenerator, + Type? embeddingType) + { + if (!vectorProperty.TrySetupEmbeddingGeneration, ReadOnlyMemory>(embeddingGenerator, embeddingType) + && !vectorProperty.TrySetupEmbeddingGeneration, ReadOnlyMemory>(embeddingGenerator, embeddingType)) + { + throw new InvalidOperationException( + string.Format( + VectorDataStrings.IncompatibleEmbeddingGenerator, + embeddingGenerator.GetType().Name, + string.Join(", ", vectorProperty.GetSupportedInputTypes().Select(t => t.Name)), + "ReadOnlyMemory, ReadOnlyMemory")); + } + } + } + + private sealed class FakeEmbeddingGenerator : IEmbeddingGenerator + where TEmbedding : Embedding + { + public Task> GenerateAsync( + IEnumerable values, + EmbeddingGenerationOptions? options = null, + CancellationToken cancellationToken = default) + => throw new UnreachableException(); + + public object? GetService(Type serviceType, object? serviceKey = null) + => throw new UnreachableException(); + + public void Dispose() { } + } +} diff --git a/dotnet/src/Experimental/Process.IntegrationTestHost.Dapr/Process.IntegrationTestHost.Dapr.csproj b/dotnet/src/Experimental/Process.IntegrationTestHost.Dapr/Process.IntegrationTestHost.Dapr.csproj index 91277c4692ad..9ce38a616a68 100644 --- a/dotnet/src/Experimental/Process.IntegrationTestHost.Dapr/Process.IntegrationTestHost.Dapr.csproj +++ b/dotnet/src/Experimental/Process.IntegrationTestHost.Dapr/Process.IntegrationTestHost.Dapr.csproj @@ -7,7 +7,7 @@ enable enable false - $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110 + $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110 b7762d10-e29b-4bb1-8b74-b6d69a667dd4 true diff --git a/dotnet/src/Experimental/Process.IntegrationTestRunner.Dapr/Process.IntegrationTestRunner.Dapr.csproj b/dotnet/src/Experimental/Process.IntegrationTestRunner.Dapr/Process.IntegrationTestRunner.Dapr.csproj index 2d35183b3648..1f3c09de232e 100644 --- a/dotnet/src/Experimental/Process.IntegrationTestRunner.Dapr/Process.IntegrationTestRunner.Dapr.csproj +++ b/dotnet/src/Experimental/Process.IntegrationTestRunner.Dapr/Process.IntegrationTestRunner.Dapr.csproj @@ -7,7 +7,7 @@ enable enable false - $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110 + $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110 b7762d10-e29b-4bb1-8b74-b6d69a667dd4 true diff --git a/dotnet/src/Functions/Functions.Grpc/Extensions/GrpcKernelExtensions.cs b/dotnet/src/Functions/Functions.Grpc/Extensions/GrpcKernelExtensions.cs index 20f928cb7bcb..2b2b7488db6a 100644 --- a/dotnet/src/Functions/Functions.Grpc/Extensions/GrpcKernelExtensions.cs +++ b/dotnet/src/Functions/Functions.Grpc/Extensions/GrpcKernelExtensions.cs @@ -88,7 +88,7 @@ public static KernelPlugin CreatePluginFromGrpcDirectory( { const string ProtoFile = "grpc.proto"; - Verify.ValidPluginName(pluginDirectoryName, kernel.Plugins); + KernelVerify.ValidPluginName(pluginDirectoryName, kernel.Plugins); var pluginDir = Path.Combine(parentDirectory, pluginDirectoryName); Verify.DirectoryExists(pluginDir); @@ -151,7 +151,7 @@ public static KernelPlugin CreatePluginFromGrpc( string pluginName) { Verify.NotNull(kernel); - Verify.ValidPluginName(pluginName, kernel.Plugins); + KernelVerify.ValidPluginName(pluginName, kernel.Plugins); // Parse var parser = new ProtoDocumentParser(); diff --git a/dotnet/src/Functions/Functions.OpenApi.Extensions/Extensions/ApiManifestKernelExtensions.cs b/dotnet/src/Functions/Functions.OpenApi.Extensions/Extensions/ApiManifestKernelExtensions.cs index 6165caf81466..cfbc240fc77d 100644 --- a/dotnet/src/Functions/Functions.OpenApi.Extensions/Extensions/ApiManifestKernelExtensions.cs +++ b/dotnet/src/Functions/Functions.OpenApi.Extensions/Extensions/ApiManifestKernelExtensions.cs @@ -100,7 +100,7 @@ public static async Task CreatePluginFromApiManifestAsync( CancellationToken cancellationToken = default) { Verify.NotNull(kernel); - Verify.ValidPluginName(pluginName, kernel.Plugins); + KernelVerify.ValidPluginName(pluginName, kernel.Plugins); #pragma warning disable CA2000 // Dispose objects before losing scope. No need to dispose the Http client here. It can either be an internal client using NonDisposableHttpClientHandler or an external client managed by the calling code, which should handle its disposal. var httpClient = HttpClientProvider.GetHttpClient(pluginParameters?.HttpClient ?? kernel.Services.GetService()); diff --git a/dotnet/src/Functions/Functions.OpenApi.Extensions/Extensions/CopilotAgentPluginKernelExtensions.cs b/dotnet/src/Functions/Functions.OpenApi.Extensions/Extensions/CopilotAgentPluginKernelExtensions.cs index fcea1ef3a387..b20b8968c90b 100644 --- a/dotnet/src/Functions/Functions.OpenApi.Extensions/Extensions/CopilotAgentPluginKernelExtensions.cs +++ b/dotnet/src/Functions/Functions.OpenApi.Extensions/Extensions/CopilotAgentPluginKernelExtensions.cs @@ -62,7 +62,7 @@ public static async Task CreatePluginFromCopilotAgentPluginAsync( CancellationToken cancellationToken = default) { Verify.NotNull(kernel); - Verify.ValidPluginName(pluginName, kernel.Plugins); + KernelVerify.ValidPluginName(pluginName, kernel.Plugins); #pragma warning disable CA2000 // Dispose objects before losing scope. No need to dispose the Http client here. It can either be an internal client using NonDisposableHttpClientHandler or an external client managed by the calling code, which should handle its disposal. var httpClient = HttpClientProvider.GetHttpClient(pluginParameters?.HttpClient ?? kernel.Services.GetService()); diff --git a/dotnet/src/Functions/Functions.OpenApi/Extensions/OpenApiKernelExtensions.cs b/dotnet/src/Functions/Functions.OpenApi/Extensions/OpenApiKernelExtensions.cs index 2e7fb3d2214f..93709fc09a77 100644 --- a/dotnet/src/Functions/Functions.OpenApi/Extensions/OpenApiKernelExtensions.cs +++ b/dotnet/src/Functions/Functions.OpenApi/Extensions/OpenApiKernelExtensions.cs @@ -117,7 +117,7 @@ public static async Task CreatePluginFromOpenApiAsync( CancellationToken cancellationToken = default) { Verify.NotNull(kernel); - Verify.ValidPluginName(pluginName, kernel.Plugins); + KernelVerify.ValidPluginName(pluginName, kernel.Plugins); #pragma warning disable CA2000 // Dispose objects before losing scope. No need to dispose the Http client here. It can either be an internal client using NonDisposableHttpClientHandler or an external client managed by the calling code, which should handle its disposal. var httpClient = HttpClientProvider.GetHttpClient(executionParameters?.HttpClient ?? kernel.Services.GetService()); @@ -156,7 +156,7 @@ public static async Task CreatePluginFromOpenApiAsync( CancellationToken cancellationToken = default) { Verify.NotNull(kernel); - Verify.ValidPluginName(pluginName, kernel.Plugins); + KernelVerify.ValidPluginName(pluginName, kernel.Plugins); #pragma warning disable CA2000 // Dispose objects before losing scope. No need to dispose the Http client here. It can either be an internal client using NonDisposableHttpClientHandler or an external client managed by the calling code, which should handle its disposal. var httpClient = HttpClientProvider.GetHttpClient(executionParameters?.HttpClient ?? kernel.Services.GetService()); @@ -199,7 +199,7 @@ public static async Task CreatePluginFromOpenApiAsync( CancellationToken cancellationToken = default) { Verify.NotNull(kernel); - Verify.ValidPluginName(pluginName, kernel.Plugins); + KernelVerify.ValidPluginName(pluginName, kernel.Plugins); #pragma warning disable CA2000 // Dispose objects before losing scope. No need to dispose the Http client here. It can either be an internal client using NonDisposableHttpClientHandler or an external client managed by the calling code, which should handle its disposal. var httpClient = HttpClientProvider.GetHttpClient(executionParameters?.HttpClient ?? kernel.Services.GetService()); @@ -233,7 +233,7 @@ public static KernelPlugin CreatePluginFromOpenApi( CancellationToken cancellationToken = default) { Verify.NotNull(kernel); - Verify.ValidPluginName(pluginName, kernel.Plugins); + KernelVerify.ValidPluginName(pluginName, kernel.Plugins); #pragma warning disable CA2000 // Dispose objects before losing scope. No need to dispose the Http client here. It can either be an internal client using NonDisposableHttpClientHandler or an external client managed by the calling code, which should handle its disposal. var httpClient = HttpClientProvider.GetHttpClient(executionParameters?.HttpClient ?? kernel.Services.GetService()); diff --git a/dotnet/src/Functions/Functions.OpenApi/OpenApiKernelPluginFactory.cs b/dotnet/src/Functions/Functions.OpenApi/OpenApiKernelPluginFactory.cs index 52e1f726fc39..e2a7bc9dc5cb 100644 --- a/dotnet/src/Functions/Functions.OpenApi/OpenApiKernelPluginFactory.cs +++ b/dotnet/src/Functions/Functions.OpenApi/OpenApiKernelPluginFactory.cs @@ -36,7 +36,7 @@ public static async Task CreateFromOpenApiAsync( OpenApiFunctionExecutionParameters? executionParameters = null, CancellationToken cancellationToken = default) { - Verify.ValidPluginName(pluginName); + KernelVerify.ValidPluginName(pluginName); #pragma warning disable CA2000 // Dispose objects before losing scope. No need to dispose the Http client here. It can either be an internal client using NonDisposableHttpClientHandler or an external client managed by the calling code, which should handle its disposal. var httpClient = HttpClientProvider.GetHttpClient(executionParameters?.HttpClient); @@ -73,7 +73,7 @@ public static async Task CreateFromOpenApiAsync( OpenApiFunctionExecutionParameters? executionParameters = null, CancellationToken cancellationToken = default) { - Verify.ValidPluginName(pluginName); + KernelVerify.ValidPluginName(pluginName); #pragma warning disable CA2000 // Dispose objects before losing scope. No need to dispose the Http client here. It can either be an internal client using NonDisposableHttpClientHandler or an external client managed by the calling code, which should handle its disposal. var httpClient = HttpClientProvider.GetHttpClient(executionParameters?.HttpClient); @@ -114,7 +114,7 @@ public static async Task CreateFromOpenApiAsync( OpenApiFunctionExecutionParameters? executionParameters = null, CancellationToken cancellationToken = default) { - Verify.ValidPluginName(pluginName); + KernelVerify.ValidPluginName(pluginName); #pragma warning disable CA2000 // Dispose objects before losing scope. No need to dispose the Http client here. It can either be an internal client using NonDisposableHttpClientHandler or an external client managed by the calling code, which should handle its disposal. var httpClient = HttpClientProvider.GetHttpClient(executionParameters?.HttpClient); @@ -143,7 +143,7 @@ public static KernelPlugin CreateFromOpenApi( RestApiSpecification specification, OpenApiFunctionExecutionParameters? executionParameters = null) { - Verify.ValidPluginName(pluginName); + KernelVerify.ValidPluginName(pluginName); #pragma warning disable CA2000 // Dispose objects before losing scope. No need to dispose the Http client here. It can either be an internal client using NonDisposableHttpClientHandler or an external client managed by the calling code, which should handle its disposal. var httpClient = HttpClientProvider.GetHttpClient(executionParameters?.HttpClient); @@ -391,7 +391,7 @@ private static string ConvertOperationIdToValidFunctionName(string operationId, { try { - Verify.ValidFunctionName(operationId); + KernelVerify.ValidFunctionName(operationId); return operationId; } catch (ArgumentException) diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchHotel.cs index 3f979fe2b828..430838d7e6ed 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchHotel.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchHotel.cs @@ -17,7 +17,7 @@ public class AzureAISearchHotel public string HotelId { get; set; } [SearchableField(IsFilterable = true, IsSortable = true)] - [VectorStoreRecordData(IsFilterable = true, IsFullTextSearchable = true)] + [VectorStoreRecordData(IsIndexed = true, IsFullTextIndexed = true)] public string HotelName { get; set; } [SearchableField(AnalyzerName = LexicalAnalyzerName.Values.EnLucene)] @@ -28,18 +28,18 @@ public class AzureAISearchHotel public ReadOnlyMemory? DescriptionEmbedding { get; set; } [SearchableField(IsFilterable = true, IsFacetable = true)] - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] #pragma warning disable CA1819 // Properties should not return arrays public string[] Tags { get; set; } #pragma warning restore CA1819 // Properties should not return arrays [JsonPropertyName("parking_is_included")] [SimpleField(IsFilterable = true, IsSortable = true, IsFacetable = true)] - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public bool? ParkingIncluded { get; set; } [SimpleField(IsFilterable = true, IsSortable = true, IsFacetable = true)] - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public DateTimeOffset? LastRenovationDate { get; set; } [SimpleField(IsFilterable = true, IsSortable = true, IsFacetable = true)] diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchTextSearchTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchTextSearchTests.cs index 115ae9aabff5..2a4932f98e6f 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchTextSearchTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchTextSearchTests.cs @@ -85,7 +85,11 @@ public override Task CreateTextSearchAsync() var stringMapper = new HotelTextSearchStringMapper(); var resultMapper = new HotelTextSearchResultMapper(); + // TODO: Once OpenAITextEmbeddingGenerationService implements MEAI's IEmbeddingGenerator (#10811), configure it with the AzureAISearchVectorStore above instead of passing it here. +#pragma warning disable CS0618 // VectorStoreTextSearch with ITextEmbeddingGenerationService is obsolete var result = new VectorStoreTextSearch(vectorSearch, this.EmbeddingGenerator!, stringMapper, resultMapper); +#pragma warning restore CS0618 + return Task.FromResult(result); } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreFixture.cs index 6a973333ca3f..3ff4493ee79b 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreFixture.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using System.Text.RegularExpressions; using System.Threading.Tasks; using Azure; @@ -70,12 +71,12 @@ public AzureAISearchVectorStoreFixture() Properties = new List { new VectorStoreRecordKeyProperty("HotelId", typeof(string)), - new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsIndexed = true, IsFullTextIndexed = true }, new VectorStoreRecordDataProperty("Description", typeof(string)), - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 1536 }, - new VectorStoreRecordDataProperty("Tags", typeof(string[])) { IsFilterable = true }, - new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool?)) { IsFilterable = true, StoragePropertyName = "parking_is_included" }, - new VectorStoreRecordDataProperty("LastRenovationDate", typeof(DateTimeOffset?)) { IsFilterable = true }, + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?), 1536), + new VectorStoreRecordDataProperty("Tags", typeof(string[])) { IsIndexed = true }, + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool?)) { IsIndexed = true, StoragePropertyName = "parking_is_included" }, + new VectorStoreRecordDataProperty("LastRenovationDate", typeof(DateTimeOffset?)) { IsIndexed = true }, new VectorStoreRecordDataProperty("Rating", typeof(double?)) } }; @@ -114,6 +115,11 @@ public AzureAISearchVectorStoreFixture() /// public ITextEmbeddingGenerationService EmbeddingGenerator { get; private set; } + /// + /// Gets the embedding used for all test documents that the collection is seeded with. + /// + public ReadOnlyMemory Embedding { get; private set; } + /// /// Create / Recreate index and upload documents before test run. /// @@ -122,7 +128,7 @@ public async Task InitializeAsync() { await AzureAISearchVectorStoreFixture.DeleteIndexIfExistsAsync(this._testIndexName, this.SearchIndexClient); await AzureAISearchVectorStoreFixture.CreateIndexAsync(this._testIndexName, this.SearchIndexClient); - await AzureAISearchVectorStoreFixture.UploadDocumentsAsync(this.SearchIndexClient.GetSearchClient(this._testIndexName), this.EmbeddingGenerator); + await this.UploadDocumentsAsync(this.SearchIndexClient.GetSearchClient(this._testIndexName), this.EmbeddingGenerator); } /// @@ -193,9 +199,9 @@ public static async Task CreateIndexAsync(string indexName, SearchIndexClient ad /// /// The client to use for uploading the documents. /// An instance of to generate embeddings. - public static async Task UploadDocumentsAsync(SearchClient searchClient, ITextEmbeddingGenerationService embeddingGenerator) + public async Task UploadDocumentsAsync(SearchClient searchClient, ITextEmbeddingGenerationService embeddingGenerator) { - var embedding = await embeddingGenerator.GenerateEmbeddingAsync("This is a great hotel"); + this.Embedding = await embeddingGenerator.GenerateEmbeddingAsync("This is a great hotel"); IndexDocumentsBatch batch = IndexDocumentsBatch.Create( IndexDocumentsAction.Upload( @@ -204,7 +210,7 @@ public static async Task UploadDocumentsAsync(SearchClient searchClient, ITextEm HotelId = "BaseSet-1", HotelName = "Hotel 1", Description = "This is a great hotel", - DescriptionEmbedding = embedding, + DescriptionEmbedding = this.Embedding, Tags = new[] { "pool", "air conditioning", "concierge" }, ParkingIncluded = false, LastRenovationDate = new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), @@ -216,7 +222,7 @@ public static async Task UploadDocumentsAsync(SearchClient searchClient, ITextEm HotelId = "BaseSet-2", HotelName = "Hotel 2", Description = "This is a great hotel", - DescriptionEmbedding = embedding, + DescriptionEmbedding = this.Embedding, Tags = new[] { "pool", "free wifi", "concierge" }, ParkingIncluded = false, LastRenovationDate = new DateTimeOffset(1979, 2, 18, 0, 0, 0, TimeSpan.Zero), @@ -228,7 +234,7 @@ public static async Task UploadDocumentsAsync(SearchClient searchClient, ITextEm HotelId = "BaseSet-3", HotelName = "Hotel 3", Description = "This is a great hotel", - DescriptionEmbedding = embedding, + DescriptionEmbedding = this.Embedding, Tags = new[] { "air conditioning", "bar", "continental breakfast" }, ParkingIncluded = true, LastRenovationDate = new DateTimeOffset(2015, 9, 20, 0, 0, 0, TimeSpan.Zero), @@ -240,7 +246,7 @@ public static async Task UploadDocumentsAsync(SearchClient searchClient, ITextEm HotelId = "BaseSet-4", HotelName = "Hotel 4", Description = "This is a great hotel", - DescriptionEmbedding = embedding, + DescriptionEmbedding = this.Embedding, Tags = new[] { "concierge", "view", "24-hour front desk service" }, ParkingIncluded = true, LastRenovationDate = new DateTimeOffset(1960, 2, 06, 0, 0, 0, TimeSpan.Zero), diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreRecordCollectionTests.cs index 8fa45147398b..eeb7b46e7ab2 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreRecordCollectionTests.cs @@ -1,14 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.Linq; -using System.Text.Json.Nodes; using System.Threading.Tasks; using Azure; using Azure.Search.Documents.Indexes; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.AzureAISearch; -using Microsoft.SemanticKernel.Embeddings; using Xunit; using Xunit.Abstractions; @@ -17,7 +16,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.AzureAISearch; #pragma warning disable CS0618 // VectorSearchFilter is obsolete /// -/// Integration tests for class. +/// Integration tests for class. /// Tests work with an Azure AI Search Instance. /// [Collection("AzureAISearchVectorStoreCollection")] @@ -33,7 +32,7 @@ public async Task CollectionExistsReturnsCollectionStateAsync(bool expectedExist { // Arrange. var collectionName = expectedExists ? fixture.TestIndexName : "nonexistentcollection"; - var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, collectionName); + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, collectionName); // Act. var actual = await sut.CollectionExistsAsync(); @@ -48,28 +47,29 @@ public async Task CollectionExistsReturnsCollectionStateAsync(bool expectedExist public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDefinition) { // Arrange - var hotel = await this.CreateTestHotelAsync("Upsert-1"); + var hotel = this.CreateTestHotel("Upsert-1"); var testCollectionName = $"{fixture.TestIndexName}-createtest"; var options = new AzureAISearchVectorStoreRecordCollectionOptions { VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null }; - var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, testCollectionName, options); + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, testCollectionName, options); await sut.DeleteCollectionAsync(); // Act await sut.CreateCollectionAsync(); var upsertResult = await sut.UpsertAsync(hotel); - var getResult = await sut.GetAsync("Upsert-1"); - var embedding = await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("A great hotel"); - var actual = await sut.VectorizedSearchAsync( + var getResult = await sut.GetAsync("Upsert-1", new() { IncludeVectors = true }); + var embedding = fixture.Embedding; + var searchResults = await sut.VectorizedSearchAsync( embedding, + top: 3, new() { IncludeVectors = true, OldFilter = new VectorSearchFilter().EqualTo("HotelName", "MyHotel Upsert-1") - }); + }).ToListAsync(); // Assert var collectionExistResult = await sut.CollectionExistsAsync(); @@ -89,7 +89,6 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDe Assert.Equal(hotel.LastRenovationDate, getResult.LastRenovationDate); Assert.Equal(hotel.Rating, getResult.Rating); - var searchResults = await actual.Results.ToListAsync(); Assert.Single(searchResults); var searchResultRecord = searchResults.First().Record; Assert.Equal(hotel.HotelName, searchResultRecord.HotelName); @@ -113,7 +112,7 @@ public async Task ItCanDeleteCollectionAsync() // Arrange var tempCollectionName = fixture.TestIndexName + "-delete"; await AzureAISearchVectorStoreFixture.CreateIndexAsync(tempCollectionName, fixture.SearchIndexClient); - var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, tempCollectionName); + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, tempCollectionName); // Act await sut.DeleteCollectionAsync(); @@ -132,12 +131,12 @@ public async Task ItCanUpsertDocumentToVectorStoreAsync(bool useRecordDefinition { VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null }; - var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName, options); + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName, options); // Act - var hotel = await this.CreateTestHotelAsync("Upsert-1"); + var hotel = this.CreateTestHotel("Upsert-1"); var upsertResult = await sut.UpsertAsync(hotel); - var getResult = await sut.GetAsync("Upsert-1"); + var getResult = await sut.GetAsync("Upsert-1", new() { IncludeVectors = true }); // Assert Assert.NotNull(upsertResult); @@ -162,27 +161,26 @@ public async Task ItCanUpsertDocumentToVectorStoreAsync(bool useRecordDefinition public async Task ItCanUpsertManyDocumentsToVectorStoreAsync() { // Arrange - var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); // Act - var results = sut.UpsertBatchAsync( + var results = await sut.UpsertAsync( [ - await this.CreateTestHotelAsync("UpsertMany-1"), - await this.CreateTestHotelAsync("UpsertMany-2"), - await this.CreateTestHotelAsync("UpsertMany-3"), + this.CreateTestHotel("UpsertMany-1"), + this.CreateTestHotel("UpsertMany-2"), + this.CreateTestHotel("UpsertMany-3"), ]); // Assert Assert.NotNull(results); - var resultsList = await results.ToListAsync(); - Assert.Equal(3, resultsList.Count); - Assert.Contains("UpsertMany-1", resultsList); - Assert.Contains("UpsertMany-2", resultsList); - Assert.Contains("UpsertMany-3", resultsList); + Assert.Equal(3, results.Count); + Assert.Contains("UpsertMany-1", results); + Assert.Contains("UpsertMany-2", results); + Assert.Contains("UpsertMany-3", results); // Output - foreach (var result in resultsList) + foreach (var result in results) { output.WriteLine(result); } @@ -200,7 +198,7 @@ public async Task ItCanGetDocumentFromVectorStoreAsync(bool includeVectors, bool { VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null }; - var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName, options); + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName, options); // Act var getResult = await sut.GetAsync("BaseSet-1", new GetRecordOptions { IncludeVectors = includeVectors }); @@ -213,7 +211,7 @@ public async Task ItCanGetDocumentFromVectorStoreAsync(bool includeVectors, bool Assert.Equal(includeVectors, getResult.DescriptionEmbedding != null); if (includeVectors) { - var embedding = await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("This is a great hotel"); + var embedding = fixture.Embedding; Assert.Equal(embedding, getResult.DescriptionEmbedding!.Value.ToArray()); } else @@ -233,11 +231,11 @@ public async Task ItCanGetDocumentFromVectorStoreAsync(bool includeVectors, bool public async Task ItCanGetManyDocumentsFromVectorStoreAsync() { // Arrange - var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); // Act // Also include one non-existing key to test that the operation does not fail for these and returns only the found ones. - var hotels = sut.GetBatchAsync(["BaseSet-1", "BaseSet-2", "BaseSet-3", "BaseSet-5", "BaseSet-4"], new GetRecordOptions { IncludeVectors = true }); + var hotels = sut.GetAsync(["BaseSet-1", "BaseSet-2", "BaseSet-3", "BaseSet-5", "BaseSet-4"], new GetRecordOptions { IncludeVectors = true }); // Assert Assert.NotNull(hotels); @@ -261,8 +259,8 @@ public async Task ItCanRemoveDocumentFromVectorStoreAsync(bool useRecordDefiniti { VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null }; - var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); - await sut.UpsertAsync(await this.CreateTestHotelAsync("Remove-1")); + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); + await sut.UpsertAsync(this.CreateTestHotel("Remove-1")); // Act await sut.DeleteAsync("Remove-1"); @@ -277,14 +275,14 @@ public async Task ItCanRemoveDocumentFromVectorStoreAsync(bool useRecordDefiniti public async Task ItCanRemoveManyDocumentsFromVectorStoreAsync() { // Arrange - var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); - await sut.UpsertAsync(await this.CreateTestHotelAsync("RemoveMany-1")); - await sut.UpsertAsync(await this.CreateTestHotelAsync("RemoveMany-2")); - await sut.UpsertAsync(await this.CreateTestHotelAsync("RemoveMany-3")); + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); + await sut.UpsertAsync(this.CreateTestHotel("RemoveMany-1")); + await sut.UpsertAsync(this.CreateTestHotel("RemoveMany-2")); + await sut.UpsertAsync(this.CreateTestHotel("RemoveMany-3")); // Act // Also include a non-existing key to test that the operation does not fail for these. - await sut.DeleteBatchAsync(["RemoveMany-1", "RemoveMany-2", "RemoveMany-3", "RemoveMany-4"]); + await sut.DeleteAsync(["RemoveMany-1", "RemoveMany-2", "RemoveMany-3", "RemoveMany-4"]); // Assert Assert.Null(await sut.GetAsync("RemoveMany-1", new GetRecordOptions { IncludeVectors = true })); @@ -296,7 +294,7 @@ public async Task ItCanRemoveManyDocumentsFromVectorStoreAsync() public async Task ItReturnsNullWhenGettingNonExistentRecordAsync() { // Arrange - var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); // Act & Assert Assert.Null(await sut.GetAsync("BaseSet-5", new GetRecordOptions { IncludeVectors = true })); @@ -307,7 +305,7 @@ public async Task ItThrowsOperationExceptionForFailedConnectionAsync() { // Arrange var searchIndexClient = new SearchIndexClient(new Uri("https://localhost:12345"), new AzureKeyCredential("12345")); - var sut = new AzureAISearchVectorStoreRecordCollection(searchIndexClient, fixture.TestIndexName); + var sut = new AzureAISearchVectorStoreRecordCollection(searchIndexClient, fixture.TestIndexName); // Act & Assert await Assert.ThrowsAsync(async () => await sut.GetAsync("BaseSet-1", new GetRecordOptions { IncludeVectors = true })); @@ -318,44 +316,33 @@ public async Task ItThrowsOperationExceptionForFailedAuthenticationAsync() { // Arrange var searchIndexClient = new SearchIndexClient(new Uri(fixture.Config.ServiceUrl), new AzureKeyCredential("12345")); - var sut = new AzureAISearchVectorStoreRecordCollection(searchIndexClient, fixture.TestIndexName); + var sut = new AzureAISearchVectorStoreRecordCollection(searchIndexClient, fixture.TestIndexName); // Act & Assert await Assert.ThrowsAsync(async () => await sut.GetAsync("BaseSet-1", new GetRecordOptions { IncludeVectors = true })); } - [Fact(Skip = SkipReason)] - public async Task ItThrowsMappingExceptionForFailedMapperAsync() - { - // Arrange - var options = new AzureAISearchVectorStoreRecordCollectionOptions { JsonObjectCustomMapper = new FailingMapper() }; - var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName, options); - - // Act & Assert - await Assert.ThrowsAsync(async () => await sut.GetAsync("BaseSet-1", new GetRecordOptions { IncludeVectors = true })); - } - [Theory(Skip = SkipReason)] [InlineData("equality", true)] [InlineData("tagContains", false)] public async Task ItCanSearchWithVectorAndFiltersAsync(string option, bool includeVectors) { // Arrange. - var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); // Act. var filter = option == "equality" ? new VectorSearchFilter().EqualTo("HotelName", "Hotel 3") : new VectorSearchFilter().AnyTagEqualTo("Tags", "bar"); - var actual = await sut.VectorizedSearchAsync( - await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("A great hotel"), + var searchResults = await sut.VectorizedSearchAsync( + fixture.Embedding, + top: 3, new() { IncludeVectors = includeVectors, VectorProperty = r => r.DescriptionEmbedding, OldFilter = filter, - }); + }).ToListAsync(); // Assert. - var searchResults = await actual.Results.ToListAsync(); Assert.Single(searchResults); var searchResult = searchResults.First(); Assert.Equal("BaseSet-3", searchResult.Record.HotelId); @@ -368,7 +355,7 @@ await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("A great hotel"), if (includeVectors) { Assert.NotNull(searchResult.Record.DescriptionEmbedding); - var embedding = await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("This is a great hotel"); + var embedding = fixture.Embedding; Assert.Equal(embedding, searchResult.Record.DescriptionEmbedding!.Value.ToArray()); } else @@ -381,100 +368,84 @@ await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("A great hotel"), public async Task ItCanSearchWithVectorizableTextAndFiltersAsync() { // Arrange. - var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); // Act. var filter = new VectorSearchFilter().EqualTo("HotelName", "Hotel 3"); - var actual = await sut.VectorizableTextSearchAsync( + var searchResults = await sut.VectorizableTextSearchAsync( "A hotel with great views.", + top: 3, new() { VectorProperty = r => r.DescriptionEmbedding, OldFilter = filter, - }); + }).ToListAsync(); // Assert. - var searchResults = await actual.Results.ToListAsync(); Assert.Single(searchResults); } [Fact(Skip = SkipReason)] - public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() + public async Task ItCanUpsertAndRetrieveUsingTheDynamicMapperAsync() { // Arrange - var options = new AzureAISearchVectorStoreRecordCollectionOptions> + var options = new AzureAISearchVectorStoreRecordCollectionOptions> { VectorStoreRecordDefinition = fixture.VectorStoreRecordDefinition }; - var sut = new AzureAISearchVectorStoreRecordCollection>(fixture.SearchIndexClient, fixture.TestIndexName, options); + var sut = new AzureAISearchVectorStoreRecordCollection>(fixture.SearchIndexClient, fixture.TestIndexName, options); // Act var baseSetGetResult = await sut.GetAsync("BaseSet-1", new GetRecordOptions { IncludeVectors = true }); - var baseSetEmbedding = await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("This is a great hotel"); - var genericMapperEmbedding = await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("This is a generic mapper hotel"); - var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel("GenericMapper-1") + var baseSetEmbedding = fixture.Embedding; + var dynamicMapperEmbedding = fixture.Embedding; + var upsertResult = await sut.UpsertAsync(new Dictionary { - Data = - { - { "HotelName", "Generic Mapper Hotel" }, - { "Description", "This is a generic mapper hotel" }, - { "Tags", new string[] { "generic" } }, - { "ParkingIncluded", false }, - { "LastRenovationDate", new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero) }, - { "Rating", 3.6d } - }, - Vectors = - { - { "DescriptionEmbedding", genericMapperEmbedding } - } + ["HotelId"] = "DynamicMapper-1", + + ["HotelName"] = "Dynamic Mapper Hotel", + ["Description"] = "This is a dynamic mapper hotel", + ["Tags"] = new string[] { "dynamic" }, + ["ParkingIncluded"] = false, + ["LastRenovationDate"] = new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), + ["Rating"] = 3.6d, + + ["DescriptionEmbedding"] = dynamicMapperEmbedding }); - var localGetResult = await sut.GetAsync("GenericMapper-1", new GetRecordOptions { IncludeVectors = true }); + var localGetResult = await sut.GetAsync("DynamicMapper-1", new GetRecordOptions { IncludeVectors = true }); // Assert Assert.NotNull(baseSetGetResult); - Assert.Equal("Hotel 1", baseSetGetResult.Data["HotelName"]); - Assert.Equal("This is a great hotel", baseSetGetResult.Data["Description"]); - Assert.Equal(new[] { "pool", "air conditioning", "concierge" }, baseSetGetResult.Data["Tags"]); - Assert.False((bool?)baseSetGetResult.Data["ParkingIncluded"]); - Assert.Equal(new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), baseSetGetResult.Data["LastRenovationDate"]); - Assert.Equal(3.6d, baseSetGetResult.Data["Rating"]); - Assert.Equal(baseSetEmbedding, (ReadOnlyMemory)baseSetGetResult.Vectors["DescriptionEmbedding"]!); + Assert.Equal("Hotel 1", baseSetGetResult["HotelName"]); + Assert.Equal("This is a great hotel", baseSetGetResult["Description"]); + Assert.Equal(new[] { "pool", "air conditioning", "concierge" }, baseSetGetResult["Tags"]); + Assert.False((bool?)baseSetGetResult["ParkingIncluded"]); + Assert.Equal(new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), baseSetGetResult["LastRenovationDate"]); + Assert.Equal(3.6d, baseSetGetResult["Rating"]); + Assert.Equal(baseSetEmbedding, (ReadOnlyMemory)baseSetGetResult["DescriptionEmbedding"]!); Assert.NotNull(upsertResult); - Assert.Equal("GenericMapper-1", upsertResult); + Assert.Equal("DynamicMapper-1", upsertResult); Assert.NotNull(localGetResult); - Assert.Equal("Generic Mapper Hotel", localGetResult.Data["HotelName"]); - Assert.Equal("This is a generic mapper hotel", localGetResult.Data["Description"]); - Assert.Equal(new[] { "generic" }, localGetResult.Data["Tags"]); - Assert.False((bool?)localGetResult.Data["ParkingIncluded"]); - Assert.Equal(new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), localGetResult.Data["LastRenovationDate"]); - Assert.Equal(3.6d, localGetResult.Data["Rating"]); - Assert.Equal(genericMapperEmbedding, (ReadOnlyMemory)localGetResult.Vectors["DescriptionEmbedding"]!); + Assert.Equal("Dynamic Mapper Hotel", localGetResult["HotelName"]); + Assert.Equal("This is a dynamic mapper hotel", localGetResult["Description"]); + Assert.Equal(new[] { "dynamic" }, localGetResult["Tags"]); + Assert.False((bool?)localGetResult["ParkingIncluded"]); + Assert.Equal(new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), localGetResult["LastRenovationDate"]); + Assert.Equal(3.6d, localGetResult["Rating"]); + Assert.Equal(dynamicMapperEmbedding, (ReadOnlyMemory)localGetResult["DescriptionEmbedding"]!); } - private async Task CreateTestHotelAsync(string hotelId) => new() + private AzureAISearchHotel CreateTestHotel(string hotelId) => new() { HotelId = hotelId, HotelName = $"MyHotel {hotelId}", Description = "My Hotel is great.", - DescriptionEmbedding = await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("My hotel is great"), + DescriptionEmbedding = fixture.Embedding, Tags = ["pool", "air conditioning", "concierge"], ParkingIncluded = true, LastRenovationDate = new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), Rating = 3.6 }; - - private sealed class FailingMapper : IVectorStoreRecordMapper - { - public JsonObject MapFromDataToStorageModel(AzureAISearchHotel dataModel) - { - throw new NotImplementedException(); - } - - public AzureAISearchHotel MapFromStorageToDataModel(JsonObject storageModel, StorageToDataModelMapperOptions options) - { - throw new NotImplementedException(); - } - } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBHotel.cs index 7a8830ea2842..bf933707041c 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBHotel.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBHotel.cs @@ -15,7 +15,7 @@ public class AzureCosmosDBMongoDBHotel public string HotelId { get; init; } /// A string metadata field. - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string? HotelName { get; set; } /// An int metadata field. @@ -43,6 +43,6 @@ public class AzureCosmosDBMongoDBHotel public DateTime Timestamp { get; set; } /// A vector field. - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance, IndexKind: IndexKind.IvfFlat)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineDistance, IndexKind = IndexKind.IvfFlat)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStoreTests.cs index cc0d1238b95a..c0282c454ee8 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStoreTests.cs @@ -13,6 +13,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.AzureCosmosDBMongoDB; /// /// Integration tests of . /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class AzureCosmosDBMongoDBMemoryStoreTests : IClassFixture { private const string? SkipReason = "Azure CosmosDB Mongo vCore cluster is required"; diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStoreTestsFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStoreTestsFixture.cs index 6854e7e7fdf8..d1fc295e3b1d 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStoreTestsFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStoreTestsFixture.cs @@ -10,6 +10,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.AzureCosmosDBMongoDB; +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class AzureCosmosDBMongoDBMemoryStoreTestsFixture : IAsyncLifetime { public AzureCosmosDBMongoDBMemoryStore MemoryStore { get; } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreFixture.cs index a56f8b41399c..10d52bce99e8 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreFixture.cs @@ -32,10 +32,11 @@ public AzureCosmosDBMongoDBVectorStoreFixture() .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true) .AddJsonFile( path: "testsettings.development.json", - optional: false, + optional: true, reloadOnChange: true ) .AddEnvironmentVariables() + .AddUserSecrets() .Build(); var connectionString = GetConnectionString(configuration); @@ -55,7 +56,7 @@ public AzureCosmosDBMongoDBVectorStoreFixture() new VectorStoreRecordDataProperty("Tags", typeof(List)), new VectorStoreRecordDataProperty("Timestamp", typeof(DateTime)), new VectorStoreRecordDataProperty("Description", typeof(string)), - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, IndexKind = IndexKind.IvfFlat, DistanceFunction = DistanceFunction.CosineDistance } + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?), 4) { IndexKind = IndexKind.IvfFlat, DistanceFunction = DistanceFunction.CosineDistance } ] }; } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs index f873991177d3..b52a7fc56b9f 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs @@ -25,7 +25,7 @@ public class AzureCosmosDBMongoDBVectorStoreRecordCollectionTests(AzureCosmosDBM public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) { // Arrange - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, collectionName); + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, collectionName); // Act var actual = await sut.CollectionExistsAsync(); @@ -38,13 +38,21 @@ public async Task CollectionExistsReturnsCollectionStateAsync(string collectionN public async Task ItCanCreateCollectionAsync() { // Arrange - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, "sk-test-create-collection"); - // Act - await sut.CreateCollectionAsync(); + try + { + // Act + await sut.CreateCollectionAsync(); - // Assert - Assert.True(await sut.CollectionExistsAsync()); + // Assert + Assert.True(await sut.CollectionExistsAsync()); + } + finally + { + // Clean up + await sut.DeleteCollectionAsync(); + } } [Theory(Skip = SkipReason)] @@ -65,7 +73,7 @@ public async Task ItCanCreateCollectionUpsertAndGetAsync(bool includeVectors, bo VectorStoreRecordDefinition = useRecordDefinition ? fixture.HotelVectorStoreRecordDefinition : null }; - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, collectionName); + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, collectionName); var record = this.CreateTestHotel(HotelId); @@ -108,7 +116,7 @@ public async Task ItCanDeleteCollectionAsync() const string TempCollectionName = "temp-test"; await fixture.MongoDatabase.CreateCollectionAsync(TempCollectionName); - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, TempCollectionName); + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, TempCollectionName); Assert.True(await sut.CollectionExistsAsync()); @@ -124,7 +132,7 @@ public async Task ItCanGetAndDeleteRecordAsync() { // Arrange const string HotelId = "55555555-5555-5555-5555-555555555555"; - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); var record = this.CreateTestHotel(HotelId); @@ -151,14 +159,14 @@ public async Task ItCanGetAndDeleteBatchAsync() const string HotelId2 = "22222222-2222-2222-2222-222222222222"; const string HotelId3 = "33333333-3333-3333-3333-333333333333"; - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); var record1 = this.CreateTestHotel(HotelId1); var record2 = this.CreateTestHotel(HotelId2); var record3 = this.CreateTestHotel(HotelId3); - var upsertResults = await sut.UpsertBatchAsync([record1, record2, record3]).ToListAsync(); - var getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + var upsertResults = await sut.UpsertAsync([record1, record2, record3]); + var getResults = await sut.GetAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); Assert.Equal([HotelId1, HotelId2, HotelId3], upsertResults); @@ -167,9 +175,9 @@ public async Task ItCanGetAndDeleteBatchAsync() Assert.NotNull(getResults.First(l => l.HotelId == HotelId3)); // Act - await sut.DeleteBatchAsync([HotelId1, HotelId2, HotelId3]); + await sut.DeleteAsync([HotelId1, HotelId2, HotelId3]); - getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + getResults = await sut.GetAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); // Assert Assert.Empty(getResults); @@ -180,7 +188,7 @@ public async Task ItCanUpsertRecordAsync() { // Arrange const string HotelId = "55555555-5555-5555-5555-555555555555"; - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); var record = this.CreateTestHotel(HotelId); @@ -218,7 +226,7 @@ public async Task UpsertWithModelWorksCorrectlyAsync() var model = new TestModel { Id = "key", HotelName = "Test Name" }; - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( fixture.MongoDatabase, fixture.TestCollection, new() { VectorStoreRecordDefinition = definition }); @@ -241,7 +249,7 @@ public async Task UpsertWithVectorStoreModelWorksCorrectlyAsync() // Arrange var model = new VectorStoreTestModel { HotelId = "key", HotelName = "Test Name" }; - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); // Act var upsertResult = await sut.UpsertAsync(model); @@ -270,7 +278,7 @@ public async Task UpsertWithBsonModelWorksCorrectlyAsync() var model = new BsonTestModel { Id = "key", HotelName = "Test Name" }; - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection( fixture.MongoDatabase, fixture.TestCollection, new() { VectorStoreRecordDefinition = definition }); @@ -293,7 +301,7 @@ public async Task UpsertWithBsonVectorStoreModelWorksCorrectlyAsync() // Arrange var model = new BsonVectorStoreTestModel { HotelId = "key", HotelName = "Test Name" }; - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); // Act var upsertResult = await sut.UpsertAsync(model); @@ -313,7 +321,7 @@ public async Task UpsertWithBsonVectorStoreWithNameModelWorksCorrectlyAsync() // Arrange var model = new BsonVectorStoreWithNameTestModel { Id = "key", HotelName = "Test Name" }; - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); // Act var upsertResult = await sut.UpsertAsync(model); @@ -336,17 +344,16 @@ public async Task VectorizedSearchReturnsValidResultsByDefaultAsync() var hotel3 = this.CreateTestHotel(hotelId: "key3", embedding: new[] { 20f, 20f, 20f, 20f }); var hotel4 = this.CreateTestHotel(hotelId: "key4", embedding: new[] { -1000f, -1000f, -1000f, -1000f }); - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, "TestVectorizedSearch"); + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, "TestVectorizedSearch"); await sut.CreateCollectionIfNotExistsAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f])); + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), top: 3).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); var ids = searchResults.Select(l => l.Record.HotelId).ToList(); Assert.Equal("key1", ids[0]); @@ -367,21 +374,19 @@ public async Task VectorizedSearchReturnsValidResultsWithOffsetAsync() var hotel3 = this.CreateTestHotel(hotelId: "key3", embedding: new[] { 20f, 20f, 20f, 20f }); var hotel4 = this.CreateTestHotel(hotelId: "key4", embedding: new[] { -1000f, -1000f, -1000f, -1000f }); - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, "TestVectorizedSearchWithOffset"); + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, "TestVectorizedSearchWithOffset"); await sut.CreateCollectionIfNotExistsAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), new() + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), top: 2, new() { - Top = 2, Skip = 2 - }); + }).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); var ids = searchResults.Select(l => l.Record.HotelId).ToList(); Assert.Equal("key3", ids[0]); @@ -400,20 +405,19 @@ public async Task VectorizedSearchReturnsValidResultsWithFilterAsync() var hotel3 = this.CreateTestHotel(hotelId: "key3", embedding: new[] { 20f, 20f, 20f, 20f }); var hotel4 = this.CreateTestHotel(hotelId: "key4", embedding: new[] { -1000f, -1000f, -1000f, -1000f }); - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, "TestVectorizedSearchWithOffset"); + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection(fixture.MongoDatabase, "TestVectorizedSearchWithOffset"); await sut.CreateCollectionIfNotExistsAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), new() + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), top: 3, new() { OldFilter = new VectorSearchFilter().EqualTo(nameof(AzureCosmosDBMongoDBHotel.HotelName), "My Hotel key2") - }); + }).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); var ids = searchResults.Select(l => l.Record.HotelId).ToList(); Assert.Equal("key2", ids[0]); @@ -424,48 +428,45 @@ public async Task VectorizedSearchReturnsValidResultsWithFilterAsync() } [Fact(Skip = SkipReason)] - public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() + public async Task ItCanUpsertAndRetrieveUsingTheDynamicMapperAsync() { // Arrange - var options = new AzureCosmosDBMongoDBVectorStoreRecordCollectionOptions> + var options = new AzureCosmosDBMongoDBVectorStoreRecordCollectionOptions> { VectorStoreRecordDefinition = fixture.HotelVectorStoreRecordDefinition }; - var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection>(fixture.MongoDatabase, fixture.TestCollection, options); + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection>(fixture.MongoDatabase, fixture.TestCollection, options); // Act - var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel("GenericMapper-1") + var upsertResult = await sut.UpsertAsync(new Dictionary { - Data = - { - { "HotelName", "Generic Mapper Hotel" }, - { "Description", "This is a generic mapper hotel" }, - { "Tags", new string[] { "generic" } }, - { "ParkingIncluded", false }, - { "Timestamp", new DateTime(1970, 1, 18, 0, 0, 0).ToUniversalTime() }, - { "HotelRating", 3.6f } - }, - Vectors = - { - { "DescriptionEmbedding", new ReadOnlyMemory([30f, 31f, 32f, 33f]) } - } + ["HotelId"] = "DynamicMapper-1", + + ["HotelName"] = "Dynamic Mapper Hotel", + ["Description"] = "This is a dynamic mapper hotel", + ["Tags"] = new string[] { "dynamic" }, + ["ParkingIncluded"] = false, + ["Timestamp"] = new DateTime(1970, 1, 18, 0, 0, 0).ToUniversalTime(), + ["HotelRating"] = 3.6f, + + ["DescriptionEmbedding"] = new ReadOnlyMemory([30f, 31f, 32f, 33f]) }); - var localGetResult = await sut.GetAsync("GenericMapper-1", new GetRecordOptions { IncludeVectors = true }); + var localGetResult = await sut.GetAsync("DynamicMapper-1", new GetRecordOptions { IncludeVectors = true }); // Assert Assert.NotNull(upsertResult); - Assert.Equal("GenericMapper-1", upsertResult); + Assert.Equal("DynamicMapper-1", upsertResult); Assert.NotNull(localGetResult); - Assert.Equal("Generic Mapper Hotel", localGetResult.Data["HotelName"]); - Assert.Equal("This is a generic mapper hotel", localGetResult.Data["Description"]); - Assert.Equal(new[] { "generic" }, localGetResult.Data["Tags"]); - Assert.False((bool?)localGetResult.Data["ParkingIncluded"]); - Assert.Equal(new DateTime(1970, 1, 18, 0, 0, 0).ToUniversalTime(), localGetResult.Data["Timestamp"]); - Assert.Equal(3.6f, localGetResult.Data["HotelRating"]); - Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult.Vectors["DescriptionEmbedding"]!).ToArray()); + Assert.Equal("Dynamic Mapper Hotel", localGetResult["HotelName"]); + Assert.Equal("This is a dynamic mapper hotel", localGetResult["Description"]); + Assert.Equal(new[] { "dynamic" }, localGetResult["Tags"]); + Assert.False((bool?)localGetResult["ParkingIncluded"]); + Assert.Equal(new DateTime(1970, 1, 18, 0, 0, 0).ToUniversalTime(), localGetResult["Timestamp"]); + Assert.Equal(3.6f, localGetResult["HotelRating"]); + Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult["DescriptionEmbedding"]!).ToArray()); } #region private diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLHotel.cs index e7d353486504..49b1ac8da6b2 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLHotel.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLHotel.cs @@ -16,11 +16,11 @@ public record AzureCosmosDBNoSQLHotel() public string HotelId { get; init; } /// A string metadata field. - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string? HotelName { get; set; } /// An int metadata field. - [VectorStoreRecordData(IsFullTextSearchable = true)] + [VectorStoreRecordData(IsFullTextIndexed = true)] public int HotelCode { get; set; } /// A float metadata field. @@ -45,6 +45,6 @@ public record AzureCosmosDBNoSQLHotel() public DateTimeOffset Timestamp { get; set; } /// A vector field. - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineSimilarity, IndexKind: IndexKind.Flat)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineSimilarity, IndexKind = IndexKind.Flat)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTests.cs index e75116e34893..fcd910889785 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTests.cs @@ -18,6 +18,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.AzureCosmosDBNoSQL; /// /// Integration tests of . /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class AzureCosmosDBNoSQLMemoryStoreTests : IClassFixture { private const string? SkipReason = "Azure Cosmos DB Account with Vector indexing enabled required"; diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTestsFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTestsFixture.cs index 7e6f376a8684..e8bbecd47533 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTestsFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTestsFixture.cs @@ -9,6 +9,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.AzureCosmosDBNoSQL; +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class AzureCosmosDBNoSQLMemoryStoreTestsFixture : IAsyncLifetime { public AzureCosmosDBNoSQLMemoryStore MemoryStore { get; } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreFixture.cs index 1af8bbbe6863..a5b2fddc729e 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreFixture.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Net.Http; using System.Text.Json; using System.Threading.Tasks; using Microsoft.Azure.Cosmos; @@ -27,7 +28,14 @@ public AzureCosmosDBNoSQLVectorStoreFixture() throw new ArgumentNullException($"{connectionString} string is not configured"); } - var options = new CosmosClientOptions { UseSystemTextJsonSerializerWithOptions = JsonSerializerOptions.Default }; + var options = new CosmosClientOptions + { + UseSystemTextJsonSerializerWithOptions = JsonSerializerOptions.Default, + ConnectionMode = ConnectionMode.Gateway, +#pragma warning disable CA5400 // HttpClient may be created without enabling CheckCertificateRevocationList + HttpClientFactory = () => new HttpClient(new HttpClientHandler { ServerCertificateCustomValidationCallback = HttpClientHandler.DangerousAcceptAnyServerCertificateValidator }) +#pragma warning restore CA5400 // HttpClient may be created without enabling CheckCertificateRevocationList + }; this._cosmosClient = new CosmosClient(connectionString, options); } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs index 546b957c68ae..462818320dd8 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs @@ -7,27 +7,28 @@ using Microsoft.Azure.Cosmos; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; +using SemanticKernel.IntegrationTests.Connectors.Memory.Xunit; using Xunit; using DistanceFunction = Microsoft.Extensions.VectorData.DistanceFunction; using IndexKind = Microsoft.Extensions.VectorData.IndexKind; namespace SemanticKernel.IntegrationTests.Connectors.Memory.AzureCosmosDBNoSQL; +#pragma warning disable CA1859 // Use concrete types when possible for improved performance #pragma warning disable CS0618 // VectorSearchFilter is obsolete /// -/// Integration tests for class. +/// Integration tests for class. /// [Collection("AzureCosmosDBNoSQLVectorStoreCollection")] +[AzureCosmosDBNoSQLConnectionStringSetCondition] public sealed class AzureCosmosDBNoSQLVectorStoreRecordCollectionTests(AzureCosmosDBNoSQLVectorStoreFixture fixture) { - private const string? SkipReason = "Azure CosmosDB NoSQL cluster is required"; - - [Fact(Skip = SkipReason)] + [VectorStoreFact] public async Task ItCanCreateCollectionAsync() { // Arrange - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, "test-create-collection"); + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, "test-create-collection"); // Act await sut.CreateCollectionAsync(); @@ -36,13 +37,13 @@ public async Task ItCanCreateCollectionAsync() Assert.True(await sut.CollectionExistsAsync()); } - [Theory(Skip = SkipReason)] + [VectorStoreTheory] [InlineData("sk-test-hotels", true)] [InlineData("nonexistentcollection", false)] public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) { // Arrange - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, collectionName); + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, collectionName); if (expectedExists) { @@ -56,7 +57,7 @@ public async Task CollectionExistsReturnsCollectionStateAsync(string collectionN Assert.Equal(expectedExists, actual); } - [Theory(Skip = SkipReason)] + [VectorStoreTheory] [InlineData(true, true)] [InlineData(true, false)] [InlineData(false, true)] @@ -75,7 +76,7 @@ public async Task ItCanCreateCollectionUpsertAndGetAsync(bool includeVectors, bo VectorStoreRecordDefinition = useRecordDefinition ? this.GetTestHotelRecordDefinition() : null }; - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, collectionName); + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, collectionName); var record = this.CreateTestHotel(HotelId); @@ -111,14 +112,14 @@ public async Task ItCanCreateCollectionUpsertAndGetAsync(bool includeVectors, bo } } - [Fact(Skip = SkipReason)] + [VectorStoreFact] public async Task ItCanDeleteCollectionAsync() { // Arrange const string TempCollectionName = "test-delete-collection"; await fixture.Database!.CreateContainerAsync(new ContainerProperties(TempCollectionName, "/id")); - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, TempCollectionName); + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, TempCollectionName); Assert.True(await sut.CollectionExistsAsync()); @@ -129,7 +130,7 @@ public async Task ItCanDeleteCollectionAsync() Assert.False(await sut.CollectionExistsAsync()); } - [Theory(Skip = SkipReason)] + [VectorStoreTheory] [InlineData("consistent-mode-collection", IndexingMode.Consistent)] [InlineData("lazy-mode-collection", IndexingMode.Lazy)] [InlineData("none-mode-collection", IndexingMode.None)] @@ -137,7 +138,7 @@ public async Task ItCanGetAndDeleteRecordAsync(string collectionName, IndexingMo { // Arrange const string HotelId = "55555555-5555-5555-5555-555555555555"; - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection( fixture.Database!, collectionName, new() { IndexingMode = indexingMode, Automatic = indexingMode != IndexingMode.None }); @@ -161,7 +162,7 @@ public async Task ItCanGetAndDeleteRecordAsync(string collectionName, IndexingMo Assert.Null(getResult); } - [Fact(Skip = SkipReason)] + [VectorStoreFact] public async Task ItCanGetAndDeleteRecordWithPartitionKeyAsync() { // Arrange @@ -169,7 +170,7 @@ public async Task ItCanGetAndDeleteRecordWithPartitionKeyAsync() const string HotelName = "Test Hotel Name"; IVectorStoreRecordCollection sut = - new AzureCosmosDBNoSQLVectorStoreRecordCollection( + new AzureCosmosDBNoSQLVectorStoreRecordCollection( fixture.Database!, "delete-with-partition-key", new() { PartitionKeyPropertyName = "HotelName" }); @@ -196,7 +197,7 @@ public async Task ItCanGetAndDeleteRecordWithPartitionKeyAsync() Assert.Null(getResult); } - [Fact(Skip = SkipReason)] + [VectorStoreFact] public async Task ItCanGetAndDeleteBatchAsync() { // Arrange @@ -204,7 +205,7 @@ public async Task ItCanGetAndDeleteBatchAsync() const string HotelId2 = "22222222-2222-2222-2222-222222222222"; const string HotelId3 = "33333333-3333-3333-3333-333333333333"; - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, "get-and-delete-batch"); + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, "get-and-delete-batch"); await sut.CreateCollectionAsync(); @@ -212,8 +213,8 @@ public async Task ItCanGetAndDeleteBatchAsync() var record2 = this.CreateTestHotel(HotelId2); var record3 = this.CreateTestHotel(HotelId3); - var upsertResults = await sut.UpsertBatchAsync([record1, record2, record3]).ToListAsync(); - var getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + var upsertResults = await sut.UpsertAsync([record1, record2, record3]); + var getResults = await sut.GetAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); Assert.Equal([HotelId1, HotelId2, HotelId3], upsertResults); @@ -222,20 +223,20 @@ public async Task ItCanGetAndDeleteBatchAsync() Assert.NotNull(getResults.First(l => l.HotelId == HotelId3)); // Act - await sut.DeleteBatchAsync([HotelId1, HotelId2, HotelId3]); + await sut.DeleteAsync([HotelId1, HotelId2, HotelId3]); - getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + getResults = await sut.GetAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); // Assert Assert.Empty(getResults); } - [Fact(Skip = SkipReason)] + [VectorStoreFact] public async Task ItCanUpsertRecordAsync() { // Arrange const string HotelId = "55555555-5555-5555-5555-555555555555"; - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, "upsert-record"); + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, "upsert-record"); await sut.CreateCollectionAsync(); @@ -260,7 +261,7 @@ public async Task ItCanUpsertRecordAsync() Assert.Equal(10, getResult.HotelRating); } - [Fact(Skip = SkipReason)] + [VectorStoreFact] public async Task VectorizedSearchReturnsValidResultsByDefaultAsync() { // Arrange @@ -269,17 +270,16 @@ public async Task VectorizedSearchReturnsValidResultsByDefaultAsync() var hotel3 = this.CreateTestHotel(hotelId: "key3", embedding: new[] { 20f, 20f, 20f, 20f }); var hotel4 = this.CreateTestHotel(hotelId: "key4", embedding: new[] { -1000f, -1000f, -1000f, -1000f }); - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, "vector-search-default"); + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, "vector-search-default"); await sut.CreateCollectionIfNotExistsAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f])); + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), top: 3).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); var ids = searchResults.Select(l => l.Record.HotelId).ToList(); Assert.Equal("key1", ids[0]); @@ -291,7 +291,7 @@ public async Task VectorizedSearchReturnsValidResultsByDefaultAsync() Assert.Equal(1, searchResults.First(l => l.Record.HotelId == "key1").Score); } - [Fact(Skip = SkipReason)] + [VectorStoreFact] public async Task VectorizedSearchReturnsValidResultsWithOffsetAsync() { // Arrange @@ -300,21 +300,19 @@ public async Task VectorizedSearchReturnsValidResultsWithOffsetAsync() var hotel3 = this.CreateTestHotel(hotelId: "key3", embedding: new[] { 20f, 20f, 20f, 20f }); var hotel4 = this.CreateTestHotel(hotelId: "key4", embedding: new[] { -1000f, -1000f, -1000f, -1000f }); - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, "vector-search-with-offset"); + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, "vector-search-with-offset"); await sut.CreateCollectionIfNotExistsAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), new() + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), top: 2, new() { - Top = 2, Skip = 2 - }); + }).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); var ids = searchResults.Select(l => l.Record.HotelId).ToList(); Assert.Equal("key3", ids[0]); @@ -324,7 +322,7 @@ public async Task VectorizedSearchReturnsValidResultsWithOffsetAsync() Assert.DoesNotContain("key2", ids); } - [Theory(Skip = SkipReason)] + [VectorStoreTheory] [MemberData(nameof(VectorizedSearchWithFilterData))] public async Task VectorizedSearchReturnsValidResultsWithFilterAsync(VectorSearchFilter filter, List expectedIds) { @@ -334,72 +332,69 @@ public async Task VectorizedSearchReturnsValidResultsWithFilterAsync(VectorSearc var hotel3 = this.CreateTestHotel(hotelId: "key3", embedding: new[] { 20f, 20f, 20f, 20f }); var hotel4 = this.CreateTestHotel(hotelId: "key4", embedding: new[] { -1000f, -1000f, -1000f, -1000f }); - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, "vector-search-with-filter"); + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection(fixture.Database!, "vector-search-with-filter"); await sut.CreateCollectionIfNotExistsAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), new() + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), top: 4, new() { OldFilter = filter, - Top = 4, - }); + }).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); var actualIds = searchResults.Select(l => l.Record.HotelId).ToList(); Assert.Equal(expectedIds, actualIds); } - [Fact(Skip = SkipReason)] - public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() + [VectorStoreFact] + public async Task ItCanUpsertAndRetrieveUsingTheDynamicMapperAsync() { // Arrange const string HotelId = "55555555-5555-5555-5555-555555555555"; - var options = new AzureCosmosDBNoSQLVectorStoreRecordCollectionOptions> + var options = new AzureCosmosDBNoSQLVectorStoreRecordCollectionOptions> { VectorStoreRecordDefinition = this.GetTestHotelRecordDefinition() }; - var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection>(fixture.Database!, "generic-mapper", options); + var sut = new AzureCosmosDBNoSQLVectorStoreRecordCollection>(fixture.Database!, "dynamic-mapper", options); await sut.CreateCollectionAsync(); // Act - var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel(HotelId) + var upsertResult = await sut.UpsertAsync(new Dictionary { - Data = - { - { "HotelName", "Generic Mapper Hotel" }, - { "Description", "This is a generic mapper hotel" }, - { "Tags", new List { "generic" } }, - { "parking_is_included", false }, - { "Timestamp", new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero) }, - { "HotelRating", 3.6f } - }, - Vectors = - { - { "DescriptionEmbedding", new ReadOnlyMemory([30f, 31f, 32f, 33f]) } - } + ["HotelId"] = HotelId, + + ["HotelName"] = "Dynamic Mapper Hotel", + ["Description"] = "This is a dynamic mapper hotel", + ["Tags"] = new List { "dynamic" }, + ["parking_is_included"] = false, + ["Timestamp"] = new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), + ["HotelRating"] = 3.6f, + + ["DescriptionEmbedding"] = new ReadOnlyMemory([30f, 31f, 32f, 33f]) }); var localGetResult = await sut.GetAsync(HotelId, new GetRecordOptions { IncludeVectors = true }); // Assert Assert.NotNull(upsertResult); - Assert.Equal(HotelId, upsertResult); + var upsertCompositeKey = (AzureCosmosDBNoSQLCompositeKey)upsertResult; + Assert.Equal(HotelId, upsertCompositeKey.PartitionKey); + Assert.Equal(HotelId, upsertCompositeKey.RecordKey); Assert.NotNull(localGetResult); - Assert.Equal("Generic Mapper Hotel", localGetResult.Data["HotelName"]); - Assert.Equal("This is a generic mapper hotel", localGetResult.Data["Description"]); - Assert.Equal(new List { "generic" }, localGetResult.Data["Tags"]); - Assert.False((bool?)localGetResult.Data["parking_is_included"]); - Assert.Equal(new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), localGetResult.Data["Timestamp"]); - Assert.Equal(3.6f, localGetResult.Data["HotelRating"]); - Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult.Vectors["DescriptionEmbedding"]!).ToArray()); + Assert.Equal("Dynamic Mapper Hotel", localGetResult["HotelName"]); + Assert.Equal("This is a dynamic mapper hotel", localGetResult["Description"]); + Assert.Equal(new List { "dynamic" }, localGetResult["Tags"]); + Assert.False((bool?)localGetResult["parking_is_included"]); + Assert.Equal(new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), localGetResult["Timestamp"]); + Assert.Equal(3.6f, localGetResult["HotelRating"]); + Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult["DescriptionEmbedding"]!).ToArray()); } public static TheoryData> VectorizedSearchWithFilterData => new() @@ -463,7 +458,7 @@ private VectorStoreRecordDefinition GetTestHotelRecordDefinition() new VectorStoreRecordDataProperty("Tags", typeof(List)), new VectorStoreRecordDataProperty("Description", typeof(string)), new VectorStoreRecordDataProperty("Timestamp", typeof(DateTimeOffset)), - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, IndexKind = IndexKind.Flat, DistanceFunction = DistanceFunction.CosineSimilarity } + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?), 4) { IndexKind = IndexKind.Flat, DistanceFunction = DistanceFunction.CosineSimilarity } ] }; } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreTests.cs index 78c87350c23d..da92d728b31b 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreTests.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; -using SemanticKernel.IntegrationTests.Connectors.Memory.Xunit; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.Memory.AzureCosmosDBNoSQL; @@ -10,7 +9,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.AzureCosmosDBNoSQL; /// Integration tests for . /// [Collection("AzureCosmosDBNoSQLVectorStoreCollection")] -[DisableVectorStoreTests(Skip = "Azure CosmosDB NoSQL cluster is required")] +[AzureCosmosDBNoSQLConnectionStringSetCondition] public sealed class AzureCosmosDBNoSQLVectorStoreTests(AzureCosmosDBNoSQLVectorStoreFixture fixture) : BaseVectorStoreTests(new AzureCosmosDBNoSQLVectorStore(fixture.Database!)) { diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/BaseVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/BaseVectorStoreRecordCollectionTests.cs index cf29e88625ab..74590628bdba 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/BaseVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/BaseVectorStoreRecordCollectionTests.cs @@ -24,7 +24,7 @@ public abstract class BaseVectorStoreRecordCollectionTests protected abstract HashSet GetSupportedDistanceFunctions(); - protected abstract IVectorStoreRecordCollection GetTargetRecordCollection(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition); + protected abstract IVectorStoreRecordCollection GetTargetRecordCollection(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition) where TRecord : notnull; protected virtual int DelayAfterIndexCreateInMilliseconds { get; } = 0; @@ -93,14 +93,13 @@ public async Task VectorSearchShouldReturnExpectedScoresAsync(string distanceFun Vector = orthogonalVector, }; - await sut.UpsertBatchAsync([baseRecord, oppositeRecord, orthogonalRecord]).ToListAsync(); + await sut.UpsertAsync([baseRecord, oppositeRecord, orthogonalRecord]); await Task.Delay(this.DelayAfterUploadInMilliseconds); // Act - var searchResult = await sut.VectorizedSearchAsync(baseVector); + var results = await sut.SearchEmbeddingAsync(baseVector, top: 3).ToListAsync(); // Assert - var results = await searchResult.Results.ToListAsync(); Assert.Equal(3, results.Count); Assert.Equal(keyDictionary[resultOrder[0]], results[0].Record.Key); @@ -123,7 +122,7 @@ private static VectorStoreRecordDefinition CreateKeyWithVectorRecordDefinition(i Properties = [ new VectorStoreRecordKeyProperty("Key", typeof(TKey)), - new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) { Dimensions = vectorDimensions, DistanceFunction = distanceFunction }, + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), vectorDimensions) { DistanceFunction = distanceFunction }, ], }; diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/BaseVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/BaseVectorStoreTests.cs index 3a9aff375be3..607a847e9028 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/BaseVectorStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/BaseVectorStoreTests.cs @@ -14,6 +14,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory; /// public abstract class BaseVectorStoreTests(IVectorStore vectorStore) where TKey : notnull + where TRecord : notnull { protected virtual IEnumerable CollectionNames => ["listcollectionnames1", "listcollectionnames2", "listcollectionnames3"]; diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Chroma/ChromaMemoryStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Chroma/ChromaMemoryStoreTests.cs index d337641ad071..770400778817 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Chroma/ChromaMemoryStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Chroma/ChromaMemoryStoreTests.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Net.Http; using System.Threading.Tasks; @@ -16,6 +17,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.Chroma; /// Integration tests for class. /// Tests work with local Chroma server. To setup the server, see dotnet/src/Connectors/Connectors.Memory.Chroma/README.md. /// +[Experimental("SKEXP0020")] public sealed class ChromaMemoryStoreTests : IDisposable { // If null, all tests will be enabled diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/InMemory/InMemoryVectorStoreTextSearchTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/InMemory/InMemoryVectorStoreTextSearchTests.cs index 27d585041e24..2bfc57eb56fc 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/InMemory/InMemoryVectorStoreTextSearchTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/InMemory/InMemoryVectorStoreTextSearchTests.cs @@ -50,7 +50,10 @@ static DataModel CreateRecord(int index, string text, ReadOnlyMemory embe var stringMapper = new DataModelTextSearchStringMapper(); var resultMapper = new DataModelTextSearchResultMapper(); + // TODO: Once OpenAITextEmbeddingGenerationService implements MEAI's IEmbeddingGenerator (#10811), configure it with the InMemoryVectorStore above instead of passing it here. +#pragma warning disable CS0618 // VectorStoreTextSearch with ITextEmbeddingGenerationService is obsolete return new VectorStoreTextSearch(vectorSearch, this.EmbeddingGenerator!, stringMapper, resultMapper); +#pragma warning restore CS0618 } /// diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Milvus/MilvusMemoryStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Milvus/MilvusMemoryStoreTests.cs index 5fba220a3ad4..4ee21728816a 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Milvus/MilvusMemoryStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Milvus/MilvusMemoryStoreTests.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading.Tasks; using Microsoft.SemanticKernel.Connectors.Milvus; @@ -11,6 +12,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.Milvus; +[Experimental("SKEXP0020")] public class MilvusMemoryStoreTests(MilvusFixture milvusFixture) : IClassFixture, IAsyncLifetime { private const string CollectionName = "test"; diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBHotel.cs index b3adb2e723a1..a5b3fd3a09e9 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBHotel.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBHotel.cs @@ -15,7 +15,7 @@ public class MongoDBHotel public string HotelId { get; init; } /// A string metadata field. - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string? HotelName { get; set; } /// An int metadata field. @@ -43,6 +43,6 @@ public class MongoDBHotel public DateTime Timestamp { get; set; } /// A vector field. - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineSimilarity, IndexKind: IndexKind.IvfFlat)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineSimilarity, IndexKind = IndexKind.IvfFlat)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBMemoryStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBMemoryStoreTests.cs index 6f4c834ecf7c..f744484316ef 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBMemoryStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBMemoryStoreTests.cs @@ -13,6 +13,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.MongoDB; /// /// Integration tests of . /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class MongoDBMemoryStoreTests(MongoDBMemoryStoreTestsFixture fixture) : IClassFixture { // If null, all tests will be enabled diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBMemoryStoreTestsFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBMemoryStoreTestsFixture.cs index f96acb8fd77b..ec678690d09f 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBMemoryStoreTestsFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBMemoryStoreTestsFixture.cs @@ -12,6 +12,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.MongoDB; +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class MongoDBMemoryStoreTestsFixture : IAsyncLifetime { #pragma warning disable CA1859 // Use concrete types when possible for improved performance diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreFixture.cs index 3d975dffbdf3..edb37e83509c 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreFixture.cs @@ -3,42 +3,41 @@ using System; using System.Collections.Generic; using System.Threading.Tasks; -using Docker.DotNet; -using Docker.DotNet.Models; using Microsoft.Extensions.VectorData; using MongoDB.Driver; +using Testcontainers.MongoDb; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.MongoDB; +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. + public class MongoDBVectorStoreFixture : IAsyncLifetime { + private readonly MongoDbContainer _container = new MongoDbBuilder() + .WithImage("mongodb/mongodb-atlas-local:7.0.6") + .Build(); + private readonly List _testCollections = ["sk-test-hotels", "sk-test-contacts", "sk-test-addresses"]; /// Main test collection for tests. public string TestCollection => this._testCollections[0]; /// that can be used to manage the collections in MongoDB. - public IMongoDatabase MongoDatabase { get; } + public IMongoDatabase MongoDatabase { get; private set; } /// Gets the manually created vector store record definition for MongoDB test model. public VectorStoreRecordDefinition HotelVectorStoreRecordDefinition { get; private set; } - /// The id of the MongoDB container that we are testing with. - private string? _containerId = null; - - /// The Docker client we are using to create a MongoDB container with. - private readonly DockerClient _client; - - /// - /// Initializes a new instance of the class. - /// - public MongoDBVectorStoreFixture() + public async Task InitializeAsync() { - using var dockerClientConfiguration = new DockerClientConfiguration(); - this._client = dockerClientConfiguration.CreateClient(); + await this._container.StartAsync(); - var mongoClient = new MongoClient("mongodb://localhost:27017/?directConnection=true"); + var mongoClient = new MongoClient(new MongoClientSettings + { + Server = new MongoServerAddress(this._container.Hostname, this._container.GetMappedPublicPort(MongoDbBuilder.MongoDbPort)), + DirectConnection = true, + }); this.MongoDatabase = mongoClient.GetDatabase("test"); @@ -54,14 +53,9 @@ public MongoDBVectorStoreFixture() new VectorStoreRecordDataProperty("Tags", typeof(List)), new VectorStoreRecordDataProperty("Timestamp", typeof(DateTime)), new VectorStoreRecordDataProperty("Description", typeof(string)), - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, IndexKind = IndexKind.IvfFlat, DistanceFunction = DistanceFunction.CosineDistance } + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?), 4) { IndexKind = IndexKind.IvfFlat, DistanceFunction = DistanceFunction.CosineSimilarity } ] }; - } - - public async Task InitializeAsync() - { - this._containerId = await SetupMongoDBContainerAsync(this._client); foreach (var collection in this._testCollections) { @@ -81,52 +75,6 @@ public async Task DisposeAsync() } } - if (this._containerId != null) - { - await this._client.Containers.StopContainerAsync(this._containerId, new ContainerStopParameters()); - await this._client.Containers.RemoveContainerAsync(this._containerId, new ContainerRemoveParameters()); - } + await this._container.StopAsync(); } - - #region private - - private static async Task SetupMongoDBContainerAsync(DockerClient client) - { - const string Image = "mongodb/mongodb-atlas-local"; - const string Tag = "latest"; - - await client.Images.CreateImageAsync( - new ImagesCreateParameters - { - FromImage = Image, - Tag = Tag, - }, - null, - new Progress()); - - var container = await client.Containers.CreateContainerAsync(new CreateContainerParameters() - { - Image = $"{Image}:{Tag}", - HostConfig = new HostConfig() - { - PortBindings = new Dictionary> - { - { "27017", new List { new() { HostPort = "27017" } } }, - }, - PublishAllPorts = true - }, - ExposedPorts = new Dictionary - { - { "27017", default }, - }, - }); - - await client.Containers.StartContainerAsync( - container.ID, - new ContainerStartParameters()); - - return container.ID; - } - - #endregion } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs index c8cab7cb477e..a5ce7239ac78 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs @@ -8,6 +8,8 @@ using Microsoft.SemanticKernel.Connectors.MongoDB; using MongoDB.Bson; using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Driver; +using xRetry; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.MongoDB; @@ -18,15 +20,15 @@ namespace SemanticKernel.IntegrationTests.Connectors.MongoDB; public class MongoDBVectorStoreRecordCollectionTests(MongoDBVectorStoreFixture fixture) { // If null, all tests will be enabled - private const string? SkipReason = "The tests are for manual verification."; + private const string? SkipReason = null; - [Theory(Skip = SkipReason)] + [RetryTheory(typeof(MongoCommandException), Skip = SkipReason)] [InlineData("sk-test-hotels", true)] [InlineData("nonexistentcollection", false)] public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) { // Arrange - var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, collectionName); + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, collectionName); // Act var actual = await sut.CollectionExistsAsync(); @@ -35,20 +37,22 @@ public async Task CollectionExistsReturnsCollectionStateAsync(string collectionN Assert.Equal(expectedExists, actual); } - [Fact(Skip = SkipReason)] + [RetryFact(typeof(MongoCommandException), Skip = SkipReason)] public async Task ItCanCreateCollectionAsync() { // Arrange - var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + var newCollection = Guid.NewGuid().ToString(); + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, newCollection); // Act await sut.CreateCollectionAsync(); // Assert Assert.True(await sut.CollectionExistsAsync()); + await sut.DeleteCollectionAsync(); } - [Theory(Skip = SkipReason)] + [RetryTheory(typeof(MongoCommandException), Skip = SkipReason)] [InlineData(true, true)] [InlineData(true, false)] [InlineData(false, true)] @@ -59,6 +63,7 @@ public async Task ItCanCreateCollectionUpsertAndGetAsync(bool includeVectors, bo const string HotelId = "55555555-5555-5555-5555-555555555555"; var collectionNamePostfix = useRecordDefinition ? "with-definition" : "with-type"; + collectionNamePostfix += includeVectors ? "-with-vectors" : "-without-vectors"; var collectionName = $"collection-{collectionNamePostfix}"; var options = new MongoDBVectorStoreRecordCollectionOptions @@ -66,7 +71,7 @@ public async Task ItCanCreateCollectionUpsertAndGetAsync(bool includeVectors, bo VectorStoreRecordDefinition = useRecordDefinition ? fixture.HotelVectorStoreRecordDefinition : null }; - var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, collectionName); + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, collectionName, options); var record = this.CreateTestHotel(HotelId); @@ -102,14 +107,14 @@ public async Task ItCanCreateCollectionUpsertAndGetAsync(bool includeVectors, bo } } - [Fact(Skip = SkipReason)] + [RetryFact(typeof(MongoCommandException), Skip = SkipReason)] public async Task ItCanDeleteCollectionAsync() { // Arrange const string TempCollectionName = "temp-test"; await fixture.MongoDatabase.CreateCollectionAsync(TempCollectionName); - var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, TempCollectionName); + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, TempCollectionName); Assert.True(await sut.CollectionExistsAsync()); @@ -120,12 +125,12 @@ public async Task ItCanDeleteCollectionAsync() Assert.False(await sut.CollectionExistsAsync()); } - [Fact(Skip = SkipReason)] + [RetryFact(typeof(MongoCommandException), Skip = SkipReason)] public async Task ItCanGetAndDeleteRecordAsync() { // Arrange const string HotelId = "55555555-5555-5555-5555-555555555555"; - var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); var record = this.CreateTestHotel(HotelId); @@ -144,7 +149,7 @@ public async Task ItCanGetAndDeleteRecordAsync() Assert.Null(getResult); } - [Fact(Skip = SkipReason)] + [RetryFact(typeof(MongoCommandException), Skip = SkipReason)] public async Task ItCanGetAndDeleteBatchAsync() { // Arrange @@ -152,14 +157,14 @@ public async Task ItCanGetAndDeleteBatchAsync() const string HotelId2 = "22222222-2222-2222-2222-222222222222"; const string HotelId3 = "33333333-3333-3333-3333-333333333333"; - var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); var record1 = this.CreateTestHotel(HotelId1); var record2 = this.CreateTestHotel(HotelId2); var record3 = this.CreateTestHotel(HotelId3); - var upsertResults = await sut.UpsertBatchAsync([record1, record2, record3]).ToListAsync(); - var getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + var upsertResults = await sut.UpsertAsync([record1, record2, record3]); + var getResults = await sut.GetAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); Assert.Equal([HotelId1, HotelId2, HotelId3], upsertResults); @@ -168,20 +173,20 @@ public async Task ItCanGetAndDeleteBatchAsync() Assert.NotNull(getResults.First(l => l.HotelId == HotelId3)); // Act - await sut.DeleteBatchAsync([HotelId1, HotelId2, HotelId3]); + await sut.DeleteAsync([HotelId1, HotelId2, HotelId3]); - getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + getResults = await sut.GetAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); // Assert Assert.Empty(getResults); } - [Fact(Skip = SkipReason)] + [RetryFact(typeof(MongoCommandException), Skip = SkipReason)] public async Task ItCanUpsertRecordAsync() { // Arrange const string HotelId = "55555555-5555-5555-5555-555555555555"; - var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); var record = this.CreateTestHotel(HotelId); @@ -204,7 +209,7 @@ public async Task ItCanUpsertRecordAsync() Assert.Equal(10, getResult.HotelRating); } - [Fact(Skip = SkipReason)] + [RetryFact(typeof(MongoCommandException), Skip = SkipReason)] public async Task UpsertWithModelWorksCorrectlyAsync() { // Arrange @@ -219,7 +224,7 @@ public async Task UpsertWithModelWorksCorrectlyAsync() var model = new TestModel { Id = "key", HotelName = "Test Name" }; - var sut = new MongoDBVectorStoreRecordCollection( + var sut = new MongoDBVectorStoreRecordCollection( fixture.MongoDatabase, fixture.TestCollection, new() { VectorStoreRecordDefinition = definition }); @@ -236,13 +241,13 @@ public async Task UpsertWithModelWorksCorrectlyAsync() Assert.Equal("Test Name", getResult.HotelName); } - [Fact(Skip = SkipReason)] + [RetryFact(typeof(MongoCommandException), Skip = SkipReason)] public async Task UpsertWithVectorStoreModelWorksCorrectlyAsync() { // Arrange var model = new VectorStoreTestModel { HotelId = "key", HotelName = "Test Name" }; - var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); // Act var upsertResult = await sut.UpsertAsync(model); @@ -256,7 +261,7 @@ public async Task UpsertWithVectorStoreModelWorksCorrectlyAsync() Assert.Equal("Test Name", getResult.HotelName); } - [Fact(Skip = SkipReason)] + [RetryFact(typeof(MongoCommandException), Skip = SkipReason)] public async Task UpsertWithBsonModelWorksCorrectlyAsync() { // Arrange @@ -271,7 +276,7 @@ public async Task UpsertWithBsonModelWorksCorrectlyAsync() var model = new BsonTestModel { Id = "key", HotelName = "Test Name" }; - var sut = new MongoDBVectorStoreRecordCollection( + var sut = new MongoDBVectorStoreRecordCollection( fixture.MongoDatabase, fixture.TestCollection, new() { VectorStoreRecordDefinition = definition }); @@ -288,13 +293,13 @@ public async Task UpsertWithBsonModelWorksCorrectlyAsync() Assert.Equal("Test Name", getResult.HotelName); } - [Fact(Skip = SkipReason)] + [RetryFact(typeof(MongoCommandException), Skip = SkipReason)] public async Task UpsertWithBsonVectorStoreModelWorksCorrectlyAsync() { // Arrange var model = new BsonVectorStoreTestModel { HotelId = "key", HotelName = "Test Name" }; - var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); // Act var upsertResult = await sut.UpsertAsync(model); @@ -308,13 +313,13 @@ public async Task UpsertWithBsonVectorStoreModelWorksCorrectlyAsync() Assert.Equal("Test Name", getResult.HotelName); } - [Fact(Skip = SkipReason)] + [RetryFact(typeof(MongoCommandException), Skip = SkipReason)] public async Task UpsertWithBsonVectorStoreWithNameModelWorksCorrectlyAsync() { // Arrange var model = new BsonVectorStoreWithNameTestModel { Id = "key", HotelName = "Test Name" }; - var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); // Act var upsertResult = await sut.UpsertAsync(model); @@ -328,7 +333,7 @@ public async Task UpsertWithBsonVectorStoreWithNameModelWorksCorrectlyAsync() Assert.Equal("Test Name", getResult.HotelName); } - [Fact(Skip = SkipReason)] + [RetryFact(typeof(MongoCommandException), Skip = SkipReason)] public async Task VectorizedSearchReturnsValidResultsByDefaultAsync() { // Arrange @@ -337,17 +342,16 @@ public async Task VectorizedSearchReturnsValidResultsByDefaultAsync() var hotel3 = this.CreateTestHotel(hotelId: "key3", embedding: new[] { 20f, 20f, 20f, 20f }); var hotel4 = this.CreateTestHotel(hotelId: "key4", embedding: new[] { -1000f, -1000f, -1000f, -1000f }); - var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, "TestVectorizedSearch"); + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, "TestVectorizedSearch"); await sut.CreateCollectionIfNotExistsAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f])); + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), top: 3).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); var ids = searchResults.Select(l => l.Record.HotelId).ToList(); Assert.Equal("key1", ids[0]); @@ -359,7 +363,7 @@ public async Task VectorizedSearchReturnsValidResultsByDefaultAsync() Assert.Equal(1, searchResults.First(l => l.Record.HotelId == "key1").Score); } - [Fact(Skip = SkipReason)] + [RetryFact(typeof(MongoCommandException), Skip = SkipReason)] public async Task VectorizedSearchReturnsValidResultsWithOffsetAsync() { // Arrange @@ -368,21 +372,19 @@ public async Task VectorizedSearchReturnsValidResultsWithOffsetAsync() var hotel3 = this.CreateTestHotel(hotelId: "key3", embedding: new[] { 20f, 20f, 20f, 20f }); var hotel4 = this.CreateTestHotel(hotelId: "key4", embedding: new[] { -1000f, -1000f, -1000f, -1000f }); - var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, "TestVectorizedSearchWithOffset"); + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, "TestVectorizedSearchWithOffset"); await sut.CreateCollectionIfNotExistsAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), new() + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), top: 2, new() { - Top = 2, Skip = 2 - }); + }).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); var ids = searchResults.Select(l => l.Record.HotelId).ToList(); Assert.Equal("key3", ids[0]); @@ -392,7 +394,7 @@ public async Task VectorizedSearchReturnsValidResultsWithOffsetAsync() Assert.DoesNotContain("key2", ids); } - [Fact(Skip = SkipReason)] + [RetryFact(typeof(MongoCommandException), Skip = SkipReason)] public async Task VectorizedSearchReturnsValidResultsWithFilterAsync() { // Arrange @@ -401,20 +403,19 @@ public async Task VectorizedSearchReturnsValidResultsWithFilterAsync() var hotel3 = this.CreateTestHotel(hotelId: "key3", embedding: new[] { 20f, 20f, 20f, 20f }); var hotel4 = this.CreateTestHotel(hotelId: "key4", embedding: new[] { -1000f, -1000f, -1000f, -1000f }); - var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, "TestVectorizedSearchWithOffset"); + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, "TestVectorizedSearchWithOffset"); await sut.CreateCollectionIfNotExistsAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), new() + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), top: 3, new() { OldFilter = new VectorSearchFilter().EqualTo(nameof(MongoDBHotel.HotelName), "My Hotel key2") - }); + }).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); var ids = searchResults.Select(l => l.Record.HotelId).ToList(); Assert.Equal("key2", ids[0]); @@ -424,49 +425,46 @@ public async Task VectorizedSearchReturnsValidResultsWithFilterAsync() Assert.DoesNotContain("key4", ids); } - [Fact(Skip = SkipReason)] - public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() + [RetryFact(typeof(MongoCommandException), Skip = SkipReason)] + public async Task ItCanUpsertAndRetrieveUsingTheDynamicMapperAsync() { // Arrange - var options = new MongoDBVectorStoreRecordCollectionOptions> + var options = new MongoDBVectorStoreRecordCollectionOptions> { VectorStoreRecordDefinition = fixture.HotelVectorStoreRecordDefinition }; - var sut = new MongoDBVectorStoreRecordCollection>(fixture.MongoDatabase, fixture.TestCollection, options); + var sut = new MongoDBVectorStoreRecordCollection>(fixture.MongoDatabase, fixture.TestCollection, options); // Act - var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel("GenericMapper-1") + var upsertResult = await sut.UpsertAsync(new Dictionary { - Data = - { - { "HotelName", "Generic Mapper Hotel" }, - { "Description", "This is a generic mapper hotel" }, - { "Tags", new string[] { "generic" } }, - { "ParkingIncluded", false }, - { "Timestamp", new DateTime(1970, 1, 18, 0, 0, 0).ToUniversalTime() }, - { "HotelRating", 3.6f } - }, - Vectors = - { - { "DescriptionEmbedding", new ReadOnlyMemory([30f, 31f, 32f, 33f]) } - } + ["HotelId"] = "DynamicMapper-1", + + ["HotelName"] = "Dynamic Mapper Hotel", + ["Description"] = "This is a dynamic mapper hotel", + ["Tags"] = new string[] { "dynamic" }, + ["ParkingIncluded"] = false, + ["Timestamp"] = new DateTime(1970, 1, 18, 0, 0, 0).ToUniversalTime(), + ["HotelRating"] = 3.6f, + + ["DescriptionEmbedding"] = new ReadOnlyMemory([30f, 31f, 32f, 33f]) }); - var localGetResult = await sut.GetAsync("GenericMapper-1", new GetRecordOptions { IncludeVectors = true }); + var localGetResult = await sut.GetAsync("DynamicMapper-1", new GetRecordOptions { IncludeVectors = true }); // Assert Assert.NotNull(upsertResult); - Assert.Equal("GenericMapper-1", upsertResult); + Assert.Equal("DynamicMapper-1", upsertResult); Assert.NotNull(localGetResult); - Assert.Equal("Generic Mapper Hotel", localGetResult.Data["HotelName"]); - Assert.Equal("This is a generic mapper hotel", localGetResult.Data["Description"]); - Assert.Equal(new[] { "generic" }, localGetResult.Data["Tags"]); - Assert.False((bool?)localGetResult.Data["ParkingIncluded"]); - Assert.Equal(new DateTime(1970, 1, 18, 0, 0, 0).ToUniversalTime(), localGetResult.Data["Timestamp"]); - Assert.Equal(3.6f, localGetResult.Data["HotelRating"]); - Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult.Vectors["DescriptionEmbedding"]!).ToArray()); + Assert.Equal("Dynamic Mapper Hotel", localGetResult["HotelName"]); + Assert.Equal("This is a dynamic mapper hotel", localGetResult["Description"]); + Assert.Equal(new[] { "dynamic" }, localGetResult["Tags"]); + Assert.False((bool?)localGetResult["ParkingIncluded"]); + Assert.Equal(new DateTime(1970, 1, 18, 0, 0, 0).ToUniversalTime(), localGetResult["Timestamp"]); + Assert.Equal(3.6f, localGetResult["HotelRating"]); + Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult["DescriptionEmbedding"]!).ToArray()); } #region private diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreTests.cs index fd6cd229091d..7673800714d1 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreTests.cs @@ -2,13 +2,11 @@ using Microsoft.SemanticKernel.Connectors.MongoDB; using SemanticKernel.IntegrationTests.Connectors.Memory; -using SemanticKernel.IntegrationTests.Connectors.Memory.Xunit; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.MongoDB; [Collection("MongoDBVectorStoreCollection")] -[DisableVectorStoreTests(Skip = "The tests are for manual verification.")] public class MongoDBVectorStoreTests(MongoDBVectorStoreFixture fixture) : BaseVectorStoreTests(new MongoDBVectorStore(fixture.MongoDatabase)) { diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs index 48a8f5f36a41..1ab2619aa869 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs @@ -44,7 +44,7 @@ public record PostgresHotel() public string Description { get; set; } /// A vector field. - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.EuclideanDistance, IndexKind: IndexKind.Hnsw)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.EuclideanDistance, IndexKind = IndexKind.Hnsw)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } public DateTime CreatedAt { get; set; } = DateTime.UtcNow; diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs index 71474ff0ebc6..a9fe0532b895 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs @@ -16,6 +16,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; /// /// Integration tests of . /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class PostgresMemoryStoreTests : IAsyncLifetime { // If null, all tests will be enabled diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index a80519c85a57..8a2200ae6bc3 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -192,8 +192,8 @@ public async Task ItCanGetUpsertDeleteBatchAsync() var record2 = new PostgresHotel { HotelId = HotelId2, HotelName = "Hotel 2", HotelCode = 1, ParkingIncluded = false, HotelRating = 3.5f, Tags = ["tag1", "tag3"] }; var record3 = new PostgresHotel { HotelId = HotelId3, HotelName = "Hotel 3", HotelCode = 1, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag4"] }; - var upsertResults = await sut.UpsertBatchAsync([record1, record2, record3]).ToListAsync(); - var getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + var upsertResults = await sut.UpsertAsync([record1, record2, record3]); + var getResults = await sut.GetAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); Assert.Equal([HotelId1, HotelId2, HotelId3], upsertResults); @@ -202,9 +202,9 @@ public async Task ItCanGetUpsertDeleteBatchAsync() Assert.NotNull(getResults.First(l => l.HotelId == HotelId3)); // Act - await sut.DeleteBatchAsync([HotelId1, HotelId2, HotelId3]); + await sut.DeleteAsync([HotelId1, HotelId2, HotelId3]); - getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + getResults = await sut.GetAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); // Assert Assert.Empty(getResults); @@ -281,31 +281,28 @@ public async Task ItCanReadManuallyInsertedRecordAsync() } [Fact] - public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() + public async Task ItCanUpsertAndRetrieveUsingTheDynamicMapperAsync() { const int HotelId = 5; - var sut = fixture.GetCollection>("GenericMapperWithNumericKey", GetVectorStoreRecordDefinition()); + var sut = fixture.GetCollection>("DynamicMapperWithNumericKey", GetVectorStoreRecordDefinition()); await sut.CreateCollectionAsync(); var record = new PostgresHotel { HotelId = (int)HotelId, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; // Act - var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel(HotelId) + var upsertResult = await sut.UpsertAsync(new Dictionary { - Data = - { - { "HotelName", "Generic Mapper Hotel" }, - { "Description", "This is a generic mapper hotel" }, - { "HotelCode", 1 }, - { "ParkingIncluded", true }, - { "HotelRating", 3.6f } - }, - Vectors = - { - { "DescriptionEmbedding", new ReadOnlyMemory([30f, 31f, 32f, 33f]) } - } + ["HotelId"] = HotelId, + + ["HotelName"] = "Dynamic Mapper Hotel", + ["Description"] = "This is a dynamic mapper hotel", + ["HotelCode"] = 1, + ["ParkingIncluded"] = true, + ["HotelRating"] = 3.6f, + + ["DescriptionEmbedding"] = new ReadOnlyMemory([30f, 31f, 32f, 33f]) }); var localGetResult = await sut.GetAsync(HotelId, new GetRecordOptions { IncludeVectors = true }); @@ -314,35 +311,32 @@ public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() Assert.Equal(HotelId, upsertResult); Assert.NotNull(localGetResult); - Assert.Equal("Generic Mapper Hotel", localGetResult.Data["HotelName"]); - Assert.Equal("This is a generic mapper hotel", localGetResult.Data["Description"]); - Assert.True((bool?)localGetResult.Data["ParkingIncluded"]); - Assert.Equal(3.6f, localGetResult.Data["HotelRating"]); - Assert.Equal([30f, 31f, 32f, 33f], ((ReadOnlyMemory)localGetResult.Vectors["DescriptionEmbedding"]!).ToArray()); + Assert.Equal("Dynamic Mapper Hotel", localGetResult["HotelName"]); + Assert.Equal("This is a dynamic mapper hotel", localGetResult["Description"]); + Assert.True((bool?)localGetResult["ParkingIncluded"]); + Assert.Equal(3.6f, localGetResult["HotelRating"]); + Assert.Equal([30f, 31f, 32f, 33f], ((ReadOnlyMemory)localGetResult["DescriptionEmbedding"]!).ToArray()); // Act - update with null embeddings // Act - var upsertResult2 = await sut.UpsertAsync(new VectorStoreGenericDataModel(HotelId) + var upsertResult2 = await sut.UpsertAsync(new Dictionary { - Data = - { - { "HotelName", "Generic Mapper Hotel" }, - { "Description", "This is a generic mapper hotel" }, - { "HotelCode", 1 }, - { "ParkingIncluded", true }, - { "HotelRating", 3.6f } - }, - Vectors = - { - { "DescriptionEmbedding", null } - } + ["HotelId"] = HotelId, + + ["HotelName"] = "Dynamic Mapper Hotel", + ["Description"] = "This is a dynamic mapper hotel", + ["HotelCode"] = 1, + ["ParkingIncluded"] = true, + ["HotelRating"] = 3.6f, + + ["DescriptionEmbedding"] = null }); var localGetResult2 = await sut.GetAsync(HotelId, new GetRecordOptions { IncludeVectors = true }); // Assert Assert.NotNull(localGetResult2); - Assert.Null(localGetResult2.Vectors["DescriptionEmbedding"]); + Assert.Null(localGetResult2["DescriptionEmbedding"]); } [Theory] @@ -364,15 +358,13 @@ public async Task VectorizedSearchReturnsValidResultsByDefaultAsync(bool include await sut.CreateCollectionAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([0.9f, 0.1f, 0.5f, 0.8f]), new() + var results = await sut.VectorizedSearchAsync(new ReadOnlyMemory([0.9f, 0.1f, 0.5f, 0.8f]), top: 3, new() { IncludeVectors = includeVectors - }); - - var results = await searchResults.Results.ToListAsync(); + }).ToListAsync(); // Assert var ids = results.Select(l => l.Record.HotelId).ToList(); @@ -402,19 +394,16 @@ public async Task VectorizedSearchWithEqualToFilterReturnsValidResultsAsync() await sut.CreateCollectionAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 29f, 28f, 27f]), new() + var results = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 29f, 28f, 27f]), top: 5, new() { IncludeVectors = false, - Top = 5, OldFilter = new([ new EqualToFilterClause("HotelRating", 2.5f) ]) - }); - - var results = await searchResults.Results.ToListAsync(); + }).ToListAsync(); // Assert var ids = results.Select(l => l.Record.HotelId).ToList(); @@ -435,19 +424,16 @@ public async Task VectorizedSearchWithAnyTagFilterReturnsValidResultsAsync() await sut.CreateCollectionAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 29f, 28f, 27f]), new() + var results = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 29f, 28f, 27f]), top: 5, new() { IncludeVectors = false, - Top = 5, OldFilter = new([ new AnyTagEqualToFilterClause("Tags", "tag2") ]) - }); - - var results = await searchResults.Results.ToListAsync(); + }).ToListAsync(); // Assert var ids = results.Select(l => l.Record.HotelId).ToList(); @@ -523,7 +509,7 @@ public async Task ItCanUpsertAndGetEnumerableTypesAsync() new VectorStoreRecordDataProperty("Tags", typeof(List)), new VectorStoreRecordDataProperty("ListInts", typeof(List)), new VectorStoreRecordDataProperty("Description", typeof(string)), - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, IndexKind = IndexKind.Hnsw, DistanceFunction = distanceFunction } + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?), 4) { IndexKind = IndexKind.Hnsw, DistanceFunction = distanceFunction } ] }; @@ -561,7 +547,7 @@ private sealed class RecordWithEnumerables [VectorStoreRecordKey] public int Id { get; set; } - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineDistance)] public ReadOnlyMemory? Embedding { get; set; } [VectorStoreRecordData] diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/CommonQdrantVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/CommonQdrantVectorStoreRecordCollectionTests.cs index 3bafffc6a3bb..ebdeb14f4d02 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/CommonQdrantVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/CommonQdrantVectorStoreRecordCollectionTests.cs @@ -21,7 +21,7 @@ public class CommonQdrantVectorStoreRecordCollectionTests(QdrantVectorStoreFixtu protected override IVectorStoreRecordCollection GetTargetRecordCollection(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition) { - return new QdrantVectorStoreRecordCollection(fixture.QdrantClient, recordCollectionName, new() + return new QdrantVectorStoreRecordCollection(fixture.QdrantClient, recordCollectionName, new() { HasNamedVectors = true, VectorStoreRecordDefinition = vectorStoreRecordDefinition diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantTextSearchTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantTextSearchTests.cs index fcf164bfc449..8f5d4a9a4b39 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantTextSearchTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantTextSearchTests.cs @@ -30,11 +30,15 @@ public override Task CreateTextSearchAsync() HasNamedVectors = true, VectorStoreRecordDefinition = fixture.HotelVectorStoreRecordDefinition, }; - var vectorSearch = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "namedVectorsHotels", options); + var vectorSearch = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "namedVectorsHotels", options); var stringMapper = new HotelInfoTextSearchStringMapper(); var resultMapper = new HotelInfoTextSearchResultMapper(); + // TODO: Once OpenAITextEmbeddingGenerationService implements MEAI's IEmbeddingGenerator (#10811), configure it with the AzureAISearchVectorStore above instead of passing it here. +#pragma warning disable CS0618 // VectorStoreTextSearch with ITextEmbeddingGenerationService is obsolete var result = new VectorStoreTextSearch(vectorSearch, this.EmbeddingGenerator!, stringMapper, resultMapper); +#pragma warning restore CS0618 + return Task.FromResult(result); } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreFixture.cs index 0c6bc64bd7d9..512b55873323 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreFixture.cs @@ -53,14 +53,14 @@ public QdrantVectorStoreFixture() Properties = new List { new VectorStoreRecordKeyProperty("HotelId", typeof(ulong)), - new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, - new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsFilterable = true }, - new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { IsFilterable = true, StoragePropertyName = "parking_is_included" }, - new VectorStoreRecordDataProperty("HotelRating", typeof(float)) { IsFilterable = true }, - new VectorStoreRecordDataProperty("LastRenovationDate", typeof(DateTimeOffset)) { IsFilterable = true }, - new VectorStoreRecordDataProperty("Tags", typeof(List)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsIndexed = true, IsFullTextIndexed = true }, + new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsIndexed = true }, + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { IsIndexed = true, StoragePropertyName = "parking_is_included" }, + new VectorStoreRecordDataProperty("HotelRating", typeof(float)) { IsIndexed = true }, + new VectorStoreRecordDataProperty("LastRenovationDate", typeof(DateTimeOffset)) { IsIndexed = true }, + new VectorStoreRecordDataProperty("Tags", typeof(List)) { IsIndexed = true }, new VectorStoreRecordDataProperty("Description", typeof(string)), - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = VectorDimensions, DistanceFunction = DistanceFunction.ManhattanDistance } + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?), VectorDimensions) { DistanceFunction = DistanceFunction.ManhattanDistance } } }; this.HotelWithGuidIdVectorStoreRecordDefinition = new VectorStoreRecordDefinition @@ -68,9 +68,9 @@ public QdrantVectorStoreFixture() Properties = new List { new VectorStoreRecordKeyProperty("HotelId", typeof(Guid)), - new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsIndexed = true, IsFullTextIndexed = true }, new VectorStoreRecordDataProperty("Description", typeof(string)), - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = VectorDimensions, DistanceFunction = DistanceFunction.ManhattanDistance } + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?), VectorDimensions) { DistanceFunction = DistanceFunction.ManhattanDistance } } }; AzureOpenAIConfiguration? embeddingsConfig = s_configuration.GetSection("AzureOpenAIEmbeddings").Get(); @@ -319,26 +319,26 @@ public record HotelInfo() public ulong HotelId { get; init; } /// A string metadata field. - [VectorStoreRecordData(IsFilterable = true, IsFullTextSearchable = true)] + [VectorStoreRecordData(IsIndexed = true, IsFullTextIndexed = true)] public string? HotelName { get; set; } /// An int metadata field. - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public int HotelCode { get; set; } /// A float metadata field. - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public float? HotelRating { get; set; } /// A bool metadata field. - [VectorStoreRecordData(IsFilterable = true, StoragePropertyName = "parking_is_included")] + [VectorStoreRecordData(IsIndexed = true, StoragePropertyName = "parking_is_included")] public bool ParkingIncluded { get; set; } - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public List Tags { get; set; } = new List(); - /// A DateTimeOffset metadata field. - [VectorStoreRecordData(IsFilterable = true)] + /// A datetime metadata field. + [VectorStoreRecordData(IsIndexed = true)] public DateTimeOffset? LastRenovationDate { get; set; } /// A data field. @@ -346,7 +346,7 @@ public record HotelInfo() public string Description { get; set; } /// A vector field. - [VectorStoreRecordVector(VectorDimensions, DistanceFunction.ManhattanDistance, IndexKind.Hnsw)] + [VectorStoreRecordVector(VectorDimensions, DistanceFunction = DistanceFunction.ManhattanDistance, IndexKind = IndexKind.Hnsw)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } } @@ -361,7 +361,7 @@ public record HotelInfoWithGuidId() public Guid HotelId { get; init; } /// A string metadata field. - [VectorStoreRecordData(IsFilterable = true, IsFullTextSearchable = true)] + [VectorStoreRecordData(IsIndexed = true, IsFullTextIndexed = true)] public string? HotelName { get; set; } /// A data field. @@ -369,7 +369,7 @@ public record HotelInfoWithGuidId() public string Description { get; set; } /// A vector field. - [VectorStoreRecordVector(VectorDimensions, DistanceFunction.ManhattanDistance, IndexKind.Hnsw)] + [VectorStoreRecordVector(VectorDimensions, DistanceFunction = DistanceFunction.ManhattanDistance, IndexKind = IndexKind.Hnsw)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreRecordCollectionTests.cs index 4b4a3529a7e4..30e8d5bfcd01 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreRecordCollectionTests.cs @@ -15,10 +15,11 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Qdrant; +#pragma warning disable CA1859 // Use concrete types when possible for improved performance #pragma warning disable CS0618 // VectorSearchFilter is obsolete /// -/// Contains tests for the class. +/// Contains tests for the class. /// /// Used for logging. /// Qdrant setup and teardown. @@ -31,7 +32,7 @@ public sealed class QdrantVectorStoreRecordCollectionTests(ITestOutputHelper out public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) { // Arrange. - var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName); + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName); // Act. var actual = await sut.CollectionExistsAsync(); @@ -57,7 +58,7 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool hasNamedVec HasNamedVectors = hasNamedVectors, VectorStoreRecordDefinition = useRecordDefinition ? fixture.HotelVectorStoreRecordDefinition : null }; - var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, testCollectionName, options); + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, testCollectionName, options); var record = await this.CreateTestHotelAsync(30, fixture.EmbeddingGenerator); @@ -66,9 +67,10 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool hasNamedVec var upsertResult = await sut.UpsertAsync(record); var getResult = await sut.GetAsync(30, new GetRecordOptions { IncludeVectors = true }); var vector = await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("A great hotel"); - var actual = await sut.VectorizedSearchAsync( + var searchResults = await sut.VectorizedSearchAsync( vector, - new() { OldFilter = new VectorSearchFilter().EqualTo("HotelCode", 30).AnyTagEqualTo("Tags", "t2") }); + top: 3, + new() { OldFilter = new VectorSearchFilter().EqualTo("HotelCode", 30).AnyTagEqualTo("Tags", "t2") }).ToListAsync(); // Assert var collectionExistResult = await sut.CollectionExistsAsync(); @@ -85,7 +87,6 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool hasNamedVec Assert.Equal(record.Tags.ToArray(), getResult?.Tags.ToArray()); Assert.Equal(record.Description, getResult?.Description); - var searchResults = await actual.Results.ToListAsync(); Assert.Single(searchResults); var searchResultRecord = searchResults.First().Record; Assert.Equal(record.HotelId, searchResultRecord?.HotelId); @@ -112,7 +113,7 @@ await fixture.QdrantClient.CreateCollectionAsync( tempCollectionName, new VectorParams { Size = 4, Distance = Distance.Cosine }); - var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, tempCollectionName); + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, tempCollectionName); // Act await sut.DeleteCollectionAsync(); @@ -134,7 +135,7 @@ public async Task ItCanUpsertDocumentToVectorStoreAsync(bool useRecordDefinition HasNamedVectors = hasNamedVectors, VectorStoreRecordDefinition = useRecordDefinition ? fixture.HotelVectorStoreRecordDefinition : null }; - var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName, options); + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName, options); var record = await this.CreateTestHotelAsync(20, fixture.EmbeddingGenerator); @@ -165,7 +166,7 @@ public async Task ItCanUpsertAndRemoveDocumentWithGuidIdToVectorStoreAsync() { // Arrange. var options = new QdrantVectorStoreRecordCollectionOptions { HasNamedVectors = false }; - IVectorStoreRecordCollection sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "singleVectorGuidIdHotels", options); + IVectorStoreRecordCollection sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "singleVectorGuidIdHotels", options); var record = new HotelInfoWithGuidId { @@ -213,7 +214,7 @@ public async Task ItCanGetDocumentFromVectorStoreAsync(bool useRecordDefinition, HasNamedVectors = hasNamedVectors, VectorStoreRecordDefinition = useRecordDefinition ? fixture.HotelVectorStoreRecordDefinition : null }; - var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName, options); + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName, options); // Act. var getResult = await sut.GetAsync(11, new GetRecordOptions { IncludeVectors = withEmbeddings }); @@ -255,7 +256,7 @@ public async Task ItCanGetDocumentWithGuidIdFromVectorStoreAsync(bool useRecordD HasNamedVectors = false, VectorStoreRecordDefinition = useRecordDefinition ? fixture.HotelWithGuidIdVectorStoreRecordDefinition : null }; - var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "singleVectorGuidIdHotels", options); + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "singleVectorGuidIdHotels", options); // Act. var getResult = await sut.GetAsync(Guid.Parse("11111111-1111-1111-1111-111111111111"), new GetRecordOptions { IncludeVectors = withEmbeddings }); @@ -282,11 +283,11 @@ public async Task ItCanGetManyDocumentsFromVectorStoreAsync() { // Arrange var options = new QdrantVectorStoreRecordCollectionOptions { HasNamedVectors = true }; - var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "namedVectorsHotels", options); + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "namedVectorsHotels", options); // Act // Also include one non-existing key to test that the operation does not fail for these and returns only the found ones. - var hotels = sut.GetBatchAsync([11, 15, 12], new GetRecordOptions { IncludeVectors = true }); + var hotels = sut.GetAsync([11, 15, 12], new GetRecordOptions { IncludeVectors = true }); // Assert Assert.NotNull(hotels); @@ -313,7 +314,7 @@ public async Task ItCanRemoveDocumentFromVectorStoreAsync(bool useRecordDefiniti HasNamedVectors = hasNamedVectors, VectorStoreRecordDefinition = useRecordDefinition ? fixture.HotelVectorStoreRecordDefinition : null }; - var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName, options); + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName, options); await sut.UpsertAsync(await this.CreateTestHotelAsync(20, fixture.EmbeddingGenerator)); @@ -339,13 +340,13 @@ public async Task ItCanRemoveManyDocumentsFromVectorStoreAsync(bool useRecordDef HasNamedVectors = hasNamedVectors, VectorStoreRecordDefinition = useRecordDefinition ? fixture.HotelVectorStoreRecordDefinition : null }; - var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName, options); + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName, options); await sut.UpsertAsync(await this.CreateTestHotelAsync(20, fixture.EmbeddingGenerator)); // Act. // Also delete a non-existing key to test that the operation does not fail for these. - await sut.DeleteBatchAsync([20, 21]); + await sut.DeleteAsync([20, 21]); // Assert. Assert.Null(await sut.GetAsync(20)); @@ -356,23 +357,12 @@ public async Task ItReturnsNullWhenGettingNonExistentRecordAsync() { // Arrange var options = new QdrantVectorStoreRecordCollectionOptions { HasNamedVectors = false }; - var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "singleVectorHotels", options); + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "singleVectorHotels", options); // Act & Assert Assert.Null(await sut.GetAsync(15, new GetRecordOptions { IncludeVectors = true })); } - [Fact] - public async Task ItThrowsMappingExceptionForFailedMapperAsync() - { - // Arrange - var options = new QdrantVectorStoreRecordCollectionOptions { PointStructCustomMapper = new FailingMapper() }; - var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "singleVectorHotels", options); - - // Act & Assert - await Assert.ThrowsAsync(async () => await sut.GetAsync(11, new GetRecordOptions { IncludeVectors = true })); - } - [Theory] [InlineData(true, "singleVectorHotels", false, "equality")] [InlineData(false, "singleVectorHotels", false, "equality")] @@ -390,20 +380,20 @@ public async Task ItCanSearchWithFilterAsync(bool useRecordDefinition, string co HasNamedVectors = hasNamedVectors, VectorStoreRecordDefinition = useRecordDefinition ? fixture.HotelVectorStoreRecordDefinition : null }; - var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName, options); + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName, options); // Act. var vector = await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("A great hotel"); var filter = filterType == "equality" ? new VectorSearchFilter().EqualTo("HotelName", "My Hotel 13").EqualTo("LastRenovationDate", new DateTimeOffset(2020, 02, 01, 0, 0, 0, TimeSpan.Zero)) : new VectorSearchFilter().AnyTagEqualTo("Tags", "t13.2"); - var actual = await sut.VectorizedSearchAsync( + var searchResults = await sut.VectorizedSearchAsync( vector, + top: 3, new() { OldFilter = filter - }); + }).ToListAsync(); // Assert. - var searchResults = await actual.Results.ToListAsync(); Assert.Single(searchResults); var searchResultRecord = searchResults.First().Record; @@ -416,59 +406,56 @@ public async Task ItCanSearchWithFilterAsync(bool useRecordDefinition, string co } [Fact] - public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() + public async Task ItCanUpsertAndRetrieveUsingTheDynamicMapperAsync() { // Arrange - var options = new QdrantVectorStoreRecordCollectionOptions> + var options = new QdrantVectorStoreRecordCollectionOptions> { VectorStoreRecordDefinition = fixture.HotelVectorStoreRecordDefinition }; - var sut = new QdrantVectorStoreRecordCollection>(fixture.QdrantClient, "singleVectorHotels", options); + var sut = new QdrantVectorStoreRecordCollection>(fixture.QdrantClient, "singleVectorHotels", options); // Act - var baseSetGetResult = await sut.GetAsync(11, new GetRecordOptions { IncludeVectors = true }); - var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel(40) + var baseSetGetResult = await sut.GetAsync(11ul, new GetRecordOptions { IncludeVectors = true }); + var upsertResult = await sut.UpsertAsync(new Dictionary { - Data = - { - { "HotelName", "Generic Mapper Hotel" }, - { "HotelCode", 40 }, - { "ParkingIncluded", false }, - { "HotelRating", 3.6d }, - { "Tags", new string[] { "generic" } }, - { "Description", "This is a generic mapper hotel" }, - }, - Vectors = - { - { "DescriptionEmbedding", await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("This is a generic mapper hotel") } - } + ["HotelId"] = 40ul, + + ["HotelName"] = "Dynamic Mapper Hotel", + ["HotelCode"] = 40, + ["ParkingIncluded"] = false, + ["HotelRating"] = 3.6d, + ["Tags"] = new string[] { "dynamic" }, + ["Description"] = "This is a dynamic mapper hotel", + + ["DescriptionEmbedding"] = await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("This is a dynamic mapper hotel") }); - var localGetResult = await sut.GetAsync(40, new GetRecordOptions { IncludeVectors = true }); + var localGetResult = await sut.GetAsync(40ul, new GetRecordOptions { IncludeVectors = true }); // Assert Assert.NotNull(baseSetGetResult); - Assert.Equal(11ul, baseSetGetResult.Key); - Assert.Equal("My Hotel 11", baseSetGetResult.Data["HotelName"]); - Assert.Equal(11, baseSetGetResult.Data["HotelCode"]); - Assert.True((bool)baseSetGetResult.Data["ParkingIncluded"]!); - Assert.Equal(4.5f, baseSetGetResult.Data["HotelRating"]); - Assert.Equal(new[] { "t11.1", "t11.2" }, ((List)baseSetGetResult.Data["Tags"]!).ToArray()); - Assert.Equal("This is a great hotel.", baseSetGetResult.Data["Description"]); - Assert.NotNull(baseSetGetResult.Vectors["DescriptionEmbedding"]); - Assert.IsType>(baseSetGetResult.Vectors["DescriptionEmbedding"]); + Assert.Equal(11ul, baseSetGetResult["HotelId"]); + Assert.Equal("My Hotel 11", baseSetGetResult["HotelName"]); + Assert.Equal(11, baseSetGetResult["HotelCode"]); + Assert.True((bool)baseSetGetResult["ParkingIncluded"]!); + Assert.Equal(4.5f, baseSetGetResult["HotelRating"]); + Assert.Equal(new[] { "t11.1", "t11.2" }, ((List)baseSetGetResult["Tags"]!).ToArray()); + Assert.Equal("This is a great hotel.", baseSetGetResult["Description"]); + Assert.NotNull(baseSetGetResult["DescriptionEmbedding"]); + Assert.IsType>(baseSetGetResult["DescriptionEmbedding"]); Assert.Equal(40ul, upsertResult); Assert.NotNull(localGetResult); - Assert.Equal(40ul, localGetResult.Key); - Assert.Equal("Generic Mapper Hotel", localGetResult.Data["HotelName"]); - Assert.Equal(40, localGetResult.Data["HotelCode"]); - Assert.False((bool)localGetResult.Data["ParkingIncluded"]!); - Assert.Equal(3.6f, localGetResult.Data["HotelRating"]); - Assert.Equal(new[] { "generic" }, ((List)localGetResult.Data["Tags"]!).ToArray()); - Assert.Equal("This is a generic mapper hotel", localGetResult.Data["Description"]); - Assert.NotNull(localGetResult.Vectors["DescriptionEmbedding"]); - Assert.IsType>(localGetResult.Vectors["DescriptionEmbedding"]); + Assert.Equal(40ul, localGetResult["HotelId"]); + Assert.Equal("Dynamic Mapper Hotel", localGetResult["HotelName"]); + Assert.Equal(40, localGetResult["HotelCode"]); + Assert.False((bool)localGetResult["ParkingIncluded"]!); + Assert.Equal(3.6f, localGetResult["HotelRating"]); + Assert.Equal(new[] { "dynamic" }, ((List)localGetResult["Tags"]!).ToArray()); + Assert.Equal("This is a dynamic mapper hotel", localGetResult["Description"]); + Assert.NotNull(localGetResult["DescriptionEmbedding"]); + Assert.IsType>(localGetResult["DescriptionEmbedding"]); } private async Task CreateTestHotelAsync(uint hotelId, ITextEmbeddingGenerationService embeddingGenerator) @@ -486,17 +473,4 @@ private async Task CreateTestHotelAsync(uint hotelId, ITextEmbeddingG DescriptionEmbedding = await embeddingGenerator.GenerateEmbeddingAsync("This is a great hotel."), }; } - - private sealed class FailingMapper : IVectorStoreRecordMapper - { - public PointStruct MapFromDataToStorageModel(HotelInfo dataModel) - { - throw new NotImplementedException(); - } - - public HotelInfo MapFromStorageToDataModel(PointStruct storageModel, StorageToDataModelMapperOptions options) - { - throw new NotImplementedException(); - } - } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreTests.cs index 39551054e4bb..a66f1d563b40 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreTests.cs @@ -20,7 +20,7 @@ public async Task ItPassesSettingsFromVectorStoreToCollectionAsync() var collectionFromVS = sut.GetCollection("SettingsPassedCollection"); await collectionFromVS.CreateCollectionIfNotExistsAsync(); - var directCollection = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "SettingsPassedCollection", new() { HasNamedVectors = true }); + var directCollection = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "SettingsPassedCollection", new() { HasNamedVectors = true }); await directCollection.UpsertAsync(new QdrantVectorStoreFixture.HotelInfo { HotelId = 1ul, diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/CommonRedisHashsetVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/CommonRedisHashsetVectorStoreRecordCollectionTests.cs index bfdabddb041a..d16c63998b36 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/CommonRedisHashsetVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/CommonRedisHashsetVectorStoreRecordCollectionTests.cs @@ -23,7 +23,7 @@ public class CommonRedisHashsetVectorStoreRecordCollectionTests(RedisVectorStore protected override IVectorStoreRecordCollection GetTargetRecordCollection(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition) { - return new RedisHashSetVectorStoreRecordCollection(fixture.Database, recordCollectionName + "hashset", new() + return new RedisHashSetVectorStoreRecordCollection(fixture.Database, recordCollectionName + "hashset", new() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }); diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/CommonRedisJsonVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/CommonRedisJsonVectorStoreRecordCollectionTests.cs index ba32545c8373..2f79dcce7ef9 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/CommonRedisJsonVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/CommonRedisJsonVectorStoreRecordCollectionTests.cs @@ -23,7 +23,7 @@ public class CommonRedisJsonVectorStoreRecordCollectionTests(RedisVectorStoreFix protected override IVectorStoreRecordCollection GetTargetRecordCollection(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition) { - return new RedisJsonVectorStoreRecordCollection(fixture.Database, recordCollectionName + "json", new() + return new RedisJsonVectorStoreRecordCollection(fixture.Database, recordCollectionName + "json", new() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }); diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs index 91723c852047..8b64c881b786 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs @@ -1,13 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Redis; using NRedisStack.RedisStackCommands; using NRedisStack.Search; -using StackExchange.Redis; using Xunit; using Xunit.Abstractions; @@ -16,7 +16,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Redis; #pragma warning disable CS0618 // VectorSearchFilter is obsolete /// -/// Contains tests for the class. +/// Contains tests for the class. /// /// Used for logging. /// Redis setup and teardown. @@ -34,7 +34,7 @@ public sealed class RedisHashSetVectorStoreRecordCollectionTests(ITestOutputHelp public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) { // Arrange. - var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, collectionName); + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, collectionName); // Act. var actual = await sut.CollectionExistsAsync(); @@ -58,16 +58,17 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDe PrefixCollectionNameToKeyNames = true, VectorStoreRecordDefinition = useRecordDefinition ? fixture.BasicVectorStoreRecordDefinition : null }; - var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, testCollectionName, options); + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, testCollectionName, options); // Act await sut.CreateCollectionAsync(); var upsertResult = await sut.UpsertAsync(record); var getResult = await sut.GetAsync("HUpsert-1", new GetRecordOptions { IncludeVectors = true }); - var actual = await sut + var searchResults = await sut .VectorizedSearchAsync( new ReadOnlyMemory(new[] { 30f, 31f, 32f, 33f }), - new() { OldFilter = new VectorSearchFilter().EqualTo("HotelCode", 1), IncludeVectors = true }); + top: 3, + new() { OldFilter = new VectorSearchFilter().EqualTo("HotelCode", 1), IncludeVectors = true }).ToListAsync(); // Assert var collectionExistResult = await sut.CollectionExistsAsync(); @@ -83,7 +84,6 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDe Assert.Equal(record.Description, getResult?.Description); Assert.Equal(record.DescriptionEmbedding?.ToArray(), getResult?.DescriptionEmbedding?.ToArray()); - var searchResults = await actual.Results.ToListAsync(); Assert.Single(searchResults); Assert.Equal(1, searchResults.First().Score); var searchResultRecord = searchResults.First().Record; @@ -112,7 +112,7 @@ public async Task ItCanDeleteCollectionAsync() createParams.AddPrefix(tempCollectionName); await fixture.Database.FT().CreateAsync(tempCollectionName, createParams, schema); - var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, tempCollectionName); + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, tempCollectionName); // Act await sut.DeleteCollectionAsync(); @@ -132,7 +132,7 @@ public async Task ItCanUpsertDocumentToVectorStoreAsync(bool useRecordDefinition PrefixCollectionNameToKeyNames = true, VectorStoreRecordDefinition = useRecordDefinition ? fixture.BasicVectorStoreRecordDefinition : null }; - var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); var record = CreateTestHotel("HUpsert-2", 2); // Act. @@ -165,10 +165,10 @@ public async Task ItCanUpsertManyDocumentsToVectorStoreAsync(bool useRecordDefin PrefixCollectionNameToKeyNames = true, VectorStoreRecordDefinition = useRecordDefinition ? fixture.BasicVectorStoreRecordDefinition : null }; - var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); // Act. - var results = sut.UpsertBatchAsync( + var results = sut.UpsertAsync( [ CreateTestHotel("HUpsertMany-1", 1), CreateTestHotel("HUpsertMany-2", 2), @@ -177,7 +177,7 @@ public async Task ItCanUpsertManyDocumentsToVectorStoreAsync(bool useRecordDefin // Assert. Assert.NotNull(results); - var resultsList = await results.ToListAsync(); + var resultsList = await results; Assert.Equal(3, resultsList.Count); Assert.Contains("HUpsertMany-1", resultsList); @@ -204,7 +204,7 @@ public async Task ItCanGetDocumentFromVectorStoreAsync(bool includeVectors, bool PrefixCollectionNameToKeyNames = true, VectorStoreRecordDefinition = useRecordDefinition ? fixture.BasicVectorStoreRecordDefinition : null }; - var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); // Act. var getResult = await sut.GetAsync("HBaseSet-1", new GetRecordOptions { IncludeVectors = includeVectors }); @@ -234,11 +234,11 @@ public async Task ItCanGetManyDocumentsFromVectorStoreAsync() { // Arrange var options = new RedisHashSetVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; - var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); // Act // Also include one non-existing key to test that the operation does not fail for these and returns only the found ones. - var hotels = sut.GetBatchAsync(["HBaseSet-1", "HBaseSet-5", "HBaseSet-2"], new GetRecordOptions { IncludeVectors = true }); + var hotels = sut.GetAsync(["HBaseSet-1", "HBaseSet-5", "HBaseSet-2"], new GetRecordOptions { IncludeVectors = true }); // Assert Assert.NotNull(hotels); @@ -263,7 +263,7 @@ public async Task ItCanRemoveDocumentFromVectorStoreAsync(bool useRecordDefiniti PrefixCollectionNameToKeyNames = true, VectorStoreRecordDefinition = useRecordDefinition ? fixture.BasicVectorStoreRecordDefinition : null }; - var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); var record = new RedisBasicFloat32Hotel { HotelId = "HRemove-1", @@ -289,14 +289,14 @@ public async Task ItCanRemoveManyDocumentsFromVectorStoreAsync() { // Arrange var options = new RedisHashSetVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; - var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); await sut.UpsertAsync(CreateTestHotel("HRemoveMany-1", 1)); await sut.UpsertAsync(CreateTestHotel("HRemoveMany-2", 2)); await sut.UpsertAsync(CreateTestHotel("HRemoveMany-3", 3)); // Act // Also include a non-existing key to test that the operation does not fail for these. - await sut.DeleteBatchAsync(["HRemoveMany-1", "HRemoveMany-2", "HRemoveMany-3", "HRemoveMany-4"]); + await sut.DeleteAsync(["HRemoveMany-1", "HRemoveMany-2", "HRemoveMany-3", "HRemoveMany-4"]); // Assert Assert.Null(await sut.GetAsync("HRemoveMany-1", new GetRecordOptions { IncludeVectors = true })); @@ -311,21 +311,21 @@ public async Task ItCanSearchWithFloat32VectorAndFilterAsync(string filterType, { // Arrange var options = new RedisHashSetVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; - var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); var vector = new ReadOnlyMemory(new[] { 30f, 31f, 32f, 33f }); var filter = filterType == "equality" ? new VectorSearchFilter().EqualTo("HotelCode", 1) : new VectorSearchFilter().EqualTo("HotelName", "My Hotel 1"); // Act - var actual = await sut.VectorizedSearchAsync( + var searchResults = await sut.VectorizedSearchAsync( vector, + top: 3, new() { IncludeVectors = includeVectors, OldFilter = filter - }); + }).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); Assert.Single(searchResults); Assert.Equal(1, searchResults.First().Score); var searchResult = searchResults.First().Record; @@ -350,7 +350,7 @@ public async Task ItCanSearchWithFloat32VectorAndTopSkipAsync() { // Arrange var options = new RedisHashSetVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; - var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName + "TopSkip", options); + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName + "TopSkip", options); await sut.CreateCollectionIfNotExistsAsync(); await sut.UpsertAsync(new RedisBasicFloat32Hotel { HotelId = "HTopSkip_1", HotelName = "1", Description = "Nice hotel", DescriptionEmbedding = new ReadOnlyMemory([1.0f, 1.0f, 1.0f, 1.0f]) }); await sut.UpsertAsync(new RedisBasicFloat32Hotel { HotelId = "HTopSkip_2", HotelName = "2", Description = "Nice hotel", DescriptionEmbedding = new ReadOnlyMemory([1.0f, 1.0f, 1.0f, 2.0f]) }); @@ -360,16 +360,15 @@ public async Task ItCanSearchWithFloat32VectorAndTopSkipAsync() var vector = new ReadOnlyMemory([1.0f, 1.0f, 1.0f, 1.0f]); // Act - var actual = await sut.VectorizedSearchAsync( + var searchResults = await sut.VectorizedSearchAsync( vector, + top: 3, new() { - Top = 3, Skip = 2 - }); + }).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); Assert.Equal(3, searchResults.Count); Assert.True(searchResults.Select(x => x.Record.HotelId).SequenceEqual(["HTopSkip_3", "HTopSkip_4", "HTopSkip_5"])); } @@ -381,7 +380,7 @@ public async Task ItCanSearchWithFloat64VectorAsync(bool includeVectors) { // Arrange var options = new RedisHashSetVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; - var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName + "Float64", options); + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName + "Float64", options); await sut.CreateCollectionIfNotExistsAsync(); await sut.UpsertAsync(new RedisBasicFloat64Hotel { HotelId = "HFloat64_1", HotelName = "1", Description = "Nice hotel", DescriptionEmbedding = new ReadOnlyMemory([1.0d, 1.1d, 1.2d, 1.3d]) }); await sut.UpsertAsync(new RedisBasicFloat64Hotel { HotelId = "HFloat64_2", HotelName = "2", Description = "Nice hotel", DescriptionEmbedding = new ReadOnlyMemory([2.0d, 2.1d, 2.2d, 2.3d]) }); @@ -390,16 +389,15 @@ public async Task ItCanSearchWithFloat64VectorAsync(bool includeVectors) var vector = new ReadOnlyMemory([2.0d, 2.1d, 2.2d, 2.3d]); // Act - var actual = await sut.VectorizedSearchAsync( + var searchResults = await sut.VectorizedSearchAsync( vector, + top: 1, new() { IncludeVectors = includeVectors, - Top = 1 - }); + }).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); Assert.Single(searchResults); var searchResult = searchResults.First().Record; Assert.Equal("HFloat64_2", searchResult?.HotelId); @@ -420,78 +418,60 @@ public async Task ItReturnsNullWhenGettingNonExistentRecordAsync() { // Arrange var options = new RedisHashSetVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; - var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); // Act & Assert Assert.Null(await sut.GetAsync("HBaseSet-5", new GetRecordOptions { IncludeVectors = true })); } [Fact(Skip = SkipReason)] - public async Task ItThrowsMappingExceptionForFailedMapperAsync() + public async Task ItCanUpsertAndRetrieveUsingTheDynamicMapperAsync() { // Arrange - var options = new RedisHashSetVectorStoreRecordCollectionOptions - { - PrefixCollectionNameToKeyNames = true, - HashEntriesCustomMapper = new FailingMapper() - }; - var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); - - // Act & Assert - await Assert.ThrowsAsync(async () => await sut.GetAsync("HBaseSet-1", new GetRecordOptions { IncludeVectors = true })); - } - - [Fact(Skip = SkipReason)] - public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() - { - // Arrange - var options = new RedisHashSetVectorStoreRecordCollectionOptions> + var options = new RedisHashSetVectorStoreRecordCollectionOptions> { PrefixCollectionNameToKeyNames = true, VectorStoreRecordDefinition = fixture.BasicVectorStoreRecordDefinition }; - var sut = new RedisHashSetVectorStoreRecordCollection>(fixture.Database, TestCollectionName, options); + var sut = new RedisHashSetVectorStoreRecordCollection>(fixture.Database, TestCollectionName, options); // Act var baseSetGetResult = await sut.GetAsync("HBaseSet-1", new GetRecordOptions { IncludeVectors = true }); - var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel("HGenericMapper-1") + var upsertResult = await sut.UpsertAsync(new Dictionary { - Data = - { - { "HotelName", "Generic Mapper Hotel" }, - { "HotelCode", 40 }, - { "ParkingIncluded", true }, - { "Rating", 3.6d }, - { "Description", "This is a generic mapper hotel" }, - }, - Vectors = - { - { "DescriptionEmbedding", new ReadOnlyMemory(new[] { 30f, 31f, 32f, 33f }) } - } + ["HotelId"] = "HDynamicMapper-1", + + ["HotelName"] = "Dynamic Mapper Hotel", + ["HotelCode"] = 40, + ["ParkingIncluded"] = true, + ["Rating"] = 3.6d, + ["Description"] = "This is a dynamic mapper hotel", + + ["DescriptionEmbedding"] = new ReadOnlyMemory(new[] { 30f, 31f, 32f, 33f }) }); - var localGetResult = await sut.GetAsync("HGenericMapper-1", new GetRecordOptions { IncludeVectors = true }); + var localGetResult = await sut.GetAsync("HDynamicMapper-1", new GetRecordOptions { IncludeVectors = true }); // Assert Assert.NotNull(baseSetGetResult); - Assert.Equal("HBaseSet-1", baseSetGetResult.Key); - Assert.Equal("My Hotel 1", baseSetGetResult.Data["HotelName"]); - Assert.Equal(1, baseSetGetResult.Data["HotelCode"]); - Assert.True((bool)baseSetGetResult.Data["ParkingIncluded"]!); - Assert.Equal(3.6d, baseSetGetResult.Data["Rating"]); - Assert.Equal("This is a great hotel.", baseSetGetResult.Data["Description"]); - Assert.NotNull(baseSetGetResult.Vectors["DescriptionEmbedding"]); - Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)baseSetGetResult.Vectors["DescriptionEmbedding"]!).ToArray()); + Assert.Equal("HBaseSet-1", baseSetGetResult["HotelId"]); + Assert.Equal("My Hotel 1", baseSetGetResult["HotelName"]); + Assert.Equal(1, baseSetGetResult["HotelCode"]); + Assert.True((bool)baseSetGetResult["ParkingIncluded"]!); + Assert.Equal(3.6d, baseSetGetResult["Rating"]); + Assert.Equal("This is a great hotel.", baseSetGetResult["Description"]); + Assert.NotNull(baseSetGetResult["DescriptionEmbedding"]); + Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)baseSetGetResult["DescriptionEmbedding"]!).ToArray()); Assert.Equal("HGenericMapper-1", upsertResult); Assert.NotNull(localGetResult); - Assert.Equal("HGenericMapper-1", localGetResult.Key); - Assert.Equal("Generic Mapper Hotel", localGetResult.Data["HotelName"]); - Assert.Equal(40, localGetResult.Data["HotelCode"]); - Assert.True((bool)localGetResult.Data["ParkingIncluded"]!); - Assert.Equal(3.6d, localGetResult.Data["Rating"]); - Assert.Equal("This is a generic mapper hotel", localGetResult.Data["Description"]); - Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult.Vectors["DescriptionEmbedding"]!).ToArray()); + Assert.Equal("HDynamicMapper-1", localGetResult["HotelId"]); + Assert.Equal("Dynamic Mapper Hotel", localGetResult["HotelName"]); + Assert.Equal(40, localGetResult["HotelCode"]); + Assert.True((bool)localGetResult["ParkingIncluded"]!); + Assert.Equal(3.6d, localGetResult["Rating"]); + Assert.Equal("This is a dynamic mapper hotel", localGetResult["Description"]); + Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult["DescriptionEmbedding"]!).ToArray()); } private static RedisBasicFloat32Hotel CreateTestHotel(string hotelId, int hotelCode) @@ -508,17 +488,4 @@ private static RedisBasicFloat32Hotel CreateTestHotel(string hotelId, int hotelC }; return record; } - - private sealed class FailingMapper : IVectorStoreRecordMapper - { - public (string Key, HashEntry[] HashEntries) MapFromDataToStorageModel(RedisBasicFloat32Hotel dataModel) - { - throw new NotImplementedException(); - } - - public RedisBasicFloat32Hotel MapFromStorageToDataModel((string Key, HashEntry[] HashEntries) storageModel, StorageToDataModelMapperOptions options) - { - throw new NotImplementedException(); - } - } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHotel.cs index 87dc5c2fb89b..5a0dbb64459f 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHotel.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHotel.cs @@ -16,23 +16,23 @@ public class RedisHotel [VectorStoreRecordKey] public string HotelId { get; init; } - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string HotelName { get; init; } - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public int HotelCode { get; init; } - [VectorStoreRecordData(IsFullTextSearchable = true)] + [VectorStoreRecordData(IsFullTextIndexed = true)] public string Description { get; init; } [VectorStoreRecordVector(4)] public ReadOnlyMemory? DescriptionEmbedding { get; init; } #pragma warning disable CA1819 // Properties should not return arrays - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string[] Tags { get; init; } - [VectorStoreRecordData(IsFullTextSearchable = true)] + [VectorStoreRecordData(IsFullTextIndexed = true)] public string[] FTSTags { get; init; } #pragma warning restore CA1819 // Properties should not return arrays @@ -67,13 +67,13 @@ public class RedisBasicHotel [VectorStoreRecordKey] public string HotelId { get; init; } - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string HotelName { get; init; } - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public int HotelCode { get; init; } - [VectorStoreRecordData(IsFullTextSearchable = true)] + [VectorStoreRecordData(IsFullTextIndexed = true)] public string Description { get; init; } [VectorStoreRecordVector(4)] diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs index 266948738ef6..2bb8bce7c1de 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs @@ -1,8 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.Linq; -using System.Text.Json.Nodes; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Redis; @@ -16,7 +16,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Redis; #pragma warning disable CS0618 // VectorSearchFilter is obsolete /// -/// Contains tests for the class. +/// Contains tests for the class. /// /// Used for logging. /// Redis setup and teardown. @@ -24,7 +24,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Redis; public sealed class RedisJsonVectorStoreRecordCollectionTests(ITestOutputHelper output, RedisVectorStoreFixture fixture) { // If null, all tests will be enabled - private const string SkipReason = "Redis tests fail intermittently on build server"; + private const string SkipReason = null; private const string TestCollectionName = "jsonhotels"; @@ -34,7 +34,7 @@ public sealed class RedisJsonVectorStoreRecordCollectionTests(ITestOutputHelper public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) { // Arrange. - var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, collectionName); + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, collectionName); // Act. var actual = await sut.CollectionExistsAsync(); @@ -58,15 +58,16 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDe PrefixCollectionNameToKeyNames = true, VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null }; - var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, testCollectionName, options); + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, testCollectionName, options); // Act await sut.CreateCollectionAsync(); var upsertResult = await sut.UpsertAsync(record); var getResult = await sut.GetAsync("Upsert-10", new GetRecordOptions { IncludeVectors = true }); - var actual = await sut.VectorizedSearchAsync( + var searchResults = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new[] { 30f, 31f, 32f, 33f }), - new() { OldFilter = new VectorSearchFilter().EqualTo("HotelCode", 10) }); + top: 3, + new() { OldFilter = new VectorSearchFilter().EqualTo("HotelCode", 10) }).ToListAsync(); // Assert var collectionExistResult = await sut.CollectionExistsAsync(); @@ -87,7 +88,6 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDe Assert.Equal(record.Description, getResult?.Description); Assert.Equal(record.DescriptionEmbedding?.ToArray(), getResult?.DescriptionEmbedding?.ToArray()); - var searchResults = await actual.Results.ToListAsync(); Assert.Single(searchResults); Assert.Equal(1, searchResults.First().Score); var searchResultRecord = searchResults.First().Record; @@ -121,7 +121,7 @@ public async Task ItCanDeleteCollectionAsync() createParams.AddPrefix(tempCollectionName); await fixture.Database.FT().CreateAsync(tempCollectionName, createParams, schema); - var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, tempCollectionName); + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, tempCollectionName); // Act await sut.DeleteCollectionAsync(); @@ -141,7 +141,7 @@ public async Task ItCanUpsertDocumentToVectorStoreAsync(bool useRecordDefinition PrefixCollectionNameToKeyNames = true, VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null }; - var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); RedisHotel record = CreateTestHotel("Upsert-2", 2); // Act. @@ -179,10 +179,10 @@ public async Task ItCanUpsertManyDocumentsToVectorStoreAsync(bool useRecordDefin PrefixCollectionNameToKeyNames = true, VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null }; - var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); // Act. - var results = sut.UpsertBatchAsync( + var results = sut.UpsertAsync( [ CreateTestHotel("UpsertMany-1", 1), CreateTestHotel("UpsertMany-2", 2), @@ -191,7 +191,7 @@ public async Task ItCanUpsertManyDocumentsToVectorStoreAsync(bool useRecordDefin // Assert. Assert.NotNull(results); - var resultsList = await results.ToListAsync(); + var resultsList = await results; Assert.Equal(3, resultsList.Count); Assert.Contains("UpsertMany-1", resultsList); @@ -218,7 +218,7 @@ public async Task ItCanGetDocumentFromVectorStoreAsync(bool includeVectors, bool PrefixCollectionNameToKeyNames = true, VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null }; - var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); // Act. var getResult = await sut.GetAsync("BaseSet-1", new GetRecordOptions { IncludeVectors = includeVectors }); @@ -252,11 +252,11 @@ public async Task ItCanGetManyDocumentsFromVectorStoreAsync() { // Arrange var options = new RedisJsonVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; - var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); // Act // Also include one non-existing key to test that the operation does not fail for these and returns only the found ones. - var hotels = sut.GetBatchAsync(["BaseSet-1", "BaseSet-5", "BaseSet-2"], new GetRecordOptions { IncludeVectors = true }); + var hotels = sut.GetAsync(["BaseSet-1", "BaseSet-5", "BaseSet-2"], new GetRecordOptions { IncludeVectors = true }); // Assert Assert.NotNull(hotels); @@ -275,7 +275,7 @@ public async Task ItFailsToGetDocumentsWithInvalidSchemaAsync() { // Arrange. var options = new RedisJsonVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; - var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); // Act & Assert. await Assert.ThrowsAsync(async () => await sut.GetAsync("BaseSet-4-Invalid", new GetRecordOptions { IncludeVectors = true })); @@ -292,7 +292,7 @@ public async Task ItCanRemoveDocumentFromVectorStoreAsync(bool useRecordDefiniti PrefixCollectionNameToKeyNames = true, VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null }; - var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); var address = new RedisHotelAddress { City = "Seattle", Country = "USA" }; var record = new RedisHotel { @@ -319,14 +319,14 @@ public async Task ItCanRemoveManyDocumentsFromVectorStoreAsync() { // Arrange var options = new RedisJsonVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; - var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); await sut.UpsertAsync(CreateTestHotel("RemoveMany-1", 1)); await sut.UpsertAsync(CreateTestHotel("RemoveMany-2", 2)); await sut.UpsertAsync(CreateTestHotel("RemoveMany-3", 3)); // Act // Also include a non-existing key to test that the operation does not fail for these. - await sut.DeleteBatchAsync(["RemoveMany-1", "RemoveMany-2", "RemoveMany-3", "RemoveMany-4"]); + await sut.DeleteAsync(["RemoveMany-1", "RemoveMany-2", "RemoveMany-3", "RemoveMany-4"]); // Assert Assert.Null(await sut.GetAsync("RemoveMany-1", new GetRecordOptions { IncludeVectors = true })); @@ -341,17 +341,17 @@ public async Task ItCanSearchWithFloat32VectorAndFilterAsync(string filterType) { // Arrange var options = new RedisJsonVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; - var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); var vector = new ReadOnlyMemory(new[] { 30f, 31f, 32f, 33f }); var filter = filterType == "equality" ? new VectorSearchFilter().EqualTo("HotelCode", 1) : new VectorSearchFilter().AnyTagEqualTo("Tags", "pool"); // Act - var actual = await sut.VectorizedSearchAsync( + var searchResults = await sut.VectorizedSearchAsync( vector, - new() { IncludeVectors = true, OldFilter = filter }); + top: 3, + new() { IncludeVectors = true, OldFilter = filter }).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); Assert.Single(searchResults); Assert.Equal(1, searchResults.First().Score); var searchResult = searchResults.First().Record; @@ -374,7 +374,7 @@ public async Task ItCanSearchWithFloat32VectorAndTopSkipAsync() { // Arrange var options = new RedisJsonVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; - var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName + "TopSkip", options); + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName + "TopSkip", options); await sut.CreateCollectionIfNotExistsAsync(); await sut.UpsertAsync(new RedisBasicFloat32Hotel { HotelId = "TopSkip_1", HotelName = "1", Description = "Nice hotel", DescriptionEmbedding = new ReadOnlyMemory([1.0f, 1.0f, 1.0f, 1.0f]) }); await sut.UpsertAsync(new RedisBasicFloat32Hotel { HotelId = "TopSkip_2", HotelName = "2", Description = "Nice hotel", DescriptionEmbedding = new ReadOnlyMemory([1.0f, 1.0f, 1.0f, 2.0f]) }); @@ -384,16 +384,15 @@ public async Task ItCanSearchWithFloat32VectorAndTopSkipAsync() var vector = new ReadOnlyMemory([1.0f, 1.0f, 1.0f, 1.0f]); // Act - var actual = await sut.VectorizedSearchAsync( + var searchResults = await sut.VectorizedSearchAsync( vector, + top: 3, new() { - Top = 3, Skip = 2 - }); + }).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); Assert.Equal(3, searchResults.Count); Assert.True(searchResults.Select(x => x.Record.HotelId).SequenceEqual(["TopSkip_3", "TopSkip_4", "TopSkip_5"])); } @@ -405,7 +404,7 @@ public async Task ItCanSearchWithFloat64VectorAsync(bool includeVectors) { // Arrange var options = new RedisJsonVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; - var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName + "Float64", options); + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName + "Float64", options); await sut.CreateCollectionIfNotExistsAsync(); await sut.UpsertAsync(new RedisBasicFloat64Hotel { HotelId = "Float64_1", HotelName = "1", Description = "Nice hotel", DescriptionEmbedding = new ReadOnlyMemory([1.0d, 1.1d, 1.2d, 1.3d]) }); await sut.UpsertAsync(new RedisBasicFloat64Hotel { HotelId = "Float64_2", HotelName = "2", Description = "Nice hotel", DescriptionEmbedding = new ReadOnlyMemory([2.0d, 2.1d, 2.2d, 2.3d]) }); @@ -414,16 +413,15 @@ public async Task ItCanSearchWithFloat64VectorAsync(bool includeVectors) var vector = new ReadOnlyMemory([2.0d, 2.1d, 2.2d, 2.3d]); // Act - var actual = await sut.VectorizedSearchAsync( + var searchResults = await sut.VectorizedSearchAsync( vector, + top: 1, new() { IncludeVectors = includeVectors, - Top = 1 - }); + }).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); Assert.Single(searchResults); var searchResult = searchResults.First().Record; Assert.Equal("Float64_2", searchResult?.HotelId); @@ -440,90 +438,71 @@ public async Task ItReturnsNullWhenGettingNonExistentRecordAsync() { // Arrange var options = new RedisJsonVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; - var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); // Act & Assert Assert.Null(await sut.GetAsync("BaseSet-5", new GetRecordOptions { IncludeVectors = true })); } [Fact(Skip = SkipReason)] - public async Task ItThrowsMappingExceptionForFailedMapperAsync() + public async Task ItCanUpsertAndRetrieveUsingTheDynamicMapperAsync() { // Arrange - var options = new RedisJsonVectorStoreRecordCollectionOptions - { - PrefixCollectionNameToKeyNames = true, - JsonNodeCustomMapper = new FailingMapper() - }; - var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); - - // Act & Assert - await Assert.ThrowsAsync(async () => await sut.GetAsync("BaseSet-1", new GetRecordOptions { IncludeVectors = true })); - } - - [Fact(Skip = SkipReason)] - public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() - { - // Arrange - var options = new RedisJsonVectorStoreRecordCollectionOptions> + var options = new RedisJsonVectorStoreRecordCollectionOptions> { PrefixCollectionNameToKeyNames = true, VectorStoreRecordDefinition = fixture.VectorStoreRecordDefinition }; - var sut = new RedisJsonVectorStoreRecordCollection>(fixture.Database, TestCollectionName, options); + var sut = new RedisJsonVectorStoreRecordCollection>(fixture.Database, TestCollectionName, options); // Act var baseSetGetResult = await sut.GetAsync("BaseSet-1", new GetRecordOptions { IncludeVectors = true }); - var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel("GenericMapper-1") + var upsertResult = await sut.UpsertAsync(new Dictionary { - Data = - { - { "HotelName", "Generic Mapper Hotel" }, - { "HotelCode", 1 }, - { "Tags", new[] { "generic 1", "generic 2" } }, - { "FTSTags", new[] { "generic 1", "generic 2" } }, - { "ParkingIncluded", true }, - { "LastRenovationDate", new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero) }, - { "Rating", 3.6 }, - { "Address", new RedisHotelAddress { City = "Seattle", Country = "USA" } }, - { "Description", "This is a generic mapper hotel" }, - { "DescriptionEmbedding", new[] { 30f, 31f, 32f, 33f } } - }, - Vectors = - { - { "DescriptionEmbedding", new ReadOnlyMemory(new[] { 30f, 31f, 32f, 33f }) } - } + ["HotelId"] = "DynamicMapper-1", + + ["HotelName"] = "Dynamic Mapper Hotel", + ["HotelCode"] = 1, + ["Tags"] = new[] { "dynamic 1", "dynamic 2" }, + ["FTSTags"] = new[] { "dynamic 1", "dynamic 2" }, + ["ParkingIncluded"] = true, + ["LastRenovationDate"] = new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), + ["Rating"] = 3.6, + ["Address"] = new RedisHotelAddress { City = "Seattle", Country = "USA" }, + ["Description"] = "This is a dynamic mapper hotel", + + ["DescriptionEmbedding"] = new ReadOnlyMemory(new[] { 30f, 31f, 32f, 33f }) }); - var localGetResult = await sut.GetAsync("GenericMapper-1", new GetRecordOptions { IncludeVectors = true }); + var localGetResult = await sut.GetAsync("DynamicMapper-1", new GetRecordOptions { IncludeVectors = true }); // Assert Assert.NotNull(baseSetGetResult); - Assert.Equal("BaseSet-1", baseSetGetResult.Key); - Assert.Equal("My Hotel 1", baseSetGetResult.Data["HotelName"]); - Assert.Equal(1, baseSetGetResult.Data["HotelCode"]); - Assert.Equal(new[] { "pool", "air conditioning", "concierge" }, baseSetGetResult.Data["Tags"]); - Assert.Equal(new[] { "pool", "air conditioning", "concierge" }, baseSetGetResult.Data["FTSTags"]); - Assert.True((bool)baseSetGetResult.Data["ParkingIncluded"]!); - Assert.Equal(new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), baseSetGetResult.Data["LastRenovationDate"]); - Assert.Equal(3.6, baseSetGetResult.Data["Rating"]); - Assert.Equal("Seattle", ((RedisHotelAddress)baseSetGetResult.Data["Address"]!).City); - Assert.Equal("This is a great hotel.", baseSetGetResult.Data["Description"]); - Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)baseSetGetResult.Vectors["DescriptionEmbedding"]!).ToArray()); - - Assert.Equal("GenericMapper-1", upsertResult); + Assert.Equal("BaseSet-1", baseSetGetResult["HotelId"]); + Assert.Equal("My Hotel 1", baseSetGetResult["HotelName"]); + Assert.Equal(1, baseSetGetResult["HotelCode"]); + Assert.Equal(new[] { "pool", "air conditioning", "concierge" }, baseSetGetResult["Tags"]); + Assert.Equal(new[] { "pool", "air conditioning", "concierge" }, baseSetGetResult["FTSTags"]); + Assert.True((bool)baseSetGetResult["ParkingIncluded"]!); + Assert.Equal(new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), baseSetGetResult["LastRenovationDate"]); + Assert.Equal(3.6, baseSetGetResult["Rating"]); + Assert.Equal("Seattle", ((RedisHotelAddress)baseSetGetResult["Address"]!).City); + Assert.Equal("This is a great hotel.", baseSetGetResult["Description"]); + Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)baseSetGetResult["DescriptionEmbedding"]!).ToArray()); + + Assert.Equal("DynamicMapper-1", upsertResult); Assert.NotNull(localGetResult); - Assert.Equal("GenericMapper-1", localGetResult.Key); - Assert.Equal("Generic Mapper Hotel", localGetResult.Data["HotelName"]); - Assert.Equal(1, localGetResult.Data["HotelCode"]); - Assert.Equal(new[] { "generic 1", "generic 2" }, localGetResult.Data["Tags"]); - Assert.Equal(new[] { "generic 1", "generic 2" }, localGetResult.Data["FTSTags"]); - Assert.True((bool)localGetResult.Data["ParkingIncluded"]!); - Assert.Equal(new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), localGetResult.Data["LastRenovationDate"]); - Assert.Equal(3.6d, localGetResult.Data["Rating"]); - Assert.Equal("Seattle", ((RedisHotelAddress)localGetResult.Data["Address"]!).City); - Assert.Equal("This is a generic mapper hotel", localGetResult.Data["Description"]); - Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult.Vectors["DescriptionEmbedding"]!).ToArray()); + Assert.Equal("DynamicMapper-1", localGetResult["HotelId"]); + Assert.Equal("Dynamic Mapper Hotel", localGetResult["HotelName"]); + Assert.Equal(1, localGetResult["HotelCode"]); + Assert.Equal(new[] { "dynamic 1", "dynamic 2" }, localGetResult["Tags"]); + Assert.Equal(new[] { "dynamic 1", "dynamic 2" }, localGetResult["FTSTags"]); + Assert.True((bool)localGetResult["ParkingIncluded"]!); + Assert.Equal(new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), localGetResult["LastRenovationDate"]); + Assert.Equal(3.6d, localGetResult["Rating"]); + Assert.Equal("Seattle", ((RedisHotelAddress)localGetResult["Address"]!).City); + Assert.Equal("This is a dynamic mapper hotel", localGetResult["Description"]); + Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult["DescriptionEmbedding"]!).ToArray()); } private static RedisHotel CreateTestHotel(string hotelId, int hotelCode) @@ -545,17 +524,4 @@ private static RedisHotel CreateTestHotel(string hotelId, int hotelCode) }; return record; } - - private sealed class FailingMapper : IVectorStoreRecordMapper - { - public (string Key, JsonNode Node) MapFromDataToStorageModel(RedisHotel dataModel) - { - throw new NotImplementedException(); - } - - public RedisHotel MapFromStorageToDataModel((string Key, JsonNode Node) storageModel, StorageToDataModelMapperOptions options) - { - throw new NotImplementedException(); - } - } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisVectorStoreFixture.cs index bec643a13d5b..febee4bed7da 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisVectorStoreFixture.cs @@ -40,12 +40,12 @@ public RedisVectorStoreFixture() Properties = new List { new VectorStoreRecordKeyProperty("HotelId", typeof(string)), - new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true }, - new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsFilterable = true }, - new VectorStoreRecordDataProperty("Description", typeof(string)) { IsFullTextSearchable = true }, - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4 }, - new VectorStoreRecordDataProperty("Tags", typeof(string[])) { IsFilterable = true }, - new VectorStoreRecordDataProperty("FTSTags", typeof(string[])) { IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsIndexed = true }, + new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsIndexed = true }, + new VectorStoreRecordDataProperty("Description", typeof(string)) { IsFullTextIndexed = true }, + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?), 4), + new VectorStoreRecordDataProperty("Tags", typeof(string[])) { IsIndexed = true }, + new VectorStoreRecordDataProperty("FTSTags", typeof(string[])) { IsFullTextIndexed = true }, new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { StoragePropertyName = "parking_is_included" }, new VectorStoreRecordDataProperty("LastRenovationDate", typeof(DateTimeOffset)), new VectorStoreRecordDataProperty("Rating", typeof(double)), @@ -57,10 +57,10 @@ public RedisVectorStoreFixture() Properties = new List { new VectorStoreRecordKeyProperty("HotelId", typeof(string)), - new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true }, - new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsFilterable = true }, - new VectorStoreRecordDataProperty("Description", typeof(string)) { IsFullTextSearchable = true }, - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4 }, + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsIndexed = true }, + new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsIndexed = true }, + new VectorStoreRecordDataProperty("Description", typeof(string)) { IsFullTextIndexed = true }, + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?), 4), new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { StoragePropertyName = "parking_is_included" }, new VectorStoreRecordDataProperty("Rating", typeof(double)), } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteHotel.cs index 761b0ce9631f..d7db1e61a9d7 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteHotel.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteHotel.cs @@ -12,7 +12,7 @@ public record SqliteHotel() public TKey? HotelId { get; init; } /// A string metadata field. - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string? HotelName { get; set; } /// An int metadata field. @@ -32,6 +32,6 @@ public record SqliteHotel() public string? Description { get; set; } /// A vector field. - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.EuclideanDistance)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.EuclideanDistance)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteServiceCollectionExtensionsTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteServiceCollectionExtensionsTests.cs index 2e3e6b32fe52..4726c2c029dc 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteServiceCollectionExtensionsTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteServiceCollectionExtensionsTests.cs @@ -1,7 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Data; -using Microsoft.Data.Sqlite; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel; @@ -44,11 +42,11 @@ public void AddVectorStoreRecordCollectionWithStringKeyAndSqliteConnectionRegist // Assert var collection = serviceProvider.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = serviceProvider.GetRequiredService>(); + var vectorizedSearch = serviceProvider.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); + Assert.IsType>(vectorizedSearch); } [Fact(Skip = SkipReason)] @@ -62,34 +60,11 @@ public void AddVectorStoreRecordCollectionWithNumericKeyAndSqliteConnectionRegis // Assert var collection = serviceProvider.GetRequiredService>(); Assert.NotNull(collection); - Assert.IsType>(collection); + Assert.IsType>(collection); - var vectorizedSearch = serviceProvider.GetRequiredService>(); + var vectorizedSearch = serviceProvider.GetRequiredService>(); Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); - } - - [Fact(Skip = SkipReason)] - public void ItClosesConnectionWhenDIServiceIsDisposed() - { - // Act - using var connection = new SqliteConnection("Data Source=:memory:"); - - this._serviceCollection.AddTransient(_ => connection); - - this._serviceCollection.AddSqliteVectorStore(); - - var serviceProvider = this._serviceCollection.BuildServiceProvider(); - - using (var scope = serviceProvider.CreateScope()) - { - scope.ServiceProvider.GetRequiredService(); - - Assert.Equal(ConnectionState.Open, connection.State); - } - - // Assert - Assert.Equal(ConnectionState.Closed, connection.State); + Assert.IsType>(vectorizedSearch); } #region private diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreFixture.cs index 6f07f20ddf67..1a8128a8cf03 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreFixture.cs @@ -1,50 +1,29 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Threading.Tasks; -using Microsoft.Data.Sqlite; +using System.IO; using Microsoft.SemanticKernel.Connectors.Sqlite; -using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.Memory.Sqlite; -public class SqliteVectorStoreFixture : IAsyncLifetime, IDisposable +public class SqliteVectorStoreFixture : IDisposable { - /// - /// SQLite extension name for vector search. - /// More information here: . - /// - private const string VectorSearchExtensionName = "vec0"; + private readonly string _databasePath = Path.GetTempFileName(); - public SqliteConnection Connection { get; } + public string ConnectionString => $"Data Source={this._databasePath}"; - public SqliteVectorStoreFixture() - { - this.Connection = new SqliteConnection("Data Source=:memory:"); - } - - public SqliteVectorStoreRecordCollection GetCollection( + public SqliteVectorStoreRecordCollection GetCollection( string collectionName, SqliteVectorStoreRecordCollectionOptions? options = default) + where TKey : notnull + where TRecord : notnull { - return new SqliteVectorStoreRecordCollection( - this.Connection, + return new SqliteVectorStoreRecordCollection( + this.ConnectionString, collectionName, options); } - public Task DisposeAsync() - { - return Task.CompletedTask; - } - - public async Task InitializeAsync() - { - await this.Connection.OpenAsync(); - - this.Connection.LoadExtension(VectorSearchExtensionName); - } - public void Dispose() { this.Dispose(true); @@ -55,7 +34,7 @@ protected virtual void Dispose(bool disposing) { if (disposing) { - this.Connection.Dispose(); + File.Delete(this._databasePath); } } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs index f799fd26eaa8..d5324fc4ecf0 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs @@ -1,19 +1,22 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; using System.Threading.Tasks; +using Microsoft.Data.Sqlite; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Sqlite; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.Memory.Sqlite; +#pragma warning disable CA1859 // Use concrete types when possible for improved performance #pragma warning disable CS0618 // VectorSearchFilter is obsolete /// -/// Integration tests for class. +/// Integration tests for class. /// [Collection("SqliteVectorStoreCollection")] public sealed class SqliteVectorStoreRecordCollectionTests(SqliteVectorStoreFixture fixture) @@ -26,7 +29,7 @@ public sealed class SqliteVectorStoreRecordCollectionTests(SqliteVectorStoreFixt public async Task CollectionExistsReturnsCollectionStateAsync(bool createCollection) { // Arrange - var sut = fixture.GetCollection>("CollectionExists"); + var sut = fixture.GetCollection>("CollectionExists"); if (createCollection) { @@ -44,7 +47,7 @@ public async Task CollectionExistsReturnsCollectionStateAsync(bool createCollect public async Task ItCanCreateCollectionAsync() { // Arrange - var sut = fixture.GetCollection>("CreateCollection"); + var sut = fixture.GetCollection>("CreateCollection"); // Act await sut.CreateCollectionAsync(); @@ -57,7 +60,7 @@ public async Task ItCanCreateCollectionAsync() public async Task ItCanCreateCollectionForSupportedDistanceFunctionsAsync() { // Arrange - var sut = fixture.GetCollection("CreateCollectionForSupportedDistanceFunctions"); + var sut = fixture.GetCollection("CreateCollectionForSupportedDistanceFunctions"); // Act await sut.CreateCollectionAsync(); @@ -70,7 +73,7 @@ public async Task ItCanCreateCollectionForSupportedDistanceFunctionsAsync() public async Task ItCanDeleteCollectionAsync() { // Arrange - var sut = fixture.GetCollection>("DeleteCollection"); + var sut = fixture.GetCollection>("DeleteCollection"); await sut.CreateCollectionAsync(); @@ -101,7 +104,7 @@ public async Task ItCanCreateCollectionUpsertAndGetAsync(bool includeVectors, bo VectorStoreRecordDefinition = useRecordDefinition ? GetVectorStoreRecordDefinition() : null }; - var sut = fixture.GetCollection>("DeleteCollection", options); + var sut = fixture.GetCollection>("DeleteCollection", options); var record = CreateTestHotel(HotelId); @@ -140,7 +143,7 @@ public async Task ItCanGetAndDeleteRecordAsync() { // Arrange const ulong HotelId = 5; - var sut = fixture.GetCollection>("DeleteRecord"); + var sut = fixture.GetCollection>("DeleteRecord"); await sut.CreateCollectionAsync(); @@ -169,7 +172,7 @@ public async Task ItCanGetUpsertDeleteBatchWithNumericKeyAsync() const ulong HotelId2 = 2; const ulong HotelId3 = 3; - var sut = fixture.GetCollection>("GetUpsertDeleteBatchWithNumericKey"); + var sut = fixture.GetCollection>("GetUpsertDeleteBatchWithNumericKey"); await sut.CreateCollectionAsync(); @@ -177,8 +180,8 @@ public async Task ItCanGetUpsertDeleteBatchWithNumericKeyAsync() var record2 = CreateTestHotel(HotelId2); var record3 = CreateTestHotel(HotelId3); - var upsertResults = await sut.UpsertBatchAsync([record1, record2, record3]).ToListAsync(); - var getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + var upsertResults = await sut.UpsertAsync([record1, record2, record3]); + var getResults = await sut.GetAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); Assert.Equal([HotelId1, HotelId2, HotelId3], upsertResults); @@ -187,9 +190,9 @@ public async Task ItCanGetUpsertDeleteBatchWithNumericKeyAsync() Assert.NotNull(getResults.First(l => l.HotelId == HotelId3)); // Act - await sut.DeleteBatchAsync([HotelId1, HotelId2, HotelId3]); + await sut.DeleteAsync([HotelId1, HotelId2, HotelId3]); - getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + getResults = await sut.GetAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); // Assert Assert.Empty(getResults); @@ -203,7 +206,7 @@ public async Task ItCanGetUpsertDeleteBatchWithStringKeyAsync() const string HotelId2 = "22222222-2222-2222-2222-222222222222"; const string HotelId3 = "33333333-3333-3333-3333-333333333333"; - var sut = fixture.GetCollection>("GetUpsertDeleteBatchWithStringKey") as IVectorStoreRecordCollection>; + var sut = fixture.GetCollection>("GetUpsertDeleteBatchWithStringKey") as IVectorStoreRecordCollection>; await sut.CreateCollectionAsync(); @@ -211,8 +214,8 @@ public async Task ItCanGetUpsertDeleteBatchWithStringKeyAsync() var record2 = CreateTestHotel(HotelId2); var record3 = CreateTestHotel(HotelId3); - var upsertResults = await sut.UpsertBatchAsync([record1, record2, record3]).ToListAsync(); - var getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + var upsertResults = await sut.UpsertAsync([record1, record2, record3]); + var getResults = await sut.GetAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); Assert.Equal([HotelId1, HotelId2, HotelId3], upsertResults); @@ -221,9 +224,9 @@ public async Task ItCanGetUpsertDeleteBatchWithStringKeyAsync() Assert.NotNull(getResults.First(l => l.HotelId == HotelId3)); // Act - await sut.DeleteBatchAsync([HotelId1, HotelId2, HotelId3]); + await sut.DeleteAsync([HotelId1, HotelId2, HotelId3]); - getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + getResults = await sut.GetAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); // Assert Assert.Empty(getResults); @@ -239,13 +242,15 @@ public async Task ItCanGetExistingRecordAsync(bool includeVectors) var collectionName = $"Collection{collectionNamePostfix}"; const ulong HotelId = 5; - var sut = fixture.GetCollection>(collectionName); + var sut = fixture.GetCollection>(collectionName); await sut.CreateCollectionAsync(); var record = CreateTestHotel(HotelId); - var commandData = fixture.Connection.CreateCommand(); + using var connection = new SqliteConnection(fixture.ConnectionString); + await connection.OpenAsync(); + var commandData = connection.CreateCommand(); commandData.CommandText = $"INSERT INTO {collectionName} " + @@ -262,7 +267,7 @@ public async Task ItCanGetExistingRecordAsync(bool includeVectors) if (includeVectors) { - var commandVector = fixture.Connection.CreateCommand(); + var commandVector = connection.CreateCommand(); commandVector.CommandText = $"INSERT INTO vec_{collectionName} " + @@ -303,7 +308,7 @@ public async Task ItCanUpsertExistingRecordAsync() { // Arrange const ulong HotelId = 5; - var sut = fixture.GetCollection>("UpsertRecord"); + var sut = fixture.GetCollection>("UpsertRecord"); await sut.CreateCollectionAsync(); @@ -343,19 +348,17 @@ public async Task VectorizedSearchReturnsValidResultsByDefaultAsync(bool include var hotel3 = CreateTestHotel(hotelId: "key3", embedding: new[] { 20f, 20f, 20f, 20f }); var hotel4 = CreateTestHotel(hotelId: "key4", embedding: new[] { -1000f, -1000f, -1000f, -1000f }); - var sut = fixture.GetCollection>("VectorizedSearch"); + var sut = fixture.GetCollection>("VectorizedSearch"); await sut.CreateCollectionIfNotExistsAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), new() + var results = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), top: 3, new() { IncludeVectors = includeVectors - }); - - var results = await searchResults.Results.ToListAsync(); + }).ToListAsync(); // Assert var ids = results.Select(l => l.Record.HotelId).ToList(); @@ -380,20 +383,17 @@ public async Task VectorizedSearchReturnsValidResultsWithOffsetAsync() var hotel3 = CreateTestHotel(hotelId: "key3", embedding: new[] { 20f, 20f, 20f, 20f }); var hotel4 = CreateTestHotel(hotelId: "key4", embedding: new[] { -1000f, -1000f, -1000f, -1000f }); - var sut = fixture.GetCollection>("VectorizedSearchWithOffset"); + var sut = fixture.GetCollection>("VectorizedSearchWithOffset"); await sut.CreateCollectionIfNotExistsAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), new() + var results = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), top: 2, new() { - Top = 2, Skip = 2 - }); - - var results = await searchResults.Results.ToListAsync(); + }).ToListAsync(); // Assert var ids = results.Select(l => l.Record.HotelId).ToList(); @@ -414,19 +414,17 @@ public async Task VectorizedSearchReturnsValidResultsWithFilterAsync() var hotel3 = CreateTestHotel(hotelId: "key3", embedding: new[] { 20f, 20f, 20f, 20f }); var hotel4 = CreateTestHotel(hotelId: "key4", embedding: new[] { -1000f, -1000f, -1000f, -1000f }); - var sut = fixture.GetCollection>("VectorizedSearchWithFilter"); + var sut = fixture.GetCollection>("VectorizedSearchWithFilter"); await sut.CreateCollectionIfNotExistsAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), new() + var results = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), top: 3, new() { OldFilter = new VectorSearchFilter().EqualTo(nameof(SqliteHotel.HotelName), "My Hotel key2") - }); - - var results = await searchResults.Results.ToListAsync(); + }).ToListAsync(); // Assert var ids = results.Select(l => l.Record.HotelId).ToList(); @@ -439,81 +437,75 @@ public async Task VectorizedSearchReturnsValidResultsWithFilterAsync() } [Fact(Skip = SkipReason)] - public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperWithNumericKeyAsync() + public async Task ItCanUpsertAndRetrieveUsingTheDynamicMapperWithNumericKeyAsync() { - const ulong HotelId = 5; + const long HotelId = 5; - var options = new SqliteVectorStoreRecordCollectionOptions> + var options = new SqliteVectorStoreRecordCollectionOptions> { VectorStoreRecordDefinition = GetVectorStoreRecordDefinition() }; - var sut = fixture.GetCollection>("GenericMapperWithNumericKey", options); + var sut = fixture.GetCollection>("DynamicMapperWithNumericKey", options); await sut.CreateCollectionAsync(); var record = CreateTestHotel(HotelId); // Act - var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel(HotelId) + var upsertResult = await sut.UpsertAsync(new Dictionary { - Data = - { - { "HotelName", "Generic Mapper Hotel" }, - { "Description", "This is a generic mapper hotel" }, - { "ParkingIncluded", true }, - { "HotelRating", 3.6f } - }, - Vectors = - { - { "DescriptionEmbedding", new ReadOnlyMemory([30f, 31f, 32f, 33f]) } - } + ["HotelId"] = HotelId, + + ["HotelName"] = "Dynamic Mapper Hotel", + ["Description"] = "This is a dynamic mapper hotel", + ["ParkingIncluded"] = true, + ["HotelRating"] = 3.6f, + + ["DescriptionEmbedding"] = new ReadOnlyMemory([30f, 31f, 32f, 33f]) }); var localGetResult = await sut.GetAsync(HotelId, new GetRecordOptions { IncludeVectors = true }); // Assert - Assert.Equal(HotelId, upsertResult); + Assert.Equal(HotelId, (long)upsertResult); Assert.NotNull(localGetResult); - Assert.Equal("Generic Mapper Hotel", localGetResult.Data["HotelName"]); - Assert.Equal("This is a generic mapper hotel", localGetResult.Data["Description"]); - Assert.True((bool?)localGetResult.Data["ParkingIncluded"]); - Assert.Equal(3.6f, localGetResult.Data["HotelRating"]); - Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult.Vectors["DescriptionEmbedding"]!).ToArray()); + Assert.Equal("Dynamic Mapper Hotel", localGetResult["HotelName"]); + Assert.Equal("This is a dynamic mapper hotel", localGetResult["Description"]); + Assert.True((bool?)localGetResult["ParkingIncluded"]); + Assert.Equal(3.6f, localGetResult["HotelRating"]); + Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult["DescriptionEmbedding"]!).ToArray()); } [Fact(Skip = SkipReason)] - public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperWithStringKeyAsync() + public async Task ItCanUpsertAndRetrieveUsingTheDynamicMapperWithStringKeyAsync() { const string HotelId = "key"; - var options = new SqliteVectorStoreRecordCollectionOptions> + var options = new SqliteVectorStoreRecordCollectionOptions> { VectorStoreRecordDefinition = GetVectorStoreRecordDefinition() }; - var sut = fixture.GetCollection>("GenericMapperWithStringKey", options) - as IVectorStoreRecordCollection>; + var sut = fixture.GetCollection>("DynamicMapperWithStringKey", options) + as IVectorStoreRecordCollection>; await sut.CreateCollectionAsync(); var record = CreateTestHotel(HotelId); // Act - var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel(HotelId) + var upsertResult = await sut.UpsertAsync(new Dictionary { - Data = - { - { "HotelName", "Generic Mapper Hotel" }, - { "Description", "This is a generic mapper hotel" }, - { "ParkingIncluded", true }, - { "HotelRating", 3.6f } - }, - Vectors = - { - { "DescriptionEmbedding", new ReadOnlyMemory([30f, 31f, 32f, 33f]) } - } + ["HotelId"] = HotelId, + + ["HotelName"] = "Dynamic Mapper Hotel", + ["Description"] = "This is a dynamic mapper hotel", + ["ParkingIncluded"] = true, + ["HotelRating"] = 3.6f, + + ["DescriptionEmbedding"] = new ReadOnlyMemory([30f, 31f, 32f, 33f]) }); var localGetResult = await sut.GetAsync(HotelId, new GetRecordOptions { IncludeVectors = true }); @@ -522,11 +514,11 @@ public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperWithStringKeyAsync( Assert.Equal(HotelId, upsertResult); Assert.NotNull(localGetResult); - Assert.Equal("Generic Mapper Hotel", localGetResult.Data["HotelName"]); - Assert.Equal("This is a generic mapper hotel", localGetResult.Data["Description"]); - Assert.True((bool?)localGetResult.Data["ParkingIncluded"]); - Assert.Equal(3.6f, localGetResult.Data["HotelRating"]); - Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult.Vectors["DescriptionEmbedding"]!).ToArray()); + Assert.Equal("Dynamic Mapper Hotel", localGetResult["HotelName"]); + Assert.Equal("This is a dynamic mapper hotel", localGetResult["Description"]); + Assert.True((bool?)localGetResult["ParkingIncluded"]); + Assert.Equal(3.6f, localGetResult["HotelRating"]); + Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult["DescriptionEmbedding"]!).ToArray()); } #region @@ -541,7 +533,7 @@ public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperWithStringKeyAsync( new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { StoragePropertyName = "parking_is_included" }, new VectorStoreRecordDataProperty("HotelRating", typeof(float)), new VectorStoreRecordDataProperty("Description", typeof(string)), - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, IndexKind = IndexKind.IvfFlat, DistanceFunction = DistanceFunction.CosineDistance } + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?), 4) { IndexKind = IndexKind.IvfFlat, DistanceFunction = DistanceFunction.CosineDistance } ] }; @@ -576,13 +568,13 @@ private sealed class RecordWithSupportedDistanceFunctions [VectorStoreRecordKey] public ulong Id { get; set; } - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineDistance)] public ReadOnlyMemory? Embedding1 { get; set; } - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.EuclideanDistance)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.EuclideanDistance)] public ReadOnlyMemory? Embedding2 { get; set; } - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.ManhattanDistance)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.ManhattanDistance)] public ReadOnlyMemory? Embedding3 { get; set; } } #pragma warning restore CA1812 diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreTests.cs index 8a173250f7fe..6eca22778b02 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreTests.cs @@ -17,7 +17,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Sqlite; [Collection("SqliteVectorStoreCollection")] [DisableVectorStoreTests(Skip = "SQLite vector search extension is required")] public sealed class SqliteVectorStoreTests(SqliteVectorStoreFixture fixture) - : BaseVectorStoreTests>(new SqliteVectorStore(fixture.Connection!)) + : BaseVectorStoreTests>(new SqliteVectorStore(fixture.ConnectionString)) { [VectorStoreFact] public async Task ItCanGetAListOfExistingCollectionNamesWhenRegisteredWithDIAsync() @@ -25,7 +25,7 @@ public async Task ItCanGetAListOfExistingCollectionNamesWhenRegisteredWithDIAsyn // Arrange var serviceCollection = new ServiceCollection(); - serviceCollection.AddSqliteVectorStore(connectionString: "Data Source=:memory:"); + serviceCollection.AddSqliteVectorStore(fixture.ConnectionString); var provider = serviceCollection.BuildServiceProvider(); diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/CommonWeaviateVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/CommonWeaviateVectorStoreRecordCollectionTests.cs index 1398f7d48f27..be63bac6ea58 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/CommonWeaviateVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/CommonWeaviateVectorStoreRecordCollectionTests.cs @@ -28,7 +28,7 @@ protected override IVectorStoreRecordCollection GetTargetRecordCo var recordCollectionNameChars = recordCollectionName.ToCharArray(); recordCollectionNameChars[0] = char.ToUpperInvariant(recordCollectionNameChars[0]); - return new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, new string(recordCollectionNameChars), new() + return new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, new string(recordCollectionNameChars), new() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }); diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateHotel.cs index bfcd78c9a51c..1338eceef6fe 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateHotel.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateHotel.cs @@ -16,34 +16,34 @@ public sealed record WeaviateHotel public Guid HotelId { get; init; } /// A string metadata field. - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public string? HotelName { get; set; } /// An int metadata field. - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public int HotelCode { get; set; } /// A float metadata field. - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public float? HotelRating { get; set; } /// A bool metadata field. [JsonPropertyName("parking_is_included")] - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public bool ParkingIncluded { get; set; } /// An array metadata field. - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public List Tags { get; set; } = []; /// A data field. - [VectorStoreRecordData(IsFullTextSearchable = true, IsFilterable = true)] + [VectorStoreRecordData(IsFullTextIndexed = true, IsIndexed = true)] public string Description { get; set; } - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public DateTimeOffset Timestamp { get; set; } /// A vector field. - [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance, IndexKind: IndexKind.Hnsw)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction = DistanceFunction.CosineDistance, IndexKind = IndexKind.Hnsw)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateMemoryStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateMemoryStoreTests.cs index b88795e9a3d6..bd366a06dfc3 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateMemoryStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateMemoryStoreTests.cs @@ -17,6 +17,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Weaviate; /// The Weaviate instance API key is set in the Docker Container as "my-secret-key". /// [Collection("Sequential")] +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public sealed class WeaviateMemoryStoreTests : IDisposable { // If null, all tests will be enabled diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateVectorStoreRecordCollectionTests.cs index 494967b21fc7..fe39b2c3ce62 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateVectorStoreRecordCollectionTests.cs @@ -19,7 +19,7 @@ public sealed class WeaviateVectorStoreRecordCollectionTests(WeaviateVectorStore public async Task ItCanCreateCollectionAsync() { // Arrange - var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, "TestCreateCollection"); + var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, "TestCreateCollection"); // Act await sut.CreateCollectionAsync(); @@ -34,7 +34,7 @@ public async Task ItCanCreateCollectionAsync() public async Task ItCanCheckIfCollectionExistsAsync(string collectionName, bool collectionExists) { // Arrange - var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, collectionName); + var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, collectionName); if (collectionExists) { @@ -63,7 +63,7 @@ public async Task ItCanUpsertAndGetRecordAsync(string collectionName, bool inclu VectorStoreRecordDefinition = useRecordDefinition ? this.GetTestHotelRecordDefinition() : null }; - var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, collectionName, options); + var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, collectionName, options); var record = this.CreateTestHotel(hotelId); @@ -104,7 +104,7 @@ public async Task ItCanDeleteCollectionAsync() // Arrange const string CollectionName = "TestDeleteCollection"; - var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, CollectionName); + var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, CollectionName); await sut.CreateCollectionAsync(); @@ -123,7 +123,7 @@ public async Task ItCanDeleteRecordAsync() // Arrange var hotelId = new Guid("55555555-5555-5555-5555-555555555555"); - var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, "TestDeleteRecord"); + var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, "TestDeleteRecord"); await sut.CreateCollectionAsync(); @@ -152,7 +152,7 @@ public async Task ItCanUpsertAndGetAndDeleteBatchAsync() var hotelId2 = new Guid("22222222-2222-2222-2222-222222222222"); var hotelId3 = new Guid("33333333-3333-3333-3333-333333333333"); - var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, "TestBatch"); + var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, "TestBatch"); await sut.CreateCollectionAsync(); @@ -160,8 +160,8 @@ public async Task ItCanUpsertAndGetAndDeleteBatchAsync() var record2 = this.CreateTestHotel(hotelId2); var record3 = this.CreateTestHotel(hotelId3); - var upsertResults = await sut.UpsertBatchAsync([record1, record2, record3]).ToListAsync(); - var getResults = await sut.GetBatchAsync([hotelId1, hotelId2, hotelId3]).ToListAsync(); + var upsertResults = await sut.UpsertAsync([record1, record2, record3]); + var getResults = await sut.GetAsync([hotelId1, hotelId2, hotelId3]).ToListAsync(); Assert.Equal([hotelId1, hotelId2, hotelId3], upsertResults); @@ -170,9 +170,9 @@ public async Task ItCanUpsertAndGetAndDeleteBatchAsync() Assert.NotNull(getResults.First(l => l.HotelId == hotelId3)); // Act - await sut.DeleteBatchAsync([hotelId1, hotelId2, hotelId3]); + await sut.DeleteAsync([hotelId1, hotelId2, hotelId3]); - getResults = await sut.GetBatchAsync([hotelId1, hotelId2, hotelId3]).ToListAsync(); + getResults = await sut.GetAsync([hotelId1, hotelId2, hotelId3]).ToListAsync(); // Assert Assert.Empty(getResults); @@ -183,7 +183,7 @@ public async Task ItCanUpsertRecordAsync() { // Arrange var hotelId = new Guid("55555555-5555-5555-5555-555555555555"); - var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, "TestUpsert"); + var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, "TestUpsert"); await sut.CreateCollectionAsync(); @@ -219,20 +219,19 @@ public async Task VectorizedSearchReturnsValidResultsByDefaultAsync(bool include var hotel3 = this.CreateTestHotel(hotelId: new Guid("33333333-3333-3333-3333-333333333333"), embedding: new[] { 20f, 20f, 20f, 20f }); var hotel4 = this.CreateTestHotel(hotelId: new Guid("44444444-4444-4444-4444-444444444444"), embedding: new[] { -1000f, -1000f, -1000f, -1000f }); - var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, "VectorSearchDefault"); + var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, "VectorSearchDefault"); await sut.CreateCollectionIfNotExistsAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), new() + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), top: 3, new() { IncludeVectors = includeVectors - }); + }).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); var ids = searchResults.Select(l => l.Record.HotelId.ToString()).ToList(); Assert.Equal("11111111-1111-1111-1111-111111111111", ids[0]); @@ -257,21 +256,19 @@ public async Task VectorizedSearchReturnsValidResultsWithOffsetAsync() var hotel3 = this.CreateTestHotel(hotelId: new Guid("33333333-3333-3333-3333-333333333333"), embedding: new[] { 20f, 20f, 20f, 20f }); var hotel4 = this.CreateTestHotel(hotelId: new Guid("44444444-4444-4444-4444-444444444444"), embedding: new[] { -1000f, -1000f, -1000f, -1000f }); - var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, "VectorSearchWithOffset"); + var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, "VectorSearchWithOffset"); await sut.CreateCollectionIfNotExistsAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), new() + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), top: 2, new() { - Top = 2, Skip = 2 - }); + }).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); var ids = searchResults.Select(l => l.Record.HotelId.ToString()).ToList(); Assert.Equal("33333333-3333-3333-3333-333333333333", ids[0]); @@ -291,21 +288,19 @@ public async Task VectorizedSearchReturnsValidResultsWithFilterAsync(VectorSearc var hotel3 = this.CreateTestHotel(hotelId: new Guid("33333333-3333-3333-3333-333333333333"), embedding: new[] { 20f, 20f, 20f, 20f }); var hotel4 = this.CreateTestHotel(hotelId: new Guid("44444444-4444-4444-4444-444444444444"), embedding: new[] { -1000f, -1000f, -1000f, -1000f }); - var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, "VectorSearchWithFilter"); + var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, "VectorSearchWithFilter"); await sut.CreateCollectionIfNotExistsAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel3, hotel1]); // Act - var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), new() + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), top: 4, new() { OldFilter = filter, - Top = 4, - }); + }).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); var actualIds = searchResults.Select(l => l.Record.HotelId.ToString()).ToList(); Assert.Equal(expectedIds, actualIds); @@ -336,21 +331,19 @@ public async Task VectorizedSearchReturnsValidResultsWithFilterAndDifferentDataT Timestamp = new DateTime(2024, 9, 22, 15, 59, 42) }; - var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, "VectorSearchWithFilterAndDataTypes"); + var sut = new WeaviateVectorStoreRecordCollection(fixture.HttpClient!, "VectorSearchWithFilterAndDataTypes"); await sut.CreateCollectionIfNotExistsAsync(); - await sut.UpsertBatchAsync([hotel4, hotel2, hotel5, hotel3, hotel1]).ToListAsync(); + await sut.UpsertAsync([hotel4, hotel2, hotel5, hotel3, hotel1]); // Act - var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([40f, 40f, 40f, 40f]), new() + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([40f, 40f, 40f, 40f]), top: 4, new() { OldFilter = filter, - Top = 4, - }); + }).ToListAsync(); // Assert - var searchResults = await actual.Results.ToListAsync(); var actualIds = searchResults.Select(l => l.Record.HotelId.ToString()).ToList(); Assert.Single(actualIds); @@ -359,35 +352,32 @@ public async Task VectorizedSearchReturnsValidResultsWithFilterAndDifferentDataT } [Fact] - public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() + public async Task ItCanUpsertAndRetrieveUsingDynamicMappingAsync() { // Arrange var hotelId = new Guid("55555555-5555-5555-5555-555555555555"); - var options = new WeaviateVectorStoreRecordCollectionOptions> + var options = new WeaviateVectorStoreRecordCollectionOptions> { VectorStoreRecordDefinition = this.GetTestHotelRecordDefinition() }; - var sut = new WeaviateVectorStoreRecordCollection>(fixture.HttpClient!, "TestGenericMapper", options); + var sut = new WeaviateVectorStoreRecordCollection>(fixture.HttpClient!, "TestDynamicMapper", options); await sut.CreateCollectionAsync(); // Act - var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel(hotelId) + var upsertResult = await sut.UpsertAsync(new Dictionary { - Data = - { - { "HotelName", "Generic Mapper Hotel" }, - { "Description", "This is a generic mapper hotel" }, - { "Tags", new List { "generic" } }, - { "parking_is_included", false }, - { "Timestamp", new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero) }, - { "HotelRating", 3.6f } - }, - Vectors = - { - { "DescriptionEmbedding", new ReadOnlyMemory([30f, 31f, 32f, 33f]) } - } + ["HotelId"] = hotelId, + + ["HotelName"] = "Dynamic Mapper Hotel", + ["Description"] = "This is a dynamic mapper hotel", + ["Tags"] = new List { "dynamic" }, + ["ParkingIncluded"] = false, + ["Timestamp"] = new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), + ["HotelRating"] = 3.6f, + + ["DescriptionEmbedding"] = new ReadOnlyMemory([30f, 31f, 32f, 33f]) }); var localGetResult = await sut.GetAsync(hotelId, new GetRecordOptions { IncludeVectors = true }); @@ -396,13 +386,13 @@ public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() Assert.Equal(hotelId, upsertResult); Assert.NotNull(localGetResult); - Assert.Equal("Generic Mapper Hotel", localGetResult.Data["HotelName"]); - Assert.Equal("This is a generic mapper hotel", localGetResult.Data["Description"]); - Assert.Equal(new List { "generic" }, localGetResult.Data["Tags"]); - Assert.False((bool?)localGetResult.Data["parking_is_included"]); - Assert.Equal(new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), localGetResult.Data["Timestamp"]); - Assert.Equal(3.6f, localGetResult.Data["HotelRating"]); - Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult.Vectors["DescriptionEmbedding"]!).ToArray()); + Assert.Equal("Dynamic Mapper Hotel", localGetResult["HotelName"]); + Assert.Equal("This is a dynamic mapper hotel", localGetResult["Description"]); + Assert.Equal(new List { "dynamic" }, localGetResult["Tags"]); + Assert.False((bool?)localGetResult["ParkingIncluded"]); + Assert.Equal(new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), localGetResult["Timestamp"]); + Assert.Equal(3.6f, localGetResult["HotelRating"]); + Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult["DescriptionEmbedding"]!).ToArray()); } public static TheoryData> VectorizedSearchWithFilterData => new() @@ -474,12 +464,12 @@ private VectorStoreRecordDefinition GetTestHotelRecordDefinition() new VectorStoreRecordKeyProperty("HotelId", typeof(Guid)), new VectorStoreRecordDataProperty("HotelName", typeof(string)), new VectorStoreRecordDataProperty("HotelCode", typeof(int)), - new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { StoragePropertyName = "parking_is_included" }, new VectorStoreRecordDataProperty("HotelRating", typeof(float)), new VectorStoreRecordDataProperty("Tags", typeof(List)), new VectorStoreRecordDataProperty("Description", typeof(string)), new VectorStoreRecordDataProperty("Timestamp", typeof(DateTimeOffset)), - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, IndexKind = IndexKind.Hnsw, DistanceFunction = DistanceFunction.CosineDistance } + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?), 4) { IndexKind = IndexKind.Hnsw, DistanceFunction = DistanceFunction.CosineDistance } ] }; } diff --git a/dotnet/src/IntegrationTests/Data/BaseVectorStoreTextSearchTests.cs b/dotnet/src/IntegrationTests/Data/BaseVectorStoreTextSearchTests.cs index 143c61f69e5f..3a1afcc4303e 100644 --- a/dotnet/src/IntegrationTests/Data/BaseVectorStoreTextSearchTests.cs +++ b/dotnet/src/IntegrationTests/Data/BaseVectorStoreTextSearchTests.cs @@ -41,6 +41,7 @@ public static async Task> AddRecords ITextEmbeddingGenerationService embeddingGenerationService, CreateRecordFromString createRecord) where TKey : notnull + where TRecord : notnull { var lines = await File.ReadAllLinesAsync("./TestData/semantic-kernel-info.txt"); @@ -96,20 +97,6 @@ public Task>> GenerateEmbeddingsAsync(IList } } - /// - /// Decorator for a that generates embeddings for text search queries. - /// - protected sealed class VectorizedSearchWrapper(IVectorizedSearch vectorizedSearch, ITextEmbeddingGenerationService textEmbeddingGeneration) : IVectorizableTextSearch - { - /// - public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) - { - var vectorizedQuery = await textEmbeddingGeneration!.GenerateEmbeddingAsync(searchText, cancellationToken: cancellationToken).ConfigureAwait(false); - - return await vectorizedSearch.VectorizedSearchAsync(vectorizedQuery, options, cancellationToken); - } - } - /// /// Sample model class that represents a record entry. /// @@ -130,7 +117,7 @@ protected sealed class DataModel [VectorStoreRecordData] public required string Link { get; init; } - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public required string Tag { get; init; } [VectorStoreRecordVector(1536)] diff --git a/dotnet/src/IntegrationTests/Data/VectorStoreExtensions.cs b/dotnet/src/IntegrationTests/Data/VectorStoreExtensions.cs index 9f730ce5d259..dea981c16d9f 100644 --- a/dotnet/src/IntegrationTests/Data/VectorStoreExtensions.cs +++ b/dotnet/src/IntegrationTests/Data/VectorStoreExtensions.cs @@ -48,6 +48,7 @@ internal static async Task> CreateCo ITextEmbeddingGenerationService embeddingGenerationService, CreateRecordFromString createRecord) where TKey : notnull + where TRecord : notnull { // Get and create collection if it doesn't exist. var collection = vectorStore.GetCollection(collectionName); @@ -83,6 +84,7 @@ internal static async Task> CreateCo ITextEmbeddingGenerationService embeddingGenerationService, CreateRecordFromTextSearchResult createRecord) where TKey : notnull + where TRecord : notnull { // Get and create collection if it doesn't exist. var collection = vectorStore.GetCollection(collectionName); diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index 284f7d46a978..b53f678c5d4b 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -5,7 +5,7 @@ net8.0 true false - $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110,OPENAI001 + $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110,OPENAI001,MEVD9000 b7762d10-e29b-4bb1-8b74-b6d69a667dd4 @@ -65,6 +65,7 @@ all + diff --git a/dotnet/src/IntegrationTests/README.md b/dotnet/src/IntegrationTests/README.md index 90a65b9d7531..cc84d8609e57 100644 --- a/dotnet/src/IntegrationTests/README.md +++ b/dotnet/src/IntegrationTests/README.md @@ -4,7 +4,7 @@ 1. **Azure OpenAI**: go to the [Azure OpenAI Quickstart](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/quickstart) 1. Deploy the following models: - 1. `dall-e-3` DALL-E 3 generates images and is used in Text to Image tests. + 1. `dall-e-3` DALL-E 3 generates images and is used in Text to Image tests. 1. `tts` TTS is a model that converts text to natural sounding speech and is used in Text to Audio tests. 1. `whisper` The Whisper models are trained for speech recognition and translation tasks and is used in Audio to Text tests. 1. `text-embedding-ada-002` Text Embedding Ada 002 is used in Text Embedding tests. diff --git a/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/IMongoDBMapper.cs b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/IMongoDBMapper.cs new file mode 100644 index 000000000000..c59c97a69ce9 --- /dev/null +++ b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/IMongoDBMapper.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; +using MongoDB.Bson; + +internal interface IMongoDBMapper +{ + /// + /// Maps from the consumer record data model to the storage model. + /// + BsonDocument MapFromDataToStorageModel(TRecord dataModel, Embedding?[]? generatedEmbeddings); + + /// + /// Maps from the storage model to the consumer record data model. + /// + TRecord MapFromStorageToDataModel(BsonDocument storageModel, StorageToDataModelMapperOptions options); +} diff --git a/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBConstants.cs b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBConstants.cs index 7acd839dd0e3..279da36b7895 100644 --- a/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBConstants.cs +++ b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBConstants.cs @@ -13,6 +13,8 @@ namespace Microsoft.SemanticKernel.Connectors.MongoDB; [ExcludeFromCodeCoverage] internal static class MongoDBConstants { + internal const string VectorStoreSystemName = "mongodb"; + /// Default ratio of number of nearest neighbors to number of documents to return. internal const int DefaultNumCandidatesRatio = 10; @@ -44,20 +46,13 @@ internal static class MongoDBConstants internal static readonly HashSet SupportedDataTypes = [ typeof(bool), - typeof(bool?), typeof(string), typeof(int), - typeof(int?), typeof(long), - typeof(long?), typeof(float), - typeof(float?), typeof(double), - typeof(double?), typeof(decimal), - typeof(decimal?), typeof(DateTime), - typeof(DateTime?), ]; /// A containing the supported vector types. diff --git a/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBDynamicDataModelMapper.cs b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBDynamicDataModelMapper.cs new file mode 100644 index 000000000000..e1507273810f --- /dev/null +++ b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBDynamicDataModelMapper.cs @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using MongoDB.Bson; + +namespace Microsoft.SemanticKernel.Connectors.MongoDB; + +/// +/// A mapper that maps between the dynamic data model and the model that the data is stored under, within MongoDB. +/// +[ExcludeFromCodeCoverage] +#pragma warning disable CS0618 // IVectorStoreRecordMapper is obsolete +internal sealed class MongoDBDynamicDataModelMapper(VectorStoreRecordModel model) : IMongoDBMapper> +#pragma warning restore CS0618 +{ + /// + public BsonDocument MapFromDataToStorageModel(Dictionary dataModel, Embedding?[]? generatedEmbeddings) + { + Verify.NotNull(dataModel); + + var document = new BsonDocument(); + + document[MongoDBConstants.MongoReservedKeyPropertyName] = !dataModel.TryGetValue(model.KeyProperty.ModelName, out var keyValue) + ? throw new KeyNotFoundException($"Missing value for key property '{model.KeyProperty.ModelName}") + : keyValue switch + { + string s => s, + null => throw new InvalidOperationException($"Key property '{model.KeyProperty.ModelName}' is null."), + _ => throw new InvalidCastException($"Key property '{model.KeyProperty.ModelName}' must be a string.") + }; + + document[MongoDBConstants.MongoReservedKeyPropertyName] = (string)(dataModel[model.KeyProperty.ModelName] + ?? throw new InvalidOperationException($"Key property '{model.KeyProperty.ModelName}' is null.")); + + foreach (var property in model.DataProperties) + { + if (dataModel.TryGetValue(property.ModelName, out var dataValue)) + { + document[property.StorageName] = BsonValue.Create(dataValue); + } + } + + for (var i = 0; i < model.VectorProperties.Count; i++) + { + var property = model.VectorProperties[i]; + + if (generatedEmbeddings?[i] is null) + { + // No generated embedding, read the vector directly from the data model + if (dataModel.TryGetValue(property.ModelName, out var vectorValue)) + { + document[property.StorageName] = BsonArray.Create(GetVectorArray(vectorValue)); + } + } + else + { + Debug.Assert(property.EmbeddingGenerator is not null); + var embedding = generatedEmbeddings[i]; + document[property.StorageName] = embedding switch + { + Embedding e => BsonArray.Create(e.Vector.ToArray()), + Embedding e => BsonArray.Create(e.Vector.ToArray()), + _ => throw new UnreachableException() + }; + } + } + + return document; + } + + /// + public Dictionary MapFromStorageToDataModel(BsonDocument storageModel, StorageToDataModelMapperOptions options) + { + Verify.NotNull(storageModel); + + var result = new Dictionary(); + + // Loop through all known properties and map each from the storage model to the data model. + foreach (var property in model.Properties) + { + switch (property) + { + case VectorStoreRecordKeyPropertyModel keyProperty: + result[keyProperty.ModelName] = storageModel.TryGetValue(MongoDBConstants.MongoReservedKeyPropertyName, out var keyValue) + ? keyValue.AsString + : throw new VectorStoreRecordMappingException("No key property was found in the record retrieved from storage."); + continue; + + case VectorStoreRecordDataPropertyModel dataProperty: + if (storageModel.TryGetValue(dataProperty.StorageName, out var dataValue)) + { + result.Add(dataProperty.ModelName, GetDataPropertyValue(property.ModelName, property.Type, dataValue)); + } + continue; + + case VectorStoreRecordVectorPropertyModel vectorProperty: + if (options.IncludeVectors && storageModel.TryGetValue(vectorProperty.StorageName, out var vectorValue)) + { + result.Add(vectorProperty.ModelName, GetVectorPropertyValue(property.ModelName, property.Type, vectorValue)); + } + continue; + + default: + throw new UnreachableException(); + } + } + + return result; + } + + #region private + + private static object? GetDataPropertyValue(string propertyName, Type propertyType, BsonValue value) + { + if (value.IsBsonNull) + { + return null; + } + + return propertyType switch + { + Type t when t == typeof(bool) => value.AsBoolean, + Type t when t == typeof(bool?) => value.AsNullableBoolean, + Type t when t == typeof(string) => value.AsString, + Type t when t == typeof(int) => value.AsInt32, + Type t when t == typeof(int?) => value.AsNullableInt32, + Type t when t == typeof(long) => value.AsInt64, + Type t when t == typeof(long?) => value.AsNullableInt64, + Type t when t == typeof(float) => ((float)value.AsDouble), + Type t when t == typeof(float?) => ((float?)value.AsNullableDouble), + Type t when t == typeof(double) => value.AsDouble, + Type t when t == typeof(double?) => value.AsNullableDouble, + Type t when t == typeof(decimal) => value.AsDecimal, + Type t when t == typeof(decimal?) => value.AsNullableDecimal, + Type t when t == typeof(DateTime) => value.ToUniversalTime(), + Type t when t == typeof(DateTime?) => value.ToNullableUniversalTime(), + Type t when typeof(IEnumerable).IsAssignableFrom(t) => value.AsBsonArray.Select( + item => GetDataPropertyValue(propertyName, VectorStoreRecordPropertyVerification.GetCollectionElementType(t), item)), + _ => throw new NotSupportedException($"Mapping for property {propertyName} with type {propertyType.FullName} is not supported in dynamic data model.") + }; + } + + private static object? GetVectorPropertyValue(string propertyName, Type propertyType, BsonValue value) + { + if (value.IsBsonNull) + { + return null; + } + + return propertyType switch + { + Type t when t == typeof(ReadOnlyMemory) || t == typeof(ReadOnlyMemory?) => + new ReadOnlyMemory(value.AsBsonArray.Select(item => (float)item.AsDouble).ToArray()), + Type t when t == typeof(ReadOnlyMemory) || t == typeof(ReadOnlyMemory?) => + new ReadOnlyMemory(value.AsBsonArray.Select(item => item.AsDouble).ToArray()), + _ => throw new NotSupportedException($"Mapping for property {propertyName} with type {propertyType.FullName} is not supported in dynamic data model.") + }; + } + + private static object GetVectorArray(object? vector) + { + if (vector is null) + { + return Array.Empty(); + } + + return vector switch + { + ReadOnlyMemory memoryFloat => memoryFloat.ToArray(), + ReadOnlyMemory memoryDouble => memoryDouble.ToArray(), + _ => throw new NotSupportedException($"Mapping for type {vector.GetType().FullName} is not supported in dynamic data model.") + }; + } + + #endregion +} diff --git a/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBGenericDataModelMapper.cs b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBGenericDataModelMapper.cs deleted file mode 100644 index 8ec0dffb935c..000000000000 --- a/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBGenericDataModelMapper.cs +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections; -using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; -using System.Linq; -using Microsoft.Extensions.VectorData; -using MongoDB.Bson; - -namespace Microsoft.SemanticKernel.Connectors.MongoDB; - -/// -/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within MongoDB. -/// -[ExcludeFromCodeCoverage] -internal sealed class MongoDBGenericDataModelMapper : IVectorStoreRecordMapper, BsonDocument> -{ - /// A that defines the schema of the data in the database. - private readonly VectorStoreRecordDefinition _vectorStoreRecordDefinition; - - /// - /// Initializes a new instance of the class. - /// - /// A that defines the schema of the data in the database. - public MongoDBGenericDataModelMapper(VectorStoreRecordDefinition vectorStoreRecordDefinition) - { - Verify.NotNull(vectorStoreRecordDefinition); - - this._vectorStoreRecordDefinition = vectorStoreRecordDefinition; - } - - /// - public BsonDocument MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - { - Verify.NotNull(dataModel); - - var document = new BsonDocument(); - - // Loop through all known properties and map each from the data model to the storage model. - foreach (var property in this._vectorStoreRecordDefinition.Properties) - { - var storagePropertyName = property.StoragePropertyName ?? property.DataModelPropertyName; - - if (property is VectorStoreRecordKeyProperty keyProperty) - { - document[MongoDBConstants.MongoReservedKeyPropertyName] = dataModel.Key; - } - else if (property is VectorStoreRecordDataProperty dataProperty) - { - if (dataModel.Data is not null && dataModel.Data.TryGetValue(dataProperty.DataModelPropertyName, out var dataValue)) - { - document[storagePropertyName] = BsonValue.Create(dataValue); - } - } - else if (property is VectorStoreRecordVectorProperty vectorProperty) - { - if (dataModel.Vectors is not null && dataModel.Vectors.TryGetValue(vectorProperty.DataModelPropertyName, out var vectorValue)) - { - document[storagePropertyName] = BsonArray.Create(GetVectorArray(vectorValue)); - } - } - } - - return document; - } - - /// - public VectorStoreGenericDataModel MapFromStorageToDataModel(BsonDocument storageModel, StorageToDataModelMapperOptions options) - { - Verify.NotNull(storageModel); - - // Create variables to store the response properties. - string? key = null; - var dataProperties = new Dictionary(); - var vectorProperties = new Dictionary(); - - // Loop through all known properties and map each from the storage model to the data model. - foreach (var property in this._vectorStoreRecordDefinition.Properties) - { - var storagePropertyName = property.StoragePropertyName ?? property.DataModelPropertyName; - - if (property is VectorStoreRecordKeyProperty keyProperty) - { - if (storageModel.TryGetValue(MongoDBConstants.MongoReservedKeyPropertyName, out var keyValue)) - { - key = keyValue.AsString; - } - } - else if (property is VectorStoreRecordDataProperty dataProperty) - { - if (!storageModel.TryGetValue(storagePropertyName, out var dataValue)) - { - continue; - } - - dataProperties.Add(dataProperty.DataModelPropertyName, GetDataPropertyValue(property.DataModelPropertyName, property.PropertyType, dataValue)); - } - else if (property is VectorStoreRecordVectorProperty vectorProperty && options.IncludeVectors) - { - if (!storageModel.TryGetValue(storagePropertyName, out var vectorValue)) - { - continue; - } - - vectorProperties.Add(vectorProperty.DataModelPropertyName, GetVectorPropertyValue(property.DataModelPropertyName, property.PropertyType, vectorValue)); - } - } - - if (key is null) - { - throw new VectorStoreRecordMappingException("No key property was found in the record retrieved from storage."); - } - - return new VectorStoreGenericDataModel(key) { Data = dataProperties, Vectors = vectorProperties }; - } - - #region private - - private static object? GetDataPropertyValue(string propertyName, Type propertyType, BsonValue value) - { - if (value.IsBsonNull) - { - return null; - } - - return propertyType switch - { - Type t when t == typeof(bool) => value.AsBoolean, - Type t when t == typeof(bool?) => value.AsNullableBoolean, - Type t when t == typeof(string) => value.AsString, - Type t when t == typeof(int) => value.AsInt32, - Type t when t == typeof(int?) => value.AsNullableInt32, - Type t when t == typeof(long) => value.AsInt64, - Type t when t == typeof(long?) => value.AsNullableInt64, - Type t when t == typeof(float) => ((float)value.AsDouble), - Type t when t == typeof(float?) => ((float?)value.AsNullableDouble), - Type t when t == typeof(double) => value.AsDouble, - Type t when t == typeof(double?) => value.AsNullableDouble, - Type t when t == typeof(decimal) => value.AsDecimal, - Type t when t == typeof(decimal?) => value.AsNullableDecimal, - Type t when t == typeof(DateTime) => value.ToUniversalTime(), - Type t when t == typeof(DateTime?) => value.ToNullableUniversalTime(), - Type t when typeof(IEnumerable).IsAssignableFrom(t) => value.AsBsonArray.Select( - item => GetDataPropertyValue(propertyName, VectorStoreRecordPropertyVerification.GetCollectionElementType(t), item)), - _ => throw new NotSupportedException($"Mapping for property {propertyName} with type {propertyType.FullName} is not supported in generic data model.") - }; - } - - private static object? GetVectorPropertyValue(string propertyName, Type propertyType, BsonValue value) - { - if (value.IsBsonNull) - { - return null; - } - - return propertyType switch - { - Type t when t == typeof(ReadOnlyMemory) || t == typeof(ReadOnlyMemory?) => - new ReadOnlyMemory(value.AsBsonArray.Select(item => (float)item.AsDouble).ToArray()), - Type t when t == typeof(ReadOnlyMemory) || t == typeof(ReadOnlyMemory?) => - new ReadOnlyMemory(value.AsBsonArray.Select(item => item.AsDouble).ToArray()), - _ => throw new NotSupportedException($"Mapping for property {propertyName} with type {propertyType.FullName} is not supported in generic data model.") - }; - } - - private static object GetVectorArray(object? vector) - { - if (vector is null) - { - return Array.Empty(); - } - - return vector switch - { - ReadOnlyMemory memoryFloat => memoryFloat.ToArray(), - ReadOnlyMemory memoryDouble => memoryDouble.ToArray(), - _ => throw new NotSupportedException($"Mapping for type {vector.GetType().FullName} is not supported in generic data model.") - }; - } - - #endregion -} diff --git a/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBModelBuilder.cs b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBModelBuilder.cs new file mode 100644 index 000000000000..447b0d0ee939 --- /dev/null +++ b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBModelBuilder.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Reflection; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; +using MongoDB.Bson.Serialization.Attributes; + +namespace Microsoft.SemanticKernel.Connectors.MongoDB; + +/// +/// Customized MongoDB model builder that adds specialized configuration of property storage names +/// (Mongo's reserve key property name and [BsonElement]). +/// +internal class MongoDBModelBuilder() : VectorStoreRecordModelBuilder(s_validationOptions) +{ + private static readonly VectorStoreRecordModelBuildingOptions s_validationOptions = new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + UsesExternalSerializer = true, + + SupportedKeyPropertyTypes = MongoDBConstants.SupportedKeyTypes, + SupportedDataPropertyTypes = MongoDBConstants.SupportedDataTypes, + SupportedEnumerableDataPropertyElementTypes = MongoDBConstants.SupportedDataTypes, + SupportedVectorPropertyTypes = MongoDBConstants.SupportedVectorTypes + }; + + protected override void ProcessTypeProperties(Type type, VectorStoreRecordDefinition? vectorStoreRecordDefinition) + { + base.ProcessTypeProperties(type, vectorStoreRecordDefinition); + + foreach (var property in this.Properties) + { + if (property.PropertyInfo?.GetCustomAttribute() is { } bsonElementAttribute) + { + property.StorageName = bsonElementAttribute.ElementName; + } + } + } +} diff --git a/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBVectorStoreRecordMapper.cs b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBVectorStoreRecordMapper.cs index 2ddb4f594fd7..2d2b6a237229 100644 --- a/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBVectorStoreRecordMapper.cs +++ b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBVectorStoreRecordMapper.cs @@ -1,9 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Reflection; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using MongoDB.Bson; using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Attributes; @@ -12,26 +15,29 @@ namespace Microsoft.SemanticKernel.Connectors.MongoDB; [ExcludeFromCodeCoverage] -internal sealed class MongoDBVectorStoreRecordMapper : IVectorStoreRecordMapper +#pragma warning disable CS0618 // IVectorStoreRecordMapper is obsolete +internal sealed class MongoDBVectorStoreRecordMapper : IMongoDBMapper +#pragma warning restore CS0618 { + private readonly VectorStoreRecordModel _model; + /// A key property info of the data model. - private readonly PropertyInfo _keyProperty; + private readonly PropertyInfo? _keyClrProperty; /// A key property name of the data model. - private readonly string _keyPropertyName; + private readonly string _keyPropertyModelName; /// /// Initializes a new instance of the class. /// - /// A helper to access property information for the current data model and record definition. - public MongoDBVectorStoreRecordMapper(VectorStoreRecordPropertyReader propertyReader) + /// The model. + public MongoDBVectorStoreRecordMapper(VectorStoreRecordModel model) { - propertyReader.VerifyKeyProperties(MongoDBConstants.SupportedKeyTypes); - propertyReader.VerifyDataProperties(MongoDBConstants.SupportedDataTypes, supportEnumerable: true); - propertyReader.VerifyVectorProperties(MongoDBConstants.SupportedVectorTypes); + this._model = model; - this._keyPropertyName = propertyReader.KeyPropertyName; - this._keyProperty = propertyReader.KeyPropertyInfo; + var keyProperty = model.KeyProperty; + this._keyPropertyModelName = keyProperty.ModelName; + this._keyClrProperty = keyProperty.PropertyInfo; var conventionPack = new ConventionPack { @@ -44,34 +50,69 @@ public MongoDBVectorStoreRecordMapper(VectorStoreRecordPropertyReader propertyRe type => type == typeof(TRecord)); } - public BsonDocument MapFromDataToStorageModel(TRecord dataModel) + public BsonDocument MapFromDataToStorageModel(TRecord dataModel, Embedding?[]? generatedEmbeddings) { var document = dataModel.ToBsonDocument(); // Handle key property mapping due to reserved key name in Mongo. if (!document.Contains(MongoDBConstants.MongoReservedKeyPropertyName)) { - var value = document[this._keyPropertyName]; + var value = document[this._keyPropertyModelName]; - document.Remove(this._keyPropertyName); + document.Remove(this._keyPropertyModelName); document[MongoDBConstants.MongoReservedKeyPropertyName] = value; } + // Go over the vector properties; those which have an embedding generator configured on them will have embedding generators, overwrite + // the value in the JSON object with that. + if (generatedEmbeddings is not null) + { + for (var i = 0; i < this._model.VectorProperties.Count; i++) + { + if (generatedEmbeddings[i] is not null) + { + var property = this._model.VectorProperties[i]; + Debug.Assert(property.EmbeddingGenerator is not null); + var embedding = generatedEmbeddings[i]; + document[property.StorageName] = embedding switch + { + Embedding e => BsonArray.Create(e.Vector.ToArray()), + Embedding e => BsonArray.Create(e.Vector.ToArray()), + _ => throw new UnreachableException() + }; + } + } + } + return document; } public TRecord MapFromStorageToDataModel(BsonDocument storageModel, StorageToDataModelMapperOptions options) { // Handle key property mapping due to reserved key name in Mongo. - if (!this._keyPropertyName.Equals(MongoDBConstants.DataModelReservedKeyPropertyName, StringComparison.OrdinalIgnoreCase) && - this._keyProperty.GetCustomAttribute() is null) + if (!this._keyPropertyModelName.Equals(MongoDBConstants.DataModelReservedKeyPropertyName, StringComparison.OrdinalIgnoreCase) && + this._keyClrProperty?.GetCustomAttribute() is null) { var value = storageModel[MongoDBConstants.MongoReservedKeyPropertyName]; storageModel.Remove(MongoDBConstants.MongoReservedKeyPropertyName); - storageModel[this._keyPropertyName] = value; + storageModel[this._keyPropertyModelName] = value; + } + + // For vector properties which have embedding generation configured, we need to remove the embeddings before deserializing + // (we can't go back from an embedding to e.g. string). + // For other cases (no embedding generation), we leave the properties even if IncludeVectors is false. + if (!options.IncludeVectors) + { + foreach (var vectorProperty in this._model.VectorProperties) + { + if (vectorProperty.EmbeddingGenerator is not null) + { + storageModel.Remove(vectorProperty.StorageName); + } + } } return BsonSerializer.Deserialize(storageModel); diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreErrorHandler.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreErrorHandler.cs index 714befadc810..2a57108caead 100644 --- a/dotnet/src/InternalUtilities/src/Data/VectorStoreErrorHandler.cs +++ b/dotnet/src/InternalUtilities/src/Data/VectorStoreErrorHandler.cs @@ -16,13 +16,19 @@ internal static class VectorStoreErrorHandler /// Run the given model conversion and wrap any exceptions with . /// /// The response type of the operation. - /// The name of the database system the operation is being run on. + /// The name of the vector store system the operation is being run on. + /// The name of the vector store the operation is being run on. /// The name of the collection the operation is being run on. /// The type of database operation being run. /// The operation to run. /// The result of the operation. [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static T RunModelConversion(string databaseSystemName, string collectionName, string operationName, Func operation) + public static T RunModelConversion( + string vectorStoreSystemName, + string? vectorStoreName, + string collectionName, + string operationName, + Func operation) { try { @@ -32,7 +38,8 @@ public static T RunModelConversion(string databaseSystemName, string collecti { throw new VectorStoreRecordMappingException("Failed to convert vector store record.", ex) { - VectorStoreType = databaseSystemName, + VectorStoreSystemName = vectorStoreSystemName, + VectorStoreName = vectorStoreName, CollectionName = collectionName, OperationName = operationName }; diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordMapping.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordMapping.cs index f5b39e396171..cd78fddb4be4 100644 --- a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordMapping.cs +++ b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordMapping.cs @@ -5,7 +5,6 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; -using System.Reflection; namespace Microsoft.Extensions.VectorData; @@ -15,83 +14,6 @@ namespace Microsoft.Extensions.VectorData; [ExcludeFromCodeCoverage] internal static class VectorStoreRecordMapping { - /// - /// Loop through the list of objects and for each one look up the storage name - /// in the and check if the value exists in the . - /// If so, set the value on the record object. - /// - /// The type of the storage properties. - /// The type of the target object. - /// The target object to set the property values on. - /// objects listing the properties on the data model to get values for. - /// Storage property names keyed by data property names. - /// A dictionary of storage values by storage property name. - /// An optional function to convert the storage property values to data property values. - public static void SetValuesOnProperties( - TRecord record, - IEnumerable dataModelPropertiesInfo, - IReadOnlyDictionary dataModelToStorageNameMapping, - IReadOnlyDictionary storageValues, - Func? storageValueConverter = null) - { - var propertiesInfoWithValues = BuildPropertiesInfoWithValues( - dataModelPropertiesInfo, - dataModelToStorageNameMapping, - storageValues, - storageValueConverter); - - SetPropertiesOnRecord(record, propertiesInfoWithValues); - } - - /// - /// Build a list of properties with their values from the given data model properties and storage values. - /// - /// The type of the storage properties. - /// objects listing the properties on the data model to get values for. - /// Storage property names keyed by data property names. - /// A dictionary of storage values by storage property name. - /// An optional function to convert the storage property values to data property values. - /// The list of data property objects and their values. - public static IEnumerable> BuildPropertiesInfoWithValues( - IEnumerable dataModelPropertiesInfo, - IReadOnlyDictionary dataModelToStorageNameMapping, - IReadOnlyDictionary storageValues, - Func? storageValueConverter = null) - { - foreach (var propertyInfo in dataModelPropertiesInfo) - { - if (dataModelToStorageNameMapping.TryGetValue(propertyInfo.Name, out var storageName) && - storageValues.TryGetValue(storageName, out var storageValue)) - { - if (storageValueConverter is not null) - { - var convertedStorageValue = storageValueConverter(storageValue, propertyInfo.PropertyType); - yield return new KeyValuePair(propertyInfo, convertedStorageValue); - } - else - { - yield return new KeyValuePair(propertyInfo, (object?)storageValue); - } - } - } - } - - /// - /// Set the given list of properties with their values on the given object. - /// - /// The type of the target object. - /// The target object to set the property values on. - /// A list of properties and their values to set. - public static void SetPropertiesOnRecord( - TRecord record, - IEnumerable> propertiesInfoWithValues) - { - foreach (var propertyInfoWithValue in propertiesInfoWithValues) - { - propertyInfoWithValue.Key.SetValue(record, propertyInfoWithValue.Value); - } - } - /// /// Create an enumerable of the required type from the input enumerable. /// diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs deleted file mode 100644 index d259a1ac0f4f..000000000000 --- a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs +++ /dev/null @@ -1,806 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; -using System.Linq; -using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Text.Json; -using System.Text.Json.Serialization; - -namespace Microsoft.Extensions.VectorData; - -/// -/// Contains helpers for reading vector store model properties and their attributes. -/// -[ExcludeFromCodeCoverage] -#pragma warning disable CA1812 // Used in some projects but not all, so need to suppress to avoid warnings in those it's not used in. -internal sealed class VectorStoreRecordPropertyReader -#pragma warning restore CA1812 -{ - /// The of the data model. - [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicConstructors)] - private readonly Type _dataModelType; - - /// A definition of the current storage model. - private readonly VectorStoreRecordDefinition _vectorStoreRecordDefinition; - - /// Options for configuring the behavior of this class. - private readonly VectorStoreRecordPropertyReaderOptions _options; - - /// The key properties from the definition. - private readonly List _keyProperties; - - /// The data properties from the definition. - private readonly List _dataProperties; - - /// The vector properties from the definition. - private readonly List _vectorProperties; - - /// The of the parameterless constructor from the data model if one exists. - private readonly Lazy _parameterlessConstructorInfo; - - /// The key objects from the data model. - private List? _keyPropertiesInfo; - - /// The data objects from the data model. - private List? _dataPropertiesInfo; - - /// The vector objects from the data model. - private List? _vectorPropertiesInfo; - - /// A lazy initialized map of data model property names to the names under which they are stored in the data store. - private readonly Lazy> _storagePropertyNamesMap; - - /// A lazy initialized list of storage names of key properties. - private readonly Lazy> _keyPropertyStoragePropertyNames; - - /// A lazy initialized list of storage names of data properties. - private readonly Lazy> _dataPropertyStoragePropertyNames; - - /// A lazy initialized list of storage names of vector properties. - private readonly Lazy> _vectorPropertyStoragePropertyNames; - - /// A lazy initialized map of data model property names to the names they will have if serialized to JSON. - private readonly Lazy> _jsonPropertyNamesMap; - - /// A lazy initialized list of json names of key properties. - private readonly Lazy> _keyPropertyJsonNames; - - /// A lazy initialized list of json names of data properties. - private readonly Lazy> _dataPropertyJsonNames; - - /// A lazy initialized list of json names of vector properties. - private readonly Lazy> _vectorPropertyJsonNames; - - public VectorStoreRecordPropertyReader( - [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicConstructors)] Type dataModelType, - VectorStoreRecordDefinition? vectorStoreRecordDefinition, - VectorStoreRecordPropertyReaderOptions? options) - { - this._dataModelType = dataModelType; - this._options = options ?? new VectorStoreRecordPropertyReaderOptions(); - - // If a definition is provided, use it. Otherwise, create one from the type. - if (vectorStoreRecordDefinition is not null) - { - // Here we received a definition, which gives us all of the information we need. - // Some mappers though need to set properties on the data model using reflection - // so we may still need to find the PropertyInfo objects on the data model later if required. - this._vectorStoreRecordDefinition = vectorStoreRecordDefinition; - } - else - { - // Here we didn't receive a definition, so we need to derive the information from - // the data model. Since we may need the PropertyInfo objects later to read or write - // property values on the data model, we save them for later in case we need them. - var propertiesInfo = FindPropertiesInfo(dataModelType); - this._vectorStoreRecordDefinition = CreateVectorStoreRecordDefinitionFromType(propertiesInfo); - - this._keyPropertiesInfo = propertiesInfo.KeyProperties; - this._dataPropertiesInfo = propertiesInfo.DataProperties; - this._vectorPropertiesInfo = propertiesInfo.VectorProperties; - } - - // Verify the definition to make sure it does not have too many or too few of each property type. - (this._keyProperties, this._dataProperties, this._vectorProperties) = SplitDefinitionAndVerify( - dataModelType.Name, - this._vectorStoreRecordDefinition, - this._options.SupportsMultipleKeys, - this._options.SupportsMultipleVectors, - this._options.RequiresAtLeastOneVector); - - // Setup lazy initializers. - this._storagePropertyNamesMap = new Lazy>(() => - { - return BuildPropertyNameToStorageNameMap((this._keyProperties, this._dataProperties, this._vectorProperties)); - }); - - this._parameterlessConstructorInfo = new Lazy(() => - { - var constructor = dataModelType.GetConstructor(Type.EmptyTypes); - if (constructor == null) - { - throw new ArgumentException($"Type {dataModelType.FullName} must have a parameterless constructor."); - } - - return constructor; - }); - - this._keyPropertyStoragePropertyNames = new Lazy>(() => - { - var storagePropertyNames = this._storagePropertyNamesMap.Value; - return this._keyProperties.Select(x => storagePropertyNames[x.DataModelPropertyName]).ToList(); - }); - - this._dataPropertyStoragePropertyNames = new Lazy>(() => - { - var storagePropertyNames = this._storagePropertyNamesMap.Value; - return this._dataProperties.Select(x => storagePropertyNames[x.DataModelPropertyName]).ToList(); - }); - - this._vectorPropertyStoragePropertyNames = new Lazy>(() => - { - var storagePropertyNames = this._storagePropertyNamesMap.Value; - return this._vectorProperties.Select(x => storagePropertyNames[x.DataModelPropertyName]).ToList(); - }); - - this._jsonPropertyNamesMap = new Lazy>(() => - { - return BuildPropertyNameToJsonPropertyNameMap( - (this._keyProperties, this._dataProperties, this._vectorProperties), - dataModelType, - this._options?.JsonSerializerOptions); - }); - - this._keyPropertyJsonNames = new Lazy>(() => - { - var jsonPropertyNamesMap = this._jsonPropertyNamesMap.Value; - return this._keyProperties.Select(x => jsonPropertyNamesMap[x.DataModelPropertyName]).ToList(); - }); - - this._dataPropertyJsonNames = new Lazy>(() => - { - var jsonPropertyNamesMap = this._jsonPropertyNamesMap.Value; - return this._dataProperties.Select(x => jsonPropertyNamesMap[x.DataModelPropertyName]).ToList(); - }); - - this._vectorPropertyJsonNames = new Lazy>(() => - { - var jsonPropertyNamesMap = this._jsonPropertyNamesMap.Value; - return this._vectorProperties.Select(x => jsonPropertyNamesMap[x.DataModelPropertyName]).ToList(); - }); - } - - /// Gets the record definition of the current storage model. - public VectorStoreRecordDefinition RecordDefinition => this._vectorStoreRecordDefinition; - - /// Gets the list of properties from the record definition. - public IReadOnlyList Properties => this._vectorStoreRecordDefinition.Properties; - - /// Gets the first object from the record definition that was provided or that was generated from the data model. - public VectorStoreRecordKeyProperty KeyProperty => this._keyProperties[0]; - - /// Gets all objects from the record definition that was provided or that was generated from the data model. - public IReadOnlyList KeyProperties => this._keyProperties; - - /// Gets all objects from the record definition that was provided or that was generated from the data model. - public IReadOnlyList DataProperties => this._dataProperties; - - /// Gets the first objects from the record definition that was provided or that was generated from the data model. - public VectorStoreRecordVectorProperty? VectorProperty => this._vectorProperties.Count > 0 ? this._vectorProperties[0] : null; - - /// Gets all objects from the record definition that was provided or that was generated from the data model. - public IReadOnlyList VectorProperties => this._vectorProperties; - - /// Gets the parameterless constructor if one exists, throws otherwise. - public ConstructorInfo ParameterLessConstructorInfo => this._parameterlessConstructorInfo.Value; - - /// Gets the first key property info object. - public PropertyInfo KeyPropertyInfo - { - get - { - this.LoadPropertyInfoIfNeeded(); - return this._keyPropertiesInfo![0]; - } - } - - /// Gets the key property info objects. - public IReadOnlyList KeyPropertiesInfo - { - get - { - this.LoadPropertyInfoIfNeeded(); - return this._keyPropertiesInfo!; - } - } - - /// Gets the data property info objects. - public IReadOnlyList DataPropertiesInfo - { - get - { - this.LoadPropertyInfoIfNeeded(); - return this._dataPropertiesInfo!; - } - } - - /// Gets the vector property info objects. - public IReadOnlyList VectorPropertiesInfo - { - get - { - this.LoadPropertyInfoIfNeeded(); - return this._vectorPropertiesInfo!; - } - } - - /// Gets the name of the first vector property in the definition or null if there are no vectors. - public string? FirstVectorPropertyName => this._vectorProperties.FirstOrDefault()?.DataModelPropertyName; - - /// Gets the first vector PropertyInfo object in the data model or null if there are no vectors. - public PropertyInfo? FirstVectorPropertyInfo => this.VectorPropertiesInfo.Count > 0 ? this.VectorPropertiesInfo[0] : null; - - /// Gets the property name of the first key property in the definition. - public string KeyPropertyName => this._keyProperties[0].DataModelPropertyName; - - /// Gets the storage name of the first key property in the definition. - public string KeyPropertyStoragePropertyName => this._keyPropertyStoragePropertyNames.Value[0]; - - /// Gets the storage names of all the properties in the definition. - public IReadOnlyDictionary StoragePropertyNamesMap => this._storagePropertyNamesMap.Value; - - /// Gets the storage names of the key properties in the definition. - public IReadOnlyList KeyPropertyStoragePropertyNames => this._keyPropertyStoragePropertyNames.Value; - - /// Gets the storage names of the data properties in the definition. - public IReadOnlyList DataPropertyStoragePropertyNames => this._dataPropertyStoragePropertyNames.Value; - - /// Gets the storage name of the first vector property in the definition or null if there are no vectors. - public string? FirstVectorPropertyStoragePropertyName => this.FirstVectorPropertyName == null ? null : this.StoragePropertyNamesMap[this.FirstVectorPropertyName]; - - /// Gets the storage names of the vector properties in the definition. - public IReadOnlyList VectorPropertyStoragePropertyNames => this._vectorPropertyStoragePropertyNames.Value; - - /// Gets the json name of the first key property in the definition. - public string KeyPropertyJsonName => this.KeyPropertyJsonNames[0]; - - /// Gets the json names of the key properties in the definition. - public IReadOnlyList KeyPropertyJsonNames => this._keyPropertyJsonNames.Value; - - /// Gets the json names of the data properties in the definition. - public IReadOnlyList DataPropertyJsonNames => this._dataPropertyJsonNames.Value; - - /// Gets the json name of the first vector property in the definition or null if there are no vectors. - public string? FirstVectorPropertyJsonName => this.FirstVectorPropertyName == null ? null : this.JsonPropertyNamesMap[this.FirstVectorPropertyName]; - - /// Gets the json names of the vector properties in the definition. - public IReadOnlyList VectorPropertyJsonNames => this._vectorPropertyJsonNames.Value; - - /// A map of data model property names to the names they will have if serialized to JSON. - public IReadOnlyDictionary JsonPropertyNamesMap => this._jsonPropertyNamesMap.Value; - - /// Verify that the data model has a parameterless constructor. - public void VerifyHasParameterlessConstructor() - { - var constructorInfo = this._parameterlessConstructorInfo.Value; - } - - /// Verify that the types of the key properties fall within the provided set. - /// The list of supported types. - public void VerifyKeyProperties(HashSet supportedTypes) - { - VectorStoreRecordPropertyVerification.VerifyPropertyTypes(this._keyProperties, supportedTypes, "Key"); - } - - /// Verify that the types of the data properties fall within the provided set. - /// The list of supported types. - /// A value indicating whether enumerable types are supported where the element type is one of the supported types. - public void VerifyDataProperties(HashSet supportedTypes, bool supportEnumerable) - { - VectorStoreRecordPropertyVerification.VerifyPropertyTypes(this._dataProperties, supportedTypes, "Data", supportEnumerable); - } - - /// Verify that the types of the data properties fall within the provided set. - /// The list of supported types. - /// A value indicating whether enumerable types are supported where the element type is one of the supported types. - public void VerifyDataProperties(HashSet supportedTypes, HashSet supportedEnumerableElementTypes) - { - VectorStoreRecordPropertyVerification.VerifyPropertyTypes(this._dataProperties, supportedTypes, supportedEnumerableElementTypes, "Data"); - } - - /// Verify that the types of the vector properties fall within the provided set. - /// The list of supported types. - public void VerifyVectorProperties(HashSet supportedTypes) - { - VectorStoreRecordPropertyVerification.VerifyPropertyTypes(this._vectorProperties, supportedTypes, "Vector"); - } - - /// - /// Get the storage property name for the given data model property name. - /// - /// The data model property name for which to get the storage property name. - /// The storage property name. - public string GetStoragePropertyName(string dataModelPropertyName) - { - return this._storagePropertyNamesMap.Value[dataModelPropertyName]; - } - - /// - /// Get the name under which a property will be stored if serialized to JSON - /// - /// The data model property name for which to get the JSON name. - /// The JSON name. - public string GetJsonPropertyName(string dataModelPropertyName) - { - return this._jsonPropertyNamesMap.Value[dataModelPropertyName]; - } - - /// - /// Get the vector property with the provided name if a name is provided, and fall back - /// to a vector property in the schema if not. If no name is provided and there is more - /// than one vector property, an exception will be thrown. - /// - /// The search options. - /// Thrown if the provided property name is not a valid vector property name. - public VectorStoreRecordVectorProperty GetVectorPropertyOrSingle(VectorSearchOptions? searchOptions) - { - if (searchOptions is not null) - { -#pragma warning disable CS0618 // Type or member is obsolete - string? vectorPropertyName = searchOptions.VectorPropertyName; -#pragma warning restore CS0618 // Type or member is obsolete - - // If vector property name is provided, try to find it in schema or throw an exception. - if (!string.IsNullOrWhiteSpace(vectorPropertyName)) - { - // Check vector properties by data model property name. - return this.VectorProperties.FirstOrDefault(l => l.DataModelPropertyName.Equals(vectorPropertyName, StringComparison.Ordinal)) - ?? throw new InvalidOperationException($"The {this._dataModelType.FullName} type does not have a vector property named '{vectorPropertyName}'."); - } - else if (searchOptions.VectorProperty is Expression> expression) - { - // VectorPropertiesInfo is not available for VectorStoreGenericDataModel. - IReadOnlyList infos = typeof(TRecord).IsGenericType && typeof(TRecord).GetGenericTypeDefinition() == typeof(VectorStoreGenericDataModel<>) - ? [] : this.VectorPropertiesInfo; - - return GetMatchingProperty(expression, infos, this.VectorProperties); - } - } - - // If vector property name is not provided, check if there is a single vector property, or throw if there are no vectors or more than one. - if (this.VectorProperty is null) - { - throw new InvalidOperationException($"The {this._dataModelType.FullName} type does not have any vector properties."); - } - - if (this.VectorProperties.Count > 1) - { - throw new InvalidOperationException($"The {this._dataModelType.FullName} type has multiple vector properties, please specify your chosen property via options."); - } - - return this.VectorProperty; - } - - /// - /// Get the text data property, that has full text search indexing enabled, with the provided name if a name is provided, and fall back - /// to a text data property in the schema if not. If no name is provided and there is more than one text data property with - /// full text search indexing enabled, an exception will be thrown. - /// - /// The full text search property selector. - /// Thrown if the provided property name is not a valid text data property name. - public VectorStoreRecordDataProperty GetFullTextDataPropertyOrSingle(Expression>? expression) - { - if (expression is not null) - { - // DataPropertiesInfo is not available for VectorStoreGenericDataModel. - IReadOnlyList infos = typeof(TRecord).IsGenericType && typeof(TRecord).GetGenericTypeDefinition() == typeof(VectorStoreGenericDataModel<>) - ? [] : this.DataPropertiesInfo; - - var dataProperty = GetMatchingProperty(expression, this.DataPropertiesInfo, this.DataProperties); - return dataProperty.IsFullTextSearchable - ? dataProperty - : throw new InvalidOperationException($"The text data property named '{dataProperty.DataModelPropertyName}' on the {this._dataModelType.FullName} type must have full text search enabled."); - } - - // If text data property name is not provided, check if a single full text searchable text property exists or throw otherwise. - var fullTextStringProperties = this.DataProperties - .Where(l => l.PropertyType == typeof(string) && l.IsFullTextSearchable) - .ToList(); - - if (fullTextStringProperties.Count == 0) - { - throw new InvalidOperationException($"The {this._dataModelType.FullName} type does not have any text data properties that have full text search enabled."); - } - - if (fullTextStringProperties.Count > 1) - { - throw new InvalidOperationException($"The {this._dataModelType.FullName} type has multiple text data properties that have full text search enabled, please specify your chosen property via options."); - } - - return fullTextStringProperties[0]; - } - - private static TProperty GetMatchingProperty(Expression> expression, - IReadOnlyList propertyInfos, IReadOnlyList properties) - where TProperty : VectorStoreRecordProperty - { - bool data = typeof(TProperty) == typeof(VectorStoreRecordDataProperty); - string expectedGenericModelPropertyName = data - ? nameof(VectorStoreGenericDataModel.Data) - : nameof(VectorStoreGenericDataModel.Vectors); - - MemberExpression? member = expression.Body as MemberExpression; - // (TRecord r) => r.PropertyName is translated into - // (TRecord r) => (object)r.PropertyName for properties that return struct like ReadOnlyMemory. - if (member is null && expression.Body is UnaryExpression unary - && unary.Operand.NodeType == ExpressionType.MemberAccess) - { - member = unary.Operand as MemberExpression; - } - - if (member is not null - && expression.Parameters.Count == 1 - && member.Expression == expression.Parameters[0] - && member.Member is PropertyInfo property) - { - for (int i = 0; i < propertyInfos.Count; i++) - { - if (propertyInfos[i] == property) - { - return properties[i]; - } - } - - throw new InvalidOperationException($"The property {property.Name} of {typeof(TRecord).FullName} is not a {(data ? "Data" : "Vector")} property."); - } - // (VectorStoreGenericDataModel r) => r.Vectors["PropertyName"] - else if (expression.Body is MethodCallExpression methodCall - // It's a Func, object> - && expression.Type.IsGenericType - && expression.Type.GenericTypeArguments.Length == 2 - && expression.Type.GenericTypeArguments[0].IsGenericType - && expression.Type.GenericTypeArguments[0].GetGenericTypeDefinition() == typeof(VectorStoreGenericDataModel<>) - // It's accessing VectorStoreGenericDataModel.Vectors (or Data) - && methodCall.Object is MemberExpression memberAccess - && memberAccess.Member.Name == expectedGenericModelPropertyName - // and has a single argument - && methodCall.Arguments.Count == 1) - { - string name = methodCall.Arguments[0] switch - { - ConstantExpression constant when constant.Value is string text => text, - MemberExpression field when TryGetCapturedValue(field, out object? capturedValue) && capturedValue is string text => text, - _ => throw new InvalidOperationException($"The value of the provided {(data ? "Additional" : "Vector")}Property option is not a valid expression.") - }; - - return properties.FirstOrDefault(l => l.DataModelPropertyName.Equals(name, StringComparison.Ordinal)) - ?? throw new InvalidOperationException($"The {typeof(TRecord).FullName} type does not have a vector property named '{name}'."); - } - - throw new InvalidOperationException($"The value of the provided {(data ? "Additional" : "Vector")}Property option is not a valid expression."); - - static bool TryGetCapturedValue(Expression expression, out object? capturedValue) - { - if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } - && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) - && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) - { - capturedValue = fieldInfo.GetValue(constant.Value); - return true; - } - - capturedValue = null; - return false; - } - } - - /// - /// Check if we have previously loaded the objects from the data model and if not, load them. - /// - private void LoadPropertyInfoIfNeeded() - { - if (this._keyPropertiesInfo != null) - { - return; - } - - // If we previously built the definition from the data model, the PropertyInfo objects - // from the data model would already be saved. If we didn't though, there could be a mismatch - // between what is defined in the definition and what is in the data model. Therefore, this - // method will throw if any property in the definition is not on the data model. - var propertiesInfo = FindPropertiesInfo(this._dataModelType, this._vectorStoreRecordDefinition); - - this._keyPropertiesInfo = propertiesInfo.KeyProperties; - this._dataPropertiesInfo = propertiesInfo.DataProperties; - this._vectorPropertiesInfo = propertiesInfo.VectorProperties; - } - - /// - /// Split the given into key, data and vector properties and verify that we have the expected numbers of each type. - /// - /// The name of the type that the definition relates to. - /// The to split. - /// A value indicating whether multiple key properties are supported. - /// A value indicating whether multiple vectors are supported. - /// A value indicating whether we need at least one vector. - /// The properties on the split into key, data and vector groupings. - /// Thrown if there are any validation failures with the provided . - private static (List KeyProperties, List DataProperties, List VectorProperties) SplitDefinitionAndVerify( - string typeName, - VectorStoreRecordDefinition definition, - bool supportsMultipleKeys, - bool supportsMultipleVectors, - bool requiresAtLeastOneVector) - { - var keyProperties = definition.Properties.OfType().ToList(); - var dataProperties = definition.Properties.OfType().ToList(); - var vectorProperties = definition.Properties.OfType().ToList(); - - if (keyProperties.Count > 1 && !supportsMultipleKeys) - { - throw new ArgumentException($"Multiple key properties found on type {typeName} or the provided {nameof(VectorStoreRecordDefinition)}."); - } - - if (keyProperties.Count == 0) - { - throw new ArgumentException($"No key property found on type {typeName} or the provided {nameof(VectorStoreRecordDefinition)}."); - } - - if (requiresAtLeastOneVector && vectorProperties.Count == 0) - { - throw new ArgumentException($"No vector property found on type {typeName} or the provided {nameof(VectorStoreRecordDefinition)} while at least one is required."); - } - - if (!supportsMultipleVectors && vectorProperties.Count > 1) - { - throw new ArgumentException($"Multiple vector properties found on type {typeName} or the provided {nameof(VectorStoreRecordDefinition)} while only one is supported."); - } - - return (keyProperties, dataProperties, vectorProperties); - } - - /// - /// Find the properties with , and attributes - /// and verify that they exist and that we have the expected numbers of each type. - /// Return those properties in separate categories. - /// - /// The data model to find the properties on. - /// The categorized properties. - private static (List KeyProperties, List DataProperties, List VectorProperties) FindPropertiesInfo([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] Type type) - { - List keyProperties = new(); - List dataProperties = new(); - List vectorProperties = new(); - - foreach (var property in type.GetProperties()) - { - // Get Key property. - if (property.GetCustomAttribute() is not null) - { - keyProperties.Add(property); - } - - // Get data properties. - if (property.GetCustomAttribute() is not null) - { - dataProperties.Add(property); - } - - // Get Vector properties. - if (property.GetCustomAttribute() is not null) - { - vectorProperties.Add(property); - } - } - - return (keyProperties, dataProperties, vectorProperties); - } - - /// - /// Find the properties listed in the on the and verify - /// that they exist. - /// Return those properties in separate categories. - /// - /// The data model to find the properties on. - /// The property configuration. - /// The categorized properties. - public static (List KeyProperties, List DataProperties, List VectorProperties) FindPropertiesInfo([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] Type type, VectorStoreRecordDefinition vectorStoreRecordDefinition) - { - List keyProperties = new(); - List dataProperties = new(); - List vectorProperties = new(); - - foreach (VectorStoreRecordProperty property in vectorStoreRecordDefinition.Properties) - { - // Key. - if (property is VectorStoreRecordKeyProperty keyPropertyInfo) - { - var keyProperty = type.GetProperty(keyPropertyInfo.DataModelPropertyName); - if (keyProperty == null) - { - throw new ArgumentException($"Key property '{keyPropertyInfo.DataModelPropertyName}' not found on type {type.FullName}."); - } - - keyProperties.Add(keyProperty); - } - // Data. - else if (property is VectorStoreRecordDataProperty dataPropertyInfo) - { - var dataProperty = type.GetProperty(dataPropertyInfo.DataModelPropertyName); - if (dataProperty == null) - { - throw new ArgumentException($"Data property '{dataPropertyInfo.DataModelPropertyName}' not found on type {type.FullName}."); - } - - dataProperties.Add(dataProperty); - } - // Vector. - else if (property is VectorStoreRecordVectorProperty vectorPropertyInfo) - { - var vectorProperty = type.GetProperty(vectorPropertyInfo.DataModelPropertyName); - if (vectorProperty == null) - { - throw new ArgumentException($"Vector property '{vectorPropertyInfo.DataModelPropertyName}' not found on type {type.FullName}."); - } - - vectorProperties.Add(vectorProperty); - } - else - { - throw new ArgumentException($"Unknown property type '{property.GetType().FullName}' in vector store record definition."); - } - } - - return (keyProperties, dataProperties, vectorProperties); - } - - /// - /// Create a by reading the attributes on the provided objects. - /// - /// objects to build a from. - /// The based on the given objects. - private static VectorStoreRecordDefinition CreateVectorStoreRecordDefinitionFromType((List KeyProperties, List DataProperties, List VectorProperties) propertiesInfo) - { - var definitionProperties = new List(); - - // Key properties. - foreach (var keyProperty in propertiesInfo.KeyProperties) - { - var keyAttribute = keyProperty.GetCustomAttribute(); - if (keyAttribute is not null) - { - definitionProperties.Add(new VectorStoreRecordKeyProperty(keyProperty.Name, keyProperty.PropertyType) - { - StoragePropertyName = keyAttribute.StoragePropertyName - }); - } - } - - // Data properties. - foreach (var dataProperty in propertiesInfo.DataProperties) - { - var dataAttribute = dataProperty.GetCustomAttribute(); - if (dataAttribute is not null) - { - definitionProperties.Add(new VectorStoreRecordDataProperty(dataProperty.Name, dataProperty.PropertyType) - { - IsFilterable = dataAttribute.IsFilterable, - IsFullTextSearchable = dataAttribute.IsFullTextSearchable, - StoragePropertyName = dataAttribute.StoragePropertyName - }); - } - } - - // Vector properties. - foreach (var vectorProperty in propertiesInfo.VectorProperties) - { - var vectorAttribute = vectorProperty.GetCustomAttribute(); - if (vectorAttribute is not null) - { - definitionProperties.Add(new VectorStoreRecordVectorProperty(vectorProperty.Name, vectorProperty.PropertyType) - { - Dimensions = vectorAttribute.Dimensions, - IndexKind = vectorAttribute.IndexKind, - DistanceFunction = vectorAttribute.DistanceFunction, - StoragePropertyName = vectorAttribute.StoragePropertyName - }); - } - } - - return new VectorStoreRecordDefinition { Properties = definitionProperties }; - } - - /// - /// Build a map of property names to the names under which they should be saved in storage, for the given properties. - /// - /// The properties to build the map for. - /// The map from property names to the names under which they should be saved in storage. - private static Dictionary BuildPropertyNameToStorageNameMap((List keyProperties, List dataProperties, List vectorProperties) properties) - { - var storagePropertyNameMap = new Dictionary(); - - foreach (var keyProperty in properties.keyProperties) - { - storagePropertyNameMap.Add(keyProperty.DataModelPropertyName, keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName); - } - - foreach (var dataProperty in properties.dataProperties) - { - storagePropertyNameMap.Add(dataProperty.DataModelPropertyName, dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName); - } - - foreach (var vectorProperty in properties.vectorProperties) - { - storagePropertyNameMap.Add(vectorProperty.DataModelPropertyName, vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName); - } - - return storagePropertyNameMap; - } - - /// - /// Build a map of property names to the names that they would have if serialized to JSON. - /// - /// The properties to build the map for. - /// The data model type that the property belongs to. - /// The options used for JSON serialization. - /// The map from property names to the names that they would have if serialized to JSON. - private static Dictionary BuildPropertyNameToJsonPropertyNameMap( - (List keyProperties, List dataProperties, List vectorProperties) properties, - [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] Type dataModel, - JsonSerializerOptions? options) - { - var jsonPropertyNameMap = new Dictionary(); - - foreach (var keyProperty in properties.keyProperties) - { - jsonPropertyNameMap.Add(keyProperty.DataModelPropertyName, GetJsonPropertyName(keyProperty, dataModel, options)); - } - - foreach (var dataProperty in properties.dataProperties) - { - jsonPropertyNameMap.Add(dataProperty.DataModelPropertyName, GetJsonPropertyName(dataProperty, dataModel, options)); - } - - foreach (var vectorProperty in properties.vectorProperties) - { - jsonPropertyNameMap.Add(vectorProperty.DataModelPropertyName, GetJsonPropertyName(vectorProperty, dataModel, options)); - } - - return jsonPropertyNameMap; - } - - /// - /// Get the JSON property name of a property by using the if available, otherwise - /// using the if available, otherwise falling back to the property name. - /// The provided may not actually contain the property, e.g. when the user has a data model that - /// doesn't resemble the stored data and where they are using a custom mapper. - /// - /// The property to retrieve a JSON name for. - /// The data model type that the property belongs to. - /// The options used for JSON serialization. - /// The JSON property name. - private static string GetJsonPropertyName(VectorStoreRecordProperty property, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] Type dataModel, JsonSerializerOptions? options) - { - var propertyInfo = dataModel.GetProperty(property.DataModelPropertyName); - - if (propertyInfo != null) - { - var jsonPropertyNameAttribute = propertyInfo.GetCustomAttribute(); - if (jsonPropertyNameAttribute is not null) - { - return jsonPropertyNameAttribute.Name; - } - } - - if (options?.PropertyNamingPolicy is not null) - { - return options.PropertyNamingPolicy.ConvertName(property.DataModelPropertyName); - } - - return property.DataModelPropertyName; - } -} diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReaderOptions.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReaderOptions.cs deleted file mode 100644 index 7404106d1a27..000000000000 --- a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReaderOptions.cs +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Diagnostics.CodeAnalysis; -using System.Text.Json; - -namespace Microsoft.Extensions.VectorData; - -/// -/// Contains options for . -/// -[ExcludeFromCodeCoverage] -internal sealed class VectorStoreRecordPropertyReaderOptions -{ - /// - /// Gets or sets a value indicating whether the connector/db supports multiple key properties. - /// - public bool SupportsMultipleKeys { get; set; } = false; - - /// - /// Gets or sets a value indicating whether the connector/db supports multiple vector properties. - /// - public bool SupportsMultipleVectors { get; set; } = true; - - /// - /// Gets or sets a value indicating whether the connector/db requires at least one vector property. - /// - public bool RequiresAtLeastOneVector { get; set; } = false; - - /// - /// Gets or sets the json serializer options that the connector might be using for storage serialization. - /// - public JsonSerializerOptions? JsonSerializerOptions { get; set; } -} diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyVerification.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyVerification.cs index 08337bd0f138..719b5f88fcf6 100644 --- a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyVerification.cs +++ b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyVerification.cs @@ -191,65 +191,6 @@ var enumerableType when GetGenericEnumerableInterface(enumerableType) is Type en return null; } - internal static bool IsGenericDataModel(Type recordType) - => recordType.IsGenericType && recordType.GetGenericTypeDefinition() == typeof(VectorStoreGenericDataModel<>); - - /// - /// Checks that if the provided is a that the key type is supported by the default mappers. - /// If not supported, a custom mapper must be supplied, otherwise an exception is thrown. - /// - /// The type of the record data model used by the connector. - /// A value indicating whether a custom mapper was supplied to the connector - /// The list of key types supported by the default mappers. - /// Thrown if the key type of the is not supported by the default mappers and a custom mapper was not supplied. - public static void VerifyGenericDataModelKeyType(Type recordType, bool customMapperSupplied, IEnumerable allowedKeyTypes) - { - // If we are not dealing with a generic data model, no need to check anything else. - if (!IsGenericDataModel(recordType)) - { - return; - } - - // If the key type is supported, we are good. - var keyType = recordType.GetGenericArguments()[0]; - if (allowedKeyTypes.Contains(keyType)) - { - return; - } - - // If the key type is not supported out of the box, but a custom mapper was supplied, we are good. - if (customMapperSupplied) - { - return; - } - - throw new ArgumentException($"The key type '{keyType.FullName}' of data model '{nameof(VectorStoreGenericDataModel)}' is not supported by the default mappers. " + - $"Only the following key types are supported: {string.Join(", ", allowedKeyTypes)}. Please provide your own mapper to map to your chosen key type."); - } - - /// - /// Checks that if the provided is a that a is also provided. - /// - /// The type of the record data model used by the connector. - /// A value indicating whether a record definition was supplied to the connector. - /// Thrown if a is not provided when using . - public static void VerifyGenericDataModelDefinitionSupplied(Type recordType, bool recordDefinitionSupplied) - { - // If we are not dealing with a generic data model, no need to check anything else. - if (!recordType.IsGenericType || recordType.GetGenericTypeDefinition() != typeof(VectorStoreGenericDataModel<>)) - { - return; - } - - // If we are dealing with a generic data model, and a record definition was supplied, we are good. - if (recordDefinitionSupplied) - { - return; - } - - throw new ArgumentException($"A {nameof(VectorStoreRecordDefinition)} must be provided when using '{nameof(VectorStoreGenericDataModel)}'."); - } - #if NET6_0_OR_GREATER private static readonly ConstructorInfo s_objectGetDefaultConstructorInfo = typeof(object).GetConstructor(Type.EmptyTypes)!; #endif diff --git a/dotnet/src/InternalUtilities/src/Diagnostics/CompilerServicesAttributes.cs b/dotnet/src/InternalUtilities/src/Diagnostics/CompilerServicesAttributes.cs index bba0ffc78584..f72f48d1c65f 100644 --- a/dotnet/src/InternalUtilities/src/Diagnostics/CompilerServicesAttributes.cs +++ b/dotnet/src/InternalUtilities/src/Diagnostics/CompilerServicesAttributes.cs @@ -7,6 +7,7 @@ #if !NETCOREAPP #pragma warning disable IDE0005 // Using directive is unnecessary. +using System.ComponentModel; using System.Diagnostics.CodeAnalysis; namespace System.Runtime.CompilerServices; @@ -23,4 +24,38 @@ public CallerArgumentExpressionAttribute(string parameterName) public string ParameterName { get; } } +/// Specifies that a type has required members or that a member is required. +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = false)] +[EditorBrowsable(EditorBrowsableState.Never)] +internal sealed class RequiredMemberAttribute : Attribute; + +[AttributeUsage(AttributeTargets.All, AllowMultiple = true, Inherited = false)] +internal sealed class CompilerFeatureRequiredAttribute : Attribute +{ + public CompilerFeatureRequiredAttribute(string featureName) + { + this.FeatureName = featureName; + } + + /// + /// The name of the compiler feature. + /// + public string FeatureName { get; } + + /// + /// If true, the compiler can choose to allow access to the location where this attribute is applied if it does not understand . + /// + public bool IsOptional { get; init; } + + /// + /// The used for the ref structs C# feature. + /// + public const string RefStructs = nameof(RefStructs); + + /// + /// The used for the required members C# feature. + /// + public const string RequiredMembers = nameof(RequiredMembers); +} + #endif diff --git a/dotnet/src/InternalUtilities/src/Diagnostics/IsExternalInit.cs b/dotnet/src/InternalUtilities/src/Diagnostics/IsExternalInit.cs index 7bd800e1dd6f..bf1c27afb2ab 100644 --- a/dotnet/src/InternalUtilities/src/Diagnostics/IsExternalInit.cs +++ b/dotnet/src/InternalUtilities/src/Diagnostics/IsExternalInit.cs @@ -1,5 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +#if !NET8_0_OR_GREATER + namespace System.Runtime.CompilerServices; /// @@ -7,3 +9,5 @@ namespace System.Runtime.CompilerServices; /// This class should not be used by developers in source code. /// internal static class IsExternalInit; + +#endif diff --git a/dotnet/src/InternalUtilities/src/Diagnostics/KernelVerify.cs b/dotnet/src/InternalUtilities/src/Diagnostics/KernelVerify.cs new file mode 100644 index 000000000000..1cb1c96ae181 --- /dev/null +++ b/dotnet/src/InternalUtilities/src/Diagnostics/KernelVerify.cs @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Text.RegularExpressions; + +namespace Microsoft.SemanticKernel; + +[ExcludeFromCodeCoverage] +internal static partial class KernelVerify +{ +#if NET + [GeneratedRegex("^[0-9A-Za-z_]*$")] + private static partial Regex AsciiLettersDigitsUnderscoresRegex(); +#else + private static Regex AsciiLettersDigitsUnderscoresRegex() => s_asciiLettersDigitsUnderscoresRegex; + private static readonly Regex s_asciiLettersDigitsUnderscoresRegex = new("^[0-9A-Za-z_]*$", RegexOptions.Compiled); +#endif + + internal static void ValidPluginName([NotNull] string? pluginName, IReadOnlyKernelPluginCollection? plugins = null, [CallerArgumentExpression(nameof(pluginName))] string? paramName = null) + { + Verify.NotNullOrWhiteSpace(pluginName); + if (!AsciiLettersDigitsUnderscoresRegex().IsMatch(pluginName)) + { + Verify.ThrowArgumentInvalidName("plugin name", pluginName, paramName); + } + + if (plugins is not null && plugins.Contains(pluginName)) + { + throw new ArgumentException($"A plugin with the name '{pluginName}' already exists."); + } + } + + internal static void ValidFunctionName([NotNull] string? functionName, [CallerArgumentExpression(nameof(functionName))] string? paramName = null) + { + Verify.NotNullOrWhiteSpace(functionName); + if (!AsciiLettersDigitsUnderscoresRegex().IsMatch(functionName)) + { + Verify.ThrowArgumentInvalidName("function name", functionName, paramName); + } + } + + /// + /// Make sure every function parameter name is unique + /// + /// List of parameters + internal static void ParametersUniqueness(IReadOnlyList parameters) + { + int count = parameters.Count; + if (count > 0) + { + var seen = new HashSet(StringComparer.OrdinalIgnoreCase); + for (int i = 0; i < count; i++) + { + KernelParameterMetadata p = parameters[i]; + if (string.IsNullOrWhiteSpace(p.Name)) + { + string paramName = $"{nameof(parameters)}[{i}].{p.Name}"; + if (p.Name is null) + { + Verify.ThrowArgumentNullException(paramName); + } + else + { + Verify.ThrowArgumentWhiteSpaceException(paramName); + } + } + + if (!seen.Add(p.Name)) + { + throw new ArgumentException($"The function has two or more parameters with the same name '{p.Name}'"); + } + } + } + } +} diff --git a/dotnet/src/InternalUtilities/src/Diagnostics/LoggingExtensions.cs b/dotnet/src/InternalUtilities/src/Diagnostics/LoggingExtensions.cs new file mode 100644 index 000000000000..8fa8c4a4125c --- /dev/null +++ b/dotnet/src/InternalUtilities/src/Diagnostics/LoggingExtensions.cs @@ -0,0 +1,137 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace Microsoft.SemanticKernel.Diagnostics; + +[ExcludeFromCodeCoverage] +internal static partial class LoggingExtensions +{ + internal static async Task RunWithLoggingAsync( + ILogger logger, + string operationName, + Func operation) + { + logger.LogInvoked(operationName); + + try + { + await operation().ConfigureAwait(false); + + logger.LogCompleted(operationName); + } + catch (OperationCanceledException) + { + logger.LogInvocationCanceled(operationName); + throw; + } + catch (Exception ex) + { + logger.LogInvocationFailed(operationName, ex); + throw; + } + } + + internal static async Task RunWithLoggingAsync( + ILogger logger, + string operationName, + Func> operation) + { + logger.LogInvoked(operationName); + + try + { + var result = await operation().ConfigureAwait(false); + + logger.LogCompleted(operationName); + + return result; + } + catch (OperationCanceledException) + { + logger.LogInvocationCanceled(operationName); + throw; + } + catch (Exception ex) + { + logger.LogInvocationFailed(operationName, ex); + throw; + } + } + + internal static async IAsyncEnumerable RunWithLoggingAsync( + ILogger logger, + string operationName, + Func> operation, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + logger.LogInvoked(operationName); + + IAsyncEnumerator enumerator; + + try + { + enumerator = operation().GetAsyncEnumerator(cancellationToken); + } + catch (OperationCanceledException) + { + logger.LogInvocationCanceled(operationName); + throw; + } + catch (Exception ex) + { + logger.LogInvocationFailed(operationName, ex); + throw; + } + + try + { + while (true) + { + try + { + if (!await enumerator.MoveNextAsync().ConfigureAwait(false)) + { + break; + } + } + catch (OperationCanceledException) + { + logger.LogInvocationCanceled(operationName); + throw; + } + catch (Exception ex) + { + logger.LogInvocationFailed(operationName, ex); + throw; + } + + yield return enumerator.Current; + } + + logger.LogCompleted(operationName); + } + finally + { + await enumerator.DisposeAsync().ConfigureAwait(false); + } + } + + [LoggerMessage(LogLevel.Debug, "{OperationName} invoked.")] + private static partial void LogInvoked(this ILogger logger, string operationName); + + [LoggerMessage(LogLevel.Debug, "{OperationName} completed.")] + private static partial void LogCompleted(this ILogger logger, string operationName); + + [LoggerMessage(LogLevel.Debug, "{OperationName} canceled.")] + private static partial void LogInvocationCanceled(this ILogger logger, string operationName); + + [LoggerMessage(LogLevel.Error, "{OperationName} failed.")] + private static partial void LogInvocationFailed(this ILogger logger, string operationName, Exception exception); +} diff --git a/dotnet/src/InternalUtilities/src/Diagnostics/UnreachableException.cs b/dotnet/src/InternalUtilities/src/Diagnostics/UnreachableException.cs index 616073f54705..1e2d9f9b0b02 100644 --- a/dotnet/src/InternalUtilities/src/Diagnostics/UnreachableException.cs +++ b/dotnet/src/InternalUtilities/src/Diagnostics/UnreachableException.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -#if NETSTANDARD2_0 +#if !NET8_0_OR_GREATER // Polyfill for using UnreachableException with .NET Standard 2.0 diff --git a/dotnet/src/InternalUtilities/src/Diagnostics/Verify.cs b/dotnet/src/InternalUtilities/src/Diagnostics/Verify.cs index c792d50d13e0..da5bf57dd103 100644 --- a/dotnet/src/InternalUtilities/src/Diagnostics/Verify.cs +++ b/dotnet/src/InternalUtilities/src/Diagnostics/Verify.cs @@ -185,7 +185,7 @@ internal static void ParametersUniqueness(IReadOnlyList #endif [DoesNotReturn] - private static void ThrowArgumentInvalidName(string kind, string name, string? paramName) => + internal static void ThrowArgumentInvalidName(string kind, string name, string? paramName) => throw new ArgumentException($"A {kind} can contain only ASCII letters, digits, and underscores: '{name}' is not a valid name.", paramName); [DoesNotReturn] @@ -242,4 +242,12 @@ internal static void ValidHostnameSegment(string hostNameSegment, [CallerArgumen throw new ArgumentException($"The location '{hostNameSegment}' is not valid. Location must start and end with alphanumeric characters and can contain hyphens and underscores.", paramName); } } + + internal static void NotLessThan(int value, int limit, [CallerArgumentExpression(nameof(value))] string? paramName = null) + { + if (value < limit) + { + throw new ArgumentOutOfRangeException(paramName, $"{paramName} must be greater than or equal to {limit}."); + } + } } diff --git a/dotnet/src/InternalUtilities/src/Http/HttpContentExtensions.cs b/dotnet/src/InternalUtilities/src/Http/HttpContentExtensions.cs index 51d9acf0509d..dd7ac895b984 100644 --- a/dotnet/src/InternalUtilities/src/Http/HttpContentExtensions.cs +++ b/dotnet/src/InternalUtilities/src/Http/HttpContentExtensions.cs @@ -24,11 +24,7 @@ public static async Task ReadAsStringWithExceptionMappingAsync(this Http { try { -#if NET5_0_OR_GREATER return await httpContent.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); -#else - return await httpContent.ReadAsStringAsync().ConfigureAwait(false); -#endif } catch (HttpRequestException ex) { @@ -46,11 +42,7 @@ public static async Task ReadAsStreamAndTranslateExceptionAsync(this Htt { try { -#if NET5_0_OR_GREATER return await httpContent.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); -#else - return await httpContent.ReadAsStreamAsync().ConfigureAwait(false); -#endif } catch (HttpRequestException ex) { @@ -68,11 +60,7 @@ public static async Task ReadAsByteArrayAndTranslateExceptionAsync(this { try { -#if NET5_0_OR_GREATER return await httpContent.ReadAsByteArrayAsync(cancellationToken).ConfigureAwait(false); -#else - return await httpContent.ReadAsByteArrayAsync().ConfigureAwait(false); -#endif } catch (HttpRequestException ex) { diff --git a/dotnet/src/InternalUtilities/src/Http/HttpContentPolyfills.cs b/dotnet/src/InternalUtilities/src/Http/HttpContentPolyfills.cs new file mode 100644 index 000000000000..aea223102298 --- /dev/null +++ b/dotnet/src/InternalUtilities/src/Http/HttpContentPolyfills.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +#if !NET5_0_OR_GREATER + +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.Http; + +[ExcludeFromCodeCoverage] +internal static class HttpContentPolyfills +{ + internal static Task ReadAsStringAsync(this HttpContent httpContent, CancellationToken cancellationToken) + => httpContent.ReadAsStringAsync(); + + internal static Task ReadAsStreamAsync(this HttpContent httpContent, CancellationToken cancellationToken) + => httpContent.ReadAsStreamAsync(); + + internal static Task ReadAsByteArrayAsync(this HttpContent httpContent, CancellationToken cancellationToken) + => httpContent.ReadAsByteArrayAsync(); +} + +#endif diff --git a/dotnet/src/InternalUtilities/src/RestrictedInternalUtilities.props b/dotnet/src/InternalUtilities/src/RestrictedInternalUtilities.props new file mode 100644 index 000000000000..f4304f16a9be --- /dev/null +++ b/dotnet/src/InternalUtilities/src/RestrictedInternalUtilities.props @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/dotnet/src/InternalUtilities/src/System/EmptyKeyedServiceProvider.cs b/dotnet/src/InternalUtilities/src/System/EmptyKeyedServiceProvider.cs new file mode 100644 index 000000000000..c7aaf6b4fd3b --- /dev/null +++ b/dotnet/src/InternalUtilities/src/System/EmptyKeyedServiceProvider.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; + +namespace Microsoft.Extensions.DependencyInjection; + +/// Provides an implementation of that contains no services. +internal sealed class EmptyKeyedServiceProvider : IKeyedServiceProvider +{ + /// Gets a singleton instance of . + public static EmptyKeyedServiceProvider Instance { get; } = new(); + + /// + public object? GetService(Type serviceType) => null; + + /// + public object? GetKeyedService(Type serviceType, object? serviceKey) => null; + + /// + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) => + this.GetKeyedService(serviceType, serviceKey) ?? + throw new InvalidOperationException($"No service for type '{serviceType}' and key '{serviceKey}' has been registered."); +} diff --git a/dotnet/src/InternalUtilities/src/System/IndexRange.cs b/dotnet/src/InternalUtilities/src/System/IndexRange.cs index 439e6e844fb6..32c6c9c12538 100644 --- a/dotnet/src/InternalUtilities/src/System/IndexRange.cs +++ b/dotnet/src/InternalUtilities/src/System/IndexRange.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -#if NETSTANDARD2_0 +#if !NET8_0_OR_GREATER // Polyfill for using Index and Range with .NET Standard 2.0 (see https://www.meziantou.net/how-to-use-csharp-8-indices-and-ranges-in-dotnet-standard-2-0-and-dotn.htm) diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs index c0a76d0e0c2c..4188bc6b6994 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs @@ -161,7 +161,7 @@ internal KernelFunction(string name, string description, IReadOnlyList parameters, KernelReturnParameterMetadata? returnParameter = null, Dictionary? executionSettings = null, ReadOnlyDictionary? additionalMetadata = null) { Verify.NotNull(name); - Verify.ParametersUniqueness(parameters); + KernelVerify.ParametersUniqueness(parameters); this.Metadata = new KernelFunctionMetadata(name) { @@ -197,7 +197,7 @@ internal KernelFunction(string name, string? pluginName, string description, IRe internal KernelFunction(string name, string? pluginName, string description, IReadOnlyList parameters, JsonSerializerOptions jsonSerializerOptions, KernelReturnParameterMetadata? returnParameter = null, Dictionary? executionSettings = null, ReadOnlyDictionary? additionalMetadata = null) { Verify.NotNull(name); - Verify.ParametersUniqueness(parameters); + KernelVerify.ParametersUniqueness(parameters); Verify.NotNull(jsonSerializerOptions); this.Metadata = new KernelFunctionMetadata(name) diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionMetadata.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionMetadata.cs index cae651f74fea..034eeb72833a 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionMetadata.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionMetadata.cs @@ -58,7 +58,7 @@ public string Name init { Verify.NotNull(value); - Verify.ValidFunctionName(value); + KernelVerify.ValidFunctionName(value); this._name = value; } } diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelPlugin.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelPlugin.cs index c86faaf03065..1b6aab3c87a3 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelPlugin.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelPlugin.cs @@ -29,7 +29,7 @@ public abstract class KernelPlugin : IEnumerable /// is an invalid plugin name. protected KernelPlugin(string name, string? description = null) { - Verify.ValidPluginName(name); + KernelVerify.ValidPluginName(name); this.Name = name; this.Description = !string.IsNullOrWhiteSpace(description) ? description! : ""; diff --git a/dotnet/src/SemanticKernel.Abstractions/Services/EmptyServiceProvider.cs b/dotnet/src/SemanticKernel.Abstractions/Services/EmptyServiceProvider.cs index 08305ca9df83..ff676289a399 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Services/EmptyServiceProvider.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Services/EmptyServiceProvider.cs @@ -53,9 +53,9 @@ private static Array CreateArray(Type elementType, int length) } private static bool VerifyAotCompatibility => -#if NETFRAMEWORK || NETSTANDARD2_0 - false; -#else +#if NET8_0_OR_GREATER !System.Runtime.CompilerServices.RuntimeFeature.IsDynamicCodeSupported; +#else + false; #endif } diff --git a/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/MockVectorizableTextSearch.cs b/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/MockVectorizableTextSearch.cs index b39976adbebf..7c6c8f8af33c 100644 --- a/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/MockVectorizableTextSearch.cs +++ b/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/MockVectorizableTextSearch.cs @@ -4,25 +4,42 @@ namespace SemanticKernel.AotTests.UnitTests.Search; -internal sealed class MockVectorizableTextSearch : IVectorizableTextSearch +internal sealed class MockVectorizableTextSearch : IVectorSearch { private readonly IAsyncEnumerable> _searchResults; public MockVectorizableTextSearch(IEnumerable> searchResults) { - this._searchResults = ToAsyncEnumerable(searchResults); + this._searchResults = searchResults.ToAsyncEnumerable(); } - public Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public IAsyncEnumerable> SearchAsync( + TInput value, + int top, + VectorSearchOptions? options = default, + CancellationToken cancellationToken = default) + where TInput : notnull { - return Task.FromResult(new VectorSearchResults(this._searchResults)); + return this._searchResults; } - private static async IAsyncEnumerable> ToAsyncEnumerable(IEnumerable> searchResults) + public IAsyncEnumerable> SearchEmbeddingAsync( + TVector vector, + int top, + VectorSearchOptions? options = default, + CancellationToken cancellationToken = default) + where TVector : notnull { - foreach (var result in searchResults) - { - yield return result; - } + return this._searchResults; + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + ArgumentNullException.ThrowIfNull(serviceType); + + return + serviceKey is null && serviceType.IsInstanceOfType(this) ? this : + null; } } diff --git a/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/VectorStoreTextSearchTests.cs b/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/VectorStoreTextSearchTests.cs index eee8ae4db55e..e06d4b3bf741 100644 --- a/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/VectorStoreTextSearchTests.cs +++ b/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/VectorStoreTextSearchTests.cs @@ -46,7 +46,7 @@ public static async Task AddVectorStoreTextSearch() }; var vectorizableTextSearch = new MockVectorizableTextSearch(testData); var serviceCollection = new ServiceCollection(); - serviceCollection.AddSingleton>(vectorizableTextSearch); + serviceCollection.AddSingleton>(vectorizableTextSearch); // Act serviceCollection.AddVectorStoreTextSearch(); diff --git a/dotnet/src/SemanticKernel.Core/CompatibilitySuppressions.xml b/dotnet/src/SemanticKernel.Core/CompatibilitySuppressions.xml new file mode 100644 index 000000000000..97293e39d3cc --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/CompatibilitySuppressions.xml @@ -0,0 +1,60 @@ + + + + + CP0002 + M:Microsoft.SemanticKernel.Data.VectorStoreTextSearch`1.#ctor(Microsoft.Extensions.VectorData.IVectorizableTextSearch{`0},Microsoft.SemanticKernel.Data.ITextSearchStringMapper,Microsoft.SemanticKernel.Data.ITextSearchResultMapper,Microsoft.SemanticKernel.Data.VectorStoreTextSearchOptions) + lib/net8.0/Microsoft.SemanticKernel.Core.dll + lib/net8.0/Microsoft.SemanticKernel.Core.dll + true + + + CP0002 + M:Microsoft.SemanticKernel.Data.VectorStoreTextSearch`1.#ctor(Microsoft.Extensions.VectorData.IVectorizableTextSearch{`0},Microsoft.SemanticKernel.Data.MapFromResultToString,Microsoft.SemanticKernel.Data.MapFromResultToTextSearchResult,Microsoft.SemanticKernel.Data.VectorStoreTextSearchOptions) + lib/net8.0/Microsoft.SemanticKernel.Core.dll + lib/net8.0/Microsoft.SemanticKernel.Core.dll + true + + + CP0002 + M:Microsoft.SemanticKernel.Data.VectorStoreTextSearch`1.#ctor(Microsoft.Extensions.VectorData.IVectorizedSearch{`0},Microsoft.SemanticKernel.Embeddings.ITextEmbeddingGenerationService,Microsoft.SemanticKernel.Data.ITextSearchStringMapper,Microsoft.SemanticKernel.Data.ITextSearchResultMapper,Microsoft.SemanticKernel.Data.VectorStoreTextSearchOptions) + lib/net8.0/Microsoft.SemanticKernel.Core.dll + lib/net8.0/Microsoft.SemanticKernel.Core.dll + true + + + CP0002 + M:Microsoft.SemanticKernel.Data.VectorStoreTextSearch`1.#ctor(Microsoft.Extensions.VectorData.IVectorizedSearch{`0},Microsoft.SemanticKernel.Embeddings.ITextEmbeddingGenerationService,Microsoft.SemanticKernel.Data.MapFromResultToString,Microsoft.SemanticKernel.Data.MapFromResultToTextSearchResult,Microsoft.SemanticKernel.Data.VectorStoreTextSearchOptions) + lib/net8.0/Microsoft.SemanticKernel.Core.dll + lib/net8.0/Microsoft.SemanticKernel.Core.dll + true + + + CP0002 + M:Microsoft.SemanticKernel.Data.VectorStoreTextSearch`1.#ctor(Microsoft.Extensions.VectorData.IVectorizableTextSearch{`0},Microsoft.SemanticKernel.Data.ITextSearchStringMapper,Microsoft.SemanticKernel.Data.ITextSearchResultMapper,Microsoft.SemanticKernel.Data.VectorStoreTextSearchOptions) + lib/netstandard2.0/Microsoft.SemanticKernel.Core.dll + lib/netstandard2.0/Microsoft.SemanticKernel.Core.dll + true + + + CP0002 + M:Microsoft.SemanticKernel.Data.VectorStoreTextSearch`1.#ctor(Microsoft.Extensions.VectorData.IVectorizableTextSearch{`0},Microsoft.SemanticKernel.Data.MapFromResultToString,Microsoft.SemanticKernel.Data.MapFromResultToTextSearchResult,Microsoft.SemanticKernel.Data.VectorStoreTextSearchOptions) + lib/netstandard2.0/Microsoft.SemanticKernel.Core.dll + lib/netstandard2.0/Microsoft.SemanticKernel.Core.dll + true + + + CP0002 + M:Microsoft.SemanticKernel.Data.VectorStoreTextSearch`1.#ctor(Microsoft.Extensions.VectorData.IVectorizedSearch{`0},Microsoft.SemanticKernel.Embeddings.ITextEmbeddingGenerationService,Microsoft.SemanticKernel.Data.ITextSearchStringMapper,Microsoft.SemanticKernel.Data.ITextSearchResultMapper,Microsoft.SemanticKernel.Data.VectorStoreTextSearchOptions) + lib/netstandard2.0/Microsoft.SemanticKernel.Core.dll + lib/netstandard2.0/Microsoft.SemanticKernel.Core.dll + true + + + CP0002 + M:Microsoft.SemanticKernel.Data.VectorStoreTextSearch`1.#ctor(Microsoft.Extensions.VectorData.IVectorizedSearch{`0},Microsoft.SemanticKernel.Embeddings.ITextEmbeddingGenerationService,Microsoft.SemanticKernel.Data.MapFromResultToString,Microsoft.SemanticKernel.Data.MapFromResultToTextSearchResult,Microsoft.SemanticKernel.Data.VectorStoreTextSearchOptions) + lib/netstandard2.0/Microsoft.SemanticKernel.Core.dll + lib/netstandard2.0/Microsoft.SemanticKernel.Core.dll + true + + \ No newline at end of file diff --git a/dotnet/src/SemanticKernel.Core/Data/TextSearch/TextSearchServiceCollectionExtensions.cs b/dotnet/src/SemanticKernel.Core/Data/TextSearch/TextSearchServiceCollectionExtensions.cs index c36c50bafa10..4f163b69c323 100644 --- a/dotnet/src/SemanticKernel.Core/Data/TextSearch/TextSearchServiceCollectionExtensions.cs +++ b/dotnet/src/SemanticKernel.Core/Data/TextSearch/TextSearchServiceCollectionExtensions.cs @@ -39,29 +39,11 @@ public static class TextSearchServiceCollectionExtensions resultMapper ??= sp.GetService(); options ??= sp.GetService(); - var vectorizableTextSearch = sp.GetService>(); - if (vectorizableTextSearch is not null) - { - return new VectorStoreTextSearch( - vectorizableTextSearch, - stringMapper, - resultMapper, - options); - } - - var vectorizedSearch = sp.GetService>(); - var generationService = sp.GetService(); - if (vectorizedSearch is not null && generationService is not null) - { - return new VectorStoreTextSearch( - vectorizedSearch, - generationService, - stringMapper, - resultMapper, - options); - } + var vectorSearch = sp.GetService>(); - throw new InvalidOperationException("No IVectorizableTextSearch or IVectorizedSearch and ITextEmbeddingGenerationService registered."); + return vectorSearch is null + ? throw new InvalidOperationException("No IVectorSearch registered.") + : new VectorStoreTextSearch(vectorSearch, stringMapper, resultMapper, options); }); return services; @@ -71,14 +53,14 @@ public static class TextSearchServiceCollectionExtensions /// Register a instance with the specified service ID. /// /// The to register the on. - /// Service id of the to use. + /// Service id of the to use. /// instance that can map a TRecord to a /// instance that can map a TRecord to a /// Options used to construct an instance of /// An optional service id to use as the service key. public static IServiceCollection AddVectorStoreTextSearch<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] TRecord>( this IServiceCollection services, - string vectorizableTextSearchServiceId, + string vectorSearchServiceId, ITextSearchStringMapper? stringMapper = null, ITextSearchResultMapper? resultMapper = null, VectorStoreTextSearchOptions? options = null, @@ -95,17 +77,17 @@ public static class TextSearchServiceCollectionExtensions resultMapper ??= sp.GetService(); options ??= sp.GetService(); - var vectorizableTextSearch = sp.GetKeyedService>(vectorizableTextSearchServiceId); - if (vectorizableTextSearch is not null) + var vectorSearch = sp.GetKeyedService>(vectorSearchServiceId); + if (vectorSearch is not null) { return new VectorStoreTextSearch( - vectorizableTextSearch, + vectorSearch, stringMapper, resultMapper, options); } - throw new InvalidOperationException($"No IVectorizableTextSearch for service id {vectorizableTextSearchServiceId} registered."); + throw new InvalidOperationException($"No IVectorSearch for service id {vectorSearchServiceId} registered."); }); return services; @@ -115,15 +97,16 @@ public static class TextSearchServiceCollectionExtensions /// Register a instance with the specified service ID. /// /// The to register the on. - /// Service id of the to use. + /// Service id of the to use. /// Service id of the to use. /// instance that can map a TRecord to a /// instance that can map a TRecord to a /// Options used to construct an instance of /// An optional service id to use as the service key. + [Obsolete("Use the overload which doesn't accept a textEmbeddingGenerationServiceId, and configure an IEmbeddingGenerator instead with the collection represented by vectorSearchServiceId.")] public static IServiceCollection AddVectorStoreTextSearch<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] TRecord>( this IServiceCollection services, - string vectorizedSearchServiceId, + string vectorSearchServiceId, string textEmbeddingGenerationServiceId, ITextSearchStringMapper? stringMapper = null, ITextSearchResultMapper? resultMapper = null, @@ -141,10 +124,10 @@ public static class TextSearchServiceCollectionExtensions resultMapper ??= sp.GetService(); options ??= sp.GetService(); - var vectorizedSearch = sp.GetKeyedService>(vectorizedSearchServiceId); + var vectorizedSearch = sp.GetKeyedService>(vectorSearchServiceId); if (vectorizedSearch is null) { - throw new InvalidOperationException($"No IVectorizedSearch for service id {vectorizedSearchServiceId} registered."); + throw new InvalidOperationException($"No IVectorizedSearch for service id {vectorSearchServiceId} registered."); } var generationService = sp.GetKeyedService(textEmbeddingGenerationServiceId); diff --git a/dotnet/src/SemanticKernel.Core/Data/TextSearch/VectorStoreTextSearch.cs b/dotnet/src/SemanticKernel.Core/Data/TextSearch/VectorStoreTextSearch.cs index 68ae09c883d5..6d06c880462a 100644 --- a/dotnet/src/SemanticKernel.Core/Data/TextSearch/VectorStoreTextSearch.cs +++ b/dotnet/src/SemanticKernel.Core/Data/TextSearch/VectorStoreTextSearch.cs @@ -20,22 +20,23 @@ public sealed class VectorStoreTextSearch<[DynamicallyAccessedMembers(Dynamicall { /// /// Create an instance of the with the - /// provided for performing searches and + /// provided for performing searches and /// for generating vectors from the text search query. /// - /// instance used to perform the search. + /// instance used to perform the search. /// instance used to create a vector from the text query. /// instance that can map a TRecord to a /// instance that can map a TRecord to a /// Options used to construct an instance of + [Obsolete("Use the constructor without an ITextEmbeddingGenerationService and pass a vectorSearch configured to perform embedding generation with IEmbeddingGenerator")] public VectorStoreTextSearch( - IVectorizedSearch vectorizedSearch, + IVectorSearch vectorSearch, ITextEmbeddingGenerationService textEmbeddingGeneration, MapFromResultToString stringMapper, MapFromResultToTextSearchResult resultMapper, VectorStoreTextSearchOptions? options = null) : this( - vectorizedSearch, + vectorSearch, textEmbeddingGeneration, stringMapper is null ? null : new TextSearchStringMapper(stringMapper), resultMapper is null ? null : new TextSearchResultMapper(resultMapper), @@ -45,25 +46,26 @@ public VectorStoreTextSearch( /// /// Create an instance of the with the - /// provided for performing searches and + /// provided for performing searches and /// for generating vectors from the text search query. /// - /// instance used to perform the search. + /// instance used to perform the search. /// instance used to create a vector from the text query. /// instance that can map a TRecord to a /// instance that can map a TRecord to a /// Options used to construct an instance of + [Obsolete("Use the constructor without an ITextEmbeddingGenerationService and pass a vectorSearch configured to perform embedding generation with IEmbeddingGenerator")] public VectorStoreTextSearch( - IVectorizedSearch vectorizedSearch, + IVectorSearch vectorSearch, ITextEmbeddingGenerationService textEmbeddingGeneration, ITextSearchStringMapper? stringMapper = null, ITextSearchResultMapper? resultMapper = null, VectorStoreTextSearchOptions? options = null) { - Verify.NotNull(vectorizedSearch); + Verify.NotNull(vectorSearch); Verify.NotNull(textEmbeddingGeneration); - this._vectorizedSearch = vectorizedSearch; + this._vectorSearch = vectorSearch; this._textEmbeddingGeneration = textEmbeddingGeneration; this._propertyReader = new Lazy(() => new TextSearchResultPropertyReader(typeof(TRecord))); this._stringMapper = stringMapper ?? this.CreateTextSearchStringMapper(); @@ -75,17 +77,17 @@ public VectorStoreTextSearch( /// provided for performing searches and /// for generating vectors from the text search query. /// - /// instance used to perform the text search. + /// instance used to perform the text search. /// instance that can map a TRecord to a /// instance that can map a TRecord to a /// Options used to construct an instance of public VectorStoreTextSearch( - IVectorizableTextSearch vectorizableTextSearch, + IVectorSearch vectorSearch, MapFromResultToString stringMapper, MapFromResultToTextSearchResult resultMapper, VectorStoreTextSearchOptions? options = null) : this( - vectorizableTextSearch, + vectorSearch, new TextSearchStringMapper(stringMapper), new TextSearchResultMapper(resultMapper), options) @@ -97,52 +99,51 @@ public VectorStoreTextSearch( /// provided for performing searches and /// for generating vectors from the text search query. /// - /// instance used to perform the text search. + /// instance used to perform the text search. /// instance that can map a TRecord to a /// instance that can map a TRecord to a /// Options used to construct an instance of public VectorStoreTextSearch( - IVectorizableTextSearch vectorizableTextSearch, + IVectorSearch vectorSearch, ITextSearchStringMapper? stringMapper = null, ITextSearchResultMapper? resultMapper = null, VectorStoreTextSearchOptions? options = null) { - Verify.NotNull(vectorizableTextSearch); + Verify.NotNull(vectorSearch); - this._vectorizableTextSearch = vectorizableTextSearch; + this._vectorSearch = vectorSearch; this._propertyReader = new Lazy(() => new TextSearchResultPropertyReader(typeof(TRecord))); this._stringMapper = stringMapper ?? this.CreateTextSearchStringMapper(); this._resultMapper = resultMapper ?? this.CreateTextSearchResultMapper(); } /// - public async Task> SearchAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) + public Task> SearchAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) { - VectorSearchResults searchResponse = await this.ExecuteVectorSearchAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); + var searchResponse = this.ExecuteVectorSearchAsync(query, searchOptions, cancellationToken); - return new KernelSearchResults(this.GetResultsAsStringAsync(searchResponse.Results, cancellationToken), searchResponse.TotalCount, searchResponse.Metadata); + return Task.FromResult(new KernelSearchResults(this.GetResultsAsStringAsync(searchResponse, cancellationToken))); } /// - public async Task> GetTextSearchResultsAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) + public Task> GetTextSearchResultsAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) { - VectorSearchResults searchResponse = await this.ExecuteVectorSearchAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); + var searchResponse = this.ExecuteVectorSearchAsync(query, searchOptions, cancellationToken); - return new KernelSearchResults(this.GetResultsAsTextSearchResultAsync(searchResponse.Results, cancellationToken), searchResponse.TotalCount, searchResponse.Metadata); + return Task.FromResult(new KernelSearchResults(this.GetResultsAsTextSearchResultAsync(searchResponse, cancellationToken))); } /// - public async Task> GetSearchResultsAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) + public Task> GetSearchResultsAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) { - VectorSearchResults searchResponse = await this.ExecuteVectorSearchAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); + var searchResponse = this.ExecuteVectorSearchAsync(query, searchOptions, cancellationToken); - return new KernelSearchResults(this.GetResultsAsRecordAsync(searchResponse.Results, cancellationToken), searchResponse.TotalCount, searchResponse.Metadata); + return Task.FromResult(new KernelSearchResults(this.GetResultsAsRecordAsync(searchResponse, cancellationToken))); } #region private - private readonly IVectorizedSearch? _vectorizedSearch; private readonly ITextEmbeddingGenerationService? _textEmbeddingGeneration; - private readonly IVectorizableTextSearch? _vectorizableTextSearch; + private readonly IVectorSearch? _vectorSearch; private readonly ITextSearchStringMapper _stringMapper; private readonly ITextSearchResultMapper _resultMapper; private readonly Lazy _propertyReader; @@ -194,7 +195,7 @@ private TextSearchStringMapper CreateTextSearchStringMapper() /// What to search for. /// Search options. /// The to monitor for cancellation requests. The default is . - private async Task> ExecuteVectorSearchAsync(string query, TextSearchOptions? searchOptions, CancellationToken cancellationToken) + private async IAsyncEnumerable> ExecuteVectorSearchAsync(string query, TextSearchOptions? searchOptions, [EnumeratorCancellation] CancellationToken cancellationToken) { searchOptions ??= new TextSearchOptions(); var vectorSearchOptions = new VectorSearchOptions @@ -203,17 +204,24 @@ private async Task> ExecuteVectorSearchAsync(string OldFilter = searchOptions.Filter?.FilterClauses is not null ? new VectorSearchFilter(searchOptions.Filter.FilterClauses) : null, #pragma warning restore CS0618 // VectorSearchFilter is obsolete Skip = searchOptions.Skip, - Top = searchOptions.Top, }; - if (this._vectorizedSearch is not null) + if (this._textEmbeddingGeneration is not null) { var vectorizedQuery = await this._textEmbeddingGeneration!.GenerateEmbeddingAsync(query, cancellationToken: cancellationToken).ConfigureAwait(false); - return await this._vectorizedSearch.VectorizedSearchAsync(vectorizedQuery, vectorSearchOptions, cancellationToken).ConfigureAwait(false); + await foreach (var result in this._vectorSearch!.SearchEmbeddingAsync(vectorizedQuery, searchOptions.Top, vectorSearchOptions, cancellationToken).ConfigureAwait(false)) + { + yield return result; + } + + yield break; } - return await this._vectorizableTextSearch!.VectorizableTextSearchAsync(query, vectorSearchOptions, cancellationToken).ConfigureAwait(false); + await foreach (var result in this._vectorSearch!.SearchAsync(query, searchOptions.Top, vectorSearchOptions, cancellationToken).ConfigureAwait(false)) + { + yield return result; + } } /// diff --git a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromMethod.cs b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromMethod.cs index 49991ad39c5a..2c212df12ef8 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromMethod.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromMethod.cs @@ -467,7 +467,7 @@ private KernelFunctionFromMethod( ReadOnlyDictionary? additionalMetadata = null) : base(functionName, pluginName, description, parameters, returnParameter, additionalMetadata: additionalMetadata) { - Verify.ValidFunctionName(functionName); + KernelVerify.ValidFunctionName(functionName); this._function = implementationFunc; this.UnderlyingMethod = method; @@ -485,7 +485,7 @@ private KernelFunctionFromMethod( ReadOnlyDictionary? additionalMetadata = null) : base(functionName, pluginName, description, parameters, jsonSerializerOptions, returnParameter, additionalMetadata: additionalMetadata) { - Verify.ValidFunctionName(functionName); + KernelVerify.ValidFunctionName(functionName); this._function = implementationFunc; this.UnderlyingMethod = method; @@ -525,7 +525,7 @@ private static MethodDetails GetMethodDetails(string? functionName, MethodInfo m } } - Verify.ValidFunctionName(functionName); + KernelVerify.ValidFunctionName(functionName); // Build up a list of KernelParameterMetadata for the parameters we expect to be populated // from arguments. Some arguments are populated specially, not from arguments, and thus @@ -546,7 +546,7 @@ private static MethodDetails GetMethodDetails(string? functionName, MethodInfo m } // Check for param names conflict - Verify.ParametersUniqueness(argParameterViews); + KernelVerify.ParametersUniqueness(argParameterViews); // Get the return type and a marshaling func for the return value. (Type returnType, Func> returnFunc) = GetReturnValueMarshalerDelegate(method); diff --git a/dotnet/src/SemanticKernel.Core/Functions/KernelPluginFactory.cs b/dotnet/src/SemanticKernel.Core/Functions/KernelPluginFactory.cs index c1d9180479b6..23f8ea8ba7c0 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/KernelPluginFactory.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/KernelPluginFactory.cs @@ -235,7 +235,7 @@ static void AppendWithoutArity(StringBuilder builder, string name) Verify.NotNull(target); pluginName ??= CreatePluginName(target.GetType()); - Verify.ValidPluginName(pluginName); + KernelVerify.ValidPluginName(pluginName); MethodInfo[] methods = target.GetType().GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static); diff --git a/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj b/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj index 2a5d5d03d961..094d79e6052f 100644 --- a/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj +++ b/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj @@ -48,8 +48,6 @@ - - diff --git a/dotnet/src/SemanticKernel.UnitTests/Data/TextSearchServiceCollectionExtensionsTests.cs b/dotnet/src/SemanticKernel.UnitTests/Data/TextSearchServiceCollectionExtensionsTests.cs index c890bd0f29f6..04e51b6bbb45 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Data/TextSearchServiceCollectionExtensionsTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Data/TextSearchServiceCollectionExtensionsTests.cs @@ -10,45 +10,23 @@ using Xunit; namespace SemanticKernel.UnitTests.Data; + public class TextSearchServiceCollectionExtensionsTests : VectorStoreTextSearchTestBase { [Fact] - public void AddVectorStoreTextSearchWithIVectorizableTextSearch() + public void AddVectorStoreTextSearch() { // Arrange - var services = new ServiceCollection(); - var vectorStore = new InMemoryVectorStore(); - var vectorSearch = vectorStore.GetCollection("records"); - var stringMapper = new DataModelTextSearchStringMapper(); - var resultMapper = new DataModelTextSearchResultMapper(); - var vectorizableTextSearch = new VectorizedSearchWrapper(vectorSearch, new MockTextEmbeddingGenerationService()); - - // Act - services.AddSingleton>(vectorizableTextSearch); - services.AddSingleton(stringMapper); - services.AddSingleton(resultMapper); - services.AddVectorStoreTextSearch(); - - // Assert - var serviceProvider = services.BuildServiceProvider(); - var result = serviceProvider.GetRequiredService>(); - Assert.NotNull(result); - } + using var embeddingGenerator = new MockTextEmbeddingGenerator(); - [Fact] - public void AddVectorStoreTextSearchWithIVectorizedSearch() - { - // Arrange var services = new ServiceCollection(); - var vectorStore = new InMemoryVectorStore(); - var vectorSearch = vectorStore.GetCollection("records"); + var vectorStore = new InMemoryVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); + var collection = vectorStore.GetCollection("records"); var stringMapper = new DataModelTextSearchStringMapper(); var resultMapper = new DataModelTextSearchResultMapper(); - var textGeneration = new MockTextEmbeddingGenerationService(); // Act - services.AddSingleton>(vectorSearch); - services.AddSingleton(textGeneration); + services.AddSingleton>(collection); services.AddSingleton(stringMapper); services.AddSingleton(resultMapper); services.AddVectorStoreTextSearch(); @@ -60,36 +38,17 @@ public void AddVectorStoreTextSearchWithIVectorizedSearch() } [Fact] - public void AddVectorStoreTextSearchWithIVectorizableTextSearchAndNoMappers() + public void AddVectorStoreTextSearchWithNoMappers() { // Arrange - var services = new ServiceCollection(); - var vectorStore = new InMemoryVectorStore(); - var vectorSearch = vectorStore.GetCollection("records"); - var vectorizableTextSearch = new VectorizedSearchWrapper(vectorSearch, new MockTextEmbeddingGenerationService()); + using var embeddingGenerator = new MockTextEmbeddingGenerator(); - // Act - services.AddSingleton>(vectorizableTextSearch); - services.AddVectorStoreTextSearch(); - - // Assert - var serviceProvider = services.BuildServiceProvider(); - var result = serviceProvider.GetRequiredService>(); - Assert.NotNull(result); - } - - [Fact] - public void AddVectorStoreTextSearchWithIVectorizedSearchAndNoMappers() - { - // Arrange var services = new ServiceCollection(); - var vectorStore = new InMemoryVectorStore(); - var vectorSearch = vectorStore.GetCollection("records"); - var textGeneration = new MockTextEmbeddingGenerationService(); + var vectorStore = new InMemoryVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); + var collection = vectorStore.GetCollection("records"); // Act - services.AddSingleton>(vectorSearch); - services.AddSingleton(textGeneration); + services.AddSingleton>(collection); services.AddVectorStoreTextSearch(); // Assert @@ -99,17 +58,18 @@ public void AddVectorStoreTextSearchWithIVectorizedSearchAndNoMappers() } [Fact] - public void AddVectorStoreTextSearchWithKeyedIVectorizableTextSearch() + public void AddVectorStoreTextSearchWithKeyedIVectorSearch() { // Arrange + using var embeddingGenerator = new MockTextEmbeddingGenerator(); + var services = new ServiceCollection(); - var vectorStore = new InMemoryVectorStore(); - var vectorSearch = vectorStore.GetCollection("records"); - var vectorizableTextSearch1 = new VectorizedSearchWrapper(vectorSearch, new MockTextEmbeddingGenerationService()); + var vectorStore = new InMemoryVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); + var collection = vectorStore.GetCollection("records"); // Act - services.AddKeyedSingleton>("vts1", vectorizableTextSearch1); - services.AddVectorStoreTextSearch("vts1"); + services.AddKeyedSingleton>("vs1", collection); + services.AddVectorStoreTextSearch("vs1"); // Assert var serviceProvider = services.BuildServiceProvider(); @@ -118,62 +78,64 @@ public void AddVectorStoreTextSearchWithKeyedIVectorizableTextSearch() } [Fact] - public void AddVectorStoreTextSearchFailsMissingKeyedVectorizableTextSearch() + public void AddVectorStoreTextSearchFailsMissingKeyedIVectorSearch() { // Arrange + using var embeddingGenerator = new MockTextEmbeddingGenerator(); + var services = new ServiceCollection(); - var vectorStore = new InMemoryVectorStore(); - var vectorSearch = vectorStore.GetCollection("records"); - var vectorizableTextSearch1 = new VectorizedSearchWrapper(vectorSearch, new MockTextEmbeddingGenerationService()); + var vectorStore = new InMemoryVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); + var collection = vectorStore.GetCollection("records"); // Act - services.AddKeyedSingleton>("vts1", vectorizableTextSearch1); - services.AddVectorStoreTextSearch("vts2"); + services.AddKeyedSingleton>("vs1", collection); + services.AddVectorStoreTextSearch("vs2"); // Assert var serviceProvider = services.BuildServiceProvider(); Assert.Throws(() => serviceProvider.GetRequiredService>()); } +#pragma warning disable CS0618 // Type or member is obsolete [Fact] - public void AddVectorStoreTextSearchWithKeyedIVectorizedSearch() + public void AddVectorStoreTextSearchWithKeyedVectorSearchAndEmbeddingGenerationService() { // Arrange var services = new ServiceCollection(); var vectorStore = new InMemoryVectorStore(); - var vectorSearch = vectorStore.GetCollection("records"); - var textGeneration = new MockTextEmbeddingGenerationService(); + var collection = vectorStore.GetCollection("records"); + using var generator = new MockTextEmbeddingGenerator(); // Act - services.AddKeyedSingleton>("vs1", vectorSearch); - services.AddKeyedSingleton("tegs1", textGeneration); + services.AddKeyedSingleton>("vs1", collection); + services.AddKeyedSingleton("tegs1", generator); - services.AddVectorStoreTextSearch("vs1", "tegs1"); + services.AddVectorStoreTextSearch("vs1", "tegs1"); // Assert var serviceProvider = services.BuildServiceProvider(); - var result = serviceProvider.GetRequiredService>(); + var result = serviceProvider.GetRequiredService>(); Assert.NotNull(result); } [Fact] - public void AddVectorStoreTextSearchFailsMissingKeyedVectorizedSearch() + public void AddVectorStoreTextSearchFailsMissingKeyedVectorSearch() { // Arrange var services = new ServiceCollection(); var vectorStore = new InMemoryVectorStore(); - var vectorSearch = vectorStore.GetCollection("records"); - var textGeneration = new MockTextEmbeddingGenerationService(); + var collection = vectorStore.GetCollection("records"); + using var textGeneration = new MockTextEmbeddingGenerator(); // Act - services.AddKeyedSingleton>("vs1", vectorSearch); + services.AddKeyedSingleton>("vs1", collection); services.AddKeyedSingleton("tegs1", textGeneration); - services.AddVectorStoreTextSearch("vs2", "tegs1"); + services.AddVectorStoreTextSearch("vs2", "tegs1"); // Assert var serviceProvider = services.BuildServiceProvider(); - Assert.Throws(() => serviceProvider.GetRequiredService>()); + Assert.Throws(() => serviceProvider.GetRequiredService>()); } [Fact] @@ -182,17 +144,18 @@ public void AddVectorStoreTextSearchFailsMissingKeyedTextEmbeddingGenerationServ // Arrange var services = new ServiceCollection(); var vectorStore = new InMemoryVectorStore(); - var vectorSearch = vectorStore.GetCollection("records"); - var textGeneration = new MockTextEmbeddingGenerationService(); + var vectorSearch = vectorStore.GetCollection("records"); + using var textGeneration = new MockTextEmbeddingGenerator(); // Act - services.AddKeyedSingleton>("vs1", vectorSearch); + services.AddKeyedSingleton>("vs1", vectorSearch); services.AddKeyedSingleton("tegs1", textGeneration); - services.AddVectorStoreTextSearch("vs1", "tegs2"); + services.AddVectorStoreTextSearch("vs1", "tegs2"); // Assert var serviceProvider = services.BuildServiceProvider(); - Assert.Throws(() => serviceProvider.GetRequiredService>()); + Assert.Throws(() => serviceProvider.GetRequiredService>()); } +#pragma warning restore CS0618 // Type or member is obsolete } diff --git a/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreRecordMappingTests.cs b/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreRecordMappingTests.cs index 121d7ac38d07..68e5caa11643 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreRecordMappingTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreRecordMappingTests.cs @@ -3,8 +3,6 @@ using System; using System.Collections; using System.Collections.Generic; -using System.Linq; -using System.Reflection; using Microsoft.Extensions.VectorData; using Xunit; @@ -12,95 +10,6 @@ namespace SemanticKernel.UnitTests.Data; public class VectorStoreRecordMappingTests { - [Fact] - public void BuildPropertiesInfoWithValuesShouldBuildPropertiesInfo() - { - // Arrange. - var dataModelPropertiesInfo = new[] - { - typeof(DataModel).GetProperty(nameof(DataModel.Key))!, - typeof(DataModel).GetProperty(nameof(DataModel.Data))! - }; - var dataModelToStorageNameMapping = new Dictionary - { - { nameof(DataModel.Key), "key" }, - { nameof(DataModel.Data), "data" }, - }; - var storageValues = new Dictionary - { - { "key", "key value" }, - { "data", "data value" }, - }; - - // Act. - var propertiesInfoWithValues = VectorStoreRecordMapping.BuildPropertiesInfoWithValues( - dataModelPropertiesInfo, - dataModelToStorageNameMapping, - storageValues); - - // Assert. - var propertiesInfoWithValuesArray = propertiesInfoWithValues.ToArray(); - Assert.Equal(2, propertiesInfoWithValuesArray.Length); - Assert.Equal(dataModelPropertiesInfo[0], propertiesInfoWithValuesArray[0].Key); - Assert.Equal("key value", propertiesInfoWithValuesArray[0].Value); - Assert.Equal(dataModelPropertiesInfo[1], propertiesInfoWithValuesArray[1].Key); - Assert.Equal("data value", propertiesInfoWithValuesArray[1].Value); - } - - [Fact] - public void BuildPropertiesInfoWithValuesShouldUseValueMapperIfProvided() - { - // Arrange. - var dataModelPropertiesInfo = new[] - { - typeof(DataModel).GetProperty(nameof(DataModel.Key))!, - typeof(DataModel).GetProperty(nameof(DataModel.Data))! - }; - var dataModelToStorageNameMapping = new Dictionary - { - { nameof(DataModel.Key), "key" }, - { nameof(DataModel.Data), "data" }, - }; - var storageValues = new Dictionary - { - { "key", 10 }, - { "data", 20 }, - }; - - // Act. - var propertiesInfoWithValues = VectorStoreRecordMapping.BuildPropertiesInfoWithValues( - dataModelPropertiesInfo, - dataModelToStorageNameMapping, - storageValues, - (int value, Type type) => value.ToString()); - - // Assert. - var propertiesInfoWithValuesArray = propertiesInfoWithValues.ToArray(); - Assert.Equal(2, propertiesInfoWithValuesArray.Length); - Assert.Equal(dataModelPropertiesInfo[0], propertiesInfoWithValuesArray[0].Key); - Assert.Equal("10", propertiesInfoWithValuesArray[0].Value); - Assert.Equal(dataModelPropertiesInfo[1], propertiesInfoWithValuesArray[1].Key); - Assert.Equal("20", propertiesInfoWithValuesArray[1].Value); - } - - [Fact] - public void SetPropertiesOnRecordShouldSetProperties() - { - // Arrange. - var record = new DataModel(); - - // Act. - VectorStoreRecordMapping.SetPropertiesOnRecord(record, new[] - { - new KeyValuePair(typeof(DataModel).GetProperty(nameof(DataModel.Key))!, "key value"), - new KeyValuePair(typeof(DataModel).GetProperty(nameof(DataModel.Data))!, "data value"), - }); - - // Assert. - Assert.Equal("key value", record.Key); - Assert.Equal("data value", record.Data); - } - [Theory] [InlineData(typeof(List))] [InlineData(typeof(ICollection))] @@ -160,10 +69,4 @@ public void CreateEnumerableThrowsForUnsupportedType(Type expectedType) // Act & Assert. Assert.Throws(() => VectorStoreRecordMapping.CreateEnumerable(input, expectedType)); } - - private sealed class DataModel - { - public string Key { get; set; } = string.Empty; - public string Data { get; set; } = string.Empty; - } } diff --git a/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreRecordPropertyReaderTests.cs b/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreRecordPropertyReaderTests.cs deleted file mode 100644 index bbaabdd3d844..000000000000 --- a/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreRecordPropertyReaderTests.cs +++ /dev/null @@ -1,814 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Text.Json; -using System.Text.Json.Serialization; -using Microsoft.Extensions.VectorData; -using Xunit; - -namespace SemanticKernel.UnitTests.Data; - -public class VectorStoreRecordPropertyReaderTests -{ - [Theory] - [MemberData(nameof(NoKeyTypeAndDefinitionCombos))] - public void ConstructorFailsForNoKey(Type type, VectorStoreRecordDefinition? definition) - { - // Act & Assert. - var exception = Assert.Throws(() => new VectorStoreRecordPropertyReader(type, definition, null)); - Assert.Equal("No key property found on type NoKeyModel or the provided VectorStoreRecordDefinition.", exception.Message); - } - - [Theory] - [MemberData(nameof(MultiKeysTypeAndDefinitionCombos))] - public void ConstructorSucceedsForSupportedMultiKeys(Type type, VectorStoreRecordDefinition? definition) - { - // Act & Assert. - var sut = new VectorStoreRecordPropertyReader(type, definition, new VectorStoreRecordPropertyReaderOptions { SupportsMultipleKeys = true }); - } - - [Theory] - [MemberData(nameof(MultiKeysTypeAndDefinitionCombos))] - public void ConstructorFailsForUnsupportedMultiKeys(Type type, VectorStoreRecordDefinition? definition) - { - // Act & Assert. - var exception = Assert.Throws(() => new VectorStoreRecordPropertyReader(type, definition, new VectorStoreRecordPropertyReaderOptions { SupportsMultipleKeys = false })); - Assert.Equal("Multiple key properties found on type MultiKeysModel or the provided VectorStoreRecordDefinition.", exception.Message); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void ConstructorSucceedsForSupportedMultiVectors(Type type, VectorStoreRecordDefinition? definition) - { - // Act & Assert. - var sut = new VectorStoreRecordPropertyReader(type, definition, new VectorStoreRecordPropertyReaderOptions { SupportsMultipleVectors = true }); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void ConstructorFailsForUnsupportedMultiVectors(Type type, VectorStoreRecordDefinition? definition) - { - // Act & Assert. - var exception = Assert.Throws(() => new VectorStoreRecordPropertyReader(type, definition, new VectorStoreRecordPropertyReaderOptions { SupportsMultipleVectors = false })); - Assert.Equal("Multiple vector properties found on type MultiPropsModel or the provided VectorStoreRecordDefinition while only one is supported.", exception.Message); - } - - [Theory] - [MemberData(nameof(NoVectorsTypeAndDefinitionCombos))] - public void ConstructorFailsForUnsupportedNoVectors(Type type, VectorStoreRecordDefinition? definition) - { - // Act & Assert. - var exception = Assert.Throws(() => new VectorStoreRecordPropertyReader(type, definition, new VectorStoreRecordPropertyReaderOptions { RequiresAtLeastOneVector = true })); - Assert.Equal("No vector property found on type NoVectorModel or the provided VectorStoreRecordDefinition while at least one is required.", exception.Message); - } - - [Theory] - [MemberData(nameof(TypeAndDefinitionCombos))] - public void CanGetDefinition(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act. - var actual = sut.RecordDefinition; - - // Assert. - Assert.NotNull(actual); - } - - [Theory] - [MemberData(nameof(TypeAndDefinitionCombos))] - public void CanGetKeyPropertyInfo(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act. - var actual = sut.KeyPropertyInfo; - - // Assert. - Assert.NotNull(actual); - Assert.Equal("Key", actual.Name); - Assert.Equal(typeof(string), actual.PropertyType); - } - - [Theory] - [MemberData(nameof(TypeAndDefinitionCombos))] - public void CanGetKeyPropertiesInfo(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act. - var actual = sut.KeyPropertiesInfo; - - // Assert. - Assert.NotNull(actual); - Assert.Single(actual); - Assert.Equal("Key", actual[0].Name); - Assert.Equal(typeof(string), actual[0].PropertyType); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void CanGetDataPropertiesInfo(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act. - var actual = sut.DataPropertiesInfo; - - // Assert. - Assert.NotNull(actual); - Assert.Equal(2, actual.Count); - Assert.Equal("Data1", actual[0].Name); - Assert.Equal(typeof(string), actual[0].PropertyType); - Assert.Equal("Data2", actual[1].Name); - Assert.Equal(typeof(string), actual[1].PropertyType); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void CanGetVectorPropertiesInfo(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act. - var actual = sut.VectorPropertiesInfo; - - // Assert. - Assert.NotNull(actual); - Assert.Equal(2, actual.Count); - Assert.Equal("Vector1", actual[0].Name); - Assert.Equal(typeof(ReadOnlyMemory), actual[0].PropertyType); - Assert.Equal("Vector2", actual[1].Name); - Assert.Equal(typeof(ReadOnlyMemory), actual[1].PropertyType); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void CanGetFirstVectorPropertyName(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act. - var actual = sut.FirstVectorPropertyName; - - // Assert. - Assert.Equal("Vector1", actual); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void CanGetFirstVectorPropertyInfo(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act. - var actual = sut.FirstVectorPropertyInfo; - - // Assert. - Assert.NotNull(actual); - Assert.Equal("Vector1", actual.Name); - Assert.Equal(typeof(ReadOnlyMemory), actual.PropertyType); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void CanGetKeyPropertyName(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act. - var actual = sut.KeyPropertyName; - - // Assert. - Assert.Equal("Key", actual); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void CanGetKeyPropertyStoragePropertyNameWithoutOverride(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act. - var actual = sut.KeyPropertyStoragePropertyName; - - // Assert. - Assert.Equal("Key", actual); - } - - [Theory] - [MemberData(nameof(StorageNamesPropsTypeAndDefinitionCombos))] - public void CanGetKeyPropertyStoragePropertyNameWithOverride(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act. - var actual = sut.KeyPropertyStoragePropertyName; - - // Assert. - Assert.Equal("storage_key", actual); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void CanGetDataPropertyStoragePropertyNameWithOverrideMix(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act. - var actual = sut.DataPropertyStoragePropertyNames; - - // Assert. - Assert.Equal("Data1", actual[0]); - Assert.Equal("storage_data2", actual[1]); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void CanGetVectorPropertyStoragePropertyNameWithOverrideMix(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act. - var actual = sut.VectorPropertyStoragePropertyNames; - - // Assert. - Assert.Equal("Vector1", actual[0]); - Assert.Equal("storage_vector2", actual[1]); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void CanGetKeyPropertyJsonNameWithoutOverride(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act. - var actual = sut.KeyPropertyJsonName; - - // Assert. - Assert.Equal("Key", actual); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void CanGetKeyPropertyJsonNameWithSerializerSettings(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, new() - { - JsonSerializerOptions = new JsonSerializerOptions() - { - PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseUpper - } - }); - - // Act. - var actual = sut.KeyPropertyJsonName; - - // Assert. - Assert.Equal("KEY", actual); - } - - [Theory] - [MemberData(nameof(StorageNamesPropsTypeAndDefinitionCombos))] - public void CanGetKeyPropertyJsonNameWithOverride(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act. - var actual = sut.KeyPropertyJsonName; - - // Assert. - Assert.Equal("json_key", actual); - } - - [Theory] - [MemberData(nameof(StorageNamesPropsTypeAndDefinitionCombos))] - public void CanGetDataPropertyJsonNameWithOverride(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act. - var actual = sut.DataPropertyJsonNames; - - // Assert. - Assert.NotNull(actual); - Assert.Equal(2, actual.Count); - Assert.Equal("json_data1", actual[0]); - Assert.Equal("json_data2", actual[1]); - } - - [Theory] - [MemberData(nameof(StorageNamesPropsTypeAndDefinitionCombos))] - public void CanGetVectorPropertyJsonNameWithOverride(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act. - var actual = sut.VectorPropertyJsonNames; - - // Assert. - Assert.NotNull(actual); - Assert.Single(actual); - Assert.Equal("json_vector", actual[0]); - } - - [Theory] - [MemberData(nameof(TypeAndDefinitionCombos))] - public void VerifyKeyPropertiesPassesForAllowedTypes(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - var allowedTypes = new HashSet { typeof(string), typeof(int) }; - - // Act. - sut.VerifyKeyProperties(allowedTypes); - } - - [Theory] - [MemberData(nameof(TypeAndDefinitionCombos))] - public void VerifyKeyPropertiesFailsForDisallowedTypes(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - var allowedTypes = new HashSet { typeof(long) }; - - // Act. - var exception = Assert.Throws(() => sut.VerifyKeyProperties(allowedTypes)); - Assert.Equal("Key properties must be one of the supported types: System.Int64. Type of the property 'Key' is System.String.", exception.Message); - } - - [Theory] - [MemberData(nameof(EnumerablePropsTypeAndDefinitionCombos))] - public void VerifyDataPropertiesPassesForAllowedEnumerableTypes(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - var allowedTypes = new HashSet { typeof(string), typeof(int) }; - - // Act. - sut.VerifyDataProperties(allowedTypes, true); - } - - [Theory] - [MemberData(nameof(EnumerablePropsTypeAndDefinitionCombos))] - public void VerifyDataPropertiesFailsForDisallowedEnumerableTypes(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - var allowedTypes = new HashSet { typeof(string), typeof(int) }; - - // Act. - var exception = Assert.Throws(() => sut.VerifyDataProperties(allowedTypes, false)); - Assert.Equal("Data properties must be one of the supported types: System.String, System.Int32. Type of the property 'EnumerableData' is System.Collections.Generic.IEnumerable`1[[System.String, System.Private.CoreLib, Version=8.0.0.0, Culture=neutral, PublicKeyToken=7cec85d7bea7798e]].", exception.Message); - } - - [Theory] - [MemberData(nameof(EnumerablePropsTypeAndDefinitionCombos))] - public void VerifyVectorPropertiesPassesForAllowedEnumerableTypes(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - var allowedTypes = new HashSet { typeof(ReadOnlyMemory) }; - - // Act. - sut.VerifyVectorProperties(allowedTypes); - } - - [Theory] - [MemberData(nameof(EnumerablePropsTypeAndDefinitionCombos))] - public void VerifyVectorPropertiesFailsForDisallowedEnumerableTypes(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - var allowedTypes = new HashSet { typeof(ReadOnlyMemory) }; - - // Act. - var exception = Assert.Throws(() => sut.VerifyVectorProperties(allowedTypes)); - Assert.Equal("Vector properties must be one of the supported types: System.ReadOnlyMemory`1[[System.Double, System.Private.CoreLib, Version=8.0.0.0, Culture=neutral, PublicKeyToken=7cec85d7bea7798e]]. Type of the property 'Vector' is System.ReadOnlyMemory`1[[System.Single, System.Private.CoreLib, Version=8.0.0.0, Culture=neutral, PublicKeyToken=7cec85d7bea7798e]].", exception.Message); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void GetStoragePropertyNameReturnsStorageNameWithFallback(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act & Assert. - Assert.Equal("Data1", sut.GetStoragePropertyName("Data1")); - Assert.Equal("storage_data2", sut.GetStoragePropertyName("Data2")); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void GetJsonPropertyNameReturnsJsonWithFallback(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act & Assert. - Assert.Equal("Data1", sut.GetJsonPropertyName("Data1")); - Assert.Equal("json_data2", sut.GetJsonPropertyName("Data2")); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void GetVectorPropertyOrSingleReturnsRequestedVectorAndThrowsForInvalidVector(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - var validVector = new VectorSearchOptions() { VectorProperty = r => r.Vector2 }; - var invalidVector = new VectorSearchOptions() { VectorProperty = r => r.Data2 }; - - // Act & Assert. - Assert.Equal("Vector2", sut.GetVectorPropertyOrSingle(validVector).DataModelPropertyName); - Assert.Throws(() => sut.GetVectorPropertyOrSingle(invalidVector)); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void GetVectorPropertyOrSingleThrowsForMultipleVectors(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act & Assert. - Assert.Throws(() => sut.GetVectorPropertyOrSingle(null)); - } - - [Theory] - [MemberData(nameof(NoVectorsTypeAndDefinitionCombos))] - public void GetVectorPropertyOrSingleThrowsForNoVectors(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act & Assert. - Assert.Throws(() => sut.GetVectorPropertyOrSingle(null)); - } - - [Fact] - public void GetVectorPropertyOrSingleReturnsRequestedGenericDataModelVectorWhenUsingConst() - { - const string TheConst = "FloatVector"; - VectorStoreRecordPropertyReader sut = CreateReaderForGenericModel(TheConst); - VectorSearchOptions> expectedConst = new() - { - VectorProperty = r => r.Vectors[TheConst] - }; - VectorSearchOptions> wrongConst = new() - { - VectorProperty = r => r.Vectors["Different"] - }; - - Assert.Equal(TheConst, sut.GetVectorPropertyOrSingle(expectedConst).DataModelPropertyName); - Assert.Throws(() => sut.GetVectorPropertyOrSingle(wrongConst)); - } - - [Fact] - public void GetVectorPropertyOrSingleReturnsRequestedGenericDataModelVectorWhenUsingVariable() - { - string theVariable = "FloatVector"; - string theWrongVariable = "Different"; - VectorStoreRecordPropertyReader sut = CreateReaderForGenericModel(theVariable); - VectorSearchOptions> expectedVariable = new() - { - VectorProperty = r => r.Vectors[theVariable] - }; - VectorSearchOptions> wrongVariable = new() - { - VectorProperty = r => r.Vectors[theWrongVariable] - }; - - Assert.Equal(theVariable, sut.GetVectorPropertyOrSingle(expectedVariable).DataModelPropertyName); - Assert.Throws(() => sut.GetVectorPropertyOrSingle(wrongVariable)); - } - - [Theory] - [InlineData("FloatVector", "Different")] - // it's a Theory just for the need of testing a method expected being captured by the lambda property selector - public void GetVectorPropertyOrSingleReturnsRequestedGenericDataModelVectorWhenUsingArgument(string expected, string wrong) - { - VectorStoreRecordPropertyReader sut = CreateReaderForGenericModel(expected); - VectorSearchOptions> expectedArgument = new() - { - VectorProperty = r => r.Vectors[expected] - }; - VectorSearchOptions> wrongArgument = new() - { - VectorProperty = r => r.Vectors[wrong] - }; - - Assert.Equal("FloatVector", sut.GetVectorPropertyOrSingle(expectedArgument).DataModelPropertyName); - Assert.Throws(() => sut.GetVectorPropertyOrSingle(wrongArgument)); - } - - private static VectorStoreRecordPropertyReader CreateReaderForGenericModel(string vectorPropertyName) - { - VectorStoreGenericDataModel genericRecord = new("key") - { - Data = - { - ["Text"] = "data" - }, - Vectors = - { - [vectorPropertyName] = new ReadOnlyMemory([-1, -1, -1, -1]) - } - }; - VectorStoreRecordDefinition definition = new() - { - Properties = - [ - new VectorStoreRecordKeyProperty("Key", typeof(Guid)), - new VectorStoreRecordDataProperty("Text", typeof(string)), - new VectorStoreRecordVectorProperty(vectorPropertyName, typeof(ReadOnlyMemory)), - ] - }; - - return new(genericRecord.GetType(), definition, null); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void GetFullTextDataPropertyOrOnlyReturnsRequestedPropOrOnlyTextDataPropAndThrowsForInvalidProp(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act & Assert. - Assert.Equal("Data1", sut.GetFullTextDataPropertyOrSingle(r => r.Data1).DataModelPropertyName); - Assert.Equal("Data1", sut.GetFullTextDataPropertyOrSingle(null).DataModelPropertyName); - Assert.Throws(() => sut.GetFullTextDataPropertyOrSingle(r => r.Vector1)); - Assert.Throws(() => sut.GetFullTextDataPropertyOrSingle(r => "DoesNotExist")); - } - - [Theory] - [MemberData(nameof(NoVectorsTypeAndDefinitionCombos))] - public void GetFullTextDataPropertyOrOnlyThrowsForNoTextDataProps(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act & Assert. - Assert.Throws(() => sut.GetFullTextDataPropertyOrSingle(null)); - } - - [Theory] - [MemberData(nameof(MultiPropsTypeAndDefinitionCombos))] - public void GetFullTextDataPropertyOrOnlyThrowsForNonFullTextSearchProp(Type type, VectorStoreRecordDefinition? definition) - { - // Arrange. - var sut = new VectorStoreRecordPropertyReader(type, definition, null); - - // Act & Assert. - Assert.Throws(() => sut.GetFullTextDataPropertyOrSingle(r => r.Data2)); - } - - [Fact] - public void GetFullTextDataPropertyOrOnlyThrowsForMultipleMatchingProps() - { - // Arrange. - var properties = new List - { - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("Data1", typeof(string)) { IsFullTextSearchable = true }, - new VectorStoreRecordDataProperty("Data2", typeof(string)) { IsFullTextSearchable = true } - }; - var definition = new VectorStoreRecordDefinition - { - Properties = properties - }; - var sut = new VectorStoreRecordPropertyReader(typeof(object), definition, null); - - // Act & Assert. - Assert.Throws(() => sut.GetFullTextDataPropertyOrSingle(null)); - } - - public static IEnumerable NoKeyTypeAndDefinitionCombos() - { - yield return new object?[] { typeof(NoKeyModel), s_noKeyDefinition }; - yield return new object?[] { typeof(NoKeyModel), null }; - } - - public static IEnumerable NoVectorsTypeAndDefinitionCombos() - { - yield return new object?[] { typeof(NoVectorModel), s_noVectorDefinition }; - yield return new object?[] { typeof(NoVectorModel), null }; - } - - public static IEnumerable MultiKeysTypeAndDefinitionCombos() - { - yield return new object?[] { typeof(MultiKeysModel), s_multiKeysDefinition }; - yield return new object?[] { typeof(MultiKeysModel), null }; - } - - public static IEnumerable TypeAndDefinitionCombos() - { - yield return new object?[] { typeof(SinglePropsModel), s_singlePropsDefinition }; - yield return new object?[] { typeof(SinglePropsModel), null }; - yield return new object?[] { typeof(MultiPropsModel), s_multiPropsDefinition }; - yield return new object?[] { typeof(MultiPropsModel), null }; - yield return new object?[] { typeof(EnumerablePropsModel), s_enumerablePropsDefinition }; - yield return new object?[] { typeof(EnumerablePropsModel), null }; - } - - public static IEnumerable MultiPropsTypeAndDefinitionCombos() - { - yield return new object?[] { typeof(MultiPropsModel), s_multiPropsDefinition }; - yield return new object?[] { typeof(MultiPropsModel), null }; - } - - public static IEnumerable StorageNamesPropsTypeAndDefinitionCombos() - { - yield return new object?[] { typeof(StorageNamesPropsModel), s_storageNamesPropsDefinition }; - yield return new object?[] { typeof(StorageNamesPropsModel), null }; - } - - public static IEnumerable EnumerablePropsTypeAndDefinitionCombos() - { - yield return new object?[] { typeof(EnumerablePropsModel), s_enumerablePropsDefinition }; - yield return new object?[] { typeof(EnumerablePropsModel), null }; - } - -#pragma warning disable CA1812 // Invalid unused classes error, since I am using these for testing purposes above. - - private sealed class NoKeyModel - { - } - - private static readonly VectorStoreRecordDefinition s_noKeyDefinition = new(); - - private sealed class NoVectorModel - { - [VectorStoreRecordKey] - public string Key { get; set; } = string.Empty; - } - - private static readonly VectorStoreRecordDefinition s_noVectorDefinition = new() - { - Properties = - [ - new VectorStoreRecordKeyProperty("Key", typeof(string)) - ] - }; - - private sealed class MultiKeysModel - { - [VectorStoreRecordKey] - public string Key1 { get; set; } = string.Empty; - - [VectorStoreRecordKey] - public string Key2 { get; set; } = string.Empty; - } - - private static readonly VectorStoreRecordDefinition s_multiKeysDefinition = new() - { - Properties = - [ - new VectorStoreRecordKeyProperty("Key1", typeof(string)), - new VectorStoreRecordKeyProperty("Key2", typeof(string)) - ] - }; - - private sealed class SinglePropsModel - { - [VectorStoreRecordKey] - public string Key { get; set; } = string.Empty; - - [VectorStoreRecordData] - public string Data { get; set; } = string.Empty; - - [VectorStoreRecordVector] - public ReadOnlyMemory Vector { get; set; } - - public string NotAnnotated { get; set; } = string.Empty; - } - - private static readonly VectorStoreRecordDefinition s_singlePropsDefinition = new() - { - Properties = - [ - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("Data", typeof(string)), - new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) - ] - }; - - private sealed class MultiPropsModel - { - [VectorStoreRecordKey] - public string Key { get; set; } = string.Empty; - - [VectorStoreRecordData(IsFilterable = true, IsFullTextSearchable = true)] - public string Data1 { get; set; } = string.Empty; - - [VectorStoreRecordData(StoragePropertyName = "storage_data2")] - [JsonPropertyName("json_data2")] - public string Data2 { get; set; } = string.Empty; - - [VectorStoreRecordVector(4, DistanceFunction.DotProductSimilarity, IndexKind.Flat)] - public ReadOnlyMemory Vector1 { get; set; } - - [VectorStoreRecordVector(StoragePropertyName = "storage_vector2")] - [JsonPropertyName("json_vector2")] - public ReadOnlyMemory Vector2 { get; set; } - - public string NotAnnotated { get; set; } = string.Empty; - } - - private static readonly VectorStoreRecordDefinition s_multiPropsDefinition = new() - { - Properties = - [ - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("Data1", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, - new VectorStoreRecordDataProperty("Data2", typeof(string)) { StoragePropertyName = "storage_data2" }, - new VectorStoreRecordVectorProperty("Vector1", typeof(ReadOnlyMemory)) { Dimensions = 4, IndexKind = IndexKind.Flat, DistanceFunction = DistanceFunction.DotProductSimilarity }, - new VectorStoreRecordVectorProperty("Vector2", typeof(ReadOnlyMemory)) { StoragePropertyName = "storage_vector2" } - ] - }; - - private sealed class EnumerablePropsModel - { - [VectorStoreRecordKey] - public string Key { get; set; } = string.Empty; - - [VectorStoreRecordData] - public IEnumerable EnumerableData { get; set; } = new List(); - - [VectorStoreRecordData] - public string[] ArrayData { get; set; } = Array.Empty(); - - [VectorStoreRecordData] - public List ListData { get; set; } = new List(); - - [VectorStoreRecordVector] - public ReadOnlyMemory Vector { get; set; } - - public string NotAnnotated { get; set; } = string.Empty; - } - - private static readonly VectorStoreRecordDefinition s_enumerablePropsDefinition = new() - { - Properties = - [ - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("EnumerableData", typeof(IEnumerable)), - new VectorStoreRecordDataProperty("ArrayData", typeof(string[])), - new VectorStoreRecordDataProperty("ListData", typeof(List)), - new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) - ] - }; - - private sealed class StorageNamesPropsModel - { - [VectorStoreRecordKey(StoragePropertyName = "storage_key")] - [JsonPropertyName("json_key")] - public string Key { get; set; } = string.Empty; - - [VectorStoreRecordData(StoragePropertyName = "storage_data1")] - [JsonPropertyName("json_data1")] - public string Data1 { get; set; } = string.Empty; - - [VectorStoreRecordData(StoragePropertyName = "storage_data2")] - [JsonPropertyName("json_data2")] - public string Data2 { get; set; } = string.Empty; - - [VectorStoreRecordVector(StoragePropertyName = "storage_vector")] - [JsonPropertyName("json_vector")] - public ReadOnlyMemory Vector { get; set; } - } - - private static readonly VectorStoreRecordDefinition s_storageNamesPropsDefinition = new() - { - Properties = - [ - new VectorStoreRecordKeyProperty("Key", typeof(string)) { StoragePropertyName = "storage_key" }, - new VectorStoreRecordDataProperty("Data1", typeof(string)) { StoragePropertyName = "storage_data1" }, - new VectorStoreRecordDataProperty("Data2", typeof(string)) { StoragePropertyName = "storage_data2" }, - new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) { StoragePropertyName = "storage_vector" } - ] - }; - -#pragma warning restore CA1812 -} diff --git a/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreRecordPropertyVerificationTests.cs b/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreRecordPropertyVerificationTests.cs deleted file mode 100644 index 9e18965d8015..000000000000 --- a/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreRecordPropertyVerificationTests.cs +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections; -using System.Collections.Generic; -using System.Linq; -using Microsoft.Extensions.VectorData; -using Xunit; - -namespace SemanticKernel.UnitTests.Data; - -public class VectorStoreRecordPropertyVerificationTests -{ - [Fact] - public void VerifyPropertyTypesPassForAllowedTypes() - { - // Arrange. - var reader = new VectorStoreRecordPropertyReader(typeof(SinglePropsModel), null, null); - - // Act. - VectorStoreRecordPropertyVerification.VerifyPropertyTypes(reader.DataProperties, [typeof(string)], "Data"); - VectorStoreRecordPropertyVerification.VerifyPropertyTypes(this._singlePropsDefinition.Properties.OfType(), [typeof(string)], "Data"); - } - - [Fact] - public void VerifyPropertyTypesPassForAllowedEnumerableTypes() - { - // Arrange. - var reader = new VectorStoreRecordPropertyReader(typeof(EnumerablePropsModel), null, null); - - // Act. - VectorStoreRecordPropertyVerification.VerifyPropertyTypes(reader.DataProperties, [typeof(string)], "Data", supportEnumerable: true); - VectorStoreRecordPropertyVerification.VerifyPropertyTypes(this._enumerablePropsDefinition.Properties.OfType(), [typeof(string)], "Data", supportEnumerable: true); - } - - [Fact] - public void VerifyPropertyTypesFailsForDisallowedTypes() - { - // Arrange. - var reader = new VectorStoreRecordPropertyReader(typeof(SinglePropsModel), null, null); - - // Act. - var ex1 = Assert.Throws(() => VectorStoreRecordPropertyVerification.VerifyPropertyTypes(reader.DataProperties, [typeof(int), typeof(float)], "Data")); - var ex2 = Assert.Throws(() => VectorStoreRecordPropertyVerification.VerifyPropertyTypes(this._singlePropsDefinition.Properties.OfType(), [typeof(int), typeof(float)], "Data")); - - // Assert. - Assert.Equal("Data properties must be one of the supported types: System.Int32, System.Single. Type of the property 'Data' is System.String.", ex1.Message); - Assert.Equal("Data properties must be one of the supported types: System.Int32, System.Single. Type of the property 'Data' is System.String.", ex2.Message); - } - - [Theory] - [InlineData(typeof(SinglePropsModel), false, new Type[] { typeof(string) }, false)] - [InlineData(typeof(VectorStoreGenericDataModel), false, new Type[] { typeof(string), typeof(ulong) }, false)] - [InlineData(typeof(VectorStoreGenericDataModel), true, new Type[] { typeof(string), typeof(ulong) }, false)] - [InlineData(typeof(VectorStoreGenericDataModel), false, new Type[] { typeof(string), typeof(ulong) }, true)] - public void VerifyGenericDataModelKeyTypeThrowsOnlyForUnsupportedKeyTypeWithoutCustomMapper(Type recordType, bool customMapperSupplied, IEnumerable allowedKeyTypes, bool shouldThrow) - { - if (shouldThrow) - { - var ex = Assert.Throws(() => VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(recordType, customMapperSupplied, allowedKeyTypes)); - Assert.Equal("The key type 'System.Int32' of data model 'VectorStoreGenericDataModel' is not supported by the default mappers. Only the following key types are supported: System.String, System.UInt64. Please provide your own mapper to map to your chosen key type.", ex.Message); - } - else - { - VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(recordType, customMapperSupplied, allowedKeyTypes); - } - } - - [Theory] - [InlineData(typeof(SinglePropsModel), false, false)] - [InlineData(typeof(VectorStoreGenericDataModel), true, false)] - [InlineData(typeof(VectorStoreGenericDataModel), false, true)] - public void VerifyGenericDataModelDefinitionSuppliedThrowsOnlyForMissingDefinition(Type recordType, bool definitionSupplied, bool shouldThrow) - { - if (shouldThrow) - { - var ex = Assert.Throws(() => VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(recordType, definitionSupplied)); - Assert.Equal("A VectorStoreRecordDefinition must be provided when using 'VectorStoreGenericDataModel'.", ex.Message); - } - else - { - VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(recordType, definitionSupplied); - } - } - - [Theory] - [InlineData(typeof(List), true)] - [InlineData(typeof(ICollection), true)] - [InlineData(typeof(IEnumerable), true)] - [InlineData(typeof(IList), true)] - [InlineData(typeof(IReadOnlyCollection), true)] - [InlineData(typeof(IReadOnlyList), true)] - [InlineData(typeof(string[]), true)] - [InlineData(typeof(IEnumerable), true)] - [InlineData(typeof(ArrayList), true)] - [InlineData(typeof(string), false)] - [InlineData(typeof(HashSet), false)] - [InlineData(typeof(ISet), false)] - [InlineData(typeof(Dictionary), false)] - [InlineData(typeof(Stack), false)] - [InlineData(typeof(Queue), false)] - public void IsSupportedEnumerableTypeReturnsCorrectAnswerForEachType(Type type, bool expected) - { - // Act. - var actual = VectorStoreRecordPropertyVerification.IsSupportedEnumerableType(type); - - // Assert. - Assert.Equal(expected, actual); - } - -#pragma warning disable CA1812 // Invalid unused classes error, since I am using these for testing purposes above. - - private sealed class SinglePropsModel - { - [VectorStoreRecordKey] - public string Key { get; set; } = string.Empty; - - [VectorStoreRecordData] - public string Data { get; set; } = string.Empty; - - [VectorStoreRecordVector] - public ReadOnlyMemory Vector { get; set; } - - public string NotAnnotated { get; set; } = string.Empty; - } - - private readonly VectorStoreRecordDefinition _singlePropsDefinition = new() - { - Properties = - [ - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("Data", typeof(string)), - new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) - ] - }; - - private sealed class EnumerablePropsModel - { - [VectorStoreRecordKey] - public string Key { get; set; } = string.Empty; - - [VectorStoreRecordData] - public IEnumerable EnumerableData { get; set; } = new List(); - - [VectorStoreRecordData] - public string[] ArrayData { get; set; } = Array.Empty(); - - [VectorStoreRecordData] - public List ListData { get; set; } = new List(); - - [VectorStoreRecordVector] - public ReadOnlyMemory Vector { get; set; } - - public string NotAnnotated { get; set; } = string.Empty; - } - - private readonly VectorStoreRecordDefinition _enumerablePropsDefinition = new() - { - Properties = - [ - new VectorStoreRecordKeyProperty("Key", typeof(string)), - new VectorStoreRecordDataProperty("EnumerableData", typeof(IEnumerable)), - new VectorStoreRecordDataProperty("ArrayData", typeof(string[])), - new VectorStoreRecordDataProperty("ListData", typeof(List)), - new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) - ] - }; - -#pragma warning restore CA1812 -} diff --git a/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreTextSearchTestBase.cs b/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreTextSearchTestBase.cs index c01fe06eddf4..c6feb2e0d047 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreTextSearchTestBase.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreTextSearchTestBase.cs @@ -5,6 +5,7 @@ using System.Collections.ObjectModel; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.InMemory; @@ -22,31 +23,31 @@ public class VectorStoreTextSearchTestBase /// /// Create a from a . /// - public static async Task> CreateVectorStoreTextSearchFromVectorizedSearchAsync() + [Obsolete("VectorStoreTextSearch with ITextEmbeddingGenerationService is obsolete")] + public static async Task> CreateVectorStoreTextSearchWithEmbeddingGenerationServiceAsync() { var vectorStore = new InMemoryVectorStore(); - var vectorSearch = vectorStore.GetCollection("records"); + var vectorSearch = vectorStore.GetCollection("records"); var stringMapper = new DataModelTextSearchStringMapper(); var resultMapper = new DataModelTextSearchResultMapper(); - var embeddingService = new MockTextEmbeddingGenerationService(); + using var embeddingService = new MockTextEmbeddingGenerator(); await AddRecordsAsync(vectorSearch, embeddingService); - var sut = new VectorStoreTextSearch(vectorSearch, embeddingService, stringMapper, resultMapper); + var sut = new VectorStoreTextSearch(vectorSearch, embeddingService, stringMapper, resultMapper); return sut; } /// /// Create a from a . /// - public static async Task> CreateVectorStoreTextSearchFromVectorizableTextSearchAsync() + public static async Task> CreateVectorStoreTextSearchAsync() { - var vectorStore = new InMemoryVectorStore(); + using var embeddingGenerator = new MockTextEmbeddingGenerator(); + var vectorStore = new InMemoryVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); var vectorSearch = vectorStore.GetCollection("records"); var stringMapper = new DataModelTextSearchStringMapper(); var resultMapper = new DataModelTextSearchResultMapper(); - var embeddingService = new MockTextEmbeddingGenerationService(); - await AddRecordsAsync(vectorSearch, embeddingService); - var vectorizableTextSearch = new VectorizedSearchWrapper(vectorSearch, new MockTextEmbeddingGenerationService()); - var sut = new VectorStoreTextSearch(vectorizableTextSearch, stringMapper, resultMapper); + await AddRecordsAsync(vectorSearch); + var sut = new VectorStoreTextSearch(vectorSearch, stringMapper, resultMapper); return sut; } @@ -55,13 +56,34 @@ public static async Task> CreateVectorStoreText /// public static async Task AddRecordsAsync( IVectorStoreRecordCollection recordCollection, - ITextEmbeddingGenerationService embeddingService, int? count = 10) { await recordCollection.CreateCollectionIfNotExistsAsync(); for (var i = 0; i < count; i++) { DataModel dataModel = new() + { + Key = Guid.NewGuid(), + Text = $"Record {i}", + Tag = i % 2 == 0 ? "Even" : "Odd", + Embedding = $"Record {i}" + }; + await recordCollection.UpsertAsync(dataModel); + } + } + + /// + /// Add sample records to the vector store record collection. + /// + public static async Task AddRecordsAsync( + IVectorStoreRecordCollection recordCollection, + ITextEmbeddingGenerationService embeddingService, + int? count = 10) + { + await recordCollection.CreateCollectionIfNotExistsAsync(); + for (var i = 0; i < count; i++) + { + DataModelWithRawEmbedding dataModel = new() { Key = Guid.NewGuid(), Text = $"Record {i}", @@ -79,13 +101,12 @@ public sealed class DataModelTextSearchStringMapper : ITextSearchStringMapper { /// public string MapFromResultToString(object result) - { - if (result is DataModel dataModel) + => result switch { - return dataModel.Text; - } - throw new ArgumentException("Invalid result type."); - } + DataModel dataModel => dataModel.Text, + DataModelWithRawEmbedding dataModelWithRawEmbedding => dataModelWithRawEmbedding.Text, + _ => throw new ArgumentException("Invalid result type.") + }; } /// @@ -95,20 +116,26 @@ public sealed class DataModelTextSearchResultMapper : ITextSearchResultMapper { /// public TextSearchResult MapFromResultToTextSearchResult(object result) - { - if (result is DataModel dataModel) + => result switch { - return new TextSearchResult(value: dataModel.Text) { Name = dataModel.Key.ToString() }; - } - throw new ArgumentException("Invalid result type."); - } + DataModel dataModel => new TextSearchResult(value: dataModel.Text) { Name = dataModel.Key.ToString() }, + DataModelWithRawEmbedding dataModelWithRawEmbedding => new TextSearchResult(value: dataModelWithRawEmbedding.Text) { Name = dataModelWithRawEmbedding.Key.ToString() }, + _ => throw new ArgumentException("Invalid result type.") + }; } /// /// Mock implementation of . /// - public sealed class MockTextEmbeddingGenerationService : ITextEmbeddingGenerationService + public sealed class MockTextEmbeddingGenerator : IEmbeddingGenerator>, ITextEmbeddingGenerationService { + public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + => Task.FromResult(new GeneratedEmbeddings>([new(new float[] { 0, 1, 2, 3 })])); + + public void Dispose() { } + + public object? GetService(Type serviceType, object? serviceKey = null) => null; + /// public IReadOnlyDictionary Attributes { get; } = ReadOnlyDictionary.Empty; @@ -121,16 +148,27 @@ public Task>> GenerateEmbeddingsAsync(IList } /// - /// Decorator for a that generates embeddings for text search queries. + /// Sample model class that represents a record entry. /// - public sealed class VectorizedSearchWrapper(IVectorizedSearch vectorizedSearch, ITextEmbeddingGenerationService textEmbeddingGeneration) : IVectorizableTextSearch + /// + /// Note that each property is decorated with an attribute that specifies how the property should be treated by the vector store. + /// This allows us to create a collection in the vector store and upsert and retrieve instances of this class without any further configuration. + /// +#pragma warning disable CA1812 // Avoid uninstantiated internal classes + public sealed class DataModel +#pragma warning restore CA1812 // Avoid uninstantiated internal classes { - /// - public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) - { - var vectorizedQuery = await textEmbeddingGeneration!.GenerateEmbeddingAsync(searchText, cancellationToken: cancellationToken).ConfigureAwait(false); - return await vectorizedSearch.VectorizedSearchAsync(vectorizedQuery, options, cancellationToken); - } + [VectorStoreRecordKey] + public Guid Key { get; init; } + + [VectorStoreRecordData] + public required string Text { get; init; } + + [VectorStoreRecordData(IsIndexed = true)] + public required string Tag { get; init; } + + [VectorStoreRecordVector(1536)] + public string? Embedding { get; init; } } /// @@ -141,7 +179,7 @@ public async Task> VectorizableTextSearchAsync(stri /// This allows us to create a collection in the vector store and upsert and retrieve instances of this class without any further configuration. /// #pragma warning disable CA1812 // Avoid uninstantiated internal classes - public sealed class DataModel + public sealed class DataModelWithRawEmbedding #pragma warning restore CA1812 // Avoid uninstantiated internal classes { [VectorStoreRecordKey] @@ -150,7 +188,7 @@ public sealed class DataModel [VectorStoreRecordData] public required string Text { get; init; } - [VectorStoreRecordData(IsFilterable = true)] + [VectorStoreRecordData(IsIndexed = true)] public required string Tag { get; init; } [VectorStoreRecordVector(1536)] diff --git a/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreTextSearchTests.cs b/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreTextSearchTests.cs index dcb2d310eda7..9737c52b2de0 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreTextSearchTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreTextSearchTests.cs @@ -1,4 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. + using System; using System.Linq; using System.Threading.Tasks; @@ -9,44 +10,46 @@ namespace SemanticKernel.UnitTests.Data; public class VectorStoreTextSearchTests : VectorStoreTextSearchTestBase { +#pragma warning disable CS0618 // VectorStoreTextSearch with ITextEmbeddingGenerationService is obsolete [Fact] - public void CanCreateVectorStoreTextSearchWithIVectorizedSearch() + public void CanCreateVectorStoreTextSearchWithEmbeddingGenerationService() { // Arrange. var vectorStore = new InMemoryVectorStore(); - var vectorSearch = vectorStore.GetCollection("records"); + var vectorSearch = vectorStore.GetCollection("records"); var stringMapper = new DataModelTextSearchStringMapper(); var resultMapper = new DataModelTextSearchResultMapper(); + using var embeddingGenerationService = new MockTextEmbeddingGenerator(); // Act. - var sut = new VectorStoreTextSearch(vectorSearch, new MockTextEmbeddingGenerationService(), stringMapper, resultMapper); + var sut = new VectorStoreTextSearch(vectorSearch, embeddingGenerationService, stringMapper, resultMapper); // Assert. Assert.NotNull(sut); } +#pragma warning restore CS0618 [Fact] - public void CanCreateVectorStoreTextSearchWithIVectorizableTextSearch() + public void CanCreateVectorStoreTextSearchWithIVectorSearch() { // Arrange. - var vectorStore = new InMemoryVectorStore(); + var vectorStore = new InMemoryVectorStore(new() { EmbeddingGenerator = new MockTextEmbeddingGenerator() }); var vectorSearch = vectorStore.GetCollection("records"); - var vectorizableTextSearch = new VectorizedSearchWrapper(vectorSearch, new MockTextEmbeddingGenerationService()); var stringMapper = new DataModelTextSearchStringMapper(); var resultMapper = new DataModelTextSearchResultMapper(); // Act. - var sut = new VectorStoreTextSearch(vectorizableTextSearch, stringMapper, resultMapper); + var sut = new VectorStoreTextSearch(vectorSearch, stringMapper, resultMapper); // Assert. Assert.NotNull(sut); } [Fact] - public async Task CanSearchWithVectorizedSearchAsync() + public async Task CanSearchAsync() { // Arrange. - var sut = await CreateVectorStoreTextSearchFromVectorizedSearchAsync(); + var sut = await CreateVectorStoreTextSearchAsync(); // Act. KernelSearchResults searchResults = await sut.SearchAsync("What is the Semantic Kernel?", new() { Top = 2, Skip = 0 }); @@ -56,10 +59,10 @@ public async Task CanSearchWithVectorizedSearchAsync() } [Fact] - public async Task CanGetTextSearchResultsWithVectorizedSearchAsync() + public async Task CanGetTextSearchResultsAsync() { // Arrange. - var sut = await CreateVectorStoreTextSearchFromVectorizedSearchAsync(); + var sut = await CreateVectorStoreTextSearchAsync(); // Act. KernelSearchResults searchResults = await sut.GetTextSearchResultsAsync("What is the Semantic Kernel?", new() { Top = 2, Skip = 0 }); @@ -69,10 +72,10 @@ public async Task CanGetTextSearchResultsWithVectorizedSearchAsync() } [Fact] - public async Task CanGetSearchResultsWithVectorizedSearchAsync() + public async Task CanGetSearchResultAsync() { // Arrange. - var sut = await CreateVectorStoreTextSearchFromVectorizedSearchAsync(); + var sut = await CreateVectorStoreTextSearchAsync(); // Act. KernelSearchResults searchResults = await sut.GetSearchResultsAsync("What is the Semantic Kernel?", new() { Top = 2, Skip = 0 }); @@ -81,11 +84,12 @@ public async Task CanGetSearchResultsWithVectorizedSearchAsync() Assert.Equal(2, results.Count); } +#pragma warning disable CS0618 // VectorStoreTextSearch with ITextEmbeddingGenerationService is obsolete [Fact] - public async Task CanSearchWithVectorizableTextSearchAsync() + public async Task CanSearchWithEmbeddingGenerationServiceAsync() { // Arrange. - var sut = await CreateVectorStoreTextSearchFromVectorizableTextSearchAsync(); + var sut = await CreateVectorStoreTextSearchWithEmbeddingGenerationServiceAsync(); // Act. KernelSearchResults searchResults = await sut.SearchAsync("What is the Semantic Kernel?", new() { Top = 2, Skip = 0 }); @@ -95,10 +99,10 @@ public async Task CanSearchWithVectorizableTextSearchAsync() } [Fact] - public async Task CanGetTextSearchResultsWithVectorizableTextSearchAsync() + public async Task CanGetTextSearchResultsWithEmbeddingGenerationServiceAsync() { // Arrange. - var sut = await CreateVectorStoreTextSearchFromVectorizableTextSearchAsync(); + var sut = await CreateVectorStoreTextSearchWithEmbeddingGenerationServiceAsync(); // Act. KernelSearchResults searchResults = await sut.GetTextSearchResultsAsync("What is the Semantic Kernel?", new() { Top = 2, Skip = 0 }); @@ -108,10 +112,10 @@ public async Task CanGetTextSearchResultsWithVectorizableTextSearchAsync() } [Fact] - public async Task CanGetSearchResultsWithVectorizableTextSearchAsync() + public async Task CanGetSearchResultsWithEmbeddingGenerationServiceAsync() { // Arrange. - var sut = await CreateVectorStoreTextSearchFromVectorizableTextSearchAsync(); + var sut = await CreateVectorStoreTextSearchWithEmbeddingGenerationServiceAsync(); // Act. KernelSearchResults searchResults = await sut.GetSearchResultsAsync("What is the Semantic Kernel?", new() { Top = 2, Skip = 0 }); @@ -119,12 +123,13 @@ public async Task CanGetSearchResultsWithVectorizableTextSearchAsync() Assert.Equal(2, results.Count); } +#pragma warning restore CS0618 // VectorStoreTextSearch with ITextEmbeddingGenerationService is obsolete [Fact] public async Task CanFilterGetSearchResultsWithVectorizedSearchAsync() { // Arrange. - var sut = await CreateVectorStoreTextSearchFromVectorizedSearchAsync(); + var sut = await CreateVectorStoreTextSearchAsync(); TextSearchFilter evenFilter = new(); evenFilter.Equality("Tag", "Even"); TextSearchFilter oddFilter = new(); diff --git a/dotnet/src/SemanticKernel.UnitTests/SemanticKernel.UnitTests.csproj b/dotnet/src/SemanticKernel.UnitTests/SemanticKernel.UnitTests.csproj index 8580c9a173ab..de5aa0586a38 100644 --- a/dotnet/src/SemanticKernel.UnitTests/SemanticKernel.UnitTests.csproj +++ b/dotnet/src/SemanticKernel.UnitTests/SemanticKernel.UnitTests.csproj @@ -6,7 +6,7 @@ net8.0 true false - $(NoWarn);CA2007,CA1861,IDE1006,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0050,SKEXP0110,SKEXP0120 + $(NoWarn);CA2007,CA1861,IDE1006,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0050,SKEXP0110,SKEXP0120,MEVD9000 diff --git a/dotnet/src/SemanticKernel.UnitTests/Utilities/FakeLogger.cs b/dotnet/src/SemanticKernel.UnitTests/Utilities/FakeLogger.cs new file mode 100644 index 000000000000..d9c44bf6a560 --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Utilities/FakeLogger.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Logging; + +namespace SemanticKernel.UnitTests.Utilities; + +public class FakeLogger : ILogger +{ + public List<(LogLevel Level, string Message, Exception? Exception)> Logs { get; } = new(); + + public IDisposable? BeginScope(TState state) where TState : notnull => null; + + public bool IsEnabled(LogLevel logLevel) => true; + + public void Log( + LogLevel logLevel, + EventId eventId, + TState state, + Exception? exception, + Func formatter) + { + var message = formatter(state, exception); + this.Logs.Add((logLevel, message, exception)); + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Utilities/LoggingExtensionsTests.cs b/dotnet/src/SemanticKernel.UnitTests/Utilities/LoggingExtensionsTests.cs new file mode 100644 index 000000000000..8a6e09c013c8 --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Utilities/LoggingExtensionsTests.cs @@ -0,0 +1,239 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.Diagnostics; +using Xunit; + +namespace SemanticKernel.UnitTests.Utilities; + +public class LoggingExtensionsTests +{ + [Fact] + public async Task RunWithLoggingVoidLogsSuccess() + { + // Arrange + var logger = new FakeLogger(); + static Task Operation() => Task.CompletedTask; + + // Act + await LoggingExtensions.RunWithLoggingAsync(logger, "TestOperation", Operation); + + // Assert + var logs = logger.Logs; + Assert.Equal(2, logs.Count); + Assert.Equal(LogLevel.Debug, logs[0].Level); + Assert.Equal("TestOperation invoked.", logs[0].Message); + Assert.Null(logs[0].Exception); + Assert.Equal(LogLevel.Debug, logs[1].Level); + Assert.Equal("TestOperation completed.", logs[1].Message); + Assert.Null(logs[1].Exception); + } + + [Fact] + public async Task RunWithLoggingVoidLogsException() + { + // Arrange + var logger = new FakeLogger(); + static Task Operation() => throw new InvalidOperationException("Test error"); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + LoggingExtensions.RunWithLoggingAsync(logger, "TestOperation", Operation)); + + Assert.Equal("Test error", exception.Message); + + var logs = logger.Logs; + Assert.Equal(2, logs.Count); + Assert.Equal(LogLevel.Debug, logs[0].Level); + Assert.Equal("TestOperation invoked.", logs[0].Message); + Assert.Null(logs[0].Exception); + Assert.Equal(LogLevel.Error, logs[1].Level); + Assert.Equal("TestOperation failed.", logs[1].Message); + Assert.Equal("Test error", logs[1].Exception?.Message); + } + + [Fact] + public async Task RunWithLoggingVoidLogsCancellation() + { + // Arrange + var logger = new FakeLogger(); + using var cts = new CancellationTokenSource(); + Task Operation() => Task.FromCanceled(cts.Token); + cts.Cancel(); + + // Act & Assert + await Assert.ThrowsAsync(() => + LoggingExtensions.RunWithLoggingAsync(logger, "TestOperation", Operation)); + + var logs = logger.Logs; + Assert.Equal(2, logs.Count); + Assert.Equal(LogLevel.Debug, logs[0].Level); + Assert.Equal("TestOperation invoked.", logs[0].Message); + Assert.Null(logs[0].Exception); + Assert.Equal(LogLevel.Debug, logs[1].Level); + Assert.Equal("TestOperation canceled.", logs[1].Message); + Assert.Null(logs[1].Exception); + } + + [Fact] + public async Task RunWithLoggingWithResultReturnsValue() + { + // Arrange + var logger = new FakeLogger(); + static Task Operation() => Task.FromResult(42); + + // Act + var result = await LoggingExtensions.RunWithLoggingAsync(logger, "TestOperation", Operation); + + // Assert + Assert.Equal(42, result); + + var logs = logger.Logs; + + Assert.Equal(2, logs.Count); + Assert.Equal(LogLevel.Debug, logs[0].Level); + Assert.Equal("TestOperation invoked.", logs[0].Message); + Assert.Null(logs[0].Exception); + Assert.Equal(LogLevel.Debug, logs[1].Level); + Assert.Equal("TestOperation completed.", logs[1].Message); + Assert.Null(logs[1].Exception); + } + + [Fact] + public async Task RunWithLoggingWithResultLogsException() + { + // Arrange + var logger = new FakeLogger(); + static Task Operation() => throw new InvalidOperationException("Test error"); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + LoggingExtensions.RunWithLoggingAsync(logger, "TestOperation", Operation)); + + Assert.Equal("Test error", exception.Message); + + var logs = logger.Logs; + Assert.Equal(2, logs.Count); + Assert.Equal(LogLevel.Debug, logs[0].Level); + Assert.Equal("TestOperation invoked.", logs[0].Message); + Assert.Null(logs[0].Exception); + Assert.Equal(LogLevel.Error, logs[1].Level); + Assert.Equal("TestOperation failed.", logs[1].Message); + Assert.Equal("Test error", logs[1].Exception?.Message); + } + + [Fact] + public async Task RunWithLoggingEnumerableYieldsValues() + { + // Arrange + var logger = new FakeLogger(); + static async IAsyncEnumerable Operation() + { + yield return 1; + yield return 2; + await Task.CompletedTask; // Ensure async behavior + } + + // Act + var results = new List(); + await foreach (var item in LoggingExtensions.RunWithLoggingAsync(logger, "TestOperation", Operation, default)) + { + results.Add(item); + } + + // Assert + Assert.Equal(new[] { 1, 2 }, results); + + var logs = logger.Logs; + + Assert.Equal(2, logs.Count); + Assert.Equal(LogLevel.Debug, logs[0].Level); + Assert.Equal("TestOperation invoked.", logs[0].Message); + Assert.Null(logs[0].Exception); + Assert.Equal(LogLevel.Debug, logs[1].Level); + Assert.Equal("TestOperation completed.", logs[1].Message); + Assert.Null(logs[1].Exception); + } + + [Fact] + public async Task RunWithLoggingEnumerableLogsException() + { + // Arrange + var logger = new FakeLogger(); + static async IAsyncEnumerable Operation() + { + yield return 1; + await Task.CompletedTask; + throw new InvalidOperationException("Test error"); + } + + // Act & Assert + var results = new List(); + var exception = await Assert.ThrowsAsync(async () => + { + await foreach (var item in LoggingExtensions.RunWithLoggingAsync(logger, "TestOperation", Operation, default)) + { + results.Add(item); + } + }); + + Assert.Equal("Test error", exception.Message); + Assert.Equal(new[] { 1 }, results); + + var logs = logger.Logs; + + Assert.Equal(2, logs.Count); + Assert.Equal(LogLevel.Debug, logs[0].Level); + Assert.Equal("TestOperation invoked.", logs[0].Message); + Assert.Null(logs[0].Exception); + Assert.Equal(LogLevel.Error, logs[1].Level); + Assert.Equal("TestOperation failed.", logs[1].Message); + Assert.Equal("Test error", logs[1].Exception?.Message); + } + + [Fact] + public async Task RunWithLoggingEnumerableLogsCancellation() + { + // Arrange + var logger = new FakeLogger(); + using var cts = new CancellationTokenSource(); + static async IAsyncEnumerable Operation([EnumeratorCancellation] CancellationToken token) + { + yield return 1; + await Task.Delay(10, token); // Simulate async work + yield return 2; + } + cts.Cancel(); + + // Act & Assert + var results = new List(); + var exception = await Assert.ThrowsAsync(async () => + { + await foreach (var item in LoggingExtensions.RunWithLoggingAsync( + logger, + "TestOperation", + () => Operation(cts.Token), + cts.Token)) + { + results.Add(item); + } + }); + + Assert.Equal(new[] { 1 }, results); // Should yield first value before cancellation + + var logs = logger.Logs; + + Assert.Equal(2, logs.Count); + Assert.Equal(LogLevel.Debug, logs[0].Level); + Assert.Equal("TestOperation invoked.", logs[0].Message); + Assert.Null(logs[0].Exception); + Assert.Equal(LogLevel.Debug, logs[1].Level); + Assert.Equal("TestOperation canceled.", logs[1].Message); + Assert.Null(logs[1].Exception); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/AzureAISearchIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/AzureAISearchIntegrationTests.csproj index 688796758267..0aa90f87ee3e 100644 --- a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/AzureAISearchIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/AzureAISearchIntegrationTests.csproj @@ -16,6 +16,7 @@ runtime; build; native; contentfiles; analyzers; buildtransitive all + diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/CRUD/AzureAISearchBatchConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/CRUD/AzureAISearchBatchConformanceTests.cs new file mode 100644 index 000000000000..4da84cc7e99a --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/CRUD/AzureAISearchBatchConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using AzureAISearchIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace AzureAISearchIntegrationTests.CRUD; + +public class AzureAISearchBatchConformanceTests(AzureAISearchSimpleModelFixture fixture) + : BatchConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/CRUD/AzureAISearchNoDataConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/CRUD/AzureAISearchNoDataConformanceTests.cs new file mode 100644 index 000000000000..69861413facc --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/CRUD/AzureAISearchNoDataConformanceTests.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.RegularExpressions; +using AzureAISearchIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace AzureAISearchIntegrationTests.CRUD; + +public class AzureAISearchNoDataConformanceTests(AzureAISearchNoDataConformanceTests.Fixture fixture) + : NoDataConformanceTests(fixture), IClassFixture +{ +#pragma warning disable CA1308 // Normalize strings to uppercase + private static readonly string _testIndexPostfix = new Regex("[^a-zA-Z0-9]").Replace(Environment.MachineName.ToLowerInvariant(), ""); +#pragma warning restore CA1308 // Normalize strings to uppercase + + public new class Fixture : NoDataConformanceTests.Fixture + { + public override string CollectionName => "nodata-" + _testIndexPostfix; + + public override TestStore TestStore => AzureAISearchTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/CRUD/AzureAISearchNoVectorConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/CRUD/AzureAISearchNoVectorConformanceTests.cs new file mode 100644 index 000000000000..a48f648d079f --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/CRUD/AzureAISearchNoVectorConformanceTests.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.RegularExpressions; +using AzureAISearchIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace AzureAISearchIntegrationTests.CRUD; + +public class AzureAISearchNoVectorConformanceTests(AzureAISearchNoVectorConformanceTests.Fixture fixture) + : NoVectorConformanceTests(fixture), IClassFixture +{ +#pragma warning disable CA1308 // Normalize strings to uppercase + private static readonly string _testIndexPostfix = new Regex("[^a-zA-Z0-9]").Replace(Environment.MachineName.ToLowerInvariant(), ""); +#pragma warning restore CA1308 // Normalize strings to uppercase + + public new class Fixture : NoVectorConformanceTests.Fixture + { + public override string CollectionName => "novector-" + _testIndexPostfix; + + public override TestStore TestStore => AzureAISearchTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/CRUD/AzureAISearchRecordConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/CRUD/AzureAISearchRecordConformanceTests.cs new file mode 100644 index 000000000000..49b8b88b4b5f --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/CRUD/AzureAISearchRecordConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using AzureAISearchIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace AzureAISearchIntegrationTests.CRUD; + +public class AzureAISearchRecordConformanceTests(AzureAISearchSimpleModelFixture fixture) + : RecordConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchBasicFilterTests.cs index 6a7e8a1df408..eb6de0cf988c 100644 --- a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchBasicFilterTests.cs @@ -19,6 +19,6 @@ public override Task Contains_over_inline_int_array() public override TestStore TestStore => AzureAISearchTestStore.Instance; // Azure AI search only supports lowercase letters, digits or dashes. - protected override string CollectionName => "filter-tests"; + public override string CollectionName => "filter-tests"; } } diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchBasicQueryTests.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchBasicQueryTests.cs new file mode 100644 index 000000000000..38dd343a8aa6 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchBasicQueryTests.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. + +using AzureAISearchIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace AzureAISearchIntegrationTests.Filter; + +public class AzureAISearchBasicQueryTests(AzureAISearchBasicQueryTests.Fixture fixture) + : BasicQueryTests(fixture), IClassFixture +{ + // Azure AI Search only supports search.in() over strings + public override Task Contains_over_inline_int_array() + => Assert.ThrowsAsync(() => base.Contains_over_inline_int_array()); + + public new class Fixture : BasicQueryTests.QueryFixture + { + public override TestStore TestStore => AzureAISearchTestStore.Instance; + + // Azure AI search only supports lowercase letters, digits or dashes. + public override string CollectionName => "query-tests"; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/HybridSearch/AzureAISearchKeywordVectorizedHybridSearchTests.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/HybridSearch/AzureAISearchKeywordVectorizedHybridSearchTests.cs index 3860489b9471..edde0f48405a 100644 --- a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/HybridSearch/AzureAISearchKeywordVectorizedHybridSearchTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/HybridSearch/AzureAISearchKeywordVectorizedHybridSearchTests.cs @@ -28,7 +28,7 @@ public class AzureAISearchKeywordVectorizedHybridSearchTests( public override TestStore TestStore => AzureAISearchTestStore.Instance; // Azure AI search only supports lowercase letters, digits or dashes. - protected override string CollectionName => "vecstring-hybrid-search-" + _testIndexPostfix; + public override string CollectionName => "vecstring-hybrid-search-" + _testIndexPostfix; } public new class MultiTextFixture : KeywordVectorizedHybridSearchComplianceTests.MultiTextFixture @@ -36,6 +36,6 @@ public class AzureAISearchKeywordVectorizedHybridSearchTests( public override TestStore TestStore => AzureAISearchTestStore.Instance; // Azure AI search only supports lowercase letters, digits or dashes. - protected override string CollectionName => "multitext-hybrid-search-" + _testIndexPostfix; + public override string CollectionName => "multitext-hybrid-search-" + _testIndexPostfix; } } diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchSimpleModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchSimpleModelFixture.cs new file mode 100644 index 000000000000..fd26563901ff --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchSimpleModelFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace AzureAISearchIntegrationTests.Support; + +public class AzureAISearchSimpleModelFixture : SimpleModelFixture +{ + public override TestStore TestStore => AzureAISearchTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestEnvironment.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestEnvironment.cs index 27e905656870..cad8632f873b 100644 --- a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestEnvironment.cs +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestEnvironment.cs @@ -10,7 +10,7 @@ internal static class AzureAISearchTestEnvironment { public static readonly string? ServiceUrl, ApiKey; - public static bool IsConnectionInfoDefined => ServiceUrl is not null && ApiKey is not null; + public static bool IsConnectionInfoDefined => ServiceUrl is not null; static AzureAISearchTestEnvironment() { diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestStore.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestStore.cs index 791005d55c9a..4d7a34fec866 100644 --- a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestStore.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Linq.Expressions; using Azure; +using Azure.Identity; using Azure.Search.Documents.Indexes; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.AzureAISearch; @@ -32,14 +34,31 @@ protected override Task StartAsync() { (string? serviceUrl, string? apiKey) = (AzureAISearchTestEnvironment.ServiceUrl, AzureAISearchTestEnvironment.ApiKey); - if (string.IsNullOrWhiteSpace(serviceUrl) || string.IsNullOrWhiteSpace(apiKey)) + if (string.IsNullOrWhiteSpace(serviceUrl)) { - throw new InvalidOperationException("Service URL and API key are not configured, set AzureAISearch:ServiceUrl and AzureAISearch:ApiKey"); + throw new InvalidOperationException("Service URL is not configured, set AzureAISearch:ServiceUrl (and AzureAISearch:ApiKey if you want)"); } - this._client = new SearchIndexClient(new Uri(serviceUrl), new AzureKeyCredential(apiKey)); + this._client = string.IsNullOrWhiteSpace(apiKey) + ? new SearchIndexClient(new Uri(serviceUrl), new DefaultAzureCredential()) + : new SearchIndexClient(new Uri(serviceUrl), new AzureKeyCredential(apiKey)); + this._defaultVectorStore = new(this._client); return Task.CompletedTask; } + + public override async Task WaitForDataAsync( + IVectorStoreRecordCollection collection, + int recordCount, + Expression>? filter = null, + int vectorSize = 3) + { + await base.WaitForDataAsync(collection, recordCount, filter, vectorSize); + + // There seems to be some asynchronicity/race condition specific to Azure AI Search which isn't taken care + // of by the generic retry loop in the base implementation. + // TODO: Investigate this and remove + await Task.Delay(TimeSpan.FromMilliseconds(1000)); + } } diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchUrlRequiredAttribute.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchUrlRequiredAttribute.cs index 1b30639bc1be..9b2fb5a9b223 100644 --- a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchUrlRequiredAttribute.cs +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchUrlRequiredAttribute.cs @@ -12,7 +12,7 @@ public sealed class AzureAISearchUrlRequiredAttribute : Attribute, ITestConditio { public ValueTask IsMetAsync() => new(AzureAISearchTestEnvironment.IsConnectionInfoDefined); - public string Skip { get; set; } = "Service URL and API key are not configured, set AzureAISearch:ServiceUrl and AzureAISearch:ApiKey."; + public string Skip { get; set; } = "Service URL is not configured, set AzureAISearch:ServiceUrl (and AzureAISearch:ApiKey if you don't use managed identity)."; public string SkipReason => this.Skip; diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CRUD/CosmosMongoDBBatchConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CRUD/CosmosMongoDBBatchConformanceTests.cs new file mode 100644 index 000000000000..bf5dbe318eac --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CRUD/CosmosMongoDBBatchConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosMongoDBIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace CosmosMongoDBIntegrationTests.CRUD; + +public class CosmosMongoDBBatchConformanceTests(CosmosMongoDBSimpleModelFixture fixture) + : BatchConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CRUD/CosmosMongoDBNoDataConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CRUD/CosmosMongoDBNoDataConformanceTests.cs new file mode 100644 index 000000000000..1cdb1ed6d4cc --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CRUD/CosmosMongoDBNoDataConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosMongoDBIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace CosmosMongoDBIntegrationTests.CRUD; + +public class CosmosMongoDBNoDataConformanceTests(CosmosMongoDBNoDataConformanceTests.Fixture fixture) + : NoDataConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoDataConformanceTests.Fixture + { + public override TestStore TestStore => CosmosMongoDBTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CRUD/CosmosMongoDBNoVectorConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CRUD/CosmosMongoDBNoVectorConformanceTests.cs new file mode 100644 index 000000000000..134707f35575 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CRUD/CosmosMongoDBNoVectorConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosMongoDBIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace CosmosMongoDBIntegrationTests.CRUD; + +public class CosmosMongoDBNoVectorConformanceTests(CosmosMongoDBNoVectorConformanceTests.Fixture fixture) + : NoVectorConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoVectorConformanceTests.Fixture + { + public override TestStore TestStore => CosmosMongoDBTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CRUD/CosmosMongoDBRecordConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CRUD/CosmosMongoDBRecordConformanceTests.cs new file mode 100644 index 000000000000..a7f89cfc05c6 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CRUD/CosmosMongoDBRecordConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosMongoDBIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace CosmosMongoDBIntegrationTests.CRUD; + +public class CosmosMongoDBRecordConformanceTests(CosmosMongoDBSimpleModelFixture fixture) + : RecordConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoDBIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoDBIntegrationTests.csproj index 59a720d7dddd..0e4200084d4e 100644 --- a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoDBIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoDBIntegrationTests.csproj @@ -7,6 +7,7 @@ true false CosmosMongoDBIntegrationTests + b7762d10-e29b-4bb1-8b74-b6d69a667dd4 @@ -19,6 +20,7 @@ + diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoEmbeddingGenerationTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoEmbeddingGenerationTests.cs new file mode 100644 index 000000000000..45e6c28d68b4 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoEmbeddingGenerationTests.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosMongoDBIntegrationTests.Support; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel; +using VectorDataSpecificationTests; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace CosmosMongoDBIntegrationTests; + +public class CosmosMongoEmbeddingGenerationTests(CosmosMongoEmbeddingGenerationTests.Fixture fixture) + : EmbeddingGenerationTests(fixture), IClassFixture +{ + public new class Fixture : EmbeddingGenerationTests.Fixture + { + public override TestStore TestStore => CosmosMongoDBTestStore.Instance; + + public override IVectorStore CreateVectorStore(IEmbeddingGenerator? embeddingGenerator) + => CosmosMongoDBTestStore.Instance.GetVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); + + public override Func[] DependencyInjectionStoreRegistrationDelegates => + [ + services => services + .AddSingleton(CosmosMongoDBTestStore.Instance.Database) + .AddAzureCosmosDBMongoDBVectorStore() + ]; + + public override Func[] DependencyInjectionCollectionRegistrationDelegates => + [ + services => services + .AddSingleton(CosmosMongoDBTestStore.Instance.Database) + .AddAzureCosmosDBMongoDBVectorStoreRecordCollection(this.CollectionName) + ]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicFilterTests.cs index 50dfc677ad00..cce45cffd2ec 100644 --- a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicFilterTests.cs @@ -14,7 +14,9 @@ public class CosmosMongoBasicFilterTests(CosmosMongoBasicFilterTests.Fixture fix // Specialized MongoDB syntax for NOT over Contains ($nin) [ConditionalFact] public virtual Task Not_over_Contains() - => this.TestFilterAsync(r => !new[] { 8, 10 }.Contains(r.Int)); + => this.TestFilterAsync( + r => !new[] { 8, 10 }.Contains(r.Int), + r => !new[] { 8, 10 }.Contains((int)r["Int"]!)); #region Null checking diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicQueryTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicQueryTests.cs new file mode 100644 index 000000000000..6852f8efa085 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicQueryTests.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosMongoDBIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace CosmosMongoDBIntegrationTests.Filter; + +public class CosmosMongoBasicQueryTests(CosmosMongoBasicQueryTests.Fixture fixture) + : BasicQueryTests(fixture), IClassFixture +{ + // Specialized MongoDB syntax for NOT over Contains ($nin) + [ConditionalFact] + public virtual Task Not_over_Contains() + => this.TestFilterAsync( + r => !new[] { 8, 10 }.Contains(r.Int), + r => !new[] { 8, 10 }.Contains((int)r["Int"]!)); + + // MongoDB currently doesn't support null checking ({ "Foo" : null }) in vector search pre-filters + public override Task Equal_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task Equal_with_null_captured() + => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); + + public override Task NotEqual_with_null_reference_type() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_reference_type()); + + public override Task NotEqual_with_null_captured() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); + + // MongoDB currently doesn't support NOT in vector search pre-filters + // (https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter) + public override Task Not_over_And() + => Assert.ThrowsAsync(() => base.Not_over_And()); + + public override Task Not_over_Or() + => Assert.ThrowsAsync(() => base.Not_over_Or()); + + public override Task Contains_over_field_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + + public override Task Contains_over_field_string_List() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + + public new class Fixture : BasicQueryTests.QueryFixture + { + public override TestStore TestStore => CosmosMongoDBTestStore.Instance; + + protected override string IndexKind => Microsoft.Extensions.VectorData.IndexKind.IvfFlat; + protected override string DistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineDistance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBSimpleModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBSimpleModelFixture.cs new file mode 100644 index 000000000000..42d6a8dbf3a9 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBSimpleModelFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace CosmosMongoDBIntegrationTests.Support; + +public class CosmosMongoDBSimpleModelFixture : SimpleModelFixture +{ + public override TestStore TestStore => CosmosMongoDBTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestEnvironment.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestEnvironment.cs index df6550d05237..faf467122f2f 100644 --- a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestEnvironment.cs +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestEnvironment.cs @@ -18,6 +18,7 @@ static CosmosMongoDBTestEnvironment() .AddJsonFile(path: "testsettings.json", optional: true) .AddJsonFile(path: "testsettings.development.json", optional: true) .AddEnvironmentVariables() + .AddUserSecrets() .Build(); ConnectionString = configuration["AzureCosmosDBMongoDB:ConnectionString"]; diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestStore.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestStore.cs index fba1d18c8a7f..f7727ee324fb 100644 --- a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestStore.cs @@ -21,6 +21,10 @@ public sealed class CosmosMongoDBTestStore : TestStore public override IVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Call InitializeAsync() first"); + public override string DefaultIndexKind => Microsoft.Extensions.VectorData.IndexKind.IvfFlat; + + public override string DefaultDistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineDistance; + public AzureCosmosDBMongoDBVectorStore GetVectorStore(AzureCosmosDBMongoDBVectorStoreOptions options) => new(this.Database, options); diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CRUD/CosmosNoSQLNoDataConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CRUD/CosmosNoSQLNoDataConformanceTests.cs new file mode 100644 index 000000000000..9c2f4659a6bb --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CRUD/CosmosNoSQLNoDataConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosNoSQLIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace CosmosNoSQLIntegrationTests.CRUD; + +public class CosmosNoSQLNoDataConformanceTests(CosmosNoSQLNoDataConformanceTests.Fixture fixture) + : NoDataConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoDataConformanceTests.Fixture + { + public override TestStore TestStore => CosmosNoSQLTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CRUD/CosmosNoSQLNoVectorConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CRUD/CosmosNoSQLNoVectorConformanceTests.cs new file mode 100644 index 000000000000..b6581771ae05 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CRUD/CosmosNoSQLNoVectorConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosNoSQLIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace CosmosNoSQLIntegrationTests.CRUD; + +public class CosmosNoSQLNoVectorConformanceTests(CosmosNoSQLNoVectorConformanceTests.Fixture fixture) + : NoVectorConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoVectorConformanceTests.Fixture + { + public override TestStore TestStore => CosmosNoSQLTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CosmosNoSQLEmbeddingGenerationTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CosmosNoSQLEmbeddingGenerationTests.cs new file mode 100644 index 000000000000..3d8be02c4552 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CosmosNoSQLEmbeddingGenerationTests.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosNoSQLIntegrationTests.Support; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel; +using VectorDataSpecificationTests; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace CosmosNoSQLIntegrationTests; + +public class CosmosNoSQLEmbeddingGenerationTests(CosmosNoSQLEmbeddingGenerationTests.Fixture fixture) + : EmbeddingGenerationTests(fixture), IClassFixture +{ + public new class Fixture : EmbeddingGenerationTests.Fixture + { + public override TestStore TestStore => CosmosNoSQLTestStore.Instance; + + public override IVectorStore CreateVectorStore(IEmbeddingGenerator? embeddingGenerator) + => CosmosNoSQLTestStore.Instance.GetVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); + + public override Func[] DependencyInjectionStoreRegistrationDelegates => + [ + services => services + .AddSingleton(CosmosNoSQLTestStore.Instance.Database) + .AddAzureCosmosDBNoSQLVectorStore() + ]; + + public override Func[] DependencyInjectionCollectionRegistrationDelegates => + [ + services => services + .AddSingleton(CosmosNoSQLTestStore.Instance.Database) + .AddAzureCosmosDBNoSQLVectorStoreRecordCollection(this.CollectionName) + ]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLBasicFilterTests.cs index 4058ea8674a7..266145d86485 100644 --- a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLBasicFilterTests.cs @@ -12,6 +12,6 @@ public class CosmosNoSQLBasicFilterTests(CosmosNoSQLBasicFilterTests.Fixture fix { public new class Fixture : BasicFilterTests.Fixture { - public override TestStore TestStore => CosmosNoSqlTestStore.Instance; + public override TestStore TestStore => CosmosNoSQLTestStore.Instance; } } diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLBasicQueryTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLBasicQueryTests.cs new file mode 100644 index 000000000000..bd1bd9d9fd9a --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLBasicQueryTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosNoSQLIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace CosmosNoSQLIntegrationTests.Filter; + +public class CosmosNoSQLBasicQueryTests(CosmosNoSQLBasicQueryTests.Fixture fixture) + : BasicQueryTests(fixture), IClassFixture +{ + public new class Fixture : BasicQueryTests.QueryFixture + { + public override TestStore TestStore => CosmosNoSQLTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/HybridSearch/CosmosNoSQLKeywordVectorizedHybridSearchTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/HybridSearch/CosmosNoSQLKeywordVectorizedHybridSearchTests.cs index 24935b4ffc2d..081df134257f 100644 --- a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/HybridSearch/CosmosNoSQLKeywordVectorizedHybridSearchTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/HybridSearch/CosmosNoSQLKeywordVectorizedHybridSearchTests.cs @@ -16,11 +16,11 @@ public class CosmosNoSQLKeywordVectorizedHybridSearchTests( { public new class VectorAndStringFixture : KeywordVectorizedHybridSearchComplianceTests.VectorAndStringFixture { - public override TestStore TestStore => CosmosNoSqlTestStore.Instance; + public override TestStore TestStore => CosmosNoSQLTestStore.Instance; } public new class MultiTextFixture : KeywordVectorizedHybridSearchComplianceTests.MultiTextFixture { - public override TestStore TestStore => CosmosNoSqlTestStore.Instance; + public override TestStore TestStore => CosmosNoSQLTestStore.Instance; } } diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLFixture.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLFixture.cs index 9ddaad05be85..f1823af21bb8 100644 --- a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLFixture.cs +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLFixture.cs @@ -6,5 +6,5 @@ namespace CosmosNoSQLIntegrationTests.Support; public class CosmosNoSQLFixture : VectorStoreFixture { - public override TestStore TestStore => CosmosNoSqlTestStore.Instance; + public override TestStore TestStore => CosmosNoSQLTestStore.Instance; } diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs index 7e3269ba2a27..fe72f99d8695 100644 --- a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs @@ -13,14 +13,16 @@ namespace CosmosNoSQLIntegrationTests.Support; #pragma warning disable CA1001 // Type owns disposable fields (_connection) but is not disposable -internal sealed class CosmosNoSqlTestStore : TestStore +internal sealed class CosmosNoSQLTestStore : TestStore { - public static CosmosNoSqlTestStore Instance { get; } = new(); + public static CosmosNoSQLTestStore Instance { get; } = new(); private CosmosClient? _client; private Database? _database; private AzureCosmosDBNoSQLVectorStore? _defaultVectorStore; + public override string DefaultIndexKind => Microsoft.Extensions.VectorData.IndexKind.Flat; + public CosmosClient Client => this._client ?? throw new InvalidOperationException("Call InitializeAsync() first"); @@ -33,7 +35,7 @@ public override IVectorStore DefaultVectorStore public AzureCosmosDBNoSQLVectorStore GetVectorStore(AzureCosmosDBNoSQLVectorStoreOptions options) => new(this.Database, options); - private CosmosNoSqlTestStore() + private CosmosNoSQLTestStore() { } diff --git a/dotnet/src/VectorDataIntegrationTests/Directory.Build.props b/dotnet/src/VectorDataIntegrationTests/Directory.Build.props index f5d133b5fd9f..eacdeec35e93 100644 --- a/dotnet/src/VectorDataIntegrationTests/Directory.Build.props +++ b/dotnet/src/VectorDataIntegrationTests/Directory.Build.props @@ -6,7 +6,9 @@ $(NoWarn);CA1707 $(NoWarn);CA1716 $(NoWarn);CA1720 + $(NoWarn);CA1721 $(NoWarn);CA1861 + $(NoWarn);CA1863 $(NoWarn);CA2007;VSTHRD111 $(NoWarn);CS1591 $(NoWarn);IDE1006 diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/CRUD/InMemoryBatchConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/CRUD/InMemoryBatchConformanceTests.cs new file mode 100644 index 000000000000..f778fbb7154d --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/CRUD/InMemoryBatchConformanceTests.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using InMemoryIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace InMemoryIntegrationTests.CRUD; + +public class InMemoryBatchConformanceTests(InMemorySimpleModelFixture fixture) + : BatchConformanceTests(fixture), IClassFixture +{ + // InMemory always returns the vectors (IncludeVectors = false isn't respected) + public override async Task GetBatchAsync_WithoutVectors() + { + var expectedRecords = fixture.TestData.Take(2); // the last two records can get deleted by other tests + var ids = expectedRecords.Select(record => record.Id); + + var received = await fixture.Collection.GetAsync(ids, new() { IncludeVectors = false }).ToArrayAsync(); + + foreach (var record in expectedRecords) + { + record.AssertEqual(this.GetRecord(received, record.Id), includeVectors: true, fixture.TestStore.VectorsComparable); + } + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/CRUD/InMemoryDynamicRecordConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/CRUD/InMemoryDynamicRecordConformanceTests.cs new file mode 100644 index 000000000000..940d8686e8ea --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/CRUD/InMemoryDynamicRecordConformanceTests.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. + +using InMemoryIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace InMemoryIntegrationTests.CRUD; + +public class InMemoryDynamicRecordConformanceTests(InMemoryDynamicDataModelFixture fixture) + : DynamicDataModelConformanceTests(fixture), IClassFixture +{ + // InMemory always returns the vectors (IncludeVectors = false isn't respected) + public override async Task GetAsync_WithoutVectors() + { + var expectedRecord = fixture.TestData[0]; + + var received = await fixture.Collection.GetAsync( + (int)expectedRecord[DynamicDataModelFixture.KeyPropertyName]!, + new() { IncludeVectors = false }); + + AssertEquivalent(expectedRecord, received, includeVectors: true, fixture.TestStore.VectorsComparable); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/CRUD/InMemoryNoDataConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/CRUD/InMemoryNoDataConformanceTests.cs new file mode 100644 index 000000000000..cbd60656cecd --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/CRUD/InMemoryNoDataConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using InMemoryIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace InMemoryIntegrationTests.CRUD; + +public class InMemoryNoDataConformanceTests(InMemoryNoDataConformanceTests.Fixture fixture) + : NoDataConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoDataConformanceTests.Fixture + { + public override TestStore TestStore => InMemoryTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/CRUD/InMemoryNoVectorConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/CRUD/InMemoryNoVectorConformanceTests.cs new file mode 100644 index 000000000000..7c8c759c8d3c --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/CRUD/InMemoryNoVectorConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using InMemoryIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace InMemoryIntegrationTests.CRUD; + +public class InMemoryNoVectorConformanceTests(InMemoryNoVectorConformanceTests.Fixture fixture) + : NoVectorConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoVectorConformanceTests.Fixture + { + public override TestStore TestStore => InMemoryTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/CRUD/InMemoryRecordConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/CRUD/InMemoryRecordConformanceTests.cs new file mode 100644 index 000000000000..02534141ec76 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/CRUD/InMemoryRecordConformanceTests.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. + +using InMemoryIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace InMemoryIntegrationTests.CRUD; + +public class InMemoryRecordConformanceTests(InMemorySimpleModelFixture fixture) + : RecordConformanceTests(fixture), IClassFixture +{ + // InMemory always returns the vectors (IncludeVectors = false isn't respected) + public override async Task GetAsync_WithoutVectors() + { + var expectedRecord = fixture.TestData[0]; + var received = await fixture.Collection.GetAsync(expectedRecord.Id, new() { IncludeVectors = false }); + + expectedRecord.AssertEqual(received, includeVectors: true, fixture.TestStore.VectorsComparable); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryBasicFilterTests.cs index 198178aae1a1..b103836840df 100644 --- a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryBasicFilterTests.cs @@ -5,7 +5,7 @@ using VectorDataSpecificationTests.Support; using Xunit; -namespace PostgresIntegrationTests.Filter; +namespace InMemoryIntegrationTests.Filter; public class InMemoryBasicFilterTests(InMemoryBasicFilterTests.Fixture fixture) : BasicFilterTests(fixture), IClassFixture @@ -13,5 +13,13 @@ public class InMemoryBasicFilterTests(InMemoryBasicFilterTests.Fixture fixture) public new class Fixture : BasicFilterTests.Fixture { public override TestStore TestStore => InMemoryTestStore.Instance; + + // BaseFilterTests attempts to create two InMemoryVectorStoreRecordCollection with different .NET types: + // 1. One for strongly-typed mapping (TRecord=FilterRecord) + // 2. One for dynamic mapping (TRecord=Dictionary) + // Unfortunately, InMemoryVectorStore does not allow mapping the same collection name to different types; + // at the same time, it simply evaluates all filtering via .NET AsQueryable(), so actual test coverage + // isn't very important here. So we disable the dynamic tests. + public override bool TestDynamic => false; } } diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryBasicQueryTests.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryBasicQueryTests.cs new file mode 100644 index 000000000000..432fbdff3fbd --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryBasicQueryTests.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using InMemoryIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace InMemoryIntegrationTests.Filter; + +public class InMemoryBasicQueryTests(InMemoryBasicQueryTests.Fixture fixture) + : BasicQueryTests(fixture), IClassFixture +{ + public new class Fixture : BasicQueryTests.QueryFixture + { + public override TestStore TestStore => InMemoryTestStore.Instance; + + // BaseFilterTests attempts to create two InMemoryVectorStoreRecordCollection with different .NET types: + // 1. One for strongly-typed mapping (TRecord=FilterRecord) + // 2. One for dynamic mapping (TRecord=Dictionary) + // Unfortunately, InMemoryVectorStore does not allow mapping the same collection name to different types; + // at the same time, it simply evaluates all filtering via .NET AsQueryable(), so actual test coverage + // isn't very important here. So we disable the dynamic tests. + public override bool TestDynamic => false; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryEmbeddingGenerationTests.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryEmbeddingGenerationTests.cs new file mode 100644 index 000000000000..6d1cedeee544 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryEmbeddingGenerationTests.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft. All rights reserved. + +using InMemoryIntegrationTests.Support; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace InMemoryIntegrationTests; + +public class InMemoryEmbeddingGenerationTests(InMemoryEmbeddingGenerationTests.Fixture fixture) + : EmbeddingGenerationTests(fixture), IClassFixture +{ + // InMemory doesn't allowing accessing the same collection via different .NET types (it's unique in this). + // The following dynamic tests attempt to access the fixture collection - which is created with Record - via + // Dictionary. + public override Task SearchAsync_with_property_generator_dynamic() => Task.CompletedTask; + public override Task UpsertAsync_dynamic() => Task.CompletedTask; + public override Task UpsertAsync_batch_dynamic() => Task.CompletedTask; + + // The same applies to the custom type test: + public override Task SearchAsync_with_custom_input_type() => Task.CompletedTask; + + // The test relies on creating a new InMemoryVectorStore configured with a store-default generator, but with InMemory that store + // doesn't share the seeded data with the fixture store (since each InMemoryVectorStore has its own private data). + // Test coverage is already largely sufficient via the property and collection tests. + public override Task SearchAsync_with_store_generator() => Task.CompletedTask; + + public new class Fixture : EmbeddingGenerationTests.Fixture + { + public override TestStore TestStore => InMemoryTestStore.Instance; + + // Note that with InMemory specifically, we can't create a vector store with an embedding generator, since it wouldn't share the seeded data with the fixture store. + public override IVectorStore CreateVectorStore(IEmbeddingGenerator? embeddingGenerator) + => InMemoryTestStore.Instance.DefaultVectorStore; + + public override Func[] DependencyInjectionStoreRegistrationDelegates => + [ + // The InMemory DI methods register a new vector store instance, which doesn't share the collection seeded by the + // fixture and the test fails. + // services => services.AddInMemoryVectorStore() + ]; + + public override Func[] DependencyInjectionCollectionRegistrationDelegates => + [ + // The InMemory DI methods register a new vector store instance, which doesn't share the collection seeded by the + // fixture and the test fails. + // services => services.AddInMemoryVectorStoreRecordCollection(this.CollectionName) + ]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryIntegrationTests.csproj index f77fff8de939..1f5b8383e120 100644 --- a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryIntegrationTests.csproj @@ -10,17 +10,17 @@ - + runtime; build; native; contentfiles; analyzers; buildtransitive all - + - - + + diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemoryDynamicDataModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemoryDynamicDataModelFixture.cs new file mode 100644 index 000000000000..07a05aad98f8 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemoryDynamicDataModelFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace InMemoryIntegrationTests.Support; + +public class InMemoryDynamicDataModelFixture : DynamicDataModelFixture +{ + public override TestStore TestStore => InMemoryTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemorySimpleModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemorySimpleModelFixture.cs new file mode 100644 index 000000000000..8c32c4cc1306 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemorySimpleModelFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace InMemoryIntegrationTests.Support; + +public class InMemorySimpleModelFixture : SimpleModelFixture +{ + public override TestStore TestStore => InMemoryTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemoryTestStore.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemoryTestStore.cs index 246d5166c831..81f44e3339df 100644 --- a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemoryTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemoryTestStore.cs @@ -10,9 +10,12 @@ internal sealed class InMemoryTestStore : TestStore { public static InMemoryTestStore Instance { get; } = new(); - private InMemoryVectorStore _vectorStore = new(); + private InMemoryVectorStore _defaultVectorStore = new(); - public override IVectorStore DefaultVectorStore => this._vectorStore; + public override IVectorStore DefaultVectorStore => this._defaultVectorStore; + + public InMemoryVectorStore GetVectorStore(InMemoryVectorStoreOptions options) + => new(new() { EmbeddingGenerator = options.EmbeddingGenerator }); private InMemoryTestStore() { @@ -20,7 +23,7 @@ private InMemoryTestStore() protected override Task StartAsync() { - this._vectorStore = new InMemoryVectorStore(); + this._defaultVectorStore = new InMemoryVectorStore(); return Task.CompletedTask; } diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/CRUD/MongoDBBatchConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/CRUD/MongoDBBatchConformanceTests.cs new file mode 100644 index 000000000000..67ea1ee7e2c4 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/CRUD/MongoDBBatchConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using MongoDBIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace MongoDBIntegrationTests.CRUD; + +public class MongoDBBatchConformanceTests(MongoDBSimpleModelFixture fixture) + : BatchConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/CRUD/MongoDBNoDataConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/CRUD/MongoDBNoDataConformanceTests.cs new file mode 100644 index 000000000000..97a11c6b4624 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/CRUD/MongoDBNoDataConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using MongoDBIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace MongoDBIntegrationTests.CRUD; + +public class MongoDBNoDataConformanceTests(MongoDBNoDataConformanceTests.Fixture fixture) + : NoDataConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoDataConformanceTests.Fixture + { + public override TestStore TestStore => MongoDBTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/CRUD/MongoDBNoVectorConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/CRUD/MongoDBNoVectorConformanceTests.cs new file mode 100644 index 000000000000..f4a597f6bcaa --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/CRUD/MongoDBNoVectorConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using MongoDBIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace MongoDBIntegrationTests.CRUD; + +public class MongoDBNoVectorConformanceTests(MongoDBNoVectorConformanceTests.Fixture fixture) + : NoVectorConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoVectorConformanceTests.Fixture + { + public override TestStore TestStore => MongoDBTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/CRUD/MongoDBRecordConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/CRUD/MongoDBRecordConformanceTests.cs new file mode 100644 index 000000000000..22d642b7c16a --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/CRUD/MongoDBRecordConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using MongoDBIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace MongoDBIntegrationTests.CRUD; + +public class MongoDBRecordConformanceTests(MongoDBSimpleModelFixture fixture) + : RecordConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicFilterTests.cs index 885c1503f5f7..da5ed5f46b8c 100644 --- a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicFilterTests.cs @@ -14,7 +14,9 @@ public class MongoDBBasicFilterTests(MongoDBBasicFilterTests.Fixture fixture) // Specialized MongoDB syntax for NOT over Contains ($nin) [ConditionalFact] public virtual Task Not_over_Contains() - => this.TestFilterAsync(r => !new[] { 8, 10 }.Contains(r.Int)); + => this.TestFilterAsync( + r => !new[] { 8, 10 }.Contains(r.Int), + r => !new[] { 8, 10 }.Contains((int)r["Int"]!)); #region Null checking diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicQueryTests.cs b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicQueryTests.cs new file mode 100644 index 000000000000..5e5d138a35b6 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicQueryTests.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft. All rights reserved. + +using MongoDBIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace MongoDBIntegrationTests.Filter; + +public class MongoDBBasicQueryTests(MongoDBBasicQueryTests.Fixture fixture) + : BasicQueryTests(fixture), IClassFixture +{ + // Specialized MongoDB syntax for NOT over Contains ($nin) + [ConditionalFact] + public virtual Task Not_over_Contains() + => this.TestFilterAsync( + r => !new[] { 8, 10 }.Contains(r.Int), + r => !new[] { 8, 10 }.Contains((int)r["Int"]!)); + + // MongoDB currently doesn't support null checking ({ "Foo" : null }) in vector search pre-filters + public override Task Equal_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task Equal_with_null_captured() + => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); + + public override Task NotEqual_with_null_reference_type() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_reference_type()); + + public override Task NotEqual_with_null_captured() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); + + // MongoDB currently doesn't support NOT in vector search pre-filters + // (https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter) + public override Task Not_over_And() + => Assert.ThrowsAsync(() => base.Not_over_And()); + + public override Task Not_over_Or() + => Assert.ThrowsAsync(() => base.Not_over_Or()); + + public override Task Contains_over_field_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + + public override Task Contains_over_field_string_List() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + + public new class Fixture : BasicQueryTests.QueryFixture + { + public override TestStore TestStore => MongoDBTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/MongoDBEmbeddingGenerationTests.cs b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/MongoDBEmbeddingGenerationTests.cs new file mode 100644 index 000000000000..aa91f172d3b8 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/MongoDBEmbeddingGenerationTests.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel; +using MongoDBIntegrationTests.Support; +using VectorDataSpecificationTests; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace MongoDBIntegrationTests; + +public class MongoDBEmbeddingGenerationTests(MongoDBEmbeddingGenerationTests.Fixture fixture) + : EmbeddingGenerationTests(fixture), IClassFixture +{ + public new class Fixture : EmbeddingGenerationTests.Fixture + { + public override TestStore TestStore => MongoDBTestStore.Instance; + + public override IVectorStore CreateVectorStore(IEmbeddingGenerator? embeddingGenerator) + => MongoDBTestStore.Instance.GetVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); + + public override Func[] DependencyInjectionStoreRegistrationDelegates => + [ + services => services + .AddSingleton(MongoDBTestStore.Instance.Database) + .AddMongoDBVectorStore() + ]; + + public override Func[] DependencyInjectionCollectionRegistrationDelegates => + [ + services => services + .AddSingleton(MongoDBTestStore.Instance.Database) + .AddMongoDBVectorStoreRecordCollection(this.CollectionName) + ]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Support/MongoDBSimpleModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Support/MongoDBSimpleModelFixture.cs new file mode 100644 index 000000000000..143f5497d7cf --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Support/MongoDBSimpleModelFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace MongoDBIntegrationTests.Support; + +public class MongoDBSimpleModelFixture : SimpleModelFixture +{ + public override TestStore TestStore => MongoDBTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/CRUD/PineconeAllSupportedTypesTests.cs b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/CRUD/PineconeAllSupportedTypesTests.cs index e1dbf9efa81e..ddb702b22104 100644 --- a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/CRUD/PineconeAllSupportedTypesTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/CRUD/PineconeAllSupportedTypesTests.cs @@ -61,9 +61,9 @@ public async Task AllTypesBatchGetAsync() } ]; - await collection.UpsertBatchAsync(records).ToArrayAsync(); + await collection.UpsertAsync(records); - var allTypes = await collection.GetBatchAsync(records.Select(r => r.Id), new GetRecordOptions { IncludeVectors = true }).ToListAsync(); + var allTypes = await collection.GetAsync(records.Select(r => r.Id), new GetRecordOptions { IncludeVectors = true }).ToListAsync(); var allTypes1 = allTypes.Single(x => x.Id == records[0].Id); var allTypes2 = allTypes.Single(x => x.Id == records[1].Id); diff --git a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/CRUD/PineconeDynamicDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/CRUD/PineconeDynamicDataModelConformanceTests.cs new file mode 100644 index 000000000000..029a18b0f4c5 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/CRUD/PineconeDynamicDataModelConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PineconeIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace PineconeIntegrationTests.CRUD; + +public class PineconeDynamicDataModelConformanceTests(PineconeDynamicDataModelFixture fixture) + : DynamicDataModelConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/CRUD/PineconeGenericDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/CRUD/PineconeGenericDataModelConformanceTests.cs deleted file mode 100644 index d18cdb99b38f..000000000000 --- a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/CRUD/PineconeGenericDataModelConformanceTests.cs +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using PineconeIntegrationTests.Support; -using VectorDataSpecificationTests.CRUD; -using Xunit; - -namespace PineconeIntegrationTests.CRUD; - -public class PineconeGenericDataModelConformanceTests(PineconeGenericDataModelFixture fixture) - : GenericDataModelConformanceTests(fixture), IClassFixture -{ -} diff --git a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/CRUD/PineconeNoDataConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/CRUD/PineconeNoDataConformanceTests.cs new file mode 100644 index 000000000000..d987ebc26907 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/CRUD/PineconeNoDataConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PineconeIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace PineconeIntegrationTests.CRUD; + +public class PineconeNoDataConformanceTests(PineconeNoDataConformanceTests.Fixture fixture) + : NoDataConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoDataConformanceTests.Fixture + { + public override TestStore TestStore => PineconeTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Filter/PineconeBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Filter/PineconeBasicFilterTests.cs index 095b0d03ebd0..b6a38a0ff09e 100644 --- a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Filter/PineconeBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Filter/PineconeBasicFilterTests.cs @@ -8,13 +8,17 @@ namespace PineconeIntegrationTests.Filter; +#pragma warning disable CS8605 // Unboxing a possibly null value. + public class PineconeBasicFilterTests(PineconeBasicFilterTests.Fixture fixture) : BasicFilterTests(fixture), IClassFixture { // Specialized Pinecone syntax for NOT over Contains ($nin) [ConditionalFact] public virtual Task Not_over_Contains() - => this.TestFilterAsync(r => !new[] { 8, 10 }.Contains(r.Int)); + => this.TestFilterAsync( + r => !new[] { 8, 10 }.Contains(r.Int), + r => !new[] { 8, 10 }.Contains((int)r["Int"])); #region Null checking @@ -65,6 +69,6 @@ public override Task Legacy_AnyTagEqualTo_List() public override TestStore TestStore => PineconeTestStore.Instance; // https://docs.pinecone.io/troubleshooting/restrictions-on-index-names - protected override string CollectionName => "filter-tests"; + public override string CollectionName => "filter-tests"; } } diff --git a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Filter/PineconeBasicQueryTests.cs b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Filter/PineconeBasicQueryTests.cs new file mode 100644 index 000000000000..2ebf1e17b451 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Filter/PineconeBasicQueryTests.cs @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PineconeIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace PineconeIntegrationTests.Filter; + +#pragma warning disable CS8605 // Unboxing a possibly null value. + +public class PineconeBasicQueryTests(PineconeBasicQueryTests.Fixture fixture) + : BasicQueryTests(fixture), IClassFixture +{ + // Specialized Pinecone syntax for NOT over Contains ($nin) + [ConditionalFact] + public virtual Task Not_over_Contains() + => this.TestFilterAsync( + r => !new[] { 8, 10 }.Contains(r.Int), + r => !new[] { 8, 10 }.Contains((int)r["Int"])); + + // Pinecone currently doesn't support null checking ({ "Foo" : null }) in vector search pre-filters + public override Task Equal_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task Equal_with_null_captured() + => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); + + public override Task NotEqual_with_null_reference_type() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_reference_type()); + + public override Task NotEqual_with_null_captured() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); + + // Pinecone currently doesn't support NOT in vector search pre-filters + // (https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter) + public override Task Not_over_And() + => Assert.ThrowsAsync(() => base.Not_over_And()); + + public override Task Not_over_Or() + => Assert.ThrowsAsync(() => base.Not_over_Or()); + + public override Task Contains_over_field_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + + public override Task Contains_over_field_string_List() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + + public new class Fixture : BasicQueryTests.QueryFixture + { + public override TestStore TestStore => PineconeTestStore.Instance; + + // https://docs.pinecone.io/troubleshooting/restrictions-on-index-names + public override string CollectionName => "query-tests"; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/PineconeEmbeddingGenerationTests.cs b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/PineconeEmbeddingGenerationTests.cs new file mode 100644 index 000000000000..642260212a2d --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/PineconeEmbeddingGenerationTests.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.Properties; +using Microsoft.SemanticKernel; +using PineconeIntegrationTests.Support; +using VectorDataSpecificationTests; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace PineconeIntegrationTests; + +public class PineconeEmbeddingGenerationTests(PineconeEmbeddingGenerationTests.Fixture fixture) + : EmbeddingGenerationTests(fixture), IClassFixture +{ + // Overriding since Pinecone requires collection names to only contain ASCII lowercase letters, digits and dashes. + public override async Task SearchAsync_without_generator_throws() + { + // The database doesn't support embedding generation, and no client-side generator has been configured at any level, + // so SearchAsync should throw. + var collection = fixture.GetCollection(fixture.TestStore.DefaultVectorStore, fixture.CollectionName + "-without-generator"); + + var exception = await Assert.ThrowsAsync(() => collection.SearchAsync("foo", top: 1).ToListAsync().AsTask()); + + Assert.Equal(VectorDataStrings.NoEmbeddingGeneratorWasConfiguredForSearch, exception.Message); + } + + public new class Fixture : EmbeddingGenerationTests.Fixture + { + public override TestStore TestStore => PineconeTestStore.Instance; + + // https://docs.pinecone.io/troubleshooting/restrictions-on-index-names + public override string CollectionName => "embedding-generation-tests"; + + public override IVectorStore CreateVectorStore(IEmbeddingGenerator? embeddingGenerator) + => PineconeTestStore.Instance.GetVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); + + public override Func[] DependencyInjectionStoreRegistrationDelegates => + [ + services => services + .AddSingleton(PineconeTestStore.Instance.Client) + .AddPineconeVectorStore() + ]; + + public override Func[] DependencyInjectionCollectionRegistrationDelegates => + [ + services => services + .AddSingleton(PineconeTestStore.Instance.Client) + .AddPineconeVectorStoreRecordCollection(this.CollectionName) + ]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/PineconeIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/PineconeIntegrationTests.csproj index bc92e1816858..3d1612762740 100644 --- a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/PineconeIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/PineconeIntegrationTests.csproj @@ -9,7 +9,7 @@ false true - $(NoWarn);CA2007,SKEXP0001,SKEXP0020,VSTHRD111;CS1685 + $(NoWarn);CA2007,SKEXP0001,VSTHRD111;CS1685 b7762d10-e29b-4bb1-8b74-b6d69a667dd4 diff --git a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Support/PineconeAllTypes.cs b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Support/PineconeAllTypes.cs index 54d98f72c251..be73146024a4 100644 --- a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Support/PineconeAllTypes.cs +++ b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Support/PineconeAllTypes.cs @@ -48,7 +48,7 @@ public record PineconeAllTypes() [VectorStoreRecordData] public List? NullableStringList { get; set; } - [VectorStoreRecordVector(Dimensions: 8, DistanceFunction: DistanceFunction.DotProductSimilarity)] + [VectorStoreRecordVector(Dimensions: 8, DistanceFunction = DistanceFunction.DotProductSimilarity)] public ReadOnlyMemory? Embedding { get; set; } internal void AssertEqual(PineconeAllTypes other) @@ -95,7 +95,7 @@ internal static VectorStoreRecordDefinition GetRecordDefinition() new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.NullableStringArray), typeof(string[])), new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.StringList), typeof(List)), new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.NullableStringList), typeof(List)), - new VectorStoreRecordVectorProperty(nameof(PineconeAllTypes.Embedding), typeof(ReadOnlyMemory?)) { Dimensions = 8, DistanceFunction = Microsoft.Extensions.VectorData.DistanceFunction.DotProductSimilarity } + new VectorStoreRecordVectorProperty(nameof(PineconeAllTypes.Embedding), typeof(ReadOnlyMemory?), 8) { DistanceFunction = Microsoft.Extensions.VectorData.DistanceFunction.DotProductSimilarity } ] }; } diff --git a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Support/PineconeGenericDataModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Support/PineconeDynamicDataModelFixture.cs similarity index 78% rename from dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Support/PineconeGenericDataModelFixture.cs rename to dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Support/PineconeDynamicDataModelFixture.cs index 91768966c9ff..5d3f5577fb56 100644 --- a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Support/PineconeGenericDataModelFixture.cs +++ b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Support/PineconeDynamicDataModelFixture.cs @@ -4,7 +4,7 @@ namespace PineconeIntegrationTests.Support; -public class PineconeGenericDataModelFixture : GenericDataModelFixture +public class PineconeDynamicDataModelFixture : DynamicDataModelFixture { public override TestStore TestStore => PineconeTestStore.Instance; } diff --git a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Support/PineconeTestStore.cs b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Support/PineconeTestStore.cs index 40d3e221e777..f97a1d0488f2 100644 --- a/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Support/PineconeTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/PineconeIntegrationTests/Support/PineconeTestStore.cs @@ -35,6 +35,9 @@ internal sealed class PineconeTestStore : TestStore public override IVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + public PineconeVectorStore GetVectorStore(PineconeVectorStoreOptions options) + => new(this.Client, options); + // Pinecone does not support distance functions other than PGA which is always enabled. public override string DefaultIndexKind => ""; diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresDynamicDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresDynamicDataModelConformanceTests.cs new file mode 100644 index 000000000000..d23adfeb48cf --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresDynamicDataModelConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PostgresIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace PostgresIntegrationTests.CRUD; + +public class PostgresDynamicDataModelConformanceTests(PostgresDynamicDataModelFixture fixture) + : DynamicDataModelConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresGenericDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresGenericDataModelConformanceTests.cs deleted file mode 100644 index 98451084af94..000000000000 --- a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresGenericDataModelConformanceTests.cs +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using PostgresIntegrationTests.Support; -using VectorDataSpecificationTests.CRUD; -using Xunit; - -namespace PostgresIntegrationTests.CRUD; - -public class PostgresGenericDataModelConformanceTests(PostgresGenericDataModelFixture fixture) - : GenericDataModelConformanceTests(fixture), IClassFixture -{ -} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresNoDataConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresNoDataConformanceTests.cs new file mode 100644 index 000000000000..ad136ca85d23 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresNoDataConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PostgresIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace PostgresIntegrationTests.CRUD; + +public class PostgresNoDataConformanceTests(PostgresNoDataConformanceTests.Fixture fixture) + : NoDataConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoDataConformanceTests.Fixture + { + public override TestStore TestStore => PostgresTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresNoVectorConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresNoVectorConformanceTests.cs new file mode 100644 index 000000000000..93f6a0fff133 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresNoVectorConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PostgresIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace PostgresIntegrationTests.CRUD; + +public class PostgresNoVectorConformanceTests(PostgresNoVectorConformanceTests.Fixture fixture) + : NoVectorConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoVectorConformanceTests.Fixture + { + public override TestStore TestStore => PostgresTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresBasicFilterTests.cs index 955d920cbde6..223fd51b6ffe 100644 --- a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresBasicFilterTests.cs @@ -8,6 +8,8 @@ namespace PostgresIntegrationTests.Filter; +#pragma warning disable CS0252 // Possible unintended reference comparison; left hand side needs cast + public class PostgresBasicFilterTests(PostgresBasicFilterTests.Fixture fixture) : BasicFilterTests(fixture), IClassFixture { @@ -15,18 +17,22 @@ public override async Task Not_over_Or() { // Test sends: WHERE (NOT (("Int" = 8) OR ("String" = 'foo'))) // There's a NULL string in the database, and relational null semantics in conjunction with negation makes the default implementation fail. - await Assert.ThrowsAsync(() => base.Not_over_Or()); + await Assert.ThrowsAsync(() => base.Not_over_Or()); // Compensate by adding a null check: - await this.TestFilterAsync(r => r.String != null && !(r.Int == 8 || r.String == "foo")); + await this.TestFilterAsync( + r => r.String != null && !(r.Int == 8 || r.String == "foo"), + r => r["String"] != null && !((int)r["Int"]! == 8 || r["String"] == "foo")); } public override async Task NotEqual_with_string() { // As above, null semantics + negation - await Assert.ThrowsAsync(() => base.NotEqual_with_string()); + await Assert.ThrowsAsync(() => base.NotEqual_with_string()); - await this.TestFilterAsync(r => r.String != null && r.String != "foo"); + await this.TestFilterAsync( + r => r.String != null && r.String != "foo", + r => r["String"] != null && r["String"] != "foo"); } [Obsolete("Legacy filter support")] diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresBasicQueryTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresBasicQueryTests.cs new file mode 100644 index 000000000000..a3dcaf2295ca --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresBasicQueryTests.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PostgresIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; +using Xunit; +using Xunit.Sdk; + +namespace PostgresIntegrationTests.Filter; + +#pragma warning disable CS0252 // Possible unintended reference comparison; left hand side needs cast + +public class PostgresBasicQueryTests(PostgresBasicQueryTests.Fixture fixture) + : BasicQueryTests(fixture), IClassFixture +{ + public override async Task Not_over_Or() + { + // Test sends: WHERE (NOT (("Int" = 8) OR ("String" = 'foo'))) + // There's a NULL string in the database, and relational null semantics in conjunction with negation makes the default implementation fail. + await Assert.ThrowsAsync(() => base.Not_over_Or()); + + // Compensate by adding a null check: + await this.TestFilterAsync( + r => r.String != null && !(r.Int == 8 || r.String == "foo"), + r => r["String"] != null && !((int)r["Int"]! == 8 || r["String"] == "foo")); + } + + public override async Task NotEqual_with_string() + { + // As above, null semantics + negation + await Assert.ThrowsAsync(() => base.NotEqual_with_string()); + + await this.TestFilterAsync( + r => r.String != null && r.String != "foo", + r => r["String"] != null && r["String"] != "foo"); + } + + public new class Fixture : BasicQueryTests.QueryFixture + { + public override TestStore TestStore => PostgresTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresEmbeddingGenerationTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresEmbeddingGenerationTests.cs new file mode 100644 index 000000000000..a28e06e425b1 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresEmbeddingGenerationTests.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel; +using PostgresIntegrationTests.Support; +using VectorDataSpecificationTests; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace PostgresIntegrationTests; + +public class PostgresEmbeddingGenerationTests(PostgresEmbeddingGenerationTests.Fixture fixture) + : EmbeddingGenerationTests(fixture), IClassFixture +{ + public new class Fixture : EmbeddingGenerationTests.Fixture + { + public override TestStore TestStore => PostgresTestStore.Instance; + + public override IVectorStore CreateVectorStore(IEmbeddingGenerator? embeddingGenerator) + => PostgresTestStore.Instance.GetVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); + + public override Func[] DependencyInjectionStoreRegistrationDelegates => + [ + services => services + .AddSingleton(PostgresTestStore.Instance.DataSource) + .AddPostgresVectorStore() + ]; + + public override Func[] DependencyInjectionCollectionRegistrationDelegates => + [ + services => services + .AddSingleton(PostgresTestStore.Instance.DataSource) + .AddPostgresVectorStoreRecordCollection(this.CollectionName) + ]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresIntegrationTests.csproj index 0a039793dc49..d3eb1b72853c 100644 --- a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresIntegrationTests.csproj @@ -17,6 +17,9 @@ + + + diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresGenericDataModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresDynamicDataModelFixture.cs similarity index 78% rename from dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresGenericDataModelFixture.cs rename to dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresDynamicDataModelFixture.cs index c5b9a96b405f..dc1698a280e1 100644 --- a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresGenericDataModelFixture.cs +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresDynamicDataModelFixture.cs @@ -4,7 +4,7 @@ namespace PostgresIntegrationTests.Support; -public class PostgresGenericDataModelFixture : GenericDataModelFixture +public class PostgresDynamicDataModelFixture : DynamicDataModelFixture { public override TestStore TestStore => PostgresTestStore.Instance; } diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresTestStore.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresTestStore.cs index 1d4c540c216a..7289d91e6bb1 100644 --- a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresTestStore.cs @@ -8,8 +8,6 @@ namespace PostgresIntegrationTests.Support; -#pragma warning disable SKEXP0020 - internal sealed class PostgresTestStore : TestStore { public static PostgresTestStore Instance { get; } = new(); diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/CRUD/QdrantBatchConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/CRUD/QdrantBatchConformanceTests.cs new file mode 100644 index 000000000000..ccabbb8697f4 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/CRUD/QdrantBatchConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using QdrantIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace QdrantIntegrationTests.CRUD; + +public class QdrantBatchConformanceTests_NamedVectors(QdrantNamedSimpleModelFixture fixture) + : BatchConformanceTests(fixture), IClassFixture +{ +} + +public class QdrantBatchConformanceTests_UnnamedVector(QdrantUnnamedSimpleModelFixture fixture) + : BatchConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/CRUD/QdrantDynamicDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/CRUD/QdrantDynamicDataModelConformanceTests.cs new file mode 100644 index 000000000000..37db1c56e0cc --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/CRUD/QdrantDynamicDataModelConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using QdrantIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace QdrantIntegrationTests.CRUD; + +public class QdrantDynamicDataModelConformanceTests_NamedVectors(QdrantNamedDynamicDataModelFixture fixture) + : DynamicDataModelConformanceTests(fixture), IClassFixture +{ +} + +public class QdrantDynamicDataModelConformanceTests_UnnamedVector(QdrantUnnamedDynamicDataModelFixture fixture) + : DynamicDataModelConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/CRUD/QdrantNoDataConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/CRUD/QdrantNoDataConformanceTests.cs new file mode 100644 index 000000000000..61437f987b83 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/CRUD/QdrantNoDataConformanceTests.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft. All rights reserved. + +using QdrantIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace QdrantIntegrationTests.CRUD; + +public class QdrantNoDataConformanceTests_NamedVectors(QdrantNoDataConformanceTests_NamedVectors.Fixture fixture) + : NoDataConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoDataConformanceTests.Fixture + { + public override TestStore TestStore => QdrantTestStore.NamedVectorsInstance; + } +} + +public class QdrantNoDataConformanceTests_UnnamedVectors(QdrantNoDataConformanceTests_UnnamedVectors.Fixture fixture) + : NoDataConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoDataConformanceTests.Fixture + { + public override TestStore TestStore => QdrantTestStore.UnnamedVectorInstance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/CRUD/QdrantRecordConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/CRUD/QdrantRecordConformanceTests.cs new file mode 100644 index 000000000000..210b980fc6b7 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/CRUD/QdrantRecordConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using QdrantIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace QdrantIntegrationTests.CRUD; + +public class QdrantRecordConformanceTests_NamedVectors(QdrantNamedSimpleModelFixture fixture) + : RecordConformanceTests(fixture), IClassFixture +{ +} + +public class QdrantRecordConformanceTests_UnnamedVectors(QdrantUnnamedSimpleModelFixture fixture) + : RecordConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Collections/QdrantCollectionConformanceTests_UnnamedVector.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Collections/QdrantCollectionConformanceTests.cs similarity index 67% rename from dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Collections/QdrantCollectionConformanceTests_UnnamedVector.cs rename to dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Collections/QdrantCollectionConformanceTests.cs index 5471d83c8996..518331a721f2 100644 --- a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Collections/QdrantCollectionConformanceTests_UnnamedVector.cs +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Collections/QdrantCollectionConformanceTests.cs @@ -6,6 +6,11 @@ namespace QdrantIntegrationTests.Collections; +public class QdrantCollectionConformanceTests_NamedVectors(QdrantNamedVectorsFixture fixture) + : CollectionConformanceTests(fixture), IClassFixture +{ +} + public class QdrantCollectionConformanceTests_UnnamedVector(QdrantUnnamedVectorFixture fixture) : CollectionConformanceTests(fixture), IClassFixture { diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Collections/QdrantCollectionConformanceTests_NamedVectors.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Collections/QdrantCollectionConformanceTests_NamedVectors.cs deleted file mode 100644 index 7f4d0f138907..000000000000 --- a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Collections/QdrantCollectionConformanceTests_NamedVectors.cs +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using QdrantIntegrationTests.Support; -using VectorDataSpecificationTests.Collections; -using Xunit; - -namespace QdrantIntegrationTests.Collections; - -public class QdrantCollectionConformanceTests_NamedVectors(QdrantNamedVectorsFixture fixture) - : CollectionConformanceTests(fixture), IClassFixture -{ -} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantBasicFilterTests.cs index 2ba3b454231b..bc2e23af9688 100644 --- a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantBasicFilterTests.cs @@ -13,8 +13,5 @@ public class QdrantBasicFilterTests(QdrantBasicFilterTests.Fixture fixture) public new class Fixture : BasicFilterTests.Fixture { public override TestStore TestStore => QdrantTestStore.NamedVectorsInstance; - - // Qdrant doesn't support the default Flat index kind - protected override string IndexKind => Microsoft.Extensions.VectorData.IndexKind.Hnsw; } } diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantBasicQueryTests.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantBasicQueryTests.cs new file mode 100644 index 000000000000..bb6d77864f31 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantBasicQueryTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using QdrantIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace QdrantIntegrationTests.Filter; + +public class QdrantBasicQueryTests(QdrantBasicQueryTests.Fixture fixture) + : BasicQueryTests(fixture), IClassFixture +{ + public new class Fixture : BasicQueryTests.QueryFixture + { + public override TestStore TestStore => QdrantTestStore.NamedVectorsInstance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/HybridSearch/QdrantKeywordVectorizedHybridSearchTests_NamedVectors.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/HybridSearch/QdrantKeywordVectorizedHybridSearchTests_NamedVectors.cs index 86a878167626..dc48a3916d82 100644 --- a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/HybridSearch/QdrantKeywordVectorizedHybridSearchTests_NamedVectors.cs +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/HybridSearch/QdrantKeywordVectorizedHybridSearchTests_NamedVectors.cs @@ -17,16 +17,10 @@ public class QdrantKeywordVectorizedHybridSearchTests_NamedVectors( public new class VectorAndStringFixture : KeywordVectorizedHybridSearchComplianceTests.VectorAndStringFixture { public override TestStore TestStore => QdrantTestStore.NamedVectorsInstance; - - // Qdrant doesn't support the default Flat index kind - protected override string IndexKind => Microsoft.Extensions.VectorData.IndexKind.Hnsw; } public new class MultiTextFixture : KeywordVectorizedHybridSearchComplianceTests.MultiTextFixture { public override TestStore TestStore => QdrantTestStore.NamedVectorsInstance; - - // Qdrant doesn't support the default Flat index kind - protected override string IndexKind => Microsoft.Extensions.VectorData.IndexKind.Hnsw; } } diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/HybridSearch/QdrantKeywordVectorizedHybridSearchTests_UnnamedVectors.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/HybridSearch/QdrantKeywordVectorizedHybridSearchTests_UnnamedVectors.cs index e9492cd7ef21..4d3ff6f4b320 100644 --- a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/HybridSearch/QdrantKeywordVectorizedHybridSearchTests_UnnamedVectors.cs +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/HybridSearch/QdrantKeywordVectorizedHybridSearchTests_UnnamedVectors.cs @@ -17,16 +17,10 @@ public class QdrantKeywordVectorizedHybridSearchTests_UnnamedVectors( public new class VectorAndStringFixture : KeywordVectorizedHybridSearchComplianceTests.VectorAndStringFixture { public override TestStore TestStore => QdrantTestStore.UnnamedVectorInstance; - - // Qdrant doesn't support the default Flat index kind - protected override string IndexKind => Microsoft.Extensions.VectorData.IndexKind.Hnsw; } public new class MultiTextFixture : KeywordVectorizedHybridSearchComplianceTests.MultiTextFixture { public override TestStore TestStore => QdrantTestStore.UnnamedVectorInstance; - - // Qdrant doesn't support the default Flat index kind - protected override string IndexKind => Microsoft.Extensions.VectorData.IndexKind.Hnsw; } } diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/QdrantEmbeddingGenerationTests.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/QdrantEmbeddingGenerationTests.cs new file mode 100644 index 000000000000..841bb1f06f2d --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/QdrantEmbeddingGenerationTests.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel; +using QdrantIntegrationTests.Support; +using VectorDataSpecificationTests; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace QdrantIntegrationTests; + +public class QdrantEmbeddingGenerationTests(QdrantEmbeddingGenerationTests.Fixture fixture) + : EmbeddingGenerationTests(fixture), IClassFixture +{ + public new class Fixture : EmbeddingGenerationTests.Fixture + { + public override TestStore TestStore => QdrantTestStore.UnnamedVectorInstance; + + public override IVectorStore CreateVectorStore(IEmbeddingGenerator? embeddingGenerator) + => QdrantTestStore.UnnamedVectorInstance.GetVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); + + public override Func[] DependencyInjectionStoreRegistrationDelegates => + [ + services => services + .AddSingleton(QdrantTestStore.UnnamedVectorInstance.Client) + .AddQdrantVectorStore() + ]; + + public override Func[] DependencyInjectionCollectionRegistrationDelegates => + [ + services => services + .AddSingleton(QdrantTestStore.UnnamedVectorInstance.Client) + .AddQdrantVectorStoreRecordCollection(this.CollectionName) + ]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantNamedDynamicDataModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantNamedDynamicDataModelFixture.cs new file mode 100644 index 000000000000..851f76450d1b --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantNamedDynamicDataModelFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace QdrantIntegrationTests.Support; + +public class QdrantNamedDynamicDataModelFixture : DynamicDataModelFixture +{ + public override TestStore TestStore => QdrantTestStore.NamedVectorsInstance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantNamedSimpleModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantNamedSimpleModelFixture.cs new file mode 100644 index 000000000000..77a241f94b4e --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantNamedSimpleModelFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace QdrantIntegrationTests.Support; + +public class QdrantNamedSimpleModelFixture : SimpleModelFixture +{ + public override TestStore TestStore => QdrantTestStore.NamedVectorsInstance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantTestStore.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantTestStore.cs index c7148c291de4..0d48a54df7ed 100644 --- a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantTestStore.cs @@ -15,6 +15,9 @@ internal sealed class QdrantTestStore : TestStore public static QdrantTestStore NamedVectorsInstance { get; } = new(hasNamedVectors: true); public static QdrantTestStore UnnamedVectorInstance { get; } = new(hasNamedVectors: false); + // Qdrant doesn't support the default Flat index kind + public override string DefaultIndexKind => IndexKind.Hnsw; + private readonly QdrantContainer _container = new QdrantBuilder().Build(); private readonly bool _hasNamedVectors; private QdrantClient? _client; @@ -29,6 +32,14 @@ public QdrantVectorStore GetVectorStore(QdrantVectorStoreOptions options) private QdrantTestStore(bool hasNamedVectors) => this._hasNamedVectors = hasNamedVectors; + /// + /// Qdrant normalizes vectors on upsert, so we cannot compare + /// what we upserted and what we retrieve, we can only check + /// that a vector was returned. + /// https://github.com/qdrant/qdrant-client/discussions/727 + /// + public override bool VectorsComparable => false; + protected override async Task StartAsync() { await this._container.StartAsync(); diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantUnnamedDynamicDataModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantUnnamedDynamicDataModelFixture.cs new file mode 100644 index 000000000000..6771c223356e --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantUnnamedDynamicDataModelFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace QdrantIntegrationTests.Support; + +public class QdrantUnnamedDynamicDataModelFixture : DynamicDataModelFixture +{ + public override TestStore TestStore => QdrantTestStore.UnnamedVectorInstance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantUnnamedSimpleModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantUnnamedSimpleModelFixture.cs new file mode 100644 index 000000000000..64159ea93901 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantUnnamedSimpleModelFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace QdrantIntegrationTests.Support; + +public class QdrantUnnamedSimpleModelFixture : SimpleModelFixture +{ + public override TestStore TestStore => QdrantTestStore.UnnamedVectorInstance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisGenericDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisGenericDataModelConformanceTests.cs deleted file mode 100644 index 8806430fb9a0..000000000000 --- a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisGenericDataModelConformanceTests.cs +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using RedisIntegrationTests.Support; -using VectorDataSpecificationTests.CRUD; -using Xunit; - -namespace RedisIntegrationTests.CRUD; - -public class RedisGenericDataModelConformanceTests(RedisGenericDataModelFixture fixture) - : GenericDataModelConformanceTests(fixture), IClassFixture -{ -} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisHashSetDynamicDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisHashSetDynamicDataModelConformanceTests.cs new file mode 100644 index 000000000000..b323b02624b0 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisHashSetDynamicDataModelConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using RedisIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace RedisIntegrationTests.CRUD; + +public class RedisHashSetDynamicDataModelConformanceTests(RedisHashSetDynamicDataModelFixture fixture) + : DynamicDataModelConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisHashSetNoDataConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisHashSetNoDataConformanceTests.cs new file mode 100644 index 000000000000..35be99c1e93f --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisHashSetNoDataConformanceTests.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft. All rights reserved. + +using RedisIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace RedisIntegrationTests.CRUD; + +public class RedisHashSetNoDataConformanceTests(RedisHashSetNoDataConformanceTests.Fixture fixture) + : NoDataConformanceTests(fixture), IClassFixture +{ + [ConditionalFact] + public override async Task GetAsyncReturnsInsertedRecord_WithoutVectors() + { + var expectedRecord = fixture.TestData[0]; + + // When using HashSets there is no way to distinguish between no fields being returned and + // the record not existing. + Assert.Null(await fixture.Collection.GetAsync(expectedRecord.Id, new() { IncludeVectors = false })); + } + + [ConditionalFact(Skip = "When using HashSets there is no way to distinguish between no fields being returned and the record not existing so this test isn't useful.")] + public override Task UpsertAsyncCanInsertNewRecord_WithoutVectors() + { + return base.UpsertAsyncCanInsertNewRecord_WithoutVectors(); + } + + [ConditionalFact(Skip = "When using HashSets there is no way to distinguish between no fields being returned and the record not existing so this test isn't useful.")] + public override Task UpsertAsyncCanUpdateExistingRecord_WithoutVectors() + { + return base.UpsertAsyncCanUpdateExistingRecord_WithoutVectors(); + } + + public new class Fixture : NoDataConformanceTests.Fixture + { + public override TestStore TestStore => RedisTestStore.HashSetInstance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisHashSetNoVectorConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisHashSetNoVectorConformanceTests.cs new file mode 100644 index 000000000000..dddd33af25e9 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisHashSetNoVectorConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using RedisIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace RedisIntegrationTests.CRUD; + +public class RedisHashSetNoVectorConformanceTests(RedisHashSetNoVectorConformanceTests.Fixture fixture) + : NoVectorConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoVectorConformanceTests.Fixture + { + public override TestStore TestStore => RedisTestStore.HashSetInstance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisHashSetRecordConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisHashSetRecordConformanceTests.cs new file mode 100644 index 000000000000..157af94969c1 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisHashSetRecordConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using RedisIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace RedisIntegrationTests.CRUD; + +public class RedisHashSetRecordConformanceTests(RedisHashSetSimpleModelFixture fixture) + : RecordConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisJsonDynamicDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisJsonDynamicDataModelConformanceTests.cs new file mode 100644 index 000000000000..e91579a4e053 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisJsonDynamicDataModelConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using RedisIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace RedisIntegrationTests.CRUD; + +public class RedisJsonDynamicDataModelConformanceTests(RedisJsonDynamicDataModelFixture fixture) + : DynamicDataModelConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisJsonNoDataConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisJsonNoDataConformanceTests.cs new file mode 100644 index 000000000000..255e093efad1 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisJsonNoDataConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using RedisIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace RedisIntegrationTests.CRUD; + +public class RedisJsonNoDataConformanceTests(RedisJsonNoDataConformanceTests.Fixture fixture) + : NoDataConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoDataConformanceTests.Fixture + { + public override TestStore TestStore => RedisTestStore.JsonInstance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisJsonNoVectorConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisJsonNoVectorConformanceTests.cs new file mode 100644 index 000000000000..3d3828244bd9 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisJsonNoVectorConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using RedisIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace RedisIntegrationTests.CRUD; + +public class RedisJsonNoVectorConformanceTests(RedisJsonNoVectorConformanceTests.Fixture fixture) + : NoVectorConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoVectorConformanceTests.Fixture + { + public override TestStore TestStore => RedisTestStore.JsonInstance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisRecordConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisJsonRecordConformanceTests.cs similarity index 70% rename from dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisRecordConformanceTests.cs rename to dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisJsonRecordConformanceTests.cs index cab8188524fd..138cec84071d 100644 --- a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisRecordConformanceTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/CRUD/RedisJsonRecordConformanceTests.cs @@ -6,7 +6,7 @@ namespace RedisIntegrationTests.CRUD; -public class RedisRecordConformanceTests(RedisSimpleModelFixture fixture) - : RecordConformanceTests(fixture), IClassFixture +public class RedisJsonRecordConformanceTests(RedisJsonSimpleModelFixture fixture) + : RecordConformanceTests(fixture), IClassFixture { } diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Collections/RedisCollectionConformanceTests_HashSet.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Collections/RedisHashSetCollectionConformanceTests.cs similarity index 77% rename from dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Collections/RedisCollectionConformanceTests_HashSet.cs rename to dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Collections/RedisHashSetCollectionConformanceTests.cs index a3b7c411d8e4..0ebda7416993 100644 --- a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Collections/RedisCollectionConformanceTests_HashSet.cs +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Collections/RedisHashSetCollectionConformanceTests.cs @@ -6,7 +6,7 @@ namespace RedisIntegrationTests.Collections; -public class RedisCollectionConformanceTests_HashSet(RedisHashSetFixture fixture) +public class RedisHashSetCollectionConformanceTests(RedisHashSetFixture fixture) : CollectionConformanceTests(fixture), IClassFixture { } diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Collections/RedisCollectionConformanceTests_Json.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Collections/RedisJsonCollectionConformanceTests.cs similarity index 78% rename from dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Collections/RedisCollectionConformanceTests_Json.cs rename to dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Collections/RedisJsonCollectionConformanceTests.cs index 97d28ef6d17e..ed3fdc36db0b 100644 --- a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Collections/RedisCollectionConformanceTests_Json.cs +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Collections/RedisJsonCollectionConformanceTests.cs @@ -6,7 +6,7 @@ namespace RedisIntegrationTests.Collections; -public class RedisCollectionConformanceTests_Json(RedisJsonFixture fixture) +public class RedisJsonCollectionConformanceTests(RedisJsonFixture fixture) : CollectionConformanceTests(fixture), IClassFixture { } diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicFilterTests.cs index 7ec4f834a5f0..437048200f62 100644 --- a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicFilterTests.cs @@ -72,17 +72,17 @@ public class RedisJsonCollectionBasicFilterTests(RedisJsonCollectionBasicFilterT { public override TestStore TestStore => RedisTestStore.JsonInstance; - protected override string CollectionName => "JsonCollectionFilterTests"; + public override string CollectionName => "JsonCollectionFilterTests"; // Override to remove the bool property, which isn't (currently) supported on Redis/JSON - protected override VectorStoreRecordDefinition GetRecordDefinition() + public override VectorStoreRecordDefinition GetRecordDefinition() => new() { Properties = base.GetRecordDefinition().Properties.Where(p => p.PropertyType != typeof(bool)).ToList() }; - protected override IVectorStoreRecordCollection CreateCollection() - => new RedisJsonVectorStoreRecordCollection( + protected override IVectorStoreRecordCollection GetCollection() + => new RedisJsonVectorStoreRecordCollection( RedisTestStore.JsonInstance.Database, this.CollectionName, new() { VectorStoreRecordDefinition = this.GetRecordDefinition() }); @@ -124,10 +124,10 @@ public override Task Legacy_AnyTagEqualTo_List() { public override TestStore TestStore => RedisTestStore.HashSetInstance; - protected override string CollectionName => "HashSetCollectionFilterTests"; + public override string CollectionName => "HashSetCollectionFilterTests"; // Override to remove the bool property, which isn't (currently) supported on Redis - protected override VectorStoreRecordDefinition GetRecordDefinition() + public override VectorStoreRecordDefinition GetRecordDefinition() => new() { Properties = base.GetRecordDefinition().Properties.Where(p => @@ -136,8 +136,8 @@ protected override VectorStoreRecordDefinition GetRecordDefinition() p.PropertyType != typeof(List)).ToList() }; - protected override IVectorStoreRecordCollection CreateCollection() - => new RedisHashSetVectorStoreRecordCollection( + protected override IVectorStoreRecordCollection GetCollection() + => new RedisHashSetVectorStoreRecordCollection( RedisTestStore.HashSetInstance.Database, this.CollectionName, new() { VectorStoreRecordDefinition = this.GetRecordDefinition() }); diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicQueryTests.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicQueryTests.cs new file mode 100644 index 000000000000..a75cb113a3d4 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicQueryTests.cs @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Redis; +using RedisIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; +using Xunit; +using Xunit.Sdk; + +namespace RedisIntegrationTests.Filter; + +public abstract class RedisBasicQueryTests(BasicQueryTests.QueryFixture fixture) + : BasicQueryTests(fixture) +{ + #region Equality with null + + public override Task Equal_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task Equal_with_null_captured() + => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); + + public override Task NotEqual_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task NotEqual_with_null_captured() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); + + #endregion + + #region Bool + + public override Task Bool() + => Assert.ThrowsAsync(() => base.Bool()); + + public override Task Not_over_bool() + => Assert.ThrowsAsync(() => base.Not_over_bool()); + + public override Task Bool_And_Bool() + => Assert.ThrowsAsync(() => base.Bool_And_Bool()); + + public override Task Bool_Or_Not_Bool() + => Assert.ThrowsAsync(() => base.Bool_Or_Not_Bool()); + + public override Task Not_over_bool_And_Comparison() + => Assert.ThrowsAsync(() => base.Not_over_bool_And_Comparison()); + + #endregion + + #region Contains + + public override Task Contains_over_inline_int_array() + => Assert.ThrowsAsync(() => base.Contains_over_inline_int_array()); + + public override Task Contains_over_inline_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_inline_string_array()); + + public override Task Contains_over_inline_string_array_with_weird_chars() + => Assert.ThrowsAsync(() => base.Contains_over_inline_string_array_with_weird_chars()); + + public override Task Contains_over_captured_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_captured_string_array()); + + #endregion +} + +public class RedisJsonCollectionBasicQueryTests(RedisJsonCollectionBasicQueryTests.Fixture fixture) + : RedisBasicQueryTests(fixture), IClassFixture +{ + public new class Fixture : BasicQueryTests.QueryFixture + { + public override TestStore TestStore => RedisTestStore.JsonInstance; + + public override string CollectionName => "JsonCollectionQueryTests"; + + // Override to remove the bool property, which isn't (currently) supported on Redis/JSON + public override VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = base.GetRecordDefinition().Properties.Where(p => p.PropertyType != typeof(bool)).ToList() + }; + + protected override IVectorStoreRecordCollection GetCollection() + => new RedisJsonVectorStoreRecordCollection( + RedisTestStore.JsonInstance.Database, + this.CollectionName, + new() { VectorStoreRecordDefinition = this.GetRecordDefinition() }); + } +} + +public class RedisHashSetCollectionBasicQueryTests(RedisHashSetCollectionBasicQueryTests.Fixture fixture) + : RedisBasicQueryTests(fixture), IClassFixture +{ + // Null values are not supported in Redis HashSet + public override Task Equal_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task Equal_with_null_captured() + => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); + + public override Task NotEqual_with_null_reference_type() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_reference_type()); + + public override Task NotEqual_with_null_captured() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); + + // Array fields not supported on Redis HashSet + public override Task Contains_over_field_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + + public override Task Contains_over_field_string_List() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + + public new class Fixture : BasicQueryTests.QueryFixture + { + public override TestStore TestStore => RedisTestStore.HashSetInstance; + + public override string CollectionName => "HashSetCollectionQueryTests"; + + // Override to remove the bool property, which isn't (currently) supported on Redis + public override VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = base.GetRecordDefinition().Properties.Where(p => + p.PropertyType != typeof(bool) && + p.PropertyType != typeof(string[]) && + p.PropertyType != typeof(List)).ToList() + }; + + protected override IVectorStoreRecordCollection GetCollection() + => new RedisHashSetVectorStoreRecordCollection( + RedisTestStore.HashSetInstance.Database, + this.CollectionName, + new() { VectorStoreRecordDefinition = this.GetRecordDefinition() }); + + protected override List BuildTestData() + { + var testData = base.BuildTestData(); + + foreach (var record in testData) + { + // Null values are not supported in Redis hashsets + record.String ??= string.Empty; + } + + return testData; + } + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisHashSetEmbeddingGenerationTests.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisHashSetEmbeddingGenerationTests.cs new file mode 100644 index 000000000000..924e26ab4866 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisHashSetEmbeddingGenerationTests.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Redis; +using RedisIntegrationTests.Support; +using VectorDataSpecificationTests; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace RedisIntegrationTests; + +public class RedisHashSetEmbeddingGenerationTests(RedisHashSetEmbeddingGenerationTests.Fixture fixture) + : EmbeddingGenerationTests(fixture), IClassFixture +{ + public new class Fixture : EmbeddingGenerationTests.Fixture + { + public override TestStore TestStore => RedisTestStore.HashSetInstance; + + public override IVectorStore CreateVectorStore(IEmbeddingGenerator? embeddingGenerator) + => RedisTestStore.HashSetInstance.GetVectorStore(new() { StorageType = RedisStorageType.HashSet, EmbeddingGenerator = embeddingGenerator }); + + public override Func[] DependencyInjectionStoreRegistrationDelegates => + [ + // TODO: This doesn't work because if a RedisVectorStoreOptions is provided (and it needs to be for HashSet), the embedding generator + // isn't looked up in DI. The options are also immutable so we can't inject an embedding generator into them. + // services => services + // .AddSingleton(RedisTestStore.HashSetInstance.Database) + // .AddRedisVectorStore(new RedisVectorStoreOptions() { StorageType = RedisStorageType.HashSet}) + ]; + + public override Func[] DependencyInjectionCollectionRegistrationDelegates => + [ + services => services + .AddSingleton(RedisTestStore.HashSetInstance.Database) + .AddRedisHashSetVectorStoreRecordCollection(this.CollectionName) + ]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisJsonEmbeddingGenerationTests.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisJsonEmbeddingGenerationTests.cs new file mode 100644 index 000000000000..af52948ad43b --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisJsonEmbeddingGenerationTests.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel; +using RedisIntegrationTests.Support; +using VectorDataSpecificationTests; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace RedisIntegrationTests; + +public class RedisJsonEmbeddingGenerationTests(RedisJsonEmbeddingGenerationTests.Fixture fixture) + : EmbeddingGenerationTests(fixture), IClassFixture +{ + public new class Fixture : EmbeddingGenerationTests.Fixture + { + public override TestStore TestStore => RedisTestStore.JsonInstance; + + public override IVectorStore CreateVectorStore(IEmbeddingGenerator? embeddingGenerator) + => RedisTestStore.JsonInstance.GetVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); + + public override Func[] DependencyInjectionStoreRegistrationDelegates => + [ + services => services + .AddSingleton(RedisTestStore.JsonInstance.Database) + .AddRedisVectorStore() + ]; + + public override Func[] DependencyInjectionCollectionRegistrationDelegates => + [ + services => services + .AddSingleton(RedisTestStore.JsonInstance.Database) + .AddRedisJsonVectorStoreRecordCollection(this.CollectionName) + ]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisHashSetDynamicDataModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisHashSetDynamicDataModelFixture.cs new file mode 100644 index 000000000000..c495c111f122 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisHashSetDynamicDataModelFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace RedisIntegrationTests.Support; + +public class RedisHashSetDynamicDataModelFixture : DynamicDataModelFixture +{ + public override TestStore TestStore => RedisTestStore.HashSetInstance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisHashSetSimpleModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisHashSetSimpleModelFixture.cs new file mode 100644 index 000000000000..b8bcf29ead8e --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisHashSetSimpleModelFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace RedisIntegrationTests.Support; + +public class RedisHashSetSimpleModelFixture : SimpleModelFixture +{ + public override TestStore TestStore => RedisTestStore.HashSetInstance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisGenericDataModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisJsonDynamicDataModelFixture.cs similarity index 72% rename from dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisGenericDataModelFixture.cs rename to dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisJsonDynamicDataModelFixture.cs index 3a63d1d77f76..1023cab93ce0 100644 --- a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisGenericDataModelFixture.cs +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisJsonDynamicDataModelFixture.cs @@ -4,7 +4,7 @@ namespace RedisIntegrationTests.Support; -public class RedisGenericDataModelFixture : GenericDataModelFixture +public class RedisJsonDynamicDataModelFixture : DynamicDataModelFixture { public override TestStore TestStore => RedisTestStore.JsonInstance; } diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisSimpleModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisJsonSimpleModelFixture.cs similarity index 75% rename from dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisSimpleModelFixture.cs rename to dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisJsonSimpleModelFixture.cs index f91aefd9055c..480e00aad6df 100644 --- a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisSimpleModelFixture.cs +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisJsonSimpleModelFixture.cs @@ -4,7 +4,7 @@ namespace RedisIntegrationTests.Support; -public class RedisSimpleModelFixture : SimpleModelFixture +public class RedisJsonSimpleModelFixture : SimpleModelFixture { public override TestStore TestStore => RedisTestStore.JsonInstance; } diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisTestStore.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisTestStore.cs index 6ee6f058da46..5744dd6e53c4 100644 --- a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisTestStore.cs @@ -15,6 +15,8 @@ internal sealed class RedisTestStore : TestStore private readonly RedisContainer _container = new RedisBuilder() .WithImage("redis/redis-stack") + .WithPortBinding(6379, assignRandomHostPort: true) + .WithPortBinding(8001, assignRandomHostPort: true) .Build(); private readonly RedisStorageType _storageType; diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBatchConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBatchConformanceTests.cs index c2d71d49281b..bb2536fafd5e 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBatchConformanceTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBatchConformanceTests.cs @@ -25,50 +25,54 @@ public Task CanSplitBatchToAccountForMaxParameterLimit_WithoutVectors() private async Task CanSplitBatchToAccountForMaxParameterLimit(bool includeVectors) { var collection = fixture.Collection; - SimpleModel[] inserted = Enumerable.Range(0, SqlServerMaxParameters + 1).Select(i => new SimpleModel() + SimpleRecord[] inserted = Enumerable.Range(0, SqlServerMaxParameters + 1).Select(i => new SimpleRecord() { Id = fixture.GenerateNextKey(), Number = 100 + i, Text = i.ToString(), - Floats = Enumerable.Range(0, SimpleModel.DimensionCount).Select(j => (float)(i + j)).ToArray() + Floats = Enumerable.Range(0, SimpleRecord.DimensionCount).Select(j => (float)(i + j)).ToArray() }).ToArray(); var keys = inserted.Select(record => record.Id).ToArray(); - Assert.Empty(await collection.GetBatchAsync(keys).ToArrayAsync()); - var receivedKeys = await collection.UpsertBatchAsync(inserted).ToArrayAsync(); + Assert.Empty(await collection.GetAsync(keys).ToArrayAsync()); + var receivedKeys = await collection.UpsertAsync(inserted); Assert.Equal(keys.ToHashSet(), receivedKeys.ToHashSet()); // .ToHashSet() to ignore order - var received = await collection.GetBatchAsync(keys, new() { IncludeVectors = includeVectors }).ToArrayAsync(); + var received = await collection.GetAsync(keys, new() { IncludeVectors = includeVectors }).ToArrayAsync(); foreach (var record in inserted) { - record.AssertEqual(this.GetRecord(received, record.Id), includeVectors); + record.AssertEqual(this.GetRecord(received, record.Id), includeVectors, fixture.TestStore.VectorsComparable); } - await collection.DeleteBatchAsync(keys); - Assert.Empty(await collection.GetBatchAsync(keys).ToArrayAsync()); + await collection.DeleteAsync(keys); + Assert.Empty(await collection.GetAsync(keys).ToArrayAsync()); } [ConditionalFact] public async Task UpsertBatchIsAtomic() { var collection = fixture.Collection; - SimpleModel[] inserted = Enumerable.Range(0, SqlServerMaxParameters + 1).Select(i => new SimpleModel() + SimpleRecord[] inserted = Enumerable.Range(0, SqlServerMaxParameters + 1).Select(i => new SimpleRecord() { // The last Id is set to NULL, so it must not be inserted and the whole batch should fail Id = i < SqlServerMaxParameters ? fixture.GenerateNextKey() : null!, Number = 100 + i, Text = i.ToString(), - Floats = Enumerable.Range(0, SimpleModel.DimensionCount).Select(j => (float)(i + j)).ToArray() + Floats = Enumerable.Range(0, SimpleRecord.DimensionCount).Select(j => (float)(i + j)).ToArray() }).ToArray(); var keys = inserted.Select(record => record.Id).Where(key => key is not null).ToArray(); - Assert.Empty(await collection.GetBatchAsync(keys).ToArrayAsync()); + Assert.Empty(await collection.GetAsync(keys).ToArrayAsync()); - VectorStoreOperationException ex = await Assert.ThrowsAsync(() => collection.UpsertBatchAsync(inserted).ToArrayAsync().AsTask()); + VectorStoreOperationException ex = await Assert.ThrowsAsync(() => collection.UpsertAsync(inserted)); Assert.Equal("UpsertBatch", ex.OperationName); - Assert.Equal(collection.CollectionName, ex.CollectionName); + + var metadata = collection.GetService(typeof(VectorStoreRecordCollectionMetadata)) as VectorStoreRecordCollectionMetadata; + + Assert.NotNull(metadata?.CollectionName); + Assert.Equal(metadata.CollectionName, ex.CollectionName); // Make sure that no records were inserted! - Assert.Empty(await collection.GetBatchAsync(keys).ToArrayAsync()); + Assert.Empty(await collection.GetAsync(keys).ToArrayAsync()); } } diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerDynamicDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerDynamicDataModelConformanceTests.cs new file mode 100644 index 000000000000..3a9e9c3e0672 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerDynamicDataModelConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServerIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace SqlServerIntegrationTests.CRUD; + +public class SqlServerDynamicDataModelConformanceTests(SqlServerDynamicDataModelFixture fixture) + : DynamicDataModelConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerGenericDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerGenericDataModelConformanceTests.cs deleted file mode 100644 index d3f67389e764..000000000000 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerGenericDataModelConformanceTests.cs +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using SqlServerIntegrationTests.Support; -using VectorDataSpecificationTests.CRUD; -using Xunit; - -namespace SqlServerIntegrationTests.CRUD; - -public class SqlServerGenericDataModelConformanceTests(SqlServerGenericDataModelFixture fixture) - : GenericDataModelConformanceTests(fixture), IClassFixture -{ -} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerNoDataConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerNoDataConformanceTests.cs new file mode 100644 index 000000000000..e3303861968a --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerNoDataConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServerIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace SqlServerIntegrationTests.CRUD; + +public class SqlServerNoDataConformanceTests(SqlServerNoDataConformanceTests.Fixture fixture) + : NoDataConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoDataConformanceTests.Fixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerNoVectorConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerNoVectorConformanceTests.cs new file mode 100644 index 000000000000..7b70e75d7b70 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerNoVectorConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServerIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace SqlServerIntegrationTests.CRUD; + +public class SqlServerNoVectorConformanceTests(SqlServerNoVectorConformanceTests.Fixture fixture) + : NoVectorConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoVectorConformanceTests.Fixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs index 3bae6cc48552..209b95a73d45 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs @@ -9,6 +9,8 @@ namespace SqlServerIntegrationTests.Filter; +#pragma warning disable CS0252 // Possible unintended reference comparison; left hand side needs cast + public class SqlServerBasicFilterTests(SqlServerBasicFilterTests.Fixture fixture) : BasicFilterTests(fixture), IClassFixture { @@ -16,18 +18,22 @@ public override async Task Not_over_Or() { // Test sends: WHERE (NOT (("Int" = 8) OR ("String" = 'foo'))) // There's a NULL string in the database, and relational null semantics in conjunction with negation makes the default implementation fail. - await Assert.ThrowsAsync(() => base.Not_over_Or()); + await Assert.ThrowsAsync(() => base.Not_over_Or()); // Compensate by adding a null check: - await this.TestFilterAsync(r => r.String != null && !(r.Int == 8 || r.String == "foo")); + await this.TestFilterAsync( + r => r.String != null && !(r.Int == 8 || r.String == "foo"), + r => r["String"] != null && !((int)r["Int"]! == 8 || r["String"] == "foo")); } public override async Task NotEqual_with_string() { // As above, null semantics + negation - await Assert.ThrowsAsync(() => base.NotEqual_with_string()); + await Assert.ThrowsAsync(() => base.NotEqual_with_string()); - await this.TestFilterAsync(r => r.String != null && r.String != "foo"); + await this.TestFilterAsync( + r => r.String != null && r.String != "foo", + r => r["String"] != null && r["String"] != "foo"); } public override Task Contains_over_field_string_array() @@ -58,10 +64,10 @@ public override Task Contains_over_field_string_List() public override TestStore TestStore => SqlServerTestStore.Instance; - protected override string CollectionName => s_uniqueName; + public override string CollectionName => s_uniqueName; // Override to remove the string collection properties, which aren't (currently) supported on SqlServer - protected override VectorStoreRecordDefinition GetRecordDefinition() + public override VectorStoreRecordDefinition GetRecordDefinition() => new() { Properties = base.GetRecordDefinition().Properties.Where(p => p.PropertyType != typeof(string[]) && p.PropertyType != typeof(List)).ToList() diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicQueryTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicQueryTests.cs new file mode 100644 index 000000000000..519e8f0e40f6 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicQueryTests.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using SqlServerIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; +using Xunit; +using Xunit.Sdk; + +namespace SqlServerIntegrationTests.Filter; + +#pragma warning disable CS0252 // Possible unintended reference comparison; left hand side needs cast + +public class SqlServerBasicQueryTests(SqlServerBasicQueryTests.Fixture fixture) + : BasicQueryTests(fixture), IClassFixture +{ + public override async Task Not_over_Or() + { + // Test sends: WHERE (NOT (("Int" = 8) OR ("String" = 'foo'))) + // There's a NULL string in the database, and relational null semantics in conjunction with negation makes the default implementation fail. + await Assert.ThrowsAsync(() => base.Not_over_Or()); + + // Compensate by adding a null check: + await this.TestFilterAsync( + r => r.String != null && !(r.Int == 8 || r.String == "foo"), + r => r["String"] != null && !((int)r["Int"]! == 8 || r["String"] == "foo")); + } + + public override async Task NotEqual_with_string() + { + // As above, null semantics + negation + await Assert.ThrowsAsync(() => base.NotEqual_with_string()); + + await this.TestFilterAsync( + r => r.String != null && r.String != "foo", + r => r["String"] != null && r["String"] != "foo"); + } + + public override Task Contains_over_field_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + + public override Task Contains_over_field_string_List() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + + public new class Fixture : BasicQueryTests.QueryFixture + { + private static readonly string s_uniqueName = Guid.NewGuid().ToString(); + + public override TestStore TestStore => SqlServerTestStore.Instance; + + public override string CollectionName => s_uniqueName; + + // Override to remove the string collection properties, which aren't (currently) supported on SqlServer + public override VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = base.GetRecordDefinition().Properties.Where(p => p.PropertyType != typeof(string[]) && p.PropertyType != typeof(List)).ToList() + }; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index c8ed9a0cdda1..dda2cf5252fe 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -3,6 +3,7 @@ using System.Text; using Microsoft.Data.SqlClient; using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ConnectorSupport; using Microsoft.SemanticKernel.Connectors.SqlServer; using Xunit; @@ -34,7 +35,7 @@ public void AppendParameterName(string propertyName, string expectedPrefix) { StringBuilder builder = new(); StringBuilder expectedBuilder = new(); - VectorStoreRecordKeyProperty keyProperty = new(propertyName, typeof(string)); + VectorStoreRecordKeyPropertyModel keyProperty = new(propertyName, typeof(string)); int paramIndex = 0; // we need a dedicated variable to ensure that AppendParameterName increments the index for (int i = 0; i < 10; i++) @@ -107,23 +108,17 @@ FROM INFORMATION_SCHEMA.TABLES [InlineData(false)] public void CreateTable(bool ifNotExists) { - VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); - VectorStoreRecordDataProperty[] dataProperties = - [ - new VectorStoreRecordDataProperty("simpleName", typeof(string)), - new VectorStoreRecordDataProperty("with space", typeof(int)) - ]; - VectorStoreRecordVectorProperty[] vectorProperties = - [ - new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) - { - Dimensions = 10 - } - ]; + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("simpleName", typeof(string)), + new VectorStoreRecordDataProperty("with space", typeof(int)) { IsIndexed = true }, + new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory), 10) + ]); + using SqlConnection connection = CreateConnection(); - using SqlCommand command = SqlServerCommandBuilder.CreateTable(connection, "schema", "table", - ifNotExists, keyProperty, dataProperties, vectorProperties); + using SqlCommand command = SqlServerCommandBuilder.CreateTable(connection, "schema", "table", ifNotExists, model); string expectedCommand = """ @@ -135,6 +130,7 @@ [simpleName] NVARCHAR(MAX), [embedding] VECTOR(10), PRIMARY KEY ([id]) ); + CREATE INDEX index_table_withspace ON [schema].[table]([with space]); END; """; if (ifNotExists) @@ -142,27 +138,22 @@ PRIMARY KEY ([id]) expectedCommand = "IF OBJECT_ID(N'[schema].[table]', N'U') IS NULL" + Environment.NewLine + expectedCommand; } - AssertEqualIgnoreNewLines(expectedCommand, command.CommandText); + Assert.Equal(expectedCommand, command.CommandText, ignoreLineEndingDifferences: true); } [Fact] public void MergeIntoSingle() { - VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); - VectorStoreRecordProperty[] properties = - [ - keyProperty, - new VectorStoreRecordDataProperty("simpleString", typeof(string)), - new VectorStoreRecordDataProperty("simpleInt", typeof(int)), - new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) - { - Dimensions = 10 - } - ]; + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("simpleString", typeof(string)), + new VectorStoreRecordDataProperty("simpleInt", typeof(int)), + new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory), 10) + ]); using SqlConnection connection = CreateConnection(); - using SqlCommand command = SqlServerCommandBuilder.MergeIntoSingle(connection, "schema", "table", - keyProperty, properties, + using SqlCommand command = SqlServerCommandBuilder.MergeIntoSingle(connection, "schema", "table", model, new Dictionary { { "id", null }, @@ -184,7 +175,7 @@ WHEN NOT MATCHED THEN OUTPUT inserted.[id]; """"; - AssertEqualIgnoreNewLines(expectedCommand, command.CommandText); + Assert.Equal(expectedCommand, command.CommandText, ignoreLineEndingDifferences: true); Assert.Equal("@id_0", command.Parameters[0].ParameterName); Assert.Equal(DBNull.Value, command.Parameters[0].Value); Assert.Equal("@simpleString_1", command.Parameters[1].ParameterName); @@ -198,17 +189,14 @@ WHEN NOT MATCHED THEN [Fact] public void MergeIntoMany() { - VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); - VectorStoreRecordProperty[] properties = - [ - keyProperty, - new VectorStoreRecordDataProperty("simpleString", typeof(string)), - new VectorStoreRecordDataProperty("simpleInt", typeof(int)), - new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) - { - Dimensions = 10 - } - ]; + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("simpleString", typeof(string)), + new VectorStoreRecordDataProperty("simpleInt", typeof(int)), + new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory), 10) + ]); + Dictionary[] records = [ new Dictionary @@ -230,8 +218,7 @@ public void MergeIntoMany() using SqlConnection connection = CreateConnection(); using SqlCommand command = connection.CreateCommand(); - Assert.True(SqlServerCommandBuilder.MergeIntoMany(command, "schema", "table", - keyProperty, properties, records)); + Assert.True(SqlServerCommandBuilder.MergeIntoMany(command, "schema", "table", model, records)); string expectedCommand = """" @@ -250,7 +237,7 @@ WHEN NOT MATCHED THEN SELECT KeyColumn FROM @InsertedKeys; """"; - AssertEqualIgnoreNewLines(expectedCommand, command.CommandText); + Assert.Equal(expectedCommand, command.CommandText, ignoreLineEndingDifferences: true); for (int i = 0; i < records.Length; i++) { @@ -268,7 +255,7 @@ WHEN NOT MATCHED THEN [Fact] public void DeleteSingle() { - VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); + VectorStoreRecordKeyPropertyModel keyProperty = new("id", typeof(long)); using SqlConnection connection = CreateConnection(); using SqlCommand command = SqlServerCommandBuilder.DeleteSingle(connection, @@ -283,7 +270,7 @@ public void DeleteSingle() public void DeleteMany() { string[] keys = ["key1", "key2"]; - VectorStoreRecordKeyProperty keyProperty = new("id", typeof(string)); + VectorStoreRecordKeyPropertyModel keyProperty = new("id", typeof(string)); using SqlConnection connection = CreateConnection(); using SqlCommand command = connection.CreateCommand(); @@ -300,27 +287,24 @@ public void DeleteMany() [Fact] public void SelectSingle() { - VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); - VectorStoreRecordProperty[] properties = [ - keyProperty, - new VectorStoreRecordDataProperty("name", typeof(string)), - new VectorStoreRecordDataProperty("age", typeof(int)), - new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) - { - Dimensions = 10 - } - ]; + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("age", typeof(int)), + new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory), 10) + ]); + using SqlConnection connection = CreateConnection(); - using SqlCommand command = SqlServerCommandBuilder.SelectSingle(connection, - "schema", "tableName", keyProperty, properties, 123L, includeVectors: true); + using SqlCommand command = SqlServerCommandBuilder.SelectSingle(connection, "schema", "tableName", model, 123L, includeVectors: true); - AssertEqualIgnoreNewLines( + Assert.Equal( """"" SELECT [id],[name],[age],[embedding] FROM [schema].[tableName] WHERE [id] = @id_0 - """"", command.CommandText); + """"", command.CommandText, ignoreLineEndingDifferences: true); Assert.Equal(123L, command.Parameters[0].Value); Assert.Equal("@id_0", command.Parameters[0].ParameterName); } @@ -328,29 +312,27 @@ FROM [schema].[tableName] [Fact] public void SelectMany() { - VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); - VectorStoreRecordProperty[] properties = [ - keyProperty, + var model = BuildModel( + [ + new VectorStoreRecordKeyProperty("id", typeof(long)), new VectorStoreRecordDataProperty("name", typeof(string)), new VectorStoreRecordDataProperty("age", typeof(int)), - new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) - { - Dimensions = 10 - } - ]; + new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory), 10) + ]); + long[] keys = [123L, 456L, 789L]; using SqlConnection connection = CreateConnection(); using SqlCommand command = connection.CreateCommand(); Assert.True(SqlServerCommandBuilder.SelectMany(command, - "schema", "tableName", keyProperty, properties, keys, includeVectors: true)); + "schema", "tableName", model, keys, includeVectors: true)); - AssertEqualIgnoreNewLines( + Assert.Equal( """"" SELECT [id],[name],[age],[embedding] FROM [schema].[tableName] WHERE [id] IN (@id_0,@id_1,@id_2) - """"", command.CommandText); + """"", command.CommandText, ignoreLineEndingDifferences: true); for (int i = 0; i < keys.Length; i++) { Assert.Equal(keys[i], command.Parameters[i].Value); @@ -358,13 +340,14 @@ WHERE [id] IN (@id_0,@id_1,@id_2) } } - // This repo is configured with eol=lf, so the expected string should always use \n - // as long given IDE does not use \r\n. - // The actual string may use \r\n, so we just normalize both. - private static void AssertEqualIgnoreNewLines(string expected, string actual) - => Assert.Equal(expected.Replace("\r\n", "\n"), actual.Replace("\r\n", "\n")); - // We create a connection using a fake connection string just to be able to create the SqlCommand. private static SqlConnection CreateConnection() => new("Server=localhost;Database=master;Integrated Security=True;"); + + private static VectorStoreRecordModel BuildModel(List properties) + => new VectorStoreRecordModelBuilder(SqlServerConstants.ModelBuildingOptions) + .Build( + typeof(Dictionary), + new() { Properties = properties }, + defaultEmbeddingGenerator: null); } diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerEmbeddingGenerationTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerEmbeddingGenerationTests.cs new file mode 100644 index 000000000000..ea72466716dd --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerEmbeddingGenerationTests.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using SqlServerIntegrationTests.Support; +using VectorDataSpecificationTests; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace SqlServerIntegrationTests; + +public class SqlServerEmbeddingGenerationTests(SqlServerEmbeddingGenerationTests.Fixture fixture) + : EmbeddingGenerationTests(fixture), IClassFixture +{ + public new class Fixture : EmbeddingGenerationTests.Fixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + + public override IVectorStore CreateVectorStore(IEmbeddingGenerator? embeddingGenerator) + => SqlServerTestStore.Instance.GetVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); + + public override Func[] DependencyInjectionStoreRegistrationDelegates => + [ + // TODO: Implement DI registration for SqlServer (https://github.com/microsoft/semantic-kernel/issues/10948) + ]; + + public override Func[] DependencyInjectionCollectionRegistrationDelegates => + [ + // TODO: Implement DI registration for SqlServer (https://github.com/microsoft/semantic-kernel/issues/10948) + ]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj index 4752d82818dc..e81066f2d839 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj @@ -8,8 +8,10 @@ false true - $(NoWarn);CA2007,SKEXP0001,SKEXP0020,VSTHRD111;CS1685 + $(NoWarn);CA2007,SKEXP0001,VSTHRD111;CS1685 b7762d10-e29b-4bb1-8b74-b6d69a667dd4 + + $(NoWarn);MEVD9000,MEVD9001 diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs index 23e714ff60bd..f1b23dcfed79 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs @@ -11,6 +11,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.SqlServer; /// /// Unit tests for class. /// +[Obsolete("The IMemoryStore abstraction is being obsoleted")] public class SqlServerMemoryStoreTests : IAsyncLifetime { private const string? SkipReason = "Configure SQL Server or Azure SQL connection string and then set this to 'null'."; @@ -339,7 +340,7 @@ private async Task CleanupDatabaseAsync() await connection.OpenAsync(); cmd.CommandText = $""" DECLARE tables_cursor CURSOR FOR - SELECT table_name + SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = '{SchemaName}' diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs index 084159af79ce..563473def484 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -2,7 +2,6 @@ using Microsoft.Data.SqlClient; using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.SqlServer; using SqlServerIntegrationTests.Support; using VectorDataSpecificationTests.Xunit; using Xunit; @@ -82,18 +81,18 @@ public async Task RecordCRUD() received = await collection.GetAsync(updated.Id, new() { IncludeVectors = true }); AssertEquality(updated, received); - VectorSearchResult vectorSearchResult = await (await collection.VectorizedSearchAsync(inserted.Floats, new() + VectorSearchResult vectorSearchResult = await (collection.SearchEmbeddingAsync(inserted.Floats, top: 3, new() { VectorProperty = r => r.Floats, IncludeVectors = true - })).Results.SingleAsync(); + })).SingleAsync(); AssertEquality(updated, vectorSearchResult.Record); - vectorSearchResult = await (await collection.VectorizedSearchAsync(inserted.Floats, new() + vectorSearchResult = await (collection.SearchEmbeddingAsync(inserted.Floats, top: 3, new() { VectorProperty = r => r.Floats, IncludeVectors = false - })).Results.SingleAsync(); + })).SingleAsync(); // Make sure the vectors are not included in the result. Assert.Equal(0, vectorSearchResult.Record.Floats.Length); @@ -150,58 +149,6 @@ public async Task WrongModels() } } - [ConditionalFact] - public async Task CustomMapper() - { - string collectionName = GetUniqueCollectionName(); - TestModelMapper mapper = new(); - SqlServerVectorStoreRecordCollectionOptions options = new() - { - Mapper = mapper - }; - SqlServerVectorStoreRecordCollection collection = new(SqlServerTestEnvironment.ConnectionString!, collectionName, options); - - try - { - await collection.CreateCollectionIfNotExistsAsync(); - - TestModel inserted = new() - { - Id = "MyId", - Number = 100, - Floats = Enumerable.Range(0, 10).Select(i => (float)i).ToArray() - }; - string key = await collection.UpsertAsync(inserted); - Assert.Equal(inserted.Id, key); - Assert.True(mapper.MapFromDataToStorageModel_WasCalled); - Assert.False(mapper.MapFromStorageToDataModel_WasCalled); - - TestModel? received = await collection.GetAsync(inserted.Id, new() { IncludeVectors = true }); - AssertEquality(inserted, received); - Assert.True(mapper.MapFromStorageToDataModel_WasCalled); - - TestModel updated = new() - { - Id = inserted.Id, - Number = inserted.Number + 200, // change one property - Floats = inserted.Floats - }; - key = await collection.UpsertAsync(updated); - Assert.Equal(inserted.Id, key); - - received = await collection.GetAsync(updated.Id, new() { IncludeVectors = true }); - AssertEquality(updated, received); - - await collection.DeleteAsync(inserted.Id); - - Assert.Null(await collection.GetAsync(inserted.Id)); - } - finally - { - await collection.DeleteCollectionAsync(); - } - } - [ConditionalFact] public async Task BatchCRUD() { @@ -220,13 +167,13 @@ public async Task BatchCRUD() Floats = Enumerable.Range(0, 10).Select(j => (float)(i + j)).ToArray() }).ToArray(); - string[] keys = await collection.UpsertBatchAsync(inserted).ToArrayAsync(); + var keys = await collection.UpsertAsync(inserted); for (int i = 0; i < inserted.Length; i++) { Assert.Equal(inserted[i].Id, keys[i]); } - TestModel[] received = await collection.GetBatchAsync(keys, new() { IncludeVectors = true }).ToArrayAsync(); + TestModel[] received = await collection.GetAsync(keys, new() { IncludeVectors = true }).ToArrayAsync(); for (int i = 0; i < inserted.Length; i++) { AssertEquality(inserted[i], received[i]); @@ -239,21 +186,21 @@ public async Task BatchCRUD() Floats = i.Floats }).ToArray(); - keys = await collection.UpsertBatchAsync(updated).ToArrayAsync(); + keys = await collection.UpsertAsync(updated); for (int i = 0; i < updated.Length; i++) { Assert.Equal(updated[i].Id, keys[i]); } - received = await collection.GetBatchAsync(keys, new() { IncludeVectors = true }).ToArrayAsync(); + received = await collection.GetAsync(keys, new() { IncludeVectors = true }).ToArrayAsync(); for (int i = 0; i < updated.Length; i++) { AssertEquality(updated[i], received[i]); } - await collection.DeleteBatchAsync(keys); + await collection.DeleteAsync(keys); - Assert.False(await collection.GetBatchAsync(keys).AnyAsync()); + Assert.False(await collection.GetAsync(keys).AnyAsync()); } finally { @@ -466,37 +413,4 @@ public sealed class FancyTestModel [VectorStoreRecordVector(Dimensions: 10, StoragePropertyName = "embedding")] public ReadOnlyMemory Floats { get; set; } } - - private sealed class TestModelMapper : IVectorStoreRecordMapper> - { - internal bool MapFromDataToStorageModel_WasCalled { get; set; } - internal bool MapFromStorageToDataModel_WasCalled { get; set; } - - public IDictionary MapFromDataToStorageModel(TestModel dataModel) - { - this.MapFromDataToStorageModel_WasCalled = true; - - return new Dictionary() - { - { "key", dataModel.Id }, - { "text", dataModel.Text }, - { "column", dataModel.Number }, - // Please note that we are not dealing with JSON directly here. - { "embedding", dataModel.Floats } - }; - } - - public TestModel MapFromStorageToDataModel(IDictionary storageModel, StorageToDataModelMapperOptions options) - { - this.MapFromStorageToDataModel_WasCalled = true; - - return new() - { - Id = (string)storageModel["key"]!, - Text = (string?)storageModel["text"], - Number = (int)storageModel["column"]!, - Floats = (ReadOnlyMemory)storageModel["embedding"]! - }; - } - } } diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerGenericDataModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerDynamicDataModelFixture.cs similarity index 78% rename from dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerGenericDataModelFixture.cs rename to dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerDynamicDataModelFixture.cs index d3be9dbe419d..0ff725729c49 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerGenericDataModelFixture.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerDynamicDataModelFixture.cs @@ -4,7 +4,7 @@ namespace SqlServerIntegrationTests.Support; -public class SqlServerGenericDataModelFixture : GenericDataModelFixture +public class SqlServerDynamicDataModelFixture : DynamicDataModelFixture { public override TestStore TestStore => SqlServerTestStore.Instance; } diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestEnvironment.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestEnvironment.cs index 043f4882e640..584a289f2902 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestEnvironment.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestEnvironment.cs @@ -17,7 +17,7 @@ internal static class SqlServerTestEnvironment .AddJsonFile(path: "testsettings.json", optional: true) .AddJsonFile(path: "testsettings.development.json", optional: true) .AddEnvironmentVariables() - .AddUserSecrets() + .AddUserSecrets() .Build(); return configuration.GetSection("SqlServer")["ConnectionString"]; diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs index 421bf7621d7f..b99303b6ef6c 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs @@ -8,23 +8,31 @@ namespace SqlServerIntegrationTests.Support; public sealed class SqlServerTestStore : TestStore { + private string? _connectionString; + public string ConnectionString => this._connectionString ?? throw new InvalidOperationException("Not initialized"); + public static readonly SqlServerTestStore Instance = new(); public override IVectorStore DefaultVectorStore - => this._connectedStore ?? throw new InvalidOperationException("Not initialized"); + => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + + public SqlServerVectorStore GetVectorStore(SqlServerVectorStoreOptions options) + => new(this.ConnectionString, options); public override string DefaultDistanceFunction => DistanceFunction.CosineDistance; - private SqlServerVectorStore? _connectedStore; + private SqlServerVectorStore? _defaultVectorStore; protected override Task StartAsync() { - if (string.IsNullOrWhiteSpace(SqlServerTestEnvironment.ConnectionString)) + this._connectionString = SqlServerTestEnvironment.ConnectionString; + + if (string.IsNullOrWhiteSpace(this._connectionString)) { throw new InvalidOperationException("Connection string is not configured, set the SqlServer:ConnectionString environment variable"); } - this._connectedStore = new(SqlServerTestEnvironment.ConnectionString); + this._defaultVectorStore = new(this._connectionString); return Task.CompletedTask; } diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteBatchConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteBatchConformanceTests.cs new file mode 100644 index 000000000000..21893736060e --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteBatchConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqliteIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace SqliteIntegrationTests.CRUD; + +public class SqliteBatchConformanceTests_string(SqliteSimpleModelFixture fixture) + : BatchConformanceTests(fixture), IClassFixture> +{ +} + +public class SqliteBatchConformanceTests_ulong(SqliteSimpleModelFixture fixture) + : BatchConformanceTests(fixture), IClassFixture> +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteNoDataConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteNoDataConformanceTests.cs new file mode 100644 index 000000000000..934cbd8a4032 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteNoDataConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqliteIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace SqliteIntegrationTests.CRUD; + +public class SqliteNoDataConformanceTests(SqliteNoDataConformanceTests.Fixture fixture) + : NoDataConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoDataConformanceTests.Fixture + { + public override TestStore TestStore => SqliteTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteNoVectorConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteNoVectorConformanceTests.cs new file mode 100644 index 000000000000..1b0e800e8ca7 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteNoVectorConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqliteIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace SqliteIntegrationTests.CRUD; + +public class SqliteNoVectorConformanceTests(SqliteNoVectorConformanceTests.Fixture fixture) + : NoVectorConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoVectorConformanceTests.Fixture + { + public override TestStore TestStore => SqliteTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteRecordConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteRecordConformanceTests.cs new file mode 100644 index 000000000000..2d5f95f8592c --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteRecordConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqliteIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace SqliteIntegrationTests.CRUD; + +public class SqliteRecordConformanceTests_string(SqliteSimpleModelFixture fixture) + : RecordConformanceTests(fixture), IClassFixture> +{ +} + +public class SqliteRecordConformanceTests_ulong(SqliteSimpleModelFixture fixture) + : RecordConformanceTests(fixture), IClassFixture> +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteBasicFilterTests.cs index 10570cc109c5..996b69847455 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteBasicFilterTests.cs @@ -9,6 +9,8 @@ namespace SqliteIntegrationTests.Filter; +#pragma warning disable CS0252 // Possible unintended reference comparison; left hand side needs cast + public class SqliteBasicFilterTests(SqliteBasicFilterTests.Fixture fixture) : BasicFilterTests(fixture), IClassFixture { @@ -16,18 +18,22 @@ public override async Task Not_over_Or() { // Test sends: WHERE (NOT (("Int" = 8) OR ("String" = 'foo'))) // There's a NULL string in the database, and relational null semantics in conjunction with negation makes the default implementation fail. - await Assert.ThrowsAsync(() => base.Not_over_Or()); + await Assert.ThrowsAsync(() => base.Not_over_Or()); // Compensate by adding a null check: - await this.TestFilterAsync(r => r.String != null && !(r.Int == 8 || r.String == "foo")); + await this.TestFilterAsync( + r => r.String != null && !(r.Int == 8 || r.String == "foo"), + r => r["String"] != null && !((int)r["Int"]! == 8 || r["String"] == "foo")); } public override async Task NotEqual_with_string() { // As above, null semantics + negation - await Assert.ThrowsAsync(() => base.NotEqual_with_string()); + await Assert.ThrowsAsync(() => base.NotEqual_with_string()); - await this.TestFilterAsync(r => r.String != null && r.String != "foo"); + await this.TestFilterAsync( + r => r.String != null && r.String != "foo", + r => r["String"] != null && r["String"] != "foo"); } // Array fields not (currently) supported on SQLite (see #10343) @@ -51,10 +57,8 @@ public override Task Legacy_AnyTagEqualTo_List() { public override TestStore TestStore => SqliteTestStore.Instance; - protected override string DistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineDistance; - // Override to remove the string array property, which isn't (currently) supported on SQLite - protected override VectorStoreRecordDefinition GetRecordDefinition() + public override VectorStoreRecordDefinition GetRecordDefinition() => new() { Properties = base.GetRecordDefinition().Properties.Where(p => p.PropertyType != typeof(string[]) && p.PropertyType != typeof(List)).ToList() diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteBasicQueryTests.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteBasicQueryTests.cs new file mode 100644 index 000000000000..ff73aa802c95 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteBasicQueryTests.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using SqliteIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; +using Xunit; +using Xunit.Sdk; + +namespace SqliteIntegrationTests.Filter; + +#pragma warning disable CS0252 // Possible unintended reference comparison; left hand side needs cast + +public class SqliteBasicQueryTests(SqliteBasicQueryTests.Fixture fixture) + : BasicQueryTests(fixture), IClassFixture +{ + public override async Task Not_over_Or() + { + // Test sends: WHERE (NOT (("Int" = 8) OR ("String" = 'foo'))) + // There's a NULL string in the database, and relational null semantics in conjunction with negation makes the default implementation fail. + await Assert.ThrowsAsync(() => base.Not_over_Or()); + + // Compensate by adding a null check: + await this.TestFilterAsync( + r => r.String != null && !(r.Int == 8 || r.String == "foo"), + r => r["String"] != null && !((int)r["Int"]! == 8 || r["String"] == "foo")); + } + + public override async Task NotEqual_with_string() + { + // As above, null semantics + negation + await Assert.ThrowsAsync(() => base.NotEqual_with_string()); + + await this.TestFilterAsync( + r => r.String != null && r.String != "foo", + r => r["String"] != null && r["String"] != "foo"); + } + + // Array fields not (currently) supported on SQLite (see #10343) + public override Task Contains_over_field_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + + // List fields not (currently) supported on SQLite (see #10343) + public override Task Contains_over_field_string_List() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + + public new class Fixture : BasicQueryTests.QueryFixture + { + public override TestStore TestStore => SqliteTestStore.Instance; + + // Override to remove the string array property, which isn't (currently) supported on SQLite + public override VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = base.GetRecordDefinition().Properties.Where(p => p.PropertyType != typeof(string[]) && p.PropertyType != typeof(List)).ToList() + }; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/SqliteEmbeddingGenerationTests.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/SqliteEmbeddingGenerationTests.cs new file mode 100644 index 000000000000..7a7e2a2ba067 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/SqliteEmbeddingGenerationTests.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel; +using SqliteIntegrationTests.Support; +using VectorDataSpecificationTests; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace SqliteIntegrationTests; + +public class SqliteEmbeddingGenerationTests(SqliteEmbeddingGenerationTests.Fixture fixture) + : EmbeddingGenerationTests(fixture), IClassFixture +{ + public new class Fixture : EmbeddingGenerationTests.Fixture + { + public override TestStore TestStore => SqliteTestStore.Instance; + + public override IVectorStore CreateVectorStore(IEmbeddingGenerator? embeddingGenerator) + => SqliteTestStore.Instance.GetVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); + + public override Func[] DependencyInjectionStoreRegistrationDelegates => + [ + services => services.AddSqliteVectorStore(SqliteTestStore.Instance.ConnectionString) + ]; + + public override Func[] DependencyInjectionCollectionRegistrationDelegates => + [ + services => services.AddSqliteVectorStoreRecordCollection(this.CollectionName, SqliteTestStore.Instance.ConnectionString) + ]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteSimpleModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteSimpleModelFixture.cs new file mode 100644 index 000000000000..70550525aa74 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteSimpleModelFixture.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace SqliteIntegrationTests.Support; + +public class SqliteSimpleModelFixture : SimpleModelFixture + where TKey : notnull +{ + public override TestStore TestStore => SqliteTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs index 526eeac3b2d8..bed7ed7c8de0 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs @@ -1,55 +1,45 @@ // Copyright (c) Microsoft. All rights reserved. -using Microsoft.Data.Sqlite; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Sqlite; using VectorDataSpecificationTests.Support; namespace SqliteIntegrationTests.Support; -#pragma warning disable CA1001 // Type owns disposable fields (_connection) but is not disposable - internal sealed class SqliteTestStore : TestStore { - public static SqliteTestStore Instance { get; } = new(); + private string? _databasePath; + + private string? _connectionString; + public string ConnectionString => this._connectionString ?? throw new InvalidOperationException("Not initialized"); - private SqliteConnection? _connection; - public SqliteConnection Connection - => this._connection ?? throw new InvalidOperationException("Call InitializeAsync() first"); + public static SqliteTestStore Instance { get; } = new(); private SqliteVectorStore? _defaultVectorStore; public override IVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Call InitializeAsync() first"); + public override string DefaultDistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineDistance; + + public SqliteVectorStore GetVectorStore(SqliteVectorStoreOptions options) + => new(this.ConnectionString, options); + private SqliteTestStore() { } - protected override async Task StartAsync() + protected override Task StartAsync() { - this._connection = new SqliteConnection("Data Source=:memory:"); - - await this.Connection.OpenAsync(); - - if (!SqliteTestEnvironment.TryLoadSqliteVec(this.Connection)) - { - this.Connection.Dispose(); - - // Note that we ignore sqlite_vec loading failures; the tests are decorated with [SqliteVecRequired], which causes - // them to be skipped if sqlite_vec isn't installed (better than an exception triggering failure here) - } - - this._defaultVectorStore = new SqliteVectorStore(this.Connection); + this._databasePath = Path.GetTempFileName(); + this._connectionString = $"Data Source={this._databasePath}"; + this._defaultVectorStore = new SqliteVectorStore(this._connectionString); + return Task.CompletedTask; } -#if NET8_0_OR_GREATER - protected override async Task StopAsync() - => await this.Connection.DisposeAsync(); -#else protected override Task StopAsync() { - this.Connection.Dispose(); + File.Delete(this._databasePath!); + this._databasePath = null; return Task.CompletedTask; } -#endif } diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BatchConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BatchConformanceTests.cs index b8fe0a30afe4..4b642de36488 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BatchConformanceTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BatchConformanceTests.cs @@ -10,24 +10,24 @@ namespace VectorDataSpecificationTests.CRUD; public abstract class BatchConformanceTests(SimpleModelFixture fixture) where TKey : notnull { [ConditionalFact] - public async Task GetBatchAsyncThrowsArgumentNullExceptionForNullKeys() + public virtual async Task GetBatchAsyncThrowsArgumentNullExceptionForNullKeys() { - ArgumentNullException ex = await Assert.ThrowsAsync(() => fixture.Collection.GetBatchAsync(keys: null!).ToArrayAsync().AsTask()); + ArgumentNullException ex = await Assert.ThrowsAsync(() => fixture.Collection.GetAsync(keys: null!).ToArrayAsync().AsTask()); Assert.Equal("keys", ex.ParamName); } [ConditionalFact] - public async Task GetBatchAsyncDoesNotThrowForEmptyBatch() + public virtual async Task GetBatchAsyncDoesNotThrowForEmptyBatch() { - Assert.Empty(await fixture.Collection.GetBatchAsync([]).ToArrayAsync()); + Assert.Empty(await fixture.Collection.GetAsync([]).ToArrayAsync()); } [ConditionalFact] - public Task GetBatchAsyncReturnsInsertedRecords_WithVectors() + public virtual Task GetBatchAsync_WithVectors() => this.GetBatchAsyncReturnsInsertedRecords(includeVectors: true); [ConditionalFact] - public Task GetBatchAsyncReturnsInsertedRecords_WithoutVectors() + public virtual Task GetBatchAsync_WithoutVectors() => this.GetBatchAsyncReturnsInsertedRecords(includeVectors: false); private async Task GetBatchAsyncReturnsInsertedRecords(bool includeVectors) @@ -35,78 +35,64 @@ private async Task GetBatchAsyncReturnsInsertedRecords(bool includeVectors) var expectedRecords = fixture.TestData.Take(2); // the last two records can get deleted by other tests var ids = expectedRecords.Select(record => record.Id); - var received = await fixture.Collection.GetBatchAsync(ids, new() { IncludeVectors = includeVectors }).ToArrayAsync(); + var received = await fixture.Collection.GetAsync(ids, new() { IncludeVectors = includeVectors }).ToArrayAsync(); foreach (var record in expectedRecords) { - record.AssertEqual(this.GetRecord(received, record.Id), includeVectors); + record.AssertEqual(this.GetRecord(received, record.Id), includeVectors, fixture.TestStore.VectorsComparable); } } [ConditionalFact] - public async Task UpsertBatchAsyncThrowsArgumentNullExceptionForNullBatch() + public virtual async Task UpsertBatchAsyncThrowsArgumentNullExceptionForNullBatch() { - ArgumentNullException ex = await Assert.ThrowsAsync(() => fixture.Collection.UpsertBatchAsync(records: null!).ToArrayAsync().AsTask()); + ArgumentNullException ex = await Assert.ThrowsAsync(() => fixture.Collection.UpsertAsync(records: null!)); Assert.Equal("records", ex.ParamName); } [ConditionalFact] - public async Task UpsertBatchAsyncDoesNotThrowForEmptyBatch() + public virtual async Task UpsertBatchAsyncDoesNotThrowForEmptyBatch() { - Assert.Empty(await fixture.Collection.UpsertBatchAsync([]).ToArrayAsync()); + Assert.Empty(await fixture.Collection.UpsertAsync([])); } [ConditionalFact] - public Task UpsertBatchAsyncCanInsertNewRecord_WithVectors() - => this.UpsertBatchAsyncCanInsertNewRecords(includeVectors: true); - - [ConditionalFact] - public Task UpsertBatchAsyncCanInsertNewRecord_WithoutVectors() - => this.UpsertBatchAsyncCanInsertNewRecords(includeVectors: false); - - private async Task UpsertBatchAsyncCanInsertNewRecords(bool includeVectors) + public virtual async Task UpsertBatchAsyncCanInsertNewRecord() { var collection = fixture.Collection; - SimpleModel[] inserted = Enumerable.Range(0, 10).Select(i => new SimpleModel() + SimpleRecord[] inserted = Enumerable.Range(0, 10).Select(i => new SimpleRecord() { Id = fixture.GenerateNextKey(), Number = 100 + i, Text = i.ToString(), - Floats = Enumerable.Range(0, SimpleModel.DimensionCount).Select(j => (float)(i + j)).ToArray() + Floats = Enumerable.Range(0, SimpleRecord.DimensionCount).Select(j => (float)(i + j)).ToArray() }).ToArray(); var keys = inserted.Select(record => record.Id).ToArray(); - Assert.Empty(await collection.GetBatchAsync(keys).ToArrayAsync()); - var receivedKeys = await collection.UpsertBatchAsync(inserted).ToArrayAsync(); + Assert.Empty(await collection.GetAsync(keys).ToArrayAsync()); + var receivedKeys = await collection.UpsertAsync(inserted); Assert.Equal(keys.ToHashSet(), receivedKeys.ToHashSet()); // .ToHashSet() to ignore order - var received = await collection.GetBatchAsync(keys, new() { IncludeVectors = includeVectors }).ToArrayAsync(); + var received = await collection.GetAsync(keys, new() { IncludeVectors = true }).ToArrayAsync(); foreach (var record in inserted) { - record.AssertEqual(this.GetRecord(received, record.Id), includeVectors); + record.AssertEqual(this.GetRecord(received, record.Id), includeVectors: true, fixture.TestStore.VectorsComparable); } } [ConditionalFact] - public Task UpsertBatchAsyncCanUpdateExistingRecords_WithVectors() - => this.UpsertBatchAsyncCanUpdateExistingRecords(includeVectors: true); - - [ConditionalFact] - public Task UpsertBatchAsyncCanUpdateExistingRecords_WithoutVectors() - => this.UpsertBatchAsyncCanUpdateExistingRecords(includeVectors: false); - - private async Task UpsertBatchAsyncCanUpdateExistingRecords(bool includeVectors) + public virtual async Task UpsertBatchAsyncCanUpdateExistingRecords() { - SimpleModel[] inserted = Enumerable.Range(0, 10).Select(i => new SimpleModel() + SimpleRecord[] inserted = Enumerable.Range(0, 10).Select(i => new SimpleRecord() { Id = fixture.GenerateNextKey(), Number = 100 + i, Text = i.ToString(), - Floats = Enumerable.Range(0, SimpleModel.DimensionCount).Select(j => (float)(i + j)).ToArray() + Floats = Enumerable.Range(0, SimpleRecord.DimensionCount).Select(j => (float)(i + j)).ToArray() }).ToArray(); - await fixture.Collection.UpsertBatchAsync(inserted).ToArrayAsync(); + await fixture.Collection.UpsertAsync(inserted); - SimpleModel[] updated = inserted.Select(i => new SimpleModel() + SimpleRecord[] updated = inserted.Select(i => new SimpleRecord() { Id = i.Id, Text = i.Text + "updated", @@ -114,39 +100,32 @@ private async Task UpsertBatchAsyncCanUpdateExistingRecords(bool includeVectors) Floats = i.Floats }).ToArray(); - var keys = await fixture.Collection.UpsertBatchAsync(updated).ToArrayAsync(); + var keys = await fixture.Collection.UpsertAsync(updated); Assert.Equal( updated.Select(r => r.Id).OrderBy(id => id).ToArray(), keys.OrderBy(id => id).ToArray()); - var received = await fixture.Collection.GetBatchAsync(keys, new() { IncludeVectors = includeVectors }).ToArrayAsync(); + var received = await fixture.Collection.GetAsync(keys, new() { IncludeVectors = true }).ToArrayAsync(); foreach (var record in updated) { - record.AssertEqual(this.GetRecord(received, record.Id), includeVectors); + record.AssertEqual(this.GetRecord(received, record.Id), includeVectors: true, fixture.TestStore.VectorsComparable); } } [ConditionalFact] - public Task UpsertCanBothInsertAndUpdateRecordsFromTheSameBatch_WithVectors() - => this.UpsertCanBothInsertAndUpdateRecordsFromTheSameBatch(includeVectors: true); - - [ConditionalFact] - public Task UpsertCanBothInsertAndUpdateRecordsFromTheSameBatch_WithoutVectors() - => this.UpsertCanBothInsertAndUpdateRecordsFromTheSameBatch(includeVectors: false); - - private async Task UpsertCanBothInsertAndUpdateRecordsFromTheSameBatch(bool includeVectors) + public virtual async Task UpsertCanBothInsertAndUpdateRecordsFromTheSameBatch() { - SimpleModel[] records = Enumerable.Range(0, 10).Select(i => new SimpleModel() + SimpleRecord[] records = Enumerable.Range(0, 10).Select(i => new SimpleRecord() { Id = fixture.GenerateNextKey(), Number = 100 + i, Text = i.ToString(), - Floats = Enumerable.Range(0, SimpleModel.DimensionCount).Select(j => (float)(i + j)).ToArray() + Floats = Enumerable.Range(0, SimpleRecord.DimensionCount).Select(j => (float)(i + j)).ToArray() }).ToArray(); // We take first half of the records and insert them. - SimpleModel[] firstHalf = records.Take(records.Length / 2).ToArray(); - TKey[] insertedKeys = await fixture.Collection.UpsertBatchAsync(firstHalf).ToArrayAsync(); + SimpleRecord[] firstHalf = records.Take(records.Length / 2).ToArray(); + var insertedKeys = await fixture.Collection.UpsertAsync(firstHalf); Assert.Equal( firstHalf.Select(r => r.Id).OrderBy(id => id).ToArray(), insertedKeys.OrderBy(id => id).ToArray()); @@ -159,43 +138,43 @@ private async Task UpsertCanBothInsertAndUpdateRecordsFromTheSameBatch(bool incl } // And now we upsert all the records (the first half is an update, the second is an insert). - TKey[] mixedKeys = await fixture.Collection.UpsertBatchAsync(records).ToArrayAsync(); + var mixedKeys = await fixture.Collection.UpsertAsync(records); Assert.Equal( records.Select(r => r.Id).OrderBy(id => id).ToArray(), mixedKeys.OrderBy(id => id).ToArray()); - var received = await fixture.Collection.GetBatchAsync(mixedKeys, new() { IncludeVectors = includeVectors }).ToArrayAsync(); + var received = await fixture.Collection.GetAsync(mixedKeys, new() { IncludeVectors = true }).ToArrayAsync(); foreach (var record in records) { - record.AssertEqual(this.GetRecord(received, record.Id), includeVectors); + record.AssertEqual(this.GetRecord(received, record.Id), includeVectors: true, fixture.TestStore.VectorsComparable); } } [ConditionalFact] - public async Task DeleteBatchAsyncDoesNotThrowForEmptyBatch() + public virtual async Task DeleteBatchAsyncDoesNotThrowForEmptyBatch() { - await fixture.Collection.DeleteBatchAsync([]); + await fixture.Collection.DeleteAsync([]); } [ConditionalFact] - public async Task DeleteBatchAsyncThrowsArgumentNullExceptionForNullKeys() + public virtual async Task DeleteBatchAsyncThrowsArgumentNullExceptionForNullKeys() { - ArgumentNullException ex = await Assert.ThrowsAsync(() => fixture.Collection.DeleteBatchAsync(keys: null!)); + ArgumentNullException ex = await Assert.ThrowsAsync(() => fixture.Collection.DeleteAsync(keys: null!)); Assert.Equal("keys", ex.ParamName); } [ConditionalFact] - public async Task DeleteBatchAsyncDeletesTheRecords() + public virtual async Task DeleteBatchAsyncDeletesTheRecords() { TKey[] idsToRemove = [fixture.TestData[2].Id, fixture.TestData[3].Id]; - Assert.NotEmpty(await fixture.Collection.GetBatchAsync(idsToRemove).ToArrayAsync()); - await fixture.Collection.DeleteBatchAsync(idsToRemove); - Assert.Empty(await fixture.Collection.GetBatchAsync(idsToRemove).ToArrayAsync()); + Assert.NotEmpty(await fixture.Collection.GetAsync(idsToRemove).ToArrayAsync()); + await fixture.Collection.DeleteAsync(idsToRemove); + Assert.Empty(await fixture.Collection.GetAsync(idsToRemove).ToArrayAsync()); } // The order of records in the received array is not guaranteed // to match the order of keys in the requested keys array. - protected SimpleModel GetRecord(SimpleModel[] received, TKey key) + protected SimpleRecord GetRecord(SimpleRecord[] received, TKey key) => received.Single(r => r.Id!.Equals(key)); } diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/DynamicDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/DynamicDataModelConformanceTests.cs new file mode 100644 index 000000000000..2fddb4dd699c --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/DynamicDataModelConformanceTests.cs @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace VectorDataSpecificationTests.CRUD; + +public abstract class DynamicDataModelConformanceTests(DynamicDataModelFixture fixture) + where TKey : notnull +{ + [ConditionalFact] + public virtual async Task GetAsyncThrowsArgumentNullExceptionForNullKey() + { + // Skip this test for value type keys + if (default(TKey) is not null) + { + return; + } + + ArgumentNullException ex = await Assert.ThrowsAsync(() => fixture.Collection.GetAsync((TKey)default!)); + Assert.Equal("key", ex.ParamName); + } + + [ConditionalFact] + public virtual async Task GetAsyncReturnsNullForNonExistingKey() + { + TKey key = fixture.GenerateNextKey(); + + Assert.Null(await fixture.Collection.GetAsync(key)); + } + + [ConditionalFact] + public virtual Task GetAsync_WithVectors() + => this.GetAsyncReturnsInsertedRecord(includeVectors: true); + + [ConditionalFact] + public virtual Task GetAsync_WithoutVectors() + => this.GetAsyncReturnsInsertedRecord(includeVectors: false); + + private async Task GetAsyncReturnsInsertedRecord(bool includeVectors) + { + var expectedRecord = fixture.TestData[0]; + + var received = await fixture.Collection.GetAsync( + (TKey)expectedRecord[DynamicDataModelFixture.KeyPropertyName]!, + new() { IncludeVectors = includeVectors }); + + AssertEquivalent(expectedRecord, received, includeVectors, fixture.TestStore.VectorsComparable); + } + + [ConditionalFact] + public virtual async Task UpsertAsyncCanInsertNewRecord() + { + var collection = fixture.Collection; + TKey expectedKey = fixture.GenerateNextKey(); + var inserted = new Dictionary + { + [DynamicDataModelFixture.KeyPropertyName] = expectedKey, + [DynamicDataModelFixture.StringPropertyName] = "some", + [DynamicDataModelFixture.IntegerPropertyName] = 123, + [DynamicDataModelFixture.EmbeddingPropertyName] = new ReadOnlyMemory(Enumerable.Repeat(0.1f, DynamicDataModelFixture.DimensionCount).ToArray()) + }; + + Assert.Null(await collection.GetAsync(expectedKey)); + var key = await collection.UpsertAsync(inserted); + Assert.Equal(expectedKey, key); + + var received = await collection.GetAsync(expectedKey, new() { IncludeVectors = true }); + AssertEquivalent(inserted, received, includeVectors: true, fixture.TestStore.VectorsComparable); + } + + [ConditionalFact] + public virtual async Task UpsertAsyncCanUpdateExistingRecord() + { + var collection = fixture.Collection; + var existingRecord = fixture.TestData[1]; + var updated = new Dictionary + { + [DynamicDataModelFixture.KeyPropertyName] = existingRecord[DynamicDataModelFixture.KeyPropertyName], + [DynamicDataModelFixture.StringPropertyName] = "different", + [DynamicDataModelFixture.IntegerPropertyName] = 456, + [DynamicDataModelFixture.EmbeddingPropertyName] = new ReadOnlyMemory(Enumerable.Repeat(0.7f, DynamicDataModelFixture.DimensionCount).ToArray()) + }; + + Assert.NotNull(await collection.GetAsync((TKey)existingRecord[DynamicDataModelFixture.KeyPropertyName]!)); + var key = await collection.UpsertAsync(updated); + Assert.Equal(existingRecord[DynamicDataModelFixture.KeyPropertyName], key); + + var received = await collection.GetAsync((TKey)existingRecord[DynamicDataModelFixture.KeyPropertyName]!, new() { IncludeVectors = true }); + AssertEquivalent(updated, received, includeVectors: true, fixture.TestStore.VectorsComparable); + } + + [ConditionalFact] + public virtual async Task DeleteAsyncDoesNotThrowForNonExistingKey() + { + TKey key = fixture.GenerateNextKey(); + + await fixture.Collection.DeleteAsync(key); + } + + [ConditionalFact] + public async Task DeleteAsyncDeletesTheRecord() + { + var recordToRemove = fixture.TestData[2]; + + Assert.NotNull(await fixture.Collection.GetAsync((TKey)recordToRemove[DynamicDataModelFixture.KeyPropertyName]!)); + await fixture.Collection.DeleteAsync((TKey)recordToRemove[DynamicDataModelFixture.KeyPropertyName]!); + Assert.Null(await fixture.Collection.GetAsync((TKey)recordToRemove[DynamicDataModelFixture.KeyPropertyName]!)); + } + + protected static void AssertEquivalent(Dictionary expected, Dictionary? actual, bool includeVectors, bool compareVectors) + { + Assert.NotNull(actual); + Assert.Equal(expected[DynamicDataModelFixture.KeyPropertyName], actual[DynamicDataModelFixture.KeyPropertyName]); + + Assert.Equal(expected[DynamicDataModelFixture.StringPropertyName], actual[DynamicDataModelFixture.StringPropertyName]); + Assert.Equal(expected[DynamicDataModelFixture.IntegerPropertyName], actual[DynamicDataModelFixture.IntegerPropertyName]); + + if (includeVectors) + { + Assert.Equal( + ((ReadOnlyMemory)expected[DynamicDataModelFixture.EmbeddingPropertyName]!).Length, + ((ReadOnlyMemory)actual[DynamicDataModelFixture.EmbeddingPropertyName]!).Length); + + if (compareVectors) + { + Assert.Equal( + ((ReadOnlyMemory)expected[DynamicDataModelFixture.EmbeddingPropertyName]!).ToArray(), + ((ReadOnlyMemory)actual[DynamicDataModelFixture.EmbeddingPropertyName]!).ToArray()); + } + } + else + { + Assert.False(actual.ContainsKey(DynamicDataModelFixture.EmbeddingPropertyName)); + } + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/GenericDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/GenericDataModelConformanceTests.cs deleted file mode 100644 index dc905f82aea8..000000000000 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/GenericDataModelConformanceTests.cs +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using Microsoft.Extensions.VectorData; -using VectorDataSpecificationTests.Support; -using VectorDataSpecificationTests.Xunit; -using Xunit; - -namespace VectorDataSpecificationTests.CRUD; - -public abstract class GenericDataModelConformanceTests(GenericDataModelFixture fixture) where TKey : notnull -{ - [ConditionalFact] - public async Task GetAsyncThrowsArgumentNullExceptionForNullKey() - { - ArgumentNullException ex = await Assert.ThrowsAsync(() => fixture.Collection.GetAsync(default!)); - Assert.Equal("key", ex.ParamName); - } - - [ConditionalFact] - public async Task GetAsyncReturnsNullForNonExistingKey() - { - TKey key = fixture.GenerateNextKey(); - - Assert.Null(await fixture.Collection.GetAsync(key)); - } - - [ConditionalFact] - public Task GetAsyncReturnsInsertedRecord_WithVectors() - => this.GetAsyncReturnsInsertedRecord(includeVectors: true); - - [ConditionalFact] - public Task GetAsyncReturnsInsertedRecord_WithoutVectors() - => this.GetAsyncReturnsInsertedRecord(includeVectors: false); - - private async Task GetAsyncReturnsInsertedRecord(bool includeVectors) - { - var expectedRecord = fixture.TestData[0]; - - var received = await fixture.Collection.GetAsync(expectedRecord.Key, new() { IncludeVectors = includeVectors }); - - AssertEqual(expectedRecord, received, includeVectors); - } - - [ConditionalFact] - public Task UpsertAsyncCanInsertNewRecord_WithVectors() - => this.UpsertAsyncCanInsertNewRecord(includeVectors: true); - - [ConditionalFact] - public Task UpsertAsyncCanInsertNewRecord_WithoutVectors() - => this.UpsertAsyncCanInsertNewRecord(includeVectors: false); - - private async Task UpsertAsyncCanInsertNewRecord(bool includeVectors) - { - var collection = fixture.Collection; - TKey expectedKey = fixture.GenerateNextKey(); - VectorStoreGenericDataModel inserted = new(expectedKey) - { - Data = - { - [GenericDataModelFixture.StringPropertyName] = "some", - [GenericDataModelFixture.IntegerPropertyName] = 123 - }, - Vectors = - { - [GenericDataModelFixture.EmbeddingPropertyName] = new ReadOnlyMemory(Enumerable.Repeat(0.1f, GenericDataModelFixture.DimensionCount).ToArray()) - } - }; - - Assert.Null(await collection.GetAsync(expectedKey)); - TKey key = await collection.UpsertAsync(inserted); - Assert.Equal(expectedKey, key); - - var received = await collection.GetAsync(expectedKey, new() { IncludeVectors = includeVectors }); - AssertEqual(inserted, received, includeVectors); - } - - [ConditionalFact] - public Task UpsertAsyncCanUpdateExistingRecord_WithVectors() - => this.UpsertAsyncCanUpdateExistingRecord(includeVectors: true); - - [ConditionalFact] - public Task UpsertAsyncCanUpdateExistingRecord__WithoutVectors() - => this.UpsertAsyncCanUpdateExistingRecord(includeVectors: false); - - private async Task UpsertAsyncCanUpdateExistingRecord(bool includeVectors) - { - var collection = fixture.Collection; - var existingRecord = fixture.TestData[1]; - VectorStoreGenericDataModel updated = new(existingRecord.Key) - { - Data = - { - [GenericDataModelFixture.StringPropertyName] = "different", - [GenericDataModelFixture.IntegerPropertyName] = 456 - }, - Vectors = - { - [GenericDataModelFixture.EmbeddingPropertyName] = new ReadOnlyMemory(Enumerable.Repeat(0.7f, GenericDataModelFixture.DimensionCount).ToArray()) - } - }; - - Assert.NotNull(await collection.GetAsync(existingRecord.Key)); - TKey key = await collection.UpsertAsync(updated); - Assert.Equal(existingRecord.Key, key); - - var received = await collection.GetAsync(existingRecord.Key, new() { IncludeVectors = includeVectors }); - AssertEqual(updated, received, includeVectors); - } - - [ConditionalFact] - public async Task DeleteAsyncDoesNotThrowForNonExistingKey() - { - TKey key = fixture.GenerateNextKey(); - - await fixture.Collection.DeleteAsync(key); - } - - [ConditionalFact] - public async Task DeleteAsyncDeletesTheRecord() - { - var recordToRemove = fixture.TestData[2]; - - Assert.NotNull(await fixture.Collection.GetAsync(recordToRemove.Key)); - await fixture.Collection.DeleteAsync(recordToRemove.Key); - Assert.Null(await fixture.Collection.GetAsync(recordToRemove.Key)); - } - - private static void AssertEqual(VectorStoreGenericDataModel expected, VectorStoreGenericDataModel? actual, bool includeVectors) - { - Assert.NotNull(actual); - Assert.Equal(expected.Key, actual.Key); - foreach (var pair in expected.Data) - { - Assert.Equal(pair.Value, actual.Data[pair.Key]); - } - - if (includeVectors) - { - Assert.Equal( - ((ReadOnlyMemory)expected.Vectors[GenericDataModelFixture.EmbeddingPropertyName]!).ToArray(), - ((ReadOnlyMemory)actual.Vectors[GenericDataModelFixture.EmbeddingPropertyName]!).ToArray()); - } - else - { - Assert.Empty(actual.Vectors); - } - } -} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/NoDataConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/NoDataConformanceTests.cs new file mode 100644 index 000000000000..2eca14714b0e --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/NoDataConformanceTests.cs @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Support; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace VectorDataSpecificationTests.CRUD; + +/// +/// Tests CRUD operations using a model without data fields. +/// +public class NoDataConformanceTests(NoDataConformanceTests.Fixture fixture) where TKey : notnull +{ + [ConditionalFact] + public virtual Task GetAsyncReturnsInsertedRecord_WithVectors() + => this.GetAsyncReturnsInsertedRecord(includeVectors: true); + + [ConditionalFact] + public virtual Task GetAsyncReturnsInsertedRecord_WithoutVectors() + => this.GetAsyncReturnsInsertedRecord(includeVectors: false); + + private async Task GetAsyncReturnsInsertedRecord(bool includeVectors) + { + var expectedRecord = fixture.TestData[0]; + + var received = await fixture.Collection.GetAsync(expectedRecord.Id, new() { IncludeVectors = includeVectors }); + + expectedRecord.AssertEqual(received, includeVectors, fixture.TestStore.VectorsComparable); + } + + [ConditionalFact] + public virtual Task UpsertAsyncCanInsertNewRecord_WithVectors() + => this.UpsertAsyncCanInsertNewRecord(includeVectors: true); + + [ConditionalFact] + public virtual Task UpsertAsyncCanInsertNewRecord_WithoutVectors() + => this.UpsertAsyncCanInsertNewRecord(includeVectors: false); + + private async Task UpsertAsyncCanInsertNewRecord(bool includeVectors) + { + var collection = fixture.Collection; + TKey expectedKey = fixture.GenerateNextKey(); + NoDataRecord inserted = new() + { + Id = expectedKey, + Floats = new ReadOnlyMemory(Enumerable.Repeat(0.1f, NoDataRecord.DimensionCount).ToArray()) + }; + + Assert.Null(await collection.GetAsync(expectedKey)); + TKey key = await collection.UpsertAsync(inserted); + Assert.Equal(expectedKey, key); + + var received = await collection.GetAsync(expectedKey, new() { IncludeVectors = includeVectors }); + inserted.AssertEqual(received, includeVectors, fixture.TestStore.VectorsComparable); + } + + [ConditionalFact] + public virtual Task UpsertAsyncCanUpdateExistingRecord_WithVectors() + => this.UpsertAsyncCanUpdateExistingRecord(includeVectors: true); + + [ConditionalFact] + public virtual Task UpsertAsyncCanUpdateExistingRecord_WithoutVectors() + => this.UpsertAsyncCanUpdateExistingRecord(includeVectors: false); + + private async Task UpsertAsyncCanUpdateExistingRecord(bool includeVectors) + { + var collection = fixture.Collection; + var existingRecord = fixture.TestData[1]; + NoDataRecord updated = new() + { + Id = existingRecord.Id, + Floats = new ReadOnlyMemory(Enumerable.Repeat(0.25f, NoDataRecord.DimensionCount).ToArray()) + }; + + Assert.NotNull(await collection.GetAsync(existingRecord.Id, new() { IncludeVectors = true })); + TKey key = await collection.UpsertAsync(updated); + Assert.Equal(existingRecord.Id, key); + + var received = await collection.GetAsync(existingRecord.Id, new() { IncludeVectors = includeVectors }); + updated.AssertEqual(received, includeVectors, fixture.TestStore.VectorsComparable); + } + + [ConditionalFact] + public virtual async Task DeleteAsyncDeletesTheRecord() + { + var recordToRemove = fixture.TestData[2]; + + Assert.NotNull(await fixture.Collection.GetAsync(recordToRemove.Id, new() { IncludeVectors = true })); + await fixture.Collection.DeleteAsync(recordToRemove.Id); + Assert.Null(await fixture.Collection.GetAsync(recordToRemove.Id)); + } + + /// + /// This class is for testing databases that support having no data fields. + /// + public sealed class NoDataRecord + { + public const int DimensionCount = 3; + + [VectorStoreRecordKey(StoragePropertyName = "key")] + public TKey Id { get; set; } = default!; + + [VectorStoreRecordVector(DimensionCount, StoragePropertyName = "embedding")] + public ReadOnlyMemory Floats { get; set; } + + public void AssertEqual(NoDataRecord? other, bool includeVectors, bool compareVectors) + { + Assert.NotNull(other); + Assert.Equal(this.Id, other.Id); + + if (includeVectors) + { + Assert.Equal(this.Floats.Span.Length, other.Floats.Span.Length); + + if (compareVectors) + { + Assert.True(this.Floats.Span.SequenceEqual(other.Floats.Span)); + } + } + } + } + + /// + /// Provides data and configuration for a model without data fields. + /// + public abstract class Fixture : VectorStoreCollectionFixture + { + protected override List BuildTestData() => + [ + new() + { + Id = this.GenerateNextKey(), + Floats = new ReadOnlyMemory(Enumerable.Repeat(0.1f, NoDataRecord.DimensionCount).ToArray()) + }, + new() + { + Id = this.GenerateNextKey(), + Floats = new ReadOnlyMemory(Enumerable.Repeat(0.2f, NoDataRecord.DimensionCount).ToArray()) + }, + new() + { + Id = this.GenerateNextKey(), + Floats = new ReadOnlyMemory(Enumerable.Repeat(0.3f, NoDataRecord.DimensionCount).ToArray()) + }, + new() + { + Id = this.GenerateNextKey(), + Floats = new ReadOnlyMemory(Enumerable.Repeat(0.4f, NoDataRecord.DimensionCount).ToArray()) + } + ]; + + public override VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(NoDataRecord.Id), typeof(TKey)) { StoragePropertyName = "key" }, + new VectorStoreRecordVectorProperty(nameof(NoDataRecord.Floats), typeof(ReadOnlyMemory), NoDataRecord.DimensionCount) + { + StoragePropertyName = "embedding", + IndexKind = this.IndexKind, + } + ] + }; + + protected override async Task WaitForDataAsync() + { + for (var i = 0; i < 20; i++) + { + var getOptions = new GetRecordOptions { IncludeVectors = true }; + var results = await this.Collection.GetAsync([this.TestData[0].Id, this.TestData[1].Id, this.TestData[2].Id, this.TestData[3].Id], getOptions).ToArrayAsync(); + if (results.Length == 4 && results.All(r => r != null)) + { + return; + } + + await Task.Delay(TimeSpan.FromMilliseconds(100)); + } + + throw new InvalidOperationException("Data did not appear in the collection within the expected time."); + } + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/NoVectorConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/NoVectorConformanceTests.cs new file mode 100644 index 000000000000..4d923d71304e --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/NoVectorConformanceTests.cs @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Support; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace VectorDataSpecificationTests.CRUD; + +/// +/// Tests CRUD operations using a model without a vector. +/// This is only supported by a subset of databases so only extend if applicable for your database. +/// +public class NoVectorConformanceTests(NoVectorConformanceTests.Fixture fixture) where TKey : notnull +{ + [ConditionalFact] + public Task GetAsyncReturnsInsertedRecord_WithVectors() + => this.GetAsyncReturnsInsertedRecord(includeVectors: true); + + [ConditionalFact] + public Task GetAsyncReturnsInsertedRecord_WithoutVectors() + => this.GetAsyncReturnsInsertedRecord(includeVectors: false); + + private async Task GetAsyncReturnsInsertedRecord(bool includeVectors) + { + var expectedRecord = fixture.TestData[0]; + + var received = await fixture.Collection.GetAsync(expectedRecord.Id, new() { IncludeVectors = includeVectors }); + + expectedRecord.AssertEqual(received); + } + + [ConditionalFact] + public Task UpsertAsyncCanInsertNewRecord_WithVectors() + => this.UpsertAsyncCanInsertNewRecord(includeVectors: true); + + [ConditionalFact] + public Task UpsertAsyncCanInsertNewRecord_WithoutVectors() + => this.UpsertAsyncCanInsertNewRecord(includeVectors: false); + + private async Task UpsertAsyncCanInsertNewRecord(bool includeVectors) + { + var collection = fixture.Collection; + TKey expectedKey = fixture.GenerateNextKey(); + NoVectorRecord inserted = new() + { + Id = expectedKey, + Text = "some" + }; + + Assert.Null(await collection.GetAsync(expectedKey)); + TKey key = await collection.UpsertAsync(inserted); + Assert.Equal(expectedKey, key); + + var received = await collection.GetAsync(expectedKey, new() { IncludeVectors = includeVectors }); + inserted.AssertEqual(received); + } + + [ConditionalFact] + public Task UpsertAsyncCanUpdateExistingRecord_WithVectors() + => this.UpsertAsyncCanUpdateExistingRecord(includeVectors: true); + + [ConditionalFact] + public Task UpsertAsyncCanUpdateExistingRecord__WithoutVectors() + => this.UpsertAsyncCanUpdateExistingRecord(includeVectors: false); + + private async Task UpsertAsyncCanUpdateExistingRecord(bool includeVectors) + { + var collection = fixture.Collection; + var existingRecord = fixture.TestData[1]; + NoVectorRecord updated = new() + { + Id = existingRecord.Id, + Text = "updated" + }; + + Assert.NotNull(await collection.GetAsync(existingRecord.Id)); + TKey key = await collection.UpsertAsync(updated); + Assert.Equal(existingRecord.Id, key); + + var received = await collection.GetAsync(existingRecord.Id, new() { IncludeVectors = includeVectors }); + updated.AssertEqual(received); + } + + [ConditionalFact] + public async Task DeleteAsyncDeletesTheRecord() + { + var recordToRemove = fixture.TestData[2]; + + Assert.NotNull(await fixture.Collection.GetAsync(recordToRemove.Id)); + await fixture.Collection.DeleteAsync(recordToRemove.Id); + Assert.Null(await fixture.Collection.GetAsync(recordToRemove.Id)); + } + + /// + /// This class is for testing databases that support having no vector. + /// Not all DBs support this. + /// + public sealed class NoVectorRecord + { + public const int DimensionCount = 3; + + [VectorStoreRecordKey(StoragePropertyName = "key")] + public TKey Id { get; set; } = default!; + + [VectorStoreRecordData(StoragePropertyName = "text")] + public string? Text { get; set; } + + public void AssertEqual(NoVectorRecord? other) + { + Assert.NotNull(other); + Assert.Equal(this.Id, other.Id); + Assert.Equal(this.Text, other.Text); + } + } + + /// + /// Provides data and configuration for a model without a vector, which is supported by some connectors. + /// + public abstract class Fixture : VectorStoreCollectionFixture + { + protected override List BuildTestData() => + [ + new() + { + Id = this.GenerateNextKey(), + Text = "UsedByGetTests", + }, + new() + { + Id = this.GenerateNextKey(), + Text = "UsedByUpdateTests", + }, + new() + { + Id = this.GenerateNextKey(), + Text = "UsedByDeleteTests", + }, + new() + { + Id = this.GenerateNextKey(), + Text = "UsedByDeleteBatchTests", + } + ]; + + public override VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(NoVectorRecord.Id), typeof(TKey)) { StoragePropertyName = "key" }, + new VectorStoreRecordDataProperty(nameof(NoVectorRecord.Text), typeof(string)) { IsIndexed = true, StoragePropertyName = "text" }, + ] + }; + + protected override async Task WaitForDataAsync() + { + for (var i = 0; i < 20; i++) + { + var results = await this.Collection.GetAsync([this.TestData[0].Id, this.TestData[1].Id, this.TestData[2].Id, this.TestData[3].Id]).ToArrayAsync(); + if (results.Length == 4 && results.All(r => r != null)) + { + return; + } + + await Task.Delay(TimeSpan.FromMilliseconds(100)); + } + + throw new InvalidOperationException("Data did not appear in the collection within the expected time."); + } + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/RecordConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/RecordConformanceTests.cs index 5a3d0d0081ea..a1354caa4efe 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/RecordConformanceTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/RecordConformanceTests.cs @@ -10,14 +10,20 @@ namespace VectorDataSpecificationTests.CRUD; public class RecordConformanceTests(SimpleModelFixture fixture) where TKey : notnull { [ConditionalFact] - public async Task GetAsyncThrowsArgumentNullExceptionForNullKey() + public virtual async Task GetAsyncThrowsArgumentNullExceptionForNullKey() { - ArgumentNullException ex = await Assert.ThrowsAsync(() => fixture.Collection.GetAsync(default!)); + // Skip this test for value type keys + if (default(TKey) is not null) + { + return; + } + + ArgumentNullException ex = await Assert.ThrowsAsync(() => fixture.Collection.GetAsync((TKey)default!)); Assert.Equal("key", ex.ParamName); } [ConditionalFact] - public async Task GetAsyncReturnsNullForNonExistingKey() + public virtual async Task GetAsyncReturnsNullForNonExistingKey() { TKey key = fixture.GenerateNextKey(); @@ -25,11 +31,11 @@ public async Task GetAsyncReturnsNullForNonExistingKey() } [ConditionalFact] - public Task GetAsyncReturnsInsertedRecord_WithVectors() + public virtual Task GetAsync_WithVectors() => this.GetAsyncReturnsInsertedRecord(includeVectors: true); [ConditionalFact] - public Task GetAsyncReturnsInsertedRecord_WithoutVectors() + public virtual Task GetAsync_WithoutVectors() => this.GetAsyncReturnsInsertedRecord(includeVectors: false); private async Task GetAsyncReturnsInsertedRecord(bool includeVectors) @@ -38,67 +44,53 @@ private async Task GetAsyncReturnsInsertedRecord(bool includeVectors) var received = await fixture.Collection.GetAsync(expectedRecord.Id, new() { IncludeVectors = includeVectors }); - expectedRecord.AssertEqual(received, includeVectors); + expectedRecord.AssertEqual(received, includeVectors, fixture.TestStore.VectorsComparable); } [ConditionalFact] - public Task UpsertAsyncCanInsertNewRecord_WithVectors() - => this.UpsertAsyncCanInsertNewRecord(includeVectors: true); - - [ConditionalFact] - public Task UpsertAsyncCanInsertNewRecord_WithoutVectors() - => this.UpsertAsyncCanInsertNewRecord(includeVectors: false); - - private async Task UpsertAsyncCanInsertNewRecord(bool includeVectors) + public virtual async Task UpsertAsyncCanInsertNewRecord() { var collection = fixture.Collection; TKey expectedKey = fixture.GenerateNextKey(); - SimpleModel inserted = new() + SimpleRecord inserted = new() { Id = expectedKey, Text = "some", Number = 123, - Floats = new ReadOnlyMemory(Enumerable.Repeat(0.1f, SimpleModel.DimensionCount).ToArray()) + Floats = new ReadOnlyMemory(Enumerable.Repeat(0.1f, SimpleRecord.DimensionCount).ToArray()) }; Assert.Null(await collection.GetAsync(expectedKey)); TKey key = await collection.UpsertAsync(inserted); Assert.Equal(expectedKey, key); - var received = await collection.GetAsync(expectedKey, new() { IncludeVectors = includeVectors }); - inserted.AssertEqual(received, includeVectors); + var received = await collection.GetAsync(expectedKey, new() { IncludeVectors = true }); + inserted.AssertEqual(received, includeVectors: true, fixture.TestStore.VectorsComparable); } [ConditionalFact] - public Task UpsertAsyncCanUpdateExistingRecord_WithVectors() - => this.UpsertAsyncCanUpdateExistingRecord(includeVectors: true); - - [ConditionalFact] - public Task UpsertAsyncCanUpdateExistingRecord__WithoutVectors() - => this.UpsertAsyncCanUpdateExistingRecord(includeVectors: false); - - private async Task UpsertAsyncCanUpdateExistingRecord(bool includeVectors) + public virtual async Task UpsertAsyncCanUpdateExistingRecord() { var collection = fixture.Collection; var existingRecord = fixture.TestData[1]; - SimpleModel updated = new() + SimpleRecord updated = new() { Id = existingRecord.Id, Text = "updated", Number = 456, - Floats = new ReadOnlyMemory(Enumerable.Repeat(0.2f, SimpleModel.DimensionCount).ToArray()) + Floats = new ReadOnlyMemory(Enumerable.Repeat(0.25f, SimpleRecord.DimensionCount).ToArray()) }; Assert.NotNull(await collection.GetAsync(existingRecord.Id)); TKey key = await collection.UpsertAsync(updated); Assert.Equal(existingRecord.Id, key); - var received = await collection.GetAsync(existingRecord.Id, new() { IncludeVectors = includeVectors }); - updated.AssertEqual(received, includeVectors); + var received = await collection.GetAsync(existingRecord.Id, new() { IncludeVectors = true }); + updated.AssertEqual(received, includeVectors: true, fixture.TestStore.VectorsComparable); } [ConditionalFact] - public async Task DeleteAsyncDoesNotThrowForNonExistingKey() + public virtual async Task DeleteAsyncDoesNotThrowForNonExistingKey() { TKey key = fixture.GenerateNextKey(); @@ -106,7 +98,7 @@ public async Task DeleteAsyncDoesNotThrowForNonExistingKey() } [ConditionalFact] - public async Task DeleteAsyncDeletesTheRecord() + public virtual async Task DeleteAsyncDeletesTheRecord() { var recordToRemove = fixture.TestData[2]; diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Collections/CollectionConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Collections/CollectionConformanceTests.cs index 16f8679df842..c031add9a1d4 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Collections/CollectionConformanceTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Collections/CollectionConformanceTests.cs @@ -8,85 +8,42 @@ namespace VectorDataSpecificationTests.Collections; -public abstract class CollectionConformanceTests(VectorStoreFixture fixture) where TKey : notnull +public abstract class CollectionConformanceTests(VectorStoreFixture fixture) + where TKey : notnull { [ConditionalFact] - public Task DeleteCollectionDoesNotThrowForNonExistingCollection() - => this.DeleteNonExistingCollection>(); - - [ConditionalFact] - public Task DeleteCollectionDoesNotThrowForNonExistingCollection_GenericDataModel() - => this.DeleteNonExistingCollection>(); - - [ConditionalFact] - public Task CreateCollectionCreatesTheCollection() - => this.CreateCollection>(); - - [ConditionalFact] - public Task CreateCollectionCreatesTheCollection_GenericDataModel() - => this.CreateCollection>(); - - [ConditionalFact] - public Task CreateCollectionIfNotExistsCalledMoreThanOnceDoesNotThrow() - => this.CreateCollectionIfNotExistsMoreThanOnce>(); - - [ConditionalFact] - public Task CreateCollectionIfNotExistsCalledMoreThanOnceDoesNotThrow_GenericDataModel() - => this.CreateCollectionIfNotExistsMoreThanOnce>(); - - [ConditionalFact] - public Task CreateCollectionCalledMoreThanOnceThrowsVectorStoreOperationException() - => this.CreateCollectionMoreThanOnce>(); - - [ConditionalFact] - public Task CreateCollectionCalledMoreThanOnceThrowsVectorStoreOperationException_GenericDataModel() - => this.CreateCollectionMoreThanOnce>(); - - private async Task> GetNonExistingCollectionAsync() + public async Task VectorStoreDeleteCollectionDeletesExistingCollection() { - var collectionName = fixture.GetUniqueCollectionName(); - VectorStoreRecordDefinition? definition = null; - if (typeof(TRecord) == typeof(VectorStoreGenericDataModel)) - { - definition = new() - { - Properties = - [ - new VectorStoreRecordKeyProperty(nameof(VectorStoreGenericDataModel.Key), typeof(TKey)), - new VectorStoreRecordDataProperty("string", typeof(string)), - new VectorStoreRecordDataProperty("integer", typeof(int)), - new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) - { - Dimensions = 10 - } - ] - }; - } + // Arrange. + var collection = await this.GetNonExistingCollectionAsync>(); + await collection.CreateCollectionAsync(); + Assert.True(await collection.CollectionExistsAsync()); - var collection = fixture.TestStore.DefaultVectorStore.GetCollection(collectionName, definition); + // Act. + await fixture.TestStore.DefaultVectorStore.DeleteCollectionAsync(collection.Name); + // Assert. Assert.False(await collection.CollectionExistsAsync()); - - return collection; } - private async Task DeleteNonExistingCollection() + [ConditionalFact] + public async Task VectorStoreDeleteCollectionDoesNotThrowForNonExistingCollection() { - var collection = await this.GetNonExistingCollectionAsync(); - - await collection.DeleteCollectionAsync(); + await fixture.TestStore.DefaultVectorStore.DeleteCollectionAsync(fixture.GetUniqueCollectionName()); } - private async Task CreateCollection() + [ConditionalFact] + public async Task VectorStoreCollectionExistsReturnsTrueForExistingCollection() { - var collection = await this.GetNonExistingCollectionAsync(); - - await collection.CreateCollectionAsync(); + // Arrange. + var collection = await this.GetNonExistingCollectionAsync>(); try { - Assert.True(await collection.CollectionExistsAsync()); - Assert.True(await fixture.TestStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(collection.CollectionName)); + await collection.CreateCollectionAsync(); + + // Act & Assert. + Assert.True(await fixture.TestStore.DefaultVectorStore.CollectionExistsAsync(collection.Name)); } finally { @@ -94,43 +51,142 @@ private async Task CreateCollection() } } - private async Task CreateCollectionIfNotExistsMoreThanOnce() + [ConditionalFact] + public async Task VectorStoreCollectionExistsReturnsFalseForNonExistingCollection() { - var collection = await this.GetNonExistingCollectionAsync(); + Assert.False(await fixture.TestStore.DefaultVectorStore.CollectionExistsAsync(fixture.GetUniqueCollectionName())); + } - await collection.CreateCollectionIfNotExistsAsync(); + [ConditionalTheory] + [MemberData(nameof(UseDynamicMappingData))] + public Task DeleteCollectionDoesNotThrowForNonExistingCollection(bool useDynamicMapping) + { + return useDynamicMapping ? Core>() : Core>(); - try + async Task Core() where TRecord : notnull { - Assert.True(await collection.CollectionExistsAsync()); - Assert.True(await fixture.TestStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(collection.CollectionName)); + var collection = await this.GetNonExistingCollectionAsync(); - await collection.CreateCollectionIfNotExistsAsync(); - } - finally - { await collection.DeleteCollectionAsync(); } } - private async Task CreateCollectionMoreThanOnce() + [ConditionalTheory] + [MemberData(nameof(UseDynamicMappingData))] + public Task CreateCollectionCreatesTheCollection(bool useDynamicMapping) { - var collection = await this.GetNonExistingCollectionAsync(); + return useDynamicMapping ? Core>() : Core>(); - await collection.CreateCollectionAsync(); + async Task Core() where TRecord : notnull + { + var collection = await this.GetNonExistingCollectionAsync(); - try + await collection.CreateCollectionAsync(); + + try + { + Assert.True(await collection.CollectionExistsAsync()); + + var collectionMetadata = collection.GetService(typeof(VectorStoreRecordCollectionMetadata)) as VectorStoreRecordCollectionMetadata; + + Assert.NotNull(collectionMetadata); + Assert.NotNull(collectionMetadata.VectorStoreSystemName); + Assert.NotNull(collectionMetadata.CollectionName); + + Assert.True(await fixture.TestStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(collectionMetadata.CollectionName)); + } + finally + { + await collection.DeleteCollectionAsync(); + } + } + } + + [ConditionalTheory] + [MemberData(nameof(UseDynamicMappingData))] + public Task CreateCollectionIfNotExistsCalledMoreThanOnceDoesNotThrow(bool useDynamicMapping) + { + return useDynamicMapping ? Core>() : Core>(); + + async Task Core() where TRecord : notnull { - Assert.True(await collection.CollectionExistsAsync()); - Assert.True(await fixture.TestStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(collection.CollectionName)); + var collection = await this.GetNonExistingCollectionAsync(); await collection.CreateCollectionIfNotExistsAsync(); - await Assert.ThrowsAsync(() => collection.CreateCollectionAsync()); + try + { + Assert.True(await collection.CollectionExistsAsync()); + + var collectionMetadata = collection.GetService(typeof(VectorStoreRecordCollectionMetadata)) as VectorStoreRecordCollectionMetadata; + + Assert.NotNull(collectionMetadata); + Assert.NotNull(collectionMetadata.VectorStoreSystemName); + Assert.NotNull(collectionMetadata.CollectionName); + + Assert.True(await fixture.TestStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(collectionMetadata.CollectionName)); + + await collection.CreateCollectionIfNotExistsAsync(); + } + finally + { + await collection.DeleteCollectionAsync(); + } } - finally + } + + [ConditionalTheory] + [MemberData(nameof(UseDynamicMappingData))] + public Task CreateCollectionCalledMoreThanOnceThrowsVectorStoreOperationException(bool useDynamicMapping) + { + return useDynamicMapping ? Core>() : Core>(); + + async Task Core() where TRecord : notnull { - await collection.DeleteCollectionAsync(); + var collection = await this.GetNonExistingCollectionAsync(); + + await collection.CreateCollectionAsync(); + + try + { + Assert.True(await collection.CollectionExistsAsync()); + + var collectionMetadata = collection.GetService(typeof(VectorStoreRecordCollectionMetadata)) as VectorStoreRecordCollectionMetadata; + + Assert.NotNull(collectionMetadata); + Assert.NotNull(collectionMetadata.VectorStoreSystemName); + Assert.NotNull(collectionMetadata.CollectionName); + + Assert.True(await fixture.TestStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(collectionMetadata.CollectionName)); + + await collection.CreateCollectionIfNotExistsAsync(); + + await Assert.ThrowsAsync(() => collection.CreateCollectionAsync()); + } + finally + { + await collection.DeleteCollectionAsync(); + } } } + + protected virtual async Task> GetNonExistingCollectionAsync() where TRecord : notnull + { + var definition = new VectorStoreRecordDefinition() + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(SimpleRecord.Id), typeof(TKey)) { StoragePropertyName = "key" }, + new VectorStoreRecordDataProperty(nameof(SimpleRecord.Text), typeof(string)) { StoragePropertyName = "text" }, + new VectorStoreRecordDataProperty(nameof(SimpleRecord.Number), typeof(int)) { StoragePropertyName = "number" }, + new VectorStoreRecordVectorProperty(nameof(SimpleRecord.Floats), typeof(ReadOnlyMemory), 10) { IndexKind = fixture.TestStore.DefaultIndexKind } + ] + }; + + var collection = fixture.TestStore.DefaultVectorStore.GetCollection(fixture.GetUniqueCollectionName(), definition); + await collection.DeleteCollectionAsync(); + return collection; + } + + public static readonly IEnumerable UseDynamicMappingData = [[false], [true]]; } diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/EmbeddingGenerationTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/EmbeddingGenerationTests.cs new file mode 100644 index 000000000000..18824613be0d --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/EmbeddingGenerationTests.cs @@ -0,0 +1,516 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.Properties; +using VectorDataSpecificationTests.Support; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace VectorDataSpecificationTests; + +#pragma warning disable CA1819 // Properties should not return arrays +#pragma warning disable CA2000 // Don't actually need to dispose FakeEmbeddingGenerator +#pragma warning disable CS8605 // Unboxing a possibly null value. + +public abstract class EmbeddingGenerationTests(EmbeddingGenerationTests.Fixture fixture) + where TKey : notnull +{ + #region Search + + [ConditionalFact] + public virtual async Task SearchAsync_with_property_generator() + { + // Property level: embedding generators are defined at all levels. The property generator should take precedence. + var collection = this.GetCollection(storeGenerator: true, collectionGenerator: true, propertyGenerator: true); + + var result = await collection.SearchAsync("[1, 1, 0]", top: 1).SingleAsync(); + + Assert.Equal("Property ([1, 1, 3])", result.Record.Text); + } + + [ConditionalFact] + public virtual async Task SearchAsync_with_property_generator_dynamic() + { + // Property level: embedding generators are defined at all levels. The property generator should take precedence. + var collection = this.GetCollection>(storeGenerator: true, collectionGenerator: true, propertyGenerator: true); + + var result = await collection.SearchAsync("[1, 1, 0]", top: 1).SingleAsync(); + + Assert.Equal("Property ([1, 1, 3])", result.Record[nameof(Record.Text)]); + } + + [ConditionalFact] + public virtual async Task SearchAsync_with_collection_generator() + { + // Collection level: embedding generators are defined at the collection and store level - the collection generator should take precedence. + var collection = this.GetCollection(storeGenerator: true, collectionGenerator: true, propertyGenerator: false); + + var result = await collection.SearchAsync("[1, 1, 0]", top: 1).SingleAsync(); + + Assert.Equal("Collection ([1, 1, 2])", result.Record.Text); + } + + [ConditionalFact] + public virtual async Task SearchAsync_with_store_generator() + { + // Store level: an embedding generator is defined at the store level only. + var collection = this.GetCollection(storeGenerator: true, collectionGenerator: false, propertyGenerator: false); + + var result = await collection.SearchAsync("[1, 1, 0]", top: 1).SingleAsync(); + + Assert.Equal("Store ([1, 1, 1])", result.Record.Text); + } + + [ConditionalFact] + public virtual async Task SearchAsync_with_store_dependency_injection() + { + foreach (var registrationDelegate in fixture.DependencyInjectionStoreRegistrationDelegates) + { + IServiceCollection serviceCollection = new ServiceCollection(); + + serviceCollection.AddSingleton(new FakeEmbeddingGenerator(replaceLast: 1)); + registrationDelegate(serviceCollection); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + + var vectorStore = serviceProvider.GetRequiredService(); + var collection = vectorStore.GetCollection(fixture.CollectionName, fixture.GetRecordDefinition()); + + var result = await collection.SearchAsync("[1, 1, 0]", top: 1).SingleAsync(); + + Assert.Equal("Store ([1, 1, 1])", result.Record.Text); + } + } + + [ConditionalFact] + public virtual async Task SearchAsync_with_collection_dependency_injection() + { + foreach (var registrationDelegate in fixture.DependencyInjectionCollectionRegistrationDelegates) + { + IServiceCollection serviceCollection = new ServiceCollection(); + + serviceCollection.AddSingleton(new FakeEmbeddingGenerator(replaceLast: 1)); + registrationDelegate(serviceCollection); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + + var collection = serviceProvider.GetRequiredService>(); + + var result = await collection.SearchAsync("[1, 1, 0]", top: 1).SingleAsync(); + + Assert.Equal("Store ([1, 1, 1])", result.Record.Text); + } + } + + [ConditionalFact] + public virtual async Task SearchAsync_with_custom_input_type() + { + var recordDefinition = new VectorStoreRecordDefinition() + { + Properties = fixture.GetRecordDefinition().Properties + .Select(p => p is VectorStoreRecordVectorProperty vectorProperty + ? new VectorStoreRecordVectorProperty(nameof(Record.Embedding), dimensions: 3) + { + DistanceFunction = fixture.DefaultDistanceFunction, + IndexKind = fixture.DefaultIndexKind + } + : p) + .ToList() + }; + + var collection = fixture.GetCollection( + fixture.CreateVectorStore(new FakeCustomerEmbeddingGenerator([1, 1, 1])), + fixture.CollectionName, + recordDefinition); + + var result = await collection.SearchAsync(new Customer(), top: 1).SingleAsync(); + + Assert.Equal("Store ([1, 1, 1])", result.Record.Text); + } + + [ConditionalFact] + public virtual async Task SearchAsync_without_generator_throws() + { + // The database doesn't support embedding generation, and no client-side generator has been configured at any level, + // so SearchAsync should throw. + var collection = fixture.GetCollection(fixture.TestStore.DefaultVectorStore, fixture.CollectionName + "WithoutGenerator"); + + var exception = await Assert.ThrowsAsync(() => collection.SearchAsync("foo", top: 1).ToListAsync().AsTask()); + + Assert.Equal(VectorDataStrings.NoEmbeddingGeneratorWasConfiguredForSearch, exception.Message); + } + + public class RawRecord + { + [VectorStoreRecordKey] + public TKey Key { get; set; } = default!; + [VectorStoreRecordVector(Dimensions: 3)] + public ReadOnlyMemory Embedding { get; set; } + } + + [ConditionalFact] + public virtual async Task SearchAsync_with_embedding_argument_throws() + { + var collection = this.GetCollection(storeGenerator: true, collectionGenerator: true, propertyGenerator: true); + + var exception = await Assert.ThrowsAsync(() => collection.SearchAsync(new ReadOnlyMemory([1, 2, 3]), top: 1).ToListAsync().AsTask()); + + Assert.Equal(VectorDataStrings.EmbeddingTypePassedToSearchAsync, exception.Message); + } + + [ConditionalFact] + public virtual async Task SearchAsync_with_incompatible_generator_throws() + { + var collection = this.GetCollection(storeGenerator: true, collectionGenerator: true, propertyGenerator: true); + + // We have a generator configured for string, not int. + var exception = await Assert.ThrowsAsync(() => collection.SearchAsync(8, top: 1).ToListAsync().AsTask()); + + Assert.Equal($"An input of type 'Int32' was provided, but an incompatible embedding generator of type '{nameof(FakeEmbeddingGenerator)}' was configured.", exception.Message); + } + + #endregion Search + + #region Upsert + + [ConditionalFact] + public virtual async Task UpsertAsync() + { + var counter = fixture.GenerateNextCounter(); + + var record = new Record + { + Key = fixture.GenerateNextKey(), + Embedding = "[100, 1, 0]", + Counter = counter, + Text = nameof(UpsertAsync) + }; + + // Property level: embedding generators are defined at all levels. The property generator should take precedence. + var collection = this.GetCollection(storeGenerator: true, collectionGenerator: true, propertyGenerator: true); + + await collection.UpsertAsync(record).ConfigureAwait(false); + + await fixture.TestStore.WaitForDataAsync(collection, 1, filter: r => r.Counter == counter); + + var result = await collection.SearchEmbeddingAsync(new ReadOnlyMemory([100, 1, 3]), top: 1).SingleAsync(); + Assert.Equal(counter, result.Record.Counter); + } + + [ConditionalFact] + public virtual async Task UpsertAsync_dynamic() + { + var counter = fixture.GenerateNextCounter(); + + var record = new Dictionary + { + [nameof(Record.Key)] = fixture.GenerateNextKey(), + [nameof(Record.Embedding)] = "[200, 1, 0]", + [nameof(Record.Counter)] = counter, + [nameof(Record.Text)] = nameof(UpsertAsync_dynamic) + }; + + // Property level: embedding generators are defined at all levels. The property generator should take precedence. + var collection = this.GetCollection>(storeGenerator: true, collectionGenerator: true, propertyGenerator: true); + + await collection.UpsertAsync(record).ConfigureAwait(false); + + await fixture.TestStore.WaitForDataAsync(collection, 1, filter: r => (int)r[nameof(Record.Counter)] == counter); + + var result = await collection.SearchEmbeddingAsync(new ReadOnlyMemory([200, 1, 3]), top: 1).SingleAsync(); + Assert.Equal(counter, result.Record[nameof(Record.Counter)]); + } + + [ConditionalFact] + public virtual async Task UpsertAsync_batch() + { + var (counter1, counter2) = (fixture.GenerateNextCounter(), fixture.GenerateNextCounter()); + + Record[] records = + [ + new() + { + Key = fixture.GenerateNextKey(), + Embedding = "[300, 1, 0]", + Counter = counter1, + Text = nameof(UpsertAsync_batch) + "1" + }, + new() + { + Key = fixture.GenerateNextKey(), + Embedding = "[400, 1, 0]", + Counter = counter2, + Text = nameof(UpsertAsync_batch) + "2" + } + ]; + + var collection = this.GetCollection(storeGenerator: true, collectionGenerator: true, propertyGenerator: true); + + await collection.UpsertAsync(records).ConfigureAwait(false); + + await fixture.TestStore.WaitForDataAsync(collection, 2, filter: r => (int)r.Counter == counter1 || (int)r.Counter == counter2); + + var result = await collection.SearchEmbeddingAsync(new ReadOnlyMemory([300, 1, 3]), top: 1).SingleAsync(); + Assert.Equal(counter1, result.Record.Counter); + + result = await collection.SearchEmbeddingAsync(new ReadOnlyMemory([400, 1, 3]), top: 1).SingleAsync(); + Assert.Equal(counter2, result.Record.Counter); + } + + [ConditionalFact] + public virtual async Task UpsertAsync_batch_dynamic() + { + var (counter1, counter2) = (fixture.GenerateNextCounter(), fixture.GenerateNextCounter()); + + Dictionary[] records = + [ + new() + { + [nameof(Record.Key)] = fixture.GenerateNextKey(), + [nameof(Record.Embedding)] = "[500, 1, 0]", + [nameof(Record.Counter)] = counter1, + [nameof(Record.Text)] = nameof(UpsertAsync_batch_dynamic) + "1" + }, + new() + { + [nameof(Record.Key)] = fixture.GenerateNextKey(), + [nameof(Record.Embedding)] = "[600, 1, 0]", + [nameof(Record.Counter)] = counter2, + [nameof(Record.Text)] = nameof(UpsertAsync_batch_dynamic) + "2" + } + ]; + + var collection = this.GetCollection>(storeGenerator: true, collectionGenerator: true, propertyGenerator: true); + + await collection.UpsertAsync(records).ConfigureAwait(false); + + await fixture.TestStore.WaitForDataAsync(collection, 2, filter: r => (int)r[nameof(Record.Counter)] == counter1 || (int)r[nameof(Record.Counter)] == counter2); + + var result = await collection.SearchEmbeddingAsync(new ReadOnlyMemory([500, 1, 3]), top: 1).SingleAsync(); + Assert.Equal(counter1, result.Record[nameof(Record.Counter)]); + + result = await collection.SearchEmbeddingAsync(new ReadOnlyMemory([600, 1, 3]), top: 1).SingleAsync(); + Assert.Equal(counter2, result.Record[nameof(Record.Counter)]); + } + + #endregion Upsert + + #region IncludeVectors + + [ConditionalFact] + public virtual async Task SearchAsync_with_IncludeVectors_throws() + { + var collection = this.GetCollection(storeGenerator: true, collectionGenerator: true, propertyGenerator: true); + + var exception = await Assert.ThrowsAsync(() => collection.SearchAsync("[1, 0, 0]", top: 1, new() { IncludeVectors = true }).ToListAsync().AsTask()); + + Assert.Equal("When an embedding generator is configured, `Include Vectors` cannot be enabled.", exception.Message); + } + + [ConditionalFact] + public virtual async Task GetAsync_with_IncludeVectors_throws() + { + var collection = this.GetCollection(storeGenerator: true, collectionGenerator: true, propertyGenerator: true); + + var exception = await Assert.ThrowsAsync(() => collection.GetAsync(fixture.TestData[0].Key, new() { IncludeVectors = true })); + + Assert.Equal("When an embedding generator is configured, `Include Vectors` cannot be enabled.", exception.Message); + } + + [ConditionalFact] + public virtual async Task GetAsync_enumerable_with_IncludeVectors_throws() + { + var collection = this.GetCollection(storeGenerator: true, collectionGenerator: true, propertyGenerator: true); + + var exception = await Assert.ThrowsAsync(() => + collection.GetAsync( + [fixture.TestData[0].Key, fixture.TestData[1].Key], + new() { IncludeVectors = true }) + .ToListAsync().AsTask()); + + Assert.Equal("When an embedding generator is configured, `Include Vectors` cannot be enabled.", exception.Message); + } + + #endregion IncludeVectors + + #region Support + + public class Record + { + public TKey Key { get; set; } = default!; + public string? Embedding { get; set; } + + public int Counter { get; set; } + public string? Text { get; set; } + } + + public class RecordWithAttributes + { + [VectorStoreRecordKey] + public TKey Key { get; set; } = default!; + + [VectorStoreRecordVector(Dimensions: 3)] + public string? Embedding { get; set; } + + [VectorStoreRecordData(IsIndexed = true)] + public int Counter { get; set; } + + [VectorStoreRecordData] + public string? Text { get; set; } + } + + public class RecordWithCustomerVectorProperty + { + public TKey Key { get; set; } = default!; + public Customer? Embedding { get; set; } + + public int Counter { get; set; } + public string? Text { get; set; } + } + + public class Customer + { + public string? FirstName { get; set; } + public string? LastName { get; set; } + } + + private IVectorStoreRecordCollection GetCollection( + bool storeGenerator = false, + bool collectionGenerator = false, + bool propertyGenerator = false) + where TRecord : notnull + { + var properties = fixture.GetRecordDefinition().Properties; + + properties = properties + .Select(p => p is VectorStoreRecordVectorProperty vectorProperty && propertyGenerator + ? new VectorStoreRecordVectorProperty(vectorProperty) { EmbeddingGenerator = new FakeEmbeddingGenerator(replaceLast: 3) } + : p) + .ToList(); + + var recordDefinition = new VectorStoreRecordDefinition + { + EmbeddingGenerator = collectionGenerator ? new FakeEmbeddingGenerator(replaceLast: 2) : null, + Properties = properties + }; + + return fixture.GetCollection( + fixture.CreateVectorStore(storeGenerator ? new FakeEmbeddingGenerator(replaceLast: 1) : null), + fixture.CollectionName, + recordDefinition); + } + + public abstract class Fixture : VectorStoreCollectionFixture + { + private int _counter; + + public override string CollectionName => "EmbeddingGenerationTests"; + + public override VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(Record.Key), typeof(TKey)), + new VectorStoreRecordVectorProperty(nameof(Record.Embedding), typeof(string), dimensions: 3) + { + DistanceFunction = this.DefaultDistanceFunction, + IndexKind = this.DefaultIndexKind + }, + + new VectorStoreRecordDataProperty(nameof(Record.Counter), typeof(int)) { IsIndexed = true }, + new VectorStoreRecordDataProperty(nameof(Record.Text), typeof(string)) + ], + EmbeddingGenerator = new FakeEmbeddingGenerator() + }; + + protected override List BuildTestData() => + [ + new() + { + Key = this.GenerateNextKey(), + Embedding = "[1, 1, 1]", + Counter = this.GenerateNextCounter(), + Text = "Store ([1, 1, 1])" + }, + new() + { + Key = this.GenerateNextKey(), + Embedding = "[1, 1, 2]", + Counter = this.GenerateNextCounter(), + Text = "Collection ([1, 1, 2])" + }, + new() + { + Key = this.GenerateNextKey(), + Embedding = "[1, 1, 3]", + Counter = this.GenerateNextCounter(), + Text = "Property ([1, 1, 3])" + } + ]; + + public virtual IVectorStoreRecordCollection GetCollection( + IVectorStore vectorStore, + string collectionName, + VectorStoreRecordDefinition? recordDefinition = null) + where TRecord : notnull + => vectorStore.GetCollection(collectionName, recordDefinition); + + public abstract IVectorStore CreateVectorStore(IEmbeddingGenerator? embeddingGenerator = null); + + public abstract Func[] DependencyInjectionStoreRegistrationDelegates { get; } + public abstract Func[] DependencyInjectionCollectionRegistrationDelegates { get; } + + public virtual int GenerateNextCounter() + => Interlocked.Increment(ref this._counter); + } + + private sealed class FakeEmbeddingGenerator(int? replaceLast = null) : IEmbeddingGenerator> + { + public Task>> GenerateAsync( + IEnumerable values, + EmbeddingGenerationOptions? options = null, + CancellationToken cancellationToken = default) + { + var results = new GeneratedEmbeddings>(); + + foreach (var value in values) + { + var vector = value.TrimStart('[').TrimEnd(']').Split(',').Select(s => float.Parse(s.Trim())).ToArray(); + + if (replaceLast is not null) + { + vector[vector.Length - 1] = replaceLast.Value; + } + + results.Add(new Embedding(vector)); + } + + return Task.FromResult(results); + } + + public object? GetService(Type serviceType, object? serviceKey = null) + => null; + + public void Dispose() + { + } + } + + private sealed class FakeCustomerEmbeddingGenerator(float[] embedding) : IEmbeddingGenerator> + { + public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + => Task.FromResult(new GeneratedEmbeddings> { new(embedding) }); + + public object? GetService(Type serviceType, object? serviceKey = null) + => null; + + public void Dispose() + { + } + } + + #endregion Support +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTests.cs index dd03c1b1bda7..f1d06ee74aaf 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTests.cs @@ -6,6 +6,10 @@ using VectorDataSpecificationTests.Xunit; using Xunit; +#pragma warning disable CS8605 // Unboxing a possibly null value. +#pragma warning disable CS0252 // Possible unintended reference comparison; left hand side needs cast +#pragma warning disable RCS1098 // Constant values should be placed on right side of comparisons + namespace VectorDataSpecificationTests.Filter; public abstract class BasicFilterTests(BasicFilterTests.Fixture fixture) @@ -15,71 +19,103 @@ public abstract class BasicFilterTests(BasicFilterTests.Fixture fixt [ConditionalFact] public virtual Task Equal_with_int() - => this.TestFilterAsync(r => r.Int == 8); + => this.TestFilterAsync( + r => r.Int == 8, + r => (int)r["Int"] == 8); [ConditionalFact] public virtual Task Equal_with_string() - => this.TestFilterAsync(r => r.String == "foo"); + => this.TestFilterAsync( + r => r.String == "foo", + r => r["String"] == "foo"); [ConditionalFact] public virtual Task Equal_with_string_containing_special_characters() - => this.TestFilterAsync(r => r.String == """with some special"characters'and\stuff"""); + => this.TestFilterAsync( + r => r.String == """with some special"characters'and\stuff""", + r => r["String"] == """with some special"characters'and\stuff"""); [ConditionalFact] public virtual Task Equal_with_string_is_not_Contains() - => this.TestFilterAsync(r => r.String == "some", expectZeroResults: true); + => this.TestFilterAsync( + r => r.String == "some", + r => r["String"] == "some", + expectZeroResults: true); [ConditionalFact] public virtual Task Equal_reversed() - => this.TestFilterAsync(r => 8 == r.Int); + => this.TestFilterAsync( + r => 8 == r.Int, + r => 8 == (int)r["Int"]); [ConditionalFact] public virtual Task Equal_with_null_reference_type() - => this.TestFilterAsync(r => r.String == null); + => this.TestFilterAsync( + r => r.String == null, + r => r["String"] == null); [ConditionalFact] public virtual Task Equal_with_null_captured() { string? s = null; - return this.TestFilterAsync(r => r.String == s); + return this.TestFilterAsync( + r => r.String == s, + r => r["String"] == s); } [ConditionalFact] public virtual Task NotEqual_with_int() - => this.TestFilterAsync(r => r.Int != 8); + => this.TestFilterAsync( + r => r.Int != 8, + r => (int)r["Int"] != 8); [ConditionalFact] public virtual Task NotEqual_with_string() - => this.TestFilterAsync(r => r.String != "foo"); + => this.TestFilterAsync( + r => r.String != "foo", + r => r["String"] != "foo"); [ConditionalFact] public virtual Task NotEqual_reversed() - => this.TestFilterAsync(r => r.Int != 8); + => this.TestFilterAsync( + r => r.Int != 8, + r => (int)r["Int"] != 8); [ConditionalFact] public virtual Task NotEqual_with_null_reference_type() - => this.TestFilterAsync(r => r.String != null); + => this.TestFilterAsync( + r => r.String != null, + r => r["String"] != null); [ConditionalFact] public virtual Task NotEqual_with_null_captured() { string? s = null; - return this.TestFilterAsync(r => r.String != s); + return this.TestFilterAsync( + r => r.String != s, + r => r["String"] != s); } [ConditionalFact] public virtual Task Bool() - => this.TestFilterAsync(r => r.Bool); + => this.TestFilterAsync( + r => r.Bool, + r => (bool)r["Bool"]); [ConditionalFact] public virtual Task Bool_And_Bool() - => this.TestFilterAsync(r => r.Bool && r.Bool); + => this.TestFilterAsync( + r => r.Bool && r.Bool, + r => (bool)r["Bool"] && (bool)r["Bool"]); [ConditionalFact] public virtual Task Bool_Or_Not_Bool() - => this.TestFilterAsync(r => r.Bool || !r.Bool, expectAllResults: true); + => this.TestFilterAsync( + r => r.Bool || !r.Bool, + r => (bool)r["Bool"] || !(bool)r["Bool"], + expectAllResults: true); #endregion Equality @@ -87,19 +123,27 @@ public virtual Task Bool_Or_Not_Bool() [ConditionalFact] public virtual Task GreaterThan_with_int() - => this.TestFilterAsync(r => r.Int > 9); + => this.TestFilterAsync( + r => r.Int > 9, + r => (int)r["Int"] > 9); [ConditionalFact] public virtual Task GreaterThanOrEqual_with_int() - => this.TestFilterAsync(r => r.Int >= 9); + => this.TestFilterAsync( + r => r.Int >= 9, + r => (int)r["Int"] >= 9); [ConditionalFact] public virtual Task LessThan_with_int() - => this.TestFilterAsync(r => r.Int < 10); + => this.TestFilterAsync( + r => r.Int < 10, + r => (int)r["Int"] < 10); [ConditionalFact] public virtual Task LessThanOrEqual_with_int() - => this.TestFilterAsync(r => r.Int <= 10); + => this.TestFilterAsync( + r => r.Int <= 10, + r => (int)r["Int"] <= 10); #endregion Comparison @@ -107,49 +151,71 @@ public virtual Task LessThanOrEqual_with_int() [ConditionalFact] public virtual Task And() - => this.TestFilterAsync(r => r.Int == 8 && r.String == "foo"); + => this.TestFilterAsync( + r => r.Int == 8 && r.String == "foo", + r => (int)r["Int"] == 8 && r["String"] == "foo"); [ConditionalFact] public virtual Task Or() - => this.TestFilterAsync(r => r.Int == 8 || r.String == "foo"); + => this.TestFilterAsync( + r => r.Int == 8 || r.String == "foo", + r => (int)r["Int"] == 8 || r["String"] == "foo"); [ConditionalFact] public virtual Task And_within_And() - => this.TestFilterAsync(r => (r.Int == 8 && r.String == "foo") && r.Int2 == 80); + => this.TestFilterAsync( + r => (r.Int == 8 && r.String == "foo") && r.Int2 == 80, + r => ((int)r["Int"] == 8 && r["String"] == "foo") && (int)r["Int2"] == 80); [ConditionalFact] public virtual Task And_within_Or() - => this.TestFilterAsync(r => (r.Int == 8 && r.String == "foo") || r.Int2 == 100); + => this.TestFilterAsync( + r => (r.Int == 8 && r.String == "foo") || r.Int2 == 100, + r => ((int)r["Int"] == 8 && r["String"] == "foo") || (int)r["Int2"] == 100); [ConditionalFact] public virtual Task Or_within_And() - => this.TestFilterAsync(r => (r.Int == 8 || r.Int == 9) && r.String == "foo"); + => this.TestFilterAsync( + r => (r.Int == 8 || r.Int == 9) && r.String == "foo", + r => ((int)r["Int"] == 8 || (int)r["Int"] == 9) && r["String"] == "foo"); [ConditionalFact] public virtual Task Not_over_Equal() // ReSharper disable once NegativeEqualityExpression - => this.TestFilterAsync(r => !(r.Int == 8)); + => this.TestFilterAsync( + r => !(r.Int == 8), + r => !((int)r["Int"] == 8)); [ConditionalFact] public virtual Task Not_over_NotEqual() // ReSharper disable once NegativeEqualityExpression - => this.TestFilterAsync(r => !(r.Int != 8)); + => this.TestFilterAsync( + r => !(r.Int != 8), + r => !((int)r["Int"] != 8)); [ConditionalFact] public virtual Task Not_over_And() - => this.TestFilterAsync(r => !(r.Int == 8 && r.String == "foo")); + => this.TestFilterAsync( + r => !(r.Int == 8 && r.String == "foo"), + r => !((int)r["Int"] == 8 && r["String"] == "foo")); [ConditionalFact] public virtual Task Not_over_Or() - => this.TestFilterAsync(r => !(r.Int == 8 || r.String == "foo")); + => this.TestFilterAsync( + r => !(r.Int == 8 || r.String == "foo"), + r => !((int)r["Int"] == 8 || r["String"] == "foo")); [ConditionalFact] public virtual Task Not_over_bool() - => this.TestFilterAsync(r => !r.Bool); + => this.TestFilterAsync( + r => !r.Bool, + r => !(bool)r["Bool"]); [ConditionalFact] public virtual Task Not_over_bool_And_Comparison() - => this.TestFilterAsync(r => !r.Bool && r.Int != int.MaxValue); + => this.TestFilterAsync( + r => !r.Bool && r.Int != int.MaxValue, + r => !(bool)r["Bool"] && (int)r["Int"] != int.MaxValue); #endregion Logical operators @@ -157,30 +223,42 @@ public virtual Task Not_over_bool_And_Comparison() [ConditionalFact] public virtual Task Contains_over_field_string_array() - => this.TestFilterAsync(r => r.StringArray.Contains("x")); + => this.TestFilterAsync( + r => r.StringArray.Contains("x"), + r => ((string[])r["StringArray"]!).Contains("x")); [ConditionalFact] public virtual Task Contains_over_field_string_List() - => this.TestFilterAsync(r => r.StringList.Contains("x")); + => this.TestFilterAsync( + r => r.StringList.Contains("x"), + r => ((List)r["StringList"]!).Contains("x")); [ConditionalFact] public virtual Task Contains_over_inline_int_array() - => this.TestFilterAsync(r => new[] { 8, 10 }.Contains(r.Int)); + => this.TestFilterAsync( + r => new[] { 8, 10 }.Contains(r.Int), + r => new[] { 8, 10 }.Contains((int)r["Int"])); [ConditionalFact] public virtual Task Contains_over_inline_string_array() - => this.TestFilterAsync(r => new[] { "foo", "baz", "unknown" }.Contains(r.String)); + => this.TestFilterAsync( + r => new[] { "foo", "baz", "unknown" }.Contains(r.String), + r => new[] { "foo", "baz", "unknown" }.Contains(r["String"])); [ConditionalFact] public virtual Task Contains_over_inline_string_array_with_weird_chars() - => this.TestFilterAsync(r => new[] { "foo", "baz", "un , ' \"" }.Contains(r.String)); + => this.TestFilterAsync( + r => new[] { "foo", "baz", "un , ' \"" }.Contains(r.String), + r => new[] { "foo", "baz", "un , ' \"" }.Contains(r["String"])); [ConditionalFact] public virtual Task Contains_over_captured_string_array() { var array = new[] { "foo", "baz", "unknown" }; - return this.TestFilterAsync(r => array.Contains(r.String)); + return this.TestFilterAsync( + r => array.Contains(r.String), + r => array.Contains(r["String"])); } #endregion Contains @@ -191,7 +269,9 @@ public virtual Task Captured_variable() // ReSharper disable once ConvertToConstant.Local var i = 8; - return this.TestFilterAsync(r => r.Int == i); + return this.TestFilterAsync( + r => r.Int == i, + r => (int)r["Int"] == i); } #region Legacy filter support @@ -226,8 +306,25 @@ public virtual Task Legacy_AnyTagEqualTo_List() #endregion Legacy filter support + protected virtual async Task> GetRecords( + Expression> filter, int top, ReadOnlyMemory vector) + => await fixture.Collection.SearchEmbeddingAsync( + vector, + top: top, + new() { Filter = filter }) + .Select(r => r.Record).OrderBy(r => r.Key).ToListAsync(); + + protected virtual async Task>> GetDynamicRecords( + Expression, bool>> dynamicFilter, int top, ReadOnlyMemory vector) + => await fixture.DynamicCollection.SearchEmbeddingAsync( + vector, + top: top, + new() { Filter = dynamicFilter }) + .Select(r => r.Record).OrderBy(r => r[nameof(FilterRecord.Key)]).ToListAsync(); + protected virtual async Task TestFilterAsync( Expression> filter, + Expression, bool>> dynamicFilter, bool expectZeroResults = false, bool expectAllResults = false) { @@ -243,20 +340,34 @@ protected virtual async Task TestFilterAsync( Assert.Fail("The test returns all results, and so is unreliable"); } - var results = await fixture.Collection.VectorizedSearchAsync( - new ReadOnlyMemory([1, 2, 3]), - new() - { - Filter = filter, - Top = fixture.TestData.Count - }); + // Execute the query against the vector store, once using the strongly typed filter + // and once using the dynamic filter + var actual = await this.GetRecords(filter, fixture.TestData.Count, new ReadOnlyMemory([1, 2, 3])); - var actual = await results.Results.Select(r => r.Record).OrderBy(r => r.Key).ToListAsync(); + if (actual.Count != expected.Count) + { + Assert.Fail($"Expected {expected.Count} results, but got {actual.Count}"); + } - Assert.Equal(expected, actual, (e, a) => - e.Int == a.Int && - e.String == a.String && - e.Int2 == a.Int2); + foreach (var (e, a) in expected.Zip(actual, (e, a) => (e, a))) + { + fixture.AssertEqualFilterRecord(e, a); + } + + if (fixture.TestDynamic) + { + var dynamicActual = await this.GetDynamicRecords(dynamicFilter, fixture.TestData.Count, new ReadOnlyMemory([1, 2, 3])); + + if (dynamicActual.Count != expected.Count) + { + Assert.Fail($"Expected {expected.Count} results, but got {actual.Count}"); + } + + foreach (var (e, a) in expected.Zip(dynamicActual, (e, a) => (e, a))) + { + fixture.AssertEqualDynamic(e, a); + } + } } [Obsolete("Legacy filter support")] @@ -278,20 +389,19 @@ protected virtual async Task TestLegacyFilterAsync( Assert.Fail("The test returns all results, and so is unreliable"); } - var results = await fixture.Collection.VectorizedSearchAsync( - new ReadOnlyMemory([1, 2, 3]), - new() - { - OldFilter = legacyFilter, - Top = fixture.TestData.Count - }); - - var actual = await results.Results.Select(r => r.Record).OrderBy(r => r.Key).ToListAsync(); + var actual = await fixture.Collection.VectorizedSearchAsync( + new ReadOnlyMemory([1, 2, 3]), + top: fixture.TestData.Count, + new() + { + OldFilter = legacyFilter + }) + .Select(r => r.Record).OrderBy(r => r.Key).ToListAsync(); - Assert.Equal(expected, actual, (e, a) => - e.Int == a.Int && - e.String == a.String && - e.Int2 == a.Int2); + foreach (var (e, a) in expected.Zip(actual, (e, a) => (e, a))) + { + fixture.AssertEqualFilterRecord(e, a); + } } #pragma warning disable CS1819 // Properties should not return arrays @@ -313,35 +423,50 @@ public class FilterRecord public abstract class Fixture : VectorStoreCollectionFixture { - protected override string CollectionName => "FilterTests"; + public override string CollectionName => "FilterTests"; + + protected virtual ReadOnlyMemory GetVector(int count) + // All records have the same vector - this fixture is about testing criteria filtering only + // Derived types may override this to provide different vectors for different records. + => new(Enumerable.Range(1, count).Select(i => (float)i).ToArray()); + + public virtual IVectorStoreRecordCollection> DynamicCollection { get; protected set; } = null!; + + public virtual bool TestDynamic => true; + + public override async Task InitializeAsync() + { + await base.InitializeAsync(); - protected override VectorStoreRecordDefinition GetRecordDefinition() + if (this.TestDynamic) + { + this.DynamicCollection = this.TestStore.DefaultVectorStore.GetCollection>(this.CollectionName, this.GetRecordDefinition()); + } + } + + public override VectorStoreRecordDefinition GetRecordDefinition() => new() { Properties = [ new VectorStoreRecordKeyProperty(nameof(FilterRecord.Key), typeof(TKey)), - new VectorStoreRecordVectorProperty(nameof(FilterRecord.Vector), typeof(ReadOnlyMemory?)) + new VectorStoreRecordVectorProperty(nameof(FilterRecord.Vector), typeof(ReadOnlyMemory?), 3) { - Dimensions = 3, DistanceFunction = this.DistanceFunction, IndexKind = this.IndexKind }, - new VectorStoreRecordDataProperty(nameof(FilterRecord.Int), typeof(int)) { IsFilterable = true }, - new VectorStoreRecordDataProperty(nameof(FilterRecord.String), typeof(string)) { IsFilterable = true }, - new VectorStoreRecordDataProperty(nameof(FilterRecord.Bool), typeof(bool)) { IsFilterable = true }, - new VectorStoreRecordDataProperty(nameof(FilterRecord.Int2), typeof(int)) { IsFilterable = true }, - new VectorStoreRecordDataProperty(nameof(FilterRecord.StringArray), typeof(string[])) { IsFilterable = true }, - new VectorStoreRecordDataProperty(nameof(FilterRecord.StringList), typeof(List)) { IsFilterable = true } + new VectorStoreRecordDataProperty(nameof(FilterRecord.Int), typeof(int)) { IsIndexed = true }, + new VectorStoreRecordDataProperty(nameof(FilterRecord.String), typeof(string)) { IsIndexed = true }, + new VectorStoreRecordDataProperty(nameof(FilterRecord.Bool), typeof(bool)) { IsIndexed = true }, + new VectorStoreRecordDataProperty(nameof(FilterRecord.Int2), typeof(int)) { IsIndexed = true }, + new VectorStoreRecordDataProperty(nameof(FilterRecord.StringArray), typeof(string[])) { IsIndexed = true }, + new VectorStoreRecordDataProperty(nameof(FilterRecord.StringList), typeof(List)) { IsIndexed = true } ] }; protected override List BuildTestData() { - // All records have the same vector - this fixture is about testing criteria filtering only - var vector = new ReadOnlyMemory([1, 2, 3]); - return [ new() @@ -353,7 +478,7 @@ protected override List BuildTestData() Int2 = 80, StringArray = ["x", "y"], StringList = ["x", "y"], - Vector = vector + Vector = this.GetVector(3) }, new() { @@ -364,7 +489,7 @@ protected override List BuildTestData() Int2 = 90, StringArray = ["a", "b"], StringList = ["a", "b"], - Vector = vector + Vector = this.GetVector(3) }, new() { @@ -375,7 +500,7 @@ protected override List BuildTestData() Int2 = 9, StringArray = ["x"], StringList = ["x"], - Vector = vector + Vector = this.GetVector(3) }, new() { @@ -386,7 +511,7 @@ protected override List BuildTestData() Int2 = 100, StringArray = ["x", "y", "z"], StringList = ["x", "y", "z"], - Vector = vector + Vector = this.GetVector(3) }, new() { @@ -397,11 +522,61 @@ protected override List BuildTestData() Int2 = 101, StringArray = ["y", "z"], StringList = ["y", "z"], - Vector = vector + Vector = this.GetVector(3) } ]; } + public virtual void AssertEqualFilterRecord(FilterRecord x, FilterRecord y) + { + var definitionProperties = this.GetRecordDefinition().Properties; + + Assert.Equal(x.Key, y.Key); + Assert.Equal(x.Int, y.Int); + Assert.Equal(x.String, y.String); + Assert.Equal(x.Int2, y.Int2); + + if (definitionProperties.Any(p => p.DataModelPropertyName == nameof(FilterRecord.Bool))) + { + Assert.Equal(x.Bool, y.Bool); + } + + if (definitionProperties.Any(p => p.DataModelPropertyName == nameof(FilterRecord.StringArray))) + { + Assert.Equivalent(x.StringArray, y.StringArray); + } + + if (definitionProperties.Any(p => p.DataModelPropertyName == nameof(FilterRecord.StringList))) + { + Assert.Equivalent(x.StringList, y.StringList); + } + } + + public virtual void AssertEqualDynamic(FilterRecord x, Dictionary y) + { + var definitionProperties = this.GetRecordDefinition().Properties; + + Assert.Equal(x.Key, y["Key"]); + Assert.Equal(x.Int, y["Int"]); + Assert.Equal(x.String, y["String"]); + Assert.Equal(x.Int2, y["Int2"]); + + if (definitionProperties.Any(p => p.DataModelPropertyName == nameof(FilterRecord.Bool))) + { + Assert.Equal(x.Bool, y["Bool"]); + } + + if (definitionProperties.Any(p => p.DataModelPropertyName == nameof(FilterRecord.StringArray))) + { + Assert.Equivalent(x.StringArray, y["StringArray"]); + } + + if (definitionProperties.Any(p => p.DataModelPropertyName == nameof(FilterRecord.StringList))) + { + Assert.Equivalent(x.StringList, y["StringList"]); + } + } + // In some databases (Azure AI Search), the data shows up but the filtering index isn't yet updated, // so filtered searches show empty results. Add a filter to the seed data check below. protected override Task WaitForDataAsync() diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicQueryTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicQueryTests.cs new file mode 100644 index 000000000000..1a8c6649cd87 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicQueryTests.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq.Expressions; + +namespace VectorDataSpecificationTests.Filter; + +public abstract class BasicQueryTests(BasicQueryTests.QueryFixture fixture) + : BasicFilterTests(fixture) where TKey : notnull +{ + protected override async Task> GetRecords(Expression> filter, int top, ReadOnlyMemory vector) + => (await fixture.Collection.GetAsync(filter, top).ToListAsync()).OrderBy(r => r.Key).ToList(); + + protected override async Task>> GetDynamicRecords(Expression, bool>> dynamicFilter, int top, ReadOnlyMemory vector) + => (await fixture.DynamicCollection.GetAsync(dynamicFilter, top).ToListAsync()).OrderBy(r => r[nameof(FilterRecord.Key)]!).ToList(); + + [Obsolete("Not used by derived types")] + public sealed override Task Legacy_And() => Task.CompletedTask; + + [Obsolete("Not used by derived types")] + public sealed override Task Legacy_equality() => Task.CompletedTask; + + [Obsolete("Not used by derived types")] + public sealed override Task Legacy_AnyTagEqualTo_array() => Task.CompletedTask; + + [Obsolete("Not used by derived types")] + public sealed override Task Legacy_AnyTagEqualTo_List() => Task.CompletedTask; + + public abstract class QueryFixture : BasicFilterTests.Fixture + { + private static readonly Random s_random = new(); + + public override string CollectionName => "QueryTests"; + + /// + /// Use random vectors to make sure that the values don't matter for GetAsync. + /// + protected override ReadOnlyMemory GetVector(int count) +#pragma warning disable CA5394 // Do not use insecure randomness + => new(Enumerable.Range(0, count).Select(_ => (float)s_random.NextDouble()).ToArray()); +#pragma warning restore CA5394 // Do not use insecure randomness + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/HybridSearch/KeywordVectorizedHybridSearchComplianceTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/HybridSearch/KeywordVectorizedHybridSearchComplianceTests.cs index c25bb065ba74..b89a65ad6bfa 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/HybridSearch/KeywordVectorizedHybridSearchComplianceTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/HybridSearch/KeywordVectorizedHybridSearchComplianceTests.cs @@ -29,10 +29,8 @@ public async Task SearchShouldReturnExpectedResultsAsync() // Act // All records have the same vector, but the third contains Grapes, so searching for // Grapes should return the third record first. - var searchResult = await hybridSearch!.HybridSearchAsync(vector, ["Grapes"]); - + var results = await hybridSearch!.HybridSearchAsync(vector, ["Grapes"], top: 3).ToListAsync(); // Assert - var results = await searchResult.Results.ToListAsync(); Assert.Equal(3, results.Count); Assert.Equal(3, results[0].Record.Code); @@ -55,10 +53,9 @@ public async Task SearchWithFilterShouldReturnExpectedResultsAsync() OldFilter = new VectorSearchFilter().EqualTo("Code", 1) }; #pragma warning restore CS0618 // Type or member is obsolete - var searchResult = await hybridSearch!.HybridSearchAsync(vector, ["Oranges"], options); + var results = await hybridSearch!.HybridSearchAsync(vector, ["Oranges"], top: 3, options).ToListAsync(); // Assert - var results = await searchResult.Results.ToListAsync(); Assert.Single(results); Assert.Equal(1, results[0].Record.Code); @@ -75,10 +72,9 @@ public async Task SearchWithTopShouldReturnExpectedResultsAsync() // Act // All records have the same vector, but the second contains Oranges, so the // second should be returned first. - var searchResult = await hybridSearch!.HybridSearchAsync(vector, ["Oranges"], new() { Top = 1 }); + var results = await hybridSearch!.HybridSearchAsync(vector, ["Oranges"], top: 1).ToListAsync(); // Assert - var results = await searchResult.Results.ToListAsync(); Assert.Single(results); Assert.Equal(2, results[0].Record.Code); @@ -95,10 +91,9 @@ public async Task SearchWithSkipShouldReturnExpectedResultsAsync() // Act // All records have the same vector, but the first and third contain healthy, // so when skipping the first two results, we should get the second record. - var searchResult = await hybridSearch!.HybridSearchAsync(vector, ["healthy"], new() { Skip = 2 }); + var results = await hybridSearch!.HybridSearchAsync(vector, ["healthy"], top: 3, new() { Skip = 2 }).ToListAsync(); // Assert - var results = await searchResult.Results.ToListAsync(); Assert.Single(results); Assert.Equal(2, results[0].Record.Code); @@ -113,10 +108,9 @@ public async Task SearchWithMultipleKeywordsShouldRankMatchedKeywordsHigherAsync var vector = new ReadOnlyMemory([1, 0, 0, 0]); // Act - var searchResult = await hybridSearch!.HybridSearchAsync(vector, ["tangy", "nourishing"]); + var results = await hybridSearch!.HybridSearchAsync(vector, ["tangy", "nourishing"], top: 3).ToListAsync(); // Assert - var results = await searchResult.Results.ToListAsync(); Assert.Equal(3, results.Count); Assert.True(results[0].Record.Code.Equals(1) || results[0].Record.Code.Equals(2)); @@ -133,17 +127,15 @@ public async Task SearchWithMultiTextRecordSearchesRequestedFieldAsync() var vector = new ReadOnlyMemory([1, 0, 0, 0]); // Act - var searchResult1 = await hybridSearch!.HybridSearchAsync(vector, ["Apples"], new() { AdditionalProperty = r => r.Text2 }); - var searchResult2 = await hybridSearch!.HybridSearchAsync(vector, ["Oranges"], new() { AdditionalProperty = r => r.Text2 }); + var results1 = await hybridSearch!.HybridSearchAsync(vector, ["Apples"], top: 3, new() { AdditionalProperty = r => r.Text2 }).ToListAsync(); + var results2 = await hybridSearch!.HybridSearchAsync(vector, ["Oranges"], top: 3, new() { AdditionalProperty = r => r.Text2 }).ToListAsync(); // Assert - var results1 = await searchResult1.Results.ToListAsync(); Assert.Equal(2, results1.Count); Assert.Equal(2, results1[0].Record.Code); Assert.Equal(1, results1[1].Record.Code); - var results2 = await searchResult2.Results.ToListAsync(); Assert.Equal(2, results2.Count); Assert.Equal(1, results2[0].Record.Code); @@ -176,17 +168,17 @@ public sealed class MultiTextStringRecord public abstract class VectorAndStringFixture : VectorStoreCollectionFixture> { - protected override string CollectionName => "KeywordHybridSearch" + this.GetUniqueCollectionName(); + public override string CollectionName => "KeywordHybridSearch" + this.GetUniqueCollectionName(); - protected override VectorStoreRecordDefinition GetRecordDefinition() + public override VectorStoreRecordDefinition GetRecordDefinition() => new() { Properties = new List() { new VectorStoreRecordKeyProperty("Key", typeof(TKey)), - new VectorStoreRecordDataProperty("Text", typeof(string)) { IsFullTextSearchable = true }, - new VectorStoreRecordDataProperty("Code", typeof(int)) { IsFilterable = true }, - new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) { Dimensions = 4, IndexKind = this.IndexKind }, + new VectorStoreRecordDataProperty("Text", typeof(string)) { IsFullTextIndexed = true }, + new VectorStoreRecordDataProperty("Code", typeof(int)) { IsIndexed = true }, + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 4) { IndexKind = this.IndexKind }, } }; @@ -229,18 +221,18 @@ protected override Task WaitForDataAsync() public abstract class MultiTextFixture : VectorStoreCollectionFixture> { - protected override string CollectionName => "KeywordHybridSearch" + this.GetUniqueCollectionName(); + public override string CollectionName => "KeywordHybridSearch" + this.GetUniqueCollectionName(); - protected override VectorStoreRecordDefinition GetRecordDefinition() + public override VectorStoreRecordDefinition GetRecordDefinition() => new() { Properties = new List() { new VectorStoreRecordKeyProperty("Key", typeof(TKey)), - new VectorStoreRecordDataProperty("Text1", typeof(string)) { IsFullTextSearchable = true }, - new VectorStoreRecordDataProperty("Text2", typeof(string)) { IsFullTextSearchable = true }, - new VectorStoreRecordDataProperty("Code", typeof(int)) { IsFilterable = true }, - new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) { Dimensions = 4, IndexKind = this.IndexKind }, + new VectorStoreRecordDataProperty("Text1", typeof(string)) { IsFullTextIndexed = true }, + new VectorStoreRecordDataProperty("Text2", typeof(string)) { IsFullTextIndexed = true }, + new VectorStoreRecordDataProperty("Code", typeof(int)) { IsIndexed = true }, + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory), 4) { IndexKind = this.IndexKind }, } }; diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Models/SimpleModel.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Models/SimpleRecord.cs similarity index 77% rename from dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Models/SimpleModel.cs rename to dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Models/SimpleRecord.cs index 13a47e386516..57f15a00e9a4 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Models/SimpleModel.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Models/SimpleRecord.cs @@ -10,7 +10,7 @@ namespace VectorDataSpecificationTests.Models; /// a key, int, string and an embedding. /// /// TKey is a generic parameter because different connectors support different key types. -public sealed class SimpleModel +public sealed class SimpleRecord { public const int DimensionCount = 3; @@ -26,7 +26,7 @@ public sealed class SimpleModel [VectorStoreRecordVector(Dimensions: DimensionCount, StoragePropertyName = "embedding")] public ReadOnlyMemory Floats { get; set; } - public void AssertEqual(SimpleModel? other, bool includeVectors) + public void AssertEqual(SimpleRecord? other, bool includeVectors, bool compareVectors) { Assert.NotNull(other); Assert.Equal(this.Id, other.Id); @@ -35,7 +35,12 @@ public void AssertEqual(SimpleModel? other, bool includeVectors) if (includeVectors) { - Assert.Equal(this.Floats.ToArray(), other.Floats.ToArray()); + Assert.Equal(this.Floats.Span.Length, other.Floats.Span.Length); + + if (compareVectors) + { + Assert.Equal(this.Floats.ToArray(), other.Floats.ToArray()); + } } else { diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/DynamicDataModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/DynamicDataModelFixture.cs new file mode 100644 index 000000000000..089b253b79f9 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/DynamicDataModelFixture.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; + +namespace VectorDataSpecificationTests.Support; + +public abstract class DynamicDataModelFixture : VectorStoreCollectionFixture> +{ + public const string KeyPropertyName = "key"; + public const string StringPropertyName = "text"; + public const string IntegerPropertyName = "integer"; + public const string EmbeddingPropertyName = "embedding"; + public const int DimensionCount = 3; + + public override VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = + [ + new VectorStoreRecordKeyProperty(KeyPropertyName, typeof(TKey)), + new VectorStoreRecordDataProperty(StringPropertyName, typeof(string)), + new VectorStoreRecordDataProperty(IntegerPropertyName, typeof(int)), + new VectorStoreRecordVectorProperty(EmbeddingPropertyName, typeof(ReadOnlyMemory), DimensionCount) + ] + }; + + protected override List> BuildTestData() => + [ + new() + { + [KeyPropertyName] = this.GenerateNextKey(), + [StringPropertyName] = "first", + [IntegerPropertyName] = 1, + [EmbeddingPropertyName] = new ReadOnlyMemory(Enumerable.Repeat(0.1f, DimensionCount).ToArray()) + }, + new() + { + [KeyPropertyName] = this.GenerateNextKey(), + [StringPropertyName] = "second", + [IntegerPropertyName] = 2, + [EmbeddingPropertyName] = new ReadOnlyMemory(Enumerable.Repeat(0.2f, DimensionCount).ToArray()) + }, + new() + { + [KeyPropertyName] = this.GenerateNextKey(), + [StringPropertyName] = "third", + [IntegerPropertyName] = 3, + [EmbeddingPropertyName] = new ReadOnlyMemory(Enumerable.Repeat(0.3f, DimensionCount).ToArray()) + }, + new() + { + [KeyPropertyName] = this.GenerateNextKey(), + [StringPropertyName] = "fourth", + [IntegerPropertyName] = 4, + [EmbeddingPropertyName] = new ReadOnlyMemory(Enumerable.Repeat(0.4f, DimensionCount).ToArray()) + } + ]; +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/GenericDataModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/GenericDataModelFixture.cs deleted file mode 100644 index 333ec1cdfea8..000000000000 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/GenericDataModelFixture.cs +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using Microsoft.Extensions.VectorData; - -namespace VectorDataSpecificationTests.Support; - -public abstract class GenericDataModelFixture : VectorStoreCollectionFixture> - where TKey : notnull -{ - public const string KeyPropertyName = "key"; - public const string StringPropertyName = "text"; - public const string IntegerPropertyName = "integer"; - public const string EmbeddingPropertyName = "embedding"; - public const int DimensionCount = 3; - - protected override VectorStoreRecordDefinition GetRecordDefinition() - => new() - { - Properties = - [ - new VectorStoreRecordKeyProperty(KeyPropertyName, typeof(TKey)), - new VectorStoreRecordDataProperty(StringPropertyName, typeof(string)), - new VectorStoreRecordDataProperty(IntegerPropertyName, typeof(int)), - new VectorStoreRecordVectorProperty(EmbeddingPropertyName, typeof(ReadOnlyMemory)) - { - Dimensions = DimensionCount - } - ] - }; - - protected override List> BuildTestData() => - [ - new(this.GenerateNextKey()) - { - Data = - { - [StringPropertyName] = "first", - [IntegerPropertyName] = 1 - }, - Vectors = - { - [EmbeddingPropertyName] = new ReadOnlyMemory(Enumerable.Repeat(0.1f, DimensionCount).ToArray()) - } - }, - new(this.GenerateNextKey()) - { - Data = - { - [StringPropertyName] = "second", - [IntegerPropertyName] = 2 - }, - Vectors = - { - [EmbeddingPropertyName] = new ReadOnlyMemory(Enumerable.Repeat(0.2f, DimensionCount).ToArray()) - } - }, - new(this.GenerateNextKey()) - { - Data = - { - [StringPropertyName] = "third", - [IntegerPropertyName] = 3 - }, - Vectors = - { - [EmbeddingPropertyName] = new ReadOnlyMemory(Enumerable.Repeat(0.3f, DimensionCount).ToArray()) - } - }, - new(this.GenerateNextKey()) - { - Data = - { - [StringPropertyName] = "fourth", - [IntegerPropertyName] = 4 - }, - Vectors = - { - [EmbeddingPropertyName] = new ReadOnlyMemory(Enumerable.Repeat(0.4f, DimensionCount).ToArray()) - } - } - ]; -} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/SimpleModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/SimpleModelFixture.cs index b5c688c01835..a3091cdc995e 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/SimpleModelFixture.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/SimpleModelFixture.cs @@ -5,56 +5,55 @@ namespace VectorDataSpecificationTests.Support; -public abstract class SimpleModelFixture : VectorStoreCollectionFixture> +public abstract class SimpleModelFixture : VectorStoreCollectionFixture> where TKey : notnull { - protected override List> BuildTestData() => + protected override List> BuildTestData() => [ new() { Id = this.GenerateNextKey(), Number = 1, Text = "UsedByGetTests", - Floats = Enumerable.Repeat(0.1f, SimpleModel.DimensionCount).ToArray() + Floats = Enumerable.Repeat(0.1f, SimpleRecord.DimensionCount).ToArray() }, new() { Id = this.GenerateNextKey(), Number = 2, Text = "UsedByUpdateTests", - Floats = Enumerable.Repeat(0.2f, SimpleModel.DimensionCount).ToArray() + Floats = Enumerable.Repeat(0.2f, SimpleRecord.DimensionCount).ToArray() }, new() { Id = this.GenerateNextKey(), Number = 3, Text = "UsedByDeleteTests", - Floats = Enumerable.Repeat(0.3f, SimpleModel.DimensionCount).ToArray() + Floats = Enumerable.Repeat(0.3f, SimpleRecord.DimensionCount).ToArray() }, new() { Id = this.GenerateNextKey(), Number = 4, Text = "UsedByDeleteBatchTests", - Floats = Enumerable.Repeat(0.4f, SimpleModel.DimensionCount).ToArray() + Floats = Enumerable.Repeat(0.4f, SimpleRecord.DimensionCount).ToArray() } ]; - protected override VectorStoreRecordDefinition GetRecordDefinition() + public override VectorStoreRecordDefinition GetRecordDefinition() => new() { Properties = [ - new VectorStoreRecordKeyProperty(nameof(SimpleModel.Id), typeof(TKey)), - new VectorStoreRecordVectorProperty(nameof(SimpleModel.Floats), typeof(ReadOnlyMemory?)) + new VectorStoreRecordKeyProperty(nameof(SimpleRecord.Id), typeof(TKey)), + new VectorStoreRecordVectorProperty(nameof(SimpleRecord.Floats), typeof(ReadOnlyMemory?), SimpleRecord.DimensionCount) { - Dimensions = SimpleModel.DimensionCount, DistanceFunction = this.DistanceFunction, IndexKind = this.IndexKind }, - new VectorStoreRecordDataProperty(nameof(SimpleModel.Number), typeof(int)) { IsFilterable = true }, - new VectorStoreRecordDataProperty(nameof(SimpleModel.Text), typeof(string)) { IsFilterable = true }, + new VectorStoreRecordDataProperty(nameof(SimpleRecord.Number), typeof(int)) { IsIndexed = true }, + new VectorStoreRecordDataProperty(nameof(SimpleRecord.Text), typeof(string)) { IsIndexed = true }, ] }; } diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/TestStore.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/TestStore.cs index cd92e070abf8..b8fd45a2ae34 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/TestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/TestStore.cs @@ -13,6 +13,11 @@ public abstract class TestStore private readonly SemaphoreSlim _lock = new(1, 1); private int _referenceCount; + /// + /// Some databases modify vectors on upsert, e.g. normalizing them, so vectors + /// returned cannot be compared with the original ones. + /// + public virtual bool VectorsComparable => true; public virtual string DefaultDistanceFunction => DistanceFunction.CosineSimilarity; public virtual string DefaultIndexKind => IndexKind.Flat; @@ -75,6 +80,7 @@ public virtual async Task WaitForDataAsync( Expression>? filter = null, int vectorSize = 3) where TKey : notnull + where TRecord : notnull { var vector = new float[vectorSize]; for (var i = 0; i < vectorSize; i++) @@ -84,20 +90,19 @@ public virtual async Task WaitForDataAsync( for (var i = 0; i < 20; i++) { - var results = await collection.VectorizedSearchAsync( + var results = collection.SearchEmbeddingAsync( new ReadOnlyMemory(vector), - new() - { - Top = recordCount, - // In some databases (Azure AI Search), the data shows up but the filtering index isn't yet updated, - // so filtered searches show empty results. Add a filter to the seed data check below. - Filter = filter - }); - var count = await results.Results.CountAsync(); + top: 1000, // TODO: this should be recordCount, but see #11655 + new() { Filter = filter }); + var count = await results.CountAsync(); if (count == recordCount) { return; } + if (count > recordCount) + { + throw new InvalidOperationException($"Expected at most {recordCount} records, but found {count}."); + } await Task.Delay(TimeSpan.FromMilliseconds(100)); } diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/VectorStoreCollectionFixture.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/VectorStoreCollectionFixture.cs index 9ae5703056f2..c76f75d46d47 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/VectorStoreCollectionFixture.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/VectorStoreCollectionFixture.cs @@ -4,30 +4,33 @@ namespace VectorDataSpecificationTests.Support; +#pragma warning disable CA1721 // Property names should not match get methods + /// /// A test fixture that sets up a single collection in the test vector store, with a specific record definition /// and test data. /// public abstract class VectorStoreCollectionFixture : VectorStoreFixture where TKey : notnull + where TRecord : notnull { private List? _testData; - protected abstract VectorStoreRecordDefinition GetRecordDefinition(); + public abstract VectorStoreRecordDefinition GetRecordDefinition(); protected abstract List BuildTestData(); - protected virtual string CollectionName => Guid.NewGuid().ToString(); + public virtual string CollectionName => Guid.NewGuid().ToString(); protected virtual string DistanceFunction => this.TestStore.DefaultDistanceFunction; protected virtual string IndexKind => this.TestStore.DefaultIndexKind; - protected virtual IVectorStoreRecordCollection CreateCollection() + protected virtual IVectorStoreRecordCollection GetCollection() => this.TestStore.DefaultVectorStore.GetCollection(this.CollectionName, this.GetRecordDefinition()); public override async Task InitializeAsync() { await base.InitializeAsync(); - this.Collection = this.CreateCollection(); + this.Collection = this.GetCollection(); if (await this.Collection.CollectionExistsAsync()) { @@ -44,11 +47,7 @@ public override async Task InitializeAsync() protected virtual async Task SeedAsync() { - // TODO: UpsertBatchAsync returns IAsyncEnumerable (to support server-generated keys?), but this makes it quite hard to use: - await foreach (var _ in this.Collection.UpsertBatchAsync(this.TestData)) - { - } - + await this.Collection.UpsertAsync(this.TestData); await this.WaitForDataAsync(); } diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj index 77fc8e90dbb2..382f8eb16e78 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj @@ -6,11 +6,14 @@ enable false VectorDataSpecificationTests + $(NoWarn);MEVD9000 + + diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/VectorSearchDistanceFunctionComplianceTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/VectorSearchDistanceFunctionComplianceTests.cs index 16c6a5f46c7b..0e2021c60f1f 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/VectorSearchDistanceFunctionComplianceTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/VectorSearchDistanceFunctionComplianceTests.cs @@ -101,28 +101,28 @@ protected async Task SimpleSearch(string distanceFunction, double expectedExactM try { - await collection.UpsertBatchAsync(insertedRecords).ToArrayAsync(); + await collection.UpsertAsync(insertedRecords); - var searchResult = await collection.VectorizedSearchAsync(baseVector); - var results = await searchResult.Results.ToListAsync(); + var searchResult = collection.SearchEmbeddingAsync(baseVector, top: 3); + var results = await searchResult.ToListAsync(); VerifySearchResults(expectedRecords, expectedScores, results, includeVectors: false); - searchResult = await collection.VectorizedSearchAsync(baseVector, new() { IncludeVectors = true }); - results = await searchResult.Results.ToListAsync(); + searchResult = collection.SearchEmbeddingAsync(baseVector, top: 3, new() { IncludeVectors = true }); + results = await searchResult.ToListAsync(); VerifySearchResults(expectedRecords, expectedScores, results, includeVectors: true); for (int skip = 0; skip <= insertedRecords.Count; skip++) { for (int top = Math.Max(1, skip); top <= insertedRecords.Count; top++) { - searchResult = await collection.VectorizedSearchAsync(baseVector, + searchResult = collection.SearchEmbeddingAsync(baseVector, + top: top, new() { Skip = skip, - Top = top, IncludeVectors = true }); - results = await searchResult.Results.ToListAsync(); + results = await searchResult.ToListAsync(); VerifySearchResults( expectedRecords.Skip(skip).Take(top).ToArray(), @@ -165,14 +165,13 @@ private VectorStoreRecordDefinition GetRecordDefinition(string distanceFunction) Properties = [ new VectorStoreRecordKeyProperty(nameof(SearchRecord.Key), typeof(TKey)), - new VectorStoreRecordVectorProperty(nameof(SearchRecord.Vector), typeof(ReadOnlyMemory)) + new VectorStoreRecordVectorProperty(nameof(SearchRecord.Vector), typeof(ReadOnlyMemory), 4) { - Dimensions = 4, DistanceFunction = distanceFunction, IndexKind = this.IndexKind }, - new VectorStoreRecordDataProperty(nameof(SearchRecord.Int), typeof(int)) { IsFilterable = true }, - new VectorStoreRecordDataProperty(nameof(SearchRecord.String), typeof(string)) { IsFilterable = true }, + new VectorStoreRecordDataProperty(nameof(SearchRecord.Int), typeof(int)) { IsIndexed = true }, + new VectorStoreRecordDataProperty(nameof(SearchRecord.String), typeof(string)) { IsIndexed = true }, ] }; diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryAttribute.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryAttribute.cs index 529f42ef1310..6e14179e7e99 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryAttribute.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryAttribute.cs @@ -6,5 +6,5 @@ namespace VectorDataSpecificationTests.Xunit; [AttributeUsage(AttributeTargets.Method)] -[XunitTestCaseDiscoverer("VectorDataSpecificationTests.Xunit.VectorStoreFactDiscoverer", "VectorDataIntegrationTests")] +[XunitTestCaseDiscoverer("VectorDataSpecificationTests.Xunit.ConditionalTheoryDiscoverer", "VectorDataIntegrationTests")] public sealed class ConditionalTheoryAttribute : TheoryAttribute; diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryDiscoverer.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryDiscoverer.cs new file mode 100644 index 000000000000..ade08a828148 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryDiscoverer.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace VectorDataSpecificationTests.Xunit; + +/// +/// Used dynamically from . +/// Make sure to update that class if you move this type. +/// +public class ConditionalTheoryDiscoverer(IMessageSink messageSink) : TheoryDiscoverer(messageSink) +{ + protected override IEnumerable CreateTestCasesForTheory( + ITestFrameworkDiscoveryOptions discoveryOptions, + ITestMethod testMethod, + IAttributeInfo theoryAttribute) + { + yield return new ConditionalTheoryTestCase( + this.DiagnosticMessageSink, + discoveryOptions.MethodDisplayOrDefault(), + discoveryOptions.MethodDisplayOptionsOrDefault(), + testMethod); + } + + protected override IEnumerable CreateTestCasesForDataRow( + ITestFrameworkDiscoveryOptions discoveryOptions, + ITestMethod testMethod, + IAttributeInfo theoryAttribute, + object[] dataRow) + { + yield return new ConditionalFactTestCase( + this.DiagnosticMessageSink, + discoveryOptions.MethodDisplayOrDefault(), + discoveryOptions.MethodDisplayOptionsOrDefault(), + testMethod, + dataRow); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryTestCase.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryTestCase.cs new file mode 100644 index 000000000000..f96ec3d9d691 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryTestCase.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace VectorDataSpecificationTests.Xunit; + +public sealed class ConditionalTheoryTestCase : XunitTheoryTestCase +{ + [Obsolete("Called by the de-serializer; should only be called by deriving classes for de-serialization purposes")] + public ConditionalTheoryTestCase() + { + } + + public ConditionalTheoryTestCase( + IMessageSink diagnosticMessageSink, + TestMethodDisplay defaultMethodDisplay, + TestMethodDisplayOptions defaultMethodDisplayOptions, + ITestMethod testMethod) + : base(diagnosticMessageSink, defaultMethodDisplay, defaultMethodDisplayOptions, testMethod) + { + } + + public override async Task RunAsync( + IMessageSink diagnosticMessageSink, + IMessageBus messageBus, + object[] constructorArguments, + ExceptionAggregator aggregator, + CancellationTokenSource cancellationTokenSource) + => await XunitTestCaseExtensions.TrySkipAsync(this, messageBus) + ? new RunSummary { Total = 1, Skipped = 1 } + : await base.RunAsync( + diagnosticMessageSink, + messageBus, + constructorArguments, + aggregator, + cancellationTokenSource); +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/CRUD/WeaviateBatchConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/CRUD/WeaviateBatchConformanceTests.cs new file mode 100644 index 000000000000..4222608ab7c0 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/CRUD/WeaviateBatchConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.CRUD; +using WeaviateIntegrationTests.Support; +using Xunit; + +namespace WeaviateIntegrationTests.CRUD; + +public class WeaviateBatchConformanceTests_NamedVectors(WeaviateSimpleModelNamedVectorsFixture fixture) + : BatchConformanceTests(fixture), IClassFixture +{ +} + +public class WeaviateBatchConformanceTests_UnnamedVector(WeaviateSimpleModelUnnamedVectorFixture fixture) + : BatchConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/CRUD/WeaviateDynamicRecordConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/CRUD/WeaviateDynamicRecordConformanceTests.cs new file mode 100644 index 000000000000..62825fac4ab1 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/CRUD/WeaviateDynamicRecordConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.CRUD; +using WeaviateIntegrationTests.Support; +using Xunit; + +namespace WeaviateIntegrationTests.CRUD; + +public class WeaviateDynamicRecordConformanceTests_NamedVectors(WeaviateDynamicDataModelNamedVectorsFixture fixture) + : DynamicDataModelConformanceTests(fixture), IClassFixture +{ +} + +public class WeaviateDynamicRecordConformanceTests_UnnamedVector(WeaviateDynamicDataModelUnnamedVectorFixture fixture) + : DynamicDataModelConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/CRUD/WeaviateNoDataConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/CRUD/WeaviateNoDataConformanceTests.cs new file mode 100644 index 000000000000..e8a6deceb1e8 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/CRUD/WeaviateNoDataConformanceTests.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using WeaviateIntegrationTests.Support; +using Xunit; + +namespace WeaviateIntegrationTests.CRUD; + +public class WeaviateNoDataConformanceTests_NamedVectors(WeaviateNoDataConformanceTests_NamedVectors.Fixture fixture) + : NoDataConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoDataConformanceTests.Fixture + { + public override TestStore TestStore => WeaviateTestStore.NamedVectorsInstance; + + /// + /// Weaviate collections must start with an uppercase letter. + /// + public override string CollectionName => "NoDataNamedCollection"; + } +} + +public class WeaviateNoDataConformanceTests_UnnamedVector(WeaviateNoDataConformanceTests_UnnamedVector.Fixture fixture) + : NoDataConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoDataConformanceTests.Fixture + { + public override TestStore TestStore => WeaviateTestStore.UnnamedVectorInstance; + + /// + /// Weaviate collections must start with an uppercase letter. + /// + public override string CollectionName => "NoDataUnnamedCollection"; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/CRUD/WeaviateNoVectorConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/CRUD/WeaviateNoVectorConformanceTests.cs new file mode 100644 index 000000000000..b2fb42f176df --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/CRUD/WeaviateNoVectorConformanceTests.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.CRUD; +using VectorDataSpecificationTests.Support; +using WeaviateIntegrationTests.Support; +using Xunit; + +namespace WeaviateIntegrationTests.CRUD; + +public class WeaviateNoVectorConformanceTests_NamedVectors(WeaviateNoVectorConformanceTests_NamedVectors.Fixture fixture) + : NoVectorConformanceTests(fixture), IClassFixture +{ + public new class Fixture : NoVectorConformanceTests.Fixture + { + public override TestStore TestStore => WeaviateTestStore.NamedVectorsInstance; + + /// + /// Weaviate collections must start with an uppercase letter. + /// + public override string CollectionName => "NoVectorNamedCollection"; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/CRUD/WeaviateRecordConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/CRUD/WeaviateRecordConformanceTests.cs new file mode 100644 index 000000000000..c2ad732eb59d --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/CRUD/WeaviateRecordConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.CRUD; +using WeaviateIntegrationTests.Support; +using Xunit; + +namespace WeaviateIntegrationTests.CRUD; + +public class WeaviateRecordConformanceTests_NamedVectors(WeaviateSimpleModelNamedVectorsFixture fixture) + : RecordConformanceTests(fixture), IClassFixture +{ +} + +public class WeaviateRecordConformanceTests_UnnamedVector(WeaviateSimpleModelUnnamedVectorFixture fixture) + : RecordConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Collections/WeaviateCollectionConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Collections/WeaviateCollectionConformanceTests.cs index e839b02ad942..3c817890ac16 100644 --- a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Collections/WeaviateCollectionConformanceTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Collections/WeaviateCollectionConformanceTests.cs @@ -6,7 +6,12 @@ namespace WeaviateIntegrationTests.Collections; -public class WeaviateCollectionConformanceTests(WeaviateFixture fixture) - : CollectionConformanceTests(fixture), IClassFixture +public class WeaviateCollectionConformanceTests_NamedVectors(WeaviateSimpleModelNamedVectorsFixture fixture) + : CollectionConformanceTests(fixture), IClassFixture +{ +} + +public class WeaviateCollectionConformanceTests_UnnamedVector(WeaviateSimpleModelUnnamedVectorFixture fixture) + : CollectionConformanceTests(fixture), IClassFixture { } diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicFilterTests.cs index 6238ca6d9b6a..b4dbf228ee37 100644 --- a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicFilterTests.cs @@ -61,12 +61,10 @@ public override Task Contains_over_inline_string_array_with_weird_chars() // In Weaviate, string equality on multi-word textual properties depends on tokenization // (https://weaviate.io/developers/weaviate/api/graphql/filters#multi-word-queries-in-equal-filters) public override Task Equal_with_string_is_not_Contains() - => Assert.ThrowsAsync(() => base.Equal_with_string_is_not_Contains()); + => Assert.ThrowsAsync(() => base.Equal_with_string_is_not_Contains()); public new class Fixture : BasicFilterTests.Fixture { - public override TestStore TestStore => WeaviateTestStore.Instance; - - protected override string DistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineDistance; + public override TestStore TestStore => WeaviateTestStore.NamedVectorsInstance; } } diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicQueryTests.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicQueryTests.cs new file mode 100644 index 000000000000..fefa13f83515 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicQueryTests.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; +using WeaviateIntegrationTests.Support; +using Xunit; +using Xunit.Sdk; + +namespace WeaviateIntegrationTests.Filter; + +public class WeaviateBasicQueryTests(WeaviateBasicQueryTests.Fixture fixture) + : BasicQueryTests(fixture), IClassFixture +{ + #region Filter by null + + // Null-state indexing needs to be set up, but that's not supported yet (#10358). + // We could interact with Weaviate directly (not via the abstraction) to do this. + + public override Task Equal_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task Equal_with_null_captured() + => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); + + public override Task NotEqual_with_null_captured() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); + + public override Task NotEqual_with_null_reference_type() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_reference_type()); + + #endregion + + #region Not + + // Weaviate currently doesn't support NOT (https://github.com/weaviate/weaviate/issues/3683) + public override Task Not_over_And() + => Assert.ThrowsAsync(() => base.Not_over_And()); + + public override Task Not_over_Or() + => Assert.ThrowsAsync(() => base.Not_over_Or()); + + #endregion + + #region Unsupported Contains scenarios + + public override Task Contains_over_captured_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_captured_string_array()); + + public override Task Contains_over_inline_int_array() + => Assert.ThrowsAsync(() => base.Contains_over_inline_int_array()); + + public override Task Contains_over_inline_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_inline_int_array()); + + public override Task Contains_over_inline_string_array_with_weird_chars() + => Assert.ThrowsAsync(() => base.Contains_over_inline_string_array_with_weird_chars()); + + #endregion + + // In Weaviate, string equality on multi-word textual properties depends on tokenization + // (https://weaviate.io/developers/weaviate/api/graphql/filters#multi-word-queries-in-equal-filters) + public override Task Equal_with_string_is_not_Contains() + => Assert.ThrowsAsync(() => base.Equal_with_string_is_not_Contains()); + + public new class Fixture : BasicQueryTests.QueryFixture + { + public override TestStore TestStore => WeaviateTestStore.NamedVectorsInstance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/HybridSearch/WeaviateKeywordVectorizedHybridSearchTests.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/HybridSearch/WeaviateKeywordVectorizedHybridSearchTests.cs index 30d6bc0516f5..b5a262c4c47c 100644 --- a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/HybridSearch/WeaviateKeywordVectorizedHybridSearchTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/HybridSearch/WeaviateKeywordVectorizedHybridSearchTests.cs @@ -7,28 +7,46 @@ namespace WeaviateIntegrationTests.HybridSearch; -public class WeaviateKeywordVectorizedHybridSearchTests( - WeaviateKeywordVectorizedHybridSearchTests.VectorAndStringFixture vectorAndStringFixture, - WeaviateKeywordVectorizedHybridSearchTests.MultiTextFixture multiTextFixture) +public class WeaviateKeywordVectorizedHybridSearchTests_NamedVectors( + WeaviateKeywordVectorizedHybridSearchTests_NamedVectors.VectorAndStringFixture vectorAndStringFixture, + WeaviateKeywordVectorizedHybridSearchTests_NamedVectors.MultiTextFixture multiTextFixture) : KeywordVectorizedHybridSearchComplianceTests(vectorAndStringFixture, multiTextFixture), - IClassFixture, - IClassFixture + IClassFixture, + IClassFixture { public new class VectorAndStringFixture : KeywordVectorizedHybridSearchComplianceTests.VectorAndStringFixture { - public override TestStore TestStore => WeaviateTestStore.Instance; + public override TestStore TestStore => WeaviateTestStore.NamedVectorsInstance; - protected override string DistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineDistance; - - protected override string CollectionName => "VectorAndStringHybridSearch"; + public override string CollectionName => "VectorAndStringHybridSearch"; } public new class MultiTextFixture : KeywordVectorizedHybridSearchComplianceTests.MultiTextFixture { - public override TestStore TestStore => WeaviateTestStore.Instance; + public override TestStore TestStore => WeaviateTestStore.NamedVectorsInstance; + + public override string CollectionName => "MultiTextHybridSearch"; + } +} - protected override string DistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineDistance; +public class WeaviateKeywordVectorizedHybridSearchTests_UnnamedVector( + WeaviateKeywordVectorizedHybridSearchTests_UnnamedVector.VectorAndStringFixture vectorAndStringFixture, + WeaviateKeywordVectorizedHybridSearchTests_UnnamedVector.MultiTextFixture multiTextFixture) + : KeywordVectorizedHybridSearchComplianceTests(vectorAndStringFixture, multiTextFixture), + IClassFixture, + IClassFixture +{ + public new class VectorAndStringFixture : KeywordVectorizedHybridSearchComplianceTests.VectorAndStringFixture + { + public override TestStore TestStore => WeaviateTestStore.UnnamedVectorInstance; + + public override string CollectionName => "VectorAndStringHybridSearch"; + } + + public new class MultiTextFixture : KeywordVectorizedHybridSearchComplianceTests.MultiTextFixture + { + public override TestStore TestStore => WeaviateTestStore.UnnamedVectorInstance; - protected override string CollectionName => "MultiTextHybridSearch"; + public override string CollectionName => "MultiTextHybridSearch"; } } diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateBuilder.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateBuilder.cs index 1745a902a348..831f05734d6b 100644 --- a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateBuilder.cs +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateBuilder.cs @@ -8,7 +8,7 @@ namespace WeaviateIntegrationTests.Support.TestContainer; public sealed class WeaviateBuilder : ContainerBuilder { - public const string WeaviateImage = "semitechnologies/weaviate:1.26.4"; + public const string WeaviateImage = "semitechnologies/weaviate:1.28.12"; public const ushort WeaviateHttpPort = 8080; public const ushort WeaviateGrpcPort = 50051; diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateDynamicDataModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateDynamicDataModelFixture.cs new file mode 100644 index 000000000000..874b771b5a8f --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateDynamicDataModelFixture.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace WeaviateIntegrationTests.Support; + +public class WeaviateDynamicDataModelFixture : DynamicDataModelFixture +{ + public override TestStore TestStore => WeaviateTestStore.NamedVectorsInstance; + + // Weaviate requires the name to start with a capital letter and not contain any chars other than a-Z and 0-9. + // Source: https://weaviate.io/developers/weaviate/starter-guides/managing-collections#collection--property-names + public override string CollectionName => this.GetUniqueCollectionName(); + + public override string GetUniqueCollectionName() => $"A{Guid.NewGuid():N}"; +} + +public class WeaviateDynamicDataModelNamedVectorsFixture : WeaviateDynamicDataModelFixture +{ + public override TestStore TestStore => WeaviateTestStore.NamedVectorsInstance; +} + +public class WeaviateDynamicDataModelUnnamedVectorFixture : WeaviateDynamicDataModelFixture +{ + public override TestStore TestStore => WeaviateTestStore.UnnamedVectorInstance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateFixture.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateFixture.cs deleted file mode 100644 index ac3b64f89006..000000000000 --- a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateFixture.cs +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using VectorDataSpecificationTests.Support; - -namespace WeaviateIntegrationTests.Support; - -public class WeaviateFixture : VectorStoreFixture -{ - public override TestStore TestStore => WeaviateTestStore.Instance; - - // Weaviate requires the name to start with a capital letter and not contain any chars other than a-Z and 0-9. - // Source: https://weaviate.io/developers/weaviate/starter-guides/managing-collections#collection--property-names - public override string GetUniqueCollectionName() => $"A{Guid.NewGuid():N}"; -} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateSimpleModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateSimpleModelFixture.cs new file mode 100644 index 000000000000..829172f8503f --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateSimpleModelFixture.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace WeaviateIntegrationTests.Support; + +public class WeaviateSimpleModelFixture : SimpleModelFixture +{ + public override TestStore TestStore => WeaviateTestStore.NamedVectorsInstance; + + // Weaviate requires the name to start with a capital letter and not contain any chars other than a-Z and 0-9. + // Source: https://weaviate.io/developers/weaviate/starter-guides/managing-collections#collection--property-names + public override string CollectionName => this.GetUniqueCollectionName(); + + public override string GetUniqueCollectionName() => $"A{Guid.NewGuid():N}"; +} + +public class WeaviateSimpleModelNamedVectorsFixture : WeaviateSimpleModelFixture +{ + public override TestStore TestStore => WeaviateTestStore.NamedVectorsInstance; +} + +public class WeaviateSimpleModelUnnamedVectorFixture : WeaviateSimpleModelFixture +{ + public override TestStore TestStore => WeaviateTestStore.UnnamedVectorInstance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs index d112a2abfe49..8ddea21255d4 100644 --- a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs @@ -12,9 +12,13 @@ namespace WeaviateIntegrationTests.Support; public sealed class WeaviateTestStore : TestStore { - public static WeaviateTestStore Instance { get; } = new(); + public static WeaviateTestStore NamedVectorsInstance { get; } = new(hasNamedVectors: true); + public static WeaviateTestStore UnnamedVectorInstance { get; } = new(hasNamedVectors: false); + + public override string DefaultDistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineDistance; private readonly WeaviateContainer _container = new WeaviateBuilder().Build(); + private readonly bool _hasNamedVectors; public HttpClient? _httpClient { get; private set; } private WeaviateVectorStore? _defaultVectorStore; @@ -25,15 +29,13 @@ public sealed class WeaviateTestStore : TestStore public WeaviateVectorStore GetVectorStore(WeaviateVectorStoreOptions options) => new(this.Client, options); - private WeaviateTestStore() - { - } + private WeaviateTestStore(bool hasNamedVectors) => this._hasNamedVectors = hasNamedVectors; protected override async Task StartAsync() { await this._container.StartAsync(); this._httpClient = new HttpClient { BaseAddress = new Uri($"http://localhost:{this._container.GetMappedPublicPort(WeaviateBuilder.WeaviateHttpPort)}/v1/") }; - this._defaultVectorStore = new(this._httpClient); + this._defaultVectorStore = new(this._httpClient, new() { HasNamedVectors = this._hasNamedVectors }); } protected override Task StopAsync() diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/VectorSearch/WeaviateVectorSearchDistanceFunctionComplianceTests.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/VectorSearch/WeaviateVectorSearchDistanceFunctionComplianceTests.cs new file mode 100644 index 000000000000..0c6a5aadd390 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/VectorSearch/WeaviateVectorSearchDistanceFunctionComplianceTests.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.VectorSearch; +using WeaviateIntegrationTests.Support; +using Xunit; + +namespace WeaviateIntegrationTests.VectorSearch; + +public class WeaviateVectorSearchDistanceFunctionComplianceTests_NamedVectors(WeaviateSimpleModelNamedVectorsFixture fixture) + : VectorSearchDistanceFunctionComplianceTests(fixture), IClassFixture +{ + public override Task CosineSimilarity() => Assert.ThrowsAsync(base.CosineSimilarity); + + public override Task DotProductSimilarity() => Assert.ThrowsAsync(base.DotProductSimilarity); + + public override Task EuclideanDistance() => Assert.ThrowsAsync(base.EuclideanDistance); + + /// + /// Tests vector search using , computing -(u · v) as a distance metric per Weaviate's convention. + /// Expects scores of -1 (exact match), 1 (opposite), and 0 (orthogonal), sorted ascending ([0, 2, 1]), with lower scores indicating closer matches. + /// . + /// + public override Task NegativeDotProductSimilarity() => this.SimpleSearch(DistanceFunction.NegativeDotProductSimilarity, -1, 1, 0, [0, 2, 1]); +} + +public class WeaviateVectorSearchDistanceFunctionComplianceTests_UnnamedVector(WeaviateDynamicDataModelNamedVectorsFixture fixture) + : VectorSearchDistanceFunctionComplianceTests(fixture), IClassFixture +{ + public override Task CosineSimilarity() => Assert.ThrowsAsync(base.CosineSimilarity); + + public override Task DotProductSimilarity() => Assert.ThrowsAsync(base.DotProductSimilarity); + + public override Task EuclideanDistance() => Assert.ThrowsAsync(base.EuclideanDistance); + + /// + /// Tests vector search using , computing -(u · v) as a distance metric per Weaviate's convention. + /// Expects scores of -1 (exact match), 1 (opposite), and 0 (orthogonal), sorted ascending ([0, 2, 1]), with lower scores indicating closer matches. + /// . + /// + public override Task NegativeDotProductSimilarity() => this.SimpleSearch(DistanceFunction.NegativeDotProductSimilarity, -1, 1, 0, [0, 2, 1]); +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateEmbeddingGenerationTests.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateEmbeddingGenerationTests.cs new file mode 100644 index 000000000000..44f8e96c5f95 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateEmbeddingGenerationTests.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel; +using VectorDataSpecificationTests; +using VectorDataSpecificationTests.Support; +using WeaviateIntegrationTests.Support; +using Xunit; + +namespace WeaviateIntegrationTests; + +public class WeaviateEmbeddingGenerationTests(WeaviateEmbeddingGenerationTests.Fixture fixture) + : EmbeddingGenerationTests(fixture), IClassFixture +{ + public new class Fixture : EmbeddingGenerationTests.Fixture + { + public override TestStore TestStore => WeaviateTestStore.NamedVectorsInstance; + + public override IVectorStore CreateVectorStore(IEmbeddingGenerator? embeddingGenerator) + => WeaviateTestStore.NamedVectorsInstance.GetVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); + + public override Func[] DependencyInjectionStoreRegistrationDelegates => + [ + services => services + .AddSingleton(WeaviateTestStore.NamedVectorsInstance.Client) + .AddWeaviateVectorStore() + ]; + + public override Func[] DependencyInjectionCollectionRegistrationDelegates => + [ + services => services + .AddSingleton(WeaviateTestStore.NamedVectorsInstance.Client) + .AddWeaviateVectorStoreRecordCollection(this.CollectionName) + ]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateIntegrationTests.csproj index eb98407f35ee..486583668bb8 100644 --- a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateIntegrationTests.csproj @@ -10,18 +10,18 @@ - + runtime; build; native; contentfiles; analyzers; buildtransitive all - - + + - - + +