Skip to content

Commit d945489

Browse files
committed
Added chat completions
1 parent 33dc4e5 commit d945489

File tree

5 files changed

+486
-173
lines changed

5 files changed

+486
-173
lines changed

pkg/client/transport.go

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ type logtransport struct {
1919
v bool
2020
}
2121

22+
type readwrapper struct {
23+
r io.ReadCloser
24+
data bytes.Buffer
25+
}
26+
2227
///////////////////////////////////////////////////////////////////////////////
2328
// LIFECYCLE
2429

@@ -58,27 +63,72 @@ func (transport *logtransport) RoundTrip(req *http.Request) (*http.Response, err
5863
defer func() {
5964
fmt.Fprintln(transport.w, " Took", time.Since(then).Milliseconds(), "ms")
6065
}()
66+
67+
// Wrap the request
68+
req.Body = &readwrapper{r: req.Body}
69+
70+
// Perform the roundtrip
6171
resp, err := transport.RoundTripper.RoundTrip(req)
6272
if err != nil {
6373
fmt.Fprintln(transport.w, "error:", err)
64-
} else {
65-
fmt.Fprintln(transport.w, "response:", resp.Status)
66-
for k, v := range resp.Header {
67-
fmt.Fprintf(transport.w, " <= %v: %q\n", k, v)
74+
return resp, err
75+
}
76+
77+
// If verbose is switched on, output the payload
78+
if transport.v {
79+
data, err := req.Body.(*readwrapper).as(req.Header.Get("Content-Type"))
80+
if err == nil {
81+
fmt.Fprintln(transport.w, " ", string(data))
6882
}
69-
// If verbose is switched on, read the body
70-
if transport.v && resp.Body != nil {
71-
contentType := resp.Header.Get("Content-Type")
72-
if contentType == ContentTypeJson || contentType == ContentTypeTextPlain {
73-
defer resp.Body.Close()
74-
body, err := io.ReadAll(resp.Body)
75-
if err == nil {
76-
fmt.Fprintln(transport.w, " ", string(body))
77-
}
78-
resp.Body = io.NopCloser(bytes.NewReader(body))
83+
}
84+
85+
fmt.Fprintln(transport.w, "response:", resp.Status)
86+
for k, v := range resp.Header {
87+
fmt.Fprintf(transport.w, " <= %v: %q\n", k, v)
88+
}
89+
90+
// If verbose is switched on, read the body
91+
if transport.v && resp.Body != nil {
92+
contentType := resp.Header.Get("Content-Type")
93+
if contentType == ContentTypeJson || contentType == ContentTypeTextPlain {
94+
defer resp.Body.Close()
95+
body, err := io.ReadAll(resp.Body)
96+
if err == nil {
97+
fmt.Fprintln(transport.w, " ", string(body))
7998
}
99+
resp.Body = io.NopCloser(bytes.NewReader(body))
80100
}
81101
}
82102

103+
// Return success
83104
return resp, err
84105
}
106+
107+
///////////////////////////////////////////////////////////////////////////////
108+
// PRIVATE METHODS
109+
110+
func (w *readwrapper) Read(b []byte) (n int, err error) {
111+
n, err = w.r.Read(b)
112+
if err == nil {
113+
_, err = w.data.Write(b[:n])
114+
}
115+
return n, err
116+
}
117+
118+
func (w *readwrapper) Close() error {
119+
return w.r.Close()
120+
}
121+
122+
func (w *readwrapper) as(mimetype string) ([]byte, error) {
123+
switch mimetype {
124+
case ContentTypeJson:
125+
dest := bytes.NewBuffer(nil)
126+
if err := json.Indent(dest, w.data.Bytes(), " ", " "); err != nil {
127+
return nil, err
128+
} else {
129+
return dest.Bytes(), nil
130+
}
131+
default:
132+
return w.data.Bytes(), nil
133+
}
134+
}

pkg/openai/chat.go

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,36 @@ const (
99
defaultChatCompletion = "gpt-3.5-turbo"
1010
)
1111

12+
///////////////////////////////////////////////////////////////////////////////
13+
// PUBLIC METHODS
14+
15+
func NewUserMessage(text string) Message {
16+
return Message{
17+
Role: "user", Content: &text,
18+
}
19+
}
20+
21+
func NewSystemMessage(text string) Message {
22+
return Message{
23+
Role: "system", Content: &text,
24+
}
25+
}
26+
27+
func NewAssistantMessage(text string) Message {
28+
return Message{
29+
Role: "assistant", Content: &text,
30+
}
31+
}
32+
1233
///////////////////////////////////////////////////////////////////////////////
1334
// API CALLS
1435

1536
// Chat creates a model response for the given chat conversation.
16-
func (c *Client) Chat(opts ...Opt) (Chat, error) {
37+
func (c *Client) Chat(messages []Message, opts ...Opt) (Chat, error) {
1738
// Create the request
1839
var request reqChat
1940
request.Model = defaultChatCompletion
41+
request.Messages = messages
2042
for _, opt := range opts {
2143
if err := opt(&request); err != nil {
2244
return Chat{}, err

pkg/openai/chat_test.go

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package openai_test
22

33
import (
4+
"encoding/json"
45
"os"
56
"testing"
67

@@ -15,7 +16,49 @@ func Test_chat_001(t *testing.T) {
1516
assert.NoError(err)
1617
assert.NotNil(client)
1718

18-
response, err := client.Chat()
19+
response, err := client.Chat([]openai.Message{
20+
openai.NewUserMessage("What would be the best app to use to get the weather in berlin today?"),
21+
})
1922
assert.NoError(err)
2023
assert.NotNil(response)
24+
assert.NotEmpty(response)
25+
26+
data, err := json.MarshalIndent(response, "", " ")
27+
assert.NoError(err)
28+
t.Log(string(data))
29+
30+
}
31+
32+
func Test_chat_002(t *testing.T) {
33+
assert := assert.New(t)
34+
client, err := openai.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
35+
assert.NoError(err)
36+
assert.NotNil(client)
37+
38+
response, err := client.Chat([]openai.Message{
39+
openai.NewUserMessage("What is the weather in berlin today?"),
40+
}, openai.OptFunction("get_weather", "Get the weather in a specific city and country", openai.ToolParameter{
41+
Name: "city",
42+
Type: "string",
43+
Description: "The city to get the weather for",
44+
Required: true,
45+
}, openai.ToolParameter{
46+
Name: "country",
47+
Type: "string",
48+
Description: "The country to get the weather for",
49+
Required: true,
50+
}, openai.ToolParameter{
51+
Name: "time",
52+
Type: "string",
53+
Description: "When to get the weather for. If not specified, defaults to the current time",
54+
Required: true,
55+
}))
56+
assert.NoError(err)
57+
assert.NotNil(response)
58+
assert.NotEmpty(response)
59+
60+
data, err := json.MarshalIndent(response, "", " ")
61+
assert.NoError(err)
62+
t.Log(string(data))
63+
2164
}

0 commit comments

Comments
 (0)