You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello jax community!
In my application I need to stack together row-by-row multiple matrices for later computation. However, the number of such matrices is variable at runtime (the user can toggle each matrix "on" or "off" via a boolean mask, in which case it is simply computed as a dummy zero of matching shape).
So, the "naively" stacked matrix may be very tall and contain many zero rows. What I need is to accumulate only the non-zero rows into a shorter matrix, so to make the final computations lighter.
The shorter matrix needs to be fixed size (we can allow for some padding) to support jit, which is essential.
I have come up with a solution based on lax.scan to do the accumulation (see code at the bottom of the post), which relies on the out-of-bounds behavior of JAX (set operations do nothing).
I ask for the community's help to give some feedback on whether this is indeed a "good" or recommended pattern, if there are shortcomings I might have missed, or if there exist better or more optimized solutions.
Thanks!
Arturo
importjaximportjax.numpyasjpimportnumpyasnpfromfunctoolsimportpartial# dummy computations (imagine heavy math inside)defcompute_matrix_1(x):
returnjp.full((2, 5), x)
defcompute_matrix_2(x):
returnjp.full((2, 5), x**2)
defcompute_matrix_3(x):
returnjp.full((2, 5), x**3)
# adapt to use a boolean mask to decide whether to compute the matrixdefmake_maybe_compute_matrix(fn):
def_zeros(x):
returnjp.zeros((2, 5))
defmaybe_compute_matrix(x, m):
returnjax.lax.cond(m, fn, _zeros, x)
returnmaybe_compute_matrix# store inside list and apply adapterfn_list= [compute_matrix_1, compute_matrix_2, compute_matrix_3]
fn_list=map(make_maybe_compute_matrix, fn_list)
# evaluate stacked matrix using runtime boolean masks from the user# we also compute a vector of boolean masks, one per each rowdefcompute_stacked_matrix(x, m):
matrices= []
masks= []
fori, fninenumerate(fn_list):
matrices.append(fn(x, m[i]))
masks.append( jp.repeat(m[i][jp.newaxis], 2, axis=0))
returnjp.vstack(matrices), jp.hstack(masks)
# testmat, masks=compute_stacked_matrix(3.0, jp.array([True, False, True], dtype=bool))
print('stacked matrix:\n', mat)
print('row masks:', masks)
# imagine we have many more than 3 functions, so the stacked matrix# could be huge and contains many zero rows# so, now I want to accumulate just the non-zero rows into a smaller matrix# of constant size maxrows (we may allow for some padding)@partial(jax.jit, static_argnums=2)defaccumulate_non_zero_rows(mat, masks, maxrows):
# define fn to apply lax.scan todefaccumulate(carry, mask_idx):
mat_nz_idx, mat_nz=carry# out-of-bound index if mask is Falseidx=jp.where(masks[mask_idx], mat_nz_idx, maxrows)
# set current idx to required row only if index is within bounds!mat_nz=mat_nz.at[idx].set(mat[mask_idx])
# increment index if mask is Truemat_nz_idx=jp.where(masks[mask_idx], mat_nz_idx+1, mat_nz_idx)
return (mat_nz_idx, mat_nz), None# initialize carrymat_nz=jp.zeros((maxrows, mat.shape[1]))
mat_nz_idx=0carry= (mat_nz_idx, mat_nz)
# apply lax.scancarry, _=jax.lax.scan(accumulate, carry, jp.arange(len(masks)))
mat_nz_idx, mat_nz=carryreturnmat_nz# testmat_nz=accumulate_non_zero_rows(mat, masks, 5)
print('accumulated non-zero rows:\n', mat_nz)
# testmat_nz=accumulate_non_zero_rows(mat, masks, 3)
print('accumulated non-zero rows:\n', mat_nz)
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hello jax community!
In my application I need to stack together row-by-row multiple matrices for later computation. However, the number of such matrices is variable at runtime (the user can toggle each matrix "on" or "off" via a boolean mask, in which case it is simply computed as a dummy zero of matching shape).
So, the "naively" stacked matrix may be very tall and contain many zero rows. What I need is to accumulate only the non-zero rows into a shorter matrix, so to make the final computations lighter.
The shorter matrix needs to be fixed size (we can allow for some padding) to support jit, which is essential.
I have come up with a solution based on
lax.scan
to do the accumulation (see code at the bottom of the post), which relies on the out-of-bounds behavior of JAX (set operations do nothing).I ask for the community's help to give some feedback on whether this is indeed a "good" or recommended pattern, if there are shortcomings I might have missed, or if there exist better or more optimized solutions.
Thanks!
Arturo
Output:
Beta Was this translation helpful? Give feedback.
All reactions