Skip to content

Commit 0d1b3a3

Browse files
Last minute pre-release changes
1 parent 1d4ea6a commit 0d1b3a3

File tree

6 files changed

+111
-83
lines changed

6 files changed

+111
-83
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -445,20 +445,22 @@ def _gemv_4bit_impl(
445445
out: torch.Tensor,
446446
) -> None:
447447
torch._check_is_size(blocksize)
448-
torch._check(
449-
A.numel() == A.size(-1),
450-
lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}",
451-
)
452-
torch._check(
453-
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
454-
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
455-
)
456-
torch._check(
457-
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
458-
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
459-
)
460-
torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}")
461-
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
448+
449+
# Note: these checks are not strictly necessary, and cost more than they are worth, so they are commented out for now.
450+
# torch._check(
451+
# A.numel() == A.size(-1),
452+
# lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}",
453+
# )
454+
# torch._check(
455+
# A.dtype in [torch.float16, torch.bfloat16, torch.float32],
456+
# lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
457+
# )
458+
# torch._check(
459+
# B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
460+
# lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
461+
# )
462+
# torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}")
463+
# torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
462464

463465
m = ct.c_int32(shapeB[0])
464466
n = ct.c_int32(1)

bitsandbytes/cextension.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ctypes as ct
2+
import functools
23
import logging
34
import os
45
from pathlib import Path
@@ -29,10 +30,8 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
2930
library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1)
3031
logger.warning(
3132
f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n"
32-
"This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n"
33+
"This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n"
3334
"If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n"
34-
"If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n"
35-
"For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n",
3635
)
3736

3837
return PACKAGE_DIR / library_name
@@ -45,10 +44,14 @@ class BNBNativeLibrary:
4544
def __init__(self, lib: ct.CDLL):
4645
self._lib = lib
4746

