Skip to content

Commit 2cd2ebe

Browse files
authored
Merge pull request #38 from mutablelogic/v1
Segmenter
2 parents 7aef1cf + b55b64d commit 2cd2ebe

File tree

17 files changed

+315
-289
lines changed

17 files changed

+315
-289
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ server: mkdir generate go-tidy libwhisper libggml
3535
@PKG_CONFIG_PATH=${ROOT_PATH}/${BUILD_DIR} ${GO} build ${BUILD_FLAGS} -o ${BUILD_DIR}/whisper-server ./cmd/server
3636

3737
# Make cli
38-
cli: mkdir generate go-tidy libwhisper libggml
38+
cli: mkdir generate go-tidy
3939
@echo "Building whisper-cli"
4040
@PKG_CONFIG_PATH=${ROOT_PATH}/${BUILD_DIR} ${GO} build ${BUILD_FLAGS} -o ${BUILD_DIR}/whisper-cli ./cmd/cli
4141

cmd/cli/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
tablewriter "github.com/djthorpe/go-tablewriter"
1212
client "github.com/mutablelogic/go-client"
1313
ctx "github.com/mutablelogic/go-server/pkg/context"
14-
api "github.com/mutablelogic/go-whisper/pkg/client"
14+
api "github.com/mutablelogic/go-whisper/pkg/whisper/client"
1515
)
1616

