Skip to content

Commit 653d9c4

Browse files
committed
Added home assistant
1 parent 3dbae74 commit 653d9c4

File tree

5 files changed

+203
-9
lines changed

5 files changed

+203
-9
lines changed

cmd/api/homeassistant.go

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"slices"
7+
"strings"
8+
"time"
9+
10+
"github.com/djthorpe/go-tablewriter"
11+
"github.com/mutablelogic/go-client"
12+
"github.com/mutablelogic/go-client/pkg/homeassistant"
13+
)
14+
15+
///////////////////////////////////////////////////////////////////////////////
16+
// TYPES
17+
18+
type haEntity struct {
19+
Id string `json:"entity_id"`
20+
Name string `json:"name,omitempty"`
21+
Class string `json:"class,omitempty"`
22+
State string `json:"state,omitempty"`
23+
Attributes map[string]interface{} `json:"attributes,omitempty,wrap"`
24+
UpdatedAt time.Time `json:"last_updated,omitempty"`
25+
}
26+
27+
type haClass struct {
28+
Class string `json:"class,omitempty"`
29+
}
30+
31+
///////////////////////////////////////////////////////////////////////////////
32+
// GLOBALS
33+
34+
var (
35+
haName = "homeassistant"
36+
haClient *homeassistant.Client
37+
)
38+
39+
///////////////////////////////////////////////////////////////////////////////
40+
// LIFECYCLE
41+
42+
func haRegister(flags *Flags) {
43+
// Register flags required
44+
flags.String(haName, "ha-endpoint", "${HA_ENDPOINT}", "Token")
45+
flags.String(haName, "ha-token", "${HA_TOKEN}", "Token")
46+
47+
flags.Register(Cmd{
48+
Name: haName,
49+
Description: "Information from home assistant",
50+
Parse: haParse,
51+
Fn: []Fn{
52+
{Name: "classes", Call: haClasses, Description: "Return entity classes"},
53+
{Name: "states", Call: haStates, Description: "Return entity states"},
54+
},
55+
})
56+
}
57+
58+
func haParse(flags *Flags, opts ...client.ClientOpt) error {
59+
// Create home assistant client
60+
if ha, err := homeassistant.New(flags.GetString("ha-endpoint"), flags.GetString("ha-token"), opts...); err != nil {
61+
return err
62+
} else {
63+
haClient = ha
64+
}
65+
66+
// Return success
67+
return nil
68+
}
69+
70+
///////////////////////////////////////////////////////////////////////////////
71+
// METHODS
72+
73+
func haStates(_ context.Context, w *tablewriter.Writer, args []string) error {
74+
if states, err := haGetStates(args); err != nil {
75+
return err
76+
} else {
77+
return w.Write(states)
78+
}
79+
}
80+
81+
func haClasses(_ context.Context, w *tablewriter.Writer, args []string) error {
82+
states, err := haGetStates(nil)
83+
if err != nil {
84+
return err
85+
}
86+
87+
classes := make(map[string]bool)
88+
for _, state := range states {
89+
classes[state.Class] = true
90+
}
91+
92+
result := []haClass{}
93+
for c := range classes {
94+
result = append(result, haClass{Class: c})
95+
}
96+
return w.Write(result)
97+
}
98+
99+
///////////////////////////////////////////////////////////////////////////////
100+
// PRIVATE METHODS
101+
102+
func haGetStates(classes []string) ([]haEntity, error) {
103+
var result []haEntity
104+
105+
// Get states from the remote service
106+
states, err := haClient.States()
107+
if err != nil {
108+
return nil, err
109+
}
110+
111+
// Filter states
112+
for _, state := range states {
113+
entity := haEntity{
114+
Id: state.Entity,
115+
State: state.State,
116+
Attributes: state.Attributes,
117+
UpdatedAt: state.LastChanged,
118+
}
119+
120+
// Ignore entities without state
121+
if entity.State == "" || entity.State == "unknown" || entity.State == "unavailable" {
122+
continue
123+
}
124+
125+
// Set entity type and name from entity id
126+
parts := strings.SplitN(entity.Id, ".", 2)
127+
if len(parts) >= 2 {
128+
entity.Class = strings.ToLower(parts[0])
129+
entity.Name = parts[1]
130+
}
131+
132+
// Set entity type from device class
133+
if t, exists := state.Attributes["device_class"]; exists {
134+
entity.Class = fmt.Sprint(t)
135+
}
136+
137+
// Filter classes
138+
if len(classes) > 0 && !slices.Contains(classes, entity.Class) {
139+
continue
140+
}
141+
142+
// Set entity name from attributes
143+
if name, exists := state.Attributes["friendly_name"]; exists {
144+
entity.Name = fmt.Sprint(name)
145+
}
146+
147+
// Append results
148+
result = append(result, entity)
149+
}
150+
151+
// Return success
152+
return result, nil
153+
}

