-
Notifications
You must be signed in to change notification settings - Fork 15
[Ref Mode] PyTorch reference mode (eager only) #339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: yf225/stack/39
Are you sure you want to change the base?
Conversation
stack-info: PR: #339, branch: yf225/stack/34
helion/ref/hl_patch.py
Outdated
) | ||
|
||
# Step 3: Handle block_size (in ref mode, full dim size is always used as block_size) | ||
block_size_list = [None] * len(end_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Always use full dim size in ref modes regardless of block_size
value
examples/concatenate.py
Outdated
x_part = hl.load( | ||
x, [tile0, tile1], extra_mask=(tile1.index < x.size(1))[None, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are treating tile
as a Python slice object in ref mode, tile.index
no longer works and we have to use hl.tile_index()
.
To make the UX better, in a follow-up PR I am thinking of adding a RefTile
class that Dynamo can understand, and support tile APIs like .index
/ .begin
/ .end
in that class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make tile.index work. Maybe rather than changing the examples we should skip tests in reference mode.
test/test_ref_compile.py
Outdated
class TestExamplesRefCompile(test_examples.TestExamples): | ||
"""Run all TestExamples tests in reference torch.compile mode via HELION_REF_COMPILE=1.""" | ||
|
||
# NOTE: All tests in TestExamples are run in ref torch.compile(fullgraph=True) mode by default in this test file. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently all examples in TestExamples
pass with ref eager mode and ref compile mode.
Planning to add more ref mode unit tests to cover test_reduce.py
, test_associative_scan.py
etc. in the next PR.
Stacked PRs:
[Ref Mode] PyTorch reference mode (eager only)
Part of #77.
Please see inline code comments on the PR.