Skip to content

Commit 5bcb326

Browse files
authored
Mf/world size check (#1586)
* add world size check * add world size check, bump version
1 parent 1de1766 commit 5bcb326

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.9.81rc001"
3+
version = "0.9.81rc003"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"

truss/base/trt_llm_config.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,11 +600,20 @@ def trt_llm_validation(config: "TrussConfig") -> "TrussConfig":
600600
"FP8 quantization is only supported on L4, H100, H200 "
601601
"accelerators or newer (CUDA_COMPUTE>=89)"
602602
)
603-
tensor_parallel_count = config.trt_llm.build.tensor_parallel_count
603+
world_size = (
604+
config.trt_llm.build.tensor_parallel_count
605+
* config.trt_llm.build.pipeline_parallel_count
606+
* config.trt_llm.build.sequence_parallel_count
607+
)
604608

605-
if tensor_parallel_count != config.resources.accelerator.count:
609+
if world_size != config.resources.accelerator.count:
606610
raise ValueError(
607611
"Tensor parallelism and GPU count must be the same for TRT-LLM"
612+
f"You have set tensor_parallel_count={config.trt_llm.build.tensor_parallel_count}, "
613+
f"pipeline_parallel_count={config.trt_llm.build.pipeline_parallel_count}, "
614+
f"sequence_parallel_count={config.trt_llm.build.sequence_parallel_count} "
615+
f"== world_size->{world_size} "
616+
f"and accelerator.count={config.resources.accelerator.count}. "
608617
)
609618

610619
return config

0 commit comments

Comments
 (0)