diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 1479a55..db55d7c 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -39,7 +39,7 @@ jobs: run: | TAG=${{ env.TAG }} if [ -f main.go ]; then - sed -i 's/Newt version replaceme/Newt version '"$TAG"'/' main.go + sed -i 's/version_replaceme/'"$TAG"'/' main.go echo "Updated main.go with version $TAG" else echo "main.go not found" diff --git a/.gitignore b/.gitignore index ba74660..d14efa9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ newt .DS_Store bin/ +nohup.out .idea *.iml -certs/ \ No newline at end of file +certs/ +newt_arm64 diff --git a/README.md b/README.md index 590a5de..9d88096 100644 --- a/README.md +++ b/README.md @@ -33,11 +33,13 @@ When Newt receives WireGuard control messages, it will use the information encod - `endpoint`: The endpoint where both Gerbil and Pangolin reside in order to connect to the websocket. - `id`: Newt ID generated by Pangolin to identify the client. - `secret`: A unique secret (not shared and kept private) used to authenticate the client ID with the websocket in order to receive commands. +- `mtu`: MTU for the internal WG interface. Default: 1280 - `dns`: DNS server to use to resolve the endpoint - `log-level` (optional): The log level to use. Default: INFO - `updown` (optional): A script to be called when targets are added or removed. - `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS. See [mTLS](#mtls) - `docker-socket` (optional): Set the Docker socket to use the container discovery integration +- `docker-enforce-network-validation` (optional): Validate the container target is on the same network as the newt process - Example: @@ -99,6 +101,26 @@ services: - DOCKER_SOCKET=/var/run/docker.sock ``` +#### Hostnames vs IPs + +When the Docker Socket Integration is used, depending on the network which Newt is run with, either the hostname (generally considered the container name) or the IP address of the container will be sent to Pangolin. Here are some of the scenarios where IPs or hostname of the container will be utilised: +- **Running in Network Mode 'host'**: IP addresses will be used +- **Running in Network Mode 'bridge'**: IP addresses will be used +- **Running in docker-compose without a network specification**: Docker compose creates a network for the compose by default, hostnames will be used +- **Running on docker-compose with defined network**: Hostnames will be used + +### Docker Enforce Network Validation + +When run as a Docker container, Newt can validate that the target being provided is on the same network as the Newt container and only return containers directly accessible by Newt. Validation will be carried out against either the hostname/IP Address and the Port number to ensure the running container is exposing the ports to Newt. + +It is important to note that if the Newt container is run with a network mode of `host` that this feature will not work. Running in `host` mode causes the container to share its resources with the host machine, therefore making it so the specific host container information for Newt cannot be retrieved to be able to carry out network validation. + +**Configuration:** + +Validation is `false` by default. It can be enabled via setting the `--docker-enforce-network-validation` CLI argument or by setting the `DOCKER_ENFORCE_NETWORK_VALIDATION` environment variable. + +If validation is enforced and the Docker socket is available, Newt will **not** add the target as it cannot be verified. A warning will be presented in the Newt logs. + ### Updown You can pass in a updown script for Newt to call when it is adding or removing a target: diff --git a/docker/client.go b/docker/client.go index 98936fe..6a3bdb7 100644 --- a/docker/client.go +++ b/docker/client.go @@ -4,10 +4,13 @@ import ( "context" "fmt" "net" + "os" + "strconv" "strings" "time" "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/filters" "github.com/docker/docker/client" "github.com/fosrl/newt/logger" ) @@ -67,13 +70,60 @@ func CheckSocket(socketPath string) bool { return true } +// IsWithinHostNetwork checks if a provided target is within the host container network +func IsWithinHostNetwork(socketPath string, targetAddress string, targetPort int) (bool, error) { + // Always enforce network validation + containers, err := ListContainers(socketPath, true) + if err != nil { + return false, err + } + + // Determine if given an IP address + var parsedTargetAddressIp = net.ParseIP(targetAddress) + + // If we can find the passed hostname/IP address in the networks or as the container name, it is valid and can add it + for _, c := range containers { + for _, network := range c.Networks { + // If the target address is not an IP address, use the container name + if parsedTargetAddressIp == nil { + if c.Name == targetAddress { + for _, port := range c.Ports { + if port.PublicPort == targetPort || port.PrivatePort == targetPort { + return true, nil + } + } + } + } else { + //If the IP address matches, check the ports being mapped too + if network.IPAddress == targetAddress { + for _, port := range c.Ports { + if port.PublicPort == targetPort || port.PrivatePort == targetPort { + return true, nil + } + } + } + } + } + } + + combinedTargetAddress := targetAddress + ":" + strconv.Itoa(targetPort) + return false, fmt.Errorf("target address not within host container network: %s", combinedTargetAddress) +} + // ListContainers lists all Docker containers with their network information -func ListContainers(socketPath string) ([]Container, error) { +func ListContainers(socketPath string, enforceNetworkValidation bool) ([]Container, error) { // Use the provided socket path or default to standard location if socketPath == "" { socketPath = "/var/run/docker.sock" } + // Used to filter down containers returned to Pangolin + containerFilters := filters.NewArgs() + + // Used to determine if we will send IP addresses or hostnames to Pangolin + useContainerIpAddresses := true + hostContainerId := "" + // Create a new Docker client ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -86,16 +136,54 @@ func ListContainers(socketPath string) ([]Container, error) { if err != nil { return nil, fmt.Errorf("failed to create Docker client: %v", err) } + defer cli.Close() + hostContainer, err := getHostContainer(ctx, cli) + if enforceNetworkValidation && err != nil { + return nil, fmt.Errorf("network validation enforced, cannot validate due to: %w", err) + } + + // We may not be able to get back host container in scenarios like running the container in network mode 'host' + if hostContainer != nil { + // We can use the host container to filter out the list of returned containers + hostContainerId = hostContainer.ID + + for hostContainerNetworkName := range hostContainer.NetworkSettings.Networks { + // If we're enforcing network validation, we'll filter on the host containers networks + if enforceNetworkValidation { + containerFilters.Add("network", hostContainerNetworkName) + } + + // If the container is on the docker bridge network, we will use IP addresses over hostnames + if useContainerIpAddresses && hostContainerNetworkName != "bridge" { + useContainerIpAddresses = false + } + } + } + // List containers - containers, err := cli.ContainerList(ctx, container.ListOptions{All: true}) + containers, err := cli.ContainerList(ctx, container.ListOptions{All: true, Filters: containerFilters}) if err != nil { return nil, fmt.Errorf("failed to list containers: %v", err) } var dockerContainers []Container for _, c := range containers { + // Short ID like docker ps + shortId := c.ID[:12] + + // Skip host container if set + if hostContainerId != "" && c.ID == hostContainerId { + continue + } + + // Get container name (remove leading slash) + name := "" + if len(c.Names) > 0 { + name = strings.TrimPrefix(c.Names[0], "/") + } + // Convert ports var ports []Port for _, port := range c.Ports { @@ -112,44 +200,36 @@ func ListContainers(socketPath string) ([]Container, error) { ports = append(ports, dockerPort) } - // Get container name (remove leading slash) - name := "" - if len(c.Names) > 0 { - name = strings.TrimPrefix(c.Names[0], "/") - } - // Get network information by inspecting the container networks := make(map[string]Network) - // Inspect container to get detailed network information - containerInfo, err := cli.ContainerInspect(ctx, c.ID) - if err != nil { - logger.Debug("Failed to inspect container %s for network info: %v", c.ID[:12], err) - // Continue without network info if inspection fails - } else { - // Extract network information from inspection - if containerInfo.NetworkSettings != nil && containerInfo.NetworkSettings.Networks != nil { - for networkName, endpoint := range containerInfo.NetworkSettings.Networks { - dockerNetwork := Network{ - NetworkID: endpoint.NetworkID, - EndpointID: endpoint.EndpointID, - Gateway: endpoint.Gateway, - IPAddress: endpoint.IPAddress, - IPPrefixLen: endpoint.IPPrefixLen, - IPv6Gateway: endpoint.IPv6Gateway, - GlobalIPv6Address: endpoint.GlobalIPv6Address, - GlobalIPv6PrefixLen: endpoint.GlobalIPv6PrefixLen, - MacAddress: endpoint.MacAddress, - Aliases: endpoint.Aliases, - DNSNames: endpoint.DNSNames, - } - networks[networkName] = dockerNetwork + // Extract network information from inspection + if c.NetworkSettings != nil && c.NetworkSettings.Networks != nil { + for networkName, endpoint := range c.NetworkSettings.Networks { + dockerNetwork := Network{ + NetworkID: endpoint.NetworkID, + EndpointID: endpoint.EndpointID, + Gateway: endpoint.Gateway, + IPPrefixLen: endpoint.IPPrefixLen, + IPv6Gateway: endpoint.IPv6Gateway, + GlobalIPv6Address: endpoint.GlobalIPv6Address, + GlobalIPv6PrefixLen: endpoint.GlobalIPv6PrefixLen, + MacAddress: endpoint.MacAddress, + Aliases: endpoint.Aliases, + DNSNames: endpoint.DNSNames, + } + + // Use IPs over hostnames/containers as we're on the bridge network + if useContainerIpAddresses { + dockerNetwork.IPAddress = endpoint.IPAddress } + + networks[networkName] = dockerNetwork } } dockerContainer := Container{ - ID: c.ID[:12], // Show short ID like docker ps + ID: shortId, Name: name, Image: c.Image, State: c.State, @@ -159,8 +239,26 @@ func ListContainers(socketPath string) ([]Container, error) { Created: c.Created, Networks: networks, } + dockerContainers = append(dockerContainers, dockerContainer) } return dockerContainers, nil } + +// getHostContainer gets the current container for the current host if possible +func getHostContainer(dockerContext context.Context, dockerClient *client.Client) (*container.InspectResponse, error) { + // Get hostname from the os + hostContainerName, err := os.Hostname() + if err != nil { + return nil, fmt.Errorf("failed to find hostname for container") + } + + // Get host container from the docker socket + hostContainer, err := dockerClient.ContainerInspect(dockerContext, hostContainerName) + if err != nil { + return nil, fmt.Errorf("failed to find host container") + } + + return &hostContainer, nil +} diff --git a/go.mod b/go.mod index b6c3839..7b469bb 100644 --- a/go.mod +++ b/go.mod @@ -6,11 +6,15 @@ toolchain go1.23.2 require ( github.com/docker/docker v28.3.1+incompatible + github.com/google/gopacket v1.1.19 github.com/gorilla/websocket v1.5.3 + github.com/vishvananda/netlink v1.3.0 + golang.org/x/crypto v0.39.0 + golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa golang.org/x/net v0.41.0 - golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 - golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 - gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 + golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb + golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 + gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c software.sslmate.com/src/go-pkcs12 v0.5.0 ) @@ -18,7 +22,6 @@ require ( github.com/Microsoft/go-winio v0.6.0 // indirect github.com/containerd/errdefs v1.0.0 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect - github.com/containerd/log v0.1.0 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.4.0 // indirect @@ -27,6 +30,11 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/btree v1.1.2 // indirect + github.com/google/go-cmp v0.7.0 // indirect + github.com/josharian/native v1.1.0 // indirect + github.com/mdlayher/genetlink v1.3.2 // indirect + github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mdlayher/socket v0.5.1 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/sys/atomicwriter v0.1.0 // indirect github.com/moby/term v0.5.2 // indirect @@ -34,16 +42,17 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/vishvananda/netns v0.0.4 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect go.opentelemetry.io/otel v1.36.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.36.0 // indirect go.opentelemetry.io/otel/metric v1.36.0 // indirect go.opentelemetry.io/otel/trace v1.36.0 // indirect - golang.org/x/crypto v0.39.0 // indirect - golang.org/x/mod v0.12.0 // indirect + golang.org/x/mod v0.23.0 // indirect + golang.org/x/sync v0.11.0 // indirect golang.org/x/sys v0.33.0 // indirect golang.org/x/time v0.7.0 // indirect - golang.org/x/tools v0.13.0 // indirect + golang.org/x/tools v0.30.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect ) diff --git a/go.sum b/go.sum index e074360..1d6d20d 100644 --- a/go.sum +++ b/go.sum @@ -34,14 +34,26 @@ github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 h1:5ZPtiqj0JL5oKWmcsq4VMaAW5ukBEgSGXEN89zeH1Jo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3/go.mod h1:ndYquD05frm2vACXE1nsccT4oJzjhw2arTS2cpUD1PI= +github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= +github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= +github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= +github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= +github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= +github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw= @@ -64,6 +76,10 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= +github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= +github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= +github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= @@ -91,10 +107,14 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= +golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4= +golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= -golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= +golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -104,11 +124,13 @@ golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= -golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -119,23 +141,26 @@ golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ= -golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= +golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= +golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= -golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4= -golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= -golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= -golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= -google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f h1:BWUVssLB0HVOSY78gIdvk1dTVYtT1y8SBWtPYuTJ/6w= +golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A= +golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= +google.golang.org/genproto v0.0.0-20230920204549-e6e6cdab5c13 h1:vlzZttNJGVqTsRFU9AmdnrcO1Znh8Ew9kCD//yjigk0= google.golang.org/genproto/googleapis/api v0.0.0-20250519155744-55703ea1f237 h1:Kog3KlB4xevJlAcbbbzPfRG0+X9fdoGM+UBRKVz6Wr0= +google.golang.org/genproto/googleapis/api v0.0.0-20250519155744-55703ea1f237/go.mod h1:ezi0AVyMKDWy5xAncvjLWH7UcLBB5n7y2fQ8MzjJcto= google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237 h1:cJfm9zPbe1e873mHJzmQ1nwVEeRDU/T1wXDK2kUSU34= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/grpc v1.72.1 h1:HR03wO6eyZ7lknl75XlxABNVLLFc2PAb6mHlYh756mA= google.golang.org/grpc v1.72.1/go.mod h1:wH5Aktxcg25y1I3w7H69nHfXdOG3UiadoBtjh3izSDM= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= @@ -144,7 +169,7 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o= gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= -gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= -gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= +gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI= +gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g= software.sslmate.com/src/go-pkcs12 v0.5.0 h1:EC6R394xgENTpZ4RltKydeDUjtlM5drOYIG9c6TVj2M= software.sslmate.com/src/go-pkcs12 v0.5.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= diff --git a/linux.go b/linux.go new file mode 100644 index 0000000..790f634 --- /dev/null +++ b/linux.go @@ -0,0 +1,83 @@ +//go:build linux + +package main + +import ( + "fmt" + "strings" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/proxy" + "github.com/fosrl/newt/websocket" + "github.com/fosrl/newt/wg" + "github.com/fosrl/newt/wgtester" +) + +var wgService *wg.WireGuardService +var wgTesterServer *wgtester.Server + +func setupClients(client *websocket.Client) { + var host = endpoint + if strings.HasPrefix(host, "http://") { + host = strings.TrimPrefix(host, "http://") + } else if strings.HasPrefix(host, "https://") { + host = strings.TrimPrefix(host, "https://") + } + + host = strings.TrimSuffix(host, "/") + + // Create WireGuard service + wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client) + if err != nil { + logger.Fatal("Failed to create WireGuard service: %v", err) + } + defer wgService.Close(rm) + + wgTesterServer = wgtester.NewServer("0.0.0.0", wgService.Port, id) // TODO: maybe make this the same ip of the wg server? + err := wgTesterServer.Start() + if err != nil { + logger.Error("Failed to start WireGuard tester server: %v", err) + } else { + // Make sure to stop the server on exit + defer wgTesterServer.Stop() + } + + client.OnTokenUpdate(func(token string) { + wgService.SetToken(token) + }) +} + +func closeClients() { + if wgService != nil { + wgService.Close(rm) + wgService = nil + } + + if wgTesterServer != nil { + wgTesterServer.Stop() + wgTesterServer = nil + } +} + +func clientsHandleNewtConnection(publicKey string) { + if wgService == nil { + return + } + wgService.SetServerPubKey(publicKey) +} + +func clientsOnConnect() { + if wgService == nil { + return + } + wgService.LoadRemoteConfig() +} + +func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) { + if wgService == nil { + return + } + // add a udp proxy for localost and the wgService port + // TODO: make sure this port is not used in a target + pm.AddTarget("udp", tunnelIp, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port)) +} diff --git a/main.go b/main.go index fdece97..8af6403 100644 --- a/main.go +++ b/main.go @@ -1,18 +1,15 @@ package main import ( - "bytes" - "encoding/base64" - "encoding/hex" "encoding/json" "flag" "fmt" - "math/rand" "net" + "net/http" "net/netip" "os" - "os/exec" "os/signal" + "runtime" "strconv" "strings" "syscall" @@ -21,10 +18,9 @@ import ( "github.com/fosrl/newt/docker" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/proxy" + "github.com/fosrl/newt/updates" "github.com/fosrl/newt/websocket" - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" @@ -49,310 +45,54 @@ type TargetData struct { Targets []string `json:"targets"` } -func fixKey(key string) string { - // Remove any whitespace - key = strings.TrimSpace(key) - - // Decode from base64 - decoded, err := base64.StdEncoding.DecodeString(key) - if err != nil { - logger.Fatal("Error decoding base64: %v", err) - } - - // Convert to hex - return hex.EncodeToString(decoded) -} - -func ping(tnet *netstack.Net, dst string) error { - logger.Info("Pinging %s", dst) - socket, err := tnet.Dial("ping4", dst) - if err != nil { - return fmt.Errorf("failed to create ICMP socket: %w", err) - } - defer socket.Close() - - requestPing := icmp.Echo{ - Seq: rand.Intn(1 << 16), - Data: []byte("gopher burrow"), - } - - icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) - if err != nil { - return fmt.Errorf("failed to marshal ICMP message: %w", err) - } - - if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil { - return fmt.Errorf("failed to set read deadline: %w", err) - } - - start := time.Now() - _, err = socket.Write(icmpBytes) - if err != nil { - return fmt.Errorf("failed to write ICMP packet: %w", err) - } - - n, err := socket.Read(icmpBytes[:]) - if err != nil { - return fmt.Errorf("failed to read ICMP packet: %w", err) - } - - replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) - if err != nil { - return fmt.Errorf("failed to parse ICMP packet: %w", err) - } - - replyPing, ok := replyPacket.Body.(*icmp.Echo) - if !ok { - return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body) - } - - if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { - return fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q", - replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data) - } - - logger.Info("Ping latency: %v", time.Since(start)) - return nil -} - -func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) { - initialInterval := 10 * time.Second - maxInterval := 60 * time.Second - currentInterval := initialInterval - consecutiveFailures := 0 - - ticker := time.NewTicker(currentInterval) - defer ticker.Stop() - - go func() { - for { - select { - case <-ticker.C: - err := ping(tnet, serverIP) - if err != nil { - consecutiveFailures++ - logger.Warn("Periodic ping failed (%d consecutive failures): %v", - consecutiveFailures, err) - logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?") - - // Increase interval if we have consistent failures, with a maximum cap - if consecutiveFailures >= 3 && currentInterval < maxInterval { - // Increase by 50% each time, up to the maximum - currentInterval = time.Duration(float64(currentInterval) * 1.5) - if currentInterval > maxInterval { - currentInterval = maxInterval - } - ticker.Reset(currentInterval) - logger.Info("Increased ping check interval to %v due to consecutive failures", - currentInterval) - } - } else { - // On success, if we've backed off, gradually return to normal interval - if currentInterval > initialInterval { - currentInterval = time.Duration(float64(currentInterval) * 0.8) - if currentInterval < initialInterval { - currentInterval = initialInterval - } - ticker.Reset(currentInterval) - logger.Info("Decreased ping check interval to %v after successful ping", - currentInterval) - } - consecutiveFailures = 0 - } - case <-stopChan: - logger.Info("Stopping ping check") - return - } - } - }() +type ExitNodeData struct { + ExitNodes []ExitNode `json:"exitNodes"` } -// Function to track connection status and trigger reconnection as needed -func monitorConnectionStatus(tnet *netstack.Net, serverIP string, client *websocket.Client) { - const checkInterval = 30 * time.Second - connectionLost := false - ticker := time.NewTicker(checkInterval) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - // Try a ping to see if connection is alive - err := ping(tnet, serverIP) - - if err != nil && !connectionLost { - // We just lost connection - connectionLost = true - logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.") - - // Notify the user they might need to check their network - logger.Warn("Please check your internet connection and ensure the Pangolin server is online.") - logger.Warn("Newt will continue reconnection attempts automatically when connectivity is restored.") - } else if err == nil && connectionLost { - // Connection has been restored - connectionLost = false - logger.Info("Connection to server restored!") - - // Tell the server we're back - err := client.SendMessage("newt/wg/register", map[string]interface{}{ - "publicKey": privateKey.PublicKey().String(), - }) - - if err != nil { - logger.Error("Failed to send registration message after reconnection: %v", err) - } else { - logger.Info("Successfully re-registered with server after reconnection") - } - } - } - } +// ExitNode represents an exit node with an ID, endpoint, and weight. +type ExitNode struct { + ID int `json:"exitNodeId"` + Name string `json:"exitNodeName"` + Endpoint string `json:"endpoint"` + Weight float64 `json:"weight"` + WasPreviouslyConnected bool `json:"wasPreviouslyConnected"` } -func pingWithRetry(tnet *netstack.Net, dst string) error { - const ( - initialMaxAttempts = 15 - initialRetryDelay = 2 * time.Second - maxRetryDelay = 60 * time.Second // Cap the maximum delay - ) - - attempt := 1 - retryDelay := initialRetryDelay - - // First try with the initial parameters - logger.Info("Ping attempt %d", attempt) - if err := ping(tnet, dst); err == nil { - // Successful ping - return nil - } else { - logger.Warn("Ping attempt %d failed: %v", attempt, err) - } - - // Start a goroutine that will attempt pings indefinitely with increasing delays - go func() { - attempt = 2 // Continue from attempt 2 - - for { - logger.Info("Ping attempt %d", attempt) - - if err := ping(tnet, dst); err != nil { - logger.Warn("Ping attempt %d failed: %v", attempt, err) - - // Increase delay after certain thresholds but cap it - if attempt%5 == 0 && retryDelay < maxRetryDelay { - retryDelay = time.Duration(float64(retryDelay) * 1.5) - if retryDelay > maxRetryDelay { - retryDelay = maxRetryDelay - } - logger.Info("Increasing ping retry delay to %v", retryDelay) - } - - time.Sleep(retryDelay) - attempt++ - } else { - // Successful ping - logger.Info("Ping succeeded after %d attempts", attempt) - return - } - } - }() - - // Return an error for the first batch of attempts (to maintain compatibility with existing code) - return fmt.Errorf("initial ping attempts failed, continuing in background") -} - -func parseLogLevel(level string) logger.LogLevel { - switch strings.ToUpper(level) { - case "DEBUG": - return logger.DEBUG - case "INFO": - return logger.INFO - case "WARN": - return logger.WARN - case "ERROR": - return logger.ERROR - case "FATAL": - return logger.FATAL - default: - return logger.INFO // default to INFO if invalid level provided - } -} - -func mapToWireGuardLogLevel(level logger.LogLevel) int { - switch level { - case logger.DEBUG: - return device.LogLevelVerbose - // case logger.INFO: - // return device.LogLevel - case logger.WARN: - return device.LogLevelError - case logger.ERROR, logger.FATAL: - return device.LogLevelSilent - default: - return device.LogLevelSilent - } -} - -func resolveDomain(domain string) (string, error) { - // Check if there's a port in the domain - host, port, err := net.SplitHostPort(domain) - if err != nil { - // No port found, use the domain as is - host = domain - port = "" - } - - // Remove any protocol prefix if present - if strings.HasPrefix(host, "http://") { - host = strings.TrimPrefix(host, "http://") - } else if strings.HasPrefix(host, "https://") { - host = strings.TrimPrefix(host, "https://") - } - - // Lookup IP addresses - ips, err := net.LookupIP(host) - if err != nil { - return "", fmt.Errorf("DNS lookup failed: %v", err) - } - - if len(ips) == 0 { - return "", fmt.Errorf("no IP addresses found for domain %s", host) - } - - // Get the first IPv4 address if available - var ipAddr string - for _, ip := range ips { - if ipv4 := ip.To4(); ipv4 != nil { - ipAddr = ipv4.String() - break - } - } - - // If no IPv4 found, use the first IP (might be IPv6) - if ipAddr == "" { - ipAddr = ips[0].String() - } - - // Add port back if it existed - if port != "" { - ipAddr = net.JoinHostPort(ipAddr, port) - } - - return ipAddr, nil +type ExitNodePingResult struct { + ExitNodeID int `json:"exitNodeId"` + LatencyMs int64 `json:"latencyMs"` + Weight float64 `json:"weight"` + Error string `json:"error,omitempty"` + Name string `json:"exitNodeName"` + Endpoint string `json:"endpoint"` + WasPreviouslyConnected bool `json:"wasPreviouslyConnected"` } var ( - endpoint string - id string - secret string - mtu string - mtuInt int - dns string - privateKey wgtypes.Key - err error - logLevel string - updownScript string - tlsPrivateKey string - dockerSocket string + endpoint string + id string + secret string + mtu string + mtuInt int + dns string + privateKey wgtypes.Key + err error + logLevel string + interfaceName string + generateAndSaveKeyTo string + rm bool + acceptClients bool + updownScript string + tlsPrivateKey string + dockerSocket string + dockerEnforceNetworkValidation string + dockerEnforceNetworkValidationBool bool + pingInterval = 2 * time.Second + pingTimeout = 3 * time.Second + publicKey wgtypes.Key + pingStopChan chan struct{} + stopFunc func() + healthFile string ) func main() { @@ -364,8 +104,16 @@ func main() { dns = os.Getenv("DNS") logLevel = os.Getenv("LOG_LEVEL") updownScript = os.Getenv("UPDOWN_SCRIPT") + // interfaceName = os.Getenv("INTERFACE") + // generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") + // rm = os.Getenv("RM") == "true" + // acceptClients = os.Getenv("ACCEPT_CLIENTS") == "true" tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT") dockerSocket = os.Getenv("DOCKER_SOCKET") + pingIntervalStr := os.Getenv("PING_INTERVAL") + pingTimeoutStr := os.Getenv("PING_TIMEOUT") + dockerEnforceNetworkValidation = os.Getenv("DOCKER_ENFORCE_NETWORK_VALIDATION") + healthFile = os.Getenv("HEALTH_FILE") if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") @@ -388,29 +136,70 @@ func main() { if updownScript == "" { flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed") } + // if interfaceName == "" { + // flag.StringVar(&interfaceName, "interface", "wg1", "Name of the WireGuard interface") + // } + // if generateAndSaveKeyTo == "" { + // flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "/tmp/newtkey", "Path to save generated private key") + // } + // flag.BoolVar(&rm, "rm", false, "Remove the WireGuard interface") + // flag.BoolVar(&acceptClients, "accept-clients", false, "Accept clients on the WireGuard interface") if tlsPrivateKey == "" { flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate used for mTLS") } if dockerSocket == "" { flag.StringVar(&dockerSocket, "docker-socket", "", "Path to Docker socket (typically /var/run/docker.sock)") } + if pingIntervalStr == "" { + flag.StringVar(&pingIntervalStr, "ping-interval", "1s", "Interval for pinging the server (default 1s)") + } + if pingTimeoutStr == "" { + flag.StringVar(&pingTimeoutStr, "ping-timeout", "2s", " Timeout for each ping (default 2s)") + } + + if pingIntervalStr != "" { + pingInterval, err = time.ParseDuration(pingIntervalStr) + if err != nil { + fmt.Printf("Invalid PING_INTERVAL value: %s, using default 1 second\n", pingIntervalStr) + pingInterval = 1 * time.Second + } + } + + if pingTimeoutStr != "" { + pingTimeout, err = time.ParseDuration(pingTimeoutStr) + if err != nil { + fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 2 seconds\n", pingTimeoutStr) + pingTimeout = 2 * time.Second + } + } + + if dockerEnforceNetworkValidation == "" { + flag.StringVar(&dockerEnforceNetworkValidation, "docker-enforce-network-validation", "false", "Enforce validation of container on newt network (true or false)") + } + if healthFile == "" { + flag.StringVar(&healthFile, "health-file", "", "Path to health file (if unset, health file won’t be written)") + } // do a --version check version := flag.Bool("version", false, "Print the version") flag.Parse() - newtVersion := "Newt version replaceme" + logger.Init() + loggerLevel := parseLogLevel(logLevel) + logger.GetLogger().SetLevel(parseLogLevel(logLevel)) + + newtVersion := "version_replaceme" if *version { - fmt.Println(newtVersion) + fmt.Println("Newt version " + newtVersion) os.Exit(0) } else { - logger.Info(newtVersion) + logger.Info("Newt version " + newtVersion) } - logger.Init() - loggerLevel := parseLogLevel(logLevel) - logger.GetLogger().SetLevel(parseLogLevel(logLevel)) + if err := updates.CheckForUpdate("fosrl", "newt", newtVersion); err != nil { + logger.Error("Error checking for updates: %v\n", err) + } // parse the mtu string into an int mtuInt, err = strconv.Atoi(mtu) @@ -418,6 +207,13 @@ func main() { logger.Fatal("Failed to parse MTU: %v", err) } + // parse if we want to enforce container network validation + dockerEnforceNetworkValidationBool, err = strconv.ParseBool(dockerEnforceNetworkValidation) + if err != nil { + logger.Info("Docker enforce network validation cannot be parsed. Defaulting to 'false'") + dockerEnforceNetworkValidationBool = false + } + privateKey, err = wgtypes.GeneratePrivateKey() if err != nil { logger.Fatal("Failed to generate private key: %v", err) @@ -431,12 +227,32 @@ func main() { id, // CLI arg takes precedence secret, // CLI arg takes precedence endpoint, + pingInterval, + pingTimeout, opt, ) if err != nil { logger.Fatal("Failed to create client: %v", err) } + // output env var values if set + logger.Debug("Endpoint: %v", endpoint) + logger.Debug("Log Level: %v", logLevel) + logger.Debug("Docker Network Validation Enabled: %v", dockerEnforceNetworkValidationBool) + logger.Debug("TLS Private Key Set: %v", tlsPrivateKey != "") + if dns != "" { + logger.Debug("Dns: %v", dns) + } + if dockerSocket != "" { + logger.Debug("Docker Socket: %v", dockerSocket) + } + if mtu != "" { + logger.Debug("MTU: %v", mtu) + } + if updownScript != "" { + logger.Debug("Up Down Script: %v", updownScript) + } + // Create TUN device and network stack var tun tun.Device var tnet *netstack.Net @@ -445,29 +261,61 @@ func main() { var connected bool var wgData WgData - client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) { - logger.Info("Received terminate message") + if acceptClients { + // make sure we are running on linux + if runtime.GOOS != "linux" { + logger.Fatal("Tunnel management is only supported on Linux right now!") + os.Exit(1) + } + + setupClients(client) + } + + var pingWithRetryStopChan chan struct{} + + closeWgTunnel := func() { + if pingStopChan != nil { + // Stop the ping check + close(pingStopChan) + pingStopChan = nil + } + + // Stop proxy manager if running if pm != nil { pm.Stop() + pm = nil } + + // Close WireGuard device first - this will automatically close the TUN device if dev != nil { dev.Close() + dev = nil } - client.Close() - }) - pingStopChan := make(chan struct{}) - defer close(pingStopChan) + // Clear references but don't manually close since dev.Close() already did it + if tnet != nil { + tnet = nil + } + if tun != nil { + tun = nil // Don't call tun.Close() here since dev.Close() already closed it + } + + } // Register handlers for different message types client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) { logger.Info("Received registration message") + if stopFunc != nil { + stopFunc() // stop the ws from sending more requests + stopFunc = nil // reset stopFunc to nil to avoid double stopping + } if connected { - logger.Info("Already connected! But I will send a ping anyway...") - // Even if pingWithRetry returns an error, it will continue trying in the background - _ = pingWithRetry(tnet, wgData.ServerIP) // Ignoring initial error as pings will continue - return + // Mark as disconnected + + closeWgTunnel() + + connected = false } jsonData, err := json.Marshal(msg.Data) @@ -481,7 +329,9 @@ func main() { return } - logger.Info("Received: %+v", msg) + clientsHandleNewtConnection(wgData.PublicKey) + + logger.Debug("Received: %+v", msg) tun, tnet, err = netstack.CreateNetTUN( []netip.Addr{netip.MustParseAddr(wgData.TunnelIP)}, []netip.Addr{netip.MustParseAddr(dns)}, @@ -496,6 +346,14 @@ func main() { "wireguard: ", )) + host, _, err := net.SplitHostPort(wgData.Endpoint) + if err != nil { + logger.Error("Failed to split endpoint: %v", err) + return + } + + logger.Info("Connecting to endpoint: %s", host) + endpoint, err := resolveDomain(wgData.Endpoint) if err != nil { logger.Error("Failed to resolve endpoint: %v", err) @@ -520,19 +378,21 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub logger.Error("Failed to bring up WireGuard device: %v", err) } - logger.Info("WireGuard device created. Lets ping the server now...") + logger.Debug("WireGuard device created. Lets ping the server now...") // Even if pingWithRetry returns an error, it will continue trying in the background - _ = pingWithRetry(tnet, wgData.ServerIP) + if pingWithRetryStopChan != nil { + // Stop the previous pingWithRetry if it exists + close(pingWithRetryStopChan) + pingWithRetryStopChan = nil + } + pingWithRetryStopChan, _ = pingWithRetry(tnet, wgData.ServerIP, pingTimeout) // Always mark as connected and start the proxy manager regardless of initial ping result // as the pings will continue in the background if !connected { - logger.Info("Starting ping check") - startPingCheck(tnet, wgData.ServerIP, pingStopChan) - - // Start connection monitoring in a separate goroutine - go monitorConnectionStatus(tnet, wgData.ServerIP, client) + logger.Debug("Starting ping check") + pingStopChan = startPingCheck(tnet, wgData.ServerIP, client) } // Create proxy manager @@ -549,14 +409,153 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP}) } + clientsAddProxyTarget(pm, wgData.TunnelIP) + err = pm.Start() if err != nil { logger.Error("Failed to start proxy manager: %v", err) } }) + client.RegisterHandler("newt/wg/reconnect", func(msg websocket.WSMessage) { + logger.Info("Received reconnect message") + + // Close the WireGuard device and TUN + closeWgTunnel() + + // Mark as disconnected + connected = false + + if stopFunc != nil { + stopFunc() // stop the ws from sending more requests + stopFunc = nil // reset stopFunc to nil to avoid double stopping + } + + // Request exit nodes from the server + stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) + + logger.Info("Tunnel destroyed, ready for reconnection") + }) + + client.RegisterHandler("newt/wg/terminate", func(msg websocket.WSMessage) { + logger.Info("Received termination message") + + // Close the WireGuard device and TUN + closeWgTunnel() + + // Mark as disconnected + connected = false + + logger.Info("Tunnel destroyed") + }) + + client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) { + logger.Info("Received ping message") + if stopFunc != nil { + stopFunc() // stop the ws from sending more requests + stopFunc = nil // reset stopFunc to nil to avoid double stopping + } + + // Parse the incoming list of exit nodes + var exitNodeData ExitNodeData + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + if err := json.Unmarshal(jsonData, &exitNodeData); err != nil { + logger.Info("Error unmarshaling exit node data: %v", err) + return + } + exitNodes := exitNodeData.ExitNodes + + if len(exitNodes) == 0 { + logger.Info("No exit nodes provided") + return + } + + type nodeResult struct { + Node ExitNode + Latency time.Duration + Err error + } + + results := make([]nodeResult, len(exitNodes)) + const pingAttempts = 3 + for i, node := range exitNodes { + if connected && node.WasPreviouslyConnected { + logger.Info("Skipping ping for previously connected exit node so we pick another %d (%s)", node.ID, node.Endpoint) + continue + } + + var totalLatency time.Duration + var lastErr error + successes := 0 + client := &http.Client{ + Timeout: 5 * time.Second, + } + url := node.Endpoint + if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { + url = "http://" + url + } + if !strings.HasSuffix(url, "/ping") { + url = strings.TrimRight(url, "/") + "/ping" + } + for j := 0; j < pingAttempts; j++ { + start := time.Now() + resp, err := client.Get(url) + latency := time.Since(start) + if err != nil { + lastErr = err + logger.Warn("Failed to ping exit node %d (%s) attempt %d: %v", node.ID, url, j+1, err) + continue + } + resp.Body.Close() + totalLatency += latency + successes++ + } + var avgLatency time.Duration + if successes > 0 { + avgLatency = totalLatency / time.Duration(successes) + } + if successes == 0 { + results[i] = nodeResult{Node: node, Latency: 0, Err: lastErr} + } else { + results[i] = nodeResult{Node: node, Latency: avgLatency, Err: nil} + } + } + + // Prepare data to send to the cloud for selection + var pingResults []ExitNodePingResult + for _, res := range results { + errMsg := "" + if res.Err != nil { + errMsg = res.Err.Error() + } + pingResults = append(pingResults, ExitNodePingResult{ + ExitNodeID: res.Node.ID, + LatencyMs: res.Latency.Milliseconds(), + Weight: res.Node.Weight, + Error: errMsg, + Name: res.Node.Name, + Endpoint: res.Node.Endpoint, + WasPreviouslyConnected: res.Node.WasPreviouslyConnected, + }) + } + + // Send the ping results to the cloud for selection + stopFunc = client.SendMessageInterval("newt/wg/register", map[string]interface{}{ + "publicKey": publicKey.String(), + "pingResults": pingResults, + "newtVersion": newtVersion, + }, 1*time.Second) + + logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults) + }) + client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) { - logger.Info("Received: %+v", msg) + logger.Debug("Received: %+v", msg) // if there is no wgData or pm, we can't add targets if wgData.TunnelIP == "" || pm == nil { @@ -676,7 +675,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub } // List Docker containers - containers, err := docker.ListContainers(dockerSocket) + containers, err := docker.ListContainers(dockerSocket, dockerEnforceNetworkValidationBool) if err != nil { logger.Error("Failed to list Docker containers: %v", err) return @@ -686,6 +685,10 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub err = client.SendMessage("newt/socket/containers", map[string]interface{}{ "containers": containers, }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + } + if err != nil { logger.Error("Failed to send Docker container list: %v", err) } else { @@ -694,18 +697,29 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub }) client.OnConnect(func() error { - publicKey := privateKey.PublicKey() + publicKey = privateKey.PublicKey() logger.Debug("Public key: %s", publicKey) + logger.Info("Websocket connected") + + if !connected { + // request from the server the list of nodes to ping at newt/ping/request + stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) + logger.Info("Requesting exit nodes from server") + clientsOnConnect() + } + // Send registration message to the server for backward compatibility err := client.SendMessage("newt/wg/register", map[string]interface{}{ - "publicKey": publicKey.String(), + "publicKey": publicKey.String(), + "newtVersion": newtVersion, + "backwardsCompatible": true, }) + if err != nil { logger.Error("Failed to send registration message: %v", err) return err } - logger.Info("Sent registration message") return nil }) @@ -718,135 +732,19 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub // Wait for interrupt signal sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - sigReceived := <-sigCh + <-sigCh - // Cleanup - logger.Info("Received %s signal, stopping", sigReceived.String()) - if dev != nil { - dev.Close() - } -} + dev.Close() -func parseTargetData(data interface{}) (TargetData, error) { - var targetData TargetData - jsonData, err := json.Marshal(data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return targetData, err - } + closeClients() - if err := json.Unmarshal(jsonData, &targetData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return targetData, err + if pm != nil { + pm.Stop() } - return targetData, nil -} - -func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { - for _, t := range targetData.Targets { - // Split the first number off of the target with : separator and use as the port - parts := strings.Split(t, ":") - if len(parts) != 3 { - logger.Info("Invalid target format: %s", t) - continue - } - - // Get the port as an int - port := 0 - _, err := fmt.Sscanf(parts[0], "%d", &port) - if err != nil { - logger.Info("Invalid port: %s", parts[0]) - continue - } - if action == "add" { - target := parts[1] + ":" + parts[2] - - // Call updown script if provided - processedTarget := target - if updownScript != "" { - newTarget, err := executeUpdownScript(action, proto, target) - if err != nil { - logger.Warn("Updown script error: %v", err) - } else if newTarget != "" { - processedTarget = newTarget - } - } - - // Only remove the specific target if it exists - err := pm.RemoveTarget(proto, tunnelIP, port) - if err != nil { - // Ignore "target not found" errors as this is expected for new targets - if !strings.Contains(err.Error(), "target not found") { - logger.Error("Failed to remove existing target: %v", err) - } - } - - // Add the new target - pm.AddTarget(proto, tunnelIP, port, processedTarget) - - } else if action == "remove" { - logger.Info("Removing target with port %d", port) - - target := parts[1] + ":" + parts[2] - - // Call updown script if provided - if updownScript != "" { - _, err := executeUpdownScript(action, proto, target) - if err != nil { - logger.Warn("Updown script error: %v", err) - } - } - - err := pm.RemoveTarget(proto, tunnelIP, port) - if err != nil { - logger.Error("Failed to remove target: %v", err) - return err - } - } - } - - return nil -} - -func executeUpdownScript(action, proto, target string) (string, error) { - if updownScript == "" { - return target, nil - } - - // Split the updownScript in case it contains spaces (like "/usr/bin/python3 script.py") - parts := strings.Fields(updownScript) - if len(parts) == 0 { - return target, fmt.Errorf("invalid updown script command") - } - - var cmd *exec.Cmd - if len(parts) == 1 { - // If it's a single executable - logger.Info("Executing updown script: %s %s %s %s", updownScript, action, proto, target) - cmd = exec.Command(parts[0], action, proto, target) - } else { - // If it includes interpreter and script - args := append(parts[1:], action, proto, target) - logger.Info("Executing updown script: %s %s %s %s %s", parts[0], strings.Join(parts[1:], " "), action, proto, target) - cmd = exec.Command(parts[0], args...) - } - - output, err := cmd.Output() - if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - return "", fmt.Errorf("updown script execution failed (exit code %d): %s", - exitErr.ExitCode(), string(exitErr.Stderr)) - } - return "", fmt.Errorf("updown script execution failed: %v", err) - } - - // If the script returns a new target, use it - newTarget := strings.TrimSpace(string(output)) - if newTarget != "" { - logger.Info("Updown script returned new target: %s", newTarget) - return newTarget, nil + if client != nil { + client.Close() } - - return target, nil + logger.Info("Exiting...") + os.Exit(0) } diff --git a/network/network.go b/network/network.go new file mode 100644 index 0000000..e359219 --- /dev/null +++ b/network/network.go @@ -0,0 +1,195 @@ +package network + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "log" + "net" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/vishvananda/netlink" + "golang.org/x/net/bpf" + "golang.org/x/net/ipv4" +) + +const ( + udpProtocol = 17 + // EmptyUDPSize is the size of an empty UDP packet + EmptyUDPSize = 28 + timeout = time.Second * 10 +) + +// Server stores data relating to the server +type Server struct { + Hostname string + Addr *net.IPAddr + Port uint16 +} + +// PeerNet stores data about a peer's endpoint +type PeerNet struct { + Resolved bool + IP net.IP + Port uint16 + NewtID string +} + +// GetClientIP gets source ip address that will be used when sending data to dstIP +func GetClientIP(dstIP net.IP) net.IP { + routes, err := netlink.RouteGet(dstIP) + if err != nil { + log.Fatalln("Error getting route:", err) + } + return routes[0].Src +} + +// HostToAddr resolves a hostname, whether DNS or IP to a valid net.IPAddr +func HostToAddr(hostStr string) *net.IPAddr { + remoteAddrs, err := net.LookupHost(hostStr) + if err != nil { + log.Fatalln("Error parsing remote address:", err) + } + + for _, addrStr := range remoteAddrs { + if remoteAddr, err := net.ResolveIPAddr("ip4", addrStr); err == nil { + return remoteAddr + } + } + return nil +} + +// SetupRawConn creates an ipv4 and udp only RawConn and applies packet filtering +func SetupRawConn(server *Server, client *PeerNet) *ipv4.RawConn { + packetConn, err := net.ListenPacket("ip4:udp", client.IP.String()) + if err != nil { + log.Fatalln("Error creating packetConn:", err) + } + + rawConn, err := ipv4.NewRawConn(packetConn) + if err != nil { + log.Fatalln("Error creating rawConn:", err) + } + + ApplyBPF(rawConn, server, client) + + return rawConn +} + +// ApplyBPF constructs a BPF program and applies it to the RawConn +func ApplyBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet) { + const ipv4HeaderLen = 20 + const srcIPOffset = 12 + const srcPortOffset = ipv4HeaderLen + 0 + const dstPortOffset = ipv4HeaderLen + 2 + + ipArr := []byte(server.Addr.IP.To4()) + ipInt := uint32(ipArr[0])<<(3*8) + uint32(ipArr[1])<<(2*8) + uint32(ipArr[2])<<8 + uint32(ipArr[3]) + + bpfRaw, err := bpf.Assemble([]bpf.Instruction{ + bpf.LoadAbsolute{Off: srcIPOffset, Size: 4}, + bpf.JumpIf{Cond: bpf.JumpEqual, Val: ipInt, SkipFalse: 5, SkipTrue: 0}, + + bpf.LoadAbsolute{Off: srcPortOffset, Size: 2}, + bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(server.Port), SkipFalse: 3, SkipTrue: 0}, + + bpf.LoadAbsolute{Off: dstPortOffset, Size: 2}, + bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0}, + + bpf.RetConstant{Val: 1<<(8*4) - 1}, + bpf.RetConstant{Val: 0}, + }) + + if err != nil { + log.Fatalln("Error assembling BPF:", err) + } + + err = rawConn.SetBPF(bpfRaw) + if err != nil { + log.Fatalln("Error setting BPF:", err) + } +} + +// MakePacket constructs a request packet to send to the server +func MakePacket(payload []byte, server *Server, client *PeerNet) []byte { + buf := gopacket.NewSerializeBuffer() + + opts := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + + ipHeader := layers.IPv4{ + SrcIP: client.IP, + DstIP: server.Addr.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + + udpHeader := layers.UDP{ + SrcPort: layers.UDPPort(client.Port), + DstPort: layers.UDPPort(server.Port), + } + + payloadLayer := gopacket.Payload(payload) + + udpHeader.SetNetworkLayerForChecksum(&ipHeader) + + gopacket.SerializeLayers(buf, opts, &ipHeader, &udpHeader, &payloadLayer) + + return buf.Bytes() +} + +// SendPacket sends packet to the Server +func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *PeerNet) error { + fullPacket := MakePacket(packet, server, client) + _, err := conn.WriteToIP(fullPacket, server.Addr) + return err +} + +// SendDataPacket sends a JSON payload to the Server +func SendDataPacket(data interface{}, conn *ipv4.RawConn, server *Server, client *PeerNet) error { + jsonData, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal payload: %v", err) + } + + return SendPacket(jsonData, conn, server, client) +} + +// RecvPacket receives a UDP packet from server +func RecvPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, int, error) { + err := conn.SetReadDeadline(time.Now().Add(timeout)) + if err != nil { + return nil, 0, err + } + + response := make([]byte, 4096) + n, err := conn.Read(response) + if err != nil { + return nil, n, err + } + return response, n, nil +} + +// RecvDataPacket receives and unmarshals a JSON packet from server +func RecvDataPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, error) { + response, n, err := RecvPacket(conn, server, client) + if err != nil { + return nil, err + } + + // Extract payload from UDP packet + payload := response[EmptyUDPSize:n] + return payload, nil +} + +// ParseResponse takes a response packet and parses it into an IP and port +func ParseResponse(response []byte) (net.IP, uint16) { + ip := net.IP(response[:4]) + port := binary.BigEndian.Uint16(response[4:6]) + return ip, port +} diff --git a/proxy/manager.go b/proxy/manager.go index 0792acb..35d023a 100644 --- a/proxy/manager.go +++ b/proxy/manager.go @@ -213,7 +213,8 @@ func (pm *ProxyManager) startTarget(proto, listenIP string, port int, targetAddr return fmt.Errorf("unsupported protocol: %s", proto) } - logger.Info("Started %s proxy from %s:%d to %s", proto, listenIP, port, targetAddr) + logger.Info("Started %s proxy to %s", proto, targetAddr) + logger.Debug("Started %s proxy from %s:%d to %s", proto, listenIP, port, targetAddr) return nil } diff --git a/stub.go b/stub.go new file mode 100644 index 0000000..e2360ff --- /dev/null +++ b/stub.go @@ -0,0 +1,32 @@ +//go:build !linux + +package main + +import ( + "github.com/fosrl/newt/proxy" + "github.com/fosrl/newt/websocket" +) + +func setupClients(client *websocket.Client) { + return // This function is not implemented for non-Linux systems. +} + +func closeClients() { + // This function is not implemented for non-Linux systems. + return +} + +func clientsHandleNewtConnection(publicKey string) { + // This function is not implemented for non-Linux systems. + return +} + +func clientsOnConnect() { + // This function is not implemented for non-Linux systems. + return +} + +func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) { + // This function is not implemented for non-Linux systems. + return +} diff --git a/updates/updates.go b/updates/updates.go new file mode 100644 index 0000000..8d7de5e --- /dev/null +++ b/updates/updates.go @@ -0,0 +1,173 @@ +package updates + +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "time" +) + +// GitHubRelease represents the GitHub API response for a release +type GitHubRelease struct { + TagName string `json:"tag_name"` + Name string `json:"name"` + HTMLURL string `json:"html_url"` +} + +// Version represents a semantic version +type Version struct { + Major int + Minor int + Patch int +} + +// parseVersion parses a semantic version string (e.g., "v1.2.3" or "1.2.3") +func parseVersion(versionStr string) (Version, error) { + // Remove 'v' prefix if present + versionStr = strings.TrimPrefix(versionStr, "v") + + parts := strings.Split(versionStr, ".") + if len(parts) != 3 { + return Version{}, fmt.Errorf("invalid version format: %s", versionStr) + } + + major, err := strconv.Atoi(parts[0]) + if err != nil { + return Version{}, fmt.Errorf("invalid major version: %s", parts[0]) + } + + minor, err := strconv.Atoi(parts[1]) + if err != nil { + return Version{}, fmt.Errorf("invalid minor version: %s", parts[1]) + } + + patch, err := strconv.Atoi(parts[2]) + if err != nil { + return Version{}, fmt.Errorf("invalid patch version: %s", parts[2]) + } + + return Version{Major: major, Minor: minor, Patch: patch}, nil +} + +// isNewer returns true if v2 is newer than v1 +func (v1 Version) isNewer(v2 Version) bool { + if v2.Major > v1.Major { + return true + } + if v2.Major < v1.Major { + return false + } + + if v2.Minor > v1.Minor { + return true + } + if v2.Minor < v1.Minor { + return false + } + + return v2.Patch > v1.Patch +} + +// String returns the version as a string +func (v Version) String() string { + return fmt.Sprintf("%d.%d.%d", v.Major, v.Minor, v.Patch) +} + +// CheckForUpdate checks GitHub for a newer version and prints an update banner if found +func CheckForUpdate(owner, repo, currentVersion string) error { + if currentVersion == "version_replaceme" { + return nil + } + + // GitHub API URL for latest release + url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", owner, repo) + + // Create HTTP client with timeout + client := &http.Client{ + Timeout: 10 * time.Second, + } + + // Make the request + resp, err := client.Get(url) + if err != nil { + return fmt.Errorf("failed to fetch release info: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("GitHub API returned status: %d", resp.StatusCode) + } + + // Parse the JSON response + var release GitHubRelease + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return fmt.Errorf("failed to parse release info: %w", err) + } + + // Parse current and latest versions + currentVer, err := parseVersion(currentVersion) + if err != nil { + return fmt.Errorf("invalid current version: %w", err) + } + + latestVer, err := parseVersion(release.TagName) + if err != nil { + return fmt.Errorf("invalid latest version: %w", err) + } + + // Check if update is available + if currentVer.isNewer(latestVer) { + printUpdateBanner(currentVer.String(), latestVer.String(), release.HTMLURL) + } + + return nil +} + +// printUpdateBanner prints a colorful update notification banner +func printUpdateBanner(currentVersion, latestVersion, releaseURL string) { + const contentWidth = 70 // width between the border lines + + borderTop := "╔" + strings.Repeat("═", contentWidth) + "╗" + borderMid := "╠" + strings.Repeat("═", contentWidth) + "╣" + borderBot := "╚" + strings.Repeat("═", contentWidth) + "╝" + emptyLine := "║" + strings.Repeat(" ", contentWidth) + "║" + + lines := []string{ + borderTop, + "║" + centerText("UPDATE AVAILABLE", contentWidth) + "║", + borderMid, + emptyLine, + "║ Current Version: " + padRight(currentVersion, contentWidth-19) + "║", + "║ Latest Version: " + padRight(latestVersion, contentWidth-19) + "║", + emptyLine, + "║ A newer version is available! Please update to get the" + padRight("", contentWidth-56) + "║", + "║ latest features, bug fixes, and security improvements." + padRight("", contentWidth-56) + "║", + emptyLine, + "║ Release URL: " + padRight(releaseURL, contentWidth-15) + "║", + emptyLine, + borderBot, + } + + for _, line := range lines { + fmt.Println(line) + } +} + +// padRight pads s with spaces on the right to the given width +func padRight(s string, width int) string { + if len(s) > width { + return s[:width] + } + return s + strings.Repeat(" ", width-len(s)) +} + +// centerText centers s in a field of width w +func centerText(s string, w int) string { + if len(s) >= w { + return s[:w] + } + padding := (w - len(s)) / 2 + return strings.Repeat(" ", padding) + s + strings.Repeat(" ", w-len(s)-padding) +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..9bdab59 --- /dev/null +++ b/util.go @@ -0,0 +1,459 @@ +package main + +import ( + "bytes" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "net" + "os" + "os/exec" + "strings" + "time" + + "math/rand" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/proxy" + "github.com/fosrl/newt/websocket" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" +) + +func fixKey(key string) string { + // Remove any whitespace + key = strings.TrimSpace(key) + + // Decode from base64 + decoded, err := base64.StdEncoding.DecodeString(key) + if err != nil { + logger.Fatal("Error decoding base64: %v", err) + } + + // Convert to hex + return hex.EncodeToString(decoded) +} + +func ping(tnet *netstack.Net, dst string, timeout time.Duration) (time.Duration, error) { + logger.Debug("Pinging %s", dst) + socket, err := tnet.Dial("ping4", dst) + if err != nil { + return 0, fmt.Errorf("failed to create ICMP socket: %w", err) + } + defer socket.Close() + + requestPing := icmp.Echo{ + Seq: rand.Intn(1 << 16), + Data: []byte("f"), + } + + icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) + if err != nil { + return 0, fmt.Errorf("failed to marshal ICMP message: %w", err) + } + + if err := socket.SetReadDeadline(time.Now().Add(timeout)); err != nil { + return 0, fmt.Errorf("failed to set read deadline: %w", err) + } + + start := time.Now() + _, err = socket.Write(icmpBytes) + if err != nil { + return 0, fmt.Errorf("failed to write ICMP packet: %w", err) + } + + n, err := socket.Read(icmpBytes[:]) + if err != nil { + return 0, fmt.Errorf("failed to read ICMP packet: %w", err) + } + + replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) + if err != nil { + return 0, fmt.Errorf("failed to parse ICMP packet: %w", err) + } + + replyPing, ok := replyPacket.Body.(*icmp.Echo) + if !ok { + return 0, fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body) + } + + if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { + return 0, fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q", + replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data) + } + + latency := time.Since(start) + + logger.Debug("Ping to %s successful, latency: %v", dst, latency) + + return latency, nil +} + +func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopChan chan struct{}, err error) { + + if healthFile != "" { + err = os.Remove(healthFile) + if err != nil { + logger.Error("Failed to remove health file: %v", err) + } + } + + const ( + initialMaxAttempts = 5 + initialRetryDelay = 2 * time.Second + maxRetryDelay = 60 * time.Second // Cap the maximum delay + ) + + stopChan = make(chan struct{}) + attempt := 1 + retryDelay := initialRetryDelay + + // First try with the initial parameters + logger.Debug("Ping attempt %d", attempt) + if latency, err := ping(tnet, dst, timeout); err == nil { + // Successful ping + logger.Debug("Ping latency: %v", latency) + logger.Info("Tunnel connection to server established successfully!") + if healthFile != "" { + err := os.WriteFile(healthFile, []byte("ok"), 0644) + if err != nil { + logger.Warn("Failed to write health file: %v", err) + } + } + return stopChan, nil + } else { + logger.Warn("Ping attempt %d failed: %v", attempt, err) + } + + // Start a goroutine that will attempt pings indefinitely with increasing delays + go func() { + attempt = 2 // Continue from attempt 2 + + for { + select { + case <-stopChan: + return + default: + logger.Debug("Ping attempt %d", attempt) + + if latency, err := ping(tnet, dst, timeout); err != nil { + logger.Warn("Ping attempt %d failed: %v", attempt, err) + + // Increase delay after certain thresholds but cap it + if attempt%5 == 0 && retryDelay < maxRetryDelay { + retryDelay = time.Duration(float64(retryDelay) * 1.5) + if retryDelay > maxRetryDelay { + retryDelay = maxRetryDelay + } + logger.Info("Increasing ping retry delay to %v", retryDelay) + } + + time.Sleep(retryDelay) + attempt++ + } else { + // Successful ping + logger.Debug("Ping succeeded after %d attempts", attempt) + logger.Debug("Ping latency: %v", latency) + logger.Info("Tunnel connection to server established successfully!") + if healthFile != "" { + err := os.WriteFile(healthFile, []byte("ok"), 0644) + if err != nil { + logger.Warn("Failed to write health file: %v", err) + } + } + return + } + } + } + }() + + // Return an error for the first batch of attempts (to maintain compatibility with existing code) + return stopChan, fmt.Errorf("initial ping attempts failed, continuing in background") +} + +func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client) chan struct{} { + initialInterval := pingInterval + maxInterval := 3 * time.Second + currentInterval := initialInterval + consecutiveFailures := 0 + connectionLost := false + + pingStopChan := make(chan struct{}) + + go func() { + ticker := time.NewTicker(currentInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + _, err := ping(tnet, serverIP, pingTimeout) + if err != nil { + consecutiveFailures++ + logger.Warn("Periodic ping failed (%d consecutive failures): %v", consecutiveFailures, err) + if consecutiveFailures >= 3 && currentInterval < maxInterval { + if !connectionLost { + connectionLost = true + logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.") + stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) + // Send registration message to the server for backward compatibility + err := client.SendMessage("newt/wg/register", map[string]interface{}{ + "publicKey": publicKey.String(), + "backwardsCompatible": true, + }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + } + if healthFile != "" { + err = os.Remove(healthFile) + if err != nil { + logger.Error("Failed to remove health file: %v", err) + } + } + } + currentInterval = time.Duration(float64(currentInterval) * 1.5) + if currentInterval > maxInterval { + currentInterval = maxInterval + } + ticker.Reset(currentInterval) + logger.Debug("Increased ping check interval to %v due to consecutive failures", currentInterval) + } + } else { + if connectionLost { + connectionLost = false + logger.Info("Connection to server restored!") + if healthFile != "" { + err := os.WriteFile(healthFile, []byte("ok"), 0644) + if err != nil { + logger.Warn("Failed to write health file: %v", err) + } + } + } + if currentInterval > initialInterval { + currentInterval = time.Duration(float64(currentInterval) * 0.8) + if currentInterval < initialInterval { + currentInterval = initialInterval + } + ticker.Reset(currentInterval) + logger.Info("Decreased ping check interval to %v after successful ping", currentInterval) + } + consecutiveFailures = 0 + } + case <-pingStopChan: + logger.Info("Stopping ping check") + return + } + } + }() + + return pingStopChan +} + +func parseLogLevel(level string) logger.LogLevel { + switch strings.ToUpper(level) { + case "DEBUG": + return logger.DEBUG + case "INFO": + return logger.INFO + case "WARN": + return logger.WARN + case "ERROR": + return logger.ERROR + case "FATAL": + return logger.FATAL + default: + return logger.INFO // default to INFO if invalid level provided + } +} + +func mapToWireGuardLogLevel(level logger.LogLevel) int { + switch level { + case logger.DEBUG: + return device.LogLevelVerbose + // case logger.INFO: + // return device.LogLevel + case logger.WARN: + return device.LogLevelError + case logger.ERROR, logger.FATAL: + return device.LogLevelSilent + default: + return device.LogLevelSilent + } +} + +func resolveDomain(domain string) (string, error) { + // Check if there's a port in the domain + host, port, err := net.SplitHostPort(domain) + if err != nil { + // No port found, use the domain as is + host = domain + port = "" + } + + // Remove any protocol prefix if present + if strings.HasPrefix(host, "http://") { + host = strings.TrimPrefix(host, "http://") + } else if strings.HasPrefix(host, "https://") { + host = strings.TrimPrefix(host, "https://") + } + + // if there are any trailing slashes, remove them + host = strings.TrimSuffix(host, "/") + + // Lookup IP addresses + ips, err := net.LookupIP(host) + if err != nil { + return "", fmt.Errorf("DNS lookup failed: %v", err) + } + + if len(ips) == 0 { + return "", fmt.Errorf("no IP addresses found for domain %s", host) + } + + // Get the first IPv4 address if available + var ipAddr string + for _, ip := range ips { + if ipv4 := ip.To4(); ipv4 != nil { + ipAddr = ipv4.String() + break + } + } + + // If no IPv4 found, use the first IP (might be IPv6) + if ipAddr == "" { + ipAddr = ips[0].String() + } + + // Add port back if it existed + if port != "" { + ipAddr = net.JoinHostPort(ipAddr, port) + } + + return ipAddr, nil +} + +func parseTargetData(data interface{}) (TargetData, error) { + var targetData TargetData + jsonData, err := json.Marshal(data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return targetData, err + } + + if err := json.Unmarshal(jsonData, &targetData); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return targetData, err + } + return targetData, nil +} + +func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { + for _, t := range targetData.Targets { + // Split the first number off of the target with : separator and use as the port + parts := strings.Split(t, ":") + if len(parts) != 3 { + logger.Info("Invalid target format: %s", t) + continue + } + + // Get the port as an int + port := 0 + _, err := fmt.Sscanf(parts[0], "%d", &port) + if err != nil { + logger.Info("Invalid port: %s", parts[0]) + continue + } + + if action == "add" { + target := parts[1] + ":" + parts[2] + + // Call updown script if provided + processedTarget := target + if updownScript != "" { + newTarget, err := executeUpdownScript(action, proto, target) + if err != nil { + logger.Warn("Updown script error: %v", err) + } else if newTarget != "" { + processedTarget = newTarget + } + } + + // Only remove the specific target if it exists + err := pm.RemoveTarget(proto, tunnelIP, port) + if err != nil { + // Ignore "target not found" errors as this is expected for new targets + if !strings.Contains(err.Error(), "target not found") { + logger.Error("Failed to remove existing target: %v", err) + } + } + + // Add the new target + pm.AddTarget(proto, tunnelIP, port, processedTarget) + + } else if action == "remove" { + logger.Info("Removing target with port %d", port) + + target := parts[1] + ":" + parts[2] + + // Call updown script if provided + if updownScript != "" { + _, err := executeUpdownScript(action, proto, target) + if err != nil { + logger.Warn("Updown script error: %v", err) + } + } + + err := pm.RemoveTarget(proto, tunnelIP, port) + if err != nil { + logger.Error("Failed to remove target: %v", err) + return err + } + } + } + + return nil +} + +func executeUpdownScript(action, proto, target string) (string, error) { + if updownScript == "" { + return target, nil + } + + // Split the updownScript in case it contains spaces (like "/usr/bin/python3 script.py") + parts := strings.Fields(updownScript) + if len(parts) == 0 { + return target, fmt.Errorf("invalid updown script command") + } + + var cmd *exec.Cmd + if len(parts) == 1 { + // If it's a single executable + logger.Info("Executing updown script: %s %s %s %s", updownScript, action, proto, target) + cmd = exec.Command(parts[0], action, proto, target) + } else { + // If it includes interpreter and script + args := append(parts[1:], action, proto, target) + logger.Info("Executing updown script: %s %s %s %s %s", parts[0], strings.Join(parts[1:], " "), action, proto, target) + cmd = exec.Command(parts[0], args...) + } + + output, err := cmd.Output() + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + return "", fmt.Errorf("updown script execution failed (exit code %d): %s", + exitErr.ExitCode(), string(exitErr.Stderr)) + } + return "", fmt.Errorf("updown script execution failed: %v", err) + } + + // If the script returns a new target, use it + newTarget := strings.TrimSpace(string(output)) + if newTarget != "" { + logger.Info("Updown script returned new target: %s", newTarget) + return newTarget, nil + } + + return target, nil +} diff --git a/websocket/client.go b/websocket/client.go index 3d221e1..98f07e6 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -9,11 +9,12 @@ import ( "net/http" "net/url" "os" - "software.sslmate.com/src/go-pkcs12" "strings" "sync" "time" + "software.sslmate.com/src/go-pkcs12" + "github.com/fosrl/newt/logger" "github.com/gorilla/websocket" ) @@ -28,8 +29,11 @@ type Client struct { reconnectInterval time.Duration isConnected bool reconnectMux sync.RWMutex - - onConnect func() error + pingInterval time.Duration + pingTimeout time.Duration + onConnect func() error + onTokenUpdate func(token string) + writeMux sync.Mutex } type ClientOption func(*Client) @@ -53,8 +57,12 @@ func (c *Client) OnConnect(callback func() error) { c.onConnect = callback } +func (c *Client) OnTokenUpdate(callback func(token string)) { + c.onTokenUpdate = callback +} + // NewClient creates a new Newt client -func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*Client, error) { +func NewClient(newtID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) { config := &Config{ NewtID: newtID, Secret: secret, @@ -66,18 +74,18 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C baseURL: endpoint, // default value handlers: make(map[string]MessageHandler), done: make(chan struct{}), - reconnectInterval: 10 * time.Second, + reconnectInterval: 3 * time.Second, isConnected: false, + pingInterval: pingInterval, + pingTimeout: pingTimeout, } // Apply options before loading config - if opts != nil { - for _, opt := range opts { - if opt == nil { - continue - } - opt(client) + for _, opt := range opts { + if opt == nil { + continue } + opt(client) } // Load existing config if available @@ -94,16 +102,31 @@ func (c *Client) Connect() error { return nil } -// Close closes the WebSocket connection +// Close closes the WebSocket connection gracefully func (c *Client) Close() error { - close(c.done) - if c.conn != nil { - return c.conn.Close() + // Signal shutdown to all goroutines first + select { + case <-c.done: + // Already closed + return nil + default: + close(c.done) } - // stop the ping monitor + // Set connection status to false c.setConnected(false) + // Close the WebSocket connection gracefully + if c.conn != nil { + // Send close message + c.writeMux.Lock() + c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + c.writeMux.Unlock() + + // Close the connection + return c.conn.Close() + } + return nil } @@ -118,9 +141,39 @@ func (c *Client) SendMessage(messageType string, data interface{}) error { Data: data, } + logger.Debug("Sending message: %s, data: %+v", messageType, data) + + c.writeMux.Lock() + defer c.writeMux.Unlock() return c.conn.WriteJSON(msg) } +func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) { + stopChan := make(chan struct{}) + go func() { + err := c.SendMessage(messageType, data) // Send immediately + if err != nil { + logger.Error("Failed to send initial message: %v", err) + } + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + err = c.SendMessage(messageType, data) + if err != nil { + logger.Error("Failed to send message: %v", err) + } + case <-stopChan: + return + } + } + }() + return func() { + close(stopChan) + } +} + // RegisterHandler registers a handler for a specific message type func (c *Client) RegisterHandler(messageType string, handler MessageHandler) { c.handlersMux.Lock() @@ -128,30 +181,6 @@ func (c *Client) RegisterHandler(messageType string, handler MessageHandler) { c.handlers[messageType] = handler } -// readPump pumps messages from the WebSocket connection -func (c *Client) readPump() { - defer c.conn.Close() - - for { - select { - case <-c.done: - return - default: - var msg WSMessage - err := c.conn.ReadJSON(&msg) - if err != nil { - return - } - - c.handlersMux.RLock() - if handler, ok := c.handlers[msg.Type]; ok { - handler(msg) - } - c.handlersMux.RUnlock() - } - } -} - func (c *Client) getToken() (string, error) { // Parse the base URL to ensure we have the correct hostname baseURL, err := url.Parse(c.baseURL) @@ -170,56 +199,6 @@ func (c *Client) getToken() (string, error) { } } - // If we already have a token, try to use it - if c.config.Token != "" { - tokenCheckData := map[string]interface{}{ - "newtId": c.config.NewtID, - "secret": c.config.Secret, - "token": c.config.Token, - } - jsonData, err := json.Marshal(tokenCheckData) - if err != nil { - return "", fmt.Errorf("failed to marshal token check data: %w", err) - } - - // Create a new request - req, err := http.NewRequest( - "POST", - baseEndpoint+"/api/v1/auth/newt/get-token", - bytes.NewBuffer(jsonData), - ) - if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-CSRF-Token", "x-csrf-protection") - - // Make the request - client := &http.Client{} - if tlsConfig != nil { - client.Transport = &http.Transport{ - TLSClientConfig: tlsConfig, - } - } - resp, err := client.Do(req) - if err != nil { - return "", fmt.Errorf("failed to check token validity: %w", err) - } - defer resp.Body.Close() - - var tokenResp TokenResponse - if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - return "", fmt.Errorf("failed to decode token check response: %w", err) - } - - // If token is still valid, return it - if tokenResp.Success && tokenResp.Message == "Token session already valid" { - return c.config.Token, nil - } - } - // Get a new token tokenData := map[string]interface{}{ "newtId": c.config.NewtID, @@ -257,12 +236,14 @@ func (c *Client) getToken() (string, error) { } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + logger.Error("Failed to get token with status code: %d", resp.StatusCode) + return "", fmt.Errorf("failed to get token with status code: %d", resp.StatusCode) + } + var tokenResp TokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - // print out the token response for debugging - buf := new(bytes.Buffer) - buf.ReadFrom(resp.Body) - logger.Info("Token response: %s", buf.String()) + logger.Error("Failed to decode token response.") return "", fmt.Errorf("failed to decode token response: %w", err) } @@ -274,6 +255,8 @@ func (c *Client) getToken() (string, error) { return "", fmt.Errorf("received empty token from server") } + logger.Debug("Received token: %s", tokenResp.Data.Token) + return tokenResp.Data.Token, nil } @@ -301,6 +284,10 @@ func (c *Client) establishConnection() error { return fmt.Errorf("failed to get token: %w", err) } + if c.onTokenUpdate != nil { + c.onTokenUpdate(token) + } + // Parse the base URL to determine protocol and hostname baseURL, err := url.Parse(c.baseURL) if err != nil { @@ -323,6 +310,7 @@ func (c *Client) establishConnection() error { // Add token to query parameters q := u.Query() q.Set("token", token) + q.Set("clientType", "newt") u.RawQuery = q.Encode() // Connect to WebSocket @@ -345,8 +333,8 @@ func (c *Client) establishConnection() error { // Start the ping monitor go c.pingMonitor() - // Start the read pump - go c.readPump() + // Start the read pump with disconnect detection + go c.readPumpWithDisconnectDetection() if c.onConnect != nil { err := c.saveConfig() @@ -361,8 +349,9 @@ func (c *Client) establishConnection() error { return nil } +// pingMonitor sends pings at a short interval and triggers reconnect on failure func (c *Client) pingMonitor() { - ticker := time.NewTicker(30 * time.Second) + ticker := time.NewTicker(c.pingInterval) defer ticker.Stop() for { @@ -370,11 +359,74 @@ func (c *Client) pingMonitor() { case <-c.done: return case <-ticker.C: - if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { - logger.Error("Ping failed: %v", err) - c.reconnect() + if c.conn == nil { return } + c.writeMux.Lock() + err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout)) + c.writeMux.Unlock() + if err != nil { + // Check if we're shutting down before logging error and reconnecting + select { + case <-c.done: + // Expected during shutdown + return + default: + logger.Error("Ping failed: %v", err) + c.reconnect() + return + } + } + } + } +} + +// readPumpWithDisconnectDetection reads messages and triggers reconnect on error +func (c *Client) readPumpWithDisconnectDetection() { + defer func() { + if c.conn != nil { + c.conn.Close() + } + // Only attempt reconnect if we're not shutting down + select { + case <-c.done: + // Shutting down, don't reconnect + return + default: + c.reconnect() + } + }() + + for { + select { + case <-c.done: + return + default: + var msg WSMessage + err := c.conn.ReadJSON(&msg) + if err != nil { + // Check if we're shutting down before logging error + select { + case <-c.done: + // Expected during shutdown, don't log as error + logger.Debug("WebSocket connection closed during shutdown") + return + default: + // Unexpected error during normal operation + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { + logger.Error("WebSocket read error: %v", err) + } else { + logger.Debug("WebSocket connection closed: %v", err) + } + return // triggers reconnect via defer + } + } + + c.handlersMux.RLock() + if handler, ok := c.handlers[msg.Type]; ok { + handler(msg) + } + c.handlersMux.RUnlock() } } } @@ -383,9 +435,16 @@ func (c *Client) reconnect() { c.setConnected(false) if c.conn != nil { c.conn.Close() + c.conn = nil } - go c.connectWithRetry() + // Only reconnect if we're not shutting down + select { + case <-c.done: + return + default: + go c.connectWithRetry() + } } func (c *Client) setConnected(status bool) { diff --git a/websocket/config.go b/websocket/config.go index e2b0055..fe11c5a 100644 --- a/websocket/config.go +++ b/websocket/config.go @@ -48,9 +48,6 @@ func (c *Client) loadConfig() error { if c.config.NewtID == "" { c.config.NewtID = config.NewtID } - if c.config.Token == "" { - c.config.Token = config.Token - } if c.config.Secret == "" { c.config.Secret = config.Secret } diff --git a/websocket/types.go b/websocket/types.go index 0ea24fc..54d33f1 100644 --- a/websocket/types.go +++ b/websocket/types.go @@ -3,7 +3,6 @@ package websocket type Config struct { NewtID string `json:"newtId"` Secret string `json:"secret"` - Token string `json:"token"` Endpoint string `json:"endpoint"` TlsClientCert string `json:"tlsClientCert"` } diff --git a/wg/wg.go b/wg/wg.go new file mode 100644 index 0000000..1c378ea --- /dev/null +++ b/wg/wg.go @@ -0,0 +1,981 @@ +//go:build linux + +package wg + +import ( + "encoding/json" + "fmt" + "net" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/network" + "github.com/fosrl/newt/websocket" + "github.com/vishvananda/netlink" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" + "golang.org/x/exp/rand" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type WgConfig struct { + IpAddress string `json:"ipAddress"` + Peers []Peer `json:"peers"` +} + +type Peer struct { + PublicKey string `json:"publicKey"` + AllowedIPs []string `json:"allowedIps"` + Endpoint string `json:"endpoint"` +} + +type PeerBandwidth struct { + PublicKey string `json:"publicKey"` + BytesIn float64 `json:"bytesIn"` + BytesOut float64 `json:"bytesOut"` +} + +type PeerReading struct { + BytesReceived int64 + BytesTransmitted int64 + LastChecked time.Time +} + +type WireGuardService struct { + interfaceName string + mtu int + client *websocket.Client + wgClient *wgctrl.Client + config WgConfig + key wgtypes.Key + newtId string + lastReadings map[string]PeerReading + mu sync.Mutex + Port uint16 + stopHolepunch chan struct{} + host string + serverPubKey string + token string + stopGetConfig chan struct{} +} + +// Add this type definition +type fixedPortBind struct { + port uint16 + conn.Bind +} + +func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { + // Ignore the requested port and use our fixed port + return b.Bind.Open(b.port) +} + +func NewFixedPortBind(port uint16) conn.Bind { + return &fixedPortBind{ + port: port, + Bind: conn.NewDefaultBind(), + } +} + +// find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester +func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { + if maxPort < minPort { + return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) + } + + // We need to check port+1 as well, so adjust the max port to avoid going out of range + adjustedMaxPort := maxPort - 1 + if adjustedMaxPort < minPort { + return 0, fmt.Errorf("insufficient port range to find consecutive ports: min=%d, max=%d", minPort, maxPort) + } + + // Create a slice of all ports in the range (excluding the last one) + portRange := make([]uint16, adjustedMaxPort-minPort+1) + for i := range portRange { + portRange[i] = minPort + uint16(i) + } + + // Fisher-Yates shuffle to randomize the port order + rand.Seed(uint64(time.Now().UnixNano())) + for i := len(portRange) - 1; i > 0; i-- { + j := rand.Intn(i + 1) + portRange[i], portRange[j] = portRange[j], portRange[i] + } + + // Try each port in the randomized order + for _, port := range portRange { + // Check if port is available + addr1 := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: int(port), + } + conn1, err1 := net.ListenUDP("udp", addr1) + if err1 != nil { + continue // Port is in use or there was an error, try next port + } + + // Check if port+1 is also available + addr2 := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: int(port + 1), + } + conn2, err2 := net.ListenUDP("udp", addr2) + if err2 != nil { + // The next port is not available, so close the first connection and try again + conn1.Close() + continue + } + + // Both ports are available, close connections and return the first port + conn1.Close() + conn2.Close() + return port, nil + } + + return 0, fmt.Errorf("no available consecutive UDP ports found in range %d-%d", minPort, maxPort) +} + +func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) { + wgClient, err := wgctrl.New() + if err != nil { + return nil, fmt.Errorf("failed to create WireGuard client: %v", err) + } + + var key wgtypes.Key + // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file + if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { + // generate a new private key + key, err = wgtypes.GeneratePrivateKey() + if err != nil { + logger.Fatal("Failed to generate private key: %v", err) + } + // save the key to the file + err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644) + if err != nil { + logger.Fatal("Failed to save private key: %v", err) + } + } else { + keyData, err := os.ReadFile(generateAndSaveKeyTo) + if err != nil { + logger.Fatal("Failed to read private key: %v", err) + } + key, err = wgtypes.ParseKey(string(keyData)) + if err != nil { + logger.Fatal("Failed to parse private key: %v", err) + } + } + + service := &WireGuardService{ + interfaceName: interfaceName, + mtu: mtu, + client: wsClient, + wgClient: wgClient, + key: key, + newtId: newtId, + host: host, + lastReadings: make(map[string]PeerReading), + stopHolepunch: make(chan struct{}), + stopGetConfig: make(chan struct{}), + } + + // Get the existing wireguard port (keep this part) + device, err := service.wgClient.Device(service.interfaceName) + if err == nil { + service.Port = uint16(device.ListenPort) + logger.Info("WireGuard interface %s already exists with port %d\n", service.interfaceName, service.Port) + } else { + service.Port, err = FindAvailableUDPPort(49152, 65535) + if err != nil { + fmt.Printf("Error finding available port: %v\n", err) + return nil, err + } + } + + // Register websocket handlers + wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig) + wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer) + wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer) + wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer) + + if err := service.sendUDPHolePunch(service.host + ":21820"); err != nil { + logger.Error("Failed to send UDP hole punch: %v", err) + } + + // start the UDP holepunch + go service.keepSendingUDPHolePunch(service.host) + + return service, nil +} + +func (s *WireGuardService) Close(rm bool) { + select { + case <-s.stopGetConfig: + // Already closed, do nothing + default: + close(s.stopGetConfig) + } + + s.wgClient.Close() + // Remove the WireGuard interface + if rm { + if err := s.removeInterface(); err != nil { + logger.Error("Failed to remove WireGuard interface: %v", err) + } + + // Remove the private key file + if err := os.Remove(s.key.String()); err != nil { + logger.Error("Failed to remove private key file: %v", err) + } + } +} + +func (s *WireGuardService) SetServerPubKey(serverPubKey string) { + s.serverPubKey = serverPubKey +} + +func (s *WireGuardService) SetToken(token string) { + s.token = token +} + +func (s *WireGuardService) LoadRemoteConfig() error { + // Send the initial message + err := s.sendGetConfigMessage() + if err != nil { + logger.Error("Failed to send initial get-config message: %v", err) + return err + } + + // Start goroutine to periodically send the message until config is received + go s.keepSendingGetConfig() + + go s.periodicBandwidthCheck() + + return nil +} + +func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { + var config WgConfig + + logger.Info("Received message: %v", msg) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &config); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return + } + s.config = config + + close(s.stopGetConfig) + + // Ensure the WireGuard interface and peers are configured + if err := s.ensureWireguardInterface(config); err != nil { + logger.Error("Failed to ensure WireGuard interface: %v", err) + } + + if err := s.ensureWireguardPeers(config.Peers); err != nil { + logger.Error("Failed to ensure WireGuard peers: %v", err) + } +} + +func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { + // Check if the WireGuard interface exists + _, err := netlink.LinkByName(s.interfaceName) + if err != nil { + if _, ok := err.(netlink.LinkNotFoundError); ok { + // Interface doesn't exist, so create it + err = s.createWireGuardInterface() + if err != nil { + logger.Fatal("Failed to create WireGuard interface: %v", err) + } + logger.Info("Created WireGuard interface %s\n", s.interfaceName) + } else { + logger.Fatal("Error checking for WireGuard interface: %v", err) + } + } else { + logger.Info("WireGuard interface %s already exists\n", s.interfaceName) + + // get the exising wireguard port + device, err := s.wgClient.Device(s.interfaceName) + if err != nil { + return fmt.Errorf("failed to get device: %v", err) + } + + // get the existing port + s.Port = uint16(device.ListenPort) + logger.Info("WireGuard interface %s already exists with port %d\n", s.interfaceName, s.Port) + + return nil + } + + logger.Info("Assigning IP address %s to interface %s\n", wgconfig.IpAddress, s.interfaceName) + // Assign IP address to the interface + err = s.assignIPAddress(wgconfig.IpAddress) + if err != nil { + logger.Fatal("Failed to assign IP address: %v", err) + } + + // Check if the interface already exists + _, err = s.wgClient.Device(s.interfaceName) + if err != nil { + return fmt.Errorf("interface %s does not exist", s.interfaceName) + } + + // Parse the private key + key, err := wgtypes.ParseKey(s.key.String()) + if err != nil { + return fmt.Errorf("failed to parse private key: %v", err) + } + + config := wgtypes.Config{ + PrivateKey: &key, + ListenPort: new(int), + } + + // Use the service's fixed port instead of the config port + *config.ListenPort = int(s.Port) + + // Create and configure the WireGuard interface + err = s.wgClient.ConfigureDevice(s.interfaceName, config) + if err != nil { + return fmt.Errorf("failed to configure WireGuard device: %v", err) + } + + // bring up the interface + link, err := netlink.LinkByName(s.interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface: %v", err) + } + + if err := netlink.LinkSetMTU(link, s.mtu); err != nil { + return fmt.Errorf("failed to set MTU: %v", err) + } + + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to bring up interface: %v", err) + } + + // if err := s.ensureMSSClamping(); err != nil { + // logger.Warn("Failed to ensure MSS clamping: %v", err) + // } + + logger.Info("WireGuard interface %s created and configured", s.interfaceName) + + return nil +} + +func (s *WireGuardService) createWireGuardInterface() error { + wgLink := &netlink.GenericLink{ + LinkAttrs: netlink.LinkAttrs{Name: s.interfaceName}, + LinkType: "wireguard", + } + return netlink.LinkAdd(wgLink) +} + +func (s *WireGuardService) assignIPAddress(ipAddress string) error { + link, err := netlink.LinkByName(s.interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface: %v", err) + } + + addr, err := netlink.ParseAddr(ipAddress) + if err != nil { + return fmt.Errorf("failed to parse IP address: %v", err) + } + + return netlink.AddrAdd(link, addr) +} + +func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { + // get the current peers + device, err := s.wgClient.Device(s.interfaceName) + if err != nil { + return fmt.Errorf("failed to get device: %v", err) + } + + // get the peer public keys + var currentPeers []string + for _, peer := range device.Peers { + currentPeers = append(currentPeers, peer.PublicKey.String()) + } + + // remove any peers that are not in the config + for _, peer := range currentPeers { + found := false + for _, configPeer := range peers { + if peer == configPeer.PublicKey { + found = true + break + } + } + if !found { + err := s.removePeer(peer) + if err != nil { + return fmt.Errorf("failed to remove peer: %v", err) + } + } + } + + // add any peers that are in the config but not in the current peers + for _, configPeer := range peers { + found := false + for _, peer := range currentPeers { + if configPeer.PublicKey == peer { + found = true + break + } + } + if !found { + err := s.addPeer(configPeer) + if err != nil { + return fmt.Errorf("failed to add peer: %v", err) + } + } + } + + return nil +} + +func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) { + logger.Info("Received message: %v", msg.Data) + var peer Peer + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + } + + if err := json.Unmarshal(jsonData, &peer); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + } + + err = s.addPeer(peer) + if err != nil { + logger.Info("Error adding peer: %v", err) + return + } +} + +func (s *WireGuardService) addPeer(peer Peer) error { + pubKey, err := wgtypes.ParseKey(peer.PublicKey) + if err != nil { + return fmt.Errorf("failed to parse public key: %v", err) + } + + // parse allowed IPs into array of net.IPNet + var allowedIPs []net.IPNet + for _, ipStr := range peer.AllowedIPs { + _, ipNet, err := net.ParseCIDR(ipStr) + if err != nil { + return fmt.Errorf("failed to parse allowed IP: %v", err) + } + allowedIPs = append(allowedIPs, *ipNet) + } + // add keep alive using *time.Duration of 1 second + keepalive := time.Second + + var peerConfig wgtypes.PeerConfig + if peer.Endpoint != "" { + endpoint, err := net.ResolveUDPAddr("udp", peer.Endpoint) + if err != nil { + return fmt.Errorf("failed to resolve endpoint address: %w", err) + } + + peerConfig = wgtypes.PeerConfig{ + PublicKey: pubKey, + AllowedIPs: allowedIPs, + PersistentKeepaliveInterval: &keepalive, + Endpoint: endpoint, + } + } else { + peerConfig = wgtypes.PeerConfig{ + PublicKey: pubKey, + AllowedIPs: allowedIPs, + PersistentKeepaliveInterval: &keepalive, + } + logger.Info("Added peer with no endpoint!") + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peerConfig}, + } + + if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { + return fmt.Errorf("failed to add peer: %v", err) + } + + logger.Info("Peer %s added successfully", peer.PublicKey) + + return nil +} + +func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) { + logger.Info("Received message: %v", msg.Data) + // parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" } + type RemoveRequest struct { + PublicKey string `json:"publicKey"` + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + } + + var request RemoveRequest + if err := json.Unmarshal(jsonData, &request); err != nil { + logger.Info("Error unmarshaling data: %v", err) + return + } + + if err := s.removePeer(request.PublicKey); err != nil { + logger.Info("Error removing peer: %v", err) + return + } +} + +func (s *WireGuardService) removePeer(publicKey string) error { + pubKey, err := wgtypes.ParseKey(publicKey) + if err != nil { + return fmt.Errorf("failed to parse public key: %v", err) + } + + peerConfig := wgtypes.PeerConfig{ + PublicKey: pubKey, + Remove: true, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peerConfig}, + } + + if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { + return fmt.Errorf("failed to remove peer: %v", err) + } + + logger.Info("Peer %s removed successfully", publicKey) + + return nil +} + +func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) { + logger.Info("Received message: %v", msg.Data) + // Define a struct to match the incoming message structure with optional fields + type UpdatePeerRequest struct { + PublicKey string `json:"publicKey"` + AllowedIPs []string `json:"allowedIps,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + } + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + var request UpdatePeerRequest + if err := json.Unmarshal(jsonData, &request); err != nil { + logger.Info("Error unmarshaling peer data: %v", err) + return + } + // First, get the current peer configuration to preserve any unmodified fields + device, err := s.wgClient.Device(s.interfaceName) + if err != nil { + logger.Info("Error getting WireGuard device: %v", err) + return + } + pubKey, err := wgtypes.ParseKey(request.PublicKey) + if err != nil { + logger.Info("Error parsing public key: %v", err) + return + } + // Find the existing peer configuration + var currentPeer *wgtypes.Peer + for _, p := range device.Peers { + if p.PublicKey == pubKey { + currentPeer = &p + break + } + } + if currentPeer == nil { + logger.Info("Peer %s not found, cannot update", request.PublicKey) + return + } + // Create the update peer config + peerConfig := wgtypes.PeerConfig{ + PublicKey: pubKey, + UpdateOnly: true, + } + // Keep the default persistent keepalive of 1 second + keepalive := time.Second + peerConfig.PersistentKeepaliveInterval = &keepalive + + // Handle Endpoint field special case + // If Endpoint is included in the request but empty, we want to remove the endpoint + // If Endpoint is not included, we don't modify it + endpointSpecified := false + for key := range msg.Data.(map[string]interface{}) { + if key == "endpoint" { + endpointSpecified = true + break + } + } + + // Only update AllowedIPs if provided in the request + if request.AllowedIPs != nil && len(request.AllowedIPs) > 0 { + var allowedIPs []net.IPNet + for _, ipStr := range request.AllowedIPs { + _, ipNet, err := net.ParseCIDR(ipStr) + if err != nil { + logger.Info("Error parsing allowed IP %s: %v", ipStr, err) + return + } + allowedIPs = append(allowedIPs, *ipNet) + } + peerConfig.AllowedIPs = allowedIPs + peerConfig.ReplaceAllowedIPs = true + logger.Info("Updating AllowedIPs for peer %s", request.PublicKey) + } else if endpointSpecified && request.Endpoint == "" { + peerConfig.ReplaceAllowedIPs = false + } + + if endpointSpecified { + if request.Endpoint != "" { + // Update to new endpoint + endpoint, err := net.ResolveUDPAddr("udp", request.Endpoint) + if err != nil { + logger.Info("Error resolving endpoint address %s: %v", request.Endpoint, err) + return + } + peerConfig.Endpoint = endpoint + logger.Info("Updating Endpoint for peer %s to %s", request.PublicKey, request.Endpoint) + } else { + // specify any address to listen for any incoming packets + peerConfig.Endpoint = &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + } + logger.Info("Removing Endpoint for peer %s", request.PublicKey) + } + } + + // Apply the configuration update + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peerConfig}, + } + if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { + logger.Info("Error updating peer configuration: %v", err) + return + } + logger.Info("Peer %s updated successfully", request.PublicKey) +} + +func (s *WireGuardService) periodicBandwidthCheck() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for range ticker.C { + if err := s.reportPeerBandwidth(); err != nil { + logger.Info("Failed to report peer bandwidth: %v", err) + } + } +} + +func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { + device, err := s.wgClient.Device(s.interfaceName) + if err != nil { + return nil, fmt.Errorf("failed to get device: %v", err) + } + + peerBandwidths := []PeerBandwidth{} + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + for _, peer := range device.Peers { + publicKey := peer.PublicKey.String() + currentReading := PeerReading{ + BytesReceived: peer.ReceiveBytes, + BytesTransmitted: peer.TransmitBytes, + LastChecked: now, + } + + var bytesInDiff, bytesOutDiff float64 + lastReading, exists := s.lastReadings[publicKey] + + if exists { + timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds() + if timeDiff > 0 { + // Calculate bytes transferred since last reading + bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived) + bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted) + + // Handle counter wraparound (if the counter resets or overflows) + if bytesInDiff < 0 { + bytesInDiff = float64(currentReading.BytesReceived) + } + if bytesOutDiff < 0 { + bytesOutDiff = float64(currentReading.BytesTransmitted) + } + + // Convert to MB + bytesInMB := bytesInDiff / (1024 * 1024) + bytesOutMB := bytesOutDiff / (1024 * 1024) + + peerBandwidths = append(peerBandwidths, PeerBandwidth{ + PublicKey: publicKey, + BytesIn: bytesInMB, + BytesOut: bytesOutMB, + }) + } else { + // If readings are too close together or time hasn't passed, report 0 + peerBandwidths = append(peerBandwidths, PeerBandwidth{ + PublicKey: publicKey, + BytesIn: 0, + BytesOut: 0, + }) + } + } else { + // For first reading of a peer, report 0 to establish baseline + peerBandwidths = append(peerBandwidths, PeerBandwidth{ + PublicKey: publicKey, + BytesIn: 0, + BytesOut: 0, + }) + } + + // Update the last reading + s.lastReadings[publicKey] = currentReading + } + + // Clean up old peers + for publicKey := range s.lastReadings { + found := false + for _, peer := range device.Peers { + if peer.PublicKey.String() == publicKey { + found = true + break + } + } + if !found { + delete(s.lastReadings, publicKey) + } + } + + return peerBandwidths, nil +} + +func (s *WireGuardService) reportPeerBandwidth() error { + bandwidths, err := s.calculatePeerBandwidth() + if err != nil { + return fmt.Errorf("failed to calculate peer bandwidth: %v", err) + } + + err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{ + "bandwidthData": bandwidths, + }) + if err != nil { + return fmt.Errorf("failed to send bandwidth data: %v", err) + } + + return nil +} + +func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { + + if s.serverPubKey == "" || s.token == "" { + logger.Debug("Server public key or token not set, skipping UDP hole punch") + return nil + } + + // Parse server address + serverSplit := strings.Split(serverAddr, ":") + if len(serverSplit) < 2 { + return fmt.Errorf("invalid server address format, expected hostname:port") + } + + serverHostname := serverSplit[0] + serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16) + if err != nil { + return fmt.Errorf("failed to parse server port: %v", err) + } + + // Resolve server hostname to IP + serverIPAddr := network.HostToAddr(serverHostname) + if serverIPAddr == nil { + return fmt.Errorf("failed to resolve server hostname") + } + + // Get client IP based on route to server + clientIP := network.GetClientIP(serverIPAddr.IP) + + // Create server and client configs + server := &network.Server{ + Hostname: serverHostname, + Addr: serverIPAddr, + Port: uint16(serverPort), + } + + client := &network.PeerNet{ + IP: clientIP, + Port: s.Port, + NewtID: s.newtId, + } + + // Setup raw connection with BPF filtering + rawConn := network.SetupRawConn(server, client) + defer rawConn.Close() + + // Create JSON payload + payload := struct { + NewtID string `json:"newtId"` + Token string `json:"token"` + }{ + NewtID: s.newtId, + Token: s.token, + } + + // Convert payload to JSON + payloadBytes, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %v", err) + } + + // Encrypt the payload using the server's WireGuard public key + encryptedPayload, err := s.encryptPayload(payloadBytes) + if err != nil { + return fmt.Errorf("failed to encrypt payload: %v", err) + } + + // Send the encrypted packet using the raw connection + err = network.SendDataPacket(encryptedPayload, rawConn, server, client) + if err != nil { + return fmt.Errorf("failed to send UDP packet: %v", err) + } + + return nil +} + +func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) { + // Generate an ephemeral keypair for this message + ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err) + } + ephemeralPublicKey := ephemeralPrivateKey.PublicKey() + + // Parse the server's public key + serverPubKey, err := wgtypes.ParseKey(s.serverPubKey) + if err != nil { + return nil, fmt.Errorf("failed to parse server public key: %v", err) + } + + // Use X25519 for key exchange (replacing deprecated ScalarMult) + var ephPrivKeyFixed [32]byte + copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:]) + + // Perform X25519 key exchange + sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:]) + if err != nil { + return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err) + } + + // Create an AEAD cipher using the shared secret + aead, err := chacha20poly1305.New(sharedSecret) + if err != nil { + return nil, fmt.Errorf("failed to create AEAD cipher: %v", err) + } + + // Generate a random nonce + nonce := make([]byte, aead.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, fmt.Errorf("failed to generate nonce: %v", err) + } + + // Encrypt the payload + ciphertext := aead.Seal(nil, nonce, payload, nil) + + // Prepare the final encrypted message + encryptedMsg := struct { + EphemeralPublicKey string `json:"ephemeralPublicKey"` + Nonce []byte `json:"nonce"` + Ciphertext []byte `json:"ciphertext"` + }{ + EphemeralPublicKey: ephemeralPublicKey.String(), + Nonce: nonce, + Ciphertext: ciphertext, + } + + return encryptedMsg, nil +} + +func (s *WireGuardService) keepSendingUDPHolePunch(host string) { + ticker := time.NewTicker(3 * time.Second) + defer ticker.Stop() + + for { + select { + case <-s.stopHolepunch: + logger.Info("Stopping UDP holepunch") + return + case <-ticker.C: + if err := s.sendUDPHolePunch(host + ":21820"); err != nil { + logger.Error("Failed to send UDP hole punch: %v", err) + } + } + } +} + +func (s *WireGuardService) removeInterface() error { + // Remove the WireGuard interface + link, err := netlink.LinkByName(s.interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface: %v", err) + } + + err = netlink.LinkDel(link) + if err != nil { + return fmt.Errorf("failed to delete interface: %v", err) + } + + logger.Info("WireGuard interface %s removed successfully", s.interfaceName) + + return nil +} + +func (s *WireGuardService) sendGetConfigMessage() error { + err := s.client.SendMessage("newt/wg/get-config", map[string]interface{}{ + "publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()), + "port": s.Port, + }) + if err != nil { + logger.Error("Failed to send get-config message: %v", err) + return err + } + logger.Info("Requesting WireGuard configuration from remote server") + return nil +} + +func (s *WireGuardService) keepSendingGetConfig() { + ticker := time.NewTicker(3 * time.Second) + defer ticker.Stop() + + for { + select { + case <-s.stopGetConfig: + logger.Info("Stopping get-config messages") + return + case <-ticker.C: + if err := s.sendGetConfigMessage(); err != nil { + logger.Error("Failed to send periodic get-config: %v", err) + } + } + } +} diff --git a/wgtester/wgtester.go b/wgtester/wgtester.go new file mode 100644 index 0000000..b302fd4 --- /dev/null +++ b/wgtester/wgtester.go @@ -0,0 +1,164 @@ +package wgtester + +import ( + "encoding/binary" + "fmt" + "net" + "sync" + "time" + + "github.com/fosrl/newt/logger" +) + +const ( + // Magic bytes to identify our packets + magicHeader uint32 = 0xDEADBEEF + // Request packet type + packetTypeRequest uint8 = 1 + // Response packet type + packetTypeResponse uint8 = 2 + // Packet format: + // - 4 bytes: magic header (0xDEADBEEF) + // - 1 byte: packet type (1 = request, 2 = response) + // - 8 bytes: timestamp (for round-trip timing) + packetSize = 13 +) + +// Server handles listening for connection check requests using UDP +type Server struct { + conn *net.UDPConn + serverAddr string + serverPort uint16 + shutdownCh chan struct{} + isRunning bool + runningLock sync.Mutex + newtID string + outputPrefix string +} + +// NewServer creates a new connection test server using UDP +func NewServer(serverAddr string, serverPort uint16, newtID string) *Server { + return &Server{ + serverAddr: serverAddr, + serverPort: serverPort + 1, // use the next port for the server + shutdownCh: make(chan struct{}), + newtID: newtID, + outputPrefix: "[WGTester] ", + } +} + +// Start begins listening for connection test packets using UDP +func (s *Server) Start() error { + s.runningLock.Lock() + defer s.runningLock.Unlock() + + if s.isRunning { + return nil + } + + //create the address to listen on + addr := net.JoinHostPort(s.serverAddr, fmt.Sprintf("%d", s.serverPort)) + + // Create UDP address to listen on + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return err + } + + // Create UDP connection + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return err + } + s.conn = conn + + s.isRunning = true + go s.handleConnections() + + logger.Info("%sServer started on %s:%d", s.outputPrefix, s.serverAddr, s.serverPort) + return nil +} + +// Stop shuts down the server +func (s *Server) Stop() { + s.runningLock.Lock() + defer s.runningLock.Unlock() + + if !s.isRunning { + return + } + + close(s.shutdownCh) + if s.conn != nil { + s.conn.Close() + } + s.isRunning = false + logger.Info(s.outputPrefix + "Server stopped") +} + +// handleConnections processes incoming packets +func (s *Server) handleConnections() { + buffer := make([]byte, 2000) // Buffer large enough for any UDP packet + + for { + select { + case <-s.shutdownCh: + return + default: + // Set read deadline to avoid blocking forever + err := s.conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + if err != nil { + logger.Error(s.outputPrefix+"Error setting read deadline: %v", err) + continue + } + + // Read from UDP connection + n, addr, err := s.conn.ReadFromUDP(buffer) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + // Just a timeout, keep going + continue + } + logger.Error(s.outputPrefix+"Error reading from UDP: %v", err) + continue + } + + // Process packet only if it meets minimum size requirements + if n < packetSize { + continue // Too small to be our packet + } + + // Check magic header + magic := binary.BigEndian.Uint32(buffer[0:4]) + if magic != magicHeader { + continue // Not our packet + } + + // Check packet type + packetType := buffer[4] + if packetType != packetTypeRequest { + continue // Not a request packet + } + + // Create response packet + responsePacket := make([]byte, packetSize) + // Copy the same magic header + binary.BigEndian.PutUint32(responsePacket[0:4], magicHeader) + // Change the packet type to response + responsePacket[4] = packetTypeResponse + // Copy the timestamp (for RTT calculation) + copy(responsePacket[5:13], buffer[5:13]) + + // Log response being sent for debugging + logger.Debug(s.outputPrefix+"Sending response to %s", addr.String()) + + // Send the response packet directly to the source address + _, err = s.conn.WriteToUDP(responsePacket, addr) + if err != nil { + logger.Error(s.outputPrefix+"Error sending response: %v", err) + } else { + logger.Debug(s.outputPrefix + "Response sent successfully") + } + } + } +}