Skip to content

Python: Add chat completion agent code interpreter sample #12393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,43 +1,11 @@
# Copyright (c) Microsoft. All rights reserved.

import asyncio
import datetime

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from azure.identity import DefaultAzureCredential

from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import AzureChatCompletion
from semantic_kernel.core_plugins.sessions_python_tool.sessions_python_plugin import SessionsPythonTool
from semantic_kernel.exceptions.function_exceptions import FunctionExecutionException
from semantic_kernel.kernel import Kernel

auth_token: AccessToken | None = None

ACA_TOKEN_ENDPOINT: str = "https://acasessions.io/.default" # nosec


async def auth_callback() -> str:
"""Auth callback for the SessionsPythonTool.
This is a sample auth callback that shows how to use Azure's DefaultAzureCredential
to get an access token.
"""
global auth_token
current_utc_timestamp = int(datetime.datetime.now(datetime.timezone.utc).timestamp())

if not auth_token or auth_token.expires_on < current_utc_timestamp:
credential = DefaultAzureCredential()

try:
auth_token = credential.get_token(ACA_TOKEN_ENDPOINT)
except ClientAuthenticationError as cae:
err_messages = getattr(cae, "messages", [])
raise FunctionExecutionException(
f"Failed to retrieve the client auth token with messages: {' '.join(err_messages)}"
) from cae

return auth_token.token


async def main():
kernel = Kernel()
Expand All @@ -48,7 +16,7 @@ async def main():
)
kernel.add_service(chat_service)

python_code_interpreter = SessionsPythonTool(auth_callback=auth_callback)
python_code_interpreter = SessionsPythonTool()

sessions_tool = kernel.add_plugin(python_code_interpreter, "PythonCodeInterpreter")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Microsoft. All rights reserved.

import asyncio
import os

from semantic_kernel.agents import ChatCompletionAgent
from semantic_kernel.connectors.ai.open_ai import AzureChatCompletion
from semantic_kernel.contents import ChatMessageContent, FunctionCallContent, FunctionResultContent
from semantic_kernel.core_plugins import SessionsPythonTool

"""
The following sample demonstrates how to create a chat completion agent with
code interpreter capabilities using the Azure Container Apps session pool service.
"""


async def handle_intermediate_steps(message: ChatMessageContent) -> None:
for item in message.items or []:
if isinstance(item, FunctionResultContent):
print(f"# Function Result:> {item.result}")
elif isinstance(item, FunctionCallContent):
print(f"# Function Call:> {item.name} with arguments: {item.arguments}")
else:
print(f"# {message.name}: {message} ")


async def main():
# 1. Create the python code interpreter tool using the SessionsPythonTool
python_code_interpreter = SessionsPythonTool()

# 2. Create the agent
agent = ChatCompletionAgent(
service=AzureChatCompletion(),
name="Host",
instructions="Answer questions about the menu.",
plugins=[python_code_interpreter],
)

# 3. Upload a CSV file to the session
csv_file_path = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "resources", "sales.csv")
file_metadata = await python_code_interpreter.upload_file(local_file_path=csv_file_path)

# 4. Invoke the agent for a response to a task
TASK = (
"What's the total sum of all sales for all segments using Python? "
f"Use the uploaded file {file_metadata.full_path} for reference."
)
print(f"# User: '{TASK}'")
async for response in agent.invoke(
messages=TASK,
on_intermediate_message=handle_intermediate_steps,
):
print(f"# {response.name}: {response} ")

"""
Sample output:
# User: 'What's the total sum of all sales for all segments using Python?
Use the uploaded file /mnt/data/sales.csv for reference.'
# Function Call:> SessionsPythonTool-execute_code with arguments: {
"code": "
import pandas as pd

# Load the sales data
file_path = '/mnt/data/sales.csv'
sales_data = pd.read_csv(file_path)

# Calculate the total sum of sales
# Assuming there's a column named 'Sales' which contains the sales amounts
total_sales = sales_data['Sales'].sum()
total_sales"
}
# Function Result:> Status:
Success
Result:
118726350.28999999
Stdout:

Stderr:
# Host: The total sum of all sales for all segments is approximately $118,726,350.29.
"""


