Methods Efficiently Finding Unique Rows in a 2D Integer Array Without Using jnp.unique Along Axis? #23530
Replies: 1 comment 2 replies
-
First of all, there is no efficient operation in XLA to find unique values period. We have a passable implementation of
Depending on the size of your problem, you might consider the spectral analysis approach demonstrated in #17370 (comment); it's only about a dozen lines of code, it's JIT-compatible, and it should be pretty efficient. What do you think? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi everyone,
I have a subroutine in my application that needs to be jittable and involves finding unique rows in a 2D integer array
A
. I’m currently running into performance issues with the jnp.unique function in JAX when using the axis argument, and I’m hoping to get some advice on a more efficient approach. Here are the properties ofA
:A
has a fixed shape (n, m), where n (the number of rows) is in the range ofA
are integers and lie within the range [0,k], where k is fixed.A
is already sorted, and no elements are repeated within a row.What I’ve Tried So Far:
I initially tried:
unique_A, unique_index = jit(jnp.unique(A, axis=0, return_index=True, size=n, fill_value=jnp.zeros(m, dtype=jnp.int32)))
However, I found that using jnp.unique along the axis is very slow (issue #17370).
Workaround Considered:
One workaround I considered is to use a hashing function to map each row into an integer and then perform jnp.unique on the resulting 1D array. However, if the hashing function is collision-free (e.g., using powers of k+1), the resulting hash values can easily exceed the range of uint32. Since the application needs to run in single precision, partitioning the hash values into multiple uint32 segments would still require using jnp.unique along the axis, which brings us back to the original problem.
Question:
Is there an efficient, jittable algorithm that can find unique integer rows in A and avoid the performance issues with jnp.unique along the axis? Any suggestions or alternative approaches would be greatly appreciated!
Thanks in advance for your help!
Beta Was this translation helpful? Give feedback.
All reactions