diff --git a/README.md b/README.md index 89796f4..24c4577 100644 --- a/README.md +++ b/README.md @@ -148,9 +148,9 @@ go run main.go --auth jwt --message "Custom message" --session-id "session123" ## Creating Your Own Agent -### 1. Implement the TaskProcessor Interface +### 1. Implement the MessageProcessor Interface -This interface defines how your agent processes incoming tasks: +This interface defines how your agent processes incoming messages: ```go import ( @@ -160,23 +160,49 @@ import ( "trpc.group/trpc-go/trpc-a2a-go/taskmanager" ) -// Implement the TaskProcessor interface -type myTaskProcessor struct { - // Optional: add your custom fields +// Implement the MessageProcessor interface +type myMessageProcessor struct { + // Add your custom fields here } -func (p *myTaskProcessor) Process( +func (p *myMessageProcessor) ProcessMessage( ctx context.Context, - taskID string, message protocol.Message, - handle taskmanager.TaskHandle, -) error { - // 1. Extract input data from message - // 2. Process data, generate results - // 3. Use handle to update task status and add artifacts - - // Processing complete, return nil for success - return nil + options taskmanager.ProcessOptions, + handle taskmanager.TaskHandler, +) (*taskmanager.MessageProcessingResult, error) { + // Extract text from the incoming message + text := extractTextFromMessage(message) + + // Process the text (example: reverse it) + result := reverseString(text) + + // Return a simple response message + responseMessage := protocol.NewMessage( + protocol.MessageRoleAgent, + []protocol.Part{protocol.NewTextPart("Processed: " + result)}, + ) + + return &taskmanager.MessageProcessingResult{ + Result: &responseMessage, + }, nil +} + +func extractTextFromMessage(message protocol.Message) string { + for _, part := range message.Parts { + if textPart, ok := part.(*protocol.TextPart); ok { + return textPart.Text + } + } + return "" +} + +func reverseString(s string) string { + runes := []rune(s) + for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 { + runes[i], runes[j] = runes[j], runes[i] + } + return string(runes) } ``` @@ -195,6 +221,11 @@ func stringPtr(s string) *string { return &s } +// Helper function to create bool pointers +func boolPtr(b bool) *bool { + return &b +} + agentCard := server.AgentCard{ Name: "My Agent", Description: stringPtr("Agent description"), @@ -204,20 +235,17 @@ agentCard := server.AgentCard{ Name: "Provider name", }, Capabilities: server.AgentCapabilities{ - Streaming: true, - StateTransitionHistory: true, + Streaming: boolPtr(true), }, - DefaultInputModes: []string{string(protocol.PartTypeText)}, - DefaultOutputModes: []string{string(protocol.PartTypeText)}, + DefaultInputModes: []string{protocol.KindText}, + DefaultOutputModes: []string{protocol.KindText}, Skills: []server.AgentSkill{ { - ID: "my_skill", - Name: "Skill name", - Description: stringPtr("Skill description"), - Tags: []string{"tag1", "tag2"}, - Examples: []string{"Example input"}, - InputModes: []string{string(protocol.PartTypeText)}, - OutputModes: []string{string(protocol.PartTypeText)}, + ID: "text_processing", + Name: "Text Processing", + Description: stringPtr("Process and transform text input"), + InputModes: []string{protocol.KindText}, + OutputModes: []string{protocol.KindText}, }, }, } @@ -236,7 +264,7 @@ import ( ) // Create the task processor -processor := &myTaskProcessor{} +processor := &myMessageProcessor{} // Create task manager, inject processor taskManager, err := taskmanager.NewMemoryTaskManager(processor) diff --git a/client/client.go b/client/client.go index 28c5ffe..051aa0e 100644 --- a/client/client.go +++ b/client/client.go @@ -43,8 +43,7 @@ type A2AClient struct { // NewA2AClient creates a new A2A client targeting the specified agentURL. // The agentURL should be the base endpoint for the agent (e.g., "http://localhost:8080/"). -// Options can be provided to configure the client, such as setting a custom -// http.Client or timeout. +// Options can be provided to configure the client, such as setting a custom http.Client or timeout. // Returns an error if the agentURL is invalid. func NewA2AClient(agentURL string, opts ...Option) (*A2AClient, error) { if !strings.HasSuffix(agentURL, "/") { @@ -70,12 +69,15 @@ func NewA2AClient(agentURL string, opts ...Option) (*A2AClient, error) { } // SendTasks sends a message using the tasks/send method. +// deprecated: use SendMessage instead // It returns the initial task state received from the agent. func (c *A2AClient) SendTasks( ctx context.Context, params protocol.SendTaskParams, ) (*protocol.Task, error) { - request := jsonrpc.NewRequest(protocol.MethodTasksSend, params.ID) + log.Info("SendTasks is deprecated in a2a specification, use SendMessage instead") + + request := jsonrpc.NewRequest(protocol.MethodTasksSend, params.RPCID) paramsBytes, err := json.Marshal(params) if err != nil { return nil, fmt.Errorf("a2aClient.SendTasks: failed to marshal params: %w", err) @@ -90,12 +92,27 @@ func (c *A2AClient) SendTasks( return task, nil } -// GetTasks retrieves the status of a task using the tasks_get method. -func (c *A2AClient) GetTasks( +// SendMessage sends a message using the message/send method. +func (c *A2AClient) SendMessage( ctx context.Context, - params protocol.TaskQueryParams, -) (*protocol.Task, error) { - request := jsonrpc.NewRequest(protocol.MethodTasksGet, params.ID) + params protocol.SendMessageParams, +) (*protocol.MessageResult, error) { + request := jsonrpc.NewRequest(protocol.MethodMessageSend, params.RPCID) + paramsBytes, err := json.Marshal(params) + if err != nil { + return nil, fmt.Errorf("a2aClient.SendMessage: failed to marshal params: %w", err) + } + request.Params = paramsBytes + message, err := c.doRequestAndDecodeMessage(ctx, request) + if err != nil { + return nil, fmt.Errorf("a2aClient.SendMessage: %w", err) + } + return message, nil +} + +// GetTasks retrieves the status of a task using the tasks_get method. +func (c *A2AClient) GetTasks(ctx context.Context, params protocol.TaskQueryParams) (*protocol.Task, error) { + request := jsonrpc.NewRequest(protocol.MethodTasksGet, params.RPCID) paramsBytes, err := json.Marshal(params) if err != nil { return nil, fmt.Errorf("a2aClient.GetTasks: failed to marshal params: %w", err) @@ -114,7 +131,7 @@ func (c *A2AClient) CancelTasks( ctx context.Context, params protocol.TaskIDParams, ) (*protocol.Task, error) { - request := jsonrpc.NewRequest(protocol.MethodTasksCancel, params.ID) + request := jsonrpc.NewRequest(protocol.MethodTasksCancel, params.RPCID) paramsBytes, err := json.Marshal(params) if err != nil { return nil, fmt.Errorf("a2aClient.CancelTasks: failed to marshal params: %w", err) @@ -128,33 +145,65 @@ func (c *A2AClient) CancelTasks( } // StreamTask sends a message using tasks_sendSubscribe and returns a channel for receiving SSE events. +// deprecated: use StreamMessage instead // It handles setting up the SSE connection and parsing events. // The returned channel will be closed when the stream ends (task completion, error, or context cancellation). func (c *A2AClient) StreamTask( ctx context.Context, params protocol.SendTaskParams, ) (<-chan protocol.TaskEvent, error) { + log.Info("StreamTask is deprecated in a2a specification, use StreamMessage instead") // Create the JSON-RPC request. - request := jsonrpc.NewRequest(protocol.MethodTasksSendSubscribe, params.ID) paramsBytes, err := json.Marshal(params) if err != nil { return nil, fmt.Errorf("a2aClient.StreamTask: failed to marshal params: %w", err) } + resp, err := c.sendA2AStreamRequest(ctx, params.RPCID, paramsBytes) + if err != nil { + return nil, fmt.Errorf("a2aClient.StreamTask: failed to build stream request: %w", err) + } + // Create the channel to send events back to the caller. + eventsChan := make(chan protocol.TaskEvent, 10) // Buffered channel. + // Start a goroutine to read from the SSE stream. + go processSSEStream(ctx, resp, params.ID, eventsChan) + return eventsChan, nil +} + +// StreamMessage sends a message using message/streamSubscribe and returns a channel for receiving SSE events. +// It handles setting up the SSE connection and parsing events. +// The returned channel will be closed when the stream ends (task completion, error, or context cancellation). +func (c *A2AClient) StreamMessage( + ctx context.Context, + params protocol.SendMessageParams, +) (<-chan protocol.StreamingMessageEvent, error) { + // Create the JSON-RPC request. + paramsBytes, err := json.Marshal(params) + if err != nil { + return nil, fmt.Errorf("a2aClient.StreamMessage: failed to marshal params: %w", err) + } + resp, err := c.sendA2AStreamRequest(ctx, params.RPCID, paramsBytes) + if err != nil { + return nil, fmt.Errorf("a2aClient.StreamMessage: failed to build stream request: %w", err) + } + eventsChan := make(chan protocol.StreamingMessageEvent, 10) // Buffered channel. + // Start a goroutine to read from the SSE stream. + go processSSEStream(ctx, resp, params.RPCID, eventsChan) + return eventsChan, nil +} + +func (c *A2AClient) sendA2AStreamRequest(ctx context.Context, id string, paramsBytes []byte) (*http.Response, error) { + // Create the JSON-RPC request. + request := jsonrpc.NewRequest(protocol.MethodMessageStream, id) request.Params = paramsBytes reqBody, err := json.Marshal(request) if err != nil { - return nil, fmt.Errorf("a2aClient.StreamTask: failed to marshal request body: %w", err) + return nil, fmt.Errorf("a2aClient.sendA2AStreamRequest: failed to marshal request body: %w", err) } // Construct the target URL. targetURL := c.baseURL.String() - req, err := http.NewRequestWithContext( - ctx, - http.MethodPost, - targetURL, - bytes.NewReader(reqBody), - ) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(reqBody)) if err != nil { - return nil, fmt.Errorf("a2aClient.StreamTask: failed to create http request: %w", err) + return nil, fmt.Errorf("a2aClient.sendA2AStreamRequest: failed to create http request: %w", err) } // Set headers, including Accept for event stream. req.Header.Set("Content-Type", "application/json; charset=utf-8") @@ -163,13 +212,13 @@ func (c *A2AClient) StreamTask( req.Header.Set("User-Agent", c.userAgent) } log.Debugf("A2A Client Stream Request -> Method: %s, ID: %v, URL: %s", request.Method, request.ID, targetURL) - // Make the initial request to establish the stream. + resp, err := c.httpReqHandler.Handle(ctx, c.httpClient, req) if err != nil { - return nil, fmt.Errorf("a2aClient.StreamTask: http request failed: %w", err) + return nil, fmt.Errorf("a2aClient.sendA2AStreamRequest: http request failed: %w", err) } if resp == nil || resp.Body == nil { - return nil, fmt.Errorf("a2aClient.StreamTask: unexpected nil response") + return nil, fmt.Errorf("a2aClient.sendA2AStreamRequest: unexpected nil response") } // Check for non-success HTTP status codes. // For SSE, a successful setup should result in 200 OK. @@ -178,7 +227,7 @@ func (c *A2AClient) StreamTask( bodyBytes, _ := io.ReadAll(resp.Body) resp.Body.Close() return nil, fmt.Errorf( - "a2aClient.StreamTask: unexpected http status %d establishing stream: %s", + "a2aClient.sendA2AStreamRequest: unexpected http status %d establishing stream: %s", resp.StatusCode, string(bodyBytes), ) } @@ -186,51 +235,48 @@ func (c *A2AClient) StreamTask( if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") { resp.Body.Close() return nil, fmt.Errorf( - "a2aClient.StreamTask: server did not respond with Content-Type 'text/event-stream', got %s", + "a2aClient.sendA2AStreamRequest: server did not respond with Content-Type 'text/event-stream', got %s", resp.Header.Get("Content-Type"), ) } - log.Debugf("A2A Client Stream Response <- Status: %d, ID: %v. Stream established.", resp.StatusCode, request.ID) - // Create the channel to send events back to the caller. - eventsChan := make(chan protocol.TaskEvent, 10) // Buffered channel. - // Start a goroutine to read from the SSE stream. - go c.processSSEStream(ctx, resp, params.ID, eventsChan) - return eventsChan, nil + log.Debugf("A2A Client Stream Response <- Status: %d, ID: %v. Stream established.", resp.StatusCode, id) + return resp, nil } // processSSEStream reads Server-Sent Events from the response body and sends them // onto the provided channel. It handles closing the channel and response body. // Runs in its own goroutine. -func (c *A2AClient) processSSEStream( +func processSSEStream[T interface{}]( ctx context.Context, resp *http.Response, - taskID string, - eventsChan chan<- protocol.TaskEvent, + reqID string, + eventsChan chan<- T, ) { // Ensure resources are cleaned up when the goroutine exits. defer resp.Body.Close() defer close(eventsChan) + reader := sse.NewEventReader(resp.Body) - log.Debugf("SSE Processor started for task %s", taskID) + log.Debugf("SSE Processor started for request %s", reqID) for { select { case <-ctx.Done(): // Context canceled (e.g., timeout or manual cancellation by caller). - log.Debugf("SSE context canceled for task %s: %v", taskID, ctx.Err()) + log.Debugf("SSE context canceled for request %s: %v", reqID, ctx.Err()) return default: // Read the next event from the stream. eventBytes, eventType, err := reader.ReadEvent() if err != nil { if err == io.EOF { - log.Debugf("SSE stream ended cleanly (EOF) for task %s", taskID) + log.Debugf("SSE stream ended cleanly (EOF) for request %s", reqID) } else if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "connection reset by peer") { // Client disconnected normally - log.Debugf("Client disconnected from SSE stream for task %s", taskID) + log.Debugf("Client disconnected from SSE stream for request %s", reqID) } else { // Log unexpected errors (like network issues or parsing problems) - log.Errorf("Error reading SSE stream for task %s: %v", taskID, err) + log.Errorf("Error reading SSE stream for request %s: %v", reqID, err) } return // Stop processing on any error or EOF. } @@ -241,8 +287,8 @@ func (c *A2AClient) processSSEStream( // Handle close event immediately before any other processing. if eventType == protocol.EventClose { log.Debugf( - "Received explicit '%s' event from server for task %s. Data: %s", - protocol.EventClose, taskID, string(eventBytes), + "Received explicit '%s' event from server for request %s. Data: %s", + protocol.EventClose, reqID, string(eventBytes), ) return // Exit immediately, do not process any more events } @@ -253,16 +299,10 @@ func (c *A2AClient) processSSEStream( // If this is a valid JSON-RPC response, extract the result for further processing if jsonRPCErr == nil && jsonRPCResponse.JSONRPC == jsonrpc.Version { - log.Debugf( - "Received JSON-RPC wrapped event for task %s. Type: %s", - taskID, eventType, - ) + log.Debugf("Received JSON-RPC wrapped event for request %s. Type: %s", reqID, eventType) // Check for errors in the JSON-RPC response if jsonRPCResponse.Error != nil { - log.Errorf( - "JSON-RPC error in SSE event for task %s: %v", - taskID, jsonRPCResponse.Error, - ) + log.Errorf("JSON-RPC error in SSE event for request %s: %v", reqID, jsonRPCResponse.Error) continue // Skip events with JSON-RPC errors } // Use the result field directly for further processing @@ -270,45 +310,24 @@ func (c *A2AClient) processSSEStream( } // Deserialize the event data based on the event type from SSE. - var taskEvent protocol.TaskEvent - switch eventType { - case protocol.EventTaskStatusUpdate: - var statusEvent protocol.TaskStatusUpdateEvent - if err := json.Unmarshal(eventBytes, &statusEvent); err != nil { - log.Errorf( - "Error unmarshaling TaskStatusUpdateEvent for task %s: %v. Data: %s", - taskID, err, string(eventBytes), - ) - continue // Skip malformed event. - } - taskEvent = statusEvent - case protocol.EventTaskArtifactUpdate: - var artifactEvent protocol.TaskArtifactUpdateEvent - if err := json.Unmarshal(eventBytes, &artifactEvent); err != nil { - log.Errorf( - "Error unmarshaling TaskArtifactUpdateEvent for task %s: %v. Data: %s", - taskID, err, string(eventBytes), - ) - continue // Skip malformed event. - } - taskEvent = artifactEvent - default: - log.Warnf( - "Received unknown SSE event type '%s' for task %s. Data: %s", - eventType, taskID, string(eventBytes), - ) - continue // Skip unknown event types. + event, err := unmarshalSSEEvent[T](eventBytes, eventType) + if err != nil { + log.Errorf("Error unmarshaling event for request:%s data:%s, error:%v", reqID, string(eventBytes), err) + continue } + + log.Debugf("Received event for task %s: %v", reqID, event) + // Send the deserialized event to the caller's channel. // Use a select to avoid blocking if the caller isn't reading fast enough // or if the context was canceled concurrently. select { - case eventsChan <- taskEvent: + case eventsChan <- event: // Event sent successfully. case <-ctx.Done(): log.Debugf( "SSE context canceled while sending event for task %s: %v", - taskID, ctx.Err(), + reqID, ctx.Err(), ) return // Stop processing. } @@ -316,6 +335,73 @@ func (c *A2AClient) processSSEStream( } } +func unmarshalSSEEvent[T interface{}](eventBytes []byte, eventType string) (T, error) { + // Check if T is StreamingMessageEvent type - use V2 for new message API + var result T + switch any(result).(type) { + case protocol.StreamingMessageEvent: + return unmarshalSSEEventV2[T](eventBytes, eventType) + default: + // For backward compatibility with old task APIs, use V1 + if len(eventBytes) > 0 { + return unmarshalSSEEventV1[T](eventBytes, eventType) + } + return unmarshalSSEEventV2[T](eventBytes, eventType) + } +} + +func unmarshalSSEEventV2[T interface{}](eventBytes []byte, _ string) (T, error) { + var result T + if err := json.Unmarshal(eventBytes, &result); err != nil { + return result, fmt.Errorf("failed to unmarshal event: %w", err) + } + return result, nil +} + +// todo: remove with StreamTask +func unmarshalSSEEventV1[T interface{}](eventBytes []byte, eventType string) (T, error) { + var result T + + // First try to unmarshal as StreamingMessageEvent + var streamEvent protocol.StreamingMessageEvent + if err := json.Unmarshal(eventBytes, &streamEvent); err == nil { + // If it's a StreamingMessageEvent, extract the Result + if taskEvent, ok := streamEvent.Result.(protocol.TaskEvent); ok { + if converted, ok := taskEvent.(T); ok { + return converted, nil + } + } + // Try to convert Result directly + if converted, ok := streamEvent.Result.(T); ok { + return converted, nil + } + } + + // Fallback to direct unmarshaling based on event type + var event interface{} + switch eventType { + case protocol.EventStatusUpdate: + statusEvent := &protocol.TaskStatusUpdateEvent{} + if err := json.Unmarshal(eventBytes, statusEvent); err != nil { + return result, fmt.Errorf("failed to unmarshal TaskStatusUpdateEvent: %w", err) + } + event = statusEvent + case protocol.EventArtifactUpdate: + artifactEvent := &protocol.TaskArtifactUpdateEvent{} + if err := json.Unmarshal(eventBytes, artifactEvent); err != nil { + return result, fmt.Errorf("failed to unmarshal TaskArtifactUpdateEvent: %w", err) + } + event = artifactEvent + default: + return result, fmt.Errorf("unknown SSE event type: %s", eventType) + } + converted, ok := event.(T) + if !ok { + return result, fmt.Errorf("failed to convert event to %T", result) + } + return converted, nil +} + func (c *A2AClient) doRequestAndDecodeTask( ctx context.Context, request *jsonrpc.Request, @@ -345,14 +431,35 @@ func (c *A2AClient) doRequestAndDecodeTask( return task, nil } +func (c *A2AClient) doRequestAndDecodeMessage( + ctx context.Context, + request *jsonrpc.Request, +) (*protocol.MessageResult, error) { + fullResponse, err := c.doRequest(ctx, request) + if err != nil { + return nil, err + } + if fullResponse.Error != nil { + return nil, fullResponse.Error + } + if len(fullResponse.Result) == 0 { + return nil, fmt.Errorf("rpc response missing required 'result' field for id %v", request.ID) + } + messageResp := &protocol.MessageResult{} + if err := json.Unmarshal(fullResponse.Result, messageResp); err != nil { + return nil, fmt.Errorf( + "failed to unmarshal rpc result: %w. Raw result: %s", err, string(fullResponse.Result), + ) + } + return messageResp, nil +} + // doRequest performs the HTTP POST request for a JSON-RPC call. // It handles request marshaling, setting headers, sending the request, // checking the HTTP status, and decoding the base JSON response structure. // It does NOT specifically handle the 'result' or 'error' fields, leaving that // to the caller or doRequestAndDecodeResult. -func (c *A2AClient) doRequest( - ctx context.Context, request *jsonrpc.Request, -) (*jsonrpc.RawResponse, error) { +func (c *A2AClient) doRequest(ctx context.Context, request *jsonrpc.Request) (*jsonrpc.RawResponse, error) { reqBody, err := json.Marshal(request) if err != nil { // Use a more specific error message prefix. @@ -421,7 +528,7 @@ func (c *A2AClient) SetPushNotification( ctx context.Context, params protocol.TaskPushNotificationConfig, ) (*protocol.TaskPushNotificationConfig, error) { - request := jsonrpc.NewRequest(protocol.MethodTasksPushNotificationSet, params.ID) + request := jsonrpc.NewRequest(protocol.MethodTasksPushNotificationConfigSet, params.RPCID) paramsBytes, err := json.Marshal(params) if err != nil { return nil, fmt.Errorf("a2aClient.SetPushNotification: failed to marshal params: %w", err) @@ -461,7 +568,7 @@ func (c *A2AClient) GetPushNotification( ctx context.Context, params protocol.TaskIDParams, ) (*protocol.TaskPushNotificationConfig, error) { - request := jsonrpc.NewRequest(protocol.MethodTasksPushNotificationGet, params.ID) + request := jsonrpc.NewRequest(protocol.MethodTasksPushNotificationConfigGet, params.RPCID) paramsBytes, err := json.Marshal(params) if err != nil { return nil, fmt.Errorf("a2aClient.GetPushNotification: failed to marshal params: %w", err) @@ -498,7 +605,6 @@ func (c *A2AClient) GetPushNotification( // httpRequestHandler is the HTTP request handler for a2a client. type httpRequestHandler struct { - handler HTTPReqHandler } // Handle is the HTTP request handler for a2a client. diff --git a/client/client_test.go b/client/client_test.go index 27e0099..162f291 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -28,7 +28,8 @@ import ( func TestA2AClient_SendTask(t *testing.T) { taskID := "client-task-1" params := protocol.SendTaskParams{ - ID: taskID, + RPCID: taskID, + ID: taskID, Message: protocol.Message{ Role: protocol.MessageRoleUser, Parts: []protocol.Part{protocol.NewTextPart("Client test input")}, @@ -104,11 +105,11 @@ func TestA2AClient_SendTask(t *testing.T) { } // TestA2AClient_StreamTask tests the StreamTask client method for SSE. -// It covers success, HTTP errors, and non-SSE response scenarios. func TestA2AClient_StreamTask(t *testing.T) { taskID := "client-task-sse-1" params := protocol.SendTaskParams{ - ID: taskID, + RPCID: taskID, + SessionID: &taskID, Message: protocol.Message{ Role: protocol.MessageRoleUser, Parts: []protocol.Part{protocol.NewTextPart("Client SSE test")}, @@ -119,26 +120,25 @@ func TestA2AClient_StreamTask(t *testing.T) { expectedRequest := &jsonrpc.Request{ Message: jsonrpc.Message{JSONRPC: "2.0", ID: taskID}, - Method: "tasks/sendSubscribe", + Method: "message/stream", Params: json.RawMessage(paramsBytes), } t.Run("StreamTask Success", func(t *testing.T) { // Prepare mock SSE stream data. sseEvent1Data, _ := json.Marshal(protocol.TaskStatusUpdateEvent{ - ID: taskID, + TaskID: taskID, Status: protocol.TaskStatus{State: protocol.TaskStateWorking}, - Final: false, + Final: &[]bool{false}[0], }) sseEvent2Data, _ := json.Marshal(protocol.TaskArtifactUpdateEvent{ - ID: taskID, - Artifact: protocol.Artifact{Index: 0, Parts: []protocol.Part{protocol.NewTextPart("SSE data")}}, - Final: false, + TaskID: taskID, + Artifact: protocol.Artifact{ArtifactID: "0", Parts: []protocol.Part{protocol.NewTextPart("SSE data")}}, }) sseEvent3Data, _ := json.Marshal(protocol.TaskStatusUpdateEvent{ - ID: taskID, + TaskID: taskID, Status: protocol.TaskStatus{State: protocol.TaskStateCompleted}, - Final: true, + Final: &[]bool{true}[0], }) // Format the mock SSE stream string. @@ -156,7 +156,7 @@ func TestA2AClient_StreamTask(t *testing.T) { mockHandler := createMockServerHandler( t, - "tasks/sendSubscribe", + "message/stream", expectedRequest, sseStream, http.StatusOK, @@ -194,23 +194,23 @@ func TestA2AClient_StreamTask(t *testing.T) { // Assert the content and order of received events. require.Len(t, receivedEvents, 3, "Should receive exactly 3 events") - _, ok1 := receivedEvents[0].(protocol.TaskStatusUpdateEvent) - _, ok2 := receivedEvents[1].(protocol.TaskArtifactUpdateEvent) - _, ok3 := receivedEvents[2].(protocol.TaskStatusUpdateEvent) + _, ok1 := receivedEvents[0].(*protocol.TaskStatusUpdateEvent) + _, ok2 := receivedEvents[1].(*protocol.TaskArtifactUpdateEvent) + _, ok3 := receivedEvents[2].(*protocol.TaskStatusUpdateEvent) assert.True(t, ok1 && ok2 && ok3, "Received event types mismatch expected sequence") assert.Equal(t, protocol.TaskStateWorking, - receivedEvents[0].(protocol.TaskStatusUpdateEvent).Status.State, "First event state mismatch") + receivedEvents[0].(*protocol.TaskStatusUpdateEvent).Status.State, "First event state mismatch") assert.False(t, receivedEvents[0].IsFinal(), "First event should not be final") assert.False(t, receivedEvents[1].IsFinal(), "Second event should not be final") assert.Equal(t, protocol.TaskStateCompleted, - receivedEvents[2].(protocol.TaskStatusUpdateEvent).Status.State, "Third event state mismatch") + receivedEvents[2].(*protocol.TaskStatusUpdateEvent).Status.State, "Third event state mismatch") assert.True(t, receivedEvents[2].IsFinal(), "Last event should be final") }) t.Run("StreamTask HTTP Error", func(t *testing.T) { // Prepare mock server HTTP error response. mockHandler := createMockServerHandler( - t, "tasks/sendSubscribe", expectedRequest, "Not Found", http.StatusNotFound, nil, + t, "message/stream", expectedRequest, "Not Found", http.StatusNotFound, nil, ) server := httptest.NewServer(mockHandler) defer server.Close() @@ -228,12 +228,9 @@ func TestA2AClient_StreamTask(t *testing.T) { }) t.Run("StreamTask Non-SSE Response", func(t *testing.T) { - // Prepare mock server returning JSON instead of SSE. + // Prepare mock server response without proper SSE headers. mockHandler := createMockServerHandler( - t, "tasks/sendSubscribe", expectedRequest, - fmt.Sprintf(`{"jsonrpc":"2.0","id":"%s","result":"not an sse stream"}`, taskID), - http.StatusOK, - map[string]string{"Content-Type": "application/json"}, // Incorrect Content-Type for SSE. + t, "message/stream", expectedRequest, "Not an SSE response", http.StatusOK, nil, ) server := httptest.NewServer(mockHandler) defer server.Close() @@ -245,9 +242,9 @@ func TestA2AClient_StreamTask(t *testing.T) { eventChan, err := client.StreamTask(context.Background(), params) // Assertions. - require.Error(t, err, "StreamTask should return an error on non-SSE response") + require.Error(t, err, "StreamTask should return an error for non-SSE response") assert.Nil(t, eventChan, "Event channel should be nil on error") - assert.Contains(t, err.Error(), "server did not respond with Content-Type 'text/event-stream'") + assert.Contains(t, err.Error(), "did not respond with Content-Type 'text/event-stream'") }) } @@ -255,7 +252,8 @@ func TestA2AClient_StreamTask(t *testing.T) { func TestA2AClient_GetTasks(t *testing.T) { taskID := "client-get-task-1" params := protocol.TaskQueryParams{ - ID: taskID, + RPCID: taskID, + ID: taskID, } paramsBytes, err := json.Marshal(params) require.NoError(t, err) @@ -263,7 +261,7 @@ func TestA2AClient_GetTasks(t *testing.T) { expectedRequest := &jsonrpc.Request{ Message: jsonrpc.Message{JSONRPC: "2.0", ID: taskID}, Method: "tasks/get", - Params: json.RawMessage(paramsBytes), + Params: paramsBytes, } t.Run("GetTasks Success", func(t *testing.T) { @@ -349,7 +347,8 @@ func TestA2AClient_GetTasks(t *testing.T) { func TestA2AClient_CancelTasks(t *testing.T) { taskID := "client-cancel-task-1" params := protocol.TaskIDParams{ - ID: taskID, + RPCID: taskID, + ID: taskID, } paramsBytes, err := json.Marshal(params) require.NoError(t, err) @@ -357,7 +356,7 @@ func TestA2AClient_CancelTasks(t *testing.T) { expectedRequest := &jsonrpc.Request{ Message: jsonrpc.Message{JSONRPC: "2.0", ID: taskID}, Method: "tasks/cancel", - Params: json.RawMessage(paramsBytes), + Params: paramsBytes, } t.Run("CancelTasks Success", func(t *testing.T) { @@ -435,9 +434,13 @@ func TestA2AClient_CancelTasks(t *testing.T) { func TestA2AClient_SetPushNotification(t *testing.T) { taskID := "client-push-task-1" params := protocol.TaskPushNotificationConfig{ - ID: taskID, + RPCID: taskID, + TaskID: taskID, PushNotificationConfig: protocol.PushNotificationConfig{ URL: "https://example.com/webhook", + Authentication: &protocol.AuthenticationInfo{ + Schemes: []string{"bearer"}, + }, }, } paramsBytes, err := json.Marshal(params) @@ -445,14 +448,14 @@ func TestA2AClient_SetPushNotification(t *testing.T) { expectedRequest := &jsonrpc.Request{ Message: jsonrpc.Message{JSONRPC: "2.0", ID: taskID}, - Method: "tasks/pushNotification/set", - Params: json.RawMessage(paramsBytes), + Method: "tasks/pushNotificationConfig/set", + Params: paramsBytes, } t.Run("SetPushNotification Success", func(t *testing.T) { // Prepare mock server response respConfig := protocol.TaskPushNotificationConfig{ - ID: taskID, + TaskID: taskID, PushNotificationConfig: protocol.PushNotificationConfig{ URL: "https://example.com/webhook", }, @@ -463,7 +466,7 @@ func TestA2AClient_SetPushNotification(t *testing.T) { mockHandler := createMockServerHandler( t, - "tasks/pushNotification/set", + "tasks/pushNotificationConfig/set", expectedRequest, respBody, http.StatusOK, @@ -482,7 +485,7 @@ func TestA2AClient_SetPushNotification(t *testing.T) { require.NoError(t, err, "SetPushNotification should not return an error on success") require.NotNil(t, result, "Result should not be nil on success") - assert.Equal(t, taskID, result.ID) + assert.Equal(t, taskID, result.TaskID) assert.Equal(t, "https://example.com/webhook", result.PushNotificationConfig.URL) }) @@ -500,7 +503,7 @@ func TestA2AClient_SetPushNotification(t *testing.T) { mockHandler := createMockServerHandler( t, - "tasks/pushNotification/set", + "tasks/pushNotificationConfig/set", expectedRequest, errorResp, http.StatusOK, @@ -526,21 +529,21 @@ func TestA2AClient_SetPushNotification(t *testing.T) { func TestA2AClient_GetPushNotification(t *testing.T) { taskID := "client-push-get-1" params := protocol.TaskIDParams{ - ID: taskID, + RPCID: taskID, + ID: taskID, } paramsBytes, err := json.Marshal(params) require.NoError(t, err) expectedRequest := &jsonrpc.Request{ Message: jsonrpc.Message{JSONRPC: "2.0", ID: taskID}, - Method: "tasks/pushNotification/get", - Params: json.RawMessage(paramsBytes), + Method: "tasks/pushNotificationConfig/get", + Params: paramsBytes, } t.Run("GetPushNotification Success", func(t *testing.T) { - // Prepare mock server response respConfig := protocol.TaskPushNotificationConfig{ - ID: taskID, + TaskID: taskID, PushNotificationConfig: protocol.PushNotificationConfig{ URL: "https://example.com/webhook", Authentication: &protocol.AuthenticationInfo{ @@ -554,7 +557,7 @@ func TestA2AClient_GetPushNotification(t *testing.T) { mockHandler := createMockServerHandler( t, - "tasks/pushNotification/get", + "tasks/pushNotificationConfig/get", expectedRequest, respBody, http.StatusOK, @@ -573,7 +576,7 @@ func TestA2AClient_GetPushNotification(t *testing.T) { require.NoError(t, err, "GetPushNotification should not return an error on success") require.NotNil(t, result, "Result should not be nil on success") - assert.Equal(t, taskID, result.ID) + assert.Equal(t, taskID, result.TaskID) assert.Equal(t, "https://example.com/webhook", result.PushNotificationConfig.URL) require.NotNil(t, result.PushNotificationConfig.Authentication) assert.Contains(t, result.PushNotificationConfig.Authentication.Schemes, "bearer") @@ -593,7 +596,7 @@ func TestA2AClient_GetPushNotification(t *testing.T) { mockHandler := createMockServerHandler( t, - "tasks/pushNotification/get", + "tasks/pushNotificationConfig/get", expectedRequest, errorResp, http.StatusOK, diff --git a/examples/auth/client/main.go b/examples/auth/client/main.go index 410d2b0..2781095 100644 --- a/examples/auth/client/main.go +++ b/examples/auth/client/main.go @@ -83,43 +83,62 @@ func main() { textPart := protocol.NewTextPart(config.TaskMessage) message := protocol.NewMessage(protocol.MessageRoleUser, []protocol.Part{textPart}) - // Prepare task parameters - taskParams := protocol.SendTaskParams{ - ID: config.TaskID, + // Prepare message parameters + params := protocol.SendMessageParams{ Message: message, } - // Add session ID if provided + // Add context ID if session ID is provided if config.SessionID != "" { - taskParams.SessionID = &config.SessionID + // In the new protocol, we use contextID instead of sessionID + params.Message.ContextID = &config.SessionID } - // Send the task + // Send the message ctx, cancel := context.WithTimeout(context.Background(), config.Timeout) defer cancel() - task, err := a2aClient.SendTasks(ctx, taskParams) + result, err := a2aClient.SendMessage(ctx, params) if err != nil { - log.Fatalf("Failed to send task: %v", err) + log.Fatalf("Failed to send message: %v", err) } - fmt.Printf("Task ID: %s, Status: %s\n", task.ID, task.Status.State) - if task.SessionID != nil { - fmt.Printf("Session ID: %s\n", *task.SessionID) - } + // Handle the response based on its type + switch response := result.Result.(type) { + case *protocol.Message: + fmt.Printf("Message Response: %s\n", response.MessageID) + if response.ContextID != nil { + fmt.Printf("Context ID: %s\n", *response.ContextID) + } + // Print message parts + for _, part := range response.Parts { + if textPart, ok := part.(*protocol.TextPart); ok { + fmt.Printf("Response: %s\n", textPart.Text) + } + } - // For demonstration purposes, get the task status - taskQuery := protocol.TaskQueryParams{ - ID: task.ID, - } + case *protocol.Task: + fmt.Printf("Task ID: %s, Status: %s\n", response.ID, response.Status.State) + if response.ContextID != "" { + fmt.Printf("Context ID: %s\n", response.ContextID) + } - updatedTask, err := a2aClient.GetTasks(ctx, taskQuery) - if err != nil { - log.Fatalf("Failed to get task: %v", err) - } + // For demonstration purposes, get the task status + taskQuery := protocol.TaskQueryParams{ + ID: response.ID, + } + + updatedTask, err := a2aClient.GetTasks(ctx, taskQuery) + if err != nil { + log.Fatalf("Failed to get task: %v", err) + } - fmt.Printf("Updated task status: %s\n", updatedTask.Status.State) + fmt.Printf("Updated task status: %s\n", updatedTask.Status.State) + + default: + fmt.Printf("Unknown response type: %T\n", response) + } } // parseFlags parses command-line flags and returns a Config. diff --git a/examples/auth/server/main.go b/examples/auth/server/main.go index ca97775..b1d7648 100644 --- a/examples/auth/server/main.go +++ b/examples/auth/server/main.go @@ -61,7 +61,7 @@ func main() { } // Create a simple echo processor for demonstration purposes - processor := &echoProcessor{} + processor := &echoMessageProcessor{} // Create a real task manager with our processor taskManager, err := taskmanager.NewMemoryTaskManager(processor) @@ -113,22 +113,48 @@ func main() { agentCard := server.AgentCard{ Name: "A2A Server with Authentication", - Description: addressableStr("A demonstration server with JWT and API key authentication"), + Description: "A demonstration server with JWT and API key authentication", URL: fmt.Sprintf("http://localhost:%d", config.Port), Provider: &server.AgentProvider{ Organization: "Example Provider", }, Version: "1.0.0", Capabilities: server.AgentCapabilities{ - Streaming: true, - PushNotifications: true, + Streaming: boolPtr(true), + PushNotifications: boolPtr(true), + StateTransitionHistory: boolPtr(true), }, - Authentication: &protocol.AuthenticationInfo{ - Schemes: []string{authType}, - Credentials: &config.APIKeyHeader, + SecuritySchemes: map[string]server.SecurityScheme{ + "apiKey": { + Type: "apiKey", + Description: stringPtr("API key authentication"), + Name: stringPtr(config.APIKeyHeader), + In: securitySchemeInPtr(server.SecuritySchemeInHeader), + }, + "jwt": { + Type: "http", + Description: stringPtr("JWT Bearer token authentication"), + Scheme: stringPtr("bearer"), + BearerFormat: stringPtr("JWT"), + }, + }, + Security: []map[string][]string{ + {"apiKey": {}}, + {"jwt": {}}, }, DefaultInputModes: []string{"text"}, DefaultOutputModes: []string{"text"}, + Skills: []server.AgentSkill{ + { + ID: "echo", + Name: "Echo Service", + Description: stringPtr("Echoes back the input text with authentication"), + Tags: []string{"text", "echo", "auth"}, + Examples: []string{"Hello, world!"}, + InputModes: []string{"text"}, + OutputModes: []string{"text"}, + }, + }, } // Create the server with authentication @@ -265,16 +291,16 @@ func printExampleCommands(port int, token string, enableOAuth bool, tokenEndpoin log.Printf("Using JWT authentication:") log.Printf("curl -X POST http://localhost:%d -H 'Content-Type: application/json' "+ "-H 'Authorization: Bearer %s' "+ - "-d '{\"jsonrpc\":\"2.0\",\"method\":\"tasks/send\",\"id\":1,"+ - "\"params\":{\"id\":\"task1\",\"message\":{\"role\":\"user\","+ + "-d '{\"jsonrpc\":\"2.0\",\"method\":\"message/send\",\"id\":1,"+ + "\"params\":{\"message\":{\"role\":\"user\","+ "\"parts\":[{\"type\":\"text\",\"text\":\"Hello, world!\"}]}}}'", port, token) // API key example log.Printf("\nUsing API key authentication:") log.Printf("curl -X POST http://localhost:%d -H 'Content-Type: application/json' "+ "-H 'X-API-Key: test-api-key' "+ - "-d '{\"jsonrpc\":\"2.0\",\"method\":\"tasks/send\",\"id\":1,"+ - "\"params\":{\"id\":\"task1\",\"message\":{\"role\":\"user\","+ + "-d '{\"jsonrpc\":\"2.0\",\"method\":\"message/send\",\"id\":1,"+ + "\"params\":{\"message\":{\"role\":\"user\","+ "\"parts\":[{\"type\":\"text\",\"text\":\"Hello, world!\"}]}}}'", port) // OAuth2 example if enabled @@ -286,8 +312,8 @@ func printExampleCommands(port int, token string, enableOAuth bool, tokenEndpoin log.Printf("\nStep 2: Use the token with the A2A API:") log.Printf("curl -X POST http://localhost:%d -H 'Content-Type: application/json' "+ "-H 'Authorization: Bearer ' "+ - "-d '{\"jsonrpc\":\"2.0\",\"method\":\"tasks/send\",\"id\":1,"+ - "\"params\":{\"id\":\"task1\",\"message\":{\"role\":\"user\","+ + "-d '{\"jsonrpc\":\"2.0\",\"method\":\"message/send\",\"id\":1,"+ + "\"params\":{\"message\":{\"role\":\"user\","+ "\"parts\":[{\"type\":\"text\",\"text\":\"Hello, world!\"}]}}}'", port) } @@ -296,37 +322,35 @@ func printExampleCommands(port int, token string, enableOAuth bool, tokenEndpoin log.Printf("curl http://localhost:%d/.well-known/agent.json", port) } -// echoProcessor is a simple processor that echoes user messages -type echoProcessor struct{} +// echoMessageProcessor is a simple processor that echoes user messages +type echoMessageProcessor struct{} -func (p *echoProcessor) Process( +func (p *echoMessageProcessor) ProcessMessage( ctx context.Context, - taskID string, - msg protocol.Message, - handle taskmanager.TaskHandle, -) error { + message protocol.Message, + options taskmanager.ProcessOptions, + handle taskmanager.TaskHandler, +) (*taskmanager.MessageProcessingResult, error) { // Create a concatenated string of all text parts var responseText string - for _, part := range msg.Parts { - if textPart, ok := part.(protocol.TextPart); ok { + for _, part := range message.Parts { + if textPart, ok := part.(*protocol.TextPart); ok { responseText += textPart.Text + " " } } // Create response message - responseMsg := &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{ + responseMsg := protocol.NewMessage( + protocol.MessageRoleAgent, + []protocol.Part{ protocol.NewTextPart(fmt.Sprintf("Echo: %s", responseText)), }, - } - - // Update the task status to completed with our response - if err := handle.UpdateStatus(protocol.TaskStateCompleted, responseMsg); err != nil { - return fmt.Errorf("failed to update task status: %w", err) - } + ) - return nil + // Return the response message directly + return &taskmanager.MessageProcessingResult{ + Result: &responseMsg, + }, nil } func addressableStr(s string) *string { @@ -440,3 +464,15 @@ func generateToken(clientID string, scopes []string) TokenResponse { ClientID: clientID, } } + +func boolPtr(b bool) *bool { + return &b +} + +func stringPtr(s string) *string { + return &s +} + +func securitySchemeInPtr(in server.SecuritySchemeIn) *server.SecuritySchemeIn { + return &in +} diff --git a/examples/basic/client/main.go b/examples/basic/client/main.go index e225ff2..88a1088 100644 --- a/examples/basic/client/main.go +++ b/examples/basic/client/main.go @@ -20,8 +20,6 @@ import ( "strings" "time" - "github.com/google/uuid" - "trpc.group/trpc-go/trpc-a2a-go/client" "trpc.group/trpc-go/trpc-a2a-go/protocol" "trpc.group/trpc-go/trpc-a2a-go/server" @@ -32,7 +30,7 @@ type Config struct { AgentURL string Timeout time.Duration ForceNoStreaming bool - SessionID string + ContextID string UseTasksGet bool HistoryLength int ServerPort int @@ -43,7 +41,7 @@ type Config struct { const ( cmdExit = "exit" cmdHelp = "help" - cmdSession = "session" + cmdContext = "context" cmdMode = "mode" cmdCancel = "cancel" cmdGet = "get" @@ -89,16 +87,16 @@ func parseFlags() Config { flag.StringVar(&config.AgentURL, "agent", "http://localhost:8080/", "Target A2A agent URL") flag.DurationVar(&config.Timeout, "timeout", 60*time.Second, "Request timeout (e.g., 30s, 1m)") flag.BoolVar(&config.ForceNoStreaming, "no-stream", false, "Disable streaming mode") - flag.StringVar(&config.SessionID, "session", "", "Use specific session ID (empty = generate new)") + flag.StringVar(&config.ContextID, "context", "", "Use specific context ID (empty = generate new)") flag.BoolVar(&config.UseTasksGet, "use-tasks-get", true, "Use tasks/get to fetch final state") flag.IntVar(&config.HistoryLength, "history", 0, "Number of history messages to request (0 = none)") flag.IntVar(&config.ServerPort, "port", 8090, "Port for push notification server") flag.StringVar(&config.ServerHost, "host", "localhost", "Host for push notification server") flag.Parse() - // Generate a session ID if not provided - if config.SessionID == "" { - config.SessionID = generateSessionID() + // Generate a context ID if not provided + if config.ContextID == "" { + config.ContextID = protocol.GenerateContextID() } return config @@ -151,8 +149,8 @@ func fetchAgentCard(baseURL string) (*server.AgentCard, error) { func displayAgentCapabilities(card *server.AgentCard) { fmt.Println("Agent Capabilities:") fmt.Printf(" Name: %s\n", card.Name) - if card.Description != nil { - fmt.Printf(" Description: %s\n", *card.Description) + if card.Description != "" { + fmt.Printf(" Description: %s\n", card.Description) } fmt.Printf(" Version: %s\n", card.Version) @@ -161,10 +159,23 @@ func displayAgentCapabilities(card *server.AgentCard) { fmt.Printf(" Provider: %s\n", card.Provider.Organization) } - // Print capabilities - fmt.Printf(" Streaming: %t\n", card.Capabilities.Streaming) - fmt.Printf(" Push Notifications: %t\n", card.Capabilities.PushNotifications) - fmt.Printf(" State Transition History: %t\n", card.Capabilities.StateTransitionHistory) + // Print capabilities - handle new *bool types + streaming := false + if card.Capabilities.Streaming != nil { + streaming = *card.Capabilities.Streaming + } + pushNotifications := false + if card.Capabilities.PushNotifications != nil { + pushNotifications = *card.Capabilities.PushNotifications + } + stateHistory := false + if card.Capabilities.StateTransitionHistory != nil { + stateHistory = *card.Capabilities.StateTransitionHistory + } + + fmt.Printf(" Streaming: %t\n", streaming) + fmt.Printf(" Push Notifications: %t\n", pushNotifications) + fmt.Printf(" State Transition History: %t\n", stateHistory) // Print input/output modes fmt.Printf(" Input Modes: %s\n", strings.Join(card.DefaultInputModes, ", ")) @@ -193,7 +204,7 @@ func displayAgentCapabilities(card *server.AgentCard) { // displayWelcomeMessage prints the welcome message with connection details. func displayWelcomeMessage(config Config) { log.Printf("Connecting to agent: %s (Timeout: %v)", config.AgentURL, config.Timeout) - fmt.Printf("Session ID: %s\n", config.SessionID) + fmt.Printf("Context ID: %s\n", config.ContextID) fmt.Printf("Streaming mode: %v\n", !config.ForceNoStreaming) fmt.Println("Enter text to send to the agent. Type 'help' for commands or 'exit' to quit.") fmt.Println(strings.Repeat("-", 60)) @@ -202,19 +213,74 @@ func displayWelcomeMessage(config Config) { // runInteractiveSession runs the main interactive session loop. func runInteractiveSession(a2aClient *client.A2AClient, config Config) { reader := bufio.NewReader(os.Stdin) - sessionID := config.SessionID + contextID := config.ContextID var lastTaskID string var lastTaskState protocol.TaskState var useStreaming = !config.ForceNoStreaming - for { - // Display prompt with indicator if we're continuing a task - if lastTaskState == protocol.TaskStateInputRequired { - fmt.Print("[Continuing task - input required] > ") - } else { - fmt.Print("> ") + // Check if input is from a pipe/redirect or interactive terminal + stat, err := os.Stdin.Stat() + isInteractive := err == nil && (stat.Mode()&os.ModeCharDevice) != 0 + + if !isInteractive { + // Non-interactive mode: process all piped input at once + log.Println("Running in non-interactive mode (piped input)") + scanner := bufio.NewScanner(os.Stdin) + inputs := []string{} + + // Read all inputs first + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line != "" { + inputs = append(inputs, line) + } } + // Process each input + for i, input := range inputs { + log.Printf("Processing input %d/%d: %s", i+1, len(inputs), input) + + // Process built-in commands + if cmdResult := processCommand( + a2aClient, + input, + &config, + &contextID, + &useStreaming, + lastTaskID, + ); cmdResult { + lastTaskState = "" + continue + } + + // Process the user input and handle the agent interaction + taskID := processUserInput(a2aClient, input, contextID, config, useStreaming) + + // Update the last task ID and check task state if a task was created + if taskID != "" { + lastTaskID = taskID + + // Get the current task state to check if it's input-required + ctx, cancel := context.WithTimeout(context.Background(), config.Timeout) + task, err := a2aClient.GetTasks(ctx, protocol.TaskQueryParams{ID: taskID}) + cancel() + + if err == nil && task != nil { + lastTaskState = task.Status.State + } else { + lastTaskState = "" + } + } + } + return + } + + // Interactive mode: continuous loop + log.Println("Running in interactive mode") + for { + // Display prompt + fmt.Print("> ") + input, readErr := reader.ReadString('\n') if readErr != nil { @@ -236,7 +302,7 @@ func runInteractiveSession(a2aClient *client.A2AClient, config Config) { a2aClient, input, &config, - &sessionID, + &contextID, &useStreaming, lastTaskID, ); cmdResult { @@ -246,32 +312,28 @@ func runInteractiveSession(a2aClient *client.A2AClient, config Config) { } // Process the user input and handle the agent interaction - // If we're in input-required state, use the existing task ID - var taskID string - if lastTaskState == protocol.TaskStateInputRequired { - taskID = processUserInput(a2aClient, input, sessionID, config, useStreaming, lastTaskID) - } else { - taskID = processUserInput(a2aClient, input, sessionID, config, useStreaming, "") - } + taskID := processUserInput(a2aClient, input, contextID, config, useStreaming) - // Update the last task ID and check task state - lastTaskID = taskID + // Update the last task ID and check task state if a task was created + if taskID != "" { + lastTaskID = taskID - // Get the current task state to check if it's input-required - ctx, cancel := context.WithTimeout(context.Background(), config.Timeout) - task, err := a2aClient.GetTasks(ctx, protocol.TaskQueryParams{ID: taskID}) - cancel() + // Get the current task state to check if it's input-required + ctx, cancel := context.WithTimeout(context.Background(), config.Timeout) + task, err := a2aClient.GetTasks(ctx, protocol.TaskQueryParams{ID: taskID}) + cancel() - if err == nil && task != nil { - lastTaskState = task.Status.State + if err == nil && task != nil { + lastTaskState = task.Status.State - // Display a message if input is required - if lastTaskState == protocol.TaskStateInputRequired { - fmt.Println(strings.Repeat("-", 60)) - fmt.Println("[Additional input required to complete this task. Continue typing.]") + // Display a message if input is required + if lastTaskState == protocol.TaskStateInputRequired { + fmt.Println(strings.Repeat("-", 60)) + fmt.Println("[Additional input required to complete this task. Continue typing.]") + } + } else { + lastTaskState = "" } - } else { - lastTaskState = "" } } } @@ -281,7 +343,7 @@ func processCommand( a2aClient *client.A2AClient, input string, config *Config, - sessionID *string, + contextID *string, useStreaming *bool, lastTaskID string, ) bool { @@ -302,15 +364,15 @@ func processCommand( displayHelpMessage() return true - case cmdSession: + case cmdContext: if len(parts) > 1 { - // Set new session ID - *sessionID = parts[1] - fmt.Printf("Session ID set to: %s\n", *sessionID) + // Set new context ID + *contextID = parts[1] + fmt.Printf("Context ID set to: %s\n", *contextID) } else { - // Generate new session ID - *sessionID = generateSessionID() - fmt.Printf("Generated new session ID: %s\n", *sessionID) + // Generate new context ID + *contextID = protocol.GenerateContextID() + fmt.Printf("Generated new context ID: %s\n", *contextID) } return true @@ -408,8 +470,9 @@ func processCommand( return true case "new": - // Force start a new task (ignore current input-required state) - fmt.Println("Starting a new task on next input.") + // Force start a new context + *contextID = protocol.GenerateContextID() + fmt.Printf("Starting a new context: %s\n", *contextID) return true case cmdServer: @@ -450,7 +513,7 @@ func displayHelpMessage() { fmt.Println("Available commands:") fmt.Println(" help - Show this help message") fmt.Println(" exit - Exit the program") - fmt.Println(" session [id] - Set or generate a new session ID") + fmt.Println(" context [id] - Set or generate a new context ID") fmt.Println(" mode [stream|sync] - Set interaction mode (streaming or standard)") fmt.Println(" cancel [task-id] - Cancel a task (uses last task ID if not specified)") fmt.Println(" get [task-id] [history] - Get task details (uses last task ID if not specified)") @@ -459,10 +522,10 @@ func displayHelpMessage() { fmt.Println(" getpush - Get push notification configuration for a task") fmt.Println(" server start - Start push notification server") fmt.Println(" server stop - Stop push notification server") - fmt.Println(" new - Start a new task (ignore current input-required state)") + fmt.Println(" new - Start a new context") fmt.Println("") fmt.Println("For normal interaction, just type your message and press Enter.") - fmt.Println("When a task requires additional input, your next message will continue the same task.") + fmt.Println("Messages in the same context will maintain conversation history.") fmt.Println(strings.Repeat("-", 60)) } @@ -470,57 +533,43 @@ func displayHelpMessage() { func processUserInput( a2aClient *client.A2AClient, input, - sessionID string, + contextID string, config Config, useStreaming bool, - existingTaskID string, ) string { - // Generate unique task ID or use existing one if provided - var taskID string - if existingTaskID != "" { - taskID = existingTaskID - } else { - taskID = generateTaskID() - } + // Create message with context + message := protocol.NewMessageWithContext( + protocol.MessageRoleUser, + []protocol.Part{protocol.NewTextPart(input)}, + nil, // taskID + &contextID, + ) - // Create message and parameters. - params := createTaskParams(taskID, sessionID, input, config.HistoryLength) + // Create message parameters + params := createMessageParams(message, config.HistoryLength) // Send the request and process the response based on mode + var taskID string if useStreaming && !config.ForceNoStreaming { - handleStreamingInteraction(a2aClient, params, taskID, config) + taskID = handleStreamingInteraction(a2aClient, params, config) } else { - handleStandardInteraction(a2aClient, params, taskID, config) + taskID = handleStandardInteraction(a2aClient, params, config) } return taskID } -// generateSessionID creates a new unique session ID. -func generateSessionID() string { - return fmt.Sprintf("cli-session-%d-%s", time.Now().Unix(), uuid.New().String()) -} - -// generateTaskID creates a new unique task ID. -func generateTaskID() string { - return fmt.Sprintf("cli-task-%d-%s", time.Now().UnixNano(), uuid.New().String()) -} - -// createTaskParams creates the parameters for sending a task. -func createTaskParams(taskID, sessionID, input string, historyLength int) protocol.SendTaskParams { - message := protocol.NewMessage( - protocol.MessageRoleUser, - []protocol.Part{protocol.NewTextPart(input)}, - ) - - params := protocol.SendTaskParams{ - ID: taskID, - SessionID: &sessionID, - Message: message, +// createMessageParams creates the parameters for sending a message. +func createMessageParams(message protocol.Message, historyLength int) protocol.SendMessageParams { + params := protocol.SendMessageParams{ + Message: message, } + // Add configuration if needed if historyLength > 0 { - params.HistoryLength = &historyLength + params.Configuration = &protocol.SendMessageConfiguration{ + HistoryLength: &historyLength, + } } return params @@ -529,120 +578,127 @@ func createTaskParams(taskID, sessionID, input string, historyLength int) protoc // handleStreamingInteraction sends a streaming request to the agent and processes the response. func handleStreamingInteraction( a2aClient *client.A2AClient, - params protocol.SendTaskParams, - taskID string, + params protocol.SendMessageParams, config Config, -) { +) string { // Create context for the stream. ctx, cancel := context.WithTimeout(context.Background(), config.Timeout*2) defer cancel() - log.Printf("Sending stream request for task %s (Session: %s)...", taskID, *params.SessionID) - eventChan, streamErr := a2aClient.StreamTask(ctx, params) + log.Printf("Sending stream request for message %s (Context: %s)...", params.Message.MessageID, *params.Message.ContextID) + eventChan, streamErr := a2aClient.StreamMessage(ctx, params) if streamErr != nil { - log.Printf("ERROR: StreamTask request failed: %v", streamErr) + log.Printf("ERROR: StreamMessage request failed: %v", streamErr) fmt.Println(strings.Repeat("-", 60)) - return + return "" } // Process the stream response. - finalTaskState, finalArtifacts := processStreamResponse(ctx, eventChan) + taskID := processStreamResponse(ctx, eventChan) - // Get the final task state if configured to do so. - if config.UseTasksGet { - getFinalTaskState(a2aClient, taskID, config.Timeout, finalTaskState, finalArtifacts, config.HistoryLength) - } - - log.Printf("Stream processing finished for task %s", taskID) + log.Printf("Stream processing finished for message %s", params.Message.MessageID) fmt.Println(strings.Repeat("-", 60)) + + return taskID } // handleStandardInteraction sends a standard (non-streaming) request to the agent. func handleStandardInteraction( a2aClient *client.A2AClient, - params protocol.SendTaskParams, - taskID string, + params protocol.SendMessageParams, config Config, -) { +) string { // Create context for the request. ctx, cancel := context.WithTimeout(context.Background(), config.Timeout) defer cancel() - log.Printf("Sending standard request for task %s (Session: %s)...", taskID, *params.SessionID) + log.Printf("Sending standard request for message %s (Context: %s)...", params.Message.MessageID, *params.Message.ContextID) - // Send the task - task, err := a2aClient.SendTasks(ctx, params) + // Send the message + result, err := a2aClient.SendMessage(ctx, params) if err != nil { - log.Printf("ERROR: SendTasks request failed: %v", err) + log.Printf("ERROR: SendMessage request failed: %v", err) fmt.Println(strings.Repeat("-", 60)) - return + return "" } // Process the response fmt.Println("\n<< Agent Response:") fmt.Println(strings.Repeat("-", 10)) - // Display task state - fmt.Printf(" State: %s (%s)\n", task.Status.State, formatTimestamp(task.Status.Timestamp)) + var taskID string + switch response := result.Result.(type) { + case *protocol.Message: + fmt.Println(" Message Response:") + printMessage(*response) - // Display message if present - if task.Status.Message != nil { - fmt.Println(" Message:") - printMessage(*task.Status.Message) - } + case *protocol.Task: + taskID = response.ID + // Display task state + fmt.Printf(" Task %s State: %s (%s)\n", response.ID, response.Status.State, formatTimestamp(response.Status.Timestamp)) - // Display artifacts if present - if len(task.Artifacts) > 0 { - fmt.Println(" Artifacts:") - for i, artifact := range task.Artifacts { - name := fmt.Sprintf("Artifact #%d", i+1) - if artifact.Name != nil { - name = *artifact.Name + // Display message if present + if response.Status.Message != nil { + fmt.Println(" Message:") + printMessage(*response.Status.Message) + } + + // Display artifacts if present + if len(response.Artifacts) > 0 { + fmt.Println(" Artifacts:") + for i, artifact := range response.Artifacts { + name := fmt.Sprintf("Artifact #%d", i+1) + if artifact.Name != nil { + name = *artifact.Name + } + fmt.Printf(" [%s]\n", name) + printParts(artifact.Parts) } - fmt.Printf(" [%s]\n", name) - printParts(artifact.Parts) } - } - // Display history if present - if task.History != nil && len(task.History) > 0 { - fmt.Println(" History:") - for i, msg := range task.History { - role := "User" - if msg.Role == protocol.MessageRoleAgent { - role = "Agent" + // Display history if present + if response.History != nil && len(response.History) > 0 { + fmt.Println(" History:") + for i, msg := range response.History { + role := "User" + if msg.Role == protocol.MessageRoleAgent { + role = "Agent" + } + fmt.Printf(" [%d] %s:\n", i+1, role) + printParts(msg.Parts) } - fmt.Printf(" [%d] %s:\n", i+1, role) - printParts(msg.Parts) } - } - // Add special handling for input-required state - if task.Status.State == protocol.TaskStateInputRequired { - fmt.Println(" [Additional input required]") + // Add special handling for input-required state + if response.Status.State == protocol.TaskStateInputRequired { + fmt.Println(" [Additional input required]") + } + + default: + fmt.Printf(" Unknown response type: %T\n", response) } fmt.Println(strings.Repeat("-", 60)) + return taskID } // processStreamResponse processes the stream of events from the agent. func processStreamResponse( - ctx context.Context, eventChan <-chan protocol.TaskEvent, -) (protocol.TaskState, []protocol.Artifact) { + ctx context.Context, eventChan <-chan protocol.StreamingMessageEvent, +) string { fmt.Println("\n<< Agent Response Stream:") fmt.Println(strings.Repeat("-", 10)) - var finalTaskState protocol.TaskState - finalArtifacts := []protocol.Artifact{} + var taskID string for { select { case <-ctx.Done(): // Context timed out or was cancelled log.Printf("ERROR: Context timeout or cancellation while waiting for stream events: %v", ctx.Err()) - return finalTaskState, finalArtifacts + return taskID case event, ok := <-eventChan: if !ok { @@ -651,27 +707,35 @@ func processStreamResponse( if ctx.Err() != nil { log.Printf("Context error after stream close: %v", ctx.Err()) } - return finalTaskState, finalArtifacts + return taskID } // Process the received event based on its type - switch e := event.(type) { - case protocol.TaskStatusUpdateEvent: - fmt.Printf(" [Status Update: %s (%s)]\n", e.Status.State, formatTimestamp(e.Status.Timestamp)) + switch e := event.Result.(type) { + case *protocol.Message: + fmt.Println(" [Message Response:]") + printMessage(*e) + + case *protocol.Task: + taskID = e.ID + fmt.Printf(" [Task %s State: %s (%s)]\n", e.ID, e.Status.State, formatTimestamp(e.Status.Timestamp)) if e.Status.Message != nil { printMessage(*e.Status.Message) } - // Update the task state - finalTaskState = e.Status.State + case *protocol.TaskStatusUpdateEvent: + taskID = e.TaskID + fmt.Printf(" [Status Update: %s (%s)]\n", e.Status.State, formatTimestamp(e.Status.Timestamp)) + if e.Status.Message != nil { + printMessage(*e.Status.Message) + } // Handle final states and input-required state if e.Status.State == protocol.TaskStateInputRequired { - // This is not a final state, but we need to store it fmt.Println(" [Additional input required]") - return finalTaskState, finalArtifacts - } else if e.IsFinal() { - log.Printf("Final status received: %s", finalTaskState) + return taskID + } else if e.Final != nil && *e.Final { + log.Printf("Final status received: %s", e.Status.State) // Print a message indicating task completion state if e.Status.State == protocol.TaskStateCompleted { @@ -681,59 +745,27 @@ func processStreamResponse( } else if e.Status.State == protocol.TaskStateCanceled { fmt.Println(" [Task was canceled]") } - return finalTaskState, finalArtifacts + return taskID } - case protocol.TaskArtifactUpdateEvent: + case *protocol.TaskArtifactUpdateEvent: + taskID = e.TaskID // Get the artifact name or use a default name := getArtifactName(e.Artifact) - // Show if this is an append operation - if e.Artifact.Append != nil && *e.Artifact.Append { - fmt.Printf(" [Artifact Update: %s (Appending)]\n", name) - } else { - fmt.Printf(" [Artifact Update: %s]\n", name) - } + fmt.Printf(" [Artifact Update: %s]\n", name) // Print the artifact parts printParts(e.Artifact.Parts) - // Handle artifact storage for return value - if e.Artifact.Append != nil && *e.Artifact.Append && len(finalArtifacts) > 0 { - // Find existing artifact with same index to append to - for i, art := range finalArtifacts { - if art.Index == e.Artifact.Index { - // Append parts - combinedParts := append(art.Parts, e.Artifact.Parts...) - finalArtifacts[i].Parts = combinedParts - - // Update other fields if needed - if e.Artifact.Name != nil { - finalArtifacts[i].Name = e.Artifact.Name - } - if e.Artifact.Description != nil { - finalArtifacts[i].Description = e.Artifact.Description - } - if e.Artifact.LastChunk != nil { - finalArtifacts[i].LastChunk = e.Artifact.LastChunk - } - - // Break after updating - break - } - } - } else { - finalArtifacts = append(finalArtifacts, e.Artifact) - } - // For artifact updates, we note it's the final artifact, // but we don't exit yet - per A2A spec, we should wait for the final status update - if e.IsFinal() { - log.Printf("Final artifact received for index %d", e.Artifact.Index) + if e.LastChunk != nil && *e.LastChunk { + log.Printf("Final artifact received with ID %s", e.Artifact.ArtifactID) } default: - log.Printf("Warning: Received unknown event type: %T\n", event) + log.Printf("Warning: Received unknown event type: %T\n", event.Result) } } } @@ -744,44 +776,7 @@ func getArtifactName(artifact protocol.Artifact) string { if artifact.Name != nil { return *artifact.Name } - return fmt.Sprintf("Artifact #%d", artifact.Index+1) -} - -// getFinalTaskState fetches and displays the final task state. -func getFinalTaskState( - a2aClient *client.A2AClient, - taskID string, - timeout time.Duration, - streamState protocol.TaskState, - streamArtifacts []protocol.Artifact, - historyLength int, -) { - finalCtx, finalCancel := context.WithTimeout(context.Background(), timeout) - defer finalCancel() - - params := protocol.TaskQueryParams{ID: taskID} - if historyLength > 0 { - params.HistoryLength = &historyLength - } - - finalTask, getErr := a2aClient.GetTasks(finalCtx, params) - - fmt.Println(strings.Repeat("-", 10)) - fmt.Println("<< Final Result (from GetTask):") - - if getErr != nil { - log.Printf("ERROR: Failed to get final task state for %s: %v", taskID, getErr) - fmt.Printf(" State: %s (from stream)\n", streamState) - return - } - - if finalTask == nil { - log.Printf("WARNING: TasksGet for %s returned nil task without error.", taskID) - fmt.Printf(" State: %s (from stream)\n", streamState) - return - } - - displayFinalTaskState(finalTask) + return fmt.Sprintf("Artifact %s", artifact.ArtifactID) } // cancelTask attempts to cancel a running task. @@ -878,87 +873,26 @@ func printParts(parts []protocol.Part) { } } -// printPart handles the printing logic based on the concrete part type. -// It includes fallbacks for map[string]interface{} representations. +// printPart prints a single part with proper indentation. func printPart(part interface{}) { - indent := " " // Indentation for nested content. - - // Handle direct types from taskmanager first (preferred). + const indent = " " switch p := part.(type) { + case *protocol.TextPart: + fmt.Println(indent + p.Text) case protocol.TextPart: fmt.Println(indent + p.Text) - case protocol.FilePart: - printFilePart(p, indent) - case protocol.DataPart: - printDataPart(p, indent) case map[string]interface{}: - printMapPart(p, indent) - default: - fmt.Printf("%s[Unknown Part Type: %T]\n", indent, p) - } -} - -// printFilePart prints a file part. -func printFilePart(p protocol.FilePart, indent string) { - name := "(unnamed file)" - if p.File.Name != nil { - name = *p.File.Name - } - mime := "(unknown type)" - if p.File.MimeType != nil { - mime = *p.File.MimeType - } - fmt.Printf("%s[File: %s (%s)]\n", indent, name, mime) - if p.File.URI != nil { - fmt.Printf("%s URI: %s\n", indent, *p.File.URI) - } - if p.File.Bytes != nil { - fmt.Printf("%s Bytes: %d bytes\n", indent, len(*p.File.Bytes)) - } -} - -// printDataPart prints a data part. -func printDataPart(p protocol.DataPart, indent string) { - fmt.Printf("%s[Structured Data]\n", indent) - dataContent, err := json.MarshalIndent(p.Data, indent, " ") - if err == nil { - fmt.Printf("%s%s\n", indent, string(dataContent)) - } else { - fmt.Printf("%s Error marshaling data: %v\n", indent, err) - fmt.Printf("%s Raw: %+v\n", indent, p.Data) - } -} - -// printMapPart prints a part represented as a map. -func printMapPart(p map[string]interface{}, indent string) { - if typeStr, ok := p["type"].(string); ok { - switch typeStr { - case string(protocol.PartTypeText): + // Handle parts that come as maps (from JSON) + if typeStr, ok := p["type"].(string); ok && typeStr == "text" { if text, ok := p["text"].(string); ok { fmt.Println(indent + text) } - case string(protocol.PartTypeFile): - fmt.Printf("%s[File (from map)]\n", indent) - fileData, err := json.MarshalIndent(p["file"], indent, " ") - if err == nil { - fmt.Printf("%s%s\n", indent, string(fileData)) - } else if p["file"] != nil { - fmt.Printf("%s %+v\n", indent, p["file"]) - } - case string(protocol.PartTypeData): - fmt.Printf("%s[Structured Data (from map)]\n", indent) - dataContent, err := json.MarshalIndent(p["data"], indent, " ") - if err == nil { - fmt.Printf("%s%s\n", indent, string(dataContent)) - } else if p["data"] != nil { - fmt.Printf("%s %+v\n", indent, p["data"]) - } - default: - fmt.Printf("%s[Unknown map part type: %s]\n", indent, typeStr) + } else { + // For other types, just print the map + fmt.Printf("%s[Structured Content: %+v]\n", indent, p) } - } else { - mapData, _ := json.MarshalIndent(p, indent, " ") - fmt.Printf("%s[Unknown map structure]:\n%s%s\n", indent, indent, string(mapData)) + default: + fmt.Printf("%s[Unknown Part Type: %T]\n", indent, p) } } @@ -1134,9 +1068,9 @@ func setPushNotification( pushConfig.Token = *token } - // Create the task push notification configuration + // Create the task push notification configuration using TaskID field taskPushConfig := protocol.TaskPushNotificationConfig{ - ID: taskID, + TaskID: taskID, PushNotificationConfig: pushConfig, } @@ -1150,7 +1084,7 @@ func setPushNotification( // Display success fmt.Println("Push notification set successfully:") - fmt.Printf(" Task ID: %s\n", result.ID) + fmt.Printf(" Task ID: %s\n", result.TaskID) fmt.Printf(" URL: %s\n", result.PushNotificationConfig.URL) if result.PushNotificationConfig.Token != "" { fmt.Printf(" Token: %s\n", result.PushNotificationConfig.Token) @@ -1179,7 +1113,7 @@ func getPushNotification(a2aClient *client.A2AClient, taskID string, timeout tim // Display the push notification configuration fmt.Println("Push notification configuration:") - fmt.Printf(" Task ID: %s\n", result.ID) + fmt.Printf(" Task ID: %s\n", result.TaskID) fmt.Printf(" URL: %s\n", result.PushNotificationConfig.URL) if result.PushNotificationConfig.Token != "" { fmt.Printf(" Token: %s\n", result.PushNotificationConfig.Token) diff --git a/examples/basic/server/main.go b/examples/basic/server/main.go index e0ed9b1..926a9b8 100644 --- a/examples/basic/server/main.go +++ b/examples/basic/server/main.go @@ -14,7 +14,6 @@ import ( "flag" "fmt" "io" - "log" "net/http" "os" "os/signal" @@ -22,6 +21,7 @@ import ( "syscall" "time" + "trpc.group/trpc-go/trpc-a2a-go/log" "trpc.group/trpc-go/trpc-a2a-go/protocol" "trpc.group/trpc-go/trpc-a2a-go/server" "trpc.group/trpc-go/trpc-a2a-go/taskmanager" @@ -32,21 +32,12 @@ const ( modeReverse = "reverse" modeUppercase = "uppercase" modeLowercase = "lowercase" - modeWordCount = "count" - modeMultiStep = "multi" + modeCount = "count" modeHelp = "help" + modeMultiStep = "multi" modeInputExample = "example" ) -// basicTaskProcessor implements the taskmanager.TaskProcessor interface -type basicTaskProcessor struct { - // Map to track multi-turn conversations - multiTurnSessions map[string]multiTurnSession - - // Flag to determine if we should use streaming mode - useStreaming bool -} - // multiTurnSession tracks state for a multi-turn interaction type multiTurnSession struct { stage int @@ -55,14 +46,23 @@ type multiTurnSession struct { complete bool } -// Process implements the taskmanager.TaskProcessor interface -func (p *basicTaskProcessor) Process( +// basicMessageProcessor implements the taskmanager.MessageProcessor interface +type basicMessageProcessor struct { + // Flag to determine if we should use streaming mode + useStreaming bool + + // Added for multi-turn session handling - now keyed by contextID instead of taskID + multiTurnSessions map[string]multiTurnSession +} + +// ProcessMessage implements the taskmanager.MessageProcessor interface +func (p *basicMessageProcessor) ProcessMessage( ctx context.Context, - taskID string, message protocol.Message, - handle taskmanager.TaskHandle, -) error { - log.Printf("Processing task %s...", taskID) + options taskmanager.ProcessOptions, + handle taskmanager.TaskHandler, +) (*taskmanager.MessageProcessingResult, error) { + log.Infof("Processing basic message with ID: %s", message.MessageID) // Initialize multi-turn sessions map if not already initialized if p.multiTurnSessions == nil { @@ -73,400 +73,174 @@ func (p *basicTaskProcessor) Process( text := extractText(message) if text == "" { errMsg := "input message must contain text" - log.Printf("Task %s failed: %s", taskID, errMsg) - // Update status to Failed via handle - failedMessage := protocol.NewMessage( + log.Errorf("Message processing failed: %s", errMsg) + + // Return error message directly + errorMessage := protocol.NewMessage( protocol.MessageRoleAgent, []protocol.Part{protocol.NewTextPart(errMsg)}, ) - _ = handle.UpdateStatus(protocol.TaskStateFailed, &failedMessage) - return fmt.Errorf(errMsg) + return &taskmanager.MessageProcessingResult{ + Result: &errorMessage, + }, nil } - // Check for continuation of a multi-turn session - session, exists := p.multiTurnSessions[taskID] - - if exists && !session.complete { - return p.handleMultiTurnSession(ctx, taskID, text, handle, session) + // Get context ID for session management + contextID := handle.GetContextID() + if contextID == "" { + // No context ID available, treat as simple command processing + return p.processSimpleCommand(text) } - // New interaction - determine mode and process accordingly - return p.handleNewInteraction(ctx, taskID, text, handle) -} - -// handleMultiTurnSession processes the next step in a multi-turn interaction -func (p *basicTaskProcessor) handleMultiTurnSession( - ctx context.Context, - taskID string, - text string, - handle taskmanager.TaskHandle, - session multiTurnSession, -) error { - // Update session with new input - switch session.stage { - case 1: - // First response received - this is the mode - session.mode = strings.ToLower(strings.TrimSpace(text)) - session.stage = 2 - - // Ask for the text to process - msg := protocol.NewMessage( - protocol.MessageRoleAgent, - []protocol.Part{protocol.NewTextPart("Please enter the text you want to process:")}, - ) - - if err := handle.UpdateStatus(protocol.TaskStateInputRequired, &msg); err != nil { - return fmt.Errorf("failed to update task status: %w", err) + if options.Streaming { + // Streaming mode - use task-based processing with full features + taskID, err := handle.BuildTask(nil, &contextID) + if err != nil { + return nil, fmt.Errorf("failed to build task: %w", err) } - // Store updated session - p.multiTurnSessions[taskID] = session - return nil + log.Infof("Created streaming task %s for processing", taskID) - case 2: - // Second response received - this is the text to process - session.text = text - session.stage = 3 - session.complete = true - - // Process the text based on the selected mode - result := p.processTextWithMode(session.text, session.mode) - - // Create the completed message - finalMsg := protocol.NewMessage( - protocol.MessageRoleAgent, - []protocol.Part{protocol.NewTextPart(result)}, - ) - - // Check if this is a streaming request - isStreaming := handle.IsStreamingRequest() - if !isStreaming && p.useStreaming { - // Fall back to the flag for backward compatibility - isStreaming = true - } - - // Use streaming if enabled - if isStreaming { - // Send intermediate status update - inProgressMsg := protocol.NewMessage( - protocol.MessageRoleAgent, - []protocol.Part{protocol.NewTextPart("Processing your request...")}, - ) - if err := handle.UpdateStatus(protocol.TaskStateWorking, &inProgressMsg); err != nil { - log.Printf("Error sending intermediate status: %v", err) - } - - // Simulate processing delay - time.Sleep(500 * time.Millisecond) + // Subscribe to the task for streaming events + subscriber, err := handle.SubScribeTask(&taskID) + if err != nil { + return nil, fmt.Errorf("failed to subscribe to task: %w", err) } - // Update task status to completed - if err := handle.UpdateStatus(protocol.TaskStateCompleted, &finalMsg); err != nil { - return fmt.Errorf("failed to update final task status: %w", err) - } + // Process asynchronously with full multi-turn support + go p.processMessageAsync(ctx, text, contextID, taskID, handle) - // Add the artifact - artifact := protocol.Artifact{ - Name: stringPtr("Processed Text"), - Description: stringPtr(fmt.Sprintf("Text processed with mode: %s", session.mode)), - Index: 0, - Parts: []protocol.Part{protocol.NewTextPart(result)}, - LastChunk: boolPtr(true), - } - - if err := handle.AddArtifact(artifact); err != nil { - log.Printf("Error adding artifact for task %s: %v", taskID, err) - } - - // Update session in map - p.multiTurnSessions[taskID] = session - return nil + return &taskmanager.MessageProcessingResult{ + StreamingEvents: subscriber, + }, nil } + // Non-streaming mode - check for multi-turn interactions + session, exists := p.multiTurnSessions[contextID] - return fmt.Errorf("unexpected stage in multi-turn session: %d", session.stage) -} - -// handleNewInteraction processes a new interaction -func (p *basicTaskProcessor) handleNewInteraction( - ctx context.Context, - taskID string, - text string, - handle taskmanager.TaskHandle, -) error { - // Check for cancellation via context - if err := ctx.Err(); err != nil { - log.Printf("Task %s cancelled during processing: %v", taskID, err) - _ = handle.UpdateStatus(protocol.TaskStateCanceled, nil) - return err + if exists && !session.complete { + // Continue existing multi-turn session + return p.processMultiTurnSession(ctx, text, contextID, handle, session, options) } - // Parse the first word as the command + // New interaction - check if it requires multi-turn handling parts := strings.SplitN(text, " ", 2) command := strings.ToLower(parts[0]) - // Handle multi-step mode - if command == modeMultiStep { - session := multiTurnSession{ - stage: 1, - complete: false, + if command == modeMultiStep || command == modeInputExample { + // Create task for multi-turn interaction even in non-streaming mode + taskID, err := handle.BuildTask(nil, &contextID) + if err != nil { + return nil, fmt.Errorf("failed to build task: %w", err) } - // Store the session - p.multiTurnSessions[taskID] = session + go p.processMessageAsync(ctx, text, contextID, taskID, handle) - // Ask for the processing mode - msg := protocol.NewMessage( + // Return a message indicating async processing + responseMessage := protocol.NewMessage( protocol.MessageRoleAgent, - []protocol.Part{protocol.NewTextPart( - "This is a multi-step interaction. Please select a processing mode:\n" + - "- reverse: Reverses the text\n" + - "- uppercase: Converts text to uppercase\n" + - "- lowercase: Converts text to lowercase\n" + - "- count: Counts words and characters")}, + []protocol.Part{protocol.NewTextPart("Multi-turn interaction started. Please continue the conversation.")}, ) - - if err := handle.UpdateStatus(protocol.TaskStateInputRequired, &msg); err != nil { - return fmt.Errorf("failed to update task status: %w", err) - } - - return nil + return &taskmanager.MessageProcessingResult{ + Result: &responseMessage, + }, nil } - // Handle example input-required state - if command == modeInputExample { - msg := protocol.NewMessage( - protocol.MessageRoleAgent, - []protocol.Part{protocol.NewTextPart("Please provide more information to continue:")}, - ) - - if err := handle.UpdateStatus(protocol.TaskStateInputRequired, &msg); err != nil { - return fmt.Errorf("failed to update task status: %w", err) - } + // Simple command processing + return p.processSimpleCommand(text) +} - // Create a session for the follow-up - session := multiTurnSession{ - stage: 2, // Skip to stage 2 (text input) - mode: modeReverse, // Default to reverse mode - complete: false, - } - p.multiTurnSessions[taskID] = session - return nil - } +// processSimpleCommand handles direct command processing without tasks +func (p *basicMessageProcessor) processSimpleCommand(text string) (*taskmanager.MessageProcessingResult, error) { + parts := strings.SplitN(text, " ", 2) + command := strings.ToLower(parts[0]) - // For direct processing (non-multi-turn), extract the rest as content var content string if len(parts) > 1 { content = parts[1] - } else { - content = "" // No content provided, command only - } - - // Check if request is streaming using the new interface method - // Fall back to the flag if needed for backward compatibility - isStreaming := handle.IsStreamingRequest() - if !isStreaming && p.useStreaming { - // If the request itself isn't streaming but the processor is configured to use streaming, - // use streaming mode anyway (for backward compatibility) - isStreaming = true - } - - // Process in streaming or non-streaming mode based on the request type - if isStreaming { - return p.processWithStreaming(ctx, taskID, command, content, handle) - } else { - return p.processDirectly(ctx, taskID, command, content, handle) - } -} - -// processWithStreaming handles the processing with intermediate updates -func (p *basicTaskProcessor) processWithStreaming( - ctx context.Context, - taskID string, - command string, - content string, - handle taskmanager.TaskHandle, -) error { - // Send initial "working" status - workingMsg := protocol.NewMessage( - protocol.MessageRoleAgent, - []protocol.Part{protocol.NewTextPart("Processing your request...")}, - ) - - if err := handle.UpdateStatus(protocol.TaskStateWorking, &workingMsg); err != nil { - return fmt.Errorf("failed to update working status: %w", err) } - // Simulate processing delay - time.Sleep(500 * time.Millisecond) - // Process the content based on command result := p.processTextWithMode(content, command) - - // Add artifact in chunks to demonstrate streaming - // Split at a complete line to avoid breaking words/sentences - splitIndex := findSplitIndex(result) - - // First chunk - artifact1 := protocol.Artifact{ - Name: stringPtr("Processed Text (Part 1)"), - Description: stringPtr("First part of the processed text"), - Index: 0, - Parts: []protocol.Part{protocol.NewTextPart(result[:splitIndex])}, - LastChunk: boolPtr(false), - } - - if err := handle.AddArtifact(artifact1); err != nil { - log.Printf("Error adding first artifact chunk for task %s: %v", taskID, err) - } - - // Small delay to demonstrate streaming - time.Sleep(300 * time.Millisecond) - - // Second chunk (appends to first) - artifact2 := protocol.Artifact{ - Name: stringPtr("Processed Text (Complete)"), - Description: stringPtr("Complete processed text"), - Index: 0, // Same index as first chunk - Append: boolPtr(true), - Parts: []protocol.Part{protocol.NewTextPart(result[splitIndex:])}, - LastChunk: boolPtr(true), - } - - if err := handle.AddArtifact(artifact2); err != nil { - log.Printf("Error adding second artifact chunk for task %s: %v", taskID, err) - } - - // Create final message - finalMsg := protocol.NewMessage( + responseMessage := protocol.NewMessage( protocol.MessageRoleAgent, []protocol.Part{protocol.NewTextPart(result)}, ) - // Update task to completed - if err := handle.UpdateStatus(protocol.TaskStateCompleted, &finalMsg); err != nil { - return fmt.Errorf("failed to update final status: %w", err) - } - return nil -} - -// findSplitIndex finds a good place to split text without breaking words or lines -func findSplitIndex(text string) int { - // If text is short, don't split - if len(text) < 20 { - return len(text) - } - // Start at halfway - splitIndex := len(text) / 2 - - // Try to find a newline near the middle - for i := splitIndex; i < len(text); i++ { - if text[i] == '\n' { - return i + 1 // Include the newline in the first chunk - } - } - - // If no newline, try to find space near the middle - for i := splitIndex; i < len(text); i++ { - if text[i] == ' ' { - return i + 1 // Include the space in the first chunk - } - } - - // If no good natural boundary found, just use halfway point - return splitIndex + return &taskmanager.MessageProcessingResult{ + Result: &responseMessage, + }, nil } -// processDirectly handles the processing without intermediate updates -func (p *basicTaskProcessor) processDirectly( +// processMultiTurnSession handles continuing a multi-turn session in non-streaming mode +func (p *basicMessageProcessor) processMultiTurnSession( ctx context.Context, - taskID string, - command string, - content string, - handle taskmanager.TaskHandle, -) error { - // Process the content based on command - result := p.processTextWithMode(content, command) - - // Create the response message - finalMsg := protocol.NewMessage( - protocol.MessageRoleAgent, - []protocol.Part{protocol.NewTextPart(result)}, - ) - - // Update task to completed - if err := handle.UpdateStatus(protocol.TaskStateCompleted, &finalMsg); err != nil { - return fmt.Errorf("failed to update final status: %w", err) - } - - // Add artifact - artifact := protocol.Artifact{ - Name: stringPtr("Processed Text"), - Description: stringPtr(fmt.Sprintf("Text processed with mode: %s", command)), - Index: 0, - Parts: []protocol.Part{protocol.NewTextPart(result)}, - LastChunk: boolPtr(true), + text string, + contextID string, + handle taskmanager.TaskHandler, + session multiTurnSession, + options taskmanager.ProcessOptions, +) (*taskmanager.MessageProcessingResult, error) { + // Create task for multi-turn processing + taskID, err := handle.BuildTask(nil, &contextID) + if err != nil { + return nil, fmt.Errorf("failed to build task: %w", err) } - if err := handle.AddArtifact(artifact); err != nil { - log.Printf("Error adding artifact for task %s: %v", taskID, err) - } + go p.handleMultiTurnSessionAsync(ctx, taskID, text, contextID, handle, session) - return nil + // Return message indicating processing + responseMessage := protocol.NewMessage( + protocol.MessageRoleAgent, + []protocol.Part{protocol.NewTextPart("Processing your multi-turn request...")}, + ) + return &taskmanager.MessageProcessingResult{ + Result: &responseMessage, + }, nil } // processTextWithMode processes text with the specified mode -func (p *basicTaskProcessor) processTextWithMode(text, mode string) string { - switch strings.ToLower(mode) { +func (p *basicMessageProcessor) processTextWithMode(text, mode string) string { + switch mode { case modeReverse: - return fmt.Sprintf("Input: %s\nReversed: %s", text, reverseString(text)) + // Simple text reversal + runes := []rune(text) + for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 { + runes[i], runes[j] = runes[j], runes[i] + } + return fmt.Sprintf("Reversed: %s", string(runes)) case modeUppercase: - return fmt.Sprintf("Input: %s\nUppercase: %s", text, strings.ToUpper(text)) + return strings.ToUpper(text) case modeLowercase: - return fmt.Sprintf("Input: %s\nLowercase: %s", text, strings.ToLower(text)) - case modeWordCount: + return strings.ToLower(text) + case modeCount: words := len(strings.Fields(text)) chars := len(text) - return fmt.Sprintf("Input: %s\nWord count: %d\nCharacter count: %d", text, words, chars) + return fmt.Sprintf("Word count: %d, Character count: %d", words, chars) case modeHelp: - return "Available commands (type one of these):\n" + - "- reverse : Reverses the input text\n Example: reverse hello world\n\n" + - "- uppercase : Converts text to uppercase\n Example: uppercase hello world\n\n" + - "- lowercase : Converts text to lowercase\n Example: lowercase HELLO WORLD\n\n" + - "- count : Counts words and characters\n Example: count this is a test\n\n" + - "- multi: Start a multi-step interaction\n\n" + - "- example: Demonstrates input-required state\n\n" + - "- help: Shows this help message" + return `Available commands: +- reverse : Reverse the given text +- uppercase : Convert text to uppercase +- lowercase : Convert text to lowercase +- count : Count words and characters +- multi-step: Start a multi-turn interaction +- input-example: Example of input-required state +- help: Show this help message + +Example: reverse hello world` default: - // Default to help if command not recognized - if text == "" { - return "Please provide a command. Type 'help' for available commands.\n\nExample commands:\n" + - "- reverse hello world\n" + - "- uppercase hello\n" + - "- count these words" - } - return fmt.Sprintf("Unknown command '%s'. Type 'help' for available commands.\n\nAssuming 'reverse' mode:\nReversed: %s", - mode, reverseString(text)) + return fmt.Sprintf("Unknown mode '%s'. Use 'help' for available commands.", mode) } } -// extractText extracts the first text part from a message +// extractText extracts text content from a message func extractText(message protocol.Message) string { + var parts []string for _, part := range message.Parts { - // Type assert to the concrete TextPart type - if p, ok := part.(protocol.TextPart); ok { - return p.Text + if textPart, ok := part.(*protocol.TextPart); ok { + parts = append(parts, textPart.Text) } } - return "" -} - -// reverseString reverses a UTF-8 encoded string -func reverseString(s string) string { - runes := []rune(s) - for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 { - runes[i], runes[j] = runes[j], runes[i] - } - return string(runes) + return strings.Join(parts, " ") } // main is the entry point for the server @@ -491,23 +265,29 @@ func main() { // Assuming HTTP for simplicity, HTTPS is recommended for production serverURL := fmt.Sprintf("http://%s/", address) + // Description based on streaming capability + description += " with streaming support" + if !forceNoStream { + description += " and push notifications" + } + // Create the agent card using types from the server package agentCard := server.AgentCard{ Name: "Text Processing Agent", - Description: stringPtr(description), + Description: description, URL: serverURL, Version: "2.0.0", // Updated version Provider: &server.AgentProvider{ Organization: "tRPC-A2A-go Examples", }, Capabilities: server.AgentCapabilities{ - Streaming: !forceNoStream, // Support streaming based on flag - PushNotifications: true, // Enable push notifications - StateTransitionHistory: true, // MemoryTaskManager stores history + Streaming: boolPtr(!forceNoStream), // Support streaming based on flag + PushNotifications: boolPtr(true), // Enable push notifications + StateTransitionHistory: boolPtr(true), // MemoryTaskManager stores history }, // Support text input/output - DefaultInputModes: []string{string(protocol.PartTypeText)}, - DefaultOutputModes: []string{string(protocol.PartTypeText)}, + DefaultInputModes: []string{"text"}, + DefaultOutputModes: []string{"text"}, Skills: []server.AgentSkill{ { ID: "text_processor_reverse", @@ -515,8 +295,8 @@ func main() { Description: stringPtr("Input: reverse hello\nOutput: Reversed: olleh"), Tags: []string{"text", "reverse"}, Examples: []string{"reverse hello world", "reverse The quick brown fox"}, - InputModes: []string{string(protocol.PartTypeText)}, - OutputModes: []string{string(protocol.PartTypeText)}, + InputModes: []string{"text"}, + OutputModes: []string{"text"}, }, { ID: "text_processor_uppercase", @@ -524,8 +304,8 @@ func main() { Description: stringPtr("Input: uppercase hello world\nOutput: HELLO WORLD"), Tags: []string{"text", "uppercase"}, Examples: []string{"uppercase hello world", "uppercase Example text"}, - InputModes: []string{string(protocol.PartTypeText)}, - OutputModes: []string{string(protocol.PartTypeText)}, + InputModes: []string{"text"}, + OutputModes: []string{"text"}, }, { ID: "text_processor_lowercase", @@ -533,8 +313,8 @@ func main() { Description: stringPtr("Input: lowercase HELLO\nOutput: hello"), Tags: []string{"text", "lowercase"}, Examples: []string{"lowercase HELLO WORLD", "lowercase TEXT"}, - InputModes: []string{string(protocol.PartTypeText)}, - OutputModes: []string{string(protocol.PartTypeText)}, + InputModes: []string{"text"}, + OutputModes: []string{"text"}, }, { ID: "text_processor_count", @@ -542,17 +322,17 @@ func main() { Description: stringPtr("Input: count hello world\nOutput: Word count: 2, Character count: 11"), Tags: []string{"text", "count"}, Examples: []string{"count The quick brown fox", "count hello world"}, - InputModes: []string{string(protocol.PartTypeText)}, - OutputModes: []string{string(protocol.PartTypeText)}, + InputModes: []string{"text"}, + OutputModes: []string{"text"}, }, { - ID: "text_processor_multi", - Name: "Multi-step Process", - Description: stringPtr("Input: multi\nOutput: Interactive conversation requesting processing mode then text"), - Tags: []string{"text", "interactive"}, - Examples: []string{"multi"}, - InputModes: []string{string(protocol.PartTypeText)}, - OutputModes: []string{string(protocol.PartTypeText)}, + ID: "text_processor_multistep", + Name: "Multi-Step Processor", + Description: stringPtr("Input: multi-step\nStarts an interactive multi-turn conversation"), + Tags: []string{"interactive", "multi-turn"}, + Examples: []string{"multi-step"}, + InputModes: []string{"text"}, + OutputModes: []string{"text"}, }, { ID: "text_processor_help", @@ -560,16 +340,16 @@ func main() { Description: stringPtr("Input: help\nOutput: List of available commands and usage"), Tags: []string{"help"}, Examples: []string{"help"}, - InputModes: []string{string(protocol.PartTypeText)}, - OutputModes: []string{string(protocol.PartTypeText)}, + InputModes: []string{"text"}, + OutputModes: []string{"text"}, }, }, } - // Create the TaskProcessor (agent logic) - processor := &basicTaskProcessor{ - multiTurnSessions: make(map[string]multiTurnSession), + // Create the MessageProcessor (agent logic) + processor := &basicMessageProcessor{ useStreaming: !forceNoStream, + multiTurnSessions: make(map[string]multiTurnSession), } // Create the base TaskManager with built-in push notification storage support @@ -594,8 +374,8 @@ func main() { // Start the server in a separate goroutine go func() { // Use log.Printf for informational message, not Fatal - log.Printf("Text Processing Agent server starting on %s (CORS enabled: %t, Streaming: %t, Push: enabled)", - address, !noCORS, !forceNoStream) + log.Infof("Text Processing Agent server starting on %s (CORS enabled: %t, Streaming: %t, Push: %t)", + address, !noCORS, !forceNoStream, true) if err := srv.Start(address); err != nil { // Fatalf will exit the program if the server fails to start log.Fatalf("Server failed to start: %v", err) @@ -604,7 +384,7 @@ func main() { // Wait for an interrupt or termination signal <-sigChan - log.Println("Shutdown signal received, initiating graceful shutdown...") + log.Info("Shutdown signal received, initiating graceful shutdown...") // Create a context with a timeout for graceful shutdown // Allow 10 seconds for existing requests to finish @@ -613,10 +393,10 @@ func main() { // Attempt to stop the server gracefully if err := srv.Stop(ctx); err != nil { - log.Fatalf("Server shutdown failed: %v", err) + log.Errorf("Server shutdown failed: %v", err) + } else { + log.Info("Server exited gracefully.") } - - log.Println("Server exited gracefully.") } // Helper function to create a string pointer @@ -643,21 +423,53 @@ func newPushNotificationSender(base taskmanager.TaskManager) *pushNotificationSe } } -// OnSendTask overrides the TaskManager.OnSendTask method to add webhook notification -func (p *pushNotificationSender) OnSendTask( +// OnSendMessage overrides to add webhook notification +func (p *pushNotificationSender) OnSendMessage( ctx context.Context, - params protocol.SendTaskParams, -) (*protocol.Task, error) { + params protocol.SendMessageParams, +) (*protocol.MessageResult, error) { // Call the underlying implementation - task, err := p.TaskManager.OnSendTask(ctx, params) - if err == nil && task != nil { - // Send push notification if task was created successfully - go p.maybeSendStatusPushNotification(ctx, task.ID, task.Status.State) + result, err := p.TaskManager.OnSendMessage(ctx, params) + if err == nil && result != nil { + // Check if result is a task and send notification if needed + if task, ok := result.Result.(*protocol.Task); ok { + go p.maybeSendStatusPushNotification(ctx, task.ID, task.Status.State) + } } - return task, err + return result, err } -// OnCancelTask overrides the TaskManager.OnCancelTask method to add webhook notification +// OnSendMessageStream overrides to add webhook notification +func (p *pushNotificationSender) OnSendMessageStream( + ctx context.Context, + params protocol.SendMessageParams, +) (<-chan protocol.StreamingMessageEvent, error) { + // Call the underlying implementation + eventChan, err := p.TaskManager.OnSendMessageStream(ctx, params) + if err != nil { + return nil, err + } + + // Create a wrapper channel to monitor events and send notifications + wrappedChan := make(chan protocol.StreamingMessageEvent) + + go func() { + defer close(wrappedChan) + for event := range eventChan { + // Forward the event + wrappedChan <- event + + // Check if it's a task status update and send notification + if statusEvent, ok := event.Result.(*protocol.TaskStatusUpdateEvent); ok { + go p.maybeSendStatusPushNotification(ctx, statusEvent.TaskID, statusEvent.Status.State) + } + } + }() + + return wrappedChan, nil +} + +// OnCancelTask overrides to add webhook notification (keep existing deprecated method support) func (p *pushNotificationSender) OnCancelTask( ctx context.Context, params protocol.TaskIDParams, @@ -704,14 +516,14 @@ func (p *pushNotificationSender) sendPushNotification( // Create JSON payload jsonData, err := json.Marshal(payload) if err != nil { - log.Printf("Error marshaling push notification: %v", err) + log.Errorf("Error marshaling push notification: %v", err) return } // Create HTTP request req, err := http.NewRequest(http.MethodPost, config.URL, bytes.NewBuffer(jsonData)) if err != nil { - log.Printf("Error creating push notification request: %v", err) + log.Errorf("Error creating push notification request: %v", err) return } @@ -728,7 +540,7 @@ func (p *pushNotificationSender) sendPushNotification( client := &http.Client{Timeout: 10 * time.Second} resp, err := client.Do(req) if err != nil { - log.Printf("Error sending push notification: %v", err) + log.Errorf("Error sending push notification: %v", err) return } defer resp.Body.Close() @@ -736,9 +548,232 @@ func (p *pushNotificationSender) sendPushNotification( // Check response if resp.StatusCode >= 400 { body, _ := io.ReadAll(resp.Body) - log.Printf("Push notification failed with status %d: %s", resp.StatusCode, string(body)) + log.Errorf("Push notification failed with status %d: %s", resp.StatusCode, string(body)) + return + } + + log.Infof("Push notification sent successfully to %s", config.URL) +} + +// processMessageAsync handles message processing asynchronously with full task management +func (p *basicMessageProcessor) processMessageAsync( + ctx context.Context, + text string, + contextID string, + taskID string, + handle taskmanager.TaskHandler, +) { + // Update task to working state + err := handle.UpdateTaskState(&taskID, protocol.TaskStateWorking, nil) + if err != nil { + log.Errorf("Failed to update task state: %v", err) return } - log.Printf("Push notification sent successfully to %s", config.URL) + // Check for continuation of a multi-turn session using contextID + session, exists := p.multiTurnSessions[contextID] + + if exists && !session.complete { + p.handleMultiTurnSessionAsync(ctx, taskID, text, contextID, handle, session) + return + } + + // New interaction - determine mode and process accordingly + p.handleNewInteractionAsync(ctx, taskID, text, contextID, handle) +} + +// handleMultiTurnSessionAsync processes the next step in a multi-turn interaction +func (p *basicMessageProcessor) handleMultiTurnSessionAsync( + ctx context.Context, + taskID string, + text string, + contextID string, + handle taskmanager.TaskHandler, + session multiTurnSession, +) { + // Update session with new input + switch session.stage { + case 1: + // First response received - this is the mode + session.mode = strings.ToLower(strings.TrimSpace(text)) + session.stage = 2 + + // Ask for the text to process + msg := protocol.NewMessage( + protocol.MessageRoleAgent, + []protocol.Part{protocol.NewTextPart("Please enter the text you want to process:")}, + ) + + // Update task to input-required state + err := handle.UpdateTaskState(&taskID, protocol.TaskStateInputRequired, &msg) + if err != nil { + log.Errorf("Failed to update task status: %v", err) + return + } + + // Store updated session using contextID + p.multiTurnSessions[contextID] = session + + case 2: + // Second response received - this is the text to process + session.text = text + session.stage = 3 + session.complete = true + + // Process the text based on the selected mode + result := p.processTextWithMode(session.text, session.mode) + + // Create the completed message + finalMsg := protocol.NewMessage( + protocol.MessageRoleAgent, + []protocol.Part{protocol.NewTextPart(result)}, + ) + + // Add artifact with the processed result + artifact := protocol.Artifact{ + ArtifactID: "processed-text-" + taskID, + Name: stringPtr("Processed Text"), + Description: stringPtr(fmt.Sprintf("Text processed with mode: %s", session.mode)), + Parts: []protocol.Part{protocol.NewTextPart(result)}, + Metadata: map[string]interface{}{ + "operation": session.mode, + "originalText": session.text, + "processedAt": time.Now().UTC().Format(time.RFC3339), + "sessionStage": session.stage, + "contextID": contextID, + }, + } + + // Add artifact to task + if err := handle.AddArtifact(&taskID, artifact, true, false); err != nil { + log.Errorf("Failed to add artifact: %v", err) + } + + // Update task to completed state + err := handle.UpdateTaskState(&taskID, protocol.TaskStateCompleted, &finalMsg) + if err != nil { + log.Errorf("Failed to complete task: %v", err) + } + + // Update session in map + p.multiTurnSessions[contextID] = session + } +} + +// handleNewInteractionAsync processes a new interaction +func (p *basicMessageProcessor) handleNewInteractionAsync( + ctx context.Context, + taskID string, + text string, + contextID string, + handle taskmanager.TaskHandler, +) { + // Check for cancellation via context + if err := ctx.Err(); err != nil { + log.Errorf("Task %s cancelled during processing: %v", taskID, err) + _ = handle.UpdateTaskState(&taskID, protocol.TaskStateCanceled, nil) + return + } + + // Parse the first word as the command + parts := strings.SplitN(text, " ", 2) + command := strings.ToLower(parts[0]) + + // Handle multi-step mode + if command == modeMultiStep { + session := multiTurnSession{ + stage: 1, + complete: false, + } + + // Store the session using contextID + p.multiTurnSessions[contextID] = session + + // Ask for the processing mode + msg := protocol.NewMessage( + protocol.MessageRoleAgent, + []protocol.Part{protocol.NewTextPart( + "This is a multi-step interaction. Please select a processing mode:\n" + + "- reverse: Reverses the text\n" + + "- uppercase: Converts text to uppercase\n" + + "- lowercase: Converts text to lowercase\n" + + "- count: Counts words and characters")}, + ) + + // Update task to input-required state + err := handle.UpdateTaskState(&taskID, protocol.TaskStateInputRequired, &msg) + if err != nil { + log.Errorf("Failed to update task status: %v", err) + } + return + } + + // Handle example input-required state + if command == modeInputExample { + msg := protocol.NewMessage( + protocol.MessageRoleAgent, + []protocol.Part{protocol.NewTextPart("Please provide more information to continue:")}, + ) + + // Update task to input-required state + err := handle.UpdateTaskState(&taskID, protocol.TaskStateInputRequired, &msg) + if err != nil { + log.Errorf("Failed to update task status: %v", err) + } + + // Create a session for the follow-up + session := multiTurnSession{ + stage: 2, // Skip to stage 2 (text input) + mode: modeReverse, // Default to reverse mode + complete: false, + } + p.multiTurnSessions[contextID] = session + return + } + + // For direct processing (non-multi-turn), extract the rest as content + var content string + if len(parts) > 1 { + content = parts[1] + } else { + content = "" // No content provided, command only + } + + // Simulate processing delay for demonstration + time.Sleep(500 * time.Millisecond) + + // Process the content based on command + result := p.processTextWithMode(content, command) + + // Create artifact with the processed result + artifact := protocol.Artifact{ + ArtifactID: "processed-text-" + taskID, + Name: stringPtr("Processed Text"), + Description: stringPtr(fmt.Sprintf("Text processed with mode: %s", command)), + Parts: []protocol.Part{protocol.NewTextPart(result)}, + Metadata: map[string]interface{}{ + "operation": command, + "originalText": content, + "processedAt": time.Now().UTC().Format(time.RFC3339), + "directMode": true, + "contextID": contextID, + }, + } + + // Add artifact to task + if err := handle.AddArtifact(&taskID, artifact, true, false); err != nil { + log.Errorf("Failed to add artifact: %v", err) + } + + // Create final message + finalMsg := protocol.NewMessage( + protocol.MessageRoleAgent, + []protocol.Part{protocol.NewTextPart(result)}, + ) + + // Update task to completed state + err := handle.UpdateTaskState(&taskID, protocol.TaskStateCompleted, &finalMsg) + if err != nil { + log.Errorf("Failed to complete task: %v", err) + } } diff --git a/examples/go.mod b/examples/go.mod index d70a3f1..abdf1c3 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -5,14 +5,17 @@ go 1.23.0 toolchain go1.23.7 require ( - github.com/google/uuid v1.3.1 + github.com/google/uuid v1.6.0 github.com/lestrrat-go/jwx/v2 v2.1.4 + github.com/redis/go-redis/v9 v9.10.0 golang.org/x/oauth2 v0.29.0 trpc.group/trpc-go/trpc-a2a-go v0.0.0 ) require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/golang-jwt/jwt/v5 v5.2.2 // indirect github.com/lestrrat-go/blackmagic v1.0.2 // indirect diff --git a/examples/go.sum b/examples/go.sum index 6f47863..9311cfc 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -1,16 +1,24 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= -github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= @@ -25,6 +33,8 @@ github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNB github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.10.0 h1:FxwK3eV8p/CQa0Ch276C7u2d0eNC9kCmAYQ7mCXCzVs= +github.com/redis/go-redis/v9 v9.10.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/examples/jwks/client/main.go b/examples/jwks/client/main.go index b71b607..9149c80 100644 --- a/examples/jwks/client/main.go +++ b/examples/jwks/client/main.go @@ -584,129 +584,126 @@ func startWebhookServer(cfg *Config, handler http.Handler) { }() } -func main() { - // Parse command line flags - cfg := Config{ - ServerHost: defaultServerHost, - ServerPort: defaultServerPort, - WebhookHost: defaultWebhookHost, - WebhookPort: defaultWebhookPort, - WebhookPath: defaultWebhookPath, +// sendTaskToServer demonstrates sending a task using the new message API +func sendTaskToServer(ctx context.Context, a2aClient *client.A2AClient, content string) (string, error) { + log.Infof("Sending message with content: %s", content) + + // Create a message with the content + message := protocol.NewMessage( + protocol.MessageRoleUser, + []protocol.Part{protocol.NewTextPart(content)}, + ) + + // Create send message parameters + params := protocol.SendMessageParams{ + Message: message, } + // Send the message + result, err := a2aClient.SendMessage(ctx, params) + if err != nil { + return "", fmt.Errorf("failed to send message: %w", err) + } + + // Check the result type + switch response := result.Result.(type) { + case *protocol.Message: + log.Infof("Received direct message response") + return "", nil // No task ID for direct message responses + + case *protocol.Task: + log.Infof("Task created: %s (State: %s)", response.ID, response.Status.State) + return response.ID, nil + + default: + return "", fmt.Errorf("unexpected response type: %T", response) + } +} + +func main() { + // Parse command line flags + var cfg Config flag.StringVar(&cfg.ServerHost, "server-host", defaultServerHost, "A2A server host") flag.IntVar(&cfg.ServerPort, "server-port", defaultServerPort, "A2A server port") flag.StringVar(&cfg.WebhookHost, "webhook-host", defaultWebhookHost, "Webhook server host") flag.IntVar(&cfg.WebhookPort, "webhook-port", defaultWebhookPort, "Webhook server port") - flag.StringVar(&cfg.WebhookPath, "webhook-path", defaultWebhookPath, "Webhook endpoint path") + flag.StringVar(&cfg.WebhookPath, "webhook-path", defaultWebhookPath, "Webhook path") + flag.StringVar(&cfg.JWKSEndpoint, "jwks-endpoint", "", "JWKS endpoint (default: derived from server)") flag.Parse() - // Set JWKS endpoint URL - cfg.JWKSEndpoint = fmt.Sprintf("http://%s:%d/.well-known/jwks.json", - cfg.ServerHost, cfg.ServerPort) - - // Construct webhook URL - webhookURL := fmt.Sprintf("http://%s:%d%s", - cfg.WebhookHost, cfg.WebhookPort, cfg.WebhookPath) - - // Create webhook handler - handler := NewWebhookHandler(cfg.JWKSEndpoint) + // Set default JWKS endpoint if not provided + if cfg.JWKSEndpoint == "" { + cfg.JWKSEndpoint = fmt.Sprintf("http://%s:%d/.well-known/jwks.json", cfg.ServerHost, cfg.ServerPort) + } - // Start webhook server - startWebhookServer(&cfg, handler) + log.Infof("Starting JWKS client with config: %+v", cfg) - // Create A2A client - a2aClient, err := client.NewA2AClient( - fmt.Sprintf("http://%s:%d/", cfg.ServerHost, cfg.ServerPort), - client.WithTimeout(30*time.Second), - ) + // Create an A2A client + serverURL := fmt.Sprintf("http://%s:%d/", cfg.ServerHost, cfg.ServerPort) + a2aClient, err := client.NewA2AClient(serverURL) if err != nil { log.Fatalf("Failed to create A2A client: %v", err) } - // Generate task ID - taskID := fmt.Sprintf("task-%d", time.Now().Unix()) - log.Infof("Task ID: %s", taskID) + // Create webhook handler that will receive push notifications + webhookHandler := NewWebhookHandler(cfg.JWKSEndpoint) - // Create task payload - payload := map[string]interface{}{ - "content": "Test task with push notification", - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - log.Fatalf("Failed to marshal payload: %v", err) - } - - // Start tracking this task - handler.TrackTask(taskID) - - // Create task parameters - params := protocol.SendTaskParams{ - ID: taskID, - Message: protocol.NewMessage( - protocol.MessageRoleUser, - []protocol.Part{ - protocol.NewTextPart(string(payloadBytes)), - }, - ), - } - - // Step 1: Set up push notification configuration BEFORE sending the task - pushConfig := protocol.TaskPushNotificationConfig{ - ID: taskID, - PushNotificationConfig: protocol.PushNotificationConfig{ - URL: webhookURL, - // Explicitly set up JWT authentication - Authentication: &protocol.AuthenticationInfo{ - Schemes: []string{"bearer"}, - }, - // Include metadata to help with JWT auth setup - Metadata: map[string]interface{}{ - "jwksUrl": cfg.JWKSEndpoint, - }, - }, - } - - // Step 1.1: Register for push notifications - log.Infof("1. Registering for push notifications at: %s", webhookURL) - _, err = a2aClient.SetPushNotification(context.Background(), pushConfig) - if err != nil { - log.Fatalf("Failed to set push notification: %v", err) - } - log.Infof(" ✓ Successfully registered for push notifications") + // Start webhook server + webhookURL := fmt.Sprintf("http://%s:%d%s", cfg.WebhookHost, cfg.WebhookPort, cfg.WebhookPath) + startWebhookServer(&cfg, webhookHandler) - // Step 2: Send the task (using non-streaming API) - log.Infof("2. Sending task to %s:%d using non-streaming API (tasks/send)...", cfg.ServerHost, cfg.ServerPort) - task, err := a2aClient.SendTasks(context.Background(), params) - if err != nil { - log.Fatalf("Failed to send task: %v", err) - } - log.Infof(" ✓ Task sent successfully, initial status: %s", task.Status.State) + // Demonstrate sending multiple tasks + for i := 1; i <= 3; i++ { + content := fmt.Sprintf("Task %d: Process this message asynchronously", i) - // Step 3: Client can do other work while waiting for push notification - log.Infof("3. Task is being processed asynchronously on the server") - log.Infof(" Client is free to do other work or disconnect") - log.Infof(" Waiting for push notifications via webhook...") + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + taskID, err := sendTaskToServer(ctx, a2aClient, content) + cancel() - // Poll for task status updates periodically to demonstrate client activity - go func() { - ticker := time.NewTicker(2 * time.Second) - defer ticker.Stop() + if err != nil { + log.Errorf("Failed to send task %d: %v", i, err) + continue + } - for range ticker.C { - status := handler.GetTaskStatus(taskID) - if status == "completed" || status == "failed" || status == "canceled" { - log.Infof("4. Task is complete! Final status: %s (received via push notification)", status) - return + if taskID != "" { + // Track the task in our webhook handler + webhookHandler.TrackTask(taskID) + + // Set push notification for the task + pushConfig := protocol.TaskPushNotificationConfig{ + TaskID: taskID, + PushNotificationConfig: protocol.PushNotificationConfig{ + URL: webhookURL, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + _, err = a2aClient.SetPushNotification(ctx, pushConfig) + cancel() + + if err != nil { + log.Errorf("Failed to set push notification for task %s: %v", taskID, err) + } else { + log.Infof("Push notification set for task %s", taskID) } - log.Infof(" ... client still waiting for push notification, current tracked status: %s", status) } - }() - // Wait for termination signal + // Add a small delay between tasks + time.Sleep(1 * time.Second) + } + + log.Infof("All tasks sent. Waiting for push notifications...") + + // Keep the client running to receive notifications sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - <-sigCh - log.Infof("Shutting down...") + select { + case sig := <-sigCh: + log.Infof("Received signal %v, shutting down...", sig) + case <-time.After(2 * time.Minute): + log.Infof("Timeout reached, shutting down...") + } + + log.Infof("Client shutting down.") } diff --git a/examples/jwks/server/main.go b/examples/jwks/server/main.go index 8f4a2a8..7eeb6da 100644 --- a/examples/jwks/server/main.go +++ b/examples/jwks/server/main.go @@ -21,6 +21,7 @@ import ( "fmt" "time" + "github.com/google/uuid" "trpc.group/trpc-go/trpc-a2a-go/auth" "trpc.group/trpc-go/trpc-a2a-go/log" "trpc.group/trpc-go/trpc-a2a-go/protocol" @@ -33,77 +34,135 @@ const ( defaultNotifyHost = "localhost" ) -// pushNotificationTaskProcessor is a task processor that sends push notifications. -// It implements both the taskmanager.TaskProcessor interface for task processing -// and the PushNotificationProcessor interface for receiving the authenticator. -type pushNotificationTaskProcessor struct { +// pushNotificationMessageProcessor is a message processor that sends push notifications. +// It implements the taskmanager.MessageProcessor interface for message processing +// and handles push notification functionality. +type pushNotificationMessageProcessor struct { notifyHost string manager *pushNotificationTaskManager } -// Process implements the TaskProcessor interface. -// This method should return quickly, only setting the task to "submitted" state -// and then processing the task asynchronously. -func (p *pushNotificationTaskProcessor) Process( +// ProcessMessage implements the MessageProcessor interface. +// This method processes messages and can handle both streaming and non-streaming modes. +func (p *pushNotificationMessageProcessor) ProcessMessage( ctx context.Context, - taskID string, message protocol.Message, - handle taskmanager.TaskHandle, -) error { - log.Infof("Task received: %s", taskID) + options taskmanager.ProcessOptions, + handle taskmanager.TaskHandler, +) (*taskmanager.MessageProcessingResult, error) { + log.Infof("Message processing started") // Extract task payload from the message parts var payload map[string]interface{} + var textContent string + if len(message.Parts) > 0 { - if textPart, ok := message.Parts[0].(protocol.TextPart); ok { + if textPart, ok := message.Parts[0].(*protocol.TextPart); ok { + textContent = textPart.Text + // Try to parse as JSON, but if it fails, treat as plain text if err := json.Unmarshal([]byte(textPart.Text), &payload); err != nil { - log.Errorf("Failed to unmarshal payload text: %v", err) - // Continue with empty payload - payload = make(map[string]interface{}) + log.Infof("Message content is plain text, not JSON: %s", textContent) + // Create a simple payload with the text content + payload = map[string]interface{}{ + "content": textContent, + "type": "text", + } + } else { + log.Infof("Message content parsed as JSON successfully") } } } - // Update status to working - if err := handle.UpdateStatus(protocol.TaskStateWorking, &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{ - protocol.NewTextPart("Task queued for processing..."), - }, - }); err != nil { - return fmt.Errorf("failed to update task status: %v", err) + if payload == nil { + payload = map[string]interface{}{ + "content": "empty message", + "type": "text", + } + } + + // For non-streaming processing, return direct result + if !options.Streaming { + return p.processDirectly(ctx, payload) + } + + // For streaming processing, create a task and process asynchronously + taskID, err := handle.BuildTask(nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to build task: %w", err) + } + + // Subscribe to the task for streaming events + subscriber, err := handle.SubScribeTask(&taskID) + if err != nil { + return nil, fmt.Errorf("failed to subscribe to task: %w", err) } // Start asynchronous processing - go p.processTaskAsync(ctx, taskID, payload, handle) + go p.processTaskAsync(ctx, taskID, payload, subscriber) - return nil + return &taskmanager.MessageProcessingResult{ + StreamingEvents: subscriber, + }, nil } -// OnTaskStatusUpdate implements the TaskProcessor interface. -func (p *pushNotificationTaskProcessor) OnTaskStatusUpdate( +// processDirectly handles immediate processing for non-streaming requests +func (p *pushNotificationMessageProcessor) processDirectly( ctx context.Context, - taskID string, - state protocol.TaskState, - message *protocol.Message, -) error { - log.Infof("Updating task status for task: %s with status: %s", taskID, state) - if state == protocol.TaskStateCompleted || - state == protocol.TaskStateFailed || state == protocol.TaskStateCanceled { - p.manager.sendPushNotification(ctx, taskID, string(state)) - } - return nil + payload map[string]interface{}, +) (*taskmanager.MessageProcessingResult, error) { + // Process the task immediately + completeMsg := "Task completed" + if content, ok := payload["content"].(string); ok { + completeMsg = fmt.Sprintf("Task completed: %s", content) + } + + responseMessage := protocol.NewMessage( + protocol.MessageRoleAgent, + []protocol.Part{protocol.NewTextPart(completeMsg)}, + ) + + return &taskmanager.MessageProcessingResult{ + Result: &responseMessage, + }, nil } // processTaskAsync handles the actual task processing in a separate goroutine. -func (p *pushNotificationTaskProcessor) processTaskAsync( +func (p *pushNotificationMessageProcessor) processTaskAsync( ctx context.Context, taskID string, payload map[string]interface{}, - handle taskmanager.TaskHandle, + subscriber taskmanager.TaskSubscriber, ) { + defer func() { + if subscriber != nil { + subscriber.Close() + } + }() + log.Infof("Starting async processing of task: %s", taskID) + // Send working status + workingEvent := protocol.StreamingMessageEvent{ + Result: &protocol.TaskStatusUpdateEvent{ + TaskID: taskID, + ContextID: "", // We'll need to get this from the task + Kind: "status-update", + Status: protocol.TaskStatus{ + State: protocol.TaskStateWorking, + Message: &protocol.Message{ + MessageID: uuid.New().String(), + Kind: "message", + Role: protocol.MessageRoleAgent, + Parts: []protocol.Part{protocol.NewTextPart("Task queued for processing...")}, + }, + }, + }, + } + err := subscriber.Send(workingEvent) + if err != nil { + log.Errorf("Failed to send working event: %v", err) + } + // Process the task (simulating work) time.Sleep(5 * time.Second) // Longer processing time to demonstrate async behavior @@ -113,20 +172,31 @@ func (p *pushNotificationTaskProcessor) processTaskAsync( completeMsg = fmt.Sprintf("Task completed: %s", content) } - // Complete the task - // When we call UpdateStatus with a terminal state (like completed), - // the task manager automatically: - // 1. Updates the task status in memory - // 2. Sends a push notification to the registered webhook URL (if enabled) - if err := handle.UpdateStatus(protocol.TaskStateCompleted, &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{ - protocol.NewTextPart(completeMsg), + // Send completion status + completedEvent := protocol.StreamingMessageEvent{ + Result: &protocol.TaskStatusUpdateEvent{ + TaskID: taskID, + ContextID: "", // We'll need to get this from the task + Kind: "status-update", + Status: protocol.TaskStatus{ + State: protocol.TaskStateCompleted, + Message: &protocol.Message{ + MessageID: uuid.New().String(), + Kind: "message", + Role: protocol.MessageRoleAgent, + Parts: []protocol.Part{protocol.NewTextPart(completeMsg)}, + }, + }, + Final: boolPtr(true), }, - }); err != nil { - log.Errorf("Failed to update task status: %v", err) - return } + err = subscriber.Send(completedEvent) + if err != nil { + log.Errorf("Failed to send completed event: %v", err) + } + + // Send push notification + p.manager.sendPushNotification(ctx, taskID, string(protocol.TaskStateCompleted)) log.Infof("Task completed asynchronously: %s", taskID) } @@ -136,14 +206,7 @@ type pushNotificationTaskManager struct { authenticator *auth.PushNotificationAuthenticator } -func (m *pushNotificationTaskManager) OnSendTask(ctx context.Context, request protocol.SendTaskParams) (*protocol.Task, error) { - task, err := m.TaskManager.OnSendTask(ctx, request) - if err != nil { - return nil, err - } - return task, nil -} - +// sendPushNotification sends a push notification for a completed task func (m *pushNotificationTaskManager) sendPushNotification(ctx context.Context, taskID, status string) { log.Infof("Sending push notification for task: %s with status: %s", taskID, status) // Get push config from task manager @@ -159,7 +222,7 @@ func (m *pushNotificationTaskManager) sendPushNotification(ctx context.Context, } // Send push notification - if err := m.authenticator.SendPushNotification(ctx, pushConfig.URL, map[string]interface{}{ + if err := m.authenticator.SendPushNotification(ctx, pushConfig.PushNotificationConfig.URL, map[string]interface{}{ "task_id": taskID, "status": status, "timestamp": time.Now().Format(time.RFC3339), @@ -181,16 +244,27 @@ func main() { // Create agent card agentCard := server.AgentCard{ Name: "Push Notification Example", - Description: strPtr("A2A server example with push notification support"), + Description: "A2A server example with push notification support", URL: fmt.Sprintf("http://localhost:%d/", *port), Version: "1.0.0", Capabilities: server.AgentCapabilities{ - Streaming: true, - PushNotifications: true, - StateTransitionHistory: true, + Streaming: boolPtr(true), + PushNotifications: boolPtr(true), + StateTransitionHistory: boolPtr(true), }, DefaultInputModes: []string{"text"}, DefaultOutputModes: []string{"text"}, + Skills: []server.AgentSkill{ + { + ID: "push_notification_task", + Name: "Push Notification Task", + Description: strPtr("Processes tasks with push notification support"), + Tags: []string{"push", "notification", "async"}, + Examples: []string{`{"content": "Hello, world!"}`}, + InputModes: []string{"text"}, + OutputModes: []string{"text"}, + }, + }, } authenticator := auth.NewPushNotificationAuthenticator() @@ -199,7 +273,7 @@ func main() { } // Create task processor - processor := &pushNotificationTaskProcessor{ + processor := &pushNotificationMessageProcessor{ notifyHost: *notifyHost, } // Create task manager @@ -241,3 +315,8 @@ func main() { func strPtr(s string) *string { return &s } + +// Helper function to create bool pointer +func boolPtr(b bool) *bool { + return &b +} diff --git a/examples/multi/cli/main.go b/examples/multi/cli/main.go index 0afc128..abaa6cc 100644 --- a/examples/multi/cli/main.go +++ b/examples/multi/cli/main.go @@ -1,3 +1,9 @@ +// Tencent is pleased to support the open source community by making trpc-a2a-go available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// trpc-a2a-go is licensed under the Apache License Version 2.0. + package main import ( diff --git a/examples/multi/creative/main.go b/examples/multi/creative/main.go index dae3a75..8dd9662 100644 --- a/examples/multi/creative/main.go +++ b/examples/multi/creative/main.go @@ -1,3 +1,9 @@ +// Tencent is pleased to support the open source community by making trpc-a2a-go available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// trpc-a2a-go is licensed under the Apache License Version 2.0. + package main import ( @@ -9,6 +15,7 @@ import ( "strings" "syscall" + "github.com/google/uuid" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/googleai" "trpc.group/trpc-go/trpc-a2a-go/log" @@ -48,7 +55,7 @@ func (c *conversationCache) GetHistory(sessionID string) []string { return []string{} } -// creativeWritingProcessor implements the taskmanager.TaskProcessor interface +// creativeWritingProcessor implements the taskmanager.MessageProcessor interface type creativeWritingProcessor struct { llm llms.Model cache *conversationCache @@ -80,40 +87,35 @@ func getAPIKey() string { return apiKey } -// Process implements the taskmanager.TaskProcessor interface -func (p *creativeWritingProcessor) Process( +// ProcessMessage implements the taskmanager.MessageProcessor interface +func (p *creativeWritingProcessor) ProcessMessage( ctx context.Context, - taskID string, message protocol.Message, - handle taskmanager.TaskHandle, -) error { + options taskmanager.ProcessOptions, + handle taskmanager.TaskHandler, +) (*taskmanager.MessageProcessingResult, error) { // Extract text from the incoming message prompt := extractText(message) if prompt == "" { errMsg := "input message must contain text." - log.Error("Task %s failed: %s", taskID, errMsg) + log.Error("Message processing failed: %s", errMsg) - // Update status to Failed via handle - failedMessage := protocol.NewMessage( + // Return error message directly + errorMessage := protocol.NewMessage( protocol.MessageRoleAgent, []protocol.Part{protocol.NewTextPart(errMsg)}, ) - _ = handle.UpdateStatus(protocol.TaskStateFailed, &failedMessage) - return fmt.Errorf(errMsg) + return &taskmanager.MessageProcessingResult{ + Result: &errorMessage, + }, nil } - log.Info("Processing creative writing task %s with prompt: %s", taskID, prompt) - - // Get session ID from task metadata or use taskID as fallback - sessionID := taskID + log.Info("Processing creative writing message with prompt: %s", prompt) - // Update to in-progress status - progressMessage := protocol.NewMessage( - protocol.MessageRoleAgent, - []protocol.Part{protocol.NewTextPart("Crafting your creative response...")}, - ) - if err := handle.UpdateStatus(protocol.TaskStateWorking, &progressMessage); err != nil { - log.Error("Failed to update task status: %v", err) + // Get session ID from message context or generate one + sessionID := handle.GetContextID() + if sessionID == "" { + sessionID = uuid.New().String() } // Build the context from conversation history @@ -139,13 +141,15 @@ func (p *creativeWritingProcessor) Process( response, err := llms.GenerateFromSinglePrompt(ctx, p.llm, finalPrompt) if err != nil { errorMsg := fmt.Sprintf("Failed to generate response: %v", err) - log.Error("Task %s failed: %s", taskID, errorMsg) + log.Error("Message processing failed: %s", errorMsg) errorMessage := protocol.NewMessage( protocol.MessageRoleAgent, []protocol.Part{protocol.NewTextPart(errorMsg)}, ) - return handle.UpdateStatus(protocol.TaskStateFailed, &errorMessage) + return &taskmanager.MessageProcessingResult{ + Result: &errorMessage, + }, nil } // Save prompt and response to conversation history @@ -158,31 +162,15 @@ func (p *creativeWritingProcessor) Process( []protocol.Part{protocol.NewTextPart(response)}, ) - // Update task status to completed - if err := handle.UpdateStatus(protocol.TaskStateCompleted, &responseMessage); err != nil { - return fmt.Errorf("failed to update task status: %w", err) - } - - // Add response as an artifact - artifact := protocol.Artifact{ - Name: stringPtr("Creative Writing Response"), - Description: stringPtr(prompt), - Index: 0, - Parts: []protocol.Part{protocol.NewTextPart(response)}, - LastChunk: boolPtr(true), - } - - if err := handle.AddArtifact(artifact); err != nil { - log.Error("Error adding artifact for task %s: %v", taskID, err) - } - - return nil + return &taskmanager.MessageProcessingResult{ + Result: &responseMessage, + }, nil } // extractText extracts the text content from a message func extractText(message protocol.Message) string { for _, part := range message.Parts { - if textPart, ok := part.(protocol.TextPart); ok { + if textPart, ok := part.(*protocol.TextPart); ok { return textPart.Text } } @@ -202,27 +190,30 @@ func boolPtr(b bool) *bool { func getAgentCard() server.AgentCard { return server.AgentCard{ Name: "Creative Writing Agent", - Description: stringPtr("An agent that generates creative writing based on prompts using Google Gemini."), + Description: "An agent that generates creative writing based on prompts using Google Gemini.", URL: "http://localhost:8082", Version: "1.0.0", Capabilities: server.AgentCapabilities{ - Streaming: false, - PushNotifications: false, - StateTransitionHistory: true, + Streaming: boolPtr(false), + PushNotifications: boolPtr(false), + StateTransitionHistory: boolPtr(true), }, - DefaultInputModes: []string{string(protocol.PartTypeText)}, - DefaultOutputModes: []string{string(protocol.PartTypeText)}, + DefaultInputModes: []string{"text"}, + DefaultOutputModes: []string{"text"}, Skills: []server.AgentSkill{ { ID: "creative_writing", Name: "Creative Writing", Description: stringPtr("Creates engaging creative text based on user prompts."), + Tags: []string{"creative", "writing", "llm"}, Examples: []string{ "Write a short story about a space explorer", "Compose a poem about autumn leaves", "Create a funny dialogue between a cat and a dog", "Write a brief fantasy adventure about a magical forest", }, + InputModes: []string{"text"}, + OutputModes: []string{"text"}, }, }, } diff --git a/examples/multi/exchange/main.go b/examples/multi/exchange/main.go index 179281e..e9571f9 100644 --- a/examples/multi/exchange/main.go +++ b/examples/multi/exchange/main.go @@ -1,3 +1,9 @@ +// Tencent is pleased to support the open source community by making trpc-a2a-go available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// trpc-a2a-go is licensed under the Apache License Version 2.0. + package main import ( @@ -14,18 +20,18 @@ import ( "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/googleai" - "trpc.group/trpc-go/trpc-a2a-go/protocol" "trpc.group/trpc-go/trpc-a2a-go/log" + "trpc.group/trpc-go/trpc-a2a-go/protocol" "trpc.group/trpc-go/trpc-a2a-go/server" "trpc.group/trpc-go/trpc-a2a-go/taskmanager" ) -// exchangeProcessor implements the taskmanager.TaskProcessor interface. +// exchangeProcessor implements the taskmanager.MessageProcessor interface type exchangeProcessor struct { llm llms.Model } -// newExchangeProcessor creates a new exchange processor with LangChain. +// newExchangeProcessor creates a new exchange processor with LangChain func newExchangeProcessor() (*exchangeProcessor, error) { // Initialize Google Gemini model llm, err := googleai.New( @@ -45,29 +51,30 @@ func getAPIKey() string { return os.Getenv("GOOGLE_API_KEY") } -// Process implements the taskmanager.TaskProcessor interface. -func (p *exchangeProcessor) Process( +// ProcessMessage implements the taskmanager.MessageProcessor interface +func (p *exchangeProcessor) ProcessMessage( ctx context.Context, - taskID string, message protocol.Message, - handle taskmanager.TaskHandle, -) error { + options taskmanager.ProcessOptions, + handle taskmanager.TaskHandler, +) (*taskmanager.MessageProcessingResult, error) { // Extract text from the incoming message. query := extractText(message) if query == "" { errMsg := "input message must contain text." - log.Error("Task %s failed: %s", taskID, errMsg) + log.Error("Message processing failed: %s", errMsg) - // Update status to Failed via handle. - failedMessage := protocol.NewMessage( + // Return error message directly + errorMessage := protocol.NewMessage( protocol.MessageRoleAgent, []protocol.Part{protocol.NewTextPart(errMsg)}, ) - _ = handle.UpdateStatus(protocol.TaskStateFailed, &failedMessage) - return fmt.Errorf(errMsg) + return &taskmanager.MessageProcessingResult{ + Result: &errorMessage, + }, nil } - log.Info("Processing exchange task %s with query: %s", taskID, query) + log.Info("Processing exchange request with query: %s", query) // First attempt to use the LLM to enhance understanding prompt := fmt.Sprintf( @@ -130,7 +137,9 @@ func (p *exchangeProcessor) Process( protocol.MessageRoleAgent, []protocol.Part{protocol.NewTextPart(completion)}, ) - return handle.UpdateStatus(protocol.TaskStateCompleted, &responseMessage) + return &taskmanager.MessageProcessingResult{ + Result: &responseMessage, + }, nil } } @@ -142,7 +151,9 @@ func (p *exchangeProcessor) Process( protocol.MessageRoleAgent, []protocol.Part{protocol.NewTextPart(fmt.Sprintf("Error processing request: %v", err))}, ) - return handle.UpdateStatus(protocol.TaskStateFailed, &errorMessage) + return &taskmanager.MessageProcessingResult{ + Result: &errorMessage, + }, nil } // Format response with some explanation @@ -151,30 +162,15 @@ func (p *exchangeProcessor) Process( log.Info("Responding with: %s", finalResponse) - // Update task status to completed with the response + // Create response message responseMessage := protocol.NewMessage( protocol.MessageRoleAgent, []protocol.Part{protocol.NewTextPart(finalResponse)}, ) - if err := handle.UpdateStatus(protocol.TaskStateCompleted, &responseMessage); err != nil { - return fmt.Errorf("failed to update task status: %w", err) - } - - // Add the exchange rate data as an artifact - artifact := protocol.Artifact{ - Name: stringPtr("Exchange Rate Data"), - Description: stringPtr(fmt.Sprintf("Exchange rate from %s to %s", fromCurrency, toCurrency)), - Index: 0, - Parts: []protocol.Part{protocol.NewTextPart(result)}, - LastChunk: boolPtr(true), - } - - if err := handle.AddArtifact(artifact); err != nil { - log.Error("Error adding artifact for task %s: %v", taskID, err) - } - - return nil + return &taskmanager.MessageProcessingResult{ + Result: &responseMessage, + }, nil } // parseExchangeQuery attempts to parse a natural language query to extract currency info. @@ -294,30 +290,33 @@ func boolPtr(b bool) *bool { return &b } -// getAgentCard returns the agent's metadata. +// getAgentCard returns the agent's metadata func getAgentCard() server.AgentCard { return server.AgentCard{ - Name: "Currency Exchange Agent", - Description: stringPtr("An agent that can fetch and display currency exchange rates."), - URL: "http://localhost:8081", + Name: "Exchange Rate Agent", + Description: "An agent that can fetch and display currency exchange rates.", + URL: "http://localhost:8084", Version: "1.0.0", Capabilities: server.AgentCapabilities{ - Streaming: false, - PushNotifications: false, - StateTransitionHistory: true, + Streaming: boolPtr(false), + PushNotifications: boolPtr(false), + StateTransitionHistory: boolPtr(true), }, - DefaultInputModes: []string{string(protocol.PartTypeText)}, - DefaultOutputModes: []string{string(protocol.PartTypeText)}, + DefaultInputModes: []string{"text"}, + DefaultOutputModes: []string{"text"}, Skills: []server.AgentSkill{ { - ID: "exchange_rate", + ID: "exchange_rates", Name: "Currency Exchange Rates", - Description: stringPtr("Gets the current or historical exchange rates between currencies."), + Description: stringPtr("Fetches current or historical currency exchange rates"), + Tags: []string{"currency", "exchange", "rates", "finance"}, Examples: []string{ - "What is the exchange rate from USD to EUR?", + "What's the USD to EUR exchange rate?", "Convert 100 USD to JPY", - "What was the rate of GBP to USD on 2023-01-01?", + "EUR to GBP rate for 2023-10-15", }, + InputModes: []string{"text"}, + OutputModes: []string{"text"}, }, }, } diff --git a/examples/multi/reimbursement/main.go b/examples/multi/reimbursement/main.go index 697bcce..8d20562 100644 --- a/examples/multi/reimbursement/main.go +++ b/examples/multi/reimbursement/main.go @@ -1,3 +1,9 @@ +// Tencent is pleased to support the open source community by making trpc-a2a-go available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// trpc-a2a-go is licensed under the Apache License Version 2.0. + package main import ( @@ -5,15 +11,12 @@ import ( "encoding/json" "flag" "fmt" - "math/rand" "os" "os/signal" "strconv" "strings" "syscall" - "github.com/tmc/langchaingo/llms" - "github.com/tmc/langchaingo/llms/googleai" "trpc.group/trpc-go/trpc-a2a-go/log" "trpc.group/trpc-go/trpc-a2a-go/protocol" "trpc.group/trpc-go/trpc-a2a-go/server" @@ -23,222 +26,143 @@ import ( // Store request IDs for demonstration purposes var requestIDs = make(map[string]bool) -// reimbursementProcessor implements the taskmanager.TaskProcessor interface -type reimbursementProcessor struct { - llm llms.Model -} - -// newReimbursementProcessor creates a new reimbursement processor with LangChain -func newReimbursementProcessor() (*reimbursementProcessor, error) { - // Initialize Google Gemini model - llm, err := googleai.New( - context.Background(), - googleai.WithAPIKey(getAPIKey()), - ) - if err != nil { - return nil, fmt.Errorf("failed to initialize Gemini model: %w", err) - } - - return &reimbursementProcessor{ - llm: llm, - }, nil -} - -func getAPIKey() string { - return os.Getenv("GOOGLE_API_KEY") -} +// reimbursementProcessor implements the taskmanager.MessageProcessor interface +type reimbursementProcessor struct{} -// Process implements the taskmanager.TaskProcessor interface -func (p *reimbursementProcessor) Process( +// ProcessMessage implements the taskmanager.MessageProcessor interface +func (p *reimbursementProcessor) ProcessMessage( ctx context.Context, - taskID string, message protocol.Message, - handle taskmanager.TaskHandle, -) error { - // Extract text from the incoming message - query := extractText(message) - if query == "" { + options taskmanager.ProcessOptions, + handle taskmanager.TaskHandler, +) (*taskmanager.MessageProcessingResult, error) { + // Extract text from the incoming message. + text := extractText(message) + if text == "" { errMsg := "input message must contain text." - log.Error("Task %s failed: %s", taskID, errMsg) + log.Error("Message processing failed: %s", errMsg) - // Update status to Failed via handle - failedMessage := protocol.NewMessage( + // Return error message directly + errorMessage := protocol.NewMessage( protocol.MessageRoleAgent, []protocol.Part{protocol.NewTextPart(errMsg)}, ) - _ = handle.UpdateStatus(protocol.TaskStateFailed, &failedMessage) - return fmt.Errorf(errMsg) + return &taskmanager.MessageProcessingResult{ + Result: &errorMessage, + }, nil } - log.Info("Processing reimbursement task %s with query: %s", taskID, query) + log.Info("Processing reimbursement request: %s", text) - // Check if this is a form submission - if strings.Contains(query, "request_id") && strings.Contains(query, "date") && - strings.Contains(query, "amount") && strings.Contains(query, "purpose") { - return p.handleFormSubmission(ctx, taskID, query, handle) - } + // Try to extract reimbursement details from natural language + date, amount, purpose := extractReimbursementDetails(text) - // Otherwise, this is a new request - create a form - return p.handleNewRequest(ctx, taskID, query, handle) -} + // Also try to extract structured form data if available + formData := extractFormData(text) -// handleNewRequest processes a new reimbursement request by creating a form -func (p *reimbursementProcessor) handleNewRequest( - ctx context.Context, - taskID string, - query string, - handle taskmanager.TaskHandle, -) error { - // Try to extract date, amount, and purpose from the query - date, amount, purpose := extractReimbursementDetails(query) - - // Generate a random request ID - requestID := fmt.Sprintf("request_id_%d", rand.Intn(9000000)+1000000) - requestIDs[requestID] = true - - // Create a form request - formRequest := map[string]interface{}{ - "request_id": requestID, - "date": date, - "amount": amount, - "purpose": purpose, - } + // Build a complete reimbursement record + reimbursement := make(map[string]interface{}) - // Create form response - formDict := map[string]interface{}{ - "type": "form", - "form": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "date": map[string]interface{}{ - "type": "string", - "format": "date", - "description": "Date of expense", - "title": "Date", - }, - "amount": map[string]interface{}{ - "type": "string", - "format": "number", - "description": "Amount of expense", - "title": "Amount", - }, - "purpose": map[string]interface{}{ - "type": "string", - "description": "Purpose of expense", - "title": "Purpose", - }, - "request_id": map[string]interface{}{ - "type": "string", - "description": "Request id", - "title": "Request ID", - }, - }, - "required": []string{"request_id", "date", "amount", "purpose"}, - }, - "form_data": formRequest, - "instructions": "Please fill out this reimbursement request form with all required information.", + // Use extracted natural language data + if date != "" { + reimbursement["date"] = date } - - // Convert to JSON - formJSON, err := json.MarshalIndent(formDict, "", " ") - if err != nil { - return fmt.Errorf("failed to create form JSON: %w", err) + if amount != "" { + reimbursement["amount"] = amount } - - // Create response message - responseMessage := protocol.NewMessage( - protocol.MessageRoleAgent, - []protocol.Part{protocol.NewTextPart(string(formJSON))}, - ) - - // Update task status to in-progress with the form - if err := handle.UpdateStatus(protocol.TaskStateWorking, &responseMessage); err != nil { - return fmt.Errorf("failed to update task status: %w", err) + if purpose != "" { + reimbursement["purpose"] = purpose } - return nil -} - -// handleFormSubmission processes a submitted reimbursement form -func (p *reimbursementProcessor) handleFormSubmission( - ctx context.Context, - taskID string, - formData string, - handle taskmanager.TaskHandle, -) error { - // Try to parse the form data - var parsedForm map[string]interface{} - if err := json.Unmarshal([]byte(formData), &parsedForm); err != nil { - // If can't parse as JSON, try to extract form data from the text - parsedForm = extractFormData(formData) + // Override with structured form data if available + for key, value := range formData { + reimbursement[key] = value } - // Validate the form - missingFields := validateForm(parsedForm) - if len(missingFields) > 0 { - // Form is incomplete - ask for the missing fields - errorMsg := fmt.Sprintf( - "Your reimbursement request is missing required information: %s. Please provide all required information.", - strings.Join(missingFields, ", "), - ) - errorMessage := protocol.NewMessage( - protocol.MessageRoleAgent, - []protocol.Part{protocol.NewTextPart(errorMsg)}, - ) - return handle.UpdateStatus(protocol.TaskStateWorking, &errorMessage) + // Generate request ID if not provided + if _, exists := reimbursement["request_id"]; !exists { + reimbursement["request_id"] = "REIMB-" + generateReferenceID() } - // Check if request ID is valid - requestID, _ := parsedForm["request_id"].(string) - if !requestIDs[requestID] { - errorMsg := fmt.Sprintf("Error: Invalid request_id: %s", requestID) - errorMessage := protocol.NewMessage( - protocol.MessageRoleAgent, - []protocol.Part{protocol.NewTextPart(errorMsg)}, - ) - return handle.UpdateStatus(protocol.TaskStateFailed, &errorMessage) - } + // Validate the reimbursement request + missing := validateForm(reimbursement) + + var result string + if len(missing) > 0 { + // Request is incomplete - ask for missing information + result = fmt.Sprintf("Your reimbursement request is missing the following required information: %s.\n\n"+ + "Please provide:\n", strings.Join(missing, ", ")) + + for _, field := range missing { + switch field { + case "request_id": + result += "- Request ID (will be auto-generated if not provided)\n" + case "date": + result += "- Date of expense (YYYY-MM-DD format)\n" + case "amount": + result += "- Amount ($XX.XX format)\n" + case "purpose": + result += "- Purpose/reason for the expense\n" + } + } - // Process the reimbursement - amount, _ := parsedForm["amount"].(string) - purpose, _ := parsedForm["purpose"].(string) - date, _ := parsedForm["date"].(string) - - // Generate response - response := fmt.Sprintf( - "Your reimbursement request has been approved!\n\n"+ - "Request ID: %s\n"+ - "Date: %s\n"+ - "Amount: %s\n"+ - "Purpose: %s\n\n"+ - "Status: approved", - requestID, date, amount, purpose, - ) + result += "\nYou can provide information in natural language or structured format like:\n" + result += "Date: 2023-10-15\nAmount: $50.00\nPurpose: Business lunch with client" + } else { + // Request is complete - process it + requestID := reimbursement["request_id"].(string) + requestIDs[requestID] = true + + result = fmt.Sprintf("✅ Reimbursement request processed successfully!\n\n"+ + "Request Details:\n"+ + "- Request ID: %s\n"+ + "- Date: %s\n"+ + "- Amount: %s\n"+ + "- Purpose: %s\n\n"+ + "Status: Approved\n"+ + "Processing Time: 2-3 business days\n"+ + "You will receive an email confirmation shortly.", + requestID, + reimbursement["date"], + reimbursement["amount"], + reimbursement["purpose"]) + } - // Mark the task as completed + // Create response message responseMessage := protocol.NewMessage( protocol.MessageRoleAgent, - []protocol.Part{protocol.NewTextPart(response)}, + []protocol.Part{protocol.NewTextPart(result)}, ) - if err := handle.UpdateStatus(protocol.TaskStateCompleted, &responseMessage); err != nil { - return fmt.Errorf("failed to update task status: %w", err) + // Create result with potential artifact for completed requests + processingResult := &taskmanager.MessageProcessingResult{ + Result: &responseMessage, } - // Add the reimbursement data as an artifact - artifact := protocol.Artifact{ - Name: stringPtr("Reimbursement Request"), - Description: stringPtr(fmt.Sprintf("Reimbursement request for %s", amount)), - Index: 0, - Parts: []protocol.Part{protocol.NewTextPart(response)}, - LastChunk: boolPtr(true), - } + // Add artifact for completed reimbursement requests + if len(missing) == 0 { + // Build task to get artifact support + task, err := handle.BuildTask(nil, nil) + if err == nil { + // Create reimbursement details artifact + reimbursementJSON, _ := json.Marshal(reimbursement) + artifact := protocol.Artifact{ + ArtifactID: fmt.Sprintf("reimb-%s", reimbursement["request_id"]), + Name: stringPtr("Reimbursement Details"), + Description: stringPtr(fmt.Sprintf("Processed reimbursement request %s", reimbursement["request_id"])), + Parts: []protocol.Part{protocol.NewTextPart(string(reimbursementJSON))}, + } - if err := handle.AddArtifact(artifact); err != nil { - log.Error("Error adding artifact for task %s: %v", taskID, err) + _ = handle.AddArtifact(&task.Task.ID, artifact, true, false) + } } - return nil + return processingResult, nil +} + +// newReimbursementProcessor creates a new reimbursement processor +func newReimbursementProcessor() (*reimbursementProcessor, error) { + return &reimbursementProcessor{}, nil } // extractReimbursementDetails attempts to extract date, amount, and purpose from the text @@ -342,13 +266,18 @@ func validateForm(form map[string]interface{}) []string { // extractText extracts the text content from a message func extractText(message protocol.Message) string { for _, part := range message.Parts { - if textPart, ok := part.(protocol.TextPart); ok { + if textPart, ok := part.(*protocol.TextPart); ok { return textPart.Text } } return "" } +// generateReferenceID generates a simple reference ID for demonstration +func generateReferenceID() string { + return fmt.Sprintf("%d", len(requestIDs)+1) +} + // Helper functions func stringPtr(s string) *string { return &s @@ -362,26 +291,29 @@ func boolPtr(b bool) *bool { func getAgentCard() server.AgentCard { return server.AgentCard{ Name: "Reimbursement Agent", - Description: stringPtr("An agent that processes employee reimbursement requests."), + Description: "An agent that processes employee reimbursement requests.", URL: "http://localhost:8083", Version: "1.0.0", Capabilities: server.AgentCapabilities{ - Streaming: false, - PushNotifications: false, - StateTransitionHistory: true, + Streaming: boolPtr(false), + PushNotifications: boolPtr(false), + StateTransitionHistory: boolPtr(true), }, - DefaultInputModes: []string{string(protocol.PartTypeText)}, - DefaultOutputModes: []string{string(protocol.PartTypeText)}, + DefaultInputModes: []string{"text"}, + DefaultOutputModes: []string{"text"}, Skills: []server.AgentSkill{ { ID: "reimbursement", Name: "Process Reimbursements", Description: stringPtr("Creates and processes expense reimbursement requests."), + Tags: []string{"expense", "reimbursement", "finance"}, Examples: []string{ "I need to get reimbursed for my business lunch.", "Process my reimbursement for $50 for office supplies.", "Submit a reimbursement request for my travel expenses on 2023-10-15.", }, + InputModes: []string{"text"}, + OutputModes: []string{"text"}, }, }, } diff --git a/examples/multi/root/main.go b/examples/multi/root/main.go index cbda9e9..83b8ada 100644 --- a/examples/multi/root/main.go +++ b/examples/multi/root/main.go @@ -1,3 +1,9 @@ +// Tencent is pleased to support the open source community by making trpc-a2a-go available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// trpc-a2a-go is licensed under the Apache License Version 2.0. + package main import ( @@ -16,7 +22,7 @@ import ( "trpc.group/trpc-go/trpc-a2a-go/taskmanager" ) -// rootAgentProcessor implements the taskmanager.TaskProcessor interface. +// rootAgentProcessor implements the taskmanager.MessageProcessor interface. type rootAgentProcessor struct { // LLM client for decision making llm *googleai.GoogleAI @@ -26,29 +32,30 @@ type rootAgentProcessor struct { reimbursementClient *client.A2AClient } -// Process implements the taskmanager.TaskProcessor interface to process incoming tasks. -func (p *rootAgentProcessor) Process( +// ProcessMessage implements the taskmanager.MessageProcessor interface +func (p *rootAgentProcessor) ProcessMessage( ctx context.Context, - taskID string, message protocol.Message, - handle taskmanager.TaskHandle, -) error { + options taskmanager.ProcessOptions, + handle taskmanager.TaskHandler, +) (*taskmanager.MessageProcessingResult, error) { // Extract text from the incoming message text := extractText(message) if text == "" { errMsg := "input message must contain text" - log.Error("Task %s failed: %s", taskID, errMsg) + log.Error("Message processing failed: %s", errMsg) - // Update status to Failed via handle - failedMessage := protocol.NewMessage( + // Return error message directly + errorMessage := protocol.NewMessage( protocol.MessageRoleAgent, []protocol.Part{protocol.NewTextPart(errMsg)}, ) - _ = handle.UpdateStatus(protocol.TaskStateFailed, &failedMessage) - return fmt.Errorf(errMsg) + return &taskmanager.MessageProcessingResult{ + Result: &errorMessage, + }, nil } - log.Info("RootAgent received new task: %s", text) + log.Info("RootAgent received new request: %s", text) // Use Gemini or rule-based routing to decide which subagent to route the task to subagent, err := p.routeTaskToSubagent(ctx, text) @@ -59,8 +66,9 @@ func (p *rootAgentProcessor) Process( protocol.MessageRoleAgent, []protocol.Part{protocol.NewTextPart(errMsg)}, ) - _ = handle.UpdateStatus(protocol.TaskStateFailed, &errMessage) - return err + return &taskmanager.MessageProcessingResult{ + Result: &errMessage, + }, nil } var result string @@ -90,8 +98,9 @@ func (p *rootAgentProcessor) Process( protocol.MessageRoleAgent, []protocol.Part{protocol.NewTextPart(errMsg)}, ) - _ = handle.UpdateStatus(protocol.TaskStateFailed, &errMessage) - return err + return &taskmanager.MessageProcessingResult{ + Result: &errMessage, + }, nil } // Create response message @@ -100,12 +109,9 @@ func (p *rootAgentProcessor) Process( []protocol.Part{protocol.NewTextPart(result)}, ) - // Update task status to completed - if err := handle.UpdateStatus(protocol.TaskStateCompleted, &responseMessage); err != nil { - return fmt.Errorf("failed to update task status: %w", err) - } - - return nil + return &taskmanager.MessageProcessingResult{ + Result: &responseMessage, + }, nil } // routeTaskToSubagent uses the LLM to decide which subagent should handle the task. @@ -232,7 +238,7 @@ func (p *rootAgentProcessor) callReimbursementAgent(ctx context.Context, text st // extractText extracts the text content from a message. func extractText(message protocol.Message) string { for _, part := range message.Parts { - if textPart, ok := part.(protocol.TextPart); ok { + if textPart, ok := part.(*protocol.TextPart); ok { return textPart.Text } } @@ -243,24 +249,29 @@ func extractText(message protocol.Message) string { func getAgentCard() server.AgentCard { return server.AgentCard{ Name: "Multi-Agent Router", - Description: stringPtr("An agent that routes tasks to appropriate subagents."), + Description: "An agent that routes tasks to appropriate subagents.", URL: "http://localhost:8080", Version: "1.0.0", Capabilities: server.AgentCapabilities{ - Streaming: false, - PushNotifications: false, - StateTransitionHistory: true, + Streaming: boolPtr(false), + PushNotifications: boolPtr(false), + StateTransitionHistory: boolPtr(true), }, + DefaultInputModes: []string{"text"}, + DefaultOutputModes: []string{"text"}, Skills: []server.AgentSkill{ { ID: "route", Name: "Task Routing", Description: stringPtr("Routes tasks to the appropriate specialized agent."), + Tags: []string{"routing", "multi-agent", "orchestration"}, Examples: []string{ "Write a poem about autumn", "What's the exchange rate from USD to EUR?", "I need to get reimbursed for a $50 business lunch", }, + InputModes: []string{"text"}, + OutputModes: []string{"text"}, }, }, } @@ -271,6 +282,11 @@ func stringPtr(s string) *string { return &s } +// boolPtr is a helper function to get a pointer to a bool. +func boolPtr(b bool) *bool { + return &b +} + func main() { port := flag.Int("port", 8080, "Port to listen on for the root agent") creativeAgentURL := flag.String("creative-url", "http://localhost:8082", "URL for the creative writing agent") diff --git a/examples/simple/client/main.go b/examples/simple/client/main.go index 8a52574..2f61eba 100644 --- a/examples/simple/client/main.go +++ b/examples/simple/client/main.go @@ -10,7 +10,6 @@ package main import ( "context" "flag" - "fmt" "log" "time" @@ -36,95 +35,134 @@ func main() { // Display connection information. log.Printf("Connecting to agent: %s (Timeout: %v)", *agentURL, *timeout) - // Create a new unique task ID. - taskID := uuid.New().String() + // Create a new unique message ID and context ID. + contextID := uuid.New().String() + log.Printf("Context ID: %s", contextID) - // Create a new session ID. - sessionID := uuid.New().String() - log.Printf("Session ID: %s", sessionID) - - // Create the message to send. - userMessage := protocol.NewMessage( + // Create the message to send using the new constructor. + userMessage := protocol.NewMessageWithContext( protocol.MessageRoleUser, []protocol.Part{protocol.NewTextPart(*message)}, + nil, // taskID + &contextID, ) - // Create task parameters. - params := protocol.SendTaskParams{ - ID: taskID, - SessionID: &sessionID, - Message: userMessage, + // Create message parameters using the new SendMessageParams structure. + params := protocol.SendMessageParams{ + Message: userMessage, + Configuration: &protocol.SendMessageConfiguration{ + Blocking: boolPtr(true), // Wait for completion + }, } - log.Printf("Sending task %s with message: %s", taskID, *message) + log.Printf("Sending message with content: %s", *message) - // Send task to the agent. + // Send message to the agent using the new message API. ctx, cancel := context.WithTimeout(context.Background(), *timeout) defer cancel() - task, err := a2aClient.SendTasks(ctx, params) + messageResult, err := a2aClient.SendMessage(ctx, params) if err != nil { - log.Fatalf("Failed to send task: %v", err) + log.Fatalf("Failed to send message: %v", err) } - // Display the initial task response. - log.Printf("Task %s initial state: %s", taskID, task.Status.State) - - // Wait for the task to complete if it's not already done. - if task.Status.State != protocol.TaskStateCompleted && - task.Status.State != protocol.TaskStateFailed && - task.Status.State != protocol.TaskStateCanceled { - - log.Printf("Task %s is %s, fetching final state...", taskID, task.Status.State) - - // Get the task's final state. - queryParams := protocol.TaskQueryParams{ - ID: taskID, - } - - // Give the server some time to process. - time.Sleep(500 * time.Millisecond) - - task, err = a2aClient.GetTasks(ctx, queryParams) - if err != nil { - log.Fatalf("Failed to get task status: %v", err) + // Display the result. + log.Printf("Message sent successfully") + + // Handle the result based on its type + switch result := messageResult.Result.(type) { + case *protocol.Message: + log.Printf("Received message response:") + printMessage(*result) + case *protocol.Task: + log.Printf("Received task response - ID: %s, State: %s", result.ID, result.Status.State) + + // If task is not completed, wait and check again + if result.Status.State != protocol.TaskStateCompleted && + result.Status.State != protocol.TaskStateFailed && + result.Status.State != protocol.TaskStateCanceled { + + log.Printf("Task %s is %s, fetching final state...", result.ID, result.Status.State) + + // Get the task's final state. + queryParams := protocol.TaskQueryParams{ + ID: result.ID, + } + + // Give the server some time to process. + time.Sleep(500 * time.Millisecond) + + task, err := a2aClient.GetTasks(ctx, queryParams) + if err != nil { + log.Fatalf("Failed to get task status: %v", err) + } + + log.Printf("Task %s final state: %s", task.ID, task.Status.State) + printTaskResult(task) + } else { + printTaskResult(result) } + default: + log.Printf("Received unknown result type: %T", result) } +} - // Display the final task state. - log.Printf("Task %s final state: %s", taskID, task.Status.State) +// printMessage prints the contents of a message. +func printMessage(message protocol.Message) { + log.Printf("Message ID: %s", message.MessageID) + if message.ContextID != nil { + log.Printf("Context ID: %s", *message.ContextID) + } + log.Printf("Role: %s", message.Role) + + log.Printf("Message parts:") + for i, part := range message.Parts { + switch p := part.(type) { + case *protocol.TextPart: + log.Printf(" Part %d (text): %s", i+1, p.Text) + case *protocol.FilePart: + log.Printf(" Part %d (file): [file content]", i+1) + case *protocol.DataPart: + log.Printf(" Part %d (data): %+v", i+1, p.Data) + default: + log.Printf(" Part %d (unknown): %+v", i+1, part) + } + } +} - // Display the response message if available. +// printTaskResult prints the contents of a task result. +func printTaskResult(task *protocol.Task) { if task.Status.Message != nil { - fmt.Println("\nAgent response:") - for _, part := range task.Status.Message.Parts { - if textPart, ok := part.(protocol.TextPart); ok { - fmt.Println(textPart.Text) - } - } + log.Printf("Task result message:") + printMessage(*task.Status.Message) } - // Display any artifacts. + // Print artifacts if any if len(task.Artifacts) > 0 { - fmt.Println("\nArtifacts:") + log.Printf("Task artifacts:") for i, artifact := range task.Artifacts { - // Display artifact name and description if available. + name := "Unnamed" if artifact.Name != nil { - fmt.Printf("%d. %s", i+1, *artifact.Name) - if artifact.Description != nil { - fmt.Printf(" - %s", *artifact.Description) - } - fmt.Println() - } else { - fmt.Printf("%d. Artifact #%d\n", i+1, i+1) + name = *artifact.Name } - - // Display artifact content. - for _, part := range artifact.Parts { - if textPart, ok := part.(protocol.TextPart); ok { - fmt.Printf(" %s\n", textPart.Text) + log.Printf(" Artifact %d: %s", i+1, name) + for j, part := range artifact.Parts { + switch p := part.(type) { + case *protocol.TextPart: + log.Printf(" Part %d (text): %s", j+1, p.Text) + case *protocol.FilePart: + log.Printf(" Part %d (file): [file content]", j+1) + case *protocol.DataPart: + log.Printf(" Part %d (data): %+v", j+1, p.Data) + default: + log.Printf(" Part %d (unknown): %+v", j+1, part) } } } } -} \ No newline at end of file +} + +// boolPtr returns a pointer to a boolean value. +func boolPtr(b bool) *bool { + return &b +} diff --git a/examples/simple/server/main.go b/examples/simple/server/main.go index d509c10..aefa6e5 100644 --- a/examples/simple/server/main.go +++ b/examples/simple/server/main.go @@ -11,77 +11,155 @@ import ( "context" "flag" "fmt" - "log" "os" "os/signal" "syscall" + "github.com/google/uuid" + + "trpc.group/trpc-go/trpc-a2a-go/log" "trpc.group/trpc-go/trpc-a2a-go/protocol" "trpc.group/trpc-go/trpc-a2a-go/server" "trpc.group/trpc-go/trpc-a2a-go/taskmanager" ) -// simpleTaskProcessor implements the taskmanager.TaskProcessor interface. -type simpleTaskProcessor struct{} +// simpleMessageProcessor implements the taskmanager.MessageProcessor interface. +type simpleMessageProcessor struct{} -// Process implements the taskmanager.TaskProcessor interface. -func (p *simpleTaskProcessor) Process( +// ProcessMessage implements the taskmanager.MessageProcessor interface. +func (p *simpleMessageProcessor) ProcessMessage( ctx context.Context, - taskID string, message protocol.Message, - handle taskmanager.TaskHandle, -) error { + options taskmanager.ProcessOptions, + handle taskmanager.TaskHandler, +) (*taskmanager.MessageProcessingResult, error) { // Extract text from the incoming message. text := extractText(message) if text == "" { errMsg := "input message must contain text." - log.Printf("Task %s failed: %s", taskID, errMsg) + log.Errorf("Message processing failed: %s", errMsg) - // Update status to Failed via handle. - failedMessage := protocol.NewMessage( + // Return error message directly + errorMessage := protocol.NewMessage( protocol.MessageRoleAgent, []protocol.Part{protocol.NewTextPart(errMsg)}, ) - _ = handle.UpdateStatus(protocol.TaskStateFailed, &failedMessage) - return fmt.Errorf(errMsg) + + return &taskmanager.MessageProcessingResult{ + Result: &errorMessage, + }, nil } - log.Printf("Processing task %s with input: %s", taskID, text) + log.Infof("Processing message with input: %s", text) // Process the input text (in this simple example, we'll just reverse it). result := reverseString(text) - // Create response message. - responseMessage := protocol.NewMessage( - protocol.MessageRoleAgent, - []protocol.Part{protocol.NewTextPart(fmt.Sprintf("Processed result: %s", result))}, - ) + // For non-streaming processing, we can return either a Message or Task + if !options.Streaming { + // Return a direct message response + responseMessage := protocol.NewMessage( + protocol.MessageRoleAgent, + []protocol.Part{protocol.NewTextPart(fmt.Sprintf("Processed result: %s", result))}, + ) - // Update task status to completed. - if err := handle.UpdateStatus(protocol.TaskStateCompleted, &responseMessage); err != nil { - return fmt.Errorf("failed to update task status: %w", err) + return &taskmanager.MessageProcessingResult{ + Result: &responseMessage, + }, nil } - // Add the processed text as an artifact. - artifact := protocol.Artifact{ - Name: stringPtr("Reversed Text"), - Description: stringPtr("The input text reversed"), - Index: 0, - Parts: []protocol.Part{protocol.NewTextPart(result)}, - LastChunk: boolPtr(true), + // For streaming processing, create a task and subscribe to it + taskID, err := handle.BuildTask(nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to build task: %w", err) } - if err := handle.AddArtifact(artifact); err != nil { - log.Printf("Error adding artifact for task %s: %v", taskID, err) + // Subscribe to the task for streaming events + subscriber, err := handle.SubScribeTask(&taskID) + if err != nil { + return nil, fmt.Errorf("failed to subscribe to task: %w", err) } - return nil + // Start processing in a goroutine + go func() { + defer func() { + if subscriber != nil { + subscriber.Close() + } + }() + + // Send task status update - working + workingEvent := protocol.StreamingMessageEvent{ + Result: &protocol.TaskStatusUpdateEvent{ + TaskID: taskID, + ContextID: "", + Kind: "status-update", + Status: protocol.TaskStatus{ + State: protocol.TaskStateWorking, + }, + }, + } + err := subscriber.Send(workingEvent) + if err != nil { + log.Errorf("Failed to send working event: %v", err) + } + + // Create response message + responseMessage := protocol.NewMessage( + protocol.MessageRoleAgent, + []protocol.Part{protocol.NewTextPart(fmt.Sprintf("Processed result: %s", result))}, + ) + + // Send task completion + completedEvent := protocol.StreamingMessageEvent{ + Result: &protocol.TaskStatusUpdateEvent{ + TaskID: taskID, + ContextID: "", + Kind: "status-update", + Status: protocol.TaskStatus{ + State: protocol.TaskStateCompleted, + Message: &responseMessage, + }, + Final: boolPtr(true), + }, + } + err = subscriber.Send(completedEvent) + if err != nil { + log.Errorf("Failed to send completed event: %v", err) + } + + // Add artifact + artifact := protocol.Artifact{ + ArtifactID: uuid.New().String(), + Name: stringPtr("Reversed Text"), + Description: stringPtr("The input text reversed"), + Parts: []protocol.Part{protocol.NewTextPart(result)}, + } + + artifactEvent := protocol.StreamingMessageEvent{ + Result: &protocol.TaskArtifactUpdateEvent{ + TaskID: taskID, + ContextID: "", + Kind: "artifact-update", + Artifact: artifact, + LastChunk: boolPtr(true), + }, + } + err = subscriber.Send(artifactEvent) + if err != nil { + log.Errorf("Failed to send artifact event: %v", err) + } + }() + + return &taskmanager.MessageProcessingResult{ + StreamingEvents: subscriber, + }, nil } // extractText extracts the text content from a message. func extractText(message protocol.Message) string { for _, part := range message.Parts { - if textPart, ok := part.(protocol.TextPart); ok { + if textPart, ok := part.(*protocol.TextPart); ok { return textPart.Text } } @@ -116,18 +194,19 @@ func main() { // Create the agent card. agentCard := server.AgentCard{ Name: "Simple A2A Example Server", - Description: stringPtr("A simple example A2A server that reverses text"), + Description: "A simple example A2A server that reverses text", URL: fmt.Sprintf("http://%s:%d/", *host, *port), Version: "1.0.0", Provider: &server.AgentProvider{ Organization: "tRPC-A2A-Go Examples", }, Capabilities: server.AgentCapabilities{ - Streaming: false, - StateTransitionHistory: true, + Streaming: boolPtr(true), + PushNotifications: boolPtr(false), + StateTransitionHistory: boolPtr(true), }, - DefaultInputModes: []string{string(protocol.PartTypeText)}, - DefaultOutputModes: []string{string(protocol.PartTypeText)}, + DefaultInputModes: []string{"text"}, + DefaultOutputModes: []string{"text"}, Skills: []server.AgentSkill{ { ID: "text_reversal", @@ -135,14 +214,14 @@ func main() { Description: stringPtr("Reverses the input text"), Tags: []string{"text", "processing"}, Examples: []string{"Hello, world!"}, - InputModes: []string{string(protocol.PartTypeText)}, - OutputModes: []string{string(protocol.PartTypeText)}, + InputModes: []string{"text"}, + OutputModes: []string{"text"}, }, }, } - // Create the task processor. - processor := &simpleTaskProcessor{} + // Create the message processor. + processor := &simpleMessageProcessor{} // Create task manager and inject processor. taskManager, err := taskmanager.NewMemoryTaskManager(processor) @@ -163,7 +242,7 @@ func main() { // Start the server in a goroutine. go func() { serverAddr := fmt.Sprintf("%s:%d", *host, *port) - log.Printf("Starting server on %s...", serverAddr) + log.Infof("Starting server on %s...", serverAddr) if err := srv.Start(serverAddr); err != nil { log.Fatalf("Server failed: %v", err) } @@ -171,5 +250,5 @@ func main() { // Wait for termination signal. sig := <-sigChan - log.Printf("Received signal %v, shutting down...", sig) + log.Infof("Received signal %v, shutting down...", sig) } diff --git a/examples/streaming/client/main.go b/examples/streaming/client/main.go index ed37421..a46100b 100644 --- a/examples/streaming/client/main.go +++ b/examples/streaming/client/main.go @@ -144,7 +144,11 @@ func checkStreamingSupport(serverURL string) (bool, error) { return false, fmt.Errorf("error parsing agent card: %w", err) } - return agentCard.Capabilities.Streaming, nil + // Handle the new *bool type for Streaming capability + if agentCard.Capabilities.Streaming != nil { + return *agentCard.Capabilities.Streaming, nil + } + return false, nil } // processStreamEvents handles events received from a streaming task @@ -169,15 +173,15 @@ func processStreamEvents(ctx context.Context, streamChan <-chan protocol.TaskEve // Process the received event switch e := event.(type) { - case protocol.TaskStatusUpdateEvent: - log.Infof("Received Status Update - TaskID: %s, State: %s, Final: %t", e.ID, e.Status.State, e.Final) + case *protocol.TaskStatusUpdateEvent: + log.Infof("Received Status Update - TaskID: %s, State: %s, Final: %v", e.TaskID, e.Status.State, e.Final) if e.Status.Message != nil { log.Infof(" Status Message: Role=%s, Parts=%+v", e.Status.Message.Role, e.Status.Message.Parts) } // Exit when we receive a final status update (indicating a terminal state) // Per A2A spec, this should be the definitive way to know the task is complete - if e.IsFinal() { + if e.Final != nil && *e.Final { if e.Status.State == protocol.TaskStateCompleted { log.Info("Task completed successfully.") } else if e.Status.State == protocol.TaskStateFailed { @@ -188,14 +192,13 @@ func processStreamEvents(ctx context.Context, streamChan <-chan protocol.TaskEve log.Info("Received final status update, exiting.") return } - case protocol.TaskArtifactUpdateEvent: - log.Infof("Received Artifact Update - TaskID: %s, Index: %d, Append: %v, LastChunk: %v", - e.ID, e.Artifact.Index, e.Artifact.Append, e.Artifact.LastChunk) + case *protocol.TaskArtifactUpdateEvent: + log.Infof("Received Artifact Update - TaskID: %s, ArtifactID: %s", e.TaskID, e.Artifact.ArtifactID) log.Infof(" Artifact Parts: %+v", e.Artifact.Parts) // For artifact updates, we note it's the final artifact, // but we don't exit yet - per A2A spec, we should wait for the final status update - if e.IsFinal() { + if e.LastChunk != nil && *e.LastChunk { log.Info("Received final artifact update, waiting for final status.") } default: @@ -210,5 +213,5 @@ func getArtifactName(artifact protocol.Artifact) string { if artifact.Name != nil { return *artifact.Name } - return fmt.Sprintf("Unnamed artifact (index: %d)", artifact.Index) + return fmt.Sprintf("Unnamed artifact (ArtifactID: %s)", artifact.ArtifactID) } diff --git a/examples/streaming/server/main.go b/examples/streaming/server/main.go index 0e85386..9232089 100644 --- a/examples/streaming/server/main.go +++ b/examples/streaming/server/main.go @@ -4,14 +4,15 @@ // // trpc-a2a-go is licensed under the Apache License Version 2.0. -// Package main implements a streaming server for the A2A protocol. +// Package main implements a streaming A2A server example. +// This example demonstrates how to process tasks with streaming responses, +// breaking large content into chunks and sending them progressively. package main import ( "context" "flag" "fmt" - "log" "os" "os/signal" "sort" @@ -19,192 +20,260 @@ import ( "syscall" "time" + "github.com/google/uuid" + + "trpc.group/trpc-go/trpc-a2a-go/log" "trpc.group/trpc-go/trpc-a2a-go/protocol" "trpc.group/trpc-go/trpc-a2a-go/server" "trpc.group/trpc-go/trpc-a2a-go/taskmanager" ) -// streamingTaskProcessor implements the TaskProcessor interface for streaming responses. +// streamingMessageProcessor implements the MessageProcessor interface for streaming responses. // This processor breaks the input text into chunks and sends them back as a stream. -type streamingTaskProcessor struct{} +type streamingMessageProcessor struct{} -// Process implements the core streaming logic. +// ProcessMessage implements the MessageProcessor interface. // It breaks the input text into chunks and sends them back incrementally. -func (p *streamingTaskProcessor) Process( +func (p *streamingMessageProcessor) ProcessMessage( ctx context.Context, - taskID string, message protocol.Message, - handle taskmanager.TaskHandle, -) error { - log.Printf("Processing streaming task %s...", taskID) + options taskmanager.ProcessOptions, + handle taskmanager.TaskHandler, +) (*taskmanager.MessageProcessingResult, error) { + log.Infof("Processing streaming message...") // Extract text from the incoming message. text := extractText(message) if text == "" { errMsg := "input message must contain text" - log.Printf("Task %s failed: %s", taskID, errMsg) + log.Errorf("Message processing failed: %s", errMsg) - // Update status to Failed via handle. - failedMessage := protocol.NewMessage( + // Return error message directly + errorMessage := protocol.NewMessage( protocol.MessageRoleAgent, []protocol.Part{protocol.NewTextPart(errMsg)}, ) - _ = handle.UpdateStatus(protocol.TaskStateFailed, &failedMessage) - return fmt.Errorf(errMsg) - } - // Check if this is a streaming request - isStreaming := handle.IsStreamingRequest() + return &taskmanager.MessageProcessingResult{ + Result: &errorMessage, + }, nil + } - // If not streaming, use a simplified process flow with fewer updates - if !isStreaming { - log.Printf("Task %s using non-streaming mode", taskID) - return p.processNonStreaming(ctx, taskID, text, handle) + // For non-streaming processing, use simplified flow + if !options.Streaming { + log.Infof("Using non-streaming mode") + return p.processNonStreaming(ctx, text, handle) } // Continue with streaming process - log.Printf("Task %s using streaming mode", taskID) + log.Infof("Using streaming mode") - // Update status to Working with an initial message - initialMessage := protocol.NewMessage( - protocol.MessageRoleAgent, - []protocol.Part{protocol.NewTextPart("Starting to process your streaming data...")}, - ) - if err := handle.UpdateStatus(protocol.TaskStateWorking, &initialMessage); err != nil { - log.Printf("Error updating initial status for task %s: %v", taskID, err) - return err + // Create a task for streaming + taskID, err := handle.BuildTask(nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to build task: %w", err) } - // Split the text into chunks to simulate streaming processing - chunks := splitTextIntoChunks(text, 5) // Split into chunks of about 5 characters - totalChunks := len(chunks) - - // Process each chunk with a small delay to simulate real-time processing - for i, chunk := range chunks { - // Check for cancellation - if err := ctx.Err(); err != nil { - log.Printf("Task %s cancelled during streaming: %v", taskID, err) - _ = handle.UpdateStatus(protocol.TaskStateCanceled, nil) - return err + // Subscribe to the task for streaming events + subscriber, err := handle.SubScribeTask(&taskID) + if err != nil { + return nil, fmt.Errorf("failed to subscribe to task: %w", err) + } + + // Start streaming processing in a goroutine + go func() { + defer func() { + if subscriber != nil { + subscriber.Close() + } + }() + + contextID := handle.GetContextID() + // Send initial working status + workingEvent := protocol.StreamingMessageEvent{ + Result: &protocol.TaskStatusUpdateEvent{ + TaskID: taskID, + ContextID: contextID, + Kind: "status-update", + Status: protocol.TaskStatus{ + State: protocol.TaskStateWorking, + Message: &protocol.Message{ + MessageID: uuid.New().String(), + Kind: "message", + Role: protocol.MessageRoleAgent, + Parts: []protocol.Part{protocol.NewTextPart("Starting to process your streaming data...")}, + }, + }, + }, + } + err = subscriber.Send(workingEvent) + if err != nil { + log.Errorf("Failed to send working event: %v", err) } - // Process the chunk (in this example, just reverse it) - processedChunk := reverseString(chunk) + // Split the text into chunks to simulate streaming processing + chunks := splitTextIntoChunks(text, 5) // Split into chunks of about 5 characters + totalChunks := len(chunks) + + // Process each chunk with a small delay to simulate real-time processing + for i, chunk := range chunks { + // Check for cancellation + if err := ctx.Err(); err != nil { + log.Errorf("Task %s cancelled during streaming: %v", taskID, err) + cancelEvent := protocol.StreamingMessageEvent{ + Result: &protocol.TaskStatusUpdateEvent{ + TaskID: taskID, + ContextID: contextID, + Kind: "status-update", + Status: protocol.TaskStatus{ + State: protocol.TaskStateCanceled, + }, + Final: boolPtr(true), + }, + } + err = subscriber.Send(cancelEvent) + if err != nil { + log.Errorf("Failed to send cancel event: %v", err) + } + return + } - // Create a progress update message - progressMsg := fmt.Sprintf("Processing chunk %d of %d: %s -> %s", - i+1, totalChunks, chunk, processedChunk) - statusMsg := protocol.NewMessage( - protocol.MessageRoleAgent, - []protocol.Part{protocol.NewTextPart(progressMsg)}, - ) + // Process the chunk (in this example, just reverse it) + processedChunk := reverseString(chunk) + + // Create a progress update message + progressMsg := fmt.Sprintf("Processing chunk %d of %d: %s -> %s", + i+1, totalChunks, chunk, processedChunk) + + // Send progress status update + progressEvent := protocol.StreamingMessageEvent{ + Result: &protocol.TaskStatusUpdateEvent{ + TaskID: taskID, + ContextID: contextID, + Kind: "status-update", + Status: protocol.TaskStatus{ + State: protocol.TaskStateWorking, + Message: &protocol.Message{ + MessageID: uuid.New().String(), + Kind: "message", + Role: protocol.MessageRoleAgent, + Parts: []protocol.Part{protocol.NewTextPart(progressMsg)}, + }, + }, + }, + } + err = subscriber.Send(progressEvent) + if err != nil { + log.Errorf("Failed to send progress event: %v", err) + } - // Update status to show progress - if err := handle.UpdateStatus(protocol.TaskStateWorking, &statusMsg); err != nil { - log.Printf("Error updating progress status for task %s: %v", taskID, err) - // Continue processing despite update error - } + // Create an artifact for this chunk + isLastChunk := (i == totalChunks-1) + chunkArtifact := protocol.Artifact{ + ArtifactID: uuid.New().String(), + Name: stringPtr(fmt.Sprintf("Chunk %d of %d", i+1, totalChunks)), + Description: stringPtr("Streaming chunk of processed data"), + Parts: []protocol.Part{protocol.NewTextPart(processedChunk)}, + } - // Create an artifact for this chunk - isLastChunk := (i == totalChunks-1) - chunkArtifact := protocol.Artifact{ - Name: stringPtr(fmt.Sprintf("Chunk %d of %d", i+1, totalChunks)), - Description: stringPtr("Streaming chunk of processed data"), - Index: i, - Parts: []protocol.Part{protocol.NewTextPart(processedChunk)}, - Append: boolPtr(i > 0), // Append after the first chunk - LastChunk: boolPtr(isLastChunk), // Mark the last chunk - } + // Send artifact update event + artifactEvent := protocol.StreamingMessageEvent{ + Result: &protocol.TaskArtifactUpdateEvent{ + TaskID: taskID, + ContextID: contextID, + Kind: "artifact-update", + Artifact: chunkArtifact, + Append: boolPtr(i > 0), // Append after the first chunk + LastChunk: boolPtr(isLastChunk), // Mark the last chunk + }, + } + err = subscriber.Send(artifactEvent) + if err != nil { + log.Errorf("Failed to send artifact event: %v", err) + } - // Add the artifact - if err := handle.AddArtifact(chunkArtifact); err != nil { - log.Printf("Error adding artifact for task %s chunk %d: %v", taskID, i+1, err) - // Continue processing despite artifact error + // Simulate processing time + select { + case <-ctx.Done(): + log.Infof("Task %s cancelled during delay: %v", taskID, ctx.Err()) + cancelEvent := protocol.StreamingMessageEvent{ + Result: &protocol.TaskStatusUpdateEvent{ + TaskID: taskID, + ContextID: contextID, + Kind: "status-update", + Status: protocol.TaskStatus{ + State: protocol.TaskStateCanceled, + }, + Final: boolPtr(true), + }, + } + err = subscriber.Send(cancelEvent) + if err != nil { + log.Errorf("Failed to send cancel event: %v", err) + } + return + case <-time.After(500 * time.Millisecond): // Simulate work with delay + // Continue processing + } } - // Simulate processing time - select { - case <-ctx.Done(): - log.Printf("Task %s cancelled during delay: %v", taskID, ctx.Err()) - _ = handle.UpdateStatus(protocol.TaskStateCanceled, nil) - return ctx.Err() - case <-time.After(500 * time.Millisecond): // Simulate work with delay - // Continue processing + // Final completion status update + completeEvent := protocol.StreamingMessageEvent{ + Result: &protocol.TaskStatusUpdateEvent{ + TaskID: taskID, + ContextID: contextID, + Kind: "status-update", + Status: protocol.TaskStatus{ + State: protocol.TaskStateCompleted, + Message: &protocol.Message{ + MessageID: uuid.New().String(), + Kind: "message", + Role: protocol.MessageRoleAgent, + Parts: []protocol.Part{protocol.NewTextPart(fmt.Sprintf("Completed processing all %d chunks successfully!", totalChunks))}, + }, + }, + Final: boolPtr(true), + }, + } + err = subscriber.Send(completeEvent) + if err != nil { + log.Errorf("Failed to send complete event: %v", err) } - } - // Final completion status update - completeMessage := protocol.NewMessage( - protocol.MessageRoleAgent, - []protocol.Part{ - protocol.NewTextPart( - fmt.Sprintf("Completed processing all %d chunks successfully!", totalChunks))}, - ) - if err := handle.UpdateStatus(protocol.TaskStateCompleted, &completeMessage); err != nil { - log.Printf("Error updating final status for task %s: %v", taskID, err) - return fmt.Errorf("failed to update final task status: %w", err) - } + log.Infof("Task %s streaming completed successfully.", taskID) + }() - log.Printf("Task %s streaming completed successfully.", taskID) - return nil + return &taskmanager.MessageProcessingResult{ + StreamingEvents: subscriber, + }, nil } // processNonStreaming handles processing for non-streaming requests // It processes the entire text at once and returns a single result -func (p *streamingTaskProcessor) processNonStreaming( +func (p *streamingMessageProcessor) processNonStreaming( ctx context.Context, - taskID string, text string, - handle taskmanager.TaskHandle, -) error { - // Update status to Working with an initial message - initialMessage := protocol.NewMessage( - protocol.MessageRoleAgent, - []protocol.Part{protocol.NewTextPart("Processing your text...")}, - ) - if err := handle.UpdateStatus(protocol.TaskStateWorking, &initialMessage); err != nil { - log.Printf("Error updating initial status for task %s: %v", taskID, err) - return err - } - + handle taskmanager.TaskHandler, +) (*taskmanager.MessageProcessingResult, error) { // Process the entire text at once processedText := reverseString(text) - // Create a single artifact with the result - artifact := protocol.Artifact{ - Name: stringPtr("Processed Text"), - Description: stringPtr("Complete processed text"), - Index: 0, - Parts: []protocol.Part{protocol.NewTextPart(processedText)}, - LastChunk: boolPtr(true), - } - - // Add the artifact - if err := handle.AddArtifact(artifact); err != nil { - log.Printf("Error adding artifact for task %s: %v", taskID, err) - } - - // Final completion status update - completeMessage := protocol.NewMessage( + // Return a direct message response + responseMessage := protocol.NewMessage( protocol.MessageRoleAgent, - []protocol.Part{ - protocol.NewTextPart( - fmt.Sprintf("Processing complete. Input: %s -> Output: %s", text, processedText))}, + []protocol.Part{protocol.NewTextPart(fmt.Sprintf("Processing complete. Input: %s -> Output: %s", text, processedText))}, ) - if err := handle.UpdateStatus(protocol.TaskStateCompleted, &completeMessage); err != nil { - log.Printf("Error updating final status for task %s: %v", taskID, err) - return fmt.Errorf("failed to update final task status: %w", err) - } - log.Printf("Task %s non-streaming completed successfully.", taskID) - return nil + return &taskmanager.MessageProcessingResult{ + Result: &responseMessage, + }, nil } // extractText extracts the first text part from a message. func extractText(message protocol.Message) string { for _, part := range message.Parts { // Type assert to the concrete TextPart type. - if p, ok := part.(protocol.TextPart); ok { + if p, ok := part.(*protocol.TextPart); ok { return p.Text } } @@ -351,21 +420,21 @@ func main() { serverURL := fmt.Sprintf("http://%s/", address) // Create the agent card - description := "A2A streaming example server that processes text in chunks" agentCard := server.AgentCard{ Name: "Streaming Text Processor", - Description: &description, + Description: "A2A streaming example server that processes text in chunks", URL: serverURL, Version: "1.0.0", Provider: &server.AgentProvider{ Organization: "tRPC-A2A-go Examples", }, Capabilities: server.AgentCapabilities{ - Streaming: true, - StateTransitionHistory: true, + Streaming: boolPtr(true), + PushNotifications: boolPtr(false), + StateTransitionHistory: boolPtr(true), }, - DefaultInputModes: []string{string(protocol.PartTypeText)}, - DefaultOutputModes: []string{string(protocol.PartTypeText)}, + DefaultInputModes: []string{"text"}, + DefaultOutputModes: []string{"text"}, Skills: []server.AgentSkill{ { ID: "streaming_processor", @@ -377,14 +446,14 @@ func main() { "Lorem ipsum dolor sit amet", "This demonstrates streaming capabilities", }, - InputModes: []string{string(protocol.PartTypeText)}, - OutputModes: []string{string(protocol.PartTypeText)}, + InputModes: []string{"text"}, + OutputModes: []string{"text"}, }, }, } - // Create the TaskProcessor (streaming logic) - processor := &streamingTaskProcessor{} + // Create the MessageProcessor (streaming logic) + processor := &streamingMessageProcessor{} // Create the TaskManager, injecting the processor taskManager, err := taskmanager.NewMemoryTaskManager(processor) @@ -404,7 +473,7 @@ func main() { // Start the server in a goroutine go func() { - log.Printf("Starting streaming server on %s...", address) + log.Infof("Starting streaming server on %s...", address) if err := srv.Start(address); err != nil { log.Fatalf("Server error: %v", err) } @@ -412,7 +481,7 @@ func main() { // Wait for shutdown signal sig := <-sigChan - log.Printf("Received signal %v, shutting down server...", sig) + log.Infof("Received signal %v, shutting down server...", sig) // Graceful shutdown ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -422,7 +491,7 @@ func main() { log.Fatalf("Error during server shutdown: %v", err) } - log.Println("Server shutdown complete") + log.Infof("Server shutdown complete") } // Helper functions to create pointers diff --git a/go.mod b/go.mod index 6e76234..8972da0 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ toolchain go1.23.7 require ( github.com/golang-jwt/jwt/v5 v5.2.2 + github.com/google/uuid v1.6.0 github.com/lestrrat-go/jwx/v2 v2.1.4 github.com/stretchr/testify v1.10.0 go.uber.org/zap v1.27.0 diff --git a/go.sum b/go.sum index c42a261..87fad39 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,8 @@ github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeD github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= diff --git a/internal/jsonrpc/jsonrpc_test.go b/internal/jsonrpc/jsonrpc_test.go index 5f552ab..3ee842e 100644 --- a/internal/jsonrpc/jsonrpc_test.go +++ b/internal/jsonrpc/jsonrpc_test.go @@ -226,10 +226,10 @@ func TestNewRequest(t *testing.T) { expectJSON: `{"jsonrpc":"2.0","id":123,"method":"get/data"}`, }, { - name: "Null ID", + name: "Auto-generated ID", method: "notify/update", id: nil, - expectJSON: `{"jsonrpc":"2.0","method":"notify/update"}`, + expectJSON: ``, // Will be checked separately since ID is auto-generated }, { name: "Complex Method Path", @@ -246,6 +246,15 @@ func TestNewRequest(t *testing.T) { // Verify individual fields assert.Equal(t, Version, req.JSONRPC, "JSONRPC version should be set correctly") assert.Equal(t, tc.method, req.Method, "Method should match input") + + // Special handling for auto-generated ID test case + if tc.name == "Auto-generated ID" { + assert.NotNil(t, req.ID, "ID should be auto-generated when nil is passed") + assert.NotEmpty(t, req.ID, "Auto-generated ID should not be empty") + // Skip JSON comparison for auto-generated ID since it's unpredictable + return + } + assert.Equal(t, tc.id, req.ID, "ID should match input") assert.Nil(t, req.Params, "Params should be nil by default") diff --git a/internal/jsonrpc/request.go b/internal/jsonrpc/request.go index e5d0680..c2e3f80 100644 --- a/internal/jsonrpc/request.go +++ b/internal/jsonrpc/request.go @@ -6,7 +6,11 @@ package jsonrpc -import "encoding/json" +import ( + "encoding/json" + + "trpc.group/trpc-go/trpc-a2a-go/protocol" +) // Request represents a JSON-RPC request object. type Request struct { @@ -20,11 +24,17 @@ type Request struct { } // NewRequest creates a new JSON-RPC request with the given method and ID. +// If id is nil, a new ID will be automatically generated since A2A protocol +// requires responses for all requests. func NewRequest(method string, id interface{}) *Request { + if id == nil { + // A2A protocol doesn't use notifications - all requests need responses + id = protocol.GenerateRPCID() + } return &Request{ Message: Message{ - JSONRPC: Version, ID: id, + JSONRPC: Version, }, Method: method, } diff --git a/internal/sse/sse.go b/internal/sse/sse.go index 057f88e..4de4ed5 100644 --- a/internal/sse/sse.go +++ b/internal/sse/sse.go @@ -21,7 +21,7 @@ import ( // CloseEventData represents the data payload for a close event. // Used when formatting SSE messages indicating stream closure. type CloseEventData struct { - TaskID string `json:"taskId"` + ID string `json:"taskId"` Reason string `json:"reason"` } diff --git a/internal/sse/sse_test.go b/internal/sse/sse_test.go index 4308ff8..720b906 100644 --- a/internal/sse/sse_test.go +++ b/internal/sse/sse_test.go @@ -1,3 +1,9 @@ +// Tencent is pleased to support the open source community by making trpc-a2a-go available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// trpc-a2a-go is licensed under the Apache License Version 2.0. + package sse import ( @@ -141,7 +147,7 @@ func TestFormatEvent(t *testing.T) { { name: "struct data", eventType: "close", - data: CloseEventData{TaskID: "123", Reason: "completed"}, + data: CloseEventData{ID: "123", Reason: "completed"}, expected: "event: close\ndata: {\"taskId\":\"123\",\"reason\":\"completed\"}\n\n", }, { @@ -215,7 +221,7 @@ func TestReadEventSequence(t *testing.T) { func TestCloseEventDataMarshaling(t *testing.T) { closeData := CloseEventData{ - TaskID: "task123", + ID: "task123", Reason: "test completed", } @@ -229,8 +235,8 @@ func TestCloseEventDataMarshaling(t *testing.T) { t.Fatalf("Failed to unmarshal CloseEventData: %v.", err) } - if unmarshaled.TaskID != closeData.TaskID { - t.Errorf("Expected TaskID %q, got %q.", closeData.TaskID, unmarshaled.TaskID) + if unmarshaled.ID != closeData.ID { + t.Errorf("Expected TaskID %q, got %q.", closeData.ID, unmarshaled.ID) } if unmarshaled.Reason != closeData.Reason { diff --git a/protocol/protocol.go b/protocol/protocol.go index 32bf6a0..c0b8712 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -7,21 +7,42 @@ // Package protocol defines constants and potentially shared types for the A2A protocol itself. package protocol -// A2A RPC Method Names define the standard method strings used in the A2A protocol's Task Service. +// Method names for A2A 0.2.2 specification. const ( - MethodTasksSend = "tasks/send" - MethodTasksSendSubscribe = "tasks/sendSubscribe" - MethodTasksGet = "tasks/get" - MethodTasksCancel = "tasks/cancel" - MethodTasksPushNotificationSet = "tasks/pushNotification/set" - MethodTasksPushNotificationGet = "tasks/pushNotification/get" - MethodTasksResubscribe = "tasks/resubscribe" + // MethodMessageSend corresponds to the 'message/send' RPC method. + MethodMessageSend = "message/send" + // MethodMessageStream corresponds to the 'message/stream' RPC method. + MethodMessageStream = "message/stream" + // MethodTasksGet corresponds to the 'tasks/get' RPC method. + MethodTasksGet = "tasks/get" + // MethodTasksCancel corresponds to the 'tasks/cancel' RPC method. + MethodTasksCancel = "tasks/cancel" + // MethodTasksPushNotificationConfigSet corresponds to the 'tasks/pushNotificationConfig/set' RPC method. + MethodTasksPushNotificationConfigSet = "tasks/pushNotificationConfig/set" + // MethodTasksPushNotificationConfigGet corresponds to the 'tasks/pushNotificationConfig/get' RPC method. + MethodTasksPushNotificationConfigGet = "tasks/pushNotificationConfig/get" + // MethodTasksResubscribe corresponds to the 'tasks/resubscribe' RPC method. + MethodTasksResubscribe = "tasks/resubscribe" + // MethodAgentAuthenticatedExtendedCard corresponds to the 'agent/authenticatedExtendedCard' HTTP GET endpoint. + MethodAgentAuthenticatedExtendedCard = "agent/authenticatedExtendedCard" + + // deprecated methods + MethodTasksSend = "tasks/send" // Deprecated: use MethodMessageSend + // deprecated methods + MethodTasksSendSubscribe = "tasks/sendSubscribe" // Deprecated: use MethodMessageStream + // deprecated methods + MethodTasksPushNotificationSet = "tasks/pushNotification/set" // Deprecated: use MethodTasksPushNotificationConfigSet + // deprecated methods + MethodTasksPushNotificationGet = "tasks/pushNotification/get" // Deprecated: use MethodTasksPushNotificationConfigGet ) // A2A SSE Event Types define the standard event type strings used in A2A SSE streams. const ( - EventTaskStatusUpdate = "task_status_update" - EventTaskArtifactUpdate = "task_artifact_update" + EventStatusUpdate = "task_status_update" + EventArtifactUpdate = "task_artifact_update" + EventTask = "task" + EventMessage = "message" + // EventClose is used internally by this implementation's server to signal stream closure. // Note: This might not be part of the formal A2A spec but is used in server logic. EventClose = "close" diff --git a/protocol/protocol_test.go b/protocol/protocol_test.go index 8c6b739..e9758c5 100644 --- a/protocol/protocol_test.go +++ b/protocol/protocol_test.go @@ -38,9 +38,9 @@ func TestMethodConstants(t *testing.T) { // and maintain their expected values. func TestEventTypeConstants(t *testing.T) { // Test SSE event type constants - assert.Equal(t, "task_status_update", protocol.EventTaskStatusUpdate, + assert.Equal(t, "task_status_update", protocol.EventStatusUpdate, "EventTaskStatusUpdate should be 'task_status_update'") - assert.Equal(t, "task_artifact_update", protocol.EventTaskArtifactUpdate, + assert.Equal(t, "task_artifact_update", protocol.EventArtifactUpdate, "EventTaskArtifactUpdate should be 'task_artifact_update'") assert.Equal(t, "close", protocol.EventClose, "EventClose should be 'close'") } @@ -63,11 +63,11 @@ func TestConstantRelationships(t *testing.T) { "Push notification set and get methods should be distinct") // Test that event types are distinct - assert.True(t, protocol.EventTaskStatusUpdate != protocol.EventTaskArtifactUpdate, + assert.True(t, protocol.EventStatusUpdate != protocol.EventArtifactUpdate, "Status and artifact event types should be distinct") - assert.True(t, protocol.EventTaskStatusUpdate != protocol.EventClose, + assert.True(t, protocol.EventStatusUpdate != protocol.EventClose, "Status update and close event types should be distinct") - assert.True(t, protocol.EventTaskArtifactUpdate != protocol.EventClose, + assert.True(t, protocol.EventArtifactUpdate != protocol.EventClose, "Artifact update and close event types should be distinct") // Test that HTTP endpoint paths are distinct diff --git a/protocol/types.go b/protocol/types.go index b295ecb..bff04c4 100644 --- a/protocol/types.go +++ b/protocol/types.go @@ -11,6 +11,8 @@ import ( "encoding/json" "fmt" "time" + + "github.com/google/uuid" ) // TaskState represents the lifecycle state of a task. @@ -31,10 +33,102 @@ const ( TaskStateCanceled TaskState = "canceled" // TaskStateFailed is the state when the task failed during processing. TaskStateFailed TaskState = "failed" + // TaskStateRejected is the state when the task was rejected by the agent. + TaskStateRejected TaskState = "rejected" + // TaskStateAuthRequired is the state when the task requires authentication before processing. + TaskStateAuthRequired TaskState = "auth-required" // TaskStateUnknown is the state when the task is in an unknown or indeterminate state. TaskStateUnknown TaskState = "unknown" ) +// Event is an interface that represents the kind of the struct. +type Event interface { + GetKind() string +} + +// GetKind returns the kind of the result. +func (m *Message) GetKind() string { return KindMessage } + +// GetKind returns the kind of the task. +func (r *Task) GetKind() string { return KindTask } + +// GetKind returns the kind of the task status update event. +func (r *TaskStatusUpdateEvent) GetKind() string { return KindTaskStatusUpdate } + +// GetKind returns the kind of the task artifact update event. +func (r *TaskArtifactUpdateEvent) GetKind() string { return KindTaskArtifactUpdate } + +// GenerateMessageID generates a new unique message ID. +func GenerateMessageID() string { + id := uuid.New() + return "msg-" + id.String() +} + +// GenerateContextID generates a new unique context ID for a task. +func GenerateContextID() string { + id := uuid.New() + return "ctx-" + id.String() +} + +// GenerateTaskID generates a new unique task ID. +func GenerateTaskID() string { + id := uuid.New() + return "task-" + id.String() +} + +// GenerateArtifactID generates a new unique artifact ID. +func GenerateArtifactID() string { + id := uuid.New() + return "artifact-" + id.String() +} + +// GenerateRPCID generates a new unique RPC ID. +func GenerateRPCID() string { + id := uuid.New() + return id.String() +} + +// UnaryMessageResult is an interface representing a result of SendMessage. +// It only supports Message or Task. +type UnaryMessageResult interface { + unaryMessageResultMarker() + GetKind() string +} + +func (Message) unaryMessageResultMarker() {} +func (Task) unaryMessageResultMarker() {} + +// Part is an interface representing a segment of a message (text, file, or data). +// It uses an unexported method to ensure only defined part types implement it. +// See A2A Spec section on Message Parts. +// Exported interface. +type Part interface { + partMarker() // Internal marker method. + GetKind() string +} + +func (TextPart) partMarker() {} +func (FilePart) partMarker() {} +func (DataPart) partMarker() {} + +// Kind constants define the possible kinds of the struct. +const ( + // KindMessage is the kind of the message. + KindMessage = "message" + // KindTask is the kind of the task. + KindTask = "task" + // KindTaskStatusUpdate is the kind of the task status update event. + KindTaskStatusUpdate = "status_update" + // KindTaskArtifactUpdate is the kind of the task artifact update event. + KindTaskArtifactUpdate = "artifact_update" + // KindData is the kind of the data. + KindData = "data" + // KindFile is the kind of the file. + KindFile = "file" + // KindText is the kind of the text. + KindText = "text" +) + // MessageRole indicates the originator of a message (user or agent). // See A2A Spec section on Messages. type MessageRole string @@ -47,74 +141,138 @@ const ( MessageRoleAgent MessageRole = "agent" ) -// PartType indicates the type of content within a message part. -// See A2A Spec section on Message Parts. -type PartType string - -// PartType constants define the supported types for message parts. -const ( - // PartTypeText is for simple text content. - PartTypeText PartType = "text" - // PartTypeFile is for file references (path or URL). - PartTypeFile PartType = "file" - // PartTypeData is for raw binary data. - PartTypeData PartType = "data" -) - -// FileContent represents file data, either directly embedded or via URI. -// Corresponds to the 'file' structure in A2A Message Parts. -type FileContent struct { - // Name is the optional filename. - Name *string `json:"name,omitempty"` - // MimeType is the optional MIME type. - MimeType *string `json:"mimeType,omitempty"` - // Bytes is the optional base64-encoded content. - Bytes *string `json:"bytes,omitempty"` - // URI is the optional URI pointing to the content. - URI *string `json:"uri,omitempty"` +// FileUnion represents the union type for file content. +// It contains either FileWithBytes or FileWithURI. +type FileUnion interface { + fileUnionMarker() } +func (f *FileWithBytes) fileUnionMarker() {} +func (f *FileWithURI) fileUnionMarker() {} + // TextPart represents a text segment within a message. // Corresponds to the 'text' part type in A2A Message Parts. type TextPart struct { - // Type is the type of the part. - Type PartType `json:"type"` + // Kind is the type of the part. + Kind string `json:"kind"` // Text is the text content. Text string `json:"text"` // Metadata is the optional metadata. Metadata map[string]interface{} `json:"metadata,omitempty"` } +// GetKind returns the kind of the part. +func (t TextPart) GetKind() string { + return KindText +} + // FilePart represents a file included in a message. // Corresponds to the 'file' part type in A2A Message Parts. type FilePart struct { - // Type is the type of the part. - Type PartType `json:"type"` + // Kind is the type of the part. + Kind string `json:"kind"` // File is the file content. - File FileContent `json:"file"` + File FileUnion `json:"file"` // Metadata is the optional metadata. Metadata map[string]interface{} `json:"metadata,omitempty"` } +// GetKind returns the kind of the part. +func (f FilePart) GetKind() string { + return KindFile +} + +// UnmarshalJSON implements custom unmarshalling logic for FilePart +// to handle the polymorphic FileUnion interface. +func (f *FilePart) UnmarshalJSON(data []byte) error { + type Alias FilePart // Alias to avoid recursion. + temp := &struct { + File json.RawMessage `json:"file"` // Unmarshal file into RawMessage first. + *Alias + }{ + Alias: (*Alias)(f), + } + if err := json.Unmarshal(data, &temp); err != nil { + return fmt.Errorf("failed to unmarshal file part base: %w", err) + } + + // Now determine the concrete type of FileUnion based on fields present + var fileContent map[string]interface{} + if err := json.Unmarshal(temp.File, &fileContent); err != nil { + return fmt.Errorf("failed to unmarshal file content: %w", err) + } + + // Check if it has "bytes" field (FileWithBytes) or "uri" field (FileWithURI) + if _, hasBytes := fileContent["bytes"]; hasBytes { + var fileWithBytes FileWithBytes + if err := json.Unmarshal(temp.File, &fileWithBytes); err != nil { + return fmt.Errorf("failed to unmarshal FileWithBytes: %w", err) + } + f.File = &fileWithBytes + } else if _, hasURI := fileContent["uri"]; hasURI { + var fileWithURI FileWithURI + if err := json.Unmarshal(temp.File, &fileWithURI); err != nil { + return fmt.Errorf("failed to unmarshal FileWithURI: %w", err) + } + f.File = &fileWithURI + } else { + return fmt.Errorf("unknown file type: must have either 'bytes' or 'uri' field") + } + + return nil +} + // DataPart represents arbitrary structured data (JSON) within a message. // Corresponds to the 'data' part type in A2A Message Parts. type DataPart struct { - // Type is the type of the part. - Type PartType `json:"type"` + // Kind is the type of the part. + Kind string `json:"kind"` // Data is the actual data payload. Data interface{} `json:"data"` // Metadata is the optional metadata. Metadata map[string]interface{} `json:"metadata,omitempty"` } -// AuthenticationInfo represents authentication details for external services. -type AuthenticationInfo struct { - // Schemes is a list of authentication schemes supported. - Schemes []string `json:"schemes"` +// GetKind returns the kind of the part. +func (d DataPart) GetKind() string { + return KindData +} + +// FileWithBytes represents file data with embedded content. +// This is one variant of the File union type in A2A 0.2.2. +type FileWithBytes struct { + // Name is the optional filename. + Name *string `json:"name,omitempty"` + // MimeType is the optional MIME type. + MimeType *string `json:"mimeType,omitempty"` + // Bytes is the required base64-encoded content. + Bytes string `json:"bytes"` +} + +// FileWithURI represents file data with URI reference. +// This is one variant of the File union type in A2A 0.2.2. +type FileWithURI struct { + // Name is the optional filename. + Name *string `json:"name,omitempty"` + // MimeType is the optional MIME type. + MimeType *string `json:"mimeType,omitempty"` + // URI is the required URI pointing to the content. + URI string `json:"uri"` +} + +// PushNotificationAuthenticationInfo represents authentication details for external services. +// Renamed from AuthenticationInfo for A2A 0.2.2 specification compliance. +type PushNotificationAuthenticationInfo struct { // Credentials are the actual authentication credentials. Credentials *string `json:"credentials,omitempty"` + // Schemes is a list of authentication schemes supported. + Schemes []string `json:"schemes"` } +// AuthenticationInfo is deprecated, use PushNotificationAuthenticationInfo instead. +// TODO: Remove in next major version. +type AuthenticationInfo = PushNotificationAuthenticationInfo + // OAuth2AuthInfo contains OAuth2-specific authentication details. type OAuth2AuthInfo struct { // ClientID is the OAuth2 client ID. @@ -161,48 +319,51 @@ type APIKeyAuthInfo struct { // PushNotificationConfig represents the configuration for task push notifications. type PushNotificationConfig struct { - // URL is the endpoint where notifications should be sent. - URL string `json:"url"` - // Token is an optional authentication token. - Token string `json:"token,omitempty"` // Authentication contains optional authentication details. Authentication *AuthenticationInfo `json:"authentication,omitempty"` - // Metadata is optional additional configuration data. + // Push Notification ID, created by server to support multiple push notification. + ID string `json:"id,omitempty"` + // Token is an optional authentication token. + Token string `json:"token,omitempty"` + // URL is the endpoint where notifications should be sent. + URL string `json:"url"` + // Metadata is the optional metadata. Metadata map[string]interface{} `json:"metadata,omitempty"` } // TaskPushNotificationConfig associates a task ID with push notification settings. type TaskPushNotificationConfig struct { - // ID is the unique task identifier. - ID string `json:"id"` + // RPCID is the ID of json-rpc. + RPCID string `json:"-"` // PushNotificationConfig contains the notification settings. PushNotificationConfig PushNotificationConfig `json:"pushNotificationConfig"` - // Metadata is optional additional configuration data. + // TaskID is the unique task identifier. + TaskID string `json:"taskId"` + // Metadata is the optional metadata. Metadata map[string]interface{} `json:"metadata,omitempty"` } -// Part is an interface representing a segment of a message (text, file, or data). -// It uses an unexported method to ensure only defined part types implement it. -// See A2A Spec section on Message Parts. -// Exported interface. -type Part interface { - partMarker() // Internal marker method. -} - -// partMarker implementations for concrete types (unexported methods). -func (TextPart) partMarker() {} -func (FilePart) partMarker() {} -func (DataPart) partMarker() {} - // Message represents a single exchange between a user and an agent. // See A2A Spec section on Messages. type Message struct { - // Role is the sender of the message. - Role MessageRole `json:"role"` - // Parts is the content parts (must implement Part). - Parts []Part `json:"parts"` + // ContextID is the optional context identifier for the message. + ContextID *string `json:"contextId,omitempty"` + // Extensions is the optional list of extension URIs. + Extensions []string `json:"extensions,omitempty"` + // Kind is the type discriminator for this message (always "message"). + Kind string `json:"kind"` + // MessageID is the unique identifier for this message. + MessageID string `json:"messageId"` // Metadata is the optional metadata. Metadata map[string]interface{} `json:"metadata,omitempty"` + // Parts is the content parts (must implement Part). + Parts []Part `json:"parts"` + // ReferenceTaskIDs is the optional list of referenced task IDs. + ReferenceTaskIDs []string `json:"referenceTaskIds,omitempty"` + // Role is the sender of the message. + Role MessageRole `json:"role"` + // TaskID is the optional task identifier this message belongs to. + TaskID *string `json:"taskId,omitempty"` } // UnmarshalJSON implements custom unmarshalling logic for Message @@ -230,23 +391,21 @@ func (m *Message) UnmarshalJSON(data []byte) error { return nil } -// Artifact represents an output generated by a task, potentially chunked for streaming. -// See A2A Spec section on Artifacts. +// Artifact represents an output generated by a task. +// See A2A Spec 0.2.2 section on Artifacts. type Artifact struct { + // ArtifactID is the unique identifier for the artifact. + ArtifactID string `json:"artifactId"` // Name is the name of the artifact. Name *string `json:"name,omitempty"` // Description is the description of the artifact. Description *string `json:"description,omitempty"` // Parts is the content parts of the artifact. Parts []Part `json:"parts"` - // Index is the index for ordering streamed artifacts. - Index int `json:"index"` - // Append is a hint for the client to append data (streaming). - Append *bool `json:"append,omitempty"` - // LastChunk is a flag indicating if this is the final chunk of an artifact stream. - LastChunk *bool `json:"lastChunk,omitempty"` // Metadata is optional metadata for the artifact. Metadata map[string]interface{} `json:"metadata,omitempty"` + // Extensions is the optional list of extension URIs. + Extensions []string `json:"extensions,omitempty"` } // UnmarshalJSON implements custom unmarshalling logic for Artifact @@ -275,40 +434,40 @@ func (a *Artifact) UnmarshalJSON(data []byte) error { } // unmarshalPart determines the concrete type of a Part from raw JSON -// based on the "type" field and unmarshals into that concrete type. +// based on the "kind" field and unmarshals into that concrete type. // Internal helper function. func unmarshalPart(rawPart json.RawMessage) (Part, error) { - // Peek at the type field to determine the concrete type. + // First, determine the type by unmarshalling just the "kind" field. var typeDetect struct { - Type PartType `json:"type"` + Kind string `json:"kind"` } if err := json.Unmarshal(rawPart, &typeDetect); err != nil { - return nil, fmt.Errorf("cannot detect part type: %w. Data: %s", err, string(rawPart)) + return nil, fmt.Errorf("failed to detect part type: %w", err) } // Unmarshal into the correct concrete type. - switch typeDetect.Type { - case PartTypeText: + switch typeDetect.Kind { + case KindText: var p TextPart if err := json.Unmarshal(rawPart, &p); err != nil { return nil, fmt.Errorf("failed to unmarshal TextPart: %w", err) } - return p, nil - case PartTypeFile: + return &p, nil + case KindFile: var p FilePart if err := json.Unmarshal(rawPart, &p); err != nil { return nil, fmt.Errorf("failed to unmarshal FilePart: %w", err) } - return p, nil - case PartTypeData: + return &p, nil + case KindData: var p DataPart if err := json.Unmarshal(rawPart, &p); err != nil { return nil, fmt.Errorf("failed to unmarshal DataPart: %w", err) } - return p, nil + return &p, nil default: // If we need to handle unknown part types gracefully (e.g., store raw JSON), // we would add that logic here. For now, treat as an error. - return nil, fmt.Errorf("unsupported part type: %s", typeDetect.Type) + return nil, fmt.Errorf("unsupported part kind: %s", typeDetect.Kind) } } @@ -326,21 +485,24 @@ type TaskStatus struct { // Task represents a unit of work being processed by the agent. // See A2A Spec section on Tasks. type Task struct { - // ID is the unique task identifier. - ID string `json:"id"` - // SessionID is the optional session identifier for grouping tasks. - SessionID *string `json:"sessionId,omitempty"` - // Status is the current task status. - Status TaskStatus `json:"status"` // Artifacts is the accumulated artifacts generated by the task. Artifacts []Artifact `json:"artifacts,omitempty"` + // ContextID is the unique context identifier for the task. + ContextID string `json:"contextId"` // History is the history of messages exchanged for this task (if tracked). History []Message `json:"history,omitempty"` + // ID is the unique task identifier. + ID string `json:"id"` + // Kind is the event type discriminator (always "task"). + Kind string `json:"kind"` // Metadata is the optional metadata associated with the task. Metadata map[string]interface{} `json:"metadata,omitempty"` + // Status is the current task status. + Status TaskStatus `json:"status"` } // TaskEvent is an interface for events published during task execution (streaming). +// Deprecated, use StreamingMessageEvent instead. // It uses an unexported method to ensure only defined event types implement it. // See A2A Spec section on Streaming and Events. // Exported interface. @@ -351,50 +513,172 @@ type TaskEvent interface { } // TaskStatusUpdateEvent indicates a change in the task's lifecycle state. -// Corresponds to the 'task_status_update' event in A2A Spec. +// Corresponds to the 'task_status_update' event in A2A Spec 0.2.2. type TaskStatusUpdateEvent struct { - // ID is the ID of the task. - ID string `json:"id"` - // Status is the new status. - Status TaskStatus `json:"status"` - // Final is a flag indicating if this is the terminal status event. - Final bool `json:"final"` + // ContextID is the context ID of the task. + ContextID string `json:"contextId"` + // Final is a flag indicating if this is the final event for the task. + Final *bool `json:"final"` + // Kind is the event type discriminator. + Kind string `json:"kind"` // Metadata is the optional metadata. Metadata map[string]interface{} `json:"metadata,omitempty"` + // Status is the new status. + Status TaskStatus `json:"status"` + // TaskID is the ID of the task. + TaskID string `json:"taskId"` } // eventMarker implementation (unexported method). func (TaskStatusUpdateEvent) eventMarker() {} -// IsFinal implements TaskEvent. -func (e TaskStatusUpdateEvent) IsFinal() bool { - return e.Final +// IsFinal returns true if this is a final event. +func (r *TaskStatusUpdateEvent) IsFinal() bool { + return r.Final != nil && *r.Final } // TaskArtifactUpdateEvent indicates a new or updated artifact chunk. -// Corresponds to the 'task_artifact_update' event in A2A Spec. +// Corresponds to the 'task_artifact_update' event in A2A Spec 0.2.2. type TaskArtifactUpdateEvent struct { - // ID is the ID of the task. - ID string `json:"id"` + // Append is a hint for the client to append data (streaming). + Append *bool `json:"append,omitempty"` // Artifact is the artifact data. Artifact Artifact `json:"artifact"` - // Final is a flag indicating if this is the final event for the task (usually linked to Artifact.LastChunk). - Final bool `json:"final"` + // ContextID is the context ID of the task. + ContextID string `json:"contextId"` + // Kind is the event type discriminator. + Kind string `json:"kind"` + // LastChunk is a flag indicating if this is the final chunk of an artifact stream. + LastChunk *bool `json:"lastChunk,omitempty"` // Metadata is optional metadata. Metadata map[string]interface{} `json:"metadata,omitempty"` + // TaskID is the ID of the task. + TaskID string `json:"taskId"` } // eventMarker implementation (unexported method). func (TaskArtifactUpdateEvent) eventMarker() {} -// IsFinal implements TaskEvent. -func (e TaskArtifactUpdateEvent) IsFinal() bool { - return e.Final +// IsFinal returns true if this is the final artifact event. +func (r TaskArtifactUpdateEvent) IsFinal() bool { + return r.LastChunk != nil && *r.LastChunk +} + +// NewTask creates a new Task with initial state (Submitted). +func NewTask(id string, contextID string) *Task { + return &Task{ + ID: id, + ContextID: contextID, + Kind: KindTask, + Status: TaskStatus{ + State: TaskStateSubmitted, + Timestamp: time.Now().UTC().Format(time.RFC3339), + }, + Metadata: make(map[string]interface{}), + } +} + +// NewMessage creates a new Message with the specified role and parts. +func NewMessage(role MessageRole, parts []Part) Message { + messageID := GenerateMessageID() + return Message{ + Role: role, + Parts: parts, + MessageID: messageID, + Kind: KindMessage, + } +} + +// NewMessageWithContext creates a new Message with context information. +func NewMessageWithContext(role MessageRole, parts []Part, taskID, contextID *string) Message { + messageID := GenerateMessageID() + return Message{ + Role: role, + Parts: parts, + MessageID: messageID, + TaskID: taskID, + ContextID: contextID, + Kind: "message", + } +} + +// NewTextPart creates a new TextPart containing the given text. +func NewTextPart(text string) TextPart { + return TextPart{ + Kind: KindText, + Text: text, + } +} + +// NewFilePartWithBytes creates a new FilePart with embedded bytes content. +func NewFilePartWithBytes(name, mimeType string, bytes string) FilePart { + return FilePart{ + Kind: KindFile, + File: &FileWithBytes{ + Name: &name, + MimeType: &mimeType, + Bytes: bytes, + }, + } +} + +// NewFilePartWithURI creates a new FilePart with URI reference. +func NewFilePartWithURI(name, mimeType string, uri string) FilePart { + return FilePart{ + Kind: KindFile, + File: &FileWithURI{ + Name: &name, + MimeType: &mimeType, + URI: uri, + }, + } +} + +// NewDataPart creates a new DataPart with the given data. +func NewDataPart(data interface{}) DataPart { + return DataPart{ + Kind: KindData, + Data: data, + } +} + +// NewArtifactWithID creates a new Artifact with a generated ID. +func NewArtifactWithID(name, description *string, parts []Part) *Artifact { + artifactID := GenerateArtifactID() + return &Artifact{ + ArtifactID: artifactID, + Name: name, + Description: description, + Parts: parts, + Metadata: make(map[string]interface{}), + } +} + +// NewTaskStatusUpdateEvent creates a new TaskStatusUpdateEvent. +func NewTaskStatusUpdateEvent(taskID, contextID string, status TaskStatus, final bool) TaskStatusUpdateEvent { + return TaskStatusUpdateEvent{ + TaskID: taskID, + ContextID: contextID, + Kind: KindTaskStatusUpdate, + Status: status, + } +} + +// NewTaskArtifactUpdateEvent creates a new TaskArtifactUpdateEvent. +func NewTaskArtifactUpdateEvent(taskID, contextID string, artifact Artifact, final bool) TaskArtifactUpdateEvent { + return TaskArtifactUpdateEvent{ + TaskID: taskID, + ContextID: contextID, + Kind: KindTaskArtifactUpdate, + Artifact: artifact, + } } // SendTaskParams defines the parameters for the tasks_send and tasks_sendSubscribe RPC methods. // See A2A Spec section on RPC Methods. type SendTaskParams struct { + // RPCID is the ID of json-rpc. + RPCID string `json:"-"` // ID is the ID of the task. ID string `json:"id"` // SessionID is the optional session ID. @@ -412,6 +696,8 @@ type SendTaskParams struct { // TaskQueryParams defines the parameters for the tasks_get RPC method. // See A2A Spec section on RPC Methods. type TaskQueryParams struct { + // RPCID is the ID of json-rpc. + RPCID string `json:"-"` // ID is the ID of the task. ID string `json:"id"` // HistoryLength is the requested message history length. @@ -423,39 +709,168 @@ type TaskQueryParams struct { // TaskIDParams defines parameters for methods needing only a task ID (e.g., tasks_cancel). // See A2A Spec section on RPC Methods. type TaskIDParams struct { - // ID is the ID of the task. + // RPCID is the ID of json-rpc. + RPCID string `json:"-"` + // ID is task ID. ID string `json:"id"` // Metadata is the optional metadata. Metadata map[string]interface{} `json:"metadata,omitempty"` } -// --- Factory Functions --- +// SendMessageParams defines the parameters for the message/send and message/stream RPC methods. +// See A2A Spec 0.2.2 section on RPC Methods. +type SendMessageParams struct { + // RPCID is the ID of json-rpc. + RPCID string `json:"-"` + // Configuration contains optional sending configuration. + Configuration *SendMessageConfiguration `json:"configuration,omitempty"` + // Message is the message to send. + Message Message `json:"message"` + // Metadata is optional metadata. + Metadata map[string]interface{} `json:"metadata,omitempty"` +} -// NewTask creates a new Task with initial state (Submitted). -func NewTask(id string, sessionID *string) *Task { - return &Task{ - ID: id, - SessionID: sessionID, - Status: TaskStatus{ - State: TaskStateSubmitted, - Timestamp: time.Now().UTC().Format(time.RFC3339), - }, - Metadata: make(map[string]interface{}), +// SendMessageConfiguration defines optional configuration for message sending. +type SendMessageConfiguration struct { + // PushNotificationConfig contains optional push notification settings. + PushNotificationConfig *PushNotificationConfig `json:"pushNotificationConfig,omitempty"` + // HistoryLength is the requested history length in response. + HistoryLength *int `json:"historyLength,omitempty"` + // Blocking indicates whether to wait for task completion (message/send only). + Blocking *bool `json:"blocking,omitempty"` +} + +// MessageResult represents the union type response for Message/Task. +type MessageResult struct { + Result UnaryMessageResult +} + +// UnmarshalJSON implements custom unmarshalling logic for MessageResult +func (r *MessageResult) UnmarshalJSON(data []byte) error { + // First, detect the kind of the result + var kindOnly struct { + Kind string `json:"kind"` + } + if err := json.Unmarshal(data, &kindOnly); err != nil { + return fmt.Errorf("failed to unmarshal result kind: %w", err) + } + + // Now unmarshal to the correct concrete type based on kind + switch kindOnly.Kind { + case KindMessage: + r.Result = &Message{} + if err := json.Unmarshal(data, r.Result); err != nil { + return fmt.Errorf("failed to unmarshal message: %w", err) + } + return nil + case KindTask: + r.Result = &Task{} + if err := json.Unmarshal(data, r.Result); err != nil { + return fmt.Errorf("failed to unmarshal task: %w", err) + } + return nil + default: + return fmt.Errorf("unsupported result kind: %s", kindOnly.Kind) } } -// NewMessage creates a new Message with the specified role and parts. -func NewMessage(role MessageRole, parts []Part) Message { - return Message{ - Role: role, - Parts: parts, +// MarshalJSON implements custom marshalling logic for MessageResult +func (r *MessageResult) MarshalJSON() ([]byte, error) { + switch r.Result.GetKind() { + case KindMessage: + return json.Marshal(r.Result) + case KindTask: + return json.Marshal(r.Result) + default: + return nil, fmt.Errorf("unsupported result kind: %s", r.Result.GetKind()) } } -// NewTextPart creates a new TextPart containing the given text. -func NewTextPart(text string) TextPart { - return TextPart{ - Type: PartTypeText, - Text: text, +// SendStreamingMessageParams defines the parameters for the message/stream RPC method. +type SendStreamingMessageParams struct { + // ID is the ID of the message. + ID string `json:"-"` + // Configuration contains optional sending configuration. + Configuration *SendMessageConfiguration `json:"configuration,omitempty"` + // Message is the message to send. + Message Message `json:"message"` + // Metadata is optional metadata. + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// StreamingMessageEvent represents the result of a streaming message operation. +// StreamingMessageEvent is the union type of Message/Task/TaskStatusUpdate/TaskArtifactUpdate. +type StreamingMessageEvent struct { + // Result is the final result of the streaming operation. + Result Event +} + +// UnmarshalJSON implements custom unmarshalling logic for StreamingMessageEvent +func (r *StreamingMessageEvent) UnmarshalJSON(data []byte) error { + // First, try to detect if this is wrapped in a Result field + type StreamingMessageEventRaw struct { + Result json.RawMessage `json:"Result"` + } + + var raw StreamingMessageEventRaw + var actualData []byte + + // Try to unmarshal as wrapped structure first + if err := json.Unmarshal(data, &raw); err == nil && len(raw.Result) > 0 { + // It's wrapped, use the Result field + actualData = raw.Result + } else { + // It's not wrapped, use the data directly + actualData = data + } + + // Parse the actual data to get the kind + var kindOnly struct { + Kind string `json:"kind"` + } + if err := json.Unmarshal(actualData, &kindOnly); err != nil { + return fmt.Errorf("failed to unmarshal result kind: %w", err) + } + + // Now unmarshal to the correct concrete type based on kind + switch kindOnly.Kind { + case KindMessage: + r.Result = &Message{} + if err := json.Unmarshal(actualData, r.Result); err != nil { + return fmt.Errorf("failed to unmarshal message: %w", err) + } + return nil + case KindTask: + r.Result = &Task{} + if err := json.Unmarshal(actualData, r.Result); err != nil { + return fmt.Errorf("failed to unmarshal task: %w", err) + } + return nil + case KindTaskStatusUpdate: + r.Result = &TaskStatusUpdateEvent{} + if err := json.Unmarshal(actualData, r.Result); err != nil { + return fmt.Errorf("failed to unmarshal task status update event: %w", err) + } + return nil + case KindTaskArtifactUpdate: + r.Result = &TaskArtifactUpdateEvent{} + if err := json.Unmarshal(actualData, r.Result); err != nil { + return fmt.Errorf("failed to unmarshal task artifact update event: %w", err) + } + return nil + default: + return fmt.Errorf("unsupported result kind: %s", kindOnly.Kind) + } +} + +// MarshalJSON implements custom marshalling logic for StreamingMessageResult +func (r *StreamingMessageEvent) MarshalJSON() ([]byte, error) { + switch r.Result.GetKind() { + case KindMessage: + return json.Marshal(r.Result) + case KindTask: + return json.Marshal(r.Result) + default: + return nil, fmt.Errorf("unsupported result kind: %s", r.Result.GetKind()) } } diff --git a/protocol/types_test.go b/protocol/types_test.go index 5c9a424..d239836 100644 --- a/protocol/types_test.go +++ b/protocol/types_test.go @@ -8,6 +8,8 @@ package protocol import ( "encoding/json" + "fmt" + "strings" "testing" "time" @@ -20,21 +22,26 @@ func boolPtr(b bool) *bool { return &b } +// Helper function to create string pointers +func stringPtr(s string) *string { + return &s +} + // Test marshalling of concrete types func TestPartConcreteType_MarshalJSON(t *testing.T) { t.Run("TextPart", func(t *testing.T) { part := NewTextPart("Hello") jsonData, err := json.Marshal(part) require.NoError(t, err) - assert.JSONEq(t, `{"type":"text","text":"Hello"}`, string(jsonData)) + assert.JSONEq(t, `{"kind":"text","text":"Hello"}`, string(jsonData)) }) t.Run("DataPart", func(t *testing.T) { payload := map[string]int{"a": 1} - part := DataPart{Type: PartTypeData, Data: payload} + part := DataPart{Kind: KindData, Data: payload} jsonData, err := json.Marshal(part) require.NoError(t, err) - assert.JSONEq(t, `{"type":"data","data":{"a":1}}`, string(jsonData)) + assert.JSONEq(t, `{"kind":"data","data":{"a":1}}`, string(jsonData)) }) // FilePart marshalling test can be added if needed @@ -46,9 +53,9 @@ func TestPartIfaceUnmarshalJSONCont(t *testing.T) { jsonData := `{ "role": "user", "parts": [ - {"type":"text", "text":"part1"}, - {"type":"data", "data":{"val": true}}, - {"type":"text", "text":"part3"} + {"kind": "text", "text": "part1"}, + {"kind": "data", "data": {"val": true}}, + {"kind": "text", "text": "part3"} ] }` @@ -57,14 +64,14 @@ func TestPartIfaceUnmarshalJSONCont(t *testing.T) { require.NoError(t, err, "Unmarshal into Message should succeed via custom UnmarshalJSON") require.Len(t, msg.Parts, 3) - require.IsType(t, TextPart{}, msg.Parts[0]) - assert.Equal(t, "part1", msg.Parts[0].(TextPart).Text) + require.IsType(t, &TextPart{}, msg.Parts[0]) + assert.Equal(t, "part1", msg.Parts[0].(*TextPart).Text) - require.IsType(t, DataPart{}, msg.Parts[1]) - assert.Equal(t, map[string]interface{}{"val": true}, msg.Parts[1].(DataPart).Data) + require.IsType(t, &DataPart{}, msg.Parts[1]) + assert.Equal(t, map[string]interface{}{"val": true}, msg.Parts[1].(*DataPart).Data) - require.IsType(t, TextPart{}, msg.Parts[2]) - assert.Equal(t, "part3", msg.Parts[2].(TextPart).Text) + require.IsType(t, &TextPart{}, msg.Parts[2]) + assert.Equal(t, "part3", msg.Parts[2].(*TextPart).Text) } // Removed TestArtifactPart_MarshalUnmarshalJSON as ArtifactPart type doesn't exist @@ -73,7 +80,7 @@ func TestPartIfaceUnmarshalJSONCont(t *testing.T) { // Removed TestMessage_MarshalUnmarshalJSON as it's covered by TestPartIface_UnmarshalJSON_Container func TestTaskEvent_IsFinal(t *testing.T) { - // Use boolPtr helper for LastChunk + // Use boolPtr helper for pointer boolean values tests := []struct { name string event TaskEvent @@ -81,37 +88,47 @@ func TestTaskEvent_IsFinal(t *testing.T) { }{ { name: "StatusUpdate Submitted", - event: TaskStatusUpdateEvent{Final: false, Status: TaskStatus{State: TaskStateSubmitted}}, + event: &TaskStatusUpdateEvent{Final: boolPtr(false), Status: TaskStatus{State: TaskStateSubmitted}}, expected: false, }, { name: "StatusUpdate Working", - event: TaskStatusUpdateEvent{Final: false, Status: TaskStatus{State: TaskStateWorking}}, + event: &TaskStatusUpdateEvent{Final: boolPtr(false), Status: TaskStatus{State: TaskStateWorking}}, expected: false, }, { name: "StatusUpdate Completed", - event: TaskStatusUpdateEvent{Final: true, Status: TaskStatus{State: TaskStateCompleted}}, + event: &TaskStatusUpdateEvent{Final: boolPtr(true), Status: TaskStatus{State: TaskStateCompleted}}, expected: true, }, { name: "StatusUpdate Failed", - event: TaskStatusUpdateEvent{Final: true, Status: TaskStatus{State: TaskStateFailed}}, + event: &TaskStatusUpdateEvent{Final: boolPtr(true), Status: TaskStatus{State: TaskStateFailed}}, expected: true, }, { name: "StatusUpdate Canceled", - event: TaskStatusUpdateEvent{Final: true, Status: TaskStatus{State: TaskStateCanceled}}, + event: &TaskStatusUpdateEvent{Final: boolPtr(true), Status: TaskStatus{State: TaskStateCanceled}}, expected: true, }, { - name: "ArtifactUpdate Not Last Chunk", - event: TaskArtifactUpdateEvent{Final: false, Artifact: Artifact{LastChunk: boolPtr(false)}}, + name: "StatusUpdate Rejected", + event: &TaskStatusUpdateEvent{Final: boolPtr(true), Status: TaskStatus{State: TaskStateRejected}}, + expected: true, + }, + { + name: "StatusUpdate AuthRequired", + event: &TaskStatusUpdateEvent{Final: boolPtr(false), Status: TaskStatus{State: TaskStateAuthRequired}}, + expected: false, + }, + { + name: "ArtifactUpdate Not Final", + event: TaskArtifactUpdateEvent{}, expected: false, }, { - name: "ArtifactUpdate Last Chunk", - event: TaskArtifactUpdateEvent{Final: true, Artifact: Artifact{LastChunk: boolPtr(true)}}, + name: "ArtifactUpdate Final", + event: TaskArtifactUpdateEvent{LastChunk: boolPtr(true)}, expected: true, }, } @@ -134,6 +151,8 @@ func TestTaskState(t *testing.T) { {TaskStateCompleted, "completed"}, {TaskStateCanceled, "canceled"}, {TaskStateFailed, "failed"}, + {TaskStateRejected, "rejected"}, + {TaskStateAuthRequired, "auth-required"}, {TaskStateUnknown, "unknown"}, {TaskStateInputRequired, "input-required"}, } @@ -149,22 +168,19 @@ func TestTaskState(t *testing.T) { func TestMessageJSON(t *testing.T) { // Create text part textPart := TextPart{ - Type: PartTypeText, + Kind: "text", Text: "Hello, world!", } // Create file part - filePart := FilePart{ - Type: PartTypeFile, - File: FileContent{ - Name: stringPtr("example.txt"), - MimeType: stringPtr("text/plain"), - Bytes: stringPtr("RmlsZSBjb250ZW50"), // Base64 encoded "File content" - }, - } + filePart := NewFilePartWithBytes( + "example.txt", + "text/plain", + "File content", + ) // Create a message with these parts - original := NewMessage(MessageRoleUser, []Part{textPart, filePart}) + original := NewMessage(MessageRoleUser, []Part{&textPart, &filePart}) // Marshal the message data, err := json.Marshal(original) @@ -182,72 +198,58 @@ func TestMessageJSON(t *testing.T) { require.Len(t, decoded.Parts, 2) // Check the text part - textPartDecoded, ok := decoded.Parts[0].(TextPart) + textPartDecoded, ok := decoded.Parts[0].(*TextPart) require.True(t, ok, "First part should be a TextPart") - assert.Equal(t, PartTypeText, textPartDecoded.Type) + assert.Equal(t, "text", textPartDecoded.Kind) assert.Equal(t, "Hello, world!", textPartDecoded.Text) // Check the file part - filePartDecoded, ok := decoded.Parts[1].(FilePart) + filePartDecoded, ok := decoded.Parts[1].(*FilePart) require.True(t, ok, "Second part should be a FilePart") - assert.Equal(t, PartTypeFile, filePartDecoded.Type) - assert.Equal(t, "example.txt", *filePartDecoded.File.Name) - assert.Equal(t, "text/plain", *filePartDecoded.File.MimeType) - assert.Equal(t, "RmlsZSBjb250ZW50", *filePartDecoded.File.Bytes) + assert.Equal(t, KindFile, filePartDecoded.Kind) + fileWithBytes, ok := filePartDecoded.File.(*FileWithBytes) + assert.True(t, ok, "File part should be a FileWithBytes") + assert.Equal(t, "example.txt", *fileWithBytes.Name) + assert.Equal(t, "text/plain", *fileWithBytes.MimeType) + assert.Equal(t, "File content", fileWithBytes.Bytes) } // TestPartValidation tests validation of different part types func TestPartValidation(t *testing.T) { // Test TextPart t.Run("TextPart", func(t *testing.T) { - textPart := TextPart{ - Type: PartTypeText, - Text: "Valid text content", - } - assert.True(t, isValidPart(textPart)) + textPart := NewTextPart("Valid text content") + assert.True(t, isValidPart(&textPart)) // Invalid part type - invalidPart := TextPart{ - Type: "invalid", - Text: "Text with invalid part type", - } - assert.False(t, isValidPart(invalidPart)) + invalidPart := NewTextPart("") + assert.False(t, isValidPart(&invalidPart)) }) // Test FilePart t.Run("FilePart", func(t *testing.T) { - validFilePart := FilePart{ - Type: PartTypeFile, - File: FileContent{ - Name: stringPtr("file.txt"), - MimeType: stringPtr("text/plain"), - Bytes: stringPtr("SGVsbG8="), // Base64 "Hello" - }, - } - assert.True(t, isValidPart(validFilePart)) + validFilePart := NewFilePartWithBytes("file.txt", "text/plain", "SGVsbG8=") + assert.True(t, isValidPart(&validFilePart)) // Invalid part: missing required file info - invalidFilePart := FilePart{ - Type: PartTypeFile, - File: FileContent{}, // Empty file content - } - assert.False(t, isValidPart(invalidFilePart)) + invalidFilePart := NewFilePartWithBytes("file.txt", "text/plain", "") + assert.False(t, isValidPart(&invalidFilePart)) }) // Test DataPart t.Run("DataPart", func(t *testing.T) { validDataPart := DataPart{ - Type: PartTypeData, + Kind: KindData, Data: map[string]interface{}{"key": "value"}, } - assert.True(t, isValidPart(validDataPart)) + assert.True(t, isValidPart(&validDataPart)) // Invalid part: nil data invalidDataPart := DataPart{ - Type: PartTypeData, + Kind: KindData, Data: nil, } - assert.False(t, isValidPart(invalidDataPart)) + assert.False(t, isValidPart(&invalidDataPart)) }) } @@ -255,12 +257,18 @@ func TestPartValidation(t *testing.T) { // This is a simplified validation just for testing func isValidPart(part Part) bool { switch p := part.(type) { - case TextPart: - return p.Type == PartTypeText && p.Text != "" - case FilePart: - return p.Type == PartTypeFile && (p.File.Name != nil || p.File.URI != nil || p.File.Bytes != nil) - case DataPart: - return p.Type == PartTypeData && p.Data != nil + case *TextPart: + return p.Text != "" + case *FilePart: + if fileWithBytes, ok := p.File.(*FileWithBytes); ok { + return fileWithBytes.Bytes != "" + } + if fileWithURI, ok := p.File.(*FileWithURI); ok { + return fileWithURI.URI != "" + } + return false + case *DataPart: + return p.Data != nil default: return false } @@ -270,18 +278,16 @@ func isValidPart(part Part) bool { func TestArtifact(t *testing.T) { // Create a simple text part textPart := TextPart{ - Type: PartTypeText, + Kind: KindText, Text: "Artifact content", } - // Create an artifact - artifact := Artifact{ - Name: stringPtr("Test Artifact"), - Description: stringPtr("This is a test artifact"), - Parts: []Part{textPart}, - Index: 1, - LastChunk: boolPtr(true), - } + // Create an artifact with generated ID + artifact := NewArtifactWithID( + stringPtr("Test Artifact"), + stringPtr("This is a test artifact"), + []Part{&textPart}, + ) // Validate the artifact assert.NotNil(t, artifact.Name) @@ -289,9 +295,7 @@ func TestArtifact(t *testing.T) { assert.NotNil(t, artifact.Description) assert.Equal(t, "This is a test artifact", *artifact.Description) assert.Len(t, artifact.Parts, 1) - assert.Equal(t, 1, artifact.Index) - assert.NotNil(t, artifact.LastChunk) - assert.True(t, *artifact.LastChunk) + assert.NotEmpty(t, artifact.ArtifactID) // Test JSON marshaling data, err := json.Marshal(artifact) @@ -305,14 +309,13 @@ func TestArtifact(t *testing.T) { // Verify the decoded artifact assert.Equal(t, *artifact.Name, *decoded.Name) assert.Equal(t, *artifact.Description, *decoded.Description) - assert.Equal(t, artifact.Index, decoded.Index) - assert.Equal(t, *artifact.LastChunk, *decoded.LastChunk) + assert.Equal(t, artifact.ArtifactID, decoded.ArtifactID) // Check the part require.Len(t, decoded.Parts, 1) - decodedPart, ok := decoded.Parts[0].(TextPart) + decodedPart, ok := decoded.Parts[0].(*TextPart) require.True(t, ok, "Should decode as a TextPart") - assert.Equal(t, PartTypeText, decodedPart.Type) + assert.Equal(t, "text", decodedPart.Kind) assert.Equal(t, "Artifact content", decodedPart.Text) } @@ -331,29 +334,23 @@ func TestTaskStatus(t *testing.T) { assert.Equal(t, now, status.Timestamp) // Test adding a message to the status - message := NewMessage(MessageRoleAgent, []Part{ - TextPart{ - Type: PartTypeText, - Text: "Task completed successfully", - }, - }) + textPart := TextPart{ + Kind: "text", + Text: "Task completed successfully", + } + message := NewMessage(MessageRoleAgent, []Part{&textPart}) status.Message = &message assert.NotNil(t, status.Message) assert.Equal(t, MessageRoleAgent, status.Message.Role) } -// Helper functions for testing -func stringPtr(s string) *string { - return &s -} - // TestMarkerFunctions tests the marker functions for parts and events func TestMarkerFunctions(t *testing.T) { // Test part marker functions - textPart := TextPart{Type: PartTypeText, Text: "Test"} - filePart := FilePart{Type: PartTypeFile} - dataPart := DataPart{Type: PartTypeData, Data: map[string]string{"key": "value"}} + textPart := TextPart{Kind: "text", Text: "Test"} + filePart := FilePart{Kind: KindFile} + dataPart := DataPart{Kind: KindData, Data: map[string]string{"key": "value"}} // This test simply ensures the marker functions exist and don't panic // They're just marker methods with no behavior @@ -362,8 +359,8 @@ func TestMarkerFunctions(t *testing.T) { dataPart.partMarker() // Test event marker functions - statusEvent := TaskStatusUpdateEvent{ID: "test", Status: TaskStatus{State: TaskStateCompleted}} - artifactEvent := TaskArtifactUpdateEvent{ID: "test"} + statusEvent := TaskStatusUpdateEvent{TaskID: "test", Status: TaskStatus{State: TaskStateCompleted}} + artifactEvent := TaskArtifactUpdateEvent{TaskID: "test"} // This test simply ensures the marker functions exist and don't panic statusEvent.eventMarker() @@ -375,23 +372,199 @@ func TestMarkerFunctions(t *testing.T) { // TestNewTask tests the NewTask factory function func TestNewTask(t *testing.T) { - // Test with just ID + // Test with contextID taskID := "test-task" - task := NewTask(taskID, nil) + contextID := "test-context-123" + task := NewTask(taskID, contextID) assert.Equal(t, taskID, task.ID) - assert.Nil(t, task.SessionID) + assert.Equal(t, contextID, task.ContextID) assert.Equal(t, TaskStateSubmitted, task.Status.State) assert.NotEmpty(t, task.Status.Timestamp) assert.NotNil(t, task.Metadata) assert.Empty(t, task.Metadata) assert.Nil(t, task.Artifacts) +} - // Test with session ID - sessionID := "test-session" - task = NewTask(taskID, &sessionID) +// TestGenerateContextID tests the context ID generation function +func TestGenerateContextID(t *testing.T) { + contextID1 := GenerateContextID() + assert.NotEmpty(t, contextID1) + assert.True(t, strings.HasPrefix(contextID1, "ctx-")) + assert.Len(t, contextID1, 40) // "ctx-" + UUID (36 chars) = 40 chars - assert.Equal(t, taskID, task.ID) - assert.Equal(t, &sessionID, task.SessionID) - assert.Equal(t, *task.SessionID, sessionID) + contextID2 := GenerateContextID() + assert.NotEmpty(t, contextID2) + assert.NotEqual(t, contextID1, contextID2) // Should be unique +} + +// TestGenerateMessageID tests the message ID generation function +func TestGenerateMessageID(t *testing.T) { + messageID1 := GenerateMessageID() + assert.NotEmpty(t, messageID1) + assert.True(t, strings.HasPrefix(messageID1, "msg-")) + assert.Len(t, messageID1, 40) // "msg-" + UUID (36 chars) + + messageID2 := GenerateMessageID() + assert.NotEmpty(t, messageID2) + assert.NotEqual(t, messageID1, messageID2) // Should be unique +} + +// TestNewMessage tests the NewMessage factory functions +func TestNewMessage(t *testing.T) { + // Test basic NewMessage + textPart := NewTextPart("Hello") + parts := []Part{&textPart} + message := NewMessage(MessageRoleUser, parts) + + assert.Equal(t, MessageRoleUser, message.Role) + assert.Equal(t, parts, message.Parts) + assert.NotEmpty(t, message.MessageID) + assert.True(t, strings.HasPrefix(message.MessageID, "msg-")) + assert.Equal(t, "message", message.Kind) + assert.Nil(t, message.TaskID) + assert.Nil(t, message.ContextID) + + // Test NewMessageWithContext + taskID := "task-123" + contextID := "ctx-456" + messageWithContext := NewMessageWithContext(MessageRoleAgent, parts, &taskID, &contextID) + + assert.Equal(t, MessageRoleAgent, messageWithContext.Role) + assert.Equal(t, parts, messageWithContext.Parts) + assert.NotEmpty(t, messageWithContext.MessageID) + assert.Equal(t, "message", messageWithContext.Kind) + require.NotNil(t, messageWithContext.TaskID) + assert.Equal(t, taskID, *messageWithContext.TaskID) + require.NotNil(t, messageWithContext.ContextID) + assert.Equal(t, contextID, *messageWithContext.ContextID) +} + +func TestPartDeserialization(t *testing.T) { + // Test for TextPart + t.Run("TextPart", func(t *testing.T) { + jsonData := `{"kind":"text","text":"Hello","metadata":{"foo":"bar"}}` + parts, err := unmarshalPartsFromJSON([]byte(fmt.Sprintf("[%s]", jsonData))) + assert.NoError(t, err) + assert.Len(t, parts, 1) + textPartDecoded, ok := parts[0].(*TextPart) + assert.True(t, ok) + assert.Equal(t, "text", textPartDecoded.Kind) + assert.Equal(t, "Hello", textPartDecoded.Text) + assert.Equal(t, map[string]interface{}{"foo": "bar"}, textPartDecoded.Metadata) + }) + + // Test for FilePart + t.Run("FilePart", func(t *testing.T) { + jsonData := `{"kind":"file","file":{"name":"example.txt","mimeType":"text/plain","bytes":"SGVsbG8gV29ybGQ="}}` + parts, err := unmarshalPartsFromJSON([]byte(fmt.Sprintf("[%s]", jsonData))) + assert.NoError(t, err) + assert.Len(t, parts, 1) + + filePartDecoded, ok := parts[0].(*FilePart) + assert.True(t, ok) + assert.Equal(t, KindFile, filePartDecoded.Kind) + + // Access FileUnion properly by type assertion + fileWithBytes, ok := filePartDecoded.File.(*FileWithBytes) + assert.True(t, ok) + assert.Equal(t, "example.txt", *fileWithBytes.Name) + assert.Equal(t, "text/plain", *fileWithBytes.MimeType) + assert.Equal(t, "SGVsbG8gV29ybGQ=", fileWithBytes.Bytes) + }) + + // Test for DataPart + t.Run("DataPart", func(t *testing.T) { + jsonData := `{"kind":"data","data":{"key":"value","number":42}}` + parts, err := unmarshalPartsFromJSON([]byte(fmt.Sprintf("[%s]", jsonData))) + assert.NoError(t, err) + assert.Len(t, parts, 1) + + dataPartDecoded, ok := parts[0].(*DataPart) + assert.True(t, ok) + assert.Equal(t, KindData, dataPartDecoded.Kind) + + // Access the data as a map + dataMap, ok := dataPartDecoded.Data.(map[string]interface{}) + assert.True(t, ok) + assert.Equal(t, "value", dataMap["key"]) + assert.Equal(t, float64(42), dataMap["number"]) // JSON numbers are float64 + }) +} + +func TestMessage_MarshalJSON(t *testing.T) { + textPart := TextPart{ + Kind: "text", + Text: "Hello", + } + + message := Message{ + Role: MessageRoleUser, + Parts: []Part{&textPart}, + } + + jsonData, err := json.Marshal(message) + require.NoError(t, err) + + var decodedMessage Message + err = json.Unmarshal(jsonData, &decodedMessage) + require.NoError(t, err) + + // Verify that the unmarshaled Part is a concrete TextPart + require.Len(t, decodedMessage.Parts, 1) + textPartFromDecoded, ok := decodedMessage.Parts[0].(*TextPart) + require.True(t, ok) + assert.Equal(t, "text", textPartFromDecoded.Kind) + assert.Equal(t, "Hello", textPartFromDecoded.Text) +} + +func TestMessage_UnmarshalJSON(t *testing.T) { + jsonData := `{ + "role": "user", + "parts": [ + {"kind": "text", "text": "Hello"}, + {"kind": "file", "file": {"name": "test.txt", "mimeType": "text/plain", "bytes": "SGVsbG8="}} + ] +}` + + var message Message + err := json.Unmarshal([]byte(jsonData), &message) + require.NoError(t, err) + + assert.Equal(t, MessageRoleUser, message.Role) + require.Len(t, message.Parts, 2) + + // Check first part (TextPart) + textPart, ok := message.Parts[0].(*TextPart) + assert.True(t, ok) + assert.Equal(t, "text", textPart.Kind) + assert.Equal(t, "Hello", textPart.Text) + + // Check second part (FilePart) + filePart, ok := message.Parts[1].(*FilePart) + assert.True(t, ok) + assert.Equal(t, KindFile, filePart.Kind) + + // Access FileUnion properly by type assertion + fileWithBytes, ok := filePart.File.(*FileWithBytes) + assert.True(t, ok) + assert.Equal(t, "test.txt", *fileWithBytes.Name) +} + +// Helper function to unmarshal JSON array of parts +func unmarshalPartsFromJSON(data []byte) ([]Part, error) { + var rawParts []json.RawMessage + if err := json.Unmarshal(data, &rawParts); err != nil { + return nil, err + } + + parts := make([]Part, len(rawParts)) + for i, rawPart := range rawParts { + part, err := unmarshalPart(rawPart) + if err != nil { + return nil, err + } + parts[i] = part + } + return parts, nil } diff --git a/server/server.go b/server/server.go index 630068a..11858a5 100644 --- a/server/server.go +++ b/server/server.go @@ -25,6 +25,8 @@ import ( "trpc.group/trpc-go/trpc-a2a-go/taskmanager" ) +var errUnknownEvent = errors.New("unknown event type") + // A2AServer implements the HTTP server for the A2A protocol. // It handles agent card requests and routes JSON-RPC calls to the TaskManager. type A2AServer struct { @@ -142,7 +144,7 @@ func (s *A2AServer) Handler() http.Handler { // Corresponds to GET /.well-known/agent.json in A2A Spec. func (s *A2AServer) handleAgentCard(w http.ResponseWriter, r *http.Request) { if s.corsEnabled { - s.setCORSHeaders(w) + setCORSHeaders(w) } if r.Method != http.MethodGet { http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) @@ -161,7 +163,7 @@ func (s *A2AServer) handleAgentCard(w http.ResponseWriter, r *http.Request) { func (s *A2AServer) handleJSONRPC(w http.ResponseWriter, r *http.Request) { // --- CORS Handling --- if s.corsEnabled { - s.setCORSHeaders(w) + setCORSHeaders(w) // Handle browser preflight requests. if r.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) @@ -246,20 +248,31 @@ func (s *A2AServer) routeJSONRPCMethod(ctx context.Context, w http.ResponseWrite log.Infof("Received JSON-RPC request (ID: %v, Method: %s)", request.ID, request.Method) switch request.Method { - case protocol.MethodTasksSend: // A2A Spec: tasks/send - s.handleTasksSend(ctx, w, request) - case protocol.MethodTasksSendSubscribe: // A2A Spec: tasks/sendSubscribe - s.handleTasksSendSubscribe(ctx, w, request) + + case protocol.MethodMessageSend: // A2A Spec: message/send + s.handleMessageSend(ctx, w, request) + case protocol.MethodMessageStream: // A2A Spec: message/stream + s.handleMessageStream(ctx, w, request) + case protocol.MethodTasksPushNotificationConfigGet: // A2A Spec: tasks/pushNotification/config/get + s.handleTasksPushNotificationGet(ctx, w, request) + case protocol.MethodTasksPushNotificationConfigSet: // A2A Spec: tasks/pushNotification/config/set + s.handleTasksPushNotificationSet(ctx, w, request) case protocol.MethodTasksGet: // A2A Spec: tasks/get s.handleTasksGet(ctx, w, request) case protocol.MethodTasksCancel: // A2A Spec: tasks/cancel s.handleTasksCancel(ctx, w, request) - case protocol.MethodTasksPushNotificationSet: // A2A Spec: tasks/pushNotification/set - s.handleTasksPushNotificationSet(ctx, w, request) - case protocol.MethodTasksPushNotificationGet: // A2A Spec: tasks/pushNotification/get - s.handleTasksPushNotificationGet(ctx, w, request) case protocol.MethodTasksResubscribe: // A2A Spec: tasks/resubscribe s.handleTasksResubscribe(ctx, w, request) + + // deprecated methods: + case protocol.MethodTasksSend: // A2A Spec: message/send + s.handleTasksSend(ctx, w, request) + case protocol.MethodTasksSendSubscribe: // A2A Spec: message/sendSubscribe + s.handleTasksSendSubscribe(ctx, w, request) + case protocol.MethodTasksPushNotificationGet: // A2A Spec: tasks/pushNotification/get + s.handleTasksPushNotificationGet(ctx, w, request) + case protocol.MethodTasksPushNotificationSet: // A2A Spec: tasks/pushNotification/config/set + s.handleTasksPushNotificationSet(ctx, w, request) default: log.Warnf("Method not found: %s (Request ID: %v)", request.Method, request.ID) s.writeJSONRPCError(w, request.ID, @@ -283,6 +296,17 @@ func (s *A2AServer) handleTasksSend(ctx context.Context, w http.ResponseWriter, s.writeJSONRPCError(w, request.ID, err) return } + + // Validate required fields + if params.ID == "" { + s.writeJSONRPCError(w, request.ID, jsonrpc.ErrInvalidParams("task ID is required")) + return + } + if params.Message.Role == "" || len(params.Message.Parts) == 0 { + s.writeJSONRPCError(w, request.ID, jsonrpc.ErrInvalidParams("message with at least one part is required")) + return + } + // Delegate to the task manager. task, err := s.taskManager.OnSendTask(ctx, params) if err != nil { @@ -346,89 +370,6 @@ func (s *A2AServer) handleTasksCancel(ctx context.Context, w http.ResponseWriter s.writeJSONRPCResponse(w, request.ID, task) } -// handleSSEStream handles an SSE stream for a task, including setup and event forwarding. -// It sets the appropriate headers, logs connection status, and forwards events to the client. -func (s *A2AServer) handleSSEStream( - ctx context.Context, - w http.ResponseWriter, - flusher http.Flusher, - eventsChan <-chan protocol.TaskEvent, - taskID string, - requestID interface{}, - isResubscribe bool, -) { - // Set headers for SSE. - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - if s.corsEnabled { - s.setCORSHeaders(w) - } - - // Indicate successful subscription setup. - w.WriteHeader(http.StatusOK) - flusher.Flush() // Send headers immediately. - - // Log appropriate message based on whether this is a new subscription or resubscribe - if isResubscribe { - log.Infof("SSE stream reopened for task %s (Request ID: %v)", taskID, requestID) - } else { - log.Infof("SSE stream opened for task %s (Request ID: %v)", taskID, requestID) - } - - // Use request context to detect client disconnection. - clientClosed := ctx.Done() - - // --- Event Forwarding Loop --- - for { - select { - case event, ok := <-eventsChan: - if !ok { - // Channel closed by task manager (task finished or error). - log.Infof("SSE stream closing for task %s (event channel closed by manager)", taskID) - // Send a final SSE event indicating closure. - closeData := sse.CloseEventData{ - TaskID: taskID, - Reason: "task ended", - } - // Use JSON-RPC format for the close event - if err := sse.FormatJSONRPCEvent(w, protocol.EventClose, requestID, closeData); err != nil { - log.Errorf("Error writing SSE JSON-RPC close event for task %s: %v", taskID, err) - } else { - flusher.Flush() - } - return // End the handler. - } - - // Determine event type string for SSE. - var eventType string - switch event.(type) { - case protocol.TaskStatusUpdateEvent: - eventType = protocol.EventTaskStatusUpdate - case protocol.TaskArtifactUpdateEvent: - eventType = protocol.EventTaskArtifactUpdate - default: - log.Warnf("Unknown event type received for task %s: %T. Skipping.", taskID, event) - continue // Skip unknown event types - } - - // Write the event to the SSE stream using JSON-RPC format. - if err := sse.FormatJSONRPCEvent(w, eventType, requestID, event); err != nil { - // Error writing, likely client disconnected. - log.Errorf("Error writing SSE JSON-RPC event for task %s (client likely disconnected): %v. "+ - "Closing stream.", taskID, err) - return // Exit the handler. - } - // Flush the buffer to ensure the event is sent immediately. - flusher.Flush() - case <-clientClosed: - // Client disconnected (request context canceled). - log.Infof("SSE client disconnected for task %s (Request ID: %v). Closing stream.", taskID, requestID) - return // Exit the handler. - } - } -} - // handleTasksSendSubscribe handles the tasks_sendSubscribe method using Server-Sent Events (SSE). func (s *A2AServer) handleTasksSendSubscribe(ctx context.Context, w http.ResponseWriter, request jsonrpc.Request) { var params protocol.SendTaskParams @@ -470,7 +411,7 @@ func (s *A2AServer) handleTasksSendSubscribe(ctx context.Context, w http.Respons } // Use the helper function to handle the SSE stream - s.handleSSEStream(ctx, w, flusher, eventsChan, params.ID, request.ID, false) + handleSSEStream(ctx, s.corsEnabled, w, flusher, eventsChan, request.ID.(string), false) } // writeJSONRPCResponse encodes and writes a successful JSON-RPC response. @@ -516,7 +457,7 @@ func (s *A2AServer) writeJSONRPCError(w http.ResponseWriter, id interface{}, err // setCORSHeaders adds permissive CORS headers for development/testing. // WARNING: This is insecure for production. Configure origins explicitly. -func (s *A2AServer) setCORSHeaders(w http.ResponseWriter) { +func setCORSHeaders(w http.ResponseWriter) { w.Header().Set("Access-Control-Allow-Origin", "*") // INSECURE w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") @@ -534,7 +475,7 @@ func (s *A2AServer) handleTasksPushNotificationSet( return } // Validate required fields. - if params.ID == "" { + if params.TaskID == "" { s.writeJSONRPCError(w, request.ID, jsonrpc.ErrInvalidParams("task ID is required")) return } @@ -579,7 +520,7 @@ func (s *A2AServer) handleTasksPushNotificationSet( // Delegate to the task manager. result, err := s.taskManager.OnPushNotificationSet(ctx, params) if err != nil { - log.Errorf("Error calling OnPushNotificationSet for task %s: %v", params.ID, err) + log.Errorf("Error calling OnPushNotificationSet for task %s: %v", params.TaskID, err) // Check if the error is already a JSONRPCError. if rpcErr, ok := err.(*jsonrpc.Error); ok { s.writeJSONRPCError(w, request.ID, rpcErr) @@ -684,5 +625,169 @@ func (s *A2AServer) handleTasksResubscribe(ctx context.Context, w http.ResponseW } // Use the helper function to handle the SSE stream - s.handleSSEStream(ctx, w, flusher, eventsChan, params.ID, request.ID, true) + handleSSEStream[protocol.StreamingMessageEvent](ctx, s.corsEnabled, w, flusher, eventsChan, request.ID.(string), true) +} + +// handleMessageSend handles the message_send method. +func (s *A2AServer) handleMessageSend(ctx context.Context, w http.ResponseWriter, request jsonrpc.Request) { + var params protocol.SendMessageParams + if err := s.unmarshalParams(request.Params, ¶ms); err != nil { + s.writeJSONRPCError(w, request.ID, err) + return + } + // Delegate to the task manager. + message, err := s.taskManager.OnSendMessage(ctx, params) + if err != nil { + log.Errorf("Error calling OnSendMessage for message %s: %v", params.RPCID, err) + // Check if it's already a JSON-RPC error + if rpcErr, ok := err.(*jsonrpc.Error); ok { + s.writeJSONRPCError(w, request.ID, rpcErr) + } else { + // Otherwise, wrap as internal error + s.writeJSONRPCError(w, request.ID, + jsonrpc.ErrInternalError(fmt.Sprintf("message processing failed: %v", err))) + } + return + } + s.writeJSONRPCResponse(w, request.ID, message) +} + +// handleMessageStream handles the message_stream method using Server-Sent Events (SSE). +func (s *A2AServer) handleMessageStream(ctx context.Context, w http.ResponseWriter, request jsonrpc.Request) { + var params protocol.SendMessageParams + if err := s.unmarshalParams(request.Params, ¶ms); err != nil { + s.writeJSONRPCError(w, request.ID, err) + return + } + + if params.Message.Role == "" || len(params.Message.Parts) == 0 { + s.writeJSONRPCError(w, request.ID, jsonrpc.ErrInvalidParams("message with at least one part is required")) + return + } + + // Check if client supports SSE. + flusher, ok := w.(http.Flusher) + if !ok { + log.Error("Streaming is not supported by the underlying http responseWriter") + s.writeJSONRPCError(w, request.ID, jsonrpc.ErrInternalError("server does not support streaming")) + return + } + + // Get the event channel from the task manager. + eventsChan, err := s.taskManager.OnSendMessageStream(ctx, params) + if err != nil { + log.Errorf("Error calling OnSendMessageStream for message %s: %v", params.RPCID, err) + s.writeJSONRPCError(w, request.ID, + jsonrpc.ErrInternalError(fmt.Sprintf("failed to subscribe to message events: %v", err))) + return + } + + // Use the helper function to handle the SSE stream + handleSSEStream(ctx, s.corsEnabled, w, flusher, eventsChan, request.ID.(string), false) +} + +// handleSSEStream handles an SSE stream for a task, including setup and event forwarding. +// It sets the appropriate headers, logs connection status, and forwards events to the client. +func handleSSEStream[T interface{}]( + ctx context.Context, + corsEnabled bool, + w http.ResponseWriter, + flusher http.Flusher, + eventsChan <-chan T, + rpcID string, + isResubscribe bool, +) { + // Set headers for SSE. + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + if corsEnabled { + setCORSHeaders(w) + } + + // Indicate successful subscription setup. + w.WriteHeader(http.StatusOK) + flusher.Flush() // Send headers immediately. + + // Log appropriate message based on whether this is a new subscription or resubscribe + if isResubscribe { + log.Infof("SSE stream reopened for request ID: %v)", rpcID) + } else { + log.Infof("SSE stream opened for request ID: %v)", rpcID) + } + + // Use request context to detect client disconnection. + clientClosed := ctx.Done() + + // --- Event Forwarding Loop --- + for { + select { + case event, ok := <-eventsChan: + if !ok { + // Channel closed by task manager (task finished or error). + log.Infof("SSE stream closing request ID: %s", rpcID) + // Send a final SSE event indicating closure. + closeData := sse.CloseEventData{ + ID: rpcID, + Reason: "task ended", + } + // Use JSON-RPC format for the close event + if err := sse.FormatJSONRPCEvent(w, protocol.EventClose, rpcID, closeData); err != nil { + log.Errorf("Error writing SSE JSON-RPC close event for request ID: %s: %v", rpcID, err) + } else { + flusher.Flush() + } + return // End the handler. + } + if err := sendSSEEvent(w, rpcID, flusher, event); err != nil { + if err == errUnknownEvent { + log.Warnf("Unknown event type received for request ID: %s: %T. Skipping.", rpcID, event) + continue + } + log.Errorf("Error writing SSE JSON-RPC event for request ID: %s (client likely disconnected): %v", rpcID, err) + return + } + // Flush the buffer to ensure the event is sent immediately. + flusher.Flush() + case <-clientClosed: + // Client disconnected (request context canceled). + log.Infof("SSE client disconnected for request ID: %s. Closing stream.", rpcID) + return // Exit the handler. + } + } +} + +func sendSSEEvent(w http.ResponseWriter, rpcID string, flusher http.Flusher, event interface{}) error { + // Determine event type string for SSE. + var eventType string + var actualEvent protocol.Event + + // Handle StreamingMessageEvent by extracting the inner Result + if streamEvent, ok := event.(protocol.StreamingMessageEvent); ok { + actualEvent = streamEvent.Result + } else if directEvent, ok := event.(protocol.Event); ok { + actualEvent = directEvent + } else { + return errUnknownEvent + } + + switch actualEvent.(type) { + case *protocol.TaskStatusUpdateEvent: + eventType = protocol.EventStatusUpdate + case *protocol.TaskArtifactUpdateEvent: + eventType = protocol.EventArtifactUpdate + case *protocol.Message: + eventType = protocol.EventMessage + case *protocol.Task: + eventType = protocol.EventTask + default: + return errUnknownEvent + } + + // For StreamMessage API, we need to send the event wrapped in StreamingMessageEvent + // Write the event to the SSE stream using JSON-RPC format. + if err := sse.FormatJSONRPCEvent(w, eventType, rpcID, event); err != nil { + return err + } + return nil } diff --git a/server/server_handlers_test.go b/server/server_handlers_test.go index 49f631e..bcc8fb4 100644 --- a/server/server_handlers_test.go +++ b/server/server_handlers_test.go @@ -33,10 +33,6 @@ func stringPtr(s string) *string { return &s } -func boolPtr(b bool) *bool { - return &b -} - // setupTestServer creates a test server with the given task manager and options. // Returns the test server and the A2A server for use in tests. func setupTestServer(t *testing.T, tm taskmanager.TaskManager, opts ...Option) (*httptest.Server, *A2AServer) { @@ -131,7 +127,7 @@ func testJSONRPCErrorResponse(t *testing.T, server *httptest.Server, method stri } // verifyPushNotificationConfig verifies a push notification configuration response -func verifyPushNotificationConfig(t *testing.T, resp *jsonrpc.Response, expectedID, expectedURL string) { +func verifyPushNotificationConfig(t *testing.T, resp *jsonrpc.Response, expectedTaskID, expectedURL string) { t.Helper() // Verify response basics @@ -146,7 +142,7 @@ func verifyPushNotificationConfig(t *testing.T, resp *jsonrpc.Response, expected err = json.Unmarshal(resultBytes, &config) require.NoError(t, err) - assert.Equal(t, expectedID, config.ID) + assert.Equal(t, expectedTaskID, config.TaskID) assert.Equal(t, expectedURL, config.PushNotificationConfig.URL) } @@ -308,7 +304,7 @@ func TestA2AServer_PushNotifications(t *testing.T) { t.Run("PushNotification_Set", func(t *testing.T) { // Configure mock task manager mockTM.pushNotificationSetResponse = &protocol.TaskPushNotificationConfig{ - ID: "test-push-task", + TaskID: "test-push-task", PushNotificationConfig: protocol.PushNotificationConfig{ URL: "https://example.com/webhook", }, @@ -317,7 +313,7 @@ func TestA2AServer_PushNotifications(t *testing.T) { // Create request params := protocol.TaskPushNotificationConfig{ - ID: "test-push-task", + TaskID: "test-push-task", PushNotificationConfig: protocol.PushNotificationConfig{ URL: "https://example.com/webhook", }, @@ -339,7 +335,7 @@ func TestA2AServer_PushNotifications(t *testing.T) { t.Run("PushNotification_Get", func(t *testing.T) { // Configure mock task manager mockTM.pushNotificationGetResponse = &protocol.TaskPushNotificationConfig{ - ID: "test-push-task", + TaskID: "test-push-task", PushNotificationConfig: protocol.PushNotificationConfig{ URL: "https://example.com/webhook", }, @@ -400,16 +396,23 @@ func TestA2AServer_Resubscribe(t *testing.T) { t.Run("Resubscribe_Success", func(t *testing.T) { // Configure mock events workingEvent := protocol.TaskStatusUpdateEvent{ - ID: "resubscribe-task", - Status: protocol.TaskStatus{State: protocol.TaskStateWorking}, - Final: false, + TaskID: "resubscribe-task", + ContextID: "test-context", + Kind: protocol.KindTaskStatusUpdate, + Status: protocol.TaskStatus{State: protocol.TaskStateWorking}, } + finalPtr := true completedEvent := protocol.TaskStatusUpdateEvent{ - ID: "resubscribe-task", - Status: protocol.TaskStatus{State: protocol.TaskStateCompleted}, - Final: true, + TaskID: "resubscribe-task", + ContextID: "test-context", + Kind: protocol.KindTaskStatusUpdate, + Status: protocol.TaskStatus{State: protocol.TaskStateCompleted}, + Final: &finalPtr, + } + mockTM.SubscribeEvents = []protocol.StreamingMessageEvent{ + {Result: &workingEvent}, + {Result: &completedEvent}, } - mockTM.SubscribeEvents = []protocol.TaskEvent{workingEvent, completedEvent} mockTM.SubscribeError = nil // Add task to mock task manager to ensure it exists @@ -529,23 +532,28 @@ func TestA2AServer_HandleTasksSendSubscribe(t *testing.T) { t.Run("Successful Subscription", func(t *testing.T) { // Create SSE event to be sent statusUpdate := protocol.TaskStatusUpdateEvent{ - ID: "test-sub-task", - Status: protocol.TaskStatus{State: protocol.TaskStateWorking}, - Final: false, + TaskID: "test-sub-task", + ContextID: "test-context", + Kind: protocol.KindTaskStatusUpdate, + Status: protocol.TaskStatus{State: protocol.TaskStateWorking}, } artifactUpdate := protocol.TaskArtifactUpdateEvent{ - ID: "test-sub-task", + TaskID: "test-sub-task", + ContextID: "test-context", + Kind: protocol.KindTaskArtifactUpdate, Artifact: protocol.Artifact{ - Name: stringPtr("test-artifact"), - LastChunk: boolPtr(true), - Parts: []protocol.Part{protocol.NewTextPart("Artifact content")}, + ArtifactID: "test-artifact-1", + Name: stringPtr("test-artifact"), + Parts: []protocol.Part{protocol.NewTextPart("Artifact content")}, }, - Final: false, } + finalPtr := true finalUpdate := protocol.TaskStatusUpdateEvent{ - ID: "test-sub-task", - Status: protocol.TaskStatus{State: protocol.TaskStateCompleted}, - Final: true, + TaskID: "test-sub-task", + ContextID: "test-context", + Kind: protocol.KindTaskStatusUpdate, + Status: protocol.TaskStatus{State: protocol.TaskStateCompleted}, + Final: &finalPtr, } params := protocol.SendTaskParams{ @@ -592,10 +600,10 @@ func TestA2AServer_HandleTasksSendSubscribe(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, "text/event-stream", resp.Header.Get("Content-Type")) - // Send test events through channel - events <- statusUpdate - events <- artifactUpdate - events <- finalUpdate + // Send test events through channel - use pointers for all events + events <- &statusUpdate + events <- &artifactUpdate // Fix: use pointer + events <- &finalUpdate // Read the SSE events reader := bufio.NewReader(resp.Body) diff --git a/server/server_test.go b/server/server_test.go index b3dda6a..2f09d5c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -33,13 +33,14 @@ import ( func defaultAgentCard() AgentCard { // Corrected based on types.go definition desc := "Agent used for server testing." + streaming := true return AgentCard{ Name: "Test Agent", - Description: &desc, + Description: desc, URL: "http://localhost/test-agent", // Example URL Version: "test-agent-v0.1.0", Capabilities: AgentCapabilities{ - Streaming: true, + Streaming: &streaming, }, DefaultInputModes: []string{"text"}, DefaultOutputModes: []string{"text", "artifact"}, @@ -263,22 +264,28 @@ func TestA2ASrv_HandleTasksSendSub_SSE(t *testing.T) { // Configure mock events event1 := protocol.TaskStatusUpdateEvent{ - ID: taskID, + TaskID: taskID, Status: protocol.TaskStatus{State: protocol.TaskStateWorking}, } event2 := protocol.TaskArtifactUpdateEvent{ - ID: taskID, + TaskID: taskID, Artifact: protocol.Artifact{ - Index: 0, - Parts: []protocol.Part{protocol.NewTextPart("Intermediate result")}, + ArtifactID: "test-artifact-1", + Parts: []protocol.Part{protocol.NewTextPart("Intermediate result")}, }, } + final := true event3 := protocol.TaskStatusUpdateEvent{ - ID: taskID, + TaskID: taskID, Status: protocol.TaskStatus{State: protocol.TaskStateCompleted}, - Final: true, + Final: &final, + } + // Wrap events in StreamingMessageEvent + mockTM.SubscribeEvents = []protocol.StreamingMessageEvent{ + {Result: &event1}, + {Result: &event2}, + {Result: &event3}, } - mockTM.SubscribeEvents = []protocol.TaskEvent{event1, event2, event3} mockTM.SubscribeError = nil // Prepare SSE request @@ -309,7 +316,7 @@ func TestA2ASrv_HandleTasksSendSub_SSE(t *testing.T) { // Read and verify SSE events reader := sse.NewEventReader(resp.Body) // Use the client's SSE reader - receivedEvents := []protocol.TaskEvent{} + receivedEvents := []protocol.StreamingMessageEvent{} ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -340,30 +347,30 @@ func TestA2ASrv_HandleTasksSendSub_SSE(t *testing.T) { } eventBytes := jsonRPCResponse.Result - var event protocol.TaskEvent + var event *protocol.StreamingMessageEvent switch eventType { case "task_status_update": var statusEvent protocol.TaskStatusUpdateEvent if err := json.Unmarshal(eventBytes, &statusEvent); err != nil { t.Fatalf("Failed to unmarshal task_status_update: %v. Data: %s", err, string(eventBytes)) } - event = statusEvent + event = &protocol.StreamingMessageEvent{Result: &statusEvent} case "task_artifact_update": var artifactEvent protocol.TaskArtifactUpdateEvent if err := json.Unmarshal(eventBytes, &artifactEvent); err != nil { t.Fatalf("Failed to unmarshal task_artifact_update: %v. Data: %s", err, string(eventBytes)) } - event = artifactEvent + event = &protocol.StreamingMessageEvent{Result: &artifactEvent} case "close": // Handle potential close event t.Logf("Received close event: %s", string(data)) - break + return default: t.Logf("Skipping unknown event type: %s", eventType) continue } if event != nil { - receivedEvents = append(receivedEvents, event) + receivedEvents = append(receivedEvents, *event) } // Check context cancellation (e.g., test timeout) @@ -374,12 +381,12 @@ func TestA2ASrv_HandleTasksSendSub_SSE(t *testing.T) { require.Greater(t, len(receivedEvents), 0, "Should have received at least one event") var lastStatusEvent protocol.TaskStatusUpdateEvent for i := len(receivedEvents) - 1; i >= 0; i-- { - if statusEvent, ok := receivedEvents[i].(protocol.TaskStatusUpdateEvent); ok { - lastStatusEvent = statusEvent + if statusEvent, ok := receivedEvents[i].Result.(*protocol.TaskStatusUpdateEvent); ok { + lastStatusEvent = *statusEvent break } } - require.NotEmpty(t, lastStatusEvent.ID, "Should have received at least one status update event") + require.NotEmpty(t, lastStatusEvent.TaskID, "Should have received at least one status update event") assert.Equal(t, protocol.TaskStateCompleted, lastStatusEvent.Status.State, "State of last status event should be 'completed'") } @@ -388,6 +395,8 @@ func getCurrentTimestamp() string { return time.Now().UTC().Format(time.RFC3339) } +var _ taskmanager.TaskManager = (*mockTaskManager)(nil) + // mockTaskManager implements the taskmanager.TaskManager interface for testing. type mockTaskManager struct { mu sync.Mutex @@ -401,10 +410,10 @@ type mockTaskManager struct { GetError error CancelResponse *protocol.Task CancelError error - SubscribeEvents []protocol.TaskEvent // Events to send for subscription + SubscribeEvents []protocol.StreamingMessageEvent // Updated to use StreamingMessageEvent SubscribeError error - // Additional fields for tests + // Additional fields for tests (deprecated) SendTaskSubscribeStream chan protocol.TaskEvent SendTaskSubscribeError error @@ -413,6 +422,12 @@ type mockTaskManager struct { pushNotificationSetError error pushNotificationGetResponse *protocol.TaskPushNotificationConfig pushNotificationGetError error + + // New message handling fields + sendMessageResponse *protocol.MessageResult + sendMessageError error + sendMessageStreamEvents []protocol.StreamingMessageEvent + sendMessageStreamError error } // newMockTaskManager creates a new MockTaskManager for testing. @@ -422,7 +437,76 @@ func newMockTaskManager() *mockTaskManager { } } -// OnSendTask implements the TaskManager interface. +// OnSendMessage implements the TaskManager interface. +func (m *mockTaskManager) OnSendMessage( + ctx context.Context, + request protocol.SendMessageParams, +) (*protocol.MessageResult, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.sendMessageError != nil { + return nil, m.sendMessageError + } + + if m.sendMessageResponse != nil { + return m.sendMessageResponse, nil + } + + // Default behavior: create a simple message response + return &protocol.MessageResult{ + Result: &request.Message, + }, nil +} + +// OnSendMessageStream implements the TaskManager interface. +func (m *mockTaskManager) OnSendMessageStream( + ctx context.Context, + request protocol.SendMessageParams, +) (<-chan protocol.StreamingMessageEvent, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.sendMessageStreamError != nil { + return nil, m.sendMessageStreamError + } + + // Create a channel and send events + eventCh := make(chan protocol.StreamingMessageEvent, len(m.sendMessageStreamEvents)+1) + + // Send configured events in background + if len(m.sendMessageStreamEvents) > 0 { + go func() { + defer close(eventCh) + for _, event := range m.sendMessageStreamEvents { + select { + case <-ctx.Done(): + return + case eventCh <- event: + // Continue sending events + } + } + }() + } else { + // Default behavior: send the message back as a streaming event + go func() { + defer close(eventCh) + event := protocol.StreamingMessageEvent{ + Result: &request.Message, + } + select { + case <-ctx.Done(): + return + case eventCh <- event: + // Event sent + } + }() + } + + return eventCh, nil +} + +// OnSendTask implements the TaskManager interface (deprecated). func (m *mockTaskManager) OnSendTask( ctx context.Context, params protocol.SendTaskParams, @@ -430,33 +514,23 @@ func (m *mockTaskManager) OnSendTask( m.mu.Lock() defer m.mu.Unlock() - // Return configured error if set if m.SendError != nil { return nil, m.SendError } - // Validate required fields - if params.ID == "" { - return nil, jsonrpc.ErrInvalidParams("task ID is required") - } - - if len(params.Message.Parts) == 0 { - return nil, jsonrpc.ErrInvalidParams("message must have at least one part") - } - - // Return configured response if set if m.SendResponse != nil { - // Store for later retrieval - m.tasks[m.SendResponse.ID] = m.SendResponse return m.SendResponse, nil } - // Default behavior: create a simple task - task := protocol.NewTask(params.ID, params.SessionID) - now := getCurrentTimestamp() + // Create a new task + var contextID string + if params.Message.ContextID != nil { + contextID = *params.Message.ContextID + } + task := protocol.NewTask(params.ID, contextID) task.Status = protocol.TaskStatus{ State: protocol.TaskStateSubmitted, - Timestamp: now, + Timestamp: getCurrentTimestamp(), } // Store for later retrieval @@ -514,7 +588,7 @@ func (m *mockTaskManager) OnCancelTask( return task, nil } -// OnSendTaskSubscribe implements the TaskManager interface. +// OnSendTaskSubscribe implements the TaskManager interface (deprecated). func (m *mockTaskManager) OnSendTaskSubscribe( ctx context.Context, params protocol.SendTaskParams, ) (<-chan protocol.TaskEvent, error) { @@ -536,7 +610,11 @@ func (m *mockTaskManager) OnSendTaskSubscribe( } // Create a task like OnSendTask would - task := protocol.NewTask(params.ID, params.SessionID) + var contextID string + if params.Message.ContextID != nil { + contextID = *params.Message.ContextID + } + task := protocol.NewTask(params.ID, contextID) task.Status = protocol.TaskStatus{ State: protocol.TaskStateSubmitted, Timestamp: getCurrentTimestamp(), @@ -551,59 +629,75 @@ func (m *mockTaskManager) OnSendTaskSubscribe( // Send configured events in background if len(m.SubscribeEvents) > 0 { go func() { - for _, event := range m.SubscribeEvents { + defer close(eventCh) + for _, streamEvent := range m.SubscribeEvents { + // Convert StreamingMessageEvent to TaskEvent for backward compatibility + var taskEvent protocol.TaskEvent + switch e := streamEvent.Result.(type) { + case *protocol.TaskStatusUpdateEvent: + taskEvent = e + case *protocol.TaskArtifactUpdateEvent: + taskEvent = e + default: + // Skip unsupported event types + continue + } + select { case <-ctx.Done(): - close(eventCh) return - case eventCh <- event: + case eventCh <- taskEvent: // If this is the final event, close the channel - if event.IsFinal() { - close(eventCh) + if taskEvent.IsFinal() { return } } } - // If we didn't have a final event, close the channel anyway - close(eventCh) }() } else { // No events configured, send a default working and completed status go func() { + defer close(eventCh) + // Working status + var eventContextID string + if params.Message.ContextID != nil { + eventContextID = *params.Message.ContextID + } workingEvent := protocol.TaskStatusUpdateEvent{ - ID: params.ID, + TaskID: params.ID, + ContextID: eventContextID, + Kind: protocol.KindTaskStatusUpdate, Status: protocol.TaskStatus{ State: protocol.TaskStateWorking, Timestamp: getCurrentTimestamp(), }, - Final: false, } // Completed status + final := true completedEvent := protocol.TaskStatusUpdateEvent{ - ID: params.ID, + TaskID: params.ID, + ContextID: eventContextID, + Kind: protocol.KindTaskStatusUpdate, + Final: &final, Status: protocol.TaskStatus{ State: protocol.TaskStateCompleted, Timestamp: getCurrentTimestamp(), }, - Final: true, } select { case <-ctx.Done(): - close(eventCh) return - case eventCh <- workingEvent: + case eventCh <- &workingEvent: // Continue } select { case <-ctx.Done(): - close(eventCh) return - case eventCh <- completedEvent: - close(eventCh) + case eventCh <- &completedEvent: return } }() @@ -629,8 +723,9 @@ func (m *mockTaskManager) OnPushNotificationSet( // Default implementation if response not configured return &protocol.TaskPushNotificationConfig{ - ID: params.ID, + RPCID: params.RPCID, PushNotificationConfig: params.PushNotificationConfig, + TaskID: params.TaskID, }, nil } @@ -656,7 +751,7 @@ func (m *mockTaskManager) OnPushNotificationGet( // OnResubscribe implements the TaskManager interface for resubscribing to task events. func (m *mockTaskManager) OnResubscribe( ctx context.Context, params protocol.TaskIDParams, -) (<-chan protocol.TaskEvent, error) { +) (<-chan protocol.StreamingMessageEvent, error) { m.mu.Lock() defer m.mu.Unlock() @@ -671,45 +766,42 @@ func (m *mockTaskManager) OnResubscribe( } // Create a channel and send events - eventCh := make(chan protocol.TaskEvent, len(m.SubscribeEvents)+1) + eventCh := make(chan protocol.StreamingMessageEvent, len(m.SubscribeEvents)+1) // Send configured events in background if len(m.SubscribeEvents) > 0 { go func() { - for _, event := range m.SubscribeEvents { + defer close(eventCh) + for _, streamEvent := range m.SubscribeEvents { select { case <-ctx.Done(): - close(eventCh) return - case eventCh <- event: - // If this is the final event, close the channel - if event.IsFinal() { - close(eventCh) - return - } + case eventCh <- streamEvent: + // Continue sending events } } - // If we didn't have a final event, close the channel anyway - close(eventCh) }() } else { // No events configured, send a default completed status go func() { - completedEvent := protocol.TaskStatusUpdateEvent{ - ID: params.ID, + defer close(eventCh) + completedEvent := &protocol.TaskStatusUpdateEvent{ + TaskID: params.ID, Status: protocol.TaskStatus{ State: protocol.TaskStateCompleted, Timestamp: getCurrentTimestamp(), }, - Final: true, + Kind: protocol.KindTaskStatusUpdate, + } + + streamEvent := protocol.StreamingMessageEvent{ + Result: completedEvent, } select { case <-ctx.Done(): - close(eventCh) return - case eventCh <- completedEvent: - close(eventCh) + case eventCh <- streamEvent: return } }() @@ -744,16 +836,19 @@ func (m *mockTaskManager) ProcessTask( return task, nil } -// mockProcessor is a mock implementation of taskmanager.Processor +// mockProcessor is a mock implementation of taskmanager.MessageProcessor type mockProcessor struct{} -func (m *mockProcessor) Process( +func (m *mockProcessor) ProcessMessage( ctx context.Context, - taskID string, message protocol.Message, - handle taskmanager.TaskHandle, -) error { - return nil + options taskmanager.ProcessOptions, + handle taskmanager.TaskHandler, +) (*taskmanager.MessageProcessingResult, error) { + // Simple echo processor for testing + return &taskmanager.MessageProcessingResult{ + Result: &message, + }, nil } // Test for push notification authenticator integration diff --git a/server/types.go b/server/types.go index 4cd7a1d..44227df 100644 --- a/server/types.go +++ b/server/types.go @@ -7,18 +7,99 @@ // Package server contains the A2A server implementation and related types. package server -import ( - "trpc.group/trpc-go/trpc-a2a-go/protocol" +// SecurityScheme represents an authentication scheme supported by the agent. +// Based on A2A 0.2.2 specification. +type SecurityScheme struct { + // Type is the type of security scheme. + Type SecuritySchemeType `json:"type"` + // Description is an optional description of the scheme. + Description *string `json:"description,omitempty"` + // Name is the name of the header/query parameter (for apiKey). + Name *string `json:"name,omitempty"` + // In specifies where to include the key (for apiKey). + In *SecuritySchemeIn `json:"in,omitempty"` + // Scheme is the HTTP authentication scheme (for http type). + Scheme *string `json:"scheme,omitempty"` + // BearerFormat is the hint for the bearer token format. + BearerFormat *string `json:"bearerFormat,omitempty"` + // Flows contains OAuth2 flow definitions. + Flows *OAuthFlows `json:"flows,omitempty"` + // OpenIDConnectURL is the OpenID Connect URL. + OpenIDConnectURL *string `json:"openIdConnectUrl,omitempty"` +} + +// SecuritySchemeType represents the type of security scheme. +type SecuritySchemeType string + +const ( + // SecuritySchemeTypeAPIKey represents API key authentication. + SecuritySchemeTypeAPIKey SecuritySchemeType = "apiKey" + // SecuritySchemeTypeHTTP represents HTTP authentication. + SecuritySchemeTypeHTTP SecuritySchemeType = "http" + // SecuritySchemeTypeOAuth2 represents OAuth2 authentication. + SecuritySchemeTypeOAuth2 SecuritySchemeType = "oauth2" + // SecuritySchemeTypeOpenIDConnect represents OpenID Connect authentication. + SecuritySchemeTypeOpenIDConnect SecuritySchemeType = "openIdConnect" ) +// SecuritySchemeIn represents where to include the security credentials. +type SecuritySchemeIn string + +const ( + // SecuritySchemeInQuery indicates the credential should be in query parameters. + SecuritySchemeInQuery SecuritySchemeIn = "query" + // SecuritySchemeInHeader indicates the credential should be in HTTP headers. + SecuritySchemeInHeader SecuritySchemeIn = "header" + // SecuritySchemeInCookie indicates the credential should be in cookies. + SecuritySchemeInCookie SecuritySchemeIn = "cookie" +) + +// OAuthFlows represents OAuth2 flow configurations. +type OAuthFlows struct { + // AuthorizationCode contains authorization code flow configuration. + AuthorizationCode *OAuthFlow `json:"authorizationCode,omitempty"` + // Implicit contains implicit flow configuration. + Implicit *OAuthFlow `json:"implicit,omitempty"` + // Password contains password flow configuration. + Password *OAuthFlow `json:"password,omitempty"` + // ClientCredentials contains client credentials flow configuration. + ClientCredentials *OAuthFlow `json:"clientCredentials,omitempty"` +} + +// OAuthFlow represents a single OAuth2 flow configuration. +type OAuthFlow struct { + // AuthorizationURL is the authorization URL (for authorizationCode and implicit). + AuthorizationURL *string `json:"authorizationUrl,omitempty"` + // TokenURL is the token URL. + TokenURL string `json:"tokenUrl"` + // RefreshURL is the refresh URL. + RefreshURL *string `json:"refreshUrl,omitempty"` + // Scopes are the available scopes. + Scopes map[string]string `json:"scopes,omitempty"` +} + +// AgentExtension represents an agent extension. +type AgentExtension struct { + // URI is the extension URI. + URI string `json:"uri"` + // Required indicates if the extension is required. + Required *bool `json:"required,omitempty"` + // Description is an optional description. + Description *string `json:"description,omitempty"` + // Params contains extension-specific parameters. + Params map[string]interface{} `json:"params,omitempty"` +} + // AgentCapabilities defines the capabilities supported by an agent. type AgentCapabilities struct { // Streaming is a flag indicating if the agent supports streaming responses. - Streaming bool `json:"streaming"` + Streaming *bool `json:"streaming,omitempty"` // PushNotifications is a flag indicating if the agent can push notifications. - PushNotifications bool `json:"pushNotifications"` + PushNotifications *bool `json:"pushNotifications,omitempty"` // StateTransitionHistory is a flag indicating if the agent can provide task history. - StateTransitionHistory bool `json:"stateTransitionHistory"` + StateTransitionHistory *bool `json:"stateTransitionHistory,omitempty"` + // Extensions are the supported agent extensions. + Extensions []AgentExtension `json:"extensions,omitempty"` } // AgentSkill describes a specific capability or function of the agent. @@ -58,28 +139,34 @@ type AgentAuthentication struct { } // AgentCard is the metadata structure describing an A2A agent. -// This is typically returned by the agent_get_card method. +// Updated for A2A 0.2.2 specification compliance. type AgentCard struct { // Name is the name of the agent. Name string `json:"name"` - // Description is an optional description of the agent. - Description *string `json:"description,omitempty"` + // Description is the description of the agent (required in 0.2.2). + Description string `json:"description"` // URL is the endpoint URL where the agent is hosted. URL string `json:"url"` // Provider is an optional provider information. Provider *AgentProvider `json:"provider,omitempty"` + // IconURL is an optional URL to the agent's icon. + IconURL *string `json:"iconUrl,omitempty"` // Version is the agent version string. Version string `json:"version"` // DocumentationURL is an optional link to documentation. DocumentationURL *string `json:"documentationUrl,omitempty"` // Capabilities are the declared capabilities of the agent. Capabilities AgentCapabilities `json:"capabilities"` - // Authentication is an optional authentication details. - Authentication *protocol.AuthenticationInfo `json:"authentication,omitempty"` - // DefaultInputModes are the default input modes if not specified per skill. - DefaultInputModes []string `json:"defaultInputModes,omitempty"` - // DefaultOutputModes are the default output modes if not specified per skill. - DefaultOutputModes []string `json:"defaultOutputModes,omitempty"` - // Skills are optional list of specific skills. + // SecuritySchemes define the security schemes supported by the agent. + SecuritySchemes map[string]SecurityScheme `json:"securitySchemes,omitempty"` + // Security defines the security requirements for the agent. + Security []map[string][]string `json:"security,omitempty"` + // DefaultInputModes are the default input modes (required in 0.2.2). + DefaultInputModes []string `json:"defaultInputModes"` + // DefaultOutputModes are the default output modes (required in 0.2.2). + DefaultOutputModes []string `json:"defaultOutputModes"` + // Skills are the list of specific skills (required in 0.2.2). Skills []AgentSkill `json:"skills"` + // SupportsAuthenticatedExtendedCard indicates if the agent supports authenticated extended card. + SupportsAuthenticatedExtendedCard *bool `json:"supportsAuthenticatedExtendedCard,omitempty"` } diff --git a/taskmanager/errors.go b/taskmanager/errors.go index a30b78b..78cdfb7 100644 --- a/taskmanager/errors.go +++ b/taskmanager/errors.go @@ -14,15 +14,81 @@ import ( "trpc.group/trpc-go/trpc-a2a-go/protocol" ) -// Custom JSON-RPC error codes specific to the TaskManager. +// JSON-RPC standard error codes const ( - ErrCodeTaskNotFound int = -32001 // Custom server error code range. - ErrCodeTaskFinal int = -32002 - ErrCodePushNotificationNotConfigured int = -32003 + ErrCodeJSONParse int = -32700 // Invalid JSON was received by the server + ErrCodeInvalidRequest int = -32600 // The JSON sent is not a valid Request object + ErrCodeMethodNotFound int = -32601 // The method does not exist or is not available + ErrCodeInvalidParams int = -32602 // Invalid method parameter(s) + ErrCodeInternalError int = -32603 // Internal JSON-RPC error ) +// Custom JSON-RPC error codes specific to the A2A specification +const ( + ErrCodeTaskNotFound int = -32001 // Task not found + ErrCodeTaskNotCancelable int = -32002 // Task cannot be canceled + ErrCodePushNotificationNotSupported int = -32003 // Push Notification is not supported + ErrCodeUnsupportedOperation int = -32004 // This operation is not supported + ErrCodeContentTypeNotSupported int = -32005 // Incompatible content types + ErrCodeInvalidAgentResponse int = -32006 // Invalid agent response +) + +// ErrCodeTaskFinal is deprecated: Use ErrCodeTaskNotCancelable instead +const ErrCodeTaskFinal int = -32002 + +// ErrCodePushNotificationNotConfigured is deprecated: Use ErrCodePushNotificationNotSupported instead +const ErrCodePushNotificationNotConfigured int = -32003 + +// Standard JSON-RPC error functions + +// ErrJSONParse creates a JSON-RPC error for invalid JSON payload. +func ErrJSONParse(details string) *jsonrpc.Error { + return &jsonrpc.Error{ + Code: ErrCodeJSONParse, + Message: "Invalid JSON payload", + Data: details, + } +} + +// ErrInvalidRequest creates a JSON-RPC error for invalid request object. +func ErrInvalidRequest(details string) *jsonrpc.Error { + return &jsonrpc.Error{ + Code: ErrCodeInvalidRequest, + Message: "Request payload validation error", + Data: details, + } +} + +// ErrMethodNotFound creates a JSON-RPC error for method not found. +func ErrMethodNotFound(method string) *jsonrpc.Error { + return &jsonrpc.Error{ + Code: ErrCodeMethodNotFound, + Message: "Method not found", + Data: fmt.Sprintf("Method '%s' not found", method), + } +} + +// ErrInvalidParams creates a JSON-RPC error for invalid parameters. +func ErrInvalidParams(details string) *jsonrpc.Error { + return &jsonrpc.Error{ + Code: ErrCodeInvalidParams, + Message: "Invalid parameters", + Data: details, + } +} + +// ErrInternalError creates a JSON-RPC error for internal server error. +func ErrInternalError(details string) *jsonrpc.Error { + return &jsonrpc.Error{ + Code: ErrCodeInternalError, + Message: "Internal error", + Data: details, + } +} + +// A2A specific error functions + // ErrTaskNotFound creates a JSON-RPC error for task not found. -// Exported function. func ErrTaskNotFound(taskID string) *jsonrpc.Error { return &jsonrpc.Error{ Code: ErrCodeTaskNotFound, @@ -31,20 +97,63 @@ func ErrTaskNotFound(taskID string) *jsonrpc.Error { } } +// ErrTaskNotCancelable creates a JSON-RPC error for task that cannot be canceled. +func ErrTaskNotCancelable(taskID string, state protocol.TaskState) *jsonrpc.Error { + return &jsonrpc.Error{ + Code: ErrCodeTaskNotCancelable, + Message: "Task cannot be canceled", + Data: fmt.Sprintf("Task '%s' is in state '%s' and cannot be canceled", taskID, state), + } +} + +// ErrPushNotificationNotSupported creates a JSON-RPC error for unsupported push notifications. +func ErrPushNotificationNotSupported() *jsonrpc.Error { + return &jsonrpc.Error{ + Code: ErrCodePushNotificationNotSupported, + Message: "Push Notification is not supported", + Data: "This agent does not support push notifications", + } +} + +// ErrUnsupportedOperation creates a JSON-RPC error for unsupported operations. +func ErrUnsupportedOperation(operation string) *jsonrpc.Error { + return &jsonrpc.Error{ + Code: ErrCodeUnsupportedOperation, + Message: "This operation is not supported", + Data: fmt.Sprintf("Operation '%s' is not supported by this agent", operation), + } +} + +// ErrContentTypeNotSupported creates a JSON-RPC error for incompatible content types. +func ErrContentTypeNotSupported(contentType string) *jsonrpc.Error { + return &jsonrpc.Error{ + Code: ErrCodeContentTypeNotSupported, + Message: "Incompatible content types", + Data: fmt.Sprintf("Content type '%s' is not supported", contentType), + } +} + +// ErrInvalidAgentResponse creates a JSON-RPC error for invalid agent response. +func ErrInvalidAgentResponse(details string) *jsonrpc.Error { + return &jsonrpc.Error{ + Code: ErrCodeInvalidAgentResponse, + Message: "Invalid agent response", + Data: details, + } +} + +// Deprecated functions for backward compatibility + // ErrTaskFinalState creates a JSON-RPC error for attempting an operation on a task // that is already in a final state (completed, failed, cancelled). -// Exported function. +// Deprecated: Use ErrTaskNotCancelable instead. func ErrTaskFinalState(taskID string, state protocol.TaskState) *jsonrpc.Error { - return &jsonrpc.Error{ - Code: ErrCodeTaskFinal, - Message: "Task is in final state", - Data: fmt.Sprintf("Task '%s' is already in final state: %s", taskID, state), - } + return ErrTaskNotCancelable(taskID, state) } // ErrPushNotificationNotConfigured creates a JSON-RPC error for when push notifications // haven't been configured for a task. -// Exported function. +// Deprecated: Use ErrPushNotificationNotSupported instead. func ErrPushNotificationNotConfigured(taskID string) *jsonrpc.Error { return &jsonrpc.Error{ Code: ErrCodePushNotificationNotConfigured, diff --git a/taskmanager/interface.go b/taskmanager/interface.go index f0635bb..d0ac122 100644 --- a/taskmanager/interface.go +++ b/taskmanager/interface.go @@ -13,84 +13,218 @@ import ( "trpc.group/trpc-go/trpc-a2a-go/protocol" ) -// TaskHandle provides methods for the agent logic (TaskProcessor) to interact +// ProcessOptions contains configuration options for processing messages +type ProcessOptions struct { + // Blocking indicates whether this is a blocking request + // If true, the user should wait for processing completion before returning the final result + // If false, the user can immediately return the initial state and update later through other means + Blocking bool + + // HistoryLength indicates the length of historical messages requested by the client + HistoryLength int + + // PushNotificationConfig contains push notification configuration + PushNotificationConfig *protocol.PushNotificationConfig + + // Streaming indicates whether this is a streaming request + // If true, the user should return event streams through the StreamingEvents channel + // If false, the user should return a single result through Result + Streaming bool +} + +// CancellableTask is a task that can be cancelled +type CancellableTask interface { + // Task returns the original task. + Task() *protocol.Task + + // Cancel cancels the task. + Cancel() +} + +// TaskHandler provides methods for the agent logic (MessageProcessor) to interact // with the task manager during processing. It encapsulates the necessary callbacks. -type TaskHandle interface { - // UpdateStatus updates the task's state and optional message. - // Returns an error if the task cannot be found or updated. - UpdateStatus(state protocol.TaskState, msg *protocol.Message) error - - // AddArtifact adds a new artifact to the task. - // Returns an error if the task cannot be found or updated. - AddArtifact(artifact protocol.Artifact) error - - // IsStreamingRequest returns true if the task was initiated via a streaming request - // (OnSendTaskSubscribe) rather than a synchronous request (OnSendTask). - // This allows the TaskProcessor to adapt its behavior based on the request type. - IsStreamingRequest() bool - - // GetSessionID returns the session ID for the task. - // If the task is not associated with a session, it returns nil. - GetSessionID() *string +type TaskHandler interface { + // BuildTask creates a new task and returns the task ID. + BuildTask(specificTaskID *string, contextID *string) (string, error) + + // UpdateTaskState updates the task's state and returns the updated task ID. + UpdateTaskState(taskID *string, state protocol.TaskState, message *protocol.Message) error + + // AddArtifact adds an artifact to the specified task. + AddArtifact(taskID *string, artifact protocol.Artifact, isFinal bool, needMoreData bool) error + + // SubScribeTask subscribes to the task and returns the task subscriber. + SubScribeTask(taskID *string) (TaskSubscriber, error) + + // GetTask returns the task by taskID. Returns an error if the task cannot be found. + GetTask(taskID *string) (CancellableTask, error) + + // CleanTask cleans up the task from storage. + CleanTask(taskID *string) error + + // GetMessageHistory returns the conversation history for the current context. + GetMessageHistory() []protocol.Message + + // GetContextID returns the context ID of the current message, if any. + GetContextID() string } -// TaskProcessor defines the interface for the core agent logic that processes a task. -// Implementations of this interface are injected into a TaskManager. -type TaskProcessor interface { - // Process executes the specific logic for a task. - // It receives the task ID, the initial message, and a TaskHandle for callbacks. - // It should use handle.Context() to check for cancellation. - // It should report progress and results via handle.UpdateStatus and handle.AddArtifact. - // Returning an error indicates the processing failed fundamentally. - Process(ctx context.Context, taskID string, initialMsg protocol.Message, handle TaskHandle) error +// MessageProcessingResult represents the result of processing a message. +type MessageProcessingResult struct { + // Result can be Message or Task + // When Streaming=false, use this field + // The framework will automatically handle whether to wait for the task to complete based on ProcessOptions.Blocking + Result protocol.UnaryMessageResult + + // StreamingEvents streaming event tunnel + // When Streaming=true, use this field + // Message、Task、TaskStatusUpdateEvent、TaskArtifactUpdateEvent is allowed to sent. + StreamingEvents TaskSubscriber } -// TaskProcessorWithStatusUpdate is an optional interface that can be implemented by TaskProcessor +// TaskSubscriber is a subscriber for a task +type TaskSubscriber interface { + // Send sends an event to the task subscriber, could be blocked if the channel is full + Send(event protocol.StreamingMessageEvent) error + + // Channel returns the channel of the task subscriber + Channel() <-chan protocol.StreamingMessageEvent + + // Closed returns true if the task subscriber is closed + Closed() bool + + // Close close the task subscriber + Close() +} + +// MessageProcessor defines the interface for processing A2A messages. +// This interface should be implemented by users to define their agent's behavior. +type MessageProcessor interface { + // ProcessMessage processes an incoming message and returns the result. + // + // Processing modes: + // 1. Non-streaming (options.Streaming=false): + // - Return MessageProcessingResult.Result (Message or Task) + // - The framework directly returns the user's result (if it's a non-final Task state, a reminder log will be printed) + // + // 2. Streaming (options.Streaming=true): + // - Return MessageProcessingResult.StreamingEvents channel + // - Multiple types of events can be sent through the channel: + // * protocol.Message - direct message reply + // * protocol.Task - task status + // * protocol.TaskStatusUpdateEvent - task status update + // * protocol.TaskArtifactUpdateEvent - artifact update + // - Users are responsible for closing the channel to end streaming transmission + // + // Parameters: + // - ctx: Request context + // - message: The incoming message to process + // - options: Processing options including blocking, streaming, history length, etc. + // - taskHandler: Task handler for accessing context, history, and task operations + // + // Returns: + // - MessageProcessingResult: Contains the result or streaming channel + // - error: Any error that occurred during processing + ProcessMessage( + ctx context.Context, + message protocol.Message, + options ProcessOptions, + taskHandler TaskHandler, + ) (*MessageProcessingResult, error) +} + +// MessageProcessorWithStatusUpdate is an optional interface that can be implemented by TaskProcessor // to receive notifications when the task status changes. -type TaskProcessorWithStatusUpdate interface { - TaskProcessor +type MessageProcessorWithStatusUpdate interface { + MessageProcessor // OnTaskStatusUpdate is called when the task status changes. // It receives the task ID, the new state, and the optional message. // It should return an error if the status update fails. - OnTaskStatusUpdate(ctx context.Context, taskID string, state protocol.TaskState, message *protocol.Message) error + OnTaskStatusUpdate( + ctx context.Context, + taskID string, + state protocol.TaskState, + message *protocol.Message, + ) error } // TaskManager defines the interface for managing A2A task lifecycles based on the protocol. // Implementations handle task creation, updates, retrieval, cancellation, and events, -// delegating the actual processing logic to an injected TaskProcessor. +// delegating the actual processing logic to an injected MessageProcessor. // This interface corresponds to the Task Service defined in the A2A Specification. // Exported interface. type TaskManager interface { - // OnSendTask handles a request corresponding to the 'tasks/send' RPC method. - // It creates and potentially starts processing a new task via the TaskProcessor. - // It returns the initial state of the task, possibly reflecting immediate processing results. - OnSendTask(ctx context.Context, request protocol.SendTaskParams) (*protocol.Task, error) - // OnSendTaskSubscribe handles a request corresponding to the 'tasks/sendSubscribe' RPC method. - // It creates a new task and returns a channel for receiving TaskEvent updates (streaming). - // It initiates asynchronous processing via the TaskProcessor. - // The channel will be closed when the task reaches a final state or an error occurs during setup/processing. - OnSendTaskSubscribe(ctx context.Context, request protocol.SendTaskParams) (<-chan protocol.TaskEvent, error) + // OnSendMessage handles a request corresponding to the 'message/send' RPC method. + // It creates and potentially starts processing a new message via the MessageProcessor. + // It returns the initial state of the message, possibly reflecting immediate processing results. + OnSendMessage( + ctx context.Context, + request protocol.SendMessageParams, + ) (*protocol.MessageResult, error) + + // OnSendMessageStream handles a request corresponding to the 'message/stream' RPC method. + // It creates a new message and returns a channel for receiving MessageEvent updates (streaming). + // It initiates asynchronous processing via the MessageProcessor. + // The channel will be closed when the message reaches a final state or an error occurs during setup/processing. + OnSendMessageStream( + ctx context.Context, + request protocol.SendMessageParams, + ) (<-chan protocol.StreamingMessageEvent, error) // OnGetTask handles a request corresponding to the 'tasks/get' RPC method. // It retrieves the current state of an existing task. - OnGetTask(ctx context.Context, params protocol.TaskQueryParams) (*protocol.Task, error) + OnGetTask( + ctx context.Context, + params protocol.TaskQueryParams, + ) (*protocol.Task, error) // OnCancelTask handles a request corresponding to the 'tasks/cancel' RPC method. // It requests the cancellation of an ongoing task. - // This typically involves canceling the context passed to the TaskProcessor. + // This typically involves canceling the context passed to the MessageProcessor. // It returns the task state after the cancellation attempt. - OnCancelTask(ctx context.Context, params protocol.TaskIDParams) (*protocol.Task, error) + OnCancelTask( + ctx context.Context, + params protocol.TaskIDParams, + ) (*protocol.Task, error) // OnPushNotificationSet handles a request corresponding to the 'tasks/pushNotification/set' RPC method. // It configures push notifications for a specific task. - OnPushNotificationSet(ctx context.Context, params protocol.TaskPushNotificationConfig) (*protocol.TaskPushNotificationConfig, error) + OnPushNotificationSet( + ctx context.Context, + params protocol.TaskPushNotificationConfig, + ) (*protocol.TaskPushNotificationConfig, error) // OnPushNotificationGet handles a request corresponding to the 'tasks/pushNotification/get' RPC method. // It retrieves the current push notification configuration for a task. - OnPushNotificationGet(ctx context.Context, params protocol.TaskIDParams) (*protocol.TaskPushNotificationConfig, error) + OnPushNotificationGet( + ctx context.Context, + params protocol.TaskIDParams, + ) (*protocol.TaskPushNotificationConfig, error) // OnResubscribe handles a request corresponding to the 'tasks/resubscribe' RPC method. // It reestablishes an SSE stream for an existing task. - OnResubscribe(ctx context.Context, params protocol.TaskIDParams) (<-chan protocol.TaskEvent, error) + OnResubscribe( + ctx context.Context, + params protocol.TaskIDParams, + ) (<-chan protocol.StreamingMessageEvent, error) + + // deprecated + // OnSendTask handles a request corresponding to the 'tasks/send' RPC method. + // It creates and potentially starts processing a new task via the MessageProcessor. + // It returns the initial state of the task, possibly reflecting immediate processing results. + OnSendTask( + ctx context.Context, + request protocol.SendTaskParams, + ) (*protocol.Task, error) + + // deprecated + // OnSendTaskSubscribe handles a request corresponding to the 'tasks/sendSubscribe' RPC method. + // It creates a new task and returns a channel for receiving TaskEvent updates (streaming). + // It initiates asynchronous processing via the MessageProcessor. + // The channel will be closed when the task reaches a final state or an error occurs during setup/processing. + OnSendTaskSubscribe( + ctx context.Context, + request protocol.SendTaskParams, + ) (<-chan protocol.TaskEvent, error) } diff --git a/taskmanager/memory.go b/taskmanager/memory.go index eb3bfb5..c41afcc 100644 --- a/taskmanager/memory.go +++ b/taskmanager/memory.go @@ -4,611 +4,752 @@ // // trpc-a2a-go is licensed under the Apache License Version 2.0. -// Package taskmanager provides task management interfaces, types, and implementations. +// Package messageprocessor provides implementations for processing A2A messages. + package taskmanager import ( "context" - "errors" "fmt" - "strings" "sync" + "sync/atomic" "time" "trpc.group/trpc-go/trpc-a2a-go/log" "trpc.group/trpc-go/trpc-a2a-go/protocol" ) -// MemoryTaskManager provides a concrete, memory-based implementation of the -// TaskManager interface. It manages tasks, messages, and subscribers in memory. -// It requires a TaskProcessor to handle the actual agent logic. -// It is safe for concurrent use. +const defaultMaxHistoryLength = 100 +const defaultCleanupInterval = 30 * time.Second +const defaultConversationTTL = 1 * time.Hour +const defaultTaskSubscriberBufferSize = 10 + +// ConversationHistory stores conversation history information +type ConversationHistory struct { + // MessageIDs is the list of message IDs, ordered by time + MessageIDs []string + // LastAccessTime is the last access time + LastAccessTime time.Time +} + +// MemoryCancellableTask is a task that can be cancelled +type MemoryCancellableTask struct { + task protocol.Task + cancelFunc context.CancelFunc + ctx context.Context +} + +// NewCancellableTask creates a new cancellable task +func NewCancellableTask(task protocol.Task) *MemoryCancellableTask { + cancelCtx, cancel := context.WithCancel(context.Background()) + return &MemoryCancellableTask{ + task: task, + cancelFunc: cancel, + ctx: cancelCtx, + } +} + +// Cancel cancels the task +func (t *MemoryCancellableTask) Cancel() { + t.cancelFunc() +} + +// Task returns the task +func (t *MemoryCancellableTask) Task() *protocol.Task { + return &t.task +} + +// MemoryTaskSubscriber is a subscriber for a task +type MemoryTaskSubscriber struct { + taskID string + eventQueue chan protocol.StreamingMessageEvent + lastAccessTime time.Time + closed atomic.Bool + mu sync.RWMutex +} + +// NewMemoryTaskSubscriber creates a new task subscriber with specified buffer length +func NewMemoryTaskSubscriber(taskID string, length int) *MemoryTaskSubscriber { + if length <= 0 { + length = defaultTaskSubscriberBufferSize // default buffer size + } + + eventQueue := make(chan protocol.StreamingMessageEvent, length) + + return &MemoryTaskSubscriber{ + taskID: taskID, + eventQueue: eventQueue, + lastAccessTime: time.Now(), + closed: atomic.Bool{}, + } +} + +// Close closes the task subscriber +func (s *MemoryTaskSubscriber) Close() { + s.mu.Lock() + defer s.mu.Unlock() + if !s.closed.Load() { + s.closed.Store(true) + close(s.eventQueue) + } +} + +// Channel returns the channel of the task subscriber +func (s *MemoryTaskSubscriber) Channel() <-chan protocol.StreamingMessageEvent { + return s.eventQueue +} + +// Closed returns true if the task subscriber is closed +func (s *MemoryTaskSubscriber) Closed() bool { + return s.closed.Load() +} + +// Send sends an event to the task subscriber +func (s *MemoryTaskSubscriber) Send(event protocol.StreamingMessageEvent) error { + if s.Closed() { + return fmt.Errorf("task subscriber is closed") + } + + s.mu.RLock() + defer s.mu.RUnlock() + if s.Closed() { + return fmt.Errorf("task subscriber is closed") + } + + s.lastAccessTime = time.Now() + + // Use select with default to avoid blocking + select { + case s.eventQueue <- event: + return nil + default: + return fmt.Errorf("event queue is full or closed") + } +} + +// GetLastAccessTime returns the last access time +func (s *MemoryTaskSubscriber) GetLastAccessTime() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() + return s.lastAccessTime +} + +// MemoryTaskManager is the implementation of the MemoryTaskManager interface type MemoryTaskManager struct { - // Processor is the agent logic processor. - Processor TaskProcessor - // Tasks is a map of task IDs to tasks. - Tasks map[string]*protocol.Task - // TasksMutex is a mutex for the Tasks map. - TasksMutex sync.RWMutex - // Messages is a map of task IDs to message history. - Messages map[string][]protocol.Message - // MessagesMutex is a mutex for the Messages map. - MessagesMutex sync.RWMutex - // Subscribers is a map of task IDs to subscriber channels. - Subscribers map[string][]chan<- protocol.TaskEvent - // SubMutex is a mutex for the Subscribers map. - SubMutex sync.RWMutex - // Contexts is a map of task IDs to cancellation functions. - Contexts map[string]context.CancelFunc - // ContextsMutex is a mutex for the Contexts map. - ContextsMutex sync.RWMutex - // PushNotifications is a map of task IDs to push notification configurations. - PushNotifications map[string]protocol.PushNotificationConfig - // PushNotificationsMutex is a mutex for the PushNotifications map. - PushNotificationsMutex sync.RWMutex -} - -// NewMemoryTaskManager creates a new instance with the provided TaskProcessor. -func NewMemoryTaskManager(processor TaskProcessor) (*MemoryTaskManager, error) { + // mu protects the following fields + mu sync.RWMutex + + // Processor is the user-provided message Processor + Processor MessageProcessor + + // Messages stores all Messages, indexed by messageID + // key: messageID, value: Message + Messages map[string]protocol.Message + + // Conversations stores the message history of each conversation, indexed by contextID + // key: contextID, value: ConversationHistory + Conversations map[string]*ConversationHistory + + // conversationMu protects the Conversations field + conversationMu sync.RWMutex + + // Tasks stores the task information, indexed by taskID + // key: taskID, value: Task + Tasks map[string]*MemoryCancellableTask + + // taskMu protects the Tasks field + taskMu sync.RWMutex + + // Subscribers stores the task subscribers + // key: taskID, value: TaskSubscriber list + // supports all event types: Message, Task, TaskStatusUpdateEvent, TaskArtifactUpdateEvent + Subscribers map[string][]*MemoryTaskSubscriber + + // PushNotifications stores the push notification configurations + // key: taskID, value: push notification configuration + PushNotifications map[string]protocol.TaskPushNotificationConfig + + // configuration options + maxHistoryLength int // max history message count +} + +// NewMemoryTaskManager creates a new MemoryTaskManager instance +func NewMemoryTaskManager(processor MessageProcessor, opts ...MemoryTaskManagerOption) (*MemoryTaskManager, error) { if processor == nil { - return nil, errors.New("task processor cannot be nil") + return nil, fmt.Errorf("processor cannot be nil") } - return &MemoryTaskManager{ + + // Apply default options + options := DefaultMemoryTaskManagerOptions() + + // Apply user options + for _, opt := range opts { + opt(options) + } + + manager := &MemoryTaskManager{ Processor: processor, - Tasks: make(map[string]*protocol.Task), - Messages: make(map[string][]protocol.Message), - Subscribers: make(map[string][]chan<- protocol.TaskEvent), - Contexts: make(map[string]context.CancelFunc), - PushNotifications: make(map[string]protocol.PushNotificationConfig), - }, nil + Messages: make(map[string]protocol.Message), + Conversations: make(map[string]*ConversationHistory), + Tasks: make(map[string]*MemoryCancellableTask), + Subscribers: make(map[string][]*MemoryTaskSubscriber), + PushNotifications: make(map[string]protocol.TaskPushNotificationConfig), + maxHistoryLength: options.MaxHistoryLength, + } + + // Start cleanup goroutine if enabled + if options.EnableCleanup { + go func() { + ticker := time.NewTicker(options.CleanupInterval) + defer ticker.Stop() + + for range ticker.C { + manager.CleanExpiredConversations(options.ConversationTTL) + } + }() + } + + return manager, nil } -// processTaskWithProcessor handles the common task processing logic. -// It creates a taskHandle, sets initial status, and calls the processor. -func (m *MemoryTaskManager) processTaskWithProcessor( +// ============================================================================= +// TaskManager interface implementation +// ============================================================================= + +// OnSendMessage handles the message/tasks request +func (m *MemoryTaskManager) OnSendMessage( ctx context.Context, - taskID string, - message protocol.Message, -) error { - handle := &memoryTaskHandle{ - taskID: taskID, - manager: m, + request protocol.SendMessageParams, +) (*protocol.MessageResult, error) { + log.Debugf("MemoryTaskManager: OnSendMessage for message %s", request.Message.MessageID) + + // process the request message + m.processRequestMessage(&request.Message) + + // process Configuration + options := m.processConfiguration(request.Configuration, request.Metadata) + options.Streaming = false // non-streaming processing + + // create MessageHandle + handle := &memoryTaskHandler{ + manager: m, + messageID: request.Message.MessageID, + ctx: ctx, } - // Set initial status to Working before calling Process - if err := m.UpdateTaskStatus(ctx, taskID, protocol.TaskStateWorking, nil); err != nil { - log.Errorf("Error setting initial Working status for task %s: %v", taskID, err) - return fmt.Errorf("failed to set initial working status: %w", err) + // call the user's message processor + result, err := m.Processor.ProcessMessage(ctx, request.Message, options, handle) + if err != nil { + return nil, fmt.Errorf("message processing failed: %w", err) } - // Delegate the actual processing to the injected processor - if err := m.Processor.Process(ctx, taskID, message, handle); err != nil { - log.Errorf("Processor failed for task %s: %v", taskID, err) - errMsg := &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{protocol.NewTextPart(err.Error())}, - } - // Log update error while still handling the processor error - if updateErr := m.UpdateTaskStatus( - ctx, - taskID, - protocol.TaskStateFailed, - errMsg, - ); updateErr != nil { - log.Errorf("Failed to update task %s status to failed: %v", taskID, updateErr) + if result == nil { + return nil, fmt.Errorf("processor returned nil result") + } + + // check if the user returned StreamingEvents for non-streaming request + if result.StreamingEvents != nil { + log.Infof("User returned StreamingEvents for non-streaming request, ignoring") + } + + if result.Result == nil { + return nil, fmt.Errorf("processor returned nil result for non-streaming request") + } + + switch result.Result.(type) { + case *protocol.Task: + case *protocol.Message: + default: + return nil, fmt.Errorf("processor returned unsupported result type %T for SendMessage request", result.Result) + } + + if message, ok := result.Result.(*protocol.Message); ok { + var contextID string + if request.Message.ContextID != nil { + contextID = *request.Message.ContextID } - return err + m.processReplyMessage(contextID, message) } - return nil + return &protocol.MessageResult{Result: result.Result}, nil } -// startTaskSubscribe starts processing a task in a goroutine that sends events to subscribers. -// It returns immediately, with the processing continuing asynchronously. -func (m *MemoryTaskManager) startTaskSubscribe( +// OnSendMessageStream handles message/stream requests +func (m *MemoryTaskManager) OnSendMessageStream( ctx context.Context, - taskID string, - message protocol.Message, -) { - // Create a handle for the processor to interact with the task - handle := &memoryTaskHandle{ - taskID: taskID, - manager: m, + request protocol.SendMessageParams, +) (<-chan protocol.StreamingMessageEvent, error) { + log.Debugf("MemoryTaskManager: OnSendMessageStream for message %s", request.Message.MessageID) + + m.processRequestMessage(&request.Message) + + // Process Configuration + options := m.processConfiguration(request.Configuration, request.Metadata) + options.Streaming = true // streaming mode + + // Create streaming MessageHandle + handle := &memoryTaskHandler{ + manager: m, + messageID: request.Message.MessageID, + ctx: ctx, } - log.Debugf("SSE Processor started for task %s", taskID) - - // Start the processor in a goroutine - go func() { - var err error - if err = m.Processor.Process(ctx, taskID, message, handle); err != nil { - log.Errorf("Processor failed for task %s in subscribe: %v", taskID, err) - if ctx.Err() != context.Canceled { - // Only update to failed if not already cancelled - errMsg := &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{protocol.NewTextPart(err.Error())}, - } - if updateErr := m.UpdateTaskStatus(ctx, taskID, protocol.TaskStateFailed, errMsg); updateErr != nil { - log.Errorf("Failed to update task %s status to failed: %v", taskID, updateErr) - } - } + // Call user's message processor + result, err := m.Processor.ProcessMessage(ctx, request.Message, options, handle) + if err != nil { + return nil, fmt.Errorf("message processing failed: %w", err) + } + + if result == nil || result.StreamingEvents == nil { + return nil, fmt.Errorf("processor returned nil result") + } + + return result.StreamingEvents.Channel(), nil +} + +// OnGetTask handles the tasks/get request +func (m *MemoryTaskManager) OnGetTask(ctx context.Context, params protocol.TaskQueryParams) (*protocol.Task, error) { + m.taskMu.RLock() + defer m.taskMu.RUnlock() + + task, exists := m.Tasks[params.ID] + if !exists { + return nil, fmt.Errorf("task not found: %s", params.ID) + } + + // return a copy of the task + taskCopy := *task.Task() + + // if the request contains history length, fill the message history + if params.HistoryLength != nil && *params.HistoryLength > 0 { + if taskCopy.ContextID != "" { + history := m.getConversationHistory(taskCopy.ContextID, *params.HistoryLength) + taskCopy.History = history } + } + + return &taskCopy, nil +} + +// OnCancelTask handles the tasks/cancel request +func (m *MemoryTaskManager) OnCancelTask(ctx context.Context, params protocol.TaskIDParams) (*protocol.Task, error) { + m.taskMu.Lock() + task, exists := m.Tasks[params.ID] + if !exists { + m.taskMu.Unlock() + return nil, fmt.Errorf("task not found: %s", params.ID) + } + + taskCopy := *task.Task() + m.taskMu.Unlock() - // Clean up the context regardless of how we finish - m.ContextsMutex.Lock() - delete(m.Contexts, taskID) - m.ContextsMutex.Unlock() + handle := &memoryTaskHandler{ + manager: m, + ctx: ctx, + } + handle.CleanTask(¶ms.ID) + taskCopy.Status.State = protocol.TaskStateCanceled + taskCopy.Status.Timestamp = time.Now().UTC().Format(time.RFC3339) - log.Debugf("Processor finished for task %s in subscribe (Error: %v). Goroutine exiting.", taskID, err) - }() + return &taskCopy, nil } -// OnSendTask handles the creation or retrieval of a task and initiates synchronous processing. -// It implements the TaskManager interface. -func (m *MemoryTaskManager) OnSendTask(ctx context.Context, params protocol.SendTaskParams) (*protocol.Task, error) { - _ = m.upsertTask(params) // Get or create task entry. Ignore return. - m.storeMessage(params.ID, params.Message) // Store the initial user message. +// OnPushNotificationSet handles tasks/pushNotificationConfig/set requests +func (m *MemoryTaskManager) OnPushNotificationSet( + ctx context.Context, + params protocol.TaskPushNotificationConfig, +) (*protocol.TaskPushNotificationConfig, error) { + m.mu.Lock() + defer m.mu.Unlock() - // Create a cancellable context for this specific task processing - taskCtx, cancel := context.WithCancel(ctx) - defer cancel() // Ensure context is cancelled eventually + // Store push notification configuration + m.PushNotifications[params.TaskID] = params + log.Debugf("MemoryTaskManager: Push notification config set for task %s", params.TaskID) + return ¶ms, nil +} - // Process the task - err := m.processTaskWithProcessor(taskCtx, params.ID, params.Message) +// OnPushNotificationGet handles tasks/pushNotificationConfig/get requests +func (m *MemoryTaskManager) OnPushNotificationGet( + ctx context.Context, + params protocol.TaskIDParams, +) (*protocol.TaskPushNotificationConfig, error) { + m.mu.RLock() + defer m.mu.RUnlock() - // Return the latest task state after processing - finalTask, e := m.getTaskInternal(params.ID) - if e != nil { - log.Errorf("Failed to get task %s after processing: %v", params.ID, e) + config, exists := m.PushNotifications[params.ID] + if !exists { + return nil, fmt.Errorf("push notification config not found for task: %s", params.ID) } - // Do not include e in the return value - return finalTask, err + return &config, nil } -// OnSendTaskSubscribe handles a tasks/sendSubscribe request with streaming response. -// It creates or updates a task based on the parameters, then returns a channel for status updates. -// The channel will receive events until the task completes, fails, is cancelled, or the context expires. -func (m *MemoryTaskManager) OnSendTaskSubscribe( +// OnResubscribe handles tasks/resubscribe requests +func (m *MemoryTaskManager) OnResubscribe( ctx context.Context, - params protocol.SendTaskParams, -) (<-chan protocol.TaskEvent, error) { - // Create a new task or update an existing one - task := m.upsertTask(params) - // Store the message that came with the request - m.storeMessage(params.ID, params.Message) - - // Create event channel for this specific subscriber - eventChan := make(chan protocol.TaskEvent, 10) // Buffered to prevent blocking sends - m.addSubscriber(params.ID, eventChan) - - // Create a cancellable context for the processor - processorCtx, cancel := context.WithCancel(ctx) - - // Store the cancel function - m.ContextsMutex.Lock() - m.Contexts[params.ID] = cancel - m.ContextsMutex.Unlock() - - // Set initial state if new (submitted -> working) - // This will generate the first event for subscribers - if task.Status.State == protocol.TaskStateSubmitted { - if err := m.UpdateTaskStatus(ctx, params.ID, protocol.TaskStateWorking, nil); err != nil { - m.removeSubscriber(params.ID, eventChan) - close(eventChan) - return nil, err - } + params protocol.TaskIDParams, +) (<-chan protocol.StreamingMessageEvent, error) { + m.taskMu.Lock() + defer m.taskMu.Unlock() + + // Check if task exists + _, exists := m.Tasks[params.ID] + if !exists { + return nil, fmt.Errorf("task not found: %s", params.ID) } - // Start the processor in a goroutine - m.startTaskSubscribe(processorCtx, params.ID, params.Message) + subscriber := NewMemoryTaskSubscriber(params.ID, defaultTaskSubscriberBufferSize) + + // Add to subscribers list + if _, exists := m.Subscribers[params.ID]; !exists { + m.Subscribers[params.ID] = make([]*MemoryTaskSubscriber, 0) + } + m.Subscribers[params.ID] = append(m.Subscribers[params.ID], subscriber) - // Return the channel for events - return eventChan, nil + return subscriber.eventQueue, nil } -// getTaskInternal retrieves the task without locking (caller must handle locks). -// Returns nil if not found. -func (m *MemoryTaskManager) getTaskInternal(taskID string) (*protocol.Task, error) { - return m.getTaskWithValidation(taskID) +// OnSendTask deprecated method empty implementation +func (m *MemoryTaskManager) OnSendTask(ctx context.Context, request protocol.SendTaskParams) (*protocol.Task, error) { + return nil, fmt.Errorf("OnSendTask is deprecated, use OnSendMessage instead") } -// OnGetTask retrieves the current state of a task, including optional message history. -// It implements the TaskManager interface. -func (m *MemoryTaskManager) OnGetTask(ctx context.Context, params protocol.TaskQueryParams) (*protocol.Task, error) { - task, err := m.getTaskWithValidation(params.ID) - if err != nil { - return nil, err // Already an ErrTaskNotFound or similar. - } - // Add message history if requested. - if params.HistoryLength != nil { - // historyLength == 0 means "get all history" - // historyLength > 0 means "get that many most recent messages" - // historyLength == nil means "don't include history" - m.MessagesMutex.RLock() - messages, historyExists := m.Messages[params.ID] - m.MessagesMutex.RUnlock() - if historyExists { - historyLen := len(messages) - requestedLen := *params.HistoryLength - var startIndex int - if requestedLen > 0 && requestedLen < historyLen { - startIndex = historyLen - requestedLen - } - // Make a copy of the history slice. - task.History = make([]protocol.Message, len(messages[startIndex:])) - copy(task.History, messages[startIndex:]) - // --> Add logging here to check type immediately after copy. - if len(task.History) > 0 && len(task.History[0].Parts) > 0 { - log.Debugf("DEBUG: Type in task.History[0].Parts[0] inside OnGetTask: %T", task.History[0].Parts[0]) +// OnSendTaskSubscribe deprecated method empty implementation +func (m *MemoryTaskManager) OnSendTaskSubscribe(ctx context.Context, request protocol.SendTaskParams) (<-chan protocol.TaskEvent, error) { + return nil, fmt.Errorf("OnSendTaskSubscribe is deprecated, use OnSendMessageStream instead") +} + +// ============================================================================= +// Internal helper methods +// ============================================================================= + +// storeMessage stores messages +func (m *MemoryTaskManager) storeMessage(message protocol.Message) { + m.conversationMu.Lock() + defer m.conversationMu.Unlock() + + // Store the message + m.Messages[message.MessageID] = message + + // If the message has a contextID, add it to conversation history + if message.ContextID != nil { + contextID := *message.ContextID + if _, exists := m.Conversations[contextID]; !exists { + m.Conversations[contextID] = &ConversationHistory{ + MessageIDs: make([]string, 0), + LastAccessTime: time.Now(), } - return task, nil + } + + // Add message ID to conversation history + m.Conversations[contextID].MessageIDs = append(m.Conversations[contextID].MessageIDs, message.MessageID) + // Update last access time + m.Conversations[contextID].LastAccessTime = time.Now() + + // Limit history length + if len(m.Conversations[contextID].MessageIDs) > m.maxHistoryLength { + // Remove the oldest message + removedMsgID := m.Conversations[contextID].MessageIDs[0] + m.Conversations[contextID].MessageIDs = m.Conversations[contextID].MessageIDs[1:] + // Delete old message from message storage + delete(m.Messages, removedMsgID) } } - task.History = nil // Ensure history is nil if not requested. - return task, nil } -// OnCancelTask attempts to cancel an ongoing task. -// It implements the TaskManager interface. -func (m *MemoryTaskManager) OnCancelTask(ctx context.Context, params protocol.TaskIDParams) (*protocol.Task, error) { - m.TasksMutex.Lock() - task, exists := m.Tasks[params.ID] - if !exists { - m.TasksMutex.Unlock() - return nil, ErrTaskNotFound(params.ID) - } - m.TasksMutex.Unlock() // Release lock before potential context cancel / status update. - // Check if task is already in a final state (read lock again). - m.TasksMutex.RLock() - alreadyFinal := isFinalState(task.Status.State) - m.TasksMutex.RUnlock() - if alreadyFinal { - return task, ErrTaskFinalState(params.ID, task.Status.State) - } - // Find and call the context cancel func stored for this taskID. - var cancelFound bool - m.ContextsMutex.Lock() - cancel, exists := m.Contexts[params.ID] - if exists { - cancel() // Call the cancel function. - cancelFound = true - // Don't delete the context here - let the processor goroutine clean up. - } - m.ContextsMutex.Unlock() - if !cancelFound { - log.Warnf("Warning: No cancellation function found for task %s", params.ID) - } - // Create a cancellation message. - cancelMsg := &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{ - protocol.NewTextPart(fmt.Sprintf("Task %s was canceled by user request", params.ID)), - }, - } - // Update state to Cancelled. - if err := m.UpdateTaskStatus(ctx, params.ID, protocol.TaskStateCanceled, cancelMsg); err != nil { - log.Errorf("Error updating status to Cancelled for task %s: %v", params.ID, err) - return nil, err - } - // Fetch the updated task state to return. - updatedTask, err := m.getTaskInternal(params.ID) - if err != nil { - return nil, fmt.Errorf("failed to get task %s after cancellation update: %w", params.ID, err) +// getMessageHistory gets message history +func (m *MemoryTaskManager) getMessageHistory(contextID string) []protocol.Message { + var history []protocol.Message + if contextID == "" { + return history } - return updatedTask, nil -} -// UpdateTaskStatus updates the task's state and notifies any subscribers. -// Returns an error if the task does not exist. -// Exported method (used by memoryTaskHandle). -func (m *MemoryTaskManager) UpdateTaskStatus( - ctx context.Context, - taskID string, - state protocol.TaskState, - message *protocol.Message, -) error { - m.TasksMutex.Lock() - task, exists := m.Tasks[taskID] - if !exists { - m.TasksMutex.Unlock() - log.Warnf("Warning: UpdateTaskStatus called for non-existent task %s", taskID) - return ErrTaskNotFound(taskID) - } - // Update status fields. - task.Status = protocol.TaskStatus{ - State: state, - Message: message, - Timestamp: time.Now().UTC().Format(time.RFC3339), - } - // Create a copy for notification before unlocking. - taskCopy := *task - m.TasksMutex.Unlock() // Unlock before potentially blocking on channel send. - if processor, ok := m.Processor.(TaskProcessorWithStatusUpdate); ok { - if err := processor.OnTaskStatusUpdate(ctx, taskID, state, message); err != nil { - log.Errorf("Error updating status for task %s: %v", taskID, err) + // Need to protect access to both conversations and messages + m.mu.Lock() + defer m.mu.Unlock() + + if conversation, exists := m.Conversations[contextID]; exists { + // Update last access time + conversation.LastAccessTime = time.Now() + + history = make([]protocol.Message, 0, len(conversation.MessageIDs)) + for _, msgID := range conversation.MessageIDs { + if msg, exists := m.Messages[msgID]; exists { + history = append(history, msg) + } } } - // Store the message in history if provided - if message != nil { - // Convert TaskStatus Message (which is a pointer) to a Message value for history - m.storeMessage(taskID, *message) - } - // Notify subscribers outside the lock. - m.notifySubscribers(taskID, protocol.TaskStatusUpdateEvent{ - ID: taskID, - Status: taskCopy.Status, - Final: isFinalState(state), - }) - return nil -} - -// AddArtifact adds an artifact to the task and notifies subscribers. -// Returns an error if the task does not exist. -// Exported method (used by memoryTaskHandle). -func (m *MemoryTaskManager) AddArtifact(taskID string, artifact protocol.Artifact) error { - m.TasksMutex.Lock() - task, exists := m.Tasks[taskID] - if !exists { - m.TasksMutex.Unlock() - log.Warnf("Warning: AddArtifact called for non-existent task %s", taskID) - return ErrTaskNotFound(taskID) - } - // Append the artifact. - if task.Artifacts == nil { - task.Artifacts = make([]protocol.Artifact, 0, 1) - } - task.Artifacts = append(task.Artifacts, artifact) - // Create copies for notification before unlocking. - m.TasksMutex.Unlock() // Unlock before potentially blocking on channel send. - // Notify subscribers outside the lock. - finalEvent := artifact.LastChunk != nil && *artifact.LastChunk - m.notifySubscribers(taskID, protocol.TaskArtifactUpdateEvent{ - ID: taskID, - Artifact: artifact, - Final: finalEvent, - }) - return nil -} - -// --- Internal Helper Methods (Unexported) --- - -// upsertTask creates a new task or updates metadata if it already exists. -// Assumes locks are handled by the caller if needed, but acquires its own lock. -func (m *MemoryTaskManager) upsertTask(params protocol.SendTaskParams) *protocol.Task { - m.TasksMutex.Lock() - defer m.TasksMutex.Unlock() - task, exists := m.Tasks[params.ID] - if !exists { - task = protocol.NewTask(params.ID, params.SessionID) - m.Tasks[params.ID] = task - log.Infof("Created new task %s (Session: %v)", params.ID, params.SessionID) - } else { - log.Debugf("Updating existing task %s", params.ID) - } - // Update metadata if provided. - if params.Metadata != nil { - if task.Metadata == nil { - task.Metadata = make(map[string]interface{}) + return history +} + +// getConversationHistory gets conversation history of specified length +func (m *MemoryTaskManager) getConversationHistory(contextID string, length int) []protocol.Message { + m.conversationMu.RLock() + defer m.conversationMu.RUnlock() + + var history []protocol.Message + + if conversation, exists := m.Conversations[contextID]; exists { + // Update last access time + conversation.LastAccessTime = time.Now() + + start := 0 + if len(conversation.MessageIDs) > length { + start = len(conversation.MessageIDs) - length } - for k, v := range params.Metadata { - task.Metadata[k] = v + + for i := start; i < len(conversation.MessageIDs); i++ { + if msg, exists := m.Messages[conversation.MessageIDs[i]]; exists { + history = append(history, msg) + } } } - return task + + return history +} + +// isFinalState checks if it's a final state +func isFinalState(state protocol.TaskState) bool { + return state == protocol.TaskStateCompleted || + state == protocol.TaskStateFailed || + state == protocol.TaskStateCanceled || + state == protocol.TaskStateRejected } -// storeMessage adds a message to the task's history. -// Assumes locks are handled by the caller if needed, but acquires its own lock. -func (m *MemoryTaskManager) storeMessage(taskID string, message protocol.Message) { - m.MessagesMutex.Lock() - defer m.MessagesMutex.Unlock() - if _, exists := m.Messages[taskID]; !exists { - m.Messages[taskID] = make([]protocol.Message, 0, 1) // Initialize with capacity. +// ============================================================================= +// Configuration related types and helper methods +// ============================================================================= + +// processConfiguration processes and normalizes Configuration +func (m *MemoryTaskManager) processConfiguration(config *protocol.SendMessageConfiguration, metadata map[string]interface{}) ProcessOptions { + result := ProcessOptions{ + Blocking: false, + HistoryLength: 0, + } + + if config == nil { + return result } - // Create a copy of the message to store, ensuring history isolation. - messageCopy := protocol.Message{ - Role: message.Role, - Metadata: message.Metadata, // Shallow copy of map is usually fine. + + // Process Blocking configuration + if config.Blocking != nil { + result.Blocking = *config.Blocking } - if message.Parts != nil { - // Copy the slice of parts (shallow copy of interface values is correct). - messageCopy.Parts = make([]protocol.Part, len(message.Parts)) - copy(messageCopy.Parts, message.Parts) + + // Process HistoryLength configuration + if config.HistoryLength != nil && *config.HistoryLength > 0 { + result.HistoryLength = *config.HistoryLength + } + + // Process PushNotificationConfig + if config.PushNotificationConfig != nil { + result.PushNotificationConfig = config.PushNotificationConfig } - m.Messages[taskID] = append(m.Messages[taskID], messageCopy) + + return result } -// addSubscriber adds a channel to the list of subscribers for a task. -func (m *MemoryTaskManager) addSubscriber(taskID string, ch chan<- protocol.TaskEvent) { - m.SubMutex.Lock() - defer m.SubMutex.Unlock() - if _, exists := m.Subscribers[taskID]; !exists { - m.Subscribers[taskID] = make([]chan<- protocol.TaskEvent, 0, 1) +func (m *MemoryTaskManager) processRequestMessage(message *protocol.Message) { + if message.MessageID == "" { + message.MessageID = protocol.GenerateMessageID() + } + if message.ContextID != nil { + m.storeMessage(*message) } - m.Subscribers[taskID] = append(m.Subscribers[taskID], ch) - log.Debugf("Added subscriber for task %s", taskID) } -// removeSubscriber removes a specific channel from the list of subscribers for a task. -func (m *MemoryTaskManager) removeSubscriber(taskID string, ch chan<- protocol.TaskEvent) { - m.SubMutex.Lock() - defer m.SubMutex.Unlock() - channels, exists := m.Subscribers[taskID] - if !exists { - return // No subscribers for this task. +func (m *MemoryTaskManager) processReplyMessage(ctxID string, message *protocol.Message) { + message.ContextID = &ctxID + message.Role = protocol.MessageRoleAgent + if message.MessageID == "" { + message.MessageID = protocol.GenerateMessageID() } - // Filter out the channel to remove. - var newChannels []chan<- protocol.TaskEvent - for _, existingCh := range channels { - if existingCh != ch { - newChannels = append(newChannels, existingCh) - } + + // if contextID is not nil, store the conversation history + if message.ContextID != nil { + m.storeMessage(*message) } - if len(newChannels) == 0 { - delete(m.Subscribers, taskID) // No more subscribers. - } else { - m.Subscribers[taskID] = newChannels +} + +func (m *MemoryTaskManager) checkTaskExists(taskID string) bool { + m.taskMu.RLock() + defer m.taskMu.RUnlock() + _, exists := m.Tasks[taskID] + return exists +} + +func (m *MemoryTaskManager) getTask(taskID string) (*MemoryCancellableTask, error) { + m.taskMu.RLock() + defer m.taskMu.RUnlock() + task, exists := m.Tasks[taskID] + if !exists { + return nil, fmt.Errorf("task not found: %s", taskID) } - log.Debugf("Removed subscriber for task %s", taskID) + return task, nil } -// notifySubscribers sends an event to all current subscribers of a task. -func (m *MemoryTaskManager) notifySubscribers(taskID string, event protocol.TaskEvent) { - m.SubMutex.RLock() +// notifySubscribers notifies all subscribers of the task +func (m *MemoryTaskManager) notifySubscribers(taskID string, event protocol.StreamingMessageEvent) { + m.taskMu.RLock() subs, exists := m.Subscribers[taskID] if !exists || len(subs) == 0 { - m.SubMutex.RUnlock() - return // No subscribers to notify. + m.taskMu.RUnlock() + return } - // Copy the slice of channels under read lock. - subsCopy := make([]chan<- protocol.TaskEvent, len(subs)) + + subsCopy := make([]*MemoryTaskSubscriber, len(subs)) copy(subsCopy, subs) - m.SubMutex.RUnlock() - log.Debugf("Notifying %d subscribers for task %s (Event Type: %T, Final: %t)", - len(subsCopy), taskID, event, event.IsFinal()) - // Send events outside the lock. - for _, ch := range subsCopy { - // Use a select with a default case for a non-blocking send. - // This prevents one slow/blocked subscriber from delaying others. - select { - case ch <- event: - // Event sent successfully. - default: - // Channel buffer is full or channel is closed. - log.Warnf("Warning: Dropping event for task %s subscriber - channel full or closed.", taskID) + m.taskMu.RUnlock() + + log.Debugf("Notifying %d subscribers for task %s (Event Type: %T)", len(subsCopy), taskID, event.Result) + + var failedSubscribers []*MemoryTaskSubscriber + + for _, sub := range subsCopy { + if sub.Closed() { + log.Debugf("Subscriber for task %s is already closed, marking for removal", taskID) + failedSubscribers = append(failedSubscribers, sub) + continue + } + + err := sub.Send(event) + if err != nil { + log.Warnf("Failed to send event to subscriber for task %s: %v", taskID, err) + failedSubscribers = append(failedSubscribers, sub) } } -} -// OnPushNotificationSet implements TaskManager.OnPushNotificationSet. -// It sets push notification configuration for a task. -func (m *MemoryTaskManager) OnPushNotificationSet( - ctx context.Context, - params protocol.TaskPushNotificationConfig, -) (*protocol.TaskPushNotificationConfig, error) { - // Store the push notification configuration. - m.PushNotificationsMutex.Lock() - m.PushNotifications[params.ID] = params.PushNotificationConfig - m.PushNotificationsMutex.Unlock() - log.Infof("Set push notification for task %s to URL: %s", params.ID, params.PushNotificationConfig.URL) - // Return the stored configuration as confirmation. - return ¶ms, nil + // Clean up failed or closed subscribers + if len(failedSubscribers) > 0 { + m.cleanupFailedSubscribers(taskID, failedSubscribers) + } } -// OnPushNotificationGet implements TaskManager.OnPushNotificationGet. -// It retrieves the push notification configuration for a task. -func (m *MemoryTaskManager) OnPushNotificationGet( - ctx context.Context, params protocol.TaskIDParams, -) (*protocol.TaskPushNotificationConfig, error) { - m.TasksMutex.RLock() - _, exists := m.Tasks[params.ID] - m.TasksMutex.RUnlock() +// cleanupFailedSubscribers cleans up failed or closed subscribers +func (m *MemoryTaskManager) cleanupFailedSubscribers(taskID string, failedSubscribers []*MemoryTaskSubscriber) { + m.taskMu.Lock() + defer m.taskMu.Unlock() + + subs, exists := m.Subscribers[taskID] if !exists { - return nil, ErrTaskNotFound(params.ID) + return } - // Retrieve the push notification configuration. - m.PushNotificationsMutex.RLock() - config, exists := m.PushNotifications[params.ID] - m.PushNotificationsMutex.RUnlock() - if !exists { - // Task exists but has no push notification config. - return nil, ErrPushNotificationNotConfigured(params.ID) + + // Filter out failed subscribers + filteredSubs := make([]*MemoryTaskSubscriber, 0, len(subs)) + removedCount := 0 + + for _, sub := range subs { + shouldRemove := false + for _, failedSub := range failedSubscribers { + if sub == failedSub { + shouldRemove = true + removedCount++ + break + } + } + if !shouldRemove { + filteredSubs = append(filteredSubs, sub) + } } - result := &protocol.TaskPushNotificationConfig{ - ID: params.ID, - PushNotificationConfig: config, + + if removedCount > 0 { + m.Subscribers[taskID] = filteredSubs + log.Debugf("Removed %d failed subscribers for task %s", removedCount, taskID) + + // If there are no subscribers left, delete the entire entry + if len(filteredSubs) == 0 { + delete(m.Subscribers, taskID) + } } - return result, nil } -// OnResubscribe implements TaskManager.OnResubscribe. -// It allows a client to reestablish an SSE stream for an existing task. -func (m *MemoryTaskManager) OnResubscribe(ctx context.Context, params protocol.TaskIDParams) (<-chan protocol.TaskEvent, error) { - m.TasksMutex.RLock() - task, exists := m.Tasks[params.ID] - m.TasksMutex.RUnlock() - if !exists { - return nil, ErrTaskNotFound(params.ID) +// addSubscriber adds a subscriber +func (m *MemoryTaskManager) addSubscriber(taskID string, sub *MemoryTaskSubscriber) { + m.taskMu.Lock() + defer m.taskMu.Unlock() + + if _, exists := m.Subscribers[taskID]; !exists { + m.Subscribers[taskID] = make([]*MemoryTaskSubscriber, 0) } - // Create a channel for events. - eventChan := make(chan protocol.TaskEvent) - // For tasks in final state, just send a status update event and close. - if isFinalState(task.Status.State) { - go func() { - // Send a task status update event. - event := protocol.TaskStatusUpdateEvent{ - ID: task.ID, - Status: task.Status, - Final: true, - } - select { - case eventChan <- event: - // Successfully sent final status. - log.Debugf("Sent final status to resubscribed client for task %s: %s", task.ID, task.Status.State) - case <-ctx.Done(): - // Context was canceled. - log.Debugf("Context done before sending status for task %s", task.ID) - } - // Close the channel to signal no more events. - close(eventChan) - }() - return eventChan, nil - } - // For tasks still in progress, add this as a subscriber. - m.addSubscriber(params.ID, eventChan) - // Ensure we remove the subscriber when the context is canceled. - go func() { - <-ctx.Done() - m.removeSubscriber(params.ID, eventChan) - // Don't close the channel here - that should happen in the task processing goroutine. - }() - // Send the current status as the first event. - go func() { - event := protocol.TaskStatusUpdateEvent{ - ID: task.ID, - Status: task.Status, - Final: isFinalState(task.Status.State), - } - select { - case eventChan <- event: - // Successfully sent initial status. - log.Debugf("Sent initial status to resubscribed client for task %s: %s", task.ID, task.Status.State) - case <-ctx.Done(): - // Context was canceled. - log.Debugf("Context done before sending initial status for task %s", task.ID) - m.removeSubscriber(params.ID, eventChan) - close(eventChan) - } - }() - return eventChan, nil + m.Subscribers[taskID] = append(m.Subscribers[taskID], sub) } -// processError checks the error type and returns the appropriate task manager error. -// If the error already has the right format, it returns it directly. -func processError(err error) error { - if err == nil { - return nil +// cleanSubscribers cleans up subscribers +func (m *MemoryTaskManager) cleanSubscribers(taskID string) { + m.taskMu.Lock() + defer m.taskMu.Unlock() + for _, sub := range m.Subscribers[taskID] { + sub.Close() } - // Check if it's a task not found error message - if strings.Contains(strings.ToLower(err.Error()), "not found") { - return ErrTaskNotFound(err.Error()) + delete(m.Subscribers, taskID) +} + +// CleanExpiredConversations cleans up expired conversation history +// maxAge: the maximum lifetime of the conversation, conversations not accessed beyond this time will be cleaned up +func (m *MemoryTaskManager) CleanExpiredConversations(maxAge time.Duration) int { + m.mu.Lock() + defer m.mu.Unlock() + + now := time.Now() + expiredContexts := make([]string, 0) + expiredMessageIDs := make([]string, 0) + + // Find expired conversations + for contextID, conversation := range m.Conversations { + if now.Sub(conversation.LastAccessTime) > maxAge { + expiredContexts = append(expiredContexts, contextID) + expiredMessageIDs = append(expiredMessageIDs, conversation.MessageIDs...) + } + } + + // Delete expired conversations + for _, contextID := range expiredContexts { + delete(m.Conversations, contextID) } - // For other errors, return a generic internal error - return err + + // Delete messages from expired conversations + for _, messageID := range expiredMessageIDs { + delete(m.Messages, messageID) + } + + if len(expiredContexts) > 0 { + log.Debugf("Cleaned %d expired conversations, removed %d messages", + len(expiredContexts), len(expiredMessageIDs)) + } + + return len(expiredContexts) } -// getTaskWithValidation gets a task and validates it exists. -// Returns task and nil if found, nil and error if not found. -func (m *MemoryTaskManager) getTaskWithValidation(taskID string) (*protocol.Task, error) { - m.TasksMutex.RLock() - task, exists := m.Tasks[taskID] - m.TasksMutex.RUnlock() +// GetConversationStats gets conversation statistics +func (m *MemoryTaskManager) GetConversationStats() map[string]interface{} { + m.mu.RLock() + defer m.mu.RUnlock() - if !exists { - return nil, ErrTaskNotFound(taskID) + totalConversations := len(m.Conversations) + totalMessages := len(m.Messages) + + oldestAccess := time.Now() + newestAccess := time.Time{} + + for _, conversation := range m.Conversations { + if conversation.LastAccessTime.Before(oldestAccess) { + oldestAccess = conversation.LastAccessTime + } + if conversation.LastAccessTime.After(newestAccess) { + newestAccess = conversation.LastAccessTime + } } - taskCopy := *task // Return a copy. - return &taskCopy, nil + + stats := map[string]interface{}{ + "total_conversations": totalConversations, + "total_messages": totalMessages, + } + + if totalConversations > 0 { + stats["oldest_access"] = oldestAccess + stats["newest_access"] = newestAccess + } + + return stats } diff --git a/taskmanager/memory_handler.go b/taskmanager/memory_handler.go new file mode 100644 index 0000000..5bb7272 --- /dev/null +++ b/taskmanager/memory_handler.go @@ -0,0 +1,253 @@ +// Tencent is pleased to support the open source community by making trpc-a2a-go available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// trpc-a2a-go is licensed under the Apache License Version 2.0. + +package taskmanager + +import ( + "context" + "fmt" + "time" + + "trpc.group/trpc-go/trpc-a2a-go/log" + "trpc.group/trpc-go/trpc-a2a-go/protocol" +) + +// ============================================================================= +// MessageHandle Implementation +// ============================================================================= + +// memoryTaskHandler implements TaskHandler interface +type memoryTaskHandler struct { + manager *MemoryTaskManager + messageID string + ctx context.Context +} + +var _ TaskHandler = (*memoryTaskHandler)(nil) + +// UpdateTaskState updates task state +func (h *memoryTaskHandler) UpdateTaskState( + taskID *string, + state protocol.TaskState, + message *protocol.Message, +) error { + if taskID == nil || *taskID == "" { + return fmt.Errorf("taskID cannot be nil or empty") + } + + h.manager.taskMu.Lock() + task, exists := h.manager.Tasks[*taskID] + if !exists { + h.manager.taskMu.Unlock() + log.Warnf("UpdateTaskState called for non-existent task %s", *taskID) + return fmt.Errorf("task not found: %s", *taskID) + } + + originalTask := task.Task() + originalTask.Status = protocol.TaskStatus{ + State: state, + Message: message, + Timestamp: time.Now().UTC().Format(time.RFC3339), + } + h.manager.taskMu.Unlock() + + log.Debugf("Updated task %s state to %s", *taskID, state) + + // notify subscribers + finalState := isFinalState(state) + event := &protocol.TaskStatusUpdateEvent{ + TaskID: *taskID, + ContextID: originalTask.ContextID, + Status: originalTask.Status, + Kind: protocol.KindTaskStatusUpdate, + Final: &finalState, + } + streamEvent := protocol.StreamingMessageEvent{Result: event} + h.manager.notifySubscribers(*taskID, streamEvent) + return nil +} + +// SubScribeTask subscribes to the task +func (h *memoryTaskHandler) SubScribeTask(taskID *string) (TaskSubscriber, error) { + if taskID == nil || *taskID == "" { + return nil, fmt.Errorf("taskID cannot be nil or empty") + } + if !h.manager.checkTaskExists(*taskID) { + return nil, fmt.Errorf("task not found: %s", *taskID) + } + subscriber := NewMemoryTaskSubscriber(*taskID, defaultTaskSubscriberBufferSize) + h.manager.addSubscriber(*taskID, subscriber) + return subscriber, nil +} + +// AddArtifact adds artifact to specified task +func (h *memoryTaskHandler) AddArtifact( + taskID *string, + artifact protocol.Artifact, + isFinal bool, + needMoreData bool, +) error { + if taskID == nil || *taskID == "" { + return fmt.Errorf("taskID cannot be nil or empty") + } + + h.manager.taskMu.Lock() + task, exists := h.manager.Tasks[*taskID] + if !exists { + h.manager.taskMu.Unlock() + return fmt.Errorf("task not found: %s", *taskID) + } + task.Task().Artifacts = append(task.Task().Artifacts, artifact) + h.manager.taskMu.Unlock() + + log.Debugf("Added artifact %s to task %s", artifact.ArtifactID, *taskID) + + // notify subscribers + event := &protocol.TaskArtifactUpdateEvent{ + TaskID: *taskID, + ContextID: task.Task().ContextID, + Artifact: artifact, + Kind: protocol.KindTaskArtifactUpdate, + LastChunk: &isFinal, + Append: &needMoreData, + } + streamEvent := protocol.StreamingMessageEvent{Result: event} + h.manager.notifySubscribers(*taskID, streamEvent) + + return nil +} + +// GetTask gets task +func (h *memoryTaskHandler) GetTask(taskID *string) (CancellableTask, error) { + if taskID == nil || *taskID == "" { + return nil, fmt.Errorf("taskID cannot be nil or empty") + } + + h.manager.taskMu.RLock() + defer h.manager.taskMu.RUnlock() + + task, err := h.manager.getTask(*taskID) + if err != nil { + return nil, err + } + + // return task copy to avoid external modification + taskCopy := *task.Task() + if taskCopy.Artifacts != nil { + taskCopy.Artifacts = make([]protocol.Artifact, len(task.Task().Artifacts)) + copy(taskCopy.Artifacts, task.Task().Artifacts) + } + if taskCopy.History != nil { + taskCopy.History = make([]protocol.Message, len(task.Task().History)) + copy(taskCopy.History, task.Task().History) + } + + return &MemoryCancellableTask{ + task: taskCopy, + cancelFunc: task.cancelFunc, + ctx: task.ctx, + }, nil +} + +// GetContextID gets context ID +func (h *memoryTaskHandler) GetContextID() string { + h.manager.conversationMu.RLock() + defer h.manager.conversationMu.RUnlock() + + if msg, exists := h.manager.Messages[h.messageID]; exists && msg.ContextID != nil { + return *msg.ContextID + } + return "" +} + +// GetMessageHistory gets message history +func (h *memoryTaskHandler) GetMessageHistory() []protocol.Message { + h.manager.conversationMu.RLock() + defer h.manager.conversationMu.RUnlock() + + if msg, exists := h.manager.Messages[h.messageID]; exists && msg.ContextID != nil { + return h.manager.getMessageHistory(*msg.ContextID) + } + return []protocol.Message{} +} + +// BuildTask creates a new task and returns task object +func (h *memoryTaskHandler) BuildTask(specificTaskID *string, contextID *string) (string, error) { + h.manager.taskMu.Lock() + defer h.manager.taskMu.Unlock() + + // if no taskID provided, generate one + var actualTaskID string + if specificTaskID == nil || *specificTaskID == "" { + actualTaskID = protocol.GenerateTaskID() + } else { + actualTaskID = *specificTaskID + } + + // Check if task already exists to avoid duplicate WithCancel calls + if _, exists := h.manager.Tasks[actualTaskID]; exists { + log.Warnf("Task %s already exists, returning existing task", actualTaskID) + return "", fmt.Errorf("task already exists: %s", actualTaskID) + } + + var actualContextID string + if contextID == nil || *contextID == "" { + actualContextID = "" + } else { + actualContextID = *contextID + } + + // create new task + task := protocol.Task{ + ID: actualTaskID, + ContextID: actualContextID, + Kind: protocol.KindTask, + Status: protocol.TaskStatus{ + State: protocol.TaskStateSubmitted, + Timestamp: time.Now().UTC().Format(time.RFC3339), + }, + Artifacts: make([]protocol.Artifact, 0), + History: make([]protocol.Message, 0), + Metadata: make(map[string]interface{}), + } + + cancellableTask := NewCancellableTask(task) + + // store task + h.manager.Tasks[actualTaskID] = cancellableTask + + log.Debugf("Created new task %s with context %s", actualTaskID, actualContextID) + + return actualTaskID, nil +} + +// CancelTask cancels the task. +func (h *memoryTaskHandler) CleanTask(taskID *string) error { + if taskID == nil || *taskID == "" { + return fmt.Errorf("taskID cannot be nil or empty") + } + + h.manager.taskMu.Lock() + task, exists := h.manager.Tasks[*taskID] + if !exists { + h.manager.taskMu.Unlock() + return fmt.Errorf("task not found: %s", *taskID) + } + + // Cancel the task and remove from Tasks map while holding the lock + task.Cancel() + delete(h.manager.Tasks, *taskID) + + // Clean up subscribers while holding the lock to avoid another lock acquisition + for _, sub := range h.manager.Subscribers[*taskID] { + sub.Close() + } + delete(h.manager.Subscribers, *taskID) + + h.manager.taskMu.Unlock() + + return nil +} diff --git a/taskmanager/memory_handler_test.go b/taskmanager/memory_handler_test.go new file mode 100644 index 0000000..8a8fd79 --- /dev/null +++ b/taskmanager/memory_handler_test.go @@ -0,0 +1,357 @@ +// Tencent is pleased to support the open source community by making trpc-a2a-go available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// trpc-a2a-go is licensed under the Apache License Version 2.0. + +package taskmanager + +import ( + "context" + "testing" + + "trpc.group/trpc-go/trpc-a2a-go/protocol" +) + +// setupTestHandler creates a test handler for use in tests +func setupTestHandler(t *testing.T) (*memoryTaskHandler, *MemoryTaskManager) { + processor := &MockMessageProcessor{} + manager, err := NewMemoryTaskManager(processor) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + ctx := context.Background() + message := protocol.Message{ + Role: protocol.MessageRoleUser, + Parts: []protocol.Part{ + protocol.NewTextPart("Test message"), + }, + } + + manager.storeMessage(message) + + handler := &memoryTaskHandler{ + manager: manager, + messageID: message.MessageID, + ctx: ctx, + } + + return handler, manager +} + +func TestMemoryTaskHandler_BuildTask(t *testing.T) { + handler, _ := setupTestHandler(t) + + // Test building a task + taskID, err := handler.BuildTask(nil, nil) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if taskID == "" { + t.Error("Expected task ID to be set") + } + + // Test building task with custom ID + customID := "custom-task-id" + customTaskID, err := handler.BuildTask(&customID, nil) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if customTaskID != customID { + t.Errorf("Expected task ID %s, got %s", customID, customTaskID) + } +} + +func TestMemoryTaskHandler_UpdateTaskState(t *testing.T) { + handler, _ := setupTestHandler(t) + + // First create a task + taskID, err := handler.BuildTask(nil, nil) + if err != nil { + t.Fatalf("Failed to create task: %v", err) + } + + // Update task state + err = handler.UpdateTaskState(&taskID, protocol.TaskStateWorking, nil) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Get the task to verify state was updated + updatedTask, err := handler.GetTask(&taskID) + if err != nil { + t.Errorf("Failed to get updated task: %v", err) + } + + if updatedTask.Task().Status.State != protocol.TaskStateWorking { + t.Errorf("Expected state %s, got %s", protocol.TaskStateWorking, updatedTask.Task().Status.State) + } +} + +func TestMemoryTaskHandler_AddArtifact(t *testing.T) { + handler, _ := setupTestHandler(t) + + // First create a task + taskID, err := handler.BuildTask(nil, nil) + if err != nil { + t.Fatalf("Failed to create task: %v", err) + } + + // Add artifact + artifact := protocol.Artifact{ + ArtifactID: "test-artifact", + Parts: []protocol.Part{ + protocol.NewTextPart("Artifact content"), + }, + } + + err = handler.AddArtifact(&taskID, artifact, false, true) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Verify artifact was added + retrievedTask, err := handler.GetTask(&taskID) + if err != nil { + t.Errorf("Failed to get task: %v", err) + } + + if len(retrievedTask.Task().Artifacts) == 0 { + t.Error("Expected artifact to be added") + } + + if retrievedTask.Task().Artifacts[0].ArtifactID != artifact.ArtifactID { + t.Errorf("Expected artifact ID %s, got %s", artifact.ArtifactID, retrievedTask.Task().Artifacts[0].ArtifactID) + } +} + +func TestMemoryTaskHandler_SubScribeTask(t *testing.T) { + handler, _ := setupTestHandler(t) + + // First create a task + taskID, err := handler.BuildTask(nil, nil) + if err != nil { + t.Fatalf("Failed to create task: %v", err) + } + + // Subscribe to task + subscriber, err := handler.SubScribeTask(&taskID) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if subscriber == nil { + t.Error("Expected subscriber but got nil") + } + + // Clean up + subscriber.Close() +} + +func TestMemoryTaskHandler_GetTask(t *testing.T) { + handler, _ := setupTestHandler(t) + + // First create a task + taskID, err := handler.BuildTask(nil, nil) + if err != nil { + t.Fatalf("Failed to create task: %v", err) + } + + // Get the task + retrievedTask, err := handler.GetTask(&taskID) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if retrievedTask == nil { + t.Error("Expected task but got nil") + } + + if retrievedTask.Task().ID != taskID { + t.Errorf("Expected task ID %s, got %s", taskID, retrievedTask.Task().ID) + } +} + +func TestMemoryTaskHandler_CleanTask(t *testing.T) { + handler, _ := setupTestHandler(t) + + // First create a task + taskID, err := handler.BuildTask(nil, nil) + if err != nil { + t.Fatalf("Failed to create task: %v", err) + } + + // Clean the task + err = handler.CleanTask(&taskID) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Verify task was cleaned (should be deleted, not just canceled) + _, err = handler.GetTask(&taskID) + if err == nil { + t.Error("Expected error when getting cleaned task, but got none") + } +} + +func TestMemoryTaskHandler_GetMessageHistory(t *testing.T) { + _, manager := setupTestHandler(t) + + // Create a message with context + contextID := "test-context" + contextMessage := protocol.Message{ + Role: protocol.MessageRoleUser, + ContextID: &contextID, + Parts: []protocol.Part{ + protocol.NewTextPart("Context message"), + }, + } + + manager.storeMessage(contextMessage) + + // Create handler with context message + contextHandler := &memoryTaskHandler{ + manager: manager, + messageID: contextMessage.MessageID, + ctx: context.Background(), + } + + // Get message history + history := contextHandler.GetMessageHistory() + + if len(history) == 0 { + t.Error("Expected message history but got empty") + } + + // Should contain the context message + found := false + for _, msg := range history { + if msg.MessageID == contextMessage.MessageID { + found = true + break + } + } + + if !found { + t.Error("Expected to find context message in history") + } +} + +func TestMemoryTaskHandler_GetContextID(t *testing.T) { + _, manager := setupTestHandler(t) + + // Create a message with context + contextID := "test-context-id" + contextMessage := protocol.Message{ + Role: protocol.MessageRoleUser, + ContextID: &contextID, + Parts: []protocol.Part{ + protocol.NewTextPart("Context message"), + }, + } + + manager.storeMessage(contextMessage) + + // Create handler with context message + contextHandler := &memoryTaskHandler{ + manager: manager, + messageID: contextMessage.MessageID, + ctx: context.Background(), + } + + // Get context ID + retrievedContextID := contextHandler.GetContextID() + + if retrievedContextID != contextID { + t.Errorf("Expected context ID %s, got %s", contextID, retrievedContextID) + } +} + +func TestTaskHandlerErrors(t *testing.T) { + processor := &MockMessageProcessor{} + manager, err := NewMemoryTaskManager(processor) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + ctx := context.Background() + handler := &memoryTaskHandler{ + manager: manager, + messageID: "non-existent-message", + ctx: ctx, + } + + t.Run("UpdateTaskState_NonExistentTask", func(t *testing.T) { + nonExistentTaskID := "non-existent-task" + err := handler.UpdateTaskState(&nonExistentTaskID, protocol.TaskStateWorking, nil) + if err == nil { + t.Error("Expected error for non-existent task") + } + }) + + t.Run("AddArtifact_NonExistentTask", func(t *testing.T) { + nonExistentTaskID := "non-existent-task" + artifact := protocol.Artifact{ + ArtifactID: "test-artifact", + } + + err := handler.AddArtifact(&nonExistentTaskID, artifact, false, true) + if err == nil { + t.Error("Expected error for non-existent task") + } + }) + + t.Run("SubScribeTask_NonExistentTask", func(t *testing.T) { + nonExistentTaskID := "non-existent-task" + _, err := handler.SubScribeTask(&nonExistentTaskID) + if err == nil { + t.Error("Expected error for non-existent task") + } + }) + + t.Run("GetTask_NonExistentTask", func(t *testing.T) { + nonExistentTaskID := "non-existent-task" + _, err := handler.GetTask(&nonExistentTaskID) + if err == nil { + t.Error("Expected error for non-existent task") + } + }) + + t.Run("CleanTask_NonExistentTask", func(t *testing.T) { + nonExistentTaskID := "non-existent-task" + err := handler.CleanTask(&nonExistentTaskID) + if err == nil { + t.Error("Expected error for non-existent task") + } + }) + + t.Run("NilTaskID", func(t *testing.T) { + err := handler.UpdateTaskState(nil, protocol.TaskStateWorking, nil) + if err == nil { + t.Error("Expected error for nil task ID") + } + + err = handler.AddArtifact(nil, protocol.Artifact{}, false, true) + if err == nil { + t.Error("Expected error for nil task ID") + } + + _, err = handler.SubScribeTask(nil) + if err == nil { + t.Error("Expected error for nil task ID") + } + + _, err = handler.GetTask(nil) + if err == nil { + t.Error("Expected error for nil task ID") + } + + err = handler.CleanTask(nil) + if err == nil { + t.Error("Expected error for nil task ID") + } + }) +} diff --git a/taskmanager/memory_options.go b/taskmanager/memory_options.go new file mode 100644 index 0000000..7541a51 --- /dev/null +++ b/taskmanager/memory_options.go @@ -0,0 +1,62 @@ +// Tencent is pleased to support the open source community by making trpc-a2a-go available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// trpc-a2a-go is licensed under the Apache License Version 2.0. + +// Package taskmanager provides configuration options for MemoryTaskManager. +package taskmanager + +import ( + "time" +) + +// MemoryTaskManagerOptions contains configuration options for MemoryTaskManager. +type MemoryTaskManagerOptions struct { + // MaxHistoryLength is the maximum number of messages to keep in conversation history. + MaxHistoryLength int + + // ConversationTTL is the maximum lifetime of conversations. + ConversationTTL time.Duration + + // CleanupInterval is the interval for cleanup checks. + CleanupInterval time.Duration + + // EnableCleanup enables automatic cleanup of expired conversations. + EnableCleanup bool +} + +// DefaultMemoryTaskManagerOptions returns the default configuration options. +func DefaultMemoryTaskManagerOptions() *MemoryTaskManagerOptions { + return &MemoryTaskManagerOptions{ + MaxHistoryLength: defaultMaxHistoryLength, + ConversationTTL: defaultConversationTTL, + CleanupInterval: defaultCleanupInterval, + EnableCleanup: true, + } +} + +// MemoryTaskManagerOption defines a function type for configuring MemoryTaskManager. +type MemoryTaskManagerOption func(*MemoryTaskManagerOptions) + +// WithMaxHistoryLength sets the maximum number of messages to keep in conversation history. +func WithMaxHistoryLength(length int) MemoryTaskManagerOption { + return func(opts *MemoryTaskManagerOptions) { + if length > 0 { + opts.MaxHistoryLength = length + } + } +} + +// WithConversationTTL sets the conversation TTL, enabling automatic cleanup. +// ttl: the maximum lifetime of the conversation +// cleanupInterval: the interval time for cleanup check +func WithConversationTTL(ttl, cleanupInterval time.Duration) MemoryTaskManagerOption { + return func(opts *MemoryTaskManagerOptions) { + if ttl > 0 && cleanupInterval > 0 { + opts.ConversationTTL = ttl + opts.CleanupInterval = cleanupInterval + opts.EnableCleanup = true + } + } +} diff --git a/taskmanager/memory_test.go b/taskmanager/memory_test.go index e648175..5f72543 100644 --- a/taskmanager/memory_test.go +++ b/taskmanager/memory_test.go @@ -8,1193 +8,697 @@ package taskmanager import ( "context" - "fmt" - "sync" "testing" "time" - // "trpc.group/trpc-go/trpc-a2a-go/internal/jsonrpc" // Removed unused import - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "trpc.group/trpc-go/trpc-a2a-go/internal/jsonrpc" "trpc.group/trpc-go/trpc-a2a-go/protocol" ) -// mockProcessor is a simple TaskProcessor for testing. -type mockProcessor struct { - processFunc func(ctx context.Context, taskID string, msg protocol.Message, handle TaskHandle) error - mu sync.Mutex - callCount int - lastTaskID string - lastMessage protocol.Message +// MockMessageProcessor implements MessageProcessor for testing +type MockMessageProcessor struct { + ProcessMessageFunc func(ctx context.Context, message protocol.Message, options ProcessOptions, handle TaskHandler) (*MessageProcessingResult, error) } -// Process implements TaskProcessor. -func (p *mockProcessor) Process(ctx context.Context, taskID string, msg protocol.Message, handle TaskHandle) error { - p.mu.Lock() - p.callCount++ - p.lastTaskID = taskID - p.lastMessage = msg - customFunc := p.processFunc - p.mu.Unlock() - - if customFunc != nil { - return customFunc(ctx, taskID, msg, handle) +func (m *MockMessageProcessor) ProcessMessage(ctx context.Context, message protocol.Message, options ProcessOptions, handle TaskHandler) (*MessageProcessingResult, error) { + if m.ProcessMessageFunc != nil { + return m.ProcessMessageFunc(ctx, message, options, handle) } - // Default behavior: complete successfully after a short delay - // to allow subscribers to attach. - time.Sleep(10 * time.Millisecond) - - // Check if context is done before attempting to complete the task - if ctx.Err() != nil { - return ctx.Err() - } - - // Set the status to Working first to test transition - err := handle.UpdateStatus(protocol.TaskStateWorking, &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{protocol.NewTextPart("Mock Working...")}, - }) - if err != nil { - return err - } - - // Small delay between updates - time.Sleep(5 * time.Millisecond) - - // Check context again before final update - if ctx.Err() != nil { - return ctx.Err() + // Default implementation: echo the message + response := &protocol.Message{ + Role: protocol.MessageRoleAgent, + Parts: []protocol.Part{ + protocol.NewTextPart("Echo: " + getTextFromMessage(message)), + }, } - // Mark task as completed and send final message - err = handle.UpdateStatus(protocol.TaskStateCompleted, &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{protocol.NewTextPart("Mock Success")}, - }) - - return err + return &MessageProcessingResult{ + Result: response, + }, nil } -func TestNewMemoryTaskManager(t *testing.T) { - processor := &mockProcessor{} - tm, err := NewMemoryTaskManager(processor) - require.NoError(t, err) - assert.NotNil(t, tm) - assert.NotNil(t, tm.Tasks) - assert.NotNil(t, tm.Messages) - assert.NotNil(t, tm.Subscribers) - assert.Equal(t, processor, tm.Processor) - - // Test error case - tm, err = NewMemoryTaskManager(nil) - assert.Error(t, err) - assert.Nil(t, tm) -} - -// assertTextPart is a helper function that asserts that a Part is a TextPart -// and contains the expected text. Returns the TextPart for further assertions. -func assertTextPart(t *testing.T, part protocol.Part, expectedText string) protocol.TextPart { - t.Helper() - textPart, ok := part.(protocol.TextPart) - require.True(t, ok, "Expected part to be TextPart") - if expectedText != "" { - assert.Contains(t, textPart.Text, expectedText, "TextPart should contain expected text") +// Helper function to extract text from message +func getTextFromMessage(message protocol.Message) string { + for _, part := range message.Parts { + if textPart, ok := part.(*protocol.TextPart); ok { + return textPart.Text + } } - return textPart + return "" } -// assertTaskStatus is a helper function that asserts a task has the expected state. -func assertTaskStatus(t *testing.T, task *protocol.Task, expectedID string, expectedState protocol.TaskState) { - t.Helper() - require.NotNil(t, task, "Task should not be nil") - assert.Equal(t, expectedID, task.ID, "Task ID should match") - assert.Equal(t, expectedState, task.Status.State, "Task state should match expected") -} - -// createTestTask creates a standard test task with the given ID and message text. -func createTestTask(id, messageText string) protocol.SendTaskParams { - return protocol.SendTaskParams{ - ID: id, - Message: protocol.Message{ - Role: protocol.MessageRoleUser, - Parts: []protocol.Part{protocol.NewTextPart(messageText)}, +func TestNewMemoryTaskManager(t *testing.T) { + tests := []struct { + name string + processor MessageProcessor + options []MemoryTaskManagerOption + wantErr bool + }{ + { + name: "valid processor", + processor: &MockMessageProcessor{}, + wantErr: false, + }, + { + name: "nil processor", + processor: nil, + wantErr: true, }, } -} -// Helper function to collect task events from a channel until completion or timeout -func collectTaskEvents(t *testing.T, eventChan <-chan protocol.TaskEvent, targetState protocol.TaskState, timeoutDuration time.Duration) []protocol.TaskEvent { - // Collect events with timeout - events := []protocol.TaskEvent{} - timeout := time.After(timeoutDuration) // Safety timeout - done := false - for !done { - select { - case event, ok := <-eventChan: - if !ok { - done = true // Channel closed - break - } - events = append(events, event) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager, err := NewMemoryTaskManager(tt.processor, tt.options...) - // If we receive a final event matching our target state, we can exit early - if statusEvent, ok := event.(protocol.TaskStatusUpdateEvent); ok && - statusEvent.Final && statusEvent.Status.State == targetState { - // We got what we need, break out - t.Logf("Received final %s event, breaking early", targetState) - done = true + if tt.wantErr { + if err == nil { + t.Error("Expected error but got none") + } + return } - case <-timeout: - t.Logf("Test timed out waiting for events, proceeding with test using %d collected events", len(events)) - done = true - } - } - return events -} -// TestMemTaskManager_OnSendTask_Sync tests the synchronous OnSendTask method. -func TestMemTaskManager_OnSendTask_Sync(t *testing.T) { - processor := &mockProcessor{} - tm, err := NewMemoryTaskManager(processor) - require.NoError(t, err) - - taskID := "test-sync-1" - params := createTestTask(taskID, "Sync Task") - - task, err := tm.OnSendTask(context.Background(), params) - require.NoError(t, err) - assertTaskStatus(t, task, taskID, protocol.TaskStateCompleted) // Default mock behavior - - // Verify message content - require.NotNil(t, task.Status.Message) - require.NotEmpty(t, task.Status.Message.Parts) - assertTextPart(t, task.Status.Message.Parts[0], "Mock Success") - - processor.mu.Lock() - assert.Equal(t, 1, processor.callCount) - assert.Equal(t, taskID, processor.lastTaskID) - processor.mu.Unlock() - - // Check stored message - there should be at least 3 messages now: - // 1. User message - // 2. Working status from mock processor - // 3. Completed status from mock processor - tm.MessagesMutex.RLock() - history, ok := tm.Messages[taskID] - tm.MessagesMutex.RUnlock() - require.True(t, ok) - require.GreaterOrEqual(t, len(history), 3, "Should have at least 3 messages in history") - - // Check first message is from user - assert.Equal(t, protocol.MessageRoleUser, history[0].Role) - textPartHistory := assertTextPart(t, history[0].Parts[0], "Sync Task") - assert.Equal(t, "Sync Task", textPartHistory.Text) - - // Check last message has completion status - lastMsg := history[len(history)-1] - assert.Equal(t, protocol.MessageRoleAgent, lastMsg.Role) - assertTextPart(t, lastMsg.Parts[0], "Mock Success") - - // Test processor error case - processor.processFunc = func(ctx context.Context, taskID string, msg protocol.Message, handle TaskHandle) error { - return fmt.Errorf("processor error") - } - taskID = "test-sync-err" - params.ID = taskID - task, err = tm.OnSendTask(context.Background(), params) - require.Error(t, err) - assertTaskStatus(t, task, taskID, protocol.TaskStateFailed) - - // Verify error message - require.NotNil(t, task.Status.Message) - require.NotEmpty(t, task.Status.Message.Parts) - assertTextPart(t, task.Status.Message.Parts[0], "processor error") -} - -func TestOnSendTaskSubAsync(t *testing.T) { - // Create processor with custom logic for this test - processor := &mockProcessor{ - processFunc: func(ctx context.Context, taskID string, msg protocol.Message, handle TaskHandle) error { - // Explicitly set working state - err := handle.UpdateStatus(protocol.TaskStateWorking, &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{protocol.NewTextPart("Mock working...")}, - }) if err != nil { - return err + t.Errorf("Unexpected error: %v", err) + return } - // Short delay - time.Sleep(5 * time.Millisecond) + if manager == nil { + t.Error("Expected manager but got nil") + return + } - // Check context - if ctx.Err() != nil { - return ctx.Err() + if manager.Processor != tt.processor { + t.Error("Processor not set correctly") } - // Set completed state - return handle.UpdateStatus(protocol.TaskStateCompleted, &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{protocol.NewTextPart("Mock Success")}, - }) - }, + if len(tt.options) > 0 && manager.maxHistoryLength != 50 { + t.Errorf("Expected MaxHistoryLength=50, got %d", manager.maxHistoryLength) + } + }) } +} - tm, err := NewMemoryTaskManager(processor) - require.NoError(t, err) - - taskID := "test-async-1" - params := createTestTask(taskID, "Async Task") - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - eventChan, err := tm.OnSendTaskSubscribe(ctx, params) - require.NoError(t, err) - require.NotNil(t, eventChan) - - // Use helper function to collect events - events := collectTaskEvents(t, eventChan, protocol.TaskStateCompleted, 3*time.Second) - - require.NotEmpty(t, events, "Should have received at least one event") - - // Find the Working event - var foundWorking bool - for _, event := range events { - if statusEvent, ok := event.(protocol.TaskStatusUpdateEvent); ok && - statusEvent.Status.State == protocol.TaskStateWorking { - foundWorking = true - break - } +func TestMemoryTaskManager_OnSendMessage(t *testing.T) { + processor := &MockMessageProcessor{} + manager, err := NewMemoryTaskManager(processor) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + ctx := context.Background() + + tests := []struct { + name string + request protocol.SendMessageParams + wantErr bool + }{ + { + name: "valid message", + request: protocol.SendMessageParams{ + Message: protocol.Message{ + Role: protocol.MessageRoleUser, + Parts: []protocol.Part{ + protocol.NewTextPart("Hello"), + }, + }, + }, + wantErr: false, + }, + { + name: "message with context", + request: protocol.SendMessageParams{ + Message: protocol.Message{ + Role: protocol.MessageRoleUser, + ContextID: stringPtr("test-context"), + Parts: []protocol.Part{ + protocol.NewTextPart("Hello with context"), + }, + }, + }, + wantErr: false, + }, } - assert.True(t, foundWorking, "Should have received a Working state event") - // Find the Completed event (final event) - var foundCompleted bool - for _, event := range events { - if statusEvent, ok := event.(protocol.TaskStatusUpdateEvent); ok && - statusEvent.Status.State == protocol.TaskStateCompleted && statusEvent.Final { - foundCompleted = true - - // Validate the completed event message - require.NotNil(t, statusEvent.Status.Message) - require.NotEmpty(t, statusEvent.Status.Message.Parts) - assertTextPart(t, statusEvent.Status.Message.Parts[0], "Mock Success") - break - } - } - assert.True(t, foundCompleted, "Should have received a Completed state event") - - // Double check task state via OnGetTask - task, err := tm.OnGetTask(context.Background(), protocol.TaskQueryParams{ID: taskID}) - require.NoError(t, err) - assertTaskStatus(t, task, taskID, protocol.TaskStateCompleted) - - // Check processor was called - processor.mu.Lock() - assert.Equal(t, 1, processor.callCount) - assert.Equal(t, taskID, processor.lastTaskID) - processor.mu.Unlock() - - // Check stored message - tm.MessagesMutex.RLock() - history, ok := tm.Messages[taskID] - tm.MessagesMutex.RUnlock() - require.True(t, ok) - assert.GreaterOrEqual(t, len(history), 2, "Should have at least 2 messages in history") - - // Check first message (from user) - require.NotEmpty(t, history[0].Parts) - textPartHistAsync := assertTextPart(t, history[0].Parts[0], "Async Task") - assert.Equal(t, "Async Task", textPartHistAsync.Text) -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := manager.OnSendMessage(ctx, tt.request) -func TestMemTaskMgr_OnSendTaskSub_Error(t *testing.T) { - errMsg := "async processor error" - processor := &mockProcessor{ - processFunc: func(ctx context.Context, taskID string, msg protocol.Message, handle TaskHandle) error { - // Simulate some work before failing - err := handle.UpdateStatus(protocol.TaskStateWorking, &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{protocol.NewTextPart("Working...")}, - }) - if err != nil { - return err + if tt.wantErr { + if err == nil { + t.Error("Expected error but got none") + } + return } - time.Sleep(5 * time.Millisecond) - // Check context - if ctx.Err() != nil { - return ctx.Err() + if err != nil { + t.Errorf("Unexpected error: %v", err) + return } - // Return error to simulate failure - return fmt.Errorf(errMsg) - }, - } - tm, err := NewMemoryTaskManager(processor) - require.NoError(t, err) - - taskID := "test-async-err-1" - params := createTestTask(taskID, "Async Fail Task") - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - eventChan, err := tm.OnSendTaskSubscribe(ctx, params) - require.NoError(t, err) - require.NotNil(t, eventChan) + if result == nil { + t.Error("Expected result but got nil") + return + } - // Use helper function to collect events - events := collectTaskEvents(t, eventChan, protocol.TaskStateFailed, 3*time.Second) + // Check that message was stored + if result.Result == nil { + t.Error("Expected result but got nil") + return + } - require.NotEmpty(t, events, "Should have received at least one event") + // Check if result is a message + if message, ok := result.Result.(*protocol.Message); ok { + if message.MessageID == "" { + t.Error("Expected message ID to be set") + } - // Find the Working event - var foundWorking bool - for _, event := range events { - if statusEvent, ok := event.(protocol.TaskStatusUpdateEvent); ok && - statusEvent.Status.State == protocol.TaskStateWorking { - foundWorking = true - break - } - } - assert.True(t, foundWorking, "Should have received a Working state event") + // Check that message is in storage + manager.mu.RLock() + _, exists := manager.Messages[message.MessageID] + manager.mu.RUnlock() - // Find the Failed event (final event) - var foundFailed bool - for _, event := range events { - if statusEvent, ok := event.(protocol.TaskStatusUpdateEvent); ok && - statusEvent.Status.State == protocol.TaskStateFailed && statusEvent.Final { - foundFailed = true - - // Validate the failed event message - require.NotNil(t, statusEvent.Status.Message) - require.NotEmpty(t, statusEvent.Status.Message.Parts) - assertTextPart(t, statusEvent.Status.Message.Parts[0], errMsg) - break - } + if !exists { + t.Error("Message not found in storage") + } + } + }) } - assert.True(t, foundFailed, "Should have received a Failed state event") - - // Double check task state via OnGetTask - task, err := tm.OnGetTask(context.Background(), protocol.TaskQueryParams{ID: taskID}) - require.NoError(t, err) - assertTaskStatus(t, task, taskID, protocol.TaskStateFailed) } -func TestMemoryTaskManager_OnGetTask(t *testing.T) { - processor := &mockProcessor{} - tm, err := NewMemoryTaskManager(processor) - require.NoError(t, err) - - taskID := "test-get-1" - // Explicitly create TextPart first - helloPart := protocol.NewTextPart("Hello") - _, okDirect := interface{}(helloPart).(protocol.TextPart) // Check value type - require.True(t, okDirect, "helloPart should be assertable to TextPart") - - userMsg := protocol.Message{Role: protocol.MessageRoleUser, Parts: []protocol.Part{helloPart}} - _, okInSliceImmediate := userMsg.Parts[0].(protocol.TextPart) // Check value type - require.True(t, okInSliceImmediate, "Part in userMsg slice should be assertable immediately to TextPart") - - params := protocol.SendTaskParams{ - ID: taskID, - Message: userMsg, - Metadata: map[string]interface{}{"meta1": "value1"}, - } - - // Send a task to create it - _, err = tm.OnSendTask(context.Background(), params) - require.NoError(t, err) - - // Get the task without history - getParams := protocol.TaskQueryParams{ID: taskID} - task, err := tm.OnGetTask(context.Background(), getParams) - require.NoError(t, err) - require.NotNil(t, task) - assert.Equal(t, taskID, task.ID) - assert.Equal(t, protocol.TaskStateCompleted, task.Status.State) // From mock processor - assert.Equal(t, "value1", task.Metadata["meta1"]) - assert.Nil(t, task.History, "History should be nil when not requested") - - // Get the task with history - histLen := 1 - getParams.HistoryLength = &histLen - task, err = tm.OnGetTask(context.Background(), getParams) - require.NoError(t, err) - require.NotNil(t, task) - require.NotNil(t, task.History, "History should not be nil when requested") - require.Len(t, task.History, 1) - // Use reflection to get the actual type and compare the text contents - historyPart := task.History[0].Parts[0] - require.NotNil(t, historyPart, "History part should not be nil") - - // Get the text content regardless of whether it's value or pointer - var historyText string - var historyPartTypeOK bool - - // Try both value and pointer type assertions - if textPart, ok := historyPart.(protocol.TextPart); ok { - historyText = textPart.Text - historyPartTypeOK = true - t.Logf("Found TextPart value type in history") - } else if textPartPtr, ok := historyPart.(*protocol.TextPart); ok { - historyText = textPartPtr.Text - historyPartTypeOK = true - t.Logf("Found *TextPart pointer type in history") - } else { - t.Logf("Expected TextPart or *TextPart but got %T", historyPart) - } - - // Accept either TextPart or *TextPart, the important part is the text content - require.True(t, historyPartTypeOK, "History part was not TextPart or *TextPart") +func TestMemoryTaskManager_OnSendMessageStream(t *testing.T) { + processor := &MockMessageProcessor{ + ProcessMessageFunc: func(ctx context.Context, message protocol.Message, options ProcessOptions, handle TaskHandler) (*MessageProcessingResult, error) { + // Create a task for streaming + taskID, err := handle.BuildTask(nil, message.ContextID) + if err != nil { + return nil, err + } - assert.Equal(t, "Mock Success", historyText) // Compare history text with the expected last message from the agent + subscriber, err := handle.SubScribeTask(&taskID) + if err != nil { + return nil, err + } - // Get non-existent task - getParams.ID = "non-existent-task" - task, err = tm.OnGetTask(context.Background(), getParams) - require.Error(t, err) + // Simulate async processing + go func() { + defer subscriber.Close() - // Check error type by asserting to JSONRPCError and comparing code - if rpcErr, ok := err.(*jsonrpc.Error); ok { - assert.Equal(t, ErrCodeTaskNotFound, rpcErr.Code) - assert.Equal(t, "Task not found", rpcErr.Message) - } else { - t.Errorf("Expected *jsonrpc.JSONRPCError but got %T", err) - } - assert.Nil(t, task) -} + // Send initial status update + handle.UpdateTaskState(&taskID, protocol.TaskStateWorking, nil) -func TestMemoryTaskManager_OnCancelTask(t *testing.T) { - // Setup a processor with a delayed execution to allow cancellation during processing - processor := &mockProcessor{ - processFunc: func(ctx context.Context, taskID string, msg protocol.Message, handle TaskHandle) error { - // Create a channel to track if context cancellation is received - done := make(chan struct{}) - canceled := make(chan struct{}) - - // Start a goroutine that will block until either context is cancelled or timeout - go func() { - select { - case <-ctx.Done(): - close(canceled) - case <-time.After(100 * time.Millisecond): // Reduced timeout to make test faster - // Should not reach here if cancellation works properly + // Complete task + finalMessage := &protocol.Message{ + Role: protocol.MessageRoleAgent, + Parts: []protocol.Part{ + protocol.NewTextPart("Streaming completed"), + }, } - close(done) + handle.UpdateTaskState(&taskID, protocol.TaskStateCompleted, finalMessage) }() - // Wait for the goroutine to complete - <-done - - // Check if cancellation was received - select { - case <-canceled: - return ctx.Err() // Return the context error (context.Canceled) - default: - return nil // Successfully completed without cancellation - } + return &MessageProcessingResult{ + StreamingEvents: subscriber, + }, nil }, } - tm, err := NewMemoryTaskManager(processor) - require.NoError(t, err) - - // Create a task - taskID := "test-cancel-task" - params := protocol.SendTaskParams{ - ID: taskID, - Message: protocol.Message{Role: protocol.MessageRoleUser, Parts: []protocol.Part{protocol.NewTextPart("Test task for cancellation")}}, + manager, err := NewMemoryTaskManager(processor) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) } - // Start task with subscription to monitor events - eventChan, err := tm.OnSendTaskSubscribe(context.Background(), params) - require.NoError(t, err) - - // Give task a moment to start processing - time.Sleep(50 * time.Millisecond) - - // Verify task is in working state before cancellation - task, err := tm.OnGetTask(context.Background(), protocol.TaskQueryParams{ID: taskID}) - require.NoError(t, err) - assert.Equal(t, protocol.TaskStateWorking, task.Status.State) - - // Now cancel the task - cancelParams := protocol.TaskIDParams{ID: taskID} - canceledTask, err := tm.OnCancelTask(context.Background(), cancelParams) - require.NoError(t, err) - require.NotNil(t, canceledTask) + ctx := context.Background() + request := protocol.SendMessageParams{ + Message: protocol.Message{ + Role: protocol.MessageRoleUser, + Parts: []protocol.Part{ + protocol.NewTextPart("Stream test"), + }, + }, + } - // Wait a little bit for the cancellation to fully propagate if needed - if canceledTask.Status.State != protocol.TaskStateCanceled { - // Poll for a short time until the task shows as canceled - deadline := time.Now().Add(100 * time.Millisecond) - for time.Now().Before(deadline) { - canceledTask, err = tm.OnGetTask(context.Background(), protocol.TaskQueryParams{ID: taskID}) - require.NoError(t, err) - if canceledTask.Status.State == protocol.TaskStateCanceled { - break - } - time.Sleep(10 * time.Millisecond) - } + eventChan, err := manager.OnSendMessageStream(ctx, request) + if err != nil { + t.Fatalf("Unexpected error: %v", err) } - assert.Equal(t, protocol.TaskStateCanceled, canceledTask.Status.State) + if eventChan == nil { + t.Fatal("Expected event channel but got nil") + } - // Collect events with timeout - var lastEvent protocol.TaskEvent - eventsCollected := false - timeout := time.After(1 * time.Second) + // Collect events with shorter timeout + var events []protocol.StreamingMessageEvent + timeout := time.After(500 * time.Millisecond) + eventCount := 0 - for !eventsCollected { + for { select { case event, ok := <-eventChan: if !ok { - eventsCollected = true // Channel closed - break + // Channel closed, test completed + goto CheckEvents } - lastEvent = event - // If we see the final canceled event, don't wait for channel close - if statusEvent, ok := event.(protocol.TaskStatusUpdateEvent); ok && - statusEvent.Status.State == protocol.TaskStateCanceled && statusEvent.Final { - eventsCollected = true + events = append(events, event) + eventCount++ + + // Stop after receiving some events to avoid infinite loop + if eventCount >= 10 { + goto CheckEvents } + case <-timeout: - t.Logf("Timeout waiting for event channel to close, proceeding with test") - eventsCollected = true + // Don't fail on timeout, just check what we got + goto CheckEvents } } - // Verify last event indicates cancellation if we got events - if lastEvent != nil { - statusEvent, ok := lastEvent.(protocol.TaskStatusUpdateEvent) - require.True(t, ok, "Expected TaskStatusUpdateEvent") - assert.Equal(t, taskID, statusEvent.ID) - assert.Equal(t, protocol.TaskStateCanceled, statusEvent.Status.State) - assert.True(t, statusEvent.Final) +CheckEvents: + if len(events) == 0 { + t.Error("Expected at least one event") + return } - // Test cancelling a non-existent task - _, err = tm.OnCancelTask(context.Background(), protocol.TaskIDParams{ID: "non-existent-task"}) - assert.Error(t, err) - // Check if the error is a jsonrpc.Error with the TaskNotFound error code - jsonRPCErr, ok := err.(*jsonrpc.Error) - assert.True(t, ok, "Expected jsonrpc.Error") - assert.Equal(t, ErrCodeTaskNotFound, jsonRPCErr.Code) - - // Test cancelling an already cancelled task - againCanceledTask, err := tm.OnCancelTask(context.Background(), cancelParams) - // Updated expectation: Task is already in final state, can't cancel again - // This returns the task but with an error indicating it's already in final state - assert.Error(t, err) - jsonRPCErr, ok = err.(*jsonrpc.Error) - assert.True(t, ok, "Expected jsonrpc.Error") - assert.Equal(t, ErrCodeTaskFinal, jsonRPCErr.Code) - assert.Equal(t, protocol.TaskStateCanceled, againCanceledTask.Status.State) - - // Test cancelling a completed task (should return task without error) - completedTaskID := "completed-task" - completedParams := protocol.SendTaskParams{ - ID: completedTaskID, - Message: protocol.Message{Role: protocol.MessageRoleUser, Parts: []protocol.Part{protocol.NewTextPart("Completed task")}}, + t.Logf("Received %d events", len(events)) + + // Should have received some events + hasStatusUpdate := false + for _, event := range events { + if event.Result != nil { + hasStatusUpdate = true + break + } } - // Use the basic mock processor behavior (completes task quickly) - processor.processFunc = nil - _, err = tm.OnSendTask(context.Background(), completedParams) - require.NoError(t, err) - - // Verify task is in completed state - completedTask, err := tm.OnGetTask(context.Background(), protocol.TaskQueryParams{ID: completedTaskID}) - require.NoError(t, err) - assert.Equal(t, protocol.TaskStateCompleted, completedTask.Status.State) - - // Try to cancel the completed task - againCompletedTask, err := tm.OnCancelTask(context.Background(), protocol.TaskIDParams{ID: completedTaskID}) - // Update expectation: can't cancel a task that's already completed - assert.Error(t, err) - jsonRPCErr, ok = err.(*jsonrpc.Error) - assert.True(t, ok, "Expected jsonrpc.Error for canceling completed task") - assert.Equal(t, ErrCodeTaskFinal, jsonRPCErr.Code) - assert.Equal(t, protocol.TaskStateCompleted, againCompletedTask.Status.State, - "Already completed task should remain in completed state after cancel attempt") + if !hasStatusUpdate { + t.Error("Expected at least one status update event") + } } -// --- Test Helpers --- - -func TestIsFinalState(t *testing.T) { - assert.True(t, isFinalState(protocol.TaskStateCompleted)) - assert.True(t, isFinalState(protocol.TaskStateFailed)) - assert.True(t, isFinalState(protocol.TaskStateCanceled)) - assert.False(t, isFinalState(protocol.TaskStateWorking)) - assert.False(t, isFinalState(protocol.TaskStateSubmitted)) // Check defined non-final state. - assert.False(t, isFinalState(protocol.TaskStateInputRequired)) // Check defined non-final state. - assert.False(t, isFinalState(protocol.TaskState("other"))) -} +func TestMemoryTaskManager_OnGetTask(t *testing.T) { + processor := &MockMessageProcessor{ + ProcessMessageFunc: func(ctx context.Context, message protocol.Message, options ProcessOptions, handle TaskHandler) (*MessageProcessingResult, error) { + // Create a task for testing + taskID, err := handle.BuildTask(nil, message.ContextID) + if err != nil { + return nil, err + } -func TestMemTaskManagerPushNotif(t *testing.T) { - processor := &mockProcessor{} - tm, err := NewMemoryTaskManager(processor) - require.NoError(t, err) + // Get the actual task object + task, err := handle.GetTask(&taskID) + if err != nil { + return nil, err + } - // Create a task first - taskID := "push-notification-task" - params := protocol.SendTaskParams{ - ID: taskID, - Message: protocol.Message{Role: protocol.MessageRoleUser, Parts: []protocol.Part{protocol.NewTextPart("Test task for push notifications")}}, + return &MessageProcessingResult{ + Result: task.Task(), // Return protocol.Task, not CancellableTask + }, nil + }, + } + manager, err := NewMemoryTaskManager(processor) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) } - // Start the task - _, err = tm.OnSendTask(context.Background(), params) - require.NoError(t, err) + ctx := context.Background() - // Test setting push notification config - pushConfig := protocol.TaskPushNotificationConfig{ - ID: taskID, - PushNotificationConfig: protocol.PushNotificationConfig{ - URL: "https://example.com/webhook", - Token: "test-token", - Authentication: &protocol.AuthenticationInfo{ - Schemes: []string{"Bearer"}, - Credentials: stringPtr("Bearer test-token"), - }, - Metadata: map[string]interface{}{ - "priority": "high", + // First create a task by sending a message + request := protocol.SendMessageParams{ + Message: protocol.Message{ + Role: protocol.MessageRoleUser, + Parts: []protocol.Part{ + protocol.NewTextPart("Test"), }, }, } - // Test OnPushNotificationSet - resultConfig, err := tm.OnPushNotificationSet(context.Background(), pushConfig) - require.NoError(t, err) - require.NotNil(t, resultConfig) - assert.Equal(t, taskID, resultConfig.ID) - assert.Equal(t, "https://example.com/webhook", resultConfig.PushNotificationConfig.URL) - assert.Equal(t, "test-token", resultConfig.PushNotificationConfig.Token) - require.NotNil(t, resultConfig.PushNotificationConfig.Authentication) - assert.Equal(t, []string{"Bearer"}, resultConfig.PushNotificationConfig.Authentication.Schemes) - require.NotNil(t, resultConfig.PushNotificationConfig.Authentication.Credentials) - assert.Equal(t, "Bearer test-token", *resultConfig.PushNotificationConfig.Authentication.Credentials) - assert.Equal(t, "high", resultConfig.PushNotificationConfig.Metadata["priority"]) - - // Test OnPushNotificationGet - getParams := protocol.TaskIDParams{ID: taskID} - fetchedConfig, err := tm.OnPushNotificationGet(context.Background(), getParams) - require.NoError(t, err) - require.NotNil(t, fetchedConfig) - assert.Equal(t, taskID, fetchedConfig.ID) - assert.Equal(t, "https://example.com/webhook", fetchedConfig.PushNotificationConfig.URL) - assert.Equal(t, "test-token", fetchedConfig.PushNotificationConfig.Token) - require.NotNil(t, fetchedConfig.PushNotificationConfig.Authentication) - assert.Equal(t, []string{"Bearer"}, fetchedConfig.PushNotificationConfig.Authentication.Schemes) - require.NotNil(t, fetchedConfig.PushNotificationConfig.Authentication.Credentials) - assert.Equal(t, "Bearer test-token", *fetchedConfig.PushNotificationConfig.Authentication.Credentials) - require.NotNil(t, fetchedConfig.PushNotificationConfig.Metadata) - assert.Equal(t, "high", fetchedConfig.PushNotificationConfig.Metadata["priority"]) - - // Test setting push notification for non-existent task - nonExistentConfig := protocol.TaskPushNotificationConfig{ - ID: "non-existent-task", - PushNotificationConfig: protocol.PushNotificationConfig{ - URL: "https://example.com/webhook", - }, - } - _, err = tm.OnPushNotificationSet(context.Background(), nonExistentConfig) - assert.NoError(t, err) - - // Test getting push notification for non-existent task - _, err = tm.OnPushNotificationGet(context.Background(), protocol.TaskIDParams{ID: "non-existent-task"}) - assert.Error(t, err) - jsonRPCErr, ok := err.(*jsonrpc.Error) - assert.True(t, ok, "Expected jsonrpc.Error") - assert.Equal(t, ErrCodeTaskNotFound, jsonRPCErr.Code) - - // Test getting push notification for task without config - // Create a new task without push notification config - newTaskID := "task-without-push-config" - newParams := protocol.SendTaskParams{ - ID: newTaskID, - Message: protocol.Message{Role: protocol.MessageRoleUser, Parts: []protocol.Part{protocol.NewTextPart("Task without push config")}}, + result, err := manager.OnSendMessage(ctx, request) + if err != nil { + t.Fatalf("Failed to send message: %v", err) } - _, err = tm.OnSendTask(context.Background(), newParams) - require.NoError(t, err) - - // Try to get push notification config - _, err = tm.OnPushNotificationGet(context.Background(), protocol.TaskIDParams{ID: newTaskID}) - assert.Error(t, err) - jsonRPCErr, ok = err.(*jsonrpc.Error) - assert.True(t, ok, "Expected jsonrpc.Error") - assert.Equal(t, ErrCodePushNotificationNotConfigured, jsonRPCErr.Code) - - // Update the push notification config - updatedConfig := protocol.TaskPushNotificationConfig{ - ID: taskID, - PushNotificationConfig: protocol.PushNotificationConfig{ - URL: "https://updated-example.com/webhook", - Token: "updated-token", - Authentication: &protocol.AuthenticationInfo{ - Schemes: []string{"Bearer"}, - Credentials: stringPtr("Bearer updated-token"), + + var existingTaskID string + if task, ok := result.Result.(*protocol.Task); ok { + existingTaskID = task.ID + } else { + t.Fatalf("Expected task result but got %T", result.Result) + } + + tests := []struct { + name string + params protocol.TaskQueryParams + wantErr bool + validate func(*testing.T, *protocol.Task, error) + }{ + { + name: "get existing task", + params: protocol.TaskQueryParams{ + ID: existingTaskID, + }, + wantErr: false, + validate: func(t *testing.T, task *protocol.Task, err error) { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if task == nil { + t.Error("Expected task but got nil") + } + if task != nil && task.ID != existingTaskID { + t.Errorf("Expected task ID %s, got %s", existingTaskID, task.ID) + } + }, + }, + { + name: "get non-existent task", + params: protocol.TaskQueryParams{ + ID: "non-existent-task", + }, + wantErr: true, + validate: func(t *testing.T, task *protocol.Task, err error) { + if err == nil { + t.Error("Expected error for non-existent task") + } + if task != nil { + t.Error("Expected nil task for error case") + } + }, + }, + { + name: "empty task ID", + params: protocol.TaskQueryParams{ + ID: "", + }, + wantErr: true, + validate: func(t *testing.T, task *protocol.Task, err error) { + if err == nil { + t.Error("Expected error for empty task ID") + } }, }, } - updatedResult, err := tm.OnPushNotificationSet(context.Background(), updatedConfig) - require.NoError(t, err) - assert.Equal(t, "https://updated-example.com/webhook", updatedResult.PushNotificationConfig.URL) - assert.Equal(t, "updated-token", updatedResult.PushNotificationConfig.Token) - require.NotNil(t, updatedResult.PushNotificationConfig.Authentication) - assert.Equal(t, []string{"Bearer"}, updatedResult.PushNotificationConfig.Authentication.Schemes) - require.NotNil(t, updatedResult.PushNotificationConfig.Authentication.Credentials) - assert.Equal(t, "Bearer updated-token", *updatedResult.PushNotificationConfig.Authentication.Credentials) - - // Fetch again to verify update - fetchedUpdatedConfig, err := tm.OnPushNotificationGet(context.Background(), getParams) - require.NoError(t, err) - assert.Equal(t, "https://updated-example.com/webhook", fetchedUpdatedConfig.PushNotificationConfig.URL) - assert.Equal(t, "updated-token", fetchedUpdatedConfig.PushNotificationConfig.Token) - require.NotNil(t, fetchedUpdatedConfig.PushNotificationConfig.Authentication) - assert.Equal(t, []string{"Bearer"}, fetchedUpdatedConfig.PushNotificationConfig.Authentication.Schemes) - require.NotNil(t, fetchedUpdatedConfig.PushNotificationConfig.Authentication.Credentials) - assert.Equal(t, "Bearer updated-token", *fetchedUpdatedConfig.PushNotificationConfig.Authentication.Credentials) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + getTask, err := manager.OnGetTask(ctx, tt.params) + tt.validate(t, getTask, err) + }) + } } -func TestMemoryTaskManager_OnResubscribe(t *testing.T) { - // Create a processor that will take longer to complete so we can test resubscribe - processor := &mockProcessor{ - processFunc: func(ctx context.Context, taskID string, msg protocol.Message, handle TaskHandle) error { - // Set initial status and send an intermediate message - err := handle.UpdateStatus(protocol.TaskStateWorking, &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{protocol.NewTextPart("Working on task...")}, - }) +func TestMemoryTaskManager_OnCancelTask(t *testing.T) { + processor := &MockMessageProcessor{ + ProcessMessageFunc: func(ctx context.Context, message protocol.Message, options ProcessOptions, handle TaskHandler) (*MessageProcessingResult, error) { + // Create a task for testing cancellation + taskID, err := handle.BuildTask(nil, message.ContextID) if err != nil { - return err + return nil, err } - // Sleep a short time to simulate work - time.Sleep(50 * time.Millisecond) - - // Check if the context was cancelled - if ctx.Err() != nil { - return ctx.Err() + // Get the actual task object + task, err := handle.GetTask(&taskID) + if err != nil { + return nil, err } - // Complete the task - err = handle.UpdateStatus(protocol.TaskStateCompleted, &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{protocol.NewTextPart("Task completed!")}, - }) - return err + return &MessageProcessingResult{ + Result: task.Task(), + }, nil }, } - - tm, err := NewMemoryTaskManager(processor) - require.NoError(t, err) - - // Create a task - taskID := "resubscribe-task" - params := protocol.SendTaskParams{ - ID: taskID, - Message: protocol.Message{Role: protocol.MessageRoleUser, Parts: []protocol.Part{protocol.NewTextPart("Test task for resubscribe")}}, + manager, err := NewMemoryTaskManager(processor) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) } - // Start task with subscription - originalEventChan, err := tm.OnSendTaskSubscribe(context.Background(), params) - require.NoError(t, err) - require.NotNil(t, originalEventChan) - - // Wait for task to start processing (wait for Working state) - var receivedWorkingEvent bool - for event := range originalEventChan { - statusEvent, ok := event.(protocol.TaskStatusUpdateEvent) - if ok && statusEvent.Status.State == protocol.TaskStateWorking { - receivedWorkingEvent = true - break - } - if event.IsFinal() { - break - } - } - assert.True(t, receivedWorkingEvent, "Should have received Working state event") - - // Now simulate a client disconnect and reconnect by resubscribing - resubscribeParams := protocol.TaskIDParams{ID: taskID} - resubscribeEventChan, err := tm.OnResubscribe(context.Background(), resubscribeParams) - require.NoError(t, err) - require.NotNil(t, resubscribeEventChan, "Should get a valid event channel from resubscribe") - - // Read events from the resubscribe channel until we get a final event - var gotFinalEvent bool - var statusUpdateEvent protocol.TaskStatusUpdateEvent - - for event := range resubscribeEventChan { - if event.IsFinal() { - // Try to type assert it to a status update event - statusUpdate, ok := event.(protocol.TaskStatusUpdateEvent) - if ok { - statusUpdateEvent = statusUpdate - gotFinalEvent = true - } - break - } - } + ctx := context.Background() - // There should be a final event - assert.True(t, gotFinalEvent, "Should have received a final event") - assert.Equal(t, protocol.TaskStateCompleted, statusUpdateEvent.Status.State) - assert.True(t, statusUpdateEvent.Final) - - // Test resubscribing to a non-existent task - _, err = tm.OnResubscribe(context.Background(), protocol.TaskIDParams{ID: "non-existent-task"}) - assert.Error(t, err) - jsonRPCErr, ok := err.(*jsonrpc.Error) - assert.True(t, ok, "Expected jsonrpc.Error") - assert.Equal(t, ErrCodeTaskNotFound, jsonRPCErr.Code) - - // Test resubscribing to an already completed task - // Should get a channel with the final event and then close - completedTaskID := "completed-resubscribe-task" - completedParams := protocol.SendTaskParams{ - ID: completedTaskID, - Message: protocol.Message{Role: protocol.MessageRoleUser, Parts: []protocol.Part{protocol.NewTextPart("Already completed task")}}, + // Create a task first + request := protocol.SendMessageParams{ + Message: protocol.Message{ + Role: protocol.MessageRoleUser, + Parts: []protocol.Part{ + protocol.NewTextPart("Test"), + }, + }, } - _, err = tm.OnSendTask(context.Background(), completedParams) - require.NoError(t, err) - - // The task should be completed now, attempt to resubscribe - completedResubChan, err := tm.OnResubscribe( - context.Background(), protocol.TaskIDParams{ID: completedTaskID}, - ) - require.NoError(t, err) - require.NotNil(t, completedResubChan) - - // Read all events from the channel - completedEvents := []protocol.TaskEvent{} - for event := range completedResubChan { - completedEvents = append(completedEvents, event) + result, err := manager.OnSendMessage(ctx, request) + if err != nil { + t.Fatalf("Failed to send message: %v", err) } - // Should have received a single event with the final status - require.Len(t, completedEvents, 1, "Should get exactly one event for a completed task") - completedStatusEvent, ok := completedEvents[0].(protocol.TaskStatusUpdateEvent) - require.True(t, ok, "Event should be a TaskStatusUpdateEvent") - assert.Equal(t, protocol.TaskStateCompleted, completedStatusEvent.Status.State) - assert.True(t, completedStatusEvent.Final) -} - -// Helper function to create string pointers -func stringPtr(s string) *string { - return &s -} - -// TestAddArtifact tests the AddArtifact method of TaskHandle and MemoryTaskManager -func TestAddArtifact(t *testing.T) { - processor := &mockProcessor{} - tm, err := NewMemoryTaskManager(processor) - require.NoError(t, err) - - taskID := "test-artifact-task" - params := createTestTask(taskID, "Test with artifacts") - - // Helper function to create bool pointers - boolPtr := func(b bool) *bool { - return &b + // Extract task from result + var taskID string + if task, ok := result.Result.(*protocol.Task); ok { + taskID = task.ID + } else { + t.Fatal("Expected task result but got different type") } - // Create a task and get its handle - processor.processFunc = func(ctx context.Context, taskID string, msg protocol.Message, handle TaskHandle) error { - // Test adding an artifact to the task - textPart := protocol.NewTextPart("Artifact content") - err := handle.AddArtifact(protocol.Artifact{ - Name: stringPtr("test-artifact"), - Description: stringPtr("A test artifact"), - Parts: []protocol.Part{textPart}, - LastChunk: boolPtr(true), - }) - assert.NoError(t, err, "Adding artifact should succeed") - - // Test adding a streaming artifact (multiple chunks) - firstChunkPart := protocol.NewTextPart("First chunk") - err = handle.AddArtifact(protocol.Artifact{ - Name: stringPtr("streaming-artifact"), - Description: stringPtr("A streaming artifact"), - Parts: []protocol.Part{firstChunkPart}, - LastChunk: boolPtr(false), - }) - assert.NoError(t, err, "Adding first chunk should succeed") - - lastChunkPart := protocol.NewTextPart("Last chunk") - err = handle.AddArtifact(protocol.Artifact{ - Name: stringPtr("streaming-artifact"), - Description: stringPtr("A streaming artifact"), - Parts: []protocol.Part{lastChunkPart}, - LastChunk: boolPtr(true), - }) - assert.NoError(t, err, "Adding last chunk should succeed") - - return handle.UpdateStatus(protocol.TaskStateCompleted, nil) + // Cancel the task + cancelParams := protocol.TaskIDParams{ + ID: taskID, } - // Run the task - task, err := tm.OnSendTask(context.Background(), params) - require.NoError(t, err) - assert.Equal(t, protocol.TaskStateCompleted, task.Status.State) - - // Verify the artifacts are present in the task - // NOTE: All artifacts are stored, not just the final ones - require.Len(t, task.Artifacts, 3, "Task should have 3 artifacts") - - // Verify first artifact - assert.Equal(t, "test-artifact", *task.Artifacts[0].Name) - assert.Equal(t, "A test artifact", *task.Artifacts[0].Description) - assert.True(t, *task.Artifacts[0].LastChunk) - require.Len(t, task.Artifacts[0].Parts, 1) - textPart, ok := task.Artifacts[0].Parts[0].(protocol.TextPart) - require.True(t, ok) - assert.Equal(t, "Artifact content", textPart.Text) - - // Verify first streaming chunk - assert.Equal(t, "streaming-artifact", *task.Artifacts[1].Name) - assert.Equal(t, "A streaming artifact", *task.Artifacts[1].Description) - assert.False(t, *task.Artifacts[1].LastChunk) - - // Verify last streaming chunk - assert.Equal(t, "streaming-artifact", *task.Artifacts[2].Name) - assert.Equal(t, "A streaming artifact", *task.Artifacts[2].Description) - assert.True(t, *task.Artifacts[2].LastChunk) - require.Len(t, task.Artifacts[2].Parts, 1) - textPart, ok = task.Artifacts[2].Parts[0].(protocol.TextPart) - require.True(t, ok) - assert.Equal(t, "Last chunk", textPart.Text) - - // Test error case: add artifact to non-existent task - memTask := &MemoryTaskManager{ - Tasks: make(map[string]*protocol.Task), - Messages: make(map[string][]protocol.Message), - Subscribers: make(map[string][]chan<- protocol.TaskEvent), - Processor: processor, + canceledTask, err := manager.OnCancelTask(ctx, cancelParams) + if err != nil { + t.Errorf("Unexpected error: %v", err) } - handle := &memoryTaskHandle{ - taskID: "non-existent-task", - manager: memTask, + if canceledTask == nil { + t.Error("Expected canceled task but got nil") } - err = handle.AddArtifact(protocol.Artifact{ - Name: stringPtr("test-artifact"), - Parts: []protocol.Part{protocol.NewTextPart("Test content")}, - LastChunk: boolPtr(true), - }) - assert.Error(t, err, "Adding artifact to non-existent task should fail") - assert.Contains(t, err.Error(), "not found", "Error should indicate task not found") + if canceledTask.Status.State != protocol.TaskStateCanceled { + t.Errorf("Expected task state to be canceled, got %s", canceledTask.Status.State) + } } -// TestIsStreamingRequest tests the IsStreamingRequest method of TaskHandle -func TestIsStreamingRequest(t *testing.T) { - processor := &mockProcessor{} - tm, err := NewMemoryTaskManager(processor) - require.NoError(t, err) - - // Test with a streaming task (OnSendTaskSubscribe) - taskID := "test-streaming-task" - params := createTestTask(taskID, "Streaming task") - - // Create a complex processor to test IsStreamingRequest - processor.processFunc = func(ctx context.Context, taskID string, msg protocol.Message, handle TaskHandle) error { - // Check if this is a streaming request - isStreaming := handle.IsStreamingRequest() - assert.True(t, isStreaming, "Task should be identified as streaming") +func TestMemoryTaskManager_PushNotifications(t *testing.T) { + processor := &MockMessageProcessor{} + manager, err := NewMemoryTaskManager(processor) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + ctx := context.Background() + + tests := []struct { + name string + action string // "set" or "get" + taskID string + config *protocol.TaskPushNotificationConfig + getParams *protocol.TaskIDParams + wantErr bool + validate func(*testing.T, interface{}, error) + }{ + { + name: "set push notification", + action: "set", + taskID: "test-task-id", + config: &protocol.TaskPushNotificationConfig{ + TaskID: "test-task-id", + PushNotificationConfig: protocol.PushNotificationConfig{ + URL: "https://example.com/webhook", + Token: "Bearer token", + }, + }, + wantErr: false, + validate: func(t *testing.T, result interface{}, err error) { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if result == nil { + t.Error("Expected set result but got nil") + } + }, + }, + { + name: "get push notification", + action: "get", + taskID: "test-task-id", + getParams: &protocol.TaskIDParams{ + ID: "test-task-id", + }, + wantErr: false, + validate: func(t *testing.T, result interface{}, err error) { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if result == nil { + t.Error("Expected get result but got nil") + return + } - // Update status and finish - return handle.UpdateStatus(protocol.TaskStateCompleted, nil) + if getResult, ok := result.(*protocol.TaskPushNotificationConfig); ok { + expectedURL := "https://example.com/webhook" + if getResult.PushNotificationConfig.URL != expectedURL { + t.Errorf("Expected URL %s, got %s", expectedURL, getResult.PushNotificationConfig.URL) + } + } else { + t.Errorf("Expected TaskPushNotificationConfig, got %T", result) + } + }, + }, + { + name: "get non-existent push notification", + action: "get", + taskID: "non-existent-task", + getParams: &protocol.TaskIDParams{ + ID: "non-existent-task", + }, + wantErr: true, + validate: func(t *testing.T, result interface{}, err error) { + if err == nil { + t.Error("Expected error for non-existent task") + } + }, + }, } - // Start a streaming task - eventsChan, err := tm.OnSendTaskSubscribe(context.Background(), params) - require.NoError(t, err) - - // Collect events to avoid blocking - go func() { - for range eventsChan { - // Just consume events - } - }() - - // Let the task complete - time.Sleep(50 * time.Millisecond) - - // Test with a non-streaming task (OnSendTask) - taskID = "test-nonstreaming-task" - params = createTestTask(taskID, "Non-streaming task") - - // Update processor for the second task - processor.processFunc = func(ctx context.Context, taskID string, msg protocol.Message, handle TaskHandle) error { - // Check if this is a streaming request - isStreaming := handle.IsStreamingRequest() - assert.False(t, isStreaming, "Task should not be identified as streaming") - - // Update status and finish - return handle.UpdateStatus(protocol.TaskStateCompleted, nil) + // First set up a push notification for the get test + setupConfig := protocol.TaskPushNotificationConfig{ + TaskID: "test-task-id", + PushNotificationConfig: protocol.PushNotificationConfig{ + URL: "https://example.com/webhook", + Token: "Bearer token", + }, + } + _, err = manager.OnPushNotificationSet(ctx, setupConfig) + if err != nil { + t.Fatalf("Failed to set up push notification: %v", err) } - // Start a non-streaming task - _, err = tm.OnSendTask(context.Background(), params) - require.NoError(t, err) -} - -// TestProcessError tests the processError helper function -func TestProcessError(t *testing.T) { - // Access to unexported processError function through reflection - tm, err := NewMemoryTaskManager(&mockProcessor{}) - require.NoError(t, err) - - // Custom error type for testing - errTaskNotFound := ErrTaskNotFound("task-id") - - // Test with a pre-defined error - result := tm.processError(errTaskNotFound) - assert.Equal(t, errTaskNotFound.Error(), result.Error(), "Pre-defined errors should be returned as-is") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var result interface{} + var err error - // Test with a random error - randomErr := fmt.Errorf("random error") - result = tm.processError(randomErr) - assert.ErrorIs(t, result, randomErr, "Random errors should be wrapped") - assert.Contains(t, result.Error(), "random error", "Error message should be preserved") -} + switch tt.action { + case "set": + if tt.config != nil { + result, err = manager.OnPushNotificationSet(ctx, *tt.config) + } + case "get": + if tt.getParams != nil { + result, err = manager.OnPushNotificationGet(ctx, *tt.getParams) + } + default: + t.Fatalf("Unknown action: %s", tt.action) + } -// Helper extension of MemoryTaskManager to expose processError -func (m *MemoryTaskManager) processError(err error) error { - // Call the unexported processError through composition - return processError(err) + tt.validate(t, result, err) + }) + } } -// TestRemoveSubscriber tests the removeSubscriber internal function -func TestRemoveSubscriber(t *testing.T) { - processor := &mockProcessor{} - tm, err := NewMemoryTaskManager(processor) - require.NoError(t, err) - - taskID := "test-subscriber-task" - - // Create some test channels - ch1 := make(chan protocol.TaskEvent, 5) - ch2 := make(chan protocol.TaskEvent, 5) - ch3 := make(chan protocol.TaskEvent, 5) - - // Add subscribers - tm.SubMutex.Lock() - tm.Subscribers[taskID] = []chan<- protocol.TaskEvent{ch1, ch2, ch3} - tm.SubMutex.Unlock() - - // Test removing a subscriber - tm.removeSubscriber(taskID, ch2) - - // Verify ch2 was removed - tm.SubMutex.RLock() - subscribers := tm.Subscribers[taskID] - tm.SubMutex.RUnlock() - - require.Len(t, subscribers, 2) - // We can't do direct channel comparisons, so check the length instead - // and verify the first and last elements - assert.Equal(t, 2, len(subscribers)) - - // Test removing another subscriber - tm.removeSubscriber(taskID, ch1) - - // Verify ch1 was removed - tm.SubMutex.RLock() - subscribers = tm.Subscribers[taskID] - tm.SubMutex.RUnlock() - - require.Len(t, subscribers, 1) - // Since we removed ch1 and ch2, only ch3 should remain - - // Test removing the last subscriber - tm.removeSubscriber(taskID, ch3) - - // Verify the task ID is no longer in the subscribers map - tm.SubMutex.RLock() - _, exists := tm.Subscribers[taskID] - tm.SubMutex.RUnlock() - - assert.False(t, exists, "Task ID should be removed from subscribers map when last subscriber is removed") - - // Test removing from a non-existent task - // This should not panic - tm.removeSubscriber("non-existent-task", ch1) - - // Test removing a non-existent channel - nonExistentCh := make(chan protocol.TaskEvent) - tm.SubMutex.Lock() - tm.Subscribers[taskID] = []chan<- protocol.TaskEvent{ch1} - tm.SubMutex.Unlock() +func TestTaskSubscriber(t *testing.T) { + tests := []struct { + name string + taskID string + capacity int + setup func(*MemoryTaskSubscriber) // Setup function to perform actions + validate func(*testing.T, *MemoryTaskSubscriber) // Validation function + }{ + { + name: "create subscriber", + taskID: "test-task", + capacity: 5, + setup: func(s *MemoryTaskSubscriber) {}, + validate: func(t *testing.T, s *MemoryTaskSubscriber) { + if s.taskID != "test-task" { + t.Errorf("Expected task ID %s, got %s", "test-task", s.taskID) + } + if s.Closed() { + t.Error("Expected subscriber to be open") + } + }, + }, + { + name: "send and receive event", + taskID: "test-task-2", + capacity: 5, + setup: func(s *MemoryTaskSubscriber) { + event := protocol.StreamingMessageEvent{ + Result: &protocol.Message{ + Role: protocol.MessageRoleAgent, + Parts: []protocol.Part{ + protocol.NewTextPart("Test event"), + }, + }, + } + err := s.Send(event) + if err != nil { + t.Errorf("Unexpected error sending event: %v", err) + } + }, + validate: func(t *testing.T, s *MemoryTaskSubscriber) { + select { + case receivedEvent := <-s.eventQueue: + if receivedEvent.Result == nil { + t.Error("Expected event result but got nil") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timeout waiting for event") + } + }, + }, + { + name: "close subscriber", + taskID: "test-task-3", + capacity: 5, + setup: func(s *MemoryTaskSubscriber) { + s.Close() + }, + validate: func(t *testing.T, s *MemoryTaskSubscriber) { + if !s.Closed() { + t.Error("Expected subscriber to be closed") + } - tm.removeSubscriber(taskID, nonExistentCh) + // Test sending to closed subscriber + event := protocol.StreamingMessageEvent{ + Result: &protocol.Message{Role: protocol.MessageRoleAgent}, + } + err := s.Send(event) + if err == nil { + t.Error("Expected error when sending to closed subscriber") + } + }, + }, + } - // Verify ch1 is still there - tm.SubMutex.RLock() - subscribers = tm.Subscribers[taskID] - tm.SubMutex.RUnlock() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + subscriber := NewMemoryTaskSubscriber(tt.taskID, tt.capacity) - require.Len(t, subscribers, 1) + tt.setup(subscriber) + tt.validate(t, subscriber) + }) + } } -// Test push notification functionality -func TestMemoryTaskManager_PushNotifications(t *testing.T) { - processor := &mockProcessor{} - tm, err := NewMemoryTaskManager(processor) - require.NoError(t, err) - - // Cast to the concrete type - memTM := tm - - // Set up push notification config - taskID := "test-task-123" - url := "http://example.com/webhook" +func TestCancellableTask(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) - // Add task first (since we need a task to register notifications) - task := &protocol.Task{ - ID: taskID, - Status: protocol.TaskStatus{ - State: "pending", + task := &MemoryCancellableTask{ + task: protocol.Task{ + ID: "test-task", + Status: protocol.TaskStatus{State: protocol.TaskStateSubmitted}, }, + cancelFunc: cancel, + ctx: ctx, } - memTM.TasksMutex.Lock() - memTM.Tasks[taskID] = task - memTM.TasksMutex.Unlock() - - // Set up push notification directly - config := protocol.PushNotificationConfig{ - URL: url, - Authentication: &protocol.AuthenticationInfo{ - Schemes: []string{"bearer"}, - }, - Metadata: map[string]interface{}{ - "jwksUrl": "http://example.com/jwks", - }, + + // Test cancellation + task.Cancel() + + select { + case <-ctx.Done(): + // Expected + case <-time.After(100 * time.Millisecond): + t.Error("Expected context to be canceled") } +} - // Register push notification directly - memTM.PushNotificationsMutex.Lock() - memTM.PushNotifications[taskID] = config - memTM.PushNotificationsMutex.Unlock() - - // Verify the notification was registered - memTM.PushNotificationsMutex.RLock() - notification, exists := memTM.PushNotifications[taskID] - memTM.PushNotificationsMutex.RUnlock() - - assert.True(t, exists) - assert.Equal(t, url, notification.URL) - assert.NotNil(t, notification.Authentication) - assert.Equal(t, "bearer", notification.Authentication.Schemes[0]) - assert.Equal(t, "http://example.com/jwks", notification.Metadata["jwksUrl"]) +func stringPtr(s string) *string { + return &s } diff --git a/taskmanager/redis/README.md b/taskmanager/redis/README.md deleted file mode 100644 index d576c10..0000000 --- a/taskmanager/redis/README.md +++ /dev/null @@ -1,139 +0,0 @@ -# Redis Task Manager for A2A - -This package provides a Redis-based implementation of the A2A TaskManager interface, allowing for persistent storage of tasks and messages using Redis. - -## Features - -- Persistent storage of tasks and task history -- Support for all TaskManager operations (send task, subscribe, cancel, etc.) -- Configurable key expiration time -- Compatible with Redis clusters, sentinel, and standalone configurations -- Thread-safe implementation -- Graceful cleanup of resources - -## Requirements - -- Go 1.21 or later -- Redis 6.0 or later (recommended) -- github.com/redis/go-redis/v9 library - -## Installation - -```bash -go get trpc.group/trpc-go/trpc-a2a-go/taskmanager/redis -``` - -## Usage - -### Basic Usage - -```go -import ( - "context" - "log" - "time" - - "github.com/redis/go-redis/v9" - "trpc.group/trpc-go/trpc-a2a-go/taskmanager" - redismgr "trpc.group/trpc-go/trpc-a2a-go/taskmanager/redis" -) - -func main() { - // Create your task processor implementation. - processor := &MyTaskProcessor{} - - // Configure Redis connection. - redisOptions := &redis.UniversalOptions{ - Addrs: []string{"localhost:6379"}, - Password: "", // no password - DB: 0, // use default DB - } - - // Create Redis task manager. - manager, err := redismgr.NewRedisTaskManager(processor, redismgr.Options{ - RedisOptions: redisOptions, - }) - if err != nil { - log.Fatalf("Failed to create Redis task manager: %v", err) - } - defer manager.Close() - - // Use the task manager... -} -``` - -### Configuring Key Expiration - -By default, task and message data in Redis will expire after 30 days. You can customize this: - -```go -// Set custom expiration time. -expiration := 7 * 24 * time.Hour // 7 days - -manager, err := redismgr.NewRedisTaskManager(processor, redismgr.Options{ - RedisOptions: redisOptions, - Expiration: &expiration, -}) -``` - -### Using with Redis Cluster - -```go -redisOptions := &redis.UniversalOptions{ - Addrs: []string{ - "redis-node-1:6379", - "redis-node-2:6379", - "redis-node-3:6379", - }, - RouteByLatency: true, -} - -manager, err := redismgr.NewRedisTaskManager(processor, redismgr.Options{ - RedisOptions: redisOptions, -}) -``` - -### Using with Redis Sentinel - -```go -redisOptions := &redis.UniversalOptions{ - Addrs: []string{"sentinel-1:26379", "sentinel-2:26379"}, - MasterName: "mymaster", -} - -manager, err := redismgr.NewRedisTaskManager(processor, redismgr.Options{ - RedisOptions: redisOptions, -}) -``` - -## Implementation Details - -### Redis Key Prefixes - -The implementation uses the following key patterns in Redis: - -- `task:ID` - Stores the serialized Task object -- `msg:ID` - Stores the message history as a Redis list -- `push:ID` - Stores push notification configuration - -### Task Subscribers - -While tasks and messages are stored in Redis, subscribers for streaming updates are maintained in memory. If your application requires distributed subscription handling, consider implementing a custom solution using Redis Pub/Sub. - -## Testing - -The package includes comprehensive tests that use an in-memory Redis server for testing. To run the tests: - -```bash -go test -v -``` - -For end-to-end testing, the package uses [miniredis](https://github.com/alicebob/miniredis), which provides a fully featured in-memory Redis implementation perfect for testing without external dependencies. - -## Full Example - -See the [example directory](./example) for a complete working example. - -## License - -This package is part of the A2A Go implementation and follows the same license. diff --git a/taskmanager/redis/example/README.md b/taskmanager/redis/example/README.md index 0791320..8741f5d 100644 --- a/taskmanager/redis/example/README.md +++ b/taskmanager/redis/example/README.md @@ -1,115 +1,229 @@ -# A2A Redis Task Manager Example +# Redis TaskManager Example - Text Case Converter -This example demonstrates how to create a complete A2A (Application-to-Application) flow using the Redis task manager. It includes both a server and client implementation that use the official A2A Go packages. +A simple example demonstrating Redis TaskManager with a **Text to Lowercase Converter** service. -## Prerequisites - -- Go 1.21 or later -- Redis 6.0 or later running locally (or accessible through network) +## Overview -## Running the Server +- **Server**: Converts text to lowercase using Redis for storage +- **Client**: Tests both streaming and non-streaming modes with enhanced visual feedback -The server implements a simple task processor that can receive tasks, process them, and return results. It uses Redis for task storage and state management. +## Prerequisites -```bash -# Start Redis if it's not already running -# For example, using Docker: -docker run --name redis -p 6379:6379 -d redis +1. **Go 1.23.0+** +2. **Redis Server** running on localhost:6379 -# Run the server with default settings (Redis on localhost:6379) -cd server -go run main.go +### Quick Redis Setup -# Or with custom settings -go run main.go -port=8080 -redis=localhost:6379 -redis-pass="" -redis-db=0 +**Using Docker:** +```bash +docker run -d -p 6379:6379 redis:7-alpine ``` -The server supports the following command-line arguments: - -- `-port`: The HTTP port to listen on (default: 8080) -- `-redis`: Redis server address (default: localhost:6379) -- `-redis-pass`: Redis password (default: "") -- `-redis-db`: Redis database number (default: 0) +**Using Docker Compose:** +```bash +cd examples/redis +docker-compose up -d +``` -## Running the Client +## Running the Example -The client provides a command-line interface to interact with the server. It supports sending tasks, retrieving task status, and streaming task updates. +### 1. Start Redis +Make sure Redis is running: ```bash -cd client -go run main.go -op=send -message="Hello, world!" +redis-cli ping +# Should return: PONG +``` -# Get task status -go run main.go -op=get -task= +### 2. Start the Server -# Cancel a task -go run main.go -op=cancel -task= +```bash +cd taskmanager/redis/example/server +go run main.go +``` -# Stream task updates -go run main.go -op=stream -message="Hello, streaming!" +You should see: +``` +Connected to Redis at localhost:6379 successfully +Starting Text Case Converter server on :8080 ``` -The client supports the following operations: +With custom parameters: +```bash +# Custom Redis address +go run main.go --redis_addr localhost:6380 + +# Custom server address +go run main.go --addr :9000 -- `send`: Create and send a new task, then poll until completion -- `get`: Retrieve the status of an existing task -- `cancel`: Cancel an in-progress task -- `stream`: Create a task and stream updates until completion +# Both custom +go run main.go --redis_addr redis.example.com:6379 --addr :8080 -Command-line arguments: +# Show help +go run main.go --help +``` -- `-server`: The A2A server URL (default: http://localhost:8080) -- `-message`: The message to send (default: "Hello, world!") -- `-op`: The operation to perform (default: "send") -- `-task`: The task ID (required for get and cancel operations) -- `-idkey`: An idempotency key for task creation (optional) +### 3. Run the Client -## Understanding the Code +In another terminal: +```bash +cd taskmanager/redis/example/client +go run main.go +``` -### Server +With custom parameters: +```bash +# Custom text input +go run main.go --text "Hello World! CONVERT THIS TEXT" -The server implementation: +# Custom server URL +go run main.go --addr http://localhost:9000/ -1. Uses the `redismgr.TaskManager` for persistent task storage in Redis -2. Implements a custom `DemoTaskProcessor` for task processing logic -3. Creates an official A2A server with appropriate configuration -4. Handles tasks via the A2A protocol endpoints +# Only test streaming mode (enhanced visual effects) +go run main.go --streaming --verbose -### Client +# Only test non-streaming mode +go run main.go --non-streaming -The client implementation: +# Enable verbose output for detailed information +go run main.go --verbose -1. Uses the official `client.A2AClient` to communicate with the server -2. Supports various operations through the command-line interface -3. Formats and displays task state and artifacts -4. Demonstrates both synchronous and streaming interaction +# Show help +go run main.go --help +``` -## Example Flow +## Sample Output -1. Start the server: - ```bash - cd server - go run main.go - ``` +### Non-streaming Mode +``` +=== Text Case Converter Client === +Server: http://localhost:8080/ +Input text: 'Hello World! THIS IS A TEST MESSAGE.' + +Test 1: Non-streaming conversion +→ Sending non-streaming request... +✓ Processing time: 45.123ms +📄 Result 1: 'hello world! this is a test message.' +``` -2. Send a task from the client: - ```bash - cd client - go run main.go -op=send -message="Process this message" - ``` +### Streaming Mode (Enhanced Display) +``` +Test 2: Streaming conversion with task updates +-> Starting streaming request... +[STREAMING] Processing events: +[TASK] ID: msg-a1b2c3d4... +[WORKING] Task State: working (Event #1) + [MESSAGE] [STARTING] Initializing text conversion process... +[WORKING] Task State: working (Event #2) + [MESSAGE] [ANALYZING] Processing input text (37 characters)... +[WORKING] Task State: working (Event #3) + [MESSAGE] [PROCESSING] Converting text to lowercase... +[WORKING] Task State: working (Event #4) + [MESSAGE] [ARTIFACT] Creating result artifact... +[ARTIFACT] ID: processed-text-msg-a1b2c3d4... + [NAME] Text to Lowercase + [DESC] Convert any text to lowercase + [CONTENT] 'hello world! this is a test message.' + [METADATA] + operation: text_to_lower + originalText: Hello World! THIS IS A TEST MESSAGE. + originalLength: 37 + resultLength: 37 + processedAt: 2025-01-02T10:30:45Z + processingTime: 1.7s +[SUCCESS] Task State: completed + [MESSAGE] [COMPLETED] Text processing finished! Original: 'Hello World! THIS IS A TEST MESSAGE.' -> Lowercase: 'hello world! this is a test message.' +[FINISHED] Task completed! (Total time: 2.1s) +[COMPLETED] Stream finished (5 events, 2.1s total) +``` -3. The client will display task status updates, including the final result and any artifacts produced. +## What This Demonstrates -## Error Handling +- **Redis Storage**: Messages, tasks, and artifacts persisted in Redis +- **Task Management**: Real-time task state transitions with multiple processing steps +- **Streaming Events**: Live updates via Server-Sent Events with clear text indicators +- **Artifact Creation**: Generated content with rich metadata +- **Command Line Interface**: Modern flag-based parameter handling +- **Code Standards**: Clean code structure with constants and proper naming conventions -The example demonstrates error handling in several ways: +## Command Line Options -- If you include the word "error" in your message, the server will return a simulated error -- Connection errors between client and server are properly reported -- Task cancellation is supported through the cancel operation +**Server Options:** +```bash +go run main.go [OPTIONS] + +Options: + --redis_addr string Redis server address (default "localhost:6379") + --addr string Server listen address (default ":8080") + --help Show help message + --version Show version information + +Examples: + go run main.go # Use default settings + go run main.go --redis_addr localhost:6380 # Custom Redis port + go run main.go --addr :9000 # Custom server port + go run main.go --redis_addr redis.example.com:6379 --addr :8080 +``` -## Next Steps +**Client Options:** +```bash +go run main.go [OPTIONS] + +Options: + --addr string Server URL (default "http://localhost:8080/") + --text string Input text to process (default "Hello World! THIS IS A TEST MESSAGE.") + --streaming Only test streaming mode + --non-streaming Only test non-streaming mode + --verbose Enable verbose output + --help Show help message + --version Show version information + +Examples: + go run main.go # Use default settings + go run main.go --text "Custom Text" # Custom input text + go run main.go --addr http://localhost:9000/ # Custom server URL + go run main.go --streaming # Only test streaming mode + go run main.go --non-streaming # Only test non-streaming mode + go run main.go --verbose # Enable verbose output +``` -- Modify the `DemoTaskProcessor` to implement your own task processing logic -- Configure the Redis task manager for production use (authentication, clustering, etc.) -- Integrate the server into your own application \ No newline at end of file +## Enhanced Features + +### Visual Improvements +- **Text Indicators**: Clear prefixed labels for different event types +- **Progress Animation**: Text-based progress indicators during processing (verbose mode) +- **Structured Output**: Consistent formatting with bracketed prefixes +- **Detailed Metadata**: Rich information about processing steps (verbose mode) + +### Streaming Enhancements +- **Multi-step Processing**: Server shows detailed progress through multiple phases: + 1. [STARTING] Initialize process + 2. [ANALYZING] Process input + 3. [PROCESSING] Convert text + 4. [ARTIFACT] Create artifact + 5. [COMPLETED] Finish processing +- **Rich Artifacts**: Detailed metadata including processing time and statistics +- **Event Counting**: Track number of streaming events received + +### Code Quality Improvements +- **Constants Usage**: All hardcoded strings replaced with named constants +- **Function Separation**: Logical separation of concerns in processing functions +- **Type Safety**: Proper handling of pointer types and interfaces +- **Error Handling**: Comprehensive error checking and reporting + +## Troubleshooting + +**Redis Connection Failed:** +- Check if Redis is running: `redis-cli ping` +- Verify the Redis address: `--redis_addr localhost:6379` + +**Server Port in Use:** +- Use different port: `--addr :8081` + +**Client Connection Failed:** +- Ensure server is running +- Check server URL: `--addr http://localhost:8080/` + +**Streaming Effects Not Visible:** +- Use `--streaming --verbose` for maximum visual effect +- Ensure you're testing the streaming mode only \ No newline at end of file diff --git a/taskmanager/redis/example/client/main.go b/taskmanager/redis/example/client/main.go index 492b3e7..f3220e7 100644 --- a/taskmanager/redis/example/client/main.go +++ b/taskmanager/redis/example/client/main.go @@ -4,7 +4,8 @@ // // trpc-a2a-go is licensed under the Apache License Version 2.0. -// Package main provides a simple A2A client to interact with the server. +// Package main provides a Redis TaskManager example client that demonstrates +// how to interact with the Redis-based task manager server. package main import ( @@ -12,6 +13,7 @@ import ( "flag" "fmt" "log" + "os" "strings" "time" @@ -19,293 +21,363 @@ import ( "trpc.group/trpc-go/trpc-a2a-go/protocol" ) -// printTaskDetails prints the details of a task to the console. -func printTaskDetails(task *protocol.Task) { - log.Printf("Task ID: %s", task.ID) - log.Printf("State: %s", task.Status.State) - log.Printf("Timestamp: %s", task.Status.Timestamp) - log.Printf("Message history: %d items", len(task.History)) - - if len(task.Artifacts) > 0 { - log.Printf("Artifacts:") - for i, artifact := range task.Artifacts { - log.Printf(" [%d] %s", i, artifactToString(artifact)) - } - } +const ( + // Default configuration values + defaultServerURL = "http://localhost:8080/" + defaultInputText = "Hello World! THIS IS A TEST MESSAGE." + + // Client information + clientName = "Text Case Converter Client" + clientVersion = "1.0.0" + + // Timing constants + clientTimeout = 30 * time.Second + progressInterval = 500 * time.Millisecond + + // Output prefixes + prefixSuccess = "[SUCCESS]" + prefixResult = "[RESULT]" + prefixWarning = "[WARNING]" + prefixStreaming = "[STREAMING]" + prefixProcessing = "[PROCESSING]" + prefixCompleted = "[COMPLETED]" + prefixTask = "[TASK]" + prefixMessage = "[MESSAGE]" + prefixFinished = "[FINISHED]" + prefixArtifact = "[ARTIFACT]" + prefixName = "[NAME]" + prefixDesc = "[DESC]" + prefixContent = "[CONTENT]" + prefixMetadata = "[METADATA]" + prefixUnknown = "[UNKNOWN]" + prefixTimeout = "[TIMEOUT]" +) - if task.Status.Message != nil { - log.Printf("Status message: %s", messageToString(*task.Status.Message)) - } -} +var ( + // Progress indicator characters + progressChars = []string{".", "..", "...", "...."} +) -// messageToString converts a message to a string representation. -func messageToString(msg protocol.Message) string { - if len(msg.Parts) == 0 { - return "[empty message]" +func main() { + // Parse command line arguments + var serverURL = flag.String("addr", defaultServerURL, "Server URL") + var inputText = flag.String("text", defaultInputText, "Input text to process") + var streamingOnly = flag.Bool("streaming", false, "Only test streaming mode") + var nonStreamingOnly = flag.Bool("non-streaming", false, "Only test non-streaming mode") + var verbose = flag.Bool("verbose", false, "Enable verbose output") + var help = flag.Bool("help", false, "Show help message") + var version = flag.Bool("version", false, "Show version information") + + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "%s - Redis TaskManager Example\n\n", clientName) + fmt.Fprintf(os.Stderr, "Usage: %s [OPTIONS]\n\n", os.Args[0]) + fmt.Fprintf(os.Stderr, "Options:\n") + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, "\nExamples:\n") + fmt.Fprintf(os.Stderr, " %s # Use default settings\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " %s --text \"Custom Text\" # Custom input text\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " %s --addr http://localhost:9000/ # Custom server URL\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " %s --streaming # Only test streaming mode\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " %s --non-streaming # Only test non-streaming mode\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " %s --verbose # Enable verbose output\n", os.Args[0]) } - var result strings.Builder - for _, part := range msg.Parts { - if textPart, ok := part.(protocol.TextPart); ok { - result.WriteString(textPart.Text) - } else { - result.WriteString("[non-text part]") - } - } - return result.String() -} + flag.Parse() -// artifactToString converts an artifact to a string representation. -func artifactToString(artifact protocol.Artifact) string { - var parts []string - for _, part := range artifact.Parts { - if textPart, ok := part.(protocol.TextPart); ok { - parts = append(parts, textPart.Text) - } else { - parts = append(parts, "[non-text part]") - } + if *help { + flag.Usage() + os.Exit(0) } - name := "unnamed" - if artifact.Name != nil { - name = *artifact.Name + if *version { + fmt.Printf("%s %s\n", clientName, clientVersion) + fmt.Println("Redis TaskManager Example") + os.Exit(0) } - lastChunk := false - if artifact.LastChunk != nil { - lastChunk = *artifact.LastChunk + // Validate flags + if *streamingOnly && *nonStreamingOnly { + log.Fatal("Cannot specify both --streaming and --non-streaming flags") } - return fmt.Sprintf("%s (index: %d, last chunk: %v): %v", - name, artifact.Index, lastChunk, parts) -} - -// pollTaskUntilDone polls a task until it reaches a final state. -func pollTaskUntilDone(ctx context.Context, a2aClient *client.A2AClient, taskID string) (*protocol.Task, error) { - ticker := time.NewTicker(500 * time.Millisecond) - defer ticker.Stop() + // Create A2A client + a2aClient, err := client.NewA2AClient( + *serverURL, + client.WithTimeout(clientTimeout), + ) + if err != nil { + log.Fatalf("Failed to create A2A client: %v", err) + } - for { - select { - case <-ticker.C: - task, err := a2aClient.GetTasks(ctx, protocol.TaskQueryParams{ID: taskID}) - if err != nil { - return nil, fmt.Errorf("failed to get task: %w", err) - } + ctx := context.Background() - // Check if the task is done - if task.Status.State == protocol.TaskStateCompleted || - task.Status.State == protocol.TaskStateFailed || - task.Status.State == protocol.TaskStateCanceled { - return task, nil - } + // Display what we're going to do + fmt.Printf("=== %s ===\n", clientName) + fmt.Printf("Server: %s\n", *serverURL) + fmt.Printf("Input text: '%s'\n", *inputText) + if *verbose { + fmt.Printf("Verbose mode: enabled\n") + } + fmt.Println() - log.Printf("Task %s state: %s", taskID, task.Status.State) + // Test based on flags + if !*streamingOnly { + fmt.Println("Test 1: Non-streaming conversion") + runNonStreamingDemo(ctx, a2aClient, *inputText, *verbose) - case <-ctx.Done(): - return nil, ctx.Err() + if !*nonStreamingOnly { + fmt.Println("\n" + strings.Repeat("=", 50) + "\n") } } -} -// generateTaskID creates a task ID if none is provided. -func generateTaskID(providedID string) string { - if providedID != "" { - return providedID + if !*nonStreamingOnly { + fmt.Println("Test 2: Streaming conversion with task updates") + runStreamingDemo(ctx, a2aClient, *inputText, *verbose) } - id := fmt.Sprintf("task-%d", time.Now().UnixNano()) - log.Printf("Generated task ID: %s", id) - return id } -// createSendTaskParams creates parameters for sending a task. -func createSendTaskParams(taskID, message, idempotencyKey string) protocol.SendTaskParams { - params := protocol.SendTaskParams{ - ID: taskID, - Message: protocol.Message{ - Role: protocol.MessageRoleUser, - Parts: []protocol.Part{ - protocol.NewTextPart(message), - }, +func runNonStreamingDemo(ctx context.Context, client *client.A2AClient, inputText string, verbose bool) { + message := protocol.Message{ + Role: protocol.MessageRoleUser, + Parts: []protocol.Part{ + protocol.NewTextPart(inputText), }, } - // Add idempotency key if provided - if idempotencyKey != "" { - if params.Metadata == nil { - params.Metadata = make(map[string]interface{}) - } - params.Metadata["idempotency_key"] = idempotencyKey + params := protocol.SendMessageParams{ + Message: message, + Configuration: &protocol.SendMessageConfiguration{ + Blocking: boolPtr(true), // Non-streaming + }, } - return params -} + if verbose { + fmt.Printf("-> Sending non-streaming request...\n") + } -// handleSendOperation handles the "send" operation. -func handleSendOperation(ctx context.Context, a2aClient *client.A2AClient, taskID, message, idempotencyKey string) { - id := generateTaskID(taskID) - sendParams := createSendTaskParams(id, message, idempotencyKey) + start := time.Now() + result, err := client.SendMessage(ctx, params) + duration := time.Since(start) - // Send the task - task, err := a2aClient.SendTasks(ctx, sendParams) if err != nil { - log.Fatalf("Failed to send task: %v", err) + log.Printf("Error: %v", err) + return } - log.Printf("Task created: %s", task.ID) - log.Printf("Initial state: %s", task.Status.State) + fmt.Printf("%s Processing time: %v\n", prefixSuccess, duration) - // Poll for task completion - log.Println("Polling for task completion...") - finalTask, err := pollTaskUntilDone(ctx, a2aClient, task.ID) - if err != nil { - log.Fatalf("Failed to poll task: %v", err) + if response, ok := result.Result.(*protocol.Message); ok { + for i, part := range response.Parts { + if textPart, ok := part.(*protocol.TextPart); ok { + fmt.Printf("%s %d: '%s'\n", prefixResult, i+1, textPart.Text) + } + } + } else { + fmt.Printf("%s Unexpected result type: %T\n", prefixWarning, result.Result) } - - printFinalTaskStatus(finalTask) } -// printFinalTaskStatus prints the final status of a completed task. -func printFinalTaskStatus(task *protocol.Task) { - log.Printf("Task %s final state: %s", task.ID, task.Status.State) - - // Print the artifacts - if len(task.Artifacts) > 0 { - for i, artifact := range task.Artifacts { - log.Printf("Artifact %d: %s", i, artifactToString(artifact)) - } +func runStreamingDemo(ctx context.Context, client *client.A2AClient, inputText string, verbose bool) { + eventChan, err := startStreamingRequest(ctx, client, inputText, verbose) + if err != nil { + log.Printf("Failed to start streaming: %v", err) + return } - // Print the final message if present - if task.Status.Message != nil { - log.Printf("Final message: %s", messageToString(*task.Status.Message)) + streamProcessor := &streamEventProcessor{ + verbose: verbose, + startTime: time.Now(), + timeout: time.After(clientTimeout), + progressTicker: time.NewTicker(progressInterval), } + defer streamProcessor.progressTicker.Stop() + + streamProcessor.processEvents(eventChan) } -// handleGetOperation handles the "get" operation. -func handleGetOperation(ctx context.Context, a2aClient *client.A2AClient, taskID string) { - if taskID == "" { - log.Fatalf("Task ID is required for get operation") +// startStreamingRequest initiates a streaming request and returns the event channel +func startStreamingRequest(ctx context.Context, client *client.A2AClient, inputText string, verbose bool) (<-chan protocol.StreamingMessageEvent, error) { + message := protocol.Message{ + Role: protocol.MessageRoleUser, + Parts: []protocol.Part{ + protocol.NewTextPart(inputText), + }, } - // Create query parameters - params := protocol.TaskQueryParams{ - ID: taskID, - HistoryLength: nil, // Get default history length + params := protocol.SendMessageParams{ + Message: message, + Configuration: &protocol.SendMessageConfiguration{ + Blocking: boolPtr(false), // Streaming + }, } - // Get the task - task, err := a2aClient.GetTasks(ctx, params) - if err != nil { - log.Fatalf("Failed to get task: %v", err) + if verbose { + fmt.Printf("-> Starting streaming request...\n") } - // Print task details - printTaskDetails(task) + return client.StreamMessage(ctx, params) } -// handleCancelOperation handles the "cancel" operation. -func handleCancelOperation(ctx context.Context, a2aClient *client.A2AClient, taskID string) { - if taskID == "" { - log.Fatalf("Task ID is required for cancel operation") - } +// streamEventProcessor handles processing of streaming events +type streamEventProcessor struct { + verbose bool + startTime time.Time + timeout <-chan time.Time + progressTicker *time.Ticker + taskID string + eventCount int + progressIndex int +} + +// processEvents processes streaming events from the event channel +func (p *streamEventProcessor) processEvents(eventChan <-chan protocol.StreamingMessageEvent) { + fmt.Printf("%s Processing events:\n", prefixStreaming) - // Create cancel parameters - params := protocol.TaskIDParams{ - ID: taskID, + for { + select { + case <-p.progressTicker.C: + p.handleProgressTick() + + case event, ok := <-eventChan: + if !ok { + p.handleStreamComplete() + return + } + p.handleStreamEvent(event) + + case <-p.timeout: + fmt.Printf("%s Waiting for events\n", prefixTimeout) + return + } } +} - // Cancel the task - task, err := a2aClient.CancelTasks(ctx, params) - if err != nil { - log.Fatalf("Failed to cancel task: %v", err) +// handleProgressTick updates the progress indicator +func (p *streamEventProcessor) handleProgressTick() { + if p.eventCount > 0 && p.verbose { + p.progressIndex = (p.progressIndex + 1) % len(progressChars) + fmt.Printf("\r%s %s", prefixProcessing, progressChars[p.progressIndex]) } +} - log.Printf("Task %s canceled, final state: %s", task.ID, task.Status.State) +// handleStreamComplete handles stream completion +func (p *streamEventProcessor) handleStreamComplete() { + duration := time.Since(p.startTime) + if p.verbose { + fmt.Printf("\r") + } + fmt.Printf("%s Stream finished (%d events, %v total)\n", prefixCompleted, p.eventCount, duration) } -// processStreamEvents processes events from the streaming API. -func processStreamEvents(eventChan <-chan protocol.TaskEvent) { - for event := range eventChan { - switch e := event.(type) { - case protocol.TaskStatusUpdateEvent: - handleStatusUpdateEvent(e) - case protocol.TaskArtifactUpdateEvent: - handleArtifactUpdateEvent(e) - default: - log.Printf("Unknown event type: %T", event) - } +// handleStreamEvent processes a single stream event +func (p *streamEventProcessor) handleStreamEvent(event protocol.StreamingMessageEvent) { + p.eventCount++ + + // Clear progress indicator + if p.verbose && p.eventCount > 1 { + fmt.Printf("\r") + } + + switch result := event.Result.(type) { + case *protocol.TaskStatusUpdateEvent: + p.handleTaskStatusEvent(result) + case *protocol.TaskArtifactUpdateEvent: + p.handleTaskArtifactEvent(result) + default: + fmt.Printf("%s Event type: %T\n", prefixUnknown, result) } } -// handleStatusUpdateEvent processes a status update event. -func handleStatusUpdateEvent(event protocol.TaskStatusUpdateEvent) { - log.Printf("Status update: %s", event.Status.State) - if event.Final { - log.Printf("Final status: %s", event.Status.State) - if event.Status.Message != nil { - log.Printf("Final message: %s", messageToString(*event.Status.Message)) +// handleTaskStatusEvent processes task status update events +func (p *streamEventProcessor) handleTaskStatusEvent(event *protocol.TaskStatusUpdateEvent) { + if p.taskID == "" { + p.taskID = event.TaskID + if p.verbose { + fmt.Printf("%s ID: %s\n", prefixTask, event.TaskID) } } -} -// handleArtifactUpdateEvent processes an artifact update event. -func handleArtifactUpdateEvent(event protocol.TaskArtifactUpdateEvent) { - name := "unnamed" - if event.Artifact.Name != nil { - name = *event.Artifact.Name + statusPrefix := getStatusPrefix(event.Status.State) + fmt.Printf("[%s] Task State: %s", statusPrefix, event.Status.State) + + if p.verbose { + fmt.Printf(" (Event #%d)", p.eventCount) } - log.Printf("Artifact: %s (Index: %d)", name, event.Artifact.Index) - log.Printf(" Content: %s", artifactToString(event.Artifact)) - if event.Final { - log.Printf(" (This is the final artifact)") + fmt.Println() + + p.displayTaskMessage(event.Status.Message) + + if event.IsFinal() { + duration := time.Since(p.startTime) + fmt.Printf("%s Task completed! (Total time: %v)\n", prefixFinished, duration) } } -// handleStreamOperation handles the "stream" operation. -func handleStreamOperation(ctx context.Context, a2aClient *client.A2AClient, taskID, message string) { - id := generateTaskID(taskID) - sendParams := createSendTaskParams(id, message, "") +// handleTaskArtifactEvent processes task artifact update events +func (p *streamEventProcessor) handleTaskArtifactEvent(event *protocol.TaskArtifactUpdateEvent) { + fmt.Printf("%s ID: %s\n", prefixArtifact, event.Artifact.ArtifactID) - // Use the streaming API (SendTaskSubscribe) - eventChan, err := a2aClient.StreamTask(ctx, sendParams) - if err != nil { - log.Fatalf("Failed to start streaming task: %v", err) + if event.Artifact.Name != nil { + fmt.Printf(" %s %s\n", prefixName, *event.Artifact.Name) } - log.Printf("Started streaming task %s", sendParams.ID) + if event.Artifact.Description != nil { + fmt.Printf(" %s %s\n", prefixDesc, *event.Artifact.Description) + } - // Process events - processStreamEvents(eventChan) + p.displayArtifactContent(event.Artifact.Parts) - log.Printf("Streaming completed for task %s", sendParams.ID) + if p.verbose { + p.displayArtifactMetadata(event.Artifact.Metadata) + } } -func main() { - serverURL := flag.String("server", "http://localhost:8080", "A2A server URL") - message := flag.String("message", "Hello, world!", "Message to send") - operation := flag.String("op", "send", "Operation: send, get, cancel, or stream") - taskID := flag.String("task", "", "Task ID (required for get and cancel operations)") - idempotencyKey := flag.String("idkey", "", "Idempotency key for task creation (generated if not provided)") - flag.Parse() +// displayTaskMessage displays task status message if present +func (p *streamEventProcessor) displayTaskMessage(message *protocol.Message) { + if message != nil { + for _, part := range message.Parts { + if textPart, ok := part.(*protocol.TextPart); ok { + fmt.Printf(" %s %s\n", prefixMessage, textPart.Text) + } + } + } +} - // Create A2A client - a2aClient, err := client.NewA2AClient(*serverURL) - if err != nil { - log.Fatalf("Failed to create A2A client: %v", err) +// displayArtifactContent displays artifact content parts +func (p *streamEventProcessor) displayArtifactContent(parts []protocol.Part) { + for _, part := range parts { + if textPart, ok := part.(*protocol.TextPart); ok { + fmt.Printf(" %s '%s'\n", prefixContent, textPart.Text) + } } - ctx := context.Background() +} + +// displayArtifactMetadata displays artifact metadata if available +func (p *streamEventProcessor) displayArtifactMetadata(metadata map[string]interface{}) { + if len(metadata) > 0 { + fmt.Printf(" %s\n", prefixMetadata) + for key, value := range metadata { + fmt.Printf(" %s: %v\n", key, value) + } + } +} - switch *operation { - case "send": - handleSendOperation(ctx, a2aClient, *taskID, *message, *idempotencyKey) - case "get": - handleGetOperation(ctx, a2aClient, *taskID) - case "cancel": - handleCancelOperation(ctx, a2aClient, *taskID) - case "stream": - handleStreamOperation(ctx, a2aClient, *taskID, *message) +// getStatusPrefix returns an appropriate prefix for the task status +func getStatusPrefix(state protocol.TaskState) string { + switch state { + case protocol.TaskStateWorking: + return "WORKING" + case protocol.TaskStateCompleted: + return "SUCCESS" + case protocol.TaskStateFailed: + return "FAILED" + case protocol.TaskStateCanceled: + return "CANCELLED" default: - log.Fatalf("Unknown operation: %s", *operation) + return "STATUS" } } + +func boolPtr(b bool) *bool { + return &b +} diff --git a/taskmanager/redis/example/server/main.go b/taskmanager/redis/example/server/main.go index 4a72412..3eca064 100644 --- a/taskmanager/redis/example/server/main.go +++ b/taskmanager/redis/example/server/main.go @@ -4,7 +4,8 @@ // // trpc-a2a-go is licensed under the Apache License Version 2.0. -// Package main provides a simple A2A server using the Redis task manager. +// Package main provides a Redis TaskManager example server that demonstrates +// how to use the Redis-based task manager for processing text conversion tasks. package main import ( @@ -12,145 +13,355 @@ import ( "flag" "fmt" "log" + "os" "strings" "time" "github.com/redis/go-redis/v9" - "trpc.group/trpc-go/trpc-a2a-go/protocol" "trpc.group/trpc-go/trpc-a2a-go/server" "trpc.group/trpc-go/trpc-a2a-go/taskmanager" - redismgr "trpc.group/trpc-go/trpc-a2a-go/taskmanager/redis" + redisTaskManager "trpc.group/trpc-go/trpc-a2a-go/taskmanager/redis" ) -// DemoTaskProcessor implements TaskProcessor for our demo server. -type DemoTaskProcessor struct{} +const ( + // Default configuration values + defaultRedisAddress = "localhost:6379" + defaultServerAddress = ":8080" + + // Processing step timing constants + analysisDelay = 500 * time.Millisecond + processingDelay = 700 * time.Millisecond + conversionDelay = 500 * time.Millisecond + artifactCreationDelay = 300 * time.Millisecond + + // Processing step messages + msgStarting = "[STARTING] Initializing text conversion process..." + msgAnalyzing = "[ANALYZING] Processing input text (%d characters)..." + msgProcessing = "[PROCESSING] Converting text to lowercase..." + msgArtifact = "[ARTIFACT] Creating result artifact..." + msgCompleted = "[COMPLETED] Text processing finished! Original: '%s' -> Lowercase: '%s'" + + // Server information + serverName = "Text Case Converter" + serverDescription = "A simple agent that converts text to lowercase using Redis storage" + serverVersion = "1.0.0" + organizationName = "Redis TaskManager Example" + + // Skill information + skillID = "text_to_lower" + skillName = "Text to Lowercase" + skillDescription = "Convert any text to lowercase" +) -// Process implements the task processing logic. -func (p *DemoTaskProcessor) Process( +var ( + // Skill configuration + skillTags = []string{"text", "conversion", "lowercase"} + skillExamples = []string{"Hello World!", "THIS IS UPPERCASE", "MiXeD cAsE tExT"} + inputOutputModes = []string{"text"} +) + +// ToLowerProcessor implements a simple text processing service that converts text to lowercase +type ToLowerProcessor struct{} + +// ProcessMessage processes incoming messages by converting text to lowercase. +// It supports both streaming and non-streaming modes of operation. +func (p *ToLowerProcessor) ProcessMessage( ctx context.Context, + message protocol.Message, + options taskmanager.ProcessOptions, + handle taskmanager.TaskHandler, +) (*taskmanager.MessageProcessingResult, error) { + log.Printf("Processing message: %s", message.MessageID) + + // Extract text from message parts + inputText := extractTextFromMessage(message) + if inputText == "" { + return &taskmanager.MessageProcessingResult{ + Result: &protocol.Message{ + Role: protocol.MessageRoleAgent, + Parts: []protocol.Part{ + protocol.NewTextPart("Error: No text found in message"), + }, + }, + }, nil + } + + if options.Streaming { + return p.processStreamingMode(inputText, message.ContextID, handle) + } + + return p.processNonStreamingMode(inputText), nil +} + +// extractTextFromMessage extracts text content from message parts +func extractTextFromMessage(message protocol.Message) string { + var inputText string + for _, part := range message.Parts { + if textPart, ok := part.(*protocol.TextPart); ok { + inputText += textPart.Text + } + } + return inputText +} + +// processStreamingMode handles streaming processing with task updates +func (p *ToLowerProcessor) processStreamingMode( + inputText string, + contextID *string, + handle taskmanager.TaskHandler, +) (*taskmanager.MessageProcessingResult, error) { + // Build task for streaming mode + taskID, err := handle.BuildTask(nil, contextID) + if err != nil { + return nil, fmt.Errorf("failed to build task: %w", err) + } + + // Subscribe to the task + subscriber, err := handle.SubScribeTask(&taskID) + if err != nil { + return nil, fmt.Errorf("failed to subscribe to task: %w", err) + } + + // Process asynchronously + go p.processTextAsync(inputText, taskID, handle) + + return &taskmanager.MessageProcessingResult{ + StreamingEvents: subscriber, + }, nil +} + +// processNonStreamingMode handles direct processing without streaming +func (p *ToLowerProcessor) processNonStreamingMode(inputText string) *taskmanager.MessageProcessingResult { + result := strings.ToLower(inputText) + + response := &protocol.Message{ + Role: protocol.MessageRoleAgent, + Parts: []protocol.Part{ + protocol.NewTextPart(result), + }, + } + + return &taskmanager.MessageProcessingResult{ + Result: response, + } +} + +func (p *ToLowerProcessor) processTextAsync( + inputText string, taskID string, - initialMsg protocol.Message, - handle taskmanager.TaskHandle, -) error { - // First, update the status to show we're working - if err := handle.UpdateStatus(protocol.TaskStateWorking, nil); err != nil { - return fmt.Errorf("failed to update status to working: %w", err) - } - - // Extract the user message - var userMessage string - if len(initialMsg.Parts) > 0 { - if textPart, ok := initialMsg.Parts[0].(protocol.TextPart); ok { - userMessage = textPart.Text + handle taskmanager.TaskHandler, +) { + defer func() { + err := handle.CleanTask(&taskID) + if err != nil { + log.Printf("Failed to clean task: %v", err) } + }() + + // Step 1: Starting processing + err := handle.UpdateTaskState(&taskID, protocol.TaskStateWorking, &protocol.Message{ + Role: protocol.MessageRoleAgent, + Parts: []protocol.Part{ + protocol.NewTextPart(msgStarting), + }, + }) + if err != nil { + log.Printf("Failed to update task state: %v", err) + return } - // Log the task - log.Printf("Processing task %s: %s", taskID, userMessage) + // Simulate analysis phase + time.Sleep(analysisDelay) - // Simulate some work - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(1 * time.Second): - // Processed + // Step 2: Analysis phase + err = handle.UpdateTaskState(&taskID, protocol.TaskStateWorking, &protocol.Message{ + Role: protocol.MessageRoleAgent, + Parts: []protocol.Part{ + protocol.NewTextPart(fmt.Sprintf(msgAnalyzing, len(inputText))), + }, + }) + if err != nil { + log.Printf("Failed to update task state: %v", err) + return } - // Create a response based on the user message - response := fmt.Sprintf("Processed: %s", userMessage) - if strings.Contains(strings.ToLower(userMessage), "error") { - return fmt.Errorf("simulated error requested in message") + // Simulate processing phase + time.Sleep(processingDelay) + + // Step 3: Processing phase + err = handle.UpdateTaskState(&taskID, protocol.TaskStateWorking, &protocol.Message{ + Role: protocol.MessageRoleAgent, + Parts: []protocol.Part{ + protocol.NewTextPart(msgProcessing), + }, + }) + if err != nil { + log.Printf("Failed to update task state: %v", err) + return } - // Add an artifact + // Simulate actual processing + time.Sleep(conversionDelay) + + // Process the text + result := strings.ToLower(inputText) + + // Step 4: Creating artifact + err = handle.UpdateTaskState(&taskID, protocol.TaskStateWorking, &protocol.Message{ + Role: protocol.MessageRoleAgent, + Parts: []protocol.Part{ + protocol.NewTextPart(msgArtifact), + }, + }) + if err != nil { + log.Printf("Failed to update task state: %v", err) + return + } + + // Create an artifact with the processed text artifact := protocol.Artifact{ - Name: strPtr("result"), - Parts: []protocol.Part{protocol.NewTextPart(response)}, - Index: 0, + ArtifactID: protocol.GenerateArtifactID(), + Name: stringPtr(skillName), + Description: stringPtr(skillDescription), + Parts: []protocol.Part{ + protocol.NewTextPart(result), + }, + Metadata: map[string]interface{}{ + "operation": skillID, + "originalText": inputText, + "originalLength": len(inputText), + "resultLength": len(result), + "processedAt": time.Now().UTC().Format(time.RFC3339), + "processingTime": "1.7s", + }, } - lastChunk := true - artifact.LastChunk = &lastChunk - if err := handle.AddArtifact(artifact); err != nil { - return fmt.Errorf("failed to add artifact: %w", err) + // Add artifact to task + if err := handle.AddArtifact(&taskID, artifact, true, false); err != nil { + log.Printf("Failed to add artifact: %v", err) } - // Complete with a success message - successMsg := &protocol.Message{ + // Small delay to show artifact creation + time.Sleep(artifactCreationDelay) + + // Send final status with result message + finalMessage := &protocol.Message{ Role: protocol.MessageRoleAgent, Parts: []protocol.Part{ - protocol.NewTextPart(fmt.Sprintf("Task completed: %s", userMessage)), + protocol.NewTextPart(fmt.Sprintf(msgCompleted, inputText, result)), }, } - if err := handle.UpdateStatus(protocol.TaskStateCompleted, successMsg); err != nil { - return fmt.Errorf("failed to update final status: %w", err) - } - return nil + // Update task to completed state + err = handle.UpdateTaskState(&taskID, protocol.TaskStateCompleted, finalMessage) + if err != nil { + log.Printf("Failed to complete task: %v", err) + } } -// Helper function to create string pointers -func strPtr(s string) *string { +func stringPtr(s string) *string { return &s } +func boolPtr(b bool) *bool { + return &b +} + func main() { // Parse command line flags - port := flag.Int("port", 8080, "Server port") - redisAddr := flag.String("redis", "localhost:6379", "Redis server address") - redisPassword := flag.String("redis-pass", "", "Redis password") - redisDB := flag.Int("redis-db", 0, "Redis database") + var redisAddr = flag.String("redis_addr", defaultRedisAddress, "Redis server address") + var serverAddr = flag.String("addr", defaultServerAddress, "Server listen address (e.g., :8080 or localhost:8080)") + var help = flag.Bool("help", false, "Show help message") + var version = flag.Bool("version", false, "Show version information") + + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Text Case Converter Server - Redis TaskManager Example\n\n") + fmt.Fprintf(os.Stderr, "Usage: %s [OPTIONS]\n\n", os.Args[0]) + fmt.Fprintf(os.Stderr, "Options:\n") + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, "\nExamples:\n") + fmt.Fprintf(os.Stderr, " %s # Use default settings\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " %s --redis_addr localhost:6380 # Custom Redis port\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " %s --addr :9000 # Custom server port\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " %s --redis_addr redis.example.com:6379 --addr :8080\n", os.Args[0]) + } + flag.Parse() - log.Printf("Starting A2A server with Redis task manager on port %d", *port) - log.Printf("Using Redis at %s (DB: %d)", *redisAddr, *redisDB) + if *help { + flag.Usage() + os.Exit(0) + } + + if *version { + fmt.Println("Text Case Converter Server v1.0.0") + fmt.Println("Redis TaskManager Example") + os.Exit(0) + } - // Create a task processor - processor := &DemoTaskProcessor{} + // Create Redis client + rdb := redis.NewClient(&redis.Options{ + Addr: *redisAddr, + Password: "", // no password + DB: 0, // default DB + }) - // Configure Redis client - redisOptions := &redis.UniversalOptions{ - Addrs: []string{*redisAddr}, - Password: *redisPassword, - DB: *redisDB, + // Test Redis connection + ctx := context.Background() + if err := rdb.Ping(ctx).Err(); err != nil { + log.Fatalf("Failed to connect to Redis at %s: %v", *redisAddr, err) } + log.Printf("Connected to Redis at %s successfully", *redisAddr) - // Create Redis task manager - manager, err := redismgr.NewRedisTaskManager( - redis.NewUniversalClient(redisOptions), - processor, - ) + // Create the toLower processor + processor := &ToLowerProcessor{} + + // Create Redis TaskManager + taskManager, err := redisTaskManager.NewTaskManager(rdb, processor) if err != nil { - log.Fatalf("Failed to create Redis task manager: %v", err) + log.Fatalf("Failed to create Redis TaskManager: %v", err) } - defer manager.Close() - - // Define the agent card with server metadata - description := "A simple A2A demo server using Redis task manager" - serverURL := fmt.Sprintf("http://localhost:%d", *port) - version := "1.0.0" + defer taskManager.Close() + // Create agent card agentCard := server.AgentCard{ - Name: "Redis Task Manager Demo", - Description: &description, - URL: serverURL, - Version: version, + Name: serverName, + Description: serverDescription, + URL: fmt.Sprintf("http://localhost%s/", *serverAddr), + Version: serverVersion, + Provider: &server.AgentProvider{ + Organization: organizationName, + }, Capabilities: server.AgentCapabilities{ - Streaming: true, - PushNotifications: false, - StateTransitionHistory: true, + Streaming: boolPtr(true), + PushNotifications: boolPtr(false), + }, + DefaultInputModes: inputOutputModes, + DefaultOutputModes: inputOutputModes, + Skills: []server.AgentSkill{ + { + ID: skillID, + Name: skillName, + Description: stringPtr(skillDescription), + Tags: skillTags, + Examples: skillExamples, + InputModes: inputOutputModes, + OutputModes: inputOutputModes, + }, }, - DefaultInputModes: []string{string(protocol.PartTypeText)}, - DefaultOutputModes: []string{string(protocol.PartTypeText)}, - Skills: []server.AgentSkill{}, // No specific skills } - // Create A2A server using the official server package - a2aServer, err := server.NewA2AServer(agentCard, manager, server.WithCORSEnabled(true)) + // Create HTTP server + agentServer, err := server.NewA2AServer(agentCard, taskManager) if err != nil { log.Fatalf("Failed to create A2A server: %v", err) } - - a2aServer.Start(fmt.Sprintf(":%d", *port)) + log.Printf("Starting Text Case Converter server on %s", *serverAddr) + log.Printf("Redis backend: %s", *redisAddr) + log.Printf("Try sending text like: 'Hello World!' and it will be converted to 'hello world!'") + err = agentServer.Start(*serverAddr) + if err != nil { + log.Fatalf("Failed to start A2A server: %v", err) + } } diff --git a/taskmanager/redis/go.mod b/taskmanager/redis/go.mod index bf29142..9dab19e 100644 --- a/taskmanager/redis/go.mod +++ b/taskmanager/redis/go.mod @@ -2,38 +2,30 @@ module trpc.group/trpc-go/trpc-a2a-go/taskmanager/redis go 1.23.0 -toolchain go1.23.7 - replace trpc.group/trpc-go/trpc-a2a-go => ../../ require ( - github.com/alicebob/miniredis/v2 v2.31.1 - github.com/redis/go-redis/v9 v9.7.3 - github.com/stretchr/testify v1.10.0 - trpc.group/trpc-go/trpc-a2a-go v0.0.0-00010101000000-000000000000 + github.com/redis/go-redis/v9 v9.10.0 + trpc.group/trpc-go/trpc-a2a-go v0.0.3 ) require ( - github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect - github.com/cespare/xxhash/v2 v2.2.0 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/golang-jwt/jwt/v5 v5.2.2 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/lestrrat-go/blackmagic v1.0.2 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect github.com/lestrrat-go/httprc v1.0.6 // indirect github.com/lestrrat-go/iter v1.0.2 // indirect github.com/lestrrat-go/jwx/v2 v2.1.4 // indirect github.com/lestrrat-go/option v1.0.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect github.com/segmentio/asm v1.2.0 // indirect - github.com/yuin/gopher-lua v1.1.0 // indirect go.uber.org/multierr v1.10.0 // indirect go.uber.org/zap v1.27.0 // indirect golang.org/x/crypto v0.35.0 // indirect golang.org/x/oauth2 v0.29.0 // indirect golang.org/x/sys v0.30.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/taskmanager/redis/go.sum b/taskmanager/redis/go.sum index fd1a255..9311cfc 100644 --- a/taskmanager/redis/go.sum +++ b/taskmanager/redis/go.sum @@ -1,18 +1,9 @@ -github.com/DmitriyVTitov/size v1.5.0/go.mod h1:le6rNI4CoLQV1b9gzp1+3d7hMAD/uu2QcJ+aYbNgiU0= -github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= -github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE= -github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= -github.com/alicebob/miniredis/v2 v2.31.1 h1:7XAt0uUg3DtwEKW5ZAGa+K7FZV2DdKQo5K/6TTnfX8Y= -github.com/alicebob/miniredis/v2 v2.31.1/go.mod h1:UB/T2Uztp7MlFSDakaX1sTXUv5CASoprx0wulRT6HBg= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= -github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= -github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -24,9 +15,10 @@ github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= @@ -41,8 +33,8 @@ github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNB github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= -github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= +github.com/redis/go-redis/v9 v9.10.0 h1:FxwK3eV8p/CQa0Ch276C7u2d0eNC9kCmAYQ7mCXCzVs= +github.com/redis/go-redis/v9 v9.10.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -50,8 +42,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/yuin/gopher-lua v1.1.0 h1:BojcDhfyDWgU2f2TOzYK/g5p2gxMrku8oupLDqlnSqE= -github.com/yuin/gopher-lua v1.1.0/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= @@ -62,10 +52,8 @@ golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98= golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= -golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/taskmanager/redis/options.go b/taskmanager/redis/options.go deleted file mode 100644 index 13f792e..0000000 --- a/taskmanager/redis/options.go +++ /dev/null @@ -1,21 +0,0 @@ -// Tencent is pleased to support the open source community by making trpc-a2a-go available. -// -// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. -// -// trpc-a2a-go is licensed under the Apache License Version 2.0. - -package redis - -import ( - "time" -) - -// Option is a function that configures the RedisTaskManager. -type Option func(*TaskManager) - -// WithExpiration sets the expiration time for Redis keys. -func WithExpiration(expiration time.Duration) Option { - return func(o *TaskManager) { - o.expiration = expiration - } -} diff --git a/taskmanager/redis/push_notification.go b/taskmanager/redis/push_notification.go deleted file mode 100644 index f33a91b..0000000 --- a/taskmanager/redis/push_notification.go +++ /dev/null @@ -1,165 +0,0 @@ -// Tencent is pleased to support the open source community by making trpc-a2a-go available. -// -// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. -// -// trpc-a2a-go is licensed under the Apache License Version 2.0. - -// Package redis provides a Redis-based implementation of the A2A TaskManager interface. -package redis - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "time" - - "trpc.group/trpc-go/trpc-a2a-go/auth" - "trpc.group/trpc-go/trpc-a2a-go/log" - "trpc.group/trpc-go/trpc-a2a-go/protocol" -) - -// getPushNotificationConfig retrieves a push notification configuration for a task. -func (m *TaskManager) getPushNotificationConfig( - ctx context.Context, taskID string, -) (*protocol.PushNotificationConfig, error) { - key := pushNotificationPrefix + taskID - val, err := m.client.Get(ctx, key).Result() - if err != nil { - if err.Error() == "redis: nil" { - // No push notification configured for this task. - return nil, nil - } - return nil, err - } - var config protocol.PushNotificationConfig - if err := json.Unmarshal([]byte(val), &config); err != nil { - return nil, fmt.Errorf("failed to unmarshal push notification config: %w", err) - } - return &config, nil -} - -// sendPushNotification sends a notification to the registered webhook URL. -func (m *TaskManager) sendPushNotification( - ctx context.Context, taskID string, event protocol.TaskEvent, -) error { - // Get the notification config. - config, err := m.getPushNotificationConfig(ctx, taskID) - if err != nil { - return fmt.Errorf("failed to get push notification config: %w", err) - } - if config == nil { - // No push notification configured, nothing to do. - return nil - } - - // Prepare the notification payload. - eventType := "" - if _, isStatus := event.(protocol.TaskStatusUpdateEvent); isStatus { - eventType = protocol.EventTaskStatusUpdate - } else if _, isArtifact := event.(protocol.TaskArtifactUpdateEvent); isArtifact { - eventType = protocol.EventTaskArtifactUpdate - } else { - return fmt.Errorf("unsupported event type: %T", event) - } - - notification := map[string]interface{}{ - "jsonrpc": "2.0", - "method": "tasks/notifyEvent", - "params": map[string]interface{}{ - "id": taskID, - "eventType": eventType, - "event": event, - }, - } - - // Marshal the notification to JSON. - body, err := json.Marshal(notification) - if err != nil { - return fmt.Errorf("failed to marshal notification: %w", err) - } - - // Create HTTP request. - req, err := http.NewRequestWithContext( - ctx, http.MethodPost, config.URL, bytes.NewReader(body), - ) - if err != nil { - return fmt.Errorf("failed to create notification request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - // Add authentication if configured. - if config.Authentication != nil { - // Check for JWT authentication using the "bearer" scheme. - for _, scheme := range config.Authentication.Schemes { - if strings.EqualFold(scheme, "bearer") { - // Check if we have a JWKs URL in the metadata for JWT auth - if config.Metadata != nil { - if jwksURL, ok := config.Metadata["jwksUrl"].(string); ok && jwksURL != "" { - // Create a JWT token for the notification - if authHeader, err := m.createJWTAuthHeader(body, jwksURL); err == nil { - req.Header.Set("Authorization", authHeader) - break - } else { - log.Errorf("Failed to create JWT auth header: %v", err) - } - } - } - } - } - - // Add token if provided. - if config.Token != "" { - req.Header.Set("Authorization", "Bearer "+config.Token) - } - } - - // Send the notification. - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("failed to send notification: %w", err) - } - defer resp.Body.Close() - - // Check for success. - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("notification failed with status %d: %s", resp.StatusCode, string(body)) - } - - return nil -} - -// createJWTAuthHeader creates a JWT authorization header for push notifications. -// It uses the jwksURL to configure a client to fetch the JWKs when needed. -func (m *TaskManager) createJWTAuthHeader(payload []byte, jwksURL string) (string, error) { - // Initialize the push auth helper if needed - m.pushAuthMu.Lock() - defer m.pushAuthMu.Unlock() - - if m.pushAuth == nil { - // This is the first request, so initialize the push auth helper - if err := m.initializePushAuth(jwksURL); err != nil { - return "", err - } - } - - // Create the authorization header - return m.pushAuth.CreateAuthorizationHeader(payload) -} - -// Initialize the push notification authenticator. -func (m *TaskManager) initializePushAuth(jwksURL string) error { - // Create a new authenticator and generate a key pair - m.pushAuth = auth.NewPushNotificationAuthenticator() - if err := m.pushAuth.GenerateKeyPair(); err != nil { - m.pushAuth = nil - return fmt.Errorf("failed to generate key pair: %w", err) - } - - return nil -} diff --git a/taskmanager/redis/redis.go b/taskmanager/redis/redis.go deleted file mode 100644 index 3633460..0000000 --- a/taskmanager/redis/redis.go +++ /dev/null @@ -1,733 +0,0 @@ -// Tencent is pleased to support the open source community by making trpc-a2a-go available. -// -// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. -// -// trpc-a2a-go is licensed under the Apache License Version 2.0. - -// Package redis provides a Redis-based implementation of the A2A TaskManager interface. -package redis - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "sync" - "time" - - "github.com/redis/go-redis/v9" - - "trpc.group/trpc-go/trpc-a2a-go/auth" - "trpc.group/trpc-go/trpc-a2a-go/log" - "trpc.group/trpc-go/trpc-a2a-go/protocol" - "trpc.group/trpc-go/trpc-a2a-go/taskmanager" -) - -const ( - // Key prefixes for Redis storage. - taskPrefix = "task:" - messagePrefix = "msg:" - pushNotificationPrefix = "push:" - subscriberPrefix = "sub:" - - // Default expiration time for Redis keys (30 days). - defaultExpiration = 30 * 24 * time.Hour -) - -// TaskManager provides a concrete, Redis-based implementation of the -// TaskManager interface. It persists tasks and messages in Redis. -// It requires a TaskProcessor to handle the actual agent logic. -// It is safe for concurrent use. -type TaskManager struct { - // processor is the agent logic processor. - processor taskmanager.TaskProcessor - // client is the Redis client. - client redis.UniversalClient - // expiration is the time after which Redis keys expire. - expiration time.Duration - - // subMu is a mutex for the Subscribers map. - subMu sync.RWMutex - // subscribers is a map of task IDs to subscriber channels. - subscribers map[string][]chan<- protocol.TaskEvent - - // cancelMu is a mutex for the cancels map. - cancelMu sync.RWMutex - // cancels is a map of task IDs to cancellation functions. - cancels map[string]context.CancelFunc - - // pushAuth is the push notification authenticator. - pushAuth *auth.PushNotificationAuthenticator - // pushAuthMu is a mutex for the pushAuth field. - pushAuthMu sync.Mutex -} - -// NewRedisTaskManager creates a new Redis-based TaskManager with the provided options. -func NewRedisTaskManager( - client redis.UniversalClient, - processor taskmanager.TaskProcessor, - opts ...Option, -) (*TaskManager, error) { - if processor == nil { - return nil, errors.New("task processor cannot be nil") - } - // Test connection. - if err := client.Ping(context.Background()).Err(); err != nil { - return nil, fmt.Errorf("failed to connect to Redis: %w", err) - } - expiration := defaultExpiration - manager := &TaskManager{ - processor: processor, - client: client, - expiration: expiration, - subscribers: make(map[string][]chan<- protocol.TaskEvent), - cancels: make(map[string]context.CancelFunc), - } - for _, opt := range opts { - opt(manager) - } - return manager, nil -} - -// redisTaskHandle implements the TaskHandle interface for Redis. -type redisTaskHandle struct { - taskID string - manager *TaskManager -} - -// UpdateStatus implements TaskHandle. -func (h *redisTaskHandle) UpdateStatus(state protocol.TaskState, msg *protocol.Message) error { - return h.manager.UpdateTaskStatus(h.taskID, state, msg) -} - -// AddArtifact implements TaskHandle -func (h *redisTaskHandle) AddArtifact(artifact protocol.Artifact) error { - return h.manager.AddArtifact(h.taskID, artifact) -} - -// IsStreamingRequest implements TaskHandle. -// It returns true if there are active subscribers for this task, -// indicating it was initiated with OnSendTaskSubscribe rather than OnSendTask. -func (h *redisTaskHandle) IsStreamingRequest() bool { - h.manager.subMu.RLock() - defer h.manager.subMu.RUnlock() - - subscribers, exists := h.manager.subscribers[h.taskID] - return exists && len(subscribers) > 0 -} - -// GetSessionID implements TaskHandle. -func (h *redisTaskHandle) GetSessionID() *string { - h.manager.subMu.RLock() - defer h.manager.subMu.RUnlock() - - task, err := h.manager.getTaskInternal(context.Background(), h.taskID) - if err != nil { - log.Errorf("Error getting session ID for task %s: %v", h.taskID, err) - return nil - } - - return task.SessionID -} - -// OnSendTask handles the creation or retrieval of a task and initiates synchronous processing. -func (m *TaskManager) OnSendTask(ctx context.Context, params protocol.SendTaskParams) (*protocol.Task, error) { - // Create or update task - _ = m.upsertTask(ctx, params) - // Store the initial message - m.storeMessage(ctx, params.ID, params.Message) - // Create a cancellable context for this specific task processing - taskCtx, cancel := context.WithCancel(ctx) - defer cancel() // Ensure context is cancelled eventually. - handle := &redisTaskHandle{ - taskID: params.ID, - manager: m, - } - // Set initial status to Working *before* calling Process. - if err := m.UpdateTaskStatus(params.ID, protocol.TaskStateWorking, nil); err != nil { - log.Errorf("Error setting initial Working status for task %s: %v", params.ID, err) - // Return the task state as it exists, but also the error. - latestTask, _ := m.getTaskInternal(ctx, params.ID) // Ignore get error for now. - return latestTask, fmt.Errorf("failed to set initial working status: %w", err) - } - // Delegate the actual processing to the injected processor (synchronously). - var processorErr error - if processorErr = m.processor.Process(taskCtx, params.ID, params.Message, handle); processorErr != nil { - log.Errorf("Processor failed for task %s: %v", params.ID, processorErr) - errMsg := &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{protocol.NewTextPart(processorErr.Error())}, - } - // Log update error while still handling the processor error. - if updateErr := m.UpdateTaskStatus(params.ID, protocol.TaskStateFailed, errMsg); updateErr != nil { - log.Errorf("Failed to update task %s status to failed: %v", params.ID, updateErr) - } - } - // Return the latest task state after processing. - finalTask, getErr := m.getTaskInternal(ctx, params.ID) - if getErr != nil { - // Only return an error if we couldn't retrieve the task state. - return nil, fmt.Errorf("failed to get final task state: %w", getErr) - } - // Don't return the processor error, as it's already reflected in the task state. - return finalTask, nil -} - -// OnSendTaskSubscribe creates a new task and returns a channel for receiving TaskEvent updates. -func (m *TaskManager) OnSendTaskSubscribe( - ctx context.Context, - params protocol.SendTaskParams, -) (<-chan protocol.TaskEvent, error) { - // Create a new task or update an existing one. - task := m.upsertTask(ctx, params) - // Store the message that came with the request. - m.storeMessage(ctx, params.ID, params.Message) - // Create event channel for this specific subscriber. - eventChan := make(chan protocol.TaskEvent, 10) // Buffered to prevent blocking sends. - m.addSubscriber(params.ID, eventChan) - // Create a cancellable context for the processor. - processorCtx, cancel := context.WithCancel(ctx) - // Store the cancel function. - m.cancelMu.Lock() - m.cancels[params.ID] = cancel - m.cancelMu.Unlock() - // Set initial state if new (submitted -> working). - // This will generate the first event for subscribers. - if task.Status.State == protocol.TaskStateSubmitted { - if err := m.UpdateTaskStatus(params.ID, protocol.TaskStateWorking, nil); err != nil { - m.removeSubscriber(params.ID, eventChan) - close(eventChan) - return nil, err - } - } - // Start the processor in a goroutine. - go func() { - // Create a handle for the processor to interact with the task. - handle := &redisTaskHandle{ - taskID: params.ID, - manager: m, - } - log.Debugf("SSE Processor started for task %s", params.ID) - var err error - if err = m.processor.Process(processorCtx, params.ID, params.Message, handle); err != nil { - log.Errorf("Processor failed for task %s in subscribe: %v", params.ID, err) - if processorCtx.Err() != context.Canceled { - // Only update to failed if not already cancelled. - errMsg := &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{protocol.NewTextPart(err.Error())}, - } - if updateErr := m.UpdateTaskStatus( - params.ID, - protocol.TaskStateFailed, - errMsg, - ); updateErr != nil { - log.Errorf("Failed to update task %s status to failed: %v", params.ID, updateErr) - } - } - } - // Clean up the context regardless of how we finish. - m.cancelMu.Lock() - delete(m.cancels, params.ID) - m.cancelMu.Unlock() - log.Debugf("Processor finished for task %s in subscribe (Error: %v). Goroutine exiting.", params.ID, err) - // Close event channel and clean up subscriber. - log.Debugf("Closing event channel and removing subscriber for task %s.", params.ID) - m.removeSubscriber(params.ID, eventChan) - close(eventChan) - }() - // Return the channel for events. - return eventChan, nil -} - -// OnGetTask retrieves the current state of a task. -func (m *TaskManager) OnGetTask( - ctx context.Context, - params protocol.TaskQueryParams, -) (*protocol.Task, error) { - task, err := m.getTaskInternal(ctx, params.ID) - if err != nil { - return nil, err - } - // Optionally include message history if requested. - if params.HistoryLength != nil && *params.HistoryLength > 0 { - history, err := m.getMessageHistory(ctx, params.ID, *params.HistoryLength) - if err != nil { - log.Warnf("Failed to retrieve message history for task %s: %v", params.ID, err) - // Continue without history rather than failing the whole request - } else { - task.History = history - } - } - return task, nil -} - -// OnCancelTask requests the cancellation of an ongoing task. -func (m *TaskManager) OnCancelTask( - ctx context.Context, - params protocol.TaskIDParams, -) (*protocol.Task, error) { - task, err := m.getTaskInternal(ctx, params.ID) - if err != nil { - return nil, err - } - // Check if task is already in a final state. - if isFinalState(task.Status.State) { - return task, taskmanager.ErrTaskFinalState(params.ID, task.Status.State) - } - var cancelFound bool - m.cancelMu.Lock() - cancel, exists := m.cancels[params.ID] - if exists { - cancel() // Call the cancel function. - cancelFound = true - // Don't delete the context here - let the processor goroutine clean up. - } - m.cancelMu.Unlock() - // If no cancellation function was found, log a warning. - if !cancelFound { - log.Warnf("Warning: No cancellation function found for task %s", params.ID) - } - // Create a cancellation message. - cancelMsg := &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{ - protocol.NewTextPart(fmt.Sprintf("Task %s was canceled by user request", params.ID)), - }, - } - // Update state to Cancelled. - if err := m.UpdateTaskStatus(params.ID, protocol.TaskStateCanceled, cancelMsg); err != nil { - log.Errorf("Error updating status to Cancelled for task %s: %v", params.ID, err) - return nil, err - } - // Fetch the updated task state to return. - updatedTask, err := m.getTaskInternal(ctx, params.ID) - if err != nil { - return nil, fmt.Errorf("failed to get task %s after cancellation update: %w", params.ID, err) - } - - return updatedTask, nil -} - -// OnPushNotificationSet configures push notifications for a specific task -func (m *TaskManager) OnPushNotificationSet( - ctx context.Context, - params protocol.TaskPushNotificationConfig, -) (*protocol.TaskPushNotificationConfig, error) { - // Check if task exists. - _, err := m.getTaskInternal(ctx, params.ID) - if err != nil { - return nil, err - } - // Store the push notification configuration. - pushKey := pushNotificationPrefix + params.ID - configBytes, err := json.Marshal(params.PushNotificationConfig) - if err != nil { - return nil, fmt.Errorf("failed to serialize push notification config: %w", err) - } - if err := m.client.Set(ctx, pushKey, configBytes, m.expiration).Err(); err != nil { - return nil, fmt.Errorf("failed to store push notification config: %w", err) - } - log.Infof("Set push notification for task %s to URL: %s", params.ID, params.PushNotificationConfig.URL) - // Return the stored configuration as confirmation. - return ¶ms, nil -} - -// OnPushNotificationGet retrieves the push notification configuration for a task. -func (m *TaskManager) OnPushNotificationGet( - ctx context.Context, - params protocol.TaskIDParams, -) (*protocol.TaskPushNotificationConfig, error) { - // Check if task exists. - _, err := m.getTaskInternal(ctx, params.ID) - if err != nil { - return nil, err - } - // Retrieve the push notification configuration. - pushKey := pushNotificationPrefix + params.ID - configBytes, err := m.client.Get(ctx, pushKey).Bytes() - if err != nil { - if err == redis.Nil { - return nil, taskmanager.ErrPushNotificationNotConfigured(params.ID) - } - return nil, fmt.Errorf("failed to retrieve push notification config: %w", err) - } - var config protocol.PushNotificationConfig - if err := json.Unmarshal(configBytes, &config); err != nil { - return nil, fmt.Errorf("failed to deserialize push notification config: %w", err) - } - result := &protocol.TaskPushNotificationConfig{ - ID: params.ID, - PushNotificationConfig: config, - } - return result, nil -} - -// OnResubscribe reestablishes an SSE stream for an existing task. -func (m *TaskManager) OnResubscribe( - ctx context.Context, - params protocol.TaskIDParams, -) (<-chan protocol.TaskEvent, error) { - task, err := m.getTaskInternal(ctx, params.ID) - if err != nil { - return nil, err - } - // Create a channel for events. - eventChan := make(chan protocol.TaskEvent) - // For tasks in final state, just send a status update event and close. - if isFinalState(task.Status.State) { - go func() { - // Send a task status update event - event := protocol.TaskStatusUpdateEvent{ - ID: task.ID, - Status: task.Status, - Final: true, - } - select { - case eventChan <- event: - // Successfully sent final status. - log.Debugf("Sent final status to resubscribed client for task %s: %s", task.ID, task.Status.State) - case <-ctx.Done(): - // Context was canceled. - log.Debugf("Context done before sending status for task %s", task.ID) - } - // Close the channel to signal no more events. - close(eventChan) - }() - return eventChan, nil - } - // For tasks still in progress, add this as a subscriber. - m.addSubscriber(params.ID, eventChan) - // Ensure we remove the subscriber when the context is canceled. - go func() { - <-ctx.Done() - m.removeSubscriber(params.ID, eventChan) - // Don't close the channel here - that should happen in the task processing goroutine. - }() - // Send the current status as the first event. - go func() { - event := protocol.TaskStatusUpdateEvent{ - ID: task.ID, - Status: task.Status, - Final: isFinalState(task.Status.State), - } - select { - case eventChan <- event: - // Successfully sent initial status. - log.Debugf("Sent initial status to resubscribed client for task %s: %s", task.ID, task.Status.State) - case <-ctx.Done(): - // Context was canceled. - log.Debugf("Context done before sending initial status for task %s", task.ID) - m.removeSubscriber(params.ID, eventChan) - close(eventChan) - } - }() - return eventChan, nil -} - -// UpdateTaskStatus updates the task's state and notifies subscribers. -func (m *TaskManager) UpdateTaskStatus( - taskID string, - state protocol.TaskState, - message *protocol.Message, -) error { - ctx := context.Background() - task, err := m.getTaskInternal(ctx, taskID) - if err != nil { - log.Warnf("Warning: UpdateTaskStatus called for non-existent task %s", taskID) - return err - } - // Update status fields. - task.Status = protocol.TaskStatus{ - State: state, - Message: message, - Timestamp: time.Now().UTC().Format(time.RFC3339), - } - // Store updated task. - taskKey := taskPrefix + taskID - taskBytes, err := json.Marshal(task) - if err != nil { - return fmt.Errorf("failed to serialize task: %w", err) - } - if err := m.client.Set(ctx, taskKey, taskBytes, m.expiration).Err(); err != nil { - return fmt.Errorf("failed to update task status: %w", err) - } - // Store the message in history if provided. - if message != nil { - m.storeMessage(ctx, taskID, *message) - } - // Notify subscribers. - m.notifySubscribers(taskID, protocol.TaskStatusUpdateEvent{ - ID: taskID, - Status: task.Status, - Final: isFinalState(state), - }) - return nil -} - -// AddArtifact adds an artifact to the task and notifies subscribers. -func (m *TaskManager) AddArtifact(taskID string, artifact protocol.Artifact) error { - ctx := context.Background() - task, err := m.getTaskInternal(ctx, taskID) - if err != nil { - log.Warnf("Warning: AddArtifact called for non-existent task %s", taskID) - return err - } - // Append the artifact. - if task.Artifacts == nil { - task.Artifacts = make([]protocol.Artifact, 0, 1) - } - task.Artifacts = append(task.Artifacts, artifact) - // Store updated task - taskKey := taskPrefix + taskID - taskBytes, err := json.Marshal(task) - if err != nil { - return fmt.Errorf("failed to serialize task: %w", err) - } - if err := m.client.Set(ctx, taskKey, taskBytes, m.expiration).Err(); err != nil { - return fmt.Errorf("failed to update task artifacts: %w", err) - } - // Notify subscribers. - finalEvent := artifact.LastChunk != nil && *artifact.LastChunk - m.notifySubscribers(taskID, protocol.TaskArtifactUpdateEvent{ - ID: taskID, - Artifact: artifact, - Final: finalEvent, - }) - return nil -} - -// --- Internal Helper Methods --- - -// isFinalState checks if a TaskState represents a terminal state. -func isFinalState(state protocol.TaskState) bool { - return state == protocol.TaskStateCompleted || - state == protocol.TaskStateFailed || - state == protocol.TaskStateCanceled -} - -// getTaskInternal retrieves a task from Redis. -func (m *TaskManager) getTaskInternal(ctx context.Context, taskID string) (*protocol.Task, error) { - taskKey := taskPrefix + taskID - taskBytes, err := m.client.Get(ctx, taskKey).Bytes() - if err != nil { - if err == redis.Nil { - return nil, taskmanager.ErrTaskNotFound(taskID) - } - return nil, fmt.Errorf("failed to retrieve task from Redis: %w", err) - } - var task protocol.Task - if err := json.Unmarshal(taskBytes, &task); err != nil { - return nil, fmt.Errorf("failed to deserialize task: %w", err) - } - return &task, nil -} - -// upsertTask creates a new task or updates metadata if it already exists. -func (m *TaskManager) upsertTask(ctx context.Context, params protocol.SendTaskParams) *protocol.Task { - taskKey := taskPrefix + params.ID - // Try to get existing task. - existingTaskBytes, err := m.client.Get(ctx, taskKey).Bytes() - var task *protocol.Task - if err == nil { - // Task exists, deserialize it. - task = &protocol.Task{} - if err := json.Unmarshal(existingTaskBytes, task); err != nil { - log.Errorf("Failed to deserialize existing task %s: %v", params.ID, err) - // Fall back to creating a new task. - task = protocol.NewTask(params.ID, params.SessionID) - } - log.Debugf("Updating existing task %s", params.ID) - } else if err == redis.Nil { - // Task doesn't exist, create new one. - task = protocol.NewTask(params.ID, params.SessionID) - log.Infof("Created new task %s (Session: %v)", params.ID, params.SessionID) - } else { - // Redis error. - log.Errorf("Redis error when retrieving task %s: %v", params.ID, err) - // Fall back to creating a new task. - task = protocol.NewTask(params.ID, params.SessionID) - } - // Update metadata if provided. - if params.Metadata != nil { - if task.Metadata == nil { - task.Metadata = make(map[string]interface{}) - } - for k, v := range params.Metadata { - task.Metadata[k] = v - } - } - // Store the task. - taskBytes, err := json.Marshal(task) - if err != nil { - log.Errorf("Failed to serialize task %s: %v", params.ID, err) - return task - } - if err := m.client.Set(ctx, taskKey, taskBytes, m.expiration).Err(); err != nil { - log.Errorf("Failed to store task %s in Redis: %v", params.ID, err) - } - return task -} - -// storeMessage adds a message to the task's history in Redis. -func (m *TaskManager) storeMessage(ctx context.Context, taskID string, message protocol.Message) { - messagesKey := messagePrefix + taskID - // Create a copy of the message to store. - messageCopy := protocol.Message{ - Role: message.Role, - Metadata: message.Metadata, - } - if message.Parts != nil { - messageCopy.Parts = make([]protocol.Part, len(message.Parts)) - copy(messageCopy.Parts, message.Parts) - } - // Serialize the message. - messageBytes, err := json.Marshal(messageCopy) - if err != nil { - log.Errorf("Failed to serialize message for task %s: %v", taskID, err) - return - } - // Add the message to a Redis list. - if err := m.client.RPush(ctx, messagesKey, messageBytes).Err(); err != nil { - log.Errorf("Failed to store message for task %s in Redis: %v", taskID, err) - return - } - // Set expiration on the message list. - m.client.Expire(ctx, messagesKey, m.expiration) -} - -// getMessageHistory retrieves message history for a task. -func (m *TaskManager) getMessageHistory( - ctx context.Context, - taskID string, - limit int, -) ([]protocol.Message, error) { - messagesKey := messagePrefix + taskID - // Get the message count. - count, err := m.client.LLen(ctx, messagesKey).Result() - if err != nil { - if err == redis.Nil { - return nil, nil // No messages found. - } - return nil, fmt.Errorf("failed to get message count: %w", err) - } - // Calculate range for LRANGE (get the latest messages). - start := int64(0) - if count > int64(limit) { - start = count - int64(limit) - } - // Get messages. - messagesBytesRaw, err := m.client.LRange(ctx, messagesKey, start, count-1).Result() - if err != nil { - return nil, fmt.Errorf("failed to retrieve messages: %w", err) - } - // Deserialize messages. - messages := make([]protocol.Message, 0, len(messagesBytesRaw)) - for _, msgBytes := range messagesBytesRaw { - var msg protocol.Message - if err := json.Unmarshal([]byte(msgBytes), &msg); err != nil { - log.Errorf("Failed to deserialize message for task %s: %v", taskID, err) - continue // Skip invalid messages. - } - messages = append(messages, msg) - } - return messages, nil -} - -// addSubscriber adds a channel to the list of subscribers for a task. -func (m *TaskManager) addSubscriber(taskID string, ch chan<- protocol.TaskEvent) { - m.subMu.Lock() - defer m.subMu.Unlock() - // If the task has no subscribers, create a new list. - if _, exists := m.subscribers[taskID]; !exists { - m.subscribers[taskID] = make([]chan<- protocol.TaskEvent, 0, 1) - } - // Add the new subscriber. - m.subscribers[taskID] = append(m.subscribers[taskID], ch) - log.Debugf("Added subscriber for task %s", taskID) -} - -// removeSubscriber removes a specific channel from the list of subscribers for a task. -func (m *TaskManager) removeSubscriber(taskID string, ch chan<- protocol.TaskEvent) { - m.subMu.Lock() - defer m.subMu.Unlock() - channels, exists := m.subscribers[taskID] - if !exists { - return // No subscribers for this task. - } - // Filter out the channel to remove. - var newChannels []chan<- protocol.TaskEvent - for _, existingCh := range channels { - if existingCh != ch { - newChannels = append(newChannels, existingCh) - } - } - // If there are no subscribers, remove the task from the map. - if len(newChannels) == 0 { - delete(m.subscribers, taskID) - } else { - m.subscribers[taskID] = newChannels - } - log.Debugf("Removed subscriber for task %s", taskID) -} - -// notifySubscribers sends an event to all current subscribers of a task. -func (m *TaskManager) notifySubscribers(taskID string, event protocol.TaskEvent) { - m.subMu.RLock() - subs, exists := m.subscribers[taskID] - if !exists || len(subs) == 0 { - m.subMu.RUnlock() - return // No subscribers to notify. - } - // Copy the slice of channels under read lock - subsCopy := make([]chan<- protocol.TaskEvent, len(subs)) - copy(subsCopy, subs) - m.subMu.RUnlock() - log.Debugf("Notifying %d subscribers for task %s (Event Type: %T, Final: %t)", - len(subsCopy), taskID, event, event.IsFinal()) - // Send events outside the lock. - for _, ch := range subsCopy { - // Use a select with a default case for a non-blocking send. - select { - case ch <- event: - // Event sent successfully. - default: - // Channel buffer is full or channel is closed. - log.Warnf("Warning: Dropping event for task %s subscriber - channel full or closed.", taskID) - } - } -} - -// Close closes the Redis client and cleans up resources. -func (m *TaskManager) Close() error { - // Cancel all active contexts. - m.cancelMu.Lock() - for _, cancel := range m.cancels { - cancel() - } - m.cancels = make(map[string]context.CancelFunc) - m.cancelMu.Unlock() - - // Close all subscriber channels. - m.subMu.Lock() - for taskID, channels := range m.subscribers { - for _, ch := range channels { - // Try to notify of closing but don't block - select { - case ch <- protocol.TaskStatusUpdateEvent{ - ID: taskID, - Status: protocol.TaskStatus{ - State: protocol.TaskStateUnknown, - Timestamp: time.Now().UTC().Format(time.RFC3339), - }, - Final: true, - }: - default: - } - } - } - m.subscribers = make(map[string][]chan<- protocol.TaskEvent) - m.subMu.Unlock() - // Close the Redis client. - return m.client.Close() -} diff --git a/taskmanager/redis/redis_manager.go b/taskmanager/redis/redis_manager.go new file mode 100644 index 0000000..04c5dce --- /dev/null +++ b/taskmanager/redis/redis_manager.go @@ -0,0 +1,673 @@ +// Tencent is pleased to support the open source community by making trpc-a2a-go available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// trpc-a2a-go is licensed under the Apache License Version 2.0. + +// Package redis provides a Redis-based implementation of the A2A TaskManager interface. +package redis + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "time" + + "github.com/redis/go-redis/v9" + "trpc.group/trpc-go/trpc-a2a-go/log" + "trpc.group/trpc-go/trpc-a2a-go/protocol" + "trpc.group/trpc-go/trpc-a2a-go/taskmanager" +) + +const ( + // Key prefixes for Redis storage. + messagePrefix = "msg:" + conversationPrefix = "conv:" + taskPrefix = "task:" + pushNotificationPrefix = "push:" + subscriberPrefix = "sub:" + + // Default expiration time for Redis keys (30 days). + defaultExpiration = 1 * time.Hour + + // Default configuration values. + defaultMaxHistoryLength = 100 + defaultTaskSubscriberBufferSize = 10 +) + +// TaskManager provides a concrete, Redis-based implementation of the +// TaskManager interface. It persists messages, conversations, and tasks in Redis. +// It requires a MessageProcessor to handle the actual agent logic. +// It is safe for concurrent use. +type TaskManager struct { + // processor is the user-provided message processor. + processor taskmanager.MessageProcessor + // client is the Redis client. + client *redis.Client + // expiration is the time after which Redis keys expire. + expiration time.Duration + + // subMu is a mutex for the subscribers map. + subMu sync.RWMutex + // subscribers is a map of task IDs to subscriber channels. + subscribers map[string][]*TaskSubscriber + + // cancelMu is a mutex for the cancels map. + cancelMu sync.RWMutex + // cancels is a map of task IDs to cancellation functions. + cancels map[string]context.CancelFunc + + // Configuration options. + maxHistoryLength int // max history message count +} + +// NewTaskManager creates a new Redis-based TaskManager with the provided options. +func NewTaskManager( + client *redis.Client, + processor taskmanager.MessageProcessor, + opts ...TaskManagerOption, +) (*TaskManager, error) { + if processor == nil { + return nil, errors.New("processor cannot be nil") + } + if client == nil { + return nil, errors.New("redis client cannot be nil") + } + + // Test connection. + if err := client.Ping(context.Background()).Err(); err != nil { + return nil, fmt.Errorf("failed to connect to Redis: %w", err) + } + + // Apply default options + options := DefaultRedisTaskManagerOptions() + + // Apply user options + for _, opt := range opts { + opt(options) + } + + // Use expiration time from options + expiration := options.ExpireTime + + manager := &TaskManager{ + processor: processor, + client: client, + expiration: expiration, + subscribers: make(map[string][]*TaskSubscriber), + cancels: make(map[string]context.CancelFunc), + maxHistoryLength: options.MaxHistoryLength, + } + + return manager, nil +} + +// OnSendMessage handles the message/send request. +func (m *TaskManager) OnSendMessage( + ctx context.Context, + request protocol.SendMessageParams, +) (*protocol.MessageResult, error) { + log.Debugf("RedisTaskManager: OnSendMessage for message %s", request.Message.MessageID) + + // Process the request message. + m.processRequestMessage(&request.Message) + + // Process configuration. + options := m.processConfiguration(request.Configuration, request.Metadata) + options.Streaming = false // non-streaming processing + + // Create MessageHandle. + handle := &taskHandler{ + manager: m, + messageID: request.Message.MessageID, + ctx: ctx, + } + + // Call the user's message processor. + result, err := m.processor.ProcessMessage(ctx, request.Message, options, handle) + if err != nil { + return nil, fmt.Errorf("message processing failed: %w", err) + } + + if result == nil { + return nil, fmt.Errorf("processor returned nil result") + } + + // Check if the user returned StreamingEvents for non-streaming request. + if result.StreamingEvents != nil { + log.Infof("User returned StreamingEvents for non-streaming request, ignoring") + } + + if result.Result == nil { + return nil, fmt.Errorf("processor returned nil result for non-streaming request") + } + + switch result.Result.(type) { + case *protocol.Task: + case *protocol.Message: + default: + return nil, fmt.Errorf("processor returned unsupported result type %T for SendMessage request", result.Result) + } + + if message, ok := result.Result.(*protocol.Message); ok { + var contextID string + if request.Message.ContextID != nil { + contextID = *request.Message.ContextID + } + m.processReplyMessage(contextID, message) + } + + return &protocol.MessageResult{Result: result.Result}, nil +} + +// OnSendMessageStream handles message/stream requests. +func (m *TaskManager) OnSendMessageStream( + ctx context.Context, + request protocol.SendMessageParams, +) (<-chan protocol.StreamingMessageEvent, error) { + log.Debugf("RedisTaskManager: OnSendMessageStream for message %s", request.Message.MessageID) + + m.processRequestMessage(&request.Message) + + // Process configuration. + options := m.processConfiguration(request.Configuration, request.Metadata) + options.Streaming = true // streaming mode + + // Create streaming MessageHandle. + handle := &taskHandler{ + manager: m, + messageID: request.Message.MessageID, + ctx: ctx, + } + + // Call user's message processor. + result, err := m.processor.ProcessMessage(ctx, request.Message, options, handle) + if err != nil { + return nil, fmt.Errorf("message processing failed: %w", err) + } + + if result == nil || result.StreamingEvents == nil { + return nil, fmt.Errorf("processor returned nil result") + } + + return result.StreamingEvents.Channel(), nil +} + +// OnGetTask handles the tasks/get request. +func (m *TaskManager) OnGetTask( + ctx context.Context, + params protocol.TaskQueryParams, +) (*protocol.Task, error) { + task, err := m.getTaskInternal(ctx, params.ID) + if err != nil { + return nil, err + } + + // If the request contains history length, fill the message history. + if params.HistoryLength != nil && *params.HistoryLength > 0 { + if task.ContextID != "" { + history, err := m.getConversationHistory(ctx, task.ContextID, *params.HistoryLength) + if err != nil { + log.Warnf("Failed to retrieve message history for task %s: %v", params.ID, err) + // Continue without history rather than failing the whole request. + } else { + task.History = history + } + } + } + + return task, nil +} + +// OnCancelTask handles the tasks/cancel request. +func (m *TaskManager) OnCancelTask( + ctx context.Context, + params protocol.TaskIDParams, +) (*protocol.Task, error) { + task, err := m.getTaskInternal(ctx, params.ID) + if err != nil { + return nil, err + } + + // Check if task is already in a final state. + if isFinalState(task.Status.State) { + return task, fmt.Errorf("task %s is already in final state: %s", params.ID, task.Status.State) + } + + var cancelFound bool + m.cancelMu.Lock() + cancel, exists := m.cancels[params.ID] + if exists { + cancel() // Call the cancel function. + cancelFound = true + // Don't delete the context here - let the processor goroutine clean up. + } + m.cancelMu.Unlock() + + // If no cancellation function was found, log a warning. + if !cancelFound { + log.Warnf("Warning: No cancellation function found for task %s", params.ID) + } + + // Update task state to Cancelled. + task.Status.State = protocol.TaskStateCanceled + task.Status.Timestamp = time.Now().UTC().Format(time.RFC3339) + + // Store updated task. + if err := m.storeTask(ctx, task); err != nil { + log.Errorf("Error storing cancelled task %s: %v", params.ID, err) + return nil, err + } + + // Clean up subscribers. + m.cleanSubscribers(params.ID) + + return task, nil +} + +// OnPushNotificationSet handles tasks/pushNotificationConfig/set requests. +func (m *TaskManager) OnPushNotificationSet( + ctx context.Context, + params protocol.TaskPushNotificationConfig, +) (*protocol.TaskPushNotificationConfig, error) { + // Check if task exists. + _, err := m.getTaskInternal(ctx, params.TaskID) + if err != nil { + return nil, err + } + + // Store the push notification configuration. + pushKey := pushNotificationPrefix + params.TaskID + configBytes, err := json.Marshal(params) + if err != nil { + return nil, fmt.Errorf("failed to serialize push notification config: %w", err) + } + + if err := m.client.Set(ctx, pushKey, configBytes, m.expiration).Err(); err != nil { + return nil, fmt.Errorf("failed to store push notification config: %w", err) + } + + log.Debugf("RedisTaskManager: Push notification config set for task %s", params.TaskID) + return ¶ms, nil +} + +// OnPushNotificationGet handles tasks/pushNotificationConfig/get requests. +func (m *TaskManager) OnPushNotificationGet( + ctx context.Context, + params protocol.TaskIDParams, +) (*protocol.TaskPushNotificationConfig, error) { + // Check if task exists. + _, err := m.getTaskInternal(ctx, params.ID) + if err != nil { + return nil, err + } + + // Retrieve the push notification configuration. + pushKey := pushNotificationPrefix + params.ID + configBytes, err := m.client.Get(ctx, pushKey).Bytes() + if err != nil { + return nil, fmt.Errorf("push notification config not found for task: %s", params.ID) + } + + var config protocol.TaskPushNotificationConfig + if err := json.Unmarshal(configBytes, &config); err != nil { + return nil, fmt.Errorf("failed to deserialize push notification config: %w", err) + } + + return &config, nil +} + +// OnResubscribe handles tasks/resubscribe requests. +func (m *TaskManager) OnResubscribe( + ctx context.Context, + params protocol.TaskIDParams, +) (<-chan protocol.StreamingMessageEvent, error) { + // Check if task exists. + _, err := m.getTaskInternal(ctx, params.ID) + if err != nil { + return nil, err + } + + subscriber := NewTaskSubscriber(params.ID, defaultTaskSubscriberBufferSize) + + // Add to subscribers list. + m.addSubscriber(params.ID, subscriber) + + return subscriber.Channel(), nil +} + +// OnSendTask deprecated method empty implementation. +func (m *TaskManager) OnSendTask(ctx context.Context, request protocol.SendTaskParams) (*protocol.Task, error) { + return nil, fmt.Errorf("OnSendTask is deprecated, use OnSendMessage instead") +} + +// OnSendTaskSubscribe deprecated method empty implementation. +func (m *TaskManager) OnSendTaskSubscribe(ctx context.Context, request protocol.SendTaskParams) (<-chan protocol.TaskEvent, error) { + return nil, fmt.Errorf("OnSendTaskSubscribe is deprecated, use OnSendMessageStream instead") +} + +// ============================================================================= +// Internal helper methods +// ============================================================================= + +// processConfiguration processes and normalizes configuration. +func (m *TaskManager) processConfiguration(config *protocol.SendMessageConfiguration, metadata map[string]interface{}) taskmanager.ProcessOptions { + result := taskmanager.ProcessOptions{ + Blocking: false, + HistoryLength: 0, + } + + if config == nil { + return result + } + + // Process Blocking configuration. + if config.Blocking != nil { + result.Blocking = *config.Blocking + } + + // Process HistoryLength configuration. + if config.HistoryLength != nil && *config.HistoryLength > 0 { + result.HistoryLength = *config.HistoryLength + } + + // Process PushNotificationConfig. + if config.PushNotificationConfig != nil { + result.PushNotificationConfig = config.PushNotificationConfig + } + + return result +} + +// processRequestMessage processes and stores the request message. +func (m *TaskManager) processRequestMessage(message *protocol.Message) { + if message.MessageID == "" { + message.MessageID = protocol.GenerateMessageID() + } + if message.ContextID != nil { + m.storeMessage(context.Background(), *message) + } +} + +// processReplyMessage processes and stores the reply message. +func (m *TaskManager) processReplyMessage(ctxID string, message *protocol.Message) { + message.ContextID = &ctxID + message.Role = protocol.MessageRoleAgent + if message.MessageID == "" { + message.MessageID = m.generateMessageID() + } + + // If contextID is not nil, store the conversation history. + if message.ContextID != nil { + m.storeMessage(context.Background(), *message) + } +} + +// generateMessageID generates a message ID. +func (m *TaskManager) generateMessageID() string { + return protocol.GenerateMessageID() +} + +// storeMessage stores a message in Redis and updates conversation history. +func (m *TaskManager) storeMessage(ctx context.Context, message protocol.Message) { + // Store the message. + msgKey := messagePrefix + message.MessageID + msgBytes, err := json.Marshal(message) + if err != nil { + log.Errorf("Failed to serialize message %s: %v", message.MessageID, err) + return + } + + if err := m.client.Set(ctx, msgKey, msgBytes, m.expiration).Err(); err != nil { + log.Errorf("Failed to store message %s in Redis: %v", message.MessageID, err) + return + } + + // If the message has a contextID, add it to conversation history. + if message.ContextID != nil { + contextID := *message.ContextID + convKey := conversationPrefix + contextID + + // Add message ID to conversation history using Redis list. + if err := m.client.RPush(ctx, convKey, message.MessageID).Err(); err != nil { + log.Errorf("Failed to add message %s to conversation %s: %v", message.MessageID, contextID, err) + return + } + + // Set expiration on the conversation list. + m.client.Expire(ctx, convKey, m.expiration).Err() + + // Limit history length by trimming the list. + if err := m.client.LTrim(ctx, convKey, -int64(m.maxHistoryLength), -1).Err(); err != nil { + log.Errorf("Failed to trim conversation %s: %v", contextID, err) + } + } +} + +// getConversationHistory retrieves conversation history for a context. +func (m *TaskManager) getConversationHistory( + ctx context.Context, + contextID string, + length int, +) ([]protocol.Message, error) { + if contextID == "" { + return nil, nil + } + + convKey := conversationPrefix + contextID + + // Get the message count. + count, err := m.client.LLen(ctx, convKey).Result() + if err != nil { + return nil, nil // No messages found. + } + + // Calculate range for LRANGE (get the latest messages). + start := int64(0) + if count > int64(length) { + start = count - int64(length) + } + + // Get message IDs. + messageIDs, err := m.client.LRange(ctx, convKey, start, count-1).Result() + if err != nil { + return nil, fmt.Errorf("failed to retrieve message IDs: %w", err) + } + + // Retrieve messages. + messages := make([]protocol.Message, 0, len(messageIDs)) + for _, msgID := range messageIDs { + msgKey := messagePrefix + msgID + msgBytes, err := m.client.Get(ctx, msgKey).Bytes() + if err != nil { + log.Warnf("Message %s not found in Redis", msgID) + continue // Skip missing messages. + } + + var msg protocol.Message + if err := json.Unmarshal(msgBytes, &msg); err != nil { + log.Errorf("Failed to deserialize message %s: %v", msgID, err) + continue // Skip invalid messages. + } + + messages = append(messages, msg) + } + + return messages, nil +} + +// getTaskInternal retrieves a task from Redis. +func (m *TaskManager) getTaskInternal(ctx context.Context, taskID string) (*protocol.Task, error) { + taskKey := taskPrefix + taskID + taskBytes, err := m.client.Get(ctx, taskKey).Bytes() + if err != nil { + return nil, fmt.Errorf("task not found: %s", taskID) + } + + var task protocol.Task + if err := json.Unmarshal(taskBytes, &task); err != nil { + return nil, fmt.Errorf("failed to deserialize task: %w", err) + } + + return &task, nil +} + +// storeTask stores a task in Redis. +func (m *TaskManager) storeTask(ctx context.Context, task *protocol.Task) error { + taskKey := taskPrefix + task.ID + taskBytes, err := json.Marshal(task) + if err != nil { + return fmt.Errorf("failed to serialize task: %w", err) + } + + if err := m.client.Set(ctx, taskKey, taskBytes, m.expiration).Err(); err != nil { + return fmt.Errorf("failed to store task: %w", err) + } + + return nil +} + +// deleteTask deletes a task from Redis. +func (m *TaskManager) deleteTask(ctx context.Context, taskID string) error { + taskKey := taskPrefix + taskID + if err := m.client.Del(ctx, taskKey).Err(); err != nil { + return fmt.Errorf("failed to delete task: %w", err) + } + return nil +} + +// isFinalState checks if a TaskState represents a terminal state. +func isFinalState(state protocol.TaskState) bool { + return state == protocol.TaskStateCompleted || + state == protocol.TaskStateFailed || + state == protocol.TaskStateCanceled || + state == protocol.TaskStateRejected +} + +// addSubscriber adds a subscriber to the list. +func (m *TaskManager) addSubscriber(taskID string, sub *TaskSubscriber) { + m.subMu.Lock() + defer m.subMu.Unlock() + + if _, exists := m.subscribers[taskID]; !exists { + m.subscribers[taskID] = make([]*TaskSubscriber, 0) + } + m.subscribers[taskID] = append(m.subscribers[taskID], sub) + log.Debugf("Added subscriber for task %s", taskID) +} + +// cleanSubscribers cleans up all subscribers for a task. +func (m *TaskManager) cleanSubscribers(taskID string) { + m.subMu.Lock() + defer m.subMu.Unlock() + + if subs, exists := m.subscribers[taskID]; exists { + for _, sub := range subs { + sub.Close() + } + delete(m.subscribers, taskID) + log.Debugf("Cleaned subscribers for task %s", taskID) + } +} + +// notifySubscribers notifies all subscribers of a task. +func (m *TaskManager) notifySubscribers(taskID string, event protocol.StreamingMessageEvent) { + m.subMu.RLock() + subs, exists := m.subscribers[taskID] + if !exists || len(subs) == 0 { + m.subMu.RUnlock() + return + } + + subsCopy := make([]*TaskSubscriber, len(subs)) + copy(subsCopy, subs) + m.subMu.RUnlock() + + log.Debugf("Notifying %d subscribers for task %s (Event Type: %T)", len(subsCopy), taskID, event.Result) + + var failedSubscribers []*TaskSubscriber + + for _, sub := range subsCopy { + if sub.Closed() { + log.Debugf("Subscriber for task %s is already closed, marking for removal", taskID) + failedSubscribers = append(failedSubscribers, sub) + continue + } + + err := sub.Send(event) + if err != nil { + log.Warnf("Failed to send event to subscriber for task %s: %v", taskID, err) + failedSubscribers = append(failedSubscribers, sub) + } + } + + // Clean up failed or closed subscribers. + if len(failedSubscribers) > 0 { + m.cleanupFailedSubscribers(taskID, failedSubscribers) + } +} + +// cleanupFailedSubscribers cleans up failed or closed subscribers. +func (m *TaskManager) cleanupFailedSubscribers(taskID string, failedSubscribers []*TaskSubscriber) { + m.subMu.Lock() + defer m.subMu.Unlock() + + subs, exists := m.subscribers[taskID] + if !exists { + return + } + + // Filter out failed subscribers. + filteredSubs := make([]*TaskSubscriber, 0, len(subs)) + removedCount := 0 + + for _, sub := range subs { + shouldRemove := false + for _, failedSub := range failedSubscribers { + if sub == failedSub { + shouldRemove = true + removedCount++ + break + } + } + if !shouldRemove { + filteredSubs = append(filteredSubs, sub) + } + } + + if removedCount > 0 { + m.subscribers[taskID] = filteredSubs + log.Debugf("Removed %d failed subscribers for task %s", removedCount, taskID) + + // If there are no subscribers left, delete the entire entry. + if len(filteredSubs) == 0 { + delete(m.subscribers, taskID) + } + } +} + +// Close closes the Redis client and cleans up resources. +func (m *TaskManager) Close() error { + // Cancel all active contexts. + m.cancelMu.Lock() + for _, cancel := range m.cancels { + cancel() + } + m.cancels = make(map[string]context.CancelFunc) + m.cancelMu.Unlock() + + // Close all subscriber channels. + m.subMu.Lock() + for _, subscribers := range m.subscribers { + for _, sub := range subscribers { + sub.Close() + } + } + m.subscribers = make(map[string][]*TaskSubscriber) + m.subMu.Unlock() + + // Close the Redis client. + return m.client.Close() +} diff --git a/taskmanager/redis/redis_options.go b/taskmanager/redis/redis_options.go new file mode 100644 index 0000000..d28d889 --- /dev/null +++ b/taskmanager/redis/redis_options.go @@ -0,0 +1,63 @@ +// Tencent is pleased to support the open source community by making trpc-a2a-go available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// trpc-a2a-go is licensed under the Apache License Version 2.0. + +// Package redis provides configuration options for RedisTaskManager. +package redis + +import ( + "time" +) + +// TaskManagerOptions contains configuration options for RedisTaskManager. +type TaskManagerOptions struct { + // ExpireTime is the time after which Redis keys expire. + ExpireTime time.Duration + + // MaxHistoryLength is the maximum number of messages to keep in conversation history. + MaxHistoryLength int + + // TaskSubscriberBufferSize is the buffer size for task subscriber channels. + TaskSubscriberBufferSize int +} + +// DefaultRedisTaskManagerOptions returns the default configuration options. +func DefaultRedisTaskManagerOptions() *TaskManagerOptions { + return &TaskManagerOptions{ + ExpireTime: defaultExpiration, + MaxHistoryLength: defaultMaxHistoryLength, + TaskSubscriberBufferSize: defaultTaskSubscriberBufferSize, + } +} + +// TaskManagerOption defines a function type for configuring RedisTaskManager. +type TaskManagerOption func(*TaskManagerOptions) + +// WithExpireTime sets the expiration time for Redis keys. +func WithExpireTime(expireTime time.Duration) TaskManagerOption { + return func(opts *TaskManagerOptions) { + if expireTime > 0 { + opts.ExpireTime = expireTime + } + } +} + +// WithMaxHistoryLength sets the maximum number of messages to keep in conversation history. +func WithMaxHistoryLength(length int) TaskManagerOption { + return func(opts *TaskManagerOptions) { + if length > 0 { + opts.MaxHistoryLength = length + } + } +} + +// WithTaskSubscriberBufferSize sets the buffer size for task subscriber channels. +func WithTaskSubscriberBufferSize(size int) TaskManagerOption { + return func(opts *TaskManagerOptions) { + if size > 0 { + opts.TaskSubscriberBufferSize = size + } + } +} diff --git a/taskmanager/redis/redis_task_handle.go b/taskmanager/redis/redis_task_handle.go new file mode 100644 index 0000000..a3a161a --- /dev/null +++ b/taskmanager/redis/redis_task_handle.go @@ -0,0 +1,284 @@ +// Tencent is pleased to support the open source community by making trpc-a2a-go available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// trpc-a2a-go is licensed under the Apache License Version 2.0. + +// Package redis provides a Redis-based implementation of the A2A TaskManager interface. +package redis + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "trpc.group/trpc-go/trpc-a2a-go/log" + "trpc.group/trpc-go/trpc-a2a-go/protocol" + "trpc.group/trpc-go/trpc-a2a-go/taskmanager" +) + +// taskHandler implements TaskHandler interface for Redis. +type taskHandler struct { + manager *TaskManager + messageID string + ctx context.Context +} + +var _ taskmanager.TaskHandler = (*taskHandler)(nil) + +// BuildTask creates a new task and returns the task ID. +func (h *taskHandler) BuildTask(specificTaskID *string, contextID *string) (string, error) { + // If no taskID provided, generate one. + var actualTaskID string + if specificTaskID == nil || *specificTaskID == "" { + actualTaskID = protocol.GenerateTaskID() + } else { + actualTaskID = *specificTaskID + } + + // Check if task already exists. + _, err := h.manager.getTaskInternal(h.ctx, actualTaskID) + if err == nil { + // Task exists, return the existing task ID. + log.Warnf("Task %s already exists, returning existing task ID", actualTaskID) + return actualTaskID, nil + } + + var actualContextID string + if contextID == nil || *contextID == "" { + actualContextID = "" + } else { + actualContextID = *contextID + } + + // Create new context for cancellation. + _, cancel := context.WithCancel(context.Background()) + + // Create new task. + task := &protocol.Task{ + ID: actualTaskID, + ContextID: actualContextID, + Kind: protocol.KindTask, + Status: protocol.TaskStatus{ + State: protocol.TaskStateSubmitted, + Timestamp: time.Now().UTC().Format(time.RFC3339), + }, + Artifacts: make([]protocol.Artifact, 0), + History: make([]protocol.Message, 0), + Metadata: make(map[string]interface{}), + } + + // Store task in Redis. + if err := h.manager.storeTask(h.ctx, task); err != nil { + cancel() // Clean up the context. + return "", fmt.Errorf("failed to store task: %w", err) + } + + // Store the cancel function. + h.manager.cancelMu.Lock() + h.manager.cancels[actualTaskID] = cancel + h.manager.cancelMu.Unlock() + + log.Debugf("Created new task %s with context %s", actualTaskID, actualContextID) + + return actualTaskID, nil +} + +// UpdateTaskState updates the task's state and returns an error if failed. +func (h *taskHandler) UpdateTaskState( + taskID *string, + state protocol.TaskState, + message *protocol.Message, +) error { + if taskID == nil || *taskID == "" { + return fmt.Errorf("taskID cannot be nil or empty") + } + + task, err := h.manager.getTaskInternal(h.ctx, *taskID) + if err != nil { + log.Warnf("UpdateTaskState called for non-existent task %s", *taskID) + return fmt.Errorf("task not found: %s", *taskID) + } + + // Update task status. + task.Status = protocol.TaskStatus{ + State: state, + Message: message, + Timestamp: time.Now().UTC().Format(time.RFC3339), + } + + // Store updated task. + if err := h.manager.storeTask(h.ctx, task); err != nil { + return fmt.Errorf("failed to update task status: %w", err) + } + + log.Debugf("Updated task %s state to %s", *taskID, state) + + // Notify subscribers. + finalState := isFinalState(state) + event := &protocol.TaskStatusUpdateEvent{ + TaskID: *taskID, + ContextID: task.ContextID, + Status: task.Status, + Kind: protocol.KindTaskStatusUpdate, + Final: &finalState, + } + streamEvent := protocol.StreamingMessageEvent{Result: event} + h.manager.notifySubscribers(*taskID, streamEvent) + + if finalState { + // Clean up resources for final states. + h.manager.cleanSubscribers(*taskID) + h.manager.cancelMu.Lock() + if cancel, exists := h.manager.cancels[*taskID]; exists { + cancel() + delete(h.manager.cancels, *taskID) + } + h.manager.cancelMu.Unlock() + } + + return nil +} + +// AddArtifact adds an artifact to the specified task. +func (h *taskHandler) AddArtifact(taskID *string, artifact protocol.Artifact, isFinal bool, needMoreData bool) error { + if taskID == nil || *taskID == "" { + return fmt.Errorf("taskID cannot be nil or empty") + } + + task, err := h.manager.getTaskInternal(h.ctx, *taskID) + if err != nil { + return fmt.Errorf("task not found: %s", *taskID) + } + + // Append the artifact. + if task.Artifacts == nil { + task.Artifacts = make([]protocol.Artifact, 0, 1) + } + task.Artifacts = append(task.Artifacts, artifact) + + // Store updated task. + if err := h.manager.storeTask(h.ctx, task); err != nil { + return fmt.Errorf("failed to update task artifacts: %w", err) + } + + log.Debugf("Added artifact %s to task %s", artifact.ArtifactID, *taskID) + + // Notify subscribers. + event := &protocol.TaskArtifactUpdateEvent{ + TaskID: *taskID, + ContextID: task.ContextID, + Artifact: artifact, + Kind: protocol.KindTaskArtifactUpdate, + LastChunk: &isFinal, + Append: &needMoreData, + } + streamEvent := protocol.StreamingMessageEvent{Result: event} + h.manager.notifySubscribers(*taskID, streamEvent) + + return nil +} + +// SubScribeTask subscribes to the task and returns a TaskSubscriber. +func (h *taskHandler) SubScribeTask(taskID *string) (taskmanager.TaskSubscriber, error) { + if taskID == nil || *taskID == "" { + return nil, fmt.Errorf("taskID cannot be nil or empty") + } + + // Check if task exists. + _, err := h.manager.getTaskInternal(h.ctx, *taskID) + if err != nil { + return nil, fmt.Errorf("task not found: %s", *taskID) + } + + subscriber := NewTaskSubscriber(*taskID, defaultTaskSubscriberBufferSize) + h.manager.addSubscriber(*taskID, subscriber) + return subscriber, nil +} + +// GetTask returns the task by taskID as a CancellableTask. +func (h *taskHandler) GetTask(taskID *string) (taskmanager.CancellableTask, error) { + if taskID == nil || *taskID == "" { + return nil, fmt.Errorf("taskID cannot be nil or empty") + } + + task, err := h.manager.getTaskInternal(h.ctx, *taskID) + if err != nil { + return nil, err + } + + // Get the cancel function for the task if it exists. + h.manager.cancelMu.RLock() + cancel, exists := h.manager.cancels[*taskID] + h.manager.cancelMu.RUnlock() + + if !exists { + // Create a no-op cancel function if one doesn't exist. + cancel = func() {} + } + + return NewRedisCancellableTask(task, cancel), nil +} + +// CleanTask deletes the task and cleans up all associated resources. +func (h *taskHandler) CleanTask(taskID *string) error { + if taskID == nil || *taskID == "" { + return fmt.Errorf("taskID cannot be nil or empty") + } + + // Get the task first to verify it exists. + _, err := h.manager.getTaskInternal(h.ctx, *taskID) + if err != nil { + return fmt.Errorf("task not found: %s", *taskID) + } + + // Cancel the task context. + h.manager.cancelMu.Lock() + if cancel, exists := h.manager.cancels[*taskID]; exists { + cancel() + delete(h.manager.cancels, *taskID) + } + h.manager.cancelMu.Unlock() + + // Clean up subscribers. + h.manager.cleanSubscribers(*taskID) + + // Delete the task from Redis. + return h.manager.deleteTask(h.ctx, *taskID) +} + +// GetContextID returns the context ID of the current message, if any. +func (h *taskHandler) GetContextID() string { + msgKey := messagePrefix + h.messageID + msgBytes, err := h.manager.client.Get(h.ctx, msgKey).Bytes() + if err != nil { + return "" + } + + var msg protocol.Message + if err := json.Unmarshal(msgBytes, &msg); err != nil { + return "" + } + + if msg.ContextID != nil { + return *msg.ContextID + } + return "" +} + +// GetMessageHistory returns the conversation history for the current context. +func (h *taskHandler) GetMessageHistory() []protocol.Message { + contextID := h.GetContextID() + if contextID == "" { + return []protocol.Message{} + } + + history, err := h.manager.getConversationHistory(h.ctx, contextID, h.manager.maxHistoryLength) + if err != nil { + log.Errorf("Failed to get message history for context %s: %v", contextID, err) + return []protocol.Message{} + } + + return history +} diff --git a/taskmanager/redis/redis_test.go b/taskmanager/redis/redis_test.go deleted file mode 100644 index d3d70fc..0000000 --- a/taskmanager/redis/redis_test.go +++ /dev/null @@ -1,551 +0,0 @@ -// Tencent is pleased to support the open source community by making trpc-a2a-go available. -// -// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. -// -// trpc-a2a-go is licensed under the Apache License Version 2.0. -package redis - -import ( - "context" - "fmt" - "sync" - "testing" - "time" - - "github.com/alicebob/miniredis/v2" - "github.com/redis/go-redis/v9" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "trpc.group/trpc-go/trpc-a2a-go/protocol" - "trpc.group/trpc-go/trpc-a2a-go/taskmanager" -) - -// testProcessor is a test implementation of the TaskProcessor interface -// It simulates specific behaviors based on the input message. -type testProcessor struct { - waitTime time.Duration // Used to control how long the processor takes - mu sync.Mutex - tasks map[string]bool -} - -func newTestProcessor() *testProcessor { - return &testProcessor{ - waitTime: 50 * time.Millisecond, - tasks: make(map[string]bool), - } -} - -// Process implements TaskProcessor -func (p *testProcessor) Process( - ctx context.Context, - taskID string, - initialMsg protocol.Message, - handle taskmanager.TaskHandle, -) error { - p.mu.Lock() - p.tasks[taskID] = true - p.mu.Unlock() - - // Extract command from the message text - var command string - var messageText string - if len(initialMsg.Parts) > 0 { - if textPart, ok := initialMsg.Parts[0].(protocol.TextPart); ok { - messageText = textPart.Text - // Extract command if format is "command:message" - for _, prefix := range []string{"sleep:", "fail:", "cancel:", "artifacts:", "input-required:"} { - if len(messageText) > len(prefix) && messageText[:len(prefix)] == prefix { - command = prefix[:len(prefix)-1] - messageText = messageText[len(prefix):] - break - } - } - } - } - - // Simulate slow processing if requested - if command == "sleep" { - sleepDuration := p.waitTime - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(sleepDuration): - // Continue processing - } - } - - // Simulate a failure if requested - if command == "fail" { - handle.UpdateStatus(protocol.TaskStateWorking, &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{protocol.NewTextPart("Processing started but will fail")}, - }) - return fmt.Errorf("task failed as requested: %s", messageText) - } - - // Check for cancellation request during processing - if command == "cancel" { - for i := 0; i < 10; i++ { - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(100 * time.Millisecond): - // Update status to show we're still working - handle.UpdateStatus(protocol.TaskStateWorking, &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{protocol.NewTextPart(fmt.Sprintf("Still working... %d", i))}, - }) - } - } - } - - // Simulate returning multiple artifacts if requested - if command == "artifacts" { - for i := 0; i < 3; i++ { - artifact := protocol.Artifact{ - Name: stringPtr(fmt.Sprintf("artifact-%d", i)), - Parts: []protocol.Part{ - protocol.NewTextPart(fmt.Sprintf("Artifact content %d: %s", i, messageText)), - }, - Index: i, - } - isLast := i == 2 - artifact.LastChunk = &isLast - if err := handle.AddArtifact(artifact); err != nil { - return err - } - time.Sleep(20 * time.Millisecond) // Small delay between artifacts - } - } - - // Simulate requiring input if requested - if command == "input-required" { - return handle.UpdateStatus(protocol.TaskStateInputRequired, &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{protocol.NewTextPart("Please provide more input: " + messageText)}, - }) - } - - // Default processing for normal tasks - return handle.UpdateStatus(protocol.TaskStateCompleted, &protocol.Message{ - Role: protocol.MessageRoleAgent, - Parts: []protocol.Part{protocol.NewTextPart("Task completed successfully: " + messageText)}, - }) -} - -func stringPtr(s string) *string { - return &s -} - -// setupRedisTest creates an in-memory Redis server and returns a configured TaskManager -func setupRedisTest(t *testing.T) (*TaskManager, *miniredis.Miniredis) { - // Create a new in-memory Redis server - mr, err := miniredis.Run() - require.NoError(t, err, "Failed to create miniredis server") - - // Create Redis client options pointing to miniredis - opts := &redis.UniversalOptions{ - Addrs: []string{mr.Addr()}, - } - - // Create a test processor - processor := newTestProcessor() - - // Create the Redis task manager - expiration := 1 * time.Hour - client := redis.NewUniversalClient(opts) - manager, err := NewRedisTaskManager(client, processor, WithExpiration(expiration)) - require.NoError(t, err, "Failed to create Redis task manager") - - return manager, mr -} - -// Test basic task processing with direct task submission -func TestE2E_BasicTaskProcessing(t *testing.T) { - manager, mr := setupRedisTest(t) - defer mr.Close() - defer manager.Close() - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - // Create a task - taskParams := protocol.SendTaskParams{ - ID: "test-task-1", - Message: protocol.Message{ - Role: protocol.MessageRoleUser, - Parts: []protocol.Part{protocol.NewTextPart("Hello, this is a test task")}, - }, - } - - // Submit the task (simulates client -> server -> task manager flow) - task, err := manager.OnSendTask(ctx, taskParams) - require.NoError(t, err, "Failed to send task") - assert.Equal(t, protocol.TaskStateCompleted, task.Status.State, "Task should be completed") - - // Verify the task was stored in Redis - taskKey := "task:" + taskParams.ID - assert.True(t, mr.Exists(taskKey), "Task should be stored in Redis") - - // Verify message history was stored in Redis - messageKey := "msg:" + taskParams.ID - assert.True(t, mr.Exists(messageKey), "Message history should be stored in Redis") - - // Retrieve the task to verify content - retrievedTask, err := manager.OnGetTask(ctx, protocol.TaskQueryParams{ - ID: taskParams.ID, - HistoryLength: intPtr(10), - }) - require.NoError(t, err, "Failed to retrieve task") - - // Verify the final state and response message - assert.Equal(t, protocol.TaskStateCompleted, retrievedTask.Status.State) - require.NotNil(t, retrievedTask.Status.Message, "Response message should not be nil") - assert.Equal(t, protocol.MessageRoleAgent, retrievedTask.Status.Message.Role) - require.Greater(t, len(retrievedTask.Status.Message.Parts), 0, "Response should have message parts") - - // Verify message history - assert.Greater(t, len(retrievedTask.History), 0, "Task should have message history") -} - -// Test task processing with streaming/subscription -func TestE2E_TaskSubscription(t *testing.T) { - manager, mr := setupRedisTest(t) - defer mr.Close() - defer manager.Close() - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - // Create a task with artifact generation - taskParams := protocol.SendTaskParams{ - ID: "test-subscribe-task", - Message: protocol.Message{ - Role: protocol.MessageRoleUser, - Parts: []protocol.Part{protocol.NewTextPart("artifacts:test artifacts")}, - }, - } - - // Subscribe to task events (simulates client -> server streaming connection) - eventChan, err := manager.OnSendTaskSubscribe(ctx, taskParams) - require.NoError(t, err, "Failed to subscribe to task") - - // Collect events - var statusEvents []protocol.TaskStatusUpdateEvent - var artifactEvents []protocol.TaskArtifactUpdateEvent - var lastEvent protocol.TaskEvent - - // Process events until the channel is closed or timeout occurs - timeout := time.After(3 * time.Second) -eventLoop: - for { - select { - case event, ok := <-eventChan: - if !ok { - // Channel closed, processing complete - break eventLoop - } - lastEvent = event - - // Collect events by type for verification - switch e := event.(type) { - case protocol.TaskStatusUpdateEvent: - statusEvents = append(statusEvents, e) - case protocol.TaskArtifactUpdateEvent: - artifactEvents = append(artifactEvents, e) - } - case <-timeout: - t.Fatal("Test timed out waiting for events") - break eventLoop - } - } - - // Verify we received status updates - assert.GreaterOrEqual(t, len(statusEvents), 2, "Should receive at least initial and final status updates") - assert.Equal(t, taskParams.ID, statusEvents[0].ID, "Status event should have correct task ID") - - // Verify we received artifact events (the test processor sends 3) - assert.Equal(t, 3, len(artifactEvents), "Should receive 3 artifact events") - - // Verify the last event was final - assert.True(t, lastEvent.IsFinal(), "Last event should be final") - - // Retrieve the task to verify final state - retrievedTask, err := manager.OnGetTask(ctx, protocol.TaskQueryParams{ - ID: taskParams.ID, - }) - require.NoError(t, err, "Failed to retrieve task") - assert.Equal(t, protocol.TaskStateCompleted, retrievedTask.Status.State) - assert.Equal(t, 3, len(retrievedTask.Artifacts), "Task should have 3 artifacts") -} - -// Test task cancellation -func TestE2E_TaskCancellation(t *testing.T) { - manager, mr := setupRedisTest(t) - defer mr.Close() - defer manager.Close() - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - // Create a long-running task that can be cancelled - taskParams := protocol.SendTaskParams{ - ID: "test-cancel-task", - Message: protocol.Message{ - Role: protocol.MessageRoleUser, - Parts: []protocol.Part{protocol.NewTextPart("cancel:task to be cancelled")}, - }, - } - - // Start task with subscription - eventChan, err := manager.OnSendTaskSubscribe(ctx, taskParams) - require.NoError(t, err, "Failed to subscribe to task") - - // Make sure task has started (wait for working state event) - var taskStarted bool - timeout := time.After(2 * time.Second) - for !taskStarted { - select { - case event, ok := <-eventChan: - if !ok { - t.Fatal("Event channel closed before task started") - } - if statusEvent, ok := event.(protocol.TaskStatusUpdateEvent); ok { - if statusEvent.Status.State == protocol.TaskStateWorking { - taskStarted = true - } - } - case <-timeout: - t.Fatal("Timed out waiting for task to start") - } - } - - // Cancel the task once it's started - cancelledTask, err := manager.OnCancelTask(ctx, protocol.TaskIDParams{ID: taskParams.ID}) - require.NoError(t, err, "Failed to cancel task") - assert.Equal(t, protocol.TaskStateCanceled, cancelledTask.Status.State, "Task should be in cancelled state") - - // Wait for final event or timeout - var finalEventReceived bool - timeout = time.After(2 * time.Second) - for !finalEventReceived { - select { - case event, ok := <-eventChan: - if !ok { - // Channel closed normally - finalEventReceived = true - break - } - if event.IsFinal() { - finalEventReceived = true - } - case <-timeout: - t.Fatal("Timed out waiting for final event after cancellation") - } - } - - // Verify the final state - retrievedTask, err := manager.OnGetTask(ctx, protocol.TaskQueryParams{ID: taskParams.ID}) - require.NoError(t, err, "Failed to retrieve task") - assert.Equal(t, protocol.TaskStateCanceled, retrievedTask.Status.State, "Task should remain in cancelled state") -} - -// Test task resubscription -func TestE2E_TaskResubscribe(t *testing.T) { - manager, mr := setupRedisTest(t) - defer mr.Close() - defer manager.Close() - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - // Create and complete a task first - taskParams := protocol.SendTaskParams{ - ID: "test-resubscribe-task", - Message: protocol.Message{ - Role: protocol.MessageRoleUser, - Parts: []protocol.Part{protocol.NewTextPart("Hello, this is a test task")}, - }, - } - - // Complete the task - task, err := manager.OnSendTask(ctx, taskParams) - require.NoError(t, err, "Failed to send task") - assert.Equal(t, protocol.TaskStateCompleted, task.Status.State) - - // Now try to resubscribe to the completed task - resubEventChan, err := manager.OnResubscribe( - ctx, protocol.TaskIDParams{ID: taskParams.ID}, - ) - require.NoError(t, err, "Failed to resubscribe to task") - - // We should get a single final status event and then the channel should close - event, ok := <-resubEventChan - require.True(t, ok, "Should receive one event") - assert.True(t, event.IsFinal(), "Event should be final") - - statusEvent, ok := event.(protocol.TaskStatusUpdateEvent) - require.True(t, ok, "Event should be a status update") - assert.Equal(t, protocol.TaskStateCompleted, statusEvent.Status.State) - - // Ensure channel closes - _, ok = <-resubEventChan - assert.False(t, ok, "Channel should be closed after final event") -} - -// Test task push notification configuration -func TestE2E_PushNotifications(t *testing.T) { - manager, mr := setupRedisTest(t) - defer mr.Close() - defer manager.Close() - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - // Create a task - taskParams := protocol.SendTaskParams{ - ID: "test-push-notification", - Message: protocol.Message{ - Role: protocol.MessageRoleUser, - Parts: []protocol.Part{protocol.NewTextPart("Task with push notifications")}, - }, - } - - // Send the task - _, err := manager.OnSendTask(ctx, taskParams) - require.NoError(t, err, "Failed to send task") - - // Configure push notifications for the task - pushConfig := protocol.TaskPushNotificationConfig{ - ID: taskParams.ID, - PushNotificationConfig: protocol.PushNotificationConfig{ - URL: "https://example.com/webhook", - Token: "test-token", - }, - } - - // Set push notification config - resultConfig, err := manager.OnPushNotificationSet(ctx, pushConfig) - require.NoError(t, err, "Failed to set push notification config") - assert.Equal(t, pushConfig.ID, resultConfig.ID) - assert.Equal(t, pushConfig.PushNotificationConfig.URL, resultConfig.PushNotificationConfig.URL) - - // Verify push notification key in Redis - pushKey := "push:" + taskParams.ID - assert.True(t, mr.Exists(pushKey), "Push notification config should be stored in Redis") - - // Get push notification config - retrievedConfig, err := manager.OnPushNotificationGet( - ctx, protocol.TaskIDParams{ID: taskParams.ID}, - ) - require.NoError(t, err, "Failed to get push notification config") - assert.Equal(t, pushConfig.PushNotificationConfig.URL, retrievedConfig.PushNotificationConfig.URL) - assert.Equal(t, pushConfig.PushNotificationConfig.Token, retrievedConfig.PushNotificationConfig.Token) -} - -// Test error handling for non-existent tasks -func TestE2E_ErrorHandling(t *testing.T) { - manager, mr := setupRedisTest(t) - defer mr.Close() - defer manager.Close() - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - nonExistentTaskID := "non-existent-task" - - // Try to get a non-existent task - _, err := manager.OnGetTask(ctx, protocol.TaskQueryParams{ID: nonExistentTaskID}) - assert.Error(t, err, "Getting non-existent task should return an error") - - // Try to cancel a non-existent task - _, err = manager.OnCancelTask(ctx, protocol.TaskIDParams{ID: nonExistentTaskID}) - assert.Error(t, err, "Cancelling non-existent task should return an error") - - // Try to get push notifications for a non-existent task - _, err = manager.OnPushNotificationGet(ctx, protocol.TaskIDParams{ID: nonExistentTaskID}) - assert.Error(t, err, "Getting push notifications for non-existent task should return an error") - - // Try to set push notifications for a non-existent task - _, err = manager.OnPushNotificationSet(ctx, protocol.TaskPushNotificationConfig{ - ID: nonExistentTaskID, - PushNotificationConfig: protocol.PushNotificationConfig{ - URL: "https://example.com/webhook", - }, - }) - assert.Error(t, err, "Setting push notifications for non-existent task should return an error") - - // Try to resubscribe to a non-existent task - _, err = manager.OnResubscribe(ctx, protocol.TaskIDParams{ID: nonExistentTaskID}) - assert.Error(t, err, "Resubscribing to non-existent task should return an error") -} - -// Test handling of task failure -func TestE2E_TaskFailure(t *testing.T) { - manager, mr := setupRedisTest(t) - defer mr.Close() - defer manager.Close() - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - // Create a task that will fail - taskParams := protocol.SendTaskParams{ - ID: "test-failure-task", - Message: protocol.Message{ - Role: protocol.MessageRoleUser, - Parts: []protocol.Part{protocol.NewTextPart("fail:intentional failure")}, - }, - } - - // Send the task - task, err := manager.OnSendTask(ctx, taskParams) - // The task should succeed but the processor should report failure status - require.NoError(t, err, "OnSendTask should not return an error even if task processing fails") - assert.Equal(t, protocol.TaskStateFailed, task.Status.State, "Task should be in failed state") - - // Verify the failure was stored - retrievedTask, err := manager.OnGetTask( - ctx, protocol.TaskQueryParams{ID: taskParams.ID}, - ) - require.NoError(t, err, "Failed to retrieve task") - assert.Equal(t, protocol.TaskStateFailed, retrievedTask.Status.State) - require.NotNil(t, retrievedTask.Status.Message, "Failure message should be available") -} - -// Test handling of input-required state -func TestE2E_InputRequired(t *testing.T) { - manager, mr := setupRedisTest(t) - defer mr.Close() - defer manager.Close() - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - // Create a task that will require input - taskParams := protocol.SendTaskParams{ - ID: "test-input-required", - Message: protocol.Message{ - Role: protocol.MessageRoleUser, - Parts: []protocol.Part{protocol.NewTextPart("input-required:more data needed")}, - }, - } - - // Send the task - task, err := manager.OnSendTask(ctx, taskParams) - require.NoError(t, err, "Failed to send task") - assert.Equal(t, protocol.TaskStateInputRequired, task.Status.State, "Task should be in input-required state") - - // Verify the state was stored - retrievedTask, err := manager.OnGetTask( - ctx, protocol.TaskQueryParams{ID: taskParams.ID}, - ) - require.NoError(t, err, "Failed to retrieve task") - assert.Equal(t, protocol.TaskStateInputRequired, retrievedTask.Status.State) - require.NotNil(t, retrievedTask.Status.Message, "Input request message should be available") -} - -func intPtr(i int) *int { - return &i -} diff --git a/taskmanager/redis/redis_types.go b/taskmanager/redis/redis_types.go new file mode 100644 index 0000000..8391562 --- /dev/null +++ b/taskmanager/redis/redis_types.go @@ -0,0 +1,131 @@ +// Tencent is pleased to support the open source community by making trpc-a2a-go available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// trpc-a2a-go is licensed under the Apache License Version 2.0. + +// Package redis provides Redis-specific implementations of taskmanager interfaces. +package redis + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "trpc.group/trpc-go/trpc-a2a-go/protocol" + "trpc.group/trpc-go/trpc-a2a-go/taskmanager" +) + +// CancellableTask implements the CancellableTask interface for Redis storage. +type CancellableTask struct { + task *protocol.Task + cancelFunc context.CancelFunc + mu sync.RWMutex +} + +// NewRedisCancellableTask creates a new Redis-based cancellable task. +func NewRedisCancellableTask(task *protocol.Task, cancelFunc context.CancelFunc) *CancellableTask { + return &CancellableTask{ + task: task, + cancelFunc: cancelFunc, + } +} + +// Task returns the protocol task. +func (t *CancellableTask) Task() *protocol.Task { + t.mu.RLock() + defer t.mu.RUnlock() + return t.task +} + +// Cancel cancels the task by calling the cancel function. +func (t *CancellableTask) Cancel() { + if t.cancelFunc != nil { + t.cancelFunc() + } +} + +// TaskSubscriber implements the TaskSubscriber interface for Redis storage. +type TaskSubscriber struct { + taskID string + eventQueue chan protocol.StreamingMessageEvent + closed atomic.Bool + mu sync.RWMutex + lastAccess time.Time +} + +// NewTaskSubscriber creates a new Redis-based task subscriber. +func NewTaskSubscriber(taskID string, bufferSize int) *TaskSubscriber { + if bufferSize <= 0 { + bufferSize = defaultTaskSubscriberBufferSize + } + + return &TaskSubscriber{ + taskID: taskID, + eventQueue: make(chan protocol.StreamingMessageEvent, bufferSize), + lastAccess: time.Now(), + } +} + +// Send sends an event to the subscriber's event queue. +func (s *TaskSubscriber) Send(event protocol.StreamingMessageEvent) error { + if s.Closed() { + return fmt.Errorf("task subscriber for task %s is closed", s.taskID) + } + + s.mu.Lock() + defer s.mu.Unlock() + + if s.Closed() { + return fmt.Errorf("task subscriber for task %s is closed", s.taskID) + } + + s.lastAccess = time.Now() + + // Use select with default to avoid blocking + select { + case s.eventQueue <- event: + return nil + default: + return fmt.Errorf("event queue is full for task %s", s.taskID) + } +} + +// Channel returns the event channel for receiving streaming events. +func (s *TaskSubscriber) Channel() <-chan protocol.StreamingMessageEvent { + return s.eventQueue +} + +// Closed returns true if the subscriber is closed. +func (s *TaskSubscriber) Closed() bool { + return s.closed.Load() +} + +// Close closes the subscriber and its event channel. +func (s *TaskSubscriber) Close() { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.closed.Load() { + s.closed.Store(true) + close(s.eventQueue) + } +} + +// GetTaskID returns the task ID this subscriber is associated with. +func (s *TaskSubscriber) GetTaskID() string { + return s.taskID +} + +// GetLastAccessTime returns the last access time of the subscriber. +func (s *TaskSubscriber) GetLastAccessTime() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() + return s.lastAccess +} + +// Ensure our types implement the required interfaces +var _ taskmanager.CancellableTask = (*CancellableTask)(nil) +var _ taskmanager.TaskSubscriber = (*TaskSubscriber)(nil) diff --git a/taskmanager/redis/redis_types_test.go b/taskmanager/redis/redis_types_test.go new file mode 100644 index 0000000..76d0d63 --- /dev/null +++ b/taskmanager/redis/redis_types_test.go @@ -0,0 +1,139 @@ +// Tencent is pleased to support the open source community by making trpc-a2a-go available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// trpc-a2a-go is licensed under the Apache License Version 2.0. + +package redis + +import ( + "context" + "testing" + "time" + + "trpc.group/trpc-go/trpc-a2a-go/protocol" + "trpc.group/trpc-go/trpc-a2a-go/taskmanager" +) + +func TestRedisCancellableTask(t *testing.T) { + // Create a test task + task := &protocol.Task{ + ID: "test-task-1", + Status: protocol.TaskStatus{ + State: protocol.TaskStateSubmitted, + Timestamp: time.Now().UTC().Format(time.RFC3339), + }, + } + + // Create a cancellable context + ctx, cancel := context.WithCancel(context.Background()) + + // Create our Redis cancellable task + cancellableTask := NewRedisCancellableTask(task, cancel) + + // Verify it implements the interface + var _ taskmanager.CancellableTask = cancellableTask + + // Test Task() method + retrievedTask := cancellableTask.Task() + if retrievedTask.ID != "test-task-1" { + t.Errorf("Expected task ID 'test-task-1', got '%s'", retrievedTask.ID) + } + + // Test Cancel() method + cancellableTask.Cancel() + + // Verify context was cancelled + select { + case <-ctx.Done(): + // Expected + case <-time.After(100 * time.Millisecond): + t.Error("Expected context to be cancelled") + } +} + +func TestRedisTaskSubscriber(t *testing.T) { + taskID := "test-task-2" + bufferSize := 5 + + // Create subscriber + subscriber := NewTaskSubscriber(taskID, bufferSize) + + // Verify it implements the interface + var _ taskmanager.TaskSubscriber = subscriber + + // Test basic properties + if subscriber.GetTaskID() != taskID { + t.Errorf("Expected task ID '%s', got '%s'", taskID, subscriber.GetTaskID()) + } + + if subscriber.Closed() { + t.Error("Subscriber should not be closed initially") + } + + // Test sending events + event := protocol.StreamingMessageEvent{ + Result: &protocol.TaskStatusUpdateEvent{ + TaskID: taskID, + Status: protocol.TaskStatus{ + State: protocol.TaskStateSubmitted, + Timestamp: time.Now().UTC().Format(time.RFC3339), + }, + }, + } + + err := subscriber.Send(event) + if err != nil { + t.Errorf("Unexpected error sending event: %v", err) + } + + // Test receiving events + select { + case receivedEvent := <-subscriber.Channel(): + if receivedEvent.Result == nil { + t.Error("Expected event result, got nil") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timeout waiting for event") + } + + // Test closing + subscriber.Close() + if !subscriber.Closed() { + t.Error("Subscriber should be closed after Close()") + } + + // Test sending to closed subscriber + err = subscriber.Send(event) + if err == nil { + t.Error("Expected error when sending to closed subscriber") + } +} + +func TestRedisTaskSubscriberBufferFull(t *testing.T) { + taskID := "test-task-3" + bufferSize := 2 + + subscriber := NewTaskSubscriber(taskID, bufferSize) + defer subscriber.Close() + + event := protocol.StreamingMessageEvent{ + Result: &protocol.TaskStatusUpdateEvent{ + TaskID: taskID, + }, + } + + // Fill the buffer + for i := 0; i < bufferSize; i++ { + err := subscriber.Send(event) + if err != nil { + t.Errorf("Unexpected error sending event %d: %v", i, err) + } + } + + // Next send should fail due to full buffer + err := subscriber.Send(event) + if err == nil { + t.Error("Expected error when buffer is full") + } +} diff --git a/taskmanager/task.go b/taskmanager/task.go deleted file mode 100644 index a7b12bc..0000000 --- a/taskmanager/task.go +++ /dev/null @@ -1,64 +0,0 @@ -// Tencent is pleased to support the open source community by making trpc-a2a-go available. -// -// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. -// -// trpc-a2a-go is licensed under the Apache License Version 2.0. - -package taskmanager - -import ( - "context" - - "trpc.group/trpc-go/trpc-a2a-go/protocol" -) - -// memoryTaskHandle implements the TaskHandle interface, providing callbacks -// for a specific task being processed by a TaskProcessor. -// It holds a reference back to the MemoryTaskManager. -type memoryTaskHandle struct { - taskID string - manager *MemoryTaskManager -} - -// UpdateStatus implements TaskHandle. -func (h *memoryTaskHandle) UpdateStatus( - state protocol.TaskState, - msg *protocol.Message, -) error { - return h.manager.UpdateTaskStatus(context.Background(), h.taskID, state, msg) -} - -// AddArtifact implements TaskHandle. -func (h *memoryTaskHandle) AddArtifact(artifact protocol.Artifact) error { - return h.manager.AddArtifact(h.taskID, artifact) -} - -// IsStreamingRequest checks if this task was initiated with a streaming request (OnSendTaskSubscribe). -// It returns true if there are active subscribers for this task, indicating it was initiated -// with OnSendTaskSubscribe rather than OnSendTask. -func (h *memoryTaskHandle) IsStreamingRequest() bool { - h.manager.SubMutex.RLock() - defer h.manager.SubMutex.RUnlock() - - subscribers, exists := h.manager.Subscribers[h.taskID] - return exists && len(subscribers) > 0 -} - -// GetSessionID implements TaskHandle. -func (h *memoryTaskHandle) GetSessionID() *string { - h.manager.SubMutex.RLock() - defer h.manager.SubMutex.RUnlock() - - task, exists := h.manager.Tasks[h.taskID] - if !exists { - return nil - } - - return task.SessionID -} - -// isFinalState checks if a TaskState represents a terminal state. -// Not exported as it's an internal helper. -func isFinalState(state protocol.TaskState) bool { - return state == protocol.TaskStateCompleted || state == protocol.TaskStateFailed || state == protocol.TaskStateCanceled -} diff --git a/tests/e2e_auth_test.go b/tests/e2e_auth_test.go index 672ec4c..888086f 100644 --- a/tests/e2e_auth_test.go +++ b/tests/e2e_auth_test.go @@ -256,7 +256,7 @@ func TestPushNotificationAuthentication(t *testing.T) { URL: "http://localhost:8080", Version: "1.0.0", Capabilities: server.AgentCapabilities{ - PushNotifications: true, + PushNotifications: &[]bool{true}[0], }, DefaultInputModes: []string{"text"}, DefaultOutputModes: []string{"text"}, @@ -447,13 +447,10 @@ func setupAuthServer(t *testing.T, provider auth.Provider) (taskmanager.TaskMana URL: "http://localhost:8080", Version: "1.0.0", Capabilities: server.AgentCapabilities{ - Streaming: true, + Streaming: &[]bool{true}[0], }, - Authentication: &protocol.AuthenticationInfo{ - Schemes: []string{"apiKey", "jwt"}, - }, - DefaultInputModes: []string{"text"}, - DefaultOutputModes: []string{"text"}, + DefaultInputModes: []string{protocol.KindText}, + DefaultOutputModes: []string{protocol.KindText}, } a2aServer, err := server.NewA2AServer( @@ -469,17 +466,21 @@ func setupAuthServer(t *testing.T, provider auth.Provider) (taskmanager.TaskMana // mockTaskManager is a simple implementation of the TaskManager interface. type mockTaskManager struct { - processor taskmanager.TaskProcessor + processor taskmanager.MessageProcessor tasks map[string]*protocol.Task pushConfigs map[string]protocol.PushNotificationConfig + messages map[string]*protocol.Message } +var _ taskmanager.TaskManager = (*mockTaskManager)(nil) + // newMockTaskManager creates a new mock task manager. -func newMockTaskManager(processor taskmanager.TaskProcessor) *mockTaskManager { +func newMockTaskManager(processor taskmanager.MessageProcessor) *mockTaskManager { return &mockTaskManager{ processor: processor, tasks: make(map[string]*protocol.Task), pushConfigs: make(map[string]protocol.PushNotificationConfig), + messages: make(map[string]*protocol.Message), } } @@ -492,11 +493,62 @@ func (m *mockTaskManager) Task(id string) (*protocol.Task, error) { return task, nil } -// OnSendTask handles sending a task. +// OnSendMessage handles a request corresponding to the 'message/send' RPC method. +func (m *mockTaskManager) OnSendMessage( + ctx context.Context, request protocol.SendMessageParams, +) (*protocol.MessageResult, error) { + // Store the message + m.messages[request.Message.MessageID] = &request.Message + + // Create a simple message result + return &protocol.MessageResult{ + Result: &request.Message, + }, nil +} + +// OnSendMessageStream handles a request corresponding to the 'message/stream' RPC method. +func (m *mockTaskManager) OnSendMessageStream( + ctx context.Context, request protocol.SendMessageParams, +) (<-chan protocol.StreamingMessageEvent, error) { + // Store the message + m.messages[request.Message.MessageID] = &request.Message + + // Create a channel for events + eventCh := make(chan protocol.StreamingMessageEvent, 10) + + // For a mock implementation, just send one event with the message and close the channel + go func() { + defer close(eventCh) + + // Send the message result + event := protocol.StreamingMessageEvent{ + Result: &request.Message, + } + + // Try to send the event, but don't block forever + select { + case eventCh <- event: + // Event sent successfully + case <-ctx.Done(): + // Context was canceled + } + }() + + return eventCh, nil +} + +// OnSendTask handles sending a task. (deprecated) func (m *mockTaskManager) OnSendTask( ctx context.Context, params protocol.SendTaskParams, ) (*protocol.Task, error) { - task := protocol.NewTask(params.ID, params.SessionID) + var contextID string + if params.SessionID != nil { + contextID = *params.SessionID + } else { + contextID = "" + } + + task := protocol.NewTask(params.ID, contextID) m.tasks[params.ID] = task handle := &mockTaskHandle{ @@ -509,7 +561,9 @@ func (m *mockTaskManager) OnSendTask( } if m.processor != nil { - if err := m.processor.Process(ctx, params.ID, params.Message, handle); err != nil { + // Note: This is a simplified mock implementation + // Real implementation would call processor with proper parameters + if err := handle.UpdateStatus(protocol.TaskStateCompleted, nil); err != nil { return task, err } } else { @@ -529,27 +583,16 @@ func (m *mockTaskManager) OnGetTask( return m.Task(params.ID) } -// OnListTasks handles listing tasks. -func (m *mockTaskManager) OnListTasks( - ctx context.Context, params protocol.TaskQueryParams, -) ([]*protocol.Task, error) { - var tasks []*protocol.Task - for _, task := range m.tasks { - tasks = append(tasks, task) - } - return tasks, nil -} - // OnPushNotificationSet sets a push notification configuration for a task. func (m *mockTaskManager) OnPushNotificationSet( ctx context.Context, params protocol.TaskPushNotificationConfig, ) (*protocol.TaskPushNotificationConfig, error) { - _, err := m.Task(params.ID) + _, err := m.Task(params.TaskID) if err != nil { return nil, err } - m.pushConfigs[params.ID] = params.PushNotificationConfig + m.pushConfigs[params.TaskID] = params.PushNotificationConfig return ¶ms, nil } @@ -568,7 +611,7 @@ func (m *mockTaskManager) OnPushNotificationGet( } return &protocol.TaskPushNotificationConfig{ - ID: params.ID, + TaskID: params.ID, PushNotificationConfig: config, }, nil } @@ -576,34 +619,34 @@ func (m *mockTaskManager) OnPushNotificationGet( // OnResubscribe handles resubscribing to a task. func (m *mockTaskManager) OnResubscribe( ctx context.Context, params protocol.TaskIDParams, -) (<-chan protocol.TaskEvent, error) { +) (<-chan protocol.StreamingMessageEvent, error) { task, err := m.Task(params.ID) if err != nil { return nil, err } // Create a channel for events - eventCh := make(chan protocol.TaskEvent) + eventCh := make(chan protocol.StreamingMessageEvent) // For a mock implementation, just send one event with the current status and close the channel go func() { + defer close(eventCh) + // Send the current task status + final := true event := protocol.TaskStatusUpdateEvent{ - ID: task.ID, + TaskID: task.ID, Status: task.Status, - Final: true, + Final: &final, } // Try to send the event, but don't block forever select { - case eventCh <- event: + case eventCh <- protocol.StreamingMessageEvent{Result: &event}: // Event sent successfully case <-ctx.Done(): // Context was canceled } - - // Close the channel - close(eventCh) }() return eventCh, nil @@ -630,14 +673,7 @@ func (m *mockTaskManager) OnCancelTask( return m.tasks[params.ID], nil } -// OnSubscribeTaskUpdates handles subscribing to task updates. -func (m *mockTaskManager) OnSubscribeTaskUpdates( - ctx context.Context, params protocol.TaskQueryParams, eventCh chan<- protocol.TaskEvent, -) error { - return fmt.Errorf("streaming not implemented in mock task manager") -} - -// OnSendTaskSubscribe handles sending a task and subscribing to updates. +// OnSendTaskSubscribe handles sending a task and subscribing to updates. (deprecated) func (m *mockTaskManager) OnSendTaskSubscribe( ctx context.Context, params protocol.SendTaskParams, ) (<-chan protocol.TaskEvent, error) { @@ -653,11 +689,14 @@ func (m *mockTaskManager) OnSendTaskSubscribe( // For a mock implementation, just send one event with the current status and close the channel go func() { + defer close(eventCh) + // Send the current task status - event := protocol.TaskStatusUpdateEvent{ - ID: task.ID, + final := true + event := &protocol.TaskStatusUpdateEvent{ + TaskID: task.ID, Status: task.Status, - Final: true, + Final: &final, } // Try to send the event, but don't block forever @@ -667,9 +706,6 @@ func (m *mockTaskManager) OnSendTaskSubscribe( case <-ctx.Done(): // Context was canceled } - - // Close the channel - close(eventCh) }() return eventCh, nil @@ -724,13 +760,13 @@ func (h *mockTaskHandle) IsStreamingRequest() bool { return false } -// GetSessionID implements the TaskHandle interface. -func (h *mockTaskHandle) GetSessionID() *string { +// GetContextID implements the TaskHandle interface. +func (h *mockTaskHandle) GetContextID() *string { task, err := h.manager.Task(h.taskID) if err != nil { return nil } - return task.SessionID + return &task.ContextID } // AddResponse adds a response to a task. @@ -751,14 +787,16 @@ func (h *mockTaskHandle) AddResponse(response protocol.Message) error { // echoProcessor is a simple task processor that echoes messages. type echoProcessor struct{} -// Process simply echoes the received message. -func (p *echoProcessor) Process( - ctx context.Context, taskID string, msg protocol.Message, handle taskmanager.TaskHandle, -) error { +var _ taskmanager.MessageProcessor = (*echoProcessor)(nil) + +// ProcessMessage simply echoes the received message. +func (p *echoProcessor) ProcessMessage( + ctx context.Context, msg protocol.Message, opts taskmanager.ProcessOptions, handle taskmanager.TaskHandler, +) (*taskmanager.MessageProcessingResult, error) { // Create a response that echoes back the message textPart, ok := msg.Parts[0].(protocol.TextPart) if !ok { - return fmt.Errorf("expected TextPart, got %T", msg.Parts[0]) + return nil, fmt.Errorf("expected TextPart, got %T", msg.Parts[0]) } response := protocol.Message{ @@ -768,6 +806,7 @@ func (p *echoProcessor) Process( }, } - // Mark the task as done with the response - return handle.UpdateStatus(protocol.TaskStateCompleted, &response) + return &taskmanager.MessageProcessingResult{ + Result: &response, + }, nil } diff --git a/tests/e2e_test.go b/tests/e2e_test.go index 97f16ce..f451aca 100644 --- a/tests/e2e_test.go +++ b/tests/e2e_test.go @@ -34,11 +34,6 @@ func boolPtr(b bool) *bool { return &b } -// intPtr is a helper to get a pointer to an int. -func intPtr(i int) *int { - return &i -} - // testReverseString is a helper that reverses a string. func testReverseString(s string) string { runes := []rune(s) @@ -48,49 +43,81 @@ func testReverseString(s string) string { return string(runes) } -// testStreamingProcessor implements taskmanager.TaskProcessor for streaming E2E tests. +// testProcessor implements taskmanager.MessageProcessor for streaming E2E tests. // It reverses the input text and sends it back chunk by chunk via status updates // and a final artifact. -type testStreamingProcessor struct{} +type testProcessor struct{} + +var _ taskmanager.MessageProcessor = (*testProcessor)(nil) -// Process implements taskmanager.TaskProcessor for streaming. -func (p *testStreamingProcessor) Process( +// ProcessMessage implements taskmanager.MessageProcessor for streaming. +func (p *testProcessor) ProcessMessage( ctx context.Context, - taskID string, - msg protocol.Message, - handle taskmanager.TaskHandle, -) error { - log.Printf("[testStreamingProcessor] Processing task %s", taskID) + message protocol.Message, + options taskmanager.ProcessOptions, + handle taskmanager.TaskHandler, +) (*taskmanager.MessageProcessingResult, error) { + // Extract input text from the message + inputText := getTextPartContent(message.Parts) + if inputText == "" { + return nil, fmt.Errorf("no text content found in message") + } - var inputText string - for _, part := range msg.Parts { - if textPart, ok := part.(protocol.TextPart); ok { // Check for value type - inputText = textPart.Text - break - } + // Create a task + taskID, err := handle.BuildTask(message.TaskID, message.ContextID) + if err != nil { + return nil, fmt.Errorf("failed to build task: %w", err) } - if inputText == "" { - // If no text part found, try legacy TextPart check just in case - for _, part := range msg.Parts { - if textPart, ok := part.(protocol.TextPart); ok { - log.Printf("WARNING: Found legacy TextPart (value) in message for task %s", taskID) - inputText = textPart.Text - break - } - } - if inputText == "" { - return fmt.Errorf("no text input found in message for task %s", taskID) + if options.Streaming { + // For streaming requests, process in background and return StreamingEvents + subscriber, err := handle.SubScribeTask(stringPtr(taskID)) + if err != nil { + return nil, fmt.Errorf("failed to subscribe to task: %w", err) } + + // Process task in background + go func() { + if err := p.processTask(taskID, message, inputText, subscriber, handle); err != nil { + log.Printf("[testStreamingProcessor] Error processing task: %v", err) + } + }() + + return &taskmanager.MessageProcessingResult{ + StreamingEvents: subscriber, + }, nil + } + // For non-streaming requests, process synchronously and return Result + // Process the task synchronously without auto-cleanup + if err := p.processTask(taskID, message, inputText, nil, handle); err != nil { + return nil, fmt.Errorf("failed to process task: %w", err) } + // Get the final task state + finalTask, err := handle.GetTask(stringPtr(taskID)) + if err != nil { + return nil, fmt.Errorf("failed to get final task: %w", err) + } + + return &taskmanager.MessageProcessingResult{ + Result: finalTask.Task(), + }, nil +} + +func (p *testProcessor) processTask( + taskID string, + message protocol.Message, + inputText string, + subscriber taskmanager.TaskSubscriber, + handle taskmanager.TaskHandler, +) error { reversedText := testReverseString(inputText) log.Printf("[testStreamingProcessor] Input: '%s', Reversed: '%s'", inputText, reversedText) - // Send intermediate 'Working' status updates (chunked). + // Send intermediate 'Working' status updates (chunked) chunkSize := 3 for i := 0; i < len(reversedText); i += chunkSize { - time.Sleep(20 * time.Millisecond) // Simulate work per chunk. + time.Sleep(20 * time.Millisecond) // Simulate work per chunk end := i + chunkSize if end > len(reversedText) { end = len(reversedText) @@ -99,120 +126,116 @@ func (p *testStreamingProcessor) Process( statusMsg := &protocol.Message{ Role: protocol.MessageRoleAgent, Parts: []protocol.Part{ - protocol.NewTextPart(fmt.Sprintf("Processing chunk: %s", chunk)), // Returns TextPart + protocol.NewTextPart(fmt.Sprintf("Processing chunk: %s", chunk)), }, } - if err := handle.UpdateStatus(protocol.TaskStateWorking, statusMsg); err != nil { + + // Will notify the subscriber automatically + if err := handle.UpdateTaskState(stringPtr(taskID), protocol.TaskStateWorking, statusMsg); err != nil { log.Printf("[testStreamingProcessor] Error sending working status chunk: %v", err) - return err // Propagate error if sending status fails. + return err } } - // Send the final artifact containing the full reversed text. + // Send the final artifact containing the full reversed text finalArtifact := protocol.Artifact{ Name: stringPtr("Processed Text"), Description: stringPtr("The reversed input text."), Parts: []protocol.Part{ - protocol.NewTextPart(reversedText), // Returns TextPart + protocol.NewTextPart(reversedText), }, - Index: 0, - LastChunk: boolPtr(true), } - if err := handle.AddArtifact(finalArtifact); err != nil { + + if err := handle.AddArtifact(stringPtr(taskID), finalArtifact, true, false); err != nil { log.Printf("[testStreamingProcessor] Error sending artifact: %v", err) - return err // Propagate error if sending artifact fails. + return err } - // Send final 'Completed' status. + // Send final 'Completed' status completionMsg := &protocol.Message{ Role: protocol.MessageRoleAgent, Parts: []protocol.Part{ protocol.NewTextPart( fmt.Sprintf("Task %s completed successfully. Result: %s", taskID, reversedText), - ), // Include the reversed text in the final message + ), }, } - if err := handle.UpdateStatus(protocol.TaskStateCompleted, completionMsg); err != nil { + + if err := handle.UpdateTaskState(stringPtr(taskID), protocol.TaskStateCompleted, completionMsg); err != nil { log.Printf("[testStreamingProcessor] Error sending completed status: %v", err) - return err // Propagate error if sending final status fails. + return err } log.Printf("[testStreamingProcessor] Finished processing task %s", taskID) - return nil // Success. + return nil } -// --- Test Task Manager Definition --- - // testBasicTaskManager is a simple TaskManager for basic tests. type testBasicTaskManager struct { - // Embed MemoryTaskManager to get basic functionality. *taskmanager.MemoryTaskManager - // Add specific fields if needed for testing. } // newTestBasicTaskManager creates an instance for testing. func newTestBasicTaskManager(t *testing.T) *testBasicTaskManager { - processor := &testStreamingProcessor{} + processor := &testProcessor{} memTm, err := taskmanager.NewMemoryTaskManager(processor) require.NoError(t, err, "Failed to create MemoryTaskManager for testBasicTaskManager") return &testBasicTaskManager{ - MemoryTaskManager: memTm, // Correctly assign the embedded field. + MemoryTaskManager: memTm, } } -// Helper to mimic unexported upsertTask for testing purposes. -// Needs access to baseTm's fields. -func (m *testBasicTaskManager) testUpsertTask(params protocol.SendTaskParams) *protocol.Task { - // NOTE: This accesses fields of baseTm directly, which relies on them being accessible - // (e.g., within the same logical package structure if tests were internal, or if fields were exported). - // For this test setup, we assume accessibility for direct map manipulation. - m.TasksMutex.Lock() - defer m.TasksMutex.Unlock() - - task, exists := m.Tasks[params.ID] - if !exists { - task = protocol.NewTask(params.ID, params.SessionID) - m.Tasks[params.ID] = task - log.Printf("[Test TM Helper] Created new task %s via testUpsertTask", params.ID) - } else { - log.Printf("[Test TM Helper] Updating task %s via testUpsertTask", params.ID) - } +// OnSendMessage delegates to the composed MemoryTaskManager. +func (m *testBasicTaskManager) OnSendMessage( + ctx context.Context, + params protocol.SendMessageParams, +) (*protocol.MessageResult, error) { + log.Printf("[Test TM Wrapper] OnSendMessage called for %s, delegating to base.", params.Message.MessageID) + return m.MemoryTaskManager.OnSendMessage(ctx, params) +} - // Update metadata if provided. - if params.Metadata != nil { - if task.Metadata == nil { - task.Metadata = make(map[string]interface{}) - } - for k, v := range params.Metadata { - task.Metadata[k] = v - } - } - return task +// OnSendMessageStream delegates to the composed MemoryTaskManager. +func (m *testBasicTaskManager) OnSendMessageStream( + ctx context.Context, + params protocol.SendMessageParams, +) (<-chan protocol.StreamingMessageEvent, error) { + log.Printf("[Test TM Wrapper] OnSendMessageStream called for %s, delegating to base.", params.Message.MessageID) + return m.MemoryTaskManager.OnSendMessageStream(ctx, params) } -// Helper to mimic unexported storeMessage for testing purposes. -// Needs access to baseTm's fields. -func (m *testBasicTaskManager) testStoreMessage(taskID string, message protocol.Message) { - // NOTE: This accesses fields of baseTm directly. - m.MessagesMutex.Lock() - defer m.MessagesMutex.Unlock() +// OnResubscribe delegates to the composed MemoryTaskManager. +func (m *testBasicTaskManager) OnResubscribe( + ctx context.Context, + params protocol.TaskIDParams, +) (<-chan protocol.StreamingMessageEvent, error) { + log.Printf("[Test TM Wrapper] OnResubscribe called for %s, delegating to base.", params.ID) + return m.MemoryTaskManager.OnResubscribe(ctx, params) +} - if _, exists := m.Messages[taskID]; !exists { - m.Messages[taskID] = make([]protocol.Message, 0, 1) - } - m.Messages[taskID] = append(m.Messages[taskID], message) - log.Printf("[Test TM Helper] Stored message for task %s via testStoreMessage", taskID) +// OnPushNotificationSet delegates to the composed MemoryTaskManager. +func (m *testBasicTaskManager) OnPushNotificationSet( + ctx context.Context, + params protocol.TaskPushNotificationConfig, +) (*protocol.TaskPushNotificationConfig, error) { + log.Printf("[Test TM Wrapper] OnPushNotificationSet called for %s, delegating to base.", params.TaskID) + return m.MemoryTaskManager.OnPushNotificationSet(ctx, params) } -// --- TaskManager Interface Implementation for testBasicTaskManager --- +// OnPushNotificationGet delegates to the composed MemoryTaskManager. +func (m *testBasicTaskManager) OnPushNotificationGet( + ctx context.Context, + params protocol.TaskIDParams, +) (*protocol.TaskPushNotificationConfig, error) { + log.Printf("[Test TM Wrapper] OnPushNotificationGet called for %s, delegating to base.", params.ID) + return m.MemoryTaskManager.OnPushNotificationGet(ctx, params) +} -// OnSendTask delegates to the composed MemoryTaskManager. +// OnSendTask delegates to the composed MemoryTaskManager (deprecated). func (m *testBasicTaskManager) OnSendTask( ctx context.Context, params protocol.SendTaskParams, ) (*protocol.Task, error) { log.Printf("[Test TM Wrapper] OnSendTask called for %s, delegating to base.", params.ID) - // In a real compositional wrapper, you might add pre/post logic here. return m.MemoryTaskManager.OnSendTask(ctx, params) } @@ -234,34 +257,15 @@ func (m *testBasicTaskManager) OnCancelTask( return m.MemoryTaskManager.OnCancelTask(ctx, params) } -// OnSendTaskSubscribe delegates to the composed MemoryTaskManager. -// The baseTm's method now handles the goroutine and processor invocation correctly. +// OnSendTaskSubscribe delegates to the composed MemoryTaskManager (deprecated). func (m *testBasicTaskManager) OnSendTaskSubscribe( ctx context.Context, params protocol.SendTaskParams, ) (<-chan protocol.TaskEvent, error) { log.Printf("[Test TM Wrapper] OnSendTaskSubscribe called for %s, delegating to base.", params.ID) - // Call the embedded MemoryTaskManager's method directly return m.MemoryTaskManager.OnSendTaskSubscribe(ctx, params) } -// ProcessTask is required by the TaskManager interface but should not be called directly -// in this compositional setup. The actual processing is done by the injected processor. -func (m *testBasicTaskManager) ProcessTask( - ctx context.Context, - taskID string, - message protocol.Message, -) (*protocol.Task, error) { - log.Printf("WARNING: testBasicTaskManager.ProcessTask called directly for %s. "+ - "This should not happen in the compositional setup.", taskID) - // Return an error or a specific state to indicate this shouldn't be called. - return nil, fmt.Errorf("testBasicTaskManager.ProcessTask should not be called directly") -} - -// --- End-to-End Test Function --- - -// --- Common Test Utilities --- - // testHelper contains common utilities and setup for e2e tests. type testHelper struct { t *testing.T @@ -274,7 +278,7 @@ type testHelper struct { } // newTestHelper creates a new test helper with a running server and client. -func newTestHelper(t *testing.T, processor taskmanager.TaskProcessor) *testHelper { +func newTestHelper(t *testing.T, processor taskmanager.MessageProcessor) *testHelper { // Create task manager var tm taskmanager.TaskManager if processor != nil { @@ -346,55 +350,53 @@ func createDefaultTestAgentCard() server.AgentCard { desc := "A test agent for E2E tests" return server.AgentCard{ Name: "Test Agent", - Description: &desc, + Description: desc, Capabilities: server.AgentCapabilities{ - Streaming: true, - StateTransitionHistory: true, + Streaming: boolPtr(true), + StateTransitionHistory: boolPtr(true), }, - DefaultInputModes: []string{string(protocol.PartTypeText)}, - DefaultOutputModes: []string{string(protocol.PartTypeText)}, + DefaultInputModes: []string{string(protocol.KindText)}, + DefaultOutputModes: []string{string(protocol.KindText)}, } } -// sendTestMessage sends a test message and returns the task. -func (h *testHelper) sendTestMessage(taskID string, text string) (*protocol.Task, error) { - return h.client.SendTasks(context.Background(), protocol.SendTaskParams{ - ID: taskID, - Message: protocol.Message{ - Role: protocol.MessageRoleUser, - Parts: []protocol.Part{ - protocol.NewTextPart(text), - }, - }, - }) -} - -// isFinalState checks if the given task state is a terminal state. -func isFinalState(state protocol.TaskState) bool { - return state == protocol.TaskStateCompleted || - state == protocol.TaskStateFailed || - state == protocol.TaskStateCanceled -} - -// waitForTaskCompletion waits for a task to reach a final state. -func waitForTaskCompletion(ctx context.Context, client *client.A2AClient, taskID string) (*protocol.Task, error) { - for { +// collectAllStreamingEvents collects all events from a streaming message event channel until it's closed. +func collectAllStreamingEvents(eventChan <-chan protocol.StreamingMessageEvent) []protocol.StreamingMessageEvent { + var events []protocol.StreamingMessageEvent + timeout := time.After(3 * time.Second) // Safety timeout + done := false + for !done { select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - task, err := client.GetTasks(ctx, protocol.TaskQueryParams{ID: taskID}) - if err != nil { - return nil, err + case event, ok := <-eventChan: + if !ok { + done = true // Channel closed + break } + events = append(events, event) - if isFinalState(task.Status.State) { - return task, nil + // Check if this is a final event + if result, ok := event.Result.(*protocol.TaskStatusUpdateEvent); ok { + if result.IsFinal() { + // Wait just a tiny bit more to see if there are any trailing events + time.Sleep(50 * time.Millisecond) + // Try to drain one more event non-blocking + select { + case lastEvent, ok := <-eventChan: + if ok { + events = append(events, lastEvent) + } + default: + // No more events available + } + return events + } } - - time.Sleep(50 * time.Millisecond) + case <-timeout: + // If we timeout, just return whatever events we've collected so far + return events } } + return events } // collectAllTaskEvents collects all events from a task event channel until it's closed. @@ -412,8 +414,6 @@ func collectAllTaskEvents(eventChan <-chan protocol.TaskEvent) []protocol.TaskEv events = append(events, event) // Check if this is a final event (completed, failed, canceled) - // If we received a final event, we can consider the collection complete - // even if the channel isn't formally closed if event.IsFinal() { // Wait just a tiny bit more to see if there are any trailing events time.Sleep(50 * time.Millisecond) @@ -439,7 +439,7 @@ func collectAllTaskEvents(eventChan <-chan protocol.TaskEvent) []protocol.TaskEv // getTextPartContent extracts text content from parts of a message. func getTextPartContent(parts []protocol.Part) string { for _, part := range parts { - if textPart, ok := part.(protocol.TextPart); ok { + if textPart, ok := part.(*protocol.TextPart); ok { return textPart.Text } } @@ -448,25 +448,172 @@ func getTextPartContent(parts []protocol.Part) string { // --- Test Functions --- -// TestE2E_BasicAgent_Streaming tests the streaming functionality. -func TestE2E_BasicAgent_Streaming(t *testing.T) { - helper := newTestHelper(t, &testStreamingProcessor{}) +// TestE2E_MessageAPI_Streaming tests the streaming functionality using the new message API. +func TestE2E_MessageAPI_Streaming(t *testing.T) { + helper := newTestHelper(t, &testProcessor{}) + defer helper.cleanup() + + // Test data + inputText := "Hello world!" + + // Generate context ID and task ID + contextID := protocol.GenerateContextID() + taskID := protocol.GenerateMessageID() + + // Create message using the NewMessageWithContext constructor + message := protocol.NewMessageWithContext( + protocol.MessageRoleUser, + []protocol.Part{ + protocol.NewTextPart(inputText), + }, + &taskID, + &contextID, + ) + + // Subscribe to streaming message events using the new API + eventChan, err := helper.client.StreamMessage( + context.Background(), + protocol.SendMessageParams{ + Message: message, + }, + ) + require.NoError(t, err) + + // Collect all events + events := collectAllStreamingEvents(eventChan) + + // Verify we received events + require.NotEmpty(t, events, "Should have received events") + + // Verify the events we received + hasWorkingStatus := false + hasArtifact := false + hasCompletedStatus := false + + for _, event := range events { + switch result := event.Result.(type) { + case *protocol.TaskStatusUpdateEvent: + if result.Status.State == protocol.TaskStateWorking { + hasWorkingStatus = true + require.NotNil(t, result.Status.Message, "Working status should have a message") + require.NotEmpty(t, result.Status.Message.Parts, "Working status message should have parts") + textPart, ok := result.Status.Message.Parts[0].(*protocol.TextPart) + require.True(t, ok, "Working status message should have text part") + require.Contains(t, textPart.Text, "Processing chunk:", "Working status should contain processing info") + } else if result.Status.State == protocol.TaskStateCompleted { + hasCompletedStatus = true + require.NotNil(t, result.Status.Message, "Completed status should have a message") + require.NotEmpty(t, result.Status.Message.Parts, "Completed status message should have parts") + textPart, ok := result.Status.Message.Parts[0].(*protocol.TextPart) + require.True(t, ok, "Completed status message should have text part") + require.Contains(t, textPart.Text, "completed successfully", "Completed status should contain success info") + require.Contains(t, textPart.Text, "!dlrow olleH", "Completed status should contain reversed text") + } + case *protocol.TaskArtifactUpdateEvent: + hasArtifact = true + require.NotNil(t, result.Artifact.Name, "Artifact should have a name") + require.Equal(t, "Processed Text", *result.Artifact.Name, "Artifact name should match") + require.NotEmpty(t, result.Artifact.Parts, "Artifact should have parts") + textPart, ok := result.Artifact.Parts[0].(*protocol.TextPart) + require.True(t, ok, "Artifact should have text part") + require.Equal(t, "!dlrow olleH", textPart.Text, "Artifact should contain reversed text") + } + } + + // Verify we got all expected event types + require.True(t, hasWorkingStatus, "Should have received working status updates") + require.True(t, hasArtifact, "Should have received artifact update") + require.True(t, hasCompletedStatus, "Should have received completed status") + + t.Logf("Successfully received %d events", len(events)) +} + +// TestE2E_MessageAPI_NonStreaming tests the non-streaming functionality using the new message API. +func TestE2E_MessageAPI_NonStreaming(t *testing.T) { + helper := newTestHelper(t, &testProcessor{}) + defer helper.cleanup() + + // Test data + inputText := "Hello world!" + + // Generate context ID and task ID + contextID := protocol.GenerateContextID() + taskID := protocol.GenerateMessageID() + + // Create message using the NewMessageWithContext constructor + message := protocol.NewMessageWithContext( + protocol.MessageRoleUser, + []protocol.Part{ + protocol.NewTextPart(inputText), + }, + &taskID, + &contextID, + ) + + // Send message using the new non-streaming API + result, err := helper.client.SendMessage( + context.Background(), + protocol.SendMessageParams{ + Message: message, + }, + ) + require.NoError(t, err) + + // Verify the result contains a task + task, ok := result.Result.(*protocol.Task) + require.True(t, ok, "Result should contain a task") + require.NotNil(t, task, "Task should not be nil") + + // Wait a bit for the task to complete + time.Sleep(500 * time.Millisecond) + + // Get the final task state + finalTask, err := helper.client.GetTasks( + context.Background(), + protocol.TaskQueryParams{ID: task.ID}, + ) + require.NoError(t, err) + require.Equal(t, protocol.TaskStateCompleted, finalTask.Status.State) + + // Verify artifacts + require.NotEmpty(t, finalTask.Artifacts, "Task should have artifacts") + require.Equal(t, 1, len(finalTask.Artifacts), "Task should have 1 artifact") + + // Verify artifact content + artifact := finalTask.Artifacts[0] + require.NotNil(t, artifact.Parts, "Artifact should have parts") + require.Equal(t, 1, len(artifact.Parts), "Artifact should have 1 part") + + // Check the reversed text + reversedText := getTextPartContent(artifact.Parts) + expectedText := testReverseString(inputText) + require.Equal(t, expectedText, reversedText, "Artifact should contain reversed text") +} + +// TestE2E_TaskAPI_Streaming tests the streaming functionality using the legacy task API. +func TestE2E_TaskAPI_Streaming(t *testing.T) { + helper := newTestHelper(t, &testProcessor{}) defer helper.cleanup() // Test code taskID := "test-streaming-1" inputText := "Hello world!" + // Generate context ID and task ID + contextID := protocol.GenerateContextID() + // Subscribe to task events eventChan, err := helper.client.StreamTask( context.Background(), protocol.SendTaskParams{ ID: taskID, - Message: protocol.Message{ - Role: protocol.MessageRoleUser, - Parts: []protocol.Part{ + Message: protocol.NewMessageWithContext( + protocol.MessageRoleUser, + []protocol.Part{ protocol.NewTextPart(inputText), }, - }, + &taskID, + &contextID, + ), }) require.NoError(t, err) @@ -478,8 +625,11 @@ func TestE2E_BasicAgent_Streaming(t *testing.T) { // Check final state lastEvent := events[len(events)-1] - require.Equal(t, protocol.TaskStateCompleted, - lastEvent.(protocol.TaskStatusUpdateEvent).Status.State, "Task should be completed") + if statusEvent, ok := lastEvent.(*protocol.TaskStatusUpdateEvent); ok { + require.Equal(t, protocol.TaskStateCompleted, statusEvent.Status.State, "Task should be completed") + } else { + t.Fatalf("Last event should be a TaskStatusUpdateEvent") + } // Verify task result task, err := helper.client.GetTasks(