diff --git a/credentials/internal/providers/hook.go b/credentials/internal/providers/hook.go index f09e65b..6839abd 100644 --- a/credentials/internal/providers/hook.go +++ b/credentials/internal/providers/hook.go @@ -1,22 +1,7 @@ package providers import ( - "io" - "net/http" - httputil "github.com/aliyun/credentials-go/credentials/internal/http" ) var httpDo = httputil.Do - -type newReuqest func(method, url string, body io.Reader) (*http.Request, error) - -var hookNewRequest = func(fn newReuqest) newReuqest { - return fn -} - -type do func(req *http.Request) (*http.Response, error) - -var hookDo = func(fn do) do { - return fn -} diff --git a/credentials/internal/providers/http.go b/credentials/internal/providers/http.go deleted file mode 100644 index 352ed9f..0000000 --- a/credentials/internal/providers/http.go +++ /dev/null @@ -1,21 +0,0 @@ -package providers - -import ( - "bytes" - "io/ioutil" - "net/http" - "strconv" -) - -func mockResponse(statusCode int, content string) (res *http.Response) { - status := strconv.Itoa(statusCode) - res = &http.Response{ - Proto: "HTTP/1.1", - ProtoMajor: 1, - Header: map[string][]string{"sdk": {"test"}}, - StatusCode: statusCode, - Status: status + " " + http.StatusText(statusCode), - } - res.Body = ioutil.NopCloser(bytes.NewReader([]byte(content))) - return -} diff --git a/credentials/internal/providers/ram_role_arn.go b/credentials/internal/providers/ram_role_arn.go index a1a83c7..119efeb 100644 --- a/credentials/internal/providers/ram_role_arn.go +++ b/credentials/internal/providers/ram_role_arn.go @@ -4,13 +4,13 @@ import ( "encoding/json" "errors" "fmt" - "io/ioutil" "net/http" "net/url" "strconv" "strings" "time" + httputil "github.com/aliyun/credentials-go/credentials/internal/http" "github.com/aliyun/credentials-go/credentials/internal/utils" ) @@ -157,7 +157,13 @@ func (builder *RAMRoleARNCredentialsProviderBuilder) Build() (provider *RAMRoleA func (provider *RAMRoleARNCredentialsProvider) getCredentials(cc *Credentials) (session *sessionCredentials, err error) { method := "POST" - host := provider.stsEndpoint + req := &httputil.Request{ + Method: method, + Protocol: "https", + Host: provider.stsEndpoint, + Headers: map[string]string{}, + } + queries := make(map[string]string) queries["Version"] = "2015-04-01" queries["Action"] = "AssumeRole" @@ -167,6 +173,7 @@ func (provider *RAMRoleARNCredentialsProvider) getCredentials(cc *Credentials) ( queries["SignatureVersion"] = "1.0" queries["SignatureNonce"] = utils.GetNonce() queries["AccessKeyId"] = cc.AccessKeyId + if cc.SecurityToken != "" { queries["SecurityToken"] = cc.SecurityToken } @@ -181,6 +188,7 @@ func (provider *RAMRoleARNCredentialsProvider) getCredentials(cc *Credentials) ( } bodyForm["RoleSessionName"] = provider.roleSessionName bodyForm["DurationSeconds"] = strconv.Itoa(provider.durationSeconds) + req.Form = bodyForm // caculate signature signParams := make(map[string]string) @@ -200,58 +208,30 @@ func (provider *RAMRoleARNCredentialsProvider) getCredentials(cc *Credentials) ( secret := cc.AccessKeySecret + "&" queries["Signature"] = utils.ShaHmac1(stringToSign, secret) - querystring := utils.GetURLFormedMap(queries) - // do request - httpUrl := fmt.Sprintf("https://%s/?%s", host, querystring) - - body := utils.GetURLFormedMap(bodyForm) - - httpRequest, err := hookNewRequest(http.NewRequest)(method, httpUrl, strings.NewReader(body)) - if err != nil { - return - } + req.Queries = queries // set headers - httpRequest.Header["Accept-Encoding"] = []string{"identity"} - httpRequest.Header["Content-Type"] = []string{"application/x-www-form-urlencoded"} - httpRequest.Header["x-credentials-provider"] = []string{cc.ProviderName} - httpClient := &http.Client{} + req.Headers["Accept-Encoding"] = "identity" + req.Headers["Content-Type"] = "application/x-www-form-urlencoded" + req.Headers["x-acs-credentials-provider"] = cc.ProviderName if provider.httpOptions != nil { - httpClient.Timeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Second - proxy := &url.URL{} - if provider.httpOptions.Proxy != "" { - proxy, err = url.Parse(provider.httpOptions.Proxy) - if err != nil { - return - } - } - trans := &http.Transport{} - if proxy != nil && provider.httpOptions.Proxy != "" { - trans.Proxy = http.ProxyURL(proxy) - } - trans.DialContext = utils.Timeout(time.Duration(provider.httpOptions.ConnectTimeout) * time.Second) - httpClient.Transport = trans + req.ConnectTimeout = time.Duration(provider.httpOptions.ConnectTimeout) * time.Second + req.ReadTimeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Second + req.Proxy = provider.httpOptions.Proxy } - httpResponse, err := hookDo(httpClient.Do)(httpRequest) - if err != nil { - return - } - - defer httpResponse.Body.Close() - - responseBody, err := ioutil.ReadAll(httpResponse.Body) + res, err := httpDo(req) if err != nil { return } - if httpResponse.StatusCode != http.StatusOK { - err = errors.New("refresh session token failed: " + string(responseBody)) + if res.StatusCode != http.StatusOK { + err = errors.New("refresh session token failed: " + string(res.Body)) return } var data assumeRoleResponse - err = json.Unmarshal(responseBody, &data) + err = json.Unmarshal(res.Body, &data) if err != nil { err = fmt.Errorf("refresh RoleArn sts token err, json.Unmarshal fail: %s", err.Error()) return diff --git a/credentials/internal/providers/ram_role_arn_test.go b/credentials/internal/providers/ram_role_arn_test.go index f6b9f6f..9551e91 100644 --- a/credentials/internal/providers/ram_role_arn_test.go +++ b/credentials/internal/providers/ram_role_arn_test.go @@ -2,25 +2,14 @@ package providers import ( "errors" - "io" - "io/ioutil" - "net/http" - "strconv" "strings" "testing" "time" + httputil "github.com/aliyun/credentials-go/credentials/internal/http" "github.com/stretchr/testify/assert" ) -type errorReader struct { -} - -func (r *errorReader) Read(p []byte) (n int, err error) { - err = errors.New("read failed") - return -} - func TestNewRAMRoleARNCredentialsProvider(t *testing.T) { // case 1: no credentials provider _, err := NewRAMRoleARNCredentialsProviderBuilder(). @@ -120,6 +109,9 @@ func TestNewRAMRoleARNCredentialsProvider(t *testing.T) { } func TestRAMRoleARNCredentialsProvider_getCredentials(t *testing.T) { + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() + akProvider, err := NewStaticAKCredentialsProviderBuilder(). WithAccessKeyId("akid"). WithAccessKeySecret("aksecret"). @@ -136,104 +128,71 @@ func TestRAMRoleARNCredentialsProvider_getCredentials(t *testing.T) { cc, err := akProvider.GetCredentials() assert.Nil(t, err) - originNewRequest := hookNewRequest - defer func() { hookNewRequest = originNewRequest }() - - // case 1: mock new http request failed - hookNewRequest = func(fn newReuqest) newReuqest { - return func(method, url string, body io.Reader) (*http.Request, error) { - return nil, errors.New("new http request failed") - } - } - _, err = p.getCredentials(cc) - assert.NotNil(t, err) - assert.Equal(t, "new http request failed", err.Error()) - // reset new request - hookNewRequest = originNewRequest - - originDo := hookDo - defer func() { hookDo = originDo }() - - // case 2: server error - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - err = errors.New("mock server error") - return - } + // case 1: server error + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + err = errors.New("mock server error") + return } _, err = p.getCredentials(cc) assert.NotNil(t, err) assert.Equal(t, "mock server error", err.Error()) - // case 3: mock read response error - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - status := strconv.Itoa(200) - res = &http.Response{ - Proto: "HTTP/1.1", - ProtoMajor: 1, - Header: map[string][]string{}, - StatusCode: 200, - Status: status + " " + http.StatusText(200), - } - res.Body = ioutil.NopCloser(&errorReader{}) - return + // case 2: 4xx error + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 400, + Body: []byte("4xx error"), } + return } - _, err = p.getCredentials(cc) - assert.NotNil(t, err) - assert.Equal(t, "read failed", err.Error()) - // case 4: 4xx error - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(400, "4xx error") - return - } - } _, err = p.getCredentials(cc) assert.NotNil(t, err) assert.Equal(t, "refresh session token failed: 4xx error", err.Error()) - // case 5: invalid json - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(200, "invalid json") - return + // case 3: invalid json + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("invalid json"), } + return } _, err = p.getCredentials(cc) assert.NotNil(t, err) assert.Equal(t, "refresh RoleArn sts token err, json.Unmarshal fail: invalid character 'i' looking for beginning of value", err.Error()) - // case 6: empty response json - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(200, "null") - return + // case 4: empty response json + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("null"), } + return } _, err = p.getCredentials(cc) assert.NotNil(t, err) assert.Equal(t, "refresh RoleArn sts token err, fail to get credentials", err.Error()) - // case 7: empty session ak response json - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(200, `{"Credentials": {}}`) - return + // case 5: empty session ak response json + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"Credentials": {}}`), } + return } _, err = p.getCredentials(cc) assert.NotNil(t, err) assert.Equal(t, "refresh RoleArn sts token err, fail to get credentials", err.Error()) - // case 8: mock ok value - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(200, `{"Credentials": {"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token"}}`) - return + // case 6: mock ok value + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"Credentials": {"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token"}}`), } + return } creds, err := p.getCredentials(cc) assert.Nil(t, err) @@ -252,8 +211,8 @@ func TestRAMRoleARNCredentialsProvider_getCredentials(t *testing.T) { } func TestRAMRoleARNCredentialsProvider_getCredentialsWithRequestCheck(t *testing.T) { - originDo := hookDo - defer func() { hookDo = originDo }() + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() stsProvider, err := NewStaticSTSCredentialsProviderBuilder(). WithAccessKeyId("akid"). @@ -273,21 +232,16 @@ func TestRAMRoleARNCredentialsProvider_getCredentialsWithRequestCheck(t *testing assert.Nil(t, err) // case 1: server error - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - assert.Equal(t, "sts.cn-beijing.aliyuncs.com", req.Host) - assert.Contains(t, req.URL.String(), "SecurityToken=ststoken") - body, err := ioutil.ReadAll(req.Body) - assert.Nil(t, err) - bodyString := string(body) - assert.Contains(t, bodyString, "Policy=policy") - assert.Contains(t, bodyString, "RoleArn=roleArn") - assert.Contains(t, bodyString, "RoleSessionName=rsn") - assert.Contains(t, bodyString, "DurationSeconds=1000") - - err = errors.New("mock server error") - return - } + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + assert.Equal(t, "sts.cn-beijing.aliyuncs.com", req.Host) + assert.Equal(t, "ststoken", req.Queries["SecurityToken"]) + assert.Equal(t, "policy", req.Form["Policy"]) + assert.Equal(t, "roleArn", req.Form["RoleArn"]) + assert.Equal(t, "rsn", req.Form["RoleSessionName"]) + assert.Equal(t, "1000", req.Form["DurationSeconds"]) + + err = errors.New("mock server error") + return } cc, err := stsProvider.GetCredentials() @@ -310,8 +264,8 @@ func (p *errorCredentialsProvider) GetProviderName() string { } func TestRAMRoleARNCredentialsProviderGetCredentials(t *testing.T) { - originDo := hookDo - defer func() { hookDo = originDo }() + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() // case 0: get previous credentials failed p, err := NewRAMRoleARNCredentialsProviderBuilder(). @@ -339,33 +293,33 @@ func TestRAMRoleARNCredentialsProviderGetCredentials(t *testing.T) { assert.Nil(t, err) // case 1: get credentials failed - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - err = errors.New("mock server error") - return - } + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + err = errors.New("mock server error") + return } _, err = p.GetCredentials() assert.NotNil(t, err) assert.Equal(t, "mock server error", err.Error()) // case 2: get invalid expiration - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(200, `{"Credentials": {"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"invalidexpiration","SecurityToken":"ststoken"}}`) - return + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"Credentials": {"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"invalidexpiration","SecurityToken":"ststoken"}}`), } + return } _, err = p.GetCredentials() assert.NotNil(t, err) assert.Equal(t, "parsing time \"invalidexpiration\" as \"2006-01-02T15:04:05Z\": cannot parse \"invalidexpiration\" as \"2006\"", err.Error()) // case 3: happy result - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(200, `{"Credentials": {"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"ststoken"}}`) - return + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"Credentials": {"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"ststoken"}}`), } + return } cc, err := p.GetCredentials() assert.Nil(t, err) @@ -401,3 +355,26 @@ func TestRAMRoleARNCredentialsProviderGetCredentialsWithError(t *testing.T) { assert.NotNil(t, err) assert.Contains(t, err.Error(), "InvalidAccessKeyId.NotFound") } + +func TestRAMRoleARNCredentialsProviderWithHttpOptions(t *testing.T) { + akProvider, err := NewStaticAKCredentialsProviderBuilder(). + WithAccessKeyId("akid"). + WithAccessKeySecret("aksecret"). + Build() + assert.Nil(t, err) + p, err := NewRAMRoleARNCredentialsProviderBuilder(). + WithCredentialsProvider(akProvider). + WithRoleArn("roleArn"). + WithRoleSessionName("rsn"). + WithDurationSeconds(1000). + WithHttpOptions(&HttpOptions{ + ConnectTimeout: 1, + ReadTimeout: 1, + Proxy: "localhost:3999", + }). + Build() + assert.Nil(t, err) + _, err = p.GetCredentials() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "proxyconnect tcp:") +}