Skip to content

Commit d28e3ca

Browse files
feat/mtls-support-cert: doc update, removing config.Endpoint loading duplicates, handling null-pointer case and some logging
1 parent b41570e commit d28e3ca

File tree

4 files changed

+20
-21
lines changed

4 files changed

+20
-21
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ Examples:
124124
--id 31frd0uzbjvp721 \
125125
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
126126
--endpoint https://example.com \
127-
--tls-client-cert /client.p12
127+
--tls-client-cert ./client.p12
128128
```
129129

130130
```yaml
@@ -137,7 +137,7 @@ services:
137137
- PANGOLIN_ENDPOINT=https://example.com
138138
- NEWT_ID=2ix2t8xk22ubpfy
139139
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
140-
- TLS_CLIENT_CERT=/client.p12
140+
- TLS_CLIENT_CERT=./client.p12
141141
```
142142
143143
## Build

main.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -561,10 +561,13 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
561561
// Wait for interrupt signal
562562
sigCh := make(chan os.Signal, 1)
563563
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
564-
<-sigCh
564+
sigReceived := <-sigCh
565565

566566
// Cleanup
567-
dev.Close()
567+
logger.Info("Received %s signal, stopping", sigReceived.String())
568+
if dev != nil {
569+
dev.Close()
570+
}
568571
}
569572

570573
func parseTargetData(data interface{}) (TargetData, error) {

websocket/client.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,14 @@ func (c *Client) getToken() (string, error) {
162162
// Ensure we have the base URL without trailing slashes
163163
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
164164

165+
var tlsConfig *tls.Config = nil
166+
if c.config.TlsClientCert != "" {
167+
tlsConfig, err = loadClientCertificate(c.config.TlsClientCert)
168+
if err != nil {
169+
return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
170+
}
171+
}
172+
165173
// If we already have a token, try to use it
166174
if c.config.Token != "" {
167175
tokenCheckData := map[string]interface{}{
@@ -190,11 +198,7 @@ func (c *Client) getToken() (string, error) {
190198

191199
// Make the request
192200
client := &http.Client{}
193-
if c.config.TlsClientCert != "" {
194-
tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert)
195-
if err != nil {
196-
return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
197-
}
201+
if tlsConfig != nil {
198202
client.Transport = &http.Transport{
199203
TLSClientConfig: tlsConfig,
200204
}
@@ -242,11 +246,7 @@ func (c *Client) getToken() (string, error) {
242246

243247
// Make the request
244248
client := &http.Client{}
245-
if c.config.TlsClientCert != "" {
246-
tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert)
247-
if err != nil {
248-
return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
249-
}
249+
if tlsConfig != nil {
250250
client.Transport = &http.Transport{
251251
TLSClientConfig: tlsConfig,
252252
}
@@ -329,7 +329,7 @@ func (c *Client) establishConnection() error {
329329
dialer := websocket.DefaultDialer
330330
if c.config.TlsClientCert != "" {
331331
logger.Info("Adding tls to req")
332-
tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert)
332+
tlsConfig, err := loadClientCertificate(c.config.TlsClientCert)
333333
if err != nil {
334334
return fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
335335
}
@@ -395,7 +395,7 @@ func (c *Client) setConnected(status bool) {
395395
}
396396

397397
// LoadClientCertificate Helper method to load client certificates
398-
func LoadClientCertificate(p12Path string) (*tls.Config, error) {
398+
func loadClientCertificate(p12Path string) (*tls.Config, error) {
399399
logger.Info("Loading tls-client-cert %s", p12Path)
400400
// Read the PKCS12 file
401401
p12Data, err := os.ReadFile(p12Path)
@@ -408,7 +408,7 @@ func LoadClientCertificate(p12Path string) (*tls.Config, error) {
408408
if err != nil {
409409
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
410410
}
411-
411+
412412
// Create certificate
413413
cert := tls.Certificate{
414414
Certificate: [][]byte{certificate.Raw},

websocket/config.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,6 @@ func (c *Client) loadConfig() error {
6161
c.config.Endpoint = config.Endpoint
6262
c.baseURL = config.Endpoint
6363
}
64-
if c.config.Endpoint == "" {
65-
c.config.Endpoint = config.Endpoint
66-
c.baseURL = config.Endpoint
67-
}
6864

6965
return nil
7066
}

0 commit comments

Comments
 (0)