Skip to content

Handle socket errors #404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 108 additions & 132 deletions pkg/stream/aggregation.go

Large diffs are not rendered by default.

24 changes: 13 additions & 11 deletions pkg/stream/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package stream
import (
"bufio"
"bytes"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("Compression algorithms", func() {

var entries *subEntries

BeforeEach(func() {
Expand All @@ -35,46 +35,48 @@ var _ = Describe("Compression algorithms", func() {
})

It("NONE", func() {
compressNONE{}.Compress(entries)
err := compressNONE{}.Compress(entries)
Expect(err).NotTo(HaveOccurred())
Expect(entries.totalSizeInBytes).To(Equal(entries.items[0].sizeInBytes))
Expect(entries.totalSizeInBytes).To(Equal(entries.items[0].unCompressedSize))

})

It("GZIP", func() {
gzip := compressGZIP{}
gzip.Compress(entries)
err := gzip.Compress(entries)
Expect(err).NotTo(HaveOccurred())
verifyCompression(gzip, entries)

})

It("SNAPPY", func() {
snappy := compressSnappy{}
snappy.Compress(entries)
err := snappy.Compress(entries)
Expect(err).NotTo(HaveOccurred())
verifyCompression(snappy, entries)
})

It("LZ4", func() {
lz4 := compressLZ4{}
lz4.Compress(entries)
err := lz4.Compress(entries)
Expect(err).NotTo(HaveOccurred())
verifyCompression(lz4, entries)
})

It("ZSTD", func() {
zstd := compressZSTD{}
zstd.Compress(entries)
err := zstd.Compress(entries)
Expect(err).NotTo(HaveOccurred())
verifyCompression(zstd, entries)
})

})

func verifyCompression(algo iCompress, subEntries *subEntries) {

Expect(subEntries.totalSizeInBytes).To(SatisfyAll(BeNumerically("<", subEntries.items[0].unCompressedSize)))
Expect(subEntries.totalSizeInBytes).To(Equal(subEntries.items[0].sizeInBytes))

bufferReader := bytes.NewReader(subEntries.items[0].dataInBytes)
algo.UnCompress(bufio.NewReader(bufferReader),
_, err := algo.UnCompress(bufio.NewReader(bufferReader),
uint32(subEntries.totalSizeInBytes), uint32(subEntries.items[0].unCompressedSize))

Expect(err).NotTo(HaveOccurred())
}
89 changes: 57 additions & 32 deletions pkg/stream/buffer_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ func writeULong(inputBuff *bytes.Buffer, value uint64) {
inputBuff.Write(buff)
}

func writeBLong(inputBuff *bufio.Writer, value int64) {
writeBULong(inputBuff, uint64(value))
func writeBLong(inputBuff *bufio.Writer, value int64) error {
return writeBULong(inputBuff, uint64(value))
}
func writeBULong(inputBuff *bufio.Writer, value uint64) {

func writeBULong(inputBuff *bufio.Writer, value uint64) error {
var buff = make([]byte, 8)
binary.BigEndian.PutUint64(buff, value)
inputBuff.Write(buff)
_, err := inputBuff.Write(buff)
return err
}

func writeShort(inputBuff *bytes.Buffer, value int16) {
Expand All @@ -35,18 +37,24 @@ func writeUShort(inputBuff *bytes.Buffer, value uint16) {
inputBuff.Write(buff)
}

func writeBShort(inputBuff *bufio.Writer, value int16) {
writeBUShort(inputBuff, uint16(value))
func writeBShort(inputBuff *bufio.Writer, value int16) error {
return writeBUShort(inputBuff, uint16(value))
}
func writeBUShort(inputBuff *bufio.Writer, value uint16) {
func writeBUShort(inputBuff *bufio.Writer, value uint16) error {
var buff = make([]byte, 2)
binary.BigEndian.PutUint16(buff, value)
inputBuff.Write(buff)
_, err := inputBuff.Write(buff)
return err
}

func writeBString(inputBuff *bufio.Writer, value string) {
writeBUShort(inputBuff, uint16(len(value)))
inputBuff.Write([]byte(value))
func writeBString(inputBuff *bufio.Writer, value string) error {
err := writeBUShort(inputBuff, uint16(len(value)))
if err != nil {
return err
}

_, err = inputBuff.Write([]byte(value))
return err
}

func writeByte(inputBuff *bytes.Buffer, value byte) {
Expand All @@ -55,10 +63,11 @@ func writeByte(inputBuff *bytes.Buffer, value byte) {
inputBuff.Write(buff)
}

func writeBByte(inputBuff *bufio.Writer, value byte) {
func writeBByte(inputBuff *bufio.Writer, value byte) error {
var buff = make([]byte, 1)
buff[0] = value
inputBuff.Write(buff)
_, err := inputBuff.Write(buff)
return err
}

func writeInt(inputBuff *bytes.Buffer, value int) {
Expand All @@ -70,21 +79,17 @@ func writeUInt(inputBuff *bytes.Buffer, value uint32) {
inputBuff.Write(buff)
}

func writeBInt(inputBuff *bufio.Writer, value int) {
writeBUInt(inputBuff, uint32(value))
func writeBInt(inputBuff *bufio.Writer, value int) error {
return writeBUInt(inputBuff, uint32(value))
}

func writeBUInt(inputBuff *bufio.Writer, value uint32) {
func writeBUInt(inputBuff *bufio.Writer, value uint32) error {
var buff = make([]byte, 4)
binary.BigEndian.PutUint32(buff, value)
inputBuff.Write(buff)
_, err := inputBuff.Write(buff)
return err
}

func bytesFromInt(value uint32) []byte {
var buff = make([]byte, 4)
binary.BigEndian.PutUint32(buff, value)
return buff
}
func writeString(inputBuff *bytes.Buffer, value string) {
writeUShort(inputBuff, uint16(len(value)))
inputBuff.Write([]byte(value))
Expand All @@ -110,7 +115,7 @@ func writeBytes(inputBuff *bytes.Buffer, value []byte) {
inputBuff.Write(value)
}

// writeProtocolHeader protocol utils functions
// writeProtocolHeader protocol utils functions
func writeProtocolHeader(inputBuff *bytes.Buffer,
length int, command uint16,
correlationId ...int) {
Expand All @@ -126,20 +131,30 @@ func writeProtocolHeader(inputBuff *bytes.Buffer,

func writeBProtocolHeader(inputBuff *bufio.Writer,
length int, command int16,
correlationId ...int) {
writeBProtocolHeaderVersion(inputBuff, length, command, version1, correlationId...)
correlationId ...int) error {
return writeBProtocolHeaderVersion(inputBuff, length, command, version1, correlationId...)
}

func writeBProtocolHeaderVersion(inputBuff *bufio.Writer,
length int, command int16, version int16,
correlationId ...int) {
func writeBProtocolHeaderVersion(inputBuff *bufio.Writer, length int, command int16,
version int16, correlationId ...int) error {

if err := writeBInt(inputBuff, length); err != nil {
return err
}
if err := writeBShort(inputBuff, command); err != nil {
return err
}
if err := writeBShort(inputBuff, version); err != nil {
return err
}

writeBInt(inputBuff, length)
writeBShort(inputBuff, command)
writeBShort(inputBuff, version)
if len(correlationId) > 0 {
writeBInt(inputBuff, correlationId[0])
if err := writeBInt(inputBuff, correlationId[0]); err != nil {
return err
}
}

return nil
}

func sizeOfStringArray(array []string) int {
Expand All @@ -157,3 +172,13 @@ func sizeOfMapStringString(mapString map[string]string) int {
}
return size
}

func bytesLenghPrefixed(msg []byte) []byte {
size := len(msg)
buff := make([]byte, 4+size)

binary.BigEndian.PutUint32(buff, uint32(size))
copy(buff[4:], msg)

return buff
}
3 changes: 1 addition & 2 deletions pkg/stream/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ func (c *Client) closeHartBeat() {

}

func (c *Client) Close() error {
func (c *Client) Close() {
c.closeHartBeat()
c.coordinator.Producers().Range(func(_, p any) bool {
producer := p.(*Producer)
Expand Down Expand Up @@ -522,7 +522,6 @@ func (c *Client) Close() error {
_ = c.coordinator.RemoveResponseById(res.correlationid)
}
c.getSocket().shutdown(nil)
return nil
}

func (c *Client) DeclarePublisher(streamName string, options *ProducerOptions) (*Producer, error) {
Expand Down
20 changes: 10 additions & 10 deletions pkg/stream/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package stream
import (
"bytes"
"fmt"
"github.com/rabbitmq/rabbitmq-stream-go-client/pkg/amqp"
logs "github.com/rabbitmq/rabbitmq-stream-go-client/pkg/logs"
"sync"
"time"

"github.com/rabbitmq/rabbitmq-stream-go-client/pkg/amqp"
logs "github.com/rabbitmq/rabbitmq-stream-go-client/pkg/logs"
)

type Consumer struct {
Expand Down Expand Up @@ -316,26 +317,27 @@ func (c *Client) credit(subscriptionId byte, credit int16) {
}

func (consumer *Consumer) Close() error {

if consumer.getStatus() == closed {
return AlreadyClosed
}
return consumer.close(Event{

consumer.close(Event{
Command: CommandUnsubscribe,
StreamName: consumer.GetStreamName(),
Name: consumer.GetName(),
Reason: UnSubscribe,
Err: nil,
})
}

func (consumer *Consumer) close(reason Event) error {
return nil
}

func (consumer *Consumer) close(reason Event) {
if consumer.options == nil {
// the config is usually set. this check is just to avoid panic and to make some test
// easier to write
logs.LogDebug("consumer options is nil, the close will be ignored")
return nil
return
}

consumer.cacheStoreOffset()
Expand Down Expand Up @@ -371,14 +373,12 @@ func (consumer *Consumer) close(reason Event) error {
_, _ = consumer.options.client.coordinator.ExtractConsumerById(consumer.ID)

if consumer.options != nil && consumer.options.client.coordinator.ConsumersCount() == 0 {
_ = consumer.options.client.Close()
consumer.options.client.Close()
}

if consumer.onClose != nil {
consumer.onClose()
}

return nil
}

func (consumer *Consumer) cacheStoreOffset() {
Expand Down
8 changes: 5 additions & 3 deletions pkg/stream/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package stream

import (
"fmt"
"github.com/pkg/errors"
"github.com/rabbitmq/rabbitmq-stream-go-client/pkg/amqp"
"strconv"
"sync"
"time"

"github.com/pkg/errors"
"github.com/rabbitmq/rabbitmq-stream-go-client/pkg/amqp"
)

type Coordinator struct {
Expand Down Expand Up @@ -86,7 +87,8 @@ func (coordinator *Coordinator) RemoveConsumerById(id interface{}, reason Event)
if err != nil {
return err
}
return consumer.close(reason)
consumer.close(reason)
return nil

}
func (coordinator *Coordinator) Consumers() *sync.Map {
Expand Down
14 changes: 3 additions & 11 deletions pkg/stream/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,7 @@ func NewEnvironment(options *EnvironmentOptions) (*Environment, error) {

client := newClient("go-stream-locator", nil,
options.TCPParameters, options.SaslConfiguration, options.RPCTimeout)
defer func(client *Client) {
err := client.Close()
if err != nil {
return
}
}(client)
defer client.Close()

// we put a limit to the heartbeat.
// it doesn't make sense to have a heartbeat less than 3 seconds
Expand Down Expand Up @@ -275,7 +270,7 @@ func (env *Environment) Close() error {
_ = env.producers.close()
_ = env.consumers.close()
if env.locator.client != nil {
_ = env.locator.client.Close()
env.locator.client.Close()
}
env.closed = true
return nil
Expand Down Expand Up @@ -581,10 +576,7 @@ func (cc *environmentCoordinator) newProducer(leader *Broker, tcpParameters *TCP
logs.LogDebug("connectionProperties host %s doesn't match with the advertised_host %s, advertised_port %s .. retry",
clientResult.connectionProperties.host,
leader.advHost, leader.advPort)
err := clientResult.Close()
if err != nil {
return nil, err
}
clientResult.Close()
clientResult = cc.newClientForProducer(clientProvidedName, leader, tcpParameters, saslConfiguration, rpcTimeout)
err = clientResult.connect()
if err != nil {
Expand Down
Loading