Skip to content

Commit fd723b7

Browse files
Merge pull request #1041 from akx/cuda-wagh
Rework CUDA/native-library setup and diagnostics
2 parents ce597c6 + 79d1ccc commit fd723b7

File tree

14 files changed

+484
-610
lines changed

14 files changed

+484
-610
lines changed

bitsandbytes/__init__.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from . import cuda_setup, research, utils
6+
from . import research, utils
77
from .autograd._functions import (
88
MatmulLtState,
99
bmm_cublas,
@@ -12,11 +12,8 @@
1212
matmul_cublas,
1313
mm_cublas,
1414
)
15-
from .cextension import COMPILED_WITH_CUDA
1615
from .nn import modules
17-
18-
if COMPILED_WITH_CUDA:
19-
from .optim import adam
16+
from .optim import adam
2017

2118
__pdoc__ = {
2219
"libbitsandbytes": False,
@@ -25,5 +22,3 @@
2522
}
2623

2724
__version__ = "0.44.0.dev"
28-
29-
PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes"

bitsandbytes/__main__.py

Lines changed: 2 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,4 @@
1-
import glob
2-
import os
3-
import sys
4-
from warnings import warn
5-
6-
import torch
7-
8-
HEADER_WIDTH = 60
9-
10-
11-
def find_dynamic_library(folder, filename):
12-
for ext in ("so", "dll", "dylib"):
13-
yield from glob.glob(os.path.join(folder, "**", filename + ext))
14-
15-
16-
def generate_bug_report_information():
17-
print_header("")
18-
print_header("BUG REPORT INFORMATION")
19-
print_header("")
20-
print('')
21-
22-
path_sources = [
23-
("ANACONDA CUDA PATHS", os.environ.get("CONDA_PREFIX")),
24-
("/usr/local CUDA PATHS", "/usr/local"),
25-
("CUDA PATHS", os.environ.get("CUDA_PATH")),
26-
("WORKING DIRECTORY CUDA PATHS", os.getcwd()),
27-
]
28-
try:
29-
ld_library_path = os.environ.get("LD_LIBRARY_PATH")
30-
if ld_library_path:
31-
for path in set(ld_library_path.strip().split(os.pathsep)):
32-
path_sources.append((f"LD_LIBRARY_PATH {path} CUDA PATHS", path))
33-
except Exception as e:
34-
print(f"Could not parse LD_LIBRARY_PATH: {e}")
35-
36-
for name, path in path_sources:
37-
if path and os.path.isdir(path):
38-
print_header(name)
39-
print(list(find_dynamic_library(path, '*cuda*')))
40-
print("")
41-
42-
43-
def print_header(
44-
txt: str, width: int = HEADER_WIDTH, filler: str = "+"
45-
) -> None:
46-
txt = f" {txt} " if txt else ""
47-
print(txt.center(width, filler))
48-
49-
50-
def print_debug_info() -> None:
51-
from . import PACKAGE_GITHUB_URL
52-
print(
53-
"\nAbove we output some debug information. Please provide this info when "
54-
f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n"
55-
)
56-
57-
58-
def main():
59-
generate_bug_report_information()
60-
61-
from . import COMPILED_WITH_CUDA
62-
from .cuda_setup.main import get_compute_capabilities
63-
64-
print_header("OTHER")
65-
print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}")
66-
print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}")
67-
print_header("")
68-
print_header("DEBUG INFO END")
69-
print_header("")
70-
print("Checking that the library is importable and CUDA is callable...")
71-
print("\nWARNING: Please be sure to sanitize sensitive info from any such env vars!\n")
72-
73-
try:
74-
from bitsandbytes.optim import Adam
75-
76-
p = torch.nn.Parameter(torch.rand(10, 10).cuda())
77-
a = torch.rand(10, 10).cuda()
78-
79-
p1 = p.data.sum().item()
80-
81-
adam = Adam([p])
82-
83-
out = a * p
84-
loss = out.sum()
85-
loss.backward()
86-
adam.step()
87-
88-
p2 = p.data.sum().item()
89-
90-
assert p1 != p2
91-
print("SUCCESS!")
92-
print("Installation was successful!")
93-
except ImportError:
94-
print()
95-
warn(
96-
f"WARNING: {__package__} is currently running as CPU-only!\n"
97-
"Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
98-
f"If you think that this is so erroneously,\nplease report an issue!"
99-
)
100-
print_debug_info()
101-
except Exception as e:
102-
print(e)
103-
print_debug_info()
104-
sys.exit(1)
105-
106-
1071
if __name__ == "__main__":
2+
from bitsandbytes.diagnostics.main import main
3+
1084
main()

bitsandbytes/cextension.py

Lines changed: 117 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,124 @@
1+
"""
2+
extract factors the build is dependent on:
3+
[X] compute capability
4+
[ ] TODO: Q - What if we have multiple GPUs of different makes?
5+
- CUDA version
6+
- Software:
7+
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
8+
- CuBLAS-LT: full-build 8-bit optimizer
9+
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
10+
11+
evaluation:
12+
- if paths faulty, return meaningful error
13+
- else:
14+
- determine CUDA version
15+
- determine capabilities
16+
- based on that set the default path
17+
"""
18+
119
import ctypes as ct
2-
from warnings import warn
20+
import logging
21+
import os
22+
from pathlib import Path
323

