Skip to content

Commit ab7e734

Browse files
committed
Add tool calling support.
1 parent 973c51c commit ab7e734

File tree

1 file changed

+83
-5
lines changed

1 file changed

+83
-5
lines changed

ads/llm/autogen/client_v02.py

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,60 @@
99
"""
1010
import copy
1111
import importlib
12+
import json
1213
import logging
1314
from typing import Dict, List, Union
1415
from types import SimpleNamespace
1516

16-
1717
from autogen import ModelClient
18+
from autogen.oai.client import OpenAIWrapper
1819
from langchain_core.messages import AIMessage
1920

21+
2022
logger = logging.getLogger(__name__)
2123

2224

25+
def register_custom_client(client_class):
26+
"""Registers custom client with AutoGen."""
27+
if not hasattr(OpenAIWrapper, "custom_clients"):
28+
raise AttributeError(
29+
"The AutoGen version you install does not support auto custom client registration."
30+
)
31+
if client_class not in OpenAIWrapper.custom_clients:
32+
OpenAIWrapper.custom_clients.append(client_class)
33+
34+
35+
def _convert_to_langchain_tool(tool):
36+
if tool["type"] == "function":
37+
tool = tool["function"]
38+
required = tool["parameters"]["required"]
39+
properties = copy.deepcopy(tool["parameters"]["properties"])
40+
for key in properties.keys():
41+
val = properties[key]
42+
val["default"] = key in required
43+
return {
44+
"title": tool["name"],
45+
"description": tool["description"],
46+
"properties": properties,
47+
}
48+
raise NotImplementedError(f"Type {tool['type']} is not supported.")
49+
50+
51+
def _convert_to_openai_tool_call(tool_call):
52+
return {
53+
"id": tool_call.get("id"),
54+
"function": {
55+
"name": tool_call.get("name"),
56+
"arguments": (
57+
""
58+
if tool_call.get("args") is None
59+
else json.dumps(tool_call.get("args"))
60+
),
61+
},
62+
"type": "function",
63+
}
64+
65+
2366
class Message(AIMessage):
2467
"""Represents message returned from the LLM."""
2568

@@ -28,6 +71,9 @@ def from_message(cls, message: AIMessage):
2871
"""Converts from LangChain AIMessage."""
2972
message = copy.deepcopy(message)
3073
message.__class__ = cls
74+
message.tool_calls = [
75+
_convert_to_openai_tool_call(tool) for tool in message.tool_calls
76+
]
3177
return message
3278

3379
@property
@@ -42,23 +88,55 @@ class LangChainModelClient(ModelClient):
4288
def __init__(self, config: dict, **kwargs) -> None:
4389
super().__init__()
4490
logger.info("LangChain model client config: %s", str(config))
91+
# Make a copy of the config since we are popping the keys
92+
config = copy.deepcopy(config)
4593
self.client_class = config.pop("model_client_cls")
46-
# Parameters for the model
94+
95+
self.function_call_params = config.pop("function_call_params", {})
96+
4797
self.model_name = config.get("model")
98+
4899
# Import the LangChain class
49100
if "langchain_cls" not in config:
50101
raise ValueError("Missing langchain_cls in LangChain Model Client config.")
51102
module_cls = config.pop("langchain_cls")
52103
module_name, cls_name = str(module_cls).rsplit(".", 1)
53104
langchain_module = importlib.import_module(module_name)
54105
langchain_cls = getattr(langchain_module, cls_name)
106+
55107
# Initialize the LangChain client
56108
self.model = langchain_cls(**config)
57109

58110
def create(self, params) -> ModelClient.ModelClientResponseProtocol:
111+
"""Creates a LLM completion for a given config.
112+
113+
Parameters
114+
----------
115+
params : dict
116+
OpenAI API compatible parameters, including all the keys from llm_config.
117+
118+
Returns
119+
-------
120+
ModelClientResponseProtocol
121+
Response from LLM
122+
123+
"""
59124
streaming = params.get("stream", False)
125+
# TODO: num_of_responses
60126
num_of_responses = params.get("n", 1)
61-
messages = params.get("messages", [])
127+
messages = params.pop("messages", [])
128+
129+
invoke_params = {}
130+
131+
tools = params.get("tools")
132+
if tools:
133+
model = self.model.bind_tools(
134+
[_convert_to_langchain_tool(tool) for tool in tools]
135+
)
136+
# invoke_params["tools"] = tools
137+
invoke_params.update(self.function_call_params)
138+
else:
139+
model = self.model
62140

63141
response = SimpleNamespace()
64142
response.choices = []
@@ -69,7 +147,7 @@ def create(self, params) -> ModelClient.ModelClientResponseProtocol:
69147
raise NotImplementedError()
70148
else:
71149
# If streaming is not enabled, send a regular chat completion request
72-
ai_message = self.model.invoke(messages)
150+
ai_message = model.invoke(messages, **invoke_params)
73151
choice = SimpleNamespace()
74152
choice.message = Message.from_message(ai_message)
75153
response.choices.append(choice)
@@ -84,7 +162,7 @@ def message_retrieval(
84162
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
85163
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
86164
"""
87-
return [choice.message.content for choice in response.choices]
165+
return [choice.message for choice in response.choices]
88166

89167
def cost(self, response: ModelClient.ModelClientResponseProtocol) -> float:
90168
response.cost = 0

0 commit comments

Comments
 (0)