Skip to content

Commit 1412681

Browse files
authored
Merge pull request #580 from smallstep/josh/tpm-capalgs
Add method to obtain TPM capabilities
2 parents c7de661 + 6463150 commit 1412681

File tree

9 files changed

+485
-22
lines changed

9 files changed

+485
-22
lines changed

kms/tpmkms/tpmkms.go

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"go.step.sm/crypto/kms/apiv1"
2929
"go.step.sm/crypto/kms/uri"
3030
"go.step.sm/crypto/tpm"
31+
"go.step.sm/crypto/tpm/algorithm"
3132
"go.step.sm/crypto/tpm/attestation"
3233
"go.step.sm/crypto/tpm/storage"
3334
"go.step.sm/crypto/tpm/tss2"
@@ -39,6 +40,32 @@ func init() {
3940
})
4041
}
4142

43+
// PreferredSignatureAlgorithms indicates the preferred selection of signature
44+
// algorithms when an explicit value is omitted in CreateKeyRequest
45+
var preferredSignatureAlgorithms []apiv1.SignatureAlgorithm
46+
47+
// SetPreferredSignatureAlgorithms sets the preferred signature algorithms
48+
// to select from when explicit values are omitted in CreateKeyRequest
49+
//
50+
// # Experimental
51+
//
52+
// Notice: This method is EXPERIMENTAL and may be changed or removed in a later
53+
// release.
54+
func SetPreferredSignatureAlgorithms(algs []apiv1.SignatureAlgorithm) {
55+
preferredSignatureAlgorithms = algs
56+
}
57+
58+
// PreferredSignatureAlgorithms returns the preferred signature algorithms
59+
// to select from when explicit values are omitted in CreateKeyRequest
60+
//
61+
// # Experimental
62+
//
63+
// Notice: This method is EXPERIMENTAL and may be changed or removed in a later
64+
// release.
65+
func PreferredSignatureAlgorithms() []apiv1.SignatureAlgorithm {
66+
return preferredSignatureAlgorithms
67+
}
68+
4269
// Scheme is the scheme used in TPM KMS URIs, the string "tpmkms".
4370
const Scheme = string(apiv1.TPMKMS)
4471

@@ -73,21 +100,22 @@ type TPMKMS struct {
73100
}
74101

75102
type algorithmAttributes struct {
76-
Type string
77-
Curve int
103+
Type string
104+
Curve int
105+
Requires []algorithm.Algorithm
78106
}
79107

80108
var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]algorithmAttributes{
81-
apiv1.UnspecifiedSignAlgorithm: {"RSA", -1},
82-
apiv1.SHA256WithRSA: {"RSA", -1},
83-
apiv1.SHA384WithRSA: {"RSA", -1},
84-
apiv1.SHA512WithRSA: {"RSA", -1},
85-
apiv1.SHA256WithRSAPSS: {"RSA", -1},
86-
apiv1.SHA384WithRSAPSS: {"RSA", -1},
87-
apiv1.SHA512WithRSAPSS: {"RSA", -1},
88-
apiv1.ECDSAWithSHA256: {"ECDSA", 256},
89-
apiv1.ECDSAWithSHA384: {"ECDSA", 384},
90-
apiv1.ECDSAWithSHA512: {"ECDSA", 521},
109+
apiv1.UnspecifiedSignAlgorithm: {"RSA", -1, []algorithm.Algorithm{algorithm.AlgorithmRSA}},
110+
apiv1.SHA256WithRSA: {"RSA", -1, []algorithm.Algorithm{algorithm.AlgorithmRSA, algorithm.AlgorithmSHA256}},
111+
apiv1.SHA384WithRSA: {"RSA", -1, []algorithm.Algorithm{algorithm.AlgorithmRSA, algorithm.AlgorithmSHA384}},
112+
apiv1.SHA512WithRSA: {"RSA", -1, []algorithm.Algorithm{algorithm.AlgorithmRSA, algorithm.AlgorithmSHA512}},
113+
apiv1.SHA256WithRSAPSS: {"RSA", -1, []algorithm.Algorithm{algorithm.AlgorithmRSAPSS, algorithm.AlgorithmSHA256}},
114+
apiv1.SHA384WithRSAPSS: {"RSA", -1, []algorithm.Algorithm{algorithm.AlgorithmRSAPSS, algorithm.AlgorithmSHA384}},
115+
apiv1.SHA512WithRSAPSS: {"RSA", -1, []algorithm.Algorithm{algorithm.AlgorithmRSAPSS, algorithm.AlgorithmSHA512}},
116+
apiv1.ECDSAWithSHA256: {"ECDSA", 256, []algorithm.Algorithm{algorithm.AlgorithmECDSA, algorithm.AlgorithmSHA256}},
117+
apiv1.ECDSAWithSHA384: {"ECDSA", 384, []algorithm.Algorithm{algorithm.AlgorithmECDSA, algorithm.AlgorithmSHA384}},
118+
apiv1.ECDSAWithSHA512: {"ECDSA", 521, []algorithm.Algorithm{algorithm.AlgorithmECDSA, algorithm.AlgorithmSHA512}},
91119
}
92120

