Skip to content

.Net: Enable CreateFromType/Object to work with closed generics #6218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ private KernelFunctionFromMethod(

private static MethodDetails GetMethodDetails(string? functionName, MethodInfo method, object? target)
{
ThrowForInvalidSignatureIf(method.IsGenericMethodDefinition, method, "Generic methods are not supported");
ThrowForInvalidSignatureIf(method.ContainsGenericParameters, method, "Open generic methods are not supported");

if (functionName is null)
{
Expand Down Expand Up @@ -795,7 +795,7 @@ input is byte ||
/// <summary>
/// Remove characters from method name that are valid in metadata but invalid for SK.
/// </summary>
private static string SanitizeMetadataName(string methodName) =>
internal static string SanitizeMetadataName(string methodName) =>
InvalidNameCharsRegex().Replace(methodName, "_");

/// <summary>Regex that flags any character other than ASCII digits or letters or the underscore.</summary>
Expand Down
33 changes: 31 additions & 2 deletions dotnet/src/SemanticKernel.Core/Functions/KernelPluginFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.ComponentModel;
using System.Reflection;
using System.Text.RegularExpressions;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;

Expand All @@ -12,7 +13,7 @@ namespace Microsoft.SemanticKernel;
/// <summary>
/// Provides static factory methods for creating commonly-used plugin implementations.
/// </summary>
public static class KernelPluginFactory
public static partial class KernelPluginFactory
{
/// <summary>Creates a plugin that wraps a new instance of the specified type <typeparamref name="T"/>.</summary>
/// <typeparam name="T">Specifies the type of the object to wrap.</typeparam>
Expand Down Expand Up @@ -49,7 +50,7 @@ public static KernelPlugin CreateFromObject(object target, string? pluginName =
{
Verify.NotNull(target);

pluginName ??= target.GetType().Name;
pluginName ??= CreatePluginName(target.GetType());
Verify.ValidPluginName(pluginName);

MethodInfo[] methods = target.GetType().GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static);
Expand Down Expand Up @@ -101,4 +102,32 @@ public static KernelPlugin CreateFromFunctions(string pluginName, IEnumerable<Ke
/// <exception cref="ArgumentException"><paramref name="functions"/> contains two functions with the same name.</exception>
public static KernelPlugin CreateFromFunctions(string pluginName, string? description = null, IEnumerable<KernelFunction>? functions = null) =>
new DefaultKernelPlugin(pluginName, description, functions);

/// <summary>Creates a name for a plugin based on its type name.</summary>
private static string CreatePluginName(Type type)
{
string name = type.ToString();

// Remove the namespace
if (type.Namespace is string ns &&
name.StartsWith(ns, StringComparison.Ordinal) &&
name.Length > ns.Length + 1 &&
name[ns.Length] == '.')
{
name = name.Substring(ns.Length + 1);
}

// Replace invalid characters
name = InvalidPluginNameCharactersRegex().Replace(name, "_");

return name;
}

#if NET
[GeneratedRegex("[^0-9A-Za-z_]")]
private static partial Regex InvalidPluginNameCharactersRegex();
#else
private static Regex InvalidPluginNameCharactersRegex() => s_invalidPluginNameCharactersRegex;
private static readonly Regex s_invalidPluginNameCharactersRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled);
#endif
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,24 @@ async Task ExecuteAsync(string done)
Assert.Empty(result.ToString());
}

[Fact]
public async Task ItCanImportClosedGenericsAsync()
{
await Validate(KernelPluginFactory.CreateFromObject(new GenericPlugin<int>()));
await Validate(KernelPluginFactory.CreateFromType<GenericPlugin<int>>());

async Task Validate(KernelPlugin plugin)
{
Assert.Equal("KernelFunctionFromMethodTests2_GenericPlugin_1_System_Int32_", plugin.Name);
Assert.Equal(3, plugin.FunctionCount);
foreach (KernelFunction function in plugin)
{
FunctionResult result = await function.InvokeAsync(new(), new() { { "input", 42 } });
Assert.Equal(42, result.Value);
}
}
}

[Fact]
public async Task ItCanImportMethodFunctionsWithExternalReferencesAsync()
{
Expand Down Expand Up @@ -449,4 +467,16 @@ public string WithPrimitives(
return string.Empty;
}
}

private sealed class GenericPlugin<T>
{
[KernelFunction]
public int GetValue1(int input) => input;

[KernelFunction]
public T GetValue2(T input) => input;

[KernelFunction]
public Task<T> GetValue3Async(T input) => Task.FromResult(input);
}
}
Loading