Skip to content

Kernel crashing for large datasets initialization #177

@JSignoretGenest

Description

@JSignoretGenest

Hello,

We've been successfully using keypoint-moseq on several datasets, but we are encountering a consistent issue with larger datasets, where the kernel crashes during model initialization.

Steps Taken to Isolate the Issue:

  1. Tested with smaller datasets: the model initializes and runs fine.
  2. Simulated larger dataset: we artificially inflated previously working datasets to match the size of the problematic dataset, and the kernel crash occurred, confirming the issue is size-related and not specific to the data itself.
  3. Isolated problematic function: we traced the issue to init_states in jax_moseq/models/arhmm/initialize.py, particularly in the function resample_discrete_stateseqs.

Behavior Observed:

  1. Without @jax.jit on init_states: the kernel dies after processing log_likelihoods correctly but before returning it to init_states. We added print statements for debugging, and the crash occurs between the successful processing and before the value is returned to init_states.
  2. With @jax.jit: the issue is deferred to the return of z. Again, z is processed with the correct array dimensions (as confirmed by logs), but the kernel crashes upon returning the value to init_states. Not returning the value prevents the crash, and init_states continues to execute normally right after, which suggests that the crash happens during the return.
  3. JAX Debugging: for working datasets, jax.debug.print successfully outputs e.g. array shapes, but for datasets that cause the crash, jax.debug.print produces no output, even though it effectively goes to the following steps.

Attempts to Work Around the Issue:

  1. Manual batching and vmap/lax batching: we tried manually processing z in smaller batches or with vmap/lax to reduce the size of the returned array. However, the kernel still crashes, now when init_states returns z.
  2. System Monitoring: we did not observe any VRAM or RAM saturation, so memory exhaustion does not seem to be the cause.

We would be grateful for any workaround for that step, even if that translates to a slower initialization!

Thank you for your help!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions