diff --git a/extension_cpp/__init__.py b/extension_cpp/__init__.py index 769c697..a5b254a 100644 --- a/extension_cpp/__init__.py +++ b/extension_cpp/__init__.py @@ -1,2 +1,10 @@ import torch -from . import _C, ops +from pathlib import Path + +so_files = list(Path(__file__).parent.glob("_C*.so")) +assert ( + len(so_files) == 1 +), f"Expected one _C*.so file, found {len(so_files)}" +torch.ops.load_library(so_files[0]) + +from . import ops diff --git a/extension_cpp/csrc/muladd.cpp b/extension_cpp/csrc/muladd.cpp index 73b8f18..85f9fce 100644 --- a/extension_cpp/csrc/muladd.cpp +++ b/extension_cpp/csrc/muladd.cpp @@ -61,9 +61,6 @@ void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { } } -// Registers _C as a Python extension module. -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} - // Defines the operators TORCH_LIBRARY(extension_cpp, m) { m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); diff --git a/setup.py b/setup.py index 1408cb3..6f699bc 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,11 @@ library_name = "extension_cpp" +if torch.__version__ >= "2.6.0": + py_limited_api = True +else: + py_limited_api = False + def get_extensions(): debug_mode = os.getenv("DEBUG", "0") == "1" @@ -59,6 +64,7 @@ def get_extensions(): sources, extra_compile_args=extra_compile_args, extra_link_args=extra_link_args, + py_limited_api=py_limited_api, ) ] @@ -71,9 +77,10 @@ def get_extensions(): packages=find_packages(), ext_modules=get_extensions(), install_requires=["torch"], - description="Example of PyTorch cpp and CUDA extensions", + description="Example of PyTorch C++ and CUDA extensions", long_description=open("README.md").read(), long_description_content_type="text/markdown", url="https://github.com/pytorch/extension-cpp", cmdclass={"build_ext": BuildExtension}, + options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {}, )