diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index 6c9759b..1b03ecc 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -909,7 +909,7 @@ async def __aexit__( self._set_error(exc_val) await self.aclose() - async def __aiter__(self) -> AsyncIterator[str]: + async def __aiter__(self) -> AsyncIterator[LlmPredictionFragment]: endpoint = self._endpoint async with self: assert self._channel is not None @@ -917,7 +917,7 @@ async def __aiter__(self) -> AsyncIterator[str]: for event in endpoint.iter_message_events(contents): endpoint.handle_rx_event(event) if isinstance(event, PredictionFragmentEvent): - yield event.arg.content + yield event.arg if endpoint.is_finished: break self._mark_finished() @@ -1008,8 +1008,8 @@ async def _complete_stream( on_prompt_processing_progress, ) channel_cm = self._create_channel(endpoint) - prediction = AsyncPredictionStream(channel_cm, endpoint) - return prediction + prediction_stream = AsyncPredictionStream(channel_cm, endpoint) + return prediction_stream @overload async def _respond_stream( @@ -1064,8 +1064,8 @@ async def _respond_stream( on_prompt_processing_progress, ) channel_cm = self._create_channel(endpoint) - prediction = AsyncPredictionStream(channel_cm, endpoint) - return prediction + prediction_stream = AsyncPredictionStream(channel_cm, endpoint) + return prediction_stream async def _apply_prompt_template( self, @@ -1264,7 +1264,7 @@ async def complete( on_prompt_processing_progress: Callable[[float], None] | None = None, ) -> PredictionResult[str] | PredictionResult[DictObject]: """Request a one-off prediction without any context.""" - prediction = await self._session._complete_stream( + prediction_stream = await self._session._complete_stream( self.identifier, prompt, response_format=response_format, @@ -1274,11 +1274,11 @@ async def complete( on_prediction_fragment=on_prediction_fragment, on_prompt_processing_progress=on_prompt_processing_progress, ) - async for _ in prediction: + async for _ in prediction_stream: # No yield in body means iterator reliably provides # prompt resource cleanup on coroutine cancellation pass - return prediction.result() + return prediction_stream.result() @overload async def respond_stream( @@ -1365,7 +1365,7 @@ async def respond( on_prompt_processing_progress: Callable[[float], None] | None = None, ) -> PredictionResult[str] | PredictionResult[DictObject]: """Request a response in an ongoing assistant chat session.""" - prediction = await self._session._respond_stream( + prediction_stream = await self._session._respond_stream( self.identifier, history, response_format=response_format, @@ -1375,11 +1375,11 @@ async def respond( on_prediction_fragment=on_prediction_fragment, on_prompt_processing_progress=on_prompt_processing_progress, ) - async for _ in prediction: + async for _ in prediction_stream: # No yield in body means iterator reliably provides # prompt resource cleanup on coroutine cancellation pass - return prediction.result() + return prediction_stream.result() @sdk_public_api_async() async def apply_prompt_template( @@ -1411,7 +1411,7 @@ async def embed( TAsyncSession = TypeVar("TAsyncSession", bound=AsyncSession) _ASYNC_API_STABILITY_WARNING = """\ -Note: the async API is not yet stable and is expected to change in future releases +Note the async API is not yet stable and is expected to change in future releases """ diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index cbd06c3..1b10cf9 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -1062,10 +1062,10 @@ def __exit__( self._set_error(exc_val) self.close() - def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[LlmPredictionFragment]: for event in self._iter_events(): if isinstance(event, PredictionFragmentEvent): - yield event.arg.content + yield event.arg def _iter_events(self) -> Iterator[PredictionRxEvent]: endpoint = self._endpoint @@ -1165,8 +1165,8 @@ def _complete_stream( on_prompt_processing_progress, ) channel_cm = self._create_channel(endpoint) - prediction = PredictionStream(channel_cm, endpoint) - return prediction + prediction_stream = PredictionStream(channel_cm, endpoint) + return prediction_stream @overload def _respond_stream( @@ -1221,8 +1221,8 @@ def _respond_stream( on_prompt_processing_progress, ) channel_cm = self._create_channel(endpoint) - prediction = PredictionStream(channel_cm, endpoint) - return prediction + prediction_stream = PredictionStream(channel_cm, endpoint) + return prediction_stream def _apply_prompt_template( self, @@ -1419,7 +1419,7 @@ def complete( on_prompt_processing_progress: Callable[[float], None] | None = None, ) -> PredictionResult[str] | PredictionResult[DictObject]: """Request a one-off prediction without any context.""" - prediction = self._session._complete_stream( + prediction_stream = self._session._complete_stream( self.identifier, prompt, response_format=response_format, @@ -1429,11 +1429,11 @@ def complete( on_prediction_fragment=on_prediction_fragment, on_prompt_processing_progress=on_prompt_processing_progress, ) - for _ in prediction: + for _ in prediction_stream: # No yield in body means iterator reliably provides # prompt resource cleanup on coroutine cancellation pass - return prediction.result() + return prediction_stream.result() @overload def respond_stream( @@ -1520,7 +1520,7 @@ def respond( on_prompt_processing_progress: Callable[[float], None] | None = None, ) -> PredictionResult[str] | PredictionResult[DictObject]: """Request a response in an ongoing assistant chat session.""" - prediction = self._session._respond_stream( + prediction_stream = self._session._respond_stream( self.identifier, history, response_format=response_format, @@ -1530,11 +1530,11 @@ def respond( on_prediction_fragment=on_prediction_fragment, on_prompt_processing_progress=on_prompt_processing_progress, ) - for _ in prediction: + for _ in prediction_stream: # No yield in body means iterator reliably provides # prompt resource cleanup on coroutine cancellation pass - return prediction.result() + return prediction_stream.result() # Multi-round predictions are currently a sync-only handle-only feature # TODO: Refactor to allow for more code sharing with the async API diff --git a/tests/async/test_inference_async.py b/tests/async/test_inference_async.py index 79e8f2c..88b3f77 100644 --- a/tests/async/test_inference_async.py +++ b/tests/async/test_inference_async.py @@ -72,17 +72,17 @@ async def test_complete_stream_async(caplog: LogCap) -> None: model_id = EXPECTED_LLM_ID async with AsyncClient() as client: session = client.llm - prediction = await session._complete_stream( + prediction_stream = await session._complete_stream( model_id, prompt, config=SHORT_PREDICTION_CONFIG ) - assert isinstance(prediction, AsyncPredictionStream) + assert isinstance(prediction_stream, AsyncPredictionStream) # Also exercise the explicit context management interface - async with prediction: - async for token in prediction: - logging.info(f"Token: {token}") - assert token - assert isinstance(token, str) - response = prediction.result() + async with prediction_stream: + async for fragment in prediction_stream: + logging.info(f"Fragment: {fragment}") + assert fragment.content + assert isinstance(fragment.content, str) + response = prediction_stream.result() # The continuation from the LLM will change, but it won't be an empty string logging.info(f"LLM response: {response!r}") assert isinstance(response, PredictionResult) @@ -151,7 +151,9 @@ def record_fragment(fragment: LlmPredictionFragment) -> None: # This test case also covers the explicit context management interface iteration_content: list[str] = [] async with prediction_stream: - iteration_content = [text async for text in prediction_stream] + iteration_content = [ + fragment.content async for fragment in prediction_stream + ] assert len(messages) == 1 message = messages[0] assert message.role == "assistant" @@ -206,7 +208,9 @@ def record_fragment(fragment: LlmPredictionFragment) -> None: # This test case also covers the explicit context management interface iteration_content: list[str] = [] async with prediction_stream: - iteration_content = [text async for text in prediction_stream] + iteration_content = [ + fragment.content async for fragment in prediction_stream + ] assert len(messages) == 1 message = messages[0] assert message.role == "assistant" @@ -267,10 +271,10 @@ async def test_invalid_model_request_stream_async(caplog: LogCap) -> None: # This should error rather than timing out, # but avoid any risk of the client hanging... async with asyncio.timeout(30): - prediction = await model.complete_stream("Some text") - async with prediction: + prediction_stream = await model.complete_stream("Some text") + async with prediction_stream: with pytest.raises(LMStudioModelNotFoundError) as exc_info: - await prediction.wait_for_result() + await prediction_stream.wait_for_result() check_sdk_error(exc_info, __file__) @@ -283,11 +287,11 @@ async def test_cancel_prediction_async(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) async with AsyncClient() as client: session = client.llm - response = await session._complete_stream(model_id, prompt=prompt) - async for _ in response: - await response.cancel() + stream = await session._complete_stream(model_id, prompt=prompt) + async for _ in stream: + await stream.cancel() num_times += 1 - assert response.stats - assert response.stats.stop_reason == "userStopped" + assert stream.stats + assert stream.stats.stop_reason == "userStopped" # ensure __aiter__ closes correctly assert num_times == 1 diff --git a/tests/sync/test_inference_sync.py b/tests/sync/test_inference_sync.py index b700f1e..1c1533e 100644 --- a/tests/sync/test_inference_sync.py +++ b/tests/sync/test_inference_sync.py @@ -76,17 +76,17 @@ def test_complete_stream_sync(caplog: LogCap) -> None: model_id = EXPECTED_LLM_ID with Client() as client: session = client.llm - prediction = session._complete_stream( + prediction_stream = session._complete_stream( model_id, prompt, config=SHORT_PREDICTION_CONFIG ) - assert isinstance(prediction, PredictionStream) + assert isinstance(prediction_stream, PredictionStream) # Also exercise the explicit context management interface - with prediction: - for token in prediction: - logging.info(f"Token: {token}") - assert token - assert isinstance(token, str) - response = prediction.result() + with prediction_stream: + for fragment in prediction_stream: + logging.info(f"Fragment: {fragment}") + assert fragment.content + assert isinstance(fragment.content, str) + response = prediction_stream.result() # The continuation from the LLM will change, but it won't be an empty string logging.info(f"LLM response: {response!r}") assert isinstance(response, PredictionResult) @@ -153,7 +153,7 @@ def record_fragment(fragment: LlmPredictionFragment) -> None: # This test case also covers the explicit context management interface iteration_content: list[str] = [] with prediction_stream: - iteration_content = [text for text in prediction_stream] + iteration_content = [fragment.content for fragment in prediction_stream] assert len(messages) == 1 message = messages[0] assert message.role == "assistant" @@ -207,7 +207,7 @@ def record_fragment(fragment: LlmPredictionFragment) -> None: # This test case also covers the explicit context management interface iteration_content: list[str] = [] with prediction_stream: - iteration_content = [text for text in prediction_stream] + iteration_content = [fragment.content for fragment in prediction_stream] assert len(messages) == 1 message = messages[0] assert message.role == "assistant" @@ -265,10 +265,10 @@ def test_invalid_model_request_stream_sync(caplog: LogCap) -> None: # This should error rather than timing out, # but avoid any risk of the client hanging... with nullcontext(): - prediction = model.complete_stream("Some text") - with prediction: + prediction_stream = model.complete_stream("Some text") + with prediction_stream: with pytest.raises(LMStudioModelNotFoundError) as exc_info: - prediction.wait_for_result() + prediction_stream.wait_for_result() check_sdk_error(exc_info, __file__) @@ -280,11 +280,11 @@ def test_cancel_prediction_sync(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) with Client() as client: session = client.llm - response = session._complete_stream(model_id, prompt=prompt) - for _ in response: - response.cancel() + stream = session._complete_stream(model_id, prompt=prompt) + for _ in stream: + stream.cancel() num_times += 1 - assert response.stats - assert response.stats.stop_reason == "userStopped" + assert stream.stats + assert stream.stats.stop_reason == "userStopped" # ensure __aiter__ closes correctly assert num_times == 1 diff --git a/tox.ini b/tox.ini index 11f1903..f7f9091 100644 --- a/tox.ini +++ b/tox.ini @@ -24,11 +24,11 @@ commands = [testenv:load-test-models] commands = - python -m tests.load_models + python -W "ignore:Note the async API is not yet stable:FutureWarning" -m tests.load_models [testenv:unload-test-models] commands = - python -m tests.unload_models + python -W "ignore:Note the async API is not yet stable:FutureWarning" -m tests.unload_models [testenv:coverage] # Subprocess coverage based on https://hynek.me/articles/turbo-charge-tox/