Skip to content

Commit 1d822fc

Browse files
committed
Float the penalty
* Add penalty to logits warpers * Test repetition penalty
1 parent 6984ba8 commit 1d822fc

File tree

3 files changed

+35
-2
lines changed

3 files changed

+35
-2
lines changed

Sources/Generation/Generation.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ public extension Generation {
101101
if config.topP < 1.0 {
102102
logitsWarpers.append(TopPLogitsWarper(p: Float(config.topP)))
103103
}
104+
if config.repetitionPenalty != 1.0 {
105+
logitsWarpers.append(RepetitionPenaltyWarper(penalty: config.repetitionPenalty))
106+
}
104107
return logitsWarpers
105108
}
106109
}

Sources/TensorUtils/LogitsWarper/RepetitionPenaltyWarper.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ import Foundation
66
public struct RepetitionPenaltyWarper: LogitsWarper {
77
public var penalty: Float
88

9-
public init(penalty: Float) {
10-
self.penalty = penalty
9+
public init(penalty: Double) {
10+
self.penalty = Float(penalty)
1111
}
1212

1313
public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) {

Tests/TensorUtilsTests/LogitsWarperTests.swift

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,36 @@ final class LogitsWarperTests: XCTestCase {
8787
XCTAssertEqual(result5.logits, [2, 1, 0], accuracy: accuracy)
8888
}
8989

90+
func testRepetitionPenaltyWarper() {
91+
let indices = Array(0..<10)
92+
let logits = indices.map({ Float($0) })
93+
94+
let result1 = RepetitionPenaltyWarper(penalty: 1.0)(indices, logits)
95+
XCTAssertEqual(result1.indices, indices)
96+
XCTAssertEqual(result1.logits, logits, accuracy: accuracy)
97+
98+
let result2 = RepetitionPenaltyWarper(penalty: 3.75)(indices, logits)
99+
XCTAssertEqual(result2.indices, indices)
100+
let logits2 = indices.map({ Float($0) / 3.75 })
101+
XCTAssertEqual(result2.logits, logits2, accuracy: accuracy)
102+
103+
let result3 = RepetitionPenaltyWarper(penalty: 0.75)([0, 1, 2], [0.8108, 0.9954, 0.0119])
104+
XCTAssertEqual(result3.indices, [0, 1, 2])
105+
XCTAssertEqual(result3.logits, [1.0811, 1.3272, 0.0158], accuracy: 1e-4)
106+
107+
let result4 = RepetitionPenaltyWarper(penalty: 1.11)([2, 3, 4], [0.5029, 0.8694, 0.4765, 0.9967, 0.4190, 0.9158])
108+
XCTAssertEqual(result4.indices, [2, 3, 4])
109+
XCTAssertEqual(result4.logits, [0.5029, 0.8694, 0.4293, 0.8980, 0.3775, 0.9158], accuracy: 1e-4)
110+
111+
let result5 = RepetitionPenaltyWarper(penalty: 0.9)([0, 1, 2], [-0.7433, -0.4738, -0.2966])
112+
XCTAssertEqual(result5.indices, [0, 1, 2])
113+
XCTAssertEqual(result5.logits, [-0.6690, -0.4264, -0.2669], accuracy: 1e-4)
114+
115+
let result6 = RepetitionPenaltyWarper(penalty: 1.125)([3, 1, 2], [0.1674, 0.6431, 0.6780, 0.2755])
116+
XCTAssertEqual(result6.indices, [3, 1, 2])
117+
XCTAssertEqual(result6.logits, [0.1674, 0.5716, 0.6026, 0.2449], accuracy: 1e-4)
118+
}
119+
90120
func testLogitsProcessor() {
91121
let processor1 = LogitsProcessor(logitsWarpers: [])
92122
let result1 = processor1([])

0 commit comments

Comments
 (0)