-
Notifications
You must be signed in to change notification settings - Fork 199
An Enhancement Proposal Workflow #1674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Enhancement Proposals (EPs) | ||
|
||
Enhancement Proposals document and discuss substantial changes to sbi’s API, internals, | ||
or processes. EPs provide context, motivation, and a concrete design for review and | ||
decision-making. | ||
|
||
This page serves as an entry point to all EPs and explains how they are organized. | ||
|
||
## Process and scope | ||
|
||
- EPs are used for larger changes that affect users or maintainers broadly (APIs, | ||
architecture, governance, tooling). | ||
- Each EP is a standalone Markdown file that explains the motivation, design, | ||
alternatives, and migration plan if needed. | ||
- EPs are numbered sequentially (EP-00, EP-01, …). | ||
- Status values: Draft → Discussion → Accepted → Implementing → Completed → | ||
Rejected/Withdrawn. The default is Draft. | ||
|
||
## Current EPs | ||
|
||
- [EP-00: Enhancement Proposal Process](proposals/ep-00-process.md) | ||
- [EP-01: Pluggable Training Infrastructure for sbi](proposals/ep-01-pluggable-training.md) | ||
|
||
If you would like to propose a new EP, open a discussion and a PR adding a new | ||
`proposals/ep-XX-<short-title>.md` file. | ||
|
||
For feedback and community discussion, we use [GitHub | ||
Discussions](https://github.com/sbi-dev/sbi/discussions/categories/enhancement-proposals). | ||
Each EP page links to its corresponding discussion thread when available. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# EP-00: Enhancement Proposal Process | ||
|
||
Status: Discussion | ||
Authors: @janfb | ||
|
||
## Purpose | ||
|
||
This document defines how `sbi` Enhancement Proposals (EPs) are created, discussed, | ||
reviewed, and accepted. EPs are for substantial changes to APIs, architecture, | ||
processes, or tooling that benefit from design discussion and broader visibility. | ||
|
||
## When to write an EP | ||
|
||
Write an EP when a change is: | ||
|
||
- Affecting multiple modules or user-facing APIs. | ||
- Architectural (e.g., training loop refactor, storage/logging backends). | ||
- Process/governance-related (e.g., release, deprecation, contribution rules). | ||
- Potentially controversial or with multiple viable designs. | ||
|
||
Smaller changes (typo fixes, small bugfixes, doc tweaks) don’t need an EP. | ||
|
||
## Structure of an EP | ||
|
||
Recommended sections: | ||
|
||
- Summary: One-paragraph overview of the proposal and motivation. | ||
- Motivation: Context, problems with the status quo. | ||
- Goals and Non-Goals: Clarify scope. | ||
- Design: API changes, data flow, architecture. | ||
- Alternatives Considered: Briefly describe alternatives and trade-offs. | ||
- Backward Compatibility: Migration strategy and deprecations. | ||
- Testing & Validation: How to verify correctness and performance. | ||
- References: Related issues, PRs, or external work. | ||
|
||
## Authoring and location | ||
|
||
- Canonical source: sbi docs site under `mkdocs/docs/proposals/ep-XX-title.md`. | ||
- Numbering: Incremental (EP-00, EP-01, …). Use a short kebab-case slug. | ||
- Reviews happen via PRs to the repository, like other changes. | ||
|
||
## Discussion and feedback | ||
|
||
- Open a GitHub Discussion titled `EP-XX: <Title>` in the “Enhancement Proposals” category. | ||
- Discussion post should include: brief summary, link to the EP doc, specific questions, | ||
and decision criteria if relevant. | ||
|
||
## Decision process | ||
|
||
- Aim for consensus. Maintainers make final calls after sufficient discussion. | ||
- Acceptance criteria: clear benefits, feasibility, alignment with sbi design | ||
principles, and maintainability. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
# EP-01: Pluggable Training Infrastructure for sbi | ||
|
||
Status: Discussion | ||
Feedback: See GitHub Discussion → [EP-01 Discussion](https://github.com/sbi-dev/sbi/discussions/new?category=ideas) | ||
|
||
## Summary | ||
|
||
This enhancement proposal describes the refactoring of `sbi` neural network training | ||
infrastructure to address technical debt, improve maintainability, and provide users | ||
with better experiment tracking capabilities. The refactoring will introduce typed | ||
configuration objects, extract a unified training loop, and implement pluggable | ||
interfaces for logging and early stopping while maintaining full backward compatibility. | ||
|
||
## Motivation | ||
|
||
The `sbi` package has evolved organically over five years to serve researchers in | ||
simulation-based inference. This growth has resulted in several architectural issues in | ||
the training infrastructure: | ||
|
||
1. **Code duplication**: Training logic is duplicated across NPE, NLE, NRE, and other | ||
inference methods, spanning over 3,000 lines of code | ||
2. **Type safety issues**: Unstructured dictionary parameters lead to runtime errors and | ||
poor IDE support | ||
3. **Limited logging options**: Users are restricted to TensorBoard with no easy way to | ||
integrate their preferred experiment tracking tools | ||
4. **Missing features**: No built-in early stopping, leading to wasted compute and | ||
potential overfitting | ||
5. **Integration barriers**: Community contributions like PR #1629 cannot be easily | ||
integrated due to architectural constraints | ||
|
||
These issues affect both users and maintainers, slowing down development and making the | ||
codebase harder to work with. | ||
|
||
## Goals | ||
|
||
### For Users | ||
|
||
- **Seamless experiment tracking**: Support for TensorBoard, WandB, MLflow, and stdout | ||
without changing existing code | ||
- **Multiple early stopping strategies**: Patience-based, plateau detection, and custom | ||
criteria | ||
- **Better debugging**: Typed configurations with clear error messages | ||
- **Zero migration effort**: Full backward compatibility with existing code | ||
|
||
### For Maintainers | ||
|
||
- **Reduced code duplication**: Extract shared training logic into reusable components | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am fine with doing this for this PR, but in general, I don't think that code duplication is always a bad thing. Introducing "reusable components" adds cognitive load and will make the package harder to maintain in the future. |
||
- **Type safety**: Prevent entire classes of bugs through static typing | ||
- **Easier feature integration**: Clean interfaces for community contributions | ||
- **Future extensibility**: Foundation for integration with PyTorch early stopping methods | ||
|
||
## Design | ||
|
||
### Current API | ||
|
||
```python | ||
# Current: Each method has its own training implementation, mixing general training options with method-specific loss options. | ||
inference = NPE(prior=prior) | ||
inference.train( | ||
training_batch_size=50, | ||
learning_rate=5e-4, | ||
validation_fraction=0.1, | ||
stop_after_epochs=20, | ||
max_num_epochs=2**31-1, | ||
clip_max_norm=5.0, | ||
exclude_invalid_x=True, | ||
resume_training=False, | ||
show_train_summary=False, | ||
) | ||
``` | ||
|
||
### Proposed API | ||
|
||
```python | ||
# Proposed: Cleaner API with typed configurations | ||
from sbi.training import TrainConfig, LossArgs | ||
|
||
# Configure training (with IDE autocomplete and validation) | ||
train_config = TrainConfig( | ||
batch_size=50, | ||
learning_rate=5e-4, | ||
max_epochs=1000, | ||
device="cuda" | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really see the point of this. All of our train config args are ints, floats, or bools. Why do we need this additional abstraction? |
||
|
||
# Method-specific loss configuration | ||
loss_args = LossArgsNPE(exclude_invalid_x=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would not consider |
||
|
||
# Train with clean API | ||
inference = NPE(prior=prior) | ||
inference.train(train_config, loss_args) | ||
``` | ||
|
||
### Logging Interface | ||
|
||
Users can seamlessly switch between logging backends: | ||
|
||
```python | ||
from sbi.training import LoggingConfig | ||
|
||
# Choose your backend - no other code changes needed | ||
logging = LoggingConfig(backend="wandb", project="my-experiment") | ||
# or: LoggingConfig(backend="tensorboard", log_dir="./runs") | ||
# or: LoggingConfig(backend="mlflow", experiment_name="sbi-run") | ||
# or: LoggingConfig(backend="stdout") # default | ||
|
||
inference.train(train_config, loss_args, logging=logging) | ||
``` | ||
|
||
### Early Stopping | ||
|
||
Multiple strategies available out of the box: | ||
|
||
```python | ||
from sbi.training import EarlyStopping | ||
|
||
# Stop when validation loss plateaus | ||
early_stop = EarlyStopping.validation_loss(patience=20, min_delta=1e-4) | ||
|
||
# Stop when learning rate drops too low | ||
early_stop = EarlyStopping.lr_threshold(min_lr=1e-6) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this not a part of |
||
|
||
inference.train(train_config, loss_args, early_stopping=early_stop) | ||
``` | ||
|
||
### Backward Compatibility | ||
|
||
All existing code continues to work: | ||
|
||
```python | ||
# Old API still supported - no breaking changes | ||
inference.train(training_batch_size=100, learning_rate=1e-3) | ||
|
||
# Mix old and new as needed during migration | ||
inference.train( | ||
training_batch_size=100, # old style | ||
logging=LoggingConfig(backend="wandb") # new feature | ||
) | ||
``` | ||
|
||
### Unified Backend | ||
|
||
All inference methods share the same training infrastructure: | ||
|
||
```python | ||
# NPE, NLE, NRE all use the same configuration | ||
npe = NPE(prior=prior) | ||
npe.train(train_config, loss_args) | ||
|
||
nle = NLE(prior=prior) | ||
nle.train(train_config, loss_args) | ||
``` | ||
|
||
## Example: Complete Training Pipeline | ||
|
||
```python | ||
from sbi import utils | ||
from sbi.inference import NPE | ||
from sbi.training import TrainConfig, LossArgsNPE, LoggingConfig, EarlyStopping | ||
|
||
# Setup simulation | ||
prior = utils.BoxUniform(low=-2*torch.ones(2), high=2*torch.ones(2)) | ||
simulator = lambda theta: theta + 0.1 * torch.randn_like(theta) | ||
|
||
# Configure training with type safety and autocomplete | ||
config = TrainConfig( | ||
batch_size=100, | ||
learning_rate=1e-3, | ||
max_epochs=1000 | ||
) | ||
|
||
# Setup logging and early stopping | ||
logging = LoggingConfig(backend="wandb", project="sbi-experiment") | ||
early_stop = EarlyStopping.validation_loss(patience=20) | ||
|
||
# Train with new features | ||
inference = NPE(prior=prior) | ||
theta, x = utils.simulate_for_sbi(simulator, prior, num_simulations=10000) | ||
inference.append_simulations(theta, x) | ||
|
||
neural_net = inference.train( | ||
config, | ||
LossArgsNPE(exclude_invalid_x=True), | ||
logging=logging, | ||
early_stopping=early_stop | ||
) | ||
``` | ||
|
||
## Next steps | ||
|
||
Centralizing training logic in `base.py` has historically increased the size and | ||
responsibilities of the `NeuralInference` “god class”. As a natural next step, we | ||
propose extracting the entire training loop into a standalone function that takes the | ||
configured options and training components, and returns the trained network (plus | ||
optional artifacts), e.g., something like: | ||
|
||
```python | ||
def run_training( | ||
config: TrainConfig, | ||
model: torch.nn.Module, | ||
loss_fn: Callable[..., torch.Tensor], | ||
train_loader: DataLoader, | ||
val_loader: DataLoader | None = None, | ||
optimizer: torch.optim.Optimizer | None = None, | ||
scheduler: torch.optim.lr_scheduler._LRScheduler | None = None, | ||
callbacks: Sequence[Callback] | None = None, # logging, early stopping, etc. | ||
device: str | torch.device | None = None, | ||
) -> tuple[torch.nn.Module, TrainingSummary]: | ||
"""Runs the unified training loop and returns the trained model and summary.""" | ||
``` | ||
|
||
Benefits: | ||
|
||
- Shrinks `NeuralInference` and makes responsibilities explicit. | ||
- Improves testability (train loop covered independently; inference classes can be | ||
tested with lightweight mocks). | ||
- Enables pluggable logging/early-stopping via callbacks without entangling method- | ||
specific logic. | ||
- Keeps backward compatibility: inference classes compose `run_training()` internally | ||
while still exposing the existing `.train(...)` entry point. | ||
|
||
This should be tackled in a follow-up EP or PR that would introduce `run_training()` | ||
(and a minimal `Callback` protocol), migrate NPE/NLE/NRE to call it, and add focused | ||
unit tests for the training runner. | ||
|
||
## Feedback Wanted | ||
|
||
We welcome feedback and implementation interest in GitHub Discussions: | ||
|
||
1. Which logging backends are most important? | ||
2. What early stopping strategies would be useful? | ||
3. Any concerns about the proposed API? | ||
4. What do you think about the external training function? | ||
|
||
- Discussion thread: [EP-01 Discussion](https://github.com/sbi-dev/sbi/discussions/new?category=ideas) | ||
|
||
## References | ||
|
||
- [PR #1629](https://github.com/sbi-dev/sbi/pull/1629): Community early stopping implementation | ||
- [NUMFOCUS SDG Proposal](https://github.com/numfocus/small-development-grant-proposals/issues/60): Related funding proposal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there any evidence that patience-based is not working well? I would only add more methods this if it is really low-effort to implement and maintain other methods or if it they are really important for performance.