Skip to content

Commit 05c6af4

Browse files
authored
Merge pull request #52 from mutablelogic/v1
Add diarization and output formats
2 parents fd498e9 + 9a7843c commit 05c6af4

File tree

16 files changed

+276
-62
lines changed

16 files changed

+276
-62
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ docker: docker-dep submodule
5656
--build-arg OS=${OS} \
5757
--build-arg SOURCE=${BUILD_MODULE} \
5858
--build-arg VERSION=${VERSION} \
59-
-f etc/Dockerfile.${ARCH} .
59+
-f etc/Dockerfile.${OS}-${ARCH} .
6060

6161
# Test whisper bindings
6262
test: generate libwhisper libggml

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ curl -F model=ggml-medium-q5_0 -F file=@samples/jfk.wav localhost:8080/v1/audio/
6262
To translate a media file into a different language, you can use the following command:
6363

6464
```bash
65-
curl -F model=ggml-medium-q5_0 -F file=@samples/ge-podcast.wav -F language=en localhost:8080/v1/audio/translations\?stream=true
65+
curl -F model=ggml-medium-q5_0 -F file=@samples/de-podcast.wav -F language=en localhost:8080/v1/audio/translations\?stream=true
6666
```
6767

6868
There's more information on the API [here](doc/API.md).

cmd/server/flags.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ func NewFlags(name string, args []string) (*Flags, error) {
2020
FlagSet: flag.NewFlagSet(name, flag.ContinueOnError),
2121
}
2222
flags.endpoint = flags.String("endpoint", "/v1", "HTTP endpoint")
23-
flags.listen = flags.String("listen", ":8080", "HTTP Listen address")
23+
flags.listen = flags.String("listen", "127.0.0.1:8080", "HTTP Listen address")
2424
flags.dir = flags.String("dir", "${WHISPER_DATA}", "Model data directory")
25-
flags.debug = flags.Bool("debug", false, "Display debug information")
25+
flags.debug = flags.Bool("debug", false, "Output additional debug information")
2626

2727
// Parse flags and return any error
2828
return flags, flags.Parse(args)

cmd/server/main.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
httpserver "github.com/mutablelogic/go-server/pkg/httpserver"
1616
whisper "github.com/mutablelogic/go-whisper/pkg/whisper"
1717
api "github.com/mutablelogic/go-whisper/pkg/whisper/api"
18+
version "github.com/mutablelogic/go-whisper/pkg/whisper/version"
1819
)
1920

