diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 2b677881..a1c6d1ba 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -83,8 +83,8 @@ def __parse_tool( session=self.__session, base_url=self.__base_url, name=name, - desc=schema.description, - params=[p.to_param() for p in params], + description=schema.description, + params=params, # create a read-only values for the maps to prevent mutation required_authn_params=types.MappingProxyType(authn_params), auth_service_token_getters=types.MappingProxyType(auth_token_getters), diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index c83051ca..15374b02 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -15,7 +15,7 @@ import asyncio import types -from inspect import Parameter, Signature +from inspect import Signature from typing import ( Any, Callable, @@ -28,6 +28,8 @@ from aiohttp import ClientSession +from toolbox_core.protocol import ParameterSchema + class ToolboxTool: """ @@ -47,8 +49,8 @@ def __init__( session: ClientSession, base_url: str, name: str, - desc: str, - params: Sequence[Parameter], + description: str, + params: Sequence[ParameterSchema], required_authn_params: Mapping[str, list[str]], auth_service_token_getters: Mapping[str, Callable[[], str]], bound_params: Mapping[str, Union[Callable[[], Any], Any]], @@ -61,31 +63,30 @@ def __init__( session: The `aiohttp.ClientSession` used for making API requests. base_url: The base URL of the Toolbox server API. name: The name of the remote tool. - desc: The description of the remote tool (used as its docstring). - params: A list of `inspect.Parameter` objects defining the tool's - arguments and their types/defaults. + description: The description of the remote tool. + params: The args of the tool. required_authn_params: A dict of required authenticated parameters to a list of services that provide values for them. auth_service_token_getters: A dict of authService -> token (or callables that produce a token) bound_params: A mapping of parameter names to bind to specific values or callables that are called to produce values as needed. - """ - # used to invoke the toolbox API self.__session: ClientSession = session self.__base_url: str = base_url self.__url = f"{base_url}/api/tool/{name}/invoke" - - self.__desc = desc + self.__description = description self.__params = params + inspect_type_params = [param.to_param() for param in self.__params] # the following properties are set to help anyone that might inspect it determine usage self.__name__ = name - self.__doc__ = desc - self.__signature__ = Signature(parameters=params, return_annotation=str) - self.__annotations__ = {p.name: p.annotation for p in params} + self.__doc__ = create_docstring(self.__description, self.__params) + self.__signature__ = Signature( + parameters=inspect_type_params, return_annotation=str + ) + self.__annotations__ = {p.name: p.annotation for p in inspect_type_params} # TODO: self.__qualname__ ?? # map of parameter name to auth service required by it @@ -100,8 +101,8 @@ def __copy( session: Optional[ClientSession] = None, base_url: Optional[str] = None, name: Optional[str] = None, - desc: Optional[str] = None, - params: Optional[list[Parameter]] = None, + description: Optional[str] = None, + params: Optional[Sequence[ParameterSchema]] = None, required_authn_params: Optional[Mapping[str, list[str]]] = None, auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None, @@ -113,9 +114,8 @@ def __copy( session: The `aiohttp.ClientSession` used for making API requests. base_url: The base URL of the Toolbox server API. name: The name of the remote tool. - desc: The description of the remote tool (used as its docstring). - params: A list of `inspect.Parameter` objects defining the tool's - arguments and their types/defaults. + description: The description of the remote tool. + params: The args of the tool. required_authn_params: A dict of required authenticated parameters that need a auth_service_token_getter set for them yet. auth_service_token_getters: A dict of authService -> token (or callables @@ -129,7 +129,7 @@ def __copy( session=check(session, self.__session), base_url=check(base_url, self.__base_url), name=check(name, self.__name__), - desc=check(desc, self.__desc), + description=check(description, self.__description), params=check(params, self.__params), required_authn_params=check( required_authn_params, self.__required_authn_params @@ -258,7 +258,6 @@ def bind_parameters( for p in self.__params: if p.name not in bound_params: new_params.append(p) - all_bound_params = dict(self.__bound_parameters) all_bound_params.update(bound_params) @@ -268,6 +267,19 @@ def bind_parameters( ) +def create_docstring(description: str, params: Sequence[ParameterSchema]) -> str: + """Convert tool description and params into its function docstring""" + docstring = description + if not params: + return docstring + docstring += "\n\nArgs:" + for p in params: + docstring += ( + f"\n {p.name} ({p.to_param().annotation.__name__}): {p.description}" + ) + return docstring + + def identify_required_authn_params( req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] ) -> dict[str, list[str]]: diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index e25573e2..2ce600c3 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -92,7 +92,11 @@ async def test_load_tool_success(aioresponses, test_tool_str): assert callable(loaded_tool) # Assert introspection attributes are set correctly assert loaded_tool.__name__ == TOOL_NAME - assert loaded_tool.__doc__ == test_tool_str.description + expected_description = ( + test_tool_str.description + + f"\n\nArgs:\n param1 (str): Description of Param1" + ) + assert loaded_tool.__doc__ == expected_description # Assert signature inspection sig = inspect.signature(loaded_tool)