Skip to content

Added RobertaTokenizer #188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion Sources/Hub/Hub.swift
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,15 @@ public struct Config {
}

/// Tuple of token identifier and string value
public var tokenValue: (UInt, String)? { value as? (UInt, String) }
public var tokenValue: (UInt, String)? {
guard let value = value as? [Any] else {
return nil
}
guard let stringValue = value.first as? String, let intValue = value.dropFirst().first as? UInt else {
return nil
}
return (intValue, stringValue)
}
}

public class LanguageModelConfigurationFromHub {
Expand Down
3 changes: 2 additions & 1 deletion Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ struct TokenizerModel {
"BertTokenizer": BertTokenizer.self,
"DistilbertTokenizer": BertTokenizer.self,
"DistilBertTokenizer": BertTokenizer.self,
"RobertaTokenizer": BPETokenizer.self,
"CodeGenTokenizer": CodeGenTokenizer.self,
"CodeLlamaTokenizer": CodeLlamaTokenizer.self,
"FalconTokenizer": FalconTokenizer.self,
Expand Down Expand Up @@ -230,7 +231,7 @@ public extension Tokenizer {
func callAsFunction(_ text: String, addSpecialTokens: Bool = true) -> [Int] {
encode(text: text, addSpecialTokens: addSpecialTokens)
}

func decode(tokens: [Int]) -> String {
decode(tokens: tokens, skipSpecialTokens: false)
}
Expand Down
14 changes: 14 additions & 0 deletions Tests/HubTests/HubTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,18 @@ class HubTests: XCTestCase {
let vocab_dict = config.dictionary["vocab"] as! [String: Int]
XCTAssertNotEqual(vocab_dict.count, 2)
}

func testConfigTokenValue() throws {
let config1 = Config(["cls": ["str" as String, 100 as UInt] as [Any]])
let tokenValue1 = config1.cls?.tokenValue
XCTAssertEqual(tokenValue1?.0, 100)
XCTAssertEqual(tokenValue1?.1, "str")

let data = #"{"cls": ["str", 100]}"#.data(using: .utf8)!
let dict = try JSONSerialization.jsonObject(with: data, options: []) as! [NSString: Any]
let config2 = Config(dict)
let tokenValue2 = config2.cls?.tokenValue
XCTAssertEqual(tokenValue2?.0, 100)
XCTAssertEqual(tokenValue2?.1, "str")
}
}
24 changes: 12 additions & 12 deletions Tests/PostProcessorTests/PostProcessorTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ class PostProcessorTests: XCTestCase {
let testCases: [(Config, [String], [String]?, [String])] = [
// Should keep spaces; uneven spaces; ignore `addPrefixSpace`.
(
Config(["cls": (0, "[HEAD]") as (UInt, String),
"sep": (0, "[END]") as (UInt, String),
Config(["cls": ["[HEAD]", 0 as UInt],
"sep": ["[END]", 0 as UInt],
"trimOffset": false,
"addPrefixSpace": true]),
[" The", " sun", "sets ", " in ", " the ", "west"],
Expand All @@ -17,8 +17,8 @@ class PostProcessorTests: XCTestCase {
),
// Should leave only one space around each token.
(
Config(["cls": (0, "[START]") as (UInt, String),
"sep": (0, "[BREAK]") as (UInt, String),
Config(["cls": ["[START]", 0 as UInt],
"sep": ["[BREAK]", 0 as UInt],
"trimOffset": true,
"addPrefixSpace": true]),
[" The ", " sun", "sets ", " in ", " the ", "west"],
Expand All @@ -27,8 +27,8 @@ class PostProcessorTests: XCTestCase {
),
// Should ignore empty tokens pair.
(
Config(["cls": (0, "[START]") as (UInt, String),
"sep": (0, "[BREAK]") as (UInt, String),
Config(["cls": ["[START]", 0 as UInt],
"sep": ["[BREAK]", 0 as UInt],
"trimOffset": true,
"addPrefixSpace": true]),
[" The ", " sun", "sets ", " in ", " the ", "west"],
Expand All @@ -37,8 +37,8 @@ class PostProcessorTests: XCTestCase {
),
// Should trim all whitespace.
(
Config(["cls": (0, "[CLS]") as (UInt, String),
"sep": (0, "[SEP]") as (UInt, String),
Config(["cls": ["[CLS]", 0 as UInt],
"sep": ["[SEP]", 0 as UInt],
"trimOffset": true,
"addPrefixSpace": false]),
[" The ", " sun", "sets ", " in ", " the ", "west"],
Expand All @@ -47,8 +47,8 @@ class PostProcessorTests: XCTestCase {
),
// Should add tokens.
(
Config(["cls": (0, "[CLS]") as (UInt, String),
"sep": (0, "[SEP]") as (UInt, String),
Config(["cls": ["[CLS]", 0 as UInt],
"sep": ["[SEP]", 0 as UInt],
"trimOffset": true,
"addPrefixSpace": true]),
[" The ", " sun", "sets ", " in ", " the ", "west"],
Expand All @@ -58,8 +58,8 @@ class PostProcessorTests: XCTestCase {
"mat", "[SEP]"]
),
(
Config(["cls": (0, "[CLS]") as (UInt, String),
"sep": (0, "[SEP]") as (UInt, String),
Config(["cls": ["[CLS]", 0 as UInt],
"sep": ["[SEP]", 0 as UInt],
"trimOffset": true,
"addPrefixSpace": true]),
[" 你 ", " 好 ", ","],
Expand Down
62 changes: 45 additions & 17 deletions Tests/TokenizersTests/TokenizerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,34 @@ class BertSpacesTests: XCTestCase {
}
}

class RobertaTests: XCTestCase {
func testEncodeDecode() async throws {
guard let tokenizer = try await AutoTokenizer.from(pretrained: "FacebookAI/roberta-base") as? PreTrainedTokenizer else {
XCTFail()
return
}

XCTAssertEqual(tokenizer.tokenize(text: "l'eure"), ["l", "'", "e", "ure"])
XCTAssertEqual(tokenizer.encode(text: "l'eure"), [0, 462, 108, 242, 2407, 2])
XCTAssertEqual(tokenizer.decode(tokens: tokenizer.encode(text: "l'eure"), skipSpecialTokens: true), "l'eure")

XCTAssertEqual(tokenizer.tokenize(text: "mąka"), ["m", "Ä", "ħ", "ka"])
XCTAssertEqual(tokenizer.encode(text: "mąka"), [0, 119, 649, 5782, 2348, 2])

XCTAssertEqual(tokenizer.tokenize(text: "département"), ["d", "é", "part", "ement"])
XCTAssertEqual(tokenizer.encode(text: "département"), [0, 417, 1140, 7755, 6285, 2])

XCTAssertEqual(tokenizer.tokenize(text: "Who are you?"), ["Who", "Ġare", "Ġyou", "?"])
XCTAssertEqual(tokenizer.encode(text: "Who are you?"), [0, 12375, 32, 47, 116, 2])

XCTAssertEqual(tokenizer.tokenize(text: " Who are you? "), ["ĠWho", "Ġare", "Ġyou", "?", "Ġ"])
XCTAssertEqual(tokenizer.encode(text: " Who are you? "), [0, 3394, 32, 47, 116, 1437, 2])

XCTAssertEqual(tokenizer.tokenize(text: "<s>Who are you?</s>"), ["<s>", "Who", "Ġare", "Ġyou", "?", "</s>"])
XCTAssertEqual(tokenizer.encode(text: "<s>Who are you?</s>"), [0, 0, 12375, 32, 47, 116, 2, 2])
}
}

struct EncodedTokenizerSamplesDataset: Decodable {
let text: String
// Bad naming, not just for bpe.
Expand Down Expand Up @@ -239,16 +267,16 @@ struct EncodedData: Decodable {
class TokenizerTester {
let encodedSamplesFilename: String
let unknownTokenId: Int?

private var configuration: LanguageModelConfigurationFromHub?
private var edgeCases: [EdgeCase]?
private var _tokenizer: Tokenizer?

init(hubModelName: String, encodedSamplesFilename: String, unknownTokenId: Int?, hubApi: HubApi) {
configuration = LanguageModelConfigurationFromHub(modelName: hubModelName, hubApi: hubApi)
self.encodedSamplesFilename = encodedSamplesFilename
self.unknownTokenId = unknownTokenId

// Read the edge cases dataset
edgeCases = {
let url = Bundle.module.url(forResource: "tokenizer_tests", withExtension: "json")!
Expand All @@ -259,15 +287,15 @@ class TokenizerTester {
return cases[hubModelName]
}()
}

lazy var dataset: EncodedTokenizerSamplesDataset = {
let url = Bundle.module.url(forResource: encodedSamplesFilename, withExtension: "json")!
let json = try! Data(contentsOf: url)
let decoder = JSONDecoder()
let dataset = try! decoder.decode(EncodedTokenizerSamplesDataset.self, from: json)
return dataset
}()

var tokenizer: Tokenizer? {
get async {
guard _tokenizer == nil else { return _tokenizer! }
Expand All @@ -283,39 +311,39 @@ class TokenizerTester {
return _tokenizer
}
}

var tokenizerModel: TokenizingModel? {
get async {
// The model is not usually accessible; maybe it should
guard let tokenizer = await tokenizer else { return nil }
return (tokenizer as! PreTrainedTokenizer).model
}
}

func testTokenize() async {
let tokenized = await tokenizer?.tokenize(text: dataset.text)
XCTAssertEqual(
tokenized,
dataset.bpe_tokens
)
}

func testEncode() async {
let encoded = await tokenizer?.encode(text: dataset.text)
XCTAssertEqual(
encoded,
dataset.token_ids
)
}

func testDecode() async {
let decoded = await tokenizer?.decode(tokens: dataset.token_ids)
XCTAssertEqual(
decoded,
dataset.decoded_text
)
}

/// Test encode and decode for a few edge cases
func testEdgeCases() async {
guard let edgeCases else {
Expand All @@ -339,7 +367,7 @@ class TokenizerTester {
)
}
}

func testUnknownToken() async {
guard let model = await tokenizerModel else { return }
XCTAssertEqual(model.unknownTokenId, unknownTokenId)
Expand All @@ -361,10 +389,10 @@ class TokenizerTester {
class TokenizerTests: XCTestCase {
/// Parallel testing in Xcode (when enabled) uses different processes, so this shouldn't be a problem
static var _tester: TokenizerTester? = nil

class var hubModelName: String? { nil }
class var encodedSamplesFilename: String? { nil }

/// Known id retrieved from Python, to verify it was parsed correctly
class var unknownTokenId: Int? { nil }

Expand Down Expand Up @@ -399,25 +427,25 @@ class TokenizerTests: XCTestCase {
await tester.testTokenize()
}
}

func testEncode() async {
if let tester = Self._tester {
await tester.testEncode()
}
}

func testDecode() async {
if let tester = Self._tester {
await tester.testDecode()
}
}

func testEdgeCases() async {
if let tester = Self._tester {
await tester.testEdgeCases()
}
}

func testUnknownToken() async {
if let tester = Self._tester {
await tester.testUnknownToken()
Expand Down
Loading