Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/lmstudio/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,15 +909,15 @@ async def __aexit__(
self._set_error(exc_val)
await self.aclose()

async def __aiter__(self) -> AsyncIterator[str]:
async def __aiter__(self) -> AsyncIterator[LlmPredictionFragment]:
endpoint = self._endpoint
async with self:
assert self._channel is not None
async for contents in self._channel.rx_stream():
for event in endpoint.iter_message_events(contents):
endpoint.handle_rx_event(event)
if isinstance(event, PredictionFragmentEvent):
yield event.arg.content
yield event.arg
if endpoint.is_finished:
break
self._mark_finished()
Expand Down Expand Up @@ -1008,8 +1008,8 @@ async def _complete_stream(
on_prompt_processing_progress,
)
channel_cm = self._create_channel(endpoint)
prediction = AsyncPredictionStream(channel_cm, endpoint)
return prediction
prediction_stream = AsyncPredictionStream(channel_cm, endpoint)
return prediction_stream

@overload
async def _respond_stream(
Expand Down Expand Up @@ -1064,8 +1064,8 @@ async def _respond_stream(
on_prompt_processing_progress,
)
channel_cm = self._create_channel(endpoint)
prediction = AsyncPredictionStream(channel_cm, endpoint)
return prediction
prediction_stream = AsyncPredictionStream(channel_cm, endpoint)
return prediction_stream

async def _apply_prompt_template(
self,
Expand Down Expand Up @@ -1264,7 +1264,7 @@ async def complete(
on_prompt_processing_progress: Callable[[float], None] | None = None,
) -> PredictionResult[str] | PredictionResult[DictObject]:
"""Request a one-off prediction without any context."""
prediction = await self._session._complete_stream(
prediction_stream = await self._session._complete_stream(
self.identifier,
prompt,
response_format=response_format,
Expand All @@ -1274,11 +1274,11 @@ async def complete(
on_prediction_fragment=on_prediction_fragment,
on_prompt_processing_progress=on_prompt_processing_progress,
)
async for _ in prediction:
async for _ in prediction_stream:
# No yield in body means iterator reliably provides
# prompt resource cleanup on coroutine cancellation
pass
return prediction.result()
return prediction_stream.result()

@overload
async def respond_stream(
Expand Down Expand Up @@ -1365,7 +1365,7 @@ async def respond(
on_prompt_processing_progress: Callable[[float], None] | None = None,
) -> PredictionResult[str] | PredictionResult[DictObject]:
"""Request a response in an ongoing assistant chat session."""
prediction = await self._session._respond_stream(
prediction_stream = await self._session._respond_stream(
self.identifier,
history,
response_format=response_format,
Expand All @@ -1375,11 +1375,11 @@ async def respond(
on_prediction_fragment=on_prediction_fragment,
on_prompt_processing_progress=on_prompt_processing_progress,
)
async for _ in prediction:
async for _ in prediction_stream:
# No yield in body means iterator reliably provides
# prompt resource cleanup on coroutine cancellation
pass
return prediction.result()
return prediction_stream.result()

@sdk_public_api_async()
async def apply_prompt_template(
Expand Down Expand Up @@ -1411,7 +1411,7 @@ async def embed(
TAsyncSession = TypeVar("TAsyncSession", bound=AsyncSession)

_ASYNC_API_STABILITY_WARNING = """\
Note: the async API is not yet stable and is expected to change in future releases
Note the async API is not yet stable and is expected to change in future releases
"""


Expand Down
24 changes: 12 additions & 12 deletions src/lmstudio/sync_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,10 +1062,10 @@ def __exit__(
self._set_error(exc_val)
self.close()

def __iter__(self) -> Iterator[str]:
def __iter__(self) -> Iterator[LlmPredictionFragment]:
for event in self._iter_events():
if isinstance(event, PredictionFragmentEvent):
yield event.arg.content
yield event.arg

def _iter_events(self) -> Iterator[PredictionRxEvent]:
endpoint = self._endpoint
Expand Down Expand Up @@ -1165,8 +1165,8 @@ def _complete_stream(
on_prompt_processing_progress,
)
channel_cm = self._create_channel(endpoint)
prediction = PredictionStream(channel_cm, endpoint)
return prediction
prediction_stream = PredictionStream(channel_cm, endpoint)
return prediction_stream

@overload
def _respond_stream(
Expand Down Expand Up @@ -1221,8 +1221,8 @@ def _respond_stream(
on_prompt_processing_progress,
)
channel_cm = self._create_channel(endpoint)
prediction = PredictionStream(channel_cm, endpoint)
return prediction
prediction_stream = PredictionStream(channel_cm, endpoint)
return prediction_stream

def _apply_prompt_template(
self,
Expand Down Expand Up @@ -1419,7 +1419,7 @@ def complete(
on_prompt_processing_progress: Callable[[float], None] | None = None,
) -> PredictionResult[str] | PredictionResult[DictObject]:
"""Request a one-off prediction without any context."""
prediction = self._session._complete_stream(
prediction_stream = self._session._complete_stream(
self.identifier,
prompt,
response_format=response_format,
Expand All @@ -1429,11 +1429,11 @@ def complete(
on_prediction_fragment=on_prediction_fragment,
on_prompt_processing_progress=on_prompt_processing_progress,
)
for _ in prediction:
for _ in prediction_stream:
# No yield in body means iterator reliably provides
# prompt resource cleanup on coroutine cancellation
pass
return prediction.result()
return prediction_stream.result()

@overload
def respond_stream(
Expand Down Expand Up @@ -1520,7 +1520,7 @@ def respond(
on_prompt_processing_progress: Callable[[float], None] | None = None,
) -> PredictionResult[str] | PredictionResult[DictObject]:
"""Request a response in an ongoing assistant chat session."""
prediction = self._session._respond_stream(
prediction_stream = self._session._respond_stream(
self.identifier,
history,
response_format=response_format,
Expand All @@ -1530,11 +1530,11 @@ def respond(
on_prediction_fragment=on_prediction_fragment,
on_prompt_processing_progress=on_prompt_processing_progress,
)
for _ in prediction:
for _ in prediction_stream:
# No yield in body means iterator reliably provides
# prompt resource cleanup on coroutine cancellation
pass
return prediction.result()
return prediction_stream.result()

# Multi-round predictions are currently a sync-only handle-only feature
# TODO: Refactor to allow for more code sharing with the async API
Expand Down
40 changes: 22 additions & 18 deletions tests/async/test_inference_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,17 @@ async def test_complete_stream_async(caplog: LogCap) -> None:
model_id = EXPECTED_LLM_ID
async with AsyncClient() as client:
session = client.llm
prediction = await session._complete_stream(
prediction_stream = await session._complete_stream(
model_id, prompt, config=SHORT_PREDICTION_CONFIG
)
assert isinstance(prediction, AsyncPredictionStream)
assert isinstance(prediction_stream, AsyncPredictionStream)
# Also exercise the explicit context management interface
async with prediction:
async for token in prediction:
logging.info(f"Token: {token}")
assert token
assert isinstance(token, str)
response = prediction.result()
async with prediction_stream:
async for fragment in prediction_stream:
logging.info(f"Fragment: {fragment}")
assert fragment.content
assert isinstance(fragment.content, str)
response = prediction_stream.result()
# The continuation from the LLM will change, but it won't be an empty string
logging.info(f"LLM response: {response!r}")
assert isinstance(response, PredictionResult)
Expand Down Expand Up @@ -151,7 +151,9 @@ def record_fragment(fragment: LlmPredictionFragment) -> None:
# This test case also covers the explicit context management interface
iteration_content: list[str] = []
async with prediction_stream:
iteration_content = [text async for text in prediction_stream]
iteration_content = [
fragment.content async for fragment in prediction_stream
]
assert len(messages) == 1
message = messages[0]
assert message.role == "assistant"
Expand Down Expand Up @@ -206,7 +208,9 @@ def record_fragment(fragment: LlmPredictionFragment) -> None:
# This test case also covers the explicit context management interface
iteration_content: list[str] = []
async with prediction_stream:
iteration_content = [text async for text in prediction_stream]
iteration_content = [
fragment.content async for fragment in prediction_stream
]
assert len(messages) == 1
message = messages[0]
assert message.role == "assistant"
Expand Down Expand Up @@ -267,10 +271,10 @@ async def test_invalid_model_request_stream_async(caplog: LogCap) -> None:
# This should error rather than timing out,
# but avoid any risk of the client hanging...
async with asyncio.timeout(30):
prediction = await model.complete_stream("Some text")
async with prediction:
prediction_stream = await model.complete_stream("Some text")
async with prediction_stream:
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
await prediction.wait_for_result()
await prediction_stream.wait_for_result()
check_sdk_error(exc_info, __file__)


Expand All @@ -283,11 +287,11 @@ async def test_cancel_prediction_async(caplog: LogCap) -> None:
caplog.set_level(logging.DEBUG)
async with AsyncClient() as client:
session = client.llm
response = await session._complete_stream(model_id, prompt=prompt)
async for _ in response:
await response.cancel()
stream = await session._complete_stream(model_id, prompt=prompt)
async for _ in stream:
await stream.cancel()
num_times += 1
assert response.stats
assert response.stats.stop_reason == "userStopped"
assert stream.stats
assert stream.stats.stop_reason == "userStopped"
# ensure __aiter__ closes correctly
assert num_times == 1
36 changes: 18 additions & 18 deletions tests/sync/test_inference_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,17 @@ def test_complete_stream_sync(caplog: LogCap) -> None:
model_id = EXPECTED_LLM_ID
with Client() as client:
session = client.llm
prediction = session._complete_stream(
prediction_stream = session._complete_stream(
model_id, prompt, config=SHORT_PREDICTION_CONFIG
)
assert isinstance(prediction, PredictionStream)
assert isinstance(prediction_stream, PredictionStream)
# Also exercise the explicit context management interface
with prediction:
for token in prediction:
logging.info(f"Token: {token}")
assert token
assert isinstance(token, str)
response = prediction.result()
with prediction_stream:
for fragment in prediction_stream:
logging.info(f"Fragment: {fragment}")
assert fragment.content
assert isinstance(fragment.content, str)
response = prediction_stream.result()
# The continuation from the LLM will change, but it won't be an empty string
logging.info(f"LLM response: {response!r}")
assert isinstance(response, PredictionResult)
Expand Down Expand Up @@ -153,7 +153,7 @@ def record_fragment(fragment: LlmPredictionFragment) -> None:
# This test case also covers the explicit context management interface
iteration_content: list[str] = []
with prediction_stream:
iteration_content = [text for text in prediction_stream]
iteration_content = [fragment.content for fragment in prediction_stream]
assert len(messages) == 1
message = messages[0]
assert message.role == "assistant"
Expand Down Expand Up @@ -207,7 +207,7 @@ def record_fragment(fragment: LlmPredictionFragment) -> None:
# This test case also covers the explicit context management interface
iteration_content: list[str] = []
with prediction_stream:
iteration_content = [text for text in prediction_stream]
iteration_content = [fragment.content for fragment in prediction_stream]
assert len(messages) == 1
message = messages[0]
assert message.role == "assistant"
Expand Down Expand Up @@ -265,10 +265,10 @@ def test_invalid_model_request_stream_sync(caplog: LogCap) -> None:
# This should error rather than timing out,
# but avoid any risk of the client hanging...
with nullcontext():
prediction = model.complete_stream("Some text")
with prediction:
prediction_stream = model.complete_stream("Some text")
with prediction_stream:
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
prediction.wait_for_result()
prediction_stream.wait_for_result()
check_sdk_error(exc_info, __file__)


Expand All @@ -280,11 +280,11 @@ def test_cancel_prediction_sync(caplog: LogCap) -> None:
caplog.set_level(logging.DEBUG)
with Client() as client:
session = client.llm
response = session._complete_stream(model_id, prompt=prompt)
for _ in response:
response.cancel()
stream = session._complete_stream(model_id, prompt=prompt)
for _ in stream:
stream.cancel()
num_times += 1
assert response.stats
assert response.stats.stop_reason == "userStopped"
assert stream.stats
assert stream.stats.stop_reason == "userStopped"
# ensure __aiter__ closes correctly
assert num_times == 1
4 changes: 2 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ commands =

[testenv:load-test-models]
commands =
python -m tests.load_models
python -W "ignore:Note the async API is not yet stable:FutureWarning" -m tests.load_models

[testenv:unload-test-models]
commands =
python -m tests.unload_models
python -W "ignore:Note the async API is not yet stable:FutureWarning" -m tests.unload_models

[testenv:coverage]
# Subprocess coverage based on https://hynek.me/articles/turbo-charge-tox/
Expand Down
Loading