@@ -552,6 +552,7 @@ def beam_search(
552
552
prompts : list [Union [TokensPrompt , TextPrompt ]],
553
553
params : BeamSearchParams ,
554
554
lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
555
+ use_tqdm : bool = False ,
555
556
) -> list [BeamSearchOutput ]:
556
557
"""
557
558
Generate sequences using beam search.
@@ -561,6 +562,7 @@ def beam_search(
561
562
of token IDs.
562
563
params: The beam search parameters.
563
564
lora_request: LoRA request to use for generation, if any.
565
+ use_tqdm: Whether to use tqdm to display the progress bar.
564
566
"""
565
567
# TODO: how does beam search work together with length penalty,
566
568
# frequency, penalty, and stopping criteria, etc.?
@@ -623,7 +625,18 @@ def create_tokens_prompt_from_beam(
623
625
** mm_kwargs ,
624
626
), )
625
627
626
- for _ in range (max_tokens ):
628
+ token_iter = range (max_tokens )
629
+ if use_tqdm :
630
+ token_iter = tqdm (token_iter ,
631
+ desc = "Beam search" ,
632
+ unit = "token" ,
633
+ unit_scale = False )
634
+ logger .warning (
635
+ "The progress bar shows the upper bound on token steps and "
636
+ "may finish early due to stopping conditions. It does not "
637
+ "reflect instance-level progress." )
638
+
639
+ for _ in token_iter :
627
640
all_beams : list [BeamSearchSequence ] = list (
628
641
sum ((instance .beams for instance in instances ), []))
629
642
pos = [0 ] + list (
0 commit comments