Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions authority/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/smallstep/certificates/cas"
casapi "github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/internal/httptransport"
"github.com/smallstep/certificates/scep"
"github.com/smallstep/certificates/templates"
"github.com/smallstep/nosql"
Expand All @@ -48,6 +49,7 @@ type Authority struct {
adminDB admin.DB
templates *templates.Templates
linkedCAToken string
wrapTransport httptransport.Wrapper
webhookClient *http.Client
httpClient *http.Client

Expand Down Expand Up @@ -128,10 +130,11 @@ func New(cfg *config.Config, opts ...Option) (*Authority, error) {
}

var a = &Authority{
config: cfg,
certificates: new(sync.Map),
validateSCEP: true,
meter: noopMeter{},
config: cfg,
certificates: new(sync.Map),
validateSCEP: true,
meter: noopMeter{},
wrapTransport: httptransport.NoopWrapper(),
}

// Apply options.
Expand All @@ -158,9 +161,10 @@ func New(cfg *config.Config, opts ...Option) (*Authority, error) {
// project without the limitations of the config.
func NewEmbedded(opts ...Option) (*Authority, error) {
a := &Authority{
config: &config.Config{},
certificates: new(sync.Map),
meter: noopMeter{},
config: &config.Config{},
certificates: new(sync.Map),
meter: noopMeter{},
wrapTransport: httptransport.NoopWrapper(),
}

// Apply options.
Expand Down Expand Up @@ -496,7 +500,7 @@ func (a *Authority) init() error {
clientRoots := make([]*x509.Certificate, 0, len(a.rootX509Certs)+len(a.federatedX509Certs))
clientRoots = append(clientRoots, a.rootX509Certs...)
clientRoots = append(clientRoots, a.federatedX509Certs...)
a.httpClient, err = newHTTPClient(clientRoots...)
a.httpClient, err = newHTTPClient(a.wrapTransport, clientRoots...)
if err != nil {
return err
}
Expand Down
42 changes: 23 additions & 19 deletions authority/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,34 @@ import (
"crypto/x509"
"fmt"
"net/http"

"github.com/smallstep/certificates/internal/httptransport"
)

// newHTTPClient will return an HTTP client that trusts the system cert pool and
// the given roots, but only if the http.DefaultTransport is an *http.Transport.
// If not, it will return the default HTTP client.
func newHTTPClient(roots ...*x509.Certificate) (*http.Client, error) {
if tr, ok := http.DefaultTransport.(*http.Transport); ok {
pool, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("error initializing http client: %w", err)
}
for _, crt := range roots {
pool.AddCert(crt)
}
// the given roots.
func newHTTPClient(wt httptransport.Wrapper, roots ...*x509.Certificate) (*http.Client, error) {
pool, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("error initializing http client: %w", err)
}
for _, crt := range roots {
pool.AddCert(crt)
}

tr, ok := http.DefaultTransport.(*http.Transport)
if !ok {
tr = httptransport.New()
} else {
tr = tr.Clone()
tr.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: pool,
}
return &http.Client{
Transport: tr,
}, nil
}

return &http.Client{}, nil
tr.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: pool,
}

return &http.Client{
Transport: wt(tr),
}, nil
}
5 changes: 3 additions & 2 deletions authority/http_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/internal/httptransport"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/jose"
Expand Down Expand Up @@ -113,8 +114,8 @@ func Test_newHTTPClient(t *testing.T) {
}{http.DefaultTransport}
http.DefaultTransport = transport

client, err := newHTTPClient(auth.rootX509Certs...)
client, err := newHTTPClient(httptransport.NoopWrapper(), auth.rootX509Certs...)
assert.NoError(t, err)
assert.Equal(t, &http.Client{}, client)
assert.NotNil(t, client)
})
}
17 changes: 17 additions & 0 deletions authority/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/smallstep/certificates/cas"
casapi "github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/internal/httptransport"
"github.com/smallstep/certificates/scep"
)

Expand Down Expand Up @@ -103,6 +104,22 @@ func WithWebhookClient(c *http.Client) Option {
}
}

// Wrapper wraps the set of functions mapping [http.Transport] references to [http.RoundTripper].
type TransportWrapper = httptransport.Wrapper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above the above.


