@@ -13,110 +13,110 @@ final class LogitsWarperTests: XCTestCase {
13
13
14
14
func testTemperatureLogitsWarper( ) {
15
15
let result1 = TemperatureLogitsWarper ( temperature: 0.0 ) ( [ ] , [ ] )
16
- XCTAssertTrue ( result1. indexes . isEmpty)
16
+ XCTAssertTrue ( result1. indices . isEmpty)
17
17
XCTAssertTrue ( result1. logits. isEmpty)
18
18
19
19
let result2 = TemperatureLogitsWarper ( temperature: 1.0 ) ( [ ] , [ ] )
20
- XCTAssertTrue ( result2. indexes . isEmpty)
20
+ XCTAssertTrue ( result2. indices . isEmpty)
21
21
XCTAssertTrue ( result2. logits. isEmpty)
22
22
23
23
let result3 = TemperatureLogitsWarper ( temperature: 1.0 ) ( [ 0 , 1 ] , [ 2.0 , 1.0 ] )
24
- XCTAssertEqual ( result3. indexes , [ 0 , 1 ] )
24
+ XCTAssertEqual ( result3. indices , [ 0 , 1 ] )
25
25
XCTAssertEqual ( result3. logits, [ 2.0 , 1.0 ] , accuracy: accuracy)
26
26
27
27
let result4 = TemperatureLogitsWarper ( temperature: 2.0 ) ( [ 0 , 1 ] , [ 2.0 , 1.0 ] )
28
- XCTAssertEqual ( result4. indexes , [ 0 , 1 ] )
28
+ XCTAssertEqual ( result4. indices , [ 0 , 1 ] )
29
29
XCTAssertEqual ( result4. logits, [ 1.0 , 0.5 ] , accuracy: accuracy)
30
30
31
31
let result5 = TemperatureLogitsWarper ( temperature: 0.5 ) ( [ 0 , 1 ] , [ 2.0 , 1.0 ] )
32
- XCTAssertEqual ( result5. indexes , [ 0 , 1 ] )
32
+ XCTAssertEqual ( result5. indices , [ 0 , 1 ] )
33
33
XCTAssertEqual ( result5. logits, [ 4.0 , 2.0 ] , accuracy: accuracy)
34
34
35
35
let result6 = TemperatureLogitsWarper ( temperature: 0.5 ) ( [ 200 , 100 ] , [ 2.0 , 1.0 ] )
36
- XCTAssertEqual ( result6. indexes , [ 200 , 100 ] )
36
+ XCTAssertEqual ( result6. indices , [ 200 , 100 ] )
37
37
XCTAssertEqual ( result6. logits, [ 4.0 , 2.0 ] , accuracy: accuracy)
38
38
}
39
39
40
40
func testTopKLogitsWarper( ) {
41
41
let result1 = TopKLogitsWarper ( k: 0 ) ( [ ] , [ ] )
42
- XCTAssertTrue ( result1. indexes . isEmpty)
42
+ XCTAssertTrue ( result1. indices . isEmpty)
43
43
XCTAssertTrue ( result1. logits. isEmpty)
44
44
45
45
let result2 = TopKLogitsWarper ( k: 3 ) ( [ ] , [ ] )
46
- XCTAssertTrue ( result2. indexes . isEmpty)
46
+ XCTAssertTrue ( result2. indices . isEmpty)
47
47
XCTAssertTrue ( result2. logits. isEmpty)
48
48
49
49
let result3 = TopKLogitsWarper ( k: 3 ) ( [ 0 , 1 ] , [ 2.0 , 1.0 ] )
50
- XCTAssertEqual ( result3. indexes , [ 0 , 1 ] )
50
+ XCTAssertEqual ( result3. indices , [ 0 , 1 ] )
51
51
XCTAssertEqual ( result3. logits, [ 2.0 , 1.0 ] , accuracy: accuracy)
52
52
53
53
let result4 = TopKLogitsWarper ( k: 3 ) ( [ 0 , 1 , 2 ] , [ 2.0 , 1.0 , 3.0 ] )
54
- XCTAssertEqual ( result4. indexes , [ 2 , 0 , 1 ] )
54
+ XCTAssertEqual ( result4. indices , [ 2 , 0 , 1 ] )
55
55
XCTAssertEqual ( result4. logits, [ 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
56
56
57
57
let result5 = TopKLogitsWarper ( k: 4 ) ( [ 0 , 1 , 2 , 3 , 4 , 5 ] , [ 2.0 , 1.0 , 3.0 , - 1.0 , 123.0 , 0.0 ] )
58
- XCTAssertEqual ( result5. indexes , [ 4 , 2 , 0 , 1 ] )
58
+ XCTAssertEqual ( result5. indices , [ 4 , 2 , 0 , 1 ] )
59
59
XCTAssertEqual ( result5. logits, [ 123.0 , 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
60
60
61
61
let result6 = TopKLogitsWarper ( k: 3 ) ( [ 10 , 1 , 52 ] , [ 2.0 , 1.0 , 3.0 ] )
62
- XCTAssertEqual ( result6. indexes , [ 52 , 10 , 1 ] )
62
+ XCTAssertEqual ( result6. indices , [ 52 , 10 , 1 ] )
63
63
XCTAssertEqual ( result6. logits, [ 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
64
64
}
65
65
66
66
func testTopPLogitsWarper( ) {
67
67
let result1 = TopPLogitsWarper ( p: 0.99 ) ( [ ] , [ ] )
68
- XCTAssertTrue ( result1. indexes . isEmpty)
68
+ XCTAssertTrue ( result1. indices . isEmpty)
69
69
XCTAssertTrue ( result1. logits. isEmpty)
70
70
71
71
let logits = ( 0 ..< 10 ) . map { Float ( $0) }
72
- let indexes = Array ( logits. indices)
73
- let result2 = TopPLogitsWarper ( p: 0.99 ) ( indexes , logits)
74
- XCTAssertEqual ( result2. indexes , [ 9 , 8 , 7 , 6 , 5 ] )
72
+ let indices = Array ( logits. indices)
73
+ let result2 = TopPLogitsWarper ( p: 0.99 ) ( indices , logits)
74
+ XCTAssertEqual ( result2. indices , [ 9 , 8 , 7 , 6 , 5 ] )
75
75
XCTAssertEqual ( result2. logits, [ 9.0 , 8.0 , 7.0 , 6.0 , 5.0 ] , accuracy: accuracy)
76
76
77
- let result3 = TopPLogitsWarper ( p: 0.95 ) ( indexes , logits)
78
- XCTAssertEqual ( result3. indexes , [ 9 , 8 , 7 ] )
77
+ let result3 = TopPLogitsWarper ( p: 0.95 ) ( indices , logits)
78
+ XCTAssertEqual ( result3. indices , [ 9 , 8 , 7 ] )
79
79
XCTAssertEqual ( result3. logits, [ 9.0 , 8.0 , 7.0 ] , accuracy: accuracy)
80
80
81
- let result4 = TopPLogitsWarper ( p: 0.6321493 ) ( indexes , logits)
82
- XCTAssertEqual ( result4. indexes , [ 9 , 8 ] )
81
+ let result4 = TopPLogitsWarper ( p: 0.6321493 ) ( indices , logits)
82
+ XCTAssertEqual ( result4. indices , [ 9 , 8 ] )
83
83
XCTAssertEqual ( result4. logits, [ 9.0 , 8.0 ] , accuracy: accuracy)
84
84
85
85
let result5 = TopPLogitsWarper ( p: 0.95 ) ( [ 3 , 1 , 8 ] , [ 0 , 1 , 2 ] )
86
- XCTAssertEqual ( result5. indexes , [ 8 , 1 , 3 ] )
86
+ XCTAssertEqual ( result5. indices , [ 8 , 1 , 3 ] )
87
87
XCTAssertEqual ( result5. logits, [ 2 , 1 , 0 ] , accuracy: accuracy)
88
88
}
89
89
90
90
func testLogitsProcessor( ) {
91
91
let processor1 = LogitsProcessor ( logitsWarpers: [ ] )
92
92
let result1 = processor1 ( [ ] )
93
- XCTAssertTrue ( result1. indexes . isEmpty)
93
+ XCTAssertTrue ( result1. indices . isEmpty)
94
94
XCTAssertTrue ( result1. logits. isEmpty)
95
95
96
96
let processor2 = LogitsProcessor ( logitsWarpers: [ ] )
97
97
let result2 = processor2 ( [ 2.0 , 1.0 ] )
98
- XCTAssertEqual ( result2. indexes , [ 0 , 1 ] )
98
+ XCTAssertEqual ( result2. indices , [ 0 , 1 ] )
99
99
XCTAssertEqual ( result2. logits, [ 2.0 , 1.0 ] , accuracy: accuracy)
100
100
101
101
let processor3 = LogitsProcessor (
102
102
logitsWarpers: [ TopKLogitsWarper ( k: 3 ) ]
103
103
)
104
104
let result3 = processor3 ( [ 2.0 , 1.0 , 3.0 , - 5.0 ] )
105
- XCTAssertEqual ( result3. indexes , [ 2 , 0 , 1 ] )
105
+ XCTAssertEqual ( result3. indices , [ 2 , 0 , 1 ] )
106
106
XCTAssertEqual ( result3. logits, [ 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
107
107
108
108
let processor4 = LogitsProcessor (
109
109
logitsWarpers: [ TopKLogitsWarper ( k: 3 ) , TopPLogitsWarper ( p: 0.99 ) ]
110
110
)
111
111
let result4 = processor4 ( [ 2.0 , 1.0 , 3.0 , - 5.0 , - 23.0 , 12.5 ] )
112
- XCTAssertEqual ( result4. indexes , [ 5 ] )
112
+ XCTAssertEqual ( result4. indices , [ 5 ] )
113
113
XCTAssertEqual ( result4. logits, [ 12.5 ] , accuracy: accuracy)
114
114
115
115
let processor5 = LogitsProcessor (
116
116
logitsWarpers: [ TopKLogitsWarper ( k: 4 ) , TopPLogitsWarper ( p: 0.99 ) ]
117
117
)
118
118
let result5 = processor5 ( [ 2.0 , 1.0 , 3.0 , - 5.0 , - 3.0 , 4.5 ] )
119
- XCTAssertEqual ( result5. indexes , [ 5 , 2 , 0 , 1 ] )
119
+ XCTAssertEqual ( result5. indices , [ 5 , 2 , 0 , 1 ] )
120
120
XCTAssertEqual ( result5. logits, [ 4.5 , 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
121
121
}
122
122
}
0 commit comments