Skip to content

Commit 2ca64fe

Browse files
authored
Avoid corrupting snake_case keys in JSON schemas (#43)
Also still report structured results at runtime when a response format is specified in the config (type hinted code still needs to use response_format for type checkers to correctly infer the type of `response.parsed`)
1 parent 9146aec commit 2ca64fe

File tree

7 files changed

+120
-22
lines changed

7 files changed

+120
-22
lines changed

src/lmstudio/_kv_config.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def prediction_config_to_kv_config_stack(
268268
response_format: Type[ModelSchema] | DictSchema | None,
269269
config: LlmPredictionConfig | LlmPredictionConfigDict | None,
270270
for_text_completion: bool = False,
271-
) -> KvConfigStack:
271+
) -> tuple[bool, KvConfigStack]:
272272
dict_config: DictObject
273273
if config is None:
274274
dict_config = {}
@@ -279,6 +279,7 @@ def prediction_config_to_kv_config_stack(
279279
dict_config = LlmPredictionConfig._from_any_dict(config).to_dict()
280280
response_schema: DictSchema | None = None
281281
if response_format is not None:
282+
structured = True
282283
if "structured" in dict_config:
283284
raise LMStudioValueError(
284285
"Cannot specify both 'response_format' in API call and 'structured' in config"
@@ -289,6 +290,15 @@ def prediction_config_to_kv_config_stack(
289290
response_schema = response_format.model_json_schema()
290291
else:
291292
response_schema = response_format
293+
else:
294+
# The response schema may also be passed in via the config
295+
# (doing it this way type hints as an unstructured result,
296+
# but we still allow it at runtime for consistency with JS)
297+
match dict_config:
298+
case {"structured": {"type": "json"}}:
299+
structured = True
300+
case _:
301+
structured = False
292302
fields = _to_kv_config_stack_base(
293303
dict_config,
294304
"llm",
@@ -308,7 +318,7 @@ def prediction_config_to_kv_config_stack(
308318
additional_layers: list[KvConfigStackLayerDict] = []
309319
if for_text_completion:
310320
additional_layers.append(_get_completion_config_layer())
311-
return _api_override_kv_config_stack(fields, additional_layers)
321+
return structured, _api_override_kv_config_stack(fields, additional_layers)
312322

313323

314324
def _get_completion_config_layer() -> KvConfigStackLayerDict:

src/lmstudio/json_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,7 +1144,7 @@ def __init__(
11441144
config["rawTools"] = llm_tools.to_dict()
11451145
else:
11461146
config.raw_tools = llm_tools
1147-
config_stack = self._make_config_override(response_format, config)
1147+
structured, config_stack = self._make_config_override(response_format, config)
11481148
params = PredictionChannelRequest._from_api_dict(
11491149
{
11501150
"modelSpecifier": _model_spec_to_api_dict(model_specifier),
@@ -1155,7 +1155,7 @@ def __init__(
11551155
super().__init__(params)
11561156
# Status tracking for the prediction progress and result reporting
11571157
self._is_cancelled = False
1158-
self._structured = response_format is not None
1158+
self._structured = structured
11591159
self._on_message = on_message
11601160
self._prompt_processing_progress = -1.0
11611161
self._on_prompt_processing_progress = on_prompt_processing_progress
@@ -1172,7 +1172,7 @@ def _make_config_override(
11721172
cls,
11731173
response_format: Type[ModelSchema] | DictSchema | None,
11741174
config: LlmPredictionConfig | LlmPredictionConfigDict | None,
1175-
) -> KvConfigStack:
1175+
) -> tuple[bool, KvConfigStack]:
11761176
return prediction_config_to_kv_config_stack(
11771177
response_format, config, **cls._additional_config_options()
11781178
)

src/lmstudio/schemas.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,23 @@ def model_json_schema(cls) -> DictSchema:
9393
"useFp16ForKvCache": "useFp16ForKVCache",
9494
}
9595

96+
_SKIP_FIELD_RECURSION = set(
97+
(
98+
"json_schema",
99+
"jsonSchema",
100+
)
101+
)
102+
96103

97104
def _snake_case_to_camelCase(key: str) -> str:
98105
first, *rest = key.split("_")
99106
camelCase = "".join((first, *(w.capitalize() for w in rest)))
100107
return _CAMEL_CASE_OVERRIDES.get(camelCase, camelCase)
101108

102109

110+
# TODO: Rework this conversion to be based on the API struct definitions
111+
# * Only recurse into fields that allow substructs
112+
# * Only check fields with a snake case -> camel case name conversion
103113
def _snake_case_keys_to_camelCase(data: DictObject) -> DictObject:
104114
translated_data: dict[str, Any] = {}
105115
dicts_to_process = [(data, translated_data)]
@@ -114,12 +124,14 @@ def _queue_dict(input_dict: DictObject, output_dict: dict[str, Any]) -> None:
114124

115125
for input_dict, output_dict in dicts_to_process:
116126
for k, v in input_dict.items():
117-
new_value: Any
118127
match v:
119128
case {}:
120-
new_dict: dict[str, Any] = {}
121-
_queue_dict(v, new_dict)
122-
new_value = new_dict
129+
if k in _SKIP_FIELD_RECURSION:
130+
new_value = v
131+
else:
132+
new_dict: dict[str, Any] = {}
133+
_queue_dict(v, new_dict)
134+
new_value = new_dict
123135
case [*_]:
124136
new_list: list[Any] = []
125137
for item in v:

tests/async/test_inference_async.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Chat,
1616
DictSchema,
1717
LlmInfo,
18+
LlmPredictionConfigDict,
1819
LlmPredictionFragment,
1920
LlmPredictionStats,
2021
LMStudioModelNotFoundError,
@@ -27,6 +28,8 @@
2728
EXPECTED_LLM_ID,
2829
PROMPT,
2930
RESPONSE_FORMATS,
31+
RESPONSE_SCHEMA,
32+
SCHEMA_FIELDS,
3033
SHORT_PREDICTION_CONFIG,
3134
check_sdk_error,
3235
)
@@ -93,7 +96,7 @@ async def test_complete_stream_async(caplog: LogCap) -> None:
9396
@pytest.mark.asyncio
9497
@pytest.mark.lmstudio
9598
@pytest.mark.parametrize("format_type", RESPONSE_FORMATS)
96-
async def test_complete_structured_async(
99+
async def test_complete_response_format_async(
97100
format_type: Type[ModelSchema] | DictSchema, caplog: LogCap
98101
) -> None:
99102
prompt = PROMPT
@@ -107,7 +110,35 @@ async def test_complete_structured_async(
107110
assert isinstance(response.content, str)
108111
assert isinstance(response.parsed, dict)
109112
assert response.parsed == json.loads(response.content)
110-
assert "response" in response.parsed
113+
assert SCHEMA_FIELDS.keys() == response.parsed.keys()
114+
115+
116+
@pytest.mark.asyncio
117+
@pytest.mark.lmstudio
118+
async def test_complete_structured_config_async(caplog: LogCap) -> None:
119+
prompt = PROMPT
120+
caplog.set_level(logging.DEBUG)
121+
model_id = EXPECTED_LLM_ID
122+
async with AsyncClient() as client:
123+
llm = await client.llm.model(model_id)
124+
config: LlmPredictionConfigDict = {
125+
# snake_case keys are accepted at runtime,
126+
# but the type hinted spelling is the camelCase names
127+
# This test case checks the schema field name is converted,
128+
# but *not* the snake_case and camelCase field names in the
129+
# schema itself
130+
"structured": {
131+
"type": "json",
132+
"json_schema": RESPONSE_SCHEMA,
133+
} # type: ignore[typeddict-item]
134+
}
135+
response = await llm.complete(prompt, config=config)
136+
assert isinstance(response, PredictionResult)
137+
logging.info(f"LLM response: {response!r}")
138+
assert isinstance(response.content, str)
139+
assert isinstance(response.parsed, dict)
140+
assert response.parsed == json.loads(response.content)
141+
assert SCHEMA_FIELDS.keys() == response.parsed.keys()
111142

112143

113144
@pytest.mark.asyncio

tests/support/__init__.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,22 +82,34 @@
8282
# Structured LLM responses
8383
####################################################
8484

85+
# Schema includes both snake_case and camelCase field
86+
# names to ensure the special-casing of snake_case
87+
# fields in dict inputs doesn't corrupt schema inputs
88+
SCHEMA_FIELDS = {
89+
"response": {
90+
"type": "string",
91+
},
92+
"first_word_in_response": {
93+
"type": "string",
94+
},
95+
"lastWordInResponse": {
96+
"type": "string",
97+
},
98+
}
99+
SCHEMA_FIELD_NAMES = list(SCHEMA_FIELDS.keys())
100+
85101
SCHEMA = {
86102
"$schema": "http://json-schema.org/draft-07/schema#",
87103
"type": "object",
88-
"required": ["response"],
89-
"properties": {
90-
"response": {
91-
"type": "string",
92-
}
93-
},
104+
"required": SCHEMA_FIELD_NAMES,
105+
"properties": SCHEMA_FIELDS,
94106
"additionalProperties": False,
95107
}
96108
RESPONSE_SCHEMA = {
97109
"$defs": {
98110
"schema": {
99-
"properties": {"response": {"type": "string"}},
100-
"required": ["response"],
111+
"properties": SCHEMA_FIELDS,
112+
"required": SCHEMA_FIELD_NAMES,
101113
"title": "schema",
102114
"type": "object",
103115
}
@@ -114,6 +126,8 @@ def model_json_schema(cls) -> DictSchema:
114126

115127
class LMStudioResponseFormat(BaseModel):
116128
response: str
129+
first_word_in_response: str
130+
lastWordInResponse: str
117131

118132

119133
RESPONSE_FORMATS = (LMStudioResponseFormat, OtherResponseFormat, SCHEMA)

tests/sync/test_inference_sync.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Chat,
2323
DictSchema,
2424
LlmInfo,
25+
LlmPredictionConfigDict,
2526
LlmPredictionFragment,
2627
LlmPredictionStats,
2728
LMStudioModelNotFoundError,
@@ -34,6 +35,8 @@
3435
EXPECTED_LLM_ID,
3536
PROMPT,
3637
RESPONSE_FORMATS,
38+
RESPONSE_SCHEMA,
39+
SCHEMA_FIELDS,
3740
SHORT_PREDICTION_CONFIG,
3841
check_sdk_error,
3942
)
@@ -96,7 +99,7 @@ def test_complete_stream_sync(caplog: LogCap) -> None:
9699

97100
@pytest.mark.lmstudio
98101
@pytest.mark.parametrize("format_type", RESPONSE_FORMATS)
99-
def test_complete_structured_sync(
102+
def test_complete_response_format_sync(
100103
format_type: Type[ModelSchema] | DictSchema, caplog: LogCap
101104
) -> None:
102105
prompt = PROMPT
@@ -110,7 +113,34 @@ def test_complete_structured_sync(
110113
assert isinstance(response.content, str)
111114
assert isinstance(response.parsed, dict)
112115
assert response.parsed == json.loads(response.content)
113-
assert "response" in response.parsed
116+
assert SCHEMA_FIELDS.keys() == response.parsed.keys()
117+
118+
119+
@pytest.mark.lmstudio
120+
def test_complete_structured_config_sync(caplog: LogCap) -> None:
121+
prompt = PROMPT
122+
caplog.set_level(logging.DEBUG)
123+
model_id = EXPECTED_LLM_ID
124+
with Client() as client:
125+
llm = client.llm.model(model_id)
126+
config: LlmPredictionConfigDict = {
127+
# snake_case keys are accepted at runtime,
128+
# but the type hinted spelling is the camelCase names
129+
# This test case checks the schema field name is converted,
130+
# but *not* the snake_case and camelCase field names in the
131+
# schema itself
132+
"structured": {
133+
"type": "json",
134+
"json_schema": RESPONSE_SCHEMA,
135+
} # type: ignore[typeddict-item]
136+
}
137+
response = llm.complete(prompt, config=config)
138+
assert isinstance(response, PredictionResult)
139+
logging.info(f"LLM response: {response!r}")
140+
assert isinstance(response.content, str)
141+
assert isinstance(response.parsed, dict)
142+
assert response.parsed == json.loads(response.content)
143+
assert SCHEMA_FIELDS.keys() == response.parsed.keys()
114144

115145

116146
@pytest.mark.lmstudio

tests/test_kv_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,8 @@ def test_kv_stack_load_config_llm(config_dict: DictObject) -> None:
419419
def test_kv_stack_prediction_config(config_dict: DictObject) -> None:
420420
# MyPy complains here that it can't be sure the dict has all the right keys
421421
# It is correct about that, but we want to ensure it is handled at runtime
422-
kv_stack = prediction_config_to_kv_config_stack(None, config_dict) # type: ignore[arg-type]
422+
structured, kv_stack = prediction_config_to_kv_config_stack(None, config_dict) # type: ignore[arg-type]
423+
assert structured
423424
assert kv_stack.to_dict() == EXPECTED_KV_STACK_PREDICTION
424425

425426

0 commit comments

Comments
 (0)