diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift index 4116dcb..fad2875 100644 --- a/Sources/Hub/Hub.swift +++ b/Sources/Hub/Hub.swift @@ -100,7 +100,14 @@ public struct Config { } /// Tuple of token identifier and string value - public var tokenValue: (UInt, String)? { value as? (UInt, String) } + public var tokenValue: (UInt, String)? { + guard let pair = value as? [Any], pair.count == 2 else { return nil } + switch (pair[0], pair[1]) { + case let (i, t) as (UInt, String): return (i, t) + case let (t, i) as (String, UInt): return (i, t) + default: return nil + } + } } public class LanguageModelConfigurationFromHub { diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index db53337..210c543 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -77,6 +77,7 @@ public protocol PreTrainedTokenizerModel: TokenizingModel { struct TokenizerModel { static let knownTokenizers: [String : PreTrainedTokenizerModel.Type] = [ + "BartTokenizer" : BertTokenizer.self, "BertTokenizer" : BertTokenizer.self, "DistilbertTokenizer": BertTokenizer.self, "DistilBertTokenizer": BertTokenizer.self, diff --git a/Tests/HubTests/HubTests.swift b/Tests/HubTests/HubTests.swift index 1d7bc86..a7f6203 100644 --- a/Tests/HubTests/HubTests.swift +++ b/Tests/HubTests/HubTests.swift @@ -118,4 +118,14 @@ class HubTests: XCTestCase { let vocab_dict = config.dictionary["vocab"] as! [String: Int] XCTAssertNotEqual(vocab_dict.count, 2) } + + func testConfigTokenValueDifferentOrder() { + let data: Data! = "{\"sep\": [\"\", 2], \"cls\": [0, \"\"]}".data(using: .utf8) + let dict = try! JSONSerialization.jsonObject(with: data, options: []) as! [NSString: Any] + let config = Config(dict) + XCTAssertEqual(config.sep!.tokenValue!.0, 2) + XCTAssertEqual(config.sep!.tokenValue!.1, "") + XCTAssertEqual(config.cls!.tokenValue!.0, 0) + XCTAssertEqual(config.cls!.tokenValue!.1, "") + } }