Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ca/acmeClient.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ import (
"strings"

"github.com/pkg/errors"

"go.step.sm/crypto/jose"

"github.com/smallstep/certificates/acme"
acmeAPI "github.com/smallstep/certificates/acme/api"
"go.step.sm/crypto/jose"
)

// ACMEClient implements an HTTP client to an ACME API.
Expand All @@ -29,7 +31,7 @@ type ACMEClient struct {
// NewACMEClient initializes a new ACMEClient.
func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*ACMEClient, error) {
// Retrieve transport from options.
o := new(clientOptions)
o := defaultClientOptions()
if err := o.apply(opts); err != nil {
return nil, err
}
Expand Down
26 changes: 14 additions & 12 deletions ca/acmeClient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ import (
"time"

"github.com/pkg/errors"

"github.com/smallstep/assert"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"

"github.com/smallstep/certificates/acme"
acmeAPI "github.com/smallstep/certificates/acme/api"
"github.com/smallstep/certificates/api/render"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
)

func TestNewACMEClient(t *testing.T) {
Expand Down Expand Up @@ -169,7 +171,7 @@ func TestACMEClient_GetNonce(t *testing.T) {
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
o := new(clientOptions)
o := defaultClientOptions()
assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)}))
tr, err := o.getTransport(srv.URL)
assert.FatalError(t, err)
Expand Down Expand Up @@ -241,7 +243,7 @@ func TestACMEClient_post(t *testing.T) {
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
o := new(clientOptions)
o := defaultClientOptions()
assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)}))
tr, err := o.getTransport(srv.URL)
assert.FatalError(t, err)
Expand Down Expand Up @@ -372,7 +374,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
NewOrder: srv.URL + "/bar",
}
// Retrieve transport from options.
o := new(clientOptions)
o := defaultClientOptions()
assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)}))
tr, err := o.getTransport(srv.URL)
assert.FatalError(t, err)
Expand Down Expand Up @@ -507,7 +509,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
o := new(clientOptions)
o := defaultClientOptions()
assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)}))
tr, err := o.getTransport(srv.URL)
assert.FatalError(t, err)
Expand Down Expand Up @@ -629,7 +631,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
o := new(clientOptions)
o := defaultClientOptions()
assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)}))
tr, err := o.getTransport(srv.URL)
assert.FatalError(t, err)
Expand Down Expand Up @@ -751,7 +753,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
o := new(clientOptions)
o := defaultClientOptions()
assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)}))
tr, err := o.getTransport(srv.URL)
assert.FatalError(t, err)
Expand Down Expand Up @@ -874,7 +876,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
o := new(clientOptions)
o := defaultClientOptions()
assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)}))
tr, err := o.getTransport(srv.URL)
assert.FatalError(t, err)
Expand Down Expand Up @@ -1087,7 +1089,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
o := new(clientOptions)
o := defaultClientOptions()
assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)}))
tr, err := o.getTransport(srv.URL)
assert.FatalError(t, err)
Expand Down Expand Up @@ -1214,7 +1216,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
o := new(clientOptions)
o := defaultClientOptions()
assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)}))
tr, err := o.getTransport(srv.URL)
assert.FatalError(t, err)
Expand Down Expand Up @@ -1347,7 +1349,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
o := new(clientOptions)
o := defaultClientOptions()
assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)}))
tr, err := o.getTransport(srv.URL)
assert.FatalError(t, err)
Expand Down
15 changes: 12 additions & 3 deletions ca/adminClient.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,23 @@ func (e *AdminClientError) Error() string {
return e.Message
}

// defaultClientOptions returns a new [clientOptions] with a
// default timeout set.
func defaultClientOptions() clientOptions {
return clientOptions{
timeout: 15 * time.Second,
}
}

// NewAdminClient creates a new AdminClient with the given endpoint and options.
func NewAdminClient(endpoint string, opts ...ClientOption) (*AdminClient, error) {
u, err := parseEndpoint(endpoint)
if err != nil {
return nil, err
}

// Retrieve transport from options.
o := new(clientOptions)
o := defaultClientOptions()
if err := o.apply(opts); err != nil {
return nil, err
}
Expand All @@ -77,7 +86,7 @@ func NewAdminClient(endpoint string, opts ...ClientOption) (*AdminClient, error)
}

