Skip to content

Commit b41570e

Browse files
feat/mtls-support-cert: config support
1 parent 435b638 commit b41570e

File tree

4 files changed

+49
-30
lines changed

4 files changed

+49
-30
lines changed

main.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,8 @@ func main() {
321321
}
322322
var opt websocket.ClientOption
323323
if tlsPrivateKey != "" {
324-
tlsConfig, err := websocket.LoadClientCertificate(tlsPrivateKey)
325-
if err != nil {
326-
logger.Fatal("Failed to load client certificate: %v", err)
327-
}
328-
opt = websocket.WithTLSConfig(tlsConfig)
324+
opt = websocket.WithTLSConfig(tlsPrivateKey)
329325
}
330-
331326
// Create a new client
332327
client, err := websocket.NewClient(
333328
id, // CLI arg takes precedence

websocket/client.go

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,12 @@ import (
1919
)
2020

2121
type Client struct {
22-
conn *websocket.Conn
23-
config *Config
24-
baseURL string
25-
handlers map[string]MessageHandler
26-
done chan struct{}
27-
handlersMux sync.RWMutex
28-
tlsConfig *tls.Config
29-
22+
conn *websocket.Conn
23+
config *Config
24+
baseURL string
25+
handlers map[string]MessageHandler
26+
done chan struct{}
27+
handlersMux sync.RWMutex
3028
reconnectInterval time.Duration
3129
isConnected bool
3230
reconnectMux sync.RWMutex
@@ -45,9 +43,9 @@ func WithBaseURL(url string) ClientOption {
4543
}
4644
}
4745

48-
func WithTLSConfig(tlsConfig *tls.Config) ClientOption {
46+
func WithTLSConfig(tlsClientCertPath string) ClientOption {
4947
return func(c *Client) {
50-
c.tlsConfig = tlsConfig
48+
c.config.TlsClientCert = tlsClientCertPath
5149
}
5250
}
5351

@@ -73,8 +71,13 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C
7371
}
7472

7573
// Apply options before loading config
76-
for _, opt := range opts {
77-
opt(client)
74+
if opts != nil {
75+
for _, opt := range opts {
76+
if opt == nil {
77+
continue
78+
}
79+
opt(client)
80+
}
7881
}
7982

8083
// Load existing config if available
@@ -187,10 +190,13 @@ func (c *Client) getToken() (string, error) {
187190

188191
// Make the request
189192
client := &http.Client{}
190-
if c.tlsConfig != nil {
191-
logger.Info("Adding tls to req")
193+
if c.config.TlsClientCert != "" {
194+
tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert)
195+
if err != nil {
196+
return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
197+
}
192198
client.Transport = &http.Transport{
193-
TLSClientConfig: c.tlsConfig,
199+
TLSClientConfig: tlsConfig,
194200
}
195201
}
196202
resp, err := client.Do(req)
@@ -236,9 +242,13 @@ func (c *Client) getToken() (string, error) {
236242

237243
// Make the request
238244
client := &http.Client{}
239-
if c.tlsConfig != nil {
245+
if c.config.TlsClientCert != "" {
246+
tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert)
247+
if err != nil {
248+
return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
249+
}
240250
client.Transport = &http.Transport{
241-
TLSClientConfig: c.tlsConfig,
251+
TLSClientConfig: tlsConfig,
242252
}
243253
}
244254
resp, err := client.Do(req)
@@ -317,8 +327,13 @@ func (c *Client) establishConnection() error {
317327

318328
// Connect to WebSocket
319329
dialer := websocket.DefaultDialer
320-
if c.tlsConfig != nil {
321-
dialer.TLSClientConfig = c.tlsConfig
330+
if c.config.TlsClientCert != "" {
331+
logger.Info("Adding tls to req")
332+
tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert)
333+
if err != nil {
334+
return fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
335+
}
336+
dialer.TLSClientConfig = tlsConfig
322337
}
323338
conn, _, err := dialer.Dial(u.String(), nil)
324339
if err != nil {
@@ -381,6 +396,7 @@ func (c *Client) setConnected(status bool) {
381396

382397
// LoadClientCertificate Helper method to load client certificates
383398
func LoadClientCertificate(p12Path string) (*tls.Config, error) {
399+
logger.Info("Loading tls-client-cert %s", p12Path)
384400
// Read the PKCS12 file
385401
p12Data, err := os.ReadFile(p12Path)
386402
if err != nil {
@@ -392,7 +408,7 @@ func LoadClientCertificate(p12Path string) (*tls.Config, error) {
392408
if err != nil {
393409
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
394410
}
395-
411+
396412
// Create certificate
397413
cert := tls.Certificate{
398414
Certificate: [][]byte{certificate.Raw},

websocket/config.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ func (c *Client) loadConfig() error {
5454
if c.config.Secret == "" {
5555
c.config.Secret = config.Secret
5656
}
57+
if c.config.TlsClientCert == "" {
58+
c.config.TlsClientCert = config.TlsClientCert
59+
}
60+
if c.config.Endpoint == "" {
61+
c.config.Endpoint = config.Endpoint
62+
c.baseURL = config.Endpoint
63+
}
5764
if c.config.Endpoint == "" {
5865
c.config.Endpoint = config.Endpoint
5966
c.baseURL = config.Endpoint

websocket/types.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
package websocket
22

33
type Config struct {
4-
NewtID string `json:"newtId"`
5-
Secret string `json:"secret"`
6-
Token string `json:"token"`
7-
Endpoint string `json:"endpoint"`
4+
NewtID string `json:"newtId"`
5+
Secret string `json:"secret"`
6+
Token string `json:"token"`
7+
Endpoint string `json:"endpoint"`
8+
TlsClientCert string `json:"tlsClientCert"`
89
}
910

1011
type TokenResponse struct {

0 commit comments

Comments
 (0)