// WithTransportWrapper sets the transport wrapper of the authority to the provided one or, in case
// that one is nil, to a noop one.
func WithTransportWrapper(tw httptransport.Wrapper) Option {
if tw == nil {
tw = httptransport.NoopWrapper()
}

return func(a *Authority) error {
a.wrapTransport = tw
return nil
}
}

// WithGetIdentityFunc sets a custom function to retrieve the identity from
// an external resource.
func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, email string) (*provisioner.Identity, error)) Option {
Expand Down
24 changes: 18 additions & 6 deletions authority/provisioner/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/internal/httptransport"
"github.com/smallstep/certificates/webhook"
"go.step.sm/linkedca"
"golang.org/x/crypto/ssh"
Expand All @@ -27,6 +28,7 @@ type Controller struct {
webhookClient *http.Client
webhooks []*Webhook
httpClient *http.Client
wrapTransport httptransport.Wrapper
}

// NewController initializes a new provisioner controller.
Expand All @@ -50,6 +52,7 @@ func NewController(p Interface, claims *Claims, config Config, options *Options)
webhookClient: config.WebhookClient,
webhooks: options.GetWebhooks(),
httpClient: config.HTTPClient,
wrapTransport: config.WrapTransport,
}, nil
}

Expand Down Expand Up @@ -89,16 +92,25 @@ func (c *Controller) AuthorizeSSHRenew(ctx context.Context, cert *ssh.Certificat
}

func (c *Controller) newWebhookController(templateData WebhookSetter, certType linkedca.Webhook_CertType, opts ...webhook.RequestBodyOption) *WebhookController {
wt := c.wrapTransport
if wt == nil {
wt = httptransport.NoopWrapper()
}

client := c.webhookClient
if client == nil {
client = http.DefaultClient
client = &http.Client{
Transport: wt(httptransport.New()),
}
}

return &WebhookController{
TemplateData: templateData,
client: client,
webhooks: c.webhooks,
certType: certType,
options: opts,
TemplateData: templateData,
client: client,
wrapTransport: wt,
webhooks: c.webhooks,
certType: certType,
options: opts,
}
}

Expand Down
16 changes: 12 additions & 4 deletions authority/provisioner/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/internal/httptransport"
"github.com/smallstep/certificates/webhook"
"github.com/stretchr/testify/assert"
"go.step.sm/crypto/pemutil"
Expand Down Expand Up @@ -512,11 +513,18 @@ func Test_newWebhookController(t *testing.T) {
options: opts,
}},
}

