Skip to content

Commit bd762f9

Browse files
committed
Simplify prediction API type hinting
With the addition of GBNF grammar structured response format support, structured responses are no longer required to contain valid JSON. Accordingly, type hinting for prediction APIs has been simplified: * `PredictionResult` is no longer a generic type * instead, `parsed` is defined as a `str`/`dict` union field * whether it is a dict or not is now solely determined at runtime (based on whether a schema was sent to the server, and the content field has been successfully parsed as a JSON response) * other types that were only generic in the kind of prediction they returned are also now no longer generic * prediction APIs no longer define overloads that attempt to infer whether the result will be structured or not (these were already inaccurate, as they only considered the `response_format` parameter, ignoring the `structured` field in the prediction config)
1 parent 4cfbd6c commit bd762f9

File tree

5 files changed

+57
-387
lines changed

5 files changed

+57
-387
lines changed

src/lmstudio/async_api.py

Lines changed: 15 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@
1616
Callable,
1717
Generic,
1818
Iterable,
19-
Literal,
2019
Sequence,
2120
Type,
2221
TypeAlias,
2322
TypeVar,
24-
overload,
2523
)
2624
from typing_extensions import (
2725
# Native in 3.11+
@@ -89,7 +87,6 @@
8987
RemoteCallHandler,
9088
ResponseSchema,
9189
TModelInfo,
92-
TPrediction,
9390
check_model_namespace,
9491
load_struct,
9592
_model_spec_to_api_dict,
@@ -902,24 +899,23 @@ async def _fetch_file_handle(self, file_data: _LocalFileData) -> FileHandle:
902899
return await self._files_session._fetch_file_handle(file_data)
903900

904901

905-
AsyncPredictionChannel: TypeAlias = AsyncChannel[PredictionResult[T]]
906-
AsyncPredictionCM: TypeAlias = AsyncContextManager[AsyncPredictionChannel[T]]
902+
AsyncPredictionChannel: TypeAlias = AsyncChannel[PredictionResult]
903+
AsyncPredictionCM: TypeAlias = AsyncContextManager[AsyncPredictionChannel]
907904

908905

909-
class AsyncPredictionStream(PredictionStreamBase[TPrediction]):
906+
class AsyncPredictionStream(PredictionStreamBase):
910907
"""Async context manager for an ongoing prediction process."""
911908

912909
def __init__(
913910
self,
914-
channel_cm: AsyncPredictionCM[TPrediction],
915-
endpoint: PredictionEndpoint[TPrediction],
911+
channel_cm: AsyncPredictionCM,
912+
endpoint: PredictionEndpoint,
916913
) -> None:
917914
"""Initialize a prediction process representation."""
918915
self._resource_manager = AsyncExitStack()
919-
self._channel_cm: AsyncPredictionCM[TPrediction] = channel_cm
920-
self._channel: AsyncPredictionChannel[TPrediction] | None = None
921-
# See comments in BasePrediction regarding not calling super().__init__() here
922-
self._init_prediction(endpoint)
916+
self._channel_cm: AsyncPredictionCM = channel_cm
917+
self._channel: AsyncPredictionChannel | None = None
918+
super().__init__(endpoint)
923919

924920
@sdk_public_api_async()
925921
async def start(self) -> None:
@@ -976,7 +972,7 @@ async def __aiter__(self) -> AsyncIterator[LlmPredictionFragment]:
976972
self._mark_finished()
977973

978974
@sdk_public_api_async()
979-
async def wait_for_result(self) -> PredictionResult[TPrediction]:
975+
async def wait_for_result(self) -> PredictionResult:
980976
"""Wait for the result of the prediction."""
981977
async for _ in self:
982978
pass
@@ -1011,34 +1007,6 @@ def _create_handle(self, model_identifier: str) -> "AsyncLLM":
10111007
"""Create a symbolic handle to the specified LLM model."""
10121008
return AsyncLLM(model_identifier, self)
10131009

1014-
@overload
1015-
async def _complete_stream(
1016-
self,
1017-
model_specifier: AnyModelSpecifier,
1018-
prompt: str,
1019-
*,
1020-
response_format: Literal[None] = ...,
1021-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1022-
preset: str | None = ...,
1023-
on_message: PredictionMessageCallback | None = ...,
1024-
on_first_token: PredictionFirstTokenCallback | None = ...,
1025-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1026-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1027-
) -> AsyncPredictionStream[str]: ...
1028-
@overload
1029-
async def _complete_stream(
1030-
self,
1031-
model_specifier: AnyModelSpecifier,
1032-
prompt: str,
1033-
*,
1034-
response_format: ResponseSchema = ...,
1035-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1036-
preset: str | None = ...,
1037-
on_message: PredictionMessageCallback | None = ...,
1038-
on_first_token: PredictionFirstTokenCallback | None = ...,
1039-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1040-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1041-
) -> AsyncPredictionStream[DictObject]: ...
10421010
async def _complete_stream(
10431011
self,
10441012
model_specifier: AnyModelSpecifier,
@@ -1051,7 +1019,7 @@ async def _complete_stream(
10511019
on_first_token: PredictionFirstTokenCallback | None = None,
10521020
on_prediction_fragment: PredictionFragmentCallback | None = None,
10531021
on_prompt_processing_progress: PromptProcessingCallback | None = None,
1054-
) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]:
1022+
) -> AsyncPredictionStream:
10551023
"""Request a one-off prediction without any context and stream the generated tokens.
10561024
10571025
Note: details of configuration fields may change in SDK feature releases.
@@ -1071,34 +1039,6 @@ async def _complete_stream(
10711039
prediction_stream = AsyncPredictionStream(channel_cm, endpoint)
10721040
return prediction_stream
10731041

1074-
@overload
1075-
async def _respond_stream(
1076-
self,
1077-
model_specifier: AnyModelSpecifier,
1078-
history: Chat | ChatHistoryDataDict | str,
1079-
*,
1080-
response_format: Literal[None] = ...,
1081-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1082-
preset: str | None = ...,
1083-
on_message: PredictionMessageCallback | None = ...,
1084-
on_first_token: PredictionFirstTokenCallback | None = ...,
1085-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1086-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1087-
) -> AsyncPredictionStream[str]: ...
1088-
@overload
1089-
async def _respond_stream(
1090-
self,
1091-
model_specifier: AnyModelSpecifier,
1092-
history: Chat | ChatHistoryDataDict | str,
1093-
*,
1094-
response_format: ResponseSchema = ...,
1095-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1096-
preset: str | None = ...,
1097-
on_message: PredictionMessageCallback | None = ...,
1098-
on_first_token: PredictionFirstTokenCallback | None = ...,
1099-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1100-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1101-
) -> AsyncPredictionStream[DictObject]: ...
11021042
async def _respond_stream(
11031043
self,
11041044
model_specifier: AnyModelSpecifier,
@@ -1111,7 +1051,7 @@ async def _respond_stream(
11111051
on_first_token: PredictionFirstTokenCallback | None = None,
11121052
on_prediction_fragment: PredictionFragmentCallback | None = None,
11131053
on_prompt_processing_progress: PromptProcessingCallback | None = None,
1114-
) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]:
1054+
) -> AsyncPredictionStream:
11151055
"""Request a response in an ongoing assistant chat session and stream the generated tokens.
11161056
11171057
Note: details of configuration fields may change in SDK feature releases.
@@ -1250,32 +1190,6 @@ async def get_context_length(self) -> int:
12501190
class AsyncLLM(AsyncModelHandle[AsyncSessionLlm]):
12511191
"""Reference to a loaded LLM model."""
12521192

