-
Notifications
You must be signed in to change notification settings - Fork 36
Description
Hi! I'm encountering a problem during my kappa scan when it tries to visualize median duration of syllables.
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[15], line 17
15 # stage 1: fit the model with AR only
16 model = kpms.update_hypparams(model, kappa=kappa)
---> 17 model = kpms.fit_model(
18 model,
19 data,
20 metadata,
21 project_dir,
22 model_name,
23 ar_only=True,
24 num_iters=num_ar_iters,
25 save_every_n_iters=25,
26 parallel_message_passing=False
27 )[0];
29 # stage 2: fit the full model
30 model = kpms.update_hypparams(model, kappa=kappa/decrease_kappa_factor)
File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/fitting.py:272](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/fitting.py#line=271), in fit_model(model, data, metadata, project_dir, model_name, num_iters, start_iter, verbose, ar_only, parallel_message_passing, jitter, generate_progress_plots, save_every_n_iters, location_aware, **kwargs)
270 save_hdf5(checkpoint_path, model, f"model_snapshots[/](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/own_files_keypoint_moseq_colab.ipynb){iteration}")
271 if generate_progress_plots:
--> 272 plot_progress(
273 model,
274 data,
275 checkpoint_path,
276 iteration,
277 project_dir,
278 model_name,
279 savefig=True,
280 )
282 return model, model_name
File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/viz.py:620](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/viz.py#line=619), in plot_progress(model, data, checkpoint_path, iteration, project_dir, model_name, path, savefig, fig_size, window_size, min_frequency, min_histogram_length)
618 z = np.array(f[f"model_snapshots[/](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/own_files_keypoint_moseq_colab.ipynb){i}/states/z"])
619 sample_state_history.append(z[batch_ix, start : start + window_size])
--> 620 median_durations.append(np.median(get_durations(z, mask)))
622 axs[2].scatter(saved_iterations, median_durations)
623 axs[2].set_ylim([-1, np.max(median_durations) * 1.1])
File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py:82](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py#line=81), in get_durations(stateseqs, mask)
80 print(mask)
81 #AM edits
---> 82 stateseq_flat = concatenate_stateseqs(stateseqs, mask=mask).astype(int)
83 stateseq_padded = np.hstack([[-1], stateseq_flat, [-1]])
84 changepoints = np.diff(stateseq_padded).nonzero()[0]
File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py:40](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py#line=39), in concatenate_stateseqs(stateseqs, mask)
38 stateseq_flat = np.hstack(stateseqs)
39 elif mask is not None:
---> 40 stateseq_flat = stateseqs[mask[:, -stateseqs.shape[1] :] > 0]
41 else:
42 stateseq_flat = stateseqs.flatten()
IndexError: boolean index did not match indexed array along dimension 0; dimension is 102 but corresponding boolean dimension is 116
This is on keypoint_moseq version 0.4.10 and jax_moseq version 0.2.2. I put some print statements into the jax_moseq utils get_durations function (since debugging on notebooks can be cumbersome) and let them crash out to find that the stateseqs and mask shapes consistently looked like this:
stateseqs len (116, 10027)
[[ 3 3 3 ... 79 79 46]
[95 95 5 ... 90 90 90]
[90 90 90 ... 58 58 58]
...
[90 90 90 ... 54 54 54]
[36 36 36 ... 90 90 90]
[90 90 90 ... 67 90 90]]
mask len (116, 10030)
[[1. 1. 1. ... 1. 1. 1.]
[1. 1. 1. ... 1. 1. 1.]
[1. 1. 1. ... 0. 0. 0.]
...
[1. 1. 1. ... 0. 0. 0.]
[1. 1. 1. ... 1. 1. 1.]
[1. 1. 1. ... 0. 0. 0.]]
stateseqs len (102, 10027)
[[72 72 96 ... 91 91 91]
[72 72 72 ... 52 52 52]
[52 52 52 ... 23 23 23]
...
[23 23 23 ... 77 77 77]
[23 23 23 ... 12 12 12]
[12 12 12 ... 87 87 87]]
mask len (116, 10030)
[[1. 1. 1. ... 1. 1. 1.]
[1. 1. 1. ... 1. 1. 1.]
[1. 1. 1. ... 0. 0. 0.]
...
[1. 1. 1. ... 0. 0. 0.]
[1. 1. 1. ... 1. 1. 1.]
[1. 1. 1. ... 0. 0. 0.]]
Keep in mind that only the first dimension of the shape is contributing to this mismatch, as this function is able to handle mismatches in the other dimensions. I don't encounter this issue when training the AR-HMM or the full model, so I just ran the AR-HMM cell a few times with different kappa values to decide on a good one then trained the full model. Wanted to open this issue in case anyone else runs into it or if Caleb has any thoughts on why this is happening.