Possible regression in jaxlib 0.4.25+ causing training deadlocks on GPU #25453
Replies: 2 comments 3 replies
-
Unfortunately it's impossible to say what's going wrong with only this information. It looks like a deadlock, but we don't know how or why without knowing more. I can think of two things that might help: b) a reproducer that we could run would also help. |
Beta Was this translation helpful? Give feedback.
-
Thank you for your help @hawkinsp and @mattjj. Shortly after our initial post, Jax 0.4.38 was released, which we believed fixed the issue. (Jax 0.4.36 and 0.4.37 did not deadlock, but produced many warnings — we can provide logs if desired.) That said, we’re posting this followup to bring it to your attention so that it may help some other poor soul and/or help avoid this issue from resurfacing in the future. We finally managed to isolate the deadlock reproducibly in the following attached script. Unfortunately, this requires specific data files. The issue is a very rare edge case, and we were unable to simplify the following code to something with an obvious source. Here’s what we know about triggering this deadlock:
This is GPU-agnostic, as we reproduced it on an We’re sorry we were unable to further isolate this MWE, but we hope that the result is reproducible for you. If it is not, please let us know and we do what we can to ensure reproducibility. If Jax 0.4.38 has fixed this issue for good, please let us know! import equinox as eqx # Any version >= 0.11.09
import jax
import jax.numpy as jnp
from jax import vmap
class Frame(eqx.Module):
rotation: jnp.ndarray
translation: jnp.ndarray
def __matmul__(self, v: jnp.ndarray):
assert v.shape == (3,)
return self.rotation @ v + self.translation
def kabsch(w: jax.Array, A: jax.Array, B: jax.Array, eps: float = 1e-8) -> Frame:
A_mean = (A * w[:, None]).sum(axis=0) / (w.sum() + eps)
B_mean = (B * w[:, None]).sum(axis=0) / (w.sum() + eps)
A_centered = A - A_mean[None, :]
B_centered = B - B_mean[None, :]
def _project_to_SO3(m):
u, _, vt = jnp.linalg.svd(m, full_matrices=False)
d = jnp.sign(jnp.linalg.det(vt.T @ u.T))
return u @ jnp.diag(jnp.array([1, 1, d])) @ vt
S = A_centered.T @ (w[:, None] * B_centered)
R = _project_to_SO3(S).T
t = B_mean - R @ A_mean
return Frame(R, t)
def prox_kabsch_solver(
p,
displacements: jax.Array,
F_anchor,
*,
max_iters: int,
tol: float,
lam: float,
disable_hang: bool,
):
n = p.shape[0]
assert displacements.shape == (n, n, 3)
p = p.at[jnp.diag_indices(n)].set(0)
def _solve_subproblem(F_targ_i, F, displacements_row, displacements_col, p_row, p_col):
A_1 = displacements_row
B_1 = F.translation
A_2 = jnp.zeros((n, 3))
B_2 = vmap(lambda F_j, d_ji: F_j @ d_ji)(F, displacements_col)
X = jnp.eye(3)
Y = vmap(lambda v: F_targ_i @ v)(X)
F_i = kabsch(
jnp.concatenate((p_row, p_col, lam * jnp.ones(3))),
jnp.concatenate((A_1, A_2, X)),
jnp.concatenate((B_1, B_2, Y)),
)
return F_i
def _parallel_block_coordinate_descent_step(F):
F = vmap(
lambda targ, dr, dc, pr, pc: _solve_subproblem(targ, F, dr, dc, pr, pc),
in_axes=(0, 0, 1, 0, 1),
)(F_anchor, displacements, displacements, p, p)
return Frame(F.rotation, F.translation - F.translation.mean(axis=0, keepdims=True))
def _compute_diff(F_old, F):
diff_t = jnp.linalg.norm(F_old.translation - F.translation)
diff_r = jnp.linalg.norm(F_old.rotation - F.rotation)
diff = (diff_t + diff_r) / n
return diff
def body_fn(carry):
i, _, F_old, _, diff_hist = carry
F = _parallel_block_coordinate_descent_step(F_old)
diff = _compute_diff(F_old, F)
diff_hist = diff_hist.at[i + 1].set(diff)
return i + 1, F_old, F, diff, diff_hist
def cond_fn(carry):
i, F_old, F, diff_carry, _ = carry
diff = _compute_diff(F_old, F)
# Can use a callback to disable the hang
if disable_hang:
jax.debug.callback(lambda _: None, diff)
return (i < max_iters) & (diff > tol)
# NOTE: We perform one iteration of the while loop manually
# This ensures that `diff == tol` is achieved in `cond_fun` explicitly at the first iteration
F_old, F = F_anchor, _parallel_block_coordinate_descent_step(F_anchor)
diff = _compute_diff(F_old, F)
diff_hist = jnp.zeros(max_iters, dtype=jnp.float32)
diff_hist = diff_hist.at[0].set(diff)
it, _, F, diff, diff_hist = jax.lax.while_loop(cond_fn, body_fn, (0, F_old, F, diff, diff_hist))
return F, it, diff_hist
@eqx.filter_jit
def single_step(w, d, R, t, max_iters=4, tol=1e-1, lam=2.0, disable_hang=False):
return vmap(
lambda w, d, f: prox_kabsch_solver(
w, d, f, max_iters=max_iters, tol=tol, lam=lam, disable_hang=disable_hang
)
)(w, d, Frame(R, t))
# Load data from npz file and convert to jax arrays
path_to_data = "/path/to/data"
data = jnp.load(f"{path_to_data}/data.npz")
data = {k: jnp.array(v) for k, v in data.items()}
print("single step: Non-exact tolerance -- Does not hang")
F, it, diff_hist = single_step(**data, tol=1e-1 - 1e-8, disable_hang=False)
print(it)
print(diff_hist)
# NOTE: On a H100 the tolerance of 1e-1 will work, but on other hardware may need to modify the tolerance
new_tol = 1e-1
new_tol = diff_hist[0, 0].item()
print("single step: Exact tolerance with disable_crash=True -- Does not hang")
F, it, diff_hist = single_step(**data, tol=new_tol, disable_hang=True)
print(it)
print(diff_hist)
print("single step: Exact tolerance -- hangs")
F, it, diff_hist = single_step(**data, tol=new_tol, disable_hang=False)
print(it)
print(diff_hist) Data file to run mwe.py: data.npz.zip |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
We have a model training script that began to experience deadlocks during GPU computation since upgrading from jax 0.4.13 --> 0.4.25+. In particular, these issues emerge with jaxlib 0.4.25, disappear with 0.4.26, and are then present from jaxlib 0.4.27 onwards. We'd appreciate any insights into how we can further understand what's going on. We're attempting to create an MRE in the meantime, but our training code is quite complicated and we're working on bisecting the issue. These issues are present with both single and multi-GPU training runs, and during testing the single-GPU case, removing all sharding-related code does not resolve the issue.
Regression description:
At some point during training we call into a jitted
single_step
function (computing loss and gradients) and this function never exits (nor does it crash), as evidenced by apy-spy
trace. This happens non-deterministically minutes to hours into training runs. We're using weight and biases for logging, and from system resource logs we can see that at the time of the deadlock our GPU power usage decreases to a nontrivial amount and stays at that level with extremely low variation (image below), and looking at the python process we can see that it's waiting for control to return. To reiterate, this appears to be a regression. Our training seems to run just fine on jaxlib 0.4.24 and below.Here's what the GPU power usage looks like, with the hang occuring at around ~26k on the x-axis:

For reference, our H100s idle at ~100W, so something is happening.
Attempt at diagnostics:
When I exec into the training pod after the hang has occurred, I see that the training python process is alive (PID 1)
but it's waiting on a FUTEX
I obtained a backtrace from
gdb
I can provide the rest of the
bt
if anyone thinks it would be helpful.Obligatory environment dump:
Any insights into what's happening or suggestions for debugging this issue would be massively appreciated!
Beta Was this translation helpful? Give feedback.
All reactions