Skip to content

Commit 9b3c826

Browse files
feat/mtls-support
1 parent 2ff8df9 commit 9b3c826

File tree

5 files changed

+127
-14
lines changed

5 files changed

+127
-14
lines changed

README.md

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ When Newt receives WireGuard control messages, it will use the information encod
3737
- `dns`: DNS server to use to resolve the endpoint
3838
- `log-level` (optional): The log level to use. Default: INFO
3939
- `updown` (optional): A script to be called when targets are added or removed.
40-
41-
Example:
40+
- `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS. See [mTLS](#mtls)
41+
42+
- Example:
4243

4344
```bash
4445
./newt \
@@ -107,6 +108,38 @@ Returning a string from the script in the format of a target (`ip:dst` so `10.0.
107108

108109
You can look at updown.py as a reference script to get started!
109110

111+
### mTLS
112+
Newt supports mutual TLS (mTLS) authentication, if the server has been configured to request a client certificate.
113+
* Only PKCS12 (.p12 or .pfx) file format is accepted
114+
* The PKCS12 file must contain:
115+
* Private key
116+
* Public certificate
117+
* CA certificate
118+
* Encrypted PKCS12 files are currently not supported
119+
120+
Examples:
121+
122+
```bash
123+
./newt \
124+
--id 31frd0uzbjvp721 \
125+
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
126+
--endpoint https://example.com \
127+
--tls-client-cert /client.p12
128+
```
129+
130+
```yaml
131+
services:
132+
newt:
133+
image: fosrl/newt
134+
container_name: newt
135+
restart: unless-stopped
136+
environment:
137+
- PANGOLIN_ENDPOINT=https://example.com
138+
- NEWT_ID=2ix2t8xk22ubpfy
139+
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
140+
- TLS_CLIENT_CERT=/client.p12
141+
```
142+
110143
## Build
111144
112145
### Container

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ require (
1010
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
1111
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
1212
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
13+
software.sslmate.com/src/go-pkcs12 v0.5.0
1314
)
1415

1516
require (

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY
2020
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
2121
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
2222
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
23+
software.sslmate.com/src/go-pkcs12 v0.5.0 h1:EC6R394xgENTpZ4RltKydeDUjtlM5drOYIG9c6TVj2M=
24+
software.sslmate.com/src/go-pkcs12 v0.5.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=

main.go

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -246,16 +246,17 @@ func resolveDomain(domain string) (string, error) {
246246
}
247247

248248
var (
249-
endpoint string
250-
id string
251-
secret string
252-
mtu string
253-
mtuInt int
254-
dns string
255-
privateKey wgtypes.Key
256-
err error
257-
logLevel string
258-
updownScript string
249+
endpoint string
250+
id string
251+
secret string
252+
mtu string
253+
mtuInt int
254+
dns string
255+
privateKey wgtypes.Key
256+
err error
257+
logLevel string
258+
updownScript string
259+
tlsPrivateKey string
259260
)
260261

261262
func main() {
@@ -267,6 +268,7 @@ func main() {
267268
dns = os.Getenv("DNS")
268269
logLevel = os.Getenv("LOG_LEVEL")
269270
updownScript = os.Getenv("UPDOWN_SCRIPT")
271+
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
270272

271273
if endpoint == "" {
272274
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@@ -289,6 +291,9 @@ func main() {
289291
if updownScript == "" {
290292
flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed")
291293
}
294+
if tlsPrivateKey == "" {
295+
flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate used for mTLS")
296+
}
292297

293298
// do a --version check
294299
version := flag.Bool("version", false, "Print the version")
@@ -314,12 +319,21 @@ func main() {
314319
if err != nil {
315320
logger.Fatal("Failed to generate private key: %v", err)
316321
}
322+
var opt websocket.ClientOption
323+
if tlsPrivateKey != "" {
324+
tlsConfig, err := websocket.LoadClientCertificate(tlsPrivateKey)
325+
if err != nil {
326+
logger.Fatal("Failed to load client certificate: %v", err)
327+
}
328+
opt = websocket.WithTLSConfig(tlsConfig)
329+
}
317330

318331
// Create a new client
319332
client, err := websocket.NewClient(
320333
id, // CLI arg takes precedence
321334
secret, // CLI arg takes precedence
322335
endpoint,
336+
opt,
323337
)
324338
if err != nil {
325339
logger.Fatal("Failed to create client: %v", err)

websocket/client.go

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@ package websocket
22

33
import (
44
"bytes"
5+
"crypto/tls"
6+
"crypto/x509"
57
"encoding/json"
68
"fmt"
79
"net/http"
810
"net/url"
11+
"os"
12+
"software.sslmate.com/src/go-pkcs12"
913
"strings"
1014
"sync"
1115
"time"
1216

1317
"github.com/fosrl/newt/logger"
14-
1518
"github.com/gorilla/websocket"
1619
)
1720

@@ -22,6 +25,7 @@ type Client struct {
2225
handlers map[string]MessageHandler
2326
done chan struct{}
2427
handlersMux sync.RWMutex
28+
tlsConfig *tls.Config
2529

2630
reconnectInterval time.Duration
2731
isConnected bool
@@ -41,6 +45,12 @@ func WithBaseURL(url string) ClientOption {
4145
}
4246
}
4347

48+
func WithTLSConfig(tlsConfig *tls.Config) ClientOption {
49+
return func(c *Client) {
50+
c.tlsConfig = tlsConfig
51+
}
52+
}
53+
4454
func (c *Client) OnConnect(callback func() error) {
4555
c.onConnect = callback
4656
}
@@ -177,6 +187,12 @@ func (c *Client) getToken() (string, error) {
177187

178188
// Make the request
179189
client := &http.Client{}
190+
if c.tlsConfig != nil {
191+
logger.Info("Adding tls to req")
192+
client.Transport = &http.Transport{
193+
TLSClientConfig: c.tlsConfig,
194+
}
195+
}
180196
resp, err := client.Do(req)
181197
if err != nil {
182198
return "", fmt.Errorf("failed to check token validity: %w", err)
@@ -220,6 +236,11 @@ func (c *Client) getToken() (string, error) {
220236

221237
// Make the request
222238
client := &http.Client{}
239+
if c.tlsConfig != nil {
240+
client.Transport = &http.Transport{
241+
TLSClientConfig: c.tlsConfig,
242+
}
243+
}
223244
resp, err := client.Do(req)
224245
if err != nil {
225246
return "", fmt.Errorf("failed to request new token: %w", err)
@@ -295,7 +316,11 @@ func (c *Client) establishConnection() error {
295316
u.RawQuery = q.Encode()
296317

297318
// Connect to WebSocket
298-
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
319+
dialer := websocket.DefaultDialer
320+
if c.tlsConfig != nil {
321+
dialer.TLSClientConfig = c.tlsConfig
322+
}
323+
conn, _, err := dialer.Dial(u.String(), nil)
299324
if err != nil {
300325
return fmt.Errorf("failed to connect to WebSocket: %w", err)
301326
}
@@ -353,3 +378,41 @@ func (c *Client) setConnected(status bool) {
353378
defer c.reconnectMux.Unlock()
354379
c.isConnected = status
355380
}
381+
382+
// LoadClientCertificate Helper method to load client certificates
383+
func LoadClientCertificate(p12Path string) (*tls.Config, error) {
384+
// Read the PKCS12 file
385+
p12Data, err := os.ReadFile(p12Path)
386+
if err != nil {
387+
return nil, fmt.Errorf("failed to read PKCS12 file: %w", err)
388+
}
389+
390+
// Parse PKCS12 with empty password for non-encrypted files
391+
privateKey, certificate, caCerts, err := pkcs12.DecodeChain(p12Data, "")
392+
if err != nil {
393+
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
394+
}
395+
396+
// Create certificate
397+
cert := tls.Certificate{
398+
Certificate: [][]byte{certificate.Raw},
399+
PrivateKey: privateKey,
400+
}
401+
402+
// Optional: Add CA certificates if present
403+
rootCAs, err := x509.SystemCertPool()
404+
if err != nil {
405+
return nil, fmt.Errorf("failed to load system cert pool: %w", err)
406+
}
407+
if len(caCerts) > 0 {
408+
for _, caCert := range caCerts {
409+
rootCAs.AddCert(caCert)
410+
}
411+
}
412+
413+
// Create TLS configuration
414+
return &tls.Config{
415+
Certificates: []tls.Certificate{cert},
416+
RootCAs: rootCAs,
417+
}, nil
418+
}

0 commit comments

Comments
 (0)