Skip to content

Commit 4c0e37d

Browse files
committed
Patch AutoGen OpenAIWrapper
1 parent 9f2c0f4 commit 4c0e37d

File tree

1 file changed

+108
-13
lines changed

1 file changed

+108
-13
lines changed

ads/llm/autogen/client_v02.py

Lines changed: 108 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,131 @@
22
# Copyright (c) 2016, 2024, Oracle and/or its affiliates. All rights reserved.
33
# This software is dual-licensed to you under the Universal Permissive License (UPL) 1.0 as shown at https://oss.oracle.com/licenses/upl or Apache License 2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0. You may choose either license.
44

5-
"""This module contains the LangChain LLM client for AutoGen
6-
# References:
7-
# https://microsoft.github.io/autogen/0.2/docs/notebooks/agentchat_huggingface_langchain/
8-
# https://github.com/microsoft/autogen/blob/0.2/notebook/agentchat_custom_model.ipynb
5+
"""This module contains the custom LLM client for AutoGen v0.2 to use LangChain chat models.
6+
https://microsoft.github.io/autogen/0.2/blog/2024/01/26/Custom-Models/
7+
8+
To use the custom client:
9+
1. Prepare the LLM config, including the parameters for initializing the LangChain client.
10+
2. Register the custom LLM
11+
12+
The LLM config should config the following keys:
13+
* model_client_cls: Required by AutoGen to identify the custom client. It should be "LangChainModelClient"
14+
* langchain_cls: LangChain class including the full import path.
15+
* model: Name of the model to be used by AutoGen
16+
* client_params: A dictionary containing the parameters to initialize the LangChain chat model.
17+
18+
Although the `LangChainModelClient` is designed to be generic and can potentially support any LangChain chat model,
19+
the invocation depends on the server API spec and it may not be compatible with some implementations.
20+
21+
Following is an example config for OCI Generative AI service:
22+
{
23+
"model_client_cls": "LangChainModelClient",
24+
"langchain_cls": "langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI",
25+
"model": "cohere.command-r-plus",
26+
# client_params will be used to initialize the LangChain ChatOCIGenAI class.
27+
"client_params": {
28+
"model_id": "cohere.command-r-plus",
29+
"compartment_id": COMPARTMENT_OCID,
30+
"model_kwargs": {"temperature": 0, "max_tokens": 2048},
31+
# Update the authentication method as needed
32+
"auth_type": "SECURITY_TOKEN",
33+
"auth_profile": "DEFAULT",
34+
# You may need to specify `service_endpoint` if the service is in a different region.
35+
},
36+
}
37+
38+
Following is an example config for OCI Data Science Model Deployment:
39+
{
40+
"model_client_cls": "LangChainModelClient",
41+
"langchain_cls": "ads.llm.ChatOCIModelDeploymentVLLM",
42+
"model": "odsc-llm",
43+
"endpoint": "https://MODEL_DEPLOYMENT_URL/predict",
44+
"model_kwargs": {"temperature": 0.1, "max_tokens": 2048},
45+
# function_call_params will only be added to the API call when function/tools are added.
46+
"function_call_params": {
47+
"tool_choice": "auto",
48+
"chat_template": ChatTemplates.mistral(),
49+
},
50+
}
51+
52+
Note that if `client_params` is not specified in the config, all arguments from the config except
53+
`model_client_cls` and `langchain_cls`, and `function_call_params`, will be used to initialize
54+
the LangChain chat model.
55+
56+
The `function_call_params` will only be used for function/tool calling when tools are specified.
57+
58+
To register the custom client:
59+
60+
from ads.llm.autogen.client_v02 import LangChainModelClient, register_custom_client
61+
register_custom_client(LangChainModelClient)
62+
63+
Once registered with ADS, the custom LLM class will be auto-registered for all new agents.
64+
There is no need to call `register_model_client()` on each agent.
65+
66+
References:
67+
https://microsoft.github.io/autogen/0.2/docs/notebooks/agentchat_huggingface_langchain/
68+
https://github.com/microsoft/autogen/blob/0.2/notebook/agentchat_custom_model.ipynb
69+
970
"""
1071
import copy
1172
import importlib
1273
import json
1374
import logging
14-
from typing import Dict, List, Union
75+
from typing import Any, Dict, List, Union
1576
from types import SimpleNamespace
1677

1778
from autogen import ModelClient
18-
from autogen.oai.client import OpenAIWrapper
79+
from autogen.oai.client import OpenAIWrapper, PlaceHolderClient
1980
from langchain_core.messages import AIMessage
2081

2182

2283
logger = logging.getLogger(__name__)
2384

85+
# custom_clients is a dictionary mapping the name of the class to the actual class
86+
custom_clients = {}
87+
88+
# There is a bug in GroupChat when using custom client:
89+
# https://github.com/microsoft/autogen/issues/2956
90+
# Here we will be patching the OpenAIWrapper to fix the issue.
91+
# With this patch, you only need to register the client once with ADS.
92+
# For example:
93+
#
94+
# from ads.llm.autogen.client_v02 import LangChainModelClient, register_custom_client
95+
# register_custom_client(LangChainModelClient)
96+
#
97+
# This patch will auto-register the custom LLM to all new agents.
98+
# So there is no need to call `register_model_client()` on each agent.
99+
OpenAIWrapper._original_register_default_client = OpenAIWrapper._register_default_client
100+
101+
102+
def _new_register_default_client(
103+
self: OpenAIWrapper, config: Dict[str, Any], openai_config: Dict[str, Any]
104+
) -> None:
105+
"""This is a patched version of the _register_default_client() method
106+
to automatically register custom client for agents.
107+
"""
108+
model_client_cls_name = config.get("model_client_cls")
109+
if model_client_cls_name in custom_clients:
110+
self._clients.append(PlaceHolderClient(config))
111+
self.register_model_client(custom_clients[model_client_cls_name])
112+
else:
113+
self._original_register_default_client(
114+
config=config, openai_config=openai_config
115+
)
116+
117+
118+
# Patch the _register_default_client() method
119+
OpenAIWrapper._register_default_client = _new_register_default_client
120+
24121

25122
def register_custom_client(client_class):
26-
"""Registers custom client with AutoGen."""
27-
if not hasattr(OpenAIWrapper, "custom_clients"):
28-
raise AttributeError(
29-
"The AutoGen version you install does not support auto custom client registration."
30-
)
31-
if client_class not in OpenAIWrapper.custom_clients:
32-
OpenAIWrapper.custom_clients.append(client_class)
123+
"""Registers custom client for AutoGen."""
124+
if client_class.__name__ not in custom_clients:
125+
custom_clients[client_class.__name__] = client_class
33126

34127

35128
def _convert_to_langchain_tool(tool):
129+
"""Converts the OpenAI tool spec to LangChain tool spec."""
36130
if tool["type"] == "function":
37131
tool = tool["function"]
38132
required = tool["parameters"]["required"]
@@ -49,6 +143,7 @@ def _convert_to_langchain_tool(tool):
49143

50144

51145
def _convert_to_openai_tool_call(tool_call):
146+
"""Converts the LangChain tool call in AI message to OpenAI tool call."""
52147
return {
53148
"id": tool_call.get("id"),
54149
"function": {

0 commit comments

Comments
 (0)