Skip to content

Commit 812ef06

Browse files
Add support for Intel Gaudi/HPU backend (#1662)
* supports hpu backend in main branch * Update bitsandbytes/backends/hpu/ops.py updates the assertion message Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * Update bitsandbytes/backends/hpu/ops.py Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * Update ops.py Fix lint issue * Update ops.py --------- Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
1 parent e9fc96a commit 812ef06

File tree

6 files changed

+82
-3
lines changed

6 files changed

+82
-3
lines changed

bitsandbytes/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"cpu",
2727
"cuda", # NVIDIA/AMD GPU
2828
"xpu", # Intel GPU
29-
"hpu", # Gaudi
29+
"hpu", # Intel Gaudi
3030
"npu", # Ascend NPU
3131
"mps", # Apple Silicon
3232
}
@@ -37,6 +37,9 @@
3737
if hasattr(torch, "xpu") and torch.xpu.is_available():
3838
from .backends.xpu import ops as xpu_ops
3939

40+
if hasattr(torch, "hpu") and torch.hpu.is_available():
41+
from .backends.hpu import ops as hpu_ops
42+
4043

4144
def _import_backends():
4245
"""

bitsandbytes/autograd/_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def matmul_4bit(
451451
else:
452452
return MatMul4Bit.apply(A, B, out, bias, quant_state)
453453

454-
if A.numel() == A.shape[-1] and A.requires_grad == False:
454+
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
455455
if A.shape[-1] % quant_state.blocksize != 0:
456456
warn(
457457
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",

bitsandbytes/backends/hpu/__init__.py

Whitespace-only changes.

bitsandbytes/backends/hpu/ops.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from collections.abc import Sequence
2+
import math
3+
4+
import torch
5+
6+
from bitsandbytes.utils import _reverse_4bit_compress_format
7+
8+
from ..._ops import register_kernel
9+
from ..utils import GAUDI_SW_VER
10+
11+
12+
@register_kernel("bitsandbytes::dequantize_4bit", "hpu")
13+
def _(
14+
A: torch.Tensor,
15+
absmax: torch.Tensor,
16+
blocksize: int,
17+
quant_type: str,
18+
shape: Sequence[int],
19+
dtype: torch.dtype,
20+
) -> torch.Tensor:
21+
torch._check_is_size(blocksize)
22+
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4, got {quant_type}")
23+
torch._check(
24+
A.dtype in [torch.bfloat16, torch.uint8],
25+
lambda: f"quant_storage supports uint8 or bfloat16, but got {A.dtype}",
26+
)
27+
28+
# Enable non uint8 dtype
29+
if A.dtype != torch.uint8:
30+
A = A.view(torch.uint8)
31+
32+
transpose = False if len(A.shape) == 2 and A.shape[0] == 1 else True
33+
34+
A = A.reshape(-1)
35+
36+
if GAUDI_SW_VER and (GAUDI_SW_VER.major < 1 or GAUDI_SW_VER.minor < 22):
37+
A = _reverse_4bit_compress_format(A)
38+
39+
# HPU dequantization function for NF4 quantized tensors.
40+
out_dq = torch.ops.hpu.dequantize_nf4(
41+
A,
42+
absmax.to(dtype),
43+
blocksize,
44+
out_shape=(math.prod(shape),),
45+
out_dtype=dtype,
46+
)
47+
48+
output = out_dq.reshape(shape)
49+
50+
if transpose:
51+
output = output.t()
52+
53+
return output

bitsandbytes/backends/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import subprocess
2+
3+
from packaging import version
14
import torch
25

36
try:
@@ -59,3 +62,23 @@
5962
else "cpu", # Only cpu/xpu use this table for now.
6063
)
6164
CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE}
65+
66+
67+
def get_gaudi_sw_version():
68+
"""
69+
Returns the installed version of Gaudi SW.
70+
"""
71+
output = subprocess.run(
72+
"pip list | grep habana-torch-plugin",
73+
shell=True,
74+
text=True,
75+
capture_output=True,
76+
)
77+
# If grep return nothing
78+
if not output.stdout.strip():
79+
return None
80+
81+
return version.parse(output.stdout.split("\n")[0].split()[-1])
82+
83+
84+
GAUDI_SW_VER = get_gaudi_sw_version()

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def __init__(
443443
)
444444
# self.persistent_buffers = [] # TODO consider as way to save quant state
445445
self.compute_dtype = compute_dtype
446-
self.compute_type_is_set = False
446+
self.compute_type_is_set = False if compute_dtype is None else True
447447
self.quant_state = None
448448
self.quant_storage = quant_storage
449449
self.ipex_linear_is_set = False

0 commit comments

Comments
 (0)