Skip to content

Commit 38ec45e

Browse files
authored
Merge pull request #109 from pytorch/_C-instead-of-load_library
Create dummymodule _C instead of using load_library(blah.so)
2 parents 4be2205 + 58ac996 commit 38ec45e

File tree

3 files changed

+25
-9
lines changed

3 files changed

+25
-9
lines changed

extension_cpp/__init__.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
11
import torch
22
from pathlib import Path
3-
4-
so_files = list(Path(__file__).parent.glob("_C*.so"))
5-
assert (
6-
len(so_files) == 1
7-
), f"Expected one _C*.so file, found {len(so_files)}"
8-
torch.ops.load_library(so_files[0])
9-
10-
from . import ops
3+
from . import _C, ops

extension_cpp/csrc/muladd.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,29 @@
1-
#include <torch/extension.h>
1+
#include <Python.h>
2+
#include <ATen/Operators.h>
3+
#include <torch/all.h>
4+
#include <torch/library.h>
25

36
#include <vector>
47

8+
extern "C" {
9+
/* Creates a dummy empty _C module that can be imported from Python.
10+
The import from Python will load the .so consisting of this file
11+
in this extension, so that the TORCH_LIBRARY static initializers
12+
below are run. */
13+
PyObject* PyInit__C(void)
14+
{
15+
static struct PyModuleDef module_def = {
16+
PyModuleDef_HEAD_INIT,
17+
"_C", /* name of module */
18+
NULL, /* module documentation, may be NULL */
19+
-1, /* size of per-interpreter state of the module,
20+
or -1 if the module keeps state in global variables. */
21+
NULL, /* methods */
22+
};
23+
return PyModule_Create(&module_def);
24+
}
25+
}
26+
527
namespace extension_cpp {
628

729
at::Tensor mymuladd_cpu(const at::Tensor& a, const at::Tensor& b, double c) {

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def get_extensions():
3838
"cxx": [
3939
"-O3" if not debug_mode else "-O0",
4040
"-fdiagnostics-color=always",
41+
"-DPy_LIMITED_API=0x03090000", # min CPython version 3.9
4142
],
4243
"nvcc": [
4344
"-O3" if not debug_mode else "-O0",

0 commit comments

Comments
 (0)