Skip to content

Commit 05c7157

Browse files
authored
Publish config retrieval APIs (#53)
* model handles provide a public "get_load_config" method * prediction results report the prediction config and model config Part of #33
1 parent 1f5340f commit 05c7157

11 files changed

+77
-45
lines changed

src/lmstudio/_kv_config.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def update_client_config(
123123
"contextLength": ConfigField("contextLength"),
124124
}
125125

126-
_SUPPORTED_SERVER_KEYS: dict[str, DictObject] = {
126+
SUPPORTED_SERVER_KEYS: dict[str, DictObject] = {
127127
"load": {
128128
"gpuSplitConfig": MultiPartField(
129129
"gpuOffload", ("mainGpu", "splitStrategy", "disabledGpus")
@@ -189,7 +189,7 @@ def _iter_server_keys(*namespaces: str) -> Iterable[tuple[str, ConfigField]]:
189189
# Map dotted config field names to their client config field counterparts
190190
for namespace in namespaces:
191191
scopes: list[tuple[str, DictObject]] = [
192-
(namespace, _SUPPORTED_SERVER_KEYS[namespace])
192+
(namespace, SUPPORTED_SERVER_KEYS[namespace])
193193
]
194194
for prefix, scope in scopes:
195195
for k, v in scope.items():
@@ -204,6 +204,7 @@ def _iter_server_keys(*namespaces: str) -> Iterable[tuple[str, ConfigField]]:
204204
FROM_SERVER_LOAD_LLM = dict(_iter_server_keys("load", "llm.load"))
205205
FROM_SERVER_LOAD_EMBEDDING = dict(_iter_server_keys("load", "embedding.load"))
206206
FROM_SERVER_PREDICTION = dict(_iter_server_keys("llm.prediction"))
207+
FROM_SERVER_CONFIG = dict(_iter_server_keys(*SUPPORTED_SERVER_KEYS))
207208

208209

209210
# Define mappings to translate client config instances to server KV configs
@@ -237,8 +238,26 @@ def dict_from_kvconfig(config: KvConfig) -> DictObject:
237238
return {kv.key: kv.value for kv in config.fields}
238239

239240

240-
def dict_from_fields_key(config: DictObject) -> DictObject:
241-
return {kv["key"]: kv["value"] for kv in config.get("fields", [])}
241+
def parse_server_config(server_config: DictObject) -> DictObject:
242+
"""Map server config fields to client config fields."""
243+
result: MutableDictObject = {}
244+
for kv in server_config.get("fields", []):
245+
key = kv["key"]
246+
config_field = FROM_SERVER_CONFIG.get(key, None)
247+
if config_field is None:
248+
# Skip unknown keys (server might be newer than the SDK)
249+
continue
250+
value = kv["value"]
251+
config_field.update_client_config(result, value)
252+
return result
253+
254+
255+
def parse_llm_load_config(server_config: DictObject) -> LlmLoadModelConfig:
256+
return LlmLoadModelConfig._from_any_api_dict(parse_server_config(server_config))
257+
258+
259+
def parse_prediction_config(server_config: DictObject) -> LlmPredictionConfig:
260+
return LlmPredictionConfig._from_any_api_dict(parse_server_config(server_config))
242261

243262

244263
def _api_override_kv_config_stack(

src/lmstudio/async_api.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
_LocalFileData,
4444
)
4545
from .json_api import (
46+
AnyLoadConfig,
4647
AnyModelSpecifier,
4748
AvailableModelBase,
4849
ChannelEndpoint,
@@ -93,7 +94,7 @@
9394
_model_spec_to_api_dict,
9495
_redact_json,
9596
)
96-
from ._kv_config import TLoadConfig, TLoadConfigDict, dict_from_fields_key
97+
from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config
9798
from ._sdk_models import (
9899
EmbeddingRpcEmbedStringParameter,
99100
EmbeddingRpcTokenizeParameter,
@@ -693,7 +694,9 @@ def _system_session(self) -> AsyncSessionSystem:
693694
def _files_session(self) -> _AsyncSessionFiles:
694695
return self._client.files
695696

696-
async def _get_load_config(self, model_specifier: AnyModelSpecifier) -> DictObject:
697+
async def _get_load_config(
698+
self, model_specifier: AnyModelSpecifier
699+
) -> AnyLoadConfig:
697700
"""Get the model load config for the specified model."""
698701
# Note that the configuration reported here uses the *server* config names,
699702
# not the attributes used to set the configuration in the client SDK
@@ -703,7 +706,8 @@ async def _get_load_config(self, model_specifier: AnyModelSpecifier) -> DictObje
703706
}
704707
)
705708
config = await self.remote_call("getLoadConfig", params)
706-
return dict_from_fields_key(config)
709+
result_type = self._API_TYPES.MODEL_LOAD_CONFIG
710+
return result_type._from_any_api_dict(parse_server_config(config))
707711

708712
async def _get_api_model_info(self, model_specifier: AnyModelSpecifier) -> Any:
709713
"""Get the raw model info (if any) for a model matching the given criteria."""
@@ -1158,7 +1162,9 @@ async def _embed(
11581162
)
11591163

11601164

1161-
class AsyncModelHandle(ModelHandleBase[TAsyncSessionModel]):
1165+
class AsyncModelHandle(
1166+
Generic[TAsyncSessionModel], ModelHandleBase[TAsyncSessionModel]
1167+
):
11621168
"""Reference to a loaded LM Studio model."""
11631169

11641170
@sdk_public_api_async()
@@ -1171,9 +1177,8 @@ async def get_info(self) -> ModelInstanceInfo:
11711177
"""Get the model info for this model."""
11721178
return await self._session.get_model_info(self.identifier)
11731179

1174-
# Private until this API can emit the client config types
11751180
@sdk_public_api_async()
1176-
async def _get_load_config(self) -> DictObject:
1181+
async def get_load_config(self) -> AnyLoadConfig:
11771182
"""Get the model load config for this model."""
11781183
return await self._session._get_load_config(self.identifier)
11791184

src/lmstudio/json_api.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@
5959
from ._kv_config import (
6060
TLoadConfig,
6161
TLoadConfigDict,
62-
dict_from_fields_key,
6362
load_config_to_kv_config_stack,
63+
parse_llm_load_config,
64+
parse_prediction_config,
6465
prediction_config_to_kv_config_stack,
6566
)
6667
from ._sdk_models import (
@@ -128,6 +129,7 @@
128129
# implicitly as part of the top-level `lmstudio` API.
129130
__all__ = [
130131
"ActResult",
132+
"AnyLoadConfig",
131133
"AnyModelSpecifier",
132134
"DownloadFinalizedCallback",
133135
"DownloadProgressCallback",
@@ -180,6 +182,7 @@
180182
DEFAULT_TTL = 60 * 60 # By default, leaves idle models loaded for an hour
181183

182184
AnyModelSpecifier: TypeAlias = str | ModelSpecifier | ModelQuery | DictObject
185+
AnyLoadConfig: TypeAlias = EmbeddingLoadModelConfig | LlmLoadModelConfig
183186

184187
GetOrLoadChannelRequest: TypeAlias = (
185188
EmbeddingChannelGetOrLoadCreationParameter | LlmChannelGetOrLoadCreationParameter
@@ -441,12 +444,9 @@ class PredictionResult(Generic[TPrediction]):
441444
parsed: TPrediction # dict for structured predictions, str otherwise
442445
stats: LlmPredictionStats # Statistics about the prediction process
443446
model_info: LlmInfo # Information about the model used
444-
structured: bool = field(init=False) # Whether the result is structured or not
445-
# Note that the configuration reported here uses the *server* config names,
446-
# not the attributes used to set the configuration in the client SDK
447-
# Private until these attributes store the client config types
448-
_load_config: DictObject # The configuration used to load the model
449-
_prediction_config: DictObject # The configuration used for the prediction
447+
structured: bool = field(init=False) # Whether the result is structured or not
448+
load_config: LlmLoadModelConfig # The configuration used to load the model
449+
prediction_config: LlmPredictionConfig # The configuration used for the prediction
450450
# fmt: on
451451

452452
def __post_init__(self) -> None:
@@ -1262,8 +1262,8 @@ def iter_message_events(
12621262
parsed=parsed_content,
12631263
stats=LlmPredictionStats._from_any_api_dict(stats),
12641264
model_info=LlmInfo._from_any_api_dict(model_info),
1265-
_load_config=dict_from_fields_key(load_kvconfig),
1266-
_prediction_config=dict_from_fields_key(prediction_kvconfig),
1265+
load_config=parse_llm_load_config(load_kvconfig),
1266+
prediction_config=parse_prediction_config(prediction_kvconfig),
12671267
)
12681268
)
12691269
case unmatched:
@@ -1477,19 +1477,19 @@ def model_info(self) -> LlmInfo | None:
14771477

14781478
# Private until this API can emit the client config types
14791479
@property
1480-
def _load_config(self) -> DictObject | None:
1480+
def _load_config(self) -> LlmLoadModelConfig | None:
14811481
"""Get the load configuration used for the current prediction if available."""
14821482
if self._final_result is None:
14831483
return None
1484-
return self._final_result._load_config
1484+
return self._final_result.load_config
14851485

14861486
# Private until this API can emit the client config types
14871487
@property
1488-
def _prediction_config(self) -> DictObject | None:
1488+
def _prediction_config(self) -> LlmPredictionConfig | None:
14891489
"""Get the prediction configuration used for the current prediction if available."""
14901490
if self._final_result is None:
14911491
return None
1492-
return self._final_result._prediction_config
1492+
return self._final_result.prediction_config
14931493

14941494
@sdk_public_api()
14951495
def result(self) -> PredictionResult[TPrediction]:

src/lmstudio/sync_api.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
)
6767
from .json_api import (
6868
ActResult,
69+
AnyLoadConfig,
6970
AnyModelSpecifier,
7071
AvailableModelBase,
7172
ChannelEndpoint,
@@ -121,7 +122,7 @@
121122
_model_spec_to_api_dict,
122123
_redact_json,
123124
)
124-
from ._kv_config import TLoadConfig, TLoadConfigDict, dict_from_fields_key
125+
from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config
125126
from ._sdk_models import (
126127
EmbeddingRpcEmbedStringParameter,
127128
EmbeddingRpcTokenizeParameter,
@@ -866,7 +867,7 @@ def _system_session(self) -> SyncSessionSystem:
866867
def _files_session(self) -> _SyncSessionFiles:
867868
return self._client.files
868869

869-
def _get_load_config(self, model_specifier: AnyModelSpecifier) -> DictObject:
870+
def _get_load_config(self, model_specifier: AnyModelSpecifier) -> AnyLoadConfig:
870871
"""Get the model load config for the specified model."""
871872
# Note that the configuration reported here uses the *server* config names,
872873
# not the attributes used to set the configuration in the client SDK
@@ -876,7 +877,8 @@ def _get_load_config(self, model_specifier: AnyModelSpecifier) -> DictObject:
876877
}
877878
)
878879
config = self.remote_call("getLoadConfig", params)
879-
return dict_from_fields_key(config)
880+
result_type = self._API_TYPES.MODEL_LOAD_CONFIG
881+
return result_type._from_any_api_dict(parse_server_config(config))
880882

881883
def _get_api_model_info(self, model_specifier: AnyModelSpecifier) -> Any:
882884
"""Get the raw model info (if any) for a model matching the given criteria."""
@@ -1339,7 +1341,7 @@ def get_info(self) -> ModelInstanceInfo:
13391341

13401342
# Private until this API can emit the client config types
13411343
@sdk_public_api()
1342-
def _get_load_config(self) -> DictObject:
1344+
def _get_load_config(self) -> AnyLoadConfig:
13431345
"""Get the model load config for this model."""
13441346
return self._session._get_load_config(self.identifier)
13451347

tests/async/test_embedding_async.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
from pytest import LogCaptureFixture as LogCap
88

9-
from lmstudio import AsyncClient, LMStudioModelNotFoundError
9+
from lmstudio import AsyncClient, EmbeddingLoadModelConfig, LMStudioModelNotFoundError
1010

1111
from ..support import (
1212
EXPECTED_EMBEDDING,
@@ -114,7 +114,7 @@ async def test_get_load_config_async(model_id: str, caplog: LogCap) -> None:
114114
response = await client.embedding._get_load_config(model_id)
115115
logging.info(f"Load config response: {response}")
116116
assert response
117-
assert isinstance(response, dict)
117+
assert isinstance(response, EmbeddingLoadModelConfig)
118118

119119

120120
@pytest.mark.asyncio

tests/async/test_inference_async.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
Chat,
1616
DictSchema,
1717
LlmInfo,
18+
LlmLoadModelConfig,
19+
LlmPredictionConfig,
1820
LlmPredictionConfigDict,
1921
LlmPredictionFragment,
2022
LlmPredictionStats,
@@ -269,12 +271,12 @@ async def test_complete_prediction_metadata_async(caplog: LogCap) -> None:
269271
logging.info(f"LLM response: {response.content!r}")
270272
assert response.stats
271273
assert response.model_info
272-
assert response._load_config
273-
assert response._prediction_config
274+
assert response.load_config
275+
assert response.prediction_config
274276
assert isinstance(response.stats, LlmPredictionStats)
275277
assert isinstance(response.model_info, LlmInfo)
276-
assert isinstance(response._load_config, dict)
277-
assert isinstance(response._prediction_config, dict)
278+
assert isinstance(response.load_config, LlmLoadModelConfig)
279+
assert isinstance(response.prediction_config, LlmPredictionConfig)
278280

279281

280282
@pytest.mark.asyncio

tests/async/test_llm_async.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from pytest import LogCaptureFixture as LogCap
77

8-
from lmstudio import AsyncClient, history
8+
from lmstudio import AsyncClient, LlmLoadModelConfig, history
99

1010
from ..support import EXPECTED_LLM, EXPECTED_LLM_ID
1111

@@ -96,7 +96,7 @@ async def test_get_load_config_async(model_id: str, caplog: LogCap) -> None:
9696
response = await client.llm._get_load_config(model_id)
9797
logging.info(f"Load config response: {response}")
9898
assert response
99-
assert isinstance(response, dict)
99+
assert isinstance(response, LlmLoadModelConfig)
100100

101101

102102
@pytest.mark.asyncio

tests/sync/test_embedding_sync.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import pytest
1414
from pytest import LogCaptureFixture as LogCap
1515

16-
from lmstudio import Client, LMStudioModelNotFoundError
16+
from lmstudio import Client, EmbeddingLoadModelConfig, LMStudioModelNotFoundError
1717

1818
from ..support import (
1919
EXPECTED_EMBEDDING,
@@ -115,7 +115,7 @@ def test_get_load_config_sync(model_id: str, caplog: LogCap) -> None:
115115
response = client.embedding._get_load_config(model_id)
116116
logging.info(f"Load config response: {response}")
117117
assert response
118-
assert isinstance(response, dict)
118+
assert isinstance(response, EmbeddingLoadModelConfig)
119119

120120

121121
@pytest.mark.lmstudio

tests/sync/test_inference_sync.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
Chat,
2323
DictSchema,
2424
LlmInfo,
25+
LlmLoadModelConfig,
26+
LlmPredictionConfig,
2527
LlmPredictionConfigDict,
2628
LlmPredictionFragment,
2729
LlmPredictionStats,
@@ -264,12 +266,12 @@ def test_complete_prediction_metadata_sync(caplog: LogCap) -> None:
264266
logging.info(f"LLM response: {response.content!r}")
265267
assert response.stats
266268
assert response.model_info
267-
assert response._load_config
268-
assert response._prediction_config
269+
assert response.load_config
270+
assert response.prediction_config
269271
assert isinstance(response.stats, LlmPredictionStats)
270272
assert isinstance(response.model_info, LlmInfo)
271-
assert isinstance(response._load_config, dict)
272-
assert isinstance(response._prediction_config, dict)
273+
assert isinstance(response.load_config, LlmLoadModelConfig)
274+
assert isinstance(response.prediction_config, LlmPredictionConfig)
273275

274276

275277
@pytest.mark.lmstudio

tests/sync/test_llm_sync.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pytest
1313
from pytest import LogCaptureFixture as LogCap
1414

15-
from lmstudio import Client, history
15+
from lmstudio import Client, LlmLoadModelConfig, history
1616

1717
from ..support import EXPECTED_LLM, EXPECTED_LLM_ID
1818

@@ -96,7 +96,7 @@ def test_get_load_config_sync(model_id: str, caplog: LogCap) -> None:
9696
response = client.llm._get_load_config(model_id)
9797
logging.info(f"Load config response: {response}")
9898
assert response
99-
assert isinstance(response, dict)
99+
assert isinstance(response, LlmLoadModelConfig)
100100

101101

102102
@pytest.mark.lmstudio

tests/test_history.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
)
2525
from lmstudio.json_api import (
2626
LlmInfo,
27+
LlmLoadModelConfig,
28+
LlmPredictionConfig,
2729
LlmPredictionStats,
2830
PredictionResult,
2931
TPrediction,
@@ -347,8 +349,8 @@ def _make_prediction_result(data: TPrediction) -> PredictionResult[TPrediction]:
347349
trained_for_tool_use=False,
348350
max_context_length=32,
349351
),
350-
_load_config={},
351-
_prediction_config={},
352+
load_config=LlmLoadModelConfig(),
353+
prediction_config=LlmPredictionConfig(),
352354
)
353355

354356

0 commit comments

Comments
 (0)