diff --git a/torchao/experimental/ops/mps/cshim.py b/torchao/experimental/ops/mps/cshim.py new file mode 100644 index 0000000000..b5f01a2e3b --- /dev/null +++ b/torchao/experimental/ops/mps/cshim.py @@ -0,0 +1,11 @@ +import torch + +# List of ops and their c-shim declarations used for AOTInductor +# Check out TestUIntxWeightOnlyLinearQuantizer.test_export_accuracy on how to use it +torchao_op_c_shim: dict[torch.ops.OpOverload, list[str]] = {} + +for nbit in range(1, 8): + op_name = f"_linear_fp_act_{nbit}bit_weight" + torchao_op_c_shim[getattr(torch.ops.torchao, op_name).default] = [ + f"AOTITorchError aoti_torch_mps_{op_name}(AtenTensorHandle A, AtenTensorHandle B, int64_t group_size, AtenTensorHandle S, AtenTensorHandle Z, AtenTensorHandle* ret)", + ] diff --git a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm index e8fcdb2699..b8ecb8c7aa 100644 --- a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm +++ b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm @@ -5,6 +5,8 @@ // LICENSE file in the root directory of this source tree. // clang-format off +#include +#include #include #include #include @@ -239,3 +241,44 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) { } } // namespace torchao::kernels::mps::lowbit::aten + + +// c-shim wrappers for AOTInductor +// Check out TestUIntxWeightOnlyLinearQuantizer.test_export_accuracy on how to use it +#define DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(BITS) \ +extern "C" { \ + AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__linear_fp_act_##BITS##bit_weight( \ + AtenTensorHandle A, \ + AtenTensorHandle B, \ + int64_t group_size, \ + AtenTensorHandle S, \ + AtenTensorHandle Z, \ + AtenTensorHandle* ret) { \ + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ \ + auto op_handle = \ + c10::Dispatcher::singleton() \ + .findSchemaOrThrow("torchao::_linear_fp_act_" #BITS "bit_weight", "") \ + .typed(); \ + auto tmp_result = op_handle.call( \ + torch::aot_inductor::resolve_tensor_dispatch_flags(A), \ + torch::aot_inductor::resolve_tensor_dispatch_flags(B), \ + group_size, \ + torch::aot_inductor::resolve_tensor_dispatch_flags(S), \ + torch::aot_inductor::resolve_tensor_dispatch_flags(Z)); \ + *ret = torch::aot_inductor::new_tensor_handle(std::move(tmp_result)); \ + }); \ + } \ +} + +DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(1) +DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(2) +DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(3) +DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(4) +DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(5) +DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(6) +DECLARE_LINEAR_FP_ACT_WEIGHT_FUNCTION(7)