From fe23bf9ace03d30f95bd7c3297acb66ad082ea55 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Fri, 10 May 2024 14:04:33 -0700 Subject: [PATCH 1/9] initial commit for testing async streaming guard --- guardrails/__init__.py | 1 + guardrails/async_guard.py | 1318 +++++++++++++++++++++++++++++++++++++ guardrails/guard.py | 63 +- 3 files changed, 1341 insertions(+), 41 deletions(-) create mode 100644 guardrails/async_guard.py diff --git a/guardrails/__init__.py b/guardrails/__init__.py index 6e28dff7f..8883668f0 100644 --- a/guardrails/__init__.py +++ b/guardrails/__init__.py @@ -1,6 +1,7 @@ # Set up __init__.py so that users can do from guardrails import Response, Schema, etc. from guardrails.guard import Guard +from guardrails.async_guard import AsyncGuard from guardrails.llm_providers import PromptCallableBase from guardrails.logging_utils import configure_logging from guardrails.prompt import Instructions, Prompt diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py new file mode 100644 index 000000000..4f986e56c --- /dev/null +++ b/guardrails/async_guard.py @@ -0,0 +1,1318 @@ +import asyncio +import contextvars +import json +import os +import time +import warnings +from copy import deepcopy +from string import Template +from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, + Optional, Sequence, Tuple, Type, Union, cast, overload) + +from guardrails_api_client.models import AnyObject +from guardrails_api_client.models import Guard as GuardModel +from guardrails_api_client.models import (History, HistoryEvent, + ValidatePayload, ValidationOutput) +from guardrails_api_client.types import UNSET +from langchain_core.messages import BaseMessage +from langchain_core.runnables import Runnable, RunnableConfig +from pydantic import BaseModel +from pydantic.version import VERSION as PYDANTIC_VERSION +from typing_extensions import deprecated + +from guardrails.api_client import GuardrailsApiClient +from guardrails.classes import OT, InputType, ValidationOutcome +from guardrails.classes.credentials import Credentials +from guardrails.classes.generic import Stack +from guardrails.classes.history import Call +from guardrails.classes.history.call_inputs import CallInputs +from guardrails.classes.history.inputs import Inputs +from guardrails.classes.history.iteration import Iteration +from guardrails.classes.history.outputs import Outputs +from guardrails.errors import ValidationError +from guardrails.llm_providers import (get_async_llm_ask, get_llm_api_enum, + get_llm_ask, + model_is_supported_server_side) +from guardrails.logger import logger, set_scope +from guardrails.prompt import Instructions, Prompt +from guardrails.rail import Rail +from guardrails.run import AsyncRunner, Runner, StreamRunner +from guardrails.schema import Schema, StringSchema +from guardrails.stores.context import (Tracer, get_call_kwarg, + get_tracer_context, set_call_kwargs, + set_tracer, set_tracer_context) +from guardrails.utils.hub_telemetry_utils import HubTelemetry +from guardrails.utils.llm_response import LLMResponse +from guardrails.utils.reask_utils import FieldReAsk +from guardrails.utils.validator_utils import get_validator +from guardrails.validator_base import FailResult, Validator + + +class AsyncGuard(Runnable, Generic[OT]): + """The Guard class. + + This class one of the main entry point for using Guardrails. It is + initialized from one of the following class methods: + + - `from_rail` + - `from_rail_string` + - `from_pydantic` + - `from_string` + + The `__call__` + method functions as a wrapper around LLM APIs. It takes in an Async LLM + API, and optional prompt parameters, and returns the raw output stream from + the LLM and the validated output stream. + """ + + _tracer = None + _tracer_context = None + _hub_telemetry = None + _guard_id = None + _user_id = None + _validators: List[Validator] + _api_client: Optional[GuardrailsApiClient] = None + + def __init__( + self, + rail: Optional[Rail] = None, + num_reasks: Optional[int] = None, + base_model: Optional[ + Union[Type[BaseModel], Type[List[Type[BaseModel]]]] + ] = None, + tracer: Optional[Tracer] = None, + *, + name: Optional[str] = None, + description: Optional[str] = None, + ): + """Initialize the Guard with optional Rail instance, num_reasks, and + base_model.""" + if not rail: + rail = ( + Rail.from_pydantic(base_model) + if base_model + else Rail.from_string_validators([]) + ) + self.rail = rail + self.num_reasks = num_reasks + # TODO: Support a sink for history so that it is not solely held in memory + self.history: Stack[Call] = Stack() + self.base_model = base_model + self._set_tracer(tracer) + + credentials = Credentials.from_rc_file(logger) + + # Get unique id of user from credentials + self._user_id = credentials.id or "" + + # Get metrics opt-out from credentials + self._disable_tracer = credentials.no_metrics + + # Get id of guard object (that is unique) + self._guard_id = id(self) # id of guard object; not the class + + # Initialize Hub Telemetry singleton and get the tracer + # if it is not disabled + if not self._disable_tracer: + self._hub_telemetry = HubTelemetry() + self._validators = [] + + # Gaurdrails As A Service Initialization + self.description = description + self.name = name + + api_key = os.environ.get("GUARDRAILS_API_KEY") + if api_key is not None: + if self.name is None: + self.name = f"gr-{str(self._guard_id)}" + logger.warn("Warning: No name passed to guard!") + logger.warn( + "Use this auto-generated name to re-use this guard: {name}".format( + name=self.name + ) + ) + self._api_client = GuardrailsApiClient(api_key=api_key) + self.upsert_guard() + + @property + def prompt_schema(self) -> Optional[StringSchema]: + """Return the input schema.""" + return self.rail.prompt_schema + + @property + def instructions_schema(self) -> Optional[StringSchema]: + """Return the input schema.""" + return self.rail.instructions_schema + + @property + def msg_history_schema(self) -> Optional[StringSchema]: + """Return the input schema.""" + return self.rail.msg_history_schema + + @property + def output_schema(self) -> Schema: + """Return the output schema.""" + return self.rail.output_schema + + @property + def instructions(self) -> Optional[Instructions]: + """Return the instruction-prompt.""" + return self.rail.instructions + + @property + def prompt(self) -> Optional[Prompt]: + """Return the prompt.""" + return self.rail.prompt + + @property + def raw_prompt(self) -> Optional[Prompt]: + """Return the prompt, alias for `prompt`.""" + return self.prompt + + @property + def base_prompt(self) -> Optional[str]: + """Return the base prompt i.e. prompt.source.""" + if self.prompt is None: + return None + return self.prompt.source + + @property + def reask_prompt(self) -> Optional[Prompt]: + """Return the reask prompt.""" + return self.output_schema.reask_prompt_template + + @reask_prompt.setter + def reask_prompt(self, reask_prompt: Optional[str]): + """Set the reask prompt.""" + self.output_schema.reask_prompt_template = reask_prompt + + @property + def reask_instructions(self) -> Optional[Instructions]: + """Return the reask prompt.""" + return self.output_schema.reask_instructions_template + + @reask_instructions.setter + def reask_instructions(self, reask_instructions: Optional[str]): + """Set the reask prompt.""" + self.output_schema.reask_instructions_template = reask_instructions + + def configure( + self, + num_reasks: Optional[int] = None, + ): + """Configure the Guard.""" + self.num_reasks = ( + num_reasks + if num_reasks is not None + else self.num_reasks + if self.num_reasks is not None + else 1 + ) + + def _set_tracer(self, tracer: Optional[Tracer] = None) -> None: + self._tracer = tracer + set_tracer(tracer) + set_tracer_context() + self._tracer_context = get_tracer_context() + + @classmethod + def from_rail( + cls, + rail_file: str, + num_reasks: Optional[int] = None, + tracer: Optional[Tracer] = None, + *, + name: Optional[str] = None, + description: Optional[str] = None, + ): + """Create a Schema from a `.rail` file. + + Args: + rail_file: The path to the `.rail` file. + num_reasks: The max times to re-ask the LLM for invalid output. + + Returns: + An instance of the `Guard` class. + """ + + # We have to set the tracer in the ContextStore before the Rail, + # and therefore the Validators, are initialized + cls._set_tracer(cls, tracer) # type: ignore + + rail = Rail.from_file(rail_file) + if rail.output_type == "str": + return cast( + AsyncGuard[str], + cls( + rail=rail, + num_reasks=num_reasks, + tracer=tracer, + name=name, + description=description, + ), + ) + elif rail.output_type == "list": + return cast( + AsyncGuard[List], + cls( + rail=rail, + num_reasks=num_reasks, + tracer=tracer, + name=name, + description=description, + ), + ) + return cast( + AsyncGuard[Dict], + cls( + rail=rail, + num_reasks=num_reasks, + tracer=tracer, + name=name, + description=description, + ), + ) + + @classmethod + def from_rail_string( + cls, + rail_string: str, + num_reasks: Optional[int] = None, + tracer: Optional[Tracer] = None, + *, + name: Optional[str] = None, + description: Optional[str] = None, + ): + """Create a Schema from a `.rail` string. + + Args: + rail_string: The `.rail` string. + num_reasks: The max times to re-ask the LLM for invalid output. + + Returns: + An instance of the `Guard` class. + """ + # We have to set the tracer in the ContextStore before the Rail, + # and therefore the Validators, are initialized + cls._set_tracer(cls, tracer) # type: ignore + + rail = Rail.from_string(rail_string) + if rail.output_type == "str": + return cast( + AsyncGuard[str], + cls( + rail=rail, + num_reasks=num_reasks, + tracer=tracer, + name=name, + description=description, + ), + ) + elif rail.output_type == "list": + return cast( + AsyncGuard[List], + cls( + rail=rail, + num_reasks=num_reasks, + tracer=tracer, + name=name, + description=description, + ), + ) + return cast( + AsyncGuard[Dict], + cls( + rail=rail, + num_reasks=num_reasks, + tracer=tracer, + name=name, + description=description, + ), + ) + + @classmethod + def from_pydantic( + cls, + output_class: Union[Type[BaseModel], Type[List[Type[BaseModel]]]], + prompt: Optional[str] = None, + instructions: Optional[str] = None, + num_reasks: Optional[int] = None, + reask_prompt: Optional[str] = None, + reask_instructions: Optional[str] = None, + tracer: Optional[Tracer] = None, + *, + name: Optional[str] = None, + description: Optional[str] = None, + ): + """Create a Guard instance from a Pydantic model and prompt.""" + if PYDANTIC_VERSION.startswith("1"): + warnings.warn( + """Support for Pydantic v1.x is deprecated and will be removed in + Guardrails 0.5.x. Please upgrade to the latest Pydantic v2.x to + continue receiving future updates and support.""", + FutureWarning, + ) + # We have to set the tracer in the ContextStore before the Rail, + # and therefore the Validators, are initialized + cls._set_tracer(cls, tracer) # type: ignore + + rail = Rail.from_pydantic( + output_class=output_class, + prompt=prompt, + instructions=instructions, + reask_prompt=reask_prompt, + reask_instructions=reask_instructions, + ) + if rail.output_type == "list": + return cast( + AsyncGuard[List], + cls(rail, num_reasks=num_reasks, base_model=output_class), + ) + return cast( + AsyncGuard[Dict], + cls( + rail, + num_reasks=num_reasks, + base_model=output_class, + tracer=tracer, + name=name, + description=description, + ), + ) + + @classmethod + def from_string( + cls, + validators: Sequence[Validator], + description: Optional[str] = None, + prompt: Optional[str] = None, + instructions: Optional[str] = None, + reask_prompt: Optional[str] = None, + reask_instructions: Optional[str] = None, + num_reasks: Optional[int] = None, + tracer: Optional[Tracer] = None, + *, + name: Optional[str] = None, + guard_description: Optional[str] = None, + ): + """Create a Guard instance for a string response with prompt, + instructions, and validations. + + Args: + validators: (List[Validator]): The list of validators to apply to the string output. + description (str, optional): A description for the string to be generated. Defaults to None. + prompt (str, optional): The prompt used to generate the string. Defaults to None. + instructions (str, optional): Instructions for chat models. Defaults to None. + reask_prompt (str, optional): An alternative prompt to use during reasks. Defaults to None. + reask_instructions (str, optional): Alternative instructions to use during reasks. Defaults to None. + num_reasks (int, optional): The max times to re-ask the LLM for invalid output. + """ # noqa + + cls._set_tracer(cls, tracer) # type: ignore + + rail = Rail.from_string_validators( + validators=validators, + description=description, + prompt=prompt, + instructions=instructions, + reask_prompt=reask_prompt, + reask_instructions=reask_instructions, + ) + return cast( + AsyncGuard[str], + cls( + rail, + num_reasks=num_reasks, + tracer=tracer, + name=name, + description=guard_description, + ), + ) + + @overload + def __call__( + self, + llm_api: Callable, + prompt_params: Optional[Dict] = None, + num_reasks: Optional[int] = None, + prompt: Optional[str] = None, + instructions: Optional[str] = None, + msg_history: Optional[List[Dict]] = None, + metadata: Optional[Dict] = None, + full_schema_reask: Optional[bool] = None, + stream: Optional[bool] = False, + *args, + **kwargs, + ) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]: ... + + @overload + def __call__( + self, + llm_api: Callable[[Any], Awaitable[Any]], + prompt_params: Optional[Dict] = None, + num_reasks: Optional[int] = None, + prompt: Optional[str] = None, + instructions: Optional[str] = None, + msg_history: Optional[List[Dict]] = None, + metadata: Optional[Dict] = None, + full_schema_reask: Optional[bool] = None, + *args, + **kwargs, + ) -> Awaitable[ValidationOutcome[OT]]: ... + + async def __call__( + self, + llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]], + prompt_params: Optional[Dict] = None, + num_reasks: Optional[int] = None, + prompt: Optional[str] = None, + instructions: Optional[str] = None, + msg_history: Optional[List[Dict]] = None, + metadata: Optional[Dict] = None, + full_schema_reask: Optional[bool] = None, + *args, + **kwargs, + ) -> Union[ + Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]], + Awaitable[ValidationOutcome[OT]], + ]: + """Call the LLM and validate the output. Pass an async LLM API to + return a coroutine. + + Args: + llm_api: The LLM API to call + (e.g. openai.Completion.create or openai.Completion.acreate) + prompt_params: The parameters to pass to the prompt.format() method. + num_reasks: The max times to re-ask the LLM for invalid output. + prompt: The prompt to use for the LLM. + instructions: Instructions for chat models. + msg_history: The message history to pass to the LLM. + metadata: Metadata to pass to the validators. + full_schema_reask: When reasking, whether to regenerate the full schema + or just the incorrect values. + Defaults to `True` if a base model is provided, + `False` otherwise. + + Returns: + The raw text output from the LLM and the validated output. + """ + + async def __call( + self, + llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]], + prompt_params: Optional[Dict] = None, + num_reasks: Optional[int] = None, + prompt: Optional[str] = None, + instructions: Optional[str] = None, + msg_history: Optional[List[Dict]] = None, + metadata: Optional[Dict] = None, + full_schema_reask: Optional[bool] = None, + *args, + **kwargs, + ): + if metadata is None: + metadata = {} + if full_schema_reask is None: + full_schema_reask = self.base_model is not None + if prompt_params is None: + prompt_params = {} + + if not self._disable_tracer: + # Create a new span for this guard call + self._hub_telemetry.create_new_span( + span_name="/guard_call", + attributes=[ + ("guard_id", self._guard_id), + ("user_id", self._user_id), + ("llm_api", llm_api.__name__ if llm_api else "None"), + ("custom_reask_prompt", self.reask_prompt is not None), + ( + "custom_reask_instructions", + self.reask_instructions is not None, + ), + ], + is_parent=True, # It will have children + has_parent=False, # Has no parents + ) + + set_call_kwargs(kwargs) + set_tracer(self._tracer) + set_tracer_context(self._tracer_context) + + self.configure(num_reasks) + if self.num_reasks is None: + raise RuntimeError( + "`num_reasks` is `None` after calling `configure()`. " + "This should never happen." + ) + + input_prompt = prompt or (self.prompt._source if self.prompt else None) + input_instructions = instructions or ( + self.instructions._source if self.instructions else None + ) + call_inputs = CallInputs( + llm_api=llm_api, + prompt=input_prompt, + instructions=input_instructions, + msg_history=msg_history, + prompt_params=prompt_params, + num_reasks=self.num_reasks, + metadata=metadata, + full_schema_reask=full_schema_reask, + args=list(args), + kwargs=kwargs, + ) + call_log = Call(inputs=call_inputs) + set_scope(str(id(call_log))) + self.history.push(call_log) + + if self._api_client is not None and model_is_supported_server_side( + llm_api, *args, **kwargs + ): + return self._call_server( + llm_api=llm_api, + num_reasks=self.num_reasks, + prompt_params=prompt_params, + full_schema_reask=full_schema_reask, + call_log=call_log, + *args, + **kwargs, + ) + + # If the LLM API is not async, fail + if not asyncio.iscoroutinefunction(llm_api): + raise RuntimeError( + f"The LLM API `{llm_api.__name__}` is not a coroutine. " + "Please use an async LLM API." + ) + # Otherwise, call the LLM + return await self._call_async( + llm_api, + prompt_params=prompt_params, + num_reasks=self.num_reasks, + prompt=prompt, + instructions=instructions, + msg_history=msg_history, + metadata=metadata, + full_schema_reask=full_schema_reask, + call_log=call_log, + *args, + **kwargs, + ) + + guard_context = contextvars.Context() + return guard_context.run( + __call, + self, + llm_api, + prompt_params, + num_reasks, + prompt, + instructions, + msg_history, + metadata, + full_schema_reask, + *args, + **kwargs, + ) + + async def _call_async( + self, + llm_api: Callable[[Any], Awaitable[Any]], + prompt_params: Dict, + num_reasks: int, + prompt: Optional[str], + instructions: Optional[str], + msg_history: Optional[List[Dict]], + metadata: Dict, + full_schema_reask: bool, + call_log: Call, + *args, + **kwargs, + ) -> ValidationOutcome[OT]: + """Call the LLM asynchronously and validate the output. + + Args: + llm_api: The LLM API to call asynchronously (e.g. openai.Completion.acreate) + prompt_params: The parameters to pass to the prompt.format() method. + num_reasks: The max times to re-ask the LLM for invalid output. + prompt: The prompt to use for the LLM. + instructions: Instructions for chat models. + msg_history: The message history to pass to the LLM. + metadata: Metadata to pass to the validators. + full_schema_reask: When reasking, whether to regenerate the full schema + or just the incorrect values. + Defaults to `True` if a base model is provided, + `False` otherwise. + + Returns: + The raw text output from the LLM and the validated output. + """ + instructions_obj = instructions or self.instructions + prompt_obj = prompt or self.prompt + msg_history_obj = msg_history or [] + if prompt_obj is None: + if msg_history_obj is not None and not len(msg_history_obj): + raise RuntimeError( + "You must provide a prompt if msg_history is empty. " + "Alternatively, you can provide a prompt in the RAIL spec." + ) + + runner = AsyncRunner( + instructions=instructions_obj, + prompt=prompt_obj, + msg_history=msg_history_obj, + api=get_async_llm_ask(llm_api, *args, **kwargs), + prompt_schema=self.prompt_schema, + instructions_schema=self.instructions_schema, + msg_history_schema=self.msg_history_schema, + output_schema=self.output_schema, + num_reasks=num_reasks, + metadata=metadata, + base_model=self.base_model, + full_schema_reask=full_schema_reask, + disable_tracer=self._disable_tracer, + ) + call = await runner.async_run(call_log=call_log, prompt_params=prompt_params) + return ValidationOutcome[OT].from_guard_history(call) + + def __repr__(self): + return f"Guard(RAIL={self.rail})" + + def __rich_repr__(self): + yield "RAIL", self.rail + + def __stringify__(self): + if self.rail and self.rail.output_type == "str": + template = Template( + """ + Guard { + validators: [ + ${validators} + ] + } + """ + ) + return template.safe_substitute( + { + "validators": ",\n".join( + [v.__stringify__() for v in self._validators] + ) + } + ) + return self.__repr__() + + @overload + def parse( + self, + llm_output: str, + metadata: Optional[Dict] = None, + llm_api: None = None, + num_reasks: Optional[int] = None, + prompt_params: Optional[Dict] = None, + full_schema_reask: Optional[bool] = None, + *args, + **kwargs, + ) -> ValidationOutcome[OT]: ... + + @overload + def parse( + self, + llm_output: str, + metadata: Optional[Dict] = None, + llm_api: Callable[[Any], Awaitable[Any]] = ..., + num_reasks: Optional[int] = None, + prompt_params: Optional[Dict] = None, + full_schema_reask: Optional[bool] = None, + *args, + **kwargs, + ) -> Awaitable[ValidationOutcome[OT]]: ... + + @overload + def parse( + self, + llm_output: str, + metadata: Optional[Dict] = None, + llm_api: Optional[Callable] = None, + num_reasks: Optional[int] = None, + prompt_params: Optional[Dict] = None, + full_schema_reask: Optional[bool] = None, + *args, + **kwargs, + ) -> ValidationOutcome[OT]: ... + + def parse( + self, + llm_output: str, + metadata: Optional[Dict] = None, + llm_api: Optional[Callable] = None, + num_reasks: Optional[int] = None, + prompt_params: Optional[Dict] = None, + full_schema_reask: Optional[bool] = None, + *args, + **kwargs, + ) -> Union[ValidationOutcome[OT], Awaitable[ValidationOutcome[OT]]]: + """Alternate flow to using Guard where the llm_output is known. + + Args: + llm_output: The output being parsed and validated. + metadata: Metadata to pass to the validators. + llm_api: The LLM API to call + (e.g. openai.Completion.create or openai.Completion.acreate) + num_reasks: The max times to re-ask the LLM for invalid output. + prompt_params: The parameters to pass to the prompt.format() method. + full_schema_reask: When reasking, whether to regenerate the full schema + or just the incorrect values. + + Returns: + The validated response. This is either a string or a dictionary, + determined by the object schema defined in the RAILspec. + """ + + def __parse( + self, + llm_output: str, + metadata: Optional[Dict] = None, + llm_api: Optional[Callable] = None, + num_reasks: Optional[int] = None, + prompt_params: Optional[Dict] = None, + full_schema_reask: Optional[bool] = None, + *args, + **kwargs, + ): + final_num_reasks = ( + num_reasks if num_reasks is not None else 0 if llm_api is None else None + ) + + if not self._disable_tracer: + self._hub_telemetry.create_new_span( + span_name="/guard_parse", + attributes=[ + ("guard_id", self._guard_id), + ("user_id", self._user_id), + ("llm_api", llm_api.__name__ if llm_api else "None"), + ("custom_reask_prompt", self.reask_prompt is not None), + ( + "custom_reask_instructions", + self.reask_instructions is not None, + ), + ], + is_parent=True, # It will have children + has_parent=False, # Has no parents + ) + + self.configure(final_num_reasks) + if self.num_reasks is None: + raise RuntimeError( + "`num_reasks` is `None` after calling `configure()`. " + "This should never happen." + ) + if full_schema_reask is None: + full_schema_reask = self.base_model is not None + metadata = metadata or {} + prompt_params = prompt_params or {} + + set_call_kwargs(kwargs) + set_tracer(self._tracer) + set_tracer_context(self._tracer_context) + + input_prompt = self.prompt._source if self.prompt else None + input_instructions = ( + self.instructions._source if self.instructions else None + ) + call_inputs = CallInputs( + llm_api=llm_api, + llm_output=llm_output, + prompt=input_prompt, + instructions=input_instructions, + prompt_params=prompt_params, + num_reasks=self.num_reasks, + metadata=metadata, + full_schema_reask=full_schema_reask, + args=list(args), + kwargs=kwargs, + ) + call_log = Call(inputs=call_inputs) + set_scope(str(id(call_log))) + self.history.push(call_log) + + if self._api_client is not None and model_is_supported_server_side( + llm_api, *args, **kwargs + ): + return self._call_server( + llm_output=llm_output, + llm_api=llm_api, + num_reasks=self.num_reasks, + prompt_params=prompt_params, + full_schema_reask=full_schema_reask, + call_log=call_log, + *args, + **kwargs, + ) + + # If the LLM API is async, return a coroutine + if asyncio.iscoroutinefunction(llm_api): + return self._async_parse( + llm_output, + metadata, + llm_api=llm_api, + num_reasks=self.num_reasks, + prompt_params=prompt_params, + full_schema_reask=full_schema_reask, + call_log=call_log, + *args, + **kwargs, + ) + else: + raise NotImplementedError( + "AsyncGuard does not support non-async LLM APIs. " + "Please use the sync API Guard or supply an async LLM API." + ) + + guard_context = contextvars.Context() + return guard_context.run( + __parse, + self, + llm_output, + metadata, + llm_api, + num_reasks, + prompt_params, + full_schema_reask, + *args, + **kwargs, + ) + + async def _async_parse( + self, + llm_output: str, + metadata: Dict, + llm_api: Optional[Callable[[Any], Awaitable[Any]]], + num_reasks: int, + prompt_params: Dict, + full_schema_reask: bool, + call_log: Call, + *args, + **kwargs, + ) -> ValidationOutcome[OT]: + """Alternate flow to using Guard where the llm_output is known. + + Args: + llm_output: The output from the LLM. + llm_api: The LLM API to use to re-ask the LLM. + num_reasks: The max times to re-ask the LLM for invalid output. + + Returns: + The validated response. + """ + runner = AsyncRunner( + instructions=kwargs.pop("instructions", None), + prompt=kwargs.pop("prompt", None), + msg_history=kwargs.pop("msg_history", None), + api=get_async_llm_ask(llm_api, *args, **kwargs) if llm_api else None, + prompt_schema=self.prompt_schema, + instructions_schema=self.instructions_schema, + msg_history_schema=self.msg_history_schema, + output_schema=self.output_schema, + num_reasks=num_reasks, + metadata=metadata, + output=llm_output, + base_model=self.base_model, + full_schema_reask=full_schema_reask, + disable_tracer=self._disable_tracer, + ) + call = await runner.async_run(call_log=call_log, prompt_params=prompt_params) + + return ValidationOutcome[OT].from_guard_history(call) + + @deprecated( + """The `with_prompt_validation` method is deprecated, + and will be removed in 0.5.x. Instead, please use + `Guard().use(YourValidator, on='prompt')`.""", + category=FutureWarning, + stacklevel=2, + ) + def with_prompt_validation( + self, + validators: Sequence[Validator], + ): + """Add prompt validation to the Guard. + + Args: + validators: The validators to add to the prompt. + """ + if self.rail.prompt_schema: + warnings.warn("Overriding existing prompt validators.") + schema = StringSchema.from_string( + validators=validators, + ) + self.rail.prompt_schema = schema + return self + + @deprecated( + """The `with_instructions_validation` method is deprecated, + and will be removed in 0.5.x. Instead, please use + `Guard().use(YourValidator, on='instructions')`.""", + category=FutureWarning, + stacklevel=2, + ) + def with_instructions_validation( + self, + validators: Sequence[Validator], + ): + """Add instructions validation to the Guard. + + Args: + validators: The validators to add to the instructions. + """ + if self.rail.instructions_schema: + warnings.warn("Overriding existing instructions validators.") + schema = StringSchema.from_string( + validators=validators, + ) + self.rail.instructions_schema = schema + return self + + @deprecated( + """The `with_msg_history_validation` method is deprecated, + and will be removed in 0.5.x. Instead, please use + `Guard().use(YourValidator, on='msg_history')`.""", + category=FutureWarning, + stacklevel=2, + ) + def with_msg_history_validation( + self, + validators: Sequence[Validator], + ): + """Add msg_history validation to the Guard. + + Args: + validators: The validators to add to the msg_history. + """ + if self.rail.msg_history_schema: + warnings.warn("Overriding existing msg_history validators.") + schema = StringSchema.from_string( + validators=validators, + ) + self.rail.msg_history_schema = schema + return self + + def __add_validator(self, validator: Validator, on: str = "output"): + # Only available for string output types + if self.rail.output_type != "str": + raise RuntimeError( + "The `use` method is only available for string output types." + ) + + if on == "prompt": + # If the prompt schema exists, add the validator to it + if self.rail.prompt_schema: + self.rail.prompt_schema.root_datatype.validators.append(validator) + else: + # Otherwise, create a new schema with the validator + schema = StringSchema.from_string( + validators=[validator], + ) + self.rail.prompt_schema = schema + elif on == "instructions": + # If the instructions schema exists, add the validator to it + if self.rail.instructions_schema: + self.rail.instructions_schema.root_datatype.validators.append(validator) + else: + # Otherwise, create a new schema with the validator + schema = StringSchema.from_string( + validators=[validator], + ) + self.rail.instructions_schema = schema + elif on == "msg_history": + # If the msg_history schema exists, add the validator to it + if self.rail.msg_history_schema: + self.rail.msg_history_schema.root_datatype.validators.append(validator) + else: + # Otherwise, create a new schema with the validator + schema = StringSchema.from_string( + validators=[validator], + ) + self.rail.msg_history_schema = schema + elif on == "output": + self._validators.append(validator) + self.rail.output_schema.root_datatype.validators.append(validator) + else: + raise ValueError( + """Invalid value for `on`. Must be one of the following: + 'output', 'prompt', 'instructions', 'msg_history'.""" + ) + + @overload + def use(self, validator: Validator, *, on: str = "output") -> "AsyncGuard": ... + + @overload + def use( + self, validator: Type[Validator], *args, on: str = "output", **kwargs + ) -> "AsyncGuard": ... + + def use( + self, + validator: Union[Validator, Type[Validator]], + *args, + on: str = "output", + **kwargs, + ) -> "AsyncGuard": + """Use a validator to validate either of the following: + - The output of an LLM request + - The prompt + - The instructions + - The message history + + *Note*: For on="output", `use` is only available for string output types. + + Args: + validator: The validator to use. Either the class or an instance. + on: The part of the LLM request to validate. Defaults to "output". + """ + hydrated_validator = get_validator(validator, *args, **kwargs) + self.__add_validator(hydrated_validator, on=on) + return self + + @overload + def use_many(self, *validators: Validator, on: str = "output") -> "Async": ... + + @overload + def use_many( + self, + *validators: Tuple[ + Type[Validator], + Optional[Union[List[Any], Dict[str, Any]]], + Optional[Dict[str, Any]], + ], + on: str = "output", + ) -> "AsyncGuard": ... + + def use_many( + self, + *validators: Union[ + Validator, + Tuple[ + Type[Validator], + Optional[Union[List[Any], Dict[str, Any]]], + Optional[Dict[str, Any]], + ], + ], + on: str = "output", + ) -> "AsyncGuard": + """Use a validator to validate results of an LLM request. + + *Note*: `use_many` is only available for string output types. + """ + if self.rail.output_type != "str": + raise RuntimeError( + "The `use_many` method is only available for string output types." + ) + + # Loop through the validators + for v in validators: + hydrated_validator = get_validator(v) + self.__add_validator(hydrated_validator, on=on) + return self + + def validate(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]: + if ( + not self.rail + or self.rail.output_schema.root_datatype.validators != self._validators + ): + self.rail = Rail.from_string_validators( + validators=self._validators, + prompt=self.prompt.source if self.prompt else None, + instructions=self.instructions.source if self.instructions else None, + reask_prompt=self.reask_prompt.source if self.reask_prompt else None, + reask_instructions=self.reask_instructions.source + if self.reask_instructions + else None, + ) + + return self.parse(llm_output=llm_output, *args, **kwargs) + + # No call support for this until + # https://github.com/guardrails-ai/guardrails/pull/525 is merged + # def __call__(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]: + # return self.validate(llm_output, *args, **kwargs) + + def invoke( + self, input: InputType, config: Optional[RunnableConfig] = None + ) -> InputType: + output = BaseMessage(content="", type="") + str_input = None + input_is_chat_message = False + if isinstance(input, BaseMessage): + input_is_chat_message = True + str_input = str(input.content) + output = deepcopy(input) + else: + str_input = str(input) + + response = self.validate(str_input) + + validated_output = response.validated_output + if not validated_output: + raise ValidationError( + ( + "The response from the LLM failed validation!" + "See `guard.history` for more details." + ) + ) + + if isinstance(validated_output, Dict): + validated_output = json.dumps(validated_output) + + if input_is_chat_message: + output.content = validated_output + return cast(InputType, output) + return cast(InputType, validated_output) + + def _to_request(self) -> Dict: + return { + "name": self.name, + "description": self.description, + "railspec": self.rail._to_request(), + "numReasks": self.num_reasks, + } + + def upsert_guard(self): + if self._api_client: + guard_dict = self._to_request() + self._api_client.upsert_guard(GuardModel.from_dict(guard_dict)) + else: + raise ValueError("Guard does not have an api client!") + + def _call_server( + self, + *args, + llm_output: Optional[str] = None, + llm_api: Optional[Callable] = None, + num_reasks: Optional[int] = None, + prompt_params: Optional[Dict] = None, + metadata: Optional[Dict] = {}, + full_schema_reask: Optional[bool] = True, + call_log: Optional[Call], + # prompt: Optional[str], + # instructions: Optional[str], + # msg_history: Optional[List[Dict]], + **kwargs, + ): + if self._api_client: + payload: Dict[str, Any] = {"args": list(args)} + payload.update(**kwargs) + if llm_output is not None: + payload["llmOutput"] = llm_output + if num_reasks is not None: + payload["numReasks"] = num_reasks + if prompt_params is not None: + payload["promptParams"] = prompt_params + if llm_api is not None: + payload["llmApi"] = get_llm_api_enum(llm_api) + # TODO: get enum for llm_api + validation_output: Optional[ValidationOutput] = self._api_client.validate( + guard=self, # type: ignore + payload=ValidatePayload.from_dict(payload), + openai_api_key=get_call_kwarg("api_key"), + ) + + if not validation_output: + return ValidationOutcome[OT]( + raw_llm_output=None, + validated_output=None, + validation_passed=False, + error="The response from the server was empty!", + ) + + call_log = call_log or Call() + if llm_api is not None: + llm_api = get_llm_ask(llm_api) + if asyncio.iscoroutinefunction(llm_api): + llm_api = get_async_llm_ask(llm_api) + session_history = ( + validation_output.session_history + if validation_output is not None and validation_output.session_history + else [] + ) + history: History + for history in session_history: + history_events: Optional[List[HistoryEvent]] = ( # type: ignore + history.history if history.history != UNSET else None + ) + if history_events is None: + continue + + iterations = [ + Iteration( + inputs=Inputs( + llm_api=llm_api, + llm_output=llm_output, + instructions=( + Instructions(h.instructions) if h.instructions else None + ), + prompt=( + Prompt(h.prompt.source) # type: ignore + if h.prompt is not None and h.prompt != UNSET + else None + ), + prompt_params=prompt_params, + num_reasks=(num_reasks or 0), + metadata=metadata, + full_schema_reask=full_schema_reask, + ), + outputs=Outputs( + llm_response_info=LLMResponse( + output=h.output # type: ignore + ), + raw_output=h.output, + parsed_output=( + h.parsed_output.to_dict() + if isinstance(h.parsed_output, AnyObject) + else h.parsed_output + ), + validation_output=( + h.validated_output.to_dict() + if isinstance(h.validated_output, AnyObject) + else h.validated_output + ), + reasks=list( + [ + FieldReAsk( + incorrect_value=r.to_dict().get( + "incorrect_value" + ), + path=r.to_dict().get("path"), + fail_results=[ + FailResult( + error_message=r.to_dict().get( + "error_message" + ), + fix_value=r.to_dict().get("fix_value"), + ) + ], + ) + for r in h.reasks # type: ignore + ] + if h.reasks != UNSET + else [] + ), + ), + ) + for h in history_events + ] + call_log.iterations.extend(iterations) + if self.history.length == 0: + self.history.push(call_log) + + # Our interfaces are too different for this to work right now. + # Once we move towards shared interfaces for both the open source + # and the api we can re-enable this. + # return ValidationOutcome[OT].from_guard_history(call_log) + return ValidationOutcome[OT]( + raw_llm_output=validation_output.raw_llm_response, # type: ignore + validated_output=cast(OT, validation_output.validated_output), + validation_passed=validation_output.result, + ) + else: + raise ValueError("Guard does not have an api client!") diff --git a/guardrails/guard.py b/guardrails/guard.py index e6ae27e1b..c5b873ae0 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -5,31 +5,13 @@ import warnings from copy import deepcopy from string import Template -from typing import ( - Any, - Awaitable, - Callable, - Dict, - Generic, - Iterable, - List, - Optional, - Sequence, - Tuple, - Type, - Union, - cast, - overload, -) - -from guardrails_api_client.models import ( - AnyObject, - History, - HistoryEvent, - ValidatePayload, - ValidationOutput, -) +from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, + Optional, Sequence, Tuple, Type, Union, cast, overload) + +from guardrails_api_client.models import AnyObject from guardrails_api_client.models import Guard as GuardModel +from guardrails_api_client.models import (History, HistoryEvent, + ValidatePayload, ValidationOutput) from guardrails_api_client.types import UNSET from langchain_core.messages import BaseMessage from langchain_core.runnables import Runnable, RunnableConfig @@ -47,25 +29,17 @@ from guardrails.classes.history.iteration import Iteration from guardrails.classes.history.outputs import Outputs from guardrails.errors import ValidationError -from guardrails.llm_providers import ( - get_async_llm_ask, - get_llm_api_enum, - get_llm_ask, - model_is_supported_server_side, -) +from guardrails.llm_providers import (get_async_llm_ask, get_llm_api_enum, + get_llm_ask, + model_is_supported_server_side) from guardrails.logger import logger, set_scope from guardrails.prompt import Instructions, Prompt from guardrails.rail import Rail from guardrails.run import AsyncRunner, Runner, StreamRunner from guardrails.schema import Schema, StringSchema -from guardrails.stores.context import ( - Tracer, - get_call_kwarg, - get_tracer_context, - set_call_kwargs, - set_tracer, - set_tracer_context, -) +from guardrails.stores.context import (Tracer, get_call_kwarg, + get_tracer_context, set_call_kwargs, + set_tracer, set_tracer_context) from guardrails.utils.hub_telemetry_utils import HubTelemetry from guardrails.utils.llm_response import LLMResponse from guardrails.utils.reask_utils import FieldReAsk @@ -500,8 +474,7 @@ def __call__( Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]], Awaitable[ValidationOutcome[OT]], ]: - """Call the LLM and validate the output. Pass an async LLM API to - return a coroutine. + """Call the LLM and validate the output. Args: llm_api: The LLM API to call @@ -603,7 +576,8 @@ def __call( **kwargs, ) - # If the LLM API is async, return a coroutine + # If the LLM API is async, return a coroutine. This will be deprecated soon. + if asyncio.iscoroutinefunction(llm_api): return self._call_async( llm_api, @@ -712,6 +686,7 @@ def _call_sync( call = runner(call_log=call_log, prompt_params=prompt_params) return ValidationOutcome[OT].from_guard_history(call) + async def _call_async( self, llm_api: Callable[[Any], Awaitable[Any]], @@ -744,6 +719,12 @@ async def _call_async( Returns: The raw text output from the LLM and the validated output. """ + warnings.warn( + "Using an async LLM client is deprecated in Guard and will " + "be removed in a future release. " + "Please use `AsyncGuard` or a non async llm client instead.", + DeprecationWarning, + ) instructions_obj = instructions or self.instructions prompt_obj = prompt or self.prompt msg_history_obj = msg_history or [] From b7deef25e31dbb41a3bb2b600216c157cd84c049 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Mon, 13 May 2024 13:40:51 -0700 Subject: [PATCH 2/9] Adding async guard, test, and deprecating async calls in Guard() --- guardrails/async_guard.py | 95 +++-- guardrails/guard.py | 71 +++- tests/unit_tests/test_async_guard.py | 601 +++++++++++++++++++++++++++ 3 files changed, 720 insertions(+), 47 deletions(-) create mode 100644 tests/unit_tests/test_async_guard.py diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py index 4f986e56c..09affb0fc 100644 --- a/guardrails/async_guard.py +++ b/guardrails/async_guard.py @@ -1,18 +1,36 @@ import asyncio import contextvars +import inspect import json import os -import time import warnings from copy import deepcopy from string import Template -from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, - Optional, Sequence, Tuple, Type, Union, cast, overload) +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Generic, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, + overload, +) from guardrails_api_client.models import AnyObject from guardrails_api_client.models import Guard as GuardModel -from guardrails_api_client.models import (History, HistoryEvent, - ValidatePayload, ValidationOutput) +from guardrails_api_client.models import ( + History, + HistoryEvent, + ValidatePayload, + ValidationOutput, +) from guardrails_api_client.types import UNSET from langchain_core.messages import BaseMessage from langchain_core.runnables import Runnable, RunnableConfig @@ -30,17 +48,25 @@ from guardrails.classes.history.iteration import Iteration from guardrails.classes.history.outputs import Outputs from guardrails.errors import ValidationError -from guardrails.llm_providers import (get_async_llm_ask, get_llm_api_enum, - get_llm_ask, - model_is_supported_server_side) +from guardrails.llm_providers import ( + get_async_llm_ask, + get_llm_api_enum, + get_llm_ask, + model_is_supported_server_side, +) from guardrails.logger import logger, set_scope from guardrails.prompt import Instructions, Prompt from guardrails.rail import Rail -from guardrails.run import AsyncRunner, Runner, StreamRunner +from guardrails.run import AsyncRunner from guardrails.schema import Schema, StringSchema -from guardrails.stores.context import (Tracer, get_call_kwarg, - get_tracer_context, set_call_kwargs, - set_tracer, set_tracer_context) +from guardrails.stores.context import ( + Tracer, + get_call_kwarg, + get_tracer_context, + set_call_kwargs, + set_tracer, + set_tracer_context, +) from guardrails.utils.hub_telemetry_utils import HubTelemetry from guardrails.utils.llm_response import LLMResponse from guardrails.utils.reask_utils import FieldReAsk @@ -460,7 +486,7 @@ def __call__( **kwargs, ) -> Awaitable[ValidationOutcome[OT]]: ... - async def __call__( + def __call__( self, llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]], prompt_params: Optional[Dict] = None, @@ -497,7 +523,7 @@ async def __call__( The raw text output from the LLM and the validated output. """ - async def __call( + def __call( self, llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]], prompt_params: Optional[Dict] = None, @@ -580,13 +606,16 @@ async def __call( ) # If the LLM API is not async, fail - if not asyncio.iscoroutinefunction(llm_api): + # FIXME: it seems like this check isn't actually working? + if not inspect.isawaitable(llm_api) and not inspect.iscoroutinefunction( + llm_api + ): raise RuntimeError( f"The LLM API `{llm_api.__name__}` is not a coroutine. " "Please use an async LLM API." ) # Otherwise, call the LLM - return await self._call_async( + return self._call_async( llm_api, prompt_params=prompt_params, num_reasks=self.num_reasks, @@ -616,7 +645,7 @@ async def __call( **kwargs, ) - async def _call_async( + def _call_async( self, llm_api: Callable[[Any], Awaitable[Any]], prompt_params: Dict, @@ -673,7 +702,9 @@ async def _call_async( full_schema_reask=full_schema_reask, disable_tracer=self._disable_tracer, ) - call = await runner.async_run(call_log=call_log, prompt_params=prompt_params) + call = asyncio.run( + runner.async_run(call_log=call_log, prompt_params=prompt_params) + ) return ValidationOutcome[OT].from_guard_history(call) def __repr__(self): @@ -851,17 +882,21 @@ def __parse( ) # If the LLM API is async, return a coroutine - if asyncio.iscoroutinefunction(llm_api): - return self._async_parse( - llm_output, - metadata, - llm_api=llm_api, - num_reasks=self.num_reasks, - prompt_params=prompt_params, - full_schema_reask=full_schema_reask, - call_log=call_log, - *args, - **kwargs, + # FIXME: Add null llm_api support becuase we want to be able to call on + # static text (e.g. validate("foobar")). Is there a better way to do this? + if not llm_api or inspect.iscoroutinefunction(llm_api): + return asyncio.run( + self._async_parse( + llm_output, + metadata, + llm_api=llm_api, + num_reasks=self.num_reasks, + prompt_params=prompt_params, + full_schema_reask=full_schema_reask, + call_log=call_log, + *args, + **kwargs, + ) ) else: raise NotImplementedError( @@ -1075,7 +1110,7 @@ def use( return self @overload - def use_many(self, *validators: Validator, on: str = "output") -> "Async": ... + def use_many(self, *validators: Validator, on: str = "output") -> "AsyncGuard": ... @overload def use_many( diff --git a/guardrails/guard.py b/guardrails/guard.py index c5b873ae0..84b4858ba 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -5,13 +5,31 @@ import warnings from copy import deepcopy from string import Template -from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, - Optional, Sequence, Tuple, Type, Union, cast, overload) +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Generic, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, + overload, +) from guardrails_api_client.models import AnyObject from guardrails_api_client.models import Guard as GuardModel -from guardrails_api_client.models import (History, HistoryEvent, - ValidatePayload, ValidationOutput) +from guardrails_api_client.models import ( + History, + HistoryEvent, + ValidatePayload, + ValidationOutput, +) from guardrails_api_client.types import UNSET from langchain_core.messages import BaseMessage from langchain_core.runnables import Runnable, RunnableConfig @@ -29,17 +47,25 @@ from guardrails.classes.history.iteration import Iteration from guardrails.classes.history.outputs import Outputs from guardrails.errors import ValidationError -from guardrails.llm_providers import (get_async_llm_ask, get_llm_api_enum, - get_llm_ask, - model_is_supported_server_side) +from guardrails.llm_providers import ( + get_async_llm_ask, + get_llm_api_enum, + get_llm_ask, + model_is_supported_server_side, +) from guardrails.logger import logger, set_scope from guardrails.prompt import Instructions, Prompt from guardrails.rail import Rail from guardrails.run import AsyncRunner, Runner, StreamRunner from guardrails.schema import Schema, StringSchema -from guardrails.stores.context import (Tracer, get_call_kwarg, - get_tracer_context, set_call_kwargs, - set_tracer, set_tracer_context) +from guardrails.stores.context import ( + Tracer, + get_call_kwarg, + get_tracer_context, + set_call_kwargs, + set_tracer, + set_tracer_context, +) from guardrails.utils.hub_telemetry_utils import HubTelemetry from guardrails.utils.llm_response import LLMResponse from guardrails.utils.reask_utils import FieldReAsk @@ -686,7 +712,18 @@ def _call_sync( call = runner(call_log=call_log, prompt_params=prompt_params) return ValidationOutcome[OT].from_guard_history(call) - + @deprecated( + """Async methods within Guard are deprecated and will be removed in 0.5.x. + Instead, please use `AsyncGuard() or pass in a sequential llm api.""", + category=FutureWarning, + stacklevel=2, + ) + @deprecated( + """Async methods within Guard are deprecated and will be removed in 0.5.x. + Instead, please use `AsyncGuard() or pass in a sequential llm api.""", + category=FutureWarning, + stacklevel=2, + ) async def _call_async( self, llm_api: Callable[[Any], Awaitable[Any]], @@ -719,12 +756,6 @@ async def _call_async( Returns: The raw text output from the LLM and the validated output. """ - warnings.warn( - "Using an async LLM client is deprecated in Guard and will " - "be removed in a future release. " - "Please use `AsyncGuard` or a non async llm client instead.", - DeprecationWarning, - ) instructions_obj = instructions or self.instructions prompt_obj = prompt or self.prompt msg_history_obj = msg_history or [] @@ -1009,6 +1040,12 @@ def _sync_parse( return ValidationOutcome[OT].from_guard_history(call) + @deprecated( + """Async methods within Guard are deprecated and will be removed in 0.5.x. + Instead, please use `AsyncGuard() or pass in a sequential llm api.""", + category=FutureWarning, + stacklevel=2, + ) async def _async_parse( self, llm_output: str, diff --git a/tests/unit_tests/test_async_guard.py b/tests/unit_tests/test_async_guard.py new file mode 100644 index 000000000..281ca0949 --- /dev/null +++ b/tests/unit_tests/test_async_guard.py @@ -0,0 +1,601 @@ +import openai +import pytest +from pydantic import BaseModel + +from guardrails import AsyncGuard, Rail, Validator +from guardrails.datatypes import verify_metadata_requirements +from guardrails.utils import args, kwargs, on_fail +from guardrails.utils.openai_utils import OPENAI_VERSION +from guardrails.validator_base import OnFailAction +from guardrails.validators import ( # ReadingTime, + EndsWith, + LowerCase, + OneLine, + PassResult, + TwoWords, + UpperCase, + ValidLength, + register_validator, +) + + +@register_validator("myrequiringvalidator", data_type="string") +class RequiringValidator(Validator): + required_metadata_keys = ["required_key"] + + def validate(self, value, metadata): + return PassResult() + + +@register_validator("myrequiringvalidator2", data_type="string") +class RequiringValidator2(Validator): + required_metadata_keys = ["required_key2"] + + def validate(self, value, metadata): + return PassResult() + + +@pytest.mark.parametrize( + "spec,metadata,error_message", + [ + ( + """ + + + + + + """, + {"required_key": "a"}, + "Missing required metadata keys: required_key", + ), + ( + """ + + + + + + + + + + + """, + {"required_key": "a", "required_key2": "b"}, + "Missing required metadata keys: required_key, required_key2", + ), + ( + """ + + + + + + + + + + + + + + + + +""", + {"required_key": "a"}, + "Missing required metadata keys: required_key", + ), + ], +) +@pytest.mark.asyncio +@pytest.mark.skipif(not OPENAI_VERSION.startswith("0"), reason="Only for OpenAI v0") +async def test_required_metadata(spec, metadata, error_message): + guard = AsyncGuard.from_rail_string(spec) + + missing_keys = verify_metadata_requirements({}, guard.output_schema.root_datatype) + assert set(missing_keys) == set(metadata) + + not_missing_keys = verify_metadata_requirements( + metadata, guard.output_schema.root_datatype + ) + assert not_missing_keys == [] + + # test sync guard + with pytest.raises(ValueError) as excinfo: + guard.parse("{}") + assert str(excinfo.value) == error_message + + response = guard.parse("{}", metadata=metadata, num_reasks=0) + assert response.error is None + + # test async guard + with pytest.raises(ValueError) as excinfo: + guard.parse("{}") + await guard.parse("{}", llm_api=openai.ChatCompletion.acreate, num_reasks=0) + assert str(excinfo.value) == error_message + + response = await guard.parse( + "{}", metadata=metadata, llm_api=openai.ChatCompletion.acreate, num_reasks=0 + ) + assert response.error is None + + +rail = Rail.from_string_validators([], "empty railspec") +empty_rail_string = """ + +""" + + +class EmptyModel(BaseModel): + empty_field: str + + +i_guard_none = AsyncGuard(rail) +i_guard_two = AsyncGuard(rail, 2) +r_guard_none = AsyncGuard.from_rail("tests/unit_tests/test_assets/empty.rail") +r_guard_two = AsyncGuard.from_rail("tests/unit_tests/test_assets/empty.rail", 2) +rs_guard_none = AsyncGuard.from_rail_string(empty_rail_string) +rs_guard_two = AsyncGuard.from_rail_string(empty_rail_string, 2) +py_guard_none = AsyncGuard.from_pydantic(output_class=EmptyModel) +py_guard_two = AsyncGuard.from_pydantic(output_class=EmptyModel, num_reasks=2) +s_guard_none = AsyncGuard.from_string(validators=[], description="empty railspec") +s_guard_two = AsyncGuard.from_string( + validators=[], description="empty railspec", num_reasks=2 +) + + +@pytest.mark.parametrize( + "guard,expected_num_reasks,config_num_reasks", + [ + (i_guard_none, 1, None), + (i_guard_two, 2, None), + (i_guard_none, 3, 3), + (r_guard_none, 1, None), + (r_guard_two, 2, None), + (r_guard_none, 3, 3), + (rs_guard_none, 1, None), + (rs_guard_two, 2, None), + (rs_guard_none, 3, 3), + (py_guard_none, 1, None), + (py_guard_two, 2, None), + (py_guard_none, 3, 3), + (s_guard_none, 1, None), + (s_guard_two, 2, None), + (s_guard_none, 3, 3), + ], +) +def test_configure(guard: AsyncGuard, expected_num_reasks: int, config_num_reasks: int): + guard.configure(config_num_reasks) + assert guard.num_reasks == expected_num_reasks + + +def guard_init_from_rail(): + guard = AsyncGuard.from_rail("tests/unit_tests/test_assets/simple.rail") + assert ( + guard.instructions.format().source.strip() + == "You are a helpful bot, who answers only with valid JSON" + ) + assert guard.prompt.format().source.strip() == "Extract a string from the text" + + +def test_use(): + guard: AsyncGuard = ( + AsyncGuard() + .use(EndsWith("a")) + .use(OneLine()) + .use(LowerCase) + .use(TwoWords, on_fail=OnFailAction.REASK) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) + + # print(guard.__stringify__()) + assert len(guard._validators) == 5 + + assert isinstance(guard._validators[0], EndsWith) + assert guard._validators[0]._kwargs["end"] == "a" + assert ( + guard._validators[0].on_fail_descriptor == OnFailAction.FIX + ) # bc this is the default + + assert isinstance(guard._validators[1], OneLine) + assert ( + guard._validators[1].on_fail_descriptor == OnFailAction.NOOP + ) # bc this is the default + + assert isinstance(guard._validators[2], LowerCase) + assert ( + guard._validators[2].on_fail_descriptor == OnFailAction.NOOP + ) # bc this is the default + + assert isinstance(guard._validators[3], TwoWords) + assert guard._validators[3].on_fail_descriptor == OnFailAction.REASK # bc we set it + + assert isinstance(guard._validators[4], ValidLength) + assert guard._validators[4]._min == 0 + assert guard._validators[4]._kwargs["min"] == 0 + assert guard._validators[4]._max == 12 + assert guard._validators[4]._kwargs["max"] == 12 + assert ( + guard._validators[4].on_fail_descriptor == OnFailAction.REFRAIN + ) # bc we set it + + # Raises error when trying to `use` a validator on a non-string + with pytest.raises(RuntimeError): + + class TestClass(BaseModel): + another_field: str + + py_guard = AsyncGuard.from_pydantic(output_class=TestClass) + py_guard.use( + EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK) + ) + + # Use a combination of prompt, instructions, msg_history and output validators + # Should only have the output validators in the guard, + # everything else is in the schema + guard: AsyncGuard = ( + AsyncGuard() + .use(LowerCase, on="prompt") + .use(OneLine, on="prompt") + .use(UpperCase, on="instructions") + .use(LowerCase, on="msg_history") + .use( + EndsWith, end="a", on="output" + ) # default on="output", still explicitly set + .use( + TwoWords, on_fail=OnFailAction.REASK + ) # default on="output", implicitly set + ) + + # Check schemas for prompt, instructions and msg_history validators + prompt_validators = guard.rail.prompt_schema.root_datatype.validators + assert len(prompt_validators) == 2 + assert prompt_validators[0].__class__.__name__ == "LowerCase" + assert prompt_validators[1].__class__.__name__ == "OneLine" + + instructions_validators = guard.rail.instructions_schema.root_datatype.validators + assert len(instructions_validators) == 1 + assert instructions_validators[0].__class__.__name__ == "UpperCase" + + msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators + assert len(msg_history_validators) == 1 + assert msg_history_validators[0].__class__.__name__ == "LowerCase" + + # Check guard for output validators + assert len(guard._validators) == 2 # only 2 output validators, hence 2 + + assert isinstance(guard._validators[0], EndsWith) + assert guard._validators[0]._kwargs["end"] == "a" + assert ( + guard._validators[0].on_fail_descriptor == OnFailAction.FIX + ) # bc this is the default + + assert isinstance(guard._validators[1], TwoWords) + assert guard._validators[1].on_fail_descriptor == OnFailAction.REASK # bc we set it + + # Test with an invalid "on" parameter, should raise a ValueError + with pytest.raises(ValueError): + guard: AsyncGuard = ( + AsyncGuard() + .use(EndsWith("a"), on="response") # invalid on parameter + .use(OneLine, on="prompt") # valid on parameter + ) + + +def test_use_many_instances(): + guard: AsyncGuard = AsyncGuard().use_many( + EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK) + ) + + # print(guard.__stringify__()) + assert len(guard._validators) == 4 + + assert isinstance(guard._validators[0], EndsWith) + assert guard._validators[0]._end == "a" + assert guard._validators[0]._kwargs["end"] == "a" + assert ( + guard._validators[0].on_fail_descriptor == OnFailAction.FIX + ) # bc this is the default + + assert isinstance(guard._validators[1], OneLine) + assert ( + guard._validators[1].on_fail_descriptor == OnFailAction.NOOP + ) # bc this is the default + + assert isinstance(guard._validators[2], LowerCase) + assert ( + guard._validators[2].on_fail_descriptor == OnFailAction.NOOP + ) # bc this is the default + + assert isinstance(guard._validators[3], TwoWords) + assert guard._validators[3].on_fail_descriptor == OnFailAction.REASK # bc we set it + + # Raises error when trying to `use_many` a validator on a non-string + with pytest.raises(RuntimeError): + + class TestClass(BaseModel): + another_field: str + + py_guard = AsyncGuard.from_pydantic(output_class=TestClass) + py_guard.use_many( + [ + EndsWith("a"), + OneLine(), + LowerCase(), + TwoWords(on_fail=OnFailAction.REASK), + ] + ) + + # Test with explicitly setting the "on" parameter = "output" + guard: AsyncGuard = AsyncGuard().use_many( + EndsWith("a"), + OneLine(), + LowerCase(), + TwoWords(on_fail=OnFailAction.REASK), + on="output", + ) + + assert len(guard._validators) == 4 # still 4 output validators, hence 4 + + assert isinstance(guard._validators[0], EndsWith) + assert guard._validators[0]._end == "a" + assert guard._validators[0]._kwargs["end"] == "a" + assert ( + guard._validators[0].on_fail_descriptor == OnFailAction.FIX + ) # bc this is the default + + assert isinstance(guard._validators[1], OneLine) + assert ( + guard._validators[1].on_fail_descriptor == OnFailAction.NOOP + ) # bc this is the default + + assert isinstance(guard._validators[2], LowerCase) + assert ( + guard._validators[2].on_fail_descriptor == OnFailAction.NOOP + ) # bc this is the default + + assert isinstance(guard._validators[3], TwoWords) + assert guard._validators[3].on_fail_descriptor == OnFailAction.REASK # bc we set it + + # Test with explicitly setting the "on" parameter = "prompt" + guard: AsyncGuard = AsyncGuard().use_many( + OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK), on="prompt" + ) + + prompt_validators = guard.rail.prompt_schema.root_datatype.validators + assert len(prompt_validators) == 3 + assert prompt_validators[0].__class__.__name__ == "OneLine" + assert prompt_validators[1].__class__.__name__ == "LowerCase" + assert prompt_validators[2].__class__.__name__ == "TwoWords" + assert len(guard._validators) == 0 # no output validators, hence 0 + + # Test with explicitly setting the "on" parameter = "instructions" + guard: AsyncGuard = AsyncGuard().use_many( + OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK), on="instructions" + ) + + instructions_validators = guard.rail.instructions_schema.root_datatype.validators + assert len(instructions_validators) == 3 + assert instructions_validators[0].__class__.__name__ == "OneLine" + assert instructions_validators[1].__class__.__name__ == "LowerCase" + assert instructions_validators[2].__class__.__name__ == "TwoWords" + assert len(guard._validators) == 0 # no output validators, hence 0 + + # Test with explicitly setting the "on" parameter = "msg_history" + guard: AsyncGuard = AsyncGuard().use_many( + OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK), on="msg_history" + ) + + msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators + assert len(msg_history_validators) == 3 + assert msg_history_validators[0].__class__.__name__ == "OneLine" + assert msg_history_validators[1].__class__.__name__ == "LowerCase" + assert msg_history_validators[2].__class__.__name__ == "TwoWords" + assert len(guard._validators) == 0 # no output validators, hence 0 + + # Test with an invalid "on" parameter, should raise a ValueError + with pytest.raises(ValueError): + guard: AsyncGuard = AsyncGuard().use_many( + EndsWith("a", on_fail=OnFailAction.EXCEPTION), OneLine(), on="response" + ) + + +def test_use_many_tuple(): + guard: AsyncGuard = AsyncGuard().use_many( + OneLine, + (EndsWith, ["a"], {"on_fail": OnFailAction.EXCEPTION}), + (LowerCase, kwargs(on_fail=OnFailAction.FIX_REASK, some_other_kwarg="kwarg")), + (TwoWords, on_fail(OnFailAction.REASK)), + (ValidLength, args(0, 12), kwargs(on_fail=OnFailAction.REFRAIN)), + ) + + # print(guard.__stringify__()) + assert len(guard._validators) == 5 + + assert isinstance(guard._validators[0], OneLine) + assert ( + guard._validators[0].on_fail_descriptor == OnFailAction.NOOP + ) # bc this is the default + + assert isinstance(guard._validators[1], EndsWith) + assert guard._validators[1]._end == "a" + assert guard._validators[1]._kwargs["end"] == "a" + assert ( + guard._validators[1].on_fail_descriptor == OnFailAction.EXCEPTION + ) # bc we set it + + assert isinstance(guard._validators[2], LowerCase) + assert guard._validators[2]._kwargs["some_other_kwarg"] == "kwarg" + assert ( + guard._validators[2].on_fail_descriptor == OnFailAction.FIX_REASK + ) # bc this is the default + + assert isinstance(guard._validators[3], TwoWords) + assert guard._validators[3].on_fail_descriptor == OnFailAction.REASK # bc we set it + + assert isinstance(guard._validators[4], ValidLength) + assert guard._validators[4]._min == 0 + assert guard._validators[4]._kwargs["min"] == 0 + assert guard._validators[4]._max == 12 + assert guard._validators[4]._kwargs["max"] == 12 + assert ( + guard._validators[4].on_fail_descriptor == OnFailAction.REFRAIN + ) # bc we set it + + # Test with explicitly setting the "on" parameter + guard: AsyncGuard = AsyncGuard().use_many( + (EndsWith, ["a"], {"on_fail": OnFailAction.EXCEPTION}), + OneLine, + on="output", + ) + + assert len(guard._validators) == 2 # only 2 output validators, hence 2 + + assert isinstance(guard._validators[0], EndsWith) + assert guard._validators[0]._end == "a" + assert guard._validators[0]._kwargs["end"] == "a" + assert ( + guard._validators[0].on_fail_descriptor == OnFailAction.EXCEPTION + ) # bc we set it + + assert isinstance(guard._validators[1], OneLine) + assert ( + guard._validators[1].on_fail_descriptor == OnFailAction.NOOP + ) # bc this is the default + + # Test with an invalid "on" parameter, should raise a ValueError + with pytest.raises(ValueError): + guard: AsyncGuard = AsyncGuard().use_many( + (EndsWith, ["a"], {"on_fail": OnFailAction.EXCEPTION}), + OneLine, + on="response", + ) + + +def test_validate(): + guard: AsyncGuard = ( + AsyncGuard() + .use(OneLine) + .use( + LowerCase(on_fail=OnFailAction.FIX), on="output" + ) # default on="output", still explicitly set + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) + + llm_output: str = "Oh Canada" # bc it meets our criteria + + response = guard.validate(llm_output) + + assert response.validation_passed is True + assert response.validated_output == llm_output.lower() + + llm_output_2 = "Star Spangled Banner" # to stick with the theme + + response_2 = guard.validate(llm_output_2) + + assert response_2.validation_passed is False + assert response_2.validated_output is None + + # Test with a combination of prompt, output, instructions and msg_history validators + # Should still only use the output validators to validate the output + guard: AsyncGuard = ( + AsyncGuard() + .use(OneLine, on="prompt") + .use(LowerCase, on="instructions") + .use(UpperCase, on="msg_history") + .use(LowerCase, on="output", on_fail=OnFailAction.FIX) + .use(TwoWords, on="output") + .use(ValidLength, 0, 12, on="output") + ) + + llm_output: str = "Oh Canada" # bc it meets our criteria + + response = guard.validate(llm_output) + + assert response.validation_passed is True + assert response.validated_output == llm_output.lower() + + llm_output_2 = "Star Spangled Banner" # to stick with the theme + + response_2 = guard.validate(llm_output_2) + + assert response_2.validation_passed is False + assert response_2.validated_output is None + + +def test_use_and_use_many(): + guard: AsyncGuard = ( + AsyncGuard() + .use_many(OneLine(), LowerCase(), on="prompt") + .use(UpperCase, on="instructions") + .use(LowerCase, on="msg_history") + .use_many( + TwoWords(on_fail=OnFailAction.REASK), + ValidLength(0, 12, on_fail=OnFailAction.REFRAIN), + on="output", + ) + ) + + # Check schemas for prompt, instructions and msg_history validators + prompt_validators = guard.rail.prompt_schema.root_datatype.validators + assert len(prompt_validators) == 2 + assert prompt_validators[0].__class__.__name__ == "OneLine" + assert prompt_validators[1].__class__.__name__ == "LowerCase" + + instructions_validators = guard.rail.instructions_schema.root_datatype.validators + assert len(instructions_validators) == 1 + assert instructions_validators[0].__class__.__name__ == "UpperCase" + + msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators + assert len(msg_history_validators) == 1 + assert msg_history_validators[0].__class__.__name__ == "LowerCase" + + # Check guard for output validators + assert len(guard._validators) == 2 # only 2 output validators, hence 2 + + assert isinstance(guard._validators[0], TwoWords) + assert guard._validators[0].on_fail_descriptor == OnFailAction.REASK # bc we set it + + assert isinstance(guard._validators[1], ValidLength) + assert guard._validators[1]._min == 0 + assert guard._validators[1]._kwargs["min"] == 0 + assert guard._validators[1]._max == 12 + assert guard._validators[1]._kwargs["max"] == 12 + assert ( + guard._validators[1].on_fail_descriptor == OnFailAction.REFRAIN + ) # bc we set it + + # Test with an invalid "on" parameter, should raise a ValueError + with pytest.raises(ValueError): + guard: AsyncGuard = ( + AsyncGuard() + .use_many(OneLine(), LowerCase(), on="prompt") + .use(UpperCase, on="instructions") + .use(LowerCase, on="msg_history") + .use_many( + TwoWords(on_fail=OnFailAction.REASK), + ValidLength(0, 12, on_fail=OnFailAction.REFRAIN), + on="response", # invalid "on" parameter + ) + ) + + +# def test_call(): +# five_seconds = 5 / 60 +# response = AsyncGuard().use_many( +# ReadingTime(five_seconds, on_fail=OnFailAction.EXCEPTION), +# OneLine, +# (EndsWith, ["a"], {"on_fail": OnFailAction.EXCEPTION}), +# (LowerCase, kwargs(on_fail=OnFailAction.FIX_REASK, some_other_kwarg="kwarg")), +# (TwoWords, on_fail(OnFailAction.REASK)), +# (ValidLength, args(0, 12), kwargs(on_fail=OnFailAction.REFRAIN)), +# )("Oh Canada") + +# assert response.validation_passed is True +# assert response.validated_output == "oh canada" From 58f8e2d4130a5903934181f3a54d24d4e63b0516 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Mon, 13 May 2024 13:41:59 -0700 Subject: [PATCH 3/9] removing extra deprecation warning --- guardrails/guard.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index 84b4858ba..a0f51f7d0 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -712,12 +712,6 @@ def _call_sync( call = runner(call_log=call_log, prompt_params=prompt_params) return ValidationOutcome[OT].from_guard_history(call) - @deprecated( - """Async methods within Guard are deprecated and will be removed in 0.5.x. - Instead, please use `AsyncGuard() or pass in a sequential llm api.""", - category=FutureWarning, - stacklevel=2, - ) @deprecated( """Async methods within Guard are deprecated and will be removed in 0.5.x. Instead, please use `AsyncGuard() or pass in a sequential llm api.""", From 1442c485fa1134d3b3e7b67be78faca042fc1e47 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Mon, 13 May 2024 13:43:02 -0700 Subject: [PATCH 4/9] removing unused test --- tests/unit_tests/test_async_guard.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/unit_tests/test_async_guard.py b/tests/unit_tests/test_async_guard.py index 281ca0949..1e02f750c 100644 --- a/tests/unit_tests/test_async_guard.py +++ b/tests/unit_tests/test_async_guard.py @@ -584,18 +584,3 @@ def test_use_and_use_many(): on="response", # invalid "on" parameter ) ) - - -# def test_call(): -# five_seconds = 5 / 60 -# response = AsyncGuard().use_many( -# ReadingTime(five_seconds, on_fail=OnFailAction.EXCEPTION), -# OneLine, -# (EndsWith, ["a"], {"on_fail": OnFailAction.EXCEPTION}), -# (LowerCase, kwargs(on_fail=OnFailAction.FIX_REASK, some_other_kwarg="kwarg")), -# (TwoWords, on_fail(OnFailAction.REASK)), -# (ValidLength, args(0, 12), kwargs(on_fail=OnFailAction.REFRAIN)), -# )("Oh Canada") - -# assert response.validation_passed is True -# assert response.validated_output == "oh canada" From f80285a1b9c490a12e5c7b39be8836fbe89de19f Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Mon, 13 May 2024 14:18:32 -0700 Subject: [PATCH 5/9] fixing init --- guardrails/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/guardrails/__init__.py b/guardrails/__init__.py index 8883668f0..b1d081fca 100644 --- a/guardrails/__init__.py +++ b/guardrails/__init__.py @@ -11,6 +11,7 @@ __all__ = [ "Guard", + "AsyncGuard", "PromptCallableBase", "Rail", "Validator", From af6b72eb9ba95042179f92266fd7b9f9c3a57439 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Thu, 16 May 2024 15:07:35 -0700 Subject: [PATCH 6/9] Updating Async guard to inherit from guard --- guardrails/async_guard.py | 886 +------------------------------------- 1 file changed, 13 insertions(+), 873 deletions(-) diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py index 09affb0fc..607e88776 100644 --- a/guardrails/async_guard.py +++ b/guardrails/async_guard.py @@ -1,80 +1,37 @@ import asyncio import contextvars import inspect -import json -import os -import warnings -from copy import deepcopy -from string import Template from typing import ( Any, Awaitable, Callable, Dict, - Generic, Iterable, List, Optional, - Sequence, - Tuple, - Type, Union, - cast, overload, ) -from guardrails_api_client.models import AnyObject -from guardrails_api_client.models import Guard as GuardModel -from guardrails_api_client.models import ( - History, - HistoryEvent, - ValidatePayload, - ValidationOutput, -) -from guardrails_api_client.types import UNSET -from langchain_core.messages import BaseMessage -from langchain_core.runnables import Runnable, RunnableConfig -from pydantic import BaseModel -from pydantic.version import VERSION as PYDANTIC_VERSION -from typing_extensions import deprecated -from guardrails.api_client import GuardrailsApiClient -from guardrails.classes import OT, InputType, ValidationOutcome -from guardrails.classes.credentials import Credentials -from guardrails.classes.generic import Stack +from guardrails import Guard +from guardrails.classes import OT, ValidationOutcome from guardrails.classes.history import Call from guardrails.classes.history.call_inputs import CallInputs -from guardrails.classes.history.inputs import Inputs -from guardrails.classes.history.iteration import Iteration -from guardrails.classes.history.outputs import Outputs -from guardrails.errors import ValidationError from guardrails.llm_providers import ( get_async_llm_ask, - get_llm_api_enum, - get_llm_ask, model_is_supported_server_side, ) -from guardrails.logger import logger, set_scope -from guardrails.prompt import Instructions, Prompt -from guardrails.rail import Rail +from guardrails.logger import set_scope from guardrails.run import AsyncRunner -from guardrails.schema import Schema, StringSchema from guardrails.stores.context import ( - Tracer, - get_call_kwarg, - get_tracer_context, set_call_kwargs, set_tracer, set_tracer_context, ) -from guardrails.utils.hub_telemetry_utils import HubTelemetry -from guardrails.utils.llm_response import LLMResponse -from guardrails.utils.reask_utils import FieldReAsk -from guardrails.utils.validator_utils import get_validator -from guardrails.validator_base import FailResult, Validator -class AsyncGuard(Runnable, Generic[OT]): +class AsyncGuard(Guard): """The Guard class. This class one of the main entry point for using Guardrails. It is @@ -91,370 +48,6 @@ class AsyncGuard(Runnable, Generic[OT]): the LLM and the validated output stream. """ - _tracer = None - _tracer_context = None - _hub_telemetry = None - _guard_id = None - _user_id = None - _validators: List[Validator] - _api_client: Optional[GuardrailsApiClient] = None - - def __init__( - self, - rail: Optional[Rail] = None, - num_reasks: Optional[int] = None, - base_model: Optional[ - Union[Type[BaseModel], Type[List[Type[BaseModel]]]] - ] = None, - tracer: Optional[Tracer] = None, - *, - name: Optional[str] = None, - description: Optional[str] = None, - ): - """Initialize the Guard with optional Rail instance, num_reasks, and - base_model.""" - if not rail: - rail = ( - Rail.from_pydantic(base_model) - if base_model - else Rail.from_string_validators([]) - ) - self.rail = rail - self.num_reasks = num_reasks - # TODO: Support a sink for history so that it is not solely held in memory - self.history: Stack[Call] = Stack() - self.base_model = base_model - self._set_tracer(tracer) - - credentials = Credentials.from_rc_file(logger) - - # Get unique id of user from credentials - self._user_id = credentials.id or "" - - # Get metrics opt-out from credentials - self._disable_tracer = credentials.no_metrics - - # Get id of guard object (that is unique) - self._guard_id = id(self) # id of guard object; not the class - - # Initialize Hub Telemetry singleton and get the tracer - # if it is not disabled - if not self._disable_tracer: - self._hub_telemetry = HubTelemetry() - self._validators = [] - - # Gaurdrails As A Service Initialization - self.description = description - self.name = name - - api_key = os.environ.get("GUARDRAILS_API_KEY") - if api_key is not None: - if self.name is None: - self.name = f"gr-{str(self._guard_id)}" - logger.warn("Warning: No name passed to guard!") - logger.warn( - "Use this auto-generated name to re-use this guard: {name}".format( - name=self.name - ) - ) - self._api_client = GuardrailsApiClient(api_key=api_key) - self.upsert_guard() - - @property - def prompt_schema(self) -> Optional[StringSchema]: - """Return the input schema.""" - return self.rail.prompt_schema - - @property - def instructions_schema(self) -> Optional[StringSchema]: - """Return the input schema.""" - return self.rail.instructions_schema - - @property - def msg_history_schema(self) -> Optional[StringSchema]: - """Return the input schema.""" - return self.rail.msg_history_schema - - @property - def output_schema(self) -> Schema: - """Return the output schema.""" - return self.rail.output_schema - - @property - def instructions(self) -> Optional[Instructions]: - """Return the instruction-prompt.""" - return self.rail.instructions - - @property - def prompt(self) -> Optional[Prompt]: - """Return the prompt.""" - return self.rail.prompt - - @property - def raw_prompt(self) -> Optional[Prompt]: - """Return the prompt, alias for `prompt`.""" - return self.prompt - - @property - def base_prompt(self) -> Optional[str]: - """Return the base prompt i.e. prompt.source.""" - if self.prompt is None: - return None - return self.prompt.source - - @property - def reask_prompt(self) -> Optional[Prompt]: - """Return the reask prompt.""" - return self.output_schema.reask_prompt_template - - @reask_prompt.setter - def reask_prompt(self, reask_prompt: Optional[str]): - """Set the reask prompt.""" - self.output_schema.reask_prompt_template = reask_prompt - - @property - def reask_instructions(self) -> Optional[Instructions]: - """Return the reask prompt.""" - return self.output_schema.reask_instructions_template - - @reask_instructions.setter - def reask_instructions(self, reask_instructions: Optional[str]): - """Set the reask prompt.""" - self.output_schema.reask_instructions_template = reask_instructions - - def configure( - self, - num_reasks: Optional[int] = None, - ): - """Configure the Guard.""" - self.num_reasks = ( - num_reasks - if num_reasks is not None - else self.num_reasks - if self.num_reasks is not None - else 1 - ) - - def _set_tracer(self, tracer: Optional[Tracer] = None) -> None: - self._tracer = tracer - set_tracer(tracer) - set_tracer_context() - self._tracer_context = get_tracer_context() - - @classmethod - def from_rail( - cls, - rail_file: str, - num_reasks: Optional[int] = None, - tracer: Optional[Tracer] = None, - *, - name: Optional[str] = None, - description: Optional[str] = None, - ): - """Create a Schema from a `.rail` file. - - Args: - rail_file: The path to the `.rail` file. - num_reasks: The max times to re-ask the LLM for invalid output. - - Returns: - An instance of the `Guard` class. - """ - - # We have to set the tracer in the ContextStore before the Rail, - # and therefore the Validators, are initialized - cls._set_tracer(cls, tracer) # type: ignore - - rail = Rail.from_file(rail_file) - if rail.output_type == "str": - return cast( - AsyncGuard[str], - cls( - rail=rail, - num_reasks=num_reasks, - tracer=tracer, - name=name, - description=description, - ), - ) - elif rail.output_type == "list": - return cast( - AsyncGuard[List], - cls( - rail=rail, - num_reasks=num_reasks, - tracer=tracer, - name=name, - description=description, - ), - ) - return cast( - AsyncGuard[Dict], - cls( - rail=rail, - num_reasks=num_reasks, - tracer=tracer, - name=name, - description=description, - ), - ) - - @classmethod - def from_rail_string( - cls, - rail_string: str, - num_reasks: Optional[int] = None, - tracer: Optional[Tracer] = None, - *, - name: Optional[str] = None, - description: Optional[str] = None, - ): - """Create a Schema from a `.rail` string. - - Args: - rail_string: The `.rail` string. - num_reasks: The max times to re-ask the LLM for invalid output. - - Returns: - An instance of the `Guard` class. - """ - # We have to set the tracer in the ContextStore before the Rail, - # and therefore the Validators, are initialized - cls._set_tracer(cls, tracer) # type: ignore - - rail = Rail.from_string(rail_string) - if rail.output_type == "str": - return cast( - AsyncGuard[str], - cls( - rail=rail, - num_reasks=num_reasks, - tracer=tracer, - name=name, - description=description, - ), - ) - elif rail.output_type == "list": - return cast( - AsyncGuard[List], - cls( - rail=rail, - num_reasks=num_reasks, - tracer=tracer, - name=name, - description=description, - ), - ) - return cast( - AsyncGuard[Dict], - cls( - rail=rail, - num_reasks=num_reasks, - tracer=tracer, - name=name, - description=description, - ), - ) - - @classmethod - def from_pydantic( - cls, - output_class: Union[Type[BaseModel], Type[List[Type[BaseModel]]]], - prompt: Optional[str] = None, - instructions: Optional[str] = None, - num_reasks: Optional[int] = None, - reask_prompt: Optional[str] = None, - reask_instructions: Optional[str] = None, - tracer: Optional[Tracer] = None, - *, - name: Optional[str] = None, - description: Optional[str] = None, - ): - """Create a Guard instance from a Pydantic model and prompt.""" - if PYDANTIC_VERSION.startswith("1"): - warnings.warn( - """Support for Pydantic v1.x is deprecated and will be removed in - Guardrails 0.5.x. Please upgrade to the latest Pydantic v2.x to - continue receiving future updates and support.""", - FutureWarning, - ) - # We have to set the tracer in the ContextStore before the Rail, - # and therefore the Validators, are initialized - cls._set_tracer(cls, tracer) # type: ignore - - rail = Rail.from_pydantic( - output_class=output_class, - prompt=prompt, - instructions=instructions, - reask_prompt=reask_prompt, - reask_instructions=reask_instructions, - ) - if rail.output_type == "list": - return cast( - AsyncGuard[List], - cls(rail, num_reasks=num_reasks, base_model=output_class), - ) - return cast( - AsyncGuard[Dict], - cls( - rail, - num_reasks=num_reasks, - base_model=output_class, - tracer=tracer, - name=name, - description=description, - ), - ) - - @classmethod - def from_string( - cls, - validators: Sequence[Validator], - description: Optional[str] = None, - prompt: Optional[str] = None, - instructions: Optional[str] = None, - reask_prompt: Optional[str] = None, - reask_instructions: Optional[str] = None, - num_reasks: Optional[int] = None, - tracer: Optional[Tracer] = None, - *, - name: Optional[str] = None, - guard_description: Optional[str] = None, - ): - """Create a Guard instance for a string response with prompt, - instructions, and validations. - - Args: - validators: (List[Validator]): The list of validators to apply to the string output. - description (str, optional): A description for the string to be generated. Defaults to None. - prompt (str, optional): The prompt used to generate the string. Defaults to None. - instructions (str, optional): Instructions for chat models. Defaults to None. - reask_prompt (str, optional): An alternative prompt to use during reasks. Defaults to None. - reask_instructions (str, optional): Alternative instructions to use during reasks. Defaults to None. - num_reasks (int, optional): The max times to re-ask the LLM for invalid output. - """ # noqa - - cls._set_tracer(cls, tracer) # type: ignore - - rail = Rail.from_string_validators( - validators=validators, - description=description, - prompt=prompt, - instructions=instructions, - reask_prompt=reask_prompt, - reask_instructions=reask_instructions, - ) - return cast( - AsyncGuard[str], - cls( - rail, - num_reasks=num_reasks, - tracer=tracer, - name=name, - description=guard_description, - ), - ) - @overload def __call__( self, @@ -707,71 +300,6 @@ def _call_async( ) return ValidationOutcome[OT].from_guard_history(call) - def __repr__(self): - return f"Guard(RAIL={self.rail})" - - def __rich_repr__(self): - yield "RAIL", self.rail - - def __stringify__(self): - if self.rail and self.rail.output_type == "str": - template = Template( - """ - Guard { - validators: [ - ${validators} - ] - } - """ - ) - return template.safe_substitute( - { - "validators": ",\n".join( - [v.__stringify__() for v in self._validators] - ) - } - ) - return self.__repr__() - - @overload - def parse( - self, - llm_output: str, - metadata: Optional[Dict] = None, - llm_api: None = None, - num_reasks: Optional[int] = None, - prompt_params: Optional[Dict] = None, - full_schema_reask: Optional[bool] = None, - *args, - **kwargs, - ) -> ValidationOutcome[OT]: ... - - @overload - def parse( - self, - llm_output: str, - metadata: Optional[Dict] = None, - llm_api: Callable[[Any], Awaitable[Any]] = ..., - num_reasks: Optional[int] = None, - prompt_params: Optional[Dict] = None, - full_schema_reask: Optional[bool] = None, - *args, - **kwargs, - ) -> Awaitable[ValidationOutcome[OT]]: ... - - @overload - def parse( - self, - llm_output: str, - metadata: Optional[Dict] = None, - llm_api: Optional[Callable] = None, - num_reasks: Optional[int] = None, - prompt_params: Optional[Dict] = None, - full_schema_reask: Optional[bool] = None, - *args, - **kwargs, - ) -> ValidationOutcome[OT]: ... - def parse( self, llm_output: str, @@ -881,10 +409,13 @@ def __parse( **kwargs, ) - # If the LLM API is async, return a coroutine - # FIXME: Add null llm_api support becuase we want to be able to call on - # static text (e.g. validate("foobar")). Is there a better way to do this? - if not llm_api or inspect.iscoroutinefunction(llm_api): + # FIXME: checking not llm_api because it can still fall back on defaults and + # function as expected. We should handle this better. + if ( + not llm_api + or inspect.iscoroutinefunction(llm_api) + or inspect.isasyncgenfunction(llm_api) + ): return asyncio.run( self._async_parse( llm_output, @@ -901,7 +432,8 @@ def __parse( else: raise NotImplementedError( "AsyncGuard does not support non-async LLM APIs. " - "Please use the sync API Guard or supply an async LLM API." + "Please use the synchronous API Guard or supply an asynchronous " + "LLM API." ) guard_context = contextvars.Context() @@ -959,395 +491,3 @@ async def _async_parse( call = await runner.async_run(call_log=call_log, prompt_params=prompt_params) return ValidationOutcome[OT].from_guard_history(call) - - @deprecated( - """The `with_prompt_validation` method is deprecated, - and will be removed in 0.5.x. Instead, please use - `Guard().use(YourValidator, on='prompt')`.""", - category=FutureWarning, - stacklevel=2, - ) - def with_prompt_validation( - self, - validators: Sequence[Validator], - ): - """Add prompt validation to the Guard. - - Args: - validators: The validators to add to the prompt. - """ - if self.rail.prompt_schema: - warnings.warn("Overriding existing prompt validators.") - schema = StringSchema.from_string( - validators=validators, - ) - self.rail.prompt_schema = schema - return self - - @deprecated( - """The `with_instructions_validation` method is deprecated, - and will be removed in 0.5.x. Instead, please use - `Guard().use(YourValidator, on='instructions')`.""", - category=FutureWarning, - stacklevel=2, - ) - def with_instructions_validation( - self, - validators: Sequence[Validator], - ): - """Add instructions validation to the Guard. - - Args: - validators: The validators to add to the instructions. - """ - if self.rail.instructions_schema: - warnings.warn("Overriding existing instructions validators.") - schema = StringSchema.from_string( - validators=validators, - ) - self.rail.instructions_schema = schema - return self - - @deprecated( - """The `with_msg_history_validation` method is deprecated, - and will be removed in 0.5.x. Instead, please use - `Guard().use(YourValidator, on='msg_history')`.""", - category=FutureWarning, - stacklevel=2, - ) - def with_msg_history_validation( - self, - validators: Sequence[Validator], - ): - """Add msg_history validation to the Guard. - - Args: - validators: The validators to add to the msg_history. - """ - if self.rail.msg_history_schema: - warnings.warn("Overriding existing msg_history validators.") - schema = StringSchema.from_string( - validators=validators, - ) - self.rail.msg_history_schema = schema - return self - - def __add_validator(self, validator: Validator, on: str = "output"): - # Only available for string output types - if self.rail.output_type != "str": - raise RuntimeError( - "The `use` method is only available for string output types." - ) - - if on == "prompt": - # If the prompt schema exists, add the validator to it - if self.rail.prompt_schema: - self.rail.prompt_schema.root_datatype.validators.append(validator) - else: - # Otherwise, create a new schema with the validator - schema = StringSchema.from_string( - validators=[validator], - ) - self.rail.prompt_schema = schema - elif on == "instructions": - # If the instructions schema exists, add the validator to it - if self.rail.instructions_schema: - self.rail.instructions_schema.root_datatype.validators.append(validator) - else: - # Otherwise, create a new schema with the validator - schema = StringSchema.from_string( - validators=[validator], - ) - self.rail.instructions_schema = schema - elif on == "msg_history": - # If the msg_history schema exists, add the validator to it - if self.rail.msg_history_schema: - self.rail.msg_history_schema.root_datatype.validators.append(validator) - else: - # Otherwise, create a new schema with the validator - schema = StringSchema.from_string( - validators=[validator], - ) - self.rail.msg_history_schema = schema - elif on == "output": - self._validators.append(validator) - self.rail.output_schema.root_datatype.validators.append(validator) - else: - raise ValueError( - """Invalid value for `on`. Must be one of the following: - 'output', 'prompt', 'instructions', 'msg_history'.""" - ) - - @overload - def use(self, validator: Validator, *, on: str = "output") -> "AsyncGuard": ... - - @overload - def use( - self, validator: Type[Validator], *args, on: str = "output", **kwargs - ) -> "AsyncGuard": ... - - def use( - self, - validator: Union[Validator, Type[Validator]], - *args, - on: str = "output", - **kwargs, - ) -> "AsyncGuard": - """Use a validator to validate either of the following: - - The output of an LLM request - - The prompt - - The instructions - - The message history - - *Note*: For on="output", `use` is only available for string output types. - - Args: - validator: The validator to use. Either the class or an instance. - on: The part of the LLM request to validate. Defaults to "output". - """ - hydrated_validator = get_validator(validator, *args, **kwargs) - self.__add_validator(hydrated_validator, on=on) - return self - - @overload - def use_many(self, *validators: Validator, on: str = "output") -> "AsyncGuard": ... - - @overload - def use_many( - self, - *validators: Tuple[ - Type[Validator], - Optional[Union[List[Any], Dict[str, Any]]], - Optional[Dict[str, Any]], - ], - on: str = "output", - ) -> "AsyncGuard": ... - - def use_many( - self, - *validators: Union[ - Validator, - Tuple[ - Type[Validator], - Optional[Union[List[Any], Dict[str, Any]]], - Optional[Dict[str, Any]], - ], - ], - on: str = "output", - ) -> "AsyncGuard": - """Use a validator to validate results of an LLM request. - - *Note*: `use_many` is only available for string output types. - """ - if self.rail.output_type != "str": - raise RuntimeError( - "The `use_many` method is only available for string output types." - ) - - # Loop through the validators - for v in validators: - hydrated_validator = get_validator(v) - self.__add_validator(hydrated_validator, on=on) - return self - - def validate(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]: - if ( - not self.rail - or self.rail.output_schema.root_datatype.validators != self._validators - ): - self.rail = Rail.from_string_validators( - validators=self._validators, - prompt=self.prompt.source if self.prompt else None, - instructions=self.instructions.source if self.instructions else None, - reask_prompt=self.reask_prompt.source if self.reask_prompt else None, - reask_instructions=self.reask_instructions.source - if self.reask_instructions - else None, - ) - - return self.parse(llm_output=llm_output, *args, **kwargs) - - # No call support for this until - # https://github.com/guardrails-ai/guardrails/pull/525 is merged - # def __call__(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]: - # return self.validate(llm_output, *args, **kwargs) - - def invoke( - self, input: InputType, config: Optional[RunnableConfig] = None - ) -> InputType: - output = BaseMessage(content="", type="") - str_input = None - input_is_chat_message = False - if isinstance(input, BaseMessage): - input_is_chat_message = True - str_input = str(input.content) - output = deepcopy(input) - else: - str_input = str(input) - - response = self.validate(str_input) - - validated_output = response.validated_output - if not validated_output: - raise ValidationError( - ( - "The response from the LLM failed validation!" - "See `guard.history` for more details." - ) - ) - - if isinstance(validated_output, Dict): - validated_output = json.dumps(validated_output) - - if input_is_chat_message: - output.content = validated_output - return cast(InputType, output) - return cast(InputType, validated_output) - - def _to_request(self) -> Dict: - return { - "name": self.name, - "description": self.description, - "railspec": self.rail._to_request(), - "numReasks": self.num_reasks, - } - - def upsert_guard(self): - if self._api_client: - guard_dict = self._to_request() - self._api_client.upsert_guard(GuardModel.from_dict(guard_dict)) - else: - raise ValueError("Guard does not have an api client!") - - def _call_server( - self, - *args, - llm_output: Optional[str] = None, - llm_api: Optional[Callable] = None, - num_reasks: Optional[int] = None, - prompt_params: Optional[Dict] = None, - metadata: Optional[Dict] = {}, - full_schema_reask: Optional[bool] = True, - call_log: Optional[Call], - # prompt: Optional[str], - # instructions: Optional[str], - # msg_history: Optional[List[Dict]], - **kwargs, - ): - if self._api_client: - payload: Dict[str, Any] = {"args": list(args)} - payload.update(**kwargs) - if llm_output is not None: - payload["llmOutput"] = llm_output - if num_reasks is not None: - payload["numReasks"] = num_reasks - if prompt_params is not None: - payload["promptParams"] = prompt_params - if llm_api is not None: - payload["llmApi"] = get_llm_api_enum(llm_api) - # TODO: get enum for llm_api - validation_output: Optional[ValidationOutput] = self._api_client.validate( - guard=self, # type: ignore - payload=ValidatePayload.from_dict(payload), - openai_api_key=get_call_kwarg("api_key"), - ) - - if not validation_output: - return ValidationOutcome[OT]( - raw_llm_output=None, - validated_output=None, - validation_passed=False, - error="The response from the server was empty!", - ) - - call_log = call_log or Call() - if llm_api is not None: - llm_api = get_llm_ask(llm_api) - if asyncio.iscoroutinefunction(llm_api): - llm_api = get_async_llm_ask(llm_api) - session_history = ( - validation_output.session_history - if validation_output is not None and validation_output.session_history - else [] - ) - history: History - for history in session_history: - history_events: Optional[List[HistoryEvent]] = ( # type: ignore - history.history if history.history != UNSET else None - ) - if history_events is None: - continue - - iterations = [ - Iteration( - inputs=Inputs( - llm_api=llm_api, - llm_output=llm_output, - instructions=( - Instructions(h.instructions) if h.instructions else None - ), - prompt=( - Prompt(h.prompt.source) # type: ignore - if h.prompt is not None and h.prompt != UNSET - else None - ), - prompt_params=prompt_params, - num_reasks=(num_reasks or 0), - metadata=metadata, - full_schema_reask=full_schema_reask, - ), - outputs=Outputs( - llm_response_info=LLMResponse( - output=h.output # type: ignore - ), - raw_output=h.output, - parsed_output=( - h.parsed_output.to_dict() - if isinstance(h.parsed_output, AnyObject) - else h.parsed_output - ), - validation_output=( - h.validated_output.to_dict() - if isinstance(h.validated_output, AnyObject) - else h.validated_output - ), - reasks=list( - [ - FieldReAsk( - incorrect_value=r.to_dict().get( - "incorrect_value" - ), - path=r.to_dict().get("path"), - fail_results=[ - FailResult( - error_message=r.to_dict().get( - "error_message" - ), - fix_value=r.to_dict().get("fix_value"), - ) - ], - ) - for r in h.reasks # type: ignore - ] - if h.reasks != UNSET - else [] - ), - ), - ) - for h in history_events - ] - call_log.iterations.extend(iterations) - if self.history.length == 0: - self.history.push(call_log) - - # Our interfaces are too different for this to work right now. - # Once we move towards shared interfaces for both the open source - # and the api we can re-enable this. - # return ValidationOutcome[OT].from_guard_history(call_log) - return ValidationOutcome[OT]( - raw_llm_output=validation_output.raw_llm_response, # type: ignore - validated_output=cast(OT, validation_output.validated_output), - validation_passed=validation_output.result, - ) - else: - raise ValueError("Guard does not have an api client!") From 64f2bd185917602b09d14e7ffd852bc8f1073560 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Thu, 16 May 2024 15:09:01 -0700 Subject: [PATCH 7/9] fixing sequential -> synchronous in guard deprecations --- guardrails/guard.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index a0f51f7d0..d5f25ba7e 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -714,7 +714,7 @@ def _call_sync( @deprecated( """Async methods within Guard are deprecated and will be removed in 0.5.x. - Instead, please use `AsyncGuard() or pass in a sequential llm api.""", + Instead, please use `AsyncGuard() or pass in a synchronous llm api.""", category=FutureWarning, stacklevel=2, ) @@ -1036,7 +1036,7 @@ def _sync_parse( @deprecated( """Async methods within Guard are deprecated and will be removed in 0.5.x. - Instead, please use `AsyncGuard() or pass in a sequential llm api.""", + Instead, please use `AsyncGuard() or pass in a synchronous llm api.""", category=FutureWarning, stacklevel=2, ) @@ -1473,3 +1473,4 @@ def _call_server( ) else: raise ValueError("Guard does not have an api client!") + raise ValueError("Guard does not have an api client!") From 88cd5dcf5e8acfde3e0245086a0384f750801436 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Thu, 16 May 2024 16:40:31 -0700 Subject: [PATCH 8/9] __call__ is now async to match standard llm apis rather than handling internally --- guardrails/async_guard.py | 88 +++++++--------------------- tests/unit_tests/test_async_guard.py | 13 ++-- 2 files changed, 28 insertions(+), 73 deletions(-) diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py index 607e88776..e689ac698 100644 --- a/guardrails/async_guard.py +++ b/guardrails/async_guard.py @@ -1,4 +1,3 @@ -import asyncio import contextvars import inspect from typing import ( @@ -10,25 +9,16 @@ List, Optional, Union, - overload, ) - from guardrails import Guard from guardrails.classes import OT, ValidationOutcome from guardrails.classes.history import Call from guardrails.classes.history.call_inputs import CallInputs -from guardrails.llm_providers import ( - get_async_llm_ask, - model_is_supported_server_side, -) +from guardrails.llm_providers import get_async_llm_ask, model_is_supported_server_side from guardrails.logger import set_scope from guardrails.run import AsyncRunner -from guardrails.stores.context import ( - set_call_kwargs, - set_tracer, - set_tracer_context, -) +from guardrails.stores.context import set_call_kwargs, set_tracer, set_tracer_context class AsyncGuard(Guard): @@ -48,38 +38,7 @@ class AsyncGuard(Guard): the LLM and the validated output stream. """ - @overload - def __call__( - self, - llm_api: Callable, - prompt_params: Optional[Dict] = None, - num_reasks: Optional[int] = None, - prompt: Optional[str] = None, - instructions: Optional[str] = None, - msg_history: Optional[List[Dict]] = None, - metadata: Optional[Dict] = None, - full_schema_reask: Optional[bool] = None, - stream: Optional[bool] = False, - *args, - **kwargs, - ) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]: ... - - @overload - def __call__( - self, - llm_api: Callable[[Any], Awaitable[Any]], - prompt_params: Optional[Dict] = None, - num_reasks: Optional[int] = None, - prompt: Optional[str] = None, - instructions: Optional[str] = None, - msg_history: Optional[List[Dict]] = None, - metadata: Optional[Dict] = None, - full_schema_reask: Optional[bool] = None, - *args, - **kwargs, - ) -> Awaitable[ValidationOutcome[OT]]: ... - - def __call__( + async def __call__( self, llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]], prompt_params: Optional[Dict] = None, @@ -116,7 +75,7 @@ def __call__( The raw text output from the LLM and the validated output. """ - def __call( + async def __call( self, llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]], prompt_params: Optional[Dict] = None, @@ -208,7 +167,7 @@ def __call( "Please use an async LLM API." ) # Otherwise, call the LLM - return self._call_async( + return await self._call_async( llm_api, prompt_params=prompt_params, num_reasks=self.num_reasks, @@ -223,7 +182,7 @@ def __call( ) guard_context = contextvars.Context() - return guard_context.run( + return await guard_context.run( __call, self, llm_api, @@ -238,7 +197,7 @@ def __call( **kwargs, ) - def _call_async( + async def _call_async( self, llm_api: Callable[[Any], Awaitable[Any]], prompt_params: Dict, @@ -295,12 +254,10 @@ def _call_async( full_schema_reask=full_schema_reask, disable_tracer=self._disable_tracer, ) - call = asyncio.run( - runner.async_run(call_log=call_log, prompt_params=prompt_params) - ) + call = await runner.async_run(call_log=call_log, prompt_params=prompt_params) return ValidationOutcome[OT].from_guard_history(call) - def parse( + async def parse( self, llm_output: str, metadata: Optional[Dict] = None, @@ -328,7 +285,7 @@ def parse( determined by the object schema defined in the RAILspec. """ - def __parse( + async def __parse( self, llm_output: str, metadata: Optional[Dict] = None, @@ -416,19 +373,18 @@ def __parse( or inspect.iscoroutinefunction(llm_api) or inspect.isasyncgenfunction(llm_api) ): - return asyncio.run( - self._async_parse( - llm_output, - metadata, - llm_api=llm_api, - num_reasks=self.num_reasks, - prompt_params=prompt_params, - full_schema_reask=full_schema_reask, - call_log=call_log, - *args, - **kwargs, - ) + return await self._async_parse( + llm_output, + metadata, + llm_api=llm_api, + num_reasks=self.num_reasks, + prompt_params=prompt_params, + full_schema_reask=full_schema_reask, + call_log=call_log, + *args, + **kwargs, ) + else: raise NotImplementedError( "AsyncGuard does not support non-async LLM APIs. " @@ -437,7 +393,7 @@ def __parse( ) guard_context = contextvars.Context() - return guard_context.run( + return await guard_context.run( __parse, self, llm_output, diff --git a/tests/unit_tests/test_async_guard.py b/tests/unit_tests/test_async_guard.py index 1e02f750c..0f4e33d6d 100644 --- a/tests/unit_tests/test_async_guard.py +++ b/tests/unit_tests/test_async_guard.py @@ -477,7 +477,8 @@ def test_use_many_tuple(): ) -def test_validate(): +@pytest.mark.asyncio +async def test_validate(): guard: AsyncGuard = ( AsyncGuard() .use(OneLine) @@ -489,15 +490,13 @@ def test_validate(): ) llm_output: str = "Oh Canada" # bc it meets our criteria - - response = guard.validate(llm_output) + response = await guard.validate(llm_output) assert response.validation_passed is True assert response.validated_output == llm_output.lower() - llm_output_2 = "Star Spangled Banner" # to stick with the theme - response_2 = guard.validate(llm_output_2) + response_2 = await guard.validate(llm_output_2) assert response_2.validation_passed is False assert response_2.validated_output is None @@ -516,14 +515,14 @@ def test_validate(): llm_output: str = "Oh Canada" # bc it meets our criteria - response = guard.validate(llm_output) + response = await guard.validate(llm_output) assert response.validation_passed is True assert response.validated_output == llm_output.lower() llm_output_2 = "Star Spangled Banner" # to stick with the theme - response_2 = guard.validate(llm_output_2) + response_2 = await guard.validate(llm_output_2) assert response_2.validation_passed is False assert response_2.validated_output is None From 2d95bcb343d7c4a571486ebe3de82860f4461066 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Thu, 16 May 2024 17:45:23 -0700 Subject: [PATCH 9/9] removing duplicated line --- guardrails/guard.py | 1 - 1 file changed, 1 deletion(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index d5f25ba7e..8a5d3dbd8 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -1473,4 +1473,3 @@ def _call_server( ) else: raise ValueError("Guard does not have an api client!") - raise ValueError("Guard does not have an api client!")