diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index 7a3a520..3f8e119 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -16,12 +16,10 @@ Callable, Generic, Iterable, - Literal, Sequence, Type, TypeAlias, TypeVar, - overload, ) from typing_extensions import ( # Native in 3.11+ @@ -89,7 +87,6 @@ RemoteCallHandler, ResponseSchema, TModelInfo, - TPrediction, check_model_namespace, load_struct, _model_spec_to_api_dict, @@ -902,24 +899,23 @@ async def _fetch_file_handle(self, file_data: _LocalFileData) -> FileHandle: return await self._files_session._fetch_file_handle(file_data) -AsyncPredictionChannel: TypeAlias = AsyncChannel[PredictionResult[T]] -AsyncPredictionCM: TypeAlias = AsyncContextManager[AsyncPredictionChannel[T]] +AsyncPredictionChannel: TypeAlias = AsyncChannel[PredictionResult] +AsyncPredictionCM: TypeAlias = AsyncContextManager[AsyncPredictionChannel] -class AsyncPredictionStream(PredictionStreamBase[TPrediction]): +class AsyncPredictionStream(PredictionStreamBase): """Async context manager for an ongoing prediction process.""" def __init__( self, - channel_cm: AsyncPredictionCM[TPrediction], - endpoint: PredictionEndpoint[TPrediction], + channel_cm: AsyncPredictionCM, + endpoint: PredictionEndpoint, ) -> None: """Initialize a prediction process representation.""" self._resource_manager = AsyncExitStack() - self._channel_cm: AsyncPredictionCM[TPrediction] = channel_cm - self._channel: AsyncPredictionChannel[TPrediction] | None = None - # See comments in BasePrediction regarding not calling super().__init__() here - self._init_prediction(endpoint) + self._channel_cm: AsyncPredictionCM = channel_cm + self._channel: AsyncPredictionChannel | None = None + super().__init__(endpoint) @sdk_public_api_async() async def start(self) -> None: @@ -976,7 +972,7 @@ async def __aiter__(self) -> AsyncIterator[LlmPredictionFragment]: self._mark_finished() @sdk_public_api_async() - async def wait_for_result(self) -> PredictionResult[TPrediction]: + async def wait_for_result(self) -> PredictionResult: """Wait for the result of the prediction.""" async for _ in self: pass @@ -1011,34 +1007,6 @@ def _create_handle(self, model_identifier: str) -> "AsyncLLM": """Create a symbolic handle to the specified LLM model.""" return AsyncLLM(model_identifier, self) - @overload - async def _complete_stream( - self, - model_specifier: AnyModelSpecifier, - prompt: str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> AsyncPredictionStream[str]: ... - @overload - async def _complete_stream( - self, - model_specifier: AnyModelSpecifier, - prompt: str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> AsyncPredictionStream[DictObject]: ... async def _complete_stream( self, model_specifier: AnyModelSpecifier, @@ -1051,7 +1019,7 @@ async def _complete_stream( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]: + ) -> AsyncPredictionStream: """Request a one-off prediction without any context and stream the generated tokens. Note: details of configuration fields may change in SDK feature releases. @@ -1071,34 +1039,6 @@ async def _complete_stream( prediction_stream = AsyncPredictionStream(channel_cm, endpoint) return prediction_stream - @overload - async def _respond_stream( - self, - model_specifier: AnyModelSpecifier, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> AsyncPredictionStream[str]: ... - @overload - async def _respond_stream( - self, - model_specifier: AnyModelSpecifier, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> AsyncPredictionStream[DictObject]: ... async def _respond_stream( self, model_specifier: AnyModelSpecifier, @@ -1111,7 +1051,7 @@ async def _respond_stream( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]: + ) -> AsyncPredictionStream: """Request a response in an ongoing assistant chat session and stream the generated tokens. Note: details of configuration fields may change in SDK feature releases. @@ -1250,32 +1190,6 @@ async def get_context_length(self) -> int: class AsyncLLM(AsyncModelHandle[AsyncSessionLlm]): """Reference to a loaded LLM model.""" - @overload - async def complete_stream( - self, - prompt: str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> AsyncPredictionStream[str]: ... - @overload - async def complete_stream( - self, - prompt: str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> AsyncPredictionStream[DictObject]: ... @sdk_public_api_async() async def complete_stream( self, @@ -1288,7 +1202,7 @@ async def complete_stream( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]: + ) -> AsyncPredictionStream: """Request a one-off prediction without any context and stream the generated tokens. Note: details of configuration fields may change in SDK feature releases. @@ -1305,32 +1219,6 @@ async def complete_stream( on_prompt_processing_progress=on_prompt_processing_progress, ) - @overload - async def complete( - self, - prompt: str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionResult[str]: ... - @overload - async def complete( - self, - prompt: str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionResult[DictObject]: ... @sdk_public_api_async() async def complete( self, @@ -1343,7 +1231,7 @@ async def complete( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> PredictionResult[str] | PredictionResult[DictObject]: + ) -> PredictionResult: """Request a one-off prediction without any context. Note: details of configuration fields may change in SDK feature releases. @@ -1365,32 +1253,6 @@ async def complete( pass return prediction_stream.result() - @overload - async def respond_stream( - self, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> AsyncPredictionStream[str]: ... - @overload - async def respond_stream( - self, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> AsyncPredictionStream[DictObject]: ... @sdk_public_api_async() async def respond_stream( self, @@ -1403,7 +1265,7 @@ async def respond_stream( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]: + ) -> AsyncPredictionStream: """Request a response in an ongoing assistant chat session and stream the generated tokens. Note: details of configuration fields may change in SDK feature releases. @@ -1420,32 +1282,6 @@ async def respond_stream( on_prompt_processing_progress=on_prompt_processing_progress, ) - @overload - async def respond( - self, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionResult[str]: ... - @overload - async def respond( - self, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionResult[DictObject]: ... @sdk_public_api_async() async def respond( self, @@ -1458,7 +1294,7 @@ async def respond( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> PredictionResult[str] | PredictionResult[DictObject]: + ) -> PredictionResult: """Request a response in an ongoing assistant chat session. Note: details of configuration fields may change in SDK feature releases. diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index 51452ae..e758a07 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -177,11 +177,13 @@ T = TypeVar("T") TStruct = TypeVar("TStruct", bound=AnyLMStudioStruct) -TPrediction = TypeVar("TPrediction", str, DictObject) DEFAULT_API_HOST = "localhost:1234" DEFAULT_TTL = 60 * 60 # By default, leaves idle models loaded for an hour +UnstructuredPrediction: TypeAlias = str +StructuredPrediction: TypeAlias = DictObject +AnyPrediction = StructuredPrediction | UnstructuredPrediction AnyModelSpecifier: TypeAlias = str | ModelSpecifier | ModelQuery | DictObject AnyLoadConfig: TypeAlias = EmbeddingLoadModelConfig | LlmLoadModelConfig @@ -445,12 +447,12 @@ class ModelLoadResult: @dataclass(kw_only=True, frozen=True, slots=True) -class PredictionResult(Generic[TPrediction]): +class PredictionResult: """The final result of a prediction.""" # fmt: off content: str # The text content of the prediction - parsed: TPrediction # dict for structured predictions, str otherwise + parsed: AnyPrediction # dict for structured predictions, str otherwise stats: LlmPredictionStats # Statistics about the prediction process model_info: LlmInfo # Information about the model used structured: bool = field(init=False) # Whether the result is structured or not @@ -475,13 +477,13 @@ def _to_history_content(self) -> str: @dataclass(kw_only=True, frozen=True, slots=True) -class PredictionRoundResult(PredictionResult[str]): +class PredictionRoundResult(PredictionResult): """The result of a prediction within a multi-round tool using action.""" round_index: int # The round within the action that produced this result @classmethod - def from_result(cls, result: PredictionResult[str], round_index: int) -> Self: + def from_result(cls, result: PredictionResult, round_index: int) -> Self: """Create a prediction round result from its underlying prediction result.""" copied_keys = { k: getattr(result, k) @@ -1110,10 +1112,7 @@ class PredictionToolCallAbortedEvent(ChannelRxEvent[None]): class PredictionEndpoint( - Generic[TPrediction], - ChannelEndpoint[ - PredictionResult[TPrediction], PredictionRxEvent, PredictionChannelRequestDict - ], + ChannelEndpoint[PredictionResult, PredictionRxEvent, PredictionChannelRequestDict], ): """Helper class for prediction endpoint message handling.""" @@ -1264,14 +1263,20 @@ def iter_message_events( # or has been successfully cancelled. Don't try # to parse the received content in the latter case. result_content = "".join(self._fragment_content) + parsed_content: AnyPrediction = result_content if self._structured and not self._is_cancelled: try: + # Check if the content is valid JSON parsed_content = json.loads(result_content) except json.JSONDecodeError: + # This likely indicates a non-JSON GBNF grammar # Fall back to unstructured result reporting - parsed_content = result_content - else: - parsed_content = result_content + pass + else: + if not isinstance(parsed_content, dict): + # This likely indicates a non-JSON GBNF grammar + # Fall back to unstructured result reporting + parsed_content = result_content yield self._set_result( PredictionResult( content=result_content, @@ -1381,7 +1386,7 @@ def mark_cancelled(self) -> None: self._is_cancelled = True -class CompletionEndpoint(PredictionEndpoint[TPrediction]): +class CompletionEndpoint(PredictionEndpoint): """API channel endpoint for requesting text completion from a model.""" _NOTICE_PREFIX = "Completion" @@ -1421,7 +1426,7 @@ def _additional_config_options(cls) -> DictObject: ToolDefinition: TypeAlias = ToolFunctionDef | ToolFunctionDefDict | Callable[..., Any] -class ChatResponseEndpoint(PredictionEndpoint[TPrediction]): +class ChatResponseEndpoint(PredictionEndpoint): """API channel endpoint for requesting a chat response from a model.""" _NOTICE_PREFIX = "Chat response" @@ -1457,26 +1462,20 @@ def parse_tools( return LlmToolUseSettingToolArray(tools=llm_tool_defs), client_tool_map -class PredictionStreamBase(Generic[TPrediction]): +class PredictionStreamBase: """Common base class for sync and async prediction streams.""" def __init__( self, - endpoint: PredictionEndpoint[TPrediction], + endpoint: PredictionEndpoint, ) -> None: """Initialize a prediction process representation.""" - # Split initialisation out to a separate helper function that plays nice with bound generic type vars - # To avoid type errors in mypy, subclasses may call this directly (instead of `super().__init__(endpoint)`) - # https://discuss.python.org/t/how-to-share-type-variables-when-inheriting-from-generic-base-classes/78839 - self._init_prediction(endpoint) - - def _init_prediction(self, endpoint: PredictionEndpoint[TPrediction]) -> None: - self._endpoint: PredictionEndpoint[TPrediction] = endpoint + self._endpoint = endpoint # Final result reporting self._is_started = False self._is_finished = False - self._final_result: PredictionResult[TPrediction] | None = None + self._final_result: PredictionResult | None = None self._error: BaseException | None = None @property @@ -1510,7 +1509,7 @@ def _prediction_config(self) -> LlmPredictionConfig | None: return self._final_result.prediction_config @sdk_public_api() - def result(self) -> PredictionResult[TPrediction]: + def result(self) -> PredictionResult: """Get the result of a completed prediction. This API raises an exception if the result is not available, diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index b9fa27f..720664b 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -27,13 +27,11 @@ Iterator, Callable, Generic, - Literal, NoReturn, Sequence, Type, TypeAlias, TypeVar, - overload, ) from typing_extensions import ( # Native in 3.11+ @@ -116,7 +114,6 @@ RemoteCallHandler, ResponseSchema, TModelInfo, - TPrediction, ToolDefinition, check_model_namespace, load_struct, @@ -1065,24 +1062,23 @@ def _fetch_file_handle(self, file_data: _LocalFileData) -> FileHandle: return self._files_session._fetch_file_handle(file_data) -SyncPredictionChannel: TypeAlias = SyncChannel[PredictionResult[T]] -SyncPredictionCM: TypeAlias = ContextManager[SyncPredictionChannel[T]] +SyncPredictionChannel: TypeAlias = SyncChannel[PredictionResult] +SyncPredictionCM: TypeAlias = ContextManager[SyncPredictionChannel] -class PredictionStream(PredictionStreamBase[TPrediction]): +class PredictionStream(PredictionStreamBase): """Sync context manager for an ongoing prediction process.""" def __init__( self, - channel_cm: SyncPredictionCM[TPrediction], - endpoint: PredictionEndpoint[TPrediction], + channel_cm: SyncPredictionCM, + endpoint: PredictionEndpoint, ) -> None: """Initialize a prediction process representation.""" self._resources = ExitStack() - self._channel_cm: SyncPredictionCM[TPrediction] = channel_cm - self._channel: SyncPredictionChannel[TPrediction] | None = None - # See comments in BasePrediction regarding not calling super().__init__() here - self._init_prediction(endpoint) + self._channel_cm: SyncPredictionCM = channel_cm + self._channel: SyncPredictionChannel | None = None + super().__init__(endpoint) @sdk_public_api() def start(self) -> None: @@ -1141,7 +1137,7 @@ def _iter_events(self) -> Iterator[PredictionRxEvent]: self._mark_finished() @sdk_public_api() - def wait_for_result(self) -> PredictionResult[TPrediction]: + def wait_for_result(self) -> PredictionResult: """Wait for the result of the prediction.""" for _ in self: pass @@ -1176,34 +1172,6 @@ def _create_handle(self, model_identifier: str) -> "LLM": """Create a symbolic handle to the specified LLM model.""" return LLM(model_identifier, self) - @overload - def _complete_stream( - self, - model_specifier: AnyModelSpecifier, - prompt: str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionStream[str]: ... - @overload - def _complete_stream( - self, - model_specifier: AnyModelSpecifier, - prompt: str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionStream[DictObject]: ... def _complete_stream( self, model_specifier: AnyModelSpecifier, @@ -1216,7 +1184,7 @@ def _complete_stream( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> PredictionStream[str] | PredictionStream[DictObject]: + ) -> PredictionStream: """Request a one-off prediction without any context and stream the generated tokens. Note: details of configuration fields may change in SDK feature releases. @@ -1236,34 +1204,6 @@ def _complete_stream( prediction_stream = PredictionStream(channel_cm, endpoint) return prediction_stream - @overload - def _respond_stream( - self, - model_specifier: AnyModelSpecifier, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionStream[str]: ... - @overload - def _respond_stream( - self, - model_specifier: AnyModelSpecifier, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionStream[DictObject]: ... def _respond_stream( self, model_specifier: AnyModelSpecifier, @@ -1276,7 +1216,7 @@ def _respond_stream( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> PredictionStream[str] | PredictionStream[DictObject]: + ) -> PredictionStream: """Request a response in an ongoing assistant chat session and stream the generated tokens. Note: details of configuration fields may change in SDK feature releases. @@ -1411,32 +1351,6 @@ def get_context_length(self) -> int: class LLM(SyncModelHandle[SyncSessionLlm]): """Reference to a loaded LLM model.""" - @overload - def complete_stream( - self, - prompt: str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionStream[str]: ... - @overload - def complete_stream( - self, - prompt: str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionStream[DictObject]: ... @sdk_public_api() def complete_stream( self, @@ -1449,7 +1363,7 @@ def complete_stream( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> PredictionStream[str] | PredictionStream[DictObject]: + ) -> PredictionStream: """Request a one-off prediction without any context and stream the generated tokens. Note: details of configuration fields may change in SDK feature releases. @@ -1466,32 +1380,6 @@ def complete_stream( on_prompt_processing_progress=on_prompt_processing_progress, ) - @overload - def complete( - self, - prompt: str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionResult[str]: ... - @overload - def complete( - self, - prompt: str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionResult[DictObject]: ... @sdk_public_api() def complete( self, @@ -1504,7 +1392,7 @@ def complete( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> PredictionResult[str] | PredictionResult[DictObject]: + ) -> PredictionResult: """Request a one-off prediction without any context. Note: details of configuration fields may change in SDK feature releases. @@ -1526,32 +1414,6 @@ def complete( pass return prediction_stream.result() - @overload - def respond_stream( - self, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionStream[str]: ... - @overload - def respond_stream( - self, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionStream[DictObject]: ... @sdk_public_api() def respond_stream( self, @@ -1564,7 +1426,7 @@ def respond_stream( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> PredictionStream[str] | PredictionStream[DictObject]: + ) -> PredictionStream: """Request a response in an ongoing assistant chat session and stream the generated tokens. Note: details of configuration fields may change in SDK feature releases. @@ -1581,32 +1443,6 @@ def respond_stream( on_prompt_processing_progress=on_prompt_processing_progress, ) - @overload - def respond( - self, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: Literal[None] = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionResult[str]: ... - @overload - def respond( - self, - history: Chat | ChatHistoryDataDict | str, - *, - response_format: ResponseSchema = ..., - config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - preset: str | None = ..., - on_message: PredictionMessageCallback | None = ..., - on_first_token: PredictionFirstTokenCallback | None = ..., - on_prediction_fragment: PredictionFragmentCallback | None = ..., - on_prompt_processing_progress: PromptProcessingCallback | None = ..., - ) -> PredictionResult[DictObject]: ... @sdk_public_api() def respond( self, @@ -1619,7 +1455,7 @@ def respond( on_first_token: PredictionFirstTokenCallback | None = None, on_prediction_fragment: PredictionFragmentCallback | None = None, on_prompt_processing_progress: PromptProcessingCallback | None = None, - ) -> PredictionResult[str] | PredictionResult[DictObject]: + ) -> PredictionResult: """Request a response in an ongoing assistant chat session. Note: details of configuration fields may change in SDK feature releases. diff --git a/tests/test_history.py b/tests/test_history.py index 233cdfd..902588b 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -28,7 +28,6 @@ LlmPredictionConfig, LlmPredictionStats, PredictionResult, - TPrediction, ) from .support import IMAGE_FILEPATH, check_sdk_error @@ -334,7 +333,7 @@ def test_add_entries_class_content() -> None: assert chat._get_history_for_prediction() == EXPECTED_HISTORY -def _make_prediction_result(data: TPrediction) -> PredictionResult[TPrediction]: +def _make_prediction_result(data: str | DictObject) -> PredictionResult: return PredictionResult( content=(data if isinstance(data, str) else json.dumps(data)), parsed=data, diff --git a/tests/test_inference.py b/tests/test_inference.py index d7bb02a..a523b9b 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -67,7 +67,7 @@ async def test_concurrent_predictions(caplog: LogCap, subtests: SubTests) -> Non async with AsyncClient() as client: session = client.llm - async def _request_response() -> PredictionResult[str]: + async def _request_response() -> PredictionResult: llm = await session.model(model_id) return await llm.respond( history=history, @@ -172,7 +172,7 @@ def test_tool_using_agent(caplog: LogCap) -> None: chat.add_user_message("What is the sum of 123 and 3210?") tools = [ADDITION_TOOL_SPEC] # Ensure ignoring the round index passes static type checks - predictions: list[PredictionResult[str]] = [] + predictions: list[PredictionResult] = [] act_result = llm.act(chat, tools, on_prediction_completed=predictions.append) assert len(predictions) > 1