Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
2766b36
dynamic validation in GEPAState
aria42 Oct 8, 2025
540b7d6
remove uv.lock change
aria42 Oct 8, 2025
7b8815e
fix bugs and undo unneded to changes to reduce churn in PR
aria42 Oct 9, 2025
3bdf2bc
Dataloader protocol for train/validation
aria42 Oct 13, 2025
92143d7
Fix parsing for blocks with language specifiers in InstructionProposa…
CDBiddulph Oct 8, 2025
fe45832
Fix stalled loops when reflective proposer returns no trajectories
Oct 7, 2025
02afba9
Encourage users to highlight their use cases for GEPA
LakshyAAAgrawal Oct 8, 2025
5817e1d
Use cloudpickle if available.
EthanAtLL Oct 3, 2025
1f41b6d
omit cloudpickle import in load() as cloudlpickle simply calls pickle…
EthanAtLL Oct 7, 2025
f5c951e
GEPAState serialize with cloudpickle conditional on method flag.
EthanAtLL Oct 7, 2025
a5579d8
GEPAEngine, add use_cloudpickle flag used when saving GEPAState.
EthanAtLL Oct 7, 2025
d2e9b38
Test with use_cloudpickle=True.
EthanAtLL Oct 7, 2025
72700e7
Add use_cloudpickle to optimize.
EthanAtLL Oct 7, 2025
7952ad1
Add new references to README.md
LakshyAAAgrawal Oct 12, 2025
3a9316d
Potential fix for code scanning alert no. 3: Workflow does not contai…
LakshyAAAgrawal Oct 13, 2025
b195409
dynamic validation in GEPAState
aria42 Oct 8, 2025
0aefafa
Refactor data loading in GEPA optimization to accept both lists and D…
aria42 Oct 14, 2025
c6b616c
Introduce `EvaluationPolicy` for sampling validation batches and gaug…
aria42 Oct 14, 2025
62bee80
Fix stranded merge conflict
aria42 Oct 14, 2025
23c01a0
Address pr comments
aria42 Oct 14, 2025
889e1a9
- fix merge proposal code for sparse validation
aria42 Oct 14, 2025
0d651a3
Check that candidate selector is compatible with eval policy and add …
aria42 Oct 14, 2025
0189a8e
Merge branch 'main' into dynamic-validation-squash
aria42 Oct 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@ pip install git+https://github.com/gepa-ai/gepa.git
GEPA can be run in just a few lines of code. In this example, we'll use GEPA to optimize a system prompt for math problems from the AIME benchmark ([full tutorial](https://dspy.ai/tutorials/gepa_aime/)). Run the following in an environment with `OPENAI_API_KEY`:
```python
import gepa
from gepa.core.data_loader import ListDataLoader

# Load AIME dataset
trainset, valset, _ = gepa.examples.aime.init_dataset()
train_loader = ListDataLoader(trainset)
val_loader = ListDataLoader(valset)

seed_prompt = {
"system_prompt": "You are a helpful assistant. You are given a question and you need to answer it. The answer should be given at the end of your response in exactly the format '### <final answer>'"
Expand All @@ -50,7 +53,8 @@ seed_prompt = {
# Let's run GEPA optimization process.
gepa_result = gepa.optimize(
seed_candidate=seed_prompt,
trainset=trainset, valset=valset,
train_loader=train_loader,
val_loader=val_loader,
task_lm="openai/gpt-4.1-mini", # <-- This is the model being optimized
max_metric_calls=150, # <-- Set a budget
reflection_lm="openai/gpt-5", # <-- Use a strong model to reflect on mistakes and propose better prompts
Expand Down
16 changes: 11 additions & 5 deletions src/gepa/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from typing import Any, Callable

from gepa.adapters.default_adapter.default_adapter import DefaultAdapter
from gepa.core.adapter import DataInst, GEPAAdapter, RolloutOutput, Trajectory
from gepa.core.adapter import GEPAAdapter, RolloutOutput, Trajectory
from gepa.core.data_loader import DataId, DataInst, DataLoader
from gepa.core.engine import GEPAEngine
from gepa.core.result import GEPAResult
from gepa.core.state import GEPAState
from gepa.logging.experiment_tracker import create_experiment_tracker
from gepa.logging.logger import LoggerProtocol, StdOutLogger
from gepa.proposer.merge import MergeProposer
Expand All @@ -25,8 +27,8 @@

def optimize(
seed_candidate: dict[str, str],
trainset: list[DataInst],
valset: list[DataInst] | None = None,
trainset: DataLoader[DataId, DataInst],
valset: DataLoader[DataId, DataInst] | None = None,
adapter: GEPAAdapter[DataInst, Trajectory, RolloutOutput] | None = None,
task_lm: str | Callable | None = None,
# Reflection-based configuration
Expand All @@ -40,6 +42,7 @@ def optimize(
# Merge-based configuration
use_merge=False,
max_merge_invocations=5,
merge_val_overlap_floor: int = 5,
# Budget and Stop Condition
max_metric_calls=None,
stop_callbacks: "StopperProtocol | list[StopperProtocol] | None" = None,
Expand All @@ -57,6 +60,7 @@ def optimize(
# Reproducibility
seed: int = 0,
raise_on_exception: bool = True,
val_evaluation_policy: Callable[[GEPAState], list[int]] | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tying back to the point about some candidate selectors not working with partial validation scores, maybe in the gepa.optimize call (entrypoint), we can add the assert here. That if val_evaluation_policy is something that doesn't work with the candidate selector chosen, then throw an error.

Copy link
Author

@aria42 aria42 Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handled in 0d651a and added a test case that this fails

):
"""
GEPA is an evolutionary optimizer that evolves (multiple) text components of a complex system to optimize them towards a given metric.
Expand Down Expand Up @@ -95,8 +99,8 @@ def optimize(

Parameters:
- seed_candidate: The initial candidate to start with.
- trainset: The training set to use for reflective updates.
- valset: The validation set to use for tracking Pareto scores. If not provided, GEPA will use the trainset for both.
- trainset: Data loader yielding training batches for reflective updates.
- valset: The validation data loader to use for tracking Pareto scores. If not provided, GEPA will use the trainset for both.
- adapter: A `GEPAAdapter` instance that implements the adapter interface. This allows GEPA to plug into your system's environment. If not provided, GEPA will use a default adapter: `gepa.adapters.default_adapter.default_adapter.DefaultAdapter`, with model defined by `task_lm`.
- task_lm: Optional. The model to use for the task. This is only used if `adapter` is not provided, and is used to initialize the default adapter.

Expand Down Expand Up @@ -255,6 +259,7 @@ def evaluator(inputs, prog):
use_merge=use_merge,
max_merge_invocations=max_merge_invocations,
rng=rng,
val_overlap_floor=merge_val_overlap_floor,
)

engine = GEPAEngine(
Expand All @@ -272,6 +277,7 @@ def evaluator(inputs, prog):
display_progress_bar=display_progress_bar,
raise_on_exception=raise_on_exception,
stop_callback=stop_callback,
val_evaluation_policy=val_evaluation_policy,
)

with experiment_tracker:
Expand Down
7 changes: 6 additions & 1 deletion src/gepa/core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from dataclasses import dataclass
from typing import Any, Generic, Protocol, TypeVar

from .data_loader import DataInst

# Generic type aliases matching your original
RolloutOutput = TypeVar("RolloutOutput")
Trajectory = TypeVar("Trajectory")
DataInst = TypeVar("DataInst")


@dataclass
class EvaluationBatch(Generic[Trajectory, RolloutOutput]):
Expand All @@ -22,10 +24,12 @@ class EvaluationBatch(Generic[Trajectory, RolloutOutput]):
a reflective dataset (See `GEPAAdapter.make_reflective_dataset`). If capture_traces=True is passed to `evaluate`, trajectories
should be provided and align one-to-one with `outputs` and `scores`.
"""

outputs: list[RolloutOutput]
scores: list[float]
trajectories: list[Trajectory] | None = None


class ProposalFn(Protocol):
def __call__(
self,
Expand All @@ -46,6 +50,7 @@ def __call__(
"""
...


class GEPAAdapter(Protocol[DataInst, Trajectory, RolloutOutput]):
"""
GEPAAdapter is the single integration point between your system
Expand Down
48 changes: 48 additions & 0 deletions src/gepa/core/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Validation data loader protocols and concrete helpers."""

from __future__ import annotations

from typing import Hashable, Protocol, Sequence, TypeVar, runtime_checkable

DataId = TypeVar("DataId", bound=Hashable)
DataInst = TypeVar("DataInst")


@runtime_checkable
class DataLoader(Protocol[DataId, DataInst]):
"""Minimal interface for retrieving validation examples keyed by opaque ids."""

def all_ids(self) -> Sequence[DataId]:
"""Return the ordered universe of ids currently available. This may change over time."""

def fetch(self, ids: Sequence[DataId]) -> list[DataInst]:
"""Materialise the payloads corresponding to `ids`, preserving order."""

def __len__(self) -> int:
"""Return current number of items in the loader."""


class MutableDataLoader(DataLoader[DataId, DataInst], Protocol):
"""A data loader that can be mutated."""

def add_items(self, items: list[DataInst]) -> None:
"""Add items to the loader."""


class ListDataLoader(MutableDataLoader[int, DataInst]):
"""In-memory reference implementation backed by a list."""

def __init__(self, items: Sequence[DataInst]):
self.items = list(items)

def all_ids(self) -> Sequence[DataId]:
return list(range(len(self.items)))

def fetch(self, ids: Sequence[DataId]) -> list[DataInst]:
return [self.items[data_id] for data_id in ids]

def __len__(self) -> int:
return len(self.items)

def add_items(self, items: Sequence[DataInst]) -> None:
self.items.extend(items)
59 changes: 44 additions & 15 deletions src/gepa/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import traceback
from typing import Any, Callable, Generic

from gepa.core.data_loader import DataId, DataInst, DataLoader
from gepa.core.state import GEPAState, initialize_gepa_state
from gepa.logging.utils import log_detailed_metrics_after_discovering_new_program
from gepa.proposer.merge import MergeProposer
from gepa.proposer.reflective_mutation.reflective_mutation import ReflectiveMutationProposer

from .adapter import DataInst, RolloutOutput, Trajectory
from .adapter import RolloutOutput, Trajectory

# Import tqdm for progress bar functionality
try:
Expand All @@ -18,7 +19,7 @@
tqdm = None


class GEPAEngine(Generic[DataInst, Trajectory, RolloutOutput]):
class GEPAEngine(Generic[DataId, DataInst, Trajectory, RolloutOutput]):
"""
Orchestrates the optimization loop. It uses pluggable ProposeNewCandidate strategies.
"""
Expand All @@ -27,7 +28,7 @@ def __init__(
self,
run_dir: str | None,
evaluator: Callable[[list[DataInst], dict[str, str]], tuple[list[RolloutOutput], list[float]]],
valset: list[DataInst] | None,
valset: DataLoader[DataId, DataInst] | None,
seed_candidate: dict[str, str],
# Controls
perfect_score: float,
Expand All @@ -44,6 +45,7 @@ def __init__(
raise_on_exception: bool = True,
# Budget and Stop Condition
stop_callback: Callable[[Any], bool] | None = None,
val_evaluation_policy: Callable[[GEPAState], list[int]] | None = None,
):
self.logger = logger
self.run_dir = run_dir
Expand Down Expand Up @@ -72,10 +74,21 @@ def __init__(
self.display_progress_bar = display_progress_bar

self.raise_on_exception = raise_on_exception
self.val_evaluation_policy = val_evaluation_policy

def _val_evaluator(self) -> Callable[[dict[str, str]], tuple[list[RolloutOutput], list[float]]]:
def _evaluate_on_valset(
self, program: dict[str, str], state: GEPAState
) -> tuple[dict[int, RolloutOutput], dict[int, float]]:
assert self.valset is not None
return lambda prog: self.evaluator(self.valset, prog)

val_ids = self.val_evaluation_policy(state) if self.val_evaluation_policy else list(self.valset.all_ids())
batch = self.valset.fetch(val_ids)
outputs, scores = self.evaluator(batch, program)
assert len(outputs) == len(val_ids), "Eval outputs should match length of selected validation indices"

outputs_by_val_idx = dict(zip(val_ids, outputs, strict=False))
scores_by_val_idx = dict(zip(val_ids, scores, strict=False))
return outputs_by_val_idx, scores_by_val_idx

def _get_pareto_front_programs(self, state: GEPAState) -> list:
return state.program_at_pareto_front_valset
Expand All @@ -88,22 +101,24 @@ def _run_full_eval_and_add(
) -> tuple[int, int]:
num_metric_calls_by_discovery = state.total_num_evals

valset_outputs, valset_subscores = self._val_evaluator()(new_program)
valset_score = sum(valset_subscores) / len(valset_subscores)
valset_outputs, valset_subscores = self._evaluate_on_valset(new_program, state)
valset_score = (
sum(valset_subscores.values()) / len(valset_subscores) if len(valset_subscores) > 0 else float("-inf")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Any thoughts about float(-inf) vs NaN?

Copy link
Author

@aria42 aria42 Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slight preference for -inf so comparisons with real numbers are still coherent, but a weak preference and happy to go with NaN if you prefer

)

state.num_full_ds_evals += 1
state.total_num_evals += len(valset_subscores)

new_program_idx, linear_pareto_front_program_idx = state.update_state_with_new_program(
parent_program_idx=parent_program_idx,
new_program=new_program,
valset_score=valset_score,
valset_outputs=valset_outputs,
valset_subscores=valset_subscores,
run_dir=self.run_dir,
num_metric_calls_by_discovery_of_new_program=num_metric_calls_by_discovery,
)
state.full_program_trace[-1]["new_program_idx"] = new_program_idx
state.full_program_trace[-1]["evaluated_val_indices"] = sorted(valset_subscores.keys())

if new_program_idx == linear_pareto_front_program_idx:
self.logger.log(f"Iteration {state.i + 1}: New program is on the linear pareto front")
Expand All @@ -116,6 +131,7 @@ def _run_full_eval_and_add(
valset_subscores=valset_subscores,
experiment_tracker=self.experiment_tracker,
linear_pareto_front_program_idx=linear_pareto_front_program_idx,
valset_size=len(self.valset),
)
return new_program_idx, linear_pareto_front_program_idx

Expand Down Expand Up @@ -149,28 +165,37 @@ def run(self) -> GEPAState:
if self.valset is None:
raise ValueError("valset must be provided to GEPAEngine.run()")

def valset_evaluator(program: dict[str, str]):
all_ids = list(self.valset.all_ids())
all_outputs, all_scores = self.evaluator(self.valset.fetch(all_ids), program)
return (
dict(zip(all_ids, all_outputs, strict=False)),
dict(zip(all_ids, all_scores, strict=False)),
)

# Initialize state
state = initialize_gepa_state(
run_dir=self.run_dir,
logger=self.logger,
seed_candidate=self.seed_candidate,
valset_evaluator=self._val_evaluator(),
valset_evaluator=valset_evaluator,
track_best_outputs=self.track_best_outputs,
)

assert len(state.pareto_front_valset) == len(self.valset)

# Log base program score
base_val_avg, base_val_coverage = state.get_program_average(0)
self.experiment_tracker.log_metrics(
{
"base_program_full_valset_score": state.program_full_scores_val_set[0],
"base_program_full_valset_score": base_val_avg,
"base_program_val_coverage": base_val_coverage,
"iteration": state.i + 1,
},
step=state.i + 1,
)

self.logger.log(
f"Iteration {state.i + 1}: Base program full valset score: {state.program_full_scores_val_set[0]}"
f"Iteration {state.i + 1}: Base program full valset score: {base_val_avg} "
f"over {base_val_coverage} / {len(self.valset)} examples"
)

# Merge scheduling
Expand Down Expand Up @@ -232,10 +257,14 @@ def run(self) -> GEPAState:
old_sum = sum(proposal.subsample_scores_before or [])
new_sum = sum(proposal.subsample_scores_after or [])
if new_sum <= old_sum:
self.logger.log(f"Iteration {state.i + 1}: New subsample score {new_sum} is not better than old score {old_sum}, skipping")
self.logger.log(
f"Iteration {state.i + 1}: New subsample score {new_sum} is not better than old score {old_sum}, skipping"
)
continue
else:
self.logger.log(f"Iteration {state.i + 1}: New subsample score {new_sum} is better than old score {old_sum}. Continue to full eval and add to candidate pool.")
self.logger.log(
f"Iteration {state.i + 1}: New subsample score {new_sum} is better than old score {old_sum}. Continue to full eval and add to candidate pool."
)

# Accept: full eval + add
self._run_full_eval_and_add(
Expand Down
23 changes: 12 additions & 11 deletions src/gepa/core/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Generic

from gepa.core.adapter import RolloutOutput
from gepa.core.state import ProgramIdx, ValId, ValScores


@dataclass(frozen=True)
Expand Down Expand Up @@ -39,16 +40,17 @@ class GEPAResult(Generic[RolloutOutput]):
- instance_winners(t): set of candidates on the pareto front for val instance t
- to_dict(...), save_json(...): serialization helpers
"""

# Core data
candidates: list[dict[str, str]]
parents: list[list[int | None]]
parents: list[list[ProgramIdx | None]]
val_aggregate_scores: list[float]
val_subscores: list[list[float]]
per_val_instance_best_candidates: list[set[int]]
val_subscores: list[ValScores]
per_val_instance_best_candidates: dict[ValId, set[ProgramIdx]]
discovery_eval_counts: list[int]

# Optional data
best_outputs_valset: list[list[tuple[int, list[RolloutOutput]]]] | None = None
best_outputs_valset: dict[ValId, list[tuple[ProgramIdx, RolloutOutput]]] | None = None

# Run metadata (optional)
total_metric_calls: int | None = None
Expand All @@ -75,18 +77,15 @@ def best_candidate(self) -> dict[str, str]:
return self.candidates[self.best_idx]

def to_dict(self) -> dict[str, Any]:
cands = [
dict(cand.items())
for cand in self.candidates
]
cands = [dict(cand.items()) for cand in self.candidates]

return dict(
candidates=cands,
parents=self.parents,
val_aggregate_scores=self.val_aggregate_scores,
val_subscores=self.val_subscores,
best_outputs_valset=self.best_outputs_valset,
per_val_instance_best_candidates=[list(s) for s in self.per_val_instance_best_candidates],
per_val_instance_best_candidates={k: list(v) for k, v in self.per_val_instance_best_candidates.items()},
discovery_eval_counts=self.discovery_eval_counts,
total_metric_calls=self.total_metric_calls,
num_full_val_evals=self.num_full_val_evals,
Expand All @@ -105,8 +104,10 @@ def from_state(state: Any, run_dir: str | None = None, seed: int | None = None)
parents=list(state.parent_program_for_candidate),
val_aggregate_scores=list(state.program_full_scores_val_set),
best_outputs_valset=getattr(state, "best_outputs_valset", None),
val_subscores=[list(s) for s in state.prog_candidate_val_subscores],
per_val_instance_best_candidates=[set(s) for s in state.program_at_pareto_front_valset],
val_subscores=[dict(scores) for scores in state.prog_candidate_val_subscores],
per_val_instance_best_candidates={
val_id: set(front) for val_id, front in state.program_at_pareto_front_valset.items()
},
discovery_eval_counts=list(state.num_metric_calls_by_discovery),
total_metric_calls=getattr(state, "total_num_evals", None),
num_full_val_evals=getattr(state, "num_full_ds_evals", None),
Expand Down
Loading