Skip to content

Commit 448c7b7

Browse files
authored
Allow definition of Model Settings for LLMJudge (#1662)
1 parent c0afc0d commit 448c7b7

File tree

4 files changed

+140
-8
lines changed

4 files changed

+140
-8
lines changed

pydantic_evals/pydantic_evals/evaluators/common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, cast
66

77
from pydantic_ai import models
8+
from pydantic_ai.settings import ModelSettings
89

910
from ..otel.span_tree import SpanQuery
1011
from .context import EvaluatorContext
@@ -164,6 +165,7 @@ class LLMJudge(Evaluator[object, object, object]):
164165
rubric: str
165166
model: models.Model | models.KnownModelName | None = None
166167
include_input: bool = False
168+
model_settings: ModelSettings | None = None
167169

168170
async def evaluate(
169171
self,
@@ -172,11 +174,13 @@ async def evaluate(
172174
if self.include_input:
173175
from .llm_as_a_judge import judge_input_output
174176

175-
grading_output = await judge_input_output(ctx.inputs, ctx.output, self.rubric, self.model)
177+
grading_output = await judge_input_output(
178+
ctx.inputs, ctx.output, self.rubric, self.model, self.model_settings
179+
)
176180
else:
177181
from .llm_as_a_judge import judge_output
178182

179-
grading_output = await judge_output(ctx.output, self.rubric, self.model)
183+
grading_output = await judge_output(ctx.output, self.rubric, self.model, self.model_settings)
180184
return EvaluationReason(value=grading_output.pass_, reason=grading_output.reason)
181185

182186
def build_serialization_arguments(self):

pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pydantic_core import to_json
88

99
from pydantic_ai import Agent, models
10+
from pydantic_ai.settings import ModelSettings
1011

1112
__all__ = ('GradingOutput', 'judge_input_output', 'judge_output', 'set_default_judge_model')
1213

@@ -44,15 +45,20 @@ class GradingOutput(BaseModel, populate_by_name=True):
4445

4546

4647
async def judge_output(
47-
output: Any, rubric: str, model: models.Model | models.KnownModelName | None = None
48+
output: Any,
49+
rubric: str,
50+
model: models.Model | models.KnownModelName | None = None,
51+
model_settings: ModelSettings | None = None,
4852
) -> GradingOutput:
4953
"""Judge the output of a model based on a rubric.
5054
5155
If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o',
5256
but this can be changed using the `set_default_judge_model` function.
5357
"""
5458
user_prompt = f'<Output>\n{_stringify(output)}\n</Output>\n<Rubric>\n{rubric}\n</Rubric>'
55-
return (await _judge_output_agent.run(user_prompt, model=model or _default_model)).output
59+
return (
60+
await _judge_output_agent.run(user_prompt, model=model or _default_model, model_settings=model_settings)
61+
).output
5662

5763

5864
_judge_input_output_agent = Agent(
@@ -79,15 +85,21 @@ async def judge_output(
7985

8086

8187
async def judge_input_output(
82-
inputs: Any, output: Any, rubric: str, model: models.Model | models.KnownModelName | None = None
88+
inputs: Any,
89+
output: Any,
90+
rubric: str,
91+
model: models.Model | models.KnownModelName | None = None,
92+
model_settings: ModelSettings | None = None,
8393
) -> GradingOutput:
8494
"""Judge the output of a model based on the inputs and a rubric.
8595
8696
If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o',
8797
but this can be changed using the `set_default_judge_model` function.
8898
"""
8999
user_prompt = f'<Input>\n{_stringify(inputs)}\n</Input>\n<Output>\n{_stringify(output)}\n</Output>\n<Rubric>\n{rubric}\n</Rubric>'
90-
return (await _judge_input_output_agent.run(user_prompt, model=model or _default_model)).output
100+
return (
101+
await _judge_input_output_agent.run(user_prompt, model=model or _default_model, model_settings=model_settings)
102+
).output
91103

92104

93105
def set_default_judge_model(model: models.Model | models.KnownModelName) -> None: # pragma: no cover

tests/evals/test_evaluator_common.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from inline_snapshot import snapshot
88
from pytest_mock import MockerFixture
99

10+
from pydantic_ai.settings import ModelSettings
11+
1012
from ..conftest import try_import
1113

1214
with try_import() as imports_successful:
@@ -222,7 +224,7 @@ async def test_llm_judge_evaluator(mocker: MockerFixture):
222224
assert result.value is True
223225
assert result.reason == 'Test passed'
224226

225-
mock_judge_output.assert_called_once_with('Hello world', 'Content contains a greeting', None)
227+
mock_judge_output.assert_called_once_with('Hello world', 'Content contains a greeting', None, None)
226228

227229
# Test with input
228230
evaluator = LLMJudge(rubric='Output contains input', include_input=True, model='openai:gpt-4o')
@@ -232,7 +234,7 @@ async def test_llm_judge_evaluator(mocker: MockerFixture):
232234
assert result.reason == 'Test passed'
233235

234236
mock_judge_input_output.assert_called_once_with(
235-
{'prompt': 'Hello'}, 'Hello world', 'Output contains input', 'openai:gpt-4o'
237+
{'prompt': 'Hello'}, 'Hello world', 'Output contains input', 'openai:gpt-4o', None
236238
)
237239

238240
# Test with failing result
@@ -244,6 +246,61 @@ async def test_llm_judge_evaluator(mocker: MockerFixture):
244246
assert result.reason == 'Test failed'
245247

246248

249+
@pytest.mark.anyio
250+
async def test_llm_judge_evaluator_with_model_settings(mocker: MockerFixture):
251+
"""Test LLMJudge evaluator with specific model_settings."""
252+
mock_grading_output = mocker.MagicMock()
253+
mock_grading_output.pass_ = True
254+
mock_grading_output.reason = 'Test passed with settings'
255+
256+
mock_judge_output = mocker.patch('pydantic_evals.evaluators.llm_as_a_judge.judge_output')
257+
mock_judge_output.return_value = mock_grading_output
258+
259+
mock_judge_input_output = mocker.patch('pydantic_evals.evaluators.llm_as_a_judge.judge_input_output')
260+
mock_judge_input_output.return_value = mock_grading_output
261+
262+
custom_model_settings = ModelSettings(temperature=0.77)
263+
264+
ctx = EvaluatorContext(
265+
name='test_custom_settings',
266+
inputs={'prompt': 'Hello Custom'},
267+
metadata=None,
268+
expected_output=None,
269+
output='Hello world custom settings',
270+
duration=0.0,
271+
_span_tree=SpanTreeRecordingError('spans were not recorded'),
272+
attributes={},
273+
metrics={},
274+
)
275+
276+
# Test without input, with custom model_settings
277+
evaluator_no_input = LLMJudge(rubric='Greeting with custom settings', model_settings=custom_model_settings)
278+
result_no_input = await evaluator_no_input.evaluate(ctx)
279+
assert result_no_input.value is True
280+
assert result_no_input.reason == 'Test passed with settings'
281+
mock_judge_output.assert_called_once_with(
282+
'Hello world custom settings', 'Greeting with custom settings', None, custom_model_settings
283+
)
284+
285+
# Test with input, with custom model_settings
286+
evaluator_with_input = LLMJudge(
287+
rubric='Output contains input with custom settings',
288+
include_input=True,
289+
model='openai:gpt-3.5-turbo',
290+
model_settings=custom_model_settings,
291+
)
292+
result_with_input = await evaluator_with_input.evaluate(ctx)
293+
assert result_with_input.value is True
294+
assert result_with_input.reason == 'Test passed with settings'
295+
mock_judge_input_output.assert_called_once_with(
296+
{'prompt': 'Hello Custom'},
297+
'Hello world custom settings',
298+
'Output contains input with custom settings',
299+
'openai:gpt-3.5-turbo',
300+
custom_model_settings,
301+
)
302+
303+
247304
async def test_python():
248305
"""Test Python evaluator."""
249306
evaluator = Python(expression='ctx.output > 0')

tests/evals/test_llm_as_a_judge.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ..conftest import try_import
77

88
with try_import() as imports_successful:
9+
from pydantic_ai.settings import ModelSettings
910
from pydantic_evals.evaluators.llm_as_a_judge import (
1011
GradingOutput,
1112
_stringify, # pyright: ignore[reportPrivateUsage]
@@ -87,6 +88,34 @@ async def test_judge_output_mock(mocker: MockerFixture):
8788
assert '<Rubric>\nContent contains a greeting\n</Rubric>' in call_args[0]
8889

8990

91+
@pytest.mark.anyio
92+
async def test_judge_output_with_model_settings_mock(mocker: MockerFixture):
93+
"""Test judge_output function with model_settings and mocked agent."""
94+
mock_result = mocker.MagicMock()
95+
mock_result.output = GradingOutput(reason='Test passed with settings', pass_=True, score=1.0)
96+
mock_run = mocker.patch('pydantic_ai.Agent.run', return_value=mock_result)
97+
98+
test_model_settings = ModelSettings(temperature=1)
99+
100+
grading_output = await judge_output(
101+
'Hello world settings',
102+
'Content contains a greeting with settings',
103+
model_settings=test_model_settings,
104+
)
105+
assert isinstance(grading_output, GradingOutput)
106+
assert grading_output.reason == 'Test passed with settings'
107+
assert grading_output.pass_ is True
108+
assert grading_output.score == 1.0
109+
110+
mock_run.assert_called_once()
111+
call_args, call_kwargs = mock_run.call_args
112+
assert '<Output>\nHello world settings\n</Output>' in call_args[0]
113+
assert '<Rubric>\nContent contains a greeting with settings\n</Rubric>' in call_args[0]
114+
assert call_kwargs['model_settings'] == test_model_settings
115+
# Check if 'model' kwarg is passed, its value will be the default model or None
116+
assert 'model' in call_kwargs
117+
118+
90119
@pytest.mark.anyio
91120
async def test_judge_input_output_mock(mocker: MockerFixture):
92121
"""Test judge_input_output function with mocked agent."""
@@ -108,3 +137,33 @@ async def test_judge_input_output_mock(mocker: MockerFixture):
108137
assert '<Input>\nHello\n</Input>' in call_args[0]
109138
assert '<Output>\nHello world\n</Output>' in call_args[0]
110139
assert '<Rubric>\nOutput contains input\n</Rubric>' in call_args[0]
140+
141+
142+
@pytest.mark.anyio
143+
async def test_judge_input_output_with_model_settings_mock(mocker: MockerFixture):
144+
"""Test judge_input_output function with model_settings and mocked agent."""
145+
mock_result = mocker.MagicMock()
146+
mock_result.output = GradingOutput(reason='Test passed with settings', pass_=True, score=1.0)
147+
mock_run = mocker.patch('pydantic_ai.Agent.run', return_value=mock_result)
148+
149+
test_model_settings = ModelSettings(temperature=1)
150+
151+
result = await judge_input_output(
152+
'Hello settings',
153+
'Hello world with settings',
154+
'Output contains input with settings',
155+
model_settings=test_model_settings,
156+
)
157+
assert isinstance(result, GradingOutput)
158+
assert result.reason == 'Test passed with settings'
159+
assert result.pass_ is True
160+
assert result.score == 1.0
161+
162+
mock_run.assert_called_once()
163+
call_args, call_kwargs = mock_run.call_args
164+
assert '<Input>\nHello settings\n</Input>' in call_args[0]
165+
assert '<Output>\nHello world with settings\n</Output>' in call_args[0]
166+
assert '<Rubric>\nOutput contains input with settings\n</Rubric>' in call_args[0]
167+
assert call_kwargs['model_settings'] == test_model_settings
168+
# Check if 'model' kwarg is passed, its value will be the default model or None
169+
assert 'model' in call_kwargs

0 commit comments

Comments
 (0)