Replies: 1 comment 3 replies
-
Hello! from tensordict import TensorDict
from tensordict.nn import TensorDictModule as Mod
import torch
mod = Mod(lambda x: (x+1).mean(), in_keys=["a"], out_keys=["b"])
td = TensorDict(a=torch.randn(10, 11, requires_grad=True), batch_size=[10])
vmap_mod = torch.vmap(mod, (0,))
td_out = vmap_mod(td)
print(td_out)
grad = torch.func.grad(lambda td: mod(td)["b"])
print(grad(td[0])["a"])
vmap_grad = torch.vmap(grad, (0,))
td_out = vmap_grad(td)
print(td_out["a"]) So vmapping a grad works (gradding a vmap doesn't I believe) Is there a version of this script that explains what you're trying to do and where it breaks? |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi everyone,
I am trying to move my code to tensordict(td) and vmap along batch according to https://discuss.pytorch.org/t/how-to-apply-vmap-on-a-heterogeneous-tensor/214109/5. Since vmap not support batched tensordict, now I just use batchsize = 1.
Nevertheless, everything works super good until I need to rewrite derivative module. I have a td with distance between atoms and predict energy by model, and the I need to calculate force by derivating energy w.r.t distance.
Before I use
torch.grad
, very straightforward to use. But within vmap, it raises:element 0 of tensors does not require grad and does not have a grad_fn
. I findBatchedtensor
of energy hasgrad_fn
, but "vmaped" tensor, the real tensor required_grad is false.Apparently(?), we can not combine vmap with torch.grad, but
torch.func.grad
, according to https://discuss.pytorch.org/t/use-vmap-and-grad-to-calculate-gradients-for-one-layer-independently-for-each-input-in-batch/187556/2?u=roy-kid, and https://discuss.pytorch.org/t/simple-use-case-compete-per-sample-gradient-with-autograd/207317/4?u=roy-kidSince
torch.func.vmap
can only specify the position of arguments, how can I derivateenergy
w.r.tdistance
. Here is my pesudo code:So thx for your help!
Also post in forum: https://discuss.pytorch.org/t/combine-vmap-func-grad-with-tensordict/215086?u=roy-kid
Beta Was this translation helpful? Give feedback.
All reactions