Skip to content

Start Command #828

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 0 additions & 27 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,40 +1,13 @@
MKDOCS_SERVE_ADDR ?= localhost:8000 # Default address for mkdocs serve, format: <host>:<port>, override with `make docs-serve MKDOCS_SERVE_ADDR=<host>:<port>`

# Extract major package versions for OpenAI and Pydantic
OPENAI_VERSION_MAJOR := $(shell poetry run python -c 'import openai; print(openai.__version__.split(".")[0])')
PYDANTIC_VERSION_MAJOR := $(shell poetry run python -c 'import pydantic; print(pydantic.__version__.split(".")[0])')

# Construct the typing command using only major versions
TYPING_CMD := type-pydantic-v$(PYDANTIC_VERSION_MAJOR)-openai-v$(OPENAI_VERSION_MAJOR)

autoformat:
poetry run ruff check guardrails/ tests/ --fix
poetry run ruff format guardrails/ tests/
poetry run docformatter --in-place --recursive guardrails tests

.PHONY: type
type:
@make $(TYPING_CMD)

type-pydantic-v1-openai-v0:
echo '{"reportDeprecated": true, "exclude": ["guardrails/utils/pydantic_utils/v2.py", "guardrails/utils/openai_utils/v1.py"]}' > pyrightconfig.json
poetry run pyright guardrails/
rm pyrightconfig.json

type-pydantic-v1-openai-v1:
echo '{"reportDeprecated": true, "exclude": ["guardrails/utils/pydantic_utils/v2.py", "guardrails/utils/openai_utils/v0.py"]}' > pyrightconfig.json
poetry run pyright guardrails/
rm pyrightconfig.json

type-pydantic-v2-openai-v0:
echo '{"reportDeprecated": true, "exclude": ["guardrails/utils/pydantic_utils/v1.py", "guardrails/utils/openai_utils/v1.py"]}' > pyrightconfig.json
poetry run pyright guardrails/
rm pyrightconfig.json

type-pydantic-v2-openai-v1:
echo '{"reportDeprecated": true, "exclude": ["guardrails/utils/pydantic_utils/v1.py", "guardrails/utils/openai_utils/v0.py"]}' > pyrightconfig.json
poetry run pyright guardrails/
rm pyrightconfig.json

lint:
poetry run ruff check guardrails/ tests/
Expand Down
40 changes: 39 additions & 1 deletion guardrails/actions/reask.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,44 @@ class ReAsk(IReask):
incorrect_value: Any
fail_results: List[FailResult]

@classmethod
def from_interface(cls, reask: IReask) -> "ReAsk":
fail_results = []
if reask.fail_results:
fail_results: List[FailResult] = [
FailResult.from_interface(fail_result)
for fail_result in reask.fail_results
]

if reask.additional_properties.get("path"):
return FieldReAsk(
incorrect_value=reask.incorrect_value,
fail_results=fail_results,
path=reask.additional_properties["path"],
)

if len(fail_results) == 1:
error_message = fail_results[0].error_message
if error_message == "Output is not parseable as JSON":
return NonParseableReAsk(
incorrect_value=reask.incorrect_value,
fail_results=fail_results,
)
elif "JSON does not match schema" in error_message:
return SkeletonReAsk(
incorrect_value=reask.incorrect_value,
fail_results=fail_results,
)

return cls(incorrect_value=reask.incorrect_value, fail_results=fail_results)

@classmethod
def from_dict(cls, obj: Dict[str, Any]) -> Optional["ReAsk"]:
i_reask = super().from_dict(obj)
if not i_reask:
return None
return cls.from_interface(i_reask)


class FieldReAsk(ReAsk):
# FIXME: This shouldn't be optional
Expand Down Expand Up @@ -363,7 +401,7 @@ def get_reask_setup_for_json(
def reask_decoder(obj: ReAsk):
decoded = {}
for k, v in obj.__dict__.items():
if k in ["path"]:
if k in ["path", "additional_properties"]:
continue
if k == "fail_results":
k = "error_messages"
Expand Down
12 changes: 12 additions & 0 deletions guardrails/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from guardrails_api_client.api.validate_api import ValidateApi
from guardrails_api_client.models import Guard, ValidatePayload

from guardrails.logger import logger


class GuardrailsApiClient:
_api_client: ApiClient
Expand Down Expand Up @@ -39,6 +41,13 @@ def upsert_guard(self, guard: Guard):
guard_name=guard.name, body=guard, _request_timeout=self.timeout
)

def fetch_guard(self, guard_name: str) -> Optional[Guard]:
try:
return self._guard_api.get_guard(guard_name=guard_name)
except Exception as e:
logger.debug(f"Error fetching guard {guard_name}: {e}")
return None

