Skip to content

Commit e66dc08

Browse files
committed
Updated go-whisper
1 parent c262d1a commit e66dc08

File tree

7 files changed

+218
-106
lines changed

7 files changed

+218
-106
lines changed

pkg/whisper/api/transcribe.go

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/mutablelogic/go-whisper/pkg/whisper"
1414
"github.com/mutablelogic/go-whisper/pkg/whisper/segmenter"
1515
"github.com/mutablelogic/go-whisper/pkg/whisper/task"
16+
"github.com/mutablelogic/go-whisper/pkg/whisper/transcription"
1617

1718
// Namespace imports
1819
. "github.com/djthorpe/go-errors"
@@ -25,11 +26,17 @@ type reqTranscribe struct {
2526
File *multipart.FileHeader `json:"file"`
2627
Model string `json:"model"`
2728
Language *string `json:"language"`
28-
Prompt *string `json:"prompt"`
29-
ResponseFmt *string `json:"response_format"`
3029
Temperature *float32 `json:"temperature"`
30+
SegmentSize *time.Duration `json:"segment_size"`
31+
ResponseFmt *string `json:"response_format"`
3132
}
3233

34+
const (
35+
minSegmentSize = 5 * time.Second
36+
maxSegmentSize = 10 * time.Minute
37+
defaultSegmentSize = 5 * time.Minute
38+
)
39+
3340
///////////////////////////////////////////////////////////////////////////////
3441
// PUBLIC METHODS
3542

@@ -61,14 +68,15 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
6168
}
6269
defer f.Close()
6370

64-
// Create a segmenter - read segments of 5 min samples
65-
segmenter, err := segmenter.New(f, 5*time.Minute, whisper.SampleRate)
71+
// Create a segmenter - read segments based on requested segment size
72+
segmenter, err := segmenter.New(f, req.SegmentDur(), whisper.SampleRate)
6673
if err != nil {
6774
httpresponse.Error(w, http.StatusBadRequest, err.Error())
6875
return
6976
}
7077

7178
// Get context for the model, perform transcription
79+
var result *transcription.Transcription
7280
if err := service.WithModel(model, func(task *task.Context) error {
7381
// Check model
7482
if translate && !task.CanTranslate() {
@@ -90,18 +98,29 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
9098
// TODO: Set temperature, etc
9199

92100
// Read samples and transcribe them
93-
return segmenter.Decode(ctx, func(ts time.Duration, buf []float32) error {
94-
fmt.Println("audio segment", ts, len(buf))
95-
101+
if err := segmenter.Decode(ctx, func(ts time.Duration, buf []float32) error {
96102
// Perform the transcription, return any errors
97-
return task.Transcribe(ctx, buf)
98-
})
103+
return task.Transcribe(ctx, ts, buf, func(segment *transcription.Segment) {
104+
fmt.Println("TODO: ", segment)
105+
})
106+
}); err != nil {
107+
return err
108+
}
109+
110+
// End of transcription, get result
111+
result = task.Result()
112+
113+
// Return success
114+
return nil
99115
}); err != nil {
100116
httpresponse.Error(w, http.StatusInternalServerError, err.Error())
101117
return
102118
}
103119

104-
var result whisper.Transcription
120+
// Set duration
121+
result.Duration = segmenter.Duration()
122+
123+
// Return transcription
105124
httpresponse.JSON(w, result, http.StatusOK, 2)
106125
}
107126

@@ -131,3 +150,16 @@ func (r reqTranscribe) ResponseFormat() string {
131150
}
132151
return *r.ResponseFmt
133152
}
153+
154+
func (r reqTranscribe) SegmentDur() time.Duration {
155+
if r.SegmentSize == nil {
156+
return defaultSegmentSize
157+
}
158+
if *r.SegmentSize < minSegmentSize {
159+
return minSegmentSize
160+
}
161+
if *r.SegmentSize > maxSegmentSize {
162+
return maxSegmentSize
163+
}
164+
return *r.SegmentSize
165+
}

pkg/whisper/model/store.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,9 @@ func listModels(path, ext string) ([]*Model, error) {
309309
}
310310

