Skip to content

Commit ae3ce32

Browse files
authored
Fixed indexes vs indices naming (#54)
1 parent 03d86ac commit ae3ce32

File tree

6 files changed

+44
-44
lines changed

6 files changed

+44
-44
lines changed

Sources/TensorUtils/LogitsWarper/LogitsProcessor.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ public struct LogitsProcessor {
77
self.logitsWarpers = logitsWarpers
88
}
99

10-
public func callAsFunction(_ arr: [Float]) -> (indexes: [Int], logits: [Float]) {
11-
var indexes = Array(arr.indices)
10+
public func callAsFunction(_ arr: [Float]) -> (indices: [Int], logits: [Float]) {
11+
var indices = Array(arr.indices)
1212
var logits = arr
1313
for warper in logitsWarpers {
14-
(indexes, logits) = warper(indexes, logits)
14+
(indices, logits) = warper(indices, logits)
1515
}
16-
return (indexes: indexes, logits: logits)
16+
return (indices: indices, logits: logits)
1717
}
1818
}

Sources/TensorUtils/LogitsWarper/LogitsWarper.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ import Foundation
22

33
/// Protocol for all logit warpers that can be applied during generation
44
public protocol LogitsWarper {
5-
func warp(indexes: [Int], logits: [Float]) -> (indexes: [Int], logits: [Float])
6-
func callAsFunction(_ indexes: [Int], _ logits: [Float]) -> (indexes: [Int], logits: [Float])
5+
func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float])
6+
func callAsFunction(_ indices: [Int], _ logits: [Float]) -> (indices: [Int], logits: [Float])
77
}
88

99
extension LogitsWarper {
10-
public func callAsFunction(_ indexes: [Int], _ logits: [Float]) -> (indexes: [Int], logits: [Float]) {
11-
warp(indexes: indexes, logits: logits)
10+
public func callAsFunction(_ indices: [Int], _ logits: [Float]) -> (indices: [Int], logits: [Float]) {
11+
warp(indices: indices, logits: logits)
1212
}
1313
}

Sources/TensorUtils/LogitsWarper/TemperatureLogitsWarper.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ public struct TemperatureLogitsWarper: LogitsWarper {
77
self.temperature = temperature
88
}
99

10-
public func warp(indexes: [Int], logits: [Float]) -> (indexes: [Int], logits: [Float]) {
11-
return (indexes: indexes, logits: logits.map { $0 / temperature })
10+
public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) {
11+
return (indices: indices, logits: logits.map { $0 / temperature })
1212
}
1313
}

Sources/TensorUtils/LogitsWarper/TopKLogitsWarper.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ public struct TopKLogitsWarper: LogitsWarper {
1212
self.k = k
1313
}
1414

