-
In the case of a sparse CSR matrix, indices of one vector are stored in another. For example, we have a sparse matrix of user-item interactions in CSR format, and we try to find if a person had an interaction with some particular item: import jax.experimental.sparse as jsp
import jax.numpy as jnp
interactions = jsp.BCSR.fromdense(jnp.array([[1,0,0,0],[0,1,0,0],[0,0,1,1]]))
no_users, no_items = interactions.shape
def in_positives(user_id, item_id, interactions) -> bool:
start, end = int(interactions.indptr[user_id]), int(interactions.indptr[user_id + 1])
return item_id in interactions.indices[start:end]
print(in_positives(2, 3, interactions)) After reviewing multiple threads on the topic of dynamic slicing, I didn't find a solution. Is there any general approach to workaround such cases? from functools import partial
import jax
import jax.experimental.sparse as jsp
import jax.numpy as jnp
interactions = jsp.BCSR.fromdense(jnp.array([[1,0,0,0],[0,1,0,0],[0,0,1,1]]))
no_users, no_items = interactions.shape
@partial(jax.jit, static_argnums=(0,1,))
def in_positives(user_id, item_id, interactions) -> bool:
start, end = interactions.indptr[user_id], interactions.indptr[user_id + 1]
return item_id in interactions.indices[start:end]
print(in_positives(2, 3, interactions))
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Your code is attempting to create a dynamically-shaped array. JAX does not support dynamically-shaped arrays (see JAX Sharp Bits: Dynamic Shapes). I don't know the full logic of the code you're trying to write, but I would suggest trying to express your logic without creating the dynamic array. You already have the statically-shaped source array and the dynamic start and end indices: can you write your remaining code to operate directly on those, and avoid the need for a dynamically-shaped intermediate array? |
Beta Was this translation helpful? Give feedback.
Your code is attempting to create a dynamically-shaped array. JAX does not support dynamically-shaped arrays (see JAX Sharp Bits: Dynamic Shapes).
I don't know the full logic of the code you're trying to write, but I would suggest trying to express your logic without creating the dynamic array. You already have the statically-shaped source array and the dynamic start and end indices: can you write your remaining code to operate directly on those, and avoid the need for a dynamically-shaped intermediate array?