From cb6361b8c7eb733a1b7848ffd77fd319d71d7645 Mon Sep 17 00:00:00 2001 From: Dan Bason Date: Mon, 28 Jul 2025 15:40:23 +1200 Subject: [PATCH] Watch the tls certs for changes update the served certs Signed-off-by: Dan Bason --- go.mod | 2 +- internal/server/certloader/certloader.go | 74 ++++++++++++ internal/server/certwatcher/certwatcher.go | 114 ++++++++++++++++++ .../server/certwatcher/certwatcher_test.go | 91 ++++++++++++++ .../server/certwatcher/directorywatcher.go | 70 +++++++++++ .../certwatcher/directorywatcher_test.go | 69 +++++++++++ pkg/server/server.go | 18 ++- 7 files changed, 435 insertions(+), 3 deletions(-) create mode 100644 internal/server/certloader/certloader.go create mode 100644 internal/server/certwatcher/certwatcher.go create mode 100644 internal/server/certwatcher/certwatcher_test.go create mode 100644 internal/server/certwatcher/directorywatcher.go create mode 100644 internal/server/certwatcher/directorywatcher_test.go diff --git a/go.mod b/go.mod index 33349d79a4..9641bc5d38 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( github.com/expr-lang/expr v1.17.6 github.com/fatih/structtag v1.2.0 github.com/fluxcd/pkg/kustomize v1.22.0 + github.com/fsnotify/fsnotify v1.9.0 github.com/go-git/go-git/v5 v5.16.2 github.com/go-logr/logr v1.4.3 github.com/go-logr/zapr v1.3.0 @@ -126,7 +127,6 @@ require ( github.com/exponent-io/jsonpath v0.0.0-20210407135951-1de76d718b3f // indirect github.com/fatih/color v1.16.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fxamacker/cbor/v2 v2.9.0 // indirect github.com/go-errors/errors v1.5.1 // indirect github.com/go-fed/httpsig v1.1.0 // indirect diff --git a/internal/server/certloader/certloader.go b/internal/server/certloader/certloader.go new file mode 100644 index 0000000000..db4ee9db62 --- /dev/null +++ b/internal/server/certloader/certloader.go @@ -0,0 +1,74 @@ +package certloader + +import ( + "crypto/tls" + "sync" + + "github.com/akuity/kargo/internal/server/certwatcher" + "github.com/akuity/kargo/pkg/logging" +) + +type CertLoader struct { + logger *logging.Logger + certPath, keyPath string + done chan struct{} + certWatcher *certwatcher.CertWatcher + + cert *tls.Certificate + certLock sync.RWMutex +} + +func NewCertLoader(logger *logging.Logger, certPath, keyPath string) (*CertLoader, error) { + certWatcher, err := certwatcher.NewCertWatcher(certPath, keyPath) + if err != nil { + return nil, err + } + + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, err + } + + c := &CertLoader{ + logger: logger, + certWatcher: certWatcher, + cert: &cert, + } + + go c.run() + + return c, nil +} + +func (c *CertLoader) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + c.certLock.RLock() + defer c.certLock.RUnlock() + return c.cert, nil +} + +func (c *CertLoader) Close() { + c.certWatcher.Close() + close(c.done) +} + +func (c *CertLoader) run() { + go c.certWatcher.Run() + for { + select { + case <-c.done: + return + case _, ok := <-c.certWatcher.Events(): + if !ok { + return + } + cert, err := tls.LoadX509KeyPair(c.certPath, c.keyPath) + if err != nil { + c.logger.Error(err, "failed to load certificate and key pair, keeping existing certificate") + continue + } + c.certLock.Lock() + c.cert = &cert + c.certLock.Unlock() + } + } +} diff --git a/internal/server/certwatcher/certwatcher.go b/internal/server/certwatcher/certwatcher.go new file mode 100644 index 0000000000..b4216b3859 --- /dev/null +++ b/internal/server/certwatcher/certwatcher.go @@ -0,0 +1,114 @@ +package certwatcher + +import ( + "fmt" + "os" + "path/filepath" + "sync" + + "github.com/fsnotify/fsnotify" +) + +// Certwatcher watches for any changes to a certificate and key pair. +// It is used to reload the certificate and key pair when they are updated. +// It is also used to reload the certificate and key pair when the file is +// created. +type CertWatcher struct { + directories map[string]*directoryWatcher + notify chan struct{} +} + +func NewCertWatcher(certPath, keyPath string) (*CertWatcher, error) { + certWatcher := &CertWatcher{ + directories: make(map[string]*directoryWatcher), + notify: make(chan struct{}), + } + + err := certWatcher.addPath(certPath) + if err != nil { + return nil, err + } + + err = certWatcher.addPath(keyPath) + if err != nil { + return nil, err + } + + return certWatcher, nil +} + +// Events returns a channel that will be notified when the certificate or key +// pair is updated. +func (c *CertWatcher) Events() <-chan struct{} { + return c.notify +} + +// Run starts the certwatcher and watches for changes to the certificate and +// key pair. Run blocks until the certwatcher is closed. +func (c *CertWatcher) Run() { + defer close(c.notify) + wg := sync.WaitGroup{} + for _, dirWatcher := range c.directories { + wg.Add(1) + go func(dirWatcher *directoryWatcher) { + defer wg.Done() + events := dirWatcher.watch() + for range events { + select { + case c.notify <- struct{}{}: + case <-dirWatcher.done: + return + } + } + }(dirWatcher) + } + wg.Wait() +} + +// Close closes the certwatcher and stops watching for changes. +func (c *CertWatcher) Close() { + for _, dirWatcher := range c.directories { + close(dirWatcher.done) + dirWatcher.watcher.Close() + } +} + +func (c *CertWatcher) addPath(path string) error { + _, err := os.Stat(path) + if err != nil { + return fmt.Errorf("failed to stat %q: %w", path, err) + } + + absolutePath, err := filepath.Abs(path) + if err != nil { + return fmt.Errorf("failed to get absolute path for %q: %w", path, err) + } + fp, _ := filepath.EvalSymlinks(absolutePath) + + fileDir := filepath.Dir(absolutePath) + + dirWatcher, ok := c.directories[fileDir] + if ok { + dirWatcher.watchedFiles[absolutePath] = fp + return nil + } + + watcher, err := fsnotify.NewWatcher() + if err != nil { + return fmt.Errorf("failed to create fsnotify watcher: %w", err) + } + err = watcher.Add(fileDir) + if err != nil { + return fmt.Errorf("failed to add %q to fsnotify watcher: %w", fileDir, err) + } + + dirWatcher = &directoryWatcher{ + watcher: watcher, + watchedFiles: make(map[string]string), + done: make(chan struct{}), + } + + dirWatcher.watchedFiles[absolutePath] = fp + c.directories[fileDir] = dirWatcher + return nil +} diff --git a/internal/server/certwatcher/certwatcher_test.go b/internal/server/certwatcher/certwatcher_test.go new file mode 100644 index 0000000000..2b1d5bdec3 --- /dev/null +++ b/internal/server/certwatcher/certwatcher_test.go @@ -0,0 +1,91 @@ +package certwatcher + +import ( + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewCertWatcher(t *testing.T) { + t.Run("success", func(t *testing.T) { + tempDir := t.TempDir() + certPath := filepath.Join(tempDir, "tls.crt") + keyPath := filepath.Join(tempDir, "tls.key") + require.NoError(t, os.WriteFile(certPath, []byte("cert"), 0600)) + require.NoError(t, os.WriteFile(keyPath, []byte("key"), 0600)) + + cw, err := NewCertWatcher(certPath, keyPath) + require.NoError(t, err) + require.NotNil(t, cw) + require.Len(t, cw.directories, 1) + }) + + t.Run("cert path does not exist", func(t *testing.T) { + tempDir := t.TempDir() + certPath := filepath.Join(tempDir, "tls.crt") + keyPath := filepath.Join(tempDir, "tls.key") + require.NoError(t, os.WriteFile(keyPath, []byte("key"), 0600)) + + _, err := NewCertWatcher(certPath, keyPath) + require.Error(t, err) + }) + + t.Run("key path does not exist", func(t *testing.T) { + tempDir := t.TempDir() + certPath := filepath.Join(tempDir, "tls.crt") + keyPath := filepath.Join(tempDir, "tls.key") + require.NoError(t, os.WriteFile(certPath, []byte("cert"), 0600)) + + _, err := NewCertWatcher(certPath, keyPath) + require.Error(t, err) + }) +} + +func TestCertWatcher(t *testing.T) { + tempDir := t.TempDir() + certPath := filepath.Join(tempDir, "tls.crt") + keyPath := filepath.Join(tempDir, "tls.key") + require.NoError(t, os.WriteFile(certPath, []byte("cert"), 0600)) + require.NoError(t, os.WriteFile(keyPath, []byte("key"), 0600)) + + cw, err := NewCertWatcher(certPath, keyPath) + require.NoError(t, err) + require.NotNil(t, cw) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + cw.Run() + }() + + // Wait a bit for the watcher to start + time.Sleep(100 * time.Millisecond) + + // Update the cert file + require.NoError(t, os.WriteFile(certPath, []byte("new cert"), 0600)) + + select { + case <-cw.Events(): + // All good + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for event") + } + + // Update the key file + require.NoError(t, os.WriteFile(keyPath, []byte("new key"), 0600)) + + select { + case <-cw.Events(): + // All good + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for event") + } + + cw.Close() + wg.Wait() +} diff --git a/internal/server/certwatcher/directorywatcher.go b/internal/server/certwatcher/directorywatcher.go new file mode 100644 index 0000000000..ffbf586064 --- /dev/null +++ b/internal/server/certwatcher/directorywatcher.go @@ -0,0 +1,70 @@ +package certwatcher + +import ( + "path/filepath" + "time" + + "github.com/fsnotify/fsnotify" +) + +type directoryWatcher struct { + watcher *fsnotify.Watcher + watchedFiles map[string]string + done chan struct{} +} + +func (d *directoryWatcher) watch() <-chan struct{} { + events := make(chan struct{}) + go func() { + defer close(events) + for { + select { + case <-d.done: + return + case event, ok := <-d.watcher.Events: + if !ok { + return + } + if d.shouldSendEvent(event) { + select { + case events <- struct{}{}: + case <-d.done: + return + } + } + } + } + }() + + return events +} + +func (d *directoryWatcher) shouldSendEvent(event fsnotify.Event) bool { + sleepTime := 10 * time.Millisecond + eventPath, _ := filepath.Abs(event.Name) + eventPath, _ = filepath.EvalSymlinks(eventPath) + + for abs, previous := range d.watchedFiles { + currentWatchedPath, _ := filepath.Abs(abs) + switch { + case currentWatchedPath == "": + // watched file was removed; wait for write event to trigger reload + d.watchedFiles[abs] = "" + case currentWatchedPath != previous: + // File previously didn't exist; send a signal to the caller + time.Sleep(sleepTime) + d.watchedFiles[abs] = currentWatchedPath + return true + case eventPath == currentWatchedPath && isUpdatedFileEvent(event): + // File was modified so send a signal to the caller + time.Sleep(sleepTime) + d.watchedFiles[abs] = currentWatchedPath + return true + } + } + return false +} + +func isUpdatedFileEvent(event fsnotify.Event) bool { + return (event.Op&fsnotify.Write) == fsnotify.Write || (event.Op&fsnotify.Create) == fsnotify.Create +} diff --git a/internal/server/certwatcher/directorywatcher_test.go b/internal/server/certwatcher/directorywatcher_test.go new file mode 100644 index 0000000000..5772d6c4be --- /dev/null +++ b/internal/server/certwatcher/directorywatcher_test.go @@ -0,0 +1,69 @@ +package certwatcher + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/fsnotify/fsnotify" + "github.com/stretchr/testify/require" +) + +func TestDirectoryWatcher(t *testing.T) { + tempDir := t.TempDir() + filePath := filepath.Join(tempDir, "file.txt") + require.NoError(t, os.WriteFile(filePath, []byte("hello"), 0600)) + + watcher, err := fsnotify.NewWatcher() + require.NoError(t, err) + require.NoError(t, watcher.Add(tempDir)) + + dw := &directoryWatcher{ + watcher: watcher, + watchedFiles: make(map[string]string), + done: make(chan struct{}), + } + absPath, err := filepath.Abs(filePath) + require.NoError(t, err) + dw.watchedFiles[absPath], _ = filepath.EvalSymlinks(absPath) + + events := dw.watch() + + // Test file modification + require.NoError(t, os.WriteFile(filePath, []byte("world"), 0600)) + select { + case <-events: + // Expected event + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for event") + } + + // Test file removal and recreation + require.NoError(t, os.Remove(filePath)) + // Without a short pause, the test can be flaky + time.Sleep(10 * time.Millisecond) + require.NoError(t, os.WriteFile(filePath, []byte("new"), 0600)) + select { + case <-events: + // Expected event + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for event") + } + + close(dw.done) + + timeout := time.After(1 * time.Second) +CLOSED: + for { + select { + case _, ok := <-events: + if !ok { + break CLOSED + } + continue CLOSED + case <-timeout: + t.Fatal("timed out waiting for events to be closed") + } + } +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 528af3bb65..8569d03c0e 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -2,6 +2,7 @@ package server import ( "context" + "crypto/tls" "embed" "errors" "fmt" @@ -25,6 +26,7 @@ import ( "github.com/akuity/kargo/api/service/v1alpha1/svcv1alpha1connect" rolloutsapi "github.com/akuity/kargo/api/stubs/rollouts/v1alpha1" kargoapi "github.com/akuity/kargo/api/v1alpha1" + "github.com/akuity/kargo/internal/server/certloader" "github.com/akuity/kargo/pkg/api" rollouts "github.com/akuity/kargo/pkg/api/stubs/rollouts" "github.com/akuity/kargo/pkg/event" @@ -241,10 +243,22 @@ func (s *server) Serve(ctx context.Context, l net.Listener) error { errCh := make(chan error) go func() { if s.cfg.TLSConfig != nil { + certLoader, err := certloader.NewCertLoader(logger, s.cfg.TLSConfig.CertPath, s.cfg.TLSConfig.KeyPath) + if err != nil { + errCh <- fmt.Errorf("error initializing cert loader: %w", err) + return + } + defer certLoader.Close() + + srv.TLSConfig = &tls.Config{ + GetCertificate: certLoader.GetCertificate, + MinVersion: tls.VersionTLS13, + } + errCh <- srv.ServeTLS( l, - s.cfg.TLSConfig.CertPath, - s.cfg.TLSConfig.KeyPath, + "", // cert - not used because GetCertificate is used + "", // key - not used because GetCertificate is used ) } else { errCh <- srv.Serve(l)