Skip to content

Commit d80899a

Browse files
authored
[AOT][primitives] Implemented __add__ support for IrScalar (#66)
1 parent 36ec590 commit d80899a

File tree

3 files changed

+83
-0
lines changed

3 files changed

+83
-0
lines changed

python/shark_turbine/aot/support/ir_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,12 @@ def build_tensor_dim_value(
280280
) -> Value:
281281
dim_value = build_index_value(dim, constant_cache=constant_cache)
282282
return tensor_d.DimOp(t, dim_value).result
283+
284+
285+
# API name inspired by mlir/python/mlir/dialects/_arith_ops_ext.py
286+
def _is_float_type(type):
287+
return isinstance(type, (BF16Type, F16Type, F32Type, F64Type))
288+
289+
290+
def _is_integer_like_type(type):
291+
return isinstance(type, (IntegerType, IndexType))

python/shark_turbine/aot/support/procedural/primitives.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@
2020
import torch
2121

2222
from ..ir_imports import (
23+
F32Type,
2324
IrType,
2425
RankedTensorType,
2526
Value,
27+
arith_d,
2628
)
2729

2830
from ..ir_utils import (
2931
build_tensor_dim_value,
32+
_is_float_type,
33+
_is_integer_like_type,
3034
)
3135

3236
from ..utils import (
@@ -38,6 +42,7 @@
3842
Intrinsic,
3943
IrTrace,
4044
ShapedTypeDynamicSizeSentinel,
45+
current_ir_trace,
4146
)
4247

4348
###############################################################################
@@ -58,6 +63,36 @@ class IrScalar(Intrinsic):
5863
def __init__(self, ir_type: IrType):
5964
self.ir_type = ir_type
6065

66+
def __add__(self, other):
67+
t = current_ir_trace()
68+
with t.ip, t.loc:
69+
# Type check and promotion.
70+
# TODO: Add more comprehensive type promotion hiearchy as seen in
71+
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
72+
lhs = self.ir_value
73+
if isinstance(other, IrScalar):
74+
# Assumes when both are Value, they have same type.
75+
rhs = other.ir_value
76+
elif isinstance(other, (int, bool)):
77+
rhs = arith_d.ConstantOp(lhs.type, other).result
78+
elif isinstance(other, float) and _is_integer_like_type(self.ir_type):
79+
lhs = arith_d.SIToFPOp(F32Type.get(), lhs).result
80+
rhs = arith_d.ConstantOp(F32Type.get(), other).result
81+
82+
# Checks that lhs and rhs has same type.
83+
if lhs.type != rhs.type:
84+
raise ValueError("Mismatch type between lhs and rhs.")
85+
86+
# Emit computation.
87+
if _is_integer_like_type(lhs.type):
88+
return arith_d.AddIOp(lhs, rhs).result
89+
elif _is_float_type(lhs.type):
90+
return arith_d.AddFOp(lhs, rhs).result
91+
else:
92+
raise ValueError(
93+
f"Expected operand to be either Int or Float but got {self.ir_type} instead."
94+
)
95+
6196

6297
class IrImmediateScalar(IrScalar):
6398
"""Represents an IR scalar value."""

tests/aot/iree_procedural_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,45 @@ def foobar(self, x=AbstractIndex, y=AbstractIndex):
186186
module_str,
187187
)
188188

189+
def testScalarAddInt(self):
190+
class ArithModule(CompiledModule):
191+
def foobar(self, a=AbstractI32, b=AbstractI32):
192+
return a + b
193+
194+
inst = ArithModule(context=Context())
195+
module_str = str(CompiledModule.get_mlir_module(inst))
196+
self.assertIn("arith.addi %arg0, %arg1 : i32", module_str)
197+
198+
def testScalarAddFloat(self):
199+
class ArithModule(CompiledModule):
200+
def foobar(self, a=AbstractF32, b=AbstractF32):
201+
return a + b
202+
203+
inst = ArithModule(context=Context())
204+
module_str = str(CompiledModule.get_mlir_module(inst))
205+
self.assertIn("arith.addf %arg0, %arg1 : f32", module_str)
206+
207+
def testScalarAddLiteral(self):
208+
class ArithModule(CompiledModule):
209+
def foobar(self, a=AbstractI32):
210+
return a + 1
211+
212+
inst = ArithModule(context=Context())
213+
module_str = str(CompiledModule.get_mlir_module(inst))
214+
self.assertIn("%c1_i32 = arith.constant 1 : i32", module_str)
215+
self.assertIn("arith.addi %arg0, %c1_i32 : i32", module_str)
216+
217+
def testScalarAddLiteralMixedType(self):
218+
class ArithModule(CompiledModule):
219+
def foobar(self, a=AbstractI32):
220+
return a + 3.23
221+
222+
inst = ArithModule(context=Context())
223+
module_str = str(CompiledModule.get_mlir_module(inst))
224+
self.assertIn("%0 = arith.sitofp %arg0 : i32 to f32", module_str)
225+
self.assertIn("%cst = arith.constant 3.230000e+00 : f32", module_str)
226+
self.assertIn("arith.addf %0, %cst : f32", module_str)
227+
189228

190229
if __name__ == "__main__":
191230
logging.basicConfig(level=logging.DEBUG)

0 commit comments

Comments
 (0)