Skip to content

Commit 00ffe0c

Browse files
Add progress bar on evaluate (#1871)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
1 parent a341e56 commit 00ffe0c

File tree

1 file changed

+26
-7
lines changed

1 file changed

+26
-7
lines changed

pydantic_evals/pydantic_evals/dataset.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import time
1616
import warnings
1717
from collections.abc import Awaitable, Mapping, Sequence
18-
from contextlib import AsyncExitStack
18+
from contextlib import AsyncExitStack, nullcontext
1919
from contextvars import ContextVar
2020
from dataclasses import dataclass, field
2121
from pathlib import Path
@@ -28,6 +28,7 @@
2828
from pydantic._internal import _typing_extra
2929
from pydantic_core import to_json
3030
from pydantic_core.core_schema import SerializationInfo, SerializerFunctionWrapHandler
31+
from rich.progress import Progress
3132
from typing_extensions import NotRequired, Self, TypedDict, TypeVar
3233

3334
from pydantic_evals._utils import get_event_loop
@@ -251,7 +252,11 @@ def __init__(
251252
)
252253

253254
async def evaluate(
254-
self, task: Callable[[InputsT], Awaitable[OutputT]], name: str | None = None, max_concurrency: int | None = None
255+
self,
256+
task: Callable[[InputsT], Awaitable[OutputT]],
257+
name: str | None = None,
258+
max_concurrency: int | None = None,
259+
progress: bool = True,
255260
) -> EvaluationReport:
256261
"""Evaluates the test cases in the dataset using the given task.
257262
@@ -265,18 +270,26 @@ async def evaluate(
265270
If omitted, the name of the task function will be used.
266271
max_concurrency: The maximum number of concurrent evaluations of the task to allow.
267272
If None, all cases will be evaluated concurrently.
273+
progress: Whether to show a progress bar for the evaluation. Defaults to `True`.
268274
269275
Returns:
270276
A report containing the results of the evaluation.
271277
"""
272278
name = name or get_unwrapped_function_name(task)
279+
total_cases = len(self.cases)
280+
progress_bar = Progress() if progress else None
273281

274282
limiter = anyio.Semaphore(max_concurrency) if max_concurrency is not None else AsyncExitStack()
275-
with _logfire.span('evaluate {name}', name=name) as eval_span:
283+
284+
with _logfire.span('evaluate {name}', name=name) as eval_span, progress_bar or nullcontext():
285+
task_id = progress_bar.add_task(f'Evaluating {name}', total=total_cases) if progress_bar else None
276286

277287
async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name: str):
278288
async with limiter:
279-
return await _run_task_and_evaluators(task, case, report_case_name, self.evaluators)
289+
result = await _run_task_and_evaluators(task, case, report_case_name, self.evaluators)
290+
if progress_bar and task_id is not None: # pragma: no branch
291+
progress_bar.update(task_id, advance=1)
292+
return result
280293

281294
report = EvaluationReport(
282295
name=name,
@@ -291,11 +304,14 @@ async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name
291304
eval_span.set_attribute('cases', report.cases)
292305
# TODO(DavidM): Remove this 'averages' attribute once we compute it in the details panel
293306
eval_span.set_attribute('averages', report.averages())
294-
295307
return report
296308

297309
def evaluate_sync(
298-
self, task: Callable[[InputsT], Awaitable[OutputT]], name: str | None = None, max_concurrency: int | None = None
310+
self,
311+
task: Callable[[InputsT], Awaitable[OutputT]],
312+
name: str | None = None,
313+
max_concurrency: int | None = None,
314+
progress: bool = True,
299315
) -> EvaluationReport:
300316
"""Evaluates the test cases in the dataset using the given task.
301317
@@ -308,11 +324,14 @@ def evaluate_sync(
308324
If omitted, the name of the task function will be used.
309325
max_concurrency: The maximum number of concurrent evaluations of the task to allow.
310326
If None, all cases will be evaluated concurrently.
327+
progress: Whether to show a progress bar for the evaluation. Defaults to True.
311328
312329
Returns:
313330
A report containing the results of the evaluation.
314331
"""
315-
return get_event_loop().run_until_complete(self.evaluate(task, name=name, max_concurrency=max_concurrency))
332+
return get_event_loop().run_until_complete(
333+
self.evaluate(task, name=name, max_concurrency=max_concurrency, progress=progress)
334+
)
316335

317336
def add_case(
318337
self,

0 commit comments

Comments
 (0)