Skip to content

AWS IAM: lakefs IDP interface #8994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
May 14, 2025
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 203 additions & 0 deletions pkg/authentication/externalidp/aws_client.go
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pkg/authentication/external_idp/aws_client.go

Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
package authentication

import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"github.com/treeverse/lakefs/pkg/api/apigen"
)

const (
AWSAuthVersion = "2011-06-15"
AWSAuthMethod = http.MethodPost
AWSAuthAction = "GetCallerIdentity"
AWSAuthAlgorithm = "AWS4-HMAC-SHA256"
StsGlobalEndpoint = "sts.amazonaws.com"
AWSAuthActionKey = "Action"
AWSAuthVersionKey = "Version"
AWSAuthAlgorithmKey = "X-Amz-Algorithm"
//nolint:gosec
AWSAuthCredentialKey = "X-Amz-Credential"
AWSAuthDateKey = "X-Amz-Date"
AWSAuthExpiresKey = "X-Amz-Expires"
//nolint:gosec
AWSAuthSecurityTokenKey = "X-Amz-Security-Token"
AWSAuthSignedHeadersKey = "X-Amz-SignedHeaders"
AWSAuthSignatureKey = "X-Amz-Signature"
AWSDatetimeFormat = "20060102T150405Z"
AWSCredentialTimeFormat = "20060102"
AWSDefaultSTSLoginExpire = 15 * time.Minute
)

var ErrAWSCredentialsExpired = errors.New("AWS credentials expired")
var ErrRetrievingToken = errors.New("failed to retrieve token")

type LoginResponse struct {
Token string
}

