You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TL;DR Do there exist any methods in JAX (which are JIT/XLA compatible) that can hash binary arrays for easy look-up (beyond using jnp.searchsorted with jnp.packbits), ideally something akin to a dictionary/lookup table.
Hi All,
I'm trying to find a way to speed-up a method in some JAX code I have, which essentially involves checking if an array (represented as a binary 1D array) exists or not. I've written a minimal reproducible example below (which should scale as O[log(N)] where N is the size of my 'lookup table'). However, in the case this lookup table reaches arrays on the order of 1e6 or more, it becomes quickly impractical and too slow for my use case (as I have to repeat this lookup on the order of millions of times).
Ideally, I'd use some form of hashing function (like a dictionary), although after reading through the discussion page it seems that no such function might exist in XLA (#10475) . I did also see @YouJiacheng write a cuckoo hashing algorithm in #10475, but as my arrays are binary integers I should be able to use jnp.packbits to directly map these 1D arrays to an integer, which should (in principal) be hashable, although perhaps not JITable?
Any ideas on how to create a JIT/XLA compatible lookup table would be greatly appreciated as, although jnp.searchsorted scales well, in the extreme case of 1e7 size arrays, it's quite slow/impractible!
Here's a minimal reproducible example, demonstrating my current method,
import jax
jax.config.update('jax_enable_x64',True)
from jax import numpy as jnp
from jax import random as jr
from jax import lax
from jax.typing import ArrayLike
from jax import Array
from functools import partial
key = jr.key(42)
key, subkey = jr.split(key, 2)
num_states = 10 # number of unique entries
state_length = 4 # length of our 1D arrays (max would be 64, with a double64 int representing its index)
states = jr.randint(key=subkey, shape=(num_states, state_length), minval=0, maxval=2) # random binary states
states = jnp.unique(states, axis=0) # make sure they're unique entries for Lookup table (all entries must be unique for jnp.searchsorted to work)
# NOTE: As we have binary arrays, we can map arrays to ints (or ints to arrays) for a unique representation
# maps a state_length-bit array to an integer
@partial(jax.jit, static_argnums=(1))
def state2integer(state: ArrayLike, state_length: int) -> int:
bit_64 = jnp.zeros(64, jnp.int64)
bit_64 = bit_64.at[0:state_length].set(state)
return jnp.packbits(bit_64, bitorder='little').view('int64')[0]
# maps an integer to a binary representation
def integer2state(integer: int, state_length: int) -> Array:
int_64 = jnp.array([i], dtype='int64')
return jnp.unpackbits(int_64.view('uint8'), bitorder='little', count=state_length).astype(jnp.int64)
integers = jax.vmap(state2integer, in_axes=(0,None))(states, state_length)
sort_idx = jnp.argsort(integers) # sort `integers` in order for searchsorted to work
states = states[sort_idx, :]
integers = integers[sort_idx]
max_integer = jnp.max(integers) # to catch out-of-bounds on the jnp.searchsorted (invalid state is placed at the end)
index_integer = jnp.arange(integers.shape[0]) # array to map the insertion index back to the index in the `states` array.
print('The inital states (what I want to hash for easy lookup)')
for j, (i, s) in enumerate(zip(integers, states)):
print(f'entry: {j:3d} | key: {i:3d} state: {s}')
key, subkey = jr.split(key, 2)
new_states = jr.randint(key=subkey, shape=(num_states, state_length), minval=0, maxval=2) # generate new random binary states (our look-up values)
new_integers = jax.vmap(state2integer, in_axes=(0,None))(new_states, state_length) # get the index of the new states
@jax.jit
def lookup_searchsorted(integers: ArrayLike, new_integers: ArrayLike, max_integer: ArrayLike) -> Array:
insert_index = jnp.searchsorted(integers, new_integers, side='left', method='sort') # get index of `new_integers` in `integers`
# check it hasn't reach end of integers (i.e. it's not higher than any pre-existsing state, therefore valid)
# AND it equals a pre-existing integer (i.e. it's not a state in-between valid states, therefore valid)
valid_mask = jnp.logical_and((insert_index <= max_integer), (integers[insert_index] == new_integers)) # define mask of valid new_integers
new_index = jnp.where(valid_mask, index_integer[insert_index], -1) # if valid, get index in original array, else return -1.
return new_index
new_index = lookup_searchsorted(integers, new_integers, max_integer) # given `new_integers` what's their index in the original array / do they exist there?
print("\nThe new states (invalid marked with '-1')")
for nj, (ni, ns, n_idx) in enumerate(zip(new_integers, new_states, new_index)):
print(f'entry: {nj:3d} | key: {ni:3d} state: {ns} | Index: {n_idx:3d}')
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
TL;DR Do there exist any methods in JAX (which are JIT/XLA compatible) that can hash binary arrays for easy look-up (beyond using
jnp.searchsorted
withjnp.packbits
), ideally something akin to a dictionary/lookup table.Hi All,
I'm trying to find a way to speed-up a method in some JAX code I have, which essentially involves checking if an array (represented as a binary 1D array) exists or not. I've written a minimal reproducible example below (which should scale as
O[log(N)]
whereN
is the size of my 'lookup table'). However, in the case this lookup table reaches arrays on the order of 1e6 or more, it becomes quickly impractical and too slow for my use case (as I have to repeat this lookup on the order of millions of times).Ideally, I'd use some form of hashing function (like a dictionary), although after reading through the discussion page it seems that no such function might exist in XLA (#10475) . I did also see @YouJiacheng write a cuckoo hashing algorithm in #10475, but as my arrays are binary integers I should be able to use
jnp.packbits
to directly map these 1D arrays to an integer, which should (in principal) be hashable, although perhaps not JITable?Any ideas on how to create a JIT/XLA compatible lookup table would be greatly appreciated as, although
jnp.searchsorted
scales well, in the extreme case of 1e7 size arrays, it's quite slow/impractible!Here's a minimal reproducible example, demonstrating my current method,
Beta Was this translation helpful? Give feedback.
All reactions