From 160e08e5b83b56b5dd9e14692881ed54f9960f1d Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Tue, 11 Feb 2025 16:37:08 +0100 Subject: [PATCH 01/32] add new test project, move the existing tests to it --- dotnet/SK-dotnet.sln | 21 +++++++---- .../IntegrationTests/IntegrationTests.csproj | 1 - .../SqlServerIntegrationTests.csproj | 36 +++++++++++++++++++ .../SqlServerMemoryStoreTests.cs | 0 4 files changed, 51 insertions(+), 7 deletions(-) create mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj rename dotnet/src/{IntegrationTests/Connectors/Memory/SqlServer => VectorDataIntegrationTests/SqlServerIntegrationTests}/SqlServerMemoryStoreTests.cs (100%) diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 0a711f84f5f3..d1bbee3fc126 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -439,6 +439,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "sk-chatgpt-azure-function", EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "kernel-functions-generator", "samples\Demos\CreateChatGptPlugin\MathPlugin\kernel-functions-generator\kernel-functions-generator.csproj", "{78785CB1-66CF-4895-D7E5-A440DD84BE86}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SqlServerIntegrationTests", "src\VectorDataIntegrationTests\SqlServerIntegrationTests\SqlServerIntegrationTests.csproj", "{A5E6193C-8431-4C6E-B674-682CB41EAA0C}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -1076,12 +1078,6 @@ Global {6F591D05-5F7F-4211-9042-42D8BCE60415}.Publish|Any CPU.Build.0 = Debug|Any CPU {6F591D05-5F7F-4211-9042-42D8BCE60415}.Release|Any CPU.ActiveCfg = Release|Any CPU {6F591D05-5F7F-4211-9042-42D8BCE60415}.Release|Any CPU.Build.0 = Release|Any CPU - {232E1153-6366-4175-A982-D66B30AAD610}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {232E1153-6366-4175-A982-D66B30AAD610}.Debug|Any CPU.Build.0 = Debug|Any CPU - {232E1153-6366-4175-A982-D66B30AAD610}.Publish|Any CPU.ActiveCfg = Debug|Any CPU - {232E1153-6366-4175-A982-D66B30AAD610}.Publish|Any CPU.Build.0 = Debug|Any CPU - {232E1153-6366-4175-A982-D66B30AAD610}.Release|Any CPU.ActiveCfg = Release|Any CPU - {232E1153-6366-4175-A982-D66B30AAD610}.Release|Any CPU.Build.0 = Release|Any CPU {E82B640C-1704-430D-8D71-FD8ED3695468}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {E82B640C-1704-430D-8D71-FD8ED3695468}.Debug|Any CPU.Build.0 = Debug|Any CPU {E82B640C-1704-430D-8D71-FD8ED3695468}.Publish|Any CPU.ActiveCfg = Debug|Any CPU @@ -1100,6 +1096,12 @@ Global {39EAB599-742F-417D-AF80-95F90376BB18}.Publish|Any CPU.Build.0 = Publish|Any CPU {39EAB599-742F-417D-AF80-95F90376BB18}.Release|Any CPU.ActiveCfg = Release|Any CPU {39EAB599-742F-417D-AF80-95F90376BB18}.Release|Any CPU.Build.0 = Release|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Debug|Any CPU.Build.0 = Debug|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Publish|Any CPU.Build.0 = Debug|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Release|Any CPU.ActiveCfg = Release|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Release|Any CPU.Build.0 = Release|Any CPU {DAC54048-A39A-4739-8307-EA5A291F2EA0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {DAC54048-A39A-4739-8307-EA5A291F2EA0}.Debug|Any CPU.Build.0 = Debug|Any CPU {DAC54048-A39A-4739-8307-EA5A291F2EA0}.Publish|Any CPU.ActiveCfg = Debug|Any CPU @@ -1172,6 +1174,12 @@ Global {78785CB1-66CF-4895-D7E5-A440DD84BE86}.Publish|Any CPU.Build.0 = Debug|Any CPU {78785CB1-66CF-4895-D7E5-A440DD84BE86}.Release|Any CPU.ActiveCfg = Release|Any CPU {78785CB1-66CF-4895-D7E5-A440DD84BE86}.Release|Any CPU.Build.0 = Release|Any CPU + {A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Publish|Any CPU.Build.0 = Debug|Any CPU + {A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -1333,6 +1341,7 @@ Global {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {2EB6E4C2-606D-B638-2E08-49EA2061C428} = {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} {78785CB1-66CF-4895-D7E5-A440DD84BE86} = {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} + {A5E6193C-8431-4C6E-B674-682CB41EAA0C} = {831DDCA2-7D2C-4C31-80DB-6BDB3E1F7AE0} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83} diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index e24215b583d6..26cfaa1949ae 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -88,7 +88,6 @@ - diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj new file mode 100644 index 000000000000..fe58fcd104ae --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj @@ -0,0 +1,36 @@ + + + + net8.0 + enable + enable + + false + true + + $(NoWarn);CA2007,SKEXP0001,SKEXP0020,VSTHRD111 + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + + + + + + + diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/SqlServer/SqlServerMemoryStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs similarity index 100% rename from dotnet/src/IntegrationTests/Connectors/Memory/SqlServer/SqlServerMemoryStoreTests.cs rename to dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs From 186fda296d6e26e94b7c5c32b803cde9b6eac1c5 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Tue, 11 Feb 2025 17:04:47 +0100 Subject: [PATCH 02/32] port existing tests to Testcontainers.MsSql and re-enable them --- dotnet/Directory.Packages.props | 1 + .../SqlServerContainerFixture.cs | 35 ++++++++++++++ .../SqlServerIntegrationTests.csproj | 3 +- .../SqlServerMemoryStoreTests.cs | 46 ++++++++----------- 4 files changed, 58 insertions(+), 27 deletions(-) create mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerContainerFixture.cs diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index e93dc3df49a2..7fb4400c9c7c 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -120,6 +120,7 @@ + diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerContainerFixture.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerContainerFixture.cs new file mode 100644 index 000000000000..36f00526303d --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerContainerFixture.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Testcontainers.MsSql; +using Xunit; + +namespace SqlServerIntegrationTests; + +public sealed class SqlServerContainerFixture : IAsyncLifetime +{ + private MsSqlContainer? Container { get; set; } + + public string GetConnectionString() => this.Container?.GetConnectionString() ?? + throw new InvalidOperationException("The test container was not initialized."); + + public async Task DisposeAsync() + { + if (this.Container is not null) + { + await this.Container.DisposeAsync(); + } + } + + public async Task InitializeAsync() => this.Container = await CreateContainerAsync(); + + private static async Task CreateContainerAsync() + { + var container = new MsSqlBuilder() + .WithImage("mcr.microsoft.com/mssql/server:2022-latest") + .Build(); + + await container.StartAsync(); + + return container; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj index fe58fcd104ae..1db2d3d7aa58 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj @@ -8,7 +8,7 @@ false true - $(NoWarn);CA2007,SKEXP0001,SKEXP0020,VSTHRD111 + $(NoWarn);CS1591,CA1861,CA2007,SKEXP0001,SKEXP0020,VSTHRD111 @@ -27,6 +27,7 @@ + diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs index 32c0f6742546..4135ad436d81 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs @@ -1,13 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; using Microsoft.Data.SqlClient; using Microsoft.Extensions.Configuration; using Microsoft.SemanticKernel.Connectors.SqlServer; using Microsoft.SemanticKernel.Memory; +using SqlServerIntegrationTests; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.SqlServer; @@ -15,10 +12,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.SqlServer; /// /// Unit tests for class. /// -public class SqlServerMemoryStoreTests : IAsyncLifetime +public class SqlServerMemoryStoreTests(SqlServerContainerFixture fixture) : IClassFixture, IAsyncLifetime { - private const string? SkipReason = "Configure SQL Server or Azure SQL connection string and then set this to 'null'."; - //private const string? SkipReason = null; private const string SchemaName = "sk_it"; private const string DefaultCollectionName = "test"; private const int TestEmbeddingDimensionsCount = 5; @@ -27,16 +22,13 @@ public class SqlServerMemoryStoreTests : IAsyncLifetime private SqlServerMemoryStore Store { get; set; } = null!; + private SqlServerContainerFixture Fixture { get; } = fixture; + public async Task InitializeAsync() { - var configuration = new ConfigurationBuilder() - .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true) - .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) - .AddEnvironmentVariables() - .AddUserSecrets() - .Build(); + await this.Fixture.InitializeAsync(); - var connectionString = configuration["SqlServer:ConnectionString"]; + string connectionString = this.Fixture.GetConnectionString(); if (string.IsNullOrWhiteSpace(connectionString)) { @@ -54,9 +46,11 @@ public async Task InitializeAsync() public async Task DisposeAsync() { await this.CleanupDatabaseAsync(); + + await this.Fixture.DisposeAsync(); } - [Fact(Skip = SkipReason)] + [Fact] public async Task CreateCollectionAsync() { Assert.False(await this.Store.DoesCollectionExistAsync(DefaultCollectionName)); @@ -65,7 +59,7 @@ public async Task CreateCollectionAsync() Assert.True(await this.Store.DoesCollectionExistAsync(DefaultCollectionName)); } - [Fact(Skip = SkipReason)] + [Fact] public async Task DropCollectionAsync() { await this.Store.CreateCollectionAsync(DefaultCollectionName); @@ -73,7 +67,7 @@ public async Task DropCollectionAsync() Assert.False(await this.Store.DoesCollectionExistAsync(DefaultCollectionName)); } - [Fact(Skip = SkipReason)] + [Fact] public async Task GetCollectionsAsync() { await this.Store.CreateCollectionAsync("collection1"); @@ -84,7 +78,7 @@ public async Task GetCollectionsAsync() Assert.Contains("collection2", collections); } - [Fact(Skip = SkipReason)] + [Fact] public async Task UpsertAsync() { await this.Store.CreateCollectionAsync(DefaultCollectionName); @@ -104,7 +98,7 @@ public async Task UpsertAsync() Assert.Equal("Some id", id); } - [Theory(Skip = SkipReason)] + [Theory] [InlineData(true)] [InlineData(false)] public async Task GetAsync(bool withEmbeddings) @@ -128,7 +122,7 @@ public async Task GetAsync(bool withEmbeddings) record.Embedding.ToArray()); } - [Fact(Skip = SkipReason)] + [Fact] public async Task UpsertBatchAsync() { await this.Store.CreateCollectionAsync(DefaultCollectionName); @@ -139,7 +133,7 @@ public async Task UpsertBatchAsync() id => Assert.Equal("Some other id", id)); } - [Theory(Skip = SkipReason)] + [Theory] [InlineData(true)] [InlineData(false)] public async Task GetBatchAsync(bool withEmbeddings) @@ -180,7 +174,7 @@ public async Task GetBatchAsync(bool withEmbeddings) }); } - [Fact(Skip = SkipReason)] + [Fact] public async Task RemoveAsync() { await this.Store.CreateCollectionAsync(DefaultCollectionName); @@ -191,7 +185,7 @@ public async Task RemoveAsync() Assert.Null(await this.Store.GetAsync(DefaultCollectionName, "Some id")); } - [Fact(Skip = SkipReason)] + [Fact] public async Task RemoveBatchAsync() { await this.Store.CreateCollectionAsync(DefaultCollectionName); @@ -204,7 +198,7 @@ public async Task RemoveBatchAsync() Assert.Null(await this.Store.GetAsync(DefaultCollectionName, "Some other id")); } - [Theory(Skip = SkipReason)] + [Theory] [InlineData(true)] [InlineData(false)] public async Task GetNearestMatchesAsync(bool withEmbeddings) @@ -248,7 +242,7 @@ public async Task GetNearestMatchesAsync(bool withEmbeddings) }); } - [Fact(Skip = SkipReason)] + [Fact] public async Task GetNearestMatchesWithMinRelevanceScoreAsync() { await this.Store.CreateCollectionAsync(DefaultCollectionName); @@ -265,7 +259,7 @@ public async Task GetNearestMatchesWithMinRelevanceScoreAsync() Assert.DoesNotContain(firstId, results.Select(r => r.Record.Metadata.Id)); } - [Theory(Skip = SkipReason)] + [Theory] [InlineData(true)] [InlineData(false)] public async Task GetNearestMatchAsync(bool withEmbeddings) From 22495b9dc6340acbaa507b875326160e03050515 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Wed, 12 Feb 2025 14:34:10 +0100 Subject: [PATCH 03/32] Revert "port existing tests to Testcontainers.MsSql and re-enable them" This reverts commit 186fda296d6e26e94b7c5c32b803cde9b6eac1c5. --- dotnet/Directory.Packages.props | 1 - .../SqlServerContainerFixture.cs | 35 -------------- .../SqlServerIntegrationTests.csproj | 3 +- .../SqlServerMemoryStoreTests.cs | 46 +++++++++++-------- 4 files changed, 27 insertions(+), 58 deletions(-) delete mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerContainerFixture.cs diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 7fb4400c9c7c..e93dc3df49a2 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -120,7 +120,6 @@ - diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerContainerFixture.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerContainerFixture.cs deleted file mode 100644 index 36f00526303d..000000000000 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerContainerFixture.cs +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using Testcontainers.MsSql; -using Xunit; - -namespace SqlServerIntegrationTests; - -public sealed class SqlServerContainerFixture : IAsyncLifetime -{ - private MsSqlContainer? Container { get; set; } - - public string GetConnectionString() => this.Container?.GetConnectionString() ?? - throw new InvalidOperationException("The test container was not initialized."); - - public async Task DisposeAsync() - { - if (this.Container is not null) - { - await this.Container.DisposeAsync(); - } - } - - public async Task InitializeAsync() => this.Container = await CreateContainerAsync(); - - private static async Task CreateContainerAsync() - { - var container = new MsSqlBuilder() - .WithImage("mcr.microsoft.com/mssql/server:2022-latest") - .Build(); - - await container.StartAsync(); - - return container; - } -} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj index 1db2d3d7aa58..fe58fcd104ae 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj @@ -8,7 +8,7 @@ false true - $(NoWarn);CS1591,CA1861,CA2007,SKEXP0001,SKEXP0020,VSTHRD111 + $(NoWarn);CA2007,SKEXP0001,SKEXP0020,VSTHRD111 @@ -27,7 +27,6 @@ - diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs index 4135ad436d81..32c0f6742546 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs @@ -1,10 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; using Microsoft.Data.SqlClient; using Microsoft.Extensions.Configuration; using Microsoft.SemanticKernel.Connectors.SqlServer; using Microsoft.SemanticKernel.Memory; -using SqlServerIntegrationTests; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.SqlServer; @@ -12,8 +15,10 @@ namespace SemanticKernel.IntegrationTests.Connectors.SqlServer; /// /// Unit tests for class. /// -public class SqlServerMemoryStoreTests(SqlServerContainerFixture fixture) : IClassFixture, IAsyncLifetime +public class SqlServerMemoryStoreTests : IAsyncLifetime { + private const string? SkipReason = "Configure SQL Server or Azure SQL connection string and then set this to 'null'."; + //private const string? SkipReason = null; private const string SchemaName = "sk_it"; private const string DefaultCollectionName = "test"; private const int TestEmbeddingDimensionsCount = 5; @@ -22,13 +27,16 @@ public class SqlServerMemoryStoreTests(SqlServerContainerFixture fixture) : ICla private SqlServerMemoryStore Store { get; set; } = null!; - private SqlServerContainerFixture Fixture { get; } = fixture; - public async Task InitializeAsync() { - await this.Fixture.InitializeAsync(); + var configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); - string connectionString = this.Fixture.GetConnectionString(); + var connectionString = configuration["SqlServer:ConnectionString"]; if (string.IsNullOrWhiteSpace(connectionString)) { @@ -46,11 +54,9 @@ public async Task InitializeAsync() public async Task DisposeAsync() { await this.CleanupDatabaseAsync(); - - await this.Fixture.DisposeAsync(); } - [Fact] + [Fact(Skip = SkipReason)] public async Task CreateCollectionAsync() { Assert.False(await this.Store.DoesCollectionExistAsync(DefaultCollectionName)); @@ -59,7 +65,7 @@ public async Task CreateCollectionAsync() Assert.True(await this.Store.DoesCollectionExistAsync(DefaultCollectionName)); } - [Fact] + [Fact(Skip = SkipReason)] public async Task DropCollectionAsync() { await this.Store.CreateCollectionAsync(DefaultCollectionName); @@ -67,7 +73,7 @@ public async Task DropCollectionAsync() Assert.False(await this.Store.DoesCollectionExistAsync(DefaultCollectionName)); } - [Fact] + [Fact(Skip = SkipReason)] public async Task GetCollectionsAsync() { await this.Store.CreateCollectionAsync("collection1"); @@ -78,7 +84,7 @@ public async Task GetCollectionsAsync() Assert.Contains("collection2", collections); } - [Fact] + [Fact(Skip = SkipReason)] public async Task UpsertAsync() { await this.Store.CreateCollectionAsync(DefaultCollectionName); @@ -98,7 +104,7 @@ public async Task UpsertAsync() Assert.Equal("Some id", id); } - [Theory] + [Theory(Skip = SkipReason)] [InlineData(true)] [InlineData(false)] public async Task GetAsync(bool withEmbeddings) @@ -122,7 +128,7 @@ public async Task GetAsync(bool withEmbeddings) record.Embedding.ToArray()); } - [Fact] + [Fact(Skip = SkipReason)] public async Task UpsertBatchAsync() { await this.Store.CreateCollectionAsync(DefaultCollectionName); @@ -133,7 +139,7 @@ public async Task UpsertBatchAsync() id => Assert.Equal("Some other id", id)); } - [Theory] + [Theory(Skip = SkipReason)] [InlineData(true)] [InlineData(false)] public async Task GetBatchAsync(bool withEmbeddings) @@ -174,7 +180,7 @@ public async Task GetBatchAsync(bool withEmbeddings) }); } - [Fact] + [Fact(Skip = SkipReason)] public async Task RemoveAsync() { await this.Store.CreateCollectionAsync(DefaultCollectionName); @@ -185,7 +191,7 @@ public async Task RemoveAsync() Assert.Null(await this.Store.GetAsync(DefaultCollectionName, "Some id")); } - [Fact] + [Fact(Skip = SkipReason)] public async Task RemoveBatchAsync() { await this.Store.CreateCollectionAsync(DefaultCollectionName); @@ -198,7 +204,7 @@ public async Task RemoveBatchAsync() Assert.Null(await this.Store.GetAsync(DefaultCollectionName, "Some other id")); } - [Theory] + [Theory(Skip = SkipReason)] [InlineData(true)] [InlineData(false)] public async Task GetNearestMatchesAsync(bool withEmbeddings) @@ -242,7 +248,7 @@ public async Task GetNearestMatchesAsync(bool withEmbeddings) }); } - [Fact] + [Fact(Skip = SkipReason)] public async Task GetNearestMatchesWithMinRelevanceScoreAsync() { await this.Store.CreateCollectionAsync(DefaultCollectionName); @@ -259,7 +265,7 @@ public async Task GetNearestMatchesWithMinRelevanceScoreAsync() Assert.DoesNotContain(firstId, results.Select(r => r.Record.Metadata.Id)); } - [Theory] + [Theory(Skip = SkipReason)] [InlineData(true)] [InlineData(false)] public async Task GetNearestMatchAsync(bool withEmbeddings) From 179f56e8ee7075a9334dedb2df39332c96f764c9 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Wed, 12 Feb 2025 17:03:20 +0100 Subject: [PATCH 04/32] implement the tests using the new pattern, provide implementation that throws NotImplementedException --- .../SqlServerVectorStore.cs | 54 +++++++++++++++++++ .../Filter/SqlServerBasicFilterTests.cs | 10 ++++ .../Filter/SqlServerFilterFixture.cs | 12 +++++ .../Properties/AssemblyAttributes.cs | 3 ++ .../SqlServerIntegrationTests.csproj | 10 ++++ ...ServerConnectionStringRequiredAttribute.cs | 18 +++++++ .../Support/SqlServerTestEnvironment.cs | 25 +++++++++ .../Support/SqlServerTestStore.cs | 31 +++++++++++ 8 files changed, 163 insertions(+) create mode 100644 dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerFilterFixture.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Properties/AssemblyAttributes.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerConnectionStringRequiredAttribute.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestEnvironment.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs new file mode 100644 index 000000000000..0d73b806b6da --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading; +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +/// +/// An implementation of backed by a SQL Server or Azure SQL database. +/// +public sealed class SqlServerVectorStore : IVectorStore, IDisposable +{ + internal const string DefaultSchema = "dbo"; + internal const int DefaultEmbeddingDimensionsCount = 1536; + + private readonly SqlConnection _connection; + private readonly SqlServerClient _sqlServerClient; + + /// + /// Initializes a new instance of the class. + /// + /// Database connection. + /// Database schema of collection tables. + /// Number of dimensions that stored embeddings will use + public SqlServerVectorStore(SqlConnection connection, string schema = DefaultSchema, int embeddingDimensionsCount = DefaultEmbeddingDimensionsCount) + { + // TODO adsitnik: design: + // 1. Do we need a ctor that takes the connection string and creates a connection? + // What is the story with pooling for the SqlConnection type? + // Does it maintain a private instance pool? Or a static one? + // 2. Should we introduce an option bag for the schema and embeddingDimensionsCount? + // This would allow us to add more options in the future without breaking the API. + this._connection = connection; + this._sqlServerClient = new SqlServerClient(connection, schema, embeddingDimensionsCount); + } + + /// + public void Dispose() => this._connection.Dispose(); + + /// + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull + { + Verify.NotNull(name); + + throw new System.NotImplementedException(); + } + + /// + public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) + => this._sqlServerClient.GetTablesAsync(cancellationToken); +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs new file mode 100644 index 000000000000..1c5552a2ec13 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using Xunit; + +namespace SqlServerIntegrationTests.Filter; + +public class SqlServerBasicFilterTests(SqlServerFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerFilterFixture.cs new file mode 100644 index 000000000000..d467f3f2d6a2 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerFilterFixture.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServerIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace SqlServerIntegrationTests.Filter; + +public class SqlServerFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => SqlServerTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Properties/AssemblyAttributes.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Properties/AssemblyAttributes.cs new file mode 100644 index 000000000000..8f36e2be3f06 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Properties/AssemblyAttributes.cs @@ -0,0 +1,3 @@ +// Copyright (c) Microsoft. All rights reserved. + +[assembly: SqlServerIntegrationTests.Support.SqlServerConnectionStringRequiredAttribute] diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj index fe58fcd104ae..d39fecdb6d82 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj @@ -31,6 +31,16 @@ + + + + + + Always + + + Always + diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerConnectionStringRequiredAttribute.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerConnectionStringRequiredAttribute.cs new file mode 100644 index 000000000000..80885df9e18c --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerConnectionStringRequiredAttribute.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Xunit; + +namespace SqlServerIntegrationTests.Support; + +/// +/// Checks whether the connection string for Sql Server is provided, and skips the test(s) otherwise. +/// +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Assembly)] +public sealed class SqlServerConnectionStringRequiredAttribute : Attribute, ITestCondition +{ + public ValueTask IsMetAsync() => new(SqlServerTestEnvironment.IsConnectionStringDefined); + + public string Skip { get; set; } = "ConnectionString is not configured, set SqlServer:ConnectionString."; + + public string SkipReason => this.Skip; +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestEnvironment.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestEnvironment.cs new file mode 100644 index 000000000000..043f4882e640 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestEnvironment.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.Configuration; +using Microsoft.SemanticKernel.Connectors.SqlServer; + +namespace SqlServerIntegrationTests.Support; + +internal static class SqlServerTestEnvironment +{ + public static readonly string? ConnectionString = GetConnectionString(); + + public static bool IsConnectionStringDefined => !string.IsNullOrEmpty(ConnectionString); + + private static string? GetConnectionString() + { + var configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true) + .AddJsonFile(path: "testsettings.development.json", optional: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + return configuration.GetSection("SqlServer")["ConnectionString"]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs new file mode 100644 index 000000000000..45ec63622e9f --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.SqlServer; +using VectorDataSpecificationTests.Support; + +namespace SqlServerIntegrationTests.Support; + +public sealed class SqlServerTestStore : TestStore +{ + public static readonly SqlServerTestStore Instance = new(); + + public override IVectorStore DefaultVectorStore + => this._connectedStore ?? throw new InvalidOperationException("Not initialized"); + + private SqlServerVectorStore? _connectedStore; + + protected override async Task StartAsync() + { + if (string.IsNullOrWhiteSpace(SqlServerTestEnvironment.ConnectionString)) + { + throw new InvalidOperationException("Connection string is not configured, set the SqlServer:ConnectionString environment variable"); + } + + SqlConnection connection = new(SqlServerTestEnvironment.ConnectionString); + await connection.OpenAsync(); + + this._connectedStore = new(connection); + } +} From 486d0281dfac5e5b4e888360c01d01930ffa219f Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Fri, 14 Feb 2025 10:32:07 +0100 Subject: [PATCH 05/32] implement collection removal, existence check and creation --- .../Connectors.Memory.SqlServer.csproj | 4 + .../SqlServerClient.cs | 17 +-- .../SqlServerCommandBuilder.cs | 109 +++++++++++++++++ .../SqlServerVectorStore.cs | 99 +++++++++++++--- .../SqlServerVectorStoreOptions.cs | 21 ++++ .../SqlServerVectorStoreRecordCollection.cs | 112 ++++++++++++++++++ .../SqlServerCommandBuilderTests.cs | 106 +++++++++++++++++ .../SqlServerVectorStoreTests.cs | 50 ++++++++ 8 files changed, 489 insertions(+), 29 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj b/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj index ba73f9641bd9..457d1f4a8d93 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj @@ -26,4 +26,8 @@ + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerClient.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerClient.cs index 4a1225f0a46f..f8a90eb75873 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerClient.cs @@ -101,16 +101,7 @@ public async Task DoesTableExistsAsync(string tableName, CancellationToken { using (await this.OpenConnectionAsync(cancellationToken).ConfigureAwait(false)) { - using var cmd = this._connection.CreateCommand(); - cmd.CommandText = """ - SELECT TABLE_NAME - FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_TYPE = 'BASE TABLE' - AND TABLE_SCHEMA = @schema - AND TABLE_NAME = @tableName - """; - cmd.Parameters.AddWithValue("@schema", this._schema); - cmd.Parameters.AddWithValue("@tableName", tableName); + using var cmd = SqlServerCommandBuilder.SelectTableName(this._connection, this._schema, tableName); using var reader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); return await reader.ReadAsync(cancellationToken).ConfigureAwait(false); } @@ -121,11 +112,7 @@ public async Task DeleteTableAsync(string tableName, CancellationToken cancellat { using (await this.OpenConnectionAsync(cancellationToken).ConfigureAwait(false)) { - using var cmd = this._connection.CreateCommand(); - var fullTableName = this.GetSanitizedFullTableName(tableName); - cmd.CommandText = $""" - DROP TABLE IF EXISTS {fullTableName} - """; + using var cmd = SqlServerCommandBuilder.DropTable(this._connection, this._schema, tableName); await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs new file mode 100644 index 000000000000..31164dde0dd1 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -0,0 +1,109 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData; + +#pragma warning disable CA2100 // Review SQL queries for security vulnerabilities + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +internal static class SqlServerCommandBuilder +{ + internal static string GetSanitizedFullTableName(string schema, string tableName) + { + // If the column name contains a ], then escape it by doubling it. + // "Name with [brackets]" becomes [Name with [brackets]]]. + + StringBuilder sb = new(tableName.Length + schema.Length + 5); + sb.Append('['); + sb.Append(schema); + sb.Replace("]", "]]"); // replace the ] for schema + sb.Append("].["); + int index = sb.Length; // store the index, so we don't replace ] for schema twice + sb.Append(tableName); + sb.Replace("]", "]]", index, tableName.Length); + sb.Append(']'); + + return sb.ToString(); + } + + internal static SqlCommand CreateTable( + SqlConnection connection, + SqlServerVectorStoreOptions options, + string tableName, + bool ifNotExists, + VectorStoreRecordKeyProperty keyProperty, + IReadOnlyList dataProperties, + IReadOnlyList vectorProperties) + { + SqlCommand command = connection.CreateCommand(); + string fullTableName = GetSanitizedFullTableName(options.Schema, tableName); + + StringBuilder sb = new(200); + if (ifNotExists) + { + sb.AppendFormat("IF OBJECT_ID(N'{0}', N'U') IS NULL", fullTableName).AppendLine(); + } + sb.AppendFormat("CREATE TABLE {0} (", fullTableName).AppendLine(); + // Use square brackets to escape column names. + string keyColumnName = GetColumnName(keyProperty); + sb.AppendFormat("[{0}] {1} NOT NULL,", keyColumnName, Map(keyProperty.PropertyType).sqlName).AppendLine(); + for (int i = 0; i < dataProperties.Count; i++) + { + (string sqlName, bool isNullable) = Map(dataProperties[i].PropertyType); + sb.AppendFormat(isNullable ? "[{0}] {1}," : "[{0}] {1} NOT NULL,", GetColumnName(dataProperties[i]), sqlName); + sb.AppendLine(); + } + for (int i = 0; i < vectorProperties.Count; i++) + { + sb.AppendFormat("[{0}] VECTOR({1}),", GetColumnName(vectorProperties[i]), vectorProperties[i].Dimensions); + sb.AppendLine(); + } + sb.AppendFormat("PRIMARY KEY NONCLUSTERED ([{0}])", keyColumnName).AppendLine(); + sb.Append(')'); // end the table definition + command.CommandText = sb.ToString(); + return command; + + static string GetColumnName(VectorStoreRecordProperty property) => property.StoragePropertyName ?? property.DataModelPropertyName; + + static (string sqlName, bool isNullable) Map(Type type) => type switch + { + Type t when t == typeof(int) => ("INT", false), + Type t when t == typeof(long) => ("BIGINT", false), + Type t when t == typeof(Guid) => ("UNIQUEIDENTIFIER", false), + Type t when t == typeof(string) => ("NVARCHAR(255) COLLATE Latin1_General_100_BIN2", true), + Type t when t == typeof(byte[]) => ("VARBINARY(MAX)", true), + Type t when t == typeof(bool) => ("BIT", false), + Type t when t == typeof(DateTime) => ("DATETIME", false), + Type t when t == typeof(TimeSpan) => ("TIME", false), + Type t when t == typeof(decimal) => ("DECIMAL", false), + Type t when t == typeof(double) => ("FLOAT", false), + Type t when t == typeof(float) => ("REAL", false), + _ => throw new NotSupportedException($"Type {type} is not supported.") + }; + } + + internal static SqlCommand DropTable(SqlConnection connection, string schema, string tableName) + { + SqlCommand command = connection.CreateCommand(); + string fullTableName = GetSanitizedFullTableName(schema, tableName); + command.CommandText = $"DROP TABLE IF EXISTS {fullTableName}"; + return command; + } + + internal static SqlCommand SelectTableName(SqlConnection connection, string schema, string tableName) + { + SqlCommand command = connection.CreateCommand(); + command.CommandText = """ + SELECT TABLE_NAME + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE = 'BASE TABLE' + AND TABLE_SCHEMA = @schema + AND TABLE_NAME = @tableName + """; + command.Parameters.AddWithValue("@schema", schema); + command.Parameters.AddWithValue("@tableName", tableName); // the name is not escaped by us, just provided as parameter + return command; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs index 0d73b806b6da..2a6d90101987 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Threading; using Microsoft.Data.SqlClient; @@ -13,42 +14,112 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer; /// public sealed class SqlServerVectorStore : IVectorStore, IDisposable { - internal const string DefaultSchema = "dbo"; - internal const int DefaultEmbeddingDimensionsCount = 1536; + private static readonly ConcurrentDictionary s_propertyReaders = new(); + + private static readonly HashSet s_supportedKeyTypes = + [ + typeof(int), // INT + typeof(long), // BIGINT + typeof(string), // VARCHAR + typeof(Guid), // UNIQUEIDENTIFIER + // TODO adsitnik: do we want to support DATETIME (DateTime) and VARBINARY (byte[])? + ]; + + private static readonly HashSet s_supportedDataTypes = + [ + typeof(int), // INT + typeof(long), // BIGINT. + typeof(Guid), // UNIQUEIDENTIFIER. + typeof(string), // NVARCHAR + typeof(byte[]), //VARBINARY + typeof(bool), // BIT + typeof(DateTime), // DATETIME + typeof(TimeSpan), // TIME + typeof(decimal), // DECIMAL + typeof(double), // FLOAT + typeof(float) // REAL + ]; + + private static readonly HashSet s_supportedVectorTypes = + [ + typeof(ReadOnlyMemory), // VECTOR + typeof(ReadOnlyMemory?) + ]; private readonly SqlConnection _connection; - private readonly SqlServerClient _sqlServerClient; + private readonly SqlServerVectorStoreOptions _options; /// /// Initializes a new instance of the class. /// /// Database connection. - /// Database schema of collection tables. - /// Number of dimensions that stored embeddings will use - public SqlServerVectorStore(SqlConnection connection, string schema = DefaultSchema, int embeddingDimensionsCount = DefaultEmbeddingDimensionsCount) + /// Optional configuration options. + public SqlServerVectorStore(SqlConnection connection, SqlServerVectorStoreOptions? options = null) { // TODO adsitnik: design: - // 1. Do we need a ctor that takes the connection string and creates a connection? - // What is the story with pooling for the SqlConnection type? - // Does it maintain a private instance pool? Or a static one? - // 2. Should we introduce an option bag for the schema and embeddingDimensionsCount? - // This would allow us to add more options in the future without breaking the API. + // Do we need a ctor that takes the connection string and creates a connection? + // What is the story with pooling for the SqlConnection type? + // Does it maintain a private instance pool? Or a static one? this._connection = connection; - this._sqlServerClient = new SqlServerClient(connection, schema, embeddingDimensionsCount); + // We need to create a copy, so any changes made to the option bag after + // the ctor call do not affect this instance. + this._options = options is not null + ? new() { Schema = options.Schema, EmbeddingDimensionsCount = options.EmbeddingDimensionsCount } + : SqlServerVectorStoreOptions.Defaults; } /// public void Dispose() => this._connection.Dispose(); + // TODO: adsitnik: design + // I find the creation process uniutive: the IVectorStoreRecordCollection.Create + // method does take only table name as an arugment, the metadata needs to be provided + // a step before that by passing the VectorStoreRecordDefinition to the GetCollection method. + // I would expect VectorStoreRecordDefinition to be argument of the CreateCollectionAsync. + // Also, please consider another problem: + // On Monday, I pass two arguments to GetCollection: + // a name: "theName" + // and a definition: "theDefinition" that consists of two properties + // When I call CreateCollectionAsync, it gets created. + // On Tuesday, I pass the same name, but a different definition: three properties. + // Now CollectionExistsAsync returns true, despite the properties mismatch?! /// public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull { Verify.NotNull(name); - throw new System.NotImplementedException(); + if (!s_propertyReaders.TryGetValue(typeof(TRecord), out VectorStoreRecordPropertyReader? propertyReader)) + { + propertyReader = new(typeof(TRecord), + // TODO adsitnik: should we cache the property reader when user has provided the VectorStoreRecordDefinition? + vectorStoreRecordDefinition, + new() + { + RequiresAtLeastOneVector = false, + // TODO adsitnik: design: can TKey represent a composite key (PRIMARY KEY)? + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + }); + + propertyReader.VerifyKeyProperties(s_supportedKeyTypes); + // TODO adsitnik: get the list of supported ienumerable types + propertyReader.VerifyDataProperties(s_supportedDataTypes, supportEnumerable: true); + propertyReader.VerifyVectorProperties(s_supportedVectorTypes); + + // Add to the cache once we have verified the record definition. + s_propertyReaders.TryAdd(typeof(TRecord), propertyReader); + } + + return new SqlServerVectorStoreRecordCollection( + this._connection, + name, + this._options, + propertyReader); } /// public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) - => this._sqlServerClient.GetTablesAsync(cancellationToken); + { + throw new NotImplementedException(); + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs new file mode 100644 index 000000000000..ab341e6cbec7 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +/// +/// Options for creating a . +/// +public sealed class SqlServerVectorStoreOptions +{ + internal static readonly SqlServerVectorStoreOptions Defaults = new(); + + /// + /// Gets or sets the database schema. + /// + public string Schema { get; init; } = "dbo"; + + /// + /// Number of dimensions that stored embeddings will use. + /// + public int EmbeddingDimensionsCount { get; init; } = 1536; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs new file mode 100644 index 000000000000..81310eaf5151 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +internal sealed class SqlServerVectorStoreRecordCollection : IVectorStoreRecordCollection + where TKey : notnull +{ + private readonly SqlConnection _sqlConnection; + private readonly SqlServerVectorStoreOptions _options; + private readonly VectorStoreRecordPropertyReader _propertyReader; + + internal SqlServerVectorStoreRecordCollection(SqlConnection sqlConnection, string name, SqlServerVectorStoreOptions options, VectorStoreRecordPropertyReader propertyReader) + { + this._sqlConnection = sqlConnection; + this.CollectionName = name; + this._options = options; + this._propertyReader = propertyReader; + } + + public string CollectionName { get; } + + public async Task CollectionExistsAsync(CancellationToken cancellationToken = default) + { + await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); + + using SqlCommand command = SqlServerCommandBuilder.SelectTableName(this._sqlConnection, this._options.Schema, this.CollectionName); + using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + return await reader.ReadAsync(cancellationToken).ConfigureAwait(false); + } + + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) + => this.CreateCollectionAsync(ifNotExists: false, cancellationToken); + + // TODO adsitnik: design: We typically don't provide such methods in BCL. + // 1. I totally see why we want to provide it, we just need to make sure it's the right thing to do. + // 2. An alternative would be to make CreateCollectionAsync a nop when the collection already exists + // or extend it with an optional boolan parameter that would control the behavior. + public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + => this.CreateCollectionAsync(ifNotExists: true, cancellationToken); + + private async Task CreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken) + { + await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); + + using SqlCommand command = SqlServerCommandBuilder.CreateTable( + this._sqlConnection, + this._options, + this.CollectionName, + ifNotExists, + this._propertyReader.KeyProperty, + this._propertyReader.DataProperties, + this._propertyReader.VectorProperties); + + await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + + public async Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + { + await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); + + using SqlCommand cmd = SqlServerCommandBuilder.DropTable(this._sqlConnection, this._options.Schema, this.CollectionName); + + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + + private Task EnsureConnectionIsOpenedAsync(CancellationToken cancellationToken) + => this._sqlConnection.State == System.Data.ConnectionState.Open + ? Task.CompletedTask + : this._sqlConnection.OpenAsync(cancellationToken); + + public Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public IAsyncEnumerable UpsertBatchAsync(IEnumerable records, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs new file mode 100644 index 000000000000..43a50a59d6db --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -0,0 +1,106 @@ +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.SqlServer; +using Xunit; + +namespace SqlServerIntegrationTests; + +public class SqlServerCommandBuilderTests +{ + [Theory] + [InlineData("schema", "name", "[schema].[name]")] + [InlineData("schema", "[brackets]", "[schema].[[brackets]]]")] + [InlineData("needs]escaping", "[brackets]", "[needs]]escaping].[[brackets]]]")] + public void GetSanitizedFullTableName(string schema, string table, string expectedFullName) + { + string result = SqlServerCommandBuilder.GetSanitizedFullTableName(schema, table); + Assert.Equal(expectedFullName, result); + } + + [Theory] + [InlineData("schema", "simpleName", "[simpleName]")] + [InlineData("schema", "[needsEscaping]", "[[needsEscaping]]]")] + public void DropTable(string schema, string table, string expectedTable) + { + using SqlConnection connection = CreateConnection(); + using SqlCommand command = SqlServerCommandBuilder.DropTable(connection, schema, table); + + Assert.Equal($"DROP TABLE IF EXISTS [{schema}].{expectedTable}", command.CommandText); + } + + [Theory] + [InlineData("schema", "simpleName")] + [InlineData("schema", "[needsEscaping]")] + public void SelectTableName(string schema, string table) + { + using SqlConnection connection = CreateConnection(); + using SqlCommand command = SqlServerCommandBuilder.SelectTableName(connection, schema, table); + + Assert.Equal( + """ + SELECT TABLE_NAME + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE = 'BASE TABLE' + AND TABLE_SCHEMA = @schema + AND TABLE_NAME = @tableName + """ + , command.CommandText); + + Assert.Equal(schema, command.Parameters[0].Value); + Assert.Equal(table, command.Parameters[1].Value); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void CreateTable(bool ifNotExists) + { + SqlServerVectorStoreOptions options = new() + { + Schema = "schema" + }; + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); + VectorStoreRecordDataProperty[] dataProperties = + [ + new VectorStoreRecordDataProperty("simpleName", typeof(string)), + new VectorStoreRecordDataProperty("with space", typeof(int)) + ]; + VectorStoreRecordVectorProperty[] vectorProperties = + [ + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { + Dimensions = 10 + } + ]; + + using SqlConnection connection = CreateConnection(); + using SqlCommand command = SqlServerCommandBuilder.CreateTable(connection, options, "table", + ifNotExists, keyProperty, dataProperties, vectorProperties); + + string expectedCommand = + """ + CREATE TABLE [schema].[table] ( + [id] BIGINT NOT NULL, + [simpleName] NVARCHAR(255) COLLATE Latin1_General_100_BIN2, + [with space] INT NOT NULL, + [embedding1] VECTOR(10), + PRIMARY KEY NONCLUSTERED ([id]) + ) + """; + if (ifNotExists) + { + expectedCommand = "IF OBJECT_ID(N'[schema].[table]', N'U') IS NULL\n" + expectedCommand; + } + + if (OperatingSystem.IsWindows()) + { + expectedCommand = expectedCommand.Replace("\n", "\r\n"); + } + + Assert.Equal(expectedCommand, command.CommandText); + } + + // We create a connection using a fake connection string just to be able to create the SqlCommand. + private static SqlConnection CreateConnection() + => new("Server=localhost;Database=master;Integrated Security=True;"); +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs new file mode 100644 index 000000000000..f9c3163559d4 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -0,0 +1,50 @@ +using Microsoft.Extensions.VectorData; +using SqlServerIntegrationTests.Support; +using Xunit; + +namespace SqlServerIntegrationTests; + +public class SqlServerVectorStoreTests +{ + [Fact] + public async Task CanCreateAndDeleteTheCollections() + { + SqlServerTestStore testStore = new(); + + await testStore.ReferenceCountingStartAsync(); + + var collection = testStore.DefaultVectorStore.GetCollection("collection"); + + try + { + Assert.False(await collection.CollectionExistsAsync()); + + await collection.CreateCollectionAsync(); + + Assert.True(await collection.CollectionExistsAsync()); + + await collection.CreateCollectionIfNotExistsAsync(); + + Assert.True(await collection.CollectionExistsAsync()); + + await collection.DeleteCollectionAsync(); + + Assert.False(await collection.CollectionExistsAsync()); + } + finally + { + await collection.DeleteCollectionAsync(); + + await testStore.ReferenceCountingStopAsync(); + } + } + + public sealed class TestModel + { + [VectorStoreRecordKey(StoragePropertyName = "key")] + public string Id { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "column")] + public int Number { get; set; } + } +} From a41cac468deef67eed12dc79cd68f99c8c71f249 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Sun, 16 Feb 2025 23:10:09 +0100 Subject: [PATCH 06/32] implement record insert and update (upsert) --- .../SqlServerCommandBuilder.cs | 104 +++++++++++++++- .../SqlServerVectorStore.cs | 3 + .../SqlServerVectorStoreOptions.cs | 2 + .../SqlServerVectorStoreRecordCollection.cs | 67 ++++++++++- .../VectorStoreRecordKeyAttribute.cs | 2 + .../SqlServerCommandBuilderTests.cs | 111 +++++++++++++++++- .../SqlServerVectorStoreTests.cs | 29 +++++ 7 files changed, 310 insertions(+), 8 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 31164dde0dd1..252befdaff66 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -1,10 +1,12 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Microsoft.Data.SqlClient; using Microsoft.Extensions.VectorData; #pragma warning disable CA2100 // Review SQL queries for security vulnerabilities +#pragma warning disable CA1851 // Possible multiple enumerations of IEnumerable namespace Microsoft.SemanticKernel.Connectors.SqlServer; @@ -65,8 +67,6 @@ internal static SqlCommand CreateTable( command.CommandText = sb.ToString(); return command; - static string GetColumnName(VectorStoreRecordProperty property) => property.StoragePropertyName ?? property.DataModelPropertyName; - static (string sqlName, bool isNullable) Map(Type type) => type switch { Type t when t == typeof(int) => ("INT", false), @@ -106,4 +106,104 @@ FROM INFORMATION_SCHEMA.TABLES command.Parameters.AddWithValue("@tableName", tableName); // the name is not escaped by us, just provided as parameter return command; } + + internal static SqlCommand InsertInto( + SqlConnection connection, + SqlServerVectorStoreOptions options, + string tableName, + VectorStoreRecordKeyProperty keyProperty, + IReadOnlyList dataProperties, + IReadOnlyList vectorProperties, + Dictionary record) + { + SqlCommand command = connection.CreateCommand(); + string fullTableName = GetSanitizedFullTableName(options.Schema, tableName); + StringBuilder sb = new(200); + sb.AppendFormat("INSERT INTO {0} (", fullTableName); + // Use square brackets to escape column names. + foreach (VectorStoreRecordProperty property in dataProperties.Concat(vectorProperties)) + { + sb.AppendFormat("[{0}],", GetColumnName(property)); + } + sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis + sb.AppendLine(); + sb.AppendFormat("OUTPUT inserted.[{0}]", GetColumnName(keyProperty)); + sb.AppendLine(); + sb.Append("VALUES ("); + foreach (VectorStoreRecordProperty property in dataProperties.Concat(vectorProperties)) + { + int index = sb.Length; + sb.AppendFormat("@{0},", GetColumnName(property)); + string paramName = sb.ToString(index, sb.Length - index - 1); // 1 is for the comma + command.Parameters.AddWithValue(paramName, record[property.DataModelPropertyName] ?? (object)DBNull.Value); + } + sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis + sb.Append(';'); + + command.CommandText = sb.ToString(); + return command; + } + + internal static SqlCommand MergeInto( + SqlConnection connection, + SqlServerVectorStoreOptions options, + string tableName, + VectorStoreRecordKeyProperty keyProperty, + IReadOnlyList dataProperties, + IReadOnlyList vectorProperties, + Dictionary record) + { + SqlCommand command = connection.CreateCommand(); + string fullTableName = GetSanitizedFullTableName(options.Schema, tableName); + StringBuilder sb = new(200); + sb.AppendFormat("MERGE INTO {0} AS t", fullTableName).AppendLine(); + sb.Append("USING (VALUES ("); + var allProperties = new VectorStoreRecordProperty[] { keyProperty }.Concat(dataProperties).Concat(vectorProperties); + foreach (VectorStoreRecordProperty property in allProperties) + { + int index = sb.Length; + sb.AppendFormat("@{0},", GetColumnName(property)); + string paramName = sb.ToString(index, sb.Length - index - 1); // 1 is for the comma + command.Parameters.AddWithValue(paramName, record[property.DataModelPropertyName] ?? (object)DBNull.Value); + } + sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis + sb.AppendFormat(") AS s ("); + foreach (VectorStoreRecordProperty property in allProperties) + { + sb.AppendFormat("[{0}],", GetColumnName(property)); + } + sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis + sb.AppendLine(); + sb.AppendFormat("ON (t.[{0}] = s.[{0}])", GetColumnName(keyProperty)).AppendLine(); + sb.AppendLine("WHEN MATCHED THEN"); + sb.Append("UPDATE SET "); + foreach (VectorStoreRecordProperty property in dataProperties.Concat(vectorProperties)) + { + sb.AppendFormat("t.[{0}] = s.[{0}],", GetColumnName(property)); + } + --sb.Length; // remove the last comma + sb.AppendLine(); + sb.Append("WHEN NOT MATCHED THEN"); + sb.AppendLine(); + sb.Append("INSERT ("); + foreach (VectorStoreRecordProperty property in allProperties) + { + sb.AppendFormat("[{0}],", GetColumnName(property)); + } + sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis + sb.AppendLine(); + sb.Append("VALUES ("); + foreach (VectorStoreRecordProperty property in allProperties) + { + sb.AppendFormat("s.[{0}],", GetColumnName(property)); + } + sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis + sb.Append(';'); + + command.CommandText = sb.ToString(); + return command; + } + + private static string GetColumnName(VectorStoreRecordProperty property) + => property.StoragePropertyName ?? property.DataModelPropertyName; } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs index 2a6d90101987..c88c01da8890 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs @@ -12,6 +12,9 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer; /// /// An implementation of backed by a SQL Server or Azure SQL database. /// +// TODO adsitnik: design: the interface is not generic, so I am not sure how the users can customize the +// mapping between the record and the table. Am I missing something? +// The interface I am talking about: public IVectorStoreRecordMapper>. public sealed class SqlServerVectorStore : IVectorStore, IDisposable { private static readonly ConcurrentDictionary s_propertyReaders = new(); diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs index ab341e6cbec7..653ec9fb1c18 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs @@ -17,5 +17,7 @@ public sealed class SqlServerVectorStoreOptions /// /// Number of dimensions that stored embeddings will use. /// + // TODO: adsitnik: this design most likely won't need this setting, + // as it up to the TRecrod to define the dimensions. public int EmbeddingDimensionsCount { get; init; } = 1536; } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index 81310eaf5151..d072ca0920a2 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Data.SqlClient; @@ -95,9 +96,46 @@ public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecord throw new NotImplementedException(); } - public Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + Verify.NotNull(record); + + await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); + + TKey? key = (TKey)this._propertyReader.KeyPropertyInfo.GetValue(record); + Dictionary map = Map(record, this._propertyReader, key); + using SqlCommand command = key is null + // When the key is null, we are inserting a new record. + ? SqlServerCommandBuilder.InsertInto( + this._sqlConnection, + this._options, + this.CollectionName, + this._propertyReader.KeyProperty, + this._propertyReader.DataProperties, + this._propertyReader.VectorProperties, + map) + : SqlServerCommandBuilder.MergeInto( + this._sqlConnection, + this._options, + this.CollectionName, + this._propertyReader.KeyProperty, + this._propertyReader.DataProperties, + this._propertyReader.VectorProperties, + map); + + if (key is not null) + { + await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + return key; + } + + if (typeof(int) == typeof(TKey)) + { + return (TKey)(object)await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + + SqlDataReader sqlDataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + return await sqlDataReader.GetFieldValueAsync(0, cancellationToken).ConfigureAwait(false); } public IAsyncEnumerable UpsertBatchAsync(IEnumerable records, CancellationToken cancellationToken = default) @@ -109,4 +147,29 @@ public Task> VectorizedSearchAsync(TVector { throw new NotImplementedException(); } + + private static Dictionary Map(TRecord record, VectorStoreRecordPropertyReader propertyReader, TKey key) + { + Dictionary map = new(StringComparer.Ordinal); + map[propertyReader.KeyProperty.DataModelPropertyName] = key; + + var dataProperties = propertyReader.DataProperties; + for (int i = 0; i < dataProperties.Count; i++) + { + map[dataProperties[i].DataModelPropertyName] = propertyReader.DataPropertiesInfo[i].GetValue(record); + } + var vectorProperties = propertyReader.VectorProperties; + for (int i = 0; i < vectorProperties.Count; i++) + { + // We restrict the vector properties to ReadOnlyMemory so the cast here is safe. + ReadOnlyMemory floats = (ReadOnlyMemory)propertyReader.VectorPropertiesInfo[i].GetValue(record); + // We know that SqlServer supports JSON serialization, so we can serialize the vector as JSON now, + // so the SqlServerCommandBuilder does not need to worry about that. + // TODO adsitnik perf: we could remove the dependency to System.Text.Json + // by using a hand-written serializer. + map[vectorProperties[i].DataModelPropertyName] = JsonSerializer.Serialize(floats); + } + + return map; + } } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs index 871794872adc..a2f27488fd63 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs @@ -11,6 +11,8 @@ namespace Microsoft.Extensions.VectorData; /// The characteristics defined here will influence how the property is treated by the vector store. /// [AttributeUsage(AttributeTargets.Property, AllowMultiple = false)] +// TODO adsitnik design: this attribute does not allow us to tell the DB to insert the key +// and upsert expects us to handle such scenario. public sealed class VectorStoreRecordKeyAttribute : Attribute { /// diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index 43a50a59d6db..7c4de285c8d6 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -92,14 +92,117 @@ PRIMARY KEY NONCLUSTERED ([id]) expectedCommand = "IF OBJECT_ID(N'[schema].[table]', N'U') IS NULL\n" + expectedCommand; } - if (OperatingSystem.IsWindows()) + Assert.Equal(HandleNewLines(expectedCommand), command.CommandText); + } + + [Fact] + public void InsertInto() + { + SqlServerVectorStoreOptions options = new() { - expectedCommand = expectedCommand.Replace("\n", "\r\n"); - } + Schema = "schema" + }; + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); + VectorStoreRecordDataProperty[] dataProperties = + [ + new VectorStoreRecordDataProperty("simpleString", typeof(string)), + new VectorStoreRecordDataProperty("simpleInt", typeof(int)) + ]; + VectorStoreRecordVectorProperty[] vectorProperties = + [ + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { + Dimensions = 10 + } + ]; + + using SqlConnection connection = CreateConnection(); + using SqlCommand command = SqlServerCommandBuilder.InsertInto(connection, options, "table", + keyProperty, dataProperties, vectorProperties, + new Dictionary + { + { "id", null }, + { "simpleString", "nameValue" }, + { "simpleInt", 134 }, + { "embedding1", "{ 10.0 }" } + }); + + string expectedCommand = + """ + INSERT INTO [schema].[table] ([simpleString],[simpleInt],[embedding1]) + OUTPUT inserted.[id] + VALUES (@simpleString,@simpleInt,@embedding1); + """; - Assert.Equal(expectedCommand, command.CommandText); + Assert.Equal(HandleNewLines(expectedCommand), command.CommandText); + Assert.Equal("@simpleString", command.Parameters[0].ParameterName); + Assert.Equal("nameValue", command.Parameters[0].Value); + Assert.Equal("@simpleInt", command.Parameters[1].ParameterName); + Assert.Equal(134, command.Parameters[1].Value); + Assert.Equal("@embedding1", command.Parameters[2].ParameterName); + Assert.Equal("{ 10.0 }", command.Parameters[2].Value); } + [Fact] + public void MergeInto() + { + SqlServerVectorStoreOptions options = new() + { + Schema = "schema" + }; + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); + VectorStoreRecordDataProperty[] dataProperties = + [ + new VectorStoreRecordDataProperty("simpleString", typeof(string)), + new VectorStoreRecordDataProperty("simpleInt", typeof(int)) + ]; + VectorStoreRecordVectorProperty[] vectorProperties = + [ + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { + Dimensions = 10 + } + ]; + + using SqlConnection connection = CreateConnection(); + using SqlCommand command = SqlServerCommandBuilder.MergeInto(connection, options, "table", + keyProperty, dataProperties, vectorProperties, + new Dictionary + { + { "id", null }, + { "simpleString", "nameValue" }, + { "simpleInt", 134 }, + { "embedding1", "{ 10.0 }" } + }); + + string expectedCommand = + """" + MERGE INTO [schema].[table] AS t + USING (VALUES (@id,@simpleString,@simpleInt,@embedding1)) AS s ([id],[simpleString],[simpleInt],[embedding1]) + ON (t.[id] = s.[id]) + WHEN MATCHED THEN + UPDATE SET t.[simpleString] = s.[simpleString],t.[simpleInt] = s.[simpleInt],t.[embedding1] = s.[embedding1] + WHEN NOT MATCHED THEN + INSERT ([id],[simpleString],[simpleInt],[embedding1]) + VALUES (s.[id],s.[simpleString],s.[simpleInt],s.[embedding1]); + """"; + + Assert.Equal(HandleNewLines(expectedCommand), command.CommandText); + Assert.Equal("@id", command.Parameters[0].ParameterName); + Assert.Equal(DBNull.Value, command.Parameters[0].Value); + Assert.Equal("@simpleString", command.Parameters[1].ParameterName); + Assert.Equal("nameValue", command.Parameters[1].Value); + Assert.Equal("@simpleInt", command.Parameters[2].ParameterName); + Assert.Equal(134, command.Parameters[2].Value); + Assert.Equal("@embedding1", command.Parameters[3].ParameterName); + Assert.Equal("{ 10.0 }", command.Parameters[3].Value); + } + + private static string HandleNewLines(string expectedCommand) + => OperatingSystem.IsWindows() + ? expectedCommand.Replace("\n", "\r\n") + : expectedCommand; + // We create a connection using a fake connection string just to be able to create the SqlCommand. private static SqlConnection CreateConnection() => new("Server=localhost;Database=master;Integrated Security=True;"); diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs index f9c3163559d4..05b0edb530fa 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -39,6 +39,35 @@ public async Task CanCreateAndDeleteTheCollections() } } + [Fact] + public async Task CanInsertRecord() + { + SqlServerTestStore testStore = new(); + + await testStore.ReferenceCountingStartAsync(); + + var collection = testStore.DefaultVectorStore.GetCollection("other"); + + try + { + await collection.CreateCollectionIfNotExistsAsync(); + + string key = await collection.UpsertAsync(new TestModel() + { + Id = "MyId", + Number = 100 + }); + + Assert.Equal("MyId", key); + } + finally + { + await collection.DeleteCollectionAsync(); + + await testStore.ReferenceCountingStopAsync(); + } + } + public sealed class TestModel { [VectorStoreRecordKey(StoragePropertyName = "key")] From 6dfb04a45a4d4a8a18ca19ca1c2de1bac01d45fc Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Mon, 17 Feb 2025 15:08:07 +0100 Subject: [PATCH 07/32] implement delete operations --- .../SqlServerCommandBuilder.cs | 57 +++++++++++++++++++ .../SqlServerVectorStoreRecordCollection.cs | 34 +++++++++-- .../SqlServerCommandBuilderTests.cs | 46 ++++++++++++++- .../SqlServerVectorStoreTests.cs | 16 +++++- 4 files changed, 142 insertions(+), 11 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 252befdaff66..9a102d500f02 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -204,6 +204,63 @@ internal static SqlCommand MergeInto( return command; } + internal static SqlCommand DeleteSingle( + SqlConnection connection, string schema, string tableName, + VectorStoreRecordKeyProperty keyProperty, object key) + { + SqlCommand command = connection.CreateCommand(); + string fullTableName = GetSanitizedFullTableName(schema, tableName); + string keyParamName = $"@{GetColumnName(keyProperty)}"; + command.CommandText = + $"""" + DELETE + FROM {fullTableName} + WHERE [{GetColumnName(keyProperty)}] = {keyParamName} + """"; + command.Parameters.AddWithValue(keyParamName, key); + return command; + } + + + internal static SqlCommand DeleteMany( + SqlConnection connection, string schema, string tableName, + VectorStoreRecordKeyProperty keyProperty, IEnumerable keys) + { + SqlCommand command = connection.CreateCommand(); + string fullTableName = GetSanitizedFullTableName(schema, tableName); + string keyParamName = $"@{GetColumnName(keyProperty)}"; + + StringBuilder keyParams = new(); + int keyIndex = 0; + foreach (TKey key in keys) + { + // The caller ensures that keys collection is not null. + // We need to ensure that none of the keys is null. + Verify.NotNull(key); + int index = keyParams.Length; + keyParams.AppendFormat("@k{0},", keyIndex++); + string keyParam = keyParams.ToString(index, keyParams.Length - index - 1); // 1 is for the comma + command.Parameters.AddWithValue(keyParam, key); + } + + if (keyParams.Length == 0) + { + // TODO adsitnik design: should we throw or simply do nothing? + throw new ArgumentException("The value cannot be empty.", nameof(keys)); + } + + keyParams.Length--; // remove the last comma + + command.CommandText = + $"""" + DELETE + FROM {fullTableName} + WHERE [{GetColumnName(keyProperty)}] IN ({keyParams}) + """"; + + return command; + } + private static string GetColumnName(VectorStoreRecordProperty property) => property.StoragePropertyName ?? property.DataModelPropertyName; } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index d072ca0920a2..3f86f656c47c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -66,9 +66,9 @@ public async Task DeleteCollectionAsync(CancellationToken cancellationToken = de { await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); - using SqlCommand cmd = SqlServerCommandBuilder.DropTable(this._sqlConnection, this._options.Schema, this.CollectionName); + using SqlCommand command = SqlServerCommandBuilder.DropTable(this._sqlConnection, this._options.Schema, this.CollectionName); - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } private Task EnsureConnectionIsOpenedAsync(CancellationToken cancellationToken) @@ -76,14 +76,36 @@ private Task EnsureConnectionIsOpenedAsync(CancellationToken cancellationToken) ? Task.CompletedTask : this._sqlConnection.OpenAsync(cancellationToken); - public Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) + public async Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + Verify.NotNull(key); + + await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); + + using SqlCommand command = SqlServerCommandBuilder.DeleteSingle( + this._sqlConnection, + this._options.Schema, + this.CollectionName, + this._propertyReader.KeyProperty, + key); + + await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } - public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + Verify.NotNull(keys); + + await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); + + using SqlCommand command = SqlServerCommandBuilder.DeleteMany( + this._sqlConnection, + this._options.Schema, + this.CollectionName, + this._propertyReader.KeyProperty, + keys); + + await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } public Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index 7c4de285c8d6..1faf82120574 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -119,7 +119,7 @@ public void InsertInto() using SqlConnection connection = CreateConnection(); using SqlCommand command = SqlServerCommandBuilder.InsertInto(connection, options, "table", keyProperty, dataProperties, vectorProperties, - new Dictionary + new Dictionary { { "id", null }, { "simpleString", "nameValue" }, @@ -167,7 +167,7 @@ public void MergeInto() using SqlConnection connection = CreateConnection(); using SqlCommand command = SqlServerCommandBuilder.MergeInto(connection, options, "table", keyProperty, dataProperties, vectorProperties, - new Dictionary + new Dictionary { { "id", null }, { "simpleString", "nameValue" }, @@ -198,6 +198,48 @@ WHEN NOT MATCHED THEN Assert.Equal("{ 10.0 }", command.Parameters[3].Value); } + [Fact] + public void DeleteSingle() + { + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.DeleteSingle(connection, + "schema", "tableName", keyProperty, 123L); + + Assert.Equal( + """"" + DELETE + FROM [schema].[tableName] + WHERE [id] = @id + """"", command.CommandText); + Assert.Equal(123L, command.Parameters[0].Value); + Assert.Equal("@id", command.Parameters[0].ParameterName); + } + + [Fact] + public void DeleteMany() + { + string[] keys = ["key1", "key2"]; + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(string)); + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.DeleteMany(connection, + "schema", "tableName", keyProperty, keys); + + Assert.Equal( + """"" + DELETE + FROM [schema].[tableName] + WHERE [id] IN (@k0,@k1) + """"", command.CommandText); + for (int i = 0; i < keys.Length; i++) + { + Assert.Equal(keys[i], command.Parameters[i].Value); + Assert.Equal($"@k{i}", command.Parameters[i].ParameterName); + } + } + private static string HandleNewLines(string expectedCommand) => OperatingSystem.IsWindows() ? expectedCommand.Replace("\n", "\r\n") diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs index 05b0edb530fa..db47fa5038f9 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -39,8 +39,10 @@ public async Task CanCreateAndDeleteTheCollections() } } - [Fact] - public async Task CanInsertRecord() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanInsertAndDeleteRecord(bool deleteBatch) { SqlServerTestStore testStore = new(); @@ -57,8 +59,16 @@ public async Task CanInsertRecord() Id = "MyId", Number = 100 }); - Assert.Equal("MyId", key); + + if (deleteBatch) + { + await collection.DeleteBatchAsync(["MyId"]); + } + else + { + await collection.DeleteAsync("MyId"); + } } finally { From e90540920103f885b25152e33cc20403961e8902 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Tue, 18 Feb 2025 12:31:18 +0100 Subject: [PATCH 08/32] GetAsync and GetBatchAsync --- .../SqlServerCommandBuilder.cs | 86 ++++++++++++++++--- .../SqlServerVectorStoreRecordCollection.cs | 72 +++++++++++++++- .../SqlServerCommandBuilderTests.cs | 60 +++++++++++++ .../SqlServerVectorStoreTests.cs | 30 ++++++- 4 files changed, 230 insertions(+), 18 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 9a102d500f02..61d7875758af 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -221,15 +221,85 @@ internal static SqlCommand DeleteSingle( return command; } - internal static SqlCommand DeleteMany( SqlConnection connection, string schema, string tableName, VectorStoreRecordKeyProperty keyProperty, IEnumerable keys) { SqlCommand command = connection.CreateCommand(); string fullTableName = GetSanitizedFullTableName(schema, tableName); + StringBuilder keyParams = CreateKeyParameterList(keys, command); + + command.CommandText = + $"""" + DELETE + FROM {fullTableName} + WHERE [{GetColumnName(keyProperty)}] IN ({keyParams}) + """"; + + return command; + } + + internal static SqlCommand SelectSingle( + SqlConnection sqlConnection, string schema, string collectionName, + VectorStoreRecordKeyProperty keyProperty, + IReadOnlyList properties, + object key) + { + SqlCommand command = sqlConnection.CreateCommand(); + string fullTableName = GetSanitizedFullTableName(schema, collectionName); string keyParamName = $"@{GetColumnName(keyProperty)}"; + command.Parameters.AddWithValue(keyParamName, key); + + StringBuilder sb = new(200); + sb.AppendFormat("SELECT "); + AppendColumnNames(properties, sb); + sb.AppendLine(); + sb.AppendFormat("FROM {0}", fullTableName); + sb.AppendLine(); + sb.AppendFormat("WHERE [{0}] = {1}", GetColumnName(keyProperty), keyParamName); + command.CommandText = sb.ToString(); + + return command; + } + + internal static SqlCommand SelectMany( + SqlConnection connection, string schema, string tableName, + VectorStoreRecordKeyProperty keyProperty, + IReadOnlyList properties, + IEnumerable keys) + { + SqlCommand command = connection.CreateCommand(); + string fullTableName = GetSanitizedFullTableName(schema, tableName); + StringBuilder keyParams = CreateKeyParameterList(keys, command); + StringBuilder sb = new(200); + sb.AppendFormat("SELECT "); + AppendColumnNames(properties, sb); + sb.AppendLine(); + sb.AppendFormat("FROM {0}", fullTableName); + sb.AppendLine(); + sb.AppendFormat("WHERE [{0}] IN ({1})", GetColumnName(keyProperty), keyParams); + + command.CommandText = sb.ToString(); + + return command; + } + + private static void AppendColumnNames(IReadOnlyList properties, StringBuilder sb) + { + foreach (VectorStoreRecordProperty property in properties) + { + sb.AppendFormat("[{0}],", GetColumnName(property)); + } + + if (properties.Count > 0) + { + --sb.Length; // remove the last comma + } + } + + private static StringBuilder CreateKeyParameterList(IEnumerable keys, SqlCommand command) + { StringBuilder keyParams = new(); int keyIndex = 0; foreach (TKey key in keys) @@ -245,22 +315,14 @@ internal static SqlCommand DeleteMany( if (keyParams.Length == 0) { - // TODO adsitnik design: should we throw or simply do nothing? + // TODO adsitnik clarify: should we throw or simply do nothing? throw new ArgumentException("The value cannot be empty.", nameof(keys)); } keyParams.Length--; // remove the last comma - - command.CommandText = - $"""" - DELETE - FROM {fullTableName} - WHERE [{GetColumnName(keyProperty)}] IN ({keyParams}) - """"; - - return command; + return keyParams; } - private static string GetColumnName(VectorStoreRecordProperty property) + internal static string GetColumnName(VectorStoreRecordProperty property) => property.StoragePropertyName ?? property.DataModelPropertyName; } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index 3f86f656c47c..06281ade072e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -108,14 +108,49 @@ public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken can await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } - public Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public async Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + Verify.NotNull(key); + + await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); + + using SqlCommand command = SqlServerCommandBuilder.SelectSingle( + this._sqlConnection, + this._options.Schema, + this.CollectionName, + this._propertyReader.KeyProperty, + this._propertyReader.Properties, + key); + + await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + + using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + return await reader.ReadAsync(cancellationToken).ConfigureAwait(false) + ? Map(reader, this._propertyReader) + : default; } - public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + Verify.NotNull(keys); + + await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); + + using SqlCommand command = SqlServerCommandBuilder.SelectMany( + this._sqlConnection, + this._options.Schema, + this.CollectionName, + this._propertyReader.KeyProperty, + this._propertyReader.Properties, + keys); + + await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + + using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return Map(reader, this._propertyReader); + } } public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) @@ -194,4 +229,33 @@ public Task> VectorizedSearchAsync(TVector return map; } + + private static TRecord Map(SqlDataReader reader, VectorStoreRecordPropertyReader propertyReader) + { + TRecord record = Activator.CreateInstance()!; + propertyReader.KeyPropertyInfo.SetValue(record, reader[SqlServerCommandBuilder.GetColumnName(propertyReader.KeyProperty)]); + var data = propertyReader.DataProperties; + var dataInfo = propertyReader.DataPropertiesInfo; + for (int i = 0; i < data.Count; i++) + { + object value = reader[SqlServerCommandBuilder.GetColumnName(data[i])]; + if (value is not DBNull) + { + dataInfo[i].SetValue(record, value); + } + } + var vector = propertyReader.VectorProperties; + var vectorInfo = propertyReader.VectorPropertiesInfo; + for (int i = 0; i < vector.Count; i++) + { + object value = reader[SqlServerCommandBuilder.GetColumnName(vector[i])]; + if (value is not DBNull) + { + // We know that it has to be a ReadOnlyMemory because that's what we serialized. + ReadOnlyMemory embedding = JsonSerializer.Deserialize>((string)value); + vectorInfo[i].SetValue(record, embedding); + } + } + return record; + } } diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index 1faf82120574..620a538b33f9 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -240,6 +240,66 @@ WHERE [id] IN (@k0,@k1) } } + [Fact] + public void SelectSingle() + { + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); + VectorStoreRecordProperty[] properties = [ + keyProperty, + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("age", typeof(int)), + new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) + { + Dimensions = 10 + } + ]; + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.SelectSingle(connection, + "schema", "tableName", keyProperty, properties, 123L); + + Assert.Equal(HandleNewLines( + """"" + SELECT [id],[name],[age],[embedding] + FROM [schema].[tableName] + WHERE [id] = @id + """""), command.CommandText); + Assert.Equal(123L, command.Parameters[0].Value); + Assert.Equal("@id", command.Parameters[0].ParameterName); + } + + [Fact] + public void SelectMany() + { + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); + VectorStoreRecordProperty[] properties = [ + keyProperty, + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("age", typeof(int)), + new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) + { + Dimensions = 10 + } + ]; + long[] keys = [123L, 456L, 789L]; + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.SelectMany(connection, + "schema", "tableName", keyProperty, properties, keys); + + Assert.Equal(HandleNewLines( + """"" + SELECT [id],[name],[age],[embedding] + FROM [schema].[tableName] + WHERE [id] IN (@k0,@k1,@k2) + """""), command.CommandText); + for (int i = 0; i < keys.Length; i++) + { + Assert.Equal(keys[i], command.Parameters[i].Value); + Assert.Equal($"@k{i}", command.Parameters[i].ParameterName); + } + } + private static string HandleNewLines(string expectedCommand) => OperatingSystem.IsWindows() ? expectedCommand.Replace("\n", "\r\n") diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs index db47fa5038f9..abb33e260e19 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -1,4 +1,5 @@ -using Microsoft.Extensions.VectorData; +using System.Linq; +using Microsoft.Extensions.VectorData; using SqlServerIntegrationTests.Support; using Xunit; @@ -54,13 +55,29 @@ public async Task CanInsertAndDeleteRecord(bool deleteBatch) { await collection.CreateCollectionIfNotExistsAsync(); + ReadOnlyMemory floats = Enumerable.Range(0, 10).Select(i => (float)i).ToArray(); string key = await collection.UpsertAsync(new TestModel() { Id = "MyId", - Number = 100 + Number = 100, + Floats = floats }); Assert.Equal("MyId", key); + TestModel? record = await collection.GetAsync("MyId"); + Assert.NotNull(record); + Assert.Equal(100, record.Number); + Assert.Equal("MyId", record.Id); + Assert.Equal(floats, record.Floats); + Assert.Null(record.Text); + + record = await collection.GetBatchAsync(["MyId"]).SingleAsync(); + Assert.NotNull(record); + Assert.Equal(100, record.Number); + Assert.Equal("MyId", record.Id); + Assert.Equal(floats, record.Floats); + Assert.Null(record.Text); + if (deleteBatch) { await collection.DeleteBatchAsync(["MyId"]); @@ -69,6 +86,9 @@ public async Task CanInsertAndDeleteRecord(bool deleteBatch) { await collection.DeleteAsync("MyId"); } + + Assert.Null(await collection.GetAsync("MyId")); + Assert.False(await collection.GetBatchAsync(["MyId"]).AnyAsync()); } finally { @@ -83,7 +103,13 @@ public sealed class TestModel [VectorStoreRecordKey(StoragePropertyName = "key")] public string Id { get; set; } + [VectorStoreRecordData(StoragePropertyName = "text")] + public string? Text { get; set; } + [VectorStoreRecordData(StoragePropertyName = "column")] public int Number { get; set; } + + [VectorStoreRecordVector(Dimensions: 10, StoragePropertyName = "embedding")] + public ReadOnlyMemory Floats { get; set; } } } From 7c212da28d195778a27ae149791295cffda6b455 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Tue, 18 Feb 2025 12:59:07 +0100 Subject: [PATCH 09/32] refactor --- .../SqlServerCommandBuilder.cs | 224 ++++++++++-------- .../SqlServerVectorStoreRecordCollection.cs | 3 +- .../SqlServerCommandBuilderTests.cs | 24 +- 3 files changed, 127 insertions(+), 124 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 61d7875758af..5974b0364f75 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Net.NetworkInformation; using System.Text; using Microsoft.Data.SqlClient; using Microsoft.Extensions.VectorData; @@ -12,24 +13,6 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer; internal static class SqlServerCommandBuilder { - internal static string GetSanitizedFullTableName(string schema, string tableName) - { - // If the column name contains a ], then escape it by doubling it. - // "Name with [brackets]" becomes [Name with [brackets]]]. - - StringBuilder sb = new(tableName.Length + schema.Length + 5); - sb.Append('['); - sb.Append(schema); - sb.Replace("]", "]]"); // replace the ] for schema - sb.Append("].["); - int index = sb.Length; // store the index, so we don't replace ] for schema twice - sb.Append(tableName); - sb.Replace("]", "]]", index, tableName.Length); - sb.Append(']'); - - return sb.ToString(); - } - internal static SqlCommand CreateTable( SqlConnection connection, SqlServerVectorStoreOptions options, @@ -39,18 +22,20 @@ internal static SqlCommand CreateTable( IReadOnlyList dataProperties, IReadOnlyList vectorProperties) { - SqlCommand command = connection.CreateCommand(); - string fullTableName = GetSanitizedFullTableName(options.Schema, tableName); - StringBuilder sb = new(200); if (ifNotExists) { - sb.AppendFormat("IF OBJECT_ID(N'{0}', N'U') IS NULL", fullTableName).AppendLine(); + sb.Append("IF OBJECT_ID(N'"); + sb.AppendTableName(options.Schema, tableName); + sb.AppendLine("', N'U') IS NULL"); } - sb.AppendFormat("CREATE TABLE {0} (", fullTableName).AppendLine(); + sb.Append("CREATE TABLE "); + sb.AppendTableName(options.Schema, tableName); + sb.AppendLine(" ("); // Use square brackets to escape column names. string keyColumnName = GetColumnName(keyProperty); - sb.AppendFormat("[{0}] {1} NOT NULL,", keyColumnName, Map(keyProperty.PropertyType).sqlName).AppendLine(); + sb.AppendFormat("[{0}] {1} NOT NULL,", keyColumnName, Map(keyProperty.PropertyType).sqlName); + sb.AppendLine(); for (int i = 0; i < dataProperties.Count; i++) { (string sqlName, bool isNullable) = Map(dataProperties[i].PropertyType); @@ -62,10 +47,11 @@ internal static SqlCommand CreateTable( sb.AppendFormat("[{0}] VECTOR({1}),", GetColumnName(vectorProperties[i]), vectorProperties[i].Dimensions); sb.AppendLine(); } - sb.AppendFormat("PRIMARY KEY NONCLUSTERED ([{0}])", keyColumnName).AppendLine(); + sb.AppendFormat("PRIMARY KEY NONCLUSTERED ([{0}])", keyColumnName); + sb.AppendLine(); sb.Append(')'); // end the table definition - command.CommandText = sb.ToString(); - return command; + + return connection.CreateCommand(sb); static (string sqlName, bool isNullable) Map(Type type) => type switch { @@ -86,10 +72,11 @@ internal static SqlCommand CreateTable( internal static SqlCommand DropTable(SqlConnection connection, string schema, string tableName) { - SqlCommand command = connection.CreateCommand(); - string fullTableName = GetSanitizedFullTableName(schema, tableName); - command.CommandText = $"DROP TABLE IF EXISTS {fullTableName}"; - return command; + StringBuilder sb = new(50); + sb.Append("DROP TABLE IF EXISTS "); + sb.AppendTableName(schema, tableName); + + return connection.CreateCommand(sb); } internal static SqlCommand SelectTableName(SqlConnection connection, string schema, string tableName) @@ -117,20 +104,18 @@ internal static SqlCommand InsertInto( Dictionary record) { SqlCommand command = connection.CreateCommand(); - string fullTableName = GetSanitizedFullTableName(options.Schema, tableName); + StringBuilder sb = new(200); - sb.AppendFormat("INSERT INTO {0} (", fullTableName); - // Use square brackets to escape column names. - foreach (VectorStoreRecordProperty property in dataProperties.Concat(vectorProperties)) - { - sb.AppendFormat("[{0}],", GetColumnName(property)); - } - sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis - sb.AppendLine(); + sb.Append("INSERT INTO "); + sb.AppendTableName(options.Schema, tableName); + sb.Append(" ("); + var nonKeyProperties = dataProperties.Concat(vectorProperties); + sb.AppendColumnNames(nonKeyProperties); + sb.AppendLine(")"); sb.AppendFormat("OUTPUT inserted.[{0}]", GetColumnName(keyProperty)); sb.AppendLine(); sb.Append("VALUES ("); - foreach (VectorStoreRecordProperty property in dataProperties.Concat(vectorProperties)) + foreach (VectorStoreRecordProperty property in nonKeyProperties) { int index = sb.Length; sb.AppendFormat("@{0},", GetColumnName(property)); @@ -149,17 +134,16 @@ internal static SqlCommand MergeInto( SqlServerVectorStoreOptions options, string tableName, VectorStoreRecordKeyProperty keyProperty, - IReadOnlyList dataProperties, - IReadOnlyList vectorProperties, + IReadOnlyList properties, Dictionary record) { SqlCommand command = connection.CreateCommand(); - string fullTableName = GetSanitizedFullTableName(options.Schema, tableName); StringBuilder sb = new(200); - sb.AppendFormat("MERGE INTO {0} AS t", fullTableName).AppendLine(); + sb.Append("MERGE INTO "); + sb.AppendTableName(options.Schema, tableName); + sb.AppendLine(" AS t"); sb.Append("USING (VALUES ("); - var allProperties = new VectorStoreRecordProperty[] { keyProperty }.Concat(dataProperties).Concat(vectorProperties); - foreach (VectorStoreRecordProperty property in allProperties) + foreach (VectorStoreRecordProperty property in properties) { int index = sb.Length; sb.AppendFormat("@{0},", GetColumnName(property)); @@ -168,37 +152,28 @@ internal static SqlCommand MergeInto( } sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis sb.AppendFormat(") AS s ("); - foreach (VectorStoreRecordProperty property in allProperties) - { - sb.AppendFormat("[{0}],", GetColumnName(property)); - } - sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis - sb.AppendLine(); + sb.AppendColumnNames(properties); + sb.AppendLine(")"); sb.AppendFormat("ON (t.[{0}] = s.[{0}])", GetColumnName(keyProperty)).AppendLine(); sb.AppendLine("WHEN MATCHED THEN"); sb.Append("UPDATE SET "); - foreach (VectorStoreRecordProperty property in dataProperties.Concat(vectorProperties)) + foreach (VectorStoreRecordProperty property in properties) { - sb.AppendFormat("t.[{0}] = s.[{0}],", GetColumnName(property)); + if (property != keyProperty) // don't update the key + { + sb.AppendFormat("t.[{0}] = s.[{0}],", GetColumnName(property)); + } } --sb.Length; // remove the last comma sb.AppendLine(); sb.Append("WHEN NOT MATCHED THEN"); sb.AppendLine(); sb.Append("INSERT ("); - foreach (VectorStoreRecordProperty property in allProperties) - { - sb.AppendFormat("[{0}],", GetColumnName(property)); - } - sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis - sb.AppendLine(); + sb.AppendColumnNames(properties); + sb.AppendLine(")"); sb.Append("VALUES ("); - foreach (VectorStoreRecordProperty property in allProperties) - { - sb.AppendFormat("s.[{0}],", GetColumnName(property)); - } - sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis - sb.Append(';'); + sb.AppendColumnNames(properties, prefix: "s."); + sb.Append(");"); command.CommandText = sb.ToString(); return command; @@ -209,15 +184,16 @@ internal static SqlCommand DeleteSingle( VectorStoreRecordKeyProperty keyProperty, object key) { SqlCommand command = connection.CreateCommand(); - string fullTableName = GetSanitizedFullTableName(schema, tableName); + string keyParamName = $"@{GetColumnName(keyProperty)}"; - command.CommandText = - $"""" - DELETE - FROM {fullTableName} - WHERE [{GetColumnName(keyProperty)}] = {keyParamName} - """"; command.Parameters.AddWithValue(keyParamName, key); + + StringBuilder sb = new(100); + sb.Append("DELETE FROM "); + sb.AppendTableName(schema, tableName); + sb.AppendFormat(" WHERE [{0}] = {1}", GetColumnName(keyProperty), keyParamName); + + command.CommandText = sb.ToString(); return command; } @@ -226,16 +202,15 @@ internal static SqlCommand DeleteMany( VectorStoreRecordKeyProperty keyProperty, IEnumerable keys) { SqlCommand command = connection.CreateCommand(); - string fullTableName = GetSanitizedFullTableName(schema, tableName); - StringBuilder keyParams = CreateKeyParameterList(keys, command); - command.CommandText = - $"""" - DELETE - FROM {fullTableName} - WHERE [{GetColumnName(keyProperty)}] IN ({keyParams}) - """"; + StringBuilder sb = new(100); + sb.Append("DELETE FROM "); + sb.AppendTableName(schema, tableName); + sb.AppendFormat(" WHERE [{0}] IN (", GetColumnName(keyProperty)); + sb.AppendKeyParameterList(keys, command); + sb.Append(')'); // close the IN clause + command.CommandText = sb.ToString(); return command; } @@ -246,19 +221,20 @@ internal static SqlCommand SelectSingle( object key) { SqlCommand command = sqlConnection.CreateCommand(); - string fullTableName = GetSanitizedFullTableName(schema, collectionName); + string keyParamName = $"@{GetColumnName(keyProperty)}"; command.Parameters.AddWithValue(keyParamName, key); StringBuilder sb = new(200); sb.AppendFormat("SELECT "); - AppendColumnNames(properties, sb); + sb.AppendColumnNames(properties); sb.AppendLine(); - sb.AppendFormat("FROM {0}", fullTableName); + sb.Append("FROM "); + sb.AppendTableName(schema, collectionName); sb.AppendLine(); sb.AppendFormat("WHERE [{0}] = {1}", GetColumnName(keyProperty), keyParamName); - command.CommandText = sb.ToString(); + command.CommandText = sb.ToString(); return command; } @@ -269,60 +245,100 @@ internal static SqlCommand SelectMany( IEnumerable keys) { SqlCommand command = connection.CreateCommand(); - string fullTableName = GetSanitizedFullTableName(schema, tableName); - StringBuilder keyParams = CreateKeyParameterList(keys, command); StringBuilder sb = new(200); sb.AppendFormat("SELECT "); - AppendColumnNames(properties, sb); + sb.AppendColumnNames(properties); sb.AppendLine(); - sb.AppendFormat("FROM {0}", fullTableName); + sb.Append("FROM "); + sb.AppendTableName(schema, tableName); sb.AppendLine(); - sb.AppendFormat("WHERE [{0}] IN ({1})", GetColumnName(keyProperty), keyParams); + sb.AppendFormat("WHERE [{0}] IN (", GetColumnName(keyProperty)); + sb.AppendKeyParameterList(keys, command); + sb.Append(')'); // close the IN clause command.CommandText = sb.ToString(); - return command; } - private static void AppendColumnNames(IReadOnlyList properties, StringBuilder sb) + internal static string GetColumnName(VectorStoreRecordProperty property) + => property.StoragePropertyName ?? property.DataModelPropertyName; + + // If possible, prefer AppendTableName over this method (it's exposed only for testing purposes). + internal static string GetSanitizedFullTableName(string schema, string tableName) + => new StringBuilder(1 + schema.Length + 3 + tableName.Length + 1) // [schema].[table] + .AppendTableName(schema, tableName) + .ToString(); + + private static StringBuilder AppendTableName(this StringBuilder sb, string schema, string tableName) { + // If the column name contains a ], then escape it by doubling it. + // "Name with [brackets]" becomes [Name with [brackets]]]. + + sb.Append('['); + int index = sb.Length; // store the index, so we replace ] only for schema + sb.Append(schema); + sb.Replace("]", "]]", index, schema.Length); // replace the ] for schema + sb.Append("].["); + index = sb.Length; + sb.Append(tableName); + sb.Replace("]", "]]", index, tableName.Length); + sb.Append(']'); + + return sb; + } + + private static StringBuilder AppendColumnNames(this StringBuilder sb, + IEnumerable properties, + string? prefix = null) + { + bool any = false; foreach (VectorStoreRecordProperty property in properties) { + if (prefix is not null) + { + sb.Append(prefix); + } sb.AppendFormat("[{0}],", GetColumnName(property)); + any = true; } - if (properties.Count > 0) + if (any) { --sb.Length; // remove the last comma } + + return sb; } - private static StringBuilder CreateKeyParameterList(IEnumerable keys, SqlCommand command) + private static StringBuilder AppendKeyParameterList(this StringBuilder sb, IEnumerable keys, SqlCommand command) { - StringBuilder keyParams = new(); int keyIndex = 0; foreach (TKey key in keys) { // The caller ensures that keys collection is not null. // We need to ensure that none of the keys is null. Verify.NotNull(key); - int index = keyParams.Length; - keyParams.AppendFormat("@k{0},", keyIndex++); - string keyParam = keyParams.ToString(index, keyParams.Length - index - 1); // 1 is for the comma + int index = sb.Length; + sb.AppendFormat("@k{0},", keyIndex++); + string keyParam = sb.ToString(index, sb.Length - index - 1); // 1 is for the comma command.Parameters.AddWithValue(keyParam, key); } - if (keyParams.Length == 0) + if (keyIndex == 0) { // TODO adsitnik clarify: should we throw or simply do nothing? throw new ArgumentException("The value cannot be empty.", nameof(keys)); } - keyParams.Length--; // remove the last comma - return keyParams; + sb.Length--; // remove the last comma + return sb; } - internal static string GetColumnName(VectorStoreRecordProperty property) - => property.StoragePropertyName ?? property.DataModelPropertyName; + private static SqlCommand CreateCommand(this SqlConnection connection, StringBuilder sb) + { + SqlCommand command = connection.CreateCommand(); + command.CommandText = sb.ToString(); + return command; + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index 06281ade072e..d3430500a6a7 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -176,8 +176,7 @@ public async Task UpsertAsync(TRecord record, CancellationToken cancellati this._options, this.CollectionName, this._propertyReader.KeyProperty, - this._propertyReader.DataProperties, - this._propertyReader.VectorProperties, + this._propertyReader.Properties, map); if (key is not null) diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index 620a538b33f9..9a700fee79ab 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -151,13 +151,11 @@ public void MergeInto() Schema = "schema" }; VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); - VectorStoreRecordDataProperty[] dataProperties = + VectorStoreRecordProperty[] properties = [ + keyProperty, new VectorStoreRecordDataProperty("simpleString", typeof(string)), - new VectorStoreRecordDataProperty("simpleInt", typeof(int)) - ]; - VectorStoreRecordVectorProperty[] vectorProperties = - [ + new VectorStoreRecordDataProperty("simpleInt", typeof(int)), new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) { Dimensions = 10 @@ -166,7 +164,7 @@ public void MergeInto() using SqlConnection connection = CreateConnection(); using SqlCommand command = SqlServerCommandBuilder.MergeInto(connection, options, "table", - keyProperty, dataProperties, vectorProperties, + keyProperty, properties, new Dictionary { { "id", null }, @@ -207,12 +205,7 @@ public void DeleteSingle() using SqlCommand command = SqlServerCommandBuilder.DeleteSingle(connection, "schema", "tableName", keyProperty, 123L); - Assert.Equal( - """"" - DELETE - FROM [schema].[tableName] - WHERE [id] = @id - """"", command.CommandText); + Assert.Equal("DELETE FROM [schema].[tableName] WHERE [id] = @id", command.CommandText); Assert.Equal(123L, command.Parameters[0].Value); Assert.Equal("@id", command.Parameters[0].ParameterName); } @@ -227,12 +220,7 @@ public void DeleteMany() using SqlCommand command = SqlServerCommandBuilder.DeleteMany(connection, "schema", "tableName", keyProperty, keys); - Assert.Equal( - """"" - DELETE - FROM [schema].[tableName] - WHERE [id] IN (@k0,@k1) - """"", command.CommandText); + Assert.Equal("DELETE FROM [schema].[tableName] WHERE [id] IN (@k0,@k1)", command.CommandText); for (int i = 0; i < keys.Length; i++) { Assert.Equal(keys[i], command.Parameters[i].Value); From 8d71b11e21b30a04d522828a1fb91e48f026802f Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Tue, 18 Feb 2025 15:10:45 +0100 Subject: [PATCH 10/32] implement UpsertBatchAsync --- .../SqlServerCommandBuilder.cs | 120 +++++++++++++---- .../SqlServerVectorStoreRecordCollection.cs | 58 ++++++--- .../SqlServerCommandBuilderTests.cs | 95 ++++++++++++-- .../SqlServerVectorStoreTests.cs | 123 +++++++++++++----- 4 files changed, 313 insertions(+), 83 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 5974b0364f75..1d056ad8c7fa 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -52,22 +52,6 @@ internal static SqlCommand CreateTable( sb.Append(')'); // end the table definition return connection.CreateCommand(sb); - - static (string sqlName, bool isNullable) Map(Type type) => type switch - { - Type t when t == typeof(int) => ("INT", false), - Type t when t == typeof(long) => ("BIGINT", false), - Type t when t == typeof(Guid) => ("UNIQUEIDENTIFIER", false), - Type t when t == typeof(string) => ("NVARCHAR(255) COLLATE Latin1_General_100_BIN2", true), - Type t when t == typeof(byte[]) => ("VARBINARY(MAX)", true), - Type t when t == typeof(bool) => ("BIT", false), - Type t when t == typeof(DateTime) => ("DATETIME", false), - Type t when t == typeof(TimeSpan) => ("TIME", false), - Type t when t == typeof(decimal) => ("DECIMAL", false), - Type t when t == typeof(double) => ("FLOAT", false), - Type t when t == typeof(float) => ("REAL", false), - _ => throw new NotSupportedException($"Type {type} is not supported.") - }; } internal static SqlCommand DropTable(SqlConnection connection, string schema, string tableName) @@ -129,7 +113,7 @@ internal static SqlCommand InsertInto( return command; } - internal static SqlCommand MergeInto( + internal static SqlCommand MergeIntoSingle( SqlConnection connection, SqlServerVectorStoreOptions options, string tableName, @@ -151,7 +135,7 @@ internal static SqlCommand MergeInto( command.Parameters.AddWithValue(paramName, record[property.DataModelPropertyName] ?? (object)DBNull.Value); } sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis - sb.AppendFormat(") AS s ("); + sb.Append(") AS s ("); sb.AppendColumnNames(properties); sb.AppendLine(")"); sb.AppendFormat("ON (t.[{0}] = s.[{0}])", GetColumnName(keyProperty)).AppendLine(); @@ -179,6 +163,82 @@ internal static SqlCommand MergeInto( return command; } + internal static SqlCommand MergeIntoMany( + SqlConnection connection, + SqlServerVectorStoreOptions options, + string tableName, + VectorStoreRecordKeyProperty keyProperty, + IReadOnlyList properties, + IEnumerable> records) + { + SqlCommand command = connection.CreateCommand(); + + StringBuilder sb = new(200); + // The DECLARE statement creates a table variable to store the keys of the inserted rows. + sb.AppendFormat("DECLARE @InsertedKeys TABLE (KeyColumn {0});", Map(keyProperty.PropertyType).sqlName); + sb.AppendLine(); + // The MERGE statement performs the upsert operation and outputs the keys of the inserted rows into the table variable. + sb.Append("MERGE INTO "); + sb.AppendTableName(options.Schema, tableName); + sb.AppendLine(" AS t"); // t stands for target + sb.AppendLine("USING (VALUES"); + int rowIndex = 0; + foreach (var record in records) + { + sb.Append('('); + foreach (VectorStoreRecordProperty property in properties) + { + int index = sb.Length; + sb.AppendFormat("@{0}_{1},", GetColumnName(property), rowIndex); + string paramName = sb.ToString(index, sb.Length - index - 1); // 1 is for the comma + command.Parameters.AddWithValue(paramName, record[property.DataModelPropertyName] ?? (object)DBNull.Value); + } + sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis + sb.AppendLine(","); + rowIndex++; + } + + if (rowIndex == 0) + { + // TODO adsitnik clarify: should we throw or simply do nothing? + throw new ArgumentException("The value cannot be empty.", nameof(records)); + } + + sb.Length -= (1 + Environment.NewLine.Length); // remove the last comma and newline + + sb.Append(") AS s ("); // s stands for source + sb.AppendColumnNames(properties); + sb.AppendLine(")"); + sb.AppendFormat("ON (t.[{0}] = s.[{0}])", GetColumnName(keyProperty)).AppendLine(); + sb.AppendLine("WHEN MATCHED THEN"); + sb.Append("UPDATE SET "); + foreach (VectorStoreRecordProperty property in properties) + { + if (property != keyProperty) // don't update the key + { + sb.AppendFormat("t.[{0}] = s.[{0}],", GetColumnName(property)); + } + } + --sb.Length; // remove the last comma + sb.AppendLine(); + sb.Append("WHEN NOT MATCHED THEN"); + sb.AppendLine(); + sb.Append("INSERT ("); + sb.AppendColumnNames(properties); + sb.AppendLine(")"); + sb.Append("VALUES ("); + sb.AppendColumnNames(properties, prefix: "s."); + sb.AppendLine(")"); + sb.AppendFormat("OUTPUT inserted.[{0}] INTO @InsertedKeys (KeyColumn);", GetColumnName(keyProperty)); + sb.AppendLine(); + + // The SELECT statement returns the keys of the inserted rows. + sb.Append("SELECT KeyColumn FROM @InsertedKeys;"); + + command.CommandText = sb.ToString(); + return command; + } + internal static SqlCommand DeleteSingle( SqlConnection connection, string schema, string tableName, VectorStoreRecordKeyProperty keyProperty, object key) @@ -264,13 +324,7 @@ internal static SqlCommand SelectMany( internal static string GetColumnName(VectorStoreRecordProperty property) => property.StoragePropertyName ?? property.DataModelPropertyName; - // If possible, prefer AppendTableName over this method (it's exposed only for testing purposes). - internal static string GetSanitizedFullTableName(string schema, string tableName) - => new StringBuilder(1 + schema.Length + 3 + tableName.Length + 1) // [schema].[table] - .AppendTableName(schema, tableName) - .ToString(); - - private static StringBuilder AppendTableName(this StringBuilder sb, string schema, string tableName) + internal static StringBuilder AppendTableName(this StringBuilder sb, string schema, string tableName) { // If the column name contains a ], then escape it by doubling it. // "Name with [brackets]" becomes [Name with [brackets]]]. @@ -341,4 +395,20 @@ private static SqlCommand CreateCommand(this SqlConnection connection, StringBui command.CommandText = sb.ToString(); return command; } + + private static (string sqlName, bool isNullable) Map(Type type) => type switch + { + Type t when t == typeof(int) => ("INT", false), + Type t when t == typeof(long) => ("BIGINT", false), + Type t when t == typeof(Guid) => ("UNIQUEIDENTIFIER", false), + Type t when t == typeof(string) => ("NVARCHAR(255) COLLATE Latin1_General_100_BIN2", true), + Type t when t == typeof(byte[]) => ("VARBINARY(MAX)", true), + Type t when t == typeof(bool) => ("BIT", false), + Type t when t == typeof(DateTime) => ("DATETIME", false), + Type t when t == typeof(TimeSpan) => ("TIME", false), + Type t when t == typeof(decimal) => ("DECIMAL", false), + Type t when t == typeof(double) => ("FLOAT", false), + Type t when t == typeof(float) => ("REAL", false), + _ => throw new NotSupportedException($"Type {type} is not supported.") + }; } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index d3430500a6a7..17df10995f8c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -2,6 +2,8 @@ using System; using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -31,8 +33,10 @@ public async Task CollectionExistsAsync(CancellationToken cancellationToke { await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); - using SqlCommand command = SqlServerCommandBuilder.SelectTableName(this._sqlConnection, this._options.Schema, this.CollectionName); + using SqlCommand command = SqlServerCommandBuilder.SelectTableName( + this._sqlConnection, this._options.Schema, this.CollectionName); using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + return await reader.ReadAsync(cancellationToken).ConfigureAwait(false); } @@ -42,7 +46,7 @@ public Task CreateCollectionAsync(CancellationToken cancellationToken = default) // TODO adsitnik: design: We typically don't provide such methods in BCL. // 1. I totally see why we want to provide it, we just need to make sure it's the right thing to do. // 2. An alternative would be to make CreateCollectionAsync a nop when the collection already exists - // or extend it with an optional boolan parameter that would control the behavior. + // or extend it with an optional boolean parameter that would control the behavior. public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) => this.CreateCollectionAsync(ifNotExists: true, cancellationToken); @@ -66,16 +70,12 @@ public async Task DeleteCollectionAsync(CancellationToken cancellationToken = de { await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); - using SqlCommand command = SqlServerCommandBuilder.DropTable(this._sqlConnection, this._options.Schema, this.CollectionName); + using SqlCommand command = SqlServerCommandBuilder.DropTable( + this._sqlConnection, this._options.Schema, this.CollectionName); await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } - private Task EnsureConnectionIsOpenedAsync(CancellationToken cancellationToken) - => this._sqlConnection.State == System.Data.ConnectionState.Open - ? Task.CompletedTask - : this._sqlConnection.OpenAsync(cancellationToken); - public async Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) { Verify.NotNull(key); @@ -122,15 +122,15 @@ public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken can this._propertyReader.Properties, key); - await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + return await reader.ReadAsync(cancellationToken).ConfigureAwait(false) ? Map(reader, this._propertyReader) : default; } - public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) { Verify.NotNull(keys); @@ -144,8 +144,6 @@ public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, Get this._propertyReader.Properties, keys); - await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) { @@ -160,7 +158,7 @@ public async Task UpsertAsync(TRecord record, CancellationToken cancellati await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); TKey? key = (TKey)this._propertyReader.KeyPropertyInfo.GetValue(record); - Dictionary map = Map(record, this._propertyReader, key); + Dictionary map = Map(record, this._propertyReader); using SqlCommand command = key is null // When the key is null, we are inserting a new record. ? SqlServerCommandBuilder.InsertInto( @@ -171,7 +169,7 @@ public async Task UpsertAsync(TRecord record, CancellationToken cancellati this._propertyReader.DataProperties, this._propertyReader.VectorProperties, map) - : SqlServerCommandBuilder.MergeInto( + : SqlServerCommandBuilder.MergeIntoSingle( this._sqlConnection, this._options, this.CollectionName, @@ -194,9 +192,26 @@ public async Task UpsertAsync(TRecord record, CancellationToken cancellati return await sqlDataReader.GetFieldValueAsync(0, cancellationToken).ConfigureAwait(false); } - public IAsyncEnumerable UpsertBatchAsync(IEnumerable records, CancellationToken cancellationToken = default) + public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, + [EnumeratorCancellation] CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + Verify.NotNull(records); + + await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); + + using SqlCommand command = SqlServerCommandBuilder.MergeIntoMany( + this._sqlConnection, + this._options, + this.CollectionName, + this._propertyReader.KeyProperty, + this._propertyReader.Properties, + records.Select(record => Map(record, this._propertyReader))); + + using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return reader.GetFieldValue(0); + } } public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) @@ -204,10 +219,15 @@ public Task> VectorizedSearchAsync(TVector throw new NotImplementedException(); } - private static Dictionary Map(TRecord record, VectorStoreRecordPropertyReader propertyReader, TKey key) + private Task EnsureConnectionIsOpenedAsync(CancellationToken cancellationToken) + => this._sqlConnection.State == System.Data.ConnectionState.Open + ? Task.CompletedTask + : this._sqlConnection.OpenAsync(cancellationToken); + + private static Dictionary Map(TRecord record, VectorStoreRecordPropertyReader propertyReader) { Dictionary map = new(StringComparer.Ordinal); - map[propertyReader.KeyProperty.DataModelPropertyName] = key; + map[propertyReader.KeyProperty.DataModelPropertyName] = propertyReader.KeyPropertyInfo.GetValue(record); var dataProperties = propertyReader.DataProperties; for (int i = 0; i < dataProperties.Count; i++) diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index 9a700fee79ab..bc1881600b4c 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -1,4 +1,5 @@ -using Microsoft.Data.SqlClient; +using System.Text; +using Microsoft.Data.SqlClient; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.SqlServer; using Xunit; @@ -11,10 +12,13 @@ public class SqlServerCommandBuilderTests [InlineData("schema", "name", "[schema].[name]")] [InlineData("schema", "[brackets]", "[schema].[[brackets]]]")] [InlineData("needs]escaping", "[brackets]", "[needs]]escaping].[[brackets]]]")] - public void GetSanitizedFullTableName(string schema, string table, string expectedFullName) + public void AppendTableName(string schema, string table, string expectedFullName) { - string result = SqlServerCommandBuilder.GetSanitizedFullTableName(schema, table); - Assert.Equal(expectedFullName, result); + StringBuilder result = new(); + + SqlServerCommandBuilder.AppendTableName(result, schema, table); + + Assert.Equal(expectedFullName, result.ToString()); } [Theory] @@ -23,6 +27,7 @@ public void GetSanitizedFullTableName(string schema, string table, string expect public void DropTable(string schema, string table, string expectedTable) { using SqlConnection connection = CreateConnection(); + using SqlCommand command = SqlServerCommandBuilder.DropTable(connection, schema, table); Assert.Equal($"DROP TABLE IF EXISTS [{schema}].{expectedTable}", command.CommandText); @@ -34,6 +39,7 @@ public void DropTable(string schema, string table, string expectedTable) public void SelectTableName(string schema, string table) { using SqlConnection connection = CreateConnection(); + using SqlCommand command = SqlServerCommandBuilder.SelectTableName(connection, schema, table); Assert.Equal( @@ -45,7 +51,6 @@ FROM INFORMATION_SCHEMA.TABLES AND TABLE_NAME = @tableName """ , command.CommandText); - Assert.Equal(schema, command.Parameters[0].Value); Assert.Equal(table, command.Parameters[1].Value); } @@ -72,8 +77,8 @@ public void CreateTable(bool ifNotExists) Dimensions = 10 } ]; - using SqlConnection connection = CreateConnection(); + using SqlCommand command = SqlServerCommandBuilder.CreateTable(connection, options, "table", ifNotExists, keyProperty, dataProperties, vectorProperties); @@ -115,8 +120,8 @@ public void InsertInto() Dimensions = 10 } ]; - using SqlConnection connection = CreateConnection(); + using SqlCommand command = SqlServerCommandBuilder.InsertInto(connection, options, "table", keyProperty, dataProperties, vectorProperties, new Dictionary @@ -144,7 +149,7 @@ OUTPUT inserted.[id] } [Fact] - public void MergeInto() + public void MergeIntoSingle() { SqlServerVectorStoreOptions options = new() { @@ -163,7 +168,7 @@ public void MergeInto() ]; using SqlConnection connection = CreateConnection(); - using SqlCommand command = SqlServerCommandBuilder.MergeInto(connection, options, "table", + using SqlCommand command = SqlServerCommandBuilder.MergeIntoSingle(connection, options, "table", keyProperty, properties, new Dictionary { @@ -196,6 +201,78 @@ WHEN NOT MATCHED THEN Assert.Equal("{ 10.0 }", command.Parameters[3].Value); } + [Fact] + public void MergeIntoMany() + { + SqlServerVectorStoreOptions options = new() + { + Schema = "schema" + }; + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); + VectorStoreRecordProperty[] properties = + [ + keyProperty, + new VectorStoreRecordDataProperty("simpleString", typeof(string)), + new VectorStoreRecordDataProperty("simpleInt", typeof(int)), + new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) + { + Dimensions = 10 + } + ]; + Dictionary[] records = + [ + new Dictionary + { + { "id", 0L }, + { "simpleString", "nameValue0" }, + { "simpleInt", 134 }, + { "embedding", "{ 10.0 }" } + }, + new Dictionary + { + { "id", 1L }, + { "simpleString", "nameValue1" }, + { "simpleInt", 135 }, + { "embedding", "{ 11.0 }" } + } + ]; + + using SqlConnection connection = CreateConnection(); + using SqlCommand command = SqlServerCommandBuilder.MergeIntoMany(connection, options, "table", + keyProperty, properties, records); + + string expectedCommand = + """" + DECLARE @InsertedKeys TABLE (KeyColumn BIGINT); + MERGE INTO [schema].[table] AS t + USING (VALUES + (@id_0,@simpleString_0,@simpleInt_0,@embedding_0), + (@id_1,@simpleString_1,@simpleInt_1,@embedding_1)) AS s ([id],[simpleString],[simpleInt],[embedding]) + ON (t.[id] = s.[id]) + WHEN MATCHED THEN + UPDATE SET t.[simpleString] = s.[simpleString],t.[simpleInt] = s.[simpleInt],t.[embedding] = s.[embedding] + WHEN NOT MATCHED THEN + INSERT ([id],[simpleString],[simpleInt],[embedding]) + VALUES (s.[id],s.[simpleString],s.[simpleInt],s.[embedding]) + OUTPUT inserted.[id] INTO @InsertedKeys (KeyColumn); + SELECT KeyColumn FROM @InsertedKeys; + """"; + + Assert.Equal(expectedCommand, command.CommandText); + + for (int i = 0; i < records.Length; i++) + { + Assert.Equal($"@id_{i}", command.Parameters[4 * i].ParameterName); + Assert.Equal((long)i, command.Parameters[4 * i].Value); + Assert.Equal($"@simpleString_{i}", command.Parameters[4 * i + 1].ParameterName); + Assert.Equal($"nameValue{i}", command.Parameters[4 * i + 1].Value); + Assert.Equal($"@simpleInt_{i}", command.Parameters[4 * i + 2].ParameterName); + Assert.Equal(134 + i, command.Parameters[4 * i + 2].Value); + Assert.Equal($"@embedding_{i}", command.Parameters[4 * i + 3].ParameterName); + Assert.Equal($"{{ 1{i}.0 }}", command.Parameters[4 * i + 3].Value); + } + } + [Fact] public void DeleteSingle() { diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs index abb33e260e19..bbe0ff0cb27a 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -40,10 +40,8 @@ public async Task CanCreateAndDeleteTheCollections() } } - [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task CanInsertAndDeleteRecord(bool deleteBatch) + [Fact] + public async Task RecordCRUD() { SqlServerTestStore testStore = new(); @@ -55,40 +53,96 @@ public async Task CanInsertAndDeleteRecord(bool deleteBatch) { await collection.CreateCollectionIfNotExistsAsync(); - ReadOnlyMemory floats = Enumerable.Range(0, 10).Select(i => (float)i).ToArray(); - string key = await collection.UpsertAsync(new TestModel() + TestModel inserted = new() { Id = "MyId", Number = 100, - Floats = floats - }); - Assert.Equal("MyId", key); - - TestModel? record = await collection.GetAsync("MyId"); - Assert.NotNull(record); - Assert.Equal(100, record.Number); - Assert.Equal("MyId", record.Id); - Assert.Equal(floats, record.Floats); - Assert.Null(record.Text); - - record = await collection.GetBatchAsync(["MyId"]).SingleAsync(); - Assert.NotNull(record); - Assert.Equal(100, record.Number); - Assert.Equal("MyId", record.Id); - Assert.Equal(floats, record.Floats); - Assert.Null(record.Text); - - if (deleteBatch) + Floats = Enumerable.Range(0, 10).Select(i => (float)i).ToArray() + }; + string key = await collection.UpsertAsync(inserted); + Assert.Equal(inserted.Id, key); + + TestModel? received = await collection.GetAsync(inserted.Id); + AssertEquality(inserted, received); + + TestModel updated = new() { - await collection.DeleteBatchAsync(["MyId"]); + Id = inserted.Id, + Number = inserted.Number + 200, // change one property + Floats = inserted.Floats + }; + key = await collection.UpsertAsync(updated); + Assert.Equal(inserted.Id, key); + + received = await collection.GetAsync(updated.Id); + AssertEquality(updated, received); + + await collection.DeleteAsync(inserted.Id); + + Assert.Null(await collection.GetAsync(inserted.Id)); + } + finally + { + await collection.DeleteCollectionAsync(); + + await testStore.ReferenceCountingStopAsync(); + } + } + + [Fact] + public async Task BatchCRUD() + { + SqlServerTestStore testStore = new(); + + await testStore.ReferenceCountingStartAsync(); + + var collection = testStore.DefaultVectorStore.GetCollection("other"); + + try + { + await collection.CreateCollectionIfNotExistsAsync(); + + TestModel[] inserted = Enumerable.Range(0, 10).Select(i => new TestModel() + { + Id = $"MyId{i}", + Number = 100 + i, + Floats = Enumerable.Range(0, 10).Select(j => (float)(i + j)).ToArray() + }).ToArray(); + + string[] keys = await collection.UpsertBatchAsync(inserted).ToArrayAsync(); + for (int i = 0; i < inserted.Length; i++) + { + Assert.Equal(inserted[i].Id, keys[i]); } - else + + TestModel[] received = await collection.GetBatchAsync(keys).ToArrayAsync(); + for (int i = 0; i < inserted.Length; i++) { - await collection.DeleteAsync("MyId"); + AssertEquality(inserted[i], received[i]); } - Assert.Null(await collection.GetAsync("MyId")); - Assert.False(await collection.GetBatchAsync(["MyId"]).AnyAsync()); + TestModel[] updated = inserted.Select(i => new TestModel() + { + Id = i.Id, + Number = i.Number + 200, // change one property + Floats = i.Floats + }).ToArray(); + + keys = await collection.UpsertBatchAsync(updated).ToArrayAsync(); + for (int i = 0; i < updated.Length; i++) + { + Assert.Equal(updated[i].Id, keys[i]); + } + + received = await collection.GetBatchAsync(keys).ToArrayAsync(); + for (int i = 0; i < updated.Length; i++) + { + AssertEquality(updated[i], received[i]); + } + + await collection.DeleteBatchAsync(keys); + + Assert.False(await collection.GetBatchAsync(keys).AnyAsync()); } finally { @@ -98,6 +152,15 @@ record = await collection.GetBatchAsync(["MyId"]).SingleAsync(); } } + private static void AssertEquality(TestModel inserted, TestModel? received) + { + Assert.NotNull(received); + Assert.Equal(inserted.Number, received.Number); + Assert.Equal(inserted.Id, received.Id); + Assert.Equal(inserted.Floats, received.Floats); + Assert.Null(received.Text); // testing DBNull code path + } + public sealed class TestModel { [VectorStoreRecordKey(StoragePropertyName = "key")] From 7f18352ed3336735bc78c2c711f2f94d55eafe48 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Wed, 19 Feb 2025 09:45:52 +0100 Subject: [PATCH 11/32] implement SelectTableNames, read the code again and add TODOs for things that need to be addressed --- .../SqlServerClient.cs | 11 ++-------- .../SqlServerCommandBuilder.cs | 16 +++++++++++++- .../SqlServerVectorStore.cs | 10 +++++++-- .../SqlServerVectorStoreRecordCollection.cs | 21 +++++++----------- .../SqlServerCommandBuilderTests.cs | 22 ++++++++++++++++++- .../SqlServerVectorStoreTests.cs | 12 ++++++---- 6 files changed, 62 insertions(+), 30 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerClient.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerClient.cs index f8a90eb75873..42f65ef55cf8 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerClient.cs @@ -80,14 +80,7 @@ public async IAsyncEnumerable GetTablesAsync([EnumeratorCancellation] Ca { using (await this.OpenConnectionAsync(cancellationToken).ConfigureAwait(false)) { - using var cmd = this._connection.CreateCommand(); - cmd.CommandText = """ - SELECT TABLE_NAME - FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_TYPE = 'BASE TABLE' - AND TABLE_SCHEMA = @schema - """; - cmd.Parameters.AddWithValue("@schema", this._schema); + using var cmd = SqlServerCommandBuilder.SelectTableNames(this._connection, this._schema); using var reader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) { @@ -112,7 +105,7 @@ public async Task DeleteTableAsync(string tableName, CancellationToken cancellat { using (await this.OpenConnectionAsync(cancellationToken).ConfigureAwait(false)) { - using var cmd = SqlServerCommandBuilder.DropTable(this._connection, this._schema, tableName); + using var cmd = SqlServerCommandBuilder.DropTableIfExists(this._connection, this._schema, tableName); await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 1d056ad8c7fa..25187c0eb8fb 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -44,6 +44,7 @@ internal static SqlCommand CreateTable( } for (int i = 0; i < vectorProperties.Count; i++) { + // TODO adsitnik design: should we require Dimensions to be always provided in explicit way or use some default? sb.AppendFormat("[{0}] VECTOR({1}),", GetColumnName(vectorProperties[i]), vectorProperties[i].Dimensions); sb.AppendLine(); } @@ -54,7 +55,7 @@ internal static SqlCommand CreateTable( return connection.CreateCommand(sb); } - internal static SqlCommand DropTable(SqlConnection connection, string schema, string tableName) + internal static SqlCommand DropTableIfExists(SqlConnection connection, string schema, string tableName) { StringBuilder sb = new(50); sb.Append("DROP TABLE IF EXISTS "); @@ -78,6 +79,19 @@ FROM INFORMATION_SCHEMA.TABLES return command; } + internal static SqlCommand SelectTableNames(SqlConnection connection, string schema) + { + SqlCommand command = connection.CreateCommand(); + command.CommandText = """ + SELECT TABLE_NAME + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE = 'BASE TABLE' + AND TABLE_SCHEMA = @schema + """; + command.Parameters.AddWithValue("@schema", schema); + return command; + } + internal static SqlCommand InsertInto( SqlConnection connection, SqlServerVectorStoreOptions options, diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs index c88c01da8890..2b005a790228 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Runtime.CompilerServices; using System.Threading; using Microsoft.Data.SqlClient; using Microsoft.Extensions.VectorData; @@ -121,8 +122,13 @@ public IVectorStoreRecordCollection GetCollection( } /// - public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + using SqlCommand cmd = SqlServerCommandBuilder.SelectTableNames(this._connection, this._options.Schema); + using SqlDataReader reader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return reader.GetString(reader.GetOrdinal("table_name")); + } } } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index 17df10995f8c..6c12a3ab1ce2 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -70,7 +70,7 @@ public async Task DeleteCollectionAsync(CancellationToken cancellationToken = de { await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); - using SqlCommand command = SqlServerCommandBuilder.DropTable( + using SqlCommand command = SqlServerCommandBuilder.DropTableIfExists( this._sqlConnection, this._options.Schema, this.CollectionName); await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); @@ -183,12 +183,7 @@ public async Task UpsertAsync(TRecord record, CancellationToken cancellati return key; } - if (typeof(int) == typeof(TKey)) - { - return (TKey)(object)await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } - - SqlDataReader sqlDataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + using SqlDataReader sqlDataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); return await sqlDataReader.GetFieldValueAsync(0, cancellationToken).ConfigureAwait(false); } @@ -200,12 +195,12 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); using SqlCommand command = SqlServerCommandBuilder.MergeIntoMany( - this._sqlConnection, - this._options, - this.CollectionName, - this._propertyReader.KeyProperty, - this._propertyReader.Properties, - records.Select(record => Map(record, this._propertyReader))); + this._sqlConnection, + this._options, + this.CollectionName, + this._propertyReader.KeyProperty, + this._propertyReader.Properties, + records.Select(record => Map(record, this._propertyReader))); using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index bc1881600b4c..9de50e872141 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -28,7 +28,7 @@ public void DropTable(string schema, string table, string expectedTable) { using SqlConnection connection = CreateConnection(); - using SqlCommand command = SqlServerCommandBuilder.DropTable(connection, schema, table); + using SqlCommand command = SqlServerCommandBuilder.DropTableIfExists(connection, schema, table); Assert.Equal($"DROP TABLE IF EXISTS [{schema}].{expectedTable}", command.CommandText); } @@ -55,6 +55,26 @@ FROM INFORMATION_SCHEMA.TABLES Assert.Equal(table, command.Parameters[1].Value); } + [Fact] + public void SelectTableNames() + { + const string SchemaName = "theSchemaName"; + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.SelectTableNames(connection, SchemaName); + + Assert.Equal( + """ + SELECT TABLE_NAME + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE = 'BASE TABLE' + AND TABLE_SCHEMA = @schema + """ + , command.CommandText); + Assert.Equal(SchemaName, command.Parameters[0].Value); + Assert.Equal("@schema", command.Parameters[0].ParameterName); + } + [Theory] [InlineData(true)] [InlineData(false)] diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs index bbe0ff0cb27a..de219028b836 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -1,5 +1,4 @@ -using System.Linq; -using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData; using SqlServerIntegrationTests.Support; using Xunit; @@ -8,21 +7,25 @@ namespace SqlServerIntegrationTests; public class SqlServerVectorStoreTests { [Fact] - public async Task CanCreateAndDeleteTheCollections() + public async Task CollectionCRUD() { + const string CollectionName = "collection"; SqlServerTestStore testStore = new(); await testStore.ReferenceCountingStartAsync(); - var collection = testStore.DefaultVectorStore.GetCollection("collection"); + var collection = testStore.DefaultVectorStore.GetCollection(CollectionName); try { Assert.False(await collection.CollectionExistsAsync()); + Assert.False(await testStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(CollectionName)); + await collection.CreateCollectionAsync(); Assert.True(await collection.CollectionExistsAsync()); + Assert.True(await testStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(CollectionName)); await collection.CreateCollectionIfNotExistsAsync(); @@ -31,6 +34,7 @@ public async Task CanCreateAndDeleteTheCollections() await collection.DeleteCollectionAsync(); Assert.False(await collection.CollectionExistsAsync()); + Assert.False(await testStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(CollectionName)); } finally { From 32605da37a9dad39114b0b75e9119ee649914ad5 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Wed, 19 Feb 2025 10:35:07 +0100 Subject: [PATCH 12/32] ensure that parameter names are always valid --- .../SqlServerCommandBuilder.cs | 83 +++++++++++------ .../SqlServerCommandBuilderTests.cs | 91 ++++++++++++------- 2 files changed, 112 insertions(+), 62 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 25187c0eb8fb..1617a3a4369a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -1,13 +1,11 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Net.NetworkInformation; using System.Text; using Microsoft.Data.SqlClient; using Microsoft.Extensions.VectorData; #pragma warning disable CA2100 // Review SQL queries for security vulnerabilities -#pragma warning disable CA1851 // Possible multiple enumerations of IEnumerable namespace Microsoft.SemanticKernel.Connectors.SqlServer; @@ -32,7 +30,6 @@ internal static SqlCommand CreateTable( sb.Append("CREATE TABLE "); sb.AppendTableName(options.Schema, tableName); sb.AppendLine(" ("); - // Use square brackets to escape column names. string keyColumnName = GetColumnName(keyProperty); sb.AppendFormat("[{0}] {1} NOT NULL,", keyColumnName, Map(keyProperty.PropertyType).sqlName); sb.AppendLine(); @@ -113,11 +110,10 @@ internal static SqlCommand InsertInto( sb.AppendFormat("OUTPUT inserted.[{0}]", GetColumnName(keyProperty)); sb.AppendLine(); sb.Append("VALUES ("); + int paramIndex = 0; foreach (VectorStoreRecordProperty property in nonKeyProperties) { - int index = sb.Length; - sb.AppendFormat("@{0},", GetColumnName(property)); - string paramName = sb.ToString(index, sb.Length - index - 1); // 1 is for the comma + sb.AppendParameterName(property, ref paramIndex, out string paramName).Append(','); command.Parameters.AddWithValue(paramName, record[property.DataModelPropertyName] ?? (object)DBNull.Value); } sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis @@ -141,11 +137,10 @@ internal static SqlCommand MergeIntoSingle( sb.AppendTableName(options.Schema, tableName); sb.AppendLine(" AS t"); sb.Append("USING (VALUES ("); + int paramIndex = 0; foreach (VectorStoreRecordProperty property in properties) { - int index = sb.Length; - sb.AppendFormat("@{0},", GetColumnName(property)); - string paramName = sb.ToString(index, sb.Length - index - 1); // 1 is for the comma + sb.AppendParameterName(property, ref paramIndex, out string paramName).Append(','); command.Parameters.AddWithValue(paramName, record[property.DataModelPropertyName] ?? (object)DBNull.Value); } sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis @@ -196,15 +191,13 @@ internal static SqlCommand MergeIntoMany( sb.AppendTableName(options.Schema, tableName); sb.AppendLine(" AS t"); // t stands for target sb.AppendLine("USING (VALUES"); - int rowIndex = 0; + int rowIndex = 0, paramIndex = 0; foreach (var record in records) { sb.Append('('); foreach (VectorStoreRecordProperty property in properties) { - int index = sb.Length; - sb.AppendFormat("@{0}_{1},", GetColumnName(property), rowIndex); - string paramName = sb.ToString(index, sb.Length - index - 1); // 1 is for the comma + sb.AppendParameterName(property, ref paramIndex, out string paramName).Append(','); command.Parameters.AddWithValue(paramName, record[property.DataModelPropertyName] ?? (object)DBNull.Value); } sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis @@ -259,13 +252,13 @@ internal static SqlCommand DeleteSingle( { SqlCommand command = connection.CreateCommand(); - string keyParamName = $"@{GetColumnName(keyProperty)}"; - command.Parameters.AddWithValue(keyParamName, key); - + int paramIndex = 0; StringBuilder sb = new(100); sb.Append("DELETE FROM "); sb.AppendTableName(schema, tableName); - sb.AppendFormat(" WHERE [{0}] = {1}", GetColumnName(keyProperty), keyParamName); + sb.AppendFormat(" WHERE [{0}] = ", GetColumnName(keyProperty)); + sb.AppendParameterName(keyProperty, ref paramIndex, out string keyParamName); + command.Parameters.AddWithValue(keyParamName, key); command.CommandText = sb.ToString(); return command; @@ -281,7 +274,7 @@ internal static SqlCommand DeleteMany( sb.Append("DELETE FROM "); sb.AppendTableName(schema, tableName); sb.AppendFormat(" WHERE [{0}] IN (", GetColumnName(keyProperty)); - sb.AppendKeyParameterList(keys, command); + sb.AppendKeyParameterList(keys, command, keyProperty); sb.Append(')'); // close the IN clause command.CommandText = sb.ToString(); @@ -296,9 +289,7 @@ internal static SqlCommand SelectSingle( { SqlCommand command = sqlConnection.CreateCommand(); - string keyParamName = $"@{GetColumnName(keyProperty)}"; - command.Parameters.AddWithValue(keyParamName, key); - + int paramIndex = 0; StringBuilder sb = new(200); sb.AppendFormat("SELECT "); sb.AppendColumnNames(properties); @@ -306,7 +297,9 @@ internal static SqlCommand SelectSingle( sb.Append("FROM "); sb.AppendTableName(schema, collectionName); sb.AppendLine(); - sb.AppendFormat("WHERE [{0}] = {1}", GetColumnName(keyProperty), keyParamName); + sb.AppendFormat("WHERE [{0}] = ", GetColumnName(keyProperty)); + sb.AppendParameterName(keyProperty, ref paramIndex, out string keyParamName); + command.Parameters.AddWithValue(keyParamName, key); command.CommandText = sb.ToString(); return command; @@ -328,7 +321,7 @@ internal static SqlCommand SelectMany( sb.AppendTableName(schema, tableName); sb.AppendLine(); sb.AppendFormat("WHERE [{0}] IN (", GetColumnName(keyProperty)); - sb.AppendKeyParameterList(keys, command); + sb.AppendKeyParameterList(keys, command, keyProperty); sb.Append(')'); // close the IN clause command.CommandText = sb.ToString(); @@ -338,6 +331,38 @@ internal static SqlCommand SelectMany( internal static string GetColumnName(VectorStoreRecordProperty property) => property.StoragePropertyName ?? property.DataModelPropertyName; + internal static StringBuilder AppendParameterName(this StringBuilder sb, VectorStoreRecordProperty property, ref int paramIndex, out string parameterName) + { + // In SQL Server, parameter names cannot be just a number like "@1". + // Parameter names must start with an alphabetic character or an underscore + // and can be followed by alphanumeric characters or underscores. + // Since we can't guarantee that the value returned by StoragePropertyName and DataModelPropertyName + // is valid parameter name (it can contain whitespaces, or start with a number), + // we just append the ASCII letters, stop on the first non-ASCII letter + // and append the index. + string columnName = GetColumnName(property); + int index = sb.Length; + sb.Append('@'); + foreach (char character in columnName) + { + // We don't call APIs like char.IsWhitespace as they are expensive + // as they need to handle all Unicode characters. + if (!((character is >= 'a' and <= 'z') || (character is >= 'A' and <= 'Z'))) + { + break; + } + sb.Append(character); + } + // In case the column name is empty or does not start with ASCII letters, + // we provide the underscore as a prefix (allowed). + sb.Append('_'); + // To ensure the generated parameter id is unique, we append the index. + sb.Append(paramIndex++); + parameterName = sb.ToString(index, sb.Length - index); + + return sb; + } + internal static StringBuilder AppendTableName(this StringBuilder sb, string schema, string tableName) { // If the column name contains a ], then escape it by doubling it. @@ -367,6 +392,7 @@ private static StringBuilder AppendColumnNames(this StringBuilder sb, { sb.Append(prefix); } + // Use square brackets to escape column names. sb.AppendFormat("[{0}],", GetColumnName(property)); any = true; } @@ -379,7 +405,8 @@ private static StringBuilder AppendColumnNames(this StringBuilder sb, return sb; } - private static StringBuilder AppendKeyParameterList(this StringBuilder sb, IEnumerable keys, SqlCommand command) + private static StringBuilder AppendKeyParameterList(this StringBuilder sb, + IEnumerable keys, SqlCommand command, VectorStoreRecordKeyProperty keyProperty) { int keyIndex = 0; foreach (TKey key in keys) @@ -387,10 +414,10 @@ private static StringBuilder AppendKeyParameterList(this StringBuilder sb, // The caller ensures that keys collection is not null. // We need to ensure that none of the keys is null. Verify.NotNull(key); - int index = sb.Length; - sb.AppendFormat("@k{0},", keyIndex++); - string keyParam = sb.ToString(index, sb.Length - index - 1); // 1 is for the comma - command.Parameters.AddWithValue(keyParam, key); + + sb.AppendParameterName(keyProperty, ref keyIndex, out string keyParamName); + sb.Append(','); + command.Parameters.AddWithValue(keyParamName, key); } if (keyIndex == 0) diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index 9de50e872141..9663132bfa41 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -21,6 +21,29 @@ public void AppendTableName(string schema, string table, string expectedFullName Assert.Equal(expectedFullName, result.ToString()); } + [Theory] + [InlineData("name", "@name_")] // typical name + [InlineData("na me", "@na_")] // contains a whitespace, an illegal parameter name character + [InlineData("123", "@_")] // starts with a digit, also not allowed + [InlineData("ĄŻŚĆ_doesNotStartWithAscii", "@_")] // starts with a non-ASCII character + public void AppendParameterName(string propertyName, string expectedPrefix) + { + StringBuilder builder = new(); + StringBuilder expectedBuilder = new(); + VectorStoreRecordKeyProperty keyProperty = new(propertyName, typeof(string)); + + int paramIndex = 0; // we need a dedicated variable to ensure that AppendParameterName increments the index + for (int i = 0; i < 10; i++) + { + Assert.Equal(paramIndex, i); + SqlServerCommandBuilder.AppendParameterName(builder, keyProperty, ref paramIndex, out string parameterName); + Assert.Equal($"{expectedPrefix}{i}", parameterName); + expectedBuilder.Append(parameterName); + } + + Assert.Equal(expectedBuilder.ToString(), builder.ToString()); + } + [Theory] [InlineData("schema", "simpleName", "[simpleName]")] [InlineData("schema", "[needsEscaping]", "[[needsEscaping]]]")] @@ -92,7 +115,7 @@ public void CreateTable(bool ifNotExists) ]; VectorStoreRecordVectorProperty[] vectorProperties = [ - new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) { Dimensions = 10 } @@ -108,7 +131,7 @@ CREATE TABLE [schema].[table] ( [id] BIGINT NOT NULL, [simpleName] NVARCHAR(255) COLLATE Latin1_General_100_BIN2, [with space] INT NOT NULL, - [embedding1] VECTOR(10), + [embedding] VECTOR(10), PRIMARY KEY NONCLUSTERED ([id]) ) """; @@ -135,7 +158,7 @@ public void InsertInto() ]; VectorStoreRecordVectorProperty[] vectorProperties = [ - new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) { Dimensions = 10 } @@ -149,22 +172,22 @@ public void InsertInto() { "id", null }, { "simpleString", "nameValue" }, { "simpleInt", 134 }, - { "embedding1", "{ 10.0 }" } + { "embedding", "{ 10.0 }" } }); string expectedCommand = """ - INSERT INTO [schema].[table] ([simpleString],[simpleInt],[embedding1]) + INSERT INTO [schema].[table] ([simpleString],[simpleInt],[embedding]) OUTPUT inserted.[id] - VALUES (@simpleString,@simpleInt,@embedding1); + VALUES (@simpleString_0,@simpleInt_1,@embedding_2); """; Assert.Equal(HandleNewLines(expectedCommand), command.CommandText); - Assert.Equal("@simpleString", command.Parameters[0].ParameterName); + Assert.Equal("@simpleString_0", command.Parameters[0].ParameterName); Assert.Equal("nameValue", command.Parameters[0].Value); - Assert.Equal("@simpleInt", command.Parameters[1].ParameterName); + Assert.Equal("@simpleInt_1", command.Parameters[1].ParameterName); Assert.Equal(134, command.Parameters[1].Value); - Assert.Equal("@embedding1", command.Parameters[2].ParameterName); + Assert.Equal("@embedding_2", command.Parameters[2].ParameterName); Assert.Equal("{ 10.0 }", command.Parameters[2].Value); } @@ -181,7 +204,7 @@ public void MergeIntoSingle() keyProperty, new VectorStoreRecordDataProperty("simpleString", typeof(string)), new VectorStoreRecordDataProperty("simpleInt", typeof(int)), - new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) { Dimensions = 10 } @@ -195,29 +218,29 @@ public void MergeIntoSingle() { "id", null }, { "simpleString", "nameValue" }, { "simpleInt", 134 }, - { "embedding1", "{ 10.0 }" } + { "embedding", "{ 10.0 }" } }); string expectedCommand = """" MERGE INTO [schema].[table] AS t - USING (VALUES (@id,@simpleString,@simpleInt,@embedding1)) AS s ([id],[simpleString],[simpleInt],[embedding1]) + USING (VALUES (@id_0,@simpleString_1,@simpleInt_2,@embedding_3)) AS s ([id],[simpleString],[simpleInt],[embedding]) ON (t.[id] = s.[id]) WHEN MATCHED THEN - UPDATE SET t.[simpleString] = s.[simpleString],t.[simpleInt] = s.[simpleInt],t.[embedding1] = s.[embedding1] + UPDATE SET t.[simpleString] = s.[simpleString],t.[simpleInt] = s.[simpleInt],t.[embedding] = s.[embedding] WHEN NOT MATCHED THEN - INSERT ([id],[simpleString],[simpleInt],[embedding1]) - VALUES (s.[id],s.[simpleString],s.[simpleInt],s.[embedding1]); + INSERT ([id],[simpleString],[simpleInt],[embedding]) + VALUES (s.[id],s.[simpleString],s.[simpleInt],s.[embedding]); """"; Assert.Equal(HandleNewLines(expectedCommand), command.CommandText); - Assert.Equal("@id", command.Parameters[0].ParameterName); + Assert.Equal("@id_0", command.Parameters[0].ParameterName); Assert.Equal(DBNull.Value, command.Parameters[0].Value); - Assert.Equal("@simpleString", command.Parameters[1].ParameterName); + Assert.Equal("@simpleString_1", command.Parameters[1].ParameterName); Assert.Equal("nameValue", command.Parameters[1].Value); - Assert.Equal("@simpleInt", command.Parameters[2].ParameterName); + Assert.Equal("@simpleInt_2", command.Parameters[2].ParameterName); Assert.Equal(134, command.Parameters[2].Value); - Assert.Equal("@embedding1", command.Parameters[3].ParameterName); + Assert.Equal("@embedding_3", command.Parameters[3].ParameterName); Assert.Equal("{ 10.0 }", command.Parameters[3].Value); } @@ -266,8 +289,8 @@ public void MergeIntoMany() DECLARE @InsertedKeys TABLE (KeyColumn BIGINT); MERGE INTO [schema].[table] AS t USING (VALUES - (@id_0,@simpleString_0,@simpleInt_0,@embedding_0), - (@id_1,@simpleString_1,@simpleInt_1,@embedding_1)) AS s ([id],[simpleString],[simpleInt],[embedding]) + (@id_0,@simpleString_1,@simpleInt_2,@embedding_3), + (@id_4,@simpleString_5,@simpleInt_6,@embedding_7)) AS s ([id],[simpleString],[simpleInt],[embedding]) ON (t.[id] = s.[id]) WHEN MATCHED THEN UPDATE SET t.[simpleString] = s.[simpleString],t.[simpleInt] = s.[simpleInt],t.[embedding] = s.[embedding] @@ -282,13 +305,13 @@ WHEN NOT MATCHED THEN for (int i = 0; i < records.Length; i++) { - Assert.Equal($"@id_{i}", command.Parameters[4 * i].ParameterName); - Assert.Equal((long)i, command.Parameters[4 * i].Value); - Assert.Equal($"@simpleString_{i}", command.Parameters[4 * i + 1].ParameterName); + Assert.Equal($"@id_{4 * i + 0}", command.Parameters[4 * i + 0].ParameterName); + Assert.Equal((long)i, command.Parameters[4 * i + 0].Value); + Assert.Equal($"@simpleString_{4 * i + 1}", command.Parameters[4 * i + 1].ParameterName); Assert.Equal($"nameValue{i}", command.Parameters[4 * i + 1].Value); - Assert.Equal($"@simpleInt_{i}", command.Parameters[4 * i + 2].ParameterName); + Assert.Equal($"@simpleInt_{4 * i + 2}", command.Parameters[4 * i + 2].ParameterName); Assert.Equal(134 + i, command.Parameters[4 * i + 2].Value); - Assert.Equal($"@embedding_{i}", command.Parameters[4 * i + 3].ParameterName); + Assert.Equal($"@embedding_{4 * i + 3}", command.Parameters[4 * i + 3].ParameterName); Assert.Equal($"{{ 1{i}.0 }}", command.Parameters[4 * i + 3].Value); } } @@ -302,9 +325,9 @@ public void DeleteSingle() using SqlCommand command = SqlServerCommandBuilder.DeleteSingle(connection, "schema", "tableName", keyProperty, 123L); - Assert.Equal("DELETE FROM [schema].[tableName] WHERE [id] = @id", command.CommandText); + Assert.Equal("DELETE FROM [schema].[tableName] WHERE [id] = @id_0", command.CommandText); Assert.Equal(123L, command.Parameters[0].Value); - Assert.Equal("@id", command.Parameters[0].ParameterName); + Assert.Equal("@id_0", command.Parameters[0].ParameterName); } [Fact] @@ -317,11 +340,11 @@ public void DeleteMany() using SqlCommand command = SqlServerCommandBuilder.DeleteMany(connection, "schema", "tableName", keyProperty, keys); - Assert.Equal("DELETE FROM [schema].[tableName] WHERE [id] IN (@k0,@k1)", command.CommandText); + Assert.Equal("DELETE FROM [schema].[tableName] WHERE [id] IN (@id_0,@id_1)", command.CommandText); for (int i = 0; i < keys.Length; i++) { Assert.Equal(keys[i], command.Parameters[i].Value); - Assert.Equal($"@k{i}", command.Parameters[i].ParameterName); + Assert.Equal($"@id_{i}", command.Parameters[i].ParameterName); } } @@ -347,10 +370,10 @@ public void SelectSingle() """"" SELECT [id],[name],[age],[embedding] FROM [schema].[tableName] - WHERE [id] = @id + WHERE [id] = @id_0 """""), command.CommandText); Assert.Equal(123L, command.Parameters[0].Value); - Assert.Equal("@id", command.Parameters[0].ParameterName); + Assert.Equal("@id_0", command.Parameters[0].ParameterName); } [Fact] @@ -376,12 +399,12 @@ public void SelectMany() """"" SELECT [id],[name],[age],[embedding] FROM [schema].[tableName] - WHERE [id] IN (@k0,@k1,@k2) + WHERE [id] IN (@id_0,@id_1,@id_2) """""), command.CommandText); for (int i = 0; i < keys.Length; i++) { Assert.Equal(keys[i], command.Parameters[i].Value); - Assert.Equal($"@k{i}", command.Parameters[i].ParameterName); + Assert.Equal($"@id_{i}", command.Parameters[i].ParameterName); } } From b4a73ee325ec4cb5f5bc2381f6a4ddb2189e850c Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Wed, 19 Feb 2025 10:48:36 +0100 Subject: [PATCH 13/32] add some comments --- .../Connectors.Memory.SqlServer/SqlServerVectorStore.cs | 2 +- .../SqlServerVectorStoreRecordCollection.cs | 1 + .../src/Data/VectorStoreRecordPropertyReader.cs | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs index 2b005a790228..81d7234895a2 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs @@ -76,7 +76,7 @@ public SqlServerVectorStore(SqlConnection connection, SqlServerVectorStoreOption public void Dispose() => this._connection.Dispose(); // TODO: adsitnik: design - // I find the creation process uniutive: the IVectorStoreRecordCollection.Create + // I find the creation process not intuitive: the IVectorStoreRecordCollection.Create // method does take only table name as an arugment, the metadata needs to be provided // a step before that by passing the VectorStoreRecordDefinition to the GetCollection method. // I would expect VectorStoreRecordDefinition to be argument of the CreateCollectionAsync. diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index 6c12a3ab1ce2..7ad6766ada45 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -47,6 +47,7 @@ public Task CreateCollectionAsync(CancellationToken cancellationToken = default) // 1. I totally see why we want to provide it, we just need to make sure it's the right thing to do. // 2. An alternative would be to make CreateCollectionAsync a nop when the collection already exists // or extend it with an optional boolean parameter that would control the behavior. + // 3. We may need it to avoid TOCTOU issues. public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) => this.CreateCollectionAsync(ifNotExists: true, cancellationToken); diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs index 48287af8b963..522ad0bbdbdd 100644 --- a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs +++ b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs @@ -118,6 +118,7 @@ public VectorStoreRecordPropertyReader( this._parameterlessConstructorInfo = new Lazy(() => { + // TODO adsitnik: design: why don't we requrie TRecord to be always : new()? var constructor = dataModelType.GetConstructor(Type.EmptyTypes); if (constructor == null) { From e8584bec1655df1b84dc44089ea185a939c97ecf Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Wed, 19 Feb 2025 11:58:31 +0100 Subject: [PATCH 14/32] support storing more types, support auto-generated keys --- .../SqlServerCommandBuilder.cs | 85 ++++++++++---- .../SqlServerVectorStore.cs | 21 +++- .../SqlServerVectorStoreRecordCollection.cs | 66 ++++++++--- .../VectorStoreRecordKeyAttribute.cs | 7 +- .../VectorStoreRecordKeyProperty.cs | 9 +- .../Data/VectorStoreRecordPropertyReader.cs | 2 +- .../SqlServerCommandBuilderTests.cs | 6 +- .../SqlServerVectorStoreTests.cs | 107 ++++++++++++++++++ 8 files changed, 248 insertions(+), 55 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 1617a3a4369a..5e857e279342 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Data; using System.Linq; using System.Text; using Microsoft.Data.SqlClient; @@ -31,12 +32,12 @@ internal static SqlCommand CreateTable( sb.AppendTableName(options.Schema, tableName); sb.AppendLine(" ("); string keyColumnName = GetColumnName(keyProperty); - sb.AppendFormat("[{0}] {1} NOT NULL,", keyColumnName, Map(keyProperty.PropertyType).sqlName); + var keyMapping = Map(keyProperty.PropertyType); + sb.AppendFormat("[{0}] {1} {2},", keyColumnName, keyMapping.sqlName, keyProperty.AutoGenerate ? keyMapping.autoGenerate : "NOT NULL"); sb.AppendLine(); for (int i = 0; i < dataProperties.Count; i++) { - (string sqlName, bool isNullable) = Map(dataProperties[i].PropertyType); - sb.AppendFormat(isNullable ? "[{0}] {1}," : "[{0}] {1} NOT NULL,", GetColumnName(dataProperties[i]), sqlName); + sb.AppendFormat("[{0}] {1},", GetColumnName(dataProperties[i]), Map(dataProperties[i].PropertyType).sqlName); sb.AppendLine(); } for (int i = 0; i < vectorProperties.Count; i++) @@ -114,7 +115,7 @@ internal static SqlCommand InsertInto( foreach (VectorStoreRecordProperty property in nonKeyProperties) { sb.AppendParameterName(property, ref paramIndex, out string paramName).Append(','); - command.Parameters.AddWithValue(paramName, record[property.DataModelPropertyName] ?? (object)DBNull.Value); + command.AddParameter(property, paramName, record[property.DataModelPropertyName]); } sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis sb.Append(';'); @@ -141,7 +142,7 @@ internal static SqlCommand MergeIntoSingle( foreach (VectorStoreRecordProperty property in properties) { sb.AppendParameterName(property, ref paramIndex, out string paramName).Append(','); - command.Parameters.AddWithValue(paramName, record[property.DataModelPropertyName] ?? (object)DBNull.Value); + command.AddParameter(property, paramName, record[property.DataModelPropertyName]); } sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis sb.Append(") AS s ("); @@ -159,13 +160,18 @@ internal static SqlCommand MergeIntoSingle( } --sb.Length; // remove the last comma sb.AppendLine(); + + // We must not try to insert the key if it is auto-generated. + var propertiesToInsert = keyProperty.AutoGenerate + ? properties.Where(p => p != keyProperty) + : properties; sb.Append("WHEN NOT MATCHED THEN"); sb.AppendLine(); sb.Append("INSERT ("); - sb.AppendColumnNames(properties); + sb.AppendColumnNames(propertiesToInsert); sb.AppendLine(")"); sb.Append("VALUES ("); - sb.AppendColumnNames(properties, prefix: "s."); + sb.AppendColumnNames(propertiesToInsert, prefix: "s."); sb.Append(");"); command.CommandText = sb.ToString(); @@ -198,7 +204,7 @@ internal static SqlCommand MergeIntoMany( foreach (VectorStoreRecordProperty property in properties) { sb.AppendParameterName(property, ref paramIndex, out string paramName).Append(','); - command.Parameters.AddWithValue(paramName, record[property.DataModelPropertyName] ?? (object)DBNull.Value); + command.AddParameter(property, paramName, record[property.DataModelPropertyName]); } sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis sb.AppendLine(","); @@ -258,7 +264,7 @@ internal static SqlCommand DeleteSingle( sb.AppendTableName(schema, tableName); sb.AppendFormat(" WHERE [{0}] = ", GetColumnName(keyProperty)); sb.AppendParameterName(keyProperty, ref paramIndex, out string keyParamName); - command.Parameters.AddWithValue(keyParamName, key); + command.AddParameter(keyProperty, keyParamName, key); command.CommandText = sb.ToString(); return command; @@ -299,7 +305,7 @@ internal static SqlCommand SelectSingle( sb.AppendLine(); sb.AppendFormat("WHERE [{0}] = ", GetColumnName(keyProperty)); sb.AppendParameterName(keyProperty, ref paramIndex, out string keyParamName); - command.Parameters.AddWithValue(keyParamName, key); + command.AddParameter(keyProperty, keyParamName, key); command.CommandText = sb.ToString(); return command; @@ -417,7 +423,7 @@ private static StringBuilder AppendKeyParameterList(this StringBuilder sb, sb.AppendParameterName(keyProperty, ref keyIndex, out string keyParamName); sb.Append(','); - command.Parameters.AddWithValue(keyParamName, key); + command.AddParameter(keyProperty, keyParamName, key); } if (keyIndex == 0) @@ -437,19 +443,48 @@ private static SqlCommand CreateCommand(this SqlConnection connection, StringBui return command; } - private static (string sqlName, bool isNullable) Map(Type type) => type switch + private static void AddParameter(this SqlCommand command, VectorStoreRecordProperty property, string name, object? value) + { + switch (value) + { + case null when property.PropertyType == typeof(byte[]): + command.Parameters.Add(name, System.Data.SqlDbType.VarBinary).Value = DBNull.Value; + break; + case null: + command.Parameters.AddWithValue(name, DBNull.Value); + break; + case byte[] buffer: + command.Parameters.Add(name, System.Data.SqlDbType.VarBinary).Value = buffer; + break; + default: + command.Parameters.AddWithValue(name, value); + break; + } + } + + private static (string sqlName, string? autoGenerate) Map(Type type) { - Type t when t == typeof(int) => ("INT", false), - Type t when t == typeof(long) => ("BIGINT", false), - Type t when t == typeof(Guid) => ("UNIQUEIDENTIFIER", false), - Type t when t == typeof(string) => ("NVARCHAR(255) COLLATE Latin1_General_100_BIN2", true), - Type t when t == typeof(byte[]) => ("VARBINARY(MAX)", true), - Type t when t == typeof(bool) => ("BIT", false), - Type t when t == typeof(DateTime) => ("DATETIME", false), - Type t when t == typeof(TimeSpan) => ("TIME", false), - Type t when t == typeof(decimal) => ("DECIMAL", false), - Type t when t == typeof(double) => ("FLOAT", false), - Type t when t == typeof(float) => ("REAL", false), - _ => throw new NotSupportedException($"Type {type} is not supported.") - }; + const string NVARCHAR = "NVARCHAR(255) COLLATE Latin1_General_100_BIN2"; + return type switch + { + Type t when t == typeof(byte) => ("TINYINT", null), + Type t when t == typeof(short) => ("SMALLINT", null), + Type t when t == typeof(int) => ("INT", "IDENTITY(1,1)"), + Type t when t == typeof(long) => ("BIGINT", "IDENTITY(1,1)"), + // TODO adsitnik: discuss using NEWID() vs NEWSEQUENTIALID(). + Type t when t == typeof(Guid) => ("UNIQUEIDENTIFIER", "DEFAULT NEWID()"), + Type t when t == typeof(string) => (NVARCHAR, null), + Type t when t == typeof(byte[]) => ("VARBINARY(MAX)", null), + Type t when t == typeof(bool) => ("BIT", null), + Type t when t == typeof(DateTime) => ("DATETIME", null), + Type t when t == typeof(TimeSpan) => ("TIME", null), + Type t when t == typeof(decimal) => ("DECIMAL", null), + Type t when t == typeof(double) => ("FLOAT", null), + Type t when t == typeof(float) => ("REAL", null), + // Collections don't have good native support, we store them as JSON + Type t when t == typeof(string[]) => (NVARCHAR, null), + Type t when t == typeof(List) => (NVARCHAR, null), + _ => throw new NotSupportedException($"Type {type} is not supported.") + }; + } } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs index 81d7234895a2..b864563fc8e2 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs @@ -26,12 +26,15 @@ public sealed class SqlServerVectorStore : IVectorStore, IDisposable typeof(long), // BIGINT typeof(string), // VARCHAR typeof(Guid), // UNIQUEIDENTIFIER - // TODO adsitnik: do we want to support DATETIME (DateTime) and VARBINARY (byte[])? + typeof(DateTime), // DATETIME + typeof(byte[]) // VARBINARY ]; private static readonly HashSet s_supportedDataTypes = [ typeof(int), // INT + typeof(short), // SMALLINT + typeof(byte), // TINYINT typeof(long), // BIGINT. typeof(Guid), // UNIQUEIDENTIFIER. typeof(string), // NVARCHAR @@ -41,10 +44,12 @@ public sealed class SqlServerVectorStore : IVectorStore, IDisposable typeof(TimeSpan), // TIME typeof(decimal), // DECIMAL typeof(double), // FLOAT - typeof(float) // REAL + typeof(float), // REAL, + typeof(string[]), // NVARCHAR accessed as JSON + typeof(List) // NVARCHAR accessed as JSON ]; - private static readonly HashSet s_supportedVectorTypes = + internal static readonly HashSet s_supportedVectorTypes = [ typeof(ReadOnlyMemory), // VECTOR typeof(ReadOnlyMemory?) @@ -106,10 +111,16 @@ public IVectorStoreRecordCollection GetCollection( }); propertyReader.VerifyKeyProperties(s_supportedKeyTypes); - // TODO adsitnik: get the list of supported ienumerable types - propertyReader.VerifyDataProperties(s_supportedDataTypes, supportEnumerable: true); + propertyReader.VerifyDataProperties(s_supportedDataTypes, supportEnumerable: false); propertyReader.VerifyVectorProperties(s_supportedVectorTypes); + if (propertyReader.KeyProperty.AutoGenerate + && !(typeof(TKey) == typeof(int) || typeof(TKey) == typeof(long) || typeof(TKey) == typeof(Guid))) + { + // SQL Server does not support auto-generated keys for types other than int, long, and Guid. + throw new ArgumentException("Key property cannot be auto-generated."); + } + // Add to the cache once we have verified the record definition. s_propertyReaders.TryAdd(typeof(TRecord), propertyReader); } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index 7ad6766ada45..4d82269807d7 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -160,32 +160,34 @@ public async Task UpsertAsync(TRecord record, CancellationToken cancellati TKey? key = (TKey)this._propertyReader.KeyPropertyInfo.GetValue(record); Dictionary map = Map(record, this._propertyReader); - using SqlCommand command = key is null - // When the key is null, we are inserting a new record. - ? SqlServerCommandBuilder.InsertInto( + + if (key is null || key.Equals(default(TKey))) + { + // When the key was not provided, we are inserting a new record. + using SqlCommand insertCommand = SqlServerCommandBuilder.InsertInto( this._sqlConnection, this._options, this.CollectionName, this._propertyReader.KeyProperty, this._propertyReader.DataProperties, this._propertyReader.VectorProperties, - map) - : SqlServerCommandBuilder.MergeIntoSingle( - this._sqlConnection, - this._options, - this.CollectionName, - this._propertyReader.KeyProperty, - this._propertyReader.Properties, map); - if (key is not null) - { - await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - return key; + using SqlDataReader reader = await insertCommand.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + await reader.ReadAsync(cancellationToken).ConfigureAwait(false); + return reader.GetFieldValue(0); } - using SqlDataReader sqlDataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - return await sqlDataReader.GetFieldValueAsync(0, cancellationToken).ConfigureAwait(false); + using SqlCommand command = SqlServerCommandBuilder.MergeIntoSingle( + this._sqlConnection, + this._options, + this.CollectionName, + this._propertyReader.KeyProperty, + this._propertyReader.Properties, + map); + + await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + return key; } public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, @@ -228,7 +230,16 @@ private Task EnsureConnectionIsOpenedAsync(CancellationToken cancellationToken) var dataProperties = propertyReader.DataProperties; for (int i = 0; i < dataProperties.Count; i++) { - map[dataProperties[i].DataModelPropertyName] = propertyReader.DataPropertiesInfo[i].GetValue(record); + object value = propertyReader.DataPropertiesInfo[i].GetValue(record); + // SQL Server does not support arrays, so we need to serialize them to JSON. + object? mappedValue = value switch + { + string[] array => JsonSerializer.Serialize(array), + List list => JsonSerializer.Serialize(list), + _ => value + }; + + map[dataProperties[i].DataModelPropertyName] = mappedValue; } var vectorProperties = propertyReader.VectorProperties; for (int i = 0; i < vectorProperties.Count; i++) @@ -254,11 +265,30 @@ private static TRecord Map(SqlDataReader reader, VectorStoreRecordPropertyReader for (int i = 0; i < data.Count; i++) { object value = reader[SqlServerCommandBuilder.GetColumnName(data[i])]; - if (value is not DBNull) + if (value is DBNull) + { + // There is no need to call the reflection to set the null, + // as it's the default value of every .NET reference type field. + continue; + } + + if (value is not string text) { dataInfo[i].SetValue(record, value); } + else + { + // SQL Server does not support arrays, so we need to deserialize them from JSON. + object? mappedValue = data[i].PropertyType switch + { + Type t when t == typeof(string[]) => JsonSerializer.Deserialize(text), + Type t when t == typeof(List) => JsonSerializer.Deserialize>(text), + _ => text + }; + dataInfo[i].SetValue(record, mappedValue); + } } + var vector = propertyReader.VectorProperties; var vectorInfo = propertyReader.VectorPropertiesInfo; for (int i = 0; i < vector.Count; i++) diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs index a2f27488fd63..4bad37f8ccb7 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs @@ -11,8 +11,6 @@ namespace Microsoft.Extensions.VectorData; /// The characteristics defined here will influence how the property is treated by the vector store. /// [AttributeUsage(AttributeTargets.Property, AllowMultiple = false)] -// TODO adsitnik design: this attribute does not allow us to tell the DB to insert the key -// and upsert expects us to handle such scenario. public sealed class VectorStoreRecordKeyAttribute : Attribute { /// @@ -20,4 +18,9 @@ public sealed class VectorStoreRecordKeyAttribute : Attribute /// E.g. the property name might be "MyProperty" but the storage name might be "my_property". /// public string? StoragePropertyName { get; set; } + + /// + /// Gets or sets whether the key should be auto-generated by the vector store. + /// + public bool AutoGenerate { get; set; } } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs index 92b8260b19d8..5fa216165d8d 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs @@ -17,9 +17,11 @@ public sealed class VectorStoreRecordKeyProperty : VectorStoreRecordProperty /// /// The name of the property. /// The type of the property. - public VectorStoreRecordKeyProperty(string propertyName, Type propertyType) + /// Whether the key should be auto-generated by the vector store. + public VectorStoreRecordKeyProperty(string propertyName, Type propertyType, bool autoGenerate = false) : base(propertyName, propertyType) { + this.AutoGenerate = autoGenerate; } /// @@ -30,4 +32,9 @@ public VectorStoreRecordKeyProperty(VectorStoreRecordKeyProperty source) : base(source) { } + + /// + /// Gets a value indicating whether the key should be auto-generated by the vector store. + /// + public bool AutoGenerate { get; } } diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs index 522ad0bbdbdd..c833093462ca 100644 --- a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs +++ b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs @@ -513,7 +513,7 @@ private static VectorStoreRecordDefinition CreateVectorStoreRecordDefinitionFrom var keyAttribute = keyProperty.GetCustomAttribute(); if (keyAttribute is not null) { - definitionProperties.Add(new VectorStoreRecordKeyProperty(keyProperty.Name, keyProperty.PropertyType) + definitionProperties.Add(new VectorStoreRecordKeyProperty(keyProperty.Name, keyProperty.PropertyType, keyAttribute.AutoGenerate) { StoragePropertyName = keyAttribute.StoragePropertyName }); diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index 9663132bfa41..54d13dc31a34 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -107,7 +107,7 @@ public void CreateTable(bool ifNotExists) { Schema = "schema" }; - VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long), autoGenerate: true); VectorStoreRecordDataProperty[] dataProperties = [ new VectorStoreRecordDataProperty("simpleName", typeof(string)), @@ -128,9 +128,9 @@ public void CreateTable(bool ifNotExists) string expectedCommand = """ CREATE TABLE [schema].[table] ( - [id] BIGINT NOT NULL, + [id] BIGINT IDENTITY(1,1), [simpleName] NVARCHAR(255) COLLATE Latin1_General_100_BIN2, - [with space] INT NOT NULL, + [with space] INT, [embedding] VECTOR(10), PRIMARY KEY NONCLUSTERED ([id]) ) diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs index de219028b836..1d45c65ee18f 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -179,4 +179,111 @@ public sealed class TestModel [VectorStoreRecordVector(Dimensions: 10, StoragePropertyName = "embedding")] public ReadOnlyMemory Floats { get; set; } } + + [Fact] + public Task CanUseFancyModels_Int() => this.CanUseFancyModels(); + + [Fact] + public Task CanUseFancyModels_Long() => this.CanUseFancyModels(); + + [Fact] + public Task CanUseFancyModels_Guid() => this.CanUseFancyModels(); + + private async Task CanUseFancyModels() where TKey : notnull + { + SqlServerTestStore testStore = new(); + + await testStore.ReferenceCountingStartAsync(); + + var collection = testStore.DefaultVectorStore.GetCollection>("other"); + + try + { + await collection.CreateCollectionIfNotExistsAsync(); + + FancyTestModel inserted = new() + { + // We let the DB assign Id! + Number8 = byte.MaxValue, + Number16 = short.MaxValue, + Number32 = int.MaxValue, + Number64 = long.MaxValue, + Floats = Enumerable.Range(0, 10).Select(i => (float)i).ToArray(), + Bytes = [1, 2, 3], + ArrayOfStrings = ["a", "b", "c"], + ListOfStrings = ["d", "e", "f"] + }; + TKey key = await collection.UpsertAsync(inserted); + Assert.NotEqual(default(TKey), key); // key should be assigned by the DB (auto-increment) + + FancyTestModel? received = await collection.GetAsync(key); + AssertEquality(inserted, received, key); + + FancyTestModel updated = new() + { + Id = key, + Number16 = short.MinValue, // change one property + Floats = inserted.Floats + }; + key = await collection.UpsertAsync(updated); + Assert.Equal(updated.Id, key); + + received = await collection.GetAsync(updated.Id); + AssertEquality(updated, received, key); + + await collection.DeleteAsync(inserted.Id); + + Assert.Null(await collection.GetAsync(inserted.Id)); + } + finally + { + await collection.DeleteCollectionAsync(); + + await testStore.ReferenceCountingStopAsync(); + } + + void AssertEquality(FancyTestModel expected, FancyTestModel? received, TKey expectedKey) + { + Assert.NotNull(received); + Assert.Equal(expectedKey, received.Id); + Assert.Equal(expected.Number8, received.Number8); + Assert.Equal(expected.Number16, received.Number16); + Assert.Equal(expected.Number32, received.Number32); + Assert.Equal(expected.Number64, received.Number64); + Assert.Equal(expected.Floats, received.Floats); + Assert.Equal(expected.Bytes, received.Bytes); + Assert.Equal(expected.ArrayOfStrings, received.ArrayOfStrings); + Assert.Equal(expected.ListOfStrings, received.ListOfStrings); + } + } + + public sealed class FancyTestModel where TKey : notnull + { + [VectorStoreRecordKey(StoragePropertyName = "key", AutoGenerate = true)] + public TKey Id { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "byte")] + public byte Number8 { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "short")] + public short Number16 { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "int")] + public int Number32 { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "long")] + public long Number64 { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "bytes")] + public byte[]? Bytes { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "array_of_strings")] + public string[]? ArrayOfStrings { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "list_of_strings")] + public List? ListOfStrings { get; set; } + + [VectorStoreRecordVector(Dimensions: 10, StoragePropertyName = "embedding")] + public ReadOnlyMemory Floats { get; set; } + } } From f397f3f7e13ef087e23efff32d9a87016f2055f5 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Wed, 19 Feb 2025 17:00:00 +0100 Subject: [PATCH 15/32] simplify: don't use a dedicated query for inserting a single record --- .../SqlServerCommandBuilder.cs | 37 +------------ .../SqlServerVectorStoreRecordCollection.cs | 27 ++------- .../SqlServerCommandBuilderTests.cs | 55 ++----------------- 3 files changed, 10 insertions(+), 109 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 5e857e279342..c53f275ee65f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -90,40 +90,6 @@ FROM INFORMATION_SCHEMA.TABLES return command; } - internal static SqlCommand InsertInto( - SqlConnection connection, - SqlServerVectorStoreOptions options, - string tableName, - VectorStoreRecordKeyProperty keyProperty, - IReadOnlyList dataProperties, - IReadOnlyList vectorProperties, - Dictionary record) - { - SqlCommand command = connection.CreateCommand(); - - StringBuilder sb = new(200); - sb.Append("INSERT INTO "); - sb.AppendTableName(options.Schema, tableName); - sb.Append(" ("); - var nonKeyProperties = dataProperties.Concat(vectorProperties); - sb.AppendColumnNames(nonKeyProperties); - sb.AppendLine(")"); - sb.AppendFormat("OUTPUT inserted.[{0}]", GetColumnName(keyProperty)); - sb.AppendLine(); - sb.Append("VALUES ("); - int paramIndex = 0; - foreach (VectorStoreRecordProperty property in nonKeyProperties) - { - sb.AppendParameterName(property, ref paramIndex, out string paramName).Append(','); - command.AddParameter(property, paramName, record[property.DataModelPropertyName]); - } - sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis - sb.Append(';'); - - command.CommandText = sb.ToString(); - return command; - } - internal static SqlCommand MergeIntoSingle( SqlConnection connection, SqlServerVectorStoreOptions options, @@ -172,7 +138,8 @@ internal static SqlCommand MergeIntoSingle( sb.AppendLine(")"); sb.Append("VALUES ("); sb.AppendColumnNames(propertiesToInsert, prefix: "s."); - sb.Append(");"); + sb.AppendLine(")"); + sb.AppendFormat("OUTPUT inserted.[{0}];", GetColumnName(keyProperty)); command.CommandText = sb.ToString(); return command; diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index 4d82269807d7..9bb2c01fd87e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -158,36 +158,17 @@ public async Task UpsertAsync(TRecord record, CancellationToken cancellati await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); - TKey? key = (TKey)this._propertyReader.KeyPropertyInfo.GetValue(record); - Dictionary map = Map(record, this._propertyReader); - - if (key is null || key.Equals(default(TKey))) - { - // When the key was not provided, we are inserting a new record. - using SqlCommand insertCommand = SqlServerCommandBuilder.InsertInto( - this._sqlConnection, - this._options, - this.CollectionName, - this._propertyReader.KeyProperty, - this._propertyReader.DataProperties, - this._propertyReader.VectorProperties, - map); - - using SqlDataReader reader = await insertCommand.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - await reader.ReadAsync(cancellationToken).ConfigureAwait(false); - return reader.GetFieldValue(0); - } - using SqlCommand command = SqlServerCommandBuilder.MergeIntoSingle( this._sqlConnection, this._options, this.CollectionName, this._propertyReader.KeyProperty, this._propertyReader.Properties, - map); + Map(record, this._propertyReader)); - await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - return key; + using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + await reader.ReadAsync(cancellationToken).ConfigureAwait(false); + return reader.GetFieldValue(0); } public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index 54d13dc31a34..4e84a0314e65 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -143,54 +143,6 @@ PRIMARY KEY NONCLUSTERED ([id]) Assert.Equal(HandleNewLines(expectedCommand), command.CommandText); } - [Fact] - public void InsertInto() - { - SqlServerVectorStoreOptions options = new() - { - Schema = "schema" - }; - VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); - VectorStoreRecordDataProperty[] dataProperties = - [ - new VectorStoreRecordDataProperty("simpleString", typeof(string)), - new VectorStoreRecordDataProperty("simpleInt", typeof(int)) - ]; - VectorStoreRecordVectorProperty[] vectorProperties = - [ - new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) - { - Dimensions = 10 - } - ]; - using SqlConnection connection = CreateConnection(); - - using SqlCommand command = SqlServerCommandBuilder.InsertInto(connection, options, "table", - keyProperty, dataProperties, vectorProperties, - new Dictionary - { - { "id", null }, - { "simpleString", "nameValue" }, - { "simpleInt", 134 }, - { "embedding", "{ 10.0 }" } - }); - - string expectedCommand = - """ - INSERT INTO [schema].[table] ([simpleString],[simpleInt],[embedding]) - OUTPUT inserted.[id] - VALUES (@simpleString_0,@simpleInt_1,@embedding_2); - """; - - Assert.Equal(HandleNewLines(expectedCommand), command.CommandText); - Assert.Equal("@simpleString_0", command.Parameters[0].ParameterName); - Assert.Equal("nameValue", command.Parameters[0].Value); - Assert.Equal("@simpleInt_1", command.Parameters[1].ParameterName); - Assert.Equal(134, command.Parameters[1].Value); - Assert.Equal("@embedding_2", command.Parameters[2].ParameterName); - Assert.Equal("{ 10.0 }", command.Parameters[2].Value); - } - [Fact] public void MergeIntoSingle() { @@ -198,7 +150,7 @@ public void MergeIntoSingle() { Schema = "schema" }; - VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long), autoGenerate: true); VectorStoreRecordProperty[] properties = [ keyProperty, @@ -229,8 +181,9 @@ MERGE INTO [schema].[table] AS t WHEN MATCHED THEN UPDATE SET t.[simpleString] = s.[simpleString],t.[simpleInt] = s.[simpleInt],t.[embedding] = s.[embedding] WHEN NOT MATCHED THEN - INSERT ([id],[simpleString],[simpleInt],[embedding]) - VALUES (s.[id],s.[simpleString],s.[simpleInt],s.[embedding]); + INSERT ([simpleString],[simpleInt],[embedding]) + VALUES (s.[simpleString],s.[simpleInt],s.[embedding]) + OUTPUT inserted.[id]; """"; Assert.Equal(HandleNewLines(expectedCommand), command.CommandText); From 7c8d2dcf3b747609af2e4121dc56dd36d839480a Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Thu, 20 Feb 2025 10:19:51 +0100 Subject: [PATCH 16/32] vector search --- .../PostgresVectorStoreRecordCollection.cs | 34 +------------ .../SqlServerCommandBuilder.cs | 45 ++++++++++++++++- .../SqlServerVectorStoreRecordCollection.cs | 50 ++++++++++++++++++- .../VectorSearch/VectorSearchOptions.cs | 32 +++++++++++- .../Data/VectorStoreRecordPropertyReader.cs | 28 +++++++++++ .../SqlServerCommandBuilderTests.cs | 4 +- .../SqlServerVectorStoreTests.cs | 14 ++++-- .../Support/SqlServerTestStore.cs | 2 + 8 files changed, 165 insertions(+), 44 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index bea84dce1b06..664a2d41a917 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -266,12 +266,7 @@ public virtual Task> VectorizedSearchAsync } var searchOptions = options ?? s_defaultVectorSearchOptions; - var vectorProperty = this.GetVectorPropertyForSearch(searchOptions.VectorPropertyName); - - if (vectorProperty is null) - { - throw new InvalidOperationException("The collection does not have any vector properties, so vector search is not possible."); - } + var vectorProperty = this._propertyReader.GetVectorPropertyForSearch(searchOptions.VectorPropertyName); var pgVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); @@ -318,33 +313,6 @@ private Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken c return this._client.CreateTableAsync(this.CollectionName, this._propertyReader.RecordDefinition.Properties, ifNotExists, cancellationToken); } - /// - /// Get vector property to use for a search by using the storage name for the field name from options - /// if available, and falling back to the first vector property in if not. - /// - /// The vector field name. - /// Thrown if the provided field name is not a valid field name. - private VectorStoreRecordVectorProperty? GetVectorPropertyForSearch(string? vectorFieldName) - { - // If vector property name is provided in options, try to find it in schema or throw an exception. - if (!string.IsNullOrWhiteSpace(vectorFieldName)) - { - // Check vector properties by data model property name. - var vectorProperty = this._propertyReader.VectorProperties - .FirstOrDefault(l => l.DataModelPropertyName.Equals(vectorFieldName, StringComparison.Ordinal)); - - if (vectorProperty is not null) - { - return vectorProperty; - } - - throw new InvalidOperationException($"The {typeof(TRecord).FullName} type does not have a vector property named '{vectorFieldName}'."); - } - - // If vector property is not provided in options, return first vector property from schema. - return this._propertyReader.VectorProperty; - } - private async Task RunOperationAsync(string operationName, Func operation) { try diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index c53f275ee65f..75260a55aa00 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -1,8 +1,10 @@ -using System; +// Copyright (c) Microsoft. All rights reserved. + +using System; using System.Collections.Generic; -using System.Data; using System.Linq; using System.Text; +using System.Text.Json; using Microsoft.Data.SqlClient; using Microsoft.Extensions.VectorData; @@ -301,6 +303,45 @@ internal static SqlCommand SelectMany( return command; } + internal static SqlCommand SelectVector( + SqlConnection connection, string schema, string tableName, + VectorStoreRecordVectorProperty vectorProperty, + IReadOnlyList properties, + VectorSearchOptions options, + ReadOnlyMemory vector) + { + string distanceFunction = vectorProperty.DistanceFunction ?? DistanceFunction.CosineDistance; + // Source: https://learn.microsoft.com/sql/t-sql/functions/vector-distance-transact-sql + string distanceMetric = distanceFunction switch + { + DistanceFunction.CosineDistance => "cosine", + DistanceFunction.EuclideanDistance => "euclidean", + DistanceFunction.DotProductSimilarity => "dot", + _ => throw new NotSupportedException($"Distance function {vectorProperty.DistanceFunction} is not supported.") + }; + + SqlCommand command = connection.CreateCommand(); + command.Parameters.AddWithValue("@vector", JsonSerializer.Serialize(vector)); + + StringBuilder sb = new(200); + sb.AppendFormat("SELECT "); + sb.AppendColumnNames(properties); + sb.AppendLine(","); + sb.AppendFormat("1 - VECTOR_DISTANCE('{0}', {1}, CAST(@vector AS VECTOR({2}))) AS [score]", + distanceMetric, GetColumnName(vectorProperty), vector.Length); + sb.AppendLine(); + sb.Append("FROM "); + sb.AppendTableName(schema, tableName); + sb.AppendLine(); + sb.AppendLine("ORDER BY [score] DESC"); + // Negative Skip and Top values are rejected by the VectorSearchOptions property setters. + // 0 is a legal value for OFFSET. + sb.AppendFormat("OFFSET {0} ROWS FETCH NEXT {1} ROWS ONLY;", options.Skip, options.Top); + + command.CommandText = sb.ToString(); + return command; + } + internal static string GetColumnName(VectorStoreRecordProperty property) => property.StoragePropertyName ?? property.DataModelPropertyName; diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index 9bb2c01fd87e..93a7a93e7b8a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -15,6 +15,8 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer; internal sealed class SqlServerVectorStoreRecordCollection : IVectorStoreRecordCollection where TKey : notnull { + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private readonly SqlConnection _sqlConnection; private readonly SqlServerVectorStoreOptions _options; private readonly VectorStoreRecordPropertyReader _propertyReader; @@ -195,7 +197,53 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + Verify.NotNull(vector); + + if (vector is not ReadOnlyMemory allowed) + { + throw new NotSupportedException( + $"The provided vector type {vector.GetType().FullName} is not supported by the SQL Server connector. " + + $"Supported types are: {string.Join(", ", SqlServerVectorStore.s_supportedVectorTypes.Select(l => l.FullName))}"); + } + + var searchOptions = options ?? s_defaultVectorSearchOptions; + var vectorProperty = this._propertyReader.GetVectorPropertyForSearch(searchOptions.VectorPropertyName); + + var results = this.ReadVectorSearchResultsAsync(allowed, vectorProperty, searchOptions, cancellationToken); + return Task.FromResult(new VectorSearchResults(results)); + } + + private async IAsyncEnumerable> ReadVectorSearchResultsAsync( + ReadOnlyMemory vector, + VectorStoreRecordVectorProperty vectorProperty, + VectorSearchOptions searchOptions, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); + + using SqlCommand command = SqlServerCommandBuilder.SelectVector( + this._sqlConnection, + this._options.Schema, + this.CollectionName, + vectorProperty, + this._propertyReader.Properties, + searchOptions, + vector); + + using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + + int scoreIndex = -1; + while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + if (scoreIndex < 0) + { + scoreIndex = reader.GetOrdinal("score"); + } + + yield return new VectorSearchResult( + Map(reader, this._propertyReader), + reader.GetDouble(scoreIndex)); + } } private Task EnsureConnectionIsOpenedAsync(CancellationToken cancellationToken) diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs index 65d9c6e157c2..7bfed5a8a3af 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs @@ -10,6 +10,8 @@ namespace Microsoft.Extensions.VectorData; /// public class VectorSearchOptions { + private int _top = 3, _skip = 0; + /// /// Gets or sets a search filter to use before doing the vector search. /// @@ -31,12 +33,38 @@ public class VectorSearchOptions /// /// Gets or sets the maximum number of results to return. /// - public int Top { get; init; } = 3; + /// Thrown when the value is less than 1. + public int Top + { + get => this._top; + init + { + if (value < 1) + { + throw new ArgumentOutOfRangeException(nameof(value), "Top must be greater than or equal to 1."); + } + + this._top = value; + } + } /// /// Gets or sets the number of results to skip before returning results, i.e. the index of the first result to return. /// - public int Skip { get; init; } = 0; + /// Thrown when the value is less than 0. + public int Skip + { + get => this._skip; + init + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "Skip must be greater than or equal to 0."); + } + + this._skip = value; + } + } /// /// Gets or sets a value indicating whether to include vectors in the retrieval result. diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs index c833093462ca..22ca6c113406 100644 --- a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs +++ b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs @@ -498,6 +498,34 @@ public static (List KeyProperties, List DataProperti return (keyProperties, dataProperties, vectorProperties); } + /// + /// Get vector property to use for a search by using the storage name for the field name from options + /// if available, and falling back to the first vector property in if not. + /// + /// The vector field name. + /// Thrown if the provided field name is not a valid field name or when no vector property was defined. + public VectorStoreRecordVectorProperty GetVectorPropertyForSearch(string? vectorFieldName) + { + // If vector property name is provided in options, try to find it in schema or throw an exception. + if (!string.IsNullOrWhiteSpace(vectorFieldName)) + { + // Check vector properties by data model property name. + var vectorProperty = this.VectorProperties + .FirstOrDefault(l => l.DataModelPropertyName.Equals(vectorFieldName, StringComparison.Ordinal)); + + if (vectorProperty is not null) + { + return vectorProperty; + } + + throw new InvalidOperationException($"The {this._dataModelType.FullName} type does not have a vector property named '{vectorFieldName}'."); + } + + // If vector property is not provided in options, return first vector property from schema. + return this.VectorProperty + ?? throw new InvalidOperationException("The collection does not have any vector properties, so vector search is not possible."); ; + } + /// /// Create a by reading the attributes on the provided objects. /// diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index 4e84a0314e65..5a8574990310 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -1,4 +1,6 @@ -using System.Text; +// Copyright (c) Microsoft. All rights reserved. + +using System.Text; using Microsoft.Data.SqlClient; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.SqlServer; diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs index 1d45c65ee18f..f22b9cd70c59 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -1,4 +1,6 @@ -using Microsoft.Extensions.VectorData; +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; using SqlServerIntegrationTests.Support; using Xunit; @@ -168,7 +170,7 @@ private static void AssertEquality(TestModel inserted, TestModel? received) public sealed class TestModel { [VectorStoreRecordKey(StoragePropertyName = "key")] - public string Id { get; set; } + public string? Id { get; set; } [VectorStoreRecordData(StoragePropertyName = "text")] public string? Text { get; set; } @@ -214,7 +216,7 @@ private async Task CanUseFancyModels() where TKey : notnull ListOfStrings = ["d", "e", "f"] }; TKey key = await collection.UpsertAsync(inserted); - Assert.NotEqual(default(TKey), key); // key should be assigned by the DB (auto-increment) + Assert.NotEqual(default, key); // key should be assigned by the DB (auto-increment) FancyTestModel? received = await collection.GetAsync(key); AssertEquality(inserted, received, key); @@ -257,10 +259,10 @@ void AssertEquality(FancyTestModel expected, FancyTestModel? receive } } - public sealed class FancyTestModel where TKey : notnull + public sealed class FancyTestModel { [VectorStoreRecordKey(StoragePropertyName = "key", AutoGenerate = true)] - public TKey Id { get; set; } + public TKey? Id { get; set; } [VectorStoreRecordData(StoragePropertyName = "byte")] public byte Number8 { get; set; } @@ -275,10 +277,12 @@ public sealed class FancyTestModel where TKey : notnull public long Number64 { get; set; } [VectorStoreRecordData(StoragePropertyName = "bytes")] +#pragma warning disable CA1819 // Properties should not return arrays public byte[]? Bytes { get; set; } [VectorStoreRecordData(StoragePropertyName = "array_of_strings")] public string[]? ArrayOfStrings { get; set; } +#pragma warning restore CA1819 // Properties should not return arrays [VectorStoreRecordData(StoragePropertyName = "list_of_strings")] public List? ListOfStrings { get; set; } diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs index 45ec63622e9f..e3bfceb54bc1 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs @@ -14,6 +14,8 @@ public sealed class SqlServerTestStore : TestStore public override IVectorStore DefaultVectorStore => this._connectedStore ?? throw new InvalidOperationException("Not initialized"); + public override string DefaultDistanceFunction => DistanceFunction.CosineDistance; + private SqlServerVectorStore? _connectedStore; protected override async Task StartAsync() From 9e5ef1c9fab682425cbfa5759c6dedba3f5ebc02 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Thu, 20 Feb 2025 14:20:03 +0100 Subject: [PATCH 17/32] implement filtering by reusing a lot of code implemented by @roji --- .../SqlServerCommandBuilder.cs | 15 + .../SqlServerFilterTranslator.cs | 269 ++++++++++++++++++ .../SqlServerVectorStoreRecordCollection.cs | 7 + 3 files changed, 291 insertions(+) create mode 100644 dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 75260a55aa00..3c38ca2af44d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -307,6 +307,7 @@ internal static SqlCommand SelectVector( SqlConnection connection, string schema, string tableName, VectorStoreRecordVectorProperty vectorProperty, IReadOnlyList properties, + IReadOnlyDictionary storagePropertyNamesMap, VectorSearchOptions options, ReadOnlyMemory vector) { @@ -333,6 +334,20 @@ internal static SqlCommand SelectVector( sb.Append("FROM "); sb.AppendTableName(schema, tableName); sb.AppendLine(); + if (options.NewFilter is not null) + { + int startParamIndex = command.Parameters.Count; + + List parameters = new SqlServerFilterTranslator(sb, schema).Translate( + storagePropertyNamesMap, + options.NewFilter, + startParamIndex); + + foreach (object parameter in parameters) + { + command.AddParameter(vectorProperty, $"@_{startParamIndex++}", parameter); + } + } sb.AppendLine("ORDER BY [score] DESC"); // Negative Skip and Top values are rejected by the VectorSearchOptions property setters. // 0 is a legal value for OFFSET. diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs new file mode 100644 index 000000000000..c6c43a8ca186 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs @@ -0,0 +1,269 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +internal class SqlServerFilterTranslator(StringBuilder sql, string schema) +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + + private readonly List _parameterValues = new(); + private int _parameterIndex; + + private readonly StringBuilder _sql = sql; + private readonly string _schema = schema; + + internal List Translate( + IReadOnlyDictionary storagePropertyNames, + LambdaExpression lambdaExpression, + int startParamIndex) + { + this._storagePropertyNames = storagePropertyNames; + this._parameterIndex = startParamIndex; + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + this._sql.Append("WHERE "); + this.Translate(lambdaExpression.Body); + return this._parameterValues; + } + + private void Translate(Expression? node) + { + switch (node) + { + case BinaryExpression binary: + this.TranslateBinary(binary); + return; + + case ConstantExpression constant: + this.TranslateConstant(constant); + return; + + case MemberExpression member: + this.TranslateMember(member); + return; + + case MethodCallExpression methodCall: + this.TranslateMethodCall(methodCall); + return; + + case UnaryExpression unary: + this.TranslateUnary(unary); + return; + + default: + throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); + } + } + + private void TranslateBinary(BinaryExpression binary) + { + // Special handling for null comparisons + switch (binary.NodeType) + { + case ExpressionType.Equal when IsNull(binary.Right): + this._sql.Append('('); + this.Translate(binary.Left); + this._sql.Append(" IS NULL)"); + return; + case ExpressionType.NotEqual when IsNull(binary.Right): + this._sql.Append('('); + this.Translate(binary.Left); + this._sql.Append(" IS NOT NULL)"); + return; + + case ExpressionType.Equal when IsNull(binary.Left): + this._sql.Append('('); + this.Translate(binary.Right); + this._sql.Append(" IS NULL)"); + return; + case ExpressionType.NotEqual when IsNull(binary.Left): + this._sql.Append('('); + this.Translate(binary.Right); + this._sql.Append(" IS NOT NULL)"); + return; + } + + this._sql.Append('('); + this.Translate(binary.Left); + + // SQL Server uses the same values as PostgreSQL. + this._sql.Append(binary.NodeType switch + { + ExpressionType.Equal => " = ", + ExpressionType.NotEqual => " <> ", + + ExpressionType.GreaterThan => " > ", + ExpressionType.GreaterThanOrEqual => " >= ", + ExpressionType.LessThan => " < ", + ExpressionType.LessThanOrEqual => " <= ", + + ExpressionType.AndAlso => " AND ", + ExpressionType.OrElse => " OR ", + + _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) + }); + + this.Translate(binary.Right); + this._sql.Append(')'); + + static bool IsNull(Expression expression) + => expression is ConstantExpression { Value: null } + || (TryGetCapturedValue(expression, out var capturedValue) && capturedValue is null); + } + + private void TranslateConstant(ConstantExpression constant) + { + // TODO: Nullable + switch (constant.Value) + { + case byte b: + this._sql.Append(b); + return; + case short s: + this._sql.Append(s); + return; + case int i: + this._sql.Append(i); + return; + case long l: + this._sql.Append(l); + return; + + case string s: + this._sql.Append('\'').Append(s.Replace("'", "''")).Append('\''); + return; + case bool b: + // SQL Server uses 1 and 0 to represent booleans. + this._sql.Append(b ? "1" : "0"); + return; + case Guid g: + this._sql.Append('\'').Append(g.ToString()).Append('\''); + return; + + case DateTime dateTime: + // SQL Server supports DateTimes in this particular format + this._sql.AppendFormat("'{0:yyyy-MM-dd HH:mm:ss}'", dateTime); + return; + + case DateTimeOffset dateTimeOffset: + // SQL Server supports DateTimeOffsets in this particular format + this._sql.AppendFormat("'{0:yyy-MM-dd HH:mm:ss zzz}'", dateTimeOffset); + return; + + case Array: + throw new NotImplementedException(); + + case null: + this._sql.Append("NULL"); + return; + + default: + throw new NotSupportedException("Unsupported constant type: " + constant.Value.GetType().Name); + } + } + + private void TranslateMember(MemberExpression memberExpression) + { + switch (memberExpression) + { + case var _ when this.TryGetColumn(memberExpression, out var column): + // SQL Server uses square brackets to escape column names and requires the schema to be provided. + //this._sql.AppendTableName(this._schema, column); + this._sql.Append('"').Append(column).Append('"'); + return; + + // Identify captured lambda variables, translate to PostgreSQL parameters ($1, $2...) + case var _ when TryGetCapturedValue(memberExpression, out var capturedValue): + // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, + // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) + if (capturedValue is null) + { + this._sql.Append("NULL"); + } + else + { + this._parameterValues.Add(capturedValue); + // SQL Server paramters can't start with a digit. + this._sql.Append("@_").Append(this._parameterIndex++); + } + return; + + default: + throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + } + } + + private void TranslateMethodCall(MethodCallExpression methodCall) + { + throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); + } + + private void TranslateUnary(UnaryExpression unary) + { + switch (unary.NodeType) + { + case ExpressionType.Not: + // Special handling for !(a == b) and !(a != b) + if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) + { + this.TranslateBinary( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + return; + } + + this._sql.Append("(NOT "); + this.Translate(unary.Operand); + this._sql.Append(')'); + return; + + default: + throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); + } + } + + private bool TryGetColumn(Expression expression, [NotNullWhen(true)] out string? column) + { + if (expression is MemberExpression member && member.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column)) + { + throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + column = null; + return false; + } + + private static bool TryGetCapturedValue(Expression expression, out object? capturedValue) + { + if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + { + capturedValue = fieldInfo.GetValue(constant.Value); + return true; + } + + capturedValue = null; + return false; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index 93a7a93e7b8a..73f043c23d9d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -205,6 +205,12 @@ public Task> VectorizedSearchAsync(TVector $"The provided vector type {vector.GetType().FullName} is not supported by the SQL Server connector. " + $"Supported types are: {string.Join(", ", SqlServerVectorStore.s_supportedVectorTypes.Select(l => l.FullName))}"); } +#pragma warning disable CS0618 // Type or member is obsolete + else if (options is not null && options.Filter is not null) +#pragma warning restore CS0618 // Type or member is obsolete + { + throw new NotSupportedException("The obsolete Filter is not supported by the SQL Server connector, use NewFilter instead."); + } var searchOptions = options ?? s_defaultVectorSearchOptions; var vectorProperty = this._propertyReader.GetVectorPropertyForSearch(searchOptions.VectorPropertyName); @@ -227,6 +233,7 @@ private async IAsyncEnumerable> ReadVectorSearchResu this.CollectionName, vectorProperty, this._propertyReader.Properties, + this._propertyReader.StoragePropertyNamesMap, searchOptions, vector); From 080811f89f308512b0f8be17ce273bb001007d29 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Thu, 20 Feb 2025 14:28:42 +0100 Subject: [PATCH 18/32] reduce code duplication --- .../SqlFilterTranslator.cs | 327 ++++++++++++++++ .../Connectors.Memory.Postgres.csproj | 4 + .../PostgresFilterTranslator.cs | 328 ++-------------- ...PostgresVectorStoreCollectionSqlBuilder.cs | 11 +- .../Connectors.Memory.SqlServer.csproj | 4 + .../SqlServerCommandBuilder.cs | 8 +- .../SqlServerFilterTranslator.cs | 270 ++----------- .../Connectors.Memory.Sqlite.csproj | 4 + .../SqliteFilterTranslator.cs | 362 +++--------------- .../SqliteVectorStoreRecordCollection.cs | 5 +- 10 files changed, 467 insertions(+), 856 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs diff --git a/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs new file mode 100644 index 000000000000..ac708302b879 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs @@ -0,0 +1,327 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Microsoft.SemanticKernel.Connectors; + +internal partial class SqlFilterTranslator +{ + private readonly IReadOnlyDictionary _storagePropertyNames; + private readonly LambdaExpression _lambdaExpression; + private readonly ParameterExpression _recordParameter; + private readonly StringBuilder _sql; + + internal SqlFilterTranslator( + IReadOnlyDictionary storagePropertyNames, + LambdaExpression lambdaExpression, + StringBuilder? sql = null) + { + this._storagePropertyNames = storagePropertyNames; + this._lambdaExpression = lambdaExpression; + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + this._sql = sql ?? new(); + } + + internal StringBuilder Clause => this._sql; + + internal void Translate(bool appendWhere) + { + if (appendWhere) + { + this._sql.Append("WHERE "); + } + + this.Translate(this._lambdaExpression.Body); + } + + private void Translate(Expression? node) + { + switch (node) + { + case BinaryExpression binary: + this.TranslateBinary(binary); + return; + + case ConstantExpression constant: + this.TranslateConstant(constant); + return; + + case MemberExpression member: + this.TranslateMember(member); + return; + + case MethodCallExpression methodCall: + this.TranslateMethodCall(methodCall); + return; + + case UnaryExpression unary: + this.TranslateUnary(unary); + return; + + default: + throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); + } + } + + private void TranslateBinary(BinaryExpression binary) + { + // Special handling for null comparisons + switch (binary.NodeType) + { + case ExpressionType.Equal when IsNull(binary.Right): + this._sql.Append('('); + this.Translate(binary.Left); + this._sql.Append(" IS NULL)"); + return; + case ExpressionType.NotEqual when IsNull(binary.Right): + this._sql.Append('('); + this.Translate(binary.Left); + this._sql.Append(" IS NOT NULL)"); + return; + + case ExpressionType.Equal when IsNull(binary.Left): + this._sql.Append('('); + this.Translate(binary.Right); + this._sql.Append(" IS NULL)"); + return; + case ExpressionType.NotEqual when IsNull(binary.Left): + this._sql.Append('('); + this.Translate(binary.Right); + this._sql.Append(" IS NOT NULL)"); + return; + } + + this._sql.Append('('); + this.Translate(binary.Left); + + this._sql.Append(binary.NodeType switch + { + ExpressionType.Equal => " = ", + ExpressionType.NotEqual => " <> ", + + ExpressionType.GreaterThan => " > ", + ExpressionType.GreaterThanOrEqual => " >= ", + ExpressionType.LessThan => " < ", + ExpressionType.LessThanOrEqual => " <= ", + + ExpressionType.AndAlso => " AND ", + ExpressionType.OrElse => " OR ", + + _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) + }); + + this.Translate(binary.Right); + this._sql.Append(')'); + + static bool IsNull(Expression expression) + => expression is ConstantExpression { Value: null } + || (TryGetCapturedValue(expression, out _, out var capturedValue) && capturedValue is null); + } + + private void TranslateConstant(ConstantExpression constant) + => this.GenerateLiteral(constant.Value); + + private void GenerateLiteral(object? value) + { + // TODO: Nullable + switch (value) + { + case byte b: + this._sql.Append(b); + return; + case short s: + this._sql.Append(s); + return; + case int i: + this._sql.Append(i); + return; + case long l: + this._sql.Append(l); + return; + + case string s: + this._sql.Append('\'').Append(s.Replace("'", "''")).Append('\''); + return; + case bool b: + this.GenerateLiteral(b); + return; + case Guid g: + this._sql.Append('\'').Append(g.ToString()).Append('\''); + return; + + case DateTime dateTime: + this.GenerateLiteral(dateTime); + return; + + case DateTimeOffset dateTimeOffset: + this.GenerateLiteral(dateTimeOffset); + return; + + case Array: + throw new NotImplementedException(); + + case null: + this._sql.Append("NULL"); + return; + + default: + throw new NotSupportedException("Unsupported constant type: " + value.GetType().Name); + } + } + + private void TranslateMember(MemberExpression memberExpression) + { + switch (memberExpression) + { + case var _ when this.TryGetColumn(memberExpression, out var column): + this._sql.Append('"').Append(column).Append('"'); + return; + + // Identify captured lambda variables, translate to PostgreSQL parameters ($1, $2...) + case var _ when TryGetCapturedValue(memberExpression, out var name, out var value): + this.TranslateLambdaVariables(name, value); + return; + + default: + throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + } + } + + private void TranslateMethodCall(MethodCallExpression methodCall) + { + switch (methodCall) + { + // Enumerable.Contains() + case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable): + this.TranslateContains(source, item); + return; + + // List.Contains() + case + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>): + this.TranslateContains(source, item); + return; + + default: + throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); + } + } + + private void TranslateContains(Expression source, Expression item) + { + switch (source) + { + // Contains over array column (r => r.Strings.Contains("foo")) + case var _ when this.TryGetColumn(source, out _): + this.TranslateContainsOverArrayColumn(source, item); + return; + + // Contains over inline array (r => new[] { "foo", "bar" }.Contains(r.String)) + case NewArrayExpression newArray: + this.Translate(item); + this._sql.Append(" IN ("); + + var isFirst = true; + foreach (var element in newArray.Expressions) + { + if (isFirst) + { + isFirst = false; + } + else + { + this._sql.Append(", "); + } + + this.Translate(element); + } + + this._sql.Append(')'); + return; + + // Contains over captured array (r => arrayLocalVariable.Contains(r.String)) + case var _ when TryGetCapturedValue(source, out _, out var value): + this.TranslateContainsOverCapturedArray(source, item, value); + return; + + default: + throw new NotSupportedException("Unsupported Contains expression"); + } + } + + private void TranslateUnary(UnaryExpression unary) + { + switch (unary.NodeType) + { + case ExpressionType.Not: + // Special handling for !(a == b) and !(a != b) + if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) + { + this.TranslateBinary( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + return; + } + + this._sql.Append("(NOT "); + this.Translate(unary.Operand); + this._sql.Append(')'); + return; + + default: + throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); + } + } + + private bool TryGetColumn(Expression expression, [NotNullWhen(true)] out string? column) + { + if (expression is MemberExpression member && member.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column)) + { + throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + column = null; + return false; + } + + private static bool TryGetCapturedValue(Expression expression, [NotNullWhen(true)] out string? name, out object? value) + { + if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + { + name = fieldInfo.Name; + value = fieldInfo.GetValue(constant.Value); + return true; + } + + name = null; + value = null; + return false; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj b/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj index b1904c6cc1cd..03b36f7525b1 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj @@ -19,6 +19,10 @@ Postgres(with pgvector extension) connector for Semantic Kernel plugins and semantic memory + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs index 6c68527da5c1..a181ebc3ecbb 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs @@ -2,331 +2,59 @@ using System; using System.Collections.Generic; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Linq; using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Text; -namespace Microsoft.SemanticKernel.Connectors.Postgres; +namespace Microsoft.SemanticKernel.Connectors; -internal class PostgresFilterTranslator +internal partial class SqlFilterTranslator { - private IReadOnlyDictionary _storagePropertyNames = null!; - private ParameterExpression _recordParameter = null!; - private readonly List _parameterValues = new(); private int _parameterIndex; - private readonly StringBuilder _sql = new(); + internal List ParameterValues => this._parameterValues; - internal (string Clause, List Parameters) Translate( - IReadOnlyDictionary storagePropertyNames, - LambdaExpression lambdaExpression, - int startParamIndex) + internal void Initialize(int startParamIndex) { - Debug.Assert(this._sql.Length == 0); - - this._storagePropertyNames = storagePropertyNames; - this._parameterIndex = startParamIndex; - - Debug.Assert(lambdaExpression.Parameters.Count == 1); - this._recordParameter = lambdaExpression.Parameters[0]; - - this._sql.Append("WHERE "); - this.Translate(lambdaExpression.Body); - return (this._sql.ToString(), this._parameterValues); - } - - private void Translate(Expression? node) - { - switch (node) - { - case BinaryExpression binary: - this.TranslateBinary(binary); - return; - - case ConstantExpression constant: - this.TranslateConstant(constant); - return; - - case MemberExpression member: - this.TranslateMember(member); - return; - - case MethodCallExpression methodCall: - this.TranslateMethodCall(methodCall); - return; - - case UnaryExpression unary: - this.TranslateUnary(unary); - return; - - default: - throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); - } - } - - private void TranslateBinary(BinaryExpression binary) - { - // Special handling for null comparisons - switch (binary.NodeType) - { - case ExpressionType.Equal when IsNull(binary.Right): - this._sql.Append('('); - this.Translate(binary.Left); - this._sql.Append(" IS NULL)"); - return; - case ExpressionType.NotEqual when IsNull(binary.Right): - this._sql.Append('('); - this.Translate(binary.Left); - this._sql.Append(" IS NOT NULL)"); - return; - - case ExpressionType.Equal when IsNull(binary.Left): - this._sql.Append('('); - this.Translate(binary.Right); - this._sql.Append(" IS NULL)"); - return; - case ExpressionType.NotEqual when IsNull(binary.Left): - this._sql.Append('('); - this.Translate(binary.Right); - this._sql.Append(" IS NOT NULL)"); - return; - } - - this._sql.Append('('); - this.Translate(binary.Left); - - this._sql.Append(binary.NodeType switch - { - ExpressionType.Equal => " = ", - ExpressionType.NotEqual => " <> ", - - ExpressionType.GreaterThan => " > ", - ExpressionType.GreaterThanOrEqual => " >= ", - ExpressionType.LessThan => " < ", - ExpressionType.LessThanOrEqual => " <= ", - - ExpressionType.AndAlso => " AND ", - ExpressionType.OrElse => " OR ", - - _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) - }); - - this.Translate(binary.Right); - this._sql.Append(')'); - - static bool IsNull(Expression expression) - => expression is ConstantExpression { Value: null } - || (TryGetCapturedValue(expression, out var capturedValue) && capturedValue is null); - } - - private void TranslateConstant(ConstantExpression constant) - { - // TODO: Nullable - switch (constant.Value) - { - case byte b: - this._sql.Append(b); - return; - case short s: - this._sql.Append(s); - return; - case int i: - this._sql.Append(i); - return; - case long l: - this._sql.Append(l); - return; - - case string s: - this._sql.Append('\'').Append(s.Replace("'", "''")).Append('\''); - return; - case bool b: - this._sql.Append(b ? "TRUE" : "FALSE"); - return; - case Guid g: - this._sql.Append('\'').Append(g.ToString()).Append('\''); - return; - - case DateTime: - case DateTimeOffset: - throw new NotImplementedException(); - - case Array: - throw new NotImplementedException(); - - case null: - this._sql.Append("NULL"); - return; - - default: - throw new NotSupportedException("Unsupported constant type: " + constant.Value.GetType().Name); - } } - private void TranslateMember(MemberExpression memberExpression) - { - switch (memberExpression) - { - case var _ when this.TryGetColumn(memberExpression, out var column): - this._sql.Append('"').Append(column).Append('"'); - return; + private void GenerateLiteral(bool value) + => this._sql.Append(value ? "TRUE" : "FALSE"); - // Identify captured lambda variables, translate to PostgreSQL parameters ($1, $2...) - case var _ when TryGetCapturedValue(memberExpression, out var capturedValue): - // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, - // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) - if (capturedValue is null) - { - this._sql.Append("NULL"); - } - else - { - this._parameterValues.Add(capturedValue); - this._sql.Append('$').Append(this._parameterIndex++); - } - return; + private void GenerateLiteral(DateTime dateTime) + => throw new NotImplementedException(); - default: - throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); - } - } + private void GenerateLiteral(DateTimeOffset dateTimeOffset) + => throw new NotImplementedException(); - private void TranslateMethodCall(MethodCallExpression methodCall) + private void TranslateContainsOverArrayColumn(Expression source, Expression item) { - switch (methodCall) - { - // Enumerable.Contains() - case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains - when contains.Method.DeclaringType == typeof(Enumerable): - this.TranslateContains(source, item); - return; - - // List.Contains() - case - { - Method: - { - Name: nameof(Enumerable.Contains), - DeclaringType: { IsGenericType: true } declaringType - }, - Object: Expression source, - Arguments: [var item] - } when declaringType.GetGenericTypeDefinition() == typeof(List<>): - this.TranslateContains(source, item); - return; - - default: - throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); - } + this.Translate(source); + this._sql.Append(" @> ARRAY["); + this.Translate(item); + this._sql.Append(']'); } - private void TranslateContains(Expression source, Expression item) + private void TranslateContainsOverCapturedArray(Expression source, Expression item, object? _) { - switch (source) - { - // Contains over array column (r => r.Strings.Contains("foo")) - case var _ when this.TryGetColumn(source, out _): - this.Translate(source); - this._sql.Append(" @> ARRAY["); - this.Translate(item); - this._sql.Append(']'); - return; - - // Contains over inline array (r => new[] { "foo", "bar" }.Contains(r.String)) - case NewArrayExpression newArray: - this.Translate(item); - this._sql.Append(" IN ("); - - var isFirst = true; - foreach (var element in newArray.Expressions) - { - if (isFirst) - { - isFirst = false; - } - else - { - this._sql.Append(", "); - } - - this.Translate(element); - } - - this._sql.Append(')'); - return; - - // Contains over captured array (r => arrayLocalVariable.Contains(r.String)) - case var _ when TryGetCapturedValue(source, out _): - this.Translate(item); - this._sql.Append(" = ANY ("); - this.Translate(source); - this._sql.Append(')'); - return; - - default: - throw new NotSupportedException("Unsupported Contains expression"); - } - } - - private void TranslateUnary(UnaryExpression unary) - { - switch (unary.NodeType) - { - case ExpressionType.Not: - // Special handling for !(a == b) and !(a != b) - if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) - { - this.TranslateBinary( - Expression.MakeBinary( - binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, - binary.Left, - binary.Right)); - return; - } - - this._sql.Append("(NOT "); - this.Translate(unary.Operand); - this._sql.Append(')'); - return; - - default: - throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); - } + this.Translate(item); + this._sql.Append(" = ANY ("); + this.Translate(source); + this._sql.Append(')'); } - private bool TryGetColumn(Expression expression, [NotNullWhen(true)] out string? column) + private void TranslateLambdaVariables(string _, object? capturedValue) { - if (expression is MemberExpression member && member.Expression == this._recordParameter) + // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, + // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) + if (capturedValue is null) { - if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column)) - { - throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); - } - - return true; + this._sql.Append("NULL"); } - - column = null; - return false; - } - - private static bool TryGetCapturedValue(Expression expression, out object? capturedValue) - { - if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } - && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) - && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + else { - capturedValue = fieldInfo.GetValue(constant.Value); - return true; + this._parameterValues.Add(capturedValue); + this._sql.Append('$').Append(this._parameterIndex++); } - - capturedValue = null; - return false; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index 364c564703e4..b42ce6add8b6 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -383,7 +383,7 @@ public PostgresSqlCommandInfo BuildGetNearestMatchCommand( { (not null, not null) => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), (not null, null) => GenerateLegacyFilterWhereClause(schema, tableName, propertyReader.RecordDefinition.Properties, legacyFilter, startParamIndex: 2), - (null, not null) => new PostgresFilterTranslator().Translate(propertyReader.StoragePropertyNamesMap, newFilter, startParamIndex: 2), + (null, not null) => GenerateNewFilterWhereClause(propertyReader, newFilter), _ => (Clause: string.Empty, Parameters: []) }; #pragma warning restore CS0618 // VectorSearchFilter is obsolete @@ -424,6 +424,15 @@ ORDER BY {PostgresConstants.DistanceColumnName} Parameters = [new NpgsqlParameter { Value = vectorValue }, .. parameters.Select(p => new NpgsqlParameter { Value = p })] }; } + + internal static (string Clause, List Parameters) GenerateNewFilterWhereClause(VectorStoreRecordPropertyReader propertyReader, LambdaExpression newFilter) + { + SqlFilterTranslator translator = new(propertyReader.StoragePropertyNamesMap, newFilter); + translator.Initialize(startParamIndex: 2); + translator.Translate(appendWhere: true); + return (translator.Clause.ToString(), translator.ParameterValues); + } + #pragma warning disable CS0618 // VectorSearchFilter is obsolete internal static (string Clause, List Parameters) GenerateLegacyFilterWhereClause(string schema, string tableName, IReadOnlyList properties, VectorSearchFilter legacyFilter, int startParamIndex) { diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj b/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj index 457d1f4a8d93..90c088464efc 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj @@ -18,6 +18,10 @@ SQL Server connector for Semantic Kernel plugins and semantic memory + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 3c38ca2af44d..fd986ce5f086 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -338,10 +338,10 @@ internal static SqlCommand SelectVector( { int startParamIndex = command.Parameters.Count; - List parameters = new SqlServerFilterTranslator(sb, schema).Translate( - storagePropertyNamesMap, - options.NewFilter, - startParamIndex); + SqlFilterTranslator translator = new(storagePropertyNamesMap, options.NewFilter, sb); + translator.Initialize(startParamIndex: startParamIndex); + translator.Translate(appendWhere: true); + List parameters = translator.ParameterValues; foreach (object parameter in parameters) { diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs index c6c43a8ca186..379b91317b0a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs @@ -1,269 +1,77 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections; using System.Collections.Generic; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Linq; using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Text; -namespace Microsoft.SemanticKernel.Connectors.SqlServer; +namespace Microsoft.SemanticKernel.Connectors; -internal class SqlServerFilterTranslator(StringBuilder sql, string schema) +internal partial class SqlFilterTranslator { - private IReadOnlyDictionary _storagePropertyNames = null!; - private ParameterExpression _recordParameter = null!; - private readonly List _parameterValues = new(); private int _parameterIndex; - private readonly StringBuilder _sql = sql; - private readonly string _schema = schema; + internal List ParameterValues => this._parameterValues; - internal List Translate( - IReadOnlyDictionary storagePropertyNames, - LambdaExpression lambdaExpression, - int startParamIndex) + internal void Initialize(int startParamIndex) { - this._storagePropertyNames = storagePropertyNames; this._parameterIndex = startParamIndex; - - Debug.Assert(lambdaExpression.Parameters.Count == 1); - this._recordParameter = lambdaExpression.Parameters[0]; - - this._sql.Append("WHERE "); - this.Translate(lambdaExpression.Body); - return this._parameterValues; - } - - private void Translate(Expression? node) - { - switch (node) - { - case BinaryExpression binary: - this.TranslateBinary(binary); - return; - - case ConstantExpression constant: - this.TranslateConstant(constant); - return; - - case MemberExpression member: - this.TranslateMember(member); - return; - - case MethodCallExpression methodCall: - this.TranslateMethodCall(methodCall); - return; - - case UnaryExpression unary: - this.TranslateUnary(unary); - return; - - default: - throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); - } - } - - private void TranslateBinary(BinaryExpression binary) - { - // Special handling for null comparisons - switch (binary.NodeType) - { - case ExpressionType.Equal when IsNull(binary.Right): - this._sql.Append('('); - this.Translate(binary.Left); - this._sql.Append(" IS NULL)"); - return; - case ExpressionType.NotEqual when IsNull(binary.Right): - this._sql.Append('('); - this.Translate(binary.Left); - this._sql.Append(" IS NOT NULL)"); - return; - - case ExpressionType.Equal when IsNull(binary.Left): - this._sql.Append('('); - this.Translate(binary.Right); - this._sql.Append(" IS NULL)"); - return; - case ExpressionType.NotEqual when IsNull(binary.Left): - this._sql.Append('('); - this.Translate(binary.Right); - this._sql.Append(" IS NOT NULL)"); - return; - } - - this._sql.Append('('); - this.Translate(binary.Left); - - // SQL Server uses the same values as PostgreSQL. - this._sql.Append(binary.NodeType switch - { - ExpressionType.Equal => " = ", - ExpressionType.NotEqual => " <> ", - - ExpressionType.GreaterThan => " > ", - ExpressionType.GreaterThanOrEqual => " >= ", - ExpressionType.LessThan => " < ", - ExpressionType.LessThanOrEqual => " <= ", - - ExpressionType.AndAlso => " AND ", - ExpressionType.OrElse => " OR ", - - _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) - }); - - this.Translate(binary.Right); - this._sql.Append(')'); - - static bool IsNull(Expression expression) - => expression is ConstantExpression { Value: null } - || (TryGetCapturedValue(expression, out var capturedValue) && capturedValue is null); } - private void TranslateConstant(ConstantExpression constant) - { - // TODO: Nullable - switch (constant.Value) - { - case byte b: - this._sql.Append(b); - return; - case short s: - this._sql.Append(s); - return; - case int i: - this._sql.Append(i); - return; - case long l: - this._sql.Append(l); - return; - - case string s: - this._sql.Append('\'').Append(s.Replace("'", "''")).Append('\''); - return; - case bool b: - // SQL Server uses 1 and 0 to represent booleans. - this._sql.Append(b ? "1" : "0"); - return; - case Guid g: - this._sql.Append('\'').Append(g.ToString()).Append('\''); - return; + private void GenerateLiteral(bool value) + => this._sql.Append(value ? "1" : "0"); - case DateTime dateTime: - // SQL Server supports DateTimes in this particular format - this._sql.AppendFormat("'{0:yyyy-MM-dd HH:mm:ss}'", dateTime); - return; + private void GenerateLiteral(DateTime dateTime) + => this._sql.AppendFormat("'{0:yyyy-MM-dd HH:mm:ss}'", dateTime); - case DateTimeOffset dateTimeOffset: - // SQL Server supports DateTimeOffsets in this particular format - this._sql.AppendFormat("'{0:yyy-MM-dd HH:mm:ss zzz}'", dateTimeOffset); - return; + private void GenerateLiteral(DateTimeOffset dateTimeOffset) + => this._sql.AppendFormat("'{0:yyy-MM-dd HH:mm:ss zzz}'", dateTimeOffset); - case Array: - throw new NotImplementedException(); - - case null: - this._sql.Append("NULL"); - return; - - default: - throw new NotSupportedException("Unsupported constant type: " + constant.Value.GetType().Name); - } - } + private void TranslateContainsOverArrayColumn(Expression source, Expression item) + => throw new NotSupportedException("Unsupported Contains expression"); - private void TranslateMember(MemberExpression memberExpression) + private void TranslateContainsOverCapturedArray(Expression source, Expression item, object? value) { - switch (memberExpression) + if (value is not IEnumerable elements) { - case var _ when this.TryGetColumn(memberExpression, out var column): - // SQL Server uses square brackets to escape column names and requires the schema to be provided. - //this._sql.AppendTableName(this._schema, column); - this._sql.Append('"').Append(column).Append('"'); - return; - - // Identify captured lambda variables, translate to PostgreSQL parameters ($1, $2...) - case var _ when TryGetCapturedValue(memberExpression, out var capturedValue): - // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, - // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) - if (capturedValue is null) - { - this._sql.Append("NULL"); - } - else - { - this._parameterValues.Add(capturedValue); - // SQL Server paramters can't start with a digit. - this._sql.Append("@_").Append(this._parameterIndex++); - } - return; - - default: - throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + throw new NotSupportedException("Unsupported Contains expression"); } - } - private void TranslateMethodCall(MethodCallExpression methodCall) - { - throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); - } + this.Translate(item); + this._sql.Append(" IN ("); - private void TranslateUnary(UnaryExpression unary) - { - switch (unary.NodeType) + var isFirst = true; + foreach (var element in elements) { - case ExpressionType.Not: - // Special handling for !(a == b) and !(a != b) - if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) - { - this.TranslateBinary( - Expression.MakeBinary( - binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, - binary.Left, - binary.Right)); - return; - } - - this._sql.Append("(NOT "); - this.Translate(unary.Operand); - this._sql.Append(')'); - return; - - default: - throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); - } - } - - private bool TryGetColumn(Expression expression, [NotNullWhen(true)] out string? column) - { - if (expression is MemberExpression member && member.Expression == this._recordParameter) - { - if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column)) + if (isFirst) { - throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); + isFirst = false; + } + else + { + this._sql.Append(", "); } - return true; + this.GenerateLiteral(element); } - column = null; - return false; + this._sql.Append(')'); } - private static bool TryGetCapturedValue(Expression expression, out object? capturedValue) + private void TranslateLambdaVariables(string _, object? capturedValue) { - if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } - && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) - && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, + // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) + if (capturedValue is null) { - capturedValue = fieldInfo.GetValue(constant.Value); - return true; + this._sql.Append("NULL"); + } + else + { + this._parameterValues.Add(capturedValue); + // SQL Server paramters can't start with a digit (but underscore is OK). + this._sql.Append("@_").Append(this._parameterIndex++); } - - capturedValue = null; - return false; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/Connectors.Memory.Sqlite.csproj b/dotnet/src/Connectors/Connectors.Memory.Sqlite/Connectors.Memory.Sqlite.csproj index 2b17f3e0bbe3..fec218bfc49d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/Connectors.Memory.Sqlite.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/Connectors.Memory.Sqlite.csproj @@ -18,6 +18,10 @@ SQLite connector for Semantic Kernel plugins and semantic memory + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs index 2cb6b16fc8cd..07bc1894fd2e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs @@ -3,357 +3,81 @@ using System; using System.Collections; using System.Collections.Generic; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Linq; using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Text; -namespace Microsoft.SemanticKernel.Connectors.Sqlite; +namespace Microsoft.SemanticKernel.Connectors; -internal class SqliteFilterTranslator +internal partial class SqlFilterTranslator { - private IReadOnlyDictionary _storagePropertyNames = null!; - private ParameterExpression _recordParameter = null!; - private readonly Dictionary _parameters = new(); - private readonly StringBuilder _sql = new(); - - internal (string Clause, Dictionary) Translate(IReadOnlyDictionary storagePropertyNames, LambdaExpression lambdaExpression) - { - Debug.Assert(this._sql.Length == 0); - - this._storagePropertyNames = storagePropertyNames; - - Debug.Assert(lambdaExpression.Parameters.Count == 1); - this._recordParameter = lambdaExpression.Parameters[0]; - - this.Translate(lambdaExpression.Body); - return (this._sql.ToString(), this._parameters); - } - - private void Translate(Expression? node) - { - switch (node) - { - case BinaryExpression binary: - this.TranslateBinary(binary); - return; - - case ConstantExpression constant: - this.TranslateConstant(constant); - return; - - case MemberExpression member: - this.TranslateMember(member); - return; - - case MethodCallExpression methodCall: - this.TranslateMethodCall(methodCall); - return; - - case UnaryExpression unary: - this.TranslateUnary(unary); - return; - - default: - throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); - } - } - - private void TranslateBinary(BinaryExpression binary) - { - // Special handling for null comparisons - switch (binary.NodeType) - { - case ExpressionType.Equal when IsNull(binary.Right): - this._sql.Append('('); - this.Translate(binary.Left); - this._sql.Append(" IS NULL)"); - return; - case ExpressionType.NotEqual when IsNull(binary.Right): - this._sql.Append('('); - this.Translate(binary.Left); - this._sql.Append(" IS NOT NULL)"); - return; - - case ExpressionType.Equal when IsNull(binary.Left): - this._sql.Append('('); - this.Translate(binary.Right); - this._sql.Append(" IS NULL)"); - return; - case ExpressionType.NotEqual when IsNull(binary.Left): - this._sql.Append('('); - this.Translate(binary.Right); - this._sql.Append(" IS NOT NULL)"); - return; - } - - this._sql.Append('('); - this.Translate(binary.Left); - - this._sql.Append(binary.NodeType switch - { - ExpressionType.Equal => " = ", - ExpressionType.NotEqual => " <> ", - - ExpressionType.GreaterThan => " > ", - ExpressionType.GreaterThanOrEqual => " >= ", - ExpressionType.LessThan => " < ", - ExpressionType.LessThanOrEqual => " <= ", - - ExpressionType.AndAlso => " AND ", - ExpressionType.OrElse => " OR ", - - _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) - }); - - this.Translate(binary.Right); - this._sql.Append(')'); - - static bool IsNull(Expression expression) - => expression is ConstantExpression { Value: null } - || (TryGetCapturedValue(expression, out _, out var capturedValue) && capturedValue is null); - } - - private void TranslateConstant(ConstantExpression constant) - => this.GenerateLiteral(constant.Value); - - private void GenerateLiteral(object? value) - { - // TODO: Nullable - switch (value) - { - case byte b: - this._sql.Append(b); - return; - case short s: - this._sql.Append(s); - return; - case int i: - this._sql.Append(i); - return; - case long l: - this._sql.Append(l); - return; - - case string s: - this._sql.Append('\'').Append(s.Replace("'", "''")).Append('\''); - return; - case bool b: - this._sql.Append(b ? "TRUE" : "FALSE"); - return; - case Guid g: - this._sql.Append('\'').Append(g.ToString()).Append('\''); - return; + internal Dictionary Parameters => this._parameters; - case DateTime: - case DateTimeOffset: - throw new NotImplementedException(); + private void GenerateLiteral(bool value) + => this._sql.Append(value ? "TRUE" : "FALSE"); - case Array: - throw new NotImplementedException(); + private void GenerateLiteral(DateTime dateTime) + => throw new NotImplementedException(); - case null: - this._sql.Append("NULL"); - return; + private void GenerateLiteral(DateTimeOffset dateTimeOffset) + => throw new NotImplementedException(); - default: - throw new NotSupportedException("Unsupported constant type: " + value.GetType().Name); - } - } + // TODO: support Contains over array fields (#10343) + private void TranslateContainsOverArrayColumn(Expression source, Expression item) + => throw new NotSupportedException("Unsupported Contains expression"); - private void TranslateMember(MemberExpression memberExpression) + private void TranslateContainsOverCapturedArray(Expression source, Expression item, object? value) { - switch (memberExpression) + if (value is not IEnumerable elements) { - case var _ when this.TryGetColumn(memberExpression, out var column): - this._sql.Append('"').Append(column).Append('"'); - return; - - // Identify captured lambda variables, translate to PostgreSQL parameters ($1, $2...) - case var _ when TryGetCapturedValue(memberExpression, out var name, out var value): - // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, - // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) - if (value is null) - { - this._sql.Append("NULL"); - } - else - { - // Duplicate parameter name, create a new parameter with a different name - // TODO: Share the same parameter when it references the same captured value - if (this._parameters.ContainsKey(name)) - { - var baseName = name; - var i = 0; - do - { - name = baseName + (i++); - } while (this._parameters.ContainsKey(name)); - } - - this._parameters.Add(name, value); - this._sql.Append('@').Append(name); - } - return; - - default: - throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + throw new NotSupportedException("Unsupported Contains expression"); } - } - - private void TranslateMethodCall(MethodCallExpression methodCall) - { - switch (methodCall) - { - // Enumerable.Contains() - case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains - when contains.Method.DeclaringType == typeof(Enumerable): - this.TranslateContains(source, item); - return; - - // List.Contains() - case - { - Method: - { - Name: nameof(Enumerable.Contains), - DeclaringType: { IsGenericType: true } declaringType - }, - Object: Expression source, - Arguments: [var item] - } when declaringType.GetGenericTypeDefinition() == typeof(List<>): - this.TranslateContains(source, item); - return; - default: - throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); - } - } + this.Translate(item); + this._sql.Append(" IN ("); - private void TranslateContains(Expression source, Expression item) - { - switch (source) + var isFirst = true; + foreach (var element in elements) { - // TODO: support Contains over array fields (#10343) - // Contains over array column (r => r.Strings.Contains("foo")) - case var _ when this.TryGetColumn(source, out _): - goto default; - - // Contains over inline array (r => new[] { "foo", "bar" }.Contains(r.String)) - case NewArrayExpression newArray: + if (isFirst) { - this.Translate(item); - this._sql.Append(" IN ("); - - var isFirst = true; - foreach (var element in newArray.Expressions) - { - if (isFirst) - { - isFirst = false; - } - else - { - this._sql.Append(", "); - } - - this.Translate(element); - } - - this._sql.Append(')'); - return; + isFirst = false; } - - // Contains over captured array (r => arrayLocalVariable.Contains(r.String)) - case var _ when TryGetCapturedValue(source, out _, out var value) && value is IEnumerable elements: + else { - this.Translate(item); - this._sql.Append(" IN ("); - - var isFirst = true; - foreach (var element in elements) - { - if (isFirst) - { - isFirst = false; - } - else - { - this._sql.Append(", "); - } - - this.GenerateLiteral(element); - } - - this._sql.Append(')'); - return; + this._sql.Append(", "); } - default: - throw new NotSupportedException("Unsupported Contains expression"); + this.GenerateLiteral(element); } + + this._sql.Append(')'); } - private void TranslateUnary(UnaryExpression unary) + private void TranslateLambdaVariables(string name, object? value) { - switch (unary.NodeType) + // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, + // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) + if (value is null) { - case ExpressionType.Not: - // Special handling for !(a == b) and !(a != b) - if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) - { - this.TranslateBinary( - Expression.MakeBinary( - binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, - binary.Left, - binary.Right)); - return; - } - - this._sql.Append("(NOT "); - this.Translate(unary.Operand); - this._sql.Append(')'); - return; - - default: - throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); + this._sql.Append("NULL"); } - } - - private bool TryGetColumn(Expression expression, [NotNullWhen(true)] out string? column) - { - if (expression is MemberExpression member && member.Expression == this._recordParameter) + else { - if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column)) + // Duplicate parameter name, create a new parameter with a different name + // TODO: Share the same parameter when it references the same captured value + if (this._parameters.ContainsKey(name)) { - throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); + var baseName = name; + var i = 0; + do + { + name = baseName + (i++); + } while (this._parameters.ContainsKey(name)); } - return true; - } - - column = null; - return false; - } - - private static bool TryGetCapturedValue(Expression expression, [NotNullWhen(true)] out string? name, out object? value) - { - if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } - && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) - && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) - { - name = fieldInfo.Name; - value = fieldInfo.GetValue(constant.Value); - return true; + this._parameters.Add(name, value); + this._sql.Append('@').Append(name); } - - name = null; - value = null; - return false; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs index e3c2431491c3..8383dd4eceaa 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs @@ -210,7 +210,10 @@ public virtual Task> VectorizedSearchAsync } else if (searchOptions.NewFilter is not null) { - (extraWhereFilter, extraParameters) = new SqliteFilterTranslator().Translate(this._propertyReader.StoragePropertyNamesMap, searchOptions.NewFilter); + SqlFilterTranslator translator = new(this._propertyReader.StoragePropertyNamesMap, searchOptions.NewFilter); + translator.Translate(appendWhere: false); + extraWhereFilter = translator.Clause.ToString(); + extraParameters = translator.Parameters; } #pragma warning restore CS0618 // VectorSearchFilter is obsolete From c17021e72185376ccebb17b6c845000cdd211f40 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Thu, 20 Feb 2025 15:52:43 +0100 Subject: [PATCH 19/32] skip some tests, some polishing The ToListAsync change is just a wokaround https://github.com/dotnet/runtime/issues/79782#issuecomment-2625927473 --- .../SqlServerCommandBuilder.cs | 1 + .../Filter/SqlServerBasicFilterTests.cs | 50 +++++++++++++++++++ .../SqlServerMemoryStoreTests.cs | 12 ++--- 3 files changed, 57 insertions(+), 6 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index fd986ce5f086..4307a064a2c8 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -347,6 +347,7 @@ internal static SqlCommand SelectVector( { command.AddParameter(vectorProperty, $"@_{startParamIndex++}", parameter); } + sb.AppendLine(); } sb.AppendLine("ORDER BY [score] DESC"); // Negative Skip and Top values are rejected by the VectorSearchOptions property setters. diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs index e5d63757437e..cbc4f3b2977b 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs @@ -4,12 +4,62 @@ using VectorDataSpecificationTests.Filter; using VectorDataSpecificationTests.Support; using Xunit; +using Xunit.Sdk; namespace SqlServerIntegrationTests.Filter; public class SqlServerBasicFilterTests(SqlServerBasicFilterTests.Fixture fixture) : BasicFilterTests(fixture), IClassFixture { + // SQL Server doesn't support the null semantics that the default implementation relies on + // "SELECT * FROM MyTable WHERE BooleanColumn = 1;" is fine + // "SELECT * FROM MyTable WHERE BooleanColumn;" is not + // TODO adsitnik: get it to work anyway + public override Task Bool() => this.TestFilterAsync(r => r.Bool == true); + + // Same as above, "WHERE NOT BooleanColumn" is not supported + public override Task Not_over_bool() => this.TestFilterAsync(r => r.Bool == false); + + public override async Task Not_over_Or() + { + // Test sends: WHERE (NOT (("Int" = 8) OR ("String" = 'foo'))) + // There's a NULL string in the database, and relational null semantics in conjunction with negation makes the default implementation fail. + await Assert.ThrowsAsync(() => base.Not_over_Or()); + + // Compensate by adding a null check: + await this.TestFilterAsync(r => r.String != null && !(r.Int == 8 || r.String == "foo")); + } + + public override async Task NotEqual_with_string() + { + // As above, null semantics + negation + await Assert.ThrowsAsync(() => base.NotEqual_with_string()); + + await this.TestFilterAsync(r => r.String != null && r.String != "foo"); + } + + public override Task Contains_over_field_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + + public override Task Contains_over_field_string_List() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + + [Fact(Skip = "Not supported")] + [Obsolete("Legacy filters are not supported")] + public override Task Legacy_And() => throw new NotSupportedException(); + + [Fact(Skip = "Not supported")] + [Obsolete("Legacy filters are not supported")] + public override Task Legacy_AnyTagEqualTo_array() => throw new NotSupportedException(); + + [Fact(Skip = "Not supported")] + [Obsolete("Legacy filters are not supported")] + public override Task Legacy_AnyTagEqualTo_List() => throw new NotSupportedException(); + + [Fact(Skip = "Not supported")] + [Obsolete("Legacy filters are not supported")] + public override Task Legacy_equality() => throw new NotSupportedException(); + public new class Fixture : BasicFilterTests.Fixture { public override TestStore TestStore => SqlServerTestStore.Instance; diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs index 32c0f6742546..2c419c5f173d 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs @@ -79,7 +79,7 @@ public async Task GetCollectionsAsync() await this.Store.CreateCollectionAsync("collection1"); await this.Store.CreateCollectionAsync("collection2"); - var collections = await this.Store.GetCollectionsAsync().ToListAsync(); + var collections = await this.Store.GetCollectionsAsync().ToArrayAsync(); Assert.Contains("collection1", collections); Assert.Contains("collection2", collections); } @@ -212,8 +212,8 @@ public async Task GetNearestMatchesAsync(bool withEmbeddings) await this.Store.CreateCollectionAsync(DefaultCollectionName); await this.InsertSampleDataAsync(); - List<(MemoryRecord Record, double SimilarityScore)> results = - await this.Store.GetNearestMatchesAsync(DefaultCollectionName, new[] { 5f, 6f, 7f, 8f, 9f }, limit: 2, withEmbeddings: withEmbeddings).ToListAsync(); + (MemoryRecord Record, double SimilarityScore)[] results = + await this.Store.GetNearestMatchesAsync(DefaultCollectionName, new[] { 5f, 6f, 7f, 8f, 9f }, limit: 2, withEmbeddings: withEmbeddings).ToArrayAsync(); Assert.All(results, t => Assert.True(t.SimilarityScore > 0)); @@ -254,13 +254,13 @@ public async Task GetNearestMatchesWithMinRelevanceScoreAsync() await this.Store.CreateCollectionAsync(DefaultCollectionName); await this.InsertSampleDataAsync(); - List<(MemoryRecord Record, double SimilarityScore)> results = - await this.Store.GetNearestMatchesAsync(DefaultCollectionName, new[] { 5f, 6f, 7f, 8f, 9f }, limit: 2).ToListAsync(); + (MemoryRecord Record, double SimilarityScore)[] results = + await this.Store.GetNearestMatchesAsync(DefaultCollectionName, new[] { 5f, 6f, 7f, 8f, 9f }, limit: 2).ToArrayAsync(); var firstId = results[0].Record.Metadata.Id; var firstSimilarityScore = results[0].SimilarityScore; - results = await this.Store.GetNearestMatchesAsync(DefaultCollectionName, new[] { 5f, 6f, 7f, 8f, 9f }, limit: 2, minRelevanceScore: firstSimilarityScore + 0.0001).ToListAsync(); + results = await this.Store.GetNearestMatchesAsync(DefaultCollectionName, new[] { 5f, 6f, 7f, 8f, 9f }, limit: 2, minRelevanceScore: firstSimilarityScore + 0.0001).ToArrayAsync(); Assert.DoesNotContain(firstId, results.Select(r => r.Record.Metadata.Id)); } From 4669e91648afaf6de24cfbc53f207967543eba30 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Thu, 20 Feb 2025 16:54:00 +0100 Subject: [PATCH 20/32] remove a comment added by Copilot --- .../Filter/SqlServerBasicFilterTests.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs index cbc4f3b2977b..ff9fb0376a64 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs @@ -11,7 +11,6 @@ namespace SqlServerIntegrationTests.Filter; public class SqlServerBasicFilterTests(SqlServerBasicFilterTests.Fixture fixture) : BasicFilterTests(fixture), IClassFixture { - // SQL Server doesn't support the null semantics that the default implementation relies on // "SELECT * FROM MyTable WHERE BooleanColumn = 1;" is fine // "SELECT * FROM MyTable WHERE BooleanColumn;" is not // TODO adsitnik: get it to work anyway From ba0486fd044b1d303bbf0598ade676fdb1e1d510 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Mon, 24 Feb 2025 16:35:08 +0100 Subject: [PATCH 21/32] Update dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs Co-authored-by: westey <164392973+westey-m@users.noreply.github.com> --- .../RecordAttributes/VectorStoreRecordKeyAttribute.cs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs index 4bad37f8ccb7..7577547f50ed 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs @@ -22,5 +22,13 @@ public sealed class VectorStoreRecordKeyAttribute : Attribute /// /// Gets or sets whether the key should be auto-generated by the vector store. /// + /// + /// The default is . + /// + /// + /// If set to , you must set the key property on any record that you pass to . + /// If set to , the key property may be left null on any record that you pass to + /// and a generated key will be returned. + /// public bool AutoGenerate { get; set; } } From 5bdaa8e80a89f056a26ad97cc6dc2e6e62835376 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Mon, 24 Feb 2025 15:16:39 +0100 Subject: [PATCH 22/32] address code review feedback: - use abstract instead of partial class - make Dimensions mandatory - remove support of string collections - map "dot" to Negative Dot - make SqlServerVectorStoreRecordCollection public - make AutoGenerate init rather than readonly property - add net472 target for the test project, make it compile - use unique collection names so multiple users can run the same tests against shared db instance - verify that parameterless ctor is present for TRecord - make sure that exceptions are wrapped --- .../SqlFilterTranslator.cs | 22 +- .../PostgresFilterTranslator.cs | 28 +- ...PostgresVectorStoreCollectionSqlBuilder.cs | 3 +- .../ExceptionWrapper.cs | 67 +++++ .../SqlServerCommandBuilder.cs | 8 +- .../SqlServerConstants.cs | 44 +++ .../SqlServerFilterTranslator.cs | 28 +- .../SqlServerVectorStore.cs | 97 +------ .../SqlServerVectorStoreOptions.cs | 9 +- .../SqlServerVectorStoreRecordCollection.cs | 262 +++++++++++------- .../SqliteFilterTranslator.cs | 27 +- .../SqliteVectorStoreRecordCollection.cs | 2 +- .../VectorStoreRecordKeyAttribute.cs | 4 +- .../VectorStoreRecordKeyProperty.cs | 6 +- .../Data/VectorStoreRecordPropertyReader.cs | 5 +- .../Filter/SqlServerBasicFilterTests.cs | 19 +- .../SqlServerCommandBuilderTests.cs | 29 +- .../SqlServerIntegrationTests.csproj | 3 +- .../SqlServerMemoryStoreTests.cs | 29 +- .../SqlServerVectorStoreTests.cs | 124 +++++++-- 20 files changed, 506 insertions(+), 310 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.Memory.SqlServer/ExceptionWrapper.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs diff --git a/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs index ac708302b879..c4b9f146201c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs @@ -12,12 +12,12 @@ namespace Microsoft.SemanticKernel.Connectors; -internal partial class SqlFilterTranslator +internal abstract class SqlFilterTranslator { private readonly IReadOnlyDictionary _storagePropertyNames; private readonly LambdaExpression _lambdaExpression; private readonly ParameterExpression _recordParameter; - private readonly StringBuilder _sql; + protected readonly StringBuilder _sql; internal SqlFilterTranslator( IReadOnlyDictionary storagePropertyNames, @@ -43,7 +43,7 @@ internal void Translate(bool appendWhere) this.Translate(this._lambdaExpression.Body); } - private void Translate(Expression? node) + protected void Translate(Expression? node) { switch (node) { @@ -130,7 +130,7 @@ static bool IsNull(Expression expression) private void TranslateConstant(ConstantExpression constant) => this.GenerateLiteral(constant.Value); - private void GenerateLiteral(object? value) + protected void GenerateLiteral(object? value) { // TODO: Nullable switch (value) @@ -178,6 +178,14 @@ private void GenerateLiteral(object? value) } } + protected abstract void GenerateLiteral(bool value); + + protected virtual void GenerateLiteral(DateTime dateTime) + => throw new NotImplementedException(); + + protected virtual void GenerateLiteral(DateTimeOffset dateTimeOffset) + => throw new NotImplementedException(); + private void TranslateMember(MemberExpression memberExpression) { switch (memberExpression) @@ -196,6 +204,8 @@ private void TranslateMember(MemberExpression memberExpression) } } + protected abstract void TranslateLambdaVariables(string name, object? capturedValue); + private void TranslateMethodCall(MethodCallExpression methodCall) { switch (methodCall) @@ -267,6 +277,10 @@ private void TranslateContains(Expression source, Expression item) } } + protected abstract void TranslateContainsOverArrayColumn(Expression source, Expression item); + + protected abstract void TranslateContainsOverCapturedArray(Expression source, Expression item, object? value); + private void TranslateUnary(UnaryExpression unary) { switch (unary.NodeType) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs index a181ebc3ecbb..1780e176cbc4 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs @@ -1,33 +1,29 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.Collections.Generic; using System.Linq.Expressions; -namespace Microsoft.SemanticKernel.Connectors; +namespace Microsoft.SemanticKernel.Connectors.Postgres; -internal partial class SqlFilterTranslator +internal sealed class PostgresFilterTranslator : SqlFilterTranslator { private readonly List _parameterValues = new(); private int _parameterIndex; - internal List ParameterValues => this._parameterValues; - - internal void Initialize(int startParamIndex) + internal PostgresFilterTranslator( + IReadOnlyDictionary storagePropertyNames, + LambdaExpression lambdaExpression, + int startParamIndex) : base(storagePropertyNames, lambdaExpression, sql: null) { this._parameterIndex = startParamIndex; } - private void GenerateLiteral(bool value) - => this._sql.Append(value ? "TRUE" : "FALSE"); - - private void GenerateLiteral(DateTime dateTime) - => throw new NotImplementedException(); + internal List ParameterValues => this._parameterValues; - private void GenerateLiteral(DateTimeOffset dateTimeOffset) - => throw new NotImplementedException(); + protected override void GenerateLiteral(bool value) + => this._sql.Append(value ? "TRUE" : "FALSE"); - private void TranslateContainsOverArrayColumn(Expression source, Expression item) + protected override void TranslateContainsOverArrayColumn(Expression source, Expression item) { this.Translate(source); this._sql.Append(" @> ARRAY["); @@ -35,7 +31,7 @@ private void TranslateContainsOverArrayColumn(Expression source, Expression item this._sql.Append(']'); } - private void TranslateContainsOverCapturedArray(Expression source, Expression item, object? _) + protected override void TranslateContainsOverCapturedArray(Expression source, Expression item, object? value) { this.Translate(item); this._sql.Append(" = ANY ("); @@ -43,7 +39,7 @@ private void TranslateContainsOverCapturedArray(Expression source, Expression it this._sql.Append(')'); } - private void TranslateLambdaVariables(string _, object? capturedValue) + protected override void TranslateLambdaVariables(string name, object? capturedValue) { // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index b42ce6add8b6..71d1448c85cc 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -427,8 +427,7 @@ ORDER BY {PostgresConstants.DistanceColumnName} internal static (string Clause, List Parameters) GenerateNewFilterWhereClause(VectorStoreRecordPropertyReader propertyReader, LambdaExpression newFilter) { - SqlFilterTranslator translator = new(propertyReader.StoragePropertyNamesMap, newFilter); - translator.Initialize(startParamIndex: 2); + PostgresFilterTranslator translator = new(propertyReader.StoragePropertyNamesMap, newFilter, startParamIndex: 2); translator.Translate(appendWhere: true); return (translator.Clause.ToString(), translator.ParameterValues); } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/ExceptionWrapper.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/ExceptionWrapper.cs new file mode 100644 index 000000000000..4a5fa37c0829 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/ExceptionWrapper.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +internal static class ExceptionWrapper +{ + private const string VectorStoreType = "SqlServer"; + + internal static async Task WrapAsync( + SqlConnection connection, + SqlCommand command, + Func> func, + CancellationToken cancellationToken, + string operationName, + string? collectionName = null) + { + if (connection.State != System.Data.ConnectionState.Open) + { + await connection.OpenAsync(cancellationToken).ConfigureAwait(false); + } + + try + { + return await func(command, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + throw new VectorStoreOperationException(ex.Message, ex) + { + OperationName = operationName, + VectorStoreType = VectorStoreType, + CollectionName = collectionName + }; + } + } + + internal static async Task WrapReadAsync( + SqlDataReader reader, + CancellationToken cancellationToken, + string operationName, + string? collectionName = null) + { + try + { + return await reader.ReadAsync(cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + throw new VectorStoreOperationException(ex.Message, ex) + { + OperationName = operationName, + VectorStoreType = VectorStoreType, + CollectionName = collectionName + }; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 4307a064a2c8..17e180f10a6f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -317,7 +317,7 @@ internal static SqlCommand SelectVector( { DistanceFunction.CosineDistance => "cosine", DistanceFunction.EuclideanDistance => "euclidean", - DistanceFunction.DotProductSimilarity => "dot", + DistanceFunction.NegativeDotProductSimilarity => "dot", _ => throw new NotSupportedException($"Distance function {vectorProperty.DistanceFunction} is not supported.") }; @@ -338,8 +338,7 @@ internal static SqlCommand SelectVector( { int startParamIndex = command.Parameters.Count; - SqlFilterTranslator translator = new(storagePropertyNamesMap, options.NewFilter, sb); - translator.Initialize(startParamIndex: startParamIndex); + SqlServerFilterTranslator translator = new(storagePropertyNamesMap, options.NewFilter, sb, startParamIndex: startParamIndex); translator.Translate(appendWhere: true); List parameters = translator.ParameterValues; @@ -505,9 +504,6 @@ private static (string sqlName, string? autoGenerate) Map(Type type) Type t when t == typeof(decimal) => ("DECIMAL", null), Type t when t == typeof(double) => ("FLOAT", null), Type t when t == typeof(float) => ("REAL", null), - // Collections don't have good native support, we store them as JSON - Type t when t == typeof(string[]) => (NVARCHAR, null), - Type t when t == typeof(List) => (NVARCHAR, null), _ => throw new NotSupportedException($"Type {type} is not supported.") }; } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs new file mode 100644 index 000000000000..b8842144258c --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +internal static class SqlServerConstants +{ + internal const string Schema = "dbo"; + + internal static readonly HashSet SupportedKeyTypes = + [ + typeof(int), // INT + typeof(long), // BIGINT + typeof(string), // VARCHAR + typeof(Guid), // UNIQUEIDENTIFIER + typeof(DateTime), // DATETIME + typeof(byte[]) // VARBINARY + ]; + + internal static readonly HashSet SupportedDataTypes = + [ + typeof(int), // INT + typeof(short), // SMALLINT + typeof(byte), // TINYINT + typeof(long), // BIGINT. + typeof(Guid), // UNIQUEIDENTIFIER. + typeof(string), // NVARCHAR + typeof(byte[]), //VARBINARY + typeof(bool), // BIT + typeof(DateTime), // DATETIME + typeof(TimeSpan), // TIME + typeof(decimal), // DECIMAL + typeof(double), // FLOAT + typeof(float), // REAL + ]; + + internal static readonly HashSet SupportedVectorTypes = + [ + typeof(ReadOnlyMemory), // VECTOR + typeof(ReadOnlyMemory?) + ]; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs index 379b91317b0a..d46e1147d963 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs @@ -4,34 +4,40 @@ using System.Collections; using System.Collections.Generic; using System.Linq.Expressions; +using System.Text; -namespace Microsoft.SemanticKernel.Connectors; +namespace Microsoft.SemanticKernel.Connectors.SqlServer; -internal partial class SqlFilterTranslator +internal sealed class SqlServerFilterTranslator : SqlFilterTranslator { private readonly List _parameterValues = new(); private int _parameterIndex; - internal List ParameterValues => this._parameterValues; - - internal void Initialize(int startParamIndex) + internal SqlServerFilterTranslator( + IReadOnlyDictionary storagePropertyNames, + LambdaExpression lambdaExpression, + StringBuilder sql, + int startParamIndex) + : base(storagePropertyNames, lambdaExpression, sql) { this._parameterIndex = startParamIndex; } - private void GenerateLiteral(bool value) + internal List ParameterValues => this._parameterValues; + + protected override void GenerateLiteral(bool value) => this._sql.Append(value ? "1" : "0"); - private void GenerateLiteral(DateTime dateTime) + protected override void GenerateLiteral(DateTime dateTime) => this._sql.AppendFormat("'{0:yyyy-MM-dd HH:mm:ss}'", dateTime); - private void GenerateLiteral(DateTimeOffset dateTimeOffset) + protected override void GenerateLiteral(DateTimeOffset dateTimeOffset) => this._sql.AppendFormat("'{0:yyy-MM-dd HH:mm:ss zzz}'", dateTimeOffset); - private void TranslateContainsOverArrayColumn(Expression source, Expression item) + protected override void TranslateContainsOverArrayColumn(Expression source, Expression item) => throw new NotSupportedException("Unsupported Contains expression"); - private void TranslateContainsOverCapturedArray(Expression source, Expression item, object? value) + protected override void TranslateContainsOverCapturedArray(Expression source, Expression item, object? value) { if (value is not IEnumerable elements) { @@ -59,7 +65,7 @@ private void TranslateContainsOverCapturedArray(Expression source, Expression it this._sql.Append(')'); } - private void TranslateLambdaVariables(string _, object? capturedValue) + protected override void TranslateLambdaVariables(string name, object? capturedValue) { // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs index b864563fc8e2..45e53faf69b7 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Threading; @@ -13,48 +12,8 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer; /// /// An implementation of backed by a SQL Server or Azure SQL database. /// -// TODO adsitnik: design: the interface is not generic, so I am not sure how the users can customize the -// mapping between the record and the table. Am I missing something? -// The interface I am talking about: public IVectorStoreRecordMapper>. public sealed class SqlServerVectorStore : IVectorStore, IDisposable { - private static readonly ConcurrentDictionary s_propertyReaders = new(); - - private static readonly HashSet s_supportedKeyTypes = - [ - typeof(int), // INT - typeof(long), // BIGINT - typeof(string), // VARCHAR - typeof(Guid), // UNIQUEIDENTIFIER - typeof(DateTime), // DATETIME - typeof(byte[]) // VARBINARY - ]; - - private static readonly HashSet s_supportedDataTypes = - [ - typeof(int), // INT - typeof(short), // SMALLINT - typeof(byte), // TINYINT - typeof(long), // BIGINT. - typeof(Guid), // UNIQUEIDENTIFIER. - typeof(string), // NVARCHAR - typeof(byte[]), //VARBINARY - typeof(bool), // BIT - typeof(DateTime), // DATETIME - typeof(TimeSpan), // TIME - typeof(decimal), // DECIMAL - typeof(double), // FLOAT - typeof(float), // REAL, - typeof(string[]), // NVARCHAR accessed as JSON - typeof(List) // NVARCHAR accessed as JSON - ]; - - internal static readonly HashSet s_supportedVectorTypes = - [ - typeof(ReadOnlyMemory), // VECTOR - typeof(ReadOnlyMemory?) - ]; - private readonly SqlConnection _connection; private readonly SqlServerVectorStoreOptions _options; @@ -73,71 +32,35 @@ public SqlServerVectorStore(SqlConnection connection, SqlServerVectorStoreOption // We need to create a copy, so any changes made to the option bag after // the ctor call do not affect this instance. this._options = options is not null - ? new() { Schema = options.Schema, EmbeddingDimensionsCount = options.EmbeddingDimensionsCount } + ? new() { Schema = options.Schema } : SqlServerVectorStoreOptions.Defaults; } /// public void Dispose() => this._connection.Dispose(); - // TODO: adsitnik: design - // I find the creation process not intuitive: the IVectorStoreRecordCollection.Create - // method does take only table name as an arugment, the metadata needs to be provided - // a step before that by passing the VectorStoreRecordDefinition to the GetCollection method. - // I would expect VectorStoreRecordDefinition to be argument of the CreateCollectionAsync. - // Also, please consider another problem: - // On Monday, I pass two arguments to GetCollection: - // a name: "theName" - // and a definition: "theDefinition" that consists of two properties - // When I call CreateCollectionAsync, it gets created. - // On Tuesday, I pass the same name, but a different definition: three properties. - // Now CollectionExistsAsync returns true, despite the properties mismatch?! /// public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull { Verify.NotNull(name); - if (!s_propertyReaders.TryGetValue(typeof(TRecord), out VectorStoreRecordPropertyReader? propertyReader)) - { - propertyReader = new(typeof(TRecord), - // TODO adsitnik: should we cache the property reader when user has provided the VectorStoreRecordDefinition? - vectorStoreRecordDefinition, - new() - { - RequiresAtLeastOneVector = false, - // TODO adsitnik: design: can TKey represent a composite key (PRIMARY KEY)? - SupportsMultipleKeys = false, - SupportsMultipleVectors = true, - }); - - propertyReader.VerifyKeyProperties(s_supportedKeyTypes); - propertyReader.VerifyDataProperties(s_supportedDataTypes, supportEnumerable: false); - propertyReader.VerifyVectorProperties(s_supportedVectorTypes); - - if (propertyReader.KeyProperty.AutoGenerate - && !(typeof(TKey) == typeof(int) || typeof(TKey) == typeof(long) || typeof(TKey) == typeof(Guid))) - { - // SQL Server does not support auto-generated keys for types other than int, long, and Guid. - throw new ArgumentException("Key property cannot be auto-generated."); - } - - // Add to the cache once we have verified the record definition. - s_propertyReaders.TryAdd(typeof(TRecord), propertyReader); - } - return new SqlServerVectorStoreRecordCollection( this._connection, name, - this._options, - propertyReader); + vectorStoreRecordDefinition, + this._options); } /// public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { - using SqlCommand cmd = SqlServerCommandBuilder.SelectTableNames(this._connection, this._options.Schema); - using SqlDataReader reader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + using SqlCommand command = SqlServerCommandBuilder.SelectTableNames(this._connection, this._options.Schema); + + using SqlDataReader reader = await ExceptionWrapper.WrapAsync(this._connection, command, + static (cmd, ct) => cmd.ExecuteReaderAsync(ct), + cancellationToken, "ListCollection").ConfigureAwait(false); + + while (await ExceptionWrapper.WrapReadAsync(reader, cancellationToken, "ListCollection").ConfigureAwait(false)) { yield return reader.GetString(reader.GetOrdinal("table_name")); } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs index 653ec9fb1c18..a1d809c0face 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs @@ -12,12 +12,5 @@ public sealed class SqlServerVectorStoreOptions /// /// Gets or sets the database schema. /// - public string Schema { get; init; } = "dbo"; - - /// - /// Number of dimensions that stored embeddings will use. - /// - // TODO: adsitnik: this design most likely won't need this setting, - // as it up to the TRecrod to define the dimensions. - public int EmbeddingDimensionsCount { get; init; } = 1536; + public string Schema { get; init; } = SqlServerConstants.Schema; } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index 73f043c23d9d..19eb9058e96b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Reflection; using System.Runtime.CompilerServices; using System.Text.Json; using System.Threading; @@ -12,7 +13,10 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer; -internal sealed class SqlServerVectorStoreRecordCollection : IVectorStoreRecordCollection +/// +/// An implementation of backed by a SQL Server or Azure SQL database. +/// +public sealed class SqlServerVectorStoreRecordCollection : IVectorStoreRecordCollection where TKey : notnull { private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); @@ -21,41 +25,87 @@ internal sealed class SqlServerVectorStoreRecordCollection : IVec private readonly SqlServerVectorStoreOptions _options; private readonly VectorStoreRecordPropertyReader _propertyReader; - internal SqlServerVectorStoreRecordCollection(SqlConnection sqlConnection, string name, SqlServerVectorStoreOptions options, VectorStoreRecordPropertyReader propertyReader) + /// + /// Initializes a new instance of the class. + /// + /// Database connection. + /// The name of the collection. + /// Optional record definition. + /// Optional configuration options. + public SqlServerVectorStoreRecordCollection( + SqlConnection connection, + string name, + VectorStoreRecordDefinition? vectorStoreRecordDefinition = null, + SqlServerVectorStoreOptions? vectorStoreOptions = null) { - this._sqlConnection = sqlConnection; + Verify.NotNull(connection); + Verify.NotNull(name); + + VectorStoreRecordPropertyReader propertyReader = new(typeof(TRecord), + vectorStoreRecordDefinition, + new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + }); + + propertyReader.VerifyHasParameterlessConstructor(); + propertyReader.VerifyKeyProperties(SqlServerConstants.SupportedKeyTypes); + propertyReader.VerifyDataProperties(SqlServerConstants.SupportedDataTypes, supportEnumerable: false); + propertyReader.VerifyVectorProperties(SqlServerConstants.SupportedVectorTypes); + + if (propertyReader.KeyProperty.AutoGenerate + && !(typeof(TKey) == typeof(int) || typeof(TKey) == typeof(long) || typeof(TKey) == typeof(Guid))) + { + // SQL Server does not support auto-generated keys for types other than int, long, and Guid. + throw new ArgumentException("Key property cannot be auto-generated."); + } + + this._sqlConnection = connection; this.CollectionName = name; - this._options = options; + // We need to create a copy, so any changes made to the option bag after + // the ctor call do not affect this instance. + this._options = vectorStoreOptions is not null + ? new() { Schema = vectorStoreOptions.Schema } + : SqlServerVectorStoreOptions.Defaults; this._propertyReader = propertyReader; } + /// public string CollectionName { get; } + /// public async Task CollectionExistsAsync(CancellationToken cancellationToken = default) { - await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); - using SqlCommand command = SqlServerCommandBuilder.SelectTableName( this._sqlConnection, this._options.Schema, this.CollectionName); - using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - return await reader.ReadAsync(cancellationToken).ConfigureAwait(false); + return await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + static async (cmd, ct) => + { + using SqlDataReader reader = await cmd.ExecuteReaderAsync(ct).ConfigureAwait(false); + return await reader.ReadAsync(ct).ConfigureAwait(false); + }, cancellationToken, "CollectionExists", this.CollectionName).ConfigureAwait(false); } + /// public Task CreateCollectionAsync(CancellationToken cancellationToken = default) => this.CreateCollectionAsync(ifNotExists: false, cancellationToken); - // TODO adsitnik: design: We typically don't provide such methods in BCL. - // 1. I totally see why we want to provide it, we just need to make sure it's the right thing to do. - // 2. An alternative would be to make CreateCollectionAsync a nop when the collection already exists - // or extend it with an optional boolean parameter that would control the behavior. - // 3. We may need it to avoid TOCTOU issues. + /// public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) => this.CreateCollectionAsync(ifNotExists: true, cancellationToken); private async Task CreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken) { - await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); + foreach (var vectorProperty in this._propertyReader.VectorProperties) + { + if (vectorProperty.Dimensions is not > 0) + { + throw new InvalidOperationException($"Property {nameof(vectorProperty.Dimensions)} on {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' must be set to a positive integer to create a collection."); + } + } using SqlCommand command = SqlServerCommandBuilder.CreateTable( this._sqlConnection, @@ -66,25 +116,27 @@ private async Task CreateCollectionAsync(bool ifNotExists, CancellationToken can this._propertyReader.DataProperties, this._propertyReader.VectorProperties); - await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + static (cmd, ct) => cmd.ExecuteNonQueryAsync(ct), + cancellationToken, "CreateCollection", this.CollectionName).ConfigureAwait(false); } + /// public async Task DeleteCollectionAsync(CancellationToken cancellationToken = default) { - await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); - using SqlCommand command = SqlServerCommandBuilder.DropTableIfExists( this._sqlConnection, this._options.Schema, this.CollectionName); - await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + static (cmd, ct) => cmd.ExecuteNonQueryAsync(ct), + cancellationToken, "DeleteCollection", this.CollectionName).ConfigureAwait(false); } + /// public async Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) { Verify.NotNull(key); - await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); - using SqlCommand command = SqlServerCommandBuilder.DeleteSingle( this._sqlConnection, this._options.Schema, @@ -92,15 +144,16 @@ public async Task DeleteAsync(TKey key, CancellationToken cancellationToken = de this._propertyReader.KeyProperty, key); - await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + static (cmd, ct) => cmd.ExecuteNonQueryAsync(ct), + cancellationToken, "Delete", this.CollectionName).ConfigureAwait(false); } + /// public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) { Verify.NotNull(keys); - await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); - using SqlCommand command = SqlServerCommandBuilder.DeleteMany( this._sqlConnection, this._options.Schema, @@ -108,15 +161,16 @@ public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken can this._propertyReader.KeyProperty, keys); - await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + static (cmd, ct) => cmd.ExecuteNonQueryAsync(ct), + cancellationToken, "DeleteBatch", this.CollectionName).ConfigureAwait(false); } + /// public async Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(key); - await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); - using SqlCommand command = SqlServerCommandBuilder.SelectSingle( this._sqlConnection, this._options.Schema, @@ -125,20 +179,23 @@ public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken can this._propertyReader.Properties, key); - using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + using SqlDataReader reader = await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + static async (cmd, ct) => + { + SqlDataReader reader = await cmd.ExecuteReaderAsync(ct).ConfigureAwait(false); + await reader.ReadAsync(ct).ConfigureAwait(false); + return reader; + }, cancellationToken, "Get", this.CollectionName).ConfigureAwait(false); - return await reader.ReadAsync(cancellationToken).ConfigureAwait(false) - ? Map(reader, this._propertyReader) - : default; + return reader.HasRows ? Map(reader, this._propertyReader) : default; } + /// public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { Verify.NotNull(keys); - await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); - using SqlCommand command = SqlServerCommandBuilder.SelectMany( this._sqlConnection, this._options.Schema, @@ -147,19 +204,21 @@ public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, Get this._propertyReader.Properties, keys); - using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + using SqlDataReader reader = await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + static (cmd, ct) => cmd.ExecuteReaderAsync(ct), + cancellationToken, "GetBatch", this.CollectionName).ConfigureAwait(false); + + while (await ExceptionWrapper.WrapReadAsync(reader, cancellationToken, "GetBatch", this.CollectionName).ConfigureAwait(false)) { yield return Map(reader, this._propertyReader); } } + /// public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) { Verify.NotNull(record); - await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); - using SqlCommand command = SqlServerCommandBuilder.MergeIntoSingle( this._sqlConnection, this._options, @@ -168,18 +227,21 @@ public async Task UpsertAsync(TRecord record, CancellationToken cancellati this._propertyReader.Properties, Map(record, this._propertyReader)); - using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - await reader.ReadAsync(cancellationToken).ConfigureAwait(false); - return reader.GetFieldValue(0); + return await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + async static (cmd, ct) => + { + using SqlDataReader reader = await cmd.ExecuteReaderAsync(ct).ConfigureAwait(false); + await reader.ReadAsync(ct).ConfigureAwait(false); + return reader.GetFieldValue(0); + }, cancellationToken, "Upsert", this.CollectionName).ConfigureAwait(false); } + /// public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) { Verify.NotNull(records); - await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); - using SqlCommand command = SqlServerCommandBuilder.MergeIntoMany( this._sqlConnection, this._options, @@ -188,14 +250,18 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record this._propertyReader.Properties, records.Select(record => Map(record, this._propertyReader))); - using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + using SqlDataReader reader = await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + static (cmd, ct) => cmd.ExecuteReaderAsync(ct), + cancellationToken, "GetBatch", this.CollectionName).ConfigureAwait(false); + + while (await ExceptionWrapper.WrapReadAsync(reader, cancellationToken, "GetBatch", this.CollectionName).ConfigureAwait(false)) { yield return reader.GetFieldValue(0); } } - public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + /// + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); @@ -203,7 +269,7 @@ public Task> VectorizedSearchAsync(TVector { throw new NotSupportedException( $"The provided vector type {vector.GetType().FullName} is not supported by the SQL Server connector. " + - $"Supported types are: {string.Join(", ", SqlServerVectorStore.s_supportedVectorTypes.Select(l => l.FullName))}"); + $"Supported types are: {string.Join(", ", SqlServerConstants.SupportedVectorTypes.Select(l => l.FullName))}"); } #pragma warning disable CS0618 // Type or member is obsolete else if (options is not null && options.Filter is not null) @@ -215,18 +281,6 @@ public Task> VectorizedSearchAsync(TVector var searchOptions = options ?? s_defaultVectorSearchOptions; var vectorProperty = this._propertyReader.GetVectorPropertyForSearch(searchOptions.VectorPropertyName); - var results = this.ReadVectorSearchResultsAsync(allowed, vectorProperty, searchOptions, cancellationToken); - return Task.FromResult(new VectorSearchResults(results)); - } - - private async IAsyncEnumerable> ReadVectorSearchResultsAsync( - ReadOnlyMemory vector, - VectorStoreRecordVectorProperty vectorProperty, - VectorSearchOptions searchOptions, - [EnumeratorCancellation] CancellationToken cancellationToken) - { - await this.EnsureConnectionIsOpenedAsync(cancellationToken).ConfigureAwait(false); - using SqlCommand command = SqlServerCommandBuilder.SelectVector( this._sqlConnection, this._options.Schema, @@ -235,8 +289,20 @@ private async IAsyncEnumerable> ReadVectorSearchResu this._propertyReader.Properties, this._propertyReader.StoragePropertyNamesMap, searchOptions, - vector); + allowed); + return await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + (cmd, ct) => + { + var results = this.ReadVectorSearchResultsAsync(cmd, ct); + return Task.FromResult(new VectorSearchResults(results)); + }, cancellationToken, "VectorizedSearch", this.CollectionName).ConfigureAwait(false); + } + + private async IAsyncEnumerable> ReadVectorSearchResultsAsync( + SqlCommand command, + [EnumeratorCancellation] CancellationToken cancellationToken) + { using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); int scoreIndex = -1; @@ -253,11 +319,6 @@ private async IAsyncEnumerable> ReadVectorSearchResu } } - private Task EnsureConnectionIsOpenedAsync(CancellationToken cancellationToken) - => this._sqlConnection.State == System.Data.ConnectionState.Open - ? Task.CompletedTask - : this._sqlConnection.OpenAsync(cancellationToken); - private static Dictionary Map(TRecord record, VectorStoreRecordPropertyReader propertyReader) { Dictionary map = new(StringComparer.Ordinal); @@ -267,15 +328,7 @@ private Task EnsureConnectionIsOpenedAsync(CancellationToken cancellationToken) for (int i = 0; i < dataProperties.Count; i++) { object value = propertyReader.DataPropertiesInfo[i].GetValue(record); - // SQL Server does not support arrays, so we need to serialize them to JSON. - object? mappedValue = value switch - { - string[] array => JsonSerializer.Serialize(array), - List list => JsonSerializer.Serialize(list), - _ => value - }; - - map[dataProperties[i].DataModelPropertyName] = mappedValue; + map[dataProperties[i].DataModelPropertyName] = value; } var vectorProperties = propertyReader.VectorProperties; for (int i = 0; i < vectorProperties.Count; i++) @@ -284,8 +337,6 @@ private Task EnsureConnectionIsOpenedAsync(CancellationToken cancellationToken) ReadOnlyMemory floats = (ReadOnlyMemory)propertyReader.VectorPropertiesInfo[i].GetValue(record); // We know that SqlServer supports JSON serialization, so we can serialize the vector as JSON now, // so the SqlServerCommandBuilder does not need to worry about that. - // TODO adsitnik perf: we could remove the dependency to System.Text.Json - // by using a hand-written serializer. map[vectorProperties[i].DataModelPropertyName] = JsonSerializer.Serialize(floats); } @@ -295,34 +346,12 @@ private Task EnsureConnectionIsOpenedAsync(CancellationToken cancellationToken) private static TRecord Map(SqlDataReader reader, VectorStoreRecordPropertyReader propertyReader) { TRecord record = Activator.CreateInstance()!; - propertyReader.KeyPropertyInfo.SetValue(record, reader[SqlServerCommandBuilder.GetColumnName(propertyReader.KeyProperty)]); + SetValue(reader, record, propertyReader.KeyPropertyInfo, propertyReader.KeyProperty); var data = propertyReader.DataProperties; var dataInfo = propertyReader.DataPropertiesInfo; for (int i = 0; i < data.Count; i++) { - object value = reader[SqlServerCommandBuilder.GetColumnName(data[i])]; - if (value is DBNull) - { - // There is no need to call the reflection to set the null, - // as it's the default value of every .NET reference type field. - continue; - } - - if (value is not string text) - { - dataInfo[i].SetValue(record, value); - } - else - { - // SQL Server does not support arrays, so we need to deserialize them from JSON. - object? mappedValue = data[i].PropertyType switch - { - Type t when t == typeof(string[]) => JsonSerializer.Deserialize(text), - Type t when t == typeof(List) => JsonSerializer.Deserialize>(text), - _ => text - }; - dataInfo[i].SetValue(record, mappedValue); - } + SetValue(reader, record, dataInfo[i], data[i]); } var vector = propertyReader.VectorProperties; @@ -332,11 +361,44 @@ private static TRecord Map(SqlDataReader reader, VectorStoreRecordPropertyReader object value = reader[SqlServerCommandBuilder.GetColumnName(vector[i])]; if (value is not DBNull) { - // We know that it has to be a ReadOnlyMemory because that's what we serialized. - ReadOnlyMemory embedding = JsonSerializer.Deserialize>((string)value); + ReadOnlyMemory? embedding = null; + + try + { + // This may fail if the user has stored a non-float array in the database + // (or serialized it in a different way). + embedding = JsonSerializer.Deserialize>((string)value); + } + catch (Exception ex) + { + throw new VectorStoreRecordMappingException($"Failed to deserialize vector property '{vector[i].DataModelPropertyName}', it contained value '{value}'.", ex); + } + vectorInfo[i].SetValue(record, embedding); } } return record; + + static void SetValue(SqlDataReader reader, object record, PropertyInfo propertyInfo, VectorStoreRecordProperty property) + { + // If we got here, there should be no column name mismatch (the query would fail). + object value = reader[SqlServerCommandBuilder.GetColumnName(property)]; + + if (value is DBNull) + { + // There is no need to call the reflection to set the null, + // as it's the default value of every .NET reference type field. + return; + } + + try + { + propertyInfo.SetValue(record, value); + } + catch (Exception ex) + { + throw new VectorStoreRecordMappingException($"Failed to set value '{value}' on property '{propertyInfo.Name}' of type '{propertyInfo.PropertyType.FullName}'.", ex); + } + } } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs index 07bc1894fd2e..8489301ad1f8 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs @@ -5,28 +5,27 @@ using System.Collections.Generic; using System.Linq.Expressions; -namespace Microsoft.SemanticKernel.Connectors; +namespace Microsoft.SemanticKernel.Connectors.Sqlite; -internal partial class SqlFilterTranslator +internal sealed class SqliteFilterTranslator : SqlFilterTranslator { private readonly Dictionary _parameters = new(); + internal SqliteFilterTranslator(IReadOnlyDictionary storagePropertyNames, + LambdaExpression lambdaExpression) : base(storagePropertyNames, lambdaExpression, sql: null) + { + } + internal Dictionary Parameters => this._parameters; - private void GenerateLiteral(bool value) + protected override void GenerateLiteral(bool value) => this._sql.Append(value ? "TRUE" : "FALSE"); - private void GenerateLiteral(DateTime dateTime) - => throw new NotImplementedException(); - - private void GenerateLiteral(DateTimeOffset dateTimeOffset) - => throw new NotImplementedException(); - // TODO: support Contains over array fields (#10343) - private void TranslateContainsOverArrayColumn(Expression source, Expression item) + protected override void TranslateContainsOverArrayColumn(Expression source, Expression item) => throw new NotSupportedException("Unsupported Contains expression"); - private void TranslateContainsOverCapturedArray(Expression source, Expression item, object? value) + protected override void TranslateContainsOverCapturedArray(Expression source, Expression item, object? value) { if (value is not IEnumerable elements) { @@ -54,11 +53,11 @@ private void TranslateContainsOverCapturedArray(Expression source, Expression it this._sql.Append(')'); } - private void TranslateLambdaVariables(string name, object? value) + protected override void TranslateLambdaVariables(string name, object? capturedValue) { // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) - if (value is null) + if (capturedValue is null) { this._sql.Append("NULL"); } @@ -76,7 +75,7 @@ private void TranslateLambdaVariables(string name, object? value) } while (this._parameters.ContainsKey(name)); } - this._parameters.Add(name, value); + this._parameters.Add(name, capturedValue); this._sql.Append('@').Append(name); } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs index 8383dd4eceaa..e3c97a149e2c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs @@ -210,7 +210,7 @@ public virtual Task> VectorizedSearchAsync } else if (searchOptions.NewFilter is not null) { - SqlFilterTranslator translator = new(this._propertyReader.StoragePropertyNamesMap, searchOptions.NewFilter); + SqliteFilterTranslator translator = new(this._propertyReader.StoragePropertyNamesMap, searchOptions.NewFilter); translator.Translate(appendWhere: false); extraWhereFilter = translator.Clause.ToString(); extraParameters = translator.Parameters; diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs index 7577547f50ed..21382344b445 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs @@ -26,8 +26,8 @@ public sealed class VectorStoreRecordKeyAttribute : Attribute /// The default is . /// /// - /// If set to , you must set the key property on any record that you pass to . - /// If set to , the key property may be left null on any record that you pass to + /// If set to , you must set the key property on any record that you pass to . + /// If set to , the key property may be left null on any record that you pass to /// and a generated key will be returned. /// public bool AutoGenerate { get; set; } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs index 5fa216165d8d..2ea0b97e2f20 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs @@ -17,11 +17,9 @@ public sealed class VectorStoreRecordKeyProperty : VectorStoreRecordProperty /// /// The name of the property. /// The type of the property. - /// Whether the key should be auto-generated by the vector store. - public VectorStoreRecordKeyProperty(string propertyName, Type propertyType, bool autoGenerate = false) + public VectorStoreRecordKeyProperty(string propertyName, Type propertyType) : base(propertyName, propertyType) { - this.AutoGenerate = autoGenerate; } /// @@ -36,5 +34,5 @@ public VectorStoreRecordKeyProperty(VectorStoreRecordKeyProperty source) /// /// Gets a value indicating whether the key should be auto-generated by the vector store. /// - public bool AutoGenerate { get; } + public bool AutoGenerate { get; init; } } diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs index 22ca6c113406..7b64c494e883 100644 --- a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs +++ b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs @@ -541,9 +541,10 @@ private static VectorStoreRecordDefinition CreateVectorStoreRecordDefinitionFrom var keyAttribute = keyProperty.GetCustomAttribute(); if (keyAttribute is not null) { - definitionProperties.Add(new VectorStoreRecordKeyProperty(keyProperty.Name, keyProperty.PropertyType, keyAttribute.AutoGenerate) + definitionProperties.Add(new VectorStoreRecordKeyProperty(keyProperty.Name, keyProperty.PropertyType) { - StoragePropertyName = keyAttribute.StoragePropertyName + StoragePropertyName = keyAttribute.StoragePropertyName, + AutoGenerate = keyAttribute.AutoGenerate }); } } diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs index ff9fb0376a64..ce9c6b4bde0a 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using Microsoft.Extensions.VectorData; using SqlServerIntegrationTests.Support; using VectorDataSpecificationTests.Filter; using VectorDataSpecificationTests.Support; @@ -38,10 +39,10 @@ public override async Task NotEqual_with_string() } public override Task Contains_over_field_string_array() - => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); public override Task Contains_over_field_string_List() - => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); [Fact(Skip = "Not supported")] [Obsolete("Legacy filters are not supported")] @@ -62,5 +63,19 @@ public override Task Contains_over_field_string_List() public new class Fixture : BasicFilterTests.Fixture { public override TestStore TestStore => SqlServerTestStore.Instance; + + protected override string CollectionName +#if NET // make sure different TFMs use different collection names (as they may run in parralel and cause trouble) + => "FilterTests-core"; +#else + => "FilterTests-framework"; +#endif + + // Override to remove the string collection properties, which aren't (currently) supported on SqlServer + protected override VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = base.GetRecordDefinition().Properties.Where(p => p.PropertyType != typeof(string[]) && p.PropertyType != typeof(List)).ToList() + }; } } diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index 5a8574990310..2c04ffcb09b5 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -109,7 +109,10 @@ public void CreateTable(bool ifNotExists) { Schema = "schema" }; - VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long), autoGenerate: true); + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)) + { + AutoGenerate = true + }; VectorStoreRecordDataProperty[] dataProperties = [ new VectorStoreRecordDataProperty("simpleName", typeof(string)), @@ -139,10 +142,10 @@ PRIMARY KEY NONCLUSTERED ([id]) """; if (ifNotExists) { - expectedCommand = "IF OBJECT_ID(N'[schema].[table]', N'U') IS NULL\n" + expectedCommand; + expectedCommand = "IF OBJECT_ID(N'[schema].[table]', N'U') IS NULL" + Environment.NewLine + expectedCommand; } - Assert.Equal(HandleNewLines(expectedCommand), command.CommandText); + Assert.Equal(expectedCommand, command.CommandText); } [Fact] @@ -152,7 +155,10 @@ public void MergeIntoSingle() { Schema = "schema" }; - VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long), autoGenerate: true); + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)) + { + AutoGenerate = true + }; VectorStoreRecordProperty[] properties = [ keyProperty, @@ -188,7 +194,7 @@ WHEN NOT MATCHED THEN OUTPUT inserted.[id]; """"; - Assert.Equal(HandleNewLines(expectedCommand), command.CommandText); + Assert.Equal(expectedCommand, command.CommandText); Assert.Equal("@id_0", command.Parameters[0].ParameterName); Assert.Equal(DBNull.Value, command.Parameters[0].Value); Assert.Equal("@simpleString_1", command.Parameters[1].ParameterName); @@ -321,12 +327,12 @@ public void SelectSingle() using SqlCommand command = SqlServerCommandBuilder.SelectSingle(connection, "schema", "tableName", keyProperty, properties, 123L); - Assert.Equal(HandleNewLines( + Assert.Equal( """"" SELECT [id],[name],[age],[embedding] FROM [schema].[tableName] WHERE [id] = @id_0 - """""), command.CommandText); + """"", command.CommandText); Assert.Equal(123L, command.Parameters[0].Value); Assert.Equal("@id_0", command.Parameters[0].ParameterName); } @@ -350,12 +356,12 @@ public void SelectMany() using SqlCommand command = SqlServerCommandBuilder.SelectMany(connection, "schema", "tableName", keyProperty, properties, keys); - Assert.Equal(HandleNewLines( + Assert.Equal( """"" SELECT [id],[name],[age],[embedding] FROM [schema].[tableName] WHERE [id] IN (@id_0,@id_1,@id_2) - """""), command.CommandText); + """"", command.CommandText); for (int i = 0; i < keys.Length; i++) { Assert.Equal(keys[i], command.Parameters[i].Value); @@ -363,11 +369,6 @@ WHERE [id] IN (@id_0,@id_1,@id_2) } } - private static string HandleNewLines(string expectedCommand) - => OperatingSystem.IsWindows() - ? expectedCommand.Replace("\n", "\r\n") - : expectedCommand; - // We create a connection using a fake connection string just to be able to create the SqlCommand. private static SqlConnection CreateConnection() => new("Server=localhost;Database=master;Integrated Security=True;"); diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj index d39fecdb6d82..54d75b8ebc6a 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj @@ -1,7 +1,7 @@  - net8.0 + net8.0;net472 enable enable @@ -9,6 +9,7 @@ true $(NoWarn);CA2007,SKEXP0001,SKEXP0020,VSTHRD111 + b7762d10-e29b-4bb1-8b74-b6d69a667dd4 diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs index 2c419c5f173d..23e714ff60bd 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs @@ -1,9 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; using Microsoft.Data.SqlClient; using Microsoft.Extensions.Configuration; using Microsoft.SemanticKernel.Connectors.SqlServer; @@ -36,14 +32,8 @@ public async Task InitializeAsync() .AddUserSecrets() .Build(); - var connectionString = configuration["SqlServer:ConnectionString"]; - - if (string.IsNullOrWhiteSpace(connectionString)) - { - throw new ArgumentException("SqlServer memory connection string is not configured."); - } - - this._connectionString = connectionString; + this._connectionString = configuration["SqlServer:ConnectionString"] + ?? throw new ArgumentException("SqlServer memory connection string is not configured."); await this.CleanupDatabaseAsync(); await this.InitializeDatabaseAsync(); @@ -324,18 +314,29 @@ private async Task> InsertSampleDataAsync() private async Task InitializeDatabaseAsync() { +#if NET // IAsyncDisposable is not present in Full Framework await using var connection = new SqlConnection(this._connectionString); - await connection.OpenAsync(); await using var cmd = connection.CreateCommand(); +#else + using var connection = new SqlConnection(this._connectionString); + using var cmd = connection.CreateCommand(); +#endif + + await connection.OpenAsync(); cmd.CommandText = $"CREATE SCHEMA {SchemaName}"; await cmd.ExecuteNonQueryAsync(); } private async Task CleanupDatabaseAsync() { +#if NET await using var connection = new SqlConnection(this._connectionString); - await connection.OpenAsync(); await using var cmd = connection.CreateCommand(); +#else + using var connection = new SqlConnection(this._connectionString); + using var cmd = connection.CreateCommand(); +#endif + await connection.OpenAsync(); cmd.CommandText = $""" DECLARE tables_cursor CURSOR FOR SELECT table_name diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs index f22b9cd70c59..2a1617eb0b4c 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -1,5 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Text.Json; +using Microsoft.Data.SqlClient; using Microsoft.Extensions.VectorData; using SqlServerIntegrationTests.Support; using Xunit; @@ -8,26 +10,29 @@ namespace SqlServerIntegrationTests; public class SqlServerVectorStoreTests { + // this test may be once executed by multiple users against a shared db instance + private static string GetUniqueCollectionName() => Guid.NewGuid().ToString(); + [Fact] public async Task CollectionCRUD() { - const string CollectionName = "collection"; + string collectionName = GetUniqueCollectionName(); SqlServerTestStore testStore = new(); await testStore.ReferenceCountingStartAsync(); - var collection = testStore.DefaultVectorStore.GetCollection(CollectionName); + var collection = testStore.DefaultVectorStore.GetCollection(collectionName); try { Assert.False(await collection.CollectionExistsAsync()); - Assert.False(await testStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(CollectionName)); + Assert.False(await testStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(collectionName)); await collection.CreateCollectionAsync(); Assert.True(await collection.CollectionExistsAsync()); - Assert.True(await testStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(CollectionName)); + Assert.True(await testStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(collectionName)); await collection.CreateCollectionIfNotExistsAsync(); @@ -36,7 +41,7 @@ public async Task CollectionCRUD() await collection.DeleteCollectionAsync(); Assert.False(await collection.CollectionExistsAsync()); - Assert.False(await testStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(CollectionName)); + Assert.False(await testStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(collectionName)); } finally { @@ -49,11 +54,12 @@ public async Task CollectionCRUD() [Fact] public async Task RecordCRUD() { + string collectionName = GetUniqueCollectionName(); SqlServerTestStore testStore = new(); await testStore.ReferenceCountingStartAsync(); - var collection = testStore.DefaultVectorStore.GetCollection("other"); + var collection = testStore.DefaultVectorStore.GetCollection(collectionName); try { @@ -95,14 +101,64 @@ public async Task RecordCRUD() } } + [Fact] + public async Task WrongModels() + { + string collectionName = GetUniqueCollectionName(); + SqlServerTestStore testStore = new(); + + await testStore.ReferenceCountingStartAsync(); + + var collection = testStore.DefaultVectorStore.GetCollection(collectionName); + + try + { + await collection.CreateCollectionIfNotExistsAsync(); + + TestModel inserted = new() + { + Id = "MyId", + Text = "NotAnInt", + Number = 100, + Floats = Enumerable.Range(0, 10).Select(i => (float)i).ToArray() + }; + Assert.Equal(inserted.Id, await collection.UpsertAsync(inserted)); + + // Let's use a model with different storage names to trigger an SQL exception + // which should be mapped to VectorStoreOperationException. + var differentNamesCollection = testStore.DefaultVectorStore.GetCollection(collectionName); + VectorStoreOperationException operationEx = await Assert.ThrowsAsync(() => differentNamesCollection.GetAsync(inserted.Id)); + Assert.IsType(operationEx.InnerException); + + // Let's use a model with the same storage names, but different types + // to trigger a mapping exception (casting a string to an int). + var sameNameDifferentModelCollection = testStore.DefaultVectorStore.GetCollection(collectionName); + VectorStoreRecordMappingException mappingEx = await Assert.ThrowsAsync(() => sameNameDifferentModelCollection.GetAsync(inserted.Id)); + Assert.IsType(mappingEx.InnerException); + + // Let's use a model with the same storage names, but different types + // to trigger a mapping exception (deserializing a string to Memory). + var invalidJsonCollection = testStore.DefaultVectorStore.GetCollection(collectionName); + mappingEx = await Assert.ThrowsAsync(() => invalidJsonCollection.GetAsync(inserted.Id)); + Assert.IsType(mappingEx.InnerException); + } + finally + { + await collection.DeleteCollectionAsync(); + + await testStore.ReferenceCountingStopAsync(); + } + } + [Fact] public async Task BatchCRUD() { + string collectionName = GetUniqueCollectionName(); SqlServerTestStore testStore = new(); await testStore.ReferenceCountingStartAsync(); - var collection = testStore.DefaultVectorStore.GetCollection("other"); + var collection = testStore.DefaultVectorStore.GetCollection(collectionName); try { @@ -163,7 +219,7 @@ private static void AssertEquality(TestModel inserted, TestModel? received) Assert.NotNull(received); Assert.Equal(inserted.Number, received.Number); Assert.Equal(inserted.Id, received.Id); - Assert.Equal(inserted.Floats, received.Floats); + Assert.Equal(inserted.Floats.ToArray(), received.Floats.ToArray()); Assert.Null(received.Text); // testing DBNull code path } @@ -182,6 +238,39 @@ public sealed class TestModel public ReadOnlyMemory Floats { get; set; } } + public sealed class SameStorageNameButDifferentType + { + [VectorStoreRecordKey(StoragePropertyName = "key")] + public string? Id { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "text")] + public int Number { get; set; } + } + + public sealed class SameStorageNameButInvalidVector + { + [VectorStoreRecordKey(StoragePropertyName = "key")] + public string? Id { get; set; } + + [VectorStoreRecordVector(Dimensions: 10, StoragePropertyName = "text")] + public ReadOnlyMemory Floats { get; set; } + } + + public sealed class DifferentStorageNames + { + [VectorStoreRecordKey(StoragePropertyName = "key")] + public string? Id { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "text2")] + public string? Text { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "column2")] + public int Number { get; set; } + + [VectorStoreRecordVector(Dimensions: 10, StoragePropertyName = "embedding2")] + public ReadOnlyMemory Floats { get; set; } + } + [Fact] public Task CanUseFancyModels_Int() => this.CanUseFancyModels(); @@ -193,11 +282,12 @@ public sealed class TestModel private async Task CanUseFancyModels() where TKey : notnull { + string collectionName = GetUniqueCollectionName(); SqlServerTestStore testStore = new(); await testStore.ReferenceCountingStartAsync(); - var collection = testStore.DefaultVectorStore.GetCollection>("other"); + var collection = testStore.DefaultVectorStore.GetCollection>(collectionName); try { @@ -212,8 +302,6 @@ private async Task CanUseFancyModels() where TKey : notnull Number64 = long.MaxValue, Floats = Enumerable.Range(0, 10).Select(i => (float)i).ToArray(), Bytes = [1, 2, 3], - ArrayOfStrings = ["a", "b", "c"], - ListOfStrings = ["d", "e", "f"] }; TKey key = await collection.UpsertAsync(inserted); Assert.NotEqual(default, key); // key should be assigned by the DB (auto-increment) @@ -233,9 +321,9 @@ private async Task CanUseFancyModels() where TKey : notnull received = await collection.GetAsync(updated.Id); AssertEquality(updated, received, key); - await collection.DeleteAsync(inserted.Id); + await collection.DeleteAsync(key); - Assert.Null(await collection.GetAsync(inserted.Id)); + Assert.Null(await collection.GetAsync(key)); } finally { @@ -252,10 +340,8 @@ void AssertEquality(FancyTestModel expected, FancyTestModel? receive Assert.Equal(expected.Number16, received.Number16); Assert.Equal(expected.Number32, received.Number32); Assert.Equal(expected.Number64, received.Number64); - Assert.Equal(expected.Floats, received.Floats); + Assert.Equal(expected.Floats.ToArray(), received.Floats.ToArray()); Assert.Equal(expected.Bytes, received.Bytes); - Assert.Equal(expected.ArrayOfStrings, received.ArrayOfStrings); - Assert.Equal(expected.ListOfStrings, received.ListOfStrings); } } @@ -279,14 +365,8 @@ public sealed class FancyTestModel [VectorStoreRecordData(StoragePropertyName = "bytes")] #pragma warning disable CA1819 // Properties should not return arrays public byte[]? Bytes { get; set; } - - [VectorStoreRecordData(StoragePropertyName = "array_of_strings")] - public string[]? ArrayOfStrings { get; set; } #pragma warning restore CA1819 // Properties should not return arrays - [VectorStoreRecordData(StoragePropertyName = "list_of_strings")] - public List? ListOfStrings { get; set; } - [VectorStoreRecordVector(Dimensions: 10, StoragePropertyName = "embedding")] public ReadOnlyMemory Floats { get; set; } } From 1902c0b56a18752828b6daca3f57eedb063040e0 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Tue, 25 Feb 2025 14:11:45 +0100 Subject: [PATCH 23/32] address remaining feedback: - add support for IVectorStoreRecordMapper - respect IncludeVectors - add VectorSearch tests (based on existing tests) - fix typos --- .../RecordMapper.cs | 100 +++++++++++ .../SqlDataReaderDictionary.cs | 140 +++++++++++++++ .../SqlServerCommandBuilder.cs | 56 +++--- .../SqlServerConstants.cs | 7 + .../SqlServerFilterTranslator.cs | 2 +- .../SqlServerVectorStore.cs | 11 +- .../SqlServerVectorStoreRecordCollection.cs | 159 ++++++------------ ...erverVectorStoreRecordCollectionOptions.cs | 48 ++++++ .../Data/VectorStoreRecordPropertyReader.cs | 1 - .../VectorStoreRecordPropertyVerification.cs | 5 +- .../PostgresBasicVectorSearchTests.cs | 23 +++ .../Filter/SqlServerBasicFilterTests.cs | 2 +- .../SqlServerCommandBuilderTests.cs | 34 ++-- .../SqlServerVectorStoreTests.cs | 105 +++++++++++- .../SqlServerBasicVectorSearchTests.cs | 32 ++++ .../VectorSearch/BasicVectorSearchTests.cs | 148 ++++++++++++++++ 16 files changed, 708 insertions(+), 165 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollectionOptions.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/BasicVectorSearchTests.cs diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs new file mode 100644 index 000000000000..0703e2b536f8 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Text.Json; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +internal sealed class RecordMapper : IVectorStoreRecordMapper> +{ + private readonly VectorStoreRecordPropertyReader _propertyReader; + + internal RecordMapper(VectorStoreRecordPropertyReader propertyReader) => this._propertyReader = propertyReader; + + public IDictionary MapFromDataToStorageModel(TRecord dataModel) + { + Dictionary map = new(StringComparer.Ordinal); + + map[SqlServerCommandBuilder.GetColumnName(this._propertyReader.KeyProperty)] = this._propertyReader.KeyPropertyInfo.GetValue(dataModel); + + var dataProperties = this._propertyReader.DataProperties; + var dataPropertiesInfo = this._propertyReader.DataPropertiesInfo; + for (int i = 0; i < dataProperties.Count; i++) + { + object? value = dataPropertiesInfo[i].GetValue(dataModel); + map[SqlServerCommandBuilder.GetColumnName(dataProperties[i])] = value; + } + var vectorProperties = this._propertyReader.VectorProperties; + var vectorPropertiesInfo = this._propertyReader.VectorPropertiesInfo; + for (int i = 0; i < vectorProperties.Count; i++) + { + // We restrict the vector properties to ReadOnlyMemory so the cast here is safe. + ReadOnlyMemory floats = (ReadOnlyMemory)vectorPropertiesInfo[i].GetValue(dataModel); + map[SqlServerCommandBuilder.GetColumnName(vectorProperties[i])] = floats; + } + + return map; + } + + public TRecord MapFromStorageToDataModel(IDictionary storageModel, StorageToDataModelMapperOptions options) + { + TRecord record = Activator.CreateInstance()!; + SetValue(storageModel, record, this._propertyReader.KeyPropertyInfo, this._propertyReader.KeyProperty); + var data = this._propertyReader.DataProperties; + var dataInfo = this._propertyReader.DataPropertiesInfo; + for (int i = 0; i < data.Count; i++) + { + SetValue(storageModel, record, dataInfo[i], data[i]); + } + + if (options.IncludeVectors) + { + var vector = this._propertyReader.VectorProperties; + var vectorInfo = this._propertyReader.VectorPropertiesInfo; + for (int i = 0; i < vector.Count; i++) + { + object? value = storageModel[SqlServerCommandBuilder.GetColumnName(vector[i])]; + if (value is not null) + { + if (value is ReadOnlyMemory floats) + { + vectorInfo[i].SetValue(record, floats); + } + else + { + // When deserializing a string to a ReadOnlyMemory fails in SqlDataReaderDictionary, + // we store the raw value so the user can handle the error in a custom mapper. + throw new VectorStoreRecordMappingException($"Failed to deserialize vector property '{vector[i].DataModelPropertyName}', it contained value '{value}'."); + } + } + } + } + + return record; + + static void SetValue(IDictionary storageModel, object record, PropertyInfo propertyInfo, VectorStoreRecordProperty property) + { + // If we got here, there should be no column name mismatch (the query would fail). + object? value = storageModel[SqlServerCommandBuilder.GetColumnName(property)]; + + if (value is null) + { + // There is no need to call the reflection to set the null, + // as it's the default value of every .NET reference type field. + return; + } + + try + { + propertyInfo.SetValue(record, value); + } + catch (Exception ex) + { + throw new VectorStoreRecordMappingException($"Failed to set value '{value}' on property '{propertyInfo.Name}' of type '{propertyInfo.PropertyType.FullName}'.", ex); + } + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs new file mode 100644 index 000000000000..62e8637e26fd --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text.Json; +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +/// +/// This class is used to provide a dictionary-like interface to a . +/// The goal is to avoid the need of allocating a new dictionary for each row read from the database. +/// +internal sealed class SqlDataReaderDictionary : IDictionary +{ + private readonly SqlDataReader _sqlDataReader; + private readonly IReadOnlyList _vectorPropertyStoragePropertyNames; + + // This field will get instantiated lazily, only if needed by a custom mapper. + private Dictionary? _dictionary; + + internal SqlDataReaderDictionary(SqlDataReader sqlDataReader, IReadOnlyList vectorPropertyStoragePropertyNames) + { + this._sqlDataReader = sqlDataReader; + this._vectorPropertyStoragePropertyNames = vectorPropertyStoragePropertyNames; + } + + private object? Unwrap(string storageName, object? value) + { + // Let's make sure our users don't need to learn what DBNull is. + if (value is DBNull) + { + return null; + } + + // If the value is a vector, we need to deserialize it. + if (this._vectorPropertyStoragePropertyNames.Count > 0 && value is string text) + { + for (int i = 0; i < this._vectorPropertyStoragePropertyNames.Count; i++) + { + if (string.Equals(storageName, this._vectorPropertyStoragePropertyNames[i], StringComparison.Ordinal)) + { + try + { + return JsonSerializer.Deserialize>(text); + } + catch (JsonException) + { + // This may fail if the user has stored a non-float array in the database + // (or serialized it in a different way). + // We need to return the raw value, so the user can handle the error in a custom mapper. + return text; + } + } + } + } + + return value; + } + + // This is the only method used by the default mapper. + public object? this[string key] + { + get => this.Unwrap(key, _sqlDataReader[key]); + set => throw new InvalidOperationException(); + } + + public ICollection Keys => GetDictionary().Keys; + + public ICollection Values => GetDictionary().Values; + + public int Count => this._sqlDataReader.FieldCount; + + public bool IsReadOnly => true; + + public void Add(string key, object? value) => throw new InvalidOperationException(); + + public void Add(KeyValuePair item) => throw new InvalidOperationException(); + + public void Clear() => throw new InvalidOperationException(); + + public bool Contains(KeyValuePair item) + => TryGetValue(item.Key, out var value) && Equals(value, item.Value); + + public bool ContainsKey(string key) + { + try + { + return this._sqlDataReader.GetOrdinal(key) >= 0; + } + catch (IndexOutOfRangeException) + { + return false; + } + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + => ((ICollection>)GetDictionary()).CopyTo(array, arrayIndex); + + public IEnumerator> GetEnumerator() + => GetDictionary().GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() + => GetDictionary().GetEnumerator(); + + public bool Remove(string key) => throw new InvalidOperationException(); + + public bool Remove(KeyValuePair item) => throw new InvalidOperationException(); + + public bool TryGetValue(string key, out object? value) + { + try + { + value = this.Unwrap(key, this._sqlDataReader[key]); + return true; + } + catch (IndexOutOfRangeException) + { + value = default; + return false; + } + } + + private Dictionary GetDictionary() + { + if (this._dictionary is null) + { + Dictionary dictionary = new(this._sqlDataReader.FieldCount, StringComparer.Ordinal); + for (int i = 0; i < this._sqlDataReader.FieldCount; i++) + { + string name = this._sqlDataReader.GetName(i); + dictionary.Add(name, this.Unwrap(name, this._sqlDataReader[i])); + } + this._dictionary = dictionary; + } + return this._dictionary; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 17e180f10a6f..ca1a9bc32cdf 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -16,7 +16,7 @@ internal static class SqlServerCommandBuilder { internal static SqlCommand CreateTable( SqlConnection connection, - SqlServerVectorStoreOptions options, + string schema, string tableName, bool ifNotExists, VectorStoreRecordKeyProperty keyProperty, @@ -27,11 +27,11 @@ internal static SqlCommand CreateTable( if (ifNotExists) { sb.Append("IF OBJECT_ID(N'"); - sb.AppendTableName(options.Schema, tableName); + sb.AppendTableName(schema, tableName); sb.AppendLine("', N'U') IS NULL"); } sb.Append("CREATE TABLE "); - sb.AppendTableName(options.Schema, tableName); + sb.AppendTableName(schema, tableName); sb.AppendLine(" ("); string keyColumnName = GetColumnName(keyProperty); var keyMapping = Map(keyProperty.PropertyType); @@ -44,7 +44,6 @@ internal static SqlCommand CreateTable( } for (int i = 0; i < vectorProperties.Count; i++) { - // TODO adsitnik design: should we require Dimensions to be always provided in explicit way or use some default? sb.AppendFormat("[{0}] VECTOR({1}),", GetColumnName(vectorProperties[i]), vectorProperties[i].Dimensions); sb.AppendLine(); } @@ -94,23 +93,23 @@ FROM INFORMATION_SCHEMA.TABLES internal static SqlCommand MergeIntoSingle( SqlConnection connection, - SqlServerVectorStoreOptions options, + string schema, string tableName, VectorStoreRecordKeyProperty keyProperty, IReadOnlyList properties, - Dictionary record) + IDictionary record) { SqlCommand command = connection.CreateCommand(); StringBuilder sb = new(200); sb.Append("MERGE INTO "); - sb.AppendTableName(options.Schema, tableName); + sb.AppendTableName(schema, tableName); sb.AppendLine(" AS t"); sb.Append("USING (VALUES ("); int paramIndex = 0; foreach (VectorStoreRecordProperty property in properties) { sb.AppendParameterName(property, ref paramIndex, out string paramName).Append(','); - command.AddParameter(property, paramName, record[property.DataModelPropertyName]); + command.AddParameter(property, paramName, record[GetColumnName(property)]); } sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis sb.Append(") AS s ("); @@ -149,11 +148,11 @@ internal static SqlCommand MergeIntoSingle( internal static SqlCommand MergeIntoMany( SqlConnection connection, - SqlServerVectorStoreOptions options, + string schema, string tableName, VectorStoreRecordKeyProperty keyProperty, IReadOnlyList properties, - IEnumerable> records) + IEnumerable> records) { SqlCommand command = connection.CreateCommand(); @@ -163,7 +162,7 @@ internal static SqlCommand MergeIntoMany( sb.AppendLine(); // The MERGE statement performs the upsert operation and outputs the keys of the inserted rows into the table variable. sb.Append("MERGE INTO "); - sb.AppendTableName(options.Schema, tableName); + sb.AppendTableName(schema, tableName); sb.AppendLine(" AS t"); // t stands for target sb.AppendLine("USING (VALUES"); int rowIndex = 0, paramIndex = 0; @@ -173,7 +172,7 @@ internal static SqlCommand MergeIntoMany( foreach (VectorStoreRecordProperty property in properties) { sb.AppendParameterName(property, ref paramIndex, out string paramName).Append(','); - command.AddParameter(property, paramName, record[property.DataModelPropertyName]); + command.AddParameter(property, paramName, record[GetColumnName(property)]); } sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis sb.AppendLine(","); @@ -182,7 +181,6 @@ internal static SqlCommand MergeIntoMany( if (rowIndex == 0) { - // TODO adsitnik clarify: should we throw or simply do nothing? throw new ArgumentException("The value cannot be empty.", nameof(records)); } @@ -313,11 +311,15 @@ internal static SqlCommand SelectVector( { string distanceFunction = vectorProperty.DistanceFunction ?? DistanceFunction.CosineDistance; // Source: https://learn.microsoft.com/sql/t-sql/functions/vector-distance-transact-sql - string distanceMetric = distanceFunction switch + (string distanceMetric, string sorting) = distanceFunction switch { - DistanceFunction.CosineDistance => "cosine", - DistanceFunction.EuclideanDistance => "euclidean", - DistanceFunction.NegativeDotProductSimilarity => "dot", + // A value of 0 indicates that the vectors are identical in direction (cosine similarity of 1), + // while a value of 1 indicates that the vectors are orthogonal (cosine similarity of 0). + DistanceFunction.CosineDistance => ("cosine", "ASC"), + // A value of 0 indicates that the vectors are identical, while larger values indicate greater dissimilarity. + DistanceFunction.EuclideanDistance => ("euclidean", "ASC"), + // A value closer to 0 indicates higher similarity, while more negative values indicate greater dissimilarity. + DistanceFunction.NegativeDotProductSimilarity => ("dot", "DESC"), _ => throw new NotSupportedException($"Distance function {vectorProperty.DistanceFunction} is not supported.") }; @@ -326,9 +328,9 @@ internal static SqlCommand SelectVector( StringBuilder sb = new(200); sb.AppendFormat("SELECT "); - sb.AppendColumnNames(properties); + sb.AppendColumnNames(properties, includeVectors: options.IncludeVectors); sb.AppendLine(","); - sb.AppendFormat("1 - VECTOR_DISTANCE('{0}', {1}, CAST(@vector AS VECTOR({2}))) AS [score]", + sb.AppendFormat("VECTOR_DISTANCE('{0}', {1}, CAST(@vector AS VECTOR({2}))) AS [score]", distanceMetric, GetColumnName(vectorProperty), vector.Length); sb.AppendLine(); sb.Append("FROM "); @@ -348,7 +350,8 @@ internal static SqlCommand SelectVector( } sb.AppendLine(); } - sb.AppendLine("ORDER BY [score] DESC"); + sb.AppendFormat("ORDER BY [score] {0}", sorting); + sb.AppendLine(); // Negative Skip and Top values are rejected by the VectorSearchOptions property setters. // 0 is a legal value for OFFSET. sb.AppendFormat("OFFSET {0} ROWS FETCH NEXT {1} ROWS ONLY;", options.Skip, options.Top); @@ -412,11 +415,17 @@ internal static StringBuilder AppendTableName(this StringBuilder sb, string sche private static StringBuilder AppendColumnNames(this StringBuilder sb, IEnumerable properties, - string? prefix = null) + string? prefix = null, + bool includeVectors = true) { bool any = false; foreach (VectorStoreRecordProperty property in properties) { + if (!includeVectors && property is VectorStoreRecordVectorProperty) + { + continue; + } + if (prefix is not null) { sb.Append(prefix); @@ -451,7 +460,6 @@ private static StringBuilder AppendKeyParameterList(this StringBuilder sb, if (keyIndex == 0) { - // TODO adsitnik clarify: should we throw or simply do nothing? throw new ArgumentException("The value cannot be empty.", nameof(keys)); } @@ -474,11 +482,15 @@ private static void AddParameter(this SqlCommand command, VectorStoreRecordPrope command.Parameters.Add(name, System.Data.SqlDbType.VarBinary).Value = DBNull.Value; break; case null: + case ReadOnlyMemory vector when vector.Length == 0: command.Parameters.AddWithValue(name, DBNull.Value); break; case byte[] buffer: command.Parameters.Add(name, System.Data.SqlDbType.VarBinary).Value = buffer; break; + case ReadOnlyMemory vector: + command.Parameters.AddWithValue(name, JsonSerializer.Serialize(vector)); + break; default: command.Parameters.AddWithValue(name, value); break; diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs index b8842144258c..339445795551 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs @@ -19,6 +19,13 @@ internal static class SqlServerConstants typeof(byte[]) // VARBINARY ]; + internal static readonly HashSet SupportedAutoGenerateKeyTypes = + [ + typeof(int), // IDENTITY + typeof(long), // IDENTITY + typeof(Guid) // NEWID + ]; + internal static readonly HashSet SupportedDataTypes = [ typeof(int), // INT diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs index d46e1147d963..b51e2d0b588a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs @@ -76,7 +76,7 @@ protected override void TranslateLambdaVariables(string name, object? capturedVa else { this._parameterValues.Add(capturedValue); - // SQL Server paramters can't start with a digit (but underscore is OK). + // SQL Server parameters can't start with a digit (but underscore is OK). this._sql.Append("@_").Append(this._parameterIndex++); } } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs index 45e53faf69b7..754f4380160c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs @@ -24,10 +24,6 @@ public sealed class SqlServerVectorStore : IVectorStore, IDisposable /// Optional configuration options. public SqlServerVectorStore(SqlConnection connection, SqlServerVectorStoreOptions? options = null) { - // TODO adsitnik: design: - // Do we need a ctor that takes the connection string and creates a connection? - // What is the story with pooling for the SqlConnection type? - // Does it maintain a private instance pool? Or a static one? this._connection = connection; // We need to create a copy, so any changes made to the option bag after // the ctor call do not affect this instance. @@ -47,8 +43,11 @@ public IVectorStoreRecordCollection GetCollection( return new SqlServerVectorStoreRecordCollection( this._connection, name, - vectorStoreRecordDefinition, - this._options); + new() + { + Schema = this._options.Schema, + RecordDefinition = vectorStoreRecordDefinition + }); } /// diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index 19eb9058e96b..1e4bd541e4f2 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -3,9 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Reflection; using System.Runtime.CompilerServices; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Data.SqlClient; @@ -20,29 +18,29 @@ public sealed class SqlServerVectorStoreRecordCollection : IVecto where TKey : notnull { private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly SqlServerVectorStoreRecordCollectionOptions s_defaultOptions = new(); private readonly SqlConnection _sqlConnection; - private readonly SqlServerVectorStoreOptions _options; + private readonly SqlServerVectorStoreRecordCollectionOptions _options; private readonly VectorStoreRecordPropertyReader _propertyReader; + private readonly IVectorStoreRecordMapper> _mapper; /// /// Initializes a new instance of the class. /// /// Database connection. /// The name of the collection. - /// Optional record definition. - /// Optional configuration options. + /// Optional configuration options. public SqlServerVectorStoreRecordCollection( SqlConnection connection, string name, - VectorStoreRecordDefinition? vectorStoreRecordDefinition = null, - SqlServerVectorStoreOptions? vectorStoreOptions = null) + SqlServerVectorStoreRecordCollectionOptions? options = null) { Verify.NotNull(connection); Verify.NotNull(name); VectorStoreRecordPropertyReader propertyReader = new(typeof(TRecord), - vectorStoreRecordDefinition, + options?.RecordDefinition, new() { RequiresAtLeastOneVector = false, @@ -50,26 +48,39 @@ public SqlServerVectorStoreRecordCollection( SupportsMultipleVectors = true, }); - propertyReader.VerifyHasParameterlessConstructor(); - propertyReader.VerifyKeyProperties(SqlServerConstants.SupportedKeyTypes); - propertyReader.VerifyDataProperties(SqlServerConstants.SupportedDataTypes, supportEnumerable: false); - propertyReader.VerifyVectorProperties(SqlServerConstants.SupportedVectorTypes); + if (options is null || options.Mapper is null) + { + propertyReader.VerifyHasParameterlessConstructor(); + } + + HashSet supportedKeyTypes = propertyReader.KeyProperty.AutoGenerate + ? SqlServerConstants.SupportedAutoGenerateKeyTypes + : SqlServerConstants.SupportedKeyTypes; - if (propertyReader.KeyProperty.AutoGenerate - && !(typeof(TKey) == typeof(int) || typeof(TKey) == typeof(long) || typeof(TKey) == typeof(Guid))) + if (VectorStoreRecordPropertyVerification.IsGenericDataModel(typeof(TRecord))) { - // SQL Server does not support auto-generated keys for types other than int, long, and Guid. - throw new ArgumentException("Key property cannot be auto-generated."); + VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.Mapper is not null, supportedKeyTypes); } + else + { + propertyReader.VerifyKeyProperties(supportedKeyTypes); + } + propertyReader.VerifyDataProperties(SqlServerConstants.SupportedDataTypes, supportEnumerable: false); + propertyReader.VerifyVectorProperties(SqlServerConstants.SupportedVectorTypes); this._sqlConnection = connection; this.CollectionName = name; // We need to create a copy, so any changes made to the option bag after // the ctor call do not affect this instance. - this._options = vectorStoreOptions is not null - ? new() { Schema = vectorStoreOptions.Schema } - : SqlServerVectorStoreOptions.Defaults; + this._options = options is null ? s_defaultOptions + : new() + { + Schema = options.Schema, + Mapper = options.Mapper, + RecordDefinition = options.RecordDefinition, + }; this._propertyReader = propertyReader; + this._mapper = options?.Mapper ?? new RecordMapper(propertyReader); } /// @@ -109,7 +120,7 @@ private async Task CreateCollectionAsync(bool ifNotExists, CancellationToken can using SqlCommand command = SqlServerCommandBuilder.CreateTable( this._sqlConnection, - this._options, + this._options.Schema, this.CollectionName, ifNotExists, this._propertyReader.KeyProperty, @@ -187,7 +198,11 @@ static async (cmd, ct) => return reader; }, cancellationToken, "Get", this.CollectionName).ConfigureAwait(false); - return reader.HasRows ? Map(reader, this._propertyReader) : default; + return reader.HasRows + ? this._mapper.MapFromStorageToDataModel( + new SqlDataReaderDictionary(reader, this._propertyReader.VectorPropertyStoragePropertyNames), + new() { IncludeVectors = true }) + : default; } /// @@ -210,7 +225,9 @@ public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, Get while (await ExceptionWrapper.WrapReadAsync(reader, cancellationToken, "GetBatch", this.CollectionName).ConfigureAwait(false)) { - yield return Map(reader, this._propertyReader); + yield return this._mapper.MapFromStorageToDataModel( + new SqlDataReaderDictionary(reader, this._propertyReader.VectorPropertyStoragePropertyNames), + new() { IncludeVectors = true }); } } @@ -221,11 +238,11 @@ public async Task UpsertAsync(TRecord record, CancellationToken cancellati using SqlCommand command = SqlServerCommandBuilder.MergeIntoSingle( this._sqlConnection, - this._options, + this._options.Schema, this.CollectionName, this._propertyReader.KeyProperty, this._propertyReader.Properties, - Map(record, this._propertyReader)); + this._mapper.MapFromDataToStorageModel(record)); return await ExceptionWrapper.WrapAsync(this._sqlConnection, command, async static (cmd, ct) => @@ -244,11 +261,11 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record using SqlCommand command = SqlServerCommandBuilder.MergeIntoMany( this._sqlConnection, - this._options, + this._options.Schema, this.CollectionName, this._propertyReader.KeyProperty, this._propertyReader.Properties, - records.Select(record => Map(record, this._propertyReader))); + records.Select(record => this._mapper.MapFromDataToStorageModel(record))); using SqlDataReader reader = await ExceptionWrapper.WrapAsync(this._sqlConnection, command, static (cmd, ct) => cmd.ExecuteReaderAsync(ct), @@ -294,15 +311,18 @@ public async Task> VectorizedSearchAsync(T return await ExceptionWrapper.WrapAsync(this._sqlConnection, command, (cmd, ct) => { - var results = this.ReadVectorSearchResultsAsync(cmd, ct); + var results = this.ReadVectorSearchResultsAsync(cmd, searchOptions.IncludeVectors, ct); return Task.FromResult(new VectorSearchResults(results)); }, cancellationToken, "VectorizedSearch", this.CollectionName).ConfigureAwait(false); } private async IAsyncEnumerable> ReadVectorSearchResultsAsync( SqlCommand command, + bool includeVectors, [EnumeratorCancellation] CancellationToken cancellationToken) { + StorageToDataModelMapperOptions options = new() { IncludeVectors = includeVectors }; + var vectorPropertyStoragePropertyNames = includeVectors ? this._propertyReader.VectorPropertyStoragePropertyNames : []; using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); int scoreIndex = -1; @@ -314,91 +334,8 @@ private async IAsyncEnumerable> ReadVectorSearchResu } yield return new VectorSearchResult( - Map(reader, this._propertyReader), + this._mapper.MapFromStorageToDataModel(new SqlDataReaderDictionary(reader, vectorPropertyStoragePropertyNames), options), reader.GetDouble(scoreIndex)); } } - - private static Dictionary Map(TRecord record, VectorStoreRecordPropertyReader propertyReader) - { - Dictionary map = new(StringComparer.Ordinal); - map[propertyReader.KeyProperty.DataModelPropertyName] = propertyReader.KeyPropertyInfo.GetValue(record); - - var dataProperties = propertyReader.DataProperties; - for (int i = 0; i < dataProperties.Count; i++) - { - object value = propertyReader.DataPropertiesInfo[i].GetValue(record); - map[dataProperties[i].DataModelPropertyName] = value; - } - var vectorProperties = propertyReader.VectorProperties; - for (int i = 0; i < vectorProperties.Count; i++) - { - // We restrict the vector properties to ReadOnlyMemory so the cast here is safe. - ReadOnlyMemory floats = (ReadOnlyMemory)propertyReader.VectorPropertiesInfo[i].GetValue(record); - // We know that SqlServer supports JSON serialization, so we can serialize the vector as JSON now, - // so the SqlServerCommandBuilder does not need to worry about that. - map[vectorProperties[i].DataModelPropertyName] = JsonSerializer.Serialize(floats); - } - - return map; - } - - private static TRecord Map(SqlDataReader reader, VectorStoreRecordPropertyReader propertyReader) - { - TRecord record = Activator.CreateInstance()!; - SetValue(reader, record, propertyReader.KeyPropertyInfo, propertyReader.KeyProperty); - var data = propertyReader.DataProperties; - var dataInfo = propertyReader.DataPropertiesInfo; - for (int i = 0; i < data.Count; i++) - { - SetValue(reader, record, dataInfo[i], data[i]); - } - - var vector = propertyReader.VectorProperties; - var vectorInfo = propertyReader.VectorPropertiesInfo; - for (int i = 0; i < vector.Count; i++) - { - object value = reader[SqlServerCommandBuilder.GetColumnName(vector[i])]; - if (value is not DBNull) - { - ReadOnlyMemory? embedding = null; - - try - { - // This may fail if the user has stored a non-float array in the database - // (or serialized it in a different way). - embedding = JsonSerializer.Deserialize>((string)value); - } - catch (Exception ex) - { - throw new VectorStoreRecordMappingException($"Failed to deserialize vector property '{vector[i].DataModelPropertyName}', it contained value '{value}'.", ex); - } - - vectorInfo[i].SetValue(record, embedding); - } - } - return record; - - static void SetValue(SqlDataReader reader, object record, PropertyInfo propertyInfo, VectorStoreRecordProperty property) - { - // If we got here, there should be no column name mismatch (the query would fail). - object value = reader[SqlServerCommandBuilder.GetColumnName(property)]; - - if (value is DBNull) - { - // There is no need to call the reflection to set the null, - // as it's the default value of every .NET reference type field. - return; - } - - try - { - propertyInfo.SetValue(record, value); - } - catch (Exception ex) - { - throw new VectorStoreRecordMappingException($"Failed to set value '{value}' on property '{propertyInfo.Name}' of type '{propertyInfo.PropertyType.FullName}'.", ex); - } - } - } } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollectionOptions.cs new file mode 100644 index 000000000000..adb1bd359d70 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollectionOptions.cs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +/// +/// Options when creating a . +/// +public sealed class SqlServerVectorStoreRecordCollectionOptions +{ + private string _schema = SqlServerConstants.Schema; + + /// + /// Gets or sets the database schema. + /// + /// when provided schema is empty or composed entirely of whitespace. + public string Schema + { + get => this._schema; + init + { + Verify.NotNullOrWhiteSpace(value); + + this._schema = value; + } + } + + /// + /// Gets or sets an optional custom mapper to use when converting between the data model and the SQL Server record. + /// + /// + /// If not set, the default mapper will be used. + /// + public IVectorStoreRecordMapper>? Mapper { get; init; } + + /// + /// Gets or sets an optional record definition that defines the schema of the record type. + /// + /// + /// If not provided, the schema will be inferred from the record model class using reflection. + /// In this case, the record model properties must be annotated with the appropriate attributes to indicate their usage. + /// See , and . + /// + public VectorStoreRecordDefinition? RecordDefinition { get; init; } +} diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs index 7b64c494e883..9ebda11442c3 100644 --- a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs +++ b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs @@ -118,7 +118,6 @@ public VectorStoreRecordPropertyReader( this._parameterlessConstructorInfo = new Lazy(() => { - // TODO adsitnik: design: why don't we requrie TRecord to be always : new()? var constructor = dataModelType.GetConstructor(Type.EmptyTypes); if (constructor == null) { diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyVerification.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyVerification.cs index b85800e6a244..08337bd0f138 100644 --- a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyVerification.cs +++ b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyVerification.cs @@ -191,6 +191,9 @@ var enumerableType when GetGenericEnumerableInterface(enumerableType) is Type en return null; } + internal static bool IsGenericDataModel(Type recordType) + => recordType.IsGenericType && recordType.GetGenericTypeDefinition() == typeof(VectorStoreGenericDataModel<>); + /// /// Checks that if the provided is a that the key type is supported by the default mappers. /// If not supported, a custom mapper must be supplied, otherwise an exception is thrown. @@ -202,7 +205,7 @@ var enumerableType when GetGenericEnumerableInterface(enumerableType) is Type en public static void VerifyGenericDataModelKeyType(Type recordType, bool customMapperSupplied, IEnumerable allowedKeyTypes) { // If we are not dealing with a generic data model, no need to check anything else. - if (!recordType.IsGenericType || recordType.GetGenericTypeDefinition() != typeof(VectorStoreGenericDataModel<>)) + if (!IsGenericDataModel(recordType)) { return; } diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests.cs new file mode 100644 index 000000000000..823f421d0f1e --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PostgresIntegrationTests.Support; +using VectorDataSpecificationTests.Support; +using VectorDataSpecificationTests.VectorSearch; +using Xunit; + +namespace PostgresIntegrationTests.VectorSearch; + +public class PostgresBasicVectorSearchTests(PostgresBasicVectorSearchTests.Fixture fixture) + : BasicVectorSearchTests(fixture), IClassFixture +{ + public override Task EuclideanSquaredDistance() => Assert.ThrowsAsync(() => base.EuclideanSquaredDistance()); + + public override Task Hamming() => Assert.ThrowsAsync(() => base.Hamming()); + + public override Task NegativeDotProductSimilarity() => Assert.ThrowsAsync(() => base.NegativeDotProductSimilarity()); + + public new class Fixture : VectorStoreFixture + { + public override TestStore TestStore => PostgresTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs index ce9c6b4bde0a..2cdbbeab82cc 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs @@ -65,7 +65,7 @@ public override Task Contains_over_field_string_List() public override TestStore TestStore => SqlServerTestStore.Instance; protected override string CollectionName -#if NET // make sure different TFMs use different collection names (as they may run in parralel and cause trouble) +#if NET // make sure different TFMs use different collection names (as they may run in parallel and cause trouble) => "FilterTests-core"; #else => "FilterTests-framework"; diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index 2c04ffcb09b5..58de196220a7 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -105,10 +105,6 @@ FROM INFORMATION_SCHEMA.TABLES [InlineData(false)] public void CreateTable(bool ifNotExists) { - SqlServerVectorStoreOptions options = new() - { - Schema = "schema" - }; VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)) { AutoGenerate = true @@ -127,7 +123,7 @@ public void CreateTable(bool ifNotExists) ]; using SqlConnection connection = CreateConnection(); - using SqlCommand command = SqlServerCommandBuilder.CreateTable(connection, options, "table", + using SqlCommand command = SqlServerCommandBuilder.CreateTable(connection, "schema", "table", ifNotExists, keyProperty, dataProperties, vectorProperties); string expectedCommand = @@ -145,16 +141,12 @@ PRIMARY KEY NONCLUSTERED ([id]) expectedCommand = "IF OBJECT_ID(N'[schema].[table]', N'U') IS NULL" + Environment.NewLine + expectedCommand; } - Assert.Equal(expectedCommand, command.CommandText); + AssertEqualIgnoreNewLines(expectedCommand, command.CommandText); } [Fact] public void MergeIntoSingle() { - SqlServerVectorStoreOptions options = new() - { - Schema = "schema" - }; VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)) { AutoGenerate = true @@ -171,7 +163,7 @@ public void MergeIntoSingle() ]; using SqlConnection connection = CreateConnection(); - using SqlCommand command = SqlServerCommandBuilder.MergeIntoSingle(connection, options, "table", + using SqlCommand command = SqlServerCommandBuilder.MergeIntoSingle(connection, "schema", "table", keyProperty, properties, new Dictionary { @@ -194,7 +186,7 @@ WHEN NOT MATCHED THEN OUTPUT inserted.[id]; """"; - Assert.Equal(expectedCommand, command.CommandText); + AssertEqualIgnoreNewLines(expectedCommand, command.CommandText); Assert.Equal("@id_0", command.Parameters[0].ParameterName); Assert.Equal(DBNull.Value, command.Parameters[0].Value); Assert.Equal("@simpleString_1", command.Parameters[1].ParameterName); @@ -208,10 +200,6 @@ WHEN NOT MATCHED THEN [Fact] public void MergeIntoMany() { - SqlServerVectorStoreOptions options = new() - { - Schema = "schema" - }; VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); VectorStoreRecordProperty[] properties = [ @@ -242,7 +230,7 @@ public void MergeIntoMany() ]; using SqlConnection connection = CreateConnection(); - using SqlCommand command = SqlServerCommandBuilder.MergeIntoMany(connection, options, "table", + using SqlCommand command = SqlServerCommandBuilder.MergeIntoMany(connection, "schema", "table", keyProperty, properties, records); string expectedCommand = @@ -262,7 +250,7 @@ WHEN NOT MATCHED THEN SELECT KeyColumn FROM @InsertedKeys; """"; - Assert.Equal(expectedCommand, command.CommandText); + AssertEqualIgnoreNewLines(expectedCommand, command.CommandText); for (int i = 0; i < records.Length; i++) { @@ -327,7 +315,7 @@ public void SelectSingle() using SqlCommand command = SqlServerCommandBuilder.SelectSingle(connection, "schema", "tableName", keyProperty, properties, 123L); - Assert.Equal( + AssertEqualIgnoreNewLines( """"" SELECT [id],[name],[age],[embedding] FROM [schema].[tableName] @@ -356,7 +344,7 @@ public void SelectMany() using SqlCommand command = SqlServerCommandBuilder.SelectMany(connection, "schema", "tableName", keyProperty, properties, keys); - Assert.Equal( + AssertEqualIgnoreNewLines( """"" SELECT [id],[name],[age],[embedding] FROM [schema].[tableName] @@ -369,6 +357,12 @@ WHERE [id] IN (@id_0,@id_1,@id_2) } } + // This repo is configured with eol=lf, so the expected string should always use \n + // as long given IDE does not use \r\n. + // The actual string may use \r\n, so we just normalize both. + private static void AssertEqualIgnoreNewLines(string expected, string actual) + => Assert.Equal(expected.Replace("\r\n", "\n"), actual.Replace("\r\n", "\n")); + // We create a connection using a fake connection string just to be able to create the SqlCommand. private static SqlConnection CreateConnection() => new("Server=localhost;Database=master;Integrated Security=True;"); diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs index 2a1617eb0b4c..a3c72b9a848f 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -3,6 +3,7 @@ using System.Text.Json; using Microsoft.Data.SqlClient; using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.SqlServer; using SqlServerIntegrationTests.Support; using Xunit; @@ -89,6 +90,21 @@ public async Task RecordCRUD() received = await collection.GetAsync(updated.Id); AssertEquality(updated, received); + VectorSearchResult vectorSearchResult = await (await collection.VectorizedSearchAsync(inserted.Floats, new() + { + VectorPropertyName = nameof(TestModel.Floats), + IncludeVectors = true + })).Results.SingleAsync(); + AssertEquality(updated, vectorSearchResult.Record); + + vectorSearchResult = await (await collection.VectorizedSearchAsync(inserted.Floats, new() + { + VectorPropertyName = nameof(TestModel.Floats), + IncludeVectors = false + })).Results.SingleAsync(); + // Make sure the vectors are not included in the result. + Assert.Equal(0, vectorSearchResult.Record.Floats.Length); + await collection.DeleteAsync(inserted.Id); Assert.Null(await collection.GetAsync(inserted.Id)); @@ -139,8 +155,7 @@ public async Task WrongModels() // Let's use a model with the same storage names, but different types // to trigger a mapping exception (deserializing a string to Memory). var invalidJsonCollection = testStore.DefaultVectorStore.GetCollection(collectionName); - mappingEx = await Assert.ThrowsAsync(() => invalidJsonCollection.GetAsync(inserted.Id)); - Assert.IsType(mappingEx.InnerException); + await Assert.ThrowsAsync(() => invalidJsonCollection.GetAsync(inserted.Id)); } finally { @@ -150,6 +165,59 @@ public async Task WrongModels() } } + [Fact] + public async Task CustomMapper() + { + string collectionName = GetUniqueCollectionName(); + TestModelMapper mapper = new(); + SqlServerVectorStoreRecordCollectionOptions options = new() + { + Mapper = mapper + }; + using SqlConnection connection = new(SqlServerTestEnvironment.ConnectionString); + SqlServerVectorStoreRecordCollection collection = new(connection, collectionName, options); + + try + { + await collection.CreateCollectionIfNotExistsAsync(); + + TestModel inserted = new() + { + Id = "MyId", + Number = 100, + Floats = Enumerable.Range(0, 10).Select(i => (float)i).ToArray() + }; + string key = await collection.UpsertAsync(inserted); + Assert.Equal(inserted.Id, key); + Assert.True(mapper.MapFromDataToStorageModel_WasCalled); + Assert.False(mapper.MapFromStorageToDataModel_WasCalled); + + TestModel? received = await collection.GetAsync(inserted.Id); + AssertEquality(inserted, received); + Assert.True(mapper.MapFromStorageToDataModel_WasCalled); + + TestModel updated = new() + { + Id = inserted.Id, + Number = inserted.Number + 200, // change one property + Floats = inserted.Floats + }; + key = await collection.UpsertAsync(updated); + Assert.Equal(inserted.Id, key); + + received = await collection.GetAsync(updated.Id); + AssertEquality(updated, received); + + await collection.DeleteAsync(inserted.Id); + + Assert.Null(await collection.GetAsync(inserted.Id)); + } + finally + { + await collection.DeleteCollectionAsync(); + } + } + [Fact] public async Task BatchCRUD() { @@ -370,4 +438,37 @@ public sealed class FancyTestModel [VectorStoreRecordVector(Dimensions: 10, StoragePropertyName = "embedding")] public ReadOnlyMemory Floats { get; set; } } + + private sealed class TestModelMapper : IVectorStoreRecordMapper> + { + internal bool MapFromDataToStorageModel_WasCalled { get; set; } + internal bool MapFromStorageToDataModel_WasCalled { get; set; } + + public IDictionary MapFromDataToStorageModel(TestModel dataModel) + { + MapFromDataToStorageModel_WasCalled = true; + + return new Dictionary() + { + { "key", dataModel.Id }, + { "text", dataModel.Text }, + { "column", dataModel.Number }, + // Please note that we are not dealing with JSON directly here. + { "embedding", dataModel.Floats } + }; + } + + public TestModel MapFromStorageToDataModel(IDictionary storageModel, StorageToDataModelMapperOptions options) + { + MapFromStorageToDataModel_WasCalled = true; + + return new() + { + Id = (string)storageModel["key"]!, + Text = (string?)storageModel["text"], + Number = (int)storageModel["column"]!, + Floats = (ReadOnlyMemory)storageModel["embedding"]! + }; + } + } } diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests.cs new file mode 100644 index 000000000000..c6221e9e5f34 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using SqlServerIntegrationTests.Support; +using VectorDataSpecificationTests.Support; +using VectorDataSpecificationTests.VectorSearch; +using Xunit; + +namespace SqlServerIntegrationTests.VectorSearch; + +public class SqlServerBasicVectorSearchTests(SqlServerBasicVectorSearchTests.Fixture fixture) + : BasicVectorSearchTests(fixture), IClassFixture +{ + public override Task CosineSimilarity() => Assert.ThrowsAsync(() => base.CosineSimilarity()); + + public override Task DotProductSimilarity() => Assert.ThrowsAsync(() => base.DotProductSimilarity()); + + public override Task EuclideanSquaredDistance() => Assert.ThrowsAsync(() => base.EuclideanSquaredDistance()); + + public override Task Hamming() => Assert.ThrowsAsync(() => base.Hamming()); + + public override Task ManhattanDistance() => Assert.ThrowsAsync(() => base.ManhattanDistance()); + + public new class Fixture : VectorStoreFixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/BasicVectorSearchTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/BasicVectorSearchTests.cs new file mode 100644 index 000000000000..2daf579b45f3 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/BasicVectorSearchTests.cs @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Support; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace VectorDataSpecificationTests.VectorSearch; + +public abstract class BasicVectorSearchTests(VectorStoreFixture fixture) + where TKey : notnull +{ + [ConditionalFact] + public virtual Task CosineDistance() + => this.SimpleSearch(DistanceFunction.CosineDistance, 0, 2, 1, [0, 2, 1]); + + [ConditionalFact] + public virtual Task CosineSimilarity() + => this.SimpleSearch(DistanceFunction.CosineSimilarity, 1, -1, 0, [0, 2, 1]); + + [ConditionalFact] + public virtual Task DotProductSimilarity() + => this.SimpleSearch(DistanceFunction.DotProductSimilarity, 1, -1, 0, [0, 2, 1]); + + [ConditionalFact] + public virtual Task NegativeDotProductSimilarity() + => this.SimpleSearch(DistanceFunction.NegativeDotProductSimilarity, -1, 1, 0, [1, 2, 0]); + + [ConditionalFact] + public virtual Task EuclideanDistance() + => this.SimpleSearch(DistanceFunction.EuclideanDistance, 0, 2, 1.73, [0, 2, 1]); + + [ConditionalFact] + public virtual Task EuclideanSquaredDistance() + => this.SimpleSearch(DistanceFunction.EuclideanSquaredDistance, 0, 4, 3, [0, 2, 1]); + + [ConditionalFact] + public virtual Task Hamming() + => this.SimpleSearch(DistanceFunction.Hamming, 0, 1, 3, [0, 1, 2]); + + [ConditionalFact] + public virtual Task ManhattanDistance() + => this.SimpleSearch(DistanceFunction.ManhattanDistance, 0, 2, 3, [0, 1, 2]); + + protected async Task SimpleSearch(string distanceFunction, double expectedExactMatchScore, + double expectedOppositeScore, double expectedOrthogonalScore, int[] resultOrder) + { + ReadOnlyMemory baseVector = new([1, 0, 0, 0]); + ReadOnlyMemory oppositeVector = new([-1, 0, 0, 0]); + ReadOnlyMemory orthogonalVector = new([0f, -1f, -1f, 0f]); + + double[] scoreDictionary = [expectedExactMatchScore, expectedOppositeScore, expectedOrthogonalScore]; + + List records = + [ + new() + { + Key = fixture.GenerateNextKey(), + Int = 8, + Vector = baseVector, + }, + new() + { + Key = fixture.GenerateNextKey(), + Int = 9, + String = "bar", + Vector = oppositeVector, + }, + new() + { + Key = fixture.GenerateNextKey(), + Int = 9, + String = "foo", + Vector = orthogonalVector, + } + ]; + + // The record definition describes the distance function, + // so we need a dedicated collection per test. + string uniqueCollectionName = Guid.NewGuid().ToString(); + var collection = fixture.TestStore.DefaultVectorStore.GetCollection( + uniqueCollectionName, GetRecordDefinition(distanceFunction)); + + await collection.CreateCollectionAsync(); + + try + { + await collection.UpsertBatchAsync(records).ToArrayAsync(); + + var searchResult = await collection.VectorizedSearchAsync(baseVector); + var results = await searchResult.Results.ToListAsync(); + VerifySearchResults(resultOrder, scoreDictionary, records, results, includeVectors: false); + + searchResult = await collection.VectorizedSearchAsync(baseVector, new() { IncludeVectors = true}); + results = await searchResult.Results.ToListAsync(); + VerifySearchResults(resultOrder, scoreDictionary, records, results, includeVectors: true); + } + finally + { + collection.DeleteCollectionAsync(); + } + + static void VerifySearchResults(int[] resultOrder, double[] scoreDictionary, List records, + List> results, bool includeVectors) + { + Assert.Equal(records.Count, results.Count); + for (int i = 0; i < results.Count; i++) + { + Assert.Equal(records[resultOrder[i]].Key, results[i].Record.Key); + Assert.Equal(Math.Round(scoreDictionary[resultOrder[i]], 2), Math.Round(results[i].Score!.Value, 2)); + + if (includeVectors) + { + Assert.Equal(records[resultOrder[i]].Vector.ToArray(), results[i].Record.Vector.ToArray()); + } + else + { + Assert.Equal(0, results[i].Record.Vector.Length); + } + } + } + } + + private VectorStoreRecordDefinition GetRecordDefinition(string distanceFunction) + => new() + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(SearchRecord.Key), typeof(TKey)), + new VectorStoreRecordVectorProperty(nameof(SearchRecord.Vector), typeof(ReadOnlyMemory)) + { + Dimensions = 4, + DistanceFunction = distanceFunction, + }, + new VectorStoreRecordDataProperty(nameof(SearchRecord.Int), typeof(int)) { IsFilterable = true }, + new VectorStoreRecordDataProperty(nameof(SearchRecord.String), typeof(string)) { IsFilterable = true }, + ] + }; + + public class SearchRecord + { + public TKey Key { get; set; } = default!; + public ReadOnlyMemory Vector { get; set; } + + public int Int { get; set; } + public string? String { get; set; } + } +} From 30813054522ea34c16ba519d30343a9cac39612e Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Wed, 26 Feb 2025 10:46:43 +0100 Subject: [PATCH 24/32] implement IndexKind support for SqlServer and fix it for PostgreSQL: - escape the index name (by quoting it) - don't create index if it already exists as part of CreateCollectionIfNotExistsAsync - don't wrap NotSupportedException with VectorStoreOperationException when the user tried to create a query that is not supported like non-supported distance function for given index) --- ...PostgresVectorStoreCollectionSqlBuilder.cs | 3 +- ...PostgresVectorStoreCollectionSqlBuilder.cs | 4 +- .../PostgresVectorStoreDbClient.cs | 2 +- .../PostgresVectorStoreRecordCollection.cs | 4 +- .../ExceptionWrapper.cs | 6 +-- .../RecordMapper.cs | 1 - .../SqlDataReaderDictionary.cs | 15 +++--- .../SqlServerCommandBuilder.cs | 41 +++++++++----- .../SqlServerVectorStoreRecordCollection.cs | 6 ++- ...resVectorStoreCollectionSqlBuilderTests.cs | 30 ++++++++--- .../Support/PostgresFixture.cs | 10 ++++ .../PostgresBasicVectorSearchTests.cs | 8 +-- .../PostgresBasicVectorSearchTests_Hnsw.cs | 10 ++++ .../SqlServerCommandBuilderTests.cs | 4 +- .../SqlServerVectorStoreTests.cs | 54 ++++++------------- .../Support/SqlServerFixture.cs | 15 ++++++ .../SqlServerBasicVectorSearchTests.cs | 24 +++------ .../SqlServerBasicVectorSearchTests_Hnsw.cs | 31 +++++++++++ .../VectorSearch/BasicVectorSearchTests.cs | 13 +++-- 19 files changed, 174 insertions(+), 107 deletions(-) create mode 100644 dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresFixture.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests_Hnsw.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerFixture.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests_Hnsw.cs diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs index 3c864cc6537f..0175243131cd 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs @@ -48,8 +48,9 @@ internal interface IPostgresVectorStoreCollectionSqlBuilder /// The name of the vector column. /// The kind of index to create. /// The distance function to use for the index. + /// Specifies whether to include IF NOT EXISTS in the command. /// The built SQL command info. - PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction); + PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction, bool ifNotExists); /// /// Builds a SQL command to drop a table in the Postgres vector store. diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index 71d1448c85cc..74d320167c6c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -124,7 +124,7 @@ public PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tabl } /// - public PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction) + public PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction, bool ifNotExists) { // Only support creating HNSW index creation through the connector. var indexTypeName = indexKind switch @@ -149,7 +149,7 @@ public PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, strin return new PostgresSqlCommandInfo( commandText: $@" - CREATE INDEX {indexName} ON {schema}.""{tableName}"" USING {indexTypeName} (""{vectorColumnName}"" {indexOps});" + CREATE INDEX {(ifNotExists ? "IF NOT EXISTS " : "")} ""{indexName}"" ON {schema}.""{tableName}"" USING {indexTypeName} (""{vectorColumnName}"" {indexOps});" ); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs index b97b24708b25..7a2ee0604274 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -74,7 +74,7 @@ public async Task CreateTableAsync(string tableName, IReadOnlyList - this._sqlBuilder.BuildCreateVectorIndexCommand(this._schema, tableName, index.column, index.kind, index.function) + this._sqlBuilder.BuildCreateVectorIndexCommand(this._schema, tableName, index.column, index.kind, index.function, ifNotExists) ); // Execute the commands in a transaction. diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index 664a2d41a917..466c66b90349 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -319,7 +319,7 @@ private async Task RunOperationAsync(string operationName, Func operation) { await operation.Invoke().ConfigureAwait(false); } - catch (Exception ex) + catch (Exception ex) when (ex is not NotSupportedException) { throw new VectorStoreOperationException("Call to vector store failed.", ex) { @@ -336,7 +336,7 @@ private async Task RunOperationAsync(string operationName, Func> o { return await operation.Invoke().ConfigureAwait(false); } - catch (Exception ex) + catch (Exception ex) when (ex is not NotSupportedException) { throw new VectorStoreOperationException("Call to vector store failed.", ex) { diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/ExceptionWrapper.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/ExceptionWrapper.cs index 4a5fa37c0829..452887ea7dd1 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/ExceptionWrapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/ExceptionWrapper.cs @@ -1,10 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Data.SqlClient; @@ -12,6 +8,8 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer; +#pragma warning disable CA1068 // CancellationToken parameters must come last + internal static class ExceptionWrapper { private const string VectorStoreType = "SqlServer"; diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs index 0703e2b536f8..b2a5cb6d2cec 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.Reflection; -using System.Text.Json; using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.SqlServer; diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs index 62e8637e26fd..3c10c8f31d46 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs @@ -5,7 +5,6 @@ using System.Collections.Generic; using System.Text.Json; using Microsoft.Data.SqlClient; -using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.SqlServer; @@ -63,13 +62,13 @@ internal SqlDataReaderDictionary(SqlDataReader sqlDataReader, IReadOnlyList this.Unwrap(key, _sqlDataReader[key]); + get => this.Unwrap(key, this._sqlDataReader[key]); set => throw new InvalidOperationException(); } - public ICollection Keys => GetDictionary().Keys; + public ICollection Keys => this.GetDictionary().Keys; - public ICollection Values => GetDictionary().Values; + public ICollection Values => this.GetDictionary().Values; public int Count => this._sqlDataReader.FieldCount; @@ -82,7 +81,7 @@ public object? this[string key] public void Clear() => throw new InvalidOperationException(); public bool Contains(KeyValuePair item) - => TryGetValue(item.Key, out var value) && Equals(value, item.Value); + => this.TryGetValue(item.Key, out var value) && Equals(value, item.Value); public bool ContainsKey(string key) { @@ -97,13 +96,13 @@ public bool ContainsKey(string key) } public void CopyTo(KeyValuePair[] array, int arrayIndex) - => ((ICollection>)GetDictionary()).CopyTo(array, arrayIndex); + => ((ICollection>)this.GetDictionary()).CopyTo(array, arrayIndex); public IEnumerator> GetEnumerator() - => GetDictionary().GetEnumerator(); + => this.GetDictionary().GetEnumerator(); IEnumerator IEnumerable.GetEnumerator() - => GetDictionary().GetEnumerator(); + => this.GetDictionary().GetEnumerator(); public bool Remove(string key) => throw new InvalidOperationException(); diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index ca1a9bc32cdf..cb11984aba6e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -30,6 +30,7 @@ internal static SqlCommand CreateTable( sb.AppendTableName(schema, tableName); sb.AppendLine("', N'U') IS NULL"); } + sb.AppendLine("BEGIN"); sb.Append("CREATE TABLE "); sb.AppendTableName(schema, tableName); sb.AppendLine(" ("); @@ -49,7 +50,21 @@ internal static SqlCommand CreateTable( } sb.AppendFormat("PRIMARY KEY NONCLUSTERED ([{0}])", keyColumnName); sb.AppendLine(); - sb.Append(')'); // end the table definition + sb.AppendLine(");"); // end the table definition + + foreach (var vectorProperty in vectorProperties) + { + switch (vectorProperty.IndexKind) + { + case null: + case "": + case IndexKind.Flat: + break; + default: + throw new NotSupportedException($"Index kind {vectorProperty.IndexKind} is not supported."); + } + } + sb.Append("END;"); return connection.CreateCommand(sb); } @@ -311,17 +326,7 @@ internal static SqlCommand SelectVector( { string distanceFunction = vectorProperty.DistanceFunction ?? DistanceFunction.CosineDistance; // Source: https://learn.microsoft.com/sql/t-sql/functions/vector-distance-transact-sql - (string distanceMetric, string sorting) = distanceFunction switch - { - // A value of 0 indicates that the vectors are identical in direction (cosine similarity of 1), - // while a value of 1 indicates that the vectors are orthogonal (cosine similarity of 0). - DistanceFunction.CosineDistance => ("cosine", "ASC"), - // A value of 0 indicates that the vectors are identical, while larger values indicate greater dissimilarity. - DistanceFunction.EuclideanDistance => ("euclidean", "ASC"), - // A value closer to 0 indicates higher similarity, while more negative values indicate greater dissimilarity. - DistanceFunction.NegativeDotProductSimilarity => ("dot", "DESC"), - _ => throw new NotSupportedException($"Distance function {vectorProperty.DistanceFunction} is not supported.") - }; + (string distanceMetric, string sorting) = MapDistanceFunction(distanceFunction); SqlCommand command = connection.CreateCommand(); command.Parameters.AddWithValue("@vector", JsonSerializer.Serialize(vector)); @@ -519,4 +524,16 @@ private static (string sqlName, string? autoGenerate) Map(Type type) _ => throw new NotSupportedException($"Type {type} is not supported.") }; } + + private static (string distanceMetric, string sorting) MapDistanceFunction(string name) => name switch + { + // A value of 0 indicates that the vectors are identical in direction (cosine similarity of 1), + // while a value of 1 indicates that the vectors are orthogonal (cosine similarity of 0). + DistanceFunction.CosineDistance => ("COSINE", "ASC"), + // A value of 0 indicates that the vectors are identical, while larger values indicate greater dissimilarity. + DistanceFunction.EuclideanDistance => ("EUCLIDEAN", "ASC"), + // A value closer to 0 indicates higher similarity, while more negative values indicate greater dissimilarity. + DistanceFunction.NegativeDotProductSimilarity => ("DOT", "DESC"), + _ => throw new NotSupportedException($"Distance function {name} is not supported.") + }; } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index 1e4bd541e4f2..e96201ccc03a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -14,8 +14,10 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer; /// /// An implementation of backed by a SQL Server or Azure SQL database. /// -public sealed class SqlServerVectorStoreRecordCollection : IVectorStoreRecordCollection - where TKey : notnull +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix (Collection) +public sealed class SqlServerVectorStoreRecordCollection +#pragma warning restore CA1711 + : IVectorStoreRecordCollection where TKey : notnull { private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); private static readonly SqlServerVectorStoreRecordCollectionOptions s_defaultOptions = new(); diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs index e1958f934c5d..60dd98f45e7a 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs @@ -76,10 +76,13 @@ public void TestBuildCreateTableCommand(bool ifNotExists) } [Theory] - [InlineData(IndexKind.Hnsw, DistanceFunction.EuclideanDistance)] - [InlineData(IndexKind.IvfFlat, DistanceFunction.DotProductSimilarity)] - [InlineData(IndexKind.Hnsw, DistanceFunction.CosineDistance)] - public void TestBuildCreateIndexCommand(string indexKind, string distanceFunction) + [InlineData(IndexKind.Hnsw, DistanceFunction.EuclideanDistance, true)] + [InlineData(IndexKind.Hnsw, DistanceFunction.EuclideanDistance, false)] + [InlineData(IndexKind.IvfFlat, DistanceFunction.DotProductSimilarity, true)] + [InlineData(IndexKind.IvfFlat, DistanceFunction.DotProductSimilarity, false)] + [InlineData(IndexKind.Hnsw, DistanceFunction.CosineDistance, true)] + [InlineData(IndexKind.Hnsw, DistanceFunction.CosineDistance, false)] + public void TestBuildCreateIndexCommand(string indexKind, string distanceFunction, bool ifNotExists) { var builder = new PostgresVectorStoreCollectionSqlBuilder(); @@ -87,15 +90,28 @@ public void TestBuildCreateIndexCommand(string indexKind, string distanceFunctio if (indexKind != IndexKind.Hnsw) { - Assert.Throws(() => builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction)); + Assert.Throws(() => builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction, ifNotExists)); + Assert.Throws(() => builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction, ifNotExists)); return; } - var cmdInfo = builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction); + var cmdInfo = builder.BuildCreateVectorIndexCommand("public", "1testcollection", vectorColumn, indexKind, distanceFunction, ifNotExists); // Check for expected properties; integration tests will validate the actual SQL. Assert.Contains("CREATE INDEX ", cmdInfo.CommandText); - Assert.Contains("ON public.\"testcollection\" USING hnsw (\"embedding1\" ", cmdInfo.CommandText); + // Make sure ifNotExists is respected + if (ifNotExists) + { + Assert.Contains("CREATE INDEX IF NOT EXISTS", cmdInfo.CommandText); + } + else + { + Assert.DoesNotContain("CREATE INDEX IF NOT EXISTS", cmdInfo.CommandText); + } + // Make sure the name is escaped, so names starting with a digit are OK. + Assert.Contains($"\"1testcollection_{vectorColumn}_index\"", cmdInfo.CommandText); + + Assert.Contains("ON public.\"1testcollection\" USING hnsw (\"embedding1\" ", cmdInfo.CommandText); if (distanceFunction == null) { // Check for distance function defaults to cosine distance diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresFixture.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresFixture.cs new file mode 100644 index 000000000000..6c8ce87ad984 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace PostgresIntegrationTests.Support; + +public class PostgresFixture : VectorStoreFixture +{ + public override TestStore TestStore => PostgresTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests.cs index 823f421d0f1e..60b1449380f3 100644 --- a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests.cs @@ -7,17 +7,11 @@ namespace PostgresIntegrationTests.VectorSearch; -public class PostgresBasicVectorSearchTests(PostgresBasicVectorSearchTests.Fixture fixture) - : BasicVectorSearchTests(fixture), IClassFixture +public class PostgresBasicVectorSearchTests(PostgresFixture fixture) : BasicVectorSearchTests(fixture), IClassFixture { public override Task EuclideanSquaredDistance() => Assert.ThrowsAsync(() => base.EuclideanSquaredDistance()); public override Task Hamming() => Assert.ThrowsAsync(() => base.Hamming()); public override Task NegativeDotProductSimilarity() => Assert.ThrowsAsync(() => base.NegativeDotProductSimilarity()); - - public new class Fixture : VectorStoreFixture - { - public override TestStore TestStore => PostgresTestStore.Instance; - } } diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests_Hnsw.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests_Hnsw.cs new file mode 100644 index 000000000000..81bad383df99 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests_Hnsw.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PostgresIntegrationTests.Support; + +namespace PostgresIntegrationTests.VectorSearch; + +public class PostgresBasicVectorSearchTests_Hnsw(PostgresFixture fixture) : PostgresBasicVectorSearchTests(fixture) +{ + protected override string IndexKind => Microsoft.Extensions.VectorData.IndexKind.Hnsw; +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index 58de196220a7..06039d56af8a 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -128,13 +128,15 @@ public void CreateTable(bool ifNotExists) string expectedCommand = """ + BEGIN CREATE TABLE [schema].[table] ( [id] BIGINT IDENTITY(1,1), [simpleName] NVARCHAR(255) COLLATE Latin1_General_100_BIN2, [with space] INT, [embedding] VECTOR(10), PRIMARY KEY NONCLUSTERED ([id]) - ) + ); + END; """; if (ifNotExists) { diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs index a3c72b9a848f..983a382a043d 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -5,23 +5,21 @@ using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.SqlServer; using SqlServerIntegrationTests.Support; +using VectorDataSpecificationTests.Xunit; using Xunit; namespace SqlServerIntegrationTests; -public class SqlServerVectorStoreTests +public class SqlServerVectorStoreTests(SqlServerFixture fixture) : IClassFixture { // this test may be once executed by multiple users against a shared db instance private static string GetUniqueCollectionName() => Guid.NewGuid().ToString(); - [Fact] + [ConditionalFact] public async Task CollectionCRUD() { string collectionName = GetUniqueCollectionName(); - SqlServerTestStore testStore = new(); - - await testStore.ReferenceCountingStartAsync(); - + var testStore = fixture.TestStore; var collection = testStore.DefaultVectorStore.GetCollection(collectionName); try @@ -47,19 +45,14 @@ public async Task CollectionCRUD() finally { await collection.DeleteCollectionAsync(); - - await testStore.ReferenceCountingStopAsync(); } } - [Fact] + [ConditionalFact] public async Task RecordCRUD() { string collectionName = GetUniqueCollectionName(); - SqlServerTestStore testStore = new(); - - await testStore.ReferenceCountingStartAsync(); - + var testStore = fixture.TestStore; var collection = testStore.DefaultVectorStore.GetCollection(collectionName); try @@ -112,19 +105,14 @@ public async Task RecordCRUD() finally { await collection.DeleteCollectionAsync(); - - await testStore.ReferenceCountingStopAsync(); } } - [Fact] + [ConditionalFact] public async Task WrongModels() { string collectionName = GetUniqueCollectionName(); - SqlServerTestStore testStore = new(); - - await testStore.ReferenceCountingStartAsync(); - + var testStore = fixture.TestStore; var collection = testStore.DefaultVectorStore.GetCollection(collectionName); try @@ -160,12 +148,10 @@ public async Task WrongModels() finally { await collection.DeleteCollectionAsync(); - - await testStore.ReferenceCountingStopAsync(); } } - [Fact] + [ConditionalFact] public async Task CustomMapper() { string collectionName = GetUniqueCollectionName(); @@ -218,14 +204,11 @@ public async Task CustomMapper() } } - [Fact] + [ConditionalFact] public async Task BatchCRUD() { string collectionName = GetUniqueCollectionName(); - SqlServerTestStore testStore = new(); - - await testStore.ReferenceCountingStartAsync(); - + var testStore = fixture.TestStore; var collection = testStore.DefaultVectorStore.GetCollection(collectionName); try @@ -277,8 +260,6 @@ public async Task BatchCRUD() finally { await collection.DeleteCollectionAsync(); - - await testStore.ReferenceCountingStopAsync(); } } @@ -339,22 +320,19 @@ public sealed class DifferentStorageNames public ReadOnlyMemory Floats { get; set; } } - [Fact] + [ConditionalFact] public Task CanUseFancyModels_Int() => this.CanUseFancyModels(); - [Fact] + [ConditionalFact] public Task CanUseFancyModels_Long() => this.CanUseFancyModels(); - [Fact] + [ConditionalFact] public Task CanUseFancyModels_Guid() => this.CanUseFancyModels(); private async Task CanUseFancyModels() where TKey : notnull { string collectionName = GetUniqueCollectionName(); - SqlServerTestStore testStore = new(); - - await testStore.ReferenceCountingStartAsync(); - + var testStore = fixture.TestStore; var collection = testStore.DefaultVectorStore.GetCollection>(collectionName); try @@ -396,8 +374,6 @@ private async Task CanUseFancyModels() where TKey : notnull finally { await collection.DeleteCollectionAsync(); - - await testStore.ReferenceCountingStopAsync(); } void AssertEquality(FancyTestModel expected, FancyTestModel? received, TKey expectedKey) diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerFixture.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerFixture.cs new file mode 100644 index 000000000000..63bdb2360d06 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerFixture.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using VectorDataSpecificationTests.Support; + +namespace SqlServerIntegrationTests.Support; + +public class SqlServerFixture : VectorStoreFixture +{ + public override TestStore TestStore => SqlServerTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests.cs index c6221e9e5f34..c101e8667c4e 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests.cs @@ -1,10 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; using SqlServerIntegrationTests.Support; using VectorDataSpecificationTests.Support; using VectorDataSpecificationTests.VectorSearch; @@ -12,21 +7,16 @@ namespace SqlServerIntegrationTests.VectorSearch; -public class SqlServerBasicVectorSearchTests(SqlServerBasicVectorSearchTests.Fixture fixture) - : BasicVectorSearchTests(fixture), IClassFixture +public class SqlServerBasicVectorSearchTests(SqlServerFixture fixture) + : BasicVectorSearchTests(fixture), IClassFixture { - public override Task CosineSimilarity() => Assert.ThrowsAsync(() => base.CosineSimilarity()); + public override Task CosineSimilarity() => Assert.ThrowsAsync(base.CosineSimilarity); - public override Task DotProductSimilarity() => Assert.ThrowsAsync(() => base.DotProductSimilarity()); + public override Task DotProductSimilarity() => Assert.ThrowsAsync(base.DotProductSimilarity); - public override Task EuclideanSquaredDistance() => Assert.ThrowsAsync(() => base.EuclideanSquaredDistance()); + public override Task EuclideanSquaredDistance() => Assert.ThrowsAsync(base.EuclideanSquaredDistance); - public override Task Hamming() => Assert.ThrowsAsync(() => base.Hamming()); + public override Task Hamming() => Assert.ThrowsAsync(base.Hamming); - public override Task ManhattanDistance() => Assert.ThrowsAsync(() => base.ManhattanDistance()); - - public new class Fixture : VectorStoreFixture - { - public override TestStore TestStore => SqlServerTestStore.Instance; - } + public override Task ManhattanDistance() => Assert.ThrowsAsync(base.ManhattanDistance); } diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests_Hnsw.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests_Hnsw.cs new file mode 100644 index 000000000000..1f37ce71945f --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests_Hnsw.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServerIntegrationTests.Support; +using Xunit; + +namespace SqlServerIntegrationTests.VectorSearch; + +public class SqlServerBasicVectorSearchTests_Hnsw(SqlServerFixture fixture) + : SqlServerBasicVectorSearchTests(fixture) +{ + // Creating such a collection is not supported. + protected override string IndexKind => Microsoft.Extensions.VectorData.IndexKind.Hnsw; + + public override async Task CosineDistance() + { + NotSupportedException ex = await Assert.ThrowsAsync(() => base.CosineDistance()); + Assert.Equal($"Index kind {this.IndexKind} is not supported.", ex.Message); + } + + public override async Task EuclideanDistance() + { + NotSupportedException ex = await Assert.ThrowsAsync(() => base.EuclideanDistance()); + Assert.Equal($"Index kind {this.IndexKind} is not supported.", ex.Message); + } + + public override async Task NegativeDotProductSimilarity() + { + NotSupportedException ex = await Assert.ThrowsAsync(() => base.NegativeDotProductSimilarity()); + Assert.Equal($"Index kind {this.IndexKind} is not supported.", ex.Message); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/BasicVectorSearchTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/BasicVectorSearchTests.cs index 2daf579b45f3..f2079c791b59 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/BasicVectorSearchTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/BasicVectorSearchTests.cs @@ -42,6 +42,8 @@ public virtual Task Hamming() public virtual Task ManhattanDistance() => this.SimpleSearch(DistanceFunction.ManhattanDistance, 0, 2, 3, [0, 1, 2]); + protected virtual string? IndexKind => null; + protected async Task SimpleSearch(string distanceFunction, double expectedExactMatchScore, double expectedOppositeScore, double expectedOrthogonalScore, int[] resultOrder) { @@ -79,10 +81,12 @@ protected async Task SimpleSearch(string distanceFunction, double expectedExactM // so we need a dedicated collection per test. string uniqueCollectionName = Guid.NewGuid().ToString(); var collection = fixture.TestStore.DefaultVectorStore.GetCollection( - uniqueCollectionName, GetRecordDefinition(distanceFunction)); + uniqueCollectionName, this.GetRecordDefinition(distanceFunction)); await collection.CreateCollectionAsync(); + await collection.CreateCollectionIfNotExistsAsync(); // just to make sure it's idempotent + try { await collection.UpsertBatchAsync(records).ToArrayAsync(); @@ -91,13 +95,13 @@ protected async Task SimpleSearch(string distanceFunction, double expectedExactM var results = await searchResult.Results.ToListAsync(); VerifySearchResults(resultOrder, scoreDictionary, records, results, includeVectors: false); - searchResult = await collection.VectorizedSearchAsync(baseVector, new() { IncludeVectors = true}); + searchResult = await collection.VectorizedSearchAsync(baseVector, new() { IncludeVectors = true }); results = await searchResult.Results.ToListAsync(); VerifySearchResults(resultOrder, scoreDictionary, records, results, includeVectors: true); } finally { - collection.DeleteCollectionAsync(); + await collection.DeleteCollectionAsync(); } static void VerifySearchResults(int[] resultOrder, double[] scoreDictionary, List records, @@ -107,6 +111,8 @@ static void VerifySearchResults(int[] resultOrder, double[] scoreDictionary, Lis for (int i = 0; i < results.Count; i++) { Assert.Equal(records[resultOrder[i]].Key, results[i].Record.Key); + Assert.Equal(records[resultOrder[i]].Int, results[i].Record.Int); + Assert.Equal(records[resultOrder[i]].String, results[i].Record.String); Assert.Equal(Math.Round(scoreDictionary[resultOrder[i]], 2), Math.Round(results[i].Score!.Value, 2)); if (includeVectors) @@ -131,6 +137,7 @@ private VectorStoreRecordDefinition GetRecordDefinition(string distanceFunction) { Dimensions = 4, DistanceFunction = distanceFunction, + IndexKind = this.IndexKind }, new VectorStoreRecordDataProperty(nameof(SearchRecord.Int), typeof(int)) { IsFilterable = true }, new VectorStoreRecordDataProperty(nameof(SearchRecord.String), typeof(string)) { IsFilterable = true }, From 5b843aaac632b880ffc3fcd9aab16e6f17f14272 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Wed, 26 Feb 2025 15:43:16 +0100 Subject: [PATCH 25/32] fix the build --- dotnet/src/IntegrationTests/IntegrationTests.csproj | 1 + .../VectorSearch/PostgresBasicVectorSearchTests.cs | 1 - .../SqlServerIntegrationTests.csproj | 2 +- .../SqlServerVectorStoreTests.cs | 5 ++--- .../Support/SqlServerFixture.cs | 5 ----- .../Support/SqlServerTestStore.cs | 11 +++++++---- .../VectorSearch/SqlServerBasicVectorSearchTests.cs | 1 - 7 files changed, 11 insertions(+), 15 deletions(-) diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index 26cfaa1949ae..ff59cf5f3136 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -41,6 +41,7 @@ + diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests.cs index 60b1449380f3..5f3f51494cf4 100644 --- a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using PostgresIntegrationTests.Support; -using VectorDataSpecificationTests.Support; using VectorDataSpecificationTests.VectorSearch; using Xunit; diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj index 54d75b8ebc6a..4752d82818dc 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj @@ -8,7 +8,7 @@ false true - $(NoWarn);CA2007,SKEXP0001,SKEXP0020,VSTHRD111 + $(NoWarn);CA2007,SKEXP0001,SKEXP0020,VSTHRD111;CS1685 b7762d10-e29b-4bb1-8b74-b6d69a667dd4 diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs index 983a382a043d..d5d10e8a9e79 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Text.Json; using Microsoft.Data.SqlClient; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.SqlServer; @@ -422,7 +421,7 @@ private sealed class TestModelMapper : IVectorStoreRecordMapper MapFromDataToStorageModel(TestModel dataModel) { - MapFromDataToStorageModel_WasCalled = true; + this.MapFromDataToStorageModel_WasCalled = true; return new Dictionary() { @@ -436,7 +435,7 @@ private sealed class TestModelMapper : IVectorStoreRecordMapper storageModel, StorageToDataModelMapperOptions options) { - MapFromStorageToDataModel_WasCalled = true; + this.MapFromStorageToDataModel_WasCalled = true; return new() { diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerFixture.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerFixture.cs index 63bdb2360d06..dabf7b40609e 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerFixture.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerFixture.cs @@ -1,10 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; using VectorDataSpecificationTests.Support; namespace SqlServerIntegrationTests.Support; diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs index e3bfceb54bc1..211bc006d5d1 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs @@ -7,7 +7,7 @@ namespace SqlServerIntegrationTests.Support; -public sealed class SqlServerTestStore : TestStore +public sealed class SqlServerTestStore : TestStore, IDisposable { public static readonly SqlServerTestStore Instance = new(); @@ -18,16 +18,19 @@ public override IVectorStore DefaultVectorStore private SqlServerVectorStore? _connectedStore; - protected override async Task StartAsync() + protected override Task StartAsync() { if (string.IsNullOrWhiteSpace(SqlServerTestEnvironment.ConnectionString)) { throw new InvalidOperationException("Connection string is not configured, set the SqlServer:ConnectionString environment variable"); } +#pragma warning disable CA2000 // Dispose objects before losing scope SqlConnection connection = new(SqlServerTestEnvironment.ConnectionString); - await connection.OpenAsync(); - +#pragma warning restore CA2000 // Dispose objects before losing scope this._connectedStore = new(connection); + return connection.OpenAsync(); } + + public void Dispose() => this._connectedStore?.Dispose(); } diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests.cs index c101e8667c4e..c321d4b51858 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using SqlServerIntegrationTests.Support; -using VectorDataSpecificationTests.Support; using VectorDataSpecificationTests.VectorSearch; using Xunit; From 8bb8aeae459ebe3431801afa002b98c09d877d08 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Wed, 26 Feb 2025 17:36:01 +0100 Subject: [PATCH 26/32] throw for null inputs, do nothing for empty ones --- ...PostgresVectorStoreCollectionSqlBuilder.cs | 9 -- .../PostgresVectorStoreDbClient.cs | 18 +++- .../PostgresVectorStoreRecordCollection.cs | 9 ++ .../SqlServerCommandBuilder.cs | 30 +++--- .../SqlServerVectorStoreRecordCollection.cs | 21 +++- .../CRUD/PostgresBasicConformanceTests.cs | 11 ++ .../CRUD/SqlServerBasicConformanceTests.cs | 11 ++ .../SqlServerCommandBuilderTests.cs | 6 +- .../CRUD/BasicConformanceTests.cs | 102 ++++++++++++++++++ 9 files changed, 188 insertions(+), 29 deletions(-) create mode 100644 dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresBasicConformanceTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBasicConformanceTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BasicConformanceTests.cs diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index 74d320167c6c..aa6b1ce1415c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -281,11 +281,6 @@ public PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string t { NpgsqlDbType? keyType = PostgresVectorStoreRecordPropertyMapping.GetNpgsqlDbType(typeof(TKey)) ?? throw new ArgumentException($"Unsupported key type {typeof(TKey).Name}"); - if (keys == null || keys.Count == 0) - { - throw new ArgumentException("Keys cannot be null or empty", nameof(keys)); - } - var keyProperty = properties.OfType().FirstOrDefault() ?? throw new ArgumentException("Properties must contain a key property", nameof(properties)); var keyColumn = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; @@ -327,10 +322,6 @@ DELETE FROM {schema}."{tableName}" public PostgresSqlCommandInfo BuildDeleteBatchCommand(string schema, string tableName, string keyColumn, List keys) { NpgsqlDbType? keyType = PostgresVectorStoreRecordPropertyMapping.GetNpgsqlDbType(typeof(TKey)) ?? throw new ArgumentException($"Unsupported key type {typeof(TKey).Name}"); - if (keys == null || keys.Count == 0) - { - throw new ArgumentException("Keys cannot be null or empty", nameof(keys)); - } for (int i = 0; i < keys.Count; i++) { diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs index 7a2ee0604274..07c228540038 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -152,11 +152,19 @@ public async Task UpsertBatchAsync(string tableName, IEnumerable> GetBatchAsync(string tableName, IEnumerable keys, IReadOnlyList properties, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) where TKey : notnull { + Verify.NotNull(keys); + + List listOfKeys = keys.ToList(); + if (listOfKeys.Count == 0) + { + yield break; + } + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); await using (connection) { - var commandInfo = this._sqlBuilder.BuildGetBatchCommand(this._schema, tableName, properties, keys.ToList(), includeVectors); + var commandInfo = this._sqlBuilder.BuildGetBatchCommand(this._schema, tableName, properties, listOfKeys, includeVectors); using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) @@ -198,7 +206,13 @@ public async Task DeleteAsync(string tableName, string keyColumn, TKey key /// public async Task DeleteBatchAsync(string tableName, string keyColumn, IEnumerable keys, CancellationToken cancellationToken = default) { - var commandInfo = this._sqlBuilder.BuildDeleteBatchCommand(this._schema, tableName, keyColumn, keys.ToList()); + var listOfKeys = keys.ToList(); + if (listOfKeys.Count == 0) + { + return; + } + + var commandInfo = this._sqlBuilder.BuildDeleteBatchCommand(this._schema, tableName, keyColumn, listOfKeys); await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index 466c66b90349..66ebd32b5762 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -168,6 +168,8 @@ public virtual Task UpsertAsync(TRecord record, CancellationToken cancella /// public virtual async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + Verify.NotNull(records); + const string OperationName = "UpsertBatch"; var storageModels = records.Select(record => VectorStoreErrorHandler.RunModelConversion( @@ -176,6 +178,11 @@ public virtual async IAsyncEnumerable UpsertBatchAsync(IEnumerable this._mapper.MapFromDataToStorageModel(record))).ToList(); + if (storageModels.Count == 0) + { + yield break; + } + var keys = storageModels.Select(model => model[this._propertyReader.KeyPropertyStoragePropertyName]!).ToList(); await this.RunOperationAsync(OperationName, () => @@ -243,6 +250,8 @@ public virtual Task DeleteAsync(TKey key, CancellationToken cancellationToken = /// public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) { + Verify.NotNull(keys); + const string OperationName = "DeleteBatch"; return this.RunOperationAsync(OperationName, () => this._client.DeleteBatchAsync(this.CollectionName, this._propertyReader.KeyPropertyStoragePropertyName, keys, cancellationToken) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index cb11984aba6e..5915324f5148 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -161,7 +161,7 @@ internal static SqlCommand MergeIntoSingle( return command; } - internal static SqlCommand MergeIntoMany( + internal static SqlCommand? MergeIntoMany( SqlConnection connection, string schema, string tableName, @@ -196,7 +196,7 @@ internal static SqlCommand MergeIntoMany( if (rowIndex == 0) { - throw new ArgumentException("The value cannot be empty.", nameof(records)); + return null; // there is nothing to do! } sb.Length -= (1 + Environment.NewLine.Length); // remove the last comma and newline @@ -252,7 +252,7 @@ internal static SqlCommand DeleteSingle( return command; } - internal static SqlCommand DeleteMany( + internal static SqlCommand? DeleteMany( SqlConnection connection, string schema, string tableName, VectorStoreRecordKeyProperty keyProperty, IEnumerable keys) { @@ -262,9 +262,14 @@ internal static SqlCommand DeleteMany( sb.Append("DELETE FROM "); sb.AppendTableName(schema, tableName); sb.AppendFormat(" WHERE [{0}] IN (", GetColumnName(keyProperty)); - sb.AppendKeyParameterList(keys, command, keyProperty); + sb.AppendKeyParameterList(keys, command, keyProperty, out bool emptyKeys); sb.Append(')'); // close the IN clause + if (emptyKeys) + { + return null; // there is nothing to do! + } + command.CommandText = sb.ToString(); return command; } @@ -293,7 +298,7 @@ internal static SqlCommand SelectSingle( return command; } - internal static SqlCommand SelectMany( + internal static SqlCommand? SelectMany( SqlConnection connection, string schema, string tableName, VectorStoreRecordKeyProperty keyProperty, IReadOnlyList properties, @@ -309,9 +314,14 @@ internal static SqlCommand SelectMany( sb.AppendTableName(schema, tableName); sb.AppendLine(); sb.AppendFormat("WHERE [{0}] IN (", GetColumnName(keyProperty)); - sb.AppendKeyParameterList(keys, command, keyProperty); + sb.AppendKeyParameterList(keys, command, keyProperty, out bool emptyKeys); sb.Append(')'); // close the IN clause + if (emptyKeys) + { + return null; // there is nothing to do! + } + command.CommandText = sb.ToString(); return command; } @@ -449,7 +459,7 @@ private static StringBuilder AppendColumnNames(this StringBuilder sb, } private static StringBuilder AppendKeyParameterList(this StringBuilder sb, - IEnumerable keys, SqlCommand command, VectorStoreRecordKeyProperty keyProperty) + IEnumerable keys, SqlCommand command, VectorStoreRecordKeyProperty keyProperty, out bool emptyKeys) { int keyIndex = 0; foreach (TKey key in keys) @@ -463,11 +473,7 @@ private static StringBuilder AppendKeyParameterList(this StringBuilder sb, command.AddParameter(keyProperty, keyParamName, key); } - if (keyIndex == 0) - { - throw new ArgumentException("The value cannot be empty.", nameof(keys)); - } - + emptyKeys = keyIndex == 0; sb.Length--; // remove the last comma return sb; } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index e96201ccc03a..5ca6bafedbd1 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -167,13 +167,18 @@ public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken can { Verify.NotNull(keys); - using SqlCommand command = SqlServerCommandBuilder.DeleteMany( + using SqlCommand? command = SqlServerCommandBuilder.DeleteMany( this._sqlConnection, this._options.Schema, this.CollectionName, this._propertyReader.KeyProperty, keys); + if (command is null) + { + return; // keys is empty, there is nothing to delete + } + await ExceptionWrapper.WrapAsync(this._sqlConnection, command, static (cmd, ct) => cmd.ExecuteNonQueryAsync(ct), cancellationToken, "DeleteBatch", this.CollectionName).ConfigureAwait(false); @@ -213,7 +218,7 @@ public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, Get { Verify.NotNull(keys); - using SqlCommand command = SqlServerCommandBuilder.SelectMany( + using SqlCommand? command = SqlServerCommandBuilder.SelectMany( this._sqlConnection, this._options.Schema, this.CollectionName, @@ -221,6 +226,11 @@ public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, Get this._propertyReader.Properties, keys); + if (command is null) + { + yield break; // keys is empty + } + using SqlDataReader reader = await ExceptionWrapper.WrapAsync(this._sqlConnection, command, static (cmd, ct) => cmd.ExecuteReaderAsync(ct), cancellationToken, "GetBatch", this.CollectionName).ConfigureAwait(false); @@ -261,7 +271,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record { Verify.NotNull(records); - using SqlCommand command = SqlServerCommandBuilder.MergeIntoMany( + using SqlCommand? command = SqlServerCommandBuilder.MergeIntoMany( this._sqlConnection, this._options.Schema, this.CollectionName, @@ -269,6 +279,11 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record this._propertyReader.Properties, records.Select(record => this._mapper.MapFromDataToStorageModel(record))); + if (command is null) + { + yield break; // records is empty + } + using SqlDataReader reader = await ExceptionWrapper.WrapAsync(this._sqlConnection, command, static (cmd, ct) => cmd.ExecuteReaderAsync(ct), cancellationToken, "GetBatch", this.CollectionName).ConfigureAwait(false); diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresBasicConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresBasicConformanceTests.cs new file mode 100644 index 000000000000..687bd680f382 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresBasicConformanceTests.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PostgresIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace PostgresIntegrationTests.CRUD; + +public class PostgresBasicConformanceTests(PostgresFixture fixture) : BasicConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBasicConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBasicConformanceTests.cs new file mode 100644 index 000000000000..c06fd9bcc9f6 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBasicConformanceTests.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServerIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace SqlServerIntegrationTests.CRUD; + +public class SqlServerBasicConformanceTests(SqlServerFixture fixture) : BasicConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index 06039d56af8a..08866cc74051 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -233,7 +233,7 @@ public void MergeIntoMany() using SqlConnection connection = CreateConnection(); using SqlCommand command = SqlServerCommandBuilder.MergeIntoMany(connection, "schema", "table", - keyProperty, properties, records); + keyProperty, properties, records)!; string expectedCommand = """" @@ -289,7 +289,7 @@ public void DeleteMany() using SqlConnection connection = CreateConnection(); using SqlCommand command = SqlServerCommandBuilder.DeleteMany(connection, - "schema", "tableName", keyProperty, keys); + "schema", "tableName", keyProperty, keys)!; Assert.Equal("DELETE FROM [schema].[tableName] WHERE [id] IN (@id_0,@id_1)", command.CommandText); for (int i = 0; i < keys.Length; i++) @@ -344,7 +344,7 @@ public void SelectMany() using SqlConnection connection = CreateConnection(); using SqlCommand command = SqlServerCommandBuilder.SelectMany(connection, - "schema", "tableName", keyProperty, properties, keys); + "schema", "tableName", keyProperty, properties, keys)!; AssertEqualIgnoreNewLines( """"" diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BasicConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BasicConformanceTests.cs new file mode 100644 index 000000000000..f2edc990bfad --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BasicConformanceTests.cs @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Support; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace VectorDataSpecificationTests.CRUD; + +public abstract class BasicConformanceTests(VectorStoreFixture fixture) +{ + protected virtual string GetUniqueCollectionName() => Guid.NewGuid().ToString(); + + [ConditionalFact] + public async Task UpsertBatchAsync_EmptyBatch_DoesNotThrow() + { + await this.ExecuteAsync(async collection => + { + Assert.Empty(await collection.UpsertBatchAsync([]).ToArrayAsync()); + }); + } + + [ConditionalFact] + public async Task DeleteBatchAsync_EmptyBatch_DoesNotThrow() + { + await this.ExecuteAsync(async collection => + { + await collection.DeleteBatchAsync([]); + }); + } + + [ConditionalFact] + public async Task GetBatchAsync_EmptyBatch_DoesNotThrow() + { + await this.ExecuteAsync(async collection => + { + Assert.Empty(await collection.GetBatchAsync([]).ToArrayAsync()); + }); + } + + [ConditionalFact] + public async Task UpsertBatchAsync_NullBatch_ThrowsArgumentNullException() + { + await this.ExecuteAsync(async collection => + { + ArgumentNullException ex = await Assert.ThrowsAsync(() => collection.UpsertBatchAsync(records: null!).ToArrayAsync().AsTask()); + Assert.Equal("records", ex.ParamName); + }); + } + + [ConditionalFact] + public async Task DeleteBatchAsync_NullKeys_ThrowsArgumentNullException() + { + await this.ExecuteAsync(async collection => + { + ArgumentNullException ex = await Assert.ThrowsAsync(() => collection.DeleteBatchAsync(keys: null!)); + Assert.Equal("keys", ex.ParamName); + }); + } + + [ConditionalFact] + public async Task GetBatchAsync_NullKeys_ThrowsArgumentNullException() + { + await this.ExecuteAsync(async collection => + { + ArgumentNullException ex = await Assert.ThrowsAsync(() => collection.GetBatchAsync(keys: null!).ToArrayAsync().AsTask()); + Assert.Equal("keys", ex.ParamName); + }); + } + + private async Task ExecuteAsync(Func, Task> test) + { + string collectionName = this.GetUniqueCollectionName(); + var collection = fixture.TestStore.DefaultVectorStore.GetCollection(collectionName); + + await collection.CreateCollectionAsync(); + + try + { + await test(collection); + } + finally + { + await collection.DeleteCollectionAsync(); + } + } + + public sealed class TestModel + { + [VectorStoreRecordKey(StoragePropertyName = "key")] + public string? Id { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "text")] + public string? Text { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "number")] + public int Number { get; set; } + + [VectorStoreRecordVector(Dimensions: 10, StoragePropertyName = "embedding")] + public ReadOnlyMemory Floats { get; set; } + } +} From f76b573c8045103f1b12eec3227207f4cce2e0f9 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Fri, 28 Feb 2025 13:08:33 +0100 Subject: [PATCH 27/32] address code review feedback: - add support for VectorStoreGenericDataModel - respect IncludeVectors in Get* methods - rename BasicVectorSearchTests to VectorSearchDistanceFunctionComplianceTests --- .../GenericRecordMapper.cs | 93 +++++++++++++++++++ .../SqlServerCommandBuilder.cs | 10 +- .../SqlServerVectorStoreRecordCollection.cs | 35 +++++-- ...ts.cs => PostgresBatchConformanceTests.cs} | 3 +- ...ostgresGenericDataModelConformanceTests.cs | 12 +++ ...rSearchDistanceFunctionComplianceTests.cs} | 8 +- ...chDistanceFunctionComplianceTests_Hnsw.cs} | 2 +- ...s.cs => SqlServerBatchConformanceTests.cs} | 3 +- ...lServerGenericDataModelConformanceTests.cs | 12 +++ .../SqlServerCommandBuilderTests.cs | 4 +- .../SqlServerVectorStoreTests.cs | 18 ++-- .../Support/SqlServerTestStore.cs | 5 +- ...rSearchDistanceFunctionComplianceTests.cs} | 4 +- ...chDistanceFunctionComplianceTests_Hnsw.cs} | 4 +- ...manceTests.cs => BatchConformanceTests.cs} | 39 +------- .../CRUD/ConformanceTestsBase.cs | 34 +++++++ .../CRUD/GenericDataModelConformanceTests.cs | 66 +++++++++++++ .../Models/SimpleModel.cs | 25 +++++ ...rSearchDistanceFunctionComplianceTests.cs} | 2 +- 19 files changed, 304 insertions(+), 75 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.Memory.SqlServer/GenericRecordMapper.cs rename dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/{PostgresBasicConformanceTests.cs => PostgresBatchConformanceTests.cs} (57%) create mode 100644 dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresGenericDataModelConformanceTests.cs rename dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/{PostgresBasicVectorSearchTests.cs => PostgresVectorSearchDistanceFunctionComplianceTests.cs} (53%) rename dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/{PostgresBasicVectorSearchTests_Hnsw.cs => PostgresVectorSearchDistanceFunctionComplianceTests_Hnsw.cs} (60%) rename dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/{SqlServerBasicConformanceTests.cs => SqlServerBatchConformanceTests.cs} (56%) create mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerGenericDataModelConformanceTests.cs rename dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/{SqlServerBasicVectorSearchTests.cs => SqlServerVectorSearchDistanceFunctionComplianceTests.cs} (80%) rename dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/{SqlServerBasicVectorSearchTests_Hnsw.cs => SqlServerVectorSearchDistanceFunctionComplianceTests_Hnsw.cs} (87%) rename dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/{BasicConformanceTests.cs => BatchConformanceTests.cs} (63%) create mode 100644 dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/ConformanceTestsBase.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/GenericDataModelConformanceTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Models/SimpleModel.cs rename dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/{BasicVectorSearchTests.cs => VectorSearchDistanceFunctionComplianceTests.cs} (98%) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/GenericRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/GenericRecordMapper.cs new file mode 100644 index 000000000000..ff9c7851f4cb --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/GenericRecordMapper.cs @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +internal sealed class GenericRecordMapper : IVectorStoreRecordMapper, IDictionary> + where TKey : notnull +{ + private readonly VectorStoreRecordPropertyReader _propertyReader; + + internal GenericRecordMapper(VectorStoreRecordPropertyReader propertyReader) => this._propertyReader = propertyReader; + + public IDictionary MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) + { + Dictionary properties = new() + { + { SqlServerCommandBuilder.GetColumnName(this._propertyReader.KeyProperty), dataModel.Key } + }; + + foreach (var property in this._propertyReader.DataProperties) + { + string name = SqlServerCommandBuilder.GetColumnName(property); + if (dataModel.Data.TryGetValue(name, out var dataValue)) + { + properties.Add(name, dataValue); + } + } + + // Add vector properties + if (dataModel.Vectors is not null) + { + foreach (var property in this._propertyReader.VectorProperties) + { + string name = SqlServerCommandBuilder.GetColumnName(property); + if (dataModel.Vectors.TryGetValue(name, out var vectorValue)) + { + if (vectorValue is ReadOnlyMemory floats) + { + properties.Add(name, floats); + } + else if (vectorValue is not null) + { + throw new VectorStoreRecordMappingException($"Vector property '{name}' contained value of non supported type: '{vectorValue.GetType().FullName}'."); + } + } + } + } + + return properties; + } + + public VectorStoreGenericDataModel MapFromStorageToDataModel(IDictionary storageModel, StorageToDataModelMapperOptions options) + { + TKey key; + var dataProperties = new Dictionary(); + var vectorProperties = new Dictionary(); + + if (storageModel.TryGetValue(SqlServerCommandBuilder.GetColumnName(this._propertyReader.KeyProperty), out var keyObject) && keyObject is not null) + { + key = (TKey)keyObject; + } + else + { + throw new VectorStoreRecordMappingException("No key property was found in the record retrieved from storage."); + } + + foreach (var property in this._propertyReader.DataProperties) + { + string name = SqlServerCommandBuilder.GetColumnName(property); + if (storageModel.TryGetValue(name, out var dataValue)) + { + dataProperties.Add(name, dataValue); + } + } + + if (options.IncludeVectors) + { + foreach (var property in this._propertyReader.VectorProperties) + { + string name = SqlServerCommandBuilder.GetColumnName(property); + if (storageModel.TryGetValue(name, out var vectorValue)) + { + vectorProperties.Add(name, vectorValue); + } + } + } + + return new(key) { Data = dataProperties, Vectors = vectorProperties }; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 5915324f5148..916f895f8429 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -278,14 +278,15 @@ internal static SqlCommand SelectSingle( SqlConnection sqlConnection, string schema, string collectionName, VectorStoreRecordKeyProperty keyProperty, IReadOnlyList properties, - object key) + object key, + bool includeVectors) { SqlCommand command = sqlConnection.CreateCommand(); int paramIndex = 0; StringBuilder sb = new(200); sb.AppendFormat("SELECT "); - sb.AppendColumnNames(properties); + sb.AppendColumnNames(properties, includeVectors: includeVectors); sb.AppendLine(); sb.Append("FROM "); sb.AppendTableName(schema, collectionName); @@ -302,13 +303,14 @@ internal static SqlCommand SelectSingle( SqlConnection connection, string schema, string tableName, VectorStoreRecordKeyProperty keyProperty, IReadOnlyList properties, - IEnumerable keys) + IEnumerable keys, + bool includeVectors) { SqlCommand command = connection.CreateCommand(); StringBuilder sb = new(200); sb.AppendFormat("SELECT "); - sb.AppendColumnNames(properties); + sb.AppendColumnNames(properties, includeVectors: includeVectors); sb.AppendLine(); sb.Append("FROM "); sb.AppendTableName(schema, tableName); diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index 5ca6bafedbd1..f8c7f714bad5 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -50,11 +50,6 @@ public SqlServerVectorStoreRecordCollection( SupportsMultipleVectors = true, }); - if (options is null || options.Mapper is null) - { - propertyReader.VerifyHasParameterlessConstructor(); - } - HashSet supportedKeyTypes = propertyReader.KeyProperty.AutoGenerate ? SqlServerConstants.SupportedAutoGenerateKeyTypes : SqlServerConstants.SupportedKeyTypes; @@ -82,7 +77,21 @@ public SqlServerVectorStoreRecordCollection( RecordDefinition = options.RecordDefinition, }; this._propertyReader = propertyReader; - this._mapper = options?.Mapper ?? new RecordMapper(propertyReader); + + if (options is not null && options.Mapper is not null) + { + this._mapper = options.Mapper; + } + else if (typeof(TRecord).IsGenericType && typeof(TRecord).GetGenericTypeDefinition() == typeof(VectorStoreGenericDataModel<>)) + { + this._mapper = (new GenericRecordMapper(propertyReader) as IVectorStoreRecordMapper>)!; + } + else + { + propertyReader.VerifyHasParameterlessConstructor(); + + this._mapper = new RecordMapper(propertyReader); + } } /// @@ -189,13 +198,16 @@ await ExceptionWrapper.WrapAsync(this._sqlConnection, command, { Verify.NotNull(key); + bool includeVectors = options?.IncludeVectors is true; + using SqlCommand command = SqlServerCommandBuilder.SelectSingle( this._sqlConnection, this._options.Schema, this.CollectionName, this._propertyReader.KeyProperty, this._propertyReader.Properties, - key); + key, + includeVectors); using SqlDataReader reader = await ExceptionWrapper.WrapAsync(this._sqlConnection, command, static async (cmd, ct) => @@ -208,7 +220,7 @@ static async (cmd, ct) => return reader.HasRows ? this._mapper.MapFromStorageToDataModel( new SqlDataReaderDictionary(reader, this._propertyReader.VectorPropertyStoragePropertyNames), - new() { IncludeVectors = true }) + new() { IncludeVectors = includeVectors }) : default; } @@ -218,13 +230,16 @@ public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, Get { Verify.NotNull(keys); + bool includeVectors = options?.IncludeVectors is true; + using SqlCommand? command = SqlServerCommandBuilder.SelectMany( this._sqlConnection, this._options.Schema, this.CollectionName, this._propertyReader.KeyProperty, this._propertyReader.Properties, - keys); + keys, + includeVectors); if (command is null) { @@ -239,7 +254,7 @@ public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, Get { yield return this._mapper.MapFromStorageToDataModel( new SqlDataReaderDictionary(reader, this._propertyReader.VectorPropertyStoragePropertyNames), - new() { IncludeVectors = true }); + new() { IncludeVectors = includeVectors }); } } diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresBasicConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresBatchConformanceTests.cs similarity index 57% rename from dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresBasicConformanceTests.cs rename to dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresBatchConformanceTests.cs index 687bd680f382..b798bab8e437 100644 --- a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresBasicConformanceTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresBatchConformanceTests.cs @@ -6,6 +6,7 @@ namespace PostgresIntegrationTests.CRUD; -public class PostgresBasicConformanceTests(PostgresFixture fixture) : BasicConformanceTests(fixture), IClassFixture +public class PostgresBatchConformanceTests(PostgresFixture fixture) + : BatchConformanceTests(fixture), IClassFixture { } diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresGenericDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresGenericDataModelConformanceTests.cs new file mode 100644 index 000000000000..1a72c1b59e01 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresGenericDataModelConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PostgresIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace PostgresIntegrationTests.CRUD; + +public class PostgresGenericDataModelConformanceTests(PostgresFixture fixture) + : GenericDataModelConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresVectorSearchDistanceFunctionComplianceTests.cs similarity index 53% rename from dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests.cs rename to dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresVectorSearchDistanceFunctionComplianceTests.cs index 5f3f51494cf4..97767626c5cf 100644 --- a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresVectorSearchDistanceFunctionComplianceTests.cs @@ -6,11 +6,11 @@ namespace PostgresIntegrationTests.VectorSearch; -public class PostgresBasicVectorSearchTests(PostgresFixture fixture) : BasicVectorSearchTests(fixture), IClassFixture +public class PostgresVectorSearchDistanceFunctionComplianceTests(PostgresFixture fixture) : VectorSearchDistanceFunctionComplianceTests(fixture), IClassFixture { - public override Task EuclideanSquaredDistance() => Assert.ThrowsAsync(() => base.EuclideanSquaredDistance()); + public override Task EuclideanSquaredDistance() => Assert.ThrowsAsync(base.EuclideanSquaredDistance); - public override Task Hamming() => Assert.ThrowsAsync(() => base.Hamming()); + public override Task Hamming() => Assert.ThrowsAsync(base.Hamming); - public override Task NegativeDotProductSimilarity() => Assert.ThrowsAsync(() => base.NegativeDotProductSimilarity()); + public override Task NegativeDotProductSimilarity() => Assert.ThrowsAsync(base.NegativeDotProductSimilarity); } diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests_Hnsw.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresVectorSearchDistanceFunctionComplianceTests_Hnsw.cs similarity index 60% rename from dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests_Hnsw.cs rename to dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresVectorSearchDistanceFunctionComplianceTests_Hnsw.cs index 81bad383df99..2daf5cc958c2 100644 --- a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresBasicVectorSearchTests_Hnsw.cs +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresVectorSearchDistanceFunctionComplianceTests_Hnsw.cs @@ -4,7 +4,7 @@ namespace PostgresIntegrationTests.VectorSearch; -public class PostgresBasicVectorSearchTests_Hnsw(PostgresFixture fixture) : PostgresBasicVectorSearchTests(fixture) +public class PostgresVectorSearchDistanceFunctionComplianceTests_Hnsw(PostgresFixture fixture) : PostgresVectorSearchDistanceFunctionComplianceTests(fixture) { protected override string IndexKind => Microsoft.Extensions.VectorData.IndexKind.Hnsw; } diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBasicConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBatchConformanceTests.cs similarity index 56% rename from dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBasicConformanceTests.cs rename to dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBatchConformanceTests.cs index c06fd9bcc9f6..1e8ee17dd6f4 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBasicConformanceTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBatchConformanceTests.cs @@ -6,6 +6,7 @@ namespace SqlServerIntegrationTests.CRUD; -public class SqlServerBasicConformanceTests(SqlServerFixture fixture) : BasicConformanceTests(fixture), IClassFixture +public class SqlServerBatchConformanceTests(SqlServerFixture fixture) + : BatchConformanceTests(fixture), IClassFixture { } diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerGenericDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerGenericDataModelConformanceTests.cs new file mode 100644 index 000000000000..5b98a7d46a11 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerGenericDataModelConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServerIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace SqlServerIntegrationTests.CRUD; + +public class SqlServerGenericDataModelConformanceTests(SqlServerFixture fixture) + : GenericDataModelConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index 08866cc74051..c1d9c0210c38 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -315,7 +315,7 @@ public void SelectSingle() using SqlConnection connection = CreateConnection(); using SqlCommand command = SqlServerCommandBuilder.SelectSingle(connection, - "schema", "tableName", keyProperty, properties, 123L); + "schema", "tableName", keyProperty, properties, 123L, includeVectors: true); AssertEqualIgnoreNewLines( """"" @@ -344,7 +344,7 @@ public void SelectMany() using SqlConnection connection = CreateConnection(); using SqlCommand command = SqlServerCommandBuilder.SelectMany(connection, - "schema", "tableName", keyProperty, properties, keys)!; + "schema", "tableName", keyProperty, properties, keys, includeVectors: true)!; AssertEqualIgnoreNewLines( """"" diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs index d5d10e8a9e79..131398bf48db 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -67,7 +67,7 @@ public async Task RecordCRUD() string key = await collection.UpsertAsync(inserted); Assert.Equal(inserted.Id, key); - TestModel? received = await collection.GetAsync(inserted.Id); + TestModel? received = await collection.GetAsync(inserted.Id, new() { IncludeVectors = true }); AssertEquality(inserted, received); TestModel updated = new() @@ -79,7 +79,7 @@ public async Task RecordCRUD() key = await collection.UpsertAsync(updated); Assert.Equal(inserted.Id, key); - received = await collection.GetAsync(updated.Id); + received = await collection.GetAsync(updated.Id, new() { IncludeVectors = true }); AssertEquality(updated, received); VectorSearchResult vectorSearchResult = await (await collection.VectorizedSearchAsync(inserted.Floats, new() @@ -142,7 +142,7 @@ public async Task WrongModels() // Let's use a model with the same storage names, but different types // to trigger a mapping exception (deserializing a string to Memory). var invalidJsonCollection = testStore.DefaultVectorStore.GetCollection(collectionName); - await Assert.ThrowsAsync(() => invalidJsonCollection.GetAsync(inserted.Id)); + await Assert.ThrowsAsync(() => invalidJsonCollection.GetAsync(inserted.Id, new() { IncludeVectors = true })); } finally { @@ -177,7 +177,7 @@ public async Task CustomMapper() Assert.True(mapper.MapFromDataToStorageModel_WasCalled); Assert.False(mapper.MapFromStorageToDataModel_WasCalled); - TestModel? received = await collection.GetAsync(inserted.Id); + TestModel? received = await collection.GetAsync(inserted.Id, new() { IncludeVectors = true }); AssertEquality(inserted, received); Assert.True(mapper.MapFromStorageToDataModel_WasCalled); @@ -190,7 +190,7 @@ public async Task CustomMapper() key = await collection.UpsertAsync(updated); Assert.Equal(inserted.Id, key); - received = await collection.GetAsync(updated.Id); + received = await collection.GetAsync(updated.Id, new() { IncludeVectors = true }); AssertEquality(updated, received); await collection.DeleteAsync(inserted.Id); @@ -227,7 +227,7 @@ public async Task BatchCRUD() Assert.Equal(inserted[i].Id, keys[i]); } - TestModel[] received = await collection.GetBatchAsync(keys).ToArrayAsync(); + TestModel[] received = await collection.GetBatchAsync(keys, new() { IncludeVectors = true }).ToArrayAsync(); for (int i = 0; i < inserted.Length; i++) { AssertEquality(inserted[i], received[i]); @@ -246,7 +246,7 @@ public async Task BatchCRUD() Assert.Equal(updated[i].Id, keys[i]); } - received = await collection.GetBatchAsync(keys).ToArrayAsync(); + received = await collection.GetBatchAsync(keys, new() { IncludeVectors = true }).ToArrayAsync(); for (int i = 0; i < updated.Length; i++) { AssertEquality(updated[i], received[i]); @@ -351,7 +351,7 @@ private async Task CanUseFancyModels() where TKey : notnull TKey key = await collection.UpsertAsync(inserted); Assert.NotEqual(default, key); // key should be assigned by the DB (auto-increment) - FancyTestModel? received = await collection.GetAsync(key); + FancyTestModel? received = await collection.GetAsync(key, new() { IncludeVectors = true }); AssertEquality(inserted, received, key); FancyTestModel updated = new() @@ -363,7 +363,7 @@ private async Task CanUseFancyModels() where TKey : notnull key = await collection.UpsertAsync(updated); Assert.Equal(updated.Id, key); - received = await collection.GetAsync(updated.Id); + received = await collection.GetAsync(updated.Id, new() { IncludeVectors = true }); AssertEquality(updated, received, key); await collection.DeleteAsync(key); diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs index 211bc006d5d1..93a329b2438a 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs @@ -18,7 +18,7 @@ public override IVectorStore DefaultVectorStore private SqlServerVectorStore? _connectedStore; - protected override Task StartAsync() + protected override async Task StartAsync() { if (string.IsNullOrWhiteSpace(SqlServerTestEnvironment.ConnectionString)) { @@ -28,8 +28,9 @@ protected override Task StartAsync() #pragma warning disable CA2000 // Dispose objects before losing scope SqlConnection connection = new(SqlServerTestEnvironment.ConnectionString); #pragma warning restore CA2000 // Dispose objects before losing scope + await connection.OpenAsync(); + this._connectedStore = new(connection); - return connection.OpenAsync(); } public void Dispose() => this._connectedStore?.Dispose(); diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerVectorSearchDistanceFunctionComplianceTests.cs similarity index 80% rename from dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests.cs rename to dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerVectorSearchDistanceFunctionComplianceTests.cs index c321d4b51858..b1564100eb84 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerVectorSearchDistanceFunctionComplianceTests.cs @@ -6,8 +6,8 @@ namespace SqlServerIntegrationTests.VectorSearch; -public class SqlServerBasicVectorSearchTests(SqlServerFixture fixture) - : BasicVectorSearchTests(fixture), IClassFixture +public class SqlServerVectorSearchDistanceFunctionComplianceTests(SqlServerFixture fixture) + : VectorSearchDistanceFunctionComplianceTests(fixture), IClassFixture { public override Task CosineSimilarity() => Assert.ThrowsAsync(base.CosineSimilarity); diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests_Hnsw.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerVectorSearchDistanceFunctionComplianceTests_Hnsw.cs similarity index 87% rename from dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests_Hnsw.cs rename to dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerVectorSearchDistanceFunctionComplianceTests_Hnsw.cs index 1f37ce71945f..fe771d73278f 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerBasicVectorSearchTests_Hnsw.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerVectorSearchDistanceFunctionComplianceTests_Hnsw.cs @@ -5,8 +5,8 @@ namespace SqlServerIntegrationTests.VectorSearch; -public class SqlServerBasicVectorSearchTests_Hnsw(SqlServerFixture fixture) - : SqlServerBasicVectorSearchTests(fixture) +public class SqlServerVectorSearchDistanceFunctionComplianceTests_Hnsw(SqlServerFixture fixture) + : SqlServerVectorSearchDistanceFunctionComplianceTests(fixture) { // Creating such a collection is not supported. protected override string IndexKind => Microsoft.Extensions.VectorData.IndexKind.Hnsw; diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BasicConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BatchConformanceTests.cs similarity index 63% rename from dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BasicConformanceTests.cs rename to dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BatchConformanceTests.cs index f2edc990bfad..ace837591a74 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BasicConformanceTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BatchConformanceTests.cs @@ -1,16 +1,15 @@ // Copyright (c) Microsoft. All rights reserved. -using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Models; using VectorDataSpecificationTests.Support; using VectorDataSpecificationTests.Xunit; using Xunit; namespace VectorDataSpecificationTests.CRUD; -public abstract class BasicConformanceTests(VectorStoreFixture fixture) +public abstract class BatchConformanceTests(VectorStoreFixture fixture) + : ConformanceTestsBase>(fixture) where TKey : notnull { - protected virtual string GetUniqueCollectionName() => Guid.NewGuid().ToString(); - [ConditionalFact] public async Task UpsertBatchAsync_EmptyBatch_DoesNotThrow() { @@ -67,36 +66,4 @@ await this.ExecuteAsync(async collection => Assert.Equal("keys", ex.ParamName); }); } - - private async Task ExecuteAsync(Func, Task> test) - { - string collectionName = this.GetUniqueCollectionName(); - var collection = fixture.TestStore.DefaultVectorStore.GetCollection(collectionName); - - await collection.CreateCollectionAsync(); - - try - { - await test(collection); - } - finally - { - await collection.DeleteCollectionAsync(); - } - } - - public sealed class TestModel - { - [VectorStoreRecordKey(StoragePropertyName = "key")] - public string? Id { get; set; } - - [VectorStoreRecordData(StoragePropertyName = "text")] - public string? Text { get; set; } - - [VectorStoreRecordData(StoragePropertyName = "number")] - public int Number { get; set; } - - [VectorStoreRecordVector(Dimensions: 10, StoragePropertyName = "embedding")] - public ReadOnlyMemory Floats { get; set; } - } } diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/ConformanceTestsBase.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/ConformanceTestsBase.cs new file mode 100644 index 000000000000..21a6c95f8986 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/ConformanceTestsBase.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Support; + +namespace VectorDataSpecificationTests.CRUD; + +// TKey is a generic parameter because different connectors support different key types. +public abstract class ConformanceTestsBase(VectorStoreFixture fixture) where TKey : notnull +{ + protected VectorStoreFixture Fixture { get; } = fixture; + + protected virtual string GetUniqueCollectionName() => Guid.NewGuid().ToString(); + + protected virtual VectorStoreRecordDefinition? GetRecordDefinition() => null; + + protected async Task ExecuteAsync(Func, Task> test) + { + string collectionName = this.GetUniqueCollectionName(); + var collection = this.Fixture.TestStore.DefaultVectorStore.GetCollection(collectionName, + this.GetRecordDefinition()); + + await collection.CreateCollectionAsync(); + + try + { + await test(collection); + } + finally + { + await collection.DeleteCollectionAsync(); + } + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/GenericDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/GenericDataModelConformanceTests.cs new file mode 100644 index 000000000000..91ac166aafd4 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/GenericDataModelConformanceTests.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Support; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace VectorDataSpecificationTests.CRUD; + +public abstract class GenericDataModelConformanceTests(VectorStoreFixture fixture) + : ConformanceTestsBase>(fixture) where TKey : notnull +{ + private const string KeyPropertyName = "key"; + private const string StringPropertyName = "text"; + private const string IntegerPropertyName = "integer"; + private const string EmbeddingPropertyName = "embedding"; + private const int DimensionCount = 10; + + protected override VectorStoreRecordDefinition? GetRecordDefinition() + => new() + { + Properties = + [ + new VectorStoreRecordKeyProperty(KeyPropertyName, typeof(TKey)), + new VectorStoreRecordDataProperty(StringPropertyName, typeof(string)), + new VectorStoreRecordDataProperty(IntegerPropertyName, typeof(int)), + new VectorStoreRecordVectorProperty(EmbeddingPropertyName, typeof(ReadOnlyMemory)) + { + Dimensions = DimensionCount + } + ] + }; + + [ConditionalFact] + public async Task CanInsertUpdateAndDelete() + { + await this.ExecuteAsync(async collection => + { + VectorStoreGenericDataModel inserted = new(key: this.Fixture.GenerateNextKey()); + inserted.Data.Add(StringPropertyName, "some"); + inserted.Data.Add(IntegerPropertyName, 123); + inserted.Vectors.Add(EmbeddingPropertyName, new ReadOnlyMemory(Enumerable.Repeat(0.1f, DimensionCount).ToArray())); + + TKey key = await collection.UpsertAsync(inserted); + Assert.Equal(inserted.Key, key); + + VectorStoreGenericDataModel? received = await collection.GetAsync(key, new() { IncludeVectors = true }); + Assert.NotNull(received); + + Assert.Equal(received.Key, key); + foreach (var pair in inserted.Data) + { + Assert.Equal(pair.Value, received.Data[pair.Key]); + } + + Assert.Equal( + ((ReadOnlyMemory)inserted.Vectors[EmbeddingPropertyName]!).ToArray(), + ((ReadOnlyMemory)received.Vectors[EmbeddingPropertyName]!).ToArray()); + + await collection.DeleteAsync(key); + + received = await collection.GetAsync(key); + Assert.Null(received); + }); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Models/SimpleModel.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Models/SimpleModel.cs new file mode 100644 index 000000000000..0646f0fe2f1f --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Models/SimpleModel.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; + +namespace VectorDataSpecificationTests.Models; + +/// +/// This class represents bare minimum that each connector should support: +/// a key, int, string and an embedding. +/// +/// TKey is a generic parameter because different connectors support different key types. +public sealed class SimpleModel +{ + [VectorStoreRecordKey(StoragePropertyName = "key")] + public TKey? Id { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "text")] + public string? Text { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "number")] + public int Number { get; set; } + + [VectorStoreRecordVector(Dimensions: 10, StoragePropertyName = "embedding")] + public ReadOnlyMemory Floats { get; set; } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/BasicVectorSearchTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/VectorSearchDistanceFunctionComplianceTests.cs similarity index 98% rename from dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/BasicVectorSearchTests.cs rename to dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/VectorSearchDistanceFunctionComplianceTests.cs index f2079c791b59..285c93c23e92 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/BasicVectorSearchTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/VectorSearchDistanceFunctionComplianceTests.cs @@ -7,7 +7,7 @@ namespace VectorDataSpecificationTests.VectorSearch; -public abstract class BasicVectorSearchTests(VectorStoreFixture fixture) +public abstract class VectorSearchDistanceFunctionComplianceTests(VectorStoreFixture fixture) where TKey : notnull { [ConditionalFact] From c40f341d958c288a4ec971031dd0e36588f8efa4 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Fri, 28 Feb 2025 17:27:34 +0100 Subject: [PATCH 28/32] Update dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs Co-authored-by: westey <164392973+westey-m@users.noreply.github.com> --- .../RecordDefinition/VectorStoreRecordKeyProperty.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs index 2ea0b97e2f20..1c223c40ca07 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs @@ -32,7 +32,7 @@ public VectorStoreRecordKeyProperty(VectorStoreRecordKeyProperty source) } /// - /// Gets a value indicating whether the key should be auto-generated by the vector store. + /// Gets or sets a value indicating whether the key should be auto-generated by the vector store. /// public bool AutoGenerate { get; init; } } From 2fe49c0f1f32327eb5f57d1621c3b2af4b953277 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Thu, 6 Mar 2025 11:17:39 +0100 Subject: [PATCH 29/32] Apply suggestions from code review Co-authored-by: Shay Rojansky --- .../Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 916f895f8429..ba367fd6f420 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -520,11 +520,11 @@ private static (string sqlName, string? autoGenerate) Map(Type type) Type t when t == typeof(int) => ("INT", "IDENTITY(1,1)"), Type t when t == typeof(long) => ("BIGINT", "IDENTITY(1,1)"), // TODO adsitnik: discuss using NEWID() vs NEWSEQUENTIALID(). - Type t when t == typeof(Guid) => ("UNIQUEIDENTIFIER", "DEFAULT NEWID()"), + Type t when t == typeof(Guid) => ("UNIQUEIDENTIFIER", "DEFAULT NEWSEQUENTIALID()"), Type t when t == typeof(string) => (NVARCHAR, null), Type t when t == typeof(byte[]) => ("VARBINARY(MAX)", null), Type t when t == typeof(bool) => ("BIT", null), - Type t when t == typeof(DateTime) => ("DATETIME", null), + Type t when t == typeof(DateTime) => ("DATETIME2", null), Type t when t == typeof(TimeSpan) => ("TIME", null), Type t when t == typeof(decimal) => ("DECIMAL", null), Type t when t == typeof(double) => ("FLOAT", null), From 0bdca763f801c7e86c56ef98aa6df5c4df7d13ce Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Thu, 6 Mar 2025 12:21:38 +0100 Subject: [PATCH 30/32] address code review feedback: - remove NONCLUSTERED used PRIMARY KEY - remove COLLATE Latin1_General_100_BIN2 for VARCHAR columns - remove TimeSpan => TIME mapping - add support for TimeOnly (TIME) - make Schema optional - handle Where(x => x.Bool) - use square brackets for escaping column names in WHERE clause --- .../SqlFilterTranslator.cs | 75 ++++++++----------- .../PostgresFilterTranslator.cs | 17 ++--- .../Connectors.Memory.SqlServer.csproj | 2 +- .../RecordMapper.cs | 2 +- .../SqlDataReaderDictionary.cs | 9 +++ .../SqlServerCommandBuilder.cs | 59 ++++++++------- .../SqlServerConstants.cs | 12 +-- .../SqlServerFilterTranslator.cs | 51 ++++++++++--- .../SqlServerVectorStoreOptions.cs | 2 +- .../SqlServerVectorStoreRecordCollection.cs | 2 +- ...erverVectorStoreRecordCollectionOptions.cs | 15 +--- .../SqliteFilterTranslator.cs | 13 ++-- .../Filter/SqlServerBasicFilterTests.cs | 17 +---- .../SqlServerCommandBuilderTests.cs | 12 +-- .../SqlServerVectorStoreTests.cs | 54 +++++++++++++ .../Filter/BasicFilterTests.cs | 12 +++ 16 files changed, 210 insertions(+), 144 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs index c4b9f146201c..cad9bd1048c2 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs @@ -40,10 +40,10 @@ internal void Translate(bool appendWhere) this._sql.Append("WHERE "); } - this.Translate(this._lambdaExpression.Body); + this.Translate(this._lambdaExpression.Body, null); } - protected void Translate(Expression? node) + protected void Translate(Expression? node, Expression? parent) { switch (node) { @@ -52,11 +52,11 @@ protected void Translate(Expression? node) return; case ConstantExpression constant: - this.TranslateConstant(constant); + this.TranslateConstant(constant.Value); return; case MemberExpression member: - this.TranslateMember(member); + this.TranslateMember(member, parent); return; case MethodCallExpression methodCall: @@ -72,36 +72,36 @@ protected void Translate(Expression? node) } } - private void TranslateBinary(BinaryExpression binary) + protected void TranslateBinary(BinaryExpression binary) { // Special handling for null comparisons switch (binary.NodeType) { case ExpressionType.Equal when IsNull(binary.Right): this._sql.Append('('); - this.Translate(binary.Left); + this.Translate(binary.Left, binary); this._sql.Append(" IS NULL)"); return; case ExpressionType.NotEqual when IsNull(binary.Right): this._sql.Append('('); - this.Translate(binary.Left); + this.Translate(binary.Left, binary); this._sql.Append(" IS NOT NULL)"); return; case ExpressionType.Equal when IsNull(binary.Left): this._sql.Append('('); - this.Translate(binary.Right); + this.Translate(binary.Right, binary); this._sql.Append(" IS NULL)"); return; case ExpressionType.NotEqual when IsNull(binary.Left): this._sql.Append('('); - this.Translate(binary.Right); + this.Translate(binary.Right, binary); this._sql.Append(" IS NOT NULL)"); return; } this._sql.Append('('); - this.Translate(binary.Left); + this.Translate(binary.Left, binary); this._sql.Append(binary.NodeType switch { @@ -119,7 +119,7 @@ private void TranslateBinary(BinaryExpression binary) _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) }); - this.Translate(binary.Right); + this.Translate(binary.Right, binary); this._sql.Append(')'); static bool IsNull(Expression expression) @@ -127,10 +127,7 @@ static bool IsNull(Expression expression) || (TryGetCapturedValue(expression, out _, out var capturedValue) && capturedValue is null); } - private void TranslateConstant(ConstantExpression constant) - => this.GenerateLiteral(constant.Value); - - protected void GenerateLiteral(object? value) + protected virtual void TranslateConstant(object? value) { // TODO: Nullable switch (value) @@ -152,20 +149,14 @@ protected void GenerateLiteral(object? value) this._sql.Append('\'').Append(s.Replace("'", "''")).Append('\''); return; case bool b: - this.GenerateLiteral(b); + this._sql.Append(b ? "TRUE" : "FALSE"); return; case Guid g: this._sql.Append('\'').Append(g.ToString()).Append('\''); return; case DateTime dateTime: - this.GenerateLiteral(dateTime); - return; - case DateTimeOffset dateTimeOffset: - this.GenerateLiteral(dateTimeOffset); - return; - case Array: throw new NotImplementedException(); @@ -178,25 +169,16 @@ protected void GenerateLiteral(object? value) } } - protected abstract void GenerateLiteral(bool value); - - protected virtual void GenerateLiteral(DateTime dateTime) - => throw new NotImplementedException(); - - protected virtual void GenerateLiteral(DateTimeOffset dateTimeOffset) - => throw new NotImplementedException(); - - private void TranslateMember(MemberExpression memberExpression) + private void TranslateMember(MemberExpression memberExpression, Expression? parent) { switch (memberExpression) { case var _ when this.TryGetColumn(memberExpression, out var column): - this._sql.Append('"').Append(column).Append('"'); + this.TranslateColumn(column, memberExpression, parent); return; - // Identify captured lambda variables, translate to PostgreSQL parameters ($1, $2...) case var _ when TryGetCapturedValue(memberExpression, out var name, out var value): - this.TranslateLambdaVariables(name, value); + this.TranslateCapturedVariable(name, value); return; default: @@ -204,7 +186,10 @@ private void TranslateMember(MemberExpression memberExpression) } } - protected abstract void TranslateLambdaVariables(string name, object? capturedValue); + protected virtual void TranslateColumn(string column, MemberExpression memberExpression, Expression? parent) + => this._sql.Append('"').Append(column).Append('"'); + + protected abstract void TranslateCapturedVariable(string name, object? capturedValue); private void TranslateMethodCall(MethodCallExpression methodCall) { @@ -213,7 +198,7 @@ private void TranslateMethodCall(MethodCallExpression methodCall) // Enumerable.Contains() case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains when contains.Method.DeclaringType == typeof(Enumerable): - this.TranslateContains(source, item); + this.TranslateContains(source, item, methodCall); return; // List.Contains() @@ -227,7 +212,7 @@ private void TranslateMethodCall(MethodCallExpression methodCall) Object: Expression source, Arguments: [var item] } when declaringType.GetGenericTypeDefinition() == typeof(List<>): - this.TranslateContains(source, item); + this.TranslateContains(source, item, methodCall); return; default: @@ -235,18 +220,18 @@ private void TranslateMethodCall(MethodCallExpression methodCall) } } - private void TranslateContains(Expression source, Expression item) + private void TranslateContains(Expression source, Expression item, MethodCallExpression parent) { switch (source) { // Contains over array column (r => r.Strings.Contains("foo")) case var _ when this.TryGetColumn(source, out _): - this.TranslateContainsOverArrayColumn(source, item); + this.TranslateContainsOverArrayColumn(source, item, parent); return; // Contains over inline array (r => new[] { "foo", "bar" }.Contains(r.String)) case NewArrayExpression newArray: - this.Translate(item); + this.Translate(item, parent); this._sql.Append(" IN ("); var isFirst = true; @@ -261,7 +246,7 @@ private void TranslateContains(Expression source, Expression item) this._sql.Append(", "); } - this.Translate(element); + this.Translate(element, parent); } this._sql.Append(')'); @@ -269,7 +254,7 @@ private void TranslateContains(Expression source, Expression item) // Contains over captured array (r => arrayLocalVariable.Contains(r.String)) case var _ when TryGetCapturedValue(source, out _, out var value): - this.TranslateContainsOverCapturedArray(source, item, value); + this.TranslateContainsOverCapturedArray(source, item, parent, value); return; default: @@ -277,9 +262,9 @@ private void TranslateContains(Expression source, Expression item) } } - protected abstract void TranslateContainsOverArrayColumn(Expression source, Expression item); + protected abstract void TranslateContainsOverArrayColumn(Expression source, Expression item, MethodCallExpression parent); - protected abstract void TranslateContainsOverCapturedArray(Expression source, Expression item, object? value); + protected abstract void TranslateContainsOverCapturedArray(Expression source, Expression item, MethodCallExpression parent, object? value); private void TranslateUnary(UnaryExpression unary) { @@ -298,7 +283,7 @@ private void TranslateUnary(UnaryExpression unary) } this._sql.Append("(NOT "); - this.Translate(unary.Operand); + this.Translate(unary.Operand, unary); this._sql.Append(')'); return; diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs index 1780e176cbc4..b4b9707c1c99 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs @@ -20,26 +20,23 @@ internal PostgresFilterTranslator( internal List ParameterValues => this._parameterValues; - protected override void GenerateLiteral(bool value) - => this._sql.Append(value ? "TRUE" : "FALSE"); - - protected override void TranslateContainsOverArrayColumn(Expression source, Expression item) + protected override void TranslateContainsOverArrayColumn(Expression source, Expression item, MethodCallExpression parent) { - this.Translate(source); + this.Translate(source, parent); this._sql.Append(" @> ARRAY["); - this.Translate(item); + this.Translate(item, parent); this._sql.Append(']'); } - protected override void TranslateContainsOverCapturedArray(Expression source, Expression item, object? value) + protected override void TranslateContainsOverCapturedArray(Expression source, Expression item, MethodCallExpression parent, object? value) { - this.Translate(item); + this.Translate(item, parent); this._sql.Append(" = ANY ("); - this.Translate(source); + this.Translate(source, parent); this._sql.Append(')'); } - protected override void TranslateLambdaVariables(string name, object? capturedValue) + protected override void TranslateCapturedVariable(string name, object? capturedValue) { // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj b/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj index 90c088464efc..045fd37fc3cf 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj @@ -4,7 +4,7 @@ Microsoft.SemanticKernel.Connectors.SqlServer $(AssemblyName) - netstandard2.0 + netstandard2.0;net8.0 alpha diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs index b2a5cb6d2cec..240f2814e044 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs @@ -31,7 +31,7 @@ internal sealed class RecordMapper : IVectorStoreRecordMapper so the cast here is safe. - ReadOnlyMemory floats = (ReadOnlyMemory)vectorPropertiesInfo[i].GetValue(dataModel); + ReadOnlyMemory floats = (ReadOnlyMemory)vectorPropertiesInfo[i].GetValue(dataModel)!; map[SqlServerCommandBuilder.GetColumnName(vectorProperties[i])] = floats; } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs index 3c10c8f31d46..414ff8de4afd 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs @@ -56,6 +56,15 @@ internal SqlDataReaderDictionary(SqlDataReader sqlDataReader, IReadOnlyList properties, @@ -163,7 +163,7 @@ internal static SqlCommand MergeIntoSingle( internal static SqlCommand? MergeIntoMany( SqlConnection connection, - string schema, + string? schema, string tableName, VectorStoreRecordKeyProperty keyProperty, IReadOnlyList properties, @@ -235,7 +235,7 @@ internal static SqlCommand MergeIntoSingle( } internal static SqlCommand DeleteSingle( - SqlConnection connection, string schema, string tableName, + SqlConnection connection, string? schema, string tableName, VectorStoreRecordKeyProperty keyProperty, object key) { SqlCommand command = connection.CreateCommand(); @@ -253,7 +253,7 @@ internal static SqlCommand DeleteSingle( } internal static SqlCommand? DeleteMany( - SqlConnection connection, string schema, string tableName, + SqlConnection connection, string? schema, string tableName, VectorStoreRecordKeyProperty keyProperty, IEnumerable keys) { SqlCommand command = connection.CreateCommand(); @@ -275,7 +275,7 @@ internal static SqlCommand DeleteSingle( } internal static SqlCommand SelectSingle( - SqlConnection sqlConnection, string schema, string collectionName, + SqlConnection sqlConnection, string? schema, string collectionName, VectorStoreRecordKeyProperty keyProperty, IReadOnlyList properties, object key, @@ -300,7 +300,7 @@ internal static SqlCommand SelectSingle( } internal static SqlCommand? SelectMany( - SqlConnection connection, string schema, string tableName, + SqlConnection connection, string? schema, string tableName, VectorStoreRecordKeyProperty keyProperty, IReadOnlyList properties, IEnumerable keys, @@ -329,7 +329,7 @@ internal static SqlCommand SelectSingle( } internal static SqlCommand SelectVector( - SqlConnection connection, string schema, string tableName, + SqlConnection connection, string? schema, string tableName, VectorStoreRecordVectorProperty vectorProperty, IReadOnlyList properties, IReadOnlyDictionary storagePropertyNamesMap, @@ -337,7 +337,6 @@ internal static SqlCommand SelectVector( ReadOnlyMemory vector) { string distanceFunction = vectorProperty.DistanceFunction ?? DistanceFunction.CosineDistance; - // Source: https://learn.microsoft.com/sql/t-sql/functions/vector-distance-transact-sql (string distanceMetric, string sorting) = MapDistanceFunction(distanceFunction); SqlCommand command = connection.CreateCommand(); @@ -412,17 +411,22 @@ internal static StringBuilder AppendParameterName(this StringBuilder sb, VectorS return sb; } - internal static StringBuilder AppendTableName(this StringBuilder sb, string schema, string tableName) + internal static StringBuilder AppendTableName(this StringBuilder sb, string? schema, string tableName) { // If the column name contains a ], then escape it by doubling it. // "Name with [brackets]" becomes [Name with [brackets]]]. sb.Append('['); - int index = sb.Length; // store the index, so we replace ] only for schema - sb.Append(schema); - sb.Replace("]", "]]", index, schema.Length); // replace the ] for schema - sb.Append("].["); - index = sb.Length; + int index = sb.Length; // store the index, so we replace ] only for the appended part + + if (!string.IsNullOrEmpty(schema)) + { + sb.Append(schema); + sb.Replace("]", "]]", index, schema.Length); // replace the ] for schema + sb.Append("].["); + index = sb.Length; + } + sb.Append(tableName); sb.Replace("]", "]]", index, tableName.Length); sb.Append(']'); @@ -512,20 +516,20 @@ private static void AddParameter(this SqlCommand command, VectorStoreRecordPrope private static (string sqlName, string? autoGenerate) Map(Type type) { - const string NVARCHAR = "NVARCHAR(255) COLLATE Latin1_General_100_BIN2"; return type switch { Type t when t == typeof(byte) => ("TINYINT", null), Type t when t == typeof(short) => ("SMALLINT", null), Type t when t == typeof(int) => ("INT", "IDENTITY(1,1)"), Type t when t == typeof(long) => ("BIGINT", "IDENTITY(1,1)"), - // TODO adsitnik: discuss using NEWID() vs NEWSEQUENTIALID(). Type t when t == typeof(Guid) => ("UNIQUEIDENTIFIER", "DEFAULT NEWSEQUENTIALID()"), - Type t when t == typeof(string) => (NVARCHAR, null), + Type t when t == typeof(string) => ("NVARCHAR(255)", null), Type t when t == typeof(byte[]) => ("VARBINARY(MAX)", null), Type t when t == typeof(bool) => ("BIT", null), Type t when t == typeof(DateTime) => ("DATETIME2", null), - Type t when t == typeof(TimeSpan) => ("TIME", null), +#if NET + Type t when t == typeof(TimeOnly) => ("TIME", null), +#endif Type t when t == typeof(decimal) => ("DECIMAL", null), Type t when t == typeof(double) => ("FLOAT", null), Type t when t == typeof(float) => ("REAL", null), @@ -533,6 +537,7 @@ private static (string sqlName, string? autoGenerate) Map(Type type) }; } + // Source: https://learn.microsoft.com/sql/t-sql/functions/vector-distance-transact-sql private static (string distanceMetric, string sorting) MapDistanceFunction(string name) => name switch { // A value of 0 indicates that the vectors are identical in direction (cosine similarity of 1), diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs index 339445795551..f5d0c608c547 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs @@ -7,15 +7,13 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer; internal static class SqlServerConstants { - internal const string Schema = "dbo"; - internal static readonly HashSet SupportedKeyTypes = [ typeof(int), // INT typeof(long), // BIGINT typeof(string), // VARCHAR typeof(Guid), // UNIQUEIDENTIFIER - typeof(DateTime), // DATETIME + typeof(DateTime), // DATETIME2 typeof(byte[]) // VARBINARY ]; @@ -36,8 +34,12 @@ internal static class SqlServerConstants typeof(string), // NVARCHAR typeof(byte[]), //VARBINARY typeof(bool), // BIT - typeof(DateTime), // DATETIME - typeof(TimeSpan), // TIME + typeof(DateTime), // DATETIME2 +#if NET + // We don't support mapping TimeSpan to TIME on purpose + // See https://github.com/microsoft/semantic-kernel/pull/10623#discussion_r1980350721 + typeof(TimeOnly), // TIME +#endif typeof(decimal), // DECIMAL typeof(double), // FLOAT typeof(float), // REAL diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs index b51e2d0b588a..9d7862abe85e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs @@ -25,26 +25,53 @@ internal SqlServerFilterTranslator( internal List ParameterValues => this._parameterValues; - protected override void GenerateLiteral(bool value) - => this._sql.Append(value ? "1" : "0"); - - protected override void GenerateLiteral(DateTime dateTime) - => this._sql.AppendFormat("'{0:yyyy-MM-dd HH:mm:ss}'", dateTime); + protected override void TranslateConstant(object? value) + { + switch (value) + { + case bool boolValue: + this._sql.Append(boolValue ? "1" : "0"); + return; + case DateTime dateTime: + this._sql.AppendFormat("'{0:yyyy-MM-dd HH:mm:ss}'", dateTime); + return; + case DateTimeOffset dateTimeOffset: + this._sql.AppendFormat("'{0:yyy-MM-dd HH:mm:ss zzz}'", dateTimeOffset); + return; + default: + base.TranslateConstant(value); + break; + } + } - protected override void GenerateLiteral(DateTimeOffset dateTimeOffset) - => this._sql.AppendFormat("'{0:yyy-MM-dd HH:mm:ss zzz}'", dateTimeOffset); + protected override void TranslateColumn(string column, MemberExpression memberExpression, Expression? parent) + { + // "SELECT * FROM MyTable WHERE BooleanColumn;" is not supported. + // "SELECT * FROM MyTable WHERE BooleanColumn = 1;" is supported. + if (memberExpression.Type == typeof(bool) + && (parent is null // Where(x => x.Bool) + || parent is UnaryExpression { NodeType: ExpressionType.Not } // Where(x => !x.Bool) + || parent is BinaryExpression { NodeType: ExpressionType.AndAlso or ExpressionType.OrElse })) // Where(x => x.Bool && other) + { + this.TranslateBinary(Expression.MakeBinary(ExpressionType.Equal, memberExpression, Expression.Constant(true))); + } + else + { + this._sql.Append('[').Append(column).Append(']'); + } + } - protected override void TranslateContainsOverArrayColumn(Expression source, Expression item) + protected override void TranslateContainsOverArrayColumn(Expression source, Expression item, MethodCallExpression parent) => throw new NotSupportedException("Unsupported Contains expression"); - protected override void TranslateContainsOverCapturedArray(Expression source, Expression item, object? value) + protected override void TranslateContainsOverCapturedArray(Expression source, Expression item, MethodCallExpression parent, object? value) { if (value is not IEnumerable elements) { throw new NotSupportedException("Unsupported Contains expression"); } - this.Translate(item); + this.Translate(item, parent); this._sql.Append(" IN ("); var isFirst = true; @@ -59,13 +86,13 @@ protected override void TranslateContainsOverCapturedArray(Expression source, Ex this._sql.Append(", "); } - this.GenerateLiteral(element); + this.TranslateConstant(element); } this._sql.Append(')'); } - protected override void TranslateLambdaVariables(string name, object? capturedValue) + protected override void TranslateCapturedVariable(string name, object? capturedValue) { // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs index a1d809c0face..a90b474a3d5f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs @@ -12,5 +12,5 @@ public sealed class SqlServerVectorStoreOptions /// /// Gets or sets the database schema. /// - public string Schema { get; init; } = SqlServerConstants.Schema; + public string? Schema { get; init; } = null; } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index f8c7f714bad5..21ccfbfe1396 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -321,7 +321,7 @@ public async Task> VectorizedSearchAsync(T $"Supported types are: {string.Join(", ", SqlServerConstants.SupportedVectorTypes.Select(l => l.FullName))}"); } #pragma warning disable CS0618 // Type or member is obsolete - else if (options is not null && options.Filter is not null) + else if (options is not null && options.OldFilter is not null) #pragma warning restore CS0618 // Type or member is obsolete { throw new NotSupportedException("The obsolete Filter is not supported by the SQL Server connector, use NewFilter instead."); diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollectionOptions.cs index adb1bd359d70..6b21a5e35842 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollectionOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollectionOptions.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.Collections.Generic; using Microsoft.Extensions.VectorData; @@ -11,22 +10,10 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer; /// public sealed class SqlServerVectorStoreRecordCollectionOptions { - private string _schema = SqlServerConstants.Schema; - /// /// Gets or sets the database schema. /// - /// when provided schema is empty or composed entirely of whitespace. - public string Schema - { - get => this._schema; - init - { - Verify.NotNullOrWhiteSpace(value); - - this._schema = value; - } - } + public string? Schema { get; init; } /// /// Gets or sets an optional custom mapper to use when converting between the data model and the SQL Server record. diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs index 8489301ad1f8..963c1184d274 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs @@ -18,21 +18,18 @@ internal SqliteFilterTranslator(IReadOnlyDictionary storagePrope internal Dictionary Parameters => this._parameters; - protected override void GenerateLiteral(bool value) - => this._sql.Append(value ? "TRUE" : "FALSE"); - // TODO: support Contains over array fields (#10343) - protected override void TranslateContainsOverArrayColumn(Expression source, Expression item) + protected override void TranslateContainsOverArrayColumn(Expression source, Expression item, MethodCallExpression parent) => throw new NotSupportedException("Unsupported Contains expression"); - protected override void TranslateContainsOverCapturedArray(Expression source, Expression item, object? value) + protected override void TranslateContainsOverCapturedArray(Expression source, Expression item, MethodCallExpression parent, object? value) { if (value is not IEnumerable elements) { throw new NotSupportedException("Unsupported Contains expression"); } - this.Translate(item); + this.Translate(item, parent); this._sql.Append(" IN ("); var isFirst = true; @@ -47,13 +44,13 @@ protected override void TranslateContainsOverCapturedArray(Expression source, Ex this._sql.Append(", "); } - this.GenerateLiteral(element); + this.TranslateConstant(element); } this._sql.Append(')'); } - protected override void TranslateLambdaVariables(string name, object? capturedValue) + protected override void TranslateCapturedVariable(string name, object? capturedValue) { // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs index 2cdbbeab82cc..3bae6cc48552 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs @@ -12,14 +12,6 @@ namespace SqlServerIntegrationTests.Filter; public class SqlServerBasicFilterTests(SqlServerBasicFilterTests.Fixture fixture) : BasicFilterTests(fixture), IClassFixture { - // "SELECT * FROM MyTable WHERE BooleanColumn = 1;" is fine - // "SELECT * FROM MyTable WHERE BooleanColumn;" is not - // TODO adsitnik: get it to work anyway - public override Task Bool() => this.TestFilterAsync(r => r.Bool == true); - - // Same as above, "WHERE NOT BooleanColumn" is not supported - public override Task Not_over_bool() => this.TestFilterAsync(r => r.Bool == false); - public override async Task Not_over_Or() { // Test sends: WHERE (NOT (("Int" = 8) OR ("String" = 'foo'))) @@ -62,14 +54,11 @@ public override Task Contains_over_field_string_List() public new class Fixture : BasicFilterTests.Fixture { + private static readonly string s_uniqueName = Guid.NewGuid().ToString(); + public override TestStore TestStore => SqlServerTestStore.Instance; - protected override string CollectionName -#if NET // make sure different TFMs use different collection names (as they may run in parallel and cause trouble) - => "FilterTests-core"; -#else - => "FilterTests-framework"; -#endif + protected override string CollectionName => s_uniqueName; // Override to remove the string collection properties, which aren't (currently) supported on SqlServer protected override VectorStoreRecordDefinition GetRecordDefinition() diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index c1d9c0210c38..3939b3cec1a7 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -12,9 +12,11 @@ public class SqlServerCommandBuilderTests { [Theory] [InlineData("schema", "name", "[schema].[name]")] + [InlineData(null, "name", "[name]")] [InlineData("schema", "[brackets]", "[schema].[[brackets]]]")] + [InlineData(null, "[needsEscaping]", "[[needsEscaping]]]")] [InlineData("needs]escaping", "[brackets]", "[needs]]escaping].[[brackets]]]")] - public void AppendTableName(string schema, string table, string expectedFullName) + public void AppendTableName(string? schema, string table, string expectedFullName) { StringBuilder result = new(); @@ -72,7 +74,7 @@ public void SelectTableName(string schema, string table) SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE' - AND TABLE_SCHEMA = @schema + AND (@schema is NULL or TABLE_SCHEMA = @schema) AND TABLE_NAME = @tableName """ , command.CommandText); @@ -93,7 +95,7 @@ public void SelectTableNames() SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE' - AND TABLE_SCHEMA = @schema + AND (@schema is NULL or TABLE_SCHEMA = @schema) """ , command.CommandText); Assert.Equal(SchemaName, command.Parameters[0].Value); @@ -131,10 +133,10 @@ public void CreateTable(bool ifNotExists) BEGIN CREATE TABLE [schema].[table] ( [id] BIGINT IDENTITY(1,1), - [simpleName] NVARCHAR(255) COLLATE Latin1_General_100_BIN2, + [simpleName] NVARCHAR(255), [with space] INT, [embedding] VECTOR(10), - PRIMARY KEY NONCLUSTERED ([id]) + PRIMARY KEY ([id]) ); END; """; diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs index 131398bf48db..1e015e34e650 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -319,6 +319,60 @@ public sealed class DifferentStorageNames public ReadOnlyMemory Floats { get; set; } } +#if NETFRAMEWORK + [ConditionalFact] + public void TimeSpanIsNotSupported() + { + string collectionName = GetUniqueCollectionName(); + var testStore = fixture.TestStore; + + Assert.Throws(() => testStore.DefaultVectorStore.GetCollection(collectionName)); + } +#else + [ConditionalFact] + public async Task TimeOnlyIsSupported() + { + string collectionName = GetUniqueCollectionName(); + var testStore = fixture.TestStore; + + var collection = testStore.DefaultVectorStore.GetCollection(collectionName); + + try + { + await collection.CreateCollectionIfNotExistsAsync(); + + TimeModel inserted = new() + { + Id = "MyId", + Time = new TimeOnly(12, 34, 56) + }; + string key = await collection.UpsertAsync(inserted); + Assert.Equal(inserted.Id, key); + + TimeModel? received = await collection.GetAsync(inserted.Id, new() { IncludeVectors = true }); + Assert.NotNull(received); + Assert.Equal(inserted.Time, received.Time); + } + finally + { + await collection.DeleteCollectionAsync(); + } + } +#endif + + public sealed class TimeModel + { + [VectorStoreRecordKey(StoragePropertyName = "key")] + public string? Id { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "time")] +#if NETFRAMEWORK + public TimeSpan Time { get; set; } +#else + public TimeOnly Time { get; set; } +#endif + } + [ConditionalFact] public Task CanUseFancyModels_Int() => this.CanUseFancyModels(); diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTests.cs index 0f87d2ae7c5d..dd03c1b1bda7 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTests.cs @@ -73,6 +73,14 @@ public virtual Task NotEqual_with_null_captured() public virtual Task Bool() => this.TestFilterAsync(r => r.Bool); + [ConditionalFact] + public virtual Task Bool_And_Bool() + => this.TestFilterAsync(r => r.Bool && r.Bool); + + [ConditionalFact] + public virtual Task Bool_Or_Not_Bool() + => this.TestFilterAsync(r => r.Bool || !r.Bool, expectAllResults: true); + #endregion Equality #region Comparison @@ -139,6 +147,10 @@ public virtual Task Not_over_Or() public virtual Task Not_over_bool() => this.TestFilterAsync(r => !r.Bool); + [ConditionalFact] + public virtual Task Not_over_bool_And_Comparison() + => this.TestFilterAsync(r => !r.Bool && r.Int != int.MaxValue); + #endregion Logical operators #region Contains From 9ed18abbf8fc355ea2562c71ff51daa369daa95e Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Fri, 7 Mar 2025 11:23:34 +0100 Subject: [PATCH 31/32] remove AutoGenerate --- .../SqlServerCommandBuilder.cs | 11 +++-------- .../SqlServerConstants.cs | 7 ------- .../SqlServerVectorStoreRecordCollection.cs | 8 ++------ .../VectorStoreRecordKeyAttribute.cs | 13 ------------- .../src/Data/VectorStoreRecordPropertyReader.cs | 1 - .../SqlServerCommandBuilderTests.cs | 16 +++++----------- .../SqlServerVectorStoreTests.cs | 6 +++--- 7 files changed, 13 insertions(+), 49 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index 852a421fdf36..17b23d7d7477 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Linq; using System.Text; using System.Text.Json; using Microsoft.Data.SqlClient; @@ -36,7 +35,7 @@ internal static SqlCommand CreateTable( sb.AppendLine(" ("); string keyColumnName = GetColumnName(keyProperty); var keyMapping = Map(keyProperty.PropertyType); - sb.AppendFormat("[{0}] {1} {2},", keyColumnName, keyMapping.sqlName, keyProperty.AutoGenerate ? keyMapping.autoGenerate : "NOT NULL"); + sb.AppendFormat("[{0}] {1} NOT NULL,", keyColumnName, keyMapping.sqlName); sb.AppendLine(); for (int i = 0; i < dataProperties.Count; i++) { @@ -143,17 +142,13 @@ internal static SqlCommand MergeIntoSingle( --sb.Length; // remove the last comma sb.AppendLine(); - // We must not try to insert the key if it is auto-generated. - var propertiesToInsert = keyProperty.AutoGenerate - ? properties.Where(p => p != keyProperty) - : properties; sb.Append("WHEN NOT MATCHED THEN"); sb.AppendLine(); sb.Append("INSERT ("); - sb.AppendColumnNames(propertiesToInsert); + sb.AppendColumnNames(properties); sb.AppendLine(")"); sb.Append("VALUES ("); - sb.AppendColumnNames(propertiesToInsert, prefix: "s."); + sb.AppendColumnNames(properties, prefix: "s."); sb.AppendLine(")"); sb.AppendFormat("OUTPUT inserted.[{0}];", GetColumnName(keyProperty)); diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs index f5d0c608c547..672bfa2d0cfa 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs @@ -17,13 +17,6 @@ internal static class SqlServerConstants typeof(byte[]) // VARBINARY ]; - internal static readonly HashSet SupportedAutoGenerateKeyTypes = - [ - typeof(int), // IDENTITY - typeof(long), // IDENTITY - typeof(Guid) // NEWID - ]; - internal static readonly HashSet SupportedDataTypes = [ typeof(int), // INT diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs index ab260f0faea7..d2a87a34bf6d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -50,17 +50,13 @@ public SqlServerVectorStoreRecordCollection( SupportsMultipleVectors = true, }); - HashSet supportedKeyTypes = propertyReader.KeyProperty.AutoGenerate - ? SqlServerConstants.SupportedAutoGenerateKeyTypes - : SqlServerConstants.SupportedKeyTypes; - if (VectorStoreRecordPropertyVerification.IsGenericDataModel(typeof(TRecord))) { - VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.Mapper is not null, supportedKeyTypes); + VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.Mapper is not null, SqlServerConstants.SupportedKeyTypes); } else { - propertyReader.VerifyKeyProperties(supportedKeyTypes); + propertyReader.VerifyKeyProperties(SqlServerConstants.SupportedKeyTypes); } propertyReader.VerifyDataProperties(SqlServerConstants.SupportedDataTypes, supportEnumerable: false); propertyReader.VerifyVectorProperties(SqlServerConstants.SupportedVectorTypes); diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs index 7c5eb63817b1..318521355f1b 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes/VectorStoreRecordKeyAttribute.cs @@ -20,17 +20,4 @@ public sealed class VectorStoreRecordKeyAttribute : Attribute /// For example, the property name might be "MyProperty" and the storage name might be "my_property". /// public string? StoragePropertyName { get; set; } - - /// - /// Gets or sets whether the key should be auto-generated by the vector store. - /// - /// - /// The default is . - /// - /// - /// If set to , you must set the key property on any record that you pass to . - /// If set to , the key property may be left null on any record that you pass to - /// and a generated key will be returned. - /// - public bool AutoGenerate { get; set; } } diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs index becbc8c26e67..545996e44dbb 100644 --- a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs +++ b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs @@ -600,7 +600,6 @@ private static VectorStoreRecordDefinition CreateVectorStoreRecordDefinitionFrom definitionProperties.Add(new VectorStoreRecordKeyProperty(keyProperty.Name, keyProperty.PropertyType) { StoragePropertyName = keyAttribute.StoragePropertyName, - AutoGenerate = keyAttribute.AutoGenerate }); } } diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs index 3939b3cec1a7..0d421b6ba314 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -107,10 +107,7 @@ FROM INFORMATION_SCHEMA.TABLES [InlineData(false)] public void CreateTable(bool ifNotExists) { - VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)) - { - AutoGenerate = true - }; + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); VectorStoreRecordDataProperty[] dataProperties = [ new VectorStoreRecordDataProperty("simpleName", typeof(string)), @@ -132,7 +129,7 @@ public void CreateTable(bool ifNotExists) """ BEGIN CREATE TABLE [schema].[table] ( - [id] BIGINT IDENTITY(1,1), + [id] BIGINT NOT NULL, [simpleName] NVARCHAR(255), [with space] INT, [embedding] VECTOR(10), @@ -151,10 +148,7 @@ PRIMARY KEY ([id]) [Fact] public void MergeIntoSingle() { - VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)) - { - AutoGenerate = true - }; + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); VectorStoreRecordProperty[] properties = [ keyProperty, @@ -185,8 +179,8 @@ MERGE INTO [schema].[table] AS t WHEN MATCHED THEN UPDATE SET t.[simpleString] = s.[simpleString],t.[simpleInt] = s.[simpleInt],t.[embedding] = s.[embedding] WHEN NOT MATCHED THEN - INSERT ([simpleString],[simpleInt],[embedding]) - VALUES (s.[simpleString],s.[simpleInt],s.[embedding]) + INSERT ([id],[simpleString],[simpleInt],[embedding]) + VALUES (s.[id],s.[simpleString],s.[simpleInt],s.[embedding]) OUTPUT inserted.[id]; """"; diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs index 1e015e34e650..bb658f486c87 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -394,7 +394,7 @@ private async Task CanUseFancyModels() where TKey : notnull FancyTestModel inserted = new() { - // We let the DB assign Id! + Id = testStore.GenerateKey(1), Number8 = byte.MaxValue, Number16 = short.MaxValue, Number32 = int.MaxValue, @@ -403,7 +403,7 @@ private async Task CanUseFancyModels() where TKey : notnull Bytes = [1, 2, 3], }; TKey key = await collection.UpsertAsync(inserted); - Assert.NotEqual(default, key); // key should be assigned by the DB (auto-increment) + Assert.NotEqual(default, key); FancyTestModel? received = await collection.GetAsync(key, new() { IncludeVectors = true }); AssertEquality(inserted, received, key); @@ -444,7 +444,7 @@ void AssertEquality(FancyTestModel expected, FancyTestModel? receive public sealed class FancyTestModel { - [VectorStoreRecordKey(StoragePropertyName = "key", AutoGenerate = true)] + [VectorStoreRecordKey(StoragePropertyName = "key")] public TKey? Id { get; set; } [VectorStoreRecordData(StoragePropertyName = "byte")] From c42c6cbeebb747aa97d745e95f2b0bcb1cc974cf Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Fri, 7 Mar 2025 18:01:44 +0100 Subject: [PATCH 32/32] Apply suggestions from code review Co-authored-by: Shay Rojansky --- .../Connectors.Memory.SqlServer/SqlServerConstants.cs | 2 +- .../Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs | 2 +- .../src/Data/VectorStoreRecordPropertyReader.cs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs index 672bfa2d0cfa..6b81cbac1ef6 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs @@ -25,7 +25,7 @@ internal static class SqlServerConstants typeof(long), // BIGINT. typeof(Guid), // UNIQUEIDENTIFIER. typeof(string), // NVARCHAR - typeof(byte[]), //VARBINARY + typeof(byte[]), // VARBINARY typeof(bool), // BIT typeof(DateTime), // DATETIME2 #if NET diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs index 9d7862abe85e..3bd3b2f97e0b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs @@ -53,7 +53,7 @@ protected override void TranslateColumn(string column, MemberExpression memberEx || parent is UnaryExpression { NodeType: ExpressionType.Not } // Where(x => !x.Bool) || parent is BinaryExpression { NodeType: ExpressionType.AndAlso or ExpressionType.OrElse })) // Where(x => x.Bool && other) { - this.TranslateBinary(Expression.MakeBinary(ExpressionType.Equal, memberExpression, Expression.Constant(true))); + this.TranslateBinary(Expression.Equal(memberExpression, Expression.Constant(true))); } else { diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs index 545996e44dbb..15047fe23b91 100644 --- a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs +++ b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs @@ -599,7 +599,7 @@ private static VectorStoreRecordDefinition CreateVectorStoreRecordDefinitionFrom { definitionProperties.Add(new VectorStoreRecordKeyProperty(keyProperty.Name, keyProperty.PropertyType) { - StoragePropertyName = keyAttribute.StoragePropertyName, + StoragePropertyName = keyAttribute.StoragePropertyName }); } }