15
15
import time
16
16
import warnings
17
17
from collections .abc import Awaitable , Mapping , Sequence
18
- from contextlib import AsyncExitStack
18
+ from contextlib import AsyncExitStack , nullcontext
19
19
from contextvars import ContextVar
20
20
from dataclasses import dataclass , field
21
21
from pathlib import Path
28
28
from pydantic ._internal import _typing_extra
29
29
from pydantic_core import to_json
30
30
from pydantic_core .core_schema import SerializationInfo , SerializerFunctionWrapHandler
31
+ from rich .progress import Progress
31
32
from typing_extensions import NotRequired , Self , TypedDict , TypeVar
32
33
33
34
from pydantic_evals ._utils import get_event_loop
@@ -251,7 +252,11 @@ def __init__(
251
252
)
252
253
253
254
async def evaluate (
254
- self , task : Callable [[InputsT ], Awaitable [OutputT ]], name : str | None = None , max_concurrency : int | None = None
255
+ self ,
256
+ task : Callable [[InputsT ], Awaitable [OutputT ]],
257
+ name : str | None = None ,
258
+ max_concurrency : int | None = None ,
259
+ progress : bool = True ,
255
260
) -> EvaluationReport :
256
261
"""Evaluates the test cases in the dataset using the given task.
257
262
@@ -265,18 +270,26 @@ async def evaluate(
265
270
If omitted, the name of the task function will be used.
266
271
max_concurrency: The maximum number of concurrent evaluations of the task to allow.
267
272
If None, all cases will be evaluated concurrently.
273
+ progress: Whether to show a progress bar for the evaluation. Defaults to `True`.
268
274
269
275
Returns:
270
276
A report containing the results of the evaluation.
271
277
"""
272
278
name = name or get_unwrapped_function_name (task )
279
+ total_cases = len (self .cases )
280
+ progress_bar = Progress () if progress else None
273
281
274
282
limiter = anyio .Semaphore (max_concurrency ) if max_concurrency is not None else AsyncExitStack ()
275
- with _logfire .span ('evaluate {name}' , name = name ) as eval_span :
283
+
284
+ with _logfire .span ('evaluate {name}' , name = name ) as eval_span , progress_bar or nullcontext ():
285
+ task_id = progress_bar .add_task (f'Evaluating { name } ' , total = total_cases ) if progress_bar else None
276
286
277
287
async def _handle_case (case : Case [InputsT , OutputT , MetadataT ], report_case_name : str ):
278
288
async with limiter :
279
- return await _run_task_and_evaluators (task , case , report_case_name , self .evaluators )
289
+ result = await _run_task_and_evaluators (task , case , report_case_name , self .evaluators )
290
+ if progress_bar and task_id is not None : # pragma: no branch
291
+ progress_bar .update (task_id , advance = 1 )
292
+ return result
280
293
281
294
report = EvaluationReport (
282
295
name = name ,
@@ -291,11 +304,14 @@ async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name
291
304
eval_span .set_attribute ('cases' , report .cases )
292
305
# TODO(DavidM): Remove this 'averages' attribute once we compute it in the details panel
293
306
eval_span .set_attribute ('averages' , report .averages ())
294
-
295
307
return report
296
308
297
309
def evaluate_sync (
298
- self , task : Callable [[InputsT ], Awaitable [OutputT ]], name : str | None = None , max_concurrency : int | None = None
310
+ self ,
311
+ task : Callable [[InputsT ], Awaitable [OutputT ]],
312
+ name : str | None = None ,
313
+ max_concurrency : int | None = None ,
314
+ progress : bool = True ,
299
315
) -> EvaluationReport :
300
316
"""Evaluates the test cases in the dataset using the given task.
301
317
@@ -308,11 +324,14 @@ def evaluate_sync(
308
324
If omitted, the name of the task function will be used.
309
325
max_concurrency: The maximum number of concurrent evaluations of the task to allow.
310
326
If None, all cases will be evaluated concurrently.
327
+ progress: Whether to show a progress bar for the evaluation. Defaults to True.
311
328
312
329
Returns:
313
330
A report containing the results of the evaluation.
314
331
"""
315
- return get_event_loop ().run_until_complete (self .evaluate (task , name = name , max_concurrency = max_concurrency ))
332
+ return get_event_loop ().run_until_complete (
333
+ self .evaluate (task , name = name , max_concurrency = max_concurrency , progress = progress )
334
+ )
316
335
317
336
def add_case (
318
337
self ,
0 commit comments