Skip to content

Commit 67706ac

Browse files
Reverse proxy ability to return Url, memoryview or TcpServerConnection object (#1397)
* Reverse proxy enhancements * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 380e0cc commit 67706ac

File tree

6 files changed

+194
-67
lines changed

6 files changed

+194
-67
lines changed

proxy/core/acceptor/pool.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,17 @@ def setup(self) -> None:
9898
"""Setup acceptors."""
9999
self._start()
100100
execution_mode = (
101-
'threadless (local)'
102-
if self.flags.local_executor
103-
else 'threadless (remote)'
104-
) if self.flags.threadless else 'threaded'
105-
logger.info(
106-
'Started %d acceptors in %s mode' % (
101+
(
102+
'threadless (local)'
103+
if self.flags.local_executor
104+
else 'threadless (remote)'
105+
)
106+
if self.flags.threadless
107+
else 'threaded'
108+
)
109+
logger.debug(
110+
'Started %d acceptors in %s mode'
111+
% (
107112
self.flags.num_acceptors,
108113
execution_mode,
109114
),
@@ -122,7 +127,7 @@ def setup(self) -> None:
122127
self.fd_queues[index].close()
123128

124129
def shutdown(self) -> None:
125-
logger.info('Shutting down %d acceptors' % self.flags.num_acceptors)
130+
logger.debug('Shutting down %d acceptors' % self.flags.num_acceptors)
126131
for acceptor in self.acceptors:
127132
acceptor.running.set()
128133
for acceptor in self.acceptors:

proxy/core/listener/tcp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ def listen(self) -> socket.socket:
9292
sock.listen(self.flags.backlog)
9393
sock.setblocking(False)
9494
self._port = sock.getsockname()[1]
95-
logger.info(
96-
'Listening on %s:%s' %
97-
(self.hostname, self._port),
95+
logger.debug(
96+
'Listening on %s:%s' % (self.hostname, self._port),
9897
)
9998
return sock

proxy/http/server/plugin.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@
2222
from ..descriptors import DescriptorsHandlerMixin
2323
from ...common.types import RePattern
2424
from ...common.utils import bytes_
25+
from ...http.server.protocols import httpProtocolTypes
2526

2627

2728
if TYPE_CHECKING: # pragma: no cover
28-
from ...core.connection import UpstreamConnectionPool
29+
from ...core.connection import TcpServerConnection, UpstreamConnectionPool
2930

3031

3132
class HttpWebServerBasePlugin(DescriptorsHandlerMixin, ABC):
@@ -64,7 +65,7 @@ def serve_static_file(path: str, min_compression_length: int) -> memoryview:
6465
# TODO: Should we really close or take advantage of keep-alive?
6566
conn_close=True,
6667
)
67-
except FileNotFoundError:
68+
except OSError:
6869
return NOT_FOUND_RESPONSE_PKT
6970

7071
def name(self) -> str:
@@ -88,6 +89,17 @@ def on_client_connection_close(self) -> None:
8889
"""Client has closed the connection, do any clean up task now."""
8990
pass
9091

92+
def do_upgrade(self, request: HttpParser) -> bool:
93+
return True
94+
95+
def on_client_data(
96+
self,
97+
request: HttpParser,
98+
raw: memoryview,
99+
) -> Optional[memoryview]:
100+
"""Return None to avoid default webserver parsing of client data."""
101+
return raw
102+
91103
# No longer abstract since v2.4.0
92104
#
93105
# @abstractmethod
@@ -125,7 +137,7 @@ def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
125137
return context
126138

127139

128-
class ReverseProxyBasePlugin(ABC):
140+
class ReverseProxyBasePlugin(DescriptorsHandlerMixin, ABC):
129141
"""ReverseProxy base plugin class."""
130142

131143
def __init__(
@@ -161,13 +173,24 @@ def routes(self) -> List[Union[str, Tuple[str, List[bytes]]]]:
161173
must return the url to serve."""
162174
raise NotImplementedError() # pragma: no cover
163175

176+
def protocols(self) -> List[int]:
177+
return [
178+
httpProtocolTypes.HTTP,
179+
httpProtocolTypes.HTTPS,
180+
httpProtocolTypes.WEBSOCKET,
181+
]
182+
164183
def before_routing(self, request: HttpParser) -> Optional[HttpParser]:
165184
"""Plugins can modify request, return response, close connection.
166185
167186
If None is returned, request will be dropped and closed."""
168187
return request # pragma: no cover
169188

170-
def handle_route(self, request: HttpParser, pattern: RePattern) -> Url:
189+
def handle_route(
190+
self,
191+
request: HttpParser,
192+
pattern: RePattern,
193+
) -> Union[memoryview, Url, 'TcpServerConnection']:
171194
"""Implement this method if you have configured dynamic routes."""
172195
raise NotImplementedError()
173196

@@ -182,3 +205,13 @@ def regexes(self) -> List[str]:
182205
else:
183206
raise ValueError('Invalid route type')
184207
return routes
208+
209+
def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
210+
"""Use this method to override default access log format (see
211+
DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT) or to add/update/modify passed context
212+
for usage by default access logger.
213+
214+
Return updated log context to use for default logging format, OR
215+
Return None if plugin has logged the request.
216+
"""
217+
return context

proxy/http/server/reverse.py

Lines changed: 101 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
from proxy.http import Url
1717
from proxy.core.base import TcpUpstreamConnectionHandler
1818
from proxy.http.parser import HttpParser
19-
from proxy.http.server import HttpWebServerBasePlugin, httpProtocolTypes
19+
from proxy.http.server import HttpWebServerBasePlugin
2020
from proxy.common.utils import text_
2121
from proxy.http.exception import HttpProtocolException
2222
from proxy.common.constants import (
2323
HTTPS_PROTO, DEFAULT_HTTP_PORT, DEFAULT_HTTPS_PORT,
2424
DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT,
2525
)
26+
from ...common.types import Readables, Writables, Descriptors
2627

2728

2829
if TYPE_CHECKING: # pragma: no cover
@@ -44,6 +45,11 @@ def __init__(self, *args: Any, **kwargs: Any):
4445
self.uid, self.flags, self.client, self.event_queue, self.upstream_conn_pool,
4546
)
4647
self.plugins.append(plugin)
48+
self._upstream_proxy_pass: Optional[str] = None
49+
50+
def do_upgrade(self, request: HttpParser) -> bool:
51+
"""Signal web protocol handler to not upgrade websocket requests by default."""
52+
return False
4753

4854
def handle_upstream_data(self, raw: memoryview) -> None:
4955
# TODO: Parse response and implement plugin hook per parsed response object
@@ -54,8 +60,8 @@ def routes(self) -> List[Tuple[int, str]]:
5460
r = []
5561
for plugin in self.plugins:
5662
for route in plugin.regexes():
57-
r.append((httpProtocolTypes.HTTP, route))
58-
r.append((httpProtocolTypes.HTTPS, route))
63+
for proto in plugin.protocols():
64+
r.append((proto, route))
5965
return r
6066

6167
def handle_request(self, request: HttpParser) -> None:
@@ -66,59 +72,123 @@ def handle_request(self, request: HttpParser) -> None:
6672
raise HttpProtocolException('before_routing closed connection')
6773
request = r
6874

75+
needs_upstream = False
76+
6977
# routes
7078
for plugin in self.plugins:
7179
for route in plugin.routes():
80+
# Static routes
7281
if isinstance(route, tuple):
7382
pattern = re.compile(route[0])
7483
if pattern.match(text_(request.path)):
7584
self.choice = Url.from_bytes(
7685
random.choice(route[1]),
7786
)
7887
break
88+
# Dynamic routes
7989
elif isinstance(route, str):
8090
pattern = re.compile(route)
8191
if pattern.match(text_(request.path)):
82-
self.choice = plugin.handle_route(request, pattern)
92+
choice = plugin.handle_route(request, pattern)
93+
if isinstance(choice, Url):
94+
self.choice = choice
95+
needs_upstream = True
96+
self._upstream_proxy_pass = str(self.choice)
97+
elif isinstance(choice, memoryview):
98+
self.client.queue(choice)
99+
self._upstream_proxy_pass = '{0} bytes'.format(len(choice))
100+
else:
101+
self.upstream = choice
102+
self._upstream_proxy_pass = '{0}:{1}'.format(
103+
*self.upstream.addr,
104+
)
83105
break
84106
else:
85107
raise ValueError('Invalid route')
86108

87-
assert self.choice and self.choice.hostname
88-
port = self.choice.port or \
89-
DEFAULT_HTTP_PORT \
90-
if self.choice.scheme == b'http' \
91-
else DEFAULT_HTTPS_PORT
92-
self.initialize_upstream(text_(self.choice.hostname), port)
93-
assert self.upstream
94-
try:
95-
self.upstream.connect()
96-
if self.choice.scheme == HTTPS_PROTO:
97-
self.upstream.wrap(
98-
text_(
99-
self.choice.hostname,
109+
if needs_upstream:
110+
assert self.choice and self.choice.hostname
111+
port = (
112+
self.choice.port or DEFAULT_HTTP_PORT
113+
if self.choice.scheme == b'http'
114+
else DEFAULT_HTTPS_PORT
115+
)
116+
self.initialize_upstream(text_(self.choice.hostname), port)
117+
assert self.upstream
118+
try:
119+
self.upstream.connect()
120+
if self.choice.scheme == HTTPS_PROTO:
121+
self.upstream.wrap(
122+
text_(
123+
self.choice.hostname,
124+
),
125+
as_non_blocking=True,
126+
ca_file=self.flags.ca_file,
127+
)
128+
request.path = self.choice.remainder
129+
self.upstream.queue(memoryview(request.build()))
130+
except ConnectionRefusedError:
131+
raise HttpProtocolException( # pragma: no cover
132+
'Connection refused by upstream server {0}:{1}'.format(
133+
text_(self.choice.hostname),
134+
port,
100135
),
101-
as_non_blocking=True,
102-
ca_file=self.flags.ca_file,
103136
)
104-
request.path = self.choice.remainder
105-
self.upstream.queue(memoryview(request.build()))
106-
except ConnectionRefusedError:
107-
raise HttpProtocolException( # pragma: no cover
108-
'Connection refused by upstream server {0}:{1}'.format(
109-
text_(self.choice.hostname), port,
110-
),
111-
)
112137

113138
def on_client_connection_close(self) -> None:
114139
if self.upstream and not self.upstream.closed:
115140
logger.debug('Closing upstream server connection')
116141
self.upstream.close()
117142
self.upstream = None
118143

144+
def on_client_data(
145+
self,
146+
request: HttpParser,
147+
raw: memoryview,
148+
) -> Optional[memoryview]:
149+
if request.is_websocket_upgrade:
150+
assert self.upstream
151+
self.upstream.queue(raw)
152+
return raw
153+
119154
def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
120-
context.update({
121-
'upstream_proxy_pass': str(self.choice) if self.choice else None,
122-
})
123-
logger.info(DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT.format_map(context))
155+
context.update(
156+
{
157+
'upstream_proxy_pass': self._upstream_proxy_pass,
158+
},
159+
)
160+
log_handled = False
161+
for plugin in self.plugins:
162+
ctx = plugin.on_access_log(context)
163+
if ctx is None:
164+
log_handled = True
165+
break
166+
context = ctx
167+
if not log_handled:
168+
logger.info(DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT.format_map(context))
124169
return None
170+
171+
async def get_descriptors(self) -> Descriptors:
172+
r, w = await super().get_descriptors()
173+
# TODO(abhinavsingh): We need to keep a mapping of plugin and
174+
# descriptors registered by them, so that within write/read blocks
175+
# we can invoke the right plugin callbacks.
176+
for plugin in self.plugins:
177+
plugin_read_desc, plugin_write_desc = await plugin.get_descriptors()
178+
r.extend(plugin_read_desc)
179+
w.extend(plugin_write_desc)
180+
return r, w
181+
182+
async def read_from_descriptors(self, r: Readables) -> bool:
183+
for plugin in self.plugins:
184+
teardown = await plugin.read_from_descriptors(r)
185+
if teardown:
186+
return True
187+
return await super().read_from_descriptors(r)
188+
189+
async def write_to_descriptors(self, w: Writables) -> bool:
190+
for plugin in self.plugins:
191+
teardown = await plugin.write_to_descriptors(w)
192+
if teardown:
193+
return True
194+
return await super().write_to_descriptors(w)

0 commit comments

Comments
 (0)