Skip to content

Force validator arg serialization #1155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion guardrails/classes/schema/processed_schema.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List
from guardrails_api_client import ValidatorReference
from guardrails.classes.execution.guard_execution_options import GuardExecutionOptions
from guardrails.classes.output_type import OutputTypes
from guardrails.classes.validation.validator_reference import ValidatorReference
from guardrails.types.validator import ValidatorMap


Expand Down
26 changes: 26 additions & 0 deletions guardrails/classes/validation/validator_reference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any, Dict
from guardrails_api_client import ValidatorReference as IValidatorReference

from guardrails.utils.serialization_utils import to_dict


# Docs only
class ValidatorReference(IValidatorReference):
Expand All @@ -17,3 +20,26 @@ class ValidatorReference(IValidatorReference):
args (Optional[List[Any]]): Positional arguments. Default None.
kwargs (Optional[Dict[str, Any]]): Keyword arguments. Default None.
"""

@classmethod
def from_interface(cls, interface: IValidatorReference) -> "ValidatorReference":
"""Create a ValidatorReference from an interface."""
return cls(
id=interface.id,
on=interface.on,
on_fail=interface.on_fail, # type: ignore
args=interface.args,
kwargs=interface.kwargs,
)

def to_dict(self) -> Dict[str, Any]:
ref_dict = super().to_dict()

# serialize args and kwargs
if self.args:
ref_dict["args"] = [to_dict(a) for a in self.args]

if self.kwargs:
ref_dict["kwargs"] = {k: to_dict(v) for k, v in self.kwargs.items()}

return ref_dict
19 changes: 14 additions & 5 deletions guardrails/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from guardrails_api_client import (
Guard as IGuard,
ValidatorReference,
ValidatePayload,
SimpleTypes,
ValidationOutcome as IValidationOutcome,
Expand All @@ -36,6 +35,7 @@
from guardrails.classes.rc import RC
from guardrails.classes.validation.validation_result import ErrorSpan
from guardrails.classes.validation.validation_summary import ValidationSummary
from guardrails.classes.validation.validator_reference import ValidatorReference
from guardrails.classes.validation_outcome import ValidationOutcome
from guardrails.classes.execution import GuardExecutionOptions
from guardrails.classes.generic import Stack
Expand Down Expand Up @@ -166,7 +166,7 @@ def __init__(
id=id,
name=name,
description=description,
validators=validators,
validators=[],
output_schema=model_schema,
history=history, # type: ignore - pyright doesn't understand pydantic overrides
)
Expand All @@ -180,6 +180,9 @@ def __init__(
# self.output_schema: Optional[ModelSchema] = None
# self.history = history

### Overrides ###
self.validators = validators

### Legacy ##
self._num_reasks = None
self._rail: Optional[str] = None
Expand Down Expand Up @@ -208,7 +211,10 @@ def __init__(
if loaded_guard:
self.id = loaded_guard.id
self.description = loaded_guard.description
self.validators = loaded_guard.validators or []
self.validators = [ # type: ignore
ValidatorReference.from_interface(v)
for v in loaded_guard.validators or []
]

loaded_output_schema = (
ModelSchema.from_dict( # trims out extra keys
Expand Down Expand Up @@ -1376,7 +1382,7 @@ def to_dict(self) -> Dict[str, Any]:
id=self.id,
name=self.name,
description=self.description,
validators=self.validators,
validators=self.validators, # type: ignore
output_schema=self.output_schema,
history=[c.to_interface() for c in self.history], # type: ignore
)
Expand Down Expand Up @@ -1416,7 +1422,10 @@ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional["Guard"]:
id=i_guard.id,
name=i_guard.name,
description=i_guard.description,
validators=i_guard.validators,
validators=[
ValidatorReference.from_interface(i_val)
for i_val in i_guard.validators or []
],
output_schema=output_schema,
)

Expand Down
2 changes: 1 addition & 1 deletion guardrails/schema/primitive_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from guardrails_api_client.models.model_schema import ModelSchema
from guardrails_api_client.models.simple_types import SimpleTypes
from guardrails_api_client.models.validation_type import ValidationType
from guardrails_api_client.models.validator_reference import ValidatorReference

from guardrails.classes.output_type import OutputTypes
from guardrails.classes.schema.processed_schema import ProcessedSchema
from guardrails.classes.validation.validator_reference import ValidatorReference
from guardrails.validator_base import Validator


Expand Down
2 changes: 1 addition & 1 deletion guardrails/schema/pydantic_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

from pydantic import AliasChoices, AliasGenerator, AliasPath, BaseModel
from pydantic.fields import FieldInfo
from guardrails_api_client import ValidatorReference
from guardrails.classes.output_type import OutputTypes
from guardrails.classes.schema.processed_schema import ProcessedSchema
from guardrails.classes.validation.validator_reference import ValidatorReference
from guardrails.logger import logger
from guardrails.types import (
ModelOrListOfModels,
Expand Down
3 changes: 2 additions & 1 deletion guardrails/schema/rail_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from lxml import etree as ET
from lxml.etree import _Element, Element, SubElement, XMLParser
from xml.etree.ElementTree import canonicalize
from guardrails_api_client import ModelSchema, SimpleTypes, ValidatorReference
from guardrails_api_client import ModelSchema, SimpleTypes
from guardrails.classes.execution.guard_execution_options import GuardExecutionOptions
from guardrails.classes.output_type import OutputTypes
from guardrails.classes.schema.processed_schema import ProcessedSchema
from guardrails.classes.validation.validator_reference import ValidatorReference
from guardrails.logger import logger
from guardrails.types import RailTypes
from guardrails.types.validator import ValidatorMap
Expand Down
19 changes: 19 additions & 0 deletions guardrails/utils/serialization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,29 @@
import json
from typing import Any, Optional
import warnings
from dataclasses import asdict, is_dataclass
from pydantic import BaseModel

from guardrails.classes.generic.default_json_encoder import DefaultJSONEncoder


# This is the same logic as the DefaultJSONEncoder but without stringifying everything
def to_dict(o):
if hasattr(o, "to_dict"):
return o.to_dict()
elif isinstance(o, BaseModel):
return o.model_dump()
elif is_dataclass(o):
return asdict(o)
elif isinstance(o, set):
return list(o)
elif isinstance(o, datetime):
return o.isoformat()
elif hasattr(o, "__dict__"):
return o.__dict__
return o


# TODO: What other common cases we should consider?
def serialize(val: Any) -> Optional[str]:
try:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/schema/test_primitive_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json

from guardrails_api_client.models.validator_reference import ValidatorReference
from guardrails.classes.validation.validator_reference import ValidatorReference
from guardrails.classes.schema.processed_schema import ProcessedSchema
from guardrails.schema.primitive_schema import primitive_to_schema
from guardrails.classes.output_type import OutputTypes
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/schema/test_pydantic_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json

from guardrails_api_client.models.validator_reference import ValidatorReference
from guardrails.classes.validation.validator_reference import ValidatorReference
from guardrails.classes.schema.processed_schema import ProcessedSchema
from guardrails.schema.pydantic_schema import pydantic_model_to_schema
from guardrails.classes.output_type import OutputTypes
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/schema/test_rail_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from xml.etree.ElementTree import canonicalize

from guardrails_api_client.models.validator_reference import ValidatorReference
from guardrails.classes.validation.validator_reference import ValidatorReference
from guardrails.classes.schema.processed_schema import ProcessedSchema
from guardrails.schema.rail_schema import (
rail_file_to_schema,
Expand Down
3 changes: 2 additions & 1 deletion tests/integration_tests/test_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

import pytest
from pydantic import BaseModel, Field
from guardrails_api_client import Guard as IGuard, ValidatorReference
from guardrails_api_client import Guard as IGuard

import guardrails as gd
from guardrails.actions.reask import SkeletonReAsk
from guardrails.classes.generic.stack import Stack
from guardrails.classes.llm.llm_response import LLMResponse
from guardrails.classes.validation_outcome import ValidationOutcome
from guardrails.classes.validation.validation_result import FailResult
from guardrails.classes.validation.validator_reference import ValidatorReference
from guardrails.guard import Guard
from guardrails.actions.reask import FieldReAsk
from tests.integration_tests.test_assets.validators import (
Expand Down
Loading