Skip to content

Commit 873827c

Browse files
authored
[Model] Enhance error reporting for invalid tensor-parallel settings (#2566)
This PR enhances the error reporting for multi-GPU model compilation, so we can provide as many error reasons as possible before loading and running the models.
1 parent a231ae1 commit 873827c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+172
-4
lines changed

python/mlc_llm/model/baichuan/baichuan_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ def __post_init__(self):
8787
class BaichuanAttention(nn.Module): # pylint: disable=too-many-instance-attributes
8888
def __init__(self, config: BaichuanConfig):
8989
self.hidden_size = config.hidden_size
90+
if config.num_attention_heads % config.tensor_parallel_shards != 0:
91+
raise ValueError(
92+
f"Cannot split {config.num_attention_heads} attention heads "
93+
f"evenly to {config.tensor_parallel_shards} GPUs."
94+
)
9095
self.num_heads = config.num_attention_heads // config.tensor_parallel_shards
9196
self.head_dim = config.head_dim
9297
self.W_pack = nn.Linear(self.hidden_size, 3 * self.num_heads * self.head_dim, bias=False)
@@ -106,6 +111,11 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:
106111

107112
class BaichuanMLP(nn.Module):
108113
def __init__(self, config: BaichuanConfig):
114+
if config.intermediate_size % config.tensor_parallel_shards != 0:
115+
raise ValueError(
116+
f"Cannot split MLP intermediate size {config.intermediate_size} "
117+
f"evenly to {config.tensor_parallel_shards} GPUs."
118+
)
109119
self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards
110120
self.gate_up_proj = nn.Linear(
111121
in_features=config.hidden_size,

python/mlc_llm/model/baichuan/baichuan_quantization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def group_quant(
1919
model: nn.Module = BaichuanForCausalLM(model_config)
2020
model.to(quantization.model_dtype)
2121
quant_map = QuantizeMapping({}, {})
22+
quantization.tensor_parallel_shards = model_config.tensor_parallel_shards
2223
model = quantization.quantize_model(
2324
model,
2425
quant_map,

python/mlc_llm/model/bert/bert_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ def __post_init__(self):
8383

8484
class BertSelfAttention(nn.Module): # pylint: disable=too-many-instance-attributes
8585
def __init__(self, config: BertConfig):
86+
if config.num_attention_heads % config.tensor_parallel_shards != 0:
87+
raise ValueError(
88+
f"Cannot split {config.num_attention_heads} attention heads"
89+
f"evenly to {config.tensor_parallel_shards} GPUs."
90+
)
8691
self.num_heads = config.num_attention_heads // config.tensor_parallel_shards
8792
self.head_dim = config.head_dim
8893

python/mlc_llm/model/bert/bert_quantization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def group_quant(
1919
model: nn.Module = BertModel(model_config)
2020
model.to(quantization.model_dtype)
2121
quant_map = QuantizeMapping({}, {})
22+
quantization.tensor_parallel_shards = model_config.tensor_parallel_shards
2223
model = quantization.quantize_model(
2324
model,
2425
quant_map,

python/mlc_llm/model/chatglm3/chatglm3_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ def __post_init__(self):
9393
class GLMAttention(nn.Module): # pylint: disable=too-many-instance-attributes
9494
def __init__(self, config: GLMConfig):
9595
self.hidden_size = config.hidden_size
96+
if config.num_attention_heads % config.tensor_parallel_shards != 0:
97+
raise ValueError(
98+
f"Cannot split {config.num_attention_heads} attention heads"
99+
f"evenly to {config.tensor_parallel_shards} GPUs."
100+
)
96101
self.num_heads = config.num_attention_heads // config.tensor_parallel_shards
97102
self.multi_query_attention = config.multi_query_attention
98103
self.num_key_value_heads = (
@@ -125,6 +130,11 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:
125130

126131
class GLMMLP(nn.Module):
127132
def __init__(self, config: GLMConfig):
133+
if config.ffn_hidden_size % config.tensor_parallel_shards != 0:
134+
raise ValueError(
135+
f"Cannot split ffn hidden size {config.ffn_hidden_size} "
136+
f"evenly to {config.tensor_parallel_shards} GPUs."
137+
)
128138
self.ffn_hidden_size = config.ffn_hidden_size // config.tensor_parallel_shards
129139

130140
self.dense_h_to_4h = nn.Linear(

python/mlc_llm/model/chatglm3/chatglm3_quantization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def group_quant(
1919
model: nn.Module = ChatGLMForCausalLM(model_config)
2020
model.to(quantization.model_dtype)
2121
quant_map = QuantizeMapping({}, {})
22+
quantization.tensor_parallel_shards = model_config.tensor_parallel_shards
2223
model = quantization.quantize_model(
2324
model,
2425
quant_map,

python/mlc_llm/model/eagle/eagle_quantization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def group_quant(
1919
model: nn.Module = EagleForCasualLM(model_config)
2020
model.to(quantization.model_dtype)
2121
quant_map = QuantizeMapping({}, {})
22+
quantization.tensor_parallel_shards = model_config.tensor_parallel_shards
2223
model = quantization.quantize_model(
2324
model,
2425
quant_map,

python/mlc_llm/model/gemma/gemma_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ def lm_head_forward(self, x: nn.Tensor):
102102
class GemmaMLP(nn.Module):
103103
def __init__(self, config: GemmaConfig):
104104
super().__init__()
105+
if config.intermediate_size % config.tensor_parallel_shards != 0:
106+
raise ValueError(
107+
f"Cannot split MLP intermediate size {config.intermediate_size} "
108+
f"evenly to {config.tensor_parallel_shards} GPUs."
109+
)
105110
self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards
106111
self.gate_up_proj = nn.Linear(
107112
in_features=config.hidden_size,

python/mlc_llm/model/gemma/gemma_quantization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def group_quant(
1919
model: nn.Module = GemmaForCausalLM(model_config)
2020
model.to(quantization.model_dtype)
2121
quant_map = QuantizeMapping({}, {})
22+
quantization.tensor_parallel_shards = model_config.tensor_parallel_shards
2223
model = quantization.quantize_model(
2324
model,
2425
quant_map,

python/mlc_llm/model/gpt2/gpt2_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ def __post_init__(self):
8484
class GPT2Attention(nn.Module): # pylint: disable=too-many-instance-attributes
8585
def __init__(self, config: GPT2Config):
8686
self.embed_dim = config.n_embd
87+
if config.n_head % config.tensor_parallel_shards != 0:
88+
raise ValueError(
89+
f"Cannot split {config.n_head} attention heads "
90+
f"evenly to {config.tensor_parallel_shards} GPUs."
91+
)
8792
self.num_heads = config.n_head // config.tensor_parallel_shards
8893
self.head_dim = config.head_dim
8994
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
@@ -120,6 +125,11 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:
120125
class GPT2MLP(nn.Module):
121126
def __init__(self, config: GPT2Config):
122127
embed_dim = config.n_embd
128+
if config.n_inner % config.tensor_parallel_shards != 0:
129+
raise ValueError(
130+
f"Cannot split MLP intermediate size {config.n_inner} "
131+
f"evenly to {config.tensor_parallel_shards} GPUs."
132+
)
123133
intermediate_size = config.n_inner // config.tensor_parallel_shards
124134
self.c_fc = nn.Linear(embed_dim, intermediate_size)
125135
self.c_proj = nn.Linear(intermediate_size, embed_dim)

0 commit comments

Comments
 (0)