1
1
import math
2
2
from typing import Literal , Optional , Tuple
3
- import warnings
3
+
4
4
import torch
5
5
6
+ from bitsandbytes .functional import get_4bit_type
6
7
from bitsandbytes .utils import QuantState
7
8
8
9
from .base import Backend
9
10
from .cpu_xpu_common import (
10
- double_quant_impl ,
11
- dequant_8bit ,
12
- NF4_QUANT_TABLE ,
13
11
INT8_QUANT_TABLE ,
14
- )
15
- from bitsandbytes .functional import (
16
- QuantState ,
17
- get_4bit_type ,
12
+ NF4_QUANT_TABLE ,
13
+ dequant_8bit ,
18
14
)
19
15
20
16
Tensor = torch .Tensor
21
17
18
+
22
19
def assert_on_hpu (tensors ):
23
20
on_hpu = True
24
21
for t in tensors :
@@ -32,8 +29,8 @@ def assert_on_hpu(tensors):
32
29
)
33
30
return on_hpu
34
31
35
- class HPUBackend (Backend ):
36
32
33
+ class HPUBackend (Backend ):
37
34
def int8_double_quant (
38
35
self ,
39
36
A : torch .Tensor ,
@@ -43,8 +40,7 @@ def int8_double_quant(
43
40
out_row : Optional [torch .Tensor ] = None ,
44
41
threshold = 0.0 ,
45
42
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
46
- assert_on_hpu ([A , col_stats , row_stats , out_col , out_row ])
47
- return double_quant_impl (A , col_stats , row_stats , out_col , out_row , threshold )
43
+ raise NotImplementedError ("Not yet implemented for HPU backend" )
48
44
49
45
def transform (
50
46
self ,
@@ -100,7 +96,7 @@ def quantize_4bit(
100
96
assert_on_hpu ([A , absmax , out ])
101
97
assert quant_storage == torch .uint8 , "HPU backend only supports uint8 quant_storage"
102
98
return self .quantize_4bit_impl (A , absmax , out , blocksize , compress_statistics , quant_type )
103
-
99
+
104
100
def quantize_4bit_impl (
105
101
self ,
106
102
A : Tensor ,
@@ -159,10 +155,9 @@ def quantize_4bit_impl(
159
155
code = get_4bit_type (quant_type , device = A .device )
160
156
161
157
if compress_statistics :
162
- raise AssertionError ("Double quantization is not supported for HPU backend" )
163
158
offset = absmax .mean ()
164
159
absmax -= offset
165
- qabsmax , state2 = self .hpu_quantize_4bit_impl (absmax , blocksize = 256 , quant_type = "int8" )
160
+ qabsmax , state2 = self .quantize_4bit_impl (absmax , blocksize = 256 , quant_type = "int8" )
166
161
del absmax
167
162
state = QuantState (
168
163
absmax = qabsmax ,
@@ -196,10 +191,10 @@ def dequantize_nf4_impl(
196
191
HPU dequantization function for NF4 quantized tensors.
197
192
"""
198
193
assert_on_hpu ([input , absmax ])
199
- out_shape = (math .prod (quant_state .shape ), )
200
- out_dq = torch .ops .hpu .dequantize_nf4 (input , absmax , blocksize ,
201
- out_shape = out_shape ,
202
- out_dtype = quant_state . dtype )
194
+ out_shape = (math .prod (quant_state .shape ),)
195
+ out_dq = torch .ops .hpu .dequantize_nf4 (
196
+ input , absmax , blocksize , out_shape = out_shape , out_dtype = quant_state . dtype
197
+ )
203
198
output = out_dq .reshape (quant_state .shape ).T
204
199
return output
205
200
@@ -214,10 +209,9 @@ def dequantize_4bit(
214
209
) -> torch .Tensor :
215
210
if blocksize is None :
216
211
blocksize = 64
217
-
212
+
218
213
assert_on_hpu ([A , absmax , out ])
219
214
if quant_state .nested :
220
- raise AssertionError ("Double quantization is not supported for HPU backend" )
221
215
absmax = dequant_8bit (absmax , quant_state .offset , quant_state .state2 )
222
216
return self .dequantize_nf4_impl (A , absmax , blocksize , quant_state )
223
217
@@ -230,18 +224,7 @@ def gemv_4bit(
230
224
transposed_B = False ,
231
225
state : QuantState = None ,
232
226
) -> torch .Tensor :
233
- assert_on_hpu ([A , B , out ])
234
- if state is None :
235
- raise ValueError (
236
- "state cannot be None. gemv_4bit() requires the state from quantize_4bit()"
237
- )
238
- dqB = self .dequantize_nf4_impl (B , state .absmax , state .blocksize , state )
239
- output = torch .matmul (A , dqB .to (A .dtype ))
240
- if out is not None :
241
- out .copy_ (output )
242
- else :
243
- out = output
244
- return out
227
+ raise NotImplementedError ("Not yet implemented for HPU backend" )
245
228
246
229
def int8_vectorwise_dequant (self , A : torch .Tensor , stats : torch .Tensor ):
247
230
raise NotImplementedError ("Not yet implemented for HPU backend" )
0 commit comments