Skip to content

supports HPU double quant #1630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def matmul_4bit(
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
elif A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "npu":
elif A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type not in ("npu", "hpu"):
if A.shape[-1] % quant_state.blocksize != 0:
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
Expand Down
47 changes: 15 additions & 32 deletions bitsandbytes/backends/hpu.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
import math
from typing import Literal, Optional, Tuple
import warnings

import torch

from bitsandbytes.functional import get_4bit_type
from bitsandbytes.utils import QuantState

from .base import Backend
from .cpu_xpu_common import (
double_quant_impl,
dequant_8bit,
NF4_QUANT_TABLE,
INT8_QUANT_TABLE,
)
from bitsandbytes.functional import (
QuantState,
get_4bit_type,
NF4_QUANT_TABLE,
dequant_8bit,
)

Tensor = torch.Tensor


def assert_on_hpu(tensors):
on_hpu = True
for t in tensors:
Expand All @@ -32,8 +29,8 @@ def assert_on_hpu(tensors):
)
return on_hpu

class HPUBackend(Backend):

class HPUBackend(Backend):
def int8_double_quant(
self,
A: torch.Tensor,
Expand All @@ -43,8 +40,7 @@ def int8_double_quant(
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert_on_hpu([A, col_stats, row_stats, out_col, out_row])
return double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold)
raise NotImplementedError("Not yet implemented for HPU backend")

def transform(
self,
Expand Down Expand Up @@ -100,7 +96,7 @@ def quantize_4bit(
assert_on_hpu([A, absmax, out])
assert quant_storage == torch.uint8, "HPU backend only supports uint8 quant_storage"
return self.quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)

def quantize_4bit_impl(
self,
A: Tensor,
Expand Down Expand Up @@ -159,10 +155,9 @@ def quantize_4bit_impl(
code = get_4bit_type(quant_type, device=A.device)

if compress_statistics:
raise AssertionError("Double quantization is not supported for HPU backend")
offset = absmax.mean()
absmax -= offset
qabsmax, state2 = self.hpu_quantize_4bit_impl(absmax, blocksize=256, quant_type="int8")
qabsmax, state2 = self.quantize_4bit_impl(absmax, blocksize=256, quant_type="int8")
del absmax
state = QuantState(
absmax=qabsmax,
Expand Down Expand Up @@ -196,10 +191,10 @@ def dequantize_nf4_impl(
HPU dequantization function for NF4 quantized tensors.
"""
assert_on_hpu([input, absmax])
out_shape = (math.prod(quant_state.shape), )
out_dq = torch.ops.hpu.dequantize_nf4(input, absmax, blocksize,
out_shape=out_shape,
out_dtype=quant_state.dtype)
out_shape = (math.prod(quant_state.shape),)
out_dq = torch.ops.hpu.dequantize_nf4(
input, absmax, blocksize, out_shape=out_shape, out_dtype=quant_state.dtype
)
output = out_dq.reshape(quant_state.shape).T
return output

Expand All @@ -214,10 +209,9 @@ def dequantize_4bit(
) -> torch.Tensor:
if blocksize is None:
blocksize = 64

assert_on_hpu([A, absmax, out])
if quant_state.nested:
raise AssertionError("Double quantization is not supported for HPU backend")
absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2)
return self.dequantize_nf4_impl(A, absmax, blocksize, quant_state)

Expand All @@ -230,18 +224,7 @@ def gemv_4bit(
transposed_B=False,
state: QuantState = None,
) -> torch.Tensor:
assert_on_hpu([A, B, out])
if state is None:
raise ValueError(
"state cannot be None. gemv_4bit() requires the state from quantize_4bit()"
)
dqB = self.dequantize_nf4_impl(B, state.absmax, state.blocksize, state)
output = torch.matmul(A, dqB.to(A.dtype))
if out is not None:
out.copy_(output)
else:
out = output
return out
raise NotImplementedError("Not yet implemented for HPU backend")

def int8_vectorwise_dequant(self, A: torch.Tensor, stats: torch.Tensor):
raise NotImplementedError("Not yet implemented for HPU backend")
Expand Down
Loading