diff --git a/hash.go b/hash.go index 85ab3fac..62182210 100644 --- a/hash.go +++ b/hash.go @@ -223,18 +223,8 @@ func NewSHA3_512() hash.Hash { return newEvpHash(crypto.SHA3_512) } -// cloneHash is an interface that defines a Clone method. -// -// hahs.CloneHash will probably be added in Go 1.25, see https://golang.org/issue/69521, -// but we need it now. -type cloneHash interface { - hash.Hash - // Clone returns a separate Hash instance with the same state as h. - Clone() hash.Hash -} - var _ hash.Hash = (*evpHash)(nil) -var _ cloneHash = (*evpHash)(nil) +var _ HashCloner = (*evpHash)(nil) // evpHash implements generic hash methods. type evpHash struct { @@ -359,7 +349,7 @@ func (h *evpHash) Sum(in []byte) []byte { // Clone returns a new evpHash object that is a deep clone of itself. // The duplicate object contains all state and data contained in the // original object at the point of duplication. -func (h *evpHash) Clone() hash.Hash { +func (h *evpHash) Clone() (HashCloner, error) { h2 := &evpHash{alg: h.alg} if h.ctx != nil { var err error @@ -379,7 +369,7 @@ func (h *evpHash) Clone() hash.Hash { runtime.SetFinalizer(h2, (*evpHash).finalize) } runtime.KeepAlive(h) - return h2 + return h2, nil } var errHashNotMarshallable = errors.New("openssl: hash state is not marshallable") diff --git a/hash_test.go b/hash_test.go index c4ecdc4b..d133b7d0 100644 --- a/hash_test.go +++ b/hash_test.go @@ -232,21 +232,42 @@ func TestHash_Clone(t *testing.T) { if !openssl.SupportsHash(ch) { t.Skip("not supported") } - h := cryptoToHash(ch)() - if _, ok := h.(encoding.BinaryMarshaler); !ok { - t.Skip("not supported") - } + h := cryptoToHash(ch)().(openssl.HashCloner) _, err := h.Write(msg) if err != nil { t.Fatal(err) } - // We don't define an interface for the Clone method to avoid other - // packages from depending on it. Use type assertion to call it. - h2 := h.(interface{ Clone() hash.Hash }).Clone() - h.Write(msg) - h2.Write(msg) - if actual, actual2 := h.Sum(nil), h2.Sum(nil); !bytes.Equal(actual, actual2) { - t.Errorf("%s(%q) = 0x%x != cloned 0x%x", ch.String(), msg, actual, actual2) + + h3, err := h.Clone() + if err != nil { + t.Fatalf("Clone failed: %v", err) + } + prefix := []byte("tmp") + writeToHash(t, h, prefix) + h2, err := h.Clone() + if err != nil { + t.Fatalf("Clone failed: %v", err) + } + prefixSum := h.Sum(nil) + if !bytes.Equal(prefixSum, h2.Sum(nil)) { + t.Fatalf("%T Clone results are inconsistent", h) + } + suffix := []byte("tmp2") + writeToHash(t, h, suffix) + writeToHash(t, h3, append(prefix, suffix...)) + compositeSum := h3.Sum(nil) + if !bytes.Equal(h.Sum(nil), compositeSum) { + t.Fatalf("%T Clone results are inconsistent", h) + } + if !bytes.Equal(h2.Sum(nil), prefixSum) { + t.Fatalf("%T Clone results are inconsistent", h) + } + writeToHash(t, h2, suffix) + if !bytes.Equal(h.Sum(nil), compositeSum) { + t.Fatalf("%T Clone results are inconsistent", h) + } + if !bytes.Equal(h2.Sum(nil), compositeSum) { + t.Fatalf("%T Clone results are inconsistent", h) } }) } @@ -519,3 +540,20 @@ func (h *stubHash) Sum(in []byte) []byte { return in } func (h *stubHash) Reset() {} func (h *stubHash) Size() int { return 0 } func (h *stubHash) BlockSize() int { return 0 } + +// Helper function for writing. Verifies that Write does not error. +func writeToHash(t *testing.T, h hash.Hash, p []byte) { + t.Helper() + + before := make([]byte, len(p)) + copy(before, p) + + n, err := h.Write(p) + if err != nil || n != len(p) { + t.Errorf("Write returned error; got (%v, %v), want (nil, %v)", err, n, len(p)) + } + + if !bytes.Equal(p, before) { + t.Errorf("Write modified input slice; got %x, want %x", p, before) + } +} diff --git a/hashclone.go b/hashclone.go new file mode 100644 index 00000000..c44be3d8 --- /dev/null +++ b/hashclone.go @@ -0,0 +1,14 @@ +//go:build !cmd_go_bootstrap + +package openssl + +import ( + "hash" +) + +// HashCloner is an interface that defines a Clone method. +type HashCloner interface { + hash.Hash + // Clone returns a separate Hash instance with the same state as h. + Clone() (HashCloner, error) +} diff --git a/hashclone_go125.go b/hashclone_go125.go new file mode 100644 index 00000000..f1f2364c --- /dev/null +++ b/hashclone_go125.go @@ -0,0 +1,9 @@ +//go:build go1.25 && !cmd_go_bootstrap + +package openssl + +import ( + "hash" +) + +type HashCloner = hash.Cloner diff --git a/hmac.go b/hmac.go index a6bb884d..f600f0f4 100644 --- a/hmac.go +++ b/hmac.go @@ -6,6 +6,7 @@ import "C" import ( "hash" "runtime" + "slices" "sync" "unsafe" @@ -242,3 +243,41 @@ func (h *opensslHMAC) Sum(in []byte) []byte { } return append(in, h.sum[:h.size]...) } + +func (h *opensslHMAC) Clone() (HashCloner, error) { + switch vMajor { + case 1: + ctx2, err := ossl.HMAC_CTX_new() + if err != nil { + panic(err) + } + if _, err := ossl.HMAC_CTX_copy(ctx2, h.ctx1.ctx); err != nil { + ossl.HMAC_CTX_free(ctx2) + panic(err) + } + cl := &opensslHMAC{ + ctx1: hmacCtx1{ctx: ctx2}, + size: h.size, + blockSize: h.blockSize, + } + runtime.SetFinalizer(cl, (*opensslHMAC).finalize) + return cl, nil + + case 3: + ctx2, err := ossl.EVP_MAC_CTX_dup(h.ctx3.ctx) + if err != nil { + panic(err) + } + + cl := &opensslHMAC{ + ctx3: hmacCtx3{ctx: ctx2, key: slices.Clone(h.ctx3.key)}, + size: h.size, + blockSize: h.blockSize, + } + runtime.SetFinalizer(cl, (*opensslHMAC).finalize) + return cl, nil + + default: + panic(errUnsupportedVersion()) + } +}