def validate(
self,
guard: Guard,
Expand Down Expand Up @@ -86,3 +95,6 @@ def stream_validate(
if line:
json_output = json.loads(line)
yield json_output

def get_history(self, guard_name: str, call_id: str):
return self._guard_api.get_guard_history(guard_name, call_id)
50 changes: 24 additions & 26 deletions guardrails/async_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
cast,
)

from guardrails_api_client.models import ValidatePayload
from guardrails_api_client.models import (
ValidatePayload,
ValidationOutcome as IValidationOutcome,
)

from guardrails import Guard
from guardrails.classes import OT, ValidationOutcome
Expand Down Expand Up @@ -273,9 +276,6 @@ async def __exec(
args=list(args),
kwargs=kwargs,
)
call_log = Call(inputs=call_inputs)
set_scope(str(object_id(call_log)))
self._history.push(call_log)

if self._api_client is not None and model_is_supported_server_side(
llm_api, *args, **kwargs
Expand All @@ -287,13 +287,15 @@ async def __exec(
prompt_params=prompt_params,
metadata=metadata,
full_schema_reask=full_schema_reask,
call_log=call_log,
*args,
**kwargs,
)

# If the LLM API is async, return a coroutine
else:
call_log = Call(inputs=call_inputs)
set_scope(str(object_id(call_log)))
self.history.push(call_log)
result = await self._exec(
llm_api=llm_api,
llm_output=llm_output,
Expand Down Expand Up @@ -538,20 +540,12 @@ async def parse(
)

async def _stream_server_call(
self,
*,
payload: Dict[str, Any],
llm_output: Optional[str] = None,
num_reasks: Optional[int] = None,
prompt_params: Optional[Dict] = None,
metadata: Optional[Dict] = {},
full_schema_reask: Optional[bool] = True,
call_log: Optional[Call],
self, *, payload: Dict[str, Any]
) -> AsyncIterable[ValidationOutcome[OT]]:
# TODO: Once server side supports async streaming, this function will need to
# yield async generators, not generators
if self._api_client:
validation_output: Optional[Any] = None
validation_output: Optional[IValidationOutcome] = None
response = self._api_client.stream_validate(
guard=self, # type: ignore
payload=ValidatePayload.from_dict(payload), # type: ignore
Expand All @@ -561,26 +555,30 @@ async def _stream_server_call(
validation_output = fragment
if validation_output is None:
yield ValidationOutcome[OT](
call_id="0", # type: ignore
raw_llm_output=None,
validated_output=None,
validation_passed=False,
error="The response from the server was empty!",
)
else:
validated_output = (
cast(OT, validation_output.validated_output.actual_instance)
if validation_output.validated_output
else None
)
yield ValidationOutcome[OT](
raw_llm_output=validation_output.raw_llm_response, # type: ignore
validated_output=cast(OT, validation_output.validated_output),
validation_passed=validation_output.result,
call_id=validation_output.call_id, # type: ignore
raw_llm_output=validation_output.raw_llm_output, # type: ignore
validated_output=validated_output,
validation_passed=(validation_output.validation_passed is True),
)
if validation_output:
self._construct_history_from_server_response(
validation_output=validation_output,
llm_output=llm_output,
num_reasks=num_reasks,
prompt_params=prompt_params,
metadata=metadata,
full_schema_reask=full_schema_reask,
call_log=call_log,
guard_history = self._api_client.get_history(
self.name, validation_output.call_id
)
self.history.extend(
[Call.from_interface(call) for call in guard_history]
)
else:
raise ValueError("AsyncGuard does not have an api client!")
Expand Down
74 changes: 41 additions & 33 deletions guardrails/classes/history/call.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Dict, List, Optional, Union

from pydantic import Field, PrivateAttr
from builtins import id as object_id
from pydantic import Field
from rich.panel import Panel
from rich.pretty import pretty_repr
from rich.tree import Tree

from guardrails_api_client import Call as ICall, CallException
from guardrails_api_client import Call as ICall
from guardrails.actions.filter import Filter
from guardrails.actions.refrain import Refrain
from guardrails.actions.reask import merge_reask_output
Expand Down Expand Up @@ -36,7 +36,10 @@ class Call(ICall, ArbitraryModel):
inputs: CallInputs = Field(
description="The inputs as passed in to Guard.__call__ or Guard.parse"
)
_exception: Optional[Exception] = PrivateAttr()
exception: Optional[Exception] = Field(
description="The exception that interrupted the run.",
default=None,
)

# Prevent Pydantic from changing our types
# Without this, Pydantic casts iterations to a list
Expand All @@ -46,16 +49,13 @@ def __init__(
inputs: Optional[CallInputs] = None,
exception: Optional[Exception] = None,
):
call_id = str(object_id(self))
iterations = iterations or Stack()
inputs = inputs or CallInputs()
super().__init__(
iterations=iterations, # type: ignore
inputs=inputs, # type: ignore
i_exception=CallException(message=str(exception)), # type: ignore
)
super().__init__(id=call_id, iterations=iterations, inputs=inputs) # type: ignore - pyright doesn't understand pydantic overrides
self.iterations = iterations
self.inputs = inputs
self._exception = exception
self.exception = exception

@property
def prompt(self) -> Optional[str]:
Expand Down Expand Up @@ -312,25 +312,12 @@ def validator_logs(self) -> Stack[ValidatorLogs]:
def error(self) -> Optional[str]:
"""The error message from any exception that raised and interrupted the
run."""
if self._exception:
return str(self._exception)
if self.exception:
return str(self.exception)
elif self.iterations.empty():
return None
return self.iterations.last.error # type: ignore

@property
def exception(self) -> Optional[Exception]:
"""The exception that interrupted the run."""
if self._exception:
return self._exception
elif self.iterations.empty():
return None
return self.iterations.last.exception # type: ignore

def _set_exception(self, exception: Optional[Exception]):
self._exception = exception
self.i_exception = CallException(message=str(exception))

@property
def failed_validations(self) -> Stack[ValidatorLogs]:
"""The validator logs for any validations that failed during the
Expand Down Expand Up @@ -408,14 +395,35 @@ def tree(self) -> Tree:
def __str__(self) -> str:
return pretty_repr(self)

def to_dict(self) -> Dict[str, Any]:
i_call = ICall(
iterations=list(self.iterations),
inputs=self.inputs,
def to_interface(self) -> ICall:
return ICall(
id=self.id,
iterations=[i.to_interface() for i in self.iterations],
inputs=self.inputs.to_interface(),
exception=self.error,
)

i_call_dict = i_call.to_dict()
def to_dict(self) -> Dict[str, Any]:
return self.to_interface().to_dict()

if self._exception:
i_call_dict["exception"] = str(self._exception)
return i_call_dict
@classmethod
def from_interface(cls, i_call: ICall) -> "Call":
iterations = Stack(
*[Iteration.from_interface(i) for i in (i_call.iterations or [])]
)
inputs = (
CallInputs.from_interface(i_call.inputs) if i_call.inputs else CallInputs()
)
exception = Exception(i_call.exception) if i_call.exception else None
call_inst = cls(iterations=iterations, inputs=inputs, exception=exception)
call_inst.id = i_call.id
return call_inst

# TODO: Necessary to GET /guards/{guard_name}/history/{call_id}
@classmethod
def from_dict(cls, obj: Dict[str, Any]) -> "Call":
i_call = ICall.from_dict(obj)

if i_call:
return cls.from_interface(i_call)
return Call()
42 changes: 42 additions & 0 deletions guardrails/classes/history/call_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,45 @@ class CallInputs(Inputs, ICallInputs, ArbitraryModel):
description="Additional keyword-arguments for the LLM as provided by the user.",
default_factory=dict,
)

def to_interface(self) -> ICallInputs:
inputs = super().to_interface().to_dict() or {}
inputs["args"] = self.args
# TODO: Better way to prevent creds from being logged,
# if they're passed in as kwargs to the LLM
redacted_kwargs = {}
for k, v in self.kwargs.items():
if "key" in k.lower() or "token" in k.lower():
redaction_length = len(v) - 4
stars = "*" * redaction_length
redacted_kwargs[k] = f"{stars}{v[-4:]}"
else:
redacted_kwargs[k] = v
inputs["kwargs"] = redacted_kwargs
return ICallInputs(**inputs)

def to_dict(self) -> Dict[str, Any]:
return self.to_interface().to_dict()

@classmethod
def from_interface(cls, i_call_inputs: ICallInputs) -> "CallInputs":
return cls(
llm_api=None,
llm_output=i_call_inputs.llm_output,
instructions=i_call_inputs.instructions,
prompt=i_call_inputs.prompt,
msg_history=i_call_inputs.msg_history,
prompt_params=i_call_inputs.prompt_params,
num_reasks=i_call_inputs.num_reasks,
metadata=i_call_inputs.metadata,
full_schema_reask=(i_call_inputs.full_schema_reask is True),
stream=(i_call_inputs.stream is True),
args=(i_call_inputs.args or []),
kwargs=(i_call_inputs.kwargs or {}),
)

@classmethod
def from_dict(cls, obj: Dict[str, Any]):
i_call_inputs = ICallInputs.from_dict(obj) or ICallInputs()

return cls.from_interface(i_call_inputs)
Loading
Loading