Skip to content

Commit f8ecd12

Browse files
committed
Extension method for setting llama_context_params from ContextParams
1 parent aa7dce1 commit f8ecd12

File tree

2 files changed

+101
-33
lines changed

2 files changed

+101
-33
lines changed

lib/src/ensemble_llama_base.dart

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,18 @@ void main() async {
1212
}
1313
});
1414

15+
final params = ContextParams(gpuLayers: 1, useMmap: false);
1516
final model = await llama.loadModel(
1617
"/Users/vczf/models/default/ggml-model-f16.gguf",
17-
params: ContextParams(gpuLayers: 1, useMmap: false),
18+
params: params,
1819
progressCallback: (p) => stdout.write("."),
1920
);
2021

2122
print(model.rawPointer);
2223

24+
// final ctx = await llama.newContext(model, params);
25+
// await llama.freeContext(ctx);
26+
2327
await llama.freeModel(model);
2428
llama.dispose();
2529
}
@@ -89,4 +93,18 @@ class Llama {
8993
_controlPort.send(ctl);
9094
await _response.firstWhere((e) => e is FreeModelResp && e.id == ctl.id);
9195
}
96+
97+
// Future<Context> newContext(Model model, ContextParams params) async {
98+
// final ctl = NewContextCtl(model, params);
99+
// _controlPort.send(ctl);
100+
// final resp =
101+
// await _response.firstWhere((e) => e is NewContextResp && e.id == ctl.id)
102+
// as NewContextResp;
103+
// }
104+
105+
// Future<void> freeContext(Context ctx) async {
106+
// final ctl = FreeContextCtl(ctx);
107+
// _controlPort.send(ctl);
108+
// await _response.firstWhere((e) => e is FreeContextResp && e.id == ctl.id);
109+
// }
92110
}

lib/src/llama_cpp_isolate_wrapper.dart

Lines changed: 82 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,33 @@ import 'package:ensemble_llama/ensemble_llama_cpp.dart';
99
// -1 (32 bit signed)
1010
const _int32Max = 0xFFFFFFFF;
1111

