Skip to content

Commit e5e0c18

Browse files
FL33TW00Dpcuenca
andauthored
chore: experiment with swift-format for PR (#174)
* Format without preferKeyPath * chore: fix teardown * Apply format --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: FL33TW00D <FL33TW00D@users.noreply.github.com>
1 parent 0b07561 commit e5e0c18

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+736
-670
lines changed

.swiftformat

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
--swiftversion 5.9
2+
--acronyms ID,URL,UUID
3+
--allman false
4+
--anonymousforeach convert
5+
--assetliterals visual-width
6+
--asynccapturing
7+
--beforemarks
8+
--binarygrouping 4,8
9+
--categorymark "MARK: %c"
10+
--classthreshold 0
11+
--closingparen balanced
12+
--closurevoid remove
13+
--commas always
14+
--conflictmarkers reject
15+
--decimalgrouping ignore
16+
--elseposition same-line
17+
--emptybraces spaced
18+
--enumthreshold 0
19+
--exponentcase lowercase
20+
--exponentgrouping disabled
21+
--extensionacl on-extension
22+
--extensionlength 0
23+
--extensionmark "MARK: - %t + %c"
24+
--fractiongrouping disabled
25+
--fragment false
26+
--funcattributes preserve
27+
--generictypes
28+
--groupedextension "MARK: %c"
29+
--guardelse auto
30+
--header ignore
31+
--hexgrouping 4,8
32+
--hexliteralcase uppercase
33+
--ifdef no-indent
34+
--importgrouping alpha
35+
--indent 4
36+
--indentcase false
37+
--indentstrings false
38+
--lifecycle
39+
--lineaftermarks true
40+
--linebreaks lf
41+
--markcategories true
42+
--markextensions always
43+
--marktypes always
44+
--maxwidth none
45+
--modifierorder
46+
--nevertrailing
47+
--nospaceoperators
48+
--nowrapoperators
49+
--octalgrouping 4,8
50+
--onelineforeach ignore
51+
--operatorfunc spaced
52+
--organizetypes actor,class,enum,struct
53+
--patternlet hoist
54+
--ranges no-space
55+
--redundanttype infer-locals-only
56+
--self remove
57+
--selfrequired
58+
--semicolons inline
59+
--shortoptionals always
60+
--smarttabs enabled
61+
--someany true
62+
--stripunusedargs unnamed-only
63+
--structthreshold 0
64+
--tabwidth unspecified
65+
--throwcapturing
66+
--trailingclosures
67+
--typeattributes preserve
68+
--typeblanklines remove
69+
--typemark "MARK: - %t"
70+
--varattributes preserve
71+
--voidtype void
72+
--wraparguments preserve
73+
--wrapcollections preserve
74+
--wrapconditions preserve
75+
--wrapeffects preserve
76+
--wrapenumcases always
77+
--wrapparameters preserve
78+
--wrapreturntype preserve
79+
--wrapternary default
80+
--wraptypealiases preserve
81+
--xcodeindentation disabled
82+
--yodaswap always
83+
--disable blankLineAfterImports,unusedArguments
84+
--enable docComments
85+
--disable enumnamespaces
86+
--trimwhitespace nonblank-lines
87+
--disable preferKeyPath

Package.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@ let package = Package(
1313
],
1414
dependencies: [
1515
.package(url: "https://github.com/apple/swift-argument-parser.git", .upToNextMinor(from: "1.4.0")),
16-
.package(url: "https://github.com/johnmai-dev/Jinja", .upToNextMinor(from: "1.1.0"))
16+
.package(url: "https://github.com/johnmai-dev/Jinja", .upToNextMinor(from: "1.1.0")),
1717
],
1818
targets: [
1919
.executableTarget(
2020
name: "TransformersCLI",
2121
dependencies: [
2222
"Models", "Generation", "Tokenizers",
23-
.product(name: "ArgumentParser", package: "swift-argument-parser")]),
23+
.product(name: "ArgumentParser", package: "swift-argument-parser"),
24+
]
25+
),
2426
.executableTarget(name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]),
2527
.target(name: "Hub", resources: [.process("FallbackConfigs")]),
2628
.target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]),

