@@ -59,11 +59,8 @@ def extract_helper_function(helper_fn: object) -> types.FunctionType:
59
59
60
60
61
61
CombineFunctionBasic = Callable [[torch .Tensor , torch .Tensor ], torch .Tensor ]
62
- CombineFunctionTuple = Callable [
63
- [tuple [torch .Tensor , ...], tuple [torch .Tensor , ...]], tuple [torch .Tensor , ...]
64
- ]
65
- CombineFunctionUnpacked = Callable [[torch .Tensor , ...], tuple [torch .Tensor , ...]] # pyright: ignore[reportInvalidTypeForm]
66
- CombineFunction = CombineFunctionBasic | CombineFunctionTuple | CombineFunctionUnpacked
62
+ CombineFunctionTuple = Callable [..., tuple [torch .Tensor , ...]]
63
+ CombineFunction = CombineFunctionBasic | CombineFunctionTuple
67
64
68
65
69
66
def create_combine_function_wrapper (
@@ -104,6 +101,7 @@ def create_combine_function_wrapper(
104
101
# If the original format matches target format, no conversion needed
105
102
if target_format == original_format :
106
103
return actual_fn
104
+ combine_fn = cast ("CombineFunctionTuple" , combine_fn )
107
105
108
106
# Create conversion wrapper
109
107
if target_format == "tuple" and original_format == "unpacked" :
@@ -113,11 +111,8 @@ def create_combine_function_wrapper(
113
111
def tuple_wrapper (
114
112
left_tuple : tuple [torch .Tensor , ...], right_tuple : tuple [torch .Tensor , ...]
115
113
) -> tuple [torch .Tensor , ...]:
116
- return inner_unpacked (* left_tuple , * right_tuple )
114
+ return combine_fn (* left_tuple , * right_tuple )
117
115
118
- inner_unpacked : CombineFunctionUnpacked = cast (
119
- "CombineFunctionUnpacked" , actual_fn
120
- )
121
116
return tuple_wrapper
122
117
123
118
if target_format == "unpacked" and original_format == "tuple" :
@@ -130,9 +125,8 @@ def unpacked_wrapper(*args: torch.Tensor) -> tuple[torch.Tensor, ...]:
130
125
half = num_args // 2
131
126
left_tuple = args [:half ]
132
127
right_tuple = args [half :]
133
- return inner_tuple ((* left_tuple ,), (* right_tuple ,))
128
+ return combine_fn ((* left_tuple ,), (* right_tuple ,))
134
129
135
- inner_tuple : CombineFunctionTuple = cast ("CombineFunctionTuple" , actual_fn )
136
130
return unpacked_wrapper
137
131
138
132
# Should not reach here
0 commit comments