Skip to content

Commit c4c3102

Browse files
committed
new features and fixes #1
- gpt params wrapper - inference aka sampler params - get embed to extract features - decode function - streaming prompt
1 parent dc09242 commit c4c3102

File tree

7 files changed

+504
-13
lines changed

7 files changed

+504
-13
lines changed

lib/float_array.dart

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import 'dart:ffi';
2+
3+
import 'package:ffi/ffi.dart';
4+
5+
import 'llama_cpp.dart';
6+
7+
class FloatArray {
8+
late llm_float_array native;
9+
10+
FloatArray();
11+
12+
llm_float_array get() => native;
13+
14+
factory FloatArray.fromNative(llm_float_array native) {
15+
FloatArray output = FloatArray();
16+
output.native = native;
17+
return output;
18+
}
19+
20+
List<double> get data {
21+
return List<double>.generate(
22+
native.size, (i) => native.data.elementAt(i).value);
23+
}
24+
25+
void dispose() {
26+
calloc.free(native.data);
27+
}
28+
}

lib/gpt_params.dart

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
import 'dart:ffi';
2+
3+
import 'package:ffi/ffi.dart';
4+
5+
import 'inference_parameters.dart';
6+
import 'llama_cpp.dart';
7+
8+
class GptParams {
9+
late llama_cpp _lib;
10+
late llm_gpt_params _parameters;
11+
late InferenceParameters _inferenceParameters;
12+
13+
GptParams() {
14+
_lib = llama_cpp(DynamicLibrary.process());
15+
_parameters = _lib.llm_create_gpt_params();
16+
}
17+
18+
llm_gpt_params get() => _parameters;
19+
20+
InferenceParameters get inferenceParameters => _inferenceParameters;
21+
set cfgNegativePrompt(InferenceParameters inferenceParameters) {
22+
_inferenceParameters = inferenceParameters;
23+
_parameters.sparams = inferenceParameters.get();
24+
}
25+
26+
int get seed => _parameters.seed;
27+
set seed(int value) => _parameters.seed = value;
28+
29+
int get nThreads => _parameters.n_threads;
30+
set nThreads(int value) => _parameters.n_threads = value;
31+
32+
int get nThreadsBatch => _parameters.n_threads_batch;
33+
set nThreadsBatch(int value) => _parameters.n_threads_batch = value;
34+
35+
int get nPredict => _parameters.n_predict;
36+
set nPredict(int value) => _parameters.n_predict = value;
37+
38+
int get nCtx => _parameters.n_ctx;
39+
set nCtx(int value) => _parameters.n_ctx = value;
40+
41+
int get nBatch => _parameters.n_batch;
42+
set nBatch(int value) => _parameters.n_batch = value;
43+
44+
int get nKeep => _parameters.n_keep;
45+
set nKeep(int value) => _parameters.n_keep = value;
46+
47+
int get nDraft => _parameters.n_draft;
48+
set nDraft(int value) => _parameters.n_draft = value;
49+
50+
int get nChunks => _parameters.n_chunks;
51+
set nChunks(int value) => _parameters.n_chunks = value;
52+
53+
int get nParallel => _parameters.n_parallel;
54+
set nParallel(int value) => _parameters.n_parallel = value;
55+
56+
int get nSequences => _parameters.n_sequences;
57+
set nSequences(int value) => _parameters.n_sequences = value;
58+
59+
double get pAccept => _parameters.p_accept;
60+
set pAccept(double value) => _parameters.p_accept = value;
61+
62+
double get pSplit => _parameters.p_split;
63+
set pSplit(double value) => _parameters.p_split = value;
64+
65+
int get nGpuLayers => _parameters.n_gpu_layers;
66+
set nGpuLayers(int value) => _parameters.n_gpu_layers = value;
67+
68+
int get nGpuLayersDraft => _parameters.n_gpu_layers_draft;
69+
set nGpuLayersDraft(int value) => _parameters.n_gpu_layers_draft = value;
70+
71+
int get mainGpu => _parameters.main_gpu;
72+
set mainGpu(int value) => _parameters.main_gpu = value;
73+
74+
List<double> get tensorSplit =>
75+
List<double>.generate(16, (i) => _parameters.tensor_split[i]);
76+
set tensorSplit(List<double> values) {
77+
for (int i = 0; i < 16; i++) {
78+
_parameters.tensor_split[i] = values[i];
79+
}
80+
}
81+
82+
int get nBeams => _parameters.n_beams;
83+
set nBeams(int value) => _parameters.n_beams = value;
84+
85+
double get ropeFreqBase => _parameters.rope_freq_base;
86+
set ropeFreqBase(double value) => _parameters.rope_freq_base = value;
87+
88+
double get ropeFreqScale => _parameters.rope_freq_scale;
89+
set ropeFreqScale(double value) => _parameters.rope_freq_scale = value;
90+
91+
double get yarnExtFactor => _parameters.yarn_ext_factor;
92+
set yarnExtFactor(double value) => _parameters.yarn_ext_factor = value;
93+
94+
double get yarnAttnFactor => _parameters.yarn_attn_factor;
95+
set yarnAttnFactor(double value) => _parameters.yarn_attn_factor = value;
96+
97+
double get yarnBetaFast => _parameters.yarn_beta_fast;
98+
set yarnBetaFast(double value) => _parameters.yarn_beta_fast = value;
99+
100+
double get yarnBetaSlow => _parameters.yarn_beta_slow;
101+
set yarnBetaSlow(double value) => _parameters.yarn_beta_slow = value;
102+
103+
int get yarnOrigCtx => _parameters.yarn_orig_ctx;
104+
set yarnOrigCtx(int value) => _parameters.yarn_orig_ctx = value;
105+
106+
int get ropeScalingType => _parameters.rope_scaling_type;
107+
set ropeScalingType(int value) => _parameters.rope_scaling_type = value;
108+
109+
String get model => _parameters.model.cast<Utf8>().toDartString();
110+
set model(String value) =>
111+
_parameters.model = value.toNativeUtf8().cast<Char>();
112+
113+
String get modelDraft => _parameters.model_draft.cast<Utf8>().toDartString();
114+
set modelDraft(String value) =>
115+
_parameters.model_draft = value.toNativeUtf8().cast<Char>();
116+
117+
String get modelAlias => _parameters.model_alias.cast<Utf8>().toDartString();
118+
set modelAlias(String value) =>
119+
_parameters.model_alias = value.toNativeUtf8().cast<Char>();
120+
121+
String get prompt => _parameters.prompt.cast<Utf8>().toDartString();
122+
set prompt(String value) =>
123+
_parameters.prompt = value.toNativeUtf8().cast<Char>();
124+
125+
String get promptFile => _parameters.prompt_file.cast<Utf8>().toDartString();
126+
set promptFile(String value) =>
127+
_parameters.prompt_file = value.toNativeUtf8().cast<Char>();
128+
129+
String get pathPromptCache =>
130+
_parameters.path_prompt_cache.cast<Utf8>().toDartString();
131+
set pathPromptCache(String value) =>
132+
_parameters.path_prompt_cache = value.toNativeUtf8().cast<Char>();
133+
134+
String get inputPrefix =>
135+
_parameters.input_prefix.cast<Utf8>().toDartString();
136+
set inputPrefix(String value) =>
137+
_parameters.input_prefix = value.toNativeUtf8().cast<Char>();
138+
139+
String get inputSuffix =>
140+
_parameters.input_suffix.cast<Utf8>().toDartString();
141+
set inputSuffix(String value) =>
142+
_parameters.input_suffix = value.toNativeUtf8().cast<Char>();
143+
144+
String get logdir => _parameters.logdir.cast<Utf8>().toDartString();
145+
set logdir(String value) =>
146+
_parameters.logdir = value.toNativeUtf8().cast<Char>();
147+
148+
String get loraBase => _parameters.lora_base.cast<Utf8>().toDartString();
149+
set loraBase(String value) =>
150+
_parameters.lora_base = value.toNativeUtf8().cast<Char>();
151+
152+
int get pplStride => _parameters.ppl_stride;
153+
set pplStride(int value) => _parameters.ppl_stride = value;
154+
155+
int get pplOutputType => _parameters.ppl_output_type;
156+
set pplOutputType(int value) => _parameters.ppl_output_type = value;
157+
158+
int get hellaswag => _parameters.hellaswag;
159+
set hellaswag(int value) => _parameters.hellaswag = value;
160+
161+
int get hellaswagTasks => _parameters.hellaswag_tasks;
162+
set hellaswagTasks(int value) => _parameters.hellaswag_tasks = value;
163+
164+
int get mulMatQ => _parameters.mul_mat_q;
165+
set mulMatQ(int value) => _parameters.mul_mat_q = value;
166+
167+
int get randomPrompt => _parameters.random_prompt;
168+
set randomPrompt(int value) => _parameters.random_prompt = value;
169+
170+
int get useColor => _parameters.use_color;
171+
set useColor(int value) => _parameters.use_color = value;
172+
173+
int get interactive => _parameters.interactive;
174+
set interactive(int value) => _parameters.interactive = value;
175+
176+
int get chatml => _parameters.chatml;
177+
set chatml(int value) => _parameters.chatml = value;
178+
179+
int get promptCacheAll => _parameters.prompt_cache_all;
180+
set promptCacheAll(int value) => _parameters.prompt_cache_all = value;
181+
182+
int get promptCacheRo => _parameters.prompt_cache_ro;
183+
set promptCacheRo(int value) => _parameters.prompt_cache_ro = value;
184+
185+
int get embedding => _parameters.embedding;
186+
set embedding(int value) => _parameters.embedding = value;
187+
188+
int get escape => _parameters.escape;
189+
set escape(int value) => _parameters.escape = value;
190+
191+
int get interactiveFirst => _parameters.interactive_first;
192+
set interactiveFirst(int value) => _parameters.interactive_first = value;
193+
194+
int get multilineInput => _parameters.multiline_input;
195+
set multilineInput(int value) => _parameters.multiline_input = value;
196+
197+
int get simpleIo => _parameters.simple_io;
198+
set simpleIo(int value) => _parameters.simple_io = value;
199+
200+
int get contBatching => _parameters.cont_batching;
201+
set contBatching(int value) => _parameters.cont_batching = value;
202+
203+
int get inputPrefixBos => _parameters.input_prefix_bos;
204+
set inputPrefixBos(int value) => _parameters.input_prefix_bos = value;
205+
206+
int get ignoreEos => _parameters.ignore_eos;
207+
set ignoreEos(int value) => _parameters.ignore_eos = value;
208+
209+
int get instruct => _parameters.instruct;
210+
set instruct(int value) => _parameters.instruct = value;
211+
212+
int get logitsAll => _parameters.logits_all;
213+
set logitsAll(int value) => _parameters.logits_all = value;
214+
215+
int get useMmap => _parameters.use_mmap;
216+
set useMmap(int value) => _parameters.use_mmap = value;
217+
218+
int get useMlock => _parameters.use_mlock;
219+
set useMlock(int value) => _parameters.use_mlock = value;
220+
221+
int get numa => _parameters.numa;
222+
set numa(int value) => _parameters.numa = value;
223+
224+
int get verbosePrompt => _parameters.verbose_prompt;
225+
set verbosePrompt(int value) => _parameters.verbose_prompt = value;
226+
227+
int get infill => _parameters.infill;
228+
set infill(int value) => _parameters.infill = value;
229+
230+
int get dumpKvCache => _parameters.dump_kv_cache;
231+
set dumpKvCache(int value) => _parameters.dump_kv_cache = value;
232+
233+
int get noKvOffload => _parameters.no_kv_offload;
234+
set noKvOffload(int value) => _parameters.no_kv_offload = value;
235+
236+
List<int> get cacheTypeK =>
237+
List<int>.generate(4, (i) => _parameters.cache_type_k[i]);
238+
set cacheTypeK(List<int> values) {
239+
for (int i = 0; i < 4; i++) {
240+
_parameters.cache_type_k[i] = values[i];
241+
}
242+
}
243+
244+
List<int> get cacheTypeV =>
245+
List<int>.generate(4, (i) => _parameters.cache_type_v[i]);
246+
set cacheTypeV(List<int> values) {
247+
for (int i = 0; i < 4; i++) {
248+
_parameters.cache_type_v[i] = values[i];
249+
}
250+
}
251+
252+
String get mmproj => _parameters.mmproj.cast<Utf8>().toDartString();
253+
set mmproj(String value) =>
254+
_parameters.mmproj = value.toNativeUtf8().cast<Char>();
255+
256+
String get image => _parameters.image.cast<Utf8>().toDartString();
257+
set image(String value) =>
258+
_parameters.image = value.toNativeUtf8().cast<Char>();
259+
260+
void dispose() {
261+
_inferenceParameters.dispose();
262+
263+
calloc.free(_parameters.model);
264+
calloc.free(_parameters.model_draft);
265+
calloc.free(_parameters.model_alias);
266+
calloc.free(_parameters.prompt);
267+
calloc.free(_parameters.prompt_file);
268+
calloc.free(_parameters.path_prompt_cache);
269+
calloc.free(_parameters.input_prefix);
270+
calloc.free(_parameters.input_suffix);
271+
calloc.free(_parameters.logdir);
272+
calloc.free(_parameters.lora_base);
273+
}
274+
}

0 commit comments

Comments
 (0)