1
1
package serve
2
2
3
3
import (
4
+ "context"
4
5
"fmt"
5
6
"io/ioutil"
6
7
"net"
7
8
"net/http"
9
+ "os"
10
+ "os/signal"
8
11
"strings"
12
+ "syscall"
13
+ "time"
9
14
10
- "github.com/docker/libnetwork/resolvconf"
11
- "github.com/docker/libnetwork/types"
15
+ "github.com/place1/wg-access-server/internal/config"
16
+ "github.com/place1/wg-access-server/internal/devices"
17
+ "github.com/place1/wg-access-server/internal/dnsproxy"
18
+ "github.com/place1/wg-access-server/internal/network"
12
19
"github.com/place1/wg-access-server/internal/services"
13
20
"github.com/place1/wg-access-server/internal/storage"
14
21
"github.com/place1/wg-access-server/pkg/authnz"
15
22
"github.com/place1/wg-access-server/pkg/authnz/authconfig"
16
23
"github.com/place1/wg-access-server/pkg/authnz/authsession"
24
+
25
+ "github.com/docker/libnetwork/resolvconf"
26
+ "github.com/docker/libnetwork/types"
27
+ "github.com/gorilla/mux"
28
+ "github.com/pkg/errors"
29
+ "github.com/place1/wg-embed/pkg/wgembed"
30
+ "github.com/sirupsen/logrus"
17
31
"github.com/vishvananda/netlink"
18
32
"golang.org/x/crypto/bcrypt"
19
33
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
20
34
"gopkg.in/alecthomas/kingpin.v2"
21
35
"gopkg.in/yaml.v2"
22
-
23
- "github.com/gorilla/mux"
24
- "github.com/place1/wg-embed/pkg/wgembed"
25
-
26
- "github.com/pkg/errors"
27
- "github.com/place1/wg-access-server/internal/config"
28
- "github.com/place1/wg-access-server/internal/devices"
29
- "github.com/place1/wg-access-server/internal/dnsproxy"
30
- "github.com/place1/wg-access-server/internal/network"
31
- "github.com/sirupsen/logrus"
32
36
)
33
37
34
38
func Register (app * kingpin.Application ) * servecmd {
@@ -52,7 +56,8 @@ func Register(app *kingpin.Application) *servecmd {
52
56
cli .Flag ("vpn-gateway-interface" , "The gateway network interface (i.e. eth0)" ).Envar ("WG_VPN_GATEWAY_INTERFACE" ).Default (detectDefaultInterface ()).StringVar (& cmd .AppConfig .VPN .GatewayInterface )
53
57
cli .Flag ("vpn-allowed-ips" , "A list of networks that VPN clients will be allowed to connect to via the VPN" ).Envar ("WG_VPN_ALLOWED_IPS" ).Default ("0.0.0.0/0" , "::/0" ).StringsVar (& cmd .AppConfig .VPN .AllowedIPs )
54
58
cli .Flag ("dns-enabled" , "Enable or disable the embedded dns proxy server (useful for development)" ).Envar ("WG_DNS_ENABLED" ).Default ("true" ).BoolVar (& cmd .AppConfig .DNS .Enabled )
55
- cli .Flag ("dns-upstream" , "An upstream DNS server to proxy DNS traffic to. Defaults to resolveconf with Cloudflare DNS as fallback" ).Envar ("WG_DNS_UPSTREAM" ).StringsVar (& cmd .AppConfig .DNS .Upstream )
59
+ cli .Flag ("dns-upstream" , "An upstream DNS server to proxy DNS traffic to. Defaults to resolvconf with Cloudflare DNS as fallback" ).Envar ("WG_DNS_UPSTREAM" ).StringsVar (& cmd .AppConfig .DNS .Upstream )
60
+ cli .Flag ("dns-domain" , "A domain to serve configured device names authoritatively" ).Envar ("WG_DNS_DOMAIN" ).StringVar (& cmd .AppConfig .DNS .Domain )
56
61
return cmd
57
62
}
58
63
@@ -74,6 +79,9 @@ func (cmd *servecmd) Run() {
74
79
if conf .VPN .CIDRv6 == "0" {
75
80
conf .VPN .CIDRv6 = ""
76
81
}
82
+ if conf .DNS .Domain == "0" {
83
+ conf .DNS .Domain = ""
84
+ }
77
85
78
86
// Get the server's IP addresses within the VPN
79
87
var vpnip , vpnipv6 * net.IPNet
@@ -98,6 +106,13 @@ func (cmd *servecmd) Run() {
98
106
conf .VPN .AllowedIPs = append (conf .VPN .AllowedIPs , fmt .Sprintf ("%s/128" , vpnipv6 .IP .String ()))
99
107
vpnipstrings = append (vpnipstrings , vpnipv6 .String ())
100
108
}
109
+ vpnips := make ([]net.IP , 0 , 2 )
110
+ if vpnip != nil {
111
+ vpnips = append (vpnips , vpnip .IP )
112
+ }
113
+ if vpnipv6 != nil {
114
+ vpnips = append (vpnips , vpnipv6 .IP )
115
+ }
101
116
102
117
// WireGuard Server
103
118
wg := wgembed .NewNoOpInterface ()
@@ -120,44 +135,76 @@ func (cmd *servecmd) Run() {
120
135
}
121
136
122
137
if err := wg .LoadConfig (wgconfig ); err != nil {
123
- logrus .Fatal (errors .Wrap (err , "failed to load wireguard config" ))
138
+ logrus .Error (errors .Wrap (err , "failed to load wireguard config" ))
139
+ return
124
140
}
125
141
126
142
logrus .Infof ("wireguard VPN network is %s" , network .StringJoinIPNets (vpnip , vpnipv6 ))
127
143
128
144
if err := network .ConfigureForwarding (conf .VPN .GatewayInterface , conf .VPN .CIDR , conf .VPN .CIDRv6 , conf .VPN .NAT44 , conf .VPN .NAT66 , conf .VPN .AllowedIPs ); err != nil {
129
- logrus .Fatal (err )
145
+ logrus .Error (err )
146
+ return
130
147
}
131
148
}
132
149
150
+ // Storage
151
+ storageBackend , err := storage .NewStorage (conf .Storage )
152
+ if err != nil {
153
+ logrus .Error (errors .Wrap (err , "failed to create storage backend" ))
154
+ return
155
+ }
156
+ if err := storageBackend .Open (); err != nil {
157
+ logrus .Error (errors .Wrap (err , "failed to connect/open storage backend" ))
158
+ return
159
+ }
160
+ defer storageBackend .Close ()
161
+
162
+ // Device manager
163
+ deviceManager := devices .New (wg , storageBackend , conf .VPN .CIDR , conf .VPN .CIDRv6 )
164
+
133
165
// DNS Server
134
166
if conf .DNS .Enabled {
135
- if conf .DNS .Upstream == nil {
167
+ if conf .DNS .Upstream == nil || len ( conf . DNS . Upstream ) <= 0 {
136
168
conf .DNS .Upstream = detectDNSUpstream (conf .VPN .CIDR != "" , conf .VPN .CIDRv6 != "" )
137
169
}
170
+ listenAddr := make ([]string , 0 , 2 )
171
+ for _ , addr := range vpnips {
172
+ listenAddr = append (listenAddr , net .JoinHostPort (addr .String (), "53" ))
173
+ }
138
174
dns , err := dnsproxy .New (dnsproxy.DNSServerOpts {
139
- Upstream : conf .DNS .Upstream ,
175
+ Upstream : conf .DNS .Upstream ,
176
+ Domain : conf .DNS .Domain ,
177
+ ListenAddr : listenAddr ,
140
178
})
141
179
if err != nil {
142
- logrus .Fatal (errors .Wrap (err , "failed to start dns server" ))
180
+ logrus .Error (errors .Wrap (err , "failed to start dns server" ))
181
+ return
143
182
}
144
183
defer dns .Close ()
184
+ if conf .DNS .Domain != "" {
185
+ // Generate initial DNS zone for registered devices
186
+ zone := generateZone (deviceManager , vpnips )
187
+ dns .PushAuthZone (zone )
188
+ // Update the zone in the background whenever a device changes
189
+ storageBackend .OnAdd (
190
+ func (_ * storage.Device ) {
191
+ zone := generateZone (deviceManager , vpnips )
192
+ dns .PushAuthZone (zone )
193
+ },
194
+ )
195
+ storageBackend .OnDelete (
196
+ func (_ * storage.Device ) {
197
+ zone := generateZone (deviceManager , vpnips )
198
+ dns .PushAuthZone (zone )
199
+ },
200
+ )
201
+ }
145
202
}
146
203
147
- // Storage
148
- storageBackend , err := storage .NewStorage (conf .Storage )
149
- if err != nil {
150
- logrus .Fatal (errors .Wrap (err , "failed to create storage backend" ))
151
- }
152
- if err := storageBackend .Open (); err != nil {
153
- logrus .Fatal (errors .Wrap (err , "failed to connect/open storage backend" ))
154
- }
155
- defer storageBackend .Close ()
156
-
157
204
// Services
158
- deviceManager := devices .New (wg , storageBackend , conf .VPN .CIDR , conf .VPN .CIDRv6 )
159
205
if err := deviceManager .StartSync (conf .DisableMetadata ); err != nil {
160
- logrus .Fatal (errors .Wrap (err , "failed to sync" ))
206
+ logrus .Error (errors .Wrap (err , "failed to sync" ))
207
+ return
161
208
}
162
209
163
210
router := mux .NewRouter ()
@@ -171,7 +218,8 @@ func (cmd *servecmd) Run() {
171
218
if conf .Auth .IsEnabled () {
172
219
middleware , err := authnz .NewMiddleware (conf .Auth , claimsMiddleware (conf ))
173
220
if err != nil {
174
- logrus .Fatal (errors .Wrap (err , "failed to set up authnz middleware" ))
221
+ logrus .Error (errors .Wrap (err , "failed to set up authnz middleware" ))
222
+ return
175
223
}
176
224
router .Use (middleware )
177
225
} else {
@@ -203,17 +251,38 @@ func (cmd *servecmd) Run() {
203
251
204
252
publicRouter := router
205
253
254
+ signalChan := make (chan os.Signal , 2 )
255
+ signal .Notify (signalChan , os .Interrupt , syscall .SIGTERM )
256
+ errChan := make (chan error )
257
+
206
258
// Listen
207
259
address := fmt .Sprintf (":%d" , conf .Port )
208
260
srv := & http.Server {
209
261
Addr : address ,
210
262
Handler : publicRouter ,
211
263
}
212
264
213
- // Start Web server
214
- logrus .Infof ("web ui listening on %v" , address )
215
- if err := srv .ListenAndServe (); err != nil {
216
- logrus .Fatal (errors .Wrap (err , "unable to start http server" ))
265
+ go func () {
266
+ // Start Web server
267
+ logrus .Infof ("web ui listening on %v" , address )
268
+ err := srv .ListenAndServe ()
269
+ if err != nil && ! errors .Is (err , http .ErrServerClosed ) {
270
+ errChan <- errors .Wrap (err , "unable to start http server" )
271
+ }
272
+ }()
273
+
274
+ select {
275
+ case <- signalChan :
276
+ ctx := context .Background ()
277
+ ctx , cancel := context .WithTimeout (ctx , 5 * time .Second )
278
+ err = srv .Shutdown (ctx )
279
+ if err != nil {
280
+ logrus .Error (err )
281
+ }
282
+ cancel () // always call cancel to clean up the context
283
+ case err = <- errChan :
284
+ logrus .Error (err )
285
+ return
217
286
}
218
287
}
219
288
@@ -318,6 +387,31 @@ func detectDefaultInterface() string {
318
387
return ""
319
388
}
320
389
390
+ func generateZone (deviceManager * devices.DeviceManager , vpnips []net.IP ) dnsproxy.Zone {
391
+ devs , err := deviceManager .ListAllDevices ()
392
+ if err != nil {
393
+ logrus .Error (errors .Wrap (err , "could not query devices to generate the DNS zone" ))
394
+ }
395
+
396
+ zone := make (dnsproxy.Zone )
397
+ for _ , device := range devs {
398
+ owner := device .Owner
399
+ name := device .Name
400
+ addressStrings := network .SplitAddresses (device .Address )
401
+ addresses := make ([]net.IP , 0 , 2 )
402
+ for _ , str := range addressStrings {
403
+ addr , _ , err := net .ParseCIDR (str )
404
+ if err != nil {
405
+ continue
406
+ }
407
+ addresses = append (addresses , addr )
408
+ }
409
+ zone [dnsproxy.ZoneKey {Owner : owner , Name : name }] = addresses
410
+ }
411
+ zone [dnsproxy.ZoneKey {}] = vpnips
412
+ return zone
413
+ }
414
+
321
415
var missingPrivateKey = `missing wireguard private key:
322
416
323
417
create a key:
0 commit comments