From 6984ba81a608ba5ba6a231d9f6488751ecefdcb2 Mon Sep 17 00:00:00 2001 From: shavit Date: Fri, 29 Mar 2024 14:08:07 -0400 Subject: [PATCH 1/2] Add repetition penalty warper --- .../RepetitionPenaltyWarper.swift | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 Sources/TensorUtils/LogitsWarper/RepetitionPenaltyWarper.swift diff --git a/Sources/TensorUtils/LogitsWarper/RepetitionPenaltyWarper.swift b/Sources/TensorUtils/LogitsWarper/RepetitionPenaltyWarper.swift new file mode 100644 index 0000000..3d612eb --- /dev/null +++ b/Sources/TensorUtils/LogitsWarper/RepetitionPenaltyWarper.swift @@ -0,0 +1,25 @@ +import Foundation + +/// `RepetitionPenaltyWarper` prevents the repetition of previous tokens through a penalty. +/// This penalty is applied at most once per token. +/// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L294 +public struct RepetitionPenaltyWarper: LogitsWarper { + public var penalty: Float + + public init(penalty: Float) { + self.penalty = penalty + } + + public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) { + var logits = logits + for index in indices { + if logits[index] < 0 { + logits[index] *= penalty + } else { + logits[index] /= penalty + } + } + + return (indices, logits) + } +} From 1d822fceb6d7971b9bc4ec6b070e61687f4b14b7 Mon Sep 17 00:00:00 2001 From: shavit Date: Fri, 29 Mar 2024 14:17:22 -0400 Subject: [PATCH 2/2] Float the penalty * Add penalty to logits warpers * Test repetition penalty --- Sources/Generation/Generation.swift | 3 ++ .../RepetitionPenaltyWarper.swift | 4 +-- .../TensorUtilsTests/LogitsWarperTests.swift | 30 +++++++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/Sources/Generation/Generation.swift b/Sources/Generation/Generation.swift index 9464925..6cfd8ab 100644 --- a/Sources/Generation/Generation.swift +++ b/Sources/Generation/Generation.swift @@ -101,6 +101,9 @@ public extension Generation { if config.topP < 1.0 { logitsWarpers.append(TopPLogitsWarper(p: Float(config.topP))) } + if config.repetitionPenalty != 1.0 { + logitsWarpers.append(RepetitionPenaltyWarper(penalty: config.repetitionPenalty)) + } return logitsWarpers } } diff --git a/Sources/TensorUtils/LogitsWarper/RepetitionPenaltyWarper.swift b/Sources/TensorUtils/LogitsWarper/RepetitionPenaltyWarper.swift index 3d612eb..cbc5c70 100644 --- a/Sources/TensorUtils/LogitsWarper/RepetitionPenaltyWarper.swift +++ b/Sources/TensorUtils/LogitsWarper/RepetitionPenaltyWarper.swift @@ -6,8 +6,8 @@ import Foundation public struct RepetitionPenaltyWarper: LogitsWarper { public var penalty: Float - public init(penalty: Float) { - self.penalty = penalty + public init(penalty: Double) { + self.penalty = Float(penalty) } public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) { diff --git a/Tests/TensorUtilsTests/LogitsWarperTests.swift b/Tests/TensorUtilsTests/LogitsWarperTests.swift index fa038c1..0260967 100644 --- a/Tests/TensorUtilsTests/LogitsWarperTests.swift +++ b/Tests/TensorUtilsTests/LogitsWarperTests.swift @@ -87,6 +87,36 @@ final class LogitsWarperTests: XCTestCase { XCTAssertEqual(result5.logits, [2, 1, 0], accuracy: accuracy) } + func testRepetitionPenaltyWarper() { + let indices = Array(0..<10) + let logits = indices.map({ Float($0) }) + + let result1 = RepetitionPenaltyWarper(penalty: 1.0)(indices, logits) + XCTAssertEqual(result1.indices, indices) + XCTAssertEqual(result1.logits, logits, accuracy: accuracy) + + let result2 = RepetitionPenaltyWarper(penalty: 3.75)(indices, logits) + XCTAssertEqual(result2.indices, indices) + let logits2 = indices.map({ Float($0) / 3.75 }) + XCTAssertEqual(result2.logits, logits2, accuracy: accuracy) + + let result3 = RepetitionPenaltyWarper(penalty: 0.75)([0, 1, 2], [0.8108, 0.9954, 0.0119]) + XCTAssertEqual(result3.indices, [0, 1, 2]) + XCTAssertEqual(result3.logits, [1.0811, 1.3272, 0.0158], accuracy: 1e-4) + + let result4 = RepetitionPenaltyWarper(penalty: 1.11)([2, 3, 4], [0.5029, 0.8694, 0.4765, 0.9967, 0.4190, 0.9158]) + XCTAssertEqual(result4.indices, [2, 3, 4]) + XCTAssertEqual(result4.logits, [0.5029, 0.8694, 0.4293, 0.8980, 0.3775, 0.9158], accuracy: 1e-4) + + let result5 = RepetitionPenaltyWarper(penalty: 0.9)([0, 1, 2], [-0.7433, -0.4738, -0.2966]) + XCTAssertEqual(result5.indices, [0, 1, 2]) + XCTAssertEqual(result5.logits, [-0.6690, -0.4264, -0.2669], accuracy: 1e-4) + + let result6 = RepetitionPenaltyWarper(penalty: 1.125)([3, 1, 2], [0.1674, 0.6431, 0.6780, 0.2755]) + XCTAssertEqual(result6.indices, [3, 1, 2]) + XCTAssertEqual(result6.logits, [0.1674, 0.5716, 0.6026, 0.2449], accuracy: 1e-4) + } + func testLogitsProcessor() { let processor1 = LogitsProcessor(logitsWarpers: []) let result1 = processor1([])