Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ require (
github.com/expr-lang/expr v1.17.6
github.com/fatih/structtag v1.2.0
github.com/fluxcd/pkg/kustomize v1.22.0
github.com/fsnotify/fsnotify v1.9.0
github.com/go-git/go-git/v5 v5.16.2
github.com/go-logr/logr v1.4.3
github.com/go-logr/zapr v1.3.0
Expand Down Expand Up @@ -126,7 +127,6 @@ require (
github.com/exponent-io/jsonpath v0.0.0-20210407135951-1de76d718b3f // indirect
github.com/fatih/color v1.16.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/fxamacker/cbor/v2 v2.9.0 // indirect
github.com/go-errors/errors v1.5.1 // indirect
github.com/go-fed/httpsig v1.1.0 // indirect
Expand Down
74 changes: 74 additions & 0 deletions internal/server/certloader/certloader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package certloader

import (
"crypto/tls"
"sync"

"github.com/akuity/kargo/internal/server/certwatcher"
"github.com/akuity/kargo/pkg/logging"
)

type CertLoader struct {
logger *logging.Logger
certPath, keyPath string
done chan struct{}
certWatcher *certwatcher.CertWatcher

cert *tls.Certificate
certLock sync.RWMutex
}

func NewCertLoader(logger *logging.Logger, certPath, keyPath string) (*CertLoader, error) {
certWatcher, err := certwatcher.NewCertWatcher(certPath, keyPath)
if err != nil {
return nil, err
}

cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, err
}

c := &CertLoader{
logger: logger,
certWatcher: certWatcher,
cert: &cert,
}

go c.run()

return c, nil
}

func (c *CertLoader) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
c.certLock.RLock()
defer c.certLock.RUnlock()
return c.cert, nil
}

func (c *CertLoader) Close() {
c.certWatcher.Close()
close(c.done)
}

func (c *CertLoader) run() {
go c.certWatcher.Run()
for {
select {
case <-c.done:
return
case _, ok := <-c.certWatcher.Events():
if !ok {
return
}
cert, err := tls.LoadX509KeyPair(c.certPath, c.keyPath)
if err != nil {
c.logger.Error(err, "failed to load certificate and key pair, keeping existing certificate")
continue
}
c.certLock.Lock()
c.cert = &cert
c.certLock.Unlock()
}
}
}
114 changes: 114 additions & 0 deletions internal/server/certwatcher/certwatcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package certwatcher

import (
"fmt"
"os"
"path/filepath"
"sync"

"github.com/fsnotify/fsnotify"
)

// Certwatcher watches for any changes to a certificate and key pair.
// It is used to reload the certificate and key pair when they are updated.
// It is also used to reload the certificate and key pair when the file is
// created.
type CertWatcher struct {
directories map[string]*directoryWatcher
notify chan struct{}
}

func NewCertWatcher(certPath, keyPath string) (*CertWatcher, error) {
certWatcher := &CertWatcher{
directories: make(map[string]*directoryWatcher),
notify: make(chan struct{}),
}

err := certWatcher.addPath(certPath)
if err != nil {
return nil, err
}

err = certWatcher.addPath(keyPath)
if err != nil {
return nil, err
}

return certWatcher, nil
}

// Events returns a channel that will be notified when the certificate or key
// pair is updated.
func (c *CertWatcher) Events() <-chan struct{} {
return c.notify
}

// Run starts the certwatcher and watches for changes to the certificate and
// key pair. Run blocks until the certwatcher is closed.
func (c *CertWatcher) Run() {
defer close(c.notify)
wg := sync.WaitGroup{}
for _, dirWatcher := range c.directories {
wg.Add(1)
go func(dirWatcher *directoryWatcher) {
defer wg.Done()
events := dirWatcher.watch()
for range events {
select {
case c.notify <- struct{}{}:
case <-dirWatcher.done:
return
}
}
}(dirWatcher)
}
wg.Wait()
}

// Close closes the certwatcher and stops watching for changes.
func (c *CertWatcher) Close() {
for _, dirWatcher := range c.directories {
close(dirWatcher.done)
dirWatcher.watcher.Close()
}
}

func (c *CertWatcher) addPath(path string) error {
_, err := os.Stat(path)
if err != nil {
return fmt.Errorf("failed to stat %q: %w", path, err)
}

absolutePath, err := filepath.Abs(path)
if err != nil {
return fmt.Errorf("failed to get absolute path for %q: %w", path, err)
}
fp, _ := filepath.EvalSymlinks(absolutePath)

fileDir := filepath.Dir(absolutePath)

dirWatcher, ok := c.directories[fileDir]
if ok {
dirWatcher.watchedFiles[absolutePath] = fp
return nil
}

watcher, err := fsnotify.NewWatcher()
if err != nil {
return fmt.Errorf("failed to create fsnotify watcher: %w", err)
}
err = watcher.Add(fileDir)
if err != nil {
return fmt.Errorf("failed to add %q to fsnotify watcher: %w", fileDir, err)
}

dirWatcher = &directoryWatcher{
watcher: watcher,
watchedFiles: make(map[string]string),
done: make(chan struct{}),
}

dirWatcher.watchedFiles[absolutePath] = fp
c.directories[fileDir] = dirWatcher
return nil
}
91 changes: 91 additions & 0 deletions internal/server/certwatcher/certwatcher_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package certwatcher

