Skip to content

Commit 1987b54

Browse files
committed
Make didStartEventSources run once with sync.Once + UT.
1 parent 57acc77 commit 1987b54

File tree

2 files changed

+91
-62
lines changed

2 files changed

+91
-62
lines changed

pkg/internal/controller/controller.go

Lines changed: 61 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ type Controller[request comparable] struct {
8989
// startWatches maintains a list of sources, handlers, and predicates to start when the controller is started.
9090
startWatches []source.TypedSource[request]
9191

92-
// didStartEventSources is used to indicate whether the event sources have been started.
93-
didStartEventSources atomic.Bool
92+
// didStartEventSourcesOnce is used to ensure that the event sources are only started once.
93+
didStartEventSourcesOnce sync.Once
9494

9595
// didEventSourcesFinishSyncSuccessfully is used to indicate whether the event sources have finished
9696
// successfully. It stores a *bool where
@@ -289,70 +289,69 @@ func (c *Controller[request]) Start(ctx context.Context) error {
289289
// startEventSources launches all the sources registered with this controller and waits
290290
// for them to sync. It returns an error if any of the sources fail to start or sync.
291291
func (c *Controller[request]) startEventSources(ctx context.Context) error {
292-
// CAS returns false if value is already true, so early exit since another goroutine must have
293-
// called startEventSources previously
294-
if !c.didStartEventSources.CompareAndSwap(false, true) {
295-
c.LogConstructor(nil).Info("Skipping starting event sources since they were previously started")
296-
return nil
297-
}
298-
299-
errGroup := &errgroup.Group{}
300-
for _, watch := range c.startWatches {
301-
log := c.LogConstructor(nil)
302-
_, ok := watch.(interface {
303-
String() string
304-
})
305-
if !ok {
306-
log = log.WithValues("source", fmt.Sprintf("%T", watch))
307-
} else {
308-
log = log.WithValues("source", fmt.Sprintf("%s", watch))
309-
}
310-
didStartSyncingSource := &atomic.Bool{}
311-
errGroup.Go(func() error {
312-
// Use a timeout for starting and syncing the source to avoid silently
313-
// blocking startup indefinitely if it doesn't come up.
314-
sourceStartCtx, cancel := context.WithTimeout(ctx, c.CacheSyncTimeout)
315-
defer cancel()
316-
317-
sourceStartErrChan := make(chan error, 1) // Buffer chan to not leak goroutine if we time out
318-
go func() {
319-
defer close(sourceStartErrChan)
320-
log.Info("Starting EventSource")
321-
if err := watch.Start(ctx, c.Queue); err != nil {
322-
sourceStartErrChan <- err
323-
return
324-
}
325-
syncingSource, ok := watch.(source.TypedSyncingSource[request])
326-
if !ok {
327-
return
328-
}
329-
didStartSyncingSource.Store(true)
330-
if err := syncingSource.WaitForSync(sourceStartCtx); err != nil {
331-
err := fmt.Errorf("failed to wait for %s caches to sync %v: %w", c.Name, syncingSource, err)
332-
log.Error(err, "Could not wait for Cache to sync")
333-
sourceStartErrChan <- err
292+
var retErr error
293+
294+
c.didStartEventSourcesOnce.Do(func() {
295+
errGroup := &errgroup.Group{}
296+
for _, watch := range c.startWatches {
297+
log := c.LogConstructor(nil)
298+
_, ok := watch.(interface {
299+
String() string
300+
})
301+
if !ok {
302+
log = log.WithValues("source", fmt.Sprintf("%T", watch))
303+
} else {
304+
log = log.WithValues("source", fmt.Sprintf("%s", watch))
305+
}
306+
didStartSyncingSource := &atomic.Bool{}
307+
errGroup.Go(func() error {
308+
// Use a timeout for starting and syncing the source to avoid silently
309+
// blocking startup indefinitely if it doesn't come up.
310+
sourceStartCtx, cancel := context.WithTimeout(ctx, c.CacheSyncTimeout)
311+
defer cancel()
312+
313+
sourceStartErrChan := make(chan error, 1) // Buffer chan to not leak goroutine if we time out
314+
go func() {
315+
defer close(sourceStartErrChan)
316+
log.Info("Starting EventSource")
317+
if err := watch.Start(ctx, c.Queue); err != nil {
318+
sourceStartErrChan <- err
319+
return
320+
}
321+
syncingSource, ok := watch.(source.TypedSyncingSource[request])
322+
if !ok {
323+
return
324+
}
325+
didStartSyncingSource.Store(true)
326+
if err := syncingSource.WaitForSync(sourceStartCtx); err != nil {
327+
err := fmt.Errorf("failed to wait for %s caches to sync %v: %w", c.Name, syncingSource, err)
328+
log.Error(err, "Could not wait for Cache to sync")
329+
sourceStartErrChan <- err
330+
}
331+
}()
332+
333+
select {
334+
case err := <-sourceStartErrChan:
335+
return err
336+
case <-sourceStartCtx.Done():
337+
if didStartSyncingSource.Load() { // We are racing with WaitForSync, wait for it to let it tell us what happened
338+
return <-sourceStartErrChan
339+
}
340+
if ctx.Err() != nil { // Don't return an error if the root context got cancelled
341+
return nil
342+
}
343+
return fmt.Errorf("timed out waiting for source %s to Start. Please ensure that its Start() method is non-blocking", watch)
334344
}
335-
}()
345+
})
346+
}
347+
err := errGroup.Wait()
336348

337-
select {
338-
case err := <-sourceStartErrChan:
339-
return err
340-
case <-sourceStartCtx.Done():
341-
if didStartSyncingSource.Load() { // We are racing with WaitForSync, wait for it to let it tell us what happened
342-
return <-sourceStartErrChan
343-
}
344-
if ctx.Err() != nil { // Don't return an error if the root context got cancelled
345-
return nil
346-
}
347-
return fmt.Errorf("timed out waiting for source %s to Start. Please ensure that its Start() method is non-blocking", watch)
348-
}
349-
})
350-
}
351-
err := errGroup.Wait()
349+
c.didEventSourcesFinishSyncSuccessfully.Store(ptr.To(err == nil))
352350

353-
c.didEventSourcesFinishSyncSuccessfully.Store(ptr.To(err == nil))
351+
retErr = err
352+
})
354353

355-
return err
354+
return retErr
356355
}
357356

358357
// processNextWorkItem will read a single work item off the workqueue and

pkg/internal/controller/controller_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,36 @@ var _ = Describe("controller", func() {
502502
Expect(err).To(HaveOccurred())
503503
Expect(err.Error()).To(ContainSubstring("timed out waiting for source"))
504504
})
505+
506+
It("should only start sources once when called multiple times", func() {
507+
ctx, cancel := context.WithCancel(context.Background())
508+
defer cancel()
509+
510+
ctrl.CacheSyncTimeout = 1 * time.Millisecond
511+
512+
var startCount atomic.Int32
513+
src := source.Func(func(ctx context.Context, _ workqueue.TypedRateLimitingInterface[reconcile.Request]) error {
514+
startCount.Add(1)
515+
return nil
516+
})
517+
518+
ctrl.startWatches = []source.TypedSource[reconcile.Request]{src}
519+
520+
By("Calling startEventSources multiple times in parallel")
521+
var wg sync.WaitGroup
522+
for i := 1; i <= 5; i++ {
523+
wg.Add(1)
524+
go func() {
525+
defer wg.Done()
526+
err := ctrl.startEventSources(ctx)
527+
// All calls should return the same nil error
528+
Expect(err).NotTo(HaveOccurred())
529+
}()
530+
}
531+
532+
wg.Wait()
533+
Expect(startCount.Load()).To(Equal(int32(1)), "Source should only be started once even when called multiple times")
534+
})
505535
})
506536

507537
Describe("Processing queue items from a Controller", func() {

0 commit comments

Comments
 (0)