Skip to content

Python: Support | and |= operators for KernelArgument #12499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions python/semantic_kernel/functions/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from semantic_kernel.const import DEFAULT_SERVICE_NAME

if TYPE_CHECKING:
from collections.abc import Iterable

from _typeshed import SupportsKeysAndGetItem

from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings


Expand Down Expand Up @@ -49,3 +53,53 @@ def __bool__(self) -> bool:
has_arguments = self.__len__() > 0
has_execution_settings = self.execution_settings is not None and len(self.execution_settings) > 0
return has_arguments or has_execution_settings

def __or__(self, value: dict) -> "KernelArguments":
"""Merges a KernelArguments with another KernelArguments or dict.

This implements the `|` operator for KernelArguments.
"""
if not isinstance(value, dict):
raise TypeError(
f"TypeError: unsupported operand type(s) for |: '{type(self).__name__}' and '{type(value).__name__}'"
)

# Merge execution settings
new_execution_settings = (self.execution_settings or {}).copy()
if isinstance(value, KernelArguments) and value.execution_settings:
new_execution_settings |= value.execution_settings
# Create a new KernelArguments with merged dict values
return KernelArguments(settings=new_execution_settings, **(dict(self) | dict(value)))

def __ror__(self, value: dict) -> "KernelArguments":
"""Merges a dict with a KernelArguments.

This implements the right-side `|` operator for KernelArguments.
"""
if not isinstance(value, dict):
raise TypeError(
f"TypeError: unsupported operand type(s) for |: '{type(value).__name__}' and '{type(self).__name__}'"
)

# Merge execution settings
new_execution_settings = {}
if isinstance(value, KernelArguments) and value.execution_settings:
new_execution_settings = value.execution_settings.copy()
if self.execution_settings:
new_execution_settings |= self.execution_settings

# Create a new KernelArguments with merged dict values
return KernelArguments(settings=new_execution_settings, **(dict(value) | dict(self)))

def __ior__(self, value: "SupportsKeysAndGetItem[Any, Any] | Iterable[tuple[Any, Any]]") -> "KernelArguments":
"""Merges into this KernelArguments with another KernelArguments or dict (in-place)."""
self.update(value)

# In-place merge execution settings
if isinstance(value, KernelArguments) and value.execution_settings:
if self.execution_settings:
self.execution_settings.update(value.execution_settings)
else:
self.execution_settings = value.execution_settings.copy()

return self
133 changes: 133 additions & 0 deletions python/tests/unit/functions/test_kernel_arguments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

import pytest

from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.functions.kernel_arguments import KernelArguments

Expand Down Expand Up @@ -46,3 +48,134 @@ def test_kernel_arguments_bool():
assert KernelArguments(settings=PromptExecutionSettings(service_id="test"))
# An KernelArguments object with both keyword arguments and execution_settings should return True
assert KernelArguments(input=10, settings=PromptExecutionSettings(service_id="test"))


@pytest.mark.parametrize(
"lhs, rhs, expected_dict, expected_settings_keys",
[
# Merging different keys
(KernelArguments(a=1), KernelArguments(b=2), {"a": 1, "b": 2}, None),
# RHS overwrites when keys duplicate
(KernelArguments(a=1), KernelArguments(a=99), {"a": 99}, None),
# Merging with a plain dict
(KernelArguments(a=1), {"b": 2}, {"a": 1, "b": 2}, None),
# Merging execution_settings together
(
KernelArguments(settings=PromptExecutionSettings(service_id="s1")),
KernelArguments(settings=PromptExecutionSettings(service_id="s2")),
{},
["s1", "s2"],
),
# Same service_id is overwritten by RHS
(
KernelArguments(settings=PromptExecutionSettings(service_id="shared")),
KernelArguments(settings=PromptExecutionSettings(service_id="shared")),
{},
["shared"],
),
],
)
def test_kernel_arguments_or_operator(lhs, rhs, expected_dict, expected_settings_keys):
"""Test the __or__ operator (lhs | rhs) with various argument combinations."""
result = lhs | rhs
assert isinstance(result, KernelArguments)
assert dict(result) == expected_dict
if expected_settings_keys is None:
assert result.execution_settings is None
else:
assert sorted(result.execution_settings.keys()) == sorted(expected_settings_keys)


