Skip to content

Commit cb6361b

Browse files
committed
Watch the tls certs for changes update the served certs
Signed-off-by: Dan Bason <dan.bason@dronedeploy.com>
1 parent b46adb6 commit cb6361b

File tree

7 files changed

+435
-3
lines changed

7 files changed

+435
-3
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ require (
2626
github.com/expr-lang/expr v1.17.6
2727
github.com/fatih/structtag v1.2.0
2828
github.com/fluxcd/pkg/kustomize v1.22.0
29+
github.com/fsnotify/fsnotify v1.9.0
2930
github.com/go-git/go-git/v5 v5.16.2
3031
github.com/go-logr/logr v1.4.3
3132
github.com/go-logr/zapr v1.3.0
@@ -126,7 +127,6 @@ require (
126127
github.com/exponent-io/jsonpath v0.0.0-20210407135951-1de76d718b3f // indirect
127128
github.com/fatih/color v1.16.0 // indirect
128129
github.com/felixge/httpsnoop v1.0.4 // indirect
129-
github.com/fsnotify/fsnotify v1.9.0 // indirect
130130
github.com/fxamacker/cbor/v2 v2.9.0 // indirect
131131
github.com/go-errors/errors v1.5.1 // indirect
132132
github.com/go-fed/httpsig v1.1.0 // indirect
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package certloader
2+
3+
import (
4+
"crypto/tls"
5+
"sync"
6+
7+
"github.com/akuity/kargo/internal/server/certwatcher"
8+
"github.com/akuity/kargo/pkg/logging"
9+
)
10+
11+
type CertLoader struct {
12+
logger *logging.Logger
13+
certPath, keyPath string
14+
done chan struct{}
15+
certWatcher *certwatcher.CertWatcher
16+
17+
cert *tls.Certificate
18+
certLock sync.RWMutex
19+
}
20+
21+
func NewCertLoader(logger *logging.Logger, certPath, keyPath string) (*CertLoader, error) {
22+
certWatcher, err := certwatcher.NewCertWatcher(certPath, keyPath)
23+
if err != nil {
24+
return nil, err
25+
}
26+
27+
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
28+
if err != nil {
29+
return nil, err
30+
}
31+
32+
c := &CertLoader{
33+
logger: logger,
34+
certWatcher: certWatcher,
35+
cert: &cert,
36+
}
37+
38+
go c.run()
39+
40+
return c, nil
41+
}
42+
43+
func (c *CertLoader) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
44+
c.certLock.RLock()
45+
defer c.certLock.RUnlock()
46+
return c.cert, nil
47+
}
48+
49+
func (c *CertLoader) Close() {
50+
c.certWatcher.Close()
51+
close(c.done)
52+
}
53+
54+
func (c *CertLoader) run() {
55+
go c.certWatcher.Run()
56+
for {
57+
select {
58+
case <-c.done:
59+
return
60+
case _, ok := <-c.certWatcher.Events():
61+
if !ok {
62+
return
63+
}
64+
cert, err := tls.LoadX509KeyPair(c.certPath, c.keyPath)
65+
if err != nil {
66+
c.logger.Error(err, "failed to load certificate and key pair, keeping existing certificate")
67+
continue
68+
}
69+
c.certLock.Lock()
70+
c.cert = &cert
71+
c.certLock.Unlock()
72+
}
73+
}
74+
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package certwatcher
2+
3+
import (
4+
"fmt"
5+
"os"
6+
"path/filepath"
7+
"sync"
8+
9+
"github.com/fsnotify/fsnotify"
10+
)
11+
12+
// Certwatcher watches for any changes to a certificate and key pair.
13+
// It is used to reload the certificate and key pair when they are updated.
14+
// It is also used to reload the certificate and key pair when the file is
15+
// created.
16+
type CertWatcher struct {
17+
directories map[string]*directoryWatcher
18+
notify chan struct{}
19+
}
20+
21+
func NewCertWatcher(certPath, keyPath string) (*CertWatcher, error) {
22+
certWatcher := &CertWatcher{
23+
directories: make(map[string]*directoryWatcher),
24+
notify: make(chan struct{}),
25+
}
26+
27+
err := certWatcher.addPath(certPath)
28+
if err != nil {
29+
return nil, err
30+
}
31+
32+
err = certWatcher.addPath(keyPath)
33+
if err != nil {
34+
return nil, err
35+
}
36+
37+
return certWatcher, nil
38+
}
39+
40+
// Events returns a channel that will be notified when the certificate or key
41+
// pair is updated.
42+
func (c *CertWatcher) Events() <-chan struct{} {
43+
return c.notify
44+
}
45+
46+
// Run starts the certwatcher and watches for changes to the certificate and
47+
// key pair. Run blocks until the certwatcher is closed.
48+
func (c *CertWatcher) Run() {
49+
defer close(c.notify)
50+
wg := sync.WaitGroup{}
51+
for _, dirWatcher := range c.directories {
52+
wg.Add(1)
53+
go func(dirWatcher *directoryWatcher) {
54+
defer wg.Done()
55+
events := dirWatcher.watch()
56+
for range events {
57+
select {
58+
case c.notify <- struct{}{}:
59+
case <-dirWatcher.done:
60+
return
61+
}
62+
}
63+
}(dirWatcher)
64+
}
65+
wg.Wait()
66+
}
67+
68+
// Close closes the certwatcher and stops watching for changes.
69+
func (c *CertWatcher) Close() {
70+
for _, dirWatcher := range c.directories {
71+
close(dirWatcher.done)
72+
dirWatcher.watcher.Close()
73+
}
74+
}
75+
76+
func (c *CertWatcher) addPath(path string) error {
77+
_, err := os.Stat(path)
78+
if err != nil {
79+
return fmt.Errorf("failed to stat %q: %w", path, err)
80+
}
81+
82+
absolutePath, err := filepath.Abs(path)
83+
if err != nil {
84+
return fmt.Errorf("failed to get absolute path for %q: %w", path, err)
85+
}
86+
fp, _ := filepath.EvalSymlinks(absolutePath)
87+
88+
fileDir := filepath.Dir(absolutePath)
89+
90+
dirWatcher, ok := c.directories[fileDir]
91+
if ok {
92+
dirWatcher.watchedFiles[absolutePath] = fp
93+
return nil
94+
}
95+
96+
watcher, err := fsnotify.NewWatcher()
97+
if err != nil {
98+
return fmt.Errorf("failed to create fsnotify watcher: %w", err)
99+
}
100+
err = watcher.Add(fileDir)
101+
if err != nil {
102+
return fmt.Errorf("failed to add %q to fsnotify watcher: %w", fileDir, err)
103+
}
104+
105+
dirWatcher = &directoryWatcher{
106+
watcher: watcher,
107+
watchedFiles: make(map[string]string),
108+
done: make(chan struct{}),
109+
}
110+
111+
dirWatcher.watchedFiles[absolutePath] = fp
112+
c.directories[fileDir] = dirWatcher
113+
return nil
114+
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package certwatcher
2+
3+
import (
4+
"os"
5+
"path/filepath"
6+
"sync"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestNewCertWatcher(t *testing.T) {
14+
t.Run("success", func(t *testing.T) {
15+
tempDir := t.TempDir()
16+
certPath := filepath.Join(tempDir, "tls.crt")
17+
keyPath := filepath.Join(tempDir, "tls.key")
18+
require.NoError(t, os.WriteFile(certPath, []byte("cert"), 0600))
19+
require.NoError(t, os.WriteFile(keyPath, []byte("key"), 0600))
20+
21+
cw, err := NewCertWatcher(certPath, keyPath)
22+
require.NoError(t, err)
23+
require.NotNil(t, cw)
24+
require.Len(t, cw.directories, 1)
25+
})
26+
27+
t.Run("cert path does not exist", func(t *testing.T) {
28+
tempDir := t.TempDir()
29+
certPath := filepath.Join(tempDir, "tls.crt")
30+
keyPath := filepath.Join(tempDir, "tls.key")
31+
require.NoError(t, os.WriteFile(keyPath, []byte("key"), 0600))
32+
33+
_, err := NewCertWatcher(certPath, keyPath)
34+
require.Error(t, err)
35+
})
36+
37+
t.Run("key path does not exist", func(t *testing.T) {
38+
tempDir := t.TempDir()
39+
certPath := filepath.Join(tempDir, "tls.crt")
40+
keyPath := filepath.Join(tempDir, "tls.key")
41+
require.NoError(t, os.WriteFile(certPath, []byte("cert"), 0600))
42+
43+
_, err := NewCertWatcher(certPath, keyPath)
44+
require.Error(t, err)
45+
})
46+
}
47+
48+
func TestCertWatcher(t *testing.T) {
49+
tempDir := t.TempDir()
50+
certPath := filepath.Join(tempDir, "tls.crt")
51+
keyPath := filepath.Join(tempDir, "tls.key")
52+
require.NoError(t, os.WriteFile(certPath, []byte("cert"), 0600))
53+
require.NoError(t, os.WriteFile(keyPath, []byte("key"), 0600))
54+
55+
cw, err := NewCertWatcher(certPath, keyPath)
56+
require.NoError(t, err)
57+
require.NotNil(t, cw)
58+
59+
wg := sync.WaitGroup{}
60+
wg.Add(1)
61+
go func() {
62+
defer wg.Done()
63+
cw.Run()
64+
}()
65+
66+
// Wait a bit for the watcher to start
67+
time.Sleep(100 * time.Millisecond)
68+
69+
// Update the cert file
70+
require.NoError(t, os.WriteFile(certPath, []byte("new cert"), 0600))
71+
72+
select {
73+
case <-cw.Events():
74+
// All good
75+
case <-time.After(5 * time.Second):
76+
t.Fatal("timed out waiting for event")
77+
}
78+
79+
// Update the key file
80+
require.NoError(t, os.WriteFile(keyPath, []byte("new key"), 0600))
81+
82+
select {
83+
case <-cw.Events():
84+
// All good
85+
case <-time.After(5 * time.Second):
86+
t.Fatal("timed out waiting for event")
87+
}
88+
89+
cw.Close()
90+
wg.Wait()
91+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package certwatcher
2+
3+
import (
4+
"path/filepath"
5+
"time"
6+
7+
"github.com/fsnotify/fsnotify"
8+
)
9+
10+
type directoryWatcher struct {
11+
watcher *fsnotify.Watcher
12+
watchedFiles map[string]string
13+
done chan struct{}
14+
}
15+
16+
func (d *directoryWatcher) watch() <-chan struct{} {
17+
events := make(chan struct{})
18+
go func() {
19+
defer close(events)
20+
for {
21+
select {
22+
case <-d.done:
23+
return
24+
case event, ok := <-d.watcher.Events:
25+
if !ok {
26+
return
27+
}
28+
if d.shouldSendEvent(event) {
29+
select {
30+
case events <- struct{}{}:
31+
case <-d.done:
32+
return
33+
}
34+
}
35+
}
36+
}
37+
}()
38+
39+
return events
40+
}
41+
42+
func (d *directoryWatcher) shouldSendEvent(event fsnotify.Event) bool {
43+
sleepTime := 10 * time.Millisecond
44+
eventPath, _ := filepath.Abs(event.Name)
45+
eventPath, _ = filepath.EvalSymlinks(eventPath)
46+
47+
for abs, previous := range d.watchedFiles {
48+
currentWatchedPath, _ := filepath.Abs(abs)
49+
switch {
50+
case currentWatchedPath == "":
51+
// watched file was removed; wait for write event to trigger reload
52+
d.watchedFiles[abs] = ""
53+
case currentWatchedPath != previous:
54+
// File previously didn't exist; send a signal to the caller
55+
time.Sleep(sleepTime)
56+
d.watchedFiles[abs] = currentWatchedPath
57+
return true
58+
case eventPath == currentWatchedPath && isUpdatedFileEvent(event):
59+
// File was modified so send a signal to the caller
60+
time.Sleep(sleepTime)
61+
d.watchedFiles[abs] = currentWatchedPath
62+
return true
63+
}
64+
}
65+
return false
66+
}
67+
68+
func isUpdatedFileEvent(event fsnotify.Event) bool {
69+
return (event.Op&fsnotify.Write) == fsnotify.Write || (event.Op&fsnotify.Create) == fsnotify.Create
70+
}

0 commit comments

Comments
 (0)