Skip to content

Commit c02b5cc

Browse files
authored
Merge pull request #22 from mutablelogic/v1
Added home assistant to samantha
2 parents 6b9e053 + e8ff351 commit c02b5cc

File tree

5 files changed

+219
-25
lines changed

5 files changed

+219
-25
lines changed

cmd/api/homeassistant.go

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"slices"
7+
"strings"
8+
"time"
9+
10+
"github.com/djthorpe/go-tablewriter"
11+
"github.com/mutablelogic/go-client"
12+
"github.com/mutablelogic/go-client/pkg/homeassistant"
13+
)
14+
15+
///////////////////////////////////////////////////////////////////////////////
16+
// TYPES
17+
18+
type haEntity struct {
19+
Id string `json:"entity_id"`
20+
Name string `json:"name,omitempty"`
21+
Class string `json:"class,omitempty"`
22+
State string `json:"state,omitempty"`
23+
Attributes map[string]interface{} `json:"attributes,omitempty,wrap"`
24+
UpdatedAt time.Time `json:"last_updated,omitempty"`
25+
}
26+
27+
type haClass struct {
28+
Class string `json:"class,omitempty"`
29+
}
30+
31+
///////////////////////////////////////////////////////////////////////////////
32+
// GLOBALS
33+
34+
var (
35+
haName = "homeassistant"
36+
haClient *homeassistant.Client
37+
)
38+
39+
///////////////////////////////////////////////////////////////////////////////
40+
// LIFECYCLE
41+
42+
func haRegister(flags *Flags) {
43+
// Register flags required
44+
flags.String(haName, "ha-endpoint", "${HA_ENDPOINT}", "Token")
45+
flags.String(haName, "ha-token", "${HA_TOKEN}", "Token")
46+
47+
flags.Register(Cmd{
48+
Name: haName,
49+
Description: "Information from home assistant",
50+
Parse: haParse,
51+
Fn: []Fn{
52+
{Name: "classes", Call: haClasses, Description: "Return entity classes"},
53+
{Name: "states", Call: haStates, Description: "Return entity states"},
54+
},
55+
})
56+
}
57+
58+
func haParse(flags *Flags, opts ...client.ClientOpt) error {
59+
// Create home assistant client
60+
if ha, err := homeassistant.New(flags.GetString("ha-endpoint"), flags.GetString("ha-token"), opts...); err != nil {
61+
return err
62+
} else {
63+
haClient = ha
64+
}
65+
66+
// Return success
67+
return nil
68+
}
69+
70+
///////////////////////////////////////////////////////////////////////////////
71+
// METHODS
72+
73+
func haStates(_ context.Context, w *tablewriter.Writer, args []string) error {
74+
if states, err := haGetStates(args); err != nil {
75+
return err
76+
} else {
77+
return w.Write(states)
78+
}
79+
}
80+
81+
func haClasses(_ context.Context, w *tablewriter.Writer, args []string) error {
82+
states, err := haGetStates(nil)
83+
if err != nil {
84+
return err
85+
}
86+
87+
classes := make(map[string]bool)
88+
for _, state := range states {
89+
classes[state.Class] = true
90+
}
91+
92+
result := []haClass{}
93+
for c := range classes {
94+
result = append(result, haClass{Class: c})
95+
}
96+
return w.Write(result)
97+
}
98+
99+
///////////////////////////////////////////////////////////////////////////////
100+
// PRIVATE METHODS
101+
102+
func haGetStates(classes []string) ([]haEntity, error) {
103+
var result []haEntity
104+
105+
// Get states from the remote service
106+
states, err := haClient.States()
107+
if err != nil {
108+
return nil, err
109+
}
110+
111+
// Filter states
112+
for _, state := range states {
113+
entity := haEntity{
114+
Id: state.Entity,
115+
State: state.State,
116+
Attributes: state.Attributes,
117+
UpdatedAt: state.LastChanged,
118+
}
119+
120+
// Ignore entities without state
121+
if entity.State == "" || entity.State == "unknown" || entity.State == "unavailable" {
122+
continue
123+
}
124+
125+
// Set entity type and name from entity id
126+
parts := strings.SplitN(entity.Id, ".", 2)
127+
if len(parts) >= 2 {
128+
entity.Class = strings.ToLower(parts[0])
129+
entity.Name = parts[1]
130+
}
131+
132+
// Set entity type from device class
133+
if t, exists := state.Attributes["device_class"]; exists {
134+
entity.Class = fmt.Sprint(t)
135+
}
136+
137+
// Filter classes
138+
if len(classes) > 0 && !slices.Contains(classes, entity.Class) {
139+
continue
140+
}
141+
142+
// Set entity name from attributes
143+
if name, exists := state.Attributes["friendly_name"]; exists {
144+
entity.Name = fmt.Sprint(name)
145+
}
146+
147+
// Append results
148+
result = append(result, entity)
149+
}
150+
151+
// Return success
152+
return result, nil
153+
}

