From 44ad8eb4f7370299b36597b9776de5e760fbd5b9 Mon Sep 17 00:00:00 2001 From: Alyssa Coghlan Date: Sat, 22 Mar 2025 00:20:29 +1000 Subject: [PATCH] Add GBNF grammar support --- src/lmstudio/_kv_config.py | 66 ++++++++++++++++++++--------- src/lmstudio/async_api.py | 27 ++++++------ src/lmstudio/json_api.py | 17 +++++--- src/lmstudio/schemas.py | 11 +++-- src/lmstudio/sync_api.py | 27 ++++++------ tests/async/test_inference_async.py | 39 ++++++++++++++--- tests/support/__init__.py | 43 ++++++++++++++++++- tests/sync/test_inference_sync.py | 38 ++++++++++++++--- 8 files changed, 198 insertions(+), 70 deletions(-) diff --git a/src/lmstudio/_kv_config.py b/src/lmstudio/_kv_config.py index e1126f5..04036c2 100644 --- a/src/lmstudio/_kv_config.py +++ b/src/lmstudio/_kv_config.py @@ -3,10 +3,10 @@ # Known KV config settings are defined in # https://github.com/lmstudio-ai/lmstudio-js/blob/main/packages/lms-kv-config/src/schema.ts from dataclasses import dataclass -from typing import Any, Container, Iterable, Sequence, Type, TypeVar +from typing import Any, Container, Iterable, Sequence, Type, TypeAlias, TypeVar, cast from .sdk_api import LMStudioValueError -from .schemas import DictSchema, DictObject, ModelSchema, MutableDictObject +from .schemas import DictObject, DictSchema, ModelSchema, MutableDictObject from ._sdk_models import ( EmbeddingLoadModelConfig, EmbeddingLoadModelConfigDict, @@ -18,6 +18,8 @@ LlmLoadModelConfigDict, LlmPredictionConfig, LlmPredictionConfigDict, + LlmStructuredPredictionSetting, + LlmStructuredPredictionSettingDict, ) @@ -330,12 +332,20 @@ def load_config_to_kv_config_stack( return _client_config_to_kv_config_stack(dict_config, TO_SERVER_LOAD_EMBEDDING) +ResponseSchema: TypeAlias = ( + DictSchema + | LlmStructuredPredictionSetting + | LlmStructuredPredictionSettingDict + | type[ModelSchema] +) + + def prediction_config_to_kv_config_stack( - response_format: Type[ModelSchema] | DictSchema | None, + response_format: Type[ModelSchema] | ResponseSchema | None, config: LlmPredictionConfig | LlmPredictionConfigDict | None, for_text_completion: bool = False, ) -> tuple[bool, KvConfigStack]: - dict_config: DictObject + dict_config: LlmPredictionConfigDict if config is None: dict_config = {} elif isinstance(config, LlmPredictionConfig): @@ -343,39 +353,55 @@ def prediction_config_to_kv_config_stack( else: assert isinstance(config, dict) dict_config = LlmPredictionConfig._from_any_dict(config).to_dict() - response_schema: DictSchema | None = None if response_format is not None: - structured = True if "structured" in dict_config: raise LMStudioValueError( "Cannot specify both 'response_format' in API call and 'structured' in config" ) - if isinstance(response_format, type) and issubclass( + response_schema: LlmStructuredPredictionSettingDict + structured = True + if isinstance(response_format, LlmStructuredPredictionSetting): + response_schema = response_format.to_dict() + elif isinstance(response_format, type) and issubclass( response_format, ModelSchema ): - response_schema = response_format.model_json_schema() + response_schema = { + "type": "json", + "jsonSchema": response_format.model_json_schema(), + } else: - response_schema = response_format + # Casts are needed as mypy doesn't detect that the given case patterns + # conform to the definition of LlmStructuredPredictionSettingDict + match response_format: + case {"type": "json", "jsonSchema": _} as json_schema: + response_schema = cast( + LlmStructuredPredictionSettingDict, json_schema + ) + case {"type": "gbnf", "gbnfGrammar": _} as gbnf_schema: + response_schema = cast( + LlmStructuredPredictionSettingDict, gbnf_schema + ) + case {"type": _}: + # Assume any other input with a type key is a JSON schema definition + response_schema = { + "type": "json", + "jsonSchema": response_format, + } + case _: + raise LMStudioValueError( + f"Failed to parse response format: {response_format!r}" + ) + dict_config["structured"] = response_schema else: # The response schema may also be passed in via the config # (doing it this way type hints as an unstructured result, # but we still allow it at runtime for consistency with JS) match dict_config: - case {"structured": {"type": "json"}}: + case {"structured": {"type": "json" | "gbnf"}}: structured = True case _: structured = False fields = _to_kv_config_stack_base(dict_config, TO_SERVER_PREDICTION) - if response_schema is not None: - fields.append( - { - "key": "llm.prediction.structured", - "value": { - "type": "json", - "jsonSchema": response_schema, - }, - } - ) additional_layers: list[KvConfigStackLayerDict] = [] if for_text_completion: additional_layers.append(_get_completion_config_layer()) diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index 47c2c40..7a3a520 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -34,7 +34,7 @@ from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException from .sdk_api import LMStudioRuntimeError, sdk_public_api, sdk_public_api_async -from .schemas import AnyLMStudioStruct, DictObject, DictSchema, ModelSchema +from .schemas import AnyLMStudioStruct, DictObject from .history import ( Chat, ChatHistoryDataDict, @@ -87,6 +87,7 @@ PredictionResult, PromptProcessingCallback, RemoteCallHandler, + ResponseSchema, TModelInfo, TPrediction, check_model_namespace, @@ -1030,7 +1031,7 @@ async def _complete_stream( model_specifier: AnyModelSpecifier, prompt: str, *, - response_format: Type[ModelSchema] | DictSchema = ..., + response_format: ResponseSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., preset: str | None = ..., on_message: PredictionMessageCallback | None = ..., @@ -1043,7 +1044,7 @@ async def _complete_stream( model_specifier: AnyModelSpecifier, prompt: str, *, - response_format: Type[ModelSchema] | DictSchema | None = None, + response_format: ResponseSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, preset: str | None = None, on_message: PredictionMessageCallback | None = None, @@ -1090,7 +1091,7 @@ async def _respond_stream( model_specifier: AnyModelSpecifier, history: Chat | ChatHistoryDataDict | str, *, - response_format: Type[ModelSchema] | DictSchema = ..., + response_format: ResponseSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., preset: str | None = ..., on_message: PredictionMessageCallback | None = ..., @@ -1103,7 +1104,7 @@ async def _respond_stream( model_specifier: AnyModelSpecifier, history: Chat | ChatHistoryDataDict | str, *, - response_format: Type[ModelSchema] | DictSchema | None = None, + response_format: ResponseSchema | None = None, on_message: PredictionMessageCallback | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, preset: str | None = None, @@ -1267,7 +1268,7 @@ async def complete_stream( self, prompt: str, *, - response_format: Type[ModelSchema] | DictSchema = ..., + response_format: ResponseSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., preset: str | None = ..., on_message: PredictionMessageCallback | None = ..., @@ -1280,7 +1281,7 @@ async def complete_stream( self, prompt: str, *, - response_format: Type[ModelSchema] | DictSchema | None = None, + response_format: ResponseSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, preset: str | None = None, on_message: PredictionMessageCallback | None = None, @@ -1322,7 +1323,7 @@ async def complete( self, prompt: str, *, - response_format: Type[ModelSchema] | DictSchema = ..., + response_format: ResponseSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., preset: str | None = ..., on_message: PredictionMessageCallback | None = ..., @@ -1335,7 +1336,7 @@ async def complete( self, prompt: str, *, - response_format: Type[ModelSchema] | DictSchema | None = None, + response_format: ResponseSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, preset: str | None = None, on_message: PredictionMessageCallback | None = None, @@ -1382,7 +1383,7 @@ async def respond_stream( self, history: Chat | ChatHistoryDataDict | str, *, - response_format: Type[ModelSchema] | DictSchema = ..., + response_format: ResponseSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., preset: str | None = ..., on_message: PredictionMessageCallback | None = ..., @@ -1395,7 +1396,7 @@ async def respond_stream( self, history: Chat | ChatHistoryDataDict | str, *, - response_format: Type[ModelSchema] | DictSchema | None = None, + response_format: ResponseSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, preset: str | None = None, on_message: PredictionMessageCallback | None = None, @@ -1437,7 +1438,7 @@ async def respond( self, history: Chat | ChatHistoryDataDict | str, *, - response_format: Type[ModelSchema] | DictSchema = ..., + response_format: ResponseSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., preset: str | None = ..., on_message: PredictionMessageCallback | None = ..., @@ -1450,7 +1451,7 @@ async def respond( self, history: Chat | ChatHistoryDataDict | str, *, - response_format: Type[ModelSchema] | DictSchema | None = None, + response_format: ResponseSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, preset: str | None = None, on_message: PredictionMessageCallback | None = None, diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index 3ac2037..51452ae 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -48,15 +48,14 @@ from .schemas import ( AnyLMStudioStruct, DictObject, - DictSchema, LMStudioStruct, - ModelSchema, TWireFormat, _format_json, _snake_case_keys_to_camelCase, _to_json_schema, ) from ._kv_config import ( + ResponseSchema, TLoadConfig, TLoadConfigDict, load_config_to_kv_config_stack, @@ -168,6 +167,7 @@ "PredictionResult", "PredictionRoundResult", "PromptProcessingCallback", + "ResponseSchema", "SerializedLMSExtendedError", "ToolDefinition", "ToolFunctionDef", @@ -185,6 +185,7 @@ AnyModelSpecifier: TypeAlias = str | ModelSpecifier | ModelQuery | DictObject AnyLoadConfig: TypeAlias = EmbeddingLoadModelConfig | LlmLoadModelConfig + GetOrLoadChannelRequest: TypeAlias = ( EmbeddingChannelGetOrLoadCreationParameter | LlmChannelGetOrLoadCreationParameter ) @@ -1122,7 +1123,7 @@ def __init__( self, model_specifier: AnyModelSpecifier, history: Chat, - response_format: Type[ModelSchema] | DictSchema | None = None, + response_format: ResponseSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, preset_config: str | None = None, on_message: PredictionMessageCallback | None = None, @@ -1181,7 +1182,7 @@ def __init__( @classmethod def _make_config_override( cls, - response_format: Type[ModelSchema] | DictSchema | None, + response_format: ResponseSchema | None, config: LlmPredictionConfig | LlmPredictionConfigDict | None, ) -> tuple[bool, KvConfigStack]: return prediction_config_to_kv_config_stack( @@ -1264,7 +1265,11 @@ def iter_message_events( # to parse the received content in the latter case. result_content = "".join(self._fragment_content) if self._structured and not self._is_cancelled: - parsed_content = json.loads(result_content) + try: + parsed_content = json.loads(result_content) + except json.JSONDecodeError: + # Fall back to unstructured result reporting + parsed_content = result_content else: parsed_content = result_content yield self._set_result( @@ -1385,7 +1390,7 @@ def __init__( self, model_specifier: AnyModelSpecifier, prompt: str, - response_format: Type[ModelSchema] | DictSchema | None = None, + response_format: ResponseSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, preset_config: str | None = None, on_message: PredictionMessageCallback | None = None, diff --git a/src/lmstudio/schemas.py b/src/lmstudio/schemas.py index 41f1b1d..2e4c3dd 100644 --- a/src/lmstudio/schemas.py +++ b/src/lmstudio/schemas.py @@ -10,6 +10,7 @@ MutableMapping, Protocol, Sequence, + TypeAlias, TypeVar, cast, runtime_checkable, @@ -26,14 +27,16 @@ __all__ = [ "BaseModel", - "ModelSchema", "DictObject", "DictSchema", + "ModelSchema", ] -DictObject = Mapping[str, Any] # Any JSON-compatible string-keyed dict -MutableDictObject = MutableMapping[str, Any] -DictSchema = Mapping[str, Any] # JSON schema as a string-keyed dict +DictObject: TypeAlias = Mapping[str, Any] # Any JSON-compatible string-keyed dict +MutableDictObject: TypeAlias = MutableMapping[str, Any] +DictSchema: TypeAlias = Mapping[str, Any] # JSON schema as a string-keyed dict +# It would be nice to require a "type" key in DictSchema, but that's currently tricky +# without "extra_items" support in TypedDict: https://peps.python.org/pep-0728/ def _format_json(data: Any, *, sort_keys: bool = True) -> str: diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index 9cb9f73..b9fa27f 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -53,7 +53,7 @@ sdk_callback_invocation, sdk_public_api, ) -from .schemas import AnyLMStudioStruct, DictObject, DictSchema, ModelSchema +from .schemas import AnyLMStudioStruct, DictObject from .history import ( AssistantResponse, ToolResultMessage, @@ -114,6 +114,7 @@ PredictionToolCallEvent, PromptProcessingCallback, RemoteCallHandler, + ResponseSchema, TModelInfo, TPrediction, ToolDefinition, @@ -1195,7 +1196,7 @@ def _complete_stream( model_specifier: AnyModelSpecifier, prompt: str, *, - response_format: Type[ModelSchema] | DictSchema = ..., + response_format: ResponseSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., preset: str | None = ..., on_message: PredictionMessageCallback | None = ..., @@ -1208,7 +1209,7 @@ def _complete_stream( model_specifier: AnyModelSpecifier, prompt: str, *, - response_format: Type[ModelSchema] | DictSchema | None = None, + response_format: ResponseSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, preset: str | None = None, on_message: PredictionMessageCallback | None = None, @@ -1255,7 +1256,7 @@ def _respond_stream( model_specifier: AnyModelSpecifier, history: Chat | ChatHistoryDataDict | str, *, - response_format: Type[ModelSchema] | DictSchema = ..., + response_format: ResponseSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., preset: str | None = ..., on_message: PredictionMessageCallback | None = ..., @@ -1268,7 +1269,7 @@ def _respond_stream( model_specifier: AnyModelSpecifier, history: Chat | ChatHistoryDataDict | str, *, - response_format: Type[ModelSchema] | DictSchema | None = None, + response_format: ResponseSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, preset: str | None = None, on_message: PredictionMessageCallback | None = None, @@ -1428,7 +1429,7 @@ def complete_stream( self, prompt: str, *, - response_format: Type[ModelSchema] | DictSchema = ..., + response_format: ResponseSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., preset: str | None = ..., on_message: PredictionMessageCallback | None = ..., @@ -1441,7 +1442,7 @@ def complete_stream( self, prompt: str, *, - response_format: Type[ModelSchema] | DictSchema | None = None, + response_format: ResponseSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, preset: str | None = None, on_message: PredictionMessageCallback | None = None, @@ -1483,7 +1484,7 @@ def complete( self, prompt: str, *, - response_format: Type[ModelSchema] | DictSchema = ..., + response_format: ResponseSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., preset: str | None = ..., on_message: PredictionMessageCallback | None = ..., @@ -1496,7 +1497,7 @@ def complete( self, prompt: str, *, - response_format: Type[ModelSchema] | DictSchema | None = None, + response_format: ResponseSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, preset: str | None = None, on_message: PredictionMessageCallback | None = None, @@ -1543,7 +1544,7 @@ def respond_stream( self, history: Chat | ChatHistoryDataDict | str, *, - response_format: Type[ModelSchema] | DictSchema = ..., + response_format: ResponseSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., preset: str | None = ..., on_message: PredictionMessageCallback | None = ..., @@ -1556,7 +1557,7 @@ def respond_stream( self, history: Chat | ChatHistoryDataDict | str, *, - response_format: Type[ModelSchema] | DictSchema | None = None, + response_format: ResponseSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, preset: str | None = None, on_message: PredictionMessageCallback | None = None, @@ -1598,7 +1599,7 @@ def respond( self, history: Chat | ChatHistoryDataDict | str, *, - response_format: Type[ModelSchema] | DictSchema = ..., + response_format: ResponseSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., preset: str | None = ..., on_message: PredictionMessageCallback | None = ..., @@ -1611,7 +1612,7 @@ def respond( self, history: Chat | ChatHistoryDataDict | str, *, - response_format: Type[ModelSchema] | DictSchema | None = None, + response_format: ResponseSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, preset: str | None = None, on_message: PredictionMessageCallback | None = None, diff --git a/tests/async/test_inference_async.py b/tests/async/test_inference_async.py index 4f619da..48ae8e2 100644 --- a/tests/async/test_inference_async.py +++ b/tests/async/test_inference_async.py @@ -6,14 +6,12 @@ import anyio import pytest from pytest import LogCaptureFixture as LogCap -from typing import Type from lmstudio import ( AssistantResponse, AsyncClient, AsyncPredictionStream, Chat, - DictSchema, LlmInfo, LlmLoadModelConfig, LlmPredictionConfig, @@ -22,13 +20,14 @@ LlmPredictionStats, LMStudioModelNotFoundError, LMStudioPresetNotFoundError, - ModelSchema, PredictionResult, + ResponseSchema, TextData, ) from ..support import ( EXPECTED_LLM_ID, + GBNF_GRAMMAR, PROMPT, RESPONSE_FORMATS, RESPONSE_SCHEMA, @@ -99,8 +98,8 @@ async def test_complete_stream_async(caplog: LogCap) -> None: @pytest.mark.asyncio @pytest.mark.lmstudio @pytest.mark.parametrize("format_type", RESPONSE_FORMATS) -async def test_complete_response_format_async( - format_type: Type[ModelSchema] | DictSchema, caplog: LogCap +async def test_complete_structured_response_format_async( + format_type: ResponseSchema, caplog: LogCap ) -> None: prompt = PROMPT caplog.set_level(logging.DEBUG) @@ -118,7 +117,7 @@ async def test_complete_response_format_async( @pytest.mark.asyncio @pytest.mark.lmstudio -async def test_complete_structured_config_async(caplog: LogCap) -> None: +async def test_complete_structured_config_json_async(caplog: LogCap) -> None: prompt = PROMPT caplog.set_level(logging.DEBUG) model_id = EXPECTED_LLM_ID @@ -144,6 +143,34 @@ async def test_complete_structured_config_async(caplog: LogCap) -> None: assert SCHEMA_FIELDS.keys() == response.parsed.keys() +@pytest.mark.asyncio +@pytest.mark.lmstudio +async def test_complete_structured_config_gbnf_async(caplog: LogCap) -> None: + prompt = PROMPT + caplog.set_level(logging.DEBUG) + model_id = EXPECTED_LLM_ID + async with AsyncClient() as client: + llm = await client.llm.model(model_id) + config: LlmPredictionConfigDict = { + # snake_case keys are accepted at runtime, + # but the type hinted spelling is the camelCase names + # This test case checks the schema field name is converted, + # but *not* the snake_case and camelCase field names in the + # schema itself + "structured": { + "type": "gbnf", + "gbnf_grammar": GBNF_GRAMMAR, + } # type: ignore[typeddict-item] + } + response = await llm.complete(prompt, config=config) + assert isinstance(response, PredictionResult) + logging.info(f"LLM response: {response!r}") + assert isinstance(response.content, str) + assert isinstance(response.parsed, dict) + assert response.parsed == json.loads(response.content) + assert SCHEMA_FIELDS.keys() == response.parsed.keys() + + @pytest.mark.asyncio @pytest.mark.lmstudio async def test_callbacks_text_completion_async(caplog: LogCap) -> None: diff --git a/tests/support/__init__.py b/tests/support/__init__.py index 57aab78..230ef5f 100644 --- a/tests/support/__init__.py +++ b/tests/support/__init__.py @@ -23,7 +23,7 @@ LMStudioChannelClosedError, ) from lmstudio.json_api import ChannelEndpoint -from lmstudio._sdk_models import LlmPredictionConfigDict +from lmstudio._sdk_models import LlmPredictionConfigDict, LlmStructuredPredictionSetting # Imports from the nominal "SDK" used in some test cases from .lmstudio import ErrFunc @@ -98,6 +98,22 @@ } SCHEMA_FIELD_NAMES = list(SCHEMA_FIELDS.keys()) +# Specify a JSON response format, so this can pass the JSON test cases +# String field definition is from the Llama JSON GBNF example at: +# https://github.com/ggml-org/llama.cpp/blob/960e72607761eb2dd170b33f02a5a2840ec412fe/grammars/json.gbnf#L16C1-L20C13 +# Note: comments and blank lines in the grammar are not yet supported +GBNF_GRAMMAR = r""" +root ::= "{\"response\":" response ",\"first_word_in_response\":" first-word-in-response ",\"lastWordInResponse\":" last-word-in-response "}" +response ::= string +first-word-in-response ::= string +last-word-in-response ::= string +string ::= + "\"" ( + [^"\\\x7f\x00-\x1f] | + "\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) + )* "\"" +""".lstrip() + SCHEMA = { "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", @@ -130,7 +146,30 @@ class LMStudioResponseFormat(BaseModel): lastWordInResponse: str -RESPONSE_FORMATS = (LMStudioResponseFormat, OtherResponseFormat, SCHEMA) +TYPED_JSON_SCHEMA = LlmStructuredPredictionSetting(type="json", json_schema=SCHEMA) +TYPED_JSON_SCHEMA_DICT = { + "type": "json", + "jsonSchema": SCHEMA, +} + +TYPED_GBNF_GRAMMAR = LlmStructuredPredictionSetting( + type="gbnf", gbnf_grammar=GBNF_GRAMMAR +) +TYPED_GBNF_GRAMMAR_DICT = { + "type": "gbnf", + "gbnfGrammar": GBNF_GRAMMAR, +} + + +RESPONSE_FORMATS = ( + LMStudioResponseFormat, + OtherResponseFormat, + SCHEMA, + TYPED_JSON_SCHEMA, + TYPED_JSON_SCHEMA_DICT, + TYPED_GBNF_GRAMMAR, + TYPED_GBNF_GRAMMAR_DICT, +) #################################################### # Provoke/emulate connection issues diff --git a/tests/sync/test_inference_sync.py b/tests/sync/test_inference_sync.py index 7c6b896..e96c183 100644 --- a/tests/sync/test_inference_sync.py +++ b/tests/sync/test_inference_sync.py @@ -13,14 +13,12 @@ import pytest from pytest import LogCaptureFixture as LogCap -from typing import Type from lmstudio import ( AssistantResponse, Client, PredictionStream, Chat, - DictSchema, LlmInfo, LlmLoadModelConfig, LlmPredictionConfig, @@ -29,13 +27,14 @@ LlmPredictionStats, LMStudioModelNotFoundError, LMStudioPresetNotFoundError, - ModelSchema, PredictionResult, + ResponseSchema, TextData, ) from ..support import ( EXPECTED_LLM_ID, + GBNF_GRAMMAR, PROMPT, RESPONSE_FORMATS, RESPONSE_SCHEMA, @@ -102,8 +101,8 @@ def test_complete_stream_sync(caplog: LogCap) -> None: @pytest.mark.lmstudio @pytest.mark.parametrize("format_type", RESPONSE_FORMATS) -def test_complete_response_format_sync( - format_type: Type[ModelSchema] | DictSchema, caplog: LogCap +def test_complete_structured_response_format_sync( + format_type: ResponseSchema, caplog: LogCap ) -> None: prompt = PROMPT caplog.set_level(logging.DEBUG) @@ -120,7 +119,7 @@ def test_complete_response_format_sync( @pytest.mark.lmstudio -def test_complete_structured_config_sync(caplog: LogCap) -> None: +def test_complete_structured_config_json_sync(caplog: LogCap) -> None: prompt = PROMPT caplog.set_level(logging.DEBUG) model_id = EXPECTED_LLM_ID @@ -146,6 +145,33 @@ def test_complete_structured_config_sync(caplog: LogCap) -> None: assert SCHEMA_FIELDS.keys() == response.parsed.keys() +@pytest.mark.lmstudio +def test_complete_structured_config_gbnf_sync(caplog: LogCap) -> None: + prompt = PROMPT + caplog.set_level(logging.DEBUG) + model_id = EXPECTED_LLM_ID + with Client() as client: + llm = client.llm.model(model_id) + config: LlmPredictionConfigDict = { + # snake_case keys are accepted at runtime, + # but the type hinted spelling is the camelCase names + # This test case checks the schema field name is converted, + # but *not* the snake_case and camelCase field names in the + # schema itself + "structured": { + "type": "gbnf", + "gbnf_grammar": GBNF_GRAMMAR, + } # type: ignore[typeddict-item] + } + response = llm.complete(prompt, config=config) + assert isinstance(response, PredictionResult) + logging.info(f"LLM response: {response!r}") + assert isinstance(response.content, str) + assert isinstance(response.parsed, dict) + assert response.parsed == json.loads(response.content) + assert SCHEMA_FIELDS.keys() == response.parsed.keys() + + @pytest.mark.lmstudio def test_callbacks_text_completion_sync(caplog: LogCap) -> None: messages: list[AssistantResponse] = []