if __name__ == "__main__":
asyncio.run(main())
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,10 @@ async def upload_file(
files = {"file": (remote_file_path, data, "application/octet-stream")}
response = await self.http_client.post(url=url, files=files)
response.raise_for_status()
response_json = response.json()
return SessionsRemoteFileMetadata.from_dict(response_json["value"][0]["properties"])
uploaded_files = await self.list_files()
return next(
file_metadata for file_metadata in uploaded_files if file_metadata.full_path == remote_file_path
)
except HTTPStatusError as e:
error_message = e.response.text if e.response.text else e.response.reason_phrase
raise FunctionExecutionException(
Expand Down
64 changes: 51 additions & 13 deletions python/tests/unit/core_plugins/test_sessions_python_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,9 @@ async def test_empty_call_to_container_fails_raises_exception(aca_python_session
await plugin.execute_code(code="")


@patch("httpx.AsyncClient.get")
@patch("httpx.AsyncClient.post")
async def test_upload_file_with_local_path(mock_post, aca_python_sessions_unit_test_env):
async def test_upload_file_with_local_path(mock_post, mock_get, aca_python_sessions_unit_test_env):
"""Test upload_file when providing a local file path."""

async def async_return(result):
Expand All @@ -196,8 +197,18 @@ async def async_return(result):
patch("builtins.open", mock_open(read_data=b"file data")),
):
mock_request = httpx.Request(method="POST", url="https://example.com/files/upload?identifier=None")

mock_response = httpx.Response(
status_code=200,
json={
"$id": "1",
"value": [],
},
request=mock_request,
)
mock_post.return_value = await async_return(mock_response)

mock_get_request = httpx.Request(method="GET", url="https://example.com/files?identifier=None")
mock_get_response = httpx.Response(
status_code=200,
json={
"$id": "1",
Expand All @@ -213,9 +224,9 @@ async def async_return(result):
},
],
},
request=mock_request,
request=mock_get_request,
)
mock_post.return_value = await async_return(mock_response)
mock_get.return_value = await async_return(mock_get_response)

plugin = SessionsPythonTool(
auth_callback=lambda: "sample_token",
Expand All @@ -229,8 +240,9 @@ async def async_return(result):
mock_post.assert_awaited_once()


@patch("httpx.AsyncClient.get")
@patch("httpx.AsyncClient.post")
async def test_upload_file_with_local_path_and_no_remote(mock_post, aca_python_sessions_unit_test_env):
async def test_upload_file_with_local_path_and_no_remote(mock_post, mock_get, aca_python_sessions_unit_test_env):
"""Test upload_file when providing a local file path."""

async def async_return(result):
Expand All @@ -243,9 +255,19 @@ async def async_return(result):
),
patch("builtins.open", mock_open(read_data=b"file data")),
):
mock_request = httpx.Request(method="POST", url="https://example.com/files/upload?identifier=None")
mock_post_request = httpx.Request(method="POST", url="https://example.com/files/upload?identifier=None")
mock_post_response = httpx.Response(
status_code=200,
json={
"$id": "1",
"value": [],
},
request=mock_post_request,
)
mock_post.return_value = await async_return(mock_post_response)

mock_response = httpx.Response(
mock_get_request = httpx.Request(method="GET", url="https://example.com/files?identifier=None")
mock_get_response = httpx.Response(
status_code=200,
json={
"$id": "1",
Expand All @@ -261,9 +283,9 @@ async def async_return(result):
},
],
},
request=mock_request,
request=mock_get_request,
)
mock_post.return_value = await async_return(mock_response)
mock_get.return_value = await async_return(mock_get_response)

plugin = SessionsPythonTool(
auth_callback=lambda: "sample_token",
Expand Down Expand Up @@ -313,9 +335,15 @@ async def async_raise_http_error(*args, **kwargs):
("./file.py", "/mnt/data/input.py", "/mnt/data/input.py"),
],
)
@patch("httpx.AsyncClient.get")
@patch("httpx.AsyncClient.post")
async def test_upload_file_with_buffer(
mock_post, local_file_path, input_remote_file_path, expected_remote_file_path, aca_python_sessions_unit_test_env
mock_post,
mock_get,
local_file_path,
input_remote_file_path,
expected_remote_file_path,
aca_python_sessions_unit_test_env,
):
"""Test upload_file when providing file data as a BufferedReader."""

Expand All @@ -330,8 +358,18 @@ async def async_return(result):
patch("builtins.open", mock_open(read_data="print('hello, world~')")),
):
mock_request = httpx.Request(method="POST", url="https://example.com/files/upload?identifier=None")

mock_response = httpx.Response(
status_code=200,
json={
"$id": "1",
"value": [],
},
request=mock_request,
)
mock_post.return_value = await async_return(mock_response)

mock_get_request = httpx.Request(method="GET", url="https://example.com/files?identifier=None")
mock_get_response = httpx.Response(
status_code=200,
json={
"$id": "1",
Expand All @@ -347,9 +385,9 @@ async def async_return(result):
},
],
},
request=mock_request,
request=mock_get_request,
)
mock_post.return_value = await async_return(mock_response)
mock_get.return_value = await async_return(mock_get_response)

plugin = SessionsPythonTool(auth_callback=lambda: "sample_token")

Expand Down
Loading