From 413fa493ebef8c115b7bc0e4411dfa006460568e Mon Sep 17 00:00:00 2001 From: qmuntal Date: Fri, 23 May 2025 15:18:00 +0200 Subject: [PATCH 1/2] support serializing SymCrypt hash objects --- evp.go | 58 ++++--- hash.go | 338 +++++++-------------------------------- hash_test.go | 59 +++++-- params.go | 35 +++- provideropenssl.go | 239 ++++++++++++++++++++++++++++ providersymcrypt.go | 377 ++++++++++++++++++++++++++++++++++++++++++++ shims.h | 9 +- 7 files changed, 789 insertions(+), 326 deletions(-) create mode 100644 provideropenssl.go create mode 100644 providersymcrypt.go diff --git a/evp.go b/evp.go index 8b5b367f..cbf3d178 100644 --- a/evp.go +++ b/evp.go @@ -68,14 +68,22 @@ func hashFuncToMD(fn func() hash.Hash) (C.GO_EVP_MD_PTR, error) { return md, nil } +// provider is an identifier for a known provider. +type provider uint8 + +const ( + providerNone provider = iota + providerOSSLDefault + providerOSSLFIPS + providerSymCrypt +) + type hashAlgorithm struct { - md C.GO_EVP_MD_PTR - ch crypto.Hash - size int - blockSize int - marshallable bool - magic string - marshalledSize int + md C.GO_EVP_MD_PTR + ch crypto.Hash + size int + blockSize int + provider provider } // loadHash converts a crypto.Hash to a EVP_MD. @@ -92,8 +100,6 @@ func loadHash(ch crypto.Hash) *hashAlgorithm { hash.md = C.go_openssl_EVP_md4() case crypto.MD5: hash.md = C.go_openssl_EVP_md5() - hash.magic = md5Magic - hash.marshalledSize = md5MarshaledSize case crypto.MD5SHA1: if vMajor == 1 && vMinor == 0 { // OpenSSL 1.0.2 does not support MD5SHA1. @@ -103,35 +109,21 @@ func loadHash(ch crypto.Hash) *hashAlgorithm { } case crypto.SHA1: hash.md = C.go_openssl_EVP_sha1() - hash.magic = sha1Magic - hash.marshalledSize = sha1MarshaledSize case crypto.SHA224: hash.md = C.go_openssl_EVP_sha224() - hash.magic = magic224 - hash.marshalledSize = marshaledSize256 case crypto.SHA256: hash.md = C.go_openssl_EVP_sha256() - hash.magic = magic256 - hash.marshalledSize = marshaledSize256 case crypto.SHA384: hash.md = C.go_openssl_EVP_sha384() - hash.magic = magic384 - hash.marshalledSize = marshaledSize512 case crypto.SHA512: hash.md = C.go_openssl_EVP_sha512() - hash.magic = magic512 - hash.marshalledSize = marshaledSize512 case crypto.SHA512_224: if versionAtOrAbove(1, 1, 1) { hash.md = C.go_openssl_EVP_sha512_224() - hash.magic = magic512_224 - hash.marshalledSize = marshaledSize512 } case crypto.SHA512_256: if versionAtOrAbove(1, 1, 1) { hash.md = C.go_openssl_EVP_sha512_256() - hash.magic = magic512_256 - hash.marshalledSize = marshaledSize512 } case crypto.SHA3_224: if versionAtOrAbove(1, 1, 1) { @@ -169,7 +161,25 @@ func loadHash(ch crypto.Hash) *hashAlgorithm { hash.md = md } } - hash.marshallable = hash.magic != "" && isHashMarshallable(hash.md) + + switch vMajor { + case 1: + hash.provider = providerOSSLDefault + case 3: + if prov := C.go_openssl_EVP_MD_get0_provider(hash.md); prov != nil { + switch C.GoString(C.go_openssl_OSSL_PROVIDER_get0_name(prov)) { + case "default": + hash.provider = providerOSSLDefault + case "fips": + hash.provider = providerOSSLFIPS + case "symcryptprovider": + hash.provider = providerSymCrypt + } + } + default: + panic(errUnsupportedVersion()) + } + cacheMD.Store(ch, &hash) return &hash } diff --git a/hash.go b/hash.go index b2109857..563154e5 100644 --- a/hash.go +++ b/hash.go @@ -14,6 +14,22 @@ import ( "unsafe" ) +const ( + magicMD5 = "md5\x01" + magic1 = "sha\x01" + magic224 = "sha\x02" + magic256 = "sha\x03" + magic384 = "sha\x04" + magic512_224 = "sha\x05" + magic512_256 = "sha\x06" + magic512 = "sha\x07" + + marshaledSizeMD5 = len(magicMD5) + 4*4 + 64 + 8 + marshaledSize1 = len(magic1) + 5*4 + 64 + 8 + marshaledSize256 = len(magic256) + 8*4 + 64 + 8 + marshaledSize512 = len(magic512) + 8*8 + 128 + 8 +) + // maxHashSize is the size of SHA52 and SHA3_512, the largest hashes we support. const maxHashSize = 64 @@ -215,27 +231,6 @@ func NewSHA3_512() hash.Hash { return newEvpHash(crypto.SHA3_512) } -// isHashMarshallable returns true if the memory layout of md -// is known by this library and can therefore be marshalled. -func isHashMarshallable(md C.GO_EVP_MD_PTR) bool { - if vMajor == 1 { - return true - } - prov := C.go_openssl_EVP_MD_get0_provider(md) - if prov == nil { - return false - } - cname := C.go_openssl_OSSL_PROVIDER_get0_name(prov) - if cname == nil { - return false - } - name := C.GoString(cname) - // We only know the memory layout of the built-in providers. - // See evpHash.hashState for more details. - marshallable := name == "default" || name == "fips" - return marshallable -} - // cloneHash is an interface that defines a Clone method. // // hahs.CloneHash will probably be added in Go 1.25, see https://golang.org/issue/69521, @@ -384,299 +379,74 @@ func (h *evpHash) Clone() hash.Hash { return h2 } -// hashState returns a pointer to the internal hash structure. -// -// The EVP_MD_CTX memory layout has changed in OpenSSL 3 -// and the property holding the internal structure is no longer md_data but algctx. -func hashState(ctx C.GO_EVP_MD_CTX_PTR) unsafe.Pointer { - switch vMajor { - case 1: - // https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/crypto/evp/evp_local.h#L12. - type mdCtx struct { - _ [2]unsafe.Pointer - _ C.ulong - md_data unsafe.Pointer - } - return (*mdCtx)(unsafe.Pointer(ctx)).md_data - case 3: - // https://github.com/openssl/openssl/blob/5675a5aaf6a2e489022bcfc18330dae9263e598e/crypto/evp/evp_local.h#L16. - type mdCtx struct { - _ [3]unsafe.Pointer - _ C.ulong - _ [3]unsafe.Pointer - algctx unsafe.Pointer - } - return (*mdCtx)(unsafe.Pointer(ctx)).algctx - default: - panic(errUnsupportedVersion()) - } -} +var errHashNotMarshallable = errors.New("openssl: hash state is not marshallable") func (d *evpHash) MarshalBinary() ([]byte, error) { - if !d.alg.marshallable { - return nil, errors.New("openssl: hash state is not marshallable") - } - buf := make([]byte, 0, d.alg.marshalledSize) + buf := make([]byte, 0, marshaledSize512) // stack allocate the buffer by setting the max size we support return d.AppendBinary(buf) } func (d *evpHash) AppendBinary(buf []byte) ([]byte, error) { defer runtime.KeepAlive(d) d.init() - if !d.alg.marshallable { - return nil, errors.New("openssl: hash state is not marshallable") - } - state := hashState(d.ctx) - if state == nil { - return nil, errors.New("openssl: can't retrieve hash state") - } - var appender interface { - AppendBinary([]byte) ([]byte, error) - } - switch d.alg.ch { - case crypto.MD5: - appender = (*md5State)(state) - case crypto.SHA1: - appender = (*sha1State)(state) - case crypto.SHA224: - appender = (*sha256State)(state) - case crypto.SHA256: - appender = (*sha256State)(state) - case crypto.SHA384: - appender = (*sha512State)(state) - case crypto.SHA512: - appender = (*sha512State)(state) - case crypto.SHA512_224: - appender = (*sha512State)(state) - case crypto.SHA512_256: - appender = (*sha512State)(state) + magic, _ := cryptoHashEncodingInfo(d.alg.ch) + if magic == "" { + return nil, errHashNotMarshallable + } + switch d.alg.provider { + case providerOSSLDefault, providerOSSLFIPS: + return osslHashAppendBinary(d.ctx, d.alg.ch, magic, buf) + case providerSymCrypt: + return symCryptHashAppendBinary(d.ctx, d.alg.ch, magic, buf) default: - panic("openssl: unsupported hash function: " + strconv.Itoa(int(d.alg.ch))) + return nil, errHashNotMarshallable } - buf = append(buf, d.alg.magic[:]...) - return appender.AppendBinary(buf) } func (d *evpHash) UnmarshalBinary(b []byte) error { defer runtime.KeepAlive(d) d.init() - if !d.alg.marshallable { - return errors.New("openssl: hash state is not marshallable") + magic, size := cryptoHashEncodingInfo(d.alg.ch) + if magic == "" { + return errHashNotMarshallable } - if len(b) < len(d.alg.magic) || string(b[:len(d.alg.magic)]) != string(d.alg.magic[:]) { + if len(b) < len(magic) || string(b[:len(magic)]) != string(magic[:]) { return errors.New("openssl: invalid hash state identifier") } - if len(b) != d.alg.marshalledSize { + if len(b) != size { return errors.New("openssl: invalid hash state size") } - state := hashState(d.ctx) - if state == nil { - return errors.New("openssl: can't retrieve hash state") - } - b = b[len(d.alg.magic):] - var unmarshaler interface { - UnmarshalBinary([]byte) error + switch d.alg.provider { + case providerOSSLDefault, providerOSSLFIPS: + return osslHashUnmarshalBinary(d.ctx, d.alg.ch, magic, b) + case providerSymCrypt: + return symCryptHashUnmarshalBinary(d.ctx, d.alg.ch, magic, b) + default: + return errHashNotMarshallable } - switch d.alg.ch { +} + +func cryptoHashEncodingInfo(ch crypto.Hash) (magic string, size int) { + switch ch { case crypto.MD5: - unmarshaler = (*md5State)(state) + return magicMD5, marshaledSizeMD5 case crypto.SHA1: - unmarshaler = (*sha1State)(state) + return magic1, marshaledSize1 case crypto.SHA224: - unmarshaler = (*sha256State)(state) + return magic224, marshaledSize256 case crypto.SHA256: - unmarshaler = (*sha256State)(state) + return magic256, marshaledSize256 case crypto.SHA384: - unmarshaler = (*sha512State)(state) - case crypto.SHA512: - unmarshaler = (*sha512State)(state) + return magic384, marshaledSize512 case crypto.SHA512_224: - unmarshaler = (*sha512State)(state) + return magic512_224, marshaledSize512 case crypto.SHA512_256: - unmarshaler = (*sha512State)(state) + return magic512_256, marshaledSize512 + case crypto.SHA512: + return magic512, marshaledSize512 default: - panic("openssl: unsupported hash function: " + strconv.Itoa(int(d.alg.ch))) + return "", 0 } - return unmarshaler.UnmarshalBinary(b) -} - -// md5State layout is taken from -// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/md5.h#L33. -type md5State struct { - h [4]uint32 - nl, nh uint32 - x [64]byte - nx uint32 -} - -const ( - md5Magic = "md5\x01" - md5MarshaledSize = len(md5Magic) + 4*4 + 64 + 8 -) - -func (d *md5State) UnmarshalBinary(b []byte) error { - b, d.h[0] = consumeUint32(b) - b, d.h[1] = consumeUint32(b) - b, d.h[2] = consumeUint32(b) - b, d.h[3] = consumeUint32(b) - b = b[copy(d.x[:], b):] - _, n := consumeUint64(b) - d.nl = uint32(n << 3) - d.nh = uint32(n >> 29) - d.nx = uint32(n) % 64 - return nil -} - -func (d *md5State) AppendBinary(buf []byte) ([]byte, error) { - buf = appendUint32(buf, d.h[0]) - buf = appendUint32(buf, d.h[1]) - buf = appendUint32(buf, d.h[2]) - buf = appendUint32(buf, d.h[3]) - buf = append(buf, d.x[:d.nx]...) - buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) - buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) - return buf, nil -} - -// sha1State layout is taken from -// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L34. -type sha1State struct { - h [5]uint32 - nl, nh uint32 - x [64]byte - nx uint32 -} - -const ( - sha1Magic = "sha\x01" - sha1MarshaledSize = len(sha1Magic) + 5*4 + 64 + 8 -) - -func (d *sha1State) UnmarshalBinary(b []byte) error { - b, d.h[0] = consumeUint32(b) - b, d.h[1] = consumeUint32(b) - b, d.h[2] = consumeUint32(b) - b, d.h[3] = consumeUint32(b) - b, d.h[4] = consumeUint32(b) - b = b[copy(d.x[:], b):] - _, n := consumeUint64(b) - d.nl = uint32(n << 3) - d.nh = uint32(n >> 29) - d.nx = uint32(n) % 64 - return nil -} - -func (d *sha1State) AppendBinary(buf []byte) ([]byte, error) { - buf = appendUint32(buf, d.h[0]) - buf = appendUint32(buf, d.h[1]) - buf = appendUint32(buf, d.h[2]) - buf = appendUint32(buf, d.h[3]) - buf = appendUint32(buf, d.h[4]) - buf = append(buf, d.x[:d.nx]...) - buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) - buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) - return buf, nil -} - -const ( - magic224 = "sha\x02" - magic256 = "sha\x03" - marshaledSize256 = len(magic256) + 8*4 + 64 + 8 -) - -// sha256State layout is taken from -// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L51. -type sha256State struct { - h [8]uint32 - nl, nh uint32 - x [64]byte - nx uint32 -} - -func (d *sha256State) UnmarshalBinary(b []byte) error { - b, d.h[0] = consumeUint32(b) - b, d.h[1] = consumeUint32(b) - b, d.h[2] = consumeUint32(b) - b, d.h[3] = consumeUint32(b) - b, d.h[4] = consumeUint32(b) - b, d.h[5] = consumeUint32(b) - b, d.h[6] = consumeUint32(b) - b, d.h[7] = consumeUint32(b) - b = b[copy(d.x[:], b):] - _, n := consumeUint64(b) - d.nl = uint32(n << 3) - d.nh = uint32(n >> 29) - d.nx = uint32(n) % 64 - return nil -} - -func (d *sha256State) AppendBinary(buf []byte) ([]byte, error) { - buf = appendUint32(buf, d.h[0]) - buf = appendUint32(buf, d.h[1]) - buf = appendUint32(buf, d.h[2]) - buf = appendUint32(buf, d.h[3]) - buf = appendUint32(buf, d.h[4]) - buf = appendUint32(buf, d.h[5]) - buf = appendUint32(buf, d.h[6]) - buf = appendUint32(buf, d.h[7]) - buf = append(buf, d.x[:d.nx]...) - buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) - buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) - return buf, nil -} - -// sha512State layout is taken from -// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L95. -type sha512State struct { - h [8]uint64 - nl, nh uint64 - x [128]byte - nx uint32 -} - -const ( - magic384 = "sha\x04" - magic512_224 = "sha\x05" - magic512_256 = "sha\x06" - magic512 = "sha\x07" - marshaledSize512 = len(magic512) + 8*8 + 128 + 8 -) - -func (d *sha512State) MarshalBinary() ([]byte, error) { - buf := make([]byte, 0, marshaledSize512) - return d.AppendBinary(buf) -} - -func (d *sha512State) UnmarshalBinary(b []byte) error { - b, d.h[0] = consumeUint64(b) - b, d.h[1] = consumeUint64(b) - b, d.h[2] = consumeUint64(b) - b, d.h[3] = consumeUint64(b) - b, d.h[4] = consumeUint64(b) - b, d.h[5] = consumeUint64(b) - b, d.h[6] = consumeUint64(b) - b, d.h[7] = consumeUint64(b) - b = b[copy(d.x[:], b):] - _, n := consumeUint64(b) - d.nl = n << 3 - d.nh = n >> 61 - d.nx = uint32(n) % 128 - return nil -} - -func (d *sha512State) AppendBinary(buf []byte) ([]byte, error) { - buf = appendUint64(buf, d.h[0]) - buf = appendUint64(buf, d.h[1]) - buf = appendUint64(buf, d.h[2]) - buf = appendUint64(buf, d.h[3]) - buf = appendUint64(buf, d.h[4]) - buf = appendUint64(buf, d.h[5]) - buf = appendUint64(buf, d.h[6]) - buf = appendUint64(buf, d.h[7]) - buf = append(buf, d.x[:d.nx]...) - buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) - buf = appendUint64(buf, d.nl>>3|d.nh<<61) - return buf, nil } // appendUint64 appends x into b as a big endian byte sequence. diff --git a/hash_test.go b/hash_test.go index a5c2581e..600b2625 100644 --- a/hash_test.go +++ b/hash_test.go @@ -9,6 +9,12 @@ import ( "strings" "testing" + // Blank imports to ensure that the hash functions are registered. + _ "crypto/md5" + _ "crypto/sha1" + _ "crypto/sha256" + _ "crypto/sha512" + "github.com/golang-fips/openssl/v2" ) @@ -93,6 +99,17 @@ func TestHash(t *testing.T) { } } +type hashEncoding interface { + hash.Hash + encoding.BinaryMarshaler + encoding.BinaryUnmarshaler +} + +type hashEncodingAppender interface { + hashEncoding + AppendBinary(b []byte) ([]byte, error) +} + func TestHash_BinaryMarshaler(t *testing.T) { msg := []byte("testing") for _, ch := range hashes { @@ -102,10 +119,7 @@ func TestHash_BinaryMarshaler(t *testing.T) { t.Skip("hash not supported") } - hashMarshaler, ok := cryptoToHash(ch)().(interface { - hash.Hash - encoding.BinaryMarshaler - }) + hashMarshaler, ok := cryptoToHash(ch)().(hashEncoding) if !ok { t.Fatal("BinaryMarshaler not supported") } @@ -122,10 +136,7 @@ func TestHash_BinaryMarshaler(t *testing.T) { t.Fatalf("MarshalBinary failed: %v", err) } - hashUnmarshaler := cryptoToHash(ch)().(interface { - hash.Hash - encoding.BinaryUnmarshaler - }) + hashUnmarshaler := cryptoToHash(ch)().(hashEncoding) if err := hashUnmarshaler.UnmarshalBinary(state); err != nil { t.Fatalf("UnmarshalBinary failed: %v", err) } @@ -133,6 +144,26 @@ func TestHash_BinaryMarshaler(t *testing.T) { if actual, actual2 := hashMarshaler.Sum(nil), hashUnmarshaler.Sum(nil); !bytes.Equal(actual, actual2) { t.Errorf("0x%x != appended 0x%x", actual, actual2) } + + // Test that the hash state is compatible with native Go. + h, ok := ch.New().(hashEncoding) + if !ok { + // The standard library doesn't support encoding this hash. + // Nothing else to do. + return + } + h.Write(msg) + stateh, err := h.(encoding.BinaryMarshaler).MarshalBinary() + if err != nil { + t.Error(err) + } + if !bytes.Equal(state, stateh) { + t.Errorf("got 0x%x != want 0x%x", state, stateh) + } + h = ch.New().(hashEncoding) + if err := h.UnmarshalBinary(state); err != nil { + t.Error(err) + } }) } } @@ -145,10 +176,7 @@ func TestHash_BinaryAppender(t *testing.T) { t.Skip("not supported") } - hashWithBinaryAppender, ok := cryptoToHash(ch)().(interface { - hash.Hash - AppendBinary(b []byte) ([]byte, error) - }) + hashWithBinaryAppender, ok := cryptoToHash(ch)().(hashEncodingAppender) if !ok { t.Fatal("AppendBinary not supported") } @@ -180,10 +208,7 @@ func TestHash_BinaryAppender(t *testing.T) { // Use only the newly appended part of the slice appendedState := state[10:] - h2, ok := cryptoToHash(ch)().(interface { - hash.Hash - encoding.BinaryUnmarshaler - }) + h2, ok := cryptoToHash(ch)().(hashEncoding) if !ok { t.Skip("not supported") } @@ -261,6 +286,7 @@ func TestHash_StringWriter(t *testing.T) { } h := cryptoToHash(ch)() initSum := h.Sum(nil) + h.(io.StringWriter).WriteString("") h.(io.StringWriter).WriteString(string(msg)) h.Reset() sum := h.Sum(nil) @@ -327,6 +353,7 @@ func TestHash_OneShot(t *testing.T) { if !openssl.SupportsHash(tt.h) { t.Skip("not supported") } + _ = tt.oneShot(nil) // test that does not panic got := tt.oneShot(msg) h := cryptoToHash(tt.h)() h.Write(msg) diff --git a/params.go b/params.go index fd5bd405..844bfd23 100644 --- a/params.go +++ b/params.go @@ -5,6 +5,7 @@ package openssl // #include "goopenssl.h" import "C" import ( + "math" "runtime" "unsafe" ) @@ -46,6 +47,34 @@ var ( _OSSL_MAC_PARAM_DIGEST = C.CString("digest") ) +// _OSSL_PARAM is a structure to pass or request object parameters. +// https://docs.openssl.org/3.0/man3/OSSL_PARAM/. +type _OSSL_PARAM struct { + Key *C.char + DataType uint32 + Data unsafe.Pointer + DataSize int + ReturnSize int +} + +func ossl_param_construct(key *C.char, dataType uint32, data unsafe.Pointer, dataSize int) _OSSL_PARAM { + return _OSSL_PARAM{ + Key: key, + DataType: dataType, + Data: data, + DataSize: dataSize, + ReturnSize: math.MaxInt - 1, + } +} + +func _OSSL_PARAM_construct_octet_string(key *C.char, data unsafe.Pointer, dataSize int) _OSSL_PARAM { + return ossl_param_construct(key, C.GO_OSSL_PARAM_OCTET_STRING, data, dataSize) +} + +func _OSSL_PARAM_construct_end() _OSSL_PARAM { + return _OSSL_PARAM{} +} + type bnParam struct { value C.GO_BIGNUM_PTR private bool @@ -65,13 +94,17 @@ type paramBuilder struct { // newParamBuilder creates a new paramBuilder. func newParamBuilder() (*paramBuilder, error) { + return newParamBuilderN(8) // the maximum known number of BIGNUMs to free are 8 for RSA +} + +func newParamBuilderN(n int) (*paramBuilder, error) { bld := C.go_openssl_OSSL_PARAM_BLD_new() if bld == nil { return nil, newOpenSSLError("OSSL_PARAM_BLD_new") } pb := ¶mBuilder{ bld: bld, - bnToFree: make([]bnParam, 0, 8), // the maximum known number of BIGNUMs to free are 8 for RSA + bnToFree: make([]bnParam, 0, n), } runtime.SetFinalizer(pb, (*paramBuilder).finalize) return pb, nil diff --git a/provideropenssl.go b/provideropenssl.go new file mode 100644 index 00000000..912bcbbf --- /dev/null +++ b/provideropenssl.go @@ -0,0 +1,239 @@ +//go:build !cmd_go_bootstrap + +package openssl + +// #include "goopenssl.h" +import "C" +import ( + "crypto" + "errors" + "unsafe" +) + +// This file contains code specific to the built-in OpenSSL providers. + +// _OSSL_MD5_CTX layout is taken from +// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/md5.h#L33. +type _OSSL_MD5_CTX struct { + h [4]uint32 + nl, nh uint32 + x [64]byte + nx uint32 +} + +func (d *_OSSL_MD5_CTX) UnmarshalBinary(b []byte) error { + b, d.h[0] = consumeUint32(b) + b, d.h[1] = consumeUint32(b) + b, d.h[2] = consumeUint32(b) + b, d.h[3] = consumeUint32(b) + b = b[copy(d.x[:], b):] + _, n := consumeUint64(b) + d.nl = uint32(n << 3) + d.nh = uint32(n >> 29) + d.nx = uint32(n) % 64 + return nil +} + +func (d *_OSSL_MD5_CTX) AppendBinary(buf []byte) ([]byte, error) { + buf = appendUint32(buf, d.h[0]) + buf = appendUint32(buf, d.h[1]) + buf = appendUint32(buf, d.h[2]) + buf = appendUint32(buf, d.h[3]) + buf = append(buf, d.x[:d.nx]...) + buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) + buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) + return buf, nil +} + +// _OSSL_SHA_CTX layout is taken from +// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L34. +type _OSSL_SHA_CTX struct { + h [5]uint32 + nl, nh uint32 + x [64]byte + nx uint32 +} + +func (d *_OSSL_SHA_CTX) UnmarshalBinary(b []byte) error { + b, d.h[0] = consumeUint32(b) + b, d.h[1] = consumeUint32(b) + b, d.h[2] = consumeUint32(b) + b, d.h[3] = consumeUint32(b) + b, d.h[4] = consumeUint32(b) + b = b[copy(d.x[:], b):] + _, n := consumeUint64(b) + d.nl = uint32(n << 3) + d.nh = uint32(n >> 29) + d.nx = uint32(n) % 64 + return nil +} + +func (d *_OSSL_SHA_CTX) AppendBinary(buf []byte) ([]byte, error) { + buf = appendUint32(buf, d.h[0]) + buf = appendUint32(buf, d.h[1]) + buf = appendUint32(buf, d.h[2]) + buf = appendUint32(buf, d.h[3]) + buf = appendUint32(buf, d.h[4]) + buf = append(buf, d.x[:d.nx]...) + buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) + buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) + return buf, nil +} + +// _OSSL_SHA256_CTX layout is taken from +// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L51. +type _OSSL_SHA256_CTX struct { + h [8]uint32 + nl, nh uint32 + x [64]byte + nx uint32 +} + +func (d *_OSSL_SHA256_CTX) UnmarshalBinary(b []byte) error { + b, d.h[0] = consumeUint32(b) + b, d.h[1] = consumeUint32(b) + b, d.h[2] = consumeUint32(b) + b, d.h[3] = consumeUint32(b) + b, d.h[4] = consumeUint32(b) + b, d.h[5] = consumeUint32(b) + b, d.h[6] = consumeUint32(b) + b, d.h[7] = consumeUint32(b) + b = b[copy(d.x[:], b):] + _, n := consumeUint64(b) + d.nl = uint32(n << 3) + d.nh = uint32(n >> 29) + d.nx = uint32(n) % 64 + return nil +} + +func (d *_OSSL_SHA256_CTX) AppendBinary(buf []byte) ([]byte, error) { + buf = appendUint32(buf, d.h[0]) + buf = appendUint32(buf, d.h[1]) + buf = appendUint32(buf, d.h[2]) + buf = appendUint32(buf, d.h[3]) + buf = appendUint32(buf, d.h[4]) + buf = appendUint32(buf, d.h[5]) + buf = appendUint32(buf, d.h[6]) + buf = appendUint32(buf, d.h[7]) + buf = append(buf, d.x[:d.nx]...) + buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) + buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) + return buf, nil +} + +// _OSSL_SHA512_CTX layout is taken from +// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L95. +type _OSSL_SHA512_CTX struct { + h [8]uint64 + nl, nh uint64 + x [128]byte + nx uint32 +} + +func (d *_OSSL_SHA512_CTX) UnmarshalBinary(b []byte) error { + b, d.h[0] = consumeUint64(b) + b, d.h[1] = consumeUint64(b) + b, d.h[2] = consumeUint64(b) + b, d.h[3] = consumeUint64(b) + b, d.h[4] = consumeUint64(b) + b, d.h[5] = consumeUint64(b) + b, d.h[6] = consumeUint64(b) + b, d.h[7] = consumeUint64(b) + b = b[copy(d.x[:], b):] + _, n := consumeUint64(b) + d.nl = n << 3 + d.nh = n >> 61 + d.nx = uint32(n) % 128 + return nil +} + +func (d *_OSSL_SHA512_CTX) AppendBinary(buf []byte) ([]byte, error) { + buf = appendUint64(buf, d.h[0]) + buf = appendUint64(buf, d.h[1]) + buf = appendUint64(buf, d.h[2]) + buf = appendUint64(buf, d.h[3]) + buf = appendUint64(buf, d.h[4]) + buf = appendUint64(buf, d.h[5]) + buf = appendUint64(buf, d.h[6]) + buf = appendUint64(buf, d.h[7]) + buf = append(buf, d.x[:d.nx]...) + buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) + buf = appendUint64(buf, d.nl>>3|d.nh<<61) + return buf, nil +} + +func getOSSLDigetsContext(ctx C.GO_EVP_MD_CTX_PTR) unsafe.Pointer { + switch vMajor { + case 1: + // https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/crypto/evp/evp_local.h#L12. + type mdCtx struct { + _ [2]unsafe.Pointer + _ uint32 + md_data unsafe.Pointer + } + return (*mdCtx)(unsafe.Pointer(ctx)).md_data + case 3: + // The EVP_MD_CTX memory layout has changed in OpenSSL 3 + // and the property holding the internal structure is no longer md_data but algctx. + // https://github.com/openssl/openssl/blob/5675a5aaf6a2e489022bcfc18330dae9263e598e/crypto/evp/evp_local.h#L16. + type mdCtx struct { + _ [3]unsafe.Pointer + _ uint32 + _ [3]unsafe.Pointer + algctx unsafe.Pointer + } + return (*mdCtx)(unsafe.Pointer(ctx)).algctx + default: + panic(errUnsupportedVersion()) + } +} + +var errHashStateInvalid = errors.New("openssl: can't retrieve hash state") + +func osslHashAppendBinary(ctx C.GO_EVP_MD_CTX_PTR, ch crypto.Hash, magic string, buf []byte) ([]byte, error) { + algctx := getOSSLDigetsContext(ctx) + if algctx == nil { + return nil, errHashStateInvalid + } + buf = append(buf, magic...) + switch ch { + case crypto.MD5: + d := (*_OSSL_MD5_CTX)(unsafe.Pointer(algctx)) + return d.AppendBinary(buf) + case crypto.SHA1: + d := (*_OSSL_SHA_CTX)(unsafe.Pointer(algctx)) + return d.AppendBinary(buf) + case crypto.SHA224, crypto.SHA256: + d := (*_OSSL_SHA256_CTX)(unsafe.Pointer(algctx)) + return d.AppendBinary(buf) + case crypto.SHA384, crypto.SHA512_224, crypto.SHA512_256, crypto.SHA512: + d := (*_OSSL_SHA512_CTX)(unsafe.Pointer(algctx)) + return d.AppendBinary(buf) + default: + panic("unsupported hash " + ch.String()) + } +} + +func osslHashUnmarshalBinary(ctx C.GO_EVP_MD_CTX_PTR, ch crypto.Hash, magic string, b []byte) error { + algctx := getOSSLDigetsContext(ctx) + if algctx == nil { + return errHashStateInvalid + } + b = b[len(magic):] + switch ch { + case crypto.MD5: + d := (*_OSSL_MD5_CTX)(unsafe.Pointer(algctx)) + return d.UnmarshalBinary(b) + case crypto.SHA1: + d := (*_OSSL_SHA_CTX)(unsafe.Pointer(algctx)) + return d.UnmarshalBinary(b) + case crypto.SHA224, crypto.SHA256: + d := (*_OSSL_SHA256_CTX)(unsafe.Pointer(algctx)) + return d.UnmarshalBinary(b) + case crypto.SHA384, crypto.SHA512_224, crypto.SHA512_256, crypto.SHA512: + d := (*_OSSL_SHA512_CTX)(unsafe.Pointer(algctx)) + return d.UnmarshalBinary(b) + default: + panic("unsupported hash " + ch.String()) + } +} diff --git a/providersymcrypt.go b/providersymcrypt.go new file mode 100644 index 00000000..f087439a --- /dev/null +++ b/providersymcrypt.go @@ -0,0 +1,377 @@ +//go:build !cmd_go_bootstrap + +package openssl + +// #include "goopenssl.h" +import "C" +import ( + "crypto" + "encoding/binary" + "errors" + "runtime" + "sync" + "unsafe" +) + +// This file contains code specific to the SymCrypt provider. + +var ( + _SCOSSL_DIGEST_PARAM_STATE = C.CString("state") + _SCOSSL_DIGEST_PARAM_RECOMPUTE_CHECKSUM = C.CString("recompute_checksum") +) + +const ( + _SYMCRYPT_BLOB_MAGIC = 0x636D7973 // "cysm" in little-endian + + _SymCryptBlobTypeHashState = 0x100 + _SymCryptBlobTypeMd2State = _SymCryptBlobTypeHashState + 1 + _SymCryptBlobTypeMd4State = _SymCryptBlobTypeHashState + 2 + _SymCryptBlobTypeMd5State = _SymCryptBlobTypeHashState + 3 + _SymCryptBlobTypeSha1State = _SymCryptBlobTypeHashState + 4 + _SymCryptBlobTypeSha256State = _SymCryptBlobTypeHashState + 5 + _SymCryptBlobTypeSha384State = _SymCryptBlobTypeHashState + 6 + _SymCryptBlobTypeSha512State = _SymCryptBlobTypeHashState + 7 + _SymCryptBlobTypeSha3_256State = _SymCryptBlobTypeHashState + 8 + _SymCryptBlobTypeSha3_384State = _SymCryptBlobTypeHashState + 9 + _SymCryptBlobTypeSha3_512State = _SymCryptBlobTypeHashState + 10 + _SymCryptBlobTypeSha224State = _SymCryptBlobTypeHashState + 11 + _SymCryptBlobTypeSha512_224State = _SymCryptBlobTypeHashState + 12 + _SymCryptBlobTypeSha512_256State = _SymCryptBlobTypeHashState + 13 + _SymCryptBlobTypeSha3_224State = _SymCryptBlobTypeHashState + 14 + + _SYMCRYPT_MD5_STATE_EXPORT_SIZE = uint32(unsafe.Sizeof(_SYMCRYPT_MD5_STATE_EXPORT_BLOB{})) + _SYMCRYPT_SHA1_STATE_EXPORT_SIZE = uint32(unsafe.Sizeof(_SYMCRYPT_SHA1_STATE_EXPORT_BLOB{})) + _SYMCRYPT_SHA256_STATE_EXPORT_SIZE = uint32(unsafe.Sizeof(_SYMCRYPT_SHA256_STATE_EXPORT_BLOB{})) + _SYMCRYPT_SHA512_STATE_EXPORT_SIZE = uint32(unsafe.Sizeof(_SYMCRYPT_SHA512_STATE_EXPORT_BLOB{})) +) + +type _SYMCRYPT_BLOB_HEADER struct { + magic uint32 + size uint32 + _type uint32 +} + +type _SYMCRYPT_BLOB_TRAILER struct { + checksum [8]uint8 +} + +// _UINT64 is a 64-bit unsigned integer, stored in native endianess. +// It is used to represent a SymCrypt UINT64 type without making the +// parent struct 8-byte aligned, given that the Windows ABI makes +// the struct 4-byte aligned. +type _UINT64 [2]uint32 + +func newUINT64(v uint64) _UINT64 { + var u _UINT64 + if nativeEndian == binary.BigEndian { + u[0], u[1] = uint32(v>>32), uint32(v) + } else { + u[0], u[1] = uint32(v), uint32(v>>32) + } + return u +} + +func (u *_UINT64) uint64() uint64 { + if nativeEndian == binary.BigEndian { + return uint64(u[0])<<32 | (uint64(u[1])) + } + return uint64(u[0]) | (uint64(u[1]) << 32) +} + +// symCryptAppendBinary appends the binary representation of a SymCrypt state +// to the given destination slice. +func symCryptAppendBinary(dst, chain, buffer []byte, blength _UINT64) []byte { + length := blength.uint64() + var nx uint64 + if len(buffer) <= 64 { + nx = length & 0x3f + } else { + nx = length & 0x7f + } + dst = append(dst, chain...) + dst = append(dst, buffer[:nx]...) + dst = append(dst, make([]byte, len(buffer)-int(nx))...) + dst = appendUint64(dst, length) + return dst +} + +// symCryptUnmarshalBinary unmarshals the binary representation of a SymCrypt state +// from the given source slice. It returns the length of the data. +func symCryptUnmarshalBinary(d []byte, chain, buffer []byte) _UINT64 { + copy(chain[:], d) + d = d[len(chain):] + copy(buffer[:], d) + d = d[len(buffer):] + _, length := consumeUint64(d) + return newUINT64(length) +} + +// swapEndianessInt32 swaps the endianness of the given byte slice +// in place. It assumes the slice is a backup of a 32-bit integer array. +func swapEndianessInt32(d []uint8) { + for i := 0; i < len(d); i += 4 { + d[i], d[i+3] = d[i+3], d[i] + d[i+1], d[i+2] = d[i+2], d[i+1] + } + +} + +type _SYMCRYPT_MD5_STATE_EXPORT_BLOB struct { + header _SYMCRYPT_BLOB_HEADER + chain [16]uint8 // little endian + length _UINT64 // native endian + buffer [64]uint8 + _ [8]uint8 // reserved + _ _SYMCRYPT_BLOB_TRAILER +} + +func (b *_SYMCRYPT_MD5_STATE_EXPORT_BLOB) appendBinary(d []byte) ([]byte, error) { + // b.chain is little endian, but Go expects big endian, + // we need to swap the bytes. + swapEndianessInt32(b.chain[:]) + return symCryptAppendBinary(d, b.chain[:], b.buffer[:], b.length), nil +} + +func (b *_SYMCRYPT_MD5_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { + b.length = symCryptUnmarshalBinary(d, b.chain[:], b.buffer[:]) + swapEndianessInt32(b.chain[:]) +} + +type _SYMCRYPT_SHA1_STATE_EXPORT_BLOB struct { + header _SYMCRYPT_BLOB_HEADER + chain [20]uint8 // big endian + length _UINT64 // native endian + buffer [64]uint8 + _ [8]uint8 // reserved + _ _SYMCRYPT_BLOB_TRAILER +} + +func (b *_SYMCRYPT_SHA1_STATE_EXPORT_BLOB) appendBinary(d []byte) ([]byte, error) { + return symCryptAppendBinary(d, b.chain[:], b.buffer[:], b.length), nil +} + +func (b *_SYMCRYPT_SHA1_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { + b.length = symCryptUnmarshalBinary(d, b.chain[:], b.buffer[:]) +} + +type _SYMCRYPT_SHA256_STATE_EXPORT_BLOB struct { + header _SYMCRYPT_BLOB_HEADER + chain [32]uint8 // big endian + length _UINT64 // native endian + buffer [64]uint8 + _ [8]uint8 // reserved + _ _SYMCRYPT_BLOB_TRAILER +} + +func (b *_SYMCRYPT_SHA256_STATE_EXPORT_BLOB) appendBinary(d []byte) ([]byte, error) { + return symCryptAppendBinary(d, b.chain[:], b.buffer[:], b.length), nil +} + +func (b *_SYMCRYPT_SHA256_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { + b.length = symCryptUnmarshalBinary(d, b.chain[:], b.buffer[:]) +} + +type _SYMCRYPT_SHA512_STATE_EXPORT_BLOB struct { + header _SYMCRYPT_BLOB_HEADER + chain [64]uint8 // big endian + lengthL _UINT64 // native endian + lengthH _UINT64 // native endian + buffer [128]uint8 + _ [8]uint8 // reserved + _ _SYMCRYPT_BLOB_TRAILER +} + +func (b *_SYMCRYPT_SHA512_STATE_EXPORT_BLOB) appendBinary(d []byte) ([]byte, error) { + if b.lengthH.uint64() != 0 { + return nil, errors.New("exporting state with more than 2^63-1 bytes of data is not supported") + } + return symCryptAppendBinary(d, b.chain[:], b.buffer[:], b.lengthL), nil +} + +func (b *_SYMCRYPT_SHA512_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { + b.lengthL = symCryptUnmarshalBinary(d, b.chain[:], b.buffer[:]) +} + +func symCryptHashAppendBinary(ctx C.GO_EVP_MD_CTX_PTR, ch crypto.Hash, magic string, buf []byte) ([]byte, error) { + size, typ, serializable := symCryptHashStateInfo(ch) + if !serializable { + return nil, errHashNotMarshallable + } + state := make([]byte, size, _SYMCRYPT_SHA512_STATE_EXPORT_SIZE) // 512 is the largest size + var pinner runtime.Pinner + pinner.Pin(&state[0]) + defer pinner.Unpin() + params := [2]_OSSL_PARAM{ + _OSSL_PARAM_construct_octet_string(_SCOSSL_DIGEST_PARAM_STATE, unsafe.Pointer(&state[0]), len(state)), + _OSSL_PARAM_construct_end(), + } + if C.go_openssl_EVP_MD_CTX_get_params(ctx, (C.GO_OSSL_PARAM_PTR)(unsafe.Pointer(¶ms[0]))) != 1 { + return nil, newOpenSSLError("EVP_MD_CTX_get_params") + } + + header := (*_SYMCRYPT_BLOB_HEADER)(unsafe.Pointer(&state[0])) + if header.magic != _SYMCRYPT_BLOB_MAGIC { + return nil, errors.New("invalid blob magic") + } + if header.size != size { + return nil, errors.New("invalid blob size") + } + if header._type != typ { + return nil, errors.New("invalid blob type") + } + + buf = append(buf, magic...) + switch ch { + case crypto.MD5: + blob := (*_SYMCRYPT_MD5_STATE_EXPORT_BLOB)(unsafe.Pointer(&state[0])) + return blob.appendBinary(buf) + case crypto.SHA1: + blob := (*_SYMCRYPT_SHA1_STATE_EXPORT_BLOB)(unsafe.Pointer(&state[0])) + return blob.appendBinary(buf) + case crypto.SHA224, crypto.SHA256: + blob := (*_SYMCRYPT_SHA256_STATE_EXPORT_BLOB)(unsafe.Pointer(&state[0])) + return blob.appendBinary(buf) + case crypto.SHA384, crypto.SHA512_224, crypto.SHA512_256, crypto.SHA512: + blob := (*_SYMCRYPT_SHA512_STATE_EXPORT_BLOB)(unsafe.Pointer(&state[0])) + return blob.appendBinary(buf) + default: + panic("unsupported hash " + ch.String()) + } +} + +func symCryptHashUnmarshalBinary(ctx C.GO_EVP_MD_CTX_PTR, ch crypto.Hash, magic string, b []byte) error { + size, typ, serializable := symCryptHashStateInfo(ch) + if !serializable { + return errHashNotMarshallable + } + hdr := _SYMCRYPT_BLOB_HEADER{ + magic: _SYMCRYPT_BLOB_MAGIC, + size: size, + _type: typ, + } + var blobPtr unsafe.Pointer + b = b[len(magic):] + switch ch { + case crypto.MD5: + var blob _SYMCRYPT_MD5_STATE_EXPORT_BLOB + blobPtr = unsafe.Pointer(&blob) + blob.header = hdr + blob.unmarshalBinary(b) + case crypto.SHA1: + var blob _SYMCRYPT_SHA1_STATE_EXPORT_BLOB + blobPtr = unsafe.Pointer(&blob) + blob.header = hdr + blob.unmarshalBinary(b) + case crypto.SHA224, crypto.SHA256: + var blob _SYMCRYPT_SHA256_STATE_EXPORT_BLOB + blobPtr = unsafe.Pointer(&blob) + blob.header = hdr + blob.unmarshalBinary(b) + case crypto.SHA384, crypto.SHA512_224, crypto.SHA512_256, crypto.SHA512: + var blob _SYMCRYPT_SHA512_STATE_EXPORT_BLOB + blobPtr = unsafe.Pointer(&blob) + blob.header = hdr + blob.unmarshalBinary(b) + default: + panic("unsupported hash " + ch.String()) + } + bld, err := newParamBuilderN(2) + if err != nil { + return err + } + defer bld.finalize() + bld.addOctetString(_SCOSSL_DIGEST_PARAM_STATE, unsafe.Slice((*byte)(blobPtr), hdr.size)) + bld.addInt32(_SCOSSL_DIGEST_PARAM_RECOMPUTE_CHECKSUM, 1) + params, err := bld.build() + if err != nil { + return err + } + if C.go_openssl_EVP_MD_CTX_set_params(ctx, params) == 0 { + return newOpenSSLError("EVP_MD_CTX_set_params") + } + return nil +} + +func symCryptHashStateInfo(ch crypto.Hash) (size, typ uint32, serializable bool) { + switch ch { + case crypto.MD5: + return _SYMCRYPT_MD5_STATE_EXPORT_SIZE, _SymCryptBlobTypeMd5State, symCryptHashStateSerializableMD5() + case crypto.SHA1: + return _SYMCRYPT_SHA1_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha1State, symCryptHashStateSerializableSHA1() + case crypto.SHA224: + return _SYMCRYPT_SHA256_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha224State, symCryptHashStateSerializableSHA224() + case crypto.SHA256: + return _SYMCRYPT_SHA256_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha256State, symCryptHashStateSerializableSHA256() + case crypto.SHA384: + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha384State, symCryptHashStateSerializableSHA384() + case crypto.SHA512_224: + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512_224State, symCryptHashStateSerializableSHA512_224() + case crypto.SHA512_256: + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512_256State, symCryptHashStateSerializableSHA512_256() + case crypto.SHA512: + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512State, symCryptHashStateSerializableSHA512() + default: + panic("unsupported hash " + ch.String()) + } +} + +var ( + symCryptHashStateSerializableMD5 = sync.OnceValue(func() bool { + return isSymCryptHashStateSerializable(crypto.MD5) + }) + symCryptHashStateSerializableSHA1 = sync.OnceValue(func() bool { + return isSymCryptHashStateSerializable(crypto.SHA1) + }) + symCryptHashStateSerializableSHA224 = sync.OnceValue(func() bool { + return isSymCryptHashStateSerializable(crypto.SHA224) + }) + symCryptHashStateSerializableSHA256 = sync.OnceValue(func() bool { + return isSymCryptHashStateSerializable(crypto.SHA256) + }) + symCryptHashStateSerializableSHA384 = sync.OnceValue(func() bool { + return isSymCryptHashStateSerializable(crypto.SHA384) + }) + symCryptHashStateSerializableSHA512_224 = sync.OnceValue(func() bool { + return isSymCryptHashStateSerializable(crypto.SHA512_224) + }) + symCryptHashStateSerializableSHA512_256 = sync.OnceValue(func() bool { + return isSymCryptHashStateSerializable(crypto.SHA512_256) + }) + symCryptHashStateSerializableSHA512 = sync.OnceValue(func() bool { + return isSymCryptHashStateSerializable(crypto.SHA512) + }) +) + +// isSymCryptHashStateSerializable checks if the SymCrypt hash state is serializable. +func isSymCryptHashStateSerializable(ch crypto.Hash) bool { + alg := loadHash(ch) + if alg == nil { + return false + } + ctx := C.go_openssl_EVP_MD_CTX_new() + if ctx == nil { + return false + } + defer C.go_openssl_EVP_MD_CTX_free(ctx) + if C.go_openssl_EVP_DigestInit_ex(ctx, alg.md, nil) != 1 { + return false + } + params := C.go_openssl_EVP_MD_CTX_gettable_params(ctx) + if params == nil { + return false + } + if C.go_openssl_OSSL_PARAM_locate_const(params, _SCOSSL_DIGEST_PARAM_STATE) == nil { + return false + } + params = C.go_openssl_EVP_MD_CTX_settable_params(ctx) + if params == nil { + return false + } + if C.go_openssl_OSSL_PARAM_locate_const(params, _SCOSSL_DIGEST_PARAM_STATE) == nil { + return false + } + if C.go_openssl_OSSL_PARAM_locate_const(params, _SCOSSL_DIGEST_PARAM_RECOMPUTE_CHECKSUM) == nil { + return false + } + return true +} diff --git a/shims.h b/shims.h index 437312ad..5c41291b 100644 --- a/shims.h +++ b/shims.h @@ -27,7 +27,9 @@ enum { GO_EVP_MAX_MD_SIZE = 64, GO_EVP_PKEY_PUBLIC_KEY = 0x86, - GO_EVP_PKEY_KEYPAIR = 0x87 + GO_EVP_PKEY_KEYPAIR = 0x87, + + GO_OSSL_PARAM_OCTET_STRING = 5 }; // #include @@ -219,6 +221,10 @@ DEFINEFUNC_RENAMED_1_1(GO_EVP_MD_CTX_PTR, EVP_MD_CTX_new, EVP_MD_CTX_create, (vo DEFINEFUNC_RENAMED_1_1(void, EVP_MD_CTX_free, EVP_MD_CTX_destroy, (GO_EVP_MD_CTX_PTR ctx), (ctx)) \ DEFINEFUNC(int, EVP_MD_CTX_copy, (GO_EVP_MD_CTX_PTR out, const GO_EVP_MD_CTX_PTR in), (out, in)) \ DEFINEFUNC(int, EVP_MD_CTX_copy_ex, (GO_EVP_MD_CTX_PTR out, const GO_EVP_MD_CTX_PTR in), (out, in)) \ +DEFINEFUNC_3_0(const GO_OSSL_PARAM_PTR, EVP_MD_CTX_gettable_params, (GO_EVP_MD_CTX_PTR ctx), (ctx)) \ +DEFINEFUNC_3_0(const GO_OSSL_PARAM_PTR, EVP_MD_CTX_settable_params, (GO_EVP_MD_CTX_PTR ctx), (ctx)) \ +DEFINEFUNC_3_0(int, EVP_MD_CTX_get_params, (GO_EVP_MD_CTX_PTR ctx, GO_OSSL_PARAM_PTR params), (ctx, params)) \ +DEFINEFUNC_3_0(int, EVP_MD_CTX_set_params, (GO_EVP_MD_CTX_PTR ctx, const GO_OSSL_PARAM_PTR params), (ctx, params)) \ DEFINEFUNC(int, EVP_Digest, (const void *data, size_t count, unsigned char *md, unsigned int *size, const GO_EVP_MD_PTR type, GO_ENGINE_PTR impl), (data, count, md, size, type, impl)) \ DEFINEFUNC(int, EVP_DigestInit_ex, (GO_EVP_MD_CTX_PTR ctx, const GO_EVP_MD_PTR type, GO_ENGINE_PTR impl), (ctx, type, impl)) \ DEFINEFUNC(int, EVP_DigestInit, (GO_EVP_MD_CTX_PTR ctx, const GO_EVP_MD_PTR type), (ctx, type)) \ @@ -376,6 +382,7 @@ DEFINEFUNC_3_0(int, EVP_MAC_init, (GO_EVP_MAC_CTX_PTR ctx, const unsigned char * DEFINEFUNC_3_0(int, EVP_MAC_update, (GO_EVP_MAC_CTX_PTR ctx, const unsigned char *data, size_t datalen), (ctx, data, datalen)) \ DEFINEFUNC_3_0(int, EVP_MAC_final, (GO_EVP_MAC_CTX_PTR ctx, unsigned char *out, size_t *outl, size_t outsize), (ctx, out, outl, outsize)) \ DEFINEFUNC_3_0(void, OSSL_PARAM_free, (GO_OSSL_PARAM_PTR p), (p)) \ +DEFINEFUNC_3_0(const GO_OSSL_PARAM_PTR, OSSL_PARAM_locate_const, (const GO_OSSL_PARAM_PTR p, const char *key), (p, key)) \ DEFINEFUNC_3_0(GO_OSSL_PARAM_BLD_PTR, OSSL_PARAM_BLD_new, (void), ()) \ DEFINEFUNC_3_0(void, OSSL_PARAM_BLD_free, (GO_OSSL_PARAM_BLD_PTR bld), (bld)) \ DEFINEFUNC_3_0(GO_OSSL_PARAM_PTR, OSSL_PARAM_BLD_to_param, (GO_OSSL_PARAM_BLD_PTR bld), (bld)) \ From 00f166734d899230c6f1f901a172c1263c34adcc Mon Sep 17 00:00:00 2001 From: qmuntal Date: Mon, 26 May 2025 08:09:48 +0200 Subject: [PATCH 2/2] reduce diffs --- evp.go | 37 ++++++++++++++--- hash.go | 58 +++++++++------------------ params.go | 26 +++++++----- providersymcrypt.go | 97 ++++++++++++++------------------------------- shims.h | 6 +++ 5 files changed, 102 insertions(+), 122 deletions(-) diff --git a/evp.go b/evp.go index cbf3d178..8594270d 100644 --- a/evp.go +++ b/evp.go @@ -79,11 +79,14 @@ const ( ) type hashAlgorithm struct { - md C.GO_EVP_MD_PTR - ch crypto.Hash - size int - blockSize int - provider provider + md C.GO_EVP_MD_PTR + ch crypto.Hash + size int + blockSize int + provider provider + marshallable bool + magic string + marshalledSize int } // loadHash converts a crypto.Hash to a EVP_MD. @@ -100,6 +103,8 @@ func loadHash(ch crypto.Hash) *hashAlgorithm { hash.md = C.go_openssl_EVP_md4() case crypto.MD5: hash.md = C.go_openssl_EVP_md5() + hash.magic = magicMD5 + hash.marshalledSize = marshaledSizeMD5 case crypto.MD5SHA1: if vMajor == 1 && vMinor == 0 { // OpenSSL 1.0.2 does not support MD5SHA1. @@ -109,21 +114,35 @@ func loadHash(ch crypto.Hash) *hashAlgorithm { } case crypto.SHA1: hash.md = C.go_openssl_EVP_sha1() + hash.magic = magic1 + hash.marshalledSize = marshaledSize1 case crypto.SHA224: hash.md = C.go_openssl_EVP_sha224() + hash.magic = magic224 + hash.marshalledSize = marshaledSize256 case crypto.SHA256: hash.md = C.go_openssl_EVP_sha256() + hash.magic = magic256 + hash.marshalledSize = marshaledSize256 case crypto.SHA384: hash.md = C.go_openssl_EVP_sha384() + hash.magic = magic384 + hash.marshalledSize = marshaledSize512 case crypto.SHA512: hash.md = C.go_openssl_EVP_sha512() + hash.magic = magic512 + hash.marshalledSize = marshaledSize512 case crypto.SHA512_224: if versionAtOrAbove(1, 1, 1) { hash.md = C.go_openssl_EVP_sha512_224() + hash.magic = magic512_224 + hash.marshalledSize = marshaledSize512 } case crypto.SHA512_256: if versionAtOrAbove(1, 1, 1) { hash.md = C.go_openssl_EVP_sha512_256() + hash.magic = magic512_256 + hash.marshalledSize = marshaledSize512 } case crypto.SHA3_224: if versionAtOrAbove(1, 1, 1) { @@ -161,6 +180,11 @@ func loadHash(ch crypto.Hash) *hashAlgorithm { hash.md = md } } + if hash.magic != "" { + if hash.marshalledSize == 0 { + panic("marshalledSize must be set for " + hash.magic) + } + } switch vMajor { case 1: @@ -170,10 +194,13 @@ func loadHash(ch crypto.Hash) *hashAlgorithm { switch C.GoString(C.go_openssl_OSSL_PROVIDER_get0_name(prov)) { case "default": hash.provider = providerOSSLDefault + hash.marshallable = hash.magic != "" case "fips": hash.provider = providerOSSLFIPS + hash.marshallable = hash.magic != "" case "symcryptprovider": hash.provider = providerSymCrypt + hash.marshallable = hash.magic != "" && isSymCryptHashStateSerializable(hash.md) } } default: diff --git a/hash.go b/hash.go index 563154e5..5f10ffd7 100644 --- a/hash.go +++ b/hash.go @@ -24,10 +24,10 @@ const ( magic512_256 = "sha\x06" magic512 = "sha\x07" - marshaledSizeMD5 = len(magicMD5) + 4*4 + 64 + 8 - marshaledSize1 = len(magic1) + 5*4 + 64 + 8 - marshaledSize256 = len(magic256) + 8*4 + 64 + 8 - marshaledSize512 = len(magic512) + 8*8 + 128 + 8 + marshaledSizeMD5 = len(magicMD5) + 4*4 + 64 + 8 // from crypto/md5 + marshaledSize1 = len(magic1) + 5*4 + 64 + 8 // from crypto/sha1 + marshaledSize256 = len(magic256) + 8*4 + 64 + 8 // from crypto/sha256 + marshaledSize512 = len(magic512) + 8*8 + 128 + 8 // from crypto/sha512 ) // maxHashSize is the size of SHA52 and SHA3_512, the largest hashes we support. @@ -382,70 +382,48 @@ func (h *evpHash) Clone() hash.Hash { var errHashNotMarshallable = errors.New("openssl: hash state is not marshallable") func (d *evpHash) MarshalBinary() ([]byte, error) { - buf := make([]byte, 0, marshaledSize512) // stack allocate the buffer by setting the max size we support + if !d.alg.marshallable { + return nil, errHashNotMarshallable + } + buf := make([]byte, 0, d.alg.marshalledSize) return d.AppendBinary(buf) } func (d *evpHash) AppendBinary(buf []byte) ([]byte, error) { defer runtime.KeepAlive(d) d.init() - magic, _ := cryptoHashEncodingInfo(d.alg.ch) - if magic == "" { + if !d.alg.marshallable { return nil, errHashNotMarshallable } switch d.alg.provider { case providerOSSLDefault, providerOSSLFIPS: - return osslHashAppendBinary(d.ctx, d.alg.ch, magic, buf) + return osslHashAppendBinary(d.ctx, d.alg.ch, d.alg.magic, buf) case providerSymCrypt: - return symCryptHashAppendBinary(d.ctx, d.alg.ch, magic, buf) + return symCryptHashAppendBinary(d.ctx, d.alg.ch, d.alg.magic, buf) default: - return nil, errHashNotMarshallable + panic("openssl: unknown hash provider" + strconv.Itoa(int(d.alg.provider))) } } func (d *evpHash) UnmarshalBinary(b []byte) error { defer runtime.KeepAlive(d) d.init() - magic, size := cryptoHashEncodingInfo(d.alg.ch) - if magic == "" { + if !d.alg.marshallable { return errHashNotMarshallable } - if len(b) < len(magic) || string(b[:len(magic)]) != string(magic[:]) { + if len(b) < len(d.alg.magic) || string(b[:len(d.alg.magic)]) != d.alg.magic { return errors.New("openssl: invalid hash state identifier") } - if len(b) != size { + if len(b) != d.alg.marshalledSize { return errors.New("openssl: invalid hash state size") } switch d.alg.provider { case providerOSSLDefault, providerOSSLFIPS: - return osslHashUnmarshalBinary(d.ctx, d.alg.ch, magic, b) + return osslHashUnmarshalBinary(d.ctx, d.alg.ch, d.alg.magic, b) case providerSymCrypt: - return symCryptHashUnmarshalBinary(d.ctx, d.alg.ch, magic, b) - default: - return errHashNotMarshallable - } -} - -func cryptoHashEncodingInfo(ch crypto.Hash) (magic string, size int) { - switch ch { - case crypto.MD5: - return magicMD5, marshaledSizeMD5 - case crypto.SHA1: - return magic1, marshaledSize1 - case crypto.SHA224: - return magic224, marshaledSize256 - case crypto.SHA256: - return magic256, marshaledSize256 - case crypto.SHA384: - return magic384, marshaledSize512 - case crypto.SHA512_224: - return magic512_224, marshaledSize512 - case crypto.SHA512_256: - return magic512_256, marshaledSize512 - case crypto.SHA512: - return magic512, marshaledSize512 + return symCryptHashUnmarshalBinary(d.ctx, d.alg.ch, d.alg.magic, b) default: - return "", 0 + panic("openssl: unknown hash provider" + strconv.Itoa(int(d.alg.provider))) } } diff --git a/params.go b/params.go index 844bfd23..8cfe5784 100644 --- a/params.go +++ b/params.go @@ -5,7 +5,6 @@ package openssl // #include "goopenssl.h" import "C" import ( - "math" "runtime" "unsafe" ) @@ -47,14 +46,16 @@ var ( _OSSL_MAC_PARAM_DIGEST = C.CString("digest") ) +const _OSSL_PARAM_UNMODIFIED uint = uint(^uintptr(0)) + // _OSSL_PARAM is a structure to pass or request object parameters. // https://docs.openssl.org/3.0/man3/OSSL_PARAM/. type _OSSL_PARAM struct { Key *C.char DataType uint32 Data unsafe.Pointer - DataSize int - ReturnSize int + DataSize uint + ReturnSize uint } func ossl_param_construct(key *C.char, dataType uint32, data unsafe.Pointer, dataSize int) _OSSL_PARAM { @@ -62,8 +63,8 @@ func ossl_param_construct(key *C.char, dataType uint32, data unsafe.Pointer, dat Key: key, DataType: dataType, Data: data, - DataSize: dataSize, - ReturnSize: math.MaxInt - 1, + DataSize: uint(dataSize), + ReturnSize: _OSSL_PARAM_UNMODIFIED, } } @@ -71,10 +72,19 @@ func _OSSL_PARAM_construct_octet_string(key *C.char, data unsafe.Pointer, dataSi return ossl_param_construct(key, C.GO_OSSL_PARAM_OCTET_STRING, data, dataSize) } +func _OSSL_PARAM_construct_int32(key *C.char, data *int32) _OSSL_PARAM { + return ossl_param_construct(key, C.GO_OSSL_PARAM_INTEGER, unsafe.Pointer(data), 4) +} + func _OSSL_PARAM_construct_end() _OSSL_PARAM { return _OSSL_PARAM{} } +func _OSSL_PARAM_modified(param *_OSSL_PARAM) bool { + // If ReturnSize is not set, the parameter has not been modified. + return param != nil && param.ReturnSize != _OSSL_PARAM_UNMODIFIED +} + type bnParam struct { value C.GO_BIGNUM_PTR private bool @@ -94,17 +104,13 @@ type paramBuilder struct { // newParamBuilder creates a new paramBuilder. func newParamBuilder() (*paramBuilder, error) { - return newParamBuilderN(8) // the maximum known number of BIGNUMs to free are 8 for RSA -} - -func newParamBuilderN(n int) (*paramBuilder, error) { bld := C.go_openssl_OSSL_PARAM_BLD_new() if bld == nil { return nil, newOpenSSLError("OSSL_PARAM_BLD_new") } pb := ¶mBuilder{ bld: bld, - bnToFree: make([]bnParam, 0, n), + bnToFree: make([]bnParam, 0, 8), // the maximum known number of BIGNUMs to free are 8 for RSA } runtime.SetFinalizer(pb, (*paramBuilder).finalize) return pb, nil diff --git a/providersymcrypt.go b/providersymcrypt.go index f087439a..c4872b68 100644 --- a/providersymcrypt.go +++ b/providersymcrypt.go @@ -9,7 +9,6 @@ import ( "encoding/binary" "errors" "runtime" - "sync" "unsafe" ) @@ -106,14 +105,13 @@ func symCryptUnmarshalBinary(d []byte, chain, buffer []byte) _UINT64 { return newUINT64(length) } -// swapEndianessInt32 swaps the endianness of the given byte slice +// swapEndianessUint32 swaps the endianness of the given byte slice // in place. It assumes the slice is a backup of a 32-bit integer array. -func swapEndianessInt32(d []uint8) { +func swapEndianessUint32(d []uint8) { for i := 0; i < len(d); i += 4 { d[i], d[i+3] = d[i+3], d[i] d[i+1], d[i+2] = d[i+2], d[i+1] } - } type _SYMCRYPT_MD5_STATE_EXPORT_BLOB struct { @@ -128,13 +126,13 @@ type _SYMCRYPT_MD5_STATE_EXPORT_BLOB struct { func (b *_SYMCRYPT_MD5_STATE_EXPORT_BLOB) appendBinary(d []byte) ([]byte, error) { // b.chain is little endian, but Go expects big endian, // we need to swap the bytes. - swapEndianessInt32(b.chain[:]) + swapEndianessUint32(b.chain[:]) return symCryptAppendBinary(d, b.chain[:], b.buffer[:], b.length), nil } func (b *_SYMCRYPT_MD5_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { b.length = symCryptUnmarshalBinary(d, b.chain[:], b.buffer[:]) - swapEndianessInt32(b.chain[:]) + swapEndianessUint32(b.chain[:]) } type _SYMCRYPT_SHA1_STATE_EXPORT_BLOB struct { @@ -193,10 +191,7 @@ func (b *_SYMCRYPT_SHA512_STATE_EXPORT_BLOB) unmarshalBinary(d []byte) { } func symCryptHashAppendBinary(ctx C.GO_EVP_MD_CTX_PTR, ch crypto.Hash, magic string, buf []byte) ([]byte, error) { - size, typ, serializable := symCryptHashStateInfo(ch) - if !serializable { - return nil, errHashNotMarshallable - } + size, typ := symCryptHashStateInfo(ch) state := make([]byte, size, _SYMCRYPT_SHA512_STATE_EXPORT_SIZE) // 512 is the largest size var pinner runtime.Pinner pinner.Pin(&state[0]) @@ -208,6 +203,9 @@ func symCryptHashAppendBinary(ctx C.GO_EVP_MD_CTX_PTR, ch crypto.Hash, magic str if C.go_openssl_EVP_MD_CTX_get_params(ctx, (C.GO_OSSL_PARAM_PTR)(unsafe.Pointer(¶ms[0]))) != 1 { return nil, newOpenSSLError("EVP_MD_CTX_get_params") } + if !_OSSL_PARAM_modified(¶ms[0]) { + return nil, errors.New("EVP_MD_CTX_get_params did not retrieve the state") + } header := (*_SYMCRYPT_BLOB_HEADER)(unsafe.Pointer(&state[0])) if header.magic != _SYMCRYPT_BLOB_MAGIC { @@ -240,10 +238,7 @@ func symCryptHashAppendBinary(ctx C.GO_EVP_MD_CTX_PTR, ch crypto.Hash, magic str } func symCryptHashUnmarshalBinary(ctx C.GO_EVP_MD_CTX_PTR, ch crypto.Hash, magic string, b []byte) error { - size, typ, serializable := symCryptHashStateInfo(ch) - if !serializable { - return errHashNotMarshallable - } + size, typ := symCryptHashStateInfo(ch) hdr := _SYMCRYPT_BLOB_HEADER{ magic: _SYMCRYPT_BLOB_MAGIC, size: size, @@ -275,85 +270,53 @@ func symCryptHashUnmarshalBinary(ctx C.GO_EVP_MD_CTX_PTR, ch crypto.Hash, magic default: panic("unsupported hash " + ch.String()) } - bld, err := newParamBuilderN(2) - if err != nil { - return err - } - defer bld.finalize() - bld.addOctetString(_SCOSSL_DIGEST_PARAM_STATE, unsafe.Slice((*byte)(blobPtr), hdr.size)) - bld.addInt32(_SCOSSL_DIGEST_PARAM_RECOMPUTE_CHECKSUM, 1) - params, err := bld.build() - if err != nil { - return err + var checksum int32 = 1 + var pinner runtime.Pinner + pinner.Pin(blobPtr) + pinner.Pin(&checksum) + defer pinner.Unpin() + params := [3]_OSSL_PARAM{ + _OSSL_PARAM_construct_octet_string(_SCOSSL_DIGEST_PARAM_STATE, blobPtr, int(hdr.size)), + _OSSL_PARAM_construct_int32(_SCOSSL_DIGEST_PARAM_RECOMPUTE_CHECKSUM, &checksum), + _OSSL_PARAM_construct_end(), } - if C.go_openssl_EVP_MD_CTX_set_params(ctx, params) == 0 { + if C.go_openssl_EVP_MD_CTX_set_params(ctx, (C.GO_OSSL_PARAM_PTR)(unsafe.Pointer(¶ms[0]))) != 1 { return newOpenSSLError("EVP_MD_CTX_set_params") } return nil } -func symCryptHashStateInfo(ch crypto.Hash) (size, typ uint32, serializable bool) { +func symCryptHashStateInfo(ch crypto.Hash) (size, typ uint32) { switch ch { case crypto.MD5: - return _SYMCRYPT_MD5_STATE_EXPORT_SIZE, _SymCryptBlobTypeMd5State, symCryptHashStateSerializableMD5() + return _SYMCRYPT_MD5_STATE_EXPORT_SIZE, _SymCryptBlobTypeMd5State case crypto.SHA1: - return _SYMCRYPT_SHA1_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha1State, symCryptHashStateSerializableSHA1() + return _SYMCRYPT_SHA1_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha1State case crypto.SHA224: - return _SYMCRYPT_SHA256_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha224State, symCryptHashStateSerializableSHA224() + return _SYMCRYPT_SHA256_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha224State case crypto.SHA256: - return _SYMCRYPT_SHA256_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha256State, symCryptHashStateSerializableSHA256() + return _SYMCRYPT_SHA256_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha256State case crypto.SHA384: - return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha384State, symCryptHashStateSerializableSHA384() + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha384State case crypto.SHA512_224: - return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512_224State, symCryptHashStateSerializableSHA512_224() + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512_224State case crypto.SHA512_256: - return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512_256State, symCryptHashStateSerializableSHA512_256() + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512_256State case crypto.SHA512: - return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512State, symCryptHashStateSerializableSHA512() + return _SYMCRYPT_SHA512_STATE_EXPORT_SIZE, _SymCryptBlobTypeSha512State default: panic("unsupported hash " + ch.String()) } } -var ( - symCryptHashStateSerializableMD5 = sync.OnceValue(func() bool { - return isSymCryptHashStateSerializable(crypto.MD5) - }) - symCryptHashStateSerializableSHA1 = sync.OnceValue(func() bool { - return isSymCryptHashStateSerializable(crypto.SHA1) - }) - symCryptHashStateSerializableSHA224 = sync.OnceValue(func() bool { - return isSymCryptHashStateSerializable(crypto.SHA224) - }) - symCryptHashStateSerializableSHA256 = sync.OnceValue(func() bool { - return isSymCryptHashStateSerializable(crypto.SHA256) - }) - symCryptHashStateSerializableSHA384 = sync.OnceValue(func() bool { - return isSymCryptHashStateSerializable(crypto.SHA384) - }) - symCryptHashStateSerializableSHA512_224 = sync.OnceValue(func() bool { - return isSymCryptHashStateSerializable(crypto.SHA512_224) - }) - symCryptHashStateSerializableSHA512_256 = sync.OnceValue(func() bool { - return isSymCryptHashStateSerializable(crypto.SHA512_256) - }) - symCryptHashStateSerializableSHA512 = sync.OnceValue(func() bool { - return isSymCryptHashStateSerializable(crypto.SHA512) - }) -) - // isSymCryptHashStateSerializable checks if the SymCrypt hash state is serializable. -func isSymCryptHashStateSerializable(ch crypto.Hash) bool { - alg := loadHash(ch) - if alg == nil { - return false - } +func isSymCryptHashStateSerializable(md C.GO_EVP_MD_PTR) bool { ctx := C.go_openssl_EVP_MD_CTX_new() if ctx == nil { return false } defer C.go_openssl_EVP_MD_CTX_free(ctx) - if C.go_openssl_EVP_DigestInit_ex(ctx, alg.md, nil) != 1 { + if C.go_openssl_EVP_DigestInit_ex(ctx, md, nil) != 1 { return false } params := C.go_openssl_EVP_MD_CTX_gettable_params(ctx) diff --git a/shims.h b/shims.h index 5c41291b..f8a300a4 100644 --- a/shims.h +++ b/shims.h @@ -28,7 +28,13 @@ enum { GO_EVP_PKEY_PUBLIC_KEY = 0x86, GO_EVP_PKEY_KEYPAIR = 0x87, +}; +// #if OPENSSL_VERSION_NUMBER >= 0x30000000L +// #include +// #endif +enum { + GO_OSSL_PARAM_INTEGER = 1, GO_OSSL_PARAM_OCTET_STRING = 5 };