Skip to content

Commit a2d89d4

Browse files
authored
Merge pull request Place1#106 from DasSkelett/feature/generate-client-dns
2 parents 4fae08b + 52be27a commit a2d89d4

File tree

13 files changed

+567
-174
lines changed

13 files changed

+567
-174
lines changed

cmd/serve/main.go

Lines changed: 129 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,38 @@
11
package serve
22

33
import (
4+
"context"
45
"fmt"
56
"io/ioutil"
67
"net"
78
"net/http"
9+
"os"
10+
"os/signal"
811
"strings"
12+
"syscall"
13+
"time"
914

10-
"github.com/docker/libnetwork/resolvconf"
11-
"github.com/docker/libnetwork/types"
15+
"github.com/place1/wg-access-server/internal/config"
16+
"github.com/place1/wg-access-server/internal/devices"
17+
"github.com/place1/wg-access-server/internal/dnsproxy"
18+
"github.com/place1/wg-access-server/internal/network"
1219
"github.com/place1/wg-access-server/internal/services"
1320
"github.com/place1/wg-access-server/internal/storage"
1421
"github.com/place1/wg-access-server/pkg/authnz"
1522
"github.com/place1/wg-access-server/pkg/authnz/authconfig"
1623
"github.com/place1/wg-access-server/pkg/authnz/authsession"
24+
25+
"github.com/docker/libnetwork/resolvconf"
26+
"github.com/docker/libnetwork/types"
27+
"github.com/gorilla/mux"
28+
"github.com/pkg/errors"
29+
"github.com/place1/wg-embed/pkg/wgembed"
30+
"github.com/sirupsen/logrus"
1731
"github.com/vishvananda/netlink"
1832
"golang.org/x/crypto/bcrypt"
1933
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
2034
"gopkg.in/alecthomas/kingpin.v2"
2135
"gopkg.in/yaml.v2"
22-
23-
"github.com/gorilla/mux"
24-
"github.com/place1/wg-embed/pkg/wgembed"
25-
26-
"github.com/pkg/errors"
27-
"github.com/place1/wg-access-server/internal/config"
28-
"github.com/place1/wg-access-server/internal/devices"
29-
"github.com/place1/wg-access-server/internal/dnsproxy"
30-
"github.com/place1/wg-access-server/internal/network"
31-
"github.com/sirupsen/logrus"
3236
)
3337

