Skip to content

Commit 4f21324

Browse files
committed
Update Guardrail base class for args parsing.
1 parent 58978b8 commit 4f21324

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

ads/llm/guardrails/base.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import datetime
99
import functools
1010
import operator
11-
from typing import Any, List
11+
from typing import Any, List, Dict, Tuple
1212
from langchain.schema.prompt import PromptValue
1313
from langchain.tools.base import BaseTool, ToolException
1414
from langchain.pydantic_v1 import BaseModel, root_validator
@@ -217,6 +217,12 @@ def _preprocess(self, input: Any) -> str:
217217
return input.to_string()
218218
return str(input)
219219

220+
def _to_args_and_kwargs(self, tool_input: Any) -> Tuple[Tuple, Dict]:
221+
if isinstance(tool_input, dict):
222+
return (), tool_input
223+
else:
224+
return (tool_input,), {}
225+
220226
def _run(self, query: Any, run_manager=None) -> Any:
221227
"""Runs the guardrail.
222228
@@ -247,7 +253,11 @@ def _run(self, query: Any, run_manager=None) -> Any:
247253
# containing the ``kwargs`` used to initialize the object.
248254
# The ``kwargs`` does not contain the defaults.
249255
# Here the ``dict()`` method is used to return a dictionary containing the defaults.
250-
info.parameters = self.dict()
256+
info.parameters = {
257+
"class": self.__class__.__name__,
258+
"path": self.__module__,
259+
"spec": self.dict(),
260+
}
251261
info.metrics = self.compute(data, **kwargs)
252262
info.output = self.moderate(info.metrics, data, **kwargs)
253263

0 commit comments

Comments
 (0)