-
Hello, I am using jax.debug.print in order to print the shape of an array. However when I actually call |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 12 replies
-
You can use the following code: import jax
import jax.tree_util as jtu
@jtu.register_pytree_node_class
class StaticShape:
def __init__(self, value: tuple[int, ...]):
self.value = value
def tree_flatten(self):
return (), self.value
@classmethod
def tree_unflatten(cls, aux_data, _):
self = object.__new__(cls)
self.value = aux_data
return self
def __repr__(self):
return f"{self.value!r}"
def __hash__(self):
return hash(self.value)
def __eq__(self, other):
return (self.value == other.value) if isinstance(other, StaticShape) else False
@jax.jit
def f(x):
jax.debug.print("shape={x}", x=StaticShape(x.shape))
return x
f(jax.numpy.ones([4, 4]))
# shape=(4, 4)
# Array([[1., 1., 1., 1.],
# [1., 1., 1., 1.],
# [1., 1., 1., 1.],
# [1., 1., 1., 1.]], dtype=float32) I suspect this issue is due to |
Beta Was this translation helpful? Give feedback.
-
# print dynamic array contents with debug print:
jax.debug.print("{}", results)
# print static array attributes with Python print:
print(results.shape)
print(results.dtype) |
Beta Was this translation helpful? Give feedback.
jax.debug.print
is designed for printing dynamic values. Array shapes are always static values, so the better tool would be a standard Python print: