diff --git a/pkg/stream/client_test.go b/pkg/stream/client_test.go index f7a9f017..3356f579 100644 --- a/pkg/stream/client_test.go +++ b/pkg/stream/client_test.go @@ -213,4 +213,21 @@ var _ = Describe("Streaming testEnvironment", func() { Expect(res).To(BeNil()) }) + It("Client.handleGenericResponse handles timeout and missing response gracefully", func() { + cli := newClient("connName", nil, nil, nil, defaultSocketCallTimeout) + + // Simulate timeout: create a response and remove it immediately + res := cli.coordinator.NewResponse(commandDeclarePublisher, "Simulated Test") + err := cli.coordinator.RemoveResponseById(res.correlationid) + Expect(err).To(BeNil()) + + // Simulate receiving a response for the removed correlation ID + readerProtocol := &ReaderProtocol{ + CorrelationId: uint32(res.correlationid), + ResponseCode: responseCodeStreamNotAvailable, + } + cli.handleGenericResponse(readerProtocol, bufio.NewReader(bytes.NewBuffer([]byte{}))) + + }) + }) diff --git a/pkg/stream/coordinator.go b/pkg/stream/coordinator.go index c2c99b8e..1183b1fe 100644 --- a/pkg/stream/coordinator.go +++ b/pkg/stream/coordinator.go @@ -221,7 +221,7 @@ func (coordinator *Coordinator) GetResponseById(id uint32) (*Response, error) { if err != nil { return nil, err } - return v.(*Response), err + return v.(*Response), nil } func (coordinator *Coordinator) ConsumersCount() int { diff --git a/pkg/stream/server_frame.go b/pkg/stream/server_frame.go index a5d48412..c0953529 100644 --- a/pkg/stream/server_frame.go +++ b/pkg/stream/server_frame.go @@ -146,7 +146,7 @@ func (c *Client) handleResponse() { } } -func (c *Client) handleSaslHandshakeResponse(streamingRes *ReaderProtocol, r *bufio.Reader) interface{} { +func (c *Client) handleSaslHandshakeResponse(streamingRes *ReaderProtocol, r *bufio.Reader) { streamingRes.CorrelationId, _ = readUInt(r) streamingRes.ResponseCode = uShortExtractResponseCode(readUShort(r)) mechanismsCount, _ := readUInt(r) @@ -158,12 +158,11 @@ func (c *Client) handleSaslHandshakeResponse(streamingRes *ReaderProtocol, r *bu res, err := c.coordinator.GetResponseById(streamingRes.CorrelationId) if err != nil { - // TODO handle response - return err + logErrorCommand(err, "handleSaslHandshakeResponse") + return } - res.data <- mechanisms - return mechanisms + res.data <- mechanisms } func (c *Client) handlePeerProperties(readProtocol *ReaderProtocol, r *bufio.Reader) { @@ -178,7 +177,11 @@ func (c *Client) handlePeerProperties(readProtocol *ReaderProtocol, r *bufio.Rea serverProperties[key] = value } res, err := c.coordinator.GetResponseById(readProtocol.CorrelationId) - logErrorCommand(err, "handlePeerProperties") + if err != nil { + logErrorCommand(err, "handlePeerProperties") + return + } + res.code <- Code{id: readProtocol.ResponseCode} res.data <- serverProperties @@ -210,7 +213,11 @@ func (c *Client) handleGenericResponse(readProtocol *ReaderProtocol, r *bufio.Re readProtocol.CorrelationId, _ = readUInt(r) readProtocol.ResponseCode = uShortExtractResponseCode(readUShort(r)) res, err := c.coordinator.GetResponseById(readProtocol.CorrelationId) - logErrorCommand(err, "handleGenericResponse") + if err != nil { + logErrorCommand(err, "handleGenericResponse") + return + } + res.code <- Code{id: readProtocol.ResponseCode} } @@ -237,7 +244,11 @@ func (c *Client) commandOpen(readProtocol *ReaderProtocol, r *bufio.Reader) { } res, err := c.coordinator.GetResponseById(readProtocol.CorrelationId) - logErrorCommand(err, "commandOpen") + if err != nil { + logErrorCommand(err, "commandOpen") + return + } + res.code <- Code{id: readProtocol.ResponseCode} res.data <- clientProperties @@ -277,7 +288,11 @@ func (c *Client) queryPublisherSequenceFrameHandler(readProtocol *ReaderProtocol readProtocol.ResponseCode = uShortExtractResponseCode(readUShort(r)) sequence := readInt64(r) res, err := c.coordinator.GetResponseById(readProtocol.CorrelationId) - logErrorCommand(err, "queryPublisherSequenceFrameHandler") + if err != nil { + logErrorCommand(err, "queryPublisherSequenceFrameHandler") + return + } + res.code <- Code{id: readProtocol.ResponseCode} res.data <- sequence } @@ -458,7 +473,11 @@ func (c *Client) queryOffsetFrameHandler(readProtocol *ReaderProtocol, c.handleGenericResponse(readProtocol, r) offset := readInt64(r) res, err := c.coordinator.GetResponseById(readProtocol.CorrelationId) - logErrorCommand(err, "queryOffsetFrameHandler") + if err != nil { + logErrorCommand(err, "queryOffsetFrameHandler") + return + } + res.data <- offset } @@ -516,7 +535,11 @@ func (c *Client) streamStatusFrameHandler(readProtocol *ReaderProtocol, streamStatus[key] = value } res, err := c.coordinator.GetResponseById(readProtocol.CorrelationId) - logErrorCommand(err, "streamStatusFrameHandler") + if err != nil { + logErrorCommand(err, "streamStatusFrameHandler") + return + } + res.code <- Code{id: readProtocol.ResponseCode} res.data <- streamStatus @@ -553,7 +576,10 @@ func (c *Client) metadataFrameHandler(readProtocol *ReaderProtocol, } res, err := c.coordinator.GetResponseById(readProtocol.CorrelationId) - logErrorCommand(err, "metadataFrameHandler") + if err != nil { + logErrorCommand(err, "metadataFrameHandler") + return + } res.code <- Code{id: readProtocol.ResponseCode} res.data <- streamsMetadata @@ -612,7 +638,11 @@ func (c *Client) handleQueryPartitions(readProtocol *ReaderProtocol, r *bufio.Re partitions = append(partitions, partition) } res, err := c.coordinator.GetResponseById(readProtocol.CorrelationId) - logErrorCommand(err, "handleQueryPartitions") + if err != nil { + logErrorCommand(err, "handleQueryPartitions") + return + } + res.code <- Code{id: readProtocol.ResponseCode} res.data <- partitions } @@ -629,7 +659,11 @@ func (c *Client) handleQueryRoute(readProtocol *ReaderProtocol, r *bufio.Reader) } res, err := c.coordinator.GetResponseById(readProtocol.CorrelationId) - logErrorCommand(err, "handleQueryRoute") + if err != nil { + logErrorCommand(err, "handleQueryRoute") + return + } + res.code <- Code{id: readProtocol.ResponseCode} res.data <- routes } @@ -646,7 +680,11 @@ func (c *Client) handleExchangeVersionResponse(readProtocol *ReaderProtocol, r * commands = append(commands, newCommandVersionResponse(minVersion, maxVersion, commandKey)) } res, err := c.coordinator.GetResponseById(readProtocol.CorrelationId) - logErrorCommand(err, "handleExchangeVersionResponse") + if err != nil { + logErrorCommand(err, "handleExchangeVersionResponse") + return + } + res.code <- Code{id: readProtocol.ResponseCode} res.data <- commands }