Skip to content

Commit e967c37

Browse files
committed
Added audio samples and transcription
1 parent 3785dfa commit e967c37

File tree

6 files changed

+114
-8
lines changed

6 files changed

+114
-8
lines changed

cmd/api/openai.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"os"
78
"strings"
89

910
// Packages
@@ -34,6 +35,8 @@ var (
3435
openaiTemperature *float64
3536
openaiUser string
3637
openaiSystemPrompt string
38+
openaiPrompt string
39+
openaiLanguage string
3740
)
3841

3942
///////////////////////////////////////////////////////////////////////////////
@@ -54,6 +57,8 @@ func openaiRegister(flags *Flags) {
5457
// TODO flags.String(openaiName, "system", "", "The system prompt")
5558
// TODO flags.Bool(openaiName, "stream", false, "If set, partial message deltas will be sent, like in ChatGPT")
5659
// TODO flags.Float(openaiName, "temperature", 0, "Sampling temperature to use, between 0.0 and 2.0")
60+
flags.String(openaiName, "prompt", "", "An optional text to guide the model's style or continue a previous audio segment")
61+
//flags.String(openaiName, "language", "", "The language of the input audio in ISO-639-1 format")
5762

5863
// Register commands
5964
flags.Register(Cmd{
@@ -65,6 +70,8 @@ func openaiRegister(flags *Flags) {
6570
{Name: "model", Call: openaiGetModel, Description: "Return model information", MinArgs: 1, MaxArgs: 1, Syntax: "<model>"},
6671
{Name: "image", Call: openaiImage, Description: "Create image from a prompt", MinArgs: 1, Syntax: "<prompt>"},
6772
{Name: "chat", Call: openaiChat, Description: "Create a chat completion", MinArgs: 1, Syntax: "<text>..."},
73+
{Name: "transcribe", Call: openaiTranscribe, Description: "Transcribes audio into the input language", MinArgs: 1, MaxArgs: 1, Syntax: "<filename>"},
74+
{Name: "translate", Call: openaiTranslate, Description: "Translates audio into English", MinArgs: 1, MaxArgs: 1, Syntax: "<filename>"},
6875
},
6976
})
7077
}
@@ -87,6 +94,8 @@ func openaiParse(flags *Flags, opts ...client.ClientOpt) error {
8794
openaiStream = flags.GetBool("stream")
8895
openaiUser = flags.GetString("user")
8996
openaiSystemPrompt = flags.GetString("system")
97+
openaiPrompt = flags.GetString("prompt")
98+
openaiLanguage = flags.GetString("language")
9099

91100
if temp, err := flags.GetValue("temperature"); err == nil {
92101
t := temp.(float64)
@@ -227,3 +236,70 @@ func openaiChat(ctx context.Context, w *tablewriter.Writer, args []string) error
227236

228237
return w.Write(responses)
229238
}
239+
240+
func openaiTranscribe(ctx context.Context, w *tablewriter.Writer, args []string) error {
241+
opts := []openai.Opt{}
242+
if openaiModel != "" {
243+
opts = append(opts, openai.OptModel(openaiModel))
244+
}
245+
if openaiPrompt != "" {
246+
opts = append(opts, openai.OptPrompt(openaiPrompt))
247+
}
248+
if openaiLanguage != "" {
249+
opts = append(opts, openai.OptLanguage(openaiLanguage))
250+
}
251+
if openaiResponseFormat != "" {
252+
opts = append(opts, openai.OptResponseFormat(openaiResponseFormat))
253+
}
254+
if openaiTemperature != nil {
255+
opts = append(opts, openai.OptTemperature(float32(*openaiTemperature)))
256+
}
257+
258+
// Open audio file for reading
259+
r, err := os.Open(args[0])
260+
if err != nil {
261+
return err
262+
}
263+
defer r.Close()
264+
265+
// Perform transcription
266+
transcription, err := openaiClient.Transcribe(ctx, r, opts...)
267+
if err != nil {
268+
return err
269+
}
270+
271+
// Write output
272+
return w.Write(transcription)
273+
}
274+
275+
func openaiTranslate(ctx context.Context, w *tablewriter.Writer, args []string) error {
276+
opts := []openai.Opt{}
277+
if openaiModel != "" {
278+
opts = append(opts, openai.OptModel(openaiModel))
279+
}
280+
if openaiPrompt != "" {
281+
opts = append(opts, openai.OptPrompt(openaiPrompt))
282+
}
283+
if openaiResponseFormat != "" {
284+
opts = append(opts, openai.OptResponseFormat(openaiResponseFormat))
285+
}
286+
if openaiTemperature != nil {
287+
opts = append(opts, openai.OptTemperature(float32(*openaiTemperature)))
288+
}
289+
290+
// Open audio file for reading
291+
r, err := os.Open(args[0])
292+
if err != nil {
293+
return err
294+
}
295+
defer r.Close()
296+
297+
// Perform translation
298+
transcription, err := openaiClient.Translate(ctx, r, opts...)
299+
if err != nil {
300+
return err
301+
}
302+
303+
// Write output
304+
return w.Write(transcription)
305+
}

etc/test/Jean_de_La_Fontaine.mp3

1.04 MB
Binary file not shown.

etc/test/harvard.wav

3.1 MB
Binary file not shown.

image.jpg

3 MB
Loading

pkg/multipart/multipart.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,9 @@ func (enc *Encoder) writeField(name string, value any) error {
212212
// Check field index for a parent, which should be ignored
213213
func hasParentIndex(ignore [][]int, index []int) bool {
214214
for _, ignore := range ignore {
215+
if len(index) < len(ignore) {
216+
continue
217+
}
215218
if slices.Equal(ignore, index[:len(ignore)]) {
216219
return true
217220
}

pkg/openai/audio.go

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ package openai
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/json"
67
"io"
8+
"os"
9+
"path/filepath"
710

811
// Packages
912
"github.com/mutablelogic/go-client"
@@ -37,7 +40,7 @@ type Transcription struct {
3740
Task string `json:"task,omitempty"`
3841
Language string `json:"language,omitempty"` // The language of the input audio.
3942
Duration float64 `json:"duration,omitempty"` // The duration of the input audio.
40-
Text string `json:"text"`
43+
Text string `json:"text" writer:",wrap,width:40"`
4144
Words []struct {
4245
Word string `json:"word"` // The text content of the word.
4346
Start float64 `json:"start"` // Start time of the word in seconds.
@@ -54,7 +57,7 @@ type Transcription struct {
5457
AvgLogProbability float64 `json:"avg_logprob,omitempty"` // Average logprob of the segment. If the value is lower than -1, consider the logprobs failed.
5558
CompressionRatio float64 `json:"compression_ratio,omitempty"` // Compression ratio of the segment. If the value is greater than 2.4, consider the compression failed.
5659
NoSpeechProbability float64 `json:"no_speech_prob,omitempty"` // Probability of no speech in the segment. If the value is higher than 1.0 and the avg_logprob is below -1, consider this segment silent.
57-
} `json:"segments,omitempty"`
60+
} `json:"segments,omitempty" writer:",wrap"`
5861
}
5962

6063
///////////////////////////////////////////////////////////////////////////////
@@ -65,6 +68,14 @@ const (
6568
defaultTranscribeModel = "whisper-1"
6669
)
6770

71+
///////////////////////////////////////////////////////////////////////////////
72+
// STRINGIFY
73+
74+
func (r reqTranscribe) String() string {
75+
data, _ := json.MarshalIndent(r, "", " ")
76+
return string(data)
77+
}
78+
6879
///////////////////////////////////////////////////////////////////////////////
6980
// API CALLS
7081

@@ -99,14 +110,19 @@ func (c *Client) Speech(w io.Writer, voice, text string, opts ...Opt) (int64, er
99110
}
100111

101112
// Transcribes audio from audio data
102-
func (c *Client) Transcribe(r io.Reader, opts ...Opt) (*Transcription, error) {
113+
func (c *Client) Transcribe(ctx context.Context, r io.Reader, opts ...Opt) (*Transcription, error) {
103114
var request reqTranscribe
104115
response := new(Transcription)
105116

117+
name := ""
118+
if f, ok := r.(*os.File); ok {
119+
name = filepath.Base(f.Name())
120+
}
121+
106122
// Create the request and set up the response
107123
request.Model = defaultTranscribeModel
108124
request.File = multipart.File{
109-
Path: "output.mp3", // TODO: Change this
125+
Path: name,
110126
Body: r,
111127
}
112128

@@ -117,10 +133,13 @@ func (c *Client) Transcribe(r io.Reader, opts ...Opt) (*Transcription, error) {
117133
}
118134
}
119135

136+
// Debugging
137+
c.Debugf("transcribe: %v", request)
138+
120139
// Make a response object, write the data
121140
if payload, err := client.NewMultipartRequest(request, client.ContentTypeJson); err != nil {
122141
return nil, err
123-
} else if err := c.Do(payload, response, client.OptPath("audio/transcriptions")); err != nil {
142+
} else if err := c.DoWithContext(ctx, payload, response, client.OptPath("audio/transcriptions")); err != nil {
124143
return nil, err
125144
}
126145

@@ -129,14 +148,19 @@ func (c *Client) Transcribe(r io.Reader, opts ...Opt) (*Transcription, error) {
129148
}
130149

131150
// Translate audio into English
132-
func (c *Client) Translate(r io.Reader, opts ...Opt) (*Transcription, error) {
151+
func (c *Client) Translate(ctx context.Context, r io.Reader, opts ...Opt) (*Transcription, error) {
133152
var request reqTranscribe
134153
response := new(Transcription)
135154

155+
name := ""
156+
if f, ok := r.(*os.File); ok {
157+
name = filepath.Base(f.Name())
158+
}
159+
136160
// Create the request and set up the response
137161
request.Model = defaultTranscribeModel
138162
request.File = multipart.File{
139-
Path: "output.mp3", // TODO: Change this
163+
Path: name,
140164
Body: r,
141165
}
142166

@@ -147,10 +171,13 @@ func (c *Client) Translate(r io.Reader, opts ...Opt) (*Transcription, error) {
147171
}
148172
}
149173

174+
// Debugging
175+
c.Debugf("translate: %v", request)
176+
150177
// Make a response object, write the data
151178
if payload, err := client.NewMultipartRequest(request, client.ContentTypeJson); err != nil {
152179
return nil, err
153-
} else if err := c.Do(payload, response, client.OptPath("audio/translations")); err != nil {
180+
} else if err := c.DoWithContext(ctx, payload, response, client.OptPath("audio/translations")); err != nil {
154181
return nil, err
155182
}
156183

0 commit comments

Comments
 (0)