Skip to content

Format with swift-format #158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
repos:
- repo: https://github.com/slessans/pre-commit-swift-format
rev: "fd627de92bdf84a75c924ed95691336d14e94cf1"
hooks:
- id: swift-format
args: ["--configuration", ".swift-format"]
9 changes: 9 additions & 0 deletions .swift-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"version": 1,
"indentation": {
"spaces": 4
},
"lineLength": 120,
"multiElementCollectionTrailingCommas": true,
"spacesAroundRangeFormationOperators": true
}
16 changes: 11 additions & 5 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,30 @@ let package = Package(
],
dependencies: [
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.4.0"),
.package(url: "https://github.com/maiqingqiang/Jinja", from: "1.0.6")
.package(url: "https://github.com/maiqingqiang/Jinja", from: "1.0.6"),
],
targets: [
.executableTarget(
name: "TransformersCLI",
dependencies: [
"Models", "Generation", "Tokenizers",
.product(name: "ArgumentParser", package: "swift-argument-parser")]),
.executableTarget(name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]),
.product(name: "ArgumentParser", package: "swift-argument-parser"),
]),
.executableTarget(
name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]),
.target(name: "Hub", resources: [.process("FallbackConfigs")]),
.target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]),
.target(name: "TensorUtils"),
.target(name: "Generation", dependencies: ["Tokenizers", "TensorUtils"]),
.target(name: "Models", dependencies: ["Tokenizers", "Generation", "TensorUtils"]),
.testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources"), .process("Vocabs")]),
.testTarget(
name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"],
resources: [.process("Resources"), .process("Vocabs")]),
.testTarget(name: "HubTests", dependencies: ["Hub"]),
.testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]),
.testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils", "Models", "Hub"], resources: [.process("Resources")]),
.testTarget(
name: "TensorUtilsTests", dependencies: ["TensorUtils", "Models", "Hub"], resources: [.process("Resources")]
),
.testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]),
.testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]),
]
Expand Down
34 changes: 23 additions & 11 deletions Sources/Generation/Generation.swift
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
//
// Generation.swift
//
//
//
// Created by Pedro Cuenca on 7/5/23.
//

import Tokenizers
import CoreML
import TensorUtils
import Tokenizers

public enum GenerationMode {
case contrastiveSearch
Expand All @@ -29,13 +29,20 @@ public typealias PredictionStringCallback = (String) -> Void

// TODO: callbacks (for streaming)
public protocol Generation {
func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback?) async -> GenerationOutput

func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback?) async -> String
func greedySearch(
config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback?
) async -> GenerationOutput

func generate(
config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer,
callback: PredictionStringCallback?
) async -> String
}

