Skip to content

Commit 89fc399

Browse files
authored
Merge pull request #24 from mutablelogic/v1
Updated home assistant and samantha to allow devices to be turned on and off
2 parents af56a42 + 0059684 commit 89fc399

File tree

3 files changed

+127
-31
lines changed

3 files changed

+127
-31
lines changed

client.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,16 @@ func (client *Client) Request(req *http.Request, out any, opts ...RequestOpt) er
187187
return do(client.Client, req, "", false, out, opts...)
188188
}
189189

190+
// Debugf outputs debug information
191+
func (client *Client) Debugf(f string, args ...any) {
192+
if client.Client.Transport != nil && client.Client.Transport != http.DefaultTransport {
193+
if debug, ok := client.Transport.(*logtransport); ok {
194+
fmt.Fprintf(debug.w, f, args...)
195+
fmt.Fprint(debug.w, "\n")
196+
}
197+
}
198+
}
199+
190200
///////////////////////////////////////////////////////////////////////////////
191201
// PRIVATE METHODS
192202

cmd/api/homeassistant.go

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ package main
22

33
import (
44
"context"
5+
"slices"
56
"strings"
67
"time"
78

89
"github.com/djthorpe/go-tablewriter"
910
"github.com/mutablelogic/go-client"
1011
"github.com/mutablelogic/go-client/pkg/homeassistant"
12+
"golang.org/x/exp/maps"
1113
)
1214

1315
///////////////////////////////////////////////////////////////////////////////
@@ -26,7 +28,7 @@ type haEntity struct {
2628

2729
type haDomain struct {
2830
Name string `json:"domain"`
29-
Services string `json:"services,omitempty"`
31+
Services string `json:"services,omitempty,width:40,wrap"`
3032
}
3133

3234
///////////////////////////////////////////////////////////////////////////////
@@ -50,6 +52,7 @@ func haRegister(flags *Flags) {
5052
Description: "Information from home assistant",
5153
Parse: haParse,
5254
Fn: []Fn{
55+
{Name: "health", Call: haHealth, Description: "Return status of home assistant"},
5356
{Name: "domains", Call: haDomains, Description: "Enumerate entity domains"},
5457
{Name: "states", Call: haStates, Description: "Show current entity states", MaxArgs: 1, Syntax: "(<name>)"},
5558
{Name: "services", Call: haServices, Description: "Show services for an entity", MinArgs: 1, MaxArgs: 1, Syntax: "<entity>"},
@@ -73,40 +76,65 @@ func haParse(flags *Flags, opts ...client.ClientOpt) error {
7376
///////////////////////////////////////////////////////////////////////////////
7477
// METHODS
7578

76-
func haStates(_ context.Context, w *tablewriter.Writer, args []string) error {
77-
var result []haEntity
78-
states, err := haGetStates(nil)
79+
func haHealth(_ context.Context, w *tablewriter.Writer, args []string) error {
80+
type respHealth struct {
81+
Status string `json:"status"`
82+
}
83+
status, err := haClient.Health()
7984
if err != nil {
8085
return err
8186
}
87+
return w.Write(respHealth{Status: status})
88+
}
8289

83-
for _, state := range states {
84-
if len(args) == 1 {
85-
if !haMatchString(args[0], state.Name, state.Id) {
86-
continue
87-
}
90+
func haStates(_ context.Context, w *tablewriter.Writer, args []string) error {
91+
var q string
92+
if len(args) > 0 {
93+
q = args[0]
94+
}
8895

89-
}
90-
result = append(result, state)
96+
states, err := haGetStates(q, nil)
97+
if err != nil {
98+
return err
9199
}
92-
return w.Write(result)
100+
101+
return w.Write(states)
93102
}
94103

95104
func haDomains(_ context.Context, w *tablewriter.Writer, args []string) error {
96-
states, err := haGetStates(nil)
105+
// Get all states
106+
states, err := haGetStates("", nil)
97107
if err != nil {
98108
return err
99109
}
100110

111+
// Enumerate all the classes
101112
classes := make(map[string]bool)
102113
for _, state := range states {
103114
classes[state.Class] = true
104115
}
105116

117+
// Get all the domains, and make a map of them
118+
domains, err := haClient.Domains()
119+
if err != nil {
120+
return err
121+
}
122+
map_domains := make(map[string]*homeassistant.Domain)
123+
for _, domain := range domains {
124+
map_domains[domain.Domain] = domain
125+
}
126+
106127
result := []haDomain{}
107128
for c := range classes {
129+
var services []string
130+
if domain, exists := map_domains[c]; exists {
131+
if v := domain.Services; v != nil {
132+
services = maps.Keys(v)
133+
}
134+
}
108135
result = append(result, haDomain{
109-
Name: c,
136+
Name: c,
137+
Services: strings.Join(services, ", "),
110138
})
111139
}
112140
return w.Write(result)
@@ -148,7 +176,7 @@ func haMatchString(q string, values ...string) bool {
148176
return false
149177
}
150178

151-
func haGetStates(domains []string) ([]haEntity, error) {
179+
func haGetStates(name string, domains []string) ([]haEntity, error) {
152180
var result []haEntity
153181

154182
// Get states from the remote service
@@ -175,16 +203,25 @@ func haGetStates(domains []string) ([]haEntity, error) {
175203
continue
176204
}
177205

206+
// Filter name
207+
if name != "" {
208+
if !haMatchString(name, entity.Name, entity.Id) {
209+
continue
210+
}
211+
}
212+
213+
// Filter domains
214+
if len(domains) > 0 {
215+
if !slices.Contains(domains, entity.Domain) {
216+
continue
217+
}
218+
}
219+
178220
// Add unit of measurement
179221
if unit := state.UnitOfMeasurement(); unit != "" {
180222
entity.State += " " + unit
181223
}
182224

183-
// Filter domains
184-
//if len(domains) > 0 && !slices.Contains(domains, entity.Domain) {
185-
// continue
186-
//}
187-
188225
// Append results
189226
result = append(result, entity)
190227
}

cmd/api/samantha.go

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@ import (
2020
// GLOBALS
2121

2222
var (
23-
samName = "sam"
24-
samWeatherTool = schema.NewTool("get_current_weather", "Get the current weather conditions for a location")
25-
samNewsHeadlinesTool = schema.NewTool("get_news_headlines", "Get the news headlines")
26-
samNewsSearchTool = schema.NewTool("search_news", "Search news articles")
27-
samHomeAssistantTool = schema.NewTool("get_home_devices", "Return information about home devices")
28-
samSystemPrompt = `Your name is Samantha, you are a personal assistant modelled on the personality of Samantha from the movie "Her". Your responses should be short and friendly.`
23+
samName = "sam"
24+
samWeatherTool = schema.NewTool("get_current_weather", "Get the current weather conditions for a location")
25+
samNewsHeadlinesTool = schema.NewTool("get_news_headlines", "Get the news headlines")
26+
samNewsSearchTool = schema.NewTool("search_news", "Search news articles")
27+
samHomeAssistantTool = schema.NewTool("get_home_devices", "Return information about home devices by type, including their state and entity_id")
28+
samHomeAssistantSearch = schema.NewTool("search_home_devices", "Return information about home devices by name, including their state and entity_id")
29+
samHomeAssistantTurnOn = schema.NewTool("turn_on_device", "Turn on a device")
30+
samHomeAssistantTurnOff = schema.NewTool("turn_off_device", "Turn off a device")
31+
samSystemPrompt = `Your name is Samantha, you are a personal assistant modelled on the personality of Samantha from the movie "Her". Your responses should be short and friendly.`
2932
)
3033

3134
///////////////////////////////////////////////////////////////////////////////
@@ -74,7 +77,16 @@ func samParse(flags *Flags, opts ...client.ClientOpt) error {
7477
if err := samNewsSearchTool.AddParameter("query", "The query with which to search news", true); err != nil {
7578
return err
7679
}
77-
if err := samHomeAssistantTool.AddParameter("class", "The class of device, which should be one or more of door,lock,occupancy,motion,climate,light,switch,sensor,speaker,media_player,temperature,humidity,battery,tv,remote,light,vacuum separated by spaces", true); err != nil {
80+
if err := samHomeAssistantTool.AddParameter("type", "Query for a device type, which could one or more of door,lock,occupancy,motion,climate,light,switch,sensor,speaker,media_player,temperature,humidity,battery,tv,remote,light,vacuum separated by spaces", true); err != nil {
81+
return err
82+
}
83+
if err := samHomeAssistantSearch.AddParameter("name", "Search for device state by name", true); err != nil {
84+
return err
85+
}
86+
if err := samHomeAssistantTurnOn.AddParameter("entity_id", "The device entity_id to turn on", true); err != nil {
87+
return err
88+
}
89+
if err := samHomeAssistantTurnOff.AddParameter("entity_id", "The device entity_id to turn off", true); err != nil {
7890
return err
7991
}
8092

@@ -128,14 +140,20 @@ func samChat(ctx context.Context, w *tablewriter.Writer, _ []string) error {
128140
}
129141

130142
// Request -> Response
131-
responses, err := anthropicClient.Messages(ctx, messages, anthropic.OptSystem(samSystemPrompt),
143+
responses, err := anthropicClient.Messages(ctx, messages,
144+
anthropic.OptSystem(samSystemPrompt),
145+
anthropic.OptMaxTokens(1000),
132146
anthropic.OptTool(samWeatherTool),
133147
anthropic.OptTool(samNewsHeadlinesTool),
134148
anthropic.OptTool(samNewsSearchTool),
135149
anthropic.OptTool(samHomeAssistantTool),
150+
anthropic.OptTool(samHomeAssistantSearch),
151+
anthropic.OptTool(samHomeAssistantTurnOn),
152+
anthropic.OptTool(samHomeAssistantTurnOff),
136153
)
137154
toolResult = false
138155
if err != nil {
156+
messages = samAppend(messages, schema.NewMessage("assistant", schema.Text(fmt.Sprint("An error occurred: ", err))))
139157
fmt.Println(err)
140158
fmt.Println("")
141159
} else {
@@ -157,6 +175,7 @@ func samChat(ctx context.Context, w *tablewriter.Writer, _ []string) error {
157175
}
158176

159177
func samCall(_ context.Context, content schema.Content) *schema.Content {
178+
anthropicClient.Debugf("%v: %v: %v", content.Type, content.Name, content.Input)
160179
if content.Type != "tool_use" {
161180
return schema.ToolResult(content.Id, fmt.Sprint("unexpected content type:", content.Type))
162181
}
@@ -205,17 +224,47 @@ func samCall(_ context.Context, content schema.Content) *schema.Content {
205224
return schema.ToolResult(content.Id, string(data))
206225
}
207226
case samHomeAssistantTool.Name:
208-
classes, exists := content.GetString(content.Name, "class")
227+
classes, exists := content.GetString(content.Name, "type")
209228
if !exists || classes == "" {
210-
return schema.ToolResult(content.Id, "Unable to get home devices due to missing class")
229+
return schema.ToolResult(content.Id, "Unable to get home devices due to missing type")
211230
}
212-
if states, err := haGetStates(strings.Fields(classes)); err != nil {
231+
if states, err := haGetStates("", strings.Fields(classes)); err != nil {
213232
return schema.ToolResult(content.Id, fmt.Sprint("Unable to get home devices, the error is ", err))
214233
} else if data, err := json.MarshalIndent(states, "", " "); err != nil {
215234
return schema.ToolResult(content.Id, fmt.Sprint("Unable to marshal the states data, the error is ", err))
216235
} else {
217236
return schema.ToolResult(content.Id, string(data))
218237
}
238+
case samHomeAssistantSearch.Name:
239+
name, exists := content.GetString(content.Name, "name")
240+
if !exists || name == "" {
241+
return schema.ToolResult(content.Id, "Unable to search home devices due to missing name")
242+
}
243+
if states, err := haGetStates(name, nil); err != nil {
244+
return schema.ToolResult(content.Id, fmt.Sprint("Unable to get home devices, the error is ", err))
245+
} else if data, err := json.MarshalIndent(states, "", " "); err != nil {
246+
return schema.ToolResult(content.Id, fmt.Sprint("Unable to marshal the states data, the error is ", err))
247+
} else {
248+
return schema.ToolResult(content.Id, string(data))
249+
}
250+
case samHomeAssistantTurnOn.Name:
251+
entity, _ := content.GetString(content.Name, "entity_id")
252+
if _, err := haClient.Call("turn_on", entity); err != nil {
253+
return schema.ToolResult(content.Id, fmt.Sprint("Unable to turn on device, the error is ", err))
254+
} else if state, err := haClient.State(entity); err != nil {
255+
return schema.ToolResult(content.Id, fmt.Sprint("Unable to get device state, the error is ", err))
256+
} else {
257+
return schema.ToolResult(content.Id, fmt.Sprint("The updated state is: ", state))
258+
}
259+
case samHomeAssistantTurnOff.Name:
260+
entity, _ := content.GetString(content.Name, "entity_id")
261+
if _, err := haClient.Call("turn_off", entity); err != nil {
262+
return schema.ToolResult(content.Id, fmt.Sprint("Unable to turn off device, the error is ", err))
263+
} else if state, err := haClient.State(entity); err != nil {
264+
return schema.ToolResult(content.Id, fmt.Sprint("Unable to get device state, the error is ", err))
265+
} else {
266+
return schema.ToolResult(content.Id, fmt.Sprint("The updated state is: ", state))
267+
}
219268
}
220269
return schema.ToolResult(content.Id, fmt.Sprint("unable to call:", content.Name))
221270
}

0 commit comments

Comments
 (0)