Skip to content

Commit c5dba06

Browse files
committed
qk norm before rope arg
1 parent fa7ff5d commit c5dba06

File tree

5 files changed

+13
-4
lines changed

5 files changed

+13
-4
lines changed

examples/models/llama/attention.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
178178
self.dim = args.dim
179179
self.attention_qkv_bias = args.attention_qkv_bias
180180
self.use_qk_norm = args.use_qk_norm
181+
self.qk_norm_before_rope = args.qk_norm_before_rope
181182

182183
if self.use_qk_norm:
183184
q_norm_dim = self.head_dim
@@ -243,7 +244,7 @@ def forward(
243244
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
244245
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
245246

246-
if self.use_qk_norm:
247+
if self.use_qk_norm and self.qk_norm_before_rope:
247248
q = self.q_norm_fn(q)
248249
k = self.k_norm_fn(k)
249250

@@ -254,6 +255,10 @@ def forward(
254255
k = k.transpose(1, 2)
255256
v = v.transpose(1, 2)
256257

258+
if self.use_qk_norm and not self.qk_norm_before_rope:
259+
q = self.q_norm_fn(q)
260+
k = self.k_norm_fn(k)
261+
257262
if self.use_kv_cache:
258263
assert input_pos is not None
259264
k, v = self.kv_cache.update(input_pos, k, v)

examples/models/llama/model_args.py

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class ModelArgs:
3838
apply_embedding: bool = True # Use embedding inside the transformer
3939
apply_output: bool = True # Use output layer (unembedding) inside the transformer
4040
use_qk_norm: bool = False # apply normalization to q and k in the attention
41+
qk_norm_before_rope: bool = False # when to apply qk norm
4142
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
4243
partial_rotary_factor: float = 1.0
4344
rope_theta: Optional[float] = (

examples/models/qwen3/0_6b_config.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
"vocab_size": 151936,
1313
"use_hf_rope": true,
1414
"attention_qkv_bias": false,
15-
"use_qk_norm": true
15+
"use_qk_norm": true,
16+
"qk_norm_before_rope": true
1617
}

examples/models/qwen3/1_7b_config.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
"vocab_size": 151936,
1313
"use_hf_rope": true,
1414
"attention_qkv_bias": false,
15-
"use_qk_norm": true
15+
"use_qk_norm": true,
16+
"qk_norm_before_rope": true
1617
}

examples/models/qwen3/4b_config.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
"vocab_size": 151936,
1313
"use_hf_rope": true,
1414
"attention_qkv_bias": false,
15-
"use_qk_norm": true
15+
"use_qk_norm": true,
16+
"qk_norm_before_repo": true
1617
}

0 commit comments

Comments
 (0)