Skip to content

Commit cb85103

Browse files
authored
x/ref/runtime/internal: use aws imds v1 first and then v2 (#174)
AWS' instance meta data service cannot be accessed from a docker container running in bridge mode. This PR changes the behaviour to use the v1 service first and only if that fails, to use the v2 service. This fallback is required since it's possible to configure aws instances to allow v2 only.
1 parent 047b918 commit cb85103

File tree

5 files changed

+80
-34
lines changed

5 files changed

+80
-34
lines changed

x/ref/runtime/internal/cloudvm.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ type asyncChooser struct {
6060
func (ac *asyncChooser) ChooseAddresses(protocol string, candidates []net.Addr) ([]net.Addr, error) {
6161
select {
6262
case <-ac.ch:
63+
if cvmErr != nil {
64+
return nil, cvmErr
65+
}
6366
return cvm.ChooseAddresses(protocol, candidates)
6467
case <-ac.ctx.Done():
6568
return nil, ac.ctx.Err()
@@ -115,7 +118,7 @@ func newCloudVM(ctx context.Context, logger logging.Logger, fl *flags.Virtualize
115118

116119
switch fl.VirtualizationProvider.Get().(flags.VirtualizationProvider) {
117120
case flags.AWS:
118-
if !cloudvm.OnAWS(ctx, time.Second) {
121+
if !cloudvm.OnAWS(ctx, cvm.logger, time.Second) {
119122
if fl.DissallowNativeFallback {
120123
return nil, fmt.Errorf("this process is not running on AWS even though its command line says it is")
121124
}

x/ref/runtime/internal/cloudvm/aws.go

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"sync"
1818
"time"
1919

20+
"v.io/v23/logging"
2021
"v.io/x/ref/lib/stats"
2122
"v.io/x/ref/runtime/internal/cloudvm/cloudpaths"
2223
)
@@ -56,65 +57,88 @@ const (
5657
)
5758

5859
var (
59-
onceAWS sync.Once
60-
onAWS bool
60+
onceAWS sync.Once
61+
onAWS bool
62+
onIMDSv2 bool
6163
)
6264

6365
// OnAWS returns true if this process is running on Amazon Web Services.
6466
// If true, the the stats variables AWSAccountIDStatName and GCPRegionStatName
6567
// are set.
66-
func OnAWS(ctx context.Context, timeout time.Duration) bool {
68+
func OnAWS(ctx context.Context, logger logging.Logger, timeout time.Duration) bool {
6769
onceAWS.Do(func() {
68-
onAWS = awsInit(ctx, timeout)
70+
onAWS, onIMDSv2 = awsInit(ctx, logger, timeout)
71+
logger.VI(1).Infof("OnAWS: onAWS: %v, onIMDSv2: %v", onAWS, onIMDSv2)
6972
})
7073
return onAWS
7174
}
7275

7376
// AWSPublicAddrs returns the current public IP of this AWS instance.
77+
// Must be called after OnAWS.
7478
func AWSPublicAddrs(ctx context.Context, timeout time.Duration) ([]net.Addr, error) {
75-
return awsGetAddr(ctx, awsExternalURL(), timeout)
79+
return awsGetAddr(ctx, onIMDSv2, awsExternalURL(), timeout)
7680
}
7781

7882
// AWSPrivateAddrs returns the current private Addrs of this AWS instance.
83+
// Must be called after OnAWS.
7984
func AWSPrivateAddrs(ctx context.Context, timeout time.Duration) ([]net.Addr, error) {
80-
return awsGetAddr(ctx, awsInternalURL(), timeout)
85+
return awsGetAddr(ctx, onIMDSv2, awsInternalURL(), timeout)
8186
}
8287

83-
func awsGet(ctx context.Context, url string, timeout time.Duration) ([]byte, error) {
88+
func awsGet(ctx context.Context, imdsv2 bool, url string, timeout time.Duration) ([]byte, error) {
8489
client := &http.Client{Timeout: timeout}
85-
token, err := awsSetIMDSv2Token(ctx, awsTokenURL(), timeout)
86-
if err != nil {
87-
return nil, err
90+
var token string
91+
var err error
92+
if imdsv2 {
93+
token, err = awsSetIMDSv2Token(ctx, awsTokenURL(), timeout)
94+
if err != nil {
95+
return nil, err
96+
}
8897
}
8998
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
90-
req.Header.Add("X-aws-ec2-metadata-token", token)
9199
if err != nil {
92100
return nil, err
93101
}
102+
if len(token) > 0 {
103+
req.Header.Add("X-aws-ec2-metadata-token", token)
104+
}
94105
resp, err := client.Do(req)
95106
if err != nil {
96107
return nil, err
97108
}
98109
defer resp.Body.Close()
99110
if resp.StatusCode != 200 {
100-
return nil, err
111+
return nil, fmt.Errorf("HTTP Error: %v %v", url, resp.StatusCode)
101112
}
102113
if server := resp.Header["Server"]; len(server) != 1 || server[0] != "EC2ws" {
103114
return nil, fmt.Errorf("wrong headers")
104115
}
105116
return ioutil.ReadAll(resp.Body)
106117
}
107118

108-
// awsInit returns true if it can access AWS project metadata. It also
119+
// awsInit returns true if it can access AWS project metadata and the version
120+
// of the metadata service it was able to access. It also
109121
// creates two stats variables with the account ID and zone.
110-
func awsInit(ctx context.Context, timeout time.Duration) bool {
111-
body, err := awsGet(ctx, awsIdentityDocURL(), timeout)
122+
func awsInit(ctx context.Context, logger logging.Logger, timeout time.Duration) (bool, bool) {
123+
v2 := false
124+
// Try the v1 service first since it should always work unless v2
125+
// is specifically configured (and hence v1 is disabled), in which
126+
// case the expectation is that it fails fast with a 4xx HTTP error.
127+
body, err := awsGet(ctx, false, awsIdentityDocURL(), timeout)
112128
if err != nil {
113-
return false
129+
logger.VI(1).Infof("failed to access v1 metadata service: %v", err)
130+
// can't access v1, try v2.
131+
body, err = awsGet(ctx, true, awsIdentityDocURL(), timeout)
132+
if err != nil {
133+
logger.VI(1).Infof("failed to access v2 metadata service: %v", err)
134+
return false, false
135+
}
136+
v2 = true
114137
}
115138
doc := map[string]interface{}{}
116139
if err := json.Unmarshal(body, &doc); err != nil {
117-
return false
140+
logger.VI(1).Infof("failed to unmarshal metadata service response: %s: %v", body, err)
141+
return false, false
118142
}
119143
found := 0
120144
for _, v := range []struct {
@@ -130,11 +154,11 @@ func awsInit(ctx context.Context, timeout time.Duration) bool {
130154
}
131155
}
132156
}
133-
return found == 2
157+
return found == 2, v2
134158
}
135159

136-
func awsGetAddr(ctx context.Context, url string, timeout time.Duration) ([]net.Addr, error) {
137-
body, err := awsGet(ctx, url, timeout)
160+
func awsGetAddr(ctx context.Context, imdsv2 bool, url string, timeout time.Duration) ([]net.Addr, error) {
161+
body, err := awsGet(ctx, imdsv2, url, timeout)
138162
if err != nil {
139163
return nil, err
140164
}

x/ref/runtime/internal/cloudvm/aws_test.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,30 @@ import (
99
"testing"
1010
"time"
1111

12+
"v.io/x/ref/internal/logger"
1213
"v.io/x/ref/runtime/internal/cloudvm/cloudpaths"
1314
"v.io/x/ref/runtime/internal/cloudvm/cloudvmtest"
1415
)
1516

16-
func startAWSMetadataServer(t *testing.T) (string, func()) {
17-
host, close := cloudvmtest.StartAWSMetadataServer(t)
17+
func startAWSMetadataServer(t *testing.T, imdsv2Only bool) (string, func()) {
18+
host, close := cloudvmtest.StartAWSMetadataServer(t, imdsv2Only)
1819
SetAWSMetadataHost(host)
1920
return host, close
2021
}
2122

2223
func TestAWS(t *testing.T) {
24+
testAWSIDMSVersion(t, false)
25+
testAWSIDMSVersion(t, true)
26+
}
27+
28+
func testAWSIDMSVersion(t *testing.T, imdsv2Only bool) {
2329
ctx := context.Background()
24-
host, stop := startAWSMetadataServer(t)
30+
host, stop := startAWSMetadataServer(t, imdsv2Only)
2531
defer stop()
2632

27-
if got, want := OnAWS(ctx, time.Second), true; got != want {
33+
logger := logger.NewLogger("test")
34+
35+
if got, want := OnAWS(ctx, logger, time.Second), true; got != want {
2836
t.Errorf("got %v, want %v", got, want)
2937
}
3038

@@ -45,8 +53,9 @@ func TestAWS(t *testing.T) {
4553
if got, want := pub[0].String(), cloudvmtest.WellKnownPublicIP; got != want {
4654
t.Errorf("got %v, want %v", got, want)
4755
}
56+
4857
externalURL := host + cloudpaths.AWSPublicIPPath + "/noip"
49-
noip, err := awsGetAddr(ctx, externalURL, time.Second)
58+
noip, err := awsGetAddr(ctx, imdsv2Only, externalURL, time.Second)
5059
if err != nil {
5160
t.Fatal(err)
5261
}

x/ref/runtime/internal/cloudvm/cloudvmtest/aws_mock.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ import (
1515
"v.io/x/ref/runtime/internal/cloudvm/cloudpaths"
1616
)
1717

18-
func StartAWSMetadataServer(t *testing.T) (string, func()) {
18+
func StartAWSMetadataServer(t *testing.T, imdsv2Only bool) (string, func()) {
1919
l, err := net.Listen("tcp", "127.0.0.1:0")
2020
if err != nil {
2121
t.Fatal(err)
2222
}
2323
var token string
24-
http.HandleFunc(cloudpaths.AWSTokenPath, func(w http.ResponseWriter, req *http.Request) {
24+
mux := &http.ServeMux{}
25+
mux.HandleFunc(cloudpaths.AWSTokenPath, func(w http.ResponseWriter, req *http.Request) {
2526
token = time.Now().String()
2627
w.Header().Add("Server", "EC2ws")
2728
fmt.Fprint(w, token)
@@ -32,7 +33,13 @@ func StartAWSMetadataServer(t *testing.T) (string, func()) {
3233
return requestToken == token
3334
}
3435

35-
http.HandleFunc(cloudpaths.AWSIdentityDocPath, func(w http.ResponseWriter, r *http.Request) {
36+
mux.HandleFunc(cloudpaths.AWSIdentityDocPath, func(w http.ResponseWriter, r *http.Request) {
37+
if imdsv2Only {
38+
if len(r.Header.Get("X-aws-ec2-metadata-token")) == 0 {
39+
w.WriteHeader(http.StatusUnauthorized)
40+
return
41+
}
42+
}
3643
if !validSession(r) {
3744
w.WriteHeader(http.StatusForbidden)
3845
return
@@ -58,19 +65,22 @@ func StartAWSMetadataServer(t *testing.T) (string, func()) {
5865
fmt.Fprintf(w, format, args...)
5966
}
6067

61-
http.HandleFunc(cloudpaths.AWSPrivateIPPath,
68+
mux.HandleFunc(cloudpaths.AWSPrivateIPPath,
6269
func(w http.ResponseWriter, r *http.Request) {
6370
respond(w, r, WellKnownPrivateIP)
6471
})
65-
http.HandleFunc(cloudpaths.AWSPublicIPPath,
72+
mux.HandleFunc(cloudpaths.AWSPublicIPPath,
6673
func(w http.ResponseWriter, r *http.Request) {
6774
respond(w, r, WellKnownPublicIP)
6875
})
69-
http.HandleFunc(cloudpaths.AWSPublicIPPath+"/noip",
76+
mux.HandleFunc(cloudpaths.AWSPublicIPPath+"/noip",
7077
func(w http.ResponseWriter, r *http.Request) {
7178
respond(w, r, "")
7279
})
7380

74-
go http.Serve(l, nil)
81+
srv := http.Server{
82+
Handler: mux,
83+
}
84+
go srv.Serve(l)
7585
return "http://" + l.Addr().String(), func() { l.Close() }
7686
}

x/ref/runtime/internal/cloudvm_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func hasAddr(addrs []net.Addr, host string) bool {
4040
}
4141

4242
func TestCloudVMProviders(t *testing.T) {
43-
awsHost, awsClose := cloudvmtest.StartAWSMetadataServer(t)
43+
awsHost, awsClose := cloudvmtest.StartAWSMetadataServer(t, true)
4444
defer awsClose()
4545
cloudvm.SetAWSMetadataHost(awsHost)
4646

0 commit comments

Comments
 (0)