diff --git a/pkg/stream/client.go b/pkg/stream/client.go index a6c0f0ae..6b17ce79 100644 --- a/pkg/stream/client.go +++ b/pkg/stream/client.go @@ -471,13 +471,13 @@ func (c *Client) closeHartBeat() { } func (c *Client) Close() error { - c.closeHartBeat() - for _, p := range c.coordinator.Producers() { - err := c.coordinator.RemoveProducerById(p.(*Producer).id, Event{ + c.coordinator.Producers().Range(func(_, p any) bool { + producer := p.(*Producer) + err := c.coordinator.RemoveProducerById(producer.id, Event{ Command: CommandClose, - StreamName: p.(*Producer).GetStreamName(), - Name: p.(*Producer).GetName(), + StreamName: producer.GetStreamName(), + Name: producer.GetName(), Reason: SocketClosed, Err: nil, }) @@ -485,23 +485,27 @@ func (c *Client) Close() error { if err != nil { logs.LogWarn("error removing producer: %s", err) } - } - for _, cs := range c.coordinator.GetConsumers() { - if cs != nil { - err := c.coordinator.RemoveConsumerById(cs.(*Consumer).ID, Event{ - Command: CommandClose, - StreamName: cs.(*Consumer).GetStreamName(), - Name: cs.(*Consumer).GetName(), - Reason: SocketClosed, - Err: nil, - }) + return true + }) - if err != nil { - logs.LogWarn("error removing consumer: %s", err) - } + c.coordinator.Consumers().Range(func(_, cs any) bool { + consumer := cs.(*Consumer) + err := c.coordinator.RemoveConsumerById(consumer.ID, Event{ + Command: CommandClose, + StreamName: consumer.GetStreamName(), + Name: consumer.GetName(), + Reason: SocketClosed, + Err: nil, + }) + + if err != nil { + logs.LogWarn("error removing consumer: %s", err) } - } + + return true + }) + if c.getSocket().isOpen() { res := c.coordinator.NewResponse(CommandClose) diff --git a/pkg/stream/coordinator.go b/pkg/stream/coordinator.go index 18d952ea..c2c99b8e 100644 --- a/pkg/stream/coordinator.go +++ b/pkg/stream/coordinator.go @@ -11,8 +11,8 @@ import ( type Coordinator struct { counter int - producers map[interface{}]interface{} - consumers map[interface{}]interface{} + producers *sync.Map + consumers *sync.Map responses map[interface{}]interface{} nextItemProducer uint8 nextItemConsumer uint8 @@ -43,8 +43,8 @@ type Response struct { func NewCoordinator() *Coordinator { return &Coordinator{mutex: &sync.Mutex{}, - producers: make(map[interface{}]interface{}), - consumers: make(map[interface{}]interface{}), + producers: &sync.Map{}, + consumers: &sync.Map{}, responses: make(map[interface{}]interface{})} } @@ -77,7 +77,7 @@ func (coordinator *Coordinator) NewProducer( confirmMutex: &sync.Mutex{}, onClose: cleanUp, } - coordinator.producers[lastId] = producer + coordinator.producers.Store(lastId, producer) return producer, err } @@ -89,11 +89,8 @@ func (coordinator *Coordinator) RemoveConsumerById(id interface{}, reason Event) return consumer.close(reason) } -func (coordinator *Coordinator) GetConsumers() map[interface{}]interface{} { - coordinator.mutex.Lock() - defer coordinator.mutex.Unlock() +func (coordinator *Coordinator) Consumers() *sync.Map { return coordinator.consumers - } func (coordinator *Coordinator) RemoveProducerById(id uint8, reason Event) error { @@ -117,7 +114,7 @@ func (coordinator *Coordinator) RemoveResponseById(id interface{}) error { } func (coordinator *Coordinator) ProducersCount() int { - return coordinator.count(coordinator.producers) + return coordinator.countSyncMap(coordinator.producers) } // response @@ -198,28 +195,25 @@ func (coordinator *Coordinator) NewConsumer(messagesHandler MessagesHandler, onClose: cleanUp, } - coordinator.consumers[lastId] = item + coordinator.consumers.Store(lastId, item) + return item } func (coordinator *Coordinator) GetConsumerById(id interface{}) (*Consumer, error) { - v, err := coordinator.getById(id, coordinator.consumers) - if err != nil { - return nil, err + if consumer, exists := coordinator.consumers.Load(id); exists { + return consumer.(*Consumer), nil } - return v.(*Consumer), err + + return nil, errors.New("item #{id} not found ") } func (coordinator *Coordinator) ExtractConsumerById(id interface{}) (*Consumer, error) { - coordinator.mutex.Lock() - defer coordinator.mutex.Unlock() - if coordinator.consumers[id] == nil { - return nil, errors.New("item #{id} not found ") + if consumer, exists := coordinator.consumers.LoadAndDelete(id); exists { + return consumer.(*Consumer), nil } - consumer := coordinator.consumers[id].(*Consumer) - coordinator.consumers[id] = nil - delete(coordinator.consumers, id) - return consumer, nil + + return nil, errors.New("item #{id} not found ") } func (coordinator *Coordinator) GetResponseById(id uint32) (*Response, error) { @@ -231,31 +225,26 @@ func (coordinator *Coordinator) GetResponseById(id uint32) (*Response, error) { } func (coordinator *Coordinator) ConsumersCount() int { - return coordinator.count(coordinator.consumers) + return coordinator.countSyncMap(coordinator.consumers) } func (coordinator *Coordinator) GetProducerById(id interface{}) (*Producer, error) { - v, err := coordinator.getById(id, coordinator.producers) - if err != nil { - return nil, err + if producer, exists := coordinator.producers.Load(id); exists { + return producer.(*Producer), nil } - return v.(*Producer), err + + return nil, errors.New("item #{id} not found ") } func (coordinator *Coordinator) ExtractProducerById(id interface{}) (*Producer, error) { - coordinator.mutex.Lock() - defer coordinator.mutex.Unlock() - if coordinator.producers[id] == nil { - return nil, errors.New("item #{id} not found ") + if producer, exists := coordinator.producers.LoadAndDelete(id); exists { + return producer.(*Producer), nil } - producer := coordinator.producers[id].(*Producer) - coordinator.producers[id] = nil - delete(coordinator.producers, id) - return producer, nil + + return nil, errors.New("item #{id} not found ") } // general functions - func (coordinator *Coordinator) getById(id interface{}, refmap map[interface{}]interface{}) (interface{}, error) { coordinator.mutex.Lock() defer coordinator.mutex.Unlock() @@ -276,11 +265,16 @@ func (coordinator *Coordinator) removeById(id interface{}, refmap map[interface{ return nil } -func (coordinator *Coordinator) count(refmap map[interface{}]interface{}) int { - coordinator.mutex.Lock() - defer coordinator.mutex.Unlock() - return len(refmap) +func (coordinator *Coordinator) countSyncMap(refmap *sync.Map) int { + count := 0 + refmap.Range(func(_, _ interface{}) bool { + count++ + return true + }) + + return count } + func (coordinator *Coordinator) getNextProducerItem() (uint8, error) { if coordinator.nextItemProducer >= ^uint8(0) { return coordinator.reuseFreeId(coordinator.producers) @@ -299,11 +293,11 @@ func (coordinator *Coordinator) getNextConsumerItem() (uint8, error) { return res, nil } -func (coordinator *Coordinator) reuseFreeId(refMap map[interface{}]interface{}) (byte, error) { +func (coordinator *Coordinator) reuseFreeId(refMap *sync.Map) (byte, error) { maxValue := int(^uint8(0)) var result byte for i := 0; i < maxValue; i++ { - if refMap[byte(i)] == nil { + if _, exists := refMap.Load(byte(i)); !exists { return byte(i), nil } result++ @@ -314,8 +308,20 @@ func (coordinator *Coordinator) reuseFreeId(refMap map[interface{}]interface{}) return result, nil } -func (coordinator *Coordinator) Producers() map[interface{}]interface{} { - coordinator.mutex.Lock() - defer coordinator.mutex.Unlock() +func (coordinator *Coordinator) Producers() *sync.Map { return coordinator.producers } + +func (coordinator *Coordinator) Close() { + coordinator.producers.Range(func(_, producer interface{}) bool { + _ = producer.(*Producer).Close() + + return true + }) + + coordinator.consumers.Range(func(_, consumer interface{}) bool { + _ = consumer.(*Consumer).Close() + + return true + }) +} diff --git a/pkg/stream/environment.go b/pkg/stream/environment.go index 3d777a5e..9d886010 100644 --- a/pkg/stream/environment.go +++ b/pkg/stream/environment.go @@ -510,42 +510,43 @@ func (cc *environmentCoordinator) maybeCleanClients() { } func (c *Client) maybeCleanProducers(streamName string) { - c.mutex.Lock() - for pidx, producer := range c.coordinator.Producers() { - if producer.(*Producer).GetStreamName() == streamName { + c.coordinator.Producers().Range(func(pidx, p any) bool { + producer := p.(*Producer) + if producer.GetStreamName() == streamName { err := c.coordinator.RemoveProducerById(pidx.(uint8), Event{ Command: CommandMetadataUpdate, StreamName: streamName, - Name: producer.(*Producer).GetName(), + Name: producer.GetName(), Reason: MetaDataUpdate, Err: nil, }) if err != nil { - return + return false } } - } - c.mutex.Unlock() + return true + }) } func (c *Client) maybeCleanConsumers(streamName string) { - c.mutex.Lock() - for pidx, consumer := range c.coordinator.consumers { - if consumer.(*Consumer).options.streamName == streamName { + c.coordinator.Consumers().Range(func(pidx, cs any) bool { + consumer := cs.(*Consumer) + if consumer.options.streamName == streamName { err := c.coordinator.RemoveConsumerById(pidx.(uint8), Event{ Command: CommandMetadataUpdate, StreamName: streamName, - Name: consumer.(*Consumer).GetName(), + Name: consumer.GetName(), Reason: MetaDataUpdate, Err: nil, }) if err != nil { - return + return false } } - } - c.mutex.Unlock() + + return true + }) } func (cc *environmentCoordinator) newProducer(leader *Broker, tcpParameters *TCPParameters, saslConfiguration *SaslConfiguration, streamName string, options *ProducerOptions, rpcTimeout time.Duration, cleanUp func()) (*Producer, error) { @@ -643,15 +644,9 @@ func (cc *environmentCoordinator) newConsumer(connectionName string, leader *Bro } func (cc *environmentCoordinator) Close() error { - cc.clientsPerContext.Range(func(key, value any) bool { - client := value.(*Client) - for i := range client.coordinator.producers { - _ = client.coordinator.producers[i].(*Producer).Close() - } - for i := range client.coordinator.consumers { - _ = client.coordinator.consumers[i].(*Consumer).Close() - } + value.(*Client).coordinator.Close() + return true }) diff --git a/pkg/stream/super_stream_producer_test.go b/pkg/stream/super_stream_producer_test.go index 3c79eb89..d6052923 100644 --- a/pkg/stream/super_stream_producer_test.go +++ b/pkg/stream/super_stream_producer_test.go @@ -476,4 +476,61 @@ var _ = Describe("Super Stream Producer", Label("super-stream-producer"), func() Expect(env.Close()).NotTo(HaveOccurred()) }) + It("should reconnect to the same partition after a close event", func() { + const partitionsCount = 3 + env, err := NewEnvironment(nil) + Expect(err).NotTo(HaveOccurred()) + + var superStream = fmt.Sprintf("reconnect-test-super-stream-%d", time.Now().Unix()) + Expect(env.DeclareSuperStream(superStream, NewPartitionsOptions(partitionsCount))).NotTo(HaveOccurred()) + + superProducer, err := newSuperStreamProducer(env, superStream, &SuperStreamProducerOptions{ + RoutingStrategy: NewHashRoutingStrategy(func(msg message.StreamMessage) string { + return msg.GetApplicationProperties()["routingKey"].(string) + }), + }) + Expect(err).To(BeNil()) + Expect(superProducer).NotTo(BeNil()) + Expect(superProducer.init()).NotTo(HaveOccurred()) + producers := superProducer.getProducers() + Expect(producers).To(HaveLen(partitionsCount)) + partitionToClose := producers[0].GetStreamName() + + // Declare synchronization helpers and listeners + partitionCloseEvent := make(chan bool) + + // Listen for the partition close event and try to reconnect + go func(ch <-chan PPartitionClose) { + for event := range ch { + err := event.Context.ConnectPartition(event.Partition) + Expect(err).To(BeNil()) + + partitionCloseEvent <- true + + break + + } + }(superProducer.NotifyPartitionClose(1)) + + // Imitates metadataUpdateFrameHandler - it can happen when stream members are changed. + go func() { + client, ok := env.producers.getCoordinators()["localhost:5552"].clientsPerContext.Load(1) + Expect(ok).To(BeTrue()) + client.(*Client).maybeCleanProducers(partitionToClose) + }() + + // Wait for the partition close event + Eventually(partitionCloseEvent).WithTimeout(5 * time.Second).WithPolling(100 * time.Millisecond).Should(Receive()) + + // Verify that the partition was successfully reconnected + Expect(superProducer.getProducers()).To(HaveLen(partitionsCount)) + reconnectedProducer := superProducer.getProducer(partitionToClose) + Expect(reconnectedProducer).NotTo(BeNil()) + + // Clean up + Expect(superProducer.Close()).NotTo(HaveOccurred()) + Expect(env.DeleteSuperStream(superStream)).NotTo(HaveOccurred()) + Expect(env.Close()).NotTo(HaveOccurred()) + }) + })