Skip to content

.Net: Make AgentThread.Create protected and expose on concrete implementations where possible #11133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions dotnet/src/Agents/Abstractions/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,6 @@ protected async Task<TThreadType> EnsureThreadExistsWithMessageAsync<TThreadType
throw new KernelException($"{this.GetType().Name} currently only supports agent threads of type {nameof(TThreadType)}.");
}

await thread.CreateAsync(cancellationToken).ConfigureAwait(false);

// Notify the thread that new messages are available.
foreach (var message in messages)
{
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Agents/Abstractions/AgentThread.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public abstract class AgentThread
/// </summary>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task that completes when the thread has been created.</returns>
public virtual async Task CreateAsync(CancellationToken cancellationToken = default)
protected virtual async Task CreateAsync(CancellationToken cancellationToken = default)
{
if (this.IsDeleted)
{
Expand Down
10 changes: 10 additions & 0 deletions dotnet/src/Agents/AzureAI/AzureAIAgentThread.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ public AzureAIAgentThread(AgentsClient client, string id)
this.Id = id;
}

/// <summary>
/// Creates the thread and returns the thread id.
/// </summary>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task that completes when the thread has been created.</returns>
public new Task CreateAsync(CancellationToken cancellationToken = default)
{
return base.CreateAsync(cancellationToken);
}

/// <inheritdoc />
protected async override Task<string?> CreateInternalAsync(CancellationToken cancellationToken)
{
Expand Down
10 changes: 10 additions & 0 deletions dotnet/src/Agents/Core/ChatHistoryAgentThread.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ public ChatHistoryAgentThread(ChatHistory chatHistory, string? id = null)
this.Id = id ?? Guid.NewGuid().ToString("N");
}

/// <summary>
/// Creates the thread and returns the thread id.
/// </summary>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task that completes when the thread has been created.</returns>
public new Task CreateAsync(CancellationToken cancellationToken = default)
{
return base.CreateAsync(cancellationToken);
}

/// <inheritdoc />
protected override Task<string?> CreateInternalAsync(CancellationToken cancellationToken)
{
Expand Down
10 changes: 10 additions & 0 deletions dotnet/src/Agents/OpenAI/OpenAIAssistantAgentThread.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ public OpenAIAssistantAgentThread(AssistantClient client, string id)
this.Id = id;
}

/// <summary>
/// Creates the thread and returns the thread id.
/// </summary>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task that completes when the thread has been created.</returns>
public new Task CreateAsync(CancellationToken cancellationToken = default)
{
return base.CreateAsync(cancellationToken);
}

/// <inheritdoc />
protected async override Task<string?> CreateInternalAsync(CancellationToken cancellationToken)
{
Expand Down
5 changes: 5 additions & 0 deletions dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ private sealed class TestAgentThread : AgentThread
public int DeleteInternalAsyncCount { get; private set; }
public int OnNewMessageInternalAsyncCount { get; private set; }

public new Task CreateAsync(CancellationToken cancellationToken = default)
{
return base.CreateAsync(cancellationToken);
}

protected override Task<string?> CreateInternalAsync(CancellationToken cancellationToken)
{
this.CreateInternalAsyncCount++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public abstract class AgentFixture : IAsyncLifetime

public abstract AgentThread AgentThread { get; }

public abstract AgentThread CreatedAgentThread { get; }

public abstract AgentThread ServiceFailingAgentThread { get; }

public abstract AgentThread CreatedServiceFailingAgentThread { get; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,16 @@ public abstract class AgentThreadTests(Func<AgentFixture> createAgentFixture) :
[Fact]
public virtual async Task DeletingThreadTwiceDoesNotThrowAsync()
{
await this.Fixture.AgentThread.CreateAsync();

await this.Fixture.AgentThread.DeleteAsync();
await this.Fixture.AgentThread.DeleteAsync();
await this.Fixture.CreatedAgentThread.DeleteAsync();
await this.Fixture.CreatedAgentThread.DeleteAsync();
}

[Fact]
public virtual async Task UsingThreadAfterDeleteThrowsAsync()
{
await this.Fixture.AgentThread.CreateAsync();
await this.Fixture.AgentThread.DeleteAsync();
await this.Fixture.CreatedAgentThread.DeleteAsync();

await Assert.ThrowsAsync<InvalidOperationException>(async () => await this.Fixture.AgentThread.CreateAsync());
await Assert.ThrowsAsync<InvalidOperationException>(async () => await this.Fixture.AgentThread.OnNewMessageAsync(new ChatMessageContent(AuthorRole.User, "Hi")));
await Assert.ThrowsAsync<InvalidOperationException>(async () => await this.Fixture.CreatedAgentThread.OnNewMessageAsync(new ChatMessageContent(AuthorRole.User, "Hi")));
}

[Fact]
Expand All @@ -49,12 +45,6 @@ public virtual async Task UsingThreadbeforeCreateCreatesAsync()
Assert.NotNull(this.Fixture.AgentThread.Id);
}

[Fact]
public virtual async Task CreateThreadWithServiceFailureThrowsAgentOperationExceptionAsync()
{
await Assert.ThrowsAsync<AgentThreadOperationException>(async () => await this.Fixture.ServiceFailingAgentThread.CreateAsync());
}

[Fact]
public virtual async Task DeleteThreadWithServiceFailureThrowsAgentOperationExceptionAsync()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,6 @@ namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.Agen

public class ChatCompletionAgentThreadTests() : AgentThreadTests(() => new ChatCompletionAgentFixture())
{
[Fact]
public override Task CreateThreadWithServiceFailureThrowsAgentOperationExceptionAsync()
{
// Test not applicable since there is no service to fail.
return Task.CompletedTask;
}

[Fact]
public override Task DeleteThreadWithServiceFailureThrowsAgentOperationExceptionAsync()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ public class AzureAIAgentFixture : AgentFixture
private AAIP.Agent? _aiAgent;
private AzureAIAgent? _agent;
private AzureAIAgentThread? _thread;
private AzureAIAgentThread? _createdThread;
private AzureAIAgentThread? _serviceFailingAgentThread;
private AzureAIAgentThread? _createdServiceFailingAgentThread;

public override Agent Agent => this._agent!;

public override AgentThread AgentThread => this._thread!;

public override AgentThread CreatedAgentThread => this._createdThread!;

public override AgentThread ServiceFailingAgentThread => this._serviceFailingAgentThread!;

public override AgentThread CreatedServiceFailingAgentThread => this._createdServiceFailingAgentThread!;
Expand Down Expand Up @@ -65,6 +68,14 @@ public override async Task DisposeAsync()
}
}

try
{
await this._agentsClient!.DeleteThreadAsync(this._createdThread!.Id);
}
catch (RequestFailedException ex) when (ex.Status == 404)
{
}

try
{
await this._agentsClient!.DeleteThreadAsync(this._createdServiceFailingAgentThread!.Id);
Expand Down Expand Up @@ -95,6 +106,9 @@ await this._agentsClient.CreateAgentAsync(
this._agent = new AzureAIAgent(this._aiAgent, this._agentsClient) { Kernel = kernel };
this._thread = new AzureAIAgentThread(this._agentsClient);

this._createdThread = new AzureAIAgentThread(this._agentsClient);
await this._createdThread.CreateAsync();

var serviceFailingClient = AzureAIAgent.CreateAzureAIClient("swedencentral.api.azureml.ms;<subscription_id>;<resource_group_name>;<project_name>", new AzureCliCredential());
this._serviceFailingAgentThread = new AzureAIAgentThread(serviceFailingClient.GetAgentsClient());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ public class ChatCompletionAgentFixture : AgentFixture

private ChatCompletionAgent? _agent;
private ChatHistoryAgentThread? _thread;
private ChatHistoryAgentThread? _createdThread;

public override Agent Agent => this._agent!;

public override AgentThread AgentThread => this._thread!;

public override AgentThread CreatedAgentThread => this._createdThread!;

public override AgentThread ServiceFailingAgentThread => null!;

public override AgentThread CreatedServiceFailingAgentThread => null!;
Expand All @@ -53,7 +56,7 @@ public override Task DeleteThread(AgentThread thread)
return Task.CompletedTask;
}

public override Task InitializeAsync()
public async override Task InitializeAsync()
{
AzureOpenAIConfiguration configuration = this._configuration.GetSection("AzureOpenAI").Get<AzureOpenAIConfiguration>()!;

Expand All @@ -70,7 +73,7 @@ public override Task InitializeAsync()
Instructions = "You are a helpful assistant.",
};
this._thread = new ChatHistoryAgentThread();

return Task.CompletedTask;
this._createdThread = new ChatHistoryAgentThread();
await this._createdThread.CreateAsync();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ public class OpenAIAssistantAgentFixture : AgentFixture
private Assistant? _assistant;
private OpenAIAssistantAgent? _agent;
private OpenAIAssistantAgentThread? _thread;
private OpenAIAssistantAgentThread? _createdThread;
private OpenAIAssistantAgentThread? _serviceFailingAgentThread;
private OpenAIAssistantAgentThread? _createdServiceFailingAgentThread;

public override Agent Agent => this._agent!;

public override AgentThread AgentThread => this._thread!;

public override AgentThread CreatedAgentThread => this._createdThread!;

public override AgentThread ServiceFailingAgentThread => this._serviceFailingAgentThread!;

public override AgentThread CreatedServiceFailingAgentThread => this._createdServiceFailingAgentThread!;
Expand Down Expand Up @@ -64,6 +67,14 @@ public override async Task DisposeAsync()
}
}

try
{
await this._assistantClient!.DeleteThreadAsync(this._createdThread!.Id);
}
catch (ClientResultException ex) when (ex.Status == 404)
{
}

await this._assistantClient!.DeleteAssistantAsync(this._assistant!.Id);
}

Expand All @@ -90,6 +101,9 @@ await this._assistantClient.CreateAssistantAsync(
this._agent = new OpenAIAssistantAgent(this._assistant, this._assistantClient) { Kernel = kernel };
this._thread = new OpenAIAssistantAgentThread(this._assistantClient);

this._createdThread = new OpenAIAssistantAgentThread(this._assistantClient);
await this._createdThread.CreateAsync();

var serviceFailingClient = OpenAIAssistantAgent.CreateAzureOpenAIClient(new AzureCliCredential(), new Uri("https://localhost/failingserviceclient"));
this._serviceFailingAgentThread = new OpenAIAssistantAgentThread(serviceFailingClient.GetAssistantClient());

Expand Down
Loading