@@ -138,14 +138,16 @@ def update(
138
138
139
139
140
140
class StaticAttentionMask :
141
- def __init__ (self , input_len , cache_len , style , mask_val = float ("-inf" )):
141
+ def __init__ (
142
+ self , input_len , cache_len , style , mask_val = float ("-inf" ), dtype = torch .float32
143
+ ):
142
144
self .input_len = input_len
143
145
self .cache_len = cache_len
144
146
assert style in ("shift_pointer" , "smart_mask" )
145
147
self .style = style
146
148
self .mask_val = mask_val
147
149
self .unmasked_len = 0
148
- self .tensor = torch .zeros (1 , input_len , input_len + cache_len )
150
+ self .tensor = torch .zeros (1 , input_len , input_len + cache_len , dtype = dtype )
149
151
self .reset ()
150
152
151
153
def reset (self ):
@@ -200,44 +202,45 @@ def __init__(
200
202
config : ModelArgs ,
201
203
input_len : int ,
202
204
cache_len : int ,
205
+ dtype = torch .float32 ,
203
206
style : str = "shift_pointer" ,
204
207
mask_val : float = float ("-inf" ),
205
208
):
206
209
self .mask = StaticAttentionMask (
207
- input_len , cache_len , style = style , mask_val = mask_val
210
+ input_len , cache_len , style = style , mask_val = mask_val , dtype = dtype
208
211
)
209
212
210
213
rope = Rope (config )
211
214
freqs = rope .get_freqs (None , config .max_seq_len )
212
- self .freqs_cos = freqs [0 ]
213
- self .freqs_sin = freqs [1 ]
215
+ self .freqs_cos = freqs [0 ]. to ( dtype )
216
+ self .freqs_sin = freqs [1 ]. to ( dtype )
214
217
215
218
split_mha = config .attention_type in ("static" , "static_shas" )
216
219
if split_mha :
217
220
self .k_caches = {
218
221
StaticKVCache .calculate_cache_key (layer_id , head_id ): torch .zeros (
219
- 1 , cache_len , config .head_dim
222
+ 1 , cache_len , config .head_dim , dtype = dtype
220
223
)
221
224
for layer_id in range (config .n_layers )
222
225
for head_id in range (config .n_kv_heads )
223
226
}
224
227
self .v_caches = {
225
228
StaticKVCache .calculate_cache_key (layer_id , head_id ): torch .zeros (
226
- 1 , cache_len , config .head_dim
229
+ 1 , cache_len , config .head_dim , dtype = dtype
227
230
)
228
231
for layer_id in range (config .n_layers )
229
232
for head_id in range (config .n_kv_heads )
230
233
}
231
234
else :
232
235
self .k_caches = {
233
236
StaticKVCache .calculate_cache_key (layer_id , 0 ): torch .zeros (
234
- 1 , config .n_kv_heads , cache_len , config .head_dim
237
+ 1 , config .n_kv_heads , cache_len , config .head_dim , dtype = dtype
235
238
)
236
239
for layer_id in range (config .n_layers )
237
240
}
238
241
self .v_caches = {
239
242
StaticKVCache .calculate_cache_key (layer_id , 0 ): torch .zeros (
240
- 1 , config .n_kv_heads , cache_len , config .head_dim
243
+ 1 , config .n_kv_heads , cache_len , config .head_dim , dtype = dtype
241
244
)
242
245
for layer_id in range (config .n_layers )
243
246
}
0 commit comments