-
Notifications
You must be signed in to change notification settings - Fork 553
Open
Labels
bugSomething isn't workingSomething isn't workinggood first issueGood for newcomersGood for newcomers
Description
The snippet
xla/torch_xla/experimental/scan.py
Lines 217 to 226 in 00fac78
# Make some fake tensors to trace the user function and obtain the | |
# forward and backward graphs. Note that the init/carry fake tensor | |
# always requires grad. That's because even if the user passed in some | |
# `init` that does not require grad, we still want gradients to flow | |
# through the `carry` from one iteration of the user function to the | |
# next. In summary, the `carry` argument used to trace a user function | |
# to get a correct backward pass always requires grad. | |
def make_fake_tensor(v: torch.Tensor, requires_grad=True) -> torch.Tensor: | |
return torch.empty_like( | |
v, dtype=v.dtype, device=v.device, requires_grad=requires_grad) |
require_grads=True
on all carry inputs and that won't work if one of the carry is a LongTensor
.
The most obvious example is that if one of the input is an integer, then it can't possibly have gradients.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinggood first issueGood for newcomersGood for newcomers