Skip to content

Commit 2da590b

Browse files
committed
Refactor AI functions
1 parent f9a1682 commit 2da590b

File tree

6 files changed

+59
-182
lines changed

6 files changed

+59
-182
lines changed

pkg/ai/cmd_question.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ func CmdQuestion(history []string, chatPrompt string, conf *config.Config) (stri
3434

3535
question := historyQuestion + questionNoHistory
3636

37-
// fmt.Println(question)
38-
3937
messages := []ai_types.ChatMessage{
4038
{
4139
Content: constants.COMMAND_QUESTION_PROMPT,

pkg/ai/commit.go

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,15 @@
11
package ai
22

33
import (
4-
"fmt"
5-
6-
gopenai "github.com/CasualCodersProjects/gopenai"
7-
ai_types "github.com/CasualCodersProjects/gopenai/types"
8-
"github.com/chand1012/ottodocs/pkg/calc"
94
"github.com/chand1012/ottodocs/pkg/config"
105
"github.com/chand1012/ottodocs/pkg/constants"
116
)
127

138
func CommitMessage(diff string, conventional bool, conf *config.Config) (string, error) {
14-
openai := gopenai.NewOpenAI(&gopenai.OpenAIOpts{
15-
APIKey: conf.APIKey,
16-
})
17-
189
sysMessage := constants.GIT_DIFF_PROMPT_STD
1910
if conventional {
2011
sysMessage = constants.GIT_DIFF_PROMPT_CONVENTIONAL
2112
}
2213

23-
messages := []ai_types.ChatMessage{
24-
{
25-
Content: sysMessage,
26-
Role: "system",
27-
},
28-
{
29-
Content: diff,
30-
Role: "user",
31-
},
32-
}
33-
34-
tokens, err := calc.PreciseTokens(messages[0].Content, messages[1].Content)
35-
if err != nil {
36-
return "", fmt.Errorf("could not calculate tokens: %s", err)
37-
}
38-
39-
maxTokens := calc.GetMaxTokens(conf.Model) - tokens
40-
41-
if maxTokens < 0 {
42-
return "", fmt.Errorf("the prompt is too long. max length is %d. Got %d", calc.GetMaxTokens(conf.Model), tokens)
43-
}
44-
45-
req := ai_types.NewDefaultChatRequest("")
46-
req.Messages = messages
47-
req.MaxTokens = maxTokens
48-
req.Model = conf.Model
49-
50-
resp, err := openai.CreateChat(req)
51-
if err != nil {
52-
return "", err
53-
}
54-
55-
message := resp.Choices[0].Message.Content
56-
57-
return message, nil
14+
return request(sysMessage, diff, conf)
5815
}

pkg/ai/markdown.go

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,15 @@
11
package ai
22

33
import (
4-
"fmt"
54
"strings"
65

7-
gopenai "github.com/CasualCodersProjects/gopenai"
8-
ai_types "github.com/CasualCodersProjects/gopenai/types"
9-
"github.com/chand1012/ottodocs/pkg/calc"
106
"github.com/chand1012/ottodocs/pkg/config"
117
"github.com/chand1012/ottodocs/pkg/constants"
128
)
139

1410
func Markdown(filePath, contents, chatPrompt string, conf *config.Config) (string, error) {
15-
openai := gopenai.NewOpenAI(&gopenai.OpenAIOpts{
16-
APIKey: conf.APIKey,
17-
})
1811

1912
question := chatPrompt + "\n\n" + strings.TrimRight(contents, " \n")
2013

21-
messages := []ai_types.ChatMessage{
22-
{
23-
Content: constants.DOCUMENT_MARKDOWN_PROMPT,
24-
Role: "system",
25-
},
26-
{
27-
Content: question,
28-
Role: "user",
29-
},
30-
}
31-
32-
tokens, err := calc.PreciseTokens(messages[0].Content, messages[1].Content)
33-
if err != nil {
34-
return "", fmt.Errorf("could not calculate tokens: %s", err)
35-
}
36-
37-
maxTokens := calc.GetMaxTokens(conf.Model) - tokens
38-
39-
if maxTokens < 0 {
40-
return "", fmt.Errorf("the prompt is too long. max length is %d. Got %d", calc.GetMaxTokens(conf.Model), tokens)
41-
}
42-
43-
req := ai_types.NewDefaultChatRequest("")
44-
req.Messages = messages
45-
req.MaxTokens = maxTokens
46-
req.Model = conf.Model
47-
// lower the temperature to make the model more deterministic
48-
// req.Temperature = 0.3
49-
50-
// ask ChatGPT the question
51-
resp, err := openai.CreateChat(req)
52-
if err != nil {
53-
return "", err
54-
}
55-
56-
message := resp.Choices[0].Message.Content
57-
58-
return message, nil
14+
return request(constants.DOCUMENT_MARKDOWN_PROMPT, question, conf)
5915
}

pkg/ai/question.go

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,14 @@
11
package ai
22

33
import (
4-
"fmt"
54
"strings"
65

7-
gopenai "github.com/CasualCodersProjects/gopenai"
8-
ai_types "github.com/CasualCodersProjects/gopenai/types"
9-
"github.com/chand1012/ottodocs/pkg/calc"
106
"github.com/chand1012/ottodocs/pkg/config"
117
"github.com/chand1012/ottodocs/pkg/constants"
128
)
139

1410
func Question(filePath, fileContent, chatPrompt string, conf *config.Config) (string, error) {
15-
openai := gopenai.NewOpenAI(&gopenai.OpenAIOpts{
16-
APIKey: conf.APIKey,
17-
})
18-
1911
question := "File Name: " + filePath + "\nQuestion: " + chatPrompt + "\n\n" + strings.TrimRight(string(fileContent), " \n") + "\nAnswer:"
2012

21-
messages := []ai_types.ChatMessage{
22-
{
23-
Content: constants.QUESTION_PROMPT,
24-
Role: "system",
25-
},
26-
{
27-
Content: question,
28-
Role: "user",
29-
},
30-
}
31-
32-
tokens, err := calc.PreciseTokens(messages[0].Content, messages[1].Content)
33-
if err != nil {
34-
return "", fmt.Errorf("could not calculate tokens: %s", err)
35-
}
36-
37-
maxTokens := calc.GetMaxTokens(conf.Model) - tokens
38-
39-
if maxTokens < 0 {
40-
return "", fmt.Errorf("the prompt is too long. max length is %d. Got %d", calc.GetMaxTokens(conf.Model), tokens)
41-
}
42-
43-
req := ai_types.NewDefaultChatRequest("")
44-
req.Messages = messages
45-
req.MaxTokens = maxTokens
46-
req.Model = conf.Model
47-
// lower the temperature to make the model more deterministic
48-
// req.Temperature = 0.3
49-
50-
resp, err := openai.CreateChat(req)
51-
if err != nil {
52-
fmt.Printf("Error: %s", err)
53-
return "", err
54-
}
55-
56-
message := resp.Choices[0].Message.Content
57-
58-
return message, nil
13+
return request(constants.QUESTION_PROMPT, question, conf)
5914
}

pkg/ai/req.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package ai
2+
3+
import (
4+
"errors"
5+
6+
gopenai "github.com/CasualCodersProjects/gopenai"
7+
ai_types "github.com/CasualCodersProjects/gopenai/types"
8+
"github.com/chand1012/ottodocs/pkg/calc"
9+
"github.com/chand1012/ottodocs/pkg/config"
10+
)
11+
12+
func request(systemMsg, userMsg string, conf *config.Config) (string, error) {
13+
14+
openai := gopenai.NewOpenAI(&gopenai.OpenAIOpts{
15+
APIKey: conf.APIKey,
16+
})
17+
18+
messages := []ai_types.ChatMessage{
19+
{
20+
Content: systemMsg,
21+
Role: "system",
22+
},
23+
{
24+
Content: userMsg,
25+
Role: "user",
26+
},
27+
}
28+
29+
tokens, err := calc.PreciseTokens(messages[0].Content, messages[1].Content)
30+
if err != nil {
31+
return "", err
32+
}
33+
34+
req := ai_types.NewDefaultChatRequest("")
35+
req.Messages = messages
36+
req.MaxTokens = calc.GetMaxTokens(conf.Model) - tokens
37+
req.Model = conf.Model
38+
39+
if req.MaxTokens < 0 {
40+
return "", errors.New("the prompt is too long")
41+
}
42+
43+
resp, err := openai.CreateChat(req)
44+
if err != nil {
45+
return "", err
46+
}
47+
48+
if len(resp.Choices) == 0 {
49+
return "", errors.New("no choices returned. Check your OpenAI API key")
50+
}
51+
52+
message := resp.Choices[0].Message.Content
53+
54+
return message, nil
55+
}

pkg/ai/single_file.go

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@ import (
66
"strconv"
77
"strings"
88

9-
gopenai "github.com/CasualCodersProjects/gopenai"
10-
ai_types "github.com/CasualCodersProjects/gopenai/types"
11-
12-
"github.com/chand1012/ottodocs/pkg/calc"
139
"github.com/chand1012/ottodocs/pkg/config"
1410
"github.com/chand1012/ottodocs/pkg/constants"
1511
"github.com/chand1012/ottodocs/pkg/textfile"
@@ -38,10 +34,6 @@ func extractLineNumber(line string) (int, error) {
3834
// Document a file using the OpenAI ChatGPT API. Takes a file path, a prompt, and an API key as arguments.
3935
func SingleFile(filePath, contents, chatPrompt string, conf *config.Config) (string, error) {
4036

41-
openai := gopenai.NewOpenAI(&gopenai.OpenAIOpts{
42-
APIKey: conf.APIKey,
43-
})
44-
4537
fileEnding := filepath.Ext(filePath)
4638

4739
commentOperator, ok := constants.CommentOperators[fileEnding]
@@ -51,47 +43,11 @@ func SingleFile(filePath, contents, chatPrompt string, conf *config.Config) (str
5143

5244
question := chatPrompt + "\n\n" + strings.TrimRight(contents, " \n")
5345

54-
messages := []ai_types.ChatMessage{
55-
{
56-
Content: constants.DOCUMENT_FILE_PROMPT,
57-
Role: "system",
58-
},
59-
{
60-
Content: question,
61-
Role: "user",
62-
},
63-
}
64-
65-
tokens, err := calc.PreciseTokens(messages[0].Content, messages[1].Content)
46+
message, err := request(constants.DOCUMENT_FILE_PROMPT, question, conf)
6647
if err != nil {
67-
return "", fmt.Errorf("could not calculate tokens: %s", err)
68-
}
69-
70-
maxTokens := calc.GetMaxTokens(conf.Model) - tokens
71-
72-
if maxTokens < 0 {
73-
return "", fmt.Errorf("the prompt is too long. max length is %d. Got %d", calc.GetMaxTokens(conf.Model), tokens)
74-
}
75-
76-
req := ai_types.NewDefaultChatRequest("")
77-
req.Messages = messages
78-
req.MaxTokens = maxTokens
79-
req.Model = conf.Model
80-
// lower the temperature to make the model more deterministic
81-
req.Temperature = 0.3
82-
83-
// ask ChatGPT the question
84-
resp, err := openai.CreateChat(req)
85-
if err != nil {
86-
fmt.Printf("Error: %s", err)
8748
return "", err
8849
}
8950

90-
message := resp.Choices[0].Message.Content
91-
92-
// fmt.Println(message)
93-
// fmt.Println("------------------------")
94-
9551
lineNumbers := []int{}
9652
comments := []string{}
9753

0 commit comments

Comments
 (0)