From a5e130f2d5703d57061a08690bf21ac45e4e2d85 Mon Sep 17 00:00:00 2001 From: John6666 <186692226+John6666cat@users.noreply.github.com> Date: Tue, 29 Oct 2024 21:30:26 +0900 Subject: [PATCH] Add a series of Mistral templates that more closely resemble the standard The current `MISTRAL` template appears to be from a time when the specifications were not finalized, and there is some confusion. It works well, and in fact I use it on a daily basis. However, I have received information from a user who uses LLM on a daily basis that is closer to the current Mistral specification, so I am committing the revised code based on this information. For compatibility, `MISTRAL` has been left as is, and the new format has been added at the end, but please change the order of appearance as you see fit. See below for more information on formats. https://huggingface.co/spaces/John6666/text2tag-llm/discussions/4#671b481656288fee065beabb https://github.com/inflatebot/SillyTavern-Mistral-Templates/tree/main/Instruct https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md --- src/llama_cpp_agent/messages_formatter.py | 58 ++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/src/llama_cpp_agent/messages_formatter.py b/src/llama_cpp_agent/messages_formatter.py index fd83a53..58d7b90 100644 --- a/src/llama_cpp_agent/messages_formatter.py +++ b/src/llama_cpp_agent/messages_formatter.py @@ -28,6 +28,10 @@ class MessagesFormatterType(Enum): AUTOCODER = 15 GEMMA_2 = 16 DEEP_SEEK_CODER_2 = 17 + MISTRAL_V1 = 18 + MISTRAL_V2 = 19 + MISTRAL_V3_TEKKEN = 20 + @dataclass class PromptMarkers: @@ -178,12 +182,14 @@ def _format_response( Roles.assistant: PromptMarkers("""### Assistant:\n""", """\n"""), Roles.tool: PromptMarkers("", ""), } + gemma_2_prompt_markers = { Roles.system: PromptMarkers("""""", """\n\n"""), Roles.user: PromptMarkers("""user\n""", """\n"""), Roles.assistant: PromptMarkers("""model\n""", """\n"""), Roles.tool: PromptMarkers("", ""), } + code_ds_prompt_markers = { Roles.system: PromptMarkers("", """\n\n"""), Roles.user: PromptMarkers("""@@ Instruction\n""", """\n\n"""), @@ -225,18 +231,21 @@ def _format_response( Roles.assistant: PromptMarkers("""<|assistant|>""", """<|end|>\n"""), Roles.tool: PromptMarkers("", ""), } + open_interpreter_chat_prompt_markers = { Roles.system: PromptMarkers("", "\n\n"), Roles.user: PromptMarkers("### Instruction:\n", "\n"), Roles.assistant: PromptMarkers("### Response:\n", "\n"), Roles.tool: PromptMarkers("", ""), } + autocoder_chat_prompt_markers = { Roles.system: PromptMarkers("", "\n"), Roles.user: PromptMarkers("Human: ", "\n"), Roles.assistant: PromptMarkers("Assistant: ", "<|EOT|>\n"), Roles.tool: PromptMarkers("", ""), } + deep_seek_coder_2_chat_prompt_markers = { Roles.system: PromptMarkers("""<|begin▁of▁sentence|>""", """\n\n"""), Roles.user: PromptMarkers("""User: """, """ \n\n"""), @@ -244,6 +253,28 @@ def _format_response( Roles.tool: PromptMarkers("", ""), } +mistral_v1_markers = { + Roles.system: PromptMarkers(""" [INST]""", """ [/INST]"""), + Roles.user: PromptMarkers(""" [INST]""", """ [/INST]"""), + Roles.assistant: PromptMarkers(""" """, """"""), + Roles.tool: PromptMarkers("", ""), +} + +mistral_v2_markers = { + Roles.system: PromptMarkers("""[INST] """, """[/INST]"""), + Roles.user: PromptMarkers("""[INST] """, """[/INST]"""), + Roles.assistant: PromptMarkers(""" """, """"""), + Roles.tool: PromptMarkers("", ""), +} + +mistral_v3_tekken_markers = { + Roles.system: PromptMarkers("""[INST]""", """[/INST]"""), + Roles.user: PromptMarkers("""[INST]""", """[/INST]"""), + Roles.assistant: PromptMarkers("""""", """"""), + Roles.tool: PromptMarkers("", ""), +} + + """ ### Instruction: {prompt} @@ -381,6 +412,28 @@ def _format_response( eos_token="<|end▁of▁sentence|>", ) +mistral_v1_formatter = MessagesFormatter( + "", + mistral_v1_markers, + False, + [""], +) + +mistral_v2_formatter = MessagesFormatter( + "", + mistral_v2_markers, + False, + [""], +) + +mistral_v3_tekken_formatter = MessagesFormatter( + "", + mistral_v3_tekken_markers, + False, + [""], +) + + predefined_formatter = { MessagesFormatterType.MISTRAL: mixtral_formatter, MessagesFormatterType.CHATML: chatml_formatter, @@ -398,7 +451,10 @@ def _format_response( MessagesFormatterType.OPEN_INTERPRETER: open_interpreter_chat_formatter, MessagesFormatterType.AUTOCODER: autocoder_chat_formatter, MessagesFormatterType.GEMMA_2: gemma_2_chat_formatter, - MessagesFormatterType.DEEP_SEEK_CODER_2: deep_seek_coder_2_chat_formatter + MessagesFormatterType.DEEP_SEEK_CODER_2: deep_seek_coder_2_chat_formatter, + MessagesFormatterType.MISTRAL_V1: mistral_v1_formatter, + MessagesFormatterType.MISTRAL_V2: mistral_v2_formatter, + MessagesFormatterType.MISTRAL_V3_TEKKEN: mistral_v3_tekken_formatter }