Skip to content

Commit cc099a4

Browse files
committed
Add tests.
1 parent 8b86cc1 commit cc099a4

File tree

4 files changed

+108
-1
lines changed

4 files changed

+108
-1
lines changed

ads/llm/autogen/client_v02.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _convert_to_langchain_tool(tool):
129129
"""Converts the OpenAI tool spec to LangChain tool spec."""
130130
if tool["type"] == "function":
131131
tool = tool["function"]
132-
required = tool["parameters"]["required"]
132+
required = tool["parameters"].get("required", [])
133133
properties = copy.deepcopy(tool["parameters"]["properties"])
134134
for key in properties.keys():
135135
val = properties[key]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ aqua = ["jupyter_server"]
212212
# Revisit this section continuously and update to recent version of libraries. focus on pyt3.9/3.10 versions.
213213
testsuite = [
214214
"arff",
215+
"autogen-agentchat~=0.2,
215216
"category_encoders==2.6.3", # set version to avoid backtracking
216217
"cohere==4.53", # set version to avoid backtracking
217218
"faiss-cpu",
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# coding: utf-8
2+
# Copyright (c) 2016, 2024, Oracle and/or its affiliates. All rights reserved.
3+
# This software is dual-licensed to you under the Universal Permissive License (UPL) 1.0 as shown at https://oss.oracle.com/licenses/upl or Apache License 2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0. You may choose either license.
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# coding: utf-8
2+
# Copyright (c) 2016, 2024, Oracle and/or its affiliates. All rights reserved.
3+
# This software is dual-licensed to you under the Universal Permissive License (UPL) 1.0 as shown at https://oss.oracle.com/licenses/upl or Apache License 2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0. You may choose either license.
4+
import sys
5+
from unittest import TestCase, mock
6+
7+
8+
if sys.version_info < (3, 9):
9+
pytest.skip(allow_module_level=True)
10+
11+
import autogen
12+
from langchain_core.messages import AIMessage, ToolCall
13+
from ads.llm.autogen.client_v02 import (
14+
LangChainModelClient,
15+
register_custom_client,
16+
custom_clients,
17+
)
18+
from ads.llm import ChatOCIModelDeploymentVLLM
19+
20+
21+
ODSC_LLM_CONFIG = {
22+
"model_client_cls": "LangChainModelClient",
23+
"langchain_cls": "ads.llm.ChatOCIModelDeploymentVLLM",
24+
"model": "Mistral",
25+
"client_params": {
26+
"model": "odsc-llm",
27+
"endpoint": "<ODSC_ENDPOINT>",
28+
"model_kwargs": {"temperature": 0, "max_tokens": 500},
29+
},
30+
}
31+
32+
TEST_PAYLOAD = {
33+
"messages": ["hello", "hi"],
34+
"tool": {
35+
"type": "function",
36+
"function": {
37+
"name": "my_tool",
38+
"description": "my_desc",
39+
"parameters": {
40+
"type": "object",
41+
"properties": {
42+
"order_id": {
43+
"type": "string",
44+
"description": "The customer's order ID.",
45+
}
46+
},
47+
"required": ["order_id"],
48+
},
49+
},
50+
},
51+
}
52+
53+
MOCKED_RESPONSE_CONTENT = "hello"
54+
MOCKED_AI_MESSAGE = AIMessage(
55+
content=MOCKED_RESPONSE_CONTENT,
56+
tool_calls=[ToolCall(name="my_tool", args={"arg": "val"}, id="a")],
57+
)
58+
MOCKED_TOOL_CALL = [
59+
{
60+
"id": "a",
61+
"function": {
62+
"name": "my_tool",
63+
"arguments": '{"arg": "val"}',
64+
},
65+
"type": "function",
66+
}
67+
]
68+
69+
70+
class AutoGenTestCase(TestCase):
71+
@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None))
72+
def test_register_client(self, signer):
73+
# There should be no custom client before registration.
74+
self.assertEqual(custom_clients, {})
75+
register_custom_client(LangChainModelClient)
76+
self.assertEqual(custom_clients, {"LangChainModelClient": LangChainModelClient})
77+
# Test LLM config without custom LLM
78+
config_list = [
79+
{
80+
"model": "llama-7B",
81+
"api_key": "123",
82+
}
83+
]
84+
wrapper = autogen.oai.client.OpenAIWrapper(config_list=config_list)
85+
self.assertEqual(type(wrapper._clients[0]), autogen.oai.client.OpenAIClient)
86+
# Test LLM config with custom LLM
87+
config_list = [ODSC_LLM_CONFIG]
88+
wrapper = autogen.oai.client.OpenAIWrapper(config_list=config_list)
89+
self.assertEqual(type(wrapper._clients[0]), LangChainModelClient)
90+
91+
@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None))
92+
@mock.patch(
93+
"ads.llm.ChatOCIModelDeploymentVLLM.invoke", return_value=MOCKED_AI_MESSAGE
94+
)
95+
def test_create_completion(self, mocked_invoke, *args):
96+
client = LangChainModelClient(config=ODSC_LLM_CONFIG)
97+
self.assertEqual(client.model_name, "Mistral")
98+
self.assertEqual(type(client.model), ChatOCIModelDeploymentVLLM)
99+
self.assertEqual(client.model._invocation_params(stop=None)["max_tokens"], 500)
100+
response = client.create(TEST_PAYLOAD)
101+
message = response.choices[0].message
102+
self.assertEqual(message.content, MOCKED_RESPONSE_CONTENT)
103+
self.assertEqual(message.tool_calls, MOCKED_TOOL_CALL)

0 commit comments

Comments
 (0)