Skip to content

Commit b55b64d

Browse files
committed
Updated
1 parent e66dc08 commit b55b64d

File tree

10 files changed

+74
-161
lines changed

10 files changed

+74
-161
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ server: mkdir generate go-tidy libwhisper libggml
3535
@PKG_CONFIG_PATH=${ROOT_PATH}/${BUILD_DIR} ${GO} build ${BUILD_FLAGS} -o ${BUILD_DIR}/whisper-server ./cmd/server
3636

3737
# Make cli
38-
cli: mkdir generate go-tidy libwhisper libggml
38+
cli: mkdir generate go-tidy
3939
@echo "Building whisper-cli"
4040
@PKG_CONFIG_PATH=${ROOT_PATH}/${BUILD_DIR} ${GO} build ${BUILD_FLAGS} -o ${BUILD_DIR}/whisper-cli ./cmd/cli
4141

cmd/cli/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
tablewriter "github.com/djthorpe/go-tablewriter"
1212
client "github.com/mutablelogic/go-client"
1313
ctx "github.com/mutablelogic/go-server/pkg/context"
14-
api "github.com/mutablelogic/go-whisper/pkg/client"
14+
api "github.com/mutablelogic/go-whisper/pkg/whisper/client"
1515
)
1616

1717
type Globals struct {

cmd/cli/transcribe.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@ package main
22

33
import (
44
"os"
5+
"time"
56

67
"github.com/djthorpe/go-tablewriter"
7-
"github.com/mutablelogic/go-whisper/pkg/client"
8+
"github.com/mutablelogic/go-whisper/pkg/whisper/client"
89
)
910

1011
type TranscribeCmd struct {
11-
Model string `arg:"" required:"" help:"Model Identifier" type:"string"`
12-
Path string `arg:"" required:"" help:"Audio File Path" type:"string"`
13-
Language string `flag:"language" help:"Source Language" type:"string"`
14-
Prompt string `flag:"prompt" help:"Initial Prompt Identifier" type:"string"`
15-
Temperature *float32 `flag:"temperature" help:"Temperature" type:"float32"`
12+
Model string `arg:"" required:"" help:"Model Identifier" type:"string"`
13+
Path string `arg:"" required:"" help:"Audio File Path" type:"string"`
14+
Language string `flag:"language" help:"Source Language" type:"string"`
15+
SegmentSize *time.Duration `flag:"segment-size" help:"Segment Size" type:"duration"`
16+
ResponseFmt *string `flag:"format" help:"Response Format" enum:"json,verbose_json,text,vtt,srt"`
1617
}
1718

1819
func (cmd *TranscribeCmd) Run(ctx *Globals) error {
@@ -26,11 +27,11 @@ func (cmd *TranscribeCmd) Run(ctx *Globals) error {
2627
if cmd.Language != "" {
2728
opts = append(opts, client.OptLanguage(cmd.Language))
2829
}
29-
if cmd.Prompt != "" {
30-
opts = append(opts, client.OptPrompt(cmd.Prompt))
30+
if cmd.SegmentSize != nil {
31+
opts = append(opts, client.OptSegmentSize(*cmd.SegmentSize))
3132
}
32-
if cmd.Temperature != nil {
33-
opts = append(opts, client.OptTemperature(*cmd.Temperature))
33+
if cmd.ResponseFmt != nil {
34+
opts = append(opts, client.OptResponseFormat(*cmd.ResponseFmt))
3435
}
3536

3637
transcription, err := ctx.api.Transcribe(ctx.ctx, cmd.Model, r, opts...)

cmd/cli/translate.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@ package main
22

33
import (
44
"os"
5+
"time"
56

67
"github.com/djthorpe/go-tablewriter"
7-
"github.com/mutablelogic/go-whisper/pkg/client"
8+
"github.com/mutablelogic/go-whisper/pkg/whisper/client"
89
)
910

1011
type TranslateCmd struct {
11-
Model string `arg:"" required:"" help:"Model Identifier" type:"string"`
12-
Path string `arg:"" required:"" help:"Audio File Path" type:"string"`
13-
Language string `flag:"language" required:"" help:"Target Language" type:"string"`
14-
Prompt string `flag:"prompt" help:"Initial Prompt Identifier" type:"string"`
15-
Temperature *float32 `flag:"temperature" help:"Temperature" type:"float32"`
12+
Model string `arg:"" required:"" help:"Model Identifier" type:"string"`
13+
Path string `arg:"" required:"" help:"Audio File Path" type:"string"`
14+
Language string `flag:"language" required:"" help:"Target Language" type:"string"`
15+
SegmentSize *time.Duration `flag:"segment-size" help:"Segment Size" type:"duration"`
16+
ResponseFmt *string `flag:"format" help:"Response Format" enum:"json,verbose_json,text,vtt,srt"`
1617
}
1718

1819
func (cmd *TranslateCmd) Run(ctx *Globals) error {
@@ -26,11 +27,11 @@ func (cmd *TranslateCmd) Run(ctx *Globals) error {
2627
if cmd.Language != "" {
2728
opts = append(opts, client.OptLanguage(cmd.Language))
2829
}
29-
if cmd.Prompt != "" {
30-
opts = append(opts, client.OptPrompt(cmd.Prompt))
30+
if cmd.SegmentSize != nil {
31+
opts = append(opts, client.OptSegmentSize(*cmd.SegmentSize))
3132
}
32-
if cmd.Temperature != nil {
33-
opts = append(opts, client.OptTemperature(*cmd.Temperature))
33+
if cmd.ResponseFmt != nil {
34+
opts = append(opts, client.OptResponseFormat(*cmd.ResponseFmt))
3435
}
3536

3637
transcription, err := ctx.api.Translate(ctx.ctx, cmd.Model, r, opts...)

pkg/whisper/api/transcribe.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
100100
// Read samples and transcribe them
101101
if err := segmenter.Decode(ctx, func(ts time.Duration, buf []float32) error {
102102
// Perform the transcription, return any errors
103-
return task.Transcribe(ctx, ts, buf, func(segment *transcription.Segment) {
103+
return task.Transcribe(ctx, ts, buf, req.OutputSegments(), func(segment *transcription.Segment) {
104104
fmt.Println("TODO: ", segment)
105105
})
106106
}); err != nil {
@@ -117,7 +117,11 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
117117
return
118118
}
119119

120-
// Set duration
120+
// Set task, duration
121+
result.Task = "transcribe"
122+
if translate {
123+
result.Task = "translate"
124+
}
121125
result.Duration = segmenter.Duration()
122126

123127
// Return transcription
@@ -151,6 +155,16 @@ func (r reqTranscribe) ResponseFormat() string {
151155
return *r.ResponseFmt
152156
}
153157

158+
func (r reqTranscribe) OutputSegments() bool {
159+
// We want to output segments if the response format is "srt", "verbose_json", "vtt"
160+
switch r.ResponseFormat() {
161+
case "srt", "verbose_json", "vtt":
162+
return true
163+
default:
164+
return false
165+
}
166+
}
167+
154168
func (r reqTranscribe) SegmentDur() time.Duration {
155169
if r.SegmentSize == nil {
156170
return defaultSegmentSize

pkg/client/client.go renamed to pkg/whisper/client/client.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ import (
1111
"github.com/mutablelogic/go-client"
1212
"github.com/mutablelogic/go-client/pkg/multipart"
1313
"github.com/mutablelogic/go-server/pkg/httprequest"
14-
"github.com/mutablelogic/go-whisper/pkg/whisper"
1514
"github.com/mutablelogic/go-whisper/pkg/whisper/model"
15+
"github.com/mutablelogic/go-whisper/pkg/whisper/transcription"
1616
)
1717

1818
///////////////////////////////////////////////////////////////////////////////
@@ -94,13 +94,13 @@ func (c *Client) DownloadModel(ctx context.Context, path string, fn func(status
9494
return r.Model, nil
9595
}
9696

97-
func (c *Client) Transcribe(ctx context.Context, model string, r io.Reader, opt ...Opt) (*whisper.Transcription, error) {
97+
func (c *Client) Transcribe(ctx context.Context, model string, r io.Reader, opt ...Opt) (*transcription.Transcription, error) {
9898
var request struct {
9999
File multipart.File `json:"file"`
100100
Model string `json:"model"`
101101
opts
102102
}
103-
var response whisper.Transcription
103+
var response transcription.Transcription
104104

105105
// Get the name from the io.Reader
106106
name := ""
@@ -131,13 +131,13 @@ func (c *Client) Transcribe(ctx context.Context, model string, r io.Reader, opt
131131
return &response, nil
132132
}
133133

134-
func (c *Client) Translate(ctx context.Context, model string, r io.Reader, opt ...Opt) (*whisper.Transcription, error) {
134+
func (c *Client) Translate(ctx context.Context, model string, r io.Reader, opt ...Opt) (*transcription.Transcription, error) {
135135
var request struct {
136136
File multipart.File `json:"file"`
137137
Model string `json:"model"`
138138
opts
139139
}
140-
var response whisper.Transcription
140+
var response transcription.Transcription
141141

142142
// Get the name from the io.Reader
143143
name := ""
Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package client
22

3+
import "time"
4+
35
// Request options
46
type opts struct {
5-
Language string `json:"language,omitempty"`
6-
Prompt string `json:"prompt,omitempty"`
7-
Temperature float32 `json:"temperature,omitempty"`
7+
Language string `json:"language,omitempty"`
8+
SegmentSize time.Duration `json:"segment_size,omitempty"`
9+
ResponseFmt string `json:"response_format,omitempty"`
810
}
911

1012
type Opt func(*opts) error
@@ -19,16 +21,16 @@ func OptLanguage(language string) Opt {
1921
}
2022
}
2123

22-
func OptPrompt(prompt string) Opt {
24+
func OptSegmentSize(v time.Duration) Opt {
2325
return func(o *opts) error {
24-
o.Prompt = prompt
26+
o.SegmentSize = v
2527
return nil
2628
}
2729
}
2830

29-
func OptTemperature(t float32) Opt {
31+
func OptResponseFormat(v string) Opt {
3032
return func(o *opts) error {
31-
o.Temperature = t
33+
o.ResponseFmt = v
3234
return nil
3335
}
3436
}

pkg/whisper/task/context.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,11 @@ func (ctx *Context) Is(model *model.Model) bool {
140140
return ctx.model == model.Id
141141
}
142142

143-
// Copy task parameters from the default
143+
// Reset task context for re-use
144144
func (task *Context) CopyParams() {
145145
task.params = whisper.DefaultFullParams(whisper.SAMPLING_GREEDY)
146146
task.params.SetLanguage("auto")
147+
task.result = nil
147148
}
148149

149150
// Model is multilingual and can translate
@@ -152,8 +153,9 @@ func (task *Context) CanTranslate() bool {
152153
}
153154

154155
// Transcribe samples. The samples should be 16KHz float32 samples in
155-
// a single channel.
156-
func (task *Context) Transcribe(ctx context.Context, ts time.Duration, samples []float32, fn NewSegmentFunc) error {
156+
// a single channel. Appends the transcription to the result, and includes
157+
// segment data if segments is true.
158+
func (task *Context) Transcribe(ctx context.Context, ts time.Duration, samples []float32, segments bool, fn NewSegmentFunc) error {
157159
// Set the 'abort' function
158160
task.params.SetAbortCallback(task.whisper, func() bool {
159161
select {
@@ -193,7 +195,7 @@ func (task *Context) Transcribe(ctx context.Context, ts time.Duration, samples [
193195
if task.result == nil {
194196
task.result = transcription.New()
195197
}
196-
task.result.Append(task.whisper, ts)
198+
task.result.Append(task.whisper, ts, segments)
197199

198200
// Return success
199201
return nil

pkg/whisper/task/context.go_old

Lines changed: 0 additions & 114 deletions
This file was deleted.

0 commit comments

Comments
 (0)