@@ -9,6 +9,33 @@ import 'package:ensemble_llama/ensemble_llama_cpp.dart';
9
9
// -1 (32 bit signed)
10
10
const _int32Max = 0xFFFFFFFF ;
11
11
12
+ extension on llama_context_params {
13
+ // Sets most of the context parameters, such as int, double, bool.
14
+ // Does not set callbacks or pointers to allocated memory.
15
+ void setSimpleFrom (ContextParams p) {
16
+ seed = p.seed;
17
+ n_ctx = p.contextSizeTokens;
18
+ n_batch = p.batchSizeTokens;
19
+ n_gpu_layers = p.gpuLayers;
20
+ main_gpu = p.cudaMainGpu;
21
+
22
+ // Skipping: tensor_split
23
+
24
+ rope_freq_base = p.ropeFreqBase;
25
+ rope_freq_scale = p.ropeFreqScale;
26
+
27
+ // Skipping: progress_callback{,_user_data}
28
+
29
+ low_vram = p.useLessVram;
30
+ mul_mat_q = p.cudaUseMulMatQ;
31
+ f16_kv = p.useFloat16KVCache;
32
+ logits_all = p.calculateAllLogits;
33
+ vocab_only = p.loadOnlyVocabSkipTensors;
34
+ use_mmap = p.useMmap;
35
+ use_mlock = p.useMlock;
36
+ embedding = p.willUseEmbedding;
37
+ }
38
+ }
12
39
class ContextParams {
13
40
final int seed;
14
41
final int contextSizeTokens;
@@ -97,6 +124,23 @@ class FreeModelCtl extends ControlMessage {
97
124
FreeModelResp done () => FreeModelResp (id);
98
125
}
99
126
127
+ class NewContextCtl extends ControlMessage {
128
+ final Model model;
129
+ final ContextParams params;
130
+ NewContextCtl (this .model, this .params);
131
+
132
+ NewContextResp done (Context ctx) => NewContextResp (id, ctx: ctx);
133
+
134
+ NewContextResp error (Object err) => NewContextResp (id, err: err);
135
+ }
136
+
137
+ class FreeContextCtl extends ControlMessage {
138
+ final Context ctx;
139
+ FreeContextCtl (this .ctx);
140
+
141
+ FreeContextResp done () => FreeContextResp (id);
142
+ }
143
+
100
144
sealed class ResponseMessage {
101
145
final int id;
102
146
final Object ? err;
@@ -132,6 +176,15 @@ class FreeModelResp extends ResponseMessage {
132
176
const FreeModelResp (super .id);
133
177
}
134
178
179
+ class NewContextResp extends ResponseMessage {
180
+ final Context ? ctx;
181
+ const NewContextResp (super .id, {super .err, this .ctx});
182
+ }
183
+
184
+ class FreeContextResp extends ResponseMessage {
185
+ const FreeContextResp (super .id);
186
+ }
187
+
135
188
class EntryArgs {
136
189
final SendPort log, response;
137
190
const EntryArgs ({required this .log, required this .response});
@@ -142,20 +195,28 @@ class Model {
142
195
const Model ._(this .rawPointer);
143
196
}
144
197
145
- class _Allocations {
146
- final Map <int , Set <Pointer >> _map = {};
198
+ class Context {
199
+ final int rawPointer;
200
+ const Context ._(this .rawPointer);
201
+ }
202
+
203
+ class _Allocations <E > {
204
+ final Map <E , Set <Pointer >> _map = {};
147
205
148
- Set <Pointer >? get (int id) => _map[id];
206
+ Set <Pointer >? operator [](E key) => _map[key];
207
+ void operator []= (E key, Set <Pointer > allocs) => _map[key] = allocs;
149
208
150
- void remove ( int id ) => _map.remove (id );
209
+ void clear ( E key ) => _map.remove (key );
151
210
152
- void add (int id , Pointer p) {
153
- _map[id ] ?? = {}..add (p);
154
- _map[id ]! .add (p);
211
+ void add (E key , Pointer p) {
212
+ _map[key ] ?? = {}..add (p);
213
+ _map[key ]! .add (p);
155
214
}
156
215
}
157
216
158
- final _Allocations _allocs = _Allocations ();
217
+ // key: rawModelPointer
218
+ final _modelAllocs = _Allocations <int >();
219
+ final _ctxAllocs = _Allocations <int >();
159
220
160
221
late final SendPort _log;
161
222
late final SendPort _response;
@@ -192,36 +253,19 @@ void _onControl(ControlMessage ctl) {
192
253
_response.send (ctl.done ());
193
254
194
255
case LoadModelCtl ():
256
+ final Set <Pointer > allocs = {};
195
257
final pd = ctl.ctxParams;
196
- final pc = libllama.llama_context_default_params ();
197
-
198
- pc.seed = pd.seed;
199
- pc.n_ctx = pd.contextSizeTokens;
200
- pc.n_batch = pd.batchSizeTokens;
201
- pc.n_gpu_layers = pd.gpuLayers;
202
- pc.main_gpu = pd.cudaMainGpu;
258
+ final pc = libllama.llama_context_default_params ()..setSimpleFrom (pd);
203
259
204
260
// TODO: can't do this until we track contexts to manage memory allocation
205
261
// pc.tensor_split
206
262
207
- pc.rope_freq_base = pd.ropeFreqBase;
208
- pc.rope_freq_scale = pd.ropeFreqScale;
209
-
210
263
pc.progress_callback = Pointer .fromFunction (_onModelLoadProgress);
211
264
final idPointer = calloc.allocate <Uint32 >(sizeOf <Uint32 >());
212
- _allocs .add (ctl.id, idPointer);
265
+ allocs .add (idPointer);
213
266
idPointer.value = ctl.id;
214
267
pc.progress_callback_user_data = idPointer.cast <Void >();
215
268
216
- pc.low_vram = pd.useLessVram;
217
- pc.mul_mat_q = pd.cudaUseMulMatQ;
218
- pc.f16_kv = pd.useFloat16KVCache;
219
- pc.logits_all = pd.calculateAllLogits;
220
- pc.vocab_only = pd.loadOnlyVocabSkipTensors;
221
- pc.use_mmap = pd.useMmap;
222
- pc.use_mlock = pd.useMlock;
223
- pc.embedding = pd.willUseEmbedding;
224
-
225
269
final rawModelPointer = libllama
226
270
.llama_load_model_from_file (
227
271
ctl.path.toNativeUtf8 ().cast <Char >(),
@@ -236,19 +280,25 @@ void _onControl(ControlMessage ctl) {
236
280
return ;
237
281
}
238
282
283
+ _modelAllocs[rawModelPointer] = allocs;
239
284
_response.send (ctl.done (Model ._(rawModelPointer)));
240
285
241
286
case FreeModelCtl ():
242
287
assert (ctl.model.rawPointer != 0 );
243
- _allocs. get ( ctl.id) ? .forEach ((p) {
288
+ _modelAllocs[ ctl.model.rawPointer] ? .forEach ((p) {
244
289
calloc.free (p);
245
290
});
246
- _allocs. remove (ctl.id );
291
+ _modelAllocs. clear (ctl.model.rawPointer );
247
292
248
- final rawModelPointer =
293
+ final modelPointer =
249
294
Pointer .fromAddress (ctl.model.rawPointer).cast <llama_model>();
250
- libllama.llama_free_model (rawModelPointer );
295
+ libllama.llama_free_model (modelPointer );
251
296
252
297
_response.send (ctl.done ());
298
+
299
+ case NewContextCtl ():
300
+ assert (ctl.model.rawPointer != 0 );
301
+
302
+ case FreeContextCtl ():
253
303
}
254
304
}
0 commit comments