From 483c3c386161cb63c8f9182928c2622b9d1f6423 Mon Sep 17 00:00:00 2001 From: Gabriele Santomaggio Date: Tue, 21 Jan 2025 10:28:31 +0100 Subject: [PATCH] replace map with syncmap Signed-off-by: Gabriele Santomaggio --- pkg/stream/environment.go | 90 ++++++++++++++++++++-------------- pkg/stream/environment_test.go | 57 +++++++++++++++++++++ 2 files changed, 109 insertions(+), 38 deletions(-) diff --git a/pkg/stream/environment.go b/pkg/stream/environment.go index 06499cc6..17903f07 100644 --- a/pkg/stream/environment.go +++ b/pkg/stream/environment.go @@ -470,32 +470,41 @@ func (envOptions *EnvironmentOptions) SetRPCTimeout(timeout time.Duration) *Envi type environmentCoordinator struct { mutex *sync.Mutex - mutexContext *sync.RWMutex - clientsPerContext map[int]*Client + clientsPerContext sync.Map maxItemsForClient int nextId int } func (cc *environmentCoordinator) isProducerListFull(clientsPerContextId int) bool { - return cc.clientsPerContext[clientsPerContextId].coordinator. - ProducersCount() >= cc.maxItemsForClient + client, ok := cc.clientsPerContext.Load(clientsPerContextId) + if !ok { + logs.LogError("client not found") + return false + } + return client.(*Client).coordinator.ProducersCount() >= cc.maxItemsForClient + } func (cc *environmentCoordinator) isConsumerListFull(clientsPerContextId int) bool { - return cc.clientsPerContext[clientsPerContextId].coordinator. - ConsumersCount() >= cc.maxItemsForClient + client, ok := cc.clientsPerContext.Load(clientsPerContextId) + if !ok { + logs.LogError("client not found") + return false + } + return client.(*Client).coordinator.ConsumersCount() >= cc.maxItemsForClient } func (cc *environmentCoordinator) maybeCleanClients() { cc.mutex.Lock() defer cc.mutex.Unlock() - cc.mutexContext.Lock() - defer cc.mutexContext.Unlock() - for i, client := range cc.clientsPerContext { + + cc.clientsPerContext.Range(func(key, value any) bool { + client := value.(*Client) if !client.socket.isOpen() { - delete(cc.clientsPerContext, i) + cc.clientsPerContext.Delete(key) } - } + return true + }) } func (c *Client) maybeCleanProducers(streamName string) { @@ -541,15 +550,16 @@ func (cc *environmentCoordinator) newProducer(leader *Broker, tcpParameters *TCP options *ProducerOptions, rpcTimeout time.Duration) (*Producer, error) { cc.mutex.Lock() defer cc.mutex.Unlock() - cc.mutexContext.Lock() - defer cc.mutexContext.Unlock() var clientResult *Client - for i, client := range cc.clientsPerContext { - if !cc.isProducerListFull(i) { - clientResult = client - break + + cc.clientsPerContext.Range(func(key, value any) bool { + if !cc.isProducerListFull(key.(int)) { + clientResult = value.(*Client) + return false } - } + return true + }) + clientProvidedName := "go-stream-producer" if options != nil && options.ClientProvidedName != "" { clientProvidedName = options.ClientProvidedName @@ -593,7 +603,8 @@ func (cc *environmentCoordinator) newProducer(leader *Broker, tcpParameters *TCP func (cc *environmentCoordinator) newClientForProducer(connectionName string, leader *Broker, tcpParameters *TCPParameters, saslConfiguration *SaslConfiguration, rpcTimeOut time.Duration) *Client { clientResult := newClient(connectionName, leader, tcpParameters, saslConfiguration, rpcTimeOut) cc.nextId++ - cc.clientsPerContext[cc.nextId] = clientResult + + cc.clientsPerContext.Store(cc.nextId, clientResult) return clientResult } @@ -602,20 +613,20 @@ func (cc *environmentCoordinator) newConsumer(connectionName string, leader *Bro options *ConsumerOptions, rpcTimeout time.Duration) (*Consumer, error) { cc.mutex.Lock() defer cc.mutex.Unlock() - cc.mutexContext.Lock() - defer cc.mutexContext.Unlock() var clientResult *Client - for i, client := range cc.clientsPerContext { - if !cc.isConsumerListFull(i) { - clientResult = client - break + + cc.clientsPerContext.Range(func(key, value any) bool { + if !cc.isConsumerListFull(key.(int)) { + clientResult = value.(*Client) + return false } - } + return true + }) if clientResult == nil { clientResult = newClient(connectionName, leader, tcpParameters, saslConfiguration, rpcTimeout) cc.nextId++ - cc.clientsPerContext[cc.nextId] = clientResult + cc.clientsPerContext.Store(cc.nextId, clientResult) } // try to reconnect in case the socket is closed err := clientResult.connect() @@ -632,23 +643,28 @@ func (cc *environmentCoordinator) newConsumer(connectionName string, leader *Bro } func (cc *environmentCoordinator) Close() error { - cc.mutexContext.Lock() - defer cc.mutexContext.Unlock() - for _, client := range cc.clientsPerContext { + + cc.clientsPerContext.Range(func(key, value any) bool { + client := value.(*Client) for i := range client.coordinator.producers { _ = client.coordinator.producers[i].(*Producer).Close() } for i := range client.coordinator.consumers { _ = client.coordinator.consumers[i].(*Consumer).Close() } - } + return true + }) + return nil } func (cc *environmentCoordinator) getClientsPerContext() map[int]*Client { - cc.mutexContext.Lock() - defer cc.mutexContext.Unlock() - return cc.clientsPerContext + clients := map[int]*Client{} + cc.clientsPerContext.Range(func(key, value any) bool { + clients[key.(int)] = value.(*Client) + return true + }) + return clients } type producersEnvironment struct { @@ -677,10 +693,9 @@ func (ps *producersEnvironment) newProducer(clientLocator *Client, streamName st coordinatorKey := leader.hostPort() if ps.producersCoordinator[coordinatorKey] == nil { ps.producersCoordinator[coordinatorKey] = &environmentCoordinator{ - clientsPerContext: map[int]*Client{}, + clientsPerContext: sync.Map{}, mutex: &sync.Mutex{}, maxItemsForClient: ps.maxItemsForClient, - mutexContext: &sync.RWMutex{}, nextId: 0, } } @@ -742,10 +757,9 @@ func (ps *consumersEnvironment) NewSubscriber(clientLocator *Client, streamName coordinatorKey := consumerBroker.hostPort() if ps.consumersCoordinator[coordinatorKey] == nil { ps.consumersCoordinator[coordinatorKey] = &environmentCoordinator{ - clientsPerContext: map[int]*Client{}, + clientsPerContext: sync.Map{}, mutex: &sync.Mutex{}, maxItemsForClient: ps.maxItemsForClient, - mutexContext: &sync.RWMutex{}, nextId: 0, } } diff --git a/pkg/stream/environment_test.go b/pkg/stream/environment_test.go index 7177dfc7..55f12386 100644 --- a/pkg/stream/environment_test.go +++ b/pkg/stream/environment_test.go @@ -2,6 +2,7 @@ package stream import ( "crypto/tls" + "github.com/rabbitmq/rabbitmq-stream-go-client/pkg/amqp" "sync" "time" @@ -419,4 +420,60 @@ var _ = Describe("Environment test", func() { Expect(env.Close()).NotTo(HaveOccurred()) }) + It("close env should close all the producers and consumers ", func() { + env, err := NewEnvironment(NewEnvironmentOptions(). + SetMaxConsumersPerClient(2). + SetMaxConsumersPerClient(3)) + + Expect(err).NotTo(HaveOccurred()) + streamName := uuid.New().String() + Expect(env.DeclareStream(streamName, nil)).NotTo(HaveOccurred()) + for i := 0; i < 5; i++ { + _, err := env.NewProducer(streamName, nil) + Expect(err).NotTo(HaveOccurred()) + } + + for i := 0; i < 5; i++ { + _, err := env.NewConsumer(streamName, func(consumerContext ConsumerContext, message *amqp.Message) { + + }, nil) + Expect(err).NotTo(HaveOccurred()) + } + + // count element sync map + count := 0 + env.consumers.getCoordinators()["localhost:5552"].clientsPerContext.Range(func(key, value any) bool { + Expect(value).NotTo(BeNil()) + count++ + return true + }) + + Expect(count).To(Equal(2)) + + Expect(env.Close()).NotTo(HaveOccurred()) + + // count element sync map + + Eventually(func() int { + count = 0 + env.producers.getCoordinators()["localhost:5552"].clientsPerContext.Range(func(key, value any) bool { + Expect(value).To(BeNil()) + count++ + return true + }) + return count + }, "5s", "1s").Should(Equal(0)) + + Eventually(func() int { + count = 0 + env.consumers.getCoordinators()["localhost:5552"].clientsPerContext.Range(func(key, value any) bool { + Expect(value).To(BeNil()) + count++ + return true + }) + return count + }, "5s", "1s").Should(Equal(0)) + + }) + })