Skip to content

Commit 032628d

Browse files
committed
increase coverage
Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>
1 parent b0ece26 commit 032628d

File tree

3 files changed

+208
-9
lines changed

3 files changed

+208
-9
lines changed

go.mod

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ require (
6363
github.com/google/s2a-go v0.1.9 // indirect
6464
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect
6565
github.com/googleapis/gax-go/v2 v2.15.0 // indirect
66-
github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f // indirect
6766
github.com/mattn/go-colorable v0.1.14 // indirect
6867
github.com/mattn/go-isatty v0.0.20 // indirect
6968
github.com/moby/docker-image-spec v1.3.1 // indirect
@@ -72,7 +71,6 @@ require (
7271
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
7372
github.com/opencontainers/go-digest v1.0.0 // indirect
7473
github.com/opencontainers/image-spec v1.1.0 // indirect
75-
github.com/pelletier/go-toml v1.9.5 // indirect
7674
github.com/planetscale/vtprotobuf v0.6.1-0.20240917153116-6f2963f01587 // indirect
7775
github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect
7876
github.com/zeebo/errs v1.4.0 // indirect

go.sum

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,6 @@ github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25d
159159
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
160160
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
161161
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
162-
github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f h1:7LYC+Yfkj3CTRcShK0KOL/w6iTiKyqqBA9a41Wnggw8=
163-
github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f/go.mod h1:pFlLw2CfqZiIBOx6BuCeRLCrfxBJipTY0nIOF/VbGcI=
164162
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
165163
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
166164
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
@@ -215,8 +213,6 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
215213
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
216214
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
217215
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
218-
github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8=
219-
github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
220216
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
221217
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
222218
github.com/planetscale/vtprotobuf v0.6.1-0.20240917153116-6f2963f01587 h1:xzZOeCMQLA/W198ZkdVdt4EKFKJtS26B773zNU377ZY=

pkg/atls/atls_test.go

Lines changed: 208 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,32 @@ func generateTestCertPEMWithSubject(t *testing.T, commonName string) string {
7575
return strings.ReplaceAll(string(certPEM), "\n", "\\n")
7676
}
7777

78+
func generateTestCertificateWithExtensions(t *testing.T, extensions []pkix.Extension) *x509.Certificate {
79+
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
80+
require.NoError(t, err)
81+
82+
template := x509.Certificate{
83+
SerialNumber: big.NewInt(1),
84+
Subject: pkix.Name{
85+
CommonName: "test",
86+
},
87+
NotBefore: time.Now(),
88+
NotAfter: time.Now().Add(365 * 24 * time.Hour),
89+
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
90+
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
91+
BasicConstraintsValid: true,
92+
ExtraExtensions: extensions,
93+
}
94+
95+
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
96+
require.NoError(t, err)
97+
98+
cert, err := x509.ParseCertificate(certDER)
99+
require.NoError(t, err)
100+
101+
return cert
102+
}
103+
78104
// TestCertificateSubject tests the CertificateSubject functionality.
79105
func TestDefaultCertificateSubject(t *testing.T) {
80106
subject := DefaultCertificateSubject()
@@ -685,7 +711,8 @@ func TestCertificateVerification(t *testing.T) {
685711
})
686712
}
687713

688-
func TestNewAttestedCAProvider(t *testing.T) {
714+
// TestAttestedCAProvider tests the CA-signed certificate provider.
715+
func TestAttestedCAProvider(t *testing.T) {
689716
mockProvider := new(mocks.Provider)
690717
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
691718
require.NoError(t, err)
@@ -694,8 +721,186 @@ func TestNewAttestedCAProvider(t *testing.T) {
694721
cvmID := "test-cvm-id"
695722
agentToken := "test-token"
696723

697-
provider := NewAttestedCAProvider(attestationProvider, subject, nil, cvmID, agentToken)
698-
assert.NotNil(t, provider)
724+
t.Run("NewAttestedCAProvider", func(t *testing.T) {
725+
provider := NewAttestedCAProvider(attestationProvider, subject, nil, cvmID, agentToken)
726+
assert.NotNil(t, provider)
727+
})
728+
729+
t.Run("SetTTL", func(t *testing.T) {
730+
provider := NewAttestedCAProvider(attestationProvider, subject, nil, cvmID, agentToken)
731+
732+
newTTL := time.Hour * 48
733+
provider.(*attestedCertificateProvider).SetTTL(newTTL)
734+
735+
attestedProvider := provider.(*attestedCertificateProvider)
736+
assert.Equal(t, newTTL, attestedProvider.ttl)
737+
})
738+
}
739+
740+
// TestCASignedCertificateErrors tests error cases in CA-signed certificate generation.
741+
func TestCASignedCertificateErrors(t *testing.T) {
742+
mockProvider := new(mocks.Provider)
743+
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
744+
require.NoError(t, err)
745+
746+
subject := DefaultCertificateSubject()
747+
cvmID := "test-cvm-id"
748+
agentToken := "test-token"
749+
750+
cases := []struct {
751+
name string
752+
certificate string
753+
sdkError error
754+
expectedError string
755+
}{
756+
{"SDKIssueError", "", errors.NewSDKError(errors.New("SDK error")), "SDK error"},
757+
{"InvalidPEMWithRemainingData", "-----BEGIN CERTIFICATE-----\\nVGVzdA==\\n-----END CERTIFICATE-----\\nExtra data here", nil, "unexpected remaining data"},
758+
{"NoPEMBlockFound", "", nil, "no PEM block found"},
759+
}
760+
761+
for _, c := range cases {
762+
t.Run(c.name, func(t *testing.T) {
763+
mockSDK := sdkmocks.NewSDK(t)
764+
mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(certssdk.Certificate{Certificate: c.certificate}, c.sdkError)
765+
766+
provider := NewAttestedCAProvider(attestationProvider, subject, mockSDK, cvmID, agentToken)
767+
attestedProvider := provider.(*attestedCertificateProvider)
768+
769+
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
770+
require.NoError(t, err)
771+
772+
extension := pkix.Extension{
773+
Id: SNPvTPMOID,
774+
Value: []byte("test-data"),
775+
}
776+
777+
_, err = attestedProvider.generateCASignedCertificate(privateKey, extension)
778+
assert.Error(t, err)
779+
assert.Contains(t, err.Error(), c.expectedError)
780+
})
781+
}
782+
}
783+
784+
// TestGetCertificateErrors tests error paths in certificate generation.
785+
func TestGetCertificateErrors(t *testing.T) {
786+
t.Run("InvalidServerNameFormat", func(t *testing.T) {
787+
mockProvider := new(mocks.Provider)
788+
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
789+
require.NoError(t, err)
790+
791+
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
792+
793+
clientHello := &tls.ClientHelloInfo{
794+
ServerName: "invalid-format",
795+
}
796+
797+
_, err = provider.GetCertificate(clientHello)
798+
assert.Error(t, err)
799+
assert.Contains(t, err.Error(), "failed to extract nonce")
800+
})
801+
802+
t.Run("AttestationProviderError", func(t *testing.T) {
803+
mockProvider := new(mocks.Provider)
804+
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
805+
806+
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
807+
require.NoError(t, err)
808+
809+
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
810+
811+
nonce := make([]byte, 64)
812+
_, err = rand.Read(nonce)
813+
require.NoError(t, err)
814+
815+
serverName := hex.EncodeToString(nonce) + ".nonce"
816+
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
817+
818+
_, err = provider.GetCertificate(clientHello)
819+
assert.Error(t, err)
820+
assert.Contains(t, err.Error(), "failed to get attestation")
821+
})
822+
823+
t.Run("CASignedCertificateError", func(t *testing.T) {
824+
mockProvider := new(mocks.Provider)
825+
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil)
826+
827+
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
828+
require.NoError(t, err)
829+
830+
mockSDK := sdkmocks.NewSDK(t)
831+
sdkErr := errors.NewSDKError(errors.New("CA error"))
832+
mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(certssdk.Certificate{}, sdkErr)
833+
834+
provider := NewAttestedCAProvider(attestationProvider, DefaultCertificateSubject(), mockSDK, "test-cvm", "test-token")
835+
836+
nonce := make([]byte, 64)
837+
_, err = rand.Read(nonce)
838+
require.NoError(t, err)
839+
840+
serverName := hex.EncodeToString(nonce) + ".nonce"
841+
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
842+
843+
_, err = provider.GetCertificate(clientHello)
844+
assert.Error(t, err)
845+
assert.Contains(t, err.Error(), "failed to generate certificate")
846+
})
847+
}
848+
849+
// TestCertificateVerificationEdgeCases tests edge cases in certificate verification.
850+
func TestCertificateVerificationEdgeCases(t *testing.T) {
851+
tempDir, err := os.MkdirTemp("", "policy")
852+
require.NoError(t, err)
853+
defer os.RemoveAll(tempDir)
854+
855+
attestationPB := prepVerifyAttReport(t)
856+
err = setAttestationPolicy(attestationPB, tempDir)
857+
require.NoError(t, err)
858+
859+
t.Run("VerifyPeerCertificateWithMultipleCerts", func(t *testing.T) {
860+
verifier := NewCertificateVerifier(nil)
861+
cert1 := createSelfSignedCert(t)
862+
cert2 := createSelfSignedCert(t)
863+
nonce := generateNonce(t)
864+
865+
err := verifier.VerifyPeerCertificate([][]byte{cert1.Raw, cert2.Raw}, nil, nonce)
866+
assert.Error(t, err)
867+
assert.Contains(t, err.Error(), "attestation extension not found")
868+
})
869+
870+
t.Run("VerifyAttestationExtensionWithNoExtensions", func(t *testing.T) {
871+
cert := createSelfSignedCert(t)
872+
verifier := certificateVerifier{}
873+
nonce := generateNonce(t)
874+
875+
err := verifier.verifyAttestationExtension(cert, nonce)
876+
assert.Error(t, err)
877+
assert.Contains(t, err.Error(), "attestation extension not found")
878+
})
879+
880+
t.Run("VerifyAttestationExtensionWithWrongOID", func(t *testing.T) {
881+
wrongOID := asn1.ObjectIdentifier{1, 2, 3, 4, 5}
882+
extension := pkix.Extension{
883+
Id: wrongOID,
884+
Value: []byte("test-data"),
885+
}
886+
887+
cert := generateTestCertificateWithExtensions(t, []pkix.Extension{extension})
888+
verifier := certificateVerifier{}
889+
nonce := generateNonce(t)
890+
891+
err := verifier.verifyAttestationExtension(cert, nonce)
892+
assert.Error(t, err)
893+
assert.Contains(t, err.Error(), "attestation extension not found")
894+
})
895+
896+
t.Run("VerifyCertificateExtensionPlatformVerifierError", func(t *testing.T) {
897+
verifier := certificateVerifier{}
898+
invalidPlatformType := attestation.PlatformType(999)
899+
900+
err := verifier.verifyCertificateExtension([]byte("test-extension"), []byte("test-pubkey"), []byte("test-nonce"), invalidPlatformType)
901+
assert.Error(t, err)
902+
assert.Contains(t, err.Error(), "unsupported platform type")
903+
})
699904
}
700905

701906
// TestCertificateWithAttestationExtension tests certificates with attestation extensions.

0 commit comments

Comments
 (0)