Skip to content

Commit 6106611

Browse files
authored
Merge pull request #27 from mutablelogic/v1
Updated the mistral client
2 parents f96eb6f + 51f3c6b commit 6106611

22 files changed

+655
-346
lines changed

cmd/api/anthropic.go

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@ import (
1515
// GLOBALS
1616

1717
var (
18-
anthropicName = "claude"
19-
anthropicClient *anthropic.Client
18+
anthropicName = "claude"
19+
anthropicClient *anthropic.Client
20+
anthropicModel string
21+
anthropicTemperature *float64
22+
anthropicMaxTokens *uint64
23+
anthropicStream bool
2024
)
2125

2226
///////////////////////////////////////////////////////////////////////////////
@@ -46,6 +50,21 @@ func anthropicParse(flags *Flags, opts ...client.ClientOpt) error {
4650
anthropicClient = client
4751
}
4852

53+
// Get the command-line parameters
54+
anthropicModel = flags.GetString("model")
55+
if temp, err := flags.GetValue("temperature"); err == nil {
56+
t := temp.(float64)
57+
anthropicTemperature = &t
58+
}
59+
if maxtokens, err := flags.GetValue("max-tokens"); err == nil {
60+
t := maxtokens.(uint64)
61+
anthropicMaxTokens = &t
62+
}
63+
if stream, err := flags.GetValue("stream"); err == nil {
64+
t := stream.(bool)
65+
anthropicStream = t
66+
}
67+
4968
// Return success
5069
return nil
5170
}
@@ -54,7 +73,36 @@ func anthropicParse(flags *Flags, opts ...client.ClientOpt) error {
5473
// METHODS
5574

5675
func anthropicChat(ctx context.Context, w *tablewriter.Writer, args []string) error {
57-
// Request -> Response
76+
77+
// Set options
78+
opts := []anthropic.Opt{}
79+
if anthropicModel != "" {
80+
opts = append(opts, anthropic.OptModel(anthropicModel))
81+
}
82+
if anthropicTemperature != nil {
83+
opts = append(opts, anthropic.OptTemperature(float32(*anthropicTemperature)))
84+
}
85+
if anthropicMaxTokens != nil {
86+
opts = append(opts, anthropic.OptMaxTokens(int(*anthropicMaxTokens)))
87+
}
88+
if anthropicStream {
89+
opts = append(opts, anthropic.OptStream(func(choice schema.MessageChoice) {
90+
w := w.Output()
91+
if choice.Delta != nil {
92+
if choice.Delta.Role != "" {
93+
fmt.Fprintf(w, "\n%v: ", choice.Delta.Role)
94+
}
95+
if choice.Delta.Content != "" {
96+
fmt.Fprintf(w, "%v", choice.Delta.Content)
97+
}
98+
}
99+
if choice.FinishReason != "" {
100+
fmt.Printf("\nfinish_reason: %q\n", choice.FinishReason)
101+
}
102+
}))
103+
}
104+
105+
// Append user message
58106
message := schema.NewMessage("user")
59107
for _, arg := range args {
60108
message.Add(schema.Text(arg))
@@ -63,11 +111,15 @@ func anthropicChat(ctx context.Context, w *tablewriter.Writer, args []string) er
63111
// Request -> Response
64112
responses, err := anthropicClient.Messages(ctx, []*schema.Message{
65113
message,
66-
})
114+
}, opts...)
67115
if err != nil {
68116
return err
69117
}
70118

71-
// Write table
72-
return w.Write(responses)
119+
// Write table (if not streaming)
120+
if !anthropicStream {
121+
return w.Write(responses)
122+
} else {
123+
return nil
124+
}
73125
}

cmd/api/mistral.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ func mistralChat(ctx context.Context, w *tablewriter.Writer, args []string) erro
131131
opts = append(opts, mistral.OptModel(mistralModel))
132132
}
133133
if mistralTemperature != nil {
134-
opts = append(opts, mistral.OptTemperature(*mistralTemperature))
134+
opts = append(opts, mistral.OptTemperature(float32(*mistralTemperature)))
135135
}
136136
if mistralMaxTokens != nil {
137137
opts = append(opts, mistral.OptMaxTokens(int(*mistralMaxTokens)))

cmd/api/samantha.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ func samChat(ctx context.Context, w *tablewriter.Writer, _ []string) error {
181181
}
182182
}
183183

184-
func samCall(_ context.Context, content schema.Content) *schema.Content {
184+
func samCall(_ context.Context, content *schema.Content) *schema.Content {
185185
anthropicClient.Debugf("%v: %v: %v", content.Type, content.Name, content.Input)
186186
if content.Type != "tool_use" {
187187
return schema.ToolResult(content.Id, fmt.Sprint("unexpected content type:", content.Type))

pkg/anthropic/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ const (
2222
endPoint = "https://api.anthropic.com/v1"
2323
defaultVersion = "2023-06-01"
2424
defaultMessageModel = "claude-3-haiku-20240307"
25-
defaultMaxTokens = 4096
25+
defaultMaxTokens = 1024
2626
)
2727

2828
///////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)