-
Notifications
You must be signed in to change notification settings - Fork 88
Dynamic Validation in GEPAState #100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
2766b36
540b7d6
7b8815e
3bdf2bc
92143d7
fe45832
02afba9
5817e1d
1f41b6d
f5c951e
a5579d8
d2e9b38
72700e7
7952ad1
3a9316d
b195409
0aefafa
c6b616c
62bee80
23c01a0
889e1a9
0d651a3
0189a8e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,9 +6,11 @@ | |
from typing import Any, Callable | ||
|
||
from gepa.adapters.default_adapter.default_adapter import DefaultAdapter | ||
from gepa.core.adapter import DataInst, GEPAAdapter, RolloutOutput, Trajectory | ||
from gepa.core.adapter import GEPAAdapter, RolloutOutput, Trajectory | ||
from gepa.core.data_loader import DataId, DataInst, DataLoader | ||
from gepa.core.engine import GEPAEngine | ||
from gepa.core.result import GEPAResult | ||
from gepa.core.state import GEPAState | ||
from gepa.logging.experiment_tracker import create_experiment_tracker | ||
from gepa.logging.logger import LoggerProtocol, StdOutLogger | ||
from gepa.proposer.merge import MergeProposer | ||
|
@@ -25,8 +27,8 @@ | |
|
||
def optimize( | ||
seed_candidate: dict[str, str], | ||
trainset: list[DataInst], | ||
valset: list[DataInst] | None = None, | ||
trainset: DataLoader[DataId, DataInst], | ||
aria42 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
valset: DataLoader[DataId, DataInst] | None = None, | ||
adapter: GEPAAdapter[DataInst, Trajectory, RolloutOutput] | None = None, | ||
task_lm: str | Callable | None = None, | ||
# Reflection-based configuration | ||
|
@@ -40,6 +42,7 @@ def optimize( | |
# Merge-based configuration | ||
use_merge=False, | ||
max_merge_invocations=5, | ||
merge_val_overlap_floor: int = 5, | ||
# Budget and Stop Condition | ||
max_metric_calls=None, | ||
stop_callbacks: "StopperProtocol | list[StopperProtocol] | None" = None, | ||
|
@@ -57,6 +60,7 @@ def optimize( | |
# Reproducibility | ||
seed: int = 0, | ||
raise_on_exception: bool = True, | ||
val_evaluation_policy: Callable[[GEPAState], list[int]] | None = None, | ||
|
||
): | ||
""" | ||
GEPA is an evolutionary optimizer that evolves (multiple) text components of a complex system to optimize them towards a given metric. | ||
|
@@ -95,8 +99,8 @@ def optimize( | |
|
||
Parameters: | ||
- seed_candidate: The initial candidate to start with. | ||
- trainset: The training set to use for reflective updates. | ||
- valset: The validation set to use for tracking Pareto scores. If not provided, GEPA will use the trainset for both. | ||
- trainset: Data loader yielding training batches for reflective updates. | ||
- valset: The validation data loader to use for tracking Pareto scores. If not provided, GEPA will use the trainset for both. | ||
- adapter: A `GEPAAdapter` instance that implements the adapter interface. This allows GEPA to plug into your system's environment. If not provided, GEPA will use a default adapter: `gepa.adapters.default_adapter.default_adapter.DefaultAdapter`, with model defined by `task_lm`. | ||
- task_lm: Optional. The model to use for the task. This is only used if `adapter` is not provided, and is used to initialize the default adapter. | ||
|
||
|
@@ -255,6 +259,7 @@ def evaluator(inputs, prog): | |
use_merge=use_merge, | ||
max_merge_invocations=max_merge_invocations, | ||
rng=rng, | ||
val_overlap_floor=merge_val_overlap_floor, | ||
) | ||
|
||
engine = GEPAEngine( | ||
|
@@ -272,6 +277,7 @@ def evaluator(inputs, prog): | |
display_progress_bar=display_progress_bar, | ||
raise_on_exception=raise_on_exception, | ||
stop_callback=stop_callback, | ||
val_evaluation_policy=val_evaluation_policy, | ||
) | ||
|
||
with experiment_tracker: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
"""Validation data loader protocols and concrete helpers.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Hashable, Protocol, Sequence, TypeVar, runtime_checkable | ||
|
||
DataId = TypeVar("DataId", bound=Hashable) | ||
DataInst = TypeVar("DataInst") | ||
|
||
|
||
@runtime_checkable | ||
class DataLoader(Protocol[DataId, DataInst]): | ||
"""Minimal interface for retrieving validation examples keyed by opaque ids.""" | ||
|
||
def all_ids(self) -> Sequence[DataId]: | ||
"""Return the ordered universe of ids currently available. This may change over time.""" | ||
|
||
def fetch(self, ids: Sequence[DataId]) -> list[DataInst]: | ||
"""Materialise the payloads corresponding to `ids`, preserving order.""" | ||
|
||
def __len__(self) -> int: | ||
"""Return current number of items in the loader.""" | ||
|
||
|
||
class MutableDataLoader(DataLoader[DataId, DataInst], Protocol): | ||
"""A data loader that can be mutated.""" | ||
|
||
def add_items(self, items: list[DataInst]) -> None: | ||
"""Add items to the loader.""" | ||
|
||
|
||
class ListDataLoader(MutableDataLoader[int, DataInst]): | ||
"""In-memory reference implementation backed by a list.""" | ||
|
||
def __init__(self, items: Sequence[DataInst]): | ||
self.items = list(items) | ||
|
||
def all_ids(self) -> Sequence[DataId]: | ||
return list(range(len(self.items))) | ||
|
||
def fetch(self, ids: Sequence[DataId]) -> list[DataInst]: | ||
return [self.items[data_id] for data_id in ids] | ||
|
||
def __len__(self) -> int: | ||
return len(self.items) | ||
|
||
def add_items(self, items: Sequence[DataInst]) -> None: | ||
self.items.extend(items) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,12 +4,13 @@ | |
import traceback | ||
from typing import Any, Callable, Generic | ||
|
||
from gepa.core.data_loader import DataId, DataInst, DataLoader | ||
from gepa.core.state import GEPAState, initialize_gepa_state | ||
from gepa.logging.utils import log_detailed_metrics_after_discovering_new_program | ||
from gepa.proposer.merge import MergeProposer | ||
from gepa.proposer.reflective_mutation.reflective_mutation import ReflectiveMutationProposer | ||
|
||
from .adapter import DataInst, RolloutOutput, Trajectory | ||
from .adapter import RolloutOutput, Trajectory | ||
|
||
# Import tqdm for progress bar functionality | ||
try: | ||
|
@@ -18,7 +19,7 @@ | |
tqdm = None | ||
|
||
|
||
class GEPAEngine(Generic[DataInst, Trajectory, RolloutOutput]): | ||
class GEPAEngine(Generic[DataId, DataInst, Trajectory, RolloutOutput]): | ||
""" | ||
Orchestrates the optimization loop. It uses pluggable ProposeNewCandidate strategies. | ||
""" | ||
|
@@ -27,7 +28,7 @@ def __init__( | |
self, | ||
run_dir: str | None, | ||
evaluator: Callable[[list[DataInst], dict[str, str]], tuple[list[RolloutOutput], list[float]]], | ||
valset: list[DataInst] | None, | ||
valset: DataLoader[DataId, DataInst] | None, | ||
seed_candidate: dict[str, str], | ||
# Controls | ||
perfect_score: float, | ||
|
@@ -44,6 +45,7 @@ def __init__( | |
raise_on_exception: bool = True, | ||
# Budget and Stop Condition | ||
stop_callback: Callable[[Any], bool] | None = None, | ||
val_evaluation_policy: Callable[[GEPAState], list[int]] | None = None, | ||
aria42 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
): | ||
self.logger = logger | ||
self.run_dir = run_dir | ||
|
@@ -72,10 +74,21 @@ def __init__( | |
self.display_progress_bar = display_progress_bar | ||
|
||
self.raise_on_exception = raise_on_exception | ||
self.val_evaluation_policy = val_evaluation_policy | ||
|
||
def _val_evaluator(self) -> Callable[[dict[str, str]], tuple[list[RolloutOutput], list[float]]]: | ||
def _evaluate_on_valset( | ||
self, program: dict[str, str], state: GEPAState | ||
) -> tuple[dict[int, RolloutOutput], dict[int, float]]: | ||
assert self.valset is not None | ||
return lambda prog: self.evaluator(self.valset, prog) | ||
|
||
val_ids = self.val_evaluation_policy(state) if self.val_evaluation_policy else list(self.valset.all_ids()) | ||
batch = self.valset.fetch(val_ids) | ||
outputs, scores = self.evaluator(batch, program) | ||
assert len(outputs) == len(val_ids), "Eval outputs should match length of selected validation indices" | ||
|
||
outputs_by_val_idx = dict(zip(val_ids, outputs, strict=False)) | ||
scores_by_val_idx = dict(zip(val_ids, scores, strict=False)) | ||
return outputs_by_val_idx, scores_by_val_idx | ||
|
||
def _get_pareto_front_programs(self, state: GEPAState) -> list: | ||
return state.program_at_pareto_front_valset | ||
|
@@ -88,22 +101,24 @@ def _run_full_eval_and_add( | |
) -> tuple[int, int]: | ||
num_metric_calls_by_discovery = state.total_num_evals | ||
|
||
valset_outputs, valset_subscores = self._val_evaluator()(new_program) | ||
valset_score = sum(valset_subscores) / len(valset_subscores) | ||
valset_outputs, valset_subscores = self._evaluate_on_valset(new_program, state) | ||
valset_score = ( | ||
sum(valset_subscores.values()) / len(valset_subscores) if len(valset_subscores) > 0 else float("-inf") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Any thoughts about float(-inf) vs NaN? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Slight preference for -inf so comparisons with real numbers are still coherent, but a weak preference and happy to go with NaN if you prefer |
||
) | ||
|
||
state.num_full_ds_evals += 1 | ||
state.total_num_evals += len(valset_subscores) | ||
|
||
new_program_idx, linear_pareto_front_program_idx = state.update_state_with_new_program( | ||
parent_program_idx=parent_program_idx, | ||
new_program=new_program, | ||
valset_score=valset_score, | ||
valset_outputs=valset_outputs, | ||
valset_subscores=valset_subscores, | ||
run_dir=self.run_dir, | ||
num_metric_calls_by_discovery_of_new_program=num_metric_calls_by_discovery, | ||
) | ||
state.full_program_trace[-1]["new_program_idx"] = new_program_idx | ||
state.full_program_trace[-1]["evaluated_val_indices"] = sorted(valset_subscores.keys()) | ||
|
||
if new_program_idx == linear_pareto_front_program_idx: | ||
self.logger.log(f"Iteration {state.i + 1}: New program is on the linear pareto front") | ||
|
@@ -116,6 +131,7 @@ def _run_full_eval_and_add( | |
valset_subscores=valset_subscores, | ||
experiment_tracker=self.experiment_tracker, | ||
linear_pareto_front_program_idx=linear_pareto_front_program_idx, | ||
valset_size=len(self.valset), | ||
) | ||
return new_program_idx, linear_pareto_front_program_idx | ||
|
||
|
@@ -149,28 +165,37 @@ def run(self) -> GEPAState: | |
if self.valset is None: | ||
raise ValueError("valset must be provided to GEPAEngine.run()") | ||
|
||
def valset_evaluator(program: dict[str, str]): | ||
all_ids = list(self.valset.all_ids()) | ||
all_outputs, all_scores = self.evaluator(self.valset.fetch(all_ids), program) | ||
return ( | ||
dict(zip(all_ids, all_outputs, strict=False)), | ||
dict(zip(all_ids, all_scores, strict=False)), | ||
) | ||
|
||
# Initialize state | ||
state = initialize_gepa_state( | ||
run_dir=self.run_dir, | ||
logger=self.logger, | ||
seed_candidate=self.seed_candidate, | ||
valset_evaluator=self._val_evaluator(), | ||
valset_evaluator=valset_evaluator, | ||
track_best_outputs=self.track_best_outputs, | ||
) | ||
|
||
assert len(state.pareto_front_valset) == len(self.valset) | ||
|
||
# Log base program score | ||
base_val_avg, base_val_coverage = state.get_program_average(0) | ||
self.experiment_tracker.log_metrics( | ||
{ | ||
"base_program_full_valset_score": state.program_full_scores_val_set[0], | ||
"base_program_full_valset_score": base_val_avg, | ||
"base_program_val_coverage": base_val_coverage, | ||
"iteration": state.i + 1, | ||
}, | ||
step=state.i + 1, | ||
) | ||
|
||
self.logger.log( | ||
f"Iteration {state.i + 1}: Base program full valset score: {state.program_full_scores_val_set[0]}" | ||
f"Iteration {state.i + 1}: Base program full valset score: {base_val_avg} " | ||
f"over {base_val_coverage} / {len(self.valset)} examples" | ||
) | ||
|
||
# Merge scheduling | ||
|
@@ -232,10 +257,14 @@ def run(self) -> GEPAState: | |
old_sum = sum(proposal.subsample_scores_before or []) | ||
new_sum = sum(proposal.subsample_scores_after or []) | ||
if new_sum <= old_sum: | ||
self.logger.log(f"Iteration {state.i + 1}: New subsample score {new_sum} is not better than old score {old_sum}, skipping") | ||
self.logger.log( | ||
f"Iteration {state.i + 1}: New subsample score {new_sum} is not better than old score {old_sum}, skipping" | ||
) | ||
continue | ||
else: | ||
self.logger.log(f"Iteration {state.i + 1}: New subsample score {new_sum} is better than old score {old_sum}. Continue to full eval and add to candidate pool.") | ||
self.logger.log( | ||
f"Iteration {state.i + 1}: New subsample score {new_sum} is better than old score {old_sum}. Continue to full eval and add to candidate pool." | ||
) | ||
|
||
# Accept: full eval + add | ||
self._run_full_eval_and_add( | ||
|
Uh oh!
There was an error while loading. Please reload this page.