diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000..ccf548723 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,21 @@ +name: ci + +on: + pull_request: + push: + branches: [main] + +jobs: + pre-commit: + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - uses: actions/checkout@v4 + # with: + # ref: ${{ github.head_ref }} + - uses: actions/setup-python@v5 + - uses: pre-commit/action@v3.0.1 + # - uses: stefanzweifel/git-auto-commit-action@v5 + # with: + # commit_message: 'pre commit fixes' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..84c7865fa --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,25 @@ +repos: + # auto update + - repo: https://gitlab.com/vojko.pribudic.foss/pre-commit-update + rev: "v0.5.0" + hooks: + - id: pre-commit-update + args: [--dry-run, --all-versions] + + # ruff + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.6.7" + hooks: + # Run the linter. + - id: ruff + types_or: [ python, pyi, jupyter ] + args: [ --fix ] + # Run the formatter. + - id: ruff-format + types_or: [ python, pyi, jupyter ] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: "v1.11.2" + hooks: + - id: mypy + args: [ '--ignore-missing-imports', '--disable-error-code=top-level-await', "--disable-error-code=empty-body" ] diff --git a/docker/open_llama/hug_model.py b/docker/open_llama/hug_model.py index 13c5b6b0d..ee23821a6 100644 --- a/docker/open_llama/hug_model.py +++ b/docker/open_llama/hug_model.py @@ -1,26 +1,27 @@ -import requests +import argparse import json import os import struct -import argparse + +import requests + def make_request(url, params=None): print(f"Making request to {url}...") response = requests.get(url, params=params) if response.status_code == 200: return json.loads(response.text) - else: - print(f"Request failed with status code {response.status_code}") - return None + print(f"Request failed with status code {response.status_code}") + return None def check_magic_and_version(filename): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: # Read the first 6 bytes from the file data = f.read(6) # Unpack the binary data, interpreting the first 4 bytes as a little-endian unsigned int # and the next 2 bytes as a little-endian unsigned short - magic, version = struct.unpack('= 10485760: # 10 MB - print('.', end='', flush=True) + print(".", end="", flush=True) total_downloaded = 0 print("\nDownload complete.") - + # Creating a symbolic link from destination to "model.bin" if os.path.isfile("model.bin"): os.remove("model.bin") # remove the existing link if any @@ -61,30 +62,29 @@ def get_user_choice(model_list): if 0 <= index < len(model_list): # Return the chosen model return model_list[index] - else: - print("Invalid choice.") + print("Invalid choice.") except ValueError: print("Invalid input. Please enter a number corresponding to a model.") except IndexError: print("Invalid choice. Index out of range.") - + return None def main(): # Create an argument parser - parser = argparse.ArgumentParser(description='Process some parameters.') + parser = argparse.ArgumentParser(description="Process some parameters.") # Arguments - parser.add_argument('-v', '--version', type=int, default=0x0003, - help='hexadecimal version number of ggml file') - parser.add_argument('-a', '--author', type=str, default='TheBloke', - help='HuggingFace author filter') - parser.add_argument('-t', '--tag', type=str, default='llama', - help='HuggingFace tag filter') - parser.add_argument('-s', '--search', type=str, default='', - help='HuggingFace search filter') - parser.add_argument('-f', '--filename', type=str, default='q5_1', - help='HuggingFace model repository filename substring match') + parser.add_argument("-v", "--version", type=int, default=0x0003, + help="hexadecimal version number of ggml file") + parser.add_argument("-a", "--author", type=str, default="TheBloke", + help="HuggingFace author filter") + parser.add_argument("-t", "--tag", type=str, default="llama", + help="HuggingFace tag filter") + parser.add_argument("-s", "--search", type=str, default="", + help="HuggingFace search filter") + parser.add_argument("-f", "--filename", type=str, default="q5_1", + help="HuggingFace model repository filename substring match") # Parse the arguments args = parser.parse_args() @@ -96,20 +96,20 @@ def main(): "search": args.search } - models = make_request('https://huggingface.co/api/models', params=params) + models = make_request("https://huggingface.co/api/models", params=params) if models is None: return model_list = [] # Iterate over the models for model in models: - model_id = model['id'] - model_info = make_request(f'https://huggingface.co/api/models/{model_id}') + model_id = model["id"] + model_info = make_request(f"https://huggingface.co/api/models/{model_id}") if model_info is None: continue - for sibling in model_info.get('siblings', []): - rfilename = sibling.get('rfilename') + for sibling in model_info.get("siblings", []): + rfilename = sibling.get("rfilename") if rfilename and args.filename in rfilename: model_list.append((model_id, rfilename)) @@ -135,5 +135,5 @@ def main(): print("Error - model choice was None") exit(2) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/batch-processing/server.py b/examples/batch-processing/server.py index 0b36746f9..067e1f271 100644 --- a/examples/batch-processing/server.py +++ b/examples/batch-processing/server.py @@ -23,8 +23,6 @@ app = FastAPI() -import openai.types.chat as types - @app.post("/v1/chat/completions") def create_chat_completions(): diff --git a/examples/gradio_chat/local.py b/examples/gradio_chat/local.py index e16bf234a..65bbfb79e 100644 --- a/examples/gradio_chat/local.py +++ b/examples/gradio_chat/local.py @@ -1,13 +1,13 @@ +import gradio as gr + import llama_cpp import llama_cpp.llama_tokenizer -import gradio as gr - llama = llama_cpp.Llama.from_pretrained( repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF", filename="*q8_0.gguf", tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained( - "Qwen/Qwen1.5-0.5B" + "Qwen/Qwen1.5-0.5B", ), verbose=False, ) @@ -25,7 +25,7 @@ def predict(message, history): messages.append({"role": "user", "content": message}) response = llama.create_chat_completion_openai_v1( - model=model, messages=messages, stream=True + model=model, messages=messages, stream=True, ) text = "" diff --git a/examples/gradio_chat/server.py b/examples/gradio_chat/server.py index 52061bea6..516c2dad2 100644 --- a/examples/gradio_chat/server.py +++ b/examples/gradio_chat/server.py @@ -1,5 +1,4 @@ import gradio as gr - from openai import OpenAI client = OpenAI(base_url="http://localhost:8000/v1", api_key="llama.cpp") @@ -17,7 +16,7 @@ def predict(message, history): messages.append({"role": "user", "content": message}) response = client.chat.completions.create( - model=model, messages=messages, stream=True + model=model, messages=messages, stream=True, ) text = "" diff --git a/examples/hf_pull/main.py b/examples/hf_pull/main.py index dfed17516..ed1046b08 100644 --- a/examples/hf_pull/main.py +++ b/examples/hf_pull/main.py @@ -1,12 +1,11 @@ import llama_cpp import llama_cpp.llama_tokenizer - llama = llama_cpp.Llama.from_pretrained( repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF", filename="*q8_0.gguf", tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained( - "Qwen/Qwen1.5-0.5B" + "Qwen/Qwen1.5-0.5B", ), verbose=False, ) diff --git a/examples/high_level_api/fastapi_server.py b/examples/high_level_api/fastapi_server.py index ee59767d6..469aa0996 100644 --- a/examples/high_level_api/fastapi_server.py +++ b/examples/high_level_api/fastapi_server.py @@ -26,6 +26,7 @@ """ import os + import uvicorn from llama_cpp.server.app import create_app @@ -34,5 +35,5 @@ app = create_app() uvicorn.run( - app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)) + app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)), ) diff --git a/examples/high_level_api/high_level_api_inference.py b/examples/high_level_api/high_level_api_inference.py index e41f37577..349445c7c 100644 --- a/examples/high_level_api/high_level_api_inference.py +++ b/examples/high_level_api/high_level_api_inference.py @@ -1,5 +1,5 @@ -import json import argparse +import json from llama_cpp import Llama diff --git a/examples/high_level_api/high_level_api_infill.py b/examples/high_level_api/high_level_api_infill.py index 282333e5a..c132ec632 100644 --- a/examples/high_level_api/high_level_api_infill.py +++ b/examples/high_level_api/high_level_api_infill.py @@ -33,5 +33,5 @@ filtered = True print( - f"Fill-in-Middle completion{' (filtered)' if filtered else ''}:\n\n{args.prompt}\033[32m{response}\033[{'33' if filtered else '0'}m{args.suffix}\033[0m" + f"Fill-in-Middle completion{' (filtered)' if filtered else ''}:\n\n{args.prompt}\033[32m{response}\033[{'33' if filtered else '0'}m{args.suffix}\033[0m", ) diff --git a/examples/high_level_api/high_level_api_streaming.py b/examples/high_level_api/high_level_api_streaming.py index 747c6130e..868cdee7c 100644 --- a/examples/high_level_api/high_level_api_streaming.py +++ b/examples/high_level_api/high_level_api_streaming.py @@ -1,5 +1,5 @@ -import json import argparse +import json from llama_cpp import Llama diff --git a/examples/high_level_api/langchain_custom_llm.py b/examples/high_level_api/langchain_custom_llm.py index b91632f5b..091cbbb69 100644 --- a/examples/high_level_api/langchain_custom_llm.py +++ b/examples/high_level_api/langchain_custom_llm.py @@ -1,9 +1,10 @@ import argparse - -from llama_cpp import Llama +from collections.abc import Mapping +from typing import Any, List, Optional from langchain.llms.base import LLM -from typing import Optional, List, Mapping, Any + +from llama_cpp import Llama class LlamaLLM(LLM): @@ -19,7 +20,7 @@ def __init__(self, model_path: str, **kwargs: Any): llm = Llama(model_path=model_path) super().__init__(model_path=model_path, llm=llm, **kwargs) - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call(self, prompt: str, stop: list[str] | None = None) -> str: response = self.llm(prompt, stop=stop or []) return response["choices"][0]["text"] @@ -37,13 +38,13 @@ def _identifying_params(self) -> Mapping[str, Any]: # Basic Q&A answer = llm( - "Question: What is the capital of France? Answer: ", stop=["Question:", "\n"] + "Question: What is the capital of France? Answer: ", stop=["Question:", "\n"], ) print(f"Answer: {answer.strip()}") # Using in a chain -from langchain.prompts import PromptTemplate from langchain.chains import LLMChain +from langchain.prompts import PromptTemplate prompt = PromptTemplate( input_variables=["product"], diff --git a/examples/low_level_api/Chat.py b/examples/low_level_api/Chat.py index a755089b2..4f338f8d0 100644 --- a/examples/low_level_api/Chat.py +++ b/examples/low_level_api/Chat.py @@ -1,5 +1,8 @@ #!/bin/python -import sys, os, datetime +import datetime +import os +import sys + from common import GptParams from low_level_api_chat_cpp import LLaMAInteract @@ -48,7 +51,7 @@ def env_or_def(env, default): {USER_NAME}: What time is it? {AI_NAME}: It is {DATE_TIME}. {USER_NAME}:""" + " ".join( - sys.argv[1:] + sys.argv[1:], ) print("Loading model...") diff --git a/examples/low_level_api/Miku.py b/examples/low_level_api/Miku.py index e072ab1b1..abc1f66c4 100644 --- a/examples/low_level_api/Miku.py +++ b/examples/low_level_api/Miku.py @@ -1,5 +1,7 @@ #!/bin/python -import sys, os +import os +import sys + from common import GptParams from low_level_api_chat_cpp import LLaMAInteract @@ -35,7 +37,7 @@ def env_or_def(env, default): {AI_NAME}: /think I wonder what {USER_NAME} likes to do in his free time? I should ask him about that! {AI_NAME}: What do you like to do in your free time? ^_^ {USER_NAME}:""" + " ".join( - sys.argv[1:] + sys.argv[1:], ) print("Loading model...") diff --git a/examples/low_level_api/ReasonAct.py b/examples/low_level_api/ReasonAct.py index 1f2c59017..ebfcceae9 100644 --- a/examples/low_level_api/ReasonAct.py +++ b/examples/low_level_api/ReasonAct.py @@ -1,5 +1,7 @@ #!/bin/python -import sys, os, datetime +import os +import sys + from common import GptParams from low_level_api_chat_cpp import LLaMAInteract @@ -12,7 +14,7 @@ def env_or_def(env, default): MODEL = env_or_def("MODEL", "./models/llama-13B/ggml-model.bin") -prompt = f"""You run in a loop of Thought, Action, Observation. +prompt = """You run in a loop of Thought, Action, Observation. At the end of the loop either Answer or restate your Thought and Action. Use Thought to describe your thoughts about the question you have been asked. Use Action to run one of these actions available to you: @@ -30,7 +32,7 @@ def env_or_def(env, default): Thought: Do I need to use an action? No, I know the answer Answer: Paris is the capital of France Question:""" + " ".join( - sys.argv[1:] + sys.argv[1:], ) print("Loading model...") diff --git a/examples/low_level_api/common.py b/examples/low_level_api/common.py index a0212ff0d..5c3214921 100644 --- a/examples/low_level_api/common.py +++ b/examples/low_level_api/common.py @@ -1,7 +1,6 @@ -import os import argparse +import os import re - from dataclasses import dataclass, field from typing import List @@ -38,7 +37,7 @@ class GptParams: path_session: str = "" input_prefix: str = " " input_suffix: str = "" - antiprompt: List[str] = field(default_factory=list) + antiprompt: list[str] = field(default_factory=list) lora_adapter: str = "" lora_base: str = "" @@ -76,7 +75,7 @@ class GptParams: def gpt_params_parse(argv=None): parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "-s", @@ -103,7 +102,7 @@ def gpt_params_parse(argv=None): dest="n_predict", ) parser.add_argument( - "--n_parts", type=int, default=-1, help="number of model parts", dest="n_parts" + "--n_parts", type=int, default=-1, help="number of model parts", dest="n_parts", ) parser.add_argument( "-c", @@ -144,10 +143,10 @@ def gpt_params_parse(argv=None): dest="ignore_eos", ) parser.add_argument( - "--top_k", type=int, default=40, help="top-k sampling", dest="top_k" + "--top_k", type=int, default=40, help="top-k sampling", dest="top_k", ) parser.add_argument( - "--top_p", type=float, default=0.95, help="top-p samplin", dest="top_p" + "--top_p", type=float, default=0.95, help="top-p samplin", dest="top_p", ) parser.add_argument( "--tfs", @@ -157,7 +156,7 @@ def gpt_params_parse(argv=None): dest="tfs_z", ) parser.add_argument( - "--temp", type=float, default=0.80, help="temperature", dest="temp" + "--temp", type=float, default=0.80, help="temperature", dest="temp", ) parser.add_argument( "--repeat_penalty", @@ -218,7 +217,7 @@ def gpt_params_parse(argv=None): dest="model", ) parser.add_argument( - "-p", "--prompt", type=str, default=None, help="initial prompt", dest="prompt" + "-p", "--prompt", type=str, default=None, help="initial prompt", dest="prompt", ) parser.add_argument( "-f", @@ -243,7 +242,7 @@ def gpt_params_parse(argv=None): dest="input_prefix", ) parser.add_argument( - "--in-suffix", type=str, default="", help="append to input", dest="input_suffix" + "--in-suffix", type=str, default="", help="append to input", dest="input_suffix", ) parser.add_argument( "-r", @@ -378,7 +377,7 @@ def gpt_params_parse(argv=None): if params.lora_adapter: params.use_mmap = False - if logit_bias_str != None: + if logit_bias_str is not None: for i in logit_bias_str: if m := re.match(r"(\d+)([-+]\d+)", i): params.logit_bias[int(m.group(1))] = float(m.group(2)) diff --git a/examples/low_level_api/low_level_api_chat_cpp.py b/examples/low_level_api/low_level_api_chat_cpp.py index 39081be17..3cadea7a6 100644 --- a/examples/low_level_api/low_level_api_chat_cpp.py +++ b/examples/low_level_api/low_level_api_chat_cpp.py @@ -13,12 +13,13 @@ import ctypes import sys -from time import time from os import cpu_count, path +from time import time -import llama_cpp -from common import GptParams, gpt_params_parse, gpt_random_prompt import util +from common import GptParams, gpt_params_parse, gpt_random_prompt + +import llama_cpp # A LLaMA interactive session @@ -35,14 +36,14 @@ def __init__(self, params: GptParams) -> None: raise NotImplementedError( """************ please use the 'perplexity' tool for perplexity calculations -************""" +************""", ) if self.params.embedding: raise NotImplementedError( """************ please use the 'embedding' tool for embedding calculations -************""" +************""", ) if self.params.n_ctx > 2048: @@ -80,7 +81,7 @@ def __init__(self, params: GptParams) -> None: self.lparams.use_mmap = self.params.use_mmap self.model = llama_cpp.llama_load_model_from_file( - self.params.model.encode("utf8"), self.lparams + self.params.model.encode("utf8"), self.lparams, ) # Context Params. @@ -174,7 +175,7 @@ def __init__(self, params: GptParams) -> None: file=sys.stderr, ) else: - print(f"session file does not exist, will create", file=sys.stderr) + print("session file does not exist, will create", file=sys.stderr) # tokenize the prompt self.embd = [] @@ -182,7 +183,7 @@ def __init__(self, params: GptParams) -> None: if len(self.embd_inp) > self.n_ctx - 4: raise RuntimeError( - f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})" + f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})", ) # debug message about similarity of saved session, if applicable @@ -197,18 +198,18 @@ def __init__(self, params: GptParams) -> None: self.n_matching_session_tokens += 1 if self.n_matching_session_tokens >= len(self.embd_inp): - print(f"session file has exact match for prompt!") + print("session file has exact match for prompt!") elif self.n_matching_session_tokens < (len(self.embd_inp) / 2): print( - f"warning: session file has low similarity to prompt ({self.n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated" + f"warning: session file has low similarity to prompt ({self.n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated", ) else: print( - f"session file matches {self.n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt" + f"session file matches {self.n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt", ) self.need_to_save_session = len( - self.params.path_session + self.params.path_session, ) > 0 and self.n_matching_session_tokens < (len(self.embd_inp) * 3 / 4) # number of tokens to keep when resetting context @@ -423,7 +424,7 @@ def generate(self): self.ctx, self.params.path_session.encode("utf8"), (llama_cpp.llama_token * len(self.session_tokens))( - *self.session_tokens + *self.session_tokens, ), len(self.session_tokens), ) @@ -441,10 +442,10 @@ def generate(self): *[ llama_cpp.llama_token_data(token_id, logits[token_id], 0.0) for token_id in range(n_vocab) - ] + ], ) candidates_p = llama_cpp.ctypes.pointer( - llama_cpp.llama_token_data_array(_arr, len(_arr), False) + llama_cpp.llama_token_data_array(_arr, len(_arr), False), ) # Apply penalties @@ -452,7 +453,7 @@ def generate(self): last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx) _arr = (llama_cpp.llama_token * last_n_repeat)( - *self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat :] + *self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat :], ) llama_cpp.llama_sample_repetition_penalties( ctx=self.ctx, @@ -475,63 +476,62 @@ def generate(self): if self.params.temp <= 0: # Greedy sampling id = llama_cpp.llama_sample_token_greedy(self.ctx, candidates_p) + elif self.params.mirostat == 1: + mirostat_mu = 2.0 * self.params.mirostat_tau + mirostat_m = 100 + llama_cpp.llama_sample_temperature( + self.ctx, candidates_p, llama_cpp.c_float(self.params.temp), + ) + id = llama_cpp.llama_sample_token_mirostat( + self.ctx, + candidates_p, + llama_cpp.c_float(self.params.mirostat_tau), + llama_cpp.c_float(self.params.mirostat_eta), + llama_cpp.c_int(mirostat_m), + llama_cpp.c_float(mirostat_mu), + ) + elif self.params.mirostat == 2: + mirostat_mu = 2.0 * self.params.mirostat_tau + llama_cpp.llama_sample_temperature( + self.ctx, candidates_p, llama_cpp.c_float(self.params.temp), + ) + id = llama_cpp.llama_sample_token_mirostat_v2( + self.ctx, + candidates_p, + llama_cpp.c_float(self.params.mirostat_tau), + llama_cpp.c_float(self.params.mirostat_eta), + llama_cpp.c_float(mirostat_mu), + ) else: - if self.params.mirostat == 1: - mirostat_mu = 2.0 * self.params.mirostat_tau - mirostat_m = 100 - llama_cpp.llama_sample_temperature( - self.ctx, candidates_p, llama_cpp.c_float(self.params.temp) - ) - id = llama_cpp.llama_sample_token_mirostat( - self.ctx, - candidates_p, - llama_cpp.c_float(self.params.mirostat_tau), - llama_cpp.c_float(self.params.mirostat_eta), - llama_cpp.c_int(mirostat_m), - llama_cpp.c_float(mirostat_mu), - ) - elif self.params.mirostat == 2: - mirostat_mu = 2.0 * self.params.mirostat_tau - llama_cpp.llama_sample_temperature( - self.ctx, candidates_p, llama_cpp.c_float(self.params.temp) - ) - id = llama_cpp.llama_sample_token_mirostat_v2( - self.ctx, - candidates_p, - llama_cpp.c_float(self.params.mirostat_tau), - llama_cpp.c_float(self.params.mirostat_eta), - llama_cpp.c_float(mirostat_mu), - ) - else: - # Temperature sampling - llama_cpp.llama_sample_top_k( - self.ctx, - candidates_p, - top_k, - min_keep=llama_cpp.c_size_t(1), - ) - llama_cpp.llama_sample_tail_free( - self.ctx, - candidates_p, - llama_cpp.c_float(self.params.tfs_z), - min_keep=llama_cpp.c_size_t(1), - ) - llama_cpp.llama_sample_typical( - self.ctx, - candidates_p, - llama_cpp.c_float(self.params.typical_p), - min_keep=llama_cpp.c_size_t(1), - ) - llama_cpp.llama_sample_top_p( - self.ctx, - candidates_p, - llama_cpp.c_float(self.params.top_p), - min_keep=llama_cpp.c_size_t(1), - ) - llama_cpp.llama_sample_temperature( - self.ctx, candidates_p, llama_cpp.c_float(self.params.temp) - ) - id = llama_cpp.llama_sample_token(self.ctx, candidates_p) + # Temperature sampling + llama_cpp.llama_sample_top_k( + self.ctx, + candidates_p, + top_k, + min_keep=llama_cpp.c_size_t(1), + ) + llama_cpp.llama_sample_tail_free( + self.ctx, + candidates_p, + llama_cpp.c_float(self.params.tfs_z), + min_keep=llama_cpp.c_size_t(1), + ) + llama_cpp.llama_sample_typical( + self.ctx, + candidates_p, + llama_cpp.c_float(self.params.typical_p), + min_keep=llama_cpp.c_size_t(1), + ) + llama_cpp.llama_sample_top_p( + self.ctx, + candidates_p, + llama_cpp.c_float(self.params.top_p), + min_keep=llama_cpp.c_size_t(1), + ) + llama_cpp.llama_sample_temperature( + self.ctx, candidates_p, llama_cpp.c_float(self.params.temp), + ) + id = llama_cpp.llama_sample_token(self.ctx, candidates_p) # print("`{}`".format(candidates_p.size)) self.last_n_tokens.pop(0) @@ -600,7 +600,7 @@ def generate(self): # end of text token if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos( - self.ctx + self.ctx, ): if not self.params.instruct: for i in self.llama_token_eot: @@ -636,7 +636,7 @@ def token_to_str(self, token_id: int) -> bytes: size = 32 buffer = (ctypes.c_char * size)() n = llama_cpp.llama_token_to_piece( - self.model, llama_cpp.llama_token(token_id), buffer, size + self.model, llama_cpp.llama_token(token_id), buffer, size, ) assert n <= size return bytes(buffer[:n]) @@ -668,7 +668,7 @@ def output(self): self.multibyte_fix[self.multibyte_fix.index(None)] = cur_char # Return completed utf char - if len(self.multibyte_fix) > 0 and not None in self.multibyte_fix: + if len(self.multibyte_fix) > 0 and None not in self.multibyte_fix: yield (b"".join(self.multibyte_fix)).decode("utf8") self.multibyte_fix = [] continue @@ -709,7 +709,7 @@ def interact(self): else: print(self.params.input_prefix, end="") self.input( - f"{self.params.input_prefix}{self.read_input()}{self.params.input_suffix}" + f"{self.params.input_prefix}{self.read_input()}{self.params.input_suffix}", ) print(self.params.input_suffix, end="") self.set_color(util.CONSOLE_COLOR_DEFAULT) diff --git a/examples/low_level_api/low_level_api_llama_cpp.py b/examples/low_level_api/low_level_api_llama_cpp.py index ba3545771..7bfcb8d89 100644 --- a/examples/low_level_api/low_level_api_llama_cpp.py +++ b/examples/low_level_api/low_level_api_llama_cpp.py @@ -1,6 +1,6 @@ import ctypes -import os import multiprocessing +import os import llama_cpp @@ -19,7 +19,7 @@ # determine the required inference memory per token: tmp = [0, 1, 2, 3] llama_cpp.llama_eval( - ctx=ctx, tokens=(llama_cpp.c_int * len(tmp))(*tmp), n_tokens=len(tmp), n_past=0 + ctx=ctx, tokens=(llama_cpp.c_int * len(tmp))(*tmp), n_tokens=len(tmp), n_past=0, ) # Deprecated n_past = 0 @@ -76,10 +76,10 @@ *[ llama_cpp.llama_token_data(token_id, logits[token_id], 0.0) for token_id in range(n_vocab) - ] + ], ) candidates_p = llama_cpp.ctypes.pointer( - llama_cpp.llama_token_data_array(_arr, len(_arr), False) + llama_cpp.llama_token_data_array(_arr, len(_arr), False), ) _arr = (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data) @@ -114,7 +114,7 @@ size = 32 buffer = (ctypes.c_char * size)() n = llama_cpp.llama_token_to_piece( - model, llama_cpp.llama_token(id), buffer, size + model, llama_cpp.llama_token(id), buffer, size, ) assert n <= size print( diff --git a/examples/low_level_api/quantize.py b/examples/low_level_api/quantize.py index 057ac389e..b6c4fa7a9 100644 --- a/examples/low_level_api/quantize.py +++ b/examples/low_level_api/quantize.py @@ -1,5 +1,6 @@ -import os import argparse +import os + import llama_cpp diff --git a/examples/low_level_api/util.py b/examples/low_level_api/util.py index ef8b1c1ee..858bfa2be 100644 --- a/examples/low_level_api/util.py +++ b/examples/low_level_api/util.py @@ -45,14 +45,14 @@ def append(self, elem): def __getitem__(self, val): if isinstance(val, int): - if 0 > val or val >= self.size: + if val < 0 or val >= self.size: raise IndexError("Index out of range") return ( self.list[val] if self.size < self.maxsize else self.list[(self.offset + val) % self.maxsize] ) - elif isinstance(val, slice): + if isinstance(val, slice): start, stop, step = val.start, val.stop, val.step if step is None: step = 1 @@ -71,8 +71,7 @@ def __getitem__(self, val): for i in indices if i < self.size ] - else: - raise TypeError("Invalid argument type") + raise TypeError("Invalid argument type") if __name__ == "__main__": diff --git a/examples/notebooks/Batching.ipynb b/examples/notebooks/Batching.ipynb index 73b28c744..26ef05f17 100644 --- a/examples/notebooks/Batching.ipynb +++ b/examples/notebooks/Batching.ipynb @@ -567,7 +567,6 @@ } ], "source": [ - "import ctypes\n", "\n", "streams = [\"\"] * n_parallel\n", "i_batch = [batch.n_tokens - 1] * n_parallel\n", diff --git a/examples/notebooks/Functions.ipynb b/examples/notebooks/Functions.ipynb index 1f4138165..12438d900 100644 --- a/examples/notebooks/Functions.ipynb +++ b/examples/notebooks/Functions.ipynb @@ -40,9 +40,9 @@ } ], "source": [ - "import openai\n", "import json\n", "\n", + "import openai\n", "\n", "client = openai.OpenAI(\n", " api_key=\"sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\", # can be anything\n", @@ -56,14 +56,13 @@ " \"\"\"Get the current weather in a given location\"\"\"\n", " if \"tokyo\" in location.lower():\n", " return json.dumps({\"location\": \"Tokyo\", \"temperature\": \"10\", \"unit\": \"celsius\"})\n", - " elif \"san francisco\" in location.lower():\n", + " if \"san francisco\" in location.lower():\n", " return json.dumps(\n", " {\"location\": \"San Francisco\", \"temperature\": \"72\", \"unit\": \"fahrenheit\"}\n", " )\n", - " elif \"paris\" in location.lower():\n", + " if \"paris\" in location.lower():\n", " return json.dumps({\"location\": \"Paris\", \"temperature\": \"22\", \"unit\": \"celsius\"})\n", - " else:\n", - " return json.dumps({\"location\": location, \"temperature\": \"unknown\"})\n", + " return json.dumps({\"location\": location, \"temperature\": \"unknown\"})\n", "\n", "\n", "def run_conversation():\n", @@ -298,7 +297,7 @@ " Class for a multi-class label prediction.\n", " \"\"\"\n", "\n", - " class_labels: List[MultiLabels]\n", + " class_labels: list[MultiLabels]\n", "\n", "\n", "def multi_classify(data: str) -> MultiClassPrediction:\n", @@ -346,11 +345,10 @@ } ], "source": [ - "from typing_extensions import Annotated\n", - "from pydantic import BaseModel, BeforeValidator\n", + "from typing import Annotated\n", "\n", "from instructor import llm_validator\n", - "\n", + "from pydantic import BaseModel, BeforeValidator\n", "\n", "question = \"What is the meaning of life?\"\n", "context = \"The according to the devil the meaning of live is to live a life of sin and debauchery.\"\n", @@ -429,14 +427,13 @@ ], "source": [ "import re\n", - "from typing import List\n", "\n", - "from pydantic import Field, BaseModel, model_validator, FieldValidationInfo\n", + "from pydantic import BaseModel, Field, FieldValidationInfo, model_validator\n", "\n", "\n", "class Fact(BaseModel):\n", " fact: str = Field(...)\n", - " substring_quote: List[str] = Field(...)\n", + " substring_quote: list[str] = Field(...)\n", "\n", " @model_validator(mode=\"after\")\n", " def validate_sources(self, info: FieldValidationInfo) -> \"Fact\":\n", @@ -456,7 +453,7 @@ "\n", "class QuestionAnswer(BaseModel):\n", " question: str = Field(...)\n", - " answer: List[Fact] = Field(...)\n", + " answer: list[Fact] = Field(...)\n", "\n", " @model_validator(mode=\"after\")\n", " def validate_sources(self) -> \"QuestionAnswer\":\n", diff --git a/examples/notebooks/OpenHermesFunctionCalling.ipynb b/examples/notebooks/OpenHermesFunctionCalling.ipynb index 13128be04..e4d9366aa 100644 --- a/examples/notebooks/OpenHermesFunctionCalling.ipynb +++ b/examples/notebooks/OpenHermesFunctionCalling.ipynb @@ -38,8 +38,8 @@ } ], "source": [ - "import json\n", "import inspect\n", + "import json\n", "from typing import get_type_hints\n", "\n", "\n", @@ -61,7 +61,6 @@ " \"\"\"Get the monthly mortgage payment given an interest rate percentage.\"\"\"\n", "\n", " # TODO: you must implement this to actually call it later\n", - " pass\n", "\n", "\n", "def get_article_details(\n", @@ -75,14 +74,12 @@ " date_published: formatted as \"MM/DD/YYYY\"'''\n", "\n", " # TODO: you must implement this to actually call it later\n", - " pass\n", "\n", "\n", "def get_weather(zip_code: str) -> Weather:\n", " \"\"\"Get the current weather given a zip code.\"\"\"\n", "\n", " # TODO: you must implement this to actually call it later\n", - " pass\n", "\n", "\n", "def get_directions(start: str, destination: str) -> Directions:\n", @@ -91,15 +88,13 @@ " destination: end address as a string including zipcode (if any)\"\"\"\n", "\n", " # TODO: you must implement this to actually call it later\n", - " pass\n", "\n", "\n", "def get_type_name(t):\n", " name = str(t)\n", " if \"list\" in name or \"dict\" in name:\n", " return name\n", - " else:\n", - " return t.__name__\n", + " return t.__name__\n", "\n", "\n", "def serialize_function_to_json(func):\n", @@ -129,8 +124,8 @@ "metadata": {}, "outputs": [], "source": [ - "import xml.etree.ElementTree as ET\n", "import re\n", + "import xml.etree.ElementTree as ET\n", "\n", "\n", "def extract_function_calls(completion):\n", diff --git a/examples/notebooks/PerformanceTuning.ipynb b/examples/notebooks/PerformanceTuning.ipynb index ba74e4a41..04c1fb1d2 100644 --- a/examples/notebooks/PerformanceTuning.ipynb +++ b/examples/notebooks/PerformanceTuning.ipynb @@ -6,18 +6,17 @@ "metadata": {}, "outputs": [], "source": [ - "import time\n", "import json\n", "import multiprocessing\n", - "\n", - "import llama_cpp\n", + "import time\n", "\n", "import numpy as np\n", "\n", - "np.int = int\n", + "import llama_cpp\n", "\n", - "from skopt.space import Integer, Categorical\n", + "int = int\n", "\n", + "from skopt.space import Categorical, Integer\n", "\n", "MODEL_PATH = \"../models/ggml-model.bin\"\n", "\n", diff --git a/examples/ray/llm.py b/examples/ray/llm.py index 2325dd303..21b3a4ffd 100755 --- a/examples/ray/llm.py +++ b/examples/ray/llm.py @@ -1,7 +1,9 @@ -from starlette.requests import Request from typing import Dict + from ray import serve from ray.serve import Application +from starlette.requests import Request + from llama_cpp import Llama @@ -10,12 +12,12 @@ class LlamaDeployment: def __init__(self, model_path: str): self._llm = Llama(model_path=model_path) - async def __call__(self, http_request: Request) -> Dict: + async def __call__(self, http_request: Request) -> dict: input_json = await http_request.json() prompt = input_json["prompt"] max_tokens = input_json.get("max_tokens", 64) return self._llm(prompt, max_tokens=max_tokens) -def llm_builder(args: Dict[str, str]) -> Application: +def llm_builder(args: dict[str, str]) -> Application: return LlamaDeployment.bind(args["model_path"]) diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index db1f1d0fa..cf0720a8f 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,4 @@ -from .llama_cpp import * -from .llama import * +from .llama_cpp import * # noqa +from .llama import * # noqa __version__ = "0.2.90" diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 0aff34844..9d0a39d3a 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -1,27 +1,25 @@ from __future__ import annotations -import os import ctypes - +import os +from collections.abc import Sequence +from contextlib import ExitStack +from dataclasses import dataclass, field from typing import ( Dict, List, Tuple, Optional, - Sequence, ) -from dataclasses import dataclass, field -from contextlib import ExitStack import numpy as np import numpy.typing as npt -from .llama_types import * -from .llama_grammar import LlamaGrammar -from ._utils import suppress_stdout_stderr - -import llama_cpp.llama_cpp as llama_cpp +from llama_cpp import llama_cpp +from ._utils import suppress_stdout_stderr +from .llama_grammar import LlamaGrammar +from .llama_types import * # Python wrappers over llama.h structs @@ -152,17 +150,17 @@ def tokenize(self, text: bytes, add_bos: bool, special: bool): n_ctx = self.n_ctx_train() tokens = (llama_cpp.llama_token * n_ctx)() n_tokens = llama_cpp.llama_tokenize( - self.model, text, len(text), tokens, n_ctx, add_bos, special + self.model, text, len(text), tokens, n_ctx, add_bos, special, ) if n_tokens < 0: n_tokens = abs(n_tokens) tokens = (llama_cpp.llama_token * n_tokens)() n_tokens = llama_cpp.llama_tokenize( - self.model, text, len(text), tokens, n_tokens, add_bos, special + self.model, text, len(text), tokens, n_tokens, add_bos, special, ) if n_tokens < 0: raise RuntimeError( - f'Failed to tokenize: text="{text}" n_tokens={n_tokens}' + f'Failed to tokenize: text="{text}" n_tokens={n_tokens}', ) return list(tokens[:n_tokens]) @@ -177,7 +175,7 @@ def detokenize(self, tokens: List[int], special: bool = False) -> bytes: buffer = (ctypes.c_char * size)() for token in tokens: n = llama_cpp.llama_token_to_piece( - self.model, llama_cpp.llama_token(token), buffer, size, 0, special + self.model, llama_cpp.llama_token(token), buffer, size, 0, special, ) assert n <= size output += bytes(buffer[:n]) @@ -199,23 +197,23 @@ def metadata(self) -> Dict[str, str]: # iterate over model keys for i in range(llama_cpp.llama_model_meta_count(self.model)): nbytes = llama_cpp.llama_model_meta_key_by_index( - self.model, i, buffer, buffer_size + self.model, i, buffer, buffer_size, ) if nbytes > buffer_size: buffer_size = nbytes + 1 buffer = ctypes.create_string_buffer(buffer_size) nbytes = llama_cpp.llama_model_meta_key_by_index( - self.model, i, buffer, buffer_size + self.model, i, buffer, buffer_size, ) key = buffer.value.decode("utf-8") nbytes = llama_cpp.llama_model_meta_val_str_by_index( - self.model, i, buffer, buffer_size + self.model, i, buffer, buffer_size, ) if nbytes > buffer_size: buffer_size = nbytes + 1 buffer = ctypes.create_string_buffer(buffer_size) nbytes = llama_cpp.llama_model_meta_val_str_by_index( - self.model, i, buffer, buffer_size + self.model, i, buffer, buffer_size, ) value = buffer.value.decode("utf-8") metadata[key] = value @@ -324,8 +322,8 @@ def set_rng_seed(self, seed: int): def sample_repetition_penalties( self, - candidates: "_LlamaTokenDataArray", - last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]", + candidates: _LlamaTokenDataArray, + last_tokens_data: llama_cpp.Array[llama_cpp.llama_token], penalty_last_n: int, penalty_repeat: float, penalty_freq: float, @@ -349,36 +347,36 @@ def sample_softmax(self, candidates: "_LlamaTokenDataArray"): def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int): llama_cpp.llama_sample_top_k( - self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep + self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep, ) def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): llama_cpp.llama_sample_top_p( - self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep + self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep, ) def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): llama_cpp.llama_sample_min_p( - self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep + self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep, ) def sample_tail_free( - self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int + self, candidates: _LlamaTokenDataArray, z: float, min_keep: int, ): llama_cpp.llama_sample_tail_free( - self.ctx, llama_cpp.byref(candidates.candidates), z, min_keep + self.ctx, llama_cpp.byref(candidates.candidates), z, min_keep, ) def sample_typical( - self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int + self, candidates: _LlamaTokenDataArray, p: float, min_keep: int, ): llama_cpp.llama_sample_typical( - self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep + self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep, ) def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float): llama_cpp.llama_sample_temp( - self.ctx, llama_cpp.byref(candidates.candidates), temp + self.ctx, llama_cpp.byref(candidates.candidates), temp, ) def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar): @@ -390,7 +388,7 @@ def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGramm def sample_token_mirostat( self, - candidates: "_LlamaTokenDataArray", + candidates: _LlamaTokenDataArray, tau: float, eta: float, m: int, @@ -407,7 +405,7 @@ def sample_token_mirostat( def sample_token_mirostat_v2( self, - candidates: "_LlamaTokenDataArray", + candidates: _LlamaTokenDataArray, tau: float, eta: float, mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float], @@ -451,7 +449,7 @@ def default_params(): class LlamaBatch: def __init__( - self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True + self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True, ): self._n_tokens = n_tokens self.embd = embd @@ -517,7 +515,7 @@ def __init__(self, *, n_vocab: int): self.candidates_data = np.recarray( (self.n_vocab,), dtype=np.dtype( - [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True + [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True, ), ) self.candidates = llama_cpp.llama_token_data_array( @@ -580,7 +578,7 @@ class LlamaSamplingParams: class LlamaSamplingContext: params: LlamaSamplingParams = field(default_factory=LlamaSamplingParams) mirostat_mu: ctypes.c_float = field(default_factory=ctypes.c_float) - grammar: Optional[LlamaGrammar] = None + grammar: LlamaGrammar | None = None # NOTE: Missing parsed_grammar prev: list[int] = field(default_factory=list) cur: list[llama_cpp.llama_token_data] = field(default_factory=list) @@ -600,11 +598,11 @@ def cp(self): cur=self.cur.copy(), ) - def last(self) -> Optional[int]: + def last(self) -> int | None: if len(self.prev) > 0: return self.prev[-1] - else: - return None + + return None def prev_str(self, ctx_main: LlamaContext, n: int) -> str: return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8") @@ -613,7 +611,7 @@ def sample( self, ctx_main: LlamaContext, idx: int = 0, - logits_array: Optional[npt.NDArray[np.single]] = None, + logits_array: npt.NDArray[np.single] | None = None, ): n_vocab = ctx_main.model.n_vocab() id: int = 0 @@ -661,44 +659,43 @@ def sample( id = token_data_array.candidates_data.id[0] elif self.params.temp == 0: id = ctx_main.sample_token_greedy(token_data_array) + elif self.params.mirostat == 1: + mirostat_m = 100 + ctx_main.sample_temp(token_data_array, self.params.temp) + id = ctx_main.sample_token_mirostat( + token_data_array, + self.params.mirostat_tau, + self.params.mirostat_eta, + mirostat_m, + ctypes.pointer(self.mirostat_mu), + ) + elif self.params.mirostat == 2: + ctx_main.sample_temp(token_data_array, self.params.temp) + id = ctx_main.sample_token_mirostat_v2( + token_data_array, + self.params.mirostat_tau, + self.params.mirostat_eta, + ctypes.pointer(self.mirostat_mu), + ) else: - if self.params.mirostat == 1: - mirostat_m = 100 - ctx_main.sample_temp(token_data_array, self.params.temp) - id = ctx_main.sample_token_mirostat( - token_data_array, - self.params.mirostat_tau, - self.params.mirostat_eta, - mirostat_m, - ctypes.pointer(self.mirostat_mu), - ) - elif self.params.mirostat == 2: - ctx_main.sample_temp(token_data_array, self.params.temp) - id = ctx_main.sample_token_mirostat_v2( - token_data_array, - self.params.mirostat_tau, - self.params.mirostat_eta, - ctypes.pointer(self.mirostat_mu), - ) - else: - min_keep = max(1, self.params.n_probs) - ctx_main.sample_top_k( - token_data_array, self.params.top_k, min_keep=min_keep - ) - ctx_main.sample_tail_free( - token_data_array, self.params.tfs_z, min_keep=min_keep - ) - ctx_main.sample_typical( - token_data_array, self.params.typical_p, min_keep=min_keep - ) - ctx_main.sample_top_p( - token_data_array, self.params.top_p, min_keep=min_keep - ) - ctx_main.sample_min_p( - token_data_array, self.params.min_p, min_keep=min_keep - ) - ctx_main.sample_temp(token_data_array, self.params.temp) - id = ctx_main.sample_token(token_data_array) + min_keep = max(1, self.params.n_probs) + ctx_main.sample_top_k( + token_data_array, self.params.top_k, min_keep=min_keep, + ) + ctx_main.sample_tail_free( + token_data_array, self.params.tfs_z, min_keep=min_keep, + ) + ctx_main.sample_typical( + token_data_array, self.params.typical_p, min_keep=min_keep, + ) + ctx_main.sample_top_p( + token_data_array, self.params.top_p, min_keep=min_keep, + ) + ctx_main.sample_min_p( + token_data_array, self.params.min_p, min_keep=min_keep, + ) + ctx_main.sample_temp(token_data_array, self.params.temp) + id = ctx_main.sample_token(token_data_array) return id def accept(self, ctx_main: LlamaContext, id: int, apply_grammar: bool): diff --git a/llama_cpp/_logger.py b/llama_cpp/_logger.py index 157af692f..b2fd1f8cf 100644 --- a/llama_cpp/_logger.py +++ b/llama_cpp/_logger.py @@ -1,6 +1,6 @@ -import sys import ctypes import logging +import sys import llama_cpp diff --git a/llama_cpp/_utils.py b/llama_cpp/_utils.py index 29628193b..0e1d6e78e 100644 --- a/llama_cpp/_utils.py +++ b/llama_cpp/_utils.py @@ -1,6 +1,5 @@ import os import sys - from typing import Any, Dict # Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor @@ -11,7 +10,7 @@ STDERR_FILENO = 2 -class suppress_stdout_stderr(object): +class suppress_stdout_stderr: # NOTE: these must be "saved" here to avoid exceptions when using # this context manager inside of a __del__ method sys = sys @@ -61,7 +60,7 @@ class MetaSingleton(type): Metaclass for implementing the Singleton pattern. """ - _instances: Dict[type, Any] = {} + _instances: dict[type, Any] = {} def __call__(cls, *args: Any, **kwargs: Any) -> Any: if cls not in cls._instances: @@ -69,7 +68,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: return cls._instances[cls] -class Singleton(object, metaclass=MetaSingleton): +class Singleton(metaclass=MetaSingleton): """ Base class for implementing the Singleton pattern. """ diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index ae712dedb..0da0dcb7c 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1,48 +1,32 @@ from __future__ import annotations +import contextlib +import ctypes +import fnmatch +import json +import multiprocessing import os import sys -import uuid import time -import json -import ctypes import typing import random import fnmatch import warnings -import contextlib -import multiprocessing - +from collections import deque +from pathlib import Path from typing import ( Any, + Callable, + Deque, + Dict, + Generator, + Iterator, List, Literal, Optional, - Union, - Generator, Sequence, - Iterator, - Deque, - Callable, - Dict, -) -from collections import deque -from pathlib import Path - - -from .llama_types import * -from .llama_grammar import LlamaGrammar -from .llama_cache import ( - BaseLlamaCache, - LlamaCache, # type: ignore - LlamaDiskCache, # type: ignore - LlamaRAMCache, # type: ignore + Union, ) -from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer -import llama_cpp.llama_cpp as llama_cpp -import llama_cpp.llama_chat_format as llama_chat_format - -from llama_cpp.llama_speculative import LlamaDraftModel import numpy as np import numpy.typing as npt @@ -50,6 +34,12 @@ import llama_cpp._internals as internals from ._logger import set_verbose from ._utils import suppress_stdout_stderr +from .llama_cache import ( + BaseLlamaCache, # type: ignore +) +from .llama_grammar import LlamaGrammar +from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer +from .llama_types import * class Llama: @@ -882,6 +872,8 @@ def generate( else: break if longest_prefix > 0: + if self.verbose: + print("Llama.generate: prefix-match hit", file=sys.stderr) reset = False tokens = tokens[longest_prefix:] self.n_tokens = longest_prefix @@ -1146,7 +1138,7 @@ def _create_completion( ]: assert suffix is None or suffix.__class__ is str - completion_id: str = f"cmpl-{str(uuid.uuid4())}" + completion_id: str = f"cmpl-{uuid.uuid4()!s}" created: int = int(time.time()) bos_token_id: int = self.token_bos() cls_token_id: int = self._model.token_cls() @@ -1414,12 +1406,10 @@ def logit_bias_processor( token_offset = len(prompt_tokens) + returned_tokens logits = self._scores[token_offset - 1, :] current_logprobs = Llama.logits_to_logprobs(logits).tolist() - sorted_logprobs = list( - sorted( + sorted_logprobs = sorted( zip(current_logprobs, range(len(current_logprobs))), reverse=True, ) - ) top_logprob = { self.detokenize([i]).decode( "utf-8", errors="ignore" @@ -1553,12 +1543,10 @@ def logit_bias_processor( token_offset = len(prompt_tokens) + returned_tokens - 1 logits = self._scores[token_offset, :] current_logprobs = Llama.logits_to_logprobs(logits).tolist() - sorted_logprobs = list( - sorted( + sorted_logprobs = sorted( zip(current_logprobs, range(len(current_logprobs))), reverse=True, ) - ) top_logprob = { self.detokenize([i]).decode("utf-8", errors="ignore"): logprob for logprob, i in sorted_logprobs[:logprobs] @@ -1630,8 +1618,7 @@ def logit_bias_processor( if self.verbose: print("Llama._create_completion: cache save", file=sys.stderr) self.cache[prompt_tokens + completion_tokens] = self.save_state() - if self.verbose: - print("Llama._create_completion: cache saved", file=sys.stderr) + print("Llama._create_completion: cache saved", file=sys.stderr) return if self.cache: @@ -1687,11 +1674,9 @@ def logit_bias_processor( ) ) tokens.append(token_str) - sorted_logprobs = list( - sorted( + sorted_logprobs = sorted( zip(logprobs_token, range(len(logprobs_token))), reverse=True ) - ) token_logprobs.append(logprobs_token[int(token)]) top_logprob: Optional[Dict[str, float]] = { self.detokenize([i], prev_tokens=all_tokens[:idx]).decode( @@ -2237,7 +2222,7 @@ def from_pretrained( local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", cache_dir: Optional[Union[str, os.PathLike[str]]] = None, **kwargs: Any, - ) -> "Llama": + ) -> Llama: """Create a Llama model from a pretrained model name or path. This method requires the huggingface-hub package. You can install it with `pip install huggingface-hub`. @@ -2253,7 +2238,7 @@ def from_pretrained( Returns: A Llama model.""" try: - from huggingface_hub import hf_hub_download, HfFileSystem + from huggingface_hub import HfFileSystem, hf_hub_download from huggingface_hub.utils import validate_repo_id except ImportError: raise ImportError( @@ -2267,7 +2252,7 @@ def from_pretrained( files = [ file["name"] if isinstance(file, dict) else file - for file in hffs.ls(repo_id, recursive=True) + for file in hffs.ls(repo_id) ] # split each file into repo_id, subfolder, filename diff --git a/llama_cpp/llama_cache.py b/llama_cpp/llama_cache.py index e059e98e1..3046fdea0 100644 --- a/llama_cpp/llama_cache.py +++ b/llama_cpp/llama_cache.py @@ -1,11 +1,11 @@ import sys from abc import ABC, abstractmethod +from collections import OrderedDict from typing import ( Optional, Sequence, Tuple, ) -from collections import OrderedDict import diskcache @@ -41,7 +41,7 @@ def __contains__(self, key: Sequence[int]) -> bool: @abstractmethod def __setitem__( - self, key: Sequence[int], value: "llama_cpp.llama.LlamaState" + self, key: Sequence[int], value: "llama_cpp.llama.LlamaState", ) -> None: raise NotImplementedError @@ -105,7 +105,7 @@ class LlamaDiskCache(BaseLlamaCache): """Cache for a llama.cpp model using disk.""" def __init__( - self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30) + self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30), ): super().__init__(capacity_bytes) self.cache = diskcache.Cache(cache_dir) @@ -132,7 +132,7 @@ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": _key = self._find_longest_prefix_key(key) if _key is None: raise KeyError("Key not found") - value: "llama_cpp.llama.LlamaState" = self.cache.pop(_key) # type: ignore + value: llama_cpp.llama.LlamaState = self.cache.pop(_key) # type: ignore # NOTE: This puts an integer as key in cache, which breaks, # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens # self.cache.push(_key, side="front") # type: ignore diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index dfb0af65e..5e231e966 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -1,39 +1,35 @@ from __future__ import annotations -import os -import sys -import json import ctypes import dataclasses +import json +import os import random import string - +import sys +from collections.abc import Iterator from contextlib import ExitStack from typing import ( Any, Dict, - Iterator, List, Literal, Optional, + Protocol, Tuple, Union, - Protocol, cast, ) import jinja2 -from jinja2.sandbox import ImmutableSandboxedEnvironment - import numpy as np import numpy.typing as npt +from jinja2.sandbox import ImmutableSandboxedEnvironment -import llama_cpp.llama as llama -import llama_cpp.llama_types as llama_types -import llama_cpp.llama_grammar as llama_grammar +from llama_cpp import llama, llama_grammar, llama_types from ._logger import logger -from ._utils import suppress_stdout_stderr, Singleton +from ._utils import Singleton, suppress_stdout_stderr ### Common Chat Templates and Special Tokens ### @@ -69,26 +65,24 @@ def __call__( # llama.cpp instance llama: llama.Llama, # openai api parameters - messages: List[llama_types.ChatCompletionRequestMessage], - functions: Optional[List[llama_types.ChatCompletionFunction]] = None, - function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, - tools: Optional[List[llama_types.ChatCompletionTool]] = None, - tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + messages: list[llama_types.ChatCompletionRequestMessage], + functions: list[llama_types.ChatCompletionFunction] | None = None, + function_call: llama_types.ChatCompletionRequestFunctionCall | None = None, + tools: list[llama_types.ChatCompletionTool] | None = None, + tool_choice: llama_types.ChatCompletionToolChoiceOption | None = None, temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, stream: bool = False, - stop: Optional[Union[str, List[str]]] = [], - seed: Optional[int] = None, - response_format: Optional[ - llama_types.ChatCompletionRequestResponseFormat - ] = None, - max_tokens: Optional[int] = None, + stop: str | list[str] | None = [], + seed: int | None = None, + response_format: llama_types.ChatCompletionRequestResponseFormat | None = None, + max_tokens: int | None = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, - model: Optional[str] = None, - logit_bias: Optional[Dict[str, float]] = None, + model: str | None = None, + logit_bias: dict[str, float] | None = None, # llama.cpp parameters min_p: float = 0.05, typical_p: float = 1.0, @@ -96,15 +90,12 @@ def __call__( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, - logits_processor: Optional[llama.LogitsProcessorList] = None, - grammar: Optional[llama.LlamaGrammar] = None, - logprobs: Optional[bool] = None, - top_logprobs: Optional[int] = None, + logits_processor: llama.LogitsProcessorList | None = None, + grammar: llama.LlamaGrammar | None = None, + logprobs: bool | None = None, + top_logprobs: int | None = None, **kwargs, # type: ignore - ) -> Union[ - llama_types.CreateChatCompletionResponse, - Iterator[llama_types.CreateChatCompletionStreamResponse], - ]: ... + ) -> llama_types.CreateChatCompletionResponse | Iterator[llama_types.CreateChatCompletionStreamResponse]: ... class LlamaChatCompletionHandlerNotFoundException(Exception): @@ -112,7 +103,7 @@ class LlamaChatCompletionHandlerNotFoundException(Exception): class LlamaChatCompletionHandlerRegistry(Singleton): - _chat_handlers: Dict[str, LlamaChatCompletionHandler] = {} + _chat_handlers: dict[str, LlamaChatCompletionHandler] = {} def register_chat_completion_handler( self, @@ -122,7 +113,7 @@ def register_chat_completion_handler( ): if not overwrite and name in self._chat_handlers: raise ValueError( - f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it." + f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it.", ) self._chat_handlers[name] = chat_handler @@ -133,20 +124,20 @@ def unregister_chat_handler(self, name: str): raise ValueError(f"No formatter registered under the name '{name}'.") def get_chat_completion_handler_by_name( - self, name: str + self, name: str, ) -> LlamaChatCompletionHandler: try: chat_handler = self._chat_handlers[name] return chat_handler except KeyError: raise LlamaChatCompletionHandlerNotFoundException( - f"Invalid chat handler: {name} (valid formats: {list(self._chat_handlers.keys())})" + f"Invalid chat handler: {name} (valid formats: {list(self._chat_handlers.keys())})", ) def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler: return LlamaChatCompletionHandlerRegistry().get_chat_completion_handler_by_name( - name + name, ) @@ -170,8 +161,8 @@ class ChatFormatterResponse: stop contains the stop token or list of stop tokens to use for the chat format.""" prompt: str - stop: Optional[Union[str, List[str]]] = None - stopping_criteria: Optional[llama.StoppingCriteriaList] = None + stop: str | list[str] | None = None + stopping_criteria: llama.StoppingCriteriaList | None = None added_special: bool = False @@ -184,7 +175,7 @@ class ChatFormatter(Protocol): def __call__( self, *, - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: ... @@ -196,7 +187,7 @@ def __init__( eos_token: str, bos_token: str, add_generation_prompt: bool = True, - stop_token_ids: Optional[List[int]] = None, + stop_token_ids: list[int] | None = None, ): """A chat formatter that uses jinja2 templates to format the prompt.""" self.template = template @@ -216,11 +207,11 @@ def __init__( def __call__( self, *, - messages: List[llama_types.ChatCompletionRequestMessage], - functions: Optional[List[llama_types.ChatCompletionFunction]] = None, - function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, - tools: Optional[List[llama_types.ChatCompletionTool]] = None, - tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + messages: list[llama_types.ChatCompletionRequestMessage], + functions: list[llama_types.ChatCompletionFunction] | None = None, + function_call: llama_types.ChatCompletionRequestFunctionCall | None = None, + tools: list[llama_types.ChatCompletionTool] | None = None, + tool_choice: llama_types.ChatCompletionToolChoiceOption | None = None, **kwargs: Any, ) -> ChatFormatterResponse: def raise_exception(message: str): @@ -242,7 +233,7 @@ def raise_exception(message: str): if self.stop_token_ids is not None: def stop_on_last_token( - tokens: npt.NDArray[np.intc], logits: npt.NDArray[np.single] + tokens: npt.NDArray[np.intc], logits: npt.NDArray[np.single], ) -> bool: return tokens[-1] in self.stop_token_ids @@ -277,7 +268,7 @@ def _convert_text_completion_to_chat( }, "logprobs": completion["choices"][0]["logprobs"], "finish_reason": completion["choices"][0]["finish_reason"], - } + }, ], "usage": completion["usage"], } @@ -301,7 +292,7 @@ def _convert_text_completion_chunks_to_chat( }, "logprobs": None, "finish_reason": None, - } + }, ], } yield { @@ -321,34 +312,25 @@ def _convert_text_completion_chunks_to_chat( ), "logprobs": chunk["choices"][0]["logprobs"], "finish_reason": chunk["choices"][0]["finish_reason"], - } + }, ], } def _convert_completion_to_chat( - completion_or_chunks: Union[ - llama_types.CreateCompletionResponse, - Iterator[llama_types.CreateCompletionStreamResponse], - ], + completion_or_chunks: llama_types.CreateCompletionResponse | Iterator[llama_types.CreateCompletionStreamResponse], stream: bool = False, -) -> Union[ - llama_types.CreateChatCompletionResponse, Iterator[llama_types.ChatCompletionChunk] -]: +) -> llama_types.CreateChatCompletionResponse | Iterator[llama_types.ChatCompletionChunk]: if stream: chunks: Iterator[llama_types.CreateCompletionStreamResponse] = completion_or_chunks # type: ignore return _convert_text_completion_chunks_to_chat(chunks) - else: - completion: llama_types.Completion = completion_or_chunks # type: ignore - return _convert_text_completion_to_chat(completion) + completion: llama_types.Completion = completion_or_chunks # type: ignore + return _convert_text_completion_to_chat(completion) def _convert_completion_to_chat_function( tool_name: str, - completion_or_chunks: Union[ - llama_types.CreateCompletionResponse, - Iterator[llama_types.CreateCompletionStreamResponse], - ], + completion_or_chunks: llama_types.CreateCompletionResponse | Iterator[llama_types.CreateCompletionStreamResponse], stream: bool, ): if not stream: @@ -379,90 +361,52 @@ def _convert_completion_to_chat_function( "name": tool_name, "arguments": completion["choices"][0]["text"], }, - } + }, ], }, "logprobs": completion["choices"][0]["logprobs"], "finish_reason": "tool_calls", - } + }, ], "usage": completion["usage"], } return chat_completion - else: - chunks: Iterator[llama_types.CreateCompletionStreamResponse] = completion_or_chunks # type: ignore - - def _stream_response_to_function_stream( - chunks: Iterator[llama_types.CreateCompletionStreamResponse], - ) -> Iterator[llama_types.CreateChatCompletionStreamResponse]: - # blank first message - first = True - id_ = None - created = None - model = None - tool_id = None - for chunk in chunks: - if first: - id_ = "chat" + chunk["id"] - created = chunk["created"] - model = chunk["model"] - tool_id = "call_" + "_0_" + tool_name + "_" + chunk["id"] - yield { - "id": id_, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [ - { - "index": 0, - "finish_reason": None, - "logprobs": None, - "delta": { - "role": "assistant", - "content": None, - "function_call": None, - "tool_calls": None, - }, - } - ], - } - yield { - "id": "chat" + chunk["id"], - "object": "chat.completion.chunk", - "created": chunk["created"], - "model": chunk["model"], - "choices": [ - { - "index": 0, - "finish_reason": None, - "logprobs": chunk["choices"][0]["logprobs"], - "delta": { - "role": None, - "content": None, - "function_call": { - "name": tool_name, - "arguments": chunk["choices"][0]["text"], - }, - "tool_calls": [ - { - "index": 0, - "id": tool_id, - "type": "function", - "function": { - "name": tool_name, - "arguments": chunk["choices"][0][ - "text" - ], - }, - } - ], - }, - } - ], - } - first = False - continue - assert tool_id is not None + chunks: Iterator[llama_types.CreateCompletionStreamResponse] = completion_or_chunks # type: ignore + + def _stream_response_to_function_stream( + chunks: Iterator[llama_types.CreateCompletionStreamResponse], + ) -> Iterator[llama_types.CreateChatCompletionStreamResponse]: + # blank first message + first = True + id_ = None + created = None + model = None + tool_id = None + for chunk in chunks: + if first: + id_ = "chat" + chunk["id"] + created = chunk["created"] + model = chunk["model"] + tool_id = "call_" + "_0_" + tool_name + "_" + chunk["id"] + yield { + "id": id_, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [ + { + "index": 0, + "finish_reason": None, + "logprobs": None, + "delta": { + "role": "assistant", + "content": None, + "function_call": None, + "tool_calls": None, + }, + }, + ], + } yield { "id": "chat" + chunk["id"], "object": "chat.completion.chunk", @@ -487,37 +431,74 @@ def _stream_response_to_function_stream( "type": "function", "function": { "name": tool_name, - "arguments": chunk["choices"][0]["text"], + "arguments": chunk["choices"][0][ + "text" + ], }, - } + }, ], }, - } + }, ], } - - if id_ is not None and created is not None and model is not None: - yield { - "id": id_, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [ - { - "index": 0, - "finish_reason": "tool_calls", - "logprobs": None, - "delta": { - "role": None, - "content": None, - "function_call": None, - "tool_calls": None, + first = False + continue + assert tool_id is not None + yield { + "id": "chat" + chunk["id"], + "object": "chat.completion.chunk", + "created": chunk["created"], + "model": chunk["model"], + "choices": [ + { + "index": 0, + "finish_reason": None, + "logprobs": chunk["choices"][0]["logprobs"], + "delta": { + "role": None, + "content": None, + "function_call": { + "name": tool_name, + "arguments": chunk["choices"][0]["text"], }, - } - ], - } + "tool_calls": [ + { + "index": 0, + "id": tool_id, + "type": "function", + "function": { + "name": tool_name, + "arguments": chunk["choices"][0]["text"], + }, + }, + ], + }, + }, + ], + } + + if id_ is not None and created is not None and model is not None: + yield { + "id": id_, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [ + { + "index": 0, + "finish_reason": "tool_calls", + "logprobs": None, + "delta": { + "role": None, + "content": None, + "function_call": None, + "tool_calls": None, + }, + }, + ], + } - return _stream_response_to_function_stream(chunks) + return _stream_response_to_function_stream(chunks) def chat_formatter_to_chat_completion_handler( @@ -526,23 +507,21 @@ def chat_formatter_to_chat_completion_handler( def chat_completion_handler( *, llama: llama.Llama, - messages: List[llama_types.ChatCompletionRequestMessage], - functions: Optional[List[llama_types.ChatCompletionFunction]] = None, - function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, - tools: Optional[List[llama_types.ChatCompletionTool]] = None, - tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + messages: list[llama_types.ChatCompletionRequestMessage], + functions: list[llama_types.ChatCompletionFunction] | None = None, + function_call: llama_types.ChatCompletionRequestFunctionCall | None = None, + tools: list[llama_types.ChatCompletionTool] | None = None, + tool_choice: llama_types.ChatCompletionToolChoiceOption | None = None, temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, min_p: float = 0.05, typical_p: float = 1.0, stream: bool = False, - stop: Optional[Union[str, List[str]]] = [], - seed: Optional[int] = None, - response_format: Optional[ - llama_types.ChatCompletionRequestResponseFormat - ] = None, - max_tokens: Optional[int] = None, + stop: str | list[str] | None = [], + seed: int | None = None, + response_format: llama_types.ChatCompletionRequestResponseFormat | None = None, + max_tokens: int | None = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, @@ -550,17 +529,14 @@ def chat_completion_handler( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, - model: Optional[str] = None, - logits_processor: Optional[llama.LogitsProcessorList] = None, - grammar: Optional[llama.LlamaGrammar] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - top_logprobs: Optional[int] = None, + model: str | None = None, + logits_processor: llama.LogitsProcessorList | None = None, + grammar: llama.LlamaGrammar | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + top_logprobs: int | None = None, **kwargs, # type: ignore - ) -> Union[ - llama_types.CreateChatCompletionResponse, - Iterator[llama_types.CreateChatCompletionStreamResponse], - ]: + ) -> llama_types.CreateChatCompletionResponse | Iterator[llama_types.CreateChatCompletionStreamResponse]: result = chat_formatter( messages=messages, functions=functions, @@ -584,7 +560,7 @@ def chat_completion_handler( if response_format is not None and response_format["type"] == "json_object": grammar = _grammar_for_response_format( - response_format, verbose=llama.verbose + response_format, verbose=llama.verbose, ) # Convert legacy functions to tools @@ -625,13 +601,13 @@ def chat_completion_handler( try: # create grammar from json schema grammar = llama_grammar.LlamaGrammar.from_json_schema( - json.dumps(schema), verbose=llama.verbose + json.dumps(schema), verbose=llama.verbose, ) except Exception as e: if llama.verbose: print(str(e), file=sys.stderr) grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF, verbose=llama.verbose + llama_grammar.JSON_GBNF, verbose=llama.verbose, ) completion_or_chunks = llama.create_completion( @@ -662,7 +638,7 @@ def chat_completion_handler( if tool is not None: tool_name = tool["function"]["name"] return _convert_completion_to_chat_function( - tool_name, completion_or_chunks, stream + tool_name, completion_or_chunks, stream, ) return _convert_completion_to_chat(completion_or_chunks, stream=stream) @@ -670,7 +646,7 @@ def chat_completion_handler( def hf_autotokenizer_to_chat_formatter( - pretrained_model_name_or_path: Union[str, os.PathLike[str]] + pretrained_model_name_or_path: str | os.PathLike[str], ) -> ChatFormatter: # https://huggingface.co/docs/transformers/main/chat_templating # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format @@ -680,7 +656,7 @@ def hf_autotokenizer_to_chat_formatter( tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) # type: ignore def format_autotokenizer( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: tokenizer.use_default_system_prompt = False # type: ignore @@ -688,21 +664,21 @@ def format_autotokenizer( assert isinstance(prompt, str) # Return formatted prompt and eos token by default return ChatFormatterResponse( - prompt=prompt, stop=tokenizer.eos_token, added_special=True + prompt=prompt, stop=tokenizer.eos_token, added_special=True, ) return format_autotokenizer def hf_autotokenizer_to_chat_completion_handler( - pretrained_model_name_or_path: Union[str, os.PathLike[str]] + pretrained_model_name_or_path: str | os.PathLike[str], ) -> LlamaChatCompletionHandler: chat_formatter = hf_autotokenizer_to_chat_formatter(pretrained_model_name_or_path) return chat_formatter_to_chat_completion_handler(chat_formatter) def hf_tokenizer_config_to_chat_formatter( - tokenizer_config: Dict[str, Any], + tokenizer_config: dict[str, Any], add_generation_prompt: bool = True, ) -> ChatFormatter: assert isinstance(tokenizer_config, dict) @@ -725,7 +701,7 @@ def hf_tokenizer_config_to_chat_formatter( ).from_string(chat_template) def format_tokenizer_config( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: # TODO: veryify this is correct @@ -734,7 +710,7 @@ def format_tokenizer_config( messages = [ *messages, llama_types.ChatCompletionRequestAssistantMessage( - role="assistant", content="" + role="assistant", content="", ), ] prompt = env.render( @@ -743,23 +719,23 @@ def format_tokenizer_config( eos_token=eos_token, ) return ChatFormatterResponse( - prompt=prompt, stop=[eos_token, bos_token], added_special=True + prompt=prompt, stop=[eos_token, bos_token], added_special=True, ) return format_tokenizer_config def hf_tokenizer_config_to_chat_completion_handler( - tokenizer_config: Dict[str, Any], + tokenizer_config: dict[str, Any], add_generation_prompt: bool = True, ) -> LlamaChatCompletionHandler: chat_formatter = hf_tokenizer_config_to_chat_formatter( - tokenizer_config, add_generation_prompt=add_generation_prompt + tokenizer_config, add_generation_prompt=add_generation_prompt, ) return chat_formatter_to_chat_completion_handler(chat_formatter) -def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[str]: +def guess_chat_format_from_gguf_metadata(metadata: dict[str, str]) -> str | None: if "tokenizer.chat_template" not in metadata: return None @@ -783,7 +759,7 @@ def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[s def _get_system_message( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], ) -> str: """Get the first system message.""" for message in messages: @@ -793,11 +769,11 @@ def _get_system_message( def _map_roles( - messages: List[llama_types.ChatCompletionRequestMessage], - role_map: Dict[str, str], -) -> List[Tuple[str, Optional[str]]]: + messages: list[llama_types.ChatCompletionRequestMessage], + role_map: dict[str, str], +) -> list[tuple[str, str | None]]: """Map the message roles.""" - output: List[Tuple[str, Optional[str]]] = [] + output: list[tuple[str, str | None]] = [] for message in messages: role = message["role"] if role in role_map: @@ -809,7 +785,7 @@ def _map_roles( def _format_llama2( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str + system_message: str, messages: list[tuple[str, str | None]], sep: str, sep2: str, ) -> str: """Format the prompt with the llama2 style.""" seps = [sep, sep2] @@ -826,7 +802,7 @@ def _format_llama2( def _format_add_colon_single( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str + system_message: str, messages: list[tuple[str, str | None]], sep: str, ) -> str: """Format the prompt with the add-colon-single style.""" ret = system_message + sep @@ -839,7 +815,7 @@ def _format_add_colon_single( def _format_add_colon_two( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str + system_message: str, messages: list[tuple[str, str | None]], sep: str, sep2: str, ) -> str: """Format the prompt with the add-colon-two style.""" seps = [sep, sep2] @@ -853,7 +829,7 @@ def _format_add_colon_two( def _format_no_colon_single( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str + system_message: str, messages: list[tuple[str, str | None]], sep: str, ) -> str: """Format the prompt with the no-colon-single style.""" ret = system_message @@ -866,7 +842,7 @@ def _format_no_colon_single( def _format_add_colon_space_single( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str + system_message: str, messages: list[tuple[str, str | None]], sep: str, ) -> str: """Format the prompt with the add-colon-space-single style.""" ret = system_message + sep @@ -879,7 +855,7 @@ def _format_add_colon_space_single( def _format_chatml( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str + system_message: str, messages: list[tuple[str, str | None]], sep: str, ) -> str: """Format the prompt with the chatml style.""" ret = "" if system_message == "" else system_message + sep + "\n" @@ -892,7 +868,7 @@ def _format_chatml( def _format_chatglm3( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str + system_message: str, messages: list[tuple[str, str | None]], sep: str, ) -> str: """Format the prompt with the chatglm3 style.""" ret = "" @@ -908,20 +884,19 @@ def _format_chatglm3( def _grammar_for_json(verbose: bool = False): return llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF, verbose=verbose + llama_grammar.JSON_GBNF, verbose=verbose, ) def _grammar_for_json_schema( - schema: str, verbose: bool = False, fallback_to_json: bool = True + schema: str, verbose: bool = False, fallback_to_json: bool = True, ): try: return llama_grammar.LlamaGrammar.from_json_schema(schema, verbose=verbose) except Exception as e: if fallback_to_json: return _grammar_for_json(verbose=verbose) - else: - raise e + raise e def _grammar_for_response_format( @@ -933,10 +908,9 @@ def _grammar_for_response_format( if "schema" in response_format: return _grammar_for_json_schema( - json.dumps(response_format["schema"]), verbose=verbose + json.dumps(response_format["schema"]), verbose=verbose, ) - else: - return _grammar_for_json(verbose=verbose) + return _grammar_for_json(verbose=verbose) ### Chat Formats ### @@ -946,7 +920,7 @@ def register_chat_format(name: str): def decorator(f: ChatFormatter): chat_completion_handler = chat_formatter_to_chat_completion_handler(f) LlamaChatCompletionHandlerRegistry().register_chat_completion_handler( - name, chat_completion_handler + name, chat_completion_handler, ) return f @@ -957,7 +931,7 @@ def decorator(f: ChatFormatter): # system prompt is "embedded" in the first message @register_chat_format("llama-2") def format_llama2( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: _system_template = "[INST] <>\n{system_message}\n<>" @@ -974,7 +948,7 @@ def format_llama2( # https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py#L202-L229 @register_chat_format("llama-3") def format_llama3( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: _roles = dict( @@ -991,7 +965,7 @@ def format_llama3( @register_chat_format("alpaca") def format_alpaca( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: _roles = dict(user="### Instruction", assistant="### Response") @@ -1005,7 +979,7 @@ def format_alpaca( @register_chat_format("qwen") def format_qwen( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: _roles = dict(user="<|im_start|>user", assistant="<|im_start|>assistant") @@ -1022,7 +996,7 @@ def format_qwen( @register_chat_format("vicuna") def format( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: _system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." @@ -1038,7 +1012,7 @@ def format( @register_chat_format("oasst_llama") def format_oasst_llama( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: _system_template = "[INST] <>\n{system_message}\n<>\n\n" @@ -1054,7 +1028,7 @@ def format_oasst_llama( @register_chat_format("baichuan-2") def format_baichuan2( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: _system_template = "{system_message}" @@ -1070,7 +1044,7 @@ def format_baichuan2( @register_chat_format("baichuan") def format_baichuan( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: _system_template = "{system_message}" @@ -1086,7 +1060,7 @@ def format_baichuan( @register_chat_format("openbuddy") def format_openbuddy( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: _system_message = """You are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human User. @@ -1108,7 +1082,7 @@ def format_openbuddy( @register_chat_format("redpajama-incite") def format_redpajama_incite( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: _system_message = _get_system_message(messages) @@ -1124,7 +1098,7 @@ def format_redpajama_incite( @register_chat_format("snoozy") def format_snoozy( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: system_template = "### Instruction:\n{system_message}" @@ -1146,7 +1120,7 @@ def format_snoozy( @register_chat_format("phind") def format_phind( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: _roles = dict(user="### User Message", assistant="### Assistant") @@ -1160,7 +1134,7 @@ def format_phind( @register_chat_format("intel") def format_intel( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: _roles = dict(user="### User:", assistant="### Assistant:") @@ -1174,7 +1148,7 @@ def format_intel( @register_chat_format("open-orca") def format_open_orca( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: system_template = "{system_message}" @@ -1193,7 +1167,7 @@ def format_open_orca( # stop_token_ids=[32000, 32001], # "<|end_of_turn|>" stop_str = "User" system_message = system_template.format(system_message=system_message) - _messages = _map_roles(messages, dict(zip(roles, roles))) + _messages = _map_roles(messages, dict(zip(roles, roles, strict=False))) _messages.append((roles[1], None)) _prompt = _format_add_colon_space_single(system_message, _messages, sep) return ChatFormatterResponse(prompt=_prompt, stop=stop_str) @@ -1201,7 +1175,7 @@ def format_open_orca( @register_chat_format("mistrallite") def format_mistrallite( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: _roles = dict(user="<|prompter|>", assistant="\n<|assistant|>") @@ -1217,7 +1191,7 @@ def format_mistrallite( @register_chat_format("zephyr") def format_zephyr( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: system_template = """<|system|> @@ -1234,7 +1208,7 @@ def format_zephyr( @register_chat_format("pygmalion") def format_pygmalion( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: system_template = """<|system|>{system_message}""" @@ -1250,7 +1224,7 @@ def format_pygmalion( @register_chat_format("chatml") def format_chatml( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: system_template = """<|im_start|>system @@ -1267,7 +1241,7 @@ def format_chatml( @register_chat_format("mistral-instruct") def format_mistral_instruct( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: eos = "" @@ -1288,7 +1262,7 @@ def format_mistral_instruct( @register_chat_format("chatglm3") def format_chatglm3( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: system_template = """<|system|> @@ -1305,14 +1279,14 @@ def format_chatglm3( @register_chat_format("openchat") def format_openchat( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: system_template = "{system_message}<|end_of_turn|>" system_message = _get_system_message(messages) system_message = system_template.format(system_message=system_message) _roles = dict( - user="GPT4 Correct User: ", assistant="<|end_of_turn|>GPT4 Correct Assistant: " + user="GPT4 Correct User: ", assistant="<|end_of_turn|>GPT4 Correct Assistant: ", ) _sep = "<|end_of_turn|>" _messages = _map_roles(messages, _roles) @@ -1347,13 +1321,13 @@ def format_saiga( # https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b @register_chat_format("gemma") def format_gemma( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: system_message = _get_system_message(messages) if system_message != "": logger.debug( - "`role='system'` messages are not allowed on Google's Gemma models." + "`role='system'` messages are not allowed on Google's Gemma models.", ) _roles = dict(user="user\n", assistant="model\n") _sep = "\n" @@ -1369,20 +1343,20 @@ def format_gemma( @register_chat_completion_handler("functionary") def functionary_chat_handler( llama: llama.Llama, - messages: List[llama_types.ChatCompletionRequestMessage], - functions: Optional[List[llama_types.ChatCompletionFunction]] = None, - function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, - tools: Optional[List[llama_types.ChatCompletionTool]] = None, - tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + messages: list[llama_types.ChatCompletionRequestMessage], + functions: list[llama_types.ChatCompletionFunction] | None = None, + function_call: llama_types.ChatCompletionRequestFunctionCall | None = None, + tools: list[llama_types.ChatCompletionTool] | None = None, + tool_choice: llama_types.ChatCompletionToolChoiceOption | None = None, temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, min_p: float = 0.05, typical_p: float = 1.0, stream: bool = False, - stop: Optional[Union[str, List[str]]] = [], - response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, - max_tokens: Optional[int] = None, + stop: str | list[str] | None = [], + response_format: llama_types.ChatCompletionRequestResponseFormat | None = None, + max_tokens: int | None = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, @@ -1390,15 +1364,15 @@ def functionary_chat_handler( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, - model: Optional[str] = None, - logits_processor: Optional[llama.LogitsProcessorList] = None, - grammar: Optional[llama.LlamaGrammar] = None, + model: str | None = None, + logits_processor: llama.LogitsProcessorList | None = None, + grammar: llama.LlamaGrammar | None = None, **kwargs, # type: ignore -) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: +) -> llama_types.ChatCompletion | Iterator[llama_types.ChatCompletionChunk]: SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary""" def generate_type_definition( - param: Dict[str, llama_types.JsonType], indent_level: int, shared_defs + param: dict[str, llama_types.JsonType], indent_level: int, shared_defs, ) -> str: indent = " " * indent_level if "$ref" in param: @@ -1407,28 +1381,27 @@ def generate_type_definition( -1 ] # Extract the type name from the reference return ref_name - elif param.get("type") == "array": + if param.get("type") == "array": items = param.get("items", {}) item_type = generate_type_definition(items, indent_level + 1, shared_defs) return f"Array<{item_type}>" - elif param.get("type") == "object": + if param.get("type") == "object": properties = param.get("properties", {}) nested_schema = "{\n" for nested_param_name, nested_param in properties.items(): nested_param_type = generate_type_definition( - nested_param, indent_level + 1, shared_defs + nested_param, indent_level + 1, shared_defs, ) nested_schema += ( f"{indent} {nested_param_name}: {nested_param_type},\n" ) nested_schema += indent + "}" return nested_schema - elif "enum" in param: + if "enum" in param: # Enum type return " | ".join([f'"{enum_value}"' for enum_value in param["enum"]]) - else: - # Simple type - return param.get("type", "any") + # Simple type + return param.get("type", "any") def generate_shared_definitions(shared_defs, indent_level: int) -> str: indent = " " * indent_level @@ -1437,12 +1410,12 @@ def generate_shared_definitions(shared_defs, indent_level: int) -> str: shared_definitions += f"{indent}type {def_name} = " if def_properties.get("type") == "object": shared_definitions += generate_type_definition( - def_properties, indent_level, shared_defs + def_properties, indent_level, shared_defs, ) elif "enum" in def_properties: # Enum type shared_definitions += " | ".join( - [f'"{enum_value}"' for enum_value in def_properties["enum"]] + [f'"{enum_value}"' for enum_value in def_properties["enum"]], ) shared_definitions += ";\n" return shared_definitions @@ -1478,20 +1451,20 @@ def generate_schema_from_functions(functions, namespace="functions") -> str: schema += f" {param_name}{optional_indicator}: {param_type},\n" schema += " }) => any;\n\n" - schema += "}} // namespace {}\n".format(namespace) + schema += f"}} // namespace {namespace}\n" return schema def prepare_messages_for_inference( - messages: List[llama_types.ChatCompletionRequestMessage], - functions: Optional[List[llama_types.ChatCompletionFunctions]] = None, - tools: Optional[List[llama_types.ChatCompletionTool]] = None, + messages: list[llama_types.ChatCompletionRequestMessage], + functions: list[llama_types.ChatCompletionFunctions] | None = None, + tools: list[llama_types.ChatCompletionTool] | None = None, ): - all_messages: List[llama_types.ChatCompletionRequestMessage] = [] + all_messages: list[llama_types.ChatCompletionRequestMessage] = [] if functions is not None: all_messages.append( llama_types.ChatCompletionRequestSystemMessage( - role="system", content=generate_schema_from_functions(functions) - ) + role="system", content=generate_schema_from_functions(functions), + ), ) if tools is not None: @@ -1503,15 +1476,15 @@ def prepare_messages_for_inference( tool["function"] for tool in tools if tool["type"] == "function" - ] + ], ), - ) + ), ) all_messages.append( llama_types.ChatCompletionRequestSystemMessage( - role="system", content=SYSTEM_MESSAGE - ) + role="system", content=SYSTEM_MESSAGE, + ), ) for message in messages: @@ -1527,34 +1500,32 @@ def prepare_messages_for_inference( all_messages.append( llama_types.ChatCompletionRequestAssistantMessage( - role="assistant", content=None - ) + role="assistant", content=None, + ), ) def message_to_str(msg: llama_types.ChatCompletionRequestMessage): if msg["role"] == "system": return f"system:\n{msg['content']}\n" - elif msg["role"] == "function" and "name" in msg: + if msg["role"] == "function" and "name" in msg: return f"function name={msg['name']}:\n{msg['content']}\n" - elif msg["role"] == "function" and "function_call" in msg: + if msg["role"] == "function" and "function_call" in msg: return f"function name={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n" - elif msg["role"] == "tool": + if msg["role"] == "tool": if msg["content"] is not None: return f"function name={msg['tool_call_id']}:\n{msg['content']}\n" - else: - return f"function name={msg['tool_call_id']}\n" - elif msg["role"] == "user": + return f"function name={msg['tool_call_id']}\n" + if msg["role"] == "user": if msg["content"] is None: return "user:\n\n" - else: - return f"user:\n{msg['content']}\n" - elif msg["role"] == "assistant": + return f"user:\n{msg['content']}\n" + if msg["role"] == "assistant": if msg["content"] is not None and "function_call" in msg: return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n" - elif "function_call" in msg: + if "function_call" in msg: return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n" - elif "tool_calls" in msg and len(msg["tool_calls"]) > 0: + if "tool_calls" in msg and len(msg["tool_calls"]) > 0: for tool_call in msg[ "tool_calls" ]: # NOTE: probably doesn't work with the functionary model @@ -1607,7 +1578,7 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage): ): stop = "\n" completion: llama_types.Completion = llama.create_completion( - prompt=prompt, stop=stop, stream=False + prompt=prompt, stop=stop, stream=False, ) # type: ignore completion_text = completion["choices"][0]["text"] # strip " to=functions." and ending ":" @@ -1635,7 +1606,7 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage): try: with suppress_stdout_stderr(disable=llama.verbose): grammar_text = llama_grammar.json_schema_to_gbnf( - json.dumps(function_body) + json.dumps(function_body), ) grammar = llama_grammar.LlamaGrammar.from_string( llama_grammar.json_schema_to_gbnf(json.dumps(function_body)), @@ -1645,7 +1616,7 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage): except Exception as e: if llama.verbose: print( - "Failed to parse function body as JSON schema, falling back to default grammar" + "Failed to parse function body as JSON schema, falling back to default grammar", ) print(e) with suppress_stdout_stderr(disable=llama.verbose): @@ -1656,7 +1627,7 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage): else: with suppress_stdout_stderr(disable=llama.verbose): grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF, verbose=llama.verbose + llama_grammar.JSON_GBNF, verbose=llama.verbose, ) completion: llama_types.Completion = llama.create_completion( @@ -1713,12 +1684,12 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage): "name": function_call, "arguments": completion["choices"][0]["text"], }, - } + }, ], }, "logprobs": completion["choices"][0]["logprobs"], "finish_reason": "tool_calls", - } + }, ], usage=completion["usage"], ) @@ -1728,20 +1699,20 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage): @register_chat_completion_handler("functionary-v2") def functionary_v1_v2_chat_handler( llama: llama.Llama, - messages: List[llama_types.ChatCompletionRequestMessage], - functions: Optional[List[llama_types.ChatCompletionFunction]] = None, - function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, - tools: Optional[List[llama_types.ChatCompletionTool]] = None, - tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + messages: list[llama_types.ChatCompletionRequestMessage], + functions: list[llama_types.ChatCompletionFunction] | None = None, + function_call: llama_types.ChatCompletionRequestFunctionCall | None = None, + tools: list[llama_types.ChatCompletionTool] | None = None, + tool_choice: llama_types.ChatCompletionToolChoiceOption | None = None, temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, min_p: float = 0.05, typical_p: float = 1.0, stream: bool = False, - stop: Optional[Union[str, List[str]]] = [], - response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, - max_tokens: Optional[int] = None, + stop: str | list[str] | None = [], + response_format: llama_types.ChatCompletionRequestResponseFormat | None = None, + max_tokens: int | None = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, @@ -1749,16 +1720,16 @@ def functionary_v1_v2_chat_handler( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, - model: Optional[str] = None, - logits_processor: Optional[llama.LogitsProcessorList] = None, - grammar: Optional[llama.LlamaGrammar] = None, + model: str | None = None, + logits_processor: llama.LogitsProcessorList | None = None, + grammar: llama.LlamaGrammar | None = None, **kwargs, # type: ignore -) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: +) -> llama_types.ChatCompletion | Iterator[llama_types.ChatCompletionChunk]: SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary""" tokenizer = llama.tokenizer_ assert hasattr( - tokenizer, "hf_tokenizer" + tokenizer, "hf_tokenizer", ), "Please provide a valid hf_tokenizer_path from https://huggingface.co/meetkai when initializing the Llama class" from transformers import AutoTokenizer @@ -1778,7 +1749,7 @@ def functionary_v1_v2_chat_handler( CONTENT_TOKEN = "<|content|>" def generate_type_definition( - param: Dict[str, llama_types.JsonType], indent_level: int, shared_defs + param: dict[str, llama_types.JsonType], indent_level: int, shared_defs, ) -> str: indent = " " * indent_level if "$ref" in param: @@ -1787,28 +1758,27 @@ def generate_type_definition( -1 ] # Extract the type name from the reference return ref_name - elif param.get("type") == "array": + if param.get("type") == "array": items = param.get("items", {}) item_type = generate_type_definition(items, indent_level + 1, shared_defs) return f"Array<{item_type}>" - elif param.get("type") == "object": + if param.get("type") == "object": properties = param.get("properties", {}) nested_schema = "{\n" for nested_param_name, nested_param in properties.items(): nested_param_type = generate_type_definition( - nested_param, indent_level + 1, shared_defs + nested_param, indent_level + 1, shared_defs, ) nested_schema += ( f"{indent} {nested_param_name}: {nested_param_type},\n" ) nested_schema += indent + "}" return nested_schema - elif "enum" in param: + if "enum" in param: # Enum type return " | ".join([f'"{enum_value}"' for enum_value in param["enum"]]) - else: - # Simple type - return param.get("type", "any") + # Simple type + return param.get("type", "any") def generate_shared_definitions(shared_defs, indent_level: int) -> str: indent = " " * indent_level @@ -1817,12 +1787,12 @@ def generate_shared_definitions(shared_defs, indent_level: int) -> str: shared_definitions += f"{indent}type {def_name} = " if def_properties.get("type") == "object": shared_definitions += generate_type_definition( - def_properties, indent_level, shared_defs + def_properties, indent_level, shared_defs, ) elif "enum" in def_properties: # Enum type shared_definitions += " | ".join( - [f'"{enum_value}"' for enum_value in def_properties["enum"]] + [f'"{enum_value}"' for enum_value in def_properties["enum"]], ) shared_definitions += ";\n" return shared_definitions @@ -1858,49 +1828,48 @@ def generate_schema_from_functions(functions, namespace="functions") -> str: schema += f"{param_name}{optional_indicator}: {param_type},\n" schema += "}) => any;\n\n" - schema += "}} // namespace {}".format(namespace) + schema += f"}} // namespace {namespace}" return schema def prepare_messages_for_inference( - messages: List[llama_types.ChatCompletionRequestMessage], + messages: list[llama_types.ChatCompletionRequestMessage], tokenizer: AutoTokenizer, version: Literal["v1", "v2"], - functions: Optional[List[llama_types.ChatCompletionFunctions]] = None, - tools: Optional[List[llama_types.ChatCompletionTool]] = None, - tool_choice: Union[Dict, str] = "auto", + functions: list[llama_types.ChatCompletionFunctions] | None = None, + tools: list[llama_types.ChatCompletionTool] | None = None, + tool_choice: dict | str = "auto", ): - all_messages: List[llama_types.ChatCompletionRequestMessage] = [] + all_messages: list[llama_types.ChatCompletionRequestMessage] = [] if tool_choice == "none": all_messages.append( llama_types.ChatCompletionRequestSystemMessage( - role="system", content=generate_schema_from_functions([]) - ) + role="system", content=generate_schema_from_functions([]), + ), + ) + elif functions is not None: + all_messages.append( + llama_types.ChatCompletionRequestSystemMessage( + role="system", content=generate_schema_from_functions(functions), + ), + ) + elif tools is not None and tool_choice != "none": + all_messages.append( + llama_types.ChatCompletionRequestSystemMessage( + role="system", + content=generate_schema_from_functions( + [ + tool["function"] + for tool in tools + if tool["type"] == "function" + ], + ), + ), ) - else: - if functions is not None: - all_messages.append( - llama_types.ChatCompletionRequestSystemMessage( - role="system", content=generate_schema_from_functions(functions) - ) - ) - elif tools is not None and tool_choice != "none": - all_messages.append( - llama_types.ChatCompletionRequestSystemMessage( - role="system", - content=generate_schema_from_functions( - [ - tool["function"] - for tool in tools - if tool["type"] == "function" - ] - ), - ) - ) all_messages.append( llama_types.ChatCompletionRequestSystemMessage( - role="system", content=SYSTEM_MESSAGE - ) + role="system", content=SYSTEM_MESSAGE, + ), ) for message in messages: @@ -1937,7 +1906,7 @@ def prepare_messages_for_inference( function_call = "auto" prompt = prepare_messages_for_inference( - messages, tokenizer, version, functions, tools, function_call + messages, tokenizer, version, functions, tools, function_call, ) # If no tools/functions are provided @@ -1989,21 +1958,21 @@ def get_grammar(function_call): try: with suppress_stdout_stderr(disable=llama.verbose): grammar_text = llama_grammar.json_schema_to_gbnf( - json.dumps(function_body) + json.dumps(function_body), ) grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.json_schema_to_gbnf(json.dumps(function_body)) + llama_grammar.json_schema_to_gbnf(json.dumps(function_body)), ) print(grammar_text) except Exception as e: if llama.verbose: print( - "Failed to parse function body as JSON schema, falling back to default grammar" + "Failed to parse function body as JSON schema, falling back to default grammar", ) print(e) with suppress_stdout_stderr(disable=llama.verbose): grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF, verbose=llama.verbose + llama_grammar.JSON_GBNF, verbose=llama.verbose, ) return grammar @@ -2051,7 +2020,7 @@ def generate_streaming(tools, functions, function_call, prompt): grammar = get_grammar(function_call["name"]) stops = [STOP_TOKEN, FROM_TOKEN] tool_id = "".join( - [random.choice(string.ascii_letters + string.digits) for _ in range(24)] + [random.choice(string.ascii_letters + string.digits) for _ in range(24)], ) completion = create_completion(prompt=prompt, stop=stops, grammar=grammar) completion_text = "" @@ -2070,15 +2039,15 @@ def generate_streaming(tools, functions, function_call, prompt): "name": function_call["name"], "arguments": "", }, - } - ] + }, + ], } else: func_call_dict = { "function_call": { "name": function_call["name"], "arguments": "", - } + }, } yield llama_types.CreateChatCompletionStreamResponse( id="chat" + chunk["id"], @@ -2094,7 +2063,7 @@ def generate_streaming(tools, functions, function_call, prompt): "content": None, **func_call_dict, }, - } + }, ], ) first = False @@ -2109,15 +2078,15 @@ def generate_streaming(tools, functions, function_call, prompt): "name": None, "arguments": chunk["choices"][0]["text"].rstrip(), }, - } - ] + }, + ], } else: func_call_dict = { "function_call": { "name": None, "arguments": chunk["choices"][0]["text"].rstrip(), - } + }, } if len(chunk["choices"][0]["text"].rstrip()) > 0: yield llama_types.CreateChatCompletionStreamResponse( @@ -2134,7 +2103,7 @@ def generate_streaming(tools, functions, function_call, prompt): "content": None, **func_call_dict, }, - } + }, ], ) # Yield tool_call/function_call stop message @@ -2156,7 +2125,7 @@ def generate_streaming(tools, functions, function_call, prompt): "function_call": None, "tool_calls": None, }, - } + }, ], ) # If "auto" or no tool_choice/function_call @@ -2167,7 +2136,7 @@ def generate_streaming(tools, functions, function_call, prompt): grammar = None stops = CONTENT_TOKEN completion = create_completion( - prompt=prompt, stop=stops, grammar=grammar + prompt=prompt, stop=stops, grammar=grammar, ) completion_text = "" for chunk in completion: @@ -2191,7 +2160,7 @@ def generate_streaming(tools, functions, function_call, prompt): "delta": {"role": "assistant", "content": ""}, "logprobs": None, "finish_reason": None, - } + }, ], ) else: @@ -2201,7 +2170,7 @@ def generate_streaming(tools, functions, function_call, prompt): [ random.choice(string.ascii_letters + string.digits) for _ in range(24) - ] + ], ) if tools is not None: func_call_dict = { @@ -2214,12 +2183,12 @@ def generate_streaming(tools, functions, function_call, prompt): "name": function_name, "arguments": "", }, - } - ] + }, + ], } else: func_call_dict = { - "function_call": {"name": function_name, "arguments": ""} + "function_call": {"name": function_name, "arguments": ""}, } # Stream function name yield llama_types.CreateChatCompletionStreamResponse( @@ -2236,13 +2205,13 @@ def generate_streaming(tools, functions, function_call, prompt): "content": None, **func_call_dict, }, - } + }, ], ) # Generate content stops = [RECIPIENT_TOKEN, STOP_TOKEN] completion = create_completion( - prompt=prompt, stop=stops, grammar=grammar + prompt=prompt, stop=stops, grammar=grammar, ) if function_name == "all": completion_text = "" @@ -2275,7 +2244,7 @@ def generate_streaming(tools, functions, function_call, prompt): "role": "assistant", "content": buffer.pop(0), }, - } + }, ], ) is_end = False @@ -2304,7 +2273,7 @@ def generate_streaming(tools, functions, function_call, prompt): ].lstrip() ), }, - } + }, ], ) # Check whether the model wants to generate another turn @@ -2336,7 +2305,7 @@ def generate_streaming(tools, functions, function_call, prompt): "delta": {}, "logprobs": None, "finish_reason": "stop", - } + }, ], ) break @@ -2359,8 +2328,8 @@ def generate_streaming(tools, functions, function_call, prompt): "text" ].rstrip(), }, - } - ] + }, + ], } else: func_call_dict = { @@ -2369,7 +2338,7 @@ def generate_streaming(tools, functions, function_call, prompt): "arguments": chunk["choices"][0][ "text" ].rstrip(), - } + }, } yield llama_types.CreateChatCompletionStreamResponse( id="chat" + chunk_id, @@ -2385,16 +2354,16 @@ def generate_streaming(tools, functions, function_call, prompt): "content": None, **func_call_dict, }, - } + }, ], ) prompt += completion_text.strip() grammar = None completion = create_completion( - prompt=prompt, stop=stops, grammar=grammar + prompt=prompt, stop=stops, grammar=grammar, ) completion_text += "".join( - [chunk["choices"][0]["text"] for chunk in completion] + [chunk["choices"][0]["text"] for chunk in completion], ) if ( "<|from|> assistant" in completion_text @@ -2424,210 +2393,201 @@ def generate_streaming(tools, functions, function_call, prompt): "function_call": None, "tool_calls": None, }, - } + }, ], ) break if stream is not False: return generate_streaming( - tools=tools, functions=functions, function_call=function_call, prompt=prompt + tools=tools, functions=functions, function_call=function_call, prompt=prompt, ) - else: - if version == "v1": - # If no or "auto" tool_choice/function_call - if isinstance(function_call, str) and function_call == "auto": - stops = ["\n", END_ASSISTANT_TOKEN] - # If tool_choice/function_call is provided - elif isinstance(function_call, dict): - prompt += f"{START_FUNCTION_CALL_TOKEN}{function_call['name']}:\n" - stops = END_FUNCTION_CALL_TOKEN - function_call = function_call["name"] - function_calls.append(function_call) - grammar = get_grammar(function_call) - else: - prompt = prompt - stops = ["\n", END_ASSISTANT_TOKEN] + if version == "v1": + # If no or "auto" tool_choice/function_call + if isinstance(function_call, str) and function_call == "auto": + stops = ["\n", END_ASSISTANT_TOKEN] + # If tool_choice/function_call is provided + elif isinstance(function_call, dict): + prompt += f"{START_FUNCTION_CALL_TOKEN}{function_call['name']}:\n" + stops = END_FUNCTION_CALL_TOKEN + function_call = function_call["name"] + function_calls.append(function_call) + grammar = get_grammar(function_call) + else: + prompt = prompt + stops = ["\n", END_ASSISTANT_TOKEN] - completion = create_completion(prompt=prompt, stop=stops, grammar=grammar) - completion_text = completion["choices"][0]["text"] - completion_tokens += completion["usage"]["completion_tokens"] + completion = create_completion(prompt=prompt, stop=stops, grammar=grammar) + completion_text = completion["choices"][0]["text"] + completion_tokens += completion["usage"]["completion_tokens"] - # If the generation does not involve a function call - if ( - START_FUNCTION_CALL_TOKEN not in prompt - and START_FUNCTION_CALL_TOKEN not in completion_text - ): - completion["usage"]["completion_tokens"] = completion_tokens - return _convert_completion_to_chat(completion, stream=stream) # type: ignore - # If the generation involves a function call in completion, generate the parameters - elif ( - START_FUNCTION_CALL_TOKEN not in prompt - and START_FUNCTION_CALL_TOKEN in completion_text - ): - prompt += ( - completion_text.replace( - f"{START_FUNCTION_CALL_TOKEN} ", START_FUNCTION_CALL_TOKEN - ) - + "\n" - ) - function_calls.append( - completion_text.split(START_FUNCTION_CALL_TOKEN)[-1][:-1].strip() - ) - grammar = get_grammar(function_calls[-1]) - completion = create_completion( - prompt=prompt, stop=END_FUNCTION_CALL_TOKEN, grammar=grammar + # If the generation does not involve a function call + if ( + START_FUNCTION_CALL_TOKEN not in prompt + and START_FUNCTION_CALL_TOKEN not in completion_text + ): + completion["usage"]["completion_tokens"] = completion_tokens + return _convert_completion_to_chat(completion, stream=stream) # type: ignore + # If the generation involves a function call in completion, generate the parameters + if ( + START_FUNCTION_CALL_TOKEN not in prompt + and START_FUNCTION_CALL_TOKEN in completion_text + ): + prompt += ( + completion_text.replace( + f"{START_FUNCTION_CALL_TOKEN} ", START_FUNCTION_CALL_TOKEN, ) - completion_tokens += completion["usage"]["completion_tokens"] - function_bodies.append(completion["choices"][0]["text"].strip()) - # If the prompt involves a function call, just append generated parameters to function_bodies - else: - function_bodies.append(completion_text.strip()) + + "\n" + ) + function_calls.append( + completion_text.split(START_FUNCTION_CALL_TOKEN)[-1][:-1].strip(), + ) + grammar = get_grammar(function_calls[-1]) + completion = create_completion( + prompt=prompt, stop=END_FUNCTION_CALL_TOKEN, grammar=grammar, + ) + completion_tokens += completion["usage"]["completion_tokens"] + function_bodies.append(completion["choices"][0]["text"].strip()) + # If the prompt involves a function call, just append generated parameters to function_bodies else: - # If tool_choice/function_call is provided - if isinstance(function_call, dict): - prompt += f"{function_call['name']}\n{CONTENT_TOKEN}" - function_call = function_call["name"] + function_bodies.append(completion_text.strip()) + elif isinstance(function_call, dict): + prompt += f"{function_call['name']}\n{CONTENT_TOKEN}" + function_call = function_call["name"] + function_calls.append(function_call) + grammar = get_grammar(function_call) + stops = [STOP_TOKEN, FROM_TOKEN] + completion = create_completion( + prompt=prompt, stop=stops, grammar=grammar, + ) + completion_text = completion["choices"][0]["text"] + completion_tokens += completion["usage"]["completion_tokens"] + function_bodies.append(completion_text.strip()) + # If "auto" or no tool_choice/function_call + elif isinstance(function_call, str) and function_call == "auto": + while True: + # Generate function name first + grammar = None + stops = CONTENT_TOKEN + completion = create_completion( + prompt=prompt, stop=stops, grammar=grammar, + ) + completion_text = completion["choices"][0]["text"] + completion_tokens += completion["usage"]["completion_tokens"] + function_name = completion_text.strip() + if function_name == "all": + prompt += "all\n<|content|>" + else: + function_call = completion_text.strip() + prompt += f"{function_call}\n<|content|>" function_calls.append(function_call) grammar = get_grammar(function_call) - stops = [STOP_TOKEN, FROM_TOKEN] + # Generate content + stops = [RECIPIENT_TOKEN, STOP_TOKEN] + completion = create_completion( + prompt=prompt, stop=stops, grammar=grammar, + ) + completion_text = completion["choices"][0]["text"] + completion_tokens += completion["usage"]["completion_tokens"] + if function_name == "all": + if completion_text.endswith("\n<|from|>assistant\n"): + content += completion_text[: -len("\n<|from|>assistant\n")] + if completion_text.endswith("\n<|from|> assistant\n"): + content += completion_text[-len("\n<|from|> assistant\n")] + else: + content += completion_text + content = content.lstrip() + # Check whether the model wants to generate another turn + if ( + "<|from|> assistant" in completion_text + or "<|from|>assistant" in completion_text + ): + if completion_text.endswith("\n<|from|>assistant\n"): + cleaned_completion_text = completion_text[ + : -len("\n<|from|>assistant\n") + ].strip() + elif completion_text.endswith("\n<|from|> assistant\n"): + cleaned_completion_text = completion_text[ + -len("\n<|from|> assistant\n") + ].strip() + else: + cleaned_completion_text = completion_text.strip() + prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>" + else: + break + else: + function_bodies.append(completion_text.strip()) + # Check whether the model wants to generate another turn + prompt += completion_text.strip() + grammar = None completion = create_completion( - prompt=prompt, stop=stops, grammar=grammar + prompt=prompt, stop=stops, grammar=grammar, ) - completion_text = completion["choices"][0]["text"] completion_tokens += completion["usage"]["completion_tokens"] - function_bodies.append(completion_text.strip()) - # If "auto" or no tool_choice/function_call - elif isinstance(function_call, str) and function_call == "auto": - while True: - # Generate function name first - grammar = None - stops = CONTENT_TOKEN - completion = create_completion( - prompt=prompt, stop=stops, grammar=grammar - ) - completion_text = completion["choices"][0]["text"] - completion_tokens += completion["usage"]["completion_tokens"] - function_name = completion_text.strip() - if function_name == "all": - prompt += "all\n<|content|>" - else: - function_call = completion_text.strip() - prompt += f"{function_call}\n<|content|>" - function_calls.append(function_call) - grammar = get_grammar(function_call) - # Generate content - stops = [RECIPIENT_TOKEN, STOP_TOKEN] - completion = create_completion( - prompt=prompt, stop=stops, grammar=grammar - ) - completion_text = completion["choices"][0]["text"] - completion_tokens += completion["usage"]["completion_tokens"] - if function_name == "all": - if completion_text.endswith("\n<|from|>assistant\n"): - content += completion_text[: -len("\n<|from|>assistant\n")] - if completion_text.endswith("\n<|from|> assistant\n"): - content += completion_text[-len("\n<|from|> assistant\n")] - else: - content += completion_text - content = content.lstrip() - # Check whether the model wants to generate another turn - if ( - "<|from|> assistant" in completion_text - or "<|from|>assistant" in completion_text - ): - if completion_text.endswith("\n<|from|>assistant\n"): - cleaned_completion_text = completion_text[ - : -len("\n<|from|>assistant\n") - ].strip() - elif completion_text.endswith("\n<|from|> assistant\n"): - cleaned_completion_text = completion_text[ - -len("\n<|from|> assistant\n") - ].strip() - else: - cleaned_completion_text = completion_text.strip() - prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>" - else: - break - else: - function_bodies.append(completion_text.strip()) - # Check whether the model wants to generate another turn - prompt += completion_text.strip() - grammar = None - completion = create_completion( - prompt=prompt, stop=stops, grammar=grammar - ) - completion_tokens += completion["usage"]["completion_tokens"] - if ( - "<|from|> assistant" in completion["choices"][0]["text"] - or "<|from|>assistant" in completion["choices"][0]["text"] - ): - prompt += "\n<|from|>assistant\n<|recipient|>" - else: - break - - assert "usage" in completion - assert len(function_calls) == len(function_bodies) + if ( + "<|from|> assistant" in completion["choices"][0]["text"] + or "<|from|>assistant" in completion["choices"][0]["text"] + ): + prompt += "\n<|from|>assistant\n<|recipient|>" + else: + break - tool_calls: List[llama_types.ChatCompletionMessageToolCall] = [] - for function_call, function_body in zip(function_calls, function_bodies): - tool_calls.append( - { - "id": "call_" - + "".join( - [ - random.choice(string.ascii_letters + string.digits) - for _ in range(24) - ] - ), - "type": "function", - "function": { - "name": function_call, - "arguments": function_body, - }, - } - ) + assert "usage" in completion + assert len(function_calls) == len(function_bodies) - # TODO: support stream mode - function_call_dict: Union[ - Dict[str, str], - Dict[ - Literal["function_call"], - llama_types.ChatCompletionRequestAssistantMessageFunctionCall, - ], - ] = {} - if len(tool_calls) > 0: - if tools is not None: - function_call_dict["tool_calls"] = tool_calls - else: - function_call_dict["function_call"] = { - "name": tool_calls[0]["function"]["name"], - "arguments": tool_calls[0]["function"]["arguments"], - } - completion["usage"]["completion_tokens"] = completion_tokens - return llama_types.CreateChatCompletionResponse( - id="chat" + completion["id"], - object="chat.completion", - created=completion["created"], - model=completion["model"], - choices=[ - { - "index": 0, - "logprobs": completion["choices"][0]["logprobs"], - "message": { - "role": "assistant", - "content": None if content == "" else content, - **function_call_dict, - }, - "finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop", - } - ], - usage=completion["usage"], + tool_calls: list[llama_types.ChatCompletionMessageToolCall] = [] + for function_call, function_body in zip(function_calls, function_bodies, strict=False): + tool_calls.append( + { + "id": "call_" + + "".join( + [ + random.choice(string.ascii_letters + string.digits) + for _ in range(24) + ], + ), + "type": "function", + "function": { + "name": function_call, + "arguments": function_body, + }, + }, ) + # TODO: support stream mode + function_call_dict: dict[str, str] | dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall] = {} + if len(tool_calls) > 0: + if tools is not None: + function_call_dict["tool_calls"] = tool_calls + else: + function_call_dict["function_call"] = { + "name": tool_calls[0]["function"]["name"], + "arguments": tool_calls[0]["function"]["arguments"], + } + completion["usage"]["completion_tokens"] = completion_tokens + return llama_types.CreateChatCompletionResponse( + id="chat" + completion["id"], + object="chat.completion", + created=completion["created"], + model=completion["model"], + choices=[ + { + "index": 0, + "logprobs": completion["choices"][0]["logprobs"], + "message": { + "role": "assistant", + "content": None if content == "" else content, + **function_call_dict, + }, + "finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop", + }, + ], + usage=completion["usage"], + ) + class Llava15ChatHandler: - DEFAULT_SYSTEM_MESSAGE: Optional[str] = ( + DEFAULT_SYSTEM_MESSAGE: str | None = ( "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." ) @@ -2667,17 +2627,15 @@ class Llava15ChatHandler: ) def __init__(self, clip_model_path: str, verbose: bool = True): - import llama_cpp.llava_cpp as llava_cpp + from llama_cpp import llava_cpp self.clip_model_path = clip_model_path self.verbose = verbose self._llava_cpp = llava_cpp # TODO: Fix self._exit_stack = ExitStack() - self._last_image_embed: Optional[ - llava_cpp.CtypesPointer[llava_cpp.llava_image_embed] - ] = None - self._last_image_hash: Optional[int] = None + self._last_image_embed: llava_cpp.CtypesPointer[llava_cpp.llava_image_embed] | None = None + self._last_image_hash: int | None = None if not os.path.exists(clip_model_path): raise ValueError(f"Clip model path does not exist: {clip_model_path}") @@ -2736,23 +2694,21 @@ def __call__( self, *, llama: llama.Llama, - messages: List[llama_types.ChatCompletionRequestMessage], - functions: Optional[List[llama_types.ChatCompletionFunction]] = None, - function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, - tools: Optional[List[llama_types.ChatCompletionTool]] = None, - tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + messages: list[llama_types.ChatCompletionRequestMessage], + functions: list[llama_types.ChatCompletionFunction] | None = None, + function_call: llama_types.ChatCompletionRequestFunctionCall | None = None, + tools: list[llama_types.ChatCompletionTool] | None = None, + tool_choice: llama_types.ChatCompletionToolChoiceOption | None = None, temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, min_p: float = 0.05, typical_p: float = 1.0, stream: bool = False, - stop: Optional[Union[str, List[str]]] = [], - seed: Optional[int] = None, - response_format: Optional[ - llama_types.ChatCompletionRequestResponseFormat - ] = None, - max_tokens: Optional[int] = None, + stop: str | list[str] | None = [], + seed: int | None = None, + response_format: llama_types.ChatCompletionRequestResponseFormat | None = None, + max_tokens: int | None = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, @@ -2760,25 +2716,22 @@ def __call__( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, - model: Optional[str] = None, - logits_processor: Optional[llama.LogitsProcessorList] = None, - grammar: Optional[llama.LlamaGrammar] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - top_logprobs: Optional[int] = None, + model: str | None = None, + logits_processor: llama.LogitsProcessorList | None = None, + grammar: llama.LlamaGrammar | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + top_logprobs: int | None = None, **kwargs, # type: ignore - ) -> Union[ - llama_types.CreateChatCompletionResponse, - Iterator[llama_types.CreateChatCompletionStreamResponse], - ]: + ) -> llama_types.CreateChatCompletionResponse | Iterator[llama_types.CreateChatCompletionStreamResponse]: assert self.clip_ctx is not None system_prompt = _get_system_message(messages) if system_prompt == "" and self.DEFAULT_SYSTEM_MESSAGE is not None: messages = [ llama_types.ChatCompletionRequestSystemMessage( - role="system", content=self.DEFAULT_SYSTEM_MESSAGE - ) + role="system", content=self.DEFAULT_SYSTEM_MESSAGE, + ), ] + messages image_urls = self.get_image_urls(messages) @@ -2804,11 +2757,11 @@ def __call__( for type_, value in split_text: if type_ == "text": tokens = llama.tokenize( - value.encode("utf8"), add_bos=False, special=True + value.encode("utf8"), add_bos=False, special=True, ) if llama.n_tokens + len(tokens) > llama.n_ctx(): raise ValueError( - f"Prompt exceeds n_ctx: {llama.n_tokens + len(tokens)} > {llama.n_ctx()}" + f"Prompt exceeds n_ctx: {llama.n_tokens + len(tokens)} > {llama.n_ctx()}", ) llama.eval(tokens) else: @@ -2816,7 +2769,7 @@ def __call__( embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch) if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx(): raise ValueError( - f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}" + f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}", ) n_past = ctypes.c_int(llama.n_tokens) n_past_p = ctypes.pointer(n_past) @@ -2875,13 +2828,13 @@ def __call__( try: # create grammar from json schema grammar = llama_grammar.LlamaGrammar.from_json_schema( - json.dumps(schema), verbose=llama.verbose + json.dumps(schema), verbose=llama.verbose, ) except Exception as e: if llama.verbose: print(str(e), file=sys.stderr) grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF, verbose=llama.verbose + llama_grammar.JSON_GBNF, verbose=llama.verbose, ) completion_or_chunks = llama.create_completion( @@ -2911,7 +2864,7 @@ def __call__( if tool is not None: tool_name = tool["function"]["name"] return _convert_completion_to_chat_function( - tool_name, completion_or_chunks, stream + tool_name, completion_or_chunks, stream, ) return _convert_completion_to_chat(completion_or_chunks, stream=stream) @@ -2923,16 +2876,15 @@ def _load_image(image_url: str) -> bytes: image_bytes = base64.b64decode(image_url.split(",")[1]) return image_bytes - else: - import urllib.request + import urllib.request - with urllib.request.urlopen(image_url) as f: - image_bytes = f.read() - return image_bytes + with urllib.request.urlopen(image_url) as f: + image_bytes = f.read() + return image_bytes @staticmethod - def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]): - image_urls: List[str] = [] + def get_image_urls(messages: list[llama_types.ChatCompletionRequestMessage]): + image_urls: list[str] = [] for message in messages: if message["role"] == "user": if message["content"] is None: @@ -2950,15 +2902,15 @@ def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]): return image_urls @staticmethod - def split_text_on_image_urls(text: str, image_urls: List[str]): - def find_first(s: str, substrs: List[str]): + def split_text_on_image_urls(text: str, image_urls: list[str]): + def find_first(s: str, substrs: list[str]): for i, substr in enumerate(substrs): pos = s.find(substr) if pos != -1: return pos, i return None, None - split_text: List[Tuple[Literal["text", "image_url"], str]] = [] + split_text: list[tuple[Literal["text", "image_url"], str]] = [] remaining = text while remaining: # Find first image_url @@ -2977,22 +2929,22 @@ def find_first(s: str, substrs: List[str]): def from_pretrained( cls, repo_id: str, - filename: Optional[str], - local_dir: Optional[Union[str, os.PathLike[str]]] = None, - local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", - cache_dir: Optional[Union[str, os.PathLike[str]]] = None, + filename: str | None, + local_dir: str | os.PathLike[str] | None = None, + local_dir_use_symlinks: bool | Literal["auto"] = "auto", + cache_dir: str | os.PathLike[str] | None = None, **kwargs: Any, - ) -> "Llava15ChatHandler": + ) -> Llava15ChatHandler: import fnmatch from pathlib import Path try: - from huggingface_hub import hf_hub_download, HfFileSystem # type: ignore + from huggingface_hub import HfFileSystem, hf_hub_download # type: ignore from huggingface_hub.utils import validate_repo_id # type: ignore except ImportError: raise ImportError( "Llama.from_pretrained requires the huggingface-hub package. " - "You can install it with `pip install huggingface-hub`." + "You can install it with `pip install huggingface-hub`.", ) validate_repo_id(repo_id) @@ -3005,7 +2957,7 @@ def from_pretrained( ] # split each file into repo_id, subfolder, filename - file_list: List[str] = [] + file_list: list[str] = [] for file in files: rel_path = Path(file).relative_to(repo_id) file_list.append(str(rel_path)) @@ -3015,13 +2967,13 @@ def from_pretrained( if len(matching_files) == 0: raise ValueError( f"No file found in {repo_id} that match {filename}\n\n" - f"Available Files:\n{json.dumps(file_list)}" + f"Available Files:\n{json.dumps(file_list)}", ) if len(matching_files) > 1: raise ValueError( f"Multiple files found in {repo_id} matching {filename}\n\n" - f"Available Files:\n{json.dumps(files)}" + f"Available Files:\n{json.dumps(files)}", ) (matching_file,) = matching_files @@ -3034,9 +2986,9 @@ def from_pretrained( repo_id=repo_id, filename=filename, subfolder=subfolder, - local_dir=cast(Union[str, Path, None], local_dir), + local_dir=cast(str | Path | None, local_dir), local_dir_use_symlinks=local_dir_use_symlinks, - cache_dir=cast(Union[str, Path, None], cache_dir), + cache_dir=cast(str | Path | None, cache_dir), ) if local_dir is None: @@ -3046,7 +2998,7 @@ def from_pretrained( subfolder=subfolder, local_dir=local_dir, local_dir_use_symlinks=local_dir_use_symlinks, - cache_dir=cast(Union[str, Path, None], cache_dir), + cache_dir=cast(str | Path | None, cache_dir), local_files_only=True, ) else: @@ -3353,20 +3305,20 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler): @register_chat_completion_handler("chatml-function-calling") def chatml_function_calling( llama: llama.Llama, - messages: List[llama_types.ChatCompletionRequestMessage], - functions: Optional[List[llama_types.ChatCompletionFunction]] = None, - function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, - tools: Optional[List[llama_types.ChatCompletionTool]] = None, - tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + messages: list[llama_types.ChatCompletionRequestMessage], + functions: list[llama_types.ChatCompletionFunction] | None = None, + function_call: llama_types.ChatCompletionRequestFunctionCall | None = None, + tools: list[llama_types.ChatCompletionTool] | None = None, + tool_choice: llama_types.ChatCompletionToolChoiceOption | None = None, temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, min_p: float = 0.05, typical_p: float = 1.0, stream: bool = False, - stop: Optional[Union[str, List[str]]] = [], - response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, - max_tokens: Optional[int] = None, + stop: str | list[str] | None = [], + response_format: llama_types.ChatCompletionRequestResponseFormat | None = None, + max_tokens: int | None = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, @@ -3374,16 +3326,13 @@ def chatml_function_calling( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, - model: Optional[str] = None, - logits_processor: Optional[llama.LogitsProcessorList] = None, - grammar: Optional[llama.LlamaGrammar] = None, - logprobs: Optional[bool] = None, - top_logprobs: Optional[int] = None, + model: str | None = None, + logits_processor: llama.LogitsProcessorList | None = None, + grammar: llama.LlamaGrammar | None = None, + logprobs: bool | None = None, + top_logprobs: int | None = None, **kwargs, # type: ignore -) -> Union[ - llama_types.CreateChatCompletionResponse, - Iterator[llama_types.CreateChatCompletionStreamResponse], -]: +) -> llama_types.CreateChatCompletionResponse | Iterator[llama_types.CreateChatCompletionStreamResponse]: function_calling_template = ( "{% for message in messages %}" "<|im_start|>{{ message.role }}\n" @@ -3517,7 +3466,7 @@ def chatml_function_calling( if isinstance(tool_choice, dict): tool_name = tool_choice["function"]["name"] tool = next( - (tool for tool in tools if tool["function"]["name"] == tool_name), None + (tool for tool in tools if tool["function"]["name"] == tool_name), None, ) if tool is None: raise ValueError(f"Tool with name '{tool_name}' not found in tools") @@ -3530,15 +3479,15 @@ def chatml_function_calling( prompt += f"functions.{tool_name}:\n" try: grammar = llama_grammar.LlamaGrammar.from_json_schema( - json.dumps(tool["function"]["parameters"]), verbose=llama.verbose + json.dumps(tool["function"]["parameters"]), verbose=llama.verbose, ) except Exception as e: grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF, verbose=llama.verbose + llama_grammar.JSON_GBNF, verbose=llama.verbose, ) if llama.verbose: print( - "Failed to parse function body as JSON schema, falling back to default grammar" + "Failed to parse function body as JSON schema, falling back to default grammar", ) print(e) completion_or_chunks = llama.create_completion( @@ -3563,13 +3512,13 @@ def chatml_function_calling( grammar=grammar, ) return _convert_completion_to_chat_function( - tool_name, completion_or_chunks, stream + tool_name, completion_or_chunks, stream, ) # Case 3: Automatic tool choice assert isinstance(tool_choice, str) and tool_choice == "auto" function_names = " | ".join( - [f'''"functions.{tool['function']['name']}:"''' for tool in tools] + [f'''"functions.{tool['function']['name']}:"''' for tool in tools], ) initial_gbnf_tool_grammar = ( """root ::= functions | "message:"\n""" @@ -3605,7 +3554,7 @@ def chatml_function_calling( model=model, logits_processor=logits_processor, grammar=llama_grammar.LlamaGrammar.from_string( - initial_gbnf_tool_grammar, verbose=llama.verbose + initial_gbnf_tool_grammar, verbose=llama.verbose, ), ) completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore @@ -3633,7 +3582,7 @@ def chatml_function_calling( model=model, logits_processor=logits_processor, grammar=llama_grammar.LlamaGrammar.from_string( - follow_up_gbnf_tool_grammar, verbose=llama.verbose + follow_up_gbnf_tool_grammar, verbose=llama.verbose, ), ), stream=stream, @@ -3643,21 +3592,21 @@ def chatml_function_calling( tool_name = text[len("functions.") :] tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) if not stream: - completions: List[llama_types.CreateCompletionResponse] = [] - completions_tool_name: List[str] = [] + completions: list[llama_types.CreateCompletionResponse] = [] + completions_tool_name: list[str] = [] while tool is not None: prompt += f"functions.{tool_name}:\n" try: grammar = llama_grammar.LlamaGrammar.from_json_schema( - json.dumps(tool["function"]["parameters"]), verbose=llama.verbose + json.dumps(tool["function"]["parameters"]), verbose=llama.verbose, ) except Exception as e: grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF, verbose=llama.verbose + llama_grammar.JSON_GBNF, verbose=llama.verbose, ) if llama.verbose: print( - "Failed to parse function body as JSON schema, falling back to default grammar" + "Failed to parse function body as JSON schema, falling back to default grammar", ) print(e) completion_or_chunks = llama.create_completion( @@ -3682,7 +3631,7 @@ def chatml_function_calling( grammar=grammar, ) completion_or_chunks = cast( - llama_types.CreateCompletionResponse, completion_or_chunks + llama_types.CreateCompletionResponse, completion_or_chunks, ) completions.append(completion_or_chunks) completions_tool_name.append(tool_name) @@ -3709,29 +3658,23 @@ def chatml_function_calling( model=model, logits_processor=logits_processor, grammar=llama_grammar.LlamaGrammar.from_string( - follow_up_gbnf_tool_grammar, verbose=llama.verbose + follow_up_gbnf_tool_grammar, verbose=llama.verbose, ), ) response = cast(llama_types.CreateCompletionResponse, response) tool_name = response["choices"][0]["text"][len("functions.") :] tool = next( - (tool for tool in tools if tool["function"]["name"] == tool_name), None + (tool for tool in tools if tool["function"]["name"] == tool_name), None, ) # Merge completions - function_call_dict: Union[ - Dict[str, str], - Dict[ - Literal["function_call"], - llama_types.ChatCompletionRequestAssistantMessageFunctionCall, - ], - ] = ( + function_call_dict: dict[str, str] | dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall] = ( { "function_call": { "name": tool_name, "arguments": completions[0]["choices"][0]["text"], - } + }, } if len(completions) == 1 else {} @@ -3763,12 +3706,12 @@ def chatml_function_calling( }, } for i, (tool_name, completion) in enumerate( - zip(completions_tool_name, completions) + zip(completions_tool_name, completions, strict=False), ) ], **function_call_dict, }, - } + }, ], "usage": { "completion_tokens": sum( diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 6b82753e1..95439694d 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -1,22 +1,22 @@ from __future__ import annotations -import sys -import os import ctypes import functools +import os import pathlib - +import sys from typing import ( + TYPE_CHECKING, Any, Callable, + Generic, List, - Union, NewType, Optional, - TYPE_CHECKING, TypeVar, - Generic, + Union, ) + from typing_extensions import TypeAlias @@ -962,7 +962,7 @@ class llama_context_params(ctypes.Structure): # int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() # enum llama_ftype ftype; // quantize to this llama_ftype # enum ggml_type output_tensor_type; // output tensor type -# enum ggml_type token_embedding_type; // token embeddings tensor type +# enum ggml_type token_embedding_type; // itoken embeddings tensor type # bool allow_requantize; // allow quantizing non-f32/f16 tensors # bool quantize_output_tensor; // quantize output.weight # bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored @@ -978,7 +978,7 @@ class llama_model_quantize_params(ctypes.Structure): nthread (int): number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() ftype (int): quantize to this llama_ftype output_tensor_type (int): output tensor type - token_embedding_type (int): token embeddings tensor type + token_embedding_type (int): itoken embeddings tensor type allow_requantize (bool): allow quantizing non-f32/f16 tensors quantize_output_tensor (bool): quantize output.weight only_copy (bool): only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored @@ -1493,14 +1493,6 @@ def llama_model_has_encoder(model: llama_model_p, /) -> bool: ... -# // Returns true if the model contains a decoder that requires llama_decode() call -# LLAMA_API bool llama_model_has_decoder(const struct llama_model * model); -@ctypes_function("llama_model_has_decoder", [llama_model_p_ctypes], ctypes.c_bool) -def llama_model_has_decoder(model: llama_model_p, /) -> bool: - """Returns true if the model contains a decoder that requires llama_decode() call""" - ... - - # // For encoder-decoder models, this function returns id of the token that must be provided # // to the decoder to start generating output sequence. For other models, it returns -1. # LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model); @@ -1771,7 +1763,7 @@ def llama_kv_cache_view_init( # // Free a KV cache view. (use only for debugging purposes) # LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view); @ctypes_function("llama_kv_cache_view_free", [llama_kv_cache_view_p], None) -def llama_kv_cache_view_free(view: "ctypes.pointer[llama_kv_cache_view]", /): # type: ignore +def llama_kv_cache_view_free(view: ctypes.pointer[llama_kv_cache_view], /): # type: ignore """Free a KV cache view. (use only for debugging purposes)""" ... diff --git a/llama_cpp/llama_speculative.py b/llama_cpp/llama_speculative.py index 39dfb903b..6950c5a64 100644 --- a/llama_cpp/llama_speculative.py +++ b/llama_cpp/llama_speculative.py @@ -1,5 +1,4 @@ import abc - from typing import Any import numpy as np @@ -9,9 +8,9 @@ class LlamaDraftModel(abc.ABC): @abc.abstractmethod def __call__( - self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any + self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any, ) -> npt.NDArray[np.intc]: - raise NotImplementedError() + raise NotImplementedError class LlamaPromptLookupDecoding(LlamaDraftModel): @@ -55,7 +54,7 @@ def find_candidate_pred_tokens( return np.array([], dtype=np.intc) def __call__( - self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any + self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any, ) -> npt.NDArray[np.intc]: return self.find_candidate_pred_tokens( input_ids=input_ids, diff --git a/llama_cpp/llama_tokenizer.py b/llama_cpp/llama_tokenizer.py index 1375e1392..7134c2220 100644 --- a/llama_cpp/llama_tokenizer.py +++ b/llama_cpp/llama_tokenizer.py @@ -2,9 +2,9 @@ import abc from typing import ( + Any, List, Optional, - Any, ) import llama_cpp @@ -14,7 +14,7 @@ class BaseLlamaTokenizer(abc.ABC): @abc.abstractmethod def tokenize( - self, text: bytes, add_bos: bool = True, special: bool = True + self, text: bytes, add_bos: bool = True, special: bool = True, ) -> List[int]: """Tokenize the text into tokens. @@ -47,7 +47,7 @@ def __init__(self, llama: llama_cpp.Llama): self._model = llama._model # type: ignore def tokenize( - self, text: bytes, add_bos: bool = True, special: bool = True + self, text: bytes, add_bos: bool = True, special: bool = True, ) -> List[int]: return self._model.tokenize(text, add_bos=add_bos, special=special) @@ -60,17 +60,17 @@ def detokenize( return self._model.detokenize(tokens, special=special) def encode( - self, text: str, add_bos: bool = True, special: bool = True + self, text: str, add_bos: bool = True, special: bool = True, ) -> List[int]: return self.tokenize( - text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special + text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special, ) def decode(self, tokens: List[int]) -> str: return self.detokenize(tokens).decode("utf-8", errors="ignore") @classmethod - def from_ggml_file(cls, path: str) -> "LlamaTokenizer": + def from_ggml_file(cls, path: str) -> LlamaTokenizer: return cls(llama_cpp.Llama(model_path=path, vocab_only=True)) @@ -79,10 +79,10 @@ def __init__(self, hf_tokenizer: Any): self.hf_tokenizer = hf_tokenizer def tokenize( - self, text: bytes, add_bos: bool = True, special: bool = True + self, text: bytes, add_bos: bool = True, special: bool = True, ) -> List[int]: return self.hf_tokenizer.encode( - text.decode("utf-8", errors="ignore"), add_special_tokens=special + text.decode("utf-8", errors="ignore"), add_special_tokens=special, ) def detokenize( @@ -106,15 +106,15 @@ def detokenize( ).encode("utf-8", errors="ignore") @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer": + def from_pretrained(cls, pretrained_model_name_or_path: str) -> LlamaHFTokenizer: try: from transformers import AutoTokenizer except ImportError: raise ImportError( "The `transformers` library is required to use the `HFTokenizer`." - "You can install it with `pip install transformers`." + "You can install it with `pip install transformers`.", ) hf_tokenizer = AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=pretrained_model_name_or_path + pretrained_model_name_or_path=pretrained_model_name_or_path, ) return cls(hf_tokenizer) diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index bbb58afc3..84bde6ef3 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -7,9 +7,9 @@ """ -from typing import Any, List, Optional, Dict, Union -from typing_extensions import TypedDict, NotRequired, Literal +from typing import Any, Dict, List, Optional, Union +from typing_extensions import Literal, NotRequired, TypedDict # NOTE: Defining this correctly using annotations seems to break pydantic validation. # This is a workaround until we can figure out how to do this correctly @@ -131,7 +131,7 @@ class ChatCompletionStreamResponseDelta(TypedDict): class ChatCompletionStreamResponseChoice(TypedDict): index: int delta: Union[ - ChatCompletionStreamResponseDelta, ChatCompletionStreamResponseDeltaEmpty + ChatCompletionStreamResponseDelta, ChatCompletionStreamResponseDeltaEmpty, ] finish_reason: Optional[Literal["stop", "length", "tool_calls", "function_call"]] logprobs: NotRequired[Optional[CompletionLogprobs]] @@ -248,7 +248,7 @@ class ChatCompletionRequestFunctionCallOption(TypedDict): ChatCompletionRequestFunctionCall = Union[ - Literal["none", "auto"], ChatCompletionRequestFunctionCallOption + Literal["none", "auto"], ChatCompletionRequestFunctionCallOption, ] ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific @@ -275,7 +275,7 @@ class ChatCompletionNamedToolChoice(TypedDict): ChatCompletionToolChoiceOption = Union[ - Literal["none", "auto", "required"], ChatCompletionNamedToolChoice + Literal["none", "auto", "required"], ChatCompletionNamedToolChoice, ] diff --git a/llama_cpp/llava_cpp.py b/llama_cpp/llava_cpp.py index 7d97dc0fd..506bded3d 100644 --- a/llama_cpp/llava_cpp.py +++ b/llama_cpp/llava_cpp.py @@ -1,35 +1,35 @@ from __future__ import annotations -import sys -import os import ctypes import functools +import os +import pathlib +import sys +from collections.abc import Callable from ctypes import ( + POINTER, + Structure, + _Pointer, # type: ignore c_bool, c_char_p, + c_float, c_int, c_uint8, - c_float, c_void_p, - POINTER, - _Pointer, # type: ignore - Structure, ) -import pathlib from typing import ( + TYPE_CHECKING, + Any, + Generic, List, - Union, NewType, Optional, + TypeAlias, TypeVar, - Callable, - Any, - TYPE_CHECKING, - Generic, + Union, ) -from typing_extensions import TypeAlias -import llama_cpp.llama_cpp as llama_cpp +from llama_cpp import llama_cpp # Load the library @@ -38,7 +38,7 @@ def _load_shared_library(lib_base_name: str): _base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" # Searching for the library in the current directory under the name "libllama" (default name # for llamacpp) and "llama" (default name for this repo) - _lib_paths: List[pathlib.Path] = [] + _lib_paths: list[pathlib.Path] = [] # Determine the file extension based on the platform if sys.platform.startswith("linux"): _lib_paths += [ @@ -81,7 +81,7 @@ def _load_shared_library(lib_base_name: str): raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") raise FileNotFoundError( - f"Shared library with base name '{lib_base_name}' not found" + f"Shared library with base name '{lib_base_name}' not found", ) @@ -105,9 +105,7 @@ def _load_shared_library(lib_base_name: str): class CtypesRef(Generic[CtypesCData]): pass - CtypesPointerOrRef: TypeAlias = Union[ - CtypesPointer[CtypesCData], CtypesRef[CtypesCData] - ] + CtypesPointerOrRef: TypeAlias = CtypesPointer[CtypesCData] | CtypesRef[CtypesCData] CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore @@ -116,7 +114,7 @@ class CtypesRef(Generic[CtypesCData]): def ctypes_function_for_shared_library(lib: ctypes.CDLL): def ctypes_function( - name: str, argtypes: List[Any], restype: Any, enabled: bool = True + name: str, argtypes: list[Any], restype: Any, enabled: bool = True, ): def decorator(f: F) -> F: if enabled: @@ -125,8 +123,8 @@ def decorator(f: F) -> F: func.restype = restype functools.wraps(f)(func) return func - else: - return f + + return f return decorator @@ -178,9 +176,9 @@ def llava_validate_embed_size( ) def llava_image_embed_make_with_bytes( ctx_clip: clip_ctx_p, - n_threads: Union[c_int, int], + n_threads: c_int | int, image_bytes: CtypesArray[c_uint8], - image_bytes_length: Union[c_int, int], + image_bytes_length: c_int | int, /, ) -> "_Pointer[llava_image_embed]": ... @@ -220,9 +218,9 @@ def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /): ) def llava_eval_image_embed( ctx_llama: llama_cpp.llama_context_p, - embed: "_Pointer[llava_image_embed]", - n_batch: Union[c_int, int], - n_past: "_Pointer[c_int]", + embed: _Pointer[llava_image_embed], + n_batch: c_int | int, + n_past: _Pointer[c_int], /, ) -> bool: ... diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index bbac4957e..7b74a0ef5 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -24,20 +24,20 @@ from __future__ import annotations +import argparse import os import sys -import argparse import uvicorn from llama_cpp.server.app import create_app +from llama_cpp.server.cli import add_args_from_model, parse_model_from_args from llama_cpp.server.settings import ( - Settings, - ServerSettings, - ModelSettings, ConfigFileSettings, + ModelSettings, + ServerSettings, + Settings, ) -from llama_cpp.server.cli import add_args_from_model, parse_model_from_args def main(): @@ -62,15 +62,16 @@ def main(): with open(config_file, "rb") as f: # Check if yaml file if config_file.endswith(".yaml") or config_file.endswith(".yml"): - import yaml import json + import yaml + config_file_settings = ConfigFileSettings.model_validate_json( - json.dumps(yaml.safe_load(f)) + json.dumps(yaml.safe_load(f)), ) else: config_file_settings = ConfigFileSettings.model_validate_json( - f.read() + f.read(), ) server_settings = ServerSettings.model_validate(config_file_settings) model_settings = config_file_settings.models diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index cd3255176..ec7da0712 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -1,49 +1,46 @@ from __future__ import annotations -import os +import contextlib import json +import os import typing -import contextlib - -from threading import Lock from functools import partial -from typing import Iterator, List, Optional, Union, Dict - -import llama_cpp +from threading import Lock +from typing import Dict, Iterator, List, Optional, Union import anyio from anyio.streams.memory import MemoryObjectSendStream -from starlette.concurrency import run_in_threadpool, iterate_in_threadpool -from fastapi import Depends, FastAPI, APIRouter, Request, HTTPException, status, Body +from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request, status from fastapi.middleware import Middleware from fastapi.middleware.cors import CORSMiddleware from fastapi.security import HTTPBearer from sse_starlette.sse import EventSourceResponse -from starlette_context.plugins import RequestIdPlugin # type: ignore +from starlette.concurrency import iterate_in_threadpool, run_in_threadpool from starlette_context.middleware import RawContextMiddleware +from starlette_context.plugins import RequestIdPlugin # type: ignore +import llama_cpp +from llama_cpp.server.errors import RouteErrorHandler from llama_cpp.server.model import ( LlamaProxy, ) from llama_cpp.server.settings import ( ConfigFileSettings, - Settings, ModelSettings, ServerSettings, + Settings, ) from llama_cpp.server.types import ( + CreateChatCompletionRequest, CreateCompletionRequest, CreateEmbeddingRequest, - CreateChatCompletionRequest, + DetokenizeInputRequest, + DetokenizeInputResponse, ModelList, + TokenizeInputCountResponse, TokenizeInputRequest, TokenizeInputResponse, - TokenizeInputCountResponse, - DetokenizeInputRequest, - DetokenizeInputResponse, ) -from llama_cpp.server.errors import RouteErrorHandler - router = APIRouter(route_class=RouteErrorHandler) @@ -150,7 +147,7 @@ def create_app( set_llama_proxy(model_settings=model_settings) if server_settings.disable_ping_events: - set_ping_message_factory(lambda: bytes()) + set_ping_message_factory(lambda: b"") return app @@ -248,7 +245,7 @@ async def authenticate( "schema": { "type": "string", "title": "Server Side Streaming response, when stream=True. " - + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501 + + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", "example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""", } }, @@ -386,7 +383,7 @@ async def create_embedding( "schema": { "type": "string", "title": "Server Side Streaming response, when stream=True" - + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501 + + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", "example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""", } }, diff --git a/llama_cpp/server/cli.py b/llama_cpp/server/cli.py index 3dd007676..b2b2ac6cd 100644 --- a/llama_cpp/server/cli.py +++ b/llama_cpp/server/cli.py @@ -1,8 +1,7 @@ from __future__ import annotations import argparse - -from typing import List, Literal, Union, Any, Type, TypeVar +from typing import Any, List, Literal, Type, TypeVar, Union from pydantic import BaseModel diff --git a/llama_cpp/server/errors.py b/llama_cpp/server/errors.py index fbf9fd80d..dae9960ba 100644 --- a/llama_cpp/server/errors.py +++ b/llama_cpp/server/errors.py @@ -1,25 +1,24 @@ from __future__ import annotations import sys -import traceback import time -from re import compile, Match, Pattern -from typing import Callable, Coroutine, Optional, Tuple, Union, Dict -from typing_extensions import TypedDict - +import traceback +from re import Match, Pattern, compile +from typing import Callable, Coroutine, Dict, Optional, Tuple, Union from fastapi import ( + HTTPException, Request, Response, - HTTPException, ) from fastapi.responses import JSONResponse from fastapi.routing import APIRoute +from typing_extensions import TypedDict from llama_cpp.server.types import ( + CreateChatCompletionRequest, CreateCompletionRequest, CreateEmbeddingRequest, - CreateChatCompletionRequest, ) @@ -46,7 +45,7 @@ class ErrorResponseFormatters: @staticmethod def context_length_exceeded( - request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], + request: Union[CreateCompletionRequest, CreateChatCompletionRequest], match, # type: Match[str] # type: ignore ) -> Tuple[int, ErrorResponse]: """Formatter for context length exceeded error""" @@ -84,7 +83,7 @@ def context_length_exceeded( @staticmethod def model_not_found( - request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], + request: Union[CreateCompletionRequest, CreateChatCompletionRequest], match, # type: Match[str] # type: ignore ) -> Tuple[int, ErrorResponse]: """Formatter for model_not_found error""" @@ -105,11 +104,11 @@ class RouteErrorHandler(APIRoute): # key: regex pattern for original error message from llama_cpp # value: formatter function pattern_and_formatters: Dict[ - "Pattern[str]", + Pattern[str], Callable[ [ - Union["CreateCompletionRequest", "CreateChatCompletionRequest"], - "Match[str]", + Union[CreateCompletionRequest, CreateChatCompletionRequest], + Match[str], ], Tuple[int, ErrorResponse], ], @@ -127,14 +126,14 @@ def error_message_wrapper( error: Exception, body: Optional[ Union[ - "CreateChatCompletionRequest", - "CreateCompletionRequest", - "CreateEmbeddingRequest", + CreateChatCompletionRequest, + CreateCompletionRequest, + CreateEmbeddingRequest, ] ] = None, ) -> Tuple[int, ErrorResponse]: """Wraps error message in OpenAI style error response""" - print(f"Exception: {str(error)}", file=sys.stderr) + print(f"Exception: {error!s}", file=sys.stderr) traceback.print_exc(file=sys.stderr) if body is not None and isinstance( body, diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index c6716f919..a39132650 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -1,13 +1,10 @@ from __future__ import annotations import json - -from typing import Dict, Optional, Union, List +from typing import Dict, List, Optional, Union import llama_cpp -import llama_cpp.llama_speculative as llama_speculative -import llama_cpp.llama_tokenizer as llama_tokenizer - +from llama_cpp import llama_speculative, llama_tokenizer from llama_cpp.server.settings import ModelSettings diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 13c951241..c4d665bf2 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -1,12 +1,11 @@ from __future__ import annotations import multiprocessing - -from typing import Optional, List, Literal, Union, Dict, cast -from typing_extensions import Self +from typing import Dict, List, Literal, Optional, Union, cast from pydantic import Field, model_validator from pydantic_settings import BaseSettings +from typing_extensions import Self import llama_cpp @@ -18,7 +17,7 @@ class ModelSettings(BaseSettings): """Model settings used to load a Llama model.""" model: str = Field( - description="The path to the model to use for generating completions." + description="The path to the model to use for generating completions.", ) model_alias: Optional[str] = Field( default=None, @@ -44,7 +43,7 @@ class ModelSettings(BaseSettings): description="Split layers across multiple GPUs in proportion.", ) vocab_only: bool = Field( - default=False, description="Whether to only return the vocabulary." + default=False, description="Whether to only return the vocabulary.", ) use_mmap: bool = Field( default=llama_cpp.llama_supports_mmap(), @@ -64,11 +63,11 @@ class ModelSettings(BaseSettings): ) # Context Params seed: int = Field( - default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random." + default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random.", ) n_ctx: int = Field(default=2048, ge=0, description="The context size.") n_batch: int = Field( - default=512, ge=1, description="The batch size to use per eval." + default=512, ge=1, description="The batch size to use per eval.", ) n_ubatch: int = Field( default=512, ge=1, description="The physical batch size used by llama.cpp" @@ -84,11 +83,11 @@ class ModelSettings(BaseSettings): description="The number of threads to use when batch processing. Use -1 for max cpu threads", ) rope_scaling_type: int = Field( - default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED + default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, ) rope_freq_base: float = Field(default=0.0, description="RoPE base frequency") rope_freq_scale: float = Field( - default=0.0, description="RoPE frequency scaling factor" + default=0.0, description="RoPE frequency scaling factor", ) yarn_ext_factor: float = Field(default=-1.0) yarn_attn_factor: float = Field(default=1.0) @@ -96,15 +95,15 @@ class ModelSettings(BaseSettings): yarn_beta_slow: float = Field(default=1.0) yarn_orig_ctx: int = Field(default=0) mul_mat_q: bool = Field( - default=True, description="if true, use experimental mul_mat_q kernels" + default=True, description="if true, use experimental mul_mat_q kernels", ) logits_all: bool = Field(default=True, description="Whether to return logits.") embedding: bool = Field(default=False, description="Whether to use embeddings.") offload_kqv: bool = Field( - default=True, description="Whether to offload kqv to the GPU." + default=True, description="Whether to offload kqv to the GPU.", ) flash_attn: bool = Field( - default=False, description="Whether to use flash attention." + default=False, description="Whether to use flash attention.", ) # Sampling Params last_n_tokens_size: int = Field( @@ -182,11 +181,11 @@ class ModelSettings(BaseSettings): ) # Misc verbose: bool = Field( - default=True, description="Whether to print debug information." + default=True, description="Whether to print debug information.", ) @model_validator( - mode="before" + mode="before", ) # pre=True to ensure this runs before any other validation def set_dynamic_defaults(self) -> Self: # If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count() @@ -206,10 +205,10 @@ class ServerSettings(BaseSettings): host: str = Field(default="localhost", description="Listen address") port: int = Field(default=8000, description="Listen port") ssl_keyfile: Optional[str] = Field( - default=None, description="SSL key file for HTTPS" + default=None, description="SSL key file for HTTPS", ) ssl_certfile: Optional[str] = Field( - default=None, description="SSL certificate file for HTTPS" + default=None, description="SSL certificate file for HTTPS", ) # FastAPI Settings api_key: Optional[str] = Field( diff --git a/llama_cpp/server/types.py b/llama_cpp/server/types.py index fdd164456..e95ab11ac 100644 --- a/llama_cpp/server/types.py +++ b/llama_cpp/server/types.py @@ -1,13 +1,12 @@ from __future__ import annotations -from typing import List, Optional, Union, Dict -from typing_extensions import TypedDict, Literal +from typing import Dict, List, Optional, Union from pydantic import BaseModel, Field +from typing_extensions import Literal, TypedDict import llama_cpp - model_field = Field( description="The model to use for generating completions.", default=None ) diff --git a/pyproject.toml b/pyproject.toml index 9983ef777..df59f12a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,8 @@ dev = [ "mkdocs-material>=9.1.18", "pytest>=7.4.0", "httpx>=0.24.1", + "pre-commit>=3.8.0", + "ruff>=0.5.7", ] all = [ "llama_cpp_python[server,test,dev]", diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..67644d9c2 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,127 @@ +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] + +# Same as Black. +line-length = 80 +indent-width = 4 + +# Assume Python 3.9 +target-version = "py39" + +[lint] +preview = true +explicit-preview-rules = true +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +select = [ + "A", # flake8-builtins + "AIR", # Airflow + "ANN", # flake8-annotations + "ARG", # flake8-unused-arguments + "ASYNC", # flake8-async + "B", # flake8-bugbear + "BLE", # flake8-blind-except + "C4", # flake8-comprehensions + "C90", # McCabe cyclomatic complexity + "COM", # flake8-commas + "CPY", # flake8-copyright + # "D", # pydocstyle + # "DJ", # flake8-django + # "DOC", # pydoclint + "DTZ", # flake8-datetimez + "E", # pycodestyle + "EM", # flake8-errmsg + "ERA", # eradicate + # "EXE", # flake8-executable + "F", # Pyflakes + "FA", # flake8-future-annotations + "FAST", # FastAPI + "FBT", # flake8-boolean-trap + "FIX", # flake8-pp + "FLY", # flynt + "FURB", # refurb + "G", # flake8-logging-format + "I", # isort + "ICN", # flake8-import-conventions + "INP", # flake8-no-pep420 + "INT", # flake8-gettext + "ISC", # flake8-implicit-str-concat + "LOG", # flake8-logging + "N", # pep8-naming + "NPY", # NumPy-specific rules + # "PD", # pandas-vet + "PERF", # Perflint + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PL", # Pylint + "PT", # flake8-pytest-style + "PTH", # flake8-use-pathlib + "PYI", # flake8-pyi + "Q", # flake8-quotes + "R", # Refactor + "RET", # flake8-return + "RSE", # flake8-raise + "RUF", # Ruff-specific rules + "S", # flake8-bandit + "SIM", # flake8-simplify + "SLF", # flake8-self + "SLOT", # flake8-slots + "T10", # flake8-debugger + "T20", # flake8-print + "TCH", # flake8-type-checking + # "TD", # flake8-todos + "TID", # flake8-tidy-imports + "TRY", # tryceratops + # "UP", # pyupgrade + "W", # pycodestyle + "YTT", # flake8-2020 +] +ignore = ["A001","A002","ANN001","ANN201","ANN202","COM812","E501","ERA001","F401","ISC001","T201"] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[format] +preview = true +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" diff --git a/tests/test_llama_chat_format.py b/tests/test_llama_chat_format.py index f031bf72b..0379ba4d1 100644 --- a/tests/test_llama_chat_format.py +++ b/tests/test_llama_chat_format.py @@ -4,12 +4,12 @@ from llama_cpp import ( ChatCompletionRequestUserMessage, + llama_chat_format, + llama_types, ) -import llama_cpp.llama_types as llama_types -import llama_cpp.llama_chat_format as llama_chat_format - from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter + def test_mistral_instruct(): chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" chat_formatter = jinja2.Template(chat_template) diff --git a/tests/test_llama_grammar.py b/tests/test_llama_grammar.py index 34ef2874d..7115c495d 100644 --- a/tests/test_llama_grammar.py +++ b/tests/test_llama_grammar.py @@ -1,6 +1,7 @@ -import llama_cpp import json +import llama_cpp + tree = """ leaf ::= "." node ::= leaf | "(" node node ")" diff --git a/tests/test_llama_speculative.py b/tests/test_llama_speculative.py index b5d450567..63e8adbb3 100644 --- a/tests/test_llama_speculative.py +++ b/tests/test_llama_speculative.py @@ -2,6 +2,7 @@ from llama_cpp.llama_speculative import LlamaPromptLookupDecoding + def test_find_candidate_pred_tokens(): find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens