Skip to content

Commit 33dc4e5

Browse files
committed
Added chat
1 parent f908b42 commit 33dc4e5

File tree

4 files changed

+104
-1
lines changed

4 files changed

+104
-1
lines changed

pkg/openai/chat.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package openai
2+
3+
import (
4+
// Packages
5+
"github.com/mutablelogic/go-client/pkg/client"
6+
)
7+
8+
const (
9+
defaultChatCompletion = "gpt-3.5-turbo"
10+
)
11+
12+
///////////////////////////////////////////////////////////////////////////////
13+
// API CALLS
14+
15+
// Chat creates a model response for the given chat conversation.
16+
func (c *Client) Chat(opts ...Opt) (Chat, error) {
17+
// Create the request
18+
var request reqChat
19+
request.Model = defaultChatCompletion
20+
for _, opt := range opts {
21+
if err := opt(&request); err != nil {
22+
return Chat{}, err
23+
}
24+
}
25+
26+
// Return the response
27+
var response Chat
28+
if payload, err := client.NewJSONRequest(request, client.ContentTypeJson); err != nil {
29+
return Chat{}, err
30+
} else if err := c.Do(payload.Post(), &response, client.OptPath("chat/completions")); err != nil {
31+
return Chat{}, err
32+
}
33+
34+
// Return success
35+
return response, nil
36+
}

pkg/openai/chat_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package openai_test
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
opts "github.com/mutablelogic/go-client/pkg/client"
8+
openai "github.com/mutablelogic/go-client/pkg/openai"
9+
assert "github.com/stretchr/testify/assert"
10+
)
11+
12+
func Test_chat_001(t *testing.T) {
13+
assert := assert.New(t)
14+
client, err := openai.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
15+
assert.NoError(err)
16+
assert.NotNil(client)
17+
18+
response, err := client.Chat()
19+
assert.NoError(err)
20+
assert.NotNil(response)
21+
}

pkg/openai/opts.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
type Opt func(Request) error
1616

1717
///////////////////////////////////////////////////////////////////////////////
18-
// CreateEmbedding request options
18+
// setModel
1919

2020
// Set the model identifier
2121
func (req *reqCreateEmbedding) setModel(value string) error {
@@ -27,6 +27,16 @@ func (req *reqCreateEmbedding) setModel(value string) error {
2727
}
2828
}
2929

30+
// Set the model identifier
31+
func (req *reqChat) setModel(value string) error {
32+
if value == "" {
33+
return ErrBadParameter.With("Model")
34+
} else {
35+
req.Model = value
36+
return nil
37+
}
38+
}
39+
3040
///////////////////////////////////////////////////////////////////////////////
3141
// PUBLIC METHODS
3242

pkg/openai/schema.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,33 @@ type Embeddings struct {
3131
} `json:"usage"`
3232
}
3333

34+
// A chat completion object
35+
type Chat struct {
36+
Id string `json:"id"`
37+
Object string `json:"object"`
38+
Created int64 `json:"created"`
39+
Model string `json:"model"`
40+
SystemFingerprint string `json:"system_fingerprint"`
41+
Choices []*MessageChoice `json:"choices"`
42+
Usage struct {
43+
PromptTokens int `json:"prompt_tokens"`
44+
CompletionTokens int `json:"completion_tokens"`
45+
TotalTokens int `json:"total_tokens"`
46+
} `json:"usage"`
47+
}
48+
49+
// A message choice object
50+
type MessageChoice struct {
51+
Index int `json:"index"`
52+
FinishReason string `json:"finish_reason"`
53+
}
54+
55+
// A message choice object
56+
type Message struct {
57+
Role string `json:"role"`
58+
Content string `json:"content"`
59+
}
60+
3461
///////////////////////////////////////////////////////////////////////////////
3562
// REQUESTS
3663

@@ -41,6 +68,15 @@ type reqCreateEmbedding struct {
4168
User string `json:"user,omitempty"`
4269
}
4370

71+
type reqChat struct {
72+
Model string `json:"model"`
73+
Messages []Message `json:"messages"`
74+
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
75+
PresencePenalty float64 `json:"presence_penalty,omitempty"`
76+
MaxTokens int `json:"max_tokens,omitempty"`
77+
Count int `json:"n,omitempty"`
78+
}
79+
4480
///////////////////////////////////////////////////////////////////////////////
4581
// RESPONSES
4682

0 commit comments

Comments
 (0)