Skip to content

Vectorization of mask generation for a custom dataset #18449

Answered by jakevdp
mohamad-amin asked this question in Q&A
Discussion options

You must be logged in to vote

I think you can compute the result you're after more efficiently using the return_inverse argument of jnp.unique:

import jax
import jax.numpy as jnp

def make_mask(data):
  m = data[:, 2].max() + 1
  pairs, idx = jnp.unique(data[:, :2], axis=0, return_inverse=True)
  return jnp.zeros_like(data, shape=(len(pairs), m)).at[idx, data[:, 2]].set(1)

m = 1000
N = 10000

key = jax.random.key(0)
data = jax.random.randint(key, shape=(N, 3), minval=0, maxval=1000)

print(make_mask(data).shape)
# (9952, 1000)

With a smaller dataset, I think it shows that it gives you what you described:

data = jnp.array([[1, 2, 3],
                  [1, 2, 4],
                  [1, 3, 2],
                  [2, 3, 0]])

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by mohamad-amin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants