Skip to content

Commit 98638d2

Browse files
committed
Updates
1 parent d6dd058 commit 98638d2

File tree

8 files changed

+72
-70
lines changed

8 files changed

+72
-70
lines changed

context.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ type Context interface {
1313
// Return the text of the context
1414
Text() string
1515

16-
// Generate a response from a user prompt (with attachments)
16+
// Generate a response from a user prompt (with attachments and
17+
// other empheral options
1718
FromUser(context.Context, string, ...Opt) (Context, error)
1819

19-
// Generate a response from a tool, passing the call identifier or function name, and the result
20+
// Generate a response from a tool, passing the call identifier or
21+
// function name, and the result
2022
FromTool(context.Context, string, any, ...Opt) (Context, error)
2123
}

model.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ type Model interface {
99

1010
// Return am empty session context object for the model,
1111
// setting session options
12-
Context(...Opt) (Context, error)
12+
Context(...Opt) Context
1313

1414
// Convenience method to create a session context object
15-
// with a user prompt, which panics on error
16-
MustUserPrompt(string, ...Opt) Context
15+
// with a user prompt
16+
UserPrompt(string, ...Opt) Context
1717
}

pkg/anthropic/messages.go

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ type reqMessages struct {
6161
opt
6262
}
6363

64-
func (anthropic *Client) Messages(ctx context.Context, model llm.Model, context llm.Context, opts ...llm.Opt) (*Response, error) {
64+
func (anthropic *Client) Messages(ctx context.Context, context llm.Context, opts ...llm.Opt) (*Response, error) {
6565
// Apply options
6666
opt, err := apply(opts...)
6767
if err != nil {
@@ -70,12 +70,12 @@ func (anthropic *Client) Messages(ctx context.Context, model llm.Model, context
7070

7171
// Set max_tokens
7272
if opt.MaxTokens == 0 {
73-
opt.MaxTokens = defaultMaxTokens(model.Name())
73+
opt.MaxTokens = defaultMaxTokens(context.(*session).model.Name())
7474
}
7575

7676
// Request
7777
req, err := client.NewJSONRequest(reqMessages{
78-
Model: model.Name(),
78+
Model: context.(*session).model.Name(),
7979
Messages: context.(*session).seq,
8080
opt: *opt,
8181
})
@@ -222,16 +222,6 @@ func (anthropic *Client) Messages(ctx context.Context, model llm.Model, context
222222
return &response, nil
223223
}
224224

225-
// Generate a response from a prompt
226-
func (anthropic *Client) Generate(ctx context.Context, model llm.Model, context llm.Context, opts ...llm.Opt) (llm.Context, error) {
227-
response, err := anthropic.Messages(ctx, model, context, opts...)
228-
if err != nil {
229-
return nil, err
230-
}
231-
fmt.Println(response)
232-
return nil, llm.ErrNotImplemented
233-
}
234-
235225
///////////////////////////////////////////////////////////////////////////////
236226
// PRIVATE METHODS
237227

pkg/anthropic/messages_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func Test_messages_001(t *testing.T) {
3333
}
3434
defer f.Close()
3535

36-
response, err := client.Messages(context.TODO(), model, client.UserPrompt("what is this image?", anthropic.WithData(f, false, false)))
36+
response, err := client.Messages(context.TODO(), model.UserPrompt("what is this image?", anthropic.WithData(f, false, false)))
3737
if assert.NoError(err) {
3838
t.Log(response)
3939
}
@@ -61,7 +61,7 @@ func Test_messages_002(t *testing.T) {
6161
}
6262
defer f.Close()
6363

64-
response, err := client.Messages(context.TODO(), model, client.UserPrompt("summarize this document for me", anthropic.WithData(f, false, false)))
64+
response, err := client.Messages(context.TODO(), model.UserPrompt("summarize this document for me", anthropic.WithData(f, false, false)))
6565
if assert.NoError(err) {
6666
t.Log(response)
6767
}
@@ -83,7 +83,7 @@ func Test_messages_003(t *testing.T) {
8383
t.FailNow()
8484
}
8585

86-
response, err := client.Messages(context.TODO(), model, client.UserPrompt("why is the sky blue"), anthropic.WithStream(func(r *anthropic.Response) {
86+
response, err := client.Messages(context.TODO(), model.UserPrompt("why is the sky blue"), anthropic.WithStream(func(r *anthropic.Response) {
8787
t.Log(r)
8888
}))
8989
if assert.NoError(err) {
@@ -114,7 +114,7 @@ func Test_messages_004(t *testing.T) {
114114
t.FailNow()
115115
}
116116

117-
response, err := client.Messages(context.TODO(), model, client.UserPrompt("why is the sky blue"), anthropic.WithTool(weather))
117+
response, err := client.Messages(context.TODO(), model.UserPrompt("why is the sky blue"), anthropic.WithTool(weather))
118118
if assert.NoError(err) {
119119
t.Log(response)
120120
}
@@ -143,7 +143,7 @@ func Test_messages_005(t *testing.T) {
143143
t.FailNow()
144144
}
145145

146-
response, err := client.Messages(context.TODO(), model, client.UserPrompt("why is the sky blue"), anthropic.WithStream(func(r *anthropic.Response) {
146+
response, err := client.Messages(context.TODO(), model.UserPrompt("why is the sky blue"), anthropic.WithStream(func(r *anthropic.Response) {
147147
t.Log(r)
148148
}), anthropic.WithTool(weather))
149149
if assert.NoError(err) {

pkg/anthropic/context.go renamed to pkg/anthropic/session.go

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,60 @@
11
package anthropic
22

33
import (
4+
"context"
45
"encoding/json"
56

7+
// Packages
68
llm "github.com/mutablelogic/go-llm"
79
)
810

911
//////////////////////////////////////////////////////////////////
1012
// TYPES
1113

1214
type session struct {
13-
seq []*MessageMeta
15+
model *model
16+
opts []llm.Opt
17+
seq []*MessageMeta
1418
}
1519

1620
var _ llm.Context = (*session)(nil)
1721

1822
///////////////////////////////////////////////////////////////////////////////
1923
// LIFECYCLE
2024

21-
func (*model) Context(...llm.Opt) (llm.Context, error) {
22-
// TODO: Currently ignoring options
23-
return &session{}, nil
25+
// Return am empty session context object for the model,
26+
// setting session options
27+
func (model *model) Context(opts ...llm.Opt) llm.Context {
28+
return &session{
29+
model: model,
30+
opts: opts,
31+
}
32+
}
33+
34+
// Convenience method to create a session context object
35+
// with a user prompt, which panics on error
36+
func (model *model) UserPrompt(prompt string, opts ...llm.Opt) llm.Context {
37+
// Apply attachments
38+
opt, err := apply(opts...)
39+
if err != nil {
40+
panic(err)
41+
}
42+
43+
meta := MessageMeta{
44+
Role: "user",
45+
Content: make([]*Content, 1, len(opt.data)+1),
46+
}
47+
48+
// Append the text
49+
meta.Content[0] = NewTextContent(prompt)
50+
51+
// Append any additional data
52+
for _, data := range opt.data {
53+
meta.Content = append(meta.Content, data)
54+
}
55+
56+
// Return success
57+
return nil
2458
}
2559

2660
///////////////////////////////////////////////////////////////////////////////
@@ -64,32 +98,14 @@ func (session *session) Text() string {
6498
return string(data)
6599
}
66100

67-
// Append user prompt (and attachments) to a context
68-
func (session *session) AppendUserPrompt(text string, opts ...llm.Opt) error {
69-
// Apply attachments
70-
opt, err := apply(opts...)
71-
if err != nil {
72-
return err
73-
}
74-
75-
meta := MessageMeta{
76-
Role: "user",
77-
Content: make([]*Content, 1, len(opt.data)+1),
78-
}
79-
80-
// Append the text
81-
meta.Content[0] = NewTextContent(text)
82-
83-
// Append any additional data
84-
for _, data := range opt.data {
85-
meta.Content = append(meta.Content, data)
86-
}
87-
88-
// Return success
89-
return nil
101+
// Generate a response from a user prompt (with attachments and
102+
// other empheral options
103+
func (session *session) FromUser(context.Context, string, ...llm.Opt) (llm.Context, error) {
104+
return nil, llm.ErrNotImplemented
90105
}
91106

92-
// Append the result of calling a tool to a context
93-
func (session *session) AppendToolResult(string, ...llm.Opt) error {
94-
return llm.ErrNotImplemented
107+
// Generate a response from a tool, passing the call identifier or
108+
// function name, and the result
109+
func (session *session) FromTool(context.Context, string, any, ...llm.Opt) (llm.Context, error) {
110+
return nil, llm.ErrNotImplemented
95111
}

pkg/ollama/chat_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func Test_chat_001(t *testing.T) {
2727

2828
t.Run("ChatStream", func(t *testing.T) {
2929
assert := assert.New(t)
30-
response, err := client.Chat(context.TODO(), model.MustUserPrompt("why is the sky blue?"), ollama.WithStream(func(stream *ollama.Response) {
30+
response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), ollama.WithStream(func(stream *ollama.Response) {
3131
t.Log(stream)
3232
}))
3333
if !assert.NoError(err) {
@@ -38,7 +38,7 @@ func Test_chat_001(t *testing.T) {
3838

3939
t.Run("ChatNoStream", func(t *testing.T) {
4040
assert := assert.New(t)
41-
response, err := client.Chat(context.TODO(), model.MustUserPrompt("why is the sky green?"))
41+
response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky green?"))
4242
if !assert.NoError(err) {
4343
t.FailNow()
4444
}
@@ -63,7 +63,7 @@ func Test_chat_002(t *testing.T) {
6363
t.Run("Tools", func(t *testing.T) {
6464
assert := assert.New(t)
6565
response, err := client.Chat(context.TODO(),
66-
model.MustUserPrompt("what is the weather in berlin?"),
66+
model.UserPrompt("what is the weather in berlin?"),
6767
ollama.WithTool(ollama.MustTool("get_weather", "Return weather conditions in a location", struct {
6868
Location string `help:"Location to get weather for" required:""`
6969
}{})),
@@ -100,7 +100,7 @@ func Test_chat_003(t *testing.T) {
100100
defer f.Close()
101101

102102
response, err := client.Chat(context.TODO(),
103-
model.MustUserPrompt("describe this photo to me", ollama.WithData(f)),
103+
model.UserPrompt("describe this photo to me", ollama.WithData(f)),
104104
)
105105
if !assert.NoError(err) {
106106
t.FailNow()

pkg/ollama/session.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,16 @@ var _ llm.Context = (*session)(nil)
2424
// LIFECYCLE
2525

2626
// Create a new empty context
27-
func (model *model) Context(opts ...llm.Opt) (llm.Context, error) {
27+
func (model *model) Context(opts ...llm.Opt) llm.Context {
2828
return &session{
2929
model: model,
3030
opts: opts,
31-
}, nil
31+
}
3232
}
3333

3434
// Create a new context with a user prompt
35-
func (model *model) MustUserPrompt(prompt string, opts ...llm.Opt) llm.Context {
36-
context, err := model.Context(opts...)
37-
if err != nil {
38-
panic(err)
39-
}
35+
func (model *model) UserPrompt(prompt string, opts ...llm.Opt) llm.Context {
36+
context := model.Context(opts...)
4037
context.(*session).seq = append(context.(*session).seq, &MessageMeta{
4138
Role: "user",
4239
Content: prompt,

pkg/ollama/session_test.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func Test_session_001(t *testing.T) {
2626
// Session with a single user prompt - streaming
2727
t.Run("stream", func(t *testing.T) {
2828
assert := assert.New(t)
29-
session, err := model.Context(ollama.WithStream(func(stream *ollama.Response) {
29+
session := model.Context(ollama.WithStream(func(stream *ollama.Response) {
3030
t.Log("SESSION DELTA", stream)
3131
}))
3232
assert.NotNil(session)
@@ -42,7 +42,7 @@ func Test_session_001(t *testing.T) {
4242
// Session with a single user prompt - not streaming
4343
t.Run("nostream", func(t *testing.T) {
4444
assert := assert.New(t)
45-
session, err := model.Context()
45+
session := model.Context()
4646
assert.NotNil(session)
4747

4848
new_session, err := session.FromUser(context.TODO(), "Why is the sky blue?")
@@ -75,10 +75,7 @@ func Test_session_002(t *testing.T) {
7575
t.FailNow()
7676
}
7777

78-
session, err := model.Context(ollama.WithTool(tool))
79-
if !assert.NoError(err) {
80-
t.FailNow()
81-
}
78+
session := model.Context(ollama.WithTool(tool))
8279
assert.NotNil(session)
8380
new_session, err := session.FromUser(context.TODO(), "What is today's weather?")
8481
if !assert.NoError(err) {

0 commit comments

Comments
 (0)