Skip to content

Commit c3eac42

Browse files
authored
supports HPU double quant (#1630)
1 parent 5e267f5 commit c3eac42

File tree

2 files changed

+16
-33
lines changed

2 files changed

+16
-33
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def matmul_4bit(
570570
return out
571571
else:
572572
return MatMul4Bit.apply(A, B, out, bias, quant_state)
573-
elif A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "npu":
573+
elif A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type not in ("npu", "hpu"):
574574
if A.shape[-1] % quant_state.blocksize != 0:
575575
warn(
576576
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",

bitsandbytes/backends/hpu.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
11
import math
22
from typing import Literal, Optional, Tuple
3-
import warnings
3+
44
import torch
55

6+
from bitsandbytes.functional import get_4bit_type
67
from bitsandbytes.utils import QuantState
78

89
from .base import Backend
910
from .cpu_xpu_common import (
10-
double_quant_impl,
11-
dequant_8bit,
12-
NF4_QUANT_TABLE,
1311
INT8_QUANT_TABLE,
14-
)
15-
from bitsandbytes.functional import (
16-
QuantState,
17-
get_4bit_type,
12+
NF4_QUANT_TABLE,
13+
dequant_8bit,
1814
)
1915

2016
Tensor = torch.Tensor
2117

18+
2219
def assert_on_hpu(tensors):
2320
on_hpu = True
2421
for t in tensors:
@@ -32,8 +29,8 @@ def assert_on_hpu(tensors):
3229
)
3330
return on_hpu
3431

35-
class HPUBackend(Backend):
3632

33+
class HPUBackend(Backend):
3734
def int8_double_quant(
3835
self,
3936
A: torch.Tensor,
@@ -43,8 +40,7 @@ def int8_double_quant(
4340
out_row: Optional[torch.Tensor] = None,
4441
threshold=0.0,
4542
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
46-
assert_on_hpu([A, col_stats, row_stats, out_col, out_row])
47-
return double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold)
43+
raise NotImplementedError("Not yet implemented for HPU backend")
4844

4945
def transform(
5046
self,
@@ -100,7 +96,7 @@ def quantize_4bit(
10096
assert_on_hpu([A, absmax, out])
10197
assert quant_storage == torch.uint8, "HPU backend only supports uint8 quant_storage"
10298
return self.quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)
103-
99+
104100
def quantize_4bit_impl(
105101
self,
106102
A: Tensor,
@@ -159,10 +155,9 @@ def quantize_4bit_impl(
159155
code = get_4bit_type(quant_type, device=A.device)
160156

161157
if compress_statistics:
162-
raise AssertionError("Double quantization is not supported for HPU backend")
163158
offset = absmax.mean()
164159
absmax -= offset
165-
qabsmax, state2 = self.hpu_quantize_4bit_impl(absmax, blocksize=256, quant_type="int8")
160+
qabsmax, state2 = self.quantize_4bit_impl(absmax, blocksize=256, quant_type="int8")
166161
del absmax
167162
state = QuantState(
168163
absmax=qabsmax,
@@ -196,10 +191,10 @@ def dequantize_nf4_impl(
196191
HPU dequantization function for NF4 quantized tensors.
197192
"""
198193
assert_on_hpu([input, absmax])
199-
out_shape = (math.prod(quant_state.shape), )
200-
out_dq = torch.ops.hpu.dequantize_nf4(input, absmax, blocksize,
201-
out_shape=out_shape,
202-
out_dtype=quant_state.dtype)
194+
out_shape = (math.prod(quant_state.shape),)
195+
out_dq = torch.ops.hpu.dequantize_nf4(
196+
input, absmax, blocksize, out_shape=out_shape, out_dtype=quant_state.dtype
197+
)
203198
output = out_dq.reshape(quant_state.shape).T
204199
return output
205200

@@ -214,10 +209,9 @@ def dequantize_4bit(
214209
) -> torch.Tensor:
215210
if blocksize is None:
216211
blocksize = 64
217-
212+
218213
assert_on_hpu([A, absmax, out])
219214
if quant_state.nested:
220-
raise AssertionError("Double quantization is not supported for HPU backend")
221215
absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2)
222216
return self.dequantize_nf4_impl(A, absmax, blocksize, quant_state)
223217

@@ -230,18 +224,7 @@ def gemv_4bit(
230224
transposed_B=False,
231225
state: QuantState = None,
232226
) -> torch.Tensor:
233-
assert_on_hpu([A, B, out])
234-
if state is None:
235-
raise ValueError(
236-
"state cannot be None. gemv_4bit() requires the state from quantize_4bit()"
237-
)
238-
dqB = self.dequantize_nf4_impl(B, state.absmax, state.blocksize, state)
239-
output = torch.matmul(A, dqB.to(A.dtype))
240-
if out is not None:
241-
out.copy_(output)
242-
else:
243-
out = output
244-
return out
227+
raise NotImplementedError("Not yet implemented for HPU backend")
245228

246229
def int8_vectorwise_dequant(self, A: torch.Tensor, stats: torch.Tensor):
247230
raise NotImplementedError("Not yet implemented for HPU backend")

0 commit comments

Comments
 (0)