@@ -6788,7 +6788,7 @@ static struct ggml_tensor * llm_build_kqv(
6788
6788
6789
6789
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
6790
6790
6791
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) {
6791
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX ) {
6792
6792
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
6793
6793
}
6794
6794
@@ -6797,7 +6797,7 @@ static struct ggml_tensor * llm_build_kqv(
6797
6797
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
6798
6798
cb(kq, "kq", il);
6799
6799
6800
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) {
6800
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX ) {
6801
6801
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
6802
6802
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
6803
6803
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
@@ -17724,6 +17724,14 @@ void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_
17724
17724
ctx->cparams.n_threads_batch = n_threads_batch;
17725
17725
}
17726
17726
17727
+ uint32_t llama_n_threads(struct llama_context * ctx) {
17728
+ return ctx->cparams.n_threads;
17729
+ }
17730
+
17731
+ uint32_t llama_n_threads_batch(struct llama_context * ctx) {
17732
+ return ctx->cparams.n_threads_batch;
17733
+ }
17734
+
17727
17735
void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
17728
17736
ctx->abort_callback = abort_callback;
17729
17737
ctx->abort_callback_data = abort_callback_data;
0 commit comments