Skip to content

Commit d952326

Browse files
authored
fix type promotion for div in RemoveMixedTypeOperators (#12157)
The promotion strategy is dependent on the rounding mode (see the div decomp in PyTorch https://github.com/pytorch/pytorch/blob/main/torch/_refs/__init__.py#L1214 and then the promotion annotation on each of the true_divide/trunc_divide/floor_divide functions itcalls). I had to restructure the test a bit more so that lint didn't complain it was too complex.
1 parent ef3cefe commit d952326

File tree

2 files changed

+97
-29
lines changed

2 files changed

+97
-29
lines changed

exir/passes/remove_mixed_type_operators.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,20 @@ def call_operator(self, op, args, kwargs, meta: NodeMetadata): # noqa: C901
2323
promotion_type_allow_list = {
2424
torch.ops.aten.add.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
2525
torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
26-
torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
26+
# The correct promotion for div depends on the mode! If there is no mode,
27+
# it's INT_TO_FLOAT, otherwise it's default.
28+
torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
29+
torch.ops.aten.div.Tensor_mode: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
2730
torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
2831
}
2932

3033
if op in promotion_type_allow_list:
3134
promotion_kind = promotion_type_allow_list[op]
35+
if (
36+
op == torch.ops.aten.div.Tensor_mode
37+
and kwargs.get("rounding_mode") is None
38+
):
39+
promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
3240
else:
3341
# Not in allow list, do nothing
3442
return super().call_operator(op, args, kwargs, meta)

exir/tests/test_passes.py

Lines changed: 88 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import os
1010
import tempfile
1111
import unittest
12-
from typing import List, Optional, Tuple
12+
from typing import Callable, List, Optional, Tuple
1313

1414
import executorch.exir as exir
1515

@@ -71,6 +71,7 @@
7171
from functorch.experimental import control_flow
7272

7373
from torch import nn
74+
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
7475
from torch.export import export
7576
from torch.export.graph_signature import InputKind, InputSpec, TensorArgument
7677
from torch.fx import GraphModule, subgraph_rewriter
@@ -121,39 +122,97 @@ def foo_out(
121122
return a + 1, None
122123

123124

125+
def simple_promote_dtype(
126+
dtype: torch.dtype, promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND
127+
) -> torch.dtype:
128+
if promotion_kind == ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT:
129+
return dtype
130+
if promotion_kind == ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT:
131+
return dtype if dtype.is_floating_point else torch.float
132+
else:
133+
raise Exception(f"Unsupported promotion kind {promotion_kind}")
134+
135+
136+
def count_nodes_with_target_asserting_arguments_have_dtype(
137+
self, module, target, arg_dtype
138+
) -> int:
139+
count = 0
140+
for node in module.graph.nodes:
141+
if node.op == "call_function" and node.target == target:
142+
count += 1
143+
for arg in node.args:
144+
self.assertEqual(arg.meta["val"].dtype, arg_dtype)
145+
return count
146+
147+
124148
class TestPasses(unittest.TestCase):
125149
@classmethod
126150
def setUpClass(cls) -> None:
127151
register_additional_test_aten_ops()
128152

129153
def test_remove_mixed_type_operators(self) -> None:
130-
def count_nodes_with_target_asserting_arguments_have_dtype(
131-
new_graph_module, target, arg_dtype
132-
):
133-
count = 0
134-
for node in new_graph_module.graph.nodes:
135-
if node.op == "call_function" and node.target == target:
136-
count += 1
137-
for arg in node.args:
138-
self.assertEqual(arg.meta["val"].dtype, arg_dtype)
139-
return count
140-
141-
class Add(torch.nn.Module):
142-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
143-
return (x + y) + x
144-
145-
class Mult(torch.nn.Module):
146-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
147-
return x * y
148-
149-
class Minimum(torch.nn.Module):
150-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
151-
return torch.minimum(x, y)
154+
def make_module(fwd: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]):
155+
class Module(torch.nn.Module):
156+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
157+
return fwd(x, y)
158+
159+
return Module
160+
161+
Add = make_module(lambda x, y: (x + y) + x)
162+
Mult = make_module(lambda x, y: x * y)
163+
Minimum = make_module(torch.minimum)
164+
DivWithoutMode = make_module(torch.div)
165+
DivWithNoneMode = make_module(lambda x, y: torch.div(x, y, rounding_mode=None))
166+
DivWithTruncMode = make_module(
167+
lambda x, y: torch.div(x, y, rounding_mode="trunc")
168+
)
169+
DivWithFloorMode = make_module(
170+
lambda x, y: torch.div(x, y, rounding_mode="floor")
171+
)
152172

153-
for module, op, expected_count in (
154-
(Add, exir_ops.edge.aten.add.Tensor, 2),
155-
(Mult, exir_ops.edge.aten.mul.Tensor, 1),
156-
(Minimum, exir_ops.edge.aten.minimum.default, 1),
173+
for module, op, expected_count, promotion_kind in (
174+
(
175+
Add,
176+
exir_ops.edge.aten.add.Tensor,
177+
2,
178+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
179+
),
180+
(
181+
Mult,
182+
exir_ops.edge.aten.mul.Tensor,
183+
1,
184+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
185+
),
186+
(
187+
Minimum,
188+
exir_ops.edge.aten.minimum.default,
189+
1,
190+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
191+
),
192+
(
193+
DivWithoutMode,
194+
exir_ops.edge.aten.div.Tensor,
195+
1,
196+
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
197+
),
198+
(
199+
DivWithNoneMode,
200+
exir_ops.edge.aten.div.Tensor_mode,
201+
1,
202+
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
203+
),
204+
(
205+
DivWithTruncMode,
206+
exir_ops.edge.aten.div.Tensor_mode,
207+
1,
208+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
209+
),
210+
(
211+
DivWithFloorMode,
212+
exir_ops.edge.aten.div.Tensor_mode,
213+
1,
214+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
215+
),
157216
):
158217
for second_arg_dtype in (torch.int64, torch.float, torch.double):
159218
int_tensor = torch.tensor([[1, 2, 3]], dtype=torch.int64)
@@ -166,8 +225,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
166225
new_graph_module = new_prog.exported_program().graph_module
167226
self.assertIsNotNone(new_graph_module)
168227

228+
promoted_type = simple_promote_dtype(second_arg_dtype, promotion_kind)
169229
count = count_nodes_with_target_asserting_arguments_have_dtype(
170-
new_graph_module, op, second_arg_dtype
230+
self, new_graph_module, op, promoted_type
171231
)
172232
self.assertEqual(count, expected_count)
173233

0 commit comments

Comments
 (0)