Skip to content

Commit 4744ff5

Browse files
authored
Support dtype option in StaticAttentionIOManager
Differential Revision: D78494128 Pull Request resolved: #12647
1 parent 04c1a07 commit 4744ff5

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

examples/models/llama/static_attention.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,16 @@ def update(
138138

139139

140140
class StaticAttentionMask:
141-
def __init__(self, input_len, cache_len, style, mask_val=float("-inf")):
141+
def __init__(
142+
self, input_len, cache_len, style, mask_val=float("-inf"), dtype=torch.float32
143+
):
142144
self.input_len = input_len
143145
self.cache_len = cache_len
144146
assert style in ("shift_pointer", "smart_mask")
145147
self.style = style
146148
self.mask_val = mask_val
147149
self.unmasked_len = 0
148-
self.tensor = torch.zeros(1, input_len, input_len + cache_len)
150+
self.tensor = torch.zeros(1, input_len, input_len + cache_len, dtype=dtype)
149151
self.reset()
150152

151153
def reset(self):
@@ -200,44 +202,45 @@ def __init__(
200202
config: ModelArgs,
201203
input_len: int,
202204
cache_len: int,
205+
dtype=torch.float32,
203206
style: str = "shift_pointer",
204207
mask_val: float = float("-inf"),
205208
):
206209
self.mask = StaticAttentionMask(
207-
input_len, cache_len, style=style, mask_val=mask_val
210+
input_len, cache_len, style=style, mask_val=mask_val, dtype=dtype
208211
)
209212

210213
rope = Rope(config)
211214
freqs = rope.get_freqs(None, config.max_seq_len)
212-
self.freqs_cos = freqs[0]
213-
self.freqs_sin = freqs[1]
215+
self.freqs_cos = freqs[0].to(dtype)
216+
self.freqs_sin = freqs[1].to(dtype)
214217

215218
split_mha = config.attention_type in ("static", "static_shas")
216219
if split_mha:
217220
self.k_caches = {
218221
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
219-
1, cache_len, config.head_dim
222+
1, cache_len, config.head_dim, dtype=dtype
220223
)
221224
for layer_id in range(config.n_layers)
222225
for head_id in range(config.n_kv_heads)
223226
}
224227
self.v_caches = {
225228
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
226-
1, cache_len, config.head_dim
229+
1, cache_len, config.head_dim, dtype=dtype
227230
)
228231
for layer_id in range(config.n_layers)
229232
for head_id in range(config.n_kv_heads)
230233
}
231234
else:
232235
self.k_caches = {
233236
StaticKVCache.calculate_cache_key(layer_id, 0): torch.zeros(
234-
1, config.n_kv_heads, cache_len, config.head_dim
237+
1, config.n_kv_heads, cache_len, config.head_dim, dtype=dtype
235238
)
236239
for layer_id in range(config.n_layers)
237240
}
238241
self.v_caches = {
239242
StaticKVCache.calculate_cache_key(layer_id, 0): torch.zeros(
240-
1, config.n_kv_heads, cache_len, config.head_dim
243+
1, config.n_kv_heads, cache_len, config.head_dim, dtype=dtype
241244
)
242245
for layer_id in range(config.n_layers)
243246
}

0 commit comments

Comments
 (0)