Skip to content

Simplify prediction API type hinting #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 15 additions & 179 deletions src/lmstudio/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
Callable,
Generic,
Iterable,
Literal,
Sequence,
Type,
TypeAlias,
TypeVar,
overload,
)
from typing_extensions import (
# Native in 3.11+
Expand Down Expand Up @@ -89,7 +87,6 @@
RemoteCallHandler,
ResponseSchema,
TModelInfo,
TPrediction,
check_model_namespace,
load_struct,
_model_spec_to_api_dict,
Expand Down Expand Up @@ -902,24 +899,23 @@ async def _fetch_file_handle(self, file_data: _LocalFileData) -> FileHandle:
return await self._files_session._fetch_file_handle(file_data)


AsyncPredictionChannel: TypeAlias = AsyncChannel[PredictionResult[T]]
AsyncPredictionCM: TypeAlias = AsyncContextManager[AsyncPredictionChannel[T]]
AsyncPredictionChannel: TypeAlias = AsyncChannel[PredictionResult]
AsyncPredictionCM: TypeAlias = AsyncContextManager[AsyncPredictionChannel]


class AsyncPredictionStream(PredictionStreamBase[TPrediction]):
class AsyncPredictionStream(PredictionStreamBase):
"""Async context manager for an ongoing prediction process."""

def __init__(
self,
channel_cm: AsyncPredictionCM[TPrediction],
endpoint: PredictionEndpoint[TPrediction],
channel_cm: AsyncPredictionCM,
endpoint: PredictionEndpoint,
) -> None:
"""Initialize a prediction process representation."""
self._resource_manager = AsyncExitStack()
self._channel_cm: AsyncPredictionCM[TPrediction] = channel_cm
self._channel: AsyncPredictionChannel[TPrediction] | None = None
# See comments in BasePrediction regarding not calling super().__init__() here
self._init_prediction(endpoint)
self._channel_cm: AsyncPredictionCM = channel_cm
self._channel: AsyncPredictionChannel | None = None
super().__init__(endpoint)

@sdk_public_api_async()
async def start(self) -> None:
Expand Down Expand Up @@ -976,7 +972,7 @@ async def __aiter__(self) -> AsyncIterator[LlmPredictionFragment]:
self._mark_finished()

@sdk_public_api_async()
async def wait_for_result(self) -> PredictionResult[TPrediction]:
async def wait_for_result(self) -> PredictionResult:
"""Wait for the result of the prediction."""
async for _ in self:
pass
Expand Down Expand Up @@ -1011,34 +1007,6 @@ def _create_handle(self, model_identifier: str) -> "AsyncLLM":
"""Create a symbolic handle to the specified LLM model."""
return AsyncLLM(model_identifier, self)

@overload
async def _complete_stream(
self,
model_specifier: AnyModelSpecifier,
prompt: str,
*,
response_format: Literal[None] = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
on_first_token: PredictionFirstTokenCallback | None = ...,
on_prediction_fragment: PredictionFragmentCallback | None = ...,
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
) -> AsyncPredictionStream[str]: ...
@overload
async def _complete_stream(
self,
model_specifier: AnyModelSpecifier,
prompt: str,
*,
response_format: ResponseSchema = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
on_first_token: PredictionFirstTokenCallback | None = ...,
on_prediction_fragment: PredictionFragmentCallback | None = ...,
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
) -> AsyncPredictionStream[DictObject]: ...
async def _complete_stream(
self,
model_specifier: AnyModelSpecifier,
Expand All @@ -1051,7 +1019,7 @@ async def _complete_stream(
on_first_token: PredictionFirstTokenCallback | None = None,
on_prediction_fragment: PredictionFragmentCallback | None = None,
on_prompt_processing_progress: PromptProcessingCallback | None = None,
) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]:
) -> AsyncPredictionStream:
"""Request a one-off prediction without any context and stream the generated tokens.

Note: details of configuration fields may change in SDK feature releases.
Expand All @@ -1071,34 +1039,6 @@ async def _complete_stream(
prediction_stream = AsyncPredictionStream(channel_cm, endpoint)
return prediction_stream

@overload
async def _respond_stream(
self,
model_specifier: AnyModelSpecifier,
history: Chat | ChatHistoryDataDict | str,
*,
response_format: Literal[None] = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
on_first_token: PredictionFirstTokenCallback | None = ...,
on_prediction_fragment: PredictionFragmentCallback | None = ...,
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
) -> AsyncPredictionStream[str]: ...
@overload
async def _respond_stream(
self,
model_specifier: AnyModelSpecifier,
history: Chat | ChatHistoryDataDict | str,
*,
response_format: ResponseSchema = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
on_first_token: PredictionFirstTokenCallback | None = ...,
on_prediction_fragment: PredictionFragmentCallback | None = ...,
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
) -> AsyncPredictionStream[DictObject]: ...
async def _respond_stream(
self,
model_specifier: AnyModelSpecifier,
Expand All @@ -1111,7 +1051,7 @@ async def _respond_stream(
on_first_token: PredictionFirstTokenCallback | None = None,
on_prediction_fragment: PredictionFragmentCallback | None = None,
on_prompt_processing_progress: PromptProcessingCallback | None = None,
) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]:
) -> AsyncPredictionStream:
"""Request a response in an ongoing assistant chat session and stream the generated tokens.

