diff --git a/authority/authority.go b/authority/authority.go index 4a9123685..7b93a39ad 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -33,6 +33,7 @@ import ( "github.com/smallstep/certificates/cas" casapi "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/scep" "github.com/smallstep/certificates/templates" "github.com/smallstep/nosql" @@ -48,6 +49,7 @@ type Authority struct { adminDB admin.DB templates *templates.Templates linkedCAToken string + wrapTransport httptransport.Wrapper webhookClient *http.Client httpClient *http.Client @@ -128,10 +130,11 @@ func New(cfg *config.Config, opts ...Option) (*Authority, error) { } var a = &Authority{ - config: cfg, - certificates: new(sync.Map), - validateSCEP: true, - meter: noopMeter{}, + config: cfg, + certificates: new(sync.Map), + validateSCEP: true, + meter: noopMeter{}, + wrapTransport: httptransport.NoopWrapper(), } // Apply options. @@ -158,9 +161,10 @@ func New(cfg *config.Config, opts ...Option) (*Authority, error) { // project without the limitations of the config. func NewEmbedded(opts ...Option) (*Authority, error) { a := &Authority{ - config: &config.Config{}, - certificates: new(sync.Map), - meter: noopMeter{}, + config: &config.Config{}, + certificates: new(sync.Map), + meter: noopMeter{}, + wrapTransport: httptransport.NoopWrapper(), } // Apply options. @@ -496,7 +500,7 @@ func (a *Authority) init() error { clientRoots := make([]*x509.Certificate, 0, len(a.rootX509Certs)+len(a.federatedX509Certs)) clientRoots = append(clientRoots, a.rootX509Certs...) clientRoots = append(clientRoots, a.federatedX509Certs...) - a.httpClient, err = newHTTPClient(clientRoots...) + a.httpClient, err = newHTTPClient(a.wrapTransport, clientRoots...) if err != nil { return err } diff --git a/authority/http_client.go b/authority/http_client.go index ff61e45fe..d06464b33 100644 --- a/authority/http_client.go +++ b/authority/http_client.go @@ -5,30 +5,34 @@ import ( "crypto/x509" "fmt" "net/http" + + "github.com/smallstep/certificates/internal/httptransport" ) // newHTTPClient will return an HTTP client that trusts the system cert pool and -// the given roots, but only if the http.DefaultTransport is an *http.Transport. -// If not, it will return the default HTTP client. -func newHTTPClient(roots ...*x509.Certificate) (*http.Client, error) { - if tr, ok := http.DefaultTransport.(*http.Transport); ok { - pool, err := x509.SystemCertPool() - if err != nil { - return nil, fmt.Errorf("error initializing http client: %w", err) - } - for _, crt := range roots { - pool.AddCert(crt) - } +// the given roots. +func newHTTPClient(wt httptransport.Wrapper, roots ...*x509.Certificate) (*http.Client, error) { + pool, err := x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("error initializing http client: %w", err) + } + for _, crt := range roots { + pool.AddCert(crt) + } + tr, ok := http.DefaultTransport.(*http.Transport) + if !ok { + tr = httptransport.New() + } else { tr = tr.Clone() - tr.TLSClientConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - RootCAs: pool, - } - return &http.Client{ - Transport: tr, - }, nil } - return &http.Client{}, nil + tr.TLSClientConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + RootCAs: pool, + } + + return &http.Client{ + Transport: wt(tr), + }, nil } diff --git a/authority/http_client_test.go b/authority/http_client_test.go index 979c884df..5a77331a0 100644 --- a/authority/http_client_test.go +++ b/authority/http_client_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/internal/httptransport" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.step.sm/crypto/jose" @@ -113,8 +114,8 @@ func Test_newHTTPClient(t *testing.T) { }{http.DefaultTransport} http.DefaultTransport = transport - client, err := newHTTPClient(auth.rootX509Certs...) + client, err := newHTTPClient(httptransport.NoopWrapper(), auth.rootX509Certs...) assert.NoError(t, err) - assert.Equal(t, &http.Client{}, client) + assert.NotNil(t, client) }) } diff --git a/authority/options.go b/authority/options.go index 9738b391e..6a75bdd99 100644 --- a/authority/options.go +++ b/authority/options.go @@ -18,6 +18,7 @@ import ( "github.com/smallstep/certificates/cas" casapi "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/scep" ) @@ -103,6 +104,22 @@ func WithWebhookClient(c *http.Client) Option { } } +// Wrapper wraps the set of functions mapping [http.Transport] references to [http.RoundTripper]. +type TransportWrapper = httptransport.Wrapper + +// WithTransportWrapper sets the transport wrapper of the authority to the provided one or, in case +// that one is nil, to a noop one. +func WithTransportWrapper(tw httptransport.Wrapper) Option { + if tw == nil { + tw = httptransport.NoopWrapper() + } + + return func(a *Authority) error { + a.wrapTransport = tw + return nil + } +} + // WithGetIdentityFunc sets a custom function to retrieve the identity from // an external resource. func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, email string) (*provisioner.Identity, error)) Option { diff --git a/authority/provisioner/controller.go b/authority/provisioner/controller.go index 8d95aaf89..13f786954 100644 --- a/authority/provisioner/controller.go +++ b/authority/provisioner/controller.go @@ -9,6 +9,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/errs" + "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/webhook" "go.step.sm/linkedca" "golang.org/x/crypto/ssh" @@ -27,6 +28,7 @@ type Controller struct { webhookClient *http.Client webhooks []*Webhook httpClient *http.Client + wrapTransport httptransport.Wrapper } // NewController initializes a new provisioner controller. @@ -50,6 +52,7 @@ func NewController(p Interface, claims *Claims, config Config, options *Options) webhookClient: config.WebhookClient, webhooks: options.GetWebhooks(), httpClient: config.HTTPClient, + wrapTransport: config.WrapTransport, }, nil } @@ -89,16 +92,25 @@ func (c *Controller) AuthorizeSSHRenew(ctx context.Context, cert *ssh.Certificat } func (c *Controller) newWebhookController(templateData WebhookSetter, certType linkedca.Webhook_CertType, opts ...webhook.RequestBodyOption) *WebhookController { + wt := c.wrapTransport + if wt == nil { + wt = httptransport.NoopWrapper() + } + client := c.webhookClient if client == nil { - client = http.DefaultClient + client = &http.Client{ + Transport: wt(httptransport.New()), + } } + return &WebhookController{ - TemplateData: templateData, - client: client, - webhooks: c.webhooks, - certType: certType, - options: opts, + TemplateData: templateData, + client: client, + wrapTransport: wt, + webhooks: c.webhooks, + certType: certType, + options: opts, } } diff --git a/authority/provisioner/controller_test.go b/authority/provisioner/controller_test.go index fe0641e65..86a5b4995 100644 --- a/authority/provisioner/controller_test.go +++ b/authority/provisioner/controller_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/smallstep/certificates/authority/policy" + "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/webhook" "github.com/stretchr/testify/assert" "go.step.sm/crypto/pemutil" @@ -512,11 +513,18 @@ func Test_newWebhookController(t *testing.T) { options: opts, }}, } + for _, tt := range tests { - c := &Controller{} - got := c.newWebhookController(tt.args.templateData, tt.args.certType, tt.args.opts...) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("newWebhookController() = %v, want %v", got, tt.want) + c := Controller{ + webhookClient: new(http.Client), + wrapTransport: httptransport.NoopWrapper(), } + got := c.newWebhookController(tt.args.templateData, tt.args.certType, tt.args.opts...) + + assert.Equal(t, tt.args.templateData, got.TemplateData) + assert.Same(t, c.webhookClient, got.client) + assert.Equal(t, c.webhooks, got.webhooks) + assert.Equal(t, tt.args.opts, got.options) + assert.Equal(t, tt.args.certType, got.certType) } } diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 24792f667..86fdd5ec4 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -264,6 +264,9 @@ type Config struct { // HTTPClient is an HTTP client that trusts the system cert pool and the CA // roots. HTTPClient *http.Client + // WrapTransport references the function that should wrap any [http.Transport] initialized + // down the Config's chain. + WrapTransport TransportWrapper } type provisioner struct { diff --git a/authority/provisioner/scep.go b/authority/provisioner/scep.go index 0bdaf3e99..bdbd2497f 100644 --- a/authority/provisioner/scep.go +++ b/authority/provisioner/scep.go @@ -18,6 +18,7 @@ import ( "go.step.sm/crypto/x509util" "go.step.sm/linkedca" + "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/webhook" ) @@ -112,13 +113,14 @@ func (s *SCEP) DefaultTLSCertDuration() time.Duration { } type challengeValidationController struct { - client *http.Client - webhooks []*Webhook + client *http.Client + wrapTransport httptransport.Wrapper + webhooks []*Webhook } // newChallengeValidationController creates a new challengeValidationController // that performs challenge validation through webhooks. -func newChallengeValidationController(client *http.Client, webhooks []*Webhook) *challengeValidationController { +func newChallengeValidationController(client *http.Client, tw httptransport.Wrapper, webhooks []*Webhook) *challengeValidationController { scepHooks := []*Webhook{} for _, wh := range webhooks { if wh.Kind != linkedca.Webhook_SCEPCHALLENGE.String() { @@ -130,8 +132,9 @@ func newChallengeValidationController(client *http.Client, webhooks []*Webhook) scepHooks = append(scepHooks, wh) } return &challengeValidationController{ - client: client, - webhooks: scepHooks, + client: client, + wrapTransport: tw, + webhooks: scepHooks, } } @@ -157,7 +160,7 @@ func (c *challengeValidationController) Validate(ctx context.Context, csr *x509. req.ProvisionerName = provisionerName req.SCEPChallenge = challenge req.SCEPTransactionID = transactionID - resp, err := wh.DoWithContext(ctx, c.client, req, nil) // TODO(hs): support templated URL? Requires some refactoring + resp, err := wh.DoWithContext(ctx, c.client, c.wrapTransport, req, nil) // TODO(hs): support templated URL? Requires some refactoring if err != nil { return nil, fmt.Errorf("failed executing webhook request: %w", err) } @@ -176,13 +179,14 @@ func (c *challengeValidationController) Validate(ctx context.Context, csr *x509. } type notificationController struct { - client *http.Client - webhooks []*Webhook + client *http.Client + wrapTransport httptransport.Wrapper + webhooks []*Webhook } // newNotificationController creates a new notificationController // that performs SCEP notifications through webhooks. -func newNotificationController(client *http.Client, webhooks []*Webhook) *notificationController { +func newNotificationController(client *http.Client, tw httptransport.Wrapper, webhooks []*Webhook) *notificationController { scepHooks := []*Webhook{} for _, wh := range webhooks { if wh.Kind != linkedca.Webhook_NOTIFYING.String() { @@ -194,8 +198,9 @@ func newNotificationController(client *http.Client, webhooks []*Webhook) *notifi scepHooks = append(scepHooks, wh) } return ¬ificationController{ - client: client, - webhooks: scepHooks, + client: client, + wrapTransport: tw, + webhooks: scepHooks, } } @@ -207,7 +212,7 @@ func (c *notificationController) Success(ctx context.Context, csr *x509.Certific } req.X509Certificate.Raw = cert.Raw // adding the full certificate DER bytes req.SCEPTransactionID = transactionID - if _, err = wh.DoWithContext(ctx, c.client, req, nil); err != nil { + if _, err = wh.DoWithContext(ctx, c.client, c.wrapTransport, req, nil); err != nil { return fmt.Errorf("failed executing webhook request: %w: %w", ErrSCEPNotificationFailed, err) } } @@ -224,7 +229,7 @@ func (c *notificationController) Failure(ctx context.Context, csr *x509.Certific req.SCEPTransactionID = transactionID req.SCEPErrorCode = errorCode req.SCEPErrorDescription = errorDescription - if _, err = wh.DoWithContext(ctx, c.client, req, nil); err != nil { + if _, err = wh.DoWithContext(ctx, c.client, c.wrapTransport, req, nil); err != nil { return fmt.Errorf("failed executing webhook request: %w: %w", ErrSCEPNotificationFailed, err) } } @@ -267,12 +272,14 @@ func (s *SCEP) Init(config Config) (err error) { // Prepare the SCEP challenge validator s.challengeValidationController = newChallengeValidationController( config.WebhookClient, + config.WrapTransport, s.GetOptions().GetWebhooks(), ) // Prepare the SCEP notification controller s.notificationController = newNotificationController( config.WebhookClient, + config.WrapTransport, s.GetOptions().GetWebhooks(), ) diff --git a/authority/provisioner/scep_test.go b/authority/provisioner/scep_test.go index dbed83d5a..c94a4b507 100644 --- a/authority/provisioner/scep_test.go +++ b/authority/provisioner/scep_test.go @@ -201,7 +201,7 @@ func Test_challengeValidationController_Validate(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := newChallengeValidationController(tt.fields.client, tt.fields.webhooks) + c := newChallengeValidationController(tt.fields.client, nil, tt.fields.webhooks) ctx := context.Background() got, err := c.Validate(ctx, dummyCSR, tt.args.provisionerName, tt.args.challenge, tt.args.transactionID) if tt.expErr != nil { diff --git a/authority/provisioner/webhook.go b/authority/provisioner/webhook.go index 1c20066bf..962a17b9e 100644 --- a/authority/provisioner/webhook.go +++ b/authority/provisioner/webhook.go @@ -31,11 +31,12 @@ type WebhookSetter interface { } type WebhookController struct { - client *http.Client - webhooks []*Webhook - certType linkedca.Webhook_CertType - options []webhook.RequestBodyOption - TemplateData WebhookSetter + client *http.Client + wrapTransport httptransport.Wrapper + webhooks []*Webhook + certType linkedca.Webhook_CertType + options []webhook.RequestBodyOption + TemplateData WebhookSetter } // Enrich fetches data from remote servers and adds returned data to the @@ -63,7 +64,7 @@ func (wc *WebhookController) Enrich(ctx context.Context, req *webhook.RequestBod whCtx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() //nolint:gocritic // every request canceled with its own timeout - resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData) + resp, err := wh.DoWithContext(whCtx, wc.client, wc.wrapTransport, req, wc.TemplateData) if err != nil { return err } @@ -102,7 +103,7 @@ func (wc *WebhookController) Authorize(ctx context.Context, req *webhook.Request whCtx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() //nolint:gocritic // every request canceled with its own timeout - resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData) + resp, err := wh.DoWithContext(whCtx, wc.client, wc.wrapTransport, req, wc.TemplateData) if err != nil { return err } @@ -141,7 +142,11 @@ type Webhook struct { } `json:"-"` } -func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { +// TransportWrapper wraps the set of functions mapping [http.Transport] references to +// [http.RoundTripper]. +type TransportWrapper = httptransport.Wrapper + +func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, tw TransportWrapper, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL) if err != nil { return nil, err @@ -214,7 +219,7 @@ retry: } client = &http.Client{ - Transport: transport, + Transport: tw(transport), } } resp, err := client.Do(req) diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index 9cfd6e8fb..50136f418 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -627,7 +627,7 @@ func TestWebhook_Do(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() - got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, reqBody, tc.dataArg) + got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, httptransport.NoopWrapper(), reqBody, tc.dataArg) if tc.expectErr != nil { assert.Equal(t, tc.expectErr.Error(), err.Error()) return @@ -663,14 +663,14 @@ func TestWebhook_Do(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - _, err = wh.DoWithContext(ctx, client, reqBody, nil) + _, err = wh.DoWithContext(ctx, client, httptransport.NoopWrapper(), reqBody, nil) require.NoError(t, err) ctx, cancel = context.WithTimeout(context.Background(), time.Second*10) defer cancel() wh.DisableTLSClientAuth = true - _, err = wh.DoWithContext(ctx, client, reqBody, nil) + _, err = wh.DoWithContext(ctx, client, httptransport.NoopWrapper(), reqBody, nil) require.Error(t, err) }) } diff --git a/authority/provisioners.go b/authority/provisioners.go index 390a318d2..991c85099 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -202,6 +202,7 @@ func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner. AuthorizeSSHRenewFunc: a.authorizeSSHRenewFunc, WebhookClient: a.webhookClient, HTTPClient: a.httpClient, + WrapTransport: a.wrapTransport, SCEPKeyManager: a.scepKeyManager, }, nil } diff --git a/ca/ca.go b/ca/ca.go index 4711d13f2..066332a4f 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -198,7 +198,9 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { } webhookTransport := httptransport.New() - opts = append(opts, authority.WithWebhookClient(&http.Client{Transport: webhookTransport})) + opts = append(opts, + authority.WithWebhookClient(&http.Client{Transport: webhookTransport}), + ) auth, err := authority.New(cfg, opts...) if err != nil { diff --git a/internal/httptransport/httptransport.go b/internal/httptransport/httptransport.go index b14862488..76146df22 100644 --- a/internal/httptransport/httptransport.go +++ b/internal/httptransport/httptransport.go @@ -8,6 +8,17 @@ import ( "time" ) +// Wrapper wraps the set of functions mapping [http.Transport] references to [http.RoundTripper]. +type Wrapper func(*http.Transport) http.RoundTripper + +// NoopWrapper returns a [Wrapper] that simply casts its provided [http.Transport] to an +// [http.RoundTripper]. +func NoopWrapper() Wrapper { + return func(t *http.Transport) http.RoundTripper { + return t + } +} + // New returns a reference to an [http.Transport] that's initialized just like the // [http.DefaultTransport] is by the standard library. func New() *http.Transport {