Skip to content

Commit 2ddab75

Browse files
authored
Relax typing for CombineFunction (#297)
1 parent 03b58ca commit 2ddab75

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

helion/_compiler/helper_function.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,8 @@ def extract_helper_function(helper_fn: object) -> types.FunctionType:
5959

6060

6161
CombineFunctionBasic = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
62-
CombineFunctionTuple = Callable[
63-
[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]], tuple[torch.Tensor, ...]
64-
]
65-
CombineFunctionUnpacked = Callable[[torch.Tensor, ...], tuple[torch.Tensor, ...]] # pyright: ignore[reportInvalidTypeForm]
66-
CombineFunction = CombineFunctionBasic | CombineFunctionTuple | CombineFunctionUnpacked
62+
CombineFunctionTuple = Callable[..., tuple[torch.Tensor, ...]]
63+
CombineFunction = CombineFunctionBasic | CombineFunctionTuple
6764

6865

6966
def create_combine_function_wrapper(
@@ -104,6 +101,7 @@ def create_combine_function_wrapper(
104101
# If the original format matches target format, no conversion needed
105102
if target_format == original_format:
106103
return actual_fn
104+
combine_fn = cast("CombineFunctionTuple", combine_fn)
107105

108106
# Create conversion wrapper
109107
if target_format == "tuple" and original_format == "unpacked":
@@ -113,11 +111,8 @@ def create_combine_function_wrapper(
113111
def tuple_wrapper(
114112
left_tuple: tuple[torch.Tensor, ...], right_tuple: tuple[torch.Tensor, ...]
115113
) -> tuple[torch.Tensor, ...]:
116-
return inner_unpacked(*left_tuple, *right_tuple)
114+
return combine_fn(*left_tuple, *right_tuple)
117115

118-
inner_unpacked: CombineFunctionUnpacked = cast(
119-
"CombineFunctionUnpacked", actual_fn
120-
)
121116
return tuple_wrapper
122117

123118
if target_format == "unpacked" and original_format == "tuple":
@@ -130,9 +125,8 @@ def unpacked_wrapper(*args: torch.Tensor) -> tuple[torch.Tensor, ...]:
130125
half = num_args // 2
131126
left_tuple = args[:half]
132127
right_tuple = args[half:]
133-
return inner_tuple((*left_tuple,), (*right_tuple,))
128+
return combine_fn((*left_tuple,), (*right_tuple,))
134129

135-
inner_tuple: CombineFunctionTuple = cast("CombineFunctionTuple", actual_fn)
136130
return unpacked_wrapper
137131

138132
# Should not reach here

0 commit comments

Comments
 (0)