diff --git a/pydantic_evals/pydantic_evals/dataset.py b/pydantic_evals/pydantic_evals/dataset.py index bdf733a24..88236f6d7 100644 --- a/pydantic_evals/pydantic_evals/dataset.py +++ b/pydantic_evals/pydantic_evals/dataset.py @@ -18,12 +18,14 @@ from contextlib import AsyncExitStack, nullcontext from contextvars import ContextVar from dataclasses import dataclass, field +from inspect import iscoroutinefunction from pathlib import Path from typing import Any, Callable, Generic, Literal, Union, cast import anyio import logfire_api import yaml +from anyio import to_thread from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, ValidationError, model_serializer from pydantic._internal import _typing_extra from pydantic_core import to_json @@ -253,7 +255,7 @@ def __init__( async def evaluate( self, - task: Callable[[InputsT], Awaitable[OutputT]], + task: Callable[[InputsT], Awaitable[OutputT]] | Callable[[InputsT], OutputT], name: str | None = None, max_concurrency: int | None = None, progress: bool = True, @@ -308,7 +310,7 @@ async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name def evaluate_sync( self, - task: Callable[[InputsT], Awaitable[OutputT]], + task: Callable[[InputsT], Awaitable[OutputT]] | Callable[[InputsT], OutputT], name: str | None = None, max_concurrency: int | None = None, progress: bool = True, @@ -811,7 +813,7 @@ def record_attribute(self, name: str, value: Any) -> None: async def _run_task( - task: Callable[[InputsT], Awaitable[OutputT]], case: Case[InputsT, OutputT, MetadataT] + task: Callable[[InputsT], Awaitable[OutputT] | OutputT], case: Case[InputsT, OutputT, MetadataT] ) -> EvaluatorContext[InputsT, OutputT, MetadataT]: """Run a task on a case and return the context for evaluators. @@ -836,7 +838,10 @@ async def _run_task( with _logfire.span('execute {task}', task=get_unwrapped_function_name(task)) as task_span: with context_subtree() as span_tree: t0 = time.perf_counter() - task_output = await task(case.inputs) + if iscoroutinefunction(task): + task_output = cast(OutputT, await task(case.inputs)) + else: + task_output = cast(OutputT, await to_thread.run_sync(task, case.inputs)) fallback_duration = time.perf_counter() - t0 finally: _CURRENT_TASK_RUN.reset(token) @@ -873,7 +878,7 @@ async def _run_task( async def _run_task_and_evaluators( - task: Callable[[InputsT], Awaitable[OutputT]], + task: Callable[[InputsT], Awaitable[OutputT]] | Callable[[InputsT], OutputT], case: Case[InputsT, OutputT, MetadataT], report_case_name: str, dataset_evaluators: list[Evaluator[InputsT, OutputT, MetadataT]], diff --git a/tests/evals/test_dataset.py b/tests/evals/test_dataset.py index 2cdefe603..d6025f800 100644 --- a/tests/evals/test_dataset.py +++ b/tests/evals/test_dataset.py @@ -178,21 +178,35 @@ def evaluate(self, ctx: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]): } +@pytest.mark.parametrize('async_task', [True, False]) async def test_evaluate( example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata], simple_evaluator: type[Evaluator[TaskInput, TaskOutput, TaskMetadata]], + async_task: bool, ): """Test evaluating a dataset.""" example_dataset.add_evaluator(simple_evaluator()) - async def mock_task(inputs: TaskInput) -> TaskOutput: - if inputs.query == 'What is 2+2?': - return TaskOutput(answer='4') - elif inputs.query == 'What is the capital of France?': - return TaskOutput(answer='Paris') - return TaskOutput(answer='Unknown') # pragma: no cover + if async_task: + + async def mock_async_task(inputs: TaskInput) -> TaskOutput: + if inputs.query == 'What is 2+2?': + return TaskOutput(answer='4') + elif inputs.query == 'What is the capital of France?': + return TaskOutput(answer='Paris') + return TaskOutput(answer='Unknown') # pragma: no cover + + report = await example_dataset.evaluate(mock_async_task) + else: + + def mock_sync_task(inputs: TaskInput) -> TaskOutput: + if inputs.query == 'What is 2+2?': + return TaskOutput(answer='4') + elif inputs.query == 'What is the capital of France?': + return TaskOutput(answer='Paris') + return TaskOutput(answer='Unknown') # pragma: no cover - report = await example_dataset.evaluate(mock_task) + report = await example_dataset.evaluate(mock_sync_task) assert report is not None assert len(report.cases) == 2