1
1
package api
2
2
3
3
import (
4
+ "bytes"
4
5
"context"
5
6
"fmt"
6
7
"mime/multipart"
7
8
"net/http"
9
+ "strings"
8
10
"time"
9
11
10
12
// Packages
@@ -35,16 +37,37 @@ type queryTranscribe struct {
35
37
Stream bool `json:"stream"`
36
38
}
37
39
40
+ type TaskType int
41
+ type ResponseFormat string
42
+
43
+ ///////////////////////////////////////////////////////////////////////////////
44
+ // GLOBALS
45
+
38
46
const (
39
47
minSegmentSize = 5 * time .Second
40
48
maxSegmentSize = 10 * time .Minute
41
49
defaultSegmentSize = 5 * time .Minute
42
50
)
43
51
52
+ const (
53
+ _ TaskType = iota
54
+ Transcribe // Transcribe audio
55
+ Translate // Translate text
56
+ Diarize // Diarize audio
57
+ )
58
+
59
+ const (
60
+ FormatJson ResponseFormat = "json"
61
+ FormatText ResponseFormat = "text"
62
+ FormatSrt ResponseFormat = "srt"
63
+ FormatVerboseJson ResponseFormat = "verbose_json"
64
+ FormatVtt ResponseFormat = "vtt"
65
+ )
66
+
44
67
///////////////////////////////////////////////////////////////////////////////
45
68
// PUBLIC METHODS
46
69
47
- func TranscribeFile (ctx context.Context , service * whisper.Whisper , w http.ResponseWriter , r * http.Request , translate bool ) {
70
+ func TranscribeFile (ctx context.Context , service * whisper.Whisper , w http.ResponseWriter , r * http.Request , t TaskType ) {
48
71
var req reqTranscribe
49
72
var query queryTranscribe
50
73
if err := httprequest .Query (& query , r .URL .Query ()); err != nil {
@@ -96,54 +119,86 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
96
119
97
120
// Get context for the model, perform transcription
98
121
var result * schema.Transcription
99
- if err := service .WithModel (model , func (task * task.Context ) error {
100
- // Check model
101
- if translate && ! task .CanTranslate () {
102
- return ErrBadParameter .With ("model is not multilingual, cannot translate" )
103
- }
122
+ if err := service .WithModel (model , func (taskctx * task.Context ) error {
123
+ result = taskctx .Result ()
104
124
105
- // Set parameters for transcription & translation, default to english
106
- task . SetTranslate ( translate )
107
- if req . Language != nil {
108
- if err := task . SetLanguage ( * req . Language ); err != nil {
109
- return err
125
+ switch t {
126
+ case Translate :
127
+ // Check model
128
+ if ! taskctx . CanTranslate () {
129
+ return ErrBadParameter . With ( "model is not multilingual, cannot translate" )
110
130
}
111
- } else if translate {
112
- if err := task .SetLanguage ("en" ); err != nil {
131
+ taskctx .SetTranslate (true )
132
+ taskctx .SetDiarize (false )
133
+ result .Task = "translate"
134
+
135
+ // Set language to EN
136
+ if err := taskctx .SetLanguage ("en" ); err != nil {
113
137
return err
114
138
}
139
+ case Diarize :
140
+ taskctx .SetTranslate (false )
141
+ taskctx .SetDiarize (true )
142
+ result .Task = "diarize"
143
+
144
+ // Set language
145
+ if req .Language != nil {
146
+ if err := taskctx .SetLanguage (* req .Language ); err != nil {
147
+ return err
148
+ }
149
+ }
150
+ default :
151
+ // Transcribe
152
+ taskctx .SetTranslate (false )
153
+ taskctx .SetDiarize (false )
154
+ result .Task = "transribe"
155
+
156
+ // Set language
157
+ if req .Language != nil {
158
+ if err := taskctx .SetLanguage (* req .Language ); err != nil {
159
+ return err
160
+ }
161
+ }
115
162
}
116
163
117
164
// TODO: Set temperature, etc
118
165
119
166
// Create response
120
- result = task .Result ()
121
- result .Task = "transcribe"
122
- if translate {
123
- result .Task = "translate"
124
- }
125
- result .Duration = schema .Timestamp (segmenter .Duration ())
126
- result .Language = task .Language ()
167
+ result .Language = taskctx .Language ()
127
168
128
169
// Output the header
129
170
if stream != nil {
130
- stream .Write ("task" , result )
171
+ stream .Write ("task" , taskctx . Result () )
131
172
}
132
173
133
174
// Read samples and transcribe them
134
175
if err := segmenter .Decode (ctx , func (ts time.Duration , buf []float32 ) error {
135
176
// Perform the transcription, return any errors
136
- return task .Transcribe (ctx , ts , buf , req .OutputSegments () || stream != nil , func (segment * schema.Segment ) {
137
- if stream != nil {
177
+ return taskctx .Transcribe (ctx , ts , buf , func (segment * schema.Segment ) {
178
+ // Segment callback
179
+ if stream == nil {
180
+ return
181
+ }
182
+ var buf bytes.Buffer
183
+ switch req .ResponseFormat () {
184
+ case FormatVerboseJson , FormatJson :
138
185
stream .Write ("segment" , segment )
186
+ return
187
+ case FormatSrt :
188
+ task .WriteSegmentSrt (& buf , segment )
189
+ case FormatVtt :
190
+ task .WriteSegmentVtt (& buf , segment )
191
+ case FormatText :
192
+ task .WriteSegmentText (& buf , segment )
139
193
}
194
+ stream .Write ("segment" , buf .String ())
140
195
})
141
196
}); err != nil {
142
197
return err
143
198
}
144
199
145
200
// Set the language
146
- result .Language = task .Language ()
201
+ result .Language = taskctx .Language ()
147
202
148
203
// Return success
149
204
return nil
@@ -156,11 +211,32 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
156
211
return
157
212
}
158
213
159
- // Return transcription if not streaming
160
- if stream == nil {
161
- httpresponse .JSON (w , result , http .StatusOK , 2 )
162
- } else {
214
+ // Return streaming ok
215
+ if stream != nil {
163
216
stream .Write ("ok" )
217
+ return
218
+ }
219
+
220
+ // Rrturn result based on response format
221
+ switch req .ResponseFormat () {
222
+ case FormatJson , FormatVerboseJson :
223
+ httpresponse .JSON (w , result , http .StatusOK , 0 )
224
+ case FormatText :
225
+ httpresponse .Text (w , "" , http .StatusOK )
226
+ for _ , seg := range result .Segments {
227
+ task .WriteSegmentText (w , seg )
228
+ }
229
+ w .Write ([]byte ("\n " ))
230
+ case FormatSrt :
231
+ httpresponse .Text (w , "" , http .StatusOK , "Content-Type" , "application/x-subrip" )
232
+ for _ , seg := range result .Segments {
233
+ task .WriteSegmentSrt (w , seg )
234
+ }
235
+ case FormatVtt :
236
+ httpresponse .Text (w , "WEBVTT\n \n " , http .StatusOK , "Content-Type" , "text/vtt" )
237
+ for _ , seg := range result .Segments {
238
+ task .WriteSegmentVtt (w , seg )
239
+ }
164
240
}
165
241
}
166
242
@@ -184,17 +260,29 @@ func (r reqTranscribe) Validate() error {
184
260
}
185
261
return nil
186
262
}
187
- func (r reqTranscribe ) ResponseFormat () string {
263
+ func (r reqTranscribe ) ResponseFormat () ResponseFormat {
188
264
if r .ResponseFmt == nil {
189
- return "json"
265
+ return FormatJson
266
+ }
267
+ switch strings .ToLower (* r .ResponseFmt ) {
268
+ case "json" :
269
+ return FormatJson
270
+ case "text" :
271
+ return FormatText
272
+ case "srt" :
273
+ return FormatSrt
274
+ case "verbose_json" :
275
+ return FormatVerboseJson
276
+ case "vtt" :
277
+ return FormatVtt
190
278
}
191
- return * r . ResponseFmt
279
+ return FormatJson
192
280
}
193
281
194
282
func (r reqTranscribe ) OutputSegments () bool {
195
283
// We want to output segments if the response format is "srt", "verbose_json", "vtt"
196
284
switch r .ResponseFormat () {
197
- case "srt" , "verbose_json" , "vtt" :
285
+ case FormatSrt , FormatVerboseJson , FormatVtt :
198
286
return true
199
287
default :
200
288
return false
0 commit comments