24
24
package crypto
25
25
26
26
import (
27
+ "crypto"
27
28
"crypto/aes"
28
29
"crypto/cipher"
29
30
"crypto/hmac"
30
31
"crypto/rand"
32
+ "crypto/rsa"
31
33
"crypto/sha256"
32
34
"encoding/binary"
33
35
"errors"
@@ -38,6 +40,8 @@ import (
38
40
"path/filepath"
39
41
"runtime"
40
42
43
+ "github.com/c2FmZQ/tpm"
44
+ "golang.org/x/crypto/cryptobyte"
41
45
"golang.org/x/crypto/pbkdf2"
42
46
)
43
47
@@ -58,6 +62,7 @@ type AESKey struct {
58
62
59
63
logger Logger
60
64
strictWipe bool
65
+ tpmKey * tpm.Key
61
66
}
62
67
63
68
func (k * AESKey ) Logger () Logger {
@@ -98,24 +103,24 @@ type AESMasterKey struct {
98
103
99
104
// CreateAESMasterKey creates a new master key.
100
105
func CreateAESMasterKey (opts ... Option ) (MasterKey , error ) {
101
- var logger Logger = defaultLogger {}
102
- var strictWipe bool
103
- for _ , opt := range opts {
104
- if opt .logger != nil {
105
- logger = opt .logger
106
- }
107
- if opt .strictWipe != nil {
108
- strictWipe = * opt .strictWipe
109
- }
110
- }
106
+ var opt option
107
+ opt .apply (opts )
111
108
b := make ([]byte , 64 )
112
109
if _ , err := rand .Read (b ); err != nil {
113
110
return nil , err
114
111
}
115
112
key := aesKeyFromBytes (b )
116
- key .logger = logger
117
- key .strictWipe = strictWipe
118
- return & AESMasterKey {key }, nil
113
+ key .logger = opt .logger
114
+ key .strictWipe = opt .strictWipe
115
+ mk := & AESMasterKey {key }
116
+ if opt .tpm != nil {
117
+ tpmkey , err := opt .tpm .CreateKey ()
118
+ if err != nil {
119
+ return nil , err
120
+ }
121
+ mk .tpmKey = tpmkey
122
+ }
123
+ return mk , nil
119
124
}
120
125
121
126
// CreateAESMasterKeyForTest creates a new master key to tests.
@@ -133,51 +138,90 @@ func CreateAESMasterKeyForTest() (MasterKey, error) {
133
138
134
139
// ReadAESMasterKey reads an encrypted master key from file and decrypts it.
135
140
func ReadAESMasterKey (passphrase []byte , file string , opts ... Option ) (MasterKey , error ) {
136
- var logger Logger = defaultLogger {}
137
- var strictWipe bool
138
- for _ , opt := range opts {
139
- if opt .logger != nil {
140
- logger = opt .logger
141
- }
142
- if opt .strictWipe != nil {
143
- strictWipe = * opt .strictWipe
144
- }
145
- }
141
+ var opt option
142
+ opt .apply (opts )
146
143
b , err := os .ReadFile (file )
147
144
if err != nil {
148
145
return nil , err
149
146
}
150
147
if len (b ) < 64 {
151
148
return nil , ErrDecryptFailed
152
149
}
153
- version , b := b [ 0 ], b [ 1 :]
154
- if version != 1 {
155
- logger . Debugf ( "ReadMasterKey: unexpected version: %d" , version )
150
+ str := cryptobyte . String ( b )
151
+ var version uint8
152
+ if ! str . ReadUint8 ( & version ) {
156
153
return nil , ErrDecryptFailed
157
154
}
158
- salt , b := b [:16 ], b [16 :]
159
- numIter , b := int (binary .BigEndian .Uint32 (b [:4 ])), b [4 :]
160
- dk := pbkdf2 .Key (passphrase , salt , numIter , 32 , sha256 .New )
155
+ if version != 1 && version != 3 {
156
+ opt .logger .Debugf ("ReadMasterKey: unexpected version: %d" , version )
157
+ return nil , ErrDecryptFailed
158
+ }
159
+ if version == 3 && opt .tpm == nil {
160
+ opt .logger .Debug ("ReadMasterKey: missing WithTPM option" )
161
+ return nil , ErrDecryptFailed
162
+ }
163
+ salt := make ([]byte , 16 )
164
+ if ! str .ReadBytes (& salt , 16 ) {
165
+ return nil , ErrDecryptFailed
166
+ }
167
+ var numIter uint32
168
+ if ! str .ReadUint32 (& numIter ) {
169
+ return nil , ErrDecryptFailed
170
+ }
171
+ dk := pbkdf2 .Key (passphrase , salt , int (numIter ), 32 , sha256 .New )
161
172
block , err := aes .NewCipher (dk )
162
173
if err != nil {
163
- logger .Debug (err )
174
+ opt . logger .Debug (err )
164
175
return nil , ErrDecryptFailed
165
176
}
166
177
gcm , err := cipher .NewGCM (block )
167
178
if err != nil {
168
- logger .Debug (err )
179
+ opt . logger .Debug (err )
169
180
return nil , ErrDecryptFailed
170
181
}
171
- nonce := b [:gcm .NonceSize ()]
172
- encMasterKey := b [gcm .NonceSize ():]
173
- mkBytes , err := gcm .Open (nil , nonce , encMasterKey , nil )
182
+ nonce := make ([]byte , gcm .NonceSize ())
183
+ if ! str .ReadBytes (& nonce , len (nonce )) {
184
+ return nil , ErrDecryptFailed
185
+ }
186
+ mkBytes , err := gcm .Open (nil , nonce , []byte (str ), nil )
174
187
if err != nil {
175
- logger .Debug (err )
188
+ opt . logger .Debug (err )
176
189
return nil , ErrDecryptFailed
177
190
}
178
- key := aesKeyFromBytes (mkBytes )
179
- key .logger = logger
180
- key .strictWipe = strictWipe
191
+ var key * AESKey
192
+ if version == 1 {
193
+ key = aesKeyFromBytes (mkBytes )
194
+ } else { // version == 3
195
+ str := cryptobyte .String (mkBytes )
196
+ var length uint16
197
+ if ! str .ReadUint16 (& length ) {
198
+ return nil , ErrDecryptFailed
199
+ }
200
+ encKey := make ([]byte , length )
201
+ if ! str .ReadBytes (& encKey , len (encKey )) {
202
+ return nil , ErrDecryptFailed
203
+ }
204
+ if ! str .ReadUint16 (& length ) {
205
+ return nil , ErrDecryptFailed
206
+ }
207
+ tpmCtx := make ([]byte , length )
208
+ if ! str .ReadBytes (& tpmCtx , len (tpmCtx )) {
209
+ return nil , ErrDecryptFailed
210
+ }
211
+ tpmKey , err := opt .tpm .UnmarshalKey (tpmCtx )
212
+ if err != nil {
213
+ return nil , err
214
+ }
215
+ decKey , err := tpmKey .Decrypt (nil , encKey , nil )
216
+ if err != nil {
217
+ opt .logger .Debug (err )
218
+ return nil , ErrDecryptFailed
219
+ }
220
+ key = aesKeyFromBytes (decKey )
221
+ key .tpmKey = tpmKey
222
+ }
223
+ key .logger = opt .logger
224
+ key .strictWipe = opt .strictWipe
181
225
return & AESMasterKey {key }, nil
182
226
}
183
227
@@ -191,8 +235,6 @@ func (mk AESMasterKey) Save(passphrase []byte, file string) error {
191
235
if len (passphrase ) == 0 {
192
236
numIter = 10
193
237
}
194
- numIterBin := make ([]byte , 4 )
195
- binary .BigEndian .PutUint32 (numIterBin , uint32 (numIter ))
196
238
dk := pbkdf2 .Key (passphrase , salt , numIter , 32 , sha256 .New )
197
239
block , err := aes .NewCipher (dk )
198
240
if err != nil {
@@ -209,11 +251,44 @@ func (mk AESMasterKey) Save(passphrase []byte, file string) error {
209
251
mk .Logger ().Debug (err )
210
252
return ErrEncryptFailed
211
253
}
212
- encMasterKey := gcm .Seal (nonce , nonce , mk .key (), nil )
213
- data := []byte {1 } // version
214
- data = append (data , salt ... )
215
- data = append (data , numIterBin ... )
216
- data = append (data , encMasterKey ... )
254
+ var version uint8
255
+ var payload []byte
256
+ if mk .tpmKey == nil {
257
+ version = 1
258
+ payload = mk .key ()
259
+ } else {
260
+ version = 3
261
+ buf := cryptobyte .NewBuilder (nil )
262
+ // encKey, err := mk.tpmKey.Encrypt(mk.key())
263
+ encKey , err := rsa .EncryptOAEP (sha256 .New (), rand .Reader , mk .tpmKey .Public ().(* rsa.PublicKey ), mk .key (), nil )
264
+ if err != nil {
265
+ mk .Logger ().Debug (err )
266
+ return ErrEncryptFailed
267
+ }
268
+ buf .AddUint16 (uint16 (len (encKey )))
269
+ buf .AddBytes (encKey )
270
+ keyctx , err := mk .tpmKey .Marshal ()
271
+ if err != nil {
272
+ mk .Logger ().Debug (err )
273
+ return ErrEncryptFailed
274
+ }
275
+ buf .AddUint16 (uint16 (len (keyctx )))
276
+ buf .AddBytes (keyctx )
277
+ if payload , err = buf .Bytes (); err != nil {
278
+ mk .Logger ().Debug (err )
279
+ return ErrEncryptFailed
280
+ }
281
+ }
282
+ encMasterKey := gcm .Seal (nonce , nonce , payload , nil )
283
+ buf := cryptobyte .NewBuilder ([]byte {version })
284
+ buf .AddBytes (salt )
285
+ buf .AddUint32 (uint32 (numIter ))
286
+ buf .AddBytes (encMasterKey )
287
+ data , err := buf .Bytes ()
288
+ if err != nil {
289
+ mk .Logger ().Debug (err )
290
+ return ErrEncryptFailed
291
+ }
217
292
dir , _ := filepath .Split (file )
218
293
if err := os .MkdirAll (dir , 0700 ); err != nil {
219
294
return err
@@ -237,6 +312,23 @@ func (k AESKey) Hash(b []byte) []byte {
237
312
238
313
// Decrypt decrypts data that was encrypted with Encrypt and the same key.
239
314
func (k AESKey ) Decrypt (data []byte ) ([]byte , error ) {
315
+ if k .tpmKey != nil {
316
+ sigSize := k .tpmKey .Bits () / 8
317
+ if len (data ) < 1 + sigSize {
318
+ return nil , ErrDecryptFailed
319
+ }
320
+ version , data := data [0 ], data [1 :]
321
+ if version != 3 {
322
+ return nil , ErrDecryptFailed
323
+ }
324
+ encData , data := data [:len (data )- sigSize ], data [len (data )- sigSize :]
325
+ sig := data [:sigSize ]
326
+ hashed := sha256 .Sum256 (encData )
327
+ if err := rsa .VerifyPKCS1v15 (k .tpmKey .Public ().(* rsa.PublicKey ), crypto .SHA256 , hashed [:], sig ); err != nil {
328
+ return nil , ErrDecryptFailed
329
+ }
330
+ return k .tpmKey .Decrypt (nil , encData , nil )
331
+ }
240
332
if len (k .maskedKey ) == 0 {
241
333
k .Logger ().Fatal ("key is not set" )
242
334
}
@@ -274,6 +366,23 @@ func (k AESKey) Decrypt(data []byte) ([]byte, error) {
274
366
275
367
// Encrypt encrypts data using the key.
276
368
func (k AESKey ) Encrypt (data []byte ) ([]byte , error ) {
369
+ if k .tpmKey != nil {
370
+ // encData, err := k.tpmKey.Encrypt(data)
371
+ encData , err := rsa .EncryptOAEP (sha256 .New (), rand .Reader , k .tpmKey .Public ().(* rsa.PublicKey ), data , nil )
372
+ if err != nil {
373
+ return nil , ErrEncryptFailed
374
+ }
375
+ hashed := sha256 .Sum256 (encData )
376
+ sig , err := k .tpmKey .Sign (nil , hashed [:], crypto .SHA256 )
377
+ if err != nil {
378
+ return nil , ErrEncryptFailed
379
+ }
380
+ out := make ([]byte , 1 + len (encData )+ len (sig ))
381
+ out [0 ] = 3 // version
382
+ copy (out [1 :], encData )
383
+ copy (out [1 + len (encData ):], sig )
384
+ return out , nil
385
+ }
277
386
if len (k .maskedKey ) == 0 {
278
387
k .Logger ().Fatal ("key is not set" )
279
388
}
@@ -347,10 +456,17 @@ func (k AESKey) NewKey() (EncryptionKey, error) {
347
456
return ek , nil
348
457
}
349
458
459
+ func (k AESKey ) keysize () int {
460
+ if k .tpmKey != nil {
461
+ return 2 * k .tpmKey .Bits ()/ 8 + 1
462
+ }
463
+ return aesEncryptedKeySize
464
+ }
465
+
350
466
// DecryptKey decrypts an encrypted key.
351
467
func (k AESKey ) DecryptKey (encryptedKey []byte ) (EncryptionKey , error ) {
352
- if len (encryptedKey ) != aesEncryptedKeySize {
353
- k .Logger ().Debugf ("DecryptKey: unexpected encrypted key size %d != %d" , len (encryptedKey ), aesEncryptedKeySize )
468
+ if len (encryptedKey ) != k . keysize () {
469
+ k .Logger ().Debugf ("DecryptKey: unexpected encrypted key size %d != %d" , len (encryptedKey ), k . keysize () )
354
470
return nil , ErrDecryptFailed
355
471
}
356
472
b , err := k .Decrypt (encryptedKey )
@@ -509,6 +625,9 @@ func (r *AESStreamReader) Close() error {
509
625
510
626
// StartReader opens a reader to decrypt a stream of data.
511
627
func (k AESKey ) StartReader (ctx []byte , r io.Reader ) (StreamReader , error ) {
628
+ if k .tpmKey != nil {
629
+ return nil , errors .New ("operation not supported with TPM key" )
630
+ }
512
631
var start int64
513
632
if seeker , ok := r .(io.Seeker ); ok {
514
633
off , err := seeker .Seek (0 , io .SeekCurrent )
@@ -577,6 +696,9 @@ func (w *AESStreamWriter) Close() (err error) {
577
696
578
697
// StartWriter opens a writer to encrypt a stream of data.
579
698
func (k AESKey ) StartWriter (ctx []byte , w io.Writer ) (StreamWriter , error ) {
699
+ if k .tpmKey != nil {
700
+ return nil , errors .New ("operation not supported with TPM key" )
701
+ }
580
702
block , err := aes .NewCipher (k .key ()[:32 ])
581
703
if err != nil {
582
704
k .Logger ().Debug (err )
@@ -592,7 +714,7 @@ func (k AESKey) StartWriter(ctx []byte, w io.Writer) (StreamWriter, error) {
592
714
593
715
// ReadEncryptedKey reads an encrypted key and decrypts it.
594
716
func (k AESKey ) ReadEncryptedKey (r io.Reader ) (EncryptionKey , error ) {
595
- buf := make ([]byte , aesEncryptedKeySize )
717
+ buf := make ([]byte , k . keysize () )
596
718
if _ , err := io .ReadFull (r , buf ); err != nil {
597
719
k .Logger ().Debug (err )
598
720
return nil , ErrDecryptFailed
@@ -603,8 +725,8 @@ func (k AESKey) ReadEncryptedKey(r io.Reader) (EncryptionKey, error) {
603
725
// WriteEncryptedKey writes the encrypted key to the writer.
604
726
func (k AESKey ) WriteEncryptedKey (w io.Writer ) error {
605
727
n , err := w .Write (k .encryptedKey )
606
- if n != aesEncryptedKeySize {
607
- k .Logger ().Debugf ("WriteEncryptedKey: unexpected key size: %d != %d " , n , aesEncryptedKeySize )
728
+ if n == 0 {
729
+ k .Logger ().Debugf ("WriteEncryptedKey: unexpected key size: %d" , n )
608
730
return ErrEncryptFailed
609
731
}
610
732
return err
0 commit comments