Skip to content

Commit 1df5c74

Browse files
committed
chore: format
1 parent eee774b commit 1df5c74

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+645
-670
lines changed

Package.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@ let package = Package(
1313
],
1414
dependencies: [
1515
.package(url: "https://github.com/apple/swift-argument-parser.git", .upToNextMinor(from: "1.4.0")),
16-
.package(url: "https://github.com/johnmai-dev/Jinja", .upToNextMinor(from: "1.1.0"))
16+
.package(url: "https://github.com/johnmai-dev/Jinja", .upToNextMinor(from: "1.1.0")),
1717
],
1818
targets: [
1919
.executableTarget(
2020
name: "TransformersCLI",
2121
dependencies: [
2222
"Models", "Generation", "Tokenizers",
23-
.product(name: "ArgumentParser", package: "swift-argument-parser")]),
23+
.product(name: "ArgumentParser", package: "swift-argument-parser"),
24+
]
25+
),
2426
.executableTarget(name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]),
2527
.target(name: "Hub", resources: [.process("FallbackConfigs")]),
2628
.target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]),

Sources/Generation/Generation.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
//
22
// Generation.swift
3-
//
3+
//
44
//
55
// Created by Pedro Cuenca on 7/5/23.
66
//
77

8-
import Tokenizers
98
import CoreML
109
import TensorUtils
10+
import Tokenizers
1111

