Skip to content

Draft rate limiter with example usage #4121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pkg/sources/postman/postman.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,13 @@ func (s *Source) Init(ctx context.Context, name string, jobId sources.JobID, sou
if conn.GetToken() == "" {
return errors.New("Postman token is empty")
}
s.client = NewClient(conn.GetToken())

client, err := NewClient(conn.GetToken())
if err != nil {
return err
}

s.client = client
s.client.HTTPClient = common.RetryableHTTPClientTimeout(10)
log.RedactGlobally(conn.GetToken())
case *sourcespb.Postman_Unauthenticated:
Expand Down
82 changes: 57 additions & 25 deletions pkg/sources/postman/postman_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"time"

"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"golang.org/x/time/rate"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/rate_limiter"

"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
)
Expand Down Expand Up @@ -193,28 +193,50 @@ type Client struct {

// Rate limiter needed for Postman API workspace and collection requests. Postman API rate limit
// is 10 calls in 10 seconds for GET /collections, GET /workspaces, and GET /workspaces/{id} endpoints.
WorkspaceAndCollectionRateLimiter *rate.Limiter
WorkspaceAndCollectionRateLimiter *rate_limiter.APIRateLimiter

// Rate limiter needed for Postman API. General rate limit is 300 requests per minute.
GeneralRateLimiter *rate.Limiter
GeneralRateLimiter *rate_limiter.APIRateLimiter
}

