Skip to content

sec: fix s3 and gcs host checks #512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion detect_gcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func (d *GCSDetector) Detect(src, _ string) (string, bool, error) {
return "", false, nil
}

if strings.Contains(src, "googleapis.com/") {
if strings.Contains(src, ".googleapis.com/") {
return d.detectHTTP(src)
}

Expand Down
59 changes: 59 additions & 0 deletions detect_gcs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ func TestGCSDetector(t *testing.T) {
"www.googleapis.com/storage/v1/foo/bar.baz",
"gcs::https://www.googleapis.com/storage/v1/foo/bar.baz",
},
{
"www.googleapis.com/storage/v2/foo/bar/toor.baz",
"gcs::https://www.googleapis.com/storage/v2/foo/bar/toor.baz",
},
}

pwd := "/pwd"
Expand All @@ -42,3 +46,58 @@ func TestGCSDetector(t *testing.T) {
}
}
}

func TestGCSDetector_MalformedDetectHTTP(t *testing.T) {
cases := []struct {
Name string
Input string
Expected string
Output string
}{
{
"valid url",
"www.googleapis.com/storage/v1/my-bucket/foo/bar",
"",
"gcs::https://www.googleapis.com/storage/v1/my-bucket/foo/bar",
},
{
"empty url",
"",
"",
"",
},
{
"not valid url",
"storage/v1/my-bucket/foo/bar",
"error parsing GCS URL",
"",
},
{
"not valid url domain",
"www.googleapis.com.invalid/storage/v1/",
"URL is not a valid GCS URL",
"",
},
{
"not valid url length",
"http://www.googleapis.com/storage",
"URL is not a valid GCS URL",
"",
},
}

pwd := "/pwd"
f := new(GCSDetector)
for _, tc := range cases {
output, _, err := f.Detect(tc.Input, pwd)
if err != nil {
if err.Error() != tc.Expected {
t.Fatalf("expected error %s, got %s for %s", tc.Expected, err.Error(), tc.Name)
}
}

if output != tc.Output {
t.Fatalf("expected %s, got %s for %s", tc.Output, output, tc.Name)
}
}
}
55 changes: 55 additions & 0 deletions detect_s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,58 @@ func TestS3Detector(t *testing.T) {
}
}
}

func TestS3Detector_MalformedDetectHTTP(t *testing.T) {
cases := []struct {
Name string
Input string
Expected string
Output string
}{
{
"valid url",
"s3.amazonaws.com/bucket/foo/bar",
"",
"s3::https://s3.amazonaws.com/bucket/foo/bar",
},
{
"empty url",
"",
"",
"",
},
{
"not valid url",
"bucket/foo/bar",
"error parsing S3 URL",
"",
},
{
"not valid url domain",
"s3.amazonaws.com.invalid/bucket/foo/bar",
"error parsing S3 URL",
"",
},
{
"not valid url lenght",
"http://s3.amazonaws.com",
"URL is not a valid S3 URL",
"",
},
}

pwd := "/pwd"
f := new(S3Detector)
for _, tc := range cases {
output, _, err := f.Detect(tc.Input, pwd)
if err != nil {
if err.Error() != tc.Expected {
t.Fatalf("expected error %s, got %s for %s", tc.Expected, err.Error(), tc.Name)
}
}

if output != tc.Output {
t.Fatalf("expected %s, got %s for %s", tc.Output, output, tc.Name)
}
}
}
4 changes: 3 additions & 1 deletion get_gcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ func (g *GCSGetter) getObject(ctx context.Context, client *storage.Client, dst,
}