Sources/Generation/Generation.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
//
22
// Generation.swift
3-
//
3+
//
44
//
55
// Created by Pedro Cuenca on 7/5/23.
66
//
77

8-
import Tokenizers
98
import CoreML
109
import TensorUtils
10+
import Tokenizers
1111

1212
public enum GenerationMode {
1313
case contrastiveSearch
@@ -57,7 +57,7 @@ public extension Generation {
5757
let logitsProcessor = LogitsProcessor(logitsWarpers: logitsWarpers(config: config))
5858
while outputTokens.count < config.maxLength {
5959
let outputs = model(outputTokens, config)
60-
/// `floats` can be much faster than `scalars` for a vector with stride 1, as it uses `memcpy` in that case
60+
// `floats` can be much faster than `scalars` for a vector with stride 1, as it uses `memcpy` in that case
6161
let logits = (outputs as? MLShapedArraySlice<Float>)?.floats ?? outputs.scalars as! [Float]
6262
let (indexes, processedLogits) = logitsProcessor(logits)
6363
let nextToken = Math.sample(indexes: indexes, probs: Math.softmax(processedLogits))
@@ -92,7 +92,7 @@ public extension Generation {
9292

9393
private func logitsWarpers(config: GenerationConfig) -> [any LogitsWarper] {
9494
var logitsWarpers = [any LogitsWarper]()
95-
if config.temperature > 0 && config.temperature != 1 {
95+
if config.temperature > 0, config.temperature != 1 {
9696
logitsWarpers.append(TemperatureLogitsWarper(temperature: Float(config.temperature)))
9797
}
9898
if config.topK > 0 {

Sources/Generation/GenerationConfig.swift

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//
22
// GenerationConfig.swift
3-
//
3+
//
44
//
55
// Created by Pedro Cuenca on 7/5/23.
66
//
@@ -14,15 +14,15 @@ public struct GenerationConfig {
1414
public var doSample = false
1515
public var numBeams = 1
1616
public var numBeamGroups = 1
17-
public var penaltyAlpha: Double? = nil
17+
public var penaltyAlpha: Double?
1818
public var temperature = 1.0
1919
public var topK = 50
2020
public var topP = 1.0
2121
public var repetitionPenalty = 1.0
2222

23-
public var padTokenId: Int? = nil
24-
public var bosTokenId: Int? = nil
25-
public var eosTokenId: Int? = nil
23+
public var padTokenId: Int?
24+
public var bosTokenId: Int?
25+
public var eosTokenId: Int?
2626

2727
public init(maxLength: Int = 20, maxNewTokens: Int, doSample: Bool = false, numBeams: Int = 1, numBeamGroups: Int = 1, penaltyAlpha: Double? = nil, temperature: Double = 1.0, topK: Int = 50, topP: Double = 1.0, repetitionPenalty: Double = 1.0) {
2828
self.maxLength = maxLength
@@ -41,18 +41,18 @@ public struct GenerationConfig {
4141
public extension GenerationConfig {
4242
var generationMode: GenerationMode {
4343
// Exclude this case from the pattern matching below
44-
if topK > 1 && !doSample && penaltyAlpha != nil && penaltyAlpha! > 0 {
44+
if topK > 1, !doSample, penaltyAlpha != nil, penaltyAlpha! > 0 {
4545
return .contrastiveSearch
4646
}
4747

4848
switch (numBeams, numBeamGroups, doSample) {
49-
case (1, 1, false) : return .greedy
50-
case (1, 1, true) : return .sample
49+
case (1, 1, false): return .greedy
50+
case (1, 1, true): return .sample
5151
case (2..., 1, false): return .beam
52-
case (2..., 2..., _) : return .groupBeam
53-
default : return .unsupported
52+
case (2..., 2..., _): return .groupBeam
53+
default: return .unsupported
5454
}
5555
}
5656
}
5757

58-
extension GenerationConfig: Decodable {}
58+
extension GenerationConfig: Decodable { }

Sources/Hub/Downloader.swift

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
77
//
88

9-
import Foundation
109
import Combine
10+
import Foundation
1111

1212
class Downloader: NSObject, ObservableObject {
1313
private(set) var destination: URL
1414

15-
private let chunkSize = 10 * 1024 * 1024 // 10MB
15+
private let chunkSize = 10 * 1024 * 1024 // 10MB
1616

1717
enum DownloadState {
1818
case notStarted
@@ -53,7 +53,7 @@ class Downloader: NSObject, ObservableObject {
5353
config.sessionSendsLaunchEvents = true
5454
}
5555

56-
self.urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil)
56+
urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil)
5757

5858
setupDownload(from: url, with: authToken, resumeSize: resumeSize, headers: headers, expectedSize: expectedSize, timeout: timeout, numRetries: numRetries)
5959
}
@@ -106,14 +106,13 @@ class Downloader: NSObject, ObservableObject {
106106
var requestHeaders = headers ?? [:]
107107

108108
// Populate header auth and range fields
109-
if let authToken = authToken {
109+
if let authToken {
110110
requestHeaders["Authorization"] = "Bearer \(authToken)"
111111
}
112112
if resumeSize > 0 {
113113
requestHeaders["Range"] = "bytes=\(resumeSize)-"
114114
}
115115

116-
117116
request.timeoutInterval = timeout
118117
request.allHTTPHeaderFields = requestHeaders
119118

@@ -157,7 +156,7 @@ class Downloader: NSObject, ObservableObject {
157156
numRetries: Int,
158157
expectedSize: Int?
159158
) async throws {
160-
guard let session = self.urlSession else {
159+
guard let session = urlSession else {
161160
throw DownloadError.unexpectedError
162161
}
163162

@@ -194,7 +193,7 @@ class Downloader: NSObject, ObservableObject {
194193
buffer.removeAll(keepingCapacity: true)
195194
downloadedSize += chunkSize
196195
newNumRetries = 5
197-
guard let expectedSize = expectedSize else { continue }
196+
guard let expectedSize else { continue }
198197
let progress = expectedSize != 0 ? Double(downloadedSize) / Double(expectedSize) : 0
199198
downloadState.value = .downloading(progress)
200199
}
@@ -227,7 +226,7 @@ class Downloader: NSObject, ObservableObject {
227226

228227
// Verify the downloaded file size matches the expected size
229228
let actualSize = try tempFile.seekToEnd()
230-
if let expectedSize = expectedSize, expectedSize != actualSize {
229+
if let expectedSize, expectedSize != actualSize {
231230
throw DownloadError.unexpectedError
232231
}
233232
}
@@ -239,16 +238,16 @@ class Downloader: NSObject, ObservableObject {
239238
stateSubscriber = downloadState.sink { state in
240239
switch state {
241240
case .completed: semaphore.signal()
242-
case .failed: semaphore.signal()
243-
default: break
241+
case .failed: semaphore.signal()
242+
default: break
244243
}
245244
}
246245
semaphore.wait()
247246

248247
switch downloadState.value {
249-
case .completed(let url): return url
250-
case .failed(let error): throw error
251-
default: throw DownloadError.unexpectedError
248+
case let .completed(url): return url
249+
case let .failed(error): throw error
250+
default: throw DownloadError.unexpectedError
252251
}
253252
}
254253

@@ -265,15 +264,15 @@ extension Downloader: URLSessionDownloadDelegate {
265264
func urlSession(_: URLSession, downloadTask _: URLSessionDownloadTask, didFinishDownloadingTo location: URL) {
266265
do {
267266
// If the downloaded file already exists on the filesystem, overwrite it
268-
try FileManager.default.moveDownloadedFile(from: location, to: self.destination)
267+
try FileManager.default.moveDownloadedFile(from: location, to: destination)
269268
downloadState.value = .completed(destination)
270269
} catch {
271270
downloadState.value = .failed(error)
272271
}
273272
}
274273

275274
func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
276-
if let error = error {
275+
if let error {
277276
downloadState.value = .failed(error)
278277
// } else if let response = task.response as? HTTPURLResponse {
279278
// print("HTTP response status code: \(response.statusCode)")

0 commit comments

Comments
 (0)