Skip to content

Commit ee4cedb

Browse files
committed
Streams are iterable + receive_some doesn't require an explicit size
This came out of discussion in gh-959
1 parent f7850e8 commit ee4cedb

19 files changed

+145
-119
lines changed

docs/source/reference-io.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ Abstract base classes
9898
* - :class:`ReceiveStream`
9999
- :class:`AsyncResource`
100100
- :meth:`~ReceiveStream.receive_some`
101-
-
101+
- ``__aiter__``, ``__anext__``
102102
- :class:`~trio.testing.MemoryReceiveStream`
103103
* - :class:`Stream`
104104
- :class:`SendStream`, :class:`ReceiveStream`

docs/source/tutorial.rst

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -908,12 +908,10 @@ And the second task's job is to process the data the server sends back:
908908
:lineno-match:
909909
:pyobject: receiver
910910

911-
It repeatedly calls ``await client_stream.receive_some(...)`` to get
912-
more data from the server (again, all Trio streams provide this
913-
method), and then checks to see if the server has closed the
914-
connection. ``receive_some`` only returns an empty bytestring if the
915-
connection has been closed; otherwise, it waits until some data has
916-
arrived, up to a maximum of ``BUFSIZE`` bytes.
911+
It uses an ``async for`` loop to fetch data from the server.
912+
Alternatively, it could use `~trio.abc.ReceiveStream.receive_some`,
913+
which is the opposite of `~trio.abc.SendStream.send_all`, but using
914+
``async for`` saves some boilerplate.
917915

918916
And now we're ready to look at the server.
919917

@@ -974,11 +972,11 @@ functions we saw in the last section:
974972

975973
The argument ``server_stream`` is provided by :func:`serve_tcp`, and
976974
is the other end of the connection we made in the client: so the data
977-
that the client passes to ``send_all`` will come out of
978-
``receive_some`` here, and vice-versa. Then we have a ``try`` block
979-
discussed below, and finally the server loop which alternates between
980-
reading some data from the socket and then sending it back out again
981-
(unless the socket was closed, in which case we quit).
975+
that the client passes to ``send_all`` will come out here. Then we
976+
have a ``try`` block discussed below, and finally the server loop
977+
which alternates between reading some data from the socket and then
978+
sending it back out again (unless the socket was closed, in which case
979+
we quit).
982980

983981
So what's that ``try`` block for? Remember that in Trio, like Python
984982
in general, exceptions keep propagating until they're caught. Here we
@@ -1029,7 +1027,7 @@ our client could use a single task like::
10291027
while True:
10301028
data = ...
10311029
await client_stream.send_all(data)
1032-
received = await client_stream.receive_some(BUFSIZE)
1030+
received = await client_stream.receive_some()
10331031
if not received:
10341032
sys.exit()
10351033
await trio.sleep(1)
@@ -1046,18 +1044,23 @@ line, any time we're expecting more than one byte of data, we have to
10461044
be prepared to call ``receive_some`` multiple times.
10471045

10481046
And where this would go especially wrong is if we find ourselves in
1049-
the situation where ``len(data) > BUFSIZE``. On each pass through the
1050-
loop, we send ``len(data)`` bytes, but only read *at most* ``BUFSIZE``
1051-
bytes. The result is something like a memory leak: we'll end up with
1052-
more and more data backed up in the network, until eventually
1053-
something breaks.
1047+
the situation where ``data`` is big enough that it passes some
1048+
internal threshold, and the operating system or network decide to
1049+
always break it up into multiple pieces. Now on each pass through the
1050+
loop, we send ``len(data)`` bytes, but read less than that. The result
1051+
is something like a memory leak: we'll end up with more and more data
1052+
backed up in the network, until eventually something breaks.
1053+
1054+
.. note:: If you're curious *how* things break, then you can use
1055+
`~trio.abc.ReceiveStream.receive_some`\'s optional argument to put
1056+
a limit on how many bytes you read each time, and see what happens.
10541057

10551058
We could fix this by keeping track of how much data we're expecting at
10561059
each moment, and then keep calling ``receive_some`` until we get it all::
10571060

