Out of Bounds Index Handling #27597
-
Recently, I was relearn ml and decided to pick JAX. However, a peculiarity I noticed was in the following code: import jax
import jax.numpy as jnp
import pandas as pd
import numpy as np
# Creating a 3 x 2 matrix:
M = jnp.array([[1, 2],
[3, 4],
[5, 6]])
print(M.shape)
print(M[3][2]) # Should raise an error, but instead returns 6 Is this an intended behavior? Wouldn't this lead to implicit errors in one's code? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
I am running the code on JAX cuda 12 version 0.5.3 build on my RTX 3070. |
Beta Was this translation helpful? Give feedback.
-
Nevermind, I did some searching and found out the reasoning at https://stackoverflow.com/questions/75770847/jax-vs-numpy-array-indexing-out-of-bounds-behaviour. It is also in the JAX sharp bits for anyone interested in the future. |
Beta Was this translation helpful? Give feedback.
Nevermind, I did some searching and found out the reasoning at https://stackoverflow.com/questions/75770847/jax-vs-numpy-array-indexing-out-of-bounds-behaviour. It is also in the JAX sharp bits for anyone interested in the future.