Skip to content

Commit 422f957

Browse files
authored
Pass tool call failures back to the LLM (#72)
Default to passing tool call failures back to the LLM. SDK users can override this via the invalid tool request callback.
1 parent 38f8afc commit 422f957

File tree

3 files changed

+42
-15
lines changed

3 files changed

+42
-15
lines changed

src/lmstudio/json_api.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,21 +1334,35 @@ def _report_prompt_processing_progress(self, progress: float) -> None:
13341334
self._on_prompt_processing_progress(progress)
13351335

13361336
def _handle_invalid_tool_request(
1337-
self, err_msg: str, request: ToolCallRequest | None = None
1337+
self,
1338+
err_msg: str,
1339+
request: ToolCallRequest | None = None,
1340+
*,
1341+
exc: Exception | None = None,
13381342
) -> str:
13391343
_on_handle_invalid_tool_request = self._on_handle_invalid_tool_request
13401344
if _on_handle_invalid_tool_request is not None:
13411345
# Allow users to override the error message, or force an exception
13421346
self._logger.debug("Invoking on_handle_invalid_tool_request callback")
1343-
exc = LMStudioPredictionError(err_msg)
1344-
user_err_msg = _on_handle_invalid_tool_request(exc, request)
1347+
callback_exc = LMStudioPredictionError(err_msg)
1348+
if exc is not None:
1349+
callback_exc.__cause__ = exc
1350+
user_err_msg = _on_handle_invalid_tool_request(callback_exc, request)
13451351
if user_err_msg is not None:
13461352
err_msg = user_err_msg
13471353
if request is not None:
13481354
return err_msg
13491355
# We don't allow users to prevent the exception when there's no request
13501356
raise LMStudioPredictionError(err_msg)
13511357

1358+
def _handle_failed_tool_request(
1359+
self, exc: Exception, request: ToolCallRequest
1360+
) -> ToolCallResultData:
1361+
err_msg = self._handle_invalid_tool_request(
1362+
f"Unhandled Python exception: {exc!r}", request, exc=exc
1363+
)
1364+
return ToolCallResultData(content=json.dumps(err_msg), tool_call_id=request.id)
1365+
13521366
def request_tool_call(
13531367
self, request: ToolCallRequest
13541368
) -> Callable[[], ToolCallResultData]:

src/lmstudio/sync_api.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,13 +1591,13 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
15911591
channel_cm = self._session._create_channel(endpoint)
15921592
prediction_stream = PredictionStream(channel_cm, endpoint)
15931593
tool_call_requests: list[ToolCallRequest] = []
1594-
pending_tool_calls: list[SyncFuture[Any]] = []
1594+
pending_tool_calls: dict[SyncFuture[Any], ToolCallRequest] = {}
15951595
for event in prediction_stream._iter_events():
15961596
if isinstance(event, PredictionToolCallEvent):
15971597
tool_call_request = event.arg
15981598
tool_call_requests.append(tool_call_request)
15991599
tool_call = endpoint.request_tool_call(tool_call_request)
1600-
pending_tool_calls.append(pool.submit(tool_call))
1600+
pending_tool_calls[pool.submit(tool_call)] = tool_call_request
16011601
prediction = prediction_stream.result()
16021602
self._logger.debug(
16031603
"Completed .act() prediction round", round_index=round_index
@@ -1610,8 +1610,22 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
16101610
with sdk_callback_invocation(err_msg, self._logger):
16111611
on_prediction_completed(round_result)
16121612
if pending_tool_calls:
1613+
1614+
def _finish_tool_call(fut: SyncFuture[Any]) -> Any:
1615+
exc = fut.exception()
1616+
if exc is not None:
1617+
if not isinstance(exc, Exception):
1618+
# Don't allow base exceptions to be suppressed
1619+
raise exc
1620+
failed_request = pending_tool_calls[fut]
1621+
return endpoint._handle_failed_tool_request(
1622+
exc, failed_request
1623+
)
1624+
return fut.result()
1625+
16131626
tool_results = [
1614-
fut.result() for fut in as_completed(pending_tool_calls)
1627+
_finish_tool_call(fut)
1628+
for fut in as_completed(pending_tool_calls)
16151629
]
16161630
requests_message = agent_chat.add_assistant_response(
16171631
prediction, tool_call_requests

tests/test_inference.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -247,11 +247,7 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
247247

248248
def divide(numerator: float, denominator: float) -> float | str:
249249
"""Divide the given numerator by the given denominator. Return the result."""
250-
try:
251-
return numerator / denominator
252-
except Exception as exc:
253-
# TODO: Perform this exception-to-response-string translation implicitly
254-
return f"Unhandled Python exception: {exc!r}"
250+
return numerator / denominator
255251

256252

257253
@pytest.mark.lmstudio
@@ -268,13 +264,13 @@ def test_tool_using_agent_error_handling(caplog: LogCap) -> None:
268264
)
269265
tools = [divide]
270266
predictions: list[PredictionRoundResult] = []
271-
invalid_requests: list[tuple[LMStudioPredictionError, ToolCallRequest]] = []
267+
request_failures: list[LMStudioPredictionError] = []
272268

273269
def _handle_invalid_request(
274270
exc: LMStudioPredictionError, request: ToolCallRequest | None
275271
) -> None:
276272
if request is not None:
277-
invalid_requests.append((exc, request))
273+
request_failures.append(exc)
278274

279275
act_result = llm.act(
280276
chat,
@@ -284,8 +280,11 @@ def _handle_invalid_request(
284280
)
285281
assert len(predictions) > 1
286282
assert act_result.rounds == len(predictions)
287-
# Test case is currently suppressing exceptions inside the tool call
288-
assert invalid_requests == []
283+
# Ensure the tool call failure was reported to the user callback
284+
assert len(request_failures) == 1
285+
tool_failure_exc = request_failures[0]
286+
assert isinstance(tool_failure_exc, LMStudioPredictionError)
287+
assert isinstance(tool_failure_exc.__cause__, ZeroDivisionError)
289288
# If the content checks prove flaky in practice, they can be dropped
290289
assert "divide" in predictions[-1].content
291290
assert "zero" in predictions[-1].content

0 commit comments

Comments
 (0)