Skip to content

Commit 54ed201

Browse files
authored
Accept tool use messages via public Chat APIs (#28)
Also simplify the prompts in the tool use test cases to avoid depending on the language model's mathematical reasoning capabilities. Closes #20
1 parent 200fa67 commit 54ed201

File tree

6 files changed

+151
-87
lines changed

6 files changed

+151
-87
lines changed

examples/tool-use-multiple.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ def is_prime(n: int) -> bool:
1818
return False
1919
return True
2020

21-
model = lms.llm("qwen2.5-7b-instruct")
21+
chat = lms.Chat()
22+
model = lms.llm("qwen2.5-7b-instruct-1m")
2223
model.act(
2324
"Is the result of 12345 + 45668 a prime? Think step by step.",
2425
[add, is_prime],
25-
on_message=print,
26+
on_message=chat.append,
2627
)
28+
print(chat)

examples/tool-use.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ def multiply(a: float, b: float) -> float:
77
"""Given two numbers a and b. Returns the product of them."""
88
return a * b
99

10-
model = lms.llm("qwen2.5-7b-instruct")
10+
chat = lms.Chat()
11+
model = lms.llm("qwen2.5-7b-instruct-1m")
1112
model.act(
1213
"What is the result of 12345 multiplied by 54321?",
1314
[multiply],
14-
on_message=print,
15+
on_message=chat.append,
1516
)
17+
print(chat)

src/lmstudio/history.py

Lines changed: 73 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@
4848
ChatMessagePartFileDataDict as _FileHandleDict,
4949
ChatMessagePartTextData as TextData,
5050
ChatMessagePartTextDataDict as TextDataDict,
51-
ChatMessagePartToolCallRequestData as _ToolCallRequestData,
52-
ChatMessagePartToolCallRequestDataDict as _ToolCallRequestDataDict,
53-
ChatMessagePartToolCallResultData as _ToolCallResultData,
54-
ChatMessagePartToolCallResultDataDict as _ToolCallResultDataDict,
51+
ChatMessagePartToolCallRequestData as ToolCallRequestData,
52+
ChatMessagePartToolCallRequestDataDict as ToolCallRequestDataDict,
53+
ChatMessagePartToolCallResultData as ToolCallResultData,
54+
ChatMessagePartToolCallResultDataDict as ToolCallResultDataDict,
5555
# Private until LM Studio file handle support stabilizes
5656
# FileType,
5757
FilesRpcUploadFileBase64Parameter,
58-
# Private until user level tool call request management is defined
59-
ToolCallRequest as _ToolCallRequest,
58+
ToolCallRequest as ToolCallRequest,
59+
FunctionToolCallRequestDict as ToolCallRequestDict,
6060
)
6161

