48
48
ChatMessagePartFileDataDict as _FileHandleDict ,
49
49
ChatMessagePartTextData as TextData ,
50
50
ChatMessagePartTextDataDict as TextDataDict ,
51
- ChatMessagePartToolCallRequestData as _ToolCallRequestData ,
52
- ChatMessagePartToolCallRequestDataDict as _ToolCallRequestDataDict ,
53
- ChatMessagePartToolCallResultData as _ToolCallResultData ,
54
- ChatMessagePartToolCallResultDataDict as _ToolCallResultDataDict ,
51
+ ChatMessagePartToolCallRequestData as ToolCallRequestData ,
52
+ ChatMessagePartToolCallRequestDataDict as ToolCallRequestDataDict ,
53
+ ChatMessagePartToolCallResultData as ToolCallResultData ,
54
+ ChatMessagePartToolCallResultDataDict as ToolCallResultDataDict ,
55
55
# Private until LM Studio file handle support stabilizes
56
56
# FileType,
57
57
FilesRpcUploadFileBase64Parameter ,
58
- # Private until user level tool call request management is defined
59
- ToolCallRequest as _ToolCallRequest ,
58
+ ToolCallRequest as ToolCallRequest ,
59
+ FunctionToolCallRequestDict as ToolCallRequestDict ,
60
60
)
61
61
62
62
__all__ = [
81
81
"TextData" ,
82
82
"TextDataDict" ,
83
83
# Private until user level tool call request management is defined
84
- "_ToolCallRequest" , # Other modules need this to be exported
85
- "_ToolCallResultData" , # Other modules need this to be exported
84
+ "ToolCallRequest" ,
85
+ "ToolCallResultData" ,
86
86
# "ToolCallRequest",
87
87
# "ToolCallResult",
88
88
"UserMessageContent" ,
109
109
SystemPromptContentDict = TextDataDict
110
110
UserMessageContent = TextData | _FileHandle
111
111
UserMessageContentDict = TextDataDict | _FileHandleDict
112
- AssistantResponseContent = TextData | _FileHandle | _ToolCallRequestData
113
- AssistantResponseContentDict = TextDataDict | _FileHandleDict | _ToolCallRequestDataDict
114
- ChatMessageContent = TextData | _FileHandle | _ToolCallRequestData | _ToolCallResultData
112
+ AssistantResponseContent = TextData | _FileHandle
113
+ AssistantResponseContentDict = TextDataDict | _FileHandleDict
114
+ ChatMessageContent = TextData | _FileHandle | ToolCallRequestData | ToolCallResultData
115
115
ChatMessageContentDict = (
116
- TextDataDict | _FileHandleDict | _ToolCallRequestData | _ToolCallResultDataDict
116
+ TextDataDict | _FileHandleDict | ToolCallRequestData | ToolCallResultDataDict
117
117
)
118
118
119
119
@@ -132,7 +132,13 @@ def _to_history_content(self) -> str:
132
132
AnyUserMessageInput = UserMessageInput | UserMessageMultiPartInput
133
133
AssistantResponseInput = str | AssistantResponseContent | AssistantResponseContentDict
134
134
AnyAssistantResponseInput = AssistantResponseInput | _ServerAssistantResponse
135
- _ToolCallResultInput = _ToolCallResultData | _ToolCallResultDataDict
135
+ ToolCallRequestInput = (
136
+ ToolCallRequest
137
+ | ToolCallRequestDict
138
+ | ToolCallRequestData
139
+ | ToolCallRequestDataDict
140
+ )
141
+ ToolCallResultInput = ToolCallResultData | ToolCallResultDataDict
136
142
ChatMessageInput = str | ChatMessageContent | ChatMessageContentDict
137
143
ChatMessageMultiPartInput = UserMessageMultiPartInput
138
144
AnyChatMessageInput = ChatMessageInput | ChatMessageMultiPartInput
@@ -355,6 +361,21 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage:
355
361
if role == "user" :
356
362
messages = cast (AnyUserMessageInput , content )
357
363
return self .add_user_message (messages )
364
+ # Assistant responses consist of a text response with zero or more tool requests
365
+ if role == "assistant" :
366
+ if _is_chat_message_input (content ):
367
+ response = cast (AssistantResponseInput , content )
368
+ return self .add_assistant_response (response )
369
+ try :
370
+ (response_content , * tool_request_contents ) = content
371
+ except ValueError :
372
+ raise LMStudioValueError (
373
+ f"Unable to parse assistant response content: { content } "
374
+ ) from None
375
+ response = cast (AssistantResponseInput , response_content )
376
+ tool_requests = cast (Iterable [ToolCallRequest ], tool_request_contents )
377
+ return self .add_assistant_response (response , tool_requests )
378
+
358
379
# Other roles do not accept multi-part messages, so ensure there
359
380
# is exactly one content item given. We still accept iterables because
360
381
# that's how the wire format is defined and we want to accept that.
@@ -368,17 +389,13 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage:
368
389
except ValueError :
369
390
err_msg = f"{ role !r} role does not support multi-part message content."
370
391
raise LMStudioValueError (err_msg ) from None
371
-
372
392
match role :
373
393
case "system" :
374
394
prompt = cast (SystemPromptInput , content_item )
375
395
result = self .add_system_prompt (prompt )
376
- case "assistant" :
377
- response = cast (AssistantResponseInput , content_item )
378
- result = self .add_assistant_response (response )
379
396
case "tool" :
380
- tool_result = cast (_ToolCallResultInput , content_item )
381
- result = self ._add_tool_result (tool_result )
397
+ tool_result = cast (ToolCallResultInput , content_item )
398
+ result = self .add_tool_result (tool_result )
382
399
case _:
383
400
raise LMStudioValueError (f"Unknown history role: { role } " )
384
401
return result
@@ -556,11 +573,14 @@ def add_user_message(
556
573
@classmethod
557
574
def _parse_assistant_response (
558
575
cls , response : AnyAssistantResponseInput
559
- ) -> AssistantResponseContent :
576
+ ) -> TextData | _FileHandle :
577
+ # Note: tool call requests are NOT accepted here, as they're expected
578
+ # to follow an initial text response
579
+ # It's not clear if file handles should be accepted as it's not obvious
580
+ # how client applications should process those (even though the API
581
+ # format nominally permits them here)
560
582
match response :
561
- # Sadly, we can't use the union type aliases for matching,
562
- # since the compiler needs visibility into every match target
563
- case TextData () | _FileHandle () | _ToolCallRequestData ():
583
+ case TextData () | _FileHandle ():
564
584
return response
565
585
case str ():
566
586
return TextData (text = response )
@@ -575,59 +595,67 @@ def _parse_assistant_response(
575
595
}:
576
596
# We accept snake_case here for consistency, but don't really expect it
577
597
return _FileHandle ._from_any_dict (response )
578
- case {"toolCallRequest" : [* _]} | {"tool_call_request" : [* _]}:
579
- # We accept snake_case here for consistency, but don't really expect it
580
- return _ToolCallRequestData ._from_any_dict (response )
581
598
case _:
582
599
raise LMStudioValueError (
583
600
f"Unable to parse assistant response content: { response } "
584
601
)
585
602
603
+ @classmethod
604
+ def _parse_tool_call_request (
605
+ cls , request : ToolCallRequestInput
606
+ ) -> ToolCallRequestData :
607
+ match request :
608
+ case ToolCallRequestData ():
609
+ return request
610
+ case ToolCallRequest ():
611
+ return ToolCallRequestData (tool_call_request = request )
612
+ case {"type" : "toolCallRequest" }:
613
+ return ToolCallRequestData ._from_any_dict (request )
614
+ case {"toolCallRequest" : [* _]} | {"tool_call_request" : [* _]}:
615
+ request_details = ToolCallRequest ._from_any_dict (request )
616
+ return ToolCallRequestData (tool_call_request = request_details )
617
+ case _:
618
+ raise LMStudioValueError (
619
+ f"Unable to parse tool call request content: { request } "
620
+ )
621
+
586
622
@sdk_public_api ()
587
623
def add_assistant_response (
588
- self , response : AnyAssistantResponseInput
624
+ self ,
625
+ response : AnyAssistantResponseInput ,
626
+ tool_call_requests : Iterable [ToolCallRequestInput ] = (),
589
627
) -> AssistantResponse :
590
628
"""Add a new 'assistant' response to the chat history."""
591
- self ._raise_if_consecutive (AssistantResponse .role , "assistant responses" )
592
- message_data = self ._parse_assistant_response (response )
593
- message = AssistantResponse (content = [message_data ])
594
- self ._messages .append (message )
595
- return message
596
-
597
- def _add_assistant_tool_requests (
598
- self , response : _ServerAssistantResponse , requests : Iterable [_ToolCallRequest ]
599
- ) -> AssistantResponse :
600
629
self ._raise_if_consecutive (AssistantResponse .role , "assistant responses" )
601
630
message_text = self ._parse_assistant_response (response )
602
631
request_parts = [
603
- _ToolCallRequestData ( tool_call_request = req ) for req in requests
632
+ self . _parse_tool_call_request ( req ) for req in tool_call_requests
604
633
]
605
634
message = AssistantResponse (content = [message_text , * request_parts ])
606
635
self ._messages .append (message )
607
636
return message
608
637
609
638
@classmethod
610
- def _parse_tool_result (cls , result : _ToolCallResultInput ) -> _ToolCallResultData :
639
+ def _parse_tool_result (cls , result : ToolCallResultInput ) -> ToolCallResultData :
611
640
match result :
612
- # Sadly, we can't use the union type aliases for matching,
613
- # since the compiler needs visibility into every match target
614
- case _ToolCallResultData ():
641
+ case ToolCallResultData ():
615
642
return result
616
643
case {"toolCallId" : _, "content" : _} | {"tool_call_id" : _, "content" : _}:
617
644
# We accept snake_case here for consistency, but don't really expect it
618
- return _ToolCallResultData .from_dict (result )
645
+ return ToolCallResultData .from_dict (result )
619
646
case _:
620
647
raise LMStudioValueError (f"Unable to parse tool result: { result } " )
621
648
622
- def _add_tool_results (
623
- self , results : Iterable [_ToolCallResultInput ]
649
+ def add_tool_results (
650
+ self , results : Iterable [ToolCallResultInput ]
624
651
) -> ToolResultMessage :
652
+ """Add multiple tool results to the chat history as a single message."""
625
653
message_content = [self ._parse_tool_result (result ) for result in results ]
626
654
message = ToolResultMessage (content = message_content )
627
655
self ._messages .append (message )
628
656
return message
629
657
630
- def _add_tool_result (self , result : _ToolCallResultInput ) -> ToolResultMessage :
658
+ def add_tool_result (self , result : ToolCallResultInput ) -> ToolResultMessage :
631
659
"""Add a new tool result to the chat history."""
632
660
# Consecutive tool result messages are allowed,
633
661
# so skip checking if the last message was a tool result
0 commit comments