Skip to content

Commit 8c65b35

Browse files
authored
update image api *os.File to io.Reader (#994)
* update image api *os.File to io.Reader * update code style * add reader test * supplementary reader test * update the reader in the form builder test * add commnet * update comment * update code style
1 parent 4d2e7ab commit 8c65b35

File tree

4 files changed

+88
-27
lines changed

4 files changed

+88
-27
lines changed

image.go

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ package openai
33
import (
44
"bytes"
55
"context"
6+
"io"
67
"net/http"
7-
"os"
88
"strconv"
99
)
1010

@@ -134,31 +134,32 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons
134134

135135
// ImageEditRequest represents the request structure for the image API.
136136
type ImageEditRequest struct {
137-
Image *os.File `json:"image,omitempty"`
138-
Mask *os.File `json:"mask,omitempty"`
139-
Prompt string `json:"prompt,omitempty"`
140-
Model string `json:"model,omitempty"`
141-
N int `json:"n,omitempty"`
142-
Size string `json:"size,omitempty"`
143-
ResponseFormat string `json:"response_format,omitempty"`
144-
Quality string `json:"quality,omitempty"`
145-
User string `json:"user,omitempty"`
137+
Image io.Reader `json:"image,omitempty"`
138+
Mask io.Reader `json:"mask,omitempty"`
139+
Prompt string `json:"prompt,omitempty"`
140+
Model string `json:"model,omitempty"`
141+
N int `json:"n,omitempty"`
142+
Size string `json:"size,omitempty"`
143+
ResponseFormat string `json:"response_format,omitempty"`
144+
Quality string `json:"quality,omitempty"`
145+
User string `json:"user,omitempty"`
146146
}
147147

148148
// CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API.
149149
func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) (response ImageResponse, err error) {
150150
body := &bytes.Buffer{}
151151
builder := c.createFormBuilder(body)
152152

153-
// image
154-
err = builder.CreateFormFile("image", request.Image)
153+
// image, filename is not required
154+
err = builder.CreateFormFileReader("image", request.Image, "")
155155
if err != nil {
156156
return
157157
}
158158

159159
// mask, it is optional
160160
if request.Mask != nil {
161-
err = builder.CreateFormFile("mask", request.Mask)
161+
// mask, filename is not required
162+
err = builder.CreateFormFileReader("mask", request.Mask, "")
162163
if err != nil {
163164
return
164165
}
@@ -206,12 +207,12 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
206207

207208
// ImageVariRequest represents the request structure for the image API.
208209
type ImageVariRequest struct {
209-
Image *os.File `json:"image,omitempty"`
210-
Model string `json:"model,omitempty"`
211-
N int `json:"n,omitempty"`
212-
Size string `json:"size,omitempty"`
213-
ResponseFormat string `json:"response_format,omitempty"`
214-
User string `json:"user,omitempty"`
210+
Image io.Reader `json:"image,omitempty"`
211+
Model string `json:"model,omitempty"`
212+
N int `json:"n,omitempty"`
213+
Size string `json:"size,omitempty"`
214+
ResponseFormat string `json:"response_format,omitempty"`
215+
User string `json:"user,omitempty"`
215216
}
216217

217218
// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API.
@@ -220,8 +221,8 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
220221
body := &bytes.Buffer{}
221222
builder := c.createFormBuilder(body)
222223

223-
// image
224-
err = builder.CreateFormFile("image", request.Image)
224+
// image, filename is not required
225+
err = builder.CreateFormFileReader("image", request.Image, "")
225226
if err != nil {
226227
return
227228
}

image_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@ func TestImageFormBuilderFailures(t *testing.T) {
5454
}
5555

5656
mockFailedErr := fmt.Errorf("mock form builder fail")
57-
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
57+
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
5858
return mockFailedErr
5959
}
6060
_, err := client.CreateEditImage(ctx, req)
6161
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
6262

63-
mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error {
63+
mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error {
6464
if name == "mask" {
6565
return mockFailedErr
6666
}
@@ -119,13 +119,13 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
119119
req := ImageVariRequest{}
120120

121121
mockFailedErr := fmt.Errorf("mock form builder fail")
122-
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
122+
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
123123
return mockFailedErr
124124
}
125125
_, err := client.CreateVariImage(ctx, req)
126126
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
127127

128-
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
128+
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
129129
return nil
130130
}
131131

internal/form_builder.go

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import (
44
"fmt"
55
"io"
66
"mime/multipart"
7+
"net/textproto"
78
"os"
8-
"path"
9+
"path/filepath"
10+
"strings"
911
)
1012

1113
type FormBuilder interface {
@@ -30,8 +32,37 @@ func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er
3032
return fb.createFormFile(fieldname, file, file.Name())
3133
}
3234

35+
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
36+
37+
func escapeQuotes(s string) string {
38+
return quoteEscaper.Replace(s)
39+
}
40+
41+
// CreateFormFileReader creates a form field with a file reader.
42+
// The filename in parameters can be an empty string.
43+
// The filename in Content-Disposition is required, But it can be an empty string.
3344
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
34-
return fb.createFormFile(fieldname, r, path.Base(filename))
45+
h := make(textproto.MIMEHeader)
46+
h.Set(
47+
"Content-Disposition",
48+
fmt.Sprintf(
49+
`form-data; name="%s"; filename="%s"`,
50+
escapeQuotes(fieldname),
51+
escapeQuotes(filepath.Base(filename)),
52+
),
53+
)
54+
55+
fieldWriter, err := fb.writer.CreatePart(h)
56+
if err != nil {
57+
return err
58+
}
59+
60+
_, err = io.Copy(fieldWriter, r)
61+
if err != nil {
62+
return err
63+
}
64+
65+
return nil
3566
}
3667

3768
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {

internal/form_builder_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,32 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
4343
checks.HasError(t, err, "formbuilder should return error if file is closed")
4444
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
4545
}
46+
47+
type failingReader struct {
48+
}
49+
50+
var errMockFailingReaderError = errors.New("mock reader failed")
51+
52+
func (*failingReader) Read([]byte) (int, error) {
53+
return 0, errMockFailingReaderError
54+
}
55+
56+
func TestFormBuilderWithReader(t *testing.T) {
57+
file, err := os.CreateTemp(t.TempDir(), "")
58+
if err != nil {
59+
t.Fatalf("Error creating tmp file: %v", err)
60+
}
61+
defer file.Close()
62+
builder := NewFormBuilder(&failingWriter{})
63+
err = builder.CreateFormFileReader("file", file, file.Name())
64+
checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
65+
66+
builder = NewFormBuilder(&bytes.Buffer{})
67+
reader := &failingReader{}
68+
err = builder.CreateFormFileReader("file", reader, "")
69+
checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails")
70+
71+
successReader := &bytes.Buffer{}
72+
err = builder.CreateFormFileReader("file", successReader, "")
73+
checks.NoError(t, err, "formbuilder should not return error")
74+
}

0 commit comments

Comments
 (0)