|
13 | 13 |
|
14 | 14 | import ctypes
|
15 | 15 | import sys
|
16 |
| -from time import time |
17 | 16 | from os import cpu_count, path
|
| 17 | +from time import time |
18 | 18 |
|
19 |
| -import llama_cpp |
20 |
| -from common import GptParams, gpt_params_parse, gpt_random_prompt |
21 | 19 | import util
|
| 20 | +from common import GptParams, gpt_params_parse, gpt_random_prompt |
| 21 | + |
| 22 | +import llama_cpp |
22 | 23 |
|
23 | 24 |
|
24 | 25 | # A LLaMA interactive session
|
@@ -475,63 +476,62 @@ def generate(self):
|
475 | 476 | if self.params.temp <= 0:
|
476 | 477 | # Greedy sampling
|
477 | 478 | id = llama_cpp.llama_sample_token_greedy(self.ctx, candidates_p)
|
| 479 | + elif self.params.mirostat == 1: |
| 480 | + mirostat_mu = 2.0 * self.params.mirostat_tau |
| 481 | + mirostat_m = 100 |
| 482 | + llama_cpp.llama_sample_temperature( |
| 483 | + self.ctx, candidates_p, llama_cpp.c_float(self.params.temp) |
| 484 | + ) |
| 485 | + id = llama_cpp.llama_sample_token_mirostat( |
| 486 | + self.ctx, |
| 487 | + candidates_p, |
| 488 | + llama_cpp.c_float(self.params.mirostat_tau), |
| 489 | + llama_cpp.c_float(self.params.mirostat_eta), |
| 490 | + llama_cpp.c_int(mirostat_m), |
| 491 | + llama_cpp.c_float(mirostat_mu), |
| 492 | + ) |
| 493 | + elif self.params.mirostat == 2: |
| 494 | + mirostat_mu = 2.0 * self.params.mirostat_tau |
| 495 | + llama_cpp.llama_sample_temperature( |
| 496 | + self.ctx, candidates_p, llama_cpp.c_float(self.params.temp) |
| 497 | + ) |
| 498 | + id = llama_cpp.llama_sample_token_mirostat_v2( |
| 499 | + self.ctx, |
| 500 | + candidates_p, |
| 501 | + llama_cpp.c_float(self.params.mirostat_tau), |
| 502 | + llama_cpp.c_float(self.params.mirostat_eta), |
| 503 | + llama_cpp.c_float(mirostat_mu), |
| 504 | + ) |
478 | 505 | else:
|
479 |
| - if self.params.mirostat == 1: |
480 |
| - mirostat_mu = 2.0 * self.params.mirostat_tau |
481 |
| - mirostat_m = 100 |
482 |
| - llama_cpp.llama_sample_temperature( |
483 |
| - self.ctx, candidates_p, llama_cpp.c_float(self.params.temp) |
484 |
| - ) |
485 |
| - id = llama_cpp.llama_sample_token_mirostat( |
486 |
| - self.ctx, |
487 |
| - candidates_p, |
488 |
| - llama_cpp.c_float(self.params.mirostat_tau), |
489 |
| - llama_cpp.c_float(self.params.mirostat_eta), |
490 |
| - llama_cpp.c_int(mirostat_m), |
491 |
| - llama_cpp.c_float(mirostat_mu), |
492 |
| - ) |
493 |
| - elif self.params.mirostat == 2: |
494 |
| - mirostat_mu = 2.0 * self.params.mirostat_tau |
495 |
| - llama_cpp.llama_sample_temperature( |
496 |
| - self.ctx, candidates_p, llama_cpp.c_float(self.params.temp) |
497 |
| - ) |
498 |
| - id = llama_cpp.llama_sample_token_mirostat_v2( |
499 |
| - self.ctx, |
500 |
| - candidates_p, |
501 |
| - llama_cpp.c_float(self.params.mirostat_tau), |
502 |
| - llama_cpp.c_float(self.params.mirostat_eta), |
503 |
| - llama_cpp.c_float(mirostat_mu), |
504 |
| - ) |
505 |
| - else: |
506 |
| - # Temperature sampling |
507 |
| - llama_cpp.llama_sample_top_k( |
508 |
| - self.ctx, |
509 |
| - candidates_p, |
510 |
| - top_k, |
511 |
| - min_keep=llama_cpp.c_size_t(1), |
512 |
| - ) |
513 |
| - llama_cpp.llama_sample_tail_free( |
514 |
| - self.ctx, |
515 |
| - candidates_p, |
516 |
| - llama_cpp.c_float(self.params.tfs_z), |
517 |
| - min_keep=llama_cpp.c_size_t(1), |
518 |
| - ) |
519 |
| - llama_cpp.llama_sample_typical( |
520 |
| - self.ctx, |
521 |
| - candidates_p, |
522 |
| - llama_cpp.c_float(self.params.typical_p), |
523 |
| - min_keep=llama_cpp.c_size_t(1), |
524 |
| - ) |
525 |
| - llama_cpp.llama_sample_top_p( |
526 |
| - self.ctx, |
527 |
| - candidates_p, |
528 |
| - llama_cpp.c_float(self.params.top_p), |
529 |
| - min_keep=llama_cpp.c_size_t(1), |
530 |
| - ) |
531 |
| - llama_cpp.llama_sample_temperature( |
532 |
| - self.ctx, candidates_p, llama_cpp.c_float(self.params.temp) |
533 |
| - ) |
534 |
| - id = llama_cpp.llama_sample_token(self.ctx, candidates_p) |
| 506 | + # Temperature sampling |
| 507 | + llama_cpp.llama_sample_top_k( |
| 508 | + self.ctx, |
| 509 | + candidates_p, |
| 510 | + top_k, |
| 511 | + min_keep=llama_cpp.c_size_t(1), |
| 512 | + ) |
| 513 | + llama_cpp.llama_sample_tail_free( |
| 514 | + self.ctx, |
| 515 | + candidates_p, |
| 516 | + llama_cpp.c_float(self.params.tfs_z), |
| 517 | + min_keep=llama_cpp.c_size_t(1), |
| 518 | + ) |
| 519 | + llama_cpp.llama_sample_typical( |
| 520 | + self.ctx, |
| 521 | + candidates_p, |
| 522 | + llama_cpp.c_float(self.params.typical_p), |
| 523 | + min_keep=llama_cpp.c_size_t(1), |
| 524 | + ) |
| 525 | + llama_cpp.llama_sample_top_p( |
| 526 | + self.ctx, |
| 527 | + candidates_p, |
| 528 | + llama_cpp.c_float(self.params.top_p), |
| 529 | + min_keep=llama_cpp.c_size_t(1), |
| 530 | + ) |
| 531 | + llama_cpp.llama_sample_temperature( |
| 532 | + self.ctx, candidates_p, llama_cpp.c_float(self.params.temp) |
| 533 | + ) |
| 534 | + id = llama_cpp.llama_sample_token(self.ctx, candidates_p) |
535 | 535 | # print("`{}`".format(candidates_p.size))
|
536 | 536 |
|
537 | 537 | self.last_n_tokens.pop(0)
|
|
0 commit comments