diff --git a/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs b/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs
new file mode 100644
index 000000000000..2f3cbb7181b1
--- /dev/null
+++ b/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs
@@ -0,0 +1,253 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System.Diagnostics;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB;
+using Microsoft.SemanticKernel.Connectors.Redis;
+using Microsoft.SemanticKernel.Memory;
+
+namespace Caching;
+
+///
+/// This example shows how to achieve Semantic Caching with Filters.
+/// is used to get rendered prompt and check in cache if similar prompt was already answered.
+/// If there is a record in cache, then previously cached answer will be returned to the user instead of making a call to LLM.
+/// If there is no record in cache, a call to LLM will be performed, and result will be cached together with rendered prompt.
+/// is used to update cache with rendered prompt and related LLM result.
+///
+public class SemanticCachingWithFilters(ITestOutputHelper output) : BaseTest(output)
+{
+ ///
+ /// Similarity/relevance score, from 0 to 1, where 1 means exact match.
+ /// It's possible to change this value during testing to see how caching logic will behave.
+ ///
+ private const double SimilarityScore = 0.9;
+
+ ///
+ /// Executing similar requests two times using in-memory caching store to compare execution time and results.
+ /// Second execution is faster, because the result is returned from cache.
+ ///
+ [Fact]
+ public async Task InMemoryCacheAsync()
+ {
+ var kernel = GetKernelWithCache(_ => new VolatileMemoryStore());
+
+ var result1 = await ExecuteAsync(kernel, "First run", "What's the tallest building in New York?");
+ var result2 = await ExecuteAsync(kernel, "Second run", "What is the highest building in New York City?");
+
+ Console.WriteLine($"Result 1: {result1}");
+ Console.WriteLine($"Result 2: {result2}");
+
+ /*
+ Output:
+ First run: What's the tallest building in New York?
+ Elapsed Time: 00:00:03.828
+ Second run: What is the highest building in New York City?
+ Elapsed Time: 00:00:00.541
+ Result 1: The tallest building in New York is One World Trade Center, also known as Freedom Tower.It stands at 1,776 feet(541.3 meters) tall, including its spire.
+ Result 2: The tallest building in New York is One World Trade Center, also known as Freedom Tower.It stands at 1,776 feet(541.3 meters) tall, including its spire.
+ */
+ }
+
+ ///
+ /// Executing similar requests two times using Redis caching store to compare execution time and results.
+ /// Second execution is faster, because the result is returned from cache.
+ /// How to run Redis on Docker locally: https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/docker/
+ ///
+ [Fact]
+ public async Task RedisCacheAsync()
+ {
+ var kernel = GetKernelWithCache(_ => new RedisMemoryStore("localhost:6379", vectorSize: 1536));
+
+ var result1 = await ExecuteAsync(kernel, "First run", "What's the tallest building in New York?");
+ var result2 = await ExecuteAsync(kernel, "Second run", "What is the highest building in New York City?");
+
+ Console.WriteLine($"Result 1: {result1}");
+ Console.WriteLine($"Result 2: {result2}");
+
+ /*
+ First run: What's the tallest building in New York?
+ Elapsed Time: 00:00:03.674
+ Second run: What is the highest building in New York City?
+ Elapsed Time: 00:00:00.292
+ Result 1: The tallest building in New York is One World Trade Center, also known as Freedom Tower. It stands at 1,776 feet (541 meters) tall, including its spire.
+ Result 2: The tallest building in New York is One World Trade Center, also known as Freedom Tower. It stands at 1,776 feet (541 meters) tall, including its spire.
+ */
+ }
+
+ ///
+ /// Executing similar requests two times using Azure Cosmos DB for MongoDB caching store to compare execution time and results.
+ /// Second execution is faster, because the result is returned from cache.
+ /// How to setup Azure Cosmos DB for MongoDB cluster: https://learn.microsoft.com/en-gb/azure/cosmos-db/mongodb/vcore/quickstart-portal
+ ///
+ [Fact]
+ public async Task AzureCosmosDBMongoDBCacheAsync()
+ {
+ var kernel = GetKernelWithCache(_ => new AzureCosmosDBMongoDBMemoryStore(
+ TestConfiguration.AzureCosmosDbMongoDb.ConnectionString,
+ TestConfiguration.AzureCosmosDbMongoDb.DatabaseName,
+ new()
+ {
+ Kind = AzureCosmosDBVectorSearchType.VectorIVF,
+ Similarity = AzureCosmosDBSimilarityType.Cosine,
+ Dimensions = 1536
+ }));
+
+ var result1 = await ExecuteAsync(kernel, "First run", "What's the tallest building in New York?");
+ var result2 = await ExecuteAsync(kernel, "Second run", "What is the highest building in New York City?");
+
+ Console.WriteLine($"Result 1: {result1}");
+ Console.WriteLine($"Result 2: {result2}");
+
+ /*
+ First run: What's the tallest building in New York?
+ Elapsed Time: 00:00:05.485
+ Second run: What is the highest building in New York City?
+ Elapsed Time: 00:00:00.389
+ Result 1: The tallest building in New York is One World Trade Center, also known as Freedom Tower, which stands at 1,776 feet (541.3 meters) tall.
+ Result 2: The tallest building in New York is One World Trade Center, also known as Freedom Tower, which stands at 1,776 feet (541.3 meters) tall.
+ */
+ }
+
+ #region Configuration
+
+ ///
+ /// Returns instance with required registered services.
+ ///
+ private Kernel GetKernelWithCache(Func cacheFactory)
+ {
+ var builder = Kernel.CreateBuilder();
+
+ // Add Azure OpenAI chat completion service
+ builder.AddAzureOpenAIChatCompletion(
+ TestConfiguration.AzureOpenAI.ChatDeploymentName,
+ TestConfiguration.AzureOpenAI.Endpoint,
+ TestConfiguration.AzureOpenAI.ApiKey);
+
+ // Add Azure OpenAI text embedding generation service
+ builder.AddAzureOpenAITextEmbeddingGeneration(
+ TestConfiguration.AzureOpenAIEmbeddings.DeploymentName,
+ TestConfiguration.AzureOpenAIEmbeddings.Endpoint,
+ TestConfiguration.AzureOpenAIEmbeddings.ApiKey);
+
+ // Add memory store for caching purposes (e.g. in-memory, Redis, Azure Cosmos DB)
+ builder.Services.AddSingleton(cacheFactory);
+
+ // Add text memory service that will be used to generate embeddings and query/store data.
+ builder.Services.AddSingleton();
+
+ // Add prompt render filter to query cache and check if rendered prompt was already answered.
+ builder.Services.AddSingleton();
+
+ // Add function invocation filter to cache rendered prompts and LLM results.
+ builder.Services.AddSingleton();
+
+ return builder.Build();
+ }
+
+ #endregion
+
+ #region Cache Filters
+
+ ///
+ /// Base class for filters that contains common constant values.
+ ///
+ public class CacheBaseFilter
+ {
+ ///
+ /// Collection/table name in cache to use.
+ ///
+ protected const string CollectionName = "llm_responses";
+
+ ///
+ /// Metadata key in function result for cache record id, which is used to overwrite previously cached response.
+ ///
+ protected const string RecordIdKey = "CacheRecordId";
+ }
+
+ ///
+ /// Filter which is executed during prompt rendering operation.
+ ///
+ public sealed class PromptCacheFilter(ISemanticTextMemory semanticTextMemory) : CacheBaseFilter, IPromptRenderFilter
+ {
+ public async Task OnPromptRenderAsync(PromptRenderContext context, Func next)
+ {
+ // Trigger prompt rendering operation
+ await next(context);
+
+ // Get rendered prompt
+ var prompt = context.RenderedPrompt!;
+
+ // Search for similar prompts in cache with provided similarity/relevance score
+ var searchResult = await semanticTextMemory.SearchAsync(
+ CollectionName,
+ prompt,
+ limit: 1,
+ minRelevanceScore: SimilarityScore).FirstOrDefaultAsync();
+
+ // If result exists, return it.
+ if (searchResult is not null)
+ {
+ // Override function result. This will prevent calling LLM and will return result immediately.
+ context.Result = new FunctionResult(context.Function, searchResult.Metadata.AdditionalMetadata)
+ {
+ Metadata = new Dictionary { [RecordIdKey] = searchResult.Metadata.Id }
+ };
+ }
+ }
+ }
+
+ ///
+ /// Filter which is executed during function invocation.
+ ///
+ public sealed class FunctionCacheFilter(ISemanticTextMemory semanticTextMemory) : CacheBaseFilter, IFunctionInvocationFilter
+ {
+ public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func next)
+ {
+ // Trigger function invocation
+ await next(context);
+
+ // Get function invocation result
+ var result = context.Result;
+
+ // If there was any rendered prompt, cache it together with LLM result for future calls.
+ if (!string.IsNullOrEmpty(context.Result.RenderedPrompt))
+ {
+ // Get cache record id if result was cached previously or generate new id.
+ var recordId = context.Result.Metadata?.GetValueOrDefault(RecordIdKey, Guid.NewGuid().ToString()) as string;
+
+ // Cache rendered prompt and LLM result.
+ await semanticTextMemory.SaveInformationAsync(
+ CollectionName,
+ context.Result.RenderedPrompt,
+ recordId!,
+ additionalMetadata: result.ToString());
+ }
+ }
+ }
+
+ #endregion
+
+ #region Execution
+
+ ///
+ /// Helper method to invoke prompt and measure execution time for comparison.
+ ///
+ private async Task ExecuteAsync(Kernel kernel, string title, string prompt)
+ {
+ Console.WriteLine($"{title}: {prompt}");
+
+ var stopwatch = Stopwatch.StartNew();
+
+ var result = await kernel.InvokePromptAsync(prompt);
+
+ stopwatch.Stop();
+
+ Console.WriteLine($@"Elapsed Time: {stopwatch.Elapsed:hh\:mm\:ss\.FFF}");
+
+ return result;
+ }
+
+ #endregion
+}
diff --git a/dotnet/samples/Concepts/Concepts.csproj b/dotnet/samples/Concepts/Concepts.csproj
index 891eea16c400..e4be32a502f8 100644
--- a/dotnet/samples/Concepts/Concepts.csproj
+++ b/dotnet/samples/Concepts/Concepts.csproj
@@ -48,6 +48,7 @@
+
diff --git a/dotnet/samples/Concepts/README.md b/dotnet/samples/Concepts/README.md
index 75b46663a2f6..828abd9ef290 100644
--- a/dotnet/samples/Concepts/README.md
+++ b/dotnet/samples/Concepts/README.md
@@ -26,6 +26,10 @@ Down below you can find the code snippets that demonstrate the usage of many Sem
- [Gemini_FunctionCalling](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/AutoFunctionCalling/Gemini_FunctionCalling.cs)
- [OpenAI_FunctionCalling](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/AutoFunctionCalling/OpenAI_FunctionCalling.cs)
+## Caching - Examples of caching implementations
+
+- [SemanticCachingWithFilters](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs)
+
## ChatCompletion - Examples using [`ChatCompletion`](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/IChatCompletionService.cs) messaging capable service with models
- [AzureOpenAIWithData_ChatCompletion](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/ChatCompletion/AzureOpenAIWithData_ChatCompletion.cs)
diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecord.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecord.cs
index ae93aeb5193f..7a54a02a8d74 100644
--- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecord.cs
+++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryRecord.cs
@@ -58,6 +58,9 @@ public AzureCosmosDBMongoDBMemoryRecord(MemoryRecord memoryRecord)
///
public static MemoryRecord ToMemoryRecord(BsonDocument doc, bool withEmbedding)
{
+ BsonValue? timestamp = doc["timestamp"];
+ DateTimeOffset? recordTimestamp = timestamp is BsonNull ? null : timestamp.ToUniversalTime();
+
return new(
BsonSerializer
.Deserialize(
@@ -68,10 +71,8 @@ public static MemoryRecord ToMemoryRecord(BsonDocument doc, bool withEmbedding)
? doc["embedding"].AsBsonArray.Select(x => (float)x.AsDouble).ToArray()
: null,
doc["_id"].AsString,
- doc["timestamp"]?.ToUniversalTime()
+ recordTimestamp
);
-
- // return result;
}
///
@@ -83,7 +84,7 @@ public MemoryRecord ToMemoryRecord(bool withEmbedding)
this.Metadata.ToMemoryRecordMetadata(),
withEmbedding ? this.Embedding : null,
this.Id,
- this.Timestamp?.ToLocalTime()
+ this.Timestamp
);
}
}
diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStore.cs
index b9d0b203e7b1..be8a82165e9e 100644
--- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStore.cs
+++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBMemoryStore.cs
@@ -147,6 +147,8 @@ public async Task UpsertAsync(
CancellationToken cancellationToken = default
)
{
+ record.Key = record.Metadata.Id;
+
var replaceOptions = new ReplaceOptions() { IsUpsert = true };
var result = await this.GetCollection(collectionName)
@@ -340,9 +342,9 @@ private BsonDocument GetIndexDefinitionVectorIVF(string collectionName)
"cosmosSearchOptions",
new BsonDocument
{
- { "kind", this._config.Kind },
+ { "kind", this._config.Kind.GetCustomName() },
{ "numLists", this._config.NumLists },
- { "similarity", this._config.Similarity },
+ { "similarity", this._config.Similarity.GetCustomName() },
{ "dimensions", this._config.Dimensions }
}
}
@@ -372,10 +374,10 @@ private BsonDocument GetIndexDefinitionVectorHNSW(string collectionName)
"cosmosSearchOptions",
new BsonDocument
{
- { "kind", this._config.Kind },
+ { "kind", this._config.Kind.GetCustomName() },
{ "m", this._config.NumberOfConnections },
{ "efConstruction", this._config.EfConstruction },
- { "similarity", this._config.Similarity },
+ { "similarity", this._config.Similarity.GetCustomName() },
{ "dimensions", this._config.Dimensions }
}
}
diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBSimilarityType.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBSimilarityType.cs
index cb7b92bdb467..96925d086e3e 100644
--- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBSimilarityType.cs
+++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBSimilarityType.cs
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft. All rights reserved.
-using System.Text.Json.Serialization;
+using System.Reflection;
+using MongoDB.Bson;
+using MongoDB.Bson.Serialization.Attributes;
// ReSharper disable InconsistentNaming
namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB;
@@ -13,18 +15,27 @@ public enum AzureCosmosDBSimilarityType
///
/// Cosine similarity
///
- [JsonPropertyName("COS")]
+ [BsonElement("COS")]
Cosine,
///
/// Inner Product similarity
///
- [JsonPropertyName("IP")]
+ [BsonElement("IP")]
InnerProduct,
///
/// Euclidean similarity
///
- [JsonPropertyName("L2")]
+ [BsonElement("L2")]
Euclidean
}
+
+internal static class AzureCosmosDBSimilarityTypeExtensions
+{
+ public static string GetCustomName(this AzureCosmosDBSimilarityType type)
+ {
+ var attribute = type.GetType().GetField(type.ToString()).GetCustomAttribute();
+ return attribute?.ElementName ?? type.ToString();
+ }
+}
diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBVectorSearchType.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBVectorSearchType.cs
index c676e5612fef..bf5597131150 100644
--- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBVectorSearchType.cs
+++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBVectorSearchType.cs
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
-using System.Text.Json.Serialization;
+using System.Reflection;
+using MongoDB.Bson.Serialization.Attributes;
// ReSharper disable InconsistentNaming
namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB;
@@ -13,12 +14,21 @@ public enum AzureCosmosDBVectorSearchType
///
/// vector-ivf is available on all cluster tiers
///
- [JsonPropertyName("vector_ivf")]
+ [BsonElement("vector-ivf")]
VectorIVF,
///
/// vector-hnsw is available on M40 cluster tiers and higher.
///
- [JsonPropertyName("vector_hnsw")]
+ [BsonElement("vector-hnsw")]
VectorHNSW
}
+
+internal static class AzureCosmosDBVectorSearchTypeExtensions
+{
+ public static string GetCustomName(this AzureCosmosDBVectorSearchType type)
+ {
+ var attribute = type.GetType().GetField(type.ToString()).GetCustomAttribute();
+ return attribute?.ElementName ?? type.ToString();
+ }
+}
diff --git a/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs b/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs
index d7c08c6344cf..508af88ca0d5 100644
--- a/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs
+++ b/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs
@@ -42,6 +42,7 @@ public static void Initialize(IConfigurationRoot configRoot)
public static MsGraphConfiguration MSGraph => LoadSection();
public static GoogleAIConfig GoogleAI => LoadSection();
public static VertexAIConfig VertexAI => LoadSection();
+ public static AzureCosmosDbMongoDbConfig AzureCosmosDbMongoDb => LoadSection();
private static T LoadSection([CallerMemberName] string? caller = null)
{
@@ -211,6 +212,12 @@ public class GeminiConfig
}
}
+ public class AzureCosmosDbMongoDbConfig
+ {
+ public string ConnectionString { get; set; }
+ public string DatabaseName { get; set; }
+ }
+
///
/// Graph API connector configuration model.
///
diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs
index 79402ceac836..a1e449642071 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs
@@ -62,4 +62,10 @@ public string? RenderedPrompt
this._renderedPrompt = value;
}
}
+
+ ///
+ /// Gets or sets the result of the function's invocation.
+ /// Setting to a non-null value will skip function invocation and return the result.
+ ///
+ public FunctionResult? Result { get; set; }
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs
index 0ebba8bca441..62cc5d343d01 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs
@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
using System.Globalization;
namespace Microsoft.SemanticKernel;
@@ -41,6 +42,7 @@ public FunctionResult(FunctionResult result, object? value = null)
this.Value = value ?? result.Value;
this.Culture = result.Culture;
this.Metadata = result.Metadata;
+ this.RenderedPrompt = result.RenderedPrompt;
}
///
@@ -67,6 +69,12 @@ public FunctionResult(FunctionResult result, object? value = null)
///
public Type? ValueType => this.Value?.GetType();
+ ///
+ /// Gets the prompt used during function invocation if any was rendered.
+ ///
+ [Experimental("SKEXP0001")]
+ public string? RenderedPrompt { get; internal set; }
+
///
/// Returns function result value.
///
diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/MemoryRecord.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/MemoryRecord.cs
index daf8bf2075a7..690a3d605cf4 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Memory/MemoryRecord.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Memory/MemoryRecord.cs
@@ -87,7 +87,7 @@ public static MemoryRecord ReferenceRecord(
/// Source content embedding.
/// Optional string for saving custom metadata.
/// Optional existing database key.
- /// optional timestamp.
+ /// Optional timestamp.
/// Memory record
public static MemoryRecord LocalRecord(
string id,
diff --git a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs
index ff2b16578038..f0340b710873 100644
--- a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs
+++ b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs
@@ -115,7 +115,7 @@ public static KernelFunction Create(
logger: loggerFactory?.CreateLogger(typeof(KernelFunctionFactory)) ?? NullLogger.Instance);
}
- /// j
+ ///
protected override async ValueTask InvokeCoreAsync(
Kernel kernel,
KernelArguments arguments,
@@ -132,18 +132,25 @@ protected override async ValueTask InvokeCoreAsync(
}
#pragma warning restore CS0612 // Events are deprecated
+ // Return function result if it was set in prompt filter.
+ if (result.FunctionResult is not null)
+ {
+ result.FunctionResult.RenderedPrompt = result.RenderedPrompt;
+ return result.FunctionResult;
+ }
+
if (result.AIService is IChatCompletionService chatCompletion)
{
var chatContent = await chatCompletion.GetChatMessageContentAsync(result.RenderedPrompt, result.ExecutionSettings, kernel, cancellationToken).ConfigureAwait(false);
this.CaptureUsageDetails(chatContent.ModelId, chatContent.Metadata, this._logger);
- return new FunctionResult(this, chatContent, kernel.Culture, chatContent.Metadata);
+ return new FunctionResult(this, chatContent, kernel.Culture, chatContent.Metadata) { RenderedPrompt = result.RenderedPrompt };
}
if (result.AIService is ITextGenerationService textGeneration)
{
var textContent = await textGeneration.GetTextContentWithDefaultParserAsync(result.RenderedPrompt, result.ExecutionSettings, kernel, cancellationToken).ConfigureAwait(false);
this.CaptureUsageDetails(textContent.ModelId, textContent.Metadata, this._logger);
- return new FunctionResult(this, textContent, kernel.Culture, textContent.Metadata);
+ return new FunctionResult(this, textContent, kernel.Culture, textContent.Metadata) { RenderedPrompt = result.RenderedPrompt };
}
// The service selector didn't find an appropriate service. This should only happen with a poorly implemented selector.
@@ -375,6 +382,7 @@ private async Task RenderPromptAsync(Kernel kernel, Kerne
{
ExecutionSettings = executionSettings,
RenderedEventArgs = renderedEventArgs,
+ FunctionResult = renderingContext.Result
};
}
diff --git a/dotnet/src/SemanticKernel.Core/Functions/PromptRenderingResult.cs b/dotnet/src/SemanticKernel.Core/Functions/PromptRenderingResult.cs
index 765585be9960..7aee48fc130b 100644
--- a/dotnet/src/SemanticKernel.Core/Functions/PromptRenderingResult.cs
+++ b/dotnet/src/SemanticKernel.Core/Functions/PromptRenderingResult.cs
@@ -15,6 +15,8 @@ internal sealed class PromptRenderingResult
public PromptExecutionSettings? ExecutionSettings { get; set; }
+ public FunctionResult? FunctionResult { get; set; }
+
#pragma warning disable CS0618 // Events are deprecated
public PromptRenderedEventArgs? RenderedEventArgs { get; set; }
#pragma warning restore CS0618 // Events are deprecated
diff --git a/dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs b/dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs
index a584d9f4cf1d..09819aea796d 100644
--- a/dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs
+++ b/dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs
@@ -46,7 +46,11 @@ public async Task SaveInformationAsync(
{
var embedding = await this._embeddingGenerator.GenerateEmbeddingAsync(text, kernel, cancellationToken).ConfigureAwait(false);
MemoryRecord data = MemoryRecord.LocalRecord(
- id: id, text: text, description: description, additionalMetadata: additionalMetadata, embedding: embedding);
+ id: id,
+ text: text,
+ description: description,
+ additionalMetadata: additionalMetadata,
+ embedding: embedding);
if (!(await this._storage.DoesCollectionExistAsync(collection, cancellationToken).ConfigureAwait(false)))
{
@@ -116,17 +120,20 @@ public async IAsyncEnumerable SearchAsync(
{
ReadOnlyMemory queryEmbedding = await this._embeddingGenerator.GenerateEmbeddingAsync(query, kernel, cancellationToken).ConfigureAwait(false);
- IAsyncEnumerable<(MemoryRecord, double)> results = this._storage.GetNearestMatchesAsync(
- collectionName: collection,
- embedding: queryEmbedding,
- limit: limit,
- minRelevanceScore: minRelevanceScore,
- withEmbeddings: withEmbeddings,
- cancellationToken: cancellationToken);
-
- await foreach ((MemoryRecord, double) result in results.WithCancellation(cancellationToken).ConfigureAwait(false))
+ if ((await this._storage.DoesCollectionExistAsync(collection, cancellationToken).ConfigureAwait(false)))
{
- yield return MemoryQueryResult.FromMemoryRecord(result.Item1, result.Item2);
+ IAsyncEnumerable<(MemoryRecord, double)> results = this._storage.GetNearestMatchesAsync(
+ collectionName: collection,
+ embedding: queryEmbedding,
+ limit: limit,
+ minRelevanceScore: minRelevanceScore,
+ withEmbeddings: withEmbeddings,
+ cancellationToken: cancellationToken);
+
+ await foreach ((MemoryRecord, double) result in results.WithCancellation(cancellationToken).ConfigureAwait(false))
+ {
+ yield return MemoryQueryResult.FromMemoryRecord(result.Item1, result.Item2);
+ }
}
}
diff --git a/dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs b/dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs
index eff697278997..020008070387 100644
--- a/dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs
+++ b/dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs
@@ -236,4 +236,32 @@ public async Task PostInvocationPromptFilterSkippingWorksCorrectlyAsync()
// Assert
mockTextGeneration.Verify(m => m.GetTextContentsAsync("", It.IsAny(), It.IsAny(), It.IsAny()), Times.Once());
}
+
+ [Fact]
+ public async Task PromptFilterCanOverrideFunctionResultAsync()
+ {
+ // Arrange
+ var mockTextGeneration = this.GetMockTextGeneration();
+ var function = KernelFunctionFactory.CreateFromPrompt("Prompt");
+
+ var kernel = this.GetKernelWithFilters(textGenerationService: mockTextGeneration.Object,
+ onPromptRender: async (context, next) =>
+ {
+ await next(context);
+
+ context.Result = new FunctionResult(context.Function, "Result from prompt filter");
+ },
+ onFunctionInvocation: async (context, next) =>
+ {
+ await next(context);
+ });
+
+ // Act
+ var result = await kernel.InvokeAsync(function);
+
+ // Assert
+ mockTextGeneration.Verify(m => m.GetTextContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Never());
+
+ Assert.Equal("Result from prompt filter", result.ToString());
+ }
}