diff --git a/authority/provisioner/controller.go b/authority/provisioner/controller.go index 13f786954..de9f54007 100644 --- a/authority/provisioner/controller.go +++ b/authority/provisioner/controller.go @@ -41,6 +41,10 @@ func NewController(p Interface, claims *Claims, config Config, options *Options) if err != nil { return nil, err } + wt := config.WrapTransport + if wt == nil { + wt = httptransport.NoopWrapper() + } return &Controller{ Interface: p, Audiences: &config.Audiences, @@ -52,7 +56,7 @@ func NewController(p Interface, claims *Claims, config Config, options *Options) webhookClient: config.WebhookClient, webhooks: options.GetWebhooks(), httpClient: config.HTTPClient, - wrapTransport: config.WrapTransport, + wrapTransport: wt, }, nil } @@ -92,22 +96,17 @@ func (c *Controller) AuthorizeSSHRenew(ctx context.Context, cert *ssh.Certificat } func (c *Controller) newWebhookController(templateData WebhookSetter, certType linkedca.Webhook_CertType, opts ...webhook.RequestBodyOption) *WebhookController { - wt := c.wrapTransport - if wt == nil { - wt = httptransport.NoopWrapper() - } - client := c.webhookClient if client == nil { client = &http.Client{ - Transport: wt(httptransport.New()), + Transport: c.wrapTransport(httptransport.New()), } } return &WebhookController{ TemplateData: templateData, client: client, - wrapTransport: wt, + wrapTransport: c.wrapTransport, webhooks: c.webhooks, certType: certType, options: opts, diff --git a/authority/provisioner/controller_test.go b/authority/provisioner/controller_test.go index 86a5b4995..11da95b6c 100644 --- a/authority/provisioner/controller_test.go +++ b/authority/provisioner/controller_test.go @@ -80,14 +80,16 @@ func TestNewController(t *testing.T) { wantErr bool }{ {"ok", args{&JWK{}, nil, Config{ - Claims: globalProvisionerClaims, - Audiences: testAudiences, - HTTPClient: &http.Client{}, + Claims: globalProvisionerClaims, + Audiences: testAudiences, + HTTPClient: &http.Client{}, + WrapTransport: httptransport.NoopWrapper(), }, nil}, &Controller{ - Interface: &JWK{}, - Audiences: &testAudiences, - Claimer: mustClaimer(t, nil, globalProvisionerClaims), - httpClient: &http.Client{}, + Interface: &JWK{}, + Audiences: &testAudiences, + Claimer: mustClaimer(t, nil, globalProvisionerClaims), + httpClient: &http.Client{}, + wrapTransport: httptransport.NoopWrapper(), }, false}, {"ok with claims", args{&JWK{}, &Claims{ DisableRenewal: &defaultDisableRenewal, @@ -100,6 +102,7 @@ func TestNewController(t *testing.T) { Claimer: mustClaimer(t, &Claims{ DisableRenewal: &defaultDisableRenewal, }, globalProvisionerClaims), + wrapTransport: httptransport.NoopWrapper(), }, false}, {"ok with claims and options", args{&JWK{}, &Claims{ DisableRenewal: &defaultDisableRenewal, @@ -112,7 +115,8 @@ func TestNewController(t *testing.T) { Claimer: mustClaimer(t, &Claims{ DisableRenewal: &defaultDisableRenewal, }, globalProvisionerClaims), - policy: mustNewPolicyEngine(t, options), + policy: mustNewPolicyEngine(t, options), + wrapTransport: httptransport.NoopWrapper(), }, false}, {"fail claimer", args{&JWK{}, &Claims{ MinTLSDur: mustDuration(t, "24h"), @@ -141,6 +145,14 @@ func TestNewController(t *testing.T) { t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr) return } + + // A function can only be compared to nil + if tt.want != nil && got != nil { + assert.NotNil(t, got.wrapTransport) + tt.want.wrapTransport = nil + got.wrapTransport = nil + } + if !reflect.DeepEqual(got, tt.want) { t.Errorf("NewController() = %v, want %v", got, tt.want) }