Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions authority/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type Authority struct {
templates *templates.Templates
linkedCAToken string
webhookClient *http.Client
httpClient *http.Client

// X509 CA
password []byte
Expand Down Expand Up @@ -491,6 +492,15 @@ func (a *Authority) init() error {
a.certificates.Store(hex.EncodeToString(sum[:]), crt)
}

// Initialize HTTPClient with all root certs
clientRoots := make([]*x509.Certificate, 0, len(a.rootX509Certs)+len(a.federatedX509Certs))
clientRoots = append(clientRoots, a.rootX509Certs...)
clientRoots = append(clientRoots, a.federatedX509Certs...)
a.httpClient, err = newHTTPClient(clientRoots...)
if err != nil {
return err
}

// Decrypt and load SSH keys
var tmplVars templates.Step
if a.config.SSH != nil {
Expand Down
34 changes: 34 additions & 0 deletions authority/http_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package authority

import (
"crypto/tls"
"crypto/x509"
"fmt"
"net/http"
)

// newHTTPClient returns an HTTP client that trusts the system cert pool and the
// given roots.
func newHTTPClient(roots ...*x509.Certificate) (*http.Client, error) {
pool, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("error initializing http client: %w", err)
}
for _, crt := range roots {
pool.AddCert(crt)
}

tr, ok := http.DefaultTransport.(*http.Transport)
Copy link
Member

@hslatman hslatman Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could turn out to be problematic in certain cases. We can follow up with a fix if so.

if !ok {
return nil, fmt.Errorf("error initializing http client: type is not *http.Transport")
}
tr = tr.Clone()
tr.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: pool,
}

return &http.Client{
Transport: tr,
}, nil
}
105 changes: 105 additions & 0 deletions authority/http_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package authority

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/smallstep/certificates/authority/provisioner"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/x509util"
)

func mustCertificate(t *testing.T, a *Authority, csr *x509.CertificateRequest) []*x509.Certificate {
t.Helper()

ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)

now := time.Now()
signOpts := provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(now),
NotAfter: provisioner.NewTimeDuration(now.Add(5 * time.Minute)),
Backdate: 1 * time.Minute,
}

sans := []string{}
sans = append(sans, csr.DNSNames...)
sans = append(sans, csr.EmailAddresses...)
for _, s := range csr.IPAddresses {
sans = append(sans, s.String())
}
for _, s := range csr.URIs {
sans = append(sans, s.String())
}

key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
require.NoError(t, err)

token, err := generateToken(csr.Subject.CommonName, "step-cli", testAudiences.Sign[0], sans, now, key)
require.NoError(t, err)

extraOpts, err := a.Authorize(ctx, token)
require.NoError(t, err)

chain, err := a.SignWithContext(ctx, csr, signOpts, extraOpts...)
require.NoError(t, err)

return chain
}

func Test_newHTTPClient(t *testing.T) {
signer, err := keyutil.GenerateDefaultSigner()
require.NoError(t, err)

csr, err := x509util.CreateCertificateRequest("test", []string{"localhost", "127.0.0.1", "[::1]"}, signer)
require.NoError(t, err)

auth := testAuthority(t)
chain := mustCertificate(t, auth, csr)

t.Run("SystemCertPool", func(t *testing.T) {
resp, err := auth.httpClient.Get("https://smallstep.com")
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.NotEmpty(t, b)
assert.NoError(t, resp.Body.Close())
})

t.Run("LocalCertPool", func(t *testing.T) {
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "ok")
}))
srv.TLS = &tls.Config{
Certificates: []tls.Certificate{
{Certificate: [][]byte{chain[0].Raw, chain[1].Raw}, PrivateKey: signer, Leaf: chain[0]},
},
}
srv.StartTLS()
defer srv.Close()

resp, err := auth.httpClient.Get(srv.URL)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, []byte("ok"), b)
assert.NoError(t, resp.Body.Close())