10581061
expected = len(data)
10591062
while expected > 0:
1060-
received = await client_stream.receive_some(BUFSIZE)
1063+
received = await client_stream.receive_some(expected)
10611064
if not received:
10621065
sys.exit(1)
10631066
expected -= len(received)

docs/source/tutorial/echo-client.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
# - can't be in use by some other program on your computer
99
# - must match what we set in our echo server
1010
PORT = 12345
11-
# How much memory to spend (at most) on each call to recv. Pretty arbitrary,
12-
# but shouldn't be too big or too small.
13-
BUFSIZE = 16384
1411

1512
async def sender(client_stream):
1613
print("sender: started!")
@@ -22,12 +19,10 @@ async def sender(client_stream):
2219

2320
async def receiver(client_stream):
2421
print("receiver: started!")
25-
while True:
26-
data = await client_stream.receive_some(BUFSIZE)
22+
async for data in client_stream:
2723
print("receiver: got data {!r}".format(data))
28-
if not data:
29-
print("receiver: connection closed")
30-
sys.exit()
24+
print("receiver: connection closed")
25+
sys.exit()
3126

3227
async def parent():
3328
print("parent: connecting to 127.0.0.1:{}".format(PORT))

docs/source/tutorial/echo-server.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
# - can't be in use by some other program on your computer
99
# - must match what we set in our echo client
1010
PORT = 12345
11-
# How much memory to spend (at most) on each call to recv. Pretty arbitrary,
12-
# but shouldn't be too big or too small.
13-
BUFSIZE = 16384
1411

1512
CONNECTION_COUNTER = count()
1613

@@ -20,14 +17,11 @@ async def echo_server(server_stream):
2017
ident = next(CONNECTION_COUNTER)
2118
print("echo_server {}: started".format(ident))
2219
try:
23-
while True:
24-
data = await server_stream.receive_some(BUFSIZE)
20+
async for data in server_stream:
2521
print("echo_server {}: received data {!r}".format(ident, data))
26-
if not data:
27-
print("echo_server {}: connection closed".format(ident))
28-
return
29-
print("echo_server {}: sending data {!r}".format(ident, data))
3022
await server_stream.send_all(data)
23+
print("echo_server {}: connection closed".format(ident))
24+
return
3125
# FIXME: add discussion of MultiErrors to the tutorial, and use
3226
# MultiError.catch here. (Not important in this case, but important if the
3327
# server code uses nurseries internally.)

newsfragments/959.feature.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
If you have a `~trio.abc.ReceiveStream` object, you can now use
2+
``async for data in stream: ...`` instead of calling
3+
`~trio.abc.ReceiveStream.receive_some` repeatedly. And the best part
4+
is, it automatically checks for EOF for you, so you don't have to.
5+
Also, you no longer have to choose a magic buffer size value before
6+
calling `~trio.abc.ReceiveStream.receive_some`; you can now call
7+
``await stream.receive_some()`` and the stream will automatically pick
8+
a reasonable value for you.

notes-to-self/graceful-shutdown-idea.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def shutting_down(self):
3030
async def stream_handler(stream):
3131
while True:
3232
with gsm.cancel_on_graceful_shutdown():
33-
data = await stream.receive_some(...)
33+
data = await stream.receive_some()
3434
if gsm.shutting_down:
3535
break
3636

trio/_abc.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -378,26 +378,26 @@ class ReceiveStream(AsyncResource):
378378
If you want to receive Python objects rather than raw bytes, see
379379
:class:`ReceiveChannel`.
380380
381+
`ReceiveStream` objects can be used in ``async for`` loops. Each iteration
382+
will produce an arbitrary size
383+
381384
"""
382385
__slots__ = ()
383386

384387
@abstractmethod
385-
async def receive_some(self, max_bytes):
388+
async def receive_some(self, max_bytes=None):
386389
"""Wait until there is data available on this stream, and then return
387-
at most ``max_bytes`` of it.
390+
some of it.
388391
389392
A return value of ``b""`` (an empty bytestring) indicates that the
390393
stream has reached end-of-file. Implementations should be careful that
391394
they return ``b""`` if, and only if, the stream has reached
392395
end-of-file!
393396
394-
This method will return as soon as any data is available, so it may
395-
return fewer than ``max_bytes`` of data. But it will never return
396-
more.
397-
398397
Args:
399398
max_bytes (int): The maximum number of bytes to return. Must be
400-
greater than zero.
399+
greater than zero. Optional; if omitted, then the stream object
400+
is free to pick a reasonable default.
401401
402402
Returns:
403403
bytes or bytearray: The data received.
@@ -413,6 +413,15 @@ async def receive_some(self, max_bytes):
413413
414414
"""
415415

416+
def __aiter__(self):
417+
return self
418+
419+
async def __anext__(self):
420+
data = await self.receive_some()
421+
if not data:
422+
raise StopAsyncIteration
423+
return data
424+
416425

417426
class Stream(SendStream, ReceiveStream):
418427
"""A standard interface for interacting with bidirectional byte streams.

