Skip to content

Commit 5518fa7

Browse files
committed
Updated
1 parent 043e7e4 commit 5518fa7

File tree

5 files changed

+73
-6
lines changed

5 files changed

+73
-6
lines changed

cmd/cli/flags.go

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ package main
22

33
import (
44
"flag"
5+
"fmt"
56
"os"
6-
"strings"
7+
"path/filepath"
8+
"strconv"
79
"time"
810

911
// Packages
@@ -31,7 +33,7 @@ func NewFlags(name string, args []string, register ...FlagsRegister) (*Flags, er
3133
// Register flags
3234
flags.Bool("debug", false, "Enable debug logging")
3335
flags.Duration("timeout", 0, "Timeout")
34-
flags.String("out", "", "Output format (text, csv, json) or file name (.txt, .csv, .tsv, .json)")
36+
flags.String("out", "txt", "Output format (txt, csv, tsv, json) or file name (.txt, .csv, .tsv, .json)")
3537
for _, fn := range register {
3638
fn(flags)
3739
}
@@ -61,7 +63,29 @@ func (flags *Flags) Timeout() time.Duration {
6163

6264
func (flags *Flags) GetOut() string {
6365
v, _ := flags.GetString("out")
64-
return strings.ToLower(v)
66+
return v
67+
}
68+
69+
// Return a filename for output, returns an empty string if the output
70+
// argument is not a filename (it requires an extension)
71+
func (flags *Flags) GetOutFilename(def string, n uint) string {
72+
filename := flags.GetOut()
73+
if filename == "" {
74+
filename = filepath.Base(def)
75+
}
76+
if filename == "" {
77+
return ""
78+
}
79+
ext := filepath.Ext(filename)
80+
if ext == "" {
81+
return ""
82+
}
83+
if n > 0 {
84+
filename = filename[:len(filename)-len(ext)] + "-" + fmt.Sprint(n) + ext
85+
} else {
86+
filename = filename[:len(filename)-len(ext)] + ext
87+
}
88+
return filepath.Clean(filename)
6589
}
6690

6791
func (flags *Flags) GetString(key string) (string, error) {
@@ -72,6 +96,16 @@ func (flags *Flags) GetString(key string) (string, error) {
7296
}
7397
}
7498

99+
func (flags *Flags) GetUint(key string) (uint, error) {
100+
if flag := flags.Lookup(key); flag == nil {
101+
return 0, errors.ErrNotFound.With(key)
102+
} else if v, err := strconv.ParseUint(os.ExpandEnv(flag.Value.String()), 10, 64); err != nil {
103+
return 0, errors.ErrBadParameter.With(key)
104+
} else {
105+
return uint(v), nil
106+
}
107+
}
108+
75109
func (flags *Flags) Write(v any) error {
76110
opts := []writer.TableOpt{}
77111

cmd/cli/openai.go

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package main
22

33
import (
44
// Packages
5+
6+
"fmt"
7+
58
"github.com/mutablelogic/go-client/pkg/client"
69
"github.com/mutablelogic/go-client/pkg/openai"
710

@@ -15,6 +18,7 @@ import (
1518
func OpenAIFlags(flags *Flags) {
1619
flags.String("openai-api-key", "${OPENAI_API_KEY}", "OpenAI API key")
1720
flags.String("openai-model", "", "OpenAI Model")
21+
flags.Uint("openai-count", 0, "Number of results to return")
1822
}
1923

2024
func OpenAIRegister(cmd []Client, opts []client.ClientOpt, flags *Flags) ([]Client, error) {
@@ -76,9 +80,16 @@ func openaiImage(client *openai.Client, flags *Flags) CommandFn {
7680
opts := []openai.ImageOpt{
7781
openai.OptImageModel("dall-e-3"),
7882
}
79-
if model, err := flags.GetString("openai-model"); err != nil && model != "" {
83+
if model, err := flags.GetString("openai-model"); err != nil {
84+
return err
85+
} else if model != "" {
8086
opts = append(opts, openai.OptImageModel(model))
8187
}
88+
if count, err := flags.GetUint("openai-count"); err != nil {
89+
return err
90+
} else if count > 0 {
91+
opts = append(opts, openai.OptImageCount(int(count)))
92+
}
8293

8394
// Call API
8495
prompt := flags.Arg(2)
@@ -90,10 +101,16 @@ func openaiImage(client *openai.Client, flags *Flags) CommandFn {
90101
}
91102

92103
// Write images out
93-
for _, image := range images {
94-
if _, err := image.Write(client, flags.Output()); err != nil {
104+
for i, image := range images {
105+
if filename, err := image.Filename(); err != nil {
95106
return err
107+
} else {
108+
filename := flags.GetOutFilename(filename, uint(i))
109+
fmt.Println(i, filename)
96110
}
111+
//if _, err := image.Write(client, flags.Output()); err != nil {
112+
// return err
113+
//}
97114
}
98115
// Return success
99116
return nil

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ require (
66
github.com/djthorpe/go-errors v1.0.3
77
github.com/pkg/errors v0.9.1
88
github.com/stretchr/testify v1.8.4
9+
github.com/veandco/go-sdl2 v0.4.36
910
golang.org/x/exp v0.0.0-20231127185646-65229373498e
1011
golang.org/x/term v0.15.0
1112
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
88
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
99
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
1010
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
11+
github.com/veandco/go-sdl2 v0.4.36 h1:Ltydev536rRQodmIrTWFZ3dRp5A+/6t5CYvbi4Kvia0=
12+
github.com/veandco/go-sdl2 v0.4.36/go.mod h1:OROqMhHD43nT4/i9crJukyVecjPNYYuCofep6SNiAjY=
1113
golang.org/x/exp v0.0.0-20231127185646-65229373498e h1:Gvh4YaCaXNs6dKTlfgismwWZKyjVZXwOPfIyUaqU3No=
1214
golang.org/x/exp v0.0.0-20231127185646-65229373498e/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI=
1315
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=

pkg/openai/image.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"encoding/base64"
55
"io"
66
"net/http"
7+
"net/url"
8+
"path/filepath"
79

810
// Packages
911
"github.com/mutablelogic/go-client/pkg/client"
@@ -86,6 +88,17 @@ func (c *Client) ImageGenerate(prompt string, opts ...ImageOpt) ([]Image, error)
8688
return response.Data, nil
8789
}
8890

91+
// Return the filename from the image
92+
func (i Image) Filename() (string, error) {
93+
if i.Url == "" {
94+
return "", ErrBadParameter.With("Missing URL in image")
95+
} else if url, err := url.Parse(i.Url); err != nil {
96+
return "", err
97+
} else {
98+
return filepath.Base(url.Path), nil
99+
}
100+
}
101+
89102
// Write an image to a writer object and return the mimetype
90103
func (i Image) Write(c *Client, w io.Writer) (string, error) {
91104
var response imageWriter

0 commit comments

Comments
 (0)