Skip to content

feat!: Implement a Base SDK for Toolbox #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 53 additions & 31 deletions src/toolbox_langchain/async_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from copy import deepcopy
from typing import Any, Callable, TypeVar, Union
from typing import Any, Callable, Union
from warnings import warn

from aiohttp import ClientSession
from langchain_core.tools import BaseTool

from .utils import (
ToolSchema,
_find_auth_params,
_find_bound_params,
_invoke_tool,
_parse_type,
_schema_to_docstring,
_schema_to_model,
)

T = TypeVar("T")


# This class is an internal implementation detail and is not exposed to the
# end-user. It should not be used directly by external code. Changes to this
# class will not be considered breaking changes to the public API.
class AsyncToolboxTool(BaseTool):
class AsyncToolboxTool:
"""
A subclass of LangChain's BaseTool that supports features specific to
Toolbox, like bound parameters and authenticated tools.
A class that supports features specific to Toolbox, like bound parameters
and authenticated tools.
"""

def __init__(
Expand Down Expand Up @@ -110,51 +110,69 @@ def __init__(

# Bind values for parameters present in the schema that don't require
# authentication.
bound_params = {
__bound_params = {
param_name: param_value
for param_name, param_value in bound_params.items()
if param_name in [param.name for param in non_auth_bound_params]
}

# Update the tools schema to validate only the presence of parameters
# that neither require authentication nor are bound.
schema.parameters = non_auth_non_bound_params

# Due to how pydantic works, we must initialize the underlying
# BaseTool class before assigning values to member variables.
super().__init__(
name=name,
description=schema.description,
args_schema=_schema_to_model(model_name=name, schema=schema.parameters),
)
__updated_schema = deepcopy(schema)
__updated_schema.parameters = non_auth_non_bound_params

self.__name = name
self.__schema = schema
self.__schema = __updated_schema
self.__model = _schema_to_model(name, self.__schema.parameters)
self.__url = url
self.__session = session
self.__auth_tokens = auth_tokens
self.__auth_params = auth_params
self.__bound_params = bound_params
self.__bound_params = __bound_params

# Warn users about any missing authentication so they can add it before
# tool invocation.
self.__validate_auth(strict=False)

def _run(self, **kwargs: Any) -> dict[str, Any]:
raise NotImplementedError("Synchronous methods not supported by async tools.")
# Store parameter definitions for the function signature and annotations
sig_params = []
annotations = {}
for param in self.__schema.parameters:
param_type = _parse_type(param)
annotations[param.name] = param_type
sig_params.append(
inspect.Parameter(
param.name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=param_type,
)
)

# Set function name, docstring, signature and annotations
self.__name__ = name
self.__qualname__ = name
self.__doc__ = _schema_to_docstring(self.__schema)
self.__signature__ = inspect.Signature(
parameters=sig_params, return_annotation=dict[str, Any]
)
self.__annotations__ = annotations

async def _arun(self, **kwargs: Any) -> dict[str, Any]:
async def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
"""
The coroutine that invokes the tool with the given arguments.

Args:
**kwargs: The arguments to the tool.
**args: The positional arguments to the tool.
**kwargs: The keyword arguments to the tool.

Returns:
A dictionary containing the parsed JSON response from the tool
invocation.
"""

# Validate arguments
validated_args = self.__signature__.bind(*args, **kwargs).arguments
self.__model.model_validate(validated_args)

# If the tool had parameters that require authentication, then right
# before invoking that tool, we check whether all these required
# authentication sources have been registered or not.
Expand All @@ -169,10 +187,14 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]:
evaluated_params[param_name] = param_value

# Merge bound parameters with the provided arguments
kwargs.update(evaluated_params)
validated_args.update(evaluated_params)

return await _invoke_tool(
self.__url, self.__session, self.__name, kwargs, self.__auth_tokens
self.__url,
self.__session,
self.__name__,
validated_args,
self.__auth_tokens,
)

def __validate_auth(self, strict: bool = True) -> None:
Expand Down Expand Up @@ -221,12 +243,12 @@ def __validate_auth(self, strict: bool = True) -> None:

if not is_authenticated:
messages.append(
f"Tool {self.__name} requires authentication, but no valid authentication sources are registered. Please register the required sources before use."
f"Tool {self.__name__} requires authentication, but no valid authentication sources are registered. Please register the required sources before use."
)

if params_missing_auth:
messages.append(
f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication sources are registered. Please register the required sources before use."
f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name__} require authentication, but no valid authentication sources are registered. Please register the required sources before use."
)

if messages:
Expand Down Expand Up @@ -277,7 +299,7 @@ def __create_copy(
# as errors or warnings, depending on the given `strict` flag.
new_schema.parameters += self.__auth_params
return AsyncToolboxTool(
name=self.__name,
name=self.__name__,
schema=new_schema,
url=self.__url,
session=self.__session,
Expand Down Expand Up @@ -317,7 +339,7 @@ def add_auth_tokens(

if dupe_tokens:
raise ValueError(
f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`."
f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name__}`."
)

return self.__create_copy(auth_tokens=auth_tokens, strict=strict)
Expand Down Expand Up @@ -380,7 +402,7 @@ def bind_params(

if dupe_params:
raise ValueError(
f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self.__name}`."
f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self.__name__}`."
)

return self.__create_copy(bound_params=bound_params, strict=strict)
Expand Down
23 changes: 5 additions & 18 deletions src/toolbox_langchain/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,15 @@
from threading import Thread
from typing import Any, Awaitable, Callable, TypeVar, Union

from langchain_core.tools import BaseTool

from .async_tools import AsyncToolboxTool

T = TypeVar("T")


class ToolboxTool(BaseTool):
class ToolboxTool:
"""
A subclass of LangChain's BaseTool that supports features specific to
Toolbox, like bound parameters and authenticated tools.
A class that supports features specific to Toolbox, like bound parameters
and authenticated tools.
"""

def __init__(
Expand All @@ -45,14 +43,6 @@ def __init__(
thread: The thread to run blocking operations in.
"""

# Due to how pydantic works, we must initialize the underlying
# BaseTool class before assigning values to member variables.
super().__init__(
name=async_tool.name,
description=async_tool.description,
args_schema=async_tool.args_schema,
)

self.__async_tool = async_tool
self.__loop = loop
self.__thread = thread
Expand All @@ -77,11 +67,8 @@ async def __run_as_async(self, coro: Awaitable[T]) -> T:
asyncio.run_coroutine_threadsafe(coro, self.__loop)
)

def _run(self, **kwargs: Any) -> dict[str, Any]:
return self.__run_as_sync(self.__async_tool._arun(**kwargs))

async def _arun(self, **kwargs: Any) -> dict[str, Any]:
return await self.__run_as_async(self.__async_tool._arun(**kwargs))
def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
return self.__run_as_sync(self.__async_tool(*args, **kwargs))

def add_auth_tokens(
self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True
Expand Down
26 changes: 26 additions & 0 deletions src/toolbox_langchain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,29 @@ def _find_bound_params(
_non_bound_params.append(param)

return (_bound_params, _non_bound_params)


def _schema_to_docstring(tool_schema: ToolSchema) -> str:
"""Generates a Google Style docstring from a ToolSchema object.

If the schema has parameters, the docstring includes an 'Args:' section
detailing each parameter's name, type, and description. If no parameters are
present, only the tool's description is returned.

Args:
tool_schema: The schema object defining the tool's interface,
including its description and parameters.

Returns:
str: A Google Style formatted docstring.
"""

if not tool_schema.parameters:
return tool_schema.description

docstring = f"{tool_schema.description}\n\nArgs:"
for param in tool_schema.parameters:
docstring += (
f"\n {param.name} ({_parse_type(param).__name__}): {param.description}"
)
return docstring