9
9
import os
10
10
import tempfile
11
11
import unittest
12
- from typing import List , Optional , Tuple
12
+ from typing import Callable , List , Optional , Tuple
13
13
14
14
import executorch .exir as exir
15
15
71
71
from functorch .experimental import control_flow
72
72
73
73
from torch import nn
74
+ from torch ._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
74
75
from torch .export import export
75
76
from torch .export .graph_signature import InputKind , InputSpec , TensorArgument
76
77
from torch .fx import GraphModule , subgraph_rewriter
@@ -121,39 +122,97 @@ def foo_out(
121
122
return a + 1 , None
122
123
123
124
125
+ def simple_promote_dtype (
126
+ dtype : torch .dtype , promotion_kind : ELEMENTWISE_TYPE_PROMOTION_KIND
127
+ ) -> torch .dtype :
128
+ if promotion_kind == ELEMENTWISE_TYPE_PROMOTION_KIND .DEFAULT :
129
+ return dtype
130
+ if promotion_kind == ELEMENTWISE_TYPE_PROMOTION_KIND .INT_TO_FLOAT :
131
+ return dtype if dtype .is_floating_point else torch .float
132
+ else :
133
+ raise Exception (f"Unsupported promotion kind { promotion_kind } " )
134
+
135
+
136
+ def count_nodes_with_target_asserting_arguments_have_dtype (
137
+ self , module , target , arg_dtype
138
+ ) -> int :
139
+ count = 0
140
+ for node in module .graph .nodes :
141
+ if node .op == "call_function" and node .target == target :
142
+ count += 1
143
+ for arg in node .args :
144
+ self .assertEqual (arg .meta ["val" ].dtype , arg_dtype )
145
+ return count
146
+
147
+
124
148
class TestPasses (unittest .TestCase ):
125
149
@classmethod
126
150
def setUpClass (cls ) -> None :
127
151
register_additional_test_aten_ops ()
128
152
129
153
def test_remove_mixed_type_operators (self ) -> None :
130
- def count_nodes_with_target_asserting_arguments_have_dtype (
131
- new_graph_module , target , arg_dtype
132
- ):
133
- count = 0
134
- for node in new_graph_module .graph .nodes :
135
- if node .op == "call_function" and node .target == target :
136
- count += 1
137
- for arg in node .args :
138
- self .assertEqual (arg .meta ["val" ].dtype , arg_dtype )
139
- return count
140
-
141
- class Add (torch .nn .Module ):
142
- def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
143
- return (x + y ) + x
144
-
145
- class Mult (torch .nn .Module ):
146
- def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
147
- return x * y
148
-
149
- class Minimum (torch .nn .Module ):
150
- def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
151
- return torch .minimum (x , y )
154
+ def make_module (fwd : Callable [[torch .Tensor , torch .Tensor ], torch .Tensor ]):
155
+ class Module (torch .nn .Module ):
156
+ def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
157
+ return fwd (x , y )
158
+
159
+ return Module
160
+
161
+ Add = make_module (lambda x , y : (x + y ) + x )
162
+ Mult = make_module (lambda x , y : x * y )
163
+ Minimum = make_module (torch .minimum )
164
+ DivWithoutMode = make_module (torch .div )
165
+ DivWithNoneMode = make_module (lambda x , y : torch .div (x , y , rounding_mode = None ))
166
+ DivWithTruncMode = make_module (
167
+ lambda x , y : torch .div (x , y , rounding_mode = "trunc" )
168
+ )
169
+ DivWithFloorMode = make_module (
170
+ lambda x , y : torch .div (x , y , rounding_mode = "floor" )
171
+ )
152
172
153
- for module , op , expected_count in (
154
- (Add , exir_ops .edge .aten .add .Tensor , 2 ),
155
- (Mult , exir_ops .edge .aten .mul .Tensor , 1 ),
156
- (Minimum , exir_ops .edge .aten .minimum .default , 1 ),
173
+ for module , op , expected_count , promotion_kind in (
174
+ (
175
+ Add ,
176
+ exir_ops .edge .aten .add .Tensor ,
177
+ 2 ,
178
+ ELEMENTWISE_TYPE_PROMOTION_KIND .DEFAULT ,
179
+ ),
180
+ (
181
+ Mult ,
182
+ exir_ops .edge .aten .mul .Tensor ,
183
+ 1 ,
184
+ ELEMENTWISE_TYPE_PROMOTION_KIND .DEFAULT ,
185
+ ),
186
+ (
187
+ Minimum ,
188
+ exir_ops .edge .aten .minimum .default ,
189
+ 1 ,
190
+ ELEMENTWISE_TYPE_PROMOTION_KIND .DEFAULT ,
191
+ ),
192
+ (
193
+ DivWithoutMode ,
194
+ exir_ops .edge .aten .div .Tensor ,
195
+ 1 ,
196
+ ELEMENTWISE_TYPE_PROMOTION_KIND .INT_TO_FLOAT ,
197
+ ),
198
+ (
199
+ DivWithNoneMode ,
200
+ exir_ops .edge .aten .div .Tensor_mode ,
201
+ 1 ,
202
+ ELEMENTWISE_TYPE_PROMOTION_KIND .INT_TO_FLOAT ,
203
+ ),
204
+ (
205
+ DivWithTruncMode ,
206
+ exir_ops .edge .aten .div .Tensor_mode ,
207
+ 1 ,
208
+ ELEMENTWISE_TYPE_PROMOTION_KIND .DEFAULT ,
209
+ ),
210
+ (
211
+ DivWithFloorMode ,
212
+ exir_ops .edge .aten .div .Tensor_mode ,
213
+ 1 ,
214
+ ELEMENTWISE_TYPE_PROMOTION_KIND .DEFAULT ,
215
+ ),
157
216
):
158
217
for second_arg_dtype in (torch .int64 , torch .float , torch .double ):
159
218
int_tensor = torch .tensor ([[1 , 2 , 3 ]], dtype = torch .int64 )
@@ -166,8 +225,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
166
225
new_graph_module = new_prog .exported_program ().graph_module
167
226
self .assertIsNotNone (new_graph_module )
168
227
228
+ promoted_type = simple_promote_dtype (second_arg_dtype , promotion_kind )
169
229
count = count_nodes_with_target_asserting_arguments_have_dtype (
170
- new_graph_module , op , second_arg_dtype
230
+ self , new_graph_module , op , promoted_type
171
231
)
172
232
self .assertEqual (count , expected_count )
173
233
0 commit comments