trio/_highlevel_generic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class StapledStream(HalfCloseableStream):
5252
left, right = trio.testing.memory_stream_pair()
5353
echo_stream = StapledStream(SocketStream(left), SocketStream(right))
5454
await echo_stream.send_all(b"x")
55-
assert await echo_stream.receive_some(1) == b"x"
55+
assert await echo_stream.receive_some() == b"x"
5656
5757
:class:`StapledStream` objects implement the methods in the
5858
:class:`~trio.abc.HalfCloseableStream` interface. They also have two
@@ -96,7 +96,7 @@ async def send_eof(self):
9696
else:
9797
return await self.send_stream.aclose()
9898

99-
async def receive_some(self, max_bytes):
99+
async def receive_some(self, max_bytes=None):
100100
"""Calls ``self.receive_stream.receive_some``.
101101
102102
"""

trio/_highlevel_socket.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010

1111
__all__ = ["SocketStream", "SocketListener"]
1212

13+
# XX TODO: this number was picked arbitrarily. We should do experiments to
14+
# tune it. (Or make it dynamic -- one idea is to start small and increase it
15+
# if we observe single reads filling up the whole buffer, at least within some
16+
# limits.)
17+
DEFAULT_RECEIVE_SIZE = 65536
18+
1319
_closed_stream_errnos = {
1420
# Unix
1521
errno.EBADF,
@@ -129,7 +135,9 @@ async def send_eof(self):
129135
with _translate_socket_errors_to_stream_errors():
130136
self.socket.shutdown(tsocket.SHUT_WR)
131137

132-
async def receive_some(self, max_bytes):
138+
async def receive_some(self, max_bytes=None):
139+
if max_bytes is None:
140+
max_bytes = DEFAULT_RECEIVE_SIZE
133141
if max_bytes < 1:
134142
raise ValueError("max_bytes must be >= 1")
135143
with _translate_socket_errors_to_stream_errors():

trio/_ssl.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,16 @@
159159
from ._highlevel_generic import aclose_forcefully
160160
from . import _sync
161161
from ._util import ConflictDetector
162+
from ._deprecate import warn_deprecated
162163

163164
################################################################
164165
# SSLStream
165166
################################################################
166167

168+
# XX TODO: this number was pulled out of a hat. We should tune it with
169+
# science.
170+
DEFAULT_RECEIVE_SIZE = 65536
171+
167172

168173
class NeedHandshakeError(Exception):
169174
"""Some :class:`SSLStream` methods can't return any meaningful data until
@@ -197,8 +202,6 @@ def done(self):
197202

198203
_State = _Enum("_State", ["OK", "BROKEN", "CLOSED"])
199204

200-
_default_max_refill_bytes = 32 * 1024
201-
202205

203206
class SSLStream(Stream):
204207
r"""Encrypted communication using SSL/TLS.
@@ -269,15 +272,6 @@ class SSLStream(Stream):
269272
that :class:`~ssl.SSLSocket` implements the
270273
``https_compatible=True`` behavior by default.
271274
272-
max_refill_bytes (int): :class:`~ssl.SSLSocket` maintains an internal
273-
buffer of incoming data, and when it runs low then it calls
274-
:meth:`receive_some` on the underlying transport stream to refill
275-
it. This argument lets you set the ``max_bytes`` argument passed to
276-
the *underlying* :meth:`receive_some` call. It doesn't affect calls
277-
to *this* class's :meth:`receive_some`, or really anything else
278-
user-observable except possibly performance. You probably don't need
279-
to worry about this.
280-
281275
Attributes:
282276
transport_stream (trio.abc.Stream): The underlying transport stream
283277
that was passed to ``__init__``. An example of when this would be
@@ -313,11 +307,14 @@ def __init__(
313307
server_hostname=None,
314308
server_side=False,
315309
https_compatible=False,
316-
max_refill_bytes=_default_max_refill_bytes
310+
max_refill_bytes="unused and deprecated",
317311
):
318312
self.transport_stream = transport_stream
319313
self._state = _State.OK
320-
self._max_refill_bytes = max_refill_bytes
314+
if max_refill_bytes != "unused and deprecated":
315+
warn_deprecated(
316+
"max_refill_bytes=...", "0.12.0", issue=959, instead=None
317+
)
321318
self._https_compatible = https_compatible
322319
self._outgoing = _stdlib_ssl.MemoryBIO()
323320
self._incoming = _stdlib_ssl.MemoryBIO()
@@ -536,9 +533,7 @@ async def _retry(self, fn, *args, ignore_want_read=False):
536533
async with self._inner_recv_lock:
537534
yielded = True
538535
if recv_count == self._inner_recv_count:
539-
data = await self.transport_stream.receive_some(
540-
self._max_refill_bytes
541-
)
536+
data = await self.transport_stream.receive_some()
542537
if not data:
543538
self._incoming.write_eof()
544539
else:
@@ -590,7 +585,7 @@ async def do_handshake(self):
590585
# https://bugs.python.org/issue30141
591586
# So we *definitely* have to make sure that do_handshake is called
592587
# before doing anything else.
593-
async def receive_some(self, max_bytes):
588+
async def receive_some(self, max_bytes=None):
594589
"""Read some data from the underlying transport, decrypt it, and
595590
return it.
596591
@@ -621,9 +616,15 @@ async def receive_some(self, max_bytes):
621616
return b""
622617
else:
623618
raise
624-
max_bytes = _operator.index(max_bytes)
625-
if max_bytes < 1:
626-
raise ValueError("max_bytes must be >= 1")
619+
if max_bytes is None:
620+
# Heuristic: normally we use DEFAULT_RECEIVE_SIZE, but if
621+
# the transport gave us a bunch of data last time then we'll
622+
# try to decrypt and pass it all back at once.
623+
max_bytes = max(DEFAULT_RECEIVE_SIZE, self._incoming.pending)
624+
else:
625+
max_bytes = _operator.index(max_bytes)
626+
if max_bytes < 1:
627+
raise ValueError("max_bytes must be >= 1")
627628
try:
628629
return await self._retry(self._ssl_object.read, max_bytes)
629630
except trio.BrokenResourceError as exc:
@@ -837,8 +838,6 @@ class SSLListener(Listener[SSLStream]):
837838
838839
https_compatible (bool): Passed on to :class:`SSLStream`.
839840
840-
max_refill_bytes (int): Passed on to :class:`SSLStream`.
841-
842841
Attributes:
843842
transport_listener (trio.abc.Listener): The underlying listener that was
844843
passed to ``__init__``.
@@ -851,12 +850,15 @@ def __init__(
851850
ssl_context,
852851
*,
853852
https_compatible=False,
854-
max_refill_bytes=_default_max_refill_bytes
853+
max_refill_bytes="unused and deprecated",
855854
):
855+
if max_refill_bytes != "unused and deprecated":
856+
warn_deprecated(
857+
"max_refill_bytes=...", "0.12.0", issue=959, instead=None
858+
)
856859
self.transport_listener = transport_listener
857860
self._ssl_context = ssl_context
858861
self._https_compatible = https_compatible
859-
self._max_refill_bytes = max_refill_bytes
860862

861863
async def accept(self):
862864
"""Accept the next connection and wrap it in an :class:`SSLStream`.
@@ -870,7 +872,6 @@ async def accept(self):
870872
self._ssl_context,
871873
server_side=True,
872874
https_compatible=self._https_compatible,
873-
max_refill_bytes=self._max_refill_bytes,
874875
)
875876

876877
async def aclose(self):

0 commit comments

Comments
 (0)