diff --git a/guardrails/classes/schema/processed_schema.py b/guardrails/classes/schema/processed_schema.py index 5e23544d0..74ec680ee 100644 --- a/guardrails/classes/schema/processed_schema.py +++ b/guardrails/classes/schema/processed_schema.py @@ -1,8 +1,8 @@ from dataclasses import dataclass, field from typing import Any, Dict, List -from guardrails_api_client import ValidatorReference from guardrails.classes.execution.guard_execution_options import GuardExecutionOptions from guardrails.classes.output_type import OutputTypes +from guardrails.classes.validation.validator_reference import ValidatorReference from guardrails.types.validator import ValidatorMap diff --git a/guardrails/classes/validation/validator_reference.py b/guardrails/classes/validation/validator_reference.py index 93890d0b0..7f48f75fb 100644 --- a/guardrails/classes/validation/validator_reference.py +++ b/guardrails/classes/validation/validator_reference.py @@ -1,5 +1,8 @@ +from typing import Any, Dict from guardrails_api_client import ValidatorReference as IValidatorReference +from guardrails.utils.serialization_utils import to_dict + # Docs only class ValidatorReference(IValidatorReference): @@ -17,3 +20,26 @@ class ValidatorReference(IValidatorReference): args (Optional[List[Any]]): Positional arguments. Default None. kwargs (Optional[Dict[str, Any]]): Keyword arguments. Default None. """ + + @classmethod + def from_interface(cls, interface: IValidatorReference) -> "ValidatorReference": + """Create a ValidatorReference from an interface.""" + return cls( + id=interface.id, + on=interface.on, + on_fail=interface.on_fail, # type: ignore + args=interface.args, + kwargs=interface.kwargs, + ) + + def to_dict(self) -> Dict[str, Any]: + ref_dict = super().to_dict() + + # serialize args and kwargs + if self.args: + ref_dict["args"] = [to_dict(a) for a in self.args] + + if self.kwargs: + ref_dict["kwargs"] = {k: to_dict(v) for k, v in self.kwargs.items()} + + return ref_dict diff --git a/guardrails/guard.py b/guardrails/guard.py index 8b20e76af..62895b5c7 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -22,7 +22,6 @@ from guardrails_api_client import ( Guard as IGuard, - ValidatorReference, ValidatePayload, SimpleTypes, ValidationOutcome as IValidationOutcome, @@ -36,6 +35,7 @@ from guardrails.classes.rc import RC from guardrails.classes.validation.validation_result import ErrorSpan from guardrails.classes.validation.validation_summary import ValidationSummary +from guardrails.classes.validation.validator_reference import ValidatorReference from guardrails.classes.validation_outcome import ValidationOutcome from guardrails.classes.execution import GuardExecutionOptions from guardrails.classes.generic import Stack @@ -166,7 +166,7 @@ def __init__( id=id, name=name, description=description, - validators=validators, + validators=[], output_schema=model_schema, history=history, # type: ignore - pyright doesn't understand pydantic overrides ) @@ -180,6 +180,9 @@ def __init__( # self.output_schema: Optional[ModelSchema] = None # self.history = history + ### Overrides ### + self.validators = validators + ### Legacy ## self._num_reasks = None self._rail: Optional[str] = None @@ -208,7 +211,10 @@ def __init__( if loaded_guard: self.id = loaded_guard.id self.description = loaded_guard.description - self.validators = loaded_guard.validators or [] + self.validators = [ # type: ignore + ValidatorReference.from_interface(v) + for v in loaded_guard.validators or [] + ] loaded_output_schema = ( ModelSchema.from_dict( # trims out extra keys @@ -1376,7 +1382,7 @@ def to_dict(self) -> Dict[str, Any]: id=self.id, name=self.name, description=self.description, - validators=self.validators, + validators=self.validators, # type: ignore output_schema=self.output_schema, history=[c.to_interface() for c in self.history], # type: ignore ) @@ -1416,7 +1422,10 @@ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional["Guard"]: id=i_guard.id, name=i_guard.name, description=i_guard.description, - validators=i_guard.validators, + validators=[ + ValidatorReference.from_interface(i_val) + for i_val in i_guard.validators or [] + ], output_schema=output_schema, ) diff --git a/guardrails/schema/primitive_schema.py b/guardrails/schema/primitive_schema.py index fbb98cfa0..4a758459d 100644 --- a/guardrails/schema/primitive_schema.py +++ b/guardrails/schema/primitive_schema.py @@ -3,10 +3,10 @@ from guardrails_api_client.models.model_schema import ModelSchema from guardrails_api_client.models.simple_types import SimpleTypes from guardrails_api_client.models.validation_type import ValidationType -from guardrails_api_client.models.validator_reference import ValidatorReference from guardrails.classes.output_type import OutputTypes from guardrails.classes.schema.processed_schema import ProcessedSchema +from guardrails.classes.validation.validator_reference import ValidatorReference from guardrails.validator_base import Validator diff --git a/guardrails/schema/pydantic_schema.py b/guardrails/schema/pydantic_schema.py index a29349307..085c8b018 100644 --- a/guardrails/schema/pydantic_schema.py +++ b/guardrails/schema/pydantic_schema.py @@ -13,9 +13,9 @@ from pydantic import AliasChoices, AliasGenerator, AliasPath, BaseModel from pydantic.fields import FieldInfo -from guardrails_api_client import ValidatorReference from guardrails.classes.output_type import OutputTypes from guardrails.classes.schema.processed_schema import ProcessedSchema +from guardrails.classes.validation.validator_reference import ValidatorReference from guardrails.logger import logger from guardrails.types import ( ModelOrListOfModels, diff --git a/guardrails/schema/rail_schema.py b/guardrails/schema/rail_schema.py index 2071f6b01..b23eae128 100644 --- a/guardrails/schema/rail_schema.py +++ b/guardrails/schema/rail_schema.py @@ -7,10 +7,11 @@ from lxml import etree as ET from lxml.etree import _Element, Element, SubElement, XMLParser from xml.etree.ElementTree import canonicalize -from guardrails_api_client import ModelSchema, SimpleTypes, ValidatorReference +from guardrails_api_client import ModelSchema, SimpleTypes from guardrails.classes.execution.guard_execution_options import GuardExecutionOptions from guardrails.classes.output_type import OutputTypes from guardrails.classes.schema.processed_schema import ProcessedSchema +from guardrails.classes.validation.validator_reference import ValidatorReference from guardrails.logger import logger from guardrails.types import RailTypes from guardrails.types.validator import ValidatorMap diff --git a/guardrails/utils/serialization_utils.py b/guardrails/utils/serialization_utils.py index d124a9069..7e2e94cda 100644 --- a/guardrails/utils/serialization_utils.py +++ b/guardrails/utils/serialization_utils.py @@ -2,10 +2,29 @@ import json from typing import Any, Optional import warnings +from dataclasses import asdict, is_dataclass +from pydantic import BaseModel from guardrails.classes.generic.default_json_encoder import DefaultJSONEncoder +# This is the same logic as the DefaultJSONEncoder but without stringifying everything +def to_dict(o): + if hasattr(o, "to_dict"): + return o.to_dict() + elif isinstance(o, BaseModel): + return o.model_dump() + elif is_dataclass(o): + return asdict(o) + elif isinstance(o, set): + return list(o) + elif isinstance(o, datetime): + return o.isoformat() + elif hasattr(o, "__dict__"): + return o.__dict__ + return o + + # TODO: What other common cases we should consider? def serialize(val: Any) -> Optional[str]: try: diff --git a/tests/integration_tests/schema/test_primitive_schema.py b/tests/integration_tests/schema/test_primitive_schema.py index 566c6dcb0..30f0d77b1 100644 --- a/tests/integration_tests/schema/test_primitive_schema.py +++ b/tests/integration_tests/schema/test_primitive_schema.py @@ -1,6 +1,6 @@ import json -from guardrails_api_client.models.validator_reference import ValidatorReference +from guardrails.classes.validation.validator_reference import ValidatorReference from guardrails.classes.schema.processed_schema import ProcessedSchema from guardrails.schema.primitive_schema import primitive_to_schema from guardrails.classes.output_type import OutputTypes diff --git a/tests/integration_tests/schema/test_pydantic_schema.py b/tests/integration_tests/schema/test_pydantic_schema.py index 12913425c..fecc5d43d 100644 --- a/tests/integration_tests/schema/test_pydantic_schema.py +++ b/tests/integration_tests/schema/test_pydantic_schema.py @@ -1,6 +1,6 @@ import json -from guardrails_api_client.models.validator_reference import ValidatorReference +from guardrails.classes.validation.validator_reference import ValidatorReference from guardrails.classes.schema.processed_schema import ProcessedSchema from guardrails.schema.pydantic_schema import pydantic_model_to_schema from guardrails.classes.output_type import OutputTypes diff --git a/tests/integration_tests/schema/test_rail_schema.py b/tests/integration_tests/schema/test_rail_schema.py index 4fb4ebfeb..0fd7b28f9 100644 --- a/tests/integration_tests/schema/test_rail_schema.py +++ b/tests/integration_tests/schema/test_rail_schema.py @@ -3,7 +3,7 @@ from xml.etree.ElementTree import canonicalize -from guardrails_api_client.models.validator_reference import ValidatorReference +from guardrails.classes.validation.validator_reference import ValidatorReference from guardrails.classes.schema.processed_schema import ProcessedSchema from guardrails.schema.rail_schema import ( rail_file_to_schema, diff --git a/tests/integration_tests/test_guard.py b/tests/integration_tests/test_guard.py index 86a258cdb..e4461c3ad 100644 --- a/tests/integration_tests/test_guard.py +++ b/tests/integration_tests/test_guard.py @@ -7,7 +7,7 @@ import pytest from pydantic import BaseModel, Field -from guardrails_api_client import Guard as IGuard, ValidatorReference +from guardrails_api_client import Guard as IGuard import guardrails as gd from guardrails.actions.reask import SkeletonReAsk @@ -15,6 +15,7 @@ from guardrails.classes.llm.llm_response import LLMResponse from guardrails.classes.validation_outcome import ValidationOutcome from guardrails.classes.validation.validation_result import FailResult +from guardrails.classes.validation.validator_reference import ValidatorReference from guardrails.guard import Guard from guardrails.actions.reask import FieldReAsk from tests.integration_tests.test_assets.validators import (