|
15 | 15 | Client,
|
16 | 16 | LlmPredictionConfig,
|
17 | 17 | LlmPredictionFragment,
|
| 18 | + LMStudioPredictionError, |
18 | 19 | LMStudioValueError,
|
19 | 20 | PredictionResult,
|
20 | 21 | PredictionRoundResult,
|
| 22 | + ToolCallRequest, |
21 | 23 | ToolFunctionDef,
|
22 | 24 | ToolFunctionDefDict,
|
23 | 25 | )
|
@@ -162,7 +164,7 @@ def test_duplicate_tool_names_rejected() -> None:
|
162 | 164 |
|
163 | 165 | @pytest.mark.lmstudio
|
164 | 166 | def test_tool_using_agent(caplog: LogCap) -> None:
|
165 |
| - # This is currently a sync-only API (it will be refactored after 1.0.0) |
| 167 | + # This is currently a sync-only API (it will be refactored in a future release) |
166 | 168 |
|
167 | 169 | caplog.set_level(logging.DEBUG)
|
168 | 170 | model_id = TOOL_LLM_ID
|
@@ -192,7 +194,7 @@ def test_tool_using_agent(caplog: LogCap) -> None:
|
192 | 194 |
|
193 | 195 | @pytest.mark.lmstudio
|
194 | 196 | def test_tool_using_agent_callbacks(caplog: LogCap) -> None:
|
195 |
| - # This is currently a sync-only API (it will be refactored after 1.0.0) |
| 197 | + # This is currently a sync-only API (it will be refactored in a future release) |
196 | 198 |
|
197 | 199 | caplog.set_level(logging.DEBUG)
|
198 | 200 | model_id = TOOL_LLM_ID
|
@@ -241,3 +243,49 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
|
241 | 243 |
|
242 | 244 | cloned_chat = chat.copy()
|
243 | 245 | assert cloned_chat._messages == chat._messages
|
| 246 | + |
| 247 | + |
| 248 | +def divide(numerator: float, denominator: float) -> float | str: |
| 249 | + """Divide the given numerator by the given denominator. Return the result.""" |
| 250 | + try: |
| 251 | + return numerator / denominator |
| 252 | + except Exception as exc: |
| 253 | + # TODO: Perform this exception-to-response-string translation implicitly |
| 254 | + return f"Unhandled Python exception: {exc!r}" |
| 255 | + |
| 256 | + |
| 257 | +@pytest.mark.lmstudio |
| 258 | +def test_tool_using_agent_error_handling(caplog: LogCap) -> None: |
| 259 | + # This is currently a sync-only API (it will be refactored in a future release) |
| 260 | + |
| 261 | + caplog.set_level(logging.DEBUG) |
| 262 | + model_id = TOOL_LLM_ID |
| 263 | + with Client() as client: |
| 264 | + llm = client.llm.model(model_id) |
| 265 | + chat = Chat() |
| 266 | + chat.add_user_message( |
| 267 | + "Attempt to divide 1 by 0 using the tool. Explain the result." |
| 268 | + ) |
| 269 | + tools = [divide] |
| 270 | + predictions: list[PredictionRoundResult] = [] |
| 271 | + invalid_requests: list[tuple[LMStudioPredictionError, ToolCallRequest]] = [] |
| 272 | + |
| 273 | + def _handle_invalid_request( |
| 274 | + exc: LMStudioPredictionError, request: ToolCallRequest | None |
| 275 | + ) -> None: |
| 276 | + if request is not None: |
| 277 | + invalid_requests.append((exc, request)) |
| 278 | + |
| 279 | + act_result = llm.act( |
| 280 | + chat, |
| 281 | + tools, |
| 282 | + handle_invalid_tool_request=_handle_invalid_request, |
| 283 | + on_prediction_completed=predictions.append, |
| 284 | + ) |
| 285 | + assert len(predictions) > 1 |
| 286 | + assert act_result.rounds == len(predictions) |
| 287 | + # Test case is currently suppressing exceptions inside the tool call |
| 288 | + assert invalid_requests == [] |
| 289 | + # If the content checks prove flaky in practice, they can be dropped |
| 290 | + assert "divide" in predictions[-1].content |
| 291 | + assert "zero" in predictions[-1].content |
0 commit comments