Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions mkdocs/docs/proposals.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Enhancement Proposals (EPs)

Enhancement Proposals document and discuss substantial changes to sbi’s API, internals,
or processes. EPs provide context, motivation, and a concrete design for review and
decision-making.

This page serves as an entry point to all EPs and explains how they are organized.

## Process and scope

- EPs are used for larger changes that affect users or maintainers broadly (APIs,
architecture, governance, tooling).
- Each EP is a standalone Markdown file that explains the motivation, design,
alternatives, and migration plan if needed.
- EPs are numbered sequentially (EP-00, EP-01, …).
- Status values: Draft → Discussion → Accepted → Implementing → Completed →
Rejected/Withdrawn. The default is Draft.

## Current EPs

- [EP-00: Enhancement Proposal Process](proposals/ep-00-process.md)
- [EP-01: Pluggable Training Infrastructure for sbi](proposals/ep-01-pluggable-training.md)

If you would like to propose a new EP, open a discussion and a PR adding a new
`proposals/ep-XX-<short-title>.md` file.

For feedback and community discussion, we use [GitHub
Discussions](https://github.com/sbi-dev/sbi/discussions/categories/enhancement-proposals).
Each EP page links to its corresponding discussion thread when available.
52 changes: 52 additions & 0 deletions mkdocs/docs/proposals/ep-00-process.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# EP-00: Enhancement Proposal Process

Status: Discussion
Authors: @janfb

## Purpose

This document defines how `sbi` Enhancement Proposals (EPs) are created, discussed,
reviewed, and accepted. EPs are for substantial changes to APIs, architecture,
processes, or tooling that benefit from design discussion and broader visibility.

## When to write an EP

Write an EP when a change is:

- Affecting multiple modules or user-facing APIs.
- Architectural (e.g., training loop refactor, storage/logging backends).
- Process/governance-related (e.g., release, deprecation, contribution rules).
- Potentially controversial or with multiple viable designs.

Smaller changes (typo fixes, small bugfixes, doc tweaks) don’t need an EP.

## Structure of an EP

Recommended sections:

- Summary: One-paragraph overview of the proposal and motivation.
- Motivation: Context, problems with the status quo.
- Goals and Non-Goals: Clarify scope.
- Design: API changes, data flow, architecture.
- Alternatives Considered: Briefly describe alternatives and trade-offs.
- Backward Compatibility: Migration strategy and deprecations.
- Testing & Validation: How to verify correctness and performance.
- References: Related issues, PRs, or external work.

## Authoring and location

- Canonical source: sbi docs site under `mkdocs/docs/proposals/ep-XX-title.md`.
- Numbering: Incremental (EP-00, EP-01, …). Use a short kebab-case slug.
- Reviews happen via PRs to the repository, like other changes.

## Discussion and feedback

- Open a GitHub Discussion titled `EP-XX: <Title>` in the “Enhancement Proposals” category.
- Discussion post should include: brief summary, link to the EP doc, specific questions,
and decision criteria if relevant.

## Decision process

- Aim for consensus. Maintainers make final calls after sufficient discussion.
- Acceptance criteria: clear benefits, feasibility, alignment with sbi design
principles, and maintainability.
240 changes: 240 additions & 0 deletions mkdocs/docs/proposals/ep-01-pluggable-training.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
# EP-01: Pluggable Training Infrastructure for sbi

Status: Discussion
Feedback: See GitHub Discussion → [EP-01 Discussion](https://github.com/sbi-dev/sbi/discussions/new?category=ideas)

## Summary

This enhancement proposal describes the refactoring of `sbi` neural network training
infrastructure to address technical debt, improve maintainability, and provide users
with better experiment tracking capabilities. The refactoring will introduce typed
configuration objects, extract a unified training loop, and implement pluggable
interfaces for logging and early stopping while maintaining full backward compatibility.

## Motivation

The `sbi` package has evolved organically over five years to serve researchers in
simulation-based inference. This growth has resulted in several architectural issues in
the training infrastructure:

1. **Code duplication**: Training logic is duplicated across NPE, NLE, NRE, and other
inference methods, spanning over 3,000 lines of code
2. **Type safety issues**: Unstructured dictionary parameters lead to runtime errors and
poor IDE support
3. **Limited logging options**: Users are restricted to TensorBoard with no easy way to
integrate their preferred experiment tracking tools
4. **Missing features**: No built-in early stopping, leading to wasted compute and
potential overfitting
5. **Integration barriers**: Community contributions like PR #1629 cannot be easily
integrated due to architectural constraints

These issues affect both users and maintainers, slowing down development and making the
codebase harder to work with.

## Goals

### For Users

- **Seamless experiment tracking**: Support for TensorBoard, WandB, MLflow, and stdout
without changing existing code
- **Multiple early stopping strategies**: Patience-based, plateau detection, and custom
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any evidence that patience-based is not working well? I would only add more methods this if it is really low-effort to implement and maintain other methods or if it they are really important for performance.

criteria
- **Better debugging**: Typed configurations with clear error messages
- **Zero migration effort**: Full backward compatibility with existing code

### For Maintainers

- **Reduced code duplication**: Extract shared training logic into reusable components
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with doing this for this PR, but in general, I don't think that code duplication is always a bad thing. Introducing "reusable components" adds cognitive load and will make the package harder to maintain in the future.

- **Type safety**: Prevent entire classes of bugs through static typing
- **Easier feature integration**: Clean interfaces for community contributions
- **Future extensibility**: Foundation for integration with PyTorch early stopping methods

## Design

### Current API

```python
# Current: Each method has its own training implementation, mixing general training options with method-specific loss options.
inference = NPE(prior=prior)
inference.train(
training_batch_size=50,
learning_rate=5e-4,
validation_fraction=0.1,
stop_after_epochs=20,
max_num_epochs=2**31-1,
clip_max_norm=5.0,
exclude_invalid_x=True,
resume_training=False,
show_train_summary=False,
)
```

### Proposed API

```python
# Proposed: Cleaner API with typed configurations
from sbi.training import TrainConfig, LossArgs

# Configure training (with IDE autocomplete and validation)
train_config = TrainConfig(
batch_size=50,
learning_rate=5e-4,
max_epochs=1000,
device="cuda"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really see the point of this. All of our train config args are ints, floats, or bools. Why do we need this additional abstraction?


# Method-specific loss configuration
loss_args = LossArgsNPE(exclude_invalid_x=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not consider exclude_invalid_x to live in LossArgsNPE


# Train with clean API
inference = NPE(prior=prior)
inference.train(train_config, loss_args)
```

### Logging Interface

Users can seamlessly switch between logging backends:

```python
from sbi.training import LoggingConfig

# Choose your backend - no other code changes needed
logging = LoggingConfig(backend="wandb", project="my-experiment")
# or: LoggingConfig(backend="tensorboard", log_dir="./runs")
# or: LoggingConfig(backend="mlflow", experiment_name="sbi-run")
# or: LoggingConfig(backend="stdout") # default

inference.train(train_config, loss_args, logging=logging)
```

### Early Stopping

Multiple strategies available out of the box:

```python
from sbi.training import EarlyStopping

# Stop when validation loss plateaus
early_stop = EarlyStopping.validation_loss(patience=20, min_delta=1e-4)

# Stop when learning rate drops too low
early_stop = EarlyStopping.lr_threshold(min_lr=1e-6)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not a part of train_config?


inference.train(train_config, loss_args, early_stopping=early_stop)
```

### Backward Compatibility

All existing code continues to work:

```python
# Old API still supported - no breaking changes
inference.train(training_batch_size=100, learning_rate=1e-3)

# Mix old and new as needed during migration
inference.train(
training_batch_size=100, # old style
logging=LoggingConfig(backend="wandb") # new feature
)
```

### Unified Backend

All inference methods share the same training infrastructure:

```python
# NPE, NLE, NRE all use the same configuration
npe = NPE(prior=prior)
npe.train(train_config, loss_args)

nle = NLE(prior=prior)
nle.train(train_config, loss_args)
```

## Example: Complete Training Pipeline

```python
from sbi import utils
from sbi.inference import NPE
from sbi.training import TrainConfig, LossArgsNPE, LoggingConfig, EarlyStopping

# Setup simulation
prior = utils.BoxUniform(low=-2*torch.ones(2), high=2*torch.ones(2))
simulator = lambda theta: theta + 0.1 * torch.randn_like(theta)

# Configure training with type safety and autocomplete
config = TrainConfig(
batch_size=100,
learning_rate=1e-3,
max_epochs=1000
)

# Setup logging and early stopping
logging = LoggingConfig(backend="wandb", project="sbi-experiment")
early_stop = EarlyStopping.validation_loss(patience=20)

# Train with new features
inference = NPE(prior=prior)
theta, x = utils.simulate_for_sbi(simulator, prior, num_simulations=10000)
inference.append_simulations(theta, x)

neural_net = inference.train(
config,
LossArgsNPE(exclude_invalid_x=True),
logging=logging,
early_stopping=early_stop
)
```

## Next steps

Centralizing training logic in `base.py` has historically increased the size and
responsibilities of the `NeuralInference` “god class”. As a natural next step, we
propose extracting the entire training loop into a standalone function that takes the
configured options and training components, and returns the trained network (plus
optional artifacts), e.g., something like:

```python
def run_training(
config: TrainConfig,
model: torch.nn.Module,
loss_fn: Callable[..., torch.Tensor],
train_loader: DataLoader,
val_loader: DataLoader | None = None,
optimizer: torch.optim.Optimizer | None = None,
scheduler: torch.optim.lr_scheduler._LRScheduler | None = None,
callbacks: Sequence[Callback] | None = None, # logging, early stopping, etc.
device: str | torch.device | None = None,
) -> tuple[torch.nn.Module, TrainingSummary]:
"""Runs the unified training loop and returns the trained model and summary."""
```

Benefits:

- Shrinks `NeuralInference` and makes responsibilities explicit.
- Improves testability (train loop covered independently; inference classes can be
tested with lightweight mocks).
- Enables pluggable logging/early-stopping via callbacks without entangling method-
specific logic.
- Keeps backward compatibility: inference classes compose `run_training()` internally
while still exposing the existing `.train(...)` entry point.

This should be tackled in a follow-up EP or PR that would introduce `run_training()`
(and a minimal `Callback` protocol), migrate NPE/NLE/NRE to call it, and add focused
unit tests for the training runner.

## Feedback Wanted

We welcome feedback and implementation interest in GitHub Discussions:

1. Which logging backends are most important?
2. What early stopping strategies would be useful?
3. Any concerns about the proposed API?
4. What do you think about the external training function?

- Discussion thread: [EP-01 Discussion](https://github.com/sbi-dev/sbi/discussions/new?category=ideas)

## References

- [PR #1629](https://github.com/sbi-dev/sbi/pull/1629): Community early stopping implementation
- [NUMFOCUS SDG Proposal](https://github.com/numfocus/small-development-grant-proposals/issues/60): Related funding proposal
1 change: 1 addition & 0 deletions mkdocs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ nav:
- Contributing:
- How to contribute: contribute.md
- Code of Conduct: code_of_conduct.md
- Enhancement Proposals: proposals.md
- Citation: citation.md
- Credits: credits.md

Expand Down