Skip to content

.Net: IVectorStore implementation for Azure SQL #10623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
160e08e
add new test project, move the existing tests to it
adamsitnik Feb 11, 2025
186fda2
port existing tests to Testcontainers.MsSql and re-enable them
adamsitnik Feb 11, 2025
22495b9
Revert "port existing tests to Testcontainers.MsSql and re-enable them"
adamsitnik Feb 12, 2025
3c3fd4a
Merge remote-tracking branch 'upstream/feature-vector-data-preb1' int…
adamsitnik Feb 12, 2025
179f56e
implement the tests using the new pattern, provide implementation tha…
adamsitnik Feb 12, 2025
486d028
implement collection removal, existence check and creation
adamsitnik Feb 14, 2025
a41cac4
implement record insert and update (upsert)
adamsitnik Feb 16, 2025
6dfb04a
implement delete operations
adamsitnik Feb 17, 2025
e905409
GetAsync and GetBatchAsync
adamsitnik Feb 18, 2025
7c212da
refactor
adamsitnik Feb 18, 2025
8d71b11
implement UpsertBatchAsync
adamsitnik Feb 18, 2025
7f18352
implement SelectTableNames, read the code again and add TODOs for thi…
adamsitnik Feb 19, 2025
32605da
ensure that parameter names are always valid
adamsitnik Feb 19, 2025
b4a73ee
add some comments
adamsitnik Feb 19, 2025
e8584be
support storing more types, support auto-generated keys
adamsitnik Feb 19, 2025
f397f3f
simplify: don't use a dedicated query for inserting a single record
adamsitnik Feb 19, 2025
ffc4b14
Merge remote-tracking branch 'upstream/feature-vector-data-preb1' int…
adamsitnik Feb 20, 2025
7c8d2dc
vector search
adamsitnik Feb 20, 2025
9e5ef1c
implement filtering by reusing a lot of code implemented by @roji
adamsitnik Feb 20, 2025
080811f
reduce code duplication
adamsitnik Feb 20, 2025
c17021e
skip some tests, some polishing
adamsitnik Feb 20, 2025
4669e91
remove a comment added by Copilot
adamsitnik Feb 20, 2025
ba0486f
Update dotnet/src/Connectors/VectorData.Abstractions/RecordAttributes…
adamsitnik Feb 24, 2025
5bdaa8e
address code review feedback:
adamsitnik Feb 24, 2025
1902c0b
address remaining feedback:
adamsitnik Feb 25, 2025
3081305
implement IndexKind support for SqlServer and fix it for PostgreSQL:
adamsitnik Feb 26, 2025
5b843aa
fix the build
adamsitnik Feb 26, 2025
8bb8aea
throw for null inputs, do nothing for empty ones
adamsitnik Feb 26, 2025
f76b573
address code review feedback:
adamsitnik Feb 28, 2025
c40f341
Update dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition…
adamsitnik Feb 28, 2025
2fe49c0
Apply suggestions from code review
adamsitnik Mar 6, 2025
4639a17
Merge remote-tracking branch 'upstream/feature-vector-data-preb1' int…
adamsitnik Mar 6, 2025
0bdca76
address code review feedback:
adamsitnik Mar 6, 2025
88419c3
Merge remote-tracking branch 'upstream/feature-vector-data-preb1' int…
adamsitnik Mar 6, 2025
9ed18ab
remove AutoGenerate
adamsitnik Mar 7, 2025
c42c6cb
Apply suggestions from code review
adamsitnik Mar 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions dotnet/SK-dotnet.sln
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "CosmosMongoDBIntegrationTests", "src\VectorDataIntegrationTests\CosmosMongoDBIntegrationTests\CosmosMongoDBIntegrationTests.csproj", "{11DFBF14-6FBA-41F0-B7F3-A288952D6FDB}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AzureAISearchIntegrationTests", "src\VectorDataIntegrationTests\AzureAISearchIntegrationTests\AzureAISearchIntegrationTests.csproj", "{06181F0F-A375-43AE-B45F-73CBCFC30C14}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Agents.AzureAI", "src\Agents\AzureAI\Agents.AzureAI.csproj", "{EA35F1B5-9148-4189-BE34-5E00AED56D65}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Plugins.AI", "src\Plugins\Plugins.AI\Plugins.AI.csproj", "{0C64EC81-8116-4388-87AD-BA14D4B59974}"
Expand Down Expand Up @@ -491,6 +492,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Agents.Bedrock", "src\Agent
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ModelContextProtocol", "samples\Demos\ModelContextProtocol\ModelContextProtocol.csproj", "{B16AC373-3DA8-4505-9510-110347CD635D}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SqlServerIntegrationTests", "src\VectorDataIntegrationTests\SqlServerIntegrationTests\SqlServerIntegrationTests.csproj", "{A5E6193C-8431-4C6E-B674-682CB41EAA0C}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -1356,6 +1359,12 @@ Global
{B16AC373-3DA8-4505-9510-110347CD635D}.Publish|Any CPU.Build.0 = Debug|Any CPU
{B16AC373-3DA8-4505-9510-110347CD635D}.Release|Any CPU.ActiveCfg = Release|Any CPU
{B16AC373-3DA8-4505-9510-110347CD635D}.Release|Any CPU.Build.0 = Release|Any CPU
{A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Publish|Any CPU.ActiveCfg = Debug|Any CPU
{A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Publish|Any CPU.Build.0 = Debug|Any CPU
{A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A5E6193C-8431-4C6E-B674-682CB41EAA0C}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -1541,6 +1550,7 @@ Global
{DAD5FC6A-8CA0-43AC-87E1-032DFBD6B02A} = {3F260A77-B6C9-97FD-1304-4B34DA936CF4}
{8C658E1E-83C8-4127-B8BF-27A638A45DDD} = {6823CD5E-2ABE-41EB-B865-F86EC13F0CF9}
{B16AC373-3DA8-4505-9510-110347CD635D} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263}
{A5E6193C-8431-4C6E-B674-682CB41EAA0C} = {4F381919-F1BE-47D8-8558-3187ED04A84F}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83}
Expand Down
326 changes: 326 additions & 0 deletions dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;

