Skip to content

Check the scan op input for requires_grad #9083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open

Conversation

haifeng-jin
Copy link
Collaborator

@haifeng-jin haifeng-jin commented May 2, 2025

Resolves #8783

Added a test for LongTensor inputs, which would fail without this PR.
Only set carry.requires_grad to True when the dtype is floating point.

Return None, None in the backward() function if none of the outputs has gradients.

Questions:

  1. Why don't we just set carry.requires_grad the same as the user input init? (I did run into some errors when I tried that)
  2. Do we also need to return None for init.grad in backward(), when only x requires grad?
  3. It fails when init is LongTensor and x is float32 because it requires carry.dtype == init.dtype, while carry is float because it is produced by int + float. Do we need to support this case?

@haifeng-jin haifeng-jin marked this pull request as ready for review May 6, 2025 21:44
@haifeng-jin haifeng-jin requested a review from tengyifei May 6, 2025 21:44
@miladm miladm requested a review from bhavya01 May 7, 2025 19:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch_xla scan forces inputs to be differentiable
1 participant