@@ -1054,11 +1054,10 @@ def forward(
1054
1054
]
1055
1055
num_actual_toks = attn_metadata .num_actual_tokens
1056
1056
if k_pe is None and not self .running_in_graph :
1057
- if not self .torchair_graph_enabled :
1058
- kv_c , k_pe = self .kv_a_proj_with_mqa (
1059
- hidden_states_or_kv_c_normed )[0 ].split (
1060
- [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
1061
- kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
1057
+ kv_c , k_pe = self .kv_a_proj_with_mqa (
1058
+ hidden_states_or_kv_c_normed )[0 ].split (
1059
+ [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
1060
+ kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
1062
1061
else :
1063
1062
kv_c_normed = hidden_states_or_kv_c_normed
1064
1063
assert attn_metadata .num_decodes is not None and \
@@ -1077,12 +1076,13 @@ def forward(
1077
1076
if not self .running_in_graph :
1078
1077
hidden_states_or_q_c = hidden_states_or_q_c [:num_actual_toks , ...]
1079
1078
prefill_hs_or_q_c = hidden_states_or_q_c [num_decode_tokens :]
1080
- if not self .torchair_graph_enabled :
1081
- decode_hs_or_q_c = hidden_states_or_q_c [:num_decode_tokens ]
1082
- k_pe = k_pe [:num_actual_toks , ...]
1083
- k_pe = k_pe .unsqueeze (1 )
1084
- decode_k_pe = k_pe [:num_decode_tokens ]
1085
- prefill_k_pe = k_pe [num_decode_tokens :]
1079
+ decode_hs_or_q_c = hidden_states_or_q_c [:num_decode_tokens ]
1080
+ prefill_hs = hidden_states_or_kv_c_normed [num_decode_tokens :]
1081
+ # if not self.torchair_graph_enabled:
1082
+ k_pe = k_pe [:num_actual_toks , ...]
1083
+ k_pe = k_pe .unsqueeze (1 )
1084
+ decode_k_pe = k_pe [:num_decode_tokens ]
1085
+ prefill_k_pe = k_pe [num_decode_tokens :]
1086
1086
else :
1087
1087
decode_hs_or_q_c = hidden_states_or_q_c
1088
1088
if has_decode :
@@ -1146,11 +1146,11 @@ def forward(
1146
1146
1147
1147
prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
1148
1148
prefill_k_pe , prefill_k_nope = self .exec_kv_prefill (
1149
- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1150
- attn_metadata .slot_mapping )
1149
+ prefill_hs , cos , sin , kv_cache ,
1150
+ attn_metadata .slot_mapping [ num_decode_tokens :] )
1151
1151
1152
1152
kv_c_normed = prefill_k_nope [:num_actual_toks , ...]
1153
- prefill_k_c_normed = prefill_k_nope [ num_decode_tokens :]
1153
+ prefill_k_c_normed = prefill_k_nope
1154
1154
prefill_k_pe = prefill_k_pe .view (num_tokens , self .num_kv_heads ,
1155
1155
- 1 )
1156
1156
prefill_q = torch .cat ([prefill_q_nope , prefill_q_pe ], dim = - 1 )
0 commit comments