diff --git a/proxy/core/base/tcp_server.py b/proxy/core/base/tcp_server.py index 4d32949d55..3fa9eda2e0 100644 --- a/proxy/core/base/tcp_server.py +++ b/proxy/core/base/tcp_server.py @@ -104,7 +104,7 @@ class BaseTcpServerHandler(Work[T]): is ready to accept new data before flushing data to it. Most importantly, BaseTcpServerHandler ensures that pending buffers - to the client are flushed before connection is closed. + to the client are flushed before connection is closed with the client. Implementations must provide:: @@ -170,9 +170,9 @@ async def handle_events( async def handle_writables(self, writables: Writables) -> bool: teardown = False if self.work.connection.fileno() in writables and self.work.has_buffer(): - logger.debug( - 'Flushing buffer to client {0}'.format(self.work.address), - ) + # logger.debug( + # 'Flushing buffer to client {0}'.format(self.work.address), + # ) self.work.flush(self.flags.max_sendbuf_size) if self.must_flush_before_shutdown is True and \ not self.work.has_buffer(): diff --git a/proxy/core/connection/connection.py b/proxy/core/connection/connection.py index 2200f93e6c..9fb53009d0 100644 --- a/proxy/core/connection/connection.py +++ b/proxy/core/connection/connection.py @@ -14,11 +14,13 @@ from .types import tcpConnectionTypes from ...common.types import TcpOrTlsSocket +from ...common.leakage import Leakage from ...common.constants import DEFAULT_BUFFER_SIZE, DEFAULT_MAX_SEND_SIZE logger = logging.getLogger(__name__) +EMPTY_MV = memoryview(b'') class TcpConnectionUninitializedException(Exception): pass @@ -34,12 +36,19 @@ class TcpConnection(ABC): a socket connection object. """ - def __init__(self, tag: int) -> None: + def __init__( + self, + tag: int, + flush_bps: int = 512, + recv_bps: int = 512, + ) -> None: self.tag: str = 'server' if tag == tcpConnectionTypes.SERVER else 'client' self.buffer: List[memoryview] = [] self.closed: bool = False self._reusable: bool = False self._num_buffer = 0 + self._flush_leakage = Leakage(rate=flush_bps) if flush_bps > 0 else None + self._recv_leakage = Leakage(rate=recv_bps) if recv_bps > 0 else None @property @abstractmethod @@ -55,14 +64,20 @@ def recv( self, buffer_size: int = DEFAULT_BUFFER_SIZE, ) -> Optional[memoryview]: """Users must handle socket.error exceptions""" + if self._recv_leakage is not None: + allowed_bytes = self._recv_leakage.consume(buffer_size) + if allowed_bytes == 0: + return EMPTY_MV + buffer_size = min(buffer_size, allowed_bytes) data: bytes = self.connection.recv(buffer_size) - if len(data) == 0: + size = len(data) + unused = buffer_size - size + if self._recv_leakage is not None and unused > 0: + self._recv_leakage.release(unused) + if size == 0: return None - logger.debug( - 'received %d bytes from %s' % - (len(data), self.tag), - ) - # logger.info(data) + logger.debug('received %d bytes from %s' % (size, self.tag)) + logger.info(data) return memoryview(data) def close(self) -> bool: @@ -75,6 +90,8 @@ def has_buffer(self) -> bool: return self._num_buffer != 0 def queue(self, mv: memoryview) -> None: + if len(mv) == 0: + return self.buffer.append(mv) self._num_buffer += 1 @@ -86,18 +103,32 @@ def flush(self, max_send_size: Optional[int] = None) -> int: # TODO: Assemble multiple packets if total # size remains below max send size. max_send_size = max_send_size or DEFAULT_MAX_SEND_SIZE - try: - sent: int = self.send(mv[:max_send_size]) - except BlockingIOError: - logger.warning('BlockingIOError when trying send to {0}'.format(self.tag)) - return 0 + allowed_bytes = ( + self._flush_leakage.consume(min(len(mv), max_send_size)) + if self._flush_leakage is not None + else max_send_size + ) + sent: int = 0 + if allowed_bytes > 0: + try: + sent = self.send(mv[:allowed_bytes]) + except BlockingIOError: + logger.warning( + 'BlockingIOError when trying send to {0}'.format(self.tag), + ) + del mv + return 0 + finally: + unused = allowed_bytes - sent + if self._flush_leakage is not None and unused > 0: + self._flush_leakage.release(unused) if sent == len(mv): self.buffer.pop(0) self._num_buffer -= 1 else: self.buffer[0] = mv[sent:] logger.debug('flushed %d bytes to %s' % (sent, self.tag)) - # logger.info(mv[:sent].tobytes()) + logger.info(mv[:sent].tobytes()) del mv return sent diff --git a/proxy/core/work/fd/fd.py b/proxy/core/work/fd/fd.py index a13024769c..e17a33ae7a 100644 --- a/proxy/core/work/fd/fd.py +++ b/proxy/core/work/fd/fd.py @@ -41,6 +41,7 @@ def work(self, *args: Any) -> None: publisher_id=self.__class__.__qualname__, ) try: + logger.debug('Initializing work#{0}'.format(fileno)) self.works[fileno].initialize() self._total += 1 except Exception as e: @@ -48,7 +49,7 @@ def work(self, *args: Any) -> None: 'Exception occurred during initialization', exc_info=e, ) - self._cleanup(fileno) + self._cleanup(fileno, 'error') @property @abstractmethod diff --git a/proxy/core/work/threadless.py b/proxy/core/work/threadless.py index e0d722c191..3587cb9d2f 100644 --- a/proxy/core/work/threadless.py +++ b/proxy/core/work/threadless.py @@ -318,16 +318,18 @@ def _cleanup_inactive(self) -> None: if self.works[work_id].is_inactive(): inactive_works.append(work_id) for work_id in inactive_works: - self._cleanup(work_id) + self._cleanup(work_id, 'inactive') # TODO: HttpProtocolHandler.shutdown can call flush which may block - def _cleanup(self, work_id: int) -> None: + def _cleanup(self, work_id: int, reason: str) -> None: if work_id in self.registered_events_by_work_ids: assert self.selector for fileno in self.registered_events_by_work_ids[work_id]: logger.debug( - 'fd#{0} unregistered by work#{1}'.format( - fileno, work_id, + 'fd#{0} unregistered by work#{1}, reason: {2}'.format( + fileno, + work_id, + reason, ), ) self.selector.unregister(fileno) @@ -384,7 +386,7 @@ async def _run_once(self) -> bool: return False # Invoke Threadless.handle_events self.unfinished.update(self._create_tasks(work_by_ids)) - # logger.debug('Executing {0} works'.format(len(self.unfinished))) + # logger.debug("Executing {0} works".format(len(self.unfinished))) # Cleanup finished tasks for task in await self._wait_for_tasks(): # Checking for result can raise exception e.g. @@ -398,11 +400,12 @@ async def _run_once(self) -> bool: teardown = True finally: if teardown: - self._cleanup(work_id) + self._cleanup(work_id, 'teardown') # self.cleanup(int(task.get_name())) # logger.debug( - # 'Done executing works, {0} pending, {1} registered'.format( - # len(self.unfinished), len(self.registered_events_by_work_ids), + # "Done executing works, {0} pending, {1} registered".format( + # len(self.unfinished), + # len(self.registered_events_by_work_ids), # ), # ) return False diff --git a/proxy/http/handler.py b/proxy/http/handler.py index 5120d5b32c..0b69087dc2 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -190,7 +190,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]: async def handle_writables(self, writables: Writables) -> bool: if self.work.connection.fileno() in writables and self.work.has_buffer(): - logger.debug('Client is write ready, flushing...') + # logger.debug('Client is write ready, flushing...') self.last_activity = time.time() # TODO(abhinavsingh): This hook could just reside within server recv block # instead of invoking when flushed to client. @@ -219,7 +219,7 @@ async def handle_writables(self, writables: Writables) -> bool: async def handle_readables(self, readables: Readables) -> bool: if self.work.connection.fileno() in readables: - logger.debug('Client is read ready, receiving...') + # logger.debug('Client is read ready, receiving...') self.last_activity = time.time() try: teardown = await super().handle_readables(readables) diff --git a/proxy/http/server/reverse.py b/proxy/http/server/reverse.py index d1d5e631da..34bf569825 100644 --- a/proxy/http/server/reverse.py +++ b/proxy/http/server/reverse.py @@ -85,6 +85,9 @@ def handle_request(self, request: HttpParser) -> None: random.choice(route[1]), ) needs_upstream = True + logger.debug( + 'Starting connection to upstream {0}'.format(self.choice), + ) break # Dynamic routes elif isinstance(route, str): @@ -95,14 +98,23 @@ def handle_request(self, request: HttpParser) -> None: self.choice = choice needs_upstream = True self._upstream_proxy_pass = str(self.choice) + logger.debug( + 'Starting connection to upstream {0}'.format(choice), + ) elif isinstance(choice, memoryview): self.client.queue(choice) self._upstream_proxy_pass = '{0} bytes'.format(len(choice)) + logger.debug('Sending raw response to client') else: self.upstream = choice self._upstream_proxy_pass = '{0}:{1}'.format( *self.upstream.addr, ) + logger.debug( + 'Using existing connection to upstream {0}'.format( + self.upstream.addr, + ), + ) break else: raise ValueError('Invalid route') diff --git a/tests/common/test_leakage.py b/tests/common/test_leakage.py index 564139cfac..fe451733c1 100644 --- a/tests/common/test_leakage.py +++ b/tests/common/test_leakage.py @@ -22,8 +22,7 @@ def test_initial_consume_no_tokens(self) -> None: rate = 100 # bytes per second bucket = Leakage(rate) self.assertEqual( - bucket.consume(150), - 100, + bucket.consume(150), 100, ) # No tokens yet, so expect 0 bytes to be sent def test_consume_with_refill(self) -> None: diff --git a/tests/http/parser/test_http_parser.py b/tests/http/parser/test_http_parser.py index 427553c310..a5254c86d7 100644 --- a/tests/http/parser/test_http_parser.py +++ b/tests/http/parser/test_http_parser.py @@ -908,3 +908,17 @@ def test_cannot_parse_sip_protocol(self) -> None: b'\r\n', ), ) + + def test_byte_by_byte(self) -> None: + response = HttpParser(httpParserTypes.RESPONSE_PARSER) + request = [ + # pylint: disable=line-too-long + b'HTTP/1.1 200 OK\r\naccess-control-allow-credentials: true\r\naccess-control-allow-origin: *\r\ncontent-type: application/json; charset=utf-8\r\ndate: Thu, 14 Nov 2024 10:24:11 GMT\r\ncontent-length: 550\r\nserver: Fly/a40a59d0 (2024-11-12)\r\nvia: 1.1 fly.io\r\nfly-request-id: 01JCN37CEK4TB4DRWZDFFQYSD9-bom\r\n\r\n{\n "args": {},\n "headers": {\n "Accept": [\n "*/*"\n ],\n "Host": [\n "httpbingo.org"\n ],\n "User-Agent": [\n "curl/8.6.0"\n ],\n "Via": [\n "1.1 fly.io"\n ],\n "X-Forwarded-For": [\n "183.82.162.68, 66.241.125.232"\n ],\n "X-Forwarded-Port": [\n "443"\n ],\n "X-Forwarded-Proto": [\n "https"\n ],\n "X-Forwarded-Ssl', + b'": [\n "on"\n ],\n "X-Request-Start": [\n "t=1731579851219982"\n ]\n },\n "method": "GET",\n "origin": "183.82.162.68",\n "url": "https://httpbingo.org/get"\n}\n', + ] + response.parse(memoryview(request[0])) + self.assertEqual(response.state, httpParserStates.RCVING_BODY) + self.assertEqual(response.code, b'200') + for byte in (bytes([b]) for b in request[1]): + response.parse(memoryview(byte)) + self.assertEqual(response.state, httpParserStates.COMPLETE)