Skip to content

Commit 28d6a74

Browse files
committed
Add context-aware editing and improve token management in edit command
1 parent 8168275 commit 28d6a74

File tree

3 files changed

+93
-28
lines changed

3 files changed

+93
-28
lines changed

cmd/edit.go

Lines changed: 81 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,18 @@ Copyright © 2023 Chandler <chandler@chand1012.dev>
44
package cmd
55

66
import (
7+
"context"
78
"fmt"
89
"os"
910
"strings"
1011

11-
"github.com/chand1012/ottodocs/pkg/ai"
12+
"github.com/chand1012/ottodocs/pkg/calc"
1213
"github.com/chand1012/ottodocs/pkg/config"
1314
"github.com/chand1012/ottodocs/pkg/constants"
1415
"github.com/chand1012/ottodocs/pkg/textfile"
1516
"github.com/chand1012/ottodocs/pkg/utils"
1617
l "github.com/charmbracelet/log"
18+
"github.com/sashabaranov/go-openai"
1719
"github.com/spf13/cobra"
1820
)
1921

@@ -84,11 +86,14 @@ Example: otto edit main.go --start 1 --end 10 --goal "Refactor the function"`,
8486
}
8587
}
8688

89+
var messages []openai.ChatCompletionMessage
90+
var newCode string
91+
8792
var prompt string
8893
if editCode != "" {
89-
prompt = constants.EDIT_CODE_PROMPT + "\nEDIT: " + editCode + "\n\nGOAL: " + chatPrompt + "\n\nFILE: " + filePath + "\n\n" + contents
94+
prompt = "EDIT: " + editCode + "\n\nGOAL: " + chatPrompt + "\n\nFILE: " + filePath + "\n\n" + contents + "\n\nBe sure to only output the edited code, do not print the entire file."
9095
} else {
91-
prompt = constants.EDIT_CODE_PROMPT + "\nGOAL: " + chatPrompt + "\n\nFILE: " + filePath + "\n\n" + contents
96+
prompt = "GOAL: " + chatPrompt + "\n\nFILE: " + filePath + "\n\n" + contents
9297
}
9398

9499
if len(contextFiles) > 0 {
@@ -103,35 +108,87 @@ Example: otto edit main.go --start 1 --end 10 --goal "Refactor the function"`,
103108
}
104109
}
105110

106-
stream, err := ai.SimpleStreamRequest(prompt, c)
107-
if err != nil {
108-
log.Errorf("Error requesting from OpenAI: %s", err)
109-
os.Exit(1)
111+
client := openai.NewClient(c.APIKey)
112+
113+
messages = []openai.ChatCompletionMessage{
114+
{
115+
Role: openai.ChatMessageRoleSystem,
116+
Content: constants.EDIT_CODE_PROMPT,
117+
},
118+
{
119+
Role: openai.ChatMessageRoleUser,
120+
Content: prompt,
121+
},
110122
}
111123

112-
// print the response
113-
utils.PrintColoredTextLn("New Code:", c.OttoColor)
114-
newCode, err := utils.PrintChatCompletionStream(stream)
115-
if err != nil {
116-
log.Errorf("Error printing chat completion stream: %s", err)
117-
os.Exit(1)
118-
}
124+
for {
119125

120-
confirmMsg := "Would you like to overwrite the file with the new code? (y/N): "
121-
if appendFile {
122-
confirmMsg = "Would you like to append the new code to the file? (y/N): "
123-
}
126+
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
127+
Model: c.Model,
128+
Messages: messages,
129+
})
130+
131+
if err != nil {
132+
log.Errorf("Error requesting from OpenAI: %s", err)
133+
os.Exit(1)
134+
}
124135

125-
if !force {
126-
confirm, err := utils.Input(confirmMsg)
136+
// print the response
137+
utils.PrintColoredTextLn("New Code:", c.OttoColor)
138+
newCode, err = utils.PrintChatCompletionStream(stream)
127139
if err != nil {
128-
log.Errorf("Error getting input: %s", err)
140+
log.Errorf("Error printing chat completion stream: %s", err)
129141
os.Exit(1)
130142
}
131143

132-
confirm = strings.ToLower(confirm)
133-
if confirm != "y" && confirm != "yes" {
134-
os.Exit(0)
144+
confirmMsg := "Would you like to write the file with the new code? (y/N). Type your input to keep editing: "
145+
if appendFile {
146+
confirmMsg = "Would you like to append the new code to the file? (y/N). Type your input to keep editing: "
147+
}
148+
149+
if !force {
150+
confirm, err := utils.Input(confirmMsg)
151+
if err != nil {
152+
log.Errorf("Error getting input: %s", err)
153+
os.Exit(1)
154+
}
155+
156+
confirm = strings.ToLower(confirm)
157+
if confirm == "n" || confirm == "no" {
158+
os.Exit(0)
159+
} else if confirm == "y" || confirm == "yes" {
160+
break
161+
} else {
162+
codeTokens, err := calc.PreciseTokens(newCode)
163+
if err != nil {
164+
log.Errorf("Error calculating tokens: %s", err)
165+
os.Exit(1)
166+
}
167+
168+
maxTokens := calc.GetMaxTokens(c.Model) - codeTokens
169+
170+
var newMessages []openai.ChatCompletionMessage
171+
newMessages = []openai.ChatCompletionMessage{
172+
{
173+
Role: openai.ChatMessageRoleUser,
174+
Content: "Use the following input to edit the code: " + confirm,
175+
},
176+
{
177+
Role: openai.ChatMessageRoleAssistant,
178+
Content: newCode,
179+
},
180+
}
181+
utils.ReverseSlice(messages)
182+
for _, message := range messages {
183+
if calc.PreciseTokensFromMessages(newMessages, c.Model) < maxTokens {
184+
newMessages = append(newMessages, message)
185+
}
186+
}
187+
utils.ReverseSlice(newMessages)
188+
messages = newMessages
189+
utils.PrintColoredText("Otto: ", c.OttoColor)
190+
fmt.Println("Ok! Here is the new code, taking your input into account.")
191+
}
135192
}
136193
}
137194

pkg/calc/tokens.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func GetMaxTokens(model string) int {
4646
return 4096
4747
}
4848

49-
func PreciseTokensFromModel(messages []openai.ChatCompletionMessage, model string) (num_tokens int) {
49+
func PreciseTokensFromMessages(messages []openai.ChatCompletionMessage, model string) (num_tokens int) {
5050
tkm, err := tiktoken.EncodingForModel(model)
5151
if err != nil {
5252
err = fmt.Errorf("EncodingForModel: %v", err)

pkg/utils/reverse.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
package utils
22

3+
// ReverseSlice reverses the order of the elements in a slice of any type.
34
func ReverseSlice[T any](s []T) {
4-
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
5-
s[i], s[j] = s[j], s[i]
5+
// Initialize two pointers, left and right, at the beginning and end of the slice.
6+
left := 0
7+
right := len(s) - 1
8+
9+
// Swap elements at the left and right indices until they meet in the middle.
10+
for left < right {
11+
s[left], s[right] = s[right], s[left]
12+
left++
13+
right--
614
}
7-
}
15+
}

0 commit comments

Comments
 (0)