return &AdminClient{
client: newClient(tr),
client: newClient(tr, o.timeout),
endpoint: u,
retryFunc: o.retryFunc,
opts: opts,
Expand Down Expand Up @@ -124,7 +133,7 @@ func (c *AdminClient) generateAdminToken(aud *url.URL) (string, error) {
func (c *AdminClient) retryOnError(r *http.Response) bool {
if c.retryFunc != nil {
if c.retryFunc(r.StatusCode) {
o := new(clientOptions)
o := defaultClientOptions()
if err := o.apply(c.opts); err != nil {
return false
}
Expand Down
26 changes: 21 additions & 5 deletions ca/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"path/filepath"
"strconv"
"strings"
"time"

"github.com/pkg/errors"
"golang.org/x/net/http2"
Expand Down Expand Up @@ -53,10 +54,11 @@ type uaClient struct {
Client *http.Client
}

func newClient(transport http.RoundTripper) *uaClient {
func newClient(transport http.RoundTripper, timeout time.Duration) *uaClient {
return &uaClient{
Client: &http.Client{
Transport: transport,
Timeout: timeout,
},
}
}
Expand Down Expand Up @@ -149,6 +151,7 @@ type ClientOption func(o *clientOptions) error

type clientOptions struct {
transport http.RoundTripper
timeout time.Duration
rootSHA256 string
rootFilename string
rootBundle []byte
Expand Down Expand Up @@ -388,6 +391,16 @@ func WithRetryFunc(fn RetryFunc) ClientOption {
}
}

// WithTimeout defines the time limit for requests made by this client. The
// timeout includes connection time, any redirects, and reading the response
// body.
func WithTimeout(d time.Duration) ClientOption {
return func(o *clientOptions) error {
o.timeout = d
return nil
}
}

func getTransportFromFile(filename string) (http.RoundTripper, error) {
data, err := os.ReadFile(filename)
if err != nil {
Expand Down Expand Up @@ -548,6 +561,7 @@ type Client struct {
client *uaClient
endpoint *url.URL
retryFunc RetryFunc
timeout time.Duration
opts []ClientOption
}

Expand All @@ -557,8 +571,9 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) {
if err != nil {
return nil, err
}

// Retrieve transport from options.
o := new(clientOptions)
o := defaultClientOptions()
if err := o.apply(opts); err != nil {
return nil, err
}
Expand All @@ -568,17 +583,18 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) {
}

return &Client{
client: newClient(tr),
client: newClient(tr, o.timeout),
endpoint: u,
retryFunc: o.retryFunc,
timeout: o.timeout,
opts: opts,
}, nil
}
Comment on lines 585 to 592
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also set a default timeout? I believe without a default timeout set, it will just wait until the server disconnects based on its timeouts (if any). See the output for different cases below. I know this client will be used with our CA, and we're in control over those timeouts, but it's generally a good practice to set a timeout in any http.Client that is created.

The WithContext methods can also be used, and they seem to behave as expected. However, we're not setting timeouts through the context in places where we use those methods.

# timeout on connection triggered by server; no (default) timeout in the HTTP client or context:
$ go run cmd/step/main.go ca health 
client GET https://127.0.0.1:8443/health failed: stream error: stream ID 1; INTERNAL_ERROR; received from peer

# timeout on the context from the client side:
$ go run cmd/step/main.go ca health
client GET https://127.0.0.1:8443/health failed: context deadline exceeded
exit status 1

# timeout on the client:
$ go run cmd/step/main.go ca health
client GET https://127.0.0.1:8443/health failed: context deadline exceeded (Client.Timeout exceeded while awaiting headers)
exit status 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maraino I've added the default timeout in 3135a2c.


func (c *Client) retryOnError(r *http.Response) bool {
if c.retryFunc != nil {
if c.retryFunc(r.StatusCode) {
o := new(clientOptions)
o := defaultClientOptions()
if err := o.apply(c.opts); err != nil {
return false
}
Expand Down Expand Up @@ -890,7 +906,7 @@ func (c *Client) RevokeWithContext(ctx context.Context, req *api.RevokeRequest,
var uaClient *uaClient
retry:
if tr != nil {
uaClient = newClient(tr)
uaClient = newClient(tr, c.timeout)
} else {
uaClient = c.client
}
Expand Down
28 changes: 28 additions & 0 deletions ca/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,34 @@ func TestClient_GetCaURL(t *testing.T) {
}
}

func TestClient_WithTimeout(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
render.JSONStatus(w, r, api.HealthResponse{Status: "ok"}, 200)
}))
defer srv.Close()

tests := []struct {
name string
options []ClientOption
assertion assert.ErrorAssertionFunc
}{
{"ok", []ClientOption{WithTransport(http.DefaultTransport)}, assert.NoError},
{"ok with timeout", []ClientOption{WithTransport(http.DefaultTransport), WithTimeout(time.Second)}, assert.NoError},
{"fail with timeout", []ClientOption{WithTransport(http.DefaultTransport), WithTimeout(100 * time.Millisecond)}, assert.Error},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, tt.options...)
require.NoError(t, err)
assert.NotZero(t, c.timeout)
_, err = c.Health()
tt.assertion(t, err)
})
}
}

func Test_enforceRequestID(t *testing.T) {
set := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
set.Header.Set("X-Request-Id", "already-set")
Expand Down