@@ -65,19 +65,17 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou
65
65
def forward (self , x , mask = None , rotary_pos_emb = None , cache = None , cache_key = None ):
66
66
b , n , _ , h , device = * x .shape , self .heads , x .device
67
67
softmax = torch .softmax if not self .stable else stable_softmax
68
- using_cache = exists (cache ) and cache_key in cache
68
+ offset = cache . get ( 'offset' , 0 ) if exists (cache ) else 0
69
69
70
70
qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
71
71
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
72
72
73
73
if exists (rotary_pos_emb ):
74
- if using_cache :
75
- rotary_pos_emb = rotary_pos_emb [..., n - 1 :, :] # FIXME: Fix rotary index here
76
- q , k , v = apply_pos_emb (rotary_pos_emb , (q , k , v ))
74
+ q , k , v = apply_pos_emb (rotary_pos_emb [..., offset :, :], (q , k , v ))
77
75
78
76
q = q * self .scale
79
77
80
- if using_cache :
78
+ if offset > 0 :
81
79
k_top , v_top = cache [cache_key ]
82
80
k = torch .cat ([k_top , k ], dim = - 2 )
83
81
v = torch .cat ([v_top , v ], dim = - 2 )
@@ -92,7 +90,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
92
90
dots .masked_fill_ (~ mask , mask_value )
93
91
del mask
94
92
95
- if self .causal and not using_cache : # causality is naturally enforced if we run the cached inference
93
+ if self .causal and offset == 0 : # causality is naturally enforced for the cached inference
96
94
i , j = dots .shape [- 2 :]
97
95
mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
98
96
dots .masked_fill_ (mask , mask_value )
0 commit comments