func (g *GCSGetter) parseURL(u *url.URL) (bucket, path, fragment string, err error) {
if strings.Contains(u.Host, "googleapis.com") {
if strings.HasSuffix(u.Host, ".googleapis.com") {
hostParts := strings.Split(u.Host, ".")
if len(hostParts) != 3 {
err = fmt.Errorf("URL is not a valid GCS URL")
Expand All @@ -208,6 +208,8 @@ func (g *GCSGetter) parseURL(u *url.URL) (bucket, path, fragment string, err err
bucket = pathParts[3]
path = pathParts[4]
fragment = u.Fragment
} else {
err = fmt.Errorf("URL is not a valid GCS URL")
}
return
}
Expand Down
56 changes: 56 additions & 0 deletions get_gcs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,59 @@ func TestGCSGetter_GetFile_OAuthAccessToken(t *testing.T) {
}
assertContents(t, dst, "# Main\n")
}

func Test_GCSGetter_ParseUrl(t *testing.T) {
tests := []struct {
name string
url string
}{
{
name: "valid host",
url: "https://www.googleapis.com/storage/v1/hc-go-getter-test/go-getter/foobar",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := new(GCSGetter)
u, err := url.Parse(tt.url)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
_, _, _, err = g.parseURL(u)
if err != nil {
t.Fatalf("wasn't expecting error, got %s", err)
}
})
}
}
func Test_GCSGetter_ParseUrl_Malformed(t *testing.T) {
tests := []struct {
name string
url string
}{
{
name: "invalid host suffix",
url: "https://www.googleapis.com.invalid",
},
{
name: "host suffix with a typo",
url: "https://www.googleapi.com.",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := new(GCSGetter)
u, err := url.Parse(tt.url)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
_, _, _, err = g.parseURL(u)
if err == nil {
t.Fatalf("expected error, got none")
}
if err.Error() != "URL is not a valid GCS URL" {
t.Fatalf("expected error 'URL is not a valid GCS URL', got %s", err.Error())
}
})
}
}
4 changes: 2 additions & 2 deletions get_s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
// This just check whether we are dealing with S3 or
// any other S3 compliant service. S3 has a predictable
// url as others do not
if strings.Contains(u.Host, "amazonaws.com") {
if strings.HasSuffix(u.Host, ".amazonaws.com") {
// Amazon S3 supports both virtual-hosted–style and path-style URLs to access a bucket, although path-style is deprecated
// In both cases few older regions supports dash-style region indication (s3-Region) even if AWS discourages their use.
// The same bucket could be reached with:
Expand Down Expand Up @@ -304,7 +304,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
path = pathParts[1]

}
if len(hostParts) < 3 && len(hostParts) > 5 {
if len(hostParts) < 3 || len(hostParts) > 5 {
err = fmt.Errorf("URL is not a valid S3 URL")
return
}
Expand Down
36 changes: 25 additions & 11 deletions get_s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,35 +278,49 @@ func TestS3Getter_Url(t *testing.T) {

func Test_S3Getter_ParseUrl_Malformed(t *testing.T) {
tests := []struct {
name string
url string
name string
input string
expected string
}{
{
name: "path style",
url: "https://s3.amazonaws.com/bucket",
name: "path style",
input: "https://s3.amazonaws.com/bucket",
expected: "URL is not a valid S3 URL",
},
{
name: "vhost-style, dash region indication",
url: "https://bucket.s3-us-east-1.amazonaws.com",
name: "vhost-style, dash region indication",
input: "https://bucket.s3-us-east-1.amazonaws.com",
expected: "URL is not a valid S3 URL",
},
{
name: "vhost-style, dot region indication",
url: "https://bucket.s3.us-east-1.amazonaws.com",
name: "vhost-style, dot region indication",
input: "https://bucket.s3.us-east-1.amazonaws.com",
expected: "URL is not a valid S3 URL",
},
{
name: "invalid host parts",
input: "https://invalid.host.parts.lenght.s3.us-east-1.amazonaws.com",
expected: "URL is not a valid S3 URL",
},
{
name: "invalid host suffix",
input: "https://bucket.s3.amazonaws.com.invalid",
expected: "URL is not a valid S3 compliant URL",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := new(S3Getter)
u, err := url.Parse(tt.url)
u, err := url.Parse(tt.input)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
_, _, _, _, _, err = g.parseUrl(u)
if err == nil {
t.Fatalf("expected error, got none")
}
if err.Error() != "URL is not a valid S3 URL" {
t.Fatalf("expected error 'URL is not a valid S3 URL', got %s", err.Error())
if err.Error() != tt.expected {
t.Fatalf("expected error '%s', got %s for %s", tt.expected, err.Error(), tt.name)
}
})
}
Expand Down
Loading