-
Notifications
You must be signed in to change notification settings - Fork 2k
Add async support for dspy.Evaluate #8504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,8 @@ | |
if TYPE_CHECKING: | ||
import pandas as pd | ||
|
||
import asyncio | ||
|
||
import tqdm | ||
|
||
import dspy | ||
|
@@ -51,6 +53,7 @@ class EvaluationResult(Prediction): | |
- score: An float value (e.g., 67.30) representing the overall performance | ||
- results: a list of (example, prediction, score) tuples for each example in devset | ||
""" | ||
|
||
def __init__(self, score: float, results: list[Tuple["dspy.Example", "dspy.Example", Any]]): | ||
super().__init__(score=score, results=results) | ||
|
||
|
@@ -126,71 +129,132 @@ def __call__( | |
Returns: | ||
The evaluation results are returned as a dspy.EvaluationResult object containing the following attributes: | ||
- score: A float percentage score (e.g., 67.30) representing overall performance | ||
- results: a list of (example, prediction, score) tuples for each example in devset | ||
""" | ||
metric = metric if metric is not None else self.metric | ||
devset = devset if devset is not None else self.devset | ||
num_threads = num_threads if num_threads is not None else self.num_threads | ||
display_progress = display_progress if display_progress is not None else self.display_progress | ||
display_table = display_table if display_table is not None else self.display_table | ||
|
||
metric, devset, num_threads, display_progress, display_table = self._resolve_call_args( | ||
metric, devset, num_threads, display_progress, display_table | ||
) | ||
if callback_metadata: | ||
logger.debug(f"Evaluate is called with callback metadata: {callback_metadata}") | ||
|
||
tqdm.tqdm._instances.clear() | ||
results = self._execute_with_multithreading(program, metric, devset, num_threads, display_progress) | ||
return self._process_evaluate_result(devset, results, metric, display_table) | ||
|
||
executor = ParallelExecutor( | ||
num_threads=num_threads, | ||
disable_progress_bar=not display_progress, | ||
max_errors=( | ||
self.max_errors | ||
if self.max_errors is not None | ||
else dspy.settings.max_errors | ||
), | ||
provide_traceback=self.provide_traceback, | ||
compare_results=True, | ||
@with_callbacks | ||
async def acall( | ||
self, | ||
program: "dspy.Module", | ||
metric: Callable | None = None, | ||
devset: List["dspy.Example"] | None = None, | ||
num_threads: int | None = None, | ||
display_progress: bool | None = None, | ||
display_table: bool | int | None = None, | ||
callback_metadata: dict[str, Any] | None = None, | ||
) -> EvaluationResult: | ||
"""Async version of `Evaluate.__call__`.""" | ||
metric, devset, num_threads, display_progress, display_table = self._resolve_call_args( | ||
metric, devset, num_threads, display_progress, display_table | ||
) | ||
if callback_metadata: | ||
logger.debug(f"Evaluate.acall is called with callback metadata: {callback_metadata}") | ||
tqdm.tqdm._instances.clear() | ||
results = await self._execute_with_event_loop(program, metric, devset, num_threads) | ||
return self._process_evaluate_result(devset, results, metric, display_table) | ||
|
||
def _resolve_call_args(self, metric, devset, num_threads, display_progress, display_table): | ||
return ( | ||
metric or self.metric, | ||
devset or self.devset, | ||
num_threads or self.num_threads, | ||
display_progress or self.display_progress, | ||
display_table or self.display_table, | ||
) | ||
|
||
def process_item(example): | ||
prediction = program(**example.inputs()) | ||
score = metric(example, prediction) | ||
|
||
# Increment assert and suggest failures to program's attributes | ||
if hasattr(program, "_assert_failures"): | ||
program._assert_failures += dspy.settings.get("assert_failures") | ||
if hasattr(program, "_suggest_failures"): | ||
program._suggest_failures += dspy.settings.get("suggest_failures") | ||
|
||
return prediction, score | ||
|
||
results = executor.execute(process_item, devset) | ||
def _process_evaluate_result(self, devset, results, metric, display_table): | ||
assert len(devset) == len(results) | ||
|
||
results = [((dspy.Prediction(), self.failure_score) if r is None else r) for r in results] | ||
results = [(example, prediction, score) for example, (prediction, score) in zip(devset, results, strict=False)] | ||
ncorrect, ntotal = sum(score for *_, score in results), len(devset) | ||
|
||
logger.info(f"Average Metric: {ncorrect} / {ntotal} ({round(100 * ncorrect / ntotal, 1)}%)") | ||
|
||
if display_table: | ||
if importlib.util.find_spec("pandas") is not None: | ||
# Rename the 'correct' column to the name of the metric object | ||
metric_name = metric.__name__ if isinstance(metric, types.FunctionType) else metric.__class__.__name__ | ||
# Construct a pandas DataFrame from the results | ||
result_df = self._construct_result_table(results, metric_name) | ||
|
||
self._display_result_table(result_df, display_table, metric_name) | ||
else: | ||
logger.warning("Skipping table display since `pandas` is not installed.") | ||
|
||
return EvaluationResult( | ||
score=round(100 * ncorrect / ntotal, 2), | ||
results=results, | ||
) | ||
|
||
def _execute_with_multithreading( | ||
self, | ||
program: "dspy.Module", | ||
metric: Callable, | ||
devset: List["dspy.Example"], | ||
num_threads: int, | ||
disable_progress_bar: bool, | ||
): | ||
executor = ParallelExecutor( | ||
num_threads=num_threads, | ||
disable_progress_bar=disable_progress_bar, | ||
max_errors=(self.max_errors or dspy.settings.max_errors), | ||
provide_traceback=self.provide_traceback, | ||
compare_results=True, | ||
) | ||
|
||
def process_item(example): | ||
prediction = program(**example.inputs()) | ||
score = metric(example, prediction) | ||
return prediction, score | ||
|
||
return executor.execute(process_item, devset) | ||
|
||
async def _execute_with_event_loop( | ||
self, | ||
program: "dspy.Module", | ||
metric: Callable, | ||
devset: List["dspy.Example"], | ||
num_threads: int, | ||
): | ||
queue = asyncio.Queue() | ||
results = [None for _ in range(len(devset))] | ||
for i, example in enumerate(devset): | ||
await queue.put((i, example)) | ||
|
||
for _ in range(num_threads): | ||
# Add a sentinel value to indicate that the worker should exit | ||
await queue.put((-1, None)) | ||
|
||
# Create tqdm progress bar | ||
pbar = tqdm.tqdm(total=len(devset), dynamic_ncols=True) | ||
|
||
async def worker(): | ||
while True: | ||
index, example = await queue.get() | ||
if index == -1: | ||
break | ||
prediction = await program.acall(**example.inputs()) | ||
score = metric(example, prediction) | ||
results[index] = (prediction, score) | ||
|
||
vals = [r[-1] for r in results if r is not None] | ||
nresults = sum(vals) | ||
ntotal = len(vals) | ||
pct = round(100 * nresults / ntotal, 1) if ntotal else 0 | ||
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({pct}%)") | ||
pbar.update(1) | ||
queue.task_done() | ||
|
||
workers = [asyncio.create_task(worker()) for _ in range(num_threads)] | ||
await asyncio.gather(*workers) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. q: Does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't use multiple threads, but does put those async workers to run asynchronously in the same event loop. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that's my understanding. Might it be misleading if we accept |
||
pbar.close() | ||
|
||
return results | ||
|
||
def _construct_result_table( | ||
self, results: list[Tuple["dspy.Example", "dspy.Example", Any]], metric_name: str | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!