Skip to content

get_duration(z, mask) encountering IndexError when running kappa scan but not when trainig AR-HMM or full model #174

@amorsi1

Description

@amorsi1

Hi! I'm encountering a problem during my kappa scan when it tries to visualize median duration of syllables.

 ---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[15], line 17
     15 # stage 1: fit the model with AR only
     16 model = kpms.update_hypparams(model, kappa=kappa)
---> 17 model = kpms.fit_model(
     18     model,
     19     data,
     20     metadata,
     21     project_dir,
     22     model_name,
     23     ar_only=True,
     24     num_iters=num_ar_iters,
     25     save_every_n_iters=25,
     26     parallel_message_passing=False
     27 )[0];
     29 # stage 2: fit the full model
     30 model = kpms.update_hypparams(model, kappa=kappa/decrease_kappa_factor)

File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/fitting.py:272](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/fitting.py#line=271), in fit_model(model, data, metadata, project_dir, model_name, num_iters, start_iter, verbose, ar_only, parallel_message_passing, jitter, generate_progress_plots, save_every_n_iters, location_aware, **kwargs)
    270                 save_hdf5(checkpoint_path, model, f"model_snapshots[/](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/own_files_keypoint_moseq_colab.ipynb){iteration}")
    271                 if generate_progress_plots:
--> 272                     plot_progress(
    273                         model,
    274                         data,
    275                         checkpoint_path,
    276                         iteration,
    277                         project_dir,
    278                         model_name,
    279                         savefig=True,
    280                     )
    282 return model, model_name

File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/viz.py:620](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/viz.py#line=619), in plot_progress(model, data, checkpoint_path, iteration, project_dir, model_name, path, savefig, fig_size, window_size, min_frequency, min_histogram_length)
    618         z = np.array(f[f"model_snapshots[/](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/own_files_keypoint_moseq_colab.ipynb){i}/states/z"])
    619         sample_state_history.append(z[batch_ix, start : start + window_size])
--> 620         median_durations.append(np.median(get_durations(z, mask)))
    622 axs[2].scatter(saved_iterations, median_durations)
    623 axs[2].set_ylim([-1, np.max(median_durations) * 1.1])

File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py:82](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py#line=81), in get_durations(stateseqs, mask)
     80 print(mask)
     81 #AM edits
---> 82 stateseq_flat = concatenate_stateseqs(stateseqs, mask=mask).astype(int)
     83 stateseq_padded = np.hstack([[-1], stateseq_flat, [-1]])
     84 changepoints = np.diff(stateseq_padded).nonzero()[0]

File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py:40](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py#line=39), in concatenate_stateseqs(stateseqs, mask)
     38     stateseq_flat = np.hstack(stateseqs)
     39 elif mask is not None:
---> 40     stateseq_flat = stateseqs[mask[:, -stateseqs.shape[1] :] > 0]
     41 else:
     42     stateseq_flat = stateseqs.flatten()

IndexError: boolean index did not match indexed array along dimension 0; dimension is 102 but corresponding boolean dimension is 116

This is on keypoint_moseq version 0.4.10 and jax_moseq version 0.2.2. I put some print statements into the jax_moseq utils get_durations function (since debugging on notebooks can be cumbersome) and let them crash out to find that the stateseqs and mask shapes consistently looked like this:

stateseqs len (116, 10027)
[[ 3  3  3 ... 79 79 46]
 [95 95  5 ... 90 90 90]
 [90 90 90 ... 58 58 58]
 ...
 [90 90 90 ... 54 54 54]
 [36 36 36 ... 90 90 90]
 [90 90 90 ... 67 90 90]]
mask len (116, 10030)
[[1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 0. 0. 0.]
 ...
 [1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 0. 0. 0.]]
stateseqs len (102, 10027)
[[72 72 96 ... 91 91 91]
 [72 72 72 ... 52 52 52]
 [52 52 52 ... 23 23 23]
 ...
 [23 23 23 ... 77 77 77]
 [23 23 23 ... 12 12 12]
 [12 12 12 ... 87 87 87]]
mask len (116, 10030)
[[1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 0. 0. 0.]
 ...
 [1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 0. 0. 0.]]

Keep in mind that only the first dimension of the shape is contributing to this mismatch, as this function is able to handle mismatches in the other dimensions. I don't encounter this issue when training the AR-HMM or the full model, so I just ran the AR-HMM cell a few times with different kappa values to decide on a good one then trained the full model. Wanted to open this issue in case anyone else runs into it or if Caleb has any thoughts on why this is happening.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions