Skip to content

Commit b10d10c

Browse files
committed
Updated whisper
1 parent 0aa5997 commit b10d10c

File tree

11 files changed

+103
-58
lines changed

11 files changed

+103
-58
lines changed

go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ require (
88
github.com/djthorpe/go-tablewriter v0.0.8
99
github.com/go-audio/wav v1.1.0
1010
github.com/mutablelogic/go-client v1.0.9
11-
github.com/mutablelogic/go-media v1.6.8
12-
github.com/mutablelogic/go-server v1.4.13
11+
github.com/mutablelogic/go-media v1.6.9
12+
github.com/mutablelogic/go-server v1.4.14
1313
github.com/stretchr/testify v1.9.0
1414
)
1515

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ github.com/mutablelogic/go-client v1.0.9 h1:Eh4sjQOFDldP/L3IizqkcOD3WigZR+u1VaHT
2828
github.com/mutablelogic/go-client v1.0.9/go.mod h1:VLyB8j8IBJSK/FXvvqhmq93PRWDKkyLu8R7V2Vudb6A=
2929
github.com/mutablelogic/go-media v1.6.8 h1:3v4povSQlOnvg9mHx6Bp9NVdCCjrNdDCjMHBGFHnVE8=
3030
github.com/mutablelogic/go-media v1.6.8/go.mod h1:HulNT0yyH63a3FRlbuzNDakhOypYrmtFVkHEXZjDgAY=
31-
github.com/mutablelogic/go-server v1.4.13 h1:k5LJJ/pCvyiw34UX341vRhliBOS6i7V65U/UICcOJOA=
32-
github.com/mutablelogic/go-server v1.4.13/go.mod h1:9nenPAohKu8bFoRgwHJh+3s8h0kLFjUAb8KZvT1TQNU=
31+
github.com/mutablelogic/go-server v1.4.14 h1:MsYyS9MjBoYtWfJo/iw6DnZ8slnhakWhPPqVCuzuaV8=
32+
github.com/mutablelogic/go-server v1.4.14/go.mod h1:9nenPAohKu8bFoRgwHJh+3s8h0kLFjUAb8KZvT1TQNU=
3333
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
3434
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
3535
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=

pkg/whisper/api/models.go

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package api
22

