diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 0a78b2c0332f..a9562b94bca6 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -84,7 +84,7 @@ - + diff --git a/dotnet/src/Extensions/PromptTemplates.Liquid/LiquidPromptTemplate.cs b/dotnet/src/Extensions/PromptTemplates.Liquid/LiquidPromptTemplate.cs index abb2b47aef4b..0e9193f290d7 100644 --- a/dotnet/src/Extensions/PromptTemplates.Liquid/LiquidPromptTemplate.cs +++ b/dotnet/src/Extensions/PromptTemplates.Liquid/LiquidPromptTemplate.cs @@ -2,14 +2,13 @@ using System; using System.Collections.Generic; -using System.Diagnostics; using System.Text; using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; using System.Web; -using Scriban; -using Scriban.Syntax; +using Fluid; +using Fluid.Ast; namespace Microsoft.SemanticKernel.PromptTemplates.Liquid; @@ -18,12 +17,18 @@ namespace Microsoft.SemanticKernel.PromptTemplates.Liquid; /// internal sealed partial class LiquidPromptTemplate : IPromptTemplate { + private static readonly FluidParser s_parser = new(); + private static readonly TemplateOptions s_templateOptions = new() + { + MemberAccessStrategy = new UnsafeMemberAccessStrategy() { MemberNameStrategy = MemberNameStrategies.SnakeCase }, + }; + private const string ReservedString = ":"; private const string ColonString = ":"; private const char LineEnding = '\n'; private readonly PromptTemplateConfig _config; private readonly bool _allowDangerouslySetContent; - private readonly Template _liquidTemplate; + private readonly IFluidTemplate _liquidTemplate; private readonly Dictionary _inputVariables; #if NET @@ -55,12 +60,12 @@ public LiquidPromptTemplate(PromptTemplateConfig config, bool allowDangerouslySe // Parse the template now so we can check for errors, understand variable usage, and // avoid having to parse on each render. - this._liquidTemplate = Template.ParseLiquid(config.Template); - if (this._liquidTemplate.HasErrors) + if (!s_parser.TryParse(config.Template, out this._liquidTemplate, out string error)) { - throw new ArgumentException($"The template could not be parsed:{Environment.NewLine}{string.Join(Environment.NewLine, this._liquidTemplate.Messages)}"); + throw new ArgumentException(error is not null ? + $"The template could not be parsed:{Environment.NewLine}{error}" : + "The template could not be parsed."); } - Debug.Assert(this._liquidTemplate.Page is not null); // Ideally the prompty author would have explicitly specified input variables. If they specified any, // assume they specified them all. If they didn't, heuristically try to find the variables, looking for @@ -92,7 +97,7 @@ public async Task RenderAsync(Kernel kernel, KernelArguments? arguments { Verify.NotNull(kernel); cancellationToken.ThrowIfCancellationRequested(); - var variables = this.GetVariables(arguments); + var variables = this.GetTemplateContext(arguments); var renderedResult = this._liquidTemplate.Render(variables); // parse chat history @@ -154,9 +159,9 @@ private string ReplaceReservedStringBackToColonIfNeeded(string text) /// /// Gets the variables for the prompt template, including setting any default values from the prompt config. /// - private Dictionary GetVariables(KernelArguments? arguments) + private TemplateContext GetTemplateContext(KernelArguments? arguments) { - var result = new Dictionary(); + var ctx = new TemplateContext(s_templateOptions); foreach (var p in this._config.InputVariables) { @@ -165,7 +170,7 @@ private string ReplaceReservedStringBackToColonIfNeeded(string text) continue; } - result[p.Name] = p.Default; + ctx.SetValue(p.Name, p.Default); } if (arguments is not null) @@ -177,17 +182,17 @@ private string ReplaceReservedStringBackToColonIfNeeded(string text) var value = (object)kvp.Value; if (this.ShouldReplaceColonToReservedString(this._config, kvp.Key, kvp.Value)) { - result[kvp.Key] = value.ToString()?.Replace(ColonString, ReservedString); + ctx.SetValue(kvp.Key, value.ToString()?.Replace(ColonString, ReservedString)); } else { - result[kvp.Key] = value; + ctx.SetValue(kvp.Key, value); } } } } - return result; + return ctx; } private bool ShouldReplaceColonToReservedString(PromptTemplateConfig promptTemplateConfig, string propertyName, object? propertyValue) @@ -209,20 +214,23 @@ private bool ShouldReplaceColonToReservedString(PromptTemplateConfig promptTempl } /// - /// Visitor for looking for variables that are only + /// Visitor for looking for variables that are only /// ever read and appear to represent very simple strings. If any variables - /// other than that are found, none are returned. + /// other than that are found, none are returned. This only handles very basic + /// cases where the template doesn't contain any more complicated constructs; + /// the heuristic can be improved over time. /// - private sealed class SimpleVariablesVisitor : ScriptVisitor + private sealed class SimpleVariablesVisitor : AstVisitor { private readonly HashSet _variables = new(StringComparer.OrdinalIgnoreCase); + private readonly Stack _statementStack = new(); private bool _valid = true; - public static HashSet InferInputs(Template template) + public static HashSet InferInputs(IFluidTemplate template) { var visitor = new SimpleVariablesVisitor(); - template.Page.Accept(visitor); + visitor.VisitTemplate(template); if (!visitor._valid) { visitor._variables.Clear(); @@ -231,27 +239,51 @@ public static HashSet InferInputs(Template template) return visitor._variables; } - public override void Visit(ScriptVariableGlobal node) + public override Statement Visit(Statement statement) + { + if (!this._valid) + { + return statement; + } + + this._statementStack.Push(statement); + try + { + return base.Visit(statement); + } + finally + { + this._statementStack.Pop(); + } + } + + protected override Expression VisitMemberExpression(MemberExpression memberExpression) { - if (this._valid) + if (memberExpression.Segments.Count == 1 && memberExpression.Segments[0] is IdentifierSegment id) { - switch (node.Parent) + bool isValid = true; + + if (this._statementStack.Count > 0) { - case ScriptAssignExpression assign when ReferenceEquals(assign.Target, node): - case ScriptForStatement forLoop: - case ScriptMemberExpression member: - // Unsupported use found; bail. - this._valid = false; - return; - - default: - // Reading from a simple variable. - this._variables.Add(node.Name); - break; + switch (this._statementStack.Peek()) + { + case ForStatement: + case AssignStatement assign when string.Equals(id.Identifier, assign.Identifier, StringComparison.OrdinalIgnoreCase): + isValid = false; + break; + } } - base.DefaultVisit(node); + if (isValid) + { + this._variables.Add(id.Identifier); + return base.VisitMemberExpression(memberExpression); + } } + + // Found something unsupported. Bail. + this._valid = false; + return memberExpression; } } } diff --git a/dotnet/src/Extensions/PromptTemplates.Liquid/PromptTemplates.Liquid.csproj b/dotnet/src/Extensions/PromptTemplates.Liquid/PromptTemplates.Liquid.csproj index 632202ce2e4e..1a8827cbbb09 100644 --- a/dotnet/src/Extensions/PromptTemplates.Liquid/PromptTemplates.Liquid.csproj +++ b/dotnet/src/Extensions/PromptTemplates.Liquid/PromptTemplates.Liquid.csproj @@ -23,6 +23,6 @@ - + \ No newline at end of file