Skip to content

Commit f1a5b14

Browse files
committed
Updated
1 parent 850f8a2 commit f1a5b14

File tree

3 files changed

+41
-11
lines changed

3 files changed

+41
-11
lines changed

pkg/whisper/task/context.go

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package task
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67
"path/filepath"
@@ -20,7 +21,9 @@ import (
2021
type Context struct {
2122
model string
2223
whisper *whisper.Context
23-
params whisper.FullParams
24+
25+
// Parameters for the next transcription
26+
params whisper.FullParams
2427
}
2528

2629
//////////////////////////////////////////////////////////////////////////////
@@ -60,8 +63,6 @@ func (m *Context) Init(path string, model *model.Model, gpu int) error {
6063
// Set resources
6164
m.whisper = ctx
6265
m.model = model.Id
63-
m.params = whisper.DefaultFullParams(whisper.SAMPLING_GREEDY)
64-
m.params.SetLanguage("auto")
6566

6667
// Return success
6768
return nil
@@ -125,13 +126,39 @@ func (ctx *Context) Is(model *model.Model) bool {
125126
return ctx.model == model.Id
126127
}
127128

129+
// Copy task parameters from the default
130+
func (task *Context) CopyParams() {
131+
task.params = whisper.DefaultFullParams(whisper.SAMPLING_GREEDY)
132+
task.params.SetLanguage("auto")
133+
}
134+
128135
// Transcribe samples. The samples should be 16KHz float32 samples in
129136
// a single channel.
130137
// TODO: We need a low-latency streaming version of this function.
131138
// TODO: We need a callback for segment progress.
132-
func (ctx *Context) Transcribe(samples []float32) error {
139+
func (task *Context) Transcribe(ctx context.Context, samples []float32) error {
140+
// Set the 'abort' function
141+
task.params.SetAbortCallback(task.whisper, func() bool {
142+
select {
143+
case <-ctx.Done():
144+
return true
145+
default:
146+
return false
147+
}
148+
})
149+
150+
// Set the 'progress' function
151+
task.params.SetProgressCallback(task.whisper, func(percent int) {
152+
fmt.Printf("Progress: %v\n", percent)
153+
})
154+
133155
// Perform the transcription
134-
return whisper.Whisper_full(ctx.whisper, ctx.params, samples)
156+
if err := whisper.Whisper_full(task.whisper, task.params, samples); err != nil {
157+
return err
158+
}
159+
160+
// Return success
161+
return nil
135162
}
136163

137164
// Set the language. For transcription, this is the language of the

pkg/whisper/whisper.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ func (w *Whisper) WithModel(model *model.Model, fn func(task *task.Context) erro
175175
}
176176
defer w.pool.Put(task)
177177

178+
// Copy parameters
179+
task.CopyParams()
180+
178181
// Execute the function
179182
return fn(task)
180183
}

pkg/whisper/whisper_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ func Test_whisper_005(t *testing.T) {
225225

226226
assert.NoError(service.WithModel(model, func(task *task.Context) error {
227227
t.Log("Transcribing", len(samples), "samples")
228-
return task.Transcribe(samples)
228+
return task.Transcribe(context.Background(), samples)
229229
}))
230230
})
231231

@@ -241,7 +241,7 @@ func Test_whisper_005(t *testing.T) {
241241

242242
assert.NoError(service.WithModel(model, func(task *task.Context) error {
243243
t.Log("Transcribing", len(samples), "samples")
244-
return task.Transcribe(samples)
244+
return task.Transcribe(context.Background(), samples)
245245
}))
246246
})
247247

@@ -257,7 +257,7 @@ func Test_whisper_005(t *testing.T) {
257257

258258
assert.NoError(service.WithModel(model, func(task *task.Context) error {
259259
t.Log("Transcribing", len(samples), "samples")
260-
return task.Transcribe(samples)
260+
return task.Transcribe(context.Background(), samples)
261261
}))
262262
})
263263
}
@@ -294,7 +294,7 @@ func Test_whisper_006(t *testing.T) {
294294

295295
assert.NoError(service.WithModel(model, func(task *task.Context) error {
296296
t.Log("Transcribing", len(samples), "samples")
297-
return task.Transcribe(samples)
297+
return task.Transcribe(context.Background(), samples)
298298
}))
299299
})
300300

@@ -312,7 +312,7 @@ func Test_whisper_006(t *testing.T) {
312312

313313
assert.NoError(service.WithModel(model, func(task *task.Context) error {
314314
t.Log("Transcribing", len(samples), "samples")
315-
return task.Transcribe(samples)
315+
return task.Transcribe(context.Background(), samples)
316316
}))
317317
})
318318

@@ -330,7 +330,7 @@ func Test_whisper_006(t *testing.T) {
330330

331331
assert.NoError(service.WithModel(model, func(task *task.Context) error {
332332
t.Log("Transcribing", len(samples), "samples")
333-
return task.Transcribe(samples)
333+
return task.Transcribe(context.Background(), samples)
334334
}))
335335
})
336336
})

0 commit comments

Comments
 (0)