Skip to content

Define complex eval case through code in a pytest-class fashion #2287

@lionpeloux

Description

@lionpeloux

Description

@DouweM suggested to submit a proposal in this slack thread

Defining test cases is straightforward when both inputs and expected outputs are simple — often just a one-liner.

However, in many cases it falls short as we would rather define complex inputs and expected outputs directly in code (think "builder pattern").

Here is a MWE to draft a proposal.

The cases here are trivial. But how would you write the expected output of an eval case for an agent that goes from prompt to 3D drawing ?! You will likely build the expected 3D model with your drawing API and match that to the result of your agent.

from abc import ABC, ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, TypeVar

from pydantic_evals import Case
from pydantic_evals.evaluators import Evaluator

# ------------------------------------------------------------------------------
# Singleton Pattern
# ------------------------------------------------------------------------------


class SingletonABCMeta(ABCMeta):
    """Metaclass to enforce a singleton pattern on abstract base classes.

    Each concrete subclass will have only one instance, stored in the `_instances`
    cache.
    """

    _instances: dict[type, Any] = {}

    def __call__(cls, *args: Any, **kwargs: Any) -> Any:
        if cls not in cls._instances:  # type: ignore[misc]
            cls._instances[cls] = super().__call__(*args, **kwargs)
        return cls._instances[cls]


# ------------------------------------------------------------------------------
# Pytest like Case
# ------------------------------------------------------------------------------

InputsT = TypeVar("InputsT")
OutputT = TypeVar("OutputT")
MetadataT = TypeVar("MetadataT")
ClsT = TypeVar("ClsT", bound="GenericBaseCase[Any, Any, Any]")


@dataclass(init=False)
class GenericBaseCase[InputsT, OutputT, MetadataT](
    Case[InputsT, OutputT, MetadataT],
    ABC,
    metaclass=SingletonABCMeta,
):
    """Generic Abstract base class for evaluation cases, similar to pytest test classes.

    Uses a singleton pattern to ensure each concrete case class is instantiated only once.
    """

    def __init__(self):
        # Get the class name as default case name
        self.name: str | None = "-".join(self.__class__.__name__.split("_")).lower()

        # Populate inputs and expected_output by calling abstract methods
        self.inputs: InputsT = self._inputs()
        self.expected_output: OutputT | None = self._expected_output()
        self.metadata: MetadataT | None = self._metadata() if hasattr(self, "_metadata") else None
        self.evaluators: list[Evaluator[InputsT, OutputT, MetadataT]] = (
            self._evaluators() if hasattr(self, "_evaluators") else []
        )

    @abstractmethod
    def _inputs(self) -> InputsT:
        """Return the list of prompts/inputs for this case."""
        ...

    @abstractmethod
    def _expected_output(self) -> OutputT:
        """Return the expected output for this case."""
        ...

    def _metadata(self) -> MetadataT | None:
        return None

    def _evaluators(self) -> list[Evaluator[InputsT, OutputT, MetadataT]]:
        return []

    @classmethod
    def cases(cls: type[ClsT]) -> list[ClsT]:
        """Return the singleton instances for **all** concrete subclasses of ``GenericBaseCase``.

        This helper will instantiate each concrete subclass exactly once (thanks to
        the underlying ``SingletonABCMeta``) and then return the cached instances. It
        is the recommended way to collect all defined cases.
        """

        def find_subclasses(klass: type) -> set[type]:
            """Recursively find all subclasses of a given class."""
            subclasses: set[type] = set()
            queue = [klass]
            while queue:
                parent = queue.pop()
                for subclass in parent.__subclasses__():
                    if subclass not in subclasses:
                        subclasses.add(subclass)
                        queue.append(subclass)
            return subclasses

        # Ensure every concrete subclass has been instantiated at least once so that
        # the singleton cache is fully populated.
        for subclass in find_subclasses(cls):
            # Instantiation is idempotent due to the singleton metaclass. Abstract
            # classes cannot be instantiated, so we skip them.
            if not getattr(subclass, "__abstractmethods__", False):
                subclass()

        # The metaclass holds the cache of all singleton instances.
        instances_dict: dict[type, Any] = getattr(SingletonABCMeta, "_instances")
        return [instance for instance_cls, instance in instances_dict.items() if issubclass(instance_cls, cls)]
if __name__ == "__main__":
    from pydantic_ai import Agent
    from pydantic_evals import Dataset
    from pydantic_evals.evaluators import EvaluatorContext

    @dataclass
    class CaseMetadata:
        level: int = 0

    class CapitalEvalCase(GenericBaseCase[str, str | None, CaseMetadata], ABC):
        pass

    class Case_01(CapitalEvalCase):
        def _inputs(self):
            return "What is the capital of France?"

        def _expected_output(self):
            return "Paris"

        def _metadata(self):
            return CaseMetadata()

    class Case_02(CapitalEvalCase):
        def _inputs(self):
            return "What is the capital of Liechtenstein?"

        def _expected_output(self):
            return "Vaduz"

        def _metadata(self):
            return CaseMetadata(level=1)

    class Case_03(CapitalEvalCase):
        def _inputs(self):
            return "What is the capital of Europe?"

        def _expected_output(self):
            return None

        def _metadata(self):
            return CaseMetadata(level=3)

    cases = sorted(CapitalEvalCase.cases(), key=lambda x: x.name or "")
    for c in cases:
        print(f"{c.name}: inputs={c.inputs!r}, expected={c.expected_output!r}, metadata={c.metadata}")

    class LowerCaseEvaluator(Evaluator[str, str | None, CaseMetadata]):
        def evaluate(self, ctx: EvaluatorContext[str, str | None, CaseMetadata]) -> bool:
            if ctx.expected_output is None and ctx.output is None:
                return True
            if ctx.output is not None and ctx.expected_output is not None:
                if ctx.output.lower() == ctx.expected_output.lower():
                    return True
            return False

    dataset = Dataset(
        cases=cases,
        evaluators=[LowerCaseEvaluator()],
    )

    agent = Agent(
        "openai:gpt-4o",
        system_prompt="You are a helpful assistant that answers questions about the capital of countries. Return only the capital in lowercase or None if the answser is not a capital.",
    )

    async def guess_capital(question: str) -> str:
        result = await agent.run(question)
        return result.output

    report = dataset.evaluate_sync(guess_capital)
    report.print(include_input=True, include_output=True, include_durations=False)

"""
Evaluating guess_capital ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00
                     Evaluation Summary: guess_capital                     
┏━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Case ID  ┃ Inputs                                ┃ Outputs ┃ Assertions ┃
┡━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━┩
│ case-01  │ What is the capital of France?        │ paris   │ ✔          │
├──────────┼───────────────────────────────────────┼─────────┼────────────┤
│ case-02  │ What is the capital of Liechtenstein? │ vaduz   │ ✔          │
├──────────┼───────────────────────────────────────┼─────────┼────────────┤
│ case-03  │ What is the capital of Europe?        │ None    │ ✗          │
├──────────┼───────────────────────────────────────┼─────────┼────────────┤
│ Averages │                                       │         │ 66.7% ✔    │
└──────────┴───────────────────────────────────────┴─────────┴────────────┘
"""

References

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions