Skip to content

Commit 92611e4

Browse files
committed
Split by regexp with capture groups
The other split helpers we have don't work for capture groups. We had to resort to raw `NSRegularExpression`s
1 parent 5e02089 commit 92611e4

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

Sources/Tokenizers/PreTokenizer.swift

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ extension StringSplitPattern {
248248
}
249249
}
250250

251-
extension String {
251+
public extension String {
252252
func ranges(of string: String, options: CompareOptions = .regularExpression) -> [Range<Index>] {
253253
var result: [Range<Index>] = []
254254
var start = startIndex
@@ -277,6 +277,42 @@ extension String {
277277
return result
278278
}
279279

280+
/// This version supports capture groups, wheres the one above doesn't
281+
func split(by captureRegex: NSRegularExpression) -> [String] {
282+
// Find the matching capture groups
283+
let selfRange = NSRange(startIndex..<endIndex, in: self)
284+
let matches = captureRegex.matches(in: self, options: [], range: selfRange)
285+
286+
if matches.first == nil { return [self] }
287+
288+
var result: [String] = []
289+
var start = startIndex
290+
for match in matches {
291+
// Append prefix before matched separator
292+
let prefixEnd = index(startIndex, offsetBy: match.range.lowerBound)
293+
if start < prefixEnd {
294+
result.append(String(self[start..<prefixEnd]))
295+
}
296+
start = index(startIndex, offsetBy: match.range.upperBound)
297+
298+
// Append separator, supporting capture groups
299+
for r in (0..<match.numberOfRanges).reversed() {
300+
let matchRange = match.range(at: r)
301+
if let sepRange = Range(matchRange, in:self) {
302+
result.append(String(self[sepRange]))
303+
break
304+
}
305+
}
306+
}
307+
308+
// Append remaining suffix
309+
let beginningOfEnd = index(startIndex, offsetBy: matches.last!.range.upperBound)
310+
if beginningOfEnd < endIndex {
311+
result.append(String(self[beginningOfEnd...]))
312+
}
313+
314+
return result
315+
}
280316
}
281317

282318
public enum SplitDelimiterBehavior {
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
//
2+
// AddedTokensTests.swift
3+
//
4+
//
5+
// Created by Pedro Cuenca on 20240426.
6+
//
7+
8+
import XCTest
9+
import Tokenizers
10+
import Hub
11+
12+
class AddedTokensTests: XCTestCase {
13+
func testPhiAddedEnd() async throws {
14+
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Phi-3-mini-128k-instruct-4bit")
15+
let inputIds = tokenizer("This is the <|end|>. My only friend, the <|end|>")
16+
XCTAssertEqual(inputIds, [1, 910, 338, 278, 29871, 32007, 29889, 1619, 871, 5121, 29892, 278, 29871, 32007])
17+
18+
let decoded = tokenizer.decode(tokens: inputIds)
19+
XCTAssertEqual(decoded, "<s> This is the <|end|>. My only friend, the <|end|>")
20+
}
21+
22+
func testSplitWithCaptureGroups() {
23+
let addedTokensRegexp = #"(<\|end\|>)\s*|(<\|raw\|>)\s*"#
24+
let captureRegex = try! NSRegularExpression(pattern: addedTokensRegexp, options: [])
25+
26+
XCTAssertEqual(
27+
"eating <|raw|> meat <|end|> That's all".split(by: captureRegex),
28+
["eating ", "<|raw|>", "meat ", "<|end|>", "That's all"]
29+
)
30+
31+
XCTAssertEqual(
32+
"<|raw|>".split(by: captureRegex),
33+
["<|raw|>"]
34+
)
35+
36+
XCTAssertEqual(
37+
"This string doesn't have those separators".split(by: captureRegex),
38+
["This string doesn't have those separators"]
39+
)
40+
41+
XCTAssertEqual(
42+
"start <|end|>".split(by: captureRegex),
43+
["start ", "<|end|>"]
44+
)
45+
46+
XCTAssertEqual(
47+
"start <|end|> ".split(by: captureRegex),
48+
["start ", "<|end|>"]
49+
)
50+
51+
XCTAssertEqual(
52+
"start <|end|> ".split(by: captureRegex),
53+
["start ", "<|end|>"]
54+
)
55+
56+
XCTAssertEqual(
57+
"start <|end|> for real".split(by: captureRegex),
58+
["start ", "<|end|>", "for real"]
59+
)
60+
61+
XCTAssertEqual(
62+
"<|raw|><|end|>".split(by: captureRegex),
63+
["<|raw|>", "<|end|>"]
64+
)
65+
66+
}
67+
}

0 commit comments

Comments
 (0)