cmd/api/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ func main() {
2323
newsapiRegister(flags)
2424
weatherapiRegister(flags)
2525
samRegister(flags)
26+
haRegister(flags)
2627

2728
// Parse command line and return function to run
2829
fn, args, err := flags.Parse(os.Args[1:])

cmd/api/samantha.go

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ import (
2121

2222
var (
2323
samName = "sam"
24-
samWeatherTool = schema.NewTool("get_weather", "Get the weather for a location")
24+
samWeatherTool = schema.NewTool("get_current_weather", "Get the current weather conditions for a location")
2525
samNewsHeadlinesTool = schema.NewTool("get_news_headlines", "Get the news headlines")
2626
samNewsSearchTool = schema.NewTool("search_news", "Search news articles")
27-
samSystemPrompt = `Your name is Samantha, you are a friendly AI assistant, here to help you with
28-
anything you need. Your responses should be short and to the point, and you should always be polite.`
27+
samHomeAssistantTool = schema.NewTool("get_home_devices", "Return information about home devices")
28+
samSystemPrompt = `Your name is Samantha, you are a personal assistant modelled on the personality of Samantha from the movie "Her". Your responses should be short and friendly.`
2929
)
3030

3131
///////////////////////////////////////////////////////////////////////////////
@@ -51,23 +51,32 @@ func samParse(flags *Flags, opts ...client.ClientOpt) error {
5151
if err := newsapiParse(flags, opts...); err != nil {
5252
return err
5353
}
54-
54+
// Initialize home assistant
55+
if err := haParse(flags, opts...); err != nil {
56+
return err
57+
}
5558
// Initialize anthropic
5659
opts = append(opts, client.OptHeader("Anthropic-Beta", "tools-2024-04-04"))
5760
if err := anthropicParse(flags, opts...); err != nil {
5861
return err
5962
}
6063

6164
// Add tool parameters
62-
if err := samWeatherTool.AddParameter("location", "The city to get the weather for", true); err != nil {
65+
if err := samWeatherTool.AddParameter("location", `City to get the weather for. If a country, use the capital city. To get weather for the current location, use "auto:ip"`, true); err != nil {
6366
return err
6467
}
6568
if err := samNewsHeadlinesTool.AddParameter("category", "The cateogry of news, which should be one of business, entertainment, general, health, science, sports or technology", true); err != nil {
6669
return err
6770
}
71+
if err := samNewsHeadlinesTool.AddParameter("country", "Headlines from agencies in a specific country. Optional. Use ISO 3166 country code.", false); err != nil {
72+
return err
73+
}
6874
if err := samNewsSearchTool.AddParameter("query", "The query with which to search news", true); err != nil {
6975
return err
7076
}
77+
if err := samHomeAssistantTool.AddParameter("class", "The class of device, which should be one or more of door,lock,occupancy,motion,climate,light,switch,sensor,speaker,media_player,temperature,humidity,battery,tv,remote,light,vacuum separated by spaces", true); err != nil {
78+
return err
79+
}
7180

7281
// Return success
7382
return nil
@@ -99,34 +108,49 @@ func samChat(ctx context.Context, w *tablewriter.Writer, _ []string) error {
99108
// Curtail requests to the last N history
100109
if len(messages) > 10 {
101110
messages = messages[len(messages)-10:]
102-
// First message must have role 'user'
111+
112+
// First message must have role 'user' and not be a tool_result
103113
for {
104-
if len(messages) == 0 || messages[0].Role == "user" {
114+
if len(messages) == 0 {
105115
break
106116
}
117+
if messages[0].Role == "user" {
118+
if content, ok := messages[0].Content.([]schema.Content); ok {
119+
if len(content) > 0 && content[0].Type != "tool_result" {
120+
break
121+
}
122+
} else {
123+
break
124+
}
125+
}
107126
messages = messages[1:]
108127
}
109-
// TODO: We must remove the first instance tool_result if there is no tool_use
110128
}
111129

112130
// Request -> Response
113-
responses, err := anthropicClient.Messages(ctx, messages, anthropic.OptSystem(samSystemPrompt), anthropic.OptTool(samWeatherTool), anthropic.OptTool(samNewsHeadlinesTool), anthropic.OptTool(samNewsSearchTool))
114-
if err != nil {
115-
return err
116-
}
131+
responses, err := anthropicClient.Messages(ctx, messages, anthropic.OptSystem(samSystemPrompt),
132+
anthropic.OptTool(samWeatherTool),
133+
anthropic.OptTool(samNewsHeadlinesTool),
134+
anthropic.OptTool(samNewsSearchTool),
135+
anthropic.OptTool(samHomeAssistantTool),
136+
)
117137
toolResult = false
118-
119-
for _, response := range responses {
120-
switch response.Type {
121-
case "text":
122-
messages = samAppend(messages, schema.NewMessage("assistant", schema.Text(response.Text)))
123-
fmt.Println(response.Text)
124-
fmt.Println("")
125-
case "tool_use":
126-
messages = samAppend(messages, schema.NewMessage("assistant", response))
127-
result := samCall(ctx, response)
128-
messages = samAppend(messages, schema.NewMessage("user", result))
129-
toolResult = true
138+
if err != nil {
139+
fmt.Println(err)
140+
fmt.Println("")
141+
} else {
142+
for _, response := range responses {
143+
switch response.Type {
144+
case "text":
145+
messages = samAppend(messages, schema.NewMessage("assistant", schema.Text(response.Text)))
146+
fmt.Println(response.Text)
147+
fmt.Println("")
148+
case "tool_use":
149+
messages = samAppend(messages, schema.NewMessage("assistant", response))
150+
result := samCall(ctx, response)
151+
messages = samAppend(messages, schema.NewMessage("user", result))
152+
toolResult = true
153+
}
130154
}
131155
}
132156
}
@@ -158,7 +182,8 @@ func samCall(_ context.Context, content schema.Content) *schema.Content {
158182
} else {
159183
category = "general"
160184
}
161-
if headlines, err := newsapiClient.Headlines(newsapi.OptCategory(category)); err != nil {
185+
country, _ := content.GetString(content.Name, "country")
186+
if headlines, err := newsapiClient.Headlines(newsapi.OptCategory(category), newsapi.OptCountry(country)); err != nil {
162187
return schema.ToolResult(content.Id, fmt.Sprint("Unable to get the news headlines, the error is ", err))
163188
} else if data, err := json.MarshalIndent(headlines, "", " "); err != nil {
164189
return schema.ToolResult(content.Id, fmt.Sprint("Unable to marshal the headlines data, the error is ", err))
@@ -179,6 +204,18 @@ func samCall(_ context.Context, content schema.Content) *schema.Content {
179204
} else {
180205
return schema.ToolResult(content.Id, string(data))
181206
}
207+
case samHomeAssistantTool.Name:
208+
classes, exists := content.GetString(content.Name, "class")
209+
if !exists || classes == "" {
210+
return schema.ToolResult(content.Id, "Unable to get home devices due to missing class")
211+
}
212+
if states, err := haGetStates(strings.Fields(classes)); err != nil {
213+
return schema.ToolResult(content.Id, fmt.Sprint("Unable to get home devices, the error is ", err))
214+
} else if data, err := json.MarshalIndent(states, "", " "); err != nil {
215+
return schema.ToolResult(content.Id, fmt.Sprint("Unable to marshal the states data, the error is ", err))
216+
} else {
217+
return schema.ToolResult(content.Id, string(data))
218+
}
182219
}
183220
return schema.ToolResult(content.Id, fmt.Sprint("unable to call:", content.Name))
184221
}

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ require (
1212
github.com/stretchr/testify v1.9.0
1313
github.com/xdg-go/pbkdf2 v1.0.0
1414
golang.org/x/crypto v0.23.0
15+
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
1516
golang.org/x/term v0.20.0
1617
)
1718

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
2121
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
2222
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
2323
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
24+
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
25+
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
2426
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
2527
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
2628
golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=

0 commit comments

Comments
 (0)