diff --git a/guardrails/__init__.py b/guardrails/__init__.py index 6e28dff7f..b1d081fca 100644 --- a/guardrails/__init__.py +++ b/guardrails/__init__.py @@ -1,6 +1,7 @@ # Set up __init__.py so that users can do from guardrails import Response, Schema, etc. from guardrails.guard import Guard +from guardrails.async_guard import AsyncGuard from guardrails.llm_providers import PromptCallableBase from guardrails.logging_utils import configure_logging from guardrails.prompt import Instructions, Prompt @@ -10,6 +11,7 @@ __all__ = [ "Guard", + "AsyncGuard", "PromptCallableBase", "Rail", "Validator", diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py new file mode 100644 index 000000000..e689ac698 --- /dev/null +++ b/guardrails/async_guard.py @@ -0,0 +1,449 @@ +import contextvars +import inspect +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Union, +) + +from guardrails import Guard +from guardrails.classes import OT, ValidationOutcome +from guardrails.classes.history import Call +from guardrails.classes.history.call_inputs import CallInputs +from guardrails.llm_providers import get_async_llm_ask, model_is_supported_server_side +from guardrails.logger import set_scope +from guardrails.run import AsyncRunner +from guardrails.stores.context import set_call_kwargs, set_tracer, set_tracer_context + + +class AsyncGuard(Guard): + """The Guard class. + + This class one of the main entry point for using Guardrails. It is + initialized from one of the following class methods: + + - `from_rail` + - `from_rail_string` + - `from_pydantic` + - `from_string` + + The `__call__` + method functions as a wrapper around LLM APIs. It takes in an Async LLM + API, and optional prompt parameters, and returns the raw output stream from + the LLM and the validated output stream. + """ + + async def __call__( + self, + llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]], + prompt_params: Optional[Dict] = None, + num_reasks: Optional[int] = None, + prompt: Optional[str] = None, + instructions: Optional[str] = None, + msg_history: Optional[List[Dict]] = None, + metadata: Optional[Dict] = None, + full_schema_reask: Optional[bool] = None, + *args, + **kwargs, + ) -> Union[ + Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]], + Awaitable[ValidationOutcome[OT]], + ]: + """Call the LLM and validate the output. Pass an async LLM API to + return a coroutine. + + Args: + llm_api: The LLM API to call + (e.g. openai.Completion.create or openai.Completion.acreate) + prompt_params: The parameters to pass to the prompt.format() method. + num_reasks: The max times to re-ask the LLM for invalid output. + prompt: The prompt to use for the LLM. + instructions: Instructions for chat models. + msg_history: The message history to pass to the LLM. + metadata: Metadata to pass to the validators. + full_schema_reask: When reasking, whether to regenerate the full schema + or just the incorrect values. + Defaults to `True` if a base model is provided, + `False` otherwise. + + Returns: + The raw text output from the LLM and the validated output. + """ + + async def __call( + self, + llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]], + prompt_params: Optional[Dict] = None, + num_reasks: Optional[int] = None, + prompt: Optional[str] = None, + instructions: Optional[str] = None, + msg_history: Optional[List[Dict]] = None, + metadata: Optional[Dict] = None, + full_schema_reask: Optional[bool] = None, + *args, + **kwargs, + ): + if metadata is None: + metadata = {} + if full_schema_reask is None: + full_schema_reask = self.base_model is not None + if prompt_params is None: + prompt_params = {} + + if not self._disable_tracer: + # Create a new span for this guard call + self._hub_telemetry.create_new_span( + span_name="/guard_call", + attributes=[ + ("guard_id", self._guard_id), + ("user_id", self._user_id), + ("llm_api", llm_api.__name__ if llm_api else "None"), + ("custom_reask_prompt", self.reask_prompt is not None), + ( + "custom_reask_instructions", + self.reask_instructions is not None, + ), + ], + is_parent=True, # It will have children + has_parent=False, # Has no parents + ) + + set_call_kwargs(kwargs) + set_tracer(self._tracer) + set_tracer_context(self._tracer_context) + + self.configure(num_reasks) + if self.num_reasks is None: + raise RuntimeError( + "`num_reasks` is `None` after calling `configure()`. " + "This should never happen." + ) + + input_prompt = prompt or (self.prompt._source if self.prompt else None) + input_instructions = instructions or ( + self.instructions._source if self.instructions else None + ) + call_inputs = CallInputs( + llm_api=llm_api, + prompt=input_prompt, + instructions=input_instructions, + msg_history=msg_history, + prompt_params=prompt_params, + num_reasks=self.num_reasks, + metadata=metadata, + full_schema_reask=full_schema_reask, + args=list(args), + kwargs=kwargs, + ) + call_log = Call(inputs=call_inputs) + set_scope(str(id(call_log))) + self.history.push(call_log) + + if self._api_client is not None and model_is_supported_server_side( + llm_api, *args, **kwargs + ): + return self._call_server( + llm_api=llm_api, + num_reasks=self.num_reasks, + prompt_params=prompt_params, + full_schema_reask=full_schema_reask, + call_log=call_log, + *args, + **kwargs, + ) + + # If the LLM API is not async, fail + # FIXME: it seems like this check isn't actually working? + if not inspect.isawaitable(llm_api) and not inspect.iscoroutinefunction( + llm_api + ): + raise RuntimeError( + f"The LLM API `{llm_api.__name__}` is not a coroutine. " + "Please use an async LLM API." + ) + # Otherwise, call the LLM + return await self._call_async( + llm_api, + prompt_params=prompt_params, + num_reasks=self.num_reasks, + prompt=prompt, + instructions=instructions, + msg_history=msg_history, + metadata=metadata, + full_schema_reask=full_schema_reask, + call_log=call_log, + *args, + **kwargs, + ) + + guard_context = contextvars.Context() + return await guard_context.run( + __call, + self, + llm_api, + prompt_params, + num_reasks, + prompt, + instructions, + msg_history, + metadata, + full_schema_reask, + *args, + **kwargs, + ) + + async def _call_async( + self, + llm_api: Callable[[Any], Awaitable[Any]], + prompt_params: Dict, + num_reasks: int, + prompt: Optional[str], + instructions: Optional[str], + msg_history: Optional[List[Dict]], + metadata: Dict, + full_schema_reask: bool, + call_log: Call, + *args, + **kwargs, + ) -> ValidationOutcome[OT]: + """Call the LLM asynchronously and validate the output. + + Args: + llm_api: The LLM API to call asynchronously (e.g. openai.Completion.acreate) + prompt_params: The parameters to pass to the prompt.format() method. + num_reasks: The max times to re-ask the LLM for invalid output. + prompt: The prompt to use for the LLM. + instructions: Instructions for chat models. + msg_history: The message history to pass to the LLM. + metadata: Metadata to pass to the validators. + full_schema_reask: When reasking, whether to regenerate the full schema + or just the incorrect values. + Defaults to `True` if a base model is provided, + `False` otherwise. + + Returns: + The raw text output from the LLM and the validated output. + """ + instructions_obj = instructions or self.instructions + prompt_obj = prompt or self.prompt + msg_history_obj = msg_history or [] + if prompt_obj is None: + if msg_history_obj is not None and not len(msg_history_obj): + raise RuntimeError( + "You must provide a prompt if msg_history is empty. " + "Alternatively, you can provide a prompt in the RAIL spec." + ) + + runner = AsyncRunner( + instructions=instructions_obj, + prompt=prompt_obj, + msg_history=msg_history_obj, + api=get_async_llm_ask(llm_api, *args, **kwargs), + prompt_schema=self.prompt_schema, + instructions_schema=self.instructions_schema, + msg_history_schema=self.msg_history_schema, + output_schema=self.output_schema, + num_reasks=num_reasks, + metadata=metadata, + base_model=self.base_model, + full_schema_reask=full_schema_reask, + disable_tracer=self._disable_tracer, + ) + call = await runner.async_run(call_log=call_log, prompt_params=prompt_params) + return ValidationOutcome[OT].from_guard_history(call) + + async def parse( + self, + llm_output: str, + metadata: Optional[Dict] = None, + llm_api: Optional[Callable] = None, + num_reasks: Optional[int] = None, + prompt_params: Optional[Dict] = None, + full_schema_reask: Optional[bool] = None, + *args, + **kwargs, + ) -> Union[ValidationOutcome[OT], Awaitable[ValidationOutcome[OT]]]: + """Alternate flow to using Guard where the llm_output is known. + + Args: + llm_output: The output being parsed and validated. + metadata: Metadata to pass to the validators. + llm_api: The LLM API to call + (e.g. openai.Completion.create or openai.Completion.acreate) + num_reasks: The max times to re-ask the LLM for invalid output. + prompt_params: The parameters to pass to the prompt.format() method. + full_schema_reask: When reasking, whether to regenerate the full schema + or just the incorrect values. + + Returns: + The validated response. This is either a string or a dictionary, + determined by the object schema defined in the RAILspec. + """ + + async def __parse( + self, + llm_output: str, + metadata: Optional[Dict] = None, + llm_api: Optional[Callable] = None, + num_reasks: Optional[int] = None, + prompt_params: Optional[Dict] = None, + full_schema_reask: Optional[bool] = None, + *args, + **kwargs, + ): + final_num_reasks = ( + num_reasks if num_reasks is not None else 0 if llm_api is None else None + ) + + if not self._disable_tracer: + self._hub_telemetry.create_new_span( + span_name="/guard_parse", + attributes=[ + ("guard_id", self._guard_id), + ("user_id", self._user_id), + ("llm_api", llm_api.__name__ if llm_api else "None"), + ("custom_reask_prompt", self.reask_prompt is not None), + ( + "custom_reask_instructions", + self.reask_instructions is not None, + ), + ], + is_parent=True, # It will have children + has_parent=False, # Has no parents + ) + + self.configure(final_num_reasks) + if self.num_reasks is None: + raise RuntimeError( + "`num_reasks` is `None` after calling `configure()`. " + "This should never happen." + ) + if full_schema_reask is None: + full_schema_reask = self.base_model is not None + metadata = metadata or {} + prompt_params = prompt_params or {} + + set_call_kwargs(kwargs) + set_tracer(self._tracer) + set_tracer_context(self._tracer_context) + + input_prompt = self.prompt._source if self.prompt else None + input_instructions = ( + self.instructions._source if self.instructions else None + ) + call_inputs = CallInputs( + llm_api=llm_api, + llm_output=llm_output, + prompt=input_prompt, + instructions=input_instructions, + prompt_params=prompt_params, + num_reasks=self.num_reasks, + metadata=metadata, + full_schema_reask=full_schema_reask, + args=list(args), + kwargs=kwargs, + ) + call_log = Call(inputs=call_inputs) + set_scope(str(id(call_log))) + self.history.push(call_log) + + if self._api_client is not None and model_is_supported_server_side( + llm_api, *args, **kwargs + ): + return self._call_server( + llm_output=llm_output, + llm_api=llm_api, + num_reasks=self.num_reasks, + prompt_params=prompt_params, + full_schema_reask=full_schema_reask, + call_log=call_log, + *args, + **kwargs, + ) + + # FIXME: checking not llm_api because it can still fall back on defaults and + # function as expected. We should handle this better. + if ( + not llm_api + or inspect.iscoroutinefunction(llm_api) + or inspect.isasyncgenfunction(llm_api) + ): + return await self._async_parse( + llm_output, + metadata, + llm_api=llm_api, + num_reasks=self.num_reasks, + prompt_params=prompt_params, + full_schema_reask=full_schema_reask, + call_log=call_log, + *args, + **kwargs, + ) + + else: + raise NotImplementedError( + "AsyncGuard does not support non-async LLM APIs. " + "Please use the synchronous API Guard or supply an asynchronous " + "LLM API." + ) + + guard_context = contextvars.Context() + return await guard_context.run( + __parse, + self, + llm_output, + metadata, + llm_api, + num_reasks, + prompt_params, + full_schema_reask, + *args, + **kwargs, + ) + + async def _async_parse( + self, + llm_output: str, + metadata: Dict, + llm_api: Optional[Callable[[Any], Awaitable[Any]]], + num_reasks: int, + prompt_params: Dict, + full_schema_reask: bool, + call_log: Call, + *args, + **kwargs, + ) -> ValidationOutcome[OT]: + """Alternate flow to using Guard where the llm_output is known. + + Args: + llm_output: The output from the LLM. + llm_api: The LLM API to use to re-ask the LLM. + num_reasks: The max times to re-ask the LLM for invalid output. + + Returns: + The validated response. + """ + runner = AsyncRunner( + instructions=kwargs.pop("instructions", None), + prompt=kwargs.pop("prompt", None), + msg_history=kwargs.pop("msg_history", None), + api=get_async_llm_ask(llm_api, *args, **kwargs) if llm_api else None, + prompt_schema=self.prompt_schema, + instructions_schema=self.instructions_schema, + msg_history_schema=self.msg_history_schema, + output_schema=self.output_schema, + num_reasks=num_reasks, + metadata=metadata, + output=llm_output, + base_model=self.base_model, + full_schema_reask=full_schema_reask, + disable_tracer=self._disable_tracer, + ) + call = await runner.async_run(call_log=call_log, prompt_params=prompt_params) + + return ValidationOutcome[OT].from_guard_history(call) diff --git a/guardrails/guard.py b/guardrails/guard.py index e6ae27e1b..8a5d3dbd8 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -22,14 +22,14 @@ overload, ) +from guardrails_api_client.models import AnyObject +from guardrails_api_client.models import Guard as GuardModel from guardrails_api_client.models import ( - AnyObject, History, HistoryEvent, ValidatePayload, ValidationOutput, ) -from guardrails_api_client.models import Guard as GuardModel from guardrails_api_client.types import UNSET from langchain_core.messages import BaseMessage from langchain_core.runnables import Runnable, RunnableConfig @@ -500,8 +500,7 @@ def __call__( Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]], Awaitable[ValidationOutcome[OT]], ]: - """Call the LLM and validate the output. Pass an async LLM API to - return a coroutine. + """Call the LLM and validate the output. Args: llm_api: The LLM API to call @@ -603,7 +602,8 @@ def __call( **kwargs, ) - # If the LLM API is async, return a coroutine + # If the LLM API is async, return a coroutine. This will be deprecated soon. + if asyncio.iscoroutinefunction(llm_api): return self._call_async( llm_api, @@ -712,6 +712,12 @@ def _call_sync( call = runner(call_log=call_log, prompt_params=prompt_params) return ValidationOutcome[OT].from_guard_history(call) + @deprecated( + """Async methods within Guard are deprecated and will be removed in 0.5.x. + Instead, please use `AsyncGuard() or pass in a synchronous llm api.""", + category=FutureWarning, + stacklevel=2, + ) async def _call_async( self, llm_api: Callable[[Any], Awaitable[Any]], @@ -1028,6 +1034,12 @@ def _sync_parse( return ValidationOutcome[OT].from_guard_history(call) + @deprecated( + """Async methods within Guard are deprecated and will be removed in 0.5.x. + Instead, please use `AsyncGuard() or pass in a synchronous llm api.""", + category=FutureWarning, + stacklevel=2, + ) async def _async_parse( self, llm_output: str, diff --git a/tests/unit_tests/test_async_guard.py b/tests/unit_tests/test_async_guard.py new file mode 100644 index 000000000..0f4e33d6d --- /dev/null +++ b/tests/unit_tests/test_async_guard.py @@ -0,0 +1,585 @@ +import openai +import pytest +from pydantic import BaseModel + +from guardrails import AsyncGuard, Rail, Validator +from guardrails.datatypes import verify_metadata_requirements +from guardrails.utils import args, kwargs, on_fail +from guardrails.utils.openai_utils import OPENAI_VERSION +from guardrails.validator_base import OnFailAction +from guardrails.validators import ( # ReadingTime, + EndsWith, + LowerCase, + OneLine, + PassResult, + TwoWords, + UpperCase, + ValidLength, + register_validator, +) + + +@register_validator("myrequiringvalidator", data_type="string") +class RequiringValidator(Validator): + required_metadata_keys = ["required_key"] + + def validate(self, value, metadata): + return PassResult() + + +@register_validator("myrequiringvalidator2", data_type="string") +class RequiringValidator2(Validator): + required_metadata_keys = ["required_key2"] + + def validate(self, value, metadata): + return PassResult() + + +@pytest.mark.parametrize( + "spec,metadata,error_message", + [ + ( + """ + + + + + + """, + {"required_key": "a"}, + "Missing required metadata keys: required_key", + ), + ( + """ + + + + + + + + + + + """, + {"required_key": "a", "required_key2": "b"}, + "Missing required metadata keys: required_key, required_key2", + ), + ( + """ + + + + + + + + + + + + + + + + +""", + {"required_key": "a"}, + "Missing required metadata keys: required_key", + ), + ], +) +@pytest.mark.asyncio +@pytest.mark.skipif(not OPENAI_VERSION.startswith("0"), reason="Only for OpenAI v0") +async def test_required_metadata(spec, metadata, error_message): + guard = AsyncGuard.from_rail_string(spec) + + missing_keys = verify_metadata_requirements({}, guard.output_schema.root_datatype) + assert set(missing_keys) == set(metadata) + + not_missing_keys = verify_metadata_requirements( + metadata, guard.output_schema.root_datatype + ) + assert not_missing_keys == [] + + # test sync guard + with pytest.raises(ValueError) as excinfo: + guard.parse("{}") + assert str(excinfo.value) == error_message + + response = guard.parse("{}", metadata=metadata, num_reasks=0) + assert response.error is None + + # test async guard + with pytest.raises(ValueError) as excinfo: + guard.parse("{}") + await guard.parse("{}", llm_api=openai.ChatCompletion.acreate, num_reasks=0) + assert str(excinfo.value) == error_message + + response = await guard.parse( + "{}", metadata=metadata, llm_api=openai.ChatCompletion.acreate, num_reasks=0 + ) + assert response.error is None + + +rail = Rail.from_string_validators([], "empty railspec") +empty_rail_string = """ + +""" + + +class EmptyModel(BaseModel): + empty_field: str + + +i_guard_none = AsyncGuard(rail) +i_guard_two = AsyncGuard(rail, 2) +r_guard_none = AsyncGuard.from_rail("tests/unit_tests/test_assets/empty.rail") +r_guard_two = AsyncGuard.from_rail("tests/unit_tests/test_assets/empty.rail", 2) +rs_guard_none = AsyncGuard.from_rail_string(empty_rail_string) +rs_guard_two = AsyncGuard.from_rail_string(empty_rail_string, 2) +py_guard_none = AsyncGuard.from_pydantic(output_class=EmptyModel) +py_guard_two = AsyncGuard.from_pydantic(output_class=EmptyModel, num_reasks=2) +s_guard_none = AsyncGuard.from_string(validators=[], description="empty railspec") +s_guard_two = AsyncGuard.from_string( + validators=[], description="empty railspec", num_reasks=2 +) + + +@pytest.mark.parametrize( + "guard,expected_num_reasks,config_num_reasks", + [ + (i_guard_none, 1, None), + (i_guard_two, 2, None), + (i_guard_none, 3, 3), + (r_guard_none, 1, None), + (r_guard_two, 2, None), + (r_guard_none, 3, 3), + (rs_guard_none, 1, None), + (rs_guard_two, 2, None), + (rs_guard_none, 3, 3), + (py_guard_none, 1, None), + (py_guard_two, 2, None), + (py_guard_none, 3, 3), + (s_guard_none, 1, None), + (s_guard_two, 2, None), + (s_guard_none, 3, 3), + ], +) +def test_configure(guard: AsyncGuard, expected_num_reasks: int, config_num_reasks: int): + guard.configure(config_num_reasks) + assert guard.num_reasks == expected_num_reasks + + +def guard_init_from_rail(): + guard = AsyncGuard.from_rail("tests/unit_tests/test_assets/simple.rail") + assert ( + guard.instructions.format().source.strip() + == "You are a helpful bot, who answers only with valid JSON" + ) + assert guard.prompt.format().source.strip() == "Extract a string from the text" + + +def test_use(): + guard: AsyncGuard = ( + AsyncGuard() + .use(EndsWith("a")) + .use(OneLine()) + .use(LowerCase) + .use(TwoWords, on_fail=OnFailAction.REASK) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) + + # print(guard.__stringify__()) + assert len(guard._validators) == 5 + + assert isinstance(guard._validators[0], EndsWith) + assert guard._validators[0]._kwargs["end"] == "a" + assert ( + guard._validators[0].on_fail_descriptor == OnFailAction.FIX + ) # bc this is the default + + assert isinstance(guard._validators[1], OneLine) + assert ( + guard._validators[1].on_fail_descriptor == OnFailAction.NOOP + ) # bc this is the default + + assert isinstance(guard._validators[2], LowerCase) + assert ( + guard._validators[2].on_fail_descriptor == OnFailAction.NOOP + ) # bc this is the default + + assert isinstance(guard._validators[3], TwoWords) + assert guard._validators[3].on_fail_descriptor == OnFailAction.REASK # bc we set it + + assert isinstance(guard._validators[4], ValidLength) + assert guard._validators[4]._min == 0 + assert guard._validators[4]._kwargs["min"] == 0 + assert guard._validators[4]._max == 12 + assert guard._validators[4]._kwargs["max"] == 12 + assert ( + guard._validators[4].on_fail_descriptor == OnFailAction.REFRAIN + ) # bc we set it + + # Raises error when trying to `use` a validator on a non-string + with pytest.raises(RuntimeError): + + class TestClass(BaseModel): + another_field: str + + py_guard = AsyncGuard.from_pydantic(output_class=TestClass) + py_guard.use( + EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK) + ) + + # Use a combination of prompt, instructions, msg_history and output validators + # Should only have the output validators in the guard, + # everything else is in the schema + guard: AsyncGuard = ( + AsyncGuard() + .use(LowerCase, on="prompt") + .use(OneLine, on="prompt") + .use(UpperCase, on="instructions") + .use(LowerCase, on="msg_history") + .use( + EndsWith, end="a", on="output" + ) # default on="output", still explicitly set + .use( + TwoWords, on_fail=OnFailAction.REASK + ) # default on="output", implicitly set + ) + + # Check schemas for prompt, instructions and msg_history validators + prompt_validators = guard.rail.prompt_schema.root_datatype.validators + assert len(prompt_validators) == 2 + assert prompt_validators[0].__class__.__name__ == "LowerCase" + assert prompt_validators[1].__class__.__name__ == "OneLine" + + instructions_validators = guard.rail.instructions_schema.root_datatype.validators + assert len(instructions_validators) == 1 + assert instructions_validators[0].__class__.__name__ == "UpperCase" + + msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators + assert len(msg_history_validators) == 1 + assert msg_history_validators[0].__class__.__name__ == "LowerCase" + + # Check guard for output validators + assert len(guard._validators) == 2 # only 2 output validators, hence 2 + + assert isinstance(guard._validators[0], EndsWith) + assert guard._validators[0]._kwargs["end"] == "a" + assert ( + guard._validators[0].on_fail_descriptor == OnFailAction.FIX + ) # bc this is the default + + assert isinstance(guard._validators[1], TwoWords) + assert guard._validators[1].on_fail_descriptor == OnFailAction.REASK # bc we set it + + # Test with an invalid "on" parameter, should raise a ValueError + with pytest.raises(ValueError): + guard: AsyncGuard = ( + AsyncGuard() + .use(EndsWith("a"), on="response") # invalid on parameter + .use(OneLine, on="prompt") # valid on parameter + ) + + +def test_use_many_instances(): + guard: AsyncGuard = AsyncGuard().use_many( + EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK) + ) + + # print(guard.__stringify__()) + assert len(guard._validators) == 4 + + assert isinstance(guard._validators[0], EndsWith) + assert guard._validators[0]._end == "a" + assert guard._validators[0]._kwargs["end"] == "a" + assert ( + guard._validators[0].on_fail_descriptor == OnFailAction.FIX + ) # bc this is the default + + assert isinstance(guard._validators[1], OneLine) + assert ( + guard._validators[1].on_fail_descriptor == OnFailAction.NOOP + ) # bc this is the default + + assert isinstance(guard._validators[2], LowerCase) + assert ( + guard._validators[2].on_fail_descriptor == OnFailAction.NOOP + ) # bc this is the default + + assert isinstance(guard._validators[3], TwoWords) + assert guard._validators[3].on_fail_descriptor == OnFailAction.REASK # bc we set it + + # Raises error when trying to `use_many` a validator on a non-string + with pytest.raises(RuntimeError): + + class TestClass(BaseModel): + another_field: str + + py_guard = AsyncGuard.from_pydantic(output_class=TestClass) + py_guard.use_many( + [ + EndsWith("a"), + OneLine(), + LowerCase(), + TwoWords(on_fail=OnFailAction.REASK), + ] + ) + + # Test with explicitly setting the "on" parameter = "output" + guard: AsyncGuard = AsyncGuard().use_many( + EndsWith("a"), + OneLine(), + LowerCase(), + TwoWords(on_fail=OnFailAction.REASK), + on="output", + ) + + assert len(guard._validators) == 4 # still 4 output validators, hence 4 + + assert isinstance(guard._validators[0], EndsWith) + assert guard._validators[0]._end == "a" + assert guard._validators[0]._kwargs["end"] == "a" + assert ( + guard._validators[0].on_fail_descriptor == OnFailAction.FIX + ) # bc this is the default + + assert isinstance(guard._validators[1], OneLine) + assert ( + guard._validators[1].on_fail_descriptor == OnFailAction.NOOP + ) # bc this is the default + + assert isinstance(guard._validators[2], LowerCase) + assert ( + guard._validators[2].on_fail_descriptor == OnFailAction.NOOP + ) # bc this is the default + + assert isinstance(guard._validators[3], TwoWords) + assert guard._validators[3].on_fail_descriptor == OnFailAction.REASK # bc we set it + + # Test with explicitly setting the "on" parameter = "prompt" + guard: AsyncGuard = AsyncGuard().use_many( + OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK), on="prompt" + ) + + prompt_validators = guard.rail.prompt_schema.root_datatype.validators + assert len(prompt_validators) == 3 + assert prompt_validators[0].__class__.__name__ == "OneLine" + assert prompt_validators[1].__class__.__name__ == "LowerCase" + assert prompt_validators[2].__class__.__name__ == "TwoWords" + assert len(guard._validators) == 0 # no output validators, hence 0 + + # Test with explicitly setting the "on" parameter = "instructions" + guard: AsyncGuard = AsyncGuard().use_many( + OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK), on="instructions" + ) + + instructions_validators = guard.rail.instructions_schema.root_datatype.validators + assert len(instructions_validators) == 3 + assert instructions_validators[0].__class__.__name__ == "OneLine" + assert instructions_validators[1].__class__.__name__ == "LowerCase" + assert instructions_validators[2].__class__.__name__ == "TwoWords" + assert len(guard._validators) == 0 # no output validators, hence 0 + + # Test with explicitly setting the "on" parameter = "msg_history" + guard: AsyncGuard = AsyncGuard().use_many( + OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK), on="msg_history" + ) + + msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators + assert len(msg_history_validators) == 3 + assert msg_history_validators[0].__class__.__name__ == "OneLine" + assert msg_history_validators[1].__class__.__name__ == "LowerCase" + assert msg_history_validators[2].__class__.__name__ == "TwoWords" + assert len(guard._validators) == 0 # no output validators, hence 0 + + # Test with an invalid "on" parameter, should raise a ValueError + with pytest.raises(ValueError): + guard: AsyncGuard = AsyncGuard().use_many( + EndsWith("a", on_fail=OnFailAction.EXCEPTION), OneLine(), on="response" + ) + + +def test_use_many_tuple(): + guard: AsyncGuard = AsyncGuard().use_many( + OneLine, + (EndsWith, ["a"], {"on_fail": OnFailAction.EXCEPTION}), + (LowerCase, kwargs(on_fail=OnFailAction.FIX_REASK, some_other_kwarg="kwarg")), + (TwoWords, on_fail(OnFailAction.REASK)), + (ValidLength, args(0, 12), kwargs(on_fail=OnFailAction.REFRAIN)), + ) + + # print(guard.__stringify__()) + assert len(guard._validators) == 5 + + assert isinstance(guard._validators[0], OneLine) + assert ( + guard._validators[0].on_fail_descriptor == OnFailAction.NOOP + ) # bc this is the default + + assert isinstance(guard._validators[1], EndsWith) + assert guard._validators[1]._end == "a" + assert guard._validators[1]._kwargs["end"] == "a" + assert ( + guard._validators[1].on_fail_descriptor == OnFailAction.EXCEPTION + ) # bc we set it + + assert isinstance(guard._validators[2], LowerCase) + assert guard._validators[2]._kwargs["some_other_kwarg"] == "kwarg" + assert ( + guard._validators[2].on_fail_descriptor == OnFailAction.FIX_REASK + ) # bc this is the default + + assert isinstance(guard._validators[3], TwoWords) + assert guard._validators[3].on_fail_descriptor == OnFailAction.REASK # bc we set it + + assert isinstance(guard._validators[4], ValidLength) + assert guard._validators[4]._min == 0 + assert guard._validators[4]._kwargs["min"] == 0 + assert guard._validators[4]._max == 12 + assert guard._validators[4]._kwargs["max"] == 12 + assert ( + guard._validators[4].on_fail_descriptor == OnFailAction.REFRAIN + ) # bc we set it + + # Test with explicitly setting the "on" parameter + guard: AsyncGuard = AsyncGuard().use_many( + (EndsWith, ["a"], {"on_fail": OnFailAction.EXCEPTION}), + OneLine, + on="output", + ) + + assert len(guard._validators) == 2 # only 2 output validators, hence 2 + + assert isinstance(guard._validators[0], EndsWith) + assert guard._validators[0]._end == "a" + assert guard._validators[0]._kwargs["end"] == "a" + assert ( + guard._validators[0].on_fail_descriptor == OnFailAction.EXCEPTION + ) # bc we set it + + assert isinstance(guard._validators[1], OneLine) + assert ( + guard._validators[1].on_fail_descriptor == OnFailAction.NOOP + ) # bc this is the default + + # Test with an invalid "on" parameter, should raise a ValueError + with pytest.raises(ValueError): + guard: AsyncGuard = AsyncGuard().use_many( + (EndsWith, ["a"], {"on_fail": OnFailAction.EXCEPTION}), + OneLine, + on="response", + ) + + +@pytest.mark.asyncio +async def test_validate(): + guard: AsyncGuard = ( + AsyncGuard() + .use(OneLine) + .use( + LowerCase(on_fail=OnFailAction.FIX), on="output" + ) # default on="output", still explicitly set + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) + + llm_output: str = "Oh Canada" # bc it meets our criteria + response = await guard.validate(llm_output) + + assert response.validation_passed is True + assert response.validated_output == llm_output.lower() + llm_output_2 = "Star Spangled Banner" # to stick with the theme + + response_2 = await guard.validate(llm_output_2) + + assert response_2.validation_passed is False + assert response_2.validated_output is None + + # Test with a combination of prompt, output, instructions and msg_history validators + # Should still only use the output validators to validate the output + guard: AsyncGuard = ( + AsyncGuard() + .use(OneLine, on="prompt") + .use(LowerCase, on="instructions") + .use(UpperCase, on="msg_history") + .use(LowerCase, on="output", on_fail=OnFailAction.FIX) + .use(TwoWords, on="output") + .use(ValidLength, 0, 12, on="output") + ) + + llm_output: str = "Oh Canada" # bc it meets our criteria + + response = await guard.validate(llm_output) + + assert response.validation_passed is True + assert response.validated_output == llm_output.lower() + + llm_output_2 = "Star Spangled Banner" # to stick with the theme + + response_2 = await guard.validate(llm_output_2) + + assert response_2.validation_passed is False + assert response_2.validated_output is None + + +def test_use_and_use_many(): + guard: AsyncGuard = ( + AsyncGuard() + .use_many(OneLine(), LowerCase(), on="prompt") + .use(UpperCase, on="instructions") + .use(LowerCase, on="msg_history") + .use_many( + TwoWords(on_fail=OnFailAction.REASK), + ValidLength(0, 12, on_fail=OnFailAction.REFRAIN), + on="output", + ) + ) + + # Check schemas for prompt, instructions and msg_history validators + prompt_validators = guard.rail.prompt_schema.root_datatype.validators + assert len(prompt_validators) == 2 + assert prompt_validators[0].__class__.__name__ == "OneLine" + assert prompt_validators[1].__class__.__name__ == "LowerCase" + + instructions_validators = guard.rail.instructions_schema.root_datatype.validators + assert len(instructions_validators) == 1 + assert instructions_validators[0].__class__.__name__ == "UpperCase" + + msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators + assert len(msg_history_validators) == 1 + assert msg_history_validators[0].__class__.__name__ == "LowerCase" + + # Check guard for output validators + assert len(guard._validators) == 2 # only 2 output validators, hence 2 + + assert isinstance(guard._validators[0], TwoWords) + assert guard._validators[0].on_fail_descriptor == OnFailAction.REASK # bc we set it + + assert isinstance(guard._validators[1], ValidLength) + assert guard._validators[1]._min == 0 + assert guard._validators[1]._kwargs["min"] == 0 + assert guard._validators[1]._max == 12 + assert guard._validators[1]._kwargs["max"] == 12 + assert ( + guard._validators[1].on_fail_descriptor == OnFailAction.REFRAIN + ) # bc we set it + + # Test with an invalid "on" parameter, should raise a ValueError + with pytest.raises(ValueError): + guard: AsyncGuard = ( + AsyncGuard() + .use_many(OneLine(), LowerCase(), on="prompt") + .use(UpperCase, on="instructions") + .use(LowerCase, on="msg_history") + .use_many( + TwoWords(on_fail=OnFailAction.REASK), + ValidLength(0, 12, on_fail=OnFailAction.REFRAIN), + on="response", # invalid "on" parameter + ) + )