// NewClient returns a new Postman API client.
func NewClient(postmanToken string) *Client {
func NewClient(postmanToken string) (*Client, error) {
bh := map[string]string{
"Content-Type": defaultContentType,
"User-Agent": userAgent,
"X-API-Key": postmanToken,
}

workspaceAndCollectionRateLimiter, err := rate_limiter.NewAPIRateLimiter(
"api.getpostman.com",
map[string]rate_limiter.APIRateLimit{
"1r/s": rate_limiter.NewSimpleRateLimit(1),
},
)

if err != nil {
return nil, err
}

generalRateLimiter, err := rate_limiter.NewAPIRateLimiter(
"api.getpostman.com",
map[string]rate_limiter.APIRateLimit{
"5r/s": rate_limiter.NewSimpleRateLimit(5),
},
)

if err != nil {
return nil, err
}

c := &Client{
HTTPClient: http.DefaultClient,
Headers: bh,
WorkspaceAndCollectionRateLimiter: rate.NewLimiter(rate.Every(time.Second), 1),
GeneralRateLimiter: rate.NewLimiter(rate.Every(time.Second/5), 1),
WorkspaceAndCollectionRateLimiter: workspaceAndCollectionRateLimiter,
GeneralRateLimiter: generalRateLimiter,
}

return c
return c, nil
}

// NewRequest creates an API request (Only GET needed for our interaction w/ Postman)
Expand Down Expand Up @@ -247,13 +269,15 @@ func checkResponseStatus(r *http.Response) error {
}

// getPostmanResponseBodyBytes makes a request to the Postman API and returns the response body as bytes.
func (c *Client) getPostmanResponseBodyBytes(ctx context.Context, url string, headers map[string]string) ([]byte, error) {
func (c *Client) getPostmanResponseBodyBytes(ctx context.Context, url string, headers map[string]string, rl *rate_limiter.APIRateLimiter) ([]byte, error) {
req, err := c.NewRequest(url, headers)
if err != nil {
return nil, err
}

resp, err := c.HTTPClient.Do(req)
resp, err := rl.DoWithRateLimiting(ctx, req, func() (*http.Response, error) {
return c.HTTPClient.Do(req)
})
if err != nil {
return nil, err
}
Expand All @@ -280,10 +304,12 @@ func (c *Client) EnumerateWorkspaces(ctx context.Context) ([]Workspace, error) {
Workspaces []Workspace `json:"workspaces"`
}{}

if err := c.WorkspaceAndCollectionRateLimiter.Wait(ctx); err != nil {
return nil, fmt.Errorf("could not wait for rate limiter during workspaces enumeration getting: %w", err)
}
body, err := c.getPostmanResponseBodyBytes(ctx, "https://api.getpostman.com/workspaces", nil)
body, err := c.getPostmanResponseBodyBytes(
ctx,
"https://api.getpostman.com/workspaces",
nil,
c.WorkspaceAndCollectionRateLimiter,
)
if err != nil {
return nil, fmt.Errorf("could not get postman workspace response bytes during enumeration: %w", err)
}
Expand Down Expand Up @@ -315,10 +341,12 @@ func (c *Client) GetWorkspace(ctx context.Context, workspaceUUID string) (Worksp
}{}

url := fmt.Sprintf(WORKSPACE_URL, workspaceUUID)
if err := c.WorkspaceAndCollectionRateLimiter.Wait(ctx); err != nil {
return Workspace{}, fmt.Errorf("could not wait for rate limiter during workspace getting: %w", err)
}
body, err := c.getPostmanResponseBodyBytes(ctx, url, nil)
body, err := c.getPostmanResponseBodyBytes(
ctx,
url,
nil,
c.WorkspaceAndCollectionRateLimiter,
)
if err != nil {
return Workspace{}, fmt.Errorf("could not get postman workspace (%s) response bytes: %w", workspaceUUID, err)
}
Expand All @@ -336,10 +364,12 @@ func (c *Client) GetEnvironmentVariables(ctx context.Context, environment_uuid s
}{}

url := fmt.Sprintf(ENVIRONMENTS_URL, environment_uuid)
if err := c.GeneralRateLimiter.Wait(ctx); err != nil {
return VariableData{}, fmt.Errorf("could not wait for rate limiter during environment variable getting: %w", err)
}
body, err := c.getPostmanResponseBodyBytes(ctx, url, nil)
body, err := c.getPostmanResponseBodyBytes(
ctx,
url,
nil,
c.GeneralRateLimiter,
)
if err != nil {
return VariableData{}, fmt.Errorf("could not get postman environment (%s) response bytes: %w", environment_uuid, err)
}
Expand All @@ -357,10 +387,12 @@ func (c *Client) GetCollection(ctx context.Context, collection_uuid string) (Col
}{}

url := fmt.Sprintf(COLLECTIONS_URL, collection_uuid)
if err := c.WorkspaceAndCollectionRateLimiter.Wait(ctx); err != nil {
return Collection{}, fmt.Errorf("could not wait for rate limiter during collection getting: %w", err)
}
body, err := c.getPostmanResponseBodyBytes(ctx, url, nil)
body, err := c.getPostmanResponseBodyBytes(
ctx,
url,
nil,
c.WorkspaceAndCollectionRateLimiter,
)
if err != nil {
return Collection{}, fmt.Errorf("could not get postman collection (%s) response bytes: %w", collection_uuid, err)
}
Expand Down
68 changes: 68 additions & 0 deletions pkg/sources/rate_limiter/api_rate_limit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package rate_limiter

import (
"net/http"
"time"

"github.com/trufflesecurity/trufflehog/v3/pkg/context"
)

// APIRateLimit describes a single rate limit of an API.
//
// Implementation requirements:
// - Be goroutine safe.
// - .MaybeWait can *NEVER* sleep for a duration; it can only sleep _UNTIL_ a
// time.
//
// Usage requirements:
// - APIRateLimits can only be used on a single API.
// - An implementation might worry that it has to track request counts (etc.)
// across different APIs, but this way it doesn't.
// - This also means that APIRateLimits can be used in multiple
// APIRateLimiters, as long as those APIRateLimiters are all only used
// against the same API.
//
// For example, if an API has 2 endpoints, A with a 1r/s limit and another B
// with a 5r/s limit, and the API in general has a 500r/month limit, this
// configuration is possible:
//
// oneReqPerSecond := NewTokenBucketRateLimit(rate.Every(time.Second), 1)
// fiveReqsPerSecond := NewTokenBucketRateLimit(rate.Every(time.Second)/5, 1)
// fiveHundredReqsPerMonth := NewPersistentRateLimit(500, MONTH)
// rateLimiterA := NewAPIRateLimiter(oneReqPerSecond, fiveHundredReqsPerMonth)
// rateLimiterB := NewAPIRateLimiter(fiveReqsPerSecond, fiveHundredReqsPerMonth)
type APIRateLimit interface {
// MaybeWait potentially sleeps in order to honor a rate limit, makes an HTTP
// request, and returns the response or an error. Implementations should:
// - Be goroutine safe
// - Check if ctx has been canceled
// - Not modify req
// - *NEVER* sleep for a duration; only sleep _UNTIL_ a time
//
// APIRateLimiter calls the MaybeWait method of all its APIRateLimits in a
// parallel loop. Any returned errors are combined into a single error, but
// returning an error doesn't stop the APIRateLimiter from (maybe) waiting on
// the other limits. Returning an error will prevent any further processing
// of the HTTP request however (sending the request and updating the
// RateLimts).
MaybeWait(ctx context.Context, req *http.Request, now time.Time) error

// Update updates the state of a APIRateLimit from an HTTP response, e.g. by
// checking for HTTP status 429 or reading a RetryAfter header
// - Be goroutine safe
// - Check if ctx has been canceled
// - Not modify res
//
// Services may only return rate limits as durations, e.g. `RetryAfter: 60`
// (units are in seconds, cf. RFC-6585), which is incompatible with MaybeWait
// as it can't wait a duration like 60 seconds, it can only wait until a
// time. Therefore it's incumbent on Update to handle this somehow, generally
// by converting the duration into a time in the future using the `now` arg.
// It's also recommended to pad the time somewhat.
//
// APIRateLimiter calls the Update method of all its APIRateLimits in a
// parallel loop. Any returned errors are combined into a single error, but
// returning an error doesn't stop the APIRateLimiter from updating the other
// limits.
Update(ctx context.Context, res *http.Response, now time.Time) error
}
139 changes: 139 additions & 0 deletions pkg/sources/rate_limiter/api_rate_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package rate_limiter

import (
"errors"
"fmt"
"net/http"
"sync"
"time"

"github.com/trufflesecurity/trufflehog/v3/pkg/context"
)

// APIRateLimiter provides a facility for honoring an API's rate limits. To use
// it:
// - Create a APIRateLimiter with its RateLimits
// - Call .Do instead of what you would normally call to make a request
// - Process the response (returned from .Do) as normal
//
// A APIRateLimiter should only be used on a single API. If you're making
// requests to multiple APIs, use multiple APIRateLimiters.
type APIRateLimiter struct {
hostname string
limits map[string]APIRateLimit
}

// Returns a new rate limiter with the given limits. Limits are passed by name
// in the map, ex:
//
// NewAPIRateLimiter(map[string]RateLimit{
// "5r/s": fiveRequestsPerSecondLimit,
// })
func NewAPIRateLimiter(
hostname string,
limits map[string]APIRateLimit,
) (*APIRateLimiter, error) {
for limitHostname := range limits {
if limitHostname != hostname {
return nil, fmt.Errorf(
"cannot add rate limit for API %q to rate limiter for different API %q",
limitHostname,
hostname,
)
}
}

return &APIRateLimiter{hostname: hostname, limits: limits}, nil
}

// Makes an HTTP request to an API while honoring its limits.
func (api *APIRateLimiter) DoWithRateLimiting(
ctx context.Context,
req *http.Request,
makeRequest func() (*http.Response, error),
) (*http.Response, error) {
if len(api.limits) == 0 {
return makeRequest()
}

if req.URL.Hostname() != api.hostname {
return nil, fmt.Errorf(
"cannot rate limit requests to API %q with a rate limiter for API %q",
req.URL.Hostname(),
api.hostname,
)
}

now := time.Now()

// [NOTE] errgroup.Group oddly isn't what we want here. It presumes you want
// to stop all other processing if a single task fails (we don't), and
// that functionality is the only reason to use it instead of a
// WaitGroup.
maybeWaitGroup := &sync.WaitGroup{}
maybeWaitErrorLock := &sync.Mutex{}
var maybeWaitError error = nil

for name, lim := range api.limits {
maybeWaitGroup.Add(1)
go func(name string, lim APIRateLimit) {
defer maybeWaitGroup.Done()

if err := lim.MaybeWait(ctx, req, now); err != nil {
err = fmt.Errorf("error updating rate limit %s: %w", name, err)

maybeWaitErrorLock.Lock()
if maybeWaitError == nil {
maybeWaitError = err
} else {
maybeWaitError = errors.Join(maybeWaitError, err)
}
maybeWaitErrorLock.Unlock()
}
}(name, lim)
}

maybeWaitGroup.Wait()

if maybeWaitError != nil {
return nil, fmt.Errorf("error honoring rate limits: %w", maybeWaitError)
}

res, err := makeRequest()
if err != nil {
return nil, fmt.Errorf("error making HTTP request: %w", err)
}

now = time.Now()

updateWaitGroup := &sync.WaitGroup{}
updateErrorLock := &sync.Mutex{}
var updateError error = nil

for name, lim := range api.limits {
updateWaitGroup.Add(1)
go func(name string, lim APIRateLimit) {
defer updateWaitGroup.Done()

if err := lim.Update(ctx, res, now); err != nil {
err = fmt.Errorf("error updating rate limit %s: %w", name, err)

updateErrorLock.Lock()
if updateError == nil {
updateError = err
} else {
updateError = errors.Join(updateError, err)
}
updateErrorLock.Unlock()
}
}(name, lim)
}

updateWaitGroup.Wait()

if updateError != nil {
return nil, fmt.Errorf("error updating rate limits: %w", updateError)
}

return res, nil
}
Loading
Loading