Skip to content

Commit e41a991

Browse files
committed
Updated
1 parent 1853a41 commit e41a991

File tree

3 files changed

+100
-9
lines changed

3 files changed

+100
-9
lines changed

cmd/cli/flags.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func NewFlags(name string, args []string, register ...FlagsRegister) (*Flags, er
3636
// Register flags
3737
flags.Bool("debug", false, "Enable debug logging")
3838
flags.Duration("timeout", 0, "Timeout")
39-
flags.String("out", "txt", "Output format <txt|csv|tsv|json> or file name <filename>.<txt|csv|tsv|json>")
39+
flags.String("out", "", "Output format or file name")
4040
flags.String("cols", "", "Comma-separated list of columns to output")
4141
for _, fn := range register {
4242
fn(flags)
@@ -141,6 +141,16 @@ func (flags *Flags) GetBool(key string) bool {
141141
}
142142
}
143143

144+
func (flags *Flags) GetFloat64(key string) *float64 {
145+
if flag := flags.Lookup(key); flag == nil {
146+
return nil
147+
} else if v, err := strconv.ParseFloat(os.ExpandEnv(flag.Value.String()), 64); err != nil {
148+
return nil
149+
} else {
150+
return &v
151+
}
152+
}
153+
144154
func (flags *Flags) Write(v any) error {
145155
opts := []tablewriter.TableOpt{}
146156

cmd/cli/openai.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ func OpenAIFlags(flags *Flags) {
4242
flags.String("size", "", "Size of output image (256x256, 512x512, 1024x1024, 1792x1024 or 1024x1792)")
4343
flags.Bool("open", false, "Open images in default viewer")
4444
flags.String("language", "", "Audio language")
45+
flags.String("prompt", "", "Text to guide the transcription style or continue a previous audio segment")
46+
flags.Float64("temperature", 0, "Sampling temperature for generation")
4547
}
4648

4749
func OpenAIRegister(cmd []Client, opts []client.ClientOpt, flags *Flags) ([]Client, error) {
@@ -60,6 +62,7 @@ func OpenAIRegister(cmd []Client, opts []client.ClientOpt, flags *Flags) ([]Clie
6062
{Name: "image", Description: "Create images from a prompt", Syntax: "<prompt>", MinArgs: 3, MaxArgs: 3, Fn: openaiImages(openai, flags)},
6163
{Name: "speak", Description: "Create speech from a prompt", Syntax: "(<voice>) <prompt>", MinArgs: 3, MaxArgs: 4, Fn: openaiSpeak(openai, flags)},
6264
{Name: "transcribe", Description: "Transcribe audio to text", Syntax: "<filename>", MinArgs: 3, MaxArgs: 3, Fn: openaiTranscribe(openai, flags)},
65+
{Name: "translate", Description: "Translate audio to English", Syntax: "<filename>", MinArgs: 3, MaxArgs: 3, Fn: openaiTranslate(openai, flags)},
6366
},
6467
})
6568

@@ -99,9 +102,51 @@ func openaiTranscribe(client *openai.Client, flags *Flags) CommandFn {
99102
if model := flags.GetString("model"); model != "" {
100103
opts = append(opts, openai.OptModel(model))
101104
}
105+
if prompt := flags.GetString("prompt"); prompt != "" {
106+
opts = append(opts, openai.OptPrompt(prompt))
107+
}
102108
if language := flags.GetString("language"); language != "" {
103109
opts = append(opts, openai.OptLanguage(language))
104110
}
111+
if temp := flags.GetFloat64("temperature"); temp != nil && *temp > 0 {
112+
opts = append(opts, openai.OptTemperature(*temp))
113+
}
114+
if format := flags.GetOutExt(); format != "" {
115+
opts = append(opts, openai.OptResponseFormat(format))
116+
}
117+
118+
// Open audio file for reading
119+
r, err := os.Open(flags.Arg(2))
120+
if err != nil {
121+
return err
122+
}
123+
defer r.Close()
124+
125+
// Perform transcription
126+
if transcription, err := client.Transcribe(r, opts...); err != nil {
127+
return err
128+
} else if err := flags.Write(transcription); err != nil {
129+
return err
130+
}
131+
132+
// Return success
133+
return nil
134+
}
135+
}
136+
137+
func openaiTranslate(client *openai.Client, flags *Flags) CommandFn {
138+
return func() error {
139+
// Set options
140+
opts := []openai.Opt{}
141+
if model := flags.GetString("model"); model != "" {
142+
opts = append(opts, openai.OptModel(model))
143+
}
144+
if prompt := flags.GetString("prompt"); prompt != "" {
145+
opts = append(opts, openai.OptPrompt(prompt))
146+
}
147+
if temp := flags.GetFloat64("temperature"); temp != nil && *temp > 0 {
148+
opts = append(opts, openai.OptTemperature(*temp))
149+
}
105150
if format := flags.GetOutExt(); format != "" {
106151
opts = append(opts, openai.OptResponseFormat(format))
107152
}

pkg/openai/audio.go

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,28 @@ type reqTranscribe struct {
3939
TimestampGranularities []string `json:"timestamp_granularities,omitempty"`
4040
}
4141

42+
// Represents a transcription response returned by model, based on the provided input.
4243
type Transcription struct {
4344
Task string `json:"task,omitempty"`
44-
Language string `json:"language,omitempty"`
45-
Duration float64 `json:"duration,omitempty"`
45+
Language string `json:"language,omitempty"` // The language of the input audio.
46+
Duration float64 `json:"duration,omitempty"` // The duration of the input audio.
4647
Text string `json:"text"`
48+
Words []struct {
49+
Word string `json:"word"` // The text content of the word.
50+
Start float64 `json:"start"` // Start time of the word in seconds.
51+
End float64 `json:"end"` // End time of the word in seconds.
52+
} `json:"words,omitempty"` // Extracted words and their corresponding timestamps.
4753
Segments []struct {
4854
Id uint `json:"id"`
4955
Seek uint `json:"seek"`
5056
Start float64 `json:"start"`
5157
End float64 `json:"end"`
5258
Text string `json:"text"`
53-
Tokens []uint `json:"tokens"`
54-
Temperature float64 `json:"temperature,omitempty"`
55-
AvgLogProbability float64 `json:"avg_logprob,omitempty"`
56-
CompressionRatio float64 `json:"compression_ratio,omitempty"`
57-
NoSpeechProbability float64 `json:"no_speech_prob,omitempty"`
59+
Tokens []uint `json:"tokens"` // Array of token IDs for the text content.
60+
Temperature float64 `json:"temperature,omitempty"` // Temperature parameter used for generating the segment.
61+
AvgLogProbability float64 `json:"avg_logprob,omitempty"` // Average logprob of the segment. If the value is lower than -1, consider the logprobs failed.
62+
CompressionRatio float64 `json:"compression_ratio,omitempty"` // Compression ratio of the segment. If the value is greater than 2.4, consider the compression failed.
63+
NoSpeechProbability float64 `json:"no_speech_prob,omitempty"` // Probability of no speech in the segment. If the value is higher than 1.0 and the avg_logprob is below -1, consider this segment silent.
5864
} `json:"segments,omitempty"`
5965
}
6066

@@ -107,7 +113,7 @@ func (c *Client) Transcribe(r io.Reader, opts ...Opt) (*Transcription, error) {
107113
// Create the request and set up the response
108114
request.Model = defaultTranscribeModel
109115
request.File = multipart.File{
110-
Path: "output.mp3",
116+
Path: "output.mp3", // TODO: Change this
111117
Body: r,
112118
}
113119

@@ -129,6 +135,36 @@ func (c *Client) Transcribe(r io.Reader, opts ...Opt) (*Transcription, error) {
129135
return response, nil
130136
}
131137

138+
// Translate audio into English
139+
func (c *Client) Translate(r io.Reader, opts ...Opt) (*Transcription, error) {
140+
var request reqTranscribe
141+
response := new(Transcription)
142+
143+
// Create the request and set up the response
144+
request.Model = defaultTranscribeModel
145+
request.File = multipart.File{
146+
Path: "output.mp3", // TODO: Change this
147+
Body: r,
148+
}
149+
150+
// Set options
151+
for _, opt := range opts {
152+
if err := opt(&request); err != nil {
153+
return nil, err
154+
}
155+
}
156+
157+
// Make a response object, write the data
158+
if payload, err := client.NewMultipartRequest(request, client.ContentTypeJson); err != nil {
159+
return nil, err
160+
} else if err := c.Do(payload, response, client.OptPath("audio/translations")); err != nil {
161+
return nil, err
162+
}
163+
164+
// Return success
165+
return response, nil
166+
}
167+
132168
///////////////////////////////////////////////////////////////////////////////
133169
// Unmarshal
134170

0 commit comments

Comments
 (0)