Skip to content

Commit 8667f6a

Browse files
committed
[AOT][ExportProc] Use annotation for extra signature details.
1 parent d80899a commit 8667f6a

File tree

2 files changed

+27
-24
lines changed

2 files changed

+27
-24
lines changed

python/shark_turbine/aot/compiled_module.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,11 @@ def def_export_proc(self, name, f) -> ExportProcDef:
193193
raise TypeError(
194194
f"exported functions only support positional parameters"
195195
)
196-
param_desc = param.default
197-
if param_desc is inspect.Parameter.empty:
196+
if param.default is not inspect.Parameter.empty:
197+
param_desc = param.default
198+
elif param.annotation is not inspect.Parameter.empty:
199+
param_desc = param.annotation
200+
else:
198201
# TODO: Merge from a decorator?
199202
raise TypeError(
200203
f"export function {name} missing required default value annotation "

tests/aot/iree_procedural_test.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
class CompiledModuleAPI(unittest.TestCase):
2020
def testTensorDim(self):
2121
class BasicModule(CompiledModule):
22-
def foobar(self, a=AbstractTensor(None, 3)):
22+
def foobar(self, a: AbstractTensor(None, 3)):
2323
return IREE.tensor_dim(a, 0)
2424

2525
inst = BasicModule(context=Context())
@@ -31,7 +31,7 @@ def foobar(self, a=AbstractTensor(None, 3)):
3131

3232
def testTensorEmpty(self):
3333
class BasicModule(CompiledModule):
34-
def foobar(self, x=AbstractIndex):
34+
def foobar(self, x: AbstractIndex):
3535
empty = IREE.tensor_empty(x, 16)
3636
dim0 = IREE.tensor_dim(empty, 0)
3737
return empty, dim0
@@ -46,7 +46,7 @@ def foobar(self, x=AbstractIndex):
4646

4747
def testTensorSplat(self):
4848
class BasicModule(CompiledModule):
49-
def foobar(self, x=AbstractIndex, y=AbstractF32):
49+
def foobar(self, x: AbstractIndex, y: AbstractF32):
5050
empty = IREE.tensor_splat(x, 34, value=y, dtype=torch.float32)
5151
dim0 = IREE.tensor_dim(empty, 0)
5252
return empty, dim0
@@ -63,7 +63,7 @@ def foobar(self, x=AbstractIndex, y=AbstractF32):
6363

6464
def testTensorTrace(self):
6565
class BasicModule(CompiledModule):
66-
def foobar(self, x=AbstractTensor(None), y=AbstractTensor(3)):
66+
def foobar(self, x: AbstractTensor(None), y: AbstractTensor(3)):
6767
IREE.tensor_trace("DEBUG", x, y)
6868

6969
inst = BasicModule(context=Context())
@@ -75,7 +75,7 @@ def testStoreDynamic(self):
7575
class BasicModule(CompiledModule):
7676
x = export_global(AbstractTensor(None, 34), mutable=True)
7777

78-
def foobar(self, x=AbstractIndex, y=AbstractF32):
78+
def foobar(self, x: AbstractIndex, y: AbstractF32):
7979
splat = IREE.tensor_splat(x, 34, value=y, dtype=torch.float32)
8080
self.x = splat
8181

@@ -91,7 +91,7 @@ def foobar(self, x=AbstractIndex, y=AbstractF32):
9191

9292
def testTensorSliceStatic(self):
9393
class BasicModule(CompiledModule):
94-
def foobar(self, x=AbstractTensor(3, 4)):
94+
def foobar(self, x: AbstractTensor(3, 4)):
9595
return IREE.tensor_slice(x, 0, (1, 3))
9696

9797
inst = BasicModule(context=Context())
@@ -104,7 +104,7 @@ def foobar(self, x=AbstractTensor(3, 4)):
104104

105105
def testTensorSliceDynamicIndex(self):
106106
class SliceDynamicIndex(CompiledModule):
107-
def foobar(self, x=AbstractIndex):
107+
def foobar(self, x: AbstractIndex):
108108
empty = IREE.tensor_empty(x, 16)
109109
return IREE.tensor_slice(empty, x, 4)
110110

@@ -118,7 +118,7 @@ def foobar(self, x=AbstractIndex):
118118

119119
def testTensorSliceDynamicLength(self):
120120
class SliceDynamicIndex(CompiledModule):
121-
def foobar(self, x=AbstractIndex, y=AbstractIndex):
121+
def foobar(self, x: AbstractIndex, y: AbstractIndex):
122122
empty = IREE.tensor_empty(x, 16)
123123
return IREE.tensor_slice(empty, (x, y), 4)
124124

@@ -134,10 +134,10 @@ def testTensorUpdateStatic(self):
134134
class UpdateStatic(CompiledModule):
135135
def foobar(
136136
self,
137-
target=AbstractTensor(4, 4),
138-
update=AbstractTensor(2, 2),
139-
i=AbstractIndex,
140-
j=AbstractIndex,
137+
target: AbstractTensor(4, 4),
138+
update: AbstractTensor(2, 2),
139+
i: AbstractIndex,
140+
j: AbstractIndex,
141141
):
142142
return IREE.tensor_update(target, update, i, j)
143143

@@ -153,11 +153,11 @@ def testTensorUpdateDynamic(self):
153153
class UpdateDynamic(CompiledModule):
154154
def foobar(
155155
self,
156-
x=AbstractIndex,
157-
y=AbstractIndex,
158-
i=AbstractIndex,
159-
j=AbstractIndex,
160-
value=AbstractF32,
156+
x: AbstractIndex,
157+
y: AbstractIndex,
158+
i: AbstractIndex,
159+
j: AbstractIndex,
160+
value: AbstractF32,
161161
):
162162
target = IREE.tensor_empty(x, y)
163163
update = IREE.tensor_splat(i, j, value=value, dtype=torch.float32)
@@ -173,7 +173,7 @@ def foobar(
173173

174174
def testTensorReshape(self):
175175
class ReshapeModule(CompiledModule):
176-
def foobar(self, x=AbstractIndex, y=AbstractIndex):
176+
def foobar(self, x: AbstractIndex, y: AbstractIndex):
177177
empty = IREE.tensor_empty(x, 16)
178178
reshaped = IREE.tensor_reshape(empty, 1, y, y)
179179
return reshaped
@@ -188,7 +188,7 @@ def foobar(self, x=AbstractIndex, y=AbstractIndex):
188188

189189
def testScalarAddInt(self):
190190
class ArithModule(CompiledModule):
191-
def foobar(self, a=AbstractI32, b=AbstractI32):
191+
def foobar(self, a: AbstractI32, b: AbstractI32):
192192
return a + b
193193

194194
inst = ArithModule(context=Context())
@@ -197,7 +197,7 @@ def foobar(self, a=AbstractI32, b=AbstractI32):
197197

198198
def testScalarAddFloat(self):
199199
class ArithModule(CompiledModule):
200-
def foobar(self, a=AbstractF32, b=AbstractF32):
200+
def foobar(self, a: AbstractF32, b: AbstractF32):
201201
return a + b
202202

203203
inst = ArithModule(context=Context())
@@ -206,7 +206,7 @@ def foobar(self, a=AbstractF32, b=AbstractF32):
206206

207207
def testScalarAddLiteral(self):
208208
class ArithModule(CompiledModule):
209-
def foobar(self, a=AbstractI32):
209+
def foobar(self, a: AbstractI32):
210210
return a + 1
211211

212212
inst = ArithModule(context=Context())
@@ -216,7 +216,7 @@ def foobar(self, a=AbstractI32):
216216

217217
def testScalarAddLiteralMixedType(self):
218218
class ArithModule(CompiledModule):
219-
def foobar(self, a=AbstractI32):
219+
def foobar(self, a: AbstractI32):
220220
return a + 3.23
221221

222222
inst = ArithModule(context=Context())

0 commit comments

Comments
 (0)