Skip to content

Add GBNF grammar support #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 46 additions & 20 deletions src/lmstudio/_kv_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
# Known KV config settings are defined in
# https://github.com/lmstudio-ai/lmstudio-js/blob/main/packages/lms-kv-config/src/schema.ts
from dataclasses import dataclass
from typing import Any, Container, Iterable, Sequence, Type, TypeVar
from typing import Any, Container, Iterable, Sequence, Type, TypeAlias, TypeVar, cast

from .sdk_api import LMStudioValueError
from .schemas import DictSchema, DictObject, ModelSchema, MutableDictObject
from .schemas import DictObject, DictSchema, ModelSchema, MutableDictObject
from ._sdk_models import (
EmbeddingLoadModelConfig,
EmbeddingLoadModelConfigDict,
Expand All @@ -18,6 +18,8 @@
LlmLoadModelConfigDict,
LlmPredictionConfig,
LlmPredictionConfigDict,
LlmStructuredPredictionSetting,
LlmStructuredPredictionSettingDict,
)


Expand Down Expand Up @@ -330,52 +332,76 @@ def load_config_to_kv_config_stack(
return _client_config_to_kv_config_stack(dict_config, TO_SERVER_LOAD_EMBEDDING)


ResponseSchema: TypeAlias = (
DictSchema
| LlmStructuredPredictionSetting
| LlmStructuredPredictionSettingDict
| type[ModelSchema]
)


def prediction_config_to_kv_config_stack(
response_format: Type[ModelSchema] | DictSchema | None,
response_format: Type[ModelSchema] | ResponseSchema | None,
config: LlmPredictionConfig | LlmPredictionConfigDict | None,
for_text_completion: bool = False,
) -> tuple[bool, KvConfigStack]:
dict_config: DictObject
dict_config: LlmPredictionConfigDict
if config is None:
dict_config = {}
elif isinstance(config, LlmPredictionConfig):
dict_config = config.to_dict()
else:
assert isinstance(config, dict)
dict_config = LlmPredictionConfig._from_any_dict(config).to_dict()
response_schema: DictSchema | None = None
if response_format is not None:
structured = True
if "structured" in dict_config:
raise LMStudioValueError(
"Cannot specify both 'response_format' in API call and 'structured' in config"
)
if isinstance(response_format, type) and issubclass(
response_schema: LlmStructuredPredictionSettingDict
structured = True
if isinstance(response_format, LlmStructuredPredictionSetting):
response_schema = response_format.to_dict()
elif isinstance(response_format, type) and issubclass(
response_format, ModelSchema
):
response_schema = response_format.model_json_schema()
response_schema = {
"type": "json",
"jsonSchema": response_format.model_json_schema(),
}
else:
response_schema = response_format
# Casts are needed as mypy doesn't detect that the given case patterns
# conform to the definition of LlmStructuredPredictionSettingDict
match response_format:
case {"type": "json", "jsonSchema": _} as json_schema:
response_schema = cast(
LlmStructuredPredictionSettingDict, json_schema
)
case {"type": "gbnf", "gbnfGrammar": _} as gbnf_schema:
response_schema = cast(
LlmStructuredPredictionSettingDict, gbnf_schema
)
case {"type": _}:
# Assume any other input with a type key is a JSON schema definition
response_schema = {
"type": "json",
"jsonSchema": response_format,
}
case _:
raise LMStudioValueError(
f"Failed to parse response format: {response_format!r}"
)
dict_config["structured"] = response_schema
else:
# The response schema may also be passed in via the config
# (doing it this way type hints as an unstructured result,
# but we still allow it at runtime for consistency with JS)
match dict_config:
case {"structured": {"type": "json"}}:
case {"structured": {"type": "json" | "gbnf"}}:
structured = True
case _:
structured = False
fields = _to_kv_config_stack_base(dict_config, TO_SERVER_PREDICTION)
if response_schema is not None:
fields.append(
{
"key": "llm.prediction.structured",
"value": {
"type": "json",
"jsonSchema": response_schema,
},
}
)
additional_layers: list[KvConfigStackLayerDict] = []
if for_text_completion:
additional_layers.append(_get_completion_config_layer())
Expand Down
27 changes: 14 additions & 13 deletions src/lmstudio/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException

from .sdk_api import LMStudioRuntimeError, sdk_public_api, sdk_public_api_async
from .schemas import AnyLMStudioStruct, DictObject, DictSchema, ModelSchema
from .schemas import AnyLMStudioStruct, DictObject
from .history import (
Chat,
ChatHistoryDataDict,
Expand Down Expand Up @@ -87,6 +87,7 @@
PredictionResult,
PromptProcessingCallback,
RemoteCallHandler,
ResponseSchema,
TModelInfo,
TPrediction,
check_model_namespace,
Expand Down Expand Up @@ -1030,7 +1031,7 @@ async def _complete_stream(
model_specifier: AnyModelSpecifier,
prompt: str,
*,
response_format: Type[ModelSchema] | DictSchema = ...,
response_format: ResponseSchema = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
Expand All @@ -1043,7 +1044,7 @@ async def _complete_stream(
model_specifier: AnyModelSpecifier,
prompt: str,
*,
response_format: Type[ModelSchema] | DictSchema | None = None,
response_format: ResponseSchema | None = None,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
preset: str | None = None,
on_message: PredictionMessageCallback | None = None,
Expand Down Expand Up @@ -1090,7 +1091,7 @@ async def _respond_stream(
model_specifier: AnyModelSpecifier,
history: Chat | ChatHistoryDataDict | str,
*,
response_format: Type[ModelSchema] | DictSchema = ...,
response_format: ResponseSchema = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
Expand All @@ -1103,7 +1104,7 @@ async def _respond_stream(
model_specifier: AnyModelSpecifier,
history: Chat | ChatHistoryDataDict | str,
*,
response_format: Type[ModelSchema] | DictSchema | None = None,
response_format: ResponseSchema | None = None,
on_message: PredictionMessageCallback | None = None,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
preset: str | None = None,
Expand Down Expand Up @@ -1267,7 +1268,7 @@ async def complete_stream(
self,
prompt: str,
*,
response_format: Type[ModelSchema] | DictSchema = ...,
response_format: ResponseSchema = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
Expand All @@ -1280,7 +1281,7 @@ async def complete_stream(
self,
prompt: str,
*,
response_format: Type[ModelSchema] | DictSchema | None = None,
response_format: ResponseSchema | None = None,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
preset: str | None = None,
on_message: PredictionMessageCallback | None = None,
Expand Down Expand Up @@ -1322,7 +1323,7 @@ async def complete(
self,
prompt: str,
*,
response_format: Type[ModelSchema] | DictSchema = ...,
response_format: ResponseSchema = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
Expand All @@ -1335,7 +1336,7 @@ async def complete(
self,
prompt: str,
*,
response_format: Type[ModelSchema] | DictSchema | None = None,
response_format: ResponseSchema | None = None,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
preset: str | None = None,
on_message: PredictionMessageCallback | None = None,
Expand Down Expand Up @@ -1382,7 +1383,7 @@ async def respond_stream(
self,
history: Chat | ChatHistoryDataDict | str,
*,
response_format: Type[ModelSchema] | DictSchema = ...,
response_format: ResponseSchema = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
Expand All @@ -1395,7 +1396,7 @@ async def respond_stream(
self,
history: Chat | ChatHistoryDataDict | str,
*,
response_format: Type[ModelSchema] | DictSchema | None = None,
response_format: ResponseSchema | None = None,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
preset: str | None = None,
on_message: PredictionMessageCallback | None = None,
Expand Down Expand Up @@ -1437,7 +1438,7 @@ async def respond(
self,
history: Chat | ChatHistoryDataDict | str,
*,
response_format: Type[ModelSchema] | DictSchema = ...,
response_format: ResponseSchema = ...,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = ...,
preset: str | None = ...,
on_message: PredictionMessageCallback | None = ...,
Expand All @@ -1450,7 +1451,7 @@ async def respond(
self,
history: Chat | ChatHistoryDataDict | str,
*,
response_format: Type[ModelSchema] | DictSchema | None = None,
response_format: ResponseSchema | None = None,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
preset: str | None = None,
on_message: PredictionMessageCallback | None = None,
Expand Down
17 changes: 11 additions & 6 deletions src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,14 @@
from .schemas import (
AnyLMStudioStruct,
DictObject,
DictSchema,
LMStudioStruct,
ModelSchema,
TWireFormat,
_format_json,
_snake_case_keys_to_camelCase,
_to_json_schema,
)
from ._kv_config import (
ResponseSchema,
TLoadConfig,
TLoadConfigDict,
load_config_to_kv_config_stack,
Expand Down Expand Up @@ -168,6 +167,7 @@
"PredictionResult",
"PredictionRoundResult",
"PromptProcessingCallback",
"ResponseSchema",
"SerializedLMSExtendedError",
"ToolDefinition",
"ToolFunctionDef",
Expand All @@ -185,6 +185,7 @@
AnyModelSpecifier: TypeAlias = str | ModelSpecifier | ModelQuery | DictObject
AnyLoadConfig: TypeAlias = EmbeddingLoadModelConfig | LlmLoadModelConfig


GetOrLoadChannelRequest: TypeAlias = (
EmbeddingChannelGetOrLoadCreationParameter | LlmChannelGetOrLoadCreationParameter
)
Expand Down Expand Up @@ -1122,7 +1123,7 @@ def __init__(
self,
model_specifier: AnyModelSpecifier,
history: Chat,
response_format: Type[ModelSchema] | DictSchema | None = None,
response_format: ResponseSchema | None = None,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
preset_config: str | None = None,
on_message: PredictionMessageCallback | None = None,
Expand Down Expand Up @@ -1181,7 +1182,7 @@ def __init__(
@classmethod
def _make_config_override(
cls,
response_format: Type[ModelSchema] | DictSchema | None,
response_format: ResponseSchema | None,
config: LlmPredictionConfig | LlmPredictionConfigDict | None,
) -> tuple[bool, KvConfigStack]:
return prediction_config_to_kv_config_stack(
Expand Down Expand Up @@ -1264,7 +1265,11 @@ def iter_message_events(
# to parse the received content in the latter case.
result_content = "".join(self._fragment_content)
if self._structured and not self._is_cancelled:
parsed_content = json.loads(result_content)
try:
parsed_content = json.loads(result_content)
except json.JSONDecodeError:
# Fall back to unstructured result reporting
parsed_content = result_content
else:
parsed_content = result_content
yield self._set_result(
Expand Down Expand Up @@ -1385,7 +1390,7 @@ def __init__(
self,
model_specifier: AnyModelSpecifier,
prompt: str,
response_format: Type[ModelSchema] | DictSchema | None = None,
response_format: ResponseSchema | None = None,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
preset_config: str | None = None,
on_message: PredictionMessageCallback | None = None,
Expand Down
11 changes: 7 additions & 4 deletions src/lmstudio/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
MutableMapping,
Protocol,
Sequence,
TypeAlias,
TypeVar,
cast,
runtime_checkable,
Expand All @@ -26,14 +27,16 @@

__all__ = [
"BaseModel",
"ModelSchema",
"DictObject",
"DictSchema",
"ModelSchema",
]

DictObject = Mapping[str, Any] # Any JSON-compatible string-keyed dict
MutableDictObject = MutableMapping[str, Any]
DictSchema = Mapping[str, Any] # JSON schema as a string-keyed dict
DictObject: TypeAlias = Mapping[str, Any] # Any JSON-compatible string-keyed dict
MutableDictObject: TypeAlias = MutableMapping[str, Any]
DictSchema: TypeAlias = Mapping[str, Any] # JSON schema as a string-keyed dict
# It would be nice to require a "type" key in DictSchema, but that's currently tricky
# without "extra_items" support in TypedDict: https://peps.python.org/pep-0728/


def _format_json(data: Any, *, sort_keys: bool = True) -> str:
Expand Down
Loading
Loading