Skip to content

Commit 6b9e053

Browse files
authored
Merge pull request #21 from mutablelogic/v1
Added an assistant called samantha
2 parents aab9fd9 + 3dbae74 commit 6b9e053

File tree

4 files changed

+232
-6
lines changed

4 files changed

+232
-6
lines changed

cmd/api/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ func main() {
2222
anthropicRegister(flags)
2323
newsapiRegister(flags)
2424
weatherapiRegister(flags)
25+
samRegister(flags)
2526

2627
// Parse command line and return function to run
2728
fn, args, err := flags.Parse(os.Args[1:])

cmd/api/newsapi.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func newsapiRegister(flags *Flags) {
3535
flags.Register(Cmd{
3636
Name: newsapiName,
3737
Description: "Obtain news headlines from https://newsapi.org/",
38-
Parse: inewsapiParse,
38+
Parse: newsapiParse,
3939
Fn: []Fn{
4040
{Name: "sources", Call: newsapiSources, Description: "Return sources of news"},
4141
{Name: "headlines", Call: newsapiHeadlines, Description: "Return top headlines from news sources"},
@@ -44,7 +44,7 @@ func newsapiRegister(flags *Flags) {
4444
})
4545
}
4646

47-
func inewsapiParse(flags *Flags, opts ...client.ClientOpt) error {
47+
func newsapiParse(flags *Flags, opts ...client.ClientOpt) error {
4848
apiKey := flags.GetString("newsapi-key")
4949
if apiKey == "" {
5050
return fmt.Errorf("missing -newsapi-key flag")

cmd/api/samantha.go

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
package main
2+
3+
import (
4+
"bufio"
5+
"context"
6+
"encoding/json"
7+
"fmt"
8+
"os"
9+
"strings"
10+
11+
// Packages
12+
"github.com/djthorpe/go-tablewriter"
13+
"github.com/mutablelogic/go-client"
14+
"github.com/mutablelogic/go-client/pkg/anthropic"
15+
"github.com/mutablelogic/go-client/pkg/newsapi"
16+
"github.com/mutablelogic/go-client/pkg/openai/schema"
17+
)
18+
19+
///////////////////////////////////////////////////////////////////////////////
20+
// GLOBALS
21+
22+
var (
23+
samName = "sam"
24+
samWeatherTool = schema.NewTool("get_weather", "Get the weather for a location")
25+
samNewsHeadlinesTool = schema.NewTool("get_news_headlines", "Get the news headlines")
26+
samNewsSearchTool = schema.NewTool("search_news", "Search news articles")
27+
samSystemPrompt = `Your name is Samantha, you are a friendly AI assistant, here to help you with
28+
anything you need. Your responses should be short and to the point, and you should always be polite.`
29+
)
30+
31+
///////////////////////////////////////////////////////////////////////////////
32+
// LIFECYCLE
33+
34+
func samRegister(flags *Flags) {
35+
flags.Register(Cmd{
36+
Name: samName,
37+
Description: "Interact with Samantha, a friendly AI assistant, to query news and weather",
38+
Parse: samParse,
39+
Fn: []Fn{
40+
{Name: "chat", Call: samChat, Description: "Chat with Sam"},
41+
},
42+
})
43+
}
44+
45+
func samParse(flags *Flags, opts ...client.ClientOpt) error {
46+
// Initialize weather
47+
if err := weatherapiParse(flags, opts...); err != nil {
48+
return err
49+
}
50+
// Initialize news
51+
if err := newsapiParse(flags, opts...); err != nil {
52+
return err
53+
}
54+
55+
// Initialize anthropic
56+
opts = append(opts, client.OptHeader("Anthropic-Beta", "tools-2024-04-04"))
57+
if err := anthropicParse(flags, opts...); err != nil {
58+
return err
59+
}
60+
61+
// Add tool parameters
62+
if err := samWeatherTool.AddParameter("location", "The city to get the weather for", true); err != nil {
63+
return err
64+
}
65+
if err := samNewsHeadlinesTool.AddParameter("category", "The cateogry of news, which should be one of business, entertainment, general, health, science, sports or technology", true); err != nil {
66+
return err
67+
}
68+
if err := samNewsSearchTool.AddParameter("query", "The query with which to search news", true); err != nil {
69+
return err
70+
}
71+
72+
// Return success
73+
return nil
74+
}
75+
76+
///////////////////////////////////////////////////////////////////////////////
77+
// METHODS
78+
79+
func samChat(ctx context.Context, w *tablewriter.Writer, _ []string) error {
80+
var toolResult bool
81+
82+
messages := []*schema.Message{}
83+
for {
84+
if ctx.Err() != nil {
85+
return nil
86+
}
87+
88+
// Read if there hasn't been any tool results yet
89+
if !toolResult {
90+
reader := bufio.NewReader(os.Stdin)
91+
fmt.Print("Chat: ")
92+
text, err := reader.ReadString('\n')
93+
if err != nil {
94+
return err
95+
}
96+
messages = append(messages, schema.NewMessage("user", schema.Text(strings.TrimSpace(text))))
97+
}
98+
99+
// Curtail requests to the last N history
100+
if len(messages) > 10 {
101+
messages = messages[len(messages)-10:]
102+
// First message must have role 'user'
103+
for {
104+
if len(messages) == 0 || messages[0].Role == "user" {
105+
break
106+
}
107+
messages = messages[1:]
108+
}
109+
// TODO: We must remove the first instance tool_result if there is no tool_use
110+
}
111+
112+
// Request -> Response
113+
responses, err := anthropicClient.Messages(ctx, messages, anthropic.OptSystem(samSystemPrompt), anthropic.OptTool(samWeatherTool), anthropic.OptTool(samNewsHeadlinesTool), anthropic.OptTool(samNewsSearchTool))
114+
if err != nil {
115+
return err
116+
}
117+
toolResult = false
118+
119+
for _, response := range responses {
120+
switch response.Type {
121+
case "text":
122+
messages = samAppend(messages, schema.NewMessage("assistant", schema.Text(response.Text)))
123+
fmt.Println(response.Text)
124+
fmt.Println("")
125+
case "tool_use":
126+
messages = samAppend(messages, schema.NewMessage("assistant", response))
127+
result := samCall(ctx, response)
128+
messages = samAppend(messages, schema.NewMessage("user", result))
129+
toolResult = true
130+
}
131+
}
132+
}
133+
}
134+
135+
func samCall(_ context.Context, content schema.Content) *schema.Content {
136+
if content.Type != "tool_use" {
137+
return schema.ToolResult(content.Id, fmt.Sprint("unexpected content type:", content.Type))
138+
}
139+
switch content.Name {
140+
case samWeatherTool.Name:
141+
var location string
142+
if v, exists := content.GetString(content.Name, "location"); exists {
143+
location = v
144+
} else {
145+
location = "auto:ip"
146+
}
147+
if weather, err := weatherapiClient.Current(location); err != nil {
148+
return schema.ToolResult(content.Id, fmt.Sprint("Unable to get the current weather, the error is ", err))
149+
} else if data, err := json.MarshalIndent(weather, "", " "); err != nil {
150+
return schema.ToolResult(content.Id, fmt.Sprint("Unable to marshal the weather data, the error is ", err))
151+
} else {
152+
return schema.ToolResult(content.Id, string(data))
153+
}
154+
case samNewsHeadlinesTool.Name:
155+
var category string
156+
if v, exists := content.GetString(content.Name, "category"); exists {
157+
category = v
158+
} else {
159+
category = "general"
160+
}
161+
if headlines, err := newsapiClient.Headlines(newsapi.OptCategory(category)); err != nil {
162+
return schema.ToolResult(content.Id, fmt.Sprint("Unable to get the news headlines, the error is ", err))
163+
} else if data, err := json.MarshalIndent(headlines, "", " "); err != nil {
164+
return schema.ToolResult(content.Id, fmt.Sprint("Unable to marshal the headlines data, the error is ", err))
165+
} else {
166+
return schema.ToolResult(content.Id, string(data))
167+
}
168+
case samNewsSearchTool.Name:
169+
var query string
170+
if v, exists := content.GetString(content.Name, "query"); exists {
171+
query = v
172+
} else {
173+
return schema.ToolResult(content.Id, "Unable to search news due to missing query")
174+
}
175+
if articles, err := newsapiClient.Articles(newsapi.OptQuery(query), newsapi.OptLimit(5)); err != nil {
176+
return schema.ToolResult(content.Id, fmt.Sprint("Unable to search news, the error is ", err))
177+
} else if data, err := json.MarshalIndent(articles, "", " "); err != nil {
178+
return schema.ToolResult(content.Id, fmt.Sprint("Unable to marshal the articles data, the error is ", err))
179+
} else {
180+
return schema.ToolResult(content.Id, string(data))
181+
}
182+
}
183+
return schema.ToolResult(content.Id, fmt.Sprint("unable to call:", content.Name))
184+
}
185+
186+
func samAppend(messages []*schema.Message, message *schema.Message) []*schema.Message {
187+
// if the previous message was of the same role, then append the new message to the previous one
188+
if len(messages) > 0 && messages[len(messages)-1].Role == message.Role {
189+
messages[len(messages)-1].Add(message.Content)
190+
return messages
191+
} else {
192+
return append(messages, message)
193+
}
194+
}

pkg/openai/schema/message.go

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ type Content struct {
4646
Text string `json:"text,omitempty,wrap,width:60"`
4747
Source *contentSource `json:"source,omitempty"`
4848
toolUse
49+
50+
ToolId string `json:"tool_use_id,omitempty"`
51+
Result string `json:"content,omitempty"`
4952
}
5053

5154
// Content Source
@@ -55,12 +58,18 @@ type contentSource struct {
5558
Data string `json:"data,omitempty"`
5659
}
5760

58-
// Tool arguments
61+
// Tool call
5962
type toolUse struct {
6063
Name string `json:"name,omitempty"`
6164
Input map[string]any `json:"input,omitempty"`
6265
}
6366

67+
// Tool result
68+
type toolResult struct {
69+
ToolId string `json:"tool_use_id,omitempty"`
70+
Result string `json:"content,omitempty"`
71+
}
72+
6473
///////////////////////////////////////////////////////////////////////////////
6574
// LIFECYCLE
6675

@@ -112,6 +121,11 @@ func ImageData(path string) (*Content, error) {
112121
return Image(r)
113122
}
114123

124+
// Return a tool result
125+
func ToolResult(id string, result string) *Content {
126+
return &Content{Type: "tool_result", ToolId: id, Result: result}
127+
}
128+
115129
///////////////////////////////////////////////////////////////////////////////
116130
// STRINGIFY
117131

@@ -148,6 +162,19 @@ func (m *Message) Add(content ...any) *Message {
148162
return m
149163
}
150164

165+
// Return an input parameter as a string, returns false if the name
166+
// is incorrect or the input doesn't exist
167+
func (c Content) GetString(name, input string) (string, bool) {
168+
if c.Name == name {
169+
if value, exists := c.Input[input]; exists {
170+
if value, ok := value.(string); ok {
171+
return value, true
172+
}
173+
}
174+
}
175+
return "", false
176+
}
177+
151178
///////////////////////////////////////////////////////////////////////////////
152179
// PRIVATE METHODS
153180

@@ -210,6 +237,10 @@ func (m *Message) append(v any) error {
210237
// []Content, Content => []Content
211238
m.Content = append(m.Content.([]Content), v)
212239
return nil
240+
case []Content:
241+
// []Content, []Content => []Content
242+
m.Content = append(m.Content.([]Content), v...)
243+
return nil
213244
}
214245
}
215246
return ErrBadParameter.With("append: not implemented for ", reflect.TypeOf(m.Content), ",", reflect.TypeOf(v))
@@ -218,16 +249,16 @@ func (m *Message) append(v any) error {
218249
// Set the message content
219250
func (m *Message) set(v any) error {
220251
// Append content to messages,
221-
// m.Content will be of type string, []string, Content or []Content
252+
// m.Content will be of type string, []string or []Content
222253
switch v := v.(type) {
223254
case string:
224255
m.Content = v
225256
case []string:
226257
m.Content = v
227258
case *Content:
228-
m.Content = *v
259+
m.Content = []Content{*v}
229260
case Content:
230-
m.Content = v
261+
m.Content = []Content{v}
231262
case []*Content:
232263
if len(v) > 0 {
233264
m.Content = make([]Content, 0, len(v))

0 commit comments

Comments
 (0)