namespace Microsoft.SemanticKernel.Connectors;

internal abstract class SqlFilterTranslator
{
private readonly IReadOnlyDictionary<string, string> _storagePropertyNames;
private readonly LambdaExpression _lambdaExpression;
private readonly ParameterExpression _recordParameter;
protected readonly StringBuilder _sql;

internal SqlFilterTranslator(
IReadOnlyDictionary<string, string> storagePropertyNames,
LambdaExpression lambdaExpression,
StringBuilder? sql = null)
{
this._storagePropertyNames = storagePropertyNames;
this._lambdaExpression = lambdaExpression;
Debug.Assert(lambdaExpression.Parameters.Count == 1);
this._recordParameter = lambdaExpression.Parameters[0];
this._sql = sql ?? new();
}

internal StringBuilder Clause => this._sql;

internal void Translate(bool appendWhere)
{
if (appendWhere)
{
this._sql.Append("WHERE ");
}

this.Translate(this._lambdaExpression.Body, null);
}

protected void Translate(Expression? node, Expression? parent)
{
switch (node)
{
case BinaryExpression binary:
this.TranslateBinary(binary);
return;

case ConstantExpression constant:
this.TranslateConstant(constant.Value);
return;

case MemberExpression member:
this.TranslateMember(member, parent);
return;

case MethodCallExpression methodCall:
this.TranslateMethodCall(methodCall);
return;

case UnaryExpression unary:
this.TranslateUnary(unary);
return;

default:
throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType);
}
}

protected void TranslateBinary(BinaryExpression binary)
{
// Special handling for null comparisons
switch (binary.NodeType)
{
case ExpressionType.Equal when IsNull(binary.Right):
this._sql.Append('(');
this.Translate(binary.Left, binary);
this._sql.Append(" IS NULL)");
return;
case ExpressionType.NotEqual when IsNull(binary.Right):
this._sql.Append('(');
this.Translate(binary.Left, binary);
this._sql.Append(" IS NOT NULL)");
return;

case ExpressionType.Equal when IsNull(binary.Left):
this._sql.Append('(');
this.Translate(binary.Right, binary);
this._sql.Append(" IS NULL)");
return;
case ExpressionType.NotEqual when IsNull(binary.Left):
this._sql.Append('(');
this.Translate(binary.Right, binary);
this._sql.Append(" IS NOT NULL)");
return;
}

this._sql.Append('(');
this.Translate(binary.Left, binary);

this._sql.Append(binary.NodeType switch
{
ExpressionType.Equal => " = ",
ExpressionType.NotEqual => " <> ",

ExpressionType.GreaterThan => " > ",
ExpressionType.GreaterThanOrEqual => " >= ",
ExpressionType.LessThan => " < ",
ExpressionType.LessThanOrEqual => " <= ",

ExpressionType.AndAlso => " AND ",
ExpressionType.OrElse => " OR ",

_ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType)
});

this.Translate(binary.Right, binary);
this._sql.Append(')');

static bool IsNull(Expression expression)
=> expression is ConstantExpression { Value: null }
|| (TryGetCapturedValue(expression, out _, out var capturedValue) && capturedValue is null);
}

