Skip to content

Commit b342f83

Browse files
Add last_token_pos in llama_transformer (#11793)
Differential Revision: D76440105 Pull Request resolved: #12239
1 parent 1decf7a commit b342f83

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

examples/models/llama/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class ForwardOptions(TypedDict, total=False):
1919
freqs_sin_override: Optional[torch.Tensor]
2020
in_cache_state: Optional[Any]
2121
out_cache_state: Optional[Any]
22+
last_valid_token_pos: Optional[torch.LongTensor]
2223

2324

2425
class Attention(nn.Module, ABC):

examples/models/llama/llama_transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,8 @@ def forward(
204204

205205
if not self.generate_full_logits:
206206
# Only the last logit is used for the new generated token
207-
h = h[:, -1, :]
207+
pos = attn_options.get("last_valid_token_pos", -1)
208+
h = h[:, pos, :]
208209

209210
h = self.norm(h)
210211

0 commit comments

Comments
 (0)