diff --git a/python/shark_turbine/aot/passes/functorch.py b/python/shark_turbine/aot/passes/functorch.py index 52f3feb66..72485f13c 100644 --- a/python/shark_turbine/aot/passes/functorch.py +++ b/python/shark_turbine/aot/passes/functorch.py @@ -41,7 +41,7 @@ # ONNX applies this post export, which suffers from the loss of output # destructuring rewrites that torch.export does. def functorch_functionalize(gm: GraphModule, *args) -> GraphModule: - functionalized_callable = _functionalize_callabale(gm) + functionalized_callable = _functionalize_callable(gm) # TODO: There is more of a dance needed if the user has entered with a fake_mode. with proxy_tensor.maybe_disable_fake_tensor_mode(): new_gm = proxy_tensor.make_fx( @@ -55,7 +55,7 @@ def functorch_functionalize(gm: GraphModule, *args) -> GraphModule: return new_gm -def _functionalize_callabale(function: Callable) -> Callable: +def _functionalize_callable(function: Callable) -> Callable: def wrapped(*args): args_functional = pytree.tree_map_only( torch.Tensor, torch._to_functional_tensor, args