Skip to content

Commit aa149c1

Browse files
authored
add optional params for audio api, e.g. prompt (#183)
* Compatible with the situation where the mask is empty in CreateEditImage. * Fix the test for the unnecessary removal of the mask.png file. * add image variation implementation * fix image variation bugs * fix ci-lint problem with max line character limit * add offitial doc link * just for codeball test * fix lint problem * add optional params for audio api, e.g. prompt * add comment for new args in translation
1 parent d529d13 commit aa149c1

File tree

2 files changed

+100
-2
lines changed

2 files changed

+100
-2
lines changed

audio.go

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,13 @@ const (
1616
)
1717

1818
// AudioRequest represents a request structure for audio API.
19+
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
1920
type AudioRequest struct {
20-
Model string
21-
FilePath string
21+
Model string
22+
FilePath string
23+
Prompt string // For translation, it should be in English
24+
Temperature float32
25+
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
2226
}
2327

2428
// AudioResponse represents a response structure for audio API.
@@ -94,6 +98,47 @@ func audioMultipartForm(request AudioRequest, w *multipart.Writer) error {
9498
if _, err = io.Copy(fw, modelName); err != nil {
9599
return fmt.Errorf("writing model name: %w", err)
96100
}
101+
102+
// Create a form field for the prompt (if provided)
103+
if request.Prompt != "" {
104+
fw, err = w.CreateFormField("prompt")
105+
if err != nil {
106+
return fmt.Errorf("creating form field: %w", err)
107+
}
108+
109+
prompt := bytes.NewReader([]byte(request.Prompt))
110+
if _, err = io.Copy(fw, prompt); err != nil {
111+
return fmt.Errorf("writing prompt: %w", err)
112+
}
113+
}
114+
115+
// Create a form field for the temperature (if provided)
116+
if request.Temperature != 0 {
117+
fw, err = w.CreateFormField("temperature")
118+
if err != nil {
119+
return fmt.Errorf("creating form field: %w", err)
120+
}
121+
122+
temperature := bytes.NewReader([]byte(fmt.Sprintf("%.2f", request.Temperature)))
123+
if _, err = io.Copy(fw, temperature); err != nil {
124+
return fmt.Errorf("writing temperature: %w", err)
125+
}
126+
}
127+
128+
// Create a form field for the language (if provided)
129+
if request.Language != "" {
130+
fw, err = w.CreateFormField("language")
131+
if err != nil {
132+
return fmt.Errorf("creating form field: %w", err)
133+
}
134+
135+
language := bytes.NewReader([]byte(request.Language))
136+
if _, err = io.Copy(fw, language); err != nil {
137+
return fmt.Errorf("writing language: %w", err)
138+
}
139+
}
140+
141+
// Close the multipart writer
97142
w.Close()
98143

99144
return nil

audio_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,59 @@ func TestAudio(t *testing.T) {
6969
}
7070
}
7171

72+
func TestAudioWithOptionalArgs(t *testing.T) {
73+
server := test.NewTestServer()
74+
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
75+
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
76+
// create the test server
77+
var err error
78+
ts := server.OpenAITestServer()
79+
ts.Start()
80+
defer ts.Close()
81+
82+
config := DefaultConfig(test.GetTestToken())
83+
config.BaseURL = ts.URL + "/v1"
84+
client := NewClientWithConfig(config)
85+
86+
testcases := []struct {
87+
name string
88+
createFn func(context.Context, AudioRequest) (AudioResponse, error)
89+
}{
90+
{
91+
"transcribe",
92+
client.CreateTranscription,
93+
},
94+
{
95+
"translate",
96+
client.CreateTranslation,
97+
},
98+
}
99+
100+
ctx := context.Background()
101+
102+
dir, cleanup := createTestDirectory(t)
103+
defer cleanup()
104+
105+
for _, tc := range testcases {
106+
t.Run(tc.name, func(t *testing.T) {
107+
path := filepath.Join(dir, "fake.mp3")
108+
createTestFile(t, path)
109+
110+
req := AudioRequest{
111+
FilePath: path,
112+
Model: "whisper-3",
113+
Prompt: "用简体中文",
114+
Temperature: 0.5,
115+
Language: "zh",
116+
}
117+
_, err = tc.createFn(ctx, req)
118+
if err != nil {
119+
t.Fatalf("audio API error: %v", err)
120+
}
121+
})
122+
}
123+
}
124+
72125
// createTestFile creates a fake file with "hello" as the content.
73126
func createTestFile(t *testing.T, path string) {
74127
file, err := os.Create(path)

0 commit comments

Comments
 (0)