Skip to content

Commit caa5caf

Browse files
fix swift 6 warnings - thread safe tokenizer and model config (#126)
- replaces #125 with simpler mechanism (NSLock) Co-authored-by: John Mai <maiqingqiang@gmail.com>
1 parent ee94992 commit caa5caf

File tree

7 files changed

+72
-22
lines changed

7 files changed

+72
-22
lines changed

Libraries/LLM/Configuration.swift

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,21 @@ public enum StringOrNumber: Codable, Equatable, Sendable {
2626
}
2727
}
2828

29-
public struct ModelType: RawRepresentable, Codable, Sendable {
30-
public let rawValue: String
29+
private class ModelTypeRegistry: @unchecked Sendable {
3130

32-
public init(rawValue: String) {
33-
self.rawValue = rawValue
34-
}
31+
// Note: using NSLock as we have very small (just dictionary get/set)
32+
// critical sections and expect no contention. this allows the methods
33+
// to remain synchronous.
34+
private let lock = NSLock()
3535

36+
@Sendable
3637
private static func createLlamaModel(url: URL) throws -> LLMModel {
3738
let configuration = try JSONDecoder().decode(
3839
LlamaConfiguration.self, from: Data(contentsOf: url))
3940
return LlamaModel(configuration)
4041
}
4142

42-
private static var creators: [String: (URL) throws -> LLMModel] = [
43+
private var creators: [String: @Sendable (URL) throws -> LLMModel] = [
4344
"mistral": createLlamaModel,
4445
"llama": createLlamaModel,
4546
"phi": { url in
@@ -89,18 +90,44 @@ public struct ModelType: RawRepresentable, Codable, Sendable {
8990
},
9091
]
9192

92-
public static func registerModelType(
93-
_ type: String, creator: @escaping (URL) throws -> LLMModel
93+
public func registerModelType(
94+
_ type: String, creator: @Sendable @escaping (URL) throws -> LLMModel
9495
) {
95-
creators[type] = creator
96+
lock.withLock {
97+
creators[type] = creator
98+
}
9699
}
97100

98-
public func createModel(configuration: URL) throws -> LLMModel {
99-
guard let creator = ModelType.creators[rawValue] else {
101+
public func createModel(configuration: URL, rawValue: String) throws -> LLMModel {
102+
let creator = lock.withLock {
103+
creators[rawValue]
104+
}
105+
guard let creator else {
100106
throw LLMError(message: "Unsupported model type.")
101107
}
102108
return try creator(configuration)
103109
}
110+
111+
}
112+
113+
private let modelTypeRegistry = ModelTypeRegistry()
114+
115+
public struct ModelType: RawRepresentable, Codable, Sendable {
116+
public let rawValue: String
117+
118+
public init(rawValue: String) {
119+
self.rawValue = rawValue
120+
}
121+
122+
public static func registerModelType(
123+
_ type: String, creator: @Sendable @escaping (URL) throws -> LLMModel
124+
) {
125+
modelTypeRegistry.registerModelType(type, creator: creator)
126+
}
127+
128+
public func createModel(configuration: URL) throws -> LLMModel {
129+
try modelTypeRegistry.createModel(configuration: configuration, rawValue: rawValue)
130+
}
104131
}
105132

106133
public struct BaseConfiguration: Codable, Sendable {

Libraries/LLM/Tokenizer.swift

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,35 @@ private func updateTokenizerConfig(_ tokenizerConfig: Config) -> Config {
6767
return tokenizerConfig
6868
}
6969

70-
/// overrides for TokenizerModel/knownTokenizers
71-
public var replacementTokenizers = [
72-
"InternLM2Tokenizer": "PreTrainedTokenizer",
73-
"Qwen2Tokenizer": "PreTrainedTokenizer",
74-
"CohereTokenizer": "PreTrainedTokenizer",
75-
]
70+
public class TokenizerReplacementRegistry: @unchecked Sendable {
71+
72+
// Note: using NSLock as we have very small (just dictionary get/set)
73+
// critical sections and expect no contention. this allows the methods
74+
// to remain synchronous.
75+
private let lock = NSLock()
76+
77+
/// overrides for TokenizerModel/knownTokenizers
78+
private var replacementTokenizers = [
79+
"InternLM2Tokenizer": "PreTrainedTokenizer",
80+
"Qwen2Tokenizer": "PreTrainedTokenizer",
81+
"CohereTokenizer": "PreTrainedTokenizer",
82+
]
83+
84+
public subscript(key: String) -> String? {
85+
get {
86+
lock.withLock {
87+
replacementTokenizers[key]
88+
}
89+
}
90+
set {
91+
lock.withLock {
92+
replacementTokenizers[key] = newValue
93+
}
94+
}
95+
}
96+
}
97+
98+
public let replacementTokenizers = TokenizerReplacementRegistry()
7699

77100
public protocol StreamingDetokenizer: IteratorProtocol<String> {
78101

Tools/LinearModelTraining/LinearModelTraining.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import MLXNN
77
import MLXOptimizers
88
import MLXRandom
99

10-
#if swift(>=6.0)
10+
#if swift(>=5.10)
1111
extension MLX.DeviceType: @retroactive ExpressibleByArgument {
1212
public init?(argument: String) {
1313
self.init(rawValue: argument)

Tools/image-tool/Arguments.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import ArgumentParser
44
import Foundation
55
import MLX
66

7-
#if swift(>=6.0)
7+
#if swift(>=5.10)
88
/// Extension to allow URL command line arguments.
99
extension URL: @retroactive ExpressibleByArgument {
1010
public init?(argument: String) {

Tools/image-tool/ImageTool.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct StableDiffusionTool: AsyncParsableCommand {
2222
)
2323
}
2424

25-
#if swift(>=6.0)
25+
#if swift(>=5.10)
2626
extension StableDiffusionConfiguration.Preset: @retroactive ExpressibleByArgument {}
2727
#else
2828
extension StableDiffusionConfiguration.Preset: ExpressibleByArgument {}

Tools/llm-tool/Arguments.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import ArgumentParser
44
import Foundation
55

66
/// Extension to allow URL command line arguments.
7-
#if swift(>=6.0)
7+
#if swift(>=5.10)
88
extension URL: @retroactive ExpressibleByArgument {
99
public init?(argument: String) {
1010
if argument.contains("://") {

Tools/mnist-tool/MNISTTool.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ struct MNISTTool: AsyncParsableCommand {
1616
defaultSubcommand: Train.self)
1717
}
1818

19-
#if swift(>=6.0)
19+
#if swift(>=5.10)
2020
extension MLX.DeviceType: @retroactive ExpressibleByArgument {
2121
public init?(argument: String) {
2222
self.init(rawValue: argument)

0 commit comments

Comments
 (0)