Skip to content

Commit cba0574

Browse files
committed
Pass tool call failures back to the LLM
Default to passing tool call failures back to the LLM. SDK users can override this via the invalid tool request callback.
1 parent 38f8afc commit cba0574

File tree

3 files changed

+42
-15
lines changed

3 files changed

+42
-15
lines changed

src/lmstudio/json_api.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,21 +1334,35 @@ def _report_prompt_processing_progress(self, progress: float) -> None:
13341334
self._on_prompt_processing_progress(progress)
13351335

13361336
def _handle_invalid_tool_request(
1337-
self, err_msg: str, request: ToolCallRequest | None = None
1337+
self,
1338+
err_msg: str,
1339+
request: ToolCallRequest | None = None,
1340+
*,
1341+
exc: Exception | None = None,
13381342
) -> str:
13391343
_on_handle_invalid_tool_request = self._on_handle_invalid_tool_request
13401344
if _on_handle_invalid_tool_request is not None:
13411345
# Allow users to override the error message, or force an exception
13421346
self._logger.debug("Invoking on_handle_invalid_tool_request callback")
1343-
exc = LMStudioPredictionError(err_msg)
1344-
user_err_msg = _on_handle_invalid_tool_request(exc, request)
1347+
callback_exc = LMStudioPredictionError(err_msg)
1348+
if exc is not None:
1349+
callback_exc.__cause__ = exc
1350+
user_err_msg = _on_handle_invalid_tool_request(callback_exc, request)
13451351
if user_err_msg is not None:
13461352
err_msg = user_err_msg
13471353
if request is not None:
13481354
return err_msg
13491355
# We don't allow users to prevent the exception when there's no request
13501356
raise LMStudioPredictionError(err_msg)
13511357

1358+
def _handle_failed_tool_request(
1359+
self, exc: Exception, request: ToolCallRequest
1360+
) -> ToolCallResultData:
1361+
err_msg = self._handle_invalid_tool_request(
1362+
f"Unhandled Python exception: {exc!r}", request, exc=exc
1363+
)
1364+
return ToolCallResultData(content=json.dumps(err_msg), tool_call_id=request.id)
1365+
13521366
def request_tool_call(
13531367
self, request: ToolCallRequest
13541368
) -> Callable[[], ToolCallResultData]:

src/lmstudio/sync_api.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,13 +1591,13 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
15911591
channel_cm = self._session._create_channel(endpoint)
15921592
prediction_stream = PredictionStream(channel_cm, endpoint)
15931593
tool_call_requests: list[ToolCallRequest] = []
1594-
pending_tool_calls: list[SyncFuture[Any]] = []
1594+
pending_tool_calls: dict[SyncFuture[Any], ToolCallRequest] = {}
15951595
for event in prediction_stream._iter_events():
15961596
if isinstance(event, PredictionToolCallEvent):
15971597
tool_call_request = event.arg
15981598
tool_call_requests.append(tool_call_request)
15991599
tool_call = endpoint.request_tool_call(tool_call_request)
1600-
pending_tool_calls.append(pool.submit(tool_call))
1600+
pending_tool_calls[pool.submit(tool_call)] = tool_call_request
16011601
prediction = prediction_stream.result()
16021602
self._logger.debug(
16031603
"Completed .act() prediction round", round_index=round_index
@@ -1610,8 +1610,22 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
16101610
with sdk_callback_invocation(err_msg, self._logger):
16111611
on_prediction_completed(round_result)
16121612
if pending_tool_calls:
1613+
1614+
def _finish_tool_call(fut: SyncFuture[Any]) -> Any:
1615+
exc = fut.exception()
1616+
if exc is not None:
1617+
if not isinstance(exc, Exception):
1618+
# Don't allow base exceptions to be suppressed
1619+
raise exc
1620+
failed_request = pending_tool_calls[fut]
1621+
return endpoint._handle_failed_tool_request(
1622+
exc, failed_request
1623+
)
1624+
return fut.result()
1625+
16131626
tool_results = [
1614-
fut.result() for fut in as_completed(pending_tool_calls)
1627+
_finish_tool_call(fut)
1628+
for fut in as_completed(pending_tool_calls)
16151629
]
16161630
requests_message = agent_chat.add_assistant_response(
16171631
prediction, tool_call_requests

tests/test_inference.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -247,11 +247,7 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
247247

248248
def divide(numerator: float, denominator: float) -> float | str:
249249
"""Divide the given numerator by the given denominator. Return the result."""
250-
try:
251-
return numerator / denominator
252-
except Exception as exc:
253-
# TODO: Perform this exception-to-response-string translation implicitly
254-
return f"Unhandled Python exception: {exc!r}"
250+
return numerator / denominator
255251

256252

257253
@pytest.mark.lmstudio
@@ -268,13 +264,13 @@ def test_tool_using_agent_error_handling(caplog: LogCap) -> None:
268264
)
269265
tools = [divide]
270266
predictions: list[PredictionRoundResult] = []
271-
invalid_requests: list[tuple[LMStudioPredictionError, ToolCallRequest]] = []
267+
request_failures: list[LMStudioPredictionError] = []
272268

273269
def _handle_invalid_request(
274270
exc: LMStudioPredictionError, request: ToolCallRequest | None
275271
) -> None:
276272
if request is not None:
277-
invalid_requests.append((exc, request))
273+
request_failures.append(exc)
278274

279275
act_result = llm.act(
280276
chat,
@@ -284,8 +280,11 @@ def _handle_invalid_request(
284280
)
285281
assert len(predictions) > 1
286282
assert act_result.rounds == len(predictions)
287-
# Test case is currently suppressing exceptions inside the tool call
288-
assert invalid_requests == []
283+
# Ensure the tool call failure was reported to the user callback
284+
assert len(request_failures) == 1
285+
tool_failure_exc = request_failures[0]
286+
assert isinstance(tool_failure_exc, LMStudioPredictionError)
287+
assert isinstance(tool_failure_exc.__cause__, ZeroDivisionError)
289288
# If the content checks prove flaky in practice, they can be dropped
290289
assert "divide" in predictions[-1].content
291290
assert "zero" in predictions[-1].content

0 commit comments

Comments
 (0)