Skip to content

Commit c262d1a

Browse files
committed
Integrated segmenter
1 parent adbb10d commit c262d1a

File tree

5 files changed

+39
-38
lines changed

5 files changed

+39
-38
lines changed

pkg/whisper/api/transcribe.go

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ import (
55
"fmt"
66
"mime/multipart"
77
"net/http"
8+
"time"
89

910
// Packages
10-
11-
"github.com/go-audio/wav"
1211
"github.com/mutablelogic/go-server/pkg/httprequest"
1312
"github.com/mutablelogic/go-server/pkg/httpresponse"
1413
"github.com/mutablelogic/go-whisper/pkg/whisper"
14+
"github.com/mutablelogic/go-whisper/pkg/whisper/segmenter"
1515
"github.com/mutablelogic/go-whisper/pkg/whisper/task"
1616

1717
// Namespace imports
@@ -53,12 +53,6 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
5353
return
5454
}
5555

56-
// Check audio format - allow WAV or binary
57-
if req.File.Header.Get("Content-Type") != "audio/wav" && req.File.Header.Get("Content-Type") != httprequest.ContentTypeBinary {
58-
httpresponse.Error(w, http.StatusBadRequest, "unsupported audio format:", req.File.Header.Get("Content-Type"))
59-
return
60-
}
61-
6256
// Open file
6357
f, err := req.File.Open()
6458
if err != nil {
@@ -67,15 +61,14 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
6761
}
6862
defer f.Close()
6963

70-
// Read samples
71-
buf, err := wav.NewDecoder(f).FullPCMBuffer()
64+
// Create a segmenter - read segments of 5 min samples
65+
segmenter, err := segmenter.New(f, 5*time.Minute, whisper.SampleRate)
7266
if err != nil {
73-
httpresponse.Error(w, http.StatusInternalServerError, err.Error())
67+
httpresponse.Error(w, http.StatusBadRequest, err.Error())
7468
return
7569
}
7670

7771
// Get context for the model, perform transcription
78-
var result *whisper.Transcription
7972
if err := service.WithModel(model, func(task *task.Context) error {
8073
// Check model
8174
if translate && !task.CanTranslate() {
@@ -93,29 +86,23 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
9386
return err
9487
}
9588
}
96-
// TODO Set prompt and temperature
97-
/*
98-
if req.Prompt != nil {
99-
ctx.SetPrompt(*req.Prompt)
100-
}
101-
if req.Temperature != nil {
102-
ctx.SetTemperature(*req.Temperature)
103-
}
104-
*/
105-
// Perform the transcription, return any errors
106-
return task.Transcribe(ctx, buf.AsFloat32Buffer().Data)
89+
90+
// TODO: Set temperature, etc
91+
92+
// Read samples and transcribe them
93+
return segmenter.Decode(ctx, func(ts time.Duration, buf []float32) error {
94+
fmt.Println("audio segment", ts, len(buf))
95+
96+
// Perform the transcription, return any errors
97+
return task.Transcribe(ctx, buf)
98+
})
10799
}); err != nil {
108-
httpresponse.Error(w, http.StatusBadRequest, err.Error())
100+
httpresponse.Error(w, http.StatusInternalServerError, err.Error())
109101
return
110102
}
111103

112-
// Response - TODO srt, vtt, verbose_json
113-
switch req.ResponseFormat() {
114-
case "text":
115-
httpresponse.Text(w, result.Text, http.StatusOK)
116-
default:
117-
httpresponse.JSON(w, result, http.StatusOK, 2)
118-
}
104+
var result whisper.Transcription
105+
httpresponse.JSON(w, result, http.StatusOK, 2)
119106
}
120107

121108
///////////////////////////////////////////////////////////////////////////////

pkg/whisper/segmenter/segmenter.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ type Segmenter struct {
2424

2525
// SegmentFunc is a callback function which is called when a segment is ready
2626
// to be processed. The first argument is the timestamp of the segment.
27-
type SegmentFunc func(time.Duration, []float32)
27+
type SegmentFunc func(time.Duration, []float32) error
2828

2929
//////////////////////////////////////////////////////////////////////////////
3030
// LIFECYCLE
3131

3232
// Create a new segmenter for "NumSamples" with a reader r
3333
// If NumSamples is zero then no segmenting is performed
34-
func NewSegmenter(r io.Reader, dur time.Duration, sample_rate int) (*Segmenter, error) {
34+
func New(r io.Reader, dur time.Duration, sample_rate int) (*Segmenter, error) {
3535
segmenter := new(Segmenter)
3636

3737
// Check arguments
@@ -105,7 +105,9 @@ func (s *Segmenter) Decode(ctx context.Context, fn SegmentFunc) error {
105105

106106
// n != 0 and len(buf) >= n we have a segment to process
107107
if s.n != 0 && len(s.buf) >= s.n {
108-
fn(s.ts, s.buf)
108+
if err := fn(s.ts, s.buf); err != nil {
109+
return err
110+
}
109111
// Increment the timestamp
110112
s.ts += time.Duration(len(s.buf)) * time.Second / time.Duration(s.sample_rate)
111113
// Clear the buffer

pkg/whisper/segmenter/segmenter_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@ func Test_segmenter_001(t *testing.T) {
2020
if !assert.NoError(err) {
2121
t.SkipNow()
2222
}
23-
segmenter, err := segmenter.NewSegmenter(f, 200*time.Millisecond, 16000)
23+
segmenter, err := segmenter.New(f, 200*time.Millisecond, 16000)
2424
if !assert.NoError(err) {
2525
t.SkipNow()
2626
}
2727
defer segmenter.Close()
2828

29-
assert.NoError(segmenter.Decode(context.Background(), func(ts time.Duration, buf []float32) {
29+
assert.NoError(segmenter.Decode(context.Background(), func(ts time.Duration, buf []float32) error {
3030
t.Log(ts, len(buf))
31+
return nil
3132
}))
3233
}
3334

@@ -40,13 +41,14 @@ func Test_segmenter_002(t *testing.T) {
4041
}
4142

4243
// No segmentation, just output the audio
43-
segmenter, err := segmenter.NewSegmenter(f, 0, 16000)
44+
segmenter, err := segmenter.New(f, 0, 16000)
4445
if !assert.NoError(err) {
4546
t.SkipNow()
4647
}
4748
defer segmenter.Close()
4849

49-
assert.NoError(segmenter.Decode(context.Background(), func(ts time.Duration, buf []float32) {
50+
assert.NoError(segmenter.Decode(context.Background(), func(ts time.Duration, buf []float32) error {
5051
t.Log(ts, len(buf))
52+
return nil
5153
}))
5254
}

pkg/whisper/whisper.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ const (
3636

3737
// This is where the model is downloaded from
3838
defaultModelUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/?download=true"
39+
40+
// Sample Rate
41+
SampleRate = whisper.SampleRate
3942
)
4043

4144
//////////////////////////////////////////////////////////////////////////////

sys/whisper/whisper.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ type (
1919
Context C.struct_whisper_context
2020
)
2121

22+
///////////////////////////////////////////////////////////////////////////////
23+
// Constants
24+
25+
const (
26+
SampleRate = C.WHISPER_SAMPLE_RATE
27+
)
28+
2229
///////////////////////////////////////////////////////////////////////////////
2330
// LIFECYCLE
2431

0 commit comments

Comments
 (0)