1
1
from __future__ import annotations
2
2
3
- import os
4
3
import ctypes
5
-
4
+ import os
5
+ from contextlib import ExitStack
6
+ from dataclasses import dataclass , field
6
7
from typing import (
7
8
Dict ,
8
9
List ,
9
10
Optional ,
10
11
Sequence ,
11
12
)
12
- from dataclasses import dataclass , field
13
- from contextlib import ExitStack
14
13
15
14
import numpy as np
16
15
import numpy .typing as npt
17
16
18
- from .llama_types import *
19
- from .llama_grammar import LlamaGrammar
17
+ from llama_cpp import llama_cpp
20
18
from ._utils import suppress_stdout_stderr
21
-
22
- import llama_cpp .llama_cpp as llama_cpp
23
-
19
+ from .llama_grammar import LlamaGrammar
20
+ from .llama_types import *
24
21
25
22
# Python wrappers over llama.h structs
26
23
@@ -351,7 +348,7 @@ def get_state_size(self) -> int:
351
348
352
349
# TODO: llama_save_session_file
353
350
354
- def decode (self , batch : " _LlamaBatch" ):
351
+ def decode (self , batch : _LlamaBatch ):
355
352
assert self .ctx is not None
356
353
assert batch .batch is not None
357
354
return_code = llama_cpp .llama_decode (
@@ -385,8 +382,8 @@ def set_rng_seed(self, seed: int):
385
382
386
383
def sample_repetition_penalties (
387
384
self ,
388
- candidates : " _LlamaTokenDataArray" ,
389
- last_tokens_data : " llama_cpp.Array[llama_cpp.llama_token]" ,
385
+ candidates : _LlamaTokenDataArray ,
386
+ last_tokens_data : llama_cpp .Array [llama_cpp .llama_token ],
390
387
penalty_last_n : int ,
391
388
penalty_repeat : float ,
392
389
penalty_freq : float ,
@@ -403,54 +400,54 @@ def sample_repetition_penalties(
403
400
penalty_present ,
404
401
)
405
402
406
- def sample_softmax (self , candidates : " _LlamaTokenDataArray" ):
403
+ def sample_softmax (self , candidates : _LlamaTokenDataArray ):
407
404
assert self .ctx is not None
408
405
llama_cpp .llama_sample_softmax (
409
406
self .ctx ,
410
407
llama_cpp .byref (candidates .candidates ),
411
408
)
412
409
413
- def sample_top_k (self , candidates : " _LlamaTokenDataArray" , k : int , min_keep : int ):
410
+ def sample_top_k (self , candidates : _LlamaTokenDataArray , k : int , min_keep : int ):
414
411
assert self .ctx is not None
415
412
llama_cpp .llama_sample_top_k (
416
413
self .ctx , llama_cpp .byref (candidates .candidates ), k , min_keep
417
414
)
418
415
419
- def sample_top_p (self , candidates : " _LlamaTokenDataArray" , p : float , min_keep : int ):
416
+ def sample_top_p (self , candidates : _LlamaTokenDataArray , p : float , min_keep : int ):
420
417
assert self .ctx is not None
421
418
llama_cpp .llama_sample_top_p (
422
419
self .ctx , llama_cpp .byref (candidates .candidates ), p , min_keep
423
420
)
424
421
425
- def sample_min_p (self , candidates : " _LlamaTokenDataArray" , p : float , min_keep : int ):
422
+ def sample_min_p (self , candidates : _LlamaTokenDataArray , p : float , min_keep : int ):
426
423
assert self .ctx is not None
427
424
llama_cpp .llama_sample_min_p (
428
425
self .ctx , llama_cpp .byref (candidates .candidates ), p , min_keep
429
426
)
430
427
431
428
def sample_tail_free (
432
- self , candidates : " _LlamaTokenDataArray" , z : float , min_keep : int
429
+ self , candidates : _LlamaTokenDataArray , z : float , min_keep : int
433
430
):
434
431
assert self .ctx is not None
435
432
llama_cpp .llama_sample_tail_free (
436
433
self .ctx , llama_cpp .byref (candidates .candidates ), z , min_keep
437
434
)
438
435
439
436
def sample_typical (
440
- self , candidates : " _LlamaTokenDataArray" , p : float , min_keep : int
437
+ self , candidates : _LlamaTokenDataArray , p : float , min_keep : int
441
438
):
442
439
assert self .ctx is not None
443
440
llama_cpp .llama_sample_typical (
444
441
self .ctx , llama_cpp .byref (candidates .candidates ), p , min_keep
445
442
)
446
443
447
- def sample_temp (self , candidates : " _LlamaTokenDataArray" , temp : float ):
444
+ def sample_temp (self , candidates : _LlamaTokenDataArray , temp : float ):
448
445
assert self .ctx is not None
449
446
llama_cpp .llama_sample_temp (
450
447
self .ctx , llama_cpp .byref (candidates .candidates ), temp
451
448
)
452
449
453
- def sample_grammar (self , candidates : " _LlamaTokenDataArray" , grammar : LlamaGrammar ):
450
+ def sample_grammar (self , candidates : _LlamaTokenDataArray , grammar : LlamaGrammar ):
454
451
assert self .ctx is not None
455
452
assert grammar .grammar is not None
456
453
llama_cpp .llama_sample_grammar (
@@ -461,7 +458,7 @@ def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGramm
461
458
462
459
def sample_token_mirostat (
463
460
self ,
464
- candidates : " _LlamaTokenDataArray" ,
461
+ candidates : _LlamaTokenDataArray ,
465
462
tau : float ,
466
463
eta : float ,
467
464
m : int ,
@@ -493,14 +490,14 @@ def sample_token_mirostat_v2(
493
490
mu ,
494
491
)
495
492
496
- def sample_token_greedy (self , candidates : " _LlamaTokenDataArray" ) -> int :
493
+ def sample_token_greedy (self , candidates : _LlamaTokenDataArray ) -> int :
497
494
assert self .ctx is not None
498
495
return llama_cpp .llama_sample_token_greedy (
499
496
self .ctx ,
500
497
llama_cpp .byref (candidates .candidates ),
501
498
)
502
499
503
- def sample_token (self , candidates : " _LlamaTokenDataArray" ) -> int :
500
+ def sample_token (self , candidates : _LlamaTokenDataArray ) -> int :
504
501
assert self .ctx is not None
505
502
return llama_cpp .llama_sample_token (
506
503
self .ctx ,
@@ -822,44 +819,43 @@ def sample(
822
819
id = token_data_array .candidates_data .id [0 ]
823
820
elif self .params .temp == 0 :
824
821
id = ctx_main .sample_token_greedy (token_data_array )
822
+ elif self .params .mirostat == 1 :
823
+ mirostat_m = 100
824
+ ctx_main .sample_temp (token_data_array , self .params .temp )
825
+ id = ctx_main .sample_token_mirostat (
826
+ token_data_array ,
827
+ self .params .mirostat_tau ,
828
+ self .params .mirostat_eta ,
829
+ mirostat_m ,
830
+ ctypes .pointer (self .mirostat_mu ),
831
+ )
832
+ elif self .params .mirostat == 2 :
833
+ ctx_main .sample_temp (token_data_array , self .params .temp )
834
+ id = ctx_main .sample_token_mirostat_v2 (
835
+ token_data_array ,
836
+ self .params .mirostat_tau ,
837
+ self .params .mirostat_eta ,
838
+ ctypes .pointer (self .mirostat_mu ),
839
+ )
825
840
else :
826
- if self .params .mirostat == 1 :
827
- mirostat_m = 100
828
- ctx_main .sample_temp (token_data_array , self .params .temp )
829
- id = ctx_main .sample_token_mirostat (
830
- token_data_array ,
831
- self .params .mirostat_tau ,
832
- self .params .mirostat_eta ,
833
- mirostat_m ,
834
- ctypes .pointer (self .mirostat_mu ),
835
- )
836
- elif self .params .mirostat == 2 :
837
- ctx_main .sample_temp (token_data_array , self .params .temp )
838
- id = ctx_main .sample_token_mirostat_v2 (
839
- token_data_array ,
840
- self .params .mirostat_tau ,
841
- self .params .mirostat_eta ,
842
- ctypes .pointer (self .mirostat_mu ),
843
- )
844
- else :
845
- min_keep = max (1 , self .params .n_probs )
846
- ctx_main .sample_top_k (
847
- token_data_array , self .params .top_k , min_keep = min_keep
848
- )
849
- ctx_main .sample_tail_free (
850
- token_data_array , self .params .tfs_z , min_keep = min_keep
851
- )
852
- ctx_main .sample_typical (
853
- token_data_array , self .params .typical_p , min_keep = min_keep
854
- )
855
- ctx_main .sample_top_p (
856
- token_data_array , self .params .top_p , min_keep = min_keep
857
- )
858
- ctx_main .sample_min_p (
859
- token_data_array , self .params .min_p , min_keep = min_keep
860
- )
861
- ctx_main .sample_temp (token_data_array , self .params .temp )
862
- id = ctx_main .sample_token (token_data_array )
841
+ min_keep = max (1 , self .params .n_probs )
842
+ ctx_main .sample_top_k (
843
+ token_data_array , self .params .top_k , min_keep = min_keep
844
+ )
845
+ ctx_main .sample_tail_free (
846
+ token_data_array , self .params .tfs_z , min_keep = min_keep
847
+ )
848
+ ctx_main .sample_typical (
849
+ token_data_array , self .params .typical_p , min_keep = min_keep
850
+ )
851
+ ctx_main .sample_top_p (
852
+ token_data_array , self .params .top_p , min_keep = min_keep
853
+ )
854
+ ctx_main .sample_min_p (
855
+ token_data_array , self .params .min_p , min_keep = min_keep
856
+ )
857
+ ctx_main .sample_temp (token_data_array , self .params .temp )
858
+ id = ctx_main .sample_token (token_data_array )
863
859
return id
864
860
865
861
def accept (self , ctx_main : _LlamaContext , id : int , apply_grammar : bool ):
0 commit comments