1717
type Globals struct {

cmd/cli/transcribe.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@ package main
22

33
import (
44
"os"
5+
"time"
56

67
"github.com/djthorpe/go-tablewriter"
7-
"github.com/mutablelogic/go-whisper/pkg/client"
8+
"github.com/mutablelogic/go-whisper/pkg/whisper/client"
89
)
910

1011
type TranscribeCmd struct {
11-
Model string `arg:"" required:"" help:"Model Identifier" type:"string"`
12-
Path string `arg:"" required:"" help:"Audio File Path" type:"string"`
13-
Language string `flag:"language" help:"Source Language" type:"string"`
14-
Prompt string `flag:"prompt" help:"Initial Prompt Identifier" type:"string"`
15-
Temperature *float32 `flag:"temperature" help:"Temperature" type:"float32"`
12+
Model string `arg:"" required:"" help:"Model Identifier" type:"string"`
13+
Path string `arg:"" required:"" help:"Audio File Path" type:"string"`
14+
Language string `flag:"language" help:"Source Language" type:"string"`
15+
SegmentSize *time.Duration `flag:"segment-size" help:"Segment Size" type:"duration"`
16+
ResponseFmt *string `flag:"format" help:"Response Format" enum:"json,verbose_json,text,vtt,srt"`
1617
}
1718

1819
func (cmd *TranscribeCmd) Run(ctx *Globals) error {
@@ -26,11 +27,11 @@ func (cmd *TranscribeCmd) Run(ctx *Globals) error {
2627
if cmd.Language != "" {
2728
opts = append(opts, client.OptLanguage(cmd.Language))
2829
}
29-
if cmd.Prompt != "" {
30-
opts = append(opts, client.OptPrompt(cmd.Prompt))
30+
if cmd.SegmentSize != nil {
31+
opts = append(opts, client.OptSegmentSize(*cmd.SegmentSize))
3132
}
32-
if cmd.Temperature != nil {
33-
opts = append(opts, client.OptTemperature(*cmd.Temperature))
33+
if cmd.ResponseFmt != nil {
34+
opts = append(opts, client.OptResponseFormat(*cmd.ResponseFmt))
3435
}
3536

3637
transcription, err := ctx.api.Transcribe(ctx.ctx, cmd.Model, r, opts...)

cmd/cli/translate.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@ package main
22

33
import (
44
"os"
5+
"time"
56

67
"github.com/djthorpe/go-tablewriter"
7-
"github.com/mutablelogic/go-whisper/pkg/client"
8+
"github.com/mutablelogic/go-whisper/pkg/whisper/client"
89
)
910

1011
type TranslateCmd struct {
11-
Model string `arg:"" required:"" help:"Model Identifier" type:"string"`
12-
Path string `arg:"" required:"" help:"Audio File Path" type:"string"`
13-
Language string `flag:"language" required:"" help:"Target Language" type:"string"`
14-
Prompt string `flag:"prompt" help:"Initial Prompt Identifier" type:"string"`
15-
Temperature *float32 `flag:"temperature" help:"Temperature" type:"float32"`
12+
Model string `arg:"" required:"" help:"Model Identifier" type:"string"`
13+
Path string `arg:"" required:"" help:"Audio File Path" type:"string"`
14+
Language string `flag:"language" required:"" help:"Target Language" type:"string"`
15+
SegmentSize *time.Duration `flag:"segment-size" help:"Segment Size" type:"duration"`
16+
ResponseFmt *string `flag:"format" help:"Response Format" enum:"json,verbose_json,text,vtt,srt"`
1617
}
1718

1819
func (cmd *TranslateCmd) Run(ctx *Globals) error {
@@ -26,11 +27,11 @@ func (cmd *TranslateCmd) Run(ctx *Globals) error {
2627
if cmd.Language != "" {
2728
opts = append(opts, client.OptLanguage(cmd.Language))
2829
}
29-
if cmd.Prompt != "" {
30-
opts = append(opts, client.OptPrompt(cmd.Prompt))
30+
if cmd.SegmentSize != nil {
31+
opts = append(opts, client.OptSegmentSize(*cmd.SegmentSize))
3132
}
32-
if cmd.Temperature != nil {
33-
opts = append(opts, client.OptTemperature(*cmd.Temperature))
33+
if cmd.ResponseFmt != nil {
34+
opts = append(opts, client.OptResponseFormat(*cmd.ResponseFmt))
3435
}
3536

3637
transcription, err := ctx.api.Translate(ctx.ctx, cmd.Model, r, opts...)

pkg/whisper/api/transcribe.go

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@ import (
55
"fmt"
66
"mime/multipart"
77
"net/http"
8+
"time"
89

910
// Packages
10-
11-
"github.com/go-audio/wav"
1211
"github.com/mutablelogic/go-server/pkg/httprequest"
1312
"github.com/mutablelogic/go-server/pkg/httpresponse"
1413
"github.com/mutablelogic/go-whisper/pkg/whisper"
14+
"github.com/mutablelogic/go-whisper/pkg/whisper/segmenter"
1515
"github.com/mutablelogic/go-whisper/pkg/whisper/task"
16+
"github.com/mutablelogic/go-whisper/pkg/whisper/transcription"
1617

1718
// Namespace imports
1819
. "github.com/djthorpe/go-errors"
@@ -25,11 +26,17 @@ type reqTranscribe struct {
2526
File *multipart.FileHeader `json:"file"`
2627
Model string `json:"model"`
2728
Language *string `json:"language"`
28-
Prompt *string `json:"prompt"`
29-
ResponseFmt *string `json:"response_format"`
3029
Temperature *float32 `json:"temperature"`
30+
SegmentSize *time.Duration `json:"segment_size"`
31+
ResponseFmt *string `json:"response_format"`
3132
}
3233

34+
const (
35+
minSegmentSize = 5 * time.Second
36+
maxSegmentSize = 10 * time.Minute
37+
defaultSegmentSize = 5 * time.Minute
38+
)
39+
3340
///////////////////////////////////////////////////////////////////////////////
3441
// PUBLIC METHODS
3542

@@ -53,12 +60,6 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
5360
return
5461
}
5562

56-
// Check audio format - allow WAV or binary
57-
if req.File.Header.Get("Content-Type") != "audio/wav" && req.File.Header.Get("Content-Type") != httprequest.ContentTypeBinary {
58-
httpresponse.Error(w, http.StatusBadRequest, "unsupported audio format:", req.File.Header.Get("Content-Type"))
59-
return
60-
}
61-
6263
// Open file
6364
f, err := req.File.Open()
6465
if err != nil {
@@ -67,15 +68,15 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
6768
}
6869
defer f.Close()
6970

70-
// Read samples
71-
buf, err := wav.NewDecoder(f).FullPCMBuffer()
71+
// Create a segmenter - read segments based on requested segment size
72+
segmenter, err := segmenter.New(f, req.SegmentDur(), whisper.SampleRate)
7273
if err != nil {
73-
httpresponse.Error(w, http.StatusInternalServerError, err.Error())
74+
httpresponse.Error(w, http.StatusBadRequest, err.Error())
7475
return
7576
}
7677

7778
// Get context for the model, perform transcription
78-
var result *whisper.Transcription
79+
var result *transcription.Transcription
7980
if err := service.WithModel(model, func(task *task.Context) error {
8081
// Check model
8182
if translate && !task.CanTranslate() {
@@ -93,29 +94,38 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
9394
return err
9495
}
9596
}
96-
// TODO Set prompt and temperature
97-
/*
98-
if req.Prompt != nil {
99-
ctx.SetPrompt(*req.Prompt)
100-
}
101-
if req.Temperature != nil {
102-
ctx.SetTemperature(*req.Temperature)
103-
}
104-
*/
105-
// Perform the transcription, return any errors
106-
return task.Transcribe(ctx, buf.AsFloat32Buffer().Data)
97+
98+
// TODO: Set temperature, etc
99+
100+
// Read samples and transcribe them
101+
if err := segmenter.Decode(ctx, func(ts time.Duration, buf []float32) error {
102+
// Perform the transcription, return any errors
103+
return task.Transcribe(ctx, ts, buf, req.OutputSegments(), func(segment *transcription.Segment) {
104+
fmt.Println("TODO: ", segment)
105+
})
106+
}); err != nil {
107+
return err
108+
}
109+
110+
// End of transcription, get result
111+
result = task.Result()
112+
113+
// Return success
114+
return nil
107115
}); err != nil {
108-
httpresponse.Error(w, http.StatusBadRequest, err.Error())
116+
httpresponse.Error(w, http.StatusInternalServerError, err.Error())
109117
return
110118
}
111119

112-
// Response - TODO srt, vtt, verbose_json
113-
switch req.ResponseFormat() {
114-
case "text":
115-
httpresponse.Text(w, result.Text, http.StatusOK)
116-
default:
117-
httpresponse.JSON(w, result, http.StatusOK, 2)
120+
// Set task, duration
121+
result.Task = "transcribe"
122+
if translate {
123+
result.Task = "translate"
118124
}
125+
result.Duration = segmenter.Duration()
126+
127+
// Return transcription
128+
httpresponse.JSON(w, result, http.StatusOK, 2)
119129
}
120130

121131
///////////////////////////////////////////////////////////////////////////////
@@ -144,3 +154,26 @@ func (r reqTranscribe) ResponseFormat() string {
144154
}
145155
return *r.ResponseFmt
146156
}
157+
158+
func (r reqTranscribe) OutputSegments() bool {
159+
// We want to output segments if the response format is "srt", "verbose_json", "vtt"
160+
switch r.ResponseFormat() {
161+
case "srt", "verbose_json", "vtt":
162+
return true
163+
default:
164+
return false
165+
}
166+
}
167+
168+
func (r reqTranscribe) SegmentDur() time.Duration {
169+
if r.SegmentSize == nil {
170+
return defaultSegmentSize
171+
}
172+
if *r.SegmentSize < minSegmentSize {
173+
return minSegmentSize
174+
}
175+
if *r.SegmentSize > maxSegmentSize {
176+
return maxSegmentSize
177+
}
178+
return *r.SegmentSize
179+
}

pkg/client/client.go renamed to pkg/whisper/client/client.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ import (
1111
"github.com/mutablelogic/go-client"
1212
"github.com/mutablelogic/go-client/pkg/multipart"
1313
"github.com/mutablelogic/go-server/pkg/httprequest"
14-
"github.com/mutablelogic/go-whisper/pkg/whisper"
1514
"github.com/mutablelogic/go-whisper/pkg/whisper/model"
15+
"github.com/mutablelogic/go-whisper/pkg/whisper/transcription"
1616
)
1717

1818
///////////////////////////////////////////////////////////////////////////////
@@ -94,13 +94,13 @@ func (c *Client) DownloadModel(ctx context.Context, path string, fn func(status
9494
return r.Model, nil
9595
}
9696

97-
func (c *Client) Transcribe(ctx context.Context, model string, r io.Reader, opt ...Opt) (*whisper.Transcription, error) {
97+
func (c *Client) Transcribe(ctx context.Context, model string, r io.Reader, opt ...Opt) (*transcription.Transcription, error) {
9898
var request struct {
9999
File multipart.File `json:"file"`
100100
Model string `json:"model"`
101101
opts
102102
}
103-
var response whisper.Transcription
103+
var response transcription.Transcription
104104

105105
// Get the name from the io.Reader
106106
name := ""
@@ -131,13 +131,13 @@ func (c *Client) Transcribe(ctx context.Context, model string, r io.Reader, opt
131131
return &response, nil
132132
}
133133

134-
func (c *Client) Translate(ctx context.Context, model string, r io.Reader, opt ...Opt) (*whisper.Transcription, error) {
134+
func (c *Client) Translate(ctx context.Context, model string, r io.Reader, opt ...Opt) (*transcription.Transcription, error) {
135135
var request struct {
136136
File multipart.File `json:"file"`
137137
Model string `json:"model"`
138138
opts
139139
}
140-
var response whisper.Transcription
140+
var response transcription.Transcription
141141

142142
// Get the name from the io.Reader
143143
name := ""
Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package client
22

3+
import "time"
4+
35
// Request options
46
type opts struct {
5-
Language string `json:"language,omitempty"`
6-
Prompt string `json:"prompt,omitempty"`
7-
Temperature float32 `json:"temperature,omitempty"`
7+
Language string `json:"language,omitempty"`
8+
SegmentSize time.Duration `json:"segment_size,omitempty"`
9+
ResponseFmt string `json:"response_format,omitempty"`
810
}
911

1012
type Opt func(*opts) error
@@ -19,16 +21,16 @@ func OptLanguage(language string) Opt {
1921
}
2022
}
2123

22-
func OptPrompt(prompt string) Opt {
24+
func OptSegmentSize(v time.Duration) Opt {
2325
return func(o *opts) error {
24-
o.Prompt = prompt
26+
o.SegmentSize = v
2527
return nil
2628
}
2729
}
2830

29-
func OptTemperature(t float32) Opt {
31+
func OptResponseFormat(v string) Opt {
3032
return func(o *opts) error {
31-
o.Temperature = t
33+
o.ResponseFmt = v
3234
return nil
3335
}
3436
}

pkg/whisper/model/store.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,9 @@ func listModels(path, ext string) ([]*Model, error) {
309309
}
310310

311311
func modelNameToId(name string) string {
312+
// Lowercase the name, remove the extension
313+
name = strings.TrimSuffix(strings.ToLower(name), filepath.Ext(name))
314+
312315
// We replace all non-alphanumeric characters with underscores
313316
return strings.Map(func(r rune) rune {
314317
if r >= 'a' && r <= 'z' {

0 commit comments

Comments
 (0)