Skip to content

Commit 3e83705

Browse files
authored
Add TPM option for master key (#1)
Add an option to store master keys on a TPM. When this option is used, the data can only be decrypted with the same TPM.
1 parent 87be173 commit 3e83705

File tree

8 files changed

+592
-254
lines changed

8 files changed

+592
-254
lines changed

crypto/aes.go

Lines changed: 172 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@
2424
package crypto
2525

2626
import (
27+
"crypto"
2728
"crypto/aes"
2829
"crypto/cipher"
2930
"crypto/hmac"
3031
"crypto/rand"
32+
"crypto/rsa"
3133
"crypto/sha256"
3234
"encoding/binary"
3335
"errors"
@@ -38,6 +40,8 @@ import (
3840
"path/filepath"
3941
"runtime"
4042

43+
"github.com/c2FmZQ/tpm"
44+
"golang.org/x/crypto/cryptobyte"
4145
"golang.org/x/crypto/pbkdf2"
4246
)
4347

@@ -58,6 +62,7 @@ type AESKey struct {
5862

5963
logger Logger
6064
strictWipe bool
65+
tpmKey *tpm.Key
6166
}
6267

6368
func (k *AESKey) Logger() Logger {
@@ -98,24 +103,24 @@ type AESMasterKey struct {
98103

99104
// CreateAESMasterKey creates a new master key.
100105
func CreateAESMasterKey(opts ...Option) (MasterKey, error) {
101-
var logger Logger = defaultLogger{}
102-
var strictWipe bool
103-
for _, opt := range opts {
104-
if opt.logger != nil {
105-
logger = opt.logger
106-
}
107-
if opt.strictWipe != nil {
108-
strictWipe = *opt.strictWipe
109-
}
110-
}
106+
var opt option
107+
opt.apply(opts)
111108
b := make([]byte, 64)
112109
if _, err := rand.Read(b); err != nil {
113110
return nil, err
114111
}
115112
key := aesKeyFromBytes(b)
116-
key.logger = logger
117-
key.strictWipe = strictWipe
118-
return &AESMasterKey{key}, nil
113+
key.logger = opt.logger
114+
key.strictWipe = opt.strictWipe
115+
mk := &AESMasterKey{key}
116+
if opt.tpm != nil {
117+
tpmkey, err := opt.tpm.CreateKey()
118+
if err != nil {
119+
return nil, err
120+
}
121+
mk.tpmKey = tpmkey
122+
}
123+
return mk, nil
119124
}
120125

121126
// CreateAESMasterKeyForTest creates a new master key to tests.
@@ -133,51 +138,90 @@ func CreateAESMasterKeyForTest() (MasterKey, error) {
133138

134139
// ReadAESMasterKey reads an encrypted master key from file and decrypts it.
135140
func ReadAESMasterKey(passphrase []byte, file string, opts ...Option) (MasterKey, error) {
136-
var logger Logger = defaultLogger{}
137-
var strictWipe bool
138-
for _, opt := range opts {
139-
if opt.logger != nil {
140-
logger = opt.logger
141-
}
142-
if opt.strictWipe != nil {
143-
strictWipe = *opt.strictWipe
144-
}
145-
}
141+
var opt option
142+
opt.apply(opts)
146143
b, err := os.ReadFile(file)
147144
if err != nil {
148145
return nil, err
149146
}
150147
if len(b) < 64 {
151148
return nil, ErrDecryptFailed
152149
}
153-
version, b := b[0], b[1:]
154-
if version != 1 {
155-
logger.Debugf("ReadMasterKey: unexpected version: %d", version)
150+
str := cryptobyte.String(b)
151+
var version uint8
152+
if !str.ReadUint8(&version) {
156153
return nil, ErrDecryptFailed
157154
}
158-
salt, b := b[:16], b[16:]
159-
numIter, b := int(binary.BigEndian.Uint32(b[:4])), b[4:]
160-
dk := pbkdf2.Key(passphrase, salt, numIter, 32, sha256.New)
155+
if version != 1 && version != 3 {
156+
opt.logger.Debugf("ReadMasterKey: unexpected version: %d", version)
157+
return nil, ErrDecryptFailed
158+
}
159+
if version == 3 && opt.tpm == nil {
160+
opt.logger.Debug("ReadMasterKey: missing WithTPM option")
161+
return nil, ErrDecryptFailed
162+
}
163+
salt := make([]byte, 16)
164+
if !str.ReadBytes(&salt, 16) {
165+
return nil, ErrDecryptFailed
166+
}
167+
var numIter uint32
168+
if !str.ReadUint32(&numIter) {
169+
return nil, ErrDecryptFailed
170+
}
171+
dk := pbkdf2.Key(passphrase, salt, int(numIter), 32, sha256.New)
161172
block, err := aes.NewCipher(dk)
162173
if err != nil {
163-
logger.Debug(err)
174+
opt.logger.Debug(err)
164175
return nil, ErrDecryptFailed
165176
}
166177
gcm, err := cipher.NewGCM(block)
167178
if err != nil {
168-
logger.Debug(err)
179+
opt.logger.Debug(err)
169180
return nil, ErrDecryptFailed
170181
}
171-
nonce := b[:gcm.NonceSize()]
172-
encMasterKey := b[gcm.NonceSize():]
173-
mkBytes, err := gcm.Open(nil, nonce, encMasterKey, nil)
182+
nonce := make([]byte, gcm.NonceSize())
183+
if !str.ReadBytes(&nonce, len(nonce)) {
184+
return nil, ErrDecryptFailed
185+
}
186+
mkBytes, err := gcm.Open(nil, nonce, []byte(str), nil)
174187
if err != nil {
175-
logger.Debug(err)
188+
opt.logger.Debug(err)
176189
return nil, ErrDecryptFailed
177190
}
178-
key := aesKeyFromBytes(mkBytes)
179-
key.logger = logger
180-
key.strictWipe = strictWipe
191+
var key *AESKey
192+
if version == 1 {
193+
key = aesKeyFromBytes(mkBytes)
194+
} else { // version == 3
195+
str := cryptobyte.String(mkBytes)
196+
var length uint16
197+
if !str.ReadUint16(&length) {
198+
return nil, ErrDecryptFailed
199+
}
200+
encKey := make([]byte, length)
201+
if !str.ReadBytes(&encKey, len(encKey)) {
202+
return nil, ErrDecryptFailed
203+
}
204+
if !str.ReadUint16(&length) {
205+
return nil, ErrDecryptFailed
206+
}
207+
tpmCtx := make([]byte, length)
208+
if !str.ReadBytes(&tpmCtx, len(tpmCtx)) {
209+
return nil, ErrDecryptFailed
210+
}
211+
tpmKey, err := opt.tpm.UnmarshalKey(tpmCtx)
212+
if err != nil {
213+
return nil, err
214+
}
215+
decKey, err := tpmKey.Decrypt(nil, encKey, nil)
216+
if err != nil {
217+
opt.logger.Debug(err)
218+
return nil, ErrDecryptFailed
219+
}
220+
key = aesKeyFromBytes(decKey)
221+
key.tpmKey = tpmKey
222+
}
223+
key.logger = opt.logger
224+
key.strictWipe = opt.strictWipe
181225
return &AESMasterKey{key}, nil
182226
}
183227

@@ -191,8 +235,6 @@ func (mk AESMasterKey) Save(passphrase []byte, file string) error {
191235
if len(passphrase) == 0 {
192236
numIter = 10
193237
}
194-
numIterBin := make([]byte, 4)
195-
binary.BigEndian.PutUint32(numIterBin, uint32(numIter))
196238
dk := pbkdf2.Key(passphrase, salt, numIter, 32, sha256.New)
197239
block, err := aes.NewCipher(dk)
198240
if err != nil {
@@ -209,11 +251,44 @@ func (mk AESMasterKey) Save(passphrase []byte, file string) error {
209251
mk.Logger().Debug(err)
210252
return ErrEncryptFailed
211253
}
212-
encMasterKey := gcm.Seal(nonce, nonce, mk.key(), nil)
213-
data := []byte{1} // version
214-
data = append(data, salt...)
215-
data = append(data, numIterBin...)
216-
data = append(data, encMasterKey...)
254+
var version uint8
255+
var payload []byte
256+
if mk.tpmKey == nil {
257+
version = 1
258+
payload = mk.key()
259+
} else {
260+
version = 3
261+
buf := cryptobyte.NewBuilder(nil)
262+
// encKey, err := mk.tpmKey.Encrypt(mk.key())
263+
encKey, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, mk.tpmKey.Public().(*rsa.PublicKey), mk.key(), nil)
264+
if err != nil {
265+
mk.Logger().Debug(err)
266+
return ErrEncryptFailed
267+
}
268+
buf.AddUint16(uint16(len(encKey)))
269+
buf.AddBytes(encKey)
270+
keyctx, err := mk.tpmKey.Marshal()
271+
if err != nil {
272+
mk.Logger().Debug(err)
273+
return ErrEncryptFailed
274+
}
275+
buf.AddUint16(uint16(len(keyctx)))
276+
buf.AddBytes(keyctx)
277+
if payload, err = buf.Bytes(); err != nil {
278+
mk.Logger().Debug(err)
279+
return ErrEncryptFailed
280+
}
281+
}
282+
encMasterKey := gcm.Seal(nonce, nonce, payload, nil)
283+
buf := cryptobyte.NewBuilder([]byte{version})
284+
buf.AddBytes(salt)
285+
buf.AddUint32(uint32(numIter))
286+
buf.AddBytes(encMasterKey)
287+
data, err := buf.Bytes()
288+
if err != nil {
289+
mk.Logger().Debug(err)
290+
return ErrEncryptFailed
291+
}
217292
dir, _ := filepath.Split(file)
218293
if err := os.MkdirAll(dir, 0700); err != nil {
219294
return err
@@ -237,6 +312,23 @@ func (k AESKey) Hash(b []byte) []byte {
237312

238313
// Decrypt decrypts data that was encrypted with Encrypt and the same key.
239314
func (k AESKey) Decrypt(data []byte) ([]byte, error) {
315+
if k.tpmKey != nil {
316+
sigSize := k.tpmKey.Bits() / 8
317+
if len(data) < 1+sigSize {
318+
return nil, ErrDecryptFailed
319+
}
320+
version, data := data[0], data[1:]
321+
if version != 3 {
322+
return nil, ErrDecryptFailed
323+
}
324+
encData, data := data[:len(data)-sigSize], data[len(data)-sigSize:]
325+
sig := data[:sigSize]
326+
hashed := sha256.Sum256(encData)
327+
if err := rsa.VerifyPKCS1v15(k.tpmKey.Public().(*rsa.PublicKey), crypto.SHA256, hashed[:], sig); err != nil {
328+
return nil, ErrDecryptFailed
329+
}
330+
return k.tpmKey.Decrypt(nil, encData, nil)
331+
}
240332
if len(k.maskedKey) == 0 {
241333
k.Logger().Fatal("key is not set")
242334
}
@@ -274,6 +366,23 @@ func (k AESKey) Decrypt(data []byte) ([]byte, error) {
274366

275367
// Encrypt encrypts data using the key.
276368
func (k AESKey) Encrypt(data []byte) ([]byte, error) {
369+
if k.tpmKey != nil {
370+
// encData, err := k.tpmKey.Encrypt(data)
371+
encData, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, k.tpmKey.Public().(*rsa.PublicKey), data, nil)
372+
if err != nil {
373+
return nil, ErrEncryptFailed
374+
}
375+
hashed := sha256.Sum256(encData)
376+
sig, err := k.tpmKey.Sign(nil, hashed[:], crypto.SHA256)
377+
if err != nil {
378+
return nil, ErrEncryptFailed
379+
}
380+
out := make([]byte, 1+len(encData)+len(sig))
381+
out[0] = 3 // version
382+
copy(out[1:], encData)
383+
copy(out[1+len(encData):], sig)
384+
return out, nil
385+
}
277386
if len(k.maskedKey) == 0 {
278387
k.Logger().Fatal("key is not set")
279388
}
@@ -347,10 +456,17 @@ func (k AESKey) NewKey() (EncryptionKey, error) {
347456
return ek, nil
348457
}
349458

459+
func (k AESKey) keysize() int {
460+
if k.tpmKey != nil {
461+
return 2*k.tpmKey.Bits()/8 + 1
462+
}
463+
return aesEncryptedKeySize
464+
}
465+
350466
// DecryptKey decrypts an encrypted key.
351467
func (k AESKey) DecryptKey(encryptedKey []byte) (EncryptionKey, error) {
352-
if len(encryptedKey) != aesEncryptedKeySize {
353-
k.Logger().Debugf("DecryptKey: unexpected encrypted key size %d != %d", len(encryptedKey), aesEncryptedKeySize)
468+
if len(encryptedKey) != k.keysize() {
469+
k.Logger().Debugf("DecryptKey: unexpected encrypted key size %d != %d", len(encryptedKey), k.keysize())
354470
return nil, ErrDecryptFailed
355471
}
356472
b, err := k.Decrypt(encryptedKey)
@@ -509,6 +625,9 @@ func (r *AESStreamReader) Close() error {
509625

510626
// StartReader opens a reader to decrypt a stream of data.
511627
func (k AESKey) StartReader(ctx []byte, r io.Reader) (StreamReader, error) {
628+
if k.tpmKey != nil {
629+
return nil, errors.New("operation not supported with TPM key")
630+
}
512631
var start int64
513632
if seeker, ok := r.(io.Seeker); ok {
514633
off, err := seeker.Seek(0, io.SeekCurrent)
@@ -577,6 +696,9 @@ func (w *AESStreamWriter) Close() (err error) {
577696

578697
// StartWriter opens a writer to encrypt a stream of data.
579698
func (k AESKey) StartWriter(ctx []byte, w io.Writer) (StreamWriter, error) {
699+
if k.tpmKey != nil {
700+
return nil, errors.New("operation not supported with TPM key")
701+
}
580702
block, err := aes.NewCipher(k.key()[:32])
581703
if err != nil {
582704
k.Logger().Debug(err)
@@ -592,7 +714,7 @@ func (k AESKey) StartWriter(ctx []byte, w io.Writer) (StreamWriter, error) {
592714

593715
// ReadEncryptedKey reads an encrypted key and decrypts it.
594716
func (k AESKey) ReadEncryptedKey(r io.Reader) (EncryptionKey, error) {
595-
buf := make([]byte, aesEncryptedKeySize)
717+
buf := make([]byte, k.keysize())
596718
if _, err := io.ReadFull(r, buf); err != nil {
597719
k.Logger().Debug(err)
598720
return nil, ErrDecryptFailed
@@ -603,8 +725,8 @@ func (k AESKey) ReadEncryptedKey(r io.Reader) (EncryptionKey, error) {
603725
// WriteEncryptedKey writes the encrypted key to the writer.
604726
func (k AESKey) WriteEncryptedKey(w io.Writer) error {
605727
n, err := w.Write(k.encryptedKey)
606-
if n != aesEncryptedKeySize {
607-
k.Logger().Debugf("WriteEncryptedKey: unexpected key size: %d != %d", n, aesEncryptedKeySize)
728+
if n == 0 {
729+
k.Logger().Debugf("WriteEncryptedKey: unexpected key size: %d", n)
608730
return ErrEncryptFailed
609731
}
610732
return err

0 commit comments

Comments
 (0)