Skip to content

Commit 2b20c24

Browse files
committed
Merge branch 'dev'
2 parents 0a3d420 + a08ef4f commit 2b20c24

File tree

9 files changed

+156
-97
lines changed

9 files changed

+156
-97
lines changed

exllamav2/exllamav2_ext/cuda/cache.cu

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
#include "cache.cuh"
2+
#include <c10/cuda/CUDAGuard.h>
3+
#include <ATen/cuda/CUDAContext.h>
4+
#include <cuda_fp16.h>
25

36
#include "quant/qdq_util.cuh"
47
#include "util.cuh"
@@ -492,3 +495,82 @@ void array_q_to_fp16_kv_cuda
492495
dim, offset, stride
493496
);
494497
}
498+
499+
#define NUM_THREADS 512
500+
#define NUM_BLOCKS 128
501+
#define CEIL_DIVIDE(x, size) (((x) + (size) - 1) / (size))
502+
503+
__global__ __launch_bounds__(NUM_THREADS)
504+
void cache_rotate_kernel
505+
(
506+
uint8_t* __restrict__ cache,
507+
const uint32_t* __restrict__ order,
508+
uint8_t* __restrict__ temp,
509+
size_t page_size,
510+
size_t rotate_len
511+
)
512+
{
513+
// Chunk for current CTA
514+
size_t block_size = CEIL_DIVIDE(page_size, gridDim.x);
515+
size_t block_beg = blockIdx.x * block_size;
516+
size_t block_end = min(block_beg + block_size, page_size);
517+
block_size = block_end - block_beg;
518+
if (!block_size) return;
519+
520+
// Rotate pages
521+
auto copy = [&](uint8_t* dst, uint8_t* src)
522+
{
523+
for (int offset = threadIdx.x * 16; offset < block_size; offset += NUM_THREADS * 16)
524+
*((uint4*) (dst + offset)) = *((uint4*) (src + offset));
525+
};
526+
527+
int i;
528+
copy(temp + block_beg, cache + page_size * (uint64_t) order[0] + block_beg);
529+
for (i = 0; i < rotate_len - 1; ++i)
530+
copy(cache + page_size * (uint64_t) order[i] + block_beg, cache + page_size * (uint64_t) order[i + 1] + block_beg);
531+
copy(cache + page_size * (uint64_t) order[i] + block_beg, temp + block_beg);
532+
}
533+
534+
/*
535+
Reorder cache pages
536+
- cache, paged cache, shape (num_pages, ...), any dtype, contiguous
537+
- order, sequence to rotate, shape (n,), dtype long
538+
- temp, temp storage, sized as one cache page
539+
540+
Performs:
541+
542+
temp <- page[order[0]]
543+
for a, b in pairwise(order):
544+
page[a] <- page[b]
545+
page[order[-1]] <- temp
546+
*/
547+
548+
void cache_rotate
549+
(
550+
const at::Tensor& cache,
551+
const at::Tensor& order,
552+
const at::Tensor& temp
553+
)
554+
{
555+
const at::cuda::OptionalCUDAGuard device_guard(cache.device());
556+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
557+
558+
TORCH_CHECK(cache.dim() > 1, "cache argument must have dim >= 2")
559+
TORCH_CHECK(order.dim() == 1, "order argument must have dim == 1")
560+
// TORCH_CHECK_DTYPE(order, kInt);
561+
562+
size_t num_pages = cache.size(0);
563+
size_t page_size = cache.nbytes() / num_pages;
564+
size_t rotate_len = order.size(0);
565+
566+
TORCH_CHECK(temp.nbytes() == page_size, "temp tensor incorrect size");
567+
568+
cache_rotate_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>
569+
(
570+
(uint8_t*) cache.data_ptr(),
571+
(const uint32_t*) order.data_ptr(),
572+
(uint8_t*) temp.data_ptr(),
573+
page_size,
574+
rotate_len
575+
);
576+
}

exllamav2/exllamav2_ext/cuda/cache.cuh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <cstdint>
77
#include <cstdio>
88