424
import torch
525

6-
from bitsandbytes.cuda_setup.main import CUDASetup
26+
from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
27+
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
33+
"""
34+
Get the disk path to the CUDA BNB native library specified by the
35+
given CUDA specs, taking into account the `BNB_CUDA_VERSION` override environment variable.
36+
37+
The library is not guaranteed to exist at the returned path.
38+
"""
39+
library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}"
40+
if not cuda_specs.has_cublaslt:
41+
# if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt
42+
library_name += "_nocublaslt"
43+
library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}"
44+
45+
override_value = os.environ.get("BNB_CUDA_VERSION")
46+
if override_value:
47+
library_name_stem, _, library_name_ext = library_name.rpartition(".")
48+
# `library_name_stem` will now be e.g. `libbitsandbytes_cuda118`;
49+
# let's remove any trailing numbers:
50+
library_name_stem = library_name_stem.rstrip("0123456789")
51+
# `library_name_stem` will now be e.g. `libbitsandbytes_cuda`;
52+
# let's tack the new version number and the original extension back on.
53+
library_name = f"{library_name_stem}{override_value}.{library_name_ext}"
54+
logger.warning(
55+
f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n"
56+
"This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n"
57+
"If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n"
58+
"If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n"
59+
"For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n"
60+
)
61+
62+
return PACKAGE_DIR / library_name
63+
64+
65+
class BNBNativeLibrary:
66+
_lib: ct.CDLL
67+
compiled_with_cuda = False
68+
69+
def __init__(self, lib: ct.CDLL):
70+
self._lib = lib
71+
72+
def __getattr__(self, item):
73+
return getattr(self._lib, item)
74+
75+
76+
class CudaBNBNativeLibrary(BNBNativeLibrary):
77+
compiled_with_cuda = True
78+
79+
def __init__(self, lib: ct.CDLL):
80+
super().__init__(lib)
81+
lib.get_context.restype = ct.c_void_p
82+
lib.get_cusparse.restype = ct.c_void_p
83+
lib.cget_managed_ptr.restype = ct.c_void_p
84+
85+
86+
def get_native_library() -> BNBNativeLibrary:
87+
binary_path = PACKAGE_DIR / f"libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}"
88+
cuda_specs = get_cuda_specs()
89+
if cuda_specs:
90+
cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)
91+
if cuda_binary_path.exists():
92+
binary_path = cuda_binary_path
93+
else:
94+
logger.warning("Could not find the bitsandbytes CUDA binary at %r", cuda_binary_path)
95+
logger.debug(f"Loading bitsandbytes native library from: {binary_path}")
96+
dll = ct.cdll.LoadLibrary(str(binary_path))
97+
98+
if hasattr(dll, "get_context"): # only a CUDA-built library exposes this
99+
return CudaBNBNativeLibrary(dll)
100+
101+
logger.warning(
102+
"The installed version of bitsandbytes was compiled without GPU support. "
103+
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable."
104+
)
105+
return BNBNativeLibrary(dll)
7106

8-
setup = CUDASetup.get_instance()
9-
if setup.initialized != True:
10-
setup.run_cuda_setup()
11107

12-
lib = setup.lib
13108
try:
14-
if lib is None and torch.cuda.is_available():
15-
CUDASetup.get_instance().generate_instructions()
16-
CUDASetup.get_instance().print_log_stack()
17-
raise RuntimeError('''
18-
CUDA Setup failed despite GPU being available. Please run the following command to get more information:
19-
20-
python -m bitsandbytes
21-
22-
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
23-
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
24-
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''')
25-
_ = lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False
26-
lib.get_context.restype = ct.c_void_p
27-
lib.get_cusparse.restype = ct.c_void_p
28-
lib.cget_managed_ptr.restype = ct.c_void_p
29-
COMPILED_WITH_CUDA = True
30-
except AttributeError as ex:
31-
warn("The installed version of bitsandbytes was compiled without GPU support. "
32-
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.")
33-
COMPILED_WITH_CUDA = False
34-
print(str(ex))
35-
36-
37-
# print the setup details after checking for errors so we do not print twice
38-
#if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
39-
#setup.print_log_stack()
109+
lib = get_native_library()
110+
except Exception as e:
111+
lib = None
112+
logger.error(f"Could not load bitsandbytes native library: {e}", exc_info=True)
113+
if torch.cuda.is_available():
114+
logger.warning(
115+
"""
116+
CUDA Setup failed despite CUDA being available. Please run the following command to get more information:
117+
118+
python -m bitsandbytes
119+
120+
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
121+
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
122+
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues
123+
"""
124+
)

bitsandbytes/consts.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from pathlib import Path
2+
import platform
3+
4+
DYNAMIC_LIBRARY_SUFFIX = {
5+
"Darwin": ".dylib",
6+
"Linux": ".so",
7+
"Windows": ".dll",
8+
}.get(platform.system(), ".so")
9+
10+
PACKAGE_DIR = Path(__file__).parent
11+
PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes"
12+
NONPYTORCH_DOC_URL = "https://github.com/TimDettmers/bitsandbytes/blob/main/docs/source/nonpytorchcuda.mdx"

bitsandbytes/cuda_setup/env_vars.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

0 commit comments

Comments
 (0)