Skip to content

Commit 3785dfa

Browse files
committed
Updated client
1 parent f0bc1bb commit 3785dfa

File tree

10 files changed

+422
-816
lines changed

10 files changed

+422
-816
lines changed

cmd/api/flags.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ func (flags *Flags) PrintCommandFlags(cmd string) {
309309

310310
func printFlag(w io.Writer, f *flag.Flag) {
311311
fmt.Fprintf(w, " -%v", f.Name)
312-
if len(f.DefValue) > 0 {
312+
if len(f.DefValue) > 0 && f.DefValue != "false" && f.DefValue != "0" && f.DefValue != "0s" {
313313
fmt.Fprintf(w, " (default %q)", f.DefValue)
314314
}
315315
fmt.Fprintf(w, "\n %v\n\n", f.Usage)

cmd/api/openai.go

Lines changed: 165 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,38 @@ package main
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
7+
"strings"
68

79
// Packages
8-
"github.com/djthorpe/go-tablewriter"
10+
tablewriter "github.com/djthorpe/go-tablewriter"
911
client "github.com/mutablelogic/go-client"
1012
openai "github.com/mutablelogic/go-client/pkg/openai"
13+
"github.com/mutablelogic/go-client/pkg/openai/schema"
14+
15+
// Namespace imports
16+
. "github.com/djthorpe/go-errors"
1117
)
1218

1319
///////////////////////////////////////////////////////////////////////////////
1420
// GLOBALS
1521

1622
var (
17-
openaiName = "openai"
18-
openaiClient *openai.Client
23+
openaiName = "openai"
24+
openaiClient *openai.Client
25+
openaiModel string
26+
openaiQuality bool
27+
openaiResponseFormat string
28+
openaiStyle string
29+
openaiFrequencyPenalty *float64
30+
openaiPresencePenalty *float64
31+
openaiMaxTokens uint64
32+
openaiCount *uint64
33+
openaiStream bool
34+
openaiTemperature *float64
35+
openaiUser string
36+
openaiSystemPrompt string
1937
)
2038

2139
///////////////////////////////////////////////////////////////////////////////
@@ -24,15 +42,29 @@ var (
2442
func openaiRegister(flags *Flags) {
2543
// Register flags
2644
flags.String(openaiName, "openai-api-key", "${OPENAI_API_KEY}", "OpenAI API key")
45+
// TODO flags.String(openaiName, "model", "", "The model to use")
46+
// TODO flags.Unsigned(openaiName, "max-tokens", 0, "The maximum number of tokens that can be generated in the chat completion")
47+
flags.Bool(openaiName, "hd", false, "Create images with finer details and greater consistency across the image")
48+
flags.String(openaiName, "response-format", "", "The format in which the generated images are returned")
49+
flags.String(openaiName, "style", "", "The style of the generated images. Must be one of vivid or natural")
50+
flags.String(openaiName, "user", "", "A unique identifier representing your end-user")
51+
flags.Float(openaiName, "frequency-penalty", 0, "The model's likelihood to repeat the same line verbatim")
52+
flags.Float(openaiName, "presence-penalty", 0, "The model's likelihood to talk about new topics")
53+
flags.Unsigned(openaiName, "n", 0, "How many chat completion choices to generate for each input message")
54+
// TODO flags.String(openaiName, "system", "", "The system prompt")
55+
// TODO flags.Bool(openaiName, "stream", false, "If set, partial message deltas will be sent, like in ChatGPT")
56+
// TODO flags.Float(openaiName, "temperature", 0, "Sampling temperature to use, between 0.0 and 2.0")
2757

2858
// Register commands
2959
flags.Register(Cmd{
3060
Name: openaiName,
3161
Description: "Interact with OpenAI, from https://platform.openai.com/docs/api-reference",
3262
Parse: openaiParse,
3363
Fn: []Fn{
34-
{Name: "models", Call: openaiModels, Description: "Gets a list of available models"},
35-
{Name: "model", Call: openaiModel, Description: "Return model information", MinArgs: 1, MaxArgs: 1, Syntax: "<model>"},
64+
{Name: "models", Call: openaiListModels, Description: "Gets a list of available models"},
65+
{Name: "model", Call: openaiGetModel, Description: "Return model information", MinArgs: 1, MaxArgs: 1, Syntax: "<model>"},
66+
{Name: "image", Call: openaiImage, Description: "Create image from a prompt", MinArgs: 1, Syntax: "<prompt>"},
67+
{Name: "chat", Call: openaiChat, Description: "Create a chat completion", MinArgs: 1, Syntax: "<text>..."},
3668
},
3769
})
3870
}
@@ -47,25 +79,151 @@ func openaiParse(flags *Flags, opts ...client.ClientOpt) error {
4779
openaiClient = client
4880
}
4981

82+
// Set arguments
83+
openaiModel = flags.GetString("model")
84+
openaiQuality = flags.GetBool("hd")
85+
openaiResponseFormat = flags.GetString("response-format")
86+
openaiStyle = flags.GetString("style")
87+
openaiStream = flags.GetBool("stream")
88+
openaiUser = flags.GetString("user")
89+
openaiSystemPrompt = flags.GetString("system")
90+
91+
if temp, err := flags.GetValue("temperature"); err == nil {
92+
t := temp.(float64)
93+
openaiTemperature = &t
94+
}
95+
if value, err := flags.GetValue("frequency-penalty"); err == nil {
96+
v := value.(float64)
97+
openaiFrequencyPenalty = &v
98+
}
99+
if value, err := flags.GetValue("presence-penalty"); err == nil {
100+
v := value.(float64)
101+
openaiPresencePenalty = &v
102+
}
103+
if maxtokens, err := flags.GetValue("max-tokens"); err == nil {
104+
t := maxtokens.(uint64)
105+
openaiMaxTokens = t
106+
}
107+
if count, err := flags.GetValue("n"); err == nil {
108+
v := count.(uint64)
109+
openaiCount = &v
110+
}
111+
50112
// Return success
51113
return nil
52114
}
53115

54116
///////////////////////////////////////////////////////////////////////////////
55117
// METHODS
56118

57-
func openaiModels(ctx context.Context, w *tablewriter.Writer, args []string) error {
119+
func openaiListModels(ctx context.Context, w *tablewriter.Writer, args []string) error {
58120
models, err := openaiClient.ListModels()
59121
if err != nil {
60122
return err
61123
}
62124
return w.Write(models)
63125
}
64126

65-
func openaiModel(ctx context.Context, w *tablewriter.Writer, args []string) error {
127+
func openaiGetModel(ctx context.Context, w *tablewriter.Writer, args []string) error {
66128
model, err := openaiClient.GetModel(args[0])
67129
if err != nil {
68130
return err
69131
}
70132
return w.Write(model)
71133
}
134+
135+
func openaiImage(ctx context.Context, w *tablewriter.Writer, args []string) error {
136+
opts := []openai.Opt{}
137+
prompt := strings.Join(args, " ")
138+
139+
// Process options
140+
if openaiModel != "" {
141+
opts = append(opts, openai.OptModel(openaiModel))
142+
}
143+
if openaiQuality {
144+
opts = append(opts, openai.OptQuality("hd"))
145+
}
146+
if openaiResponseFormat != "" {
147+
opts = append(opts, openai.OptResponseFormat(openaiResponseFormat))
148+
}
149+
if openaiStyle != "" {
150+
opts = append(opts, openai.OptStyle(openaiStyle))
151+
}
152+
if openaiUser != "" {
153+
opts = append(opts, openai.OptUser(openaiUser))
154+
}
155+
156+
// Request->Response
157+
response, err := openaiClient.CreateImages(ctx, prompt, opts...)
158+
if err != nil {
159+
return err
160+
} else if len(response) == 0 {
161+
return ErrUnexpectedResponse.With("no images returned")
162+
}
163+
164+
// Write each image
165+
var result error
166+
for _, image := range response {
167+
if n, err := openaiClient.WriteImage(w.Output(), image); err != nil {
168+
result = errors.Join(result, err)
169+
} else {
170+
openaiClient.Debugf("openaiImage: wrote %v bytes", n)
171+
}
172+
}
173+
174+
// Return success
175+
return nil
176+
}
177+
178+
func openaiChat(ctx context.Context, w *tablewriter.Writer, args []string) error {
179+
var messages []*schema.Message
180+
181+
// Set options
182+
opts := []openai.Opt{}
183+
if openaiModel != "" {
184+
opts = append(opts, openai.OptModel(openaiModel))
185+
}
186+
if openaiFrequencyPenalty != nil {
187+
opts = append(opts, openai.OptFrequencyPenalty(float32(*openaiFrequencyPenalty)))
188+
}
189+
if openaiPresencePenalty != nil {
190+
opts = append(opts, openai.OptPresencePenalty(float32(*openaiPresencePenalty)))
191+
}
192+
if openaiTemperature != nil {
193+
opts = append(opts, openai.OptTemperature(float32(*openaiTemperature)))
194+
}
195+
if openaiMaxTokens != 0 {
196+
opts = append(opts, openai.OptMaxTokens(int(openaiMaxTokens)))
197+
}
198+
if openaiCount != nil && *openaiCount > 1 {
199+
opts = append(opts, openai.OptCount(int(*openaiCount)))
200+
}
201+
if openaiResponseFormat != "" {
202+
// TODO: Should be an object, not a string
203+
opts = append(opts, openai.OptResponseFormat(openaiResponseFormat))
204+
}
205+
if openaiStream {
206+
opts = append(opts, openai.OptStream())
207+
}
208+
if openaiUser != "" {
209+
opts = append(opts, openai.OptUser(openaiUser))
210+
}
211+
if openaiSystemPrompt != "" {
212+
messages = append(messages, schema.NewMessage("system").Add(schema.Text(openaiSystemPrompt)))
213+
}
214+
215+
// Append user message
216+
message := schema.NewMessage("user")
217+
for _, arg := range args {
218+
message.Add(schema.Text(arg))
219+
}
220+
messages = append(messages, message)
221+
222+
// Request->Response
223+
responses, err := openaiClient.Chat(ctx, messages, opts...)
224+
if err != nil {
225+
return err
226+
}
227+
228+
return w.Write(responses)
229+
}

0 commit comments

Comments
 (0)