From b9a79ce9fd9bb69ae5b513bc7bc682a32e0beca7 Mon Sep 17 00:00:00 2001 From: Brad Moylan Date: Fri, 24 Jan 2025 14:44:30 -0800 Subject: [PATCH] fix: Use url.JoinPath() to more flexibly handle extraneous slashes --- sdk/auth/client_credentials.go | 20 ++++++++++++++++---- sdk/auth/client_secret_authorizer.go | 12 ++++++++++-- sdk/client/client.go | 22 ++++++++-------------- sdk/internal/metadata/client.go | 8 +++++++- 4 files changed, 41 insertions(+), 21 deletions(-) diff --git a/sdk/auth/client_credentials.go b/sdk/auth/client_credentials.go index 571a12ccd8f..890de36415d 100644 --- a/sdk/auth/client_credentials.go +++ b/sdk/auth/client_credentials.go @@ -276,7 +276,11 @@ func (a *ClientAssertionAuthorizer) Token(ctx context.Context, _ *http.Request) if a.conf.Environment.Authorization == nil { return nil, fmt.Errorf("no `authorization` configuration was found for this environment") } - tokenUrl = tokenEndpoint(*a.conf.Environment.Authorization, a.conf.TenantID) + var err error + tokenUrl, err = tokenEndpoint(*a.conf.Environment.Authorization, a.conf.TenantID) + if err != nil { + return nil, err + } } return a.token(ctx, tokenUrl) @@ -300,7 +304,11 @@ func (a *ClientAssertionAuthorizer) AuxiliaryTokens(ctx context.Context, _ *http if a.conf.Environment.Authorization == nil { return nil, fmt.Errorf("no `authorization` configuration was found for this environment") } - tokenUrl = tokenEndpoint(*a.conf.Environment.Authorization, tenantId) + var err error + tokenUrl, err = tokenEndpoint(*a.conf.Environment.Authorization, tenantId) + if err != nil { + return tokens, err + } } token, err := a.token(ctx, tokenUrl) @@ -374,9 +382,13 @@ func clientCredentialsToken(ctx context.Context, endpoint string, params *url.Va return token, nil } -func tokenEndpoint(endpoint environments.Authorization, tenant string) string { +func tokenEndpoint(endpoint environments.Authorization, tenant string) (string, error) { if tenant == "" { tenant = "common" } - return fmt.Sprintf("%s/%s/oauth2/v2.0/token", endpoint.LoginEndpoint, tenant) + uri, err := url.JoinPath(endpoint.LoginEndpoint, tenant, "oauth2/v2.0/token") + if err != nil { + return "", fmt.Errorf("parsing loginEndpoint: %w", err) + } + return uri, nil } diff --git a/sdk/auth/client_secret_authorizer.go b/sdk/auth/client_secret_authorizer.go index dcbd3ffe85a..2593c83605c 100644 --- a/sdk/auth/client_secret_authorizer.go +++ b/sdk/auth/client_secret_authorizer.go @@ -83,7 +83,11 @@ func (a *ClientSecretAuthorizer) Token(ctx context.Context, _ *http.Request) (*o if a.conf.Environment.Authorization == nil { return nil, fmt.Errorf("no `authorization` configuration was found for this environment") } - tokenUrl = tokenEndpoint(*a.conf.Environment.Authorization, a.conf.TenantID) + var err error + tokenUrl, err = tokenEndpoint(*a.conf.Environment.Authorization, a.conf.TenantID) + if err != nil { + return nil, err + } } return clientCredentialsToken(ctx, tokenUrl, &v) @@ -118,7 +122,11 @@ func (a *ClientSecretAuthorizer) AuxiliaryTokens(ctx context.Context, _ *http.Re if a.conf.Environment.Authorization == nil { return nil, fmt.Errorf("no `authorization` configuration was found for this environment") } - tokenUrl = tokenEndpoint(*a.conf.Environment.Authorization, tenantId) + var err error + tokenUrl, err = tokenEndpoint(*a.conf.Environment.Authorization, tenantId) + if err != nil { + return nil, err + } } token, err := clientCredentialsToken(ctx, tokenUrl, &v) diff --git a/sdk/client/client.go b/sdk/client/client.go index a52e5dad664..96a198df1c8 100644 --- a/sdk/client/client.go +++ b/sdk/client/client.go @@ -377,11 +377,14 @@ func (c *Client) ClearResponseMiddlewares() { // NewRequest configures a new *Request func (c *Client) NewRequest(ctx context.Context, input RequestOptions) (*Request, error) { - req := (&http.Request{}).WithContext(ctx) - - req.Method = input.HttpMethod - - req.Header = make(http.Header) + uri, err := url.JoinPath(c.BaseUri, input.Path) + if err != nil { + return nil, fmt.Errorf("parsing URI: %w", err) + } + req, err := http.NewRequestWithContext(ctx, input.HttpMethod, uri, nil) + if err != nil { + return nil, err + } if input.ContentType != "" { req.Header.Add("Content-Type", input.ContentType) @@ -394,15 +397,6 @@ func (c *Client) NewRequest(ctx context.Context, input RequestOptions) (*Request req.Header.Add("X-Ms-Correlation-Request-Id", c.CorrelationId) } - path := strings.TrimPrefix(input.Path, "/") - u, err := url.ParseRequestURI(fmt.Sprintf("%s/%s", c.BaseUri, path)) - if err != nil { - return nil, err - } - - req.Host = u.Host - req.URL = u - ret := Request{ Client: c, Request: req, diff --git a/sdk/internal/metadata/client.go b/sdk/internal/metadata/client.go index c9bc16f7226..614fc756660 100644 --- a/sdk/internal/metadata/client.go +++ b/sdk/internal/metadata/client.go @@ -13,6 +13,7 @@ import ( "log" "net" "net/http" + "net/url" "runtime" "time" ) @@ -52,7 +53,12 @@ func (c *Client) GetMetaData(ctx context.Context) (*MetaData, error) { MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1, }, } - uri := fmt.Sprintf("%s/metadata/endpoints?api-version=2022-09-01", c.endpoint) + uri, err := url.JoinPath(c.endpoint, "/metadata/endpoints") + if err != nil { + return nil, fmt.Errorf("parsing endpoint: %+v", err) + } + uri += "?api-version=2022-09-01" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil) if err != nil { return nil, fmt.Errorf("preparing request: %+v", err)