47+
@functools.cache # noqa: B019
4848
def __getattr__(self, name):
49+
fn = getattr(self._lib, name, None)
50+
51+
if fn is not None:
52+
return fn
53+
4954
def throw_on_call(*args, **kwargs):
50-
if hasattr(self._lib, name):
51-
return getattr(self._lib, name)(*args, **kwargs)
5255
raise RuntimeError(
5356
f"Method '{name}' not available in CPU-only version of bitsandbytes.\n"
5457
"Reinstall with GPU support or use CUDA-enabled hardware."

bitsandbytes/diagnostics/cuda.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch
77

88
from bitsandbytes.cextension import get_cuda_bnb_library_path
9-
from bitsandbytes.consts import NONPYTORCH_DOC_URL
109
from bitsandbytes.cuda_specs import CUDASpecs
1110
from bitsandbytes.diagnostics.utils import print_dedented
1211

@@ -114,26 +113,10 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
114113
if not binary_path.exists():
115114
print_dedented(
116115
f"""
117-
Library not found: {binary_path}. Maybe you need to compile it from source?
118-
If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION`,
119-
for example, `make CUDA_VERSION=113`.
120-
121-
The CUDA version for the compile might depend on your conda install, if using conda.
122-
Inspect CUDA version via `conda list | grep cuda`.
123-
""",
124-
)
125-
126-
cuda_major, cuda_minor = cuda_specs.cuda_version_tuple
127-
if cuda_major < 11:
128-
print_dedented(
129-
"""
130-
WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8().
131-
You will be only to use 8-bit optimizers and quantization routines!
116+
Library not found: {binary_path}. Maybe you need to compile it from source?
132117
""",
133118
)
134119

135-
print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")
136-
137120
# 7.5 is the minimum CC for int8 tensor cores
138121
if not cuda_specs.has_imma:
139122
print_dedented(
@@ -144,10 +127,6 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
144127
""",
145128
)
146129

147-
# TODO:
148-
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
149-
# (2) Multiple CUDA versions installed
150-
151130

152131
def print_cuda_runtime_diagnostics() -> None:
153132
cudart_paths = list(find_cudart_libraries())

bitsandbytes/diagnostics/main.py

Lines changed: 83 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,30 @@
1+
import importlib
2+
import platform
13
import sys
24
import traceback
35

46
import torch
57

8+
from bitsandbytes import __version__ as bnb_version
69
from bitsandbytes.consts import PACKAGE_GITHUB_URL
710
from bitsandbytes.cuda_specs import get_cuda_specs
811
from bitsandbytes.diagnostics.cuda import (
912
print_cuda_diagnostics,
10-
print_cuda_runtime_diagnostics,
1113
)
1214
from bitsandbytes.diagnostics.utils import print_dedented, print_header
1315

16+
_RELATED_PACKAGES = [
17+
"accelerate",
18+
"diffusers",
19+
"numpy",
20+
"pip",
21+
"peft",
22+
"safetensors",
23+
"transformers",
24+
"triton",
25+
"trl",
26+
]
27+
1428

1529
def sanity_check():
1630
from bitsandbytes.optim import Adam
@@ -27,47 +41,77 @@ def sanity_check():
2741
assert p1 != p2
2842

2943

44+
def get_package_version(name: str) -> str:
45+
try:
46+
version = importlib.metadata.version(name)
47+
except importlib.metadata.PackageNotFoundError:
48+
version = "not found"
49+
return version
50+
51+
52+
def show_environment():
53+
"""Simple utility to print out environment information."""
54+
55+
print(f"Platform: {platform.platform()}")
56+
if platform.system() == "Linux":
57+
print(f" libc: {'-'.join(platform.libc_ver())}")
58+
59+
print(f"Python: {platform.python_version()}")
60+
61+
print(f"PyTorch: {torch.__version__}")
62+
print(f" CUDA: {torch.version.cuda or 'N/A'}")
63+
print(f" HIP: {torch.version.hip or 'N/A'}")
64+
print(f" XPU: {getattr(torch.version, 'xpu', 'N/A') or 'N/A'}")
65+
66+
print("Related packages:")
67+
for pkg in _RELATED_PACKAGES:
68+
version = get_package_version(pkg)
69+
print(f" {pkg}: {version}")
70+
71+
3072
def main():
31-
print_header("")
32-
print_header("BUG REPORT INFORMATION")
73+
print_header(f"bitsandbytes v{bnb_version}")
74+
show_environment()
3375
print_header("")
3476

35-
print_header("OTHER")
3677
cuda_specs = get_cuda_specs()
37-
print("CUDA specs:", cuda_specs)
38-
if not torch.cuda.is_available():
39-
print("Torch says CUDA is not available. Possible reasons:")
40-
print("1. CUDA driver not installed")
41-
print("2. CUDA not installed")
42-
print("3. You have multiple conflicting CUDA libraries")
78+
4379
if cuda_specs:
4480
print_cuda_diagnostics(cuda_specs)
45-
print_cuda_runtime_diagnostics()
46-
print_header("")
47-
print_header("DEBUG INFO END")
48-
print_header("")
49-
print("Checking that the library is importable and CUDA is callable...")
50-
try:
51-
sanity_check()
52-
print("SUCCESS!")
53-
print("Installation was successful!")
54-
return
55-
except RuntimeError as e:
56-
if "not available in CPU-only" in str(e):
57-
print(
58-
f"WARNING: {__package__} is currently running as CPU-only!\n"
59-
"Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
60-
f"If you think that this is so erroneously,\nplease report an issue!",
61-
)
62-
else:
63-
raise e
64-
except Exception:
65-
traceback.print_exc()
66-
print_dedented(
67-
f"""
68-
Above we output some debug information.
69-
Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose
70-
WARNING: Please be sure to sanitize sensitive info from the output before posting it.
71-
""",
72-
)
73-
sys.exit(1)
81+
82+
# TODO: There's a lot of noise in this; needs improvement.
83+
# print_cuda_runtime_diagnostics()
84+
85+
if not torch.cuda.is_available():
86+
print("PyTorch says CUDA is not available. Possible reasons:")
87+
print("1. CUDA driver not installed")
88+
print("2. Using a CPU-only PyTorch build")
89+
print("3. No GPU detected")
90+
91+
else:
92+
print("Checking that the library is importable and CUDA is callable...")
93+
94+
try:
95+
sanity_check()
96+
print("SUCCESS!")
97+
return
98+
except RuntimeError as e:
99+
if "not available in CPU-only" in str(e):
100+
print(
101+
f"WARNING: {__package__} is currently running as CPU-only!\n"
102+
"Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
103+
f"If you think that this is so erroneously,\nplease report an issue!",
104+
)
105+
else:
106+
raise e
107+
except Exception:
108+
traceback.print_exc()
109+
110+
print_dedented(
111+
f"""
112+
Above we output some debug information.
113+
Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose
114+
WARNING: Please be sure to sanitize sensitive info from the output before posting it.
115+
""",
116+
)
117+
sys.exit(1)

bitsandbytes/diagnostics/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
HEADER_WIDTH = 60
44

55

6-
def print_header(txt: str, width: int = HEADER_WIDTH, filler: str = "+") -> None:
6+
def print_header(txt: str, width: int = HEADER_WIDTH, filler: str = "=") -> None:
77
txt = f" {txt} " if txt else ""
88
print(txt.center(width, filler))
99

bitsandbytes/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -851,8 +851,8 @@ def dequantize_blockwise(
851851
torch.ops.bitsandbytes.dequantize_blockwise.out(
852852
A,
853853
absmax,
854-
code.to(A.device),
855-
blocksize,
854+
quant_state.code.to(A.device),
855+
quant_state.blocksize,
856856
quant_state.dtype,
857857
out=out,
858858
)

0 commit comments

Comments
 (0)