@@ -199,20 +199,21 @@ def test_tool_using_agent_callbacks(caplog: LogCap) -> None:
199
199
with Client () as client :
200
200
llm = client .llm .model (model_id )
201
201
chat = Chat ()
202
- chat .add_user_message ("What is the sum of 123 and 3210?" )
202
+ # Ensure the first response is a combination of text and tool use requests
203
+ chat .add_user_message ("First say 'Hi'. Then calculate 1 + 3 with the tool." )
203
204
tools = [ADDITION_TOOL_SPEC ]
204
205
round_starts : list [int ] = []
205
206
round_ends : list [int ] = []
206
207
first_tokens : list [int ] = []
207
208
predictions : list [PredictionRoundResult ] = []
208
209
fragments : list [LlmPredictionFragment ] = []
209
- last_fragment_round_index = 0
210
+ fragment_round_indices : set [ int ] = set ()
210
211
211
212
def _append_fragment (f : LlmPredictionFragment , round_index : int ) -> None :
212
- nonlocal last_fragment_round_index
213
+ last_fragment_round_index = max ( fragment_round_indices , default = - 1 )
213
214
assert round_index >= last_fragment_round_index
214
215
fragments .append (f )
215
- last_fragment_round_index = round_index
216
+ fragment_round_indices . add ( round_index )
216
217
217
218
# TODO: Also check on_prompt_processing_progress and handling invalid messages
218
219
# (although it isn't clear how to provoke calls to the latter without mocking)
@@ -233,8 +234,9 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
233
234
assert round_starts == sequential_round_indices
234
235
assert round_ends == sequential_round_indices
235
236
expected_token_indices = [p .round_index for p in predictions if p .content ]
237
+ assert expected_token_indices == sequential_round_indices
236
238
assert first_tokens == expected_token_indices
237
- assert last_fragment_round_index == num_rounds - 1
239
+ assert fragment_round_indices == set ( expected_token_indices )
238
240
assert len (chat ._messages ) == 2 * num_rounds # No tool results in last round
239
241
240
242
cloned_chat = chat .copy ()
0 commit comments