diff --git a/const.go b/const.go index e4aaf3bb..9f99a9dc 100644 --- a/const.go +++ b/const.go @@ -50,7 +50,7 @@ const ( //checkheader:ignore // KDF names _OSSL_KDF_NAME_HKDF cString = "HKDF\x00" _OSSL_KDF_NAME_PBKDF2 cString = "PBKDF2\x00" - _OSSL_KDF_NAME_TLS1_PRF cString = "TLS1-PRF\x00" + _OSSL_KDF_NAME_TLS1_PRF cString = "TLS1-PRF\x00" _OSSL_KDF_NAME_TLS13_KDF cString = "TLS13-KDF\x00" _OSSL_MAC_NAME_HMAC cString = "HMAC\x00" diff --git a/evp.go b/evp.go index 20777c07..a6ad5590 100644 --- a/evp.go +++ b/evp.go @@ -63,11 +63,22 @@ func hashFuncToMD(fn func() hash.Hash) (ossl.EVP_MD_PTR, error) { return md, nil } +// provider is an identifier for a known provider. +type provider uint8 + +const ( + providerNone provider = iota + providerOSSLDefault + providerOSSLFIPS + providerSymCrypt +) + type hashAlgorithm struct { md ossl.EVP_MD_PTR ch crypto.Hash size int blockSize int + provider provider marshallable bool magic string marshalledSize int @@ -87,14 +98,14 @@ func loadHash(ch crypto.Hash) *hashAlgorithm { hash.md = ossl.EVP_md4() case crypto.MD5: hash.md = ossl.EVP_md5() - hash.magic = md5Magic - hash.marshalledSize = md5MarshaledSize + hash.magic = magicMD5 + hash.marshalledSize = marshaledSizeMD5 case crypto.MD5SHA1: hash.md = ossl.EVP_md5_sha1() case crypto.SHA1: hash.md = ossl.EVP_sha1() - hash.magic = sha1Magic - hash.marshalledSize = sha1MarshaledSize + hash.magic = magic1 + hash.marshalledSize = marshaledSize1 case crypto.SHA224: hash.md = ossl.EVP_sha224() hash.magic = magic224 @@ -159,7 +170,34 @@ func loadHash(ch crypto.Hash) *hashAlgorithm { hash.md = md } } - hash.marshallable = hash.magic != "" && isHashMarshallable(hash.md) + if hash.magic != "" { + if hash.marshalledSize == 0 { + panic("marshalledSize must be set for " + hash.magic) + } + } + + switch vMajor { + case 1: + hash.provider = providerOSSLDefault + case 3: + if prov := ossl.EVP_MD_get0_provider(hash.md); prov != nil { + cname := ossl.OSSL_PROVIDER_get0_name(prov) + switch C.GoString((*C.char)(unsafe.Pointer(cname))) { + case "default": + hash.provider = providerOSSLDefault + hash.marshallable = hash.magic != "" + case "fips": + hash.provider = providerOSSLFIPS + hash.marshallable = hash.magic != "" + case "symcryptprovider": + hash.provider = providerSymCrypt + hash.marshallable = hash.magic != "" && isSymCryptHashStateSerializable(hash.md) + } + } + default: + panic(errUnsupportedVersion()) + } + cacheMD.Store(ch, &hash) return &hash } diff --git a/hash.go b/hash.go index 033169be..85ab3fac 100644 --- a/hash.go +++ b/hash.go @@ -15,6 +15,22 @@ import ( "github.com/golang-fips/openssl/v2/internal/ossl" ) +const ( + magicMD5 = "md5\x01" + magic1 = "sha\x01" + magic224 = "sha\x02" + magic256 = "sha\x03" + magic384 = "sha\x04" + magic512_224 = "sha\x05" + magic512_256 = "sha\x06" + magic512 = "sha\x07" + + marshaledSizeMD5 = len(magicMD5) + 4*4 + 64 + 8 // from crypto/md5 + marshaledSize1 = len(magic1) + 5*4 + 64 + 8 // from crypto/sha1 + marshaledSize256 = len(magic256) + 8*4 + 64 + 8 // from crypto/sha256 + marshaledSize512 = len(magic512) + 8*8 + 128 + 8 // from crypto/sha512 +) + // maxHashSize is the size of SHA52 and SHA3_512, the largest hashes we support. const maxHashSize = 64 @@ -207,27 +223,6 @@ func NewSHA3_512() hash.Hash { return newEvpHash(crypto.SHA3_512) } -// isHashMarshallable returns true if the memory layout of md -// is known by this library and can therefore be marshalled. -func isHashMarshallable(md ossl.EVP_MD_PTR) bool { - if vMajor == 1 { - return true - } - prov := ossl.EVP_MD_get0_provider(md) - if prov == nil { - return false - } - cname := ossl.OSSL_PROVIDER_get0_name(prov) - if cname == nil { - return false - } - name := C.GoString((*C.char)(unsafe.Pointer(cname))) - // We only know the memory layout of the built-in providers. - // See evpHash.hashState for more details. - marshallable := name == "default" || name == "fips" - return marshallable -} - // cloneHash is an interface that defines a Clone method. // // hahs.CloneHash will probably be added in Go 1.25, see https://golang.org/issue/69521, @@ -387,37 +382,11 @@ func (h *evpHash) Clone() hash.Hash { return h2 } -// hashState returns a pointer to the internal hash structure. -// -// The EVP_MD_CTX memory layout has changed in OpenSSL 3 -// and the property holding the internal structure is no longer md_data but algctx. -func hashState(ctx ossl.EVP_MD_CTX_PTR) unsafe.Pointer { - switch vMajor { - case 1: - // https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/crypto/evp/evp_local.h#L12. - type mdCtx struct { - _ [2]unsafe.Pointer - _ C.ulong - md_data unsafe.Pointer - } - return (*mdCtx)(unsafe.Pointer(ctx)).md_data - case 3: - // https://github.com/openssl/openssl/blob/5675a5aaf6a2e489022bcfc18330dae9263e598e/crypto/evp/evp_local.h#L16. - type mdCtx struct { - _ [3]unsafe.Pointer - _ C.ulong - _ [3]unsafe.Pointer - algctx unsafe.Pointer - } - return (*mdCtx)(unsafe.Pointer(ctx)).algctx - default: - panic(errUnsupportedVersion()) - } -} +var errHashNotMarshallable = errors.New("openssl: hash state is not marshallable") func (d *evpHash) MarshalBinary() ([]byte, error) { if !d.alg.marshallable { - return nil, errors.New("openssl: hash state is not marshallable") + return nil, errHashNotMarshallable } buf := make([]byte, 0, d.alg.marshalledSize) return d.AppendBinary(buf) @@ -425,261 +394,40 @@ func (d *evpHash) MarshalBinary() ([]byte, error) { func (d *evpHash) AppendBinary(buf []byte) ([]byte, error) { defer runtime.KeepAlive(d) - d.init() if !d.alg.marshallable { - return nil, errors.New("openssl: hash state is not marshallable") - } - state := hashState(d.ctx) - if state == nil { - return nil, errors.New("openssl: can't retrieve hash state") - } - var appender interface { - AppendBinary([]byte) ([]byte, error) - } - switch d.alg.ch { - case crypto.MD5: - appender = (*md5State)(state) - case crypto.SHA1: - appender = (*sha1State)(state) - case crypto.SHA224: - appender = (*sha256State)(state) - case crypto.SHA256: - appender = (*sha256State)(state) - case crypto.SHA384: - appender = (*sha512State)(state) - case crypto.SHA512: - appender = (*sha512State)(state) - case crypto.SHA512_224: - appender = (*sha512State)(state) - case crypto.SHA512_256: - appender = (*sha512State)(state) + return nil, errHashNotMarshallable + } + d.init() + switch d.alg.provider { + case providerOSSLDefault, providerOSSLFIPS: + return osslHashAppendBinary(d.ctx, d.alg.ch, d.alg.magic, buf) + case providerSymCrypt: + return symCryptHashAppendBinary(d.ctx, d.alg.ch, d.alg.magic, buf) default: - panic("openssl: unsupported hash function: " + strconv.Itoa(int(d.alg.ch))) + panic("openssl: unknown hash provider" + strconv.Itoa(int(d.alg.provider))) } - buf = append(buf, d.alg.magic[:]...) - return appender.AppendBinary(buf) } func (d *evpHash) UnmarshalBinary(b []byte) error { defer runtime.KeepAlive(d) d.init() if !d.alg.marshallable { - return errors.New("openssl: hash state is not marshallable") + return errHashNotMarshallable } - if len(b) < len(d.alg.magic) || string(b[:len(d.alg.magic)]) != string(d.alg.magic[:]) { + if len(b) < len(d.alg.magic) || string(b[:len(d.alg.magic)]) != d.alg.magic { return errors.New("openssl: invalid hash state identifier") } if len(b) != d.alg.marshalledSize { return errors.New("openssl: invalid hash state size") } - state := hashState(d.ctx) - if state == nil { - return errors.New("openssl: can't retrieve hash state") - } - b = b[len(d.alg.magic):] - var unmarshaler interface { - UnmarshalBinary([]byte) error - } - switch d.alg.ch { - case crypto.MD5: - unmarshaler = (*md5State)(state) - case crypto.SHA1: - unmarshaler = (*sha1State)(state) - case crypto.SHA224: - unmarshaler = (*sha256State)(state) - case crypto.SHA256: - unmarshaler = (*sha256State)(state) - case crypto.SHA384: - unmarshaler = (*sha512State)(state) - case crypto.SHA512: - unmarshaler = (*sha512State)(state) - case crypto.SHA512_224: - unmarshaler = (*sha512State)(state) - case crypto.SHA512_256: - unmarshaler = (*sha512State)(state) + switch d.alg.provider { + case providerOSSLDefault, providerOSSLFIPS: + return osslHashUnmarshalBinary(d.ctx, d.alg.ch, d.alg.magic, b) + case providerSymCrypt: + return symCryptHashUnmarshalBinary(d.ctx, d.alg.ch, d.alg.magic, b) default: - panic("openssl: unsupported hash function: " + strconv.Itoa(int(d.alg.ch))) + panic("openssl: unknown hash provider" + strconv.Itoa(int(d.alg.provider))) } - return unmarshaler.UnmarshalBinary(b) -} - -// md5State layout is taken from -// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/md5.h#L33. -type md5State struct { - h [4]uint32 - nl, nh uint32 - x [64]byte - nx uint32 -} - -const ( - md5Magic = "md5\x01" - md5MarshaledSize = len(md5Magic) + 4*4 + 64 + 8 -) - -func (d *md5State) UnmarshalBinary(b []byte) error { - b, d.h[0] = consumeUint32(b) - b, d.h[1] = consumeUint32(b) - b, d.h[2] = consumeUint32(b) - b, d.h[3] = consumeUint32(b) - b = b[copy(d.x[:], b):] - _, n := consumeUint64(b) - d.nl = uint32(n << 3) - d.nh = uint32(n >> 29) - d.nx = uint32(n) % 64 - return nil -} - -func (d *md5State) AppendBinary(buf []byte) ([]byte, error) { - buf = appendUint32(buf, d.h[0]) - buf = appendUint32(buf, d.h[1]) - buf = appendUint32(buf, d.h[2]) - buf = appendUint32(buf, d.h[3]) - buf = append(buf, d.x[:d.nx]...) - buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) - buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) - return buf, nil -} - -// sha1State layout is taken from -// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L34. -type sha1State struct { - h [5]uint32 - nl, nh uint32 - x [64]byte - nx uint32 -} - -const ( - sha1Magic = "sha\x01" - sha1MarshaledSize = len(sha1Magic) + 5*4 + 64 + 8 -) - -func (d *sha1State) UnmarshalBinary(b []byte) error { - b, d.h[0] = consumeUint32(b) - b, d.h[1] = consumeUint32(b) - b, d.h[2] = consumeUint32(b) - b, d.h[3] = consumeUint32(b) - b, d.h[4] = consumeUint32(b) - b = b[copy(d.x[:], b):] - _, n := consumeUint64(b) - d.nl = uint32(n << 3) - d.nh = uint32(n >> 29) - d.nx = uint32(n) % 64 - return nil -} - -func (d *sha1State) AppendBinary(buf []byte) ([]byte, error) { - buf = appendUint32(buf, d.h[0]) - buf = appendUint32(buf, d.h[1]) - buf = appendUint32(buf, d.h[2]) - buf = appendUint32(buf, d.h[3]) - buf = appendUint32(buf, d.h[4]) - buf = append(buf, d.x[:d.nx]...) - buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) - buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) - return buf, nil -} - -const ( - magic224 = "sha\x02" - magic256 = "sha\x03" - marshaledSize256 = len(magic256) + 8*4 + 64 + 8 -) - -// sha256State layout is taken from -// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L51. -type sha256State struct { - h [8]uint32 - nl, nh uint32 - x [64]byte - nx uint32 -} - -func (d *sha256State) UnmarshalBinary(b []byte) error { - b, d.h[0] = consumeUint32(b) - b, d.h[1] = consumeUint32(b) - b, d.h[2] = consumeUint32(b) - b, d.h[3] = consumeUint32(b) - b, d.h[4] = consumeUint32(b) - b, d.h[5] = consumeUint32(b) - b, d.h[6] = consumeUint32(b) - b, d.h[7] = consumeUint32(b) - b = b[copy(d.x[:], b):] - _, n := consumeUint64(b) - d.nl = uint32(n << 3) - d.nh = uint32(n >> 29) - d.nx = uint32(n) % 64 - return nil -} - -func (d *sha256State) AppendBinary(buf []byte) ([]byte, error) { - buf = appendUint32(buf, d.h[0]) - buf = appendUint32(buf, d.h[1]) - buf = appendUint32(buf, d.h[2]) - buf = appendUint32(buf, d.h[3]) - buf = appendUint32(buf, d.h[4]) - buf = appendUint32(buf, d.h[5]) - buf = appendUint32(buf, d.h[6]) - buf = appendUint32(buf, d.h[7]) - buf = append(buf, d.x[:d.nx]...) - buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) - buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) - return buf, nil -} - -// sha512State layout is taken from -// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L95. -type sha512State struct { - h [8]uint64 - nl, nh uint64 - x [128]byte - nx uint32 -} - -const ( - magic384 = "sha\x04" - magic512_224 = "sha\x05" - magic512_256 = "sha\x06" - magic512 = "sha\x07" - marshaledSize512 = len(magic512) + 8*8 + 128 + 8 -) - -func (d *sha512State) MarshalBinary() ([]byte, error) { - buf := make([]byte, 0, marshaledSize512) - return d.AppendBinary(buf) -} - -func (d *sha512State) UnmarshalBinary(b []byte) error { - b, d.h[0] = consumeUint64(b) - b, d.h[1] = consumeUint64(b) - b, d.h[2] = consumeUint64(b) - b, d.h[3] = consumeUint64(b) - b, d.h[4] = consumeUint64(b) - b, d.h[5] = consumeUint64(b) - b, d.h[6] = consumeUint64(b) - b, d.h[7] = consumeUint64(b) - b = b[copy(d.x[:], b):] - _, n := consumeUint64(b) - d.nl = n << 3 - d.nh = n >> 61 - d.nx = uint32(n) % 128 - return nil -} - -func (d *sha512State) AppendBinary(buf []byte) ([]byte, error) { - buf = appendUint64(buf, d.h[0]) - buf = appendUint64(buf, d.h[1]) - buf = appendUint64(buf, d.h[2]) - buf = appendUint64(buf, d.h[3]) - buf = appendUint64(buf, d.h[4]) - buf = appendUint64(buf, d.h[5]) - buf = appendUint64(buf, d.h[6]) - buf = appendUint64(buf, d.h[7]) - buf = append(buf, d.x[:d.nx]...) - buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) - buf = appendUint64(buf, d.nl>>3|d.nh<<61) - return buf, nil } // appendUint64 appends x into b as a big endian byte sequence. diff --git a/hash_test.go b/hash_test.go index 569e51c7..c4ecdc4b 100644 --- a/hash_test.go +++ b/hash_test.go @@ -10,6 +10,12 @@ import ( "strings" "testing" + // Blank imports to ensure that the hash functions are registered. + _ "crypto/md5" + _ "crypto/sha1" + _ "crypto/sha256" + _ "crypto/sha512" + "github.com/golang-fips/openssl/v2" ) @@ -94,6 +100,17 @@ func TestHash(t *testing.T) { } } +type hashEncoding interface { + hash.Hash + encoding.BinaryMarshaler + encoding.BinaryUnmarshaler +} + +type hashEncodingAppender interface { + hashEncoding + AppendBinary(b []byte) ([]byte, error) +} + func TestHash_BinaryMarshaler(t *testing.T) { msg := []byte("testing") for _, ch := range hashes { @@ -103,10 +120,7 @@ func TestHash_BinaryMarshaler(t *testing.T) { t.Skip("hash not supported") } - hashMarshaler, ok := cryptoToHash(ch)().(interface { - hash.Hash - encoding.BinaryMarshaler - }) + hashMarshaler, ok := cryptoToHash(ch)().(hashEncoding) if !ok { t.Fatal("BinaryMarshaler not supported") } @@ -123,10 +137,7 @@ func TestHash_BinaryMarshaler(t *testing.T) { t.Fatalf("MarshalBinary failed: %v", err) } - hashUnmarshaler := cryptoToHash(ch)().(interface { - hash.Hash - encoding.BinaryUnmarshaler - }) + hashUnmarshaler := cryptoToHash(ch)().(hashEncoding) if err := hashUnmarshaler.UnmarshalBinary(state); err != nil { t.Fatalf("UnmarshalBinary failed: %v", err) } @@ -134,6 +145,26 @@ func TestHash_BinaryMarshaler(t *testing.T) { if actual, actual2 := hashMarshaler.Sum(nil), hashUnmarshaler.Sum(nil); !bytes.Equal(actual, actual2) { t.Errorf("0x%x != appended 0x%x", actual, actual2) } + + // Test that the hash state is compatible with native Go. + h, ok := ch.New().(hashEncoding) + if !ok { + // The standard library doesn't support encoding this hash. + // Nothing else to do. + return + } + h.Write(msg) + stateh, err := h.(encoding.BinaryMarshaler).MarshalBinary() + if err != nil { + t.Error(err) + } + if !bytes.Equal(state, stateh) { + t.Errorf("got 0x%x != want 0x%x", state, stateh) + } + h = ch.New().(hashEncoding) + if err := h.UnmarshalBinary(state); err != nil { + t.Error(err) + } }) } } @@ -146,10 +177,7 @@ func TestHash_BinaryAppender(t *testing.T) { t.Skip("not supported") } - hashWithBinaryAppender, ok := cryptoToHash(ch)().(interface { - hash.Hash - AppendBinary(b []byte) ([]byte, error) - }) + hashWithBinaryAppender, ok := cryptoToHash(ch)().(hashEncodingAppender) if !ok { t.Fatal("AppendBinary not supported") } @@ -181,10 +209,7 @@ func TestHash_BinaryAppender(t *testing.T) { // Use only the newly appended part of the slice appendedState := state[10:] - h2, ok := cryptoToHash(ch)().(interface { - hash.Hash - encoding.BinaryUnmarshaler - }) + h2, ok := cryptoToHash(ch)().(hashEncoding) if !ok { t.Skip("not supported") } diff --git a/internal/ossl/ossl.go b/internal/ossl/ossl.go index 4a64c2f0..e0d0b8e8 100644 --- a/internal/ossl/ossl.go +++ b/internal/ossl/ossl.go @@ -22,7 +22,9 @@ go_hash_sum(const _EVP_MD_CTX_PTR ctx, _EVP_MD_CTX_PTR ctx2, unsigned char *out, } */ import "C" -import "unsafe" +import ( + "unsafe" +) func HashSum(ctx1, ctx2 EVP_MD_CTX_PTR, out []byte) error { var errst C.mkcgo_err_state @@ -38,3 +40,42 @@ func HashSum(ctx1, ctx2 EVP_MD_CTX_PTR, out []byte) error { } return nil } + +const _OSSL_PARAM_UNMODIFIED uint = uint(^uintptr(0)) + +// OSSL_PARAM is a structure to pass or request object parameters. +// https://docs.openssl.org/3.0/man3/OSSL_PARAM/. +type OSSL_PARAM struct { + Key *byte + DataType uint32 + Data unsafe.Pointer + DataSize uint + ReturnSize uint +} + +func ossl_param_construct(key *byte, dataType uint32, data unsafe.Pointer, dataSize int) OSSL_PARAM { + return OSSL_PARAM{ + Key: key, + DataType: dataType, + Data: data, + DataSize: uint(dataSize), + ReturnSize: _OSSL_PARAM_UNMODIFIED, + } +} + +func OSSL_PARAM_construct_octet_string(key *byte, data unsafe.Pointer, dataSize int) OSSL_PARAM { + return ossl_param_construct(key, OSSL_PARAM_OCTET_STRING, data, dataSize) +} + +func OSSL_PARAM_construct_int32(key *byte, data *int32) OSSL_PARAM { + return ossl_param_construct(key, OSSL_PARAM_INTEGER, unsafe.Pointer(data), 4) +} + +func OSSL_PARAM_construct_end() OSSL_PARAM { + return OSSL_PARAM{} +} + +func OSSL_PARAM_modified(param *OSSL_PARAM) bool { + // If ReturnSize is not set, the parameter has not been modified. + return param != nil && param.ReturnSize != _OSSL_PARAM_UNMODIFIED +} diff --git a/internal/ossl/shims.h b/internal/ossl/shims.h index 012d60f8..5790defe 100644 --- a/internal/ossl/shims.h +++ b/internal/ossl/shims.h @@ -23,6 +23,7 @@ // #include // #include // #include +// #include // #endif // #if OPENSSL_VERSION_NUMBER < 0x10100000L // #include @@ -86,6 +87,9 @@ enum { _EVP_PKEY_CTRL_RSA_OAEP_LABEL = 0x100A, _EVP_PKEY_CTRL_DSA_PARAMGEN_BITS = 0x1001, _EVP_PKEY_CTRL_DSA_PARAMGEN_Q_BITS = 0x1002, + + _OSSL_PARAM_INTEGER = 1, + _OSSL_PARAM_OCTET_STRING = 5, }; typedef void* _OPENSSL_INIT_SETTINGS_PTR; @@ -189,6 +193,10 @@ _EVP_MD_CTX_PTR EVP_MD_CTX_new(void); void EVP_MD_CTX_free(_EVP_MD_CTX_PTR ctx); int EVP_MD_CTX_copy(_EVP_MD_CTX_PTR out, const _EVP_MD_CTX_PTR in) __attribute__((noescape,nocallback)); int EVP_MD_CTX_copy_ex(_EVP_MD_CTX_PTR out, const _EVP_MD_CTX_PTR in); +const _OSSL_PARAM_PTR EVP_MD_CTX_gettable_params(_EVP_MD_CTX_PTR ctx) __attribute__((tag("3"))); +const _OSSL_PARAM_PTR EVP_MD_CTX_settable_params(_EVP_MD_CTX_PTR ctx) __attribute__((tag("3"))); +int EVP_MD_CTX_get_params(_EVP_MD_CTX_PTR ctx, _OSSL_PARAM_PTR params) __attribute__((tag("3"))); +int EVP_MD_CTX_set_params(_EVP_MD_CTX_PTR ctx, const _OSSL_PARAM_PTR params) __attribute__((tag("3"))); int EVP_Digest(const void *data, size_t count, unsigned char *md, unsigned int *size, const _EVP_MD_PTR type, _ENGINE_PTR impl) __attribute__((noescape,nocallback,nocheckptr("data"))); int EVP_DigestInit_ex(_EVP_MD_CTX_PTR ctx, const _EVP_MD_PTR type, _ENGINE_PTR impl); int EVP_DigestInit(_EVP_MD_CTX_PTR ctx, const _EVP_MD_PTR type); @@ -353,6 +361,7 @@ int EVP_MAC_final(_EVP_MAC_CTX_PTR ctx, unsigned char *out, size_t *outl, size_t // OSSL_PARAM API void OSSL_PARAM_free(_OSSL_PARAM_PTR p) __attribute__((tag("3"))); +const _OSSL_PARAM_PTR OSSL_PARAM_locate_const(const _OSSL_PARAM_PTR p, const char *key) __attribute__((tag("3"))); _OSSL_PARAM_BLD_PTR OSSL_PARAM_BLD_new(void) __attribute__((tag("3"))); void OSSL_PARAM_BLD_free(_OSSL_PARAM_BLD_PTR bld) __attribute__((tag("3"))); _OSSL_PARAM_PTR OSSL_PARAM_BLD_to_param(_OSSL_PARAM_BLD_PTR bld) __attribute__((tag("3"))); diff --git a/internal/ossl/zossl.c b/internal/ossl/zossl.c index 0219d15b..1444f551 100644 --- a/internal/ossl/zossl.c +++ b/internal/ossl/zossl.c @@ -98,7 +98,11 @@ int (*_g_EVP_MAC_update)(_EVP_MAC_CTX_PTR, const unsigned char*, size_t); int (*_g_EVP_MD_CTX_copy)(_EVP_MD_CTX_PTR, const _EVP_MD_CTX_PTR); int (*_g_EVP_MD_CTX_copy_ex)(_EVP_MD_CTX_PTR, const _EVP_MD_CTX_PTR); void (*_g_EVP_MD_CTX_free)(_EVP_MD_CTX_PTR); +int (*_g_EVP_MD_CTX_get_params)(_EVP_MD_CTX_PTR, _OSSL_PARAM_PTR); +const _OSSL_PARAM_PTR (*_g_EVP_MD_CTX_gettable_params)(_EVP_MD_CTX_PTR); _EVP_MD_CTX_PTR (*_g_EVP_MD_CTX_new)(void); +int (*_g_EVP_MD_CTX_set_params)(_EVP_MD_CTX_PTR, const _OSSL_PARAM_PTR); +const _OSSL_PARAM_PTR (*_g_EVP_MD_CTX_settable_params)(_EVP_MD_CTX_PTR); _EVP_MD_PTR (*_g_EVP_MD_fetch)(_OSSL_LIB_CTX_PTR, const char*, const char*); void (*_g_EVP_MD_free)(_EVP_MD_PTR); const char* (*_g_EVP_MD_get0_name)(const _EVP_MD_PTR); @@ -212,6 +216,7 @@ int (*_g_OSSL_PARAM_BLD_push_octet_string)(_OSSL_PARAM_BLD_PTR, const char*, con int (*_g_OSSL_PARAM_BLD_push_utf8_string)(_OSSL_PARAM_BLD_PTR, const char*, const char*, size_t); _OSSL_PARAM_PTR (*_g_OSSL_PARAM_BLD_to_param)(_OSSL_PARAM_BLD_PTR); void (*_g_OSSL_PARAM_free)(_OSSL_PARAM_PTR); +const _OSSL_PARAM_PTR (*_g_OSSL_PARAM_locate_const)(const _OSSL_PARAM_PTR, const char*); int (*_g_OSSL_PROVIDER_available)(_OSSL_LIB_CTX_PTR, const char*); const char* (*_g_OSSL_PROVIDER_get0_name)(const _OSSL_PROVIDER_PTR); _OSSL_PROVIDER_PTR (*_g_OSSL_PROVIDER_try_load)(_OSSL_LIB_CTX_PTR, const char*, int); @@ -494,6 +499,10 @@ void __mkcgo_load_3(void* handle) { __mkcgo__dlsym(EVP_MAC_final) __mkcgo__dlsym(EVP_MAC_init) __mkcgo__dlsym(EVP_MAC_update) + __mkcgo__dlsym(EVP_MD_CTX_get_params) + __mkcgo__dlsym(EVP_MD_CTX_gettable_params) + __mkcgo__dlsym(EVP_MD_CTX_set_params) + __mkcgo__dlsym(EVP_MD_CTX_settable_params) __mkcgo__dlsym(EVP_MD_fetch) __mkcgo__dlsym(EVP_MD_free) __mkcgo__dlsym(EVP_MD_get0_name) @@ -531,6 +540,7 @@ void __mkcgo_load_3(void* handle) { __mkcgo__dlsym(OSSL_PARAM_BLD_push_utf8_string) __mkcgo__dlsym(OSSL_PARAM_BLD_to_param) __mkcgo__dlsym(OSSL_PARAM_free) + __mkcgo__dlsym(OSSL_PARAM_locate_const) __mkcgo__dlsym(OSSL_PROVIDER_available) __mkcgo__dlsym(OSSL_PROVIDER_get0_name) __mkcgo__dlsym(OSSL_PROVIDER_try_load) @@ -557,6 +567,10 @@ void __mkcgo_unload_3() { _g_EVP_MAC_final = NULL; _g_EVP_MAC_init = NULL; _g_EVP_MAC_update = NULL; + _g_EVP_MD_CTX_get_params = NULL; + _g_EVP_MD_CTX_gettable_params = NULL; + _g_EVP_MD_CTX_set_params = NULL; + _g_EVP_MD_CTX_settable_params = NULL; _g_EVP_MD_fetch = NULL; _g_EVP_MD_free = NULL; _g_EVP_MD_get0_name = NULL; @@ -594,6 +608,7 @@ void __mkcgo_unload_3() { _g_OSSL_PARAM_BLD_push_utf8_string = NULL; _g_OSSL_PARAM_BLD_to_param = NULL; _g_OSSL_PARAM_free = NULL; + _g_OSSL_PARAM_locate_const = NULL; _g_OSSL_PROVIDER_available = NULL; _g_OSSL_PROVIDER_get0_name = NULL; _g_OSSL_PROVIDER_try_load = NULL; @@ -1253,6 +1268,20 @@ void _mkcgo_EVP_MD_CTX_free(_EVP_MD_CTX_PTR _arg0) { _g_EVP_MD_CTX_free(_arg0); } +int _mkcgo_EVP_MD_CTX_get_params(_EVP_MD_CTX_PTR _arg0, _OSSL_PARAM_PTR _arg1, mkcgo_err_state *_err_state) { + mkcgo_err_clear(); + int _ret = _g_EVP_MD_CTX_get_params(_arg0, _arg1); + if (_ret <= 0) *_err_state = mkcgo_err_retrieve(); + return _ret; +} + +const _OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_gettable_params(_EVP_MD_CTX_PTR _arg0, mkcgo_err_state *_err_state) { + mkcgo_err_clear(); + const _OSSL_PARAM_PTR _ret = _g_EVP_MD_CTX_gettable_params(_arg0); + if (_ret <= 0) *_err_state = mkcgo_err_retrieve(); + return _ret; +} + _EVP_MD_CTX_PTR _mkcgo_EVP_MD_CTX_new(mkcgo_err_state *_err_state) { mkcgo_err_clear(); _EVP_MD_CTX_PTR _ret = _g_EVP_MD_CTX_new(); @@ -1260,6 +1289,20 @@ _EVP_MD_CTX_PTR _mkcgo_EVP_MD_CTX_new(mkcgo_err_state *_err_state) { return _ret; } +int _mkcgo_EVP_MD_CTX_set_params(_EVP_MD_CTX_PTR _arg0, const _OSSL_PARAM_PTR _arg1, mkcgo_err_state *_err_state) { + mkcgo_err_clear(); + int _ret = _g_EVP_MD_CTX_set_params(_arg0, _arg1); + if (_ret <= 0) *_err_state = mkcgo_err_retrieve(); + return _ret; +} + +const _OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_settable_params(_EVP_MD_CTX_PTR _arg0, mkcgo_err_state *_err_state) { + mkcgo_err_clear(); + const _OSSL_PARAM_PTR _ret = _g_EVP_MD_CTX_settable_params(_arg0); + if (_ret <= 0) *_err_state = mkcgo_err_retrieve(); + return _ret; +} + _EVP_MD_PTR _mkcgo_EVP_MD_fetch(_OSSL_LIB_CTX_PTR _arg0, const char* _arg1, const char* _arg2, mkcgo_err_state *_err_state) { mkcgo_err_clear(); _EVP_MD_PTR _ret = _g_EVP_MD_fetch(_arg0, _arg1, _arg2); @@ -1924,6 +1967,13 @@ void _mkcgo_OSSL_PARAM_free(_OSSL_PARAM_PTR _arg0) { _g_OSSL_PARAM_free(_arg0); } +const _OSSL_PARAM_PTR _mkcgo_OSSL_PARAM_locate_const(const _OSSL_PARAM_PTR _arg0, const char* _arg1, mkcgo_err_state *_err_state) { + mkcgo_err_clear(); + const _OSSL_PARAM_PTR _ret = _g_OSSL_PARAM_locate_const(_arg0, _arg1); + if (_ret <= 0) *_err_state = mkcgo_err_retrieve(); + return _ret; +} + int _mkcgo_OSSL_PROVIDER_available(_OSSL_LIB_CTX_PTR _arg0, const char* _arg1) { return _g_OSSL_PROVIDER_available(_arg0, _arg1); } diff --git a/internal/ossl/zossl.go b/internal/ossl/zossl.go index 568e410a..9079f7a0 100644 --- a/internal/ossl/zossl.go +++ b/internal/ossl/zossl.go @@ -61,6 +61,8 @@ const ( EVP_PKEY_CTRL_RSA_OAEP_LABEL = 0x100A EVP_PKEY_CTRL_DSA_PARAMGEN_BITS = 0x1001 EVP_PKEY_CTRL_DSA_PARAMGEN_Q_BITS = 0x1002 + OSSL_PARAM_INTEGER = 1 + OSSL_PARAM_OCTET_STRING = 5 ) type BIGNUM_PTR = C._BIGNUM_PTR @@ -612,12 +614,36 @@ func EVP_MD_CTX_free(ctx EVP_MD_CTX_PTR) { C._mkcgo_EVP_MD_CTX_free(ctx) } +func EVP_MD_CTX_get_params(ctx EVP_MD_CTX_PTR, params OSSL_PARAM_PTR) (int32, error) { + var _err C.mkcgo_err_state + _ret := C._mkcgo_EVP_MD_CTX_get_params(ctx, params, mkcgoNoEscape(&_err)) + return int32(_ret), newMkcgoErr("EVP_MD_CTX_get_params", _err) +} + +func EVP_MD_CTX_gettable_params(ctx EVP_MD_CTX_PTR) (OSSL_PARAM_PTR, error) { + var _err C.mkcgo_err_state + _ret := C._mkcgo_EVP_MD_CTX_gettable_params(ctx, mkcgoNoEscape(&_err)) + return _ret, newMkcgoErr("EVP_MD_CTX_gettable_params", _err) +} + func EVP_MD_CTX_new() (EVP_MD_CTX_PTR, error) { var _err C.mkcgo_err_state _ret := C._mkcgo_EVP_MD_CTX_new(mkcgoNoEscape(&_err)) return _ret, newMkcgoErr("EVP_MD_CTX_new", _err) } +func EVP_MD_CTX_set_params(ctx EVP_MD_CTX_PTR, params OSSL_PARAM_PTR) (int32, error) { + var _err C.mkcgo_err_state + _ret := C._mkcgo_EVP_MD_CTX_set_params(ctx, params, mkcgoNoEscape(&_err)) + return int32(_ret), newMkcgoErr("EVP_MD_CTX_set_params", _err) +} + +func EVP_MD_CTX_settable_params(ctx EVP_MD_CTX_PTR) (OSSL_PARAM_PTR, error) { + var _err C.mkcgo_err_state + _ret := C._mkcgo_EVP_MD_CTX_settable_params(ctx, mkcgoNoEscape(&_err)) + return _ret, newMkcgoErr("EVP_MD_CTX_settable_params", _err) +} + func EVP_MD_fetch(ctx OSSL_LIB_CTX_PTR, algorithm *byte, properties *byte) (EVP_MD_PTR, error) { var _err C.mkcgo_err_state _ret := C._mkcgo_EVP_MD_fetch(ctx, (*C.char)(unsafe.Pointer(algorithm)), (*C.char)(unsafe.Pointer(properties)), mkcgoNoEscape(&_err)) @@ -1218,6 +1244,12 @@ func OSSL_PARAM_free(p OSSL_PARAM_PTR) { C._mkcgo_OSSL_PARAM_free(p) } +func OSSL_PARAM_locate_const(p OSSL_PARAM_PTR, key *byte) (OSSL_PARAM_PTR, error) { + var _err C.mkcgo_err_state + _ret := C._mkcgo_OSSL_PARAM_locate_const(p, (*C.char)(unsafe.Pointer(key)), mkcgoNoEscape(&_err)) + return _ret, newMkcgoErr("OSSL_PARAM_locate_const", _err) +} + func OSSL_PROVIDER_available(libctx OSSL_LIB_CTX_PTR, name *byte) int32 { return int32(C._mkcgo_OSSL_PROVIDER_available(libctx, (*C.char)(unsafe.Pointer(name)))) } diff --git a/internal/ossl/zossl.h b/internal/ossl/zossl.h index 66249349..ea01eed7 100644 --- a/internal/ossl/zossl.h +++ b/internal/ossl/zossl.h @@ -83,6 +83,8 @@ enum { _EVP_PKEY_CTRL_RSA_OAEP_LABEL = 0x100A, _EVP_PKEY_CTRL_DSA_PARAMGEN_BITS = 0x1001, _EVP_PKEY_CTRL_DSA_PARAMGEN_Q_BITS = 0x1002, + _OSSL_PARAM_INTEGER = 1, + _OSSL_PARAM_OCTET_STRING = 5, }; typedef void* mkcgo_err_state; @@ -195,7 +197,11 @@ int _mkcgo_EVP_MAC_update(_EVP_MAC_CTX_PTR, const unsigned char*, size_t, mkcgo_ int _mkcgo_EVP_MD_CTX_copy(_EVP_MD_CTX_PTR, const _EVP_MD_CTX_PTR, mkcgo_err_state *); int _mkcgo_EVP_MD_CTX_copy_ex(_EVP_MD_CTX_PTR, const _EVP_MD_CTX_PTR, mkcgo_err_state *); void _mkcgo_EVP_MD_CTX_free(_EVP_MD_CTX_PTR); +int _mkcgo_EVP_MD_CTX_get_params(_EVP_MD_CTX_PTR, _OSSL_PARAM_PTR, mkcgo_err_state *); +const _OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_gettable_params(_EVP_MD_CTX_PTR, mkcgo_err_state *); _EVP_MD_CTX_PTR _mkcgo_EVP_MD_CTX_new(mkcgo_err_state *); +int _mkcgo_EVP_MD_CTX_set_params(_EVP_MD_CTX_PTR, const _OSSL_PARAM_PTR, mkcgo_err_state *); +const _OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_settable_params(_EVP_MD_CTX_PTR, mkcgo_err_state *); _EVP_MD_PTR _mkcgo_EVP_MD_fetch(_OSSL_LIB_CTX_PTR, const char*, const char*, mkcgo_err_state *); void _mkcgo_EVP_MD_free(_EVP_MD_PTR); const char* _mkcgo_EVP_MD_get0_name(const _EVP_MD_PTR); @@ -311,6 +317,7 @@ int _mkcgo_OSSL_PARAM_BLD_push_octet_string(_OSSL_PARAM_BLD_PTR, const char*, co int _mkcgo_OSSL_PARAM_BLD_push_utf8_string(_OSSL_PARAM_BLD_PTR, const char*, const char*, size_t, mkcgo_err_state *); _OSSL_PARAM_PTR _mkcgo_OSSL_PARAM_BLD_to_param(_OSSL_PARAM_BLD_PTR, mkcgo_err_state *); void _mkcgo_OSSL_PARAM_free(_OSSL_PARAM_PTR); +const _OSSL_PARAM_PTR _mkcgo_OSSL_PARAM_locate_const(const _OSSL_PARAM_PTR, const char*, mkcgo_err_state *); int _mkcgo_OSSL_PROVIDER_available(_OSSL_LIB_CTX_PTR, const char*); const char* _mkcgo_OSSL_PROVIDER_get0_name(const _OSSL_PROVIDER_PTR); _OSSL_PROVIDER_PTR _mkcgo_OSSL_PROVIDER_try_load(_OSSL_LIB_CTX_PTR, const char*, int, mkcgo_err_state *); diff --git a/provideropenssl.go b/provideropenssl.go new file mode 100644 index 00000000..1d8f3ad0 --- /dev/null +++ b/provideropenssl.go @@ -0,0 +1,239 @@ +//go:build !cmd_go_bootstrap && cgo + +package openssl + +import ( + "crypto" + "errors" + "unsafe" + + "github.com/golang-fips/openssl/v2/internal/ossl" +) + +// This file contains code specific to the built-in OpenSSL providers. + +// _OSSL_MD5_CTX layout is taken from +// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/md5.h#L33. +type _OSSL_MD5_CTX struct { + h [4]uint32 + nl, nh uint32 + x [64]byte + nx uint32 +} + +func (d *_OSSL_MD5_CTX) UnmarshalBinary(b []byte) error { + b, d.h[0] = consumeUint32(b) + b, d.h[1] = consumeUint32(b) + b, d.h[2] = consumeUint32(b) + b, d.h[3] = consumeUint32(b) + b = b[copy(d.x[:], b):] + _, n := consumeUint64(b) + d.nl = uint32(n << 3) + d.nh = uint32(n >> 29) + d.nx = uint32(n) % 64 + return nil +} + +func (d *_OSSL_MD5_CTX) AppendBinary(buf []byte) ([]byte, error) { + buf = appendUint32(buf, d.h[0]) + buf = appendUint32(buf, d.h[1]) + buf = appendUint32(buf, d.h[2]) + buf = appendUint32(buf, d.h[3]) + buf = append(buf, d.x[:d.nx]...) + buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) + buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) + return buf, nil +} + +// _OSSL_SHA_CTX layout is taken from +// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L34. +type _OSSL_SHA_CTX struct { + h [5]uint32 + nl, nh uint32 + x [64]byte + nx uint32 +} + +func (d *_OSSL_SHA_CTX) UnmarshalBinary(b []byte) error { + b, d.h[0] = consumeUint32(b) + b, d.h[1] = consumeUint32(b) + b, d.h[2] = consumeUint32(b) + b, d.h[3] = consumeUint32(b) + b, d.h[4] = consumeUint32(b) + b = b[copy(d.x[:], b):] + _, n := consumeUint64(b) + d.nl = uint32(n << 3) + d.nh = uint32(n >> 29) + d.nx = uint32(n) % 64 + return nil +} + +func (d *_OSSL_SHA_CTX) AppendBinary(buf []byte) ([]byte, error) { + buf = appendUint32(buf, d.h[0]) + buf = appendUint32(buf, d.h[1]) + buf = appendUint32(buf, d.h[2]) + buf = appendUint32(buf, d.h[3]) + buf = appendUint32(buf, d.h[4]) + buf = append(buf, d.x[:d.nx]...) + buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) + buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) + return buf, nil +} + +// _OSSL_SHA256_CTX layout is taken from +// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L51. +type _OSSL_SHA256_CTX struct { + h [8]uint32 + nl, nh uint32 + x [64]byte + nx uint32 +} + +func (d *_OSSL_SHA256_CTX) UnmarshalBinary(b []byte) error { + b, d.h[0] = consumeUint32(b) + b, d.h[1] = consumeUint32(b) + b, d.h[2] = consumeUint32(b) + b, d.h[3] = consumeUint32(b) + b, d.h[4] = consumeUint32(b) + b, d.h[5] = consumeUint32(b) + b, d.h[6] = consumeUint32(b) + b, d.h[7] = consumeUint32(b) + b = b[copy(d.x[:], b):] + _, n := consumeUint64(b) + d.nl = uint32(n << 3) + d.nh = uint32(n >> 29) + d.nx = uint32(n) % 64 + return nil +} + +func (d *_OSSL_SHA256_CTX) AppendBinary(buf []byte) ([]byte, error) { + buf = appendUint32(buf, d.h[0]) + buf = appendUint32(buf, d.h[1]) + buf = appendUint32(buf, d.h[2]) + buf = appendUint32(buf, d.h[3]) + buf = appendUint32(buf, d.h[4]) + buf = appendUint32(buf, d.h[5]) + buf = appendUint32(buf, d.h[6]) + buf = appendUint32(buf, d.h[7]) + buf = append(buf, d.x[:d.nx]...) + buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) + buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) + return buf, nil +} + +// _OSSL_SHA512_CTX layout is taken from +// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L95. +type _OSSL_SHA512_CTX struct { + h [8]uint64 + nl, nh uint64 + x [128]byte + nx uint32 +} + +func (d *_OSSL_SHA512_CTX) UnmarshalBinary(b []byte) error { + b, d.h[0] = consumeUint64(b) + b, d.h[1] = consumeUint64(b) + b, d.h[2] = consumeUint64(b) + b, d.h[3] = consumeUint64(b) + b, d.h[4] = consumeUint64(b) + b, d.h[5] = consumeUint64(b) + b, d.h[6] = consumeUint64(b) + b, d.h[7] = consumeUint64(b) + b = b[copy(d.x[:], b):] + _, n := consumeUint64(b) + d.nl = n << 3 + d.nh = n >> 61 + d.nx = uint32(n) % 128 + return nil +} + +func (d *_OSSL_SHA512_CTX) AppendBinary(buf []byte) ([]byte, error) { + buf = appendUint64(buf, d.h[0]) + buf = appendUint64(buf, d.h[1]) + buf = appendUint64(buf, d.h[2]) + buf = appendUint64(buf, d.h[3]) + buf = appendUint64(buf, d.h[4]) + buf = appendUint64(buf, d.h[5]) + buf = appendUint64(buf, d.h[6]) + buf = appendUint64(buf, d.h[7]) + buf = append(buf, d.x[:d.nx]...) + buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) + buf = appendUint64(buf, d.nl>>3|d.nh<<61) + return buf, nil +} + +func getOSSLDigetsContext(ctx ossl.EVP_MD_CTX_PTR) unsafe.Pointer { + switch vMajor { + case 1: + // https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/crypto/evp/evp_local.h#L12. + type mdCtx struct { + _ [2]unsafe.Pointer + _ uint32 + md_data unsafe.Pointer + } + return (*mdCtx)(unsafe.Pointer(ctx)).md_data + case 3: + // The EVP_MD_CTX memory layout has changed in OpenSSL 3 + // and the property holding the internal structure is no longer md_data but algctx. + // https://github.com/openssl/openssl/blob/5675a5aaf6a2e489022bcfc18330dae9263e598e/crypto/evp/evp_local.h#L16. + type mdCtx struct { + _ [3]unsafe.Pointer + _ uint32 + _ [3]unsafe.Pointer + algctx unsafe.Pointer + } + return (*mdCtx)(unsafe.Pointer(ctx)).algctx + default: + panic(errUnsupportedVersion()) + } +} + +var errHashStateInvalid = errors.New("openssl: can't retrieve hash state") + +func osslHashAppendBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, magic string, buf []byte) ([]byte, error) { + algctx := getOSSLDigetsContext(ctx) + if algctx == nil { + return nil, errHashStateInvalid + } + buf = append(buf, magic...) + switch ch { + case crypto.MD5: + d := (*_OSSL_MD5_CTX)(unsafe.Pointer(algctx)) + return d.AppendBinary(buf) + case crypto.SHA1: + d := (*_OSSL_SHA_CTX)(unsafe.Pointer(algctx)) + return d.AppendBinary(buf) + case crypto.SHA224, crypto.SHA256: + d := (*_OSSL_SHA256_CTX)(unsafe.Pointer(algctx)) + return d.AppendBinary(buf) + case crypto.SHA384, crypto.SHA512_224, crypto.SHA512_256, crypto.SHA512: + d := (*_OSSL_SHA512_CTX)(unsafe.Pointer(algctx)) + return d.AppendBinary(buf) + default: + panic("unsupported hash " + ch.String()) + } +} + +func osslHashUnmarshalBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, magic string, b []byte) error { + algctx := getOSSLDigetsContext(ctx) + if algctx == nil { + return errHashStateInvalid + } + b = b[len(magic):] + switch ch { + case crypto.MD5: + d := (*_OSSL_MD5_CTX)(unsafe.Pointer(algctx)) + return d.UnmarshalBinary(b) + case crypto.SHA1: + d := (*_OSSL_SHA_CTX)(unsafe.Pointer(algctx)) + return d.UnmarshalBinary(b) + case crypto.SHA224, crypto.SHA256: + d := (*_OSSL_SHA256_CTX)(unsafe.Pointer(algctx)) + return d.UnmarshalBinary(b) + case crypto.SHA384, crypto.SHA512_224, crypto.SHA512_256, crypto.SHA512: + d := (*_OSSL_SHA512_CTX)(unsafe.Pointer(algctx)) + return d.UnmarshalBinary(b) + default: + panic("unsupported hash " + ch.String()) + } +} diff --git a/providersymcrypt.go b/providersymcrypt.go new file mode 100644 index 00000000..693849f1 --- /dev/null +++ b/providersymcrypt.go @@ -0,0 +1,338 @@ +//go:build !cmd_go_bootstrap && cgo + +package openssl + +import ( + "crypto" + "errors" + "runtime" + "unsafe" + + "github.com/golang-fips/openssl/v2/internal/ossl" +) + +// This file contains code specific to the SymCrypt provider. + +const ( + _SCOSSL_DIGEST_PARAM_STATE cString = "state\x00" + _SCOSSL_DIGEST_PARAM_RECOMPUTE_CHECKSUM cString = "recompute_checksum\x00" +) + +const ( + _SYMCRYPT_BLOB_MAGIC = 0x636D7973 // "cysm" in little-endian + + _SymCryptBlobTypeHashState = 0x100 + _SymCryptBlobTypeMd2State = _SymCryptBlobTypeHashState + 1 + _SymCryptBlobTypeMd4State = _SymCryptBlobTypeHashState + 2 + _SymCryptBlobTypeMd5State = _SymCryptBlobTypeHashState + 3 + _SymCryptBlobTypeSha1State = _SymCryptBlobTypeHashState + 4 + _SymCryptBlobTypeSha256State = _SymCryptBlobTypeHashState + 5 + _SymCryptBlobTypeSha384State = _SymCryptBlobTypeHashState + 6 + _SymCryptBlobTypeSha512State = _SymCryptBlobTypeHashState + 7 + _SymCryptBlobTypeSha3_256State = _SymCryptBlobTypeHashState + 8 + _SymCryptBlobTypeSha3_384State = _SymCryptBlobTypeHashState + 9 + _SymCryptBlobTypeSha3_512State = _SymCryptBlobTypeHashState + 10 + _SymCryptBlobTypeSha224State = _SymCryptBlobTypeHashState + 11 + _SymCryptBlobTypeSha512_224State = _SymCryptBlobTypeHashState + 12 + _SymCryptBlobTypeSha512_256State = _SymCryptBlobTypeHashState + 13 + _SymCryptBlobTypeSha3_224State = _SymCryptBlobTypeHashState + 14 + + _SYMCRYPT_MD5_STATE_EXPORT_SIZE = uint32(unsafe.Sizeof(_SYMCRYPT_MD5_STATE_EXPORT_BLOB{})) + _SYMCRYPT_SHA1_STATE_EXPORT_SIZE = uint32(unsafe.Sizeof(_SYMCRYPT_SHA1_STATE_EXPORT_BLOB{})) + _SYMCRYPT_SHA256_STATE_EXPORT_SIZE = uint32(unsafe.Sizeof(_SYMCRYPT_SHA256_STATE_EXPORT_BLOB{})) + _SYMCRYPT_SHA512_STATE_EXPORT_SIZE = uint32(unsafe.Sizeof(_SYMCRYPT_SHA512_STATE_EXPORT_BLOB{})) +) + +type _SYMCRYPT_BLOB_HEADER struct { + magic uint32 + size uint32 + _type uint32 +} + +type _SYMCRYPT_BLOB_TRAILER struct { + checksum [8]uint8 +} + +// _UINT64 is a 64-bit unsigned integer, stored in native endianess. +// It is used to represent a SymCrypt UINT64 type without making the +// parent struct 8-byte aligned, given that the Windows ABI makes +// the struct 4-byte aligned. +type _UINT64 [2]uint32 + +func newUINT64(v uint64) _UINT64 { + var u _UINT64 + if isBigEndian { + u[0], u[1] = uint32(v>>32), uint32(v) + } else { + u[0], u[1] = uint32(v), uint32(v>>32) + } + return u +} + +func (u *_UINT64) uint64() uint64 { + if isBigEndian { + return uint64(u[0])<<32 | (uint64(u[1])) + } + return uint64(u[0]) | (uint64(u[1]) << 32) +} + +// symCryptAppendBinary appends the binary representation of a SymCrypt state +// to the given destination slice. +func symCryptAppendBinary(dst, chain, buffer []byte, blength _UINT64) []byte { + length := blength.uint64() + var nx uint64 + if len(buffer) <= 64 { + nx = length & 0x3f + } else { + nx = length & 0x7f + } + dst = append(dst, chain...) + dst = append(dst, buffer[:nx]...) + dst = append(dst, make([]byte, len(buffer)-int(nx))...) + dst = appendUint64(dst, length) + return dst +} + +// symCryptUnmarshalBinary unmarshals the binary representation of a SymCrypt state +// from the given source slice. It returns the length of the data. +func symCryptUnmarshalBinary(d []byte, chain, buffer []byte) _UINT64 { + copy(chain[:], d) + d = d[len(chain):] + copy(buffer[:], d) + d = d[len(buffer):] + _, length := consumeUint64(d) + return newUINT64(length) +} + +// swapEndianessUint32 swaps the endianness of the given byte slice +// in place. It assumes the slice is a backup of a 32-bit integer array. +func swapEndianessUint32(d []uint8) { + for i := 0; i < len(d); i += 4 { + d[i], d[i+3] = d[i+3], d[i] + d[i+1], d[i+2] = d[i+2], d[i+1] + } + +} + +type _SYMCRYPT_MD5_STATE_EXPORT_BLOB struct { + header _SYMCRYPT_BLOB_HEADER + chain [16]uint8 // little endian + length _UINT64 // native endian + buffer [64]uint8 + _ [8]uint8 // reserved + _ _SYMCRYPT_BLOB_TRAILER +} + +func (b *_SYMCRYPT_MD5_STATE_EXPORT_BLOB) appendBinary(d []byte) ([]byte, error) { + // b.chain is little endian, but Go expects big endian, + // we need to swap the bytes. + swapEndianessUint32(b.chain[:]) + return symCryptAppendBinary(d, b.chain[:], b.buffer[:], b.length), nil +} + +func (b *_SYMCRYPT_MD5_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { + b.length = symCryptUnmarshalBinary(d, b.chain[:], b.buffer[:]) + swapEndianessUint32(b.chain[:]) +} + +type _SYMCRYPT_SHA1_STATE_EXPORT_BLOB struct { + header _SYMCRYPT_BLOB_HEADER + chain [20]uint8 // big endian + length _UINT64 // native endian + buffer [64]uint8 + _ [8]uint8 // reserved + _ _SYMCRYPT_BLOB_TRAILER +} + +func (b *_SYMCRYPT_SHA1_STATE_EXPORT_BLOB) appendBinary(d []byte) ([]byte, error) { + return symCryptAppendBinary(d, b.chain[:], b.buffer[:], b.length), nil +} + +func (b *_SYMCRYPT_SHA1_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { + b.length = symCryptUnmarshalBinary(d, b.chain[:], b.buffer[:]) +} + +type _SYMCRYPT_SHA256_STATE_EXPORT_BLOB struct { + header _SYMCRYPT_BLOB_HEADER + chain [32]uint8 // big endian + length _UINT64 // native endian + buffer [64]uint8 + _ [8]uint8 // reserved + _ _SYMCRYPT_BLOB_TRAILER +} + +func (b *_SYMCRYPT_SHA256_STATE_EXPORT_BLOB) appendBinary(d []byte) ([]byte, error) { + return symCryptAppendBinary(d, b.chain[:], b.buffer[:], b.length), nil +} + +func (b *_SYMCRYPT_SHA256_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { + b.length = symCryptUnmarshalBinary(d, b.chain[:], b.buffer[:]) +} + +type _SYMCRYPT_SHA512_STATE_EXPORT_BLOB struct { + header _SYMCRYPT_BLOB_HEADER + chain [64]uint8 // big endian + lengthL _UINT64 // native endian + lengthH _UINT64 // native endian + buffer [128]uint8 + _ [8]uint8 // reserved + _ _SYMCRYPT_BLOB_TRAILER +} + +func (b *_SYMCRYPT_SHA512_STATE_EXPORT_BLOB) appendBinary(d []byte) ([]byte, error) { + if b.lengthH.uint64() != 0 { + return nil, errors.New("exporting state with more than 2^63-1 bytes of data is not supported") + } + return symCryptAppendBinary(d, b.chain[:], b.buffer[:], b.lengthL), nil +} + +func (b *_SYMCRYPT_SHA512_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { + b.lengthL = symCryptUnmarshalBinary(d, b.chain[:], b.buffer[:]) +} + +func symCryptHashAppendBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, magic string, buf []byte) ([]byte, error) { + size, typ := symCryptHashStateInfo(ch) + state := make([]byte, size, _SYMCRYPT_SHA512_STATE_EXPORT_SIZE) // 512 is the largest size + var pinner runtime.Pinner + pinner.Pin(&state[0]) + defer pinner.Unpin() + params := [2]ossl.OSSL_PARAM{ + ossl.OSSL_PARAM_construct_octet_string(_SCOSSL_DIGEST_PARAM_STATE.ptr(), unsafe.Pointer(&state[0]), len(state)), + ossl.OSSL_PARAM_construct_end(), + } + if _, err := ossl.EVP_MD_CTX_get_params(ctx, (ossl.OSSL_PARAM_PTR)(unsafe.Pointer(¶ms[0]))); err != nil { + return nil, err + } + if !ossl.OSSL_PARAM_modified(¶ms[0]) { + return nil, errors.New("EVP_MD_CTX_get_params did not retrieve the state") + } + + header := (*_SYMCRYPT_BLOB_HEADER)(unsafe.Pointer(&state[0])) + if header.magic != _SYMCRYPT_BLOB_MAGIC { + return nil, errors.New("invalid blob magic") + } + if header.size != size { + return nil, errors.New("invalid blob size") + } + if header._type != typ { + return nil, errors.New("invalid blob type") + } + + buf = append(buf, magic...) + switch ch { + case crypto.MD5: + blob := (*_SYMCRYPT_MD5_STATE_EXPORT_BLOB)(unsafe.Pointer(&state[0])) + return blob.appendBinary(buf) + case crypto.SHA1: + blob := (*_SYMCRYPT_SHA1_STATE_EXPORT_BLOB)(unsafe.Pointer(&state[0])) + return blob.appendBinary(buf) + case crypto.SHA224, crypto.SHA256: + blob := (*_SYMCRYPT_SHA256_STATE_EXPORT_BLOB)(unsafe.Pointer(&state[0])) + return blob.appendBinary(buf) + case crypto.SHA384, crypto.SHA512_224, crypto.SHA512_256, crypto.SHA512: + blob := (*_SYMCRYPT_SHA512_STATE_EXPORT_BLOB)(unsafe.Pointer(&state[0])) + return blob.appendBinary(buf) + default: + panic("unsupported hash " + ch.String()) + } +} + +func symCryptHashUnmarshalBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, magic string, b []byte) error { + size, typ := symCryptHashStateInfo(ch) + hdr := _SYMCRYPT_BLOB_HEADER{ + magic: _SYMCRYPT_BLOB_MAGIC, + size: size, + _type: typ, + } + var blobPtr unsafe.Pointer + b = b[len(magic):] + switch ch { + case crypto.MD5: + var blob _SYMCRYPT_MD5_STATE_EXPORT_BLOB + blobPtr = unsafe.Pointer(&blob) + blob.header = hdr + blob.unmarshalBinary(b) + case crypto.SHA1: + var blob _SYMCRYPT_SHA1_STATE_EXPORT_BLOB + blobPtr = unsafe.Pointer(&blob) + blob.header = hdr + blob.unmarshalBinary(b) + case crypto.SHA224, crypto.SHA256: + var blob _SYMCRYPT_SHA256_STATE_EXPORT_BLOB + blobPtr = unsafe.Pointer(&blob) + blob.header = hdr + blob.unmarshalBinary(b) + case crypto.SHA384, crypto.SHA512_224, crypto.SHA512_256, crypto.SHA512: + var blob _SYMCRYPT_SHA512_STATE_EXPORT_BLOB + blobPtr = unsafe.Pointer(&blob) + blob.header = hdr + blob.unmarshalBinary(b) + default: + panic("unsupported hash " + ch.String()) + } + var checksum int32 = 1 + var pinner runtime.Pinner + pinner.Pin(blobPtr) + pinner.Pin(&checksum) + defer pinner.Unpin() + params := [3]ossl.OSSL_PARAM{ + ossl.OSSL_PARAM_construct_octet_string(_SCOSSL_DIGEST_PARAM_STATE.ptr(), blobPtr, int(hdr.size)), + ossl.OSSL_PARAM_construct_int32(_SCOSSL_DIGEST_PARAM_RECOMPUTE_CHECKSUM.ptr(), &checksum), + ossl.OSSL_PARAM_construct_end(), + } + _, err := ossl.EVP_MD_CTX_set_params(ctx, (ossl.OSSL_PARAM_PTR)(unsafe.Pointer(¶ms[0]))) + return err +} + +func symCryptHashStateInfo(ch crypto.Hash) (size, typ uint32) { + switch ch { + case crypto.MD5: + return _SYMCRYPT_MD5_STATE_EXPORT_SIZE, _SymCryptBlobTypeMd5State + case crypto.SHA1: + return _SYMCRYPT_SHA1_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha1State + case crypto.SHA224: + return _SYMCRYPT_SHA256_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha224State + case crypto.SHA256: + return _SYMCRYPT_SHA256_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha256State + case crypto.SHA384: + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha384State + case crypto.SHA512_224: + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512_224State + case crypto.SHA512_256: + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512_256State + case crypto.SHA512: + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512State + default: + panic("unsupported hash " + ch.String()) + } +} + +// isSymCryptHashStateSerializable checks if the SymCrypt hash state is serializable. +func isSymCryptHashStateSerializable(md ossl.EVP_MD_PTR) bool { + ctx, err := ossl.EVP_MD_CTX_new() + if err != nil { + return false + } + defer ossl.EVP_MD_CTX_free(ctx) + if _, err := ossl.EVP_DigestInit_ex(ctx, md, nil); err != nil { + return false + } + params, err := ossl.EVP_MD_CTX_gettable_params(ctx) + if err != nil { + return false + } + if _, err = ossl.OSSL_PARAM_locate_const(params, _SCOSSL_DIGEST_PARAM_STATE.ptr()); err != nil { + return false + } + params, err = ossl.EVP_MD_CTX_settable_params(ctx) + if err != nil { + return false + } + if _, err = ossl.OSSL_PARAM_locate_const(params, _SCOSSL_DIGEST_PARAM_STATE.ptr()); err != nil { + return false + } + if _, err = ossl.OSSL_PARAM_locate_const(params, _SCOSSL_DIGEST_PARAM_RECOMPUTE_CHECKSUM.ptr()); err != nil { + return false + } + return true +}