9+
#include <ATen/Tensor.h>
10+
911
void array_fp16_to_fp8_cuda
1012
(
1113
cudaStream_t stream,
@@ -100,4 +102,11 @@ void array_q_to_fp16_kv_paged_cuda
100102
// void array_fp16_to_fp8_ref_cuda(const half* pIn, unsigned char *pOut, int size);
101103
// void array_fp8_to_fp16_ref_cuda(const unsigned char* pIn, half* pOut, int size);
102104

105+
void cache_rotate
106+
(
107+
const at::Tensor& cache,
108+
const at::Tensor& order,
109+
const at::Tensor& temp
110+
);
111+
103112
#endif

exllamav2/exllamav2_ext/ext_bindings.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include "ext_element.h"
2323
#include "ext_tp.h"
2424

25+
#include "cuda/cache.cuh"
26+
2527
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
2628
{
2729
// quant
@@ -95,6 +97,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
9597
m.def("count_match", &count_match, "count_match");
9698
// m.def("array_fp16_to_fp8_ref", &array_fp16_to_fp8_ref, "array_fp16_to_fp8_ref");
9799
// m.def("array_fp8_to_fp16_ref", &array_fp8_to_fp16_ref, "array_fp8_to_fp16_ref");
100+
m.def("cache_rotate", &cache_rotate, "cache_rotate");
98101

99102
// hadamard
100103

exllamav2/generator/dynamic.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,27 +1352,34 @@ def defrag_cache(self):
13521352
if not self.paged:
13531353
return
13541354

1355+
# Defragment once job queue is empty after touching all the cache pages
13551356
if self.access_serial < self.last_defrag_serial + self.max_pages:
13561357
return
13571358
self.last_defrag_serial = self.access_serial
13581359

13591360
assert not self.referenced_pages
13601361

1361-
@dataclass
13621362
class CacheNode:
13631363
page: CachePage | None
1364-
parent: CachePage | None = None
1365-
children: set[CacheNode] = None
1366-
left_page: int = len(self.all_pages)
1364+
parent: CacheNode | None
1365+
children: set[CacheNode] | None
1366+
children_sorted: deque[CacheNode] | None
1367+
left_page: int = 0
13671368
def __init__(self, page_):
13681369
self.page = page_
1369-
if self.page:
1370-
self.left_page = page_.access_serial
1370+
self.parent = None
13711371
self.children = set()
1372+
self.children_sorted = None
1373+
self.left_page = page_.access_serial if page_ else 0
13721374
def __hash__(self):
13731375
return id(self)
13741376
def __eq__(self, other):
13751377
return self is other
1378+
def presort(self, recursive = True):
1379+
self.children_sorted = deque(sorted(self.children, key = lambda x: x.left_page))
1380+
if recursive:
1381+
for c in self.children:
1382+
c.presort()
13761383

13771384
# Build a tree of the current cache
13781385

@@ -1393,28 +1400,50 @@ def __eq__(self, other):
13931400

13941401
# Remove oldest branch until tree is empty
13951402

1403+
root_node.presort()
1404+
shift_counts = {}
1405+
13961406
new_page_index = 0
13971407
while root_node.children:
1398-
oldest = min(root_node.children, key = lambda x: x.left_page)
1408+
oldest = root_node.children_sorted[0]
13991409
node = oldest
14001410
skipped_nodes = set()
14011411
while True:
14021412
node.page.new_page_index = new_page_index
1413+
shift = node.page.new_page_index - node.page.page_index
1414+
if shift in shift_counts:
1415+
shift_counts[shift] += 1
1416+
else:
1417+
shift_counts[shift] = 1
14031418
new_page_index += 1
14041419
if not node.children: break
1405-
next_node = min(node.children, key = lambda x: x.left_page)
1406-
skipped_nodes |= set([n for n in node.children if n != next_node])
1420+
next_node = node.children_sorted[0]
1421+
if len(node.children_sorted) > 1:
1422+
skipped_nodes |= set([n for n in node.children if n != next_node])
14071423
node = next_node
14081424
root_node.children.remove(oldest)
1425+
root_node.children_sorted.popleft()
14091426
root_node.children |= skipped_nodes
1427+
if len(skipped_nodes):
1428+
root_node.presort(False)
1429+
1430+
# Adjust overall shift to minimize page copies
1431+
1432+
shift_adjust = max(shift_counts, key = shift_counts.get)
14101433

14111434
# Order of operations
14121435

14131436
defrag_map = {}
14141437
for page in self.all_pages:
1438+
page.new_page_index = (page.new_page_index - shift_adjust + self.max_pages) % self.max_pages
14151439
if page.page_index != page.new_page_index:
14161440
defrag_map[page.new_page_index] = page.page_index
14171441

1442+
# Don't bother if less than 10% of cache is fragmented
1443+
1444+
if len(defrag_map) <= self.max_pages // 10:
1445+
return
1446+
14181447
# Shuffle pages
14191448

14201449
cache_tensors = self.cache.all_tensors()
@@ -1435,12 +1464,11 @@ def __eq__(self, other):
14351464
source = defrag_map[target]
14361465
del defrag_map[target]
14371466

1438-
rotation = [r * self.page_size for r in rotation]
1467+
rotation = torch.tensor(rotation, dtype = torch.int)
14391468
for cache, buffer in zip(cache_tensors, defrag_buffers):
1440-
buffer[:, :, :, :].copy_(cache[:, rotation[0] : rotation[0] + self.page_size, :, :])
1441-
for a, b in pairwise(rotation):
1442-
cache[:, a : a + self.page_size, :, :].copy_(cache[:, b : b + self.page_size, :, :])
1443-
cache[:, rotation[-1] : rotation[-1] + self.page_size, :, :].copy_(buffer[:, :, :, :])
1469+
rotation = rotation.to(cache.device)
1470+
cache = cache.view(cache.shape[1] // self.page_size, -1)
1471+
ext_c.cache_rotate(cache, rotation, buffer)
14441472

14451473
# Update page table
14461474

exllamav2/tokenizer/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from exllamav2.version import __version__
22

33
from exllamav2.tokenizer.base import ExLlamaV2TokenizerBase
4-
from exllamav2.tokenizer.spm import ExLlamaV2TokenizerSPM
54
from exllamav2.tokenizer.hf import ExLlamaV2TokenizerHF

exllamav2/tokenizer/spm.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

exllamav2/tokenizer/tokenizer.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import os, json, re
66
from exllamav2.tokenizer import (
77
ExLlamaV2TokenizerBase,
8-
ExLlamaV2TokenizerSPM,
98
ExLlamaV2TokenizerHF
109
)
1110
import threading
@@ -93,13 +92,12 @@ def __init__(
9392
Defer initialization of some data structures to speed up loading
9493
9594
:param force_json:
96-
No effect from v0.2.3. tokenizer.json is now preferred over tokenizer.model by default.
97-
If True and no tokenizer.json is present in the model directory, will emit a warning before
98-
falling back to SPM
95+
No effect from v0.2.3. tokenizer.json is now preferred over tokenizer.model by default. From v0.3.1
96+
tokenizer.model is not used at all
9997
10098
:param force_spm:
101-
Use only tokenizer.model (SentencePiece) even if tokenizer.model (HF Tokenizers)
102-
is available
99+
Deprecated, Sentencepiece is abandoned and no longer supported. All SPM tokenizers should
100+
still load correctly via the Tokenizers library
103101
"""
104102

105103
self.config = config
@@ -123,33 +121,31 @@ def __init__(
123121

124122
# Detect tokenizer model type and initialize
125123

126-
path_spm = os.path.join(self.config.model_dir, "tokenizer.model")
124+
assert not force_spm, "tokenizer.py: force_spm is deprecated. Sentencepiece is no longer supported."
127125
path_hf = os.path.join(self.config.model_dir, "tokenizer.json")
128126

129-
if os.path.exists(path_hf) and not force_spm:
130-
self.tokenizer_model = ExLlamaV2TokenizerHF(path_hf)
131-
elif os.path.exists(path_spm):
132-
if force_json:
133-
print(" !! Warning: Tokenizer loading with force_json = True but no tokenizer.json found, falling back to tokenizer.model")
134-
self.tokenizer_model = ExLlamaV2TokenizerSPM(path_spm)
135-
else:
127+
if not os.path.exists(path_hf):
136128
raise FileNotFoundError("No supported tokenizer found.")
137129

130+
self.tokenizer_model = ExLlamaV2TokenizerHF(path_hf)
131+
138132
# Attempt to load added tokens from tokenizer.json
139133

140134
self.extended_piece_to_id = {}
141135
self.unspecial_piece_to_id = {}
142136

143137
tokenizer_json_path = os.path.join(self.config.model_dir, "tokenizer.json")
144-
if os.path.exists(tokenizer_json_path):
145-
with open(tokenizer_json_path, encoding = "utf8") as f:
146-
tokenizer_json = json.load(f)
147-
if "added_tokens" in tokenizer_json:
148-
for v in tokenizer_json["added_tokens"]:
149-
if v["special"]:
150-
self.extended_piece_to_id[v["content"]] = v["id"]
151-
else:
152-
self.unspecial_piece_to_id[v["content"]] = v["id"]
138+
if not os.path.exists(tokenizer_json_path):
139+
raise ValueError(" ## Model does not include a tokenizer.json file. SentencePiece-only tokenizers are no longer supported")
140+
141+
with open(tokenizer_json_path, encoding = "utf8") as f:
142+
tokenizer_json = json.load(f)
143+
if "added_tokens" in tokenizer_json:
144+
for v in tokenizer_json["added_tokens"]:
145+
if v["special"]:
146+
self.extended_piece_to_id[v["content"]] = v["id"]
147+
else:
148+
self.unspecial_piece_to_id[v["content"]] = v["id"]
153149

154150
# Attempt to load tokenizer_config.json
155151

exllamav2/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.0"
1+
__version__ = "0.3.1"

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ setuptools
55
fastparquet
66
torch>=2.2.0
77
safetensors>=0.4.3
8-
sentencepiece>=0.1.97
98
pygments
109
websockets
1110
regex

0 commit comments

Comments
 (0)