diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 6d0f08aa9b5b..bad51cba9c8e 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -462,6 +462,7 @@ EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "CosmosMongoDBIntegrationTests", "src\VectorDataIntegrationTests\CosmosMongoDBIntegrationTests\CosmosMongoDBIntegrationTests.csproj", "{11DFBF14-6FBA-41F0-B7F3-A288952D6FDB}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AzureAISearchIntegrationTests", "src\VectorDataIntegrationTests\AzureAISearchIntegrationTests\AzureAISearchIntegrationTests.csproj", "{06181F0F-A375-43AE-B45F-73CBCFC30C14}" +EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Agents.AzureAI", "src\Agents\AzureAI\Agents.AzureAI.csproj", "{EA35F1B5-9148-4189-BE34-5E00AED56D65}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Plugins.AI", "src\Plugins\Plugins.AI\Plugins.AI.csproj", "{0C64EC81-8116-4388-87AD-BA14D4B59974}" @@ -491,6 +492,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Agents.Bedrock", "src\Agent EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ModelContextProtocol", "samples\Demos\ModelContextProtocol\ModelContextProtocol.csproj", "{B16AC373-3DA8-4505-9510-110347CD635D}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SqlServerIntegrationTests", "src\VectorDataIntegrationTests\SqlServerIntegrationTests\SqlServerIntegrationTests.csproj", "{A5E6193C-8431-4C6E-B674-682CB41EAA0C}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -1356,6 +1359,12 @@ Global {B16AC373-3DA8-4505-9510-110347CD635D}.Publish|Any CPU.Build.0 = Debug|Any CPU {B16AC373-3DA8-4505-9510-110347CD635D}.Release|Any CPU.ActiveCfg = Release|Any CPU {B16AC373-3DA8-4505-9510-110347CD635D}.Release|Any CPU.Build.0 = Release|Any CPU + {A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Publish|Any CPU.Build.0 = Debug|Any CPU + {A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -1541,6 +1550,7 @@ Global {DAD5FC6A-8CA0-43AC-87E1-032DFBD6B02A} = {3F260A77-B6C9-97FD-1304-4B34DA936CF4} {8C658E1E-83C8-4127-B8BF-27A638A45DDD} = {6823CD5E-2ABE-41EB-B865-F86EC13F0CF9} {B16AC373-3DA8-4505-9510-110347CD635D} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} + {A5E6193C-8431-4C6E-B674-682CB41EAA0C} = {4F381919-F1BE-47D8-8558-3187ED04A84F} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83} diff --git a/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs new file mode 100644 index 000000000000..cad9bd1048c2 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs @@ -0,0 +1,326 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Microsoft.SemanticKernel.Connectors; + +internal abstract class SqlFilterTranslator +{ + private readonly IReadOnlyDictionary _storagePropertyNames; + private readonly LambdaExpression _lambdaExpression; + private readonly ParameterExpression _recordParameter; + protected readonly StringBuilder _sql; + + internal SqlFilterTranslator( + IReadOnlyDictionary storagePropertyNames, + LambdaExpression lambdaExpression, + StringBuilder? sql = null) + { + this._storagePropertyNames = storagePropertyNames; + this._lambdaExpression = lambdaExpression; + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + this._sql = sql ?? new(); + } + + internal StringBuilder Clause => this._sql; + + internal void Translate(bool appendWhere) + { + if (appendWhere) + { + this._sql.Append("WHERE "); + } + + this.Translate(this._lambdaExpression.Body, null); + } + + protected void Translate(Expression? node, Expression? parent) + { + switch (node) + { + case BinaryExpression binary: + this.TranslateBinary(binary); + return; + + case ConstantExpression constant: + this.TranslateConstant(constant.Value); + return; + + case MemberExpression member: + this.TranslateMember(member, parent); + return; + + case MethodCallExpression methodCall: + this.TranslateMethodCall(methodCall); + return; + + case UnaryExpression unary: + this.TranslateUnary(unary); + return; + + default: + throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); + } + } + + protected void TranslateBinary(BinaryExpression binary) + { + // Special handling for null comparisons + switch (binary.NodeType) + { + case ExpressionType.Equal when IsNull(binary.Right): + this._sql.Append('('); + this.Translate(binary.Left, binary); + this._sql.Append(" IS NULL)"); + return; + case ExpressionType.NotEqual when IsNull(binary.Right): + this._sql.Append('('); + this.Translate(binary.Left, binary); + this._sql.Append(" IS NOT NULL)"); + return; + + case ExpressionType.Equal when IsNull(binary.Left): + this._sql.Append('('); + this.Translate(binary.Right, binary); + this._sql.Append(" IS NULL)"); + return; + case ExpressionType.NotEqual when IsNull(binary.Left): + this._sql.Append('('); + this.Translate(binary.Right, binary); + this._sql.Append(" IS NOT NULL)"); + return; + } + + this._sql.Append('('); + this.Translate(binary.Left, binary); + + this._sql.Append(binary.NodeType switch + { + ExpressionType.Equal => " = ", + ExpressionType.NotEqual => " <> ", + + ExpressionType.GreaterThan => " > ", + ExpressionType.GreaterThanOrEqual => " >= ", + ExpressionType.LessThan => " < ", + ExpressionType.LessThanOrEqual => " <= ", + + ExpressionType.AndAlso => " AND ", + ExpressionType.OrElse => " OR ", + + _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) + }); + + this.Translate(binary.Right, binary); + this._sql.Append(')'); + + static bool IsNull(Expression expression) + => expression is ConstantExpression { Value: null } + || (TryGetCapturedValue(expression, out _, out var capturedValue) && capturedValue is null); + } + + protected virtual void TranslateConstant(object? value) + { + // TODO: Nullable + switch (value) + { + case byte b: + this._sql.Append(b); + return; + case short s: + this._sql.Append(s); + return; + case int i: + this._sql.Append(i); + return; + case long l: + this._sql.Append(l); + return; + + case string s: + this._sql.Append('\'').Append(s.Replace("'", "''")).Append('\''); + return; + case bool b: + this._sql.Append(b ? "TRUE" : "FALSE"); + return; + case Guid g: + this._sql.Append('\'').Append(g.ToString()).Append('\''); + return; + + case DateTime dateTime: + case DateTimeOffset dateTimeOffset: + case Array: + throw new NotImplementedException(); + + case null: + this._sql.Append("NULL"); + return; + + default: + throw new NotSupportedException("Unsupported constant type: " + value.GetType().Name); + } + } + + private void TranslateMember(MemberExpression memberExpression, Expression? parent) + { + switch (memberExpression) + { + case var _ when this.TryGetColumn(memberExpression, out var column): + this.TranslateColumn(column, memberExpression, parent); + return; + + case var _ when TryGetCapturedValue(memberExpression, out var name, out var value): + this.TranslateCapturedVariable(name, value); + return; + + default: + throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + } + } + + protected virtual void TranslateColumn(string column, MemberExpression memberExpression, Expression? parent) + => this._sql.Append('"').Append(column).Append('"'); + + protected abstract void TranslateCapturedVariable(string name, object? capturedValue); + + private void TranslateMethodCall(MethodCallExpression methodCall) + { + switch (methodCall) + { + // Enumerable.Contains() + case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable): + this.TranslateContains(source, item, methodCall); + return; + + // List.Contains() + case + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>): + this.TranslateContains(source, item, methodCall); + return; + + default: + throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); + } + } + + private void TranslateContains(Expression source, Expression item, MethodCallExpression parent) + { + switch (source) + { + // Contains over array column (r => r.Strings.Contains("foo")) + case var _ when this.TryGetColumn(source, out _): + this.TranslateContainsOverArrayColumn(source, item, parent); + return; + + // Contains over inline array (r => new[] { "foo", "bar" }.Contains(r.String)) + case NewArrayExpression newArray: + this.Translate(item, parent); + this._sql.Append(" IN ("); + + var isFirst = true; + foreach (var element in newArray.Expressions) + { + if (isFirst) + { + isFirst = false; + } + else + { + this._sql.Append(", "); + } + + this.Translate(element, parent); + } + + this._sql.Append(')'); + return; + + // Contains over captured array (r => arrayLocalVariable.Contains(r.String)) + case var _ when TryGetCapturedValue(source, out _, out var value): + this.TranslateContainsOverCapturedArray(source, item, parent, value); + return; + + default: + throw new NotSupportedException("Unsupported Contains expression"); + } + } + + protected abstract void TranslateContainsOverArrayColumn(Expression source, Expression item, MethodCallExpression parent); + + protected abstract void TranslateContainsOverCapturedArray(Expression source, Expression item, MethodCallExpression parent, object? value); + + private void TranslateUnary(UnaryExpression unary) + { + switch (unary.NodeType) + { + case ExpressionType.Not: + // Special handling for !(a == b) and !(a != b) + if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) + { + this.TranslateBinary( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + return; + } + + this._sql.Append("(NOT "); + this.Translate(unary.Operand, unary); + this._sql.Append(')'); + return; + + default: + throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); + } + } + + private bool TryGetColumn(Expression expression, [NotNullWhen(true)] out string? column) + { + if (expression is MemberExpression member && member.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column)) + { + throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + column = null; + return false; + } + + private static bool TryGetCapturedValue(Expression expression, [NotNullWhen(true)] out string? name, out object? value) + { + if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + { + name = fieldInfo.Name; + value = fieldInfo.GetValue(constant.Value); + return true; + } + + name = null; + value = null; + return false; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj b/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj index b1904c6cc1cd..03b36f7525b1 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj @@ -19,6 +19,10 @@ Postgres(with pgvector extension) connector for Semantic Kernel plugins and semantic memory + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs index 3c864cc6537f..0175243131cd 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs @@ -48,8 +48,9 @@ internal interface IPostgresVectorStoreCollectionSqlBuilder /// The name of the vector column. /// The kind of index to create. /// The distance function to use for the index. + /// Specifies whether to include IF NOT EXISTS in the command. /// The built SQL command info. - PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction); + PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction, bool ifNotExists); /// /// Builds a SQL command to drop a table in the Postgres vector store. diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs index 6c68527da5c1..b4b9707c1c99 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs @@ -1,332 +1,53 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.Collections.Generic; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Linq; using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Text; namespace Microsoft.SemanticKernel.Connectors.Postgres; -internal class PostgresFilterTranslator +internal sealed class PostgresFilterTranslator : SqlFilterTranslator { - private IReadOnlyDictionary _storagePropertyNames = null!; - private ParameterExpression _recordParameter = null!; - private readonly List _parameterValues = new(); private int _parameterIndex; - private readonly StringBuilder _sql = new(); - - internal (string Clause, List Parameters) Translate( + internal PostgresFilterTranslator( IReadOnlyDictionary storagePropertyNames, LambdaExpression lambdaExpression, - int startParamIndex) + int startParamIndex) : base(storagePropertyNames, lambdaExpression, sql: null) { - Debug.Assert(this._sql.Length == 0); - - this._storagePropertyNames = storagePropertyNames; - this._parameterIndex = startParamIndex; - - Debug.Assert(lambdaExpression.Parameters.Count == 1); - this._recordParameter = lambdaExpression.Parameters[0]; - - this._sql.Append("WHERE "); - this.Translate(lambdaExpression.Body); - return (this._sql.ToString(), this._parameterValues); } - private void Translate(Expression? node) - { - switch (node) - { - case BinaryExpression binary: - this.TranslateBinary(binary); - return; - - case ConstantExpression constant: - this.TranslateConstant(constant); - return; - - case MemberExpression member: - this.TranslateMember(member); - return; + internal List ParameterValues => this._parameterValues; - case MethodCallExpression methodCall: - this.TranslateMethodCall(methodCall); - return; - - case UnaryExpression unary: - this.TranslateUnary(unary); - return; - - default: - throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); - } - } - - private void TranslateBinary(BinaryExpression binary) + protected override void TranslateContainsOverArrayColumn(Expression source, Expression item, MethodCallExpression parent) { - // Special handling for null comparisons - switch (binary.NodeType) - { - case ExpressionType.Equal when IsNull(binary.Right): - this._sql.Append('('); - this.Translate(binary.Left); - this._sql.Append(" IS NULL)"); - return; - case ExpressionType.NotEqual when IsNull(binary.Right): - this._sql.Append('('); - this.Translate(binary.Left); - this._sql.Append(" IS NOT NULL)"); - return; - - case ExpressionType.Equal when IsNull(binary.Left): - this._sql.Append('('); - this.Translate(binary.Right); - this._sql.Append(" IS NULL)"); - return; - case ExpressionType.NotEqual when IsNull(binary.Left): - this._sql.Append('('); - this.Translate(binary.Right); - this._sql.Append(" IS NOT NULL)"); - return; - } - - this._sql.Append('('); - this.Translate(binary.Left); - - this._sql.Append(binary.NodeType switch - { - ExpressionType.Equal => " = ", - ExpressionType.NotEqual => " <> ", - - ExpressionType.GreaterThan => " > ", - ExpressionType.GreaterThanOrEqual => " >= ", - ExpressionType.LessThan => " < ", - ExpressionType.LessThanOrEqual => " <= ", - - ExpressionType.AndAlso => " AND ", - ExpressionType.OrElse => " OR ", - - _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) - }); - - this.Translate(binary.Right); - this._sql.Append(')'); - - static bool IsNull(Expression expression) - => expression is ConstantExpression { Value: null } - || (TryGetCapturedValue(expression, out var capturedValue) && capturedValue is null); - } - - private void TranslateConstant(ConstantExpression constant) - { - // TODO: Nullable - switch (constant.Value) - { - case byte b: - this._sql.Append(b); - return; - case short s: - this._sql.Append(s); - return; - case int i: - this._sql.Append(i); - return; - case long l: - this._sql.Append(l); - return; - - case string s: - this._sql.Append('\'').Append(s.Replace("'", "''")).Append('\''); - return; - case bool b: - this._sql.Append(b ? "TRUE" : "FALSE"); - return; - case Guid g: - this._sql.Append('\'').Append(g.ToString()).Append('\''); - return; - - case DateTime: - case DateTimeOffset: - throw new NotImplementedException(); - - case Array: - throw new NotImplementedException(); - - case null: - this._sql.Append("NULL"); - return; - - default: - throw new NotSupportedException("Unsupported constant type: " + constant.Value.GetType().Name); - } + this.Translate(source, parent); + this._sql.Append(" @> ARRAY["); + this.Translate(item, parent); + this._sql.Append(']'); } - private void TranslateMember(MemberExpression memberExpression) + protected override void TranslateContainsOverCapturedArray(Expression source, Expression item, MethodCallExpression parent, object? value) { - switch (memberExpression) - { - case var _ when this.TryGetColumn(memberExpression, out var column): - this._sql.Append('"').Append(column).Append('"'); - return; - - // Identify captured lambda variables, translate to PostgreSQL parameters ($1, $2...) - case var _ when TryGetCapturedValue(memberExpression, out var capturedValue): - // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, - // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) - if (capturedValue is null) - { - this._sql.Append("NULL"); - } - else - { - this._parameterValues.Add(capturedValue); - this._sql.Append('$').Append(this._parameterIndex++); - } - return; - - default: - throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); - } - } - - private void TranslateMethodCall(MethodCallExpression methodCall) - { - switch (methodCall) - { - // Enumerable.Contains() - case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains - when contains.Method.DeclaringType == typeof(Enumerable): - this.TranslateContains(source, item); - return; - - // List.Contains() - case - { - Method: - { - Name: nameof(Enumerable.Contains), - DeclaringType: { IsGenericType: true } declaringType - }, - Object: Expression source, - Arguments: [var item] - } when declaringType.GetGenericTypeDefinition() == typeof(List<>): - this.TranslateContains(source, item); - return; - - default: - throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); - } - } - - private void TranslateContains(Expression source, Expression item) - { - switch (source) - { - // Contains over array column (r => r.Strings.Contains("foo")) - case var _ when this.TryGetColumn(source, out _): - this.Translate(source); - this._sql.Append(" @> ARRAY["); - this.Translate(item); - this._sql.Append(']'); - return; - - // Contains over inline array (r => new[] { "foo", "bar" }.Contains(r.String)) - case NewArrayExpression newArray: - this.Translate(item); - this._sql.Append(" IN ("); - - var isFirst = true; - foreach (var element in newArray.Expressions) - { - if (isFirst) - { - isFirst = false; - } - else - { - this._sql.Append(", "); - } - - this.Translate(element); - } - - this._sql.Append(')'); - return; - - // Contains over captured array (r => arrayLocalVariable.Contains(r.String)) - case var _ when TryGetCapturedValue(source, out _): - this.Translate(item); - this._sql.Append(" = ANY ("); - this.Translate(source); - this._sql.Append(')'); - return; - - default: - throw new NotSupportedException("Unsupported Contains expression"); - } - } - - private void TranslateUnary(UnaryExpression unary) - { - switch (unary.NodeType) - { - case ExpressionType.Not: - // Special handling for !(a == b) and !(a != b) - if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) - { - this.TranslateBinary( - Expression.MakeBinary( - binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, - binary.Left, - binary.Right)); - return; - } - - this._sql.Append("(NOT "); - this.Translate(unary.Operand); - this._sql.Append(')'); - return; - - default: - throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); - } + this.Translate(item, parent); + this._sql.Append(" = ANY ("); + this.Translate(source, parent); + this._sql.Append(')'); } - private bool TryGetColumn(Expression expression, [NotNullWhen(true)] out string? column) + protected override void TranslateCapturedVariable(string name, object? capturedValue) { - if (expression is MemberExpression member && member.Expression == this._recordParameter) + // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, + // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) + if (capturedValue is null) { - if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column)) - { - throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); - } - - return true; + this._sql.Append("NULL"); } - - column = null; - return false; - } - - private static bool TryGetCapturedValue(Expression expression, out object? capturedValue) - { - if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } - && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) - && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + else { - capturedValue = fieldInfo.GetValue(constant.Value); - return true; + this._parameterValues.Add(capturedValue); + this._sql.Append('$').Append(this._parameterIndex++); } - - capturedValue = null; - return false; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index 521dc5633cb0..f661c09ebf44 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -124,7 +124,7 @@ public PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tabl } /// - public PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction) + public PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction, bool ifNotExists) { // Only support creating HNSW index creation through the connector. var indexTypeName = indexKind switch @@ -149,7 +149,7 @@ public PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, strin return new PostgresSqlCommandInfo( commandText: $@" - CREATE INDEX ""{indexName}"" ON {schema}.""{tableName}"" USING {indexTypeName} (""{vectorColumnName}"" {indexOps});" + CREATE INDEX {(ifNotExists ? "IF NOT EXISTS " : "")} ""{indexName}"" ON {schema}.""{tableName}"" USING {indexTypeName} (""{vectorColumnName}"" {indexOps});" ); } @@ -281,11 +281,6 @@ public PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string t { NpgsqlDbType? keyType = PostgresVectorStoreRecordPropertyMapping.GetNpgsqlDbType(typeof(TKey)) ?? throw new ArgumentException($"Unsupported key type {typeof(TKey).Name}"); - if (keys == null || keys.Count == 0) - { - throw new ArgumentException("Keys cannot be null or empty", nameof(keys)); - } - var keyProperty = properties.OfType().FirstOrDefault() ?? throw new ArgumentException("Properties must contain a key property", nameof(properties)); var keyColumn = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; @@ -327,10 +322,6 @@ DELETE FROM {schema}."{tableName}" public PostgresSqlCommandInfo BuildDeleteBatchCommand(string schema, string tableName, string keyColumn, List keys) { NpgsqlDbType? keyType = PostgresVectorStoreRecordPropertyMapping.GetNpgsqlDbType(typeof(TKey)) ?? throw new ArgumentException($"Unsupported key type {typeof(TKey).Name}"); - if (keys == null || keys.Count == 0) - { - throw new ArgumentException("Keys cannot be null or empty", nameof(keys)); - } for (int i = 0; i < keys.Count; i++) { @@ -383,7 +374,7 @@ public PostgresSqlCommandInfo BuildGetNearestMatchCommand( { (not null, not null) => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), (not null, null) => GenerateLegacyFilterWhereClause(schema, tableName, propertyReader.RecordDefinition.Properties, legacyFilter, startParamIndex: 2), - (null, not null) => new PostgresFilterTranslator().Translate(propertyReader.StoragePropertyNamesMap, newFilter, startParamIndex: 2), + (null, not null) => GenerateNewFilterWhereClause(propertyReader, newFilter), _ => (Clause: string.Empty, Parameters: []) }; #pragma warning restore CS0618 // VectorSearchFilter is obsolete @@ -424,6 +415,14 @@ ORDER BY {PostgresConstants.DistanceColumnName} Parameters = [new NpgsqlParameter { Value = vectorValue }, .. parameters.Select(p => new NpgsqlParameter { Value = p })] }; } + + internal static (string Clause, List Parameters) GenerateNewFilterWhereClause(VectorStoreRecordPropertyReader propertyReader, LambdaExpression newFilter) + { + PostgresFilterTranslator translator = new(propertyReader.StoragePropertyNamesMap, newFilter, startParamIndex: 2); + translator.Translate(appendWhere: true); + return (translator.Clause.ToString(), translator.ParameterValues); + } + #pragma warning disable CS0618 // VectorSearchFilter is obsolete internal static (string Clause, List Parameters) GenerateLegacyFilterWhereClause(string schema, string tableName, IReadOnlyList properties, VectorSearchFilter legacyFilter, int startParamIndex) { diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs index b97b24708b25..07c228540038 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -74,7 +74,7 @@ public async Task CreateTableAsync(string tableName, IReadOnlyList - this._sqlBuilder.BuildCreateVectorIndexCommand(this._schema, tableName, index.column, index.kind, index.function) + this._sqlBuilder.BuildCreateVectorIndexCommand(this._schema, tableName, index.column, index.kind, index.function, ifNotExists) ); // Execute the commands in a transaction. @@ -152,11 +152,19 @@ public async Task UpsertBatchAsync(string tableName, IEnumerable> GetBatchAsync(string tableName, IEnumerable keys, IReadOnlyList properties, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) where TKey : notnull { + Verify.NotNull(keys); + + List listOfKeys = keys.ToList(); + if (listOfKeys.Count == 0) + { + yield break; + } + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); await using (connection) { - var commandInfo = this._sqlBuilder.BuildGetBatchCommand(this._schema, tableName, properties, keys.ToList(), includeVectors); + var commandInfo = this._sqlBuilder.BuildGetBatchCommand(this._schema, tableName, properties, listOfKeys, includeVectors); using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) @@ -198,7 +206,13 @@ public async Task DeleteAsync(string tableName, string keyColumn, TKey key /// public async Task DeleteBatchAsync(string tableName, string keyColumn, IEnumerable keys, CancellationToken cancellationToken = default) { - var commandInfo = this._sqlBuilder.BuildDeleteBatchCommand(this._schema, tableName, keyColumn, keys.ToList()); + var listOfKeys = keys.ToList(); + if (listOfKeys.Count == 0) + { + return; + } + + var commandInfo = this._sqlBuilder.BuildDeleteBatchCommand(this._schema, tableName, keyColumn, listOfKeys); await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index 6a87d2454179..5db73b801275 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -168,6 +168,8 @@ public virtual Task UpsertAsync(TRecord record, CancellationToken cancella /// public virtual async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + Verify.NotNull(records); + const string OperationName = "UpsertBatch"; var storageModels = records.Select(record => VectorStoreErrorHandler.RunModelConversion( @@ -176,6 +178,11 @@ public virtual async IAsyncEnumerable UpsertBatchAsync(IEnumerable this._mapper.MapFromDataToStorageModel(record))).ToList(); + if (storageModels.Count == 0) + { + yield break; + } + var keys = storageModels.Select(model => model[this._propertyReader.KeyPropertyStoragePropertyName]!).ToList(); await this.RunOperationAsync(OperationName, () => @@ -243,6 +250,8 @@ public virtual Task DeleteAsync(TKey key, CancellationToken cancellationToken = /// public virtual Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) { + Verify.NotNull(keys); + const string OperationName = "DeleteBatch"; return this.RunOperationAsync(OperationName, () => this._client.DeleteBatchAsync(this.CollectionName, this._propertyReader.KeyPropertyStoragePropertyName, keys, cancellationToken) @@ -319,7 +328,7 @@ private async Task RunOperationAsync(string operationName, Func operation) { await operation.Invoke().ConfigureAwait(false); } - catch (Exception ex) + catch (Exception ex) when (ex is not NotSupportedException) { throw new VectorStoreOperationException("Call to vector store failed.", ex) { @@ -336,7 +345,7 @@ private async Task RunOperationAsync(string operationName, Func> o { return await operation.Invoke().ConfigureAwait(false); } - catch (Exception ex) + catch (Exception ex) when (ex is not NotSupportedException) { throw new VectorStoreOperationException("Call to vector store failed.", ex) { diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj b/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj index ba73f9641bd9..045fd37fc3cf 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/Connectors.Memory.SqlServer.csproj @@ -4,7 +4,7 @@ Microsoft.SemanticKernel.Connectors.SqlServer $(AssemblyName) - netstandard2.0 + netstandard2.0;net8.0 alpha @@ -18,6 +18,10 @@ SQL Server connector for Semantic Kernel plugins and semantic memory + + + + @@ -26,4 +30,8 @@ + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/ExceptionWrapper.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/ExceptionWrapper.cs new file mode 100644 index 000000000000..452887ea7dd1 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/ExceptionWrapper.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +#pragma warning disable CA1068 // CancellationToken parameters must come last + +internal static class ExceptionWrapper +{ + private const string VectorStoreType = "SqlServer"; + + internal static async Task WrapAsync( + SqlConnection connection, + SqlCommand command, + Func> func, + CancellationToken cancellationToken, + string operationName, + string? collectionName = null) + { + if (connection.State != System.Data.ConnectionState.Open) + { + await connection.OpenAsync(cancellationToken).ConfigureAwait(false); + } + + try + { + return await func(command, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + throw new VectorStoreOperationException(ex.Message, ex) + { + OperationName = operationName, + VectorStoreType = VectorStoreType, + CollectionName = collectionName + }; + } + } + + internal static async Task WrapReadAsync( + SqlDataReader reader, + CancellationToken cancellationToken, + string operationName, + string? collectionName = null) + { + try + { + return await reader.ReadAsync(cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + throw new VectorStoreOperationException(ex.Message, ex) + { + OperationName = operationName, + VectorStoreType = VectorStoreType, + CollectionName = collectionName + }; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/GenericRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/GenericRecordMapper.cs new file mode 100644 index 000000000000..ff9c7851f4cb --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/GenericRecordMapper.cs @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +internal sealed class GenericRecordMapper : IVectorStoreRecordMapper, IDictionary> + where TKey : notnull +{ + private readonly VectorStoreRecordPropertyReader _propertyReader; + + internal GenericRecordMapper(VectorStoreRecordPropertyReader propertyReader) => this._propertyReader = propertyReader; + + public IDictionary MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) + { + Dictionary properties = new() + { + { SqlServerCommandBuilder.GetColumnName(this._propertyReader.KeyProperty), dataModel.Key } + }; + + foreach (var property in this._propertyReader.DataProperties) + { + string name = SqlServerCommandBuilder.GetColumnName(property); + if (dataModel.Data.TryGetValue(name, out var dataValue)) + { + properties.Add(name, dataValue); + } + } + + // Add vector properties + if (dataModel.Vectors is not null) + { + foreach (var property in this._propertyReader.VectorProperties) + { + string name = SqlServerCommandBuilder.GetColumnName(property); + if (dataModel.Vectors.TryGetValue(name, out var vectorValue)) + { + if (vectorValue is ReadOnlyMemory floats) + { + properties.Add(name, floats); + } + else if (vectorValue is not null) + { + throw new VectorStoreRecordMappingException($"Vector property '{name}' contained value of non supported type: '{vectorValue.GetType().FullName}'."); + } + } + } + } + + return properties; + } + + public VectorStoreGenericDataModel MapFromStorageToDataModel(IDictionary storageModel, StorageToDataModelMapperOptions options) + { + TKey key; + var dataProperties = new Dictionary(); + var vectorProperties = new Dictionary(); + + if (storageModel.TryGetValue(SqlServerCommandBuilder.GetColumnName(this._propertyReader.KeyProperty), out var keyObject) && keyObject is not null) + { + key = (TKey)keyObject; + } + else + { + throw new VectorStoreRecordMappingException("No key property was found in the record retrieved from storage."); + } + + foreach (var property in this._propertyReader.DataProperties) + { + string name = SqlServerCommandBuilder.GetColumnName(property); + if (storageModel.TryGetValue(name, out var dataValue)) + { + dataProperties.Add(name, dataValue); + } + } + + if (options.IncludeVectors) + { + foreach (var property in this._propertyReader.VectorProperties) + { + string name = SqlServerCommandBuilder.GetColumnName(property); + if (storageModel.TryGetValue(name, out var vectorValue)) + { + vectorProperties.Add(name, vectorValue); + } + } + } + + return new(key) { Data = dataProperties, Vectors = vectorProperties }; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs new file mode 100644 index 000000000000..240f2814e044 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/RecordMapper.cs @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Reflection; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +internal sealed class RecordMapper : IVectorStoreRecordMapper> +{ + private readonly VectorStoreRecordPropertyReader _propertyReader; + + internal RecordMapper(VectorStoreRecordPropertyReader propertyReader) => this._propertyReader = propertyReader; + + public IDictionary MapFromDataToStorageModel(TRecord dataModel) + { + Dictionary map = new(StringComparer.Ordinal); + + map[SqlServerCommandBuilder.GetColumnName(this._propertyReader.KeyProperty)] = this._propertyReader.KeyPropertyInfo.GetValue(dataModel); + + var dataProperties = this._propertyReader.DataProperties; + var dataPropertiesInfo = this._propertyReader.DataPropertiesInfo; + for (int i = 0; i < dataProperties.Count; i++) + { + object? value = dataPropertiesInfo[i].GetValue(dataModel); + map[SqlServerCommandBuilder.GetColumnName(dataProperties[i])] = value; + } + var vectorProperties = this._propertyReader.VectorProperties; + var vectorPropertiesInfo = this._propertyReader.VectorPropertiesInfo; + for (int i = 0; i < vectorProperties.Count; i++) + { + // We restrict the vector properties to ReadOnlyMemory so the cast here is safe. + ReadOnlyMemory floats = (ReadOnlyMemory)vectorPropertiesInfo[i].GetValue(dataModel)!; + map[SqlServerCommandBuilder.GetColumnName(vectorProperties[i])] = floats; + } + + return map; + } + + public TRecord MapFromStorageToDataModel(IDictionary storageModel, StorageToDataModelMapperOptions options) + { + TRecord record = Activator.CreateInstance()!; + SetValue(storageModel, record, this._propertyReader.KeyPropertyInfo, this._propertyReader.KeyProperty); + var data = this._propertyReader.DataProperties; + var dataInfo = this._propertyReader.DataPropertiesInfo; + for (int i = 0; i < data.Count; i++) + { + SetValue(storageModel, record, dataInfo[i], data[i]); + } + + if (options.IncludeVectors) + { + var vector = this._propertyReader.VectorProperties; + var vectorInfo = this._propertyReader.VectorPropertiesInfo; + for (int i = 0; i < vector.Count; i++) + { + object? value = storageModel[SqlServerCommandBuilder.GetColumnName(vector[i])]; + if (value is not null) + { + if (value is ReadOnlyMemory floats) + { + vectorInfo[i].SetValue(record, floats); + } + else + { + // When deserializing a string to a ReadOnlyMemory fails in SqlDataReaderDictionary, + // we store the raw value so the user can handle the error in a custom mapper. + throw new VectorStoreRecordMappingException($"Failed to deserialize vector property '{vector[i].DataModelPropertyName}', it contained value '{value}'."); + } + } + } + } + + return record; + + static void SetValue(IDictionary storageModel, object record, PropertyInfo propertyInfo, VectorStoreRecordProperty property) + { + // If we got here, there should be no column name mismatch (the query would fail). + object? value = storageModel[SqlServerCommandBuilder.GetColumnName(property)]; + + if (value is null) + { + // There is no need to call the reflection to set the null, + // as it's the default value of every .NET reference type field. + return; + } + + try + { + propertyInfo.SetValue(record, value); + } + catch (Exception ex) + { + throw new VectorStoreRecordMappingException($"Failed to set value '{value}' on property '{propertyInfo.Name}' of type '{propertyInfo.PropertyType.FullName}'.", ex); + } + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs new file mode 100644 index 000000000000..414ff8de4afd --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text.Json; +using Microsoft.Data.SqlClient; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +/// +/// This class is used to provide a dictionary-like interface to a . +/// The goal is to avoid the need of allocating a new dictionary for each row read from the database. +/// +internal sealed class SqlDataReaderDictionary : IDictionary +{ + private readonly SqlDataReader _sqlDataReader; + private readonly IReadOnlyList _vectorPropertyStoragePropertyNames; + + // This field will get instantiated lazily, only if needed by a custom mapper. + private Dictionary? _dictionary; + + internal SqlDataReaderDictionary(SqlDataReader sqlDataReader, IReadOnlyList vectorPropertyStoragePropertyNames) + { + this._sqlDataReader = sqlDataReader; + this._vectorPropertyStoragePropertyNames = vectorPropertyStoragePropertyNames; + } + + private object? Unwrap(string storageName, object? value) + { + // Let's make sure our users don't need to learn what DBNull is. + if (value is DBNull) + { + return null; + } + + // If the value is a vector, we need to deserialize it. + if (this._vectorPropertyStoragePropertyNames.Count > 0 && value is string text) + { + for (int i = 0; i < this._vectorPropertyStoragePropertyNames.Count; i++) + { + if (string.Equals(storageName, this._vectorPropertyStoragePropertyNames[i], StringComparison.Ordinal)) + { + try + { + return JsonSerializer.Deserialize>(text); + } + catch (JsonException) + { + // This may fail if the user has stored a non-float array in the database + // (or serialized it in a different way). + // We need to return the raw value, so the user can handle the error in a custom mapper. + return text; + } + } + } + } + +#if NET + // The SqlClient accepts TimeOnly as parameters, but returns them as TimeSpan. + // Since we don't support TimeSpan, we can convert it back to TimeOnly. + if (value is TimeSpan timeSpan) + { + return new TimeOnly(timeSpan.Ticks); + } +#endif + + return value; + } + + // This is the only method used by the default mapper. + public object? this[string key] + { + get => this.Unwrap(key, this._sqlDataReader[key]); + set => throw new InvalidOperationException(); + } + + public ICollection Keys => this.GetDictionary().Keys; + + public ICollection Values => this.GetDictionary().Values; + + public int Count => this._sqlDataReader.FieldCount; + + public bool IsReadOnly => true; + + public void Add(string key, object? value) => throw new InvalidOperationException(); + + public void Add(KeyValuePair item) => throw new InvalidOperationException(); + + public void Clear() => throw new InvalidOperationException(); + + public bool Contains(KeyValuePair item) + => this.TryGetValue(item.Key, out var value) && Equals(value, item.Value); + + public bool ContainsKey(string key) + { + try + { + return this._sqlDataReader.GetOrdinal(key) >= 0; + } + catch (IndexOutOfRangeException) + { + return false; + } + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + => ((ICollection>)this.GetDictionary()).CopyTo(array, arrayIndex); + + public IEnumerator> GetEnumerator() + => this.GetDictionary().GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() + => this.GetDictionary().GetEnumerator(); + + public bool Remove(string key) => throw new InvalidOperationException(); + + public bool Remove(KeyValuePair item) => throw new InvalidOperationException(); + + public bool TryGetValue(string key, out object? value) + { + try + { + value = this.Unwrap(key, this._sqlDataReader[key]); + return true; + } + catch (IndexOutOfRangeException) + { + value = default; + return false; + } + } + + private Dictionary GetDictionary() + { + if (this._dictionary is null) + { + Dictionary dictionary = new(this._sqlDataReader.FieldCount, StringComparer.Ordinal); + for (int i = 0; i < this._sqlDataReader.FieldCount; i++) + { + string name = this._sqlDataReader.GetName(i); + dictionary.Add(name, this.Unwrap(name, this._sqlDataReader[i])); + } + this._dictionary = dictionary; + } + return this._dictionary; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerClient.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerClient.cs index 4a1225f0a46f..42f65ef55cf8 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerClient.cs @@ -80,14 +80,7 @@ public async IAsyncEnumerable GetTablesAsync([EnumeratorCancellation] Ca { using (await this.OpenConnectionAsync(cancellationToken).ConfigureAwait(false)) { - using var cmd = this._connection.CreateCommand(); - cmd.CommandText = """ - SELECT TABLE_NAME - FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_TYPE = 'BASE TABLE' - AND TABLE_SCHEMA = @schema - """; - cmd.Parameters.AddWithValue("@schema", this._schema); + using var cmd = SqlServerCommandBuilder.SelectTableNames(this._connection, this._schema); using var reader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) { @@ -101,16 +94,7 @@ public async Task DoesTableExistsAsync(string tableName, CancellationToken { using (await this.OpenConnectionAsync(cancellationToken).ConfigureAwait(false)) { - using var cmd = this._connection.CreateCommand(); - cmd.CommandText = """ - SELECT TABLE_NAME - FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_TYPE = 'BASE TABLE' - AND TABLE_SCHEMA = @schema - AND TABLE_NAME = @tableName - """; - cmd.Parameters.AddWithValue("@schema", this._schema); - cmd.Parameters.AddWithValue("@tableName", tableName); + using var cmd = SqlServerCommandBuilder.SelectTableName(this._connection, this._schema, tableName); using var reader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); return await reader.ReadAsync(cancellationToken).ConfigureAwait(false); } @@ -121,11 +105,7 @@ public async Task DeleteTableAsync(string tableName, CancellationToken cancellat { using (await this.OpenConnectionAsync(cancellationToken).ConfigureAwait(false)) { - using var cmd = this._connection.CreateCommand(); - var fullTableName = this.GetSanitizedFullTableName(tableName); - cmd.CommandText = $""" - DROP TABLE IF EXISTS {fullTableName} - """; + using var cmd = SqlServerCommandBuilder.DropTableIfExists(this._connection, this._schema, tableName); await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } } diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs new file mode 100644 index 000000000000..17b23d7d7477 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -0,0 +1,547 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Text; +using System.Text.Json; +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData; + +#pragma warning disable CA2100 // Review SQL queries for security vulnerabilities + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +internal static class SqlServerCommandBuilder +{ + internal static SqlCommand CreateTable( + SqlConnection connection, + string? schema, + string tableName, + bool ifNotExists, + VectorStoreRecordKeyProperty keyProperty, + IReadOnlyList dataProperties, + IReadOnlyList vectorProperties) + { + StringBuilder sb = new(200); + if (ifNotExists) + { + sb.Append("IF OBJECT_ID(N'"); + sb.AppendTableName(schema, tableName); + sb.AppendLine("', N'U') IS NULL"); + } + sb.AppendLine("BEGIN"); + sb.Append("CREATE TABLE "); + sb.AppendTableName(schema, tableName); + sb.AppendLine(" ("); + string keyColumnName = GetColumnName(keyProperty); + var keyMapping = Map(keyProperty.PropertyType); + sb.AppendFormat("[{0}] {1} NOT NULL,", keyColumnName, keyMapping.sqlName); + sb.AppendLine(); + for (int i = 0; i < dataProperties.Count; i++) + { + sb.AppendFormat("[{0}] {1},", GetColumnName(dataProperties[i]), Map(dataProperties[i].PropertyType).sqlName); + sb.AppendLine(); + } + for (int i = 0; i < vectorProperties.Count; i++) + { + sb.AppendFormat("[{0}] VECTOR({1}),", GetColumnName(vectorProperties[i]), vectorProperties[i].Dimensions); + sb.AppendLine(); + } + sb.AppendFormat("PRIMARY KEY ([{0}])", keyColumnName); + sb.AppendLine(); + sb.AppendLine(");"); // end the table definition + + foreach (var vectorProperty in vectorProperties) + { + switch (vectorProperty.IndexKind) + { + case null: + case "": + case IndexKind.Flat: + break; + default: + throw new NotSupportedException($"Index kind {vectorProperty.IndexKind} is not supported."); + } + } + sb.Append("END;"); + + return connection.CreateCommand(sb); + } + + internal static SqlCommand DropTableIfExists(SqlConnection connection, string? schema, string tableName) + { + StringBuilder sb = new(50); + sb.Append("DROP TABLE IF EXISTS "); + sb.AppendTableName(schema, tableName); + + return connection.CreateCommand(sb); + } + + internal static SqlCommand SelectTableName(SqlConnection connection, string? schema, string tableName) + { + SqlCommand command = connection.CreateCommand(); + command.CommandText = """ + SELECT TABLE_NAME + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE = 'BASE TABLE' + AND (@schema is NULL or TABLE_SCHEMA = @schema) + AND TABLE_NAME = @tableName + """; + command.Parameters.AddWithValue("@schema", string.IsNullOrEmpty(schema) ? DBNull.Value : schema); + command.Parameters.AddWithValue("@tableName", tableName); // the name is not escaped by us, just provided as parameter + return command; + } + + internal static SqlCommand SelectTableNames(SqlConnection connection, string? schema) + { + SqlCommand command = connection.CreateCommand(); + command.CommandText = """ + SELECT TABLE_NAME + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE = 'BASE TABLE' + AND (@schema is NULL or TABLE_SCHEMA = @schema) + """; + command.Parameters.AddWithValue("@schema", string.IsNullOrEmpty(schema) ? DBNull.Value : schema); + return command; + } + + internal static SqlCommand MergeIntoSingle( + SqlConnection connection, + string? schema, + string tableName, + VectorStoreRecordKeyProperty keyProperty, + IReadOnlyList properties, + IDictionary record) + { + SqlCommand command = connection.CreateCommand(); + StringBuilder sb = new(200); + sb.Append("MERGE INTO "); + sb.AppendTableName(schema, tableName); + sb.AppendLine(" AS t"); + sb.Append("USING (VALUES ("); + int paramIndex = 0; + foreach (VectorStoreRecordProperty property in properties) + { + sb.AppendParameterName(property, ref paramIndex, out string paramName).Append(','); + command.AddParameter(property, paramName, record[GetColumnName(property)]); + } + sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis + sb.Append(") AS s ("); + sb.AppendColumnNames(properties); + sb.AppendLine(")"); + sb.AppendFormat("ON (t.[{0}] = s.[{0}])", GetColumnName(keyProperty)).AppendLine(); + sb.AppendLine("WHEN MATCHED THEN"); + sb.Append("UPDATE SET "); + foreach (VectorStoreRecordProperty property in properties) + { + if (property != keyProperty) // don't update the key + { + sb.AppendFormat("t.[{0}] = s.[{0}],", GetColumnName(property)); + } + } + --sb.Length; // remove the last comma + sb.AppendLine(); + + sb.Append("WHEN NOT MATCHED THEN"); + sb.AppendLine(); + sb.Append("INSERT ("); + sb.AppendColumnNames(properties); + sb.AppendLine(")"); + sb.Append("VALUES ("); + sb.AppendColumnNames(properties, prefix: "s."); + sb.AppendLine(")"); + sb.AppendFormat("OUTPUT inserted.[{0}];", GetColumnName(keyProperty)); + + command.CommandText = sb.ToString(); + return command; + } + + internal static SqlCommand? MergeIntoMany( + SqlConnection connection, + string? schema, + string tableName, + VectorStoreRecordKeyProperty keyProperty, + IReadOnlyList properties, + IEnumerable> records) + { + SqlCommand command = connection.CreateCommand(); + + StringBuilder sb = new(200); + // The DECLARE statement creates a table variable to store the keys of the inserted rows. + sb.AppendFormat("DECLARE @InsertedKeys TABLE (KeyColumn {0});", Map(keyProperty.PropertyType).sqlName); + sb.AppendLine(); + // The MERGE statement performs the upsert operation and outputs the keys of the inserted rows into the table variable. + sb.Append("MERGE INTO "); + sb.AppendTableName(schema, tableName); + sb.AppendLine(" AS t"); // t stands for target + sb.AppendLine("USING (VALUES"); + int rowIndex = 0, paramIndex = 0; + foreach (var record in records) + { + sb.Append('('); + foreach (VectorStoreRecordProperty property in properties) + { + sb.AppendParameterName(property, ref paramIndex, out string paramName).Append(','); + command.AddParameter(property, paramName, record[GetColumnName(property)]); + } + sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis + sb.AppendLine(","); + rowIndex++; + } + + if (rowIndex == 0) + { + return null; // there is nothing to do! + } + + sb.Length -= (1 + Environment.NewLine.Length); // remove the last comma and newline + + sb.Append(") AS s ("); // s stands for source + sb.AppendColumnNames(properties); + sb.AppendLine(")"); + sb.AppendFormat("ON (t.[{0}] = s.[{0}])", GetColumnName(keyProperty)).AppendLine(); + sb.AppendLine("WHEN MATCHED THEN"); + sb.Append("UPDATE SET "); + foreach (VectorStoreRecordProperty property in properties) + { + if (property != keyProperty) // don't update the key + { + sb.AppendFormat("t.[{0}] = s.[{0}],", GetColumnName(property)); + } + } + --sb.Length; // remove the last comma + sb.AppendLine(); + sb.Append("WHEN NOT MATCHED THEN"); + sb.AppendLine(); + sb.Append("INSERT ("); + sb.AppendColumnNames(properties); + sb.AppendLine(")"); + sb.Append("VALUES ("); + sb.AppendColumnNames(properties, prefix: "s."); + sb.AppendLine(")"); + sb.AppendFormat("OUTPUT inserted.[{0}] INTO @InsertedKeys (KeyColumn);", GetColumnName(keyProperty)); + sb.AppendLine(); + + // The SELECT statement returns the keys of the inserted rows. + sb.Append("SELECT KeyColumn FROM @InsertedKeys;"); + + command.CommandText = sb.ToString(); + return command; + } + + internal static SqlCommand DeleteSingle( + SqlConnection connection, string? schema, string tableName, + VectorStoreRecordKeyProperty keyProperty, object key) + { + SqlCommand command = connection.CreateCommand(); + + int paramIndex = 0; + StringBuilder sb = new(100); + sb.Append("DELETE FROM "); + sb.AppendTableName(schema, tableName); + sb.AppendFormat(" WHERE [{0}] = ", GetColumnName(keyProperty)); + sb.AppendParameterName(keyProperty, ref paramIndex, out string keyParamName); + command.AddParameter(keyProperty, keyParamName, key); + + command.CommandText = sb.ToString(); + return command; + } + + internal static SqlCommand? DeleteMany( + SqlConnection connection, string? schema, string tableName, + VectorStoreRecordKeyProperty keyProperty, IEnumerable keys) + { + SqlCommand command = connection.CreateCommand(); + + StringBuilder sb = new(100); + sb.Append("DELETE FROM "); + sb.AppendTableName(schema, tableName); + sb.AppendFormat(" WHERE [{0}] IN (", GetColumnName(keyProperty)); + sb.AppendKeyParameterList(keys, command, keyProperty, out bool emptyKeys); + sb.Append(')'); // close the IN clause + + if (emptyKeys) + { + return null; // there is nothing to do! + } + + command.CommandText = sb.ToString(); + return command; + } + + internal static SqlCommand SelectSingle( + SqlConnection sqlConnection, string? schema, string collectionName, + VectorStoreRecordKeyProperty keyProperty, + IReadOnlyList properties, + object key, + bool includeVectors) + { + SqlCommand command = sqlConnection.CreateCommand(); + + int paramIndex = 0; + StringBuilder sb = new(200); + sb.AppendFormat("SELECT "); + sb.AppendColumnNames(properties, includeVectors: includeVectors); + sb.AppendLine(); + sb.Append("FROM "); + sb.AppendTableName(schema, collectionName); + sb.AppendLine(); + sb.AppendFormat("WHERE [{0}] = ", GetColumnName(keyProperty)); + sb.AppendParameterName(keyProperty, ref paramIndex, out string keyParamName); + command.AddParameter(keyProperty, keyParamName, key); + + command.CommandText = sb.ToString(); + return command; + } + + internal static SqlCommand? SelectMany( + SqlConnection connection, string? schema, string tableName, + VectorStoreRecordKeyProperty keyProperty, + IReadOnlyList properties, + IEnumerable keys, + bool includeVectors) + { + SqlCommand command = connection.CreateCommand(); + + StringBuilder sb = new(200); + sb.AppendFormat("SELECT "); + sb.AppendColumnNames(properties, includeVectors: includeVectors); + sb.AppendLine(); + sb.Append("FROM "); + sb.AppendTableName(schema, tableName); + sb.AppendLine(); + sb.AppendFormat("WHERE [{0}] IN (", GetColumnName(keyProperty)); + sb.AppendKeyParameterList(keys, command, keyProperty, out bool emptyKeys); + sb.Append(')'); // close the IN clause + + if (emptyKeys) + { + return null; // there is nothing to do! + } + + command.CommandText = sb.ToString(); + return command; + } + + internal static SqlCommand SelectVector( + SqlConnection connection, string? schema, string tableName, + VectorStoreRecordVectorProperty vectorProperty, + IReadOnlyList properties, + IReadOnlyDictionary storagePropertyNamesMap, + VectorSearchOptions options, + ReadOnlyMemory vector) + { + string distanceFunction = vectorProperty.DistanceFunction ?? DistanceFunction.CosineDistance; + (string distanceMetric, string sorting) = MapDistanceFunction(distanceFunction); + + SqlCommand command = connection.CreateCommand(); + command.Parameters.AddWithValue("@vector", JsonSerializer.Serialize(vector)); + + StringBuilder sb = new(200); + sb.AppendFormat("SELECT "); + sb.AppendColumnNames(properties, includeVectors: options.IncludeVectors); + sb.AppendLine(","); + sb.AppendFormat("VECTOR_DISTANCE('{0}', {1}, CAST(@vector AS VECTOR({2}))) AS [score]", + distanceMetric, GetColumnName(vectorProperty), vector.Length); + sb.AppendLine(); + sb.Append("FROM "); + sb.AppendTableName(schema, tableName); + sb.AppendLine(); + if (options.Filter is not null) + { + int startParamIndex = command.Parameters.Count; + + SqlServerFilterTranslator translator = new(storagePropertyNamesMap, options.Filter, sb, startParamIndex: startParamIndex); + translator.Translate(appendWhere: true); + List parameters = translator.ParameterValues; + + foreach (object parameter in parameters) + { + command.AddParameter(vectorProperty, $"@_{startParamIndex++}", parameter); + } + sb.AppendLine(); + } + sb.AppendFormat("ORDER BY [score] {0}", sorting); + sb.AppendLine(); + // Negative Skip and Top values are rejected by the VectorSearchOptions property setters. + // 0 is a legal value for OFFSET. + sb.AppendFormat("OFFSET {0} ROWS FETCH NEXT {1} ROWS ONLY;", options.Skip, options.Top); + + command.CommandText = sb.ToString(); + return command; + } + + internal static string GetColumnName(VectorStoreRecordProperty property) + => property.StoragePropertyName ?? property.DataModelPropertyName; + + internal static StringBuilder AppendParameterName(this StringBuilder sb, VectorStoreRecordProperty property, ref int paramIndex, out string parameterName) + { + // In SQL Server, parameter names cannot be just a number like "@1". + // Parameter names must start with an alphabetic character or an underscore + // and can be followed by alphanumeric characters or underscores. + // Since we can't guarantee that the value returned by StoragePropertyName and DataModelPropertyName + // is valid parameter name (it can contain whitespaces, or start with a number), + // we just append the ASCII letters, stop on the first non-ASCII letter + // and append the index. + string columnName = GetColumnName(property); + int index = sb.Length; + sb.Append('@'); + foreach (char character in columnName) + { + // We don't call APIs like char.IsWhitespace as they are expensive + // as they need to handle all Unicode characters. + if (!((character is >= 'a' and <= 'z') || (character is >= 'A' and <= 'Z'))) + { + break; + } + sb.Append(character); + } + // In case the column name is empty or does not start with ASCII letters, + // we provide the underscore as a prefix (allowed). + sb.Append('_'); + // To ensure the generated parameter id is unique, we append the index. + sb.Append(paramIndex++); + parameterName = sb.ToString(index, sb.Length - index); + + return sb; + } + + internal static StringBuilder AppendTableName(this StringBuilder sb, string? schema, string tableName) + { + // If the column name contains a ], then escape it by doubling it. + // "Name with [brackets]" becomes [Name with [brackets]]]. + + sb.Append('['); + int index = sb.Length; // store the index, so we replace ] only for the appended part + + if (!string.IsNullOrEmpty(schema)) + { + sb.Append(schema); + sb.Replace("]", "]]", index, schema.Length); // replace the ] for schema + sb.Append("].["); + index = sb.Length; + } + + sb.Append(tableName); + sb.Replace("]", "]]", index, tableName.Length); + sb.Append(']'); + + return sb; + } + + private static StringBuilder AppendColumnNames(this StringBuilder sb, + IEnumerable properties, + string? prefix = null, + bool includeVectors = true) + { + bool any = false; + foreach (VectorStoreRecordProperty property in properties) + { + if (!includeVectors && property is VectorStoreRecordVectorProperty) + { + continue; + } + + if (prefix is not null) + { + sb.Append(prefix); + } + // Use square brackets to escape column names. + sb.AppendFormat("[{0}],", GetColumnName(property)); + any = true; + } + + if (any) + { + --sb.Length; // remove the last comma + } + + return sb; + } + + private static StringBuilder AppendKeyParameterList(this StringBuilder sb, + IEnumerable keys, SqlCommand command, VectorStoreRecordKeyProperty keyProperty, out bool emptyKeys) + { + int keyIndex = 0; + foreach (TKey key in keys) + { + // The caller ensures that keys collection is not null. + // We need to ensure that none of the keys is null. + Verify.NotNull(key); + + sb.AppendParameterName(keyProperty, ref keyIndex, out string keyParamName); + sb.Append(','); + command.AddParameter(keyProperty, keyParamName, key); + } + + emptyKeys = keyIndex == 0; + sb.Length--; // remove the last comma + return sb; + } + + private static SqlCommand CreateCommand(this SqlConnection connection, StringBuilder sb) + { + SqlCommand command = connection.CreateCommand(); + command.CommandText = sb.ToString(); + return command; + } + + private static void AddParameter(this SqlCommand command, VectorStoreRecordProperty property, string name, object? value) + { + switch (value) + { + case null when property.PropertyType == typeof(byte[]): + command.Parameters.Add(name, System.Data.SqlDbType.VarBinary).Value = DBNull.Value; + break; + case null: + case ReadOnlyMemory vector when vector.Length == 0: + command.Parameters.AddWithValue(name, DBNull.Value); + break; + case byte[] buffer: + command.Parameters.Add(name, System.Data.SqlDbType.VarBinary).Value = buffer; + break; + case ReadOnlyMemory vector: + command.Parameters.AddWithValue(name, JsonSerializer.Serialize(vector)); + break; + default: + command.Parameters.AddWithValue(name, value); + break; + } + } + + private static (string sqlName, string? autoGenerate) Map(Type type) + { + return type switch + { + Type t when t == typeof(byte) => ("TINYINT", null), + Type t when t == typeof(short) => ("SMALLINT", null), + Type t when t == typeof(int) => ("INT", "IDENTITY(1,1)"), + Type t when t == typeof(long) => ("BIGINT", "IDENTITY(1,1)"), + Type t when t == typeof(Guid) => ("UNIQUEIDENTIFIER", "DEFAULT NEWSEQUENTIALID()"), + Type t when t == typeof(string) => ("NVARCHAR(255)", null), + Type t when t == typeof(byte[]) => ("VARBINARY(MAX)", null), + Type t when t == typeof(bool) => ("BIT", null), + Type t when t == typeof(DateTime) => ("DATETIME2", null), +#if NET + Type t when t == typeof(TimeOnly) => ("TIME", null), +#endif + Type t when t == typeof(decimal) => ("DECIMAL", null), + Type t when t == typeof(double) => ("FLOAT", null), + Type t when t == typeof(float) => ("REAL", null), + _ => throw new NotSupportedException($"Type {type} is not supported.") + }; + } + + // Source: https://learn.microsoft.com/sql/t-sql/functions/vector-distance-transact-sql + private static (string distanceMetric, string sorting) MapDistanceFunction(string name) => name switch + { + // A value of 0 indicates that the vectors are identical in direction (cosine similarity of 1), + // while a value of 1 indicates that the vectors are orthogonal (cosine similarity of 0). + DistanceFunction.CosineDistance => ("COSINE", "ASC"), + // A value of 0 indicates that the vectors are identical, while larger values indicate greater dissimilarity. + DistanceFunction.EuclideanDistance => ("EUCLIDEAN", "ASC"), + // A value closer to 0 indicates higher similarity, while more negative values indicate greater dissimilarity. + DistanceFunction.NegativeDotProductSimilarity => ("DOT", "DESC"), + _ => throw new NotSupportedException($"Distance function {name} is not supported.") + }; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs new file mode 100644 index 000000000000..6b81cbac1ef6 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerConstants.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +internal static class SqlServerConstants +{ + internal static readonly HashSet SupportedKeyTypes = + [ + typeof(int), // INT + typeof(long), // BIGINT + typeof(string), // VARCHAR + typeof(Guid), // UNIQUEIDENTIFIER + typeof(DateTime), // DATETIME2 + typeof(byte[]) // VARBINARY + ]; + + internal static readonly HashSet SupportedDataTypes = + [ + typeof(int), // INT + typeof(short), // SMALLINT + typeof(byte), // TINYINT + typeof(long), // BIGINT. + typeof(Guid), // UNIQUEIDENTIFIER. + typeof(string), // NVARCHAR + typeof(byte[]), // VARBINARY + typeof(bool), // BIT + typeof(DateTime), // DATETIME2 +#if NET + // We don't support mapping TimeSpan to TIME on purpose + // See https://github.com/microsoft/semantic-kernel/pull/10623#discussion_r1980350721 + typeof(TimeOnly), // TIME +#endif + typeof(decimal), // DECIMAL + typeof(double), // FLOAT + typeof(float), // REAL + ]; + + internal static readonly HashSet SupportedVectorTypes = + [ + typeof(ReadOnlyMemory), // VECTOR + typeof(ReadOnlyMemory?) + ]; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs new file mode 100644 index 000000000000..3bd3b2f97e0b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerFilterTranslator.cs @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Text; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +internal sealed class SqlServerFilterTranslator : SqlFilterTranslator +{ + private readonly List _parameterValues = new(); + private int _parameterIndex; + + internal SqlServerFilterTranslator( + IReadOnlyDictionary storagePropertyNames, + LambdaExpression lambdaExpression, + StringBuilder sql, + int startParamIndex) + : base(storagePropertyNames, lambdaExpression, sql) + { + this._parameterIndex = startParamIndex; + } + + internal List ParameterValues => this._parameterValues; + + protected override void TranslateConstant(object? value) + { + switch (value) + { + case bool boolValue: + this._sql.Append(boolValue ? "1" : "0"); + return; + case DateTime dateTime: + this._sql.AppendFormat("'{0:yyyy-MM-dd HH:mm:ss}'", dateTime); + return; + case DateTimeOffset dateTimeOffset: + this._sql.AppendFormat("'{0:yyy-MM-dd HH:mm:ss zzz}'", dateTimeOffset); + return; + default: + base.TranslateConstant(value); + break; + } + } + + protected override void TranslateColumn(string column, MemberExpression memberExpression, Expression? parent) + { + // "SELECT * FROM MyTable WHERE BooleanColumn;" is not supported. + // "SELECT * FROM MyTable WHERE BooleanColumn = 1;" is supported. + if (memberExpression.Type == typeof(bool) + && (parent is null // Where(x => x.Bool) + || parent is UnaryExpression { NodeType: ExpressionType.Not } // Where(x => !x.Bool) + || parent is BinaryExpression { NodeType: ExpressionType.AndAlso or ExpressionType.OrElse })) // Where(x => x.Bool && other) + { + this.TranslateBinary(Expression.Equal(memberExpression, Expression.Constant(true))); + } + else + { + this._sql.Append('[').Append(column).Append(']'); + } + } + + protected override void TranslateContainsOverArrayColumn(Expression source, Expression item, MethodCallExpression parent) + => throw new NotSupportedException("Unsupported Contains expression"); + + protected override void TranslateContainsOverCapturedArray(Expression source, Expression item, MethodCallExpression parent, object? value) + { + if (value is not IEnumerable elements) + { + throw new NotSupportedException("Unsupported Contains expression"); + } + + this.Translate(item, parent); + this._sql.Append(" IN ("); + + var isFirst = true; + foreach (var element in elements) + { + if (isFirst) + { + isFirst = false; + } + else + { + this._sql.Append(", "); + } + + this.TranslateConstant(element); + } + + this._sql.Append(')'); + } + + protected override void TranslateCapturedVariable(string name, object? capturedValue) + { + // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, + // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) + if (capturedValue is null) + { + this._sql.Append("NULL"); + } + else + { + this._parameterValues.Add(capturedValue); + // SQL Server parameters can't start with a digit (but underscore is OK). + this._sql.Append("@_").Append(this._parameterIndex++); + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs new file mode 100644 index 000000000000..754f4380160c --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +/// +/// An implementation of backed by a SQL Server or Azure SQL database. +/// +public sealed class SqlServerVectorStore : IVectorStore, IDisposable +{ + private readonly SqlConnection _connection; + private readonly SqlServerVectorStoreOptions _options; + + /// + /// Initializes a new instance of the class. + /// + /// Database connection. + /// Optional configuration options. + public SqlServerVectorStore(SqlConnection connection, SqlServerVectorStoreOptions? options = null) + { + this._connection = connection; + // We need to create a copy, so any changes made to the option bag after + // the ctor call do not affect this instance. + this._options = options is not null + ? new() { Schema = options.Schema } + : SqlServerVectorStoreOptions.Defaults; + } + + /// + public void Dispose() => this._connection.Dispose(); + + /// + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull + { + Verify.NotNull(name); + + return new SqlServerVectorStoreRecordCollection( + this._connection, + name, + new() + { + Schema = this._options.Schema, + RecordDefinition = vectorStoreRecordDefinition + }); + } + + /// + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + using SqlCommand command = SqlServerCommandBuilder.SelectTableNames(this._connection, this._options.Schema); + + using SqlDataReader reader = await ExceptionWrapper.WrapAsync(this._connection, command, + static (cmd, ct) => cmd.ExecuteReaderAsync(ct), + cancellationToken, "ListCollection").ConfigureAwait(false); + + while (await ExceptionWrapper.WrapReadAsync(reader, cancellationToken, "ListCollection").ConfigureAwait(false)) + { + yield return reader.GetString(reader.GetOrdinal("table_name")); + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs new file mode 100644 index 000000000000..a90b474a3d5f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreOptions.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +/// +/// Options for creating a . +/// +public sealed class SqlServerVectorStoreOptions +{ + internal static readonly SqlServerVectorStoreOptions Defaults = new(); + + /// + /// Gets or sets the database schema. + /// + public string? Schema { get; init; } = null; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs new file mode 100644 index 000000000000..d2a87a34bf6d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs @@ -0,0 +1,369 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +/// +/// An implementation of backed by a SQL Server or Azure SQL database. +/// +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix (Collection) +public sealed class SqlServerVectorStoreRecordCollection +#pragma warning restore CA1711 + : IVectorStoreRecordCollection where TKey : notnull +{ + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly SqlServerVectorStoreRecordCollectionOptions s_defaultOptions = new(); + + private readonly SqlConnection _sqlConnection; + private readonly SqlServerVectorStoreRecordCollectionOptions _options; + private readonly VectorStoreRecordPropertyReader _propertyReader; + private readonly IVectorStoreRecordMapper> _mapper; + + /// + /// Initializes a new instance of the class. + /// + /// Database connection. + /// The name of the collection. + /// Optional configuration options. + public SqlServerVectorStoreRecordCollection( + SqlConnection connection, + string name, + SqlServerVectorStoreRecordCollectionOptions? options = null) + { + Verify.NotNull(connection); + Verify.NotNull(name); + + VectorStoreRecordPropertyReader propertyReader = new(typeof(TRecord), + options?.RecordDefinition, + new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + }); + + if (VectorStoreRecordPropertyVerification.IsGenericDataModel(typeof(TRecord))) + { + VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.Mapper is not null, SqlServerConstants.SupportedKeyTypes); + } + else + { + propertyReader.VerifyKeyProperties(SqlServerConstants.SupportedKeyTypes); + } + propertyReader.VerifyDataProperties(SqlServerConstants.SupportedDataTypes, supportEnumerable: false); + propertyReader.VerifyVectorProperties(SqlServerConstants.SupportedVectorTypes); + + this._sqlConnection = connection; + this.CollectionName = name; + // We need to create a copy, so any changes made to the option bag after + // the ctor call do not affect this instance. + this._options = options is null ? s_defaultOptions + : new() + { + Schema = options.Schema, + Mapper = options.Mapper, + RecordDefinition = options.RecordDefinition, + }; + this._propertyReader = propertyReader; + + if (options is not null && options.Mapper is not null) + { + this._mapper = options.Mapper; + } + else if (typeof(TRecord).IsGenericType && typeof(TRecord).GetGenericTypeDefinition() == typeof(VectorStoreGenericDataModel<>)) + { + this._mapper = (new GenericRecordMapper(propertyReader) as IVectorStoreRecordMapper>)!; + } + else + { + propertyReader.VerifyHasParameterlessConstructor(); + + this._mapper = new RecordMapper(propertyReader); + } + } + + /// + public string CollectionName { get; } + + /// + public async Task CollectionExistsAsync(CancellationToken cancellationToken = default) + { + using SqlCommand command = SqlServerCommandBuilder.SelectTableName( + this._sqlConnection, this._options.Schema, this.CollectionName); + + return await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + static async (cmd, ct) => + { + using SqlDataReader reader = await cmd.ExecuteReaderAsync(ct).ConfigureAwait(false); + return await reader.ReadAsync(ct).ConfigureAwait(false); + }, cancellationToken, "CollectionExists", this.CollectionName).ConfigureAwait(false); + } + + /// + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) + => this.CreateCollectionAsync(ifNotExists: false, cancellationToken); + + /// + public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + => this.CreateCollectionAsync(ifNotExists: true, cancellationToken); + + private async Task CreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken) + { + foreach (var vectorProperty in this._propertyReader.VectorProperties) + { + if (vectorProperty.Dimensions is not > 0) + { + throw new InvalidOperationException($"Property {nameof(vectorProperty.Dimensions)} on {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' must be set to a positive integer to create a collection."); + } + } + + using SqlCommand command = SqlServerCommandBuilder.CreateTable( + this._sqlConnection, + this._options.Schema, + this.CollectionName, + ifNotExists, + this._propertyReader.KeyProperty, + this._propertyReader.DataProperties, + this._propertyReader.VectorProperties); + + await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + static (cmd, ct) => cmd.ExecuteNonQueryAsync(ct), + cancellationToken, "CreateCollection", this.CollectionName).ConfigureAwait(false); + } + + /// + public async Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + { + using SqlCommand command = SqlServerCommandBuilder.DropTableIfExists( + this._sqlConnection, this._options.Schema, this.CollectionName); + + await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + static (cmd, ct) => cmd.ExecuteNonQueryAsync(ct), + cancellationToken, "DeleteCollection", this.CollectionName).ConfigureAwait(false); + } + + /// + public async Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) + { + Verify.NotNull(key); + + using SqlCommand command = SqlServerCommandBuilder.DeleteSingle( + this._sqlConnection, + this._options.Schema, + this.CollectionName, + this._propertyReader.KeyProperty, + key); + + await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + static (cmd, ct) => cmd.ExecuteNonQueryAsync(ct), + cancellationToken, "Delete", this.CollectionName).ConfigureAwait(false); + } + + /// + public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + { + Verify.NotNull(keys); + + using SqlCommand? command = SqlServerCommandBuilder.DeleteMany( + this._sqlConnection, + this._options.Schema, + this.CollectionName, + this._propertyReader.KeyProperty, + keys); + + if (command is null) + { + return; // keys is empty, there is nothing to delete + } + + await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + static (cmd, ct) => cmd.ExecuteNonQueryAsync(ct), + cancellationToken, "DeleteBatch", this.CollectionName).ConfigureAwait(false); + } + + /// + public async Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(key); + + bool includeVectors = options?.IncludeVectors is true; + + using SqlCommand command = SqlServerCommandBuilder.SelectSingle( + this._sqlConnection, + this._options.Schema, + this.CollectionName, + this._propertyReader.KeyProperty, + this._propertyReader.Properties, + key, + includeVectors); + + using SqlDataReader reader = await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + static async (cmd, ct) => + { + SqlDataReader reader = await cmd.ExecuteReaderAsync(ct).ConfigureAwait(false); + await reader.ReadAsync(ct).ConfigureAwait(false); + return reader; + }, cancellationToken, "Get", this.CollectionName).ConfigureAwait(false); + + return reader.HasRows + ? this._mapper.MapFromStorageToDataModel( + new SqlDataReaderDictionary(reader, this._propertyReader.VectorPropertyStoragePropertyNames), + new() { IncludeVectors = includeVectors }) + : default; + } + + /// + public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(keys); + + bool includeVectors = options?.IncludeVectors is true; + + using SqlCommand? command = SqlServerCommandBuilder.SelectMany( + this._sqlConnection, + this._options.Schema, + this.CollectionName, + this._propertyReader.KeyProperty, + this._propertyReader.Properties, + keys, + includeVectors); + + if (command is null) + { + yield break; // keys is empty + } + + using SqlDataReader reader = await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + static (cmd, ct) => cmd.ExecuteReaderAsync(ct), + cancellationToken, "GetBatch", this.CollectionName).ConfigureAwait(false); + + while (await ExceptionWrapper.WrapReadAsync(reader, cancellationToken, "GetBatch", this.CollectionName).ConfigureAwait(false)) + { + yield return this._mapper.MapFromStorageToDataModel( + new SqlDataReaderDictionary(reader, this._propertyReader.VectorPropertyStoragePropertyNames), + new() { IncludeVectors = includeVectors }); + } + } + + /// + public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + { + Verify.NotNull(record); + + using SqlCommand command = SqlServerCommandBuilder.MergeIntoSingle( + this._sqlConnection, + this._options.Schema, + this.CollectionName, + this._propertyReader.KeyProperty, + this._propertyReader.Properties, + this._mapper.MapFromDataToStorageModel(record)); + + return await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + async static (cmd, ct) => + { + using SqlDataReader reader = await cmd.ExecuteReaderAsync(ct).ConfigureAwait(false); + await reader.ReadAsync(ct).ConfigureAwait(false); + return reader.GetFieldValue(0); + }, cancellationToken, "Upsert", this.CollectionName).ConfigureAwait(false); + } + + /// + public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(records); + + using SqlCommand? command = SqlServerCommandBuilder.MergeIntoMany( + this._sqlConnection, + this._options.Schema, + this.CollectionName, + this._propertyReader.KeyProperty, + this._propertyReader.Properties, + records.Select(record => this._mapper.MapFromDataToStorageModel(record))); + + if (command is null) + { + yield break; // records is empty + } + + using SqlDataReader reader = await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + static (cmd, ct) => cmd.ExecuteReaderAsync(ct), + cancellationToken, "GetBatch", this.CollectionName).ConfigureAwait(false); + + while (await ExceptionWrapper.WrapReadAsync(reader, cancellationToken, "GetBatch", this.CollectionName).ConfigureAwait(false)) + { + yield return reader.GetFieldValue(0); + } + } + + /// + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(vector); + + if (vector is not ReadOnlyMemory allowed) + { + throw new NotSupportedException( + $"The provided vector type {vector.GetType().FullName} is not supported by the SQL Server connector. " + + $"Supported types are: {string.Join(", ", SqlServerConstants.SupportedVectorTypes.Select(l => l.FullName))}"); + } +#pragma warning disable CS0618 // Type or member is obsolete + else if (options is not null && options.OldFilter is not null) +#pragma warning restore CS0618 // Type or member is obsolete + { + throw new NotSupportedException("The obsolete Filter is not supported by the SQL Server connector, use NewFilter instead."); + } + + var searchOptions = options ?? s_defaultVectorSearchOptions; + var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(searchOptions.VectorPropertyName); + + using SqlCommand command = SqlServerCommandBuilder.SelectVector( + this._sqlConnection, + this._options.Schema, + this.CollectionName, + vectorProperty, + this._propertyReader.Properties, + this._propertyReader.StoragePropertyNamesMap, + searchOptions, + allowed); + + return await ExceptionWrapper.WrapAsync(this._sqlConnection, command, + (cmd, ct) => + { + var results = this.ReadVectorSearchResultsAsync(cmd, searchOptions.IncludeVectors, ct); + return Task.FromResult(new VectorSearchResults(results)); + }, cancellationToken, "VectorizedSearch", this.CollectionName).ConfigureAwait(false); + } + + private async IAsyncEnumerable> ReadVectorSearchResultsAsync( + SqlCommand command, + bool includeVectors, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + StorageToDataModelMapperOptions options = new() { IncludeVectors = includeVectors }; + var vectorPropertyStoragePropertyNames = includeVectors ? this._propertyReader.VectorPropertyStoragePropertyNames : []; + using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + + int scoreIndex = -1; + while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + if (scoreIndex < 0) + { + scoreIndex = reader.GetOrdinal("score"); + } + + yield return new VectorSearchResult( + this._mapper.MapFromStorageToDataModel(new SqlDataReaderDictionary(reader, vectorPropertyStoragePropertyNames), options), + reader.GetDouble(scoreIndex)); + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollectionOptions.cs new file mode 100644 index 000000000000..6b21a5e35842 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollectionOptions.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +/// +/// Options when creating a . +/// +public sealed class SqlServerVectorStoreRecordCollectionOptions +{ + /// + /// Gets or sets the database schema. + /// + public string? Schema { get; init; } + + /// + /// Gets or sets an optional custom mapper to use when converting between the data model and the SQL Server record. + /// + /// + /// If not set, the default mapper will be used. + /// + public IVectorStoreRecordMapper>? Mapper { get; init; } + + /// + /// Gets or sets an optional record definition that defines the schema of the record type. + /// + /// + /// If not provided, the schema will be inferred from the record model class using reflection. + /// In this case, the record model properties must be annotated with the appropriate attributes to indicate their usage. + /// See , and . + /// + public VectorStoreRecordDefinition? RecordDefinition { get; init; } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/Connectors.Memory.Sqlite.csproj b/dotnet/src/Connectors/Connectors.Memory.Sqlite/Connectors.Memory.Sqlite.csproj index 2b17f3e0bbe3..fec218bfc49d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/Connectors.Memory.Sqlite.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/Connectors.Memory.Sqlite.csproj @@ -18,6 +18,10 @@ SQLite connector for Semantic Kernel plugins and semantic memory + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs index 2cb6b16fc8cd..963c1184d274 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs @@ -3,357 +3,77 @@ using System; using System.Collections; using System.Collections.Generic; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Linq; using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Text; namespace Microsoft.SemanticKernel.Connectors.Sqlite; -internal class SqliteFilterTranslator +internal sealed class SqliteFilterTranslator : SqlFilterTranslator { - private IReadOnlyDictionary _storagePropertyNames = null!; - private ParameterExpression _recordParameter = null!; - private readonly Dictionary _parameters = new(); - private readonly StringBuilder _sql = new(); - - internal (string Clause, Dictionary) Translate(IReadOnlyDictionary storagePropertyNames, LambdaExpression lambdaExpression) - { - Debug.Assert(this._sql.Length == 0); - - this._storagePropertyNames = storagePropertyNames; - - Debug.Assert(lambdaExpression.Parameters.Count == 1); - this._recordParameter = lambdaExpression.Parameters[0]; - - this.Translate(lambdaExpression.Body); - return (this._sql.ToString(), this._parameters); - } - - private void Translate(Expression? node) - { - switch (node) - { - case BinaryExpression binary: - this.TranslateBinary(binary); - return; - - case ConstantExpression constant: - this.TranslateConstant(constant); - return; - - case MemberExpression member: - this.TranslateMember(member); - return; - - case MethodCallExpression methodCall: - this.TranslateMethodCall(methodCall); - return; - - case UnaryExpression unary: - this.TranslateUnary(unary); - return; - - default: - throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); - } - } - - private void TranslateBinary(BinaryExpression binary) + internal SqliteFilterTranslator(IReadOnlyDictionary storagePropertyNames, + LambdaExpression lambdaExpression) : base(storagePropertyNames, lambdaExpression, sql: null) { - // Special handling for null comparisons - switch (binary.NodeType) - { - case ExpressionType.Equal when IsNull(binary.Right): - this._sql.Append('('); - this.Translate(binary.Left); - this._sql.Append(" IS NULL)"); - return; - case ExpressionType.NotEqual when IsNull(binary.Right): - this._sql.Append('('); - this.Translate(binary.Left); - this._sql.Append(" IS NOT NULL)"); - return; - - case ExpressionType.Equal when IsNull(binary.Left): - this._sql.Append('('); - this.Translate(binary.Right); - this._sql.Append(" IS NULL)"); - return; - case ExpressionType.NotEqual when IsNull(binary.Left): - this._sql.Append('('); - this.Translate(binary.Right); - this._sql.Append(" IS NOT NULL)"); - return; - } - - this._sql.Append('('); - this.Translate(binary.Left); - - this._sql.Append(binary.NodeType switch - { - ExpressionType.Equal => " = ", - ExpressionType.NotEqual => " <> ", - - ExpressionType.GreaterThan => " > ", - ExpressionType.GreaterThanOrEqual => " >= ", - ExpressionType.LessThan => " < ", - ExpressionType.LessThanOrEqual => " <= ", - - ExpressionType.AndAlso => " AND ", - ExpressionType.OrElse => " OR ", - - _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) - }); - - this.Translate(binary.Right); - this._sql.Append(')'); - - static bool IsNull(Expression expression) - => expression is ConstantExpression { Value: null } - || (TryGetCapturedValue(expression, out _, out var capturedValue) && capturedValue is null); } - private void TranslateConstant(ConstantExpression constant) - => this.GenerateLiteral(constant.Value); + internal Dictionary Parameters => this._parameters; - private void GenerateLiteral(object? value) - { - // TODO: Nullable - switch (value) - { - case byte b: - this._sql.Append(b); - return; - case short s: - this._sql.Append(s); - return; - case int i: - this._sql.Append(i); - return; - case long l: - this._sql.Append(l); - return; - - case string s: - this._sql.Append('\'').Append(s.Replace("'", "''")).Append('\''); - return; - case bool b: - this._sql.Append(b ? "TRUE" : "FALSE"); - return; - case Guid g: - this._sql.Append('\'').Append(g.ToString()).Append('\''); - return; - - case DateTime: - case DateTimeOffset: - throw new NotImplementedException(); - - case Array: - throw new NotImplementedException(); + // TODO: support Contains over array fields (#10343) + protected override void TranslateContainsOverArrayColumn(Expression source, Expression item, MethodCallExpression parent) + => throw new NotSupportedException("Unsupported Contains expression"); - case null: - this._sql.Append("NULL"); - return; - - default: - throw new NotSupportedException("Unsupported constant type: " + value.GetType().Name); - } - } - - private void TranslateMember(MemberExpression memberExpression) + protected override void TranslateContainsOverCapturedArray(Expression source, Expression item, MethodCallExpression parent, object? value) { - switch (memberExpression) + if (value is not IEnumerable elements) { - case var _ when this.TryGetColumn(memberExpression, out var column): - this._sql.Append('"').Append(column).Append('"'); - return; - - // Identify captured lambda variables, translate to PostgreSQL parameters ($1, $2...) - case var _ when TryGetCapturedValue(memberExpression, out var name, out var value): - // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, - // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) - if (value is null) - { - this._sql.Append("NULL"); - } - else - { - // Duplicate parameter name, create a new parameter with a different name - // TODO: Share the same parameter when it references the same captured value - if (this._parameters.ContainsKey(name)) - { - var baseName = name; - var i = 0; - do - { - name = baseName + (i++); - } while (this._parameters.ContainsKey(name)); - } - - this._parameters.Add(name, value); - this._sql.Append('@').Append(name); - } - return; - - default: - throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + throw new NotSupportedException("Unsupported Contains expression"); } - } - - private void TranslateMethodCall(MethodCallExpression methodCall) - { - switch (methodCall) - { - // Enumerable.Contains() - case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains - when contains.Method.DeclaringType == typeof(Enumerable): - this.TranslateContains(source, item); - return; - - // List.Contains() - case - { - Method: - { - Name: nameof(Enumerable.Contains), - DeclaringType: { IsGenericType: true } declaringType - }, - Object: Expression source, - Arguments: [var item] - } when declaringType.GetGenericTypeDefinition() == typeof(List<>): - this.TranslateContains(source, item); - return; - default: - throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); - } - } + this.Translate(item, parent); + this._sql.Append(" IN ("); - private void TranslateContains(Expression source, Expression item) - { - switch (source) + var isFirst = true; + foreach (var element in elements) { - // TODO: support Contains over array fields (#10343) - // Contains over array column (r => r.Strings.Contains("foo")) - case var _ when this.TryGetColumn(source, out _): - goto default; - - // Contains over inline array (r => new[] { "foo", "bar" }.Contains(r.String)) - case NewArrayExpression newArray: + if (isFirst) { - this.Translate(item); - this._sql.Append(" IN ("); - - var isFirst = true; - foreach (var element in newArray.Expressions) - { - if (isFirst) - { - isFirst = false; - } - else - { - this._sql.Append(", "); - } - - this.Translate(element); - } - - this._sql.Append(')'); - return; + isFirst = false; } - - // Contains over captured array (r => arrayLocalVariable.Contains(r.String)) - case var _ when TryGetCapturedValue(source, out _, out var value) && value is IEnumerable elements: + else { - this.Translate(item); - this._sql.Append(" IN ("); - - var isFirst = true; - foreach (var element in elements) - { - if (isFirst) - { - isFirst = false; - } - else - { - this._sql.Append(", "); - } - - this.GenerateLiteral(element); - } - - this._sql.Append(')'); - return; + this._sql.Append(", "); } - default: - throw new NotSupportedException("Unsupported Contains expression"); + this.TranslateConstant(element); } + + this._sql.Append(')'); } - private void TranslateUnary(UnaryExpression unary) + protected override void TranslateCapturedVariable(string name, object? capturedValue) { - switch (unary.NodeType) + // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, + // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) + if (capturedValue is null) { - case ExpressionType.Not: - // Special handling for !(a == b) and !(a != b) - if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) - { - this.TranslateBinary( - Expression.MakeBinary( - binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, - binary.Left, - binary.Right)); - return; - } - - this._sql.Append("(NOT "); - this.Translate(unary.Operand); - this._sql.Append(')'); - return; - - default: - throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); + this._sql.Append("NULL"); } - } - - private bool TryGetColumn(Expression expression, [NotNullWhen(true)] out string? column) - { - if (expression is MemberExpression member && member.Expression == this._recordParameter) + else { - if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column)) + // Duplicate parameter name, create a new parameter with a different name + // TODO: Share the same parameter when it references the same captured value + if (this._parameters.ContainsKey(name)) { - throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); + var baseName = name; + var i = 0; + do + { + name = baseName + (i++); + } while (this._parameters.ContainsKey(name)); } - return true; - } - - column = null; - return false; - } - - private static bool TryGetCapturedValue(Expression expression, [NotNullWhen(true)] out string? name, out object? value) - { - if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } - && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) - && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) - { - name = fieldInfo.Name; - value = fieldInfo.GetValue(constant.Value); - return true; + this._parameters.Add(name, capturedValue); + this._sql.Append('@').Append(name); } - - name = null; - value = null; - return false; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs index 6cbd1a27d474..835073cf9c59 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs @@ -204,7 +204,10 @@ public virtual Task> VectorizedSearchAsync } else if (searchOptions.Filter is not null) { - (extraWhereFilter, extraParameters) = new SqliteFilterTranslator().Translate(this._propertyReader.StoragePropertyNamesMap, searchOptions.Filter); + SqliteFilterTranslator translator = new(this._propertyReader.StoragePropertyNamesMap, searchOptions.Filter); + translator.Translate(appendWhere: false); + extraWhereFilter = translator.Clause.ToString(); + extraParameters = translator.Parameters; } #pragma warning restore CS0618 // VectorSearchFilter is obsolete diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs index e1958f934c5d..60dd98f45e7a 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs @@ -76,10 +76,13 @@ public void TestBuildCreateTableCommand(bool ifNotExists) } [Theory] - [InlineData(IndexKind.Hnsw, DistanceFunction.EuclideanDistance)] - [InlineData(IndexKind.IvfFlat, DistanceFunction.DotProductSimilarity)] - [InlineData(IndexKind.Hnsw, DistanceFunction.CosineDistance)] - public void TestBuildCreateIndexCommand(string indexKind, string distanceFunction) + [InlineData(IndexKind.Hnsw, DistanceFunction.EuclideanDistance, true)] + [InlineData(IndexKind.Hnsw, DistanceFunction.EuclideanDistance, false)] + [InlineData(IndexKind.IvfFlat, DistanceFunction.DotProductSimilarity, true)] + [InlineData(IndexKind.IvfFlat, DistanceFunction.DotProductSimilarity, false)] + [InlineData(IndexKind.Hnsw, DistanceFunction.CosineDistance, true)] + [InlineData(IndexKind.Hnsw, DistanceFunction.CosineDistance, false)] + public void TestBuildCreateIndexCommand(string indexKind, string distanceFunction, bool ifNotExists) { var builder = new PostgresVectorStoreCollectionSqlBuilder(); @@ -87,15 +90,28 @@ public void TestBuildCreateIndexCommand(string indexKind, string distanceFunctio if (indexKind != IndexKind.Hnsw) { - Assert.Throws(() => builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction)); + Assert.Throws(() => builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction, ifNotExists)); + Assert.Throws(() => builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction, ifNotExists)); return; } - var cmdInfo = builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction); + var cmdInfo = builder.BuildCreateVectorIndexCommand("public", "1testcollection", vectorColumn, indexKind, distanceFunction, ifNotExists); // Check for expected properties; integration tests will validate the actual SQL. Assert.Contains("CREATE INDEX ", cmdInfo.CommandText); - Assert.Contains("ON public.\"testcollection\" USING hnsw (\"embedding1\" ", cmdInfo.CommandText); + // Make sure ifNotExists is respected + if (ifNotExists) + { + Assert.Contains("CREATE INDEX IF NOT EXISTS", cmdInfo.CommandText); + } + else + { + Assert.DoesNotContain("CREATE INDEX IF NOT EXISTS", cmdInfo.CommandText); + } + // Make sure the name is escaped, so names starting with a digit are OK. + Assert.Contains($"\"1testcollection_{vectorColumn}_index\"", cmdInfo.CommandText); + + Assert.Contains("ON public.\"1testcollection\" USING hnsw (\"embedding1\" ", cmdInfo.CommandText); if (distanceFunction == null) { // Check for distance function defaults to cosine distance diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs index 4973d6e637cb..81b6f0124c30 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/VectorStoreRecordKeyProperty.cs @@ -30,4 +30,9 @@ public VectorStoreRecordKeyProperty(VectorStoreRecordKeyProperty source) : base(source) { } + + /// + /// Gets or sets a value indicating whether the key should be auto-generated by the vector store. + /// + public bool AutoGenerate { get; init; } } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs index 54db3c74ca21..7976184f8ebf 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs @@ -10,6 +10,8 @@ namespace Microsoft.Extensions.VectorData; /// public class VectorSearchOptions { + private int _top = 3, _skip = 0; + /// /// Gets or sets a search filter to use before doing the vector search. /// @@ -34,12 +36,38 @@ public class VectorSearchOptions /// /// Gets or sets the maximum number of results to return. /// - public int Top { get; init; } = 3; + /// Thrown when the value is less than 1. + public int Top + { + get => this._top; + init + { + if (value < 1) + { + throw new ArgumentOutOfRangeException(nameof(value), "Top must be greater than or equal to 1."); + } + + this._top = value; + } + } /// /// Gets or sets the number of results to skip before returning results, that is, the index of the first result to return. /// - public int Skip { get; init; } = 0; + /// Thrown when the value is less than 0. + public int Skip + { + get => this._skip; + init + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "Skip must be greater than or equal to 0."); + } + + this._skip = value; + } + } /// /// Gets or sets a value indicating whether to include vectors in the retrieval result. diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index 06b2e839116b..6a75257b503a 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -41,6 +41,7 @@ + @@ -89,7 +90,6 @@ - diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyVerification.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyVerification.cs index b85800e6a244..08337bd0f138 100644 --- a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyVerification.cs +++ b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyVerification.cs @@ -191,6 +191,9 @@ var enumerableType when GetGenericEnumerableInterface(enumerableType) is Type en return null; } + internal static bool IsGenericDataModel(Type recordType) + => recordType.IsGenericType && recordType.GetGenericTypeDefinition() == typeof(VectorStoreGenericDataModel<>); + /// /// Checks that if the provided is a that the key type is supported by the default mappers. /// If not supported, a custom mapper must be supplied, otherwise an exception is thrown. @@ -202,7 +205,7 @@ var enumerableType when GetGenericEnumerableInterface(enumerableType) is Type en public static void VerifyGenericDataModelKeyType(Type recordType, bool customMapperSupplied, IEnumerable allowedKeyTypes) { // If we are not dealing with a generic data model, no need to check anything else. - if (!recordType.IsGenericType || recordType.GetGenericTypeDefinition() != typeof(VectorStoreGenericDataModel<>)) + if (!IsGenericDataModel(recordType)) { return; } diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresBatchConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresBatchConformanceTests.cs new file mode 100644 index 000000000000..b798bab8e437 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresBatchConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PostgresIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace PostgresIntegrationTests.CRUD; + +public class PostgresBatchConformanceTests(PostgresFixture fixture) + : BatchConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresGenericDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresGenericDataModelConformanceTests.cs new file mode 100644 index 000000000000..1a72c1b59e01 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/CRUD/PostgresGenericDataModelConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PostgresIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace PostgresIntegrationTests.CRUD; + +public class PostgresGenericDataModelConformanceTests(PostgresFixture fixture) + : GenericDataModelConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresFixture.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresFixture.cs new file mode 100644 index 000000000000..6c8ce87ad984 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace PostgresIntegrationTests.Support; + +public class PostgresFixture : VectorStoreFixture +{ + public override TestStore TestStore => PostgresTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresVectorSearchDistanceFunctionComplianceTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresVectorSearchDistanceFunctionComplianceTests.cs new file mode 100644 index 000000000000..97767626c5cf --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresVectorSearchDistanceFunctionComplianceTests.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PostgresIntegrationTests.Support; +using VectorDataSpecificationTests.VectorSearch; +using Xunit; + +namespace PostgresIntegrationTests.VectorSearch; + +public class PostgresVectorSearchDistanceFunctionComplianceTests(PostgresFixture fixture) : VectorSearchDistanceFunctionComplianceTests(fixture), IClassFixture +{ + public override Task EuclideanSquaredDistance() => Assert.ThrowsAsync(base.EuclideanSquaredDistance); + + public override Task Hamming() => Assert.ThrowsAsync(base.Hamming); + + public override Task NegativeDotProductSimilarity() => Assert.ThrowsAsync(base.NegativeDotProductSimilarity); +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresVectorSearchDistanceFunctionComplianceTests_Hnsw.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresVectorSearchDistanceFunctionComplianceTests_Hnsw.cs new file mode 100644 index 000000000000..2daf5cc958c2 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/VectorSearch/PostgresVectorSearchDistanceFunctionComplianceTests_Hnsw.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PostgresIntegrationTests.Support; + +namespace PostgresIntegrationTests.VectorSearch; + +public class PostgresVectorSearchDistanceFunctionComplianceTests_Hnsw(PostgresFixture fixture) : PostgresVectorSearchDistanceFunctionComplianceTests(fixture) +{ + protected override string IndexKind => Microsoft.Extensions.VectorData.IndexKind.Hnsw; +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBatchConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBatchConformanceTests.cs new file mode 100644 index 000000000000..1e8ee17dd6f4 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerBatchConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServerIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace SqlServerIntegrationTests.CRUD; + +public class SqlServerBatchConformanceTests(SqlServerFixture fixture) + : BatchConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerGenericDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerGenericDataModelConformanceTests.cs new file mode 100644 index 000000000000..5b98a7d46a11 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/CRUD/SqlServerGenericDataModelConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServerIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace SqlServerIntegrationTests.CRUD; + +public class SqlServerGenericDataModelConformanceTests(SqlServerFixture fixture) + : GenericDataModelConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs new file mode 100644 index 000000000000..3bae6cc48552 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Filter/SqlServerBasicFilterTests.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using SqlServerIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; +using Xunit; +using Xunit.Sdk; + +namespace SqlServerIntegrationTests.Filter; + +public class SqlServerBasicFilterTests(SqlServerBasicFilterTests.Fixture fixture) + : BasicFilterTests(fixture), IClassFixture +{ + public override async Task Not_over_Or() + { + // Test sends: WHERE (NOT (("Int" = 8) OR ("String" = 'foo'))) + // There's a NULL string in the database, and relational null semantics in conjunction with negation makes the default implementation fail. + await Assert.ThrowsAsync(() => base.Not_over_Or()); + + // Compensate by adding a null check: + await this.TestFilterAsync(r => r.String != null && !(r.Int == 8 || r.String == "foo")); + } + + public override async Task NotEqual_with_string() + { + // As above, null semantics + negation + await Assert.ThrowsAsync(() => base.NotEqual_with_string()); + + await this.TestFilterAsync(r => r.String != null && r.String != "foo"); + } + + public override Task Contains_over_field_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + + public override Task Contains_over_field_string_List() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + + [Fact(Skip = "Not supported")] + [Obsolete("Legacy filters are not supported")] + public override Task Legacy_And() => throw new NotSupportedException(); + + [Fact(Skip = "Not supported")] + [Obsolete("Legacy filters are not supported")] + public override Task Legacy_AnyTagEqualTo_array() => throw new NotSupportedException(); + + [Fact(Skip = "Not supported")] + [Obsolete("Legacy filters are not supported")] + public override Task Legacy_AnyTagEqualTo_List() => throw new NotSupportedException(); + + [Fact(Skip = "Not supported")] + [Obsolete("Legacy filters are not supported")] + public override Task Legacy_equality() => throw new NotSupportedException(); + + public new class Fixture : BasicFilterTests.Fixture + { + private static readonly string s_uniqueName = Guid.NewGuid().ToString(); + + public override TestStore TestStore => SqlServerTestStore.Instance; + + protected override string CollectionName => s_uniqueName; + + // Override to remove the string collection properties, which aren't (currently) supported on SqlServer + protected override VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = base.GetRecordDefinition().Properties.Where(p => p.PropertyType != typeof(string[]) && p.PropertyType != typeof(List)).ToList() + }; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Properties/AssemblyAttributes.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Properties/AssemblyAttributes.cs new file mode 100644 index 000000000000..8f36e2be3f06 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Properties/AssemblyAttributes.cs @@ -0,0 +1,3 @@ +// Copyright (c) Microsoft. All rights reserved. + +[assembly: SqlServerIntegrationTests.Support.SqlServerConnectionStringRequiredAttribute] diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs new file mode 100644 index 000000000000..0d421b6ba314 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerCommandBuilderTests.cs @@ -0,0 +1,367 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text; +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.SqlServer; +using Xunit; + +namespace SqlServerIntegrationTests; + +public class SqlServerCommandBuilderTests +{ + [Theory] + [InlineData("schema", "name", "[schema].[name]")] + [InlineData(null, "name", "[name]")] + [InlineData("schema", "[brackets]", "[schema].[[brackets]]]")] + [InlineData(null, "[needsEscaping]", "[[needsEscaping]]]")] + [InlineData("needs]escaping", "[brackets]", "[needs]]escaping].[[brackets]]]")] + public void AppendTableName(string? schema, string table, string expectedFullName) + { + StringBuilder result = new(); + + SqlServerCommandBuilder.AppendTableName(result, schema, table); + + Assert.Equal(expectedFullName, result.ToString()); + } + + [Theory] + [InlineData("name", "@name_")] // typical name + [InlineData("na me", "@na_")] // contains a whitespace, an illegal parameter name character + [InlineData("123", "@_")] // starts with a digit, also not allowed + [InlineData("ĄŻŚĆ_doesNotStartWithAscii", "@_")] // starts with a non-ASCII character + public void AppendParameterName(string propertyName, string expectedPrefix) + { + StringBuilder builder = new(); + StringBuilder expectedBuilder = new(); + VectorStoreRecordKeyProperty keyProperty = new(propertyName, typeof(string)); + + int paramIndex = 0; // we need a dedicated variable to ensure that AppendParameterName increments the index + for (int i = 0; i < 10; i++) + { + Assert.Equal(paramIndex, i); + SqlServerCommandBuilder.AppendParameterName(builder, keyProperty, ref paramIndex, out string parameterName); + Assert.Equal($"{expectedPrefix}{i}", parameterName); + expectedBuilder.Append(parameterName); + } + + Assert.Equal(expectedBuilder.ToString(), builder.ToString()); + } + + [Theory] + [InlineData("schema", "simpleName", "[simpleName]")] + [InlineData("schema", "[needsEscaping]", "[[needsEscaping]]]")] + public void DropTable(string schema, string table, string expectedTable) + { + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.DropTableIfExists(connection, schema, table); + + Assert.Equal($"DROP TABLE IF EXISTS [{schema}].{expectedTable}", command.CommandText); + } + + [Theory] + [InlineData("schema", "simpleName")] + [InlineData("schema", "[needsEscaping]")] + public void SelectTableName(string schema, string table) + { + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.SelectTableName(connection, schema, table); + + Assert.Equal( + """ + SELECT TABLE_NAME + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE = 'BASE TABLE' + AND (@schema is NULL or TABLE_SCHEMA = @schema) + AND TABLE_NAME = @tableName + """ + , command.CommandText); + Assert.Equal(schema, command.Parameters[0].Value); + Assert.Equal(table, command.Parameters[1].Value); + } + + [Fact] + public void SelectTableNames() + { + const string SchemaName = "theSchemaName"; + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.SelectTableNames(connection, SchemaName); + + Assert.Equal( + """ + SELECT TABLE_NAME + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE = 'BASE TABLE' + AND (@schema is NULL or TABLE_SCHEMA = @schema) + """ + , command.CommandText); + Assert.Equal(SchemaName, command.Parameters[0].Value); + Assert.Equal("@schema", command.Parameters[0].ParameterName); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void CreateTable(bool ifNotExists) + { + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); + VectorStoreRecordDataProperty[] dataProperties = + [ + new VectorStoreRecordDataProperty("simpleName", typeof(string)), + new VectorStoreRecordDataProperty("with space", typeof(int)) + ]; + VectorStoreRecordVectorProperty[] vectorProperties = + [ + new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) + { + Dimensions = 10 + } + ]; + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.CreateTable(connection, "schema", "table", + ifNotExists, keyProperty, dataProperties, vectorProperties); + + string expectedCommand = + """ + BEGIN + CREATE TABLE [schema].[table] ( + [id] BIGINT NOT NULL, + [simpleName] NVARCHAR(255), + [with space] INT, + [embedding] VECTOR(10), + PRIMARY KEY ([id]) + ); + END; + """; + if (ifNotExists) + { + expectedCommand = "IF OBJECT_ID(N'[schema].[table]', N'U') IS NULL" + Environment.NewLine + expectedCommand; + } + + AssertEqualIgnoreNewLines(expectedCommand, command.CommandText); + } + + [Fact] + public void MergeIntoSingle() + { + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); + VectorStoreRecordProperty[] properties = + [ + keyProperty, + new VectorStoreRecordDataProperty("simpleString", typeof(string)), + new VectorStoreRecordDataProperty("simpleInt", typeof(int)), + new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) + { + Dimensions = 10 + } + ]; + + using SqlConnection connection = CreateConnection(); + using SqlCommand command = SqlServerCommandBuilder.MergeIntoSingle(connection, "schema", "table", + keyProperty, properties, + new Dictionary + { + { "id", null }, + { "simpleString", "nameValue" }, + { "simpleInt", 134 }, + { "embedding", "{ 10.0 }" } + }); + + string expectedCommand = + """" + MERGE INTO [schema].[table] AS t + USING (VALUES (@id_0,@simpleString_1,@simpleInt_2,@embedding_3)) AS s ([id],[simpleString],[simpleInt],[embedding]) + ON (t.[id] = s.[id]) + WHEN MATCHED THEN + UPDATE SET t.[simpleString] = s.[simpleString],t.[simpleInt] = s.[simpleInt],t.[embedding] = s.[embedding] + WHEN NOT MATCHED THEN + INSERT ([id],[simpleString],[simpleInt],[embedding]) + VALUES (s.[id],s.[simpleString],s.[simpleInt],s.[embedding]) + OUTPUT inserted.[id]; + """"; + + AssertEqualIgnoreNewLines(expectedCommand, command.CommandText); + Assert.Equal("@id_0", command.Parameters[0].ParameterName); + Assert.Equal(DBNull.Value, command.Parameters[0].Value); + Assert.Equal("@simpleString_1", command.Parameters[1].ParameterName); + Assert.Equal("nameValue", command.Parameters[1].Value); + Assert.Equal("@simpleInt_2", command.Parameters[2].ParameterName); + Assert.Equal(134, command.Parameters[2].Value); + Assert.Equal("@embedding_3", command.Parameters[3].ParameterName); + Assert.Equal("{ 10.0 }", command.Parameters[3].Value); + } + + [Fact] + public void MergeIntoMany() + { + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); + VectorStoreRecordProperty[] properties = + [ + keyProperty, + new VectorStoreRecordDataProperty("simpleString", typeof(string)), + new VectorStoreRecordDataProperty("simpleInt", typeof(int)), + new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) + { + Dimensions = 10 + } + ]; + Dictionary[] records = + [ + new Dictionary + { + { "id", 0L }, + { "simpleString", "nameValue0" }, + { "simpleInt", 134 }, + { "embedding", "{ 10.0 }" } + }, + new Dictionary + { + { "id", 1L }, + { "simpleString", "nameValue1" }, + { "simpleInt", 135 }, + { "embedding", "{ 11.0 }" } + } + ]; + + using SqlConnection connection = CreateConnection(); + using SqlCommand command = SqlServerCommandBuilder.MergeIntoMany(connection, "schema", "table", + keyProperty, properties, records)!; + + string expectedCommand = + """" + DECLARE @InsertedKeys TABLE (KeyColumn BIGINT); + MERGE INTO [schema].[table] AS t + USING (VALUES + (@id_0,@simpleString_1,@simpleInt_2,@embedding_3), + (@id_4,@simpleString_5,@simpleInt_6,@embedding_7)) AS s ([id],[simpleString],[simpleInt],[embedding]) + ON (t.[id] = s.[id]) + WHEN MATCHED THEN + UPDATE SET t.[simpleString] = s.[simpleString],t.[simpleInt] = s.[simpleInt],t.[embedding] = s.[embedding] + WHEN NOT MATCHED THEN + INSERT ([id],[simpleString],[simpleInt],[embedding]) + VALUES (s.[id],s.[simpleString],s.[simpleInt],s.[embedding]) + OUTPUT inserted.[id] INTO @InsertedKeys (KeyColumn); + SELECT KeyColumn FROM @InsertedKeys; + """"; + + AssertEqualIgnoreNewLines(expectedCommand, command.CommandText); + + for (int i = 0; i < records.Length; i++) + { + Assert.Equal($"@id_{4 * i + 0}", command.Parameters[4 * i + 0].ParameterName); + Assert.Equal((long)i, command.Parameters[4 * i + 0].Value); + Assert.Equal($"@simpleString_{4 * i + 1}", command.Parameters[4 * i + 1].ParameterName); + Assert.Equal($"nameValue{i}", command.Parameters[4 * i + 1].Value); + Assert.Equal($"@simpleInt_{4 * i + 2}", command.Parameters[4 * i + 2].ParameterName); + Assert.Equal(134 + i, command.Parameters[4 * i + 2].Value); + Assert.Equal($"@embedding_{4 * i + 3}", command.Parameters[4 * i + 3].ParameterName); + Assert.Equal($"{{ 1{i}.0 }}", command.Parameters[4 * i + 3].Value); + } + } + + [Fact] + public void DeleteSingle() + { + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.DeleteSingle(connection, + "schema", "tableName", keyProperty, 123L); + + Assert.Equal("DELETE FROM [schema].[tableName] WHERE [id] = @id_0", command.CommandText); + Assert.Equal(123L, command.Parameters[0].Value); + Assert.Equal("@id_0", command.Parameters[0].ParameterName); + } + + [Fact] + public void DeleteMany() + { + string[] keys = ["key1", "key2"]; + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(string)); + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.DeleteMany(connection, + "schema", "tableName", keyProperty, keys)!; + + Assert.Equal("DELETE FROM [schema].[tableName] WHERE [id] IN (@id_0,@id_1)", command.CommandText); + for (int i = 0; i < keys.Length; i++) + { + Assert.Equal(keys[i], command.Parameters[i].Value); + Assert.Equal($"@id_{i}", command.Parameters[i].ParameterName); + } + } + + [Fact] + public void SelectSingle() + { + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); + VectorStoreRecordProperty[] properties = [ + keyProperty, + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("age", typeof(int)), + new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) + { + Dimensions = 10 + } + ]; + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.SelectSingle(connection, + "schema", "tableName", keyProperty, properties, 123L, includeVectors: true); + + AssertEqualIgnoreNewLines( + """"" + SELECT [id],[name],[age],[embedding] + FROM [schema].[tableName] + WHERE [id] = @id_0 + """"", command.CommandText); + Assert.Equal(123L, command.Parameters[0].Value); + Assert.Equal("@id_0", command.Parameters[0].ParameterName); + } + + [Fact] + public void SelectMany() + { + VectorStoreRecordKeyProperty keyProperty = new("id", typeof(long)); + VectorStoreRecordProperty[] properties = [ + keyProperty, + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("age", typeof(int)), + new VectorStoreRecordVectorProperty("embedding", typeof(ReadOnlyMemory)) + { + Dimensions = 10 + } + ]; + long[] keys = [123L, 456L, 789L]; + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.SelectMany(connection, + "schema", "tableName", keyProperty, properties, keys, includeVectors: true)!; + + AssertEqualIgnoreNewLines( + """"" + SELECT [id],[name],[age],[embedding] + FROM [schema].[tableName] + WHERE [id] IN (@id_0,@id_1,@id_2) + """"", command.CommandText); + for (int i = 0; i < keys.Length; i++) + { + Assert.Equal(keys[i], command.Parameters[i].Value); + Assert.Equal($"@id_{i}", command.Parameters[i].ParameterName); + } + } + + // This repo is configured with eol=lf, so the expected string should always use \n + // as long given IDE does not use \r\n. + // The actual string may use \r\n, so we just normalize both. + private static void AssertEqualIgnoreNewLines(string expected, string actual) + => Assert.Equal(expected.Replace("\r\n", "\n"), actual.Replace("\r\n", "\n")); + + // We create a connection using a fake connection string just to be able to create the SqlCommand. + private static SqlConnection CreateConnection() + => new("Server=localhost;Database=master;Integrated Security=True;"); +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj new file mode 100644 index 000000000000..4752d82818dc --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerIntegrationTests.csproj @@ -0,0 +1,47 @@ + + + + net8.0;net472 + enable + enable + + false + true + + $(NoWarn);CA2007,SKEXP0001,SKEXP0020,VSTHRD111;CS1685 + b7762d10-e29b-4bb1-8b74-b6d69a667dd4 + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + + + + + + + + + + Always + + + Always + + + + diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/SqlServer/SqlServerMemoryStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs similarity index 94% rename from dotnet/src/IntegrationTests/Connectors/Memory/SqlServer/SqlServerMemoryStoreTests.cs rename to dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs index 32c0f6742546..23e714ff60bd 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/SqlServer/SqlServerMemoryStoreTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerMemoryStoreTests.cs @@ -1,9 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; using Microsoft.Data.SqlClient; using Microsoft.Extensions.Configuration; using Microsoft.SemanticKernel.Connectors.SqlServer; @@ -36,14 +32,8 @@ public async Task InitializeAsync() .AddUserSecrets() .Build(); - var connectionString = configuration["SqlServer:ConnectionString"]; - - if (string.IsNullOrWhiteSpace(connectionString)) - { - throw new ArgumentException("SqlServer memory connection string is not configured."); - } - - this._connectionString = connectionString; + this._connectionString = configuration["SqlServer:ConnectionString"] + ?? throw new ArgumentException("SqlServer memory connection string is not configured."); await this.CleanupDatabaseAsync(); await this.InitializeDatabaseAsync(); @@ -79,7 +69,7 @@ public async Task GetCollectionsAsync() await this.Store.CreateCollectionAsync("collection1"); await this.Store.CreateCollectionAsync("collection2"); - var collections = await this.Store.GetCollectionsAsync().ToListAsync(); + var collections = await this.Store.GetCollectionsAsync().ToArrayAsync(); Assert.Contains("collection1", collections); Assert.Contains("collection2", collections); } @@ -212,8 +202,8 @@ public async Task GetNearestMatchesAsync(bool withEmbeddings) await this.Store.CreateCollectionAsync(DefaultCollectionName); await this.InsertSampleDataAsync(); - List<(MemoryRecord Record, double SimilarityScore)> results = - await this.Store.GetNearestMatchesAsync(DefaultCollectionName, new[] { 5f, 6f, 7f, 8f, 9f }, limit: 2, withEmbeddings: withEmbeddings).ToListAsync(); + (MemoryRecord Record, double SimilarityScore)[] results = + await this.Store.GetNearestMatchesAsync(DefaultCollectionName, new[] { 5f, 6f, 7f, 8f, 9f }, limit: 2, withEmbeddings: withEmbeddings).ToArrayAsync(); Assert.All(results, t => Assert.True(t.SimilarityScore > 0)); @@ -254,13 +244,13 @@ public async Task GetNearestMatchesWithMinRelevanceScoreAsync() await this.Store.CreateCollectionAsync(DefaultCollectionName); await this.InsertSampleDataAsync(); - List<(MemoryRecord Record, double SimilarityScore)> results = - await this.Store.GetNearestMatchesAsync(DefaultCollectionName, new[] { 5f, 6f, 7f, 8f, 9f }, limit: 2).ToListAsync(); + (MemoryRecord Record, double SimilarityScore)[] results = + await this.Store.GetNearestMatchesAsync(DefaultCollectionName, new[] { 5f, 6f, 7f, 8f, 9f }, limit: 2).ToArrayAsync(); var firstId = results[0].Record.Metadata.Id; var firstSimilarityScore = results[0].SimilarityScore; - results = await this.Store.GetNearestMatchesAsync(DefaultCollectionName, new[] { 5f, 6f, 7f, 8f, 9f }, limit: 2, minRelevanceScore: firstSimilarityScore + 0.0001).ToListAsync(); + results = await this.Store.GetNearestMatchesAsync(DefaultCollectionName, new[] { 5f, 6f, 7f, 8f, 9f }, limit: 2, minRelevanceScore: firstSimilarityScore + 0.0001).ToArrayAsync(); Assert.DoesNotContain(firstId, results.Select(r => r.Record.Metadata.Id)); } @@ -324,18 +314,29 @@ private async Task> InsertSampleDataAsync() private async Task InitializeDatabaseAsync() { +#if NET // IAsyncDisposable is not present in Full Framework await using var connection = new SqlConnection(this._connectionString); - await connection.OpenAsync(); await using var cmd = connection.CreateCommand(); +#else + using var connection = new SqlConnection(this._connectionString); + using var cmd = connection.CreateCommand(); +#endif + + await connection.OpenAsync(); cmd.CommandText = $"CREATE SCHEMA {SchemaName}"; await cmd.ExecuteNonQueryAsync(); } private async Task CleanupDatabaseAsync() { +#if NET await using var connection = new SqlConnection(this._connectionString); - await connection.OpenAsync(); await using var cmd = connection.CreateCommand(); +#else + using var connection = new SqlConnection(this._connectionString); + using var cmd = connection.CreateCommand(); +#endif + await connection.OpenAsync(); cmd.CommandText = $""" DECLARE tables_cursor CURSOR FOR SELECT table_name diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs new file mode 100644 index 000000000000..bb658f486c87 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs @@ -0,0 +1,503 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.SqlServer; +using SqlServerIntegrationTests.Support; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace SqlServerIntegrationTests; + +public class SqlServerVectorStoreTests(SqlServerFixture fixture) : IClassFixture +{ + // this test may be once executed by multiple users against a shared db instance + private static string GetUniqueCollectionName() => Guid.NewGuid().ToString(); + + [ConditionalFact] + public async Task CollectionCRUD() + { + string collectionName = GetUniqueCollectionName(); + var testStore = fixture.TestStore; + var collection = testStore.DefaultVectorStore.GetCollection(collectionName); + + try + { + Assert.False(await collection.CollectionExistsAsync()); + + Assert.False(await testStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(collectionName)); + + await collection.CreateCollectionAsync(); + + Assert.True(await collection.CollectionExistsAsync()); + Assert.True(await testStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(collectionName)); + + await collection.CreateCollectionIfNotExistsAsync(); + + Assert.True(await collection.CollectionExistsAsync()); + + await collection.DeleteCollectionAsync(); + + Assert.False(await collection.CollectionExistsAsync()); + Assert.False(await testStore.DefaultVectorStore.ListCollectionNamesAsync().ContainsAsync(collectionName)); + } + finally + { + await collection.DeleteCollectionAsync(); + } + } + + [ConditionalFact] + public async Task RecordCRUD() + { + string collectionName = GetUniqueCollectionName(); + var testStore = fixture.TestStore; + var collection = testStore.DefaultVectorStore.GetCollection(collectionName); + + try + { + await collection.CreateCollectionIfNotExistsAsync(); + + TestModel inserted = new() + { + Id = "MyId", + Number = 100, + Floats = Enumerable.Range(0, 10).Select(i => (float)i).ToArray() + }; + string key = await collection.UpsertAsync(inserted); + Assert.Equal(inserted.Id, key); + + TestModel? received = await collection.GetAsync(inserted.Id, new() { IncludeVectors = true }); + AssertEquality(inserted, received); + + TestModel updated = new() + { + Id = inserted.Id, + Number = inserted.Number + 200, // change one property + Floats = inserted.Floats + }; + key = await collection.UpsertAsync(updated); + Assert.Equal(inserted.Id, key); + + received = await collection.GetAsync(updated.Id, new() { IncludeVectors = true }); + AssertEquality(updated, received); + + VectorSearchResult vectorSearchResult = await (await collection.VectorizedSearchAsync(inserted.Floats, new() + { + VectorPropertyName = nameof(TestModel.Floats), + IncludeVectors = true + })).Results.SingleAsync(); + AssertEquality(updated, vectorSearchResult.Record); + + vectorSearchResult = await (await collection.VectorizedSearchAsync(inserted.Floats, new() + { + VectorPropertyName = nameof(TestModel.Floats), + IncludeVectors = false + })).Results.SingleAsync(); + // Make sure the vectors are not included in the result. + Assert.Equal(0, vectorSearchResult.Record.Floats.Length); + + await collection.DeleteAsync(inserted.Id); + + Assert.Null(await collection.GetAsync(inserted.Id)); + } + finally + { + await collection.DeleteCollectionAsync(); + } + } + + [ConditionalFact] + public async Task WrongModels() + { + string collectionName = GetUniqueCollectionName(); + var testStore = fixture.TestStore; + var collection = testStore.DefaultVectorStore.GetCollection(collectionName); + + try + { + await collection.CreateCollectionIfNotExistsAsync(); + + TestModel inserted = new() + { + Id = "MyId", + Text = "NotAnInt", + Number = 100, + Floats = Enumerable.Range(0, 10).Select(i => (float)i).ToArray() + }; + Assert.Equal(inserted.Id, await collection.UpsertAsync(inserted)); + + // Let's use a model with different storage names to trigger an SQL exception + // which should be mapped to VectorStoreOperationException. + var differentNamesCollection = testStore.DefaultVectorStore.GetCollection(collectionName); + VectorStoreOperationException operationEx = await Assert.ThrowsAsync(() => differentNamesCollection.GetAsync(inserted.Id)); + Assert.IsType(operationEx.InnerException); + + // Let's use a model with the same storage names, but different types + // to trigger a mapping exception (casting a string to an int). + var sameNameDifferentModelCollection = testStore.DefaultVectorStore.GetCollection(collectionName); + VectorStoreRecordMappingException mappingEx = await Assert.ThrowsAsync(() => sameNameDifferentModelCollection.GetAsync(inserted.Id)); + Assert.IsType(mappingEx.InnerException); + + // Let's use a model with the same storage names, but different types + // to trigger a mapping exception (deserializing a string to Memory). + var invalidJsonCollection = testStore.DefaultVectorStore.GetCollection(collectionName); + await Assert.ThrowsAsync(() => invalidJsonCollection.GetAsync(inserted.Id, new() { IncludeVectors = true })); + } + finally + { + await collection.DeleteCollectionAsync(); + } + } + + [ConditionalFact] + public async Task CustomMapper() + { + string collectionName = GetUniqueCollectionName(); + TestModelMapper mapper = new(); + SqlServerVectorStoreRecordCollectionOptions options = new() + { + Mapper = mapper + }; + using SqlConnection connection = new(SqlServerTestEnvironment.ConnectionString); + SqlServerVectorStoreRecordCollection collection = new(connection, collectionName, options); + + try + { + await collection.CreateCollectionIfNotExistsAsync(); + + TestModel inserted = new() + { + Id = "MyId", + Number = 100, + Floats = Enumerable.Range(0, 10).Select(i => (float)i).ToArray() + }; + string key = await collection.UpsertAsync(inserted); + Assert.Equal(inserted.Id, key); + Assert.True(mapper.MapFromDataToStorageModel_WasCalled); + Assert.False(mapper.MapFromStorageToDataModel_WasCalled); + + TestModel? received = await collection.GetAsync(inserted.Id, new() { IncludeVectors = true }); + AssertEquality(inserted, received); + Assert.True(mapper.MapFromStorageToDataModel_WasCalled); + + TestModel updated = new() + { + Id = inserted.Id, + Number = inserted.Number + 200, // change one property + Floats = inserted.Floats + }; + key = await collection.UpsertAsync(updated); + Assert.Equal(inserted.Id, key); + + received = await collection.GetAsync(updated.Id, new() { IncludeVectors = true }); + AssertEquality(updated, received); + + await collection.DeleteAsync(inserted.Id); + + Assert.Null(await collection.GetAsync(inserted.Id)); + } + finally + { + await collection.DeleteCollectionAsync(); + } + } + + [ConditionalFact] + public async Task BatchCRUD() + { + string collectionName = GetUniqueCollectionName(); + var testStore = fixture.TestStore; + var collection = testStore.DefaultVectorStore.GetCollection(collectionName); + + try + { + await collection.CreateCollectionIfNotExistsAsync(); + + TestModel[] inserted = Enumerable.Range(0, 10).Select(i => new TestModel() + { + Id = $"MyId{i}", + Number = 100 + i, + Floats = Enumerable.Range(0, 10).Select(j => (float)(i + j)).ToArray() + }).ToArray(); + + string[] keys = await collection.UpsertBatchAsync(inserted).ToArrayAsync(); + for (int i = 0; i < inserted.Length; i++) + { + Assert.Equal(inserted[i].Id, keys[i]); + } + + TestModel[] received = await collection.GetBatchAsync(keys, new() { IncludeVectors = true }).ToArrayAsync(); + for (int i = 0; i < inserted.Length; i++) + { + AssertEquality(inserted[i], received[i]); + } + + TestModel[] updated = inserted.Select(i => new TestModel() + { + Id = i.Id, + Number = i.Number + 200, // change one property + Floats = i.Floats + }).ToArray(); + + keys = await collection.UpsertBatchAsync(updated).ToArrayAsync(); + for (int i = 0; i < updated.Length; i++) + { + Assert.Equal(updated[i].Id, keys[i]); + } + + received = await collection.GetBatchAsync(keys, new() { IncludeVectors = true }).ToArrayAsync(); + for (int i = 0; i < updated.Length; i++) + { + AssertEquality(updated[i], received[i]); + } + + await collection.DeleteBatchAsync(keys); + + Assert.False(await collection.GetBatchAsync(keys).AnyAsync()); + } + finally + { + await collection.DeleteCollectionAsync(); + } + } + + private static void AssertEquality(TestModel inserted, TestModel? received) + { + Assert.NotNull(received); + Assert.Equal(inserted.Number, received.Number); + Assert.Equal(inserted.Id, received.Id); + Assert.Equal(inserted.Floats.ToArray(), received.Floats.ToArray()); + Assert.Null(received.Text); // testing DBNull code path + } + + public sealed class TestModel + { + [VectorStoreRecordKey(StoragePropertyName = "key")] + public string? Id { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "text")] + public string? Text { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "column")] + public int Number { get; set; } + + [VectorStoreRecordVector(Dimensions: 10, StoragePropertyName = "embedding")] + public ReadOnlyMemory Floats { get; set; } + } + + public sealed class SameStorageNameButDifferentType + { + [VectorStoreRecordKey(StoragePropertyName = "key")] + public string? Id { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "text")] + public int Number { get; set; } + } + + public sealed class SameStorageNameButInvalidVector + { + [VectorStoreRecordKey(StoragePropertyName = "key")] + public string? Id { get; set; } + + [VectorStoreRecordVector(Dimensions: 10, StoragePropertyName = "text")] + public ReadOnlyMemory Floats { get; set; } + } + + public sealed class DifferentStorageNames + { + [VectorStoreRecordKey(StoragePropertyName = "key")] + public string? Id { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "text2")] + public string? Text { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "column2")] + public int Number { get; set; } + + [VectorStoreRecordVector(Dimensions: 10, StoragePropertyName = "embedding2")] + public ReadOnlyMemory Floats { get; set; } + } + +#if NETFRAMEWORK + [ConditionalFact] + public void TimeSpanIsNotSupported() + { + string collectionName = GetUniqueCollectionName(); + var testStore = fixture.TestStore; + + Assert.Throws(() => testStore.DefaultVectorStore.GetCollection(collectionName)); + } +#else + [ConditionalFact] + public async Task TimeOnlyIsSupported() + { + string collectionName = GetUniqueCollectionName(); + var testStore = fixture.TestStore; + + var collection = testStore.DefaultVectorStore.GetCollection(collectionName); + + try + { + await collection.CreateCollectionIfNotExistsAsync(); + + TimeModel inserted = new() + { + Id = "MyId", + Time = new TimeOnly(12, 34, 56) + }; + string key = await collection.UpsertAsync(inserted); + Assert.Equal(inserted.Id, key); + + TimeModel? received = await collection.GetAsync(inserted.Id, new() { IncludeVectors = true }); + Assert.NotNull(received); + Assert.Equal(inserted.Time, received.Time); + } + finally + { + await collection.DeleteCollectionAsync(); + } + } +#endif + + public sealed class TimeModel + { + [VectorStoreRecordKey(StoragePropertyName = "key")] + public string? Id { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "time")] +#if NETFRAMEWORK + public TimeSpan Time { get; set; } +#else + public TimeOnly Time { get; set; } +#endif + } + + [ConditionalFact] + public Task CanUseFancyModels_Int() => this.CanUseFancyModels(); + + [ConditionalFact] + public Task CanUseFancyModels_Long() => this.CanUseFancyModels(); + + [ConditionalFact] + public Task CanUseFancyModels_Guid() => this.CanUseFancyModels(); + + private async Task CanUseFancyModels() where TKey : notnull + { + string collectionName = GetUniqueCollectionName(); + var testStore = fixture.TestStore; + var collection = testStore.DefaultVectorStore.GetCollection>(collectionName); + + try + { + await collection.CreateCollectionIfNotExistsAsync(); + + FancyTestModel inserted = new() + { + Id = testStore.GenerateKey(1), + Number8 = byte.MaxValue, + Number16 = short.MaxValue, + Number32 = int.MaxValue, + Number64 = long.MaxValue, + Floats = Enumerable.Range(0, 10).Select(i => (float)i).ToArray(), + Bytes = [1, 2, 3], + }; + TKey key = await collection.UpsertAsync(inserted); + Assert.NotEqual(default, key); + + FancyTestModel? received = await collection.GetAsync(key, new() { IncludeVectors = true }); + AssertEquality(inserted, received, key); + + FancyTestModel updated = new() + { + Id = key, + Number16 = short.MinValue, // change one property + Floats = inserted.Floats + }; + key = await collection.UpsertAsync(updated); + Assert.Equal(updated.Id, key); + + received = await collection.GetAsync(updated.Id, new() { IncludeVectors = true }); + AssertEquality(updated, received, key); + + await collection.DeleteAsync(key); + + Assert.Null(await collection.GetAsync(key)); + } + finally + { + await collection.DeleteCollectionAsync(); + } + + void AssertEquality(FancyTestModel expected, FancyTestModel? received, TKey expectedKey) + { + Assert.NotNull(received); + Assert.Equal(expectedKey, received.Id); + Assert.Equal(expected.Number8, received.Number8); + Assert.Equal(expected.Number16, received.Number16); + Assert.Equal(expected.Number32, received.Number32); + Assert.Equal(expected.Number64, received.Number64); + Assert.Equal(expected.Floats.ToArray(), received.Floats.ToArray()); + Assert.Equal(expected.Bytes, received.Bytes); + } + } + + public sealed class FancyTestModel + { + [VectorStoreRecordKey(StoragePropertyName = "key")] + public TKey? Id { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "byte")] + public byte Number8 { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "short")] + public short Number16 { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "int")] + public int Number32 { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "long")] + public long Number64 { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "bytes")] +#pragma warning disable CA1819 // Properties should not return arrays + public byte[]? Bytes { get; set; } +#pragma warning restore CA1819 // Properties should not return arrays + + [VectorStoreRecordVector(Dimensions: 10, StoragePropertyName = "embedding")] + public ReadOnlyMemory Floats { get; set; } + } + + private sealed class TestModelMapper : IVectorStoreRecordMapper> + { + internal bool MapFromDataToStorageModel_WasCalled { get; set; } + internal bool MapFromStorageToDataModel_WasCalled { get; set; } + + public IDictionary MapFromDataToStorageModel(TestModel dataModel) + { + this.MapFromDataToStorageModel_WasCalled = true; + + return new Dictionary() + { + { "key", dataModel.Id }, + { "text", dataModel.Text }, + { "column", dataModel.Number }, + // Please note that we are not dealing with JSON directly here. + { "embedding", dataModel.Floats } + }; + } + + public TestModel MapFromStorageToDataModel(IDictionary storageModel, StorageToDataModelMapperOptions options) + { + this.MapFromStorageToDataModel_WasCalled = true; + + return new() + { + Id = (string)storageModel["key"]!, + Text = (string?)storageModel["text"], + Number = (int)storageModel["column"]!, + Floats = (ReadOnlyMemory)storageModel["embedding"]! + }; + } + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerConnectionStringRequiredAttribute.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerConnectionStringRequiredAttribute.cs new file mode 100644 index 000000000000..80885df9e18c --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerConnectionStringRequiredAttribute.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Xunit; + +namespace SqlServerIntegrationTests.Support; + +/// +/// Checks whether the connection string for Sql Server is provided, and skips the test(s) otherwise. +/// +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Assembly)] +public sealed class SqlServerConnectionStringRequiredAttribute : Attribute, ITestCondition +{ + public ValueTask IsMetAsync() => new(SqlServerTestEnvironment.IsConnectionStringDefined); + + public string Skip { get; set; } = "ConnectionString is not configured, set SqlServer:ConnectionString."; + + public string SkipReason => this.Skip; +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerFixture.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerFixture.cs new file mode 100644 index 000000000000..dabf7b40609e --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace SqlServerIntegrationTests.Support; + +public class SqlServerFixture : VectorStoreFixture +{ + public override TestStore TestStore => SqlServerTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestEnvironment.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestEnvironment.cs new file mode 100644 index 000000000000..043f4882e640 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestEnvironment.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.Configuration; +using Microsoft.SemanticKernel.Connectors.SqlServer; + +namespace SqlServerIntegrationTests.Support; + +internal static class SqlServerTestEnvironment +{ + public static readonly string? ConnectionString = GetConnectionString(); + + public static bool IsConnectionStringDefined => !string.IsNullOrEmpty(ConnectionString); + + private static string? GetConnectionString() + { + var configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true) + .AddJsonFile(path: "testsettings.development.json", optional: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + return configuration.GetSection("SqlServer")["ConnectionString"]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs new file mode 100644 index 000000000000..93a329b2438a --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/Support/SqlServerTestStore.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.SqlServer; +using VectorDataSpecificationTests.Support; + +namespace SqlServerIntegrationTests.Support; + +public sealed class SqlServerTestStore : TestStore, IDisposable +{ + public static readonly SqlServerTestStore Instance = new(); + + public override IVectorStore DefaultVectorStore + => this._connectedStore ?? throw new InvalidOperationException("Not initialized"); + + public override string DefaultDistanceFunction => DistanceFunction.CosineDistance; + + private SqlServerVectorStore? _connectedStore; + + protected override async Task StartAsync() + { + if (string.IsNullOrWhiteSpace(SqlServerTestEnvironment.ConnectionString)) + { + throw new InvalidOperationException("Connection string is not configured, set the SqlServer:ConnectionString environment variable"); + } + +#pragma warning disable CA2000 // Dispose objects before losing scope + SqlConnection connection = new(SqlServerTestEnvironment.ConnectionString); +#pragma warning restore CA2000 // Dispose objects before losing scope + await connection.OpenAsync(); + + this._connectedStore = new(connection); + } + + public void Dispose() => this._connectedStore?.Dispose(); +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerVectorSearchDistanceFunctionComplianceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerVectorSearchDistanceFunctionComplianceTests.cs new file mode 100644 index 000000000000..b1564100eb84 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerVectorSearchDistanceFunctionComplianceTests.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServerIntegrationTests.Support; +using VectorDataSpecificationTests.VectorSearch; +using Xunit; + +namespace SqlServerIntegrationTests.VectorSearch; + +public class SqlServerVectorSearchDistanceFunctionComplianceTests(SqlServerFixture fixture) + : VectorSearchDistanceFunctionComplianceTests(fixture), IClassFixture +{ + public override Task CosineSimilarity() => Assert.ThrowsAsync(base.CosineSimilarity); + + public override Task DotProductSimilarity() => Assert.ThrowsAsync(base.DotProductSimilarity); + + public override Task EuclideanSquaredDistance() => Assert.ThrowsAsync(base.EuclideanSquaredDistance); + + public override Task Hamming() => Assert.ThrowsAsync(base.Hamming); + + public override Task ManhattanDistance() => Assert.ThrowsAsync(base.ManhattanDistance); +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerVectorSearchDistanceFunctionComplianceTests_Hnsw.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerVectorSearchDistanceFunctionComplianceTests_Hnsw.cs new file mode 100644 index 000000000000..fe771d73278f --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/VectorSearch/SqlServerVectorSearchDistanceFunctionComplianceTests_Hnsw.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServerIntegrationTests.Support; +using Xunit; + +namespace SqlServerIntegrationTests.VectorSearch; + +public class SqlServerVectorSearchDistanceFunctionComplianceTests_Hnsw(SqlServerFixture fixture) + : SqlServerVectorSearchDistanceFunctionComplianceTests(fixture) +{ + // Creating such a collection is not supported. + protected override string IndexKind => Microsoft.Extensions.VectorData.IndexKind.Hnsw; + + public override async Task CosineDistance() + { + NotSupportedException ex = await Assert.ThrowsAsync(() => base.CosineDistance()); + Assert.Equal($"Index kind {this.IndexKind} is not supported.", ex.Message); + } + + public override async Task EuclideanDistance() + { + NotSupportedException ex = await Assert.ThrowsAsync(() => base.EuclideanDistance()); + Assert.Equal($"Index kind {this.IndexKind} is not supported.", ex.Message); + } + + public override async Task NegativeDotProductSimilarity() + { + NotSupportedException ex = await Assert.ThrowsAsync(() => base.NegativeDotProductSimilarity()); + Assert.Equal($"Index kind {this.IndexKind} is not supported.", ex.Message); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BatchConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BatchConformanceTests.cs new file mode 100644 index 000000000000..ace837591a74 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/BatchConformanceTests.cs @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Models; +using VectorDataSpecificationTests.Support; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace VectorDataSpecificationTests.CRUD; + +public abstract class BatchConformanceTests(VectorStoreFixture fixture) + : ConformanceTestsBase>(fixture) where TKey : notnull +{ + [ConditionalFact] + public async Task UpsertBatchAsync_EmptyBatch_DoesNotThrow() + { + await this.ExecuteAsync(async collection => + { + Assert.Empty(await collection.UpsertBatchAsync([]).ToArrayAsync()); + }); + } + + [ConditionalFact] + public async Task DeleteBatchAsync_EmptyBatch_DoesNotThrow() + { + await this.ExecuteAsync(async collection => + { + await collection.DeleteBatchAsync([]); + }); + } + + [ConditionalFact] + public async Task GetBatchAsync_EmptyBatch_DoesNotThrow() + { + await this.ExecuteAsync(async collection => + { + Assert.Empty(await collection.GetBatchAsync([]).ToArrayAsync()); + }); + } + + [ConditionalFact] + public async Task UpsertBatchAsync_NullBatch_ThrowsArgumentNullException() + { + await this.ExecuteAsync(async collection => + { + ArgumentNullException ex = await Assert.ThrowsAsync(() => collection.UpsertBatchAsync(records: null!).ToArrayAsync().AsTask()); + Assert.Equal("records", ex.ParamName); + }); + } + + [ConditionalFact] + public async Task DeleteBatchAsync_NullKeys_ThrowsArgumentNullException() + { + await this.ExecuteAsync(async collection => + { + ArgumentNullException ex = await Assert.ThrowsAsync(() => collection.DeleteBatchAsync(keys: null!)); + Assert.Equal("keys", ex.ParamName); + }); + } + + [ConditionalFact] + public async Task GetBatchAsync_NullKeys_ThrowsArgumentNullException() + { + await this.ExecuteAsync(async collection => + { + ArgumentNullException ex = await Assert.ThrowsAsync(() => collection.GetBatchAsync(keys: null!).ToArrayAsync().AsTask()); + Assert.Equal("keys", ex.ParamName); + }); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/ConformanceTestsBase.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/ConformanceTestsBase.cs new file mode 100644 index 000000000000..21a6c95f8986 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/ConformanceTestsBase.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Support; + +namespace VectorDataSpecificationTests.CRUD; + +// TKey is a generic parameter because different connectors support different key types. +public abstract class ConformanceTestsBase(VectorStoreFixture fixture) where TKey : notnull +{ + protected VectorStoreFixture Fixture { get; } = fixture; + + protected virtual string GetUniqueCollectionName() => Guid.NewGuid().ToString(); + + protected virtual VectorStoreRecordDefinition? GetRecordDefinition() => null; + + protected async Task ExecuteAsync(Func, Task> test) + { + string collectionName = this.GetUniqueCollectionName(); + var collection = this.Fixture.TestStore.DefaultVectorStore.GetCollection(collectionName, + this.GetRecordDefinition()); + + await collection.CreateCollectionAsync(); + + try + { + await test(collection); + } + finally + { + await collection.DeleteCollectionAsync(); + } + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/GenericDataModelConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/GenericDataModelConformanceTests.cs new file mode 100644 index 000000000000..91ac166aafd4 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/GenericDataModelConformanceTests.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Support; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace VectorDataSpecificationTests.CRUD; + +public abstract class GenericDataModelConformanceTests(VectorStoreFixture fixture) + : ConformanceTestsBase>(fixture) where TKey : notnull +{ + private const string KeyPropertyName = "key"; + private const string StringPropertyName = "text"; + private const string IntegerPropertyName = "integer"; + private const string EmbeddingPropertyName = "embedding"; + private const int DimensionCount = 10; + + protected override VectorStoreRecordDefinition? GetRecordDefinition() + => new() + { + Properties = + [ + new VectorStoreRecordKeyProperty(KeyPropertyName, typeof(TKey)), + new VectorStoreRecordDataProperty(StringPropertyName, typeof(string)), + new VectorStoreRecordDataProperty(IntegerPropertyName, typeof(int)), + new VectorStoreRecordVectorProperty(EmbeddingPropertyName, typeof(ReadOnlyMemory)) + { + Dimensions = DimensionCount + } + ] + }; + + [ConditionalFact] + public async Task CanInsertUpdateAndDelete() + { + await this.ExecuteAsync(async collection => + { + VectorStoreGenericDataModel inserted = new(key: this.Fixture.GenerateNextKey()); + inserted.Data.Add(StringPropertyName, "some"); + inserted.Data.Add(IntegerPropertyName, 123); + inserted.Vectors.Add(EmbeddingPropertyName, new ReadOnlyMemory(Enumerable.Repeat(0.1f, DimensionCount).ToArray())); + + TKey key = await collection.UpsertAsync(inserted); + Assert.Equal(inserted.Key, key); + + VectorStoreGenericDataModel? received = await collection.GetAsync(key, new() { IncludeVectors = true }); + Assert.NotNull(received); + + Assert.Equal(received.Key, key); + foreach (var pair in inserted.Data) + { + Assert.Equal(pair.Value, received.Data[pair.Key]); + } + + Assert.Equal( + ((ReadOnlyMemory)inserted.Vectors[EmbeddingPropertyName]!).ToArray(), + ((ReadOnlyMemory)received.Vectors[EmbeddingPropertyName]!).ToArray()); + + await collection.DeleteAsync(key); + + received = await collection.GetAsync(key); + Assert.Null(received); + }); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTests.cs index 0f87d2ae7c5d..dd03c1b1bda7 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTests.cs @@ -73,6 +73,14 @@ public virtual Task NotEqual_with_null_captured() public virtual Task Bool() => this.TestFilterAsync(r => r.Bool); + [ConditionalFact] + public virtual Task Bool_And_Bool() + => this.TestFilterAsync(r => r.Bool && r.Bool); + + [ConditionalFact] + public virtual Task Bool_Or_Not_Bool() + => this.TestFilterAsync(r => r.Bool || !r.Bool, expectAllResults: true); + #endregion Equality #region Comparison @@ -139,6 +147,10 @@ public virtual Task Not_over_Or() public virtual Task Not_over_bool() => this.TestFilterAsync(r => !r.Bool); + [ConditionalFact] + public virtual Task Not_over_bool_And_Comparison() + => this.TestFilterAsync(r => !r.Bool && r.Int != int.MaxValue); + #endregion Logical operators #region Contains diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Models/SimpleModel.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Models/SimpleModel.cs new file mode 100644 index 000000000000..0646f0fe2f1f --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Models/SimpleModel.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; + +namespace VectorDataSpecificationTests.Models; + +/// +/// This class represents bare minimum that each connector should support: +/// a key, int, string and an embedding. +/// +/// TKey is a generic parameter because different connectors support different key types. +public sealed class SimpleModel +{ + [VectorStoreRecordKey(StoragePropertyName = "key")] + public TKey? Id { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "text")] + public string? Text { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "number")] + public int Number { get; set; } + + [VectorStoreRecordVector(Dimensions: 10, StoragePropertyName = "embedding")] + public ReadOnlyMemory Floats { get; set; } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/VectorSearchDistanceFunctionComplianceTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/VectorSearchDistanceFunctionComplianceTests.cs new file mode 100644 index 000000000000..285c93c23e92 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorSearch/VectorSearchDistanceFunctionComplianceTests.cs @@ -0,0 +1,155 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Support; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace VectorDataSpecificationTests.VectorSearch; + +public abstract class VectorSearchDistanceFunctionComplianceTests(VectorStoreFixture fixture) + where TKey : notnull +{ + [ConditionalFact] + public virtual Task CosineDistance() + => this.SimpleSearch(DistanceFunction.CosineDistance, 0, 2, 1, [0, 2, 1]); + + [ConditionalFact] + public virtual Task CosineSimilarity() + => this.SimpleSearch(DistanceFunction.CosineSimilarity, 1, -1, 0, [0, 2, 1]); + + [ConditionalFact] + public virtual Task DotProductSimilarity() + => this.SimpleSearch(DistanceFunction.DotProductSimilarity, 1, -1, 0, [0, 2, 1]); + + [ConditionalFact] + public virtual Task NegativeDotProductSimilarity() + => this.SimpleSearch(DistanceFunction.NegativeDotProductSimilarity, -1, 1, 0, [1, 2, 0]); + + [ConditionalFact] + public virtual Task EuclideanDistance() + => this.SimpleSearch(DistanceFunction.EuclideanDistance, 0, 2, 1.73, [0, 2, 1]); + + [ConditionalFact] + public virtual Task EuclideanSquaredDistance() + => this.SimpleSearch(DistanceFunction.EuclideanSquaredDistance, 0, 4, 3, [0, 2, 1]); + + [ConditionalFact] + public virtual Task Hamming() + => this.SimpleSearch(DistanceFunction.Hamming, 0, 1, 3, [0, 1, 2]); + + [ConditionalFact] + public virtual Task ManhattanDistance() + => this.SimpleSearch(DistanceFunction.ManhattanDistance, 0, 2, 3, [0, 1, 2]); + + protected virtual string? IndexKind => null; + + protected async Task SimpleSearch(string distanceFunction, double expectedExactMatchScore, + double expectedOppositeScore, double expectedOrthogonalScore, int[] resultOrder) + { + ReadOnlyMemory baseVector = new([1, 0, 0, 0]); + ReadOnlyMemory oppositeVector = new([-1, 0, 0, 0]); + ReadOnlyMemory orthogonalVector = new([0f, -1f, -1f, 0f]); + + double[] scoreDictionary = [expectedExactMatchScore, expectedOppositeScore, expectedOrthogonalScore]; + + List records = + [ + new() + { + Key = fixture.GenerateNextKey(), + Int = 8, + Vector = baseVector, + }, + new() + { + Key = fixture.GenerateNextKey(), + Int = 9, + String = "bar", + Vector = oppositeVector, + }, + new() + { + Key = fixture.GenerateNextKey(), + Int = 9, + String = "foo", + Vector = orthogonalVector, + } + ]; + + // The record definition describes the distance function, + // so we need a dedicated collection per test. + string uniqueCollectionName = Guid.NewGuid().ToString(); + var collection = fixture.TestStore.DefaultVectorStore.GetCollection( + uniqueCollectionName, this.GetRecordDefinition(distanceFunction)); + + await collection.CreateCollectionAsync(); + + await collection.CreateCollectionIfNotExistsAsync(); // just to make sure it's idempotent + + try + { + await collection.UpsertBatchAsync(records).ToArrayAsync(); + + var searchResult = await collection.VectorizedSearchAsync(baseVector); + var results = await searchResult.Results.ToListAsync(); + VerifySearchResults(resultOrder, scoreDictionary, records, results, includeVectors: false); + + searchResult = await collection.VectorizedSearchAsync(baseVector, new() { IncludeVectors = true }); + results = await searchResult.Results.ToListAsync(); + VerifySearchResults(resultOrder, scoreDictionary, records, results, includeVectors: true); + } + finally + { + await collection.DeleteCollectionAsync(); + } + + static void VerifySearchResults(int[] resultOrder, double[] scoreDictionary, List records, + List> results, bool includeVectors) + { + Assert.Equal(records.Count, results.Count); + for (int i = 0; i < results.Count; i++) + { + Assert.Equal(records[resultOrder[i]].Key, results[i].Record.Key); + Assert.Equal(records[resultOrder[i]].Int, results[i].Record.Int); + Assert.Equal(records[resultOrder[i]].String, results[i].Record.String); + Assert.Equal(Math.Round(scoreDictionary[resultOrder[i]], 2), Math.Round(results[i].Score!.Value, 2)); + + if (includeVectors) + { + Assert.Equal(records[resultOrder[i]].Vector.ToArray(), results[i].Record.Vector.ToArray()); + } + else + { + Assert.Equal(0, results[i].Record.Vector.Length); + } + } + } + } + + private VectorStoreRecordDefinition GetRecordDefinition(string distanceFunction) + => new() + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(SearchRecord.Key), typeof(TKey)), + new VectorStoreRecordVectorProperty(nameof(SearchRecord.Vector), typeof(ReadOnlyMemory)) + { + Dimensions = 4, + DistanceFunction = distanceFunction, + IndexKind = this.IndexKind + }, + new VectorStoreRecordDataProperty(nameof(SearchRecord.Int), typeof(int)) { IsFilterable = true }, + new VectorStoreRecordDataProperty(nameof(SearchRecord.String), typeof(string)) { IsFilterable = true }, + ] + }; + + public class SearchRecord + { + public TKey Key { get; set; } = default!; + public ReadOnlyMemory Vector { get; set; } + + public int Int { get; set; } + public string? String { get; set; } + } +}