diff --git a/src/llama_cpp_agent/messages_formatter.py b/src/llama_cpp_agent/messages_formatter.py index fd83a53..58d7b90 100644 --- a/src/llama_cpp_agent/messages_formatter.py +++ b/src/llama_cpp_agent/messages_formatter.py @@ -28,6 +28,10 @@ class MessagesFormatterType(Enum): AUTOCODER = 15 GEMMA_2 = 16 DEEP_SEEK_CODER_2 = 17 + MISTRAL_V1 = 18 + MISTRAL_V2 = 19 + MISTRAL_V3_TEKKEN = 20 + @dataclass class PromptMarkers: @@ -178,12 +182,14 @@ def _format_response( Roles.assistant: PromptMarkers("""### Assistant:\n""", """\n"""), Roles.tool: PromptMarkers("", ""), } + gemma_2_prompt_markers = { Roles.system: PromptMarkers("""""", """\n\n"""), Roles.user: PromptMarkers("""user\n""", """\n"""), Roles.assistant: PromptMarkers("""model\n""", """\n"""), Roles.tool: PromptMarkers("", ""), } + code_ds_prompt_markers = { Roles.system: PromptMarkers("", """\n\n"""), Roles.user: PromptMarkers("""@@ Instruction\n""", """\n\n"""), @@ -225,18 +231,21 @@ def _format_response( Roles.assistant: PromptMarkers("""<|assistant|>""", """<|end|>\n"""), Roles.tool: PromptMarkers("", ""), } + open_interpreter_chat_prompt_markers = { Roles.system: PromptMarkers("", "\n\n"), Roles.user: PromptMarkers("### Instruction:\n", "\n"), Roles.assistant: PromptMarkers("### Response:\n", "\n"), Roles.tool: PromptMarkers("", ""), } + autocoder_chat_prompt_markers = { Roles.system: PromptMarkers("", "\n"), Roles.user: PromptMarkers("Human: ", "\n"), Roles.assistant: PromptMarkers("Assistant: ", "<|EOT|>\n"), Roles.tool: PromptMarkers("", ""), } + deep_seek_coder_2_chat_prompt_markers = { Roles.system: PromptMarkers("""<|begin▁of▁sentence|>""", """\n\n"""), Roles.user: PromptMarkers("""User: """, """ \n\n"""), @@ -244,6 +253,28 @@ def _format_response( Roles.tool: PromptMarkers("", ""), } +mistral_v1_markers = { + Roles.system: PromptMarkers(""" [INST]""", """ [/INST]"""), + Roles.user: PromptMarkers(""" [INST]""", """ [/INST]"""), + Roles.assistant: PromptMarkers(""" """, """"""), + Roles.tool: PromptMarkers("", ""), +} + +mistral_v2_markers = { + Roles.system: PromptMarkers("""[INST] """, """[/INST]"""), + Roles.user: PromptMarkers("""[INST] """, """[/INST]"""), + Roles.assistant: PromptMarkers(""" """, """"""), + Roles.tool: PromptMarkers("", ""), +} + +mistral_v3_tekken_markers = { + Roles.system: PromptMarkers("""[INST]""", """[/INST]"""), + Roles.user: PromptMarkers("""[INST]""", """[/INST]"""), + Roles.assistant: PromptMarkers("""""", """"""), + Roles.tool: PromptMarkers("", ""), +} + + """ ### Instruction: {prompt} @@ -381,6 +412,28 @@ def _format_response( eos_token="<|end▁of▁sentence|>", ) +mistral_v1_formatter = MessagesFormatter( + "", + mistral_v1_markers, + False, + [""], +) + +mistral_v2_formatter = MessagesFormatter( + "", + mistral_v2_markers, + False, + [""], +) + +mistral_v3_tekken_formatter = MessagesFormatter( + "", + mistral_v3_tekken_markers, + False, + [""], +) + + predefined_formatter = { MessagesFormatterType.MISTRAL: mixtral_formatter, MessagesFormatterType.CHATML: chatml_formatter, @@ -398,7 +451,10 @@ def _format_response( MessagesFormatterType.OPEN_INTERPRETER: open_interpreter_chat_formatter, MessagesFormatterType.AUTOCODER: autocoder_chat_formatter, MessagesFormatterType.GEMMA_2: gemma_2_chat_formatter, - MessagesFormatterType.DEEP_SEEK_CODER_2: deep_seek_coder_2_chat_formatter + MessagesFormatterType.DEEP_SEEK_CODER_2: deep_seek_coder_2_chat_formatter, + MessagesFormatterType.MISTRAL_V1: mistral_v1_formatter, + MessagesFormatterType.MISTRAL_V2: mistral_v2_formatter, + MessagesFormatterType.MISTRAL_V3_TEKKEN: mistral_v3_tekken_formatter }