Skip to content

Commit fc5b21b

Browse files
FanShupeiarthw
authored andcommitted
llama: use sliding window for phi3 (ggml-org#8627)
* use sliding window for phi3 * fix typo, "data_swa" -> "data" * [conver_hf_to_gguf.py] add phi3 sliding window
1 parent 65e54b5 commit fc5b21b

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

convert_hf_to_gguf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2084,6 +2084,7 @@ def set_gguf_parameters(self):
20842084
self.gguf_writer.add_rope_dimension_count(rope_dims)
20852085
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
20862086
self.gguf_writer.add_file_type(self.ftype)
2087+
self.gguf_writer.add_sliding_window(self.find_hparam(["sliding_window"]))
20872088

20882089
# write rope scaling for long context (128k) model
20892090
rope_scaling = self.find_hparam(['rope_scaling'], True)

src/llama.cpp

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4890,6 +4890,7 @@ static void llm_load_hparams(
48904890
} break;
48914891
case LLM_ARCH_PHI3:
48924892
{
4893+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
48934894
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
48944895

48954896
switch (hparams.n_layer) {
@@ -10749,7 +10750,7 @@ struct llm_build_context {
1074910750
struct ggml_tensor * inp_pos = build_inp_pos();
1075010751

1075110752
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
10752-
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
10753+
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
1075310754

1075410755
for (int il = 0; il < n_layer; ++il) {
1075510756
auto residual = inpL;
@@ -10807,7 +10808,7 @@ struct llm_build_context {
1080710808

1080810809
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
1080910810
model.layers[il].wo, model.layers[il].bo,
10810-
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
10811+
Kcur, Vcur, Qcur, KQ_mask_swa, n_tokens, kv_head, n_kv, 1.0f, cb, il);
1081110812
}
1081210813

1081310814
if (il == n_layer - 1) {
@@ -14014,18 +14015,23 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1401414015
"causal attention is not supported by this model"
1401514016
);
1401614017

14017-
if (lctx.inp_KQ_mask) {
14018+
if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
1401814019
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
1401914020
if (cparams.causal_attn && !lctx.is_encoding) {
1402014021
const int64_t n_kv = kv_self.n;
1402114022
const int64_t n_tokens = batch.n_tokens;
1402214023

14023-
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
1402414024

14025-
float * data = (float *) lctx.inp_KQ_mask->data;
14025+
float * data = nullptr;
1402614026
float * data_swa = nullptr;
1402714027

14028+
if (lctx.inp_KQ_mask) {
14029+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
14030+
data = (float *) lctx.inp_KQ_mask->data;
14031+
}
14032+
1402814033
if (lctx.inp_KQ_mask_swa) {
14034+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
1402914035
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
1403014036
}
1403114037

@@ -14048,7 +14054,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1404814054
f = 0.0f;
1404914055
}
1405014056
}
14051-
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
14057+
14058+
if (data) {
14059+
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
14060+
}
1405214061

1405314062
// may need to cut off old tokens for sliding window
1405414063
if (data_swa) {
@@ -14060,9 +14069,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1406014069
}
1406114070
}
1406214071

14063-
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
14064-
for (int j = 0; j < n_kv; ++j) {
14065-
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
14072+
if (data) {
14073+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
14074+
for (int j = 0; j < n_kv; ++j) {
14075+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
14076+
}
14077+
}
14078+
}
14079+
14080+
if (data_swa) {
14081+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
14082+
for (int j = 0; j < n_kv; ++j) {
14083+
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
14084+
}
1406614085
}
1406714086
}
1406814087
}

0 commit comments

Comments
 (0)