Skip to content

Commit f529c49

Browse files
authored
Handle JIT model unloading notification messages (#26)
Closes #25
1 parent 184acc1 commit f529c49

File tree

9 files changed

+229
-46
lines changed

9 files changed

+229
-46
lines changed

src/lmstudio/json_api.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,27 @@ def iter_message_events(
862862
yield from self._update_progress(0.0)
863863
case {"type": "loadProgress" | "progress", "progress": progress}:
864864
yield from self._update_progress(progress)
865+
case {"type": "unloadingOtherJITModel", "info": other_model_info} if (
866+
"modelKey" in other_model_info
867+
):
868+
jit_unload_event = "Unloading other JIT model"
869+
unloaded_model_key = other_model_info["modelKey"]
870+
suggestion = (
871+
"You can disable this behavior by going to "
872+
"LM Studio -> Settings -> Developer -> Turn OFF JIT models auto-evict"
873+
)
874+
# Report the JIT unload
875+
self._logger.info(
876+
jit_unload_event,
877+
unloaded_model_key=unloaded_model_key,
878+
suggestion=suggestion,
879+
)
880+
# Report further details on the unloaded model if debug messages are enabled
881+
self._logger.debug(
882+
jit_unload_event,
883+
unloaded_model_key=unloaded_model_key,
884+
unloaded_model=other_model_info,
885+
)
865886
case {
866887
"type": "success" | "alreadyLoaded" | "loadSuccess",
867888
"info": {

tests/README.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,24 @@ conditions must also be met for the test suite to pass:
1313
- the API server must be enabled and running on port 1234
1414
- the following models model must be loaded with their default identifiers
1515
- `text-embedding-nomic-embed-text-v1.5` (text embedding model)
16-
- `llama-3.2-1b-instruct` (chat oriented text LLM)
16+
- `llama-3.2-1b-instruct` (text LLM)
1717
- `ZiangWu/MobileVLM_V2-1.7B-GGUF` (visual LLM)
1818
- `qwen2.5-7b-instruct-1m` (tool using LLM)
1919

2020
Additional models should NOT be loaded when running the test suite,
2121
as some model querying tests may fail in that case.
2222

23-
However, there's no problem with having additional models downloaded.
23+
There are also some JIT model loading/unloading test cases which
24+
expect `smollm2-135m` (small text LLM) to already be downloaded.
25+
A full test run will download this model (since it is also the
26+
model used for the end-to-end search-and-download test case).
27+
28+
There's no problem with having additional models downloaded.
2429
The only impact is that the test that checks all of the expected
2530
models can be found in the list of downloaded models will take a
2631
little longer to run.
2732

33+
2834
# Loading and unloading the required models
2935

3036
The `load-test-models` `tox` environment can be used to ensure the required
@@ -44,6 +50,12 @@ explicitly unload the test models:
4450
$ tox -m unload-test-models
4551
```
4652

53+
The model downloading test cases can be specifically run with:
54+
55+
```console
56+
$ tox -m test -- -k test_download_model
57+
```
58+
4759

4860
## Adding new tests
4961

tests/async/test_model_catalog_async.py

Lines changed: 87 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import asyncio
44
import logging
55

6+
from contextlib import suppress
7+
68
import pytest
79
from pytest import LogCaptureFixture as LogCap
810
from pytest_subtests import SubTests
@@ -13,12 +15,11 @@
1315
from ..support import (
1416
LLM_LOAD_CONFIG,
1517
EXPECTED_LLM,
16-
EXPECTED_LLM_DEFAULT_ID,
1718
EXPECTED_LLM_ID,
1819
EXPECTED_EMBEDDING,
19-
EXPECTED_EMBEDDING_DEFAULT_ID,
2020
EXPECTED_EMBEDDING_ID,
2121
EXPECTED_VLM_ID,
22+
SMALL_LLM_ID,
2223
TOOL_LLM_ID,
2324
check_sdk_error,
2425
)
@@ -291,16 +292,17 @@ async def test_get_or_load_when_unloaded_llm_async(caplog: LogCap) -> None:
291292
caplog.set_level(logging.DEBUG)
292293
async with AsyncClient() as client:
293294
llm = client.llm
294-
await llm.unload(EXPECTED_LLM_ID)
295-
model = await llm.model(EXPECTED_LLM_DEFAULT_ID, config=LLM_LOAD_CONFIG)
296-
assert model.identifier == EXPECTED_LLM_DEFAULT_ID
295+
with suppress(LMStudioModelNotFoundError):
296+
await llm.unload(EXPECTED_LLM_ID)
297+
model = await llm.model(EXPECTED_LLM_ID, config=LLM_LOAD_CONFIG)
298+
assert model.identifier == EXPECTED_LLM_ID
297299
# LM Studio may default to JIT handling for models loaded with `getOrLoad`,
298300
# so ensure we restore a regular non-JIT instance with no TTL set
299-
await llm.unload(EXPECTED_LLM_ID)
301+
await model.unload()
300302
model = await llm.load_new_instance(
301-
EXPECTED_LLM_DEFAULT_ID, config=LLM_LOAD_CONFIG, ttl=None
303+
EXPECTED_LLM_ID, config=LLM_LOAD_CONFIG, ttl=None
302304
)
303-
assert model.identifier == EXPECTED_LLM_DEFAULT_ID
305+
assert model.identifier == EXPECTED_LLM_ID
304306

305307

306308
@pytest.mark.asyncio
@@ -310,13 +312,83 @@ async def test_get_or_load_when_unloaded_embedding_async(caplog: LogCap) -> None
310312
caplog.set_level(logging.DEBUG)
311313
async with AsyncClient() as client:
312314
embedding = client.embedding
313-
await embedding.unload(EXPECTED_EMBEDDING_ID)
314-
model = await embedding.model(EXPECTED_EMBEDDING_DEFAULT_ID)
315-
assert model.identifier == EXPECTED_EMBEDDING_DEFAULT_ID
315+
with suppress(LMStudioModelNotFoundError):
316+
await embedding.unload(EXPECTED_EMBEDDING_ID)
317+
model = await embedding.model(EXPECTED_EMBEDDING_ID)
318+
assert model.identifier == EXPECTED_EMBEDDING_ID
316319
# LM Studio may default to JIT handling for models loaded with `getOrLoad`,
317320
# so ensure we restore a regular non-JIT instance with no TTL set
318-
await embedding.unload(EXPECTED_EMBEDDING_ID)
319-
model = await embedding.load_new_instance(
320-
EXPECTED_EMBEDDING_DEFAULT_ID, ttl=None
321+
await model.unload()
322+
model = await embedding.load_new_instance(EXPECTED_EMBEDDING_ID, ttl=None)
323+
assert model.identifier == EXPECTED_EMBEDDING_ID
324+
325+
326+
@pytest.mark.asyncio
327+
@pytest.mark.slow
328+
@pytest.mark.lmstudio
329+
async def test_jit_unloading_async(caplog: LogCap) -> None:
330+
# For the time being, only test the embedding vs LLM cross-namespace
331+
# JIT unloading (since that ensures the info type mixing is handled).
332+
# Assuming LM Studio eventually switches to per-namespace JIT unloading,
333+
# this can be split into separate LLM and embedding test cases at that time.
334+
caplog.set_level(logging.DEBUG)
335+
async with AsyncClient() as client:
336+
# Unload the non-JIT instance of the embedding model
337+
with suppress(LMStudioModelNotFoundError):
338+
await client.embedding.unload(EXPECTED_EMBEDDING_ID)
339+
# Load a JIT instance of the embedding model
340+
model1 = await client.embedding.model(EXPECTED_EMBEDDING_ID, ttl=300)
341+
assert model1.identifier == EXPECTED_EMBEDDING_ID
342+
model1_info = await model1.get_info()
343+
assert model1_info.identifier == model1.identifier
344+
# Load a JIT instance of the small testing LLM
345+
# This will unload the JIT instance of the testing embedding model
346+
model2 = await client.llm.model(SMALL_LLM_ID, ttl=300)
347+
assert model2.identifier == SMALL_LLM_ID
348+
model2_info = await model2.get_info()
349+
assert model2_info.identifier == model2.identifier
350+
# Attempting to query the now unloaded JIT embedding model will fail
351+
with pytest.raises(LMStudioModelNotFoundError):
352+
await model1.get_info()
353+
# Restore things to the way other test cases expect them to be
354+
await model2.unload()
355+
model = await client.embedding.load_new_instance(
356+
EXPECTED_EMBEDDING_ID, ttl=None
321357
)
322-
assert model.identifier == EXPECTED_EMBEDDING_DEFAULT_ID
358+
assert model.identifier == EXPECTED_EMBEDDING_ID
359+
360+
# Check for expected log messages
361+
jit_unload_event = "Unloading other JIT model"
362+
jit_unload_messages_debug: list[str] = []
363+
jit_unload_messages_info: list[str] = []
364+
jit_unload_messages = {
365+
logging.DEBUG: jit_unload_messages_debug,
366+
logging.INFO: jit_unload_messages_info,
367+
}
368+
for _logger_name, log_level, message in caplog.record_tuples:
369+
if jit_unload_event not in message:
370+
continue
371+
jit_unload_messages[log_level].append(message)
372+
373+
assert len(jit_unload_messages_info) == 1
374+
assert len(jit_unload_messages_debug) == 1
375+
376+
info_message = jit_unload_messages_info[0]
377+
debug_message = jit_unload_messages_debug[0]
378+
# Ensure info message omits model info, but includes config guidance
379+
unload_notice = f'"event": "{jit_unload_event}"'
380+
assert unload_notice in info_message
381+
loading_model_notice = f'"model_key": "{SMALL_LLM_ID}"'
382+
assert loading_model_notice in info_message
383+
unloaded_model_notice = f'"unloaded_model_key": "{EXPECTED_EMBEDDING_ID}"'
384+
assert unloaded_model_notice in info_message
385+
assert '"suggestion": ' in info_message
386+
assert "disable this behavior" in info_message
387+
assert '"unloaded_model": ' not in info_message
388+
# Ensure debug message includes model info, but omits config guidance
389+
assert unload_notice in debug_message
390+
assert loading_model_notice in info_message
391+
assert unloaded_model_notice in debug_message
392+
assert '"suggestion": ' not in debug_message
393+
assert "disable this behavior" not in debug_message
394+
assert '"unloaded_model": ' in debug_message

tests/async/test_repository_async.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from lmstudio import AsyncClient, LMStudioClientError
99

10-
from ..support import EXPECTED_DOWNLOAD_SEARCH_TERM
10+
from ..support import SMALL_LLM_SEARCH_TERM
1111

1212

1313
# N.B. We can maybe provide a reference list for what should be available
@@ -21,7 +21,7 @@
2121
async def test_download_model_async(caplog: LogCap) -> None:
2222
caplog.set_level(logging.DEBUG)
2323
async with AsyncClient() as client:
24-
models = await client.repository.search_models(EXPECTED_DOWNLOAD_SEARCH_TERM)
24+
models = await client.repository.search_models(SMALL_LLM_SEARCH_TERM)
2525
logging.info(f"Models: {models}")
2626
assert models
2727
assert isinstance(models, list)
@@ -45,7 +45,7 @@ async def test_download_model_async(caplog: LogCap) -> None:
4545
async def test_get_options_out_of_session_async(caplog: LogCap) -> None:
4646
caplog.set_level(logging.DEBUG)
4747
async with AsyncClient() as client:
48-
models = await client.repository.search_models(EXPECTED_DOWNLOAD_SEARCH_TERM)
48+
models = await client.repository.search_models(SMALL_LLM_SEARCH_TERM)
4949
assert models
5050
assert isinstance(models, list)
5151
assert len(models) > 0
@@ -60,7 +60,7 @@ async def test_get_options_out_of_session_async(caplog: LogCap) -> None:
6060
async def test_download_out_of_session_async(caplog: LogCap) -> None:
6161
caplog.set_level(logging.DEBUG)
6262
async with AsyncClient() as client:
63-
models = await client.repository.search_models(EXPECTED_DOWNLOAD_SEARCH_TERM)
63+
models = await client.repository.search_models(SMALL_LLM_SEARCH_TERM)
6464
logging.info(f"Models: {models}")
6565
assert models
6666
assert isinstance(models, list)

tests/support/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,12 @@
2424
THIS_DIR = Path(__file__).parent
2525

2626
LOCAL_API_HOST = "localhost:1234"
27-
EXPECTED_DOWNLOAD_SEARCH_TERM = "smollm2-135m"
2827

2928
####################################################
3029
# Embedding model testing
3130
####################################################
3231
EXPECTED_EMBEDDING = "nomic-ai/nomic-embed-text-v1.5"
3332
EXPECTED_EMBEDDING_ID = "text-embedding-nomic-embed-text-v1.5"
34-
EXPECTED_EMBEDDING_DEFAULT_ID = EXPECTED_EMBEDDING_ID # the same for now
3533
EXPECTED_EMBEDDING_LENGTH = 768 # nomic has embedding dimension 768
3634
EXPECTED_EMBEDDING_CONTEXT_LENGTH = 2048 # nomic accepts a 2048 token context
3735

@@ -40,7 +38,6 @@
4038
####################################################
4139
EXPECTED_LLM = "hugging-quants/llama-3.2-1b-instruct"
4240
EXPECTED_LLM_ID = "llama-3.2-1b-instruct"
43-
EXPECTED_LLM_DEFAULT_ID = EXPECTED_LLM_ID # the same for now
4441
PROMPT = "Hello"
4542
MAX_PREDICTED_TOKENS = 50
4643
# Use a dict here to ensure dicts are accepted in all config APIs,
@@ -68,6 +65,12 @@
6865
####################################################
6966
TOOL_LLM_ID = "qwen2.5-7b-instruct-1m"
7067

68+
####################################################
69+
# Other specific models needed for testing
70+
####################################################
71+
SMALL_LLM_SEARCH_TERM = "smollm2-135m"
72+
SMALL_LLM_ID = "smollm2-135m-instruct"
73+
7174
####################################################
7275
# Structured LLM responses
7376
####################################################

0 commit comments

Comments
 (0)