diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml new file mode 100644 index 00000000..be02d2a1 --- /dev/null +++ b/.ci/integration.cloudbuild.yaml @@ -0,0 +1,46 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +steps: + - id: Install library requirements + name: 'python:${_VERSION}' + args: + - install + - '-r' + - './requirements.txt' + - '--user' + entrypoint: pip + - id: Install test requirements + name: 'python:${_VERSION}' + args: + - install + - '.[test]' + - '--user' + entrypoint: pip + - id: Run integration tests + name: 'python:${_VERSION}' + env: + - TOOLBOX_URL=$_TOOLBOX_URL + - TOOLBOX_VERSION=$_TOOLBOX_VERSION + - GOOGLE_CLOUD_PROJECT=$PROJECT_ID + args: + - '-c' + - >- + python -m pytest ./tests/ + entrypoint: /bin/bash +options: + logging: CLOUD_LOGGING_ONLY +substitutions: + _VERSION: '3.13' + _TOOLBOX_VERSION: '0.0.5' diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000..2170d838 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,43 @@ +--- +name: Bug report +about: Create a report to help us improve + +--- + +Thanks for stopping by to let us know something could be better! + +**PLEASE READ**: If you have a support contract with Google, please create an issue in the [support console](https://cloud.google.com/support/) instead of filing on GitHub. This will ensure a timely response. + +Please run down the following list and make sure you've tried the usual "quick fixes": + + - Search the issues already opened: https://github.com/googleapis/genai-toolbox-langchain-python/issues + - Search StackOverflow: https://stackoverflow.com/questions/tagged/google-cloud-platform+python + +If you are still having issues, please be sure to include as much information as possible: + +#### Environment details + + - OS type and version: + - Python version: `python --version` + - pip version: `pip --version` + - `toolbox-langchain-sdk` version: `pip show toolbox-langchain-sdk` + +#### Steps to reproduce + + 1. ? + 2. ? + +#### Code example + +```python +# example +``` + +#### Stack trace +``` +# example +``` + +Making sure to follow these steps will guarantee the quickest resolution possible. + +Thanks! \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 00000000..b8d7217c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,18 @@ +--- +name: Feature request +about: Suggest an idea for this library + +--- + +Thanks for stopping by to let us know something could be better! + +**PLEASE READ**: If you have a support contract with Google, please create an issue in the [support console](https://cloud.google.com/support/) instead of filing on GitHub. This will ensure a timely response. + + **Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + **Describe the solution you'd like** +A clear and concise description of what you want to happen. + **Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + **Additional context** +Add any other context or screenshots about the feature request here. \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/support_request.md b/.github/ISSUE_TEMPLATE/support_request.md new file mode 100644 index 00000000..9e5b2532 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/support_request.md @@ -0,0 +1,7 @@ +--- +name: Support request +about: If you have a support contract with Google, please create an issue in the Google Cloud Support console. + +--- + +**PLEASE READ**: If you have a support contract with Google, please create an issue in the [support console](https://cloud.google.com/support/) instead of filing on GitHub. This will ensure a timely response. \ No newline at end of file diff --git a/.github/auto-label.yaml b/.github/auto-label.yaml new file mode 100644 index 00000000..57437684 --- /dev/null +++ b/.github/auto-label.yaml @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +enabled: false diff --git a/.github/blunderbuss.yml b/.github/blunderbuss.yml new file mode 100644 index 00000000..68d529c3 --- /dev/null +++ b/.github/blunderbuss.yml @@ -0,0 +1,4 @@ +assign_issues: + - googleapis/genai-toolbox-langchain-python +assign_prs: + - googleapis/genai-toolbox-langchain-python \ No newline at end of file diff --git a/.github/header-checker-lint.yml b/.github/header-checker-lint.yml index bd097511..42529751 100644 --- a/.github/header-checker-lint.yml +++ b/.github/header-checker-lint.yml @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/.github/labels.yaml b/.github/labels.yaml new file mode 100644 index 00000000..552ca4de --- /dev/null +++ b/.github/labels.yaml @@ -0,0 +1,70 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +- name: duplicate + color: ededed + description: "" + +- name: 'type: bug' + color: db4437 + description: Error or flaw in code with unintended results or allowing sub-optimal + usage patterns. +- name: 'type: cleanup' + color: c5def5 + description: An internal cleanup or hygiene concern. +- name: 'type: docs' + color: 0000A0 + description: Improvement to the documentation for an API. +- name: 'type: feature request' + color: c5def5 + description: ‘Nice-to-have’ improvement, new feature or different behavior or design. +- name: 'type: process' + color: c5def5 + description: A process-related concern. May include testing, release, or the like. +- name: 'type: question' + color: c5def5 + description: Request for information or clarification. + +- name: 'priority: p0' + color: b60205 + description: Highest priority. Critical issue. P0 implies highest priority. +- name: 'priority: p1' + color: ffa03e + description: Important issue which blocks shipping the next release. Will be fixed + prior to next release. +- name: 'priority: p2' + color: fef2c0 + description: Moderately-important priority. Fix may not be included in next release. +- name: 'priority: p3' + color: ffffc7 + description: Desirable enhancement or fix. May not be included in next release. + +- name: do not merge + color: d93f0b + description: Indicates a pull request not ready for merge, due to either quality + or timing. + +- name: 'autorelease: pending' + color: ededed + description: Release please needs to do its work on this. +- name: 'autorelease: triggered' + color: ededed + description: Release please has triggered a release for this. +- name: 'autorelease: tagged' + color: ededed + description: Release please has completed a release for this. + +- name: 'tests: run' + color: 3DED97 + description: Label to trigger Github Action tests. diff --git a/.github/release-please.yml b/.github/release-please.yml new file mode 100644 index 00000000..0b469937 --- /dev/null +++ b/.github/release-please.yml @@ -0,0 +1,18 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +handleGHRelease: true +packageName: genai-toolbox-langchain-python +releaseType: simple +versionFile: "cmd/version.txt" diff --git a/.github/release-trigger.yml b/.github/release-trigger.yml new file mode 100644 index 00000000..7fe36225 --- /dev/null +++ b/.github/release-trigger.yml @@ -0,0 +1 @@ +enabled: true \ No newline at end of file diff --git a/.github/renovate.json5 b/.github/renovate.json5 new file mode 100644 index 00000000..61a303ac --- /dev/null +++ b/.github/renovate.json5 @@ -0,0 +1,45 @@ +{ + "extends": [ + "config:base", // https://docs.renovatebot.com/presets-config/#configbase + ":semanticCommits", // https://docs.renovatebot.com/presets-default/#semanticcommits + ":ignoreUnstable", // https://docs.renovatebot.com/presets-default/#ignoreunstable + "group:allNonMajor", // https://docs.renovatebot.com/presets-group/#groupallnonmajor + ":separateMajorReleases", // https://docs.renovatebot.com/presets-default/#separatemajorreleases + ":prConcurrentLimitNone", // View complete backlog as PRs. https://docs.renovatebot.com/presets-default/#prconcurrentlimitnone + ":prHourlyLimitNone", // https://docs.renovatebot.com/presets-default/#prhourlylimitnone + ":preserveSemverRanges" + ], + + // Give ecosystem time to catch up. + // npm allows maintainers to unpublish a release up to 3 days later. + // https://docs.renovatebot.com/configuration-options/#minimumreleaseage + "minimumReleaseAge": "3", + + // Create PRs, but do not update them without manual action. + // Reduces spurious retesting in repositories that have many PRs at a time. + // https://docs.renovatebot.com/configuration-options/#rebasewhen + "rebaseWhen": "conflicted", + + // Organizational processes. + // https://docs.renovatebot.com/configuration-options/#dependencydashboardlabels + "dependencyDashboardLabels": [ + "type: process" + ], + "packageRules": [ + { + "groupName": "GitHub Actions", + "matchManagers": ["github-actions"], + "pinDigests": true + }, + // Python Specific + { + "matchPackageNames": ["pytest"], + "matchUpdateTypes": ["minor", "major"] + }, + { + "groupName": "python-nonmajor", + "matchLanguages": ["python"], + "matchUpdateTypes": ["minor", "patch"] + } + ] +} \ No newline at end of file diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml index ea980b27..3e117dd3 100644 --- a/.github/sync-repo-settings.yaml +++ b/.github/sync-repo-settings.yaml @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - # Synchronize repository settings from a centralized config # https://github.com/googleapis/repo-automation-bots/tree/main/packages/sync-repo-settings # Install: https://github.com/apps/sync-repo-settings @@ -23,14 +22,18 @@ squashMergeAllowed: true mergeCommitAllowed: false # Enable branch protection branchProtectionRules: -- pattern: main - isAdminEnforced: true - requiredStatusCheckContexts: - - 'cla/google' - # - Add required status checks like presubmit tests - requiredApprovingReviewCount: 1 - requiresCodeOwnerReviews: true - requiresStrictStatusChecks: true + - pattern: main + isAdminEnforced: true + requiredStatusCheckContexts: + - "cla/google" + - "lint" + - "conventionalcommits.org" + - "header-check" + # - Add required status checks like presubmit tests + - "langchain-python-sdk-pr-py313 (toolbox-testing-438616)" + requiredApprovingReviewCount: 1 + requiresCodeOwnerReviews: true + requiresStrictStatusChecks: true # Set team access permissionRules: diff --git a/.github/trusted-contribution.yml b/.github/trusted-contribution.yml new file mode 100644 index 00000000..244cb6d1 --- /dev/null +++ b/.github/trusted-contribution.yml @@ -0,0 +1,28 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Trigger presubmit tests for trusted contributors +# https://github.com/googleapis/repo-automation-bots/tree/main/packages/trusted-contribution +# Install: https://github.com/apps/trusted-contributions-gcf + +trustedContributors: + - "dependabot[bot]" + - "renovate-bot" + - "renovate[bot]" + - "forking-renovate[bot]" + - "release-please[bot]" +annotations: + # Trigger Cloud Build tests + - type: comment + text: "/gcbrun" \ No newline at end of file diff --git a/.github/workflows/cloud_build_failure_reporter.yml b/.github/workflows/cloud_build_failure_reporter.yml new file mode 100644 index 00000000..39116bbf --- /dev/null +++ b/.github/workflows/cloud_build_failure_reporter.yml @@ -0,0 +1,179 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Cloud Build Failure Reporter + +on: + workflow_call: + inputs: + trigger_names: + required: true + type: string + workflow_dispatch: + inputs: + trigger_names: + description: 'Cloud Build trigger names separated by comma.' + required: true + default: '' + +jobs: + report: + + permissions: + issues: 'write' + checks: 'read' + + runs-on: 'ubuntu-latest' + + steps: + - uses: 'actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea' # v7 + with: + script: |- + // parse test names + const testNameSubstring = '${{ inputs.trigger_names }}'; + const testNameFound = new Map(); //keeps track of whether each test is found + testNameSubstring.split(',').forEach(testName => { + testNameFound.set(testName, false); + }); + + // label for all issues opened by reporter + const periodicLabel = 'periodic-failure'; + + // check if any reporter opened any issues previously + const prevIssues = await github.paginate(github.rest.issues.listForRepo, { + ...context.repo, + state: 'open', + creator: 'github-actions[bot]', + labels: [periodicLabel] + }); + + // createOrCommentIssue creates a new issue or comments on an existing issue. + const createOrCommentIssue = async function (title, txt) { + if (prevIssues.length < 1) { + console.log('no previous issues found, creating one'); + await github.rest.issues.create({ + ...context.repo, + title: title, + body: txt, + labels: [periodicLabel] + }); + return; + } + // only comment on issue related to the current test + for (const prevIssue of prevIssues) { + if (prevIssue.title.includes(title)){ + console.log( + `found previous issue ${prevIssue.html_url}, adding comment` + ); + + await github.rest.issues.createComment({ + ...context.repo, + issue_number: prevIssue.number, + body: txt + }); + return; + } + } + }; + + // updateIssues comments on any existing issues. No-op if no issue exists. + const updateIssues = async function (checkName, txt) { + if (prevIssues.length < 1) { + console.log('no previous issues found.'); + return; + } + // only comment on issue related to the current test + for (const prevIssue of prevIssues) { + if (prevIssue.title.includes(checkName)){ + console.log(`found previous issue ${prevIssue.html_url}, adding comment`); + await github.rest.issues.createComment({ + ...context.repo, + issue_number: prevIssue.number, + body: txt + }); + } + } + }; + + // Find status of check runs. + // We will find check runs for each commit and then filter for the periodic. + // Checks API only allows for ref and if we use main there could be edge cases where + // the check run happened on a SHA that is different from head. + const commits = await github.paginate(github.rest.repos.listCommits, { + ...context.repo + }); + + const relevantChecks = new Map(); + for (const commit of commits) { + console.log( + `checking runs at ${commit.html_url}: ${commit.commit.message}` + ); + const checks = await github.rest.checks.listForRef({ + ...context.repo, + ref: commit.sha + }); + + // Iterate through each check and find matching names + for (const check of checks.data.check_runs) { + console.log(`Handling test name ${check.name}`); + for (const testName of testNameFound.keys()) { + if (testNameFound.get(testName) === true){ + //skip if a check is already found for this name + continue; + } + if (check.name.includes(testName)) { + relevantChecks.set(check, commit); + testNameFound.set(testName, true); + } + } + } + // Break out of the loop early if all tests are found + const allTestsFound = Array.from(testNameFound.values()).every(value => value === true); + if (allTestsFound){ + break; + } + } + + // Handle each relevant check + relevantChecks.forEach((commit, check) => { + if ( + check.status === 'completed' && + check.conclusion === 'success' + ) { + updateIssues( + check.name, + `[Tests are passing](${check.html_url}) for commit [${commit.sha}](${commit.html_url}).` + ); + } else if (check.status === 'in_progress') { + console.log( + `Check is pending ${check.html_url} for ${commit.html_url}. Retry again later.` + ); + } else { + createOrCommentIssue( + `Cloud Build Failure Reporter: ${check.name} failed`, + `Cloud Build Failure Reporter found test failure for [**${check.name}** ](${check.html_url}) at [${commit.sha}](${commit.html_url}). Please fix the error and then close the issue after the **${check.name}** test passes.` + ); + } + }); + + // no periodic checks found across all commits, report it + const noTestFound = Array.from(testNameFound.values()).every(value => value === false); + if (noTestFound){ + createOrCommentIssue( + 'Missing periodic tests: ${{ inputs.trigger_names }}', + `No periodic test is found for triggers: ${{ inputs.trigger_names }}. Last checked from ${ + commits[0].html_url + } to ${commits[commits.length - 1].html_url}.` + ); + } diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 00000000..a77f8627 --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,78 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: lint +on: + pull_request: + pull_request_target: + types: [labeled] + +# Declare default permissions as read only. +permissions: read-all + +jobs: + lint: + if: "${{ github.event.action != 'labeled' || github.event.label.name == 'tests: run' }}" + name: lint + runs-on: ubuntu-latest + concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + permissions: + contents: 'read' + issues: 'write' + pull-requests: 'write' + steps: + - name: Remove PR Label + if: "${{ github.event.action == 'labeled' && github.event.label.name == 'tests: run' }}" + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + try { + await github.rest.issues.removeLabel({ + name: 'tests: run', + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number + }); + } catch (e) { + console.log('Failed to remove label. Another job may have already removed it!'); + } + - name: Checkout code + uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4.1.6 + with: + ref: ${{ github.event.pull_request.head.sha }} + repository: ${{ github.event.pull_request.head.repo.full_name }} + token: ${{ secrets.GITHUB_TOKEN }} + - name: Setup Python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: "3.13" + + - name: Install library requirements + run: pip install -r requirements.txt + + - name: Install test requirements + run: pip install .[test] + + - name: Run linters + run: | + black --check . + isort --check . + + - name: Run type-check + env: + MYPYPATH: './src' + run: mypy --install-types --non-interactive --cache-dir=.mypy_cache/ -p toolbox_langchain_sdk \ No newline at end of file diff --git a/.github/workflows/lint_fallback.yml b/.github/workflows/lint_fallback.yml new file mode 100644 index 00000000..ad258146 --- /dev/null +++ b/.github/workflows/lint_fallback.yml @@ -0,0 +1,28 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: lint +on: + pull_request: + paths: # These paths are the inverse of lint.yml + - "./**/*.md" + +jobs: + lint: + runs-on: ubuntu-latest + permissions: + contents: none + + steps: + - run: echo "No tests required." \ No newline at end of file diff --git a/.github/workflows/schedule_reporter.yml b/.github/workflows/schedule_reporter.yml new file mode 100644 index 00000000..25f6bfab --- /dev/null +++ b/.github/workflows/schedule_reporter.yml @@ -0,0 +1,25 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Schedule Reporter + +on: + schedule: + - cron: '0 6 * * *' # Runs at 6 AM every morning + +jobs: + run_reporter: + uses: ./.github/workflows/cloud_build_failure_reporter.yml + with: + trigger_names: "langchain-python-sdk-test-nightly,langchain-python-sdk-test-on-merge" \ No newline at end of file diff --git a/.github/workflows/sync-labels.yaml b/.github/workflows/sync-labels.yaml new file mode 100644 index 00000000..22df93c9 --- /dev/null +++ b/.github/workflows/sync-labels.yaml @@ -0,0 +1,37 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Sync Labels +on: + push: + branches: + - main + +# Declare default permissions as read only. +permissions: read-all + +jobs: + build: + runs-on: ubuntu-latest + permissions: + contents: 'read' + issues: 'write' + pull-requests: 'write' + steps: + - uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4.1.6 + - uses: micnncim/action-label-syncer@3abd5ab72fda571e69fffd97bd4e0033dd5f495c # v1.3.0 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + manifest: .github/labels.yaml diff --git a/DEVELOPER.md b/DEVELOPER.md index a3bf5819..4634c7bd 100644 --- a/DEVELOPER.md +++ b/DEVELOPER.md @@ -3,35 +3,35 @@ Below are the details to set up a development environment and run tests. ## Install - 1. Clone the repository: - ```bash - git clone https://github.com/googleapis/genai-toolbox-langchain-sdk-python.git + + git clone https://github.com/googleapis/genai-toolbox-langchain-python + ``` +1. Navigate to the repo directory: + ```bash + cd genai-toolbox-langchain-python ``` - 1. Install the package in editable mode, so changes are reflected without reinstall: - ```bash pip install -e . ``` - 1. Make code changes and contribute to the SDK's development. - > [!TIP] Using `-e` option allows you to make changes to the SDK code and have > those changes reflected immediately without reinstalling the package. ## Test - +1. Navigate to the repo directory if needed: + ```bash + cd genai-toolbox-langchain-python + ``` 1. Install the SDK and test dependencies: - ```bash pip install -e .[test] ``` - 1. Run tests and/or contribute to the SDK's development. ```bash pytest - ``` + ``` \ No newline at end of file diff --git a/README.md b/README.md index 67fe8b09..19094a7c 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,346 @@ -# Toolbox Langchain SDK +# GenAI Toolbox LangChain SDK This SDK allows you to seamlessly integrate the functionalities of -[Toolbox](https://github.com/googleapis/genai-toolbox) into your LLM +[Toolbox](https://github.com/googleapis/genai-toolbox) into your LangChain LLM applications, enabling advanced orchestration and interaction with GenAI models. + ## Table of Contents + -## Getting Started +- [Quickstart](#quickstart) +- [Installation](#installation) +- [Usage](#usage) +- [Loading Tools](#loading-tools) + - [Load a toolset](#load-a-toolset) + - [Load a single tool](#load-a-single-tool) +- [Use with LangChain](#use-with-langchain) +- [Use with LangGraph](#use-with-langgraph) + - [Represent Tools as Nodes](#represent-tools-as-nodes) + - [Connect Tools with LLM](#connect-tools-with-llm) +- [Manual usage](#manual-usage) +- [Authenticating Tools](#authenticating-tools) + - [Supported Authentication Mechanisms](#supported-authentication-mechanisms) + - [Configure Tools](#configure-tools) + - [Configure SDK](#configure-sdk) + - [Add Authentication to a Tool](#add-authentication-to-a-tool) + - [Add Authentication While Loading](#add-authentication-while-loading) + - [Complete Example](#complete-example) +- [Binding Parameter Values](#binding-parameter-values) + - [Binding Parameters to a Tool](#binding-parameters-to-a-tool) + - [Binding Parameters While Loading](#binding-parameters-while-loading) + - [Binding Dynamic Values](#binding-dynamic-values) +- [Error Handling](#error-handling) -### Prerequisites + -## Contributing +## Quickstart -Contributions to this library are always welcome and highly encouraged. +Here's a minimal example to get you started: -See [CONTRIBUTING](CONTRIBUTING.md) for more information how to get started. +```py +import asyncio +from toolbox_langchain_sdk import ToolboxClient +from langchain_google_vertexai import ChatVertexAI -Please note that this project is released with a Contributor Code of Conduct. By participating in -this project you agree to abide by its terms. See [Code of Conduct](CODE_OF_CONDUCT.md) for more -information. +async def main(): + toolbox = ToolboxClient("http://127.0.0.1:5000") + tools = await toolbox.load_toolset() + + model = ChatVertexAI(model="gemini-1.5-pro-002") + agent = model.bind_tools(tools) + result = agent.invoke("How's the weather today?") + print(result) -## License +if __name__ == "__main__": + asyncio.run(main()) +``` -Apache 2.0 - See [LICENSE](LICENSE) for more information. \ No newline at end of file +## Installation + +> [!IMPORTANT] +> This SDK is not yet available on PyPI. For now, install it from source by +> following these [installation instructions](DEVELOPER.md). + +You can install the Toolbox SDK for LangChain using `pip`. + +```bash +pip install toolbox-langchain-sdk +``` + +## Usage + +Import and initialize the toolbox client. + +```py +from toolbox_langchain_sdk import ToolboxClient + +# Replace with your Toolbox service's URL +toolbox = ToolboxClient("http://127.0.0.1:5000") +``` + +> [!IMPORTANT] +> The toolbox client requires an asynchronous environment. +> For guidance on running asynchronous Python programs, see +> [asyncio documentation](https://docs.python.org/3/library/asyncio-runner.html#running-an-asyncio-program). + +> [!TIP] +> You can also pass your own `ClientSession` to reuse the same session: +> ```py +> async with ClientSession() as session: +> toolbox = ToolboxClient("http://localhost:5000", session) +> ``` + +## Loading Tools + +### Load a toolset + +A toolset is a collection of related tools. You can load all tools in a toolset +or a specific one: + +```py +# Load all tools +tools = await toolbox.load_toolset() + +# Load a specific toolset +tools = await toolbox.load_toolset("my-toolset") +``` + +### Load a single tool + +```py +tool = await toolbox.load_tool("my-tool") +``` + +Loading individual tools gives you finer-grained control over which tools are +available to your LLM agent. + +## Use with LangChain + +LangChain's agents can dynamically choose and execute tools based on the user +input. Include tools loaded from the Toolbox SDK in the agent's toolkit: + +```py +from langchain_google_vertexai import ChatVertexAI + +model = ChatVertexAI(model="gemini-1.5-pro-002") + +# Initialize agent with tools +agent = model.bind_tools(tools) + +# Run the agent +result = agent.invoke("Do something with the tools") +``` + +## Use with LangGraph + +Integrate the Toolbox SDK with LangGraph to use Toolbox service tools within a +graph-based workflow. Follow the [official +guide](https://langchain-ai.github.io/langgraph/) with minimal changes. + +### Represent Tools as Nodes + +Represent each tool as a LangGraph node, encapsulating the tool's execution within the node's functionality: + +```py +from toolbox_langchain_sdk import ToolboxClient +from langgraph.graph import StateGraph, MessagesState +from langgraph.prebuilt import ToolNode + +# Define the function that calls the model +def call_model(state: MessagesState): + messages = state['messages'] + response = model.invoke(messages) + return {"messages": [response]} # Return a list to add to existing messages + +model = ChatVertexAI(model="gemini-1.5-pro-002") +builder = StateGraph(MessagesState) +tool_node = ToolNode(tools) + +builder.add_node("agent", call_model) +builder.add_node("tools", tool_node) +``` + +### Connect Tools with LLM + +Connect tool nodes with LLM nodes. The LLM decides which tool to use based on +input or context. Tool output can be fed back into the LLM: + +```py +from typing import Literal +from langgraph.graph import END, START +from langchain_core.messages import HumanMessage + +# Define the function that determines whether to continue or not +def should_continue(state: MessagesState) -> Literal["tools", END]: + messages = state['messages'] + last_message = messages[-1] + if last_message.tool_calls: + return "tools" # Route to "tools" node if LLM makes a tool call + return END # Otherwise, stop + +builder.add_edge(START, "agent") +builder.add_conditional_edges("agent", should_continue) +builder.add_edge("tools", 'agent') + +graph = builder.compile() + +graph.invoke({"messages": [HumanMessage(content="Do something with the tools")]}) +``` + +## Manual usage + +Execute a tool manually using the `ainvoke` method: + +```py +result = await tools[0].ainvoke({"name": "Alice", "age": 30}) +``` + +This is useful for testing tools or when you need precise control over tool +execution outside of an agent framework. + +## Authenticating Tools + +> [!WARNING] +> Always use HTTPS to connect your application with the Toolbox service, +> especially when using tools with authentication configured. Using HTTP exposes +> your application to serious security risks. + +Some tools require user authentication to access sensitive data. + +### Supported Authentication Mechanisms +Toolbox currently supports authentication using the [OIDC +protocol](https://openid.net/specs/openid-connect-core-1_0.html) with [ID +tokens](https://openid.net/specs/openid-connect-core-1_0.html#IDToken) (not +access tokens) for [Google OAuth +2.0](https://cloud.google.com/apigee/docs/api-platform/security/oauth/oauth-home). + +### Configure Tools + +Refer to [these +instructions](../../docs/tools/README.md#authenticated-parameters) on +configuring tools for authenticated parameters. + +### Configure SDK + +You need a method to retrieve an ID token from your authentication service: + +```py +async def get_auth_token(): + # ... Logic to retrieve ID token (e.g., from local storage, OAuth flow) + # This example just returns a placeholder. Replace with your actual token retrieval. + return "YOUR_ID_TOKEN" # Placeholder +``` + +#### Add Authentication to a Tool + +```py +toolbox = ToolboxClient("http://localhost:5000") +tools = await toolbox.load_toolset() + +auth_tool = tools[0].add_auth_token("my_auth", get_auth_token) # Single token + +multi_auth_tool = tools[0].add_auth_tokens({"my_auth", get_auth_token}) # Multiple tokens + +# OR + +auth_tools = [tool.add_auth_token("my_auth", get_auth_token) for tool in tools] +``` + +#### Add Authentication While Loading + +```py +auth_tool = await toolbox.load_tool(auth_tokens={"my_auth": get_auth_token}) + +auth_tools = await toolbox.load_toolset(auth_tokens={"my_auth": get_auth_token}) +``` + +> [!NOTE] +> Adding auth tokens during loading only affect the tools loaded within +> that call. + +### Complete Example + +```py +import asyncio +from toolbox_langchain_sdk import ToolboxClient + +async def get_auth_token(): + # ... Logic to retrieve ID token (e.g., from local storage, OAuth flow) + # This example just returns a placeholder. Replace with your actual token retrieval. + return "YOUR_ID_TOKEN" # Placeholder + +async def main(): + toolbox = ToolboxClient("http://localhost:5000") + tool = await toolbox.load_tool("my-tool") + + auth_tool = tool.add_auth_token("my_auth", get_auth_token) + result = await auth_tool.ainvoke({"input": "some input"}) + print(result) + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## Binding Parameter Values + +Predetermine values for tool parameters using the SDK. These values won't be +modified by the LLM. This is useful for: + +* **Protecting sensitive information:** API keys, secrets, etc. +* **Enforcing consistency:** Ensuring specific values for certain parameters. +* **Pre-filling known data:** Providing defaults or context. + +### Binding Parameters to a Tool + +```py +toolbox = ToolboxClient("http://localhost:5000") +tools = await toolbox.load_toolset() + +bound_tool = tool[0].bind_param("param", "value") # Single param + +multi_bound_tool = tools[0].bind_params({"param1": "value1", "param2": "value2"}) # Multiple params + +# OR + +bound_tools = [tool.bind_param("param", "value") for tool in tools] +``` + +### Binding Parameters While Loading + +```py +bound_tool = await toolbox.load_tool(bound_params={"param": "value"}) + +bound_tools = await toolbox.load_toolset(bound_params={"param": "value"}) +``` + +> [!NOTE] +> Bound values during loading only affect the tools loaded in that call. + +### Binding Dynamic Values + +Use a function to bind dynamic values: + +```py +def get_dynamic_value(): + # Logic to determine the value + return "dynamic_value" + +dynamic_bound_tool = tool.bind_param("param", get_dynamic_value) +``` + +> [!IMPORTANT] +> You don't need to modify tool configurations to bind parameter values. + +## Error Handling + +When interacting with the Toolbox service or executing tools, you might +encounter errors. Handle potential exceptions gracefully: + +```py +try: + result = await tool.ainvoke({"input": "some input"}) +except Exception as e: + print(f"An error occurred: {e}") + # Implement error recovery logic, e.g., retrying the request or logging the error +``` \ No newline at end of file diff --git a/cmd/version.txt b/cmd/version.txt new file mode 100644 index 00000000..8acdd82b --- /dev/null +++ b/cmd/version.txt @@ -0,0 +1 @@ +0.0.1 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..8a75956a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,60 @@ +[project] +name = "toolbox-langchain-sdk" +version="0.0.1" +description = "Python SDK for interacting with the Toolbox service with LangChain" +license = {file = "LICENSE"} +requires-python = ">=3.9" +authors = [ + {name = "Google LLC", email = "googleapis-packages@google.com"} +] +dependencies = [ + "langchain-core>=0.2.23,<1.0.0", + "PyYAML>=6.0.1,<7.0.0", + "pydantic>=2.7.0,<3.0.0", + "aiohttp>=3.8.6,<4.0.0", + "deprecated>=1.1.0,<2.0.0", +] + +classifiers = [ + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] + +[project.urls] +Homepage = "https://github.com/googleapis/genai-toolbox" +Repository = "https://github.com/googleapis/genai-toolbox.git" +"Bug Tracker" = "https://github.com/googleapis/genai-toolbox/issues" + +[project.optional-dependencies] +test = [ + "black[jupyter]==24.10.0", + "isort==5.13.2", + "mypy==1.13.0", + "pytest-asyncio==0.24.0", + "pytest==8.3.3", + "pytest-cov==6.0.0", + "Pillow==10.4.0", + "google-cloud-secret-manager==2.22.0", + "google-cloud-storage==2.19.0", +] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.black] +target-version = ['py39'] + +[tool.isort] +profile = "black" + +[tool.mypy] +python_version = "3.9" +warn_unused_configs = true +disallow_incomplete_defs = true diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..bebf7f51 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +langchain-core==0.3.21 +PyYAML==6.0.2 +pydantic==2.10.2 +aiohttp==3.11.7 +deprecated==1.2.15 \ No newline at end of file diff --git a/src/toolbox_langchain_sdk/__init__.py b/src/toolbox_langchain_sdk/__init__.py new file mode 100644 index 00000000..5ff0058f --- /dev/null +++ b/src/toolbox_langchain_sdk/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .client import ToolboxClient +from .tools import ToolboxTool + +__all__ = ["ToolboxClient", "ToolboxTool"] diff --git a/src/toolbox_langchain_sdk/client.py b/src/toolbox_langchain_sdk/client.py new file mode 100644 index 00000000..660f3ceb --- /dev/null +++ b/src/toolbox_langchain_sdk/client.py @@ -0,0 +1,201 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import Any, Callable, Optional, Union +from warnings import warn + +from aiohttp import ClientSession + +from .tools import ToolboxTool +from .utils import ManifestSchema, _load_manifest + + +class ToolboxClient: + def __init__(self, url: str, session: Optional[ClientSession] = None): + """ + Initializes the ToolboxClient for the Toolbox service at the given URL. + + Args: + url: The base URL of the Toolbox service. + session: An optional HTTP client session. If not provided, a new + session will be created. + """ + self._url: str = url + self._should_close_session: bool = session is None + self._session: ClientSession = session or ClientSession() + + async def close(self) -> None: + """ + Closes the HTTP client session if it was created by this client. + """ + # We check whether _should_close_session is set or not since we do not + # want to close the session in case the user had passed their own + # ClientSession object, since then we expect the user to be owning its + # lifecycle. + if self._session and self._should_close_session: + await self._session.close() + + def __del__(self): + """ + Ensures the HTTP client session is closed when the client is garbage + collected. + """ + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(self.close()) + else: + loop.run_until_complete(self.close()) + except Exception: + # We "pass" assuming that the exception is thrown because the event + # loop is no longer running, but at that point the Session should + # have been closed already anyway. + pass + + async def _load_tool_manifest(self, tool_name: str) -> ManifestSchema: + """ + Fetches and parses the manifest schema for the given tool from the + Toolbox service. + + Args: + tool_name: The name of the tool to load. + + Returns: + The parsed Toolbox manifest. + """ + url = f"{self._url}/api/tool/{tool_name}" + return await _load_manifest(url, self._session) + + async def _load_toolset_manifest( + self, toolset_name: Optional[str] = None + ) -> ManifestSchema: + """ + Fetches and parses the manifest schema from the Toolbox service. + + Args: + toolset_name: The name of the toolset to load. If not provided, + the manifest for all available tools is loaded. + + Returns: + The parsed Toolbox manifest. + """ + url = f"{self._url}/api/toolset/{toolset_name or ''}" + return await _load_manifest(url, self._session) + + async def load_tool( + self, + tool_name: str, + auth_tokens: dict[str, Callable[[], str]] = {}, + auth_headers: Optional[dict[str, Callable[[], str]]] = None, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> ToolboxTool: + """ + Loads the tool with the given tool name from the Toolbox service. + + Args: + tool_name: The name of the tool to load. + auth_tokens: An optional mapping of authentication source names to + functions that retrieve ID tokens. + auth_headers: Deprecated. Use `auth_tokens` instead. + bound_params: An optional mapping of parameter names to their + bound values. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. + + Returns: + A tool loaded from the Toolbox. + """ + if auth_headers: + if auth_tokens: + warn( + "Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_tokens` instead.", + DeprecationWarning, + ) + auth_tokens = auth_headers + + manifest: ManifestSchema = await self._load_tool_manifest(tool_name) + return ToolboxTool( + tool_name, + manifest.tools[tool_name], + self._url, + self._session, + auth_tokens, + bound_params, + strict, + ) + + async def load_toolset( + self, + toolset_name: Optional[str] = None, + auth_tokens: dict[str, Callable[[], str]] = {}, + auth_headers: Optional[dict[str, Callable[[], str]]] = None, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> list[ToolboxTool]: + """ + Loads tools from the Toolbox service, optionally filtered by toolset + name. + + Args: + toolset_name: The name of the toolset to load. If not provided, + all tools are loaded. + auth_tokens: An optional mapping of authentication source names to + functions that retrieve ID tokens. + auth_headers: Deprecated. Use `auth_tokens` instead. + bound_params: An optional mapping of parameter names to their + bound values. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. + + Returns: + A list of all tools loaded from the Toolbox. + """ + if auth_headers: + if auth_tokens: + warn( + "Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_tokens` instead.", + DeprecationWarning, + ) + auth_tokens = auth_headers + + tools: list[ToolboxTool] = [] + manifest: ManifestSchema = await self._load_toolset_manifest(toolset_name) + + for tool_name, tool_schema in manifest.tools.items(): + tools.append( + ToolboxTool( + tool_name, + tool_schema, + self._url, + self._session, + auth_tokens, + bound_params, + strict, + ) + ) + return tools diff --git a/src/toolbox_langchain_sdk/tools.py b/src/toolbox_langchain_sdk/tools.py new file mode 100644 index 00000000..0e46bccb --- /dev/null +++ b/src/toolbox_langchain_sdk/tools.py @@ -0,0 +1,374 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Any, Callable, Union +from warnings import warn + +from aiohttp import ClientSession +from langchain_core.tools import BaseTool +from typing_extensions import Self + +from .utils import ( + ParameterSchema, + ToolSchema, + _find_auth_params, + _find_bound_params, + _invoke_tool, + _schema_to_model, +) + + +class ToolboxTool(BaseTool): + """ + A subclass of LangChain's StructuredTool that supports features specific to + Toolbox, like bound parameters and authenticated tools. + """ + + def __init__( + self, + name: str, + schema: ToolSchema, + url: str, + session: ClientSession, + auth_tokens: dict[str, Callable[[], str]] = {}, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> None: + """ + Initializes a ToolboxTool instance. + + Args: + name: The name of the tool. + schema: The tool schema. + url: The base URL of the Toolbox service. + session: The HTTP client session. + auth_tokens: A mapping of authentication source names to functions + that retrieve ID tokens. + bound_params: A mapping of parameter names to their bound + values. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. + """ + + # If the schema is not already a ToolSchema instance, we create one from + # its attributes. This allows flexibility in how the schema is provided, + # accepting both a ToolSchema object and a dictionary of schema + # attributes. + if not isinstance(schema, ToolSchema): + schema = ToolSchema(**schema) + + auth_params, non_auth_params = _find_auth_params(schema.parameters) + non_auth_bound_params, non_auth_non_bound_params = _find_bound_params( + non_auth_params, list(bound_params) + ) + + # Check if the user is trying to bind a param that is authenticated or + # is missing from the given schema. + auth_bound_params: list[str] = [] + missing_bound_params: list[str] = [] + for bound_param in bound_params: + if bound_param in [param.name for param in auth_params]: + auth_bound_params.append(bound_param) + elif bound_param not in [param.name for param in non_auth_params]: + missing_bound_params.append(bound_param) + + # Create error messages for any params that are found to be + # authenticated or missing. + messages: list[str] = [] + if auth_bound_params: + messages.append( + f"Parameter(s) {', '.join(auth_bound_params)} already authenticated and cannot be bound." + ) + if missing_bound_params: + messages.append( + f"Parameter(s) {', '.join(missing_bound_params)} missing and cannot be bound." + ) + + # Join any error messages and raise them as an error or warning, + # depending on the value of the strict flag. + if messages: + message = "\n\n".join(messages) + if strict: + raise ValueError(message) + warn(message) + + # Bind values for parameters present in the schema that don't require + # authentication. + bound_params = { + param_name: param_value + for param_name, param_value in bound_params.items() + if param_name in [param.name for param in non_auth_bound_params] + } + + # Update the tools schema to validate only the presence of parameters + # that neither require authentication nor are bound. + schema.parameters = non_auth_non_bound_params + + # Due to how pydantic works, we must initialize the underlying + # StructuredTool class before assigning values to member variables. + super().__init__( + name=name, + description=schema.description, + args_schema=_schema_to_model(model_name=name, schema=schema.parameters), + ) + + self._name: str = name + self._schema: ToolSchema = schema + self._url: str = url + self._session: ClientSession = session + self._auth_tokens: dict[str, Callable[[], str]] = auth_tokens + self._auth_params: list[ParameterSchema] = auth_params + self._bound_params: dict[str, Union[Any, Callable[[], Any]]] = bound_params + + # Warn users about any missing authentication so they can add it before + # tool invocation. + self.__validate_auth(strict=False) + + async def _arun(self, **kwargs: Any) -> dict[str, Any]: + """ + The coroutine that invokes the tool with the given arguments. + + Args: + **kwargs: The arguments to the tool. + + Returns: + A dictionary containing the parsed JSON response from the tool + invocation. + """ + + # If the tool had parameters that require authentication, then right + # before invoking that tool, we check whether all these required + # authentication sources have been registered or not. + self.__validate_auth() + + # Evaluate dynamic parameter values if any + evaluated_params = {} + for param_name, param_value in self._bound_params.items(): + if callable(param_value): + evaluated_params[param_name] = param_value() + else: + evaluated_params[param_name] = param_value + + # Merge bound parameters with the provided arguments + kwargs.update(evaluated_params) + + return await _invoke_tool( + self._url, self._session, self._name, kwargs, self._auth_tokens + ) + + def _run(self, **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("Sync tool calls not supported yet.") + + def __validate_auth(self, strict: bool = True) -> None: + """ + Checks if a tool meets the authentication requirements. + + A tool is considered authenticated if all of its parameters meet at + least one of the following conditions: + + * The parameter has at least one registered authentication source. + * The parameter requires no authentication. + + Args: + strict: If True, raises a PermissionError if any required + authentication sources are not registered. If False, only issues + a warning. + + Raises: + PermissionError: If strict is True and any required authentication + sources are not registered. + """ + params_missing_auth: list[str] = [] + + # Check each parameter for at least 1 required auth source + for param in self._auth_params: + assert param.authSources is not None + has_auth = False + for src in param.authSources: + # Find first auth source that is specified + if src in self._auth_tokens: + has_auth = True + break + if not has_auth: + params_missing_auth.append(param.name) + + if params_missing_auth: + message = f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self._name} require authentication, but no valid authentication sources are registered. Please register the required sources before use." + + if strict: + raise PermissionError(message) + warn(message) + + def __create_copy( + self, + *, + auth_tokens: dict[str, Callable[[], str]] = {}, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool, + ) -> Self: + """ + Creates a deep copy of the current ToolboxTool instance, allowing for + modification of auth tokens and bound params. + + This method enables the creation of new tool instances with inherited + properties from the current instance, while optionally updating the auth + tokens and bound params. This is useful for creating variations of the + tool with additional auth tokens or bound params without modifying the + original instance, ensuring immutability. + + Args: + auth_tokens: A dictionary of auth source names to functions that + retrieve ID tokens. These tokens will be merged with the + existing auth tokens. + bound_params: A dictionary of parameter names to their + bound values or functions to retrieve the values. These params + will be merged with the existing bound params. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. + + Returns: + A new ToolboxTool instance that is a deep copy of the current + instance, with added auth tokens or bound params. + """ + new_schema = deepcopy(self._schema) + + # Reconstruct the complete parameter schema by merging the auth + # parameters back with the non-auth parameters. This is necessary to + # accurately validate the new combination of auth tokens and bound + # params in the constructor of the new ToolboxTool instance, ensuring + # that any overlaps or conflicts are correctly identified and reported + # as errors or warnings, depending on the given `strict` flag. + new_schema.parameters += self._auth_params + return type(self)( + name=self._name, + schema=new_schema, + url=self._url, + session=self._session, + auth_tokens={**self._auth_tokens, **auth_tokens}, + bound_params={**self._bound_params, **bound_params}, + strict=strict, + ) + + def add_auth_tokens( + self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True + ) -> Self: + """ + Registers functions to retrieve ID tokens for the corresponding + authentication sources. + + Args: + auth_tokens: A dictionary of authentication source names to the + functions that return corresponding ID token. + strict: If True, a ValueError is raised if any of the provided auth + tokens are already registered, or are already bound. If False, + only a warning is issued. + + Returns: + A new ToolboxTool instance that is a deep copy of the current + instance, with added auth tokens. + """ + + # Check if the authentication source is already registered. + dupe_tokens: list[str] = [] + for auth_token, _ in auth_tokens.items(): + if auth_token in self._auth_tokens: + dupe_tokens.append(auth_token) + + if dupe_tokens: + raise ValueError( + f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self._name}`." + ) + + return self.__create_copy(auth_tokens=auth_tokens, strict=strict) + + def add_auth_token( + self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + ) -> Self: + """ + Registers a function to retrieve an ID token for a given authentication + source. + + Args: + auth_source: The name of the authentication source. + get_id_token: A function that returns the ID token. + strict: If True, a ValueError is raised if any of the provided auth + tokens are already registered, or are already bound. If False, + only a warning is issued. + + Returns: + A new ToolboxTool instance that is a deep copy of the current + instance, with added auth tokens. + """ + return self.add_auth_tokens({auth_source: get_id_token}, strict=strict) + + def bind_params( + self, + bound_params: dict[str, Union[Any, Callable[[], Any]]], + strict: bool = True, + ) -> Self: + """ + Registers values or functions to retrieve the value for the + corresponding bound parameters. + + Args: + bound_params: A dictionary of the bound parameter name to the + value or function of the bound value. + strict: If True, a ValueError is raised if any of the provided bound + params are already bound, not defined in the tool's schema, or + require authentication. If False, only a warning is issued. + + Returns: + A new ToolboxTool instance that is a deep copy of the current + instance, with added bound params. + """ + + # Check if the parameter is already bound. + dupe_params: list[str] = [] + for param_name, _ in bound_params.items(): + if param_name in self._bound_params: + dupe_params.append(param_name) + + if dupe_params: + raise ValueError( + f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self._name}`." + ) + + return self.__create_copy(bound_params=bound_params, strict=strict) + + def bind_param( + self, + param_name: str, + param_value: Union[Any, Callable[[], Any]], + strict: bool = True, + ) -> Self: + """ + Registers a value or a function to retrieve the value for a given + bound parameter. + + Args: + param_name: The name of the bound parameter. + param_value: The value of the bound parameter, or a callable + that returns the value. + strict: If True, a ValueError is raised if any of the provided bound + params are already bound, not defined in the tool's schema, or + require authentication. If False, only a warning is issued. + + Returns: + A new ToolboxTool instance that is a deep copy of the current + instance, with added bound params. + """ + return self.bind_params({param_name: param_value}, strict) diff --git a/src/toolbox_langchain_sdk/utils.py b/src/toolbox_langchain_sdk/utils.py new file mode 100644 index 00000000..229c3f2f --- /dev/null +++ b/src/toolbox_langchain_sdk/utils.py @@ -0,0 +1,255 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any, Callable, Optional, Type, cast +from warnings import warn + +from aiohttp import ClientSession +from deprecated import deprecated +from pydantic import BaseModel, Field, create_model + + +class ParameterSchema(BaseModel): + """ + Schema for a tool parameter. + """ + + name: str + type: str + description: str + authSources: Optional[list[str]] = None + + +class ToolSchema(BaseModel): + """ + Schema for a tool. + """ + + description: str + parameters: list[ParameterSchema] + + +class ManifestSchema(BaseModel): + """ + Schema for the Toolbox manifest. + """ + + serverVersion: str + tools: dict[str, ToolSchema] + + +async def _load_manifest(url: str, session: ClientSession) -> ManifestSchema: + """ + Asynchronously fetches and parses the JSON manifest schema from the given + URL. + + Args: + url: The URL to fetch the JSON from. + session: The HTTP client session. + + Returns: + The parsed Toolbox manifest. + + Raises: + json.JSONDecodeError: If the response is not valid JSON. + ValueError: If the response is not a valid manifest. + """ + async with session.get(url) as response: + response.raise_for_status() + try: + parsed_json = json.loads(await response.text()) + except json.JSONDecodeError as e: + raise json.JSONDecodeError( + f"Failed to parse JSON from {url}: {e}", e.doc, e.pos + ) from e + try: + return ManifestSchema(**parsed_json) + except ValueError as e: + raise ValueError(f"Invalid JSON data from {url}: {e}") from e + + +def _schema_to_model(model_name: str, schema: list[ParameterSchema]) -> Type[BaseModel]: + """ + Converts the given manifest schema to a Pydantic BaseModel class. + + Args: + model_name: The name of the model to create. + schema: The schema to convert. + + Returns: + A Pydantic BaseModel class. + """ + field_definitions = {} + for field in schema: + field_definitions[field.name] = cast( + Any, + ( + # TODO: Remove the hardcoded optional types once optional fields + # are supported by Toolbox. + Optional[_parse_type(field.type)], + Field(description=field.description), + ), + ) + + return create_model(model_name, **field_definitions) + + +def _parse_type(type_: str) -> Any: + """ + Converts a schema type to a JSON type. + + Args: + type_: The type name to convert. + + Returns: + A valid JSON type. + + Raises: + ValueError: If the given type is not supported. + """ + + if type_ == "string": + return str + elif type_ == "integer": + return int + elif type_ == "number": + return float + elif type_ == "boolean": + return bool + elif type_ == "array": + return list + else: + raise ValueError(f"Unsupported schema type: {type_}") + + +@deprecated("Please use `_get_auth_tokens` instead.") +def _get_auth_headers(id_token_getters: dict[str, Callable[[], str]]) -> dict[str, str]: + """ + Deprecated. Use `_get_auth_tokens` instead. + """ + return _get_auth_tokens(id_token_getters) + + +def _get_auth_tokens(id_token_getters: dict[str, Callable[[], str]]) -> dict[str, str]: + """ + Gets ID tokens for the given auth sources in the getters map and returns + tokens to be included in tool invocation. + + Args: + id_token_getters: A dict that maps auth source names to the functions + that return its ID token. + + Returns: + A dictionary of tokens to be included in the tool invocation. + """ + auth_tokens = {} + for auth_source, get_id_token in id_token_getters.items(): + auth_tokens[f"{auth_source}_token"] = get_id_token() + return auth_tokens + + +async def _invoke_tool( + url: str, + session: ClientSession, + tool_name: str, + data: dict, + id_token_getters: dict[str, Callable[[], str]], +) -> dict: + """ + Asynchronously makes an API call to the Toolbox service to invoke a tool. + + Args: + url: The base URL of the Toolbox service. + session: The HTTP client session. + tool_name: The name of the tool to invoke. + data: The input data for the tool. + id_token_getters: A dict that maps auth source names to the functions + that return its ID token. + + Returns: + A dictionary containing the parsed JSON response from the tool + invocation. + """ + url = f"{url}/api/tool/{tool_name}/invoke" + auth_tokens = _get_auth_tokens(id_token_getters) + + # ID tokens contain sensitive user information (claims). Transmitting these + # over HTTP exposes the data to interception and unauthorized access. Always + # use HTTPS to ensure secure communication and protect user privacy. + if auth_tokens and not url.startswith("https://"): + warn( + "Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication." + ) + + async with session.post( + url, + json=_convert_none_to_empty_string(data), + headers=auth_tokens, + ) as response: + response.raise_for_status() + return await response.json() + + +def _convert_none_to_empty_string(input_dict): + """ + Temporary fix to convert None values to empty strings in the input data. + This is needed because the current version of the Toolbox service does not + support optional fields. + + TODO: Remove this once optional fields are supported by Toolbox. + + Args: + input_dict: The input data dictionary. + + Returns: + A new dictionary with None values replaced by empty strings. + """ + new_dict = {} + for key, value in input_dict.items(): + if value is None: + new_dict[key] = "" + else: + new_dict[key] = value + return new_dict + + +def _find_auth_params( + params: list[ParameterSchema], +) -> tuple[list[ParameterSchema], list[ParameterSchema]]: + _auth_params: list[ParameterSchema] = [] + _non_auth_params: list[ParameterSchema] = [] + + for param in params: + if param.authSources: + _auth_params.append(param) + else: + _non_auth_params.append(param) + + return (_auth_params, _non_auth_params) + + +def _find_bound_params( + params: list[ParameterSchema], bound_params: list[str] +) -> tuple[list[ParameterSchema], list[ParameterSchema]]: + _bound_params: list[ParameterSchema] = [] + _non_bound_params: list[ParameterSchema] = [] + + for param in params: + if param.name in bound_params: + _bound_params.append(param) + else: + _non_bound_params.append(param) + + return (_bound_params, _non_bound_params) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..9875fffe --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,166 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains pytest fixtures that are accessible from all +files present in the same directory.""" + +from __future__ import annotations + +import os +import platform +import subprocess +import tempfile +import time +from typing import Generator + +import google +import pytest_asyncio +from google.auth import compute_engine +from google.cloud import secretmanager, storage + + +#### Define Utility Functions +def get_env_var(key: str) -> str: + """Gets environment variables.""" + value = os.environ.get(key) + if value is None: + raise ValueError(f"Must set env var {key}") + return value + + +def access_secret_version( + project_id: str, secret_id: str, version_id: str = "latest" +) -> str: + """Accesses the payload of a given secret version from Secret Manager.""" + client = secretmanager.SecretManagerServiceClient() + name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}" + response = client.access_secret_version(request={"name": name}) + return response.payload.data.decode("UTF-8") + + +def create_tmpfile(content: str) -> str: + """Creates a temporary file with the given content.""" + with tempfile.NamedTemporaryFile(delete=False, mode="w") as tmpfile: + tmpfile.write(content) + return tmpfile.name + + +def download_blob( + bucket_name: str, source_blob_name: str, destination_file_name: str +) -> None: + """Downloads a blob from a GCS bucket.""" + storage_client = storage.Client() + + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(source_blob_name) + blob.download_to_filename(destination_file_name) + + print(f"Blob {source_blob_name} downloaded to {destination_file_name}.") + + +def get_toolbox_binary_url(toolbox_version: str) -> str: + """Constructs the GCS path to the toolbox binary.""" + os_system = platform.system().lower() + arch = ( + "arm64" if os_system == "darwin" and platform.machine() == "arm64" else "amd64" + ) + return f"v{toolbox_version}/{os_system}/{arch}/toolbox" + + +def get_auth_token(client_id: str) -> str: + """Retrieves an authentication token""" + request = google.auth.transport.requests.Request() + credentials = compute_engine.IDTokenCredentials( + request=request, + target_audience=client_id, + use_metadata_identity_endpoint=True, + ) + if not credentials.valid: + credentials.refresh(request) + return credentials.token + + +#### Define Fixtures +@pytest_asyncio.fixture(scope="session") +def project_id() -> str: + return get_env_var("GOOGLE_CLOUD_PROJECT") + + +@pytest_asyncio.fixture(scope="session") +def toolbox_version() -> str: + return get_env_var("TOOLBOX_VERSION") + + +@pytest_asyncio.fixture(scope="session") +def tools_file_path(project_id: str) -> Generator[str]: + """Provides a temporary file path containing the tools manifest.""" + tools_manifest = access_secret_version( + project_id=project_id, secret_id="sdk_testing_tools" + ) + tools_file_path = create_tmpfile(tools_manifest) + yield tools_file_path + os.remove(tools_file_path) + + +@pytest_asyncio.fixture(scope="session") +def auth_token1(project_id: str) -> str: + client_id = access_secret_version( + project_id=project_id, secret_id="sdk_testing_client1" + ) + return get_auth_token(client_id) + + +@pytest_asyncio.fixture(scope="session") +def auth_token2(project_id: str) -> str: + client_id = access_secret_version( + project_id=project_id, secret_id="sdk_testing_client2" + ) + return get_auth_token(client_id) + + +@pytest_asyncio.fixture(scope="session") +def toolbox_server(toolbox_version: str, tools_file_path: str) -> Generator[None]: + """Starts the toolbox server as a subprocess.""" + print("Downloading toolbox binary from gcs bucket...") + source_blob_name = get_toolbox_binary_url(toolbox_version) + download_blob("genai-toolbox", source_blob_name, "toolbox") + print("Toolbox binary downloaded successfully.") + try: + print("Opening toolbox server process...") + # Make toolbox executable + os.chmod("toolbox", 0o700) + # Run toolbox binary + toolbox_server = subprocess.Popen( + ["./toolbox", "--tools_file", tools_file_path] + ) + + # Wait for server to start + # Retry logic with a timeout + for _ in range(5): # retries + time.sleep(2) + print("Checking if toolbox is successfully started...") + if toolbox_server.poll() is None: + print("Toolbox server started successfully.") + break + else: + raise RuntimeError("Toolbox server failed to start after 5 retries.") + except subprocess.CalledProcessError as e: + print(e.stderr.decode("utf-8")) + print(e.stdout.decode("utf-8")) + raise RuntimeError(f"{e}\n\n{e.stderr.decode('utf-8')}") from e + yield + + # Clean up toolbox server + toolbox_server.terminate() + toolbox_server.wait() diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 00000000..5ad4e8e1 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,439 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from aiohttp import ClientSession + +from toolbox_langchain_sdk.client import ToolboxClient +from toolbox_langchain_sdk.utils import ManifestSchema + + +@pytest.fixture +def manifest_schema(): + return ManifestSchema( + **{ + "serverVersion": "1.0.0", + "tools": { + "test_tool_1": { + "description": "Test Tool 1 Description", + "parameters": [ + {"name": "param1", "type": "string", "description": "Param 1"} + ], + }, + "test_tool_2": { + "description": "Test Tool 2 Description", + "parameters": [ + {"name": "param2", "type": "integer", "description": "Param 2"} + ], + }, + }, + } + ) + + +@pytest.fixture +def mock_auth_tokens(): + return {"test-auth-source": lambda: "test-token"} + + +@pytest.fixture +def mock_bound_params(): + return {"param1": "bound-value"} + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client.ClientSession") +async def test_toolbox_client_init(mock_client): + client = ToolboxClient(url="https://test-url", session=mock_client) + assert client._url == "https://test-url" + assert client._session == mock_client + + +@pytest.fixture(params=[True, False]) +@patch("toolbox_langchain_sdk.client.ClientSession") +def toolbox_client(MockClientSession, request): + """ + Fixture to provide a ToolboxClient with and without a provided session. + """ + if request.param: + # Client with a provided session + session = MockClientSession.return_value + client = ToolboxClient(url="https://test-url", session=session) + yield client + else: + # Client that creates its own session + client = ToolboxClient(url="https://test-url") + yield client + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client.ClientSession") +async def test_toolbox_client_close(MockClientSession, toolbox_client): + MockClientSession.return_value.close = AsyncMock() + for client in toolbox_client: + assert not client._session.close.called + await client.close() + if client._should_close_session: + # Assert session is closed only if it was created by the client + assert client._session.closed + else: + # Assert session is NOT closed if it was provided + assert not client._session.close.called + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client.ClientSession") +async def test_toolbox_client_del(MockClientSession, toolbox_client): + MockClientSession.return_value.close = AsyncMock() + for client in toolbox_client: + client_session = client._session + assert not client_session.close.called + client.__del__() + assert not client_session.close.called + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client._load_manifest") +async def test_toolbox_client_load_tool_manifest(mock_load_manifest): + mock_load_manifest.return_value = AsyncMock( + return_value={"tools": {"test_tool": {"description": "Test Tool Description"}}} + ) + async with ClientSession() as session: + client = ToolboxClient(url="https://test-url", session=session) + manifest = await client._load_tool_manifest("test_tool") + assert manifest == ( # Call the mock object to get its return value + mock_load_manifest.return_value # This will return the dictionary + ) + mock_load_manifest.assert_called_once_with( + "https://test-url/api/tool/test_tool", session + ) + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client._load_manifest") +async def test_toolbox_client_load_toolset_manifest(mock_load_manifest): + mock_load_manifest.return_value = AsyncMock( + return_value={"tools": {"test_tool": {"description": "Test Tool Description"}}} + ) + async with ClientSession() as session: + client = ToolboxClient(url="https://test-url", session=session) + manifest = await client._load_toolset_manifest("test_toolset") + assert manifest == ( # Call the mock object to get its return value + mock_load_manifest.return_value # This will return the dictionary + ) + mock_load_manifest.assert_called_once_with( + "https://test-url/api/toolset/test_toolset", session + ) + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client._load_manifest") +async def test_toolbox_client_load_toolset_manifest_no_toolset(mock_load_manifest): + mock_load_manifest.return_value = AsyncMock( + return_value={"tools": {"test_tool": {"description": "Test Tool Description"}}} + ) + async with ClientSession() as session: + client = ToolboxClient(url="https://test-url", session=session) + manifest = await client._load_toolset_manifest() + assert manifest == ( # Call the mock object to get its return value + mock_load_manifest.return_value # This will return the dictionary + ) + mock_load_manifest.assert_called_once_with( + "https://test-url/api/toolset/", session + ) + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client.ToolboxTool") +@patch("toolbox_langchain_sdk.client._load_manifest") +async def test_toolbox_client_load_tool(mock_load_manifest, MockToolboxTool): + mock_load_manifest.return_value = AsyncMock( + return_value={"tools": {"test_tool": {"description": "Test Tool Description"}}} + ) + async with ClientSession() as session: + client = ToolboxClient(url="https://test-url", session=session) + tool = await client.load_tool("test_tool") + assert tool == MockToolboxTool.return_value + MockToolboxTool.assert_called_once_with( + "test_tool", + mock_load_manifest.return_value.tools.__getitem__( + "test_tool" + ), # Correctly access the tool schema + "https://test-url", + session, + {}, + {}, + True, + ) + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client.ToolboxTool") +@patch("toolbox_langchain_sdk.client._load_manifest") +async def test_toolbox_client_load_tool_with_auth( + mock_load_manifest, MockToolboxTool, mock_auth_tokens +): + mock_load_manifest.return_value = AsyncMock( + return_value={"tools": {"test_tool": {"description": "Test Tool Description"}}} + ) + async with ClientSession() as session: + client = ToolboxClient(url="https://test-url", session=session) + tool = await client.load_tool("test_tool", auth_tokens=mock_auth_tokens) + assert tool == MockToolboxTool.return_value + MockToolboxTool.assert_called_once_with( + "test_tool", + mock_load_manifest.return_value.tools.__getitem__("test_tool"), + "https://test-url", + session, + mock_auth_tokens, + {}, + True, + ) + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client.ToolboxTool") +@patch("toolbox_langchain_sdk.client._load_manifest") +async def test_toolbox_client_load_tool_with_auth_headers( + mock_load_manifest, MockToolboxTool, mock_auth_tokens +): + mock_load_manifest.return_value = AsyncMock( + return_value={"tools": {"test_tool": {"description": "Test Tool Description"}}} + ) + async with ClientSession() as session: + client = ToolboxClient(url="https://test-url", session=session) + with pytest.warns( + DeprecationWarning, + match="Argument `auth_headers` is deprecated. Use `auth_tokens` instead.", + ): + tool = await client.load_tool("test_tool", auth_headers=mock_auth_tokens) + assert tool == MockToolboxTool.return_value + MockToolboxTool.assert_called_once_with( + "test_tool", + mock_load_manifest.return_value.tools.__getitem__("test_tool"), + "https://test-url", + session, + mock_auth_tokens, + {}, + True, + ) + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client.ToolboxTool") +@patch("toolbox_langchain_sdk.client._load_manifest") +async def test_toolbox_client_load_tool_with_auth_and_headers( + mock_load_manifest, MockToolboxTool, mock_auth_tokens +): + mock_load_manifest.return_value = AsyncMock( + return_value={"tools": {"test_tool": {"description": "Test Tool Description"}}} + ) + async with ClientSession() as session: + client = ToolboxClient(url="https://test-url", session=session) + with pytest.warns( + DeprecationWarning, + match="Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.", + ): + tool = await client.load_tool( + "test_tool", auth_tokens=mock_auth_tokens, auth_headers=mock_auth_tokens + ) + assert tool == MockToolboxTool.return_value + MockToolboxTool.assert_called_once_with( + "test_tool", + mock_load_manifest.return_value.tools.__getitem__("test_tool"), + "https://test-url", + session, + mock_auth_tokens, + {}, + True, + ) + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client.ToolboxTool") +@patch("toolbox_langchain_sdk.client._load_manifest") +async def test_toolbox_client_load_tool_with_bound_params( + mock_load_manifest, MockToolboxTool, mock_bound_params +): + mock_load_manifest.return_value = AsyncMock( + return_value={"tools": {"test_tool": {"description": "Test Tool Description"}}} + ) + async with ClientSession() as session: + client = ToolboxClient(url="https://test-url", session=session) + tool = await client.load_tool("test_tool", bound_params=mock_bound_params) + assert tool == MockToolboxTool.return_value + MockToolboxTool.assert_called_once_with( + "test_tool", + mock_load_manifest.return_value.tools.__getitem__("test_tool"), + "https://test-url", + session, + {}, + mock_bound_params, + True, + ) + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client._load_manifest") +async def test_toolbox_client_load_toolset( + mock_load_manifest, toolbox_client, manifest_schema +): + mock_load_manifest.return_value = manifest_schema + for client in toolbox_client: + tools = await client.load_toolset() + assert [tool._schema for tool in tools] == list(manifest_schema.tools.values()) + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client.ToolboxTool") +@patch("toolbox_langchain_sdk.client._load_manifest") +async def test_toolbox_client_load_toolset_with_auth( + mock_load_manifest, + mock_toolbox_tool, + toolbox_client, + manifest_schema, + mock_auth_tokens, +): + mock_load_manifest.return_value = manifest_schema + for client in toolbox_client: + tools = await client.load_toolset(auth_tokens=mock_auth_tokens) + + for i, (tool_name, tool_schema) in enumerate(manifest_schema.tools.items()): + call_args, _ = mock_toolbox_tool.call_args_list[i] + assert call_args[0] == tool_name + assert call_args[1] == tool_schema + assert call_args[2] == client._url + assert call_args[3] == client._session + assert call_args[4] == mock_auth_tokens + assert call_args[5] == {} + + assert len(tools) == len(manifest_schema.tools) + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client.ToolboxTool") +@patch("toolbox_langchain_sdk.client._load_manifest") +async def test_toolbox_client_load_toolset_with_auth_headers( + mock_load_manifest, + mock_toolbox_tool, + toolbox_client, + manifest_schema, + mock_auth_tokens, +): + mock_load_manifest.return_value = manifest_schema + for client in toolbox_client: + with pytest.warns( + DeprecationWarning, + match="Argument `auth_headers` is deprecated. Use `auth_tokens` instead.", + ): + tools = await client.load_toolset(auth_headers=mock_auth_tokens) + + for i, (tool_name, tool_schema) in enumerate(manifest_schema.tools.items()): + call_args, _ = mock_toolbox_tool.call_args_list[i] + assert call_args[0] == tool_name + assert call_args[1] == tool_schema + assert call_args[2] == client._url + assert call_args[3] == client._session + assert call_args[4] == mock_auth_tokens + assert call_args[5] == {} + + assert len(tools) == len(manifest_schema.tools) + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client.ToolboxTool") +@patch("toolbox_langchain_sdk.client._load_manifest") +async def test_toolbox_client_load_toolset_with_auth_and_headers( + mock_load_manifest, + mock_toolbox_tool, + toolbox_client, + manifest_schema, + mock_auth_tokens, +): + mock_load_manifest.return_value = manifest_schema + for client in toolbox_client: + with pytest.warns( + DeprecationWarning, + match="Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.", + ): + tools = await client.load_toolset( + auth_tokens=mock_auth_tokens, auth_headers=mock_auth_tokens + ) + + for i, (tool_name, tool_schema) in enumerate(manifest_schema.tools.items()): + call_args, _ = mock_toolbox_tool.call_args_list[i] + assert call_args[0] == tool_name + assert call_args[1] == tool_schema + assert call_args[2] == client._url + assert call_args[3] == client._session + assert call_args[4] == mock_auth_tokens + assert call_args[5] == {} + + assert len(tools) == len(manifest_schema.tools) + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client.ToolboxTool") +@patch("toolbox_langchain_sdk.client._load_manifest") +async def test_toolbox_client_load_toolset_with_bound_params( + mock_load_manifest, + mock_toolbox_tool, + toolbox_client, + manifest_schema, + mock_bound_params, +): + mock_load_manifest.return_value = manifest_schema + for client in toolbox_client: + tools = await client.load_toolset(bound_params=mock_bound_params) + + for i, (tool_name, tool_schema) in enumerate(manifest_schema.tools.items()): + call_args, _ = mock_toolbox_tool.call_args_list[i] + assert call_args[0] == tool_name + assert call_args[1] == tool_schema + assert call_args[2] == client._url + assert call_args[3] == client._session + assert call_args[4] == {} + assert call_args[5] == mock_bound_params + + assert len(tools) == len(manifest_schema.tools) + + +@pytest.mark.asyncio +async def test_toolbox_client_del_loop_not_running(): + """Test __del__ when the loop is not running.""" + mock_loop = Mock() + mock_loop.is_running.return_value = False + mock_close = Mock(spec=ToolboxClient.close) + + with patch("asyncio.get_event_loop", return_value=mock_loop): + client = ToolboxClient(url="https://test-url") + client.close = mock_close + client.__del__() + + +@pytest.mark.asyncio +async def test_toolbox_client_del_exception(): + """Test __del__ when an exception occurs.""" + mock_loop = Mock() + mock_loop.is_running.return_value = True + mock_loop.create_task.side_effect = Exception("Test Exception") + + with patch("asyncio.get_event_loop", return_value=mock_loop): + client = ToolboxClient(url="https://test-url") + client.__del__() + + # Assert that create_task was called (despite the exception) + mock_loop.create_task.assert_called_once() diff --git a/tests/test_e2e.py b/tests/test_e2e.py new file mode 100644 index 00000000..7ea53f94 --- /dev/null +++ b/tests/test_e2e.py @@ -0,0 +1,174 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end tests for the toolbox SDK interacting with the toolbox server. + +This file covers the following use cases: + +1. Loading a tool. +2. Loading a specific toolset. +3. Loading the default toolset (contains all tools). +4. Running a tool with + a. Missing params. + b. Wrong param type. +5. Running a tool with no required auth, with auth provided. +6. Running a tool with required auth: + a. No auth provided. + b. Wrong auth provided: The tool requires a different authentication + than the one provided. + c. Correct auth provided. +7. Running a tool with a parameter that requires auth: + a. No auth provided. + b. Correct auth provided. + c. Auth provided does not contain the required claim. +""" + +import pytest +import pytest_asyncio +from aiohttp import ClientResponseError +from pydantic import ValidationError + +from toolbox_langchain_sdk.client import ToolboxClient + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestE2EClient: + @pytest_asyncio.fixture(scope="function") + async def toolbox(self): + """Provides a ToolboxClient instance for each test.""" + toolbox = ToolboxClient("http://localhost:5000") + yield toolbox + await toolbox.close() + + #### Basic e2e tests + @pytest.mark.asyncio + async def test_load_tool(self, toolbox): + tool = await toolbox.load_tool("get-n-rows") + response = await tool.ainvoke({"num_rows": "2"}) + result = response["result"] + + assert "row1" in result + assert "row2" in result + assert "row3" not in result + + @pytest.mark.asyncio + async def test_load_toolset_specific(self, toolbox): + toolset = await toolbox.load_toolset("my-toolset") + assert len(toolset) == 1 + assert toolset[0].name == "get-row-by-id" + + toolset = await toolbox.load_toolset("my-toolset-2") + assert len(toolset) == 2 + tool_names = ["get-n-rows", "get-row-by-id"] + assert toolset[0].name in tool_names + assert toolset[1].name in tool_names + + @pytest.mark.asyncio + async def test_load_toolset_all(self, toolbox): + toolset = await toolbox.load_toolset() + assert len(toolset) == 5 + tool_names = [ + "get-n-rows", + "get-row-by-id", + "get-row-by-id-auth", + "get-row-by-email-auth", + "get-row-by-content-auth", + ] + for tool in toolset: + assert tool.name in tool_names + + @pytest.mark.asyncio + async def test_run_tool_missing_params(self, toolbox): + tool = await toolbox.load_tool("get-n-rows") + with pytest.raises(ValidationError, match="Field required"): + await tool.ainvoke({}) + + @pytest.mark.asyncio + async def test_run_tool_wrong_param_type(self, toolbox): + tool = await toolbox.load_tool("get-n-rows") + with pytest.raises(ValidationError, match="Input should be a valid string"): + await tool.ainvoke({"num_rows": 2}) + + ##### Auth tests + @pytest.mark.asyncio + @pytest.mark.skip(reason="b/389574566") + async def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): + """Tests running a tool that doesn't require auth, with auth provided.""" + tool = await toolbox.load_tool( + "get-row-by-id", auth_tokens={"my-test-auth": lambda: auth_token2} + ) + response = await tool.ainvoke({"id": "2"}) + assert "row2" in response["result"] + + @pytest.mark.asyncio + async def test_run_tool_no_auth(self, toolbox): + """Tests running a tool requiring auth without providing auth.""" + tool = await toolbox.load_tool( + "get-row-by-id-auth", + ) + with pytest.raises(ClientResponseError, match="401, message='Unauthorized'"): + await tool.ainvoke({"id": "2"}) + + @pytest.mark.asyncio + async def test_run_tool_wrong_auth(self, toolbox, auth_token2): + """Tests running a tool with incorrect auth.""" + tool = await toolbox.load_tool( + "get-row-by-id-auth", + ) + auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token2) + # TODO: Fix error message (b/389577313) + with pytest.raises(ClientResponseError, match="400, message='Bad Request'"): + await auth_tool.ainvoke({"id": "2"}) + + @pytest.mark.asyncio + async def test_run_tool_auth(self, toolbox, auth_token1): + """Tests running a tool with correct auth.""" + tool = await toolbox.load_tool( + "get-row-by-id-auth", + ) + auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token1) + response = await auth_tool.ainvoke({"id": "2"}) + assert "row2" in response["result"] + + @pytest.mark.asyncio + async def test_run_tool_param_auth_no_auth(self, toolbox): + """Tests running a tool with a param requiring auth, without auth.""" + tool = await toolbox.load_tool("get-row-by-email-auth") + with pytest.raises( + PermissionError, + match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", + ): + await tool.ainvoke({}) + + @pytest.mark.asyncio + async def test_run_tool_param_auth(self, toolbox, auth_token1): + """Tests running a tool with a param requiring auth, with correct auth.""" + tool = await toolbox.load_tool( + "get-row-by-email-auth", auth_tokens={"my-test-auth": lambda: auth_token1} + ) + response = await tool.ainvoke({}) + result = response["result"] + assert "row4" in result + assert "row5" in result + assert "row6" in result + + @pytest.mark.asyncio + async def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): + """Tests running a tool with a param requiring auth, with insufficient auth.""" + tool = await toolbox.load_tool( + "get-row-by-content-auth", auth_tokens={"my-test-auth": lambda: auth_token1} + ) + with pytest.raises(ClientResponseError, match="400, message='Bad Request'"): + await tool.ainvoke({}) diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 00000000..d8b1fffd --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,325 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from pydantic import ValidationError + +from toolbox_langchain_sdk.tools import ToolboxTool + + +@pytest.fixture +def tool_schema(): + return { + "description": "Test Tool Description", + "parameters": [ + {"name": "param1", "type": "string", "description": "Param 1"}, + {"name": "param2", "type": "integer", "description": "Param 2"}, + ], + } + + +@pytest.fixture +def auth_tool_schema(): + return { + "description": "Test Tool Description", + "parameters": [ + { + "name": "param1", + "type": "string", + "description": "Param 1", + "authSources": ["test-auth-source"], + }, + {"name": "param2", "type": "integer", "description": "Param 2"}, + ], + } + + +@pytest.fixture +@patch("aiohttp.ClientSession") +async def toolbox_tool(MockClientSession, tool_schema): + mock_session = MockClientSession.return_value + mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() + mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"result": "test-result"} + ) + tool = ToolboxTool( + name="test_tool", + schema=tool_schema, + url="https://test-url", + session=mock_session, + ) + yield tool + + +@pytest.fixture +@patch("aiohttp.ClientSession") +async def auth_toolbox_tool(MockClientSession, auth_tool_schema): + mock_session = MockClientSession.return_value + mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() + mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"result": "test-result"} + ) + with pytest.warns( + UserWarning, + match="Parameter\(s\) \`param1\` of tool test_tool require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", + ): + tool = ToolboxTool( + name="test_tool", + schema=auth_tool_schema, + url="https://test-url", + session=mock_session, + ) + yield tool + + +@pytest.mark.asyncio +@patch("toolbox_langchain_sdk.client.ClientSession") +async def test_toolbox_tool_init(MockClientSession, tool_schema): + mock_session = MockClientSession.return_value + tool = ToolboxTool( + name="test_tool", + schema=tool_schema, + url="https://test-url", + session=mock_session, + ) + assert tool.name == "test_tool" + assert tool.description == "Test Tool Description" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "params, expected_bound_params", + [ + ({"param1": "bound-value"}, {"param1": "bound-value"}), + ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), + ( + {"param1": "bound-value", "param2": 123}, + {"param1": "bound-value", "param2": 123}, + ), + ], +) +async def test_toolbox_tool_bind_params(toolbox_tool, params, expected_bound_params): + async for tool in toolbox_tool: + tool = tool.bind_params(params) + for key, value in expected_bound_params.items(): + if callable(value): + assert value() == tool._bound_params[key]() + else: + assert value == tool._bound_params[key] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("strict", [True, False]) +async def test_toolbox_tool_bind_params_invalid(toolbox_tool, strict): + async for tool in toolbox_tool: + if strict: + with pytest.raises(ValueError) as e: + tool = tool.bind_params({"param3": "bound-value"}, strict=strict) + assert "Parameter(s) param3 missing and cannot be bound." in str(e.value) + else: + with pytest.warns(UserWarning) as record: + tool = tool.bind_params({"param3": "bound-value"}, strict=strict) + assert len(record) == 1 + assert "Parameter(s) param3 missing and cannot be bound." in str( + record[0].message + ) + + +@pytest.mark.asyncio +async def test_toolbox_tool_bind_params_duplicate(toolbox_tool): + async for tool in toolbox_tool: + tool = tool.bind_params({"param1": "bound-value"}) + with pytest.raises(ValueError) as e: + tool = tool.bind_params({"param1": "bound-value"}) + assert "Parameter(s) `param1` already bound in tool `test_tool`." in str( + e.value + ) + + +@pytest.mark.asyncio +async def test_toolbox_tool_bind_params_invalid_params(auth_toolbox_tool): + async for tool in auth_toolbox_tool: + with pytest.raises(ValueError) as e: + tool = tool.bind_params({"param1": "bound-value"}) + assert "Parameter(s) param1 already authenticated and cannot be bound." in str( + e.value + ) + + +@pytest.mark.asyncio +async def test_toolbox_tool_bind_param(toolbox_tool): + async for tool in toolbox_tool: + tool = tool.bind_param("param1", "bound-value") + assert tool._bound_params == {"param1": "bound-value"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("strict", [True, False]) +async def test_toolbox_tool_bind_param_invalid(toolbox_tool, strict): + async for tool in toolbox_tool: + if strict: + with pytest.raises(ValueError) as e: + tool = tool.bind_param("param3", "bound-value", strict=strict) + assert "Parameter(s) param3 missing and cannot be bound." in str(e.value) + else: + with pytest.warns(UserWarning) as record: + tool = tool.bind_param("param3", "bound-value", strict=strict) + assert len(record) == 1 + assert "Parameter(s) param3 missing and cannot be bound." in str( + record[0].message + ) + + +@pytest.mark.asyncio +async def test_toolbox_tool_bind_param_duplicate(toolbox_tool): + async for tool in toolbox_tool: + tool = tool.bind_param("param1", "bound-value") + with pytest.raises(ValueError) as e: + tool = tool.bind_param("param1", "bound-value") + assert "Parameter(s) `param1` already bound in tool `test_tool`." in str( + e.value + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "auth_tokens, expected_auth_tokens", + [ + ( + {"test-auth-source": lambda: "test-token"}, + {"test-auth-source": lambda: "test-token"}, + ), + ( + { + "test-auth-source": lambda: "test-token", + "another-auth-source": lambda: "another-token", + }, + { + "test-auth-source": lambda: "test-token", + "another-auth-source": lambda: "another-token", + }, + ), + ], +) +async def test_toolbox_tool_add_auth_tokens( + auth_toolbox_tool, auth_tokens, expected_auth_tokens +): + async for tool in auth_toolbox_tool: + tool = tool.add_auth_tokens(auth_tokens) + for source, getter in expected_auth_tokens.items(): + assert tool._auth_tokens[source]() == getter() + + +@pytest.mark.asyncio +async def test_toolbox_tool_add_auth_tokens_duplicate(auth_toolbox_tool): + async for tool in auth_toolbox_tool: + tool = tool.add_auth_tokens({"test-auth-source": lambda: "test-token"}) + with pytest.raises(ValueError) as e: + tool = tool.add_auth_tokens({"test-auth-source": lambda: "test-token"}) + assert ( + "Authentication source(s) `test-auth-source` already registered in tool `test_tool`." + in str(e.value) + ) + + +@pytest.mark.asyncio +async def test_toolbox_tool_add_auth_token(auth_toolbox_tool): + async for tool in auth_toolbox_tool: + tool = tool.add_auth_token("test-auth-source", lambda: "test-token") + assert tool._auth_tokens["test-auth-source"]() == "test-token" + + +@pytest.mark.asyncio +async def test_toolbox_tool_validate_auth_strict(auth_toolbox_tool): + async for tool in auth_toolbox_tool: + with pytest.raises(PermissionError) as e: + tool._ToolboxTool__validate_auth(strict=True) + assert ( + "Parameter(s) `param1` of tool test_tool require authentication, but no valid authentication sources are registered. Please register the required sources before use." + in str(e.value) + ) + + +@pytest.mark.asyncio +async def test_toolbox_tool_call_with_callable_bound_params(toolbox_tool): + async for tool in toolbox_tool: + tool = tool.bind_param("param1", lambda: "bound-value") + result = await tool.ainvoke({"param2": 123}) + assert result == {"result": "test-result"} + + +@pytest.mark.asyncio +async def test_toolbox_tool_call(toolbox_tool): + async for tool in toolbox_tool: + result = await tool.ainvoke({"param1": "test-value", "param2": 123}) + assert result == {"result": "test-result"} + + +@pytest.mark.asyncio +async def test_toolbox_sync_tool_call_(toolbox_tool): + async for tool in toolbox_tool: + with pytest.raises(NotImplementedError) as e: + result = tool.invoke({"param1": "test-value", "param2": 123}) + assert "Sync tool calls not supported yet." in str(e.value) + + +@pytest.mark.asyncio +async def test_toolbox_tool_call_with_bound_params(toolbox_tool): + async for tool in toolbox_tool: + tool = tool.bind_params({"param1": "bound-value"}) + result = await tool.ainvoke({"param2": 123}) + assert result == {"result": "test-result"} + + +@pytest.mark.asyncio +async def test_toolbox_tool_call_with_auth_tokens(auth_toolbox_tool): + async for tool in auth_toolbox_tool: + tool = tool.add_auth_tokens({"test-auth-source": lambda: "test-token"}) + result = await tool.ainvoke({"param2": 123}) + assert result == {"result": "test-result"} + + +@pytest.mark.asyncio +async def test_toolbox_tool_call_with_auth_tokens_insecure(auth_toolbox_tool): + async for tool in auth_toolbox_tool: + with pytest.warns( + UserWarning, + match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", + ): + tool._url = "http://test-url" + tool = tool.add_auth_tokens({"test-auth-source": lambda: "test-token"}) + result = await tool.ainvoke({"param2": 123}) + assert result == {"result": "test-result"} + + +@pytest.mark.asyncio +async def test_toolbox_tool_call_with_invalid_input(toolbox_tool): + async for tool in toolbox_tool: + with pytest.raises(ValidationError) as e: + await tool.ainvoke({"param1": 123, "param2": "invalid"}) + assert "2 validation errors for test_tool" in str(e.value) + assert "param1\n Input should be a valid string" in str(e.value) + assert "param2\n Input should be a valid integer" in str(e.value) + + +@pytest.mark.asyncio +async def test_toolbox_tool_call_with_empty_input(toolbox_tool): + async for tool in toolbox_tool: + with pytest.raises(ValidationError) as e: + await tool.ainvoke({}) + assert "2 validation errors for test_tool" in str(e.value) + assert "param1\n Field required" in str(e.value) + assert "param2\n Field required" in str(e.value) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..748e2b4a --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,271 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import json +import re +import warnings +from typing import Union +from unittest.mock import AsyncMock, Mock, patch + +import aiohttp +import pytest +from pydantic import BaseModel + +from toolbox_langchain_sdk.utils import ( + ParameterSchema, + _convert_none_to_empty_string, + _get_auth_headers, + _invoke_tool, + _load_manifest, + _parse_type, + _schema_to_model, +) + +URL = "https://my-toolbox.com/test" +MOCK_MANIFEST = """ +{ + "serverVersion": "0.0.1", + "tools": { + "test_tool": { + "summary": "Test Tool", + "description": "This is a test tool.", + "parameters": [ + { + "name": "param1", + "type": "string", + "description": "Parameter 1" + }, + { + "name": "param2", + "type": "integer", + "description": "Parameter 2" + } + ] + } + } +} +""" + + +class TestUtils: + @pytest.fixture(scope="module") + def mock_manifest(self): + return aiohttp.ClientResponse( + method="GET", + url=aiohttp.client.URL(URL), + writer=None, + continue100=None, + timer=None, + request_info=None, + traces=None, + session=None, + loop=asyncio.get_event_loop(), + ) + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession.get") + async def test_load_manifest(self, mock_get, mock_manifest): + mock_manifest.raise_for_status = Mock() + mock_manifest.text = AsyncMock(return_value=MOCK_MANIFEST) + + mock_get.return_value = mock_manifest + session = aiohttp.ClientSession() + manifest = await _load_manifest(URL, session) + await session.close() + mock_get.assert_called_once_with(URL) + + assert manifest.serverVersion == "0.0.1" + assert len(manifest.tools) == 1 + + tool = manifest.tools["test_tool"] + assert tool.description == "This is a test tool." + assert tool.parameters == [ + ParameterSchema(name="param1", type="string", description="Parameter 1"), + ParameterSchema(name="param2", type="integer", description="Parameter 2"), + ] + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession.get") + async def test_load_manifest_invalid_json(self, mock_get, mock_manifest): + mock_manifest.raise_for_status = Mock() + mock_manifest.text = AsyncMock(return_value="{ invalid manifest") + mock_get.return_value = mock_manifest + + with pytest.raises(Exception) as e: + session = aiohttp.ClientSession() + await _load_manifest(URL, session) + + mock_get.assert_called_once_with(URL) + assert isinstance(e.value, json.JSONDecodeError) + assert ( + str(e.value) + == "Failed to parse JSON from https://my-toolbox.com/test: Expecting property name enclosed in double quotes: line 1 column 3 (char 2): line 1 column 3 (char 2)" + ) + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession.get") + async def test_load_manifest_invalid_manifest(self, mock_get, mock_manifest): + mock_manifest.raise_for_status = Mock() + mock_manifest.text = AsyncMock(return_value='{ "something": "invalid" }') + mock_get.return_value = mock_manifest + + with pytest.raises(Exception) as e: + session = aiohttp.ClientSession() + await _load_manifest(URL, session) + + mock_get.assert_called_once_with(URL) + assert isinstance(e.value, ValueError) + assert re.match( + r"Invalid JSON data from https://my-toolbox.com/test: 2 validation errors for ManifestSchema\nserverVersion\n Field required \[type=missing, input_value={'something': 'invalid'}, input_type=dict]\n For further information visit https://errors.pydantic.dev/\d+\.\d+/v/missing\ntools\n Field required \[type=missing, input_value={'something': 'invalid'}, input_type=dict]\n For further information visit https://errors.pydantic.dev/\d+\.\d+/v/missing", + str(e.value), + ) + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession.get") + async def test_load_manifest_api_error(self, mock_get, mock_manifest): + error = aiohttp.ClientError("Simulated HTTP Error") + mock_manifest.raise_for_status = Mock() + mock_manifest.text = AsyncMock(side_effect=error) + mock_get.return_value = mock_manifest + + with pytest.raises(aiohttp.ClientError) as exc_info: + session = aiohttp.ClientSession() + await _load_manifest(URL, session) + mock_get.assert_called_once_with(URL) + assert exc_info.value == error + + def test_schema_to_model(self): + schema = [ + ParameterSchema(name="param1", type="string", description="Parameter 1"), + ParameterSchema(name="param2", type="integer", description="Parameter 2"), + ] + model = _schema_to_model("TestModel", schema) + assert issubclass(model, BaseModel) + + assert model.model_fields["param1"].annotation == Union[str, None] + assert model.model_fields["param1"].description == "Parameter 1" + assert model.model_fields["param2"].annotation == Union[int, None] + assert model.model_fields["param2"].description == "Parameter 2" + + def test_schema_to_model_empty(self): + model = _schema_to_model("TestModel", []) + assert issubclass(model, BaseModel) + assert len(model.model_fields) == 0 + + @pytest.mark.parametrize( + "type_string, expected_type", + [ + ("string", str), + ("integer", int), + ("number", float), + ("boolean", bool), + ("array", list), + ], + ) + def test_parse_type(self, type_string, expected_type): + assert _parse_type(type_string) == expected_type + + def test_parse_type_invalid(self): + with pytest.raises(ValueError): + _parse_type("invalid") + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession.post") + async def test_invoke_tool(self, mock_post): + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json = AsyncMock(return_value={"key": "value"}) + mock_post.return_value.__aenter__.return_value = mock_response + + result = await _invoke_tool( + "http://localhost:8000", + aiohttp.ClientSession(), + "tool_name", + {"input": "data"}, + {}, + ) + + mock_post.assert_called_once_with( + "http://localhost:8000/api/tool/tool_name/invoke", + json=_convert_none_to_empty_string({"input": "data"}), + headers={}, + ) + assert result == {"key": "value"} + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession.post") + async def test_invoke_tool_unsecure_with_auth(self, mock_post): + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json = AsyncMock(return_value={"key": "value"}) + mock_post.return_value.__aenter__.return_value = mock_response + + with pytest.warns( + UserWarning, + match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", + ): + result = await _invoke_tool( + "http://localhost:8000", + aiohttp.ClientSession(), + "tool_name", + {"input": "data"}, + {"my_test_auth": lambda: "fake_id_token"}, + ) + + mock_post.assert_called_once_with( + "http://localhost:8000/api/tool/tool_name/invoke", + json=_convert_none_to_empty_string({"input": "data"}), + headers={"my_test_auth_token": "fake_id_token"}, + ) + assert result == {"key": "value"} + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession.post") + async def test_invoke_tool_secure_with_auth(self, mock_post): + session = aiohttp.ClientSession() + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json = AsyncMock(return_value={"key": "value"}) + mock_post.return_value.__aenter__.return_value = mock_response + + with warnings.catch_warnings(): + warnings.simplefilter("error") + result = await _invoke_tool( + "https://localhost:8000", + session, + "tool_name", + {"input": "data"}, + {"my_test_auth": lambda: "fake_id_token"}, + ) + + mock_post.assert_called_once_with( + "https://localhost:8000/api/tool/tool_name/invoke", + json=_convert_none_to_empty_string({"input": "data"}), + headers={"my_test_auth_token": "fake_id_token"}, + ) + assert result == {"key": "value"} + + def test_convert_none_to_empty_string(self): + input_dict = {"a": None, "b": 123} + expected_output = {"a": "", "b": 123} + assert _convert_none_to_empty_string(input_dict) == expected_output + + def test_get_auth_headers_deprecation_warning(self): + """Test _get_auth_headers deprecation warning.""" + with pytest.warns( + DeprecationWarning, + match=r"Call to deprecated function \(or staticmethod\) _get_auth_headers\. \(Please use `_get_auth_tokens` instead\.\)$", + ): + _get_auth_headers({"auth_source1": lambda: "test_token"})