Skip to content

Commit 1a82283

Browse files
committed
Merge branch 'refs/heads/dev'
2 parents 57ee846 + f1d8909 commit 1a82283

File tree

14 files changed

+130
-35
lines changed

14 files changed

+130
-35
lines changed

exllamav2/exllamav2_ext/cpp/safetensors.cpp

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,4 +453,65 @@ void safetensors_read_fb(uintptr_t handle, size_t beg, size_t size, torch::Tenso
453453
remaining -= chunk;
454454
}
455455
}
456-
}
456+
}
457+
458+
void tensor_remap
459+
(
460+
torch::Tensor tensor,
461+
torch::Tensor index
462+
)
463+
{
464+
TORCH_CHECK_SHAPES(tensor, 1, index, 0, 1);
465+
TORCH_CHECK_DTYPE(tensor, kInt);
466+
TORCH_CHECK_DTYPE(index, kInt);
467+
468+
int rows = tensor.size(0);
469+
int cols = tensor.size(1);
470+
uint32_t* temp = (uint32_t*) calloc(cols, sizeof(int));
471+
uint32_t* a = (uint32_t*) tensor.data_ptr();
472+
uint32_t* idx = (uint32_t*) index.data_ptr();
473+
474+
for (int r = 0; r < rows; ++r)
475+
{
476+
memcpy(temp, a, sizeof(uint32_t) * cols);
477+
for (int c = 0; c < cols; ++c)
478+
{
479+
*a++ = temp[idx[c]];
480+
}
481+
}
482+
free(temp);
483+
}
484+
485+
void tensor_remap_4bit
486+
(
487+
torch::Tensor tensor,
488+
torch::Tensor index
489+
)
490+
{
491+
TORCH_CHECK_SHAPES(index, 0, tensor, 1, 8);
492+
TORCH_CHECK_DTYPE(tensor, kInt);
493+
TORCH_CHECK_DTYPE(index, kInt);
494+
495+
int rows = tensor.size(0);
496+
int cols = index.size(0);
497+
uint32_t* temp = (uint32_t*) calloc(cols / 8, sizeof(int));
498+
uint32_t* a = (uint32_t*) tensor.data_ptr();
499+
uint32_t* idx = (uint32_t*) index.data_ptr();
500+
501+
for (int r = 0; r < rows; ++r)
502+
{
503+
memcpy(temp, a, sizeof(uint32_t) * cols / 8);
504+
for (int c = 0; c < cols;)
505+
{
506+
uint32_t rv = 0;
507+
for (int b = 0; b < 8; ++b, ++c)
508+
{
509+
uint32_t i = idx[c];
510+
uint32_t v = (temp[i / 8] >> ((i & 7) * 4) & 0x0f);
511+
rv |= v << (b * 4);
512+
}
513+
*a++ = rv;
514+
}
515+
}
516+
free(temp);
517+
}

exllamav2/exllamav2_ext/cpp/safetensors.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,17 @@ uintptr_t safetensors_open_fb(const char* filename);
4747
void safetensors_close_fb(uintptr_t handle);
4848
void safetensors_read_fb(uintptr_t handle, size_t beg, size_t size, torch::Tensor target);
4949

50+
void tensor_remap
51+
(
52+
torch::Tensor tensor,
53+
torch::Tensor index
54+
);
55+
56+
void tensor_remap_4bit
57+
(
58+
torch::Tensor tensor,
59+
torch::Tensor index
60+
);
61+
62+
5063
#endif

exllamav2/exllamav2_ext/cuda/graph.cu

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ void Graph::attach_label(cudaStream_t stream, int label, int sublabel)
133133
}
134134

