Skip to content

Commit 9af4907

Browse files
committed
Merge branch 'upstream' into concedo_experimental
# Conflicts: # README.md # ci/run.sh
2 parents 6ac4025 + 8b94e79 commit 9af4907

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

llama.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6788,7 +6788,7 @@ static struct ggml_tensor * llm_build_kqv(
67886788

67896789
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
67906790

6791-
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) {
6791+
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
67926792
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
67936793
}
67946794

@@ -6797,7 +6797,7 @@ static struct ggml_tensor * llm_build_kqv(
67976797
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
67986798
cb(kq, "kq", il);
67996799

6800-
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) {
6800+
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
68016801
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
68026802
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
68036803
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
@@ -17724,6 +17724,14 @@ void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_
1772417724
ctx->cparams.n_threads_batch = n_threads_batch;
1772517725
}
1772617726

17727+
uint32_t llama_n_threads(struct llama_context * ctx) {
17728+
return ctx->cparams.n_threads;
17729+
}
17730+
17731+
uint32_t llama_n_threads_batch(struct llama_context * ctx) {
17732+
return ctx->cparams.n_threads_batch;
17733+
}
17734+
1772717735
void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
1772817736
ctx->abort_callback = abort_callback;
1772917737
ctx->abort_callback_data = abort_callback_data;

llama.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,12 @@ extern "C" {
761761
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
762762
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
763763

764+
// Get the number of threads used for generation of a single token.
765+
LLAMA_API uint32_t llama_n_threads(struct llama_context * ctx);
766+
767+
// Get the number of threads used for prompt and batch processing (multiple token).
768+
LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
769+
764770
// Set whether to use causal attention or not
765771
// If set to true, the model will only attend to the past tokens
766772
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);

0 commit comments

Comments
 (0)