Skip to content

Commit 03d9e53

Browse files
committed
fix: add clone method for openssl hmac and hash
1 parent 6200bfa commit 03d9e53

File tree

5 files changed

+118
-24
lines changed

5 files changed

+118
-24
lines changed

hash.go

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -223,18 +223,8 @@ func NewSHA3_512() hash.Hash {
223223
return newEvpHash(crypto.SHA3_512)
224224
}
225225

226-
// cloneHash is an interface that defines a Clone method.
227-
//
228-
// hahs.CloneHash will probably be added in Go 1.25, see https://golang.org/issue/69521,
229-
// but we need it now.
230-
type cloneHash interface {
231-
hash.Hash
232-
// Clone returns a separate Hash instance with the same state as h.
233-
Clone() hash.Hash
234-
}
235-
236226
var _ hash.Hash = (*evpHash)(nil)
237-
var _ cloneHash = (*evpHash)(nil)
227+
var _ HashCloner = (*evpHash)(nil)
238228

239229
// evpHash implements generic hash methods.
240230
type evpHash struct {
@@ -359,7 +349,7 @@ func (h *evpHash) Sum(in []byte) []byte {
359349
// Clone returns a new evpHash object that is a deep clone of itself.
360350
// The duplicate object contains all state and data contained in the
361351
// original object at the point of duplication.
362-
func (h *evpHash) Clone() hash.Hash {
352+
func (h *evpHash) Clone() (HashCloner, error) {
363353
h2 := &evpHash{alg: h.alg}
364354
if h.ctx != nil {
365355
var err error
@@ -379,7 +369,7 @@ func (h *evpHash) Clone() hash.Hash {
379369
runtime.SetFinalizer(h2, (*evpHash).finalize)
380370
}
381371
runtime.KeepAlive(h)
382-
return h2
372+
return h2, nil
383373
}
384374

385375
var errHashNotMarshallable = errors.New("openssl: hash state is not marshallable")

hash_test.go

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -232,21 +232,42 @@ func TestHash_Clone(t *testing.T) {
232232
if !openssl.SupportsHash(ch) {
233233
t.Skip("not supported")
234234
}
235-
h := cryptoToHash(ch)()
236-
if _, ok := h.(encoding.BinaryMarshaler); !ok {
237-
t.Skip("not supported")
238-
}
235+
h := cryptoToHash(ch)().(openssl.HashCloner)
239236
_, err := h.Write(msg)
240237
if err != nil {
241238
t.Fatal(err)
242239
}
243-
// We don't define an interface for the Clone method to avoid other
244-
// packages from depending on it. Use type assertion to call it.
245-
h2 := h.(interface{ Clone() hash.Hash }).Clone()
246-
h.Write(msg)
247-
h2.Write(msg)
248-
if actual, actual2 := h.Sum(nil), h2.Sum(nil); !bytes.Equal(actual, actual2) {
249-
t.Errorf("%s(%q) = 0x%x != cloned 0x%x", ch.String(), msg, actual, actual2)
240+
241+
h3, err := h.Clone()
242+
if err != nil {
243+
t.Fatalf("Clone failed: %v", err)
244+
}
245+
prefix := []byte("tmp")
246+
writeToHash(t, h, prefix)
247+
h2, err := h.Clone()
248+
if err != nil {
249+
t.Fatalf("Clone failed: %v", err)
250+
}
251+
prefixSum := h.Sum(nil)
252+
if !bytes.Equal(prefixSum, h2.Sum(nil)) {
253+
t.Fatalf("%T Clone results are inconsistent", h)
254+
}
255+
suffix := []byte("tmp2")
256+
writeToHash(t, h, suffix)
257+
writeToHash(t, h3, append(prefix, suffix...))
258+
compositeSum := h3.Sum(nil)
259+
if !bytes.Equal(h.Sum(nil), compositeSum) {
260+
t.Fatalf("%T Clone results are inconsistent", h)
261+
}
262+
if !bytes.Equal(h2.Sum(nil), prefixSum) {
263+
t.Fatalf("%T Clone results are inconsistent", h)
264+
}
265+
writeToHash(t, h2, suffix)
266+
if !bytes.Equal(h.Sum(nil), compositeSum) {
267+
t.Fatalf("%T Clone results are inconsistent", h)
268+
}
269+
if !bytes.Equal(h2.Sum(nil), compositeSum) {
270+
t.Fatalf("%T Clone results are inconsistent", h)
250271
}
251272
})
252273
}
@@ -519,3 +540,20 @@ func (h *stubHash) Sum(in []byte) []byte { return in }
519540
func (h *stubHash) Reset() {}
520541
func (h *stubHash) Size() int { return 0 }
521542
func (h *stubHash) BlockSize() int { return 0 }
543+
544+
// Helper function for writing. Verifies that Write does not error.Add commentMore actions
545+
func writeToHash(t *testing.T, h hash.Hash, p []byte) {
546+
t.Helper()
547+
548+
before := make([]byte, len(p))
549+
copy(before, p)
550+
551+
n, err := h.Write(p)
552+
if err != nil || n != len(p) {
553+
t.Errorf("Write returned error; got (%v, %v), want (nil, %v)", err, n, len(p))
554+
}
555+
556+
if !bytes.Equal(p, before) {
557+
t.Errorf("Write modified input slice; got %x, want %x", p, before)
558+
}
559+
}

hashclone.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//go:build !cmd_go_bootstrap
2+
3+
package openssl
4+
5+
import (
6+
"hash"
7+
)
8+
9+
// HashCloner is an interface that defines a Clone method.
10+
type HashCloner interface {
11+
hash.Hash
12+
// Clone returns a separate Hash instance with the same state as h.
13+
Clone() (HashCloner, error)
14+
}

hashclone_go125.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//go:build go1.25 && !cmd_go_bootstrap
2+
3+
package openssl
4+
5+
import (
6+
"hash"
7+
)
8+
9+
type HashCloner = hash.Cloner

hmac.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import "C"
66
import (
77
"hash"
88
"runtime"
9+
"slices"
910
"sync"
1011
"unsafe"
1112

@@ -242,3 +243,45 @@ func (h *opensslHMAC) Sum(in []byte) []byte {
242243
}
243244
return append(in, h.sum[:h.size]...)
244245
}
246+
247+
func (h *opensslHMAC) Clone() (HashCloner, error) {
248+
// Make copy of context because Go hash.Hash mandates that Clone
249+
// has no effect on the underlying stream.
250+
switch vMajor {
251+
case 1:
252+
ctx2, err := ossl.HMAC_CTX_new()
253+
if err != nil {
254+
return nil, err
255+
}
256+
if _, err := ossl.HMAC_CTX_copy(ctx2, h.ctx1.ctx); err != nil {
257+
ossl.HMAC_CTX_free(ctx2)
258+
return nil, err
259+
}
260+
cl := &opensslHMAC{
261+
ctx1: hmacCtx1{ctx: ctx2},
262+
size: h.size,
263+
blockSize: h.blockSize,
264+
}
265+
runtime.SetFinalizer(cl, (*opensslHMAC).finalize)
266+
return cl, nil
267+
268+
case 3:
269+
ctx2, err := ossl.EVP_MAC_CTX_dup(h.ctx3.ctx)
270+
if err != nil {
271+
return nil, err
272+
}
273+
274+
// For OpenSSL 3.0.0, 3.0.1, and 3.0.2 we need to copy the key
275+
// from the original context to the new one.
276+
cl := &opensslHMAC{
277+
ctx3: hmacCtx3{ctx: ctx2, key: slices.Clone(h.ctx3.key)},
278+
size: h.size,
279+
blockSize: h.blockSize,
280+
}
281+
runtime.SetFinalizer(cl, (*opensslHMAC).finalize)
282+
return cl, nil
283+
284+
default:
285+
panic(errUnsupportedVersion())
286+
}
287+
}

0 commit comments

Comments
 (0)