135135
template <typename T>
136-
void Graph::update_param(int label, int sublabel, int param, T value)
136+
void Graph::update_param(int label, int sublabel, int param, T value, bool debug)
137137
{
138138
for (int i = 0; i < node_labels.size(); ++i)
139139
{
@@ -145,19 +145,22 @@ void Graph::update_param(int label, int sublabel, int param, T value)
145145

146146
node_needs_update[i] = true;
147147

148-
// printf("-----------------------------------------------------\n");
149-
// printf("UPDATED:\n");
150-
// DBGI(i);
151-
// inspect_graph();
148+
if (debug)
149+
{
150+
printf("-----------------------------------------------------\n");
151+
printf("UPDATED: ");
152+
DBGI(i);
153+
inspect_graph();
154+
}
152155
}
153156
}
154157

155-
void Graph::update_param_ptr(int label, int sublabel, int param, void* value)
158+
void Graph::update_param_ptr(int label, int sublabel, int param, void* value, bool debug)
156159
{
157-
update_param<void*>(label, sublabel, param, value);
160+
update_param<void*>(label, sublabel, param, value, debug);
158161
}
159162

160-
void Graph::update_param_int(int label, int sublabel, int param, int value)
163+
void Graph::update_param_int(int label, int sublabel, int param, int value, bool debug)
161164
{
162-
update_param<int>(label, sublabel, param, value);
165+
update_param<int>(label, sublabel, param, value, debug);
163166
}

exllamav2/exllamav2_ext/cuda/graph.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ public:
4646
void attach_label(cudaStream_t stream, int label, int sublabel);
4747

4848
template <typename T>
49-
void update_param(int label, int sublabel, int param, T value);
49+
void update_param(int label, int sublabel, int param, T value, bool debug);
5050

51-
void update_param_ptr(int label, int sublabel, int param, void* value);
52-
void update_param_int(int label, int sublabel, int param, int value);
51+
void update_param_ptr(int label, int sublabel, int param, void* value, bool debug = false);
52+
void update_param_int(int label, int sublabel, int param, int value, bool debug = false);
5353
};
5454

5555

