Skip to content

Commit 566045a

Browse files
authored
Use tensor device reference in persistent kernels (#317)
1 parent 00d13b6 commit 566045a

File tree

3 files changed

+47
-37
lines changed

3 files changed

+47
-37
lines changed

helion/_compiler/program_id.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,13 +420,23 @@ def __init__(self, is_blocked: bool = False) -> None:
420420
"step": NUM_SM_VAR,
421421
}
422422
if device_function.constexpr_arg(NUM_SM_VAR):
423-
device = CompileEnvironment.current().device
424423
device_function.codegen.host_statements.append(
425424
statement_from_string(
426-
f"{NUM_SM_VAR} = helion.runtime.get_num_sm(torch.{device!r})"
425+
f"{NUM_SM_VAR} = helion.runtime.get_num_sm({self.get_device_str()})"
427426
)
428427
)
429428

429+
def get_device_str(self) -> str:
430+
"""Get the device string for the current device, reusing the first tensor's origin."""
431+
host_function = HostFunction.current()
432+
device = CompileEnvironment.current().device
433+
origins = [
434+
o for t, o in host_function.tensor_to_origin.items() if t.device == device
435+
]
436+
if origins:
437+
return f"{origins[0].host_str()}.device"
438+
return f"torch.{device!r}"
439+
430440
def codegen_grid(self) -> ast.AST:
431441
# Use num_sms for persistent kernels
432442
return expr_from_string(f"({NUM_SM_VAR},)")

test/test_examples.expected

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la
259259
v_view = v_in.reshape([-1, n_dim, head_dim])
260260
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
261261
out = torch.empty_like(q_view)
262-
_NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
262+
_NUM_SM = helion.runtime.get_num_sm(q_in.device)
263263
_BLOCK_SIZE_1 = 64
264264
_BLOCK_SIZE_3 = 64
265265
_launcher(_attention_kernel, (_NUM_SM,), q_view, k_view, v_view, out, _NUM_SM, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)

0 commit comments

Comments
 (0)