protected virtual void TranslateConstant(object? value)
{
// TODO: Nullable
switch (value)
{
case byte b:
this._sql.Append(b);
return;
case short s:
this._sql.Append(s);
return;
case int i:
this._sql.Append(i);
return;
case long l:
this._sql.Append(l);
return;

case string s:
this._sql.Append('\'').Append(s.Replace("'", "''")).Append('\'');
return;
case bool b:
this._sql.Append(b ? "TRUE" : "FALSE");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the representation of bool varies considerably across databases, this probably shouldn't be here (only in the overrides of this method in connector subtypes).

return;
case Guid g:
this._sql.Append('\'').Append(g.ToString()).Append('\'');
return;

case DateTime dateTime:
case DateTimeOffset dateTimeOffset:
case Array:
throw new NotImplementedException();

case null:
this._sql.Append("NULL");
return;

default:
throw new NotSupportedException("Unsupported constant type: " + value.GetType().Name);
}
}

private void TranslateMember(MemberExpression memberExpression, Expression? parent)
{
switch (memberExpression)
{
case var _ when this.TryGetColumn(memberExpression, out var column):
this.TranslateColumn(column, memberExpression, parent);
return;

case var _ when TryGetCapturedValue(memberExpression, out var name, out var value):
this.TranslateCapturedVariable(name, value);
return;

default:
throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported");
}
}

protected virtual void TranslateColumn(string column, MemberExpression memberExpression, Expression? parent)
=> this._sql.Append('"').Append(column).Append('"');

protected abstract void TranslateCapturedVariable(string name, object? capturedValue);

private void TranslateMethodCall(MethodCallExpression methodCall)
{
switch (methodCall)
{
// Enumerable.Contains()
case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains
when contains.Method.DeclaringType == typeof(Enumerable):
this.TranslateContains(source, item, methodCall);
return;

// List.Contains()
case
{
Method:
{
Name: nameof(Enumerable.Contains),
DeclaringType: { IsGenericType: true } declaringType
},
Object: Expression source,
Arguments: [var item]
} when declaringType.GetGenericTypeDefinition() == typeof(List<>):
this.TranslateContains(source, item, methodCall);
return;

default:
throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}");
}
}

private void TranslateContains(Expression source, Expression item, MethodCallExpression parent)
{
switch (source)
{
// Contains over array column (r => r.Strings.Contains("foo"))
case var _ when this.TryGetColumn(source, out _):
this.TranslateContainsOverArrayColumn(source, item, parent);
return;

// Contains over inline array (r => new[] { "foo", "bar" }.Contains(r.String))
case NewArrayExpression newArray:
this.Translate(item, parent);
this._sql.Append(" IN (");

var isFirst = true;
foreach (var element in newArray.Expressions)
{
if (isFirst)
{
isFirst = false;
}
else
{
this._sql.Append(", ");
}

this.Translate(element, parent);
}

this._sql.Append(')');
return;

// Contains over captured array (r => arrayLocalVariable.Contains(r.String))
case var _ when TryGetCapturedValue(source, out _, out var value):
this.TranslateContainsOverCapturedArray(source, item, parent, value);
return;

default:
throw new NotSupportedException("Unsupported Contains expression");
}
}

protected abstract void TranslateContainsOverArrayColumn(Expression source, Expression item, MethodCallExpression parent);

protected abstract void TranslateContainsOverCapturedArray(Expression source, Expression item, MethodCallExpression parent, object? value);

private void TranslateUnary(UnaryExpression unary)
{
switch (unary.NodeType)
{
case ExpressionType.Not:
// Special handling for !(a == b) and !(a != b)
if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary)
{
this.TranslateBinary(
Expression.MakeBinary(
binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal,
binary.Left,
binary.Right));
return;
}

this._sql.Append("(NOT ");
this.Translate(unary.Operand, unary);
this._sql.Append(')');
return;

default:
throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType);
}
}

private bool TryGetColumn(Expression expression, [NotNullWhen(true)] out string? column)
{
if (expression is MemberExpression member && member.Expression == this._recordParameter)
{
if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column))
{
throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name.");
}

return true;
}

column = null;
return false;
}

private static bool TryGetCapturedValue(Expression expression, [NotNullWhen(true)] out string? name, out object? value)
{
if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo }
&& constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate)
&& Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true))
{
name = fieldInfo.Name;
value = fieldInfo.GetValue(constant.Value);
return true;
}

name = null;
value = null;
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
<Description>Postgres(with pgvector extension) connector for Semantic Kernel plugins and semantic memory</Description>
</PropertyGroup>

<ItemGroup>
<Compile Include="..\Connectors.Memory.Common\SqlFilterTranslator.cs" Link="SqlFilterTranslator.cs" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" />
<PackageReference Include="Npgsql" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ internal interface IPostgresVectorStoreCollectionSqlBuilder
/// <param name="vectorColumnName">The name of the vector column.</param>
/// <param name="indexKind">The kind of index to create.</param>
/// <param name="distanceFunction">The distance function to use for the index.</param>
/// <param name="ifNotExists">Specifies whether to include IF NOT EXISTS in the command.</param>
/// <returns>The built SQL command info.</returns>
PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction);
PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction, bool ifNotExists);

/// <summary>
/// Builds a SQL command to drop a table in the Postgres vector store.
Expand Down
Loading
Loading