Skip to content

Commit 973c51c

Browse files
committed
Add LangChain client for AutoGen.
1 parent 4607e07 commit 973c51c

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

ads/llm/autogen/__init__.py

Whitespace-only changes.

ads/llm/autogen/client_v02.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# coding: utf-8
2+
# Copyright (c) 2016, 2024, Oracle and/or its affiliates. All rights reserved.
3+
# This software is dual-licensed to you under the Universal Permissive License (UPL) 1.0 as shown at https://oss.oracle.com/licenses/upl or Apache License 2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0. You may choose either license.
4+
5+
"""This module contains the LangChain LLM client for AutoGen
6+
# References:
7+
# https://microsoft.github.io/autogen/0.2/docs/notebooks/agentchat_huggingface_langchain/
8+
# https://github.com/microsoft/autogen/blob/0.2/notebook/agentchat_custom_model.ipynb
9+
"""
10+
import copy
11+
import importlib
12+
import logging
13+
from typing import Dict, List, Union
14+
from types import SimpleNamespace
15+
16+
17+
from autogen import ModelClient
18+
from langchain_core.messages import AIMessage
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class Message(AIMessage):
24+
"""Represents message returned from the LLM."""
25+
26+
@classmethod
27+
def from_message(cls, message: AIMessage):
28+
"""Converts from LangChain AIMessage."""
29+
message = copy.deepcopy(message)
30+
message.__class__ = cls
31+
return message
32+
33+
@property
34+
def function_call(self):
35+
"""Function calls."""
36+
return self.tool_calls
37+
38+
39+
class LangChainModelClient(ModelClient):
40+
"""Represents a model client wrapping a LangChain chat model."""
41+
42+
def __init__(self, config: dict, **kwargs) -> None:
43+
super().__init__()
44+
logger.info("LangChain model client config: %s", str(config))
45+
self.client_class = config.pop("model_client_cls")
46+
# Parameters for the model
47+
self.model_name = config.get("model")
48+
# Import the LangChain class
49+
if "langchain_cls" not in config:
50+
raise ValueError("Missing langchain_cls in LangChain Model Client config.")
51+
module_cls = config.pop("langchain_cls")
52+
module_name, cls_name = str(module_cls).rsplit(".", 1)
53+
langchain_module = importlib.import_module(module_name)
54+
langchain_cls = getattr(langchain_module, cls_name)
55+
# Initialize the LangChain client
56+
self.model = langchain_cls(**config)
57+
58+
def create(self, params) -> ModelClient.ModelClientResponseProtocol:
59+
streaming = params.get("stream", False)
60+
num_of_responses = params.get("n", 1)
61+
messages = params.get("messages", [])
62+
63+
response = SimpleNamespace()
64+
response.choices = []
65+
response.model = self.model_name
66+
67+
if streaming and messages:
68+
# If streaming is enabled and has messages, then iterate over the chunks of the response.
69+
raise NotImplementedError()
70+
else:
71+
# If streaming is not enabled, send a regular chat completion request
72+
ai_message = self.model.invoke(messages)
73+
choice = SimpleNamespace()
74+
choice.message = Message.from_message(ai_message)
75+
response.choices.append(choice)
76+
return response
77+
78+
def message_retrieval(
79+
self, response: ModelClient.ModelClientResponseProtocol
80+
) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]:
81+
"""
82+
Retrieve and return a list of strings or a list of Choice.Message from the response.
83+
84+
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
85+
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
86+
"""
87+
return [choice.message.content for choice in response.choices]
88+
89+
def cost(self, response: ModelClient.ModelClientResponseProtocol) -> float:
90+
response.cost = 0
91+
return 0
92+
93+
@staticmethod
94+
def get_usage(response: ModelClient.ModelClientResponseProtocol) -> Dict:
95+
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
96+
return {}

0 commit comments

Comments
 (0)