@@ -68,6 +68,7 @@ def __init__(
68
68
num_hidden_layers = 32 ,
69
69
num_attention_heads = 32 ,
70
70
num_key_value_heads = 8 ,
71
+ head_dim = None ,
71
72
hidden_act = "silu" ,
72
73
max_position_embeddings = 4096 * 32 ,
73
74
initializer_range = 0.02 ,
@@ -101,8 +102,11 @@ def __init__(
101
102
# for backward compatibility
102
103
if num_key_value_heads is None :
103
104
num_key_value_heads = num_attention_heads
105
+ if head_dim is None :
106
+ head_dim = hidden_size // num_attention_heads
104
107
105
108
self .num_key_value_heads = num_key_value_heads
109
+ self .head_dim = head_dim
106
110
self .hidden_act = hidden_act
107
111
self .initializer_range = initializer_range
108
112
self .rms_norm_eps = rms_norm_eps
@@ -294,6 +298,7 @@ def __init__(
294
298
hidden_size : int ,
295
299
num_heads : int ,
296
300
num_kv_heads : int ,
301
+ head_dim : Optional [int ] = None ,
297
302
max_position : int = 4096 * 32 ,
298
303
rope_theta : float = 10000 ,
299
304
cache_config : Optional [CacheConfig ] = None ,
@@ -317,7 +322,9 @@ def __init__(
317
322
# the KV heads across multiple tensor parallel GPUs.
318
323
assert tp_size % self .total_num_kv_heads == 0
319
324
self .num_kv_heads = max (1 , self .total_num_kv_heads // tp_size )
320
- self .head_dim = hidden_size // self .total_num_heads
325
+ if head_dim is None :
326
+ head_dim = hidden_size // num_heads
327
+ self .head_dim = head_dim
321
328
self .q_size = self .num_heads * self .head_dim
322
329
self .kv_size = self .num_kv_heads * self .head_dim
323
330
self .scaling = self .head_dim ** - 0.5
@@ -387,6 +394,8 @@ def __init__(
387
394
num_heads = config .num_attention_heads ,
388
395
max_position = config .max_position_embeddings ,
389
396
num_kv_heads = config .num_key_value_heads ,
397
+ head_dim = getattr (config , "head_dim" ,
398
+ self .hidden_size // config .num_attention_heads ),
390
399
rope_theta = rope_theta ,
391
400
cache_config = cache_config ,
392
401
quant_config = quant_config ,
0 commit comments