Skip to content

Commit 8aae026

Browse files
authored
Lint
1 parent 63d0418 commit 8aae026

File tree

1 file changed

+56
-60
lines changed

1 file changed

+56
-60
lines changed

llama_cpp/_internals.py

Lines changed: 56 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,23 @@
11
from __future__ import annotations
22

3-
import os
43
import ctypes
5-
4+
import os
5+
from contextlib import ExitStack
6+
from dataclasses import dataclass, field
67
from typing import (
78
Dict,
89
List,
910
Optional,
1011
Sequence,
1112
)
12-
from dataclasses import dataclass, field
13-
from contextlib import ExitStack
1413

1514
import numpy as np
1615
import numpy.typing as npt
1716

18-
from .llama_types import *
19-
from .llama_grammar import LlamaGrammar
17+
from llama_cpp import llama_cpp
2018
from ._utils import suppress_stdout_stderr
21-
22-
import llama_cpp.llama_cpp as llama_cpp
23-
19+
from .llama_grammar import LlamaGrammar
20+
from .llama_types import *
2421

2522
# Python wrappers over llama.h structs
2623

@@ -351,7 +348,7 @@ def get_state_size(self) -> int:
351348

352349
# TODO: llama_save_session_file
353350

354-
def decode(self, batch: "_LlamaBatch"):
351+
def decode(self, batch: _LlamaBatch):
355352
assert self.ctx is not None
356353
assert batch.batch is not None
357354
return_code = llama_cpp.llama_decode(
@@ -385,8 +382,8 @@ def set_rng_seed(self, seed: int):
385382

386383
def sample_repetition_penalties(
387384
self,
388-
candidates: "_LlamaTokenDataArray",
389-
last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]",
385+
candidates: _LlamaTokenDataArray,
386+
last_tokens_data: llama_cpp.Array[llama_cpp.llama_token],
390387
penalty_last_n: int,
391388
penalty_repeat: float,
392389
penalty_freq: float,
@@ -403,54 +400,54 @@ def sample_repetition_penalties(
403400
penalty_present,
404401
)
405402

406-
def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
403+
def sample_softmax(self, candidates: _LlamaTokenDataArray):
407404
assert self.ctx is not None
408405
llama_cpp.llama_sample_softmax(
409406
self.ctx,
410407
llama_cpp.byref(candidates.candidates),
411408
)
412409

413-
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
410+
def sample_top_k(self, candidates: _LlamaTokenDataArray, k: int, min_keep: int):
414411
assert self.ctx is not None
415412
llama_cpp.llama_sample_top_k(
416413
self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep
417414
)
418415

419-
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
416+
def sample_top_p(self, candidates: _LlamaTokenDataArray, p: float, min_keep: int):
420417
assert self.ctx is not None
421418
llama_cpp.llama_sample_top_p(
422419
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
423420
)
424421

425-
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
422+
def sample_min_p(self, candidates: _LlamaTokenDataArray, p: float, min_keep: int):
426423
assert self.ctx is not None
427424
llama_cpp.llama_sample_min_p(
428425
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
429426
)
430427

431428
def sample_tail_free(
432-
self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int
429+
self, candidates: _LlamaTokenDataArray, z: float, min_keep: int
433430
):
434431
assert self.ctx is not None
435432
llama_cpp.llama_sample_tail_free(
436433
self.ctx, llama_cpp.byref(candidates.candidates), z, min_keep
437434
)
438435

439436
def sample_typical(
440-
self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int
437+
self, candidates: _LlamaTokenDataArray, p: float, min_keep: int
441438
):
442439
assert self.ctx is not None
443440
llama_cpp.llama_sample_typical(
444441
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
445442
)
446443

447-
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
444+
def sample_temp(self, candidates: _LlamaTokenDataArray, temp: float):
448445
assert self.ctx is not None
449446
llama_cpp.llama_sample_temp(
450447
self.ctx, llama_cpp.byref(candidates.candidates), temp
451448
)
452449

453-
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
450+
def sample_grammar(self, candidates: _LlamaTokenDataArray, grammar: LlamaGrammar):
454451
assert self.ctx is not None
455452
assert grammar.grammar is not None
456453
llama_cpp.llama_sample_grammar(
@@ -461,7 +458,7 @@ def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGramm
461458

462459
def sample_token_mirostat(
463460
self,
464-
candidates: "_LlamaTokenDataArray",
461+
candidates: _LlamaTokenDataArray,
465462
tau: float,
466463
eta: float,
467464
m: int,
@@ -493,14 +490,14 @@ def sample_token_mirostat_v2(
493490
mu,
494491
)
495492

496-
def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int:
493+
def sample_token_greedy(self, candidates: _LlamaTokenDataArray) -> int:
497494
assert self.ctx is not None
498495
return llama_cpp.llama_sample_token_greedy(
499496
self.ctx,
500497
llama_cpp.byref(candidates.candidates),
501498
)
502499

503-
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
500+
def sample_token(self, candidates: _LlamaTokenDataArray) -> int:
504501
assert self.ctx is not None
505502
return llama_cpp.llama_sample_token(
506503
self.ctx,
@@ -822,44 +819,43 @@ def sample(
822819
id = token_data_array.candidates_data.id[0]
823820
elif self.params.temp == 0:
824821
id = ctx_main.sample_token_greedy(token_data_array)
822+
elif self.params.mirostat == 1:
823+
mirostat_m = 100
824+
ctx_main.sample_temp(token_data_array, self.params.temp)
825+
id = ctx_main.sample_token_mirostat(
826+
token_data_array,
827+
self.params.mirostat_tau,
828+
self.params.mirostat_eta,
829+
mirostat_m,
830+
ctypes.pointer(self.mirostat_mu),
831+
)
832+
elif self.params.mirostat == 2:
833+
ctx_main.sample_temp(token_data_array, self.params.temp)
834+
id = ctx_main.sample_token_mirostat_v2(
835+
token_data_array,
836+
self.params.mirostat_tau,
837+
self.params.mirostat_eta,
838+
ctypes.pointer(self.mirostat_mu),
839+
)
825840
else:
826-
if self.params.mirostat == 1:
827-
mirostat_m = 100
828-
ctx_main.sample_temp(token_data_array, self.params.temp)
829-
id = ctx_main.sample_token_mirostat(
830-
token_data_array,
831-
self.params.mirostat_tau,
832-
self.params.mirostat_eta,
833-
mirostat_m,
834-
ctypes.pointer(self.mirostat_mu),
835-
)
836-
elif self.params.mirostat == 2:
837-
ctx_main.sample_temp(token_data_array, self.params.temp)
838-
id = ctx_main.sample_token_mirostat_v2(
839-
token_data_array,
840-
self.params.mirostat_tau,
841-
self.params.mirostat_eta,
842-
ctypes.pointer(self.mirostat_mu),
843-
)
844-
else:
845-
min_keep = max(1, self.params.n_probs)
846-
ctx_main.sample_top_k(
847-
token_data_array, self.params.top_k, min_keep=min_keep
848-
)
849-
ctx_main.sample_tail_free(
850-
token_data_array, self.params.tfs_z, min_keep=min_keep
851-
)
852-
ctx_main.sample_typical(
853-
token_data_array, self.params.typical_p, min_keep=min_keep
854-
)
855-
ctx_main.sample_top_p(
856-
token_data_array, self.params.top_p, min_keep=min_keep
857-
)
858-
ctx_main.sample_min_p(
859-
token_data_array, self.params.min_p, min_keep=min_keep
860-
)
861-
ctx_main.sample_temp(token_data_array, self.params.temp)
862-
id = ctx_main.sample_token(token_data_array)
841+
min_keep = max(1, self.params.n_probs)
842+
ctx_main.sample_top_k(
843+
token_data_array, self.params.top_k, min_keep=min_keep
844+
)
845+
ctx_main.sample_tail_free(
846+
token_data_array, self.params.tfs_z, min_keep=min_keep
847+
)
848+
ctx_main.sample_typical(
849+
token_data_array, self.params.typical_p, min_keep=min_keep
850+
)
851+
ctx_main.sample_top_p(
852+
token_data_array, self.params.top_p, min_keep=min_keep
853+
)
854+
ctx_main.sample_min_p(
855+
token_data_array, self.params.min_p, min_keep=min_keep
856+
)
857+
ctx_main.sample_temp(token_data_array, self.params.temp)
858+
id = ctx_main.sample_token(token_data_array)
863859
return id
864860

865861
def accept(self, ctx_main: _LlamaContext, id: int, apply_grammar: bool):

0 commit comments

Comments
 (0)