type IDPProvider interface {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDP == Identity Provider -> IDPProvider == IdentityProviderProvider 😋

Suggested change
type IDPProvider interface {
type Provider interface {

Since the package name contains the name idp and external, in go you don't need to duplicate the name, i.e this is enough:

import (
    ".../externalidp"
)

// usage is clear
externalidp.Provider

Login() (LoginResponse, error)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either rename the file, or extract these to a general idp.go file


type AWSProvider struct {
params AWSIAMParams
serverEndpoint string
client *http.Client
}

type IdentityTokenInfo struct {
Method string `json:"method"`
Host string `json:"host"`
Region string `json:"region"`
Action string `json:"action"`
Date string `json:"date"`
ExpirationDuration string `json:"expiration_duration"`
AccessKeyID string `json:"access_key_id"`
Signature string `json:"signature"`
SignedHeaders []string `json:"signed_headers"`
Version string `json:"version"`
Algorithm string `json:"algorithm"`
SecurityToken string `json:"security_token"`
}

type AWSIAMParams struct {
ProviderType string
TokenRequestHeaders map[string]string
URLPresignTTL time.Duration
TokenTTL time.Duration
}

func NewAWSProvider(params AWSIAMParams, serverEndpoint string, httpClient *http.Client) *AWSProvider {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*http.Client is (a) confusing name (b) too low level.

I suggest the following refactor:

  1. Create a minimal interface (easier to test, less dependencies) with a single function ExternalPrincipalLoginWithResponse.

Note: that's already a part of the generated lakeFS client ClientWithResponsesInterface, but just this func.

type ExternalPrincipalLoginClient interface {
  ExternalPrincipalLoginWithResponse(ctx context.Context, body ExternalPrincipalLoginJSONRequestBody, reqEditors ...RequestEditorFn) (*ExternalPrincipalLoginResponse, error)
}
  1. Refactor the existing NewAWSProvider function (help with testing and mock):
NewAWSProviderWithClient(params AWSIAMParams, client ExternalPrincipalLoginClient) *AWSProvider { 
   ...
}
  1. Add additional New func: NewAWSProvider that will be used in everest:
func NewAWSProvider(params AWSIAMParams, lakeFSHost string) (*AWSProvider, error) {
	client, err := apigen.NewClientWithResponses(
		serverEndpoint,
		apigen.WithHTTPClient(httpClient),
	)
	if err != nil {
		return nil, err
	}
        return NewAWSProviderWithClient(params, client), nil
}

return &AWSProvider{
params: params,
serverEndpoint: serverEndpoint,
client: httpClient,
}
}

func (p *AWSProvider) Login() (LoginResponse, error) {
jwt, err := getJWT(&p.params, p.serverEndpoint, p.client)
resp := LoginResponse{Token: jwt}
return resp, err
}

func getJWT(params *AWSIAMParams, serverEndpoint string, httpClient *http.Client) (string, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no point in another function getJWT if it's only used in Login put the code there, otherwise it's bloating the code

ctx := context.TODO()
identityToken, err := getIdentityToken(ctx, params)
if err != nil {
return "", err
}

client, err := apigen.NewClientWithResponses(
serverEndpoint,
apigen.WithHTTPClient(httpClient),
)
if err != nil {
return "", err
}

tokenTTL := int(params.TokenTTL.Seconds())
externalLoginInfo := apigen.ExternalLoginInformation{
IdentityRequest: map[string]interface{}{
"identity_token": identityToken,
},
TokenExpirationDuration: &tokenTTL,
}
externalPrincipalLoginResp, err := client.ExternalPrincipalLoginWithResponse(ctx, apigen.ExternalPrincipalLoginJSONRequestBody(externalLoginInfo))
if err != nil {
return "", err
}
if externalPrincipalLoginResp == nil || externalPrincipalLoginResp.JSON200 == nil {
return "", ErrRetrievingToken
}
return externalPrincipalLoginResp.JSON200.Token, nil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Check http issue after regular error with ResponseAsError
  2. Return struct AuthenticationToken not just the token, it contains more info and allows extending
  3. No need for those nil checks below
import "github.com/treeverse/lakefs/pkg/api/helpers"

/// 
if err != nil {
    ... 
}
err = helpers.ResponseAsError(res)
if err != nil {
	return nil, err
}
return res.JSON200

}

func getIdentityToken(ctx context.Context, params *AWSIAMParams) (string, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please avoid making this func private, so we can easily go-get and test.
Also it should be part of the AWSProvider struct, e.g struct method.

cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return "", err
}
creds, err := cfg.Credentials.Retrieve(ctx)
if err != nil {
return "", err
}
if creds.Expired() {
return "", ErrAWSCredentialsExpired
}
stsClient := sts.NewFromConfig(cfg)
stsPresignClient := sts.NewPresignClient(stsClient, func(o *sts.PresignOptions) {
o.ClientOptions = append(o.ClientOptions, func(opts *sts.Options) {
opts.ClientLogMode = aws.LogSigning
})
})

presignGetCallerIdentityResp, err := stsPresignClient.PresignGetCallerIdentity(context.Background(), &sts.GetCallerIdentityInput{},
sts.WithPresignClientFromClientOptions(sts.WithAPIOptions(setHTTPHeaders(params.TokenRequestHeaders, params.URLPresignTTL))),
)
if err != nil {
return "", err
}

parsedURL, err := url.Parse(presignGetCallerIdentityResp.URL)
if err != nil {
return "", err
}

queryParams := parsedURL.Query()
credentials := queryParams.Get(AWSAuthCredentialKey)
splitedCreds := strings.Split(credentials, "/")
calculatedRegion := splitedCreds[2]
identityTokenInfo := IdentityTokenInfo{
Method: "POST",
Host: parsedURL.Host,
Region: calculatedRegion,
Action: AWSAuthAction,
Date: queryParams.Get(AWSAuthDateKey),
ExpirationDuration: queryParams.Get(AWSAuthExpiresKey),
AccessKeyID: creds.AccessKeyID,
Signature: queryParams.Get(AWSAuthSignatureKey),
SignedHeaders: strings.Split(queryParams.Get(AWSAuthSignedHeadersKey), ";"),
Version: queryParams.Get(AWSAuthVersionKey),
Algorithm: queryParams.Get(AWSAuthAlgorithmKey),
SecurityToken: queryParams.Get(AWSAuthSecurityTokenKey),
}

marshaledIdentityTokenInfo, _ := json.Marshal(identityTokenInfo)
encodedIdentityTokenInfo := base64.StdEncoding.EncodeToString(marshaledIdentityTokenInfo)
return encodedIdentityTokenInfo, nil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section is a good candidate to a separate function, it'll be easier to test it, i.e:

func NewIdentityTokenInfoFromURL(presignedURL string) (*IdentityTokenInfo, string, error)

Note: returning both the encoded string output of the identityTokenInfo and the struct itself - it allows to test it properly, writing a test for encoded string is harder if you want to verify specific fields.

}

func setHTTPHeaders(requestHeaders map[string]string, ttl time.Duration) func(*middleware.Stack) error {
return func(stack *middleware.Stack) error {
return stack.Build.Add(middleware.BuildMiddlewareFunc("AddHeaders", func(
ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
) (
middleware.BuildOutput, middleware.Metadata, error,
) {
if req, ok := in.Request.(*smithyhttp.Request); ok {
req.Method = "POST"
for header, value := range requestHeaders {
req.Header.Add(header, value)
}
queryParams := req.Request.URL.Query()
queryParams.Set(AWSAuthExpiresKey, fmt.Sprintf("%d", int(ttl.Seconds())))
req.Request.URL.RawQuery = queryParams.Encode()
}
return next.HandleBuild(ctx, in)
}), middleware.Before)
}
}
Loading