Note: details of configuration fields may change in SDK feature releases.
Expand Down Expand Up @@ -1250,32 +1190,6 @@ async def get_context_length(self) -> int:
class AsyncLLM(AsyncModelHandle[AsyncSessionLlm]):
"""Reference to a loaded LLM model."""

@overload
async def complete_stream(
self,
prompt: str,
*,
response_format: Literal[None] = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
on_first_token: PredictionFirstTokenCallback | None = ...,
on_prediction_fragment: PredictionFragmentCallback | None = ...,
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
) -> AsyncPredictionStream[str]: ...
@overload
async def complete_stream(
self,
prompt: str,
*,
response_format: ResponseSchema = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
on_first_token: PredictionFirstTokenCallback | None = ...,
on_prediction_fragment: PredictionFragmentCallback | None = ...,
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
) -> AsyncPredictionStream[DictObject]: ...
@sdk_public_api_async()
async def complete_stream(
self,
Expand All @@ -1288,7 +1202,7 @@ async def complete_stream(
on_first_token: PredictionFirstTokenCallback | None = None,
on_prediction_fragment: PredictionFragmentCallback | None = None,
on_prompt_processing_progress: PromptProcessingCallback | None = None,
) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]:
) -> AsyncPredictionStream:
"""Request a one-off prediction without any context and stream the generated tokens.

Note: details of configuration fields may change in SDK feature releases.
Expand All @@ -1305,32 +1219,6 @@ async def complete_stream(
on_prompt_processing_progress=on_prompt_processing_progress,
)

@overload
async def complete(
self,
prompt: str,
*,
response_format: Literal[None] = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
on_first_token: PredictionFirstTokenCallback | None = ...,
on_prediction_fragment: PredictionFragmentCallback | None = ...,
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
) -> PredictionResult[str]: ...
@overload
async def complete(
self,
prompt: str,
*,
response_format: ResponseSchema = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
on_first_token: PredictionFirstTokenCallback | None = ...,
on_prediction_fragment: PredictionFragmentCallback | None = ...,
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
) -> PredictionResult[DictObject]: ...
@sdk_public_api_async()
async def complete(
self,
Expand All @@ -1343,7 +1231,7 @@ async def complete(
on_first_token: PredictionFirstTokenCallback | None = None,
on_prediction_fragment: PredictionFragmentCallback | None = None,
on_prompt_processing_progress: PromptProcessingCallback | None = None,
) -> PredictionResult[str] | PredictionResult[DictObject]:
) -> PredictionResult:
"""Request a one-off prediction without any context.

Note: details of configuration fields may change in SDK feature releases.
Expand All @@ -1365,32 +1253,6 @@ async def complete(
pass
return prediction_stream.result()

@overload
async def respond_stream(
self,
history: Chat | ChatHistoryDataDict | str,
*,
response_format: Literal[None] = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
on_first_token: PredictionFirstTokenCallback | None = ...,
on_prediction_fragment: PredictionFragmentCallback | None = ...,
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
) -> AsyncPredictionStream[str]: ...
@overload
async def respond_stream(
self,
history: Chat | ChatHistoryDataDict | str,
*,
response_format: ResponseSchema = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
on_first_token: PredictionFirstTokenCallback | None = ...,
on_prediction_fragment: PredictionFragmentCallback | None = ...,
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
) -> AsyncPredictionStream[DictObject]: ...
@sdk_public_api_async()
async def respond_stream(
self,
Expand All @@ -1403,7 +1265,7 @@ async def respond_stream(
on_first_token: PredictionFirstTokenCallback | None = None,
on_prediction_fragment: PredictionFragmentCallback | None = None,
on_prompt_processing_progress: PromptProcessingCallback | None = None,
) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]:
) -> AsyncPredictionStream:
"""Request a response in an ongoing assistant chat session and stream the generated tokens.

Note: details of configuration fields may change in SDK feature releases.
Expand All @@ -1420,32 +1282,6 @@ async def respond_stream(
on_prompt_processing_progress=on_prompt_processing_progress,
)

@overload
async def respond(
self,
history: Chat | ChatHistoryDataDict | str,
*,
response_format: Literal[None] = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
on_first_token: PredictionFirstTokenCallback | None = ...,
on_prediction_fragment: PredictionFragmentCallback | None = ...,
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
) -> PredictionResult[str]: ...
@overload
async def respond(
self,
history: Chat | ChatHistoryDataDict | str,
*,
response_format: ResponseSchema = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
on_first_token: PredictionFirstTokenCallback | None = ...,
on_prediction_fragment: PredictionFragmentCallback | None = ...,
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
) -> PredictionResult[DictObject]: ...
@sdk_public_api_async()
async def respond(
self,
Expand All @@ -1458,7 +1294,7 @@ async def respond(
on_first_token: PredictionFirstTokenCallback | None = None,
on_prediction_fragment: PredictionFragmentCallback | None = None,
on_prompt_processing_progress: PromptProcessingCallback | None = None,
) -> PredictionResult[str] | PredictionResult[DictObject]:
) -> PredictionResult:
"""Request a response in an ongoing assistant chat session.

Note: details of configuration fields may change in SDK feature releases.
Expand Down
Loading
Loading