Skip to content

Commit 38f8afc

Browse files
authored
Add initial tool error handling test case (#71)
The SDK doesn't currently manage unhandled exceptions in tool calls. Add an initial test case that suppresses the exception inside the tool call. A subsequent PR will update the test case to provide this behaviour as the default behaviour.
1 parent 05a58ce commit 38f8afc

File tree

3 files changed

+56
-6
lines changed

3 files changed

+56
-6
lines changed

src/lmstudio/json_api.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,7 +1131,7 @@ def __init__(
11311131
on_prompt_processing_progress: PromptProcessingCallback | None = None,
11321132
# The remaining options are only relevant for multi-round tool actions
11331133
handle_invalid_tool_request: Callable[
1134-
[LMStudioPredictionError, ToolCallRequest | None], str
1134+
[LMStudioPredictionError, ToolCallRequest | None], str | None
11351135
]
11361136
| None = None,
11371137
llm_tools: LlmToolUseSettingToolArray | None = None,
@@ -1336,12 +1336,14 @@ def _report_prompt_processing_progress(self, progress: float) -> None:
13361336
def _handle_invalid_tool_request(
13371337
self, err_msg: str, request: ToolCallRequest | None = None
13381338
) -> str:
1339-
exc = LMStudioPredictionError(err_msg)
13401339
_on_handle_invalid_tool_request = self._on_handle_invalid_tool_request
13411340
if _on_handle_invalid_tool_request is not None:
13421341
# Allow users to override the error message, or force an exception
13431342
self._logger.debug("Invoking on_handle_invalid_tool_request callback")
1344-
err_msg = _on_handle_invalid_tool_request(exc, request)
1343+
exc = LMStudioPredictionError(err_msg)
1344+
user_err_msg = _on_handle_invalid_tool_request(exc, request)
1345+
if user_err_msg is not None:
1346+
err_msg = user_err_msg
13451347
if request is not None:
13461348
return err_msg
13471349
# We don't allow users to prevent the exception when there's no request

src/lmstudio/sync_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1499,7 +1499,7 @@ def act(
14991499
on_prediction_completed: Callable[[PredictionRoundResult], Any] | None = None,
15001500
on_prompt_processing_progress: Callable[[float, int], Any] | None = None,
15011501
handle_invalid_tool_request: Callable[
1502-
[LMStudioPredictionError, ToolCallRequest | None], str
1502+
[LMStudioPredictionError, ToolCallRequest | None], str | None
15031503
]
15041504
| None = None,
15051505
) -> ActResult:

tests/test_inference.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
Client,
1616
LlmPredictionConfig,
1717
LlmPredictionFragment,
18+
LMStudioPredictionError,
1819
LMStudioValueError,
1920
PredictionResult,
2021
PredictionRoundResult,
22+
ToolCallRequest,
2123
ToolFunctionDef,
2224
ToolFunctionDefDict,
2325
)
@@ -162,7 +164,7 @@ def test_duplicate_tool_names_rejected() -> None:
162164

163165
@pytest.mark.lmstudio
164166
def test_tool_using_agent(caplog: LogCap) -> None:
165-
# This is currently a sync-only API (it will be refactored after 1.0.0)
167+
# This is currently a sync-only API (it will be refactored in a future release)
166168

167169
caplog.set_level(logging.DEBUG)
168170
model_id = TOOL_LLM_ID
@@ -192,7 +194,7 @@ def test_tool_using_agent(caplog: LogCap) -> None:
192194

193195
@pytest.mark.lmstudio
194196
def test_tool_using_agent_callbacks(caplog: LogCap) -> None:
195-
# This is currently a sync-only API (it will be refactored after 1.0.0)
197+
# This is currently a sync-only API (it will be refactored in a future release)
196198

197199
caplog.set_level(logging.DEBUG)
198200
model_id = TOOL_LLM_ID
@@ -241,3 +243,49 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
241243

242244
cloned_chat = chat.copy()
243245
assert cloned_chat._messages == chat._messages
246+
247+
248+
def divide(numerator: float, denominator: float) -> float | str:
249+
"""Divide the given numerator by the given denominator. Return the result."""
250+
try:
251+
return numerator / denominator
252+
except Exception as exc:
253+
# TODO: Perform this exception-to-response-string translation implicitly
254+
return f"Unhandled Python exception: {exc!r}"
255+
256+
257+
@pytest.mark.lmstudio
258+
def test_tool_using_agent_error_handling(caplog: LogCap) -> None:
259+
# This is currently a sync-only API (it will be refactored in a future release)
260+
261+
caplog.set_level(logging.DEBUG)
262+
model_id = TOOL_LLM_ID
263+
with Client() as client:
264+
llm = client.llm.model(model_id)
265+
chat = Chat()
266+
chat.add_user_message(
267+
"Attempt to divide 1 by 0 using the tool. Explain the result."
268+
)
269+
tools = [divide]
270+
predictions: list[PredictionRoundResult] = []
271+
invalid_requests: list[tuple[LMStudioPredictionError, ToolCallRequest]] = []
272+
273+
def _handle_invalid_request(
274+
exc: LMStudioPredictionError, request: ToolCallRequest | None
275+
) -> None:
276+
if request is not None:
277+
invalid_requests.append((exc, request))
278+
279+
act_result = llm.act(
280+
chat,
281+
tools,
282+
handle_invalid_tool_request=_handle_invalid_request,
283+
on_prediction_completed=predictions.append,
284+
)
285+
assert len(predictions) > 1
286+
assert act_result.rounds == len(predictions)
287+
# Test case is currently suppressing exceptions inside the tool call
288+
assert invalid_requests == []
289+
# If the content checks prove flaky in practice, they can be dropped
290+
assert "divide" in predictions[-1].content
291+
assert "zero" in predictions[-1].content

0 commit comments

Comments
 (0)