Skip to content

Commit 5e02089

Browse files
authored
Add repetition penalty warper (#85)
* Add repetition penalty warper * Float the penalty * Add penalty to logits warpers * Test repetition penalty
1 parent 9df94c1 commit 5e02089

File tree

3 files changed

+58
-0
lines changed

3 files changed

+58
-0
lines changed

Sources/Generation/Generation.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ public extension Generation {
101101
if config.topP < 1.0 {
102102
logitsWarpers.append(TopPLogitsWarper(p: Float(config.topP)))
103103
}
104+
if config.repetitionPenalty != 1.0 {
105+
logitsWarpers.append(RepetitionPenaltyWarper(penalty: config.repetitionPenalty))
106+
}
104107
return logitsWarpers
105108
}
106109
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import Foundation
2+
3+
/// `RepetitionPenaltyWarper` prevents the repetition of previous tokens through a penalty.
4+
/// This penalty is applied at most once per token.
5+
/// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L294
6+
public struct RepetitionPenaltyWarper: LogitsWarper {
7+
public var penalty: Float
8+
9+
public init(penalty: Double) {
10+
self.penalty = Float(penalty)
11+
}
12+
13+
public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) {
14+
var logits = logits
15+
for index in indices {
16+
if logits[index] < 0 {
17+
logits[index] *= penalty
18+
} else {
19+
logits[index] /= penalty
20+
}
21+
}
22+
23+
return (indices, logits)
24+
}
25+
}

Tests/TensorUtilsTests/LogitsWarperTests.swift

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,36 @@ final class LogitsWarperTests: XCTestCase {
8787
XCTAssertEqual(result5.logits, [2, 1, 0], accuracy: accuracy)
8888
}
8989

90+
func testRepetitionPenaltyWarper() {
91+
let indices = Array(0..<10)
92+
let logits = indices.map({ Float($0) })
93+
94+
let result1 = RepetitionPenaltyWarper(penalty: 1.0)(indices, logits)
95+
XCTAssertEqual(result1.indices, indices)
96+
XCTAssertEqual(result1.logits, logits, accuracy: accuracy)
97+
98+
let result2 = RepetitionPenaltyWarper(penalty: 3.75)(indices, logits)
99+
XCTAssertEqual(result2.indices, indices)
100+
let logits2 = indices.map({ Float($0) / 3.75 })
101+
XCTAssertEqual(result2.logits, logits2, accuracy: accuracy)
102+
103+
let result3 = RepetitionPenaltyWarper(penalty: 0.75)([0, 1, 2], [0.8108, 0.9954, 0.0119])
104+
XCTAssertEqual(result3.indices, [0, 1, 2])
105+
XCTAssertEqual(result3.logits, [1.0811, 1.3272, 0.0158], accuracy: 1e-4)
106+
107+
let result4 = RepetitionPenaltyWarper(penalty: 1.11)([2, 3, 4], [0.5029, 0.8694, 0.4765, 0.9967, 0.4190, 0.9158])
108+
XCTAssertEqual(result4.indices, [2, 3, 4])
109+
XCTAssertEqual(result4.logits, [0.5029, 0.8694, 0.4293, 0.8980, 0.3775, 0.9158], accuracy: 1e-4)
110+
111+
let result5 = RepetitionPenaltyWarper(penalty: 0.9)([0, 1, 2], [-0.7433, -0.4738, -0.2966])
112+
XCTAssertEqual(result5.indices, [0, 1, 2])
113+
XCTAssertEqual(result5.logits, [-0.6690, -0.4264, -0.2669], accuracy: 1e-4)
114+
115+
let result6 = RepetitionPenaltyWarper(penalty: 1.125)([3, 1, 2], [0.1674, 0.6431, 0.6780, 0.2755])
116+
XCTAssertEqual(result6.indices, [3, 1, 2])
117+
XCTAssertEqual(result6.logits, [0.1674, 0.5716, 0.6026, 0.2449], accuracy: 1e-4)
118+
}
119+
90120
func testLogitsProcessor() {
91121
let processor1 = LogitsProcessor(logitsWarpers: [])
92122
let result1 = processor1([])

0 commit comments

Comments
 (0)