-
Notifications
You must be signed in to change notification settings - Fork 131
Add repetition penalty warper #85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import Foundation | ||
|
||
/// `RepetitionPenaltyWarper` prevents the repetition of previous tokens through a penalty. | ||
/// This penalty is applied at most once per token. | ||
/// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L294 | ||
public struct RepetitionPenaltyWarper: LogitsWarper { | ||
public var penalty: Float | ||
|
||
public init(penalty: Double) { | ||
self.penalty = Float(penalty) | ||
} | ||
|
||
public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) { | ||
var logits = logits | ||
for index in indices { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you try to find an accelerated version of this? (not critical right now, but good to keep in mind for the future) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wanted to use gather scatter, but didn't find scatter method in vDSP. |
||
if logits[index] < 0 { | ||
logits[index] *= penalty | ||
} else { | ||
logits[index] /= penalty | ||
} | ||
} | ||
|
||
return (indices, logits) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,6 +87,36 @@ final class LogitsWarperTests: XCTestCase { | |
XCTAssertEqual(result5.logits, [2, 1, 0], accuracy: accuracy) | ||
} | ||
|
||
func testRepetitionPenaltyWarper() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for adding tests! Did you verify results against the reference transformers implementation? |
||
let indices = Array(0..<10) | ||
let logits = indices.map({ Float($0) }) | ||
|
||
let result1 = RepetitionPenaltyWarper(penalty: 1.0)(indices, logits) | ||
XCTAssertEqual(result1.indices, indices) | ||
XCTAssertEqual(result1.logits, logits, accuracy: accuracy) | ||
|
||
let result2 = RepetitionPenaltyWarper(penalty: 3.75)(indices, logits) | ||
XCTAssertEqual(result2.indices, indices) | ||
let logits2 = indices.map({ Float($0) / 3.75 }) | ||
XCTAssertEqual(result2.logits, logits2, accuracy: accuracy) | ||
|
||
let result3 = RepetitionPenaltyWarper(penalty: 0.75)([0, 1, 2], [0.8108, 0.9954, 0.0119]) | ||
XCTAssertEqual(result3.indices, [0, 1, 2]) | ||
XCTAssertEqual(result3.logits, [1.0811, 1.3272, 0.0158], accuracy: 1e-4) | ||
|
||
let result4 = RepetitionPenaltyWarper(penalty: 1.11)([2, 3, 4], [0.5029, 0.8694, 0.4765, 0.9967, 0.4190, 0.9158]) | ||
XCTAssertEqual(result4.indices, [2, 3, 4]) | ||
XCTAssertEqual(result4.logits, [0.5029, 0.8694, 0.4293, 0.8980, 0.3775, 0.9158], accuracy: 1e-4) | ||
|
||
let result5 = RepetitionPenaltyWarper(penalty: 0.9)([0, 1, 2], [-0.7433, -0.4738, -0.2966]) | ||
XCTAssertEqual(result5.indices, [0, 1, 2]) | ||
XCTAssertEqual(result5.logits, [-0.6690, -0.4264, -0.2669], accuracy: 1e-4) | ||
|
||
let result6 = RepetitionPenaltyWarper(penalty: 1.125)([3, 1, 2], [0.1674, 0.6431, 0.6780, 0.2755]) | ||
XCTAssertEqual(result6.indices, [3, 1, 2]) | ||
XCTAssertEqual(result6.logits, [0.1674, 0.5716, 0.6026, 0.2449], accuracy: 1e-4) | ||
} | ||
|
||
func testLogitsProcessor() { | ||
let processor1 = LogitsProcessor(logitsWarpers: []) | ||
let result1 = processor1([]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe let's throw here if
penalty <= 0
(it must be strictly positive)