diff --git a/python/shark_turbine/aot/compiled_module.py b/python/shark_turbine/aot/compiled_module.py index 723e36ffa..77ba55a34 100644 --- a/python/shark_turbine/aot/compiled_module.py +++ b/python/shark_turbine/aot/compiled_module.py @@ -193,8 +193,11 @@ def def_export_proc(self, name, f) -> ExportProcDef: raise TypeError( f"exported functions only support positional parameters" ) - param_desc = param.default - if param_desc is inspect.Parameter.empty: + if param.default is not inspect.Parameter.empty: + param_desc = param.default + elif param.annotation is not inspect.Parameter.empty: + param_desc = param.annotation + else: # TODO: Merge from a decorator? raise TypeError( f"export function {name} missing required default value annotation " diff --git a/tests/aot/args_test.py b/tests/aot/args_test.py index f833b888e..8b5450620 100644 --- a/tests/aot/args_test.py +++ b/tests/aot/args_test.py @@ -17,7 +17,7 @@ class ArgsTest(unittest.TestCase): def testProcArgs(self): class ProcArgsModule(CompiledModule): - def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): + def foobar(self, a: AbstractTensor(3, 2), b: AbstractTensor(1, 1)): return b, a inst = ProcArgsModule(context=Context()) @@ -31,7 +31,7 @@ def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): def testProcToJitArgs(self): class ProcArgsModule(CompiledModule): - def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): + def foobar(self, a: AbstractTensor(3, 2), b: AbstractTensor(1, 1)): return self.compute(a, b) @jittable @@ -56,7 +56,7 @@ def compute(a, b): def testProcToJitArgs(self): class ProcArgsModule(CompiledModule): - def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): + def foobar(self, a: AbstractTensor(3, 2), b: AbstractTensor(1, 1)): x = self.compute(a, b) y = self.compute(x, a) return y diff --git a/tests/aot/globals_test.py b/tests/aot/globals_test.py index b3fa71eee..8461b1554 100644 --- a/tests/aot/globals_test.py +++ b/tests/aot/globals_test.py @@ -34,7 +34,7 @@ class GlobalModule(CompiledModule): params = export_parameters(m) compute = jittable(m.forward) - def run(self, x=AbstractTensor(128, 20)): + def run(self, x: AbstractTensor(128, 20)): return self.compute(x) inst = GlobalModule(context=Context()) @@ -91,7 +91,7 @@ def testGlobalStoreFromPyTree(self): class GlobalModule(CompiledModule): params = export_parameters(m, initialize=False, mutable=True) - def update_params(me, updates=abstractify(params)): + def update_params(me, updates: abstractify(params)): self.assertIn("classifier.weight", updates) self.assertIn("classifier.bias", updates) me.params = updates @@ -108,7 +108,7 @@ def testGlobalStoreFromLeaf(self): class GlobalModule(CompiledModule): params = export_parameters(m, initialize=False, mutable=True) - def update_bias(self, new_bias=abstractify(params["classifier.bias"])): + def update_bias(self, new_bias: abstractify(params["classifier.bias"])): self.params["classifier.bias"] = new_bias inst = GlobalModule(context=Context()) @@ -177,7 +177,7 @@ def testUpdateGlobalStateTree(self): class SingleState(CompiledModule): state0 = export_global_tree(state_example, mutable=True, initialize=False) - def read_state(self, updates=abstractify(state_example)): + def read_state(self, updates: abstractify(state_example)): self.state0 = updates inst = SingleState(context=Context()) @@ -199,7 +199,7 @@ def testTensorUpdateGlobal(self): class SingleState(CompiledModule): state0 = export_global(state_example, mutable=True, initialize=False) - def tensor_update_state(self, update=abstractify(update_example)): + def tensor_update_state(self, update: abstractify(update_example)): IREE.tensor_update(self.state0, update, 0, 0) inst = SingleState(context=Context()) diff --git a/tests/aot/iree_procedural_test.py b/tests/aot/iree_procedural_test.py index aa0228e26..4bd7ff34b 100644 --- a/tests/aot/iree_procedural_test.py +++ b/tests/aot/iree_procedural_test.py @@ -19,7 +19,7 @@ class CompiledModuleAPI(unittest.TestCase): def testTensorDim(self): class BasicModule(CompiledModule): - def foobar(self, a=AbstractTensor(None, 3)): + def foobar(self, a: AbstractTensor(None, 3)): return IREE.tensor_dim(a, 0) inst = BasicModule(context=Context()) @@ -31,7 +31,7 @@ def foobar(self, a=AbstractTensor(None, 3)): def testTensorEmpty(self): class BasicModule(CompiledModule): - def foobar(self, x=AbstractIndex): + def foobar(self, x: AbstractIndex): empty = IREE.tensor_empty(x, 16) dim0 = IREE.tensor_dim(empty, 0) return empty, dim0 @@ -46,7 +46,7 @@ def foobar(self, x=AbstractIndex): def testTensorSplat(self): class BasicModule(CompiledModule): - def foobar(self, x=AbstractIndex, y=AbstractF32): + def foobar(self, x: AbstractIndex, y: AbstractF32): empty = IREE.tensor_splat(x, 34, value=y, dtype=torch.float32) dim0 = IREE.tensor_dim(empty, 0) return empty, dim0 @@ -63,7 +63,7 @@ def foobar(self, x=AbstractIndex, y=AbstractF32): def testTensorTrace(self): class BasicModule(CompiledModule): - def foobar(self, x=AbstractTensor(None), y=AbstractTensor(3)): + def foobar(self, x: AbstractTensor(None), y: AbstractTensor(3)): IREE.tensor_trace("DEBUG", x, y) inst = BasicModule(context=Context()) @@ -75,7 +75,7 @@ def testStoreDynamic(self): class BasicModule(CompiledModule): x = export_global(AbstractTensor(None, 34), mutable=True) - def foobar(self, x=AbstractIndex, y=AbstractF32): + def foobar(self, x: AbstractIndex, y: AbstractF32): splat = IREE.tensor_splat(x, 34, value=y, dtype=torch.float32) self.x = splat @@ -91,7 +91,7 @@ def foobar(self, x=AbstractIndex, y=AbstractF32): def testTensorSliceStatic(self): class BasicModule(CompiledModule): - def foobar(self, x=AbstractTensor(3, 4)): + def foobar(self, x: AbstractTensor(3, 4)): return IREE.tensor_slice(x, 0, (1, 3)) inst = BasicModule(context=Context()) @@ -104,7 +104,7 @@ def foobar(self, x=AbstractTensor(3, 4)): def testTensorSliceDynamicIndex(self): class SliceDynamicIndex(CompiledModule): - def foobar(self, x=AbstractIndex): + def foobar(self, x: AbstractIndex): empty = IREE.tensor_empty(x, 16) return IREE.tensor_slice(empty, x, 4) @@ -118,7 +118,7 @@ def foobar(self, x=AbstractIndex): def testTensorSliceDynamicLength(self): class SliceDynamicIndex(CompiledModule): - def foobar(self, x=AbstractIndex, y=AbstractIndex): + def foobar(self, x: AbstractIndex, y: AbstractIndex): empty = IREE.tensor_empty(x, 16) return IREE.tensor_slice(empty, (x, y), 4) @@ -134,10 +134,10 @@ def testTensorUpdateStatic(self): class UpdateStatic(CompiledModule): def foobar( self, - target=AbstractTensor(4, 4), - update=AbstractTensor(2, 2), - i=AbstractIndex, - j=AbstractIndex, + target: AbstractTensor(4, 4), + update: AbstractTensor(2, 2), + i: AbstractIndex, + j: AbstractIndex, ): return IREE.tensor_update(target, update, i, j) @@ -153,11 +153,11 @@ def testTensorUpdateDynamic(self): class UpdateDynamic(CompiledModule): def foobar( self, - x=AbstractIndex, - y=AbstractIndex, - i=AbstractIndex, - j=AbstractIndex, - value=AbstractF32, + x: AbstractIndex, + y: AbstractIndex, + i: AbstractIndex, + j: AbstractIndex, + value: AbstractF32, ): target = IREE.tensor_empty(x, y) update = IREE.tensor_splat(i, j, value=value, dtype=torch.float32) @@ -173,7 +173,7 @@ def foobar( def testTensorReshape(self): class ReshapeModule(CompiledModule): - def foobar(self, x=AbstractIndex, y=AbstractIndex): + def foobar(self, x: AbstractIndex, y: AbstractIndex): empty = IREE.tensor_empty(x, 16) reshaped = IREE.tensor_reshape(empty, 1, y, y) return reshaped @@ -188,7 +188,7 @@ def foobar(self, x=AbstractIndex, y=AbstractIndex): def testScalarAddInt(self): class ArithModule(CompiledModule): - def foobar(self, a=AbstractI32, b=AbstractI32): + def foobar(self, a: AbstractI32, b: AbstractI32): return a + b inst = ArithModule(context=Context()) @@ -197,7 +197,7 @@ def foobar(self, a=AbstractI32, b=AbstractI32): def testScalarAddFloat(self): class ArithModule(CompiledModule): - def foobar(self, a=AbstractF32, b=AbstractF32): + def foobar(self, a: AbstractF32, b: AbstractF32): return a + b inst = ArithModule(context=Context()) @@ -206,7 +206,7 @@ def foobar(self, a=AbstractF32, b=AbstractF32): def testScalarAddLiteral(self): class ArithModule(CompiledModule): - def foobar(self, a=AbstractI32): + def foobar(self, a: AbstractI32): return a + 1 inst = ArithModule(context=Context()) @@ -216,7 +216,7 @@ def foobar(self, a=AbstractI32): def testScalarAddLiteralMixedType(self): class ArithModule(CompiledModule): - def foobar(self, a=AbstractI32): + def foobar(self, a: AbstractI32): return a + 3.23 inst = ArithModule(context=Context())