Global way to automatically convert all floating-point computations to bfloat16 #30106
Replies: 1 comment
-
Hi, for this, you can use MPX (https://github.com/Data-Science-in-Mechanical-Engineering/mixed_precision_for_JAX) an update and extension of JMP that we released this month. MPX is a toolbox for mixed precision training and allows you to port your training pipeline to mixed precision with just a few changes. In case you dont want to do mixed precision training, but just cast your function, the toolbox contains Example: import jax
import jax.numpy as jnp
import mpx
def my_func(x):
x = x+1
print(x.dtype)
return x
my_func(jnp.zeros((42, ))) # prints float32
mpx.cast_function(my_func, dtype=jnp.bfloat16)(jnp.zeros((42, ))) # prints bfloat16 But you must make sure that the function does not contain a constant in a higher datatype (e.g. def my_func(x):
my_const = jnp.zeros((42)) # change this to jnp.zeros((42), dtype=jnp.bfloat16)
x = x+my_const
print(x.dtype)
return x Best, Alex |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Given an existing program, is there any global way to automatically convert all of its floating-point computations to bfloat16? (Without having to manually litter the code with
dtype=jnp.bfloat16
flags inside every constructor.)If not, would it be worth considering adding such functionality as a feature? Perhaps via a context manager or configuration argument to jit.
Beta Was this translation helpful? Give feedback.
All reactions