2021
func main() {
@@ -45,6 +46,16 @@ func main() {
4546
os.Exit(-1)
4647
}
4748

49+
// Print version
50+
if version.GitSource != "" {
51+
log.Println(name, version.GitSource)
52+
} else {
53+
log.Println(name)
54+
}
55+
if version.GitTag != "" {
56+
log.Println("Version:", version.GitTag)
57+
}
58+
4859
// Create a whisper service
4960
log.Println("Storing models at", dir)
5061
opts := []whisper.Opt{
@@ -77,7 +88,7 @@ func main() {
7788
api.RegisterEndpoints(flags.Endpoint(), mux, whisper)
7889

7990
// Create a new HTTP server
80-
log.Println("List address", flags.Listen())
91+
log.Println("Listen address", flags.Listen())
8192
server, err := httpserver.Config{
8293
Listen: flags.Listen(),
8394
Router: mux,

doc/API.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,20 @@ event: ok
144144

145145
### Translation
146146

147-
This is the same as transcription (above) except that the `language` parameter is not optional, and should be the language to translate the audio into.
147+
This is the same as transcription (above) except that the `language` parameter is always set to 'en', to translate the audio into English.
148148

149149
```html
150150
POST /v1/audio/translations
151151
POST /v1/audio/translations?stream={bool}
152152
```
153+
154+
### Diarization
155+
156+
To diarize an Enlgish-language audio file, use the following endpoint:
157+
158+
```html
159+
POST /v1/audio/diarize
160+
POST /v1/audio/diarize?stream={bool}
161+
```
162+
163+
The segments returned include a "speaker_turn" field which indicates that the segment is a new speaker. It requires a separate download of a [diarization model](https://huggingface.co/akashmjn/tinydiarize-whisper.cpp).
File renamed without changes.
File renamed without changes.

pkg/whisper/api/register.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func RegisterEndpoints(base string, mux *http.ServeMux, whisper *whisper.Whisper
7070

7171
switch r.Method {
7272
case http.MethodPost:
73-
TranscribeFile(r.Context(), whisper, w, r, true)
73+
TranscribeFile(r.Context(), whisper, w, r, Translate)
7474
default:
7575
httpresponse.Error(w, http.StatusMethodNotAllowed)
7676
}
@@ -84,7 +84,21 @@ func RegisterEndpoints(base string, mux *http.ServeMux, whisper *whisper.Whisper
8484

8585
switch r.Method {
8686
case http.MethodPost:
87-
TranscribeFile(r.Context(), whisper, w, r, false)
87+
TranscribeFile(r.Context(), whisper, w, r, Transcribe)
88+
default:
89+
httpresponse.Error(w, http.StatusMethodNotAllowed)
90+
}
91+
})
92+
93+
// Diarize: POST /v1/audio/diarize
94+
// Transcribes audio into the input language - language parameter should be set to the source
95+
// language of the audio. Output speaker parts.
96+
mux.HandleFunc(joinPath(base, "audio/diarize"), func(w http.ResponseWriter, r *http.Request) {
97+
defer r.Body.Close()
98+
99+
switch r.Method {
100+
case http.MethodPost:
101+
TranscribeFile(r.Context(), whisper, w, r, Diarize)
88102
default:
89103
httpresponse.Error(w, http.StatusMethodNotAllowed)
90104
}

pkg/whisper/api/transcribe.go

Lines changed: 120 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package api
22

33
import (
4+
"bytes"
45
"context"
56
"fmt"
67
"mime/multipart"
78
"net/http"
9+
"strings"
810
"time"
911

1012
// Packages
@@ -35,16 +37,37 @@ type queryTranscribe struct {
3537
Stream bool `json:"stream"`
3638
}
3739

40+
type TaskType int
41+
type ResponseFormat string
42+
43+
///////////////////////////////////////////////////////////////////////////////
44+
// GLOBALS
45+
3846
const (
3947
minSegmentSize = 5 * time.Second
4048
maxSegmentSize = 10 * time.Minute
4149
defaultSegmentSize = 5 * time.Minute
4250
)
4351

52+
const (
53+
_ TaskType = iota
54+
Transcribe // Transcribe audio
55+
Translate // Translate text
56+
Diarize // Diarize audio
57+
)
58+
59+
const (
60+
FormatJson ResponseFormat = "json"
61+
FormatText ResponseFormat = "text"
62+
FormatSrt ResponseFormat = "srt"
63+
FormatVerboseJson ResponseFormat = "verbose_json"
64+
FormatVtt ResponseFormat = "vtt"
65+
)
66+
4467
///////////////////////////////////////////////////////////////////////////////
4568
// PUBLIC METHODS
4669

47-
func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.ResponseWriter, r *http.Request, translate bool) {
70+
func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.ResponseWriter, r *http.Request, t TaskType) {
4871
var req reqTranscribe
4972
var query queryTranscribe
5073
if err := httprequest.Query(&query, r.URL.Query()); err != nil {
@@ -96,54 +119,86 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
96119

97120
// Get context for the model, perform transcription
98121
var result *schema.Transcription
99-
if err := service.WithModel(model, func(task *task.Context) error {
100-
// Check model
101-
if translate && !task.CanTranslate() {
102-
return ErrBadParameter.With("model is not multilingual, cannot translate")
103-
}
122+
if err := service.WithModel(model, func(taskctx *task.Context) error {
123+
result = taskctx.Result()
104124

105-
// Set parameters for transcription & translation, default to english
106-
task.SetTranslate(translate)
107-
if req.Language != nil {
108-
if err := task.SetLanguage(*req.Language); err != nil {
109-
return err
125+
switch t {
126+
case Translate:
127+
// Check model
128+
if !taskctx.CanTranslate() {
129+
return ErrBadParameter.With("model is not multilingual, cannot translate")
110130
}
111-
} else if translate {
112-
if err := task.SetLanguage("en"); err != nil {
131+
taskctx.SetTranslate(true)
132+
taskctx.SetDiarize(false)
133+
result.Task = "translate"
134+
135+
// Set language to EN
136+
if err := taskctx.SetLanguage("en"); err != nil {
113137
return err
114138
}
139+
case Diarize:
140+
taskctx.SetTranslate(false)
141+
taskctx.SetDiarize(true)
142+
result.Task = "diarize"
143+
144+
// Set language
145+
if req.Language != nil {
146+
if err := taskctx.SetLanguage(*req.Language); err != nil {
147+
return err
148+
}
149+
}
150+
default:
151+
// Transcribe
152+
taskctx.SetTranslate(false)
153+
taskctx.SetDiarize(false)
154+
result.Task = "transribe"
155+
156+
// Set language
157+
if req.Language != nil {
158+
if err := taskctx.SetLanguage(*req.Language); err != nil {
159+
return err
160+
}
161+
}
115162
}
116163

117164
// TODO: Set temperature, etc
118165

119166
// Create response
120-
result = task.Result()
121-
result.Task = "transcribe"
122-
if translate {
123-
result.Task = "translate"
124-
}
125-
result.Duration = schema.Timestamp(segmenter.Duration())
126-
result.Language = task.Language()
167+
result.Language = taskctx.Language()
127168

128169
// Output the header
129170
if stream != nil {
130-
stream.Write("task", result)
171+
stream.Write("task", taskctx.Result())
131172
}
132173

133174
// Read samples and transcribe them
134175
if err := segmenter.Decode(ctx, func(ts time.Duration, buf []float32) error {
135176
// Perform the transcription, return any errors
136-
return task.Transcribe(ctx, ts, buf, req.OutputSegments() || stream != nil, func(segment *schema.Segment) {
137-
if stream != nil {
177+
return taskctx.Transcribe(ctx, ts, buf, func(segment *schema.Segment) {
178+
// Segment callback
179+
if stream == nil {
180+
return
181+
}
182+
var buf bytes.Buffer
183+
switch req.ResponseFormat() {
184+
case FormatVerboseJson, FormatJson:
138185
stream.Write("segment", segment)
186+
return
187+
case FormatSrt:
188+
task.WriteSegmentSrt(&buf, segment)
189+
case FormatVtt:
190+
task.WriteSegmentVtt(&buf, segment)
191+
case FormatText:
192+
task.WriteSegmentText(&buf, segment)
139193
}
194+
stream.Write("segment", buf.String())
140195
})
141196
}); err != nil {
142197
return err
143198
}
144199

145200
// Set the language
146-
result.Language = task.Language()
201+
result.Language = taskctx.Language()
147202

148203
// Return success
149204
return nil
@@ -156,11 +211,32 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
156211
return
157212
}
158213

159-
// Return transcription if not streaming
160-
if stream == nil {
161-
httpresponse.JSON(w, result, http.StatusOK, 2)
162-
} else {
214+
// Return streaming ok
215+
if stream != nil {
163216
stream.Write("ok")
217+
return
218+
}
219+
220+
// Rrturn result based on response format
221+
switch req.ResponseFormat() {
222+
case FormatJson, FormatVerboseJson:
223+
httpresponse.JSON(w, result, http.StatusOK, 0)
224+
case FormatText:
225+
httpresponse.Text(w, "", http.StatusOK)
226+
for _, seg := range result.Segments {
227+
task.WriteSegmentText(w, seg)
228+
}
229+
w.Write([]byte("\n"))
230+
case FormatSrt:
231+
httpresponse.Text(w, "", http.StatusOK, "Content-Type", "application/x-subrip")
232+
for _, seg := range result.Segments {
233+
task.WriteSegmentSrt(w, seg)
234+
}
235+
case FormatVtt:
236+
httpresponse.Text(w, "WEBVTT\n\n", http.StatusOK, "Content-Type", "text/vtt")
237+
for _, seg := range result.Segments {
238+
task.WriteSegmentVtt(w, seg)
239+
}
164240
}
165241
}
166242

@@ -184,17 +260,29 @@ func (r reqTranscribe) Validate() error {
184260
}
185261
return nil
186262
}
187-
func (r reqTranscribe) ResponseFormat() string {
263+
func (r reqTranscribe) ResponseFormat() ResponseFormat {
188264
if r.ResponseFmt == nil {
189-
return "json"
265+
return FormatJson
266+
}
267+
switch strings.ToLower(*r.ResponseFmt) {
268+
case "json":
269+
return FormatJson
270+
case "text":
271+
return FormatText
272+
case "srt":
273+
return FormatSrt
274+
case "verbose_json":
275+
return FormatVerboseJson
276+
case "vtt":
277+
return FormatVtt
190278
}
191-
return *r.ResponseFmt
279+
return FormatJson
192280
}
193281

194282
func (r reqTranscribe) OutputSegments() bool {
195283
// We want to output segments if the response format is "srt", "verbose_json", "vtt"
196284
switch r.ResponseFormat() {
197-
case "srt", "verbose_json", "vtt":
285+
case FormatSrt, FormatVerboseJson, FormatVtt:
198286
return true
199287
default:
200288
return false

0 commit comments

Comments
 (0)