@@ -127,85 +127,44 @@ def setUpClass(cls) -> None:
127
127
register_additional_test_aten_ops ()
128
128
129
129
def test_remove_mixed_type_operators (self ) -> None :
130
+ def count_nodes_with_target_asserting_arguments_have_dtype (
131
+ new_graph_module , target , arg_dtype
132
+ ):
133
+ count = 0
134
+ for node in new_graph_module .graph .nodes :
135
+ if node .op == "call_function" and node .target == target :
136
+ count += 1
137
+ for arg in node .args :
138
+ self .assertEqual (arg .meta ["val" ].dtype , arg_dtype )
139
+ return count
140
+
130
141
class Add (torch .nn .Module ):
131
142
def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
132
143
return (x + y ) + x
133
144
134
- add = Add ()
135
-
136
- int_tensor = torch .tensor ([[1 , 2 , 3 ]])
137
- float_tensor = torch .tensor ([[1.0 , 2.0 , 3.0 ]])
138
- edge_prog = to_edge (export (add , (int_tensor , float_tensor ), strict = True ))
139
-
140
- new_prog = edge_prog .transform ([RemoveMixedTypeOperators ()])
141
- new_graph_module = new_prog .exported_program ().graph_module
142
- self .assertIsNotNone (new_graph_module )
143
-
144
- add_count = 0
145
-
146
- for node in new_graph_module .graph .nodes :
147
- if (
148
- node .op == "call_function"
149
- and node .target == exir_ops .edge .aten .add .Tensor
150
- ):
151
- add_count += 1
152
- node_args = node .args
153
- for arg in node_args :
154
- self .assertEqual (arg .meta ["val" ].dtype , torch .float )
155
-
156
- self .assertEqual (add_count , 2 )
157
-
158
- double_tensor = torch .tensor ([[1.0 , 2.0 , 3.0 ]])
159
- double_tensor = double_tensor .to (torch .double )
160
-
161
- double_prog = to_edge (export (add , (int_tensor , double_tensor ), strict = True ))
162
-
163
- double_prog .transform ([RemoveMixedTypeOperators ()])
164
- new_graph_module_double = double_prog .exported_program ().graph_module
165
- self .assertIsNotNone (new_graph_module_double )
166
-
167
- add_count_double = 0
168
-
169
- for node in new_graph_module_double .graph .nodes :
170
- if (
171
- node .op == "call_function"
172
- and node .target == exir_ops .edge .aten .add .Tensor
173
- ):
174
- add_count_double += 1
175
- node_args = node .args
176
- for arg in node_args :
177
- self .assertEqual (arg .meta ["val" ].dtype , torch .double )
178
-
179
- self .assertEqual (add_count_double , 2 )
180
-
181
145
class Mult (torch .nn .Module ):
182
146
def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
183
147
return x * y
184
148
185
- mult = Mult ()
186
-
187
- float_tensor_vert = float_tensor . T
188
- mult_prog = to_edge ( export ( mult , ( int_tensor , float_tensor_vert ), strict = True ))
189
-
190
- # graph_module_mult.graph.print_tabular( )
191
-
192
- mult_prog = mult_prog . transform ([ RemoveMixedTypeOperators ()])
193
- new_graph_module_mult = mult_prog . exported_program (). graph_module
194
- self . assertIsNotNone ( new_graph_module_mult )
149
+ for module , op , expected_count in (
150
+ ( Add , exir_ops . edge . aten . add . Tensor , 2 ),
151
+ ( Mult , exir_ops . edge . aten . mul . Tensor , 1 ),
152
+ ):
153
+ for second_arg_dtype in ( torch . int64 , torch . float , torch . double ):
154
+ int_tensor = torch . tensor ([[ 1 , 2 , 3 ]], dtype = torch . int64 )
155
+ float_tensor = torch . tensor ([[ 1.0 , 2.0 , 3.0 ]], dtype = second_arg_dtype )
156
+ edge_prog = to_edge (
157
+ export ( module (), ( int_tensor , float_tensor ), strict = True )
158
+ )
195
159
196
- mult_count = 0
160
+ new_prog = edge_prog .transform ([RemoveMixedTypeOperators ()])
161
+ new_graph_module = new_prog .exported_program ().graph_module
162
+ self .assertIsNotNone (new_graph_module )
197
163
198
- for node in new_graph_module_mult .graph .nodes :
199
- if (
200
- node .op == "call_function"
201
- and node .target == exir_ops .edge .aten .mul .Tensor
202
- ):
203
- mult_count += 1
204
- node_args = node .args
205
- for arg in node_args :
206
- self .assertEqual (arg .meta ["val" ].dtype , torch .float )
207
-
208
- self .assertEqual (mult_count , 1 )
164
+ count = count_nodes_with_target_asserting_arguments_have_dtype (
165
+ new_graph_module , op , second_arg_dtype
166
+ )
167
+ self .assertEqual (count , expected_count )
209
168
210
169
def test_remove_noop_pass (self ) -> None :
211
170
class Foo (torch .nn .Module ):
0 commit comments