From e3a01103cc4f5b79fb5a3afbd6dcdadabeeb7bba Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Mon, 6 May 2024 21:18:21 -0700 Subject: [PATCH 1/7] Added caching example with in-memory store --- .../Caching/SemanticCachingWithFilters.cs | 139 ++++++++++++++++++ .../Filters/Prompt/PromptRenderContext.cs | 6 + .../Functions/FunctionResult.cs | 8 + .../Functions/KernelFunctionFromPrompt.cs | 13 +- .../Functions/PromptRenderingResult.cs | 2 + .../Filters/PromptRenderFilterTests.cs | 28 ++++ 6 files changed, 193 insertions(+), 3 deletions(-) create mode 100644 dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs diff --git a/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs b/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs new file mode 100644 index 000000000000..364664efa15d --- /dev/null +++ b/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Memory; + +namespace Caching; + +public class SemanticCachingWithFilters(ITestOutputHelper output) : BaseTest(output) +{ + private const double SimilarityScore = 0.9; + + [Fact] + public async Task InMemoryCacheAsync() + { + var kernel = GetKernel(_ => new VolatileMemoryStore()); + + Console.WriteLine("First run:"); + var result1 = await ExecuteAsync(() => kernel.InvokePromptAsync("What's the tallest building in New York?")); + + Console.WriteLine("Second run:"); + var result2 = await ExecuteAsync(() => kernel.InvokePromptAsync("What is the highest building in New York City?")); + + Console.WriteLine($"Result 1: {result1}"); + Console.WriteLine($"Result 2: {result2}"); + } + + [Fact] + public async Task RedisCacheAsync() + { + + } + + [Fact] + public async Task AzureCosmosDBMongoDBCacheAsync() + { + + } + + #region Configuration + + private Kernel GetKernel(Func cacheFactory) + { + var builder = Kernel.CreateBuilder(); + + builder.AddAzureOpenAIChatCompletion( + TestConfiguration.AzureOpenAI.ChatDeploymentName, + TestConfiguration.AzureOpenAI.Endpoint, + TestConfiguration.AzureOpenAI.ApiKey); + + builder.AddAzureOpenAITextEmbeddingGeneration( + TestConfiguration.AzureOpenAIEmbeddings.DeploymentName, + TestConfiguration.AzureOpenAIEmbeddings.Endpoint, + TestConfiguration.AzureOpenAIEmbeddings.ApiKey); + + builder.Services.AddSingleton(cacheFactory); + builder.Services.AddSingleton(); + + builder.Services.AddSingleton(); + builder.Services.AddSingleton( + (sp) => new PromptCacheFilter(sp.GetRequiredService(), SimilarityScore)); + + return builder.Build(); + } + + #endregion + + #region Caching filters + + public class CacheBaseFilter + { + protected const string CacheCollectionName = "llm_responses"; + + protected const string IsCachedResultKey = "IsCachedResult"; + } + + public sealed class PromptCacheFilter(ISemanticTextMemory semanticTextMemory, double minRelevanceScore) : CacheBaseFilter, IPromptRenderFilter + { + public async Task OnPromptRenderAsync(PromptRenderContext context, Func next) + { + await next(context); + + var prompt = context.RenderedPrompt!; + + var searchResult = await semanticTextMemory.SearchAsync( + CacheCollectionName, + prompt, + limit: 1, + minRelevanceScore: minRelevanceScore).FirstOrDefaultAsync(); + + if (searchResult is not null) + { + context.Result = new FunctionResult(context.Function, searchResult.Metadata.AdditionalMetadata) + { + Metadata = new Dictionary { [IsCachedResultKey] = true } + }; + } + } + } + + public sealed class FunctionCacheFilter(ISemanticTextMemory semanticTextMemory) : CacheBaseFilter, IFunctionInvocationFilter + { + public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func next) + { + await next(context); + + var result = context.Result; + + if (!string.IsNullOrEmpty(context.Result.RenderedPrompt)) + { + await semanticTextMemory.SaveInformationAsync( + CacheCollectionName, + context.Result.RenderedPrompt, + Guid.NewGuid().ToString(), + additionalMetadata: result.ToString()); + } + } + } + + #endregion + + #region Benchmarking + + private async Task ExecuteAsync(Func> action) + { + var stopwatch = Stopwatch.StartNew(); + + var result = await action(); + + stopwatch.Stop(); + + Console.WriteLine($@"Elapsed Time: {stopwatch.Elapsed:hh\:mm\:ss\.FFF}"); + + return result; + } + + #endregion +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs index 79402ceac836..a1e449642071 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs @@ -62,4 +62,10 @@ public string? RenderedPrompt this._renderedPrompt = value; } } + + /// + /// Gets or sets the result of the function's invocation. + /// Setting to a non-null value will skip function invocation and return the result. + /// + public FunctionResult? Result { get; set; } } diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs index 0ebba8bca441..62cc5d343d01 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Globalization; namespace Microsoft.SemanticKernel; @@ -41,6 +42,7 @@ public FunctionResult(FunctionResult result, object? value = null) this.Value = value ?? result.Value; this.Culture = result.Culture; this.Metadata = result.Metadata; + this.RenderedPrompt = result.RenderedPrompt; } /// @@ -67,6 +69,12 @@ public FunctionResult(FunctionResult result, object? value = null) /// public Type? ValueType => this.Value?.GetType(); + /// + /// Gets the prompt used during function invocation if any was rendered. + /// + [Experimental("SKEXP0001")] + public string? RenderedPrompt { get; internal set; } + /// /// Returns function result value. /// diff --git a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs index ff2b16578038..818f54e7ae2a 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs @@ -115,7 +115,7 @@ public static KernelFunction Create( logger: loggerFactory?.CreateLogger(typeof(KernelFunctionFactory)) ?? NullLogger.Instance); } - /// j + /// protected override async ValueTask InvokeCoreAsync( Kernel kernel, KernelArguments arguments, @@ -132,18 +132,24 @@ protected override async ValueTask InvokeCoreAsync( } #pragma warning restore CS0612 // Events are deprecated + // Return function result if it was set in prompt filter. + if (result.FunctionResult is not null) + { + return new(result.FunctionResult) { RenderedPrompt = result.RenderedPrompt }; + } + if (result.AIService is IChatCompletionService chatCompletion) { var chatContent = await chatCompletion.GetChatMessageContentAsync(result.RenderedPrompt, result.ExecutionSettings, kernel, cancellationToken).ConfigureAwait(false); this.CaptureUsageDetails(chatContent.ModelId, chatContent.Metadata, this._logger); - return new FunctionResult(this, chatContent, kernel.Culture, chatContent.Metadata); + return new FunctionResult(this, chatContent, kernel.Culture, chatContent.Metadata) { RenderedPrompt = result.RenderedPrompt }; } if (result.AIService is ITextGenerationService textGeneration) { var textContent = await textGeneration.GetTextContentWithDefaultParserAsync(result.RenderedPrompt, result.ExecutionSettings, kernel, cancellationToken).ConfigureAwait(false); this.CaptureUsageDetails(textContent.ModelId, textContent.Metadata, this._logger); - return new FunctionResult(this, textContent, kernel.Culture, textContent.Metadata); + return new FunctionResult(this, textContent, kernel.Culture, textContent.Metadata) { RenderedPrompt = result.RenderedPrompt }; } // The service selector didn't find an appropriate service. This should only happen with a poorly implemented selector. @@ -375,6 +381,7 @@ private async Task RenderPromptAsync(Kernel kernel, Kerne { ExecutionSettings = executionSettings, RenderedEventArgs = renderedEventArgs, + FunctionResult = renderingContext.Result }; } diff --git a/dotnet/src/SemanticKernel.Core/Functions/PromptRenderingResult.cs b/dotnet/src/SemanticKernel.Core/Functions/PromptRenderingResult.cs index 765585be9960..7aee48fc130b 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/PromptRenderingResult.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/PromptRenderingResult.cs @@ -15,6 +15,8 @@ internal sealed class PromptRenderingResult public PromptExecutionSettings? ExecutionSettings { get; set; } + public FunctionResult? FunctionResult { get; set; } + #pragma warning disable CS0618 // Events are deprecated public PromptRenderedEventArgs? RenderedEventArgs { get; set; } #pragma warning restore CS0618 // Events are deprecated diff --git a/dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs b/dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs index eff697278997..020008070387 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs @@ -236,4 +236,32 @@ public async Task PostInvocationPromptFilterSkippingWorksCorrectlyAsync() // Assert mockTextGeneration.Verify(m => m.GetTextContentsAsync("", It.IsAny(), It.IsAny(), It.IsAny()), Times.Once()); } + + [Fact] + public async Task PromptFilterCanOverrideFunctionResultAsync() + { + // Arrange + var mockTextGeneration = this.GetMockTextGeneration(); + var function = KernelFunctionFactory.CreateFromPrompt("Prompt"); + + var kernel = this.GetKernelWithFilters(textGenerationService: mockTextGeneration.Object, + onPromptRender: async (context, next) => + { + await next(context); + + context.Result = new FunctionResult(context.Function, "Result from prompt filter"); + }, + onFunctionInvocation: async (context, next) => + { + await next(context); + }); + + // Act + var result = await kernel.InvokeAsync(function); + + // Assert + mockTextGeneration.Verify(m => m.GetTextContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Never()); + + Assert.Equal("Result from prompt filter", result.ToString()); + } } From 7635783f67cef7e8849c5b8b2211fd7324a3446c Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Tue, 7 May 2024 14:16:47 -0700 Subject: [PATCH 2/7] Added caching example for Redis and Cosmos DB MongoDB --- .../Caching/SemanticCachingWithFilters.cs | 52 ++++++++++++++----- dotnet/samples/Concepts/Concepts.csproj | 1 + .../AzureCosmosDBMongoDBMemoryRecord.cs | 9 ++-- .../AzureCosmosDBMongoDBMemoryStore.cs | 10 ++-- .../AzureCosmosDBSimilarityType.cs | 19 +++++-- .../AzureCosmosDBVectorSearchType.cs | 16 ++++-- .../InternalUtilities/TestConfiguration.cs | 7 +++ .../Memory/ISemanticTextMemory.cs | 3 ++ .../Memory/MemoryRecord.cs | 2 +- .../Memory/NullMemory.cs | 2 + .../Memory/SemanticTextMemory.cs | 31 +++++++---- 11 files changed, 111 insertions(+), 41 deletions(-) diff --git a/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs b/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs index 364664efa15d..037e042570cc 100644 --- a/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs +++ b/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs @@ -3,6 +3,8 @@ using System.Diagnostics; using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; +using Microsoft.SemanticKernel.Connectors.Redis; using Microsoft.SemanticKernel.Memory; namespace Caching; @@ -14,13 +16,10 @@ public class SemanticCachingWithFilters(ITestOutputHelper output) : BaseTest(out [Fact] public async Task InMemoryCacheAsync() { - var kernel = GetKernel(_ => new VolatileMemoryStore()); + var kernel = GetKernelWithCache(_ => new VolatileMemoryStore()); - Console.WriteLine("First run:"); - var result1 = await ExecuteAsync(() => kernel.InvokePromptAsync("What's the tallest building in New York?")); - - Console.WriteLine("Second run:"); - var result2 = await ExecuteAsync(() => kernel.InvokePromptAsync("What is the highest building in New York City?")); + var result1 = await ExecuteAsync(kernel, "First run", "What's the tallest building in New York?"); + var result2 = await ExecuteAsync(kernel, "Second run", "What is the highest building in New York City?"); Console.WriteLine($"Result 1: {result1}"); Console.WriteLine($"Result 2: {result2}"); @@ -29,18 +28,38 @@ public async Task InMemoryCacheAsync() [Fact] public async Task RedisCacheAsync() { + var kernel = GetKernelWithCache(_ => new RedisMemoryStore("localhost:6379", vectorSize: 1536)); + + var result1 = await ExecuteAsync(kernel, "First run", "What's the tallest building in New York?"); + var result2 = await ExecuteAsync(kernel, "Second run", "What is the highest building in New York City?"); + Console.WriteLine($"Result 1: {result1}"); + Console.WriteLine($"Result 2: {result2}"); } [Fact] public async Task AzureCosmosDBMongoDBCacheAsync() { + var kernel = GetKernelWithCache(_ => new AzureCosmosDBMongoDBMemoryStore( + TestConfiguration.AzureCosmosDbMongoDb.ConnectionString, + TestConfiguration.AzureCosmosDbMongoDb.DatabaseName, + new() + { + Kind = AzureCosmosDBVectorSearchType.VectorIVF, + Similarity = AzureCosmosDBSimilarityType.Cosine, + Dimensions = 1536 + })); + var result1 = await ExecuteAsync(kernel, "First run", "What's the tallest building in New York?"); + var result2 = await ExecuteAsync(kernel, "Second run", "What is the highest building in New York City?"); + + Console.WriteLine($"Result 1: {result1}"); + Console.WriteLine($"Result 2: {result2}"); } #region Configuration - private Kernel GetKernel(Func cacheFactory) + private Kernel GetKernelWithCache(Func cacheFactory) { var builder = Kernel.CreateBuilder(); @@ -72,7 +91,7 @@ public class CacheBaseFilter { protected const string CacheCollectionName = "llm_responses"; - protected const string IsCachedResultKey = "IsCachedResult"; + protected const string CacheRecordIdKey = "CacheRecordId"; } public sealed class PromptCacheFilter(ISemanticTextMemory semanticTextMemory, double minRelevanceScore) : CacheBaseFilter, IPromptRenderFilter @@ -93,7 +112,7 @@ public async Task OnPromptRenderAsync(PromptRenderContext context, Func { [IsCachedResultKey] = true } + Metadata = new Dictionary { [CacheRecordIdKey] = searchResult.Metadata.Id } }; } } @@ -109,24 +128,29 @@ public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, F if (!string.IsNullOrEmpty(context.Result.RenderedPrompt)) { + var recordId = context.Result.Metadata?.GetValueOrDefault(CacheRecordIdKey, Guid.NewGuid().ToString()) as string; + await semanticTextMemory.SaveInformationAsync( CacheCollectionName, context.Result.RenderedPrompt, - Guid.NewGuid().ToString(), - additionalMetadata: result.ToString()); + recordId!, + additionalMetadata: result.ToString(), + timestamp: DateTimeOffset.UtcNow); } } } #endregion - #region Benchmarking + #region Execution - private async Task ExecuteAsync(Func> action) + private async Task ExecuteAsync(Kernel kernel, string title, string prompt) { + Console.WriteLine($"{title}: {prompt}"); + var stopwatch = Stopwatch.StartNew(); - var result = await action(); + var result = await kernel.InvokePromptAsync(prompt); stopwatch.Stop(); diff --git a/dotnet/samples/Concepts/Concepts.csproj b/dotnet/samples/Concepts/Concepts.csproj index 891eea16c400..e4be32a502f8 100644 --- a/dotnet/samples/Concepts/Concepts.csproj +++ b/dotnet/samples/Concepts/Concepts.csproj @@ -48,6 +48,7 @@ + diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecord.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecord.cs index ae93aeb5193f..7a54a02a8d74 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecord.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecord.cs @@ -58,6 +58,9 @@ public AzureCosmosDBMongoDBMemoryRecord(MemoryRecord memoryRecord) /// public static MemoryRecord ToMemoryRecord(BsonDocument doc, bool withEmbedding) { + BsonValue? timestamp = doc["timestamp"]; + DateTimeOffset? recordTimestamp = timestamp is BsonNull ? null : timestamp.ToUniversalTime(); + return new( BsonSerializer .Deserialize( @@ -68,10 +71,8 @@ public static MemoryRecord ToMemoryRecord(BsonDocument doc, bool withEmbedding) ? doc["embedding"].AsBsonArray.Select(x => (float)x.AsDouble).ToArray() : null, doc["_id"].AsString, - doc["timestamp"]?.ToUniversalTime() + recordTimestamp ); - - // return result; } /// @@ -83,7 +84,7 @@ public MemoryRecord ToMemoryRecord(bool withEmbedding) this.Metadata.ToMemoryRecordMetadata(), withEmbedding ? this.Embedding : null, this.Id, - this.Timestamp?.ToLocalTime() + this.Timestamp ); } } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStore.cs index b9d0b203e7b1..be8a82165e9e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStore.cs @@ -147,6 +147,8 @@ public async Task UpsertAsync( CancellationToken cancellationToken = default ) { + record.Key = record.Metadata.Id; + var replaceOptions = new ReplaceOptions() { IsUpsert = true }; var result = await this.GetCollection(collectionName) @@ -340,9 +342,9 @@ private BsonDocument GetIndexDefinitionVectorIVF(string collectionName) "cosmosSearchOptions", new BsonDocument { - { "kind", this._config.Kind }, + { "kind", this._config.Kind.GetCustomName() }, { "numLists", this._config.NumLists }, - { "similarity", this._config.Similarity }, + { "similarity", this._config.Similarity.GetCustomName() }, { "dimensions", this._config.Dimensions } } } @@ -372,10 +374,10 @@ private BsonDocument GetIndexDefinitionVectorHNSW(string collectionName) "cosmosSearchOptions", new BsonDocument { - { "kind", this._config.Kind }, + { "kind", this._config.Kind.GetCustomName() }, { "m", this._config.NumberOfConnections }, { "efConstruction", this._config.EfConstruction }, - { "similarity", this._config.Similarity }, + { "similarity", this._config.Similarity.GetCustomName() }, { "dimensions", this._config.Dimensions } } } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBSimilarityType.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBSimilarityType.cs index cb7b92bdb467..96925d086e3e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBSimilarityType.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBSimilarityType.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Text.Json.Serialization; +using System.Reflection; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; // ReSharper disable InconsistentNaming namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; @@ -13,18 +15,27 @@ public enum AzureCosmosDBSimilarityType /// /// Cosine similarity /// - [JsonPropertyName("COS")] + [BsonElement("COS")] Cosine, /// /// Inner Product similarity /// - [JsonPropertyName("IP")] + [BsonElement("IP")] InnerProduct, /// /// Euclidean similarity /// - [JsonPropertyName("L2")] + [BsonElement("L2")] Euclidean } + +internal static class AzureCosmosDBSimilarityTypeExtensions +{ + public static string GetCustomName(this AzureCosmosDBSimilarityType type) + { + var attribute = type.GetType().GetField(type.ToString()).GetCustomAttribute(); + return attribute?.ElementName ?? type.ToString(); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBVectorSearchType.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBVectorSearchType.cs index c676e5612fef..bf5597131150 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBVectorSearchType.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBVectorSearchType.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Text.Json.Serialization; +using System.Reflection; +using MongoDB.Bson.Serialization.Attributes; // ReSharper disable InconsistentNaming namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; @@ -13,12 +14,21 @@ public enum AzureCosmosDBVectorSearchType /// /// vector-ivf is available on all cluster tiers /// - [JsonPropertyName("vector_ivf")] + [BsonElement("vector-ivf")] VectorIVF, /// /// vector-hnsw is available on M40 cluster tiers and higher. /// - [JsonPropertyName("vector_hnsw")] + [BsonElement("vector-hnsw")] VectorHNSW } + +internal static class AzureCosmosDBVectorSearchTypeExtensions +{ + public static string GetCustomName(this AzureCosmosDBVectorSearchType type) + { + var attribute = type.GetType().GetField(type.ToString()).GetCustomAttribute(); + return attribute?.ElementName ?? type.ToString(); + } +} diff --git a/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs b/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs index d7c08c6344cf..508af88ca0d5 100644 --- a/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs +++ b/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs @@ -42,6 +42,7 @@ public static void Initialize(IConfigurationRoot configRoot) public static MsGraphConfiguration MSGraph => LoadSection(); public static GoogleAIConfig GoogleAI => LoadSection(); public static VertexAIConfig VertexAI => LoadSection(); + public static AzureCosmosDbMongoDbConfig AzureCosmosDbMongoDb => LoadSection(); private static T LoadSection([CallerMemberName] string? caller = null) { @@ -211,6 +212,12 @@ public class GeminiConfig } } + public class AzureCosmosDbMongoDbConfig + { + public string ConnectionString { get; set; } + public string DatabaseName { get; set; } + } + /// /// Graph API connector configuration model. /// diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs index d587fc56778b..7a475658e5e0 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Threading; @@ -21,6 +22,7 @@ public interface ISemanticTextMemory /// Unique identifier. /// Optional description. /// Optional string for saving custom metadata. + /// Optional timestamp. /// The containing services, plugins, and other state for use throughout the operation. /// The to monitor for cancellation requests. The default is . /// Unique identifier of the saved memory record. @@ -30,6 +32,7 @@ public Task SaveInformationAsync( string id, string? description = null, string? additionalMetadata = null, + DateTimeOffset? timestamp = null, Kernel? kernel = null, CancellationToken cancellationToken = default); diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/MemoryRecord.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/MemoryRecord.cs index daf8bf2075a7..690a3d605cf4 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/MemoryRecord.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/MemoryRecord.cs @@ -87,7 +87,7 @@ public static MemoryRecord ReferenceRecord( /// Source content embedding. /// Optional string for saving custom metadata. /// Optional existing database key. - /// optional timestamp. + /// Optional timestamp. /// Memory record public static MemoryRecord LocalRecord( string id, diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/NullMemory.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/NullMemory.cs index 1bbf72e429a8..9e01f2addc51 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/NullMemory.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/NullMemory.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; @@ -28,6 +29,7 @@ public Task SaveInformationAsync( string id, string? description = null, string? additionalMetadata = null, + DateTimeOffset? timestamp = null, Kernel? kernel = null, CancellationToken cancellationToken = default) { diff --git a/dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs b/dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs index a584d9f4cf1d..6210048240c4 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs @@ -41,12 +41,18 @@ public async Task SaveInformationAsync( string id, string? description = null, string? additionalMetadata = null, + DateTimeOffset? timestamp = null, Kernel? kernel = null, CancellationToken cancellationToken = default) { var embedding = await this._embeddingGenerator.GenerateEmbeddingAsync(text, kernel, cancellationToken).ConfigureAwait(false); MemoryRecord data = MemoryRecord.LocalRecord( - id: id, text: text, description: description, additionalMetadata: additionalMetadata, embedding: embedding); + id: id, + text: text, + description: description, + additionalMetadata: additionalMetadata, + embedding: embedding, + timestamp: timestamp); if (!(await this._storage.DoesCollectionExistAsync(collection, cancellationToken).ConfigureAwait(false))) { @@ -116,17 +122,20 @@ public async IAsyncEnumerable SearchAsync( { ReadOnlyMemory queryEmbedding = await this._embeddingGenerator.GenerateEmbeddingAsync(query, kernel, cancellationToken).ConfigureAwait(false); - IAsyncEnumerable<(MemoryRecord, double)> results = this._storage.GetNearestMatchesAsync( - collectionName: collection, - embedding: queryEmbedding, - limit: limit, - minRelevanceScore: minRelevanceScore, - withEmbeddings: withEmbeddings, - cancellationToken: cancellationToken); - - await foreach ((MemoryRecord, double) result in results.WithCancellation(cancellationToken).ConfigureAwait(false)) + if ((await this._storage.DoesCollectionExistAsync(collection, cancellationToken).ConfigureAwait(false))) { - yield return MemoryQueryResult.FromMemoryRecord(result.Item1, result.Item2); + IAsyncEnumerable<(MemoryRecord, double)> results = this._storage.GetNearestMatchesAsync( + collectionName: collection, + embedding: queryEmbedding, + limit: limit, + minRelevanceScore: minRelevanceScore, + withEmbeddings: withEmbeddings, + cancellationToken: cancellationToken); + + await foreach ((MemoryRecord, double) result in results.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return MemoryQueryResult.FromMemoryRecord(result.Item1, result.Item2); + } } } From a0f26ab6b368b53c083bfebc5a6b0d7ea3e63f56 Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Tue, 7 May 2024 16:56:12 -0700 Subject: [PATCH 3/7] Small improvements and comments --- .../Caching/SemanticCachingWithFilters.cs | 115 ++++++++++++++++-- 1 file changed, 103 insertions(+), 12 deletions(-) diff --git a/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs b/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs index 037e042570cc..57f43a710d32 100644 --- a/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs +++ b/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs @@ -9,10 +9,25 @@ namespace Caching; +/// +/// This example shows how to achieve Semantic Caching with Filters. +/// is used to get rendered prompt and check in cache if similar prompt was already answered. +/// If there is a record in cache, then previously cached answer will be returned to the user instead of making a call to LLM. +/// If there is no record in cache, a call to LLM will be performed, and result will be cached together with rendered prompt. +/// is used to update cache with rendered prompt and related LLM result. +/// public class SemanticCachingWithFilters(ITestOutputHelper output) : BaseTest(output) { + /// + /// Similarity/relevance score, from 0 to 1, where 1 means exact match. + /// It's possible to change this value during testing to see how caching logic will behave. + /// private const double SimilarityScore = 0.9; + /// + /// Executing similar requests two times using in-memory caching store to compare execution time and results. + /// Second execution is faster, because the result is returned from cache. + /// [Fact] public async Task InMemoryCacheAsync() { @@ -23,8 +38,23 @@ public async Task InMemoryCacheAsync() Console.WriteLine($"Result 1: {result1}"); Console.WriteLine($"Result 2: {result2}"); + + /* + Output: + First run: What's the tallest building in New York? + Elapsed Time: 00:00:03.828 + Second run: What is the highest building in New York City? + Elapsed Time: 00:00:00.541 + Result 1: The tallest building in New York is One World Trade Center, also known as Freedom Tower.It stands at 1,776 feet(541.3 meters) tall, including its spire. + Result 2: The tallest building in New York is One World Trade Center, also known as Freedom Tower.It stands at 1,776 feet(541.3 meters) tall, including its spire. + */ } + /// + /// Executing similar requests two times using Redis caching store to compare execution time and results. + /// Second execution is faster, because the result is returned from cache. + /// How to run Redis on Docker locally: https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/docker/ + /// [Fact] public async Task RedisCacheAsync() { @@ -35,8 +65,22 @@ public async Task RedisCacheAsync() Console.WriteLine($"Result 1: {result1}"); Console.WriteLine($"Result 2: {result2}"); + + /* + First run: What's the tallest building in New York? + Elapsed Time: 00:00:03.674 + Second run: What is the highest building in New York City? + Elapsed Time: 00:00:00.292 + Result 1: The tallest building in New York is One World Trade Center, also known as Freedom Tower. It stands at 1,776 feet (541 meters) tall, including its spire. + Result 2: The tallest building in New York is One World Trade Center, also known as Freedom Tower. It stands at 1,776 feet (541 meters) tall, including its spire. + */ } + /// + /// Executing similar requests two times using Azure Cosmos DB for MongoDB caching store to compare execution time and results. + /// Second execution is faster, because the result is returned from cache. + /// How to setup Azure Cosmos DB for MongoDB cluster: https://learn.microsoft.com/en-gb/azure/cosmos-db/mongodb/vcore/quickstart-portal + /// [Fact] public async Task AzureCosmosDBMongoDBCacheAsync() { @@ -55,83 +99,127 @@ public async Task AzureCosmosDBMongoDBCacheAsync() Console.WriteLine($"Result 1: {result1}"); Console.WriteLine($"Result 2: {result2}"); + + /* + First run: What's the tallest building in New York? + Elapsed Time: 00:00:05.485 + Second run: What is the highest building in New York City? + Elapsed Time: 00:00:00.389 + Result 1: The tallest building in New York is One World Trade Center, also known as Freedom Tower, which stands at 1,776 feet (541.3 meters) tall. + Result 2: The tallest building in New York is One World Trade Center, also known as Freedom Tower, which stands at 1,776 feet (541.3 meters) tall. + */ } #region Configuration + /// + /// Returns instance with required registered services. + /// private Kernel GetKernelWithCache(Func cacheFactory) { var builder = Kernel.CreateBuilder(); + // Add Azure OpenAI chat completion service builder.AddAzureOpenAIChatCompletion( TestConfiguration.AzureOpenAI.ChatDeploymentName, TestConfiguration.AzureOpenAI.Endpoint, TestConfiguration.AzureOpenAI.ApiKey); + // Add Azure OpenAI text embedding generation service builder.AddAzureOpenAITextEmbeddingGeneration( TestConfiguration.AzureOpenAIEmbeddings.DeploymentName, TestConfiguration.AzureOpenAIEmbeddings.Endpoint, TestConfiguration.AzureOpenAIEmbeddings.ApiKey); + // Add memory store for caching purposes (e.g. in-memory, Redis, Azure Cosmos DB) builder.Services.AddSingleton(cacheFactory); + + // Add text memory service that will be used to generate embeddings and query/store data. builder.Services.AddSingleton(); + // Add prompt render filter to query cache and check if rendered prompt was already answered. + builder.Services.AddSingleton(); + + // Add function invocation filter to cache rendered prompts and LLM results. builder.Services.AddSingleton(); - builder.Services.AddSingleton( - (sp) => new PromptCacheFilter(sp.GetRequiredService(), SimilarityScore)); return builder.Build(); } #endregion - #region Caching filters + #region Cache Filters + /// + /// Base class for filters that contains common constant values. + /// public class CacheBaseFilter { - protected const string CacheCollectionName = "llm_responses"; - - protected const string CacheRecordIdKey = "CacheRecordId"; + /// + /// Collection/table name in cache to use. + /// + protected const string CollectionName = "llm_responses"; + + /// + /// Metadata key in function result for cache record id, which is used to overwrite previously cached response. + /// + protected const string RecordIdKey = "CacheRecordId"; } - public sealed class PromptCacheFilter(ISemanticTextMemory semanticTextMemory, double minRelevanceScore) : CacheBaseFilter, IPromptRenderFilter + /// + /// Filter which is executed during prompt rendering operation. + /// + public sealed class PromptCacheFilter(ISemanticTextMemory semanticTextMemory) : CacheBaseFilter, IPromptRenderFilter { public async Task OnPromptRenderAsync(PromptRenderContext context, Func next) { + // Trigger prompt rendering operation await next(context); + // Get rendered prompt var prompt = context.RenderedPrompt!; + // Search for similar prompts in cache with provided similarity/relevance score var searchResult = await semanticTextMemory.SearchAsync( - CacheCollectionName, + CollectionName, prompt, limit: 1, - minRelevanceScore: minRelevanceScore).FirstOrDefaultAsync(); + minRelevanceScore: SimilarityScore).FirstOrDefaultAsync(); + // If result exists, return it. if (searchResult is not null) { + // Override function result. This will prevent calling LLM and will return result immediately. context.Result = new FunctionResult(context.Function, searchResult.Metadata.AdditionalMetadata) { - Metadata = new Dictionary { [CacheRecordIdKey] = searchResult.Metadata.Id } + Metadata = new Dictionary { [RecordIdKey] = searchResult.Metadata.Id } }; } } } + /// + /// Filter which is executed during function invocation. + /// public sealed class FunctionCacheFilter(ISemanticTextMemory semanticTextMemory) : CacheBaseFilter, IFunctionInvocationFilter { public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func next) { + // Trigger function invocation await next(context); + // Get function invocation result var result = context.Result; + // If there was any rendered prompt, cache it together with LLM result for future calls. if (!string.IsNullOrEmpty(context.Result.RenderedPrompt)) { - var recordId = context.Result.Metadata?.GetValueOrDefault(CacheRecordIdKey, Guid.NewGuid().ToString()) as string; + // Get cache record id if result was cached previously or generate new id. + var recordId = context.Result.Metadata?.GetValueOrDefault(RecordIdKey, Guid.NewGuid().ToString()) as string; + // Cache rendered prompt, LLM result and timestamp. await semanticTextMemory.SaveInformationAsync( - CacheCollectionName, + CollectionName, context.Result.RenderedPrompt, recordId!, additionalMetadata: result.ToString(), @@ -144,6 +232,9 @@ await semanticTextMemory.SaveInformationAsync( #region Execution + /// + /// Helper method to invoke prompt and measure execution time for comparison. + /// private async Task ExecuteAsync(Kernel kernel, string title, string prompt) { Console.WriteLine($"{title}: {prompt}"); From 7012cf9ba2da36e542d75e460c851ceb6c964463 Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Tue, 7 May 2024 17:25:31 -0700 Subject: [PATCH 4/7] Small improvement --- .../SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs index 818f54e7ae2a..f0340b710873 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs @@ -135,7 +135,8 @@ protected override async ValueTask InvokeCoreAsync( // Return function result if it was set in prompt filter. if (result.FunctionResult is not null) { - return new(result.FunctionResult) { RenderedPrompt = result.RenderedPrompt }; + result.FunctionResult.RenderedPrompt = result.RenderedPrompt; + return result.FunctionResult; } if (result.AIService is IChatCompletionService chatCompletion) From 8b2591e15712afd2ab864f5cf7bebd0e5dce0cda Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Tue, 7 May 2024 17:31:38 -0700 Subject: [PATCH 5/7] Reverted changes in SemanticTextMemory --- .../samples/Concepts/Caching/SemanticCachingWithFilters.cs | 5 ++--- .../Memory/ISemanticTextMemory.cs | 2 -- dotnet/src/SemanticKernel.Abstractions/Memory/NullMemory.cs | 2 -- dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs | 4 +--- 4 files changed, 3 insertions(+), 10 deletions(-) diff --git a/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs b/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs index 57f43a710d32..2f3cbb7181b1 100644 --- a/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs +++ b/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs @@ -217,13 +217,12 @@ public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, F // Get cache record id if result was cached previously or generate new id. var recordId = context.Result.Metadata?.GetValueOrDefault(RecordIdKey, Guid.NewGuid().ToString()) as string; - // Cache rendered prompt, LLM result and timestamp. + // Cache rendered prompt and LLM result. await semanticTextMemory.SaveInformationAsync( CollectionName, context.Result.RenderedPrompt, recordId!, - additionalMetadata: result.ToString(), - timestamp: DateTimeOffset.UtcNow); + additionalMetadata: result.ToString()); } } } diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs index 7a475658e5e0..fad2ade9e881 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs @@ -22,7 +22,6 @@ public interface ISemanticTextMemory /// Unique identifier. /// Optional description. /// Optional string for saving custom metadata. - /// Optional timestamp. /// The containing services, plugins, and other state for use throughout the operation. /// The to monitor for cancellation requests. The default is . /// Unique identifier of the saved memory record. @@ -32,7 +31,6 @@ public Task SaveInformationAsync( string id, string? description = null, string? additionalMetadata = null, - DateTimeOffset? timestamp = null, Kernel? kernel = null, CancellationToken cancellationToken = default); diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/NullMemory.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/NullMemory.cs index 9e01f2addc51..1bbf72e429a8 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/NullMemory.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/NullMemory.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; @@ -29,7 +28,6 @@ public Task SaveInformationAsync( string id, string? description = null, string? additionalMetadata = null, - DateTimeOffset? timestamp = null, Kernel? kernel = null, CancellationToken cancellationToken = default) { diff --git a/dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs b/dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs index 6210048240c4..09819aea796d 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs @@ -41,7 +41,6 @@ public async Task SaveInformationAsync( string id, string? description = null, string? additionalMetadata = null, - DateTimeOffset? timestamp = null, Kernel? kernel = null, CancellationToken cancellationToken = default) { @@ -51,8 +50,7 @@ public async Task SaveInformationAsync( text: text, description: description, additionalMetadata: additionalMetadata, - embedding: embedding, - timestamp: timestamp); + embedding: embedding); if (!(await this._storage.DoesCollectionExistAsync(collection, cancellationToken).ConfigureAwait(false))) { From 05c859c2e7e94f6fc2de4d94ce4b9e4f11b7d925 Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Tue, 7 May 2024 17:34:56 -0700 Subject: [PATCH 6/7] Fixed formatting --- .../SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs index fad2ade9e881..d587fc56778b 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Threading; From b1436ff0f4df8ce66796e2d5f57dcd34b1ac1e76 Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Wed, 8 May 2024 06:59:36 -0700 Subject: [PATCH 7/7] Updated README --- dotnet/samples/Concepts/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dotnet/samples/Concepts/README.md b/dotnet/samples/Concepts/README.md index 75b46663a2f6..828abd9ef290 100644 --- a/dotnet/samples/Concepts/README.md +++ b/dotnet/samples/Concepts/README.md @@ -26,6 +26,10 @@ Down below you can find the code snippets that demonstrate the usage of many Sem - [Gemini_FunctionCalling](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/AutoFunctionCalling/Gemini_FunctionCalling.cs) - [OpenAI_FunctionCalling](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/AutoFunctionCalling/OpenAI_FunctionCalling.cs) +## Caching - Examples of caching implementations + +- [SemanticCachingWithFilters](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs) + ## ChatCompletion - Examples using [`ChatCompletion`](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/IChatCompletionService.cs) messaging capable service with models - [AzureOpenAIWithData_ChatCompletion](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/ChatCompletion/AzureOpenAIWithData_ChatCompletion.cs)