You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Prevent accidental upcasting in jax.nn.initializers.
Currently distribution parameters such as stddev and scale are expected to be
weakly typed scalars. When they're passed as float32 they can cause an upcast
of the initialized arrays even when the dtype is specified as e.g. bfloat16.
Some users were surprised by this.
PiperOrigin-RevId: 611858446
0 commit comments