Skip to content

Commit babd70d

Browse files
committed
Upgraded segmenter
1 parent e3103a8 commit babd70d

File tree

4 files changed

+98
-20
lines changed

4 files changed

+98
-20
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ require (
88
github.com/djthorpe/go-tablewriter v0.0.8
99
github.com/go-audio/wav v1.1.0
1010
github.com/mutablelogic/go-client v1.0.9
11-
github.com/mutablelogic/go-media v1.6.7
11+
github.com/mutablelogic/go-media v1.6.8
1212
github.com/mutablelogic/go-server v1.4.13
1313
github.com/stretchr/testify v1.9.0
1414
)

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6T
2626
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
2727
github.com/mutablelogic/go-client v1.0.9 h1:Eh4sjQOFDldP/L3IizqkcOD3WigZR+u1VaHTUM4ujYw=
2828
github.com/mutablelogic/go-client v1.0.9/go.mod h1:VLyB8j8IBJSK/FXvvqhmq93PRWDKkyLu8R7V2Vudb6A=
29-
github.com/mutablelogic/go-media v1.6.7 h1:0hCr89EVJg7xw8ChABb7Cscr0UiZ1+Tl9xDXong0lu0=
30-
github.com/mutablelogic/go-media v1.6.7/go.mod h1:vWKq6QKqUQ+sAwfbU/DgakJGIk2Uq7ozH0qSxhysCkM=
29+
github.com/mutablelogic/go-media v1.6.8 h1:3v4povSQlOnvg9mHx6Bp9NVdCCjrNdDCjMHBGFHnVE8=
30+
github.com/mutablelogic/go-media v1.6.8/go.mod h1:HulNT0yyH63a3FRlbuzNDakhOypYrmtFVkHEXZjDgAY=
3131
github.com/mutablelogic/go-server v1.4.13 h1:k5LJJ/pCvyiw34UX341vRhliBOS6i7V65U/UICcOJOA=
3232
github.com/mutablelogic/go-server v1.4.13/go.mod h1:9nenPAohKu8bFoRgwHJh+3s8h0kLFjUAb8KZvT1TQNU=
3333
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=

pkg/whisper/segmenter/segmenter.go

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,50 @@ package segmenter
33
import (
44
"context"
55
"errors"
6-
"fmt"
76
"io"
7+
"time"
88

99
// Packages
1010
media "github.com/mutablelogic/go-media"
1111
ffmpeg "github.com/mutablelogic/go-media/pkg/ffmpeg"
12+
13+
// Namespace imports
14+
. "github.com/djthorpe/go-errors"
1215
)
1316

1417
type Segmenter struct {
15-
reader *ffmpeg.Reader
18+
ts time.Duration
19+
sample_rate int
20+
n int
21+
buf []float32
22+
reader *ffmpeg.Reader
1623
}
1724

25+
// SegmentFunc is a callback function which is called when a segment is ready
26+
// to be processed. The first argument is the timestamp of the segment.
27+
type SegmentFunc func(time.Duration, []float32)
28+
1829
//////////////////////////////////////////////////////////////////////////////
1930
// LIFECYCLE
2031