exllamav2/exllamav2_ext/cuda/q_mlp.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void QMLP::forward_
109109
if (graph->count())
110110
{
111111
graph->begin_capture(stream);
112-
forward_run_(stream, cublas_handle, (half*) x, rows, columns, loras, lora_temp, graph);
112+
forward_run_(stream, cublas_handle, (void*) x, rows, columns, loras, lora_temp, graph);
113113
graph->end_capture(stream);
114114
// printf("**** record ****\n");
115115
// DBGI2(rows, columns);
@@ -225,7 +225,7 @@ void QMLP::forward_run_
225225

226226
else
227227
{
228-
gemm_half_q_half_cuda(stream, cublas_handle, temp_a, down, temp_state, rows, columns, intermediate_size, true, temp_dq, graph, 0);
228+
gemm_half_q_half_cuda(stream, cublas_handle, temp_a, down, temp_state, rows, columns, intermediate_size, true, temp_dq, false, NULL, 0, false, graph, 0);
229229
if (layernorm_is_rms)
230230
rms_norm_cuda(stream, temp_state, post_layernorm, x, norm_epsilon, rows, columns, true, false, residual_fp32, graph, KernelLabels::POST_NORM);
231231
else

exllamav2/exllamav2_ext/ext_bindings.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
5555
m.def("safetensors_pinned_buffer", &safetensors_pinned_buffer, "safetensors_pinned_buffer");
5656
m.def("safetensors_free_pinned_buffer", &safetensors_free_pinned_buffer, "safetensors_free_pinned_buffer");
5757
m.def("safetensors_read_fb", &safetensors_read_fb, "safetensors_read_fb");
58+
m.def("tensor_remap", &tensor_remap, "tensor_remap");
59+
m.def("tensor_remap_4bit", &tensor_remap_4bit, "tensor_remap_4bit");
5860

5961
// qmatrix
6062

exllamav2/ext.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ def find_msvc():
173173
# gcc / cl.exe flags
174174

175175
if windows:
176-
extra_cflags = ["/Ox", "/openmp"]
176+
extra_cflags = ["/Ox"]
177177
else:
178-
extra_cflags = ["-Ofast", "-fopenmp"]
178+
extra_cflags = ["-Ofast"]
179179

180180
if ext_debug:
181181
extra_cflags += ["-ftime-report", "-DTORCH_USE_CUDA_DSA"]

exllamav2/fasttensors.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ def get_tensor(self,
189189
out_dtype = None) -> torch.Tensor:
190190
global global_tensorcache
191191

192+
torch.cuda.synchronize()
193+
192194
if self.tensor_remap and (not_fast or not self.fast):
193195
key = self.tensor_remap[key]
194196

@@ -211,6 +213,8 @@ def get_tensor(self,
211213
size = end - beg
212214
numel = size // esize
213215
shape = h["shape"]
216+
if device != "cpu":
217+
torch.cuda.set_stream(torch.cuda.default_stream(device))
214218
tensor = torch.zeros(shape, dtype = dtype, device = device)
215219
assert tensor.is_contiguous, "Non-contiguous tensor"
216220
ext_c.safetensors_read_fb(self.handle_fb, beg + self.header_size, size, tensor)
@@ -224,7 +228,8 @@ def get_tensor(self,
224228
offset = data_offsets[0] + self.header_size
225229
length = data_offsets[1] - data_offsets[0]
226230
assert np.prod(sh) * dts == length, f"Tensor shape doesn't match storage size: {key}"
227-
231+
if device != "cpu":
232+
torch.cuda.set_stream(torch.cuda.default_stream(device))
228233
tensor = torch.empty(sh, device = device, dtype = dtt)
229234
ext_c.safetensors_load(self.handle, tensor, offset, length)
230235

@@ -236,4 +241,6 @@ def get_tensor(self,
236241
global_tensorcache = global_tensorcache[1:]
237242
global_tensorcache.append((cachekey, tensor))
238243

244+
torch.cuda.synchronize()
245+
239246
return tensor

exllamav2/linear.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from exllamav2.compat import safe_move_tensor
99
from exllamav2.tensor_p import BROADCAST_VC
1010
from exllamav2.util import unpack_4bit, pack_4bit
11+
import gc
1112

1213
from typing import TYPE_CHECKING
1314

@@ -118,7 +119,7 @@ def load(self,
118119
cfg = self.model.config
119120

120121
if self.f_key: w = self.load_weight_fused(self.f_key, self.f_beg, self.f_end, self.in_features, self.out_features, self.altpack_qkv)
121-
if w is None: w = self.load_weight()
122+
if w is None: w = self.load_weight(cpu = output_map is not None)
122123

123124
# Load quantized linear layer from dictionary
124125

@@ -137,7 +138,7 @@ def load(self,
137138
self.q_tensors = w
138139

139140
if unmap and "q_perm" in w:
140-
perm = w["q_perm"]
141+
perm = w["q_perm"].cpu()
141142
del w["q_perm"]
142143
del w["q_invperm"]
143144
# w["q_perm"] = torch.arange(0, w["q_perm"].shape[-1], dtype = w["q_perm"].dtype, device = w["q_perm"].device)
@@ -146,8 +147,10 @@ def load(self,
146147
perm = None
147148

148149
if output_map is not None:
149-
w["q_weight"] = w["q_weight"][:, output_map]
150-
w["q_scale"] = pack_4bit(unpack_4bit(w["q_scale"])[:, output_map])
150+
ext_c.tensor_remap(w["q_weight"], output_map)
151+
ext_c.tensor_remap_4bit(w["q_scale"], output_map)
152+
for k in w.keys():
153+
w[k] = safe_move_tensor(w[k], self.device())
151154

152155
self.q_handle = ext.make_q_matrix(w,
153156
self.temp_dq,

exllamav2/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,10 @@ def forward_chunk(self,
989989
if self.tp_context:
990990
self.tp_context.wait_streams()
991991

992+
if x is not None and x.is_cuda:
993+
context = self.get_device_context(x.device.index)
994+
torch.cuda.set_stream(context.stream)
995+
992996
# Apply logit scale
993997

994998
# if x is not None and self.config.logit_scale != 1:

0 commit comments

Comments
 (0)