From 5a76325112330d1d5dc1c37748c4d32feb9ac299 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 15 May 2025 16:16:07 +0200 Subject: [PATCH 01/11] support serializing SymCrypt hash objects --- const.go | 2 +- evp.go | 59 ++++--- hash.go | 337 ++++++------------------------------- hash_test.go | 49 ++++-- internal/ossl/ossl.go | 33 +++- internal/ossl/shims.h | 7 + internal/ossl/zossl.c | 50 ++++++ internal/ossl/zossl.go | 31 ++++ internal/ossl/zossl.h | 6 + params.go | 6 +- provideropenssl.go | 239 ++++++++++++++++++++++++++ providersymcrypt.go | 372 +++++++++++++++++++++++++++++++++++++++++ 12 files changed, 864 insertions(+), 327 deletions(-) create mode 100644 provideropenssl.go create mode 100644 providersymcrypt.go diff --git a/const.go b/const.go index e4aaf3bb..9f99a9dc 100644 --- a/const.go +++ b/const.go @@ -50,7 +50,7 @@ const ( //checkheader:ignore // KDF names _OSSL_KDF_NAME_HKDF cString = "HKDF\x00" _OSSL_KDF_NAME_PBKDF2 cString = "PBKDF2\x00" - _OSSL_KDF_NAME_TLS1_PRF cString = "TLS1-PRF\x00" + _OSSL_KDF_NAME_TLS1_PRF cString = "TLS1-PRF\x00" _OSSL_KDF_NAME_TLS13_KDF cString = "TLS13-KDF\x00" _OSSL_MAC_NAME_HMAC cString = "HMAC\x00" diff --git a/evp.go b/evp.go index 20777c07..4189ee27 100644 --- a/evp.go +++ b/evp.go @@ -63,14 +63,22 @@ func hashFuncToMD(fn func() hash.Hash) (ossl.EVP_MD_PTR, error) { return md, nil } +// provider is an identifier for a known provider. +type provider uint8 + +const ( + providerNone provider = iota + providerOSSLDefault + providerOSSLFIPS + providerSymCrypt +) + type hashAlgorithm struct { - md ossl.EVP_MD_PTR - ch crypto.Hash - size int - blockSize int - marshallable bool - magic string - marshalledSize int + md ossl.EVP_MD_PTR + ch crypto.Hash + size int + blockSize int + provider provider } // loadHash converts a crypto.Hash to a EVP_MD. @@ -87,41 +95,25 @@ func loadHash(ch crypto.Hash) *hashAlgorithm { hash.md = ossl.EVP_md4() case crypto.MD5: hash.md = ossl.EVP_md5() - hash.magic = md5Magic - hash.marshalledSize = md5MarshaledSize case crypto.MD5SHA1: hash.md = ossl.EVP_md5_sha1() case crypto.SHA1: hash.md = ossl.EVP_sha1() - hash.magic = sha1Magic - hash.marshalledSize = sha1MarshaledSize case crypto.SHA224: hash.md = ossl.EVP_sha224() - hash.magic = magic224 - hash.marshalledSize = marshaledSize256 case crypto.SHA256: hash.md = ossl.EVP_sha256() - hash.magic = magic256 - hash.marshalledSize = marshaledSize256 case crypto.SHA384: hash.md = ossl.EVP_sha384() - hash.magic = magic384 - hash.marshalledSize = marshaledSize512 case crypto.SHA512: hash.md = ossl.EVP_sha512() - hash.magic = magic512 - hash.marshalledSize = marshaledSize512 case crypto.SHA512_224: if versionAtOrAbove(1, 1, 1) { hash.md = ossl.EVP_sha512_224() - hash.magic = magic512_224 - hash.marshalledSize = marshaledSize512 } case crypto.SHA512_256: if versionAtOrAbove(1, 1, 1) { hash.md = ossl.EVP_sha512_256() - hash.magic = magic512_256 - hash.marshalledSize = marshaledSize512 } case crypto.SHA3_224: if versionAtOrAbove(1, 1, 1) { @@ -159,7 +151,26 @@ func loadHash(ch crypto.Hash) *hashAlgorithm { hash.md = md } } - hash.marshallable = hash.magic != "" && isHashMarshallable(hash.md) + + switch vMajor { + case 1: + hash.provider = providerOSSLDefault + case 3: + if prov := ossl.EVP_MD_get0_provider(hash.md); prov != nil { + cname := ossl.OSSL_PROVIDER_get0_name(prov) + switch C.GoString((*C.char)(unsafe.Pointer(cname))) { + case "default": + hash.provider = providerOSSLDefault + case "fips": + hash.provider = providerOSSLFIPS + case "symcryptprovider": + hash.provider = providerSymCrypt + } + } + default: + panic(errUnsupportedVersion()) + } + cacheMD.Store(ch, &hash) return &hash } diff --git a/hash.go b/hash.go index 033169be..993c785b 100644 --- a/hash.go +++ b/hash.go @@ -15,6 +15,22 @@ import ( "github.com/golang-fips/openssl/v2/internal/ossl" ) +const ( + magicMD5 = "md5\x01" + magic1 = "sha\x01" + magic224 = "sha\x02" + magic256 = "sha\x03" + magic384 = "sha\x04" + magic512_224 = "sha\x05" + magic512_256 = "sha\x06" + magic512 = "sha\x07" + + marshaledSizeMD5 = len(magicMD5) + 4*4 + 64 + 8 + marshaledSize1 = len(magic1) + 5*4 + 64 + 8 + marshaledSize256 = len(magic256) + 8*4 + 64 + 8 + marshaledSize512 = len(magic512) + 8*8 + 128 + 8 +) + // maxHashSize is the size of SHA52 and SHA3_512, the largest hashes we support. const maxHashSize = 64 @@ -207,27 +223,6 @@ func NewSHA3_512() hash.Hash { return newEvpHash(crypto.SHA3_512) } -// isHashMarshallable returns true if the memory layout of md -// is known by this library and can therefore be marshalled. -func isHashMarshallable(md ossl.EVP_MD_PTR) bool { - if vMajor == 1 { - return true - } - prov := ossl.EVP_MD_get0_provider(md) - if prov == nil { - return false - } - cname := ossl.OSSL_PROVIDER_get0_name(prov) - if cname == nil { - return false - } - name := C.GoString((*C.char)(unsafe.Pointer(cname))) - // We only know the memory layout of the built-in providers. - // See evpHash.hashState for more details. - marshallable := name == "default" || name == "fips" - return marshallable -} - // cloneHash is an interface that defines a Clone method. // // hahs.CloneHash will probably be added in Go 1.25, see https://golang.org/issue/69521, @@ -387,299 +382,73 @@ func (h *evpHash) Clone() hash.Hash { return h2 } -// hashState returns a pointer to the internal hash structure. -// -// The EVP_MD_CTX memory layout has changed in OpenSSL 3 -// and the property holding the internal structure is no longer md_data but algctx. -func hashState(ctx ossl.EVP_MD_CTX_PTR) unsafe.Pointer { - switch vMajor { - case 1: - // https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/crypto/evp/evp_local.h#L12. - type mdCtx struct { - _ [2]unsafe.Pointer - _ C.ulong - md_data unsafe.Pointer - } - return (*mdCtx)(unsafe.Pointer(ctx)).md_data - case 3: - // https://github.com/openssl/openssl/blob/5675a5aaf6a2e489022bcfc18330dae9263e598e/crypto/evp/evp_local.h#L16. - type mdCtx struct { - _ [3]unsafe.Pointer - _ C.ulong - _ [3]unsafe.Pointer - algctx unsafe.Pointer - } - return (*mdCtx)(unsafe.Pointer(ctx)).algctx - default: - panic(errUnsupportedVersion()) - } -} +var errHashNotMarshallable = errors.New("openssl: hash state is not marshallable") func (d *evpHash) MarshalBinary() ([]byte, error) { - if !d.alg.marshallable { - return nil, errors.New("openssl: hash state is not marshallable") - } - buf := make([]byte, 0, d.alg.marshalledSize) + buf := make([]byte, 0, marshaledSize512) // stack allocate the buffer by setting the max size we support return d.AppendBinary(buf) } func (d *evpHash) AppendBinary(buf []byte) ([]byte, error) { defer runtime.KeepAlive(d) d.init() - if !d.alg.marshallable { - return nil, errors.New("openssl: hash state is not marshallable") - } - state := hashState(d.ctx) - if state == nil { - return nil, errors.New("openssl: can't retrieve hash state") - } - var appender interface { - AppendBinary([]byte) ([]byte, error) - } - switch d.alg.ch { - case crypto.MD5: - appender = (*md5State)(state) - case crypto.SHA1: - appender = (*sha1State)(state) - case crypto.SHA224: - appender = (*sha256State)(state) - case crypto.SHA256: - appender = (*sha256State)(state) - case crypto.SHA384: - appender = (*sha512State)(state) - case crypto.SHA512: - appender = (*sha512State)(state) - case crypto.SHA512_224: - appender = (*sha512State)(state) - case crypto.SHA512_256: - appender = (*sha512State)(state) + if magic, _ := cryptoHashEncodingInfo(d.alg.ch); magic == "" { + return nil, errHashNotMarshallable + } + switch d.alg.provider { + case providerOSSLDefault, providerOSSLFIPS: + return osslHashAppendBinary(d.ctx, d.alg.ch, buf) + case providerSymCrypt: + return symCryptHashAppendBinary(d.ctx, d.alg.ch, buf) default: - panic("openssl: unsupported hash function: " + strconv.Itoa(int(d.alg.ch))) + return nil, errHashNotMarshallable } - buf = append(buf, d.alg.magic[:]...) - return appender.AppendBinary(buf) } func (d *evpHash) UnmarshalBinary(b []byte) error { defer runtime.KeepAlive(d) d.init() - if !d.alg.marshallable { - return errors.New("openssl: hash state is not marshallable") + magic, size := cryptoHashEncodingInfo(d.alg.ch) + if magic == "" { + return errHashNotMarshallable } - if len(b) < len(d.alg.magic) || string(b[:len(d.alg.magic)]) != string(d.alg.magic[:]) { + if len(b) < len(magic) || string(b[:len(magic)]) != string(magic[:]) { return errors.New("openssl: invalid hash state identifier") } - if len(b) != d.alg.marshalledSize { + if len(b) != size { return errors.New("openssl: invalid hash state size") } - state := hashState(d.ctx) - if state == nil { - return errors.New("openssl: can't retrieve hash state") - } - b = b[len(d.alg.magic):] - var unmarshaler interface { - UnmarshalBinary([]byte) error + switch d.alg.provider { + case providerOSSLDefault, providerOSSLFIPS: + return osslHashUnmarshalBinary(d.ctx, d.alg.ch, b) + case providerSymCrypt: + return symCryptHashUnmarshalBinary(d.ctx, d.alg.ch, b) + default: + return errHashNotMarshallable } - switch d.alg.ch { +} + +func cryptoHashEncodingInfo(ch crypto.Hash) (magic string, size int) { + switch ch { case crypto.MD5: - unmarshaler = (*md5State)(state) + return magicMD5, marshaledSizeMD5 case crypto.SHA1: - unmarshaler = (*sha1State)(state) + return magic1, marshaledSize1 case crypto.SHA224: - unmarshaler = (*sha256State)(state) + return magic224, marshaledSize256 case crypto.SHA256: - unmarshaler = (*sha256State)(state) + return magic256, marshaledSize256 case crypto.SHA384: - unmarshaler = (*sha512State)(state) - case crypto.SHA512: - unmarshaler = (*sha512State)(state) + return magic384, marshaledSize512 case crypto.SHA512_224: - unmarshaler = (*sha512State)(state) + return magic512_224, marshaledSize512 case crypto.SHA512_256: - unmarshaler = (*sha512State)(state) + return magic512_256, marshaledSize512 + case crypto.SHA512: + return magic512, marshaledSize512 default: - panic("openssl: unsupported hash function: " + strconv.Itoa(int(d.alg.ch))) + return "", 0 } - return unmarshaler.UnmarshalBinary(b) -} - -// md5State layout is taken from -// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/md5.h#L33. -type md5State struct { - h [4]uint32 - nl, nh uint32 - x [64]byte - nx uint32 -} - -const ( - md5Magic = "md5\x01" - md5MarshaledSize = len(md5Magic) + 4*4 + 64 + 8 -) - -func (d *md5State) UnmarshalBinary(b []byte) error { - b, d.h[0] = consumeUint32(b) - b, d.h[1] = consumeUint32(b) - b, d.h[2] = consumeUint32(b) - b, d.h[3] = consumeUint32(b) - b = b[copy(d.x[:], b):] - _, n := consumeUint64(b) - d.nl = uint32(n << 3) - d.nh = uint32(n >> 29) - d.nx = uint32(n) % 64 - return nil -} - -func (d *md5State) AppendBinary(buf []byte) ([]byte, error) { - buf = appendUint32(buf, d.h[0]) - buf = appendUint32(buf, d.h[1]) - buf = appendUint32(buf, d.h[2]) - buf = appendUint32(buf, d.h[3]) - buf = append(buf, d.x[:d.nx]...) - buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) - buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) - return buf, nil -} - -// sha1State layout is taken from -// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L34. -type sha1State struct { - h [5]uint32 - nl, nh uint32 - x [64]byte - nx uint32 -} - -const ( - sha1Magic = "sha\x01" - sha1MarshaledSize = len(sha1Magic) + 5*4 + 64 + 8 -) - -func (d *sha1State) UnmarshalBinary(b []byte) error { - b, d.h[0] = consumeUint32(b) - b, d.h[1] = consumeUint32(b) - b, d.h[2] = consumeUint32(b) - b, d.h[3] = consumeUint32(b) - b, d.h[4] = consumeUint32(b) - b = b[copy(d.x[:], b):] - _, n := consumeUint64(b) - d.nl = uint32(n << 3) - d.nh = uint32(n >> 29) - d.nx = uint32(n) % 64 - return nil -} - -func (d *sha1State) AppendBinary(buf []byte) ([]byte, error) { - buf = appendUint32(buf, d.h[0]) - buf = appendUint32(buf, d.h[1]) - buf = appendUint32(buf, d.h[2]) - buf = appendUint32(buf, d.h[3]) - buf = appendUint32(buf, d.h[4]) - buf = append(buf, d.x[:d.nx]...) - buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) - buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) - return buf, nil -} - -const ( - magic224 = "sha\x02" - magic256 = "sha\x03" - marshaledSize256 = len(magic256) + 8*4 + 64 + 8 -) - -// sha256State layout is taken from -// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L51. -type sha256State struct { - h [8]uint32 - nl, nh uint32 - x [64]byte - nx uint32 -} - -func (d *sha256State) UnmarshalBinary(b []byte) error { - b, d.h[0] = consumeUint32(b) - b, d.h[1] = consumeUint32(b) - b, d.h[2] = consumeUint32(b) - b, d.h[3] = consumeUint32(b) - b, d.h[4] = consumeUint32(b) - b, d.h[5] = consumeUint32(b) - b, d.h[6] = consumeUint32(b) - b, d.h[7] = consumeUint32(b) - b = b[copy(d.x[:], b):] - _, n := consumeUint64(b) - d.nl = uint32(n << 3) - d.nh = uint32(n >> 29) - d.nx = uint32(n) % 64 - return nil -} - -func (d *sha256State) AppendBinary(buf []byte) ([]byte, error) { - buf = appendUint32(buf, d.h[0]) - buf = appendUint32(buf, d.h[1]) - buf = appendUint32(buf, d.h[2]) - buf = appendUint32(buf, d.h[3]) - buf = appendUint32(buf, d.h[4]) - buf = appendUint32(buf, d.h[5]) - buf = appendUint32(buf, d.h[6]) - buf = appendUint32(buf, d.h[7]) - buf = append(buf, d.x[:d.nx]...) - buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) - buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) - return buf, nil -} - -// sha512State layout is taken from -// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L95. -type sha512State struct { - h [8]uint64 - nl, nh uint64 - x [128]byte - nx uint32 -} - -const ( - magic384 = "sha\x04" - magic512_224 = "sha\x05" - magic512_256 = "sha\x06" - magic512 = "sha\x07" - marshaledSize512 = len(magic512) + 8*8 + 128 + 8 -) - -func (d *sha512State) MarshalBinary() ([]byte, error) { - buf := make([]byte, 0, marshaledSize512) - return d.AppendBinary(buf) -} - -func (d *sha512State) UnmarshalBinary(b []byte) error { - b, d.h[0] = consumeUint64(b) - b, d.h[1] = consumeUint64(b) - b, d.h[2] = consumeUint64(b) - b, d.h[3] = consumeUint64(b) - b, d.h[4] = consumeUint64(b) - b, d.h[5] = consumeUint64(b) - b, d.h[6] = consumeUint64(b) - b, d.h[7] = consumeUint64(b) - b = b[copy(d.x[:], b):] - _, n := consumeUint64(b) - d.nl = n << 3 - d.nh = n >> 61 - d.nx = uint32(n) % 128 - return nil -} - -func (d *sha512State) AppendBinary(buf []byte) ([]byte, error) { - buf = appendUint64(buf, d.h[0]) - buf = appendUint64(buf, d.h[1]) - buf = appendUint64(buf, d.h[2]) - buf = appendUint64(buf, d.h[3]) - buf = appendUint64(buf, d.h[4]) - buf = appendUint64(buf, d.h[5]) - buf = appendUint64(buf, d.h[6]) - buf = appendUint64(buf, d.h[7]) - buf = append(buf, d.x[:d.nx]...) - buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) - buf = appendUint64(buf, d.nl>>3|d.nh<<61) - return buf, nil } // appendUint64 appends x into b as a big endian byte sequence. diff --git a/hash_test.go b/hash_test.go index 569e51c7..4c3ceb16 100644 --- a/hash_test.go +++ b/hash_test.go @@ -10,6 +10,13 @@ import ( "strings" "testing" + // Blank imports to ensure that the hash functions are registered. + _ "crypto/md5" + _ "crypto/sha1" + _ "crypto/sha256" + _ "crypto/sha3" + _ "crypto/sha512" + "github.com/golang-fips/openssl/v2" ) @@ -94,6 +101,13 @@ func TestHash(t *testing.T) { } } +type hashEncoding interface { + hash.Hash + encoding.BinaryMarshaler + encoding.BinaryUnmarshaler + AppendBinary(b []byte) ([]byte, error) +} + func TestHash_BinaryMarshaler(t *testing.T) { msg := []byte("testing") for _, ch := range hashes { @@ -103,10 +117,7 @@ func TestHash_BinaryMarshaler(t *testing.T) { t.Skip("hash not supported") } - hashMarshaler, ok := cryptoToHash(ch)().(interface { - hash.Hash - encoding.BinaryMarshaler - }) + hashMarshaler, ok := cryptoToHash(ch)().(hashEncoding) if !ok { t.Fatal("BinaryMarshaler not supported") } @@ -123,10 +134,7 @@ func TestHash_BinaryMarshaler(t *testing.T) { t.Fatalf("MarshalBinary failed: %v", err) } - hashUnmarshaler := cryptoToHash(ch)().(interface { - hash.Hash - encoding.BinaryUnmarshaler - }) + hashUnmarshaler := cryptoToHash(ch)().(hashEncoding) if err := hashUnmarshaler.UnmarshalBinary(state); err != nil { t.Fatalf("UnmarshalBinary failed: %v", err) } @@ -134,6 +142,21 @@ func TestHash_BinaryMarshaler(t *testing.T) { if actual, actual2 := hashMarshaler.Sum(nil), hashUnmarshaler.Sum(nil); !bytes.Equal(actual, actual2) { t.Errorf("0x%x != appended 0x%x", actual, actual2) } + + // Test that the hash state is compatible with native Go. + h := ch.New().(hashEncoding) + h.Write(msg) + stateh, err := h.(encoding.BinaryMarshaler).MarshalBinary() + if err != nil { + t.Error(err) + } + if !bytes.Equal(state, stateh) { + t.Errorf("got 0x%x != want 0x%x", state, stateh) + } + h = ch.New().(hashEncoding) + if err := h.UnmarshalBinary(state); err != nil { + t.Error(err) + } }) } } @@ -146,10 +169,7 @@ func TestHash_BinaryAppender(t *testing.T) { t.Skip("not supported") } - hashWithBinaryAppender, ok := cryptoToHash(ch)().(interface { - hash.Hash - AppendBinary(b []byte) ([]byte, error) - }) + hashWithBinaryAppender, ok := cryptoToHash(ch)().(hashEncoding) if !ok { t.Fatal("AppendBinary not supported") } @@ -181,10 +201,7 @@ func TestHash_BinaryAppender(t *testing.T) { // Use only the newly appended part of the slice appendedState := state[10:] - h2, ok := cryptoToHash(ch)().(interface { - hash.Hash - encoding.BinaryUnmarshaler - }) + h2, ok := cryptoToHash(ch)().(hashEncoding) if !ok { t.Skip("not supported") } diff --git a/internal/ossl/ossl.go b/internal/ossl/ossl.go index 4a64c2f0..7d18d727 100644 --- a/internal/ossl/ossl.go +++ b/internal/ossl/ossl.go @@ -22,7 +22,10 @@ go_hash_sum(const _EVP_MD_CTX_PTR ctx, _EVP_MD_CTX_PTR ctx2, unsigned char *out, } */ import "C" -import "unsafe" +import ( + "math" + "unsafe" +) func HashSum(ctx1, ctx2 EVP_MD_CTX_PTR, out []byte) error { var errst C.mkcgo_err_state @@ -38,3 +41,31 @@ func HashSum(ctx1, ctx2 EVP_MD_CTX_PTR, out []byte) error { } return nil } + +// OSSL_PARAM is a structure to pass or request object parameters. +// https://docs.openssl.org/3.0/man3/OSSL_PARAM/. +type OSSL_PARAM struct { + Key *byte + DataType uint32 + Data unsafe.Pointer + DataSize int + ReturnSize int +} + +func ossl_param_construct(key *byte, dataType uint32, data unsafe.Pointer, dataSize int) OSSL_PARAM { + return OSSL_PARAM{ + Key: key, + DataType: dataType, + Data: data, + DataSize: dataSize, + ReturnSize: math.MaxInt - 1, + } +} + +func OSSL_PARAM_construct_octet_string(key *byte, data unsafe.Pointer, dataSize int) OSSL_PARAM { + return ossl_param_construct(key, OSSL_PARAM_OCTET_STRING, data, dataSize) +} + +func OSSL_PARAM_construct_end() OSSL_PARAM { + return OSSL_PARAM{} +} diff --git a/internal/ossl/shims.h b/internal/ossl/shims.h index 012d60f8..76a08503 100644 --- a/internal/ossl/shims.h +++ b/internal/ossl/shims.h @@ -86,6 +86,8 @@ enum { _EVP_PKEY_CTRL_RSA_OAEP_LABEL = 0x100A, _EVP_PKEY_CTRL_DSA_PARAMGEN_BITS = 0x1001, _EVP_PKEY_CTRL_DSA_PARAMGEN_Q_BITS = 0x1002, + + _OSSL_PARAM_OCTET_STRING = 5, }; typedef void* _OPENSSL_INIT_SETTINGS_PTR; @@ -189,6 +191,10 @@ _EVP_MD_CTX_PTR EVP_MD_CTX_new(void); void EVP_MD_CTX_free(_EVP_MD_CTX_PTR ctx); int EVP_MD_CTX_copy(_EVP_MD_CTX_PTR out, const _EVP_MD_CTX_PTR in) __attribute__((noescape,nocallback)); int EVP_MD_CTX_copy_ex(_EVP_MD_CTX_PTR out, const _EVP_MD_CTX_PTR in); +_OSSL_PARAM_PTR EVP_MD_CTX_gettable_params(const _EVP_MD_CTX_PTR ctx) __attribute__((tag("3"))); +_OSSL_PARAM_PTR EVP_MD_CTX_settable_params(const _EVP_MD_CTX_PTR ctx) __attribute__((tag("3"))); +int EVP_MD_CTX_get_params(_EVP_MD_CTX_PTR ctx, _OSSL_PARAM_PTR params) __attribute__((tag("3"))); +int EVP_MD_CTX_set_params(_EVP_MD_CTX_PTR ctx, const _OSSL_PARAM_PTR params) __attribute__((tag("3"))); int EVP_Digest(const void *data, size_t count, unsigned char *md, unsigned int *size, const _EVP_MD_PTR type, _ENGINE_PTR impl) __attribute__((noescape,nocallback,nocheckptr("data"))); int EVP_DigestInit_ex(_EVP_MD_CTX_PTR ctx, const _EVP_MD_PTR type, _ENGINE_PTR impl); int EVP_DigestInit(_EVP_MD_CTX_PTR ctx, const _EVP_MD_PTR type); @@ -353,6 +359,7 @@ int EVP_MAC_final(_EVP_MAC_CTX_PTR ctx, unsigned char *out, size_t *outl, size_t // OSSL_PARAM API void OSSL_PARAM_free(_OSSL_PARAM_PTR p) __attribute__((tag("3"))); +const _OSSL_PARAM_PTR OSSL_PARAM_locate_const(const _OSSL_PARAM_PTR p, const char *key) __attribute__((tag("3"))); _OSSL_PARAM_BLD_PTR OSSL_PARAM_BLD_new(void) __attribute__((tag("3"))); void OSSL_PARAM_BLD_free(_OSSL_PARAM_BLD_PTR bld) __attribute__((tag("3"))); _OSSL_PARAM_PTR OSSL_PARAM_BLD_to_param(_OSSL_PARAM_BLD_PTR bld) __attribute__((tag("3"))); diff --git a/internal/ossl/zossl.c b/internal/ossl/zossl.c index 0219d15b..6d33f49d 100644 --- a/internal/ossl/zossl.c +++ b/internal/ossl/zossl.c @@ -98,7 +98,11 @@ int (*_g_EVP_MAC_update)(_EVP_MAC_CTX_PTR, const unsigned char*, size_t); int (*_g_EVP_MD_CTX_copy)(_EVP_MD_CTX_PTR, const _EVP_MD_CTX_PTR); int (*_g_EVP_MD_CTX_copy_ex)(_EVP_MD_CTX_PTR, const _EVP_MD_CTX_PTR); void (*_g_EVP_MD_CTX_free)(_EVP_MD_CTX_PTR); +int (*_g_EVP_MD_CTX_get_params)(_EVP_MD_CTX_PTR, _OSSL_PARAM_PTR); +_OSSL_PARAM_PTR (*_g_EVP_MD_CTX_gettable_params)(const _EVP_MD_CTX_PTR); _EVP_MD_CTX_PTR (*_g_EVP_MD_CTX_new)(void); +int (*_g_EVP_MD_CTX_set_params)(_EVP_MD_CTX_PTR, const _OSSL_PARAM_PTR); +_OSSL_PARAM_PTR (*_g_EVP_MD_CTX_settable_params)(const _EVP_MD_CTX_PTR); _EVP_MD_PTR (*_g_EVP_MD_fetch)(_OSSL_LIB_CTX_PTR, const char*, const char*); void (*_g_EVP_MD_free)(_EVP_MD_PTR); const char* (*_g_EVP_MD_get0_name)(const _EVP_MD_PTR); @@ -212,6 +216,7 @@ int (*_g_OSSL_PARAM_BLD_push_octet_string)(_OSSL_PARAM_BLD_PTR, const char*, con int (*_g_OSSL_PARAM_BLD_push_utf8_string)(_OSSL_PARAM_BLD_PTR, const char*, const char*, size_t); _OSSL_PARAM_PTR (*_g_OSSL_PARAM_BLD_to_param)(_OSSL_PARAM_BLD_PTR); void (*_g_OSSL_PARAM_free)(_OSSL_PARAM_PTR); +const _OSSL_PARAM_PTR (*_g_OSSL_PARAM_locate_const)(const _OSSL_PARAM_PTR, const char*); int (*_g_OSSL_PROVIDER_available)(_OSSL_LIB_CTX_PTR, const char*); const char* (*_g_OSSL_PROVIDER_get0_name)(const _OSSL_PROVIDER_PTR); _OSSL_PROVIDER_PTR (*_g_OSSL_PROVIDER_try_load)(_OSSL_LIB_CTX_PTR, const char*, int); @@ -494,6 +499,10 @@ void __mkcgo_load_3(void* handle) { __mkcgo__dlsym(EVP_MAC_final) __mkcgo__dlsym(EVP_MAC_init) __mkcgo__dlsym(EVP_MAC_update) + __mkcgo__dlsym(EVP_MD_CTX_get_params) + __mkcgo__dlsym(EVP_MD_CTX_gettable_params) + __mkcgo__dlsym(EVP_MD_CTX_set_params) + __mkcgo__dlsym(EVP_MD_CTX_settable_params) __mkcgo__dlsym(EVP_MD_fetch) __mkcgo__dlsym(EVP_MD_free) __mkcgo__dlsym(EVP_MD_get0_name) @@ -531,6 +540,7 @@ void __mkcgo_load_3(void* handle) { __mkcgo__dlsym(OSSL_PARAM_BLD_push_utf8_string) __mkcgo__dlsym(OSSL_PARAM_BLD_to_param) __mkcgo__dlsym(OSSL_PARAM_free) + __mkcgo__dlsym(OSSL_PARAM_locate_const) __mkcgo__dlsym(OSSL_PROVIDER_available) __mkcgo__dlsym(OSSL_PROVIDER_get0_name) __mkcgo__dlsym(OSSL_PROVIDER_try_load) @@ -557,6 +567,10 @@ void __mkcgo_unload_3() { _g_EVP_MAC_final = NULL; _g_EVP_MAC_init = NULL; _g_EVP_MAC_update = NULL; + _g_EVP_MD_CTX_get_params = NULL; + _g_EVP_MD_CTX_gettable_params = NULL; + _g_EVP_MD_CTX_set_params = NULL; + _g_EVP_MD_CTX_settable_params = NULL; _g_EVP_MD_fetch = NULL; _g_EVP_MD_free = NULL; _g_EVP_MD_get0_name = NULL; @@ -594,6 +608,7 @@ void __mkcgo_unload_3() { _g_OSSL_PARAM_BLD_push_utf8_string = NULL; _g_OSSL_PARAM_BLD_to_param = NULL; _g_OSSL_PARAM_free = NULL; + _g_OSSL_PARAM_locate_const = NULL; _g_OSSL_PROVIDER_available = NULL; _g_OSSL_PROVIDER_get0_name = NULL; _g_OSSL_PROVIDER_try_load = NULL; @@ -1253,6 +1268,20 @@ void _mkcgo_EVP_MD_CTX_free(_EVP_MD_CTX_PTR _arg0) { _g_EVP_MD_CTX_free(_arg0); } +int _mkcgo_EVP_MD_CTX_get_params(_EVP_MD_CTX_PTR _arg0, _OSSL_PARAM_PTR _arg1, mkcgo_err_state *_err_state) { + mkcgo_err_clear(); + int _ret = _g_EVP_MD_CTX_get_params(_arg0, _arg1); + if (_ret <= 0) *_err_state = mkcgo_err_retrieve(); + return _ret; +} + +_OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_gettable_params(const _EVP_MD_CTX_PTR _arg0, mkcgo_err_state *_err_state) { + mkcgo_err_clear(); + _OSSL_PARAM_PTR _ret = _g_EVP_MD_CTX_gettable_params(_arg0); + if (_ret == NULL) *_err_state = mkcgo_err_retrieve(); + return _ret; +} + _EVP_MD_CTX_PTR _mkcgo_EVP_MD_CTX_new(mkcgo_err_state *_err_state) { mkcgo_err_clear(); _EVP_MD_CTX_PTR _ret = _g_EVP_MD_CTX_new(); @@ -1260,6 +1289,20 @@ _EVP_MD_CTX_PTR _mkcgo_EVP_MD_CTX_new(mkcgo_err_state *_err_state) { return _ret; } +int _mkcgo_EVP_MD_CTX_set_params(_EVP_MD_CTX_PTR _arg0, const _OSSL_PARAM_PTR _arg1, mkcgo_err_state *_err_state) { + mkcgo_err_clear(); + int _ret = _g_EVP_MD_CTX_set_params(_arg0, _arg1); + if (_ret <= 0) *_err_state = mkcgo_err_retrieve(); + return _ret; +} + +_OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_settable_params(const _EVP_MD_CTX_PTR _arg0, mkcgo_err_state *_err_state) { + mkcgo_err_clear(); + _OSSL_PARAM_PTR _ret = _g_EVP_MD_CTX_settable_params(_arg0); + if (_ret == NULL) *_err_state = mkcgo_err_retrieve(); + return _ret; +} + _EVP_MD_PTR _mkcgo_EVP_MD_fetch(_OSSL_LIB_CTX_PTR _arg0, const char* _arg1, const char* _arg2, mkcgo_err_state *_err_state) { mkcgo_err_clear(); _EVP_MD_PTR _ret = _g_EVP_MD_fetch(_arg0, _arg1, _arg2); @@ -1924,6 +1967,13 @@ void _mkcgo_OSSL_PARAM_free(_OSSL_PARAM_PTR _arg0) { _g_OSSL_PARAM_free(_arg0); } +const _OSSL_PARAM_PTR _mkcgo_OSSL_PARAM_locate_const(const _OSSL_PARAM_PTR _arg0, const char* _arg1, mkcgo_err_state *_err_state) { + mkcgo_err_clear(); + const _OSSL_PARAM_PTR _ret = _g_OSSL_PARAM_locate_const(_arg0, _arg1); + if (_ret <= 0) *_err_state = mkcgo_err_retrieve(); + return _ret; +} + int _mkcgo_OSSL_PROVIDER_available(_OSSL_LIB_CTX_PTR _arg0, const char* _arg1) { return _g_OSSL_PROVIDER_available(_arg0, _arg1); } diff --git a/internal/ossl/zossl.go b/internal/ossl/zossl.go index 568e410a..94a3dc6b 100644 --- a/internal/ossl/zossl.go +++ b/internal/ossl/zossl.go @@ -61,6 +61,7 @@ const ( EVP_PKEY_CTRL_RSA_OAEP_LABEL = 0x100A EVP_PKEY_CTRL_DSA_PARAMGEN_BITS = 0x1001 EVP_PKEY_CTRL_DSA_PARAMGEN_Q_BITS = 0x1002 + OSSL_PARAM_OCTET_STRING = 5 ) type BIGNUM_PTR = C._BIGNUM_PTR @@ -612,12 +613,36 @@ func EVP_MD_CTX_free(ctx EVP_MD_CTX_PTR) { C._mkcgo_EVP_MD_CTX_free(ctx) } +func EVP_MD_CTX_get_params(ctx EVP_MD_CTX_PTR, params OSSL_PARAM_PTR) (int32, error) { + var _err C.mkcgo_err_state + _ret := C._mkcgo_EVP_MD_CTX_get_params(ctx, params, mkcgoNoEscape(&_err)) + return int32(_ret), newMkcgoErr("EVP_MD_CTX_get_params", _err) +} + +func EVP_MD_CTX_gettable_params(ctx EVP_MD_CTX_PTR) (OSSL_PARAM_PTR, error) { + var _err C.mkcgo_err_state + _ret := C._mkcgo_EVP_MD_CTX_gettable_params(ctx, mkcgoNoEscape(&_err)) + return _ret, newMkcgoErr("EVP_MD_CTX_gettable_params", _err) +} + func EVP_MD_CTX_new() (EVP_MD_CTX_PTR, error) { var _err C.mkcgo_err_state _ret := C._mkcgo_EVP_MD_CTX_new(mkcgoNoEscape(&_err)) return _ret, newMkcgoErr("EVP_MD_CTX_new", _err) } +func EVP_MD_CTX_set_params(ctx EVP_MD_CTX_PTR, params OSSL_PARAM_PTR) (int32, error) { + var _err C.mkcgo_err_state + _ret := C._mkcgo_EVP_MD_CTX_set_params(ctx, params, mkcgoNoEscape(&_err)) + return int32(_ret), newMkcgoErr("EVP_MD_CTX_set_params", _err) +} + +func EVP_MD_CTX_settable_params(ctx EVP_MD_CTX_PTR) (OSSL_PARAM_PTR, error) { + var _err C.mkcgo_err_state + _ret := C._mkcgo_EVP_MD_CTX_settable_params(ctx, mkcgoNoEscape(&_err)) + return _ret, newMkcgoErr("EVP_MD_CTX_settable_params", _err) +} + func EVP_MD_fetch(ctx OSSL_LIB_CTX_PTR, algorithm *byte, properties *byte) (EVP_MD_PTR, error) { var _err C.mkcgo_err_state _ret := C._mkcgo_EVP_MD_fetch(ctx, (*C.char)(unsafe.Pointer(algorithm)), (*C.char)(unsafe.Pointer(properties)), mkcgoNoEscape(&_err)) @@ -1218,6 +1243,12 @@ func OSSL_PARAM_free(p OSSL_PARAM_PTR) { C._mkcgo_OSSL_PARAM_free(p) } +func OSSL_PARAM_locate_const(p OSSL_PARAM_PTR, key *byte) (OSSL_PARAM_PTR, error) { + var _err C.mkcgo_err_state + _ret := C._mkcgo_OSSL_PARAM_locate_const(p, (*C.char)(unsafe.Pointer(key)), mkcgoNoEscape(&_err)) + return _ret, newMkcgoErr("OSSL_PARAM_locate_const", _err) +} + func OSSL_PROVIDER_available(libctx OSSL_LIB_CTX_PTR, name *byte) int32 { return int32(C._mkcgo_OSSL_PROVIDER_available(libctx, (*C.char)(unsafe.Pointer(name)))) } diff --git a/internal/ossl/zossl.h b/internal/ossl/zossl.h index 66249349..4f55cd7c 100644 --- a/internal/ossl/zossl.h +++ b/internal/ossl/zossl.h @@ -83,6 +83,7 @@ enum { _EVP_PKEY_CTRL_RSA_OAEP_LABEL = 0x100A, _EVP_PKEY_CTRL_DSA_PARAMGEN_BITS = 0x1001, _EVP_PKEY_CTRL_DSA_PARAMGEN_Q_BITS = 0x1002, + _OSSL_PARAM_OCTET_STRING = 5, }; typedef void* mkcgo_err_state; @@ -195,7 +196,11 @@ int _mkcgo_EVP_MAC_update(_EVP_MAC_CTX_PTR, const unsigned char*, size_t, mkcgo_ int _mkcgo_EVP_MD_CTX_copy(_EVP_MD_CTX_PTR, const _EVP_MD_CTX_PTR, mkcgo_err_state *); int _mkcgo_EVP_MD_CTX_copy_ex(_EVP_MD_CTX_PTR, const _EVP_MD_CTX_PTR, mkcgo_err_state *); void _mkcgo_EVP_MD_CTX_free(_EVP_MD_CTX_PTR); +int _mkcgo_EVP_MD_CTX_get_params(_EVP_MD_CTX_PTR, _OSSL_PARAM_PTR, mkcgo_err_state *); +_OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_gettable_params(const _EVP_MD_CTX_PTR, mkcgo_err_state *); _EVP_MD_CTX_PTR _mkcgo_EVP_MD_CTX_new(mkcgo_err_state *); +int _mkcgo_EVP_MD_CTX_set_params(_EVP_MD_CTX_PTR, const _OSSL_PARAM_PTR, mkcgo_err_state *); +_OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_settable_params(const _EVP_MD_CTX_PTR, mkcgo_err_state *); _EVP_MD_PTR _mkcgo_EVP_MD_fetch(_OSSL_LIB_CTX_PTR, const char*, const char*, mkcgo_err_state *); void _mkcgo_EVP_MD_free(_EVP_MD_PTR); const char* _mkcgo_EVP_MD_get0_name(const _EVP_MD_PTR); @@ -311,6 +316,7 @@ int _mkcgo_OSSL_PARAM_BLD_push_octet_string(_OSSL_PARAM_BLD_PTR, const char*, co int _mkcgo_OSSL_PARAM_BLD_push_utf8_string(_OSSL_PARAM_BLD_PTR, const char*, const char*, size_t, mkcgo_err_state *); _OSSL_PARAM_PTR _mkcgo_OSSL_PARAM_BLD_to_param(_OSSL_PARAM_BLD_PTR, mkcgo_err_state *); void _mkcgo_OSSL_PARAM_free(_OSSL_PARAM_PTR); +const _OSSL_PARAM_PTR _mkcgo_OSSL_PARAM_locate_const(const _OSSL_PARAM_PTR, const char*, mkcgo_err_state *); int _mkcgo_OSSL_PROVIDER_available(_OSSL_LIB_CTX_PTR, const char*); const char* _mkcgo_OSSL_PROVIDER_get0_name(const _OSSL_PROVIDER_PTR); _OSSL_PROVIDER_PTR _mkcgo_OSSL_PROVIDER_try_load(_OSSL_LIB_CTX_PTR, const char*, int, mkcgo_err_state *); diff --git a/params.go b/params.go index a5b6cdb9..337bba9a 100644 --- a/params.go +++ b/params.go @@ -37,13 +37,17 @@ type paramBuilder struct { // newParamBuilder creates a new paramBuilder. func newParamBuilder() (*paramBuilder, error) { + return newParamBuilderN(8) // the maximum known number of BIGNUMs to free are 8 for RSA +} + +func newParamBuilderN(n int) (*paramBuilder, error) { bld, err := ossl.OSSL_PARAM_BLD_new() if err != nil { return nil, err } pb := ¶mBuilder{ bld: bld, - bnToFree: make([]bnParam, 0, 8), // the maximum known number of BIGNUMs to free are 8 for RSA + bnToFree: make([]bnParam, 0, n), } runtime.SetFinalizer(pb, (*paramBuilder).finalize) return pb, nil diff --git a/provideropenssl.go b/provideropenssl.go new file mode 100644 index 00000000..c1e699cd --- /dev/null +++ b/provideropenssl.go @@ -0,0 +1,239 @@ +package openssl + +import ( + "crypto" + "errors" + "unsafe" + + "github.com/golang-fips/openssl/v2/internal/ossl" +) + +// This file contains code specific to the built-in OpenSSL providers. + +// _OSSL_MD5_CTX layout is taken from +// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/md5.h#L33. +type _OSSL_MD5_CTX struct { + h [4]uint32 + nl, nh uint32 + x [64]byte + nx uint32 +} + +func (d *_OSSL_MD5_CTX) UnmarshalBinary(b []byte) error { + b, d.h[0] = consumeUint32(b) + b, d.h[1] = consumeUint32(b) + b, d.h[2] = consumeUint32(b) + b, d.h[3] = consumeUint32(b) + b = b[copy(d.x[:], b):] + _, n := consumeUint64(b) + d.nl = uint32(n << 3) + d.nh = uint32(n >> 29) + d.nx = uint32(n) % 64 + return nil +} + +func (d *_OSSL_MD5_CTX) AppendBinary(buf []byte) ([]byte, error) { + buf = appendUint32(buf, d.h[0]) + buf = appendUint32(buf, d.h[1]) + buf = appendUint32(buf, d.h[2]) + buf = appendUint32(buf, d.h[3]) + buf = append(buf, d.x[:d.nx]...) + buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) + buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) + return buf, nil +} + +// _OSSL_SHA_CTX layout is taken from +// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L34. +type _OSSL_SHA_CTX struct { + h [5]uint32 + nl, nh uint32 + x [64]byte + nx uint32 +} + +func (d *_OSSL_SHA_CTX) UnmarshalBinary(b []byte) error { + b, d.h[0] = consumeUint32(b) + b, d.h[1] = consumeUint32(b) + b, d.h[2] = consumeUint32(b) + b, d.h[3] = consumeUint32(b) + b, d.h[4] = consumeUint32(b) + b = b[copy(d.x[:], b):] + _, n := consumeUint64(b) + d.nl = uint32(n << 3) + d.nh = uint32(n >> 29) + d.nx = uint32(n) % 64 + return nil +} + +func (d *_OSSL_SHA_CTX) AppendBinary(buf []byte) ([]byte, error) { + buf = appendUint32(buf, d.h[0]) + buf = appendUint32(buf, d.h[1]) + buf = appendUint32(buf, d.h[2]) + buf = appendUint32(buf, d.h[3]) + buf = appendUint32(buf, d.h[4]) + buf = append(buf, d.x[:d.nx]...) + buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) + buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) + return buf, nil +} + +// _OSSL_SHA256_CTX layout is taken from +// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L51. +type _OSSL_SHA256_CTX struct { + h [8]uint32 + nl, nh uint32 + x [64]byte + nx uint32 +} + +func (d *_OSSL_SHA256_CTX) UnmarshalBinary(b []byte) error { + b, d.h[0] = consumeUint32(b) + b, d.h[1] = consumeUint32(b) + b, d.h[2] = consumeUint32(b) + b, d.h[3] = consumeUint32(b) + b, d.h[4] = consumeUint32(b) + b, d.h[5] = consumeUint32(b) + b, d.h[6] = consumeUint32(b) + b, d.h[7] = consumeUint32(b) + b = b[copy(d.x[:], b):] + _, n := consumeUint64(b) + d.nl = uint32(n << 3) + d.nh = uint32(n >> 29) + d.nx = uint32(n) % 64 + return nil +} + +func (d *_OSSL_SHA256_CTX) AppendBinary(buf []byte) ([]byte, error) { + buf = appendUint32(buf, d.h[0]) + buf = appendUint32(buf, d.h[1]) + buf = appendUint32(buf, d.h[2]) + buf = appendUint32(buf, d.h[3]) + buf = appendUint32(buf, d.h[4]) + buf = appendUint32(buf, d.h[5]) + buf = appendUint32(buf, d.h[6]) + buf = appendUint32(buf, d.h[7]) + buf = append(buf, d.x[:d.nx]...) + buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) + buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) + return buf, nil +} + +// _OSSL_SHA512_CTX layout is taken from +// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L95. +type _OSSL_SHA512_CTX struct { + h [8]uint64 + nl, nh uint64 + x [128]byte + nx uint32 +} + +func (d *_OSSL_SHA512_CTX) UnmarshalBinary(b []byte) error { + b, d.h[0] = consumeUint64(b) + b, d.h[1] = consumeUint64(b) + b, d.h[2] = consumeUint64(b) + b, d.h[3] = consumeUint64(b) + b, d.h[4] = consumeUint64(b) + b, d.h[5] = consumeUint64(b) + b, d.h[6] = consumeUint64(b) + b, d.h[7] = consumeUint64(b) + b = b[copy(d.x[:], b):] + _, n := consumeUint64(b) + d.nl = n << 3 + d.nh = n >> 61 + d.nx = uint32(n) % 128 + return nil +} + +func (d *_OSSL_SHA512_CTX) AppendBinary(buf []byte) ([]byte, error) { + buf = appendUint64(buf, d.h[0]) + buf = appendUint64(buf, d.h[1]) + buf = appendUint64(buf, d.h[2]) + buf = appendUint64(buf, d.h[3]) + buf = appendUint64(buf, d.h[4]) + buf = appendUint64(buf, d.h[5]) + buf = appendUint64(buf, d.h[6]) + buf = appendUint64(buf, d.h[7]) + buf = append(buf, d.x[:d.nx]...) + buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) + buf = appendUint64(buf, d.nl>>3|d.nh<<61) + return buf, nil +} + +func getOSSLDigetsContext(ctx ossl.EVP_MD_CTX_PTR) unsafe.Pointer { + switch vMajor { + case 1: + // https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/crypto/evp/evp_local.h#L12. + type mdCtx struct { + _ [2]unsafe.Pointer + _ uint32 + md_data unsafe.Pointer + } + return (*mdCtx)(unsafe.Pointer(ctx)).md_data + case 3: + // The EVP_MD_CTX memory layout has changed in OpenSSL 3 + // and the property holding the internal structure is no longer md_data but algctx. + // https://github.com/openssl/openssl/blob/5675a5aaf6a2e489022bcfc18330dae9263e598e/crypto/evp/evp_local.h#L16. + type mdCtx struct { + _ [3]unsafe.Pointer + _ uint32 + _ [3]unsafe.Pointer + algctx unsafe.Pointer + } + return (*mdCtx)(unsafe.Pointer(ctx)).algctx + default: + panic(errUnsupportedVersion()) + } +} + +var errHashStateInvalid = errors.New("openssl: can't retrieve hash state") + +func osslHashAppendBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, buf []byte) ([]byte, error) { + algctx := getOSSLDigetsContext(ctx) + if algctx == nil { + return nil, errHashStateInvalid + } + magic, _ := cryptoHashEncodingInfo(ch) + buf = append(buf, magic...) + switch ch { + case crypto.MD5: + d := (*_OSSL_MD5_CTX)(unsafe.Pointer(algctx)) + return d.AppendBinary(buf) + case crypto.SHA1: + d := (*_OSSL_SHA_CTX)(unsafe.Pointer(algctx)) + return d.AppendBinary(buf) + case crypto.SHA224, crypto.SHA256: + d := (*_OSSL_SHA256_CTX)(unsafe.Pointer(algctx)) + return d.AppendBinary(buf) + case crypto.SHA384, crypto.SHA512_224, crypto.SHA512_256, crypto.SHA512: + d := (*_OSSL_SHA512_CTX)(unsafe.Pointer(algctx)) + return d.AppendBinary(buf) + default: + panic("unsupported hash " + ch.String()) + } +} + +func osslHashUnmarshalBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, b []byte) error { + algctx := getOSSLDigetsContext(ctx) + if algctx == nil { + return errHashStateInvalid + } + magic, _ := cryptoHashEncodingInfo(ch) + b = b[len(magic):] + switch ch { + case crypto.MD5: + d := (*_OSSL_MD5_CTX)(unsafe.Pointer(algctx)) + return d.UnmarshalBinary(b) + case crypto.SHA1: + d := (*_OSSL_SHA_CTX)(unsafe.Pointer(algctx)) + return d.UnmarshalBinary(b) + case crypto.SHA224, crypto.SHA256: + d := (*_OSSL_SHA256_CTX)(unsafe.Pointer(algctx)) + return d.UnmarshalBinary(b) + case crypto.SHA384, crypto.SHA512_224, crypto.SHA512_256, crypto.SHA512: + d := (*_OSSL_SHA512_CTX)(unsafe.Pointer(algctx)) + return d.UnmarshalBinary(b) + default: + panic("unsupported hash " + ch.String()) + } +} diff --git a/providersymcrypt.go b/providersymcrypt.go new file mode 100644 index 00000000..456a2f1a --- /dev/null +++ b/providersymcrypt.go @@ -0,0 +1,372 @@ +package openssl + +import ( + "crypto" + "errors" + "runtime" + "sync" + "unsafe" + + "github.com/golang-fips/openssl/v2/internal/ossl" +) + +// This file contains code specific to the SymCrypt provider. + +const ( + _SCOSSL_DIGEST_PARAM_STATE cString = "state\x00" + _SCOSSL_DIGEST_PARAM_RECOMPUTE_CHECKSUM cString = "recompute_checksum\x00" +) + +const ( + _SYMCRYPT_BLOB_MAGIC = 0x636D7973 // "cysm" in little-endian + + _SymCryptBlobTypeHashState = 0x100 + _SymCryptBlobTypeMd2State = _SymCryptBlobTypeHashState + 1 + _SymCryptBlobTypeMd4State = _SymCryptBlobTypeHashState + 2 + _SymCryptBlobTypeMd5State = _SymCryptBlobTypeHashState + 3 + _SymCryptBlobTypeSha1State = _SymCryptBlobTypeHashState + 4 + _SymCryptBlobTypeSha256State = _SymCryptBlobTypeHashState + 5 + _SymCryptBlobTypeSha384State = _SymCryptBlobTypeHashState + 6 + _SymCryptBlobTypeSha512State = _SymCryptBlobTypeHashState + 7 + _SymCryptBlobTypeSha3_256State = _SymCryptBlobTypeHashState + 8 + _SymCryptBlobTypeSha3_384State = _SymCryptBlobTypeHashState + 9 + _SymCryptBlobTypeSha3_512State = _SymCryptBlobTypeHashState + 10 + _SymCryptBlobTypeSha224State = _SymCryptBlobTypeHashState + 11 + _SymCryptBlobTypeSha512_224State = _SymCryptBlobTypeHashState + 12 + _SymCryptBlobTypeSha512_256State = _SymCryptBlobTypeHashState + 13 + _SymCryptBlobTypeSha3_224State = _SymCryptBlobTypeHashState + 14 + + _SYMCRYPT_MD5_STATE_EXPORT_SIZE = uint32(unsafe.Sizeof(_SYMCRYPT_MD5_STATE_EXPORT_BLOB{})) + _SYMCRYPT_SHA1_STATE_EXPORT_SIZE = uint32(unsafe.Sizeof(_SYMCRYPT_SHA1_STATE_EXPORT_BLOB{})) + _SYMCRYPT_SHA256_STATE_EXPORT_SIZE = uint32(unsafe.Sizeof(_SYMCRYPT_SHA256_STATE_EXPORT_BLOB{})) + _SYMCRYPT_SHA512_STATE_EXPORT_SIZE = uint32(unsafe.Sizeof(_SYMCRYPT_SHA512_STATE_EXPORT_BLOB{})) +) + +type _SYMCRYPT_BLOB_HEADER struct { + magic uint32 + size uint32 + _type uint32 +} + +type _SYMCRYPT_BLOB_TRAILER struct { + checksum [8]uint8 +} + +// _UINT64 is a 64-bit unsigned integer, stored in native endianess. +// It is used to represent a SymCrypt UINT64 type without making the +// parent struct 8-byte aligned, given that the Windows ABI makes +// the struct 4-byte aligned. +type _UINT64 [2]uint32 + +func newUINT64(v uint64) _UINT64 { + var u _UINT64 + if isBigEndian { + u[0], u[1] = uint32(v>>32), uint32(v) + } else { + u[0], u[1] = uint32(v), uint32(v>>32) + } + return u +} + +func (u *_UINT64) uint64() uint64 { + if isBigEndian { + return uint64(u[0])<<32 | (uint64(u[1])) + } + return uint64(u[0]) | (uint64(u[1]) << 32) +} + +// symCryptAppendBinary appends the binary representation of a SymCrypt state +// to the given destination slice. +func symCryptAppendBinary(dst, chain, buffer []byte, blength _UINT64) []byte { + length := blength.uint64() + nx := length & 0x3f + dst = append(dst, chain...) + dst = append(dst, buffer[:nx]...) + dst = append(dst, make([]byte, len(buffer)-int(nx))...) + dst = appendUint64(dst, length) + return dst +} + +// symCryptUnmarshalBinary unmarshals the binary representation of a SymCrypt state +// from the given source slice. It returns the length of the data. +func symCryptUnmarshalBinary(d []byte, chain, buffer []byte) _UINT64 { + copy(chain[:], d) + d = d[len(chain):] + copy(buffer[:], d) + d = d[len(buffer):] + _, length := consumeUint64(d) + return newUINT64(length) +} + +type _SYMCRYPT_MD5_STATE_EXPORT_BLOB struct { + header _SYMCRYPT_BLOB_HEADER + chain [16]uint8 // little endian + length _UINT64 // native endian + buffer [64]uint8 + _ [8]uint8 // reserved + _ _SYMCRYPT_BLOB_TRAILER +} + +func (b *_SYMCRYPT_MD5_STATE_EXPORT_BLOB) appendBinary(d []byte) ([]byte, error) { + // b.chain is little endian, but Go expects big endian, + // we need to swap the bytes. + for i := 0; i < len(b.chain); i += 4 { + b.chain[i], b.chain[i+3] = b.chain[i+3], b.chain[i] + b.chain[i+1], b.chain[i+2] = b.chain[i+2], b.chain[i+1] + } + return symCryptAppendBinary(d, b.chain[:], b.buffer[:], b.length), nil +} + +func (b *_SYMCRYPT_MD5_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { + b.length = symCryptUnmarshalBinary(d, b.chain[:], b.buffer[:]) + // b.chain should be little endian, but Go uses big endian, + // we need to swap the bytes. + for i := 0; i < len(b.chain); i += 4 { + b.chain[i], b.chain[i+3] = b.chain[i+3], b.chain[i] + b.chain[i+1], b.chain[i+2] = b.chain[i+2], b.chain[i+1] + } +} + +type _SYMCRYPT_SHA1_STATE_EXPORT_BLOB struct { + header _SYMCRYPT_BLOB_HEADER + chain [20]uint8 // big endian + length _UINT64 // native endian + buffer [64]uint8 + _ [8]uint8 // reserved + _ _SYMCRYPT_BLOB_TRAILER +} + +func (b *_SYMCRYPT_SHA1_STATE_EXPORT_BLOB) appendBinary(d []byte) ([]byte, error) { + return symCryptAppendBinary(d, b.chain[:], b.buffer[:], b.length), nil +} + +func (b *_SYMCRYPT_SHA1_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { + b.length = symCryptUnmarshalBinary(d, b.chain[:], b.buffer[:]) +} + +type _SYMCRYPT_SHA256_STATE_EXPORT_BLOB struct { + header _SYMCRYPT_BLOB_HEADER + chain [32]uint8 // big endian + length _UINT64 // native endian + buffer [64]uint8 + _ [8]uint8 // reserved + _ _SYMCRYPT_BLOB_TRAILER +} + +func (b *_SYMCRYPT_SHA256_STATE_EXPORT_BLOB) appendBinary(d []byte) ([]byte, error) { + return symCryptAppendBinary(d, b.chain[:], b.buffer[:], b.length), nil +} + +func (b *_SYMCRYPT_SHA256_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { + b.length = symCryptUnmarshalBinary(d, b.chain[:], b.buffer[:]) +} + +type _SYMCRYPT_SHA512_STATE_EXPORT_BLOB struct { + header _SYMCRYPT_BLOB_HEADER + chain [64]uint8 // big endian + lengthL _UINT64 // native endian + lengthH _UINT64 // native endian + buffer [128]uint8 + _ [8]uint8 // reserved + _ _SYMCRYPT_BLOB_TRAILER +} + +func (b *_SYMCRYPT_SHA512_STATE_EXPORT_BLOB) appendBinary(d []byte) ([]byte, error) { + if b.lengthH.uint64() != 0 { + return nil, errors.New("exporting state with more than 2^63-1 bytes of data is not supported") + } + return symCryptAppendBinary(d, b.chain[:], b.buffer[:], b.lengthL), nil +} + +func (b *_SYMCRYPT_SHA512_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { + b.lengthL = symCryptUnmarshalBinary(d, b.chain[:], b.buffer[:]) +} + +func symCryptHashAppendBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, buf []byte) ([]byte, error) { + size, typ, serializable := symCryptHashStateInfo(ch) + if !serializable { + return nil, errHashNotMarshallable + } + state := make([]byte, size, _SYMCRYPT_SHA512_STATE_EXPORT_SIZE) // 512 is the largest size + var pinner runtime.Pinner + pinner.Pin(&state[0]) + defer pinner.Unpin() + params := [2]ossl.OSSL_PARAM{ + ossl.OSSL_PARAM_construct_octet_string(_SCOSSL_DIGEST_PARAM_STATE.ptr(), unsafe.Pointer(&state[0]), len(state)), + ossl.OSSL_PARAM_construct_end(), + } + if _, err := ossl.EVP_MD_CTX_get_params(ctx, (ossl.OSSL_PARAM_PTR)(unsafe.Pointer(¶ms[0]))); err != nil { + // Old versions of SCOSSL don't support SCOSSL_DIGEST_PARAM_STATE. + return nil, errHashNotMarshallable + } + + header := (*_SYMCRYPT_BLOB_HEADER)(unsafe.Pointer(&state[0])) + if header.magic != _SYMCRYPT_BLOB_MAGIC { + return nil, errors.New("invalid blob magic") + } + if header.size != size { + return nil, errors.New("invalid blob size") + } + if header._type != typ { + return nil, errors.New("invalid blob type") + } + + magic, _ := cryptoHashEncodingInfo(ch) + buf = append(buf, magic...) + switch ch { + case crypto.MD5: + blob := (*_SYMCRYPT_MD5_STATE_EXPORT_BLOB)(unsafe.Pointer(&state[0])) + return blob.appendBinary(buf) + case crypto.SHA1: + blob := (*_SYMCRYPT_SHA1_STATE_EXPORT_BLOB)(unsafe.Pointer(&state[0])) + return blob.appendBinary(buf) + case crypto.SHA224, crypto.SHA256: + blob := (*_SYMCRYPT_SHA256_STATE_EXPORT_BLOB)(unsafe.Pointer(&state[0])) + return blob.appendBinary(buf) + case crypto.SHA384, crypto.SHA512_224, crypto.SHA512_256, crypto.SHA512: + blob := (*_SYMCRYPT_SHA512_STATE_EXPORT_BLOB)(unsafe.Pointer(&state[0])) + return blob.appendBinary(buf) + default: + panic("unsupported hash " + ch.String()) + } +} + +func symCryptHashUnmarshalBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, b []byte) error { + size, typ, serializable := symCryptHashStateInfo(ch) + if !serializable { + return errHashNotMarshallable + } + hdr := _SYMCRYPT_BLOB_HEADER{ + magic: _SYMCRYPT_BLOB_MAGIC, + size: size, + _type: typ, + } + var blobPtr unsafe.Pointer + magic, _ := cryptoHashEncodingInfo(ch) + b = b[len(magic):] + switch ch { + case crypto.MD5: + var blob _SYMCRYPT_MD5_STATE_EXPORT_BLOB + blobPtr = unsafe.Pointer(&blob) + blob.header = hdr + blob.unmarshalBinary(b) + case crypto.SHA1: + var blob _SYMCRYPT_SHA1_STATE_EXPORT_BLOB + blobPtr = unsafe.Pointer(&blob) + blob.header = hdr + blob.unmarshalBinary(b) + case crypto.SHA224, crypto.SHA256: + var blob _SYMCRYPT_SHA256_STATE_EXPORT_BLOB + blobPtr = unsafe.Pointer(&blob) + blob.header = hdr + blob.unmarshalBinary(b) + case crypto.SHA384, crypto.SHA512_224, crypto.SHA512_256, crypto.SHA512: + var blob _SYMCRYPT_SHA512_STATE_EXPORT_BLOB + blobPtr = unsafe.Pointer(&blob) + blob.header = hdr + blob.unmarshalBinary(b) + default: + panic("unsupported hash " + ch.String()) + } + bld, err := newParamBuilderN(2) + if err != nil { + return err + } + defer bld.finalize() + bld.addOctetString(_SCOSSL_DIGEST_PARAM_STATE, unsafe.Slice((*byte)(blobPtr), hdr.size)) + bld.addInt32(_SCOSSL_DIGEST_PARAM_RECOMPUTE_CHECKSUM, 1) + params, err := bld.build() + if err != nil { + return err + } + if _, err := ossl.EVP_MD_CTX_set_params(ctx, params); err != nil { + // Old versions of SCOSSL don't support SCOSSL_DIGEST_PARAM_STATE + // nor _SCOSSL_DIGEST_PARAM_RECOMPUTE_CHECKSUM. + return errHashNotMarshallable + } + return nil +} + +func symCryptHashStateInfo(ch crypto.Hash) (size, typ uint32, serializable bool) { + switch ch { + case crypto.MD5: + return _SYMCRYPT_MD5_STATE_EXPORT_SIZE, _SymCryptBlobTypeMd5State, symCryptHashStateSerializableMD5() + case crypto.SHA1: + return _SYMCRYPT_SHA1_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha1State, symCryptHashStateSerializableSHA1() + case crypto.SHA224: + return _SYMCRYPT_SHA256_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha224State, symCryptHashStateSerializableSHA224() + case crypto.SHA256: + return _SYMCRYPT_SHA256_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha256State, symCryptHashStateSerializableSHA256() + case crypto.SHA384: + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha384State, symCryptHashStateSerializableSHA384() + case crypto.SHA512_224: + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512_224State, symCryptHashStateSerializableSHA512_224() + case crypto.SHA512_256: + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512_256State, symCryptHashStateSerializableSHA512_256() + case crypto.SHA512: + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512State, symCryptHashStateSerializableSHA512() + default: + panic("unsupported hash " + ch.String()) + } +} + +var ( + symCryptHashStateSerializableMD5 = sync.OnceValue(func() bool { + return isSymCryptHashStateSerializable(crypto.MD5) + }) + symCryptHashStateSerializableSHA1 = sync.OnceValue(func() bool { + return isSymCryptHashStateSerializable(crypto.SHA1) + }) + symCryptHashStateSerializableSHA224 = sync.OnceValue(func() bool { + return isSymCryptHashStateSerializable(crypto.SHA224) + }) + symCryptHashStateSerializableSHA256 = sync.OnceValue(func() bool { + return isSymCryptHashStateSerializable(crypto.SHA256) + }) + symCryptHashStateSerializableSHA384 = sync.OnceValue(func() bool { + return isSymCryptHashStateSerializable(crypto.SHA384) + }) + symCryptHashStateSerializableSHA512_224 = sync.OnceValue(func() bool { + return isSymCryptHashStateSerializable(crypto.SHA512_224) + }) + symCryptHashStateSerializableSHA512_256 = sync.OnceValue(func() bool { + return isSymCryptHashStateSerializable(crypto.SHA512_256) + }) + symCryptHashStateSerializableSHA512 = sync.OnceValue(func() bool { + return isSymCryptHashStateSerializable(crypto.SHA512) + }) +) + +// isSymCryptHashStateSerializable checks if the SymCrypt hash state is serializable. +func isSymCryptHashStateSerializable(ch crypto.Hash) bool { + alg := loadHash(ch) + if alg == nil { + return false + } + ctx, err := ossl.EVP_MD_CTX_new() + if err != nil { + return false + } + defer ossl.EVP_MD_CTX_free(ctx) + if _, err := ossl.EVP_DigestInit_ex(ctx, alg.md, nil); err != nil { + return false + } + params, err := ossl.EVP_MD_CTX_gettable_params(ctx) + if err != nil { + return false + } + if _, err = ossl.OSSL_PARAM_locate_const(params, _SCOSSL_DIGEST_PARAM_STATE.ptr()); err != nil { + return false + } + params, err = ossl.EVP_MD_CTX_settable_params(ctx) + if err != nil { + return false + } + if _, err = ossl.OSSL_PARAM_locate_const(params, _SCOSSL_DIGEST_PARAM_STATE.ptr()); err != nil { + return false + } + if _, err = ossl.OSSL_PARAM_locate_const(params, _SCOSSL_DIGEST_PARAM_RECOMPUTE_CHECKSUM.ptr()); err != nil { + return false + } + return true +} From ad6c0bf4007478c84e92b0a65a36658207892207 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 15 May 2025 16:22:59 +0200 Subject: [PATCH 02/11] fix function signatures --- internal/ossl/shims.h | 4 ++-- internal/ossl/zossl.c | 16 ++++++++-------- internal/ossl/zossl.h | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/internal/ossl/shims.h b/internal/ossl/shims.h index 76a08503..a1c15a75 100644 --- a/internal/ossl/shims.h +++ b/internal/ossl/shims.h @@ -191,8 +191,8 @@ _EVP_MD_CTX_PTR EVP_MD_CTX_new(void); void EVP_MD_CTX_free(_EVP_MD_CTX_PTR ctx); int EVP_MD_CTX_copy(_EVP_MD_CTX_PTR out, const _EVP_MD_CTX_PTR in) __attribute__((noescape,nocallback)); int EVP_MD_CTX_copy_ex(_EVP_MD_CTX_PTR out, const _EVP_MD_CTX_PTR in); -_OSSL_PARAM_PTR EVP_MD_CTX_gettable_params(const _EVP_MD_CTX_PTR ctx) __attribute__((tag("3"))); -_OSSL_PARAM_PTR EVP_MD_CTX_settable_params(const _EVP_MD_CTX_PTR ctx) __attribute__((tag("3"))); +const _OSSL_PARAM_PTR EVP_MD_CTX_gettable_params(const _EVP_MD_CTX_PTR ctx) __attribute__((tag("3"))); +const _OSSL_PARAM_PTR EVP_MD_CTX_settable_params(const _EVP_MD_CTX_PTR ctx) __attribute__((tag("3"))); int EVP_MD_CTX_get_params(_EVP_MD_CTX_PTR ctx, _OSSL_PARAM_PTR params) __attribute__((tag("3"))); int EVP_MD_CTX_set_params(_EVP_MD_CTX_PTR ctx, const _OSSL_PARAM_PTR params) __attribute__((tag("3"))); int EVP_Digest(const void *data, size_t count, unsigned char *md, unsigned int *size, const _EVP_MD_PTR type, _ENGINE_PTR impl) __attribute__((noescape,nocallback,nocheckptr("data"))); diff --git a/internal/ossl/zossl.c b/internal/ossl/zossl.c index 6d33f49d..47e75f2e 100644 --- a/internal/ossl/zossl.c +++ b/internal/ossl/zossl.c @@ -99,10 +99,10 @@ int (*_g_EVP_MD_CTX_copy)(_EVP_MD_CTX_PTR, const _EVP_MD_CTX_PTR); int (*_g_EVP_MD_CTX_copy_ex)(_EVP_MD_CTX_PTR, const _EVP_MD_CTX_PTR); void (*_g_EVP_MD_CTX_free)(_EVP_MD_CTX_PTR); int (*_g_EVP_MD_CTX_get_params)(_EVP_MD_CTX_PTR, _OSSL_PARAM_PTR); -_OSSL_PARAM_PTR (*_g_EVP_MD_CTX_gettable_params)(const _EVP_MD_CTX_PTR); +const _OSSL_PARAM_PTR (*_g_EVP_MD_CTX_gettable_params)(const _EVP_MD_CTX_PTR); _EVP_MD_CTX_PTR (*_g_EVP_MD_CTX_new)(void); int (*_g_EVP_MD_CTX_set_params)(_EVP_MD_CTX_PTR, const _OSSL_PARAM_PTR); -_OSSL_PARAM_PTR (*_g_EVP_MD_CTX_settable_params)(const _EVP_MD_CTX_PTR); +const _OSSL_PARAM_PTR (*_g_EVP_MD_CTX_settable_params)(const _EVP_MD_CTX_PTR); _EVP_MD_PTR (*_g_EVP_MD_fetch)(_OSSL_LIB_CTX_PTR, const char*, const char*); void (*_g_EVP_MD_free)(_EVP_MD_PTR); const char* (*_g_EVP_MD_get0_name)(const _EVP_MD_PTR); @@ -1275,10 +1275,10 @@ int _mkcgo_EVP_MD_CTX_get_params(_EVP_MD_CTX_PTR _arg0, _OSSL_PARAM_PTR _arg1, m return _ret; } -_OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_gettable_params(const _EVP_MD_CTX_PTR _arg0, mkcgo_err_state *_err_state) { +const _OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_gettable_params(const _EVP_MD_CTX_PTR _arg0, mkcgo_err_state *_err_state) { mkcgo_err_clear(); - _OSSL_PARAM_PTR _ret = _g_EVP_MD_CTX_gettable_params(_arg0); - if (_ret == NULL) *_err_state = mkcgo_err_retrieve(); + const _OSSL_PARAM_PTR _ret = _g_EVP_MD_CTX_gettable_params(_arg0); + if (_ret <= 0) *_err_state = mkcgo_err_retrieve(); return _ret; } @@ -1296,10 +1296,10 @@ int _mkcgo_EVP_MD_CTX_set_params(_EVP_MD_CTX_PTR _arg0, const _OSSL_PARAM_PTR _a return _ret; } -_OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_settable_params(const _EVP_MD_CTX_PTR _arg0, mkcgo_err_state *_err_state) { +const _OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_settable_params(const _EVP_MD_CTX_PTR _arg0, mkcgo_err_state *_err_state) { mkcgo_err_clear(); - _OSSL_PARAM_PTR _ret = _g_EVP_MD_CTX_settable_params(_arg0); - if (_ret == NULL) *_err_state = mkcgo_err_retrieve(); + const _OSSL_PARAM_PTR _ret = _g_EVP_MD_CTX_settable_params(_arg0); + if (_ret <= 0) *_err_state = mkcgo_err_retrieve(); return _ret; } diff --git a/internal/ossl/zossl.h b/internal/ossl/zossl.h index 4f55cd7c..ddaaed0c 100644 --- a/internal/ossl/zossl.h +++ b/internal/ossl/zossl.h @@ -197,10 +197,10 @@ int _mkcgo_EVP_MD_CTX_copy(_EVP_MD_CTX_PTR, const _EVP_MD_CTX_PTR, mkcgo_err_sta int _mkcgo_EVP_MD_CTX_copy_ex(_EVP_MD_CTX_PTR, const _EVP_MD_CTX_PTR, mkcgo_err_state *); void _mkcgo_EVP_MD_CTX_free(_EVP_MD_CTX_PTR); int _mkcgo_EVP_MD_CTX_get_params(_EVP_MD_CTX_PTR, _OSSL_PARAM_PTR, mkcgo_err_state *); -_OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_gettable_params(const _EVP_MD_CTX_PTR, mkcgo_err_state *); +const _OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_gettable_params(const _EVP_MD_CTX_PTR, mkcgo_err_state *); _EVP_MD_CTX_PTR _mkcgo_EVP_MD_CTX_new(mkcgo_err_state *); int _mkcgo_EVP_MD_CTX_set_params(_EVP_MD_CTX_PTR, const _OSSL_PARAM_PTR, mkcgo_err_state *); -_OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_settable_params(const _EVP_MD_CTX_PTR, mkcgo_err_state *); +const _OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_settable_params(const _EVP_MD_CTX_PTR, mkcgo_err_state *); _EVP_MD_PTR _mkcgo_EVP_MD_fetch(_OSSL_LIB_CTX_PTR, const char*, const char*, mkcgo_err_state *); void _mkcgo_EVP_MD_free(_EVP_MD_PTR); const char* _mkcgo_EVP_MD_get0_name(const _EVP_MD_PTR); From 1ae93ae3f1ac5978bca57bb3fc1735c48132ef6e Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 15 May 2025 16:28:24 +0200 Subject: [PATCH 03/11] fix tests --- hash_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/hash_test.go b/hash_test.go index 4c3ceb16..b932ecad 100644 --- a/hash_test.go +++ b/hash_test.go @@ -14,7 +14,6 @@ import ( _ "crypto/md5" _ "crypto/sha1" _ "crypto/sha256" - _ "crypto/sha3" _ "crypto/sha512" "github.com/golang-fips/openssl/v2" From ce8b1a78d5b9179fe92809712bb30a3448cb7987 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 15 May 2025 16:33:18 +0200 Subject: [PATCH 04/11] fix tests --- hash_test.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/hash_test.go b/hash_test.go index b932ecad..c4ecdc4b 100644 --- a/hash_test.go +++ b/hash_test.go @@ -104,6 +104,10 @@ type hashEncoding interface { hash.Hash encoding.BinaryMarshaler encoding.BinaryUnmarshaler +} + +type hashEncodingAppender interface { + hashEncoding AppendBinary(b []byte) ([]byte, error) } @@ -143,7 +147,12 @@ func TestHash_BinaryMarshaler(t *testing.T) { } // Test that the hash state is compatible with native Go. - h := ch.New().(hashEncoding) + h, ok := ch.New().(hashEncoding) + if !ok { + // The standard library doesn't support encoding this hash. + // Nothing else to do. + return + } h.Write(msg) stateh, err := h.(encoding.BinaryMarshaler).MarshalBinary() if err != nil { @@ -168,7 +177,7 @@ func TestHash_BinaryAppender(t *testing.T) { t.Skip("not supported") } - hashWithBinaryAppender, ok := cryptoToHash(ch)().(hashEncoding) + hashWithBinaryAppender, ok := cryptoToHash(ch)().(hashEncodingAppender) if !ok { t.Fatal("AppendBinary not supported") } From baed898cfe8fefd88f64a991c1821f17a8532202 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 15 May 2025 16:38:11 +0200 Subject: [PATCH 05/11] fix cgoless build --- provideropenssl.go | 2 ++ providersymcrypt.go | 2 ++ 2 files changed, 4 insertions(+) diff --git a/provideropenssl.go b/provideropenssl.go index c1e699cd..44f06519 100644 --- a/provideropenssl.go +++ b/provideropenssl.go @@ -1,3 +1,5 @@ +//go:build !cmd_go_bootstrap && cgo + package openssl import ( diff --git a/providersymcrypt.go b/providersymcrypt.go index 456a2f1a..bca1a833 100644 --- a/providersymcrypt.go +++ b/providersymcrypt.go @@ -1,3 +1,5 @@ +//go:build !cmd_go_bootstrap && cgo + package openssl import ( From a689068c24f687b83e3813182816eb6bd1f6966c Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 15 May 2025 16:39:35 +0200 Subject: [PATCH 06/11] fix headers --- internal/ossl/shims.h | 4 ++-- internal/ossl/zossl.c | 8 ++++---- internal/ossl/zossl.h | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/internal/ossl/shims.h b/internal/ossl/shims.h index a1c15a75..f43a9ae5 100644 --- a/internal/ossl/shims.h +++ b/internal/ossl/shims.h @@ -191,8 +191,8 @@ _EVP_MD_CTX_PTR EVP_MD_CTX_new(void); void EVP_MD_CTX_free(_EVP_MD_CTX_PTR ctx); int EVP_MD_CTX_copy(_EVP_MD_CTX_PTR out, const _EVP_MD_CTX_PTR in) __attribute__((noescape,nocallback)); int EVP_MD_CTX_copy_ex(_EVP_MD_CTX_PTR out, const _EVP_MD_CTX_PTR in); -const _OSSL_PARAM_PTR EVP_MD_CTX_gettable_params(const _EVP_MD_CTX_PTR ctx) __attribute__((tag("3"))); -const _OSSL_PARAM_PTR EVP_MD_CTX_settable_params(const _EVP_MD_CTX_PTR ctx) __attribute__((tag("3"))); +const _OSSL_PARAM_PTR EVP_MD_CTX_gettable_params(_EVP_MD_CTX_PTR ctx) __attribute__((tag("3"))); +const _OSSL_PARAM_PTR EVP_MD_CTX_settable_params(_EVP_MD_CTX_PTR ctx) __attribute__((tag("3"))); int EVP_MD_CTX_get_params(_EVP_MD_CTX_PTR ctx, _OSSL_PARAM_PTR params) __attribute__((tag("3"))); int EVP_MD_CTX_set_params(_EVP_MD_CTX_PTR ctx, const _OSSL_PARAM_PTR params) __attribute__((tag("3"))); int EVP_Digest(const void *data, size_t count, unsigned char *md, unsigned int *size, const _EVP_MD_PTR type, _ENGINE_PTR impl) __attribute__((noescape,nocallback,nocheckptr("data"))); diff --git a/internal/ossl/zossl.c b/internal/ossl/zossl.c index 47e75f2e..1444f551 100644 --- a/internal/ossl/zossl.c +++ b/internal/ossl/zossl.c @@ -99,10 +99,10 @@ int (*_g_EVP_MD_CTX_copy)(_EVP_MD_CTX_PTR, const _EVP_MD_CTX_PTR); int (*_g_EVP_MD_CTX_copy_ex)(_EVP_MD_CTX_PTR, const _EVP_MD_CTX_PTR); void (*_g_EVP_MD_CTX_free)(_EVP_MD_CTX_PTR); int (*_g_EVP_MD_CTX_get_params)(_EVP_MD_CTX_PTR, _OSSL_PARAM_PTR); -const _OSSL_PARAM_PTR (*_g_EVP_MD_CTX_gettable_params)(const _EVP_MD_CTX_PTR); +const _OSSL_PARAM_PTR (*_g_EVP_MD_CTX_gettable_params)(_EVP_MD_CTX_PTR); _EVP_MD_CTX_PTR (*_g_EVP_MD_CTX_new)(void); int (*_g_EVP_MD_CTX_set_params)(_EVP_MD_CTX_PTR, const _OSSL_PARAM_PTR); -const _OSSL_PARAM_PTR (*_g_EVP_MD_CTX_settable_params)(const _EVP_MD_CTX_PTR); +const _OSSL_PARAM_PTR (*_g_EVP_MD_CTX_settable_params)(_EVP_MD_CTX_PTR); _EVP_MD_PTR (*_g_EVP_MD_fetch)(_OSSL_LIB_CTX_PTR, const char*, const char*); void (*_g_EVP_MD_free)(_EVP_MD_PTR); const char* (*_g_EVP_MD_get0_name)(const _EVP_MD_PTR); @@ -1275,7 +1275,7 @@ int _mkcgo_EVP_MD_CTX_get_params(_EVP_MD_CTX_PTR _arg0, _OSSL_PARAM_PTR _arg1, m return _ret; } -const _OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_gettable_params(const _EVP_MD_CTX_PTR _arg0, mkcgo_err_state *_err_state) { +const _OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_gettable_params(_EVP_MD_CTX_PTR _arg0, mkcgo_err_state *_err_state) { mkcgo_err_clear(); const _OSSL_PARAM_PTR _ret = _g_EVP_MD_CTX_gettable_params(_arg0); if (_ret <= 0) *_err_state = mkcgo_err_retrieve(); @@ -1296,7 +1296,7 @@ int _mkcgo_EVP_MD_CTX_set_params(_EVP_MD_CTX_PTR _arg0, const _OSSL_PARAM_PTR _a return _ret; } -const _OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_settable_params(const _EVP_MD_CTX_PTR _arg0, mkcgo_err_state *_err_state) { +const _OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_settable_params(_EVP_MD_CTX_PTR _arg0, mkcgo_err_state *_err_state) { mkcgo_err_clear(); const _OSSL_PARAM_PTR _ret = _g_EVP_MD_CTX_settable_params(_arg0); if (_ret <= 0) *_err_state = mkcgo_err_retrieve(); diff --git a/internal/ossl/zossl.h b/internal/ossl/zossl.h index ddaaed0c..31e766d8 100644 --- a/internal/ossl/zossl.h +++ b/internal/ossl/zossl.h @@ -197,10 +197,10 @@ int _mkcgo_EVP_MD_CTX_copy(_EVP_MD_CTX_PTR, const _EVP_MD_CTX_PTR, mkcgo_err_sta int _mkcgo_EVP_MD_CTX_copy_ex(_EVP_MD_CTX_PTR, const _EVP_MD_CTX_PTR, mkcgo_err_state *); void _mkcgo_EVP_MD_CTX_free(_EVP_MD_CTX_PTR); int _mkcgo_EVP_MD_CTX_get_params(_EVP_MD_CTX_PTR, _OSSL_PARAM_PTR, mkcgo_err_state *); -const _OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_gettable_params(const _EVP_MD_CTX_PTR, mkcgo_err_state *); +const _OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_gettable_params(_EVP_MD_CTX_PTR, mkcgo_err_state *); _EVP_MD_CTX_PTR _mkcgo_EVP_MD_CTX_new(mkcgo_err_state *); int _mkcgo_EVP_MD_CTX_set_params(_EVP_MD_CTX_PTR, const _OSSL_PARAM_PTR, mkcgo_err_state *); -const _OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_settable_params(const _EVP_MD_CTX_PTR, mkcgo_err_state *); +const _OSSL_PARAM_PTR _mkcgo_EVP_MD_CTX_settable_params(_EVP_MD_CTX_PTR, mkcgo_err_state *); _EVP_MD_PTR _mkcgo_EVP_MD_fetch(_OSSL_LIB_CTX_PTR, const char*, const char*, mkcgo_err_state *); void _mkcgo_EVP_MD_free(_EVP_MD_PTR); const char* _mkcgo_EVP_MD_get0_name(const _EVP_MD_PTR); From b95750907c5e29d0a541e8af8ffc945bc9a8d420 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Wed, 21 May 2025 12:24:30 +0200 Subject: [PATCH 07/11] fix sha512 buffer length --- providersymcrypt.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/providersymcrypt.go b/providersymcrypt.go index bca1a833..597b9730 100644 --- a/providersymcrypt.go +++ b/providersymcrypt.go @@ -81,7 +81,12 @@ func (u *_UINT64) uint64() uint64 { // to the given destination slice. func symCryptAppendBinary(dst, chain, buffer []byte, blength _UINT64) []byte { length := blength.uint64() - nx := length & 0x3f + var nx uint64 + if len(buffer) <= 64 { + nx = length & 0x3f + } else { + nx = length & 0x7f + } dst = append(dst, chain...) dst = append(dst, buffer[:nx]...) dst = append(dst, make([]byte, len(buffer)-int(nx))...) From cc31f890e20f284addad1d7a64b0f149c67144ec Mon Sep 17 00:00:00 2001 From: qmuntal Date: Fri, 23 May 2025 14:13:18 +0200 Subject: [PATCH 08/11] return err --- providersymcrypt.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/providersymcrypt.go b/providersymcrypt.go index 597b9730..a80c3059 100644 --- a/providersymcrypt.go +++ b/providersymcrypt.go @@ -203,8 +203,7 @@ func symCryptHashAppendBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, buf []byt ossl.OSSL_PARAM_construct_end(), } if _, err := ossl.EVP_MD_CTX_get_params(ctx, (ossl.OSSL_PARAM_PTR)(unsafe.Pointer(¶ms[0]))); err != nil { - // Old versions of SCOSSL don't support SCOSSL_DIGEST_PARAM_STATE. - return nil, errHashNotMarshallable + return nil, err } header := (*_SYMCRYPT_BLOB_HEADER)(unsafe.Pointer(&state[0])) From 66cbbca11f8863303454cefe1b62ba74be4a8f3f Mon Sep 17 00:00:00 2001 From: qmuntal Date: Fri, 23 May 2025 14:15:58 +0200 Subject: [PATCH 09/11] pass magic as parameter --- hash.go | 11 ++++++----- provideropenssl.go | 6 ++---- providersymcrypt.go | 6 ++---- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/hash.go b/hash.go index 993c785b..4bade6a3 100644 --- a/hash.go +++ b/hash.go @@ -392,14 +392,15 @@ func (d *evpHash) MarshalBinary() ([]byte, error) { func (d *evpHash) AppendBinary(buf []byte) ([]byte, error) { defer runtime.KeepAlive(d) d.init() - if magic, _ := cryptoHashEncodingInfo(d.alg.ch); magic == "" { + magic, _ := cryptoHashEncodingInfo(d.alg.ch) + if magic == "" { return nil, errHashNotMarshallable } switch d.alg.provider { case providerOSSLDefault, providerOSSLFIPS: - return osslHashAppendBinary(d.ctx, d.alg.ch, buf) + return osslHashAppendBinary(d.ctx, d.alg.ch, magic, buf) case providerSymCrypt: - return symCryptHashAppendBinary(d.ctx, d.alg.ch, buf) + return symCryptHashAppendBinary(d.ctx, d.alg.ch, magic, buf) default: return nil, errHashNotMarshallable } @@ -420,9 +421,9 @@ func (d *evpHash) UnmarshalBinary(b []byte) error { } switch d.alg.provider { case providerOSSLDefault, providerOSSLFIPS: - return osslHashUnmarshalBinary(d.ctx, d.alg.ch, b) + return osslHashUnmarshalBinary(d.ctx, d.alg.ch, magic, b) case providerSymCrypt: - return symCryptHashUnmarshalBinary(d.ctx, d.alg.ch, b) + return symCryptHashUnmarshalBinary(d.ctx, d.alg.ch, magic, b) default: return errHashNotMarshallable } diff --git a/provideropenssl.go b/provideropenssl.go index 44f06519..1d8f3ad0 100644 --- a/provideropenssl.go +++ b/provideropenssl.go @@ -190,12 +190,11 @@ func getOSSLDigetsContext(ctx ossl.EVP_MD_CTX_PTR) unsafe.Pointer { var errHashStateInvalid = errors.New("openssl: can't retrieve hash state") -func osslHashAppendBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, buf []byte) ([]byte, error) { +func osslHashAppendBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, magic string, buf []byte) ([]byte, error) { algctx := getOSSLDigetsContext(ctx) if algctx == nil { return nil, errHashStateInvalid } - magic, _ := cryptoHashEncodingInfo(ch) buf = append(buf, magic...) switch ch { case crypto.MD5: @@ -215,12 +214,11 @@ func osslHashAppendBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, buf []byte) ( } } -func osslHashUnmarshalBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, b []byte) error { +func osslHashUnmarshalBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, magic string, b []byte) error { algctx := getOSSLDigetsContext(ctx) if algctx == nil { return errHashStateInvalid } - magic, _ := cryptoHashEncodingInfo(ch) b = b[len(magic):] switch ch { case crypto.MD5: diff --git a/providersymcrypt.go b/providersymcrypt.go index a80c3059..4016930f 100644 --- a/providersymcrypt.go +++ b/providersymcrypt.go @@ -189,7 +189,7 @@ func (b *_SYMCRYPT_SHA512_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { b.lengthL = symCryptUnmarshalBinary(d, b.chain[:], b.buffer[:]) } -func symCryptHashAppendBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, buf []byte) ([]byte, error) { +func symCryptHashAppendBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, magic string, buf []byte) ([]byte, error) { size, typ, serializable := symCryptHashStateInfo(ch) if !serializable { return nil, errHashNotMarshallable @@ -217,7 +217,6 @@ func symCryptHashAppendBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, buf []byt return nil, errors.New("invalid blob type") } - magic, _ := cryptoHashEncodingInfo(ch) buf = append(buf, magic...) switch ch { case crypto.MD5: @@ -237,7 +236,7 @@ func symCryptHashAppendBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, buf []byt } } -func symCryptHashUnmarshalBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, b []byte) error { +func symCryptHashUnmarshalBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, magic string, b []byte) error { size, typ, serializable := symCryptHashStateInfo(ch) if !serializable { return errHashNotMarshallable @@ -248,7 +247,6 @@ func symCryptHashUnmarshalBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, b []by _type: typ, } var blobPtr unsafe.Pointer - magic, _ := cryptoHashEncodingInfo(ch) b = b[len(magic):] switch ch { case crypto.MD5: From 0bf146cb4cbac8d0043ca6b797480f5c0c16e0c7 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Fri, 23 May 2025 14:34:19 +0200 Subject: [PATCH 10/11] simplify --- providersymcrypt.go | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/providersymcrypt.go b/providersymcrypt.go index 4016930f..2857772d 100644 --- a/providersymcrypt.go +++ b/providersymcrypt.go @@ -105,6 +105,16 @@ func symCryptUnmarshalBinary(d []byte, chain, buffer []byte) _UINT64 { return newUINT64(length) } +// swapEndianessInt32 swaps the endianness of the given byte slice +// in place. It assumes the slice is a backup of a 32-bit integer array. +func swapEndianessInt32(d []uint8) { + for i := 0; i < len(d); i += 4 { + d[i], d[i+3] = d[i+3], d[i] + d[i+1], d[i+2] = d[i+2], d[i+1] + } + +} + type _SYMCRYPT_MD5_STATE_EXPORT_BLOB struct { header _SYMCRYPT_BLOB_HEADER chain [16]uint8 // little endian @@ -117,21 +127,13 @@ type _SYMCRYPT_MD5_STATE_EXPORT_BLOB struct { func (b *_SYMCRYPT_MD5_STATE_EXPORT_BLOB) appendBinary(d []byte) ([]byte, error) { // b.chain is little endian, but Go expects big endian, // we need to swap the bytes. - for i := 0; i < len(b.chain); i += 4 { - b.chain[i], b.chain[i+3] = b.chain[i+3], b.chain[i] - b.chain[i+1], b.chain[i+2] = b.chain[i+2], b.chain[i+1] - } + swapEndianessInt32(b.chain[:]) return symCryptAppendBinary(d, b.chain[:], b.buffer[:], b.length), nil } func (b *_SYMCRYPT_MD5_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { b.length = symCryptUnmarshalBinary(d, b.chain[:], b.buffer[:]) - // b.chain should be little endian, but Go uses big endian, - // we need to swap the bytes. - for i := 0; i < len(b.chain); i += 4 { - b.chain[i], b.chain[i+3] = b.chain[i+3], b.chain[i] - b.chain[i+1], b.chain[i+2] = b.chain[i+2], b.chain[i+1] - } + swapEndianessInt32(b.chain[:]) } type _SYMCRYPT_SHA1_STATE_EXPORT_BLOB struct { @@ -283,12 +285,8 @@ func symCryptHashUnmarshalBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, magic if err != nil { return err } - if _, err := ossl.EVP_MD_CTX_set_params(ctx, params); err != nil { - // Old versions of SCOSSL don't support SCOSSL_DIGEST_PARAM_STATE - // nor _SCOSSL_DIGEST_PARAM_RECOMPUTE_CHECKSUM. - return errHashNotMarshallable - } - return nil + _, err = ossl.EVP_MD_CTX_set_params(ctx, params) + return err } func symCryptHashStateInfo(ch crypto.Hash) (size, typ uint32, serializable bool) { From dd36184002a280e6d2356c92f84f522f0f9ea823 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Mon, 26 May 2025 11:28:51 +0200 Subject: [PATCH 11/11] reduce diffs --- evp.go | 37 +++++++++++++--- hash.go | 60 +++++++++----------------- internal/ossl/ossl.go | 20 ++++++--- internal/ossl/shims.h | 2 + internal/ossl/zossl.go | 1 + internal/ossl/zossl.h | 1 + params.go | 6 +-- providersymcrypt.go | 96 +++++++++++++----------------------------- 8 files changed, 101 insertions(+), 122 deletions(-) diff --git a/evp.go b/evp.go index 4189ee27..a6ad5590 100644 --- a/evp.go +++ b/evp.go @@ -74,11 +74,14 @@ const ( ) type hashAlgorithm struct { - md ossl.EVP_MD_PTR - ch crypto.Hash - size int - blockSize int - provider provider + md ossl.EVP_MD_PTR + ch crypto.Hash + size int + blockSize int + provider provider + marshallable bool + magic string + marshalledSize int } // loadHash converts a crypto.Hash to a EVP_MD. @@ -95,25 +98,41 @@ func loadHash(ch crypto.Hash) *hashAlgorithm { hash.md = ossl.EVP_md4() case crypto.MD5: hash.md = ossl.EVP_md5() + hash.magic = magicMD5 + hash.marshalledSize = marshaledSizeMD5 case crypto.MD5SHA1: hash.md = ossl.EVP_md5_sha1() case crypto.SHA1: hash.md = ossl.EVP_sha1() + hash.magic = magic1 + hash.marshalledSize = marshaledSize1 case crypto.SHA224: hash.md = ossl.EVP_sha224() + hash.magic = magic224 + hash.marshalledSize = marshaledSize256 case crypto.SHA256: hash.md = ossl.EVP_sha256() + hash.magic = magic256 + hash.marshalledSize = marshaledSize256 case crypto.SHA384: hash.md = ossl.EVP_sha384() + hash.magic = magic384 + hash.marshalledSize = marshaledSize512 case crypto.SHA512: hash.md = ossl.EVP_sha512() + hash.magic = magic512 + hash.marshalledSize = marshaledSize512 case crypto.SHA512_224: if versionAtOrAbove(1, 1, 1) { hash.md = ossl.EVP_sha512_224() + hash.magic = magic512_224 + hash.marshalledSize = marshaledSize512 } case crypto.SHA512_256: if versionAtOrAbove(1, 1, 1) { hash.md = ossl.EVP_sha512_256() + hash.magic = magic512_256 + hash.marshalledSize = marshaledSize512 } case crypto.SHA3_224: if versionAtOrAbove(1, 1, 1) { @@ -151,6 +170,11 @@ func loadHash(ch crypto.Hash) *hashAlgorithm { hash.md = md } } + if hash.magic != "" { + if hash.marshalledSize == 0 { + panic("marshalledSize must be set for " + hash.magic) + } + } switch vMajor { case 1: @@ -161,10 +185,13 @@ func loadHash(ch crypto.Hash) *hashAlgorithm { switch C.GoString((*C.char)(unsafe.Pointer(cname))) { case "default": hash.provider = providerOSSLDefault + hash.marshallable = hash.magic != "" case "fips": hash.provider = providerOSSLFIPS + hash.marshallable = hash.magic != "" case "symcryptprovider": hash.provider = providerSymCrypt + hash.marshallable = hash.magic != "" && isSymCryptHashStateSerializable(hash.md) } } default: diff --git a/hash.go b/hash.go index 4bade6a3..85ab3fac 100644 --- a/hash.go +++ b/hash.go @@ -25,10 +25,10 @@ const ( magic512_256 = "sha\x06" magic512 = "sha\x07" - marshaledSizeMD5 = len(magicMD5) + 4*4 + 64 + 8 - marshaledSize1 = len(magic1) + 5*4 + 64 + 8 - marshaledSize256 = len(magic256) + 8*4 + 64 + 8 - marshaledSize512 = len(magic512) + 8*8 + 128 + 8 + marshaledSizeMD5 = len(magicMD5) + 4*4 + 64 + 8 // from crypto/md5 + marshaledSize1 = len(magic1) + 5*4 + 64 + 8 // from crypto/sha1 + marshaledSize256 = len(magic256) + 8*4 + 64 + 8 // from crypto/sha256 + marshaledSize512 = len(magic512) + 8*8 + 128 + 8 // from crypto/sha512 ) // maxHashSize is the size of SHA52 and SHA3_512, the largest hashes we support. @@ -385,70 +385,48 @@ func (h *evpHash) Clone() hash.Hash { var errHashNotMarshallable = errors.New("openssl: hash state is not marshallable") func (d *evpHash) MarshalBinary() ([]byte, error) { - buf := make([]byte, 0, marshaledSize512) // stack allocate the buffer by setting the max size we support + if !d.alg.marshallable { + return nil, errHashNotMarshallable + } + buf := make([]byte, 0, d.alg.marshalledSize) return d.AppendBinary(buf) } func (d *evpHash) AppendBinary(buf []byte) ([]byte, error) { defer runtime.KeepAlive(d) - d.init() - magic, _ := cryptoHashEncodingInfo(d.alg.ch) - if magic == "" { + if !d.alg.marshallable { return nil, errHashNotMarshallable } + d.init() switch d.alg.provider { case providerOSSLDefault, providerOSSLFIPS: - return osslHashAppendBinary(d.ctx, d.alg.ch, magic, buf) + return osslHashAppendBinary(d.ctx, d.alg.ch, d.alg.magic, buf) case providerSymCrypt: - return symCryptHashAppendBinary(d.ctx, d.alg.ch, magic, buf) + return symCryptHashAppendBinary(d.ctx, d.alg.ch, d.alg.magic, buf) default: - return nil, errHashNotMarshallable + panic("openssl: unknown hash provider" + strconv.Itoa(int(d.alg.provider))) } } func (d *evpHash) UnmarshalBinary(b []byte) error { defer runtime.KeepAlive(d) d.init() - magic, size := cryptoHashEncodingInfo(d.alg.ch) - if magic == "" { + if !d.alg.marshallable { return errHashNotMarshallable } - if len(b) < len(magic) || string(b[:len(magic)]) != string(magic[:]) { + if len(b) < len(d.alg.magic) || string(b[:len(d.alg.magic)]) != d.alg.magic { return errors.New("openssl: invalid hash state identifier") } - if len(b) != size { + if len(b) != d.alg.marshalledSize { return errors.New("openssl: invalid hash state size") } switch d.alg.provider { case providerOSSLDefault, providerOSSLFIPS: - return osslHashUnmarshalBinary(d.ctx, d.alg.ch, magic, b) + return osslHashUnmarshalBinary(d.ctx, d.alg.ch, d.alg.magic, b) case providerSymCrypt: - return symCryptHashUnmarshalBinary(d.ctx, d.alg.ch, magic, b) - default: - return errHashNotMarshallable - } -} - -func cryptoHashEncodingInfo(ch crypto.Hash) (magic string, size int) { - switch ch { - case crypto.MD5: - return magicMD5, marshaledSizeMD5 - case crypto.SHA1: - return magic1, marshaledSize1 - case crypto.SHA224: - return magic224, marshaledSize256 - case crypto.SHA256: - return magic256, marshaledSize256 - case crypto.SHA384: - return magic384, marshaledSize512 - case crypto.SHA512_224: - return magic512_224, marshaledSize512 - case crypto.SHA512_256: - return magic512_256, marshaledSize512 - case crypto.SHA512: - return magic512, marshaledSize512 + return symCryptHashUnmarshalBinary(d.ctx, d.alg.ch, d.alg.magic, b) default: - return "", 0 + panic("openssl: unknown hash provider" + strconv.Itoa(int(d.alg.provider))) } } diff --git a/internal/ossl/ossl.go b/internal/ossl/ossl.go index 7d18d727..e0d0b8e8 100644 --- a/internal/ossl/ossl.go +++ b/internal/ossl/ossl.go @@ -23,7 +23,6 @@ go_hash_sum(const _EVP_MD_CTX_PTR ctx, _EVP_MD_CTX_PTR ctx2, unsigned char *out, */ import "C" import ( - "math" "unsafe" ) @@ -42,14 +41,16 @@ func HashSum(ctx1, ctx2 EVP_MD_CTX_PTR, out []byte) error { return nil } +const _OSSL_PARAM_UNMODIFIED uint = uint(^uintptr(0)) + // OSSL_PARAM is a structure to pass or request object parameters. // https://docs.openssl.org/3.0/man3/OSSL_PARAM/. type OSSL_PARAM struct { Key *byte DataType uint32 Data unsafe.Pointer - DataSize int - ReturnSize int + DataSize uint + ReturnSize uint } func ossl_param_construct(key *byte, dataType uint32, data unsafe.Pointer, dataSize int) OSSL_PARAM { @@ -57,8 +58,8 @@ func ossl_param_construct(key *byte, dataType uint32, data unsafe.Pointer, dataS Key: key, DataType: dataType, Data: data, - DataSize: dataSize, - ReturnSize: math.MaxInt - 1, + DataSize: uint(dataSize), + ReturnSize: _OSSL_PARAM_UNMODIFIED, } } @@ -66,6 +67,15 @@ func OSSL_PARAM_construct_octet_string(key *byte, data unsafe.Pointer, dataSize return ossl_param_construct(key, OSSL_PARAM_OCTET_STRING, data, dataSize) } +func OSSL_PARAM_construct_int32(key *byte, data *int32) OSSL_PARAM { + return ossl_param_construct(key, OSSL_PARAM_INTEGER, unsafe.Pointer(data), 4) +} + func OSSL_PARAM_construct_end() OSSL_PARAM { return OSSL_PARAM{} } + +func OSSL_PARAM_modified(param *OSSL_PARAM) bool { + // If ReturnSize is not set, the parameter has not been modified. + return param != nil && param.ReturnSize != _OSSL_PARAM_UNMODIFIED +} diff --git a/internal/ossl/shims.h b/internal/ossl/shims.h index f43a9ae5..5790defe 100644 --- a/internal/ossl/shims.h +++ b/internal/ossl/shims.h @@ -23,6 +23,7 @@ // #include // #include // #include +// #include // #endif // #if OPENSSL_VERSION_NUMBER < 0x10100000L // #include @@ -87,6 +88,7 @@ enum { _EVP_PKEY_CTRL_DSA_PARAMGEN_BITS = 0x1001, _EVP_PKEY_CTRL_DSA_PARAMGEN_Q_BITS = 0x1002, + _OSSL_PARAM_INTEGER = 1, _OSSL_PARAM_OCTET_STRING = 5, }; diff --git a/internal/ossl/zossl.go b/internal/ossl/zossl.go index 94a3dc6b..9079f7a0 100644 --- a/internal/ossl/zossl.go +++ b/internal/ossl/zossl.go @@ -61,6 +61,7 @@ const ( EVP_PKEY_CTRL_RSA_OAEP_LABEL = 0x100A EVP_PKEY_CTRL_DSA_PARAMGEN_BITS = 0x1001 EVP_PKEY_CTRL_DSA_PARAMGEN_Q_BITS = 0x1002 + OSSL_PARAM_INTEGER = 1 OSSL_PARAM_OCTET_STRING = 5 ) diff --git a/internal/ossl/zossl.h b/internal/ossl/zossl.h index 31e766d8..ea01eed7 100644 --- a/internal/ossl/zossl.h +++ b/internal/ossl/zossl.h @@ -83,6 +83,7 @@ enum { _EVP_PKEY_CTRL_RSA_OAEP_LABEL = 0x100A, _EVP_PKEY_CTRL_DSA_PARAMGEN_BITS = 0x1001, _EVP_PKEY_CTRL_DSA_PARAMGEN_Q_BITS = 0x1002, + _OSSL_PARAM_INTEGER = 1, _OSSL_PARAM_OCTET_STRING = 5, }; diff --git a/params.go b/params.go index 337bba9a..a5b6cdb9 100644 --- a/params.go +++ b/params.go @@ -37,17 +37,13 @@ type paramBuilder struct { // newParamBuilder creates a new paramBuilder. func newParamBuilder() (*paramBuilder, error) { - return newParamBuilderN(8) // the maximum known number of BIGNUMs to free are 8 for RSA -} - -func newParamBuilderN(n int) (*paramBuilder, error) { bld, err := ossl.OSSL_PARAM_BLD_new() if err != nil { return nil, err } pb := ¶mBuilder{ bld: bld, - bnToFree: make([]bnParam, 0, n), + bnToFree: make([]bnParam, 0, 8), // the maximum known number of BIGNUMs to free are 8 for RSA } runtime.SetFinalizer(pb, (*paramBuilder).finalize) return pb, nil diff --git a/providersymcrypt.go b/providersymcrypt.go index 2857772d..693849f1 100644 --- a/providersymcrypt.go +++ b/providersymcrypt.go @@ -6,7 +6,6 @@ import ( "crypto" "errors" "runtime" - "sync" "unsafe" "github.com/golang-fips/openssl/v2/internal/ossl" @@ -105,9 +104,9 @@ func symCryptUnmarshalBinary(d []byte, chain, buffer []byte) _UINT64 { return newUINT64(length) } -// swapEndianessInt32 swaps the endianness of the given byte slice +// swapEndianessUint32 swaps the endianness of the given byte slice // in place. It assumes the slice is a backup of a 32-bit integer array. -func swapEndianessInt32(d []uint8) { +func swapEndianessUint32(d []uint8) { for i := 0; i < len(d); i += 4 { d[i], d[i+3] = d[i+3], d[i] d[i+1], d[i+2] = d[i+2], d[i+1] @@ -127,13 +126,13 @@ type _SYMCRYPT_MD5_STATE_EXPORT_BLOB struct { func (b *_SYMCRYPT_MD5_STATE_EXPORT_BLOB) appendBinary(d []byte) ([]byte, error) { // b.chain is little endian, but Go expects big endian, // we need to swap the bytes. - swapEndianessInt32(b.chain[:]) + swapEndianessUint32(b.chain[:]) return symCryptAppendBinary(d, b.chain[:], b.buffer[:], b.length), nil } func (b *_SYMCRYPT_MD5_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { b.length = symCryptUnmarshalBinary(d, b.chain[:], b.buffer[:]) - swapEndianessInt32(b.chain[:]) + swapEndianessUint32(b.chain[:]) } type _SYMCRYPT_SHA1_STATE_EXPORT_BLOB struct { @@ -192,10 +191,7 @@ func (b *_SYMCRYPT_SHA512_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { } func symCryptHashAppendBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, magic string, buf []byte) ([]byte, error) { - size, typ, serializable := symCryptHashStateInfo(ch) - if !serializable { - return nil, errHashNotMarshallable - } + size, typ := symCryptHashStateInfo(ch) state := make([]byte, size, _SYMCRYPT_SHA512_STATE_EXPORT_SIZE) // 512 is the largest size var pinner runtime.Pinner pinner.Pin(&state[0]) @@ -207,6 +203,9 @@ func symCryptHashAppendBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, magic str if _, err := ossl.EVP_MD_CTX_get_params(ctx, (ossl.OSSL_PARAM_PTR)(unsafe.Pointer(¶ms[0]))); err != nil { return nil, err } + if !ossl.OSSL_PARAM_modified(¶ms[0]) { + return nil, errors.New("EVP_MD_CTX_get_params did not retrieve the state") + } header := (*_SYMCRYPT_BLOB_HEADER)(unsafe.Pointer(&state[0])) if header.magic != _SYMCRYPT_BLOB_MAGIC { @@ -239,10 +238,7 @@ func symCryptHashAppendBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, magic str } func symCryptHashUnmarshalBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, magic string, b []byte) error { - size, typ, serializable := symCryptHashStateInfo(ch) - if !serializable { - return errHashNotMarshallable - } + size, typ := symCryptHashStateInfo(ch) hdr := _SYMCRYPT_BLOB_HEADER{ magic: _SYMCRYPT_BLOB_MAGIC, size: size, @@ -274,83 +270,51 @@ func symCryptHashUnmarshalBinary(ctx ossl.EVP_MD_CTX_PTR, ch crypto.Hash, magic default: panic("unsupported hash " + ch.String()) } - bld, err := newParamBuilderN(2) - if err != nil { - return err - } - defer bld.finalize() - bld.addOctetString(_SCOSSL_DIGEST_PARAM_STATE, unsafe.Slice((*byte)(blobPtr), hdr.size)) - bld.addInt32(_SCOSSL_DIGEST_PARAM_RECOMPUTE_CHECKSUM, 1) - params, err := bld.build() - if err != nil { - return err + var checksum int32 = 1 + var pinner runtime.Pinner + pinner.Pin(blobPtr) + pinner.Pin(&checksum) + defer pinner.Unpin() + params := [3]ossl.OSSL_PARAM{ + ossl.OSSL_PARAM_construct_octet_string(_SCOSSL_DIGEST_PARAM_STATE.ptr(), blobPtr, int(hdr.size)), + ossl.OSSL_PARAM_construct_int32(_SCOSSL_DIGEST_PARAM_RECOMPUTE_CHECKSUM.ptr(), &checksum), + ossl.OSSL_PARAM_construct_end(), } - _, err = ossl.EVP_MD_CTX_set_params(ctx, params) + _, err := ossl.EVP_MD_CTX_set_params(ctx, (ossl.OSSL_PARAM_PTR)(unsafe.Pointer(¶ms[0]))) return err } -func symCryptHashStateInfo(ch crypto.Hash) (size, typ uint32, serializable bool) { +func symCryptHashStateInfo(ch crypto.Hash) (size, typ uint32) { switch ch { case crypto.MD5: - return _SYMCRYPT_MD5_STATE_EXPORT_SIZE, _SymCryptBlobTypeMd5State, symCryptHashStateSerializableMD5() + return _SYMCRYPT_MD5_STATE_EXPORT_SIZE, _SymCryptBlobTypeMd5State case crypto.SHA1: - return _SYMCRYPT_SHA1_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha1State, symCryptHashStateSerializableSHA1() + return _SYMCRYPT_SHA1_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha1State case crypto.SHA224: - return _SYMCRYPT_SHA256_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha224State, symCryptHashStateSerializableSHA224() + return _SYMCRYPT_SHA256_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha224State case crypto.SHA256: - return _SYMCRYPT_SHA256_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha256State, symCryptHashStateSerializableSHA256() + return _SYMCRYPT_SHA256_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha256State case crypto.SHA384: - return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha384State, symCryptHashStateSerializableSHA384() + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha384State case crypto.SHA512_224: - return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512_224State, symCryptHashStateSerializableSHA512_224() + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512_224State case crypto.SHA512_256: - return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512_256State, symCryptHashStateSerializableSHA512_256() + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512_256State case crypto.SHA512: - return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512State, symCryptHashStateSerializableSHA512() + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512State default: panic("unsupported hash " + ch.String()) } } -var ( - symCryptHashStateSerializableMD5 = sync.OnceValue(func() bool { - return isSymCryptHashStateSerializable(crypto.MD5) - }) - symCryptHashStateSerializableSHA1 = sync.OnceValue(func() bool { - return isSymCryptHashStateSerializable(crypto.SHA1) - }) - symCryptHashStateSerializableSHA224 = sync.OnceValue(func() bool { - return isSymCryptHashStateSerializable(crypto.SHA224) - }) - symCryptHashStateSerializableSHA256 = sync.OnceValue(func() bool { - return isSymCryptHashStateSerializable(crypto.SHA256) - }) - symCryptHashStateSerializableSHA384 = sync.OnceValue(func() bool { - return isSymCryptHashStateSerializable(crypto.SHA384) - }) - symCryptHashStateSerializableSHA512_224 = sync.OnceValue(func() bool { - return isSymCryptHashStateSerializable(crypto.SHA512_224) - }) - symCryptHashStateSerializableSHA512_256 = sync.OnceValue(func() bool { - return isSymCryptHashStateSerializable(crypto.SHA512_256) - }) - symCryptHashStateSerializableSHA512 = sync.OnceValue(func() bool { - return isSymCryptHashStateSerializable(crypto.SHA512) - }) -) - // isSymCryptHashStateSerializable checks if the SymCrypt hash state is serializable. -func isSymCryptHashStateSerializable(ch crypto.Hash) bool { - alg := loadHash(ch) - if alg == nil { - return false - } +func isSymCryptHashStateSerializable(md ossl.EVP_MD_PTR) bool { ctx, err := ossl.EVP_MD_CTX_new() if err != nil { return false } defer ossl.EVP_MD_CTX_free(ctx) - if _, err := ossl.EVP_DigestInit_ex(ctx, alg.md, nil); err != nil { + if _, err := ossl.EVP_DigestInit_ex(ctx, md, nil); err != nil { return false } params, err := ossl.EVP_MD_CTX_gettable_params(ctx)