Skip to content

feat: support IMDS v2 default for ecs ram role #116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions credentials/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type Config struct {
PublicKeyId *string `json:"public_key_id"`
RoleName *string `json:"role_name"`
EnableIMDSv2 *bool `json:"enable_imds_v2"`
DisableIMDSv1 *bool `json:"disable_imds_v1"`
MetadataTokenDuration *int `json:"metadata_token_duration"`
SessionExpiration *int `json:"session_expiration"`
PrivateKeyFile *string `json:"private_key_file"`
Expand Down Expand Up @@ -248,8 +249,7 @@ func NewCredential(config *Config) (credential Credential, err error) {
case "ecs_ram_role":
provider, err := providers.NewECSRAMRoleCredentialsProviderBuilder().
WithRoleName(tea.StringValue(config.RoleName)).
WithEnableIMDSv2(tea.BoolValue(config.EnableIMDSv2)).
WithMetadataTokenDurationSeconds(tea.IntValue(config.MetadataTokenDuration)).
WithDisableIMDSv1(tea.BoolValue(config.DisableIMDSv1)).
Build()

if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions credentials/credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ this is privatekey`

func TestConfig(t *testing.T) {
config := new(Config)
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.String())
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.GoString())
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"disable_imds_v1\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.String())
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"disable_imds_v1\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.GoString())

config.SetSTSEndpoint("sts.cn-hangzhou.aliyuncs.com")
assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", *config.STSEndpoint)
Expand Down
66 changes: 30 additions & 36 deletions credentials/internal/providers/ecs_ram_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package providers

import (
"encoding/json"
"errors"
"fmt"
"os"
"strconv"
Expand All @@ -13,9 +12,8 @@ import (
)

type ECSRAMRoleCredentialsProvider struct {
roleName string
metadataTokenDurationSeconds int
enableIMDSv2 bool
roleName string
disableIMDSv1 bool
// for sts
session *sessionCredentials
expirationTimestamp int64
Expand All @@ -27,43 +25,31 @@ type ECSRAMRoleCredentialsProviderBuilder struct {

func NewECSRAMRoleCredentialsProviderBuilder() *ECSRAMRoleCredentialsProviderBuilder {
return &ECSRAMRoleCredentialsProviderBuilder{
provider: &ECSRAMRoleCredentialsProvider{
// TBD: 默认启用 IMDS v2
// enableIMDSv2: os.Getenv("ALIBABA_CLOUD_IMDSV2_DISABLED") != "true", // 默认启用 v2
},
provider: &ECSRAMRoleCredentialsProvider{},
}
}

func (builder *ECSRAMRoleCredentialsProviderBuilder) WithMetadataTokenDurationSeconds(metadataTokenDurationSeconds int) *ECSRAMRoleCredentialsProviderBuilder {
builder.provider.metadataTokenDurationSeconds = metadataTokenDurationSeconds
return builder
}

func (builder *ECSRAMRoleCredentialsProviderBuilder) WithRoleName(roleName string) *ECSRAMRoleCredentialsProviderBuilder {
builder.provider.roleName = roleName
return builder
}

func (builder *ECSRAMRoleCredentialsProviderBuilder) WithEnableIMDSv2(enableIMDSv2 bool) *ECSRAMRoleCredentialsProviderBuilder {
builder.provider.enableIMDSv2 = enableIMDSv2
func (builder *ECSRAMRoleCredentialsProviderBuilder) WithDisableIMDSv1(disableIMDSv1 bool) *ECSRAMRoleCredentialsProviderBuilder {
builder.provider.disableIMDSv1 = disableIMDSv1
return builder
}

const defaultMetadataTokenDuration = 21600 // 6 hours

func (builder *ECSRAMRoleCredentialsProviderBuilder) Build() (provider *ECSRAMRoleCredentialsProvider, err error) {

// 设置 roleName 默认值
if builder.provider.roleName == "" {
builder.provider.roleName = os.Getenv("ALIBABA_CLOUD_ECS_METADATA")
}

if builder.provider.metadataTokenDurationSeconds == 0 {
builder.provider.metadataTokenDurationSeconds = defaultMetadataTokenDuration
}

if builder.provider.metadataTokenDurationSeconds < 1 || builder.provider.metadataTokenDurationSeconds > 21600 {
err = errors.New("the metadata token duration seconds must be 1-21600")
return
if !builder.provider.disableIMDSv1 {
builder.provider.disableIMDSv1 = os.Getenv("ALIBABA_CLOUD_IMDSV1_DISABLE") == "true"
}

provider = builder.provider
Expand Down Expand Up @@ -98,11 +84,11 @@ func (provider *ECSRAMRoleCredentialsProvider) getRoleName() (roleName string, e
Headers: map[string]string{},
}

if provider.enableIMDSv2 {
metadataToken, err := provider.getMetadataToken()
if err != nil {
return "", err
}
metadataToken, err := provider.getMetadataToken()
if err != nil {
return "", err
}
if metadataToken != "" {
req.Headers["x-aliyun-ecs-metadata-token"] = metadataToken
}

Expand Down Expand Up @@ -140,11 +126,11 @@ func (provider *ECSRAMRoleCredentialsProvider) getCredentials() (session *sessio
Headers: map[string]string{},
}

if provider.enableIMDSv2 {
metadataToken, err := provider.getMetadataToken()
if err != nil {
return nil, err
}
metadataToken, err := provider.getMetadataToken()
if err != nil {
return nil, err
}
if metadataToken != "" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getMetadataToken 应该抛错吧,如果为空字符串的话

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getMetadataToken中,如果disableIMDSv1开启了,才会抛错,否则忽略报错走v1

req.Headers["x-aliyun-ecs-metadata-token"] = metadataToken
}

Expand Down Expand Up @@ -221,14 +207,22 @@ func (provider *ECSRAMRoleCredentialsProvider) getMetadataToken() (metadataToken
Host: "100.100.100.200",
Path: "/latest/api/token",
Headers: map[string]string{
"X-aliyun-ecs-metadata-token-ttl-seconds": strconv.Itoa(provider.metadataTokenDurationSeconds),
"X-aliyun-ecs-metadata-token-ttl-seconds": strconv.Itoa(defaultMetadataTokenDuration),
},
ConnectTimeout: 5 * time.Second,
ReadTimeout: 5 * time.Second,
}
res, err := httpDo(req)
if err != nil {
err = fmt.Errorf("get metadata token failed: %s", err.Error())
res, _err := httpDo(req)
if _err != nil {
if provider.disableIMDSv1 {
err = fmt.Errorf("get metadata token failed: %s", _err.Error())
}
return
}
if res.StatusCode != 200 {
if provider.disableIMDSv1 {
err = fmt.Errorf("refresh Ecs sts token err, httpStatus: %d, message = %s", res.StatusCode, string(res.Body))
}
return
}
metadataToken = string(res.Body)
Expand Down
62 changes: 54 additions & 8 deletions credentials/internal/providers/ecs_ram_role_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package providers

import (
"errors"
"os"
"testing"
"time"

Expand All @@ -13,15 +14,10 @@ func TestNewECSRAMRoleCredentialsProvider(t *testing.T) {
p, err := NewECSRAMRoleCredentialsProviderBuilder().Build()
assert.Nil(t, err)
assert.Equal(t, "", p.roleName)
assert.Equal(t, 21600, p.metadataTokenDurationSeconds)

_, err = NewECSRAMRoleCredentialsProviderBuilder().WithMetadataTokenDurationSeconds(1000000000).Build()
assert.EqualError(t, err, "the metadata token duration seconds must be 1-21600")

p, err = NewECSRAMRoleCredentialsProviderBuilder().WithRoleName("role").WithMetadataTokenDurationSeconds(3600).Build()
p, err = NewECSRAMRoleCredentialsProviderBuilder().WithRoleName("role").Build()
assert.Nil(t, err)
assert.Equal(t, "role", p.roleName)
assert.Equal(t, 3600, p.metadataTokenDurationSeconds)

assert.True(t, p.needUpdateCredential())
}
Expand Down Expand Up @@ -73,7 +69,7 @@ func TestECSRAMRoleCredentialsProvider_getRoleNameWithMetadataV2(t *testing.T) {
originHttpDo := httpDo
defer func() { httpDo = originHttpDo }()

p, err := NewECSRAMRoleCredentialsProviderBuilder().WithEnableIMDSv2(true).Build()
p, err := NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(true).Build()
assert.Nil(t, err)

// case 1: get metadata token failed
Expand Down Expand Up @@ -281,7 +277,7 @@ func TestECSRAMRoleCredentialsProvider_getCredentialsWithMetadataV2(t *testing.T
originHttpDo := httpDo
defer func() { httpDo = originHttpDo }()

p, err := NewECSRAMRoleCredentialsProviderBuilder().WithRoleName("rolename").WithEnableIMDSv2(true).Build()
p, err := NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(true).WithRoleName("rolename").Build()
assert.Nil(t, err)

// case 1: get metadata token failed
Expand Down Expand Up @@ -383,9 +379,37 @@ func TestECSRAMRoleCredentialsProvider_getMetadataToken(t *testing.T) {
return
}

_, err = p.getMetadataToken()
assert.Nil(t, err)

p, err = NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(false).Build()
assert.Nil(t, err)

_, err = p.getMetadataToken()
assert.Nil(t, err)

os.Setenv("ALIBABA_CLOUD_IMDSV1_DISABLE", "true")
p, err = NewECSRAMRoleCredentialsProviderBuilder().Build()
assert.Nil(t, err)

_, err = p.getMetadataToken()
assert.NotNil(t, err)

os.Setenv("ALIBABA_CLOUD_IMDSV1_DISABLE", "")
p, err = NewECSRAMRoleCredentialsProviderBuilder().Build()
assert.Nil(t, err)

_, err = p.getMetadataToken()
assert.Nil(t, err)

p, err = NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(true).Build()
assert.Nil(t, err)

_, err = p.getMetadataToken()
assert.NotNil(t, err)

assert.Equal(t, "get metadata token failed: mock server error", err.Error())

// case 2: return token
httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
res = &httputil.Response{
Expand All @@ -397,4 +421,26 @@ func TestECSRAMRoleCredentialsProvider_getMetadataToken(t *testing.T) {
metadataToken, err := p.getMetadataToken()
assert.Nil(t, err)
assert.Equal(t, "tokenxxxxx", metadataToken)

// case 3: return 404
p, err = NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(false).Build()
assert.Nil(t, err)

httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
res = &httputil.Response{
StatusCode: 404,
Body: []byte("not found"),
}
return
}
metadataToken, err = p.getMetadataToken()
assert.Nil(t, err)
assert.Equal(t, "", metadataToken)

p, err = NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(true).Build()
assert.Nil(t, err)

metadataToken, err = p.getMetadataToken()
assert.NotNil(t, err)
assert.Equal(t, "", metadataToken)
}