@@ -2,6 +2,7 @@ package api
22
33import (
44 "bufio"
5+ "context"
56 "fmt"
67 "io"
78 "net"
@@ -12,7 +13,7 @@ import (
1213 "time"
1314)
1415
15- // createReverseProxy creates a reverse proxy for a tunnel
16+ // createReverseProxy creates a reverse proxy for a tunnel using netstack
1617func (s * Server ) createReverseProxy (targetIP string , port uint16 ) * httputil.ReverseProxy {
1718 target := & url.URL {
1819 Scheme : "http" ,
@@ -21,13 +22,12 @@ func (s *Server) createReverseProxy(targetIP string, port uint16) *httputil.Reve
2122
2223 proxy := httputil .NewSingleHostReverseProxy (target )
2324
24- // Customize the transport for better performance
25+ // Get netstack from tunnel for userspace networking
26+ tnet := s .tun .GetNetstack ()
27+
28+ // Customize the transport to use netstack (userspace WireGuard networking)
2529 proxy .Transport = & http.Transport {
26- Proxy : http .ProxyFromEnvironment ,
27- DialContext : (& net.Dialer {
28- Timeout : 30 * time .Second ,
29- KeepAlive : 30 * time .Second ,
30- }).DialContext ,
30+ DialContext : tnet .DialContext , // Use netstack instead of kernel networking
3131 ForceAttemptHTTP2 : true ,
3232 MaxIdleConns : 100 ,
3333 IdleConnTimeout : 90 * time .Second ,
@@ -75,21 +75,29 @@ func (s *Server) createReverseProxy(targetIP string, port uint16) *httputil.Reve
7575 return proxy
7676}
7777
78- // handleTunnelTrafficWithProxy handles incoming traffic and proxies it to the tunnel
79- func ( s * Server ) handleTunnelTrafficWithProxy ( w http. ResponseWriter , r * http. Request ) {
80- // Extract subdomain from host
81- host := r . Host
82- if idx := strings .Index (host , ":" ); idx != - 1 {
78+ // extractSubdomain extracts the subdomain from a host header value.
79+ // It handles port stripping and returns just the subdomain portion.
80+ func extractSubdomain ( host string ) string {
81+ // Remove port if present
82+ if idx := strings .IndexByte (host , ':' ); idx != - 1 {
8383 host = host [:idx ]
8484 }
8585
86- parts := strings .Split (host , "." )
87- if len (parts ) < 2 {
88- http .Error (w , "Invalid host" , http .StatusBadRequest )
86+ // Extract subdomain (first part before first dot)
87+ if idx := strings .IndexByte (host , '.' ); idx != - 1 {
88+ return host [:idx ]
89+ }
90+ return host
91+ }
92+
93+ // handleTunnelTrafficWithProxy handles incoming traffic and proxies it to the tunnel
94+ func (s * Server ) handleTunnelTrafficWithProxy (w http.ResponseWriter , r * http.Request ) {
95+ // Extract subdomain from host
96+ subdomain := extractSubdomain (r .Host )
97+ if subdomain == "" {
98+ http .Error (w , "Invalid host header" , http .StatusBadRequest )
8999 return
90100 }
91-
92- subdomain := parts [0 ]
93101 tunnel := s .registry .GetTunnelBySubdomain (subdomain )
94102 if tunnel == nil {
95103 http .Error (w , "Tunnel not found" , http .StatusNotFound )
@@ -121,7 +129,7 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request, targetI
121129 targetURL += "?" + r .URL .RawQuery
122130 }
123131
124- targetConn , resp , err := websocketDial (targetURL , r .Header )
132+ targetConn , resp , err := s . websocketDial (targetURL , r .Header )
125133 if err != nil {
126134 s .logger .Error ("websocket dial error" , "error" , err , "target" , targetURL )
127135 http .Error (w , "Bad Gateway" , http .StatusBadGateway )
@@ -150,31 +158,47 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request, targetI
150158 return
151159 }
152160
153- // Proxy data between connections
161+ // Use context for proper cancellation
162+ ctx , cancel := context .WithCancel (r .Context ())
163+ defer cancel ()
164+
165+ // Proxy data between connections with proper cleanup
154166 errc := make (chan error , 2 )
155167 go func () {
168+ defer cancel () // Cancel context when one direction completes
156169 _ , err := io .Copy (targetConn , clientConn )
157170 errc <- err
158171 }()
159172 go func () {
173+ defer cancel () // Cancel context when one direction completes
160174 _ , err := io .Copy (clientConn , targetConn )
161175 errc <- err
162176 }()
163177
164- // Wait for either copy to complete
165- <- errc
178+ // Wait for either copy to complete or context cancellation
179+ select {
180+ case <- ctx .Done ():
181+ return
182+ case <- errc :
183+ return
184+ }
166185}
167186
168- // websocketDial dials a WebSocket connection
169- func websocketDial (targetURL string , headers http.Header ) (net.Conn , * http.Response , error ) {
187+ // websocketDial dials a WebSocket connection using the tunnel's netstack
188+ func ( s * Server ) websocketDial (targetURL string , headers http.Header ) (net.Conn , * http.Response , error ) {
170189 // Parse the URL
171190 u , err := url .Parse (targetURL )
172191 if err != nil {
173192 return nil , nil , err
174193 }
175194
176- // Dial TCP connection
177- conn , err := net .DialTimeout ("tcp" , u .Host , 10 * time .Second )
195+ // Get netstack from tunnel for userspace networking
196+ tnet := s .tun .GetNetstack ()
197+
198+ // Dial TCP connection using netstack (userspace WireGuard networking)
199+ ctx , cancel := context .WithTimeout (context .Background (), 10 * time .Second )
200+ defer cancel ()
201+ conn , err := tnet .DialContext (ctx , "tcp" , u .Host )
178202 if err != nil {
179203 return nil , nil , err
180204 }
0 commit comments