Skip to content

Commit 71522c4

Browse files
authored
Clean up TestPasses.test_remove_mixed_type_operators (#12155)
I want to add more tests here. First, clean up the existing coverage.
1 parent 6669637 commit 71522c4

File tree

1 file changed

+28
-69
lines changed

1 file changed

+28
-69
lines changed

exir/tests/test_passes.py

Lines changed: 28 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -127,85 +127,44 @@ def setUpClass(cls) -> None:
127127
register_additional_test_aten_ops()
128128

129129
def test_remove_mixed_type_operators(self) -> None:
130+
def count_nodes_with_target_asserting_arguments_have_dtype(
131+
new_graph_module, target, arg_dtype
132+
):
133+
count = 0
134+
for node in new_graph_module.graph.nodes:
135+
if node.op == "call_function" and node.target == target:
136+
count += 1
137+
for arg in node.args:
138+
self.assertEqual(arg.meta["val"].dtype, arg_dtype)
139+
return count
140+
130141
class Add(torch.nn.Module):
131142
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
132143
return (x + y) + x
133144

134-
add = Add()
135-
136-
int_tensor = torch.tensor([[1, 2, 3]])
137-
float_tensor = torch.tensor([[1.0, 2.0, 3.0]])
138-
edge_prog = to_edge(export(add, (int_tensor, float_tensor), strict=True))
139-
140-
new_prog = edge_prog.transform([RemoveMixedTypeOperators()])
141-
new_graph_module = new_prog.exported_program().graph_module
142-
self.assertIsNotNone(new_graph_module)
143-
144-
add_count = 0
145-
146-
for node in new_graph_module.graph.nodes:
147-
if (
148-
node.op == "call_function"
149-
and node.target == exir_ops.edge.aten.add.Tensor
150-
):
151-
add_count += 1
152-
node_args = node.args
153-
for arg in node_args:
154-
self.assertEqual(arg.meta["val"].dtype, torch.float)
155-
156-
self.assertEqual(add_count, 2)
157-
158-
double_tensor = torch.tensor([[1.0, 2.0, 3.0]])
159-
double_tensor = double_tensor.to(torch.double)
160-
161-
double_prog = to_edge(export(add, (int_tensor, double_tensor), strict=True))
162-
163-
double_prog.transform([RemoveMixedTypeOperators()])
164-
new_graph_module_double = double_prog.exported_program().graph_module
165-
self.assertIsNotNone(new_graph_module_double)
166-
167-
add_count_double = 0
168-
169-
for node in new_graph_module_double.graph.nodes:
170-
if (
171-
node.op == "call_function"
172-
and node.target == exir_ops.edge.aten.add.Tensor
173-
):
174-
add_count_double += 1
175-
node_args = node.args
176-
for arg in node_args:
177-
self.assertEqual(arg.meta["val"].dtype, torch.double)
178-
179-
self.assertEqual(add_count_double, 2)
180-
181145
class Mult(torch.nn.Module):
182146
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
183147
return x * y
184148

185-
mult = Mult()
186-
187-
float_tensor_vert = float_tensor.T
188-
mult_prog = to_edge(export(mult, (int_tensor, float_tensor_vert), strict=True))
189-
190-
# graph_module_mult.graph.print_tabular()
191-
192-
mult_prog = mult_prog.transform([RemoveMixedTypeOperators()])
193-
new_graph_module_mult = mult_prog.exported_program().graph_module
194-
self.assertIsNotNone(new_graph_module_mult)
149+
for module, op, expected_count in (
150+
(Add, exir_ops.edge.aten.add.Tensor, 2),
151+
(Mult, exir_ops.edge.aten.mul.Tensor, 1),
152+
):
153+
for second_arg_dtype in (torch.int64, torch.float, torch.double):
154+
int_tensor = torch.tensor([[1, 2, 3]], dtype=torch.int64)
155+
float_tensor = torch.tensor([[1.0, 2.0, 3.0]], dtype=second_arg_dtype)
156+
edge_prog = to_edge(
157+
export(module(), (int_tensor, float_tensor), strict=True)
158+
)
195159

196-
mult_count = 0
160+
new_prog = edge_prog.transform([RemoveMixedTypeOperators()])
161+
new_graph_module = new_prog.exported_program().graph_module
162+
self.assertIsNotNone(new_graph_module)
197163

198-
for node in new_graph_module_mult.graph.nodes:
199-
if (
200-
node.op == "call_function"
201-
and node.target == exir_ops.edge.aten.mul.Tensor
202-
):
203-
mult_count += 1
204-
node_args = node.args
205-
for arg in node_args:
206-
self.assertEqual(arg.meta["val"].dtype, torch.float)
207-
208-
self.assertEqual(mult_count, 1)
164+
count = count_nodes_with_target_asserting_arguments_have_dtype(
165+
new_graph_module, op, second_arg_dtype
166+
)
167+
self.assertEqual(count, expected_count)
209168

210169
def test_remove_noop_pass(self) -> None:
211170
class Foo(torch.nn.Module):

0 commit comments

Comments
 (0)