Skip to content

Commit 82c6836

Browse files
authored
feat: Add support for checking tool-level auth before tool invocation. (#72)
This is to prevent tool invocation if the required tool-level auth is missing. This is similar to how we prevent tool invocation if a parameter-level auth is missing.
1 parent 7260209 commit 82c6836

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/toolbox_langchain/async_tools.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,15 @@ def __validate_auth(self, strict: bool = True) -> None:
194194
PermissionError: If strict is True and any required authentication
195195
sources are not registered.
196196
"""
197+
is_authenticated: bool = not self.__schema.authRequired
197198
params_missing_auth: list[str] = []
198199

200+
# Check tool for at least 1 required auth source
201+
for src in self.__schema.authRequired:
202+
if src in self.__auth_tokens:
203+
is_authenticated = True
204+
break
205+
199206
# Check each parameter for at least 1 required auth source
200207
for param in self.__auth_params:
201208
if not param.authSources:
@@ -210,9 +217,20 @@ def __validate_auth(self, strict: bool = True) -> None:
210217
if not has_auth:
211218
params_missing_auth.append(param.name)
212219

220+
messages: list[str] = []
221+
222+
if not is_authenticated:
223+
messages.append(
224+
f"Tool {self.__name} requires authentication, but no valid authentication sources are registered. Please register the required sources before use."
225+
)
226+
213227
if params_missing_auth:
214-
message = f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication sources are registered. Please register the required sources before use."
228+
messages.append(
229+
f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication sources are registered. Please register the required sources before use."
230+
)
215231

232+
if messages:
233+
message = "\n\n".join(messages)
216234
if strict:
217235
raise PermissionError(message)
218236
warn(message)

src/toolbox_langchain/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class ToolSchema(BaseModel):
4040

4141
description: str
4242
parameters: list[ParameterSchema]
43+
authRequired: list[str] = []
4344

4445

4546
class ManifestSchema(BaseModel):

0 commit comments

Comments
 (0)