diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index d224cfe1c..b14d2024c 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -570,7 +570,7 @@ def matmul_4bit( return out else: return MatMul4Bit.apply(A, B, out, bias, quant_state) - elif A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "npu": + elif A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type not in ("npu", "hpu"): if A.shape[-1] % quant_state.blocksize != 0: warn( f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", diff --git a/bitsandbytes/backends/hpu.py b/bitsandbytes/backends/hpu.py index 03308cd5d..2bc367078 100644 --- a/bitsandbytes/backends/hpu.py +++ b/bitsandbytes/backends/hpu.py @@ -1,24 +1,21 @@ import math from typing import Literal, Optional, Tuple -import warnings + import torch +from bitsandbytes.functional import get_4bit_type from bitsandbytes.utils import QuantState from .base import Backend from .cpu_xpu_common import ( - double_quant_impl, - dequant_8bit, - NF4_QUANT_TABLE, INT8_QUANT_TABLE, -) -from bitsandbytes.functional import ( - QuantState, - get_4bit_type, + NF4_QUANT_TABLE, + dequant_8bit, ) Tensor = torch.Tensor + def assert_on_hpu(tensors): on_hpu = True for t in tensors: @@ -32,8 +29,8 @@ def assert_on_hpu(tensors): ) return on_hpu -class HPUBackend(Backend): +class HPUBackend(Backend): def int8_double_quant( self, A: torch.Tensor, @@ -43,8 +40,7 @@ def int8_double_quant( out_row: Optional[torch.Tensor] = None, threshold=0.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - assert_on_hpu([A, col_stats, row_stats, out_col, out_row]) - return double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold) + raise NotImplementedError("Not yet implemented for HPU backend") def transform( self, @@ -100,7 +96,7 @@ def quantize_4bit( assert_on_hpu([A, absmax, out]) assert quant_storage == torch.uint8, "HPU backend only supports uint8 quant_storage" return self.quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type) - + def quantize_4bit_impl( self, A: Tensor, @@ -159,10 +155,9 @@ def quantize_4bit_impl( code = get_4bit_type(quant_type, device=A.device) if compress_statistics: - raise AssertionError("Double quantization is not supported for HPU backend") offset = absmax.mean() absmax -= offset - qabsmax, state2 = self.hpu_quantize_4bit_impl(absmax, blocksize=256, quant_type="int8") + qabsmax, state2 = self.quantize_4bit_impl(absmax, blocksize=256, quant_type="int8") del absmax state = QuantState( absmax=qabsmax, @@ -196,10 +191,10 @@ def dequantize_nf4_impl( HPU dequantization function for NF4 quantized tensors. """ assert_on_hpu([input, absmax]) - out_shape = (math.prod(quant_state.shape), ) - out_dq = torch.ops.hpu.dequantize_nf4(input, absmax, blocksize, - out_shape=out_shape, - out_dtype=quant_state.dtype) + out_shape = (math.prod(quant_state.shape),) + out_dq = torch.ops.hpu.dequantize_nf4( + input, absmax, blocksize, out_shape=out_shape, out_dtype=quant_state.dtype + ) output = out_dq.reshape(quant_state.shape).T return output @@ -214,10 +209,9 @@ def dequantize_4bit( ) -> torch.Tensor: if blocksize is None: blocksize = 64 - + assert_on_hpu([A, absmax, out]) if quant_state.nested: - raise AssertionError("Double quantization is not supported for HPU backend") absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2) return self.dequantize_nf4_impl(A, absmax, blocksize, quant_state) @@ -230,18 +224,7 @@ def gemv_4bit( transposed_B=False, state: QuantState = None, ) -> torch.Tensor: - assert_on_hpu([A, B, out]) - if state is None: - raise ValueError( - "state cannot be None. gemv_4bit() requires the state from quantize_4bit()" - ) - dqB = self.dequantize_nf4_impl(B, state.absmax, state.blocksize, state) - output = torch.matmul(A, dqB.to(A.dtype)) - if out is not None: - out.copy_(output) - else: - out = output - return out + raise NotImplementedError("Not yet implemented for HPU backend") def int8_vectorwise_dequant(self, A: torch.Tensor, stats: torch.Tensor): raise NotImplementedError("Not yet implemented for HPU backend")