diff --git a/pkg/raw/client_types.go b/pkg/raw/client_types.go index 0e154593..0c8acefb 100644 --- a/pkg/raw/client_types.go +++ b/pkg/raw/client_types.go @@ -74,6 +74,10 @@ func (r *ClientConfiguration) RabbitmqBrokers() broker { return r.rabbitmqBroker } +func (r *ClientConfiguration) ClientHeartbeat() uint32 { + return r.clientHeartbeat +} + func (r *ClientConfiguration) SetClientMaxFrameSize(clientMaxFrameSize uint32) { r.clientMaxFrameSize = clientMaxFrameSize } diff --git a/pkg/stream/heartbeater.go b/pkg/stream/heartbeater.go new file mode 100644 index 00000000..69171dd1 --- /dev/null +++ b/pkg/stream/heartbeater.go @@ -0,0 +1,85 @@ +package stream + +import ( + "github.com/rabbitmq/rabbitmq-stream-go-client/v2/pkg/raw" + "golang.org/x/exp/slog" + "sync" + "time" +) + +type heartBeater struct { + logger *slog.Logger + client raw.Clienter + tickDuration time.Duration + ticker *time.Ticker + done *DoneChan + receiveCh <-chan *raw.Heartbeat +} + +type DoneChan struct { + C chan struct{} + closed bool + mutex sync.Mutex +} + +func NewDoneChan() *DoneChan { + return &DoneChan{C: make(chan struct{})} +} + +// GracefulClose closes the DoneChan only if the Done chan is not already closed. +func (dc *DoneChan) GracefulClose() { + dc.mutex.Lock() + defer dc.mutex.Unlock() + + if !dc.closed { + close(dc.C) + dc.closed = true + } +} + +func NewHeartBeater(duration time.Duration, client raw.Clienter, logger *slog.Logger) *heartBeater { + return &heartBeater{ + logger: logger, + client: client, + tickDuration: duration, + done: NewDoneChan(), + } +} + +func (hb *heartBeater) start() { + hb.ticker = time.NewTicker(hb.tickDuration) + hb.receiveCh = hb.client.NotifyHeartbeat() + + go func() { + for { + select { + case <-hb.done.C: + return + case <-hb.ticker.C: + hb.send() + case <-hb.receiveCh: + hb.send() + } + } + }() +} + +func (hb *heartBeater) reset() { + // This nil check is mainly for tests. + if hb == nil || hb.ticker == nil { + return + } + hb.ticker.Reset(hb.tickDuration) +} + +func (hb *heartBeater) stop() { + hb.ticker.Stop() + hb.done.GracefulClose() +} + +func (hb *heartBeater) send() { + err := hb.client.SendHeartbeat() + if err != nil { + hb.logger.Error("error sending heartbeat", "error", err) + } +} diff --git a/pkg/stream/heartbeater_test.go b/pkg/stream/heartbeater_test.go new file mode 100644 index 00000000..e22c04a2 --- /dev/null +++ b/pkg/stream/heartbeater_test.go @@ -0,0 +1,86 @@ +package stream + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/rabbitmq/rabbitmq-stream-go-client/v2/pkg/raw" + "go.uber.org/mock/gomock" + "time" +) + +var _ = Describe("Heartbeater", func() { + + var ( + hb *heartBeater + mockCtrl *gomock.Controller + mockRawClient *MockRawClient + ) + + BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) + mockRawClient = NewMockRawClient(mockCtrl) + hb = NewHeartBeater(time.Millisecond*100, mockRawClient, nil) + }) + + It("can configure the tick duration", func() { + Expect(hb.tickDuration).To(BeNumerically("==", 100000000)) + }) + + It("when the tick duration expires, it sends a heartbeat", func() { + // setup the mock + mockRawClient.EXPECT().SendHeartbeat() + mockRawClient.EXPECT().NotifyHeartbeat(). + DoAndReturn(func() <-chan *raw.Heartbeat { + return make(chan *raw.Heartbeat) + }) + hb.start() + // wait until the mock gets called + // the mock will fail the test in SendHeartbeat is not called + <-time.After(time.Millisecond * 150) + }) + + It("sends a heartbeat when it receives one from the server", func(ctx SpecContext) { + var receiveCh chan *raw.Heartbeat + mockRawClient.EXPECT().NotifyHeartbeat(). + DoAndReturn(func() <-chan *raw.Heartbeat { + receiveCh = make(chan *raw.Heartbeat) + return receiveCh + }) + mockRawClient.EXPECT().SendHeartbeat() + + hb.start() + + select { + case <-ctx.Done(): + Fail("failed in setup: did not send a heartbeat notification") + case receiveCh <- &raw.Heartbeat{}: + } + + // wait until the mock gets called + // the mock will fail the test in SendHeartbeat is not called + <-time.After(time.Millisecond * 50) + hb.stop() + }, SpecTimeout(100*time.Millisecond)) + + It("stops the heartbeater gracefully", func() { + // TODO use the gleak detector + mockRawClient.EXPECT().NotifyHeartbeat(). + DoAndReturn(func() <-chan *raw.Heartbeat { + return make(chan *raw.Heartbeat) + }) + + hb.tickDuration = time.Minute // we do not want to receive a tick + hb.start() + Expect(hb.done.C).ToNot(BeClosed()) + Expect(hb.ticker.C).ToNot(BeClosed()) + + hb.stop() + Expect(hb.done.C).To(BeClosed()) + Consistently(hb.ticker.C, "100ms").ShouldNot(Receive()) + + By("not panicking on subsequent close") + hb.stop() + // TODO investigate using gleak and asserts that heartbeater go routine have not leaked + // tried this before, but could not make the test go red, even after leaking the heartbeater routine + }) +}) diff --git a/pkg/stream/locator.go b/pkg/stream/locator.go index 8ed838ac..b6a06ec9 100644 --- a/pkg/stream/locator.go +++ b/pkg/stream/locator.go @@ -2,12 +2,13 @@ package stream import ( "context" - "github.com/rabbitmq/rabbitmq-stream-go-client/v2/pkg/raw" - "golang.org/x/exp/slog" - "golang.org/x/mod/semver" "net" "sync" "time" + + "github.com/rabbitmq/rabbitmq-stream-go-client/v2/pkg/raw" + "golang.org/x/exp/slog" + "golang.org/x/mod/semver" ) const ( @@ -24,11 +25,11 @@ type locator struct { clientClose <-chan error backOffPolicy func(int) time.Duration addressResolver net.Addr // TODO: placeholder for address resolver - + heartbeater *heartBeater } func newLocator(c raw.ClientConfiguration, logger *slog.Logger) *locator { - return &locator{ + locator := &locator{ log: logger. WithGroup("locator"). With( @@ -43,7 +44,10 @@ func newLocator(c raw.ClientConfiguration, logger *slog.Logger) *locator { isSet: false, addressResolver: nil, shutdownNotification: make(chan struct{}), + heartbeater: NewHeartBeater(time.Second*time.Duration(c.ClientHeartbeat()), nil, logger), } + + return locator } func (l *locator) connect(ctx context.Context) error { @@ -64,6 +68,9 @@ func (l *locator) connect(ctx context.Context) error { return l.client.ExchangeCommandVersions(ctx) } + l.heartbeater.client = client + l.heartbeater.start() + return nil } @@ -95,6 +102,7 @@ func (l *locator) shutdownHandler() { // TODO: maybe add a 'ok' safeguard here? log.Debug("unexpected locator disconnection, trying to reconnect", slog.Any("error", err)) l.Lock() + l.heartbeater.stop() for i := 0; i < 100; i++ { dialCtx, cancel := context.WithTimeout(raw.NewContextWithLogger(context.Background(), *log), DefaultTimeout) c, e := raw.DialConfig(dialCtx, &l.rawClientConf) @@ -116,6 +124,8 @@ func (l *locator) shutdownHandler() { l.client = c l.clientClose = c.NotifyConnectionClosed() + l.heartbeater.client = c + l.heartbeater.start() log.Debug("locator reconnected") @@ -171,6 +181,8 @@ func (l *locator) locatorOperation(op locatorOperationFn, args ...any) (result [ l.log.Debug("error in locator operation", slog.Any("error", lastErr), slog.Int("attempt", attempt)) attempt++ } + // TODO reset heartbeat timer + l.heartbeater.reset() return result } @@ -204,6 +216,7 @@ func (l *locator) operationQueryOffset(args ...any) []any { offset, err := l.client.QueryOffset(ctx, reference, stream) return []any{offset, err} } + func (l *locator) operationPartitions(args ...any) []any { ctx := args[0].(context.Context) superstream := args[1].(string) diff --git a/pkg/stream/locator_test.go b/pkg/stream/locator_test.go index c269b021..03ce30ca 100644 --- a/pkg/stream/locator_test.go +++ b/pkg/stream/locator_test.go @@ -3,13 +3,13 @@ package stream import ( "context" "errors" - "time" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/onsi/gomega/gbytes" "github.com/rabbitmq/rabbitmq-stream-go-client/v2/pkg/raw" + "go.uber.org/mock/gomock" "golang.org/x/exp/slog" + "time" ) var _ = Describe("Locator", func() { @@ -119,6 +119,55 @@ var _ = Describe("Locator", func() { Eventually(logBuffer).Within(time.Second).Should(gbytes.Say("context deadline exceeded")) }) }) + + }) + + Describe("heartbeats", func() { + var ( + mockCtrl *gomock.Controller + mockRawClient *MockRawClient + discardLogger = slog.New(discardHandler{}) + backOffPolicy = func(_ int) time.Duration { + return time.Millisecond * 10 + } + ) + + BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) + mockRawClient = NewMockRawClient(mockCtrl) + }) + + It("resets heartbeat ticker on locator operation", func() { + mockRawClient.EXPECT().NotifyHeartbeat() + + hb := NewHeartBeater(time.Second, mockRawClient, discardLogger) + loc := &locator{ + log: discardLogger, + shutdownNotification: make(chan struct{}), + rawClientConf: raw.ClientConfiguration{}, + client: nil, + isSet: true, + clientClose: nil, + backOffPolicy: backOffPolicy, + heartbeater: hb, + } + + done := make(chan struct{}) + hb.start() + + go func() { + defer GinkgoRecover() + Consistently(loc.heartbeater.ticker.C, "1010ms").ShouldNot(Receive()) + close(done) + }() + + // do a locator op + loc.locatorOperation(func(_ *locator, _ ...any) (result []any) { + <-time.After(time.Millisecond * 50) + return []any{nil} + }) + <-done + }) }) Describe("Utils", func() { diff --git a/pkg/stream/mock_raw_client_test.go b/pkg/stream/mock_raw_client_test.go index d1068fd2..212594eb 100644 --- a/pkg/stream/mock_raw_client_test.go +++ b/pkg/stream/mock_raw_client_test.go @@ -6,6 +6,7 @@ package stream import ( context "context" + "github.com/onsi/ginkgo/v2" reflect "reflect" common "github.com/rabbitmq/rabbitmq-stream-go-client/v2/pkg/common" @@ -379,6 +380,7 @@ func (mr *MockRawClientMockRecorder) Send(ctx, publisherId, messages interface{} // SendHeartbeat mocks base method. func (m *MockRawClient) SendHeartbeat() error { + defer ginkgo.GinkgoRecover() m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SendHeartbeat") ret0, _ := ret[0].(error)