From 3ae76f74ded6ef2d71279c29c1eac4fa57f717ef Mon Sep 17 00:00:00 2001 From: laviphon Date: Thu, 27 Feb 2025 18:02:06 -0500 Subject: [PATCH 1/4] this should allow custom endpoints to use functions right out of the box --- assistant/commands/admin.py | 12 +++++++++++ assistant/common/calls.py | 40 +++++++++++++++++++++---------------- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/assistant/commands/admin.py b/assistant/commands/admin.py index 2351f26a..88632bb6 100644 --- a/assistant/commands/admin.py +++ b/assistant/commands/admin.py @@ -2131,3 +2131,15 @@ async def toggle_bot_listen(self, ctx: commands.Context): self.db.listen_to_bots = True await ctx.send(_("Assistant will listen to other bot messages")) await self.save_conf() + + @assistant.command(name="toolformat") + async def toggle_tool_formatting(self, ctx: commands.Context, true_or_false: bool): + + match true_or_false: + case True: + self.db.tool_format = True + await ctx.send("Assistant will now send functions via tools.") + case False: + self.db.tool_format = False + await ctx.send("Assistant will now send functions via functions.") + await self.save_conf() \ No newline at end of file diff --git a/assistant/common/calls.py b/assistant/common/calls.py index 05fb3f9f..3a1d181b 100644 --- a/assistant/common/calls.py +++ b/assistant/common/calls.py @@ -9,6 +9,8 @@ from openai.types.chat import ChatCompletion from pydantic import BaseModel from sentry_sdk import add_breadcrumb +from ..common.models import DB + from tenacity import ( retry, retry_if_exception_type, @@ -45,6 +47,7 @@ async def request_chat_completion_raw( seed: int = None, base_url: Optional[str] = None, reasoning_effort: Optional[str] = None, + db: Optional[any] = None, ) -> ChatCompletion: client = openai.AsyncOpenAI(api_key=api_key, base_url=base_url) @@ -66,27 +69,30 @@ async def request_chat_completion_raw( if seed and model in SUPPORTS_SEED: kwargs["seed"] = seed - if functions and model not in NO_DEVELOPER_ROLE: - if model in SUPPORTS_TOOLS: - tools = [] - for func in functions: - function = {"type": "function", "function": func, "name": func["name"]} - tools.append(function) - if tools: - kwargs["tools"] = tools - # If passing tools, make sure the messages payload has no "function_call" key + if functions: + if model not in NO_DEVELOPER_ROLE or db.endpoint_override: + + if model in SUPPORTS_TOOLS or db.tool_format: + tools = [] + for func in functions: + function = {"type": "function", "function": func, "name": func["name"]} + tools.append(function) + if tools: + kwargs["tools"] = tools + # If passing tools, make sure the messages payload has no "function_call" key + for idx, message in enumerate(messages): + if "function_call" in message: + # Remove the message from the payload + del kwargs["messages"][idx] + + else: + kwargs["functions"] = functions + # If passing functions, make sure the messages payload has no tool calls for idx, message in enumerate(messages): - if "function_call" in message: + if "tool_calls" in message: # Remove the message from the payload del kwargs["messages"][idx] - else: - kwargs["functions"] = functions - # If passing functions, make sure the messages payload has no tool calls - for idx, message in enumerate(messages): - if "tool_calls" in message: - # Remove the message from the payload - del kwargs["messages"][idx] add_breadcrumb( category="api", From 507f9f156bb818af0dd9f80b3c7e96037ed50fa6 Mon Sep 17 00:00:00 2001 From: laviphon Date: Thu, 27 Feb 2025 18:31:47 -0500 Subject: [PATCH 2/4] fixes tweaks and polish --- assistant/commands/admin.py | 6 ++++++ assistant/common/api.py | 1 + assistant/common/models.py | 1 + 3 files changed, 8 insertions(+) diff --git a/assistant/commands/admin.py b/assistant/commands/admin.py index 88632bb6..a584dd5c 100644 --- a/assistant/commands/admin.py +++ b/assistant/commands/admin.py @@ -100,6 +100,7 @@ async def view_settings(self, ctx: commands.Context, private: bool = False): + _("`System Prompt: `{} tokens\n").format(humanize_number(system_tokens)) + _("`User Prompt: `{} tokens\n").format(humanize_number(prompt_tokens)) + _("`Endpoint Override: `{}\n").format(self.db.endpoint_override) + + _("'Tool Output Format: `{}\n").format(self.db.tool_format) ) embed = discord.Embed( @@ -2133,8 +2134,13 @@ async def toggle_bot_listen(self, ctx: commands.Context): await self.save_conf() @assistant.command(name="toolformat") + @commands.is_owner() async def toggle_tool_formatting(self, ctx: commands.Context, true_or_false: bool): + """ + Assistant will submit enabled functions to your endpoint as tools instead of functions. + Useful for troubleshooting function calling with a custom endpoint. + """ match true_or_false: case True: self.db.tool_format = True diff --git a/assistant/common/api.py b/assistant/common/api.py index 898ce51a..f4cb5a17 100644 --- a/assistant/common/api.py +++ b/assistant/common/api.py @@ -96,6 +96,7 @@ async def request_response( presence_penalty=conf.presence_penalty, seed=conf.seed, base_url=self.db.endpoint_override, + db=self.db, ) message: ChatCompletionMessage = response.choices[0].message diff --git a/assistant/common/models.py b/assistant/common/models.py index cba20619..90308739 100644 --- a/assistant/common/models.py +++ b/assistant/common/models.py @@ -345,6 +345,7 @@ class DB(AssistantBaseModel): listen_to_bots: bool = False brave_api_key: t.Optional[str] = None endpoint_override: t.Optional[str] = None + tool_format: t.Optional[bool] = False def get_conf(self, guild: t.Union[discord.Guild, int]) -> GuildSettings: gid = guild if isinstance(guild, int) else guild.id From e3fe22c70ae90b302865b05a82739c2c9db447fd Mon Sep 17 00:00:00 2001 From: laviphon Date: Thu, 27 Feb 2025 19:25:04 -0500 Subject: [PATCH 3/4] i may be dyslexic --- assistant/common/calls.py | 40 +++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/assistant/common/calls.py b/assistant/common/calls.py index 3a1d181b..8a2436b8 100644 --- a/assistant/common/calls.py +++ b/assistant/common/calls.py @@ -69,30 +69,30 @@ async def request_chat_completion_raw( if seed and model in SUPPORTS_SEED: kwargs["seed"] = seed - if functions: - if model not in NO_DEVELOPER_ROLE or db.endpoint_override: - - if model in SUPPORTS_TOOLS or db.tool_format: - tools = [] - for func in functions: - function = {"type": "function", "function": func, "name": func["name"]} - tools.append(function) - if tools: - kwargs["tools"] = tools - # If passing tools, make sure the messages payload has no "function_call" key - for idx, message in enumerate(messages): - if "function_call" in message: - # Remove the message from the payload - del kwargs["messages"][idx] - - else: - kwargs["functions"] = functions - # If passing functions, make sure the messages payload has no tool calls + if functions: + if model not in NO_DEVELOPER_ROLE or db.endpoint_override: + + if model in SUPPORTS_TOOLS or db.tool_format: + tools = [] + for func in functions: + function = {"type": "function", "function": func, "name": func["name"]} + tools.append(function) + if tools: + kwargs["tools"] = tools + # If passing tools, make sure the messages payload has no "function_call" key for idx, message in enumerate(messages): - if "tool_calls" in message: + if "function_call" in message: # Remove the message from the payload del kwargs["messages"][idx] + else: + kwargs["functions"] = functions + # If passing functions, make sure the messages payload has no tool calls + for idx, message in enumerate(messages): + if "tool_calls" in message: + # Remove the message from the payload + del kwargs["messages"][idx] + add_breadcrumb( category="api", From 3fbdfc3bce451c199e4c4de326f2be010cb87f5e Mon Sep 17 00:00:00 2001 From: laviphon Date: Thu, 27 Feb 2025 20:10:05 -0500 Subject: [PATCH 4/4] fix view embed error --- assistant/commands/admin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/assistant/commands/admin.py b/assistant/commands/admin.py index a584dd5c..b09c0657 100644 --- a/assistant/commands/admin.py +++ b/assistant/commands/admin.py @@ -100,7 +100,6 @@ async def view_settings(self, ctx: commands.Context, private: bool = False): + _("`System Prompt: `{} tokens\n").format(humanize_number(system_tokens)) + _("`User Prompt: `{} tokens\n").format(humanize_number(prompt_tokens)) + _("`Endpoint Override: `{}\n").format(self.db.endpoint_override) - + _("'Tool Output Format: `{}\n").format(self.db.tool_format) ) embed = discord.Embed( @@ -170,6 +169,7 @@ async def view_settings(self, ctx: commands.Context, private: bool = False): custom_func_field = ( _("`Function Calling: `{}\n").format(conf.use_function_calls) + + _("`Tool Output Format: `{}\n").format(self.db.tool_format) + _("`Maximum Recursion: `{}\n").format(conf.max_function_calls) + _("`Function Tokens: `{}\n").format(humanize_number(func_tokens)) )