33
import (
44
"context"
5-
"encoding/json"
65
"errors"
76
"fmt"
87
"net/http"
@@ -49,15 +48,13 @@ func ListModels(ctx context.Context, w http.ResponseWriter, service *whisper.Whi
4948
}
5049

5150
func DownloadModel(ctx context.Context, w http.ResponseWriter, r *http.Request, service *whisper.Whisper) {
52-
// Get query
51+
// Get query and body
5352
var query queryDownloadModel
53+
var req reqDownloadModel
5454
if err := httprequest.Query(&query, r.URL.Query()); err != nil {
5555
httpresponse.Error(w, http.StatusBadRequest, err.Error())
5656
return
5757
}
58-
59-
// Get request body
60-
var req reqDownloadModel
6158
if err := httprequest.Body(&req, r); err != nil {
6259
httpresponse.Error(w, http.StatusBadRequest, err.Error())
6360
return
@@ -69,34 +66,31 @@ func DownloadModel(ctx context.Context, w http.ResponseWriter, r *http.Request,
6966
return
7067
}
7168

72-
// If we're streaming, then set response to streaming
69+
// Create a text stream
70+
var stream *httpresponse.TextStream
7371
if query.Stream {
74-
httpresponse.JSON(w, respDownloadModelStatus{
75-
Status: fmt.Sprint("downloading ", req.Name()),
76-
}, http.StatusProcessing, 0)
72+
if stream = httpresponse.NewTextStream(w); stream == nil {
73+
httpresponse.Error(w, http.StatusInternalServerError, "Cannot create text stream")
74+
return
75+
}
76+
defer stream.Close()
7777
}
7878

7979
// Download the model
8080
t := time.Now()
8181
model, err := service.DownloadModel(ctx, req.Name(), func(curBytes, totalBytes uint64) {
82-
if time.Since(t) > time.Second && query.Stream {
82+
if time.Since(t) > time.Second && stream != nil {
8383
t = time.Now()
84-
json.NewEncoder(w).Encode(respDownloadModelStatus{
84+
stream.Write("progress", respDownloadModelStatus{
8585
Status: fmt.Sprint("downloading ", req.Name()),
8686
Total: totalBytes,
8787
Completed: curBytes,
8888
})
89-
// Flush the response
90-
if f, ok := w.(http.Flusher); ok {
91-
f.Flush()
92-
}
9389
}
9490
})
9591
if err != nil {
96-
if query.Stream {
97-
json.NewEncoder(w).Encode(respDownloadModelStatus{
98-
Status: fmt.Sprint("error ", err.Error()),
99-
})
92+
if stream != nil {
93+
stream.Write("error", err.Error())
10094
} else {
10195
httpresponse.Error(w, http.StatusBadGateway, err.Error())
10296
}
@@ -105,7 +99,7 @@ func DownloadModel(ctx context.Context, w http.ResponseWriter, r *http.Request,
10599

106100
// Return the model information
107101
if query.Stream {
108-
json.NewEncoder(w).Encode(model)
102+
stream.Write("ok", model)
109103
} else {
110104
httpresponse.JSON(w, model, http.StatusCreated, 2)
111105
}

pkg/whisper/api/transcribe.go

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ type reqTranscribe struct {
3131
ResponseFmt *string `json:"response_format"`
3232
}
3333

34+
type queryTranscribe struct {
35+
Stream bool `json:"stream"`
36+
}
37+
3438
const (
3539
minSegmentSize = 5 * time.Second
3640
maxSegmentSize = 10 * time.Minute
@@ -42,6 +46,11 @@ const (
4246

4347
func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.ResponseWriter, r *http.Request, translate bool) {
4448
var req reqTranscribe
49+
var query queryTranscribe
50+
if err := httprequest.Query(&query, r.URL.Query()); err != nil {
51+
httpresponse.Error(w, http.StatusBadRequest, err.Error())
52+
return
53+
}
4554
if err := httprequest.Body(&req, r); err != nil {
4655
httpresponse.Error(w, http.StatusBadRequest, err.Error())
4756
return
@@ -75,6 +84,16 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
7584
return
7685
}
7786

87+
// Create a text stream
88+
var stream *httpresponse.TextStream
89+
if query.Stream {
90+
if stream = httpresponse.NewTextStream(w); stream == nil {
91+
httpresponse.Error(w, http.StatusInternalServerError, "Cannot create text stream")
92+
return
93+
}
94+
defer stream.Close()
95+
}
96+
7897
// Get context for the model, perform transcription
7998
var result *schema.Transcription
8099
if err := service.WithModel(model, func(task *task.Context) error {
@@ -97,35 +116,52 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
97116

98117
// TODO: Set temperature, etc
99118

119+
// Create response
120+
result = task.Result()
121+
result.Task = "transcribe"
122+
if translate {
123+
result.Task = "translate"
124+
}
125+
result.Duration = schema.Timestamp(segmenter.Duration())
126+
result.Language = task.Language()
127+
128+
// Output the header
129+
if stream != nil {
130+
stream.Write("task", result)
131+
}
132+
100133
// Read samples and transcribe them
101134
if err := segmenter.Decode(ctx, func(ts time.Duration, buf []float32) error {
102135
// Perform the transcription, return any errors
103-
return task.Transcribe(ctx, ts, buf, req.OutputSegments(), func(segment *schema.Segment) {
104-
fmt.Println("TODO: ", segment)
136+
return task.Transcribe(ctx, ts, buf, req.OutputSegments() || stream != nil, func(segment *schema.Segment) {
137+
if stream != nil {
138+
stream.Write("segment", segment)
139+
}
105140
})
106141
}); err != nil {
107142
return err
108143
}
109144

110-
// End of transcription, get result
111-
result = task.Result()
145+
// Set the language
146+
result.Language = task.Language()
112147

113148
// Return success
114149
return nil
115150
}); err != nil {
116-
httpresponse.Error(w, http.StatusInternalServerError, err.Error())
151+
if stream != nil {
152+
stream.Write("error", err.Error())
153+
} else {
154+
httpresponse.Error(w, http.StatusInternalServerError, err.Error())
155+
}
117156
return
118157
}
119158

120-
// Set task, duration
121-
result.Task = "transcribe"
122-
if translate {
123-
result.Task = "translate"
159+
// Return transcription if not streaming
160+
if stream == nil {
161+
httpresponse.JSON(w, result, http.StatusOK, 2)
162+
} else {
163+
stream.Write("ok")
124164
}
125-
result.Duration = segmenter.Duration()
126-
127-
// Return transcription
128-
httpresponse.JSON(w, result, http.StatusOK, 2)
129165
}
130166

131167
///////////////////////////////////////////////////////////////////////////////

pkg/whisper/model/store.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,12 @@ func (s *Store) Rescan() error {
120120
}
121121

122122
// Return a model by its Id
123-
func (s *Store) ById(name string) *schema.Model {
123+
func (s *Store) ById(id string) *schema.Model {
124124
s.RLock()
125125
defer s.RUnlock()
126-
name = modelNameToId(name)
126+
127127
for _, model := range s.models {
128-
if model.Id == name {
128+
if model.Id == id {
129129
return model
130130
}
131131
}

pkg/whisper/schema/segment.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,17 @@ package schema
22

33
import (
44
"encoding/json"
5-
"time"
65
)
76

87
//////////////////////////////////////////////////////////////////////////////
98
// TYPES
109

1110
type Segment struct {
12-
Id int32 `json:"id"`
13-
Start time.Duration `json:"start"`
14-
End time.Duration `json:"end"`
15-
Text string `json:"text"`
16-
SpeakerTurn bool `json:"speaker_turn,omitempty"` // TODO
11+
Id int32 `json:"id"`
12+
Start Timestamp `json:"start"`
13+
End Timestamp `json:"end"`
14+
Text string `json:"text"`
15+
SpeakerTurn bool `json:"speaker_turn,omitempty"` // TODO
1716
}
1817

1918
//////////////////////////////////////////////////////////////////////////////

pkg/whisper/schema/transcription.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ import (
88
//////////////////////////////////////////////////////////////////////////////
99
// TYPES
1010

11+
type Timestamp time.Duration
12+
1113
type Transcription struct {
12-
Task string `json:"task,omitempty"`
13-
Language string `json:"language,omitempty" writer:",width:8"`
14-
Duration time.Duration `json:"duration,omitempty" writer:",width:8,right"`
15-
Text string `json:"text" writer:",width:60,wrap"`
16-
Segments []*Segment `json:"segments,omitempty" writer:",width:40,wrap"`
14+
Task string `json:"task,omitempty"`
15+
Language string `json:"language,omitempty" writer:",width:8"`
16+
Duration Timestamp `json:"duration,omitempty" writer:",width:8,right"`
17+
Text string `json:"text,omitempty" writer:",width:60,wrap"`
18+
Segments []*Segment `json:"segments,omitempty" writer:",width:40,wrap"`
1719
}
1820

1921
//////////////////////////////////////////////////////////////////////////////
@@ -26,3 +28,8 @@ func (t *Transcription) String() string {
2628
}
2729
return string(data)
2830
}
31+
32+
func (t Timestamp) MarshalJSON() ([]byte, error) {
33+
// We convert durations into float64 seconds
34+
return json.Marshal(time.Duration(t).Seconds())
35+
}

pkg/whisper/task/context.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ func (ctx *Context) Is(model *schema.Model) bool {
143143
func (task *Context) CopyParams() {
144144
task.params = whisper.DefaultFullParams(whisper.SAMPLING_GREEDY)
145145
task.params.SetLanguage("auto")
146-
task.result = nil
146+
task.result = new(schema.Transcription)
147147
}
148148

149149
// Model is multilingual and can translate
@@ -191,9 +191,6 @@ func (task *Context) Transcribe(ctx context.Context, ts time.Duration, samples [
191191
task.params.SetSegmentCallback(task.whisper, nil)
192192

193193
// Append the transcription
194-
if task.result == nil {
195-
task.result = new(schema.Transcription)
196-
}
197194
task.appendResult(ts, segments)
198195

199196
// Return success
@@ -216,6 +213,10 @@ func (ctx *Context) SetLanguage(v string) error {
216213
return nil
217214
}
218215

216+
func (ctx *Context) Language() string {
217+
return ctx.params.Language()
218+
}
219+
219220
// Set translate to true or false
220221
func (ctx *Context) SetTranslate(v bool) {
221222
ctx.params.SetTranslate(v)

pkg/whisper/task/transcription.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ func newSegment(ts time.Duration, seg *whisper.Segment) *schema.Segment {
1616
return &schema.Segment{
1717
Id: seg.Id,
1818
Text: seg.Text,
19-
Start: seg.T0 + ts,
20-
End: seg.T1 + ts,
19+
Start: schema.Timestamp(seg.T0 + ts),
20+
End: schema.Timestamp(seg.T1 + ts),
2121
SpeakerTurn: seg.SpeakerTurn,
2222
}
2323
}

sys/whisper/fullparams.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,14 @@ func (c *FullParams) SetLanguage(v string) {
253253
}
254254
}
255255

256+
func (c *FullParams) Language() string {
257+
v := C.GoString(c.language)
258+
if v == "" {
259+
return "auto"
260+
}
261+
return v
262+
}
263+
256264
func (c *FullParams) SetProgressCallback(ctx *Context, cb ProgressCallback) {
257265
key := cbkey(unsafe.Pointer(ctx))
258266
if cb == nil {

0 commit comments

Comments
 (0)