import (
"os"
"path/filepath"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestNewCertWatcher(t *testing.T) {
t.Run("success", func(t *testing.T) {
tempDir := t.TempDir()
certPath := filepath.Join(tempDir, "tls.crt")
keyPath := filepath.Join(tempDir, "tls.key")
require.NoError(t, os.WriteFile(certPath, []byte("cert"), 0600))
require.NoError(t, os.WriteFile(keyPath, []byte("key"), 0600))

cw, err := NewCertWatcher(certPath, keyPath)
require.NoError(t, err)
require.NotNil(t, cw)
require.Len(t, cw.directories, 1)
})

t.Run("cert path does not exist", func(t *testing.T) {
tempDir := t.TempDir()
certPath := filepath.Join(tempDir, "tls.crt")
keyPath := filepath.Join(tempDir, "tls.key")
require.NoError(t, os.WriteFile(keyPath, []byte("key"), 0600))

_, err := NewCertWatcher(certPath, keyPath)
require.Error(t, err)
})

t.Run("key path does not exist", func(t *testing.T) {
tempDir := t.TempDir()
certPath := filepath.Join(tempDir, "tls.crt")
keyPath := filepath.Join(tempDir, "tls.key")
require.NoError(t, os.WriteFile(certPath, []byte("cert"), 0600))

_, err := NewCertWatcher(certPath, keyPath)
require.Error(t, err)
})
}

func TestCertWatcher(t *testing.T) {
tempDir := t.TempDir()
certPath := filepath.Join(tempDir, "tls.crt")
keyPath := filepath.Join(tempDir, "tls.key")
require.NoError(t, os.WriteFile(certPath, []byte("cert"), 0600))
require.NoError(t, os.WriteFile(keyPath, []byte("key"), 0600))

cw, err := NewCertWatcher(certPath, keyPath)
require.NoError(t, err)
require.NotNil(t, cw)

wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
cw.Run()
}()

// Wait a bit for the watcher to start
time.Sleep(100 * time.Millisecond)

// Update the cert file
require.NoError(t, os.WriteFile(certPath, []byte("new cert"), 0600))

select {
case <-cw.Events():
// All good
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for event")
}

// Update the key file
require.NoError(t, os.WriteFile(keyPath, []byte("new key"), 0600))

select {
case <-cw.Events():
// All good
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for event")
}

cw.Close()
wg.Wait()
}
70 changes: 70 additions & 0 deletions internal/server/certwatcher/directorywatcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package certwatcher

import (
"path/filepath"
"time"

"github.com/fsnotify/fsnotify"
)

type directoryWatcher struct {
watcher *fsnotify.Watcher
watchedFiles map[string]string
done chan struct{}
}

func (d *directoryWatcher) watch() <-chan struct{} {
events := make(chan struct{})
go func() {
defer close(events)
for {
select {
case <-d.done:
return
case event, ok := <-d.watcher.Events:
if !ok {
return
}
if d.shouldSendEvent(event) {
select {
case events <- struct{}{}:
case <-d.done:
return
}
}
}
}
}()

return events
}

func (d *directoryWatcher) shouldSendEvent(event fsnotify.Event) bool {
sleepTime := 10 * time.Millisecond
eventPath, _ := filepath.Abs(event.Name)
eventPath, _ = filepath.EvalSymlinks(eventPath)

for abs, previous := range d.watchedFiles {
currentWatchedPath, _ := filepath.Abs(abs)
switch {
case currentWatchedPath == "":
// watched file was removed; wait for write event to trigger reload
d.watchedFiles[abs] = ""
case currentWatchedPath != previous:
// File previously didn't exist; send a signal to the caller
time.Sleep(sleepTime)
d.watchedFiles[abs] = currentWatchedPath
return true
case eventPath == currentWatchedPath && isUpdatedFileEvent(event):
// File was modified so send a signal to the caller
time.Sleep(sleepTime)
d.watchedFiles[abs] = currentWatchedPath
return true
}
}
return false
}

func isUpdatedFileEvent(event fsnotify.Event) bool {
return (event.Op&fsnotify.Write) == fsnotify.Write || (event.Op&fsnotify.Create) == fsnotify.Create
}
Loading