Skip to content

Commit 0ddc8e3

Browse files
committed
llama : move sampling code into llama-sampling
ggml-ci
1 parent 081fe43 commit 0ddc8e3

File tree

7 files changed

+758
-699
lines changed

7 files changed

+758
-699
lines changed

Makefile

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,7 @@ OBJ_GGML += \
876876

877877
OBJ_LLAMA = \
878878
src/llama.o \
879+
src/llama-sampling.o \
879880
src/unicode.o \
880881
src/unicode-data.o
881882

@@ -1055,6 +1056,7 @@ src/unicode-data.o: \
10551056

10561057
src/llama.o: \
10571058
src/llama.cpp \
1059+
src/llama-impl.h \
10581060
src/unicode.h \
10591061
include/llama.h \
10601062
ggml/include/ggml-cuda.h \
@@ -1064,6 +1066,13 @@ src/llama.o: \
10641066
ggml/include/ggml-backend.h
10651067
$(CXX) $(CXXFLAGS) -c $< -o $@
10661068

1069+
src/llama-sampling.o: \
1070+
src/llama-sampling.cpp \
1071+
src/llama-sampling.h \
1072+
src/llama-impl.h \
1073+
include/llama.h
1074+
$(CXX) $(CXXFLAGS) -c $< -o $@
1075+
10671076
$(LIB_LLAMA): \
10681077
$(OBJ_LLAMA) \
10691078
$(LIB_GGML)

include/llama.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,12 +1084,6 @@ extern "C" {
10841084
llama_token_data_array * candidates,
10851085
float temp);
10861086

1087-
/// @details Apply constraints from grammar
1088-
LLAMA_API void llama_sample_grammar(
1089-
struct llama_context * ctx,
1090-
llama_token_data_array * candidates,
1091-
const struct llama_grammar * grammar);
1092-
10931087
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
10941088
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
10951089
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@@ -1127,6 +1121,12 @@ extern "C" {
11271121
struct llama_context * ctx,
11281122
llama_token_data_array * candidates);
11291123

1124+
/// @details Apply constraints from grammar
1125+
LLAMA_API void llama_sample_grammar(
1126+
struct llama_context * ctx,
1127+
llama_token_data_array * candidates,
1128+
const struct llama_grammar * grammar);
1129+
11301130
/// @details Accepts the sampled token into the grammar
11311131
LLAMA_API void llama_grammar_accept_token(
11321132
struct llama_context * ctx,

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ endif()
1414
add_library(llama
1515
../include/llama.h
1616
llama.cpp
17+
llama-sampling.cpp
1718
unicode.h
1819
unicode.cpp
1920
unicode-data.cpp

src/llama-impl.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#pragma once
2+
3+
#define LLAMA_API_INTERNAL
4+
#include "llama.h"
5+
6+
#include <array>
7+
#include <set>
8+
#include <map>
9+
#include <cstdint>
10+
#include <random>
11+
12+
#ifdef __has_include
13+
#if __has_include(<unistd.h>)
14+
#include <unistd.h>
15+
#if defined(_POSIX_MAPPED_FILES)
16+
#include <sys/mman.h>
17+
#include <fcntl.h>
18+
#endif
19+
#if defined(_POSIX_MEMLOCK_RANGE)
20+
#include <sys/resource.h>
21+
#endif
22+
#endif
23+
#endif
24+
25+
// bump if necessary
26+
#define LLAMA_MAX_NODES 8192
27+
#define LLAMA_MAX_LAYERS 256
28+
#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2
29+
30+
#ifdef __GNUC__
31+
#ifdef __MINGW32__
32+
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
33+
#else
34+
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
35+
#endif
36+
#else
37+
#define LLAMA_ATTRIBUTE_FORMAT(...)
38+
#endif
39+
40+
//
41+
// logging
42+
//
43+
44+
LLAMA_ATTRIBUTE_FORMAT(2, 3)
45+
void llama_log_internal (ggml_log_level level, const char * format, ...);
46+
void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data);
47+
48+
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
49+
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
50+
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)

0 commit comments

Comments
 (0)