Skip to content

Commit 04f6bfa

Browse files
levskayajax authors
authored andcommitted
Prevent accidental upcasting in jax.nn.initializers.
Currently distribution parameters such as stddev and scale are expected to be weakly typed scalars. When they're passed as float32 they can cause an upcast of the initialized arrays even when the dtype is specified as e.g. bfloat16. Some users were surprised by this. PiperOrigin-RevId: 611858446
1 parent 5c9c57f commit 04f6bfa

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

jax/_src/nn/initializers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def init(key: KeyArray,
130130
shape: core.Shape,
131131
dtype: DTypeLikeInexact = dtype) -> Array:
132132
dtype = dtypes.canonicalize_dtype(dtype)
133-
return random.uniform(key, shape, dtype) * scale
133+
return random.uniform(key, shape, dtype) * jnp.array(scale, dtype)
134134
return init
135135

136136
@export
@@ -156,7 +156,7 @@ def init(key: KeyArray,
156156
shape: core.Shape,
157157
dtype: DTypeLikeInexact = dtype) -> Array:
158158
dtype = dtypes.canonicalize_dtype(dtype)
159-
return random.normal(key, shape, dtype) * stddev
159+
return random.normal(key, shape, dtype) * jnp.array(stddev, dtype)
160160
return init
161161

162162
@export
@@ -193,7 +193,8 @@ def init(key: KeyArray,
193193
shape: core.Shape,
194194
dtype: DTypeLikeInexact = dtype) -> Array:
195195
dtype = dtypes.canonicalize_dtype(dtype)
196-
return random.truncated_normal(key, lower, upper, shape, dtype) * stddev
196+
return random.truncated_normal(
197+
key, lower, upper, shape, dtype) * jnp.array(stddev, dtype)
197198
return init
198199

199200
@export
@@ -613,7 +614,7 @@ def init(key: KeyArray,
613614
if n_rows < n_cols: Q = Q.T
614615
Q = jnp.reshape(Q, tuple(np.delete(shape, column_axis)) + (shape[column_axis],))
615616
Q = jnp.moveaxis(Q, -1, column_axis)
616-
return scale * Q
617+
return jnp.array(scale, dtype) * Q
617618
return init
618619

619620
@export

tests/nn_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,17 @@ def testVarianceScalingError(self):
403403
):
404404
initializer(rng, shape)
405405

406+
def testAccidentalUpcasting(self):
407+
rng = random.PRNGKey(0)
408+
shape = (4, 4)
409+
scalar_param = jnp.array(1.0, dtype=jnp.float32)
410+
for init_fn in (nn.initializers.uniform(scalar_param, jnp.bfloat16),
411+
nn.initializers.normal(scalar_param, jnp.bfloat16),
412+
nn.initializers.truncated_normal(scalar_param, jnp.bfloat16),
413+
):
414+
sub_rng, rng = random.split(rng)
415+
val = init_fn(sub_rng, shape)
416+
self.assertEqual(val.dtype, jnp.bfloat16)
406417

407418
if __name__ == "__main__":
408419
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)