Skip to content

Commit c81eea0

Browse files
authored
Simplify prediction API type hinting (#61)
With the addition of GBNF grammar structured response format support, structured responses are no longer required to contain valid JSON. Accordingly, type hinting for prediction APIs has been simplified: * `PredictionResult` is no longer a generic type * instead, `parsed` is defined as a `str`/`dict` union field * whether it is a dict or not is now solely determined at runtime (based on whether a schema was sent to the server, and the content field has been successfully parsed as a JSON response) * other types that were only generic in the kind of prediction they returned are also now no longer generic * prediction APIs no longer define overloads that attempt to infer whether the result will be structured or not (these were already inaccurate, as they only considered the `response_format` parameter, ignoring the `structured` field in the prediction config)
1 parent 4cfbd6c commit c81eea0

File tree

5 files changed

+57
-387
lines changed

5 files changed

+57
-387
lines changed

src/lmstudio/async_api.py

Lines changed: 15 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@
1616
Callable,
1717
Generic,
1818
Iterable,
19-
Literal,
2019
Sequence,
2120
Type,
2221
TypeAlias,
2322
TypeVar,
24-
overload,
2523
)
2624
from typing_extensions import (
2725
# Native in 3.11+
@@ -89,7 +87,6 @@
8987
RemoteCallHandler,
9088
ResponseSchema,
9189
TModelInfo,
92-
TPrediction,
9390
check_model_namespace,
9491
load_struct,
9592
_model_spec_to_api_dict,
@@ -902,24 +899,23 @@ async def _fetch_file_handle(self, file_data: _LocalFileData) -> FileHandle:
902899
return await self._files_session._fetch_file_handle(file_data)
903900

904901

905-
AsyncPredictionChannel: TypeAlias = AsyncChannel[PredictionResult[T]]
906-
AsyncPredictionCM: TypeAlias = AsyncContextManager[AsyncPredictionChannel[T]]
902+
AsyncPredictionChannel: TypeAlias = AsyncChannel[PredictionResult]
903+
AsyncPredictionCM: TypeAlias = AsyncContextManager[AsyncPredictionChannel]
907904

908905

909-
class AsyncPredictionStream(PredictionStreamBase[TPrediction]):
906+
class AsyncPredictionStream(PredictionStreamBase):
910907
"""Async context manager for an ongoing prediction process."""
911908

912909
def __init__(
913910
self,
914-
channel_cm: AsyncPredictionCM[TPrediction],
915-
endpoint: PredictionEndpoint[TPrediction],
911+
channel_cm: AsyncPredictionCM,
912+
endpoint: PredictionEndpoint,
916913
) -> None:
917914
"""Initialize a prediction process representation."""
918915
self._resource_manager = AsyncExitStack()
919-
self._channel_cm: AsyncPredictionCM[TPrediction] = channel_cm
920-
self._channel: AsyncPredictionChannel[TPrediction] | None = None
921-
# See comments in BasePrediction regarding not calling super().__init__() here
922-
self._init_prediction(endpoint)
916+
self._channel_cm: AsyncPredictionCM = channel_cm
917+
self._channel: AsyncPredictionChannel | None = None
918+
super().__init__(endpoint)
923919

924920
@sdk_public_api_async()
925921
async def start(self) -> None:
@@ -976,7 +972,7 @@ async def __aiter__(self) -> AsyncIterator[LlmPredictionFragment]:
976972
self._mark_finished()
977973

978974
@sdk_public_api_async()
979-
async def wait_for_result(self) -> PredictionResult[TPrediction]:
975+
async def wait_for_result(self) -> PredictionResult:
980976
"""Wait for the result of the prediction."""
981977
async for _ in self:
982978
pass
@@ -1011,34 +1007,6 @@ def _create_handle(self, model_identifier: str) -> "AsyncLLM":
10111007
"""Create a symbolic handle to the specified LLM model."""
10121008
return AsyncLLM(model_identifier, self)
10131009

1014-
@overload
1015-
async def _complete_stream(
1016-
self,
1017-
model_specifier: AnyModelSpecifier,
1018-
prompt: str,
1019-
*,
1020-
response_format: Literal[None] = ...,
1021-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1022-
preset: str | None = ...,
1023-
on_message: PredictionMessageCallback | None = ...,
1024-
on_first_token: PredictionFirstTokenCallback | None = ...,
1025-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1026-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1027-
) -> AsyncPredictionStream[str]: ...
1028-
@overload
1029-
async def _complete_stream(
1030-
self,
1031-
model_specifier: AnyModelSpecifier,
1032-
prompt: str,
1033-
*,
1034-
response_format: ResponseSchema = ...,
1035-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1036-
preset: str | None = ...,
1037-
on_message: PredictionMessageCallback | None = ...,
1038-
on_first_token: PredictionFirstTokenCallback | None = ...,
1039-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1040-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1041-
) -> AsyncPredictionStream[DictObject]: ...
10421010
async def _complete_stream(
10431011
self,
10441012
model_specifier: AnyModelSpecifier,
@@ -1051,7 +1019,7 @@ async def _complete_stream(
10511019
on_first_token: PredictionFirstTokenCallback | None = None,
10521020
on_prediction_fragment: PredictionFragmentCallback | None = None,
10531021
on_prompt_processing_progress: PromptProcessingCallback | None = None,
1054-
) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]:
1022+
) -> AsyncPredictionStream:
10551023
"""Request a one-off prediction without any context and stream the generated tokens.
10561024
10571025
Note: details of configuration fields may change in SDK feature releases.
@@ -1071,34 +1039,6 @@ async def _complete_stream(
10711039
prediction_stream = AsyncPredictionStream(channel_cm, endpoint)
10721040
return prediction_stream
10731041

1074-
@overload
1075-
async def _respond_stream(
1076-
self,
1077-
model_specifier: AnyModelSpecifier,
1078-
history: Chat | ChatHistoryDataDict | str,
1079-
*,
1080-
response_format: Literal[None] = ...,
1081-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1082-
preset: str | None = ...,
1083-
on_message: PredictionMessageCallback | None = ...,
1084-
on_first_token: PredictionFirstTokenCallback | None = ...,
1085-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1086-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1087-
) -> AsyncPredictionStream[str]: ...
1088-
@overload
1089-
async def _respond_stream(
1090-
self,
1091-
model_specifier: AnyModelSpecifier,
1092-
history: Chat | ChatHistoryDataDict | str,
1093-
*,
1094-
response_format: ResponseSchema = ...,
1095-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1096-
preset: str | None = ...,
1097-
on_message: PredictionMessageCallback | None = ...,
1098-
on_first_token: PredictionFirstTokenCallback | None = ...,
1099-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1100-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1101-
) -> AsyncPredictionStream[DictObject]: ...
11021042
async def _respond_stream(
11031043
self,
11041044
model_specifier: AnyModelSpecifier,
@@ -1111,7 +1051,7 @@ async def _respond_stream(
11111051
on_first_token: PredictionFirstTokenCallback | None = None,
11121052
on_prediction_fragment: PredictionFragmentCallback | None = None,
11131053
on_prompt_processing_progress: PromptProcessingCallback | None = None,
1114-
) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]:
1054+
) -> AsyncPredictionStream:
11151055
"""Request a response in an ongoing assistant chat session and stream the generated tokens.
11161056
11171057
Note: details of configuration fields may change in SDK feature releases.
@@ -1250,32 +1190,6 @@ async def get_context_length(self) -> int:
12501190
class AsyncLLM(AsyncModelHandle[AsyncSessionLlm]):
12511191
"""Reference to a loaded LLM model."""
12521192

