Skip to content

lax.scan over 3D array #12915

Answered by hericks
dui1234 asked this question in General
Oct 21, 2022 · 2 comments · 2 replies
Discussion options

You must be logged in to vote
  1. Setup. I used the following code as setup.

    import jax
    import jax.numpy as jnp
    import numpy as np
    
    key= jax.random.PRNGKey(0)
    key, key_ijk, key_S_t = jax.random.split(key, 3)
    
    ijk = jax.random.randint(key_ijk, (159600, 3), 0, 2**5)
    S_t = jax.random.normal(key_S_t, (159600, 21))
    
  2. Fixing bug in provided solution. I believe that your code does not solve the problem that you'd actually like to solve, i.e. computing a mean. For this, note that when a tuple of indices i, j, k appears multiple times in the array ijk, then the entries in stresses[i, j, k, :] are divided by count[i, j, k] multiple times - once for each occurrence of the indices i, j, k in ijk.

    This can be fixed dividing all non…

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@dui1234
Comment options

Comment options

You must be logged in to vote
1 reply
@dui1234
Comment options

Answer selected by dui1234
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants