A high-level API to build and train NNX models
Blaxbird
[blækbɜːd] is a high-level API to easily build NNX models and train them on CPU or GPU.
Using blaxbird
one can
- concisely define models and loss functions without the usual JAX/Flax verbosity,
- easily define checkpointers that save the best and most current network weights,
- distribute data and model weights over multiple processes or GPUs,
- define hooks that are periodically called during training.
In addition, blaxbird
offers high-quality implementations of common neural network modules and algorithms, such as:
- MLPs, DiTs, UNets,
- Flow Matching and Denoising Score Matching (EDM schedules) models with Euler and Heun samplers,
- Consistency Distillation/Matching models.
To use blaxbird
, one only needs to define a model, a loss function, and train and validation step functions:
import optax
from flax import nnx
class CNN(nnx.Module):
...
def loss_fn(model, images, labels):
logits = model(images)
return optax.losses.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=labels
).mean()
def train_step(model, rng_key, batch):
return nnx.value_and_grad(loss_fn)(model, batch["image"], batch["label"])
def val_step(model, rng_key, batch):
return loss_fn(model, batch["image"], batch["label"])
You can then define construct (and use) a training function like this:
import optax
from flax import nnx
from jax import random as jr
from blaxbird import train_fn
model = CNN(rngs=nnx.rnglib.Rngs(jr.key(1)))
optimizer = nnx.Optimizer(model, optax.adam(1e-4))
train = train_fn(
fns=(train_step, val_step),
n_steps=100,
eval_every_n_steps=10,
n_eval_batches=10
)
train(jr.key(2), model, optimizer, train_itr, val_itr)
See the entire self-contained example in examples/mnist_classification.
train_fn
is a higher order function with the following signature:
def train_fn(
*,
fns: tuple[Callable, Callable],
shardings: Optional[tuple[jax.NamedSharding, jax.NamedSharding]] = None,
n_steps: int,
eval_every_n_steps: int,
n_eval_batches: int,
log_to_wandb: bool = False,
hooks: Iterable[Callable] = (),
) -> Callable:
...
We briefly explain the more ambiguous argument types below.
fns
is a required argument consistenf of tuple of two functions, a step function and a validation function.
In the simplest case they look like this:
def train_step(model, rng_key, batch):
return nnx.value_and_grad(loss_fn)(model, batch["image"], batch["label"])
def val_step(model, rng_key, batch):
return loss_fn(model, batch["image"], batch["label"])
Both train_step
and val_step
have the same arguments and argument types:
model
specifies annx.Module
, i.e., a neural network like the CNN shown above.rng_key
is ajax.random.key
in case you need to generate random numbers.batch
is a sample from a data loader (to be specified later).
The loss function that is called by both computes a scalar loss value. B
While train_step
returns has to return the loss and gradients, val_step
only needs
to return the loss.
To specify how data and model weights are distributed over devices and processes,
blaxbird
uses JAX' sharding functionality.
shardings
is again specified by a tuple, one for the model sharding, the other for the data sharding.
An example is shown below, where we only distributed the data over num_devices
devices.
You can, if you don't want to distribute anything, just set the argument to None
or not specify it.
def get_sharding():
num_devices = jax.local_device_count()
mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh((num_devices,)), ("data",)
)
model_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec())
data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec("data"))
return model_sharding, data_sharding
hooks
is a list of callables which are periodically called during training.
Each hook has to have the following signature:
def hook_fn(step, *, model, **kwargs) -> None:
...
It takes an integer step
specifying the current training iteration and the model itself.
For instance, if you want to track custom metrics during validation, you could create a hook like this:
def hook_fn(metrics, val_iter, hook_every_n_steps):
def fn(step, *, model, **kwargs):
if step % hook_every_n_steps != 0:
return
for batch in val_iter:
logits = model(batch["image"])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch["label"]
).mean()
metrics.update(loss=loss, logits=logits, labels=batch["label"])
if jax.process_index() == 0:
curr_metrics = ", ".join(
[f"{k}: {v}" for k, v in metrics.compute().items()]
)
logging.info(f"metrics at step {step}: {curr_metrics}")
metrics.reset()
return fn
metrics = nnx.MultiMetric(
accuracy=nnx.metrics.Accuracy(),
loss=nnx.metrics.Average("loss"),
)
hook = hook_fn(metrics, val_iter, hook_every_n_steps)
This creates a hook function hook
that after eval_every_n_steps
steps iterates over the validation set
computes accuracy and loss, and then logs everything.
To provide multiple hooks to the train function, just concatenate them in a list.
We provide a convenient hook for checkpointing which can be constructed using
get_default_checkpointer
. The checkpointer saves both the last k
checkpoints with the lowest
validation loss and the last training checkpoint.
The signature of the hook is:
def get_default_checkpointer(
outfolder: str,
*,
save_every_n_steps: int,
max_to_keep: int = 5,
) -> tuple[Callable, Callable, Callable]
Its arguments are:
outfolder
: a folder specifying where to store the checkpoints.save_every_n_steps
: after how many training steps to store a checkpoint.max_to_keep
: the number of checkpoints to keep before starting to remove old checkpoints (to not clog the device).
For instance, you would construct the checkpointing function then like this:
from blaxbird import get_default_checkpointer
hook_save, *_ = get_default_checkpointer(
"checkpoints", save_every_n_steps=100
)
You can also use get_default_checkpointer
to restart the run where you left off.
get_default_checkpointer
in fact returns three functions, one for saving checkpoints and two for restoring
checkpoints:
from blaxbird import get_default_checkpointer
save, restore_best, restore_last = get_default_checkpointer(
"checkpoints", save_every_n_steps=100
)
You can then do either of:
model = CNN(rngs=nnx.rnglib.Rngs(jr.key(1)))
optimizer = nnx.Optimizer(model, optax.adam(1e-4))
model, optimizer = restore_best(model, optimizer)
model, optimizer = restore_last(model, optimizer)
After having defined train functions, hooks and shardings, you can train your model like this:
train = train_fn(
fns=(train_step, val_step),
n_steps=n_steps,
eval_every_n_steps=eval_every_n_steps,
n_eval_batches=n_eval_batches,
shardings=(model_sharding, data_sharding),
hooks=hooks,
log_to_wandb=False,
)
train(jr.key(1), model, optimizer, train_itr, val_itr)
Self-contained examples that also explain how the data loaders should look like can be found in examples.
To install the package from PyPI, call:
pip install blaxbird
To install the latest GitHub , just call the following on the command line:
pip install git+https://github.com/dirmeier/blaxbird@<RELEASE>
Simon Dirmeier simd@mailbox.org