3438
func Register(app *kingpin.Application) *servecmd {
@@ -52,7 +56,8 @@ func Register(app *kingpin.Application) *servecmd {
5256
cli.Flag("vpn-gateway-interface", "The gateway network interface (i.e. eth0)").Envar("WG_VPN_GATEWAY_INTERFACE").Default(detectDefaultInterface()).StringVar(&cmd.AppConfig.VPN.GatewayInterface)
5357
cli.Flag("vpn-allowed-ips", "A list of networks that VPN clients will be allowed to connect to via the VPN").Envar("WG_VPN_ALLOWED_IPS").Default("0.0.0.0/0", "::/0").StringsVar(&cmd.AppConfig.VPN.AllowedIPs)
5458
cli.Flag("dns-enabled", "Enable or disable the embedded dns proxy server (useful for development)").Envar("WG_DNS_ENABLED").Default("true").BoolVar(&cmd.AppConfig.DNS.Enabled)
55-
cli.Flag("dns-upstream", "An upstream DNS server to proxy DNS traffic to. Defaults to resolveconf with Cloudflare DNS as fallback").Envar("WG_DNS_UPSTREAM").StringsVar(&cmd.AppConfig.DNS.Upstream)
59+
cli.Flag("dns-upstream", "An upstream DNS server to proxy DNS traffic to. Defaults to resolvconf with Cloudflare DNS as fallback").Envar("WG_DNS_UPSTREAM").StringsVar(&cmd.AppConfig.DNS.Upstream)
60+
cli.Flag("dns-domain", "A domain to serve configured device names authoritatively").Envar("WG_DNS_DOMAIN").StringVar(&cmd.AppConfig.DNS.Domain)
5661
return cmd
5762
}
5863

@@ -74,6 +79,9 @@ func (cmd *servecmd) Run() {
7479
if conf.VPN.CIDRv6 == "0" {
7580
conf.VPN.CIDRv6 = ""
7681
}
82+
if conf.DNS.Domain == "0" {
83+
conf.DNS.Domain = ""
84+
}
7785

7886
// Get the server's IP addresses within the VPN
7987
var vpnip, vpnipv6 *net.IPNet
@@ -98,6 +106,13 @@ func (cmd *servecmd) Run() {
98106
conf.VPN.AllowedIPs = append(conf.VPN.AllowedIPs, fmt.Sprintf("%s/128", vpnipv6.IP.String()))
99107
vpnipstrings = append(vpnipstrings, vpnipv6.String())
100108
}
109+
vpnips := make([]net.IP, 0, 2)
110+
if vpnip != nil {
111+
vpnips = append(vpnips, vpnip.IP)
112+
}
113+
if vpnipv6 != nil {
114+
vpnips = append(vpnips, vpnipv6.IP)
115+
}
101116

102117
// WireGuard Server
103118
wg := wgembed.NewNoOpInterface()
@@ -120,44 +135,76 @@ func (cmd *servecmd) Run() {
120135
}
121136

122137
if err := wg.LoadConfig(wgconfig); err != nil {
123-
logrus.Fatal(errors.Wrap(err, "failed to load wireguard config"))
138+
logrus.Error(errors.Wrap(err, "failed to load wireguard config"))
139+
return
124140
}
125141

126142
logrus.Infof("wireguard VPN network is %s", network.StringJoinIPNets(vpnip, vpnipv6))
127143

128144
if err := network.ConfigureForwarding(conf.VPN.GatewayInterface, conf.VPN.CIDR, conf.VPN.CIDRv6, conf.VPN.NAT44, conf.VPN.NAT66, conf.VPN.AllowedIPs); err != nil {
129-
logrus.Fatal(err)
145+
logrus.Error(err)
146+
return
130147
}
131148
}
132149

150+
// Storage
151+
storageBackend, err := storage.NewStorage(conf.Storage)
152+
if err != nil {
153+
logrus.Error(errors.Wrap(err, "failed to create storage backend"))
154+
return
155+
}
156+
if err := storageBackend.Open(); err != nil {
157+
logrus.Error(errors.Wrap(err, "failed to connect/open storage backend"))
158+
return
159+
}
160+
defer storageBackend.Close()
161+
162+
// Device manager
163+
deviceManager := devices.New(wg, storageBackend, conf.VPN.CIDR, conf.VPN.CIDRv6)
164+
133165
// DNS Server
134166
if conf.DNS.Enabled {
135-
if conf.DNS.Upstream == nil {
167+
if conf.DNS.Upstream == nil || len(conf.DNS.Upstream) <= 0 {
136168
conf.DNS.Upstream = detectDNSUpstream(conf.VPN.CIDR != "", conf.VPN.CIDRv6 != "")
137169
}
170+
listenAddr := make([]string, 0, 2)
171+
for _, addr := range vpnips {
172+
listenAddr = append(listenAddr, net.JoinHostPort(addr.String(), "53"))
173+
}
138174
dns, err := dnsproxy.New(dnsproxy.DNSServerOpts{
139-
Upstream: conf.DNS.Upstream,
175+
Upstream: conf.DNS.Upstream,
176+
Domain: conf.DNS.Domain,
177+
ListenAddr: listenAddr,
140178
})
141179
if err != nil {
142-
logrus.Fatal(errors.Wrap(err, "failed to start dns server"))
180+
logrus.Error(errors.Wrap(err, "failed to start dns server"))
181+
return
143182
}
144183
defer dns.Close()
184+
if conf.DNS.Domain != "" {
185+
// Generate initial DNS zone for registered devices
186+
zone := generateZone(deviceManager, vpnips)
187+
dns.PushAuthZone(zone)
188+
// Update the zone in the background whenever a device changes
189+
storageBackend.OnAdd(
190+
func(_ *storage.Device) {
191+
zone := generateZone(deviceManager, vpnips)
192+
dns.PushAuthZone(zone)
193+
},
194+
)
195+
storageBackend.OnDelete(
196+
func(_ *storage.Device) {
197+
zone := generateZone(deviceManager, vpnips)
198+
dns.PushAuthZone(zone)
199+
},
200+
)
201+
}
145202
}
146203

147-
// Storage
148-
storageBackend, err := storage.NewStorage(conf.Storage)
149-
if err != nil {
150-
logrus.Fatal(errors.Wrap(err, "failed to create storage backend"))
151-
}
152-
if err := storageBackend.Open(); err != nil {
153-
logrus.Fatal(errors.Wrap(err, "failed to connect/open storage backend"))
154-
}
155-
defer storageBackend.Close()
156-
157204
// Services
158-
deviceManager := devices.New(wg, storageBackend, conf.VPN.CIDR, conf.VPN.CIDRv6)
159205
if err := deviceManager.StartSync(conf.DisableMetadata); err != nil {
160-
logrus.Fatal(errors.Wrap(err, "failed to sync"))
206+
logrus.Error(errors.Wrap(err, "failed to sync"))
207+
return
161208
}
162209

163210
router := mux.NewRouter()
@@ -171,7 +218,8 @@ func (cmd *servecmd) Run() {
171218
if conf.Auth.IsEnabled() {
172219
middleware, err := authnz.NewMiddleware(conf.Auth, claimsMiddleware(conf))
173220
if err != nil {
174-
logrus.Fatal(errors.Wrap(err, "failed to set up authnz middleware"))
221+
logrus.Error(errors.Wrap(err, "failed to set up authnz middleware"))
222+
return
175223
}
176224
router.Use(middleware)
177225
} else {
@@ -203,17 +251,38 @@ func (cmd *servecmd) Run() {
203251

204252
publicRouter := router
205253

254+
signalChan := make(chan os.Signal, 2)
255+
signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM)
256+
errChan := make(chan error)
257+
206258
// Listen
207259
address := fmt.Sprintf(":%d", conf.Port)
208260
srv := &http.Server{
209261
Addr: address,
210262
Handler: publicRouter,
211263
}
212264

213-
// Start Web server
214-
logrus.Infof("web ui listening on %v", address)
215-
if err := srv.ListenAndServe(); err != nil {
216-
logrus.Fatal(errors.Wrap(err, "unable to start http server"))
265+
go func() {
266+
// Start Web server
267+
logrus.Infof("web ui listening on %v", address)
268+
err := srv.ListenAndServe()
269+
if err != nil && !errors.Is(err, http.ErrServerClosed) {
270+
errChan <- errors.Wrap(err, "unable to start http server")
271+
}
272+
}()
273+
274+
select {
275+
case <-signalChan:
276+
ctx := context.Background()
277+
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
278+
err = srv.Shutdown(ctx)
279+
if err != nil {
280+
logrus.Error(err)
281+
}
282+
cancel() // always call cancel to clean up the context
283+
case err = <-errChan:
284+
logrus.Error(err)
285+
return
217286
}
218287
}
219288

@@ -318,6 +387,31 @@ func detectDefaultInterface() string {
318387
return ""
319388
}
320389

390+
func generateZone(deviceManager *devices.DeviceManager, vpnips []net.IP) dnsproxy.Zone {
391+
devs, err := deviceManager.ListAllDevices()
392+
if err != nil {
393+
logrus.Error(errors.Wrap(err, "could not query devices to generate the DNS zone"))
394+
}
395+
396+
zone := make(dnsproxy.Zone)
397+
for _, device := range devs {
398+
owner := device.Owner
399+
name := device.Name
400+
addressStrings := network.SplitAddresses(device.Address)
401+
addresses := make([]net.IP, 0, 2)
402+
for _, str := range addressStrings {
403+
addr, _, err := net.ParseCIDR(str)
404+
if err != nil {
405+
continue
406+
}
407+
addresses = append(addresses, addr)
408+
}
409+
zone[dnsproxy.ZoneKey{Owner: owner, Name: name}] = addresses
410+
}
411+
zone[dnsproxy.ZoneKey{}] = vpnips
412+
return zone
413+
}
414+
321415
var missingPrivateKey = `missing wireguard private key:
322416
323417
create a key:

0 commit comments

Comments
 (0)