Skip to content

torch_xla scan forces inputs to be differentiable #8783

Open
@tengyifei

Description

@tengyifei

The snippet

# Make some fake tensors to trace the user function and obtain the
# forward and backward graphs. Note that the init/carry fake tensor
# always requires grad. That's because even if the user passed in some
# `init` that does not require grad, we still want gradients to flow
# through the `carry` from one iteration of the user function to the
# next. In summary, the `carry` argument used to trace a user function
# to get a correct backward pass always requires grad.
def make_fake_tensor(v: torch.Tensor, requires_grad=True) -> torch.Tensor:
return torch.empty_like(
v, dtype=v.dtype, device=v.device, requires_grad=requires_grad)
is probably wrong. It adds require_grads=True on all carry inputs and that won't work if one of the carry is a LongTensor.

The most obvious example is that if one of the input is an integer, then it can't possibly have gradients.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinggood first issueGood for newcomers

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions