diff --git a/changes_revision.go b/changes_revision.go index 719e1bb..8950635 100644 --- a/changes_revision.go +++ b/changes_revision.go @@ -2,6 +2,7 @@ package gerrit import ( "context" + "errors" "fmt" "net/url" ) @@ -140,6 +141,35 @@ type PatchOptions struct { Path string `url:"path,omitempty"` } +// StringPointerWriter is a new type based on *string. +// We would like it to implement Writer interface so that we can avoid +// unmarshalling of non JSON responses by Client.Do +type StringPointerWriter struct { + Target *string +} + +func NewStringPointerWriter(target *string) (*StringPointerWriter, error) { + if target == nil { + return nil, errors.New("StringPointerWriter: target *string cannot be nil") + } + return &StringPointerWriter{Target: target}, nil +} + +// Write implements the io.Writer interface for *StringPointerWriter. +func (spw *StringPointerWriter) Write(p []byte) (n int, err error) { + // Check if the StringPointerWriter pointer itself is nil, or if its Target is nil. + if spw == nil || spw.Target == nil { + return 0, errors.New("StringPointerWriter: receiver or target *string is nil, cannot write") + } + + // *(spw.Target) gives us the actual string value. + // We append the new data (converted from []byte to string) to it. + *(spw.Target) = *(spw.Target) + string(p) + + // Return the number of bytes written and no error. + return len(p), nil +} + // GetDiff gets the diff of a file from a certain revision. // // Gerrit API docs: https://gerrit-review.googlesource.com/Documentation/rest-api-changes.html#get-diff @@ -445,8 +475,17 @@ func (s *ChangesService) GetPatch(ctx context.Context, changeID, revisionID stri return nil, nil, err } - v := new(string) - resp, err := s.client.Do(req, v) + strVal := "" + var v *string = &strVal + + // Create an instance of our writer struct since the /patch endpoint + // returns a base64 encoded string which cannot be marshalled as JSON. + stringWriter, err := NewStringPointerWriter(v) + if err != nil { + return nil, nil, err + } + + resp, err := s.client.Do(req, stringWriter) if err != nil { return nil, resp, err } diff --git a/changes_revision_test.go b/changes_revision_test.go index 93d0d98..4eaf85e 100644 --- a/changes_revision_test.go +++ b/changes_revision_test.go @@ -2,6 +2,7 @@ package gerrit_test import ( "context" + "encoding/base64" "fmt" "net/http" "net/http/httptest" @@ -82,3 +83,41 @@ func TestChangesService_ListFilesReviewed(t *testing.T) { t.Errorf("client.Changes.ListFilesReviewed:\ngot: %q\nwant: %q", got, want) } } + +func TestChangesService_GetPatch(t *testing.T) { + rawPatch := `diff --git a/COMMIT_MSG b/COMMIT_MSG +index 123..456 100644 +--- a/COMMIT_MSG ++++ b/COMMIT_MSG +@@ -1,1 +1,1 @@ +-Old subject ++New subject for A +diff --git a/fileA.txt b/fileA.txt +new file mode 100644 +index 0000000..abc 100644 +--- /dev/null ++++ b/fileA.txt +@@ -0,0 +1 @@ ++Content for A +` + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got, want := r.URL.String(), "/changes/123/revisions/456/patch"; got != want { + t.Errorf("request URL:\ngot: %q\nwant: %q", got, want) + } + encodedPatchContent := base64.StdEncoding.EncodeToString( + []byte(rawPatch)) + fmt.Fprint(w, encodedPatchContent) + })) + defer ts.Close() + + ctx := context.Background() + client := newClient(ctx, t, ts) + got, _, err := client.Changes.GetPatch(ctx, "123", "456", nil) + if err != nil { + t.Fatal(err) + } + want := base64.StdEncoding.EncodeToString([]byte(rawPatch)) + if !reflect.DeepEqual(*got, want) { + t.Errorf("client.Changes.GetPatch:\ngot: %q\nwant: %q", *got, want) + } +}