diff --git a/packages/toolbox-core/pyproject.toml b/packages/toolbox-core/pyproject.toml index ee8a5f73..e9ee66f1 100644 --- a/packages/toolbox-core/pyproject.toml +++ b/packages/toolbox-core/pyproject.toml @@ -46,6 +46,8 @@ test = [ "pytest==8.3.5", "pytest-aioresponses==0.3.0", "pytest-asyncio==0.25.3", + "google-cloud-secret-manager==2.23.2", + "google-cloud-storage==3.1.0", ] [build-system] requires = ["setuptools"] diff --git a/packages/toolbox-core/tests/conftest.py b/packages/toolbox-core/tests/conftest.py new file mode 100644 index 00000000..e579f843 --- /dev/null +++ b/packages/toolbox-core/tests/conftest.py @@ -0,0 +1,166 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains pytest fixtures that are accessible from all +files present in the same directory.""" + +from __future__ import annotations + +import os +import platform +import subprocess +import tempfile +import time +from typing import Generator + +import google +import pytest_asyncio +from google.auth import compute_engine +from google.cloud import secretmanager, storage + + +#### Define Utility Functions +def get_env_var(key: str) -> str: + """Gets environment variables.""" + value = os.environ.get(key) + if value is None: + raise ValueError(f"Must set env var {key}") + return value + + +def access_secret_version( + project_id: str, secret_id: str, version_id: str = "latest" +) -> str: + """Accesses the payload of a given secret version from Secret Manager.""" + client = secretmanager.SecretManagerServiceClient() + name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}" + response = client.access_secret_version(request={"name": name}) + return response.payload.data.decode("UTF-8") + + +def create_tmpfile(content: str) -> str: + """Creates a temporary file with the given content.""" + with tempfile.NamedTemporaryFile(delete=False, mode="w") as tmpfile: + tmpfile.write(content) + return tmpfile.name + + +def download_blob( + bucket_name: str, source_blob_name: str, destination_file_name: str +) -> None: + """Downloads a blob from a GCS bucket.""" + storage_client = storage.Client() + + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(source_blob_name) + blob.download_to_filename(destination_file_name) + + print(f"Blob {source_blob_name} downloaded to {destination_file_name}.") + + +def get_toolbox_binary_url(toolbox_version: str) -> str: + """Constructs the GCS path to the toolbox binary.""" + os_system = platform.system().lower() + arch = ( + "arm64" if os_system == "darwin" and platform.machine() == "arm64" else "amd64" + ) + return f"v{toolbox_version}/{os_system}/{arch}/toolbox" + + +def get_auth_token(client_id: str) -> str: + """Retrieves an authentication token""" + request = google.auth.transport.requests.Request() + credentials = compute_engine.IDTokenCredentials( + request=request, + target_audience=client_id, + use_metadata_identity_endpoint=True, + ) + if not credentials.valid: + credentials.refresh(request) + return credentials.token + + +#### Define Fixtures +@pytest_asyncio.fixture(scope="session") +def project_id() -> str: + return get_env_var("GOOGLE_CLOUD_PROJECT") + + +@pytest_asyncio.fixture(scope="session") +def toolbox_version() -> str: + return get_env_var("TOOLBOX_VERSION") + + +@pytest_asyncio.fixture(scope="session") +def tools_file_path(project_id: str) -> Generator[str]: + """Provides a temporary file path containing the tools manifest.""" + tools_manifest = access_secret_version( + project_id=project_id, secret_id="sdk_testing_tools" + ) + tools_file_path = create_tmpfile(tools_manifest) + yield tools_file_path + os.remove(tools_file_path) + + +@pytest_asyncio.fixture(scope="session") +def auth_token1(project_id: str) -> str: + client_id = access_secret_version( + project_id=project_id, secret_id="sdk_testing_client1" + ) + return get_auth_token(client_id) + + +@pytest_asyncio.fixture(scope="session") +def auth_token2(project_id: str) -> str: + client_id = access_secret_version( + project_id=project_id, secret_id="sdk_testing_client2" + ) + return get_auth_token(client_id) + + +@pytest_asyncio.fixture(scope="session") +def toolbox_server(toolbox_version: str, tools_file_path: str) -> Generator[None]: + """Starts the toolbox server as a subprocess.""" + print("Downloading toolbox binary from gcs bucket...") + source_blob_name = get_toolbox_binary_url(toolbox_version) + download_blob("genai-toolbox", source_blob_name, "toolbox") + print("Toolbox binary downloaded successfully.") + try: + print("Opening toolbox server process...") + # Make toolbox executable + os.chmod("toolbox", 0o700) + # Run toolbox binary + toolbox_server = subprocess.Popen( + ["./toolbox", "--tools_file", tools_file_path] + ) + + # Wait for server to start + # Retry logic with a timeout + for _ in range(5): # retries + time.sleep(2) + print("Checking if toolbox is successfully started...") + if toolbox_server.poll() is None: + print("Toolbox server started successfully.") + break + else: + raise RuntimeError("Toolbox server failed to start after 5 retries.") + except subprocess.CalledProcessError as e: + print(e.stderr.decode("utf-8")) + print(e.stdout.decode("utf-8")) + raise RuntimeError(f"{e}\n\n{e.stderr.decode('utf-8')}") from e + yield + + # Clean up toolbox server + toolbox_server.terminate() + toolbox_server.wait(timeout=5) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py new file mode 100644 index 00000000..43f4d0f8 --- /dev/null +++ b/packages/toolbox-core/tests/test_e2e.py @@ -0,0 +1,186 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import pytest_asyncio + +from toolbox_core.client import ToolboxClient +from toolbox_core.tool import ToolboxTool + + +# --- Shared Fixtures Defined at Module Level --- +@pytest_asyncio.fixture(scope="function") +async def toolbox(): + """Creates a ToolboxClient instance shared by all tests in this module.""" + toolbox = ToolboxClient("http://localhost:5000") + try: + yield toolbox + finally: + await toolbox.close() + + +@pytest_asyncio.fixture(scope="function") +async def get_n_rows_tool(toolbox: ToolboxClient) -> ToolboxTool: + """Load the 'get-n-rows' tool using the shared toolbox client.""" + tool = await toolbox.load_tool("get-n-rows") + assert tool.__name__ == "get-n-rows" + return tool + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestBasicE2E: + @pytest.mark.parametrize( + "toolset_name, expected_length, expected_tools", + [ + ("my-toolset", 1, ["get-row-by-id"]), + ("my-toolset-2", 2, ["get-n-rows", "get-row-by-id"]), + ], + ) + async def test_load_toolset_specific( + self, + toolbox: ToolboxClient, + toolset_name: str, + expected_length: int, + expected_tools: list[str], + ): + """Load a specific toolset""" + toolset = await toolbox.load_toolset(toolset_name) + assert len(toolset) == expected_length + tool_names = {tool.__name__ for tool in toolset} + assert tool_names == set(expected_tools) + + async def test_run_tool(self, get_n_rows_tool: ToolboxTool): + """Invoke a tool.""" + response = await get_n_rows_tool(num_rows="2") + + assert isinstance(response, str) + assert "row1" in response + assert "row2" in response + assert "row3" not in response + + async def test_run_tool_missing_params(self, get_n_rows_tool: ToolboxTool): + """Invoke a tool with missing params.""" + with pytest.raises(TypeError, match="missing a required argument: 'num_rows'"): + await get_n_rows_tool() + + async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool): + """Invoke a tool with wrong param type.""" + with pytest.raises( + Exception, + match='provided parameters were invalid: unable to parse value for "num_rows": .* not type "string"', + ): + await get_n_rows_tool(num_rows=2) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestBindParams: + async def test_bind_params( + self, toolbox: ToolboxClient, get_n_rows_tool: ToolboxTool + ): + """Bind a param to an existing tool.""" + new_tool = get_n_rows_tool.bind_parameters({"num_rows": "3"}) + response = await new_tool() + assert isinstance(response, str) + assert "row1" in response + assert "row2" in response + assert "row3" in response + assert "row4" not in response + + async def test_bind_params_callable( + self, toolbox: ToolboxClient, get_n_rows_tool: ToolboxTool + ): + """Bind a callable param to an existing tool.""" + new_tool = get_n_rows_tool.bind_parameters({"num_rows": lambda: "3"}) + response = await new_tool() + assert isinstance(response, str) + assert "row1" in response + assert "row2" in response + assert "row3" in response + assert "row4" not in response + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestAuth: + async def test_run_tool_unauth_with_auth( + self, toolbox: ToolboxClient, auth_token2: str + ): + """Tests running a tool that doesn't require auth, with auth provided.""" + tool = await toolbox.load_tool( + "get-row-by-id", auth_token_getters={"my-test-auth": lambda: auth_token2} + ) + response = await tool(id="2") + assert "row2" in response + + async def test_run_tool_no_auth(self, toolbox: ToolboxClient): + """Tests running a tool requiring auth without providing auth.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + with pytest.raises( + Exception, + match="tool invocation not authorized. Please make sure your specify correct auth headers", + ): + await tool(id="2") + + async def test_run_tool_wrong_auth(self, toolbox: ToolboxClient, auth_token2: str): + """Tests running a tool with incorrect auth. The tool + requires a different authentication than the one provided.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + auth_tool = tool.add_auth_token_getters({"my-test-auth": lambda: auth_token2}) + with pytest.raises( + Exception, + match="tool invocation not authorized", + ): + await auth_tool(id="2") + + async def test_run_tool_auth(self, toolbox: ToolboxClient, auth_token1: str): + """Tests running a tool with correct auth.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + auth_tool = tool.add_auth_token_getters({"my-test-auth": lambda: auth_token1}) + response = await auth_tool(id="2") + assert "row2" in response + + async def test_run_tool_param_auth_no_auth(self, toolbox: ToolboxClient): + """Tests running a tool with a param requiring auth, without auth.""" + tool = await toolbox.load_tool("get-row-by-email-auth") + with pytest.raises( + Exception, + match="One or more of the following authn services are required to invoke this tool: my-test-auth", + ): + await tool() + + async def test_run_tool_param_auth(self, toolbox: ToolboxClient, auth_token1: str): + """Tests running a tool with a param requiring auth, with correct auth.""" + tool = await toolbox.load_tool( + "get-row-by-email-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, + ) + response = await tool() + assert "row4" in response + assert "row5" in response + assert "row6" in response + + async def test_run_tool_param_auth_no_field( + self, toolbox: ToolboxClient, auth_token1: str + ): + """Tests running a tool with a param requiring auth, with insufficient auth.""" + tool = await toolbox.load_tool( + "get-row-by-content-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, + ) + with pytest.raises( + Exception, + match="no field named row_data in claims", + ): + await tool()