@@ -12,65 +12,79 @@ final class LogitsWarperTests: XCTestCase {
12
12
private let accuracy : Float = 0.00001
13
13
14
14
func testTemperatureLogitsWarper( ) {
15
- let result1 = TemperatureLogitsWarper ( temperature: 0.0 ) ( [ ] )
15
+ let result1 = TemperatureLogitsWarper ( temperature: 0.0 ) ( [ ] , [ ] )
16
16
XCTAssertTrue ( result1. indexes. isEmpty)
17
17
XCTAssertTrue ( result1. logits. isEmpty)
18
18
19
- let result2 = TemperatureLogitsWarper ( temperature: 1.0 ) ( [ ] )
19
+ let result2 = TemperatureLogitsWarper ( temperature: 1.0 ) ( [ ] , [ ] )
20
20
XCTAssertTrue ( result2. indexes. isEmpty)
21
21
XCTAssertTrue ( result2. logits. isEmpty)
22
22
23
- let result3 = TemperatureLogitsWarper ( temperature: 1.0 ) ( [ 2.0 , 1.0 ] )
23
+ let result3 = TemperatureLogitsWarper ( temperature: 1.0 ) ( [ 0 , 1 ] , [ 2.0 , 1.0 ] )
24
24
XCTAssertEqual ( result3. indexes, [ 0 , 1 ] )
25
25
XCTAssertEqual ( result3. logits, [ 2.0 , 1.0 ] , accuracy: accuracy)
26
26
27
- let result4 = TemperatureLogitsWarper ( temperature: 2.0 ) ( [ 2.0 , 1.0 ] )
27
+ let result4 = TemperatureLogitsWarper ( temperature: 2.0 ) ( [ 0 , 1 ] , [ 2.0 , 1.0 ] )
28
28
XCTAssertEqual ( result4. indexes, [ 0 , 1 ] )
29
29
XCTAssertEqual ( result4. logits, [ 1.0 , 0.5 ] , accuracy: accuracy)
30
30
31
- let result5 = TemperatureLogitsWarper ( temperature: 0.5 ) ( [ 2.0 , 1.0 ] )
31
+ let result5 = TemperatureLogitsWarper ( temperature: 0.5 ) ( [ 0 , 1 ] , [ 2.0 , 1.0 ] )
32
32
XCTAssertEqual ( result5. indexes, [ 0 , 1 ] )
33
33
XCTAssertEqual ( result5. logits, [ 4.0 , 2.0 ] , accuracy: accuracy)
34
+
35
+ let result6 = TemperatureLogitsWarper ( temperature: 0.5 ) ( [ 200 , 100 ] , [ 2.0 , 1.0 ] )
36
+ XCTAssertEqual ( result6. indexes, [ 200 , 100 ] )
37
+ XCTAssertEqual ( result6. logits, [ 4.0 , 2.0 ] , accuracy: accuracy)
34
38
}
35
39
36
40
func testTopKLogitsWarper( ) {
37
- let result1 = TopKLogitsWarper ( k: 0 ) ( [ ] )
41
+ let result1 = TopKLogitsWarper ( k: 0 ) ( [ ] , [ ] )
38
42
XCTAssertTrue ( result1. indexes. isEmpty)
39
43
XCTAssertTrue ( result1. logits. isEmpty)
40
44
41
- let result2 = TopKLogitsWarper ( k: 3 ) ( [ ] )
45
+ let result2 = TopKLogitsWarper ( k: 3 ) ( [ ] , [ ] )
42
46
XCTAssertTrue ( result2. indexes. isEmpty)
43
47
XCTAssertTrue ( result2. logits. isEmpty)
44
48
45
- let result3 = TopKLogitsWarper ( k: 3 ) ( [ 2.0 , 1.0 ] )
49
+ let result3 = TopKLogitsWarper ( k: 3 ) ( [ 0 , 1 ] , [ 2.0 , 1.0 ] )
46
50
XCTAssertEqual ( result3. indexes, [ 0 , 1 ] )
47
51
XCTAssertEqual ( result3. logits, [ 2.0 , 1.0 ] , accuracy: accuracy)
48
52
49
- let result4 = TopKLogitsWarper ( k: 3 ) ( [ 2.0 , 1.0 , 3.0 ] )
53
+ let result4 = TopKLogitsWarper ( k: 3 ) ( [ 0 , 1 , 2 ] , [ 2.0 , 1.0 , 3.0 ] )
50
54
XCTAssertEqual ( result4. indexes, [ 2 , 0 , 1 ] )
51
55
XCTAssertEqual ( result4. logits, [ 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
52
56
53
- let result5 = TopKLogitsWarper ( k: 4 ) ( [ 2.0 , 1.0 , 3.0 , - 1.0 , 123.0 , 0.0 ] )
57
+ let result5 = TopKLogitsWarper ( k: 4 ) ( [ 0 , 1 , 2 , 3 , 4 , 5 ] , [ 2.0 , 1.0 , 3.0 , - 1.0 , 123.0 , 0.0 ] )
54
58
XCTAssertEqual ( result5. indexes, [ 4 , 2 , 0 , 1 ] )
55
59
XCTAssertEqual ( result5. logits, [ 123.0 , 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
60
+
61
+ let result6 = TopKLogitsWarper ( k: 3 ) ( [ 10 , 1 , 52 ] , [ 2.0 , 1.0 , 3.0 ] )
62
+ XCTAssertEqual ( result6. indexes, [ 52 , 10 , 1 ] )
63
+ XCTAssertEqual ( result6. logits, [ 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
56
64
}
57
65
58
66
func testTopPLogitsWarper( ) {
59
- let result1 = TopPLogitsWarper ( p: 0.99 ) ( [ ] )
67
+ let result1 = TopPLogitsWarper ( p: 0.99 ) ( [ ] , [ ] )
60
68
XCTAssertTrue ( result1. indexes. isEmpty)
61
69
XCTAssertTrue ( result1. logits. isEmpty)
62
70
63
- let result2 = TopPLogitsWarper ( p: 0.99 ) ( ( 0 ..< 10 ) . map { Float ( $0) } )
71
+ let logits = ( 0 ..< 10 ) . map { Float ( $0) }
72
+ let indexes = Array ( logits. indices)
73
+ let result2 = TopPLogitsWarper ( p: 0.99 ) ( indexes, logits)
64
74
XCTAssertEqual ( result2. indexes, [ 9 , 8 , 7 , 6 , 5 ] )
65
75
XCTAssertEqual ( result2. logits, [ 9.0 , 8.0 , 7.0 , 6.0 , 5.0 ] , accuracy: accuracy)
66
76
67
- let result3 = TopPLogitsWarper ( p: 0.95 ) ( ( 0 ..< 10 ) . map { Float ( $0 ) } )
77
+ let result3 = TopPLogitsWarper ( p: 0.95 ) ( indexes , logits )
68
78
XCTAssertEqual ( result3. indexes, [ 9 , 8 , 7 ] )
69
79
XCTAssertEqual ( result3. logits, [ 9.0 , 8.0 , 7.0 ] , accuracy: accuracy)
70
80
71
- let result4 = TopPLogitsWarper ( p: 0.6321493 ) ( ( 0 ..< 10 ) . map { Float ( $0 ) } )
81
+ let result4 = TopPLogitsWarper ( p: 0.6321493 ) ( indexes , logits )
72
82
XCTAssertEqual ( result4. indexes, [ 9 , 8 ] )
73
83
XCTAssertEqual ( result4. logits, [ 9.0 , 8.0 ] , accuracy: accuracy)
84
+
85
+ let result5 = TopPLogitsWarper ( p: 0.95 ) ( [ 3 , 1 , 8 ] , [ 0 , 1 , 2 ] )
86
+ XCTAssertEqual ( result5. indexes, [ 8 , 1 , 3 ] )
87
+ XCTAssertEqual ( result5. logits, [ 2 , 1 , 0 ] , accuracy: accuracy)
74
88
}
75
89
76
90
func testLogitsProcessor( ) {
@@ -95,7 +109,14 @@ final class LogitsWarperTests: XCTestCase {
95
109
logitsWarpers: [ TopKLogitsWarper ( k: 3 ) , TopPLogitsWarper ( p: 0.99 ) ]
96
110
)
97
111
let result4 = processor4 ( [ 2.0 , 1.0 , 3.0 , - 5.0 , - 23.0 , 12.5 ] )
98
- XCTAssertEqual ( result4. indexes, [ 0 ] )
112
+ XCTAssertEqual ( result4. indexes, [ 5 ] )
99
113
XCTAssertEqual ( result4. logits, [ 12.5 ] , accuracy: accuracy)
114
+
115
+ let processor5 = LogitsProcessor (
116
+ logitsWarpers: [ TopKLogitsWarper ( k: 4 ) , TopPLogitsWarper ( p: 0.99 ) ]
117
+ )
118
+ let result5 = processor5 ( [ 2.0 , 1.0 , 3.0 , - 5.0 , - 3.0 , 4.5 ] )
119
+ XCTAssertEqual ( result5. indexes, [ 5 , 2 , 0 , 1 ] )
120
+ XCTAssertEqual ( result5. logits, [ 4.5 , 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
100
121
}
101
122
}
0 commit comments