Skip to content

Commit ad72e7a

Browse files
committed
Added OpenAI speech
1 parent c928835 commit ad72e7a

File tree

10 files changed

+402
-155
lines changed

10 files changed

+402
-155
lines changed

cmd/cli/flags.go

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"os"
77
"path/filepath"
88
"strconv"
9+
"strings"
910
"time"
1011

1112
// Packages
@@ -65,13 +66,24 @@ func (flags *Flags) Timeout() time.Duration {
6566
}
6667

6768
func (flags *Flags) GetOut() string {
68-
v, _ := flags.GetString("out")
69-
return v
69+
return flags.GetString("out")
70+
}
71+
72+
func (flags *Flags) GetOutExt() string {
73+
out := flags.GetOut()
74+
if out == "" {
75+
return ""
76+
}
77+
if ext := filepath.Ext(out); ext == "" {
78+
return out
79+
} else {
80+
return strings.TrimPrefix(ext, ".")
81+
}
7082
}
7183

7284
// Return a filename for output, returns an empty string if the output
7385
// argument is not a filename (it requires an extension)
74-
func (flags *Flags) GetOutFilename(def string, n uint) string {
86+
func (flags *Flags) GetOutFilename(def string, n int) string {
7587
filename := flags.GetOut()
7688
if filename == "" {
7789
filename = filepath.Base(def)
@@ -91,11 +103,11 @@ func (flags *Flags) GetOutFilename(def string, n uint) string {
91103
return filepath.Clean(filename)
92104
}
93105

94-
func (flags *Flags) GetString(key string) (string, error) {
106+
func (flags *Flags) GetString(key string) string {
95107
if flag := flags.Lookup(key); flag == nil {
96-
return "", ErrNotFound.With(key)
108+
return ""
97109
} else {
98-
return os.ExpandEnv(flag.Value.String()), nil
110+
return os.ExpandEnv(flag.Value.String())
99111
}
100112
}
101113

@@ -109,6 +121,26 @@ func (flags *Flags) GetUint(key string) (uint, error) {
109121
}
110122
}
111123

124+
func (flags *Flags) GetInt(key string) (int, error) {
125+
if flag := flags.Lookup(key); flag == nil {
126+
return 0, ErrNotFound.With(key)
127+
} else if v, err := strconv.ParseInt(os.ExpandEnv(flag.Value.String()), 10, 64); err != nil {
128+
return 0, ErrBadParameter.With(key)
129+
} else {
130+
return int(v), nil
131+
}
132+
}
133+
134+
func (flags *Flags) GetBool(key string) bool {
135+
if flag := flags.Lookup(key); flag == nil {
136+
return false
137+
} else if v, err := strconv.ParseBool(os.ExpandEnv(flag.Value.String())); err != nil {
138+
return false
139+
} else {
140+
return v
141+
}
142+
}
143+
112144
func (flags *Flags) Write(v any) error {
113145
opts := []tablewriter.TableOpt{}
114146

cmd/cli/homeassistant.go

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,8 @@ func HomeAssistantFlags(flags *Flags) {
1515
}
1616

1717
func HomeAssistantRegister(cmd []Client, opts []client.ClientOpt, flags *Flags) ([]Client, error) {
18-
// Get API key
19-
key, err := flags.GetString("ha-token")
20-
if err != nil {
21-
return nil, err
22-
}
23-
24-
// Get endpoint
25-
endPoint, err := flags.GetString("ha-endpoint")
26-
if err != nil {
27-
return nil, err
28-
}
29-
30-
// Create ipify client
31-
ha, err := homeassistant.New(endPoint, key, opts...)
18+
// Create home assistant client
19+
ha, err := homeassistant.New(flags.GetString("ha-endpoint"), flags.GetString("ha-token"), opts...)
3220
if err != nil {
3321
return nil, err
3422
}

cmd/cli/open.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package main
2+
3+
import (
4+
"os/exec"
5+
"runtime"
6+
)
7+
8+
// open opens the specified files with the operating system
9+
func open(url ...string) error {
10+
var cmd string
11+
var args []string
12+
13+
switch runtime.GOOS {
14+
case "windows":
15+
cmd = "cmd"
16+
args = []string{"/c", "start"}
17+
case "darwin":
18+
cmd = "open"
19+
default: // "linux", "freebsd", "openbsd", "netbsd"
20+
cmd = "xdg-open"
21+
}
22+
args = append(args, url...)
23+
return exec.Command(cmd, args...).Start()
24+
}

cmd/cli/openai.go

Lines changed: 178 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,51 @@
11
package main
22

33
import (
4+
"errors"
5+
"fmt"
6+
"net/url"
7+
"os"
8+
"path/filepath"
9+
"regexp"
10+
"strconv"
11+
412
"github.com/mutablelogic/go-client/pkg/client"
513
"github.com/mutablelogic/go-client/pkg/openai"
614
)
715

16+
/////////////////////////////////////////////////////////////////////
17+
// TYPES
18+
19+
type openaiImageResponse struct {
20+
Url string `json:"-"`
21+
Path string `json:"path"`
22+
Bytes uint `json:"bytes_written"`
23+
}
24+
25+
/////////////////////////////////////////////////////////////////////
26+
// GLOBALS
27+
28+
var (
29+
reOpenAISize = regexp.MustCompile(`^(\d+)x(\d+)$`)
30+
defaultVoice = "alloy"
31+
)
32+
833
/////////////////////////////////////////////////////////////////////
934
// REGISTER FUNCTIONS
1035

1136
func OpenAIFlags(flags *Flags) {
1237
flags.String("openai-api-key", "${OPENAI_API_KEY}", "OpenAI API key")
13-
flags.String("openai-model", "", "OpenAI Model")
14-
flags.Uint("openai-count", 0, "Number of results to return")
38+
flags.String("model", "", "Model to use for generation")
39+
flags.Uint("count", 0, "Number of results to return")
40+
flags.Bool("natural", false, "Create more natural images")
41+
flags.Bool("hd", false, "Create images with finer details and greater consistency across the image")
42+
flags.String("size", "", "Size of output image (256x256, 512x512, 1024x1024, 1792x1024 or 1024x1792)")
43+
flags.Bool("open", false, "Open images in default viewer")
1544
}
1645

1746
func OpenAIRegister(cmd []Client, opts []client.ClientOpt, flags *Flags) ([]Client, error) {
18-
// Get API key
19-
key, err := flags.GetString("openai-api-key")
20-
if err != nil {
21-
return nil, err
22-
}
23-
2447
// Create client
25-
openai, err := openai.New(key, opts...)
48+
openai, err := openai.New(flags.GetString("openai-api-key"), opts...)
2649
if err != nil {
2750
return nil, err
2851
}
@@ -32,6 +55,9 @@ func OpenAIRegister(cmd []Client, opts []client.ClientOpt, flags *Flags) ([]Clie
3255
ns: "openai",
3356
cmd: []Command{
3457
{Name: "models", Description: "Return registered models", MinArgs: 2, MaxArgs: 2, Fn: openaiModels(openai, flags)},
58+
{Name: "model", Description: "Return model information", Syntax: "<model>", MinArgs: 3, MaxArgs: 3, Fn: openaiModel(openai, flags)},
59+
{Name: "image", Description: "Create images from a prompt", Syntax: "<prompt>", MinArgs: 3, MaxArgs: 3, Fn: openaiImages(openai, flags)},
60+
{Name: "speak", Description: "Create speech from a prompt", Syntax: "(<voice>) <prompt>", MinArgs: 3, MaxArgs: 4, Fn: openaiSpeak(openai, flags)},
3561
},
3662
})
3763

@@ -52,3 +78,146 @@ func openaiModels(client *openai.Client, flags *Flags) CommandFn {
5278
return nil
5379
}
5480
}
81+
82+
func openaiModel(client *openai.Client, flags *Flags) CommandFn {
83+
return func() error {
84+
if model, err := client.GetModel(flags.Arg(2)); err != nil {
85+
return err
86+
} else if err := flags.Write(model); err != nil {
87+
return err
88+
}
89+
return nil
90+
}
91+
}
92+
93+
func openaiSpeak(client *openai.Client, flags *Flags) CommandFn {
94+
return func() error {
95+
// Set options
96+
opts := []openai.Opt{}
97+
if model := flags.GetString("model"); model != "" {
98+
opts = append(opts, openai.OptModel(model))
99+
}
100+
if format := flags.GetOutExt(); format != "" {
101+
opts = append(opts, openai.OptResponseFormat(format))
102+
}
103+
var voice, prompt string
104+
if flags.NArg() == 4 {
105+
voice = flags.Arg(2)
106+
prompt = flags.Arg(3)
107+
} else {
108+
voice = defaultVoice
109+
prompt = flags.Arg(2)
110+
}
111+
112+
// Determine the filename
113+
w, err := os.Create("output.mp3")
114+
if err != nil {
115+
return err
116+
}
117+
defer w.Close()
118+
119+
// Create the audio
120+
response, err := client.Speech(w, voice, prompt, opts...)
121+
if err != nil {
122+
return err
123+
}
124+
125+
// Open images
126+
if flags.GetBool("open") {
127+
if err := open("output.mp3"); err != nil {
128+
return err
129+
}
130+
}
131+
132+
fmt.Println(response, "bytes written")
133+
134+
// Return any errors
135+
return nil
136+
}
137+
}
138+
139+
func openaiImages(client *openai.Client, flags *Flags) CommandFn {
140+
return func() error {
141+
// Set options
142+
opts := []openai.Opt{}
143+
if model := flags.GetString("model"); model != "" {
144+
opts = append(opts, openai.OptModel(model))
145+
}
146+
if count, err := flags.GetInt("count"); err != nil {
147+
return err
148+
} else if count > 0 {
149+
opts = append(opts, openai.OptCount(count))
150+
}
151+
if flags.GetBool("hd") {
152+
opts = append(opts, openai.OptQuality("hd"), openai.OptModel("dall-e-3"))
153+
}
154+
if flags.GetBool("natural") {
155+
opts = append(opts, openai.OptStyle("natural"))
156+
}
157+
if size := flags.GetString("size"); size != "" {
158+
if width, height, err := openaiSize(size); err != nil {
159+
return err
160+
} else {
161+
opts = append(opts, openai.OptSize(width, height))
162+
}
163+
}
164+
if format := flags.GetOutExt(); format != "" {
165+
opts = append(opts, openai.OptResponseFormat(format))
166+
}
167+
168+
// Create images
169+
response, err := client.CreateImages(flags.Arg(2), opts...)
170+
if err != nil {
171+
return err
172+
}
173+
174+
// Write out images
175+
var result error
176+
var written []openaiImageResponse
177+
for _, image := range response {
178+
if url, err := url.Parse(image.Url); err != nil {
179+
result = errors.Join(result, err)
180+
} else if w, err := os.Create(filepath.Base(url.Path)); err != nil {
181+
result = errors.Join(result, err)
182+
} else {
183+
defer w.Close()
184+
if n, err := client.WriteImage(w, image); err != nil {
185+
result = errors.Join(result, err)
186+
} else {
187+
written = append(written, openaiImageResponse{Url: image.Url, Bytes: uint(n), Path: w.Name()})
188+
}
189+
}
190+
}
191+
192+
// Open images
193+
if flags.GetBool("open") {
194+
var paths []string
195+
for _, image := range written {
196+
paths = append(paths, image.Path)
197+
}
198+
if err := open(paths...); err != nil {
199+
result = errors.Join(result, err)
200+
}
201+
} else if err := flags.Write(written); err != nil {
202+
result = errors.Join(result, err)
203+
}
204+
205+
// Return any errors
206+
return result
207+
}
208+
}
209+
210+
/////////////////////////////////////////////////////////////////////
211+
// PRIVATE METHODS
212+
213+
func openaiSize(size string) (uint, uint, error) {
214+
if n := reOpenAISize.FindStringSubmatch(size); n == nil || len(n) != 3 {
215+
return 0, 0, errors.New("invalid size, should be <width>x<height>")
216+
} else if w, err := strconv.ParseUint(n[1], 10, 64); err != nil {
217+
return 0, 0, err
218+
} else if h, err := strconv.ParseUint(n[2], 10, 64); err != nil {
219+
return 0, 0, err
220+
} else {
221+
return uint(w), uint(h), nil
222+
}
223+
}

0 commit comments

Comments
 (0)