diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml index 90e9497a..492074e8 100644 --- a/.github/sync-repo-settings.yaml +++ b/.github/sync-repo-settings.yaml @@ -30,16 +30,21 @@ branchProtectionRules: - "conventionalcommits.org" - "header-check" # Add required status checks like presubmit tests - - "langchain-python-sdk-pr-py313 (toolbox-testing-438616)" - - "langchain-python-sdk-pr-py312 (toolbox-testing-438616)" - - "langchain-python-sdk-pr-py311 (toolbox-testing-438616)" - - "langchain-python-sdk-pr-py310 (toolbox-testing-438616)" - - "langchain-python-sdk-pr-py39 (toolbox-testing-438616)" - "core-python-sdk-pr-py313 (toolbox-testing-438616)" - "core-python-sdk-pr-py312 (toolbox-testing-438616)" - "core-python-sdk-pr-py311 (toolbox-testing-438616)" - "core-python-sdk-pr-py310 (toolbox-testing-438616)" - "core-python-sdk-pr-py39 (toolbox-testing-438616)" + - "langchain-python-sdk-pr-py313 (toolbox-testing-438616)" + - "langchain-python-sdk-pr-py312 (toolbox-testing-438616)" + - "langchain-python-sdk-pr-py311 (toolbox-testing-438616)" + - "langchain-python-sdk-pr-py310 (toolbox-testing-438616)" + - "langchain-python-sdk-pr-py39 (toolbox-testing-438616)" + - "llamaindex-python-sdk-pr-py313-1 (toolbox-testing-438616)" + - "llamaindex-python-sdk-pr-py312-1 (toolbox-testing-438616)" + - "llamaindex-python-sdk-pr-py311-1 (toolbox-testing-438616)" + - "llamaindex-python-sdk-pr-py310-1 (toolbox-testing-438616)" + - "llamaindex-python-sdk-pr-py39-1 (toolbox-testing-438616)" requiredApprovingReviewCount: 1 requiresCodeOwnerReviews: true requiresStrictStatusChecks: true diff --git a/.github/workflows/lint-toolbox-llamaindex.yaml b/.github/workflows/lint-toolbox-llamaindex.yaml new file mode 100644 index 00000000..05f10c79 --- /dev/null +++ b/.github/workflows/lint-toolbox-llamaindex.yaml @@ -0,0 +1,84 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: llamaindex +on: + pull_request: + paths: + - 'packages/toolbox-llamaindex/**' + - '!packages/toolbox-llamaindex/**/*.md' + pull_request_target: + types: [labeled] + +# Declare default permissions as read only. +permissions: read-all + +jobs: + lint: + if: "${{ github.event.action != 'labeled' || github.event.label.name == 'tests: run' }}" + name: lint + runs-on: ubuntu-latest + concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + defaults: + run: + working-directory: ./packages/toolbox-llamaindex + permissions: + contents: 'read' + issues: 'write' + pull-requests: 'write' + steps: + - name: Remove PR Label + if: "${{ github.event.action == 'labeled' && github.event.label.name == 'tests: run' }}" + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + try { + await github.rest.issues.removeLabel({ + name: 'tests: run', + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number + }); + } catch (e) { + console.log('Failed to remove label. Another job may have already removed it!'); + } + - name: Checkout code + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + ref: ${{ github.event.pull_request.head.sha }} + repository: ${{ github.event.pull_request.head.repo.full_name }} + token: ${{ secrets.GITHUB_TOKEN }} + - name: Setup Python + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 + with: + python-version: "3.13" + + - name: Install library requirements + run: pip install -r requirements.txt + + - name: Install test requirements + run: pip install .[test] + + - name: Run linters + run: | + black --check . + isort --check . + + - name: Run type-check + env: + MYPYPATH: './src' + run: mypy --install-types --non-interactive --cache-dir=.mypy_cache/ -p toolbox_llamaindex \ No newline at end of file diff --git a/.github/workflows/schedule_reporter.yml b/.github/workflows/schedule_reporter.yml index 16a84312..db2de963 100644 --- a/.github/workflows/schedule_reporter.yml +++ b/.github/workflows/schedule_reporter.yml @@ -26,4 +26,4 @@ jobs: contents: 'read' uses: ./.github/workflows/cloud_build_failure_reporter.yml with: - trigger_names: "langchain-python-sdk-test-nightly,langchain-python-sdk-test-on-merge,core-python-sdk-test-nightly,core-python-sdk-test-on-merge" + trigger_names: "core-python-sdk-test-nightly,core-python-sdk-test-on-merge,langchain-python-sdk-test-nightly,langchain-python-sdk-test-on-merge,llamaindex-python-sdk-test-nightly,llamaindex-python-sdk-test-on-merge" diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 6182df46..b1eab2f9 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1 +1 @@ -{"packages/toolbox-langchain":"0.1.0","packages/toolbox-core":"0.1.0"} +{"packages/toolbox-langchain":"0.1.0","packages/toolbox-core":"0.1.0","packages/toolbox-llamaindex":"0.1.1"} diff --git a/CHANGELOG.md b/CHANGELOG.md index 056c42bc..a63be9d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,5 +2,7 @@ Please refer to each API's `CHANGELOG.md` file under the `packages/` directory Changelogs ----- -- [toolbox-langchain==0.1.0](https://github.com/googleapis/mcp-toolbox-sdk-python/tree/main/packages/toolbox-langchain/CHANGELOG.md) - [toolbox-core==0.1.0](https://github.com/googleapis/mcp-toolbox-sdk-python/tree/main/packages/toolbox-core/CHANGELOG.md) +- [toolbox-langchain==0.1.0](https://github.com/googleapis/mcp-toolbox-sdk-python/tree/main/packages/toolbox-langchain/CHANGELOG.md) +- [toolbox-llamaindex==0.1.1](https://github.com/googleapis/mcp-toolbox-sdk-python/tree/main/packages/toolbox-llamaindex/CHANGELOG.md) + diff --git a/packages/toolbox-langchain/README.md b/packages/toolbox-langchain/README.md index 8d011a01..e2d70029 100644 --- a/packages/toolbox-langchain/README.md +++ b/packages/toolbox-langchain/README.md @@ -1,8 +1,8 @@ ![MCP Toolbox Logo](https://raw.githubusercontent.com/googleapis/genai-toolbox/main/logo.png) -# MCP Toolbox LangChain SDK +# MCP Toolbox LlamaIndex SDK This SDK allows you to seamlessly integrate the functionalities of -[Toolbox](https://github.com/googleapis/genai-toolbox) into your LangChain LLM +[Toolbox](https://github.com/googleapis/genai-toolbox) into your LlamaIndex LLM applications, enabling advanced orchestration and interaction with GenAI models. @@ -15,10 +15,7 @@ applications, enabling advanced orchestration and interaction with GenAI models. - [Loading Tools](#loading-tools) - [Load a toolset](#load-a-toolset) - [Load a single tool](#load-a-single-tool) -- [Use with LangChain](#use-with-langchain) -- [Use with LangGraph](#use-with-langgraph) - - [Represent Tools as Nodes](#represent-tools-as-nodes) - - [Connect Tools with LLM](#connect-tools-with-llm) +- [Use with LlamaIndex](#use-with-llamaindex) - [Manual usage](#manual-usage) - [Authenticating Tools](#authenticating-tools) - [Supported Authentication Mechanisms](#supported-authentication-mechanisms) @@ -38,33 +35,40 @@ applications, enabling advanced orchestration and interaction with GenAI models. ## Installation ```bash -pip install toolbox-langchain +pip install toolbox-llamaindex ``` ## Quickstart Here's a minimal example to get you started using -[LangGraph](https://langchain-ai.github.io/langgraph/reference/prebuilt/#langgraph.prebuilt.chat_agent_executor.create_react_agent): +# TODO: add link +[LlamaIndex](): ```py -from toolbox_langchain import ToolboxClient -from langchain_google_vertexai import ChatVertexAI -from langgraph.prebuilt import create_react_agent +import asyncio -toolbox = ToolboxClient("http://127.0.0.1:5000") -tools = toolbox.load_toolset() +from llama_index.llms.google_genai import GoogleGenAI +from llama_index.core.agent.workflow import AgentWorkflow + +from toolbox_llamaindex import ToolboxClient -model = ChatVertexAI(model="gemini-1.5-pro-002") -agent = create_react_agent(model, tools) +async def run_agent(): + toolbox = ToolboxClient("http://127.0.0.1:5000") + tools = toolbox.load_toolset() -prompt = "How's the weather today?" + vertex_model = GoogleGenAI( + model="gemini-1.5-pro", + vertexai_config={"project": "project-id", "location": "us-central1"}, + ) + agent = AgentWorkflow.from_tools_or_functions( + tools, + llm=vertex_model, + system_prompt="You are a helpful assistant.", + ) + response = await agent.run(user_msg="Get some response from the agent.") + print(response) -for s in agent.stream({"messages": [("user", prompt)]}, stream_mode="values"): - message = s["messages"][-1] - if isinstance(message, tuple): - print(message) - else: - message.pretty_print() +asyncio.run(run_agent()) ``` ## Usage @@ -72,7 +76,7 @@ for s in agent.stream({"messages": [("user", prompt)]}, stream_mode="values"): Import and initialize the toolbox client. ```py -from toolbox_langchain import ToolboxClient +from toolbox_llamaindex import ToolboxClient # Replace with your Toolbox service's URL toolbox = ToolboxClient("http://127.0.0.1:5000") @@ -102,85 +106,63 @@ tool = toolbox.load_tool("my-tool") Loading individual tools gives you finer-grained control over which tools are available to your LLM agent. -## Use with LangChain +## Use with LlamaIndex LangChain's agents can dynamically choose and execute tools based on the user input. Include tools loaded from the Toolbox SDK in the agent's toolkit: ```py -from langchain_google_vertexai import ChatVertexAI +from llama_index.llms.google_genai import GoogleGenAI +from llama_index.core.agent.workflow import AgentWorkflow -model = ChatVertexAI(model="gemini-1.5-pro-002") +vertex_model = GoogleGenAI( + model="gemini-1.5-pro", + vertexai_config={"project": "project-id", "location": "us-central1"}, +) # Initialize agent with tools -agent = model.bind_tools(tools) - -# Run the agent -result = agent.invoke("Do something with the tools") -``` - -## Use with LangGraph - -Integrate the Toolbox SDK with LangGraph to use Toolbox service tools within a -graph-based workflow. Follow the [official -guide](https://langchain-ai.github.io/langgraph/) with minimal changes. - -### Represent Tools as Nodes - -Represent each tool as a LangGraph node, encapsulating the tool's execution within the node's functionality: - -```py -from toolbox_langchain import ToolboxClient -from langgraph.graph import StateGraph, MessagesState -from langgraph.prebuilt import ToolNode - -# Define the function that calls the model -def call_model(state: MessagesState): - messages = state['messages'] - response = model.invoke(messages) - return {"messages": [response]} # Return a list to add to existing messages - -model = ChatVertexAI(model="gemini-1.5-pro-002") -builder = StateGraph(MessagesState) -tool_node = ToolNode(tools) - -builder.add_node("agent", call_model) -builder.add_node("tools", tool_node) +agent = AgentWorkflow.from_tools_or_functions( + tools, + llm=vertex_model, + system_prompt="You are a helpful assistant.", +) + +# Query the agent +response = await agent.run(user_msg="Get some response from the agent.") +print(response) ``` -### Connect Tools with LLM +### Maintain state -Connect tool nodes with LLM nodes. The LLM decides which tool to use based on -input or context. Tool output can be fed back into the LLM: +To maintain state for the agent, add context as follows: ```py -from typing import Literal -from langgraph.graph import END, START -from langchain_core.messages import HumanMessage - -# Define the function that determines whether to continue or not -def should_continue(state: MessagesState) -> Literal["tools", END]: - messages = state['messages'] - last_message = messages[-1] - if last_message.tool_calls: - return "tools" # Route to "tools" node if LLM makes a tool call - return END # Otherwise, stop - -builder.add_edge(START, "agent") -builder.add_conditional_edges("agent", should_continue) -builder.add_edge("tools", 'agent') - -graph = builder.compile() - -graph.invoke({"messages": [HumanMessage(content="Do something with the tools")]}) +from llama_index.core.agent.workflow import AgentWorkflow +from llama_index.core.workflow import Context +from llama_index.llms.google_genai import GoogleGenAI + +vertex_model = GoogleGenAI( + model="gemini-1.5-pro", + vertexai_config={"project": "twisha-dev", "location": "us-central1"}, +) +agent = AgentWorkflow.from_tools_or_functions( + tools, + llm=vertex_model, + system_prompt="You are a helpful assistant", +) + +# Save memory in agent context +ctx = Context(agent) +response = await agent.run(user_msg="Give me some response.", ctx=ctx) +print(response) ``` ## Manual usage -Execute a tool manually using the `invoke` method: +Execute a tool manually using the `call` method: ```py -result = tools[0].invoke({"name": "Alice", "age": 30}) +result = tools[0].call({"name": "Alice", "age": 30}) ``` This is useful for testing tools or when you need precise control over tool @@ -250,7 +232,7 @@ auth_tools = toolbox.load_toolset(auth_tokens={"my_auth": get_auth_token}) ```py import asyncio -from toolbox_langchain import ToolboxClient +from toolbox_llamaindex import ToolboxClient async def get_auth_token(): # ... Logic to retrieve ID token (e.g., from local storage, OAuth flow) @@ -261,7 +243,7 @@ toolbox = ToolboxClient("http://127.0.0.1:5000") tool = toolbox.load_tool("my-tool") auth_tool = tool.add_auth_token("my_auth", get_auth_token) -result = auth_tool.invoke({"input": "some input"}) +result = auth_tool.call({"input": "some input"}) print(result) ``` @@ -329,7 +311,7 @@ use the asynchronous interfaces of the `ToolboxClient`. ```py import asyncio -from toolbox_langchain import ToolboxClient +from toolbox_llamaindex import ToolboxClient async def main(): toolbox = ToolboxClient("http://127.0.0.1:5000") diff --git a/packages/toolbox-llamaindex/CHANGELOG.md b/packages/toolbox-llamaindex/CHANGELOG.md new file mode 100644 index 00000000..82fd4db5 --- /dev/null +++ b/packages/toolbox-llamaindex/CHANGELOG.md @@ -0,0 +1,32 @@ +# Changelog + +## [0.1.1](https://github.com/googleapis/genai-toolbox-llamaindex-python/compare/v0.1.0...v0.1.1) (2025-04-04) + + +### Bug Fixes + +* **deps:** Update dependency black to v25 ([#46](https://github.com/googleapis/genai-toolbox-llamaindex-python/issues/46)) ([ddb60af](https://github.com/googleapis/genai-toolbox-llamaindex-python/commit/ddb60afaa78c4e57b01e87a649963df449f3ac6a)) +* **deps:** Update dependency google-cloud-storage to v3 ([#47](https://github.com/googleapis/genai-toolbox-llamaindex-python/issues/47)) ([d10d779](https://github.com/googleapis/genai-toolbox-llamaindex-python/commit/d10d779ea22c02f04b26825e686ad519b4eec56f)) +* **deps:** Update dependency isort to v6 ([#48](https://github.com/googleapis/genai-toolbox-llamaindex-python/issues/48)) ([e27a249](https://github.com/googleapis/genai-toolbox-llamaindex-python/commit/e27a249afb52bd0a0aff8a0ddb5b6cc8e1c535ec)) +* **deps:** Update dependency pillow to v11 ([#49](https://github.com/googleapis/genai-toolbox-llamaindex-python/issues/49)) ([a467b68](https://github.com/googleapis/genai-toolbox-llamaindex-python/commit/a467b680201e796d80d0699fe7b1de711a99be74)) +* **deps:** Update python-nonmajor ([#44](https://github.com/googleapis/genai-toolbox-llamaindex-python/issues/44)) ([4c1b88d](https://github.com/googleapis/genai-toolbox-llamaindex-python/commit/4c1b88d23d1c0a0b78f6b29200fa32044152c550)) +* **deps:** Update python-nonmajor ([#68](https://github.com/googleapis/genai-toolbox-llamaindex-python/issues/68)) ([7595657](https://github.com/googleapis/genai-toolbox-llamaindex-python/commit/7595657b2dd5cf7974d751649120a08ba3f7853d)) + +## 0.1.0 (2025-03-17) + + +### Features + +* Add support for sync operations ([#20](https://github.com/googleapis/genai-toolbox-llamaindex-python/issues/20)) ([1fa45af](https://github.com/googleapis/genai-toolbox-llamaindex-python/commit/1fa45afed49db863bf17641fb5984bf8ceb5a4c6)) +* Add support for Bound Params. ([#10](https://github.com/googleapis/genai-toolbox-llamaindex-python/issues/10)) ([1d484a8](https://github.com/googleapis/genai-toolbox-llamaindex-python/commit/1d484a8daee5567d5a32d20ea492dbc125daf332)) + +### Bug Fixes + +* Add items to parameter schema ([#9](https://github.com/googleapis/genai-toolbox-llamaindex-python/issues/9)) ([769b7f1](https://github.com/googleapis/genai-toolbox-llamaindex-python/commit/769b7f1c86dd83c9cd5e19c8bd28890da6f6a6ae)) +* Rename package to 'toolbox_llamaindex' ([#8](https://github.com/googleapis/genai-toolbox-llamaindex-python/issues/8)) ([9b71c72](https://github.com/googleapis/genai-toolbox-llamaindex-python/commit/9b71c728a7887d783a027fc54367584e0ddd4489)) +* Throw tool errors correctly. ([#35](https://github.com/googleapis/genai-toolbox-llamaindex-python/issues/35)) ([11159c6](https://github.com/googleapis/genai-toolbox-llamaindex-python/commit/11159c6ac9813d8da21888c70a8550518f64f3ce)) + +### Documentation + +* Update README for new features ([#22](https://github.com/googleapis/genai-toolbox-llamaindex-python/issues/22)) ([f5060b9](https://github.com/googleapis/genai-toolbox-llamaindex-python/commit/f5060b9057329809073553c88ebd2e677db7b902)) +* Update the README to recommend AgentWorkflow for using LlamaIndex. ([#34](https://github.com/googleapis/genai-toolbox-llamaindex-python/issues/34)) ([fe8e74f](https://github.com/googleapis/genai-toolbox-llamaindex-python/commit/fe8e74fb2c76af6598e6054914b03731c85a2741)) \ No newline at end of file diff --git a/packages/toolbox-llamaindex/DEVELOPER.md b/packages/toolbox-llamaindex/DEVELOPER.md new file mode 100644 index 00000000..7e0a5e56 --- /dev/null +++ b/packages/toolbox-llamaindex/DEVELOPER.md @@ -0,0 +1,37 @@ +# Development + +Below are the details to set up a development environment and run tests. + +## Install +1. Clone the repository: + ```bash + git clone https://github.com/googleapis/mcp-toolbox-sdk-python + ``` +1. Navigate to the package directory: + ```bash + cd mcp-toolbox-sdk-python/packages/toolbox-llamaindex + ``` +1. Install the package in editable mode, so changes are reflected without + reinstall: + ```bash + pip install -e . + ``` +1. Make code changes and contribute to the SDK's development. +> [!TIP] +> Using `-e` option allows you to make changes to the SDK code and have +> those changes reflected immediately without reinstalling the package. + +## Test +1. Navigate to the package directory if needed: + ```bash + cd mcp-toolbox-sdk-python/packages/toolbox-llamaindex + ``` +1. Install the SDK and test dependencies: + ```bash + pip install -e .[test] + ``` +1. Run tests and/or contribute to the SDK's development. + + ```bash + pytest + ``` diff --git a/packages/toolbox-llamaindex/README.md b/packages/toolbox-llamaindex/README.md new file mode 100644 index 00000000..8d011a01 --- /dev/null +++ b/packages/toolbox-llamaindex/README.md @@ -0,0 +1,342 @@ +![MCP Toolbox Logo](https://raw.githubusercontent.com/googleapis/genai-toolbox/main/logo.png) +# MCP Toolbox LangChain SDK + +This SDK allows you to seamlessly integrate the functionalities of +[Toolbox](https://github.com/googleapis/genai-toolbox) into your LangChain LLM +applications, enabling advanced orchestration and interaction with GenAI models. + + +## Table of Contents + + +- [Installation](#installation) +- [Quickstart](#quickstart) +- [Usage](#usage) +- [Loading Tools](#loading-tools) + - [Load a toolset](#load-a-toolset) + - [Load a single tool](#load-a-single-tool) +- [Use with LangChain](#use-with-langchain) +- [Use with LangGraph](#use-with-langgraph) + - [Represent Tools as Nodes](#represent-tools-as-nodes) + - [Connect Tools with LLM](#connect-tools-with-llm) +- [Manual usage](#manual-usage) +- [Authenticating Tools](#authenticating-tools) + - [Supported Authentication Mechanisms](#supported-authentication-mechanisms) + - [Configure Tools](#configure-tools) + - [Configure SDK](#configure-sdk) + - [Add Authentication to a Tool](#add-authentication-to-a-tool) + - [Add Authentication While Loading](#add-authentication-while-loading) + - [Complete Example](#complete-example) +- [Binding Parameter Values](#binding-parameter-values) + - [Binding Parameters to a Tool](#binding-parameters-to-a-tool) + - [Binding Parameters While Loading](#binding-parameters-while-loading) + - [Binding Dynamic Values](#binding-dynamic-values) +- [Asynchronous Usage](#asynchronous-usage) + + + +## Installation + +```bash +pip install toolbox-langchain +``` + +## Quickstart + +Here's a minimal example to get you started using +[LangGraph](https://langchain-ai.github.io/langgraph/reference/prebuilt/#langgraph.prebuilt.chat_agent_executor.create_react_agent): + +```py +from toolbox_langchain import ToolboxClient +from langchain_google_vertexai import ChatVertexAI +from langgraph.prebuilt import create_react_agent + +toolbox = ToolboxClient("http://127.0.0.1:5000") +tools = toolbox.load_toolset() + +model = ChatVertexAI(model="gemini-1.5-pro-002") +agent = create_react_agent(model, tools) + +prompt = "How's the weather today?" + +for s in agent.stream({"messages": [("user", prompt)]}, stream_mode="values"): + message = s["messages"][-1] + if isinstance(message, tuple): + print(message) + else: + message.pretty_print() +``` + +## Usage + +Import and initialize the toolbox client. + +```py +from toolbox_langchain import ToolboxClient + +# Replace with your Toolbox service's URL +toolbox = ToolboxClient("http://127.0.0.1:5000") +``` + +## Loading Tools + +### Load a toolset + +A toolset is a collection of related tools. You can load all tools in a toolset +or a specific one: + +```py +# Load all tools +tools = toolbox.load_toolset() + +# Load a specific toolset +tools = toolbox.load_toolset("my-toolset") +``` + +### Load a single tool + +```py +tool = toolbox.load_tool("my-tool") +``` + +Loading individual tools gives you finer-grained control over which tools are +available to your LLM agent. + +## Use with LangChain + +LangChain's agents can dynamically choose and execute tools based on the user +input. Include tools loaded from the Toolbox SDK in the agent's toolkit: + +```py +from langchain_google_vertexai import ChatVertexAI + +model = ChatVertexAI(model="gemini-1.5-pro-002") + +# Initialize agent with tools +agent = model.bind_tools(tools) + +# Run the agent +result = agent.invoke("Do something with the tools") +``` + +## Use with LangGraph + +Integrate the Toolbox SDK with LangGraph to use Toolbox service tools within a +graph-based workflow. Follow the [official +guide](https://langchain-ai.github.io/langgraph/) with minimal changes. + +### Represent Tools as Nodes + +Represent each tool as a LangGraph node, encapsulating the tool's execution within the node's functionality: + +```py +from toolbox_langchain import ToolboxClient +from langgraph.graph import StateGraph, MessagesState +from langgraph.prebuilt import ToolNode + +# Define the function that calls the model +def call_model(state: MessagesState): + messages = state['messages'] + response = model.invoke(messages) + return {"messages": [response]} # Return a list to add to existing messages + +model = ChatVertexAI(model="gemini-1.5-pro-002") +builder = StateGraph(MessagesState) +tool_node = ToolNode(tools) + +builder.add_node("agent", call_model) +builder.add_node("tools", tool_node) +``` + +### Connect Tools with LLM + +Connect tool nodes with LLM nodes. The LLM decides which tool to use based on +input or context. Tool output can be fed back into the LLM: + +```py +from typing import Literal +from langgraph.graph import END, START +from langchain_core.messages import HumanMessage + +# Define the function that determines whether to continue or not +def should_continue(state: MessagesState) -> Literal["tools", END]: + messages = state['messages'] + last_message = messages[-1] + if last_message.tool_calls: + return "tools" # Route to "tools" node if LLM makes a tool call + return END # Otherwise, stop + +builder.add_edge(START, "agent") +builder.add_conditional_edges("agent", should_continue) +builder.add_edge("tools", 'agent') + +graph = builder.compile() + +graph.invoke({"messages": [HumanMessage(content="Do something with the tools")]}) +``` + +## Manual usage + +Execute a tool manually using the `invoke` method: + +```py +result = tools[0].invoke({"name": "Alice", "age": 30}) +``` + +This is useful for testing tools or when you need precise control over tool +execution outside of an agent framework. + +## Authenticating Tools + +> [!WARNING] +> Always use HTTPS to connect your application with the Toolbox service, +> especially when using tools with authentication configured. Using HTTP exposes +> your application to serious security risks. + +Some tools require user authentication to access sensitive data. + +### Supported Authentication Mechanisms +Toolbox currently supports authentication using the [OIDC +protocol](https://openid.net/specs/openid-connect-core-1_0.html) with [ID +tokens](https://openid.net/specs/openid-connect-core-1_0.html#IDToken) (not +access tokens) for [Google OAuth +2.0](https://cloud.google.com/apigee/docs/api-platform/security/oauth/oauth-home). + +### Configure Tools + +Refer to [these +instructions](https://googleapis.github.io/genai-toolbox/resources/tools/#authenticated-parameters) on +configuring tools for authenticated parameters. + +### Configure SDK + +You need a method to retrieve an ID token from your authentication service: + +```py +async def get_auth_token(): + # ... Logic to retrieve ID token (e.g., from local storage, OAuth flow) + # This example just returns a placeholder. Replace with your actual token retrieval. + return "YOUR_ID_TOKEN" # Placeholder +``` + +#### Add Authentication to a Tool + +```py +toolbox = ToolboxClient("http://127.0.0.1:5000") +tools = toolbox.load_toolset() + +auth_tool = tools[0].add_auth_token("my_auth", get_auth_token) # Single token + +multi_auth_tool = tools[0].add_auth_tokens({"my_auth", get_auth_token}) # Multiple tokens + +# OR + +auth_tools = [tool.add_auth_token("my_auth", get_auth_token) for tool in tools] +``` + +#### Add Authentication While Loading + +```py +auth_tool = toolbox.load_tool(auth_tokens={"my_auth": get_auth_token}) + +auth_tools = toolbox.load_toolset(auth_tokens={"my_auth": get_auth_token}) +``` + +> [!NOTE] +> Adding auth tokens during loading only affect the tools loaded within +> that call. + +### Complete Example + +```py +import asyncio +from toolbox_langchain import ToolboxClient + +async def get_auth_token(): + # ... Logic to retrieve ID token (e.g., from local storage, OAuth flow) + # This example just returns a placeholder. Replace with your actual token retrieval. + return "YOUR_ID_TOKEN" # Placeholder + +toolbox = ToolboxClient("http://127.0.0.1:5000") +tool = toolbox.load_tool("my-tool") + +auth_tool = tool.add_auth_token("my_auth", get_auth_token) +result = auth_tool.invoke({"input": "some input"}) +print(result) +``` + +## Binding Parameter Values + +Predetermine values for tool parameters using the SDK. These values won't be +modified by the LLM. This is useful for: + +* **Protecting sensitive information:** API keys, secrets, etc. +* **Enforcing consistency:** Ensuring specific values for certain parameters. +* **Pre-filling known data:** Providing defaults or context. + +### Binding Parameters to a Tool + +```py +toolbox = ToolboxClient("http://127.0.0.1:5000") +tools = toolbox.load_toolset() + +bound_tool = tool[0].bind_param("param", "value") # Single param + +multi_bound_tool = tools[0].bind_params({"param1": "value1", "param2": "value2"}) # Multiple params + +# OR + +bound_tools = [tool.bind_param("param", "value") for tool in tools] +``` + +### Binding Parameters While Loading + +```py +bound_tool = toolbox.load_tool("my-tool", bound_params={"param": "value"}) + +bound_tools = toolbox.load_toolset(bound_params={"param": "value"}) +``` + +> [!NOTE] +> Bound values during loading only affect the tools loaded in that call. + +### Binding Dynamic Values + +Use a function to bind dynamic values: + +```py +def get_dynamic_value(): + # Logic to determine the value + return "dynamic_value" + +dynamic_bound_tool = tool.bind_param("param", get_dynamic_value) +``` + +> [!IMPORTANT] +> You don't need to modify tool configurations to bind parameter values. + +## Asynchronous Usage + +For better performance through [cooperative +multitasking](https://en.wikipedia.org/wiki/Cooperative_multitasking), you can +use the asynchronous interfaces of the `ToolboxClient`. + +> [!Note] +> Asynchronous interfaces like `aload_tool` and `aload_toolset` require an +> asynchronous environment. For guidance on running asynchronous Python +> programs, see [asyncio +> documentation](https://docs.python.org/3/library/asyncio-runner.html#running-an-asyncio-program). + +```py +import asyncio +from toolbox_langchain import ToolboxClient + +async def main(): + toolbox = ToolboxClient("http://127.0.0.1:5000") + tool = await client.aload_tool("my-tool") + tools = await client.aload_toolset() + response = await tool.ainvoke() + +if __name__ == "__main__": + asyncio.run(main()) +``` diff --git a/packages/toolbox-llamaindex/integration.cloudbuild.yaml b/packages/toolbox-llamaindex/integration.cloudbuild.yaml new file mode 100644 index 00000000..ce32e8a6 --- /dev/null +++ b/packages/toolbox-llamaindex/integration.cloudbuild.yaml @@ -0,0 +1,46 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +steps: + - id: Install library requirements + name: 'python:${_VERSION}' + args: + - install + - '-r' + - 'packages/toolbox-llamaindex/requirements.txt' + - '--user' + entrypoint: pip + - id: Install test requirements + name: 'python:${_VERSION}' + args: + - install + - 'packages/toolbox-llamaindex[test]' + - '--user' + entrypoint: pip + - id: Run integration tests + name: 'python:${_VERSION}' + env: + - TOOLBOX_URL=$_TOOLBOX_URL + - TOOLBOX_VERSION=$_TOOLBOX_VERSION + - GOOGLE_CLOUD_PROJECT=$PROJECT_ID + args: + - '-c' + - >- + python -m pytest packages/toolbox-llamaindex/tests/ + entrypoint: /bin/bash +options: + logging: CLOUD_LOGGING_ONLY +substitutions: + _VERSION: '3.13' + _TOOLBOX_VERSION: '0.3.0' diff --git a/packages/toolbox-llamaindex/pyproject.toml b/packages/toolbox-llamaindex/pyproject.toml new file mode 100644 index 00000000..8ec28663 --- /dev/null +++ b/packages/toolbox-llamaindex/pyproject.toml @@ -0,0 +1,70 @@ +[project] +name = "toolbox-llamindex" +dynamic = ["version"] +readme = "README.md" +description = "Python SDK for interacting with the Toolbox service with LlamaIndex" +license = {file = "LICENSE"} +requires-python = ">=3.9" +authors = [ + {name = "Google LLC", email = "googleapis-packages@google.com"} +] +dependencies = [ + "llama-index>=0.12.0,<1.0.0", + "PyYAML>=6.0.1,<7.0.0", + "pydantic>=2.8.0,<3.0.0", + "aiohttp>=3.8.6,<4.0.0", + "deprecated>=1.2.10,<2.0.0", +] + +classifiers = [ + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] + +# Tells setuptools that packages are under the 'src' directory +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.dynamic] +version = {attr = "toolbox_llamaindex.version.__version__"} + +[project.urls] +Homepage = "https://github.com/googleapis/mcp-toolbox-sdk-python/blob/main/packages/toolbox-llamaindex" +Repository = "https://github.com/googleapis/mcp-toolbox-sdk-python.git" +"Bug Tracker" = "https://github.com/googleapis/mcp-toolbox-sdk-python/issues" +Changelog = "https://github.com/googleapis/mcp-toolbox-sdk-python/blob/main/packages/toolbox-llamaindex/CHANGELOG.md" + +[project.optional-dependencies] +test = [ + "black[jupyter]==25.1.0", + "isort==6.0.1", + "mypy==1.15.0", + "pytest-asyncio==0.26.0", + "pytest==8.3.5", + "pytest-cov==6.1.1", + "Pillow==11.1.0", + "google-cloud-secret-manager==2.23.2", + "google-cloud-storage==3.1.0", +] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.black] +target-version = ['py39'] + +[tool.isort] +profile = "black" + +[tool.mypy] +python_version = "3.9" +warn_unused_configs = true +disallow_incomplete_defs = true diff --git a/packages/toolbox-llamaindex/requirements.txt b/packages/toolbox-llamaindex/requirements.txt new file mode 100644 index 00000000..a6ab8110 --- /dev/null +++ b/packages/toolbox-llamaindex/requirements.txt @@ -0,0 +1,5 @@ +llama-index==0.12.28 +PyYAML==6.0.2 +pydantic==2.11.2 +aiohttp==3.11.16 +deprecated==1.2.18 \ No newline at end of file diff --git a/packages/toolbox-llamaindex/src/toolbox_llamaindex/__init__.py b/packages/toolbox-llamaindex/src/toolbox_llamaindex/__init__.py new file mode 100644 index 00000000..5ff0058f --- /dev/null +++ b/packages/toolbox-llamaindex/src/toolbox_llamaindex/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .client import ToolboxClient +from .tools import ToolboxTool + +__all__ = ["ToolboxClient", "ToolboxTool"] diff --git a/packages/toolbox-llamaindex/src/toolbox_llamaindex/async_client.py b/packages/toolbox-llamaindex/src/toolbox_llamaindex/async_client.py new file mode 100644 index 00000000..b65c8ccf --- /dev/null +++ b/packages/toolbox-llamaindex/src/toolbox_llamaindex/async_client.py @@ -0,0 +1,171 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Optional, Union +from warnings import warn + +from aiohttp import ClientSession + +from .tools import AsyncToolboxTool +from .utils import ManifestSchema, _load_manifest + + +# This class is an internal implementation detail and is not exposed to the +# end-user. It should not be used directly by external code. Changes to this +# class will not be considered breaking changes to the public API. +class AsyncToolboxClient: + + def __init__( + self, + url: str, + session: ClientSession, + ): + """ + Initializes the AsyncToolboxClient for the Toolbox service at the given URL. + + Args: + url: The base URL of the Toolbox service. + session: An HTTP client session. + """ + self.__url = url + self.__session = session + + async def aload_tool( + self, + tool_name: str, + auth_tokens: dict[str, Callable[[], str]] = {}, + auth_headers: Optional[dict[str, Callable[[], str]]] = None, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> AsyncToolboxTool: + """ + Loads the tool with the given tool name from the Toolbox service. + + Args: + tool_name: The name of the tool to load. + auth_tokens: An optional mapping of authentication source names to + functions that retrieve ID tokens. + auth_headers: Deprecated. Use `auth_tokens` instead. + bound_params: An optional mapping of parameter names to their + bound values. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. + + Returns: + A tool loaded from the Toolbox. + """ + if auth_headers: + if auth_tokens: + warn( + "Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_tokens` instead.", + DeprecationWarning, + ) + auth_tokens = auth_headers + + url = f"{self.__url}/api/tool/{tool_name}" + manifest: ManifestSchema = await _load_manifest(url, self.__session) + + return AsyncToolboxTool( + tool_name, + manifest.tools[tool_name], + self.__url, + self.__session, + auth_tokens, + bound_params, + strict, + ) + + async def aload_toolset( + self, + toolset_name: Optional[str] = None, + auth_tokens: dict[str, Callable[[], str]] = {}, + auth_headers: Optional[dict[str, Callable[[], str]]] = None, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> list[AsyncToolboxTool]: + """ + Loads tools from the Toolbox service, optionally filtered by toolset + name. + + Args: + toolset_name: The name of the toolset to load. If not provided, + all tools are loaded. + auth_tokens: An optional mapping of authentication source names to + functions that retrieve ID tokens. + auth_headers: Deprecated. Use `auth_tokens` instead. + bound_params: An optional mapping of parameter names to their + bound values. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. + + Returns: + A list of all tools loaded from the Toolbox. + """ + if auth_headers: + if auth_tokens: + warn( + "Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_tokens` instead.", + DeprecationWarning, + ) + auth_tokens = auth_headers + + url = f"{self.__url}/api/toolset/{toolset_name or ''}" + manifest: ManifestSchema = await _load_manifest(url, self.__session) + tools: list[AsyncToolboxTool] = [] + + for tool_name, tool_schema in manifest.tools.items(): + tools.append( + AsyncToolboxTool( + tool_name, + tool_schema, + self.__url, + self.__session, + auth_tokens, + bound_params, + strict, + ) + ) + return tools + + def load_tool( + self, + tool_name: str, + auth_tokens: dict[str, Callable[[], str]] = {}, + auth_headers: Optional[dict[str, Callable[[], str]]] = None, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> AsyncToolboxTool: + raise NotImplementedError("Synchronous methods not supported by async client.") + + def load_toolset( + self, + toolset_name: Optional[str] = None, + auth_tokens: dict[str, Callable[[], str]] = {}, + auth_headers: Optional[dict[str, Callable[[], str]]] = None, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> list[AsyncToolboxTool]: + raise NotImplementedError("Synchronous methods not supported by async client.") diff --git a/packages/toolbox-llamaindex/src/toolbox_llamaindex/async_tools.py b/packages/toolbox-llamaindex/src/toolbox_llamaindex/async_tools.py new file mode 100644 index 00000000..879df74e --- /dev/null +++ b/packages/toolbox-llamaindex/src/toolbox_llamaindex/async_tools.py @@ -0,0 +1,421 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Any, Callable, TypeVar, Union +from warnings import warn + +from aiohttp import ClientResponseError, ClientSession +from llama_index.core.tools import ToolMetadata +from llama_index.core.tools.types import AsyncBaseTool, ToolOutput + +from .utils import ( + ToolSchema, + _find_auth_params, + _find_bound_params, + _invoke_tool, + _schema_to_model, +) + +T = TypeVar("T") + + +# This class is an internal implementation detail and is not exposed to the +# end-user. It should not be used directly by external code. Changes to this +# class will not be considered breaking changes to the public API. +class AsyncToolboxTool(AsyncBaseTool): + """ + A subclass of LlamaIndex's AsyncBaseTool that supports features specific to + Toolbox, like bound parameters and authenticated tools. + """ + + def __init__( + self, + name: str, + schema: ToolSchema, + url: str, + session: ClientSession, + auth_tokens: dict[str, Callable[[], str]] = {}, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> None: + """ + Initializes an AsyncToolboxTool instance. + + Args: + name: The name of the tool. + schema: The tool schema. + url: The base URL of the Toolbox service. + session: The HTTP client session. + auth_tokens: A mapping of authentication source names to functions + that retrieve ID tokens. + bound_params: A mapping of parameter names to their bound + values. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. + """ + + # If the schema is not already a ToolSchema instance, we create one from + # its attributes. This allows flexibility in how the schema is provided, + # accepting both a ToolSchema object and a dictionary of schema + # attributes. + if not isinstance(schema, ToolSchema): + schema = ToolSchema(**schema) + + auth_params, non_auth_params = _find_auth_params(schema.parameters) + non_auth_bound_params, non_auth_non_bound_params = _find_bound_params( + non_auth_params, list(bound_params) + ) + + # Check if the user is trying to bind a param that is authenticated or + # is missing from the given schema. + auth_bound_params: list[str] = [] + missing_bound_params: list[str] = [] + for bound_param in bound_params: + if bound_param in [param.name for param in auth_params]: + auth_bound_params.append(bound_param) + elif bound_param not in [param.name for param in non_auth_params]: + missing_bound_params.append(bound_param) + + # Create error messages for any params that are found to be + # authenticated or missing. + messages: list[str] = [] + if auth_bound_params: + messages.append( + f"Parameter(s) {', '.join(auth_bound_params)} already authenticated and cannot be bound." + ) + if missing_bound_params: + messages.append( + f"Parameter(s) {', '.join(missing_bound_params)} missing and cannot be bound." + ) + + # Join any error messages and raise them as an error or warning, + # depending on the value of the strict flag. + if messages: + message = "\n\n".join(messages) + if strict: + raise ValueError(message) + warn(message) + + # Bind values for parameters present in the schema that don't require + # authentication. + bound_params = { + param_name: param_value + for param_name, param_value in bound_params.items() + if param_name in [param.name for param in non_auth_bound_params] + } + + # Update the tools schema to validate only the presence of parameters + # that neither require authentication nor are bound. + schema.parameters = non_auth_non_bound_params + + # Due to how pydantic works, we must initialize the underlying + # AsyncBaseTool class before assigning values to member variables. + super().__init__() + self.__name = name + self.__schema = schema + self.__url = url + self.__session = session + self.__auth_tokens = auth_tokens + self.__auth_params = auth_params + self.__bound_params = bound_params + + # Warn users about any missing authentication so they can add it before + # tool invocation. + self.__validate_auth(strict=False) + + @property + def metadata(self) -> ToolMetadata: + return ToolMetadata( + name=self.__name, + description=self.__schema.description, + fn_schema=_schema_to_model( + model_name=self.__name, schema=self.__schema.parameters + ), + ) + + def call(self, *args: Any, **kwargs: Any) -> ToolOutput: # type: ignore + raise NotImplementedError("Synchronous methods not supported by async tools.") + + async def acall(self, **kwargs: Any) -> ToolOutput: # type: ignore + """ + The coroutine that invokes the tool with the given arguments. + + Args: + kwargs: The arguments to the tool. + + Returns: + A dictionary containing the parsed JSON response from the tool + invocation. + """ + # Validate arguments with the schema + if self.metadata.fn_schema: + self.metadata.fn_schema.model_validate(kwargs) + + # If the tool had parameters that require authentication, then right + # before invoking that tool, we check whether all these required + # authentication sources have been registered or not. + self.__validate_auth() + + # Evaluate dynamic parameter values if any + evaluated_params = {} + for param_name, param_value in self.__bound_params.items(): + if callable(param_value): + evaluated_params[param_name] = param_value() + else: + evaluated_params[param_name] = param_value + + # Merge bound parameters with the provided arguments + kwargs.update(evaluated_params) + try: + response = await _invoke_tool( + self.__url, self.__session, self.__name, kwargs, self.__auth_tokens + ) + return ToolOutput( + content=str(response), + tool_name=self.__name, + raw_input=kwargs, + raw_output=response, + is_error=False, + ) + except ClientResponseError as e: + return ToolOutput( + content="Encountered error: " + str(e), + tool_name=self.__name, + raw_input=kwargs, + raw_output=str(e), + is_error=True, + ) + + def __validate_auth(self, strict: bool = True) -> None: + """ + Checks if a tool meets the authentication requirements. + + A tool is considered authenticated if all of its parameters meet at + least one of the following conditions: + + * The parameter has at least one registered authentication source. + * The parameter requires no authentication. + + Args: + strict: If True, raises a PermissionError if any required + authentication sources are not registered. If False, only issues + a warning. + + Raises: + PermissionError: If strict is True and any required authentication + sources are not registered. + """ + params_missing_auth: list[str] = [] + + # Check each parameter for at least 1 required auth source + for param in self.__auth_params: + if not param.authSources: + raise ValueError("Auth sources cannot be None.") + has_auth = False + for src in param.authSources: + + # Find first auth source that is specified + if src in self.__auth_tokens: + has_auth = True + break + if not has_auth: + params_missing_auth.append(param.name) + + if params_missing_auth: + message = f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication sources are registered. Please register the required sources before use." + + if strict: + raise PermissionError(message) + warn(message) + + def __create_copy( + self, + *, + auth_tokens: dict[str, Callable[[], str]] = {}, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool, + ) -> "AsyncToolboxTool": + """ + Creates a copy of the current AsyncToolboxTool instance, allowing for + modification of auth tokens and bound params. + + This method enables the creation of new tool instances with inherited + properties from the current instance, while optionally updating the auth + tokens and bound params. This is useful for creating variations of the + tool with additional auth tokens or bound params without modifying the + original instance, ensuring immutability. + + Args: + auth_tokens: A dictionary of auth source names to functions that + retrieve ID tokens. These tokens will be merged with the + existing auth tokens. + bound_params: A dictionary of parameter names to their + bound values or functions to retrieve the values. These params + will be merged with the existing bound params. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. + + Returns: + A new AsyncToolboxTool instance that is a deep copy of the current + instance, with added auth tokens or bound params. + """ + new_schema = deepcopy(self.__schema) + + # Reconstruct the complete parameter schema by merging the auth + # parameters back with the non-auth parameters. This is necessary to + # accurately validate the new combination of auth tokens and bound + # params in the constructor of the new AsyncToolboxTool instance, ensuring + # that any overlaps or conflicts are correctly identified and reported + # as errors or warnings, depending on the given `strict` flag. + new_schema.parameters += self.__auth_params + return AsyncToolboxTool( + name=self.__name, + schema=new_schema, + url=self.__url, + session=self.__session, + auth_tokens={**self.__auth_tokens, **auth_tokens}, + bound_params={**self.__bound_params, **bound_params}, + strict=strict, + ) + + def add_auth_tokens( + self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True + ) -> "AsyncToolboxTool": + """ + Registers functions to retrieve ID tokens for the corresponding + authentication sources. + + Args: + auth_tokens: A dictionary of authentication source names to the + functions that return corresponding ID token. + strict: If True, a ValueError is raised if any of the provided auth + tokens are already bound. If False, only a warning is issued. + + Returns: + A new AsyncToolboxTool instance that is a deep copy of the current + instance, with added auth tokens. + + Raises: + ValueError: If the provided auth tokens are already registered. + ValueError: If the provided auth tokens are already bound and strict + is True. + """ + + # Check if the authentication source is already registered. + dupe_tokens: list[str] = [] + for auth_token, _ in auth_tokens.items(): + if auth_token in self.__auth_tokens: + dupe_tokens.append(auth_token) + + if dupe_tokens: + raise ValueError( + f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`." + ) + + return self.__create_copy(auth_tokens=auth_tokens, strict=strict) + + def add_auth_token( + self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + ) -> "AsyncToolboxTool": + """ + Registers a function to retrieve an ID token for a given authentication + source. + + Args: + auth_source: The name of the authentication source. + get_id_token: A function that returns the ID token. + strict: If True, a ValueError is raised if any of the provided auth + token is already bound. If False, only a warning is issued. + + Returns: + A new ToolboxTool instance that is a deep copy of the current + instance, with added auth token. + + Raises: + ValueError: If the provided auth token is already registered. + ValueError: If the provided auth token is already bound and strict + is True. + """ + return self.add_auth_tokens({auth_source: get_id_token}, strict=strict) + + def bind_params( + self, + bound_params: dict[str, Union[Any, Callable[[], Any]]], + strict: bool = True, + ) -> "AsyncToolboxTool": + """ + Registers values or functions to retrieve the value for the + corresponding bound parameters. + + Args: + bound_params: A dictionary of the bound parameter name to the + value or function of the bound value. + strict: If True, a ValueError is raised if any of the provided bound + params are not defined in the tool's schema, or require + authentication. If False, only a warning is issued. + + Returns: + A new AsyncToolboxTool instance that is a deep copy of the current + instance, with added bound params. + + Raises: + ValueError: If the provided bound params are already bound. + ValueError: if the provided bound params are not defined in the tool's schema, or require + authentication, and strict is True. + """ + + # Check if the parameter is already bound. + dupe_params: list[str] = [] + for param_name, _ in bound_params.items(): + if param_name in self.__bound_params: + dupe_params.append(param_name) + + if dupe_params: + raise ValueError( + f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self.__name}`." + ) + + return self.__create_copy(bound_params=bound_params, strict=strict) + + def bind_param( + self, + param_name: str, + param_value: Union[Any, Callable[[], Any]], + strict: bool = True, + ) -> "AsyncToolboxTool": + """ + Registers a value or a function to retrieve the value for a given bound + parameter. + + Args: + param_name: The name of the bound parameter. + param_value: The value of the bound parameter, or a callable that + returns the value. + strict: If True, a ValueError is raised if any of the provided bound + params is not defined in the tool's schema, or requires + authentication. If False, only a warning is issued. + + Returns: + A new ToolboxTool instance that is a deep copy of the current + instance, with added bound param. + + Raises: + ValueError: If the provided bound param is already bound. + ValueError: if the provided bound param is not defined in the tool's + schema, or requires authentication, and strict is True. + """ + return self.bind_params({param_name: param_value}, strict) diff --git a/packages/toolbox-llamaindex/src/toolbox_llamaindex/client.py b/packages/toolbox-llamaindex/src/toolbox_llamaindex/client.py new file mode 100644 index 00000000..5079beab --- /dev/null +++ b/packages/toolbox-llamaindex/src/toolbox_llamaindex/client.py @@ -0,0 +1,237 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from threading import Thread +from typing import Any, Awaitable, Callable, Optional, TypeVar, Union + +from aiohttp import ClientSession + +from .async_client import AsyncToolboxClient +from .tools import ToolboxTool + +T = TypeVar("T") + + +class ToolboxClient: + __session: Optional[ClientSession] = None + __loop: Optional[asyncio.AbstractEventLoop] = None + __thread: Optional[Thread] = None + + def __init__( + self, + url: str, + ) -> None: + """ + Initializes the ToolboxClient for the Toolbox service at the given URL. + + Args: + url: The base URL of the Toolbox service. + """ + + # Running a loop in a background thread allows us to support async + # methods from non-async environments. + if ToolboxClient.__loop is None: + loop = asyncio.new_event_loop() + thread = Thread(target=loop.run_forever, daemon=True) + thread.start() + ToolboxClient.__thread = thread + ToolboxClient.__loop = loop + + async def __start_session() -> None: + + # Use a default session if none is provided. This leverages connection + # pooling for better performance by reusing a single session throughout + # the application's lifetime. + if ToolboxClient.__session is None: + ToolboxClient.__session = ClientSession() + + coro = __start_session() + + asyncio.run_coroutine_threadsafe(coro, ToolboxClient.__loop).result() + + if not ToolboxClient.__session: + raise ValueError("Session cannot be None.") + self.__async_client = AsyncToolboxClient(url, ToolboxClient.__session) + + def __run_as_sync(self, coro: Awaitable[T]) -> T: + """Run an async coroutine synchronously""" + if not self.__loop: + raise Exception( + "Cannot call synchronous methods before the background loop is initialized." + ) + return asyncio.run_coroutine_threadsafe(coro, self.__loop).result() + + async def __run_as_async(self, coro: Awaitable[T]) -> T: + """Run an async coroutine asynchronously""" + + # If a loop has not been provided, attempt to run in current thread. + if not self.__loop: + return await coro + + # Otherwise, run in the background thread. + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, self.__loop) + ) + + async def aload_tool( + self, + tool_name: str, + auth_tokens: dict[str, Callable[[], str]] = {}, + auth_headers: Optional[dict[str, Callable[[], str]]] = None, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> ToolboxTool: + """ + Loads the tool with the given tool name from the Toolbox service. + + Args: + tool_name: The name of the tool to load. + auth_tokens: An optional mapping of authentication source names to + functions that retrieve ID tokens. + auth_headers: Deprecated. Use `auth_tokens` instead. + bound_params: An optional mapping of parameter names to their + bound values. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. + + Returns: + A tool loaded from the Toolbox. + """ + async_tool = await self.__run_as_async( + self.__async_client.aload_tool( + tool_name, auth_tokens, auth_headers, bound_params, strict + ) + ) + + if not self.__loop or not self.__thread: + raise ValueError("Background loop or thread cannot be None.") + return ToolboxTool(async_tool, self.__loop, self.__thread) + + async def aload_toolset( + self, + toolset_name: Optional[str] = None, + auth_tokens: dict[str, Callable[[], str]] = {}, + auth_headers: Optional[dict[str, Callable[[], str]]] = None, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> list[ToolboxTool]: + """ + Loads tools from the Toolbox service, optionally filtered by toolset + name. + + Args: + toolset_name: The name of the toolset to load. If not provided, + all tools are loaded. + auth_tokens: An optional mapping of authentication source names to + functions that retrieve ID tokens. + auth_headers: Deprecated. Use `auth_tokens` instead. + bound_params: An optional mapping of parameter names to their + bound values. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. + + Returns: + A list of all tools loaded from the Toolbox. + """ + async_tools = await self.__run_as_async( + self.__async_client.aload_toolset( + toolset_name, auth_tokens, auth_headers, bound_params, strict + ) + ) + + tools: list[ToolboxTool] = [] + + if not self.__loop or not self.__thread: + raise ValueError("Background loop or thread cannot be None.") + for async_tool in async_tools: + tools.append(ToolboxTool(async_tool, self.__loop, self.__thread)) + return tools + + def load_tool( + self, + tool_name: str, + auth_tokens: dict[str, Callable[[], str]] = {}, + auth_headers: Optional[dict[str, Callable[[], str]]] = None, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> ToolboxTool: + """ + Loads the tool with the given tool name from the Toolbox service. + + Args: + tool_name: The name of the tool to load. + auth_tokens: An optional mapping of authentication source names to + functions that retrieve ID tokens. + auth_headers: Deprecated. Use `auth_tokens` instead. + bound_params: An optional mapping of parameter names to their + bound values. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. + + Returns: + A tool loaded from the Toolbox. + """ + async_tool = self.__run_as_sync( + self.__async_client.aload_tool( + tool_name, auth_tokens, auth_headers, bound_params, strict + ) + ) + + if not self.__loop or not self.__thread: + raise ValueError("Background loop or thread cannot be None.") + return ToolboxTool(async_tool, self.__loop, self.__thread) + + def load_toolset( + self, + toolset_name: Optional[str] = None, + auth_tokens: dict[str, Callable[[], str]] = {}, + auth_headers: Optional[dict[str, Callable[[], str]]] = None, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> list[ToolboxTool]: + """ + Loads tools from the Toolbox service, optionally filtered by toolset + name. + + Args: + toolset_name: The name of the toolset to load. If not provided, + all tools are loaded. + auth_tokens: An optional mapping of authentication source names to + functions that retrieve ID tokens. + auth_headers: Deprecated. Use `auth_tokens` instead. + bound_params: An optional mapping of parameter names to their + bound values. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. + + Returns: + A list of all tools loaded from the Toolbox. + """ + async_tools = self.__run_as_sync( + self.__async_client.aload_toolset( + toolset_name, auth_tokens, auth_headers, bound_params, strict + ) + ) + + if not self.__loop or not self.__thread: + raise ValueError("Background loop or thread cannot be None.") + tools: list[ToolboxTool] = [] + for async_tool in async_tools: + tools.append(ToolboxTool(async_tool, self.__loop, self.__thread)) + return tools diff --git a/packages/toolbox-llamaindex/src/toolbox_llamaindex/py.typed b/packages/toolbox-llamaindex/src/toolbox_llamaindex/py.typed new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/packages/toolbox-llamaindex/src/toolbox_llamaindex/py.typed @@ -0,0 +1 @@ + diff --git a/packages/toolbox-llamaindex/src/toolbox_llamaindex/tools.py b/packages/toolbox-llamaindex/src/toolbox_llamaindex/tools.py new file mode 100644 index 00000000..00690dca --- /dev/null +++ b/packages/toolbox-llamaindex/src/toolbox_llamaindex/tools.py @@ -0,0 +1,210 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from asyncio import AbstractEventLoop +from threading import Thread +from typing import Any, Awaitable, Callable, TypeVar, Union + +from llama_index.core.tools import ToolMetadata +from llama_index.core.tools.types import AsyncBaseTool, ToolOutput + +from .async_tools import AsyncToolboxTool + +T = TypeVar("T") + + +class ToolboxTool(AsyncBaseTool): + """ + A subclass of LlamaIndex's AsyncBaseTool that supports features specific to + Toolbox, like bound parameters and authenticated tools. + """ + + def __init__( + self, + async_tool: AsyncToolboxTool, + loop: AbstractEventLoop, + thread: Thread, + ) -> None: + """ + Initializes a ToolboxTool instance. + + Args: + async_tool: The underlying AsyncToolboxTool instance. + loop: The event loop used to run asynchronous tasks. + thread: The thread to run blocking operations in. + """ + + # Due to how pydantic works, we must initialize the underlying + # AsyncBaseTool class before assigning values to member variables. + super().__init__() + + self.__async_tool = async_tool + self.__loop = loop + self.__thread = thread + + def __run_as_sync(self, coro: Awaitable[T]) -> T: + """Run an async coroutine synchronously""" + if not self.__loop: + raise Exception( + "Cannot call synchronous methods before the background loop is initialized." + ) + return asyncio.run_coroutine_threadsafe(coro, self.__loop).result() + + async def __run_as_async(self, coro: Awaitable[T]) -> T: + """Run an async coroutine asynchronously""" + + # If a loop has not been provided, attempt to run in current thread. + if not self.__loop: + return await coro + + # Otherwise, run in the background thread. + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, self.__loop) + ) + + @property + def metadata(self) -> ToolMetadata: + async_tool = self.__async_tool + return ToolMetadata( + name=async_tool.metadata.name, + description=async_tool.metadata.description, + fn_schema=async_tool.metadata.fn_schema, + ) + + def call(self, **kwargs: Any) -> ToolOutput: # type: ignore + return self.__run_as_sync(self.__async_tool.acall(**kwargs)) + + async def acall(self, **kwargs: Any) -> ToolOutput: # type: ignore + return await self.__run_as_async(self.__async_tool.acall(**kwargs)) + + def add_auth_tokens( + self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True + ) -> "ToolboxTool": + """ + Registers functions to retrieve ID tokens for the corresponding + authentication sources. + + Args: + auth_tokens: A dictionary of authentication source names to the + functions that return corresponding ID token. + strict: If True, a ValueError is raised if any of the provided auth + tokens are already bound. If False, only a warning is issued. + + Returns: + A new ToolboxTool instance that is a deep copy of the current + instance, with added auth tokens. + + Raises: + ValueError: If the provided auth tokens are already registered. + ValueError: If the provided auth tokens are already bound and strict + is True. + """ + return ToolboxTool( + self.__async_tool.add_auth_tokens(auth_tokens, strict), + self.__loop, + self.__thread, + ) + + def add_auth_token( + self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + ) -> "ToolboxTool": + """ + Registers a function to retrieve an ID token for a given authentication + source. + + Args: + auth_source: The name of the authentication source. + get_id_token: A function that returns the ID token. + strict: If True, a ValueError is raised if any of the provided auth + token is already bound. If False, only a warning is issued. + + Returns: + A new ToolboxTool instance that is a deep copy of the current + instance, with added auth token. + + Raises: + ValueError: If the provided auth token is already registered. + ValueError: If the provided auth token is already bound and strict + is True. + """ + return ToolboxTool( + self.__async_tool.add_auth_token(auth_source, get_id_token, strict), + self.__loop, + self.__thread, + ) + + def bind_params( + self, + bound_params: dict[str, Union[Any, Callable[[], Any]]], + strict: bool = True, + ) -> "ToolboxTool": + """ + Registers values or functions to retrieve the value for the + corresponding bound parameters. + + Args: + bound_params: A dictionary of the bound parameter name to the + value or function of the bound value. + strict: If True, a ValueError is raised if any of the provided bound + params are not defined in the tool's schema, or require + authentication. If False, only a warning is issued. + + Returns: + A new ToolboxTool instance that is a deep copy of the current + instance, with added bound params. + + Raises: + ValueError: If the provided bound params are already bound. + ValueError: if the provided bound params are not defined in the tool's schema, or require + authentication, and strict is True. + """ + return ToolboxTool( + self.__async_tool.bind_params(bound_params, strict), + self.__loop, + self.__thread, + ) + + def bind_param( + self, + param_name: str, + param_value: Union[Any, Callable[[], Any]], + strict: bool = True, + ) -> "ToolboxTool": + """ + Registers a value or a function to retrieve the value for a given bound + parameter. + + Args: + param_name: The name of the bound parameter. + param_value: The value of the bound parameter, or a callable that + returns the value. + strict: If True, a ValueError is raised if any of the provided bound + params is not defined in the tool's schema, or requires + authentication. If False, only a warning is issued. + + Returns: + A new ToolboxTool instance that is a deep copy of the current + instance, with added bound param. + + Raises: + ValueError: If the provided bound param is already bound. + ValueError: if the provided bound param is not defined in the tool's + schema, or requires authentication, and strict is True. + """ + return ToolboxTool( + self.__async_tool.bind_param(param_name, param_value, strict), + self.__loop, + self.__thread, + ) diff --git a/packages/toolbox-llamaindex/src/toolbox_llamaindex/utils.py b/packages/toolbox-llamaindex/src/toolbox_llamaindex/utils.py new file mode 100644 index 00000000..54c55e30 --- /dev/null +++ b/packages/toolbox-llamaindex/src/toolbox_llamaindex/utils.py @@ -0,0 +1,262 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any, Callable, Optional, Type, cast +from warnings import warn + +from aiohttp import ClientSession +from deprecated import deprecated +from pydantic import BaseModel, Field, create_model + + +class ParameterSchema(BaseModel): + """ + Schema for a tool parameter. + """ + + name: str + type: str + description: str + authSources: Optional[list[str]] = None + items: Optional["ParameterSchema"] = None + + +class ToolSchema(BaseModel): + """ + Schema for a tool. + """ + + description: str + parameters: list[ParameterSchema] + + +class ManifestSchema(BaseModel): + """ + Schema for the Toolbox manifest. + """ + + serverVersion: str + tools: dict[str, ToolSchema] + + +async def _load_manifest(url: str, session: ClientSession) -> ManifestSchema: + """ + Asynchronously fetches and parses the JSON manifest schema from the given + URL. + + Args: + url: The URL to fetch the JSON from. + session: The HTTP client session. + + Returns: + The parsed Toolbox manifest. + + Raises: + json.JSONDecodeError: If the response is not valid JSON. + ValueError: If the response is not a valid manifest. + """ + async with session.get(url) as response: + # TODO: Remove as it masks error messages. + response.raise_for_status() + try: + # TODO: Simply use response.json() + parsed_json = json.loads(await response.text()) + except json.JSONDecodeError as e: + raise json.JSONDecodeError( + f"Failed to parse JSON from {url}: {e}", e.doc, e.pos + ) from e + try: + return ManifestSchema(**parsed_json) + except ValueError as e: + raise ValueError(f"Invalid JSON data from {url}: {e}") from e + + +def _schema_to_model(model_name: str, schema: list[ParameterSchema]) -> Type[BaseModel]: + """ + Converts the given manifest schema to a Pydantic BaseModel class. + + Args: + model_name: The name of the model to create. + schema: The schema to convert. + + Returns: + A Pydantic BaseModel class. + """ + field_definitions = {} + for field in schema: + field_definitions[field.name] = cast( + Any, + ( + _parse_type(field), + Field(description=field.description), + ), + ) + + return create_model(model_name, **field_definitions) + + +def _parse_type(schema_: ParameterSchema) -> Any: + """ + Converts a schema type to a JSON type. + + Args: + schema_: The ParameterSchema to convert. + + Returns: + A valid JSON type. + + Raises: + ValueError: If the given type is not supported. + """ + type_ = schema_.type + + if type_ == "string": + return str + elif type_ == "integer": + return int + elif type_ == "float": + return float + elif type_ == "boolean": + return bool + elif type_ == "array": + if isinstance(schema_, ParameterSchema) and schema_.items: + return list[_parse_type(schema_.items)] # type: ignore + else: + raise ValueError(f"Schema missing field items") + else: + raise ValueError(f"Unsupported schema type: {type_}") + + +@deprecated("Please use `_get_auth_tokens` instead.") +def _get_auth_headers(id_token_getters: dict[str, Callable[[], str]]) -> dict[str, str]: + """ + Deprecated. Use `_get_auth_tokens` instead. + """ + return _get_auth_tokens(id_token_getters) + + +def _get_auth_tokens(id_token_getters: dict[str, Callable[[], str]]) -> dict[str, str]: + """ + Gets ID tokens for the given auth sources in the getters map and returns + tokens to be included in tool invocation. + + Args: + id_token_getters: A dict that maps auth source names to the functions + that return its ID token. + + Returns: + A dictionary of tokens to be included in the tool invocation. + """ + auth_tokens = {} + for auth_source, get_id_token in id_token_getters.items(): + auth_tokens[f"{auth_source}_token"] = get_id_token() + return auth_tokens + + +async def _invoke_tool( + url: str, + session: ClientSession, + tool_name: str, + data: dict, + id_token_getters: dict[str, Callable[[], str]], +) -> dict: + """ + Asynchronously makes an API call to the Toolbox service to invoke a tool. + + Args: + url: The base URL of the Toolbox service. + session: The HTTP client session. + tool_name: The name of the tool to invoke. + data: The input data for the tool. + id_token_getters: A dict that maps auth source names to the functions + that return its ID token. + + Returns: + A dictionary containing the parsed JSON response from the tool + invocation. + """ + url = f"{url}/api/tool/{tool_name}/invoke" + auth_tokens = _get_auth_tokens(id_token_getters) + + # ID tokens contain sensitive user information (claims). Transmitting these + # over HTTP exposes the data to interception and unauthorized access. Always + # use HTTPS to ensure secure communication and protect user privacy. + if auth_tokens and not url.startswith("https://"): + warn( + "Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication." + ) + + async with session.post( + url, + json=data, + headers=auth_tokens, + ) as response: + # TODO: Remove as it masks error messages. + response.raise_for_status() + return await response.json() + + +def _find_auth_params( + params: list[ParameterSchema], +) -> tuple[list[ParameterSchema], list[ParameterSchema]]: + """ + Separates parameters into those that are authenticated and those that are not. + + Args: + params: A list of ParameterSchema objects. + + Returns: + A tuple containing two lists: + - auth_params: A list of ParameterSchema objects that require authentication. + - non_auth_params: A list of ParameterSchema objects that do not require authentication. + """ + _auth_params: list[ParameterSchema] = [] + _non_auth_params: list[ParameterSchema] = [] + + for param in params: + if param.authSources: + _auth_params.append(param) + else: + _non_auth_params.append(param) + + return (_auth_params, _non_auth_params) + + +def _find_bound_params( + params: list[ParameterSchema], bound_params: list[str] +) -> tuple[list[ParameterSchema], list[ParameterSchema]]: + """ + Separates parameters into those that are bound and those that are not. + + Args: + params: A list of ParameterSchema objects. + bound_params: A list of parameter names that are bound. + + Returns: + A tuple containing two lists: + - bound_params: A list of ParameterSchema objects whose names are in the bound_params list. + - non_bound_params: A list of ParameterSchema objects whose names are not in the bound_params list. + """ + + _bound_params: list[ParameterSchema] = [] + _non_bound_params: list[ParameterSchema] = [] + + for param in params: + if param.name in bound_params: + _bound_params.append(param) + else: + _non_bound_params.append(param) + + return (_bound_params, _non_bound_params) diff --git a/packages/toolbox-llamaindex/src/toolbox_llamaindex/version.py b/packages/toolbox-llamaindex/src/toolbox_llamaindex/version.py new file mode 100644 index 00000000..5ff95198 --- /dev/null +++ b/packages/toolbox-llamaindex/src/toolbox_llamaindex/version.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "0.1.1" diff --git a/packages/toolbox-llamaindex/tests/conftest.py b/packages/toolbox-llamaindex/tests/conftest.py new file mode 100644 index 00000000..231ef349 --- /dev/null +++ b/packages/toolbox-llamaindex/tests/conftest.py @@ -0,0 +1,166 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains pytest fixtures that are accessible from all +files present in the same directory.""" + +from __future__ import annotations + +import os +import platform +import subprocess +import tempfile +import time +from typing import Generator + +import google +import pytest_asyncio +from google.auth import compute_engine +from google.cloud import secretmanager, storage + + +#### Define Utility Functions +def get_env_var(key: str) -> str: + """Gets environment variables.""" + value = os.environ.get(key) + if value is None: + raise ValueError(f"Must set env var {key}") + return value + + +def access_secret_version( + project_id: str, secret_id: str, version_id: str = "latest" +) -> str: + """Accesses the payload of a given secret version from Secret Manager.""" + client = secretmanager.SecretManagerServiceClient() + name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}" + response = client.access_secret_version(request={"name": name}) + return response.payload.data.decode("UTF-8") + + +def create_tmpfile(content: str) -> str: + """Creates a temporary file with the given content.""" + with tempfile.NamedTemporaryFile(delete=False, mode="w") as tmpfile: + tmpfile.write(content) + return tmpfile.name + + +def download_blob( + bucket_name: str, source_blob_name: str, destination_file_name: str +) -> None: + """Downloads a blob from a GCS bucket.""" + storage_client = storage.Client() + + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(source_blob_name) + blob.download_to_filename(destination_file_name) + + print(f"Blob {source_blob_name} downloaded to {destination_file_name}.") + + +def get_toolbox_binary_url(toolbox_version: str) -> str: + """Constructs the GCS path to the toolbox binary.""" + os_system = platform.system().lower() + arch = ( + "arm64" if os_system == "darwin" and platform.machine() == "arm64" else "amd64" + ) + return f"v{toolbox_version}/{os_system}/{arch}/toolbox" + + +def get_auth_token(client_id: str) -> str: + """Retrieves an authentication token""" + request = google.auth.transport.requests.Request() + credentials = compute_engine.IDTokenCredentials( + request=request, + target_audience=client_id, + use_metadata_identity_endpoint=True, + ) + if not credentials.valid: + credentials.refresh(request) + return credentials.token + + +#### Define Fixtures +@pytest_asyncio.fixture(scope="session") +def project_id() -> str: + return get_env_var("GOOGLE_CLOUD_PROJECT") + + +@pytest_asyncio.fixture(scope="session") +def toolbox_version() -> str: + return get_env_var("TOOLBOX_VERSION") + + +@pytest_asyncio.fixture(scope="session") +def tools_file_path(project_id: str) -> Generator[str]: + """Provides a temporary file path containing the tools manifest.""" + tools_manifest = access_secret_version( + project_id=project_id, secret_id="sdk_testing_tools" + ) + tools_file_path = create_tmpfile(tools_manifest) + yield tools_file_path + os.remove(tools_file_path) + + +@pytest_asyncio.fixture(scope="session") +def auth_token1(project_id: str) -> str: + client_id = access_secret_version( + project_id=project_id, secret_id="sdk_testing_client1" + ) + return get_auth_token(client_id) + + +@pytest_asyncio.fixture(scope="session") +def auth_token2(project_id: str) -> str: + client_id = access_secret_version( + project_id=project_id, secret_id="sdk_testing_client2" + ) + return get_auth_token(client_id) + + +@pytest_asyncio.fixture(scope="session") +def toolbox_server(toolbox_version: str, tools_file_path: str) -> Generator[None]: + """Starts the toolbox server as a subprocess.""" + print("Downloading toolbox binary from gcs bucket...") + source_blob_name = get_toolbox_binary_url(toolbox_version) + download_blob("genai-toolbox", source_blob_name, "toolbox") + print("Toolbox binary downloaded successfully.") + try: + print("Opening toolbox server process...") + # Make toolbox executable + os.chmod("toolbox", 0o700) + # Run toolbox binary + toolbox_server = subprocess.Popen( + ["./toolbox", "--tools_file", tools_file_path] + ) + + # Wait for server to start + # Retry logic with a timeout + for _ in range(5): # retries + time.sleep(4) + print("Checking if toolbox is successfully started...") + if toolbox_server.poll() is None: + print("Toolbox server started successfully.") + break + else: + raise RuntimeError("Toolbox server failed to start after 5 retries.") + except subprocess.CalledProcessError as e: + print(e.stderr.decode("utf-8")) + print(e.stdout.decode("utf-8")) + raise RuntimeError(f"{e}\n\n{e.stderr.decode('utf-8')}") from e + yield + + # Clean up toolbox server + toolbox_server.terminate() + toolbox_server.wait() diff --git a/packages/toolbox-llamaindex/tests/test_async_client.py b/packages/toolbox-llamaindex/tests/test_async_client.py new file mode 100644 index 00000000..cdfd2cbc --- /dev/null +++ b/packages/toolbox-llamaindex/tests/test_async_client.py @@ -0,0 +1,194 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest.mock import AsyncMock, patch +from warnings import catch_warnings, simplefilter + +import pytest +from aiohttp import ClientSession + +from toolbox_llamaindex.async_client import AsyncToolboxClient +from toolbox_llamaindex.async_tools import AsyncToolboxTool +from toolbox_llamaindex.utils import ManifestSchema + +URL = "http://test_url" +MANIFEST_JSON = { + "serverVersion": "1.0.0", + "tools": { + "test_tool_1": { + "description": "Test Tool 1 Description", + "parameters": [ + { + "name": "param1", + "type": "string", + "description": "Param 1", + } + ], + }, + "test_tool_2": { + "description": "Test Tool 2 Description", + "parameters": [ + { + "name": "param2", + "type": "integer", + "description": "Param 2", + } + ], + }, + }, +} + + +@pytest.mark.asyncio +class TestAsyncToolboxClient: + @pytest.fixture() + def manifest_schema(self): + return ManifestSchema(**MANIFEST_JSON) + + @pytest.fixture() + def mock_session(self): + return AsyncMock(spec=ClientSession) + + @pytest.fixture() + def mock_client(self, mock_session): + return AsyncToolboxClient(URL, session=mock_session) + + async def test_create_with_existing_session(self, mock_client, mock_session): + assert mock_client._AsyncToolboxClient__session == mock_session + + @patch("toolbox_llamaindex.async_client._load_manifest") + async def test_aload_tool( + self, mock_load_manifest, mock_client, mock_session, manifest_schema + ): + tool_name = "test_tool_1" + mock_load_manifest.return_value = manifest_schema + + tool = await mock_client.aload_tool(tool_name) + + mock_load_manifest.assert_called_once_with( + f"{URL}/api/tool/{tool_name}", mock_session + ) + assert isinstance(tool, AsyncToolboxTool) + assert tool._AsyncToolboxTool__name == tool_name + + @patch("toolbox_llamaindex.async_client._load_manifest") + async def test_aload_tool_auth_headers_deprecated( + self, mock_load_manifest, mock_client, manifest_schema + ): + tool_name = "test_tool_1" + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_tool( + tool_name, auth_headers={"Authorization": lambda: "Bearer token"} + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_headers" in str(w[-1].message) + + @patch("toolbox_llamaindex.async_client._load_manifest") + async def test_aload_tool_auth_headers_and_tokens( + self, mock_load_manifest, mock_client, manifest_schema + ): + tool_name = "test_tool_1" + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_tool( + tool_name, + auth_headers={"Authorization": lambda: "Bearer token"}, + auth_tokens={"test": lambda: "token"}, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_headers" in str(w[-1].message) + + @patch("toolbox_llamaindex.async_client._load_manifest") + async def test_aload_toolset( + self, mock_load_manifest, mock_client, mock_session, manifest_schema + ): + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + tools = await mock_client.aload_toolset() + + mock_load_manifest.assert_called_once_with(f"{URL}/api/toolset/", mock_session) + assert len(tools) == 2 + for tool in tools: + assert isinstance(tool, AsyncToolboxTool) + assert tool._AsyncToolboxTool__name in ["test_tool_1", "test_tool_2"] + + @patch("toolbox_llamaindex.async_client._load_manifest") + async def test_aload_toolset_with_toolset_name( + self, mock_load_manifest, mock_client, mock_session, manifest_schema + ): + toolset_name = "test_toolset_1" + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + tools = await mock_client.aload_toolset(toolset_name=toolset_name) + + mock_load_manifest.assert_called_once_with( + f"{URL}/api/toolset/{toolset_name}", mock_session + ) + assert len(tools) == 2 + for tool in tools: + assert isinstance(tool, AsyncToolboxTool) + assert tool._AsyncToolboxTool__name in ["test_tool_1", "test_tool_2"] + + @patch("toolbox_llamaindex.async_client._load_manifest") + async def test_aload_toolset_auth_headers_deprecated( + self, mock_load_manifest, mock_client, manifest_schema + ): + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_toolset( + auth_headers={"Authorization": lambda: "Bearer token"} + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_headers" in str(w[-1].message) + + @patch("toolbox_llamaindex.async_client._load_manifest") + async def test_aload_toolset_auth_headers_and_tokens( + self, mock_load_manifest, mock_client, manifest_schema + ): + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_toolset( + auth_headers={"Authorization": lambda: "Bearer token"}, + auth_tokens={"test": lambda: "token"}, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_headers" in str(w[-1].message) + + async def test_load_tool_not_implemented(self, mock_client): + with pytest.raises(NotImplementedError) as excinfo: + mock_client.load_tool("test_tool") + assert "Synchronous methods not supported by async client." in str( + excinfo.value + ) + + async def test_load_toolset_not_implemented(self, mock_client): + with pytest.raises(NotImplementedError) as excinfo: + mock_client.load_toolset() + assert "Synchronous methods not supported by async client." in str( + excinfo.value + ) diff --git a/packages/toolbox-llamaindex/tests/test_async_tools.py b/packages/toolbox-llamaindex/tests/test_async_tools.py new file mode 100644 index 00000000..16b891e5 --- /dev/null +++ b/packages/toolbox-llamaindex/tests/test_async_tools.py @@ -0,0 +1,270 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +import pytest_asyncio +from pydantic import ValidationError + +from toolbox_llamaindex.async_tools import AsyncToolboxTool + + +@pytest.mark.asyncio +class TestAsyncToolboxTool: + @pytest.fixture + def tool_schema(self): + return { + "description": "Test Tool Description", + "parameters": [ + {"name": "param1", "type": "string", "description": "Param 1"}, + {"name": "param2", "type": "integer", "description": "Param 2"}, + ], + } + + @pytest.fixture + def auth_tool_schema(self): + return { + "description": "Test Tool Description", + "parameters": [ + { + "name": "param1", + "type": "string", + "description": "Param 1", + "authSources": ["test-auth-source"], + }, + {"name": "param2", "type": "integer", "description": "Param 2"}, + ], + } + + @pytest_asyncio.fixture + @patch("aiohttp.ClientSession") + async def toolbox_tool(self, mock_client_session, tool_schema): + mock_session = mock_client_session.return_value + mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() + mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"result": "test-result"} + ) + tool = AsyncToolboxTool( + name="test_tool", + schema=tool_schema, + url="http://test_url", + session=mock_session, + ) + return tool + + @pytest_asyncio.fixture + @patch("aiohttp.ClientSession") + async def auth_toolbox_tool(self, mock_client_session, auth_tool_schema): + mock_session = mock_client_session.return_value + mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() + mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"result": "test-result"} + ) + with pytest.warns( + UserWarning, + match=r"Parameter\(s\) `param1` of tool test_tool require authentication", + ): + tool = AsyncToolboxTool( + name="test_tool", + schema=auth_tool_schema, + url="https://test-url", + session=mock_session, + ) + return tool + + @patch("aiohttp.ClientSession") + async def test_toolbox_tool_init(self, mock_client_session, tool_schema): + mock_session = mock_client_session.return_value + tool = AsyncToolboxTool( + name="test_tool", + schema=tool_schema, + url="https://test-url", + session=mock_session, + ) + assert tool.metadata.name == "test_tool" + assert tool.metadata.description == "Test Tool Description" + + @pytest.mark.parametrize( + "params, expected_bound_params", + [ + ({"param1": "bound-value"}, {"param1": "bound-value"}), + ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), + ( + {"param1": "bound-value", "param2": 123}, + {"param1": "bound-value", "param2": 123}, + ), + ], + ) + async def test_toolbox_tool_bind_params( + self, toolbox_tool, params, expected_bound_params + ): + tool = toolbox_tool.bind_params(params) + for key, value in expected_bound_params.items(): + if callable(value): + assert value() == tool._AsyncToolboxTool__bound_params[key]() + else: + assert value == tool._AsyncToolboxTool__bound_params[key] + + @pytest.mark.parametrize("strict", [True, False]) + async def test_toolbox_tool_bind_params_invalid(self, toolbox_tool, strict): + if strict: + with pytest.raises(ValueError) as e: + tool = toolbox_tool.bind_params( + {"param3": "bound-value"}, strict=strict + ) + assert "Parameter(s) param3 missing and cannot be bound." in str(e.value) + else: + with pytest.warns(UserWarning) as record: + tool = toolbox_tool.bind_params( + {"param3": "bound-value"}, strict=strict + ) + assert len(record) == 1 + assert "Parameter(s) param3 missing and cannot be bound." in str( + record[0].message + ) + + async def test_toolbox_tool_bind_params_duplicate(self, toolbox_tool): + tool = toolbox_tool.bind_params({"param1": "bound-value"}) + with pytest.raises(ValueError) as e: + tool = tool.bind_params({"param1": "bound-value"}) + assert "Parameter(s) `param1` already bound in tool `test_tool`." in str( + e.value + ) + + async def test_toolbox_tool_bind_params_invalid_params(self, auth_toolbox_tool): + with pytest.raises(ValueError) as e: + auth_toolbox_tool.bind_params({"param1": "bound-value"}) + assert "Parameter(s) param1 already authenticated and cannot be bound." in str( + e.value + ) + + @pytest.mark.parametrize( + "auth_tokens, expected_auth_tokens", + [ + ( + {"test-auth-source": lambda: "test-token"}, + {"test-auth-source": lambda: "test-token"}, + ), + ( + { + "test-auth-source": lambda: "test-token", + "another-auth-source": lambda: "another-token", + }, + { + "test-auth-source": lambda: "test-token", + "another-auth-source": lambda: "another-token", + }, + ), + ], + ) + async def test_toolbox_tool_add_auth_tokens( + self, auth_toolbox_tool, auth_tokens, expected_auth_tokens + ): + tool = auth_toolbox_tool.add_auth_tokens(auth_tokens) + for source, getter in expected_auth_tokens.items(): + assert tool._AsyncToolboxTool__auth_tokens[source]() == getter() + + async def test_toolbox_tool_add_auth_tokens_duplicate(self, auth_toolbox_tool): + tool = auth_toolbox_tool.add_auth_tokens( + {"test-auth-source": lambda: "test-token"} + ) + with pytest.raises(ValueError) as e: + tool = tool.add_auth_tokens({"test-auth-source": lambda: "test-token"}) + assert ( + "Authentication source(s) `test-auth-source` already registered in tool `test_tool`." + in str(e.value) + ) + + async def test_toolbox_tool_validate_auth_strict(self, auth_toolbox_tool): + with pytest.raises(PermissionError) as e: + auth_toolbox_tool._AsyncToolboxTool__validate_auth(strict=True) + assert "Parameter(s) `param1` of tool test_tool require authentication" in str( + e.value + ) + + async def test_toolbox_tool_call(self, toolbox_tool): + result = await toolbox_tool.acall(param1="test-value", param2=123) + assert result.content == str({"result": "test-result"}) + toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + "http://test_url/api/tool/test_tool/invoke", + json={"param1": "test-value", "param2": 123}, + headers={}, + ) + + @pytest.mark.parametrize( + "bound_param, expected_value", + [ + ({"param1": "bound-value"}, "bound-value"), + ({"param1": lambda: "dynamic-value"}, "dynamic-value"), + ], + ) + async def test_toolbox_tool_call_with_bound_params( + self, toolbox_tool, bound_param, expected_value + ): + tool = toolbox_tool.bind_params(bound_param) + result = await tool.acall(param2=123) + assert result.content == str({"result": "test-result"}) + toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + "http://test_url/api/tool/test_tool/invoke", + json={"param1": expected_value, "param2": 123}, + headers={}, + ) + + async def test_toolbox_tool_call_with_auth_tokens(self, auth_toolbox_tool): + tool = auth_toolbox_tool.add_auth_tokens( + {"test-auth-source": lambda: "test-token"} + ) + result = await tool.acall(param2=123) + assert result.content == str({"result": "test-result"}) + auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + "https://test-url/api/tool/test_tool/invoke", + json={"param2": 123}, + headers={"test-auth-source_token": "test-token"}, + ) + + async def test_toolbox_tool_call_with_auth_tokens_insecure(self, auth_toolbox_tool): + with pytest.warns( + UserWarning, + match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", + ): + auth_toolbox_tool._AsyncToolboxTool__url = "http://test-url" + tool = auth_toolbox_tool.add_auth_tokens( + {"test-auth-source": lambda: "test-token"} + ) + result = await tool.acall(param2=123) + assert result.content == str({"result": "test-result"}) + auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + "http://test-url/api/tool/test_tool/invoke", + json={"param2": 123}, + headers={"test-auth-source_token": "test-token"}, + ) + + async def test_toolbox_tool_call_with_invalid_input(self, toolbox_tool): + with pytest.raises(ValidationError) as e: + await toolbox_tool.acall(param1=123, param2="invalid") + assert "2 validation errors for test_tool" in str(e.value) + assert "param1\n Input should be a valid string" in str(e.value) + assert "param2\n Input should be a valid integer" in str(e.value) + + async def test_toolbox_tool_call_with_empty_input(self, toolbox_tool): + with pytest.raises(ValidationError) as e: + await toolbox_tool.acall() + assert "2 validation errors for test_tool" in str(e.value) + assert "param1\n Field required" in str(e.value) + assert "param2\n Field required" in str(e.value) + + async def test_toolbox_tool_run_not_implemented(self, toolbox_tool): + with pytest.raises(NotImplementedError): + toolbox_tool.call() diff --git a/packages/toolbox-llamaindex/tests/test_client.py b/packages/toolbox-llamaindex/tests/test_client.py new file mode 100644 index 00000000..842dae22 --- /dev/null +++ b/packages/toolbox-llamaindex/tests/test_client.py @@ -0,0 +1,305 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock, patch + +import pytest +from pydantic import BaseModel + +from toolbox_llamaindex.async_tools import AsyncToolboxTool +from toolbox_llamaindex.client import ToolboxClient +from toolbox_llamaindex.tools import ToolboxTool +from toolbox_llamaindex.utils import _schema_to_model + +URL = "http://test_url" + + +class TestToolboxClient: + @pytest.fixture + def tool_schema(self): + return { + "description": "Test Tool Description", + "parameters": [ + {"name": "param1", "type": "string", "description": "Param 1"}, + {"name": "param2", "type": "integer", "description": "Param 2"}, + ], + } + + @pytest.fixture() + def toolbox_client(self): + client = ToolboxClient(URL) + assert isinstance(client, ToolboxClient) + assert client._ToolboxClient__async_client is not None + + # Check that the background loop was created and started + assert client._ToolboxClient__loop is not None + assert client._ToolboxClient__loop.is_running() + + return client + + @patch("toolbox_llamaindex.client.ToolboxTool.__init__", return_value=None) + @patch("toolbox_llamaindex.client.AsyncToolboxClient.aload_tool") + def test_load_tool( + self, mock_aload_tool, mock_toolbox_tool_init, toolbox_client, tool_schema + ): + mock_async_tool = Mock(spec=AsyncToolboxTool) + mock_async_tool._AsyncToolboxTool__name = "mock-tool" # Access the mangled name + mock_async_tool._AsyncToolboxTool__schema = ( + tool_schema # Access the mangled name + ) + mock_aload_tool.return_value = mock_async_tool + + tool = toolbox_client.load_tool("test_tool") + mock_toolbox_tool_init.assert_called_once_with( + mock_async_tool, + toolbox_client._ToolboxClient__loop, + toolbox_client._ToolboxClient__thread, + ) + + assert ( + tool_schema["description"] + == mock_async_tool._AsyncToolboxTool__schema["description"] + ) + mock_aload_tool.assert_called_once_with("test_tool", {}, None, {}, True) + + @patch("toolbox_llamaindex.client.ToolboxTool.__init__", return_value=None) + @patch("toolbox_llamaindex.client.AsyncToolboxClient.aload_toolset") + def test_load_toolset( + self, mock_aload_toolset, mock_toolbox_tool_init, toolbox_client, tool_schema + ): + mock_async_tool1 = Mock(spec=AsyncToolboxTool) + mock_async_tool1._AsyncToolboxTool__name = "mock-tool-0" + mock_async_tool1._AsyncToolboxTool__schema = tool_schema + + mock_async_tool2 = Mock(spec=AsyncToolboxTool) + mock_async_tool2._AsyncToolboxTool__name = "mock-tool-1" + mock_async_tool2._AsyncToolboxTool__schema = tool_schema + mock_aload_toolset.return_value = [mock_async_tool1, mock_async_tool2] + + tools = toolbox_client.load_toolset() + assert len(tools) == 2 + mock_toolbox_tool_init.assert_any_call( + mock_async_tool1, + toolbox_client._ToolboxClient__loop, + toolbox_client._ToolboxClient__thread, + ) + mock_toolbox_tool_init.assert_any_call( + mock_async_tool2, + toolbox_client._ToolboxClient__loop, + toolbox_client._ToolboxClient__thread, + ) + + mock_aload_toolset.assert_called_once_with(None, {}, None, {}, True) + + @pytest.mark.asyncio + @patch("toolbox_llamaindex.client.ToolboxTool.__init__", return_value=None) + @patch("toolbox_llamaindex.client.AsyncToolboxClient.aload_tool") + async def test_aload_tool( + self, mock_aload_tool, mock_toolbox_tool_init, toolbox_client, tool_schema + ): + mock_async_tool = Mock(spec=AsyncToolboxTool) + mock_async_tool._AsyncToolboxTool__name = "mock-tool" # Access mangled name + mock_async_tool._AsyncToolboxTool__schema = tool_schema + mock_aload_tool.return_value = mock_async_tool + + tool = await toolbox_client.aload_tool("test_tool") + mock_toolbox_tool_init.assert_called_once_with( + mock_async_tool, + toolbox_client._ToolboxClient__loop, + toolbox_client._ToolboxClient__thread, + ) + + assert ( + tool_schema["description"] + == mock_async_tool._AsyncToolboxTool__schema["description"] + ) + mock_aload_tool.assert_called_once_with("test_tool", {}, None, {}, True) + + @pytest.mark.asyncio + @patch("toolbox_llamaindex.client.ToolboxTool.__init__", return_value=None) + @patch("toolbox_llamaindex.client.AsyncToolboxClient.aload_toolset") + async def test_aload_toolset( + self, mock_aload_toolset, mock_toolbox_tool_init, toolbox_client, tool_schema + ): + mock_async_tool1 = Mock(spec=AsyncToolboxTool) + mock_async_tool1._AsyncToolboxTool__name = "mock-tool-0" + mock_async_tool1._AsyncToolboxTool__schema = tool_schema + + mock_async_tool2 = Mock(spec=AsyncToolboxTool) + mock_async_tool2._AsyncToolboxTool__name = "mock-tool-1" + mock_async_tool2._AsyncToolboxTool__schema = tool_schema + + mock_aload_toolset.return_value = [mock_async_tool1, mock_async_tool2] + + tools = await toolbox_client.aload_toolset() + assert len(tools) == 2 + mock_toolbox_tool_init.assert_any_call( + mock_async_tool1, + toolbox_client._ToolboxClient__loop, + toolbox_client._ToolboxClient__thread, + ) + mock_toolbox_tool_init.assert_any_call( + mock_async_tool2, + toolbox_client._ToolboxClient__loop, + toolbox_client._ToolboxClient__thread, + ) + mock_aload_toolset.assert_called_once_with(None, {}, None, {}, True) + + @patch("toolbox_llamaindex.client.ToolboxTool.__init__", return_value=None) + @patch("toolbox_llamaindex.client.AsyncToolboxClient.aload_tool") + def test_load_tool_with_args( + self, mock_aload_tool, mock_toolbox_tool_init, toolbox_client, tool_schema + ): + mock_async_tool = Mock(spec=AsyncToolboxTool) + mock_async_tool._AsyncToolboxTool__name = "mock-tool" + mock_async_tool._AsyncToolboxTool__schema = tool_schema + mock_aload_tool.return_value = mock_async_tool + + auth_tokens = {"token1": lambda: "value1"} + auth_headers = {"header1": lambda: "value2"} + bound_params = {"param1": "value3"} + + tool = toolbox_client.load_tool( + "test_tool_name", + auth_tokens=auth_tokens, + auth_headers=auth_headers, + bound_params=bound_params, + strict=False, + ) + mock_toolbox_tool_init.assert_called_once_with( + mock_async_tool, + toolbox_client._ToolboxClient__loop, + toolbox_client._ToolboxClient__thread, + ) + + assert ( + tool_schema["description"] + == mock_async_tool._AsyncToolboxTool__schema["description"] + ) + mock_aload_tool.assert_called_once_with( + "test_tool_name", auth_tokens, auth_headers, bound_params, False + ) + + @patch("toolbox_llamaindex.client.ToolboxTool.__init__", return_value=None) + @patch("toolbox_llamaindex.client.AsyncToolboxClient.aload_toolset") + def test_load_toolset_with_args( + self, mock_aload_toolset, mock_toolbox_tool_init, toolbox_client, tool_schema + ): + mock_async_tool1 = Mock(spec=AsyncToolboxTool) + mock_async_tool1._AsyncToolboxTool__name = "mock-tool-0" + mock_async_tool1._AsyncToolboxTool__schema = tool_schema + + mock_async_tool2 = Mock(spec=AsyncToolboxTool) + mock_async_tool2._AsyncToolboxTool__name = "mock-tool-1" + mock_async_tool2._AsyncToolboxTool__schema = tool_schema + + mock_aload_toolset.return_value = [mock_async_tool1, mock_async_tool2] + + auth_tokens = {"token1": lambda: "value1"} + auth_headers = {"header1": lambda: "value2"} + bound_params = {"param1": "value3"} + + tools = toolbox_client.load_toolset( + toolset_name="my_toolset", + auth_tokens=auth_tokens, + auth_headers=auth_headers, + bound_params=bound_params, + strict=False, + ) + + assert len(tools) == 2 + mock_toolbox_tool_init.assert_any_call( + mock_async_tool1, + toolbox_client._ToolboxClient__loop, + toolbox_client._ToolboxClient__thread, + ) + mock_toolbox_tool_init.assert_any_call( + mock_async_tool2, + toolbox_client._ToolboxClient__loop, + toolbox_client._ToolboxClient__thread, + ) + + mock_aload_toolset.assert_called_once_with( + "my_toolset", auth_tokens, auth_headers, bound_params, False + ) + + @pytest.mark.asyncio + @patch("toolbox_llamaindex.client.ToolboxTool.__init__", return_value=None) + @patch("toolbox_llamaindex.client.AsyncToolboxClient.aload_tool") + async def test_aload_tool_with_args( + self, mock_aload_tool, mock_toolbox_tool_init, toolbox_client, tool_schema + ): + mock_async_tool = Mock(spec=AsyncToolboxTool) + mock_async_tool._AsyncToolboxTool__name = "mock-tool" + mock_async_tool._AsyncToolboxTool__schema = tool_schema + mock_aload_tool.return_value = mock_async_tool + + auth_tokens = {"token1": lambda: "value1"} + auth_headers = {"header1": lambda: "value2"} + bound_params = {"param1": "value3"} + + tool = await toolbox_client.aload_tool( + "test_tool", auth_tokens, auth_headers, bound_params, False + ) + mock_toolbox_tool_init.assert_called_once_with( + mock_async_tool, + toolbox_client._ToolboxClient__loop, + toolbox_client._ToolboxClient__thread, + ) + + assert ( + tool_schema["description"] + == mock_async_tool._AsyncToolboxTool__schema["description"] + ) + mock_aload_tool.assert_called_once_with( + "test_tool", auth_tokens, auth_headers, bound_params, False + ) + + @pytest.mark.asyncio + @patch("toolbox_llamaindex.client.ToolboxTool.__init__", return_value=None) + @patch("toolbox_llamaindex.client.AsyncToolboxClient.aload_toolset") + async def test_aload_toolset_with_args( + self, mock_aload_toolset, mock_toolbox_tool_init, toolbox_client, tool_schema + ): + mock_async_tool1 = Mock(spec=AsyncToolboxTool) + mock_async_tool1._AsyncToolboxTool__name = "mock-tool-0" + mock_async_tool1._AsyncToolboxTool__schema = tool_schema + + mock_async_tool2 = Mock(spec=AsyncToolboxTool) + mock_async_tool2._AsyncToolboxTool__name = "mock-tool-1" + mock_async_tool2._AsyncToolboxTool__schema = tool_schema + mock_aload_toolset.return_value = [mock_async_tool1, mock_async_tool2] + + auth_tokens = {"token1": lambda: "value1"} + auth_headers = {"header1": lambda: "value2"} + bound_params = {"param1": "value3"} + + tools = await toolbox_client.aload_toolset( + "my_toolset", auth_tokens, auth_headers, bound_params, False + ) + assert len(tools) == 2 + mock_toolbox_tool_init.assert_any_call( + mock_async_tool1, + toolbox_client._ToolboxClient__loop, + toolbox_client._ToolboxClient__thread, + ) + mock_toolbox_tool_init.assert_any_call( + mock_async_tool2, + toolbox_client._ToolboxClient__loop, + toolbox_client._ToolboxClient__thread, + ) + + mock_aload_toolset.assert_called_once_with( + "my_toolset", auth_tokens, auth_headers, bound_params, False + ) diff --git a/packages/toolbox-llamaindex/tests/test_e2e.py b/packages/toolbox-llamaindex/tests/test_e2e.py new file mode 100644 index 00000000..55b2e522 --- /dev/null +++ b/packages/toolbox-llamaindex/tests/test_e2e.py @@ -0,0 +1,323 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end tests for the toolbox SDK interacting with the toolbox server. + +This file covers the following use cases: + +1. Loading a tool. +2. Loading a specific toolset. +3. Loading the default toolset (contains all tools). +4. Running a tool with + a. Missing params. + b. Wrong param type. +5. Running a tool with no required auth, with auth provided. +6. Running a tool with required auth: + a. No auth provided. + b. Wrong auth provided: The tool requires a different authentication + than the one provided. + c. Correct auth provided. +7. Running a tool with a parameter that requires auth: + a. No auth provided. + b. Correct auth provided. + c. Auth provided does not contain the required claim. +""" + +import pytest +import pytest_asyncio +from aiohttp import ClientResponseError +from pydantic import ValidationError + +from toolbox_llamaindex.client import ToolboxClient + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestE2EClientAsync: + @pytest.fixture(scope="function") + def toolbox(self): + """Provides a ToolboxClient instance for each test.""" + toolbox = ToolboxClient("http://localhost:5000") + return toolbox + + @pytest_asyncio.fixture(scope="function") + async def get_n_rows_tool(self, toolbox): + tool = await toolbox.aload_tool("get-n-rows") + assert tool._ToolboxTool__async_tool._AsyncToolboxTool__name == "get-n-rows" + return tool + + #### Basic e2e tests + @pytest.mark.parametrize( + "toolset_name, expected_length, expected_tools", + [ + ("my-toolset", 1, ["get-row-by-id"]), + ("my-toolset-2", 2, ["get-n-rows", "get-row-by-id"]), + ], + ) + async def test_aload_toolset_specific( + self, toolbox, toolset_name, expected_length, expected_tools + ): + toolset = await toolbox.aload_toolset(toolset_name) + assert len(toolset) == expected_length + for tool in toolset: + name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + assert name in expected_tools + + async def test_aload_toolset_all(self, toolbox): + toolset = await toolbox.aload_toolset() + assert len(toolset) == 5 + tool_names = [ + "get-n-rows", + "get-row-by-id", + "get-row-by-id-auth", + "get-row-by-email-auth", + "get-row-by-content-auth", + ] + for tool in toolset: + name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + assert name in tool_names + + async def test_run_tool_async(self, get_n_rows_tool): + response = await get_n_rows_tool.acall(num_rows="2") + result = response.content + + assert "row1" in result + assert "row2" in result + assert "row3" not in result + + async def test_run_tool_sync(self, get_n_rows_tool): + response = get_n_rows_tool.call(num_rows="2") + result = response.content + + assert "row1" in result + assert "row2" in result + assert "row3" not in result + + async def test_run_tool_missing_params(self, get_n_rows_tool): + with pytest.raises(ValidationError, match="Field required"): + await get_n_rows_tool.acall() + + async def test_run_tool_wrong_param_type(self, get_n_rows_tool): + with pytest.raises(ValidationError, match="Input should be a valid string"): + await get_n_rows_tool.acall(num_rows=2) + + ##### Auth tests + @pytest.mark.asyncio + async def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): + """Tests running a tool that doesn't require auth, with auth provided.""" + tool = await toolbox.aload_tool( + "get-row-by-id", auth_tokens={"my-test-auth": lambda: auth_token2} + ) + response = await tool.acall(id="2") + assert "row2" in response.content + + async def test_run_tool_no_auth(self, toolbox): + """Tests running a tool requiring auth without providing auth.""" + tool = await toolbox.aload_tool( + "get-row-by-id-auth", + ) + response = await tool.acall(id="2") + assert response.is_error == True + assert "401, message='Unauthorized'" in response.content + assert isinstance(response.raw_output, str) + + async def test_run_tool_wrong_auth(self, toolbox, auth_token2): + """Tests running a tool with incorrect auth.""" + tool = await toolbox.aload_tool( + "get-row-by-id-auth", + ) + auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token2) + response = await auth_tool.acall(id="2") + assert response.is_error == True + assert "401, message='Unauthorized'" in response.content + assert isinstance(response.raw_output, str) + + async def test_run_tool_auth(self, toolbox, auth_token1): + """Tests running a tool with correct auth.""" + tool = await toolbox.aload_tool( + "get-row-by-id-auth", + ) + auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token1) + response = await auth_tool.acall(id="2") + assert "row2" in response.content + + async def test_run_tool_param_auth_no_auth(self, toolbox): + """Tests runningP a tool with a param requiring auth, without auth.""" + tool = await toolbox.aload_tool("get-row-by-email-auth") + with pytest.raises( + PermissionError, + match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", + ): + await tool.acall(email="") + + async def test_run_tool_param_auth(self, toolbox, auth_token1): + """Tests running a tool with a param requiring auth, with correct auth.""" + tool = await toolbox.aload_tool( + "get-row-by-email-auth", auth_tokens={"my-test-auth": lambda: auth_token1} + ) + response = await tool.acall() + result = response.content + assert "row4" in result + assert "row5" in result + assert "row6" in result + + async def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): + """Tests running a tool with a param requiring auth, with insufficient auth.""" + tool = await toolbox.aload_tool( + "get-row-by-content-auth", auth_tokens={"my-test-auth": lambda: auth_token1} + ) + response = await tool.acall() + assert response.is_error == True + assert "400, message='Bad Request'" in response.content + assert isinstance(response.raw_output, str) + + +@pytest.mark.usefixtures("toolbox_server") +class TestE2EClientSync: + @pytest.fixture(scope="session") + def toolbox(self): + """Provides a ToolboxClient instance for each test.""" + toolbox = ToolboxClient("http://localhost:5000") + return toolbox + + @pytest.fixture(scope="function") + def get_n_rows_tool(self, toolbox): + tool = toolbox.load_tool("get-n-rows") + assert tool._ToolboxTool__async_tool._AsyncToolboxTool__name == "get-n-rows" + return tool + + #### Basic e2e tests + @pytest.mark.parametrize( + "toolset_name, expected_length, expected_tools", + [ + ("my-toolset", 1, ["get-row-by-id"]), + ("my-toolset-2", 2, ["get-n-rows", "get-row-by-id"]), + ], + ) + def test_load_toolset_specific( + self, toolbox, toolset_name, expected_length, expected_tools + ): + toolset = toolbox.load_toolset(toolset_name) + assert len(toolset) == expected_length + for tool in toolset: + name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + assert name in expected_tools + + def test_aload_toolset_all(self, toolbox): + toolset = toolbox.load_toolset() + assert len(toolset) == 5 + tool_names = [ + "get-n-rows", + "get-row-by-id", + "get-row-by-id-auth", + "get-row-by-email-auth", + "get-row-by-content-auth", + ] + for tool in toolset: + name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + assert name in tool_names + + @pytest.mark.asyncio + async def test_run_tool_async(self, get_n_rows_tool): + response = await get_n_rows_tool.acall(num_rows="2") + result = response.content + + assert "row1" in result + assert "row2" in result + assert "row3" not in result + + def test_run_tool_sync(self, get_n_rows_tool): + response = get_n_rows_tool.call(num_rows="2") + result = response.content + + assert "row1" in result + assert "row2" in result + assert "row3" not in result + + def test_run_tool_missing_params(self, get_n_rows_tool): + with pytest.raises(ValidationError, match="Field required"): + get_n_rows_tool.call() + + def test_run_tool_wrong_param_type(self, get_n_rows_tool): + with pytest.raises(ValidationError, match="Input should be a valid string"): + get_n_rows_tool.call(num_rows=2) + + #### Auth tests + def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): + """Tests running a tool that doesn't require auth, with auth provided.""" + tool = toolbox.load_tool( + "get-row-by-id", auth_tokens={"my-test-auth": lambda: auth_token2} + ) + response = tool.call(id="2") + assert "row2" in response.content + + def test_run_tool_no_auth(self, toolbox): + """Tests running a tool requiring auth without providing auth.""" + tool = toolbox.load_tool( + "get-row-by-id-auth", + ) + response = tool.call(id="2") + assert response.is_error == True + assert "401, message='Unauthorized'" in response.content + assert isinstance(response.raw_output, str) + + def test_run_tool_wrong_auth(self, toolbox, auth_token2): + """Tests running a tool with incorrect auth.""" + tool = toolbox.load_tool( + "get-row-by-id-auth", + ) + auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token2) + response = auth_tool.call(id="2") + assert response.is_error == True + assert "401, message='Unauthorized'" in response.content + assert isinstance(response.raw_output, str) + + def test_run_tool_auth(self, toolbox, auth_token1): + """Tests running a tool with correct auth.""" + tool = toolbox.load_tool( + "get-row-by-id-auth", + ) + auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token1) + response = auth_tool.call(id="2") + assert "row2" in response.content + + def test_run_tool_param_auth_no_auth(self, toolbox): + """Tests running a tool with a param requiring auth, without auth.""" + tool = toolbox.load_tool("get-row-by-email-auth") + with pytest.raises( + PermissionError, + match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", + ): + tool.call(email="") + + def test_run_tool_param_auth(self, toolbox, auth_token1): + """Tests running a tool with a param requiring auth, with correct auth.""" + tool = toolbox.load_tool( + "get-row-by-email-auth", auth_tokens={"my-test-auth": lambda: auth_token1} + ) + response = tool.call() + result = response.content + assert "row4" in result + assert "row5" in result + assert "row6" in result + + def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): + """Tests running a tool with a param requiring auth, with insufficient auth.""" + tool = toolbox.load_tool( + "get-row-by-content-auth", auth_tokens={"my-test-auth": lambda: auth_token1} + ) + response = tool.call() + assert response.is_error == True + assert "400, message='Bad Request'" in response.content + assert isinstance(response.raw_output, str) diff --git a/packages/toolbox-llamaindex/tests/test_tools.py b/packages/toolbox-llamaindex/tests/test_tools.py new file mode 100644 index 00000000..faeefd20 --- /dev/null +++ b/packages/toolbox-llamaindex/tests/test_tools.py @@ -0,0 +1,243 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import concurrent.futures +from unittest.mock import Mock, patch + +import pytest +from pydantic import BaseModel + +from toolbox_llamaindex.async_tools import AsyncToolboxTool +from toolbox_llamaindex.tools import ToolboxTool + + +class TestToolboxTool: + @pytest.fixture + def tool_schema(self): + return { + "description": "Test Tool Description", + "parameters": [ + {"name": "param1", "type": "string", "description": "Param 1"}, + {"name": "param2", "type": "integer", "description": "Param 2"}, + ], + } + + @pytest.fixture + def auth_tool_schema(self): + return { + "description": "Test Tool Description", + "parameters": [ + { + "name": "param1", + "type": "string", + "description": "Param 1", + "authSources": ["test-auth-source"], + }, + {"name": "param2", "type": "integer", "description": "Param 2"}, + ], + } + + @pytest.fixture(scope="function") + def mock_async_tool(self, tool_schema): + mock_async_tool = Mock(spec=AsyncToolboxTool) + mock_async_tool.name = "test_tool" + mock_async_tool.description = "test description" + mock_async_tool.args_schema = BaseModel + mock_async_tool._AsyncToolboxTool__name = "test_tool" + mock_async_tool._AsyncToolboxTool__schema = tool_schema + mock_async_tool._AsyncToolboxTool__url = "http://test_url" + mock_async_tool._AsyncToolboxTool__session = Mock() + mock_async_tool._AsyncToolboxTool__auth_tokens = {} + mock_async_tool._AsyncToolboxTool__bound_params = {} + return mock_async_tool + + @pytest.fixture(scope="function") + def mock_async_auth_tool(self, auth_tool_schema): + mock_async_tool = Mock(spec=AsyncToolboxTool) + mock_async_tool.name = "test_tool" + mock_async_tool.description = "test description" + mock_async_tool.args_schema = BaseModel + mock_async_tool._AsyncToolboxTool__name = "test_tool" + mock_async_tool._AsyncToolboxTool__schema = auth_tool_schema + mock_async_tool._AsyncToolboxTool__url = "http://test_url" + mock_async_tool._AsyncToolboxTool__session = Mock() + mock_async_tool._AsyncToolboxTool__auth_tokens = {} + mock_async_tool._AsyncToolboxTool__bound_params = {} + return mock_async_tool + + @pytest.fixture + def toolbox_tool(self, mock_async_tool): + return ToolboxTool( + async_tool=mock_async_tool, + loop=Mock(), + thread=Mock(), + ) + + @pytest.fixture + def auth_toolbox_tool(self, mock_async_auth_tool): + return ToolboxTool( + async_tool=mock_async_auth_tool, + loop=Mock(), + thread=Mock(), + ) + + def test_toolbox_tool_init(self, mock_async_tool, toolbox_tool): + assert toolbox_tool._ToolboxTool__async_tool == mock_async_tool + + @pytest.mark.parametrize( + "params, expected_bound_params", + [ + ({"param1": "bound-value"}, {"param1": "bound-value"}), + ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), + ( + {"param1": "bound-value", "param2": 123}, + {"param1": "bound-value", "param2": 123}, + ), + ], + ) + def test_toolbox_tool_bind_params( + self, + params, + expected_bound_params, + toolbox_tool, + mock_async_tool, + ): + mock_async_tool._AsyncToolboxTool__bound_params = expected_bound_params + mock_async_tool.bind_params.return_value = mock_async_tool + + tool = toolbox_tool.bind_params(params) + mock_async_tool.bind_params.assert_called_once_with(params, True) + + assert isinstance(tool, ToolboxTool) + + for key, value in expected_bound_params.items(): + async_tool_bound_param_val = ( + tool._ToolboxTool__async_tool._AsyncToolboxTool__bound_params[key] + ) + if callable(value): + assert value() == async_tool_bound_param_val() + else: + assert value == async_tool_bound_param_val + + def test_toolbox_tool_bind_param(self, mock_async_tool, toolbox_tool): + expected_bound_param = {"param1": "bound-value"} + mock_async_tool._AsyncToolboxTool__bound_params = expected_bound_param + mock_async_tool.bind_param.return_value = mock_async_tool + + tool = toolbox_tool.bind_param("param1", "bound-value") + mock_async_tool.bind_param.assert_called_once_with( + "param1", "bound-value", True + ) + + assert ( + tool._ToolboxTool__async_tool._AsyncToolboxTool__bound_params + == expected_bound_param + ) + assert isinstance(tool, ToolboxTool) + + @pytest.mark.parametrize( + "auth_tokens, expected_auth_tokens", + [ + ( + {"test-auth-source": lambda: "test-token"}, + {"test-auth-source": lambda: "test-token"}, + ), + ( + { + "test-auth-source": lambda: "test-token", + "another-auth-source": lambda: "another-token", + }, + { + "test-auth-source": lambda: "test-token", + "another-auth-source": lambda: "another-token", + }, + ), + ], + ) + def test_toolbox_tool_add_auth_tokens( + self, + auth_tokens, + expected_auth_tokens, + mock_async_auth_tool, + auth_toolbox_tool, + ): + auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens = ( + expected_auth_tokens + ) + auth_toolbox_tool._ToolboxTool__async_tool.add_auth_tokens.return_value = ( + mock_async_auth_tool + ) + + tool = auth_toolbox_tool.add_auth_tokens(auth_tokens) + mock_async_auth_tool.add_auth_tokens.assert_called_once_with(auth_tokens, True) + for source, getter in expected_auth_tokens.items(): + assert ( + tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens[source]() + == getter() + ) + assert isinstance(tool, ToolboxTool) + + def test_toolbox_tool_add_auth_token(self, mock_async_auth_tool, auth_toolbox_tool): + get_id_token = lambda: "test-token" + expected_auth_tokens = {"test-auth-source": get_id_token} + auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens = ( + expected_auth_tokens + ) + auth_toolbox_tool._ToolboxTool__async_tool.add_auth_token.return_value = ( + mock_async_auth_tool + ) + + tool = auth_toolbox_tool.add_auth_token("test-auth-source", get_id_token) + mock_async_auth_tool.add_auth_token.assert_called_once_with( + "test-auth-source", get_id_token, True + ) + + assert ( + tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens[ + "test-auth-source" + ]() + == "test-token" + ) + assert isinstance(tool, ToolboxTool) + + @pytest.mark.asyncio + async def test_toolbox_tool_validate_auth_strict(self, auth_toolbox_tool): + auth_toolbox_tool._ToolboxTool__async_tool.acall = Mock( + side_effect=PermissionError( + "Parameter(s) `param1` of tool test_tool require authentication" + ) + ) + with pytest.raises(PermissionError) as e: + await auth_toolbox_tool.acall() + assert "Parameter(s) `param1` of tool test_tool require authentication" in str( + e.value + ) + + @pytest.mark.asyncio + @patch("asyncio.run_coroutine_threadsafe") + async def test_toolbox_tool_run(self, mock_run_coroutine_threadsafe, toolbox_tool): + future = concurrent.futures.Future() + future.set_result({"result": "async success"}) + mock_run_coroutine_threadsafe.return_value = future + result = await toolbox_tool.acall(param1="value1", param2=3) + mock_run_coroutine_threadsafe.assert_called_once() + assert result == {"result": "async success"} + + @patch("asyncio.run_coroutine_threadsafe") + def test_toolbox_tool_sync_run(self, mock_run_coroutine_threadsafe, toolbox_tool): + future = concurrent.futures.Future() + future.set_result({"result": "sync success"}) + mock_run_coroutine_threadsafe.return_value = future + result = toolbox_tool.call(param1="value1", param2=3) + mock_run_coroutine_threadsafe.assert_called_once() + assert result == {"result": "sync success"} diff --git a/packages/toolbox-llamaindex/tests/test_utils.py b/packages/toolbox-llamaindex/tests/test_utils.py new file mode 100644 index 00000000..78d64e26 --- /dev/null +++ b/packages/toolbox-llamaindex/tests/test_utils.py @@ -0,0 +1,291 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import json +import re +import warnings +from typing import Union +from unittest.mock import AsyncMock, Mock, patch + +import aiohttp +import pytest +from pydantic import BaseModel + +from toolbox_llamaindex.utils import ( + ParameterSchema, + _get_auth_headers, + _invoke_tool, + _load_manifest, + _parse_type, + _schema_to_model, +) + +URL = "https://my-toolbox.com/test" +MOCK_MANIFEST = """ +{ + "serverVersion": "0.0.1", + "tools": { + "test_tool": { + "summary": "Test Tool", + "description": "This is a test tool.", + "parameters": [ + { + "name": "param1", + "type": "string", + "description": "Parameter 1" + }, + { + "name": "param2", + "type": "integer", + "description": "Parameter 2" + } + ] + } + } +} +""" + + +class TestUtils: + @pytest.fixture(scope="module") + def mock_manifest(self): + return aiohttp.ClientResponse( + method="GET", + url=aiohttp.client.URL(URL), + writer=None, + continue100=None, + timer=None, + request_info=None, + traces=None, + session=None, + loop=asyncio.get_event_loop(), + ) + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession.get") + async def test_load_manifest(self, mock_get, mock_manifest): + mock_manifest.raise_for_status = Mock() + mock_manifest.text = AsyncMock(return_value=MOCK_MANIFEST) + + mock_get.return_value = mock_manifest + session = aiohttp.ClientSession() + manifest = await _load_manifest(URL, session) + await session.close() + mock_get.assert_called_once_with(URL) + + assert manifest.serverVersion == "0.0.1" + assert len(manifest.tools) == 1 + + tool = manifest.tools["test_tool"] + assert tool.description == "This is a test tool." + assert tool.parameters == [ + ParameterSchema(name="param1", type="string", description="Parameter 1"), + ParameterSchema(name="param2", type="integer", description="Parameter 2"), + ] + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession.get") + async def test_load_manifest_invalid_json(self, mock_get, mock_manifest): + mock_manifest.raise_for_status = Mock() + mock_manifest.text = AsyncMock(return_value="{ invalid manifest") + mock_get.return_value = mock_manifest + + with pytest.raises(Exception) as e: + session = aiohttp.ClientSession() + await _load_manifest(URL, session) + + mock_get.assert_called_once_with(URL) + assert isinstance(e.value, json.JSONDecodeError) + assert ( + str(e.value) + == "Failed to parse JSON from https://my-toolbox.com/test: Expecting property name enclosed in double quotes: line 1 column 3 (char 2): line 1 column 3 (char 2)" + ) + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession.get") + async def test_load_manifest_invalid_manifest(self, mock_get, mock_manifest): + mock_manifest.raise_for_status = Mock() + mock_manifest.text = AsyncMock(return_value='{ "something": "invalid" }') + mock_get.return_value = mock_manifest + + with pytest.raises(Exception) as e: + session = aiohttp.ClientSession() + await _load_manifest(URL, session) + + mock_get.assert_called_once_with(URL) + assert isinstance(e.value, ValueError) + assert re.match( + r"Invalid JSON data from https://my-toolbox.com/test: 2 validation errors for ManifestSchema\nserverVersion\n Field required \[type=missing, input_value={'something': 'invalid'}, input_type=dict]\n For further information visit https://errors.pydantic.dev/\d+\.\d+/v/missing\ntools\n Field required \[type=missing, input_value={'something': 'invalid'}, input_type=dict]\n For further information visit https://errors.pydantic.dev/\d+\.\d+/v/missing", + str(e.value), + ) + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession.get") + async def test_load_manifest_api_error(self, mock_get, mock_manifest): + error = aiohttp.ClientError("Simulated HTTP Error") + mock_manifest.raise_for_status = Mock() + mock_manifest.text = AsyncMock(side_effect=error) + mock_get.return_value = mock_manifest + + with pytest.raises(aiohttp.ClientError) as exc_info: + session = aiohttp.ClientSession() + await _load_manifest(URL, session) + mock_get.assert_called_once_with(URL) + assert exc_info.value == error + + def test_schema_to_model(self): + schema = [ + ParameterSchema(name="param1", type="string", description="Parameter 1"), + ParameterSchema(name="param2", type="integer", description="Parameter 2"), + ] + model = _schema_to_model("TestModel", schema) + assert issubclass(model, BaseModel) + + assert model.model_fields["param1"].annotation == str + assert model.model_fields["param1"].description == "Parameter 1" + assert model.model_fields["param2"].annotation == int + assert model.model_fields["param2"].description == "Parameter 2" + + def test_schema_to_model_empty(self): + model = _schema_to_model("TestModel", []) + assert issubclass(model, BaseModel) + assert len(model.model_fields) == 0 + + @pytest.mark.parametrize( + "parameter_schema, expected_type", + [ + (ParameterSchema(name="foo", description="bar", type="string"), str), + (ParameterSchema(name="foo", description="bar", type="integer"), int), + (ParameterSchema(name="foo", description="bar", type="float"), float), + (ParameterSchema(name="foo", description="bar", type="boolean"), bool), + ( + ParameterSchema( + name="foo", + description="bar", + type="array", + items=ParameterSchema( + name="foo", description="bar", type="integer" + ), + ), + list[int], + ), + ], + ) + def test_parse_type(self, parameter_schema, expected_type): + assert _parse_type(parameter_schema) == expected_type + + @pytest.mark.parametrize( + "fail_parameter_schema", + [ + (ParameterSchema(name="foo", description="bar", type="invalid")), + ( + ParameterSchema( + name="foo", + description="bar", + type="array", + items=ParameterSchema( + name="foo", description="bar", type="invalid" + ), + ) + ), + ], + ) + def test_parse_type_invalid(self, fail_parameter_schema): + with pytest.raises(ValueError): + _parse_type(fail_parameter_schema) + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession.post") + async def test_invoke_tool(self, mock_post): + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json = AsyncMock(return_value={"key": "value"}) + mock_post.return_value.__aenter__.return_value = mock_response + + result = await _invoke_tool( + "http://localhost:8000", + aiohttp.ClientSession(), + "tool_name", + {"input": "data"}, + {}, + ) + + mock_post.assert_called_once_with( + "http://localhost:8000/api/tool/tool_name/invoke", + json={"input": "data"}, + headers={}, + ) + assert result == {"key": "value"} + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession.post") + async def test_invoke_tool_unsecure_with_auth(self, mock_post): + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json = AsyncMock(return_value={"key": "value"}) + mock_post.return_value.__aenter__.return_value = mock_response + + with pytest.warns( + UserWarning, + match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", + ): + result = await _invoke_tool( + "http://localhost:8000", + aiohttp.ClientSession(), + "tool_name", + {"input": "data"}, + {"my_test_auth": lambda: "fake_id_token"}, + ) + + mock_post.assert_called_once_with( + "http://localhost:8000/api/tool/tool_name/invoke", + json={"input": "data"}, + headers={"my_test_auth_token": "fake_id_token"}, + ) + assert result == {"key": "value"} + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession.post") + async def test_invoke_tool_secure_with_auth(self, mock_post): + session = aiohttp.ClientSession() + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json = AsyncMock(return_value={"key": "value"}) + mock_post.return_value.__aenter__.return_value = mock_response + + with warnings.catch_warnings(): + warnings.simplefilter("error") + result = await _invoke_tool( + "https://localhost:8000", + session, + "tool_name", + {"input": "data"}, + {"my_test_auth": lambda: "fake_id_token"}, + ) + + mock_post.assert_called_once_with( + "https://localhost:8000/api/tool/tool_name/invoke", + json={"input": "data"}, + headers={"my_test_auth_token": "fake_id_token"}, + ) + assert result == {"key": "value"} + + def test_get_auth_headers_deprecation_warning(self): + """Test _get_auth_headers deprecation warning.""" + with pytest.warns( + DeprecationWarning, + match=r"Call to deprecated function \(or staticmethod\) _get_auth_headers\. \(Please use `_get_auth_tokens` instead\.\)$", + ): + _get_auth_headers({"auth_source1": lambda: "test_token"}) diff --git a/release-please-config.json b/release-please-config.json index 34af52df..f5ad840d 100644 --- a/release-please-config.json +++ b/release-please-config.json @@ -21,6 +21,12 @@ "extra-files": [ "src/toolbox_langchain/version.py" ] + }, + "packages/toolbox-llamaindex": { + "component": "toolbox-llamaindex", + "extra-files": [ + "src/toolbox_llamaindex/version.py" + ] } }, "plugins": [ @@ -28,7 +34,7 @@ "type": "linked-versions", "groupName": "toolbox-python-sdks", "components": [ - "toolbox-core", "toolbox-langchain" + "toolbox-core", "toolbox-langchain", "toolbox-llamaindex" ] } ]