From cd59f54fde8fb3a1e53a9e45766ce0310d03d5c5 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 15 Apr 2025 11:25:38 +0530 Subject: [PATCH 01/36] iter1: poc # Conflicts: # packages/toolbox-core/src/toolbox_core/client.py # packages/toolbox-core/src/toolbox_core/tool.py --- .../toolbox-core/src/toolbox_core/client.py | 58 ++++++++++-- .../toolbox-core/src/toolbox_core/tool.py | 88 ++++++++++++++++--- 2 files changed, 129 insertions(+), 17 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index a534e706..dd117c95 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +import asyncio import types -from typing import Any, Callable, Mapping, Optional, Union +from typing import Any, Awaitable, Callable, Coroutine, Mapping, Optional, Union from aiohttp import ClientSession @@ -37,6 +36,7 @@ def __init__( self, url: str, session: Optional[ClientSession] = None, + client_headers: Optional[dict[str, Union[Callable, Coroutine]]] = None, ): """ Initializes the ToolboxClient. @@ -47,6 +47,7 @@ def __init__( If None (default), a new session is created internally. Note that if a session is provided, its lifecycle (including closing) should typically be managed externally. + client_headers: Headers to include in each request sent through this client. """ self.__base_url = url @@ -55,6 +56,8 @@ def __init__( session = ClientSession() self.__session = session + self.__client_headers = client_headers if client_headers is not None else {} + def __parse_tool( self, name: str, @@ -63,13 +66,16 @@ def __parse_tool( all_bound_params: Mapping[str, Union[Callable[[], Any], Any]], ) -> ToolboxTool: """Internal helper to create a callable tool from its schema.""" + # TODO: Check if any auth token getters have the same name as client_headers and don't pass those. # sort into reg, authn, and bound params params = [] authn_params: dict[str, list[str]] = {} bound_params: dict[str, Callable[[], str]] = {} + auth_sources: set[str] = set() for p in schema.parameters: if p.authSources: # authn parameter authn_params[p.name] = p.authSources + auth_sources.update(p.authSources) elif p.name in all_bound_params: # bound parameter bound_params[p.name] = all_bound_params[p.name] else: # regular parameter @@ -84,6 +90,7 @@ def __parse_tool( base_url=self.__base_url, name=name, description=schema.description, + client_headers=types.MappingProxyType(self.__client_headers), params=params, # create a read-only values for the maps to prevent mutation required_authn_params=types.MappingProxyType(authn_params), @@ -153,9 +160,16 @@ async def load_tool( """ + # Resolve client headers + original_headers = self.__client_headers + resolved_headers = { + header_name: resolve_value(original_headers[header_name]) + for header_name in original_headers + } + # request the definition of the tool from the server url = f"{self.__base_url}/api/tool/{name}" - async with self.__session.get(url) as response: + async with self.__session.get(url, headers=resolved_headers) as response: json = await response.json() manifest: ManifestSchema = ManifestSchema(**json) @@ -185,15 +199,19 @@ async def load_toolset( bound_params: A mapping of parameter names to bind to specific values or callables that are called to produce values as needed. - - Returns: list[ToolboxTool]: A list of callables, one for each tool defined in the toolset. """ + # Resolve client headers + original_headers = self.__client_headers + resolved_headers = { + header_name: resolve_value(original_headers[header_name]) + for header_name in original_headers + } # Request the definition of the tool from the server url = f"{self.__base_url}/api/toolset/{name or ''}" - async with self.__session.get(url) as response: + async with self.__session.get(url, headers=resolved_headers) as response: json = await response.json() manifest: ManifestSchema = ManifestSchema(**json) @@ -203,3 +221,29 @@ async def load_toolset( for n, s in manifest.tools.items() ] return tools + + async def add_headers(self, headers: Mapping[str, Union[Callable, Coroutine]]): + # TODO: Add logic to update self.__headers + pass + + +async def resolve_value( + source: Union[Callable[[], Awaitable[Any]], Callable[[], Any], Any], +) -> Any: + """ + Asynchronously or synchronously resolves a given source to its value. + If the `source` is a coroutine function, it will be awaited. + If the `source` is a regular callable, it will be called. + Otherwise (if it's not a callable), the `source` itself is returned directly. + Args: + source: The value, a callable returning a value, or a callable + returning an awaitable value. + Returns: + The resolved value. + """ + + if asyncio.iscoroutinefunction(source): + return await source() + elif callable(source): + return source() + return source diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 3150be94..a5a54683 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -13,28 +13,27 @@ # limitations under the License. +import asyncio import types from inspect import Signature from typing import ( Any, Callable, + Coroutine, + Iterable, Mapping, Optional, Sequence, + Type, Union, + cast, ) from aiohttp import ClientSession +from pydantic import BaseModel, Field, create_model from toolbox_core.protocol import ParameterSchema -from .utils import ( - create_func_docstring, - identify_required_authn_params, - params_to_pydantic_model, - resolve_value, -) - class ToolboxTool: """ @@ -55,6 +54,7 @@ def __init__( base_url: str, name: str, description: str, + client_headers: Mapping[str, Union[Callable, Coroutine]], params: Sequence[ParameterSchema], required_authn_params: Mapping[str, list[str]], auth_service_token_getters: Mapping[str, Callable[[], str]], @@ -69,6 +69,8 @@ def __init__( base_url: The base URL of the Toolbox server API. name: The name of the remote tool. description: The description of the remote tool. + client_headers: Immutable headers to include in each request sent to this tool. + Tool headers will override the client headers. params: The args of the tool. required_authn_params: A dict of required authenticated parameters to a list of services that provide values for them. @@ -82,6 +84,7 @@ def __init__( self.__base_url: str = base_url self.__url = f"{base_url}/api/tool/{name}/invoke" self.__description = description + self.__client_headers = client_headers self.__params = params self.__pydantic_model = params_to_pydantic_model(name, self.__params) @@ -89,7 +92,7 @@ def __init__( # the following properties are set to help anyone that might inspect it determine usage self.__name__ = name - self.__doc__ = create_func_docstring(self.__description, self.__params) + self.__doc__ = create_docstring(self.__description, self.__params) self.__signature__ = Signature( parameters=inspect_type_params, return_annotation=str ) @@ -97,6 +100,7 @@ def __init__( self.__annotations__ = {p.name: p.annotation for p in inspect_type_params} self.__qualname__ = f"{self.__class__.__qualname__}.{self.__name__}" + # TODO: Add logic to remove any auth params named the same as client_params. Raise a warning. # map of parameter name to auth service required by it self.__required_authn_params = required_authn_params # map of authService -> token_getter @@ -110,6 +114,7 @@ def __copy( base_url: Optional[str] = None, name: Optional[str] = None, description: Optional[str] = None, + client_headers: Optional[Mapping[str, Union[Callable, Coroutine]]] = None, params: Optional[Sequence[ParameterSchema]] = None, required_authn_params: Optional[Mapping[str, list[str]]] = None, auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, @@ -139,6 +144,7 @@ def __copy( name=check(name, self.__name__), description=check(description, self.__description), params=check(params, self.__params), + client_headers=check(client_headers, self.__client_headers), required_authn_params=check( required_authn_params, self.__required_authn_params ), @@ -183,13 +189,18 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: # apply bounded parameters for param, value in self.__bound_parameters.items(): - payload[param] = await resolve_value(value) + if asyncio.iscoroutinefunction(value): + value = await value() + elif callable(value): + value = value() + payload[param] = value # create headers for auth services headers = {} for auth_service, token_getter in self.__auth_service_token_getters.items(): - headers[f"{auth_service}_token"] = await resolve_value(token_getter) + headers[f"{auth_service}_token"] = token_getter() + # TODO: Add client headers async with self.__session.post( self.__url, json=payload, @@ -217,6 +228,7 @@ def add_auth_token_getters( A new ToolboxTool instance with the specified authentication token getters registered. """ + # TODO: Check if any auth token getters are registered as headers in the client. # throw an error if the authentication source is already registered existing_services = self.__auth_service_token_getters.keys() @@ -272,3 +284,59 @@ def bind_parameters( params=new_params, bound_params=types.MappingProxyType(all_bound_params), ) + + +def create_docstring(description: str, params: Sequence[ParameterSchema]) -> str: + """Convert tool description and params into its function docstring""" + docstring = description + if not params: + return docstring + docstring += "\n\nArgs:" + for p in params: + docstring += ( + f"\n {p.name} ({p.to_param().annotation.__name__}): {p.description}" + ) + return docstring + + +def identify_required_authn_params( + req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] +) -> dict[str, list[str]]: + """ + Identifies authentication parameters that are still required; because they + not covered by the provided `auth_service_names`. + + Args: + req_authn_params: A mapping of parameter names to sets of required + authentication services. + auth_service_names: An iterable of authentication service names for which + token getters are available. + + Returns: + A new dictionary representing the subset of required authentication parameters + that are not covered by the provided `auth_services`. + """ + required_params = {} # params that are still required with provided auth_services + for param, services in req_authn_params.items(): + # if we don't have a token_getter for any of the services required by the param, + # the param is still required + required = not any(s in services for s in auth_service_names) + if required: + required_params[param] = services + return required_params + + +def params_to_pydantic_model( + tool_name: str, params: Sequence[ParameterSchema] +) -> Type[BaseModel]: + """Converts the given parameters to a Pydantic BaseModel class.""" + field_definitions = {} + for field in params: + field_definitions[field.name] = cast( + Any, + ( + field.to_param().annotation, + Field(description=field.description), + ), + ) + return create_model(tool_name, **field_definitions) From 297f5a97a0ddfbdac3b19080216c1b2f90965099 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Thu, 17 Apr 2025 11:42:31 +0530 Subject: [PATCH 02/36] remove client headers from tool --- packages/toolbox-core/src/toolbox_core/tool.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index a5a54683..d7216fd2 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -54,7 +54,6 @@ def __init__( base_url: str, name: str, description: str, - client_headers: Mapping[str, Union[Callable, Coroutine]], params: Sequence[ParameterSchema], required_authn_params: Mapping[str, list[str]], auth_service_token_getters: Mapping[str, Callable[[], str]], @@ -69,8 +68,6 @@ def __init__( base_url: The base URL of the Toolbox server API. name: The name of the remote tool. description: The description of the remote tool. - client_headers: Immutable headers to include in each request sent to this tool. - Tool headers will override the client headers. params: The args of the tool. required_authn_params: A dict of required authenticated parameters to a list of services that provide values for them. @@ -84,7 +81,6 @@ def __init__( self.__base_url: str = base_url self.__url = f"{base_url}/api/tool/{name}/invoke" self.__description = description - self.__client_headers = client_headers self.__params = params self.__pydantic_model = params_to_pydantic_model(name, self.__params) @@ -100,7 +96,6 @@ def __init__( self.__annotations__ = {p.name: p.annotation for p in inspect_type_params} self.__qualname__ = f"{self.__class__.__qualname__}.{self.__name__}" - # TODO: Add logic to remove any auth params named the same as client_params. Raise a warning. # map of parameter name to auth service required by it self.__required_authn_params = required_authn_params # map of authService -> token_getter @@ -114,7 +109,6 @@ def __copy( base_url: Optional[str] = None, name: Optional[str] = None, description: Optional[str] = None, - client_headers: Optional[Mapping[str, Union[Callable, Coroutine]]] = None, params: Optional[Sequence[ParameterSchema]] = None, required_authn_params: Optional[Mapping[str, list[str]]] = None, auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, @@ -144,7 +138,6 @@ def __copy( name=check(name, self.__name__), description=check(description, self.__description), params=check(params, self.__params), - client_headers=check(client_headers, self.__client_headers), required_authn_params=check( required_authn_params, self.__required_authn_params ), @@ -200,7 +193,6 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: for auth_service, token_getter in self.__auth_service_token_getters.items(): headers[f"{auth_service}_token"] = token_getter() - # TODO: Add client headers async with self.__session.post( self.__url, json=payload, @@ -228,8 +220,6 @@ def add_auth_token_getters( A new ToolboxTool instance with the specified authentication token getters registered. """ - # TODO: Check if any auth token getters are registered as headers in the client. - # throw an error if the authentication source is already registered existing_services = self.__auth_service_token_getters.keys() incoming_services = auth_token_getters.keys() @@ -304,7 +294,7 @@ def identify_required_authn_params( ) -> dict[str, list[str]]: """ Identifies authentication parameters that are still required; because they - not covered by the provided `auth_service_names`. + are not covered by the provided `auth_service_names`. Args: req_authn_params: A mapping of parameter names to sets of required From 7a30444adae9f28c0a7f12a31c5421212c004759 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Thu, 17 Apr 2025 11:43:50 +0530 Subject: [PATCH 03/36] merge correction --- .../toolbox-core/src/toolbox_core/tool.py | 56 ------------------- 1 file changed, 56 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index d7216fd2..05bd2c4b 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -274,59 +274,3 @@ def bind_parameters( params=new_params, bound_params=types.MappingProxyType(all_bound_params), ) - - -def create_docstring(description: str, params: Sequence[ParameterSchema]) -> str: - """Convert tool description and params into its function docstring""" - docstring = description - if not params: - return docstring - docstring += "\n\nArgs:" - for p in params: - docstring += ( - f"\n {p.name} ({p.to_param().annotation.__name__}): {p.description}" - ) - return docstring - - -def identify_required_authn_params( - req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] -) -> dict[str, list[str]]: - """ - Identifies authentication parameters that are still required; because they - are not covered by the provided `auth_service_names`. - - Args: - req_authn_params: A mapping of parameter names to sets of required - authentication services. - auth_service_names: An iterable of authentication service names for which - token getters are available. - - Returns: - A new dictionary representing the subset of required authentication parameters - that are not covered by the provided `auth_services`. - """ - required_params = {} # params that are still required with provided auth_services - for param, services in req_authn_params.items(): - # if we don't have a token_getter for any of the services required by the param, - # the param is still required - required = not any(s in services for s in auth_service_names) - if required: - required_params[param] = services - return required_params - - -def params_to_pydantic_model( - tool_name: str, params: Sequence[ParameterSchema] -) -> Type[BaseModel]: - """Converts the given parameters to a Pydantic BaseModel class.""" - field_definitions = {} - for field in params: - field_definitions[field.name] = cast( - Any, - ( - field.to_param().annotation, - Field(description=field.description), - ), - ) - return create_model(tool_name, **field_definitions) From c78b74e85911b72a3fc1589eda852fca87560c01 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Thu, 17 Apr 2025 11:51:22 +0530 Subject: [PATCH 04/36] cleanup --- .../toolbox-core/src/toolbox_core/client.py | 31 ++----------------- .../toolbox-core/src/toolbox_core/tool.py | 24 +++++++------- 2 files changed, 14 insertions(+), 41 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index dd117c95..b0f2822d 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import types -from typing import Any, Awaitable, Callable, Coroutine, Mapping, Optional, Union +from typing import Any, Callable, Coroutine, Mapping, Optional, Union from aiohttp import ClientSession from .protocol import ManifestSchema, ToolSchema -from .tool import ToolboxTool, identify_required_authn_params +from .tool import ToolboxTool +from .utils import identify_required_authn_params, resolve_value class ToolboxClient: @@ -71,11 +71,9 @@ def __parse_tool( params = [] authn_params: dict[str, list[str]] = {} bound_params: dict[str, Callable[[], str]] = {} - auth_sources: set[str] = set() for p in schema.parameters: if p.authSources: # authn parameter authn_params[p.name] = p.authSources - auth_sources.update(p.authSources) elif p.name in all_bound_params: # bound parameter bound_params[p.name] = all_bound_params[p.name] else: # regular parameter @@ -90,7 +88,6 @@ def __parse_tool( base_url=self.__base_url, name=name, description=schema.description, - client_headers=types.MappingProxyType(self.__client_headers), params=params, # create a read-only values for the maps to prevent mutation required_authn_params=types.MappingProxyType(authn_params), @@ -225,25 +222,3 @@ async def load_toolset( async def add_headers(self, headers: Mapping[str, Union[Callable, Coroutine]]): # TODO: Add logic to update self.__headers pass - - -async def resolve_value( - source: Union[Callable[[], Awaitable[Any]], Callable[[], Any], Any], -) -> Any: - """ - Asynchronously or synchronously resolves a given source to its value. - If the `source` is a coroutine function, it will be awaited. - If the `source` is a regular callable, it will be called. - Otherwise (if it's not a callable), the `source` itself is returned directly. - Args: - source: The value, a callable returning a value, or a callable - returning an awaitable value. - Returns: - The resolved value. - """ - - if asyncio.iscoroutinefunction(source): - return await source() - elif callable(source): - return source() - return source diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 05bd2c4b..3150be94 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -13,27 +13,28 @@ # limitations under the License. -import asyncio import types from inspect import Signature from typing import ( Any, Callable, - Coroutine, - Iterable, Mapping, Optional, Sequence, - Type, Union, - cast, ) from aiohttp import ClientSession -from pydantic import BaseModel, Field, create_model from toolbox_core.protocol import ParameterSchema +from .utils import ( + create_func_docstring, + identify_required_authn_params, + params_to_pydantic_model, + resolve_value, +) + class ToolboxTool: """ @@ -88,7 +89,7 @@ def __init__( # the following properties are set to help anyone that might inspect it determine usage self.__name__ = name - self.__doc__ = create_docstring(self.__description, self.__params) + self.__doc__ = create_func_docstring(self.__description, self.__params) self.__signature__ = Signature( parameters=inspect_type_params, return_annotation=str ) @@ -182,16 +183,12 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: # apply bounded parameters for param, value in self.__bound_parameters.items(): - if asyncio.iscoroutinefunction(value): - value = await value() - elif callable(value): - value = value() - payload[param] = value + payload[param] = await resolve_value(value) # create headers for auth services headers = {} for auth_service, token_getter in self.__auth_service_token_getters.items(): - headers[f"{auth_service}_token"] = token_getter() + headers[f"{auth_service}_token"] = await resolve_value(token_getter) async with self.__session.post( self.__url, @@ -220,6 +217,7 @@ def add_auth_token_getters( A new ToolboxTool instance with the specified authentication token getters registered. """ + # throw an error if the authentication source is already registered existing_services = self.__auth_service_token_getters.keys() incoming_services = auth_token_getters.keys() From 4b92acef3dc0850d9a45c94555aa27caffdc06e1 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Thu, 17 Apr 2025 14:22:18 +0530 Subject: [PATCH 05/36] client headers functionality --- .../toolbox-core/src/toolbox_core/client.py | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index b0f2822d..1da8b95b 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import types +import warnings from typing import Any, Callable, Coroutine, Mapping, Optional, Union from aiohttp import ClientSession @@ -66,7 +67,6 @@ def __parse_tool( all_bound_params: Mapping[str, Union[Callable[[], Any], Any]], ) -> ToolboxTool: """Internal helper to create a callable tool from its schema.""" - # TODO: Check if any auth token getters have the same name as client_headers and don't pass those. # sort into reg, authn, and bound params params = [] authn_params: dict[str, list[str]] = {} @@ -83,6 +83,13 @@ def __parse_tool( authn_params, auth_token_getters.keys() ) + request_headers = self.__client_headers + for auth_token, auth_token_val in auth_token_getters: + if auth_token in request_headers.keys(): + warnings.warn(f"Auth token {auth_token} already bound in client.") + else: + request_headers[auth_token] = auth_token_val + tool = ToolboxTool( session=self.__session, base_url=self.__base_url, @@ -91,7 +98,7 @@ def __parse_tool( params=params, # create a read-only values for the maps to prevent mutation required_authn_params=types.MappingProxyType(authn_params), - auth_service_token_getters=types.MappingProxyType(auth_token_getters), + auth_service_token_getters=types.MappingProxyType(request_headers), bound_params=types.MappingProxyType(bound_params), ) return tool @@ -220,5 +227,13 @@ async def load_toolset( return tools async def add_headers(self, headers: Mapping[str, Union[Callable, Coroutine]]): - # TODO: Add logic to update self.__headers - pass + existing_headers = self.__client_headers.keys() + incoming_headers = headers.keys() + duplicates = existing_headers & incoming_headers + if duplicates: + warnings.warn( + f"Client header(s) `{', '.join(duplicates)}` already registered in client. These will not be registered again." + ) + for header_key, header_val in headers: + if header_key not in existing_headers: + self.__client_headers[header_key] = header_val From 37a398437e7fd97623b0baa38cbbf9de85347c55 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Thu, 17 Apr 2025 15:17:01 +0530 Subject: [PATCH 06/36] small diff --- packages/toolbox-core/src/toolbox_core/client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 1da8b95b..2c3c9576 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -84,7 +84,7 @@ def __parse_tool( ) request_headers = self.__client_headers - for auth_token, auth_token_val in auth_token_getters: + for auth_token, auth_token_val in auth_token_getters.items(): if auth_token in request_headers.keys(): warnings.warn(f"Auth token {auth_token} already bound in client.") else: @@ -167,7 +167,7 @@ async def load_tool( # Resolve client headers original_headers = self.__client_headers resolved_headers = { - header_name: resolve_value(original_headers[header_name]) + header_name: await resolve_value(original_headers[header_name]) for header_name in original_headers } @@ -210,7 +210,7 @@ async def load_toolset( # Resolve client headers original_headers = self.__client_headers resolved_headers = { - header_name: resolve_value(original_headers[header_name]) + header_name: await resolve_value(original_headers[header_name]) for header_name in original_headers } # Request the definition of the tool from the server @@ -234,6 +234,6 @@ async def add_headers(self, headers: Mapping[str, Union[Callable, Coroutine]]): warnings.warn( f"Client header(s) `{', '.join(duplicates)}` already registered in client. These will not be registered again." ) - for header_key, header_val in headers: + for header_key, header_val in headers.items(): if header_key not in existing_headers: self.__client_headers[header_key] = header_val From bc6ca963295168d8c891a10da37746856556a865 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Thu, 17 Apr 2025 15:52:29 +0530 Subject: [PATCH 07/36] mypy --- packages/toolbox-core/src/toolbox_core/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 2c3c9576..e58e3fd0 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -226,7 +226,7 @@ async def load_toolset( ] return tools - async def add_headers(self, headers: Mapping[str, Union[Callable, Coroutine]]): + async def add_headers(self, headers: Mapping[str, Union[Callable, Coroutine]]) -> None: existing_headers = self.__client_headers.keys() incoming_headers = headers.keys() duplicates = existing_headers & incoming_headers From f5fc5263f8be724555fc8934cdaa203514f7db11 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 11:39:39 +0530 Subject: [PATCH 08/36] raise error on duplicate headers --- .../toolbox-core/src/toolbox_core/client.py | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index e58e3fd0..0b981065 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import types -import warnings from typing import Any, Callable, Coroutine, Mapping, Optional, Union from aiohttp import ClientSession @@ -67,6 +66,15 @@ def __parse_tool( all_bound_params: Mapping[str, Union[Callable[[], Any], Any]], ) -> ToolboxTool: """Internal helper to create a callable tool from its schema.""" + # Validate conflicting Headers/Auth Tokens + request_header_names = self.__client_headers.keys() + auth_token_names = [auth_token + "_token" for auth_token in auth_token_getters.keys()] + duplicates = request_header_names & auth_token_names + if duplicates: + raise ValueError( + f"Client header(s) `{', '.join(duplicates)}` already registered in client." + ) + # sort into reg, authn, and bound params params = [] authn_params: dict[str, list[str]] = {} @@ -83,13 +91,6 @@ def __parse_tool( authn_params, auth_token_getters.keys() ) - request_headers = self.__client_headers - for auth_token, auth_token_val in auth_token_getters.items(): - if auth_token in request_headers.keys(): - warnings.warn(f"Auth token {auth_token} already bound in client.") - else: - request_headers[auth_token] = auth_token_val - tool = ToolboxTool( session=self.__session, base_url=self.__base_url, @@ -98,7 +99,7 @@ def __parse_tool( params=params, # create a read-only values for the maps to prevent mutation required_authn_params=types.MappingProxyType(authn_params), - auth_service_token_getters=types.MappingProxyType(request_headers), + auth_service_token_getters=types.MappingProxyType({**self.__client_headers, **auth_token_getters}), bound_params=types.MappingProxyType(bound_params), ) return tool @@ -163,7 +164,6 @@ async def load_tool( depend on the tool itself. """ - # Resolve client headers original_headers = self.__client_headers resolved_headers = { @@ -227,13 +227,14 @@ async def load_toolset( return tools async def add_headers(self, headers: Mapping[str, Union[Callable, Coroutine]]) -> None: + """ + + """ existing_headers = self.__client_headers.keys() incoming_headers = headers.keys() duplicates = existing_headers & incoming_headers if duplicates: - warnings.warn( - f"Client header(s) `{', '.join(duplicates)}` already registered in client. These will not be registered again." + raise ValueError( + f"Client header(s) `{', '.join(duplicates)}` already registered in client." ) - for header_key, header_val in headers.items(): - if header_key not in existing_headers: - self.__client_headers[header_key] = header_val + self.__client_headers.update(headers) \ No newline at end of file From 5e91f15d790b3b3bda6324da2013886999aa8a2c Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 11:41:44 +0530 Subject: [PATCH 09/36] docs --- packages/toolbox-core/src/toolbox_core/client.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 0b981065..03f73fc9 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -228,13 +228,19 @@ async def load_toolset( async def add_headers(self, headers: Mapping[str, Union[Callable, Coroutine]]) -> None: """ + Asynchronously Add headers to be included in each request sent through this client. + Args: + headers: Headers to include in each request sent through this client. + + Raises: + ValueError: If any of the headers are already registered in the client. """ existing_headers = self.__client_headers.keys() incoming_headers = headers.keys() duplicates = existing_headers & incoming_headers if duplicates: raise ValueError( - f"Client header(s) `{', '.join(duplicates)}` already registered in client." + f"Client header(s) `{', '.join(duplicates)}` already registered in the client." ) self.__client_headers.update(headers) \ No newline at end of file From 154edc1c5fcb2dd968c86c540f1654e03cfb40d5 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 15:54:28 +0530 Subject: [PATCH 10/36] add client headers to tool --- .../toolbox-core/src/toolbox_core/client.py | 23 +++------ .../toolbox-core/src/toolbox_core/tool.py | 50 +++++++++++++++---- 2 files changed, 49 insertions(+), 24 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 03f73fc9..3509f1e6 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -36,7 +36,7 @@ def __init__( self, url: str, session: Optional[ClientSession] = None, - client_headers: Optional[dict[str, Union[Callable, Coroutine]]] = None, + client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None, ): """ Initializes the ToolboxClient. @@ -64,17 +64,9 @@ def __parse_tool( schema: ToolSchema, auth_token_getters: dict[str, Callable[[], str]], all_bound_params: Mapping[str, Union[Callable[[], Any], Any]], + client_headers: Mapping[str, Union[Callable, Coroutine, str]], ) -> ToolboxTool: """Internal helper to create a callable tool from its schema.""" - # Validate conflicting Headers/Auth Tokens - request_header_names = self.__client_headers.keys() - auth_token_names = [auth_token + "_token" for auth_token in auth_token_getters.keys()] - duplicates = request_header_names & auth_token_names - if duplicates: - raise ValueError( - f"Client header(s) `{', '.join(duplicates)}` already registered in client." - ) - # sort into reg, authn, and bound params params = [] authn_params: dict[str, list[str]] = {} @@ -99,8 +91,9 @@ def __parse_tool( params=params, # create a read-only values for the maps to prevent mutation required_authn_params=types.MappingProxyType(authn_params), - auth_service_token_getters=types.MappingProxyType({**self.__client_headers, **auth_token_getters}), + auth_service_token_getters=types.MappingProxyType(auth_token_getters), bound_params=types.MappingProxyType(bound_params), + client_headers=types.MappingProxyType(client_headers), ) return tool @@ -156,8 +149,6 @@ async def load_tool( bound_params: A mapping of parameter names to bind to specific values or callables that are called to produce values as needed. - - Returns: ToolboxTool: A callable object representing the loaded tool, ready for execution. The specific arguments and behavior of the callable @@ -226,7 +217,9 @@ async def load_toolset( ] return tools - async def add_headers(self, headers: Mapping[str, Union[Callable, Coroutine]]) -> None: + async def add_headers( + self, headers: Mapping[str, Union[Callable, Coroutine, str]] + ) -> None: """ Asynchronously Add headers to be included in each request sent through this client. @@ -243,4 +236,4 @@ async def add_headers(self, headers: Mapping[str, Union[Callable, Coroutine]]) - raise ValueError( f"Client header(s) `{', '.join(duplicates)}` already registered in the client." ) - self.__client_headers.update(headers) \ No newline at end of file + self.__client_headers.update(headers) diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index b9f5b8df..bb5affae 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -15,14 +15,7 @@ import types from inspect import Signature -from typing import ( - Any, - Callable, - Mapping, - Optional, - Sequence, - Union, -) +from typing import Any, Callable, Mapping, Optional, Sequence, Union, Coroutine from aiohttp import ClientSession @@ -59,6 +52,7 @@ def __init__( required_authn_params: Mapping[str, list[str]], auth_service_token_getters: Mapping[str, Callable[[], str]], bound_params: Mapping[str, Union[Callable[[], Any], Any]], + client_headers: Mapping[str, Union[Callable, Coroutine, str]], ): """ Initializes a callable that will trigger the tool invocation through the @@ -76,6 +70,7 @@ def __init__( produce a token) bound_params: A mapping of parameter names to bind to specific values or callables that are called to produce values as needed. + client_headers: Client specific headers bound to the tool. """ # used to invoke the toolbox API self.__session: ClientSession = session @@ -97,12 +92,28 @@ def __init__( self.__annotations__ = {p.name: p.annotation for p in inspect_type_params} self.__qualname__ = f"{self.__class__.__qualname__}.{self.__name__}" + # Validate conflicting Headers/Auth Tokens + request_header_names = client_headers.keys() + auth_token_names = [ + auth_token_name + "_token" + for auth_token_name in auth_service_token_getters.keys() + ] + duplicates = request_header_names & auth_token_names + if duplicates: + raise ValueError( + f"Client header(s) `{', '.join(duplicates)}` already registered in client. " + f"Cannot register client the same headers in the client as well as tool." + ) + # map of parameter name to auth service required by it self.__required_authn_params = required_authn_params # map of authService -> token_getter self.__auth_service_token_getters = auth_service_token_getters # map of parameter name to value (or callable that produces that value) self.__bound_parameters = bound_params + # map of client headers to their value/callable/coroutine + self.__client_headers = client_headers + def __copy( self, @@ -114,6 +125,7 @@ def __copy( required_authn_params: Optional[Mapping[str, list[str]]] = None, auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None, + client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None, ) -> "ToolboxTool": """ Creates a copy of the ToolboxTool, overriding specific fields. @@ -130,7 +142,7 @@ def __copy( that produce a token) bound_params: A mapping of parameter names to bind to specific values or callables that are called to produce values as needed. - + client_headers: Client specific headers bound to the tool. """ check = lambda val, default: val if val is not None else default return ToolboxTool( @@ -146,6 +158,7 @@ def __copy( auth_service_token_getters, self.__auth_service_token_getters ), bound_params=check(bound_params, self.__bound_parameters), + client_headers=check(client_headers, self.__client_headers), ) async def __call__(self, *args: Any, **kwargs: Any) -> str: @@ -189,6 +202,8 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: headers = {} for auth_service, token_getter in self.__auth_service_token_getters.items(): headers[f"{auth_service}_token"] = await resolve_value(token_getter) + for client_header_name, client_header_val in self.__client_headers: + headers[client_header_name] = await resolve_value(client_header_val) async with self.__session.post( self.__url, @@ -216,6 +231,10 @@ def add_auth_token_getters( Returns: A new ToolboxTool instance with the specified authentication token getters registered. + + Raises + ValueError: If the auth source has already been registered either + to the tool or to the corresponding client. """ # throw an error if the authentication source is already registered @@ -227,6 +246,19 @@ def add_auth_token_getters( f"Authentication source(s) `{', '.join(duplicates)}` already registered in tool `{self.__name__}`." ) + # Validate duplicates with client headers + request_header_names = self.__client_headers.keys() + auth_token_names = [ + auth_token_name + "_token" + for auth_token_name in incoming_services + ] + duplicates = request_header_names & auth_token_names + if duplicates: + raise ValueError( + f"Client header(s) `{', '.join(duplicates)}` already registered in client. " + f"Cannot register client the same headers in the client as well as tool." + ) + # create a read-only updated value for new_getters new_getters = types.MappingProxyType( dict(self.__auth_service_token_getters, **auth_token_getters) From b83eaef8b6f52c46f73a4284141170fe8629c0c6 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 15:56:11 +0530 Subject: [PATCH 11/36] lint --- packages/toolbox-core/src/toolbox_core/tool.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index bb5affae..8dadbc15 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -15,7 +15,7 @@ import types from inspect import Signature -from typing import Any, Callable, Mapping, Optional, Sequence, Union, Coroutine +from typing import Any, Callable, Coroutine, Mapping, Optional, Sequence, Union from aiohttp import ClientSession @@ -114,7 +114,6 @@ def __init__( # map of client headers to their value/callable/coroutine self.__client_headers = client_headers - def __copy( self, session: Optional[ClientSession] = None, @@ -249,8 +248,7 @@ def add_auth_token_getters( # Validate duplicates with client headers request_header_names = self.__client_headers.keys() auth_token_names = [ - auth_token_name + "_token" - for auth_token_name in incoming_services + auth_token_name + "_token" for auth_token_name in incoming_services ] duplicates = request_header_names & auth_token_names if duplicates: From 6965ba1e064394c0cf621a22cf12422502f86089 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 15:56:43 +0530 Subject: [PATCH 12/36] lint --- packages/toolbox-core/src/toolbox_core/tool.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 8dadbc15..901275ef 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -182,7 +182,8 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: for s in self.__required_authn_params.values(): req_auth_services.update(s) raise Exception( - f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" + f"One or more of the following authn services are required to invoke this tool" + f": {','.join(req_auth_services)}" ) # validate inputs to this call using the signature From e8a53c0b5142bf188cb4abe30360be6aa624a884 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 15:58:08 +0530 Subject: [PATCH 13/36] fix --- packages/toolbox-core/src/toolbox_core/client.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 3509f1e6..f54693a2 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -173,7 +173,11 @@ async def load_tool( # TODO: Better exception raise Exception(f"Tool '{name}' not found!") tool = self.__parse_tool( - name, manifest.tools[name], auth_token_getters, bound_params + name, + manifest.tools[name], + auth_token_getters, + bound_params, + self.__client_headers, ) return tool @@ -212,7 +216,9 @@ async def load_toolset( # parse each tools name and schema into a list of ToolboxTools tools = [ - self.__parse_tool(n, s, auth_token_getters, bound_params) + self.__parse_tool( + n, s, auth_token_getters, bound_params, self.__client_headers + ) for n, s in manifest.tools.items() ] return tools From ea319e5268150469a28cb5eea94ab809a22fc0dc Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 16:32:24 +0530 Subject: [PATCH 14/36] add client tests --- packages/toolbox-core/tests/test_client.py | 884 +++++++++++++-------- 1 file changed, 555 insertions(+), 329 deletions(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index a9cb091a..2e5de151 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -66,388 +66,614 @@ def test_tool_auth(): ) -@pytest.mark.asyncio -async def test_load_tool_success(aioresponses, test_tool_str): - """ - Tests successfully loading a tool when the API returns a valid manifest. - """ - # Mock out responses from server - TOOL_NAME = "test_tool_1" - manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str}) - aioresponses.get( - f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}", - payload=manifest.model_dump(), - status=200, - ) - aioresponses.post( - f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke", - payload={"result": "ok"}, - status=200, - ) - - async with ToolboxClient(TEST_BASE_URL) as client: - # Load a Tool - loaded_tool = await client.load_tool(TOOL_NAME) - - # Assertions - assert callable(loaded_tool) - # Assert introspection attributes are set correctly - assert loaded_tool.__name__ == TOOL_NAME - expected_description = ( - test_tool_str.description - + f"\n\nArgs:\n param1 (str): Description of Param1" - ) - assert loaded_tool.__doc__ == expected_description - - # Assert signature inspection - sig = inspect.signature(loaded_tool) - assert list(sig.parameters.keys()) == [p.name for p in test_tool_str.parameters] - - assert await loaded_tool("some value") == "ok" - - -@pytest.mark.asyncio -async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_bool): - """Tests successfully loading a toolset with multiple tools.""" - TOOLSET_NAME = "my_toolset" - TOOL1 = "tool1" - TOOL2 = "tool2" - manifest = ManifestSchema( - serverVersion="0.0.0", tools={TOOL1: test_tool_str, TOOL2: test_tool_int_bool} - ) - aioresponses.get( - f"{TEST_BASE_URL}/api/toolset/{TOOLSET_NAME}", - payload=manifest.model_dump(), - status=200, - ) - - async with ToolboxClient(TEST_BASE_URL) as client: - tools = await client.load_toolset(TOOLSET_NAME) - - assert isinstance(tools, list) - assert len(tools) == len(manifest.tools) - - # Check if tools were created correctly - assert {t.__name__ for t in tools} == manifest.tools.keys() - - -@pytest.mark.asyncio -async def test_invoke_tool_server_error(aioresponses, test_tool_str): - """Tests that invoking a tool raises an Exception when the server returns an - error status.""" - TOOL_NAME = "server_error_tool" - ERROR_MESSAGE = "Simulated Server Error" - manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str}) - - aioresponses.get( - f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}", - payload=manifest.model_dump(), - status=200, - ) - aioresponses.post( - f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke", - payload={"error": ERROR_MESSAGE}, - status=500, - ) - - async with ToolboxClient(TEST_BASE_URL) as client: - loaded_tool = await client.load_tool(TOOL_NAME) - - with pytest.raises(Exception, match=ERROR_MESSAGE): - await loaded_tool(param1="some input") - - -@pytest.mark.asyncio -async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str): - """ - Tests that load_tool raises an Exception when the requested tool name - is not found in the manifest returned by the server, using existing fixtures. - """ - ACTUAL_TOOL_IN_MANIFEST = "actual_tool_abc" - REQUESTED_TOOL_NAME = "non_existent_tool_xyz" - - manifest = ManifestSchema( - serverVersion="0.0.0", tools={ACTUAL_TOOL_IN_MANIFEST: test_tool_str} - ) - - aioresponses.get( - f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", - payload=manifest.model_dump(), - status=200, - ) - - async with ToolboxClient(TEST_BASE_URL) as client: - with pytest.raises(Exception, match=f"Tool '{REQUESTED_TOOL_NAME}' not found!"): - await client.load_tool(REQUESTED_TOOL_NAME) - - aioresponses.assert_called_once_with( - f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", method="GET" - ) - +# @pytest.mark.asyncio +# async def test_load_tool_success(aioresponses, test_tool_str): +# """ +# Tests successfully loading a tool when the API returns a valid manifest. +# """ +# # Mock out responses from server +# TOOL_NAME = "test_tool_1" +# manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str}) +# aioresponses.get( +# f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}", +# payload=manifest.model_dump(), +# status=200, +# ) +# aioresponses.post( +# f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke", +# payload={"result": "ok"}, +# status=200, +# ) +# +# async with ToolboxClient(TEST_BASE_URL) as client: +# # Load a Tool +# loaded_tool = await client.load_tool(TOOL_NAME) +# +# # Assertions +# assert callable(loaded_tool) +# # Assert introspection attributes are set correctly +# assert loaded_tool.__name__ == TOOL_NAME +# expected_description = ( +# test_tool_str.description +# + f"\n\nArgs:\n param1 (str): Description of Param1" +# ) +# assert loaded_tool.__doc__ == expected_description +# +# # Assert signature inspection +# sig = inspect.signature(loaded_tool) +# assert list(sig.parameters.keys()) == [p.name for p in test_tool_str.parameters] +# +# assert await loaded_tool("some value") == "ok" +# +# +# @pytest.mark.asyncio +# async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_bool): +# """Tests successfully loading a toolset with multiple tools.""" +# TOOLSET_NAME = "my_toolset" +# TOOL1 = "tool1" +# TOOL2 = "tool2" +# manifest = ManifestSchema( +# serverVersion="0.0.0", tools={TOOL1: test_tool_str, TOOL2: test_tool_int_bool} +# ) +# aioresponses.get( +# f"{TEST_BASE_URL}/api/toolset/{TOOLSET_NAME}", +# payload=manifest.model_dump(), +# status=200, +# ) +# +# async with ToolboxClient(TEST_BASE_URL) as client: +# tools = await client.load_toolset(TOOLSET_NAME) +# +# assert isinstance(tools, list) +# assert len(tools) == len(manifest.tools) +# +# # Check if tools were created correctly +# assert {t.__name__ for t in tools} == manifest.tools.keys() +# +# +# @pytest.mark.asyncio +# async def test_invoke_tool_server_error(aioresponses, test_tool_str): +# """Tests that invoking a tool raises an Exception when the server returns an +# error status.""" +# TOOL_NAME = "server_error_tool" +# ERROR_MESSAGE = "Simulated Server Error" +# manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str}) +# +# aioresponses.get( +# f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}", +# payload=manifest.model_dump(), +# status=200, +# ) +# aioresponses.post( +# f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke", +# payload={"error": ERROR_MESSAGE}, +# status=500, +# ) +# +# async with ToolboxClient(TEST_BASE_URL) as client: +# loaded_tool = await client.load_tool(TOOL_NAME) +# +# with pytest.raises(Exception, match=ERROR_MESSAGE): +# await loaded_tool(param1="some input") +# +# +# @pytest.mark.asyncio +# async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str): +# """ +# Tests that load_tool raises an Exception when the requested tool name +# is not found in the manifest returned by the server, using existing fixtures. +# """ +# ACTUAL_TOOL_IN_MANIFEST = "actual_tool_abc" +# REQUESTED_TOOL_NAME = "non_existent_tool_xyz" +# +# manifest = ManifestSchema( +# serverVersion="0.0.0", tools={ACTUAL_TOOL_IN_MANIFEST: test_tool_str} +# ) +# +# aioresponses.get( +# f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", +# payload=manifest.model_dump(), +# status=200, +# ) +# +# async with ToolboxClient(TEST_BASE_URL) as client: +# with pytest.raises(Exception, match=f"Tool '{REQUESTED_TOOL_NAME}' not found!"): +# await client.load_tool(REQUESTED_TOOL_NAME) +# +# aioresponses.assert_called_once_with( +# f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", method="GET" +# ) +# +# +# class TestAuth: +# +# @pytest.fixture +# def expected_header(self): +# return "some_token_for_testing" +# +# @pytest.fixture +# def tool_name(self): +# return "tool1" +# +# @pytest_asyncio.fixture +# async def client(self, aioresponses, test_tool_auth, tool_name, expected_header): +# manifest = ManifestSchema( +# serverVersion="0.0.0", tools={tool_name: test_tool_auth} +# ) +# +# # mock tool GET call +# aioresponses.get( +# f"{TEST_BASE_URL}/api/tool/{tool_name}", +# payload=manifest.model_dump(), +# status=200, +# ) +# +# # mock tool INVOKE call +# def require_headers(url, **kwargs): +# if kwargs["headers"].get("my-auth-service_token") == expected_header: +# return CallbackResult(status=200, body="{}") +# else: +# return CallbackResult(status=400, body="{}") +# +# aioresponses.post( +# f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", +# payload=manifest.model_dump(), +# callback=require_headers, +# status=200, +# ) +# +# async with ToolboxClient(TEST_BASE_URL) as client: +# yield client +# +# @pytest.mark.asyncio +# async def test_auth_with_load_tool_success( +# self, tool_name, expected_header, client +# ): +# """Tests 'load_tool' with auth token is specified.""" +# +# def token_handler(): +# return expected_header +# +# tool = await client.load_tool( +# tool_name, auth_token_getters={"my-auth-service": token_handler} +# ) +# await tool(5) +# +# @pytest.mark.asyncio +# async def test_auth_with_add_token_success( +# self, tool_name, expected_header, client +# ): +# """Tests 'load_tool' with auth token is specified.""" +# +# def token_handler(): +# return expected_header +# +# tool = await client.load_tool(tool_name) +# tool = tool.add_auth_token_getters({"my-auth-service": token_handler}) +# await tool(5) +# +# @pytest.mark.asyncio +# async def test_auth_with_load_tool_fail_no_token( +# self, tool_name, expected_header, client +# ): +# """Tests 'load_tool' with auth token is specified.""" +# +# tool = await client.load_tool(tool_name) +# with pytest.raises(Exception): +# await tool(5) +# +# @pytest.mark.asyncio +# async def test_add_auth_token_getters_duplicate_fail(self, tool_name, client): +# """ +# Tests that adding a duplicate auth token getter raises ValueError. +# """ +# AUTH_SERVICE = "my-auth-service" +# +# tool = await client.load_tool(tool_name) +# +# authed_tool = tool.add_auth_token_getters({AUTH_SERVICE: {}}) +# assert AUTH_SERVICE in authed_tool._ToolboxTool__auth_service_token_getters +# +# with pytest.raises( +# ValueError, +# match=f"Authentication source\\(s\\) `{AUTH_SERVICE}` already registered in tool `{tool_name}`.", +# ): +# authed_tool.add_auth_token_getters({AUTH_SERVICE: {}}) +# +# +# class TestBoundParameter: +# +# @pytest.fixture +# def tool_name(self): +# return "tool1" +# +# @pytest_asyncio.fixture +# async def client(self, aioresponses, test_tool_int_bool, tool_name): +# manifest = ManifestSchema( +# serverVersion="0.0.0", tools={tool_name: test_tool_int_bool} +# ) +# +# # mock toolset GET call +# aioresponses.get( +# f"{TEST_BASE_URL}/api/toolset/", +# payload=manifest.model_dump(), +# status=200, +# ) +# +# # mock tool GET call +# aioresponses.get( +# f"{TEST_BASE_URL}/api/tool/{tool_name}", +# payload=manifest.model_dump(), +# status=200, +# ) +# +# # mock tool INVOKE call +# def reflect_parameters(url, **kwargs): +# body = {"result": kwargs["json"]} +# return CallbackResult(status=200, body=json.dumps(body)) +# +# aioresponses.post( +# f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", +# payload=manifest.model_dump(), +# callback=reflect_parameters, +# status=200, +# ) +# +# async with ToolboxClient(TEST_BASE_URL) as client: +# yield client +# +# @pytest.mark.asyncio +# async def test_load_tool_success(self, tool_name, client): +# """Tests 'load_tool' with a bound parameter specified.""" +# tool = await client.load_tool(tool_name, bound_params={"argA": lambda: 5}) +# +# assert len(tool.__signature__.parameters) == 1 +# assert "argA" not in tool.__signature__.parameters +# +# res = await tool(True) +# assert "argA" in res +# +# @pytest.mark.asyncio +# async def test_load_toolset_success(self, tool_name, client): +# """Tests 'load_toolset' with a bound parameter specified.""" +# tools = await client.load_toolset("", bound_params={"argB": lambda: "hello"}) +# tool = tools[0] +# +# assert len(tool.__signature__.parameters) == 1 +# assert "argB" not in tool.__signature__.parameters +# +# res = await tool(True) +# assert "argB" in res +# +# @pytest.mark.asyncio +# async def test_bind_param_success(self, tool_name, client): +# """Tests 'bind_param' with a bound parameter specified.""" +# tool = await client.load_tool(tool_name) +# +# assert len(tool.__signature__.parameters) == 2 +# assert "argA" in tool.__signature__.parameters +# +# tool = tool.bind_parameters({"argA": 5}) +# +# assert len(tool.__signature__.parameters) == 1 +# assert "argA" not in tool.__signature__.parameters +# +# res = await tool(True) +# assert "argA" in res +# +# @pytest.mark.asyncio +# async def test_bind_callable_param_success(self, tool_name, client): +# """Tests 'bind_param' with a bound parameter specified.""" +# tool = await client.load_tool(tool_name) +# +# assert len(tool.__signature__.parameters) == 2 +# assert "argA" in tool.__signature__.parameters +# +# tool = tool.bind_parameters({"argA": lambda: 5}) +# +# assert len(tool.__signature__.parameters) == 1 +# assert "argA" not in tool.__signature__.parameters +# +# res = await tool(True) +# assert "argA" in res +# +# @pytest.mark.asyncio +# async def test_bind_param_fail(self, tool_name, client): +# """Tests 'bind_param' with a bound parameter that doesn't exist.""" +# tool = await client.load_tool(tool_name) +# +# assert len(tool.__signature__.parameters) == 2 +# assert "argA" in tool.__signature__.parameters +# +# with pytest.raises(Exception): +# tool = tool.bind_parameters({"argC": lambda: 5}) +# +# @pytest.mark.asyncio +# async def test_bind_param_static_value_success(self, tool_name, client): +# """ +# Tests bind_parameters method with a static value. +# """ +# +# bound_value = "Test value" +# +# tool = await client.load_tool(tool_name) +# bound_tool = tool.bind_parameters({"argB": bound_value}) +# +# assert bound_tool is not tool +# assert "argB" not in bound_tool.__signature__.parameters +# assert "argA" in bound_tool.__signature__.parameters +# +# passed_value_a = 42 +# res_payload = await bound_tool(argA=passed_value_a) +# +# assert res_payload == {"argA": passed_value_a, "argB": bound_value} +# +# @pytest.mark.asyncio +# async def test_bind_param_sync_callable_value_success(self, tool_name, client): +# """ +# Tests bind_parameters method with a sync callable value. +# """ +# +# bound_value_result = True +# bound_sync_callable = Mock(return_value=bound_value_result) +# +# tool = await client.load_tool(tool_name) +# bound_tool = tool.bind_parameters({"argB": bound_sync_callable}) +# +# assert bound_tool is not tool +# assert "argB" not in bound_tool.__signature__.parameters +# assert "argA" in bound_tool.__signature__.parameters +# +# passed_value_a = 42 +# res_payload = await bound_tool(argA=passed_value_a) +# +# assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} +# bound_sync_callable.assert_called_once() +# +# @pytest.mark.asyncio +# async def test_bind_param_async_callable_value_success(self, tool_name, client): +# """ +# Tests bind_parameters method with an async callable value. +# """ +# +# bound_value_result = True +# bound_async_callable = AsyncMock(return_value=bound_value_result) +# +# tool = await client.load_tool(tool_name) +# bound_tool = tool.bind_parameters({"argB": bound_async_callable}) +# +# assert bound_tool is not tool +# assert "argB" not in bound_tool.__signature__.parameters +# assert "argA" in bound_tool.__signature__.parameters +# +# passed_value_a = 42 +# res_payload = await bound_tool(argA=passed_value_a) +# +# assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} +# bound_async_callable.assert_awaited_once() -class TestAuth: +class TestClientHeaders: @pytest.fixture - def expected_header(self): - return "some_token_for_testing" + def static_header(self): + return {"X-Static-Header": "static_value"} @pytest.fixture - def tool_name(self): - return "tool1" - - @pytest_asyncio.fixture - async def client(self, aioresponses, test_tool_auth, tool_name, expected_header): - manifest = ManifestSchema( - serverVersion="0.0.0", tools={tool_name: test_tool_auth} - ) - - # mock tool GET call - aioresponses.get( - f"{TEST_BASE_URL}/api/tool/{tool_name}", - payload=manifest.model_dump(), - status=200, - ) - - # mock tool INVOKE call - def require_headers(url, **kwargs): - if kwargs["headers"].get("my-auth-service_token") == expected_header: - return CallbackResult(status=200, body="{}") - else: - return CallbackResult(status=400, body="{}") - - aioresponses.post( - f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", - payload=manifest.model_dump(), - callback=require_headers, - status=200, - ) - - async with ToolboxClient(TEST_BASE_URL) as client: - yield client + def sync_callable_header_value(self): + return "sync_callable_value" - @pytest.mark.asyncio - async def test_auth_with_load_tool_success( - self, tool_name, expected_header, client - ): - """Tests 'load_tool' with auth token is specified.""" + @pytest.fixture + def sync_callable_header(self, sync_callable_header_value): + return {"X-Sync-Callable-Header": Mock(return_value=sync_callable_header_value)} - def token_handler(): - return expected_header + @pytest.fixture + def async_callable_header_value(self): + return "async_callable_value" - tool = await client.load_tool( - tool_name, auth_token_getters={"my-auth-service": token_handler} - ) - await tool(5) + @pytest.fixture + def async_callable_header(self, async_callable_header_value): + # This mock, when awaited by the client, will return the value. + return { + "X-Async-Callable-Header": AsyncMock( + return_value=async_callable_header_value + ) + } @pytest.mark.asyncio - async def test_auth_with_add_token_success( - self, tool_name, expected_header, client - ): - """Tests 'load_tool' with auth token is specified.""" - - def token_handler(): - return expected_header - - tool = await client.load_tool(tool_name) - tool = tool.add_auth_token_getters({"my-auth-service": token_handler}) - await tool(5) + async def test_client_init_with_headers(self, static_header): + """Tests client initialization with static headers.""" + async with ToolboxClient(TEST_BASE_URL, client_headers=static_header) as client: + assert client._ToolboxClient__client_headers == static_header @pytest.mark.asyncio - async def test_auth_with_load_tool_fail_no_token( - self, tool_name, expected_header, client + async def test_load_tool_with_static_headers( + self, aioresponses, test_tool_str, static_header ): - """Tests 'load_tool' with auth token is specified.""" - - tool = await client.load_tool(tool_name) - with pytest.raises(Exception): - await tool(5) - - @pytest.mark.asyncio - async def test_add_auth_token_getters_duplicate_fail(self, tool_name, client): - """ - Tests that adding a duplicate auth token getter raises ValueError. - """ - AUTH_SERVICE = "my-auth-service" - - tool = await client.load_tool(tool_name) - - authed_tool = tool.add_auth_token_getters({AUTH_SERVICE: {}}) - assert AUTH_SERVICE in authed_tool._ToolboxTool__auth_service_token_getters - - with pytest.raises( - ValueError, - match=f"Authentication source\\(s\\) `{AUTH_SERVICE}` already registered in tool `{tool_name}`.", - ): - authed_tool.add_auth_token_getters({AUTH_SERVICE: {}}) - - -class TestBoundParameter: - - @pytest.fixture - def tool_name(self): - return "tool1" - - @pytest_asyncio.fixture - async def client(self, aioresponses, test_tool_int_bool, tool_name): + """Tests loading and invoking a tool with static client headers.""" + tool_name = "tool_with_static_headers" manifest = ManifestSchema( - serverVersion="0.0.0", tools={tool_name: test_tool_int_bool} + serverVersion="0.0.0", tools={tool_name: test_tool_str} ) + expected_payload = {"result": "ok"} - # mock toolset GET call - aioresponses.get( - f"{TEST_BASE_URL}/api/toolset/", - payload=manifest.model_dump(), - status=200, - ) + # Mock GET for tool definition + def get_callback(url, **kwargs): + # Verify headers + assert kwargs.get("headers") == static_header + return CallbackResult(status=200, payload=manifest.model_dump()) - # mock tool GET call - aioresponses.get( - f"{TEST_BASE_URL}/api/tool/{tool_name}", - payload=manifest.model_dump(), - status=200, - ) + aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) - # mock tool INVOKE call - def reflect_parameters(url, **kwargs): - body = {"result": kwargs["json"]} - return CallbackResult(status=200, body=json.dumps(body)) + # Mock POST for invocation + def post_callback(url, **kwargs): + # Verify headers + assert kwargs.get("headers") == static_header + return CallbackResult(status=200, payload=expected_payload) aioresponses.post( - f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", - payload=manifest.model_dump(), - callback=reflect_parameters, - status=200, + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback ) - async with ToolboxClient(TEST_BASE_URL) as client: - yield client + async with ToolboxClient(TEST_BASE_URL, client_headers=static_header) as client: + tool = await client.load_tool(tool_name) + result = await tool(param1="test") + assert result == expected_payload["result"] @pytest.mark.asyncio - async def test_load_tool_success(self, tool_name, client): - """Tests 'load_tool' with a bound parameter specified.""" - tool = await client.load_tool(tool_name, bound_params={"argA": lambda: 5}) + async def test_load_tool_with_sync_callable_headers(self, aioresponses, test_tool_str, sync_callable_header, + sync_callable_header_value): + """Tests loading and invoking a tool with sync callable client headers.""" + tool_name = "tool_with_sync_callable_headers" + manifest = ManifestSchema(serverVersion="0.0.0", tools={tool_name: test_tool_str}) + expected_payload = {"result": "ok_sync"} + header_key = list(sync_callable_header.keys())[0] + header_mock = sync_callable_header[header_key] + resolved_header = {header_key: sync_callable_header_value} + + # Mock GET + def get_callback(url, **kwargs): + # Verify headers + assert kwargs.get("headers") == resolved_header + return CallbackResult(status=200, payload=manifest.model_dump()) + + aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) + + # Mock POST + def post_callback(url, **kwargs): + # Verify headers + assert kwargs.get("headers") == resolved_header + return CallbackResult(status=200, payload=expected_payload) + + aioresponses.post(f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback) + + async with ToolboxClient(TEST_BASE_URL, client_headers=sync_callable_header) as client: + tool = await client.load_tool(tool_name) + header_mock.assert_called_once() # GET + + header_mock.reset_mock() # Reset before invoke + + result = await tool(param1="test") + assert result == expected_payload["result"] + header_mock.assert_called_once() # POST/invoke - assert len(tool.__signature__.parameters) == 1 - assert "argA" not in tool.__signature__.parameters + @pytest.mark.asyncio + async def test_load_tool_with_async_callable_headers(self, aioresponses, test_tool_str, + async_callable_header, + # Use the header fixture (provides mock) + async_callable_header_value + # Use the value fixture for expected result + ): + """Tests loading and invoking a tool with async callable client headers.""" + tool_name = "tool_with_async_callable_headers" + manifest = ManifestSchema(serverVersion="0.0.0", tools={tool_name: test_tool_str}) + expected_payload = {"result": "ok_async"} - res = await tool(True) - assert "argA" in res + header_key = list(async_callable_header.keys())[0] + header_mock: AsyncMock = async_callable_header[header_key] # Get the AsyncMock - @pytest.mark.asyncio - async def test_load_toolset_success(self, tool_name, client): - """Tests 'load_toolset' with a bound parameter specified.""" - tools = await client.load_toolset("", bound_params={"argB": lambda: "hello"}) - tool = tools[0] + # Calculate expected result using the VALUE fixture + resolved_header = {header_key: async_callable_header_value} - assert len(tool.__signature__.parameters) == 1 - assert "argB" not in tool.__signature__.parameters + # Mock GET callback checks against the RESOLVED value + def get_callback(url, **kwargs): + assert kwargs.get("headers") == resolved_header + return CallbackResult(status=200, payload=manifest.model_dump()) - res = await tool(True) - assert "argB" in res + aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) - @pytest.mark.asyncio - async def test_bind_param_success(self, tool_name, client): - """Tests 'bind_param' with a bound parameter specified.""" - tool = await client.load_tool(tool_name) + # Mock POST callback checks against the RESOLVED value + def post_callback(url, **kwargs): + assert kwargs.get("headers") == resolved_header + return CallbackResult(status=200, payload=expected_payload) - assert len(tool.__signature__.parameters) == 2 - assert "argA" in tool.__signature__.parameters + aioresponses.post(f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback) - tool = tool.bind_parameters({"argA": 5}) + # Pass the dictionary containing the AsyncMock to the client + async with ToolboxClient(TEST_BASE_URL, client_headers=async_callable_header) as client: + tool = await client.load_tool(tool_name) + header_mock.assert_awaited_once() # GET - assert len(tool.__signature__.parameters) == 1 - assert "argA" not in tool.__signature__.parameters + header_mock.reset_mock() - res = await tool(True) - assert "argA" in res + result = await tool(param1="test") + assert result == expected_payload["result"] + header_mock.assert_awaited_once() # POST/invoke @pytest.mark.asyncio - async def test_bind_callable_param_success(self, tool_name, client): - """Tests 'bind_param' with a bound parameter specified.""" - tool = await client.load_tool(tool_name) + async def test_load_toolset_with_headers(self, aioresponses, test_tool_str, static_header): + """Tests loading a toolset with client headers.""" + toolset_name = "toolset_with_headers" + tool_name = "tool_in_set" + manifest = ManifestSchema(serverVersion="0.0.0", tools={tool_name: test_tool_str}) - assert len(tool.__signature__.parameters) == 2 - assert "argA" in tool.__signature__.parameters + # Mock GET + def get_callback(url, **kwargs): + # Verify headers + assert kwargs.get("headers") == static_header + return CallbackResult(status=200, payload=manifest.model_dump()) - tool = tool.bind_parameters({"argA": lambda: 5}) + aioresponses.get(f"{TEST_BASE_URL}/api/toolset/{toolset_name}", callback=get_callback) - assert len(tool.__signature__.parameters) == 1 - assert "argA" not in tool.__signature__.parameters - - res = await tool(True) - assert "argA" in res + async with ToolboxClient(TEST_BASE_URL, client_headers=static_header) as client: + tools = await client.load_toolset(toolset_name) + assert len(tools) == 1 + assert tools[0].__name__ == tool_name @pytest.mark.asyncio - async def test_bind_param_fail(self, tool_name, client): - """Tests 'bind_param' with a bound parameter that doesn't exist.""" - tool = await client.load_tool(tool_name) - - assert len(tool.__signature__.parameters) == 2 - assert "argA" in tool.__signature__.parameters - - with pytest.raises(Exception): - tool = tool.bind_parameters({"argC": lambda: 5}) + async def test_add_headers_success(self, aioresponses, test_tool_str, static_header): + """Tests adding headers after client initialization.""" + tool_name = "tool_after_add_headers" + manifest = ManifestSchema(serverVersion="0.0.0", tools={tool_name: test_tool_str}) + expected_payload = {"result": "added_ok"} - @pytest.mark.asyncio - async def test_bind_param_static_value_success(self, tool_name, client): - """ - Tests bind_parameters method with a static value. - """ + # Mock GET for tool definition - check headers + def get_callback(url, **kwargs): + assert kwargs.get("headers") == static_header # Verify headers in GET + return CallbackResult(status=200, payload=manifest.model_dump()) - bound_value = "Test value" + aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) - tool = await client.load_tool(tool_name) - bound_tool = tool.bind_parameters({"argB": bound_value}) + # Mock POST + def post_callback(url, **kwargs): + # Verify headers + assert kwargs.get("headers") == static_header + return CallbackResult(status=200, payload=expected_payload) - assert bound_tool is not tool - assert "argB" not in bound_tool.__signature__.parameters - assert "argA" in bound_tool.__signature__.parameters + aioresponses.post(f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback) - passed_value_a = 42 - res_payload = await bound_tool(argA=passed_value_a) + async with ToolboxClient(TEST_BASE_URL) as client: + await client.add_headers(static_header) + assert client._ToolboxClient__client_headers == static_header - assert res_payload == {"argA": passed_value_a, "argB": bound_value} + tool = await client.load_tool(tool_name) + result = await tool(param1="test") + assert result == expected_payload["result"] @pytest.mark.asyncio - async def test_bind_param_sync_callable_value_success(self, tool_name, client): - """ - Tests bind_parameters method with a sync callable value. - """ - - bound_value_result = True - bound_sync_callable = Mock(return_value=bound_value_result) - - tool = await client.load_tool(tool_name) - bound_tool = tool.bind_parameters({"argB": bound_sync_callable}) - - assert bound_tool is not tool - assert "argB" not in bound_tool.__signature__.parameters - assert "argA" in bound_tool.__signature__.parameters - - passed_value_a = 42 - res_payload = await bound_tool(argA=passed_value_a) - - assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} - bound_sync_callable.assert_called_once() + async def test_add_headers_duplicate_fail(self, static_header): + """Tests that adding a duplicate header via add_headers raises ValueError.""" + async with ToolboxClient(TEST_BASE_URL, client_headers=static_header) as client: + with pytest.raises(ValueError, match=f"Client header\\(s\\) `X-Static-Header` already registered"): + await client.add_headers(static_header) # Try adding the same header again @pytest.mark.asyncio - async def test_bind_param_async_callable_value_success(self, tool_name, client): + async def test_client_header_auth_token_conflict_fail(self, aioresponses, test_tool_auth): """ - Tests bind_parameters method with an async callable value. + Tests that loading a tool fails if a client header conflicts with an auth token name. """ + tool_name = "auth_conflict_tool" + conflict_key = "my-auth-service_token" + manifest = ManifestSchema(serverVersion="0.0.0", tools={tool_name: test_tool_auth}) - bound_value_result = True - bound_async_callable = AsyncMock(return_value=bound_value_result) - - tool = await client.load_tool(tool_name) - bound_tool = tool.bind_parameters({"argB": bound_async_callable}) + conflicting_headers = {conflict_key: "some_value"} + auth_getters = {"my-auth-service": lambda: "token_val"} - assert bound_tool is not tool - assert "argB" not in bound_tool.__signature__.parameters - assert "argA" in bound_tool.__signature__.parameters - - passed_value_a = 42 - res_payload = await bound_tool(argA=passed_value_a) + aioresponses.get( + f"{TEST_BASE_URL}/api/tool/{tool_name}", + payload=manifest.model_dump(), + status=200, + ) - assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} - bound_async_callable.assert_awaited_once() + async with ToolboxClient(TEST_BASE_URL, client_headers=conflicting_headers) as client: + with pytest.raises(ValueError, match=f"Client header\\(s\\) `{conflict_key}` already registered"): + await client.load_tool(tool_name, auth_token_getters=auth_getters) From abeb8decce167aa243a9e084aacf60726199331a Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 16:33:20 +0530 Subject: [PATCH 15/36] add client tests --- packages/toolbox-core/tests/test_client.py | 770 ++++++++++----------- 1 file changed, 385 insertions(+), 385 deletions(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index 2e5de151..fa532200 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -66,391 +66,391 @@ def test_tool_auth(): ) -# @pytest.mark.asyncio -# async def test_load_tool_success(aioresponses, test_tool_str): -# """ -# Tests successfully loading a tool when the API returns a valid manifest. -# """ -# # Mock out responses from server -# TOOL_NAME = "test_tool_1" -# manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str}) -# aioresponses.get( -# f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}", -# payload=manifest.model_dump(), -# status=200, -# ) -# aioresponses.post( -# f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke", -# payload={"result": "ok"}, -# status=200, -# ) -# -# async with ToolboxClient(TEST_BASE_URL) as client: -# # Load a Tool -# loaded_tool = await client.load_tool(TOOL_NAME) -# -# # Assertions -# assert callable(loaded_tool) -# # Assert introspection attributes are set correctly -# assert loaded_tool.__name__ == TOOL_NAME -# expected_description = ( -# test_tool_str.description -# + f"\n\nArgs:\n param1 (str): Description of Param1" -# ) -# assert loaded_tool.__doc__ == expected_description -# -# # Assert signature inspection -# sig = inspect.signature(loaded_tool) -# assert list(sig.parameters.keys()) == [p.name for p in test_tool_str.parameters] -# -# assert await loaded_tool("some value") == "ok" -# -# -# @pytest.mark.asyncio -# async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_bool): -# """Tests successfully loading a toolset with multiple tools.""" -# TOOLSET_NAME = "my_toolset" -# TOOL1 = "tool1" -# TOOL2 = "tool2" -# manifest = ManifestSchema( -# serverVersion="0.0.0", tools={TOOL1: test_tool_str, TOOL2: test_tool_int_bool} -# ) -# aioresponses.get( -# f"{TEST_BASE_URL}/api/toolset/{TOOLSET_NAME}", -# payload=manifest.model_dump(), -# status=200, -# ) -# -# async with ToolboxClient(TEST_BASE_URL) as client: -# tools = await client.load_toolset(TOOLSET_NAME) -# -# assert isinstance(tools, list) -# assert len(tools) == len(manifest.tools) -# -# # Check if tools were created correctly -# assert {t.__name__ for t in tools} == manifest.tools.keys() -# -# -# @pytest.mark.asyncio -# async def test_invoke_tool_server_error(aioresponses, test_tool_str): -# """Tests that invoking a tool raises an Exception when the server returns an -# error status.""" -# TOOL_NAME = "server_error_tool" -# ERROR_MESSAGE = "Simulated Server Error" -# manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str}) -# -# aioresponses.get( -# f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}", -# payload=manifest.model_dump(), -# status=200, -# ) -# aioresponses.post( -# f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke", -# payload={"error": ERROR_MESSAGE}, -# status=500, -# ) -# -# async with ToolboxClient(TEST_BASE_URL) as client: -# loaded_tool = await client.load_tool(TOOL_NAME) -# -# with pytest.raises(Exception, match=ERROR_MESSAGE): -# await loaded_tool(param1="some input") -# -# -# @pytest.mark.asyncio -# async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str): -# """ -# Tests that load_tool raises an Exception when the requested tool name -# is not found in the manifest returned by the server, using existing fixtures. -# """ -# ACTUAL_TOOL_IN_MANIFEST = "actual_tool_abc" -# REQUESTED_TOOL_NAME = "non_existent_tool_xyz" -# -# manifest = ManifestSchema( -# serverVersion="0.0.0", tools={ACTUAL_TOOL_IN_MANIFEST: test_tool_str} -# ) -# -# aioresponses.get( -# f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", -# payload=manifest.model_dump(), -# status=200, -# ) -# -# async with ToolboxClient(TEST_BASE_URL) as client: -# with pytest.raises(Exception, match=f"Tool '{REQUESTED_TOOL_NAME}' not found!"): -# await client.load_tool(REQUESTED_TOOL_NAME) -# -# aioresponses.assert_called_once_with( -# f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", method="GET" -# ) -# -# -# class TestAuth: -# -# @pytest.fixture -# def expected_header(self): -# return "some_token_for_testing" -# -# @pytest.fixture -# def tool_name(self): -# return "tool1" -# -# @pytest_asyncio.fixture -# async def client(self, aioresponses, test_tool_auth, tool_name, expected_header): -# manifest = ManifestSchema( -# serverVersion="0.0.0", tools={tool_name: test_tool_auth} -# ) -# -# # mock tool GET call -# aioresponses.get( -# f"{TEST_BASE_URL}/api/tool/{tool_name}", -# payload=manifest.model_dump(), -# status=200, -# ) -# -# # mock tool INVOKE call -# def require_headers(url, **kwargs): -# if kwargs["headers"].get("my-auth-service_token") == expected_header: -# return CallbackResult(status=200, body="{}") -# else: -# return CallbackResult(status=400, body="{}") -# -# aioresponses.post( -# f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", -# payload=manifest.model_dump(), -# callback=require_headers, -# status=200, -# ) -# -# async with ToolboxClient(TEST_BASE_URL) as client: -# yield client -# -# @pytest.mark.asyncio -# async def test_auth_with_load_tool_success( -# self, tool_name, expected_header, client -# ): -# """Tests 'load_tool' with auth token is specified.""" -# -# def token_handler(): -# return expected_header -# -# tool = await client.load_tool( -# tool_name, auth_token_getters={"my-auth-service": token_handler} -# ) -# await tool(5) -# -# @pytest.mark.asyncio -# async def test_auth_with_add_token_success( -# self, tool_name, expected_header, client -# ): -# """Tests 'load_tool' with auth token is specified.""" -# -# def token_handler(): -# return expected_header -# -# tool = await client.load_tool(tool_name) -# tool = tool.add_auth_token_getters({"my-auth-service": token_handler}) -# await tool(5) -# -# @pytest.mark.asyncio -# async def test_auth_with_load_tool_fail_no_token( -# self, tool_name, expected_header, client -# ): -# """Tests 'load_tool' with auth token is specified.""" -# -# tool = await client.load_tool(tool_name) -# with pytest.raises(Exception): -# await tool(5) -# -# @pytest.mark.asyncio -# async def test_add_auth_token_getters_duplicate_fail(self, tool_name, client): -# """ -# Tests that adding a duplicate auth token getter raises ValueError. -# """ -# AUTH_SERVICE = "my-auth-service" -# -# tool = await client.load_tool(tool_name) -# -# authed_tool = tool.add_auth_token_getters({AUTH_SERVICE: {}}) -# assert AUTH_SERVICE in authed_tool._ToolboxTool__auth_service_token_getters -# -# with pytest.raises( -# ValueError, -# match=f"Authentication source\\(s\\) `{AUTH_SERVICE}` already registered in tool `{tool_name}`.", -# ): -# authed_tool.add_auth_token_getters({AUTH_SERVICE: {}}) -# -# -# class TestBoundParameter: -# -# @pytest.fixture -# def tool_name(self): -# return "tool1" -# -# @pytest_asyncio.fixture -# async def client(self, aioresponses, test_tool_int_bool, tool_name): -# manifest = ManifestSchema( -# serverVersion="0.0.0", tools={tool_name: test_tool_int_bool} -# ) -# -# # mock toolset GET call -# aioresponses.get( -# f"{TEST_BASE_URL}/api/toolset/", -# payload=manifest.model_dump(), -# status=200, -# ) -# -# # mock tool GET call -# aioresponses.get( -# f"{TEST_BASE_URL}/api/tool/{tool_name}", -# payload=manifest.model_dump(), -# status=200, -# ) -# -# # mock tool INVOKE call -# def reflect_parameters(url, **kwargs): -# body = {"result": kwargs["json"]} -# return CallbackResult(status=200, body=json.dumps(body)) -# -# aioresponses.post( -# f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", -# payload=manifest.model_dump(), -# callback=reflect_parameters, -# status=200, -# ) -# -# async with ToolboxClient(TEST_BASE_URL) as client: -# yield client -# -# @pytest.mark.asyncio -# async def test_load_tool_success(self, tool_name, client): -# """Tests 'load_tool' with a bound parameter specified.""" -# tool = await client.load_tool(tool_name, bound_params={"argA": lambda: 5}) -# -# assert len(tool.__signature__.parameters) == 1 -# assert "argA" not in tool.__signature__.parameters -# -# res = await tool(True) -# assert "argA" in res -# -# @pytest.mark.asyncio -# async def test_load_toolset_success(self, tool_name, client): -# """Tests 'load_toolset' with a bound parameter specified.""" -# tools = await client.load_toolset("", bound_params={"argB": lambda: "hello"}) -# tool = tools[0] -# -# assert len(tool.__signature__.parameters) == 1 -# assert "argB" not in tool.__signature__.parameters -# -# res = await tool(True) -# assert "argB" in res -# -# @pytest.mark.asyncio -# async def test_bind_param_success(self, tool_name, client): -# """Tests 'bind_param' with a bound parameter specified.""" -# tool = await client.load_tool(tool_name) -# -# assert len(tool.__signature__.parameters) == 2 -# assert "argA" in tool.__signature__.parameters -# -# tool = tool.bind_parameters({"argA": 5}) -# -# assert len(tool.__signature__.parameters) == 1 -# assert "argA" not in tool.__signature__.parameters -# -# res = await tool(True) -# assert "argA" in res -# -# @pytest.mark.asyncio -# async def test_bind_callable_param_success(self, tool_name, client): -# """Tests 'bind_param' with a bound parameter specified.""" -# tool = await client.load_tool(tool_name) -# -# assert len(tool.__signature__.parameters) == 2 -# assert "argA" in tool.__signature__.parameters -# -# tool = tool.bind_parameters({"argA": lambda: 5}) -# -# assert len(tool.__signature__.parameters) == 1 -# assert "argA" not in tool.__signature__.parameters -# -# res = await tool(True) -# assert "argA" in res -# -# @pytest.mark.asyncio -# async def test_bind_param_fail(self, tool_name, client): -# """Tests 'bind_param' with a bound parameter that doesn't exist.""" -# tool = await client.load_tool(tool_name) -# -# assert len(tool.__signature__.parameters) == 2 -# assert "argA" in tool.__signature__.parameters -# -# with pytest.raises(Exception): -# tool = tool.bind_parameters({"argC": lambda: 5}) -# -# @pytest.mark.asyncio -# async def test_bind_param_static_value_success(self, tool_name, client): -# """ -# Tests bind_parameters method with a static value. -# """ -# -# bound_value = "Test value" -# -# tool = await client.load_tool(tool_name) -# bound_tool = tool.bind_parameters({"argB": bound_value}) -# -# assert bound_tool is not tool -# assert "argB" not in bound_tool.__signature__.parameters -# assert "argA" in bound_tool.__signature__.parameters -# -# passed_value_a = 42 -# res_payload = await bound_tool(argA=passed_value_a) -# -# assert res_payload == {"argA": passed_value_a, "argB": bound_value} -# -# @pytest.mark.asyncio -# async def test_bind_param_sync_callable_value_success(self, tool_name, client): -# """ -# Tests bind_parameters method with a sync callable value. -# """ -# -# bound_value_result = True -# bound_sync_callable = Mock(return_value=bound_value_result) -# -# tool = await client.load_tool(tool_name) -# bound_tool = tool.bind_parameters({"argB": bound_sync_callable}) -# -# assert bound_tool is not tool -# assert "argB" not in bound_tool.__signature__.parameters -# assert "argA" in bound_tool.__signature__.parameters -# -# passed_value_a = 42 -# res_payload = await bound_tool(argA=passed_value_a) -# -# assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} -# bound_sync_callable.assert_called_once() -# -# @pytest.mark.asyncio -# async def test_bind_param_async_callable_value_success(self, tool_name, client): -# """ -# Tests bind_parameters method with an async callable value. -# """ -# -# bound_value_result = True -# bound_async_callable = AsyncMock(return_value=bound_value_result) -# -# tool = await client.load_tool(tool_name) -# bound_tool = tool.bind_parameters({"argB": bound_async_callable}) -# -# assert bound_tool is not tool -# assert "argB" not in bound_tool.__signature__.parameters -# assert "argA" in bound_tool.__signature__.parameters -# -# passed_value_a = 42 -# res_payload = await bound_tool(argA=passed_value_a) -# -# assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} -# bound_async_callable.assert_awaited_once() +@pytest.mark.asyncio +async def test_load_tool_success(aioresponses, test_tool_str): + """ + Tests successfully loading a tool when the API returns a valid manifest. + """ + # Mock out responses from server + TOOL_NAME = "test_tool_1" + manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str}) + aioresponses.get( + f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}", + payload=manifest.model_dump(), + status=200, + ) + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke", + payload={"result": "ok"}, + status=200, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + # Load a Tool + loaded_tool = await client.load_tool(TOOL_NAME) + + # Assertions + assert callable(loaded_tool) + # Assert introspection attributes are set correctly + assert loaded_tool.__name__ == TOOL_NAME + expected_description = ( + test_tool_str.description + + f"\n\nArgs:\n param1 (str): Description of Param1" + ) + assert loaded_tool.__doc__ == expected_description + + # Assert signature inspection + sig = inspect.signature(loaded_tool) + assert list(sig.parameters.keys()) == [p.name for p in test_tool_str.parameters] + + assert await loaded_tool("some value") == "ok" + + +@pytest.mark.asyncio +async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_bool): + """Tests successfully loading a toolset with multiple tools.""" + TOOLSET_NAME = "my_toolset" + TOOL1 = "tool1" + TOOL2 = "tool2" + manifest = ManifestSchema( + serverVersion="0.0.0", tools={TOOL1: test_tool_str, TOOL2: test_tool_int_bool} + ) + aioresponses.get( + f"{TEST_BASE_URL}/api/toolset/{TOOLSET_NAME}", + payload=manifest.model_dump(), + status=200, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + tools = await client.load_toolset(TOOLSET_NAME) + + assert isinstance(tools, list) + assert len(tools) == len(manifest.tools) + + # Check if tools were created correctly + assert {t.__name__ for t in tools} == manifest.tools.keys() + + +@pytest.mark.asyncio +async def test_invoke_tool_server_error(aioresponses, test_tool_str): + """Tests that invoking a tool raises an Exception when the server returns an + error status.""" + TOOL_NAME = "server_error_tool" + ERROR_MESSAGE = "Simulated Server Error" + manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str}) + + aioresponses.get( + f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}", + payload=manifest.model_dump(), + status=200, + ) + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke", + payload={"error": ERROR_MESSAGE}, + status=500, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + loaded_tool = await client.load_tool(TOOL_NAME) + + with pytest.raises(Exception, match=ERROR_MESSAGE): + await loaded_tool(param1="some input") + + +@pytest.mark.asyncio +async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str): + """ + Tests that load_tool raises an Exception when the requested tool name + is not found in the manifest returned by the server, using existing fixtures. + """ + ACTUAL_TOOL_IN_MANIFEST = "actual_tool_abc" + REQUESTED_TOOL_NAME = "non_existent_tool_xyz" + + manifest = ManifestSchema( + serverVersion="0.0.0", tools={ACTUAL_TOOL_IN_MANIFEST: test_tool_str} + ) + + aioresponses.get( + f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", + payload=manifest.model_dump(), + status=200, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + with pytest.raises(Exception, match=f"Tool '{REQUESTED_TOOL_NAME}' not found!"): + await client.load_tool(REQUESTED_TOOL_NAME) + + aioresponses.assert_called_once_with( + f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", method="GET" + ) + + +class TestAuth: + + @pytest.fixture + def expected_header(self): + return "some_token_for_testing" + + @pytest.fixture + def tool_name(self): + return "tool1" + + @pytest_asyncio.fixture + async def client(self, aioresponses, test_tool_auth, tool_name, expected_header): + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_auth} + ) + + # mock tool GET call + aioresponses.get( + f"{TEST_BASE_URL}/api/tool/{tool_name}", + payload=manifest.model_dump(), + status=200, + ) + + # mock tool INVOKE call + def require_headers(url, **kwargs): + if kwargs["headers"].get("my-auth-service_token") == expected_header: + return CallbackResult(status=200, body="{}") + else: + return CallbackResult(status=400, body="{}") + + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", + payload=manifest.model_dump(), + callback=require_headers, + status=200, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + yield client + + @pytest.mark.asyncio + async def test_auth_with_load_tool_success( + self, tool_name, expected_header, client + ): + """Tests 'load_tool' with auth token is specified.""" + + def token_handler(): + return expected_header + + tool = await client.load_tool( + tool_name, auth_token_getters={"my-auth-service": token_handler} + ) + await tool(5) + + @pytest.mark.asyncio + async def test_auth_with_add_token_success( + self, tool_name, expected_header, client + ): + """Tests 'load_tool' with auth token is specified.""" + + def token_handler(): + return expected_header + + tool = await client.load_tool(tool_name) + tool = tool.add_auth_token_getters({"my-auth-service": token_handler}) + await tool(5) + + @pytest.mark.asyncio + async def test_auth_with_load_tool_fail_no_token( + self, tool_name, expected_header, client + ): + """Tests 'load_tool' with auth token is specified.""" + + tool = await client.load_tool(tool_name) + with pytest.raises(Exception): + await tool(5) + + @pytest.mark.asyncio + async def test_add_auth_token_getters_duplicate_fail(self, tool_name, client): + """ + Tests that adding a duplicate auth token getter raises ValueError. + """ + AUTH_SERVICE = "my-auth-service" + + tool = await client.load_tool(tool_name) + + authed_tool = tool.add_auth_token_getters({AUTH_SERVICE: {}}) + assert AUTH_SERVICE in authed_tool._ToolboxTool__auth_service_token_getters + + with pytest.raises( + ValueError, + match=f"Authentication source\\(s\\) `{AUTH_SERVICE}` already registered in tool `{tool_name}`.", + ): + authed_tool.add_auth_token_getters({AUTH_SERVICE: {}}) + + +class TestBoundParameter: + + @pytest.fixture + def tool_name(self): + return "tool1" + + @pytest_asyncio.fixture + async def client(self, aioresponses, test_tool_int_bool, tool_name): + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_int_bool} + ) + + # mock toolset GET call + aioresponses.get( + f"{TEST_BASE_URL}/api/toolset/", + payload=manifest.model_dump(), + status=200, + ) + + # mock tool GET call + aioresponses.get( + f"{TEST_BASE_URL}/api/tool/{tool_name}", + payload=manifest.model_dump(), + status=200, + ) + + # mock tool INVOKE call + def reflect_parameters(url, **kwargs): + body = {"result": kwargs["json"]} + return CallbackResult(status=200, body=json.dumps(body)) + + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", + payload=manifest.model_dump(), + callback=reflect_parameters, + status=200, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + yield client + + @pytest.mark.asyncio + async def test_load_tool_success(self, tool_name, client): + """Tests 'load_tool' with a bound parameter specified.""" + tool = await client.load_tool(tool_name, bound_params={"argA": lambda: 5}) + + assert len(tool.__signature__.parameters) == 1 + assert "argA" not in tool.__signature__.parameters + + res = await tool(True) + assert "argA" in res + + @pytest.mark.asyncio + async def test_load_toolset_success(self, tool_name, client): + """Tests 'load_toolset' with a bound parameter specified.""" + tools = await client.load_toolset("", bound_params={"argB": lambda: "hello"}) + tool = tools[0] + + assert len(tool.__signature__.parameters) == 1 + assert "argB" not in tool.__signature__.parameters + + res = await tool(True) + assert "argB" in res + + @pytest.mark.asyncio + async def test_bind_param_success(self, tool_name, client): + """Tests 'bind_param' with a bound parameter specified.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + tool = tool.bind_parameters({"argA": 5}) + + assert len(tool.__signature__.parameters) == 1 + assert "argA" not in tool.__signature__.parameters + + res = await tool(True) + assert "argA" in res + + @pytest.mark.asyncio + async def test_bind_callable_param_success(self, tool_name, client): + """Tests 'bind_param' with a bound parameter specified.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + tool = tool.bind_parameters({"argA": lambda: 5}) + + assert len(tool.__signature__.parameters) == 1 + assert "argA" not in tool.__signature__.parameters + + res = await tool(True) + assert "argA" in res + + @pytest.mark.asyncio + async def test_bind_param_fail(self, tool_name, client): + """Tests 'bind_param' with a bound parameter that doesn't exist.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + with pytest.raises(Exception): + tool = tool.bind_parameters({"argC": lambda: 5}) + + @pytest.mark.asyncio + async def test_bind_param_static_value_success(self, tool_name, client): + """ + Tests bind_parameters method with a static value. + """ + + bound_value = "Test value" + + tool = await client.load_tool(tool_name) + bound_tool = tool.bind_parameters({"argB": bound_value}) + + assert bound_tool is not tool + assert "argB" not in bound_tool.__signature__.parameters + assert "argA" in bound_tool.__signature__.parameters + + passed_value_a = 42 + res_payload = await bound_tool(argA=passed_value_a) + + assert res_payload == {"argA": passed_value_a, "argB": bound_value} + + @pytest.mark.asyncio + async def test_bind_param_sync_callable_value_success(self, tool_name, client): + """ + Tests bind_parameters method with a sync callable value. + """ + + bound_value_result = True + bound_sync_callable = Mock(return_value=bound_value_result) + + tool = await client.load_tool(tool_name) + bound_tool = tool.bind_parameters({"argB": bound_sync_callable}) + + assert bound_tool is not tool + assert "argB" not in bound_tool.__signature__.parameters + assert "argA" in bound_tool.__signature__.parameters + + passed_value_a = 42 + res_payload = await bound_tool(argA=passed_value_a) + + assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} + bound_sync_callable.assert_called_once() + + @pytest.mark.asyncio + async def test_bind_param_async_callable_value_success(self, tool_name, client): + """ + Tests bind_parameters method with an async callable value. + """ + + bound_value_result = True + bound_async_callable = AsyncMock(return_value=bound_value_result) + + tool = await client.load_tool(tool_name) + bound_tool = tool.bind_parameters({"argB": bound_async_callable}) + + assert bound_tool is not tool + assert "argB" not in bound_tool.__signature__.parameters + assert "argA" in bound_tool.__signature__.parameters + + passed_value_a = 42 + res_payload = await bound_tool(argA=passed_value_a) + + assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} + bound_async_callable.assert_awaited_once() class TestClientHeaders: From 45c0fc6d0901a96be1feacc87df95922f09285a6 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 16:35:40 +0530 Subject: [PATCH 16/36] fix tests --- packages/toolbox-core/tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index fa532200..ea46990e 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -181,7 +181,7 @@ async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str): await client.load_tool(REQUESTED_TOOL_NAME) aioresponses.assert_called_once_with( - f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", method="GET" + f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", method="GET", headers={} ) From 769301140d434cfca2463d5343cccd2c72a2752d Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 16:36:05 +0530 Subject: [PATCH 17/36] fix --- packages/toolbox-core/src/toolbox_core/tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 901275ef..9c977df2 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -202,7 +202,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: headers = {} for auth_service, token_getter in self.__auth_service_token_getters.items(): headers[f"{auth_service}_token"] = await resolve_value(token_getter) - for client_header_name, client_header_val in self.__client_headers: + for client_header_name, client_header_val in self.__client_headers.items(): headers[client_header_name] = await resolve_value(client_header_val) async with self.__session.post( From a43ffadc74f11b7206c52cb17086dc7eed418e4f Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 16:36:38 +0530 Subject: [PATCH 18/36] lint --- packages/toolbox-core/tests/test_client.py | 100 +++++++++++++++------ 1 file changed, 73 insertions(+), 27 deletions(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index ea46990e..badf8e61 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -520,11 +520,18 @@ def post_callback(url, **kwargs): assert result == expected_payload["result"] @pytest.mark.asyncio - async def test_load_tool_with_sync_callable_headers(self, aioresponses, test_tool_str, sync_callable_header, - sync_callable_header_value): + async def test_load_tool_with_sync_callable_headers( + self, + aioresponses, + test_tool_str, + sync_callable_header, + sync_callable_header_value, + ): """Tests loading and invoking a tool with sync callable client headers.""" tool_name = "tool_with_sync_callable_headers" - manifest = ManifestSchema(serverVersion="0.0.0", tools={tool_name: test_tool_str}) + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_str} + ) expected_payload = {"result": "ok_sync"} header_key = list(sync_callable_header.keys())[0] header_mock = sync_callable_header[header_key] @@ -544,9 +551,13 @@ def post_callback(url, **kwargs): assert kwargs.get("headers") == resolved_header return CallbackResult(status=200, payload=expected_payload) - aioresponses.post(f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback) + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback + ) - async with ToolboxClient(TEST_BASE_URL, client_headers=sync_callable_header) as client: + async with ToolboxClient( + TEST_BASE_URL, client_headers=sync_callable_header + ) as client: tool = await client.load_tool(tool_name) header_mock.assert_called_once() # GET @@ -557,15 +568,20 @@ def post_callback(url, **kwargs): header_mock.assert_called_once() # POST/invoke @pytest.mark.asyncio - async def test_load_tool_with_async_callable_headers(self, aioresponses, test_tool_str, - async_callable_header, - # Use the header fixture (provides mock) - async_callable_header_value - # Use the value fixture for expected result - ): + async def test_load_tool_with_async_callable_headers( + self, + aioresponses, + test_tool_str, + async_callable_header, + # Use the header fixture (provides mock) + async_callable_header_value, + # Use the value fixture for expected result + ): """Tests loading and invoking a tool with async callable client headers.""" tool_name = "tool_with_async_callable_headers" - manifest = ManifestSchema(serverVersion="0.0.0", tools={tool_name: test_tool_str}) + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_str} + ) expected_payload = {"result": "ok_async"} header_key = list(async_callable_header.keys())[0] @@ -586,12 +602,16 @@ def post_callback(url, **kwargs): assert kwargs.get("headers") == resolved_header return CallbackResult(status=200, payload=expected_payload) - aioresponses.post(f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback) + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback + ) # Pass the dictionary containing the AsyncMock to the client - async with ToolboxClient(TEST_BASE_URL, client_headers=async_callable_header) as client: + async with ToolboxClient( + TEST_BASE_URL, client_headers=async_callable_header + ) as client: tool = await client.load_tool(tool_name) - header_mock.assert_awaited_once() # GET + header_mock.assert_awaited_once() # GET header_mock.reset_mock() @@ -600,11 +620,15 @@ def post_callback(url, **kwargs): header_mock.assert_awaited_once() # POST/invoke @pytest.mark.asyncio - async def test_load_toolset_with_headers(self, aioresponses, test_tool_str, static_header): + async def test_load_toolset_with_headers( + self, aioresponses, test_tool_str, static_header + ): """Tests loading a toolset with client headers.""" toolset_name = "toolset_with_headers" tool_name = "tool_in_set" - manifest = ManifestSchema(serverVersion="0.0.0", tools={tool_name: test_tool_str}) + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_str} + ) # Mock GET def get_callback(url, **kwargs): @@ -612,7 +636,9 @@ def get_callback(url, **kwargs): assert kwargs.get("headers") == static_header return CallbackResult(status=200, payload=manifest.model_dump()) - aioresponses.get(f"{TEST_BASE_URL}/api/toolset/{toolset_name}", callback=get_callback) + aioresponses.get( + f"{TEST_BASE_URL}/api/toolset/{toolset_name}", callback=get_callback + ) async with ToolboxClient(TEST_BASE_URL, client_headers=static_header) as client: tools = await client.load_toolset(toolset_name) @@ -620,10 +646,14 @@ def get_callback(url, **kwargs): assert tools[0].__name__ == tool_name @pytest.mark.asyncio - async def test_add_headers_success(self, aioresponses, test_tool_str, static_header): + async def test_add_headers_success( + self, aioresponses, test_tool_str, static_header + ): """Tests adding headers after client initialization.""" tool_name = "tool_after_add_headers" - manifest = ManifestSchema(serverVersion="0.0.0", tools={tool_name: test_tool_str}) + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_str} + ) expected_payload = {"result": "added_ok"} # Mock GET for tool definition - check headers @@ -639,7 +669,9 @@ def post_callback(url, **kwargs): assert kwargs.get("headers") == static_header return CallbackResult(status=200, payload=expected_payload) - aioresponses.post(f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback) + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback + ) async with ToolboxClient(TEST_BASE_URL) as client: await client.add_headers(static_header) @@ -653,17 +685,26 @@ def post_callback(url, **kwargs): async def test_add_headers_duplicate_fail(self, static_header): """Tests that adding a duplicate header via add_headers raises ValueError.""" async with ToolboxClient(TEST_BASE_URL, client_headers=static_header) as client: - with pytest.raises(ValueError, match=f"Client header\\(s\\) `X-Static-Header` already registered"): - await client.add_headers(static_header) # Try adding the same header again + with pytest.raises( + ValueError, + match=f"Client header\\(s\\) `X-Static-Header` already registered", + ): + await client.add_headers( + static_header + ) # Try adding the same header again @pytest.mark.asyncio - async def test_client_header_auth_token_conflict_fail(self, aioresponses, test_tool_auth): + async def test_client_header_auth_token_conflict_fail( + self, aioresponses, test_tool_auth + ): """ Tests that loading a tool fails if a client header conflicts with an auth token name. """ tool_name = "auth_conflict_tool" conflict_key = "my-auth-service_token" - manifest = ManifestSchema(serverVersion="0.0.0", tools={tool_name: test_tool_auth}) + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_auth} + ) conflicting_headers = {conflict_key: "some_value"} auth_getters = {"my-auth-service": lambda: "token_val"} @@ -674,6 +715,11 @@ async def test_client_header_auth_token_conflict_fail(self, aioresponses, test_t status=200, ) - async with ToolboxClient(TEST_BASE_URL, client_headers=conflicting_headers) as client: - with pytest.raises(ValueError, match=f"Client header\\(s\\) `{conflict_key}` already registered"): + async with ToolboxClient( + TEST_BASE_URL, client_headers=conflicting_headers + ) as client: + with pytest.raises( + ValueError, + match=f"Client header\\(s\\) `{conflict_key}` already registered", + ): await client.load_tool(tool_name, auth_token_getters=auth_getters) From f564e603693f200e6810e228f1bbd56f04333235 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 16:51:38 +0530 Subject: [PATCH 19/36] fix tests --- packages/toolbox-core/tests/test_tools.py | 155 ++++++++++++++++++++-- 1 file changed, 147 insertions(+), 8 deletions(-) diff --git a/packages/toolbox-core/tests/test_tools.py b/packages/toolbox-core/tests/test_tools.py index 7cb4f305..5554cc54 100644 --- a/packages/toolbox-core/tests/test_tools.py +++ b/packages/toolbox-core/tests/test_tools.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -from typing import AsyncGenerator +import inspect +from inspect import iscoroutine +from typing import AsyncGenerator, Callable from unittest.mock import AsyncMock, Mock import pytest @@ -40,6 +40,20 @@ def sample_tool_params() -> list[ParameterSchema]: ] +@pytest.fixture +def sample_tool_auth_params() -> list[ParameterSchema]: + """Parameters for a sample tool requiring authentication.""" + return [ + ParameterSchema(name="target", type="string", description="Target system"), + ParameterSchema( + name="token", + type="string", + description="Auth token", + authSources=["test-auth"], + ), + ] + + @pytest.fixture def sample_tool_description() -> str: """Description for the sample tool.""" @@ -53,6 +67,61 @@ async def http_session() -> AsyncGenerator[ClientSession, None]: yield session +# --- Fixtures for Client Headers --- + + +@pytest.fixture +def static_client_header() -> dict[str, str]: + return {"X-Client-Static": "client-static-value"} + + +@pytest.fixture +def sync_callable_client_header_value() -> str: + return "client-sync-callable-value" + + +@pytest.fixture +def sync_callable_client_header(sync_callable_client_header_value) -> dict[str, Mock]: + return {"X-Client-Sync": Mock(return_value=sync_callable_client_header_value)} + + +@pytest.fixture +def async_callable_client_header_value() -> str: + return "client-async-callable-value" + + +@pytest.fixture +def async_callable_client_header( + async_callable_client_header_value, +) -> dict[str, AsyncMock]: + return { + "X-Client-Async": AsyncMock(return_value=async_callable_client_header_value) + } + + +# --- Fixtures for Auth Getters --- + + +@pytest.fixture +def auth_token_value() -> str: + return "auth-token-123" + + +@pytest.fixture +def auth_getters(auth_token_value) -> dict[str, Callable[[], str]]: + return {"test-auth": lambda: auth_token_value} + + +@pytest.fixture +def auth_getters_mock(auth_token_value) -> dict[str, Mock]: + return {"test-auth": Mock(return_value=auth_token_value)} + + +@pytest.fixture +def auth_header_key() -> str: + return "test-auth_token" + + def test_create_func_docstring_one_param_real_schema(): """ Tests create_func_docstring with one real ParameterSchema instance. @@ -168,6 +237,7 @@ async def test_tool_creation_callable_and_run( required_authn_params={}, auth_service_token_getters={}, bound_params={}, + client_headers={}, ) assert callable(tool_instance), "ToolboxTool instance should be callable" @@ -212,17 +282,15 @@ async def test_tool_run_with_pydantic_validation_error( required_authn_params={}, auth_service_token_getters={}, bound_params={}, + client_headers={}, ) assert callable(tool_instance) - with pytest.raises(ValidationError) as exc_info: + expected_pattern = r"1 validation error for sample_tool\ncount\n Input should be a valid integer, unable to parse string as an integer \[\s*type=int_parsing,\s*input_value='not-a-number',\s*input_type=str\s*\]*" + with pytest.raises(ValidationError, match=expected_pattern): await tool_instance(message="hello", count="not-a-number") - assert ( - "1 validation error for sample_tool\ncount\n Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='not-a-number', input_type=str]\n For further information visit https://errors.pydantic.dev/2.11/v/int_parsing" - in str(exc_info.value) - ) m.assert_not_called() @@ -285,3 +353,74 @@ async def test_resolve_value_async_callable(): async_callable.assert_awaited_once() assert resolved == expected_value + + +# --- Tests for ToolboxTool Initialization and Validation --- + + +def test_tool_init_basic(http_session, sample_tool_params, sample_tool_description): + """Tests basic tool initialization without headers or auth.""" + tool_instance = ToolboxTool( + session=http_session, + base_url=TEST_BASE_URL, + name=TEST_TOOL_NAME, + description=sample_tool_description, + params=sample_tool_params, + required_authn_params={}, + auth_service_token_getters={}, + bound_params={}, + client_headers={}, + ) + assert tool_instance.__name__ == TEST_TOOL_NAME + assert inspect.iscoroutinefunction(tool_instance.__call__) + assert "message" in tool_instance.__signature__.parameters + assert "count" in tool_instance.__signature__.parameters + + assert tool_instance._ToolboxTool__client_headers == {} + assert tool_instance._ToolboxTool__auth_service_token_getters == {} + + +def test_tool_init_with_client_headers( + http_session, sample_tool_params, sample_tool_description, static_client_header +): + """Tests tool initialization *with* client headers.""" + tool_instance = ToolboxTool( + session=http_session, + base_url=TEST_BASE_URL, + name=TEST_TOOL_NAME, + description=sample_tool_description, + params=sample_tool_params, + required_authn_params={}, + auth_service_token_getters={}, + bound_params={}, + client_headers=static_client_header, # Pass client headers + ) + assert tool_instance._ToolboxTool__client_headers == static_client_header + + +def test_tool_init_header_auth_conflict( + http_session, + sample_tool_auth_params, + sample_tool_description, + auth_getters, + auth_header_key, +): + """Tests ValueError on init if client header conflicts with auth token.""" + conflicting_client_header = { + auth_header_key: "some-client-value" + } # Header name matches derived auth token name + + with pytest.raises( + ValueError, match=f"Client header\\(s\\) `{auth_header_key}` already registered" + ): + ToolboxTool( + session=http_session, + base_url=TEST_BASE_URL, + name="auth_conflict_tool", + description=sample_tool_description, + params=sample_tool_auth_params, + required_authn_params={}, # Assume auth req satisfied for init check + auth_service_token_getters=auth_getters, # Has 'test-auth' + bound_params={}, + client_headers=conflicting_client_header, # Conflicting header + ) From 2361cdf99800007ee426b2784de6c5bbf466011e Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 16:52:50 +0530 Subject: [PATCH 20/36] cleanup --- packages/toolbox-core/tests/test_tools.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/toolbox-core/tests/test_tools.py b/packages/toolbox-core/tests/test_tools.py index 5554cc54..7b530849 100644 --- a/packages/toolbox-core/tests/test_tools.py +++ b/packages/toolbox-core/tests/test_tools.py @@ -393,7 +393,7 @@ def test_tool_init_with_client_headers( required_authn_params={}, auth_service_token_getters={}, bound_params={}, - client_headers=static_client_header, # Pass client headers + client_headers=static_client_header, ) assert tool_instance._ToolboxTool__client_headers == static_client_header @@ -408,7 +408,7 @@ def test_tool_init_header_auth_conflict( """Tests ValueError on init if client header conflicts with auth token.""" conflicting_client_header = { auth_header_key: "some-client-value" - } # Header name matches derived auth token name + } with pytest.raises( ValueError, match=f"Client header\\(s\\) `{auth_header_key}` already registered" @@ -419,8 +419,8 @@ def test_tool_init_header_auth_conflict( name="auth_conflict_tool", description=sample_tool_description, params=sample_tool_auth_params, - required_authn_params={}, # Assume auth req satisfied for init check - auth_service_token_getters=auth_getters, # Has 'test-auth' + required_authn_params={}, + auth_service_token_getters=auth_getters, bound_params={}, - client_headers=conflicting_client_header, # Conflicting header + client_headers=conflicting_client_header, ) From d46bfd0f20b86c7c93daf84ad425fd9c054b9c9f Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 17:06:35 +0530 Subject: [PATCH 21/36] cleanup --- packages/toolbox-core/tests/test_client.py | 106 ++++++++++++++------- 1 file changed, 73 insertions(+), 33 deletions(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index badf8e61..6f47c36b 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -17,9 +17,12 @@ import json from unittest.mock import AsyncMock, Mock +import aioresponses import pytest + +from typing import Optional, Callable, Mapping, Any import pytest_asyncio -from aioresponses import CallbackResult +from aioresponses import CallbackResult, aioresponses from toolbox_core import ToolboxClient from toolbox_core.protocol import ManifestSchema, ParameterSchema, ToolSchema @@ -65,6 +68,65 @@ def test_tool_auth(): ], ) +# --- Helper Functions for Mocking --- + +def mock_tool_load( + aio_resp: aioresponses, + tool_name: str, + tool_schema: ToolSchema, + base_url: str = TEST_BASE_URL, + server_version: str = "0.0.0", + status: int = 200, + callback: Optional[Callable] = None, +): + """Mocks the GET /api/tool/{tool_name} endpoint.""" + url = f"{base_url}/api/tool/{tool_name}" + manifest = ManifestSchema(serverVersion=server_version, tools={tool_name: tool_schema}) + aio_resp.get( + url, + payload=manifest.model_dump(), + status=status, + callback=callback, + ) + +def mock_toolset_load( + aio_resp: aioresponses, + toolset_name: str, + tools_dict: Mapping[str, ToolSchema], + base_url: str = TEST_BASE_URL, + server_version: str = "0.0.0", + status: int = 200, + callback: Optional[Callable] = None, +): + """Mocks the GET /api/toolset/{toolset_name} endpoint.""" + # Handle default toolset name (empty string) + url_path = f"toolset/{toolset_name}" if toolset_name else "toolset/" + url = f"{base_url}/api/{url_path}" + manifest = ManifestSchema(serverVersion=server_version, tools=tools_dict) + aio_resp.get( + url, + payload=manifest.model_dump(), + status=status, + callback=callback, + ) + +def mock_tool_invoke( + aio_resp: aioresponses, + tool_name: str, + base_url: str = TEST_BASE_URL, + response_payload: Any = {"result": "ok"}, + status: int = 200, + callback: Optional[Callable] = None, +): + """Mocks the POST /api/tool/{tool_name}/invoke endpoint.""" + url = f"{base_url}/api/tool/{tool_name}/invoke" + aio_resp.post( + url, + payload=response_payload, + status=status, + callback=callback, + ) + @pytest.mark.asyncio async def test_load_tool_success(aioresponses, test_tool_str): @@ -73,17 +135,8 @@ async def test_load_tool_success(aioresponses, test_tool_str): """ # Mock out responses from server TOOL_NAME = "test_tool_1" - manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str}) - aioresponses.get( - f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}", - payload=manifest.model_dump(), - status=200, - ) - aioresponses.post( - f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke", - payload={"result": "ok"}, - status=200, - ) + mock_tool_load(aioresponses, TOOL_NAME, test_tool_str) + mock_tool_invoke(aioresponses, TOOL_NAME) async with ToolboxClient(TEST_BASE_URL) as client: # Load a Tool @@ -115,11 +168,8 @@ async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_b manifest = ManifestSchema( serverVersion="0.0.0", tools={TOOL1: test_tool_str, TOOL2: test_tool_int_bool} ) - aioresponses.get( - f"{TEST_BASE_URL}/api/toolset/{TOOLSET_NAME}", - payload=manifest.model_dump(), - status=200, - ) + mock_toolset_load(aioresponses, TOOLSET_NAME, manifest.tools) + async with ToolboxClient(TEST_BASE_URL) as client: tools = await client.load_toolset(TOOLSET_NAME) @@ -137,18 +187,9 @@ async def test_invoke_tool_server_error(aioresponses, test_tool_str): error status.""" TOOL_NAME = "server_error_tool" ERROR_MESSAGE = "Simulated Server Error" - manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str}) - aioresponses.get( - f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}", - payload=manifest.model_dump(), - status=200, - ) - aioresponses.post( - f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke", - payload={"error": ERROR_MESSAGE}, - status=500, - ) + mock_tool_load(aioresponses, TOOL_NAME, test_tool_str) + mock_tool_invoke(aioresponses, TOOL_NAME, response_payload={"error": ERROR_MESSAGE}, status=500) async with ToolboxClient(TEST_BASE_URL) as client: loaded_tool = await client.load_tool(TOOL_NAME) @@ -170,11 +211,10 @@ async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str): serverVersion="0.0.0", tools={ACTUAL_TOOL_IN_MANIFEST: test_tool_str} ) - aioresponses.get( - f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", - payload=manifest.model_dump(), - status=200, - ) + mock_tool_load(aioresponses, REQUESTED_TOOL_NAME, test_tool_str, server_version="0.0.0") + url = f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}" + aioresponses.get(url, payload=manifest.model_dump(), + status=200) async with ToolboxClient(TEST_BASE_URL) as client: with pytest.raises(Exception, match=f"Tool '{REQUESTED_TOOL_NAME}' not found!"): From f2f5cd2ef337895bc822a7eaf4f43e5321facad0 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 17:13:35 +0530 Subject: [PATCH 22/36] lint --- packages/toolbox-core/tests/test_tools.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/packages/toolbox-core/tests/test_tools.py b/packages/toolbox-core/tests/test_tools.py index 7b530849..8f62baf5 100644 --- a/packages/toolbox-core/tests/test_tools.py +++ b/packages/toolbox-core/tests/test_tools.py @@ -406,9 +406,7 @@ def test_tool_init_header_auth_conflict( auth_header_key, ): """Tests ValueError on init if client header conflicts with auth token.""" - conflicting_client_header = { - auth_header_key: "some-client-value" - } + conflicting_client_header = {auth_header_key: "some-client-value"} with pytest.raises( ValueError, match=f"Client header\\(s\\) `{auth_header_key}` already registered" From 0fd1ca260da6e3435cda6b984c3deaad9eff3600 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 18:25:08 +0530 Subject: [PATCH 23/36] fix --- packages/toolbox-core/tests/test_client.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index 6f47c36b..73319003 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -211,7 +211,6 @@ async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str): serverVersion="0.0.0", tools={ACTUAL_TOOL_IN_MANIFEST: test_tool_str} ) - mock_tool_load(aioresponses, REQUESTED_TOOL_NAME, test_tool_str, server_version="0.0.0") url = f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}" aioresponses.get(url, payload=manifest.model_dump(), status=200) @@ -512,7 +511,6 @@ def async_callable_header_value(self): @pytest.fixture def async_callable_header(self, async_callable_header_value): - # This mock, when awaited by the client, will return the value. return { "X-Async-Callable-Header": AsyncMock( return_value=async_callable_header_value From 516d1514369143e796e012a9c8506efc020ca38f Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 18:26:58 +0530 Subject: [PATCH 24/36] cleanup --- packages/toolbox-core/tests/test_client.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index 73319003..dcc619c8 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -611,9 +611,7 @@ async def test_load_tool_with_async_callable_headers( aioresponses, test_tool_str, async_callable_header, - # Use the header fixture (provides mock) async_callable_header_value, - # Use the value fixture for expected result ): """Tests loading and invoking a tool with async callable client headers.""" tool_name = "tool_with_async_callable_headers" @@ -628,14 +626,14 @@ async def test_load_tool_with_async_callable_headers( # Calculate expected result using the VALUE fixture resolved_header = {header_key: async_callable_header_value} - # Mock GET callback checks against the RESOLVED value + # Mock GET def get_callback(url, **kwargs): assert kwargs.get("headers") == resolved_header return CallbackResult(status=200, payload=manifest.model_dump()) aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) - # Mock POST callback checks against the RESOLVED value + # Mock POST def post_callback(url, **kwargs): assert kwargs.get("headers") == resolved_header return CallbackResult(status=200, payload=expected_payload) @@ -644,7 +642,6 @@ def post_callback(url, **kwargs): f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback ) - # Pass the dictionary containing the AsyncMock to the client async with ToolboxClient( TEST_BASE_URL, client_headers=async_callable_header ) as client: @@ -694,9 +691,10 @@ async def test_add_headers_success( ) expected_payload = {"result": "added_ok"} - # Mock GET for tool definition - check headers + # Mock GET def get_callback(url, **kwargs): - assert kwargs.get("headers") == static_header # Verify headers in GET + # Verify headers + assert kwargs.get("headers") == static_header return CallbackResult(status=200, payload=manifest.model_dump()) aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) @@ -729,7 +727,7 @@ async def test_add_headers_duplicate_fail(self, static_header): ): await client.add_headers( static_header - ) # Try adding the same header again + ) @pytest.mark.asyncio async def test_client_header_auth_token_conflict_fail( From 80b40d7216e11ca475e562459906a9954d37a5c5 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 18:33:45 +0530 Subject: [PATCH 25/36] lint --- packages/toolbox-core/tests/test_client.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index dcc619c8..ea23fbfd 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -68,8 +68,10 @@ def test_tool_auth(): ], ) + # --- Helper Functions for Mocking --- + def mock_tool_load( aio_resp: aioresponses, tool_name: str, @@ -81,7 +83,9 @@ def mock_tool_load( ): """Mocks the GET /api/tool/{tool_name} endpoint.""" url = f"{base_url}/api/tool/{tool_name}" - manifest = ManifestSchema(serverVersion=server_version, tools={tool_name: tool_schema}) + manifest = ManifestSchema( + serverVersion=server_version, tools={tool_name: tool_schema} + ) aio_resp.get( url, payload=manifest.model_dump(), @@ -89,6 +93,7 @@ def mock_tool_load( callback=callback, ) + def mock_toolset_load( aio_resp: aioresponses, toolset_name: str, @@ -110,6 +115,7 @@ def mock_toolset_load( callback=callback, ) + def mock_tool_invoke( aio_resp: aioresponses, tool_name: str, @@ -170,7 +176,6 @@ async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_b ) mock_toolset_load(aioresponses, TOOLSET_NAME, manifest.tools) - async with ToolboxClient(TEST_BASE_URL) as client: tools = await client.load_toolset(TOOLSET_NAME) @@ -189,7 +194,9 @@ async def test_invoke_tool_server_error(aioresponses, test_tool_str): ERROR_MESSAGE = "Simulated Server Error" mock_tool_load(aioresponses, TOOL_NAME, test_tool_str) - mock_tool_invoke(aioresponses, TOOL_NAME, response_payload={"error": ERROR_MESSAGE}, status=500) + mock_tool_invoke( + aioresponses, TOOL_NAME, response_payload={"error": ERROR_MESSAGE}, status=500 + ) async with ToolboxClient(TEST_BASE_URL) as client: loaded_tool = await client.load_tool(TOOL_NAME) @@ -212,8 +219,7 @@ async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str): ) url = f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}" - aioresponses.get(url, payload=manifest.model_dump(), - status=200) + aioresponses.get(url, payload=manifest.model_dump(), status=200) async with ToolboxClient(TEST_BASE_URL) as client: with pytest.raises(Exception, match=f"Tool '{REQUESTED_TOOL_NAME}' not found!"): @@ -725,9 +731,7 @@ async def test_add_headers_duplicate_fail(self, static_header): ValueError, match=f"Client header\\(s\\) `X-Static-Header` already registered", ): - await client.add_headers( - static_header - ) + await client.add_headers(static_header) @pytest.mark.asyncio async def test_client_header_auth_token_conflict_fail( From 70d55b69bc99c3964fbb98069e6f3e23169b62ad Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 18:34:02 +0530 Subject: [PATCH 26/36] lint --- packages/toolbox-core/tests/test_tools.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_tools.py b/packages/toolbox-core/tests/test_tools.py index 8f62baf5..a12fe83f 100644 --- a/packages/toolbox-core/tests/test_tools.py +++ b/packages/toolbox-core/tests/test_tools.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from inspect import iscoroutine from typing import AsyncGenerator, Callable from unittest.mock import AsyncMock, Mock From 551635bc1c0fa7ffc2bfd38754c3a5db12ee7726 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 18:36:09 +0530 Subject: [PATCH 27/36] lint --- packages/toolbox-core/tests/test_client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index ea23fbfd..f5c0af9d 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -15,12 +15,11 @@ import inspect import json +from typing import Any, Callable, Mapping, Optional from unittest.mock import AsyncMock, Mock import aioresponses import pytest - -from typing import Optional, Callable, Mapping, Any import pytest_asyncio from aioresponses import CallbackResult, aioresponses From 2880b4d348a9e494010b7942bd7d8d22773bb720 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Mon, 21 Apr 2025 18:48:11 +0530 Subject: [PATCH 28/36] lint --- packages/toolbox-core/src/toolbox_core/client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index f54693a2..cbb0242b 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -242,4 +242,6 @@ async def add_headers( raise ValueError( f"Client header(s) `{', '.join(duplicates)}` already registered in the client." ) - self.__client_headers.update(headers) + + merged_headers = {**self.__client_headers, **headers} + self.__client_headers = merged_headers From 7990b99cc973d28497ee193d10932640c2215c8c Mon Sep 17 00:00:00 2001 From: Twisha Bansal <58483338+twishabansal@users.noreply.github.com> Date: Tue, 22 Apr 2025 10:01:37 +0530 Subject: [PATCH 29/36] Update packages/toolbox-core/src/toolbox_core/client.py Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> --- packages/toolbox-core/src/toolbox_core/client.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index cbb0242b..e504b72a 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -156,11 +156,7 @@ async def load_tool( """ # Resolve client headers - original_headers = self.__client_headers - resolved_headers = { - header_name: await resolve_value(original_headers[header_name]) - for header_name in original_headers - } + headers = { name: await resolve_value(value) for name, val in self.__client_headers.items() } # request the definition of the tool from the server url = f"{self.__base_url}/api/tool/{name}" From 60572b476eebac6acf613b784c31b501aaffaff4 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 22 Apr 2025 10:23:33 +0530 Subject: [PATCH 30/36] lint --- packages/toolbox-core/src/toolbox_core/client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index e504b72a..5bb4e0c3 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -156,7 +156,10 @@ async def load_tool( """ # Resolve client headers - headers = { name: await resolve_value(value) for name, val in self.__client_headers.items() } + headers = { + name: await resolve_value(value) + for name, val in self.__client_headers.items() + } # request the definition of the tool from the server url = f"{self.__base_url}/api/tool/{name}" From 1826ad97f62a1125bbbcd0f0862c1d54bcc9c4de Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 22 Apr 2025 10:24:39 +0530 Subject: [PATCH 31/36] fix --- packages/toolbox-core/src/toolbox_core/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 5bb4e0c3..4d4d9db2 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -156,8 +156,8 @@ async def load_tool( """ # Resolve client headers - headers = { - name: await resolve_value(value) + resolved_headers = { + name: await resolve_value(val) for name, val in self.__client_headers.items() } From a4eafe67fa06b0b328fcc49bd218a594886318ec Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 22 Apr 2025 10:26:39 +0530 Subject: [PATCH 32/36] cleanup --- packages/toolbox-core/tests/test_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index f5c0af9d..f5752bd0 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -18,7 +18,6 @@ from typing import Any, Callable, Mapping, Optional from unittest.mock import AsyncMock, Mock -import aioresponses import pytest import pytest_asyncio from aioresponses import CallbackResult, aioresponses From 8502db53b30b732e91771f76880dd328454daca4 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 22 Apr 2025 10:32:48 +0530 Subject: [PATCH 33/36] use mock_tool_load in test --- packages/toolbox-core/tests/test_client.py | 24 ++++++++++++++-------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index f5752bd0..e7fee424 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -78,15 +78,18 @@ def mock_tool_load( server_version: str = "0.0.0", status: int = 200, callback: Optional[Callable] = None, + payload_override: Optional[Callable] = None, ): """Mocks the GET /api/tool/{tool_name} endpoint.""" url = f"{base_url}/api/tool/{tool_name}" - manifest = ManifestSchema( - serverVersion=server_version, tools={tool_name: tool_schema} - ) + if payload_override is not None: + payload = payload_override + else: + manifest = ManifestSchema(serverVersion=server_version, tools={tool_name: tool_schema}) + payload = manifest.model_dump() aio_resp.get( url, - payload=manifest.model_dump(), + payload=payload, status=status, callback=callback, ) @@ -202,7 +205,6 @@ async def test_invoke_tool_server_error(aioresponses, test_tool_str): with pytest.raises(Exception, match=ERROR_MESSAGE): await loaded_tool(param1="some input") - @pytest.mark.asyncio async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str): """ @@ -212,12 +214,16 @@ async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str): ACTUAL_TOOL_IN_MANIFEST = "actual_tool_abc" REQUESTED_TOOL_NAME = "non_existent_tool_xyz" - manifest = ManifestSchema( + mismatched_manifest_payload = ManifestSchema( serverVersion="0.0.0", tools={ACTUAL_TOOL_IN_MANIFEST: test_tool_str} - ) + ).model_dump() - url = f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}" - aioresponses.get(url, payload=manifest.model_dump(), status=200) + mock_tool_load( + aio_resp=aioresponses, + tool_name=REQUESTED_TOOL_NAME, + tool_schema=test_tool_str, + payload_override=mismatched_manifest_payload + ) async with ToolboxClient(TEST_BASE_URL) as client: with pytest.raises(Exception, match=f"Tool '{REQUESTED_TOOL_NAME}' not found!"): From d0731ec6218144d77f13ad391afed2e951d72950 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 22 Apr 2025 12:41:59 +0530 Subject: [PATCH 34/36] test cleanup --- packages/toolbox-core/tests/test_client.py | 121 +++++++++++---------- 1 file changed, 66 insertions(+), 55 deletions(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index e7fee424..cce3bf53 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -85,7 +85,9 @@ def mock_tool_load( if payload_override is not None: payload = payload_override else: - manifest = ManifestSchema(serverVersion=server_version, tools={tool_name: tool_schema}) + manifest = ManifestSchema( + serverVersion=server_version, tools={tool_name: tool_schema} + ) payload = manifest.model_dump() aio_resp.get( url, @@ -205,6 +207,7 @@ async def test_invoke_tool_server_error(aioresponses, test_tool_str): with pytest.raises(Exception, match=ERROR_MESSAGE): await loaded_tool(param1="some input") + @pytest.mark.asyncio async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str): """ @@ -222,7 +225,7 @@ async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str): aio_resp=aioresponses, tool_name=REQUESTED_TOOL_NAME, tool_schema=test_tool_str, - payload_override=mismatched_manifest_payload + payload_override=mismatched_manifest_payload, ) async with ToolboxClient(TEST_BASE_URL) as client: @@ -527,6 +530,22 @@ def async_callable_header(self, async_callable_header_value): ) } + @staticmethod + def create_callback_factory( + expected_header, callback_status, callback_payload + ) -> Callable: + """ + Factory that RETURNS a callback function for aioresponses. + The returned callback will check headers and return the specified payload/status. + """ + + def actual_callback(url, **kwargs): + received_headers = kwargs.get("headers") + assert received_headers == expected_header + return CallbackResult(status=callback_status, payload=callback_payload) + + return actual_callback + @pytest.mark.asyncio async def test_client_init_with_headers(self, static_header): """Tests client initialization with static headers.""" @@ -544,20 +563,18 @@ async def test_load_tool_with_static_headers( ) expected_payload = {"result": "ok"} - # Mock GET for tool definition - def get_callback(url, **kwargs): - # Verify headers - assert kwargs.get("headers") == static_header - return CallbackResult(status=200, payload=manifest.model_dump()) - + get_callback = self.create_callback_factory( + expected_header=static_header, + callback_status=200, + callback_payload=manifest.model_dump(), + ) aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) - # Mock POST for invocation - def post_callback(url, **kwargs): - # Verify headers - assert kwargs.get("headers") == static_header - return CallbackResult(status=200, payload=expected_payload) - + post_callback = self.create_callback_factory( + expected_header=static_header, + callback_status=200, + callback_payload=expected_payload, + ) aioresponses.post( f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback ) @@ -585,20 +602,18 @@ async def test_load_tool_with_sync_callable_headers( header_mock = sync_callable_header[header_key] resolved_header = {header_key: sync_callable_header_value} - # Mock GET - def get_callback(url, **kwargs): - # Verify headers - assert kwargs.get("headers") == resolved_header - return CallbackResult(status=200, payload=manifest.model_dump()) - + get_callback = self.create_callback_factory( + expected_header=resolved_header, + callback_status=200, + callback_payload=manifest.model_dump(), + ) aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) - # Mock POST - def post_callback(url, **kwargs): - # Verify headers - assert kwargs.get("headers") == resolved_header - return CallbackResult(status=200, payload=expected_payload) - + post_callback = self.create_callback_factory( + expected_header=resolved_header, + callback_status=200, + callback_payload=expected_payload, + ) aioresponses.post( f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback ) @@ -636,18 +651,18 @@ async def test_load_tool_with_async_callable_headers( # Calculate expected result using the VALUE fixture resolved_header = {header_key: async_callable_header_value} - # Mock GET - def get_callback(url, **kwargs): - assert kwargs.get("headers") == resolved_header - return CallbackResult(status=200, payload=manifest.model_dump()) - + get_callback = self.create_callback_factory( + expected_header=resolved_header, + callback_status=200, + callback_payload=manifest.model_dump(), + ) aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) - # Mock POST - def post_callback(url, **kwargs): - assert kwargs.get("headers") == resolved_header - return CallbackResult(status=200, payload=expected_payload) - + post_callback = self.create_callback_factory( + expected_header=resolved_header, + callback_status=200, + callback_payload=expected_payload, + ) aioresponses.post( f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback ) @@ -675,16 +690,14 @@ async def test_load_toolset_with_headers( serverVersion="0.0.0", tools={tool_name: test_tool_str} ) - # Mock GET - def get_callback(url, **kwargs): - # Verify headers - assert kwargs.get("headers") == static_header - return CallbackResult(status=200, payload=manifest.model_dump()) - + get_callback = self.create_callback_factory( + expected_header=static_header, + callback_status=200, + callback_payload=manifest.model_dump(), + ) aioresponses.get( f"{TEST_BASE_URL}/api/toolset/{toolset_name}", callback=get_callback ) - async with ToolboxClient(TEST_BASE_URL, client_headers=static_header) as client: tools = await client.load_toolset(toolset_name) assert len(tools) == 1 @@ -701,20 +714,18 @@ async def test_add_headers_success( ) expected_payload = {"result": "added_ok"} - # Mock GET - def get_callback(url, **kwargs): - # Verify headers - assert kwargs.get("headers") == static_header - return CallbackResult(status=200, payload=manifest.model_dump()) - + get_callback = self.create_callback_factory( + expected_header=static_header, + callback_status=200, + callback_payload=manifest.model_dump(), + ) aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) - # Mock POST - def post_callback(url, **kwargs): - # Verify headers - assert kwargs.get("headers") == static_header - return CallbackResult(status=200, payload=expected_payload) - + post_callback = self.create_callback_factory( + expected_header=static_header, + callback_status=200, + callback_payload=expected_payload, + ) aioresponses.post( f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback ) From 560dfec4b1ae196100572db03771d6fb9a829d5c Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 22 Apr 2025 12:46:02 +0530 Subject: [PATCH 35/36] test cleanup --- packages/toolbox-core/tests/test_client.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index cce3bf53..3da37b31 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -532,7 +532,7 @@ def async_callable_header(self, async_callable_header_value): @staticmethod def create_callback_factory( - expected_header, callback_status, callback_payload + expected_header, callback_payload, callback_status: int = 200, ) -> Callable: """ Factory that RETURNS a callback function for aioresponses. @@ -565,14 +565,12 @@ async def test_load_tool_with_static_headers( get_callback = self.create_callback_factory( expected_header=static_header, - callback_status=200, callback_payload=manifest.model_dump(), ) aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) post_callback = self.create_callback_factory( expected_header=static_header, - callback_status=200, callback_payload=expected_payload, ) aioresponses.post( @@ -604,14 +602,12 @@ async def test_load_tool_with_sync_callable_headers( get_callback = self.create_callback_factory( expected_header=resolved_header, - callback_status=200, callback_payload=manifest.model_dump(), ) aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) post_callback = self.create_callback_factory( expected_header=resolved_header, - callback_status=200, callback_payload=expected_payload, ) aioresponses.post( @@ -653,14 +649,12 @@ async def test_load_tool_with_async_callable_headers( get_callback = self.create_callback_factory( expected_header=resolved_header, - callback_status=200, callback_payload=manifest.model_dump(), ) aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) post_callback = self.create_callback_factory( expected_header=resolved_header, - callback_status=200, callback_payload=expected_payload, ) aioresponses.post( @@ -692,7 +686,6 @@ async def test_load_toolset_with_headers( get_callback = self.create_callback_factory( expected_header=static_header, - callback_status=200, callback_payload=manifest.model_dump(), ) aioresponses.get( @@ -716,14 +709,12 @@ async def test_add_headers_success( get_callback = self.create_callback_factory( expected_header=static_header, - callback_status=200, callback_payload=manifest.model_dump(), ) aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) post_callback = self.create_callback_factory( expected_header=static_header, - callback_status=200, callback_payload=expected_payload, ) aioresponses.post( From 1ac9a14d112490934ca72fe17003d79ccd70f240 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 22 Apr 2025 12:47:56 +0530 Subject: [PATCH 36/36] lint --- packages/toolbox-core/tests/test_client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index 3da37b31..e6d12a0c 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -532,7 +532,9 @@ def async_callable_header(self, async_callable_header_value): @staticmethod def create_callback_factory( - expected_header, callback_payload, callback_status: int = 200, + expected_header, + callback_payload, + callback_status: int = 200, ) -> Callable: """ Factory that RETURNS a callback function for aioresponses.