Skip to content

Commit a9d09f9

Browse files
authored
Ensure tool use test covers mixed output (#29)
1 parent 54ed201 commit a9d09f9

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

tests/test_inference.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,20 +199,21 @@ def test_tool_using_agent_callbacks(caplog: LogCap) -> None:
199199
with Client() as client:
200200
llm = client.llm.model(model_id)
201201
chat = Chat()
202-
chat.add_user_message("What is the sum of 123 and 3210?")
202+
# Ensure the first response is a combination of text and tool use requests
203+
chat.add_user_message("First say 'Hi'. Then calculate 1 + 3 with the tool.")
203204
tools = [ADDITION_TOOL_SPEC]
204205
round_starts: list[int] = []
205206
round_ends: list[int] = []
206207
first_tokens: list[int] = []
207208
predictions: list[PredictionRoundResult] = []
208209
fragments: list[LlmPredictionFragment] = []
209-
last_fragment_round_index = 0
210+
fragment_round_indices: set[int] = set()
210211

211212
def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
212-
nonlocal last_fragment_round_index
213+
last_fragment_round_index = max(fragment_round_indices, default=-1)
213214
assert round_index >= last_fragment_round_index
214215
fragments.append(f)
215-
last_fragment_round_index = round_index
216+
fragment_round_indices.add(round_index)
216217

217218
# TODO: Also check on_prompt_processing_progress and handling invalid messages
218219
# (although it isn't clear how to provoke calls to the latter without mocking)
@@ -233,8 +234,9 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
233234
assert round_starts == sequential_round_indices
234235
assert round_ends == sequential_round_indices
235236
expected_token_indices = [p.round_index for p in predictions if p.content]
237+
assert expected_token_indices == sequential_round_indices
236238
assert first_tokens == expected_token_indices
237-
assert last_fragment_round_index == num_rounds - 1
239+
assert fragment_round_indices == set(expected_token_indices)
238240
assert len(chat._messages) == 2 * num_rounds # No tool results in last round
239241

240242
cloned_chat = chat.copy()

0 commit comments

Comments
 (0)