-
Hello, x = to.tensor([1., 2., 3.], requires_grad=True)
y = to.tensor([4. ,5. ,6.], requires_grad=True)
z = x * y
if some_dynamic_config_parameter:
t = z.sum() / 2
else:
t = z.sum()
t.backward()
print(x.grad, y.grad) The motivation is that I need to differentiate through computations that are described in config files. I could dynamically produce the code of such a function and I looked around the docs and the GitHub discussions and I could not find an answer to this question -- I'm probably not using the same terms that people in this community would use. Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
Thanks for the question! Perhaps I'm misunderstanding, but would something like this work for you? import jax
import jax.numpy as jnp
def f():
x = jnp.array([1., 2., 3.])
y = jnp.array([4., 5., 6.])
def g(x, y):
z = x * y
if some_dynamic_config_parameter:
t = z.sum() / 2
else:
t = z.sum()
return t
return jax.grad(g, (0, 1))(x, y)
some_dynamic_config_parameter = True
print(f())
# (Array([2. , 2.5, 3. ], dtype=float32), Array([0.5, 1. , 1.5], dtype=float32))
some_dynamic_config_parameter = False
print(f())
# (Array([4., 5., 6.], dtype=float32), Array([1., 2., 3.], dtype=float32)) In other words, you parse the config file in the definition of the function you'd like to take the gradient of. |
Beta Was this translation helpful? Give feedback.
I don't see any reason why you can't just evaluate the expression directly in jax, just as you do in pytorch. For example, this uses the same
compute
function from your pytorch example: