@@ -374,6 +374,178 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string& te
374
374
return bpe_offsets;
375
375
}
376
376
377
+ // K2 system regex patterns (from tokenization_kimi.py):
378
+ // [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+
379
+ static std::vector<size_t > unicode_regex_split_custom_kimi_k2 (const std::string & text, const std::vector<size_t > & offsets) {
380
+ std::vector<size_t > bpe_offsets;
381
+ bpe_offsets.reserve (offsets.size ());
382
+
383
+ const auto cpts = unicode_cpts_from_utf8 (text);
384
+
385
+ size_t start = 0 ;
386
+ for (auto offset : offsets) {
387
+ const size_t offset_ini = start;
388
+ const size_t offset_end = start + offset;
389
+ assert (offset_end <= cpts.size ());
390
+ start = offset_end;
391
+
392
+ static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF ;
393
+ auto _get_cpt = [&] (const size_t pos) -> uint32_t {
394
+ return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
395
+ };
396
+
397
+ auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
398
+ return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags (cpts[pos]) : codepoint_flags{};
399
+ };
400
+
401
+ size_t _prev_end = offset_ini;
402
+ auto _add_token = [&] (const size_t end) -> size_t {
403
+ assert (_prev_end <= end && end <= offset_end);
404
+ size_t len = end - _prev_end;
405
+ if (len > 0 ) {
406
+ bpe_offsets.push_back (len);
407
+ }
408
+ _prev_end = end;
409
+ return len;
410
+ };
411
+
412
+ for (size_t pos = offset_ini; pos < offset_end; /* pos++*/ ) {
413
+ const uint32_t cpt = _get_cpt (pos);
414
+ const auto flags = _get_flags (pos);
415
+
416
+ // Pattern 1: [\p{Han}]+ (Chinese characters)
417
+ if (unicode_cpt_is_han (cpt)) {
418
+ while (unicode_cpt_is_han (_get_cpt (pos))) {
419
+ pos++;
420
+ }
421
+ _add_token (pos);
422
+ continue ;
423
+ }
424
+
425
+ // Pattern 2 & 3: Letter words excluding Han characters with optional contractions
426
+ // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?:'s|'t|'re|'ve|'m|'ll|'d)?
427
+ // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?:'s|'t|'re|'ve|'m|'ll|'d)?
428
+ // Check if current char is a letter OR if current char could be a leading char and next char is a letter
429
+ bool is_letter_pattern = (flags.is_letter && !unicode_cpt_is_han (cpt)) ||
430
+ (!(cpt == ' \r ' || cpt == ' \n ' || flags.is_letter || flags.is_number ) &&
431
+ _get_flags (pos + 1 ).is_letter && !unicode_cpt_is_han (_get_cpt (pos + 1 )));
432
+
433
+ if (is_letter_pattern) {
434
+ // Handle optional leading non-letter/non-number character
435
+ bool has_leading_char = false ;
436
+ if (!(cpt == ' \r ' || cpt == ' \n ' || flags.is_letter || flags.is_number )) {
437
+ has_leading_char = true ;
438
+ pos++;
439
+ }
440
+
441
+ // Match letter sequence (excluding Han characters)
442
+ bool has_letters = false ;
443
+ while (_get_flags (pos).is_letter && !unicode_cpt_is_han (_get_cpt (pos))) {
444
+ has_letters = true ;
445
+ pos++;
446
+ }
447
+
448
+ // Only proceed if we found letters (after potentially skipping leading char)
449
+ if (has_letters || (!has_leading_char && _get_flags (pos).is_letter && !unicode_cpt_is_han (_get_cpt (pos)))) {
450
+ if (!has_letters) pos++; // consume the first letter if we didn't already
451
+
452
+ // Continue consuming letters
453
+ while (_get_flags (pos).is_letter && !unicode_cpt_is_han (_get_cpt (pos))) {
454
+ pos++;
455
+ }
456
+
457
+ // Check for optional contractions (?:'s|'t|'re|'ve|'m|'ll|'d)
458
+ if (_get_cpt (pos) == ' \' ' && pos + 1 < offset_end) {
459
+ uint32_t cpt_next = unicode_tolower (_get_cpt (pos + 1 ));
460
+ if (cpt_next == ' s' || cpt_next == ' t' || cpt_next == ' m' || cpt_next == ' d' ) {
461
+ pos += 2 ;
462
+ } else if (pos + 2 < offset_end) {
463
+ uint32_t cpt_next_next = unicode_tolower (_get_cpt (pos + 2 ));
464
+ if ((cpt_next == ' r' && cpt_next_next == ' e' ) ||
465
+ (cpt_next == ' v' && cpt_next_next == ' e' ) ||
466
+ (cpt_next == ' l' && cpt_next_next == ' l' )) {
467
+ pos += 3 ;
468
+ }
469
+ }
470
+ }
471
+
472
+ _add_token (pos);
473
+ continue ;
474
+ } else if (has_leading_char) {
475
+ // We consumed a leading char but found no letters, backtrack
476
+ pos--;
477
+ }
478
+ }
479
+
480
+ // Pattern 4: \p{N}{1,3} (numbers 1-3 digits)
481
+ if (flags.is_number ) {
482
+ size_t ini = pos;
483
+ while (_get_flags (pos).is_number ) {
484
+ if (++pos - ini >= 3 ) {
485
+ _add_token (pos);
486
+ ini = pos;
487
+ }
488
+ }
489
+ _add_token (pos);
490
+ continue ;
491
+ }
492
+
493
+ // Pattern 5: ?[^\s\p{L}\p{N}]+[\r\n]* (optional space + non-word chars + optional newlines)
494
+ auto flags2 = (cpt == ' ' ? _get_flags (pos + 1 ) : flags);
495
+ if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number ) && flags2.as_uint ()) {
496
+ pos += (cpt == ' ' );
497
+ while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number ) && flags2.as_uint ()) {
498
+ flags2 = _get_flags (++pos);
499
+ }
500
+ // Match optional [\r\n]*
501
+ uint32_t cpt2 = _get_cpt (pos);
502
+ while (cpt2 == ' \r ' || cpt2 == ' \n ' ) {
503
+ cpt2 = _get_cpt (++pos);
504
+ }
505
+ _add_token (pos);
506
+ continue ;
507
+ }
508
+
509
+ // Count whitespace characters
510
+ size_t num_whitespaces = 0 ;
511
+ size_t last_end_r_or_n = 0 ;
512
+ while (_get_flags (pos + num_whitespaces).is_whitespace ) {
513
+ uint32_t cpt2 = _get_cpt (pos + num_whitespaces);
514
+ if (cpt2 == ' \r ' || cpt2 == ' \n ' ) {
515
+ last_end_r_or_n = pos + num_whitespaces + 1 ;
516
+ }
517
+ num_whitespaces++;
518
+ }
519
+
520
+ // Pattern 6: \s*[\r\n]+ (whitespace with newlines)
521
+ if (last_end_r_or_n > 0 ) {
522
+ pos = last_end_r_or_n;
523
+ _add_token (pos);
524
+ continue ;
525
+ }
526
+
527
+ // Pattern 7: \s+(?!\S) (trailing whitespace)
528
+ if (num_whitespaces > 1 && _get_cpt (pos + num_whitespaces) != OUT_OF_RANGE) {
529
+ pos += num_whitespaces - 1 ;
530
+ _add_token (pos);
531
+ continue ;
532
+ }
533
+
534
+ // Pattern 8: \s+ (general whitespace)
535
+ if (num_whitespaces > 0 ) {
536
+ pos += num_whitespaces;
537
+ _add_token (pos);
538
+ continue ;
539
+ }
540
+
541
+ // No matches - consume single character
542
+ _add_token (++pos);
543
+ }
544
+ }
545
+
546
+ return bpe_offsets;
547
+ }
548
+
377
549
// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
378
550
static std::vector<size_t > unicode_regex_split_custom_llama3 (const std::string& text, const std::vector<size_t >& offsets) {
379
551
std::vector<size_t > bpe_offsets; // store the offset of each word
@@ -587,6 +759,10 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string& text, c
587
759
588
760
bpe_offsets = unicode_regex_split_custom_llama3 (text, offsets);
589
761
}
762
+ else if (regex_expr == " \\ p{Han}+" ) {
763
+ // K2's first pattern - handle all K2 patterns together
764
+ bpe_offsets = unicode_regex_split_custom_kimi_k2 (text, offsets);
765
+ }
590
766
591
767
return bpe_offsets;
592
768
}
@@ -662,6 +838,38 @@ codepoint_flags unicode_cpt_flags(const std::string& utf8) {
662
838
return unicode_cpt_flags (unicode_cpt_from_utf8 (utf8, offset));
663
839
}
664
840
841
+ bool unicode_cpt_is_han (uint32_t cpt) {
842
+ // Han character ranges (Chinese/CJK characters)
843
+ // CJK Unified Ideographs (most common)
844
+ if (cpt >= 0x4E00 && cpt <= 0x9FFF ) return true ;
845
+
846
+ // CJK Extension A
847
+ if (cpt >= 0x3400 && cpt <= 0x4DBF ) return true ;
848
+
849
+ // CJK Extension B
850
+ if (cpt >= 0x20000 && cpt <= 0x2A6DF ) return true ;
851
+
852
+ // CJK Extension C
853
+ if (cpt >= 0x2A700 && cpt <= 0x2B73F ) return true ;
854
+
855
+ // CJK Extension D
856
+ if (cpt >= 0x2B740 && cpt <= 0x2B81F ) return true ;
857
+
858
+ // CJK Extension E
859
+ if (cpt >= 0x2B820 && cpt <= 0x2CEAF ) return true ;
860
+
861
+ // CJK Extension F
862
+ if (cpt >= 0x2CEB0 && cpt <= 0x2EBEF ) return true ;
863
+
864
+ // CJK Compatibility Ideographs
865
+ if (cpt >= 0xF900 && cpt <= 0xFAFF ) return true ;
866
+
867
+ // CJK Compatibility Ideographs Supplement
868
+ if (cpt >= 0x2F800 && cpt <= 0x2FA1F ) return true ;
869
+
870
+ return false ;
871
+ }
872
+
665
873
std::string unicode_byte_to_utf8 (uint8_t byte) {
666
874
static std::unordered_map<uint8_t , std::string> map = unicode_byte_to_utf8_map ();
667
875
return map.at (byte);
0 commit comments