Skip to content

Commit 9f7f69d

Browse files
l0rincLőrinc
and
Lőrinc
authored
Add possessive quantifiers to avoid catastrophic backtracking (#258)
Fixes the crash in #245 by prohibiting the regex engine from backtracking catastrophically via [possessive quantifiers](https://www.regular-expressions.info/possessive.html). <img width="400" alt="image" src="https://github.com/openai/tiktoken/assets/1841944/ed341153-4cf4-4c1c-93d6-3f5e32133569"> Interestingly these possesives make the encoding a lot faster again in `fancy-regex`. Before this change (but with large byte pair merge PR cherry-picked): ``` num_threads: 1, num_bytes: 98379553 tiktoken 11,946,036 bytes / s tiktoken 11,961,343 bytes / s tiktoken 11,995,846 bytes / s tiktoken 11,951,263 bytes / s tiktoken 11,983,405 bytes / s ``` Same, with these changes applied: ``` num_threads: 1, num_bytes: 98379553 tiktoken 14,511,827 bytes / s tiktoken 14,638,134 bytes / s tiktoken 14,644,029 bytes / s tiktoken 14,729,030 bytes / s tiktoken 14,666,903 bytes / s ``` Updating the regex libs makes it a tiny bit faster still: ``` num_threads: 1, num_bytes: 98379553 tiktoken 14,485,590 bytes / s tiktoken 14,854,049 bytes / s tiktoken 14,891,086 bytes / s tiktoken 14,843,007 bytes / s tiktoken 14,874,520 bytes / s ``` This is almost 2x faster than [before any of the optimizations](#234). ------- Opened an issue for increasing the [default backtrack limit](https://github.com/fancy-regex/fancy-regex/blob/bf2c807447f72ee20ae839e0f8cb3a06fc79982c/src/lib.rs#L407), see: fancy-regex/fancy-regex#134, but it shouldn't be necessary here anymore. --------- Co-authored-by: Lőrinc <lorinc.pap@gmail.com>
1 parent c0ba74c commit 9f7f69d

File tree

4 files changed

+43
-11
lines changed

4 files changed

+43
-11
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ crate-type = ["cdylib"]
1212
pyo3 = { version = "0.20.0", features = ["extension-module"] }
1313

1414
# tiktoken dependencies
15-
fancy-regex = "0.11.0"
16-
regex = "1.8.3"
15+
fancy-regex = "0.13.0"
16+
regex = "1.10.3"
1717
rustc-hash = "1.1.0"
1818
bstr = "1.5.0"

src/lib.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::num::NonZeroU64;
66
use std::thread;
77

88
use fancy_regex::Regex;
9+
use fancy_regex::RegexBuilder;
910
use pyo3::exceptions;
1011
use pyo3::prelude::*;
1112
use pyo3::pyclass;
@@ -417,7 +418,7 @@ impl CoreBPE {
417418
special_tokens_encoder: HashMap<String, Rank>,
418419
pattern: &str,
419420
) -> PyResult<Self> {
420-
let regex = Regex::new(pattern)
421+
let regex = RegexBuilder::new(pattern).backtrack_limit(10_000).build()
421422
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;
422423

423424
let special_regex = {
@@ -572,6 +573,7 @@ fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
572573

573574
#[cfg(test)]
574575
mod tests {
576+
use fancy_regex::RegexBuilder;
575577
use rustc_hash::FxHashMap as HashMap;
576578

577579
use crate::{byte_pair_split, Rank};
@@ -596,4 +598,16 @@ mod tests {
596598
let res = byte_pair_split(b"abab", &ranks);
597599
assert_eq!(res, vec![b"ab", b"ab"]);
598600
}
601+
602+
#[test]
603+
fn test_effect_of_backtrack_limit() {
604+
let regex = RegexBuilder::new(r"(a|b|ab)*(?=c)")
605+
.backtrack_limit(10)
606+
.build()
607+
.expect("Failed to build regex")
608+
.clone();
609+
610+
let input = "ab".repeat(100) + "c";
611+
assert!(regex.is_match(&input).is_err(), "Should throw");
612+
}
599613
}

tests/test_encoding.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@
1111
from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES
1212

1313

14+
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
15+
def test_extremely_big_encoding(make_enc: Callable[[], tiktoken.Encoding]):
16+
enc = make_enc()
17+
for c in ["^", "0", "a", "'s", " ", "\n"]:
18+
print(f"Validating `{c}`")
19+
20+
big_value = c * 10_000
21+
assert big_value == enc.decode(enc.encode(big_value))
22+
23+
big_value = " " + big_value
24+
assert big_value == enc.decode(enc.encode(big_value))
25+
26+
big_value = big_value + "\n"
27+
assert big_value == enc.decode(enc.encode(big_value))
28+
29+
1430
def test_simple():
1531
enc = tiktoken.get_encoding("gpt2")
1632
assert enc.encode("hello world") == [31373, 995]

tiktoken_ext/openai_public.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
FIM_SUFFIX = "<|fim_suffix|>"
77
ENDOFPROMPT = "<|endofprompt|>"
88

9+
# The pattern in the original GPT-2 release is:
10+
# r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
11+
# This is equivalent, but executes faster:
12+
_legacy_splitter_regex = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s"""
13+
914

1015
def gpt2():
1116
mergeable_ranks = data_gym_to_mergeable_bpe_ranks(
@@ -17,10 +22,7 @@ def gpt2():
1722
return {
1823
"name": "gpt2",
1924
"explicit_n_vocab": 50257,
20-
# The pattern in the original GPT-2 release is:
21-
# r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
22-
# This is equivalent, but executes faster:
23-
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
25+
"pat_str": _legacy_splitter_regex,
2426
"mergeable_ranks": mergeable_ranks,
2527
"special_tokens": {ENDOFTEXT: 50256},
2628
}
@@ -34,7 +36,7 @@ def r50k_base():
3436
return {
3537
"name": "r50k_base",
3638
"explicit_n_vocab": 50257,
37-
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
39+
"pat_str": _legacy_splitter_regex,
3840
"mergeable_ranks": mergeable_ranks,
3941
"special_tokens": {ENDOFTEXT: 50256},
4042
}
@@ -48,7 +50,7 @@ def p50k_base():
4850
return {
4951
"name": "p50k_base",
5052
"explicit_n_vocab": 50281,
51-
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
53+
"pat_str": _legacy_splitter_regex,
5254
"mergeable_ranks": mergeable_ranks,
5355
"special_tokens": {ENDOFTEXT: 50256},
5456
}
@@ -62,7 +64,7 @@ def p50k_edit():
6264
special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283}
6365
return {
6466
"name": "p50k_edit",
65-
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
67+
"pat_str": _legacy_splitter_regex,
6668
"mergeable_ranks": mergeable_ranks,
6769
"special_tokens": special_tokens,
6870
}
@@ -82,7 +84,7 @@ def cl100k_base():
8284
}
8385
return {
8486
"name": "cl100k_base",
85-
"pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""",
87+
"pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s""",
8688
"mergeable_ranks": mergeable_ranks,
8789
"special_tokens": special_tokens,
8890
}

0 commit comments

Comments
 (0)