From fe23bf9ace03d30f95bd7c3297acb66ad082ea55 Mon Sep 17 00:00:00 2001
From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com>
Date: Fri, 10 May 2024 14:04:33 -0700
Subject: [PATCH 1/9] initial commit for testing async streaming guard
---
guardrails/__init__.py | 1 +
guardrails/async_guard.py | 1318 +++++++++++++++++++++++++++++++++++++
guardrails/guard.py | 63 +-
3 files changed, 1341 insertions(+), 41 deletions(-)
create mode 100644 guardrails/async_guard.py
diff --git a/guardrails/__init__.py b/guardrails/__init__.py
index 6e28dff7f..8883668f0 100644
--- a/guardrails/__init__.py
+++ b/guardrails/__init__.py
@@ -1,6 +1,7 @@
# Set up __init__.py so that users can do from guardrails import Response, Schema, etc.
from guardrails.guard import Guard
+from guardrails.async_guard import AsyncGuard
from guardrails.llm_providers import PromptCallableBase
from guardrails.logging_utils import configure_logging
from guardrails.prompt import Instructions, Prompt
diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py
new file mode 100644
index 000000000..4f986e56c
--- /dev/null
+++ b/guardrails/async_guard.py
@@ -0,0 +1,1318 @@
+import asyncio
+import contextvars
+import json
+import os
+import time
+import warnings
+from copy import deepcopy
+from string import Template
+from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
+ Optional, Sequence, Tuple, Type, Union, cast, overload)
+
+from guardrails_api_client.models import AnyObject
+from guardrails_api_client.models import Guard as GuardModel
+from guardrails_api_client.models import (History, HistoryEvent,
+ ValidatePayload, ValidationOutput)
+from guardrails_api_client.types import UNSET
+from langchain_core.messages import BaseMessage
+from langchain_core.runnables import Runnable, RunnableConfig
+from pydantic import BaseModel
+from pydantic.version import VERSION as PYDANTIC_VERSION
+from typing_extensions import deprecated
+
+from guardrails.api_client import GuardrailsApiClient
+from guardrails.classes import OT, InputType, ValidationOutcome
+from guardrails.classes.credentials import Credentials
+from guardrails.classes.generic import Stack
+from guardrails.classes.history import Call
+from guardrails.classes.history.call_inputs import CallInputs
+from guardrails.classes.history.inputs import Inputs
+from guardrails.classes.history.iteration import Iteration
+from guardrails.classes.history.outputs import Outputs
+from guardrails.errors import ValidationError
+from guardrails.llm_providers import (get_async_llm_ask, get_llm_api_enum,
+ get_llm_ask,
+ model_is_supported_server_side)
+from guardrails.logger import logger, set_scope
+from guardrails.prompt import Instructions, Prompt
+from guardrails.rail import Rail
+from guardrails.run import AsyncRunner, Runner, StreamRunner
+from guardrails.schema import Schema, StringSchema
+from guardrails.stores.context import (Tracer, get_call_kwarg,
+ get_tracer_context, set_call_kwargs,
+ set_tracer, set_tracer_context)
+from guardrails.utils.hub_telemetry_utils import HubTelemetry
+from guardrails.utils.llm_response import LLMResponse
+from guardrails.utils.reask_utils import FieldReAsk
+from guardrails.utils.validator_utils import get_validator
+from guardrails.validator_base import FailResult, Validator
+
+
+class AsyncGuard(Runnable, Generic[OT]):
+ """The Guard class.
+
+ This class one of the main entry point for using Guardrails. It is
+ initialized from one of the following class methods:
+
+ - `from_rail`
+ - `from_rail_string`
+ - `from_pydantic`
+ - `from_string`
+
+ The `__call__`
+ method functions as a wrapper around LLM APIs. It takes in an Async LLM
+ API, and optional prompt parameters, and returns the raw output stream from
+ the LLM and the validated output stream.
+ """
+
+ _tracer = None
+ _tracer_context = None
+ _hub_telemetry = None
+ _guard_id = None
+ _user_id = None
+ _validators: List[Validator]
+ _api_client: Optional[GuardrailsApiClient] = None
+
+ def __init__(
+ self,
+ rail: Optional[Rail] = None,
+ num_reasks: Optional[int] = None,
+ base_model: Optional[
+ Union[Type[BaseModel], Type[List[Type[BaseModel]]]]
+ ] = None,
+ tracer: Optional[Tracer] = None,
+ *,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ):
+ """Initialize the Guard with optional Rail instance, num_reasks, and
+ base_model."""
+ if not rail:
+ rail = (
+ Rail.from_pydantic(base_model)
+ if base_model
+ else Rail.from_string_validators([])
+ )
+ self.rail = rail
+ self.num_reasks = num_reasks
+ # TODO: Support a sink for history so that it is not solely held in memory
+ self.history: Stack[Call] = Stack()
+ self.base_model = base_model
+ self._set_tracer(tracer)
+
+ credentials = Credentials.from_rc_file(logger)
+
+ # Get unique id of user from credentials
+ self._user_id = credentials.id or ""
+
+ # Get metrics opt-out from credentials
+ self._disable_tracer = credentials.no_metrics
+
+ # Get id of guard object (that is unique)
+ self._guard_id = id(self) # id of guard object; not the class
+
+ # Initialize Hub Telemetry singleton and get the tracer
+ # if it is not disabled
+ if not self._disable_tracer:
+ self._hub_telemetry = HubTelemetry()
+ self._validators = []
+
+ # Gaurdrails As A Service Initialization
+ self.description = description
+ self.name = name
+
+ api_key = os.environ.get("GUARDRAILS_API_KEY")
+ if api_key is not None:
+ if self.name is None:
+ self.name = f"gr-{str(self._guard_id)}"
+ logger.warn("Warning: No name passed to guard!")
+ logger.warn(
+ "Use this auto-generated name to re-use this guard: {name}".format(
+ name=self.name
+ )
+ )
+ self._api_client = GuardrailsApiClient(api_key=api_key)
+ self.upsert_guard()
+
+ @property
+ def prompt_schema(self) -> Optional[StringSchema]:
+ """Return the input schema."""
+ return self.rail.prompt_schema
+
+ @property
+ def instructions_schema(self) -> Optional[StringSchema]:
+ """Return the input schema."""
+ return self.rail.instructions_schema
+
+ @property
+ def msg_history_schema(self) -> Optional[StringSchema]:
+ """Return the input schema."""
+ return self.rail.msg_history_schema
+
+ @property
+ def output_schema(self) -> Schema:
+ """Return the output schema."""
+ return self.rail.output_schema
+
+ @property
+ def instructions(self) -> Optional[Instructions]:
+ """Return the instruction-prompt."""
+ return self.rail.instructions
+
+ @property
+ def prompt(self) -> Optional[Prompt]:
+ """Return the prompt."""
+ return self.rail.prompt
+
+ @property
+ def raw_prompt(self) -> Optional[Prompt]:
+ """Return the prompt, alias for `prompt`."""
+ return self.prompt
+
+ @property
+ def base_prompt(self) -> Optional[str]:
+ """Return the base prompt i.e. prompt.source."""
+ if self.prompt is None:
+ return None
+ return self.prompt.source
+
+ @property
+ def reask_prompt(self) -> Optional[Prompt]:
+ """Return the reask prompt."""
+ return self.output_schema.reask_prompt_template
+
+ @reask_prompt.setter
+ def reask_prompt(self, reask_prompt: Optional[str]):
+ """Set the reask prompt."""
+ self.output_schema.reask_prompt_template = reask_prompt
+
+ @property
+ def reask_instructions(self) -> Optional[Instructions]:
+ """Return the reask prompt."""
+ return self.output_schema.reask_instructions_template
+
+ @reask_instructions.setter
+ def reask_instructions(self, reask_instructions: Optional[str]):
+ """Set the reask prompt."""
+ self.output_schema.reask_instructions_template = reask_instructions
+
+ def configure(
+ self,
+ num_reasks: Optional[int] = None,
+ ):
+ """Configure the Guard."""
+ self.num_reasks = (
+ num_reasks
+ if num_reasks is not None
+ else self.num_reasks
+ if self.num_reasks is not None
+ else 1
+ )
+
+ def _set_tracer(self, tracer: Optional[Tracer] = None) -> None:
+ self._tracer = tracer
+ set_tracer(tracer)
+ set_tracer_context()
+ self._tracer_context = get_tracer_context()
+
+ @classmethod
+ def from_rail(
+ cls,
+ rail_file: str,
+ num_reasks: Optional[int] = None,
+ tracer: Optional[Tracer] = None,
+ *,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ):
+ """Create a Schema from a `.rail` file.
+
+ Args:
+ rail_file: The path to the `.rail` file.
+ num_reasks: The max times to re-ask the LLM for invalid output.
+
+ Returns:
+ An instance of the `Guard` class.
+ """
+
+ # We have to set the tracer in the ContextStore before the Rail,
+ # and therefore the Validators, are initialized
+ cls._set_tracer(cls, tracer) # type: ignore
+
+ rail = Rail.from_file(rail_file)
+ if rail.output_type == "str":
+ return cast(
+ AsyncGuard[str],
+ cls(
+ rail=rail,
+ num_reasks=num_reasks,
+ tracer=tracer,
+ name=name,
+ description=description,
+ ),
+ )
+ elif rail.output_type == "list":
+ return cast(
+ AsyncGuard[List],
+ cls(
+ rail=rail,
+ num_reasks=num_reasks,
+ tracer=tracer,
+ name=name,
+ description=description,
+ ),
+ )
+ return cast(
+ AsyncGuard[Dict],
+ cls(
+ rail=rail,
+ num_reasks=num_reasks,
+ tracer=tracer,
+ name=name,
+ description=description,
+ ),
+ )
+
+ @classmethod
+ def from_rail_string(
+ cls,
+ rail_string: str,
+ num_reasks: Optional[int] = None,
+ tracer: Optional[Tracer] = None,
+ *,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ):
+ """Create a Schema from a `.rail` string.
+
+ Args:
+ rail_string: The `.rail` string.
+ num_reasks: The max times to re-ask the LLM for invalid output.
+
+ Returns:
+ An instance of the `Guard` class.
+ """
+ # We have to set the tracer in the ContextStore before the Rail,
+ # and therefore the Validators, are initialized
+ cls._set_tracer(cls, tracer) # type: ignore
+
+ rail = Rail.from_string(rail_string)
+ if rail.output_type == "str":
+ return cast(
+ AsyncGuard[str],
+ cls(
+ rail=rail,
+ num_reasks=num_reasks,
+ tracer=tracer,
+ name=name,
+ description=description,
+ ),
+ )
+ elif rail.output_type == "list":
+ return cast(
+ AsyncGuard[List],
+ cls(
+ rail=rail,
+ num_reasks=num_reasks,
+ tracer=tracer,
+ name=name,
+ description=description,
+ ),
+ )
+ return cast(
+ AsyncGuard[Dict],
+ cls(
+ rail=rail,
+ num_reasks=num_reasks,
+ tracer=tracer,
+ name=name,
+ description=description,
+ ),
+ )
+
+ @classmethod
+ def from_pydantic(
+ cls,
+ output_class: Union[Type[BaseModel], Type[List[Type[BaseModel]]]],
+ prompt: Optional[str] = None,
+ instructions: Optional[str] = None,
+ num_reasks: Optional[int] = None,
+ reask_prompt: Optional[str] = None,
+ reask_instructions: Optional[str] = None,
+ tracer: Optional[Tracer] = None,
+ *,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ):
+ """Create a Guard instance from a Pydantic model and prompt."""
+ if PYDANTIC_VERSION.startswith("1"):
+ warnings.warn(
+ """Support for Pydantic v1.x is deprecated and will be removed in
+ Guardrails 0.5.x. Please upgrade to the latest Pydantic v2.x to
+ continue receiving future updates and support.""",
+ FutureWarning,
+ )
+ # We have to set the tracer in the ContextStore before the Rail,
+ # and therefore the Validators, are initialized
+ cls._set_tracer(cls, tracer) # type: ignore
+
+ rail = Rail.from_pydantic(
+ output_class=output_class,
+ prompt=prompt,
+ instructions=instructions,
+ reask_prompt=reask_prompt,
+ reask_instructions=reask_instructions,
+ )
+ if rail.output_type == "list":
+ return cast(
+ AsyncGuard[List],
+ cls(rail, num_reasks=num_reasks, base_model=output_class),
+ )
+ return cast(
+ AsyncGuard[Dict],
+ cls(
+ rail,
+ num_reasks=num_reasks,
+ base_model=output_class,
+ tracer=tracer,
+ name=name,
+ description=description,
+ ),
+ )
+
+ @classmethod
+ def from_string(
+ cls,
+ validators: Sequence[Validator],
+ description: Optional[str] = None,
+ prompt: Optional[str] = None,
+ instructions: Optional[str] = None,
+ reask_prompt: Optional[str] = None,
+ reask_instructions: Optional[str] = None,
+ num_reasks: Optional[int] = None,
+ tracer: Optional[Tracer] = None,
+ *,
+ name: Optional[str] = None,
+ guard_description: Optional[str] = None,
+ ):
+ """Create a Guard instance for a string response with prompt,
+ instructions, and validations.
+
+ Args:
+ validators: (List[Validator]): The list of validators to apply to the string output.
+ description (str, optional): A description for the string to be generated. Defaults to None.
+ prompt (str, optional): The prompt used to generate the string. Defaults to None.
+ instructions (str, optional): Instructions for chat models. Defaults to None.
+ reask_prompt (str, optional): An alternative prompt to use during reasks. Defaults to None.
+ reask_instructions (str, optional): Alternative instructions to use during reasks. Defaults to None.
+ num_reasks (int, optional): The max times to re-ask the LLM for invalid output.
+ """ # noqa
+
+ cls._set_tracer(cls, tracer) # type: ignore
+
+ rail = Rail.from_string_validators(
+ validators=validators,
+ description=description,
+ prompt=prompt,
+ instructions=instructions,
+ reask_prompt=reask_prompt,
+ reask_instructions=reask_instructions,
+ )
+ return cast(
+ AsyncGuard[str],
+ cls(
+ rail,
+ num_reasks=num_reasks,
+ tracer=tracer,
+ name=name,
+ description=guard_description,
+ ),
+ )
+
+ @overload
+ def __call__(
+ self,
+ llm_api: Callable,
+ prompt_params: Optional[Dict] = None,
+ num_reasks: Optional[int] = None,
+ prompt: Optional[str] = None,
+ instructions: Optional[str] = None,
+ msg_history: Optional[List[Dict]] = None,
+ metadata: Optional[Dict] = None,
+ full_schema_reask: Optional[bool] = None,
+ stream: Optional[bool] = False,
+ *args,
+ **kwargs,
+ ) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]: ...
+
+ @overload
+ def __call__(
+ self,
+ llm_api: Callable[[Any], Awaitable[Any]],
+ prompt_params: Optional[Dict] = None,
+ num_reasks: Optional[int] = None,
+ prompt: Optional[str] = None,
+ instructions: Optional[str] = None,
+ msg_history: Optional[List[Dict]] = None,
+ metadata: Optional[Dict] = None,
+ full_schema_reask: Optional[bool] = None,
+ *args,
+ **kwargs,
+ ) -> Awaitable[ValidationOutcome[OT]]: ...
+
+ async def __call__(
+ self,
+ llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]],
+ prompt_params: Optional[Dict] = None,
+ num_reasks: Optional[int] = None,
+ prompt: Optional[str] = None,
+ instructions: Optional[str] = None,
+ msg_history: Optional[List[Dict]] = None,
+ metadata: Optional[Dict] = None,
+ full_schema_reask: Optional[bool] = None,
+ *args,
+ **kwargs,
+ ) -> Union[
+ Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]],
+ Awaitable[ValidationOutcome[OT]],
+ ]:
+ """Call the LLM and validate the output. Pass an async LLM API to
+ return a coroutine.
+
+ Args:
+ llm_api: The LLM API to call
+ (e.g. openai.Completion.create or openai.Completion.acreate)
+ prompt_params: The parameters to pass to the prompt.format() method.
+ num_reasks: The max times to re-ask the LLM for invalid output.
+ prompt: The prompt to use for the LLM.
+ instructions: Instructions for chat models.
+ msg_history: The message history to pass to the LLM.
+ metadata: Metadata to pass to the validators.
+ full_schema_reask: When reasking, whether to regenerate the full schema
+ or just the incorrect values.
+ Defaults to `True` if a base model is provided,
+ `False` otherwise.
+
+ Returns:
+ The raw text output from the LLM and the validated output.
+ """
+
+ async def __call(
+ self,
+ llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]],
+ prompt_params: Optional[Dict] = None,
+ num_reasks: Optional[int] = None,
+ prompt: Optional[str] = None,
+ instructions: Optional[str] = None,
+ msg_history: Optional[List[Dict]] = None,
+ metadata: Optional[Dict] = None,
+ full_schema_reask: Optional[bool] = None,
+ *args,
+ **kwargs,
+ ):
+ if metadata is None:
+ metadata = {}
+ if full_schema_reask is None:
+ full_schema_reask = self.base_model is not None
+ if prompt_params is None:
+ prompt_params = {}
+
+ if not self._disable_tracer:
+ # Create a new span for this guard call
+ self._hub_telemetry.create_new_span(
+ span_name="/guard_call",
+ attributes=[
+ ("guard_id", self._guard_id),
+ ("user_id", self._user_id),
+ ("llm_api", llm_api.__name__ if llm_api else "None"),
+ ("custom_reask_prompt", self.reask_prompt is not None),
+ (
+ "custom_reask_instructions",
+ self.reask_instructions is not None,
+ ),
+ ],
+ is_parent=True, # It will have children
+ has_parent=False, # Has no parents
+ )
+
+ set_call_kwargs(kwargs)
+ set_tracer(self._tracer)
+ set_tracer_context(self._tracer_context)
+
+ self.configure(num_reasks)
+ if self.num_reasks is None:
+ raise RuntimeError(
+ "`num_reasks` is `None` after calling `configure()`. "
+ "This should never happen."
+ )
+
+ input_prompt = prompt or (self.prompt._source if self.prompt else None)
+ input_instructions = instructions or (
+ self.instructions._source if self.instructions else None
+ )
+ call_inputs = CallInputs(
+ llm_api=llm_api,
+ prompt=input_prompt,
+ instructions=input_instructions,
+ msg_history=msg_history,
+ prompt_params=prompt_params,
+ num_reasks=self.num_reasks,
+ metadata=metadata,
+ full_schema_reask=full_schema_reask,
+ args=list(args),
+ kwargs=kwargs,
+ )
+ call_log = Call(inputs=call_inputs)
+ set_scope(str(id(call_log)))
+ self.history.push(call_log)
+
+ if self._api_client is not None and model_is_supported_server_side(
+ llm_api, *args, **kwargs
+ ):
+ return self._call_server(
+ llm_api=llm_api,
+ num_reasks=self.num_reasks,
+ prompt_params=prompt_params,
+ full_schema_reask=full_schema_reask,
+ call_log=call_log,
+ *args,
+ **kwargs,
+ )
+
+ # If the LLM API is not async, fail
+ if not asyncio.iscoroutinefunction(llm_api):
+ raise RuntimeError(
+ f"The LLM API `{llm_api.__name__}` is not a coroutine. "
+ "Please use an async LLM API."
+ )
+ # Otherwise, call the LLM
+ return await self._call_async(
+ llm_api,
+ prompt_params=prompt_params,
+ num_reasks=self.num_reasks,
+ prompt=prompt,
+ instructions=instructions,
+ msg_history=msg_history,
+ metadata=metadata,
+ full_schema_reask=full_schema_reask,
+ call_log=call_log,
+ *args,
+ **kwargs,
+ )
+
+ guard_context = contextvars.Context()
+ return guard_context.run(
+ __call,
+ self,
+ llm_api,
+ prompt_params,
+ num_reasks,
+ prompt,
+ instructions,
+ msg_history,
+ metadata,
+ full_schema_reask,
+ *args,
+ **kwargs,
+ )
+
+ async def _call_async(
+ self,
+ llm_api: Callable[[Any], Awaitable[Any]],
+ prompt_params: Dict,
+ num_reasks: int,
+ prompt: Optional[str],
+ instructions: Optional[str],
+ msg_history: Optional[List[Dict]],
+ metadata: Dict,
+ full_schema_reask: bool,
+ call_log: Call,
+ *args,
+ **kwargs,
+ ) -> ValidationOutcome[OT]:
+ """Call the LLM asynchronously and validate the output.
+
+ Args:
+ llm_api: The LLM API to call asynchronously (e.g. openai.Completion.acreate)
+ prompt_params: The parameters to pass to the prompt.format() method.
+ num_reasks: The max times to re-ask the LLM for invalid output.
+ prompt: The prompt to use for the LLM.
+ instructions: Instructions for chat models.
+ msg_history: The message history to pass to the LLM.
+ metadata: Metadata to pass to the validators.
+ full_schema_reask: When reasking, whether to regenerate the full schema
+ or just the incorrect values.
+ Defaults to `True` if a base model is provided,
+ `False` otherwise.
+
+ Returns:
+ The raw text output from the LLM and the validated output.
+ """
+ instructions_obj = instructions or self.instructions
+ prompt_obj = prompt or self.prompt
+ msg_history_obj = msg_history or []
+ if prompt_obj is None:
+ if msg_history_obj is not None and not len(msg_history_obj):
+ raise RuntimeError(
+ "You must provide a prompt if msg_history is empty. "
+ "Alternatively, you can provide a prompt in the RAIL spec."
+ )
+
+ runner = AsyncRunner(
+ instructions=instructions_obj,
+ prompt=prompt_obj,
+ msg_history=msg_history_obj,
+ api=get_async_llm_ask(llm_api, *args, **kwargs),
+ prompt_schema=self.prompt_schema,
+ instructions_schema=self.instructions_schema,
+ msg_history_schema=self.msg_history_schema,
+ output_schema=self.output_schema,
+ num_reasks=num_reasks,
+ metadata=metadata,
+ base_model=self.base_model,
+ full_schema_reask=full_schema_reask,
+ disable_tracer=self._disable_tracer,
+ )
+ call = await runner.async_run(call_log=call_log, prompt_params=prompt_params)
+ return ValidationOutcome[OT].from_guard_history(call)
+
+ def __repr__(self):
+ return f"Guard(RAIL={self.rail})"
+
+ def __rich_repr__(self):
+ yield "RAIL", self.rail
+
+ def __stringify__(self):
+ if self.rail and self.rail.output_type == "str":
+ template = Template(
+ """
+ Guard {
+ validators: [
+ ${validators}
+ ]
+ }
+ """
+ )
+ return template.safe_substitute(
+ {
+ "validators": ",\n".join(
+ [v.__stringify__() for v in self._validators]
+ )
+ }
+ )
+ return self.__repr__()
+
+ @overload
+ def parse(
+ self,
+ llm_output: str,
+ metadata: Optional[Dict] = None,
+ llm_api: None = None,
+ num_reasks: Optional[int] = None,
+ prompt_params: Optional[Dict] = None,
+ full_schema_reask: Optional[bool] = None,
+ *args,
+ **kwargs,
+ ) -> ValidationOutcome[OT]: ...
+
+ @overload
+ def parse(
+ self,
+ llm_output: str,
+ metadata: Optional[Dict] = None,
+ llm_api: Callable[[Any], Awaitable[Any]] = ...,
+ num_reasks: Optional[int] = None,
+ prompt_params: Optional[Dict] = None,
+ full_schema_reask: Optional[bool] = None,
+ *args,
+ **kwargs,
+ ) -> Awaitable[ValidationOutcome[OT]]: ...
+
+ @overload
+ def parse(
+ self,
+ llm_output: str,
+ metadata: Optional[Dict] = None,
+ llm_api: Optional[Callable] = None,
+ num_reasks: Optional[int] = None,
+ prompt_params: Optional[Dict] = None,
+ full_schema_reask: Optional[bool] = None,
+ *args,
+ **kwargs,
+ ) -> ValidationOutcome[OT]: ...
+
+ def parse(
+ self,
+ llm_output: str,
+ metadata: Optional[Dict] = None,
+ llm_api: Optional[Callable] = None,
+ num_reasks: Optional[int] = None,
+ prompt_params: Optional[Dict] = None,
+ full_schema_reask: Optional[bool] = None,
+ *args,
+ **kwargs,
+ ) -> Union[ValidationOutcome[OT], Awaitable[ValidationOutcome[OT]]]:
+ """Alternate flow to using Guard where the llm_output is known.
+
+ Args:
+ llm_output: The output being parsed and validated.
+ metadata: Metadata to pass to the validators.
+ llm_api: The LLM API to call
+ (e.g. openai.Completion.create or openai.Completion.acreate)
+ num_reasks: The max times to re-ask the LLM for invalid output.
+ prompt_params: The parameters to pass to the prompt.format() method.
+ full_schema_reask: When reasking, whether to regenerate the full schema
+ or just the incorrect values.
+
+ Returns:
+ The validated response. This is either a string or a dictionary,
+ determined by the object schema defined in the RAILspec.
+ """
+
+ def __parse(
+ self,
+ llm_output: str,
+ metadata: Optional[Dict] = None,
+ llm_api: Optional[Callable] = None,
+ num_reasks: Optional[int] = None,
+ prompt_params: Optional[Dict] = None,
+ full_schema_reask: Optional[bool] = None,
+ *args,
+ **kwargs,
+ ):
+ final_num_reasks = (
+ num_reasks if num_reasks is not None else 0 if llm_api is None else None
+ )
+
+ if not self._disable_tracer:
+ self._hub_telemetry.create_new_span(
+ span_name="/guard_parse",
+ attributes=[
+ ("guard_id", self._guard_id),
+ ("user_id", self._user_id),
+ ("llm_api", llm_api.__name__ if llm_api else "None"),
+ ("custom_reask_prompt", self.reask_prompt is not None),
+ (
+ "custom_reask_instructions",
+ self.reask_instructions is not None,
+ ),
+ ],
+ is_parent=True, # It will have children
+ has_parent=False, # Has no parents
+ )
+
+ self.configure(final_num_reasks)
+ if self.num_reasks is None:
+ raise RuntimeError(
+ "`num_reasks` is `None` after calling `configure()`. "
+ "This should never happen."
+ )
+ if full_schema_reask is None:
+ full_schema_reask = self.base_model is not None
+ metadata = metadata or {}
+ prompt_params = prompt_params or {}
+
+ set_call_kwargs(kwargs)
+ set_tracer(self._tracer)
+ set_tracer_context(self._tracer_context)
+
+ input_prompt = self.prompt._source if self.prompt else None
+ input_instructions = (
+ self.instructions._source if self.instructions else None
+ )
+ call_inputs = CallInputs(
+ llm_api=llm_api,
+ llm_output=llm_output,
+ prompt=input_prompt,
+ instructions=input_instructions,
+ prompt_params=prompt_params,
+ num_reasks=self.num_reasks,
+ metadata=metadata,
+ full_schema_reask=full_schema_reask,
+ args=list(args),
+ kwargs=kwargs,
+ )
+ call_log = Call(inputs=call_inputs)
+ set_scope(str(id(call_log)))
+ self.history.push(call_log)
+
+ if self._api_client is not None and model_is_supported_server_side(
+ llm_api, *args, **kwargs
+ ):
+ return self._call_server(
+ llm_output=llm_output,
+ llm_api=llm_api,
+ num_reasks=self.num_reasks,
+ prompt_params=prompt_params,
+ full_schema_reask=full_schema_reask,
+ call_log=call_log,
+ *args,
+ **kwargs,
+ )
+
+ # If the LLM API is async, return a coroutine
+ if asyncio.iscoroutinefunction(llm_api):
+ return self._async_parse(
+ llm_output,
+ metadata,
+ llm_api=llm_api,
+ num_reasks=self.num_reasks,
+ prompt_params=prompt_params,
+ full_schema_reask=full_schema_reask,
+ call_log=call_log,
+ *args,
+ **kwargs,
+ )
+ else:
+ raise NotImplementedError(
+ "AsyncGuard does not support non-async LLM APIs. "
+ "Please use the sync API Guard or supply an async LLM API."
+ )
+
+ guard_context = contextvars.Context()
+ return guard_context.run(
+ __parse,
+ self,
+ llm_output,
+ metadata,
+ llm_api,
+ num_reasks,
+ prompt_params,
+ full_schema_reask,
+ *args,
+ **kwargs,
+ )
+
+ async def _async_parse(
+ self,
+ llm_output: str,
+ metadata: Dict,
+ llm_api: Optional[Callable[[Any], Awaitable[Any]]],
+ num_reasks: int,
+ prompt_params: Dict,
+ full_schema_reask: bool,
+ call_log: Call,
+ *args,
+ **kwargs,
+ ) -> ValidationOutcome[OT]:
+ """Alternate flow to using Guard where the llm_output is known.
+
+ Args:
+ llm_output: The output from the LLM.
+ llm_api: The LLM API to use to re-ask the LLM.
+ num_reasks: The max times to re-ask the LLM for invalid output.
+
+ Returns:
+ The validated response.
+ """
+ runner = AsyncRunner(
+ instructions=kwargs.pop("instructions", None),
+ prompt=kwargs.pop("prompt", None),
+ msg_history=kwargs.pop("msg_history", None),
+ api=get_async_llm_ask(llm_api, *args, **kwargs) if llm_api else None,
+ prompt_schema=self.prompt_schema,
+ instructions_schema=self.instructions_schema,
+ msg_history_schema=self.msg_history_schema,
+ output_schema=self.output_schema,
+ num_reasks=num_reasks,
+ metadata=metadata,
+ output=llm_output,
+ base_model=self.base_model,
+ full_schema_reask=full_schema_reask,
+ disable_tracer=self._disable_tracer,
+ )
+ call = await runner.async_run(call_log=call_log, prompt_params=prompt_params)
+
+ return ValidationOutcome[OT].from_guard_history(call)
+
+ @deprecated(
+ """The `with_prompt_validation` method is deprecated,
+ and will be removed in 0.5.x. Instead, please use
+ `Guard().use(YourValidator, on='prompt')`.""",
+ category=FutureWarning,
+ stacklevel=2,
+ )
+ def with_prompt_validation(
+ self,
+ validators: Sequence[Validator],
+ ):
+ """Add prompt validation to the Guard.
+
+ Args:
+ validators: The validators to add to the prompt.
+ """
+ if self.rail.prompt_schema:
+ warnings.warn("Overriding existing prompt validators.")
+ schema = StringSchema.from_string(
+ validators=validators,
+ )
+ self.rail.prompt_schema = schema
+ return self
+
+ @deprecated(
+ """The `with_instructions_validation` method is deprecated,
+ and will be removed in 0.5.x. Instead, please use
+ `Guard().use(YourValidator, on='instructions')`.""",
+ category=FutureWarning,
+ stacklevel=2,
+ )
+ def with_instructions_validation(
+ self,
+ validators: Sequence[Validator],
+ ):
+ """Add instructions validation to the Guard.
+
+ Args:
+ validators: The validators to add to the instructions.
+ """
+ if self.rail.instructions_schema:
+ warnings.warn("Overriding existing instructions validators.")
+ schema = StringSchema.from_string(
+ validators=validators,
+ )
+ self.rail.instructions_schema = schema
+ return self
+
+ @deprecated(
+ """The `with_msg_history_validation` method is deprecated,
+ and will be removed in 0.5.x. Instead, please use
+ `Guard().use(YourValidator, on='msg_history')`.""",
+ category=FutureWarning,
+ stacklevel=2,
+ )
+ def with_msg_history_validation(
+ self,
+ validators: Sequence[Validator],
+ ):
+ """Add msg_history validation to the Guard.
+
+ Args:
+ validators: The validators to add to the msg_history.
+ """
+ if self.rail.msg_history_schema:
+ warnings.warn("Overriding existing msg_history validators.")
+ schema = StringSchema.from_string(
+ validators=validators,
+ )
+ self.rail.msg_history_schema = schema
+ return self
+
+ def __add_validator(self, validator: Validator, on: str = "output"):
+ # Only available for string output types
+ if self.rail.output_type != "str":
+ raise RuntimeError(
+ "The `use` method is only available for string output types."
+ )
+
+ if on == "prompt":
+ # If the prompt schema exists, add the validator to it
+ if self.rail.prompt_schema:
+ self.rail.prompt_schema.root_datatype.validators.append(validator)
+ else:
+ # Otherwise, create a new schema with the validator
+ schema = StringSchema.from_string(
+ validators=[validator],
+ )
+ self.rail.prompt_schema = schema
+ elif on == "instructions":
+ # If the instructions schema exists, add the validator to it
+ if self.rail.instructions_schema:
+ self.rail.instructions_schema.root_datatype.validators.append(validator)
+ else:
+ # Otherwise, create a new schema with the validator
+ schema = StringSchema.from_string(
+ validators=[validator],
+ )
+ self.rail.instructions_schema = schema
+ elif on == "msg_history":
+ # If the msg_history schema exists, add the validator to it
+ if self.rail.msg_history_schema:
+ self.rail.msg_history_schema.root_datatype.validators.append(validator)
+ else:
+ # Otherwise, create a new schema with the validator
+ schema = StringSchema.from_string(
+ validators=[validator],
+ )
+ self.rail.msg_history_schema = schema
+ elif on == "output":
+ self._validators.append(validator)
+ self.rail.output_schema.root_datatype.validators.append(validator)
+ else:
+ raise ValueError(
+ """Invalid value for `on`. Must be one of the following:
+ 'output', 'prompt', 'instructions', 'msg_history'."""
+ )
+
+ @overload
+ def use(self, validator: Validator, *, on: str = "output") -> "AsyncGuard": ...
+
+ @overload
+ def use(
+ self, validator: Type[Validator], *args, on: str = "output", **kwargs
+ ) -> "AsyncGuard": ...
+
+ def use(
+ self,
+ validator: Union[Validator, Type[Validator]],
+ *args,
+ on: str = "output",
+ **kwargs,
+ ) -> "AsyncGuard":
+ """Use a validator to validate either of the following:
+ - The output of an LLM request
+ - The prompt
+ - The instructions
+ - The message history
+
+ *Note*: For on="output", `use` is only available for string output types.
+
+ Args:
+ validator: The validator to use. Either the class or an instance.
+ on: The part of the LLM request to validate. Defaults to "output".
+ """
+ hydrated_validator = get_validator(validator, *args, **kwargs)
+ self.__add_validator(hydrated_validator, on=on)
+ return self
+
+ @overload
+ def use_many(self, *validators: Validator, on: str = "output") -> "Async": ...
+
+ @overload
+ def use_many(
+ self,
+ *validators: Tuple[
+ Type[Validator],
+ Optional[Union[List[Any], Dict[str, Any]]],
+ Optional[Dict[str, Any]],
+ ],
+ on: str = "output",
+ ) -> "AsyncGuard": ...
+
+ def use_many(
+ self,
+ *validators: Union[
+ Validator,
+ Tuple[
+ Type[Validator],
+ Optional[Union[List[Any], Dict[str, Any]]],
+ Optional[Dict[str, Any]],
+ ],
+ ],
+ on: str = "output",
+ ) -> "AsyncGuard":
+ """Use a validator to validate results of an LLM request.
+
+ *Note*: `use_many` is only available for string output types.
+ """
+ if self.rail.output_type != "str":
+ raise RuntimeError(
+ "The `use_many` method is only available for string output types."
+ )
+
+ # Loop through the validators
+ for v in validators:
+ hydrated_validator = get_validator(v)
+ self.__add_validator(hydrated_validator, on=on)
+ return self
+
+ def validate(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]:
+ if (
+ not self.rail
+ or self.rail.output_schema.root_datatype.validators != self._validators
+ ):
+ self.rail = Rail.from_string_validators(
+ validators=self._validators,
+ prompt=self.prompt.source if self.prompt else None,
+ instructions=self.instructions.source if self.instructions else None,
+ reask_prompt=self.reask_prompt.source if self.reask_prompt else None,
+ reask_instructions=self.reask_instructions.source
+ if self.reask_instructions
+ else None,
+ )
+
+ return self.parse(llm_output=llm_output, *args, **kwargs)
+
+ # No call support for this until
+ # https://github.com/guardrails-ai/guardrails/pull/525 is merged
+ # def __call__(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]:
+ # return self.validate(llm_output, *args, **kwargs)
+
+ def invoke(
+ self, input: InputType, config: Optional[RunnableConfig] = None
+ ) -> InputType:
+ output = BaseMessage(content="", type="")
+ str_input = None
+ input_is_chat_message = False
+ if isinstance(input, BaseMessage):
+ input_is_chat_message = True
+ str_input = str(input.content)
+ output = deepcopy(input)
+ else:
+ str_input = str(input)
+
+ response = self.validate(str_input)
+
+ validated_output = response.validated_output
+ if not validated_output:
+ raise ValidationError(
+ (
+ "The response from the LLM failed validation!"
+ "See `guard.history` for more details."
+ )
+ )
+
+ if isinstance(validated_output, Dict):
+ validated_output = json.dumps(validated_output)
+
+ if input_is_chat_message:
+ output.content = validated_output
+ return cast(InputType, output)
+ return cast(InputType, validated_output)
+
+ def _to_request(self) -> Dict:
+ return {
+ "name": self.name,
+ "description": self.description,
+ "railspec": self.rail._to_request(),
+ "numReasks": self.num_reasks,
+ }
+
+ def upsert_guard(self):
+ if self._api_client:
+ guard_dict = self._to_request()
+ self._api_client.upsert_guard(GuardModel.from_dict(guard_dict))
+ else:
+ raise ValueError("Guard does not have an api client!")
+
+ def _call_server(
+ self,
+ *args,
+ llm_output: Optional[str] = None,
+ llm_api: Optional[Callable] = None,
+ num_reasks: Optional[int] = None,
+ prompt_params: Optional[Dict] = None,
+ metadata: Optional[Dict] = {},
+ full_schema_reask: Optional[bool] = True,
+ call_log: Optional[Call],
+ # prompt: Optional[str],
+ # instructions: Optional[str],
+ # msg_history: Optional[List[Dict]],
+ **kwargs,
+ ):
+ if self._api_client:
+ payload: Dict[str, Any] = {"args": list(args)}
+ payload.update(**kwargs)
+ if llm_output is not None:
+ payload["llmOutput"] = llm_output
+ if num_reasks is not None:
+ payload["numReasks"] = num_reasks
+ if prompt_params is not None:
+ payload["promptParams"] = prompt_params
+ if llm_api is not None:
+ payload["llmApi"] = get_llm_api_enum(llm_api)
+ # TODO: get enum for llm_api
+ validation_output: Optional[ValidationOutput] = self._api_client.validate(
+ guard=self, # type: ignore
+ payload=ValidatePayload.from_dict(payload),
+ openai_api_key=get_call_kwarg("api_key"),
+ )
+
+ if not validation_output:
+ return ValidationOutcome[OT](
+ raw_llm_output=None,
+ validated_output=None,
+ validation_passed=False,
+ error="The response from the server was empty!",
+ )
+
+ call_log = call_log or Call()
+ if llm_api is not None:
+ llm_api = get_llm_ask(llm_api)
+ if asyncio.iscoroutinefunction(llm_api):
+ llm_api = get_async_llm_ask(llm_api)
+ session_history = (
+ validation_output.session_history
+ if validation_output is not None and validation_output.session_history
+ else []
+ )
+ history: History
+ for history in session_history:
+ history_events: Optional[List[HistoryEvent]] = ( # type: ignore
+ history.history if history.history != UNSET else None
+ )
+ if history_events is None:
+ continue
+
+ iterations = [
+ Iteration(
+ inputs=Inputs(
+ llm_api=llm_api,
+ llm_output=llm_output,
+ instructions=(
+ Instructions(h.instructions) if h.instructions else None
+ ),
+ prompt=(
+ Prompt(h.prompt.source) # type: ignore
+ if h.prompt is not None and h.prompt != UNSET
+ else None
+ ),
+ prompt_params=prompt_params,
+ num_reasks=(num_reasks or 0),
+ metadata=metadata,
+ full_schema_reask=full_schema_reask,
+ ),
+ outputs=Outputs(
+ llm_response_info=LLMResponse(
+ output=h.output # type: ignore
+ ),
+ raw_output=h.output,
+ parsed_output=(
+ h.parsed_output.to_dict()
+ if isinstance(h.parsed_output, AnyObject)
+ else h.parsed_output
+ ),
+ validation_output=(
+ h.validated_output.to_dict()
+ if isinstance(h.validated_output, AnyObject)
+ else h.validated_output
+ ),
+ reasks=list(
+ [
+ FieldReAsk(
+ incorrect_value=r.to_dict().get(
+ "incorrect_value"
+ ),
+ path=r.to_dict().get("path"),
+ fail_results=[
+ FailResult(
+ error_message=r.to_dict().get(
+ "error_message"
+ ),
+ fix_value=r.to_dict().get("fix_value"),
+ )
+ ],
+ )
+ for r in h.reasks # type: ignore
+ ]
+ if h.reasks != UNSET
+ else []
+ ),
+ ),
+ )
+ for h in history_events
+ ]
+ call_log.iterations.extend(iterations)
+ if self.history.length == 0:
+ self.history.push(call_log)
+
+ # Our interfaces are too different for this to work right now.
+ # Once we move towards shared interfaces for both the open source
+ # and the api we can re-enable this.
+ # return ValidationOutcome[OT].from_guard_history(call_log)
+ return ValidationOutcome[OT](
+ raw_llm_output=validation_output.raw_llm_response, # type: ignore
+ validated_output=cast(OT, validation_output.validated_output),
+ validation_passed=validation_output.result,
+ )
+ else:
+ raise ValueError("Guard does not have an api client!")
diff --git a/guardrails/guard.py b/guardrails/guard.py
index e6ae27e1b..c5b873ae0 100644
--- a/guardrails/guard.py
+++ b/guardrails/guard.py
@@ -5,31 +5,13 @@
import warnings
from copy import deepcopy
from string import Template
-from typing import (
- Any,
- Awaitable,
- Callable,
- Dict,
- Generic,
- Iterable,
- List,
- Optional,
- Sequence,
- Tuple,
- Type,
- Union,
- cast,
- overload,
-)
-
-from guardrails_api_client.models import (
- AnyObject,
- History,
- HistoryEvent,
- ValidatePayload,
- ValidationOutput,
-)
+from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
+ Optional, Sequence, Tuple, Type, Union, cast, overload)
+
+from guardrails_api_client.models import AnyObject
from guardrails_api_client.models import Guard as GuardModel
+from guardrails_api_client.models import (History, HistoryEvent,
+ ValidatePayload, ValidationOutput)
from guardrails_api_client.types import UNSET
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig
@@ -47,25 +29,17 @@
from guardrails.classes.history.iteration import Iteration
from guardrails.classes.history.outputs import Outputs
from guardrails.errors import ValidationError
-from guardrails.llm_providers import (
- get_async_llm_ask,
- get_llm_api_enum,
- get_llm_ask,
- model_is_supported_server_side,
-)
+from guardrails.llm_providers import (get_async_llm_ask, get_llm_api_enum,
+ get_llm_ask,
+ model_is_supported_server_side)
from guardrails.logger import logger, set_scope
from guardrails.prompt import Instructions, Prompt
from guardrails.rail import Rail
from guardrails.run import AsyncRunner, Runner, StreamRunner
from guardrails.schema import Schema, StringSchema
-from guardrails.stores.context import (
- Tracer,
- get_call_kwarg,
- get_tracer_context,
- set_call_kwargs,
- set_tracer,
- set_tracer_context,
-)
+from guardrails.stores.context import (Tracer, get_call_kwarg,
+ get_tracer_context, set_call_kwargs,
+ set_tracer, set_tracer_context)
from guardrails.utils.hub_telemetry_utils import HubTelemetry
from guardrails.utils.llm_response import LLMResponse
from guardrails.utils.reask_utils import FieldReAsk
@@ -500,8 +474,7 @@ def __call__(
Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]],
Awaitable[ValidationOutcome[OT]],
]:
- """Call the LLM and validate the output. Pass an async LLM API to
- return a coroutine.
+ """Call the LLM and validate the output.
Args:
llm_api: The LLM API to call
@@ -603,7 +576,8 @@ def __call(
**kwargs,
)
- # If the LLM API is async, return a coroutine
+ # If the LLM API is async, return a coroutine. This will be deprecated soon.
+
if asyncio.iscoroutinefunction(llm_api):
return self._call_async(
llm_api,
@@ -712,6 +686,7 @@ def _call_sync(
call = runner(call_log=call_log, prompt_params=prompt_params)
return ValidationOutcome[OT].from_guard_history(call)
+
async def _call_async(
self,
llm_api: Callable[[Any], Awaitable[Any]],
@@ -744,6 +719,12 @@ async def _call_async(
Returns:
The raw text output from the LLM and the validated output.
"""
+ warnings.warn(
+ "Using an async LLM client is deprecated in Guard and will "
+ "be removed in a future release. "
+ "Please use `AsyncGuard` or a non async llm client instead.",
+ DeprecationWarning,
+ )
instructions_obj = instructions or self.instructions
prompt_obj = prompt or self.prompt
msg_history_obj = msg_history or []
From b7deef25e31dbb41a3bb2b600216c157cd84c049 Mon Sep 17 00:00:00 2001
From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com>
Date: Mon, 13 May 2024 13:40:51 -0700
Subject: [PATCH 2/9] Adding async guard, test, and deprecating async calls in
Guard()
---
guardrails/async_guard.py | 95 +++--
guardrails/guard.py | 71 +++-
tests/unit_tests/test_async_guard.py | 601 +++++++++++++++++++++++++++
3 files changed, 720 insertions(+), 47 deletions(-)
create mode 100644 tests/unit_tests/test_async_guard.py
diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py
index 4f986e56c..09affb0fc 100644
--- a/guardrails/async_guard.py
+++ b/guardrails/async_guard.py
@@ -1,18 +1,36 @@
import asyncio
import contextvars
+import inspect
import json
import os
-import time
import warnings
from copy import deepcopy
from string import Template
-from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
- Optional, Sequence, Tuple, Type, Union, cast, overload)
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Generic,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ cast,
+ overload,
+)
from guardrails_api_client.models import AnyObject
from guardrails_api_client.models import Guard as GuardModel
-from guardrails_api_client.models import (History, HistoryEvent,
- ValidatePayload, ValidationOutput)
+from guardrails_api_client.models import (
+ History,
+ HistoryEvent,
+ ValidatePayload,
+ ValidationOutput,
+)
from guardrails_api_client.types import UNSET
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig
@@ -30,17 +48,25 @@
from guardrails.classes.history.iteration import Iteration
from guardrails.classes.history.outputs import Outputs
from guardrails.errors import ValidationError
-from guardrails.llm_providers import (get_async_llm_ask, get_llm_api_enum,
- get_llm_ask,
- model_is_supported_server_side)
+from guardrails.llm_providers import (
+ get_async_llm_ask,
+ get_llm_api_enum,
+ get_llm_ask,
+ model_is_supported_server_side,
+)
from guardrails.logger import logger, set_scope
from guardrails.prompt import Instructions, Prompt
from guardrails.rail import Rail
-from guardrails.run import AsyncRunner, Runner, StreamRunner
+from guardrails.run import AsyncRunner
from guardrails.schema import Schema, StringSchema
-from guardrails.stores.context import (Tracer, get_call_kwarg,
- get_tracer_context, set_call_kwargs,
- set_tracer, set_tracer_context)
+from guardrails.stores.context import (
+ Tracer,
+ get_call_kwarg,
+ get_tracer_context,
+ set_call_kwargs,
+ set_tracer,
+ set_tracer_context,
+)
from guardrails.utils.hub_telemetry_utils import HubTelemetry
from guardrails.utils.llm_response import LLMResponse
from guardrails.utils.reask_utils import FieldReAsk
@@ -460,7 +486,7 @@ def __call__(
**kwargs,
) -> Awaitable[ValidationOutcome[OT]]: ...
- async def __call__(
+ def __call__(
self,
llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]],
prompt_params: Optional[Dict] = None,
@@ -497,7 +523,7 @@ async def __call__(
The raw text output from the LLM and the validated output.
"""
- async def __call(
+ def __call(
self,
llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]],
prompt_params: Optional[Dict] = None,
@@ -580,13 +606,16 @@ async def __call(
)
# If the LLM API is not async, fail
- if not asyncio.iscoroutinefunction(llm_api):
+ # FIXME: it seems like this check isn't actually working?
+ if not inspect.isawaitable(llm_api) and not inspect.iscoroutinefunction(
+ llm_api
+ ):
raise RuntimeError(
f"The LLM API `{llm_api.__name__}` is not a coroutine. "
"Please use an async LLM API."
)
# Otherwise, call the LLM
- return await self._call_async(
+ return self._call_async(
llm_api,
prompt_params=prompt_params,
num_reasks=self.num_reasks,
@@ -616,7 +645,7 @@ async def __call(
**kwargs,
)
- async def _call_async(
+ def _call_async(
self,
llm_api: Callable[[Any], Awaitable[Any]],
prompt_params: Dict,
@@ -673,7 +702,9 @@ async def _call_async(
full_schema_reask=full_schema_reask,
disable_tracer=self._disable_tracer,
)
- call = await runner.async_run(call_log=call_log, prompt_params=prompt_params)
+ call = asyncio.run(
+ runner.async_run(call_log=call_log, prompt_params=prompt_params)
+ )
return ValidationOutcome[OT].from_guard_history(call)
def __repr__(self):
@@ -851,17 +882,21 @@ def __parse(
)
# If the LLM API is async, return a coroutine
- if asyncio.iscoroutinefunction(llm_api):
- return self._async_parse(
- llm_output,
- metadata,
- llm_api=llm_api,
- num_reasks=self.num_reasks,
- prompt_params=prompt_params,
- full_schema_reask=full_schema_reask,
- call_log=call_log,
- *args,
- **kwargs,
+ # FIXME: Add null llm_api support becuase we want to be able to call on
+ # static text (e.g. validate("foobar")). Is there a better way to do this?
+ if not llm_api or inspect.iscoroutinefunction(llm_api):
+ return asyncio.run(
+ self._async_parse(
+ llm_output,
+ metadata,
+ llm_api=llm_api,
+ num_reasks=self.num_reasks,
+ prompt_params=prompt_params,
+ full_schema_reask=full_schema_reask,
+ call_log=call_log,
+ *args,
+ **kwargs,
+ )
)
else:
raise NotImplementedError(
@@ -1075,7 +1110,7 @@ def use(
return self
@overload
- def use_many(self, *validators: Validator, on: str = "output") -> "Async": ...
+ def use_many(self, *validators: Validator, on: str = "output") -> "AsyncGuard": ...
@overload
def use_many(
diff --git a/guardrails/guard.py b/guardrails/guard.py
index c5b873ae0..84b4858ba 100644
--- a/guardrails/guard.py
+++ b/guardrails/guard.py
@@ -5,13 +5,31 @@
import warnings
from copy import deepcopy
from string import Template
-from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
- Optional, Sequence, Tuple, Type, Union, cast, overload)
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Generic,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ cast,
+ overload,
+)
from guardrails_api_client.models import AnyObject
from guardrails_api_client.models import Guard as GuardModel
-from guardrails_api_client.models import (History, HistoryEvent,
- ValidatePayload, ValidationOutput)
+from guardrails_api_client.models import (
+ History,
+ HistoryEvent,
+ ValidatePayload,
+ ValidationOutput,
+)
from guardrails_api_client.types import UNSET
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig
@@ -29,17 +47,25 @@
from guardrails.classes.history.iteration import Iteration
from guardrails.classes.history.outputs import Outputs
from guardrails.errors import ValidationError
-from guardrails.llm_providers import (get_async_llm_ask, get_llm_api_enum,
- get_llm_ask,
- model_is_supported_server_side)
+from guardrails.llm_providers import (
+ get_async_llm_ask,
+ get_llm_api_enum,
+ get_llm_ask,
+ model_is_supported_server_side,
+)
from guardrails.logger import logger, set_scope
from guardrails.prompt import Instructions, Prompt
from guardrails.rail import Rail
from guardrails.run import AsyncRunner, Runner, StreamRunner
from guardrails.schema import Schema, StringSchema
-from guardrails.stores.context import (Tracer, get_call_kwarg,
- get_tracer_context, set_call_kwargs,
- set_tracer, set_tracer_context)
+from guardrails.stores.context import (
+ Tracer,
+ get_call_kwarg,
+ get_tracer_context,
+ set_call_kwargs,
+ set_tracer,
+ set_tracer_context,
+)
from guardrails.utils.hub_telemetry_utils import HubTelemetry
from guardrails.utils.llm_response import LLMResponse
from guardrails.utils.reask_utils import FieldReAsk
@@ -686,7 +712,18 @@ def _call_sync(
call = runner(call_log=call_log, prompt_params=prompt_params)
return ValidationOutcome[OT].from_guard_history(call)
-
+ @deprecated(
+ """Async methods within Guard are deprecated and will be removed in 0.5.x.
+ Instead, please use `AsyncGuard() or pass in a sequential llm api.""",
+ category=FutureWarning,
+ stacklevel=2,
+ )
+ @deprecated(
+ """Async methods within Guard are deprecated and will be removed in 0.5.x.
+ Instead, please use `AsyncGuard() or pass in a sequential llm api.""",
+ category=FutureWarning,
+ stacklevel=2,
+ )
async def _call_async(
self,
llm_api: Callable[[Any], Awaitable[Any]],
@@ -719,12 +756,6 @@ async def _call_async(
Returns:
The raw text output from the LLM and the validated output.
"""
- warnings.warn(
- "Using an async LLM client is deprecated in Guard and will "
- "be removed in a future release. "
- "Please use `AsyncGuard` or a non async llm client instead.",
- DeprecationWarning,
- )
instructions_obj = instructions or self.instructions
prompt_obj = prompt or self.prompt
msg_history_obj = msg_history or []
@@ -1009,6 +1040,12 @@ def _sync_parse(
return ValidationOutcome[OT].from_guard_history(call)
+ @deprecated(
+ """Async methods within Guard are deprecated and will be removed in 0.5.x.
+ Instead, please use `AsyncGuard() or pass in a sequential llm api.""",
+ category=FutureWarning,
+ stacklevel=2,
+ )
async def _async_parse(
self,
llm_output: str,
diff --git a/tests/unit_tests/test_async_guard.py b/tests/unit_tests/test_async_guard.py
new file mode 100644
index 000000000..281ca0949
--- /dev/null
+++ b/tests/unit_tests/test_async_guard.py
@@ -0,0 +1,601 @@
+import openai
+import pytest
+from pydantic import BaseModel
+
+from guardrails import AsyncGuard, Rail, Validator
+from guardrails.datatypes import verify_metadata_requirements
+from guardrails.utils import args, kwargs, on_fail
+from guardrails.utils.openai_utils import OPENAI_VERSION
+from guardrails.validator_base import OnFailAction
+from guardrails.validators import ( # ReadingTime,
+ EndsWith,
+ LowerCase,
+ OneLine,
+ PassResult,
+ TwoWords,
+ UpperCase,
+ ValidLength,
+ register_validator,
+)
+
+
+@register_validator("myrequiringvalidator", data_type="string")
+class RequiringValidator(Validator):
+ required_metadata_keys = ["required_key"]
+
+ def validate(self, value, metadata):
+ return PassResult()
+
+
+@register_validator("myrequiringvalidator2", data_type="string")
+class RequiringValidator2(Validator):
+ required_metadata_keys = ["required_key2"]
+
+ def validate(self, value, metadata):
+ return PassResult()
+
+
+@pytest.mark.parametrize(
+ "spec,metadata,error_message",
+ [
+ (
+ """
+
+
+
+ """,
+ {"required_key": "a"},
+ "Missing required metadata keys: required_key",
+ ),
+ (
+ """
+
+
+
+ """,
+ {"required_key": "a", "required_key2": "b"},
+ "Missing required metadata keys: required_key, required_key2",
+ ),
+ (
+ """
+
+
+
+""",
+ {"required_key": "a"},
+ "Missing required metadata keys: required_key",
+ ),
+ ],
+)
+@pytest.mark.asyncio
+@pytest.mark.skipif(not OPENAI_VERSION.startswith("0"), reason="Only for OpenAI v0")
+async def test_required_metadata(spec, metadata, error_message):
+ guard = AsyncGuard.from_rail_string(spec)
+
+ missing_keys = verify_metadata_requirements({}, guard.output_schema.root_datatype)
+ assert set(missing_keys) == set(metadata)
+
+ not_missing_keys = verify_metadata_requirements(
+ metadata, guard.output_schema.root_datatype
+ )
+ assert not_missing_keys == []
+
+ # test sync guard
+ with pytest.raises(ValueError) as excinfo:
+ guard.parse("{}")
+ assert str(excinfo.value) == error_message
+
+ response = guard.parse("{}", metadata=metadata, num_reasks=0)
+ assert response.error is None
+
+ # test async guard
+ with pytest.raises(ValueError) as excinfo:
+ guard.parse("{}")
+ await guard.parse("{}", llm_api=openai.ChatCompletion.acreate, num_reasks=0)
+ assert str(excinfo.value) == error_message
+
+ response = await guard.parse(
+ "{}", metadata=metadata, llm_api=openai.ChatCompletion.acreate, num_reasks=0
+ )
+ assert response.error is None
+
+
+rail = Rail.from_string_validators([], "empty railspec")
+empty_rail_string = """
+
+"""
+
+
+class EmptyModel(BaseModel):
+ empty_field: str
+
+
+i_guard_none = AsyncGuard(rail)
+i_guard_two = AsyncGuard(rail, 2)
+r_guard_none = AsyncGuard.from_rail("tests/unit_tests/test_assets/empty.rail")
+r_guard_two = AsyncGuard.from_rail("tests/unit_tests/test_assets/empty.rail", 2)
+rs_guard_none = AsyncGuard.from_rail_string(empty_rail_string)
+rs_guard_two = AsyncGuard.from_rail_string(empty_rail_string, 2)
+py_guard_none = AsyncGuard.from_pydantic(output_class=EmptyModel)
+py_guard_two = AsyncGuard.from_pydantic(output_class=EmptyModel, num_reasks=2)
+s_guard_none = AsyncGuard.from_string(validators=[], description="empty railspec")
+s_guard_two = AsyncGuard.from_string(
+ validators=[], description="empty railspec", num_reasks=2
+)
+
+
+@pytest.mark.parametrize(
+ "guard,expected_num_reasks,config_num_reasks",
+ [
+ (i_guard_none, 1, None),
+ (i_guard_two, 2, None),
+ (i_guard_none, 3, 3),
+ (r_guard_none, 1, None),
+ (r_guard_two, 2, None),
+ (r_guard_none, 3, 3),
+ (rs_guard_none, 1, None),
+ (rs_guard_two, 2, None),
+ (rs_guard_none, 3, 3),
+ (py_guard_none, 1, None),
+ (py_guard_two, 2, None),
+ (py_guard_none, 3, 3),
+ (s_guard_none, 1, None),
+ (s_guard_two, 2, None),
+ (s_guard_none, 3, 3),
+ ],
+)
+def test_configure(guard: AsyncGuard, expected_num_reasks: int, config_num_reasks: int):
+ guard.configure(config_num_reasks)
+ assert guard.num_reasks == expected_num_reasks
+
+
+def guard_init_from_rail():
+ guard = AsyncGuard.from_rail("tests/unit_tests/test_assets/simple.rail")
+ assert (
+ guard.instructions.format().source.strip()
+ == "You are a helpful bot, who answers only with valid JSON"
+ )
+ assert guard.prompt.format().source.strip() == "Extract a string from the text"
+
+
+def test_use():
+ guard: AsyncGuard = (
+ AsyncGuard()
+ .use(EndsWith("a"))
+ .use(OneLine())
+ .use(LowerCase)
+ .use(TwoWords, on_fail=OnFailAction.REASK)
+ .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN)
+ )
+
+ # print(guard.__stringify__())
+ assert len(guard._validators) == 5
+
+ assert isinstance(guard._validators[0], EndsWith)
+ assert guard._validators[0]._kwargs["end"] == "a"
+ assert (
+ guard._validators[0].on_fail_descriptor == OnFailAction.FIX
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[1], OneLine)
+ assert (
+ guard._validators[1].on_fail_descriptor == OnFailAction.NOOP
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[2], LowerCase)
+ assert (
+ guard._validators[2].on_fail_descriptor == OnFailAction.NOOP
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[3], TwoWords)
+ assert guard._validators[3].on_fail_descriptor == OnFailAction.REASK # bc we set it
+
+ assert isinstance(guard._validators[4], ValidLength)
+ assert guard._validators[4]._min == 0
+ assert guard._validators[4]._kwargs["min"] == 0
+ assert guard._validators[4]._max == 12
+ assert guard._validators[4]._kwargs["max"] == 12
+ assert (
+ guard._validators[4].on_fail_descriptor == OnFailAction.REFRAIN
+ ) # bc we set it
+
+ # Raises error when trying to `use` a validator on a non-string
+ with pytest.raises(RuntimeError):
+
+ class TestClass(BaseModel):
+ another_field: str
+
+ py_guard = AsyncGuard.from_pydantic(output_class=TestClass)
+ py_guard.use(
+ EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK)
+ )
+
+ # Use a combination of prompt, instructions, msg_history and output validators
+ # Should only have the output validators in the guard,
+ # everything else is in the schema
+ guard: AsyncGuard = (
+ AsyncGuard()
+ .use(LowerCase, on="prompt")
+ .use(OneLine, on="prompt")
+ .use(UpperCase, on="instructions")
+ .use(LowerCase, on="msg_history")
+ .use(
+ EndsWith, end="a", on="output"
+ ) # default on="output", still explicitly set
+ .use(
+ TwoWords, on_fail=OnFailAction.REASK
+ ) # default on="output", implicitly set
+ )
+
+ # Check schemas for prompt, instructions and msg_history validators
+ prompt_validators = guard.rail.prompt_schema.root_datatype.validators
+ assert len(prompt_validators) == 2
+ assert prompt_validators[0].__class__.__name__ == "LowerCase"
+ assert prompt_validators[1].__class__.__name__ == "OneLine"
+
+ instructions_validators = guard.rail.instructions_schema.root_datatype.validators
+ assert len(instructions_validators) == 1
+ assert instructions_validators[0].__class__.__name__ == "UpperCase"
+
+ msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators
+ assert len(msg_history_validators) == 1
+ assert msg_history_validators[0].__class__.__name__ == "LowerCase"
+
+ # Check guard for output validators
+ assert len(guard._validators) == 2 # only 2 output validators, hence 2
+
+ assert isinstance(guard._validators[0], EndsWith)
+ assert guard._validators[0]._kwargs["end"] == "a"
+ assert (
+ guard._validators[0].on_fail_descriptor == OnFailAction.FIX
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[1], TwoWords)
+ assert guard._validators[1].on_fail_descriptor == OnFailAction.REASK # bc we set it
+
+ # Test with an invalid "on" parameter, should raise a ValueError
+ with pytest.raises(ValueError):
+ guard: AsyncGuard = (
+ AsyncGuard()
+ .use(EndsWith("a"), on="response") # invalid on parameter
+ .use(OneLine, on="prompt") # valid on parameter
+ )
+
+
+def test_use_many_instances():
+ guard: AsyncGuard = AsyncGuard().use_many(
+ EndsWith("a"), OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK)
+ )
+
+ # print(guard.__stringify__())
+ assert len(guard._validators) == 4
+
+ assert isinstance(guard._validators[0], EndsWith)
+ assert guard._validators[0]._end == "a"
+ assert guard._validators[0]._kwargs["end"] == "a"
+ assert (
+ guard._validators[0].on_fail_descriptor == OnFailAction.FIX
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[1], OneLine)
+ assert (
+ guard._validators[1].on_fail_descriptor == OnFailAction.NOOP
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[2], LowerCase)
+ assert (
+ guard._validators[2].on_fail_descriptor == OnFailAction.NOOP
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[3], TwoWords)
+ assert guard._validators[3].on_fail_descriptor == OnFailAction.REASK # bc we set it
+
+ # Raises error when trying to `use_many` a validator on a non-string
+ with pytest.raises(RuntimeError):
+
+ class TestClass(BaseModel):
+ another_field: str
+
+ py_guard = AsyncGuard.from_pydantic(output_class=TestClass)
+ py_guard.use_many(
+ [
+ EndsWith("a"),
+ OneLine(),
+ LowerCase(),
+ TwoWords(on_fail=OnFailAction.REASK),
+ ]
+ )
+
+ # Test with explicitly setting the "on" parameter = "output"
+ guard: AsyncGuard = AsyncGuard().use_many(
+ EndsWith("a"),
+ OneLine(),
+ LowerCase(),
+ TwoWords(on_fail=OnFailAction.REASK),
+ on="output",
+ )
+
+ assert len(guard._validators) == 4 # still 4 output validators, hence 4
+
+ assert isinstance(guard._validators[0], EndsWith)
+ assert guard._validators[0]._end == "a"
+ assert guard._validators[0]._kwargs["end"] == "a"
+ assert (
+ guard._validators[0].on_fail_descriptor == OnFailAction.FIX
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[1], OneLine)
+ assert (
+ guard._validators[1].on_fail_descriptor == OnFailAction.NOOP
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[2], LowerCase)
+ assert (
+ guard._validators[2].on_fail_descriptor == OnFailAction.NOOP
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[3], TwoWords)
+ assert guard._validators[3].on_fail_descriptor == OnFailAction.REASK # bc we set it
+
+ # Test with explicitly setting the "on" parameter = "prompt"
+ guard: AsyncGuard = AsyncGuard().use_many(
+ OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK), on="prompt"
+ )
+
+ prompt_validators = guard.rail.prompt_schema.root_datatype.validators
+ assert len(prompt_validators) == 3
+ assert prompt_validators[0].__class__.__name__ == "OneLine"
+ assert prompt_validators[1].__class__.__name__ == "LowerCase"
+ assert prompt_validators[2].__class__.__name__ == "TwoWords"
+ assert len(guard._validators) == 0 # no output validators, hence 0
+
+ # Test with explicitly setting the "on" parameter = "instructions"
+ guard: AsyncGuard = AsyncGuard().use_many(
+ OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK), on="instructions"
+ )
+
+ instructions_validators = guard.rail.instructions_schema.root_datatype.validators
+ assert len(instructions_validators) == 3
+ assert instructions_validators[0].__class__.__name__ == "OneLine"
+ assert instructions_validators[1].__class__.__name__ == "LowerCase"
+ assert instructions_validators[2].__class__.__name__ == "TwoWords"
+ assert len(guard._validators) == 0 # no output validators, hence 0
+
+ # Test with explicitly setting the "on" parameter = "msg_history"
+ guard: AsyncGuard = AsyncGuard().use_many(
+ OneLine(), LowerCase(), TwoWords(on_fail=OnFailAction.REASK), on="msg_history"
+ )
+
+ msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators
+ assert len(msg_history_validators) == 3
+ assert msg_history_validators[0].__class__.__name__ == "OneLine"
+ assert msg_history_validators[1].__class__.__name__ == "LowerCase"
+ assert msg_history_validators[2].__class__.__name__ == "TwoWords"
+ assert len(guard._validators) == 0 # no output validators, hence 0
+
+ # Test with an invalid "on" parameter, should raise a ValueError
+ with pytest.raises(ValueError):
+ guard: AsyncGuard = AsyncGuard().use_many(
+ EndsWith("a", on_fail=OnFailAction.EXCEPTION), OneLine(), on="response"
+ )
+
+
+def test_use_many_tuple():
+ guard: AsyncGuard = AsyncGuard().use_many(
+ OneLine,
+ (EndsWith, ["a"], {"on_fail": OnFailAction.EXCEPTION}),
+ (LowerCase, kwargs(on_fail=OnFailAction.FIX_REASK, some_other_kwarg="kwarg")),
+ (TwoWords, on_fail(OnFailAction.REASK)),
+ (ValidLength, args(0, 12), kwargs(on_fail=OnFailAction.REFRAIN)),
+ )
+
+ # print(guard.__stringify__())
+ assert len(guard._validators) == 5
+
+ assert isinstance(guard._validators[0], OneLine)
+ assert (
+ guard._validators[0].on_fail_descriptor == OnFailAction.NOOP
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[1], EndsWith)
+ assert guard._validators[1]._end == "a"
+ assert guard._validators[1]._kwargs["end"] == "a"
+ assert (
+ guard._validators[1].on_fail_descriptor == OnFailAction.EXCEPTION
+ ) # bc we set it
+
+ assert isinstance(guard._validators[2], LowerCase)
+ assert guard._validators[2]._kwargs["some_other_kwarg"] == "kwarg"
+ assert (
+ guard._validators[2].on_fail_descriptor == OnFailAction.FIX_REASK
+ ) # bc this is the default
+
+ assert isinstance(guard._validators[3], TwoWords)
+ assert guard._validators[3].on_fail_descriptor == OnFailAction.REASK # bc we set it
+
+ assert isinstance(guard._validators[4], ValidLength)
+ assert guard._validators[4]._min == 0
+ assert guard._validators[4]._kwargs["min"] == 0
+ assert guard._validators[4]._max == 12
+ assert guard._validators[4]._kwargs["max"] == 12
+ assert (
+ guard._validators[4].on_fail_descriptor == OnFailAction.REFRAIN
+ ) # bc we set it
+
+ # Test with explicitly setting the "on" parameter
+ guard: AsyncGuard = AsyncGuard().use_many(
+ (EndsWith, ["a"], {"on_fail": OnFailAction.EXCEPTION}),
+ OneLine,
+ on="output",
+ )
+
+ assert len(guard._validators) == 2 # only 2 output validators, hence 2
+
+ assert isinstance(guard._validators[0], EndsWith)
+ assert guard._validators[0]._end == "a"
+ assert guard._validators[0]._kwargs["end"] == "a"
+ assert (
+ guard._validators[0].on_fail_descriptor == OnFailAction.EXCEPTION
+ ) # bc we set it
+
+ assert isinstance(guard._validators[1], OneLine)
+ assert (
+ guard._validators[1].on_fail_descriptor == OnFailAction.NOOP
+ ) # bc this is the default
+
+ # Test with an invalid "on" parameter, should raise a ValueError
+ with pytest.raises(ValueError):
+ guard: AsyncGuard = AsyncGuard().use_many(
+ (EndsWith, ["a"], {"on_fail": OnFailAction.EXCEPTION}),
+ OneLine,
+ on="response",
+ )
+
+
+def test_validate():
+ guard: AsyncGuard = (
+ AsyncGuard()
+ .use(OneLine)
+ .use(
+ LowerCase(on_fail=OnFailAction.FIX), on="output"
+ ) # default on="output", still explicitly set
+ .use(TwoWords)
+ .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN)
+ )
+
+ llm_output: str = "Oh Canada" # bc it meets our criteria
+
+ response = guard.validate(llm_output)
+
+ assert response.validation_passed is True
+ assert response.validated_output == llm_output.lower()
+
+ llm_output_2 = "Star Spangled Banner" # to stick with the theme
+
+ response_2 = guard.validate(llm_output_2)
+
+ assert response_2.validation_passed is False
+ assert response_2.validated_output is None
+
+ # Test with a combination of prompt, output, instructions and msg_history validators
+ # Should still only use the output validators to validate the output
+ guard: AsyncGuard = (
+ AsyncGuard()
+ .use(OneLine, on="prompt")
+ .use(LowerCase, on="instructions")
+ .use(UpperCase, on="msg_history")
+ .use(LowerCase, on="output", on_fail=OnFailAction.FIX)
+ .use(TwoWords, on="output")
+ .use(ValidLength, 0, 12, on="output")
+ )
+
+ llm_output: str = "Oh Canada" # bc it meets our criteria
+
+ response = guard.validate(llm_output)
+
+ assert response.validation_passed is True
+ assert response.validated_output == llm_output.lower()
+
+ llm_output_2 = "Star Spangled Banner" # to stick with the theme
+
+ response_2 = guard.validate(llm_output_2)
+
+ assert response_2.validation_passed is False
+ assert response_2.validated_output is None
+
+
+def test_use_and_use_many():
+ guard: AsyncGuard = (
+ AsyncGuard()
+ .use_many(OneLine(), LowerCase(), on="prompt")
+ .use(UpperCase, on="instructions")
+ .use(LowerCase, on="msg_history")
+ .use_many(
+ TwoWords(on_fail=OnFailAction.REASK),
+ ValidLength(0, 12, on_fail=OnFailAction.REFRAIN),
+ on="output",
+ )
+ )
+
+ # Check schemas for prompt, instructions and msg_history validators
+ prompt_validators = guard.rail.prompt_schema.root_datatype.validators
+ assert len(prompt_validators) == 2
+ assert prompt_validators[0].__class__.__name__ == "OneLine"
+ assert prompt_validators[1].__class__.__name__ == "LowerCase"
+
+ instructions_validators = guard.rail.instructions_schema.root_datatype.validators
+ assert len(instructions_validators) == 1
+ assert instructions_validators[0].__class__.__name__ == "UpperCase"
+
+ msg_history_validators = guard.rail.msg_history_schema.root_datatype.validators
+ assert len(msg_history_validators) == 1
+ assert msg_history_validators[0].__class__.__name__ == "LowerCase"
+
+ # Check guard for output validators
+ assert len(guard._validators) == 2 # only 2 output validators, hence 2
+
+ assert isinstance(guard._validators[0], TwoWords)
+ assert guard._validators[0].on_fail_descriptor == OnFailAction.REASK # bc we set it
+
+ assert isinstance(guard._validators[1], ValidLength)
+ assert guard._validators[1]._min == 0
+ assert guard._validators[1]._kwargs["min"] == 0
+ assert guard._validators[1]._max == 12
+ assert guard._validators[1]._kwargs["max"] == 12
+ assert (
+ guard._validators[1].on_fail_descriptor == OnFailAction.REFRAIN
+ ) # bc we set it
+
+ # Test with an invalid "on" parameter, should raise a ValueError
+ with pytest.raises(ValueError):
+ guard: AsyncGuard = (
+ AsyncGuard()
+ .use_many(OneLine(), LowerCase(), on="prompt")
+ .use(UpperCase, on="instructions")
+ .use(LowerCase, on="msg_history")
+ .use_many(
+ TwoWords(on_fail=OnFailAction.REASK),
+ ValidLength(0, 12, on_fail=OnFailAction.REFRAIN),
+ on="response", # invalid "on" parameter
+ )
+ )
+
+
+# def test_call():
+# five_seconds = 5 / 60
+# response = AsyncGuard().use_many(
+# ReadingTime(five_seconds, on_fail=OnFailAction.EXCEPTION),
+# OneLine,
+# (EndsWith, ["a"], {"on_fail": OnFailAction.EXCEPTION}),
+# (LowerCase, kwargs(on_fail=OnFailAction.FIX_REASK, some_other_kwarg="kwarg")),
+# (TwoWords, on_fail(OnFailAction.REASK)),
+# (ValidLength, args(0, 12), kwargs(on_fail=OnFailAction.REFRAIN)),
+# )("Oh Canada")
+
+# assert response.validation_passed is True
+# assert response.validated_output == "oh canada"
From 58f8e2d4130a5903934181f3a54d24d4e63b0516 Mon Sep 17 00:00:00 2001
From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com>
Date: Mon, 13 May 2024 13:41:59 -0700
Subject: [PATCH 3/9] removing extra deprecation warning
---
guardrails/guard.py | 6 ------
1 file changed, 6 deletions(-)
diff --git a/guardrails/guard.py b/guardrails/guard.py
index 84b4858ba..a0f51f7d0 100644
--- a/guardrails/guard.py
+++ b/guardrails/guard.py
@@ -712,12 +712,6 @@ def _call_sync(
call = runner(call_log=call_log, prompt_params=prompt_params)
return ValidationOutcome[OT].from_guard_history(call)
- @deprecated(
- """Async methods within Guard are deprecated and will be removed in 0.5.x.
- Instead, please use `AsyncGuard() or pass in a sequential llm api.""",
- category=FutureWarning,
- stacklevel=2,
- )
@deprecated(
"""Async methods within Guard are deprecated and will be removed in 0.5.x.
Instead, please use `AsyncGuard() or pass in a sequential llm api.""",
From 1442c485fa1134d3b3e7b67be78faca042fc1e47 Mon Sep 17 00:00:00 2001
From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com>
Date: Mon, 13 May 2024 13:43:02 -0700
Subject: [PATCH 4/9] removing unused test
---
tests/unit_tests/test_async_guard.py | 15 ---------------
1 file changed, 15 deletions(-)
diff --git a/tests/unit_tests/test_async_guard.py b/tests/unit_tests/test_async_guard.py
index 281ca0949..1e02f750c 100644
--- a/tests/unit_tests/test_async_guard.py
+++ b/tests/unit_tests/test_async_guard.py
@@ -584,18 +584,3 @@ def test_use_and_use_many():
on="response", # invalid "on" parameter
)
)
-
-
-# def test_call():
-# five_seconds = 5 / 60
-# response = AsyncGuard().use_many(
-# ReadingTime(five_seconds, on_fail=OnFailAction.EXCEPTION),
-# OneLine,
-# (EndsWith, ["a"], {"on_fail": OnFailAction.EXCEPTION}),
-# (LowerCase, kwargs(on_fail=OnFailAction.FIX_REASK, some_other_kwarg="kwarg")),
-# (TwoWords, on_fail(OnFailAction.REASK)),
-# (ValidLength, args(0, 12), kwargs(on_fail=OnFailAction.REFRAIN)),
-# )("Oh Canada")
-
-# assert response.validation_passed is True
-# assert response.validated_output == "oh canada"
From f80285a1b9c490a12e5c7b39be8836fbe89de19f Mon Sep 17 00:00:00 2001
From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com>
Date: Mon, 13 May 2024 14:18:32 -0700
Subject: [PATCH 5/9] fixing init
---
guardrails/__init__.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/guardrails/__init__.py b/guardrails/__init__.py
index 8883668f0..b1d081fca 100644
--- a/guardrails/__init__.py
+++ b/guardrails/__init__.py
@@ -11,6 +11,7 @@
__all__ = [
"Guard",
+ "AsyncGuard",
"PromptCallableBase",
"Rail",
"Validator",
From af6b72eb9ba95042179f92266fd7b9f9c3a57439 Mon Sep 17 00:00:00 2001
From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com>
Date: Thu, 16 May 2024 15:07:35 -0700
Subject: [PATCH 6/9] Updating Async guard to inherit from guard
---
guardrails/async_guard.py | 886 +-------------------------------------
1 file changed, 13 insertions(+), 873 deletions(-)
diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py
index 09affb0fc..607e88776 100644
--- a/guardrails/async_guard.py
+++ b/guardrails/async_guard.py
@@ -1,80 +1,37 @@
import asyncio
import contextvars
import inspect
-import json
-import os
-import warnings
-from copy import deepcopy
-from string import Template
from typing import (
Any,
Awaitable,
Callable,
Dict,
- Generic,
Iterable,
List,
Optional,
- Sequence,
- Tuple,
- Type,
Union,
- cast,
overload,
)
-from guardrails_api_client.models import AnyObject
-from guardrails_api_client.models import Guard as GuardModel
-from guardrails_api_client.models import (
- History,
- HistoryEvent,
- ValidatePayload,
- ValidationOutput,
-)
-from guardrails_api_client.types import UNSET
-from langchain_core.messages import BaseMessage
-from langchain_core.runnables import Runnable, RunnableConfig
-from pydantic import BaseModel
-from pydantic.version import VERSION as PYDANTIC_VERSION
-from typing_extensions import deprecated
-from guardrails.api_client import GuardrailsApiClient
-from guardrails.classes import OT, InputType, ValidationOutcome
-from guardrails.classes.credentials import Credentials
-from guardrails.classes.generic import Stack
+from guardrails import Guard
+from guardrails.classes import OT, ValidationOutcome
from guardrails.classes.history import Call
from guardrails.classes.history.call_inputs import CallInputs
-from guardrails.classes.history.inputs import Inputs
-from guardrails.classes.history.iteration import Iteration
-from guardrails.classes.history.outputs import Outputs
-from guardrails.errors import ValidationError
from guardrails.llm_providers import (
get_async_llm_ask,
- get_llm_api_enum,
- get_llm_ask,
model_is_supported_server_side,
)
-from guardrails.logger import logger, set_scope
-from guardrails.prompt import Instructions, Prompt
-from guardrails.rail import Rail
+from guardrails.logger import set_scope
from guardrails.run import AsyncRunner
-from guardrails.schema import Schema, StringSchema
from guardrails.stores.context import (
- Tracer,
- get_call_kwarg,
- get_tracer_context,
set_call_kwargs,
set_tracer,
set_tracer_context,
)
-from guardrails.utils.hub_telemetry_utils import HubTelemetry
-from guardrails.utils.llm_response import LLMResponse
-from guardrails.utils.reask_utils import FieldReAsk
-from guardrails.utils.validator_utils import get_validator
-from guardrails.validator_base import FailResult, Validator
-class AsyncGuard(Runnable, Generic[OT]):
+class AsyncGuard(Guard):
"""The Guard class.
This class one of the main entry point for using Guardrails. It is
@@ -91,370 +48,6 @@ class AsyncGuard(Runnable, Generic[OT]):
the LLM and the validated output stream.
"""
- _tracer = None
- _tracer_context = None
- _hub_telemetry = None
- _guard_id = None
- _user_id = None
- _validators: List[Validator]
- _api_client: Optional[GuardrailsApiClient] = None
-
- def __init__(
- self,
- rail: Optional[Rail] = None,
- num_reasks: Optional[int] = None,
- base_model: Optional[
- Union[Type[BaseModel], Type[List[Type[BaseModel]]]]
- ] = None,
- tracer: Optional[Tracer] = None,
- *,
- name: Optional[str] = None,
- description: Optional[str] = None,
- ):
- """Initialize the Guard with optional Rail instance, num_reasks, and
- base_model."""
- if not rail:
- rail = (
- Rail.from_pydantic(base_model)
- if base_model
- else Rail.from_string_validators([])
- )
- self.rail = rail
- self.num_reasks = num_reasks
- # TODO: Support a sink for history so that it is not solely held in memory
- self.history: Stack[Call] = Stack()
- self.base_model = base_model
- self._set_tracer(tracer)
-
- credentials = Credentials.from_rc_file(logger)
-
- # Get unique id of user from credentials
- self._user_id = credentials.id or ""
-
- # Get metrics opt-out from credentials
- self._disable_tracer = credentials.no_metrics
-
- # Get id of guard object (that is unique)
- self._guard_id = id(self) # id of guard object; not the class
-
- # Initialize Hub Telemetry singleton and get the tracer
- # if it is not disabled
- if not self._disable_tracer:
- self._hub_telemetry = HubTelemetry()
- self._validators = []
-
- # Gaurdrails As A Service Initialization
- self.description = description
- self.name = name
-
- api_key = os.environ.get("GUARDRAILS_API_KEY")
- if api_key is not None:
- if self.name is None:
- self.name = f"gr-{str(self._guard_id)}"
- logger.warn("Warning: No name passed to guard!")
- logger.warn(
- "Use this auto-generated name to re-use this guard: {name}".format(
- name=self.name
- )
- )
- self._api_client = GuardrailsApiClient(api_key=api_key)
- self.upsert_guard()
-
- @property
- def prompt_schema(self) -> Optional[StringSchema]:
- """Return the input schema."""
- return self.rail.prompt_schema
-
- @property
- def instructions_schema(self) -> Optional[StringSchema]:
- """Return the input schema."""
- return self.rail.instructions_schema
-
- @property
- def msg_history_schema(self) -> Optional[StringSchema]:
- """Return the input schema."""
- return self.rail.msg_history_schema
-
- @property
- def output_schema(self) -> Schema:
- """Return the output schema."""
- return self.rail.output_schema
-
- @property
- def instructions(self) -> Optional[Instructions]:
- """Return the instruction-prompt."""
- return self.rail.instructions
-
- @property
- def prompt(self) -> Optional[Prompt]:
- """Return the prompt."""
- return self.rail.prompt
-
- @property
- def raw_prompt(self) -> Optional[Prompt]:
- """Return the prompt, alias for `prompt`."""
- return self.prompt
-
- @property
- def base_prompt(self) -> Optional[str]:
- """Return the base prompt i.e. prompt.source."""
- if self.prompt is None:
- return None
- return self.prompt.source
-
- @property
- def reask_prompt(self) -> Optional[Prompt]:
- """Return the reask prompt."""
- return self.output_schema.reask_prompt_template
-
- @reask_prompt.setter
- def reask_prompt(self, reask_prompt: Optional[str]):
- """Set the reask prompt."""
- self.output_schema.reask_prompt_template = reask_prompt
-
- @property
- def reask_instructions(self) -> Optional[Instructions]:
- """Return the reask prompt."""
- return self.output_schema.reask_instructions_template
-
- @reask_instructions.setter
- def reask_instructions(self, reask_instructions: Optional[str]):
- """Set the reask prompt."""
- self.output_schema.reask_instructions_template = reask_instructions
-
- def configure(
- self,
- num_reasks: Optional[int] = None,
- ):
- """Configure the Guard."""
- self.num_reasks = (
- num_reasks
- if num_reasks is not None
- else self.num_reasks
- if self.num_reasks is not None
- else 1
- )
-
- def _set_tracer(self, tracer: Optional[Tracer] = None) -> None:
- self._tracer = tracer
- set_tracer(tracer)
- set_tracer_context()
- self._tracer_context = get_tracer_context()
-
- @classmethod
- def from_rail(
- cls,
- rail_file: str,
- num_reasks: Optional[int] = None,
- tracer: Optional[Tracer] = None,
- *,
- name: Optional[str] = None,
- description: Optional[str] = None,
- ):
- """Create a Schema from a `.rail` file.
-
- Args:
- rail_file: The path to the `.rail` file.
- num_reasks: The max times to re-ask the LLM for invalid output.
-
- Returns:
- An instance of the `Guard` class.
- """
-
- # We have to set the tracer in the ContextStore before the Rail,
- # and therefore the Validators, are initialized
- cls._set_tracer(cls, tracer) # type: ignore
-
- rail = Rail.from_file(rail_file)
- if rail.output_type == "str":
- return cast(
- AsyncGuard[str],
- cls(
- rail=rail,
- num_reasks=num_reasks,
- tracer=tracer,
- name=name,
- description=description,
- ),
- )
- elif rail.output_type == "list":
- return cast(
- AsyncGuard[List],
- cls(
- rail=rail,
- num_reasks=num_reasks,
- tracer=tracer,
- name=name,
- description=description,
- ),
- )
- return cast(
- AsyncGuard[Dict],
- cls(
- rail=rail,
- num_reasks=num_reasks,
- tracer=tracer,
- name=name,
- description=description,
- ),
- )
-
- @classmethod
- def from_rail_string(
- cls,
- rail_string: str,
- num_reasks: Optional[int] = None,
- tracer: Optional[Tracer] = None,
- *,
- name: Optional[str] = None,
- description: Optional[str] = None,
- ):
- """Create a Schema from a `.rail` string.
-
- Args:
- rail_string: The `.rail` string.
- num_reasks: The max times to re-ask the LLM for invalid output.
-
- Returns:
- An instance of the `Guard` class.
- """
- # We have to set the tracer in the ContextStore before the Rail,
- # and therefore the Validators, are initialized
- cls._set_tracer(cls, tracer) # type: ignore
-
- rail = Rail.from_string(rail_string)
- if rail.output_type == "str":
- return cast(
- AsyncGuard[str],
- cls(
- rail=rail,
- num_reasks=num_reasks,
- tracer=tracer,
- name=name,
- description=description,
- ),
- )
- elif rail.output_type == "list":
- return cast(
- AsyncGuard[List],
- cls(
- rail=rail,
- num_reasks=num_reasks,
- tracer=tracer,
- name=name,
- description=description,
- ),
- )
- return cast(
- AsyncGuard[Dict],
- cls(
- rail=rail,
- num_reasks=num_reasks,
- tracer=tracer,
- name=name,
- description=description,
- ),
- )
-
- @classmethod
- def from_pydantic(
- cls,
- output_class: Union[Type[BaseModel], Type[List[Type[BaseModel]]]],
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
- num_reasks: Optional[int] = None,
- reask_prompt: Optional[str] = None,
- reask_instructions: Optional[str] = None,
- tracer: Optional[Tracer] = None,
- *,
- name: Optional[str] = None,
- description: Optional[str] = None,
- ):
- """Create a Guard instance from a Pydantic model and prompt."""
- if PYDANTIC_VERSION.startswith("1"):
- warnings.warn(
- """Support for Pydantic v1.x is deprecated and will be removed in
- Guardrails 0.5.x. Please upgrade to the latest Pydantic v2.x to
- continue receiving future updates and support.""",
- FutureWarning,
- )
- # We have to set the tracer in the ContextStore before the Rail,
- # and therefore the Validators, are initialized
- cls._set_tracer(cls, tracer) # type: ignore
-
- rail = Rail.from_pydantic(
- output_class=output_class,
- prompt=prompt,
- instructions=instructions,
- reask_prompt=reask_prompt,
- reask_instructions=reask_instructions,
- )
- if rail.output_type == "list":
- return cast(
- AsyncGuard[List],
- cls(rail, num_reasks=num_reasks, base_model=output_class),
- )
- return cast(
- AsyncGuard[Dict],
- cls(
- rail,
- num_reasks=num_reasks,
- base_model=output_class,
- tracer=tracer,
- name=name,
- description=description,
- ),
- )
-
- @classmethod
- def from_string(
- cls,
- validators: Sequence[Validator],
- description: Optional[str] = None,
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
- reask_prompt: Optional[str] = None,
- reask_instructions: Optional[str] = None,
- num_reasks: Optional[int] = None,
- tracer: Optional[Tracer] = None,
- *,
- name: Optional[str] = None,
- guard_description: Optional[str] = None,
- ):
- """Create a Guard instance for a string response with prompt,
- instructions, and validations.
-
- Args:
- validators: (List[Validator]): The list of validators to apply to the string output.
- description (str, optional): A description for the string to be generated. Defaults to None.
- prompt (str, optional): The prompt used to generate the string. Defaults to None.
- instructions (str, optional): Instructions for chat models. Defaults to None.
- reask_prompt (str, optional): An alternative prompt to use during reasks. Defaults to None.
- reask_instructions (str, optional): Alternative instructions to use during reasks. Defaults to None.
- num_reasks (int, optional): The max times to re-ask the LLM for invalid output.
- """ # noqa
-
- cls._set_tracer(cls, tracer) # type: ignore
-
- rail = Rail.from_string_validators(
- validators=validators,
- description=description,
- prompt=prompt,
- instructions=instructions,
- reask_prompt=reask_prompt,
- reask_instructions=reask_instructions,
- )
- return cast(
- AsyncGuard[str],
- cls(
- rail,
- num_reasks=num_reasks,
- tracer=tracer,
- name=name,
- description=guard_description,
- ),
- )
-
@overload
def __call__(
self,
@@ -707,71 +300,6 @@ def _call_async(
)
return ValidationOutcome[OT].from_guard_history(call)
- def __repr__(self):
- return f"Guard(RAIL={self.rail})"
-
- def __rich_repr__(self):
- yield "RAIL", self.rail
-
- def __stringify__(self):
- if self.rail and self.rail.output_type == "str":
- template = Template(
- """
- Guard {
- validators: [
- ${validators}
- ]
- }
- """
- )
- return template.safe_substitute(
- {
- "validators": ",\n".join(
- [v.__stringify__() for v in self._validators]
- )
- }
- )
- return self.__repr__()
-
- @overload
- def parse(
- self,
- llm_output: str,
- metadata: Optional[Dict] = None,
- llm_api: None = None,
- num_reasks: Optional[int] = None,
- prompt_params: Optional[Dict] = None,
- full_schema_reask: Optional[bool] = None,
- *args,
- **kwargs,
- ) -> ValidationOutcome[OT]: ...
-
- @overload
- def parse(
- self,
- llm_output: str,
- metadata: Optional[Dict] = None,
- llm_api: Callable[[Any], Awaitable[Any]] = ...,
- num_reasks: Optional[int] = None,
- prompt_params: Optional[Dict] = None,
- full_schema_reask: Optional[bool] = None,
- *args,
- **kwargs,
- ) -> Awaitable[ValidationOutcome[OT]]: ...
-
- @overload
- def parse(
- self,
- llm_output: str,
- metadata: Optional[Dict] = None,
- llm_api: Optional[Callable] = None,
- num_reasks: Optional[int] = None,
- prompt_params: Optional[Dict] = None,
- full_schema_reask: Optional[bool] = None,
- *args,
- **kwargs,
- ) -> ValidationOutcome[OT]: ...
-
def parse(
self,
llm_output: str,
@@ -881,10 +409,13 @@ def __parse(
**kwargs,
)
- # If the LLM API is async, return a coroutine
- # FIXME: Add null llm_api support becuase we want to be able to call on
- # static text (e.g. validate("foobar")). Is there a better way to do this?
- if not llm_api or inspect.iscoroutinefunction(llm_api):
+ # FIXME: checking not llm_api because it can still fall back on defaults and
+ # function as expected. We should handle this better.
+ if (
+ not llm_api
+ or inspect.iscoroutinefunction(llm_api)
+ or inspect.isasyncgenfunction(llm_api)
+ ):
return asyncio.run(
self._async_parse(
llm_output,
@@ -901,7 +432,8 @@ def __parse(
else:
raise NotImplementedError(
"AsyncGuard does not support non-async LLM APIs. "
- "Please use the sync API Guard or supply an async LLM API."
+ "Please use the synchronous API Guard or supply an asynchronous "
+ "LLM API."
)
guard_context = contextvars.Context()
@@ -959,395 +491,3 @@ async def _async_parse(
call = await runner.async_run(call_log=call_log, prompt_params=prompt_params)
return ValidationOutcome[OT].from_guard_history(call)
-
- @deprecated(
- """The `with_prompt_validation` method is deprecated,
- and will be removed in 0.5.x. Instead, please use
- `Guard().use(YourValidator, on='prompt')`.""",
- category=FutureWarning,
- stacklevel=2,
- )
- def with_prompt_validation(
- self,
- validators: Sequence[Validator],
- ):
- """Add prompt validation to the Guard.
-
- Args:
- validators: The validators to add to the prompt.
- """
- if self.rail.prompt_schema:
- warnings.warn("Overriding existing prompt validators.")
- schema = StringSchema.from_string(
- validators=validators,
- )
- self.rail.prompt_schema = schema
- return self
-
- @deprecated(
- """The `with_instructions_validation` method is deprecated,
- and will be removed in 0.5.x. Instead, please use
- `Guard().use(YourValidator, on='instructions')`.""",
- category=FutureWarning,
- stacklevel=2,
- )
- def with_instructions_validation(
- self,
- validators: Sequence[Validator],
- ):
- """Add instructions validation to the Guard.
-
- Args:
- validators: The validators to add to the instructions.
- """
- if self.rail.instructions_schema:
- warnings.warn("Overriding existing instructions validators.")
- schema = StringSchema.from_string(
- validators=validators,
- )
- self.rail.instructions_schema = schema
- return self
-
- @deprecated(
- """The `with_msg_history_validation` method is deprecated,
- and will be removed in 0.5.x. Instead, please use
- `Guard().use(YourValidator, on='msg_history')`.""",
- category=FutureWarning,
- stacklevel=2,
- )
- def with_msg_history_validation(
- self,
- validators: Sequence[Validator],
- ):
- """Add msg_history validation to the Guard.
-
- Args:
- validators: The validators to add to the msg_history.
- """
- if self.rail.msg_history_schema:
- warnings.warn("Overriding existing msg_history validators.")
- schema = StringSchema.from_string(
- validators=validators,
- )
- self.rail.msg_history_schema = schema
- return self
-
- def __add_validator(self, validator: Validator, on: str = "output"):
- # Only available for string output types
- if self.rail.output_type != "str":
- raise RuntimeError(
- "The `use` method is only available for string output types."
- )
-
- if on == "prompt":
- # If the prompt schema exists, add the validator to it
- if self.rail.prompt_schema:
- self.rail.prompt_schema.root_datatype.validators.append(validator)
- else:
- # Otherwise, create a new schema with the validator
- schema = StringSchema.from_string(
- validators=[validator],
- )
- self.rail.prompt_schema = schema
- elif on == "instructions":
- # If the instructions schema exists, add the validator to it
- if self.rail.instructions_schema:
- self.rail.instructions_schema.root_datatype.validators.append(validator)
- else:
- # Otherwise, create a new schema with the validator
- schema = StringSchema.from_string(
- validators=[validator],
- )
- self.rail.instructions_schema = schema
- elif on == "msg_history":
- # If the msg_history schema exists, add the validator to it
- if self.rail.msg_history_schema:
- self.rail.msg_history_schema.root_datatype.validators.append(validator)
- else:
- # Otherwise, create a new schema with the validator
- schema = StringSchema.from_string(
- validators=[validator],
- )
- self.rail.msg_history_schema = schema
- elif on == "output":
- self._validators.append(validator)
- self.rail.output_schema.root_datatype.validators.append(validator)
- else:
- raise ValueError(
- """Invalid value for `on`. Must be one of the following:
- 'output', 'prompt', 'instructions', 'msg_history'."""
- )
-
- @overload
- def use(self, validator: Validator, *, on: str = "output") -> "AsyncGuard": ...
-
- @overload
- def use(
- self, validator: Type[Validator], *args, on: str = "output", **kwargs
- ) -> "AsyncGuard": ...
-
- def use(
- self,
- validator: Union[Validator, Type[Validator]],
- *args,
- on: str = "output",
- **kwargs,
- ) -> "AsyncGuard":
- """Use a validator to validate either of the following:
- - The output of an LLM request
- - The prompt
- - The instructions
- - The message history
-
- *Note*: For on="output", `use` is only available for string output types.
-
- Args:
- validator: The validator to use. Either the class or an instance.
- on: The part of the LLM request to validate. Defaults to "output".
- """
- hydrated_validator = get_validator(validator, *args, **kwargs)
- self.__add_validator(hydrated_validator, on=on)
- return self
-
- @overload
- def use_many(self, *validators: Validator, on: str = "output") -> "AsyncGuard": ...
-
- @overload
- def use_many(
- self,
- *validators: Tuple[
- Type[Validator],
- Optional[Union[List[Any], Dict[str, Any]]],
- Optional[Dict[str, Any]],
- ],
- on: str = "output",
- ) -> "AsyncGuard": ...
-
- def use_many(
- self,
- *validators: Union[
- Validator,
- Tuple[
- Type[Validator],
- Optional[Union[List[Any], Dict[str, Any]]],
- Optional[Dict[str, Any]],
- ],
- ],
- on: str = "output",
- ) -> "AsyncGuard":
- """Use a validator to validate results of an LLM request.
-
- *Note*: `use_many` is only available for string output types.
- """
- if self.rail.output_type != "str":
- raise RuntimeError(
- "The `use_many` method is only available for string output types."
- )
-
- # Loop through the validators
- for v in validators:
- hydrated_validator = get_validator(v)
- self.__add_validator(hydrated_validator, on=on)
- return self
-
- def validate(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]:
- if (
- not self.rail
- or self.rail.output_schema.root_datatype.validators != self._validators
- ):
- self.rail = Rail.from_string_validators(
- validators=self._validators,
- prompt=self.prompt.source if self.prompt else None,
- instructions=self.instructions.source if self.instructions else None,
- reask_prompt=self.reask_prompt.source if self.reask_prompt else None,
- reask_instructions=self.reask_instructions.source
- if self.reask_instructions
- else None,
- )
-
- return self.parse(llm_output=llm_output, *args, **kwargs)
-
- # No call support for this until
- # https://github.com/guardrails-ai/guardrails/pull/525 is merged
- # def __call__(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]:
- # return self.validate(llm_output, *args, **kwargs)
-
- def invoke(
- self, input: InputType, config: Optional[RunnableConfig] = None
- ) -> InputType:
- output = BaseMessage(content="", type="")
- str_input = None
- input_is_chat_message = False
- if isinstance(input, BaseMessage):
- input_is_chat_message = True
- str_input = str(input.content)
- output = deepcopy(input)
- else:
- str_input = str(input)
-
- response = self.validate(str_input)
-
- validated_output = response.validated_output
- if not validated_output:
- raise ValidationError(
- (
- "The response from the LLM failed validation!"
- "See `guard.history` for more details."
- )
- )
-
- if isinstance(validated_output, Dict):
- validated_output = json.dumps(validated_output)
-
- if input_is_chat_message:
- output.content = validated_output
- return cast(InputType, output)
- return cast(InputType, validated_output)
-
- def _to_request(self) -> Dict:
- return {
- "name": self.name,
- "description": self.description,
- "railspec": self.rail._to_request(),
- "numReasks": self.num_reasks,
- }
-
- def upsert_guard(self):
- if self._api_client:
- guard_dict = self._to_request()
- self._api_client.upsert_guard(GuardModel.from_dict(guard_dict))
- else:
- raise ValueError("Guard does not have an api client!")
-
- def _call_server(
- self,
- *args,
- llm_output: Optional[str] = None,
- llm_api: Optional[Callable] = None,
- num_reasks: Optional[int] = None,
- prompt_params: Optional[Dict] = None,
- metadata: Optional[Dict] = {},
- full_schema_reask: Optional[bool] = True,
- call_log: Optional[Call],
- # prompt: Optional[str],
- # instructions: Optional[str],
- # msg_history: Optional[List[Dict]],
- **kwargs,
- ):
- if self._api_client:
- payload: Dict[str, Any] = {"args": list(args)}
- payload.update(**kwargs)
- if llm_output is not None:
- payload["llmOutput"] = llm_output
- if num_reasks is not None:
- payload["numReasks"] = num_reasks
- if prompt_params is not None:
- payload["promptParams"] = prompt_params
- if llm_api is not None:
- payload["llmApi"] = get_llm_api_enum(llm_api)
- # TODO: get enum for llm_api
- validation_output: Optional[ValidationOutput] = self._api_client.validate(
- guard=self, # type: ignore
- payload=ValidatePayload.from_dict(payload),
- openai_api_key=get_call_kwarg("api_key"),
- )
-
- if not validation_output:
- return ValidationOutcome[OT](
- raw_llm_output=None,
- validated_output=None,
- validation_passed=False,
- error="The response from the server was empty!",
- )
-
- call_log = call_log or Call()
- if llm_api is not None:
- llm_api = get_llm_ask(llm_api)
- if asyncio.iscoroutinefunction(llm_api):
- llm_api = get_async_llm_ask(llm_api)
- session_history = (
- validation_output.session_history
- if validation_output is not None and validation_output.session_history
- else []
- )
- history: History
- for history in session_history:
- history_events: Optional[List[HistoryEvent]] = ( # type: ignore
- history.history if history.history != UNSET else None
- )
- if history_events is None:
- continue
-
- iterations = [
- Iteration(
- inputs=Inputs(
- llm_api=llm_api,
- llm_output=llm_output,
- instructions=(
- Instructions(h.instructions) if h.instructions else None
- ),
- prompt=(
- Prompt(h.prompt.source) # type: ignore
- if h.prompt is not None and h.prompt != UNSET
- else None
- ),
- prompt_params=prompt_params,
- num_reasks=(num_reasks or 0),
- metadata=metadata,
- full_schema_reask=full_schema_reask,
- ),
- outputs=Outputs(
- llm_response_info=LLMResponse(
- output=h.output # type: ignore
- ),
- raw_output=h.output,
- parsed_output=(
- h.parsed_output.to_dict()
- if isinstance(h.parsed_output, AnyObject)
- else h.parsed_output
- ),
- validation_output=(
- h.validated_output.to_dict()
- if isinstance(h.validated_output, AnyObject)
- else h.validated_output
- ),
- reasks=list(
- [
- FieldReAsk(
- incorrect_value=r.to_dict().get(
- "incorrect_value"
- ),
- path=r.to_dict().get("path"),
- fail_results=[
- FailResult(
- error_message=r.to_dict().get(
- "error_message"
- ),
- fix_value=r.to_dict().get("fix_value"),
- )
- ],
- )
- for r in h.reasks # type: ignore
- ]
- if h.reasks != UNSET
- else []
- ),
- ),
- )
- for h in history_events
- ]
- call_log.iterations.extend(iterations)
- if self.history.length == 0:
- self.history.push(call_log)
-
- # Our interfaces are too different for this to work right now.
- # Once we move towards shared interfaces for both the open source
- # and the api we can re-enable this.
- # return ValidationOutcome[OT].from_guard_history(call_log)
- return ValidationOutcome[OT](
- raw_llm_output=validation_output.raw_llm_response, # type: ignore
- validated_output=cast(OT, validation_output.validated_output),
- validation_passed=validation_output.result,
- )
- else:
- raise ValueError("Guard does not have an api client!")
From 64f2bd185917602b09d14e7ffd852bc8f1073560 Mon Sep 17 00:00:00 2001
From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com>
Date: Thu, 16 May 2024 15:09:01 -0700
Subject: [PATCH 7/9] fixing sequential -> synchronous in guard deprecations
---
guardrails/guard.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/guardrails/guard.py b/guardrails/guard.py
index a0f51f7d0..d5f25ba7e 100644
--- a/guardrails/guard.py
+++ b/guardrails/guard.py
@@ -714,7 +714,7 @@ def _call_sync(
@deprecated(
"""Async methods within Guard are deprecated and will be removed in 0.5.x.
- Instead, please use `AsyncGuard() or pass in a sequential llm api.""",
+ Instead, please use `AsyncGuard() or pass in a synchronous llm api.""",
category=FutureWarning,
stacklevel=2,
)
@@ -1036,7 +1036,7 @@ def _sync_parse(
@deprecated(
"""Async methods within Guard are deprecated and will be removed in 0.5.x.
- Instead, please use `AsyncGuard() or pass in a sequential llm api.""",
+ Instead, please use `AsyncGuard() or pass in a synchronous llm api.""",
category=FutureWarning,
stacklevel=2,
)
@@ -1473,3 +1473,4 @@ def _call_server(
)
else:
raise ValueError("Guard does not have an api client!")
+ raise ValueError("Guard does not have an api client!")
From 88cd5dcf5e8acfde3e0245086a0384f750801436 Mon Sep 17 00:00:00 2001
From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com>
Date: Thu, 16 May 2024 16:40:31 -0700
Subject: [PATCH 8/9] __call__ is now async to match standard llm apis rather
than handling internally
---
guardrails/async_guard.py | 88 +++++++---------------------
tests/unit_tests/test_async_guard.py | 13 ++--
2 files changed, 28 insertions(+), 73 deletions(-)
diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py
index 607e88776..e689ac698 100644
--- a/guardrails/async_guard.py
+++ b/guardrails/async_guard.py
@@ -1,4 +1,3 @@
-import asyncio
import contextvars
import inspect
from typing import (
@@ -10,25 +9,16 @@
List,
Optional,
Union,
- overload,
)
-
from guardrails import Guard
from guardrails.classes import OT, ValidationOutcome
from guardrails.classes.history import Call
from guardrails.classes.history.call_inputs import CallInputs
-from guardrails.llm_providers import (
- get_async_llm_ask,
- model_is_supported_server_side,
-)
+from guardrails.llm_providers import get_async_llm_ask, model_is_supported_server_side
from guardrails.logger import set_scope
from guardrails.run import AsyncRunner
-from guardrails.stores.context import (
- set_call_kwargs,
- set_tracer,
- set_tracer_context,
-)
+from guardrails.stores.context import set_call_kwargs, set_tracer, set_tracer_context
class AsyncGuard(Guard):
@@ -48,38 +38,7 @@ class AsyncGuard(Guard):
the LLM and the validated output stream.
"""
- @overload
- def __call__(
- self,
- llm_api: Callable,
- prompt_params: Optional[Dict] = None,
- num_reasks: Optional[int] = None,
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
- msg_history: Optional[List[Dict]] = None,
- metadata: Optional[Dict] = None,
- full_schema_reask: Optional[bool] = None,
- stream: Optional[bool] = False,
- *args,
- **kwargs,
- ) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]: ...
-
- @overload
- def __call__(
- self,
- llm_api: Callable[[Any], Awaitable[Any]],
- prompt_params: Optional[Dict] = None,
- num_reasks: Optional[int] = None,
- prompt: Optional[str] = None,
- instructions: Optional[str] = None,
- msg_history: Optional[List[Dict]] = None,
- metadata: Optional[Dict] = None,
- full_schema_reask: Optional[bool] = None,
- *args,
- **kwargs,
- ) -> Awaitable[ValidationOutcome[OT]]: ...
-
- def __call__(
+ async def __call__(
self,
llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]],
prompt_params: Optional[Dict] = None,
@@ -116,7 +75,7 @@ def __call__(
The raw text output from the LLM and the validated output.
"""
- def __call(
+ async def __call(
self,
llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]],
prompt_params: Optional[Dict] = None,
@@ -208,7 +167,7 @@ def __call(
"Please use an async LLM API."
)
# Otherwise, call the LLM
- return self._call_async(
+ return await self._call_async(
llm_api,
prompt_params=prompt_params,
num_reasks=self.num_reasks,
@@ -223,7 +182,7 @@ def __call(
)
guard_context = contextvars.Context()
- return guard_context.run(
+ return await guard_context.run(
__call,
self,
llm_api,
@@ -238,7 +197,7 @@ def __call(
**kwargs,
)
- def _call_async(
+ async def _call_async(
self,
llm_api: Callable[[Any], Awaitable[Any]],
prompt_params: Dict,
@@ -295,12 +254,10 @@ def _call_async(
full_schema_reask=full_schema_reask,
disable_tracer=self._disable_tracer,
)
- call = asyncio.run(
- runner.async_run(call_log=call_log, prompt_params=prompt_params)
- )
+ call = await runner.async_run(call_log=call_log, prompt_params=prompt_params)
return ValidationOutcome[OT].from_guard_history(call)
- def parse(
+ async def parse(
self,
llm_output: str,
metadata: Optional[Dict] = None,
@@ -328,7 +285,7 @@ def parse(
determined by the object schema defined in the RAILspec.
"""
- def __parse(
+ async def __parse(
self,
llm_output: str,
metadata: Optional[Dict] = None,
@@ -416,19 +373,18 @@ def __parse(
or inspect.iscoroutinefunction(llm_api)
or inspect.isasyncgenfunction(llm_api)
):
- return asyncio.run(
- self._async_parse(
- llm_output,
- metadata,
- llm_api=llm_api,
- num_reasks=self.num_reasks,
- prompt_params=prompt_params,
- full_schema_reask=full_schema_reask,
- call_log=call_log,
- *args,
- **kwargs,
- )
+ return await self._async_parse(
+ llm_output,
+ metadata,
+ llm_api=llm_api,
+ num_reasks=self.num_reasks,
+ prompt_params=prompt_params,
+ full_schema_reask=full_schema_reask,
+ call_log=call_log,
+ *args,
+ **kwargs,
)
+
else:
raise NotImplementedError(
"AsyncGuard does not support non-async LLM APIs. "
@@ -437,7 +393,7 @@ def __parse(
)
guard_context = contextvars.Context()
- return guard_context.run(
+ return await guard_context.run(
__parse,
self,
llm_output,
diff --git a/tests/unit_tests/test_async_guard.py b/tests/unit_tests/test_async_guard.py
index 1e02f750c..0f4e33d6d 100644
--- a/tests/unit_tests/test_async_guard.py
+++ b/tests/unit_tests/test_async_guard.py
@@ -477,7 +477,8 @@ def test_use_many_tuple():
)
-def test_validate():
+@pytest.mark.asyncio
+async def test_validate():
guard: AsyncGuard = (
AsyncGuard()
.use(OneLine)
@@ -489,15 +490,13 @@ def test_validate():
)
llm_output: str = "Oh Canada" # bc it meets our criteria
-
- response = guard.validate(llm_output)
+ response = await guard.validate(llm_output)
assert response.validation_passed is True
assert response.validated_output == llm_output.lower()
-
llm_output_2 = "Star Spangled Banner" # to stick with the theme
- response_2 = guard.validate(llm_output_2)
+ response_2 = await guard.validate(llm_output_2)
assert response_2.validation_passed is False
assert response_2.validated_output is None
@@ -516,14 +515,14 @@ def test_validate():
llm_output: str = "Oh Canada" # bc it meets our criteria
- response = guard.validate(llm_output)
+ response = await guard.validate(llm_output)
assert response.validation_passed is True
assert response.validated_output == llm_output.lower()
llm_output_2 = "Star Spangled Banner" # to stick with the theme
- response_2 = guard.validate(llm_output_2)
+ response_2 = await guard.validate(llm_output_2)
assert response_2.validation_passed is False
assert response_2.validated_output is None
From 2d95bcb343d7c4a571486ebe3de82860f4461066 Mon Sep 17 00:00:00 2001
From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com>
Date: Thu, 16 May 2024 17:45:23 -0700
Subject: [PATCH 9/9] removing duplicated line
---
guardrails/guard.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/guardrails/guard.py b/guardrails/guard.py
index d5f25ba7e..8a5d3dbd8 100644
--- a/guardrails/guard.py
+++ b/guardrails/guard.py
@@ -1473,4 +1473,3 @@ def _call_server(
)
else:
raise ValueError("Guard does not have an api client!")
- raise ValueError("Guard does not have an api client!")