Skip to content

Commit 222bec4

Browse files
authored
Add ability to specify the evaluation name for all provided Evaluators (#1725)
1 parent 3d5199b commit 222bec4

File tree

9 files changed

+170
-47
lines changed

9 files changed

+170
-47
lines changed

pydantic_evals/pydantic_evals/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,14 +1036,14 @@ def _get_registry(
10361036
raise ValueError(
10371037
f'All custom evaluator classes must be decorated with `@dataclass`, but {evaluator_class} is not'
10381038
)
1039-
name = evaluator_class.name()
1039+
name = evaluator_class.get_serialization_name()
10401040
if name in registry:
10411041
raise ValueError(f'Duplicate evaluator class name: {name!r}')
10421042
registry[name] = evaluator_class
10431043

10441044
for evaluator_class in DEFAULT_EVALUATORS:
10451045
# Allow overriding the default evaluators with custom evaluators raising an error
1046-
registry.setdefault(evaluator_class.name(), evaluator_class)
1046+
registry.setdefault(evaluator_class.get_serialization_name(), evaluator_class)
10471047

10481048
return registry
10491049

pydantic_evals/pydantic_evals/evaluators/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1-
from .common import Contains, Equals, EqualsExpected, HasMatchingSpan, IsInstance, LLMJudge, MaxDuration, Python
1+
from .common import (
2+
Contains,
3+
Equals,
4+
EqualsExpected,
5+
HasMatchingSpan,
6+
IsInstance,
7+
LLMJudge,
8+
MaxDuration,
9+
OutputConfig,
10+
Python,
11+
)
212
from .context import EvaluatorContext
313
from .evaluator import EvaluationReason, EvaluationResult, Evaluator, EvaluatorOutput
414

@@ -11,6 +21,7 @@
1121
'MaxDuration',
1222
'LLMJudge',
1323
'HasMatchingSpan',
24+
'OutputConfig',
1425
'Python',
1526
# context
1627
'EvaluatorContext',

pydantic_evals/pydantic_evals/evaluators/_run_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ async def run_evaluator(
4242
except ValidationError as e:
4343
raise ValueError(f'{evaluator!r}.evaluate returned a value of an invalid type: {raw_results!r}.') from e
4444

45-
results = _convert_to_mapping(results, scalar_name=evaluator.name())
45+
results = _convert_to_mapping(results, scalar_name=evaluator.get_default_evaluation_name())
4646

4747
details: list[EvaluationResult] = []
4848
for name, result in results.items():

pydantic_evals/pydantic_evals/evaluators/_spec.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,19 @@ class EvaluatorSpec(BaseModel):
3232
* `{'MyEvaluator': {k1: v1, k2: v2}}` - Multiple kwargs are passed to `MyEvaluator.__init__`
3333
3434
Args:
35-
name: The name of the evaluator to use. Unless overridden, this is the snake_case version of the class name.
35+
name: The serialization name of the evaluator class returned by `EvaluatorClass.get_serialization_name()`;
36+
this is usually just the class name itself.
3637
arguments: The arguments to pass to the evaluator's constructor. Can be None (for no arguments),
3738
a tuple (for a single positional argument), or a dict (for multiple keyword arguments).
3839
"""
3940

4041
name: str
41-
"""The name of the evaluator class; should be the value returned by EvaluatorClass.name()"""
42+
"""The name of the evaluator class; should be the value returned by `EvaluatorClass.get_serialization_name()`"""
4243

4344
arguments: None | tuple[Any] | dict[str, Any]
4445
"""The arguments to pass to the evaluator's constructor.
4546
46-
Can be None (no arguments), a tuple (positional arguments), or a dict (keyword arguments).
47+
Can be None (no arguments), a tuple (a single positional argument), or a dict (keyword arguments).
4748
"""
4849

4950
@property

pydantic_evals/pydantic_evals/evaluators/common.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from __future__ import annotations as _annotations
22

3-
from dataclasses import dataclass
3+
from dataclasses import dataclass, field
44
from datetime import timedelta
5-
from typing import Any, cast
5+
from typing import Any, Literal, cast
6+
7+
from typing_extensions import TypedDict
68

79
from pydantic_ai import models
810
from pydantic_ai.settings import ModelSettings
911

1012
from ..otel.span_tree import SpanQuery
1113
from .context import EvaluatorContext
12-
from .evaluator import EvaluationReason, Evaluator, EvaluatorOutput
14+
from .evaluator import EvaluationReason, EvaluationScalar, Evaluator, EvaluatorOutput
1315

1416
__all__ = (
1517
'Equals',
@@ -20,6 +22,7 @@
2022
'LLMJudge',
2123
'HasMatchingSpan',
2224
'Python',
25+
'OutputConfig',
2326
)
2427

2528

@@ -28,6 +31,7 @@ class Equals(Evaluator[object, object, object]):
2831
"""Check if the output exactly equals the provided value."""
2932

3033
value: Any
34+
evaluation_name: str | None = field(default=None, repr=False)
3135

3236
def evaluate(self, ctx: EvaluatorContext[object, object, object]) -> bool:
3337
return ctx.output == self.value
@@ -37,6 +41,8 @@ def evaluate(self, ctx: EvaluatorContext[object, object, object]) -> bool:
3741
class EqualsExpected(Evaluator[object, object, object]):
3842
"""Check if the output exactly equals the expected output."""
3943

44+
evaluation_name: str | None = field(default=None, repr=False)
45+
4046
def evaluate(self, ctx: EvaluatorContext[object, object, object]) -> bool | dict[str, bool]:
4147
if ctx.expected_output is None:
4248
return {} # Only compare if expected output is provided
@@ -68,6 +74,7 @@ class Contains(Evaluator[object, object, object]):
6874
value: Any
6975
case_sensitive: bool = True
7076
as_strings: bool = False
77+
evaluation_name: str | None = field(default=None, repr=False)
7178

7279
def evaluate(
7380
self,
@@ -127,6 +134,7 @@ class IsInstance(Evaluator[object, object, object]):
127134
"""Check if the output is an instance of a type with the given name."""
128135

129136
type_name: str
137+
evaluation_name: str | None = field(default=None, repr=False)
130138

131139
def evaluate(self, ctx: EvaluatorContext[object, object, object]) -> EvaluationReason:
132140
output = ctx.output
@@ -154,6 +162,27 @@ def evaluate(self, ctx: EvaluatorContext[object, object, object]) -> bool:
154162
return duration <= seconds
155163

156164

165+
class OutputConfig(TypedDict, total=False):
166+
"""Configuration for the score and assertion outputs of the LLMJudge evaluator."""
167+
168+
evaluation_name: str
169+
include_reason: bool
170+
171+
172+
def _update_combined_output(
173+
combined_output: dict[str, EvaluationScalar | EvaluationReason],
174+
value: EvaluationScalar,
175+
reason: str | None,
176+
config: OutputConfig,
177+
default_name: str,
178+
) -> None:
179+
name = config.get('evaluation_name') or default_name
180+
if config.get('include_reason') and reason is not None:
181+
combined_output[name] = EvaluationReason(value=value, reason=reason)
182+
else:
183+
combined_output[name] = value
184+
185+
157186
@dataclass
158187
class LLMJudge(Evaluator[object, object, object]):
159188
"""Judge whether the output of a language model meets the criteria of a provided rubric.
@@ -166,11 +195,13 @@ class LLMJudge(Evaluator[object, object, object]):
166195
model: models.Model | models.KnownModelName | None = None
167196
include_input: bool = False
168197
model_settings: ModelSettings | None = None
198+
score: OutputConfig | Literal[False] = False
199+
assertion: OutputConfig | Literal[False] = field(default_factory=lambda: OutputConfig(include_reason=True))
169200

170201
async def evaluate(
171202
self,
172203
ctx: EvaluatorContext[object, object, object],
173-
) -> EvaluationReason:
204+
) -> EvaluatorOutput:
174205
if self.include_input:
175206
from .llm_as_a_judge import judge_input_output
176207

@@ -181,7 +212,20 @@ async def evaluate(
181212
from .llm_as_a_judge import judge_output
182213

183214
grading_output = await judge_output(ctx.output, self.rubric, self.model, self.model_settings)
184-
return EvaluationReason(value=grading_output.pass_, reason=grading_output.reason)
215+
216+
output: dict[str, EvaluationScalar | EvaluationReason] = {}
217+
include_both = self.score is not False and self.assertion is not False
218+
evaluation_name = self.get_default_evaluation_name()
219+
220+
if self.score is not False:
221+
default_name = f'{evaluation_name}_score' if include_both else evaluation_name
222+
_update_combined_output(output, grading_output.score, grading_output.reason, self.score, default_name)
223+
224+
if self.assertion is not False:
225+
default_name = f'{evaluation_name}_pass' if include_both else evaluation_name
226+
_update_combined_output(output, grading_output.pass_, grading_output.reason, self.assertion, default_name)
227+
228+
return output
185229

186230
def build_serialization_arguments(self):
187231
result = super().build_serialization_arguments()
@@ -200,6 +244,7 @@ class HasMatchingSpan(Evaluator[object, object, object]):
200244
"""Check if the span tree contains a span that matches the specified query."""
201245

202246
query: SpanQuery
247+
evaluation_name: str | None = field(default=None, repr=False)
203248

204249
def evaluate(
205250
self,
@@ -217,6 +262,7 @@ class Python(Evaluator[object, object, object]):
217262
"""
218263

219264
expression: str
265+
evaluation_name: str | None = field(default=None, repr=False)
220266

221267
def evaluate(self, ctx: EvaluatorContext[object, object, object]) -> EvaluatorOutput:
222268
# Evaluate the condition, exposing access to the evaluator context as `ctx`.

pydantic_evals/pydantic_evals/evaluators/evaluator.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from pydantic_core import to_jsonable_python
1414
from pydantic_core.core_schema import SerializationInfo
15-
from typing_extensions import TypeVar
15+
from typing_extensions import TypeVar, deprecated
1616

1717
from .._utils import get_event_loop
1818
from ._spec import EvaluatorSpec
@@ -146,17 +146,38 @@ def evaluate(self, ctx: EvaluatorContext) -> bool:
146146
__pydantic_config__ = ConfigDict(arbitrary_types_allowed=True)
147147

148148
@classmethod
149-
def name(cls) -> str:
149+
def get_serialization_name(cls) -> str:
150150
"""Return the 'name' of this Evaluator to use during serialization.
151151
152152
Returns:
153153
The name of the Evaluator, which is typically the class name.
154154
"""
155-
# Note: if we wanted to prefer snake_case, we could use:
156-
# from pydantic.alias_generators import to_snake
157-
# return to_snake(cls.__name__)
158155
return cls.__name__
159156

157+
@classmethod
158+
@deprecated('`name` has been renamed, use `get_serialization_name` instead.')
159+
def name(cls) -> str:
160+
"""`name` has been renamed, use `get_serialization_name` instead."""
161+
return cls.get_serialization_name()
162+
163+
def get_default_evaluation_name(self) -> str:
164+
"""Return the default name to use in reports for the output of this evaluator.
165+
166+
By default, if the evaluator has an attribute called `evaluation_name` of type string, that will be used.
167+
Otherwise, the serialization name of the evaluator (which is usually the class name) will be used.
168+
169+
This can be overridden to get a more descriptive name in evaluation reports, e.g. using instance information.
170+
171+
Note that evaluators that return a mapping of results will always use the keys of that mapping as the names
172+
of the associated evaluation results.
173+
"""
174+
evaluation_name = getattr(self, 'evaluation_name', None)
175+
if isinstance(evaluation_name, str):
176+
# If the evaluator has an attribute `name` of type string, use that
177+
return evaluation_name
178+
179+
return self.get_serialization_name()
180+
160181
@abstractmethod
161182
def evaluate(
162183
self, ctx: EvaluatorContext[InputsT, OutputT, MetadataT]
@@ -233,7 +254,9 @@ def serialize(self, info: SerializationInfo) -> Any:
233254
else:
234255
arguments = raw_arguments
235256
return to_jsonable_python(
236-
EvaluatorSpec(name=self.name(), arguments=arguments), context=info.context, serialize_unknown=True
257+
EvaluatorSpec(name=self.get_serialization_name(), arguments=arguments),
258+
context=info.context,
259+
serialize_unknown=True,
237260
)
238261

239262
def build_serialization_arguments(self) -> dict[str, Any]:

tests/evals/test_evaluator_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,11 @@ async def test_evaluator_async():
172172
assert result is True
173173

174174

175-
async def test_evaluator_name():
175+
async def test_evaluation_name():
176176
"""Test evaluator name method."""
177177
evaluator = SimpleEvaluator()
178-
assert evaluator.name() == 'SimpleEvaluator'
178+
assert evaluator.get_serialization_name() == 'SimpleEvaluator'
179+
assert evaluator.get_default_evaluation_name() == 'SimpleEvaluator'
179180

180181

181182
async def test_evaluator_serialization():

tests/evals/test_evaluator_common.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pytest
77
from inline_snapshot import snapshot
8+
from pydantic_core import to_jsonable_python
89
from pytest_mock import MockerFixture
910

1011
from pydantic_ai.settings import ModelSettings
@@ -25,6 +26,7 @@
2526
IsInstance,
2627
LLMJudge,
2728
MaxDuration,
29+
OutputConfig,
2830
Python,
2931
)
3032
from pydantic_evals.otel._context_in_memory_span_exporter import context_subtree
@@ -194,6 +196,7 @@ async def test_llm_judge_evaluator(mocker: MockerFixture):
194196
"""Test LLMJudge evaluator."""
195197
# Create a mock GradingOutput
196198
mock_grading_output = mocker.MagicMock()
199+
mock_grading_output.score = 1.0
197200
mock_grading_output.pass_ = True
198201
mock_grading_output.reason = 'Test passed'
199202

@@ -219,31 +222,42 @@ async def test_llm_judge_evaluator(mocker: MockerFixture):
219222

220223
# Test without input
221224
evaluator = LLMJudge(rubric='Content contains a greeting')
222-
result = await evaluator.evaluate(ctx)
223-
assert isinstance(result, EvaluationReason)
224-
assert result.value is True
225-
assert result.reason == 'Test passed'
225+
assert to_jsonable_python(await evaluator.evaluate(ctx)) == snapshot(
226+
{'LLMJudge': {'value': True, 'reason': 'Test passed'}}
227+
)
226228

227229
mock_judge_output.assert_called_once_with('Hello world', 'Content contains a greeting', None, None)
228230

229231
# Test with input
230232
evaluator = LLMJudge(rubric='Output contains input', include_input=True, model='openai:gpt-4o')
231-
result = await evaluator.evaluate(ctx)
232-
assert isinstance(result, EvaluationReason)
233-
assert result.value is True
234-
assert result.reason == 'Test passed'
233+
assert to_jsonable_python(await evaluator.evaluate(ctx)) == snapshot(
234+
{'LLMJudge': {'value': True, 'reason': 'Test passed'}}
235+
)
235236

236237
mock_judge_input_output.assert_called_once_with(
237238
{'prompt': 'Hello'}, 'Hello world', 'Output contains input', 'openai:gpt-4o', None
238239
)
239240

240241
# Test with failing result
242+
mock_grading_output.score = 0.0
241243
mock_grading_output.pass_ = False
242244
mock_grading_output.reason = 'Test failed'
243-
result = await evaluator.evaluate(ctx)
244-
assert isinstance(result, EvaluationReason)
245-
assert result.value is False
246-
assert result.reason == 'Test failed'
245+
assert to_jsonable_python(await evaluator.evaluate(ctx)) == snapshot(
246+
{'LLMJudge': {'value': False, 'reason': 'Test failed'}}
247+
)
248+
249+
# Test with overridden configs
250+
evaluator = LLMJudge(rubric='Mock rubric', assertion=False)
251+
assert to_jsonable_python(await evaluator.evaluate(ctx)) == snapshot({})
252+
253+
evaluator = LLMJudge(
254+
rubric='Mock rubric',
255+
score=OutputConfig(evaluation_name='my_score', include_reason=True),
256+
assertion=OutputConfig(evaluation_name='my_assertion'),
257+
)
258+
assert to_jsonable_python(await evaluator.evaluate(ctx)) == snapshot(
259+
{'my_assertion': False, 'my_score': {'reason': 'Test failed', 'value': 0.0}}
260+
)
247261

248262

249263
@pytest.mark.anyio
@@ -275,9 +289,9 @@ async def test_llm_judge_evaluator_with_model_settings(mocker: MockerFixture):
275289

276290
# Test without input, with custom model_settings
277291
evaluator_no_input = LLMJudge(rubric='Greeting with custom settings', model_settings=custom_model_settings)
278-
result_no_input = await evaluator_no_input.evaluate(ctx)
279-
assert result_no_input.value is True
280-
assert result_no_input.reason == 'Test passed with settings'
292+
assert to_jsonable_python(await evaluator_no_input.evaluate(ctx)) == snapshot(
293+
{'LLMJudge': {'value': True, 'reason': 'Test passed with settings'}}
294+
)
281295
mock_judge_output.assert_called_once_with(
282296
'Hello world custom settings', 'Greeting with custom settings', None, custom_model_settings
283297
)
@@ -289,9 +303,9 @@ async def test_llm_judge_evaluator_with_model_settings(mocker: MockerFixture):
289303
model='openai:gpt-3.5-turbo',
290304
model_settings=custom_model_settings,
291305
)
292-
result_with_input = await evaluator_with_input.evaluate(ctx)
293-
assert result_with_input.value is True
294-
assert result_with_input.reason == 'Test passed with settings'
306+
assert to_jsonable_python(await evaluator_with_input.evaluate(ctx)) == snapshot(
307+
{'LLMJudge': {'value': True, 'reason': 'Test passed with settings'}}
308+
)
295309
mock_judge_input_output.assert_called_once_with(
296310
{'prompt': 'Hello Custom'},
297311
'Hello world custom settings',

0 commit comments

Comments
 (0)