Skip to content

Commit d2ce49e

Browse files
authored
Add server side token counting to API (#57)
Closes #24
1 parent 10a8c12 commit d2ce49e

File tree

6 files changed

+147
-17
lines changed

6 files changed

+147
-17
lines changed

src/lmstudio/async_api.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
)
9797
from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config
9898
from ._sdk_models import (
99+
EmbeddingRpcCountTokensParameter,
99100
EmbeddingRpcEmbedStringParameter,
100101
EmbeddingRpcTokenizeParameter,
101102
LlmApplyPromptTemplateOpts,
@@ -733,6 +734,18 @@ async def _get_context_length(self, model_specifier: AnyModelSpecifier) -> int:
733734
raw_model_info = await self._get_api_model_info(model_specifier)
734735
return int(raw_model_info.get("contextLength", -1))
735736

737+
async def _count_tokens(
738+
self, model_specifier: AnyModelSpecifier, input: str
739+
) -> int:
740+
params = EmbeddingRpcCountTokensParameter._from_api_dict(
741+
{
742+
"specifier": _model_spec_to_api_dict(model_specifier),
743+
"inputString": input,
744+
}
745+
)
746+
response = await self.remote_call("countTokens", params)
747+
return int(response["tokenCount"])
748+
736749
# Private helper method to allow the main API to easily accept iterables
737750
async def _tokenize_text(
738751
self, model_specifier: AnyModelSpecifier, input: str
@@ -748,7 +761,6 @@ async def _tokenize_text(
748761

749762
# Alas, type hints don't properly support distinguishing str vs Iterable[str]:
750763
# https://github.com/python/typing/issues/256
751-
@sdk_public_api_async()
752764
async def _tokenize(
753765
self, model_specifier: AnyModelSpecifier, input: str | Iterable[str]
754766
) -> Sequence[int] | Sequence[Sequence[int]]:
@@ -1191,6 +1203,11 @@ async def tokenize(
11911203
"""Tokenize the input string(s) using this model."""
11921204
return await self._session._tokenize(self.identifier, input)
11931205

1206+
@sdk_public_api_async()
1207+
async def count_tokens(self, input: str) -> int:
1208+
"""Report the number of tokens needed for the input string using this model."""
1209+
return await self._session._count_tokens(self.identifier, input)
1210+
11941211
@sdk_public_api_async()
11951212
async def get_context_length(self) -> int:
11961213
"""Get the context length of this model."""

src/lmstudio/sync_api.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@
124124
)
125125
from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config
126126
from ._sdk_models import (
127+
EmbeddingRpcCountTokensParameter,
127128
EmbeddingRpcEmbedStringParameter,
128129
EmbeddingRpcTokenizeParameter,
129130
LlmApplyPromptTemplateOpts,
@@ -902,6 +903,16 @@ def _get_context_length(self, model_specifier: AnyModelSpecifier) -> int:
902903
raw_model_info = self._get_api_model_info(model_specifier)
903904
return int(raw_model_info.get("contextLength", -1))
904905

906+
def _count_tokens(self, model_specifier: AnyModelSpecifier, input: str) -> int:
907+
params = EmbeddingRpcCountTokensParameter._from_api_dict(
908+
{
909+
"specifier": _model_spec_to_api_dict(model_specifier),
910+
"inputString": input,
911+
}
912+
)
913+
response = self.remote_call("countTokens", params)
914+
return int(response["tokenCount"])
915+
905916
# Private helper method to allow the main API to easily accept iterables
906917
def _tokenize_text(
907918
self, model_specifier: AnyModelSpecifier, input: str
@@ -1353,6 +1364,11 @@ def tokenize(
13531364
"""Tokenize the input string(s) using this model."""
13541365
return self._session._tokenize(self.identifier, input)
13551366

1367+
@sdk_public_api()
1368+
def count_tokens(self, input: str) -> int:
1369+
"""Report the number of tokens needed for the input string using this model."""
1370+
return self._session._count_tokens(self.identifier, input)
1371+
13561372
@sdk_public_api()
13571373
def get_context_length(self) -> int:
13581374
"""Get the context length of this model."""

tests/async/test_embedding_async.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,15 @@ async def test_tokenize_async(model_id: str, caplog: LogCap) -> None:
6363

6464
caplog.set_level(logging.DEBUG)
6565
async with AsyncClient() as client:
66-
session = client.embedding
67-
response = await session._tokenize(model_id, input=text)
66+
model = await client.embedding.model(model_id)
67+
num_tokens = await model.count_tokens(text)
68+
response = await model.tokenize(text)
6869
logging.info(f"Tokenization response: {response}")
6970
assert response
7071
assert isinstance(response, list)
72+
# Ensure token count and tokenization are consistent
73+
# (embedding models add extra start/end markers during actual tokenization)
74+
assert len(response) == num_tokens + 2
7175
# the response should be deterministic if we set constant seed
7276
# so we can also check the value if desired
7377

@@ -80,8 +84,8 @@ async def test_tokenize_list_async(model_id: str, caplog: LogCap) -> None:
8084

8185
caplog.set_level(logging.DEBUG)
8286
async with AsyncClient() as client:
83-
session = client.embedding
84-
response = await session._tokenize(model_id, input=text)
87+
model = await client.embedding.model(model_id)
88+
response = await model.tokenize(text)
8589
logging.info(f"Tokenization response: {response}")
8690
assert response
8791
assert isinstance(response, list)
@@ -142,6 +146,10 @@ async def test_invalid_model_request_async(caplog: LogCap) -> None:
142146
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
143147
await model.embed("Some text")
144148
check_sdk_error(exc_info, __file__)
149+
with anyio.fail_after(30):
150+
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
151+
await model.count_tokens("Some text")
152+
check_sdk_error(exc_info, __file__)
145153
with anyio.fail_after(30):
146154
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
147155
await model.tokenize("Some text")

tests/async/test_llm_async.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,18 @@
22

33
import logging
44

5+
import anyio
56
import pytest
67
from pytest import LogCaptureFixture as LogCap
78

8-
from lmstudio import AsyncClient, LlmLoadModelConfig, history
9+
from lmstudio import (
10+
AsyncClient,
11+
LlmLoadModelConfig,
12+
LMStudioModelNotFoundError,
13+
history,
14+
)
915

10-
from ..support import EXPECTED_LLM, EXPECTED_LLM_ID
16+
from ..support import EXPECTED_LLM, EXPECTED_LLM_ID, check_sdk_error
1117

1218

1319
@pytest.mark.asyncio
@@ -52,10 +58,14 @@ async def test_tokenize_async(model_id: str, caplog: LogCap) -> None:
5258

5359
caplog.set_level(logging.DEBUG)
5460
async with AsyncClient() as client:
55-
response = await client.llm._tokenize(model_id, input=text)
61+
model = await client.llm.model(model_id)
62+
num_tokens = await model.count_tokens(text)
63+
response = await model.tokenize(text)
5664
logging.info(f"Tokenization response: {response}")
5765
assert response
5866
assert isinstance(response, list)
67+
# Ensure token count and tokenization are consistent
68+
assert len(response) == num_tokens
5969

6070

6171
@pytest.mark.asyncio
@@ -66,7 +76,8 @@ async def test_tokenize_list_async(model_id: str, caplog: LogCap) -> None:
6676

6777
caplog.set_level(logging.DEBUG)
6878
async with AsyncClient() as client:
69-
response = await client.llm._tokenize(model_id, input=text)
79+
model = await client.llm.model(model_id)
80+
response = await model.tokenize(text)
7081
logging.info(f"Tokenization response: {response}")
7182
assert response
7283
assert isinstance(response, list)
@@ -109,3 +120,33 @@ async def test_get_model_info_async(model_id: str, caplog: LogCap) -> None:
109120
response = await client.llm.get_model_info(model_id)
110121
logging.info(f"Model config response: {response}")
111122
assert response
123+
124+
125+
@pytest.mark.asyncio
126+
@pytest.mark.lmstudio
127+
async def test_invalid_model_request_async(caplog: LogCap) -> None:
128+
caplog.set_level(logging.DEBUG)
129+
async with AsyncClient() as client:
130+
# Deliberately create an invalid model handle
131+
model = client.llm._create_handle("No such model")
132+
# This should error rather than timing out,
133+
# but avoid any risk of the client hanging...
134+
with anyio.fail_after(30):
135+
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
136+
await model.complete("Some text")
137+
check_sdk_error(exc_info, __file__)
138+
with anyio.fail_after(30):
139+
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
140+
await model.respond("Some text")
141+
check_sdk_error(exc_info, __file__)
142+
with anyio.fail_after(30):
143+
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
144+
await model.count_tokens("Some text")
145+
with anyio.fail_after(30):
146+
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
147+
await model.tokenize("Some text")
148+
check_sdk_error(exc_info, __file__)
149+
with anyio.fail_after(30):
150+
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
151+
await model.get_context_length()
152+
check_sdk_error(exc_info, __file__)

tests/sync/test_embedding_sync.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,15 @@ def test_tokenize_sync(model_id: str, caplog: LogCap) -> None:
6767

6868
caplog.set_level(logging.DEBUG)
6969
with Client() as client:
70-
session = client.embedding
71-
response = session._tokenize(model_id, input=text)
70+
model = client.embedding.model(model_id)
71+
num_tokens = model.count_tokens(text)
72+
response = model.tokenize(text)
7273
logging.info(f"Tokenization response: {response}")
7374
assert response
7475
assert isinstance(response, list)
76+
# Ensure token count and tokenization are consistent
77+
# (embedding models add extra start/end markers during actual tokenization)
78+
assert len(response) == num_tokens + 2
7579
# the response should be deterministic if we set constant seed
7680
# so we can also check the value if desired
7781

@@ -83,8 +87,8 @@ def test_tokenize_list_sync(model_id: str, caplog: LogCap) -> None:
8387

8488
caplog.set_level(logging.DEBUG)
8589
with Client() as client:
86-
session = client.embedding
87-
response = session._tokenize(model_id, input=text)
90+
model = client.embedding.model(model_id)
91+
response = model.tokenize(text)
8892
logging.info(f"Tokenization response: {response}")
8993
assert response
9094
assert isinstance(response, list)
@@ -141,6 +145,10 @@ def test_invalid_model_request_sync(caplog: LogCap) -> None:
141145
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
142146
model.embed("Some text")
143147
check_sdk_error(exc_info, __file__)
148+
with nullcontext():
149+
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
150+
model.count_tokens("Some text")
151+
check_sdk_error(exc_info, __file__)
144152
with nullcontext():
145153
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
146154
model.tokenize("Some text")

tests/sync/test_llm_sync.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,19 @@
88
"""Test non-inference methods on LLMs."""
99

1010
import logging
11+
from contextlib import nullcontext
1112

1213
import pytest
1314
from pytest import LogCaptureFixture as LogCap
1415

15-
from lmstudio import Client, LlmLoadModelConfig, history
16+
from lmstudio import (
17+
Client,
18+
LlmLoadModelConfig,
19+
LMStudioModelNotFoundError,
20+
history,
21+
)
1622

17-
from ..support import EXPECTED_LLM, EXPECTED_LLM_ID
23+
from ..support import EXPECTED_LLM, EXPECTED_LLM_ID, check_sdk_error
1824

1925

2026
@pytest.mark.lmstudio
@@ -55,10 +61,14 @@ def test_tokenize_sync(model_id: str, caplog: LogCap) -> None:
5561

5662
caplog.set_level(logging.DEBUG)
5763
with Client() as client:
58-
response = client.llm._tokenize(model_id, input=text)
64+
model = client.llm.model(model_id)
65+
num_tokens = model.count_tokens(text)
66+
response = model.tokenize(text)
5967
logging.info(f"Tokenization response: {response}")
6068
assert response
6169
assert isinstance(response, list)
70+
# Ensure token count and tokenization are consistent
71+
assert len(response) == num_tokens
6272

6373

6474
@pytest.mark.lmstudio
@@ -68,7 +78,8 @@ def test_tokenize_list_sync(model_id: str, caplog: LogCap) -> None:
6878

6979
caplog.set_level(logging.DEBUG)
7080
with Client() as client:
71-
response = client.llm._tokenize(model_id, input=text)
81+
model = client.llm.model(model_id)
82+
response = model.tokenize(text)
7283
logging.info(f"Tokenization response: {response}")
7384
assert response
7485
assert isinstance(response, list)
@@ -108,3 +119,32 @@ def test_get_model_info_sync(model_id: str, caplog: LogCap) -> None:
108119
response = client.llm.get_model_info(model_id)
109120
logging.info(f"Model config response: {response}")
110121
assert response
122+
123+
124+
@pytest.mark.lmstudio
125+
def test_invalid_model_request_sync(caplog: LogCap) -> None:
126+
caplog.set_level(logging.DEBUG)
127+
with Client() as client:
128+
# Deliberately create an invalid model handle
129+
model = client.llm._create_handle("No such model")
130+
# This should error rather than timing out,
131+
# but avoid any risk of the client hanging...
132+
with nullcontext():
133+
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
134+
model.complete("Some text")
135+
check_sdk_error(exc_info, __file__)
136+
with nullcontext():
137+
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
138+
model.respond("Some text")
139+
check_sdk_error(exc_info, __file__)
140+
with nullcontext():
141+
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
142+
model.count_tokens("Some text")
143+
with nullcontext():
144+
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
145+
model.tokenize("Some text")
146+
check_sdk_error(exc_info, __file__)
147+
with nullcontext():
148+
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
149+
model.get_context_length()
150+
check_sdk_error(exc_info, __file__)

0 commit comments

Comments
 (0)