Skip to content

Commit a1a439c

Browse files
Merge pull request #28 from fosrl/dev
MTLS, Connection Monitoring, time zone logger
2 parents f713c29 + e7c8dbc commit a1a439c

File tree

10 files changed

+431
-66
lines changed

10 files changed

+431
-66
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
newt
22
.DS_Store
3-
bin/
3+
bin/
4+
.idea
5+
*.iml
6+
certs/

README.md

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ When Newt receives WireGuard control messages, it will use the information encod
3737
- `dns`: DNS server to use to resolve the endpoint
3838
- `log-level` (optional): The log level to use. Default: INFO
3939
- `updown` (optional): A script to be called when targets are added or removed.
40-
41-
Example:
40+
- `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS. See [mTLS](#mtls)
41+
42+
- Example:
4243

4344
```bash
4445
./newt \
@@ -107,6 +108,38 @@ Returning a string from the script in the format of a target (`ip:dst` so `10.0.
107108

108109
You can look at updown.py as a reference script to get started!
109110

111+
### mTLS
112+
Newt supports mutual TLS (mTLS) authentication, if the server has been configured to request a client certificate.
113+
* Only PKCS12 (.p12 or .pfx) file format is accepted
114+
* The PKCS12 file must contain:
115+
* Private key
116+
* Public certificate
117+
* CA certificate
118+
* Encrypted PKCS12 files are currently not supported
119+
120+
Examples:
121+
122+
```bash
123+
./newt \
124+
--id 31frd0uzbjvp721 \
125+
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
126+
--endpoint https://example.com \
127+
--tls-client-cert ./client.p12
128+
```
129+
130+
```yaml
131+
services:
132+
newt:
133+
image: fosrl/newt
134+
container_name: newt
135+
restart: unless-stopped
136+
environment:
137+
- PANGOLIN_ENDPOINT=https://example.com
138+
- NEWT_ID=2ix2t8xk22ubpfy
139+
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
140+
- TLS_CLIENT_CERT=./client.p12
141+
```
142+
110143
## Build
111144
112145
### Container

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ require (
1010
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
1111
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
1212
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
13+
software.sslmate.com/src/go-pkcs12 v0.5.0
1314
)
1415

1516
require (

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY
2020
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
2121
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
2222
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
23+
software.sslmate.com/src/go-pkcs12 v0.5.0 h1:EC6R394xgENTpZ4RltKydeDUjtlM5drOYIG9c6TVj2M=
24+
software.sslmate.com/src/go-pkcs12 v0.5.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=

logger/logger.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,23 @@ func (l *Logger) log(level LogLevel, format string, args ...interface{}) {
5353
if level < l.level {
5454
return
5555
}
56-
timestamp := time.Now().Format("2006/01/02 15:04:05")
56+
57+
// Get timezone from environment variable or use local timezone
58+
timezone := os.Getenv("LOGGER_TIMEZONE")
59+
var location *time.Location
60+
var err error
61+
62+
if timezone != "" {
63+
location, err = time.LoadLocation(timezone)
64+
if err != nil {
65+
// If invalid timezone, fall back to local
66+
location = time.Local
67+
}
68+
} else {
69+
location = time.Local
70+
}
71+
72+
timestamp := time.Now().In(location).Format("2006/01/02 15:04:05")
5773
message := fmt.Sprintf(format, args...)
5874
l.logger.Printf("%s: %s %s", level.String(), timestamp, message)
5975
}

main.go

Lines changed: 149 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,12 @@ func ping(tnet *netstack.Net, dst string) error {
115115
}
116116

117117
func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) {
118-
ticker := time.NewTicker(10 * time.Second)
118+
initialInterval := 10 * time.Second
119+
maxInterval := 60 * time.Second
120+
currentInterval := initialInterval
121+
consecutiveFailures := 0
122+
123+
ticker := time.NewTicker(currentInterval)
119124
defer ticker.Stop()
120125

121126
go func() {
@@ -124,8 +129,34 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
124129
case <-ticker.C:
125130
err := ping(tnet, serverIP)
126131
if err != nil {
127-
logger.Warn("Periodic ping failed: %v", err)
132+
consecutiveFailures++
133+
logger.Warn("Periodic ping failed (%d consecutive failures): %v",
134+
consecutiveFailures, err)
128135
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
136+
137+
// Increase interval if we have consistent failures, with a maximum cap
138+
if consecutiveFailures >= 3 && currentInterval < maxInterval {
139+
// Increase by 50% each time, up to the maximum
140+
currentInterval = time.Duration(float64(currentInterval) * 1.5)
141+
if currentInterval > maxInterval {
142+
currentInterval = maxInterval
143+
}
144+
ticker.Reset(currentInterval)
145+
logger.Info("Increased ping check interval to %v due to consecutive failures",
146+
currentInterval)
147+
}
148+
} else {
149+
// On success, if we've backed off, gradually return to normal interval
150+
if currentInterval > initialInterval {
151+
currentInterval = time.Duration(float64(currentInterval) * 0.8)
152+
if currentInterval < initialInterval {
153+
currentInterval = initialInterval
154+
}
155+
ticker.Reset(currentInterval)
156+
logger.Info("Decreased ping check interval to %v after successful ping",
157+
currentInterval)
158+
}
159+
consecutiveFailures = 0
129160
}
130161
case <-stopChan:
131162
logger.Info("Stopping ping check")
@@ -135,34 +166,97 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
135166
}()
136167
}
137168

169+
// Function to track connection status and trigger reconnection as needed
170+
func monitorConnectionStatus(tnet *netstack.Net, serverIP string, client *websocket.Client) {
171+
const checkInterval = 30 * time.Second
172+
connectionLost := false
173+
ticker := time.NewTicker(checkInterval)
174+
defer ticker.Stop()
175+
176+
for {
177+
select {
178+
case <-ticker.C:
179+
// Try a ping to see if connection is alive
180+
err := ping(tnet, serverIP)
181+
182+
if err != nil && !connectionLost {
183+
// We just lost connection
184+
connectionLost = true
185+
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
186+
187+
// Notify the user they might need to check their network
188+
logger.Warn("Please check your internet connection and ensure the Pangolin server is online.")
189+
logger.Warn("Newt will continue reconnection attempts automatically when connectivity is restored.")
190+
} else if err == nil && connectionLost {
191+
// Connection has been restored
192+
connectionLost = false
193+
logger.Info("Connection to server restored!")
194+
195+
// Tell the server we're back
196+
err := client.SendMessage("newt/wg/register", map[string]interface{}{
197+
"publicKey": fmt.Sprintf("%s", privateKey.PublicKey()),
198+
})
199+
200+
if err != nil {
201+
logger.Error("Failed to send registration message after reconnection: %v", err)
202+
} else {
203+
logger.Info("Successfully re-registered with server after reconnection")
204+
}
205+
}
206+
}
207+
}
208+
}
209+
138210
func pingWithRetry(tnet *netstack.Net, dst string) error {
139211
const (
140-
maxAttempts = 15
141-
retryDelay = 2 * time.Second
212+
initialMaxAttempts = 15
213+
initialRetryDelay = 2 * time.Second
214+
maxRetryDelay = 60 * time.Second // Cap the maximum delay
142215
)
143216

144-
var lastErr error
145-
for attempt := 1; attempt <= maxAttempts; attempt++ {
146-
logger.Info("Ping attempt %d of %d", attempt, maxAttempts)
217+
attempt := 1
218+
retryDelay := initialRetryDelay
147219

148-
if err := ping(tnet, dst); err != nil {
149-
lastErr = err
150-
logger.Warn("Ping attempt %d failed: %v", attempt, err)
220+
// First try with the initial parameters
221+
logger.Info("Ping attempt %d", attempt)
222+
if err := ping(tnet, dst); err == nil {
223+
// Successful ping
224+
return nil
225+
} else {
226+
logger.Warn("Ping attempt %d failed: %v", attempt, err)
227+
}
228+
229+
// Start a goroutine that will attempt pings indefinitely with increasing delays
230+
go func() {
231+
attempt = 2 // Continue from attempt 2
232+
233+
for {
234+
logger.Info("Ping attempt %d", attempt)
235+
236+
if err := ping(tnet, dst); err != nil {
237+
logger.Warn("Ping attempt %d failed: %v", attempt, err)
238+
239+
// Increase delay after certain thresholds but cap it
240+
if attempt%5 == 0 && retryDelay < maxRetryDelay {
241+
retryDelay = time.Duration(float64(retryDelay) * 1.5)
242+
if retryDelay > maxRetryDelay {
243+
retryDelay = maxRetryDelay
244+
}
245+
logger.Info("Increasing ping retry delay to %v", retryDelay)
246+
}
151247

152-
if attempt < maxAttempts {
153248
time.Sleep(retryDelay)
154-
continue
249+
attempt++
250+
} else {
251+
// Successful ping
252+
logger.Info("Ping succeeded after %d attempts", attempt)
253+
return
155254
}
156-
return fmt.Errorf("all ping attempts failed after %d tries, last error: %w",
157-
maxAttempts, lastErr)
158255
}
256+
}()
159257

160-
// Successful ping
161-
return nil
162-
}
163-
164-
// This shouldn't be reached due to the return in the loop, but added for completeness
165-
return fmt.Errorf("unexpected error: all ping attempts failed")
258+
// Return an error for the first batch of attempts (to maintain compatibility with existing code)
259+
return fmt.Errorf("initial ping attempts failed, continuing in background")
166260
}
167261

168262
func parseLogLevel(level string) logger.LogLevel {
@@ -246,16 +340,17 @@ func resolveDomain(domain string) (string, error) {
246340
}
247341

248342
var (
249-
endpoint string
250-
id string
251-
secret string
252-
mtu string
253-
mtuInt int
254-
dns string
255-
privateKey wgtypes.Key
256-
err error
257-
logLevel string
258-
updownScript string
343+
endpoint string
344+
id string
345+
secret string
346+
mtu string
347+
mtuInt int
348+
dns string
349+
privateKey wgtypes.Key
350+
err error
351+
logLevel string
352+
updownScript string
353+
tlsPrivateKey string
259354
)
260355

261356
func main() {
@@ -267,6 +362,7 @@ func main() {
267362
dns = os.Getenv("DNS")
268363
logLevel = os.Getenv("LOG_LEVEL")
269364
updownScript = os.Getenv("UPDOWN_SCRIPT")
365+
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
270366

271367
if endpoint == "" {
272368
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@@ -289,6 +385,9 @@ func main() {
289385
if updownScript == "" {
290386
flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed")
291387
}
388+
if tlsPrivateKey == "" {
389+
flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate used for mTLS")
390+
}
292391

293392
// do a --version check
294393
version := flag.Bool("version", false, "Print the version")
@@ -314,12 +413,16 @@ func main() {
314413
if err != nil {
315414
logger.Fatal("Failed to generate private key: %v", err)
316415
}
317-
416+
var opt websocket.ClientOption
417+
if tlsPrivateKey != "" {
418+
opt = websocket.WithTLSConfig(tlsPrivateKey)
419+
}
318420
// Create a new client
319421
client, err := websocket.NewClient(
320422
id, // CLI arg takes precedence
321423
secret, // CLI arg takes precedence
322424
endpoint,
425+
opt,
323426
)
324427
if err != nil {
325428
logger.Fatal("Failed to create client: %v", err)
@@ -353,13 +456,8 @@ func main() {
353456

354457
if connected {
355458
logger.Info("Already connected! But I will send a ping anyway...")
356-
// ping(tnet, wgData.ServerIP)
357-
err = pingWithRetry(tnet, wgData.ServerIP)
358-
if err != nil {
359-
// Handle complete failure after all retries
360-
logger.Warn("Failed to ping %s: %v", wgData.ServerIP, err)
361-
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
362-
}
459+
// Even if pingWithRetry returns an error, it will continue trying in the background
460+
_ = pingWithRetry(tnet, wgData.ServerIP) // Ignoring initial error as pings will continue
363461
return
364462
}
365463

@@ -414,17 +512,18 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
414512
}
415513

416514
logger.Info("WireGuard device created. Lets ping the server now...")
417-
// Ping to bring the tunnel up on the server side quickly
418-
// ping(tnet, wgData.ServerIP)
419-
err = pingWithRetry(tnet, wgData.ServerIP)
420-
if err != nil {
421-
// Handle complete failure after all retries
422-
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
423-
}
424515

516+
// Even if pingWithRetry returns an error, it will continue trying in the background
517+
_ = pingWithRetry(tnet, wgData.ServerIP)
518+
519+
// Always mark as connected and start the proxy manager regardless of initial ping result
520+
// as the pings will continue in the background
425521
if !connected {
426522
logger.Info("Starting ping check")
427523
startPingCheck(tnet, wgData.ServerIP, pingStopChan)
524+
525+
// Start connection monitoring in a separate goroutine
526+
go monitorConnectionStatus(tnet, wgData.ServerIP, client)
428527
}
429528

430529
// Create proxy manager
@@ -552,10 +651,13 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
552651
// Wait for interrupt signal
553652
sigCh := make(chan os.Signal, 1)
554653
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
555-
<-sigCh
654+
sigReceived := <-sigCh
556655

557656
// Cleanup
558-
dev.Close()
657+
logger.Info("Received %s signal, stopping", sigReceived.String())
658+
if dev != nil {
659+
dev.Close()
660+
}
559661
}
560662

561663
func parseTargetData(data interface{}) (TargetData, error) {

0 commit comments

Comments
 (0)