diff --git a/guardrails/__init__.py b/guardrails/__init__.py
index 6e28dff7f..b1d081fca 100644
--- a/guardrails/__init__.py
+++ b/guardrails/__init__.py
@@ -1,6 +1,7 @@
# Set up __init__.py so that users can do from guardrails import Response, Schema, etc.
from guardrails.guard import Guard
+from guardrails.async_guard import AsyncGuard
from guardrails.llm_providers import PromptCallableBase
from guardrails.logging_utils import configure_logging
from guardrails.prompt import Instructions, Prompt
@@ -10,6 +11,7 @@
__all__ = [
"Guard",
+ "AsyncGuard",
"PromptCallableBase",
"Rail",
"Validator",
diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py
new file mode 100644
index 000000000..e689ac698
--- /dev/null
+++ b/guardrails/async_guard.py
@@ -0,0 +1,449 @@
+import contextvars
+import inspect
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Union,
+)
+
+from guardrails import Guard
+from guardrails.classes import OT, ValidationOutcome
+from guardrails.classes.history import Call
+from guardrails.classes.history.call_inputs import CallInputs
+from guardrails.llm_providers import get_async_llm_ask, model_is_supported_server_side
+from guardrails.logger import set_scope
+from guardrails.run import AsyncRunner
+from guardrails.stores.context import set_call_kwargs, set_tracer, set_tracer_context
+
+
+class AsyncGuard(Guard):
+ """The Guard class.
+
+ This class one of the main entry point for using Guardrails. It is
+ initialized from one of the following class methods:
+
+ - `from_rail`
+ - `from_rail_string`
+ - `from_pydantic`
+ - `from_string`
+
+ The `__call__`
+ method functions as a wrapper around LLM APIs. It takes in an Async LLM
+ API, and optional prompt parameters, and returns the raw output stream from
+ the LLM and the validated output stream.
+ """
+
+ async def __call__(
+ self,
+ llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]],
+ prompt_params: Optional[Dict] = None,
+ num_reasks: Optional[int] = None,
+ prompt: Optional[str] = None,
+ instructions: Optional[str] = None,
+ msg_history: Optional[List[Dict]] = None,
+ metadata: Optional[Dict] = None,
+ full_schema_reask: Optional[bool] = None,
+ *args,
+ **kwargs,
+ ) -> Union[
+ Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]],
+ Awaitable[ValidationOutcome[OT]],
+ ]:
+ """Call the LLM and validate the output. Pass an async LLM API to
+ return a coroutine.
+
+ Args:
+ llm_api: The LLM API to call
+ (e.g. openai.Completion.create or openai.Completion.acreate)
+ prompt_params: The parameters to pass to the prompt.format() method.
+ num_reasks: The max times to re-ask the LLM for invalid output.
+ prompt: The prompt to use for the LLM.
+ instructions: Instructions for chat models.
+ msg_history: The message history to pass to the LLM.
+ metadata: Metadata to pass to the validators.
+ full_schema_reask: When reasking, whether to regenerate the full schema
+ or just the incorrect values.
+ Defaults to `True` if a base model is provided,
+ `False` otherwise.
+
+ Returns:
+ The raw text output from the LLM and the validated output.
+ """
+
+ async def __call(
+ self,
+ llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]],
+ prompt_params: Optional[Dict] = None,
+ num_reasks: Optional[int] = None,
+ prompt: Optional[str] = None,
+ instructions: Optional[str] = None,
+ msg_history: Optional[List[Dict]] = None,
+ metadata: Optional[Dict] = None,
+ full_schema_reask: Optional[bool] = None,
+ *args,
+ **kwargs,
+ ):
+ if metadata is None:
+ metadata = {}
+ if full_schema_reask is None:
+ full_schema_reask = self.base_model is not None
+ if prompt_params is None:
+ prompt_params = {}
+
+ if not self._disable_tracer:
+ # Create a new span for this guard call
+ self._hub_telemetry.create_new_span(
+ span_name="/guard_call",
+ attributes=[
+ ("guard_id", self._guard_id),
+ ("user_id", self._user_id),
+ ("llm_api", llm_api.__name__ if llm_api else "None"),
+ ("custom_reask_prompt", self.reask_prompt is not None),
+ (
+ "custom_reask_instructions",
+ self.reask_instructions is not None,
+ ),
+ ],
+ is_parent=True, # It will have children
+ has_parent=False, # Has no parents
+ )
+
+ set_call_kwargs(kwargs)
+ set_tracer(self._tracer)
+ set_tracer_context(self._tracer_context)
+
+ self.configure(num_reasks)
+ if self.num_reasks is None:
+ raise RuntimeError(
+ "`num_reasks` is `None` after calling `configure()`. "
+ "This should never happen."
+ )
+
+ input_prompt = prompt or (self.prompt._source if self.prompt else None)
+ input_instructions = instructions or (
+ self.instructions._source if self.instructions else None
+ )
+ call_inputs = CallInputs(
+ llm_api=llm_api,
+ prompt=input_prompt,
+ instructions=input_instructions,
+ msg_history=msg_history,
+ prompt_params=prompt_params,
+ num_reasks=self.num_reasks,
+ metadata=metadata,
+ full_schema_reask=full_schema_reask,
+ args=list(args),
+ kwargs=kwargs,
+ )
+ call_log = Call(inputs=call_inputs)
+ set_scope(str(id(call_log)))
+ self.history.push(call_log)
+
+ if self._api_client is not None and model_is_supported_server_side(
+ llm_api, *args, **kwargs
+ ):
+ return self._call_server(
+ llm_api=llm_api,
+ num_reasks=self.num_reasks,
+ prompt_params=prompt_params,
+ full_schema_reask=full_schema_reask,
+ call_log=call_log,
+ *args,
+ **kwargs,
+ )
+
+ # If the LLM API is not async, fail
+ # FIXME: it seems like this check isn't actually working?
+ if not inspect.isawaitable(llm_api) and not inspect.iscoroutinefunction(
+ llm_api
+ ):
+ raise RuntimeError(
+ f"The LLM API `{llm_api.__name__}` is not a coroutine. "
+ "Please use an async LLM API."
+ )
+ # Otherwise, call the LLM
+ return await self._call_async(
+ llm_api,
+ prompt_params=prompt_params,
+ num_reasks=self.num_reasks,
+ prompt=prompt,
+ instructions=instructions,
+ msg_history=msg_history,
+ metadata=metadata,
+ full_schema_reask=full_schema_reask,
+ call_log=call_log,
+ *args,
+ **kwargs,
+ )
+
+ guard_context = contextvars.Context()
+ return await guard_context.run(
+ __call,
+ self,
+ llm_api,
+ prompt_params,
+ num_reasks,
+ prompt,
+ instructions,
+ msg_history,
+ metadata,
+ full_schema_reask,
+ *args,
+ **kwargs,
+ )
+
+ async def _call_async(
+ self,
+ llm_api: Callable[[Any], Awaitable[Any]],
+ prompt_params: Dict,
+ num_reasks: int,
+ prompt: Optional[str],
+ instructions: Optional[str],
+ msg_history: Optional[List[Dict]],
+ metadata: Dict,
+ full_schema_reask: bool,
+ call_log: Call,
+ *args,
+ **kwargs,
+ ) -> ValidationOutcome[OT]:
+ """Call the LLM asynchronously and validate the output.
+
+ Args:
+ llm_api: The LLM API to call asynchronously (e.g. openai.Completion.acreate)
+ prompt_params: The parameters to pass to the prompt.format() method.
+ num_reasks: The max times to re-ask the LLM for invalid output.
+ prompt: The prompt to use for the LLM.
+ instructions: Instructions for chat models.
+ msg_history: The message history to pass to the LLM.
+ metadata: Metadata to pass to the validators.
+ full_schema_reask: When reasking, whether to regenerate the full schema
+ or just the incorrect values.
+ Defaults to `True` if a base model is provided,
+ `False` otherwise.
+
+ Returns:
+ The raw text output from the LLM and the validated output.
+ """
+ instructions_obj = instructions or self.instructions
+ prompt_obj = prompt or self.prompt
+ msg_history_obj = msg_history or []
+ if prompt_obj is None:
+ if msg_history_obj is not None and not len(msg_history_obj):
+ raise RuntimeError(
+ "You must provide a prompt if msg_history is empty. "
+ "Alternatively, you can provide a prompt in the RAIL spec."
+ )
+
+ runner = AsyncRunner(
+ instructions=instructions_obj,
+ prompt=prompt_obj,
+ msg_history=msg_history_obj,
+ api=get_async_llm_ask(llm_api, *args, **kwargs),
+ prompt_schema=self.prompt_schema,
+ instructions_schema=self.instructions_schema,
+ msg_history_schema=self.msg_history_schema,
+ output_schema=self.output_schema,
+ num_reasks=num_reasks,
+ metadata=metadata,
+ base_model=self.base_model,
+ full_schema_reask=full_schema_reask,
+ disable_tracer=self._disable_tracer,
+ )
+ call = await runner.async_run(call_log=call_log, prompt_params=prompt_params)
+ return ValidationOutcome[OT].from_guard_history(call)
+
+ async def parse(
+ self,
+ llm_output: str,
+ metadata: Optional[Dict] = None,
+ llm_api: Optional[Callable] = None,
+ num_reasks: Optional[int] = None,
+ prompt_params: Optional[Dict] = None,
+ full_schema_reask: Optional[bool] = None,
+ *args,
+ **kwargs,
+ ) -> Union[ValidationOutcome[OT], Awaitable[ValidationOutcome[OT]]]:
+ """Alternate flow to using Guard where the llm_output is known.
+
+ Args:
+ llm_output: The output being parsed and validated.
+ metadata: Metadata to pass to the validators.
+ llm_api: The LLM API to call
+ (e.g. openai.Completion.create or openai.Completion.acreate)
+ num_reasks: The max times to re-ask the LLM for invalid output.
+ prompt_params: The parameters to pass to the prompt.format() method.
+ full_schema_reask: When reasking, whether to regenerate the full schema
+ or just the incorrect values.
+
+ Returns:
+ The validated response. This is either a string or a dictionary,
+ determined by the object schema defined in the RAILspec.
+ """
+
+ async def __parse(
+ self,
+ llm_output: str,
+ metadata: Optional[Dict] = None,
+ llm_api: Optional[Callable] = None,
+ num_reasks: Optional[int] = None,
+ prompt_params: Optional[Dict] = None,
+ full_schema_reask: Optional[bool] = None,
+ *args,
+ **kwargs,
+ ):
+ final_num_reasks = (
+ num_reasks if num_reasks is not None else 0 if llm_api is None else None
+ )
+
+ if not self._disable_tracer:
+ self._hub_telemetry.create_new_span(
+ span_name="/guard_parse",
+ attributes=[
+ ("guard_id", self._guard_id),
+ ("user_id", self._user_id),
+ ("llm_api", llm_api.__name__ if llm_api else "None"),
+ ("custom_reask_prompt", self.reask_prompt is not None),
+ (
+ "custom_reask_instructions",
+ self.reask_instructions is not None,
+ ),
+ ],
+ is_parent=True, # It will have children
+ has_parent=False, # Has no parents
+ )
+
+ self.configure(final_num_reasks)
+ if self.num_reasks is None:
+ raise RuntimeError(
+ "`num_reasks` is `None` after calling `configure()`. "
+ "This should never happen."
+ )
+ if full_schema_reask is None:
+ full_schema_reask = self.base_model is not None
+ metadata = metadata or {}
+ prompt_params = prompt_params or {}
+
+ set_call_kwargs(kwargs)
+ set_tracer(self._tracer)
+ set_tracer_context(self._tracer_context)
+
+ input_prompt = self.prompt._source if self.prompt else None
+ input_instructions = (
+ self.instructions._source if self.instructions else None
+ )
+ call_inputs = CallInputs(
+ llm_api=llm_api,
+ llm_output=llm_output,
+ prompt=input_prompt,
+ instructions=input_instructions,
+ prompt_params=prompt_params,
+ num_reasks=self.num_reasks,
+ metadata=metadata,
+ full_schema_reask=full_schema_reask,
+ args=list(args),
+ kwargs=kwargs,
+ )
+ call_log = Call(inputs=call_inputs)
+ set_scope(str(id(call_log)))
+ self.history.push(call_log)
+
+ if self._api_client is not None and model_is_supported_server_side(
+ llm_api, *args, **kwargs
+ ):
+ return self._call_server(
+ llm_output=llm_output,
+ llm_api=llm_api,
+ num_reasks=self.num_reasks,
+ prompt_params=prompt_params,
+ full_schema_reask=full_schema_reask,
+ call_log=call_log,
+ *args,
+ **kwargs,
+ )
+
+ # FIXME: checking not llm_api because it can still fall back on defaults and
+ # function as expected. We should handle this better.
+ if (
+ not llm_api
+ or inspect.iscoroutinefunction(llm_api)
+ or inspect.isasyncgenfunction(llm_api)
+ ):
+ return await self._async_parse(
+ llm_output,
+ metadata,
+ llm_api=llm_api,
+ num_reasks=self.num_reasks,
+ prompt_params=prompt_params,
+ full_schema_reask=full_schema_reask,
+ call_log=call_log,
+ *args,
+ **kwargs,
+ )
+
+ else:
+ raise NotImplementedError(
+ "AsyncGuard does not support non-async LLM APIs. "
+ "Please use the synchronous API Guard or supply an asynchronous "
+ "LLM API."
+ )
+
+ guard_context = contextvars.Context()
+ return await guard_context.run(
+ __parse,
+ self,
+ llm_output,
+ metadata,
+ llm_api,
+ num_reasks,
+ prompt_params,
+ full_schema_reask,
+ *args,
+ **kwargs,
+ )
+
+ async def _async_parse(
+ self,
+ llm_output: str,
+ metadata: Dict,
+ llm_api: Optional[Callable[[Any], Awaitable[Any]]],
+ num_reasks: int,
+ prompt_params: Dict,
+ full_schema_reask: bool,
+ call_log: Call,
+ *args,
+ **kwargs,
+ ) -> ValidationOutcome[OT]:
+ """Alternate flow to using Guard where the llm_output is known.
+
+ Args:
+ llm_output: The output from the LLM.
+ llm_api: The LLM API to use to re-ask the LLM.
+ num_reasks: The max times to re-ask the LLM for invalid output.
+
+ Returns:
+ The validated response.
+ """
+ runner = AsyncRunner(
+ instructions=kwargs.pop("instructions", None),
+ prompt=kwargs.pop("prompt", None),
+ msg_history=kwargs.pop("msg_history", None),
+ api=get_async_llm_ask(llm_api, *args, **kwargs) if llm_api else None,
+ prompt_schema=self.prompt_schema,
+ instructions_schema=self.instructions_schema,
+ msg_history_schema=self.msg_history_schema,
+ output_schema=self.output_schema,
+ num_reasks=num_reasks,
+ metadata=metadata,
+ output=llm_output,
+ base_model=self.base_model,
+ full_schema_reask=full_schema_reask,
+ disable_tracer=self._disable_tracer,
+ )
+ call = await runner.async_run(call_log=call_log, prompt_params=prompt_params)
+
+ return ValidationOutcome[OT].from_guard_history(call)
diff --git a/guardrails/guard.py b/guardrails/guard.py
index e6ae27e1b..8a5d3dbd8 100644
--- a/guardrails/guard.py
+++ b/guardrails/guard.py
@@ -22,14 +22,14 @@
overload,
)
+from guardrails_api_client.models import AnyObject
+from guardrails_api_client.models import Guard as GuardModel
from guardrails_api_client.models import (
- AnyObject,
History,
HistoryEvent,
ValidatePayload,
ValidationOutput,
)
-from guardrails_api_client.models import Guard as GuardModel
from guardrails_api_client.types import UNSET
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig
@@ -500,8 +500,7 @@ def __call__(
Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]],
Awaitable[ValidationOutcome[OT]],
]:
- """Call the LLM and validate the output. Pass an async LLM API to
- return a coroutine.
+ """Call the LLM and validate the output.
Args:
llm_api: The LLM API to call
@@ -603,7 +602,8 @@ def __call(
**kwargs,
)
- # If the LLM API is async, return a coroutine
+ # If the LLM API is async, return a coroutine. This will be deprecated soon.
+
if asyncio.iscoroutinefunction(llm_api):
return self._call_async(
llm_api,
@@ -712,6 +712,12 @@ def _call_sync(
call = runner(call_log=call_log, prompt_params=prompt_params)
return ValidationOutcome[OT].from_guard_history(call)
+ @deprecated(
+ """Async methods within Guard are deprecated and will be removed in 0.5.x.
+ Instead, please use `AsyncGuard() or pass in a synchronous llm api.""",
+ category=FutureWarning,
+ stacklevel=2,
+ )
async def _call_async(
self,
llm_api: Callable[[Any], Awaitable[Any]],
@@ -1028,6 +1034,12 @@ def _sync_parse(
return ValidationOutcome[OT].from_guard_history(call)
+ @deprecated(
+ """Async methods within Guard are deprecated and will be removed in 0.5.x.
+ Instead, please use `AsyncGuard() or pass in a synchronous llm api.""",
+ category=FutureWarning,
+ stacklevel=2,
+ )
async def _async_parse(
self,
llm_output: str,
diff --git a/tests/unit_tests/test_async_guard.py b/tests/unit_tests/test_async_guard.py
new file mode 100644
index 000000000..0f4e33d6d
--- /dev/null
+++ b/tests/unit_tests/test_async_guard.py
@@ -0,0 +1,585 @@
+import openai
+import pytest
+from pydantic import BaseModel
+
+from guardrails import AsyncGuard, Rail, Validator
+from guardrails.datatypes import verify_metadata_requirements
+from guardrails.utils import args, kwargs, on_fail
+from guardrails.utils.openai_utils import OPENAI_VERSION
+from guardrails.validator_base import OnFailAction
+from guardrails.validators import ( # ReadingTime,
+ EndsWith,
+ LowerCase,
+ OneLine,
+ PassResult,
+ TwoWords,
+ UpperCase,
+ ValidLength,
+ register_validator,
+)
+
+
+@register_validator("myrequiringvalidator", data_type="string")
+class RequiringValidator(Validator):
+ required_metadata_keys = ["required_key"]
+
+ def validate(self, value, metadata):
+ return PassResult()
+
+
+@register_validator("myrequiringvalidator2", data_type="string")
+class RequiringValidator2(Validator):
+ required_metadata_keys = ["required_key2"]
+
+ def validate(self, value, metadata):
+ return PassResult()
+
+
+@pytest.mark.parametrize(
+ "spec,metadata,error_message",
+ [
+ (
+ """
+
+
+
+ """,
+ {"required_key": "a"},
+ "Missing required metadata keys: required_key",
+ ),
+ (
+ """
+
+
+
+ """,
+ {"required_key": "a", "required_key2": "b"},
+ "Missing required metadata keys: required_key, required_key2",
+ ),
+ (
+ """
+
+
+
+""",
+ {"required_key": "a"},
+ "Missing required metadata keys: required_key",
+ ),
+ ],
+)
+@pytest.mark.asyncio
+@pytest.mark.skipif(not OPENAI_VERSION.startswith("0"), reason="Only for OpenAI v0")
+async def test_required_metadata(spec, metadata, error_message):
+ guard = AsyncGuard.from_rail_string(spec)
+
+ missing_keys = verify_metadata_requirements({}, guard.output_schema.root_datatype)
+ assert set(missing_keys) == set(metadata)
+
+ not_missing_keys = verify_metadata_requirements(
+ metadata, guard.output_schema.root_datatype
+ )
+ assert not_missing_keys == []
+
+ # test sync guard
+ with pytest.raises(ValueError) as excinfo:
+ guard.parse("{}")
+ assert str(excinfo.value) == error_message
+
+ response = guard.parse("{}", metadata=metadata, num_reasks=0)
+ assert response.error is None
+
+ # test async guard
+ with pytest.raises(ValueError) as excinfo:
+ guard.parse("{}")
+ await guard.parse("{}", llm_api=openai.ChatCompletion.acreate, num_reasks=0)
+ assert str(excinfo.value) == error_message
+
+ response = await guard.parse(
+ "{}", metadata=metadata, llm_api=openai.ChatCompletion.acreate, num_reasks=0
+ )
+ assert response.error is None
+
+
+rail = Rail.from_string_validators([], "empty railspec")
+empty_rail_string = """
+
+"""
+
+
+class EmptyModel(BaseModel):
+ empty_field: str
+
+
+i_guard_none = AsyncGuard(rail)
+i_guard_two = AsyncGuard(rail, 2)
+r_guard_none = AsyncGuard.from_rail("tests/unit_tests/test_assets/empty.rail")
+r_guard_two = AsyncGuard.from_rail("tests/unit_tests/test_assets/empty.rail", 2)
+rs_guard_none = AsyncGuard.from_rail_string(empty_rail_string)
+rs_guard_two = AsyncGuard.from_rail_string(empty_rail_string, 2)
+py_guard_none = AsyncGuard.from_pydantic(output_class=EmptyModel)
+py_guard_two = AsyncGuard.from_pydantic(output_class=EmptyModel, num_reasks=2)
+s_guard_none = AsyncGuard.from_string(validators=[], description="empty railspec")
+s_guard_two = AsyncGuard.from_string(
+ validators=[], description="empty railspec", num_reasks=2
+)
+
+
+@pytest.mark.parametrize(
+ "guard,expected_num_reasks,config_num_reasks",
+ [
+ (i_guard_none, 1, None),
+ (i_guard_two, 2, None),
+ (i_guard_none, 3, 3),
+ (r_guard_none, 1, None),
+ (r_guard_two, 2, None),
+ (r_guard_none, 3, 3),
+ (rs_guard_none, 1, None),
+ (rs_guard_two, 2, None),
+ (rs_guard_none, 3, 3),
+ (py_guard_none, 1, None),
+ (py_guard_two, 2, None),
+ (py_guard_none, 3, 3),
+ (s_guard_none, 1, None),
+ (s_guard_two, 2, None),
+ (s_guard_none, 3, 3),
+ ],
+)
+def test_configure(guard: AsyncGuard, expected_num_reasks: int, config_num_reasks: int):
+ guard.configure(config_num_reasks)
+ assert guard.num_reasks == expected_num_reasks
+
+
+def guard_init_from_rail():
+ guard = AsyncGuard.from_rail("tests/unit_tests/test_assets/simple.rail")
+ assert (
+ guard.instructions.format().source.strip()
+ == "You are a helpful bot, who answers only with valid JSON"
+ )
+ assert guard.prompt.format().source.strip() == "Extract a string from the text"
+
+
+def test_use():
+ guard: AsyncGuard = (
+ AsyncGuard()
+ .use(EndsWith("a"))
+ .use(OneLine())
+ .use(LowerCase)
+ .use(TwoWords, on_fail=OnFailAction.REASK)
+ .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN)
+ )
+
+ # print(guard.__stringify__())
+ assert len(guard._validators) == 5
+
+ assert isinstance(guard._validators[0], EndsWith)
+ assert guard._validators[0]._kwargs["end"] == "a"
+ assert (
+ guard._validators[0].on_fail_descriptor == OnFailAction.FIX
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[1], OneLine)
+ assert (
+ guard._validators[1].on_fail_descriptor == OnFailAction.NOOP
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[2], LowerCase)
+ assert (
+ guard._validators[2].on_fail_descriptor == OnFailAction.NOOP
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[3], TwoWords)
+ assert guard._validators[3].on_fail_descriptor == OnFailAction.REASK # bc we set it
+
+ assert isinstance(guard._validators[4], ValidLength)
+ assert guard._validators[4]._min == 0
+ assert guard._validators[4]._kwargs["min"] == 0
+ assert guard._validators[4]._max == 12
+ assert guard._validators[4]._kwargs["max"] == 12
+ assert (
+ guard._validators[4].on_fail_descriptor == OnFailAction.REFRAIN
+ ) # bc we set it
+
+ # Raises error when trying to `use` a validator on a non-string
+ with pytest.raises(RuntimeError):
+
+ class TestClass(BaseModel):
+ another_field: str
+
+ py_guard = AsyncGuard.from_pydantic(output_class=TestClass)
+ py_guard.use(
+ EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK)
+ )
+
+ # Use a combination of prompt, instructions, msg_history and output validators
+ # Should only have the output validators in the guard,
+ # everything else is in the schema
+ guard: AsyncGuard = (
+ AsyncGuard()
+ .use(LowerCase, on="prompt")
+ .use(OneLine, on="prompt")
+ .use(UpperCase, on="instructions")
+ .use(LowerCase, on="msg_history")
+ .use(
+ EndsWith, end="a", on="output"
+ ) # default on="output", still explicitly set
+ .use(
+ TwoWords, on_fail=OnFailAction.REASK
+ ) # default on="output", implicitly set
+ )
+
+ # Check schemas for prompt, instructions and msg_history validators
+ prompt_validators = guard.rail.prompt_schema.root_datatype.validators
+ assert len(prompt_validators) == 2
+ assert prompt_validators[0].__class__.__name__ == "LowerCase"
+ assert prompt_validators[1].__class__.__name__ == "OneLine"
+
+ instructions_validators = guard.rail.instructions_schema.root_datatype.validators
+ assert len(instructions_validators) == 1
+ assert instructions_validators[0].__class__.__name__ == "UpperCase"
+
+ msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators
+ assert len(msg_history_validators) == 1
+ assert msg_history_validators[0].__class__.__name__ == "LowerCase"
+
+ # Check guard for output validators
+ assert len(guard._validators) == 2 # only 2 output validators, hence 2
+
+ assert isinstance(guard._validators[0], EndsWith)
+ assert guard._validators[0]._kwargs["end"] == "a"
+ assert (
+ guard._validators[0].on_fail_descriptor == OnFailAction.FIX
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[1], TwoWords)
+ assert guard._validators[1].on_fail_descriptor == OnFailAction.REASK # bc we set it
+
+ # Test with an invalid "on" parameter, should raise a ValueError
+ with pytest.raises(ValueError):
+ guard: AsyncGuard = (
+ AsyncGuard()
+ .use(EndsWith("a"), on="response") # invalid on parameter
+ .use(OneLine, on="prompt") # valid on parameter
+ )
+
+
+def test_use_many_instances():
+ guard: AsyncGuard = AsyncGuard().use_many(
+ EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK)
+ )
+
+ # print(guard.__stringify__())
+ assert len(guard._validators) == 4
+
+ assert isinstance(guard._validators[0], EndsWith)
+ assert guard._validators[0]._end == "a"
+ assert guard._validators[0]._kwargs["end"] == "a"
+ assert (
+ guard._validators[0].on_fail_descriptor == OnFailAction.FIX
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[1], OneLine)
+ assert (
+ guard._validators[1].on_fail_descriptor == OnFailAction.NOOP
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[2], LowerCase)
+ assert (
+ guard._validators[2].on_fail_descriptor == OnFailAction.NOOP
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[3], TwoWords)
+ assert guard._validators[3].on_fail_descriptor == OnFailAction.REASK # bc we set it
+
+ # Raises error when trying to `use_many` a validator on a non-string
+ with pytest.raises(RuntimeError):
+
+ class TestClass(BaseModel):
+ another_field: str
+
+ py_guard = AsyncGuard.from_pydantic(output_class=TestClass)
+ py_guard.use_many(
+ [
+ EndsWith("a"),
+ OneLine(),
+ LowerCase(),
+ TwoWords(on_fail=OnFailAction.REASK),
+ ]
+ )
+
+ # Test with explicitly setting the "on" parameter = "output"
+ guard: AsyncGuard = AsyncGuard().use_many(
+ EndsWith("a"),
+ OneLine(),
+ LowerCase(),
+ TwoWords(on_fail=OnFailAction.REASK),
+ on="output",
+ )
+
+ assert len(guard._validators) == 4 # still 4 output validators, hence 4
+
+ assert isinstance(guard._validators[0], EndsWith)
+ assert guard._validators[0]._end == "a"
+ assert guard._validators[0]._kwargs["end"] == "a"
+ assert (
+ guard._validators[0].on_fail_descriptor == OnFailAction.FIX
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[1], OneLine)
+ assert (
+ guard._validators[1].on_fail_descriptor == OnFailAction.NOOP
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[2], LowerCase)
+ assert (
+ guard._validators[2].on_fail_descriptor == OnFailAction.NOOP
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[3], TwoWords)
+ assert guard._validators[3].on_fail_descriptor == OnFailAction.REASK # bc we set it
+
+ # Test with explicitly setting the "on" parameter = "prompt"
+ guard: AsyncGuard = AsyncGuard().use_many(
+ OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK), on="prompt"
+ )
+
+ prompt_validators = guard.rail.prompt_schema.root_datatype.validators
+ assert len(prompt_validators) == 3
+ assert prompt_validators[0].__class__.__name__ == "OneLine"
+ assert prompt_validators[1].__class__.__name__ == "LowerCase"
+ assert prompt_validators[2].__class__.__name__ == "TwoWords"
+ assert len(guard._validators) == 0 # no output validators, hence 0
+
+ # Test with explicitly setting the "on" parameter = "instructions"
+ guard: AsyncGuard = AsyncGuard().use_many(
+ OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK), on="instructions"
+ )
+
+ instructions_validators = guard.rail.instructions_schema.root_datatype.validators
+ assert len(instructions_validators) == 3
+ assert instructions_validators[0].__class__.__name__ == "OneLine"
+ assert instructions_validators[1].__class__.__name__ == "LowerCase"
+ assert instructions_validators[2].__class__.__name__ == "TwoWords"
+ assert len(guard._validators) == 0 # no output validators, hence 0
+
+ # Test with explicitly setting the "on" parameter = "msg_history"
+ guard: AsyncGuard = AsyncGuard().use_many(
+ OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK), on="msg_history"
+ )
+
+ msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators
+ assert len(msg_history_validators) == 3
+ assert msg_history_validators[0].__class__.__name__ == "OneLine"
+ assert msg_history_validators[1].__class__.__name__ == "LowerCase"
+ assert msg_history_validators[2].__class__.__name__ == "TwoWords"
+ assert len(guard._validators) == 0 # no output validators, hence 0
+
+ # Test with an invalid "on" parameter, should raise a ValueError
+ with pytest.raises(ValueError):
+ guard: AsyncGuard = AsyncGuard().use_many(
+ EndsWith("a", on_fail=OnFailAction.EXCEPTION), OneLine(), on="response"
+ )
+
+
+def test_use_many_tuple():
+ guard: AsyncGuard = AsyncGuard().use_many(
+ OneLine,
+ (EndsWith, ["a"], {"on_fail": OnFailAction.EXCEPTION}),
+ (LowerCase, kwargs(on_fail=OnFailAction.FIX_REASK, some_other_kwarg="kwarg")),
+ (TwoWords, on_fail(OnFailAction.REASK)),
+ (ValidLength, args(0, 12), kwargs(on_fail=OnFailAction.REFRAIN)),
+ )
+
+ # print(guard.__stringify__())
+ assert len(guard._validators) == 5
+
+ assert isinstance(guard._validators[0], OneLine)
+ assert (
+ guard._validators[0].on_fail_descriptor == OnFailAction.NOOP
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[1], EndsWith)
+ assert guard._validators[1]._end == "a"
+ assert guard._validators[1]._kwargs["end"] == "a"
+ assert (
+ guard._validators[1].on_fail_descriptor == OnFailAction.EXCEPTION
+ ) # bc we set it
+
+ assert isinstance(guard._validators[2], LowerCase)
+ assert guard._validators[2]._kwargs["some_other_kwarg"] == "kwarg"
+ assert (
+ guard._validators[2].on_fail_descriptor == OnFailAction.FIX_REASK
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[3], TwoWords)
+ assert guard._validators[3].on_fail_descriptor == OnFailAction.REASK # bc we set it
+
+ assert isinstance(guard._validators[4], ValidLength)
+ assert guard._validators[4]._min == 0
+ assert guard._validators[4]._kwargs["min"] == 0
+ assert guard._validators[4]._max == 12
+ assert guard._validators[4]._kwargs["max"] == 12
+ assert (
+ guard._validators[4].on_fail_descriptor == OnFailAction.REFRAIN
+ ) # bc we set it
+
+ # Test with explicitly setting the "on" parameter
+ guard: AsyncGuard = AsyncGuard().use_many(
+ (EndsWith, ["a"], {"on_fail": OnFailAction.EXCEPTION}),
+ OneLine,
+ on="output",
+ )
+
+ assert len(guard._validators) == 2 # only 2 output validators, hence 2
+
+ assert isinstance(guard._validators[0], EndsWith)
+ assert guard._validators[0]._end == "a"
+ assert guard._validators[0]._kwargs["end"] == "a"
+ assert (
+ guard._validators[0].on_fail_descriptor == OnFailAction.EXCEPTION
+ ) # bc we set it
+
+ assert isinstance(guard._validators[1], OneLine)
+ assert (
+ guard._validators[1].on_fail_descriptor == OnFailAction.NOOP
+ ) # bc this is the default
+
+ # Test with an invalid "on" parameter, should raise a ValueError
+ with pytest.raises(ValueError):
+ guard: AsyncGuard = AsyncGuard().use_many(
+ (EndsWith, ["a"], {"on_fail": OnFailAction.EXCEPTION}),
+ OneLine,
+ on="response",
+ )
+
+
+@pytest.mark.asyncio
+async def test_validate():
+ guard: AsyncGuard = (
+ AsyncGuard()
+ .use(OneLine)
+ .use(
+ LowerCase(on_fail=OnFailAction.FIX), on="output"
+ ) # default on="output", still explicitly set
+ .use(TwoWords)
+ .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN)
+ )
+
+ llm_output: str = "Oh Canada" # bc it meets our criteria
+ response = await guard.validate(llm_output)
+
+ assert response.validation_passed is True
+ assert response.validated_output == llm_output.lower()
+ llm_output_2 = "Star Spangled Banner" # to stick with the theme
+
+ response_2 = await guard.validate(llm_output_2)
+
+ assert response_2.validation_passed is False
+ assert response_2.validated_output is None
+
+ # Test with a combination of prompt, output, instructions and msg_history validators
+ # Should still only use the output validators to validate the output
+ guard: AsyncGuard = (
+ AsyncGuard()
+ .use(OneLine, on="prompt")
+ .use(LowerCase, on="instructions")
+ .use(UpperCase, on="msg_history")
+ .use(LowerCase, on="output", on_fail=OnFailAction.FIX)
+ .use(TwoWords, on="output")
+ .use(ValidLength, 0, 12, on="output")
+ )
+
+ llm_output: str = "Oh Canada" # bc it meets our criteria
+
+ response = await guard.validate(llm_output)
+
+ assert response.validation_passed is True
+ assert response.validated_output == llm_output.lower()
+
+ llm_output_2 = "Star Spangled Banner" # to stick with the theme
+
+ response_2 = await guard.validate(llm_output_2)
+
+ assert response_2.validation_passed is False
+ assert response_2.validated_output is None
+
+
+def test_use_and_use_many():
+ guard: AsyncGuard = (
+ AsyncGuard()
+ .use_many(OneLine(), LowerCase(), on="prompt")
+ .use(UpperCase, on="instructions")
+ .use(LowerCase, on="msg_history")
+ .use_many(
+ TwoWords(on_fail=OnFailAction.REASK),
+ ValidLength(0, 12, on_fail=OnFailAction.REFRAIN),
+ on="output",
+ )
+ )
+
+ # Check schemas for prompt, instructions and msg_history validators
+ prompt_validators = guard.rail.prompt_schema.root_datatype.validators
+ assert len(prompt_validators) == 2
+ assert prompt_validators[0].__class__.__name__ == "OneLine"
+ assert prompt_validators[1].__class__.__name__ == "LowerCase"
+
+ instructions_validators = guard.rail.instructions_schema.root_datatype.validators
+ assert len(instructions_validators) == 1
+ assert instructions_validators[0].__class__.__name__ == "UpperCase"
+
+ msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators
+ assert len(msg_history_validators) == 1
+ assert msg_history_validators[0].__class__.__name__ == "LowerCase"
+
+ # Check guard for output validators
+ assert len(guard._validators) == 2 # only 2 output validators, hence 2
+
+ assert isinstance(guard._validators[0], TwoWords)
+ assert guard._validators[0].on_fail_descriptor == OnFailAction.REASK # bc we set it
+
+ assert isinstance(guard._validators[1], ValidLength)
+ assert guard._validators[1]._min == 0
+ assert guard._validators[1]._kwargs["min"] == 0
+ assert guard._validators[1]._max == 12
+ assert guard._validators[1]._kwargs["max"] == 12
+ assert (
+ guard._validators[1].on_fail_descriptor == OnFailAction.REFRAIN
+ ) # bc we set it
+
+ # Test with an invalid "on" parameter, should raise a ValueError
+ with pytest.raises(ValueError):
+ guard: AsyncGuard = (
+ AsyncGuard()
+ .use_many(OneLine(), LowerCase(), on="prompt")
+ .use(UpperCase, on="instructions")
+ .use(LowerCase, on="msg_history")
+ .use_many(
+ TwoWords(on_fail=OnFailAction.REASK),
+ ValidLength(0, 12, on_fail=OnFailAction.REFRAIN),
+ on="response", # invalid "on" parameter
+ )
+ )