1253-
@overload
1254-
async def complete_stream(
1255-
self,
1256-
prompt: str,
1257-
*,
1258-
response_format: Literal[None] = ...,
1259-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1260-
preset: str | None = ...,
1261-
on_message: PredictionMessageCallback | None = ...,
1262-
on_first_token: PredictionFirstTokenCallback | None = ...,
1263-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1264-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1265-
) -> AsyncPredictionStream[str]: ...
1266-
@overload
1267-
async def complete_stream(
1268-
self,
1269-
prompt: str,
1270-
*,
1271-
response_format: ResponseSchema = ...,
1272-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1273-
preset: str | None = ...,
1274-
on_message: PredictionMessageCallback | None = ...,
1275-
on_first_token: PredictionFirstTokenCallback | None = ...,
1276-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1277-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1278-
) -> AsyncPredictionStream[DictObject]: ...
12791193
@sdk_public_api_async()
12801194
async def complete_stream(
12811195
self,
@@ -1288,7 +1202,7 @@ async def complete_stream(
12881202
on_first_token: PredictionFirstTokenCallback | None = None,
12891203
on_prediction_fragment: PredictionFragmentCallback | None = None,
12901204
on_prompt_processing_progress: PromptProcessingCallback | None = None,
1291-
) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]:
1205+
) -> AsyncPredictionStream:
12921206
"""Request a one-off prediction without any context and stream the generated tokens.
12931207
12941208
Note: details of configuration fields may change in SDK feature releases.
@@ -1305,32 +1219,6 @@ async def complete_stream(
13051219
on_prompt_processing_progress=on_prompt_processing_progress,
13061220
)
13071221

1308-
@overload
1309-
async def complete(
1310-
self,
1311-
prompt: str,
1312-
*,
1313-
response_format: Literal[None] = ...,
1314-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1315-
preset: str | None = ...,
1316-
on_message: PredictionMessageCallback | None = ...,
1317-
on_first_token: PredictionFirstTokenCallback | None = ...,
1318-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1319-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1320-
) -> PredictionResult[str]: ...
1321-
@overload
1322-
async def complete(
1323-
self,
1324-
prompt: str,
1325-
*,
1326-
response_format: ResponseSchema = ...,
1327-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1328-
preset: str | None = ...,
1329-
on_message: PredictionMessageCallback | None = ...,
1330-
on_first_token: PredictionFirstTokenCallback | None = ...,
1331-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1332-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1333-
) -> PredictionResult[DictObject]: ...
13341222
@sdk_public_api_async()
13351223
async def complete(
13361224
self,
@@ -1343,7 +1231,7 @@ async def complete(
13431231
on_first_token: PredictionFirstTokenCallback | None = None,
13441232
on_prediction_fragment: PredictionFragmentCallback | None = None,
13451233
on_prompt_processing_progress: PromptProcessingCallback | None = None,
1346-
) -> PredictionResult[str] | PredictionResult[DictObject]:
1234+
) -> PredictionResult:
13471235
"""Request a one-off prediction without any context.
13481236
13491237
Note: details of configuration fields may change in SDK feature releases.
@@ -1365,32 +1253,6 @@ async def complete(
13651253
pass
13661254
return prediction_stream.result()
13671255

1368-
@overload
1369-
async def respond_stream(
1370-
self,
1371-
history: Chat | ChatHistoryDataDict | str,
1372-
*,
1373-
response_format: Literal[None] = ...,
1374-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1375-
preset: str | None = ...,
1376-
on_message: PredictionMessageCallback | None = ...,
1377-
on_first_token: PredictionFirstTokenCallback | None = ...,
1378-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1379-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1380-
) -> AsyncPredictionStream[str]: ...
1381-
@overload
1382-
async def respond_stream(
1383-
self,
1384-
history: Chat | ChatHistoryDataDict | str,
1385-
*,
1386-
response_format: ResponseSchema = ...,
1387-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1388-
preset: str | None = ...,
1389-
on_message: PredictionMessageCallback | None = ...,
1390-
on_first_token: PredictionFirstTokenCallback | None = ...,
1391-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1392-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1393-
) -> AsyncPredictionStream[DictObject]: ...
13941256
@sdk_public_api_async()
13951257
async def respond_stream(
13961258
self,
@@ -1403,7 +1265,7 @@ async def respond_stream(
14031265
on_first_token: PredictionFirstTokenCallback | None = None,
14041266
on_prediction_fragment: PredictionFragmentCallback | None = None,
14051267
on_prompt_processing_progress: PromptProcessingCallback | None = None,
1406-
) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]:
1268+
) -> AsyncPredictionStream:
14071269
"""Request a response in an ongoing assistant chat session and stream the generated tokens.
14081270
14091271
Note: details of configuration fields may change in SDK feature releases.
@@ -1420,32 +1282,6 @@ async def respond_stream(
14201282
on_prompt_processing_progress=on_prompt_processing_progress,
14211283
)
14221284

1423-
@overload
1424-
async def respond(
1425-
self,
1426-
history: Chat | ChatHistoryDataDict | str,
1427-
*,
1428-
response_format: Literal[None] = ...,
1429-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1430-
preset: str | None = ...,
1431-
on_message: PredictionMessageCallback | None = ...,
1432-
on_first_token: PredictionFirstTokenCallback | None = ...,
1433-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1434-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1435-
) -> PredictionResult[str]: ...
1436-
@overload
1437-
async def respond(
1438-
self,
1439-
history: Chat | ChatHistoryDataDict | str,
1440-
*,
1441-
response_format: ResponseSchema = ...,
1442-
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
1443-
preset: str | None = ...,
1444-
on_message: PredictionMessageCallback | None = ...,
1445-
on_first_token: PredictionFirstTokenCallback | None = ...,
1446-
on_prediction_fragment: PredictionFragmentCallback | None = ...,
1447-
on_prompt_processing_progress: PromptProcessingCallback | None = ...,
1448-
) -> PredictionResult[DictObject]: ...
14491285
@sdk_public_api_async()
14501286
async def respond(
14511287
self,
@@ -1458,7 +1294,7 @@ async def respond(
14581294
on_first_token: PredictionFirstTokenCallback | None = None,
14591295
on_prediction_fragment: PredictionFragmentCallback | None = None,
14601296
on_prompt_processing_progress: PromptProcessingCallback | None = None,
1461-
) -> PredictionResult[str] | PredictionResult[DictObject]:
1297+
) -> PredictionResult:
14621298
"""Request a response in an ongoing assistant chat session.
14631299
14641300
Note: details of configuration fields may change in SDK feature releases.

0 commit comments

Comments
 (0)