2132
// Create a new segmenter for "NumSamples" with a reader r
2233
// If NumSamples is zero then no segmenting is performed
23-
func NewSegmenter(r io.Reader) (*Segmenter, error) {
34+
func NewSegmenter(r io.Reader, dur time.Duration, sample_rate int) (*Segmenter, error) {
2435
segmenter := new(Segmenter)
2536

37+
// Check arguments
38+
if dur < 0 || sample_rate <= 0 {
39+
return nil, ErrBadParameter.With("invalid duration or sample rate arguments")
40+
} else {
41+
segmenter.sample_rate = sample_rate
42+
}
43+
44+
// Sample buffer is duration * sample rate
45+
if dur > 0 {
46+
segmenter.n = int(dur.Seconds()) * sample_rate
47+
segmenter.buf = make([]float32, 0, int(dur.Seconds())*sample_rate)
48+
}
49+
2650
// Open the file
2751
media, err := ffmpeg.NewReader(r)
2852
if err != nil {
@@ -42,6 +66,7 @@ func (s *Segmenter) Close() error {
4266
result = errors.Join(result, s.reader.Close())
4367
}
4468
s.reader = nil
69+
s.buf = nil
4570

4671
// Return any errors
4772
return result
@@ -53,19 +78,51 @@ func (s *Segmenter) Close() error {
5378
// TODO: segments are output through a callback, with the samples and a timestamp
5479
// TODO: we could do some basic silence and voice detection to segment to ensure
5580
// we don't overtax the CPU/GPU with silence and non-speech
56-
// TODO: We have hard-coded the sample format, sample rate and number of channels
57-
// here. We should make this configurable
58-
func (s *Segmenter) Decode(ctx context.Context) error {
81+
func (s *Segmenter) Decode(ctx context.Context, fn SegmentFunc) error {
82+
// Check input parameters
83+
if fn == nil {
84+
return ErrBadParameter.With("SegmentFunc is nil")
85+
}
86+
87+
// Map function chooses the best audio stream
5988
mapFunc := func(stream int, params *ffmpeg.Par) (*ffmpeg.Par, error) {
6089
if stream == s.reader.BestStream(media.AUDIO) {
61-
return ffmpeg.NewAudioPar("flt", "mono", 16000)
90+
return ffmpeg.NewAudioPar("flt", "mono", s.sample_rate)
6291
}
6392
// Ignore no-audio streams
6493
return nil, nil
6594
}
66-
return s.reader.Decode(ctx, mapFunc, func(stream int, frame *ffmpeg.Frame) error {
67-
// Append float32 samples to buffer
68-
fmt.Println("TODO: Implement Decode", frame)
95+
96+
// Decode samples and segment
97+
if err := s.reader.Decode(ctx, mapFunc, func(stream int, frame *ffmpeg.Frame) error {
98+
// We get null frames sometimes, ignore them
99+
if frame == nil {
100+
return nil
101+
}
102+
103+
// Append float32 samples from plane 0 to buffer
104+
s.buf = append(s.buf, frame.Float32(0)...)
105+
106+
// n != 0 and len(buf) >= n we have a segment to process
107+
if s.n != 0 && len(s.buf) >= s.n {
108+
fn(s.ts, s.buf)
109+
// Clear the buffer
110+
s.buf = s.buf[:0]
111+
// Increment the timestamp
112+
s.ts += time.Duration(float64(s.n)/float64(s.sample_rate)) * time.Second
113+
}
114+
115+
// Continue processing
69116
return nil
70-
})
117+
}); err != nil {
118+
return err
119+
}
120+
121+
// Output any remaining samples
122+
if len(s.buf) > 0 {
123+
fn(s.ts, s.buf)
124+
}
125+
126+
// Return success
127+
return nil
71128
}

pkg/whisper/segmenter/segmenter_test.go

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,49 @@ import (
44
"context"
55
"os"
66
"testing"
7+
"time"
78

89
// Packages
910
segmenter "github.com/mutablelogic/go-whisper/pkg/whisper/segmenter"
1011
assert "github.com/stretchr/testify/assert"
1112
)
1213

13-
const SAMPLE_EN = "../../../samples/jfk.wav"
14-
const SAMPLE_FR = "../../../samples/OlivierL.wav"
15-
const SAMPLE_DE = "../../../samples/ge-podcast.wav"
14+
const SAMPLE = "../../../samples/OlivierL.wav"
1615

1716
func Test_segmenter_001(t *testing.T) {
1817
assert := assert.New(t)
1918

20-
f, err := os.Open(SAMPLE_EN)
19+
f, err := os.Open(SAMPLE)
2120
if !assert.NoError(err) {
2221
t.SkipNow()
2322
}
24-
segmenter, err := segmenter.NewSegmenter(f)
23+
segmenter, err := segmenter.NewSegmenter(f, time.Second, 16000)
2524
if !assert.NoError(err) {
2625
t.SkipNow()
2726
}
2827
defer segmenter.Close()
2928

30-
assert.NoError(segmenter.Decode(context.Background()))
29+
assert.NoError(segmenter.Decode(context.Background(), func(ts time.Duration, buf []float32) {
30+
t.Log(ts, len(buf))
31+
}))
32+
}
33+
34+
func Test_segmenter_002(t *testing.T) {
35+
assert := assert.New(t)
36+
37+
f, err := os.Open(SAMPLE)
38+
if !assert.NoError(err) {
39+
t.SkipNow()
40+
}
41+
42+
// No segmentation, just output the audio
43+
segmenter, err := segmenter.NewSegmenter(f, 0, 16000)
44+
if !assert.NoError(err) {
45+
t.SkipNow()
46+
}
47+
defer segmenter.Close()
48+
49+
assert.NoError(segmenter.Decode(context.Background(), func(ts time.Duration, buf []float32) {
50+
t.Log(ts, len(buf))
51+
}))
3152
}

0 commit comments

Comments
 (0)