16
16
from proxy .http import Url
17
17
from proxy .core .base import TcpUpstreamConnectionHandler
18
18
from proxy .http .parser import HttpParser
19
- from proxy .http .server import HttpWebServerBasePlugin , httpProtocolTypes
19
+ from proxy .http .server import HttpWebServerBasePlugin
20
20
from proxy .common .utils import text_
21
21
from proxy .http .exception import HttpProtocolException
22
22
from proxy .common .constants import (
23
23
HTTPS_PROTO , DEFAULT_HTTP_PORT , DEFAULT_HTTPS_PORT ,
24
24
DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT ,
25
25
)
26
+ from ...common .types import Readables , Writables , Descriptors
26
27
27
28
28
29
if TYPE_CHECKING : # pragma: no cover
@@ -44,6 +45,11 @@ def __init__(self, *args: Any, **kwargs: Any):
44
45
self .uid , self .flags , self .client , self .event_queue , self .upstream_conn_pool ,
45
46
)
46
47
self .plugins .append (plugin )
48
+ self ._upstream_proxy_pass : Optional [str ] = None
49
+
50
+ def do_upgrade (self , request : HttpParser ) -> bool :
51
+ """Signal web protocol handler to not upgrade websocket requests by default."""
52
+ return False
47
53
48
54
def handle_upstream_data (self , raw : memoryview ) -> None :
49
55
# TODO: Parse response and implement plugin hook per parsed response object
@@ -54,8 +60,8 @@ def routes(self) -> List[Tuple[int, str]]:
54
60
r = []
55
61
for plugin in self .plugins :
56
62
for route in plugin .regexes ():
57
- r . append (( httpProtocolTypes . HTTP , route ))
58
- r .append ((httpProtocolTypes . HTTPS , route ))
63
+ for proto in plugin . protocols ():
64
+ r .append ((proto , route ))
59
65
return r
60
66
61
67
def handle_request (self , request : HttpParser ) -> None :
@@ -66,59 +72,123 @@ def handle_request(self, request: HttpParser) -> None:
66
72
raise HttpProtocolException ('before_routing closed connection' )
67
73
request = r
68
74
75
+ needs_upstream = False
76
+
69
77
# routes
70
78
for plugin in self .plugins :
71
79
for route in plugin .routes ():
80
+ # Static routes
72
81
if isinstance (route , tuple ):
73
82
pattern = re .compile (route [0 ])
74
83
if pattern .match (text_ (request .path )):
75
84
self .choice = Url .from_bytes (
76
85
random .choice (route [1 ]),
77
86
)
78
87
break
88
+ # Dynamic routes
79
89
elif isinstance (route , str ):
80
90
pattern = re .compile (route )
81
91
if pattern .match (text_ (request .path )):
82
- self .choice = plugin .handle_route (request , pattern )
92
+ choice = plugin .handle_route (request , pattern )
93
+ if isinstance (choice , Url ):
94
+ self .choice = choice
95
+ needs_upstream = True
96
+ self ._upstream_proxy_pass = str (self .choice )
97
+ elif isinstance (choice , memoryview ):
98
+ self .client .queue (choice )
99
+ self ._upstream_proxy_pass = '{0} bytes' .format (len (choice ))
100
+ else :
101
+ self .upstream = choice
102
+ self ._upstream_proxy_pass = '{0}:{1}' .format (
103
+ * self .upstream .addr ,
104
+ )
83
105
break
84
106
else :
85
107
raise ValueError ('Invalid route' )
86
108
87
- assert self .choice and self .choice .hostname
88
- port = self .choice .port or \
89
- DEFAULT_HTTP_PORT \
90
- if self .choice .scheme == b'http' \
91
- else DEFAULT_HTTPS_PORT
92
- self .initialize_upstream (text_ (self .choice .hostname ), port )
93
- assert self .upstream
94
- try :
95
- self .upstream .connect ()
96
- if self .choice .scheme == HTTPS_PROTO :
97
- self .upstream .wrap (
98
- text_ (
99
- self .choice .hostname ,
109
+ if needs_upstream :
110
+ assert self .choice and self .choice .hostname
111
+ port = (
112
+ self .choice .port or DEFAULT_HTTP_PORT
113
+ if self .choice .scheme == b'http'
114
+ else DEFAULT_HTTPS_PORT
115
+ )
116
+ self .initialize_upstream (text_ (self .choice .hostname ), port )
117
+ assert self .upstream
118
+ try :
119
+ self .upstream .connect ()
120
+ if self .choice .scheme == HTTPS_PROTO :
121
+ self .upstream .wrap (
122
+ text_ (
123
+ self .choice .hostname ,
124
+ ),
125
+ as_non_blocking = True ,
126
+ ca_file = self .flags .ca_file ,
127
+ )
128
+ request .path = self .choice .remainder
129
+ self .upstream .queue (memoryview (request .build ()))
130
+ except ConnectionRefusedError :
131
+ raise HttpProtocolException ( # pragma: no cover
132
+ 'Connection refused by upstream server {0}:{1}' .format (
133
+ text_ (self .choice .hostname ),
134
+ port ,
100
135
),
101
- as_non_blocking = True ,
102
- ca_file = self .flags .ca_file ,
103
136
)
104
- request .path = self .choice .remainder
105
- self .upstream .queue (memoryview (request .build ()))
106
- except ConnectionRefusedError :
107
- raise HttpProtocolException ( # pragma: no cover
108
- 'Connection refused by upstream server {0}:{1}' .format (
109
- text_ (self .choice .hostname ), port ,
110
- ),
111
- )
112
137
113
138
def on_client_connection_close (self ) -> None :
114
139
if self .upstream and not self .upstream .closed :
115
140
logger .debug ('Closing upstream server connection' )
116
141
self .upstream .close ()
117
142
self .upstream = None
118
143
144
+ def on_client_data (
145
+ self ,
146
+ request : HttpParser ,
147
+ raw : memoryview ,
148
+ ) -> Optional [memoryview ]:
149
+ if request .is_websocket_upgrade :
150
+ assert self .upstream
151
+ self .upstream .queue (raw )
152
+ return raw
153
+
119
154
def on_access_log (self , context : Dict [str , Any ]) -> Optional [Dict [str , Any ]]:
120
- context .update ({
121
- 'upstream_proxy_pass' : str (self .choice ) if self .choice else None ,
122
- })
123
- logger .info (DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT .format_map (context ))
155
+ context .update (
156
+ {
157
+ 'upstream_proxy_pass' : self ._upstream_proxy_pass ,
158
+ },
159
+ )
160
+ log_handled = False
161
+ for plugin in self .plugins :
162
+ ctx = plugin .on_access_log (context )
163
+ if ctx is None :
164
+ log_handled = True
165
+ break
166
+ context = ctx
167
+ if not log_handled :
168
+ logger .info (DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT .format_map (context ))
124
169
return None
170
+
171
+ async def get_descriptors (self ) -> Descriptors :
172
+ r , w = await super ().get_descriptors ()
173
+ # TODO(abhinavsingh): We need to keep a mapping of plugin and
174
+ # descriptors registered by them, so that within write/read blocks
175
+ # we can invoke the right plugin callbacks.
176
+ for plugin in self .plugins :
177
+ plugin_read_desc , plugin_write_desc = await plugin .get_descriptors ()
178
+ r .extend (plugin_read_desc )
179
+ w .extend (plugin_write_desc )
180
+ return r , w
181
+
182
+ async def read_from_descriptors (self , r : Readables ) -> bool :
183
+ for plugin in self .plugins :
184
+ teardown = await plugin .read_from_descriptors (r )
185
+ if teardown :
186
+ return True
187
+ return await super ().read_from_descriptors (r )
188
+
189
+ async def write_to_descriptors (self , w : Writables ) -> bool :
190
+ for plugin in self .plugins :
191
+ teardown = await plugin .write_to_descriptors (w )
192
+ if teardown :
193
+ return True
194
+ return await super ().write_to_descriptors (w )
0 commit comments