diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 8c772f608..5773bad5a 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -6,6 +6,7 @@ * Enabled asynchronous token refreshes by default ([#1208](https://github.com/databricks/databricks-sdk-go/pull/1208)). ### Bug Fixes +* Tolerate trailing slashes in hostnames in `Config`. ### Documentation diff --git a/config/config.go b/config/config.go index 6c14f0a27..7bb035c76 100644 --- a/config/config.go +++ b/config/config.go @@ -458,15 +458,17 @@ func (c *Config) getOidcEndpoints(ctx context.Context) (*u2m.OAuthAuthorizationS oauthClient := &u2m.BasicOAuthEndpointSupplier{ Client: c.refreshClient, } + host := c.CanonicalHostName() if c.IsAccountClient() { - return oauthClient.GetAccountOAuthEndpoints(ctx, c.Host, c.AccountID) + return oauthClient.GetAccountOAuthEndpoints(ctx, host, c.AccountID) } - return oauthClient.GetWorkspaceOAuthEndpoints(ctx, c.Host) + return oauthClient.GetWorkspaceOAuthEndpoints(ctx, host) } func (c *Config) getOAuthArgument() (u2m.OAuthArgument, error) { + host := c.CanonicalHostName() if c.IsAccountClient() { - return u2m.NewBasicAccountOAuthArgument(c.Host, c.AccountID) + return u2m.NewBasicAccountOAuthArgument(host, c.AccountID) } - return u2m.NewBasicWorkspaceOAuthArgument(c.Host) + return u2m.NewBasicWorkspaceOAuthArgument(host) } diff --git a/config/config_test.go b/config/config_test.go index 5a5ac7041..7b117bdbd 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -70,34 +70,136 @@ func TestAuthenticate_InvalidHostSet(t *testing.T) { } func TestConfig_getOidcEndpoints_account(t *testing.T) { - c := &Config{ - Host: "https://accounts.cloud.databricks.com", - AccountID: "abc", + tests := []struct { + name string + host string + accountID string + }{ + { + name: "without trailing slash", + host: "https://accounts.cloud.databricks.com", + accountID: "abc", + }, + { + name: "with trailing slash", + host: "https://accounts.cloud.databricks.com/", + accountID: "abc", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Config{ + Host: tt.host, + AccountID: tt.accountID, + } + got, err := c.getOidcEndpoints(context.Background()) + assert.NoError(t, err) + assert.Equal(t, &u2m.OAuthAuthorizationServer{ + AuthorizationEndpoint: "https://accounts.cloud.databricks.com/oidc/accounts/abc/v1/authorize", + TokenEndpoint: "https://accounts.cloud.databricks.com/oidc/accounts/abc/v1/token", + }, got) + }) } - got, err := c.getOidcEndpoints(context.Background()) - assert.NoError(t, err) - assert.Equal(t, &u2m.OAuthAuthorizationServer{ - AuthorizationEndpoint: "https://accounts.cloud.databricks.com/oidc/accounts/abc/v1/authorize", - TokenEndpoint: "https://accounts.cloud.databricks.com/oidc/accounts/abc/v1/token", - }, got) } func TestConfig_getOidcEndpoints_workspace(t *testing.T) { - c := &Config{ - Host: "https://myworkspace.cloud.databricks.com", - HTTPTransport: fixtures.SliceTransport{ - { - Method: "GET", - Resource: "/oidc/.well-known/oauth-authorization-server", - Status: 200, - Response: `{"authorization_endpoint": "https://myworkspace.cloud.databricks.com/oidc/v1/authorize", "token_endpoint": "https://myworkspace.cloud.databricks.com/oidc/v1/token"}`, - }, + tests := []struct { + name string + host string + }{ + { + name: "without trailing slash", + host: "https://myworkspace.cloud.databricks.com", + }, + { + name: "with trailing slash", + host: "https://myworkspace.cloud.databricks.com/", }, } - got, err := c.getOidcEndpoints(context.Background()) - assert.NoError(t, err) - assert.Equal(t, &u2m.OAuthAuthorizationServer{ - AuthorizationEndpoint: "https://myworkspace.cloud.databricks.com/oidc/v1/authorize", - TokenEndpoint: "https://myworkspace.cloud.databricks.com/oidc/v1/token", - }, got) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Config{ + Host: tt.host, + HTTPTransport: fixtures.SliceTransport{ + { + Method: "GET", + Resource: "/oidc/.well-known/oauth-authorization-server", + Status: 200, + Response: `{"authorization_endpoint": "https://myworkspace.cloud.databricks.com/oidc/v1/authorize", "token_endpoint": "https://myworkspace.cloud.databricks.com/oidc/v1/token"}`, + }, + }, + } + got, err := c.getOidcEndpoints(context.Background()) + assert.NoError(t, err) + assert.Equal(t, &u2m.OAuthAuthorizationServer{ + AuthorizationEndpoint: "https://myworkspace.cloud.databricks.com/oidc/v1/authorize", + TokenEndpoint: "https://myworkspace.cloud.databricks.com/oidc/v1/token", + }, got) + }) + } +} + +func TestConfig_getOAuthArgument_account(t *testing.T) { + tests := []struct { + name string + host string + accountID string + }{ + { + name: "without trailing slash", + host: "https://accounts.cloud.databricks.com", + accountID: "abc", + }, + { + name: "with trailing slash", + host: "https://accounts.cloud.databricks.com/", + accountID: "abc", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Config{ + Host: tt.host, + AccountID: tt.accountID, + } + rawGot, err := c.getOAuthArgument() + assert.NoError(t, err) + got, ok := rawGot.(u2m.BasicAccountOAuthArgument) + assert.True(t, ok) + assert.Equal(t, "https://accounts.cloud.databricks.com", got.GetAccountHost()) + assert.Equal(t, "abc", got.GetAccountId()) + }) + } +} + +func TestConfig_getOAuthArgument_workspace(t *testing.T) { + tests := []struct { + name string + host string + }{ + { + name: "without trailing slash", + host: "https://myworkspace.cloud.databricks.com", + }, + { + name: "with trailing slash", + host: "https://myworkspace.cloud.databricks.com/", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Config{ + Host: tt.host, + } + rawGot, err := c.getOAuthArgument() + assert.NoError(t, err) + got, ok := rawGot.(u2m.BasicWorkspaceOAuthArgument) + assert.True(t, ok) + assert.Equal(t, "https://myworkspace.cloud.databricks.com", got.GetWorkspaceHost()) + }) + } }