Skip to content

replace map with syncmap #377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 52 additions & 38 deletions pkg/stream/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,32 +470,41 @@ func (envOptions *EnvironmentOptions) SetRPCTimeout(timeout time.Duration) *Envi

type environmentCoordinator struct {
mutex *sync.Mutex
mutexContext *sync.RWMutex
clientsPerContext map[int]*Client
clientsPerContext sync.Map
maxItemsForClient int
nextId int
}

func (cc *environmentCoordinator) isProducerListFull(clientsPerContextId int) bool {
return cc.clientsPerContext[clientsPerContextId].coordinator.
ProducersCount() >= cc.maxItemsForClient
client, ok := cc.clientsPerContext.Load(clientsPerContextId)
if !ok {
logs.LogError("client not found")
return false
}
return client.(*Client).coordinator.ProducersCount() >= cc.maxItemsForClient

}

func (cc *environmentCoordinator) isConsumerListFull(clientsPerContextId int) bool {
return cc.clientsPerContext[clientsPerContextId].coordinator.
ConsumersCount() >= cc.maxItemsForClient
client, ok := cc.clientsPerContext.Load(clientsPerContextId)
if !ok {
logs.LogError("client not found")
return false
}
return client.(*Client).coordinator.ConsumersCount() >= cc.maxItemsForClient
}

func (cc *environmentCoordinator) maybeCleanClients() {
cc.mutex.Lock()
defer cc.mutex.Unlock()
cc.mutexContext.Lock()
defer cc.mutexContext.Unlock()
for i, client := range cc.clientsPerContext {

cc.clientsPerContext.Range(func(key, value any) bool {
client := value.(*Client)
if !client.socket.isOpen() {
delete(cc.clientsPerContext, i)
cc.clientsPerContext.Delete(key)
}
}
return true
})
}

