From bd762f975eff5e33974afcf71ed48a29c1a7c1d6 Mon Sep 17 00:00:00 2001 From: Alyssa Coghlan Date: Sat, 22 Mar 2025 00:49:46 +1000 Subject: [PATCH] Simplify prediction API type hinting With the addition of GBNF grammar structured response format support, structured responses are no longer required to contain valid JSON. Accordingly, type hinting for prediction APIs has been simplified: * `PredictionResult` is no longer a generic type * instead, `parsed` is defined as a `str`/`dict` union field * whether it is a dict or not is now solely determined at runtime (based on whether a schema was sent to the server, and the content field has been successfully parsed as a JSON response) * other types that were only generic in the kind of prediction they returned are also now no longer generic * prediction APIs no longer define overloads that attempt to infer whether the result will be structured or not (these were already inaccurate, as they only considered the `response_format` parameter, ignoring the `structured` field in the prediction config) --- src/lmstudio/async_api.py | 194 +++----------------------------------- src/lmstudio/json_api.py | 49 +++++----- src/lmstudio/sync_api.py | 194 +++----------------------------------- tests/test_history.py | 3 +- tests/test_inference.py | 4 +- 5 files changed, 57 insertions(+), 387 deletions(-) diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index 7a3a520..3f8e119 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -16,12 +16,10 @@ Callable, Generic, Iterable, - Literal, Sequence, Type, TypeAlias, TypeVar, - overload, ) from typing_extensions import ( # Native in 3.11+ @@ -89,7 +87,6 @@ RemoteCallHandler, ResponseSchema, TModelInfo, - TPrediction, check_model_namespace, load_struct, _model_spec_to_api_dict, @@ -902,24 +899,23 @@ async def _fetch_file_handle(self, file_data: _LocalFileData) -> FileHandle: return await self._files_session._fetch_file_handle(file_data) -AsyncPredictionChannel: TypeAlias = AsyncChannel[PredictionResult[T]] -AsyncPredictionCM: TypeAlias = AsyncContextManager[AsyncPredictionChannel[T]] +AsyncPredictionChannel: TypeAlias = AsyncChannel[PredictionResult] +AsyncPredictionCM: TypeAlias = AsyncContextManager[AsyncPredictionChannel] -class AsyncPredictionStream(PredictionStreamBase[TPrediction]): +class AsyncPredictionStream(PredictionStreamBase): """Async context manager for an ongoing prediction process.""" def __init__( self, - channel_cm: AsyncPredictionCM[TPrediction], - endpoint: PredictionEndpoint[TPrediction], + channel_cm: AsyncPredictionCM, + endpoint: PredictionEndpoint, ) -> None: """Initialize a prediction process representation.""" self._resource_manager = AsyncExitStack() - self._channel_cm: AsyncPredictionCM[TPrediction] = channel_cm - self._channel: AsyncPredictionChannel[TPrediction] | None = None - # See comments in BasePrediction regarding not calling super().__init__() here - self._init_prediction(endpoint) + self._channel_cm: AsyncPredictionCM = channel_cm + self._channel: AsyncPredictionChannel | None = None + super().__init__(endpoint) @sdk_public_api_async() async def start(self) -> None: @@ -976,7 +972,7 @@ async def __aiter__(self) -> AsyncIterator[LlmPredictionFragment]: self._mark_finished() @sdk_public_api_async() - async def wait_for_result(self) -> PredictionResult[TPrediction]: + async def wait_for_result(self) -> PredictionResult: """Wait for the result of the prediction.""" async for _ in self: pass @@ -1011,34 +1007,6 @@ def _create_handle(self, model_identifier: str) -> "AsyncLLM": """Create a symbolic handle to the specified LLM model.""" return AsyncLLM(model_identifier, self) - @overload - async def _complete_stream( - self, - model_specifier: AnyModelSpecifier, - prompt: str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> AsyncPredictionStream[str]: ... - @overload - async def _complete_stream( - self, - model_specifier: AnyModelSpecifier, - prompt: str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> AsyncPredictionStream[DictObject]: ... async def _complete_stream( self, model_specifier: AnyModelSpecifier, @@ -1051,7 +1019,7 @@ async def _complete_stream( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]: + ) -> AsyncPredictionStream: """Request a one-off prediction without any context and stream the generated tokens. Note: details of configuration fields may change in SDK feature releases. @@ -1071,34 +1039,6 @@ async def _complete_stream( prediction_stream = AsyncPredictionStream(channel_cm, endpoint) return prediction_stream - @overload - async def _respond_stream( - self, - model_specifier: AnyModelSpecifier, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> AsyncPredictionStream[str]: ... - @overload - async def _respond_stream( - self, - model_specifier: AnyModelSpecifier, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> AsyncPredictionStream[DictObject]: ... async def _respond_stream( self, model_specifier: AnyModelSpecifier, @@ -1111,7 +1051,7 @@ async def _respond_stream( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]: + ) -> AsyncPredictionStream: """Request a response in an ongoing assistant chat session and stream the generated tokens. Note: details of configuration fields may change in SDK feature releases. @@ -1250,32 +1190,6 @@ async def get_context_length(self) -> int: class AsyncLLM(AsyncModelHandle[AsyncSessionLlm]): """Reference to a loaded LLM model.""" - @overload - async def complete_stream( - self, - prompt: str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> AsyncPredictionStream[str]: ... - @overload - async def complete_stream( - self, - prompt: str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> AsyncPredictionStream[DictObject]: ... @sdk_public_api_async() async def complete_stream( self, @@ -1288,7 +1202,7 @@ async def complete_stream( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]: + ) -> AsyncPredictionStream: """Request a one-off prediction without any context and stream the generated tokens. Note: details of configuration fields may change in SDK feature releases. @@ -1305,32 +1219,6 @@ async def complete_stream( on_prompt_processing_progress=on_prompt_processing_progress, ) - @overload - async def complete( - self, - prompt: str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionResult[str]: ... - @overload - async def complete( - self, - prompt: str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionResult[DictObject]: ... @sdk_public_api_async() async def complete( self, @@ -1343,7 +1231,7 @@ async def complete( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> PredictionResult[str] | PredictionResult[DictObject]: + ) -> PredictionResult: """Request a one-off prediction without any context. Note: details of configuration fields may change in SDK feature releases. @@ -1365,32 +1253,6 @@ async def complete( pass return prediction_stream.result() - @overload - async def respond_stream( - self, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> AsyncPredictionStream[str]: ... - @overload - async def respond_stream( - self, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> AsyncPredictionStream[DictObject]: ... @sdk_public_api_async() async def respond_stream( self, @@ -1403,7 +1265,7 @@ async def respond_stream( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]: + ) -> AsyncPredictionStream: """Request a response in an ongoing assistant chat session and stream the generated tokens. Note: details of configuration fields may change in SDK feature releases. @@ -1420,32 +1282,6 @@ async def respond_stream( on_prompt_processing_progress=on_prompt_processing_progress, ) - @overload - async def respond( - self, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionResult[str]: ... - @overload - async def respond( - self, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionResult[DictObject]: ... @sdk_public_api_async() async def respond( self, @@ -1458,7 +1294,7 @@ async def respond( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> PredictionResult[str] | PredictionResult[DictObject]: + ) -> PredictionResult: """Request a response in an ongoing assistant chat session. Note: details of configuration fields may change in SDK feature releases. diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index 51452ae..e758a07 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -177,11 +177,13 @@ T = TypeVar("T") TStruct = TypeVar("TStruct", bound=AnyLMStudioStruct) -TPrediction = TypeVar("TPrediction", str, DictObject) DEFAULT_API_HOST = "localhost:1234" DEFAULT_TTL = 60 * 60 # By default, leaves idle models loaded for an hour +UnstructuredPrediction: TypeAlias = str +StructuredPrediction: TypeAlias = DictObject +AnyPrediction = StructuredPrediction | UnstructuredPrediction AnyModelSpecifier: TypeAlias = str | ModelSpecifier | ModelQuery | DictObject AnyLoadConfig: TypeAlias = EmbeddingLoadModelConfig | LlmLoadModelConfig @@ -445,12 +447,12 @@ class ModelLoadResult: @dataclass(kw_only=True, frozen=True, slots=True) -class PredictionResult(Generic[TPrediction]): +class PredictionResult: """The final result of a prediction.""" # fmt: off content: str # The text content of the prediction - parsed: TPrediction # dict for structured predictions, str otherwise + parsed: AnyPrediction # dict for structured predictions, str otherwise stats: LlmPredictionStats # Statistics about the prediction process model_info: LlmInfo # Information about the model used structured: bool = field(init=False) # Whether the result is structured or not @@ -475,13 +477,13 @@ def _to_history_content(self) -> str: @dataclass(kw_only=True, frozen=True, slots=True) -class PredictionRoundResult(PredictionResult[str]): +class PredictionRoundResult(PredictionResult): """The result of a prediction within a multi-round tool using action.""" round_index: int # The round within the action that produced this result @classmethod - def from_result(cls, result: PredictionResult[str], round_index: int) -> Self: + def from_result(cls, result: PredictionResult, round_index: int) -> Self: """Create a prediction round result from its underlying prediction result.""" copied_keys = { k: getattr(result, k) @@ -1110,10 +1112,7 @@ class PredictionToolCallAbortedEvent(ChannelRxEvent[None]): class PredictionEndpoint( - Generic[TPrediction], - ChannelEndpoint[ - PredictionResult[TPrediction], PredictionRxEvent, PredictionChannelRequestDict - ], + ChannelEndpoint[PredictionResult, PredictionRxEvent, PredictionChannelRequestDict], ): """Helper class for prediction endpoint message handling.""" @@ -1264,14 +1263,20 @@ def iter_message_events( # or has been successfully cancelled. Don't try # to parse the received content in the latter case. result_content = "".join(self._fragment_content) + parsed_content: AnyPrediction = result_content if self._structured and not self._is_cancelled: try: + # Check if the content is valid JSON parsed_content = json.loads(result_content) except json.JSONDecodeError: + # This likely indicates a non-JSON GBNF grammar # Fall back to unstructured result reporting - parsed_content = result_content - else: - parsed_content = result_content + pass + else: + if not isinstance(parsed_content, dict): + # This likely indicates a non-JSON GBNF grammar + # Fall back to unstructured result reporting + parsed_content = result_content yield self._set_result( PredictionResult( content=result_content, @@ -1381,7 +1386,7 @@ def mark_cancelled(self) -> None: self._is_cancelled = True -class CompletionEndpoint(PredictionEndpoint[TPrediction]): +class CompletionEndpoint(PredictionEndpoint): """API channel endpoint for requesting text completion from a model.""" _NOTICE_PREFIX = "Completion" @@ -1421,7 +1426,7 @@ def _additional_config_options(cls) -> DictObject: ToolDefinition: TypeAlias = ToolFunctionDef | ToolFunctionDefDict | Callable[..., Any] -class ChatResponseEndpoint(PredictionEndpoint[TPrediction]): +class ChatResponseEndpoint(PredictionEndpoint): """API channel endpoint for requesting a chat response from a model.""" _NOTICE_PREFIX = "Chat response" @@ -1457,26 +1462,20 @@ def parse_tools( return LlmToolUseSettingToolArray(tools=llm_tool_defs), client_tool_map -class PredictionStreamBase(Generic[TPrediction]): +class PredictionStreamBase: """Common base class for sync and async prediction streams.""" def __init__( self, - endpoint: PredictionEndpoint[TPrediction], + endpoint: PredictionEndpoint, ) -> None: """Initialize a prediction process representation.""" - # Split initialisation out to a separate helper function that plays nice with bound generic type vars - # To avoid type errors in mypy, subclasses may call this directly (instead of `super().__init__(endpoint)`) - # https://discuss.python.org/t/how-to-share-type-variables-when-inheriting-from-generic-base-classes/78839 - self._init_prediction(endpoint) - - def _init_prediction(self, endpoint: PredictionEndpoint[TPrediction]) -> None: - self._endpoint: PredictionEndpoint[TPrediction] = endpoint + self._endpoint = endpoint # Final result reporting self._is_started = False self._is_finished = False - self._final_result: PredictionResult[TPrediction] | None = None + self._final_result: PredictionResult | None = None self._error: BaseException | None = None @property @@ -1510,7 +1509,7 @@ def _prediction_config(self) -> LlmPredictionConfig | None: return self._final_result.prediction_config @sdk_public_api() - def result(self) -> PredictionResult[TPrediction]: + def result(self) -> PredictionResult: """Get the result of a completed prediction. This API raises an exception if the result is not available, diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index b9fa27f..720664b 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -27,13 +27,11 @@ Iterator, Callable, Generic, - Literal, NoReturn, Sequence, Type, TypeAlias, TypeVar, - overload, ) from typing_extensions import ( # Native in 3.11+ @@ -116,7 +114,6 @@ RemoteCallHandler, ResponseSchema, TModelInfo, - TPrediction, ToolDefinition, check_model_namespace, load_struct, @@ -1065,24 +1062,23 @@ def _fetch_file_handle(self, file_data: _LocalFileData) -> FileHandle: return self._files_session._fetch_file_handle(file_data) -SyncPredictionChannel: TypeAlias = SyncChannel[PredictionResult[T]] -SyncPredictionCM: TypeAlias = ContextManager[SyncPredictionChannel[T]] +SyncPredictionChannel: TypeAlias = SyncChannel[PredictionResult] +SyncPredictionCM: TypeAlias = ContextManager[SyncPredictionChannel] -class PredictionStream(PredictionStreamBase[TPrediction]): +class PredictionStream(PredictionStreamBase): """Sync context manager for an ongoing prediction process.""" def __init__( self, - channel_cm: SyncPredictionCM[TPrediction], - endpoint: PredictionEndpoint[TPrediction], + channel_cm: SyncPredictionCM, + endpoint: PredictionEndpoint, ) -> None: """Initialize a prediction process representation.""" self._resources = ExitStack() - self._channel_cm: SyncPredictionCM[TPrediction] = channel_cm - self._channel: SyncPredictionChannel[TPrediction] | None = None - # See comments in BasePrediction regarding not calling super().__init__() here - self._init_prediction(endpoint) + self._channel_cm: SyncPredictionCM = channel_cm + self._channel: SyncPredictionChannel | None = None + super().__init__(endpoint) @sdk_public_api() def start(self) -> None: @@ -1141,7 +1137,7 @@ def _iter_events(self) -> Iterator[PredictionRxEvent]: self._mark_finished() @sdk_public_api() - def wait_for_result(self) -> PredictionResult[TPrediction]: + def wait_for_result(self) -> PredictionResult: """Wait for the result of the prediction.""" for _ in self: pass @@ -1176,34 +1172,6 @@ def _create_handle(self, model_identifier: str) -> "LLM": """Create a symbolic handle to the specified LLM model.""" return LLM(model_identifier, self) - @overload - def _complete_stream( - self, - model_specifier: AnyModelSpecifier, - prompt: str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionStream[str]: ... - @overload - def _complete_stream( - self, - model_specifier: AnyModelSpecifier, - prompt: str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionStream[DictObject]: ... def _complete_stream( self, model_specifier: AnyModelSpecifier, @@ -1216,7 +1184,7 @@ def _complete_stream( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> PredictionStream[str] | PredictionStream[DictObject]: + ) -> PredictionStream: """Request a one-off prediction without any context and stream the generated tokens. Note: details of configuration fields may change in SDK feature releases. @@ -1236,34 +1204,6 @@ def _complete_stream( prediction_stream = PredictionStream(channel_cm, endpoint) return prediction_stream - @overload - def _respond_stream( - self, - model_specifier: AnyModelSpecifier, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionStream[str]: ... - @overload - def _respond_stream( - self, - model_specifier: AnyModelSpecifier, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionStream[DictObject]: ... def _respond_stream( self, model_specifier: AnyModelSpecifier, @@ -1276,7 +1216,7 @@ def _respond_stream( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> PredictionStream[str] | PredictionStream[DictObject]: + ) -> PredictionStream: """Request a response in an ongoing assistant chat session and stream the generated tokens. Note: details of configuration fields may change in SDK feature releases. @@ -1411,32 +1351,6 @@ def get_context_length(self) -> int: class LLM(SyncModelHandle[SyncSessionLlm]): """Reference to a loaded LLM model.""" - @overload - def complete_stream( - self, - prompt: str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionStream[str]: ... - @overload - def complete_stream( - self, - prompt: str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionStream[DictObject]: ... @sdk_public_api() def complete_stream( self, @@ -1449,7 +1363,7 @@ def complete_stream( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> PredictionStream[str] | PredictionStream[DictObject]: + ) -> PredictionStream: """Request a one-off prediction without any context and stream the generated tokens. Note: details of configuration fields may change in SDK feature releases. @@ -1466,32 +1380,6 @@ def complete_stream( on_prompt_processing_progress=on_prompt_processing_progress, ) - @overload - def complete( - self, - prompt: str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionResult[str]: ... - @overload - def complete( - self, - prompt: str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionResult[DictObject]: ... @sdk_public_api() def complete( self, @@ -1504,7 +1392,7 @@ def complete( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> PredictionResult[str] | PredictionResult[DictObject]: + ) -> PredictionResult: """Request a one-off prediction without any context. Note: details of configuration fields may change in SDK feature releases. @@ -1526,32 +1414,6 @@ def complete( pass return prediction_stream.result() - @overload - def respond_stream( - self, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionStream[str]: ... - @overload - def respond_stream( - self, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionStream[DictObject]: ... @sdk_public_api() def respond_stream( self, @@ -1564,7 +1426,7 @@ def respond_stream( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> PredictionStream[str] | PredictionStream[DictObject]: + ) -> PredictionStream: """Request a response in an ongoing assistant chat session and stream the generated tokens. Note: details of configuration fields may change in SDK feature releases. @@ -1581,32 +1443,6 @@ def respond_stream( on_prompt_processing_progress=on_prompt_processing_progress, ) - @overload - def respond( - self, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionResult[str]: ... - @overload - def respond( - self, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionResult[DictObject]: ... @sdk_public_api() def respond( self, @@ -1619,7 +1455,7 @@ def respond( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> PredictionResult[str] | PredictionResult[DictObject]: + ) -> PredictionResult: """Request a response in an ongoing assistant chat session. Note: details of configuration fields may change in SDK feature releases. diff --git a/tests/test_history.py b/tests/test_history.py index 233cdfd..902588b 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -28,7 +28,6 @@ LlmPredictionConfig, LlmPredictionStats, PredictionResult, - TPrediction, ) from .support import IMAGE_FILEPATH, check_sdk_error @@ -334,7 +333,7 @@ def test_add_entries_class_content() -> None: assert chat._get_history_for_prediction() == EXPECTED_HISTORY -def _make_prediction_result(data: TPrediction) -> PredictionResult[TPrediction]: +def _make_prediction_result(data: str | DictObject) -> PredictionResult: return PredictionResult( content=(data if isinstance(data, str) else json.dumps(data)), parsed=data, diff --git a/tests/test_inference.py b/tests/test_inference.py index d7bb02a..a523b9b 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -67,7 +67,7 @@ async def test_concurrent_predictions(caplog: LogCap, subtests: SubTests) -> Non async with AsyncClient() as client: session = client.llm - async def _request_response() -> PredictionResult[str]: + async def _request_response() -> PredictionResult: llm = await session.model(model_id) return await llm.respond( history=history, @@ -172,7 +172,7 @@ def test_tool_using_agent(caplog: LogCap) -> None: chat.add_user_message("What is the sum of 123 and 3210?") tools = [ADDITION_TOOL_SPEC] # Ensure ignoring the round index passes static type checks - predictions: list[PredictionResult[str]] = [] + predictions: list[PredictionResult] = [] act_result = llm.act(chat, tools, on_prediction_completed=predictions.append) assert len(predictions) > 1