-
Notifications
You must be signed in to change notification settings - Fork 126
Add repetition penalty warper #85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
* Add penalty to logits warpers * Test repetition penalty
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! This is a very nice addition to the library 🔥 I'd just like to confirm whether results are the same as in transformers, will run some tests later.
|
||
public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) { | ||
var logits = logits | ||
for index in indices { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you try to find an accelerated version of this? (not critical right now, but good to keep in mind for the future)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to use gather scatter, but didn't find scatter method in vDSP.
@@ -87,6 +87,36 @@ final class LogitsWarperTests: XCTestCase { | |||
XCTAssertEqual(result5.logits, [2, 1, 0], accuracy: accuracy) | |||
} | |||
|
|||
func testRepetitionPenaltyWarper() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding tests! Did you verify results against the reference transformers implementation?
public struct RepetitionPenaltyWarper: LogitsWarper { | ||
public var penalty: Float | ||
|
||
public init(penalty: Double) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe let's throw here if penalty <= 0
(it must be strictly positive)
Test runners are not working for some reason. I verified locally that the tests pass, so merging now. Thanks @shavit! |
Closes #84