Skip to content

torch_xla scan forces inputs to be differentiable #8783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
tengyifei opened this issue Mar 4, 2025 · 1 comment · May be fixed by #9083
Open

torch_xla scan forces inputs to be differentiable #8783

tengyifei opened this issue Mar 4, 2025 · 1 comment · May be fixed by #9083
Assignees
Labels
bug Something isn't working good first issue Good for newcomers

Comments

@tengyifei
Copy link
Collaborator

tengyifei commented Mar 4, 2025

The snippet

# Make some fake tensors to trace the user function and obtain the
# forward and backward graphs. Note that the init/carry fake tensor
# always requires grad. That's because even if the user passed in some
# `init` that does not require grad, we still want gradients to flow
# through the `carry` from one iteration of the user function to the
# next. In summary, the `carry` argument used to trace a user function
# to get a correct backward pass always requires grad.
def make_fake_tensor(v: torch.Tensor, requires_grad=True) -> torch.Tensor:
return torch.empty_like(
v, dtype=v.dtype, device=v.device, requires_grad=requires_grad)
is probably wrong. It adds require_grads=True on all carry inputs and that won't work if one of the carry is a LongTensor.

The most obvious example is that if one of the input is an integer, then it can't possibly have gradients.

@tengyifei tengyifei self-assigned this Mar 4, 2025
@ysiraichi ysiraichi added the bug Something isn't working label Mar 5, 2025
@tengyifei
Copy link
Collaborator Author

tengyifei commented Mar 17, 2025

After consulting the LLMs, it looks like PyTorch autograd will discard any gradient we return from the backwards() if that tensor is not supposed to require gradients (we should write a test for this). That's why this logic "works". The fix here is that requires_grad should be True unless the tensor is not a floating point one. Also we should return None for any input tensors that do not require gradient to be clean.

@tengyifei tengyifei changed the title torch_xla scan forces inputs to have gradients torch_xla scan forces inputs to be differentiable Mar 17, 2025
@tengyifei tengyifei added the good first issue Good for newcomers label Apr 22, 2025
@haifeng-jin haifeng-jin self-assigned this Apr 30, 2025
@haifeng-jin haifeng-jin linked a pull request May 2, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants