Skip to content

Commit 92af8ea

Browse files
committed
Fix unmanaged self retain missing corrisponding release (caused memory leak)
1 parent d2231f4 commit 92af8ea

File tree

1 file changed

+31
-18
lines changed

1 file changed

+31
-18
lines changed

Sources/SwiftWhisper/Whisper.swift

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import whisper_cpp
33

44
public class Whisper {
55
private let whisperContext: OpaquePointer
6+
private var unmanagedSelf: Unmanaged<Whisper>?
67

78
public var delegate: WhisperDelegate?
89
public var params: WhisperParams
@@ -14,17 +15,13 @@ public class Whisper {
1415
public init(fromFileURL fileURL: URL, withParams params: WhisperParams = .default) {
1516
self.whisperContext = fileURL.relativePath.withCString { whisper_init_from_file($0) }
1617
self.params = params
17-
18-
prepareCallbacks()
1918
}
2019

2120
public init(fromData data: Data, withParams params: WhisperParams = .default) {
2221
var copy = data // Need to copy memory so we can gaurentee exclusive ownership over pointer
2322

2423
self.whisperContext = copy.withUnsafeMutableBytes { whisper_init_from_buffer($0.baseAddress!, data.count) }
2524
self.params = params
26-
27-
prepareCallbacks()
2825
}
2926

3027
deinit {
@@ -38,8 +35,11 @@ public class Whisper {
3835

3936
We can unwrap that and obtain a copy of self inside the callback.
4037
*/
41-
params.new_segment_callback_user_data = Unmanaged.passRetained(self).toOpaque()
42-
params.encoder_begin_callback_user_data = Unmanaged.passRetained(self).toOpaque()
38+
cleanupCallbacks()
39+
let unmanagedSelf = Unmanaged.passRetained(self)
40+
self.unmanagedSelf = unmanagedSelf
41+
params.new_segment_callback_user_data = unmanagedSelf.toOpaque()
42+
params.encoder_begin_callback_user_data = unmanagedSelf.toOpaque()
4343

4444
// swiftlint:disable line_length
4545
params.new_segment_callback = { (ctx: OpaquePointer?, _: OpaquePointer?, newSegmentCount: Int32, userData: UnsafeMutableRawPointer?) in
@@ -94,32 +94,45 @@ public class Whisper {
9494
}
9595
}
9696

97+
private func cleanupCallbacks() {
98+
guard let unmanagedSelf else { return }
99+
100+
unmanagedSelf.release()
101+
self.unmanagedSelf = nil
102+
}
103+
97104
public func transcribe(audioFrames: [Float], completionHandler: @escaping (Result<[Segment], Error>) -> Void) {
105+
prepareCallbacks()
106+
107+
let wrappedCompletionHandler: (Result<[Segment], Error>) -> Void = { result in
108+
self.cleanupCallbacks()
109+
completionHandler(result)
110+
}
111+
98112
guard !inProgress else {
99-
completionHandler(.failure(WhisperError.instanceBusy))
113+
wrappedCompletionHandler(.failure(WhisperError.instanceBusy))
100114
return
101115
}
102116
guard audioFrames.count > 0 else {
103-
completionHandler(.failure(WhisperError.invalidFrames))
117+
wrappedCompletionHandler(.failure(WhisperError.invalidFrames))
104118
return
105119
}
106120

107121
inProgress = true
108122
frameCount = audioFrames.count
109123

110-
DispatchQueue.global(qos: .userInitiated).async { [unowned self] in
111-
112-
whisper_full(whisperContext, params.whisperParams, audioFrames, Int32(audioFrames.count))
124+
DispatchQueue.global(qos: .userInitiated).async {
125+
whisper_full(self.whisperContext, self.params.whisperParams, audioFrames, Int32(audioFrames.count))
113126

114-
let segmentCount = whisper_full_n_segments(whisperContext)
127+
let segmentCount = whisper_full_n_segments(self.whisperContext)
115128

116129
var segments: [Segment] = []
117130
segments.reserveCapacity(Int(segmentCount))
118131

119132
for index in 0..<segmentCount {
120-
guard let text = whisper_full_get_segment_text(whisperContext, index) else { continue }
121-
let startTime = whisper_full_get_segment_t0(whisperContext, index)
122-
let endTime = whisper_full_get_segment_t1(whisperContext, index)
133+
guard let text = whisper_full_get_segment_text(self.whisperContext, index) else { continue }
134+
let startTime = whisper_full_get_segment_t0(self.whisperContext, index)
135+
let endTime = whisper_full_get_segment_t1(self.whisperContext, index)
123136

124137
segments.append(
125138
.init(
@@ -130,20 +143,20 @@ public class Whisper {
130143
)
131144
}
132145

133-
if let cancelCallback {
146+
if let cancelCallback = self.cancelCallback {
134147
DispatchQueue.main.async {
135148
// Should cancel callback be called after delegate and completionHandler?
136149
cancelCallback()
137150

138151
let error = WhisperError.cancelled
139152

140153
self.delegate?.whisper(self, didErrorWith: error)
141-
completionHandler(.failure(error))
154+
wrappedCompletionHandler(.failure(error))
142155
}
143156
} else {
144157
DispatchQueue.main.async {
145158
self.delegate?.whisper(self, didCompleteWithSegments: segments)
146-
completionHandler(.success(segments))
159+
wrappedCompletionHandler(.success(segments))
147160
}
148161
}
149162

0 commit comments

Comments
 (0)