2
2
3
3
from collections import defaultdict
4
4
from collections .abc import Mapping
5
- from dataclasses import dataclass , field
5
+ from dataclasses import dataclass
6
6
from io import StringIO
7
- from typing import Any , Callable , Literal , Protocol , TypeVar
7
+ from typing import Any , Callable , Generic , Literal , Protocol
8
8
9
- from pydantic import BaseModel
9
+ from pydantic import BaseModel , TypeAdapter
10
10
from rich .console import Console
11
11
from rich .table import Table
12
- from typing_extensions import TypedDict
12
+ from typing_extensions import TypedDict , TypeVar
13
13
14
14
from pydantic_evals ._utils import UNSET , Unset
15
15
24
24
25
25
__all__ = (
26
26
'EvaluationReport' ,
27
+ 'EvaluationReportAdapter' ,
27
28
'ReportCase' ,
29
+ 'ReportCaseAdapter' ,
28
30
'EvaluationRenderer' ,
29
31
'RenderValueConfig' ,
30
32
'RenderNumberConfig' ,
35
37
EMPTY_CELL_STR = '-'
36
38
EMPTY_AGGREGATE_CELL_STR = ''
37
39
40
+ InputsT = TypeVar ('InputsT' , default = Any )
41
+ OutputT = TypeVar ('OutputT' , default = Any )
42
+ MetadataT = TypeVar ('MetadataT' , default = Any )
38
43
39
- class ReportCase (BaseModel ):
44
+
45
+ @dataclass
46
+ class ReportCase (Generic [InputsT , OutputT , MetadataT ]):
40
47
"""A single case in an evaluation report."""
41
48
42
49
name : str
43
50
"""The name of the [case][pydantic_evals.Case]."""
44
- inputs : Any
51
+ inputs : InputsT
45
52
"""The inputs to the task, from [`Case.inputs`][pydantic_evals.Case.inputs]."""
46
- metadata : Any
53
+ metadata : MetadataT | None
47
54
"""Any metadata associated with the case, from [`Case.metadata`][pydantic_evals.Case.metadata]."""
48
- expected_output : Any
55
+ expected_output : OutputT | None
49
56
"""The expected output of the task, from [`Case.expected_output`][pydantic_evals.Case.expected_output]."""
50
- output : Any
57
+ output : OutputT
51
58
"""The output of the task execution."""
52
59
53
60
metrics : dict [str , float | int ]
54
61
attributes : dict [str , Any ]
55
62
56
- scores : dict [str , EvaluationResult [int | float ]] = field ( init = False )
57
- labels : dict [str , EvaluationResult [str ]] = field ( init = False )
58
- assertions : dict [str , EvaluationResult [bool ]] = field ( init = False )
63
+ scores : dict [str , EvaluationResult [int | float ]]
64
+ labels : dict [str , EvaluationResult [str ]]
65
+ assertions : dict [str , EvaluationResult [bool ]]
59
66
60
67
task_duration : float
61
68
total_duration : float # includes evaluator execution time
@@ -65,6 +72,9 @@ class ReportCase(BaseModel):
65
72
span_id : str
66
73
67
74
75
+ ReportCaseAdapter = TypeAdapter (ReportCase [Any , Any , Any ])
76
+
77
+
68
78
class ReportCaseAggregate (BaseModel ):
69
79
"""A synthetic case that summarizes a set of cases."""
70
80
@@ -142,12 +152,13 @@ def _labels_averages(labels_by_name: list[dict[str, str]]) -> dict[str, dict[str
142
152
)
143
153
144
154
145
- class EvaluationReport (BaseModel ):
155
+ @dataclass
156
+ class EvaluationReport (Generic [InputsT , OutputT , MetadataT ]):
146
157
"""A report of the results of evaluating a model on a set of cases."""
147
158
148
159
name : str
149
160
"""The name of the report."""
150
- cases : list [ReportCase ]
161
+ cases : list [ReportCase [ InputsT , OutputT , MetadataT ] ]
151
162
"""The cases in the report."""
152
163
153
164
def averages (self ) -> ReportCaseAggregate :
@@ -156,7 +167,7 @@ def averages(self) -> ReportCaseAggregate:
156
167
def print (
157
168
self ,
158
169
width : int | None = None ,
159
- baseline : EvaluationReport | None = None ,
170
+ baseline : EvaluationReport [ InputsT , OutputT , MetadataT ] | None = None ,
160
171
include_input : bool = False ,
161
172
include_metadata : bool = False ,
162
173
include_expected_output : bool = False ,
@@ -199,7 +210,7 @@ def print(
199
210
200
211
def console_table (
201
212
self ,
202
- baseline : EvaluationReport | None = None ,
213
+ baseline : EvaluationReport [ InputsT , OutputT , MetadataT ] | None = None ,
203
214
include_input : bool = False ,
204
215
include_metadata : bool = False ,
205
216
include_expected_output : bool = False ,
@@ -250,6 +261,9 @@ def __str__(self) -> str: # pragma: lax no cover
250
261
return io_file .getvalue ()
251
262
252
263
264
+ EvaluationReportAdapter = TypeAdapter (EvaluationReport [Any , Any , Any ])
265
+
266
+
253
267
class RenderValueConfig (TypedDict , total = False ):
254
268
"""A configuration for rendering a values in an Evaluation report."""
255
269
0 commit comments