6262
__all__ = [
@@ -81,8 +81,8 @@
8181
"TextData",
8282
"TextDataDict",
8383
# Private until user level tool call request management is defined
84-
"_ToolCallRequest", # Other modules need this to be exported
85-
"_ToolCallResultData", # Other modules need this to be exported
84+
"ToolCallRequest",
85+
"ToolCallResultData",
8686
# "ToolCallRequest",
8787
# "ToolCallResult",
8888
"UserMessageContent",
@@ -109,11 +109,11 @@
109109
SystemPromptContentDict = TextDataDict
110110
UserMessageContent = TextData | _FileHandle
111111
UserMessageContentDict = TextDataDict | _FileHandleDict
112-
AssistantResponseContent = TextData | _FileHandle | _ToolCallRequestData
113-
AssistantResponseContentDict = TextDataDict | _FileHandleDict | _ToolCallRequestDataDict
114-
ChatMessageContent = TextData | _FileHandle | _ToolCallRequestData | _ToolCallResultData
112+
AssistantResponseContent = TextData | _FileHandle
113+
AssistantResponseContentDict = TextDataDict | _FileHandleDict
114+
ChatMessageContent = TextData | _FileHandle | ToolCallRequestData | ToolCallResultData
115115
ChatMessageContentDict = (
116-
TextDataDict | _FileHandleDict | _ToolCallRequestData | _ToolCallResultDataDict
116+
TextDataDict | _FileHandleDict | ToolCallRequestData | ToolCallResultDataDict
117117
)
118118

119119

@@ -132,7 +132,13 @@ def _to_history_content(self) -> str:
132132
AnyUserMessageInput = UserMessageInput | UserMessageMultiPartInput
133133
AssistantResponseInput = str | AssistantResponseContent | AssistantResponseContentDict
134134
AnyAssistantResponseInput = AssistantResponseInput | _ServerAssistantResponse
135-
_ToolCallResultInput = _ToolCallResultData | _ToolCallResultDataDict
135+
ToolCallRequestInput = (
136+
ToolCallRequest
137+
| ToolCallRequestDict
138+
| ToolCallRequestData
139+
| ToolCallRequestDataDict
140+
)
141+
ToolCallResultInput = ToolCallResultData | ToolCallResultDataDict
136142
ChatMessageInput = str | ChatMessageContent | ChatMessageContentDict
137143
ChatMessageMultiPartInput = UserMessageMultiPartInput
138144
AnyChatMessageInput = ChatMessageInput | ChatMessageMultiPartInput
@@ -355,6 +361,21 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage:
355361
if role == "user":
356362
messages = cast(AnyUserMessageInput, content)
357363
return self.add_user_message(messages)
364+
# Assistant responses consist of a text response with zero or more tool requests
365+
if role == "assistant":
366+
if _is_chat_message_input(content):
367+
response = cast(AssistantResponseInput, content)
368+
return self.add_assistant_response(response)
369+
try:
370+
(response_content, *tool_request_contents) = content
371+
except ValueError:
372+
raise LMStudioValueError(
373+
f"Unable to parse assistant response content: {content}"
374+
) from None
375+
response = cast(AssistantResponseInput, response_content)
376+
tool_requests = cast(Iterable[ToolCallRequest], tool_request_contents)
377+
return self.add_assistant_response(response, tool_requests)
378+
358379
# Other roles do not accept multi-part messages, so ensure there
359380
# is exactly one content item given. We still accept iterables because
360381
# that's how the wire format is defined and we want to accept that.
@@ -368,17 +389,13 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage:
368389
except ValueError:
369390
err_msg = f"{role!r} role does not support multi-part message content."
370391
raise LMStudioValueError(err_msg) from None
371-
372392
match role:
373393
case "system":
374394
prompt = cast(SystemPromptInput, content_item)
375395
result = self.add_system_prompt(prompt)
376-
case "assistant":
377-
response = cast(AssistantResponseInput, content_item)
378-
result = self.add_assistant_response(response)
379396
case "tool":
380-
tool_result = cast(_ToolCallResultInput, content_item)
381-
result = self._add_tool_result(tool_result)
397+
tool_result = cast(ToolCallResultInput, content_item)
398+
result = self.add_tool_result(tool_result)
382399
case _:
383400
raise LMStudioValueError(f"Unknown history role: {role}")
384401
return result
@@ -556,11 +573,14 @@ def add_user_message(
556573
@classmethod
557574
def _parse_assistant_response(
558575
cls, response: AnyAssistantResponseInput
559-
) -> AssistantResponseContent:
576+
) -> TextData | _FileHandle:
577+
# Note: tool call requests are NOT accepted here, as they're expected
578+
# to follow an initial text response
579+
# It's not clear if file handles should be accepted as it's not obvious
580+
# how client applications should process those (even though the API
581+
# format nominally permits them here)
560582
match response:
561-
# Sadly, we can't use the union type aliases for matching,
562-
# since the compiler needs visibility into every match target
563-
case TextData() | _FileHandle() | _ToolCallRequestData():
583+
case TextData() | _FileHandle():
564584
return response
565585
case str():
566586
return TextData(text=response)
@@ -575,59 +595,67 @@ def _parse_assistant_response(
575595
}:
576596
# We accept snake_case here for consistency, but don't really expect it
577597
return _FileHandle._from_any_dict(response)
578-
case {"toolCallRequest": [*_]} | {"tool_call_request": [*_]}:
579-
# We accept snake_case here for consistency, but don't really expect it
580-
return _ToolCallRequestData._from_any_dict(response)
581598
case _:
582599
raise LMStudioValueError(
583600
f"Unable to parse assistant response content: {response}"
584601
)
585602

603+
@classmethod
604+
def _parse_tool_call_request(
605+
cls, request: ToolCallRequestInput
606+
) -> ToolCallRequestData:
607+
match request:
608+
case ToolCallRequestData():
609+
return request
610+
case ToolCallRequest():
611+
return ToolCallRequestData(tool_call_request=request)
612+
case {"type": "toolCallRequest"}:
613+
return ToolCallRequestData._from_any_dict(request)
614+
case {"toolCallRequest": [*_]} | {"tool_call_request": [*_]}:
615+
request_details = ToolCallRequest._from_any_dict(request)
616+
return ToolCallRequestData(tool_call_request=request_details)
617+
case _:
618+
raise LMStudioValueError(
619+
f"Unable to parse tool call request content: {request}"
620+
)
621+
586622
@sdk_public_api()
587623
def add_assistant_response(
588-
self, response: AnyAssistantResponseInput
624+
self,
625+
response: AnyAssistantResponseInput,
626+
tool_call_requests: Iterable[ToolCallRequestInput] = (),
589627
) -> AssistantResponse:
590628
"""Add a new 'assistant' response to the chat history."""
591-
self._raise_if_consecutive(AssistantResponse.role, "assistant responses")
592-
message_data = self._parse_assistant_response(response)
593-
message = AssistantResponse(content=[message_data])
594-
self._messages.append(message)
595-
return message
596-
597-
def _add_assistant_tool_requests(
598-
self, response: _ServerAssistantResponse, requests: Iterable[_ToolCallRequest]
599-
) -> AssistantResponse:
600629
self._raise_if_consecutive(AssistantResponse.role, "assistant responses")
601630
message_text = self._parse_assistant_response(response)
602631
request_parts = [
603-
_ToolCallRequestData(tool_call_request=req) for req in requests
632+
self._parse_tool_call_request(req) for req in tool_call_requests
604633
]
605634
message = AssistantResponse(content=[message_text, *request_parts])
606635
self._messages.append(message)
607636
return message
608637

609638
@classmethod
610-
def _parse_tool_result(cls, result: _ToolCallResultInput) -> _ToolCallResultData:
639+
def _parse_tool_result(cls, result: ToolCallResultInput) -> ToolCallResultData:
611640
match result:
612-
# Sadly, we can't use the union type aliases for matching,
613-
# since the compiler needs visibility into every match target
614-
case _ToolCallResultData():
641+
case ToolCallResultData():
615642
return result
616643
case {"toolCallId": _, "content": _} | {"tool_call_id": _, "content": _}:
617644
# We accept snake_case here for consistency, but don't really expect it
618-
return _ToolCallResultData.from_dict(result)
645+
return ToolCallResultData.from_dict(result)
619646
case _:
620647
raise LMStudioValueError(f"Unable to parse tool result: {result}")
621648

622-
def _add_tool_results(
623-
self, results: Iterable[_ToolCallResultInput]
649+
def add_tool_results(
650+
self, results: Iterable[ToolCallResultInput]
624651
) -> ToolResultMessage:
652+
"""Add multiple tool results to the chat history as a single message."""
625653
message_content = [self._parse_tool_result(result) for result in results]
626654
message = ToolResultMessage(content=message_content)
627655
self._messages.append(message)
628656
return message
629657

630-
def _add_tool_result(self, result: _ToolCallResultInput) -> ToolResultMessage:
658+
def add_tool_result(self, result: ToolCallResultInput) -> ToolResultMessage:
631659
"""Add a new tool result to the chat history."""
632660
# Consecutive tool result messages are allowed,
633661
# so skip checking if the last message was a tool result

src/lmstudio/json_api.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
sdk_public_type,
4141
_truncate_traceback,
4242
)
43-
from .history import AssistantResponse, Chat, _ToolCallRequest, _ToolCallResultData
43+
from .history import AssistantResponse, Chat, ToolCallRequest, ToolCallResultData
4444
from .schemas import (
4545
AnyLMStudioStruct,
4646
DictObject,
@@ -1067,7 +1067,7 @@ class PredictionFragmentEvent(ChannelRxEvent[LlmPredictionFragment]):
10671067
pass
10681068

10691069

1070-
class PredictionToolCallEvent(ChannelRxEvent[_ToolCallRequest]):
1070+
class PredictionToolCallEvent(ChannelRxEvent[ToolCallRequest]):
10711071
pass
10721072

10731073

@@ -1114,7 +1114,7 @@ def __init__(
11141114
on_prompt_processing_progress: PromptProcessingCallback | None = None,
11151115
# The remaining options are only relevant for multi-round tool actions
11161116
handle_invalid_tool_request: Callable[
1117-
[LMStudioPredictionError, _ToolCallRequest | None], str
1117+
[LMStudioPredictionError, ToolCallRequest | None], str
11181118
]
11191119
| None = None,
11201120
llm_tools: LlmToolUseSettingToolArray | None = None,
@@ -1224,7 +1224,7 @@ def iter_message_events(
12241224
"toolCallRequest": tool_call_request,
12251225
}:
12261226
yield PredictionToolCallEvent(
1227-
_ToolCallRequest._from_api_dict(tool_call_request)
1227+
ToolCallRequest._from_api_dict(tool_call_request)
12281228
)
12291229
case {
12301230
"type": "toolCallGenerationFailed",
@@ -1267,10 +1267,17 @@ def handle_rx_event(self, event: PredictionRxEvent) -> None:
12671267
self._report_prompt_processing_progress(progress)
12681268
case PredictionFragmentEvent(_fragment):
12691269
if self._on_first_token is not None:
1270-
self._on_first_token()
1270+
self._logger.debug("Invoking on_first_token callback")
1271+
err_msg = f"First token callback failed for {self!r}"
1272+
with sdk_callback_invocation(err_msg, self._logger):
1273+
self._on_first_token()
12711274
self._on_first_token = None
12721275
if self._on_prediction_fragment is not None:
1273-
self._on_prediction_fragment(_fragment)
1276+
# TODO: Define an even-spammier-than-debug trace logging level for this
1277+
# self._logger.trace("Invoking on_prediction_fragment callback")
1278+
err_msg = f"Prediction fragment callback failed for {self!r}"
1279+
with sdk_callback_invocation(err_msg, self._logger):
1280+
self._on_prediction_fragment(_fragment)
12741281
pass
12751282
case PredictionToolCallEvent(_tool_call_request):
12761283
# Handled externally when iterating over events
@@ -1294,32 +1301,34 @@ def _report_prompt_processing_progress(self, progress: float) -> None:
12941301
assert self._on_prompt_processing_progress is not None
12951302
err_msg = f"Prediction progress callback failed for {self!r}"
12961303
with sdk_callback_invocation(err_msg, self._logger):
1304+
self._logger.debug("Invoking on_prompt_processing_progress callback")
12971305
self._on_prompt_processing_progress(progress)
12981306

12991307
def _handle_invalid_tool_request(
1300-
self, err_msg: str, request: _ToolCallRequest | None = None
1308+
self, err_msg: str, request: ToolCallRequest | None = None
13011309
) -> str:
13021310
exc = LMStudioPredictionError(err_msg)
13031311
_on_handle_invalid_tool_request = self._on_handle_invalid_tool_request
13041312
if _on_handle_invalid_tool_request is not None:
13051313
# Allow users to override the error message, or force an exception
1314+
self._logger.debug("Invoking on_handle_invalid_tool_request callback")
13061315
err_msg = _on_handle_invalid_tool_request(exc, request)
13071316
if request is not None:
13081317
return err_msg
13091318
# We don't allow users to prevent the exception when there's no request
13101319
raise LMStudioPredictionError(err_msg)
13111320

13121321
def request_tool_call(
1313-
self, request: _ToolCallRequest
1314-
) -> Callable[[], _ToolCallResultData]:
1322+
self, request: ToolCallRequest
1323+
) -> Callable[[], ToolCallResultData]:
13151324
tool_name = request.name
13161325
tool_call_id = request.id
13171326
client_tool = self._client_tools.get(tool_name, None)
13181327
if client_tool is None:
13191328
err_msg = self._handle_invalid_tool_request(
13201329
f"Cannot find tool with name {tool_name}.", request
13211330
)
1322-
result = _ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
1331+
result = ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
13231332
return lambda: result
13241333
# Validate parameters against their specification
13251334
params_struct, implementation = client_tool
@@ -1330,14 +1339,14 @@ def request_tool_call(
13301339
err_msg = self._handle_invalid_tool_request(
13311340
f"Failed to parse arguments for tool {tool_name}: {exc}", request
13321341
)
1333-
result = _ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
1342+
result = ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
13341343
return lambda: result
13351344
kwds = to_builtins(parsed_kwds)
13361345

13371346
# Allow caller to schedule the tool call request for background execution
1338-
def _call_requested_tool() -> _ToolCallResultData:
1347+
def _call_requested_tool() -> ToolCallResultData:
13391348
call_result = implementation(**kwds)
1340-
return _ToolCallResultData(
1349+
return ToolCallResultData(
13411350
content=json.dumps(call_result), tool_call_id=tool_call_id
13421351
)
13431352

@@ -1980,6 +1989,8 @@ def __init__(self, model_identifier: str, session: TSession) -> None:
19801989
"""Initialize the LM Studio model reference."""
19811990
self.identifier = model_identifier
19821991
self._session = session
1992+
self._logger = logger = get_logger(type(self).__name__)
1993+
logger.update_context(model_identifier=model_identifier)
19831994

19841995
def __repr__(self) -> str:
19851996
return f"{type(self).__name__}(identifier={self.identifier!r})"

0 commit comments

Comments
 (0)