From 5efad9dbffab20985df3031111128426c742957b Mon Sep 17 00:00:00 2001 From: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com> Date: Thu, 27 Feb 2025 14:55:41 +0000 Subject: [PATCH 1/3] Add support for all Azure AI tool types --- .../Extensions/AgentDefinitionExtensions.cs | 133 ++++++- .../AgentToolDefinitionExtensions.cs | 184 ++++++++++ .../Yaml/AzureAIKernelAgentYamlTests.cs | 346 ++++++++++++++++++ 3 files changed, 658 insertions(+), 5 deletions(-) create mode 100644 dotnet/src/Agents/AzureAI/Extensions/AgentToolDefinitionExtensions.cs create mode 100644 dotnet/src/Agents/UnitTests/Yaml/AzureAIKernelAgentYamlTests.cs diff --git a/dotnet/src/Agents/AzureAI/Extensions/AgentDefinitionExtensions.cs b/dotnet/src/Agents/AzureAI/Extensions/AgentDefinitionExtensions.cs index aeef38652fb1..b60f4449387d 100644 --- a/dotnet/src/Agents/AzureAI/Extensions/AgentDefinitionExtensions.cs +++ b/dotnet/src/Agents/AzureAI/Extensions/AgentDefinitionExtensions.cs @@ -11,6 +11,29 @@ namespace Microsoft.SemanticKernel.Agents.AzureAI; /// internal static class AgentDefinitionExtensions { + private const string AzureAISearchType = "azure_ai_search"; + private const string AzureFunctionType = "azure_function"; + private const string BingGroundingType = "bing_grounding"; + private const string CodeInterpreterType = "code_interpreter"; + private const string FileSearchType = "file_search"; + private const string FunctionType = "function"; + private const string MicrosoftFabricType = "fabric_aiskill"; + private const string OpenApiType = "openapi"; + private const string SharepointGroundingType = "sharepoint_grounding"; + + private static readonly string[] s_validToolTypes = new string[] + { + AzureAISearchType, + AzureFunctionType, + BingGroundingType, + CodeInterpreterType, + FileSearchType, + FunctionType, + MicrosoftFabricType, + OpenApiType, + SharepointGroundingType + }; + /// /// Return the Azure AI tool definitions which corresponds with the provided . /// @@ -18,15 +41,22 @@ internal static class AgentDefinitionExtensions /// public static IEnumerable GetAzureToolDefinitions(this AgentDefinition agentDefinition) { - return agentDefinition.Tools?.Select(tool => + return agentDefinition.Tools.Select(tool => { return tool.Type switch { - "code_interpreter" => new CodeInterpreterToolDefinition(), - "file_search" => new FileSearchToolDefinition(), - _ => throw new NotSupportedException($"Unable to create Azure AI tool definition because of unsupported tool type: {tool.Type}"), + AzureAISearchType => CreateAzureAISearchToolDefinition(tool), + AzureFunctionType => CreateAzureFunctionToolDefinition(tool), + BingGroundingType => CreateBingGroundingToolDefinition(tool), + CodeInterpreterType => CreateCodeInterpreterToolDefinition(tool), + FileSearchType => CreateFileSearchToolDefinition(tool), + FunctionType => CreateFunctionToolDefinition(tool), + MicrosoftFabricType => CreateMicrosoftFabricToolDefinition(tool), + OpenApiType => CreateOpenApiToolDefinition(tool), + SharepointGroundingType => CreateSharepointGroundingToolDefinition(tool), + _ => throw new NotSupportedException($"Unable to create Azure AI tool definition because of unsupported tool type: {tool.Type}, supported tool types are: {string.Join(",", s_validToolTypes)}"), }; - }) ?? []; + }); } /// @@ -40,4 +70,97 @@ public static IEnumerable GetAzureToolDefinitions(this AgentDefi // TODO: Implement return null; } + + #region private + private static AzureAISearchToolDefinition CreateAzureAISearchToolDefinition(AgentToolDefinition tool) + { + Verify.NotNull(tool); + + return new AzureAISearchToolDefinition(); + } + + private static AzureFunctionToolDefinition CreateAzureFunctionToolDefinition(AgentToolDefinition tool) + { + Verify.NotNull(tool); + Verify.NotNull(tool.Name); + Verify.NotNull(tool.Description); + + string name = tool.Name; + string description = tool.Description; + AzureFunctionBinding inputBinding = tool.GetInputBinding(); + AzureFunctionBinding outputBinding = tool.GetOutputBinding(); + BinaryData parameters = tool.GetParameters(); + + return new AzureFunctionToolDefinition(name, description, inputBinding, outputBinding, parameters); + } + + private static BingGroundingToolDefinition CreateBingGroundingToolDefinition(AgentToolDefinition tool) + { + Verify.NotNull(tool); + + ToolConnectionList bingGrounding = tool.GetToolConnectionList(); + + return new BingGroundingToolDefinition(bingGrounding); + } + + private static CodeInterpreterToolDefinition CreateCodeInterpreterToolDefinition(AgentToolDefinition tool) + { + return new CodeInterpreterToolDefinition(); + } + + private static FileSearchToolDefinition CreateFileSearchToolDefinition(AgentToolDefinition tool) + { + Verify.NotNull(tool); + + return new FileSearchToolDefinition() + { + FileSearch = tool.GetFileSearchToolDefinitionDetails() + }; + } + + private static FunctionToolDefinition CreateFunctionToolDefinition(AgentToolDefinition tool) + { + Verify.NotNull(tool); + Verify.NotNull(tool.Name); + Verify.NotNull(tool.Description); + + string name = tool.Name; + string description = tool.Description; + BinaryData parameters = tool.GetParameters(); + + return new FunctionToolDefinition(name, description, parameters); + } + + private static MicrosoftFabricToolDefinition CreateMicrosoftFabricToolDefinition(AgentToolDefinition tool) + { + Verify.NotNull(tool); + + ToolConnectionList fabricAiskill = tool.GetToolConnectionList(); + + return new MicrosoftFabricToolDefinition(fabricAiskill); + } + + private static OpenApiToolDefinition CreateOpenApiToolDefinition(AgentToolDefinition tool) + { + Verify.NotNull(tool); + Verify.NotNull(tool.Name); + Verify.NotNull(tool.Description); + + string name = tool.Name; + string description = tool.Description; + BinaryData spec = tool.GetSpecification(); + OpenApiAuthDetails auth = tool.GetOpenApiAuthDetails(); + + return new OpenApiToolDefinition(name, description, spec, auth); + } + + private static SharepointToolDefinition CreateSharepointGroundingToolDefinition(AgentToolDefinition tool) + { + Verify.NotNull(tool); + + ToolConnectionList sharepointGrounding = tool.GetToolConnectionList(); + + return new SharepointToolDefinition(sharepointGrounding); + } + #endregion } diff --git a/dotnet/src/Agents/AzureAI/Extensions/AgentToolDefinitionExtensions.cs b/dotnet/src/Agents/AzureAI/Extensions/AgentToolDefinitionExtensions.cs new file mode 100644 index 000000000000..e334d83c1eed --- /dev/null +++ b/dotnet/src/Agents/AzureAI/Extensions/AgentToolDefinitionExtensions.cs @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using Azure.AI.Projects; + +namespace Microsoft.SemanticKernel.Agents.AzureAI; + +/// +/// Provides extension methods for . +/// +internal static class AgentToolDefinitionExtensions +{ + internal static AzureFunctionBinding GetInputBinding(this AgentToolDefinition agentToolDefinition) + { + return agentToolDefinition.GetAzureFunctionBinding("input_binding"); + } + + internal static AzureFunctionBinding GetOutputBinding(this AgentToolDefinition agentToolDefinition) + { + return agentToolDefinition.GetAzureFunctionBinding("output_binding"); + } + + internal static BinaryData GetParameters(this AgentToolDefinition agentToolDefinition) + { + Verify.NotNull(agentToolDefinition.Configuration); + + var parameters = agentToolDefinition.GetConfiguration>("parameters"); + + // TODO Needswork + return parameters is not null ? new BinaryData(parameters) : s_noParams; + } + + internal static FileSearchToolDefinitionDetails GetFileSearchToolDefinitionDetails(this AgentToolDefinition agentToolDefinition) + { + var details = new FileSearchToolDefinitionDetails() + { + MaxNumResults = agentToolDefinition.GetConfiguration("max_num_results") + }; + + FileSearchRankingOptions? rankingOptions = agentToolDefinition.GetFileSearchRankingOptions(); + if (rankingOptions is not null) + { + details.RankingOptions = rankingOptions; + } + + return details; + } + + internal static ToolConnectionList GetToolConnectionList(this AgentToolDefinition agentToolDefinition) + { + Verify.NotNull(agentToolDefinition.Configuration); + + var toolConnections = agentToolDefinition.GetToolConnections(); + + var toolConnectionList = new ToolConnectionList(); + if (toolConnections is not null) + { + toolConnectionList.ConnectionList.AddRange(toolConnections); + } + return toolConnectionList; + } + + internal static BinaryData GetSpecification(this AgentToolDefinition agentToolDefinition) + { + Verify.NotNull(agentToolDefinition.Configuration); + + var specification = agentToolDefinition.GetRequiredConfiguration>("specification"); + + return new BinaryData(specification); + } + + internal static OpenApiAuthDetails GetOpenApiAuthDetails(this AgentToolDefinition agentToolDefinition) + { + Verify.NotNull(agentToolDefinition.Configuration); + + var connectionId = agentToolDefinition.GetConfiguration("connection_id"); + if (!string.IsNullOrEmpty(connectionId)) + { + return new OpenApiConnectionAuthDetails(new OpenApiConnectionSecurityScheme(connectionId)); + } + + var audience = agentToolDefinition.GetConfiguration("audience"); + if (!string.IsNullOrEmpty(audience)) + { + return new OpenApiManagedAuthDetails(new OpenApiManagedSecurityScheme(audience)); + } + + return new OpenApiAnonymousAuthDetails(); + } + + private static AzureFunctionBinding GetAzureFunctionBinding(this AgentToolDefinition agentToolDefinition, string bindingType) + { + Verify.NotNull(agentToolDefinition.Configuration); + + var binding = agentToolDefinition.GetRequiredConfiguration>(bindingType); + if (!binding.TryGetValue("storage_service_endpoint", out var endpointValue) || endpointValue is not string storageServiceEndpoint) + { + throw new ArgumentException($"The configuration key '{bindingType}.storage_service_endpoint' is required."); + } + if (!binding.TryGetValue("queue_name", out var nameValue) || nameValue is not string queueName) + { + throw new ArgumentException($"The configuration key '{bindingType}.queue_name' is required."); + } + + return new AzureFunctionBinding(new AzureFunctionStorageQueue(storageServiceEndpoint, queueName)); + } + + private static FileSearchRankingOptions? GetFileSearchRankingOptions(this AgentToolDefinition agentToolDefinition) + { + string? ranker = agentToolDefinition.GetConfiguration("ranker"); + float? scoreThreshold = agentToolDefinition.GetConfiguration("score_threshold"); + + if (ranker is not null && scoreThreshold is not null) + { + return new FileSearchRankingOptions(ranker, (float)scoreThreshold!); + } + + return null; + } + + private static List GetToolConnections(this AgentToolDefinition agentToolDefinition) + { + Verify.NotNull(agentToolDefinition.Configuration); + + var toolConnections = agentToolDefinition.GetRequiredConfiguration>("tool_connections"); + + return toolConnections.Select(connectionId => new ToolConnection(connectionId.ToString())).ToList(); + } + + private static T GetRequiredConfiguration(this AgentToolDefinition agentToolDefinition, string key) + { + Verify.NotNull(agentToolDefinition); + Verify.NotNull(agentToolDefinition.Configuration); + Verify.NotNull(key); + + if (agentToolDefinition.Configuration?.TryGetValue(key, out var value) ?? false) + { + if (value == null) + { + throw new ArgumentNullException($"The configuration key '{key}' must be a non null value."); + } + + try + { + return (T)Convert.ChangeType(value, typeof(T)); + } + catch (InvalidCastException) + { + throw new InvalidCastException($"The configuration key '{key}' value must be of type '{typeof(T)}' but is '{value.GetType()}'."); + } + } + + throw new ArgumentException($"The configuration key '{key}' is required."); + } + + private static T? GetConfiguration(this AgentToolDefinition agentToolDefinition, string key) + { + Verify.NotNull(agentToolDefinition); + Verify.NotNull(key); + + if (agentToolDefinition.Configuration?.TryGetValue(key, out var value) ?? false) + { + if (value == null) + { + return default; + } + + try + { + return (T?)Convert.ChangeType(value, typeof(T)); + } + catch (InvalidCastException) + { + throw new InvalidCastException($"The configuration key '{key}' value must be of type '{typeof(T?)}' but is '{value.GetType()}'."); + } + } + + return default; + } + + private static readonly BinaryData s_noParams = BinaryData.FromObjectAsJson(new { type = "object", properties = new { } }); +} diff --git a/dotnet/src/Agents/UnitTests/Yaml/AzureAIKernelAgentYamlTests.cs b/dotnet/src/Agents/UnitTests/Yaml/AzureAIKernelAgentYamlTests.cs new file mode 100644 index 000000000000..f83c0758bd18 --- /dev/null +++ b/dotnet/src/Agents/UnitTests/Yaml/AzureAIKernelAgentYamlTests.cs @@ -0,0 +1,346 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using Azure.AI.Projects; +using Azure.Core.Pipeline; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.Agents.AzureAI; +using SemanticKernel.Agents.UnitTests.AzureAI.Definition; +using Xunit; + +namespace SemanticKernel.Agents.UnitTests.Yaml; + +/// +/// Unit tests for with . +/// +public class AzureAIKernelAgentYamlTests : IDisposable +{ + private readonly HttpMessageHandlerStub _messageHandlerStub; + private readonly HttpClient _httpClient; + private readonly Kernel _kernel; + + /// + /// Initializes a new instance of the class. + /// + public AzureAIKernelAgentYamlTests() + { + this._messageHandlerStub = new HttpMessageHandlerStub(); + this._httpClient = new HttpClient(this._messageHandlerStub, disposeHandler: false); + + var builder = Kernel.CreateBuilder(); + + // Add Azure AI agents client + var client = new AIProjectClient( + "endpoint;subscription_id;resource_group_name;project_name", + new FakeTokenCredential(), + new AIProjectClientOptions() + { Transport = new HttpClientTransport(this._httpClient) }); + builder.Services.AddSingleton(client); + + this._kernel = builder.Build(); + } + + /// + public void Dispose() + { + GC.SuppressFinalize(this); + this._messageHandlerStub.Dispose(); + this._httpClient.Dispose(); + } + + /// + /// Verify the request includes a tool of the specified when creating an Azure AI agent. + /// + [Theory] + [InlineData("code_interpreter")] + [InlineData("azure_ai_search")] + public async Task VerifyRequestIncludesToolAsync(string type) + { + // Arrange + var text = + $""" + type: azureai_agent + name: AzureAIAgent + description: AzureAIAgent Description + instructions: AzureAIAgent Instructions + model: + id: gpt-4o-mini + tools: + - type: {type} + """; + AzureAIAgentFactory factory = new(); + this.SetupResponse(HttpStatusCode.OK, AzureAIAgentFactoryTests.AzureAIAgentResponse); + + // Act + var agent = await factory.CreateAgentFromYamlAsync(text, this._kernel); + + // Assert + Assert.NotNull(agent); + var requestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!); + Assert.NotNull(requestContent); + var requestJson = JsonSerializer.Deserialize(requestContent); + Assert.Equal(1, requestJson.GetProperty("tools").GetArrayLength()); + Assert.Equal(type, requestJson.GetProperty("tools")[0].GetProperty("type").GetString()); + } + + /// + /// Verify the request includes an Azure Function tool when creating an Azure AI agent. + /// + [Fact] + public async Task VerifyRequestIncludesAzureFunctionAsync() + { + // Arrange + var text = + """ + type: azureai_agent + name: AzureAIAgent + description: AzureAIAgent Description + instructions: AzureAIAgent Instructions + model: + id: gpt-4o-mini + tools: + - type: azure_function + name: function1 + description: function1 description + input_binding: + storage_service_endpoint: https://storage_service_endpoint + queue_name: queue_name + output_binding: + storage_service_endpoint: https://storage_service_endpoint + queue_name: queue_name + parameters: + - name: param1 + type: string + description: param1 description + - name: param2 + type: string + description: param2 description + """; + AzureAIAgentFactory factory = new(); + this.SetupResponse(HttpStatusCode.OK, AzureAIAgentFactoryTests.AzureAIAgentResponse); + + // Act + var agent = await factory.CreateAgentFromYamlAsync(text, this._kernel); + + // Assert + Assert.NotNull(agent); + var requestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!); + Assert.NotNull(requestContent); + var requestJson = JsonSerializer.Deserialize(requestContent); + Assert.Equal(1, requestJson.GetProperty("tools").GetArrayLength()); + Assert.Equal("azure_function", requestJson.GetProperty("tools")[0].GetProperty("type").GetString()); + } + + /// + /// Verify the request includes a Function when creating an Azure AI agent. + /// + [Fact] + public async Task VerifyRequestIncludesFunctionAsync() + { + // Arrange + var text = + """ + type: azureai_agent + name: AzureAIAgent + description: AzureAIAgent Description + instructions: AzureAIAgent Instructions + model: + id: gpt-4o-mini + tools: + - type: function + name: function1 + description: function1 description + parameters: + - name: param1 + type: string + description: param1 description + - name: param2 + type: string + description: param2 description + """; + AzureAIAgentFactory factory = new(); + this.SetupResponse(HttpStatusCode.OK, AzureAIAgentFactoryTests.AzureAIAgentResponse); + + // Act + var agent = await factory.CreateAgentFromYamlAsync(text, this._kernel); + + // Assert + Assert.NotNull(agent); + var requestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!); + Assert.NotNull(requestContent); + var requestJson = JsonSerializer.Deserialize(requestContent); + Assert.Equal(1, requestJson.GetProperty("tools").GetArrayLength()); + Assert.Equal("function", requestJson.GetProperty("tools")[0].GetProperty("type").GetString()); + } + + /// + /// Verify the request includes a Bing Grounding tool when creating an Azure AI agent. + /// + [Fact] + public async Task VerifyRequestIncludesBingGroundingAsync() + { + // Arrange + var text = + """ + type: azureai_agent + name: AzureAIAgent + description: AzureAIAgent Description + instructions: AzureAIAgent Instructions + model: + id: gpt-4o-mini + tools: + - type: bing_grounding + tool_connections: + - connection_string + """; + AzureAIAgentFactory factory = new(); + this.SetupResponse(HttpStatusCode.OK, AzureAIAgentFactoryTests.AzureAIAgentResponse); + + // Act + var agent = await factory.CreateAgentFromYamlAsync(text, this._kernel); + + // Assert + Assert.NotNull(agent); + var requestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!); + Assert.NotNull(requestContent); + var requestJson = JsonSerializer.Deserialize(requestContent); + Assert.Equal(1, requestJson.GetProperty("tools").GetArrayLength()); + Assert.Equal("bing_grounding", requestJson.GetProperty("tools")[0].GetProperty("type").GetString()); + } + + /// + /// Verify the request includes a Microsoft Fabric tool when creating an Azure AI agent. + /// + [Fact] + public async Task VerifyRequestIncludesMicrosoftFabricAsync() + { + // Arrange + var text = + """ + type: azureai_agent + name: AzureAIAgent + description: AzureAIAgent Description + instructions: AzureAIAgent Instructions + model: + id: gpt-4o-mini + tools: + - type: fabric_aiskill + tool_connections: + - connection_string + """; + AzureAIAgentFactory factory = new(); + this.SetupResponse(HttpStatusCode.OK, AzureAIAgentFactoryTests.AzureAIAgentResponse); + + // Act + var agent = await factory.CreateAgentFromYamlAsync(text, this._kernel); + + // Assert + Assert.NotNull(agent); + var requestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!); + Assert.NotNull(requestContent); + var requestJson = JsonSerializer.Deserialize(requestContent); + Assert.Equal(1, requestJson.GetProperty("tools").GetArrayLength()); + Assert.Equal("fabric_aiskill", requestJson.GetProperty("tools")[0].GetProperty("type").GetString()); + } + + /// + /// Verify the request includes a Open API tool when creating an Azure AI agent. + /// + [Fact] + public async Task VerifyRequestIncludesOpenAPIAsync() + { + // Arrange + var text = + """ + type: azureai_agent + name: AzureAIAgent + description: AzureAIAgent Description + instructions: AzureAIAgent Instructions + model: + id: gpt-4o-mini + tools: + - type: openapi + name: function1 + description: function1 description + specification: {"openapi":"3.1.0","info":{"title":"Get Weather Data","description":"Retrieves current weather data for a location based on wttr.in.","version":"v1.0.0"},"servers":[{"url":"https://wttr.in"}],"auth":[],"paths":{"/{location}":{"get":{"description":"Get weather information for a specific location","operationId":"GetCurrentWeather","parameters":[{"name":"location","in":"path","description":"City or location to retrieve the weather for","required":true,"schema":{"type":"string"}},{"name":"format","in":"query","description":"Always use j1 value for this parameter","required":true,"schema":{"type":"string","default":"j1"}}],"responses":{"200":{"description":"Successful response","content":{"text/plain":{"schema":{"type":"string"}}}},"404":{"description":"Location not found"}},"deprecated":false}}},"components":{"schemes":{}}} + - type: openapi + name: function2 + description: function2 description + specification: {"openapi":"3.1.0","info":{"title":"Get Weather Data","description":"Retrieves current weather data for a location based on wttr.in.","version":"v1.0.0"},"servers":[{"url":"https://wttr.in"}],"auth":[],"paths":{"/{location}":{"get":{"description":"Get weather information for a specific location","operationId":"GetCurrentWeather","parameters":[{"name":"location","in":"path","description":"City or location to retrieve the weather for","required":true,"schema":{"type":"string"}},{"name":"format","in":"query","description":"Always use j1 value for this parameter","required":true,"schema":{"type":"string","default":"j1"}}],"responses":{"200":{"description":"Successful response","content":{"text/plain":{"schema":{"type":"string"}}}},"404":{"description":"Location not found"}},"deprecated":false}}},"components":{"schemes":{}}} + authentication: + connection_id: connection_id + - type: openapi + name: function3 + description: function3 description + specification: {"openapi":"3.1.0","info":{"title":"Get Weather Data","description":"Retrieves current weather data for a location based on wttr.in.","version":"v1.0.0"},"servers":[{"url":"https://wttr.in"}],"auth":[],"paths":{"/{location}":{"get":{"description":"Get weather information for a specific location","operationId":"GetCurrentWeather","parameters":[{"name":"location","in":"path","description":"City or location to retrieve the weather for","required":true,"schema":{"type":"string"}},{"name":"format","in":"query","description":"Always use j1 value for this parameter","required":true,"schema":{"type":"string","default":"j1"}}],"responses":{"200":{"description":"Successful response","content":{"text/plain":{"schema":{"type":"string"}}}},"404":{"description":"Location not found"}},"deprecated":false}}},"components":{"schemes":{}}} + authentication: + audience: audience + """; + AzureAIAgentFactory factory = new(); + this.SetupResponse(HttpStatusCode.OK, AzureAIAgentFactoryTests.AzureAIAgentResponse); + + // Act + var agent = await factory.CreateAgentFromYamlAsync(text, this._kernel); + + // Assert + Assert.NotNull(agent); + var requestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!); + Assert.NotNull(requestContent); + var requestJson = JsonSerializer.Deserialize(requestContent); + Assert.Equal(3, requestJson.GetProperty("tools").GetArrayLength()); + Assert.Equal("openapi", requestJson.GetProperty("tools")[0].GetProperty("type").GetString()); + Assert.Equal("openapi", requestJson.GetProperty("tools")[1].GetProperty("type").GetString()); + Assert.Equal("openapi", requestJson.GetProperty("tools")[2].GetProperty("type").GetString()); + } + + /// + /// Verify the request includes a Sharepoint tool when creating an Azure AI agent. + /// + [Fact] + public async Task VerifyRequestIncludesSharepointGroundingAsync() + { + // Arrange + var text = + """ + type: azureai_agent + name: AzureAIAgent + description: AzureAIAgent Description + instructions: AzureAIAgent Instructions + model: + id: gpt-4o-mini + tools: + - type: sharepoint_grounding + tool_connections: + - connection_string + """; + AzureAIAgentFactory factory = new(); + this.SetupResponse(HttpStatusCode.OK, AzureAIAgentFactoryTests.AzureAIAgentResponse); + + // Act + var agent = await factory.CreateAgentFromYamlAsync(text, this._kernel); + + // Assert + Assert.NotNull(agent); + var requestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!); + Assert.NotNull(requestContent); + var requestJson = JsonSerializer.Deserialize(requestContent); + Assert.Equal(1, requestJson.GetProperty("tools").GetArrayLength()); + Assert.Equal("sharepoint_grounding", requestJson.GetProperty("tools")[0].GetProperty("type").GetString()); + } + + #region private + private void SetupResponse(HttpStatusCode statusCode, string response) => +#pragma warning disable CA2000 // Dispose objects before losing scope + this._messageHandlerStub.ResponseQueue.Enqueue(new(statusCode) + { + Content = new StringContent(response) + }); + #endregion +} From 20fe322b6c97c07f413e1f0e9786d26c7488c68b Mon Sep 17 00:00:00 2001 From: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com> Date: Thu, 27 Feb 2025 19:49:54 +0000 Subject: [PATCH 2/3] Fix warning --- .../Agents/AzureAI/Extensions/AgentDefinitionExtensions.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dotnet/src/Agents/AzureAI/Extensions/AgentDefinitionExtensions.cs b/dotnet/src/Agents/AzureAI/Extensions/AgentDefinitionExtensions.cs index b60f4449387d..330108a9ac01 100644 --- a/dotnet/src/Agents/AzureAI/Extensions/AgentDefinitionExtensions.cs +++ b/dotnet/src/Agents/AzureAI/Extensions/AgentDefinitionExtensions.cs @@ -41,7 +41,7 @@ internal static class AgentDefinitionExtensions /// public static IEnumerable GetAzureToolDefinitions(this AgentDefinition agentDefinition) { - return agentDefinition.Tools.Select(tool => + return agentDefinition.Tools?.Select(tool => { return tool.Type switch { @@ -56,7 +56,7 @@ public static IEnumerable GetAzureToolDefinitions(this AgentDefi SharepointGroundingType => CreateSharepointGroundingToolDefinition(tool), _ => throw new NotSupportedException($"Unable to create Azure AI tool definition because of unsupported tool type: {tool.Type}, supported tool types are: {string.Join(",", s_validToolTypes)}"), }; - }); + }) ?? []; } /// From 026d51fc4c27bcd92c739e5de8e3cbf0e840992f Mon Sep 17 00:00:00 2001 From: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com> Date: Fri, 28 Feb 2025 11:04:38 +0000 Subject: [PATCH 3/3] Address code review feedback --- .../UnitTests/Yaml/AzureAIKernelAgentYamlTests.cs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dotnet/src/Agents/UnitTests/Yaml/AzureAIKernelAgentYamlTests.cs b/dotnet/src/Agents/UnitTests/Yaml/AzureAIKernelAgentYamlTests.cs index f83c0758bd18..01e23ae743ff 100644 --- a/dotnet/src/Agents/UnitTests/Yaml/AzureAIKernelAgentYamlTests.cs +++ b/dotnet/src/Agents/UnitTests/Yaml/AzureAIKernelAgentYamlTests.cs @@ -337,10 +337,10 @@ public async Task VerifyRequestIncludesSharepointGroundingAsync() #region private private void SetupResponse(HttpStatusCode statusCode, string response) => -#pragma warning disable CA2000 // Dispose objects before losing scope - this._messageHandlerStub.ResponseQueue.Enqueue(new(statusCode) - { - Content = new StringContent(response) - }); + this._messageHandlerStub.ResponseToReturn = + new HttpResponseMessage(statusCode) + { + Content = new StringContent(response) + }; #endregion }