Skip to content

dirmeier/blaxbird

Repository files navigation

blaxbird [blækbɜːd]

ci version

A high-level API to build and train NNX models

About

Blaxbird [blækbɜːd] is a high-level API to easily build NNX models and train them on CPU or GPU.

Using blaxbird one can

  • concisely define models and loss functions without the usual JAX/Flax verbosity,
  • easily define checkpointers that save the best and most current network weights,
  • distribute data and model weights over multiple processes or GPUs,
  • define hooks that are periodically called during training.

In addition, blaxbird offers high-quality implementations of common neural network modules and algorithms, such as:

  • MLPs, DiTs, UNets,
  • Flow Matching and Denoising Score Matching (EDM schedules) models with Euler and Heun samplers,
  • Consistency Distillation/Matching models.

Example

To use blaxbird, one only needs to define a model, a loss function, and train and validation step functions:

import optax
from flax import nnx

class CNN(nnx.Module):
  ...

def loss_fn(model, images, labels):
  logits = model(images)
  return optax.losses.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=labels
  ).mean()

def train_step(model, rng_key, batch):
    return nnx.value_and_grad(loss_fn)(model, batch["image"], batch["label"])

def val_step(model, rng_key, batch):
    return loss_fn(model, batch["image"], batch["label"])

You can then define construct (and use) a training function like this:

import optax
from flax import nnx
from jax import random as jr

from blaxbird import train_fn

model = CNN(rngs=nnx.rnglib.Rngs(jr.key(1)))
optimizer = nnx.Optimizer(model, optax.adam(1e-4))

train = train_fn(
  fns=(train_step, val_step),
  n_steps=100,
  eval_every_n_steps=10,
  n_eval_batches=10
)
train(jr.key(2), model, optimizer, train_itr, val_itr)

See the entire self-contained example in examples/mnist_classification.

Usage

train_fn is a higher order function with the following signature:

def train_fn(
  *,
  fns: tuple[Callable, Callable],
  shardings: Optional[tuple[jax.NamedSharding, jax.NamedSharding]] = None,
  n_steps: int,
  eval_every_n_steps: int,
  n_eval_batches: int,
  log_to_wandb: bool = False,
  hooks: Iterable[Callable] = (),
) -> Callable:
  ...

We briefly explain the more ambiguous argument types below.

fns

fns is a required argument consistenf of tuple of two functions, a step function and a validation function. In the simplest case they look like this:

def train_step(model, rng_key, batch):
    return nnx.value_and_grad(loss_fn)(model, batch["image"], batch["label"])

def val_step(model, rng_key, batch):
    return loss_fn(model, batch["image"], batch["label"])

Both train_step and val_step have the same arguments and argument types:

  • model specifies a nnx.Module, i.e., a neural network like the CNN shown above.
  • rng_key is a jax.random.key in case you need to generate random numbers.
  • batch is a sample from a data loader (to be specified later).

The loss function that is called by both computes a scalar loss value. B While train_step returns has to return the loss and gradients, val_step only needs to return the loss.

shardings

To specify how data and model weights are distributed over devices and processes, blaxbird uses JAX' sharding functionality.

shardings is again specified by a tuple, one for the model sharding, the other for the data sharding. An example is shown below, where we only distributed the data over num_devices devices. You can, if you don't want to distribute anything, just set the argument to None or not specify it.

def get_sharding():
  num_devices = jax.local_device_count()
  mesh = jax.sharding.Mesh(
    mesh_utils.create_device_mesh((num_devices,)), ("data",)
  )
  model_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec())
  data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec("data"))
  return model_sharding, data_sharding

hooks

hooks is a list of callables which are periodically called during training. Each hook has to have the following signature:

def hook_fn(step, *, model, **kwargs) -> None:
  ...

It takes an integer step specifying the current training iteration and the model itself. For instance, if you want to track custom metrics during validation, you could create a hook like this:

def hook_fn(metrics, val_iter, hook_every_n_steps):
  def fn(step, *, model, **kwargs):
    if step % hook_every_n_steps != 0:
      return
    for batch in val_iter:
      logits = model(batch["image"])
      loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch["label"]
      ).mean()
      metrics.update(loss=loss, logits=logits, labels=batch["label"])
    if jax.process_index() == 0:
      curr_metrics = ", ".join(
        [f"{k}: {v}" for k, v in metrics.compute().items()]
      )
      logging.info(f"metrics at step {step}: {curr_metrics}")
    metrics.reset()
  return fn

metrics = nnx.MultiMetric(
  accuracy=nnx.metrics.Accuracy(),
  loss=nnx.metrics.Average("loss"),
)
hook = hook_fn(metrics, val_iter, hook_every_n_steps)

This creates a hook function hook that after eval_every_n_steps steps iterates over the validation set computes accuracy and loss, and then logs everything.

To provide multiple hooks to the train function, just concatenate them in a list.

A checkpointing hook

We provide a convenient hook for checkpointing which can be constructed using get_default_checkpointer. The checkpointer saves both the last k checkpoints with the lowest validation loss and the last training checkpoint.

The signature of the hook is:

def get_default_checkpointer(
  outfolder: str,
  *,
  save_every_n_steps: int,
  max_to_keep: int = 5,
) -> tuple[Callable, Callable, Callable]

Its arguments are:

  • outfolder: a folder specifying where to store the checkpoints.
  • save_every_n_steps: after how many training steps to store a checkpoint.
  • max_to_keep: the number of checkpoints to keep before starting to remove old checkpoints (to not clog the device).

For instance, you would construct the checkpointing function then like this:

from blaxbird import get_default_checkpointer

hook_save, *_ = get_default_checkpointer(
  "checkpoints", save_every_n_steps=100
)

Restoring a run

You can also use get_default_checkpointer to restart the run where you left off. get_default_checkpointer in fact returns three functions, one for saving checkpoints and two for restoring checkpoints:

from blaxbird import get_default_checkpointer

save, restore_best, restore_last = get_default_checkpointer(
  "checkpoints", save_every_n_steps=100
)

You can then do either of:

model = CNN(rngs=nnx.rnglib.Rngs(jr.key(1)))
optimizer = nnx.Optimizer(model, optax.adam(1e-4))

model, optimizer = restore_best(model, optimizer)
model, optimizer = restore_last(model, optimizer)

Doing training

After having defined train functions, hooks and shardings, you can train your model like this:

train = train_fn(
  fns=(train_step, val_step),
  n_steps=n_steps,
  eval_every_n_steps=eval_every_n_steps,
  n_eval_batches=n_eval_batches,
  shardings=(model_sharding, data_sharding),
  hooks=hooks,
  log_to_wandb=False,
)
train(jr.key(1), model, optimizer, train_itr, val_itr)

Self-contained examples that also explain how the data loaders should look like can be found in examples.

Installation

To install the package from PyPI, call:

pip install blaxbird

To install the latest GitHub , just call the following on the command line:

pip install git+https://github.com/dirmeier/blaxbird@<RELEASE>

Author

Simon Dirmeier simd@mailbox.org

About

A high-level API to build and train NNX models.

Topics

Resources

License

Stars

Watchers

Forks