311311
func modelNameToId(name string) string {
312+
// Lowercase the name, remove the extension
313+
name = strings.TrimSuffix(strings.ToLower(name), filepath.Ext(name))
314+
312315
// We replace all non-alphanumeric characters with underscores
313316
return strings.Map(func(r rune) rune {
314317
if r >= 'a' && r <= 'z' {

pkg/whisper/segmenter/segmenter.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,12 @@ func (s *Segmenter) Decode(ctx context.Context, fn SegmentFunc) error {
128128
// Return success
129129
return nil
130130
}
131+
132+
// Return the duration from the file or timestamp
133+
func (s *Segmenter) Duration() time.Duration {
134+
if s.reader != nil {
135+
return s.reader.Duration()
136+
} else {
137+
return s.ts + time.Duration(len(s.buf))*time.Second/time.Duration(s.sample_rate)
138+
}
139+
}

pkg/whisper/task/context.go

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ import (
66
"fmt"
77
"path/filepath"
88
"sync"
9+
"time"
910

1011
// Packages
1112
model "github.com/mutablelogic/go-whisper/pkg/whisper/model"
13+
transcription "github.com/mutablelogic/go-whisper/pkg/whisper/transcription"
1214
whisper "github.com/mutablelogic/go-whisper/sys/whisper"
1315

1416
// Namespace imports
@@ -22,13 +24,20 @@ import (
2224
type Context struct {
2325
sync.Mutex
2426

27+
// Model Id and whisper context
2528
model string
2629
whisper *whisper.Context
2730

2831
// Parameters for the next transcription
2932
params whisper.FullParams
33+
34+
// Collect the transcription
35+
result *transcription.Transcription
3036
}
3137

38+
// Callback for new segments during the transcription process
39+
type NewSegmentFunc func(*transcription.Segment)
40+
3241
//////////////////////////////////////////////////////////////////////////////
3342
// LIFECYCLE
3443

@@ -84,7 +93,6 @@ func (ctx *Context) Close() error {
8493

8594
// Release resources
8695
if ctx.whisper != nil {
87-
fmt.Printf("Release model resources %v\n", ctx)
8896
whisper.Whisper_free(ctx.whisper)
8997
}
9098
ctx.whisper = nil
@@ -97,7 +105,7 @@ func (ctx *Context) Close() error {
97105
//////////////////////////////////////////////////////////////////////////////
98106
// STRINGIFY
99107

100-
func (ctx Context) MarshalJSON() ([]byte, error) {
108+
func (ctx *Context) MarshalJSON() ([]byte, error) {
101109
type j struct {
102110
Model string `json:"model"`
103111
Params whisper.FullParams `json:"params"`
@@ -110,7 +118,7 @@ func (ctx Context) MarshalJSON() ([]byte, error) {
110118
})
111119
}
112120

113-
func (ctx Context) String() string {
121+
func (ctx *Context) String() string {
114122
data, err := json.MarshalIndent(ctx, "", " ")
115123
if err != nil {
116124
return err.Error()
@@ -145,34 +153,47 @@ func (task *Context) CanTranslate() bool {
145153

146154
// Transcribe samples. The samples should be 16KHz float32 samples in
147155
// a single channel.
148-
// TODO: We need a low-latency streaming version of this function.
149-
// TODO: We need a callback for segment progress.
150-
func (task *Context) Transcribe(ctx context.Context, samples []float32) error {
156+
func (task *Context) Transcribe(ctx context.Context, ts time.Duration, samples []float32, fn NewSegmentFunc) error {
151157
// Set the 'abort' function
152-
/*task.params.SetAbortCallback(task.whisper, func() bool {
158+
task.params.SetAbortCallback(task.whisper, func() bool {
153159
select {
154160
case <-ctx.Done():
155161
return true
156162
default:
157163
return false
158164
}
159-
})*/
165+
})
166+
167+
// Set the new segment function
168+
if fn != nil {
169+
task.params.SetSegmentCallback(task.whisper, func(new_segments int) {
170+
num_segments := task.whisper.NumSegments()
171+
for i := num_segments - new_segments; i < num_segments; i++ {
172+
fn(transcription.NewSegment(ts, task.whisper.Segment(i)))
173+
}
174+
})
175+
}
160176

161-
// Set the 'progress' function
162-
//task.params.SetProgressCallback(task.whisper, func(percent int) {
163-
// fmt.Printf("Progress: %v\n", percent)
164-
//})
177+
// TODO: Set the initial prompt tokens from any previous transcription call
165178

166179
// Perform the transcription
167180
if err := whisper.Whisper_full(task.whisper, task.params, samples); err != nil {
168-
return err
181+
if ctx.Err() != nil {
182+
return ctx.Err()
183+
} else {
184+
return err
185+
}
169186
}
170187

171-
// Get segments
172-
for i := 0; i < task.whisper.NumSegments(); i++ {
173-
segment := task.whisper.Segment(i)
174-
fmt.Printf("Segment: %v\n", segment.Text)
188+
// Remove the callbacks
189+
task.params.SetAbortCallback(task.whisper, nil)
190+
task.params.SetSegmentCallback(task.whisper, nil)
191+
192+
// Append the transcription
193+
if task.result == nil {
194+
task.result = transcription.New()
175195
}
196+
task.result.Append(task.whisper, ts)
176197

177198
// Return success
178199
return nil
@@ -194,6 +215,12 @@ func (ctx *Context) SetLanguage(v string) error {
194215
return nil
195216
}
196217

218+
// Set translate to true or false
197219
func (ctx *Context) SetTranslate(v bool) {
198220
ctx.params.SetTranslate(v)
199221
}
222+
223+
// Return the transcription result
224+
func (ctx *Context) Result() *transcription.Transcription {
225+
return ctx.result
226+
}

pkg/whisper/transcription.go

Lines changed: 0 additions & 32 deletions
This file was deleted.
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package transcription
2+
3+
import (
4+
"encoding/json"
5+
"time"
6+
7+
"github.com/mutablelogic/go-whisper/sys/whisper"
8+
)
9+
10+
//////////////////////////////////////////////////////////////////////////////
11+
// TYPES
12+
13+
type Transcription struct {
14+
Task string `json:"task,omitempty"`
15+
Language string `json:"language,omitempty" writer:",width:8"`
16+
Duration time.Duration `json:"duration,omitempty" writer:",width:8,right"`
17+
Text string `json:"text" writer:",width:60,wrap"`
18+
Segments []Segment `json:"segments,omitempty" writer:",width:40,wrap"`
19+
}
20+
21+
type Segment struct {
22+
Id int32 `json:"id"`
23+
Start time.Duration `json:"start"`
24+
End time.Duration `json:"end"`
25+
Text string `json:"text"`
26+
SpeakerTurn bool `json:"speaker_turn,omitempty"`
27+
}
28+
29+
//////////////////////////////////////////////////////////////////////////////
30+
// LIFECYCLE
31+
32+
// Create a new transcription from a context
33+
func New() *Transcription {
34+
return new(Transcription)
35+
}
36+
37+
func NewSegment(ts time.Duration, seg *whisper.Segment) *Segment {
38+
// Dumb copy function
39+
return &Segment{
40+
Id: seg.Id,
41+
Text: seg.Text,
42+
Start: seg.T0 + ts,
43+
End: seg.T1 + ts,
44+
SpeakerTurn: seg.SpeakerTurn,
45+
}
46+
}
47+
48+
//////////////////////////////////////////////////////////////////////////////
49+
// STRINGIFY
50+
51+
func (t *Transcription) String() string {
52+
data, err := json.MarshalIndent(t, "", " ")
53+
if err != nil {
54+
return err.Error()
55+
}
56+
return string(data)
57+
}
58+
59+
func (s *Segment) String() string {
60+
data, err := json.MarshalIndent(s, "", " ")
61+
if err != nil {
62+
return err.Error()
63+
}
64+
return string(data)
65+
}
66+
67+
//////////////////////////////////////////////////////////////////////////////
68+
// PUBLIC METHODS
69+
70+
// Append a transcription to the current transcription, with the offset timestamp
71+
func (t *Transcription) Append(ctx *whisper.Context, ts time.Duration) {
72+
// Append the segment text
73+
for i := 0; i < ctx.NumSegments(); i++ {
74+
seg := ctx.Segment(i)
75+
if t.Text == "" {
76+
t.Text = seg.Text
77+
} else {
78+
t.Text += " " + seg.Text
79+
}
80+
}
81+
}

0 commit comments

Comments
 (0)