Skip to content

Commit e4e4999

Browse files
authored
Add model_request_stream_sync to direct API (#2116)
1 parent cefe6a3 commit e4e4999

File tree

3 files changed

+291
-3
lines changed

3 files changed

+291
-3
lines changed

docs/direct.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The following functions are available:
99
- [`model_request`][pydantic_ai.direct.model_request]: Make a non-streamed async request to a model
1010
- [`model_request_sync`][pydantic_ai.direct.model_request_sync]: Make a non-streamed synchronous request to a model
1111
- [`model_request_stream`][pydantic_ai.direct.model_request_stream]: Make a streamed async request to a model
12+
- [`model_request_stream_sync`][pydantic_ai.direct.model_request_stream_sync]: Make a streamed sync request to a model
1213

1314
## Basic Example
1415

pydantic_ai_slim/pydantic_ai/direct.py

Lines changed: 191 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,29 @@
88

99
from __future__ import annotations as _annotations
1010

11+
import queue
12+
import threading
13+
from collections.abc import Iterator
1114
from contextlib import AbstractAsyncContextManager
15+
from dataclasses import dataclass, field
16+
from datetime import datetime
17+
from types import TracebackType
1218

19+
from pydantic_ai.usage import Usage
1320
from pydantic_graph._utils import get_event_loop as _get_event_loop
1421

1522
from . import agent, messages, models, settings
16-
from .models import instrumented as instrumented_models
23+
from .models import StreamedResponse, instrumented as instrumented_models
1724

18-
__all__ = 'model_request', 'model_request_sync', 'model_request_stream'
25+
__all__ = (
26+
'model_request',
27+
'model_request_sync',
28+
'model_request_stream',
29+
'model_request_stream_sync',
30+
'StreamedResponseSync',
31+
)
32+
33+
STREAM_INITIALIZATION_TIMEOUT = 30
1934

2035

2136
async def model_request(
@@ -144,7 +159,7 @@ def model_request_stream(
144159
145160
async def main():
146161
messages = [ModelRequest.user_text_prompt('Who was Albert Einstein?')] # (1)!
147-
async with model_request_stream( 'openai:gpt-4.1-mini', messages) as stream:
162+
async with model_request_stream('openai:gpt-4.1-mini', messages) as stream:
148163
chunks = []
149164
async for chunk in stream:
150165
chunks.append(chunk)
@@ -181,6 +196,63 @@ async def main():
181196
)
182197

183198

199+
def model_request_stream_sync(
200+
model: models.Model | models.KnownModelName | str,
201+
messages: list[messages.ModelMessage],
202+
*,
203+
model_settings: settings.ModelSettings | None = None,
204+
model_request_parameters: models.ModelRequestParameters | None = None,
205+
instrument: instrumented_models.InstrumentationSettings | bool | None = None,
206+
) -> StreamedResponseSync:
207+
"""Make a streamed synchronous request to a model.
208+
209+
This is the synchronous version of [`model_request_stream`][pydantic_ai.direct.model_request_stream].
210+
It uses threading to run the asynchronous stream in the background while providing a synchronous iterator interface.
211+
212+
```py {title="model_request_stream_sync_example.py"}
213+
214+
from pydantic_ai.direct import model_request_stream_sync
215+
from pydantic_ai.messages import ModelRequest
216+
217+
messages = [ModelRequest.user_text_prompt('Who was Albert Einstein?')]
218+
with model_request_stream_sync('openai:gpt-4.1-mini', messages) as stream:
219+
chunks = []
220+
for chunk in stream:
221+
chunks.append(chunk)
222+
print(chunks)
223+
'''
224+
[
225+
PartStartEvent(index=0, part=TextPart(content='Albert Einstein was ')),
226+
PartDeltaEvent(
227+
index=0, delta=TextPartDelta(content_delta='a German-born theoretical ')
228+
),
229+
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='physicist.')),
230+
]
231+
'''
232+
```
233+
234+
Args:
235+
model: The model to make a request to. We allow `str` here since the actual list of allowed models changes frequently.
236+
messages: Messages to send to the model
237+
model_settings: optional model settings
238+
model_request_parameters: optional model request parameters
239+
instrument: Whether to instrument the request with OpenTelemetry/Logfire, if `None` the value from
240+
[`logfire.instrument_pydantic_ai`][logfire.Logfire.instrument_pydantic_ai] is used.
241+
242+
Returns:
243+
A [sync stream response][pydantic_ai.direct.StreamedResponseSync] context manager.
244+
"""
245+
async_stream_cm = model_request_stream(
246+
model=model,
247+
messages=messages,
248+
model_settings=model_settings,
249+
model_request_parameters=model_request_parameters,
250+
instrument=instrument,
251+
)
252+
253+
return StreamedResponseSync(async_stream_cm)
254+
255+
184256
def _prepare_model(
185257
model: models.Model | models.KnownModelName | str,
186258
instrument: instrumented_models.InstrumentationSettings | bool | None,
@@ -191,3 +263,119 @@ def _prepare_model(
191263
instrument = agent.Agent._instrument_default # pyright: ignore[reportPrivateUsage]
192264

193265
return instrumented_models.instrument_model(model_instance, instrument)
266+
267+
268+
@dataclass
269+
class StreamedResponseSync:
270+
"""Synchronous wrapper to async streaming responses by running the async producer in a background thread and providing a synchronous iterator.
271+
272+
This class must be used as a context manager with the `with` statement.
273+
"""
274+
275+
_async_stream_cm: AbstractAsyncContextManager[StreamedResponse]
276+
_queue: queue.Queue[messages.ModelResponseStreamEvent | Exception | None] = field(
277+
default_factory=queue.Queue, init=False
278+
)
279+
_thread: threading.Thread | None = field(default=None, init=False)
280+
_stream_response: StreamedResponse | None = field(default=None, init=False)
281+
_exception: Exception | None = field(default=None, init=False)
282+
_context_entered: bool = field(default=False, init=False)
283+
_stream_ready: threading.Event = field(default_factory=threading.Event, init=False)
284+
285+
def __enter__(self) -> StreamedResponseSync:
286+
self._context_entered = True
287+
self._start_producer()
288+
return self
289+
290+
def __exit__(
291+
self,
292+
_exc_type: type[BaseException] | None,
293+
_exc_val: BaseException | None,
294+
_exc_tb: TracebackType | None,
295+
) -> None:
296+
self._cleanup()
297+
298+
def __iter__(self) -> Iterator[messages.ModelResponseStreamEvent]:
299+
"""Stream the response as an iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
300+
self._check_context_manager_usage()
301+
302+
while True:
303+
item = self._queue.get()
304+
if item is None: # End of stream
305+
break
306+
elif isinstance(item, Exception):
307+
raise item
308+
else:
309+
yield item
310+
311+
def __repr__(self) -> str:
312+
if self._stream_response:
313+
return repr(self._stream_response)
314+
else:
315+
return f'{self.__class__.__name__}(context_entered={self._context_entered})'
316+
317+
__str__ = __repr__
318+
319+
def _check_context_manager_usage(self) -> None:
320+
if not self._context_entered:
321+
raise RuntimeError(
322+
'StreamedResponseSync must be used as a context manager. '
323+
'Use: `with model_request_stream_sync(...) as stream:`'
324+
)
325+
326+
def _ensure_stream_ready(self) -> StreamedResponse:
327+
self._check_context_manager_usage()
328+
329+
if self._stream_response is None:
330+
# Wait for the background thread to signal that the stream is ready
331+
if not self._stream_ready.wait(timeout=STREAM_INITIALIZATION_TIMEOUT):
332+
raise RuntimeError('Stream failed to initialize within timeout')
333+
334+
if self._stream_response is None: # pragma: no cover
335+
raise RuntimeError('Stream failed to initialize')
336+
337+
return self._stream_response
338+
339+
def _start_producer(self):
340+
self._thread = threading.Thread(target=self._async_producer, daemon=True)
341+
self._thread.start()
342+
343+
def _async_producer(self):
344+
async def _consume_async_stream():
345+
try:
346+
async with self._async_stream_cm as stream:
347+
self._stream_response = stream
348+
# Signal that the stream is ready
349+
self._stream_ready.set()
350+
async for event in stream:
351+
self._queue.put(event)
352+
except Exception as e:
353+
# Signal ready even on error so waiting threads don't hang
354+
self._stream_ready.set()
355+
self._queue.put(e)
356+
finally:
357+
self._queue.put(None) # Signal end
358+
359+
_get_event_loop().run_until_complete(_consume_async_stream())
360+
361+
def _cleanup(self):
362+
if self._thread and self._thread.is_alive():
363+
self._thread.join()
364+
365+
def get(self) -> messages.ModelResponse:
366+
"""Build a ModelResponse from the data received from the stream so far."""
367+
return self._ensure_stream_ready().get()
368+
369+
def usage(self) -> Usage:
370+
"""Get the usage of the response so far."""
371+
return self._ensure_stream_ready().usage()
372+
373+
@property
374+
def model_name(self) -> str:
375+
"""Get the model name of the response."""
376+
return self._ensure_stream_ready().model_name
377+
378+
@property
379+
def timestamp(self) -> datetime:
380+
"""Get the timestamp of the response."""
381+
return self._ensure_stream_ready().timestamp

tests/test_direct.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
1+
import asyncio
2+
import re
13
from contextlib import contextmanager
24
from datetime import timezone
5+
from unittest.mock import AsyncMock, patch
36

47
import pytest
58
from inline_snapshot import snapshot
69

710
from pydantic_ai import Agent
811
from pydantic_ai.direct import (
12+
StreamedResponseSync,
913
_prepare_model, # pyright: ignore[reportPrivateUsage]
1014
model_request,
1115
model_request_stream,
16+
model_request_stream_sync,
1217
model_request_sync,
1318
)
1419
from pydantic_ai.messages import (
20+
ModelMessage,
1521
ModelRequest,
1622
ModelResponse,
1723
PartDeltaEvent,
@@ -76,6 +82,24 @@ def test_model_request_sync():
7682
)
7783

7884

85+
def test_model_request_stream_sync():
86+
with model_request_stream_sync('test', [ModelRequest.user_text_prompt('x')]) as stream:
87+
chunks = list(stream)
88+
assert chunks == snapshot(
89+
[
90+
PartStartEvent(index=0, part=TextPart(content='')),
91+
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='success ')),
92+
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='(no ')),
93+
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='tool ')),
94+
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='calls)')),
95+
]
96+
)
97+
98+
repr_str = repr(stream)
99+
assert 'TestStreamedResponse' in repr_str
100+
assert 'test' in repr_str
101+
102+
79103
async def test_model_request_stream():
80104
async with model_request_stream('test', [ModelRequest.user_text_prompt('x')]) as stream:
81105
chunks = [chunk async for chunk in stream]
@@ -90,6 +114,81 @@ async def test_model_request_stream():
90114
)
91115

92116

117+
def test_model_request_stream_sync_without_context_manager():
118+
"""Test that accessing properties or iterating without context manager raises RuntimeError."""
119+
messages: list[ModelMessage] = [ModelRequest.user_text_prompt('x')]
120+
121+
expected_error_msg = re.escape(
122+
'StreamedResponseSync must be used as a context manager. Use: `with model_request_stream_sync(...) as stream:`'
123+
)
124+
125+
stream_cm = model_request_stream_sync('test', messages)
126+
127+
stream_repr = repr(stream_cm)
128+
assert 'StreamedResponseSync' in stream_repr
129+
assert 'context_entered=False' in stream_repr
130+
131+
with pytest.raises(RuntimeError, match=expected_error_msg):
132+
_ = stream_cm.model_name
133+
134+
with pytest.raises(RuntimeError, match=expected_error_msg):
135+
_ = stream_cm.timestamp
136+
137+
with pytest.raises(RuntimeError, match=expected_error_msg):
138+
stream_cm.get()
139+
140+
with pytest.raises(RuntimeError, match=expected_error_msg):
141+
stream_cm.usage()
142+
143+
with pytest.raises(RuntimeError, match=expected_error_msg):
144+
list(stream_cm)
145+
146+
with pytest.raises(RuntimeError, match=expected_error_msg):
147+
for _ in stream_cm:
148+
break
149+
150+
151+
def test_model_request_stream_sync_exception_in_stream():
152+
"""Test handling of exceptions raised during streaming."""
153+
async_stream_mock = AsyncMock()
154+
async_stream_mock.__aenter__ = AsyncMock(side_effect=ValueError('Stream error'))
155+
156+
stream_sync = StreamedResponseSync(_async_stream_cm=async_stream_mock)
157+
158+
with stream_sync:
159+
with pytest.raises(ValueError, match='Stream error'):
160+
list(stream_sync)
161+
162+
163+
def test_model_request_stream_sync_timeout():
164+
"""Test timeout when stream fails to initialize."""
165+
async_stream_mock = AsyncMock()
166+
167+
async def slow_init():
168+
await asyncio.sleep(0.1)
169+
170+
async_stream_mock.__aenter__ = AsyncMock(side_effect=slow_init)
171+
172+
stream_sync = StreamedResponseSync(_async_stream_cm=async_stream_mock)
173+
174+
with patch('pydantic_ai.direct.STREAM_INITIALIZATION_TIMEOUT', 0.01):
175+
with stream_sync:
176+
with pytest.raises(RuntimeError, match='Stream failed to initialize within timeout'):
177+
stream_sync.get()
178+
179+
180+
def test_model_request_stream_sync_intermediate_get():
181+
"""Test getting properties of StreamedResponse before consuming all events."""
182+
messages: list[ModelMessage] = [ModelRequest.user_text_prompt('x')]
183+
184+
with model_request_stream_sync('test', messages) as stream:
185+
response = stream.get()
186+
assert response is not None
187+
188+
usage = stream.usage()
189+
assert usage is not None
190+
191+
93192
@contextmanager
94193
def set_instrument_default(value: bool):
95194
"""Context manager to temporarily set the default instrumentation value."""

0 commit comments

Comments
 (0)