93121
const (
@@ -326,9 +354,36 @@ func (k *TPMKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons
326354
return nil, fmt.Errorf("failed parsing %q: %w", req.Name, err)
327355
}
328356

329-
v, ok := signatureAlgorithmMapping[req.SignatureAlgorithm]
330-
if !ok {
331-
return nil, fmt.Errorf("TPMKMS does not support signature algorithm %q", req.SignatureAlgorithm)
357+
ctx := context.Background()
358+
caps, err := k.tpm.GetCapabilities(ctx)
359+
if err != nil {
360+
return nil, fmt.Errorf("could not get TPM capabilities: %w", err)
361+
}
362+
363+
var (
364+
v algorithmAttributes
365+
ok bool
366+
)
367+
if !properties.ak && req.SignatureAlgorithm == apiv1.UnspecifiedSignAlgorithm && len(preferredSignatureAlgorithms) > 0 {
368+
for _, alg := range preferredSignatureAlgorithms {
369+
v, ok = signatureAlgorithmMapping[alg]
370+
if !ok {
371+
return nil, fmt.Errorf("TPMKMS does not support signature algorithm %q", alg)
372+
}
373+
374+
if caps.SupportsAlgorithms(v.Requires) {
375+
break
376+
}
377+
}
378+
} else {
379+
v, ok = signatureAlgorithmMapping[req.SignatureAlgorithm]
380+
if !ok {
381+
return nil, fmt.Errorf("TPMKMS does not support signature algorithm %q", req.SignatureAlgorithm)
382+
}
383+
384+
if !caps.SupportsAlgorithms(v.Requires) {
385+
return nil, fmt.Errorf("signature algorithm %q not supported by the TPM device", req.SignatureAlgorithm)
386+
}
332387
}
333388

334389
if properties.ak && v.Type == "ECDSA" {
@@ -348,8 +403,6 @@ func (k *TPMKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons
348403
size = v.Curve
349404
}
350405

351-
ctx := context.Background()
352-
353406
var privateKey any
354407
if properties.ak {
355408
ak, err := k.tpm.CreateAK(ctx, properties.name) // NOTE: size is never passed for AKs; it's hardcoded to 2048 in lower levels.

kms/tpmkms/tpmkms_simulator_test.go

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,19 @@ import (
3737
"go.step.sm/crypto/tpm/tss2"
3838
)
3939

40-
type newSimulatedTPMOption func(t *testing.T, tpm *tpmp.TPM)
40+
type newSimulatedTPMOption any
4141

42-
func withAK(name string) newSimulatedTPMOption {
42+
type newSimulatedTPMPreparerOption func(t *testing.T, tpm *tpmp.TPM)
43+
44+
func withAK(name string) newSimulatedTPMPreparerOption {
4345
return func(t *testing.T, tpm *tpmp.TPM) {
4446
t.Helper()
4547
_, err := tpm.CreateAK(context.Background(), name)
4648
require.NoError(t, err)
4749
}
4850
}
4951

50-
func withKey(name string) newSimulatedTPMOption {
52+
func withKey(name string) newSimulatedTPMPreparerOption {
5153
return func(t *testing.T, tpm *tpmp.TPM) {
5254
t.Helper()
5355
config := tpmp.CreateKeyConfig{
@@ -59,14 +61,38 @@ func withKey(name string) newSimulatedTPMOption {
5961
}
6062
}
6163

64+
func withCapabilities(caps *tpmp.Capabilities) tpmp.NewTPMOption {
65+
return tpmp.WithCapabilities(caps)
66+
}
67+
6268
func newSimulatedTPM(t *testing.T, opts ...newSimulatedTPMOption) *tpmp.TPM {
6369
t.Helper()
70+
6471
tmpDir := t.TempDir()
65-
tpm, err := tpmp.New(withSimulator(t), tpmp.WithStore(storage.NewDirstore(tmpDir)))
72+
tpmOpts := []tpmp.NewTPMOption{
73+
withSimulator(t),
74+
tpmp.WithStore(storage.NewDirstore(tmpDir)),
75+
}
76+
77+
var preparers []newSimulatedTPMPreparerOption
78+
for _, opt := range opts {
79+
switch o := opt.(type) {
80+
case tpmp.NewTPMOption:
81+
tpmOpts = append(tpmOpts, o)
82+
case newSimulatedTPMPreparerOption:
83+
preparers = append(preparers, o)
84+
default:
85+
require.Fail(t, "invalid TPM option type provided", `TPM option type "%T"`, o)
86+
}
87+
}
88+
89+
tpm, err := tpmp.New(tpmOpts...)
6690
require.NoError(t, err)
67-
for _, applyTo := range opts {
91+
92+
for _, applyTo := range preparers {
6893
applyTo(t, tpm)
6994
}
95+
7096
return tpm
7197
}
7298

@@ -87,6 +113,60 @@ func withSimulator(t *testing.T) tpmp.NewTPMOption {
87113
return tpmp.WithSimulator(sim)
88114
}
89115

116+
func TestTPMKMS_CreateKey_Capabilities(t *testing.T) {
117+
tpmWithNoCaps := newSimulatedTPM(t, withCapabilities(&tpmp.Capabilities{}))
118+
type fields struct {
119+
tpm *tpmp.TPM
120+
}
121+
type args struct {
122+
req *apiv1.CreateKeyRequest
123+
}
124+
tests := []struct {
125+
name string
126+
fields fields
127+
args args
128+
assertFunc assert.ValueAssertionFunc
129+
expErr error
130+
}{
131+
{
132+
name: "fail/unsupported-algorithm",
133+
fields: fields{
134+
tpm: tpmWithNoCaps,
135+
},
136+
args: args{
137+
req: &apiv1.CreateKeyRequest{
138+
Name: "tpmkms:name=key1",
139+
SignatureAlgorithm: apiv1.SHA256WithRSA,
140+
Bits: 2048,
141+
},
142+
},
143+
assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool {
144+
if assert.IsType(t, &apiv1.CreateKeyResponse{}, i1) {
145+
r, _ := i1.(*apiv1.CreateKeyResponse)
146+
return assert.Nil(t, r)
147+
}
148+
return false
149+
},
150+
expErr: errors.New(`signature algorithm "SHA256-RSA" not supported by the TPM device`),
151+
},
152+
}
153+
for _, tt := range tests {
154+
t.Run(tt.name, func(t *testing.T) {
155+
k := &TPMKMS{
156+
tpm: tt.fields.tpm,
157+
}
158+
got, err := k.CreateKey(tt.args.req)
159+
if tt.expErr != nil {
160+
assert.EqualError(t, err, tt.expErr.Error())
161+
return
162+
}
163+
164+
assert.NoError(t, err)
165+
assert.True(t, tt.assertFunc(t, got))
166+
})
167+
}
168+
}
169+
90170
func TestTPMKMS_CreateKey(t *testing.T) {
91171
tpmWithAK := newSimulatedTPM(t, withAK("ak1"))
92172
type fields struct {

kms/tpmkms/tpmkms_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,17 @@ func Test_notFoundError(t *testing.T) {
138138
})
139139
}
140140
}
141+
142+
func Test_SetPreferredSignatureAlgorithms(t *testing.T) {
143+
old := preferredSignatureAlgorithms
144+
want := []apiv1.SignatureAlgorithm{
145+
apiv1.ECDSAWithSHA256,
146+
}
147+
SetPreferredSignatureAlgorithms(want)
148+
assert.Equal(t, preferredSignatureAlgorithms, want)
149+
SetPreferredSignatureAlgorithms(old)
150+
}
151+
152+
func Test_PreferredSignatureAlgorithms(t *testing.T) {
153+
assert.Equal(t, PreferredSignatureAlgorithms(), preferredSignatureAlgorithms)
154+
}

tpm/algorithm/algorithm.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
package algorithm
2+
3+
import (
4+
"encoding/json"
5+
)
6+
7+
// Supported Algorithms.
8+
const (
9+
AlgorithmUnknown Algorithm = 0x0000
10+
AlgorithmRSA Algorithm = 0x0001
11+
Algorithm3DES Algorithm = 0x0003
12+
AlgorithmSHA1 Algorithm = 0x0004
13+
AlgorithmHMAC Algorithm = 0x0005
14+
AlgorithmAES Algorithm = 0x0006
15+
AlgorithmMGF1 Algorithm = 0x0007
16+
AlgorithmKeyedHash Algorithm = 0x0008
17+
AlgorithmXOR Algorithm = 0x000A
18+
AlgorithmSHA256 Algorithm = 0x000B
19+
AlgorithmSHA384 Algorithm = 0x000C
20+
AlgorithmSHA512 Algorithm = 0x000D
21+
AlgorithmNull Algorithm = 0x0010
22+
AlgorithmSM3256 Algorithm = 0x0012
23+
AlgorithmSM4 Algorithm = 0x0013
24+
AlgorithmRSASSA Algorithm = 0x0014
25+
AlgorithmRSAES Algorithm = 0x0015
26+
AlgorithmRSAPSS Algorithm = 0x0016
27+
AlgorithmOAEP Algorithm = 0x0017
28+
AlgorithmECDSA Algorithm = 0x0018
29+
AlgorithmECDH Algorithm = 0x0019
30+
AlgorithmECDAA Algorithm = 0x001A
31+
AlgorithmECSchnorr Algorithm = 0x001C
32+
AlgorithmKDF1_56A Algorithm = 0x0020
33+
AlgorithmKDF2 Algorithm = 0x0021
34+
AlgorithmKDF1_108 Algorithm = 0x0022
35+
AlgorithmECC Algorithm = 0x0023
36+
AlgorithmSymCipher Algorithm = 0x0025
37+
AlgorithmCamellia Algorithm = 0x0026
38+
AlgorithmSHA3_256 Algorithm = 0x0027
39+
AlgorithmSHA3_384 Algorithm = 0x0028
40+
AlgorithmSHA3_512 Algorithm = 0x0029
41+
AlgorithmCMAC Algorithm = 0x003F
42+
AlgorithmCTR Algorithm = 0x0040
43+
AlgorithmOFB Algorithm = 0x0041
44+
AlgorithmCBC Algorithm = 0x0042
45+
AlgorithmCFB Algorithm = 0x0043
46+
AlgorithmECB Algorithm = 0x0044
47+
)
48+
49+
// https://trustedcomputinggroup.org/wp-content/uploads/TCG_TPM2_r1p59_Part2_Structures_pub.pdf
50+
var algs = map[Algorithm]string{
51+
// object types
52+
AlgorithmRSA: "RSA",
53+
AlgorithmECC: "ECC",
54+
55+
// encryption algs
56+
AlgorithmRSAES: "RSAES",
57+
58+
// block ciphers
59+
Algorithm3DES: "3DES",
60+
AlgorithmAES: "AES",
61+
AlgorithmCamellia: "Camellia",
62+
AlgorithmECB: "ECB",
63+
AlgorithmCFB: "CFB",
64+
AlgorithmOFB: "OFB",
65+
AlgorithmCBC: "CBC",
66+
AlgorithmCTR: "CTR",
67+
AlgorithmSymCipher: "Symmetric Cipher",
68+
AlgorithmCMAC: "CMAC",
69+
70+
// other ciphers
71+
AlgorithmXOR: "XOR",
72+
AlgorithmNull: "Null Cipher",
73+
74+
// hash algs
75+
AlgorithmSHA1: "SHA-1",
76+
AlgorithmHMAC: "HMAC",
77+
AlgorithmMGF1: "MGF1",
78+
AlgorithmKeyedHash: "Keyed Hash",
79+
AlgorithmSM3256: "SM3-256",
80+
AlgorithmSHA256: "SHA-256",
81+
AlgorithmSHA384: "SHA-384",
82+
AlgorithmSHA512: "SHA-512",
83+
AlgorithmSHA3_256: "SHA3-256",
84+
AlgorithmSHA3_384: "SHA3-384",
85+
AlgorithmSHA3_512: "SHA3-512",
86+
87+
// signature algs
88+
AlgorithmSM4: "SM4",
89+
AlgorithmRSASSA: "RSA-SSA",
90+
AlgorithmRSAPSS: "RSA-PSS",
91+
AlgorithmECDSA: "ECDSA",
92+
AlgorithmECDAA: "ECDAA",
93+
AlgorithmECSchnorr: "EC-Schnorr",
94+
95+
// encryption schemes
96+
AlgorithmOAEP: "OAEP",
97+
AlgorithmECDH: "ECDH",
98+
99+
// key derivation
100+
AlgorithmKDF1_56A: "KDF1-SP800-56A",
101+
AlgorithmKDF1_108: "KDF1-SP800-108",
102+
AlgorithmKDF2: "KDF2",
103+
}
104+
105+
type Algorithm uint16
106+
107+
func (a Algorithm) String() string {
108+
return algs[Algorithm(int(a))]
109+
}
110+
111+
func (a Algorithm) MarshalJSON() ([]byte, error) {
112+
return json.Marshal(a.String())
113+
}

0 commit comments

Comments
 (0)