@pytest.mark.parametrize("rhs", [42, "foo", None])
def test_kernel_arguments_or_operator_with_invalid_type(rhs):
"""Test the __or__ operator with an invalid type raises TypeError."""
with pytest.raises(TypeError):
KernelArguments() | rhs


@pytest.mark.parametrize(
"lhs, rhs, expected_dict, expected_settings_keys",
[
# Dict merge (in-place)
(KernelArguments(a=1), {"b": 2}, {"a": 1, "b": 2}, None),
# Merging between KernelArguments
(KernelArguments(a=1), KernelArguments(b=2), {"a": 1, "b": 2}, None),
# Retain existing execution_settings after dict merge
(KernelArguments(a=1, settings=PromptExecutionSettings(service_id="s1")), {"b": 2}, {"a": 1, "b": 2}, ["s1"]),
# In-place merge of execution_settings
(
KernelArguments(settings=PromptExecutionSettings(service_id="s1")),
KernelArguments(settings=PromptExecutionSettings(service_id="s2")),
{},
["s1", "s2"],
),
],
)
def test_kernel_arguments_inplace_merge(lhs, rhs, expected_dict, expected_settings_keys):
"""Test the |= operator with various argument combinations without execution_settings."""
original_id = id(lhs)
lhs |= rhs
# Verify this is the same object (in-place)
assert id(lhs) == original_id
assert dict(lhs) == expected_dict
if expected_settings_keys is None:
assert lhs.execution_settings is None
else:
assert sorted(lhs.execution_settings.keys()) == sorted(expected_settings_keys)


@pytest.mark.parametrize(
"rhs, lhs, expected_dict, expected_settings_keys",
[
# Merging different keys
({"b": 2}, KernelArguments(a=1), {"b": 2, "a": 1}, None),
# RHS overwrites when keys duplicate
({"a": 1}, KernelArguments(a=99), {"a": 99}, None),
# Merging with a KernelArguments
({"b": 2}, KernelArguments(a=1), {"b": 2, "a": 1}, None),
# Merging execution_settings together
(
{"test": "value"},
KernelArguments(settings=PromptExecutionSettings(service_id="s2")),
{"test": "value"},
["s2"],
),
# Plain dict on the left with KernelArguments+settings on the right
(
{"a": 1},
KernelArguments(b=2, settings=PromptExecutionSettings(service_id="shared")),
{"a": 1, "b": 2},
["shared"],
),
# KernelArguments on both sides with execution_settings
(
KernelArguments(a=1, settings=PromptExecutionSettings(service_id="s1")),
KernelArguments(b=2, settings=PromptExecutionSettings(service_id="s2")),
{"a": 1, "b": 2},
["s1", "s2"],
),
# Same service_id is overwritten by RHS (KernelArguments)
(
KernelArguments(a=1, settings=PromptExecutionSettings(service_id="shared")),
KernelArguments(b=2, settings=PromptExecutionSettings(service_id="shared")),
{"a": 1, "b": 2},
["shared"],
),
],
)
def test_kernel_arguments_ror_operator(rhs, lhs, expected_dict, expected_settings_keys):
"""Test the __ror__ operator (lhs | rhs) with various argument combinations."""
result = rhs | lhs
assert isinstance(result, KernelArguments)
assert dict(result) == expected_dict
if expected_settings_keys is None:
assert result.execution_settings is None
else:
assert sorted(result.execution_settings.keys()) == sorted(expected_settings_keys)


@pytest.mark.parametrize("lhs", [42, "foo", None])
def test_kernel_arguments_ror_operator_with_invalid_type(lhs):
"""Test the __ror__ operator with an invalid type raises TypeError."""
with pytest.raises(TypeError):
lhs | KernelArguments()
Loading