Skip to content

Commit 4cfbd6c

Browse files
authored
Add GBNF grammar support (#60)
1 parent c08ab50 commit 4cfbd6c

File tree

8 files changed

+198
-70
lines changed

8 files changed

+198
-70
lines changed

src/lmstudio/_kv_config.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
# Known KV config settings are defined in
44
# https://github.com/lmstudio-ai/lmstudio-js/blob/main/packages/lms-kv-config/src/schema.ts
55
from dataclasses import dataclass
6-
from typing import Any, Container, Iterable, Sequence, Type, TypeVar
6+
from typing import Any, Container, Iterable, Sequence, Type, TypeAlias, TypeVar, cast
77

88
from .sdk_api import LMStudioValueError
9-
from .schemas import DictSchema, DictObject, ModelSchema, MutableDictObject
9+
from .schemas import DictObject, DictSchema, ModelSchema, MutableDictObject
1010
from ._sdk_models import (
1111
EmbeddingLoadModelConfig,
1212
EmbeddingLoadModelConfigDict,
@@ -18,6 +18,8 @@
1818
LlmLoadModelConfigDict,
1919
LlmPredictionConfig,
2020
LlmPredictionConfigDict,
21+
LlmStructuredPredictionSetting,
22+
LlmStructuredPredictionSettingDict,
2123
)
2224

2325

@@ -330,52 +332,76 @@ def load_config_to_kv_config_stack(
330332
return _client_config_to_kv_config_stack(dict_config, TO_SERVER_LOAD_EMBEDDING)
331333

332334

335+
ResponseSchema: TypeAlias = (
336+
DictSchema
337+
| LlmStructuredPredictionSetting
338+
| LlmStructuredPredictionSettingDict
339+
| type[ModelSchema]
340+
)
341+
342+
333343
def prediction_config_to_kv_config_stack(
334-
response_format: Type[ModelSchema] | DictSchema | None,
344+
response_format: Type[ModelSchema] | ResponseSchema | None,
335345
config: LlmPredictionConfig | LlmPredictionConfigDict | None,
336346
for_text_completion: bool = False,
337347
) -> tuple[bool, KvConfigStack]:
338-
dict_config: DictObject
348+
dict_config: LlmPredictionConfigDict
339349
if config is None:
340350
dict_config = {}
341351
elif isinstance(config, LlmPredictionConfig):
342352
dict_config = config.to_dict()
343353
else:
344354
assert isinstance(config, dict)
345355
dict_config = LlmPredictionConfig._from_any_dict(config).to_dict()
346-
response_schema: DictSchema | None = None
347356
if response_format is not None:
348-
structured = True
349357
if "structured" in dict_config:
350358
raise LMStudioValueError(
351359
"Cannot specify both 'response_format' in API call and 'structured' in config"
352360
)
353-
if isinstance(response_format, type) and issubclass(
361+
response_schema: LlmStructuredPredictionSettingDict
362+
structured = True
363+
if isinstance(response_format, LlmStructuredPredictionSetting):
364+
response_schema = response_format.to_dict()
365+
elif isinstance(response_format, type) and issubclass(
354366
response_format, ModelSchema
355367
):
356-
response_schema = response_format.model_json_schema()
368+
response_schema = {
369+
"type": "json",
370+
"jsonSchema": response_format.model_json_schema(),
371+
}
357372
else:
358-
response_schema = response_format
373+
# Casts are needed as mypy doesn't detect that the given case patterns
374+
# conform to the definition of LlmStructuredPredictionSettingDict
375+
match response_format:
376+
case {"type": "json", "jsonSchema": _} as json_schema:
377+
response_schema = cast(
378+
LlmStructuredPredictionSettingDict, json_schema
379+
)
380+
case {"type": "gbnf", "gbnfGrammar": _} as gbnf_schema:
381+
response_schema = cast(
382+
LlmStructuredPredictionSettingDict, gbnf_schema
383+
)
384+
case {"type": _}:
385+
# Assume any other input with a type key is a JSON schema definition
386+
response_schema = {
387+
"type": "json",
388+
"jsonSchema": response_format,
389+
}
390+
case _:
391+
raise LMStudioValueError(
392+
f"Failed to parse response format: {response_format!r}"
393+
)
394+
dict_config["structured"] = response_schema
359395
else:
360396
# The response schema may also be passed in via the config
361397
# (doing it this way type hints as an unstructured result,
362398
# but we still allow it at runtime for consistency with JS)
363399
match dict_config:
364-
case {"structured": {"type": "json"}}:
400+
case {"structured": {"type": "json" | "gbnf"}}:
365401
structured = True
366402
case _:
367403
structured = False
368404
fields = _to_kv_config_stack_base(dict_config, TO_SERVER_PREDICTION)
369-
if response_schema is not None:
370-
fields.append(
371-
{
372-
"key": "llm.prediction.structured",
373-
"value": {
374-
"type": "json",
375-
"jsonSchema": response_schema,
376-
},
377-
}
378-
)
379405
additional_layers: list[KvConfigStackLayerDict] = []
380406
if for_text_completion:
381407
additional_layers.append(_get_completion_config_layer())

src/lmstudio/async_api.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException
3535

3636
from .sdk_api import LMStudioRuntimeError, sdk_public_api, sdk_public_api_async
37-
from .schemas import AnyLMStudioStruct, DictObject, DictSchema, ModelSchema
37+
from .schemas import AnyLMStudioStruct, DictObject
3838
from .history import (
3939
Chat,
4040
ChatHistoryDataDict,
@@ -87,6 +87,7 @@
8787
PredictionResult,
8888
PromptProcessingCallback,
8989
RemoteCallHandler,
90+
ResponseSchema,
9091
TModelInfo,
9192
TPrediction,
9293
check_model_namespace,
@@ -1030,7 +1031,7 @@ async def _complete_stream(
10301031
model_specifier: AnyModelSpecifier,
10311032
prompt: str,
10321033
*,
1033-
response_format: Type[ModelSchema] | DictSchema = ...,
1034+
response_format: ResponseSchema = ...,
10341035
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
10351036
preset: str | None = ...,
10361037
on_message: PredictionMessageCallback | None = ...,
@@ -1043,7 +1044,7 @@ async def _complete_stream(
10431044
model_specifier: AnyModelSpecifier,
10441045
prompt: str,
10451046
*,
1046-
response_format: Type[ModelSchema] | DictSchema | None = None,
1047+
response_format: ResponseSchema | None = None,
10471048
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
10481049
preset: str | None = None,
10491050
on_message: PredictionMessageCallback | None = None,
@@ -1090,7 +1091,7 @@ async def _respond_stream(
10901091
model_specifier: AnyModelSpecifier,
10911092
history: Chat | ChatHistoryDataDict | str,
10921093
*,
1093-
response_format: Type[ModelSchema] | DictSchema = ...,
1094+
response_format: ResponseSchema = ...,
10941095
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
10951096
preset: str | None = ...,
10961097
on_message: PredictionMessageCallback | None = ...,
@@ -1103,7 +1104,7 @@ async def _respond_stream(
11031104
model_specifier: AnyModelSpecifier,
11041105
history: Chat | ChatHistoryDataDict | str,
11051106
*,
1106-
response_format: Type[ModelSchema] | DictSchema | None = None,
1107+
response_format: ResponseSchema | None = None,
11071108
on_message: PredictionMessageCallback | None = None,
11081109
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
11091110
preset: str | None = None,
@@ -1267,7 +1268,7 @@ async def complete_stream(
12671268
self,
12681269
prompt: str,
12691270
*,
1270-
response_format: Type[ModelSchema] | DictSchema = ...,
1271+
response_format: ResponseSchema = ...,
12711272
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
12721273
preset: str | None = ...,
12731274
on_message: PredictionMessageCallback | None = ...,
@@ -1280,7 +1281,7 @@ async def complete_stream(
12801281
self,
12811282
prompt: str,
12821283
*,
1283-
response_format: Type[ModelSchema] | DictSchema | None = None,
1284+
response_format: ResponseSchema | None = None,
12841285
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
12851286
preset: str | None = None,
12861287
on_message: PredictionMessageCallback | None = None,
@@ -1322,7 +1323,7 @@ async def complete(
13221323
self,
13231324
prompt: str,
13241325
*,
1325-
response_format: Type[ModelSchema] | DictSchema = ...,
1326+
response_format: ResponseSchema = ...,
13261327
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
13271328
preset: str | None = ...,
13281329
on_message: PredictionMessageCallback | None = ...,
@@ -1335,7 +1336,7 @@ async def complete(
13351336
self,
13361337
prompt: str,
13371338
*,
1338-
response_format: Type[ModelSchema] | DictSchema | None = None,
1339+
response_format: ResponseSchema | None = None,
13391340
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
13401341
preset: str | None = None,
13411342
on_message: PredictionMessageCallback | None = None,
@@ -1382,7 +1383,7 @@ async def respond_stream(
13821383
self,
13831384
history: Chat | ChatHistoryDataDict | str,
13841385
*,
1385-
response_format: Type[ModelSchema] | DictSchema = ...,
1386+
response_format: ResponseSchema = ...,
13861387
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
13871388
preset: str | None = ...,
13881389
on_message: PredictionMessageCallback | None = ...,
@@ -1395,7 +1396,7 @@ async def respond_stream(
13951396
self,
13961397
history: Chat | ChatHistoryDataDict | str,
13971398
*,
1398-
response_format: Type[ModelSchema] | DictSchema | None = None,
1399+
response_format: ResponseSchema | None = None,
13991400
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
14001401
preset: str | None = None,
14011402
on_message: PredictionMessageCallback | None = None,
@@ -1437,7 +1438,7 @@ async def respond(
14371438
self,
14381439
history: Chat | ChatHistoryDataDict | str,
14391440
*,
1440-
response_format: Type[ModelSchema] | DictSchema = ...,
1441+
response_format: ResponseSchema = ...,
14411442
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
14421443
preset: str | None = ...,
14431444
on_message: PredictionMessageCallback | None = ...,
@@ -1450,7 +1451,7 @@ async def respond(
14501451
self,
14511452
history: Chat | ChatHistoryDataDict | str,
14521453
*,
1453-
response_format: Type[ModelSchema] | DictSchema | None = None,
1454+
response_format: ResponseSchema | None = None,
14541455
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
14551456
preset: str | None = None,
14561457
on_message: PredictionMessageCallback | None = None,

src/lmstudio/json_api.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,14 @@
4848
from .schemas import (
4949
AnyLMStudioStruct,
5050
DictObject,
51-
DictSchema,
5251
LMStudioStruct,
53-
ModelSchema,
5452
TWireFormat,
5553
_format_json,
5654
_snake_case_keys_to_camelCase,
5755
_to_json_schema,
5856
)
5957
from ._kv_config import (
58+
ResponseSchema,
6059
TLoadConfig,
6160
TLoadConfigDict,
6261
load_config_to_kv_config_stack,
@@ -168,6 +167,7 @@
168167
"PredictionResult",
169168
"PredictionRoundResult",
170169
"PromptProcessingCallback",
170+
"ResponseSchema",
171171
"SerializedLMSExtendedError",
172172
"ToolDefinition",
173173
"ToolFunctionDef",
@@ -185,6 +185,7 @@
185185
AnyModelSpecifier: TypeAlias = str | ModelSpecifier | ModelQuery | DictObject
186186
AnyLoadConfig: TypeAlias = EmbeddingLoadModelConfig | LlmLoadModelConfig
187187

188+
188189
GetOrLoadChannelRequest: TypeAlias = (
189190
EmbeddingChannelGetOrLoadCreationParameter | LlmChannelGetOrLoadCreationParameter
190191
)
@@ -1122,7 +1123,7 @@ def __init__(
11221123
self,
11231124
model_specifier: AnyModelSpecifier,
11241125
history: Chat,
1125-
response_format: Type[ModelSchema] | DictSchema | None = None,
1126+
response_format: ResponseSchema | None = None,
11261127
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
11271128
preset_config: str | None = None,
11281129
on_message: PredictionMessageCallback | None = None,
@@ -1181,7 +1182,7 @@ def __init__(
11811182
@classmethod
11821183
def _make_config_override(
11831184
cls,
1184-
response_format: Type[ModelSchema] | DictSchema | None,
1185+
response_format: ResponseSchema | None,
11851186
config: LlmPredictionConfig | LlmPredictionConfigDict | None,
11861187
) -> tuple[bool, KvConfigStack]:
11871188
return prediction_config_to_kv_config_stack(
@@ -1264,7 +1265,11 @@ def iter_message_events(
12641265
# to parse the received content in the latter case.
12651266
result_content = "".join(self._fragment_content)
12661267
if self._structured and not self._is_cancelled:
1267-
parsed_content = json.loads(result_content)
1268+
try:
1269+
parsed_content = json.loads(result_content)
1270+
except json.JSONDecodeError:
1271+
# Fall back to unstructured result reporting
1272+
parsed_content = result_content
12681273
else:
12691274
parsed_content = result_content
12701275
yield self._set_result(
@@ -1385,7 +1390,7 @@ def __init__(
13851390
self,
13861391
model_specifier: AnyModelSpecifier,
13871392
prompt: str,
1388-
response_format: Type[ModelSchema] | DictSchema | None = None,
1393+
response_format: ResponseSchema | None = None,
13891394
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
13901395
preset_config: str | None = None,
13911396
on_message: PredictionMessageCallback | None = None,

src/lmstudio/schemas.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
MutableMapping,
1111
Protocol,
1212
Sequence,
13+
TypeAlias,
1314
TypeVar,
1415
cast,
1516
runtime_checkable,
@@ -26,14 +27,16 @@
2627

2728
__all__ = [
2829
"BaseModel",
29-
"ModelSchema",
3030
"DictObject",
3131
"DictSchema",
32+
"ModelSchema",
3233
]
3334

34-
DictObject = Mapping[str, Any] # Any JSON-compatible string-keyed dict
35-
MutableDictObject = MutableMapping[str, Any]
36-
DictSchema = Mapping[str, Any] # JSON schema as a string-keyed dict
35+
DictObject: TypeAlias = Mapping[str, Any] # Any JSON-compatible string-keyed dict
36+
MutableDictObject: TypeAlias = MutableMapping[str, Any]
37+
DictSchema: TypeAlias = Mapping[str, Any] # JSON schema as a string-keyed dict
38+
# It would be nice to require a "type" key in DictSchema, but that's currently tricky
39+
# without "extra_items" support in TypedDict: https://peps.python.org/pep-0728/
3740

3841

3942
def _format_json(data: Any, *, sort_keys: bool = True) -> str:

0 commit comments

Comments
 (0)