Skip to content

refactor trpc-a2a-go to latest a2a specification #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 189 additions & 83 deletions client/client.go

Large diffs are not rendered by default.

91 changes: 47 additions & 44 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import (
func TestA2AClient_SendTask(t *testing.T) {
taskID := "client-task-1"
params := protocol.SendTaskParams{
ID: taskID,
RPCID: taskID,
ID: taskID,
Message: protocol.Message{
Role: protocol.MessageRoleUser,
Parts: []protocol.Part{protocol.NewTextPart("Client test input")},
Expand Down Expand Up @@ -104,11 +105,11 @@ func TestA2AClient_SendTask(t *testing.T) {
}

// TestA2AClient_StreamTask tests the StreamTask client method for SSE.
// It covers success, HTTP errors, and non-SSE response scenarios.
func TestA2AClient_StreamTask(t *testing.T) {
taskID := "client-task-sse-1"
params := protocol.SendTaskParams{
ID: taskID,
RPCID: taskID,
SessionID: &taskID,
Message: protocol.Message{
Role: protocol.MessageRoleUser,
Parts: []protocol.Part{protocol.NewTextPart("Client SSE test")},
Expand All @@ -119,26 +120,25 @@ func TestA2AClient_StreamTask(t *testing.T) {

expectedRequest := &jsonrpc.Request{
Message: jsonrpc.Message{JSONRPC: "2.0", ID: taskID},
Method: "tasks/sendSubscribe",
Method: "message/stream",
Params: json.RawMessage(paramsBytes),
}

t.Run("StreamTask Success", func(t *testing.T) {
// Prepare mock SSE stream data.
sseEvent1Data, _ := json.Marshal(protocol.TaskStatusUpdateEvent{
ID: taskID,
TaskID: taskID,
Status: protocol.TaskStatus{State: protocol.TaskStateWorking},
Final: false,
Final: &[]bool{false}[0],
})
sseEvent2Data, _ := json.Marshal(protocol.TaskArtifactUpdateEvent{
ID: taskID,
Artifact: protocol.Artifact{Index: 0, Parts: []protocol.Part{protocol.NewTextPart("SSE data")}},
Final: false,
TaskID: taskID,
Artifact: protocol.Artifact{ArtifactID: "0", Parts: []protocol.Part{protocol.NewTextPart("SSE data")}},
})
sseEvent3Data, _ := json.Marshal(protocol.TaskStatusUpdateEvent{
ID: taskID,
TaskID: taskID,
Status: protocol.TaskStatus{State: protocol.TaskStateCompleted},
Final: true,
Final: &[]bool{true}[0],
})

// Format the mock SSE stream string.
Expand All @@ -156,7 +156,7 @@ func TestA2AClient_StreamTask(t *testing.T) {

mockHandler := createMockServerHandler(
t,
"tasks/sendSubscribe",
"message/stream",
expectedRequest,
sseStream,
http.StatusOK,
Expand Down Expand Up @@ -194,23 +194,23 @@ func TestA2AClient_StreamTask(t *testing.T) {

// Assert the content and order of received events.
require.Len(t, receivedEvents, 3, "Should receive exactly 3 events")
_, ok1 := receivedEvents[0].(protocol.TaskStatusUpdateEvent)
_, ok2 := receivedEvents[1].(protocol.TaskArtifactUpdateEvent)
_, ok3 := receivedEvents[2].(protocol.TaskStatusUpdateEvent)
_, ok1 := receivedEvents[0].(*protocol.TaskStatusUpdateEvent)
_, ok2 := receivedEvents[1].(*protocol.TaskArtifactUpdateEvent)
_, ok3 := receivedEvents[2].(*protocol.TaskStatusUpdateEvent)
assert.True(t, ok1 && ok2 && ok3, "Received event types mismatch expected sequence")
assert.Equal(t, protocol.TaskStateWorking,
receivedEvents[0].(protocol.TaskStatusUpdateEvent).Status.State, "First event state mismatch")
receivedEvents[0].(*protocol.TaskStatusUpdateEvent).Status.State, "First event state mismatch")
assert.False(t, receivedEvents[0].IsFinal(), "First event should not be final")
assert.False(t, receivedEvents[1].IsFinal(), "Second event should not be final")
assert.Equal(t, protocol.TaskStateCompleted,
receivedEvents[2].(protocol.TaskStatusUpdateEvent).Status.State, "Third event state mismatch")
receivedEvents[2].(*protocol.TaskStatusUpdateEvent).Status.State, "Third event state mismatch")
assert.True(t, receivedEvents[2].IsFinal(), "Last event should be final")
})

t.Run("StreamTask HTTP Error", func(t *testing.T) {
// Prepare mock server HTTP error response.
mockHandler := createMockServerHandler(
t, "tasks/sendSubscribe", expectedRequest, "Not Found", http.StatusNotFound, nil,
t, "message/stream", expectedRequest, "Not Found", http.StatusNotFound, nil,
)
server := httptest.NewServer(mockHandler)
defer server.Close()
Expand All @@ -228,12 +228,9 @@ func TestA2AClient_StreamTask(t *testing.T) {
})

t.Run("StreamTask Non-SSE Response", func(t *testing.T) {
// Prepare mock server returning JSON instead of SSE.
// Prepare mock server response without proper SSE headers.
mockHandler := createMockServerHandler(
t, "tasks/sendSubscribe", expectedRequest,
fmt.Sprintf(`{"jsonrpc":"2.0","id":"%s","result":"not an sse stream"}`, taskID),
http.StatusOK,
map[string]string{"Content-Type": "application/json"}, // Incorrect Content-Type for SSE.
t, "message/stream", expectedRequest, "Not an SSE response", http.StatusOK, nil,
)
server := httptest.NewServer(mockHandler)
defer server.Close()
Expand All @@ -245,25 +242,26 @@ func TestA2AClient_StreamTask(t *testing.T) {
eventChan, err := client.StreamTask(context.Background(), params)

// Assertions.
require.Error(t, err, "StreamTask should return an error on non-SSE response")
require.Error(t, err, "StreamTask should return an error for non-SSE response")
assert.Nil(t, eventChan, "Event channel should be nil on error")
assert.Contains(t, err.Error(), "server did not respond with Content-Type 'text/event-stream'")
assert.Contains(t, err.Error(), "did not respond with Content-Type 'text/event-stream'")
})
}

// TestA2AClient_GetTasks tests the GetTasks client method covering success and error scenarios.
func TestA2AClient_GetTasks(t *testing.T) {
taskID := "client-get-task-1"
params := protocol.TaskQueryParams{
ID: taskID,
RPCID: taskID,
ID: taskID,
}
paramsBytes, err := json.Marshal(params)
require.NoError(t, err)

expectedRequest := &jsonrpc.Request{
Message: jsonrpc.Message{JSONRPC: "2.0", ID: taskID},
Method: "tasks/get",
Params: json.RawMessage(paramsBytes),
Params: paramsBytes,
}

t.Run("GetTasks Success", func(t *testing.T) {
Expand Down Expand Up @@ -349,15 +347,16 @@ func TestA2AClient_GetTasks(t *testing.T) {
func TestA2AClient_CancelTasks(t *testing.T) {
taskID := "client-cancel-task-1"
params := protocol.TaskIDParams{
ID: taskID,
RPCID: taskID,
ID: taskID,
}
paramsBytes, err := json.Marshal(params)
require.NoError(t, err)

expectedRequest := &jsonrpc.Request{
Message: jsonrpc.Message{JSONRPC: "2.0", ID: taskID},
Method: "tasks/cancel",
Params: json.RawMessage(paramsBytes),
Params: paramsBytes,
}

t.Run("CancelTasks Success", func(t *testing.T) {
Expand Down Expand Up @@ -435,24 +434,28 @@ func TestA2AClient_CancelTasks(t *testing.T) {
func TestA2AClient_SetPushNotification(t *testing.T) {
taskID := "client-push-task-1"
params := protocol.TaskPushNotificationConfig{
ID: taskID,
RPCID: taskID,
TaskID: taskID,
PushNotificationConfig: protocol.PushNotificationConfig{
URL: "https://example.com/webhook",
Authentication: &protocol.AuthenticationInfo{
Schemes: []string{"bearer"},
},
},
}
paramsBytes, err := json.Marshal(params)
require.NoError(t, err)

expectedRequest := &jsonrpc.Request{
Message: jsonrpc.Message{JSONRPC: "2.0", ID: taskID},
Method: "tasks/pushNotification/set",
Params: json.RawMessage(paramsBytes),
Method: "tasks/pushNotificationConfig/set",
Params: paramsBytes,
}

t.Run("SetPushNotification Success", func(t *testing.T) {
// Prepare mock server response
respConfig := protocol.TaskPushNotificationConfig{
ID: taskID,
TaskID: taskID,
PushNotificationConfig: protocol.PushNotificationConfig{
URL: "https://example.com/webhook",
},
Expand All @@ -463,7 +466,7 @@ func TestA2AClient_SetPushNotification(t *testing.T) {

mockHandler := createMockServerHandler(
t,
"tasks/pushNotification/set",
"tasks/pushNotificationConfig/set",
expectedRequest,
respBody,
http.StatusOK,
Expand All @@ -482,7 +485,7 @@ func TestA2AClient_SetPushNotification(t *testing.T) {
require.NoError(t, err, "SetPushNotification should not return an error on success")
require.NotNil(t, result, "Result should not be nil on success")

assert.Equal(t, taskID, result.ID)
assert.Equal(t, taskID, result.TaskID)
assert.Equal(t, "https://example.com/webhook", result.PushNotificationConfig.URL)
})

Expand All @@ -500,7 +503,7 @@ func TestA2AClient_SetPushNotification(t *testing.T) {

mockHandler := createMockServerHandler(
t,
"tasks/pushNotification/set",
"tasks/pushNotificationConfig/set",
expectedRequest,
errorResp,
http.StatusOK,
Expand All @@ -526,21 +529,21 @@ func TestA2AClient_SetPushNotification(t *testing.T) {
func TestA2AClient_GetPushNotification(t *testing.T) {
taskID := "client-push-get-1"
params := protocol.TaskIDParams{
ID: taskID,
RPCID: taskID,
ID: taskID,
}
paramsBytes, err := json.Marshal(params)
require.NoError(t, err)

expectedRequest := &jsonrpc.Request{
Message: jsonrpc.Message{JSONRPC: "2.0", ID: taskID},
Method: "tasks/pushNotification/get",
Params: json.RawMessage(paramsBytes),
Method: "tasks/pushNotificationConfig/get",
Params: paramsBytes,
}

t.Run("GetPushNotification Success", func(t *testing.T) {
// Prepare mock server response
respConfig := protocol.TaskPushNotificationConfig{
ID: taskID,
TaskID: taskID,
PushNotificationConfig: protocol.PushNotificationConfig{
URL: "https://example.com/webhook",
Authentication: &protocol.AuthenticationInfo{
Expand All @@ -554,7 +557,7 @@ func TestA2AClient_GetPushNotification(t *testing.T) {

mockHandler := createMockServerHandler(
t,
"tasks/pushNotification/get",
"tasks/pushNotificationConfig/get",
expectedRequest,
respBody,
http.StatusOK,
Expand All @@ -573,7 +576,7 @@ func TestA2AClient_GetPushNotification(t *testing.T) {
require.NoError(t, err, "GetPushNotification should not return an error on success")
require.NotNil(t, result, "Result should not be nil on success")

assert.Equal(t, taskID, result.ID)
assert.Equal(t, taskID, result.TaskID)
assert.Equal(t, "https://example.com/webhook", result.PushNotificationConfig.URL)
require.NotNil(t, result.PushNotificationConfig.Authentication)
assert.Contains(t, result.PushNotificationConfig.Authentication.Schemes, "bearer")
Expand All @@ -593,7 +596,7 @@ func TestA2AClient_GetPushNotification(t *testing.T) {

mockHandler := createMockServerHandler(
t,
"tasks/pushNotification/get",
"tasks/pushNotificationConfig/get",
expectedRequest,
errorResp,
http.StatusOK,
Expand Down
61 changes: 40 additions & 21 deletions examples/auth/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,43 +83,62 @@ func main() {
textPart := protocol.NewTextPart(config.TaskMessage)
message := protocol.NewMessage(protocol.MessageRoleUser, []protocol.Part{textPart})

// Prepare task parameters
taskParams := protocol.SendTaskParams{
ID: config.TaskID,
// Prepare message parameters
params := protocol.SendMessageParams{
Message: message,
}

// Add session ID if provided
// Add context ID if session ID is provided
if config.SessionID != "" {
taskParams.SessionID = &config.SessionID
// In the new protocol, we use contextID instead of sessionID
params.Message.ContextID = &config.SessionID
}

// Send the task
// Send the message
ctx, cancel := context.WithTimeout(context.Background(), config.Timeout)
defer cancel()

task, err := a2aClient.SendTasks(ctx, taskParams)
result, err := a2aClient.SendMessage(ctx, params)

if err != nil {
log.Fatalf("Failed to send task: %v", err)
log.Fatalf("Failed to send message: %v", err)
}

fmt.Printf("Task ID: %s, Status: %s\n", task.ID, task.Status.State)
if task.SessionID != nil {
fmt.Printf("Session ID: %s\n", *task.SessionID)
}
// Handle the response based on its type
switch response := result.Result.(type) {
case *protocol.Message:
fmt.Printf("Message Response: %s\n", response.MessageID)
if response.ContextID != nil {
fmt.Printf("Context ID: %s\n", *response.ContextID)
}
// Print message parts
for _, part := range response.Parts {
if textPart, ok := part.(*protocol.TextPart); ok {
fmt.Printf("Response: %s\n", textPart.Text)
}
}

// For demonstration purposes, get the task status
taskQuery := protocol.TaskQueryParams{
ID: task.ID,
}
case *protocol.Task:
fmt.Printf("Task ID: %s, Status: %s\n", response.ID, response.Status.State)
if response.ContextID != "" {
fmt.Printf("Context ID: %s\n", response.ContextID)
}

updatedTask, err := a2aClient.GetTasks(ctx, taskQuery)
if err != nil {
log.Fatalf("Failed to get task: %v", err)
}
// For demonstration purposes, get the task status
taskQuery := protocol.TaskQueryParams{
ID: response.ID,
}

updatedTask, err := a2aClient.GetTasks(ctx, taskQuery)
if err != nil {
log.Fatalf("Failed to get task: %v", err)
}

fmt.Printf("Updated task status: %s\n", updatedTask.Status.State)
fmt.Printf("Updated task status: %s\n", updatedTask.Status.State)

default:
fmt.Printf("Unknown response type: %T\n", response)
}
}

// parseFlags parses command-line flags and returns a Config.
Expand Down
Loading
Loading