Fairness through Aleatoric Uncertainty (JAX Bayes by Backprop) #17157
Unanswered
aniquetahir
asked this question in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I would like to introduce our paper which just got accepted for CIKM about Fairness through Aleatoric Uncertainty:
https://arxiv.org/abs/2304.03646
JAX played a vital role to make this possible because we used its speed benefits to prototype faster (and very intuitive handling of the BNN parameters). Our JAX/Haiku implementation of Bayes by Backprop can be found here:
https://github.com/aniquetahir/GAIA/blob/master/utils/jax/models/bnn.py
To the best of my knowledge, this is the fastest implementation of this approach. It can be helpful as a model for a more generic version (since some fairness specific things are built-in). I embedded the jit compilation in a class method so its easy to replace different components using subclasses.
Hopefully the JAX folks don't forget us non-TPU academic users (at least not until they offer me a job at Google 😅).
Beta Was this translation helpful? Give feedback.
All reactions