for _, tt := range tests {
c := &Controller{}
got := c.newWebhookController(tt.args.templateData, tt.args.certType, tt.args.opts...)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("newWebhookController() = %v, want %v", got, tt.want)
c := Controller{
webhookClient: new(http.Client),
wrapTransport: httptransport.NoopWrapper(),
}
got := c.newWebhookController(tt.args.templateData, tt.args.certType, tt.args.opts...)

assert.Equal(t, tt.args.templateData, got.TemplateData)
assert.Same(t, c.webhookClient, got.client)
assert.Equal(t, c.webhooks, got.webhooks)
assert.Equal(t, tt.args.opts, got.options)
assert.Equal(t, tt.args.certType, got.certType)
}
}
3 changes: 3 additions & 0 deletions authority/provisioner/provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ type Config struct {
// HTTPClient is an HTTP client that trusts the system cert pool and the CA
// roots.
HTTPClient *http.Client
// WrapTransport references the function that should wrap any [http.Transport] initialized
// down the Config's chain.
WrapTransport func(*http.Transport) http.RoundTripper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not httptransport.Wrapper?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it'll be a type alias but it's coming right up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

type provisioner struct {
Expand Down
33 changes: 20 additions & 13 deletions authority/provisioner/scep.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"go.step.sm/crypto/x509util"
"go.step.sm/linkedca"

"github.com/smallstep/certificates/internal/httptransport"
"github.com/smallstep/certificates/webhook"
)

Expand Down Expand Up @@ -112,13 +113,14 @@ func (s *SCEP) DefaultTLSCertDuration() time.Duration {
}

type challengeValidationController struct {
client *http.Client
webhooks []*Webhook
client *http.Client
wrapTransport httptransport.Wrapper
webhooks []*Webhook
}

// newChallengeValidationController creates a new challengeValidationController
// that performs challenge validation through webhooks.
func newChallengeValidationController(client *http.Client, webhooks []*Webhook) *challengeValidationController {
func newChallengeValidationController(client *http.Client, tw httptransport.Wrapper, webhooks []*Webhook) *challengeValidationController {
scepHooks := []*Webhook{}
for _, wh := range webhooks {
if wh.Kind != linkedca.Webhook_SCEPCHALLENGE.String() {
Expand All @@ -130,8 +132,9 @@ func newChallengeValidationController(client *http.Client, webhooks []*Webhook)
scepHooks = append(scepHooks, wh)
}
return &challengeValidationController{
client: client,
webhooks: scepHooks,
client: client,
wrapTransport: tw,
webhooks: scepHooks,
}
}

Expand All @@ -157,7 +160,7 @@ func (c *challengeValidationController) Validate(ctx context.Context, csr *x509.
req.ProvisionerName = provisionerName
req.SCEPChallenge = challenge
req.SCEPTransactionID = transactionID
resp, err := wh.DoWithContext(ctx, c.client, req, nil) // TODO(hs): support templated URL? Requires some refactoring
resp, err := wh.DoWithContext(ctx, c.client, c.wrapTransport, req, nil) // TODO(hs): support templated URL? Requires some refactoring
if err != nil {
return nil, fmt.Errorf("failed executing webhook request: %w", err)
}
Expand All @@ -176,13 +179,14 @@ func (c *challengeValidationController) Validate(ctx context.Context, csr *x509.
}

type notificationController struct {
client *http.Client
webhooks []*Webhook
client *http.Client
wrapTransport httptransport.Wrapper
webhooks []*Webhook
}

// newNotificationController creates a new notificationController
// that performs SCEP notifications through webhooks.
func newNotificationController(client *http.Client, webhooks []*Webhook) *notificationController {
func newNotificationController(client *http.Client, tw httptransport.Wrapper, webhooks []*Webhook) *notificationController {
scepHooks := []*Webhook{}
for _, wh := range webhooks {
if wh.Kind != linkedca.Webhook_NOTIFYING.String() {
Expand All @@ -194,8 +198,9 @@ func newNotificationController(client *http.Client, webhooks []*Webhook) *notifi
scepHooks = append(scepHooks, wh)
}
return &notificationController{
client: client,
webhooks: scepHooks,
client: client,
wrapTransport: tw,
webhooks: scepHooks,
}
}

Expand All @@ -207,7 +212,7 @@ func (c *notificationController) Success(ctx context.Context, csr *x509.Certific
}
req.X509Certificate.Raw = cert.Raw // adding the full certificate DER bytes
req.SCEPTransactionID = transactionID
if _, err = wh.DoWithContext(ctx, c.client, req, nil); err != nil {
if _, err = wh.DoWithContext(ctx, c.client, c.wrapTransport, req, nil); err != nil {
return fmt.Errorf("failed executing webhook request: %w: %w", ErrSCEPNotificationFailed, err)
}
}
Expand All @@ -224,7 +229,7 @@ func (c *notificationController) Failure(ctx context.Context, csr *x509.Certific
req.SCEPTransactionID = transactionID
req.SCEPErrorCode = errorCode
req.SCEPErrorDescription = errorDescription
if _, err = wh.DoWithContext(ctx, c.client, req, nil); err != nil {
if _, err = wh.DoWithContext(ctx, c.client, c.wrapTransport, req, nil); err != nil {
return fmt.Errorf("failed executing webhook request: %w: %w", ErrSCEPNotificationFailed, err)
}
}
Expand Down Expand Up @@ -267,12 +272,14 @@ func (s *SCEP) Init(config Config) (err error) {
// Prepare the SCEP challenge validator
s.challengeValidationController = newChallengeValidationController(
config.WebhookClient,
config.WrapTransport,
s.GetOptions().GetWebhooks(),
)

// Prepare the SCEP notification controller
s.notificationController = newNotificationController(
config.WebhookClient,
config.WrapTransport,
s.GetOptions().GetWebhooks(),
)

Expand Down
2 changes: 1 addition & 1 deletion authority/provisioner/scep_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func Test_challengeValidationController_Validate(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := newChallengeValidationController(tt.fields.client, tt.fields.webhooks)
c := newChallengeValidationController(tt.fields.client, nil, tt.fields.webhooks)
ctx := context.Background()
got, err := c.Validate(ctx, dummyCSR, tt.args.provisionerName, tt.args.challenge, tt.args.transactionID)
if tt.expErr != nil {
Expand Down
Loading
Loading