Skip to content

Commit 5f4661f

Browse files
authored
BREAKING CHANGE: Make EvaluationReport and ReportCase into generic dataclasses (#1799)
1 parent 98650cd commit 5f4661f

File tree

4 files changed

+41
-25
lines changed

4 files changed

+41
-25
lines changed

pydantic_evals/pydantic_evals/dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ async def evaluate(
257257
name: str | None = None,
258258
max_concurrency: int | None = None,
259259
progress: bool = True,
260-
) -> EvaluationReport:
260+
) -> EvaluationReport[InputsT, OutputT, MetadataT]:
261261
"""Evaluates the test cases in the dataset using the given task.
262262
263263
This method runs the task on each case in the dataset, applies evaluators,
@@ -312,7 +312,7 @@ def evaluate_sync(
312312
name: str | None = None,
313313
max_concurrency: int | None = None,
314314
progress: bool = True,
315-
) -> EvaluationReport:
315+
) -> EvaluationReport[InputsT, OutputT, MetadataT]:
316316
"""Evaluates the test cases in the dataset using the given task.
317317
318318
This is a synchronous wrapper around [`evaluate`][pydantic_evals.Dataset.evaluate] provided for convenience.
@@ -877,7 +877,7 @@ async def _run_task_and_evaluators(
877877
case: Case[InputsT, OutputT, MetadataT],
878878
report_case_name: str,
879879
dataset_evaluators: list[Evaluator[InputsT, OutputT, MetadataT]],
880-
) -> ReportCase:
880+
) -> ReportCase[InputsT, OutputT, MetadataT]:
881881
"""Run a task on a case and evaluate the results.
882882
883883
Args:
@@ -927,7 +927,7 @@ async def _run_task_and_evaluators(
927927
span_id = f'{context.span_id:016x}'
928928
fallback_duration = time.time() - t0
929929

930-
return ReportCase(
930+
return ReportCase[InputsT, OutputT, MetadataT](
931931
name=report_case_name,
932932
inputs=case.inputs,
933933
metadata=case.metadata,

pydantic_evals/pydantic_evals/reporting/__init__.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
from collections import defaultdict
44
from collections.abc import Mapping
5-
from dataclasses import dataclass, field
5+
from dataclasses import dataclass
66
from io import StringIO
7-
from typing import Any, Callable, Literal, Protocol, TypeVar
7+
from typing import Any, Callable, Generic, Literal, Protocol
88

9-
from pydantic import BaseModel
9+
from pydantic import BaseModel, TypeAdapter
1010
from rich.console import Console
1111
from rich.table import Table
12-
from typing_extensions import TypedDict
12+
from typing_extensions import TypedDict, TypeVar
1313

1414
from pydantic_evals._utils import UNSET, Unset
1515

@@ -24,7 +24,9 @@
2424

2525
__all__ = (
2626
'EvaluationReport',
27+
'EvaluationReportAdapter',
2728
'ReportCase',
29+
'ReportCaseAdapter',
2830
'EvaluationRenderer',
2931
'RenderValueConfig',
3032
'RenderNumberConfig',
@@ -35,27 +37,32 @@
3537
EMPTY_CELL_STR = '-'
3638
EMPTY_AGGREGATE_CELL_STR = ''
3739

40+
InputsT = TypeVar('InputsT', default=Any)
41+
OutputT = TypeVar('OutputT', default=Any)
42+
MetadataT = TypeVar('MetadataT', default=Any)
3843

39-
class ReportCase(BaseModel):
44+
45+
@dataclass
46+
class ReportCase(Generic[InputsT, OutputT, MetadataT]):
4047
"""A single case in an evaluation report."""
4148

4249
name: str
4350
"""The name of the [case][pydantic_evals.Case]."""
44-
inputs: Any
51+
inputs: InputsT
4552
"""The inputs to the task, from [`Case.inputs`][pydantic_evals.Case.inputs]."""
46-
metadata: Any
53+
metadata: MetadataT | None
4754
"""Any metadata associated with the case, from [`Case.metadata`][pydantic_evals.Case.metadata]."""
48-
expected_output: Any
55+
expected_output: OutputT | None
4956
"""The expected output of the task, from [`Case.expected_output`][pydantic_evals.Case.expected_output]."""
50-
output: Any
57+
output: OutputT
5158
"""The output of the task execution."""
5259

5360
metrics: dict[str, float | int]
5461
attributes: dict[str, Any]
5562

56-
scores: dict[str, EvaluationResult[int | float]] = field(init=False)
57-
labels: dict[str, EvaluationResult[str]] = field(init=False)
58-
assertions: dict[str, EvaluationResult[bool]] = field(init=False)
63+
scores: dict[str, EvaluationResult[int | float]]
64+
labels: dict[str, EvaluationResult[str]]
65+
assertions: dict[str, EvaluationResult[bool]]
5966

6067
task_duration: float
6168
total_duration: float # includes evaluator execution time
@@ -65,6 +72,9 @@ class ReportCase(BaseModel):
6572
span_id: str
6673

6774

75+
ReportCaseAdapter = TypeAdapter(ReportCase[Any, Any, Any])
76+
77+
6878
class ReportCaseAggregate(BaseModel):
6979
"""A synthetic case that summarizes a set of cases."""
7080

@@ -142,12 +152,13 @@ def _labels_averages(labels_by_name: list[dict[str, str]]) -> dict[str, dict[str
142152
)
143153

144154

145-
class EvaluationReport(BaseModel):
155+
@dataclass
156+
class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
146157
"""A report of the results of evaluating a model on a set of cases."""
147158

148159
name: str
149160
"""The name of the report."""
150-
cases: list[ReportCase]
161+
cases: list[ReportCase[InputsT, OutputT, MetadataT]]
151162
"""The cases in the report."""
152163

153164
def averages(self) -> ReportCaseAggregate:
@@ -156,7 +167,7 @@ def averages(self) -> ReportCaseAggregate:
156167
def print(
157168
self,
158169
width: int | None = None,
159-
baseline: EvaluationReport | None = None,
170+
baseline: EvaluationReport[InputsT, OutputT, MetadataT] | None = None,
160171
include_input: bool = False,
161172
include_metadata: bool = False,
162173
include_expected_output: bool = False,
@@ -199,7 +210,7 @@ def print(
199210

200211
def console_table(
201212
self,
202-
baseline: EvaluationReport | None = None,
213+
baseline: EvaluationReport[InputsT, OutputT, MetadataT] | None = None,
203214
include_input: bool = False,
204215
include_metadata: bool = False,
205216
include_expected_output: bool = False,
@@ -250,6 +261,9 @@ def __str__(self) -> str: # pragma: lax no cover
250261
return io_file.getvalue()
251262

252263

264+
EvaluationReportAdapter = TypeAdapter(EvaluationReport[Any, Any, Any])
265+
266+
253267
class RenderValueConfig(TypedDict, total=False):
254268
"""A configuration for rendering a values in an Evaluation report."""
255269

tests/evals/test_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class MockEvaluator(Evaluator[object, object, object]):
3333
def evaluate(self, ctx: EvaluatorContext[object, object, object]) -> EvaluatorOutput:
3434
return self.output
3535

36-
from pydantic_evals.reporting import ReportCase
36+
from pydantic_evals.reporting import ReportCase, ReportCaseAdapter
3737

3838
pytestmark = [pytest.mark.skipif(not imports_successful(), reason='pydantic-evals not installed'), pytest.mark.anyio]
3939

@@ -196,7 +196,7 @@ async def mock_task(inputs: TaskInput) -> TaskOutput:
196196

197197
assert report is not None
198198
assert len(report.cases) == 2
199-
assert report.cases[0].model_dump() == snapshot(
199+
assert ReportCaseAdapter.dump_python(report.cases[0]) == snapshot(
200200
{
201201
'assertions': {
202202
'correct': {
@@ -248,7 +248,7 @@ async def mock_task(inputs: TaskInput) -> TaskOutput:
248248

249249
assert report is not None
250250
assert len(report.cases) == 2
251-
assert report.cases[0].model_dump() == snapshot(
251+
assert ReportCaseAdapter.dump_python(report.cases[0]) == snapshot(
252252
{
253253
'assertions': {
254254
'correct': {

tests/evals/test_reports.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
from pydantic_evals.evaluators import EvaluationResult, Evaluator, EvaluatorContext
1313
from pydantic_evals.reporting import (
1414
EvaluationReport,
15+
EvaluationReportAdapter,
1516
RenderNumberConfig,
1617
RenderValueConfig,
1718
ReportCase,
19+
ReportCaseAdapter,
1820
ReportCaseAggregate,
1921
)
2022

@@ -157,7 +159,7 @@ async def test_report_case_aggregate():
157159
async def test_report_serialization(sample_report: EvaluationReport):
158160
"""Test serializing a report to dict."""
159161
# Serialize the report
160-
serialized = sample_report.model_dump()
162+
serialized = EvaluationReportAdapter.dump_python(sample_report)
161163

162164
# Check the serialized structure
163165
assert 'cases' in serialized
@@ -202,7 +204,7 @@ async def test_report_with_error(mock_evaluator: Evaluator[TaskInput, TaskOutput
202204
name='error_report',
203205
)
204206

205-
assert report.cases[0].model_dump() == snapshot(
207+
assert ReportCaseAdapter.dump_python(report.cases[0]) == snapshot(
206208
{
207209
'assertions': {
208210
'error_evaluator': {

0 commit comments

Comments
 (0)