Skip to content

Commit b64bf00

Browse files
authored
Add support for print(prefix_str, *tensors) (#140)
1 parent 0e047af commit b64bf00

File tree

6 files changed

+772
-1
lines changed

6 files changed

+772
-1
lines changed

helion/_compiler/device_ir.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ..autotuner.config_spec import ReductionLoopSpec
3131
from ..language import _tracing_ops
3232
from ..language._decorators import args_to_proxies
33+
from ..language._decorators import get_device_func_replacement
3334
from .ast_extension import ExtendedAST
3435
from .ast_extension import LoopType
3536
from .ast_extension import NodeVisitor
@@ -788,8 +789,17 @@ def visit_Call(self, node: ast.Call) -> object:
788789
kwargs.update(self._to_proxy(kwarg.value))
789790
else:
790791
kwargs[kwarg.arg] = self._to_proxy(kwarg.value)
792+
793+
if isinstance(
794+
(func_type_info := node.func._type_info), # pyre-ignore[16]
795+
CallableType,
796+
) and (replacement := get_device_func_replacement(func_type_info.value)):
797+
func = replacement
798+
else:
799+
func = self.visit(node.func)
800+
791801
# pyre-ignore[6]
792-
return CheckForIndexCalls.retry_call(self.visit(node.func), args, kwargs)
802+
return CheckForIndexCalls.retry_call(func, args, kwargs)
793803

794804
def visit_Attribute(self, node: ast.Attribute) -> object:
795805
return getattr(self.visit(node.value), node.attr)

helion/_compiler/type_propagation.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from .. import exc
2424
from ..autotuner.config_spec import BlockSizeSpec
25+
from ..language._decorators import get_device_func_replacement
2526
from ..language._decorators import is_api_func
2627
from .ast_extension import ExtendedAST
2728
from .ast_extension import LoopType
@@ -1828,6 +1829,14 @@ def visit_Call(self, node: ast.Call) -> TypeInfo:
18281829
# TODO(jansel): test handling if *args and **kwargs
18291830
# TODO(jansel): check for calling a Kernel here
18301831
func = self.visit(node.func)
1832+
1833+
if (
1834+
isinstance(func, CallableType)
1835+
and self.origin().is_device()
1836+
and (replacement := get_device_func_replacement(func.value))
1837+
):
1838+
func = CallableType(func.origin, replacement)
1839+
18311840
unhandled = []
18321841
args = []
18331842
kwargs = {}

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .constexpr import specialize as specialize
55
from .creation_ops import full as full
66
from .creation_ops import zeros as zeros
7+
from .device_print import device_print as device_print
78
from .loops import Tile as Tile
89
from .loops import grid as grid
910
from .loops import register_block_size as register_block_size

helion/language/_decorators.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,22 @@ def _to_proxy(arg: TypeInfo) -> object:
289289
return arg.proxy()
290290
except NotImplementedError:
291291
raise exc.TracedArgNotSupported(arg) from None
292+
293+
294+
# Tracks 1-1 mapping between Python functions and their Helion API counterparts within device function.
295+
_DEVICE_FUNC_REPLACEMENTS: dict[object, APIFunc] = {}
296+
297+
298+
def device_func_replacement(python_func: object) -> _Decorator:
299+
def _impl(fn: _C) -> _C:
300+
assert is_api_func(fn), (
301+
f"{device_func_replacement.__qualname__} can only be used on API functions"
302+
)
303+
_DEVICE_FUNC_REPLACEMENTS[python_func] = fn
304+
return fn # pyre-ignore[7]
305+
306+
return _impl
307+
308+
309+
def get_device_func_replacement(func: object) -> APIFunc | None:
310+
return _DEVICE_FUNC_REPLACEMENTS.get(func)

helion/language/device_print.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
import builtins
5+
from typing import TYPE_CHECKING
6+
7+
from torch.fx import has_side_effect
8+
9+
from .. import exc
10+
from .._compiler.ast_extension import create
11+
from .._compiler.ast_extension import expr_from_string
12+
from . import _decorators
13+
14+
if TYPE_CHECKING:
15+
from .._compiler.inductor_lowering import CodegenState
16+
17+
18+
@has_side_effect
19+
@_decorators.device_func_replacement(builtins.print)
20+
@_decorators.api(is_device_only=False)
21+
def device_print(prefix: str, *values: object) -> None:
22+
"""
23+
Print values from device code.
24+
25+
:param prefix: A string prefix for the print statement
26+
:param values: Tensor values to print
27+
"""
28+
raise exc.NotInsideKernel
29+
30+
31+
@_decorators.register_fake(device_print)
32+
def _(*values: object, sep: str = " ", end: str = "\n") -> None:
33+
return None
34+
35+
36+
@_decorators.type_propagation(device_print)
37+
def _(*args: object, origin: object, **kwargs: object) -> object:
38+
from .._compiler.type_propagation import LiteralType
39+
from .._compiler.type_propagation import NoType
40+
from .._compiler.type_propagation import TensorType
41+
42+
# Check that we have at least one argument (prefix)
43+
if len(args) == 0:
44+
raise ValueError("print() requires at least one argument (prefix)")
45+
46+
# First argument must be the prefix string
47+
if not (isinstance(args[0], LiteralType) and isinstance(args[0].value, str)):
48+
raise TypeError(
49+
f"First argument to print() must be a string prefix, got {args[0]}"
50+
)
51+
52+
# For compile-time values like tensor shapes, we should error out
53+
for i, arg in enumerate(args[1:]):
54+
if not isinstance(arg, TensorType):
55+
raise TypeError(
56+
f"print() only supports runtime tensor values. "
57+
f"Argument {i + 1} is {arg}, not a tensor. "
58+
f"Compile-time values like tensor shapes are not supported yet."
59+
)
60+
61+
return NoType(origin=origin)
62+
63+
64+
# pyre-fixme[56]
65+
@_decorators.codegen(device_print)
66+
def _(state: CodegenState) -> None:
67+
prefix = state.proxy_arg(0)
68+
call_args = [create(ast.Constant, value=prefix)]
69+
70+
# Handle varargs
71+
if len(state.proxy_args) > 1:
72+
assert len(state.ast_args) > 1
73+
ast_varargs = state.ast_args[1]
74+
call_args.extend(ast_varargs[0]) # pyre-fixme[16]
75+
76+
call_expr = create(
77+
ast.Call,
78+
func=expr_from_string("tl.device_print"),
79+
args=call_args,
80+
keywords=[],
81+
)
82+
stmt = create(ast.Expr, value=call_expr)
83+
state.add_statement(stmt)

0 commit comments

Comments
 (0)