15-
public func warp(indexes: [Int], logits: [Float]) -> (indexes: [Int], logits: [Float]) {
15+
public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) {
1616
guard !logits.isEmpty else {
17-
return (indexes: [], logits: [])
17+
return (indices: [], logits: [])
1818
}
1919
let k = min(k, logits.count)
2020
let arrDescriptor = BNNSNDArrayDescriptor.allocate(
@@ -50,9 +50,9 @@ public struct TopKLogitsWarper: LogitsWarper {
5050
let topkLogits = bestValues.data!.withMemoryRebound(to: Float.self, capacity: k) { ptr in
5151
Array(UnsafeBufferPointer(start: ptr, count: k))
5252
}
53-
let topkIndexes = bestIndices.data!.withMemoryRebound(to: Int32.self, capacity: k) { ptr in
53+
let topkIndices = bestIndices.data!.withMemoryRebound(to: Int32.self, capacity: k) { ptr in
5454
Array(UnsafeBufferPointer(start: ptr, count: k))
5555
}
56-
return (indexes: topkIndexes.map { indexes[Int($0)] }, logits: topkLogits)
56+
return (indices: topkIndices.map { indices[Int($0)] }, logits: topkLogits)
5757
}
5858
}

Sources/TensorUtils/LogitsWarper/TopPLogitsWarper.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ public struct TopPLogitsWarper: LogitsWarper {
1010
self.p = p
1111
}
1212

13-
public func warp(indexes: [Int], logits: [Float]) -> (indexes: [Int], logits: [Float]) {
13+
public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) {
1414
guard !logits.isEmpty else {
15-
return (indexes: [], logits: [])
15+
return (indices: [], logits: [])
1616
}
1717

1818
let arrSoftmax = Math.softmax(logits)
@@ -30,8 +30,8 @@ public struct TopPLogitsWarper: LogitsWarper {
3030
break
3131
}
3232

33-
let toppIndexes = indexLogitProb[0 ... sliceIndex].map { indexes[$0.index] }
33+
let toppIndices = indexLogitProb[0 ... sliceIndex].map { indices[$0.index] }
3434
let toppLogits = indexLogitProb[0 ... sliceIndex].map(\.logit)
35-
return (indexes: toppIndexes, logits: toppLogits)
35+
return (indices: toppIndices, logits: toppLogits)
3636
}
3737
}

Tests/TensorUtilsTests/LogitsWarperTests.swift

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,110 +13,110 @@ final class LogitsWarperTests: XCTestCase {
1313

1414
func testTemperatureLogitsWarper() {
1515
let result1 = TemperatureLogitsWarper(temperature: 0.0)([], [])
16-
XCTAssertTrue(result1.indexes.isEmpty)
16+
XCTAssertTrue(result1.indices.isEmpty)
1717
XCTAssertTrue(result1.logits.isEmpty)
1818

1919
let result2 = TemperatureLogitsWarper(temperature: 1.0)([], [])
20-
XCTAssertTrue(result2.indexes.isEmpty)
20+
XCTAssertTrue(result2.indices.isEmpty)
2121
XCTAssertTrue(result2.logits.isEmpty)
2222

2323
let result3 = TemperatureLogitsWarper(temperature: 1.0)([0, 1], [2.0, 1.0])
24-
XCTAssertEqual(result3.indexes, [0, 1])
24+
XCTAssertEqual(result3.indices, [0, 1])
2525
XCTAssertEqual(result3.logits, [2.0, 1.0], accuracy: accuracy)
2626

2727
let result4 = TemperatureLogitsWarper(temperature: 2.0)([0, 1], [2.0, 1.0])
28-
XCTAssertEqual(result4.indexes, [0, 1])
28+
XCTAssertEqual(result4.indices, [0, 1])
2929
XCTAssertEqual(result4.logits, [1.0, 0.5], accuracy: accuracy)
3030

3131
let result5 = TemperatureLogitsWarper(temperature: 0.5)([0, 1], [2.0, 1.0])
32-
XCTAssertEqual(result5.indexes, [0, 1])
32+
XCTAssertEqual(result5.indices, [0, 1])
3333
XCTAssertEqual(result5.logits, [4.0, 2.0], accuracy: accuracy)
3434

3535
let result6 = TemperatureLogitsWarper(temperature: 0.5)([200, 100], [2.0, 1.0])
36-
XCTAssertEqual(result6.indexes, [200, 100])
36+
XCTAssertEqual(result6.indices, [200, 100])
3737
XCTAssertEqual(result6.logits, [4.0, 2.0], accuracy: accuracy)
3838
}
3939

4040
func testTopKLogitsWarper() {
4141
let result1 = TopKLogitsWarper(k: 0)([], [])
42-
XCTAssertTrue(result1.indexes.isEmpty)
42+
XCTAssertTrue(result1.indices.isEmpty)
4343
XCTAssertTrue(result1.logits.isEmpty)
4444

4545
let result2 = TopKLogitsWarper(k: 3)([], [])
46-
XCTAssertTrue(result2.indexes.isEmpty)
46+
XCTAssertTrue(result2.indices.isEmpty)
4747
XCTAssertTrue(result2.logits.isEmpty)
4848

4949
let result3 = TopKLogitsWarper(k: 3)([0, 1], [2.0, 1.0])
50-
XCTAssertEqual(result3.indexes, [0, 1])
50+
XCTAssertEqual(result3.indices, [0, 1])
5151
XCTAssertEqual(result3.logits, [2.0, 1.0], accuracy: accuracy)
5252

5353
let result4 = TopKLogitsWarper(k: 3)([0, 1, 2], [2.0, 1.0, 3.0])
54-
XCTAssertEqual(result4.indexes, [2, 0, 1])
54+
XCTAssertEqual(result4.indices, [2, 0, 1])
5555
XCTAssertEqual(result4.logits, [3.0, 2.0, 1.0], accuracy: accuracy)
5656

5757
let result5 = TopKLogitsWarper(k: 4)([0, 1, 2, 3, 4, 5], [2.0, 1.0, 3.0, -1.0, 123.0, 0.0])
58-
XCTAssertEqual(result5.indexes, [4, 2, 0, 1])
58+
XCTAssertEqual(result5.indices, [4, 2, 0, 1])
5959
XCTAssertEqual(result5.logits, [123.0, 3.0, 2.0, 1.0], accuracy: accuracy)
6060

6161
let result6 = TopKLogitsWarper(k: 3)([10, 1, 52], [2.0, 1.0, 3.0])
62-
XCTAssertEqual(result6.indexes, [52, 10, 1])
62+
XCTAssertEqual(result6.indices, [52, 10, 1])
6363
XCTAssertEqual(result6.logits, [3.0, 2.0, 1.0], accuracy: accuracy)
6464
}
6565

6666
func testTopPLogitsWarper() {
6767
let result1 = TopPLogitsWarper(p: 0.99)([], [])
68-
XCTAssertTrue(result1.indexes.isEmpty)
68+
XCTAssertTrue(result1.indices.isEmpty)
6969
XCTAssertTrue(result1.logits.isEmpty)
7070

7171
let logits = (0 ..< 10).map { Float($0) }
72-
let indexes = Array(logits.indices)
73-
let result2 = TopPLogitsWarper(p: 0.99)(indexes, logits)
74-
XCTAssertEqual(result2.indexes, [9, 8, 7, 6, 5])
72+
let indices = Array(logits.indices)
73+
let result2 = TopPLogitsWarper(p: 0.99)(indices, logits)
74+
XCTAssertEqual(result2.indices, [9, 8, 7, 6, 5])
7575
XCTAssertEqual(result2.logits, [9.0, 8.0, 7.0, 6.0, 5.0], accuracy: accuracy)
7676

77-
let result3 = TopPLogitsWarper(p: 0.95)(indexes, logits)
78-
XCTAssertEqual(result3.indexes, [9, 8, 7])
77+
let result3 = TopPLogitsWarper(p: 0.95)(indices, logits)
78+
XCTAssertEqual(result3.indices, [9, 8, 7])
7979
XCTAssertEqual(result3.logits, [9.0, 8.0, 7.0], accuracy: accuracy)
8080

81-
let result4 = TopPLogitsWarper(p: 0.6321493)(indexes, logits)
82-
XCTAssertEqual(result4.indexes, [9, 8])
81+
let result4 = TopPLogitsWarper(p: 0.6321493)(indices, logits)
82+
XCTAssertEqual(result4.indices, [9, 8])
8383
XCTAssertEqual(result4.logits, [9.0, 8.0], accuracy: accuracy)
8484

8585
let result5 = TopPLogitsWarper(p: 0.95)([3, 1, 8], [0, 1, 2])
86-
XCTAssertEqual(result5.indexes, [8, 1, 3])
86+
XCTAssertEqual(result5.indices, [8, 1, 3])
8787
XCTAssertEqual(result5.logits, [2, 1, 0], accuracy: accuracy)
8888
}
8989

9090
func testLogitsProcessor() {
9191
let processor1 = LogitsProcessor(logitsWarpers: [])
9292
let result1 = processor1([])
93-
XCTAssertTrue(result1.indexes.isEmpty)
93+
XCTAssertTrue(result1.indices.isEmpty)
9494
XCTAssertTrue(result1.logits.isEmpty)
9595

9696
let processor2 = LogitsProcessor(logitsWarpers: [])
9797
let result2 = processor2([2.0, 1.0])
98-
XCTAssertEqual(result2.indexes, [0, 1])
98+
XCTAssertEqual(result2.indices, [0, 1])
9999
XCTAssertEqual(result2.logits, [2.0, 1.0], accuracy: accuracy)
100100

101101
let processor3 = LogitsProcessor(
102102
logitsWarpers: [TopKLogitsWarper(k: 3)]
103103
)
104104
let result3 = processor3([2.0, 1.0, 3.0, -5.0])
105-
XCTAssertEqual(result3.indexes, [2, 0, 1])
105+
XCTAssertEqual(result3.indices, [2, 0, 1])
106106
XCTAssertEqual(result3.logits, [3.0, 2.0, 1.0], accuracy: accuracy)
107107

108108
let processor4 = LogitsProcessor(
109109
logitsWarpers: [TopKLogitsWarper(k: 3), TopPLogitsWarper(p: 0.99)]
110110
)
111111
let result4 = processor4([2.0, 1.0, 3.0, -5.0, -23.0, 12.5])
112-
XCTAssertEqual(result4.indexes, [5])
112+
XCTAssertEqual(result4.indices, [5])
113113
XCTAssertEqual(result4.logits, [12.5], accuracy: accuracy)
114114

115115
let processor5 = LogitsProcessor(
116116
logitsWarpers: [TopKLogitsWarper(k: 4), TopPLogitsWarper(p: 0.99)]
117117
)
118118
let result5 = processor5([2.0, 1.0, 3.0, -5.0, -3.0, 4.5])
119-
XCTAssertEqual(result5.indexes, [5, 2, 0, 1])
119+
XCTAssertEqual(result5.indices, [5, 2, 0, 1])
120120
XCTAssertEqual(result5.logits, [4.5, 3.0, 2.0, 1.0], accuracy: accuracy)
121121
}
122122
}

0 commit comments

Comments
 (0)