1253-
@overload
1254-
async def complete_stream(
1255-
self,
1256-
prompt: str,
1257-
*,
1258-
response_format: Literal[None] = ...,
1259-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1260-
preset: str | None = ...,
1261-
on_message: PredictionMessageCallback | None = ...,
1262-
on_first_token: PredictionFirstTokenCallback | None = ...,
1263-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1264-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1265-
) -> AsyncPredictionStream[str]: ...
1266-
@overload
1267-
async def complete_stream(
1268-
self,
1269-
prompt: str,
1270-
*,
1271-
response_format: ResponseSchema = ...,
1272-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1273-
preset: str | None = ...,
1274-
on_message: PredictionMessageCallback | None = ...,
1275-
on_first_token: PredictionFirstTokenCallback | None = ...,
1276-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1277-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1278-
) -> AsyncPredictionStream[DictObject]: ...
12791193
@sdk_public_api_async()
12801194
async def complete_stream(
12811195
self,
@@ -1288,7 +1202,7 @@ async def complete_stream(
12881202
on_first_token: PredictionFirstTokenCallback | None = None,
12891203
on_prediction_fragment: PredictionFragmentCallback | None = None,
12901204
on_prompt_processing_progress: PromptProcessingCallback | None = None,
1291-
) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]:
1205+
) -> AsyncPredictionStream:
12921206
"""Request a one-off prediction without any context and stream the generated tokens.
12931207
12941208
Note: details of configuration fields may change in SDK feature releases.
@@ -1305,32 +1219,6 @@ async def complete_stream(
13051219
on_prompt_processing_progress=on_prompt_processing_progress,
13061220
)
13071221

1308-
@overload
1309-
async def complete(
1310-
self,
1311-
prompt: str,
1312-
*,
1313-
response_format: Literal[None] = ...,
1314-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1315-
preset: str | None = ...,
1316-
on_message: PredictionMessageCallback | None = ...,
1317-
on_first_token: PredictionFirstTokenCallback | None = ...,
1318-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1319-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1320-
) -> PredictionResult[str]: ...
1321-
@overload
1322-
async def complete(
1323-
self,
1324-
prompt: str,
1325-
*,
1326-
response_format: ResponseSchema = ...,
1327-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1328-
preset: str | None = ...,
1329-
on_message: PredictionMessageCallback | None = ...,
1330-
on_first_token: PredictionFirstTokenCallback | None = ...,
1331-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1332-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1333-
) -> PredictionResult[DictObject]: ...
13341222
@sdk_public_api_async()
13351223
async def complete(
13361224
self,
@@ -1343,7 +1231,7 @@ async def complete(
13431231
on_first_token: PredictionFirstTokenCallback | None = None,
13441232
on_prediction_fragment: PredictionFragmentCallback | None = None,
13451233
on_prompt_processing_progress: PromptProcessingCallback | None = None,
1346-
) -> PredictionResult[str] | PredictionResult[DictObject]:
1234+
) -> PredictionResult:
13471235
"""Request a one-off prediction without any context.
13481236
13491237
Note: details of configuration fields may change in SDK feature releases.
@@ -1365,32 +1253,6 @@ async def complete(
13651253
pass
13661254
return prediction_stream.result()
13671255

1368-
@overload
1369-
async def respond_stream(
1370-
self,
1371-
history: Chat | ChatHistoryDataDict | str,
1372-
*,
1373-
response_format: Literal[None] = ...,
1374-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1375-
preset: str | None = ...,
1376-
on_message: PredictionMessageCallback | None = ...,
1377-
on_first_token: PredictionFirstTokenCallback | None = ...,
1378-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1379-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1380-
) -> AsyncPredictionStream[str]: ...
1381-
@overload
1382-
async def respond_stream(
1383-
self,
1384-
history: Chat | ChatHistoryDataDict | str,
1385-
*,
1386-
response_format: ResponseSchema = ...,
1387-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1388-
preset: str | None = ...,
1389-
on_message: PredictionMessageCallback | None = ...,
1390-
on_first_token: PredictionFirstTokenCallback | None = ...,
1391-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1392-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1393-
) -> AsyncPredictionStream[DictObject]: ...
13941256
@sdk_public_api_async()
13951257
async def respond_stream(
13961258
self,
@@ -1403,7 +1265,7 @@ async def respond_stream(
14031265
on_first_token: PredictionFirstTokenCallback | None = None,
14041266
on_prediction_fragment: PredictionFragmentCallback | None = None,
14051267
on_prompt_processing_progress: PromptProcessingCallback | None = None,
1406-
) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]:
1268+
) -> AsyncPredictionStream:
14071269
"""Request a response in an ongoing assistant chat session and stream the generated tokens.
14081270
14091271
Note: details of configuration fields may change in SDK feature releases.
@@ -1420,32 +1282,6 @@ async def respond_stream(
14201282
on_prompt_processing_progress=on_prompt_processing_progress,
14211283
)
14221284

1423-
@overload
1424-
async def respond(
1425-
self,
1426-
history: Chat | ChatHistoryDataDict | str,
1427-
*,
1428-
response_format: Literal[None] = ...,
1429-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1430-
preset: str | None = ...,
1431-
on_message: PredictionMessageCallback | None = ...,
1432-
on_first_token: PredictionFirstTokenCallback | None = ...,
1433-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1434-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1435-
) -> PredictionResult[str]: ...
1436-
@overload
1437-
async def respond(
1438-
self,
1439-
history: Chat | ChatHistoryDataDict | str,
1440-
*,
1441-
response_format: ResponseSchema = ...,
1442-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1443-
preset: str | None = ...,
1444-
on_message: PredictionMessageCallback | None = ...,
1445-
on_first_token: PredictionFirstTokenCallback | None = ...,
1446-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1447-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1448-
) -> PredictionResult[DictObject]: ...
14491285
@sdk_public_api_async()
14501286
async def respond(
14511287
self,
@@ -1458,7 +1294,7 @@ async def respond(
14581294
on_first_token: PredictionFirstTokenCallback | None = None,
14591295
on_prediction_fragment: PredictionFragmentCallback | None = None,
14601296
on_prompt_processing_progress: PromptProcessingCallback | None = None,
1461-
) -> PredictionResult[str] | PredictionResult[DictObject]:
1297+
) -> PredictionResult:
14621298
"""Request a response in an ongoing assistant chat session.
14631299
14641300
Note: details of configuration fields may change in SDK feature releases.

0 commit comments

Comments
 (0)