Skip to content

Commit e950b15

Browse files
authored
Fixing a bug from transformers==4.52. config.head_dim is now explicitly set to None (#552)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
1 parent 8a67a53 commit e950b15

File tree

8 files changed

+27
-18
lines changed

8 files changed

+27
-18
lines changed

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ def get_model_args(self, model_executable: torch.nn.Module):
4444
head_size = model_config.qk_nope_head_dim + \
4545
model_config.qk_rope_head_dim
4646
else:
47-
head_size = getattr(model_config, "head_dim",
48-
int(hidden_size // num_attention_heads))
47+
head_size = getattr(model_config, "head_dim", None)
48+
if head_size is None:
49+
head_size = int(hidden_size // num_attention_heads)
4950

5051
return num_heads, head_size
5152

vllm/model_executor/models/exaone.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,9 @@ def __init__(
127127
assert tp_size % self.total_num_kv_heads == 0
128128
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
129129
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
130-
self.head_dim = getattr(config, "head_dim",
131-
self.hidden_size // self.total_num_heads)
130+
self.head_dim = getattr(config, "head_dim", None)
131+
if self.head_dim is None:
132+
self.head_dim = self.hidden_size // self.total_num_heads
132133
self.q_size = self.num_heads * self.head_dim
133134
self.kv_size = self.num_kv_heads * self.head_dim
134135
self.scaling = self.head_dim**-0.5

vllm/model_executor/models/granite.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,9 @@ def __init__(
122122
assert tp_size % self.total_num_kv_heads == 0
123123
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
124124
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
125-
self.head_dim = getattr(config, "head_dim",
126-
self.hidden_size // self.total_num_heads)
125+
self.head_dim = getattr(config, "head_dim", None)
126+
if self.head_dim is None:
127+
self.head_dim = self.hidden_size // self.total_num_heads
127128
self.q_size = self.num_heads * self.head_dim
128129
self.kv_size = self.num_kv_heads * self.head_dim
129130
self.scaling = config.attention_multiplier

vllm/model_executor/models/minimax_text_01.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -604,8 +604,9 @@ def __init__(
604604

605605
rope_theta = getattr(config, "rope_theta", 10000)
606606

607-
head_dim = getattr(config, "head_dim",
608-
config.hidden_size // config.num_attention_heads)
607+
head_dim = getattr(config, "head_dim", None)
608+
if head_dim is None:
609+
head_dim = config.hidden_size // config.num_attention_heads
609610
if hasattr(config, "max_model_len") and isinstance(
610611
config.max_model_len, int):
611612
max_position_embeddings = min(config.max_position_embeddings,
@@ -861,8 +862,9 @@ def layer_fn(prefix):
861862
cache_shape=self.cache_shape)
862863

863864
rope_theta = getattr(config, "rope_theta", 10000)
864-
head_dim = getattr(config, "head_dim",
865-
config.hidden_size // config.num_attention_heads)
865+
head_dim = getattr(config, "head_dim", None)
866+
if head_dim is None:
867+
head_dim = config.hidden_size // config.num_attention_heads
866868
if hasattr(config, "max_model_len") and isinstance(
867869
config.max_model_len, int):
868870
max_position_embeddings = min(config.max_position_embeddings,

vllm/model_executor/models/mixtral.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,9 @@ def __init__(
138138
assert tp_size % self.total_num_kv_heads == 0
139139
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
140140
# MixtralConfig has an optional head_dim argument
141-
self.head_dim = getattr(config, "head_dim",
142-
self.hidden_size // self.total_num_heads)
141+
self.head_dim = getattr(config, "head_dim", None)
142+
if self.head_dim is None:
143+
self.head_dim = self.hidden_size // self.total_num_heads
143144
self.q_size = self.num_heads * self.head_dim
144145
self.kv_size = self.num_kv_heads * self.head_dim
145146
self.scaling = self.head_dim**-0.5

vllm/model_executor/models/mixtral_quant.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,9 @@ def __init__(
193193
assert tp_size % self.total_num_kv_heads == 0
194194
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
195195
# MixtralConfig has an optional head_dim argument
196-
self.head_dim = getattr(config, "head_dim",
197-
self.hidden_size // self.total_num_heads)
196+
self.head_dim = getattr(config, "head_dim", None)
197+
if self.head_dim is None:
198+
self.head_dim = self.hidden_size // self.total_num_heads
198199
self.q_size = self.num_heads * self.head_dim
199200
self.kv_size = self.num_kv_heads * self.head_dim
200201
self.scaling = self.head_dim**-0.5

vllm/model_executor/models/nemotron.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,9 @@ def __init__(
158158
assert tp_size % self.total_num_kv_heads == 0
159159
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
160160
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
161-
self.head_dim = getattr(config, "head_dim",
162-
self.hidden_size // self.total_num_heads)
161+
self.head_dim = getattr(config, "head_dim", None)
162+
if self.head_dim is None:
163+
self.head_dim = self.hidden_size // self.total_num_heads
163164
self.q_size = self.num_heads * self.head_dim
164165
self.kv_size = self.num_kv_heads * self.head_dim
165166
self.scaling = self.head_dim**-0.5

vllm/model_executor/models/solar.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,9 @@ def __init__(
126126
assert tp_size % self.total_num_kv_heads == 0
127127
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
128128
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
129-
self.head_dim = getattr(config, "head_dim",
130-
self.hidden_size // self.total_num_heads)
129+
self.head_dim = getattr(config, "head_dim", None)
130+
if self.head_dim is None:
131+
self.head_dim = self.hidden_size // self.total_num_heads
131132
self.q_size = self.num_heads * self.head_dim
132133
self.kv_size = self.num_kv_heads * self.head_dim
133134
self.scaling = self.head_dim**-0.5

0 commit comments

Comments
 (0)