Skip to content

fonction with return type according to argument type but with intermediate change #17243

Closed Answered by jakevdp
jecampagne asked this question in Q&A
Discussion options

You must be logged in to vote

In internal JAX code, I often use a pattern that looks something like this:

def f(x):
  out_shape = jnp.shape(x)
  x = jnp.atleast_1d(x)
  res = x**3
  return res.reshape(out_shape)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@jecampagne
Comment options

Answer selected by jecampagne
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants