Skip to content

Commit 7b76c37

Browse files
committed
Added streamable messages
1 parent 32d1864 commit 7b76c37

File tree

5 files changed

+115
-12
lines changed

5 files changed

+115
-12
lines changed

pkg/anthropic/message.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ type Usage struct {
8282
type Delta struct {
8383
Type string `json:"type"`
8484
Text string `json:"text,omitempty"`
85+
Json string `json:"partial_json,omitempty"`
8586
}
8687

8788
///////////////////////////////////////////////////////////////////////////////
@@ -136,6 +137,14 @@ func (c *Client) Messages(ctx context.Context, messages []*schema.Message, opt .
136137
return response.Content, nil
137138
}
138139

140+
///////////////////////////////////////////////////////////////////////////////
141+
// STRINGIFY
142+
143+
func (d Delta) String() string {
144+
data, _ := json.MarshalIndent(d, "", " ")
145+
return string(data)
146+
}
147+
139148
///////////////////////////////////////////////////////////////////////////////
140149
// UNMARSHAL TEXT STREAM
141150

@@ -182,6 +191,7 @@ func (m *respMessage) decodeTextStream(r io.Reader) error {
182191
// TODO: Set input_tokens from stream.Usage.InputTokens
183192
case "message_stop":
184193
if m.delta != nil {
194+
// Callback with nil to indicate end of message
185195
m.delta(nil)
186196
}
187197
case "content_block_start":
@@ -192,11 +202,13 @@ func (m *respMessage) decodeTextStream(r io.Reader) error {
192202
case stream.Delta.Type == "text_delta" && content.Type == "text":
193203
// Append text
194204
content.Text += stream.Delta.Text
205+
case stream.Delta.Type == "input_json_delta":
206+
// Append partial_json
207+
content.Text += stream.Delta.Json
195208
}
196209
case "content_block_stop":
197210
// Append content
198211
m.Content = append(m.Content, content)
199-
200212
// Reset content
201213
content = schema.Content{}
202214
case "message_delta":

pkg/anthropic/message_test.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package anthropic_test
33
import (
44
"context"
55
"os"
6+
"reflect"
67
"testing"
78

89
opts "github.com/mutablelogic/go-client"
@@ -44,10 +45,29 @@ func Test_message_003(t *testing.T) {
4445

4546
// Create the weather tool
4647
weather := schema.NewTool("weather", "Get the weather for a location")
47-
weather.AddParameter("location", "The location to get the weather for", true)
48+
assert.NoError(weather.Add("location", "The location to get the weather for", true, reflect.TypeOf("")))
4849

4950
// Request -> Response
5051
content, err := client.Messages(context.Background(), []*schema.Message{msg}, anthropic.OptTool(weather))
5152
assert.NoError(err)
5253
t.Log(content)
5354
}
55+
56+
func Test_message_004(t *testing.T) {
57+
assert := assert.New(t)
58+
client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true), opts.OptHeader("Anthropic-Beta", "tools-2024-04-04"))
59+
assert.NoError(err)
60+
assert.NotNil(client)
61+
msg := schema.NewMessage("user", "What is the weather today in Berlin, Germany")
62+
63+
// Create the weather tool
64+
weather := schema.NewTool("weather", "Get the weather for a location")
65+
assert.NoError(weather.Add("location", "The location to get the weather for", true, reflect.TypeOf("")))
66+
67+
// Request -> Response
68+
content, err := client.Messages(context.Background(), []*schema.Message{msg}, anthropic.OptTool(weather), anthropic.OptStream(func(v *anthropic.Delta) {
69+
t.Log(v)
70+
}))
71+
assert.NoError(err)
72+
t.Log(content)
73+
}

pkg/openai/schema/tool.go

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@ package schema
22

33
import (
44
"encoding/json"
5+
"fmt"
6+
"reflect"
57

68
// Namespace imports
79
. "github.com/djthorpe/go-errors"
10+
"github.com/djthorpe/go-tablewriter/pkg/meta"
811
)
912

1013
///////////////////////////////////////////////////////////////////////////////
@@ -21,7 +24,7 @@ type Tool struct {
2124
type toolParameters struct {
2225
Type string `json:"type,omitempty"`
2326
Properties map[string]toolParameter `json:"properties,omitempty"`
24-
Required []string `json:"required"`
27+
Required []string `json:"required,omitempty"`
2528
}
2629

2730
// Tool function call parameter
@@ -32,6 +35,19 @@ type toolParameter struct {
3235
Description string `json:"description"`
3336
}
3437

38+
///////////////////////////////////////////////////////////////////////////////
39+
// GLOBALS
40+
41+
const (
42+
tagParameter = "json"
43+
)
44+
45+
var (
46+
typeString = reflect.TypeOf("")
47+
typeBool = reflect.TypeOf(true)
48+
typeInt = reflect.TypeOf(int(0))
49+
)
50+
3551
///////////////////////////////////////////////////////////////////////////////
3652
// LIFECYCLE
3753

@@ -46,6 +62,29 @@ func NewTool(name, description string) *Tool {
4662
}
4763
}
4864

65+
func NewToolEx(name, description string, parameters any) (*Tool, error) {
66+
t := NewTool(name, description)
67+
if parameters == nil {
68+
return t, nil
69+
}
70+
71+
// Get tool metadata
72+
meta, err := meta.New(parameters, tagParameter)
73+
if err != nil {
74+
return nil, err
75+
}
76+
77+
// Iterate over fields, and add parameters
78+
for _, field := range meta.Fields() {
79+
if err := t.Add(field.Name(), field.Tag("description"), !field.Is("omitempty"), field.Type()); err != nil {
80+
return nil, fmt.Errorf("field %q: %w", field.Name(), err)
81+
}
82+
}
83+
84+
// Return the tool
85+
return t, nil
86+
}
87+
4988
///////////////////////////////////////////////////////////////////////////////
5089
// STRINGIFY
5190

@@ -57,22 +96,42 @@ func (t Tool) String() string {
5796
///////////////////////////////////////////////////////////////////////////////
5897
// PUBLIC METHODS
5998

60-
func (t *Tool) AddParameter(name, description string, required bool) error {
99+
func (tool *Tool) Add(name, description string, required bool, t reflect.Type) error {
61100
if name == "" {
62101
return ErrBadParameter.With("missing name")
63102
}
64-
if _, exists := t.Parameters.Properties[name]; exists {
103+
if _, exists := tool.Parameters.Properties[name]; exists {
65104
return ErrDuplicateEntry.With(name)
66105
}
67-
t.Parameters.Properties[name] = toolParameter{
106+
typ, err := typeOf(t)
107+
if err != nil {
108+
return err
109+
}
110+
tool.Parameters.Properties[name] = toolParameter{
68111
Name: name,
69-
Type: "string",
112+
Type: typ,
70113
Description: description,
71114
}
72115
if required {
73-
t.Parameters.Required = append(t.Parameters.Required, name)
116+
tool.Parameters.Required = append(tool.Parameters.Required, name)
74117
}
75118

76119
// Return success
77120
return nil
78121
}
122+
123+
///////////////////////////////////////////////////////////////////////////////
124+
// PRIVATE METHODS
125+
126+
func typeOf(v reflect.Type) (string, error) {
127+
switch v {
128+
case typeString:
129+
return "string", nil
130+
case typeBool:
131+
return "boolean", nil
132+
case typeInt:
133+
return "integer", nil
134+
default:
135+
return "", ErrBadParameter.Withf("unsupported type %q", v)
136+
}
137+
}

pkg/openai/schema/tool_test.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package schema_test
22

33
import (
4+
"reflect"
45
"testing"
56

67
"github.com/mutablelogic/go-client/pkg/openai/schema"
@@ -11,6 +12,16 @@ func Test_tool_001(t *testing.T) {
1112
assert := assert.New(t)
1213
tool := schema.NewTool("get_stock_price", "Get the current stock price for a given ticker symbol.")
1314
assert.NotNil(tool)
14-
assert.NoError(tool.AddParameter("ticker", "The stock ticker symbol, e.g. AAPL for Apple Inc.", true))
15+
assert.NoError(tool.Add("ticker", "The stock ticker symbol, e.g. AAPL for Apple Inc.", true, reflect.TypeOf("")))
16+
t.Log(tool)
17+
}
18+
19+
func Test_tool_002(t *testing.T) {
20+
assert := assert.New(t)
21+
tool, err := schema.NewToolEx("get_stock_price", "Get the current stock price for a given ticker symbol.", struct {
22+
Ticker string `json:"ticker,omitempty" description:"The stock ticker symbol, e.g. AAPL for Apple Inc."`
23+
}{})
24+
assert.NoError(err)
25+
assert.NotNil(tool)
1526
t.Log(tool)
1627
}

transport.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"io"
99
"mime"
1010
"net/http"
11+
"strings"
1112
"time"
1213

1314
// Packages
@@ -101,15 +102,15 @@ func (transport *logtransport) RoundTrip(req *http.Request) (*http.Response, err
101102
resp.Body = io.NopCloser(bytes.NewReader(body))
102103
defer resp.Body.Close()
103104

104-
switch contentType {
105-
case ContentTypeJson:
105+
switch {
106+
case contentType == ContentTypeJson:
106107
dst := &bytes.Buffer{}
107108
if err := json.Indent(dst, body, " ", " "); err != nil {
108109
fmt.Fprintf(transport.w, " <= %q\n", string(body))
109110
} else {
110111
fmt.Fprintf(transport.w, " <= %v\n", dst.String())
111112
}
112-
case ContentTypeTextPlain:
113+
case strings.HasPrefix(contentType, "text/"):
113114
fmt.Fprintf(transport.w, " <= %q\n", string(body))
114115
default:
115116
fmt.Fprintf(transport.w, " <= (not displaying response of type %q)\n", contentType)

0 commit comments

Comments
 (0)