t.Run("DefaultClient", func(t *testing.T) {
client := &http.Client{}
_, err := client.Get(srv.URL)
assert.Error(t, err)
})
})
}
4 changes: 2 additions & 2 deletions authority/provisioner/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,14 @@ func (p *Azure) Init(config Config) (err error) {
p.assertConfig()

// Decode and validate openid-configuration endpoint
if err = getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
if err = getAndDecode(http.DefaultClient, p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
return
}
if err := p.oidcConfig.Validate(); err != nil {
return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL)
}
// Get JWK key set
if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil {
if p.keyStore, err = newKeyStore(http.DefaultClient, p.oidcConfig.JWKSetURI); err != nil {
return
}

Expand Down
11 changes: 11 additions & 0 deletions authority/provisioner/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type Controller struct {
policy *policyEngine
webhookClient *http.Client
webhooks []*Webhook
httpClient *http.Client
}

// NewController initializes a new provisioner controller.
Expand All @@ -48,9 +49,19 @@ func NewController(p Interface, claims *Claims, config Config, options *Options)
policy: policy,
webhookClient: config.WebhookClient,
webhooks: options.GetWebhooks(),
httpClient: config.HTTPClient,
}, nil
}

// GetHTTPClient returns the configured HTTP client or the default one if none
// is configured.
func (c *Controller) GetHTTPClient() *http.Client {
if c.httpClient != nil {
return c.httpClient
}
return &http.Client{}
}

// GetIdentity returns the identity for a given email.
func (c *Controller) GetIdentity(ctx context.Context, email string) (*Identity, error) {
if c.IdentityFunc != nil {
Expand Down
42 changes: 34 additions & 8 deletions authority/provisioner/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import (
"testing"
"time"

"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/webhook"
"github.com/stretchr/testify/assert"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util"
"go.step.sm/linkedca"
"golang.org/x/crypto/ssh"

"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/webhook"
)

var trueValue = true
Expand Down Expand Up @@ -79,12 +79,14 @@ func TestNewController(t *testing.T) {
wantErr bool
}{
{"ok", args{&JWK{}, nil, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
Claims: globalProvisionerClaims,
Audiences: testAudiences,
HTTPClient: &http.Client{},
}, nil}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
httpClient: &http.Client{},
}, false},
{"ok with claims", args{&JWK{}, &Claims{
DisableRenewal: &defaultDisableRenewal,
Expand Down Expand Up @@ -145,6 +147,30 @@ func TestNewController(t *testing.T) {
}
}

func TestController_GetHTTPClient(t *testing.T) {
srv := generateTLSJWKServer(2)
defer srv.Close()
type fields struct {
httpClient *http.Client
}
tests := []struct {
name string
fields fields
want *http.Client
}{
{"ok custom", fields{srv.Client()}, srv.Client()},
{"ok default", fields{http.DefaultClient}, http.DefaultClient},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Controller{
httpClient: tt.fields.httpClient,
}
assert.Equal(t, tt.want, c.GetHTTPClient())
})
}
}

func TestController_GetIdentity(t *testing.T) {
ctx := context.Background()
type fields struct {
Expand Down
2 changes: 1 addition & 1 deletion authority/provisioner/gcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (p *GCP) Init(config Config) (err error) {
p.assertConfig()

// Initialize key store
if p.keyStore, err = newKeyStore(p.config.CertsURL); err != nil {
if p.keyStore, err = newKeyStore(http.DefaultClient, p.config.CertsURL); err != nil {
return
}

Expand Down
12 changes: 7 additions & 5 deletions authority/provisioner/keystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,21 @@ var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`)

type keyStore struct {
sync.RWMutex
client *http.Client
uri string
keySet jose.JSONWebKeySet
timer *time.Timer
expiry time.Time
jitter time.Duration
}

func newKeyStore(uri string) (*keyStore, error) {
keys, age, err := getKeysFromJWKsURI(uri)
func newKeyStore(client *http.Client, uri string) (*keyStore, error) {
keys, age, err := getKeysFromJWKsURI(client, uri)
if err != nil {
return nil, err
}
ks := &keyStore{
client: client,
uri: uri,
keySet: keys,
expiry: getExpirationTime(age),
Expand Down Expand Up @@ -64,7 +66,7 @@ func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {

func (ks *keyStore) reload() {
var next time.Duration
keys, age, err := getKeysFromJWKsURI(ks.uri)
keys, age, err := getKeysFromJWKsURI(ks.client, ks.uri)
if err != nil {
next = ks.nextReloadDuration(ks.jitter / 2)
} else {
Expand All @@ -90,9 +92,9 @@ func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
return abs(age)
}

func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
func getKeysFromJWKsURI(client *http.Client, uri string) (jose.JSONWebKeySet, time.Duration, error) {
var keys jose.JSONWebKeySet
resp, err := http.Get(uri) //nolint:gosec // openid-configuration jwks_uri
resp, err := client.Get(uri)
if err != nil {
return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri)
}
Expand Down
Loading