Skip to content
/ lox Public

Logging library for JAX that is compatible with transformations and primitives such as vmap and scan.

License

Notifications You must be signed in to change notification settings

huterguier/lox

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Accelerated logging in JAX

Lox is a lightweight and flexible logging library for JAX that provides a simple interface for logging data during function execution. Logging is implemented with it's own primitive, which allows it to work seamlessly with JAX's built-in function transformations like jit or vmap. All you need to do is decorate your code with lox.log statements and Lox does the rest. Using JAX's intermediate function representation Lox can dynamically insert callbacks to log you data or collect the logs that would have been generated during the execution and return them as part of the output of you function. While it's obviously possible to implement this functionality yourself, Lox provides a simple and efficient way to do so without having to carry around boilerplate code in your functions.

>>> import jax
>>> import jax.numpy as jnp
>>> import lox

>>> def f(xs):
...     lox.log({"xs": xs})
...     def step(carry, x):
...         carry += x
...         lox.log({"carry": carry})
...         return carry, x
...     y, _ = jax.lax.scan(step, 0, xs)
...     return y

>>> xs = jnp.arange(3)

The first transformation, lox.tap, lets you "tap into" function execution by attaching a callback that receives logs as they're generated. It streams logs in real time, making it great for debugging or live monitoring.

>>> def callback(logs):
...     print("Logging:", logs)
>>> y = lox.tap(f, callback=callback)(xs)
Logging: {'xs': [0, 1, 2]}
Logging: {'carry': 0}
Logging: {'carry': 1}
Logging: {'carry': 3}

The second transformation, lox.spool, "spools up" all logs during execution and returns them alongside the function's output. This is especially useful when frequent callbacks would be too expensive. For instance, instead of logging on every iteration, you can collect all logs for a training step and emit them in a single call.

>>> y, logs = lox.spool(f)(xs)
>>> print("Collected Logs:", logs)
Collected Logs: {'xs': [0, 1, 2], 'carry': [0, 1, 3]}

Installation

Lox can be installed via pip directly from the GitHub repository.

pip install git+https://github.com/huterguier/lox

About

Logging library for JAX that is compatible with transformations and primitives such as vmap and scan.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages