diff --git a/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs b/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs new file mode 100644 index 000000000000..2f3cbb7181b1 --- /dev/null +++ b/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs @@ -0,0 +1,253 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; +using Microsoft.SemanticKernel.Connectors.Redis; +using Microsoft.SemanticKernel.Memory; + +namespace Caching; + +/// +/// This example shows how to achieve Semantic Caching with Filters. +/// is used to get rendered prompt and check in cache if similar prompt was already answered. +/// If there is a record in cache, then previously cached answer will be returned to the user instead of making a call to LLM. +/// If there is no record in cache, a call to LLM will be performed, and result will be cached together with rendered prompt. +/// is used to update cache with rendered prompt and related LLM result. +/// +public class SemanticCachingWithFilters(ITestOutputHelper output) : BaseTest(output) +{ + /// + /// Similarity/relevance score, from 0 to 1, where 1 means exact match. + /// It's possible to change this value during testing to see how caching logic will behave. + /// + private const double SimilarityScore = 0.9; + + /// + /// Executing similar requests two times using in-memory caching store to compare execution time and results. + /// Second execution is faster, because the result is returned from cache. + /// + [Fact] + public async Task InMemoryCacheAsync() + { + var kernel = GetKernelWithCache(_ => new VolatileMemoryStore()); + + var result1 = await ExecuteAsync(kernel, "First run", "What's the tallest building in New York?"); + var result2 = await ExecuteAsync(kernel, "Second run", "What is the highest building in New York City?"); + + Console.WriteLine($"Result 1: {result1}"); + Console.WriteLine($"Result 2: {result2}"); + + /* + Output: + First run: What's the tallest building in New York? + Elapsed Time: 00:00:03.828 + Second run: What is the highest building in New York City? + Elapsed Time: 00:00:00.541 + Result 1: The tallest building in New York is One World Trade Center, also known as Freedom Tower.It stands at 1,776 feet(541.3 meters) tall, including its spire. + Result 2: The tallest building in New York is One World Trade Center, also known as Freedom Tower.It stands at 1,776 feet(541.3 meters) tall, including its spire. + */ + } + + /// + /// Executing similar requests two times using Redis caching store to compare execution time and results. + /// Second execution is faster, because the result is returned from cache. + /// How to run Redis on Docker locally: https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/docker/ + /// + [Fact] + public async Task RedisCacheAsync() + { + var kernel = GetKernelWithCache(_ => new RedisMemoryStore("localhost:6379", vectorSize: 1536)); + + var result1 = await ExecuteAsync(kernel, "First run", "What's the tallest building in New York?"); + var result2 = await ExecuteAsync(kernel, "Second run", "What is the highest building in New York City?"); + + Console.WriteLine($"Result 1: {result1}"); + Console.WriteLine($"Result 2: {result2}"); + + /* + First run: What's the tallest building in New York? + Elapsed Time: 00:00:03.674 + Second run: What is the highest building in New York City? + Elapsed Time: 00:00:00.292 + Result 1: The tallest building in New York is One World Trade Center, also known as Freedom Tower. It stands at 1,776 feet (541 meters) tall, including its spire. + Result 2: The tallest building in New York is One World Trade Center, also known as Freedom Tower. It stands at 1,776 feet (541 meters) tall, including its spire. + */ + } + + /// + /// Executing similar requests two times using Azure Cosmos DB for MongoDB caching store to compare execution time and results. + /// Second execution is faster, because the result is returned from cache. + /// How to setup Azure Cosmos DB for MongoDB cluster: https://learn.microsoft.com/en-gb/azure/cosmos-db/mongodb/vcore/quickstart-portal + /// + [Fact] + public async Task AzureCosmosDBMongoDBCacheAsync() + { + var kernel = GetKernelWithCache(_ => new AzureCosmosDBMongoDBMemoryStore( + TestConfiguration.AzureCosmosDbMongoDb.ConnectionString, + TestConfiguration.AzureCosmosDbMongoDb.DatabaseName, + new() + { + Kind = AzureCosmosDBVectorSearchType.VectorIVF, + Similarity = AzureCosmosDBSimilarityType.Cosine, + Dimensions = 1536 + })); + + var result1 = await ExecuteAsync(kernel, "First run", "What's the tallest building in New York?"); + var result2 = await ExecuteAsync(kernel, "Second run", "What is the highest building in New York City?"); + + Console.WriteLine($"Result 1: {result1}"); + Console.WriteLine($"Result 2: {result2}"); + + /* + First run: What's the tallest building in New York? + Elapsed Time: 00:00:05.485 + Second run: What is the highest building in New York City? + Elapsed Time: 00:00:00.389 + Result 1: The tallest building in New York is One World Trade Center, also known as Freedom Tower, which stands at 1,776 feet (541.3 meters) tall. + Result 2: The tallest building in New York is One World Trade Center, also known as Freedom Tower, which stands at 1,776 feet (541.3 meters) tall. + */ + } + + #region Configuration + + /// + /// Returns instance with required registered services. + /// + private Kernel GetKernelWithCache(Func cacheFactory) + { + var builder = Kernel.CreateBuilder(); + + // Add Azure OpenAI chat completion service + builder.AddAzureOpenAIChatCompletion( + TestConfiguration.AzureOpenAI.ChatDeploymentName, + TestConfiguration.AzureOpenAI.Endpoint, + TestConfiguration.AzureOpenAI.ApiKey); + + // Add Azure OpenAI text embedding generation service + builder.AddAzureOpenAITextEmbeddingGeneration( + TestConfiguration.AzureOpenAIEmbeddings.DeploymentName, + TestConfiguration.AzureOpenAIEmbeddings.Endpoint, + TestConfiguration.AzureOpenAIEmbeddings.ApiKey); + + // Add memory store for caching purposes (e.g. in-memory, Redis, Azure Cosmos DB) + builder.Services.AddSingleton(cacheFactory); + + // Add text memory service that will be used to generate embeddings and query/store data. + builder.Services.AddSingleton(); + + // Add prompt render filter to query cache and check if rendered prompt was already answered. + builder.Services.AddSingleton(); + + // Add function invocation filter to cache rendered prompts and LLM results. + builder.Services.AddSingleton(); + + return builder.Build(); + } + + #endregion + + #region Cache Filters + + /// + /// Base class for filters that contains common constant values. + /// + public class CacheBaseFilter + { + /// + /// Collection/table name in cache to use. + /// + protected const string CollectionName = "llm_responses"; + + /// + /// Metadata key in function result for cache record id, which is used to overwrite previously cached response. + /// + protected const string RecordIdKey = "CacheRecordId"; + } + + /// + /// Filter which is executed during prompt rendering operation. + /// + public sealed class PromptCacheFilter(ISemanticTextMemory semanticTextMemory) : CacheBaseFilter, IPromptRenderFilter + { + public async Task OnPromptRenderAsync(PromptRenderContext context, Func next) + { + // Trigger prompt rendering operation + await next(context); + + // Get rendered prompt + var prompt = context.RenderedPrompt!; + + // Search for similar prompts in cache with provided similarity/relevance score + var searchResult = await semanticTextMemory.SearchAsync( + CollectionName, + prompt, + limit: 1, + minRelevanceScore: SimilarityScore).FirstOrDefaultAsync(); + + // If result exists, return it. + if (searchResult is not null) + { + // Override function result. This will prevent calling LLM and will return result immediately. + context.Result = new FunctionResult(context.Function, searchResult.Metadata.AdditionalMetadata) + { + Metadata = new Dictionary { [RecordIdKey] = searchResult.Metadata.Id } + }; + } + } + } + + /// + /// Filter which is executed during function invocation. + /// + public sealed class FunctionCacheFilter(ISemanticTextMemory semanticTextMemory) : CacheBaseFilter, IFunctionInvocationFilter + { + public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func next) + { + // Trigger function invocation + await next(context); + + // Get function invocation result + var result = context.Result; + + // If there was any rendered prompt, cache it together with LLM result for future calls. + if (!string.IsNullOrEmpty(context.Result.RenderedPrompt)) + { + // Get cache record id if result was cached previously or generate new id. + var recordId = context.Result.Metadata?.GetValueOrDefault(RecordIdKey, Guid.NewGuid().ToString()) as string; + + // Cache rendered prompt and LLM result. + await semanticTextMemory.SaveInformationAsync( + CollectionName, + context.Result.RenderedPrompt, + recordId!, + additionalMetadata: result.ToString()); + } + } + } + + #endregion + + #region Execution + + /// + /// Helper method to invoke prompt and measure execution time for comparison. + /// + private async Task ExecuteAsync(Kernel kernel, string title, string prompt) + { + Console.WriteLine($"{title}: {prompt}"); + + var stopwatch = Stopwatch.StartNew(); + + var result = await kernel.InvokePromptAsync(prompt); + + stopwatch.Stop(); + + Console.WriteLine($@"Elapsed Time: {stopwatch.Elapsed:hh\:mm\:ss\.FFF}"); + + return result; + } + + #endregion +} diff --git a/dotnet/samples/Concepts/Concepts.csproj b/dotnet/samples/Concepts/Concepts.csproj index 891eea16c400..e4be32a502f8 100644 --- a/dotnet/samples/Concepts/Concepts.csproj +++ b/dotnet/samples/Concepts/Concepts.csproj @@ -48,6 +48,7 @@ + diff --git a/dotnet/samples/Concepts/README.md b/dotnet/samples/Concepts/README.md index 75b46663a2f6..828abd9ef290 100644 --- a/dotnet/samples/Concepts/README.md +++ b/dotnet/samples/Concepts/README.md @@ -26,6 +26,10 @@ Down below you can find the code snippets that demonstrate the usage of many Sem - [Gemini_FunctionCalling](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/AutoFunctionCalling/Gemini_FunctionCalling.cs) - [OpenAI_FunctionCalling](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/AutoFunctionCalling/OpenAI_FunctionCalling.cs) +## Caching - Examples of caching implementations + +- [SemanticCachingWithFilters](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs) + ## ChatCompletion - Examples using [`ChatCompletion`](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/IChatCompletionService.cs) messaging capable service with models - [AzureOpenAIWithData_ChatCompletion](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/ChatCompletion/AzureOpenAIWithData_ChatCompletion.cs) diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecord.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecord.cs index ae93aeb5193f..7a54a02a8d74 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecord.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecord.cs @@ -58,6 +58,9 @@ public AzureCosmosDBMongoDBMemoryRecord(MemoryRecord memoryRecord) /// public static MemoryRecord ToMemoryRecord(BsonDocument doc, bool withEmbedding) { + BsonValue? timestamp = doc["timestamp"]; + DateTimeOffset? recordTimestamp = timestamp is BsonNull ? null : timestamp.ToUniversalTime(); + return new( BsonSerializer .Deserialize( @@ -68,10 +71,8 @@ public static MemoryRecord ToMemoryRecord(BsonDocument doc, bool withEmbedding) ? doc["embedding"].AsBsonArray.Select(x => (float)x.AsDouble).ToArray() : null, doc["_id"].AsString, - doc["timestamp"]?.ToUniversalTime() + recordTimestamp ); - - // return result; } /// @@ -83,7 +84,7 @@ public MemoryRecord ToMemoryRecord(bool withEmbedding) this.Metadata.ToMemoryRecordMetadata(), withEmbedding ? this.Embedding : null, this.Id, - this.Timestamp?.ToLocalTime() + this.Timestamp ); } } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStore.cs index b9d0b203e7b1..be8a82165e9e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStore.cs @@ -147,6 +147,8 @@ public async Task UpsertAsync( CancellationToken cancellationToken = default ) { + record.Key = record.Metadata.Id; + var replaceOptions = new ReplaceOptions() { IsUpsert = true }; var result = await this.GetCollection(collectionName) @@ -340,9 +342,9 @@ private BsonDocument GetIndexDefinitionVectorIVF(string collectionName) "cosmosSearchOptions", new BsonDocument { - { "kind", this._config.Kind }, + { "kind", this._config.Kind.GetCustomName() }, { "numLists", this._config.NumLists }, - { "similarity", this._config.Similarity }, + { "similarity", this._config.Similarity.GetCustomName() }, { "dimensions", this._config.Dimensions } } } @@ -372,10 +374,10 @@ private BsonDocument GetIndexDefinitionVectorHNSW(string collectionName) "cosmosSearchOptions", new BsonDocument { - { "kind", this._config.Kind }, + { "kind", this._config.Kind.GetCustomName() }, { "m", this._config.NumberOfConnections }, { "efConstruction", this._config.EfConstruction }, - { "similarity", this._config.Similarity }, + { "similarity", this._config.Similarity.GetCustomName() }, { "dimensions", this._config.Dimensions } } } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBSimilarityType.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBSimilarityType.cs index cb7b92bdb467..96925d086e3e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBSimilarityType.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBSimilarityType.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Text.Json.Serialization; +using System.Reflection; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; // ReSharper disable InconsistentNaming namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; @@ -13,18 +15,27 @@ public enum AzureCosmosDBSimilarityType /// /// Cosine similarity /// - [JsonPropertyName("COS")] + [BsonElement("COS")] Cosine, /// /// Inner Product similarity /// - [JsonPropertyName("IP")] + [BsonElement("IP")] InnerProduct, /// /// Euclidean similarity /// - [JsonPropertyName("L2")] + [BsonElement("L2")] Euclidean } + +internal static class AzureCosmosDBSimilarityTypeExtensions +{ + public static string GetCustomName(this AzureCosmosDBSimilarityType type) + { + var attribute = type.GetType().GetField(type.ToString()).GetCustomAttribute(); + return attribute?.ElementName ?? type.ToString(); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBVectorSearchType.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBVectorSearchType.cs index c676e5612fef..bf5597131150 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBVectorSearchType.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBVectorSearchType.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Text.Json.Serialization; +using System.Reflection; +using MongoDB.Bson.Serialization.Attributes; // ReSharper disable InconsistentNaming namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; @@ -13,12 +14,21 @@ public enum AzureCosmosDBVectorSearchType /// /// vector-ivf is available on all cluster tiers /// - [JsonPropertyName("vector_ivf")] + [BsonElement("vector-ivf")] VectorIVF, /// /// vector-hnsw is available on M40 cluster tiers and higher. /// - [JsonPropertyName("vector_hnsw")] + [BsonElement("vector-hnsw")] VectorHNSW } + +internal static class AzureCosmosDBVectorSearchTypeExtensions +{ + public static string GetCustomName(this AzureCosmosDBVectorSearchType type) + { + var attribute = type.GetType().GetField(type.ToString()).GetCustomAttribute(); + return attribute?.ElementName ?? type.ToString(); + } +} diff --git a/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs b/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs index d7c08c6344cf..508af88ca0d5 100644 --- a/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs +++ b/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs @@ -42,6 +42,7 @@ public static void Initialize(IConfigurationRoot configRoot) public static MsGraphConfiguration MSGraph => LoadSection(); public static GoogleAIConfig GoogleAI => LoadSection(); public static VertexAIConfig VertexAI => LoadSection(); + public static AzureCosmosDbMongoDbConfig AzureCosmosDbMongoDb => LoadSection(); private static T LoadSection([CallerMemberName] string? caller = null) { @@ -211,6 +212,12 @@ public class GeminiConfig } } + public class AzureCosmosDbMongoDbConfig + { + public string ConnectionString { get; set; } + public string DatabaseName { get; set; } + } + /// /// Graph API connector configuration model. /// diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs index 79402ceac836..a1e449642071 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs @@ -62,4 +62,10 @@ public string? RenderedPrompt this._renderedPrompt = value; } } + + /// + /// Gets or sets the result of the function's invocation. + /// Setting to a non-null value will skip function invocation and return the result. + /// + public FunctionResult? Result { get; set; } } diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs index 0ebba8bca441..62cc5d343d01 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Globalization; namespace Microsoft.SemanticKernel; @@ -41,6 +42,7 @@ public FunctionResult(FunctionResult result, object? value = null) this.Value = value ?? result.Value; this.Culture = result.Culture; this.Metadata = result.Metadata; + this.RenderedPrompt = result.RenderedPrompt; } /// @@ -67,6 +69,12 @@ public FunctionResult(FunctionResult result, object? value = null) /// public Type? ValueType => this.Value?.GetType(); + /// + /// Gets the prompt used during function invocation if any was rendered. + /// + [Experimental("SKEXP0001")] + public string? RenderedPrompt { get; internal set; } + /// /// Returns function result value. /// diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/MemoryRecord.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/MemoryRecord.cs index daf8bf2075a7..690a3d605cf4 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/MemoryRecord.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/MemoryRecord.cs @@ -87,7 +87,7 @@ public static MemoryRecord ReferenceRecord( /// Source content embedding. /// Optional string for saving custom metadata. /// Optional existing database key. - /// optional timestamp. + /// Optional timestamp. /// Memory record public static MemoryRecord LocalRecord( string id, diff --git a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs index ff2b16578038..f0340b710873 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs @@ -115,7 +115,7 @@ public static KernelFunction Create( logger: loggerFactory?.CreateLogger(typeof(KernelFunctionFactory)) ?? NullLogger.Instance); } - /// j + /// protected override async ValueTask InvokeCoreAsync( Kernel kernel, KernelArguments arguments, @@ -132,18 +132,25 @@ protected override async ValueTask InvokeCoreAsync( } #pragma warning restore CS0612 // Events are deprecated + // Return function result if it was set in prompt filter. + if (result.FunctionResult is not null) + { + result.FunctionResult.RenderedPrompt = result.RenderedPrompt; + return result.FunctionResult; + } + if (result.AIService is IChatCompletionService chatCompletion) { var chatContent = await chatCompletion.GetChatMessageContentAsync(result.RenderedPrompt, result.ExecutionSettings, kernel, cancellationToken).ConfigureAwait(false); this.CaptureUsageDetails(chatContent.ModelId, chatContent.Metadata, this._logger); - return new FunctionResult(this, chatContent, kernel.Culture, chatContent.Metadata); + return new FunctionResult(this, chatContent, kernel.Culture, chatContent.Metadata) { RenderedPrompt = result.RenderedPrompt }; } if (result.AIService is ITextGenerationService textGeneration) { var textContent = await textGeneration.GetTextContentWithDefaultParserAsync(result.RenderedPrompt, result.ExecutionSettings, kernel, cancellationToken).ConfigureAwait(false); this.CaptureUsageDetails(textContent.ModelId, textContent.Metadata, this._logger); - return new FunctionResult(this, textContent, kernel.Culture, textContent.Metadata); + return new FunctionResult(this, textContent, kernel.Culture, textContent.Metadata) { RenderedPrompt = result.RenderedPrompt }; } // The service selector didn't find an appropriate service. This should only happen with a poorly implemented selector. @@ -375,6 +382,7 @@ private async Task RenderPromptAsync(Kernel kernel, Kerne { ExecutionSettings = executionSettings, RenderedEventArgs = renderedEventArgs, + FunctionResult = renderingContext.Result }; } diff --git a/dotnet/src/SemanticKernel.Core/Functions/PromptRenderingResult.cs b/dotnet/src/SemanticKernel.Core/Functions/PromptRenderingResult.cs index 765585be9960..7aee48fc130b 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/PromptRenderingResult.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/PromptRenderingResult.cs @@ -15,6 +15,8 @@ internal sealed class PromptRenderingResult public PromptExecutionSettings? ExecutionSettings { get; set; } + public FunctionResult? FunctionResult { get; set; } + #pragma warning disable CS0618 // Events are deprecated public PromptRenderedEventArgs? RenderedEventArgs { get; set; } #pragma warning restore CS0618 // Events are deprecated diff --git a/dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs b/dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs index a584d9f4cf1d..09819aea796d 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs @@ -46,7 +46,11 @@ public async Task SaveInformationAsync( { var embedding = await this._embeddingGenerator.GenerateEmbeddingAsync(text, kernel, cancellationToken).ConfigureAwait(false); MemoryRecord data = MemoryRecord.LocalRecord( - id: id, text: text, description: description, additionalMetadata: additionalMetadata, embedding: embedding); + id: id, + text: text, + description: description, + additionalMetadata: additionalMetadata, + embedding: embedding); if (!(await this._storage.DoesCollectionExistAsync(collection, cancellationToken).ConfigureAwait(false))) { @@ -116,17 +120,20 @@ public async IAsyncEnumerable SearchAsync( { ReadOnlyMemory queryEmbedding = await this._embeddingGenerator.GenerateEmbeddingAsync(query, kernel, cancellationToken).ConfigureAwait(false); - IAsyncEnumerable<(MemoryRecord, double)> results = this._storage.GetNearestMatchesAsync( - collectionName: collection, - embedding: queryEmbedding, - limit: limit, - minRelevanceScore: minRelevanceScore, - withEmbeddings: withEmbeddings, - cancellationToken: cancellationToken); - - await foreach ((MemoryRecord, double) result in results.WithCancellation(cancellationToken).ConfigureAwait(false)) + if ((await this._storage.DoesCollectionExistAsync(collection, cancellationToken).ConfigureAwait(false))) { - yield return MemoryQueryResult.FromMemoryRecord(result.Item1, result.Item2); + IAsyncEnumerable<(MemoryRecord, double)> results = this._storage.GetNearestMatchesAsync( + collectionName: collection, + embedding: queryEmbedding, + limit: limit, + minRelevanceScore: minRelevanceScore, + withEmbeddings: withEmbeddings, + cancellationToken: cancellationToken); + + await foreach ((MemoryRecord, double) result in results.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return MemoryQueryResult.FromMemoryRecord(result.Item1, result.Item2); + } } } diff --git a/dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs b/dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs index eff697278997..020008070387 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs @@ -236,4 +236,32 @@ public async Task PostInvocationPromptFilterSkippingWorksCorrectlyAsync() // Assert mockTextGeneration.Verify(m => m.GetTextContentsAsync("", It.IsAny(), It.IsAny(), It.IsAny()), Times.Once()); } + + [Fact] + public async Task PromptFilterCanOverrideFunctionResultAsync() + { + // Arrange + var mockTextGeneration = this.GetMockTextGeneration(); + var function = KernelFunctionFactory.CreateFromPrompt("Prompt"); + + var kernel = this.GetKernelWithFilters(textGenerationService: mockTextGeneration.Object, + onPromptRender: async (context, next) => + { + await next(context); + + context.Result = new FunctionResult(context.Function, "Result from prompt filter"); + }, + onFunctionInvocation: async (context, next) => + { + await next(context); + }); + + // Act + var result = await kernel.InvokeAsync(function); + + // Assert + mockTextGeneration.Verify(m => m.GetTextContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Never()); + + Assert.Equal("Result from prompt filter", result.ToString()); + } }