1212
public enum GenerationMode {
1313
case contrastiveSearch
@@ -57,7 +57,7 @@ public extension Generation {
5757
let logitsProcessor = LogitsProcessor(logitsWarpers: logitsWarpers(config: config))
5858
while outputTokens.count < config.maxLength {
5959
let outputs = model(outputTokens, config)
60-
/// `floats` can be much faster than `scalars` for a vector with stride 1, as it uses `memcpy` in that case
60+
// `floats` can be much faster than `scalars` for a vector with stride 1, as it uses `memcpy` in that case
6161
let logits = (outputs as? MLShapedArraySlice<Float>)?.floats ?? outputs.scalars as! [Float]
6262
let (indexes, processedLogits) = logitsProcessor(logits)
6363
let nextToken = Math.sample(indexes: indexes, probs: Math.softmax(processedLogits))
@@ -92,7 +92,7 @@ public extension Generation {
9292

9393
private func logitsWarpers(config: GenerationConfig) -> [any LogitsWarper] {
9494
var logitsWarpers = [any LogitsWarper]()
95-
if config.temperature > 0 && config.temperature != 1 {
95+
if config.temperature > 0, config.temperature != 1 {
9696
logitsWarpers.append(TemperatureLogitsWarper(temperature: Float(config.temperature)))
9797
}
9898
if config.topK > 0 {

Sources/Generation/GenerationConfig.swift

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//
22
// GenerationConfig.swift
3-
//
3+
//
44
//
55
// Created by Pedro Cuenca on 7/5/23.
66
//
@@ -14,15 +14,15 @@ public struct GenerationConfig {
1414
public var doSample = false
1515
public var numBeams = 1
1616
public var numBeamGroups = 1
17-
public var penaltyAlpha: Double? = nil
17+
public var penaltyAlpha: Double?
1818
public var temperature = 1.0
1919
public var topK = 50
2020
public var topP = 1.0
2121
public var repetitionPenalty = 1.0
2222

23-
public var padTokenId: Int? = nil
24-
public var bosTokenId: Int? = nil
25-
public var eosTokenId: Int? = nil
23+
public var padTokenId: Int?
24+
public var bosTokenId: Int?
25+
public var eosTokenId: Int?
2626

2727
public init(maxLength: Int = 20, maxNewTokens: Int, doSample: Bool = false, numBeams: Int = 1, numBeamGroups: Int = 1, penaltyAlpha: Double? = nil, temperature: Double = 1.0, topK: Int = 50, topP: Double = 1.0, repetitionPenalty: Double = 1.0) {
2828
self.maxLength = maxLength
@@ -41,18 +41,18 @@ public struct GenerationConfig {
4141
public extension GenerationConfig {
4242
var generationMode: GenerationMode {
4343
// Exclude this case from the pattern matching below
44-
if topK > 1 && !doSample && penaltyAlpha != nil && penaltyAlpha! > 0 {
44+
if topK > 1, !doSample, penaltyAlpha != nil, penaltyAlpha! > 0 {
4545
return .contrastiveSearch
4646
}
4747

4848
switch (numBeams, numBeamGroups, doSample) {
49-
case (1, 1, false) : return .greedy
50-
case (1, 1, true) : return .sample
49+
case (1, 1, false): return .greedy
50+
case (1, 1, true): return .sample
5151
case (2..., 1, false): return .beam
52-
case (2..., 2..., _) : return .groupBeam
53-
default : return .unsupported
52+
case (2..., 2..., _): return .groupBeam
53+
default: return .unsupported
5454
}
5555
}
5656
}
5757

58-
extension GenerationConfig: Decodable {}
58+
extension GenerationConfig: Decodable { }

Sources/Hub/Downloader.swift

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
77
//
88

9-
import Foundation
109
import Combine
10+
import Foundation
1111

1212
class Downloader: NSObject, ObservableObject {
1313
private(set) var destination: URL
1414

15-
private let chunkSize = 10 * 1024 * 1024 // 10MB
15+
private let chunkSize = 10 * 1024 * 1024 // 10MB
1616

1717
enum DownloadState {
1818
case notStarted
@@ -53,7 +53,7 @@ class Downloader: NSObject, ObservableObject {
5353
config.sessionSendsLaunchEvents = true
5454
}
5555

56-
self.urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil)
56+
urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil)
5757

5858
setupDownload(from: url, with: authToken, resumeSize: resumeSize, headers: headers, expectedSize: expectedSize, timeout: timeout, numRetries: numRetries)
5959
}
@@ -106,14 +106,13 @@ class Downloader: NSObject, ObservableObject {
106106
var requestHeaders = headers ?? [:]
107107

108108
// Populate header auth and range fields
109-
if let authToken = authToken {
109+
if let authToken {
110110
requestHeaders["Authorization"] = "Bearer \(authToken)"
111111
}
112112
if resumeSize > 0 {
113113
requestHeaders["Range"] = "bytes=\(resumeSize)-"
114114
}
115115

116-
117116
request.timeoutInterval = timeout
118117
request.allHTTPHeaderFields = requestHeaders
119118

@@ -157,7 +156,7 @@ class Downloader: NSObject, ObservableObject {
157156
numRetries: Int,
158157
expectedSize: Int?
159158
) async throws {
160-
guard let session = self.urlSession else {
159+
guard let session = urlSession else {
161160
throw DownloadError.unexpectedError
162161
}
163162

@@ -194,7 +193,7 @@ class Downloader: NSObject, ObservableObject {
194193
buffer.removeAll(keepingCapacity: true)
195194
downloadedSize += chunkSize
196195
newNumRetries = 5
197-
guard let expectedSize = expectedSize else { continue }
196+
guard let expectedSize else { continue }
198197
let progress = expectedSize != 0 ? Double(downloadedSize) / Double(expectedSize) : 0
199198
downloadState.value = .downloading(progress)
200199
}
@@ -227,7 +226,7 @@ class Downloader: NSObject, ObservableObject {
227226

228227
// Verify the downloaded file size matches the expected size
229228
let actualSize = try tempFile.seekToEnd()
230-
if let expectedSize = expectedSize, expectedSize != actualSize {
229+
if let expectedSize, expectedSize != actualSize {
231230
throw DownloadError.unexpectedError
232231
}
233232
}
@@ -239,16 +238,16 @@ class Downloader: NSObject, ObservableObject {
239238
stateSubscriber = downloadState.sink { state in
240239
switch state {
241240
case .completed: semaphore.signal()
242-
case .failed: semaphore.signal()
243-
default: break
241+
case .failed: semaphore.signal()
242+
default: break
244243
}
245244
}
246245
semaphore.wait()
247246

248247
switch downloadState.value {
249-
case .completed(let url): return url
250-
case .failed(let error): throw error
251-
default: throw DownloadError.unexpectedError
248+
case let .completed(url): return url
249+
case let .failed(error): throw error
250+
default: throw DownloadError.unexpectedError
252251
}
253252
}
254253

@@ -265,15 +264,15 @@ extension Downloader: URLSessionDownloadDelegate {
265264
func urlSession(_: URLSession, downloadTask _: URLSessionDownloadTask, didFinishDownloadingTo location: URL) {
266265
do {
267266
// If the downloaded file already exists on the filesystem, overwrite it
268-
try FileManager.default.moveDownloadedFile(from: location, to: self.destination)
267+
try FileManager.default.moveDownloadedFile(from: location, to: destination)
269268
downloadState.value = .completed(destination)
270269
} catch {
271270
downloadState.value = .failed(error)
272271
}
273272
}
274273

275274
func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
276-
if let error = error {
275+
if let error {
277276
downloadState.value = .failed(error)
278277
// } else if let response = task.response as? HTTPURLResponse {
279278
// print("HTTP response status code: \(response.statusCode)")

Sources/Hub/Hub.swift

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import Foundation
99

10-
public struct Hub {}
10+
public struct Hub { }
1111

1212
public extension Hub {
1313
enum HubClientError: LocalizedError {
@@ -25,28 +25,28 @@ public extension Hub {
2525

2626
public var errorDescription: String? {
2727
switch self {
28-
case .authorizationRequired:
29-
return String(localized: "Authentication required. Please provide a valid Hugging Face token.")
30-
case .httpStatusCode(let code):
31-
return String(localized: "HTTP error with status code: \(code)")
32-
case .parse:
33-
return String(localized: "Failed to parse server response.")
34-
case .unexpectedError:
35-
return String(localized: "An unexpected error occurred.")
36-
case .downloadError(let message):
37-
return String(localized: "Download failed: \(message)")
38-
case .fileNotFound(let filename):
39-
return String(localized: "File not found: \(filename)")
40-
case .networkError(let error):
41-
return String(localized: "Network error: \(error.localizedDescription)")
42-
case .resourceNotFound(let resource):
43-
return String(localized: "Resource not found: \(resource)")
44-
case .configurationMissing(let file):
45-
return String(localized: "Required configuration file missing: \(file)")
46-
case .fileSystemError(let error):
47-
return String(localized: "File system error: \(error.localizedDescription)")
48-
case .parseError(let message):
49-
return String(localized: "Parse error: \(message)")
28+
case .authorizationRequired:
29+
String(localized: "Authentication required. Please provide a valid Hugging Face token.")
30+
case let .httpStatusCode(code):
31+
String(localized: "HTTP error with status code: \(code)")
32+
case .parse:
33+
String(localized: "Failed to parse server response.")
34+
case .unexpectedError:
35+
String(localized: "An unexpected error occurred.")
36+
case let .downloadError(message):
37+
String(localized: "Download failed: \(message)")
38+
case let .fileNotFound(filename):
39+
String(localized: "File not found: \(filename)")
40+
case let .networkError(error):
41+
String(localized: "Network error: \(error.localizedDescription)")
42+
case let .resourceNotFound(resource):
43+
String(localized: "Resource not found: \(resource)")
44+
case let .configurationMissing(file):
45+
String(localized: "Required configuration file missing: \(file)")
46+
case let .fileSystemError(error):
47+
String(localized: "File system error: \(error.localizedDescription)")
48+
case let .parseError(message):
49+
String(localized: "Parse error: \(message)")
5050
}
5151
}
5252
}
@@ -79,7 +79,7 @@ public struct Config {
7979
}
8080

8181
func camelCase(_ string: String) -> String {
82-
return string
82+
string
8383
.split(separator: "_")
8484
.enumerated()
8585
.map { $0.offset == 0 ? $0.element.lowercased() : $0.element.capitalized }
@@ -108,7 +108,6 @@ public struct Config {
108108
return result
109109
}
110110

111-
112111
public subscript(dynamicMember member: String) -> Config? {
113112
let key = (dictionary[member as NSString] != nil ? member : uncamelCase(member)) as NSString
114113
if let value = dictionary[key] as? [NSString: Any] {
@@ -120,17 +119,17 @@ public struct Config {
120119
}
121120

122121
public var value: Any? {
123-
return dictionary["value"]
122+
dictionary["value"]
124123
}
125124

126125
public var intValue: Int? { value as? Int }
127126
public var boolValue: Bool? { value as? Bool }
128127
public var stringValue: String? { value as? String }
129128

130-
// Instead of doing this we could provide custom classes and decode to them
129+
/// Instead of doing this we could provide custom classes and decode to them
131130
public var arrayValue: [Config]? {
132131
guard let list = value as? [Any] else { return nil }
133-
return list.map { Config($0 as! [NSString : Any]) }
132+
return list.map { Config($0 as! [NSString: Any]) }
134133
}
135134

136135
/// Tuple of token identifier and string value
@@ -144,23 +143,23 @@ public class LanguageModelConfigurationFromHub {
144143
var tokenizerData: Config
145144
}
146145

147-
private var configPromise: Task<Configurations, Error>? = nil
146+
private var configPromise: Task<Configurations, Error>?
148147

149148
public init(
150149
modelName: String,
151150
hubApi: HubApi = .shared
152151
) {
153-
self.configPromise = Task.init {
154-
return try await self.loadConfig(modelName: modelName, hubApi: hubApi)
152+
configPromise = Task.init {
153+
try await self.loadConfig(modelName: modelName, hubApi: hubApi)
155154
}
156155
}
157156

158157
public init(
159158
modelFolder: URL,
160159
hubApi: HubApi = .shared
161160
) {
162-
self.configPromise = Task {
163-
return try await self.loadConfig(modelFolder: modelFolder, hubApi: hubApi)
161+
configPromise = Task {
162+
try await self.loadConfig(modelFolder: modelFolder, hubApi: hubApi)
164163
}
165164
}
166165

@@ -221,12 +220,12 @@ public class LanguageModelConfigurationFromHub {
221220
// Convert generic errors to more specific ones
222221
if let urlError = error as? URLError {
223222
switch urlError.code {
224-
case .notConnectedToInternet, .networkConnectionLost:
225-
throw Hub.HubClientError.networkError(urlError)
226-
case .resourceUnavailable:
227-
throw Hub.HubClientError.resourceNotFound(modelName)
228-
default:
229-
throw Hub.HubClientError.networkError(urlError)
223+
case .notConnectedToInternet, .networkConnectionLost:
224+
throw Hub.HubClientError.networkError(urlError)
225+
case .resourceUnavailable:
226+
throw Hub.HubClientError.resourceNotFound(modelName)
227+
default:
228+
throw Hub.HubClientError.networkError(urlError)
230229
}
231230
} else {
232231
throw error
@@ -265,7 +264,8 @@ public class LanguageModelConfigurationFromHub {
265264
let chatTemplateURL = modelFolder.appending(path: "chat_template.json")
266265
if FileManager.default.fileExists(atPath: chatTemplateURL.path),
267266
let chatTemplateConfig = try? hubApi.configuration(fileURL: chatTemplateURL),
268-
let chatTemplate = chatTemplateConfig.chatTemplate?.stringValue {
267+
let chatTemplate = chatTemplateConfig.chatTemplate?.stringValue
268+
{
269269
// Create or update tokenizer config with chat template
270270
if var configDict = tokenizerConfig?.dictionary {
271271
configDict["chat_template"] = chatTemplate
@@ -284,7 +284,7 @@ public class LanguageModelConfigurationFromHub {
284284
throw error
285285
} catch {
286286
if let nsError = error as NSError? {
287-
if nsError.domain == NSCocoaErrorDomain && nsError.code == NSFileReadNoSuchFileError {
287+
if nsError.domain == NSCocoaErrorDomain, nsError.code == NSFileReadNoSuchFileError {
288288
throw Hub.HubClientError.fileSystemError(error)
289289
} else if nsError.domain == "NSJSONSerialization" {
290290
throw Hub.HubClientError.parseError("Invalid JSON format: \(nsError.localizedDescription)")

0 commit comments

Comments
 (0)