Skip to content

Commit 827129f

Browse files
twishabansalkurtisvgYuan325anubhav756
authored
tests(toolbox-code): add e2e tests (#122)
* feat: add authenticated parameters support * chore: add asyncio dep * chore: run itest * chore: add type hint * fix: call tool instead of client * chore: correct arg name * feat: add support for bound parameters * chore: add tests for bound parameters * docs: update syntax error on readme (#121) * ci: added release please config (#112) * ci: add release please config * chore: add initial version * chore: specify initial version as string * chore: Update .release-please-manifest.json * chore: add empty json * chore: small change * chore: try fixing config * chore: try fixing config again * chore: remove release-as * chore: add changelog sections * chore: better release notes * chore: better release notes * chore: change toolbox-langchain version * chore: separate PRs for packages * chore: change PR style * added basic e2e tests * change license year * add test deps * fix tests * fix tests * fix tests * add new test case * fix docstring * added todo * cleanup * add bind param test case * make bind params dynamic * try fix test errors * lint * remove redundant test * test fix * fix docstring * feat: add authenticated parameters support * chore: add asyncio dep * chore: run itest * chore: add type hint * fix: call tool instead of client * chore: correct arg name * chore: address feedback * chore: address more feedback * feat: add support for bound parameters * chore: add tests for bound parameters * chore: address feedback * revert package file changes * fix error message * revert package files * lint * fix error message * Update packages/toolbox-core/tests/test_e2e.py Co-authored-by: Anubhav Dhawan <anubhav756@gmail.com> * add new test case * change docstring to reflect new test cases * clean up docstring * lint * Move tests to different classes * add timeout --------- Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Co-authored-by: Yuan <45984206+Yuan325@users.noreply.github.com> Co-authored-by: Anubhav Dhawan <anubhav756@gmail.com>
1 parent 6454676 commit 827129f

File tree

3 files changed

+354
-0
lines changed

3 files changed

+354
-0
lines changed

packages/toolbox-core/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ test = [
4646
"pytest==8.3.5",
4747
"pytest-aioresponses==0.3.0",
4848
"pytest-asyncio==0.25.3",
49+
"google-cloud-secret-manager==2.23.2",
50+
"google-cloud-storage==3.1.0",
4951
]
5052
[build-system]
5153
requires = ["setuptools"]
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Contains pytest fixtures that are accessible from all
16+
files present in the same directory."""
17+
18+
from __future__ import annotations
19+
20+
import os
21+
import platform
22+
import subprocess
23+
import tempfile
24+
import time
25+
from typing import Generator
26+
27+
import google
28+
import pytest_asyncio
29+
from google.auth import compute_engine
30+
from google.cloud import secretmanager, storage
31+
32+
33+
#### Define Utility Functions
34+
def get_env_var(key: str) -> str:
35+
"""Gets environment variables."""
36+
value = os.environ.get(key)
37+
if value is None:
38+
raise ValueError(f"Must set env var {key}")
39+
return value
40+
41+
42+
def access_secret_version(
43+
project_id: str, secret_id: str, version_id: str = "latest"
44+
) -> str:
45+
"""Accesses the payload of a given secret version from Secret Manager."""
46+
client = secretmanager.SecretManagerServiceClient()
47+
name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}"
48+
response = client.access_secret_version(request={"name": name})
49+
return response.payload.data.decode("UTF-8")
50+
51+
52+
def create_tmpfile(content: str) -> str:
53+
"""Creates a temporary file with the given content."""
54+
with tempfile.NamedTemporaryFile(delete=False, mode="w") as tmpfile:
55+
tmpfile.write(content)
56+
return tmpfile.name
57+
58+
59+
def download_blob(
60+
bucket_name: str, source_blob_name: str, destination_file_name: str
61+
) -> None:
62+
"""Downloads a blob from a GCS bucket."""
63+
storage_client = storage.Client()
64+
65+
bucket = storage_client.bucket(bucket_name)
66+
blob = bucket.blob(source_blob_name)
67+
blob.download_to_filename(destination_file_name)
68+
69+
print(f"Blob {source_blob_name} downloaded to {destination_file_name}.")
70+
71+
72+
def get_toolbox_binary_url(toolbox_version: str) -> str:
73+
"""Constructs the GCS path to the toolbox binary."""
74+
os_system = platform.system().lower()
75+
arch = (
76+
"arm64" if os_system == "darwin" and platform.machine() == "arm64" else "amd64"
77+
)
78+
return f"v{toolbox_version}/{os_system}/{arch}/toolbox"
79+
80+
81+
def get_auth_token(client_id: str) -> str:
82+
"""Retrieves an authentication token"""
83+
request = google.auth.transport.requests.Request()
84+
credentials = compute_engine.IDTokenCredentials(
85+
request=request,
86+
target_audience=client_id,
87+
use_metadata_identity_endpoint=True,
88+
)
89+
if not credentials.valid:
90+
credentials.refresh(request)
91+
return credentials.token
92+
93+
94+
#### Define Fixtures
95+
@pytest_asyncio.fixture(scope="session")
96+
def project_id() -> str:
97+
return get_env_var("GOOGLE_CLOUD_PROJECT")
98+
99+
100+
@pytest_asyncio.fixture(scope="session")
101+
def toolbox_version() -> str:
102+
return get_env_var("TOOLBOX_VERSION")
103+
104+
105+
@pytest_asyncio.fixture(scope="session")
106+
def tools_file_path(project_id: str) -> Generator[str]:
107+
"""Provides a temporary file path containing the tools manifest."""
108+
tools_manifest = access_secret_version(
109+
project_id=project_id, secret_id="sdk_testing_tools"
110+
)
111+
tools_file_path = create_tmpfile(tools_manifest)
112+
yield tools_file_path
113+
os.remove(tools_file_path)
114+
115+
116+
@pytest_asyncio.fixture(scope="session")
117+
def auth_token1(project_id: str) -> str:
118+
client_id = access_secret_version(
119+
project_id=project_id, secret_id="sdk_testing_client1"
120+
)
121+
return get_auth_token(client_id)
122+
123+
124+
@pytest_asyncio.fixture(scope="session")
125+
def auth_token2(project_id: str) -> str:
126+
client_id = access_secret_version(
127+
project_id=project_id, secret_id="sdk_testing_client2"
128+
)
129+
return get_auth_token(client_id)
130+
131+
132+
@pytest_asyncio.fixture(scope="session")
133+
def toolbox_server(toolbox_version: str, tools_file_path: str) -> Generator[None]:
134+
"""Starts the toolbox server as a subprocess."""
135+
print("Downloading toolbox binary from gcs bucket...")
136+
source_blob_name = get_toolbox_binary_url(toolbox_version)
137+
download_blob("genai-toolbox", source_blob_name, "toolbox")
138+
print("Toolbox binary downloaded successfully.")
139+
try:
140+
print("Opening toolbox server process...")
141+
# Make toolbox executable
142+
os.chmod("toolbox", 0o700)
143+
# Run toolbox binary
144+
toolbox_server = subprocess.Popen(
145+
["./toolbox", "--tools_file", tools_file_path]
146+
)
147+
148+
# Wait for server to start
149+
# Retry logic with a timeout
150+
for _ in range(5): # retries
151+
time.sleep(2)
152+
print("Checking if toolbox is successfully started...")
153+
if toolbox_server.poll() is None:
154+
print("Toolbox server started successfully.")
155+
break
156+
else:
157+
raise RuntimeError("Toolbox server failed to start after 5 retries.")
158+
except subprocess.CalledProcessError as e:
159+
print(e.stderr.decode("utf-8"))
160+
print(e.stdout.decode("utf-8"))
161+
raise RuntimeError(f"{e}\n\n{e.stderr.decode('utf-8')}") from e
162+
yield
163+
164+
# Clean up toolbox server
165+
toolbox_server.terminate()
166+
toolbox_server.wait(timeout=5)
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
import pytest_asyncio
16+
17+
from toolbox_core.client import ToolboxClient
18+
from toolbox_core.tool import ToolboxTool
19+
20+
21+
# --- Shared Fixtures Defined at Module Level ---
22+
@pytest_asyncio.fixture(scope="function")
23+
async def toolbox():
24+
"""Creates a ToolboxClient instance shared by all tests in this module."""
25+
toolbox = ToolboxClient("http://localhost:5000")
26+
try:
27+
yield toolbox
28+
finally:
29+
await toolbox.close()
30+
31+
32+
@pytest_asyncio.fixture(scope="function")
33+
async def get_n_rows_tool(toolbox: ToolboxClient) -> ToolboxTool:
34+
"""Load the 'get-n-rows' tool using the shared toolbox client."""
35+
tool = await toolbox.load_tool("get-n-rows")
36+
assert tool.__name__ == "get-n-rows"
37+
return tool
38+
39+
40+
@pytest.mark.asyncio
41+
@pytest.mark.usefixtures("toolbox_server")
42+
class TestBasicE2E:
43+
@pytest.mark.parametrize(
44+
"toolset_name, expected_length, expected_tools",
45+
[
46+
("my-toolset", 1, ["get-row-by-id"]),
47+
("my-toolset-2", 2, ["get-n-rows", "get-row-by-id"]),
48+
],
49+
)
50+
async def test_load_toolset_specific(
51+
self,
52+
toolbox: ToolboxClient,
53+
toolset_name: str,
54+
expected_length: int,
55+
expected_tools: list[str],
56+
):
57+
"""Load a specific toolset"""
58+
toolset = await toolbox.load_toolset(toolset_name)
59+
assert len(toolset) == expected_length
60+
tool_names = {tool.__name__ for tool in toolset}
61+
assert tool_names == set(expected_tools)
62+
63+
async def test_run_tool(self, get_n_rows_tool: ToolboxTool):
64+
"""Invoke a tool."""
65+
response = await get_n_rows_tool(num_rows="2")
66+
67+
assert isinstance(response, str)
68+
assert "row1" in response
69+
assert "row2" in response
70+
assert "row3" not in response
71+
72+
async def test_run_tool_missing_params(self, get_n_rows_tool: ToolboxTool):
73+
"""Invoke a tool with missing params."""
74+
with pytest.raises(TypeError, match="missing a required argument: 'num_rows'"):
75+
await get_n_rows_tool()
76+
77+
async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool):
78+
"""Invoke a tool with wrong param type."""
79+
with pytest.raises(
80+
Exception,
81+
match='provided parameters were invalid: unable to parse value for "num_rows": .* not type "string"',
82+
):
83+
await get_n_rows_tool(num_rows=2)
84+
85+
86+
@pytest.mark.asyncio
87+
@pytest.mark.usefixtures("toolbox_server")
88+
class TestBindParams:
89+
async def test_bind_params(
90+
self, toolbox: ToolboxClient, get_n_rows_tool: ToolboxTool
91+
):
92+
"""Bind a param to an existing tool."""
93+
new_tool = get_n_rows_tool.bind_parameters({"num_rows": "3"})
94+
response = await new_tool()
95+
assert isinstance(response, str)
96+
assert "row1" in response
97+
assert "row2" in response
98+
assert "row3" in response
99+
assert "row4" not in response
100+
101+
async def test_bind_params_callable(
102+
self, toolbox: ToolboxClient, get_n_rows_tool: ToolboxTool
103+
):
104+
"""Bind a callable param to an existing tool."""
105+
new_tool = get_n_rows_tool.bind_parameters({"num_rows": lambda: "3"})
106+
response = await new_tool()
107+
assert isinstance(response, str)
108+
assert "row1" in response
109+
assert "row2" in response
110+
assert "row3" in response
111+
assert "row4" not in response
112+
113+
114+
@pytest.mark.asyncio
115+
@pytest.mark.usefixtures("toolbox_server")
116+
class TestAuth:
117+
async def test_run_tool_unauth_with_auth(
118+
self, toolbox: ToolboxClient, auth_token2: str
119+
):
120+
"""Tests running a tool that doesn't require auth, with auth provided."""
121+
tool = await toolbox.load_tool(
122+
"get-row-by-id", auth_token_getters={"my-test-auth": lambda: auth_token2}
123+
)
124+
response = await tool(id="2")
125+
assert "row2" in response
126+
127+
async def test_run_tool_no_auth(self, toolbox: ToolboxClient):
128+
"""Tests running a tool requiring auth without providing auth."""
129+
tool = await toolbox.load_tool("get-row-by-id-auth")
130+
with pytest.raises(
131+
Exception,
132+
match="tool invocation not authorized. Please make sure your specify correct auth headers",
133+
):
134+
await tool(id="2")
135+
136+
async def test_run_tool_wrong_auth(self, toolbox: ToolboxClient, auth_token2: str):
137+
"""Tests running a tool with incorrect auth. The tool
138+
requires a different authentication than the one provided."""
139+
tool = await toolbox.load_tool("get-row-by-id-auth")
140+
auth_tool = tool.add_auth_token_getters({"my-test-auth": lambda: auth_token2})
141+
with pytest.raises(
142+
Exception,
143+
match="tool invocation not authorized",
144+
):
145+
await auth_tool(id="2")
146+
147+
async def test_run_tool_auth(self, toolbox: ToolboxClient, auth_token1: str):
148+
"""Tests running a tool with correct auth."""
149+
tool = await toolbox.load_tool("get-row-by-id-auth")
150+
auth_tool = tool.add_auth_token_getters({"my-test-auth": lambda: auth_token1})
151+
response = await auth_tool(id="2")
152+
assert "row2" in response
153+
154+
async def test_run_tool_param_auth_no_auth(self, toolbox: ToolboxClient):
155+
"""Tests running a tool with a param requiring auth, without auth."""
156+
tool = await toolbox.load_tool("get-row-by-email-auth")
157+
with pytest.raises(
158+
Exception,
159+
match="One or more of the following authn services are required to invoke this tool: my-test-auth",
160+
):
161+
await tool()
162+
163+
async def test_run_tool_param_auth(self, toolbox: ToolboxClient, auth_token1: str):
164+
"""Tests running a tool with a param requiring auth, with correct auth."""
165+
tool = await toolbox.load_tool(
166+
"get-row-by-email-auth",
167+
auth_token_getters={"my-test-auth": lambda: auth_token1},
168+
)
169+
response = await tool()
170+
assert "row4" in response
171+
assert "row5" in response
172+
assert "row6" in response
173+
174+
async def test_run_tool_param_auth_no_field(
175+
self, toolbox: ToolboxClient, auth_token1: str
176+
):
177+
"""Tests running a tool with a param requiring auth, with insufficient auth."""
178+
tool = await toolbox.load_tool(
179+
"get-row-by-content-auth",
180+
auth_token_getters={"my-test-auth": lambda: auth_token1},
181+
)
182+
with pytest.raises(
183+
Exception,
184+
match="no field named row_data in claims",
185+
):
186+
await tool()

0 commit comments

Comments
 (0)