Skip to content

Commit 0ba2c6d

Browse files
committed
[AOT][ExportProc] Use annotation for extra signature details.
1 parent d80899a commit 0ba2c6d

File tree

4 files changed

+35
-32
lines changed

4 files changed

+35
-32
lines changed

python/shark_turbine/aot/compiled_module.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,11 @@ def def_export_proc(self, name, f) -> ExportProcDef:
193193
raise TypeError(
194194
f"exported functions only support positional parameters"
195195
)
196-
param_desc = param.default
197-
if param_desc is inspect.Parameter.empty:
196+
if param.default is not inspect.Parameter.empty:
197+
param_desc = param.default
198+
elif param.annotation is not inspect.Parameter.empty:
199+
param_desc = param.annotation
200+
else:
198201
# TODO: Merge from a decorator?
199202
raise TypeError(
200203
f"export function {name} missing required default value annotation "

tests/aot/args_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
class ArgsTest(unittest.TestCase):
1818
def testProcArgs(self):
1919
class ProcArgsModule(CompiledModule):
20-
def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)):
20+
def foobar(self, a: AbstractTensor(3, 2), b: AbstractTensor(1, 1)):
2121
return b, a
2222

2323
inst = ProcArgsModule(context=Context())
@@ -31,7 +31,7 @@ def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)):
3131

3232
def testProcToJitArgs(self):
3333
class ProcArgsModule(CompiledModule):
34-
def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)):
34+
def foobar(self, a: AbstractTensor(3, 2), b: AbstractTensor(1, 1)):
3535
return self.compute(a, b)
3636

3737
@jittable
@@ -56,7 +56,7 @@ def compute(a, b):
5656

5757
def testProcToJitArgs(self):
5858
class ProcArgsModule(CompiledModule):
59-
def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)):
59+
def foobar(self, a: AbstractTensor(3, 2), b: AbstractTensor(1, 1)):
6060
x = self.compute(a, b)
6161
y = self.compute(x, a)
6262
return y

tests/aot/globals_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class GlobalModule(CompiledModule):
3434
params = export_parameters(m)
3535
compute = jittable(m.forward)
3636

37-
def run(self, x=AbstractTensor(128, 20)):
37+
def run(self, x: AbstractTensor(128, 20)):
3838
return self.compute(x)
3939

4040
inst = GlobalModule(context=Context())
@@ -91,7 +91,7 @@ def testGlobalStoreFromPyTree(self):
9191
class GlobalModule(CompiledModule):
9292
params = export_parameters(m, initialize=False, mutable=True)
9393

94-
def update_params(me, updates=abstractify(params)):
94+
def update_params(me, updates: abstractify(params)):
9595
self.assertIn("classifier.weight", updates)
9696
self.assertIn("classifier.bias", updates)
9797
me.params = updates
@@ -108,7 +108,7 @@ def testGlobalStoreFromLeaf(self):
108108
class GlobalModule(CompiledModule):
109109
params = export_parameters(m, initialize=False, mutable=True)
110110

111-
def update_bias(self, new_bias=abstractify(params["classifier.bias"])):
111+
def update_bias(self, new_bias: abstractify(params["classifier.bias"])):
112112
self.params["classifier.bias"] = new_bias
113113

114114
inst = GlobalModule(context=Context())
@@ -177,7 +177,7 @@ def testUpdateGlobalStateTree(self):
177177
class SingleState(CompiledModule):
178178
state0 = export_global_tree(state_example, mutable=True, initialize=False)
179179

180-
def read_state(self, updates=abstractify(state_example)):
180+
def read_state(self, updates: abstractify(state_example)):
181181
self.state0 = updates
182182

183183
inst = SingleState(context=Context())
@@ -199,7 +199,7 @@ def testTensorUpdateGlobal(self):
199199
class SingleState(CompiledModule):
200200
state0 = export_global(state_example, mutable=True, initialize=False)
201201

202-
def tensor_update_state(self, update=abstractify(update_example)):
202+
def tensor_update_state(self, update: abstractify(update_example)):
203203
IREE.tensor_update(self.state0, update, 0, 0)
204204

205205
inst = SingleState(context=Context())

tests/aot/iree_procedural_test.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
class CompiledModuleAPI(unittest.TestCase):
2020
def testTensorDim(self):
2121
class BasicModule(CompiledModule):
22-
def foobar(self, a=AbstractTensor(None, 3)):
22+
def foobar(self, a: AbstractTensor(None, 3)):
2323
return IREE.tensor_dim(a, 0)
2424

2525
inst = BasicModule(context=Context())
@@ -31,7 +31,7 @@ def foobar(self, a=AbstractTensor(None, 3)):
3131

3232
def testTensorEmpty(self):
3333
class BasicModule(CompiledModule):
34-
def foobar(self, x=AbstractIndex):
34+
def foobar(self, x: AbstractIndex):
3535
empty = IREE.tensor_empty(x, 16)
3636
dim0 = IREE.tensor_dim(empty, 0)
3737
return empty, dim0
@@ -46,7 +46,7 @@ def foobar(self, x=AbstractIndex):
4646

4747
def testTensorSplat(self):
4848
class BasicModule(CompiledModule):
49-
def foobar(self, x=AbstractIndex, y=AbstractF32):
49+
def foobar(self, x: AbstractIndex, y: AbstractF32):
5050
empty = IREE.tensor_splat(x, 34, value=y, dtype=torch.float32)
5151
dim0 = IREE.tensor_dim(empty, 0)
5252
return empty, dim0
@@ -63,7 +63,7 @@ def foobar(self, x=AbstractIndex, y=AbstractF32):
6363

