7
7
8
8
from ..._ops import register_kernel
9
9
from ...cextension import lib
10
+ from ..utils import ipex_cpu
10
11
11
12
# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
12
13
# However, we can overflow if we use this without AVX512_VNNI support.
@@ -26,22 +27,42 @@ def _(A: torch.Tensor, B: torch.Tensor):
26
27
@register_kernel ("bitsandbytes::quantize_blockwise" , "cpu" )
27
28
def _ (A : torch .Tensor , code : torch .Tensor , blocksize : int ) -> tuple [torch .Tensor , torch .Tensor ]:
28
29
torch ._check_is_size (blocksize )
29
- torch ._check (A .dtype == torch .float32 , lambda : f"A must be float32 on cpu, got { A .dtype } " )
30
30
31
31
n = A .numel ()
32
- blocks = - (n // - blocksize )
33
32
34
- absmax = torch .empty ((blocks ,), device = A .device , dtype = torch .float32 )
35
- out = torch .empty_like (A , dtype = torch .uint8 )
36
-
37
- lib .cquantize_blockwise_cpu_fp32 (
38
- get_ptr (code ),
39
- get_ptr (A ),
40
- get_ptr (absmax ),
41
- get_ptr (out ),
42
- ct .c_longlong (blocksize ),
43
- ct .c_longlong (n ),
44
- )
33
+ # Only FP32 has c++ kernrl
34
+ if A .dtype == torch .float32 :
35
+ blocks = - (n // - blocksize )
36
+
37
+ absmax = torch .empty ((blocks ,), device = A .device , dtype = torch .float32 )
38
+ out = torch .empty_like (A , dtype = torch .uint8 )
39
+
40
+ lib .cquantize_blockwise_cpu_fp32 (
41
+ get_ptr (code ),
42
+ get_ptr (A ),
43
+ get_ptr (absmax ),
44
+ get_ptr (out ),
45
+ ct .c_longlong (blocksize ),
46
+ ct .c_longlong (n ),
47
+ )
48
+ else :
49
+ rem = n % blocksize
50
+ has_rem = rem > 0
51
+ blocks = n // blocksize + has_rem
52
+ absmax = torch .zeros ((blocks ,), device = A .device , dtype = torch .float32 )
53
+ A_reshaped = A .reshape (n )
54
+ A_com = A_reshaped [: n - rem ]
55
+ A_com_reshaped = A_com .reshape (n // blocksize , blocksize )
56
+ absmax [: blocks - has_rem ] = torch .abs (A_com_reshaped ).max (dim = - 1 )[0 ]
57
+ scaled_A = torch .clamp (A_com_reshaped * (1 / absmax [: blocks - has_rem ].view (- 1 , 1 )), - 1 , 1 )
58
+ scaled_A = scaled_A .reshape (- 1 )
59
+ if has_rem :
60
+ absmax [- 1 ] = torch .abs (A_reshaped [n - rem :]).max ()
61
+ scaled_A_rem = torch .clamp (A_reshaped [n - rem :] * (1 / absmax [- 1 ]), - 1 , 1 )
62
+ scaled_A = torch .cat ([scaled_A , scaled_A_rem ], dim = 0 )
63
+
64
+ diff = torch .abs (scaled_A .unsqueeze (- 1 ) - code .to (scaled_A .device ))
65
+ out = torch .argmin (diff , dim = - 1 ).to (torch .uint8 ).to (scaled_A .device ).reshape (A .shape )
45
66
46
67
return out , absmax
47
68
@@ -50,144 +71,50 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
50
71
def _ (A : torch .Tensor , absmax : torch .Tensor , code : torch .Tensor , blocksize : int , dtype : torch .dtype ) -> torch .Tensor :
51
72
torch ._check_is_size (blocksize )
52
73
torch ._check (A .dtype == torch .uint8 , lambda : f"A must be uint8, got { A .dtype } " )
53
- torch ._check (dtype == torch .float32 , lambda : f"dtype must be float32 on cpu, got { dtype } " )
54
-
55
- out = torch .empty_like (A , dtype = dtype )
56
74
57
- lib .cdequantize_blockwise_cpu_fp32 (
58
- get_ptr (code ),
59
- get_ptr (A ),
60
- get_ptr (absmax ),
61
- get_ptr (out ),
62
- ct .c_longlong (blocksize ),
63
- ct .c_longlong (A .numel ()),
64
- )
75
+ # Only FP32 has c++ kernrl
76
+ if dtype == torch .float32 :
77
+ out = torch .empty_like (A , dtype = dtype )
78
+
79
+ lib .cdequantize_blockwise_cpu_fp32 (
80
+ get_ptr (code ),
81
+ get_ptr (A ),
82
+ get_ptr (absmax ),
83
+ get_ptr (out ),
84
+ ct .c_longlong (blocksize ),
85
+ ct .c_longlong (A .numel ()),
86
+ )
87
+ else :
88
+ out = code [A .reshape (- 1 ).int ()]
89
+ blocks = out .shape [- 1 ] // blocksize
90
+ res = out .shape [- 1 ] % blocksize
91
+ if res != 0 :
92
+ out = torch .nn .functional .pad (out , (0 , blocksize - res ), mode = "constant" , value = 0 )
93
+ out = (out .view (- 1 , blocksize ) * absmax .view (- 1 , 1 )).to (dtype ).reshape (- 1 )
94
+ out = out [: blocks * blocksize + res ]
95
+ out = out .reshape (A .shape )
65
96
66
97
return out
67
98
68
99
69
- _NF4_QUANT_TABLE = torch .tensor (
70
- [
71
- - 1.0 ,
72
- - 0.6961928009986877 ,
73
- - 0.5250730514526367 ,
74
- - 0.39491748809814453 ,
75
- - 0.28444138169288635 ,
76
- - 0.18477343022823334 ,
77
- - 0.09105003625154495 ,
78
- 0.0 ,
79
- 0.07958029955625534 ,
80
- 0.16093020141124725 ,
81
- 0.24611230194568634 ,
82
- 0.33791524171829224 ,
83
- 0.44070982933044434 ,
84
- 0.5626170039176941 ,
85
- 0.7229568362236023 ,
86
- 1.0 ,
87
- ],
88
- dtype = torch .float32 ,
89
- device = "cpu" ,
90
- )
91
-
92
-
93
- @register_kernel ("bitsandbytes::quantize_4bit" , "cpu" )
94
- def _ (
95
- A : torch .Tensor , blocksize : int , quant_type : str , quant_storage : torch .dtype
96
- ) -> tuple [torch .Tensor , torch .Tensor ]:
97
- torch ._check_is_size (blocksize )
98
- torch ._check (quant_type == "nf4" , lambda : f"quant_type must be nf4 on CPU, got { quant_type } " )
99
- torch ._check (
100
- A .dtype in [torch .bfloat16 , torch .float16 , torch .float32 ],
101
- lambda : f"Blockwise 4bit quantization only supports 16/32-bit floats, but got { A .dtype } " ,
102
- )
103
-
104
- n = A .numel ()
105
-
106
- # TODO: Support when weight matrix is not divisible by blocksize
107
- torch ._check (n % blocksize == 0 , lambda : f"n must be divisible by blocksize, got { n } and { blocksize } " )
108
-
109
- # Divide into blocks and normalize
110
- blocks = A .reshape (- 1 , blocksize )
111
- absmax = blocks .abs ().max (dim = 1 ).values .float ()
112
- scaled = blocks / absmax .unsqueeze (- 1 )
113
-
114
- # Quantize with the lookup table
115
- quantized = torch .argmin (torch .abs (scaled .view (- 1 , 1 ) - _NF4_QUANT_TABLE ), dim = - 1 , keepdim = True ).to (torch .uint8 )
116
-
117
- # Pack two quantized values per byte
118
- packed = quantized [::2 ] << 4 | quantized [1 ::2 ]
119
-
120
- if quant_storage != torch .uint8 :
121
- packed = packed .squeeze ().view (quant_storage ).unsqueeze (1 )
122
-
123
- return packed , absmax .float ()
124
-
125
-
126
- @register_kernel ("bitsandbytes::dequantize_4bit" , "cpu" )
127
- def _ (
128
- A : torch .Tensor ,
129
- absmax : torch .Tensor ,
130
- blocksize : int ,
131
- quant_type : str ,
132
- shape : Sequence [int ],
133
- dtype : torch .dtype ,
134
- ) -> torch .Tensor :
135
- torch ._check_is_size (blocksize )
136
- torch ._check (quant_type == "nf4" , lambda : f"quant_type must be nf4 on CPU, got { quant_type } " )
137
- torch ._check (
138
- dtype in [torch .bfloat16 , torch .float16 , torch .float32 ],
139
- lambda : f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got { dtype } " ,
140
- )
141
- torch ._check (
142
- A .dtype == torch .uint8 ,
143
- lambda : f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got { A .dtype } " ,
144
- )
145
-
146
- A = A .view (- 1 , 1 )
147
-
148
- # Grab upper and lower nibbles. Using int64 for indexing in the LUT.
149
- upper = (A >> 4 ).to (torch .int64 )
150
- lower = (A & 0x0F ).to (torch .int64 )
151
-
152
- # Expand to blocks
153
- blocks = torch .cat ((upper , lower ), dim = 1 ).reshape (- 1 , blocksize )
154
-
155
- # Dequantize
156
- blocks = _NF4_QUANT_TABLE [blocks ] * absmax [:, None ]
157
-
158
- # Reshape to original shape
159
- blocks = blocks .reshape (- 1 , * shape [1 :])
160
-
161
- return blocks .to (dtype )
162
-
163
-
164
- @register_kernel ("bitsandbytes::gemv_4bit" , "cpu" )
165
- def _ (
166
- A : torch .Tensor ,
167
- B : torch .Tensor ,
168
- shapeB : Sequence [int ],
169
- absmax : torch .Tensor ,
170
- code : torch .Tensor ,
171
- blocksize : int ,
172
- ) -> torch .Tensor :
173
- # TODO: We need to determine whether `code` is NF4, FP4, or other.
174
- # Right now we assume NF4, as this is the only one supported on CPU.
175
-
176
- B_dq = torch .ops .bitsandbytes .dequantize_4bit .default (
177
- B ,
178
- absmax ,
179
- blocksize ,
180
- "nf4" ,
181
- shape = shapeB ,
182
- dtype = A .dtype ,
183
- )
184
-
185
- # User called gemv with B.t(), so we need to transpose it back.
186
- # if B.shape[0] == 1:
187
- # B_dq = B_dq.t()
188
-
189
- return torch .nn .functional .linear (
190
- A ,
191
- B_dq ,
192
- bias = None ,
193
- )
100
+ if ipex_cpu :
101
+ from bitsandbytes .utils import _reverse_4bit_compress_format
102
+
103
+ @register_kernel ("bitsandbytes::dequantize_nf4_ipex" , "cpu" )
104
+ def _ (
105
+ A : torch .Tensor ,
106
+ absmax : torch .Tensor ,
107
+ blocksize : int ,
108
+ shape : Sequence [int ],
109
+ dtype : torch .dtype ,
110
+ ) -> torch .Tensor :
111
+ ipex_weight = torch .ops .ipex_prepack .woq_linear_unpack_weight (A , "nf4" , shape , 2 )
112
+ A = _reverse_4bit_compress_format (ipex_weight .reshape (- 1 )).reshape (1 , - 1 )
113
+ return torch .ops .bitsandbytes .dequantize_4bit .default (
114
+ A ,
115
+ absmax ,
116
+ blocksize ,
117
+ "nf4" ,
118
+ shape ,
119
+ dtype ,
120
+ )
0 commit comments