Skip to content

Commit f56d299

Browse files
authored
[Misc] Respect no_use_tqdm_on_load flag while capturing CUDA graph (#20834)
Signed-off-by: Linkun <github@lkchen.net>
1 parent 147afb4 commit f56d299

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2270,8 +2270,10 @@ def capture_model(self) -> None:
22702270
# Only rank 0 should print progress bar during capture
22712271
compilation_cases = reversed(self.cudagraph_batch_sizes)
22722272
if is_global_first_rank():
2273-
compilation_cases = tqdm(list(compilation_cases),
2274-
desc="Capturing CUDA graph shapes")
2273+
compilation_cases = tqdm(
2274+
list(compilation_cases),
2275+
disable=not self.load_config.use_tqdm_on_load,
2276+
desc="Capturing CUDA graph shapes")
22752277
for num_tokens in compilation_cases:
22762278
# We skip EPLB here since we don't want to record dummy metrics
22772279
for _ in range(

vllm/worker/model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,6 +1587,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
15871587
if get_tensor_model_parallel_rank() == 0:
15881588
compilation_cases = tqdm(
15891589
list(compilation_cases),
1590+
disable=not self.load_config.use_tqdm_on_load,
15901591
desc="Capturing CUDA graph shapes")
15911592
for batch_size, use_inputs_embeds in compilation_cases:
15921593
attn_metadata = (

0 commit comments

Comments
 (0)