@@ -87,6 +87,36 @@ final class LogitsWarperTests: XCTestCase {
87
87
XCTAssertEqual ( result5. logits, [ 2 , 1 , 0 ] , accuracy: accuracy)
88
88
}
89
89
90
+ func testRepetitionPenaltyWarper( ) {
91
+ let indices = Array ( 0 ..< 10 )
92
+ let logits = indices. map ( { Float ( $0) } )
93
+
94
+ let result1 = RepetitionPenaltyWarper ( penalty: 1.0 ) ( indices, logits)
95
+ XCTAssertEqual ( result1. indices, indices)
96
+ XCTAssertEqual ( result1. logits, logits, accuracy: accuracy)
97
+
98
+ let result2 = RepetitionPenaltyWarper ( penalty: 3.75 ) ( indices, logits)
99
+ XCTAssertEqual ( result2. indices, indices)
100
+ let logits2 = indices. map ( { Float ( $0) / 3.75 } )
101
+ XCTAssertEqual ( result2. logits, logits2, accuracy: accuracy)
102
+
103
+ let result3 = RepetitionPenaltyWarper ( penalty: 0.75 ) ( [ 0 , 1 , 2 ] , [ 0.8108 , 0.9954 , 0.0119 ] )
104
+ XCTAssertEqual ( result3. indices, [ 0 , 1 , 2 ] )
105
+ XCTAssertEqual ( result3. logits, [ 1.0811 , 1.3272 , 0.0158 ] , accuracy: 1e-4 )
106
+
107
+ let result4 = RepetitionPenaltyWarper ( penalty: 1.11 ) ( [ 2 , 3 , 4 ] , [ 0.5029 , 0.8694 , 0.4765 , 0.9967 , 0.4190 , 0.9158 ] )
108
+ XCTAssertEqual ( result4. indices, [ 2 , 3 , 4 ] )
109
+ XCTAssertEqual ( result4. logits, [ 0.5029 , 0.8694 , 0.4293 , 0.8980 , 0.3775 , 0.9158 ] , accuracy: 1e-4 )
110
+
111
+ let result5 = RepetitionPenaltyWarper ( penalty: 0.9 ) ( [ 0 , 1 , 2 ] , [ - 0.7433 , - 0.4738 , - 0.2966 ] )
112
+ XCTAssertEqual ( result5. indices, [ 0 , 1 , 2 ] )
113
+ XCTAssertEqual ( result5. logits, [ - 0.6690 , - 0.4264 , - 0.2669 ] , accuracy: 1e-4 )
114
+
115
+ let result6 = RepetitionPenaltyWarper ( penalty: 1.125 ) ( [ 3 , 1 , 2 ] , [ 0.1674 , 0.6431 , 0.6780 , 0.2755 ] )
116
+ XCTAssertEqual ( result6. indices, [ 3 , 1 , 2 ] )
117
+ XCTAssertEqual ( result6. logits, [ 0.1674 , 0.5716 , 0.6026 , 0.2449 ] , accuracy: 1e-4 )
118
+ }
119
+
90
120
func testLogitsProcessor( ) {
91
121
let processor1 = LogitsProcessor ( logitsWarpers: [ ] )
92
122
let result1 = processor1 ( [ ] )
0 commit comments