fonction with return type according to argument type but with intermediate change #17243
-
Hello, The problem is to find a way to accept a scalar/vector as argument, turn it as a vector (at least 1d) to get an algorithm working, and finally return either a scalar or a vector according to the argument type. The function should also be jacrev compatible. Now a snippet that illustrate a possible solution (nb. the "algo" here would of course be thought to be complicated in real life, as a possible need to convert the scalar to vector think with a vmap): def f(x):
nd = jnp.ndim(x)
x = jnp.atleast_1d(x)
res = x**3 # here a very simple algo
if nd==0:
return res[0]
else:
return res One can figure out that
So far so good, but my solution is certainly not the most elegant and robust. If you find better, I'd be grateful. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
In internal JAX code, I often use a pattern that looks something like this: def f(x):
out_shape = jnp.shape(x)
x = jnp.atleast_1d(x)
res = x**3
return res.reshape(out_shape) |
Beta Was this translation helpful? Give feedback.
In internal JAX code, I often use a pattern that looks something like this: