From 00906e3d37131cce5dc2e64d0b878e6a2432b727 Mon Sep 17 00:00:00 2001 From: Saood Karim Date: Sun, 12 Jan 2025 08:31:08 -0600 Subject: [PATCH] Deepseek V3 support added MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Stanisław Szymczyk --- convert_hf_to_gguf.py | 3 + convert_hf_to_gguf_update.py | 1 + gguf-py/gguf/constants.py | 9 +++ gguf-py/gguf/gguf_writer.py | 7 +++ gguf-py/gguf/tensor_mapping.py | 4 ++ include/llama.h | 1 + src/llama-vocab.cpp | 7 +++ src/llama.cpp | 100 +++++++++++++++++++++++++++++++-- src/unicode.cpp | 9 ++- 9 files changed, 136 insertions(+), 5 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b470a0883..3910aa1dc 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -590,6 +590,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "855059429035d75a914d1eda9f10a876752e281a054a7a3d421ef0533e5b6249": # ref: https://huggingface.co/HuggingFaceTB/SmolLM-135M res = "smollm" + if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5": + # ref: https://huggingface.co/deepseek-ai/DeepSeek-V3 + res = "deepseek-v3" if res is None: logger.warning("\n") diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index d5a2d925e..40af02f46 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -94,6 +94,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", }, {"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", }, {"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", }, + {"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"}, ] diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1bea66aa0..90d5efec2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -89,6 +89,8 @@ class LLM: EXPERT_USED_COUNT = "{arch}.expert_used_count" EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" + EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" + EXPERT_GATING_FUNC = "{arch}.expert_gating_func" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" @@ -257,6 +259,7 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() + FFN_EXP_PROBS_B = auto() ATTN_Q_NORM = auto() ATTN_K_NORM = auto() LAYER_OUT_NORM = auto() @@ -387,6 +390,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", @@ -978,6 +982,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B ], MODEL_ARCH.CHATGLM : [ MODEL_TENSOR.TOKEN_EMBD, @@ -1177,6 +1182,10 @@ class GGMLQuantizationType(IntEnum): IQ2_TN = 42, +class ExpertGatingFuncType(IntEnum): + SOFTMAX = 1 + SIGMOID = 2 + # TODO: add GGMLFileType from ggml_ftype in ggml.h diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 76385a828..e31bf97b1 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -26,6 +26,7 @@ RopeScalingType, PoolingType, TokenType, + ExpertGatingFuncType, ) from .quants import quant_shape_from_byte_shape @@ -670,6 +671,12 @@ def add_expert_shared_count(self, count: int) -> None: def add_expert_weights_scale(self, value: float) -> None: self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) + def add_expert_weights_norm(self, value: bool) -> None: + self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value) + + def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None: + self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value) + def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 9aa2209e2..a70b69c5a 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -251,6 +251,10 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe ), + MODEL_TENSOR.FFN_EXP_PROBS_B: ( + "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 + ), + # Feed-forward up MODEL_TENSOR.FFN_UP: ( "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox diff --git a/include/llama.h b/include/llama.h index f5f3b8bf9..3a81dce42 100644 --- a/include/llama.h +++ b/include/llama.h @@ -93,6 +93,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 23, //llama.cpp lists this as 28 }; // note: these values should be synchronized with ggml_rope diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 749f85718..4bd5aa815 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -367,6 +367,13 @@ struct llm_tokenizer_bpe { "\\p{N}+", }; break; + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: + regex_exprs = { + "\\p{N}{1,3}", + "[一-龥぀-ゟ゠-ヿ]+", + "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: regex_exprs = { "[\r\n]", diff --git a/src/llama.cpp b/src/llama.cpp index b983c84b4..29bd14aff 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -106,7 +106,7 @@ // bump if necessary #define LLAMA_MAX_LAYERS 512 -#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2 +#define LLAMA_MAX_EXPERTS 256 // DeepSeekV2 // // helpers @@ -294,6 +294,8 @@ enum llm_kv { LLM_KV_EXPERT_USED_COUNT, LLM_KV_EXPERT_SHARED_COUNT, LLM_KV_EXPERT_WEIGHTS_SCALE, + LLM_KV_EXPERT_WEIGHTS_NORM, + LLM_KV_EXPERT_GATING_FUNC, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, @@ -399,6 +401,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, + { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, + { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, @@ -520,6 +524,7 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, @@ -1211,6 +1216,7 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, }, }, { @@ -2186,6 +2192,7 @@ enum e_model { MODEL_70B, MODEL_236B, MODEL_314B, + MODEL_671B, MODEL_SMALL, MODEL_MEDIUM, MODEL_LARGE, @@ -2203,6 +2210,21 @@ static const size_t kiB = 1024; static const size_t MiB = 1024*kiB; static const size_t GiB = 1024*MiB; +enum llm_expert_gating_func_type { + LLM_EXPERT_GATING_FUNC_SOFTMAX = 1, + LLM_EXPERT_GATING_FUNC_SIGMOID = 2, +}; + +static const char * llama_expert_gating_func_name(llm_expert_gating_func_type type) { + switch (type) { + case LLM_EXPERT_GATING_FUNC_SOFTMAX: return "softmax"; + case LLM_EXPERT_GATING_FUNC_SIGMOID: return "sigmoid"; + default: return "unknown"; + } +} + + + struct llama_hparams { bool vocab_only; bool rope_finetuned; @@ -2232,6 +2254,8 @@ struct llama_hparams { uint32_t n_ff_shexp = 0; uint32_t n_expert_shared = 0; float expert_weights_scale = 0.0; + bool expert_weights_norm = false; + uint32_t expert_gating_func = LLM_EXPERT_GATING_FUNC_SOFTMAX; float f_norm_eps; float f_norm_rms_eps; @@ -2502,6 +2526,7 @@ struct llama_layer { struct ggml_tensor * ffn_down_b = nullptr; // b2 struct ggml_tensor * ffn_up_b = nullptr; // b3 struct ggml_tensor * ffn_act; + struct ggml_tensor * ffn_exp_probs_b = nullptr; // mamba proj struct ggml_tensor * ssm_in; @@ -4677,6 +4702,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_70B: return "70B"; case MODEL_236B: return "236B"; case MODEL_314B: return "314B"; + case MODEL_671B: return "671B"; case MODEL_SMALL: return "0.1B"; case MODEL_MEDIUM: return "0.4B"; case MODEL_LARGE: return "0.8B"; @@ -5302,11 +5328,19 @@ static void llm_load_hparams( ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == 0) { + // for compatibility with existing DeepSeek V2 and V2.5 GGUFs + // that have no expert_gating_func model parameter set + hparams.expert_gating_func = LLM_EXPERT_GATING_FUNC_SOFTMAX; + } ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); switch (hparams.n_layer) { case 27: model.type = e_model::MODEL_16B; break; case 60: model.type = e_model::MODEL_236B; break; + case 61: model.type = e_model::MODEL_671B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -5565,6 +5599,10 @@ static void llm_load_vocab( tokenizer_pre == "deepseek-coder") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "deepseek-v3") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; + vocab.tokenizer_clean_spaces = false; } else if ( tokenizer_pre == "falcon") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON; @@ -6075,6 +6113,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((enum llm_expert_gating_func_type) hparams.expert_gating_func)); LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); } @@ -7540,6 +7580,7 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } else { layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_exp_probs_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert} ); GGML_ASSERT(n_expert > 0); GGML_ASSERT(n_expert_used > 0); @@ -8346,12 +8387,14 @@ static struct ggml_tensor * llm_build_moe_ffn( struct ggml_tensor * up_exps, struct ggml_tensor * gate_exps, struct ggml_tensor * down_exps, + struct ggml_tensor * exp_probs_b, int64_t n_expert, int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, bool scale_w, float w_scale, +llm_expert_gating_func_type gating_op, const llm_build_cb & cb, int il) { int64_t n_embd = cur->ne[0]; @@ -8360,11 +8403,32 @@ static struct ggml_tensor * llm_build_moe_ffn( ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens] cb(logits, "ffn_moe_logits", il); - ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] + //ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] + ggml_tensor * probs = nullptr; + switch (gating_op) { + case LLM_EXPERT_GATING_FUNC_SOFTMAX: + { + probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] + } break; + case LLM_EXPERT_GATING_FUNC_SIGMOID: + { + probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens] + } break; + default: + GGML_ABORT("fatal error"); + } cb(probs, "ffn_moe_probs", il); + // add experts selection bias - introduced in DeepSeek V3 + // leave probs unbiased as it's later used to get expert weights + ggml_tensor * selection_probs = probs; + if (exp_probs_b != nullptr) { + selection_probs = ggml_add(ctx, probs, exp_probs_b); + cb(selection_probs, "ffn_moe_probs_biased", il); + } + // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts, "ffn_moe_topk", il); @@ -9180,9 +9244,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); } @@ -9673,9 +9739,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_GELU, true, false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -9814,9 +9882,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -10944,9 +11014,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -13109,9 +13181,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -13324,9 +13398,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, - LLM_FFN_SILU, false, + LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale, + (enum llm_expert_gating_func_type) hparams.expert_gating_func, cb, il); cb(moe_out, "ffn_moe_out", il); @@ -18547,6 +18623,7 @@ struct llama_data_read { read_to(&n_seq_id, sizeof(n_seq_id)); if (n_seq_id != 0) { + llama_batch_free(batch); LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); return false; } @@ -19732,6 +19809,21 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "Assistant:"; } + } else if (tmpl == "deepseek3" || tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) { + // DeepSeek-V3 + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "\n\n"; + } else if (role == "user") { + ss << LU8("<|User|>") << message->content; + } else if (role == "assistant") { + ss << LU8("<|Assistant|>") << message->content << LU8("<|end▁of▁sentence|>"); + } + } + if (add_ass) { + ss << LU8("<|Assistant|>"); + } } else { // template not supported return -1; diff --git a/src/unicode.cpp b/src/unicode.cpp index 46650bff0..cfffde0d9 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -648,18 +648,25 @@ std::vector unicode_regex_split(const std::string & text, const std { "\\p{N}", codepoint_flags::NUMBER }, { "\\p{L}", codepoint_flags::LETTER }, { "\\p{P}", codepoint_flags::PUNCTUATION }, + { "\\p{M}", codepoint_flags::ACCENT_MARK }, + { "\\p{S}", codepoint_flags::SYMBOL }, }; static const std::map k_ucat_cpt = { { codepoint_flags::NUMBER, 0xD1 }, { codepoint_flags::LETTER, 0xD2 }, { codepoint_flags::PUNCTUATION, 0xD3 }, + { codepoint_flags::ACCENT_MARK, 0xD4 }, + { codepoint_flags::SYMBOL, 0xD5 }, + }; static const std::map k_ucat_map = { { codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9 { codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z - { codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} + { codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}i + { codepoint_flags::ACCENT_MARK, "" }, // no sub-128 codepoints + { codepoint_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`| }; // compute collapsed codepoints only if needed by at least one regex