Skip to content

Commit c754d14

Browse files
authored
fixed logits warper (#49)
1 parent 15bc01e commit c754d14

File tree

6 files changed

+59
-39
lines changed

6 files changed

+59
-39
lines changed

Sources/TensorUtils/LogitsWarper/LogitsProcessor.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public struct LogitsProcessor {
1111
var indexes = Array(arr.indices)
1212
var logits = arr
1313
for warper in logitsWarpers {
14-
(indexes, logits) = warper(logits)
14+
(indexes, logits) = warper(indexes, logits)
1515
}
1616
return (indexes: indexes, logits: logits)
1717
}

Sources/TensorUtils/LogitsWarper/LogitsWarper.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ import Foundation
22

33
/// Protocol for all logit warpers that can be applied during generation
44
public protocol LogitsWarper {
5-
func warp(_ arr: [Float]) -> (indexes: [Int], logits: [Float])
6-
func callAsFunction(_ arr: [Float]) -> (indexes: [Int], logits: [Float])
5+
func warp(indexes: [Int], logits: [Float]) -> (indexes: [Int], logits: [Float])
6+
func callAsFunction(_ indexes: [Int], _ logits: [Float]) -> (indexes: [Int], logits: [Float])
77
}
88

99
extension LogitsWarper {
10-
public func callAsFunction(_ arr: [Float]) -> (indexes: [Int], logits: [Float]) {
11-
warp(arr)
10+
public func callAsFunction(_ indexes: [Int], _ logits: [Float]) -> (indexes: [Int], logits: [Float]) {
11+
warp(indexes: indexes, logits: logits)
1212
}
1313
}

Sources/TensorUtils/LogitsWarper/TemperatureLogitsWarper.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ public struct TemperatureLogitsWarper: LogitsWarper {
77
self.temperature = temperature
88
}
99

10-
public func warp(_ arr: [Float]) -> (indexes: [Int], logits: [Float]) {
11-
let logits = arr.map { $0 / temperature }
12-
return (indexes: Array(logits.indices), logits: logits)
10+
public func warp(indexes: [Int], logits: [Float]) -> (indexes: [Int], logits: [Float]) {
11+
return (indexes: indexes, logits: logits.map { $0 / temperature })
1312
}
1413
}

Sources/TensorUtils/LogitsWarper/TopKLogitsWarper.swift

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ public struct TopKLogitsWarper: LogitsWarper {
1212
self.k = k
1313
}
1414

15-
public func warp(_ arr: [Float]) -> (indexes: [Int], logits: [Float]) {
16-
guard !arr.isEmpty else {
15+
public func warp(indexes: [Int], logits: [Float]) -> (indexes: [Int], logits: [Float]) {
16+
guard !logits.isEmpty else {
1717
return (indexes: [], logits: [])
1818
}
19-
let k = min(k, arr.count)
19+
let k = min(k, logits.count)
2020
let arrDescriptor = BNNSNDArrayDescriptor.allocate(
21-
initializingFrom: arr,
22-
shape: .vector(arr.count)
21+
initializingFrom: logits,
22+
shape: .vector(logits.count)
2323
)
2424
defer {
2525
arrDescriptor.deallocate()
@@ -47,12 +47,12 @@ public struct TopKLogitsWarper: LogitsWarper {
4747
batchSize: 1,
4848
filterParameters: nil
4949
)
50-
let distances = bestValues.data!.withMemoryRebound(to: Float.self, capacity: k) { ptr in
50+
let topkLogits = bestValues.data!.withMemoryRebound(to: Float.self, capacity: k) { ptr in
5151
Array(UnsafeBufferPointer(start: ptr, count: k))
5252
}
53-
let indices = bestIndices.data!.withMemoryRebound(to: Int32.self, capacity: k) { ptr in
53+
let topkIndexes = bestIndices.data!.withMemoryRebound(to: Int32.self, capacity: k) { ptr in
5454
Array(UnsafeBufferPointer(start: ptr, count: k))
5555
}
56-
return (indexes: indices.map { Int($0) }, logits: distances)
56+
return (indexes: topkIndexes.map { indexes[Int($0)] }, logits: topkLogits)
5757
}
5858
}

Sources/TensorUtils/LogitsWarper/TopPLogitsWarper.swift

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@ public struct TopPLogitsWarper: LogitsWarper {
1010
self.p = p
1111
}
1212

13-
public func warp(_ arr: [Float]) -> (indexes: [Int], logits: [Float]) {
14-
guard !arr.isEmpty else {
13+
public func warp(indexes: [Int], logits: [Float]) -> (indexes: [Int], logits: [Float]) {
14+
guard !logits.isEmpty else {
1515
return (indexes: [], logits: [])
1616
}
1717

18-
let arrSoftmax = Math.softmax(arr)
18+
let arrSoftmax = Math.softmax(logits)
1919
var indexLogitProb = [(index: Int, logit: Float, prob: Float)]()
20-
indexLogitProb.reserveCapacity(arr.count)
21-
for (index, data) in zip(arr, arrSoftmax).enumerated() {
20+
indexLogitProb.reserveCapacity(logits.count)
21+
for (index, data) in zip(logits, arrSoftmax).enumerated() {
2222
indexLogitProb.append((index: index, logit: data.0, prob: data.1))
2323
}
2424
indexLogitProb.sort { $0.prob > $1.prob }
@@ -30,8 +30,8 @@ public struct TopPLogitsWarper: LogitsWarper {
3030
break
3131
}
3232

33-
let indexes = indexLogitProb[0 ... sliceIndex].map(\.index)
34-
let logits = indexLogitProb[0 ... sliceIndex].map(\.logit)
35-
return (indexes: indexes, logits: logits)
33+
let toppIndexes = indexLogitProb[0 ... sliceIndex].map { indexes[$0.index] }
34+
let toppLogits = indexLogitProb[0 ... sliceIndex].map(\.logit)
35+
return (indexes: toppIndexes, logits: toppLogits)
3636
}
3737
}

Tests/TensorUtilsTests/LogitsWarperTests.swift

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,65 +12,79 @@ final class LogitsWarperTests: XCTestCase {
1212
private let accuracy: Float = 0.00001
1313

1414
func testTemperatureLogitsWarper() {
15-
let result1 = TemperatureLogitsWarper(temperature: 0.0)([])
15+
let result1 = TemperatureLogitsWarper(temperature: 0.0)([], [])
1616
XCTAssertTrue(result1.indexes.isEmpty)
1717
XCTAssertTrue(result1.logits.isEmpty)
1818

19-
let result2 = TemperatureLogitsWarper(temperature: 1.0)([])
19+
let result2 = TemperatureLogitsWarper(temperature: 1.0)([], [])
2020
XCTAssertTrue(result2.indexes.isEmpty)
2121
XCTAssertTrue(result2.logits.isEmpty)
2222

23-
let result3 = TemperatureLogitsWarper(temperature: 1.0)([2.0, 1.0])
23+
let result3 = TemperatureLogitsWarper(temperature: 1.0)([0, 1], [2.0, 1.0])
2424
XCTAssertEqual(result3.indexes, [0, 1])
2525
XCTAssertEqual(result3.logits, [2.0, 1.0], accuracy: accuracy)
2626

27-
let result4 = TemperatureLogitsWarper(temperature: 2.0)([2.0, 1.0])
27+
let result4 = TemperatureLogitsWarper(temperature: 2.0)([0, 1], [2.0, 1.0])
2828
XCTAssertEqual(result4.indexes, [0, 1])
2929
XCTAssertEqual(result4.logits, [1.0, 0.5], accuracy: accuracy)
3030

31-
let result5 = TemperatureLogitsWarper(temperature: 0.5)([2.0, 1.0])
31+
let result5 = TemperatureLogitsWarper(temperature: 0.5)([0, 1], [2.0, 1.0])
3232
XCTAssertEqual(result5.indexes, [0, 1])
3333
XCTAssertEqual(result5.logits, [4.0, 2.0], accuracy: accuracy)
34+
35+
let result6 = TemperatureLogitsWarper(temperature: 0.5)([200, 100], [2.0, 1.0])
36+
XCTAssertEqual(result6.indexes, [200, 100])
37+
XCTAssertEqual(result6.logits, [4.0, 2.0], accuracy: accuracy)
3438
}
3539

3640
func testTopKLogitsWarper() {
37-
let result1 = TopKLogitsWarper(k: 0)([])
41+
let result1 = TopKLogitsWarper(k: 0)([], [])
3842
XCTAssertTrue(result1.indexes.isEmpty)
3943
XCTAssertTrue(result1.logits.isEmpty)
4044

41-
let result2 = TopKLogitsWarper(k: 3)([])
45+
let result2 = TopKLogitsWarper(k: 3)([], [])
4246
XCTAssertTrue(result2.indexes.isEmpty)
4347
XCTAssertTrue(result2.logits.isEmpty)
4448

45-
let result3 = TopKLogitsWarper(k: 3)([2.0, 1.0])
49+
let result3 = TopKLogitsWarper(k: 3)([0, 1], [2.0, 1.0])
4650
XCTAssertEqual(result3.indexes, [0, 1])
4751
XCTAssertEqual(result3.logits, [2.0, 1.0], accuracy: accuracy)
4852

49-
let result4 = TopKLogitsWarper(k: 3)([2.0, 1.0, 3.0])
53+
let result4 = TopKLogitsWarper(k: 3)([0, 1, 2], [2.0, 1.0, 3.0])
5054
XCTAssertEqual(result4.indexes, [2, 0, 1])
5155
XCTAssertEqual(result4.logits, [3.0, 2.0, 1.0], accuracy: accuracy)
5256

53-
let result5 = TopKLogitsWarper(k: 4)([2.0, 1.0, 3.0, -1.0, 123.0, 0.0])
57+
let result5 = TopKLogitsWarper(k: 4)([0, 1, 2, 3, 4, 5], [2.0, 1.0, 3.0, -1.0, 123.0, 0.0])
5458
XCTAssertEqual(result5.indexes, [4, 2, 0, 1])
5559
XCTAssertEqual(result5.logits, [123.0, 3.0, 2.0, 1.0], accuracy: accuracy)
60+
61+
let result6 = TopKLogitsWarper(k: 3)([10, 1, 52], [2.0, 1.0, 3.0])
62+
XCTAssertEqual(result6.indexes, [52, 10, 1])
63+
XCTAssertEqual(result6.logits, [3.0, 2.0, 1.0], accuracy: accuracy)
5664
}
5765

5866
func testTopPLogitsWarper() {
59-
let result1 = TopPLogitsWarper(p: 0.99)([])
67+
let result1 = TopPLogitsWarper(p: 0.99)([], [])
6068
XCTAssertTrue(result1.indexes.isEmpty)
6169
XCTAssertTrue(result1.logits.isEmpty)
6270

63-
let result2 = TopPLogitsWarper(p: 0.99)((0 ..< 10).map { Float($0) })
71+
let logits = (0 ..< 10).map { Float($0) }
72+
let indexes = Array(logits.indices)
73+
let result2 = TopPLogitsWarper(p: 0.99)(indexes, logits)
6474
XCTAssertEqual(result2.indexes, [9, 8, 7, 6, 5])
6575
XCTAssertEqual(result2.logits, [9.0, 8.0, 7.0, 6.0, 5.0], accuracy: accuracy)
6676

67-
let result3 = TopPLogitsWarper(p: 0.95)((0 ..< 10).map { Float($0) })
77+
let result3 = TopPLogitsWarper(p: 0.95)(indexes, logits)
6878
XCTAssertEqual(result3.indexes, [9, 8, 7])
6979
XCTAssertEqual(result3.logits, [9.0, 8.0, 7.0], accuracy: accuracy)
7080

71-
let result4 = TopPLogitsWarper(p: 0.6321493)((0 ..< 10).map { Float($0) })
81+
let result4 = TopPLogitsWarper(p: 0.6321493)(indexes, logits)
7282
XCTAssertEqual(result4.indexes, [9, 8])
7383
XCTAssertEqual(result4.logits, [9.0, 8.0], accuracy: accuracy)
84+
85+
let result5 = TopPLogitsWarper(p: 0.95)([3, 1, 8], [0, 1, 2])
86+
XCTAssertEqual(result5.indexes, [8, 1, 3])
87+
XCTAssertEqual(result5.logits, [2, 1, 0], accuracy: accuracy)
7488
}
7589

7690
func testLogitsProcessor() {
@@ -95,7 +109,14 @@ final class LogitsWarperTests: XCTestCase {
95109
logitsWarpers: [TopKLogitsWarper(k: 3), TopPLogitsWarper(p: 0.99)]
96110
)
97111
let result4 = processor4([2.0, 1.0, 3.0, -5.0, -23.0, 12.5])
98-
XCTAssertEqual(result4.indexes, [0])
112+
XCTAssertEqual(result4.indexes, [5])
99113
XCTAssertEqual(result4.logits, [12.5], accuracy: accuracy)
114+
115+
let processor5 = LogitsProcessor(
116+
logitsWarpers: [TopKLogitsWarper(k: 4), TopPLogitsWarper(p: 0.99)]
117+
)
118+
let result5 = processor5([2.0, 1.0, 3.0, -5.0, -3.0, 4.5])
119+
XCTAssertEqual(result5.indexes, [5, 2, 0, 1])
120+
XCTAssertEqual(result5.logits, [4.5, 3.0, 2.0, 1.0], accuracy: accuracy)
100121
}
101122
}

0 commit comments

Comments
 (0)