public extension Generation {
func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput {
extension Generation {
public func greedySearch(
config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil
) async -> GenerationOutput {
// Iterate until we find the eos token or reach the max length
// TODO: additional stopping criteria
var outputTokens = tokens
Expand All @@ -48,9 +55,11 @@ public extension Generation {
}
return outputTokens
}

/// https://github.com/huggingface/transformers/blob/42017d82baa083da2bee3055fdac80c81ee97b8a/src/transformers/generation/utils.py#L1552
func sample(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput {
public func sample(
config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil
) async -> GenerationOutput {
// Iterate until we find the eos token or reach the max length
// TODO: additional stopping criteria
var outputTokens = tokens
Expand All @@ -68,7 +77,10 @@ public extension Generation {
return outputTokens
}

func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback? = nil) async -> String {
public func generate(
config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer,
callback: PredictionStringCallback? = nil
) async -> String {
let tokens = tokenizer.encode(text: prompt)
var generationConfig = config
generationConfig.maxLength = config.maxNewTokens + tokens.count
Expand All @@ -86,7 +98,7 @@ public extension Generation {
default:
fatalError("Generation mode \(generationConfig.generationMode) not implemented yet")
}

return tokenizer.decode(tokens: output)
}

Expand Down
26 changes: 15 additions & 11 deletions Sources/Generation/GenerationConfig.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//
// GenerationConfig.swift
//
//
//
// Created by Pedro Cuenca on 7/5/23.
//
Expand All @@ -19,12 +19,16 @@ public struct GenerationConfig {
public var topK = 50
public var topP = 1.0
public var repetitionPenalty = 1.0

public var padTokenId: Int? = nil
public var bosTokenId: Int? = nil
public var eosTokenId: Int? = nil

public init(maxLength: Int = 20, maxNewTokens: Int, doSample: Bool = false, numBeams: Int = 1, numBeamGroups: Int = 1, penaltyAlpha: Double? = nil, temperature: Double = 1.0, topK: Int = 50, topP: Double = 1.0, repetitionPenalty: Double = 1.0) {

public init(
maxLength: Int = 20, maxNewTokens: Int, doSample: Bool = false, numBeams: Int = 1, numBeamGroups: Int = 1,
penaltyAlpha: Double? = nil, temperature: Double = 1.0, topK: Int = 50, topP: Double = 1.0,
repetitionPenalty: Double = 1.0
) {
self.maxLength = maxLength
self.maxNewTokens = maxNewTokens
self.doSample = doSample
Expand All @@ -38,19 +42,19 @@ public struct GenerationConfig {
}
}

public extension GenerationConfig {
var generationMode: GenerationMode {
extension GenerationConfig {
public var generationMode: GenerationMode {
// Exclude this case from the pattern matching below
if topK > 1 && !doSample && penaltyAlpha != nil && penaltyAlpha! > 0 {
return .contrastiveSearch
}

switch (numBeams, numBeamGroups, doSample) {
case (1, 1, false) : return .greedy
case (1, 1, true) : return .sample
case (1, 1, false): return .greedy
case (1, 1, true): return .sample
case (2..., 1, false): return .beam
case (2..., 2..., _) : return .groupBeam
default : return .unsupported
case (2..., 2..., _): return .groupBeam
default: return .unsupported
}
}
}
Expand Down
23 changes: 13 additions & 10 deletions Sources/Hub/Downloader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//

import Foundation
import Combine
import Foundation

class Downloader: NSObject, ObservableObject {
private(set) var destination: URL
Expand Down Expand Up @@ -86,16 +86,16 @@ class Downloader: NSObject, ObservableObject {
stateSubscriber = downloadState.sink { state in
switch state {
case .completed: semaphore.signal()
case .failed: semaphore.signal()
default: break
case .failed: semaphore.signal()
default: break
}
}
semaphore.wait()

switch downloadState.value {
case .completed(let url): return url
case .failed(let error): throw error
default: throw DownloadError.unexpectedError
case .failed(let error): throw error
default: throw DownloadError.unexpectedError
}
}

Expand All @@ -105,7 +105,10 @@ class Downloader: NSObject, ObservableObject {
}

extension Downloader: URLSessionDownloadDelegate {
func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64) {
func urlSession(
_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten: Int64,
totalBytesExpectedToWrite: Int64
) {
downloadState.value = .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite))
}

Expand All @@ -122,10 +125,10 @@ extension Downloader: URLSessionDownloadDelegate {
func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
if let error = error {
downloadState.value = .failed(error)
// } else if let response = task.response as? HTTPURLResponse {
// print("HTTP response status code: \(response.statusCode)")
// let headers = response.allHeaderFields
// print("HTTP response headers: \(headers)")
// } else if let response = task.response as? HTTPURLResponse {
// print("HTTP response status code: \(response.statusCode)")
// let headers = response.allHeaderFields
// print("HTTP response headers: \(headers)")
}
}
}
Expand Down
49 changes: 26 additions & 23 deletions Sources/Hub/Hub.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//
// Hub.swift
//
//
//
// Created by Pedro Cuenca on 18/5/23.
//
Expand All @@ -9,24 +9,24 @@ import Foundation

public struct Hub {}

public extension Hub {
enum HubClientError: Error {
extension Hub {
public enum HubClientError: Error {
case parse
case authorizationRequired
case unexpectedError
case httpStatusCode(Int)
}
enum RepoType: String {

public enum RepoType: String {
case models
case datasets
case spaces
}
struct Repo {

public struct Repo {
public let id: String
public let type: RepoType

public init(id: String, type: RepoType = .models) {
self.id = id
self.type = type
Expand All @@ -45,17 +45,18 @@ public struct Config {
}

func camelCase(_ string: String) -> String {
return string
return
string
.split(separator: "_")
.enumerated()
.map { $0.offset == 0 ? $0.element.lowercased() : $0.element.capitalized }
.joined()
}

func uncamelCase(_ string: String) -> String {
let scalars = string.unicodeScalars
var result = ""

var previousCharacterIsLowercase = false
for scalar in scalars {
if CharacterSet.uppercaseLetters.contains(scalar) {
Expand All @@ -70,11 +71,10 @@ public struct Config {
previousCharacterIsLowercase = true
}
}

return result
}


public subscript(dynamicMember member: String) -> Config? {
let key = (dictionary[member as NSString] != nil ? member : uncamelCase(member)) as NSString
if let value = dictionary[key] as? [NSString: Any] {
Expand All @@ -88,17 +88,17 @@ public struct Config {
public var value: Any? {
return dictionary["value"]
}

public var intValue: Int? { value as? Int }
public var boolValue: Bool? { value as? Bool }
public var stringValue: String? { value as? String }

// Instead of doing this we could provide custom classes and decode to them
public var arrayValue: [Config]? {
guard let list = value as? [Any] else { return nil }
return list.map { Config($0 as! [NSString : Any]) }
return list.map { Config($0 as! [NSString: Any]) }
}

/// Tuple of token identifier and string value
public var tokenValue: (UInt, String)? { value as? (UInt, String) }
}
Expand All @@ -120,7 +120,7 @@ public class LanguageModelConfigurationFromHub {
return try await self.loadConfig(modelName: modelName, hubApi: hubApi)
}
}

public init(
modelFolder: URL,
hubApi: HubApi = .shared
Expand All @@ -140,12 +140,13 @@ public class LanguageModelConfigurationFromHub {
get async throws {
if let hubConfig = try await configPromise!.value.tokenizerConfig {
// Try to guess the class if it's not present and the modelType is
if let _ = hubConfig.tokenizerClass?.stringValue { return hubConfig }
if hubConfig.tokenizerClass?.stringValue != nil { return hubConfig }
guard let modelType = try await modelType else { return hubConfig }

// If the config exists but doesn't contain a tokenizerClass, use a fallback config if we have it
if let fallbackConfig = Self.fallbackTokenizerConfig(for: modelType) {
let configuration = fallbackConfig.dictionary.merging(hubConfig.dictionary, uniquingKeysWith: { current, _ in current })
let configuration = fallbackConfig.dictionary.merging(
hubConfig.dictionary, uniquingKeysWith: { current, _ in current })
return Config(configuration)
}

Expand Down Expand Up @@ -183,7 +184,7 @@ public class LanguageModelConfigurationFromHub {

return try await loadConfig(modelFolder: downloadedModelFolder, hubApi: hubApi)
}

func loadConfig(
modelFolder: URL,
hubApi: HubApi = .shared
Expand All @@ -192,7 +193,7 @@ public class LanguageModelConfigurationFromHub {
let modelConfig = try hubApi.configuration(fileURL: modelFolder.appending(path: "config.json"))
let tokenizerConfig = try? hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer_config.json"))
let tokenizerVocab = try hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer.json"))

let configs = Configurations(
modelConfig: modelConfig,
tokenizerConfig: tokenizerConfig,
Expand All @@ -202,7 +203,9 @@ public class LanguageModelConfigurationFromHub {
}

static func fallbackTokenizerConfig(for modelType: String) -> Config? {
guard let url = Bundle.module.url(forResource: "\(modelType)_tokenizer_config", withExtension: "json") else { return nil }
guard let url = Bundle.module.url(forResource: "\(modelType)_tokenizer_config", withExtension: "json") else {
return nil
}
do {
let data = try Data(contentsOf: url)
let parsed = try JSONSerialization.jsonObject(with: data, options: [])
Expand Down
Loading