17
17
//!
18
18
19
19
use core:: ptr;
20
- use core:: ops:: Deref ;
20
+ use core:: ops:: { FnMut , Deref } ;
21
21
22
22
use key:: { SecretKey , PublicKey } ;
23
23
use ffi:: { self , CPtr } ;
24
+ use secp256k1_sys:: types:: { c_int, c_uchar, c_void} ;
24
25
25
26
/// A tag used for recovering the public key from a compact signature
26
27
#[ derive( Copy , Clone ) ]
@@ -68,7 +69,7 @@ impl SharedSecret {
68
69
69
70
impl PartialEq for SharedSecret {
70
71
fn eq ( & self , other : & SharedSecret ) -> bool {
71
- & self . data [ .. self . len ] == & other. data [ ..other . len ]
72
+ self . as_ref ( ) == other. as_ref ( )
72
73
}
73
74
}
74
75
@@ -86,6 +87,22 @@ impl Deref for SharedSecret {
86
87
}
87
88
88
89
90
+ unsafe extern "C" fn hash_callback < F > ( output : * mut c_uchar , x : * const c_uchar , y : * const c_uchar , data : * mut c_void ) -> c_int
91
+ where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret {
92
+ let callback: & mut F = & mut * ( data as * mut F ) ;
93
+
94
+ let mut x_arr = [ 0 ; 32 ] ;
95
+ let mut y_arr = [ 0 ; 32 ] ;
96
+ ptr:: copy_nonoverlapping ( x, x_arr. as_mut_ptr ( ) , 32 ) ;
97
+ ptr:: copy_nonoverlapping ( y, y_arr. as_mut_ptr ( ) , 32 ) ;
98
+
99
+ let secret = callback ( x_arr, y_arr) ;
100
+ ptr:: copy_nonoverlapping ( secret. as_ptr ( ) , output as * mut u8 , secret. len ( ) ) ;
101
+
102
+ secret. len ( ) as c_int
103
+ }
104
+
105
+
89
106
impl SharedSecret {
90
107
/// Creates a new shared secret from a pubkey and secret key
91
108
#[ inline]
@@ -105,6 +122,44 @@ impl SharedSecret {
105
122
ss. set_len ( 32 ) ; // The default hash function is SHA256, which is 32 bytes long.
106
123
ss
107
124
}
125
+
126
+ /// Creates a new shared secret from a pubkey and secret key with applied custom hash function
127
+ /// # Examples
128
+ /// ```
129
+ /// # use secp256k1::ecdh::SharedSecret;
130
+ /// # use secp256k1::{Secp256k1, PublicKey, SecretKey};
131
+ /// # fn sha2(_a: &[u8], _b: &[u8]) -> [u8; 32] {[0u8; 32]}
132
+ /// # let secp = Secp256k1::signing_only();
133
+ /// # let secret_key = SecretKey::from_slice(&[3u8; 32]).unwrap();
134
+ /// # let secret_key2 = SecretKey::from_slice(&[7u8; 32]).unwrap();
135
+ /// # let public_key = PublicKey::from_secret_key(&secp, &secret_key2);
136
+ ///
137
+ /// let secret = SharedSecret::new_with_hash(&public_key, &secret_key, |x,y| {
138
+ /// let hash: [u8; 32] = sha2(&x,&y);
139
+ /// hash.into()
140
+ /// });
141
+ ///
142
+ /// ```
143
+ pub fn new_with_hash < F > ( point : & PublicKey , scalar : & SecretKey , mut hash_function : F ) -> SharedSecret
144
+ where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret
145
+ {
146
+ let mut ss = SharedSecret :: empty ( ) ;
147
+ let hashfp: ffi:: EcdhHashFn = hash_callback :: < F > ;
148
+
149
+ let res = unsafe {
150
+ ffi:: secp256k1_ecdh (
151
+ ffi:: secp256k1_context_no_precomp,
152
+ ss. get_data_mut_ptr ( ) ,
153
+ point. as_ptr ( ) ,
154
+ scalar. as_ptr ( ) ,
155
+ hashfp,
156
+ & mut hash_function as * mut F as * mut c_void ,
157
+ )
158
+ } ;
159
+ debug_assert ! ( res >= 16 ) ; // 128 bit is the minimum for a secure hash function and the minimum we let users.
160
+ ss. set_len ( res as usize ) ;
161
+ ss
162
+ }
108
163
}
109
164
110
165
#[ cfg( test) ]
0 commit comments