6464
def testTensorTrace(self):
6565
class BasicModule(CompiledModule):
66-
def foobar(self, x=AbstractTensor(None), y=AbstractTensor(3)):
66+
def foobar(self, x: AbstractTensor(None), y: AbstractTensor(3)):
6767
IREE.tensor_trace("DEBUG", x, y)
6868

6969
inst = BasicModule(context=Context())
@@ -75,7 +75,7 @@ def testStoreDynamic(self):
7575
class BasicModule(CompiledModule):
7676
x = export_global(AbstractTensor(None, 34), mutable=True)
7777

78-
def foobar(self, x=AbstractIndex, y=AbstractF32):
78+
def foobar(self, x: AbstractIndex, y: AbstractF32):
7979
splat = IREE.tensor_splat(x, 34, value=y, dtype=torch.float32)
8080
self.x = splat
8181

@@ -91,7 +91,7 @@ def foobar(self, x=AbstractIndex, y=AbstractF32):
9191

9292
def testTensorSliceStatic(self):
9393
class BasicModule(CompiledModule):
94-
def foobar(self, x=AbstractTensor(3, 4)):
94+
def foobar(self, x: AbstractTensor(3, 4)):
9595
return IREE.tensor_slice(x, 0, (1, 3))
9696

9797
inst = BasicModule(context=Context())
@@ -104,7 +104,7 @@ def foobar(self, x=AbstractTensor(3, 4)):
104104

105105
def testTensorSliceDynamicIndex(self):
106106
class SliceDynamicIndex(CompiledModule):
107-
def foobar(self, x=AbstractIndex):
107+
def foobar(self, x: AbstractIndex):
108108
empty = IREE.tensor_empty(x, 16)
109109
return IREE.tensor_slice(empty, x, 4)
110110

@@ -118,7 +118,7 @@ def foobar(self, x=AbstractIndex):
118118

119119
def testTensorSliceDynamicLength(self):
120120
class SliceDynamicIndex(CompiledModule):
121-
def foobar(self, x=AbstractIndex, y=AbstractIndex):
121+
def foobar(self, x: AbstractIndex, y: AbstractIndex):
122122
empty = IREE.tensor_empty(x, 16)
123123
return IREE.tensor_slice(empty, (x, y), 4)
124124

@@ -134,10 +134,10 @@ def testTensorUpdateStatic(self):
134134
class UpdateStatic(CompiledModule):
135135
def foobar(
136136
self,
137-
target=AbstractTensor(4, 4),
138-
update=AbstractTensor(2, 2),
139-
i=AbstractIndex,
140-
j=AbstractIndex,
137+
target: AbstractTensor(4, 4),
138+
update: AbstractTensor(2, 2),
139+
i: AbstractIndex,
140+
j: AbstractIndex,
141141
):
142142
return IREE.tensor_update(target, update, i, j)
143143

@@ -153,11 +153,11 @@ def testTensorUpdateDynamic(self):
153153
class UpdateDynamic(CompiledModule):
154154
def foobar(
155155
self,
156-
x=AbstractIndex,
157-
y=AbstractIndex,
158-
i=AbstractIndex,
159-
j=AbstractIndex,
160-
value=AbstractF32,
156+
x: AbstractIndex,
157+
y: AbstractIndex,
158+
i: AbstractIndex,
159+
j: AbstractIndex,
160+
value: AbstractF32,
161161
):
162162
target = IREE.tensor_empty(x, y)
163163
update = IREE.tensor_splat(i, j, value=value, dtype=torch.float32)
@@ -173,7 +173,7 @@ def foobar(
173173

174174
def testTensorReshape(self):
175175
class ReshapeModule(CompiledModule):
176-
def foobar(self, x=AbstractIndex, y=AbstractIndex):
176+
def foobar(self, x: AbstractIndex, y: AbstractIndex):
177177
empty = IREE.tensor_empty(x, 16)
178178
reshaped = IREE.tensor_reshape(empty, 1, y, y)
179179
return reshaped
@@ -188,7 +188,7 @@ def foobar(self, x=AbstractIndex, y=AbstractIndex):
188188

189189
def testScalarAddInt(self):
190190
class ArithModule(CompiledModule):
191-
def foobar(self, a=AbstractI32, b=AbstractI32):
191+
def foobar(self, a: AbstractI32, b: AbstractI32):
192192
return a + b
193193

194194
inst = ArithModule(context=Context())
@@ -197,7 +197,7 @@ def foobar(self, a=AbstractI32, b=AbstractI32):
197197

198198
def testScalarAddFloat(self):
199199
class ArithModule(CompiledModule):
200-
def foobar(self, a=AbstractF32, b=AbstractF32):
200+
def foobar(self, a: AbstractF32, b: AbstractF32):
201201
return a + b
202202

203203
inst = ArithModule(context=Context())
@@ -206,7 +206,7 @@ def foobar(self, a=AbstractF32, b=AbstractF32):
206206

207207
def testScalarAddLiteral(self):
208208
class ArithModule(CompiledModule):
209-
def foobar(self, a=AbstractI32):
209+
def foobar(self, a: AbstractI32):
210210
return a + 1
211211

212212
inst = ArithModule(context=Context())
@@ -216,7 +216,7 @@ def foobar(self, a=AbstractI32):
216216

217217
def testScalarAddLiteralMixedType(self):
218218
class ArithModule(CompiledModule):
219-
def foobar(self, a=AbstractI32):
219+
def foobar(self, a: AbstractI32):
220220
return a + 3.23
221221

222222
inst = ArithModule(context=Context())

0 commit comments

Comments
 (0)