func (c *Client) maybeCleanProducers(streamName string) {
Expand Down Expand Up @@ -541,15 +550,16 @@ func (cc *environmentCoordinator) newProducer(leader *Broker, tcpParameters *TCP
options *ProducerOptions, rpcTimeout time.Duration) (*Producer, error) {
cc.mutex.Lock()
defer cc.mutex.Unlock()
cc.mutexContext.Lock()
defer cc.mutexContext.Unlock()
var clientResult *Client
for i, client := range cc.clientsPerContext {
if !cc.isProducerListFull(i) {
clientResult = client
break

cc.clientsPerContext.Range(func(key, value any) bool {
if !cc.isProducerListFull(key.(int)) {
clientResult = value.(*Client)
return false
}
}
return true
})

clientProvidedName := "go-stream-producer"
if options != nil && options.ClientProvidedName != "" {
clientProvidedName = options.ClientProvidedName
Expand Down Expand Up @@ -593,7 +603,8 @@ func (cc *environmentCoordinator) newProducer(leader *Broker, tcpParameters *TCP
func (cc *environmentCoordinator) newClientForProducer(connectionName string, leader *Broker, tcpParameters *TCPParameters, saslConfiguration *SaslConfiguration, rpcTimeOut time.Duration) *Client {
clientResult := newClient(connectionName, leader, tcpParameters, saslConfiguration, rpcTimeOut)
cc.nextId++
cc.clientsPerContext[cc.nextId] = clientResult

cc.clientsPerContext.Store(cc.nextId, clientResult)
return clientResult
}

Expand All @@ -602,20 +613,20 @@ func (cc *environmentCoordinator) newConsumer(connectionName string, leader *Bro
options *ConsumerOptions, rpcTimeout time.Duration) (*Consumer, error) {
cc.mutex.Lock()
defer cc.mutex.Unlock()
cc.mutexContext.Lock()
defer cc.mutexContext.Unlock()
var clientResult *Client
for i, client := range cc.clientsPerContext {
if !cc.isConsumerListFull(i) {
clientResult = client
break

cc.clientsPerContext.Range(func(key, value any) bool {
if !cc.isConsumerListFull(key.(int)) {
clientResult = value.(*Client)
return false
}
}
return true
})

if clientResult == nil {
clientResult = newClient(connectionName, leader, tcpParameters, saslConfiguration, rpcTimeout)
cc.nextId++
cc.clientsPerContext[cc.nextId] = clientResult
cc.clientsPerContext.Store(cc.nextId, clientResult)
}
// try to reconnect in case the socket is closed
err := clientResult.connect()
Expand All @@ -632,23 +643,28 @@ func (cc *environmentCoordinator) newConsumer(connectionName string, leader *Bro
}

func (cc *environmentCoordinator) Close() error {
cc.mutexContext.Lock()
defer cc.mutexContext.Unlock()
for _, client := range cc.clientsPerContext {

cc.clientsPerContext.Range(func(key, value any) bool {
client := value.(*Client)
for i := range client.coordinator.producers {
_ = client.coordinator.producers[i].(*Producer).Close()
}
for i := range client.coordinator.consumers {
_ = client.coordinator.consumers[i].(*Consumer).Close()
}
}
return true
})

return nil
}

func (cc *environmentCoordinator) getClientsPerContext() map[int]*Client {
cc.mutexContext.Lock()
defer cc.mutexContext.Unlock()
return cc.clientsPerContext
clients := map[int]*Client{}
cc.clientsPerContext.Range(func(key, value any) bool {
clients[key.(int)] = value.(*Client)
return true
})
return clients
}

type producersEnvironment struct {
Expand Down Expand Up @@ -677,10 +693,9 @@ func (ps *producersEnvironment) newProducer(clientLocator *Client, streamName st
coordinatorKey := leader.hostPort()
if ps.producersCoordinator[coordinatorKey] == nil {
ps.producersCoordinator[coordinatorKey] = &environmentCoordinator{
clientsPerContext: map[int]*Client{},
clientsPerContext: sync.Map{},
mutex: &sync.Mutex{},
maxItemsForClient: ps.maxItemsForClient,
mutexContext: &sync.RWMutex{},
nextId: 0,
}
}
Expand Down Expand Up @@ -742,10 +757,9 @@ func (ps *consumersEnvironment) NewSubscriber(clientLocator *Client, streamName
coordinatorKey := consumerBroker.hostPort()
if ps.consumersCoordinator[coordinatorKey] == nil {
ps.consumersCoordinator[coordinatorKey] = &environmentCoordinator{
clientsPerContext: map[int]*Client{},
clientsPerContext: sync.Map{},
mutex: &sync.Mutex{},
maxItemsForClient: ps.maxItemsForClient,
mutexContext: &sync.RWMutex{},
nextId: 0,
}
}
Expand Down
57 changes: 57 additions & 0 deletions pkg/stream/environment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package stream

import (
"crypto/tls"
"github.com/rabbitmq/rabbitmq-stream-go-client/pkg/amqp"
"sync"
"time"

Expand Down Expand Up @@ -419,4 +420,60 @@ var _ = Describe("Environment test", func() {
Expect(env.Close()).NotTo(HaveOccurred())
})

It("close env should close all the producers and consumers ", func() {
env, err := NewEnvironment(NewEnvironmentOptions().
SetMaxConsumersPerClient(2).
SetMaxConsumersPerClient(3))

Expect(err).NotTo(HaveOccurred())
streamName := uuid.New().String()
Expect(env.DeclareStream(streamName, nil)).NotTo(HaveOccurred())
for i := 0; i < 5; i++ {
_, err := env.NewProducer(streamName, nil)
Expect(err).NotTo(HaveOccurred())
}

for i := 0; i < 5; i++ {
_, err := env.NewConsumer(streamName, func(consumerContext ConsumerContext, message *amqp.Message) {

}, nil)
Expect(err).NotTo(HaveOccurred())
}

// count element sync map
count := 0
env.consumers.getCoordinators()["localhost:5552"].clientsPerContext.Range(func(key, value any) bool {
Expect(value).NotTo(BeNil())
count++
return true
})

Expect(count).To(Equal(2))

Expect(env.Close()).NotTo(HaveOccurred())

// count element sync map

Eventually(func() int {
count = 0
env.producers.getCoordinators()["localhost:5552"].clientsPerContext.Range(func(key, value any) bool {
Expect(value).To(BeNil())
count++
return true
})
return count
}, "5s", "1s").Should(Equal(0))

Eventually(func() int {
count = 0
env.consumers.getCoordinators()["localhost:5552"].clientsPerContext.Range(func(key, value any) bool {
Expect(value).To(BeNil())
count++
return true
})
return count
}, "5s", "1s").Should(Equal(0))

})

})
Loading