12+
extension on llama_context_params {
13+
// Sets most of the context parameters, such as int, double, bool.
14+
// Does not set callbacks or pointers to allocated memory.
15+
void setSimpleFrom(ContextParams p) {
16+
seed = p.seed;
17+
n_ctx = p.contextSizeTokens;
18+
n_batch = p.batchSizeTokens;
19+
n_gpu_layers = p.gpuLayers;
20+
main_gpu = p.cudaMainGpu;
21+
22+
// Skipping: tensor_split
23+
24+
rope_freq_base = p.ropeFreqBase;
25+
rope_freq_scale = p.ropeFreqScale;
26+
27+
// Skipping: progress_callback{,_user_data}
28+
29+
low_vram = p.useLessVram;
30+
mul_mat_q = p.cudaUseMulMatQ;
31+
f16_kv = p.useFloat16KVCache;
32+
logits_all = p.calculateAllLogits;
33+
vocab_only = p.loadOnlyVocabSkipTensors;
34+
use_mmap = p.useMmap;
35+
use_mlock = p.useMlock;
36+
embedding = p.willUseEmbedding;
37+
}
38+
}
1239
class ContextParams {
1340
final int seed;
1441
final int contextSizeTokens;
@@ -97,6 +124,23 @@ class FreeModelCtl extends ControlMessage {
97124
FreeModelResp done() => FreeModelResp(id);
98125
}
99126

127+
class NewContextCtl extends ControlMessage {
128+
final Model model;
129+
final ContextParams params;
130+
NewContextCtl(this.model, this.params);
131+
132+
NewContextResp done(Context ctx) => NewContextResp(id, ctx: ctx);
133+
134+
NewContextResp error(Object err) => NewContextResp(id, err: err);
135+
}
136+
137+
class FreeContextCtl extends ControlMessage {
138+
final Context ctx;
139+
FreeContextCtl(this.ctx);
140+
141+
FreeContextResp done() => FreeContextResp(id);
142+
}
143+
100144
sealed class ResponseMessage {
101145
final int id;
102146
final Object? err;
@@ -132,6 +176,15 @@ class FreeModelResp extends ResponseMessage {
132176
const FreeModelResp(super.id);
133177
}
134178

179+
class NewContextResp extends ResponseMessage {
180+
final Context? ctx;
181+
const NewContextResp(super.id, {super.err, this.ctx});
182+
}
183+
184+
class FreeContextResp extends ResponseMessage {
185+
const FreeContextResp(super.id);
186+
}
187+
135188
class EntryArgs {
136189
final SendPort log, response;
137190
const EntryArgs({required this.log, required this.response});
@@ -142,20 +195,28 @@ class Model {
142195
const Model._(this.rawPointer);
143196
}
144197

145-
class _Allocations {
146-
final Map<int, Set<Pointer>> _map = {};
198+
class Context {
199+
final int rawPointer;
200+
const Context._(this.rawPointer);
201+
}
202+
203+
class _Allocations<E> {
204+
final Map<E, Set<Pointer>> _map = {};
147205

148-
Set<Pointer>? get(int id) => _map[id];
206+
Set<Pointer>? operator [](E key) => _map[key];
207+
void operator []=(E key, Set<Pointer> allocs) => _map[key] = allocs;
149208

150-
void remove(int id) => _map.remove(id);
209+
void clear(E key) => _map.remove(key);
151210

152-
void add(int id, Pointer p) {
153-
_map[id] ??= {}..add(p);
154-
_map[id]!.add(p);
211+
void add(E key, Pointer p) {
212+
_map[key] ??= {}..add(p);
213+
_map[key]!.add(p);
155214
}
156215
}
157216

158-
final _Allocations _allocs = _Allocations();
217+
// key: rawModelPointer
218+
final _modelAllocs = _Allocations<int>();
219+
final _ctxAllocs = _Allocations<int>();
159220

160221
late final SendPort _log;
161222
late final SendPort _response;
@@ -192,36 +253,19 @@ void _onControl(ControlMessage ctl) {
192253
_response.send(ctl.done());
193254

194255
case LoadModelCtl():
256+
final Set<Pointer> allocs = {};
195257
final pd = ctl.ctxParams;
196-
final pc = libllama.llama_context_default_params();
197-
198-
pc.seed = pd.seed;
199-
pc.n_ctx = pd.contextSizeTokens;
200-
pc.n_batch = pd.batchSizeTokens;
201-
pc.n_gpu_layers = pd.gpuLayers;
202-
pc.main_gpu = pd.cudaMainGpu;
258+
final pc = libllama.llama_context_default_params()..setSimpleFrom(pd);
203259

204260
// TODO: can't do this until we track contexts to manage memory allocation
205261
// pc.tensor_split
206262

207-
pc.rope_freq_base = pd.ropeFreqBase;
208-
pc.rope_freq_scale = pd.ropeFreqScale;
209-
210263
pc.progress_callback = Pointer.fromFunction(_onModelLoadProgress);
211264
final idPointer = calloc.allocate<Uint32>(sizeOf<Uint32>());
212-
_allocs.add(ctl.id, idPointer);
265+
allocs.add(idPointer);
213266
idPointer.value = ctl.id;
214267
pc.progress_callback_user_data = idPointer.cast<Void>();
215268

216-
pc.low_vram = pd.useLessVram;
217-
pc.mul_mat_q = pd.cudaUseMulMatQ;
218-
pc.f16_kv = pd.useFloat16KVCache;
219-
pc.logits_all = pd.calculateAllLogits;
220-
pc.vocab_only = pd.loadOnlyVocabSkipTensors;
221-
pc.use_mmap = pd.useMmap;
222-
pc.use_mlock = pd.useMlock;
223-
pc.embedding = pd.willUseEmbedding;
224-
225269
final rawModelPointer = libllama
226270
.llama_load_model_from_file(
227271
ctl.path.toNativeUtf8().cast<Char>(),
@@ -236,19 +280,25 @@ void _onControl(ControlMessage ctl) {
236280
return;
237281
}
238282

283+
_modelAllocs[rawModelPointer] = allocs;
239284
_response.send(ctl.done(Model._(rawModelPointer)));
240285

241286
case FreeModelCtl():
242287
assert(ctl.model.rawPointer != 0);
243-
_allocs.get(ctl.id)?.forEach((p) {
288+
_modelAllocs[ctl.model.rawPointer]?.forEach((p) {
244289
calloc.free(p);
245290
});
246-
_allocs.remove(ctl.id);
291+
_modelAllocs.clear(ctl.model.rawPointer);
247292

248-
final rawModelPointer =
293+
final modelPointer =
249294
Pointer.fromAddress(ctl.model.rawPointer).cast<llama_model>();
250-
libllama.llama_free_model(rawModelPointer);
295+
libllama.llama_free_model(modelPointer);
251296

252297
_response.send(ctl.done());
298+
299+
case NewContextCtl():
300+
assert(ctl.model.rawPointer != 0);
301+
302+
case FreeContextCtl():
253303
}
254304
}

0 commit comments

Comments
 (0)