From 28b304830391adff681609ff6cf6437c8f8ac837 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Sun, 23 Mar 2025 11:48:41 +0000 Subject: [PATCH] Make AgentThread.Create protected and expose on concrete implementations where possible --- dotnet/src/Agents/Abstractions/Agent.cs | 2 -- dotnet/src/Agents/Abstractions/AgentThread.cs | 2 +- .../src/Agents/AzureAI/AzureAIAgentThread.cs | 10 ++++++++++ .../src/Agents/Core/ChatHistoryAgentThread.cs | 10 ++++++++++ .../OpenAI/OpenAIAssistantAgentThread.cs | 10 ++++++++++ .../Agents/UnitTests/Core/AgentThreadTests.cs | 5 +++++ .../CommonInterfaceConformance/AgentFixture.cs | 2 ++ .../AgentThreadConformance/AgentThreadTests.cs | 18 ++++-------------- .../ChatCompletionAgentThreadTests.cs | 7 ------- .../AzureAIAgentFixture.cs | 14 ++++++++++++++ .../ChatCompletionAgentFixture.cs | 9 ++++++--- .../OpenAIAssistantAgentFixture.cs | 14 ++++++++++++++ 12 files changed, 76 insertions(+), 27 deletions(-) diff --git a/dotnet/src/Agents/Abstractions/Agent.cs b/dotnet/src/Agents/Abstractions/Agent.cs index 1dc1461f122f..0931827f94ed 100644 --- a/dotnet/src/Agents/Abstractions/Agent.cs +++ b/dotnet/src/Agents/Abstractions/Agent.cs @@ -198,8 +198,6 @@ protected async Task EnsureThreadExistsWithMessageAsync /// The to monitor for cancellation requests. The default is . /// A task that completes when the thread has been created. - public virtual async Task CreateAsync(CancellationToken cancellationToken = default) + protected virtual async Task CreateAsync(CancellationToken cancellationToken = default) { if (this.IsDeleted) { diff --git a/dotnet/src/Agents/AzureAI/AzureAIAgentThread.cs b/dotnet/src/Agents/AzureAI/AzureAIAgentThread.cs index 39bbcce4fbba..ffdd07b421ac 100644 --- a/dotnet/src/Agents/AzureAI/AzureAIAgentThread.cs +++ b/dotnet/src/Agents/AzureAI/AzureAIAgentThread.cs @@ -61,6 +61,16 @@ public AzureAIAgentThread(AgentsClient client, string id) this.Id = id; } + /// + /// Creates the thread and returns the thread id. + /// + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the thread has been created. + public new Task CreateAsync(CancellationToken cancellationToken = default) + { + return base.CreateAsync(cancellationToken); + } + /// protected async override Task CreateInternalAsync(CancellationToken cancellationToken) { diff --git a/dotnet/src/Agents/Core/ChatHistoryAgentThread.cs b/dotnet/src/Agents/Core/ChatHistoryAgentThread.cs index 313bdff31fcb..836464a2920d 100644 --- a/dotnet/src/Agents/Core/ChatHistoryAgentThread.cs +++ b/dotnet/src/Agents/Core/ChatHistoryAgentThread.cs @@ -36,6 +36,16 @@ public ChatHistoryAgentThread(ChatHistory chatHistory, string? id = null) this.Id = id ?? Guid.NewGuid().ToString("N"); } + /// + /// Creates the thread and returns the thread id. + /// + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the thread has been created. + public new Task CreateAsync(CancellationToken cancellationToken = default) + { + return base.CreateAsync(cancellationToken); + } + /// protected override Task CreateInternalAsync(CancellationToken cancellationToken) { diff --git a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgentThread.cs b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgentThread.cs index 97952a8f1fed..15a9c1159c07 100644 --- a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgentThread.cs +++ b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgentThread.cs @@ -47,6 +47,16 @@ public OpenAIAssistantAgentThread(AssistantClient client, string id) this.Id = id; } + /// + /// Creates the thread and returns the thread id. + /// + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the thread has been created. + public new Task CreateAsync(CancellationToken cancellationToken = default) + { + return base.CreateAsync(cancellationToken); + } + /// protected async override Task CreateInternalAsync(CancellationToken cancellationToken) { diff --git a/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs b/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs index 0f22756037bb..c8e0c1884a87 100644 --- a/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs +++ b/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs @@ -147,6 +147,11 @@ private sealed class TestAgentThread : AgentThread public int DeleteInternalAsyncCount { get; private set; } public int OnNewMessageInternalAsyncCount { get; private set; } + public new Task CreateAsync(CancellationToken cancellationToken = default) + { + return base.CreateAsync(cancellationToken); + } + protected override Task CreateInternalAsync(CancellationToken cancellationToken) { this.CreateInternalAsyncCount++; diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentFixture.cs index 8a2a498897e5..8be11475493c 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentFixture.cs @@ -17,6 +17,8 @@ public abstract class AgentFixture : IAsyncLifetime public abstract AgentThread AgentThread { get; } + public abstract AgentThread CreatedAgentThread { get; } + public abstract AgentThread ServiceFailingAgentThread { get; } public abstract AgentThread CreatedServiceFailingAgentThread { get; } diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentThreadConformance/AgentThreadTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentThreadConformance/AgentThreadTests.cs index 075f20c56eb0..4b30e142fee0 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentThreadConformance/AgentThreadTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentThreadConformance/AgentThreadTests.cs @@ -20,20 +20,16 @@ public abstract class AgentThreadTests(Func createAgentFixture) : [Fact] public virtual async Task DeletingThreadTwiceDoesNotThrowAsync() { - await this.Fixture.AgentThread.CreateAsync(); - - await this.Fixture.AgentThread.DeleteAsync(); - await this.Fixture.AgentThread.DeleteAsync(); + await this.Fixture.CreatedAgentThread.DeleteAsync(); + await this.Fixture.CreatedAgentThread.DeleteAsync(); } [Fact] public virtual async Task UsingThreadAfterDeleteThrowsAsync() { - await this.Fixture.AgentThread.CreateAsync(); - await this.Fixture.AgentThread.DeleteAsync(); + await this.Fixture.CreatedAgentThread.DeleteAsync(); - await Assert.ThrowsAsync(async () => await this.Fixture.AgentThread.CreateAsync()); - await Assert.ThrowsAsync(async () => await this.Fixture.AgentThread.OnNewMessageAsync(new ChatMessageContent(AuthorRole.User, "Hi"))); + await Assert.ThrowsAsync(async () => await this.Fixture.CreatedAgentThread.OnNewMessageAsync(new ChatMessageContent(AuthorRole.User, "Hi"))); } [Fact] @@ -49,12 +45,6 @@ public virtual async Task UsingThreadbeforeCreateCreatesAsync() Assert.NotNull(this.Fixture.AgentThread.Id); } - [Fact] - public virtual async Task CreateThreadWithServiceFailureThrowsAgentOperationExceptionAsync() - { - await Assert.ThrowsAsync(async () => await this.Fixture.ServiceFailingAgentThread.CreateAsync()); - } - [Fact] public virtual async Task DeleteThreadWithServiceFailureThrowsAgentOperationExceptionAsync() { diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentThreadConformance/ChatCompletionAgentThreadTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentThreadConformance/ChatCompletionAgentThreadTests.cs index fde4c3fd751b..9981a38ab41d 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentThreadConformance/ChatCompletionAgentThreadTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentThreadConformance/ChatCompletionAgentThreadTests.cs @@ -7,13 +7,6 @@ namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.Agen public class ChatCompletionAgentThreadTests() : AgentThreadTests(() => new ChatCompletionAgentFixture()) { - [Fact] - public override Task CreateThreadWithServiceFailureThrowsAgentOperationExceptionAsync() - { - // Test not applicable since there is no service to fail. - return Task.CompletedTask; - } - [Fact] public override Task DeleteThreadWithServiceFailureThrowsAgentOperationExceptionAsync() { diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AzureAIAgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AzureAIAgentFixture.cs index 018ffccb1918..819e604b2e64 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AzureAIAgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AzureAIAgentFixture.cs @@ -26,6 +26,7 @@ public class AzureAIAgentFixture : AgentFixture private AAIP.Agent? _aiAgent; private AzureAIAgent? _agent; private AzureAIAgentThread? _thread; + private AzureAIAgentThread? _createdThread; private AzureAIAgentThread? _serviceFailingAgentThread; private AzureAIAgentThread? _createdServiceFailingAgentThread; @@ -33,6 +34,8 @@ public class AzureAIAgentFixture : AgentFixture public override AgentThread AgentThread => this._thread!; + public override AgentThread CreatedAgentThread => this._createdThread!; + public override AgentThread ServiceFailingAgentThread => this._serviceFailingAgentThread!; public override AgentThread CreatedServiceFailingAgentThread => this._createdServiceFailingAgentThread!; @@ -65,6 +68,14 @@ public override async Task DisposeAsync() } } + try + { + await this._agentsClient!.DeleteThreadAsync(this._createdThread!.Id); + } + catch (RequestFailedException ex) when (ex.Status == 404) + { + } + try { await this._agentsClient!.DeleteThreadAsync(this._createdServiceFailingAgentThread!.Id); @@ -95,6 +106,9 @@ await this._agentsClient.CreateAgentAsync( this._agent = new AzureAIAgent(this._aiAgent, this._agentsClient) { Kernel = kernel }; this._thread = new AzureAIAgentThread(this._agentsClient); + this._createdThread = new AzureAIAgentThread(this._agentsClient); + await this._createdThread.CreateAsync(); + var serviceFailingClient = AzureAIAgent.CreateAzureAIClient("swedencentral.api.azureml.ms;;;", new AzureCliCredential()); this._serviceFailingAgentThread = new AzureAIAgentThread(serviceFailingClient.GetAgentsClient()); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/ChatCompletionAgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/ChatCompletionAgentFixture.cs index 8e5e14ecbeee..c7fa8dbcede3 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/ChatCompletionAgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/ChatCompletionAgentFixture.cs @@ -24,11 +24,14 @@ public class ChatCompletionAgentFixture : AgentFixture private ChatCompletionAgent? _agent; private ChatHistoryAgentThread? _thread; + private ChatHistoryAgentThread? _createdThread; public override Agent Agent => this._agent!; public override AgentThread AgentThread => this._thread!; + public override AgentThread CreatedAgentThread => this._createdThread!; + public override AgentThread ServiceFailingAgentThread => null!; public override AgentThread CreatedServiceFailingAgentThread => null!; @@ -53,7 +56,7 @@ public override Task DeleteThread(AgentThread thread) return Task.CompletedTask; } - public override Task InitializeAsync() + public async override Task InitializeAsync() { AzureOpenAIConfiguration configuration = this._configuration.GetSection("AzureOpenAI").Get()!; @@ -70,7 +73,7 @@ public override Task InitializeAsync() Instructions = "You are a helpful assistant.", }; this._thread = new ChatHistoryAgentThread(); - - return Task.CompletedTask; + this._createdThread = new ChatHistoryAgentThread(); + await this._createdThread.CreateAsync(); } } diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIAssistantAgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIAssistantAgentFixture.cs index 32bcb43adb1b..25b5b28b60e0 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIAssistantAgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIAssistantAgentFixture.cs @@ -30,6 +30,7 @@ public class OpenAIAssistantAgentFixture : AgentFixture private Assistant? _assistant; private OpenAIAssistantAgent? _agent; private OpenAIAssistantAgentThread? _thread; + private OpenAIAssistantAgentThread? _createdThread; private OpenAIAssistantAgentThread? _serviceFailingAgentThread; private OpenAIAssistantAgentThread? _createdServiceFailingAgentThread; @@ -37,6 +38,8 @@ public class OpenAIAssistantAgentFixture : AgentFixture public override AgentThread AgentThread => this._thread!; + public override AgentThread CreatedAgentThread => this._createdThread!; + public override AgentThread ServiceFailingAgentThread => this._serviceFailingAgentThread!; public override AgentThread CreatedServiceFailingAgentThread => this._createdServiceFailingAgentThread!; @@ -64,6 +67,14 @@ public override async Task DisposeAsync() } } + try + { + await this._assistantClient!.DeleteThreadAsync(this._createdThread!.Id); + } + catch (ClientResultException ex) when (ex.Status == 404) + { + } + await this._assistantClient!.DeleteAssistantAsync(this._assistant!.Id); } @@ -90,6 +101,9 @@ await this._assistantClient.CreateAssistantAsync( this._agent = new OpenAIAssistantAgent(this._assistant, this._assistantClient) { Kernel = kernel }; this._thread = new OpenAIAssistantAgentThread(this._assistantClient); + this._createdThread = new OpenAIAssistantAgentThread(this._assistantClient); + await this._createdThread.CreateAsync(); + var serviceFailingClient = OpenAIAssistantAgent.CreateAzureOpenAIClient(new AzureCliCredential(), new Uri("https://localhost/failingserviceclient")); this._serviceFailingAgentThread = new OpenAIAssistantAgentThread(serviceFailingClient.GetAssistantClient());