From 3224288f8a1155da1c1513cf705ef53efe86b98f Mon Sep 17 00:00:00 2001 From: Jackson Tian Date: Fri, 26 Jul 2024 18:17:50 +0800 Subject: [PATCH 1/2] rename URLCredential to URLCredentialsProvider --- credentials/uri_credential.go | 62 ++++++++++++------------------ credentials/uri_credential_test.go | 11 +++--- 2 files changed, 30 insertions(+), 43 deletions(-) diff --git a/credentials/uri_credential.go b/credentials/uri_credential.go index d03006c..335a349 100644 --- a/credentials/uri_credential.go +++ b/credentials/uri_credential.go @@ -12,7 +12,7 @@ import ( ) // URLCredential is a kind of credential -type URLCredential struct { +type URLCredentialsProvider struct { URL string *credentialUpdater *sessionCredential @@ -26,18 +26,18 @@ type URLResponse struct { Expiration string `json:"Expiration" xml:"Expiration"` } -func newURLCredential(URL string) *URLCredential { +func newURLCredential(URL string) *URLCredentialsProvider { credentialUpdater := new(credentialUpdater) if URL == "" { URL = os.Getenv("ALIBABA_CLOUD_CREDENTIALS_URI") } - return &URLCredential{ + return &URLCredentialsProvider{ URL: URL, credentialUpdater: credentialUpdater, } } -func (e *URLCredential) GetCredential() (*CredentialModel, error) { +func (e *URLCredentialsProvider) GetCredential() (*CredentialModel, error) { if e.sessionCredential == nil || e.needUpdateCredential() { err := e.updateCredential() if err != nil { @@ -55,60 +55,48 @@ func (e *URLCredential) GetCredential() (*CredentialModel, error) { // GetAccessKeyId reutrns URLCredential's AccessKeyId // if AccessKeyId is not exist or out of date, the function will update it. -func (e *URLCredential) GetAccessKeyId() (*string, error) { - if e.sessionCredential == nil || e.needUpdateCredential() { - err := e.updateCredential() - if err != nil { - if e.credentialExpiration > (int(time.Now().Unix()) - int(e.lastUpdateTimestamp)) { - return &e.sessionCredential.AccessKeyId, nil - } - return tea.String(""), err - } +func (e *URLCredentialsProvider) GetAccessKeyId() (accessKeyId *string, err error) { + c, err := e.GetCredential() + if err != nil { + return } - return tea.String(e.sessionCredential.AccessKeyId), nil + accessKeyId = c.AccessKeyId + return } // GetAccessSecret reutrns URLCredential's AccessKeySecret // if AccessKeySecret is not exist or out of date, the function will update it. -func (e *URLCredential) GetAccessKeySecret() (*string, error) { - if e.sessionCredential == nil || e.needUpdateCredential() { - err := e.updateCredential() - if err != nil { - if e.credentialExpiration > (int(time.Now().Unix()) - int(e.lastUpdateTimestamp)) { - return &e.sessionCredential.AccessKeySecret, nil - } - return tea.String(""), err - } +func (e *URLCredentialsProvider) GetAccessKeySecret() (accessKeySecret *string, err error) { + c, err := e.GetCredential() + if err != nil { + return } - return tea.String(e.sessionCredential.AccessKeySecret), nil + accessKeySecret = c.AccessKeySecret + return } // GetSecurityToken reutrns URLCredential's SecurityToken // if SecurityToken is not exist or out of date, the function will update it. -func (e *URLCredential) GetSecurityToken() (*string, error) { - if e.sessionCredential == nil || e.needUpdateCredential() { - err := e.updateCredential() - if err != nil { - if e.credentialExpiration > (int(time.Now().Unix()) - int(e.lastUpdateTimestamp)) { - return &e.sessionCredential.SecurityToken, nil - } - return tea.String(""), err - } +func (e *URLCredentialsProvider) GetSecurityToken() (securityToken *string, err error) { + c, err := e.GetCredential() + if err != nil { + return } - return tea.String(e.sessionCredential.SecurityToken), nil + securityToken = c.SecurityToken + return } // GetBearerToken is useless for URLCredential -func (e *URLCredential) GetBearerToken() *string { +func (e *URLCredentialsProvider) GetBearerToken() *string { return tea.String("") } // GetType reutrns URLCredential's type -func (e *URLCredential) GetType() *string { +func (e *URLCredentialsProvider) GetType() *string { return tea.String("credential_uri") } -func (e *URLCredential) updateCredential() (err error) { +func (e *URLCredentialsProvider) updateCredential() (err error) { if e.runtime == nil { e.runtime = new(utils.Runtime) } diff --git a/credentials/uri_credential_test.go b/credentials/uri_credential_test.go index f0fec7d..0b15a45 100644 --- a/credentials/uri_credential_test.go +++ b/credentials/uri_credential_test.go @@ -9,20 +9,19 @@ import ( ) func TestURLCredential_updateCredential(t *testing.T) { - URLCredential := newURLCredential("http://127.0.0.1") + provider := newURLCredential("http://127.0.0.1") hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { return func(req *http.Request) (*http.Response, error) { return mockResponse(300, ``, errors.New("sdk test")) } } - accesskeyId, err := URLCredential.GetAccessKeyId() - // assert.NotNil(t, err) + _, err := provider.GetAccessKeyId() + assert.NotNil(t, err) assert.Equal(t, "refresh Ecs sts token err: sdk test", err.Error()) - assert.Equal(t, "", *accesskeyId) - assert.Equal(t, "credential_uri", *URLCredential.GetType()) + assert.Equal(t, "credential_uri", *provider.GetType()) - cred, err := URLCredential.GetCredential() + cred, err := provider.GetCredential() assert.Equal(t, "refresh Ecs sts token err: sdk test", err.Error()) assert.Nil(t, cred) } From 2b216a12be5c836d23584897a53093366020eb97 Mon Sep 17 00:00:00 2001 From: Jackson Tian Date: Sun, 28 Jul 2024 11:28:55 +0800 Subject: [PATCH 2/2] improve URICredentialsProvider --- credentials/uri_credential.go | 6 +-- credentials/uri_credential_test.go | 80 +++++++++++++++++++++++++++--- 2 files changed, 77 insertions(+), 9 deletions(-) diff --git a/credentials/uri_credential.go b/credentials/uri_credential.go index 335a349..8ac897e 100644 --- a/credentials/uri_credential.go +++ b/credentials/uri_credential.go @@ -105,15 +105,15 @@ func (e *URLCredentialsProvider) updateCredential() (err error) { request.Method = "GET" content, err := doAction(request, e.runtime) if err != nil { - return fmt.Errorf("refresh Ecs sts token err: %s", err.Error()) + return fmt.Errorf("get credentials from %s failed with error: %s", e.URL, err.Error()) } var resp *URLResponse err = json.Unmarshal(content, &resp) if err != nil { - return fmt.Errorf("refresh Ecs sts token err: Json Unmarshal fail: %s", err.Error()) + return fmt.Errorf("get credentials from %s failed with error, json unmarshal fail: %s", e.URL, err.Error()) } if resp.AccessKeyId == "" || resp.AccessKeySecret == "" || resp.SecurityToken == "" || resp.Expiration == "" { - return fmt.Errorf("refresh Ecs sts token err: AccessKeyId: %s, AccessKeySecret: %s, SecurityToken: %s, Expiration: %s", resp.AccessKeyId, resp.AccessKeySecret, resp.SecurityToken, resp.Expiration) + return fmt.Errorf("get credentials failed: AccessKeyId: %s, AccessKeySecret: %s, SecurityToken: %s, Expiration: %s", resp.AccessKeyId, resp.AccessKeySecret, resp.SecurityToken, resp.Expiration) } expirationTime, err := time.Parse("2006-01-02T15:04:05Z", resp.Expiration) diff --git a/credentials/uri_credential_test.go b/credentials/uri_credential_test.go index 0b15a45..47dbc51 100644 --- a/credentials/uri_credential_test.go +++ b/credentials/uri_credential_test.go @@ -8,20 +8,88 @@ import ( "github.com/stretchr/testify/assert" ) -func TestURLCredential_updateCredential(t *testing.T) { +func TestURLCredentialsProvider_updateCredential(t *testing.T) { provider := newURLCredential("http://127.0.0.1") + + origTestHookDo := hookDo + defer func() { hookDo = origTestHookDo }() hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { return func(req *http.Request) (*http.Response, error) { return mockResponse(300, ``, errors.New("sdk test")) } } - _, err := provider.GetAccessKeyId() + + cred, err := provider.GetCredential() assert.NotNil(t, err) - assert.Equal(t, "refresh Ecs sts token err: sdk test", err.Error()) + assert.Equal(t, "get credentials from http://127.0.0.1 failed with error: sdk test", err.Error()) + assert.Nil(t, cred) - assert.Equal(t, "credential_uri", *provider.GetType()) + _, err = provider.GetAccessKeyId() + assert.NotNil(t, err) + assert.Equal(t, "get credentials from http://127.0.0.1 failed with error: sdk test", err.Error()) - cred, err := provider.GetCredential() - assert.Equal(t, "refresh Ecs sts token err: sdk test", err.Error()) + _, err = provider.GetAccessKeySecret() + assert.NotNil(t, err) + assert.Equal(t, "get credentials from http://127.0.0.1 failed with error: sdk test", err.Error()) + + _, err = provider.GetSecurityToken() + assert.NotNil(t, err) + assert.Equal(t, "get credentials from http://127.0.0.1 failed with error: sdk test", err.Error()) + + hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { + return func(req *http.Request) (*http.Response, error) { + return mockResponse(200, `invalid json`, nil) + } + } + + cred, err = provider.GetCredential() + assert.NotNil(t, err) + assert.Equal(t, "get credentials from http://127.0.0.1 failed with error, json unmarshal fail: invalid character 'i' looking for beginning of value", err.Error()) assert.Nil(t, cred) + + hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { + return func(req *http.Request) (*http.Response, error) { + return mockResponse(200, `{}`, nil) + } + } + + cred, err = provider.GetCredential() + assert.NotNil(t, err) + assert.Equal(t, "get credentials failed: AccessKeyId: , AccessKeySecret: , SecurityToken: , Expiration: ", err.Error()) + assert.Nil(t, cred) + + hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { + return func(req *http.Request) (*http.Response, error) { + return mockResponse(200, `{"AccessKeyId":"akid", "AccessKeySecret":"aksecret","SecurityToken":"sts","Expiration":"2006-01-02T15:04:05Z"}`, nil) + } + } + + cred, err = provider.GetCredential() + assert.Nil(t, err) + assert.NotNil(t, cred) + assert.Equal(t, "akid", *cred.AccessKeyId) + assert.Equal(t, "aksecret", *cred.AccessKeySecret) + assert.Equal(t, "sts", *cred.SecurityToken) + + akid, err := provider.GetAccessKeyId() + assert.Nil(t, err) + assert.Equal(t, "akid", *akid) + + aksecret, err := provider.GetAccessKeySecret() + assert.Nil(t, err) + assert.Equal(t, "aksecret", *aksecret) + + sts, err := provider.GetSecurityToken() + assert.Nil(t, err) + assert.Equal(t, "sts", *sts) +} + +func TestURLCredentialsProviderGetBearerToken(t *testing.T) { + provider := newURLCredential("http://127.0.0.1") + assert.Equal(t, "", *provider.GetBearerToken()) +} + +func TestURLCredentialsProviderGetType(t *testing.T) { + provider := newURLCredential("http://127.0.0.1") + assert.Equal(t, "credential_uri", *provider.GetType()) }