cmd/api/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ func main() {
2323
newsapiRegister(flags)
2424
weatherapiRegister(flags)
2525
samRegister(flags)
26+
haRegister(flags)
2627

2728
// Parse command line and return function to run
2829
fn, args, err := flags.Parse(os.Args[1:])

cmd/api/samantha.go

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ var (
2424
samWeatherTool = schema.NewTool("get_weather", "Get the weather for a location")
2525
samNewsHeadlinesTool = schema.NewTool("get_news_headlines", "Get the news headlines")
2626
samNewsSearchTool = schema.NewTool("search_news", "Search news articles")
27-
samSystemPrompt = `Your name is Samantha, you are a friendly AI assistant, here to help you with
28-
anything you need. Your responses should be short and to the point, and you should always be polite.`
27+
samHomeAssistantTool = schema.NewTool("get_home_devices", "Return information about home devices")
28+
samSystemPrompt = `Your name is Samantha, you are a friendly and occasionally sarcastic assistant,
29+
here to help with anything. Your responses should be short and to the point, and you should always be polite.`
2930
)
3031

3132
///////////////////////////////////////////////////////////////////////////////
@@ -51,23 +52,32 @@ func samParse(flags *Flags, opts ...client.ClientOpt) error {
5152
if err := newsapiParse(flags, opts...); err != nil {
5253
return err
5354
}
54-
55+
// Initialize home assistant
56+
if err := haParse(flags, opts...); err != nil {
57+
return err
58+
}
5559
// Initialize anthropic
5660
opts = append(opts, client.OptHeader("Anthropic-Beta", "tools-2024-04-04"))
5761
if err := anthropicParse(flags, opts...); err != nil {
5862
return err
5963
}
6064

6165
// Add tool parameters
62-
if err := samWeatherTool.AddParameter("location", "The city to get the weather for", true); err != nil {
66+
if err := samWeatherTool.AddParameter("location", `City to get the weather for. If a country, use the capital city. To get weather for the current location, use "auto:ip"`, true); err != nil {
6367
return err
6468
}
6569
if err := samNewsHeadlinesTool.AddParameter("category", "The cateogry of news, which should be one of business, entertainment, general, health, science, sports or technology", true); err != nil {
6670
return err
6771
}
72+
if err := samNewsHeadlinesTool.AddParameter("country", "Headlines from agencies in a specific country. Optional. Use ISO 3166 country code.", false); err != nil {
73+
return err
74+
}
6875
if err := samNewsSearchTool.AddParameter("query", "The query with which to search news", true); err != nil {
6976
return err
7077
}
78+
if err := samHomeAssistantTool.AddParameter("class", "The class of device, which should be one or more of door,lock,occupancy,motion,climate,light,switch,sensor,speaker,media_player,temperature,humidity,battery,tv,remote,light,vacuum separated by spaces", true); err != nil {
79+
return err
80+
}
7181

7282
// Return success
7383
return nil
@@ -99,18 +109,32 @@ func samChat(ctx context.Context, w *tablewriter.Writer, _ []string) error {
99109
// Curtail requests to the last N history
100110
if len(messages) > 10 {
101111
messages = messages[len(messages)-10:]
102-
// First message must have role 'user'
112+
113+
// First message must have role 'user' and not be a tool_result
103114
for {
104-
if len(messages) == 0 || messages[0].Role == "user" {
115+
if len(messages) == 0 {
105116
break
106117
}
118+
if messages[0].Role == "user" {
119+
if content, ok := messages[0].Content.([]schema.Content); ok {
120+
if len(content) > 0 && content[0].Type != "tool_result" {
121+
break
122+
}
123+
} else {
124+
break
125+
}
126+
}
107127
messages = messages[1:]
108128
}
109-
// TODO: We must remove the first instance tool_result if there is no tool_use
110129
}
111130

112131
// Request -> Response
113-
responses, err := anthropicClient.Messages(ctx, messages, anthropic.OptSystem(samSystemPrompt), anthropic.OptTool(samWeatherTool), anthropic.OptTool(samNewsHeadlinesTool), anthropic.OptTool(samNewsSearchTool))
132+
responses, err := anthropicClient.Messages(ctx, messages, anthropic.OptSystem(samSystemPrompt),
133+
anthropic.OptTool(samWeatherTool),
134+
anthropic.OptTool(samNewsHeadlinesTool),
135+
anthropic.OptTool(samNewsSearchTool),
136+
anthropic.OptTool(samHomeAssistantTool),
137+
)
114138
if err != nil {
115139
return err
116140
}
@@ -158,7 +182,8 @@ func samCall(_ context.Context, content schema.Content) *schema.Content {
158182
} else {
159183
category = "general"
160184
}
161-
if headlines, err := newsapiClient.Headlines(newsapi.OptCategory(category)); err != nil {
185+
country, _ := content.GetString(content.Name, "country")
186+
if headlines, err := newsapiClient.Headlines(newsapi.OptCategory(category), newsapi.OptCountry(country)); err != nil {
162187
return schema.ToolResult(content.Id, fmt.Sprint("Unable to get the news headlines, the error is ", err))
163188
} else if data, err := json.MarshalIndent(headlines, "", " "); err != nil {
164189
return schema.ToolResult(content.Id, fmt.Sprint("Unable to marshal the headlines data, the error is ", err))
@@ -179,6 +204,18 @@ func samCall(_ context.Context, content schema.Content) *schema.Content {
179204
} else {
180205
return schema.ToolResult(content.Id, string(data))
181206
}
207+
case samHomeAssistantTool.Name:
208+
classes, exists := content.GetString(content.Name, "class")
209+
if !exists || classes == "" {
210+
return schema.ToolResult(content.Id, "Unable to get home devices due to missing class")
211+
}
212+
if states, err := haGetStates(strings.Fields(classes)); err != nil {
213+
return schema.ToolResult(content.Id, fmt.Sprint("Unable to get home devices, the error is ", err))
214+
} else if data, err := json.MarshalIndent(states, "", " "); err != nil {
215+
return schema.ToolResult(content.Id, fmt.Sprint("Unable to marshal the states data, the error is ", err))
216+
} else {
217+
return schema.ToolResult(content.Id, string(data))
218+
}
182219
}
183220
return schema.ToolResult(content.Id, fmt.Sprint("unable to call:", content.Name))
184221
}

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ require (
1212
github.com/stretchr/testify v1.9.0
1313
github.com/xdg-go/pbkdf2 v1.0.0
1414
golang.org/x/crypto v0.23.0
15+
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
1516
golang.org/x/term v0.20.0
1617
)
1718

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
2121
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
2222
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
2323
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
24+
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
25+
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
2426
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
2527
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
2628
golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=

0 commit comments

Comments
 (0)