Vectorization of mask generation for a custom dataset #18449
-
Hey, I have an input dataset in the form of # Given: dataset, an N x 3 matrix of input (i,j,k) pairs.
def maskup(data, row):
vals = jnp.where((data[:, 0] == row[0]) & (data[:, 1] == row[1]), data[:, 2], -1)
return jnp.zeros(p+1).at[vals].set(jnp.ones(len(vals)))[:p]
v_maskup = jax.jit(jax.vmap(maskup, in_axes=(None, 0)))
dataset_ij = jnp.unique(dataset[:, :2], axis=0) # Find unique (i, j) pairs
mask = v_maskup(dataset, dataset_ij) I'm wondering how this can be improved to avoid the OOM, but I'm not sure how I should approach this. I'd be grateful for any feedbacks or comments. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I think you can compute the result you're after more efficiently using the import jax
import jax.numpy as jnp
def make_mask(data):
m = data[:, 2].max() + 1
pairs, idx = jnp.unique(data[:, :2], axis=0, return_inverse=True)
return jnp.zeros_like(data, shape=(len(pairs), m)).at[idx, data[:, 2]].set(1)
m = 1000
N = 10000
key = jax.random.key(0)
data = jax.random.randint(key, shape=(N, 3), minval=0, maxval=1000)
print(make_mask(data).shape)
# (9952, 1000) With a smaller dataset, I think it shows that it gives you what you described: data = jnp.array([[1, 2, 3],
[1, 2, 4],
[1, 3, 2],
[2, 3, 0]])
print(make_mask(data))
# [[0 0 0 1 1]
# [0 0 1 0 0]
# [1 0 0 0 0]] |
Beta Was this translation helpful? Give feedback.
I think you can compute the result you're after more efficiently using the
return_inverse
argument ofjnp.unique
:With a smaller dataset, I think it shows that it gives you what you described: