9
9
"""
10
10
import copy
11
11
import importlib
12
+ import json
12
13
import logging
13
14
from typing import Dict , List , Union
14
15
from types import SimpleNamespace
15
16
16
-
17
17
from autogen import ModelClient
18
+ from autogen .oai .client import OpenAIWrapper
18
19
from langchain_core .messages import AIMessage
19
20
21
+
20
22
logger = logging .getLogger (__name__ )
21
23
22
24
25
+ def register_custom_client (client_class ):
26
+ """Registers custom client with AutoGen."""
27
+ if not hasattr (OpenAIWrapper , "custom_clients" ):
28
+ raise AttributeError (
29
+ "The AutoGen version you install does not support auto custom client registration."
30
+ )
31
+ if client_class not in OpenAIWrapper .custom_clients :
32
+ OpenAIWrapper .custom_clients .append (client_class )
33
+
34
+
35
+ def _convert_to_langchain_tool (tool ):
36
+ if tool ["type" ] == "function" :
37
+ tool = tool ["function" ]
38
+ required = tool ["parameters" ]["required" ]
39
+ properties = copy .deepcopy (tool ["parameters" ]["properties" ])
40
+ for key in properties .keys ():
41
+ val = properties [key ]
42
+ val ["default" ] = key in required
43
+ return {
44
+ "title" : tool ["name" ],
45
+ "description" : tool ["description" ],
46
+ "properties" : properties ,
47
+ }
48
+ raise NotImplementedError (f"Type { tool ['type' ]} is not supported." )
49
+
50
+
51
+ def _convert_to_openai_tool_call (tool_call ):
52
+ return {
53
+ "id" : tool_call .get ("id" ),
54
+ "function" : {
55
+ "name" : tool_call .get ("name" ),
56
+ "arguments" : (
57
+ ""
58
+ if tool_call .get ("args" ) is None
59
+ else json .dumps (tool_call .get ("args" ))
60
+ ),
61
+ },
62
+ "type" : "function" ,
63
+ }
64
+
65
+
23
66
class Message (AIMessage ):
24
67
"""Represents message returned from the LLM."""
25
68
@@ -28,6 +71,9 @@ def from_message(cls, message: AIMessage):
28
71
"""Converts from LangChain AIMessage."""
29
72
message = copy .deepcopy (message )
30
73
message .__class__ = cls
74
+ message .tool_calls = [
75
+ _convert_to_openai_tool_call (tool ) for tool in message .tool_calls
76
+ ]
31
77
return message
32
78
33
79
@property
@@ -42,23 +88,55 @@ class LangChainModelClient(ModelClient):
42
88
def __init__ (self , config : dict , ** kwargs ) -> None :
43
89
super ().__init__ ()
44
90
logger .info ("LangChain model client config: %s" , str (config ))
91
+ # Make a copy of the config since we are popping the keys
92
+ config = copy .deepcopy (config )
45
93
self .client_class = config .pop ("model_client_cls" )
46
- # Parameters for the model
94
+
95
+ self .function_call_params = config .pop ("function_call_params" , {})
96
+
47
97
self .model_name = config .get ("model" )
98
+
48
99
# Import the LangChain class
49
100
if "langchain_cls" not in config :
50
101
raise ValueError ("Missing langchain_cls in LangChain Model Client config." )
51
102
module_cls = config .pop ("langchain_cls" )
52
103
module_name , cls_name = str (module_cls ).rsplit ("." , 1 )
53
104
langchain_module = importlib .import_module (module_name )
54
105
langchain_cls = getattr (langchain_module , cls_name )
106
+
55
107
# Initialize the LangChain client
56
108
self .model = langchain_cls (** config )
57
109
58
110
def create (self , params ) -> ModelClient .ModelClientResponseProtocol :
111
+ """Creates a LLM completion for a given config.
112
+
113
+ Parameters
114
+ ----------
115
+ params : dict
116
+ OpenAI API compatible parameters, including all the keys from llm_config.
117
+
118
+ Returns
119
+ -------
120
+ ModelClientResponseProtocol
121
+ Response from LLM
122
+
123
+ """
59
124
streaming = params .get ("stream" , False )
125
+ # TODO: num_of_responses
60
126
num_of_responses = params .get ("n" , 1 )
61
- messages = params .get ("messages" , [])
127
+ messages = params .pop ("messages" , [])
128
+
129
+ invoke_params = {}
130
+
131
+ tools = params .get ("tools" )
132
+ if tools :
133
+ model = self .model .bind_tools (
134
+ [_convert_to_langchain_tool (tool ) for tool in tools ]
135
+ )
136
+ # invoke_params["tools"] = tools
137
+ invoke_params .update (self .function_call_params )
138
+ else :
139
+ model = self .model
62
140
63
141
response = SimpleNamespace ()
64
142
response .choices = []
@@ -69,7 +147,7 @@ def create(self, params) -> ModelClient.ModelClientResponseProtocol:
69
147
raise NotImplementedError ()
70
148
else :
71
149
# If streaming is not enabled, send a regular chat completion request
72
- ai_message = self . model .invoke (messages )
150
+ ai_message = model .invoke (messages , ** invoke_params )
73
151
choice = SimpleNamespace ()
74
152
choice .message = Message .from_message (ai_message )
75
153
response .choices .append (choice )
@@ -84,7 +162,7 @@ def message_retrieval(
84
162
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
85
163
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
86
164
"""
87
- return [choice .message . content for choice in response .choices ]
165
+ return [choice .message for choice in response .choices ]
88
166
89
167
def cost (self , response : ModelClient .ModelClientResponseProtocol ) -> float :
90
168
response .cost = 0
0 commit comments