Skip to content

Python: Fix schema handling. Fix function result return for type list. #6370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions python/semantic_kernel/connectors/ai/open_ai/services/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,30 @@ def kernel_function_metadata_to_openai_tool_format(metadata: KernelFunctionMetad

def parse_schema(schema_data):
"""Recursively parse the schema data to include nested properties."""
if schema_data.get("type") == "object":
if schema_data is None:
return {"type": "string", "description": ""}

schema_type = schema_data.get("type")
schema_description = schema_data.get("description", "")

if schema_type == "object":
properties = {key: parse_schema(value) for key, value in schema_data.get("properties", {}).items()}
return {
"type": "object",
"properties": {key: parse_schema(value) for key, value in schema_data.get("properties", {}).items()},
"description": schema_data.get("description", ""),
}
else:
return {
"type": schema_data.get("type", "string"),
"description": schema_data.get("description", ""),
**({"enum": schema_data.get("enum")} if "enum" in schema_data else {}),
"properties": properties,
"description": schema_description,
}

if schema_type == "array":
items = schema_data.get("items", {"type": "string"})
return {"type": "array", "description": schema_description, "items": items}

schema_dict = {"type": schema_type, "description": schema_description}
if "enum" in schema_data:
schema_dict["enum"] = schema_data["enum"]

return schema_dict

return {
"type": "function",
"function": {
Expand Down
180 changes: 180 additions & 0 deletions python/tests/unit/services/test_service_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright (c) Microsoft. All rights reserved.

from typing import Annotated

import pytest
from pydantic import Field

from semantic_kernel.connectors.ai.open_ai.services.utils import kernel_function_metadata_to_openai_tool_format
from semantic_kernel.functions.kernel_function_decorator import kernel_function
from semantic_kernel.kernel import Kernel
from semantic_kernel.kernel_pydantic import KernelBaseModel

# region Test helpers


class BooleanPlugin:
@kernel_function(name="GetBoolean", description="Get a boolean value.")
def get_boolean(self, value: Annotated[bool, "The boolean value."]) -> Annotated[bool, "The boolean value."]:
return value


class StringPlugin:
@kernel_function(name="GetWeather", description="Get the weather for a location.")
def get_weather(
self, location: Annotated[str, "The location to get the weather for."]
) -> Annotated[str, "The weather for the location."]:
return "The weather in {} is sunny.".format(location)


class ComplexRequest(KernelBaseModel):
start_date: str = Field(
...,
description="The start date in ISO 8601 format",
examples=["2024-01-23", "2020-06-15"],
)
end_date: str = Field(
...,
description="The end date in ISO-8601 format",
examples=["2024-01-23", "2020-06-15"],
)


class ComplexTypePlugin:
@kernel_function(name="answer_request", description="Answer a request")
def book_holiday(
self, request: Annotated[ComplexRequest, "A request to answer."]
) -> Annotated[bool, "The result is the boolean value True if successful, False if unsuccessful."]:
return True


class ListPlugin:
@kernel_function(name="get_items", description="Filters a list.")
def get_configuration(
self, items: Annotated[list[str], "The list of items."]
) -> Annotated[list[str], "The filtered list."]:
return [item for item in items if item in ["skip"]]


@pytest.fixture
def setup_kernel():
kernel = Kernel()
kernel.add_plugins(
{
"BooleanPlugin": BooleanPlugin(),
"StringPlugin": StringPlugin(),
"ComplexTypePlugin": ComplexTypePlugin(),
"ListPlugin": ListPlugin(),
}
)
return kernel


# endregion


def test_bool_schema(setup_kernel):
kernel = setup_kernel

boolean_func_metadata = kernel.get_list_of_function_metadata_filters(
filters={"included_plugins": ["BooleanPlugin"]}
)

boolean_schema = kernel_function_metadata_to_openai_tool_format(boolean_func_metadata[0])

expected_schema = {
"type": "function",
"function": {
"name": "BooleanPlugin-GetBoolean",
"description": "Get a boolean value.",
"parameters": {
"type": "object",
"properties": {"value": {"type": "boolean", "description": "The boolean value."}},
"required": ["value"],
},
},
}

assert boolean_schema == expected_schema


def test_string_schema(setup_kernel):
kernel = setup_kernel

string_func_metadata = kernel.get_list_of_function_metadata_filters(filters={"included_plugins": ["StringPlugin"]})

string_schema = kernel_function_metadata_to_openai_tool_format(string_func_metadata[0])

expected_schema = {
"type": "function",
"function": {
"name": "StringPlugin-GetWeather",
"description": "Get the weather for a location.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string", "description": "The location to get the weather for."}},
"required": ["location"],
},
},
}

assert string_schema == expected_schema


def test_complex_schema(setup_kernel):
kernel = setup_kernel

complex_func_metadata = kernel.get_list_of_function_metadata_filters(
filters={"included_plugins": ["ComplexTypePlugin"]}
)

complex_schema = kernel_function_metadata_to_openai_tool_format(complex_func_metadata[0])

expected_schema = {
"type": "function",
"function": {
"name": "ComplexTypePlugin-answer_request",
"description": "Answer a request",
"parameters": {
"type": "object",
"properties": {
"request": {
"type": "object",
"properties": {
"start_date": {"type": "string", "description": "The start date in ISO 8601 format"},
"end_date": {"type": "string", "description": "The end date in ISO-8601 format"},
},
"description": "A request to answer.",
}
},
"required": ["request"],
},
},
}

assert complex_schema == expected_schema


def test_list_schema(setup_kernel):
kernel = setup_kernel

complex_func_metadata = kernel.get_list_of_function_metadata_filters(filters={"included_plugins": ["ListPlugin"]})

complex_schema = kernel_function_metadata_to_openai_tool_format(complex_func_metadata[0])

expected_schema = {
"type": "function",
"function": {
"name": "ListPlugin-get_items",
"description": "Filters a list.",
"parameters": {
"type": "object",
"properties": {
"items": {"type": "array", "description": "The list of items.", "items": {"type": "string"}}
},
"required": ["items"],
},
},
}

assert complex_schema == expected_schema
Loading